Merge remote-tracking branch 'origin' into persist-worktree-2-archived-model

Anthony Eid created

Change summary

.github/workflows/compliance_check.yml                |  55 +
.github/workflows/release.yml                         |  84 +
Cargo.lock                                            | 151 ++
Cargo.toml                                            |   5 
crates/agent/src/tests/edit_file_thread_test.rs       | 211 ++++
crates/agent/src/tests/mod.rs                         | 112 ++
crates/agent/src/tests/test_tools.rs                  |  67 +
crates/agent/src/thread.rs                            | 312 +++--
crates/agent/src/tools/streaming_edit_file_tool.rs    | 550 ++++++----
crates/agent_ui/src/agent_panel.rs                    |  12 
crates/agent_ui/src/conversation_view.rs              |   1 
crates/agent_ui/src/threads_archive_view.rs           |   8 
crates/dev_container/src/devcontainer_manifest.rs     |  29 
crates/edit_prediction/src/edit_prediction_tests.rs   |  59 +
crates/edit_prediction/src/example_spec.rs            | 111 --
crates/edit_prediction/src/zeta.rs                    |  22 
crates/edit_prediction_cli/src/parse_output.rs        |  43 
crates/gpui/src/elements/list.rs                      |  13 
crates/language_model/src/fake_provider.rs            |  10 
crates/picker/src/picker.rs                           |  17 
crates/recent_projects/src/recent_projects.rs         |  10 
crates/rules_library/src/rules_library.rs             |   4 
crates/settings_ui/src/page_data.rs                   |  71 
crates/settings_ui/src/settings_ui.rs                 |   1 
crates/sidebar/src/sidebar.rs                         |  59 
crates/sidebar/src/sidebar_tests.rs                   | 425 ++++++--
crates/title_bar/src/title_bar.rs                     |   2 
crates/workspace/src/multi_workspace.rs               | 309 ++++-
crates/workspace/src/multi_workspace_tests.rs         |  24 
crates/workspace/src/pane.rs                          |  90 +
crates/workspace/src/persistence.rs                   |  29 
crates/workspace/src/workspace.rs                     |  32 
crates/zed/src/visual_test_runner.rs                  |  28 
crates/zed/src/zed.rs                                 |  68 
crates/zeta_prompt/Cargo.toml                         |   1 
crates/zeta_prompt/src/udiff.rs                       | 200 ++++
crates/zeta_prompt/src/zeta_prompt.rs                 | 157 +++
tooling/compliance/Cargo.toml                         |  38 
tooling/compliance/LICENSE-GPL                        |   1 
tooling/compliance/src/checks.rs                      | 647 +++++++++++++
tooling/compliance/src/git.rs                         | 591 +++++++++++
tooling/compliance/src/github.rs                      | 424 ++++++++
tooling/compliance/src/lib.rs                         |   4 
tooling/compliance/src/report.rs                      | 446 ++++++++
tooling/xtask/Cargo.toml                              |   6 
tooling/xtask/src/main.rs                             |   2 
tooling/xtask/src/tasks.rs                            |   1 
tooling/xtask/src/tasks/compliance.rs                 | 135 ++
tooling/xtask/src/tasks/workflows.rs                  |   2 
tooling/xtask/src/tasks/workflows/compliance_check.rs |  66 +
tooling/xtask/src/tasks/workflows/release.rs          | 113 ++
51 files changed, 4,985 insertions(+), 873 deletions(-)

Detailed changes

.github/workflows/compliance_check.yml 🔗

@@ -0,0 +1,55 @@
+# Generated from xtask::workflows::compliance_check
+# Rebuild with `cargo xtask workflows`.
+name: compliance_check
+env:
+  CARGO_TERM_COLOR: always
+on:
+  schedule:
+  - cron: 30 17 * * 2
+jobs:
+  scheduled_compliance_check:
+    if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
+    runs-on: namespace-profile-2x4-ubuntu-2404
+    steps:
+    - name: steps::checkout_repo
+      uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd
+      with:
+        clean: false
+        fetch-depth: 0
+    - name: steps::cache_rust_dependencies_namespace
+      uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9
+      with:
+        cache: rust
+        path: ~/.rustup
+    - id: determine-version
+      name: compliance_check::scheduled_compliance_check
+      run: |
+        VERSION=$(sed -n 's/^version = "\(.*\)"/\1/p' crates/zed/Cargo.toml | tr -d '[:space:]')
+        if [ -z "$VERSION" ]; then
+            echo "Could not determine version from crates/zed/Cargo.toml"
+            exit 1
+        fi
+        TAG="v${VERSION}-pre"
+        echo "Checking compliance for $TAG"
+        echo "tag=$TAG" >> "$GITHUB_OUTPUT"
+    - id: run-compliance-check
+      name: compliance_check::scheduled_compliance_check::run_compliance_check
+      run: cargo xtask compliance "$LATEST_TAG" --branch main --report-path target/compliance-report
+      env:
+        LATEST_TAG: ${{ steps.determine-version.outputs.tag }}
+        GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }}
+        GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }}
+    - name: compliance_check::scheduled_compliance_check::send_failure_slack_notification
+      if: failure()
+      run: |
+        MESSAGE="⚠️ Scheduled compliance check failed for upcoming preview release $LATEST_TAG: There are PRs with missing reviews."
+
+        curl -X POST -H 'Content-type: application/json' \
+            --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+            "$SLACK_WEBHOOK"
+      env:
+        SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }}
+        LATEST_TAG: ${{ steps.determine-version.outputs.tag }}
+defaults:
+  run:
+    shell: bash -euxo pipefail {0}

.github/workflows/release.yml 🔗

@@ -293,6 +293,51 @@ jobs:
       env:
         GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
     timeout-minutes: 60
+  compliance_check:
+    if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
+    runs-on: namespace-profile-16x32-ubuntu-2204
+    env:
+      COMPLIANCE_FILE_PATH: compliance.md
+    steps:
+    - name: steps::checkout_repo
+      uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd
+      with:
+        clean: false
+        fetch-depth: 0
+        ref: ${{ github.ref }}
+    - name: steps::cache_rust_dependencies_namespace
+      uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9
+      with:
+        cache: rust
+        path: ~/.rustup
+    - id: run-compliance-check
+      name: release::compliance_check::run_compliance_check
+      run: cargo xtask compliance "$GITHUB_REF_NAME" --report-path "$COMPLIANCE_FILE_OUTPUT"
+      env:
+        GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }}
+        GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }}
+    - name: release::compliance_check::send_compliance_slack_notification
+      if: always()
+      run: |
+        if [ "$COMPLIANCE_OUTCOME" == "success" ]; then
+            STATUS="✅ Compliance check passed for $GITHUB_REF_NAME"
+        else
+            STATUS="❌ Compliance check failed for $GITHUB_REF_NAME"
+        fi
+
+        REPORT_CONTENT=""
+        if [ -f "$COMPLIANCE_FILE_OUTPUT" ]; then
+            REPORT_CONTENT=$(cat "$REPORT_FILE")
+        fi
+
+        MESSAGE=$(printf "%s\n\n%s" "$STATUS" "$REPORT_CONTENT")
+
+        curl -X POST -H 'Content-type: application/json' \
+            --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+            "$SLACK_WEBHOOK"
+      env:
+        SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }}
+        COMPLIANCE_OUTCOME: ${{ steps.run-compliance-check.outcome }}
   bundle_linux_aarch64:
     needs:
     - run_tests_linux
@@ -613,6 +658,45 @@ jobs:
         echo "All expected assets are present in the release."
       env:
         GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+    - name: steps::checkout_repo
+      uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd
+      with:
+        clean: false
+        fetch-depth: 0
+        ref: ${{ github.ref }}
+    - name: steps::cache_rust_dependencies_namespace
+      uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9
+      with:
+        cache: rust
+        path: ~/.rustup
+    - id: run-post-upload-compliance-check
+      name: release::validate_release_assets::run_post_upload_compliance_check
+      run: cargo xtask compliance "$GITHUB_REF_NAME" --report-path target/compliance-report
+      env:
+        GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }}
+        GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }}
+    - name: release::validate_release_assets::send_post_upload_compliance_notification
+      if: always()
+      run: |
+        if [ -z "$COMPLIANCE_OUTCOME" ] || [ "$COMPLIANCE_OUTCOME" == "skipped" ]; then
+            echo "Compliance check was skipped, not sending notification"
+            exit 0
+        fi
+
+        TAG="$GITHUB_REF_NAME"
+
+        if [ "$COMPLIANCE_OUTCOME" == "success" ]; then
+            MESSAGE="✅ Post-upload compliance re-check passed for $TAG"
+        else
+            MESSAGE="❌ Post-upload compliance re-check failed for $TAG"
+        fi
+
+        curl -X POST -H 'Content-type: application/json' \
+            --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+            "$SLACK_WEBHOOK"
+      env:
+        SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }}
+        COMPLIANCE_OUTCOME: ${{ steps.run-post-upload-compliance-check.outcome }}
   auto_release_preview:
     needs:
     - validate_release_assets

Cargo.lock 🔗

@@ -677,6 +677,15 @@ dependencies = [
  "derive_arbitrary",
 ]
 
+[[package]]
+name = "arc-swap"
+version = "1.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6"
+dependencies = [
+ "rustversion",
+]
+
 [[package]]
 name = "arg_enum_proc_macro"
 version = "0.3.4"
@@ -2530,6 +2539,16 @@ dependencies = [
  "serde",
 ]
 
+[[package]]
+name = "cargo-platform"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "87a0c0e6148f11f01f32650a2ea02d532b2ad4e81d8bd41e6e565b5adc5e6082"
+dependencies = [
+ "serde",
+ "serde_core",
+]
+
 [[package]]
 name = "cargo_metadata"
 version = "0.19.2"
@@ -2537,7 +2556,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "dd5eb614ed4c27c5d706420e4320fbe3216ab31fa1c33cd8246ac36dae4479ba"
 dependencies = [
  "camino",
- "cargo-platform",
+ "cargo-platform 0.1.9",
+ "semver",
+ "serde",
+ "serde_json",
+ "thiserror 2.0.17",
+]
+
+[[package]]
+name = "cargo_metadata"
+version = "0.23.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ef987d17b0a113becdd19d3d0022d04d7ef41f9efe4f3fb63ac44ba61df3ade9"
+dependencies = [
+ "camino",
+ "cargo-platform 0.3.2",
  "semver",
  "serde",
  "serde_json",
@@ -3284,6 +3317,25 @@ dependencies = [
  "workspace",
 ]
 
+[[package]]
+name = "compliance"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "async-trait",
+ "derive_more",
+ "futures 0.3.32",
+ "indoc",
+ "itertools 0.14.0",
+ "jsonwebtoken",
+ "octocrab",
+ "regex",
+ "semver",
+ "serde",
+ "serde_json",
+ "tokio",
+]
+
 [[package]]
 name = "component"
 version = "0.1.0"
@@ -8324,6 +8376,7 @@ dependencies = [
  "http 1.3.1",
  "hyper 1.7.0",
  "hyper-util",
+ "log",
  "rustls 0.23.33",
  "rustls-native-certs 0.8.2",
  "rustls-pki-types",
@@ -8332,6 +8385,19 @@ dependencies = [
  "tower-service",
 ]
 
+[[package]]
+name = "hyper-timeout"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0"
+dependencies = [
+ "hyper 1.7.0",
+ "hyper-util",
+ "pin-project-lite",
+ "tokio",
+ "tower-service",
+]
+
 [[package]]
 name = "hyper-tls"
 version = "0.5.0"
@@ -10008,7 +10074,7 @@ dependencies = [
 [[package]]
 name = "lsp-types"
 version = "0.95.1"
-source = "git+https://github.com/zed-industries/lsp-types?rev=a4f410987660bf560d1e617cb78117c6b6b9f599#a4f410987660bf560d1e617cb78117c6b6b9f599"
+source = "git+https://github.com/zed-industries/lsp-types?rev=c7396459fefc7886b4adfa3b596832405ae1e880#c7396459fefc7886b4adfa3b596832405ae1e880"
 dependencies = [
  "bitflags 1.3.2",
  "serde",
@@ -11380,6 +11446,48 @@ dependencies = [
  "memchr",
 ]
 
+[[package]]
+name = "octocrab"
+version = "0.49.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "63f6687a23731011d0117f9f4c3cdabaa7b5e42ca671f42b5cc0657c492540e3"
+dependencies = [
+ "arc-swap",
+ "async-trait",
+ "base64 0.22.1",
+ "bytes 1.11.1",
+ "cargo_metadata 0.23.1",
+ "cfg-if",
+ "chrono",
+ "either",
+ "futures 0.3.32",
+ "futures-core",
+ "futures-util",
+ "getrandom 0.2.16",
+ "http 1.3.1",
+ "http-body 1.0.1",
+ "http-body-util",
+ "hyper 1.7.0",
+ "hyper-rustls 0.27.7",
+ "hyper-timeout",
+ "hyper-util",
+ "jsonwebtoken",
+ "once_cell",
+ "percent-encoding",
+ "pin-project",
+ "secrecy",
+ "serde",
+ "serde_json",
+ "serde_path_to_error",
+ "serde_urlencoded",
+ "snafu",
+ "tokio",
+ "tower 0.5.2",
+ "tower-http 0.6.6",
+ "url",
+ "web-time",
+]
+
 [[package]]
 name = "ollama"
 version = "0.1.0"
@@ -15381,6 +15489,15 @@ dependencies = [
  "zeroize",
 ]
 
+[[package]]
+name = "secrecy"
+version = "0.10.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
+dependencies = [
+ "zeroize",
+]
+
 [[package]]
 name = "security-framework"
 version = "2.11.1"
@@ -16085,6 +16202,27 @@ version = "0.3.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "0f7a918bd2a9951d18ee6e48f076843e8e73a9a5d22cf05bcd4b7a81bdd04e17"
 
+[[package]]
+name = "snafu"
+version = "0.8.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6e84b3f4eacbf3a1ce05eac6763b4d629d60cbc94d632e4092c54ade71f1e1a2"
+dependencies = [
+ "snafu-derive",
+]
+
+[[package]]
+name = "snafu-derive"
+version = "0.8.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
+dependencies = [
+ "heck 0.5.0",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.117",
+]
+
 [[package]]
 name = "snippet"
 version = "0.1.0"
@@ -18089,8 +18227,10 @@ dependencies = [
  "pin-project-lite",
  "sync_wrapper 1.0.2",
  "tokio",
+ "tokio-util",
  "tower-layer",
  "tower-service",
+ "tracing",
 ]
 
 [[package]]
@@ -18128,6 +18268,7 @@ dependencies = [
  "tower 0.5.2",
  "tower-layer",
  "tower-service",
+ "tracing",
 ]
 
 [[package]]
@@ -19974,6 +20115,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
 dependencies = [
  "js-sys",
+ "serde",
  "wasm-bindgen",
 ]
 
@@ -21711,9 +21853,10 @@ dependencies = [
  "annotate-snippets",
  "anyhow",
  "backtrace",
- "cargo_metadata",
+ "cargo_metadata 0.19.2",
  "cargo_toml",
  "clap",
+ "compliance",
  "gh-workflow",
  "indexmap",
  "indoc",
@@ -21723,6 +21866,7 @@ dependencies = [
  "serde_json",
  "serde_yaml",
  "strum 0.27.2",
+ "tokio",
  "toml 0.8.23",
  "toml_edit 0.22.27",
 ]
@@ -22398,6 +22542,7 @@ name = "zeta_prompt"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "imara-diff",
  "indoc",
  "serde",
  "strum 0.27.2",

Cargo.toml 🔗

@@ -242,6 +242,7 @@ members = [
     # Tooling
     #
 
+    "tooling/compliance",
     "tooling/perf",
     "tooling/xtask",
 ]
@@ -289,6 +290,7 @@ collab_ui = { path = "crates/collab_ui" }
 collections = { path = "crates/collections", version = "0.1.0" }
 command_palette = { path = "crates/command_palette" }
 command_palette_hooks = { path = "crates/command_palette_hooks" }
+compliance = { path = "tooling/compliance" }
 component = { path = "crates/component" }
 component_preview = { path = "crates/component_preview" }
 context_server = { path = "crates/context_server" }
@@ -547,6 +549,7 @@ derive_more = { version = "2.1.1", features = [
     "add_assign",
     "deref",
     "deref_mut",
+    "display",
     "from_str",
     "mul",
     "mul_assign",
@@ -596,7 +599,7 @@ linkify = "0.10.0"
 libwebrtc = "0.3.26"
 livekit = { version = "0.7.32", features = ["tokio", "rustls-tls-native-roots"] }
 log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
-lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "a4f410987660bf560d1e617cb78117c6b6b9f599" }
+lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c7396459fefc7886b4adfa3b596832405ae1e880" }
 mach2 = "0.5"
 markup5ever_rcdom = "0.3.0"
 metal = "0.33"

crates/agent/src/tests/edit_file_thread_test.rs 🔗

@@ -202,3 +202,214 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
         );
     });
 }
+
+#[gpui::test]
+async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes(
+    cx: &mut TestAppContext,
+) {
+    super::init_test(cx);
+    super::always_allow_tools(cx);
+
+    // Enable the streaming edit file tool feature flag.
+    cx.update(|cx| {
+        cx.update_flags(true, vec!["streaming-edit-file-tool".to_string()]);
+    });
+
+    let fs = FakeFs::new(cx.executor());
+    fs.insert_tree(
+        path!("/project"),
+        json!({
+            "src": {
+                "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}\n"
+            }
+        }),
+    )
+    .await;
+
+    let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+    let project_context = cx.new(|_cx| ProjectContext::default());
+    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+    let context_server_registry =
+        cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
+    let model = Arc::new(FakeLanguageModel::default());
+    model.as_fake().set_supports_streaming_tools(true);
+    let fake_model = model.as_fake();
+
+    let thread = cx.new(|cx| {
+        let mut thread = crate::Thread::new(
+            project.clone(),
+            project_context,
+            context_server_registry,
+            crate::Templates::new(),
+            Some(model.clone()),
+            cx,
+        );
+        let language_registry = project.read(cx).languages().clone();
+        thread.add_tool(crate::StreamingEditFileTool::new(
+            project.clone(),
+            cx.weak_entity(),
+            thread.action_log().clone(),
+            language_registry,
+        ));
+        thread
+    });
+
+    let _events = thread
+        .update(cx, |thread, cx| {
+            thread.send(
+                UserMessageId::new(),
+                ["Write new content to src/main.rs"],
+                cx,
+            )
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    let tool_use_id = "edit_1";
+    let partial_1 = LanguageModelToolUse {
+        id: tool_use_id.into(),
+        name: EditFileTool::NAME.into(),
+        raw_input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write"
+        })
+        .to_string(),
+        input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write"
+        }),
+        is_input_complete: false,
+        thought_signature: None,
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_1));
+    cx.run_until_parked();
+
+    let partial_2 = LanguageModelToolUse {
+        id: tool_use_id.into(),
+        name: EditFileTool::NAME.into(),
+        raw_input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write",
+            "content": "fn main() { /* rewritten */ }"
+        })
+        .to_string(),
+        input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write",
+            "content": "fn main() { /* rewritten */ }"
+        }),
+        is_input_complete: false,
+        thought_signature: None,
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_2));
+    cx.run_until_parked();
+
+    // Now send a json parse error. At this point we have started writing content to the buffer.
+    fake_model.send_last_completion_stream_event(
+        LanguageModelCompletionEvent::ToolUseJsonParseError {
+            id: tool_use_id.into(),
+            tool_name: EditFileTool::NAME.into(),
+            raw_input: r#"{"display_description":"Rewrite main.rs","path":"project/src/main.rs","mode":"write","content":"fn main() { /* rewritten "#.into(),
+            json_parse_error: "EOF while parsing a string at line 1 column 95".into(),
+        },
+    );
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    // cx.executor().advance_clock(Duration::from_secs(5));
+    // cx.run_until_parked();
+
+    assert!(
+        !fake_model.pending_completions().is_empty(),
+        "Thread should have retried after the error"
+    );
+
+    // Respond with a new, well-formed, complete edit_file tool use.
+    let tool_use = LanguageModelToolUse {
+        id: "edit_2".into(),
+        name: EditFileTool::NAME.into(),
+        raw_input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write",
+            "content": "fn main() {\n    println!(\"Hello, rewritten!\");\n}\n"
+        })
+        .to_string(),
+        input: json!({
+            "display_description": "Rewrite main.rs",
+            "path": "project/src/main.rs",
+            "mode": "write",
+            "content": "fn main() {\n    println!(\"Hello, rewritten!\");\n}\n"
+        }),
+        is_input_complete: true,
+        thought_signature: None,
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    let pending_completions = fake_model.pending_completions();
+    assert!(
+        pending_completions.len() == 1,
+        "Expected only the follow-up completion containing the successful tool result"
+    );
+
+    let completion = pending_completions
+        .into_iter()
+        .last()
+        .expect("Expected a completion containing the tool result for edit_2");
+
+    let tool_result = completion
+        .messages
+        .iter()
+        .flat_map(|msg| &msg.content)
+        .find_map(|content| match content {
+            language_model::MessageContent::ToolResult(result)
+                if result.tool_use_id == language_model::LanguageModelToolUseId::from("edit_2") =>
+            {
+                Some(result)
+            }
+            _ => None,
+        })
+        .expect("Should have a tool result for edit_2");
+
+    // Ensure that the second tool call completed successfully and edits were applied.
+    assert!(
+        !tool_result.is_error,
+        "Tool result should succeed, got: {:?}",
+        tool_result
+    );
+    let content_text = match &tool_result.content {
+        language_model::LanguageModelToolResultContent::Text(t) => t.to_string(),
+        other => panic!("Expected text content, got: {:?}", other),
+    };
+    assert!(
+        !content_text.contains("file has been modified since you last read it"),
+        "Did not expect a stale last-read error, got: {content_text}"
+    );
+    assert!(
+        !content_text.contains("This file has unsaved changes"),
+        "Did not expect an unsaved-changes error, got: {content_text}"
+    );
+
+    let file_content = fs
+        .load(path!("/project/src/main.rs").as_ref())
+        .await
+        .expect("file should exist");
+    super::assert_eq!(
+        file_content,
+        "fn main() {\n    println!(\"Hello, rewritten!\");\n}\n",
+        "The second edit should be applied and saved gracefully"
+    );
+
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+}

crates/agent/src/tests/mod.rs 🔗

@@ -3903,6 +3903,117 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
     });
 }
 
+#[gpui::test]
+async fn test_streaming_tool_json_parse_error_is_forwarded_to_running_tool(
+    cx: &mut TestAppContext,
+) {
+    init_test(cx);
+    always_allow_tools(cx);
+
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    thread.update(cx, |thread, _cx| {
+        thread.add_tool(StreamingJsonErrorContextTool);
+    });
+
+    let _events = thread
+        .update(cx, |thread, cx| {
+            thread.send(
+                UserMessageId::new(),
+                ["Use the streaming_json_error_context tool"],
+                cx,
+            )
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    let tool_use = LanguageModelToolUse {
+        id: "tool_1".into(),
+        name: StreamingJsonErrorContextTool::NAME.into(),
+        raw_input: r#"{"text": "partial"#.into(),
+        input: json!({"text": "partial"}),
+        is_input_complete: false,
+        thought_signature: None,
+    };
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
+    cx.run_until_parked();
+
+    fake_model.send_last_completion_stream_event(
+        LanguageModelCompletionEvent::ToolUseJsonParseError {
+            id: "tool_1".into(),
+            tool_name: StreamingJsonErrorContextTool::NAME.into(),
+            raw_input: r#"{"text": "partial"#.into(),
+            json_parse_error: "EOF while parsing a string at line 1 column 17".into(),
+        },
+    );
+    fake_model
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    cx.executor().advance_clock(Duration::from_secs(5));
+    cx.run_until_parked();
+
+    let completion = fake_model
+        .pending_completions()
+        .pop()
+        .expect("No running turn");
+
+    let tool_results: Vec<_> = completion
+        .messages
+        .iter()
+        .flat_map(|message| &message.content)
+        .filter_map(|content| match content {
+            MessageContent::ToolResult(result)
+                if result.tool_use_id == language_model::LanguageModelToolUseId::from("tool_1") =>
+            {
+                Some(result)
+            }
+            _ => None,
+        })
+        .collect();
+
+    assert_eq!(
+        tool_results.len(),
+        1,
+        "Expected exactly 1 tool result for tool_1, got {}: {:#?}",
+        tool_results.len(),
+        tool_results
+    );
+
+    let result = tool_results[0];
+    assert!(result.is_error);
+    let content_text = match &result.content {
+        language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
+        other => panic!("Expected text content, got {:?}", other),
+    };
+    assert!(
+        content_text.contains("Saw partial text 'partial' before invalid JSON"),
+        "Expected tool-enriched partial context, got: {content_text}"
+    );
+    assert!(
+        content_text
+            .contains("Error parsing input JSON: EOF while parsing a string at line 1 column 17"),
+        "Expected forwarded JSON parse error, got: {content_text}"
+    );
+    assert!(
+        !content_text.contains("tool input was not fully received"),
+        "Should not contain orphaned sender error, got: {content_text}"
+    );
+
+    fake_model.send_last_completion_stream_text_chunk("Done");
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+
+    thread.read_with(cx, |thread, _cx| {
+        assert!(
+            thread.is_turn_complete(),
+            "Thread should not be stuck; the turn should have completed",
+        );
+    });
+}
+
 /// Filters out the stop events for asserting against in tests
 fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
     result_events
@@ -3959,6 +4070,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
                             InfiniteTool::NAME: true,
                             CancellationAwareTool::NAME: true,
                             StreamingEchoTool::NAME: true,
+                            StreamingJsonErrorContextTool::NAME: true,
                             StreamingFailingEchoTool::NAME: true,
                             TerminalTool::NAME: true,
                             UpdatePlanTool::NAME: true,

crates/agent/src/tests/test_tools.rs 🔗

@@ -56,13 +56,12 @@ impl AgentTool for StreamingEchoTool {
 
     fn run(
         self: Arc<Self>,
-        mut input: ToolInput<Self::Input>,
+        input: ToolInput<Self::Input>,
         _event_stream: ToolCallEventStream,
         cx: &mut App,
     ) -> Task<Result<String, String>> {
         let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take();
         cx.spawn(async move |_cx| {
-            while input.recv_partial().await.is_some() {}
             let input = input
                 .recv()
                 .await
@@ -75,6 +74,68 @@ impl AgentTool for StreamingEchoTool {
     }
 }
 
+#[derive(JsonSchema, Serialize, Deserialize)]
+pub struct StreamingJsonErrorContextToolInput {
+    /// The text to echo.
+    pub text: String,
+}
+
+pub struct StreamingJsonErrorContextTool;
+
+impl AgentTool for StreamingJsonErrorContextTool {
+    type Input = StreamingJsonErrorContextToolInput;
+    type Output = String;
+
+    const NAME: &'static str = "streaming_json_error_context";
+
+    fn supports_input_streaming() -> bool {
+        true
+    }
+
+    fn kind() -> acp::ToolKind {
+        acp::ToolKind::Other
+    }
+
+    fn initial_title(
+        &self,
+        _input: Result<Self::Input, serde_json::Value>,
+        _cx: &mut App,
+    ) -> SharedString {
+        "Streaming JSON Error Context".into()
+    }
+
+    fn run(
+        self: Arc<Self>,
+        mut input: ToolInput<Self::Input>,
+        _event_stream: ToolCallEventStream,
+        cx: &mut App,
+    ) -> Task<Result<String, String>> {
+        cx.spawn(async move |_cx| {
+            let mut last_partial_text = None;
+
+            loop {
+                match input.next().await {
+                    Ok(ToolInputPayload::Partial(partial)) => {
+                        if let Some(text) = partial.get("text").and_then(|value| value.as_str()) {
+                            last_partial_text = Some(text.to_string());
+                        }
+                    }
+                    Ok(ToolInputPayload::Full(input)) => return Ok(input.text),
+                    Ok(ToolInputPayload::InvalidJson { error_message }) => {
+                        let partial_text = last_partial_text.unwrap_or_default();
+                        return Err(format!(
+                            "Saw partial text '{partial_text}' before invalid JSON: {error_message}"
+                        ));
+                    }
+                    Err(error) => {
+                        return Err(format!("Failed to receive tool input: {error}"));
+                    }
+                }
+            }
+        })
+    }
+}
+
 /// A streaming tool that echoes its input, used to test streaming tool
 /// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
 /// before `is_input_complete`).
@@ -119,7 +180,7 @@ impl AgentTool for StreamingFailingEchoTool {
     ) -> Task<Result<Self::Output, Self::Output>> {
         cx.spawn(async move |_cx| {
             for _ in 0..self.receive_chunks_until_failure {
-                let _ = input.recv_partial().await;
+                let _ = input.next().await;
             }
             Err("failed".into())
         })

crates/agent/src/thread.rs 🔗

@@ -22,13 +22,13 @@ use client::UserStore;
 use cloud_api_types::Plan;
 use collections::{HashMap, HashSet, IndexMap};
 use fs::Fs;
-use futures::stream;
 use futures::{
     FutureExt,
     channel::{mpsc, oneshot},
     future::Shared,
     stream::FuturesUnordered,
 };
+use futures::{StreamExt, stream};
 use gpui::{
     App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
 };
@@ -47,7 +47,6 @@ use schemars::{JsonSchema, Schema};
 use serde::de::DeserializeOwned;
 use serde::{Deserialize, Serialize};
 use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file};
-use smol::stream::StreamExt;
 use std::{
     collections::BTreeMap,
     marker::PhantomData,
@@ -2095,7 +2094,7 @@ impl Thread {
         this.update(cx, |this, _cx| {
             this.pending_message()
                 .tool_results
-                .insert(tool_result.tool_use_id.clone(), tool_result);
+                .insert(tool_result.tool_use_id.clone(), tool_result)
         })?;
         Ok(())
     }
@@ -2195,15 +2194,15 @@ impl Thread {
                 raw_input,
                 json_parse_error,
             } => {
-                return Ok(Some(Task::ready(
-                    self.handle_tool_use_json_parse_error_event(
-                        id,
-                        tool_name,
-                        raw_input,
-                        json_parse_error,
-                        event_stream,
-                    ),
-                )));
+                return Ok(self.handle_tool_use_json_parse_error_event(
+                    id,
+                    tool_name,
+                    raw_input,
+                    json_parse_error,
+                    event_stream,
+                    cancellation_rx,
+                    cx,
+                ));
             }
             UsageUpdate(usage) => {
                 telemetry::event!(
@@ -2304,12 +2303,12 @@ impl Thread {
         if !tool_use.is_input_complete {
             if tool.supports_input_streaming() {
                 let running_turn = self.running_turn.as_mut()?;
-                if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) {
+                if let Some(sender) = running_turn.streaming_tool_inputs.get_mut(&tool_use.id) {
                     sender.send_partial(tool_use.input);
                     return None;
                 }
 
-                let (sender, tool_input) = ToolInputSender::channel();
+                let (mut sender, tool_input) = ToolInputSender::channel();
                 sender.send_partial(tool_use.input);
                 running_turn
                     .streaming_tool_inputs
@@ -2331,13 +2330,13 @@ impl Thread {
             }
         }
 
-        if let Some(sender) = self
+        if let Some(mut sender) = self
             .running_turn
             .as_mut()?
             .streaming_tool_inputs
             .remove(&tool_use.id)
         {
-            sender.send_final(tool_use.input);
+            sender.send_full(tool_use.input);
             return None;
         }
 
@@ -2410,10 +2409,12 @@ impl Thread {
         raw_input: Arc<str>,
         json_parse_error: String,
         event_stream: &ThreadEventStream,
-    ) -> LanguageModelToolResult {
+        cancellation_rx: watch::Receiver<bool>,
+        cx: &mut Context<Self>,
+    ) -> Option<Task<LanguageModelToolResult>> {
         let tool_use = LanguageModelToolUse {
-            id: tool_use_id.clone(),
-            name: tool_name.clone(),
+            id: tool_use_id,
+            name: tool_name,
             raw_input: raw_input.to_string(),
             input: serde_json::json!({}),
             is_input_complete: true,
@@ -2426,14 +2427,43 @@ impl Thread {
             event_stream,
         );
 
-        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
-        LanguageModelToolResult {
-            tool_use_id,
-            tool_name,
-            is_error: true,
-            content: LanguageModelToolResultContent::Text(tool_output.into()),
-            output: Some(serde_json::Value::String(raw_input.to_string())),
+        let tool = self.tool(tool_use.name.as_ref());
+
+        let Some(tool) = tool else {
+            let content = format!("No tool named {} exists", tool_use.name);
+            return Some(Task::ready(LanguageModelToolResult {
+                content: LanguageModelToolResultContent::Text(Arc::from(content)),
+                tool_use_id: tool_use.id,
+                tool_name: tool_use.name,
+                is_error: true,
+                output: None,
+            }));
+        };
+
+        let error_message = format!("Error parsing input JSON: {json_parse_error}");
+
+        if tool.supports_input_streaming()
+            && let Some(mut sender) = self
+                .running_turn
+                .as_mut()?
+                .streaming_tool_inputs
+                .remove(&tool_use.id)
+        {
+            sender.send_invalid_json(error_message);
+            return None;
         }
+
+        log::debug!("Running tool {}. Received invalid JSON", tool_use.name);
+        let tool_input = ToolInput::invalid_json(error_message);
+        Some(self.run_tool(
+            tool,
+            tool_input,
+            tool_use.id,
+            tool_use.name,
+            event_stream,
+            cancellation_rx,
+            cx,
+        ))
     }
 
     fn send_or_update_tool_use(
@@ -3114,8 +3144,7 @@ impl EventEmitter<TitleUpdated> for Thread {}
 /// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams
 /// them, followed by the final complete input available through `.recv()`.
 pub struct ToolInput<T> {
-    partial_rx: mpsc::UnboundedReceiver<serde_json::Value>,
-    final_rx: oneshot::Receiver<serde_json::Value>,
+    rx: mpsc::UnboundedReceiver<ToolInputPayload<serde_json::Value>>,
     _phantom: PhantomData<T>,
 }
 
@@ -3127,13 +3156,20 @@ impl<T: DeserializeOwned> ToolInput<T> {
     }
 
     pub fn ready(value: serde_json::Value) -> Self {
-        let (partial_tx, partial_rx) = mpsc::unbounded();
-        drop(partial_tx);
-        let (final_tx, final_rx) = oneshot::channel();
-        final_tx.send(value).ok();
+        let (tx, rx) = mpsc::unbounded();
+        tx.unbounded_send(ToolInputPayload::Full(value)).ok();
         Self {
-            partial_rx,
-            final_rx,
+            rx,
+            _phantom: PhantomData,
+        }
+    }
+
+    pub fn invalid_json(error_message: String) -> Self {
+        let (tx, rx) = mpsc::unbounded();
+        tx.unbounded_send(ToolInputPayload::InvalidJson { error_message })
+            .ok();
+        Self {
+            rx,
             _phantom: PhantomData,
         }
     }
@@ -3147,65 +3183,89 @@ impl<T: DeserializeOwned> ToolInput<T> {
     /// Wait for the final deserialized input, ignoring all partial updates.
     /// Non-streaming tools can use this to wait until the whole input is available.
     pub async fn recv(mut self) -> Result<T> {
-        // Drain any remaining partials
-        while self.partial_rx.next().await.is_some() {}
+        while let Ok(value) = self.next().await {
+            match value {
+                ToolInputPayload::Full(value) => return Ok(value),
+                ToolInputPayload::Partial(_) => {}
+                ToolInputPayload::InvalidJson { error_message } => {
+                    return Err(anyhow!(error_message));
+                }
+            }
+        }
+        Err(anyhow!("tool input was not fully received"))
+    }
+
+    pub async fn next(&mut self) -> Result<ToolInputPayload<T>> {
         let value = self
-            .final_rx
+            .rx
+            .next()
             .await
-            .map_err(|_| anyhow!("tool input was not fully received"))?;
-        serde_json::from_value(value).map_err(Into::into)
-    }
+            .ok_or_else(|| anyhow!("tool input was not fully received"))?;
 
-    /// Returns the next partial JSON snapshot, or `None` when input is complete.
-    /// Once this returns `None`, call `recv()` to get the final input.
-    pub async fn recv_partial(&mut self) -> Option<serde_json::Value> {
-        self.partial_rx.next().await
+        Ok(match value {
+            ToolInputPayload::Partial(payload) => ToolInputPayload::Partial(payload),
+            ToolInputPayload::Full(payload) => {
+                ToolInputPayload::Full(serde_json::from_value(payload)?)
+            }
+            ToolInputPayload::InvalidJson { error_message } => {
+                ToolInputPayload::InvalidJson { error_message }
+            }
+        })
     }
 
     fn cast<U: DeserializeOwned>(self) -> ToolInput<U> {
         ToolInput {
-            partial_rx: self.partial_rx,
-            final_rx: self.final_rx,
+            rx: self.rx,
             _phantom: PhantomData,
         }
     }
 }
 
+pub enum ToolInputPayload<T> {
+    Partial(serde_json::Value),
+    Full(T),
+    InvalidJson { error_message: String },
+}
+
 pub struct ToolInputSender {
-    partial_tx: mpsc::UnboundedSender<serde_json::Value>,
-    final_tx: Option<oneshot::Sender<serde_json::Value>>,
+    has_received_final: bool,
+    tx: mpsc::UnboundedSender<ToolInputPayload<serde_json::Value>>,
 }
 
 impl ToolInputSender {
     pub(crate) fn channel() -> (Self, ToolInput<serde_json::Value>) {
-        let (partial_tx, partial_rx) = mpsc::unbounded();
-        let (final_tx, final_rx) = oneshot::channel();
+        let (tx, rx) = mpsc::unbounded();
         let sender = Self {
-            partial_tx,
-            final_tx: Some(final_tx),
+            tx,
+            has_received_final: false,
         };
         let input = ToolInput {
-            partial_rx,
-            final_rx,
+            rx,
             _phantom: PhantomData,
         };
         (sender, input)
     }
 
     pub(crate) fn has_received_final(&self) -> bool {
-        self.final_tx.is_none()
+        self.has_received_final
     }
 
-    pub(crate) fn send_partial(&self, value: serde_json::Value) {
-        self.partial_tx.unbounded_send(value).ok();
+    pub fn send_partial(&mut self, payload: serde_json::Value) {
+        self.tx
+            .unbounded_send(ToolInputPayload::Partial(payload))
+            .ok();
     }
 
-    pub(crate) fn send_final(mut self, value: serde_json::Value) {
-        // Close the partial channel so recv_partial() returns None
-        self.partial_tx.close_channel();
-        if let Some(final_tx) = self.final_tx.take() {
-            final_tx.send(value).ok();
-        }
+    pub fn send_full(&mut self, payload: serde_json::Value) {
+        self.has_received_final = true;
+        self.tx.unbounded_send(ToolInputPayload::Full(payload)).ok();
+    }
+
+    pub fn send_invalid_json(&mut self, error_message: String) {
+        self.has_received_final = true;
+        self.tx
+            .unbounded_send(ToolInputPayload::InvalidJson { error_message })
+            .ok();
     }
 }
 
@@ -4251,68 +4311,78 @@ mod tests {
     ) {
         let (thread, event_stream) = setup_thread_for_test(cx).await;
 
-        cx.update(|cx| {
-            thread.update(cx, |thread, _cx| {
-                let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
-                let tool_name: Arc<str> = Arc::from("test_tool");
-                let raw_input: Arc<str> = Arc::from("{invalid json");
-                let json_parse_error = "expected value at line 1 column 1".to_string();
-
-                // Call the function under test
-                let result = thread.handle_tool_use_json_parse_error_event(
-                    tool_use_id.clone(),
-                    tool_name.clone(),
-                    raw_input.clone(),
-                    json_parse_error,
-                    &event_stream,
-                );
-
-                // Verify the result is an error
-                assert!(result.is_error);
-                assert_eq!(result.tool_use_id, tool_use_id);
-                assert_eq!(result.tool_name, tool_name);
-                assert!(matches!(
-                    result.content,
-                    LanguageModelToolResultContent::Text(_)
-                ));
-
-                // Verify the tool use was added to the message content
-                {
-                    let last_message = thread.pending_message();
-                    assert_eq!(
-                        last_message.content.len(),
-                        1,
-                        "Should have one tool_use in content"
-                    );
-
-                    match &last_message.content[0] {
-                        AgentMessageContent::ToolUse(tool_use) => {
-                            assert_eq!(tool_use.id, tool_use_id);
-                            assert_eq!(tool_use.name, tool_name);
-                            assert_eq!(tool_use.raw_input, raw_input.to_string());
-                            assert!(tool_use.is_input_complete);
-                            // Should fall back to empty object for invalid JSON
-                            assert_eq!(tool_use.input, json!({}));
-                        }
-                        _ => panic!("Expected ToolUse content"),
-                    }
-                }
-
-                // Insert the tool result (simulating what the caller does)
-                thread
-                    .pending_message()
-                    .tool_results
-                    .insert(result.tool_use_id.clone(), result);
+        let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
+        let tool_name: Arc<str> = Arc::from("test_tool");
+        let raw_input: Arc<str> = Arc::from("{invalid json");
+        let json_parse_error = "expected value at line 1 column 1".to_string();
+
+        let (_cancellation_tx, cancellation_rx) = watch::channel(false);
+
+        let result = cx
+            .update(|cx| {
+                thread.update(cx, |thread, cx| {
+                    // Call the function under test
+                    thread
+                        .handle_tool_use_json_parse_error_event(
+                            tool_use_id.clone(),
+                            tool_name.clone(),
+                            raw_input.clone(),
+                            json_parse_error,
+                            &event_stream,
+                            cancellation_rx,
+                            cx,
+                        )
+                        .unwrap()
+                })
+            })
+            .await;
+
+        // Verify the result is an error
+        assert!(result.is_error);
+        assert_eq!(result.tool_use_id, tool_use_id);
+        assert_eq!(result.tool_name, tool_name);
+        assert!(matches!(
+            result.content,
+            LanguageModelToolResultContent::Text(_)
+        ));
 
-                // Verify the tool result was added
+        thread.update(cx, |thread, _cx| {
+            // Verify the tool use was added to the message content
+            {
                 let last_message = thread.pending_message();
                 assert_eq!(
-                    last_message.tool_results.len(),
+                    last_message.content.len(),
                     1,
-                    "Should have one tool_result"
+                    "Should have one tool_use in content"
                 );
-                assert!(last_message.tool_results.contains_key(&tool_use_id));
-            });
-        });
+
+                match &last_message.content[0] {
+                    AgentMessageContent::ToolUse(tool_use) => {
+                        assert_eq!(tool_use.id, tool_use_id);
+                        assert_eq!(tool_use.name, tool_name);
+                        assert_eq!(tool_use.raw_input, raw_input.to_string());
+                        assert!(tool_use.is_input_complete);
+                        // Should fall back to empty object for invalid JSON
+                        assert_eq!(tool_use.input, json!({}));
+                    }
+                    _ => panic!("Expected ToolUse content"),
+                }
+            }
+
+            // Insert the tool result (simulating what the caller does)
+            thread
+                .pending_message()
+                .tool_results
+                .insert(result.tool_use_id.clone(), result);
+
+            // Verify the tool result was added
+            let last_message = thread.pending_message();
+            assert_eq!(
+                last_message.tool_results.len(),
+                1,
+                "Should have one tool_result"
+            );
+            assert!(last_message.tool_results.contains_key(&tool_use_id));
+        })
     }
 }

crates/agent/src/tools/streaming_edit_file_tool.rs 🔗

@@ -2,6 +2,7 @@ use super::edit_file_tool::EditFileTool;
 use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
 use super::save_file_tool::SaveFileTool;
 use super::tool_edit_parser::{ToolEditEvent, ToolEditParser};
+use crate::ToolInputPayload;
 use crate::{
     AgentTool, Thread, ToolCallEventStream, ToolInput,
     edit_agent::{
@@ -12,7 +13,7 @@ use crate::{
 use acp_thread::Diff;
 use action_log::ActionLog;
 use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
-use anyhow::{Context as _, Result};
+use anyhow::Result;
 use collections::HashSet;
 use futures::FutureExt as _;
 use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
@@ -188,6 +189,10 @@ pub enum StreamingEditFileToolOutput {
     },
     Error {
         error: String,
+        #[serde(default)]
+        input_path: Option<PathBuf>,
+        #[serde(default)]
+        diff: String,
     },
 }
 
@@ -195,6 +200,8 @@ impl StreamingEditFileToolOutput {
     pub fn error(error: impl Into<String>) -> Self {
         Self::Error {
             error: error.into(),
+            input_path: None,
+            diff: String::new(),
         }
     }
 }
@@ -215,7 +222,24 @@ impl std::fmt::Display for StreamingEditFileToolOutput {
                     )
                 }
             }
-            StreamingEditFileToolOutput::Error { error } => write!(f, "{error}"),
+            StreamingEditFileToolOutput::Error {
+                error,
+                diff,
+                input_path,
+            } => {
+                write!(f, "{error}\n")?;
+                if let Some(input_path) = input_path
+                    && !diff.is_empty()
+                {
+                    write!(
+                        f,
+                        "Edited {}:\n\n```diff\n{diff}\n```",
+                        input_path.display()
+                    )
+                } else {
+                    write!(f, "No edits were made.")
+                }
+            }
         }
     }
 }
@@ -233,6 +257,14 @@ pub struct StreamingEditFileTool {
     language_registry: Arc<LanguageRegistry>,
 }
 
+enum EditSessionResult {
+    Completed(EditSession),
+    Failed {
+        error: String,
+        session: Option<EditSession>,
+    },
+}
+
 impl StreamingEditFileTool {
     pub fn new(
         project: Entity<Project>,
@@ -276,6 +308,158 @@ impl StreamingEditFileTool {
             });
         }
     }
+
+    async fn ensure_buffer_saved(&self, buffer: &Entity<Buffer>, cx: &mut AsyncApp) {
+        let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
+            let settings = language_settings::LanguageSettings::for_buffer(buffer, cx);
+            settings.format_on_save != FormatOnSave::Off
+        });
+
+        if format_on_save_enabled {
+            self.project
+                .update(cx, |project, cx| {
+                    project.format(
+                        HashSet::from_iter([buffer.clone()]),
+                        LspFormatTarget::Buffers,
+                        false,
+                        FormatTrigger::Save,
+                        cx,
+                    )
+                })
+                .await
+                .log_err();
+        }
+
+        self.project
+            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+            .await
+            .log_err();
+
+        self.action_log.update(cx, |log, cx| {
+            log.buffer_edited(buffer.clone(), cx);
+        });
+    }
+
+    async fn process_streaming_edits(
+        &self,
+        input: &mut ToolInput<StreamingEditFileToolInput>,
+        event_stream: &ToolCallEventStream,
+        cx: &mut AsyncApp,
+    ) -> EditSessionResult {
+        let mut session: Option<EditSession> = None;
+        let mut last_partial: Option<StreamingEditFileToolPartialInput> = None;
+
+        loop {
+            futures::select! {
+                payload = input.next().fuse() => {
+                    match payload {
+                        Ok(payload) => match payload {
+                            ToolInputPayload::Partial(partial) => {
+                                if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial) {
+                                    let path_complete = parsed.path.is_some()
+                                        && parsed.path.as_ref() == last_partial.as_ref().and_then(|partial| partial.path.as_ref());
+
+                                    last_partial = Some(parsed.clone());
+
+                                    if session.is_none()
+                                        && path_complete
+                                        && let StreamingEditFileToolPartialInput {
+                                            path: Some(path),
+                                            display_description: Some(display_description),
+                                            mode: Some(mode),
+                                            ..
+                                        } = &parsed
+                                    {
+                                        match EditSession::new(
+                                            PathBuf::from(path),
+                                            display_description,
+                                            *mode,
+                                            self,
+                                            event_stream,
+                                            cx,
+                                        )
+                                        .await
+                                        {
+                                            Ok(created_session) => session = Some(created_session),
+                                            Err(error) => {
+                                                log::error!("Failed to create edit session: {}", error);
+                                                return EditSessionResult::Failed {
+                                                    error,
+                                                    session: None,
+                                                };
+                                            }
+                                        }
+                                    }
+
+                                    if let Some(current_session) = &mut session
+                                        && let Err(error) = current_session.process(parsed, self, event_stream, cx)
+                                    {
+                                        log::error!("Failed to process edit: {}", error);
+                                        return EditSessionResult::Failed { error, session };
+                                    }
+                                }
+                            }
+                            ToolInputPayload::Full(full_input) => {
+                                let mut session = if let Some(session) = session {
+                                    session
+                                } else {
+                                    match EditSession::new(
+                                        full_input.path.clone(),
+                                        &full_input.display_description,
+                                        full_input.mode,
+                                        self,
+                                        event_stream,
+                                        cx,
+                                    )
+                                    .await
+                                    {
+                                        Ok(created_session) => created_session,
+                                        Err(error) => {
+                                            log::error!("Failed to create edit session: {}", error);
+                                            return EditSessionResult::Failed {
+                                                error,
+                                                session: None,
+                                            };
+                                        }
+                                    }
+                                };
+
+                                return match session.finalize(full_input, self, event_stream, cx).await {
+                                    Ok(()) => EditSessionResult::Completed(session),
+                                    Err(error) => {
+                                        log::error!("Failed to finalize edit: {}", error);
+                                        EditSessionResult::Failed {
+                                            error,
+                                            session: Some(session),
+                                        }
+                                    }
+                                };
+                            }
+                            ToolInputPayload::InvalidJson { error_message } => {
+                                log::error!("Received invalid JSON: {error_message}");
+                                return EditSessionResult::Failed {
+                                    error: error_message,
+                                    session,
+                                };
+                            }
+                        },
+                        Err(error) => {
+                            return EditSessionResult::Failed {
+                                error: format!("Failed to receive tool input: {error}"),
+                                session,
+                            };
+                        }
+                    }
+                }
+                _ = event_stream.cancelled_by_user().fuse() => {
+                    return EditSessionResult::Failed {
+                        error: "Edit cancelled by user".to_string(),
+                        session,
+                    };
+                }
+            }
+        }
+    }
 }
 
 impl AgentTool for StreamingEditFileTool {
@@ -348,94 +532,40 @@ impl AgentTool for StreamingEditFileTool {
         cx: &mut App,
     ) -> Task<Result<Self::Output, Self::Output>> {
         cx.spawn(async move |cx: &mut AsyncApp| {
-            let mut state: Option<EditSession> = None;
-            let mut last_partial: Option<StreamingEditFileToolPartialInput> = None;
-            loop {
-                futures::select! {
-                    partial = input.recv_partial().fuse() => {
-                        let Some(partial_value) = partial else { break };
-                        if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial_value) {
-                            let path_complete = parsed.path.is_some()
-                                && parsed.path.as_ref() == last_partial.as_ref().and_then(|p| p.path.as_ref());
-
-                            last_partial = Some(parsed.clone());
-
-                            if state.is_none()
-                                && path_complete
-                                && let StreamingEditFileToolPartialInput {
-                                    path: Some(path),
-                                    display_description: Some(display_description),
-                                    mode: Some(mode),
-                                    ..
-                                } = &parsed
-                            {
-                                match EditSession::new(
-                                    &PathBuf::from(path),
-                                    display_description,
-                                    *mode,
-                                    &self,
-                                    &event_stream,
-                                    cx,
-                                )
-                                .await
-                                {
-                                    Ok(session) => state = Some(session),
-                                    Err(e) => {
-                                        log::error!("Failed to create edit session: {}", e);
-                                        return Err(e);
-                                    }
-                                }
-                            }
-
-                            if let Some(state) = &mut state {
-                                if let Err(e) = state.process(parsed, &self, &event_stream, cx) {
-                                    log::error!("Failed to process edit: {}", e);
-                                    return Err(e);
-                                }
-                            }
-                        }
-                    }
-                    _ = event_stream.cancelled_by_user().fuse() => {
-                        return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
-                    }
-                }
-            }
-            let full_input =
-                input
-                    .recv()
-                    .await
-                    .map_err(|e| {
-                        let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}"));
-                        log::error!("Failed to receive tool input: {e}");
-                        err
-                    })?;
-
-            let mut state = if let Some(state) = state {
-                state
-            } else {
-                match EditSession::new(
-                    &full_input.path,
-                    &full_input.display_description,
-                    full_input.mode,
-                    &self,
-                    &event_stream,
-                    cx,
-                )
+            match self
+                .process_streaming_edits(&mut input, &event_stream, cx)
                 .await
-                {
-                    Ok(session) => session,
-                    Err(e) => {
-                        log::error!("Failed to create edit session: {}", e);
-                        return Err(e);
-                    }
+            {
+                EditSessionResult::Completed(session) => {
+                    self.ensure_buffer_saved(&session.buffer, cx).await;
+                    let (new_text, diff) = session.compute_new_text_and_diff(cx).await;
+                    Ok(StreamingEditFileToolOutput::Success {
+                        old_text: session.old_text.clone(),
+                        new_text,
+                        input_path: session.input_path,
+                        diff,
+                    })
                 }
-            };
-            match state.finalize(full_input, &self, &event_stream, cx).await {
-                Ok(output) => Ok(output),
-                Err(e) => {
-                    log::error!("Failed to finalize edit: {}", e);
-                    Err(e)
+                EditSessionResult::Failed {
+                    error,
+                    session: Some(session),
+                } => {
+                    self.ensure_buffer_saved(&session.buffer, cx).await;
+                    let (_new_text, diff) = session.compute_new_text_and_diff(cx).await;
+                    Err(StreamingEditFileToolOutput::Error {
+                        error,
+                        input_path: Some(session.input_path),
+                        diff,
+                    })
                 }
+                EditSessionResult::Failed {
+                    error,
+                    session: None,
+                } => Err(StreamingEditFileToolOutput::Error {
+                    error,
+                    input_path: None,
+                    diff: String::new(),
+                }),
             }
         })
     }
@@ -472,6 +602,7 @@ impl AgentTool for StreamingEditFileTool {
 
 pub struct EditSession {
     abs_path: PathBuf,
+    input_path: PathBuf,
     buffer: Entity<Buffer>,
     old_text: Arc<String>,
     diff: Entity<Diff>,
@@ -518,23 +649,21 @@ impl EditPipeline {
 
 impl EditSession {
     async fn new(
-        path: &PathBuf,
+        path: PathBuf,
         display_description: &str,
         mode: StreamingEditFileMode,
         tool: &StreamingEditFileTool,
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
-    ) -> Result<Self, StreamingEditFileToolOutput> {
-        let project_path = cx
-            .update(|cx| resolve_path(mode, &path, &tool.project, cx))
-            .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+    ) -> Result<Self, String> {
+        let project_path = cx.update(|cx| resolve_path(mode, &path, &tool.project, cx))?;
 
         let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx))
         else {
-            return Err(StreamingEditFileToolOutput::error(format!(
+            return Err(format!(
                 "Worktree at '{}' does not exist",
                 path.to_string_lossy()
-            )));
+            ));
         };
 
         event_stream.update_fields(
@@ -543,13 +672,13 @@ impl EditSession {
 
         cx.update(|cx| tool.authorize(&path, &display_description, event_stream, cx))
             .await
-            .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+            .map_err(|e| e.to_string())?;
 
         let buffer = tool
             .project
             .update(cx, |project, cx| project.open_buffer(project_path, cx))
             .await
-            .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+            .map_err(|e| e.to_string())?;
 
         ensure_buffer_saved(&buffer, &abs_path, tool, cx)?;
 
@@ -578,6 +707,7 @@ impl EditSession {
 
         Ok(Self {
             abs_path,
+            input_path: path,
             buffer,
             old_text,
             diff,
@@ -594,22 +724,20 @@ impl EditSession {
         tool: &StreamingEditFileTool,
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
-    ) -> Result<StreamingEditFileToolOutput, StreamingEditFileToolOutput> {
-        let old_text = self.old_text.clone();
-
+    ) -> Result<(), String> {
         match input.mode {
             StreamingEditFileMode::Write => {
-                let content = input.content.ok_or_else(|| {
-                    StreamingEditFileToolOutput::error("'content' field is required for write mode")
-                })?;
+                let content = input
+                    .content
+                    .ok_or_else(|| "'content' field is required for write mode".to_string())?;
 
                 let events = self.parser.finalize_content(&content);
                 self.process_events(&events, tool, event_stream, cx)?;
             }
             StreamingEditFileMode::Edit => {
-                let edits = input.edits.ok_or_else(|| {
-                    StreamingEditFileToolOutput::error("'edits' field is required for edit mode")
-                })?;
+                let edits = input
+                    .edits
+                    .ok_or_else(|| "'edits' field is required for edit mode".to_string())?;
                 let events = self.parser.finalize_edits(&edits);
                 self.process_events(&events, tool, event_stream, cx)?;
 
@@ -625,53 +753,15 @@ impl EditSession {
                 }
             }
         }
+        Ok(())
+    }
 
-        let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| {
-            let settings = language_settings::LanguageSettings::for_buffer(buffer, cx);
-            settings.format_on_save != FormatOnSave::Off
-        });
-
-        if format_on_save_enabled {
-            tool.action_log.update(cx, |log, cx| {
-                log.buffer_edited(self.buffer.clone(), cx);
-            });
-
-            let format_task = tool.project.update(cx, |project, cx| {
-                project.format(
-                    HashSet::from_iter([self.buffer.clone()]),
-                    LspFormatTarget::Buffers,
-                    false,
-                    FormatTrigger::Save,
-                    cx,
-                )
-            });
-            futures::select! {
-                result = format_task.fuse() => { result.log_err(); },
-                _ = event_stream.cancelled_by_user().fuse() => {
-                    return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
-                }
-            };
-        }
-
-        let save_task = tool.project.update(cx, |project, cx| {
-            project.save_buffer(self.buffer.clone(), cx)
-        });
-        futures::select! {
-            result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; },
-            _ = event_stream.cancelled_by_user().fuse() => {
-                return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
-            }
-        };
-
-        tool.action_log.update(cx, |log, cx| {
-            log.buffer_edited(self.buffer.clone(), cx);
-        });
-
+    async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) {
         let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
         let (new_text, unified_diff) = cx
             .background_spawn({
                 let new_snapshot = new_snapshot.clone();
-                let old_text = old_text.clone();
+                let old_text = self.old_text.clone();
                 async move {
                     let new_text = new_snapshot.text();
                     let diff = language::unified_diff(&old_text, &new_text);
@@ -679,14 +769,7 @@ impl EditSession {
                 }
             })
             .await;
-
-        let output = StreamingEditFileToolOutput::Success {
-            input_path: input.path,
-            new_text,
-            old_text: old_text.clone(),
-            diff: unified_diff,
-        };
-        Ok(output)
+        (new_text, unified_diff)
     }
 
     fn process(
@@ -695,7 +778,7 @@ impl EditSession {
         tool: &StreamingEditFileTool,
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
-    ) -> Result<(), StreamingEditFileToolOutput> {
+    ) -> Result<(), String> {
         match &self.mode {
             StreamingEditFileMode::Write => {
                 if let Some(content) = &partial.content {
@@ -719,7 +802,7 @@ impl EditSession {
         tool: &StreamingEditFileTool,
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
-    ) -> Result<(), StreamingEditFileToolOutput> {
+    ) -> Result<(), String> {
         for event in events {
             match event {
                 ToolEditEvent::ContentChunk { chunk } => {
@@ -969,14 +1052,14 @@ fn extract_match(
     buffer: &Entity<Buffer>,
     edit_index: &usize,
     cx: &mut AsyncApp,
-) -> Result<Range<usize>, StreamingEditFileToolOutput> {
+) -> Result<Range<usize>, String> {
     match matches.len() {
-        0 => Err(StreamingEditFileToolOutput::error(format!(
+        0 => Err(format!(
             "Could not find matching text for edit at index {}. \
                 The old_text did not match any content in the file. \
                 Please read the file again to get the current content.",
             edit_index,
-        ))),
+        )),
         1 => Ok(matches.into_iter().next().unwrap()),
         _ => {
             let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
@@ -985,12 +1068,12 @@ fn extract_match(
                 .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
                 .collect::<Vec<_>>()
                 .join(", ");
-            Err(StreamingEditFileToolOutput::error(format!(
+            Err(format!(
                 "Edit {} matched multiple locations in the file at lines: {}. \
                     Please provide more context in old_text to uniquely \
                     identify the location.",
                 edit_index, lines
-            )))
+            ))
         }
     }
 }
@@ -1022,7 +1105,7 @@ fn ensure_buffer_saved(
     abs_path: &PathBuf,
     tool: &StreamingEditFileTool,
     cx: &mut AsyncApp,
-) -> Result<(), StreamingEditFileToolOutput> {
+) -> Result<(), String> {
     let last_read_mtime = tool
         .action_log
         .read_with(cx, |log, _| log.file_read_time(abs_path));
@@ -1063,15 +1146,14 @@ fn ensure_buffer_saved(
                          then ask them to save or revert the file manually and inform you when it's ok to proceed."
             }
         };
-        return Err(StreamingEditFileToolOutput::error(message));
+        return Err(message.to_string());
     }
 
     if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
         if current != last_read {
-            return Err(StreamingEditFileToolOutput::error(
-                "The file has been modified since you last read it. \
-                             Please read the file again to get the current state before editing it.",
-            ));
+            return Err("The file has been modified since you last read it. \
+                    Please read the file again to get the current state before editing it."
+                .to_string());
         }
     }
 
@@ -1083,56 +1165,63 @@ fn resolve_path(
     path: &PathBuf,
     project: &Entity<Project>,
     cx: &mut App,
-) -> Result<ProjectPath> {
+) -> Result<ProjectPath, String> {
     let project = project.read(cx);
 
     match mode {
         StreamingEditFileMode::Edit => {
             let path = project
                 .find_project_path(&path, cx)
-                .context("Can't edit file: path not found")?;
+                .ok_or_else(|| "Can't edit file: path not found".to_string())?;
 
             let entry = project
                 .entry_for_path(&path, cx)
-                .context("Can't edit file: path not found")?;
+                .ok_or_else(|| "Can't edit file: path not found".to_string())?;
 
-            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
-            Ok(path)
+            if entry.is_file() {
+                Ok(path)
+            } else {
+                Err("Can't edit file: path is a directory".to_string())
+            }
         }
         StreamingEditFileMode::Write => {
             if let Some(path) = project.find_project_path(&path, cx)
                 && let Some(entry) = project.entry_for_path(&path, cx)
             {
-                anyhow::ensure!(entry.is_file(), "Can't write to file: path is a directory");
-                return Ok(path);
+                if entry.is_file() {
+                    return Ok(path);
+                } else {
+                    return Err("Can't write to file: path is a directory".to_string());
+                }
             }
 
-            let parent_path = path.parent().context("Can't create file: incorrect path")?;
+            let parent_path = path
+                .parent()
+                .ok_or_else(|| "Can't create file: incorrect path".to_string())?;
 
             let parent_project_path = project.find_project_path(&parent_path, cx);
 
             let parent_entry = parent_project_path
                 .as_ref()
                 .and_then(|path| project.entry_for_path(path, cx))
-                .context("Can't create file: parent directory doesn't exist")?;
+                .ok_or_else(|| "Can't create file: parent directory doesn't exist")?;
 
-            anyhow::ensure!(
-                parent_entry.is_dir(),
-                "Can't create file: parent is not a directory"
-            );
+            if !parent_entry.is_dir() {
+                return Err("Can't create file: parent is not a directory".to_string());
+            }
 
             let file_name = path
                 .file_name()
                 .and_then(|file_name| file_name.to_str())
                 .and_then(|file_name| RelPath::unix(file_name).ok())
-                .context("Can't create file: invalid filename")?;
+                .ok_or_else(|| "Can't create file: invalid filename".to_string())?;
 
             let new_file_path = parent_project_path.map(|parent| ProjectPath {
                 path: parent.path.join(file_name),
                 ..parent
             });
 
-            new_file_path.context("Can't create file")
+            new_file_path.ok_or_else(|| "Can't create file".to_string())
         }
     }
 }
@@ -1382,10 +1471,17 @@ mod tests {
             })
             .await;
 
-        let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+        let StreamingEditFileToolOutput::Error {
+            error,
+            diff,
+            input_path,
+        } = result.unwrap_err()
+        else {
             panic!("expected error");
         };
         assert_eq!(error, "Can't edit file: path not found");
+        assert!(diff.is_empty());
+        assert_eq!(input_path, None);
     }
 
     #[gpui::test]
@@ -1411,7 +1507,7 @@ mod tests {
             })
             .await;
 
-        let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+        let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else {
             panic!("expected error");
         };
         assert!(
@@ -1424,7 +1520,7 @@ mod tests {
     async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1447,7 +1543,7 @@ mod tests {
         cx.run_until_parked();
 
         // Now send the final complete input
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1465,7 +1561,7 @@ mod tests {
     async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "hello world"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1485,7 +1581,7 @@ mod tests {
         cx.run_until_parked();
 
         // Send final
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Overwrite file",
             "path": "root/file.txt",
             "mode": "write",
@@ -1503,7 +1599,7 @@ mod tests {
     async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "hello world"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver, mut cancellation_tx) =
             ToolCallEventStream::test_with_cancellation();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1521,7 +1617,7 @@ mod tests {
         drop(sender);
 
         let result = task.await;
-        let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+        let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else {
             panic!("expected error");
         };
         assert!(
@@ -1537,7 +1633,7 @@ mod tests {
             json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}),
         )
         .await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1578,7 +1674,7 @@ mod tests {
         cx.run_until_parked();
 
         // Send final complete input
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit multiple lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1601,7 +1697,7 @@ mod tests {
     #[gpui::test]
     async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1625,7 +1721,7 @@ mod tests {
         cx.run_until_parked();
 
         // Final with full content
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Create new file",
             "path": "root/dir/new_file.txt",
             "mode": "write",
@@ -1643,12 +1739,12 @@ mod tests {
     async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
         // Send final immediately with no partials (simulates non-streaming path)
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1669,7 +1765,7 @@ mod tests {
             json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}),
         )
         .await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1739,7 +1835,7 @@ mod tests {
         );
 
         // Send final complete input
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit multiple lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1767,7 +1863,7 @@ mod tests {
     async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) {
         let (tool, project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1835,7 +1931,7 @@ mod tests {
         assert_eq!(buffer_text.as_deref(), Some("AAA\nbbb\nCCC\nddd\nEEEeee\n"));
 
         // Send final
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit three lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1857,7 +1953,7 @@ mod tests {
     async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) {
         let (tool, project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1893,16 +1989,17 @@ mod tests {
         }));
         cx.run_until_parked();
 
-        // Verify edit 1 was applied
-        let buffer_text = project.update(cx, |project, cx| {
+        let buffer = project.update(cx, |project, cx| {
             let pp = project
                 .find_project_path(&PathBuf::from("root/file.txt"), cx)
                 .unwrap();
-            project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text())
+            project.get_open_buffer(&pp, cx).unwrap()
         });
+
+        // Verify edit 1 was applied
+        let buffer_text = buffer.read_with(cx, |buffer, _cx| buffer.text());
         assert_eq!(
-            buffer_text.as_deref(),
-            Some("MODIFIED\nline 2\nline 3\n"),
+            buffer_text, "MODIFIED\nline 2\nline 3\n",
             "First edit should be applied even though second edit will fail"
         );
 
@@ -1925,20 +2022,32 @@ mod tests {
         drop(sender);
 
         let result = task.await;
-        let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+        let StreamingEditFileToolOutput::Error {
+            error,
+            diff,
+            input_path,
+        } = result.unwrap_err()
+        else {
             panic!("expected error");
         };
+
         assert!(
             error.contains("Could not find matching text for edit at index 1"),
             "Expected error about edit 1 failing, got: {error}"
         );
+        // Ensure that first edit was applied successfully and that we saved the buffer
+        assert_eq!(input_path, Some(PathBuf::from("root/file.txt")));
+        assert_eq!(
+            diff,
+            "@@ -1,3 +1,3 @@\n-line 1\n+MODIFIED\n line 2\n line 3\n"
+        );
     }
 
     #[gpui::test]
     async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) {
         let (tool, project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "hello world\n"})).await;
-        let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+        let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
         let (event_stream, _receiver) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
 
@@ -1975,7 +2084,7 @@ mod tests {
         );
 
         // Send final — the edit is applied during finalization
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Single edit",
             "path": "root/file.txt",
             "mode": "edit",
@@ -1993,7 +2102,7 @@ mod tests {
     async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
-        let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+        let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
             ToolInput::test();
         let (event_stream, _event_rx) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2020,7 +2129,7 @@ mod tests {
         cx.run_until_parked();
 
         // Send the final complete input
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Edit lines",
             "path": "root/file.txt",
             "mode": "edit",
@@ -2038,7 +2147,7 @@ mod tests {
     async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) {
         let (tool, _project, _action_log, _fs, _thread) =
             setup_test(cx, json!({"file.txt": "hello world\n"})).await;
-        let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+        let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
             ToolInput::test();
         let (event_stream, _event_rx) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2064,7 +2173,7 @@ mod tests {
         // Create a channel and send multiple partials before a final, then use
         // ToolInput::resolved-style immediate delivery to confirm recv() works
         // when partials are already buffered.
-        let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+        let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
             ToolInput::test();
         let (event_stream, _event_rx) = ToolCallEventStream::test();
         let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2077,7 +2186,7 @@ mod tests {
             "path": "root/dir/new.txt",
             "mode": "write"
         }));
-        sender.send_final(json!({
+        sender.send_full(json!({
             "display_description": "Create",
             "path": "root/dir/new.txt",
             "mode": "write",
@@ -2109,13 +2218,13 @@ mod tests {
 
         let result = test_resolve_path(&mode, "root/dir/subdir", cx);
         assert_eq!(
-            result.await.unwrap_err().to_string(),
+            result.await.unwrap_err(),
             "Can't write to file: path is a directory"
         );
 
         let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx);
         assert_eq!(
-            result.await.unwrap_err().to_string(),
+            result.await.unwrap_err(),
             "Can't create file: parent directory doesn't exist"
         );
     }
@@ -2133,14 +2242,11 @@ mod tests {
         assert_resolved_path_eq(result.await, rel_path(path_without_root));
 
         let result = test_resolve_path(&mode, "root/nonexistent.txt", cx);
-        assert_eq!(
-            result.await.unwrap_err().to_string(),
-            "Can't edit file: path not found"
-        );
+        assert_eq!(result.await.unwrap_err(), "Can't edit file: path not found");
 
         let result = test_resolve_path(&mode, "root/dir", cx);
         assert_eq!(
-            result.await.unwrap_err().to_string(),
+            result.await.unwrap_err(),
             "Can't edit file: path is a directory"
         );
     }

crates/agent_ui/src/agent_panel.rs 🔗

@@ -5175,7 +5175,7 @@ mod tests {
         multi_workspace
             .read_with(cx, |multi_workspace, _cx| {
                 assert_eq!(
-                    multi_workspace.workspaces().len(),
+                    multi_workspace.workspaces().count(),
                     1,
                     "LocalProject should not create a new workspace"
                 );
@@ -5451,6 +5451,11 @@ mod tests {
 
         let multi_workspace =
             cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+        multi_workspace
+            .update(cx, |multi_workspace, _, cx| {
+                multi_workspace.open_sidebar(cx);
+            })
+            .unwrap();
 
         let workspace = multi_workspace
             .read_with(cx, |multi_workspace, _cx| {
@@ -5538,15 +5543,14 @@ mod tests {
             .read_with(cx, |multi_workspace, cx| {
                 // There should be more than one workspace now (the original + the new worktree).
                 assert!(
-                    multi_workspace.workspaces().len() > 1,
+                    multi_workspace.workspaces().count() > 1,
                     "expected a new workspace to have been created, found {}",
-                    multi_workspace.workspaces().len(),
+                    multi_workspace.workspaces().count(),
                 );
 
                 // Check the newest workspace's panel for the correct agent.
                 let new_workspace = multi_workspace
                     .workspaces()
-                    .iter()
                     .find(|ws| ws.entity_id() != workspace.entity_id())
                     .expect("should find the new workspace");
                 let new_panel = new_workspace

crates/agent_ui/src/conversation_view.rs 🔗

@@ -3375,7 +3375,6 @@ pub(crate) mod tests {
         // Verify workspace1 is no longer the active workspace
         multi_workspace_handle
             .read_with(cx, |mw, _cx| {
-                assert_eq!(mw.active_workspace_index(), 1);
                 assert_ne!(mw.workspace(), &workspace1);
             })
             .unwrap();

crates/agent_ui/src/threads_archive_view.rs 🔗

@@ -218,6 +218,13 @@ impl ThreadsArchiveView {
         handle.focus(window, cx);
     }
 
+    pub fn is_filter_editor_focused(&self, window: &Window, cx: &App) -> bool {
+        self.filter_editor
+            .read(cx)
+            .focus_handle(cx)
+            .is_focused(window)
+    }
+
     fn update_items(&mut self, cx: &mut Context<Self>) {
         let sessions = ThreadMetadataStore::global(cx)
             .read(cx)
@@ -346,7 +353,6 @@ impl ThreadsArchiveView {
             .map(|mw| {
                 mw.read(cx)
                     .workspaces()
-                    .iter()
                     .filter_map(|ws| ws.read(cx).database_id())
                     .collect()
             })

crates/dev_container/src/devcontainer_manifest.rs 🔗

@@ -883,7 +883,13 @@ RUN sed -i -E 's/((^|\s)PATH=)([^\$]*)$/\1\${{PATH:-\3}}/g' /etc/profile || true
                         labels: None,
                         build: Some(DockerComposeServiceBuild {
                             context: Some(
-                                features_build_info.empty_context_dir.display().to_string(),
+                                main_service
+                                    .build
+                                    .as_ref()
+                                    .and_then(|b| b.context.clone())
+                                    .unwrap_or_else(|| {
+                                        features_build_info.empty_context_dir.display().to_string()
+                                    }),
                             ),
                             dockerfile: Some(dockerfile_path.display().to_string()),
                             args: Some(build_args),
@@ -3546,6 +3552,27 @@ ENV DOCKER_BUILDKIT=1
 "#
         );
 
+        let build_override = files
+            .iter()
+            .find(|f| {
+                f.file_name()
+                    .is_some_and(|s| s.display().to_string() == "docker_compose_build.json")
+            })
+            .expect("to be found");
+        let build_override = test_dependencies.fs.load(build_override).await.unwrap();
+        let build_config: DockerComposeConfig =
+            serde_json_lenient::from_str(&build_override).unwrap();
+        let build_context = build_config
+            .services
+            .get("app")
+            .and_then(|s| s.build.as_ref())
+            .and_then(|b| b.context.clone())
+            .expect("build override should have a context");
+        assert_eq!(
+            build_context, ".",
+            "build override should preserve the original build context from docker-compose.yml"
+        );
+
         let runtime_override = files
             .iter()
             .find(|f| {

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -2707,6 +2707,65 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte
     });
 }
 
+#[gpui::test]
+async fn test_v3_prediction_strips_cursor_marker_from_edit_text(cx: &mut TestAppContext) {
+    let (ep_store, mut requests) = init_test_with_fake_client(cx);
+    let fs = FakeFs::new(cx.executor());
+
+    fs.insert_tree(
+        "/root",
+        json!({
+            "foo.txt": "hello"
+        }),
+    )
+    .await;
+    let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+    let buffer = project
+        .update(cx, |project, cx| {
+            let path = project
+                .find_project_path(path!("root/foo.txt"), cx)
+                .unwrap();
+            project.open_buffer(path, cx)
+        })
+        .await
+        .unwrap();
+
+    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+    let position = snapshot.anchor_before(language::Point::new(0, 5));
+
+    ep_store.update(cx, |ep_store, cx| {
+        ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+    });
+
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
+    let excerpt_length = request.input.cursor_excerpt.len();
+    respond_tx
+        .send(PredictEditsV3Response {
+            request_id: Uuid::new_v4().to_string(),
+            output: "hello<|user_cursor|> world".to_string(),
+            editable_range: 0..excerpt_length,
+            model_version: None,
+        })
+        .unwrap();
+
+    cx.run_until_parked();
+
+    ep_store.update(cx, |ep_store, cx| {
+        let prediction = ep_store
+            .prediction_at(&buffer, None, &project, cx)
+            .expect("should have prediction");
+        let snapshot = buffer.read(cx).snapshot();
+        let edits: Vec<_> = prediction
+            .edits
+            .iter()
+            .map(|(range, text)| (range.to_offset(&snapshot), text.clone()))
+            .collect();
+
+        assert_eq!(edits, vec![(5..5, " world".into())]);
+    });
+}
+
 fn init_test(cx: &mut TestAppContext) {
     cx.update(|cx| {
         let settings_store = SettingsStore::test(cx);

crates/edit_prediction/src/example_spec.rs 🔗

@@ -1,10 +1,11 @@
-use crate::udiff::DiffLine;
 use anyhow::{Context as _, Result};
 use serde::{Deserialize, Serialize};
 use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
 use telemetry_events::EditPredictionRating;
 
-pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
+pub use zeta_prompt::udiff::{
+    CURSOR_POSITION_MARKER, encode_cursor_in_patch, extract_cursor_from_patch,
+};
 pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
 
 /// Maximum cursor file size to capture (64KB).
@@ -12,64 +13,6 @@ pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
 /// falling back to git-based loading.
 pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
 
-/// Encodes a cursor position into a diff patch by adding a comment line with a caret
-/// pointing to the cursor column.
-///
-/// The cursor offset is relative to the start of the new text content (additions and context lines).
-/// Returns the patch with cursor marker comment lines inserted after the relevant addition line.
-pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option<usize>) -> String {
-    let Some(cursor_offset) = cursor_offset else {
-        return patch.to_string();
-    };
-
-    let mut result = String::new();
-    let mut line_start_offset = 0usize;
-
-    for line in patch.lines() {
-        if matches!(
-            DiffLine::parse(line),
-            DiffLine::Garbage(content)
-                if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER)
-        ) {
-            continue;
-        }
-
-        if !result.is_empty() {
-            result.push('\n');
-        }
-        result.push_str(line);
-
-        match DiffLine::parse(line) {
-            DiffLine::Addition(content) => {
-                let line_end_offset = line_start_offset + content.len();
-
-                if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset {
-                    let cursor_column = cursor_offset - line_start_offset;
-
-                    result.push('\n');
-                    result.push('#');
-                    for _ in 0..cursor_column {
-                        result.push(' ');
-                    }
-                    write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
-                }
-
-                line_start_offset = line_end_offset + 1;
-            }
-            DiffLine::Context(content) => {
-                line_start_offset += content.len() + 1;
-            }
-            _ => {}
-        }
-    }
-
-    if patch.ends_with('\n') {
-        result.push('\n');
-    }
-
-    result
-}
-
 #[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 pub struct ExampleSpec {
     #[serde(default)]
@@ -509,53 +452,7 @@ impl ExampleSpec {
     pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option<usize>)> {
         self.expected_patches
             .iter()
-            .map(|patch| {
-                let mut clean_patch = String::new();
-                let mut cursor_offset: Option<usize> = None;
-                let mut line_start_offset = 0usize;
-                let mut prev_line_start_offset = 0usize;
-
-                for line in patch.lines() {
-                    let diff_line = DiffLine::parse(line);
-
-                    match &diff_line {
-                        DiffLine::Garbage(content)
-                            if content.starts_with('#')
-                                && content.contains(CURSOR_POSITION_MARKER) =>
-                        {
-                            let caret_column = if let Some(caret_pos) = content.find('^') {
-                                caret_pos
-                            } else if let Some(_) = content.find('<') {
-                                0
-                            } else {
-                                continue;
-                            };
-                            let cursor_column = caret_column.saturating_sub('#'.len_utf8());
-                            cursor_offset = Some(prev_line_start_offset + cursor_column);
-                        }
-                        _ => {
-                            if !clean_patch.is_empty() {
-                                clean_patch.push('\n');
-                            }
-                            clean_patch.push_str(line);
-
-                            match diff_line {
-                                DiffLine::Addition(content) | DiffLine::Context(content) => {
-                                    prev_line_start_offset = line_start_offset;
-                                    line_start_offset += content.len() + 1;
-                                }
-                                _ => {}
-                            }
-                        }
-                    }
-                }
-
-                if patch.ends_with('\n') && !clean_patch.is_empty() {
-                    clean_patch.push('\n');
-                }
-
-                (clean_patch, cursor_offset)
-            })
+            .map(|patch| extract_cursor_from_patch(patch))
             .collect()
     }
 

crates/edit_prediction/src/zeta.rs 🔗

@@ -24,8 +24,9 @@ use zeta_prompt::{ParsedOutput, ZetaPromptInput};
 
 use std::{env, ops::Range, path::Path, sync::Arc};
 use zeta_prompt::{
-    CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
-    prompt_input_contains_special_tokens, stop_tokens_for_format,
+    ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
+    parsed_output_from_editable_region, prompt_input_contains_special_tokens,
+    stop_tokens_for_format,
     zeta1::{self, EDITABLE_REGION_END_MARKER},
 };
 
@@ -181,6 +182,7 @@ pub fn request_prediction_with_zeta(
                             let parsed_output = output_text.map(|text| ParsedOutput {
                                 new_editable_region: text,
                                 range_in_excerpt: editable_range_in_excerpt,
+                                cursor_offset_in_new_editable_region: None,
                             });
 
                             (request_id, parsed_output, None, None)
@@ -283,10 +285,10 @@ pub fn request_prediction_with_zeta(
                     let request_id = EditPredictionId(response.request_id.into());
                     let output_text = Some(response.output).filter(|s| !s.is_empty());
                     let model_version = response.model_version;
-                    let parsed_output = ParsedOutput {
-                        new_editable_region: output_text.unwrap_or_default(),
-                        range_in_excerpt: response.editable_range,
-                    };
+                    let parsed_output = parsed_output_from_editable_region(
+                        response.editable_range,
+                        output_text.unwrap_or_default(),
+                    );
 
                     Some((request_id, Some(parsed_output), model_version, usage))
                 })
@@ -299,6 +301,7 @@ pub fn request_prediction_with_zeta(
             let Some(ParsedOutput {
                 new_editable_region: mut output_text,
                 range_in_excerpt: editable_range_in_excerpt,
+                cursor_offset_in_new_editable_region: cursor_offset_in_output,
             }) = output
             else {
                 return Ok((Some((request_id, None)), None));
@@ -312,13 +315,6 @@ pub fn request_prediction_with_zeta(
                 .text_for_range(editable_range_in_buffer.clone())
                 .collect::<String>();
 
-            // Client-side cursor marker processing (applies to both raw and v3 responses)
-            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
-            if let Some(offset) = cursor_offset_in_output {
-                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
-                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
-            }
-
             if let Some(debug_tx) = &debug_tx {
                 debug_tx
                     .unbounded_send(DebugEvent::EditPredictionFinished(

crates/edit_prediction_cli/src/parse_output.rs 🔗

@@ -5,8 +5,7 @@ use crate::{
     repair,
 };
 use anyhow::{Context as _, Result};
-use edit_prediction::example_spec::encode_cursor_in_patch;
-use zeta_prompt::{CURSOR_MARKER, ZetaFormat, parse_zeta2_model_output};
+use zeta_prompt::{ZetaFormat, parse_zeta2_model_output, parsed_output_to_patch};
 
 pub fn run_parse_output(example: &mut Example) -> Result<()> {
     example
@@ -65,46 +64,18 @@ fn parse_zeta2_output(
         .context("prompt_inputs required")?;
 
     let parsed = parse_zeta2_model_output(actual_output, format, prompt_inputs)?;
-    let range_in_excerpt = parsed.range_in_excerpt;
-
+    let range_in_excerpt = parsed.range_in_excerpt.clone();
     let excerpt = prompt_inputs.cursor_excerpt.as_ref();
-    let old_text = excerpt[range_in_excerpt.clone()].to_string();
-    let mut new_text = parsed.new_editable_region;
-
-    let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
-        new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
-        Some(offset)
-    } else {
-        None
-    };
+    let editable_region_offset = range_in_excerpt.start;
+    let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
 
-    // Normalize trailing newlines for diff generation
-    let mut old_text_normalized = old_text;
+    let mut new_text = parsed.new_editable_region.clone();
     if !new_text.is_empty() && !new_text.ends_with('\n') {
         new_text.push('\n');
     }
-    if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
-        old_text_normalized.push('\n');
-    }
-
-    let editable_region_offset = range_in_excerpt.start;
-    let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
-    let editable_region_lines = old_text_normalized.lines().count() as u32;
-
-    let diff = language::unified_diff_with_context(
-        &old_text_normalized,
-        &new_text,
-        editable_region_start_line as u32,
-        editable_region_start_line as u32,
-        editable_region_lines,
-    );
-
-    let formatted_diff = format!(
-        "--- a/{path}\n+++ b/{path}\n{diff}",
-        path = example.spec.cursor_path.to_string_lossy(),
-    );
 
-    let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset);
+    let cursor_offset = parsed.cursor_offset_in_new_editable_region;
+    let formatted_diff = parsed_output_to_patch(prompt_inputs, parsed)?;
 
     let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
         ActualCursor::from_editable_region(

crates/gpui/src/elements/list.rs 🔗

@@ -462,6 +462,13 @@ impl ListState {
 
         let current_offset = self.logical_scroll_top();
         let state = &mut *self.0.borrow_mut();
+
+        if distance < px(0.) {
+            if let FollowState::Tail { is_following } = &mut state.follow_state {
+                *is_following = false;
+            }
+        }
+
         let mut cursor = state.items.cursor::<ListItemSummary>(());
         cursor.seek(&Count(current_offset.item_ix), Bias::Right);
 
@@ -536,6 +543,12 @@ impl ListState {
             scroll_top.offset_in_item = px(0.);
         }
 
+        if scroll_top.item_ix < item_count {
+            if let FollowState::Tail { is_following } = &mut state.follow_state {
+                *is_following = false;
+            }
+        }
+
         state.logical_scroll_top = Some(scroll_top);
     }
 

crates/language_model/src/fake_provider.rs 🔗

@@ -125,6 +125,7 @@ pub struct FakeLanguageModel {
     >,
     forbid_requests: AtomicBool,
     supports_thinking: AtomicBool,
+    supports_streaming_tools: AtomicBool,
 }
 
 impl Default for FakeLanguageModel {
@@ -137,6 +138,7 @@ impl Default for FakeLanguageModel {
             current_completion_txs: Mutex::new(Vec::new()),
             forbid_requests: AtomicBool::new(false),
             supports_thinking: AtomicBool::new(false),
+            supports_streaming_tools: AtomicBool::new(false),
         }
     }
 }
@@ -169,6 +171,10 @@ impl FakeLanguageModel {
         self.supports_thinking.store(supports, SeqCst);
     }
 
+    pub fn set_supports_streaming_tools(&self, supports: bool) {
+        self.supports_streaming_tools.store(supports, SeqCst);
+    }
+
     pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
         self.current_completion_txs
             .lock()
@@ -282,6 +288,10 @@ impl LanguageModel for FakeLanguageModel {
         self.supports_thinking.load(SeqCst)
     }
 
+    fn supports_streaming_tools(&self) -> bool {
+        self.supports_streaming_tools.load(SeqCst)
+    }
+
     fn telemetry_id(&self) -> String {
         "fake".to_string()
     }

crates/picker/src/picker.rs 🔗

@@ -121,6 +121,9 @@ pub trait PickerDelegate: Sized + 'static {
     ) -> bool {
         true
     }
+    fn select_on_hover(&self) -> bool {
+        true
+    }
 
     // Allows binding some optional effect to when the selection changes.
     fn selected_index_changed(
@@ -788,12 +791,14 @@ impl<D: PickerDelegate> Picker<D> {
                     this.handle_click(ix, event.modifiers.platform, window, cx)
                 }),
             )
-            .on_hover(cx.listener(move |this, hovered: &bool, window, cx| {
-                if *hovered {
-                    this.set_selected_index(ix, None, false, window, cx);
-                    cx.notify();
-                }
-            }))
+            .when(self.delegate.select_on_hover(), |this| {
+                this.on_hover(cx.listener(move |this, hovered: &bool, window, cx| {
+                    if *hovered {
+                        this.set_selected_index(ix, None, false, window, cx);
+                        cx.notify();
+                    }
+                }))
+            })
             .children(self.delegate.render_match(
                 ix,
                 ix == self.delegate.selected_index(),

crates/recent_projects/src/recent_projects.rs 🔗

@@ -357,7 +357,6 @@ pub fn init(cx: &mut App) {
                         .update(cx, |multi_workspace, window, cx| {
                             let sibling_workspace_ids: HashSet<WorkspaceId> = multi_workspace
                                 .workspaces()
-                                .iter()
                                 .filter_map(|ws| ws.read(cx).database_id())
                                 .collect();
 
@@ -1113,7 +1112,6 @@ impl PickerDelegate for RecentProjectsDelegate {
                             .update(cx, |multi_workspace, window, cx| {
                                 let workspace = multi_workspace
                                     .workspaces()
-                                    .iter()
                                     .find(|ws| ws.read(cx).database_id() == Some(workspace_id))
                                     .cloned();
                                 if let Some(workspace) = workspace {
@@ -1932,7 +1930,6 @@ impl RecentProjectsDelegate {
                     .update(cx, |multi_workspace, window, cx| {
                         let workspace = multi_workspace
                             .workspaces()
-                            .iter()
                             .find(|ws| ws.read(cx).database_id() == Some(workspace_id))
                             .cloned();
                         if let Some(workspace) = workspace {
@@ -2055,6 +2052,11 @@ mod tests {
         assert_eq!(cx.update(|cx| cx.windows().len()), 1);
 
         let multi_workspace = cx.update(|cx| cx.windows()[0].downcast::<MultiWorkspace>().unwrap());
+        multi_workspace
+            .update(cx, |multi_workspace, _, cx| {
+                multi_workspace.open_sidebar(cx);
+            })
+            .unwrap();
         multi_workspace
             .update(cx, |multi_workspace, _, cx| {
                 assert!(!multi_workspace.workspace().read(cx).is_edited())
@@ -2141,7 +2143,7 @@ mod tests {
                 );
 
                 assert!(
-                    multi_workspace.workspaces().contains(&dirty_workspace),
+                    multi_workspace.workspaces().any(|w| w == &dirty_workspace),
                     "The dirty workspace should still be present in multi-workspace mode"
                 );
 

crates/rules_library/src/rules_library.rs 🔗

@@ -225,6 +225,10 @@ impl PickerDelegate for RulePickerDelegate {
         }
     }
 
+    fn select_on_hover(&self) -> bool {
+        false
+    }
+
     fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
         "Search…".into()
     }

crates/settings_ui/src/page_data.rs 🔗

@@ -4433,7 +4433,7 @@ fn window_and_layout_page() -> SettingsPage {
 }
 
 fn panels_page() -> SettingsPage {
-    fn project_panel_section() -> [SettingsPageItem; 24] {
+    fn project_panel_section() -> [SettingsPageItem; 28] {
         [
             SettingsPageItem::SectionHeader("Project Panel"),
             SettingsPageItem::SettingItem(SettingItem {
@@ -4914,31 +4914,25 @@ fn panels_page() -> SettingsPage {
                 files: USER,
             }),
             SettingsPageItem::SettingItem(SettingItem {
-                title: "Hidden Files",
-                description: "Globs to match files that will be considered \"hidden\" and can be hidden from the project panel.",
-                field: Box::new(
-                    SettingField {
-                        json_path: Some("worktree.hidden_files"),
-                        pick: |settings_content| {
-                            settings_content.project.worktree.hidden_files.as_ref()
-                        },
-                        write: |settings_content, value| {
-                            settings_content.project.worktree.hidden_files = value;
-                        },
-                    }
-                    .unimplemented(),
-                ),
+                title: "Sort Mode",
+                description: "Sort order for entries in the project panel.",
+                field: Box::new(SettingField {
+                    json_path: Some("project_panel.sort_mode"),
+                    pick: |settings_content| {
+                        settings_content.project_panel.as_ref()?.sort_mode.as_ref()
+                    },
+                    write: |settings_content, value| {
+                        settings_content
+                            .project_panel
+                            .get_or_insert_default()
+                            .sort_mode = value;
+                    },
+                }),
                 metadata: None,
                 files: USER,
             }),
-        ]
-    }
-
-    fn auto_open_files_section() -> [SettingsPageItem; 5] {
-        [
-            SettingsPageItem::SectionHeader("Auto Open Files"),
             SettingsPageItem::SettingItem(SettingItem {
-                title: "On Create",
+                title: "Auto Open Files On Create",
                 description: "Whether to automatically open newly created files in the editor.",
                 field: Box::new(SettingField {
                     json_path: Some("project_panel.auto_open.on_create"),
@@ -4964,7 +4958,7 @@ fn panels_page() -> SettingsPage {
                 files: USER,
             }),
             SettingsPageItem::SettingItem(SettingItem {
-                title: "On Paste",
+                title: "Auto Open Files On Paste",
                 description: "Whether to automatically open files after pasting or duplicating them.",
                 field: Box::new(SettingField {
                     json_path: Some("project_panel.auto_open.on_paste"),
@@ -4990,7 +4984,7 @@ fn panels_page() -> SettingsPage {
                 files: USER,
             }),
             SettingsPageItem::SettingItem(SettingItem {
-                title: "On Drop",
+                title: "Auto Open Files On Drop",
                 description: "Whether to automatically open files dropped from external sources.",
                 field: Box::new(SettingField {
                     json_path: Some("project_panel.auto_open.on_drop"),
@@ -5016,20 +5010,20 @@ fn panels_page() -> SettingsPage {
                 files: USER,
             }),
             SettingsPageItem::SettingItem(SettingItem {
-                title: "Sort Mode",
-                description: "Sort order for entries in the project panel.",
-                field: Box::new(SettingField {
-                    pick: |settings_content| {
-                        settings_content.project_panel.as_ref()?.sort_mode.as_ref()
-                    },
-                    write: |settings_content, value| {
-                        settings_content
-                            .project_panel
-                            .get_or_insert_default()
-                            .sort_mode = value;
-                    },
-                    json_path: Some("project_panel.sort_mode"),
-                }),
+                title: "Hidden Files",
+                description: "Globs to match files that will be considered \"hidden\" and can be hidden from the project panel.",
+                field: Box::new(
+                    SettingField {
+                        json_path: Some("worktree.hidden_files"),
+                        pick: |settings_content| {
+                            settings_content.project.worktree.hidden_files.as_ref()
+                        },
+                        write: |settings_content, value| {
+                            settings_content.project.worktree.hidden_files = value;
+                        },
+                    }
+                    .unimplemented(),
+                ),
                 metadata: None,
                 files: USER,
             }),
@@ -5807,7 +5801,6 @@ fn panels_page() -> SettingsPage {
         title: "Panels",
         items: concat_sections![
             project_panel_section(),
-            auto_open_files_section(),
             terminal_panel_section(),
             outline_panel_section(),
             git_panel_section(),

crates/settings_ui/src/settings_ui.rs 🔗

@@ -3753,7 +3753,6 @@ fn all_projects(
                 .flat_map(|multi_workspace| {
                     multi_workspace
                         .workspaces()
-                        .iter()
                         .map(|workspace| workspace.read(cx).project().clone())
                         .collect::<Vec<_>>()
                 }),

crates/sidebar/src/sidebar.rs 🔗

@@ -434,7 +434,7 @@ impl Sidebar {
         })
         .detach();
 
-        let workspaces = multi_workspace.read(cx).workspaces().to_vec();
+        let workspaces: Vec<_> = multi_workspace.read(cx).workspaces().cloned().collect();
         cx.defer_in(window, move |this, window, cx| {
             for workspace in &workspaces {
                 this.subscribe_to_workspace(workspace, window, cx);
@@ -673,7 +673,6 @@ impl Sidebar {
         let mw = self.multi_workspace.upgrade()?;
         let mw = mw.read(cx);
         mw.workspaces()
-            .iter()
             .find(|ws| ws.read(cx).project_group_key(cx).path_list() == path_list)
             .cloned()
     }
@@ -716,8 +715,8 @@ impl Sidebar {
             return;
         };
         let mw = multi_workspace.read(cx);
-        let workspaces = mw.workspaces().to_vec();
-        let active_workspace = mw.workspaces().get(mw.active_workspace_index()).cloned();
+        let workspaces: Vec<_> = mw.workspaces().cloned().collect();
+        let active_workspace = Some(mw.workspace().clone());
 
         let agent_server_store = workspaces
             .first()
@@ -1769,7 +1768,11 @@ impl Sidebar {
         dispatch_context.add("ThreadsSidebar");
         dispatch_context.add("menu");
 
-        let identifier = if self.filter_editor.focus_handle(cx).is_focused(window) {
+        let is_archived_search_focused = matches!(&self.view, SidebarView::Archive(archive) if archive.read(cx).is_filter_editor_focused(window, cx));
+
+        let identifier = if self.filter_editor.focus_handle(cx).is_focused(window)
+            || is_archived_search_focused
+        {
             "searching"
         } else {
             "not_searching"
@@ -1989,7 +1992,6 @@ impl Sidebar {
                 let workspace = window.read(cx).ok().and_then(|multi_workspace| {
                     multi_workspace
                         .workspaces()
-                        .iter()
                         .find(|workspace| predicate(workspace, cx))
                         .cloned()
                 })?;
@@ -2006,7 +2008,6 @@ impl Sidebar {
             multi_workspace
                 .read(cx)
                 .workspaces()
-                .iter()
                 .find(|workspace| predicate(workspace, cx))
                 .cloned()
         })
@@ -2199,12 +2200,10 @@ impl Sidebar {
             return;
         }
 
-        let active_workspace = self.multi_workspace.upgrade().and_then(|w| {
-            w.read(cx)
-                .workspaces()
-                .get(w.read(cx).active_workspace_index())
-                .cloned()
-        });
+        let active_workspace = self
+            .multi_workspace
+            .upgrade()
+            .map(|w| w.read(cx).workspace().clone());
 
         if let Some(workspace) = active_workspace {
             self.activate_thread_locally(&metadata, &workspace, window, cx);
@@ -2339,7 +2338,7 @@ impl Sidebar {
             return;
         };
 
-        let workspaces = multi_workspace.read(cx).workspaces().to_vec();
+        let workspaces: Vec<_> = multi_workspace.read(cx).workspaces().cloned().collect();
         for workspace in workspaces {
             if let Some(agent_panel) = workspace.read(cx).panel::<AgentPanel>(cx) {
                 let cancelled =
@@ -2932,7 +2931,6 @@ impl Sidebar {
             .map(|mw| {
                 mw.read(cx)
                     .workspaces()
-                    .iter()
                     .filter_map(|ws| ws.read(cx).database_id())
                     .collect()
             })
@@ -3400,12 +3398,9 @@ impl Sidebar {
     }
 
     fn active_workspace(&self, cx: &App) -> Option<Entity<Workspace>> {
-        self.multi_workspace.upgrade().and_then(|w| {
-            w.read(cx)
-                .workspaces()
-                .get(w.read(cx).active_workspace_index())
-                .cloned()
-        })
+        self.multi_workspace
+            .upgrade()
+            .map(|w| w.read(cx).workspace().clone())
     }
 
     fn show_thread_import_modal(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -3513,12 +3508,11 @@ impl Sidebar {
     }
 
     fn show_archive(&mut self, window: &mut Window, cx: &mut Context<Self>) {
-        let Some(active_workspace) = self.multi_workspace.upgrade().and_then(|w| {
-            w.read(cx)
-                .workspaces()
-                .get(w.read(cx).active_workspace_index())
-                .cloned()
-        }) else {
+        let Some(active_workspace) = self
+            .multi_workspace
+            .upgrade()
+            .map(|w| w.read(cx).workspace().clone())
+        else {
             return;
         };
         let Some(agent_panel) = active_workspace.read(cx).panel::<AgentPanel>(cx) else {
@@ -3820,12 +3814,12 @@ pub fn dump_workspace_info(
 
     let multi_workspace = workspace.multi_workspace().and_then(|weak| weak.upgrade());
     let workspaces: Vec<gpui::Entity<Workspace>> = match &multi_workspace {
-        Some(mw) => mw.read(cx).workspaces().to_vec(),
+        Some(mw) => mw.read(cx).workspaces().cloned().collect(),
         None => vec![this_entity.clone()],
     };
-    let active_index = multi_workspace
+    let active_workspace = multi_workspace
         .as_ref()
-        .map(|mw| mw.read(cx).active_workspace_index());
+        .map(|mw| mw.read(cx).workspace().clone());
 
     writeln!(output, "MultiWorkspace: {} workspace(s)", workspaces.len()).ok();
 
@@ -3837,13 +3831,10 @@ pub fn dump_workspace_info(
         }
     }
 
-    if let Some(index) = active_index {
-        writeln!(output, "Active workspace index: {index}").ok();
-    }
     writeln!(output).ok();
 
     for (index, ws) in workspaces.iter().enumerate() {
-        let is_active = active_index == Some(index);
+        let is_active = active_workspace.as_ref() == Some(ws);
         writeln!(
             output,
             "--- Workspace {index}{} ---",

crates/sidebar/src/sidebar_tests.rs 🔗

@@ -77,6 +77,18 @@ async fn init_test_project(
 fn setup_sidebar(
     multi_workspace: &Entity<MultiWorkspace>,
     cx: &mut gpui::VisualTestContext,
+) -> Entity<Sidebar> {
+    let sidebar = setup_sidebar_closed(multi_workspace, cx);
+    multi_workspace.update_in(cx, |mw, window, cx| {
+        mw.toggle_sidebar(window, cx);
+    });
+    cx.run_until_parked();
+    sidebar
+}
+
+fn setup_sidebar_closed(
+    multi_workspace: &Entity<MultiWorkspace>,
+    cx: &mut gpui::VisualTestContext,
 ) -> Entity<Sidebar> {
     let multi_workspace = multi_workspace.clone();
     let sidebar =
@@ -172,16 +184,7 @@ fn save_thread_metadata(
     cx.run_until_parked();
 }
 
-fn open_and_focus_sidebar(sidebar: &Entity<Sidebar>, cx: &mut gpui::VisualTestContext) {
-    let multi_workspace = sidebar.read_with(cx, |s, _| s.multi_workspace.upgrade());
-    if let Some(multi_workspace) = multi_workspace {
-        multi_workspace.update_in(cx, |mw, window, cx| {
-            if !mw.sidebar_open() {
-                mw.toggle_sidebar(window, cx);
-            }
-        });
-    }
-    cx.run_until_parked();
+fn focus_sidebar(sidebar: &Entity<Sidebar>, cx: &mut gpui::VisualTestContext) {
     sidebar.update_in(cx, |_, window, cx| {
         cx.focus_self(window);
     });
@@ -544,7 +547,7 @@ async fn test_workspace_lifecycle(cx: &mut TestAppContext) {
 
     // Remove the second workspace
     multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[1].clone();
+        let workspace = mw.workspaces().nth(1).cloned().unwrap();
         mw.remove(&workspace, window, cx);
     });
     cx.run_until_parked();
@@ -604,7 +607,7 @@ async fn test_view_more_batched_expansion(cx: &mut TestAppContext) {
     assert!(entries.iter().any(|e| e.contains("View More")));
 
     // Focus and navigate to View More, then confirm to expand by one batch
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     for _ in 0..7 {
         cx.dispatch_action(SelectNext);
     }
@@ -915,7 +918,7 @@ async fn test_keyboard_select_next_and_previous(cx: &mut TestAppContext) {
     // Entries: [header, thread3, thread2, thread1]
     // Focusing the sidebar does not set a selection; select_next/select_previous
     // handle None gracefully by starting from the first or last entry.
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None);
 
     // First SelectNext from None starts at index 0
@@ -970,7 +973,7 @@ async fn test_keyboard_select_first_and_last(cx: &mut TestAppContext) {
     multi_workspace.update_in(cx, |_, _window, cx| cx.notify());
     cx.run_until_parked();
 
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
 
     // SelectLast jumps to the end
     cx.dispatch_action(SelectLast);
@@ -993,7 +996,7 @@ async fn test_keyboard_focus_in_does_not_set_selection(cx: &mut TestAppContext)
 
     // Open the sidebar so it's rendered, then focus it to trigger focus_in.
     // focus_in no longer sets a default selection.
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None);
 
     // Manually set a selection, blur, then refocus — selection should be preserved
@@ -1030,7 +1033,7 @@ async fn test_keyboard_confirm_on_project_header_toggles_collapse(cx: &mut TestA
     );
 
     // Focus the sidebar and select the header (index 0)
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     sidebar.update_in(cx, |sidebar, _window, _cx| {
         sidebar.selection = Some(0);
     });
@@ -1071,7 +1074,7 @@ async fn test_keyboard_confirm_on_view_more_expands(cx: &mut TestAppContext) {
     assert!(entries.iter().any(|e| e.contains("View More")));
 
     // Focus sidebar (selection starts at None), then navigate down to the "View More" entry (index 6)
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     for _ in 0..7 {
         cx.dispatch_action(SelectNext);
     }
@@ -1105,7 +1108,7 @@ async fn test_keyboard_expand_and_collapse_selected_entry(cx: &mut TestAppContex
     );
 
     // Focus sidebar and manually select the header (index 0). Press left to collapse.
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     sidebar.update_in(cx, |sidebar, _window, _cx| {
         sidebar.selection = Some(0);
     });
@@ -1144,7 +1147,7 @@ async fn test_keyboard_collapse_from_child_selects_parent(cx: &mut TestAppContex
     cx.run_until_parked();
 
     // Focus sidebar (selection starts at None), then navigate down to the thread (child)
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     cx.dispatch_action(SelectNext);
     cx.dispatch_action(SelectNext);
     assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(1));
@@ -1179,7 +1182,7 @@ async fn test_keyboard_navigation_on_empty_list(cx: &mut TestAppContext) {
     );
 
     // Focus sidebar — focus_in does not set a selection
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None);
 
     // First SelectNext from None starts at index 0 (header)
@@ -1211,7 +1214,7 @@ async fn test_selection_clamps_after_entry_removal(cx: &mut TestAppContext) {
     cx.run_until_parked();
 
     // Focus sidebar (selection starts at None), navigate down to the thread (index 1)
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     cx.dispatch_action(SelectNext);
     cx.dispatch_action(SelectNext);
     assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(1));
@@ -1492,7 +1495,7 @@ async fn test_escape_clears_search_and_restores_full_list(cx: &mut TestAppContex
     );
 
     // User types a search query to filter down.
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     type_in_search(&sidebar, "alpha", cx);
     assert_eq!(
         visible_entries_as_strings(&sidebar, cx),
@@ -1540,8 +1543,9 @@ async fn test_search_only_shows_workspace_headers_with_matches(cx: &mut TestAppC
     });
     cx.run_until_parked();
 
-    let project_b =
-        multi_workspace.read_with(cx, |mw, cx| mw.workspaces()[1].read(cx).project().clone());
+    let project_b = multi_workspace.read_with(cx, |mw, cx| {
+        mw.workspaces().nth(1).unwrap().read(cx).project().clone()
+    });
 
     for (id, title, hour) in [
         ("b1", "Refactor sidebar layout", 3),
@@ -1621,8 +1625,9 @@ async fn test_search_matches_workspace_name(cx: &mut TestAppContext) {
     });
     cx.run_until_parked();
 
-    let project_b =
-        multi_workspace.read_with(cx, |mw, cx| mw.workspaces()[1].read(cx).project().clone());
+    let project_b = multi_workspace.read_with(cx, |mw, cx| {
+        mw.workspaces().nth(1).unwrap().read(cx).project().clone()
+    });
 
     for (id, title, hour) in [
         ("b1", "Refactor sidebar layout", 3),
@@ -1764,7 +1769,7 @@ async fn test_search_finds_threads_inside_collapsed_groups(cx: &mut TestAppConte
 
     // User focuses the sidebar and collapses the group using keyboard:
     // manually select the header, then press SelectParent to collapse.
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     sidebar.update_in(cx, |sidebar, _window, _cx| {
         sidebar.selection = Some(0);
     });
@@ -1807,7 +1812,7 @@ async fn test_search_then_keyboard_navigate_and_confirm(cx: &mut TestAppContext)
     }
     cx.run_until_parked();
 
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
 
     // User types "fix" — two threads match.
     type_in_search(&sidebar, "fix", cx);
@@ -1856,6 +1861,13 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC
     });
     cx.run_until_parked();
 
+    let (workspace_0, workspace_1) = multi_workspace.read_with(cx, |mw, _| {
+        (
+            mw.workspaces().next().unwrap().clone(),
+            mw.workspaces().nth(1).unwrap().clone(),
+        )
+    });
+
     save_thread_metadata(
         acp::SessionId::new(Arc::from("hist-1")),
         "Historical Thread".into(),
@@ -1875,13 +1887,13 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC
 
     // Switch to workspace 1 so we can verify the confirm switches back.
     multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[1].clone();
+        let workspace = mw.workspaces().nth(1).unwrap().clone();
         mw.activate(workspace, window, cx);
     });
     cx.run_until_parked();
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
-        1
+        multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+        workspace_1
     );
 
     // Confirm on the historical (non-live) thread at index 1.
@@ -1895,8 +1907,8 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC
     cx.run_until_parked();
 
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
-        0
+        multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+        workspace_0
     );
 }
 
@@ -2037,7 +2049,8 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) {
     let panel_b = add_agent_panel(&workspace_b, cx);
     cx.run_until_parked();
 
-    let workspace_a = multi_workspace.read_with(cx, |mw, _cx| mw.workspaces()[0].clone());
+    let workspace_a =
+        multi_workspace.read_with(cx, |mw, _cx| mw.workspaces().next().unwrap().clone());
 
     // ── 1. Initial state: focused thread derived from active panel ─────
     sidebar.read_with(cx, |sidebar, _cx| {
@@ -2135,7 +2148,7 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) {
     });
 
     multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[0].clone();
+        let workspace = mw.workspaces().next().unwrap().clone();
         mw.activate(workspace, window, cx);
     });
     cx.run_until_parked();
@@ -2190,8 +2203,8 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) {
     // Switching workspaces via the multi_workspace (simulates clicking
     // a workspace header) should clear focused_thread.
     multi_workspace.update_in(cx, |mw, window, cx| {
-        if let Some(index) = mw.workspaces().iter().position(|w| w == &workspace_b) {
-            let workspace = mw.workspaces()[index].clone();
+        let workspace = mw.workspaces().find(|w| *w == &workspace_b).cloned();
+        if let Some(workspace) = workspace {
             mw.activate(workspace, window, cx);
         }
     });
@@ -2477,6 +2490,8 @@ async fn test_cmd_n_shows_new_thread_entry_in_absorbed_worktree(cx: &mut TestApp
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx));
 
+    let sidebar = setup_sidebar(&multi_workspace, cx);
+
     let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
         mw.test_add_workspace(worktree_project.clone(), window, cx)
     });
@@ -2485,12 +2500,10 @@ async fn test_cmd_n_shows_new_thread_entry_in_absorbed_worktree(cx: &mut TestApp
 
     // Switch to the worktree workspace.
     multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[1].clone();
+        let workspace = mw.workspaces().nth(1).unwrap().clone();
         mw.activate(workspace, window, cx);
     });
 
-    let sidebar = setup_sidebar(&multi_workspace, cx);
-
     // Create a non-empty thread in the worktree workspace.
     let connection = StubAgentConnection::new();
     connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk(
@@ -3027,6 +3040,8 @@ async fn test_absorbed_worktree_running_thread_shows_live_status(cx: &mut TestAp
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx));
 
+    let sidebar = setup_sidebar(&multi_workspace, cx);
+
     let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
         mw.test_add_workspace(worktree_project.clone(), window, cx)
     });
@@ -3037,12 +3052,10 @@ async fn test_absorbed_worktree_running_thread_shows_live_status(cx: &mut TestAp
 
     // Switch back to the main workspace before setting up the sidebar.
     multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[0].clone();
+        let workspace = mw.workspaces().next().unwrap().clone();
         mw.activate(workspace, window, cx);
     });
 
-    let sidebar = setup_sidebar(&multi_workspace, cx);
-
     // Start a thread in the worktree workspace's panel and keep it
     // generating (don't resolve it).
     let connection = StubAgentConnection::new();
@@ -3127,6 +3140,8 @@ async fn test_absorbed_worktree_completion_triggers_notification(cx: &mut TestAp
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx));
 
+    let sidebar = setup_sidebar(&multi_workspace, cx);
+
     let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
         mw.test_add_workspace(worktree_project.clone(), window, cx)
     });
@@ -3134,12 +3149,10 @@ async fn test_absorbed_worktree_completion_triggers_notification(cx: &mut TestAp
     let worktree_panel = add_agent_panel(&worktree_workspace, cx);
 
     multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[0].clone();
+        let workspace = mw.workspaces().next().unwrap().clone();
         mw.activate(workspace, window, cx);
     });
 
-    let sidebar = setup_sidebar(&multi_workspace, cx);
-
     let connection = StubAgentConnection::new();
     open_thread_with_connection(&worktree_panel, connection.clone(), cx);
     send_message(&worktree_panel, cx);
@@ -3231,12 +3244,12 @@ async fn test_clicking_worktree_thread_opens_workspace_when_none_exists(cx: &mut
 
     // Only 1 workspace should exist.
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()),
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
         1,
     );
 
     // Focus the sidebar and select the worktree thread.
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     sidebar.update_in(cx, |sidebar, _window, _cx| {
         sidebar.selection = Some(1); // index 0 is header, 1 is the thread
     });
@@ -3248,11 +3261,11 @@ async fn test_clicking_worktree_thread_opens_workspace_when_none_exists(cx: &mut
     // A new workspace should have been created for the worktree path.
     let new_workspace = multi_workspace.read_with(cx, |mw, _| {
         assert_eq!(
-            mw.workspaces().len(),
+            mw.workspaces().count(),
             2,
             "confirming a worktree thread without a workspace should open one",
         );
-        mw.workspaces()[1].clone()
+        mw.workspaces().nth(1).unwrap().clone()
     });
 
     let new_path_list =
@@ -3318,7 +3331,7 @@ async fn test_clicking_worktree_thread_does_not_briefly_render_as_separate_proje
         vec!["v [project]", "  WT Thread {wt-feature-a}"],
     );
 
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     sidebar.update_in(cx, |sidebar, _window, _cx| {
         sidebar.selection = Some(1); // index 0 is header, 1 is the thread
     });
@@ -3444,18 +3457,19 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace(
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx));
 
+    let sidebar = setup_sidebar(&multi_workspace, cx);
+
     let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
         mw.test_add_workspace(worktree_project.clone(), window, cx)
     });
 
     // Activate the main workspace before setting up the sidebar.
-    multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[0].clone();
-        mw.activate(workspace, window, cx);
+    let main_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
+        let workspace = mw.workspaces().next().unwrap().clone();
+        mw.activate(workspace.clone(), window, cx);
+        workspace
     });
 
-    let sidebar = setup_sidebar(&multi_workspace, cx);
-
     save_named_thread_metadata("thread-main", "Main Thread", &main_project, cx).await;
     save_named_thread_metadata("thread-wt", "WT Thread", &worktree_project, cx).await;
 
@@ -3475,13 +3489,13 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace(
         .expect("should find the worktree thread entry");
 
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
-        0,
+        multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+        main_workspace,
         "main workspace should be active initially"
     );
 
     // Focus the sidebar and select the absorbed worktree thread.
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     sidebar.update_in(cx, |sidebar, _window, _cx| {
         sidebar.selection = Some(wt_thread_index);
     });
@@ -3491,9 +3505,7 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace(
     cx.run_until_parked();
 
     // The worktree workspace should now be active, not the main one.
-    let active_workspace = multi_workspace.read_with(cx, |mw, _| {
-        mw.workspaces()[mw.active_workspace_index()].clone()
-    });
+    let active_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
     assert_eq!(
         active_workspace, worktree_workspace,
         "clicking an absorbed worktree thread should activate the worktree workspace"
@@ -3520,25 +3532,27 @@ async fn test_activate_archived_thread_with_saved_paths_activates_matching_works
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx));
 
-    multi_workspace.update_in(cx, |mw, window, cx| {
-        mw.test_add_workspace(project_b.clone(), window, cx);
-    });
-
     let sidebar = setup_sidebar(&multi_workspace, cx);
 
+    let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| {
+        mw.test_add_workspace(project_b.clone(), window, cx)
+    });
+    let workspace_a =
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone());
+
     // Save a thread with path_list pointing to project-b.
     let session_id = acp::SessionId::new(Arc::from("archived-1"));
     save_test_thread_metadata(&session_id, &project_b, cx).await;
 
     // Ensure workspace A is active.
     multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[0].clone();
+        let workspace = mw.workspaces().next().unwrap().clone();
         mw.activate(workspace, window, cx);
     });
     cx.run_until_parked();
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
-        0
+        multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+        workspace_a
     );
 
     // Call activate_archived_thread – should resolve saved paths and
@@ -3562,8 +3576,8 @@ async fn test_activate_archived_thread_with_saved_paths_activates_matching_works
     cx.run_until_parked();
 
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
-        1,
+        multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+        workspace_b,
         "should have activated the workspace matching the saved path_list"
     );
 }
@@ -3588,21 +3602,23 @@ async fn test_activate_archived_thread_cwd_fallback_with_matching_workspace(
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
 
-    multi_workspace.update_in(cx, |mw, window, cx| {
-        mw.test_add_workspace(project_b, window, cx);
-    });
-
     let sidebar = setup_sidebar(&multi_workspace, cx);
 
+    let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| {
+        mw.test_add_workspace(project_b, window, cx)
+    });
+    let workspace_a =
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone());
+
     // Start with workspace A active.
     multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[0].clone();
+        let workspace = mw.workspaces().next().unwrap().clone();
         mw.activate(workspace, window, cx);
     });
     cx.run_until_parked();
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
-        0
+        multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+        workspace_a
     );
 
     // No thread saved to the store – cwd is the only path hint.
@@ -3625,8 +3641,8 @@ async fn test_activate_archived_thread_cwd_fallback_with_matching_workspace(
     cx.run_until_parked();
 
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
-        1,
+        multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+        workspace_b,
         "should have activated the workspace matching the cwd"
     );
 }
@@ -3651,21 +3667,21 @@ async fn test_activate_archived_thread_no_paths_no_cwd_uses_active_workspace(
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
 
-    multi_workspace.update_in(cx, |mw, window, cx| {
-        mw.test_add_workspace(project_b, window, cx);
-    });
-
     let sidebar = setup_sidebar(&multi_workspace, cx);
 
+    let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| {
+        mw.test_add_workspace(project_b, window, cx)
+    });
+
     // Activate workspace B (index 1) to make it the active one.
     multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[1].clone();
+        let workspace = mw.workspaces().nth(1).unwrap().clone();
         mw.activate(workspace, window, cx);
     });
     cx.run_until_parked();
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
-        1
+        multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+        workspace_b
     );
 
     // No saved thread, no cwd – should fall back to the active workspace.
@@ -3688,8 +3704,8 @@ async fn test_activate_archived_thread_no_paths_no_cwd_uses_active_workspace(
     cx.run_until_parked();
 
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
-        1,
+        multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+        workspace_b,
         "should have stayed on the active workspace when no path info is available"
     );
 }
@@ -3719,7 +3735,7 @@ async fn test_activate_archived_thread_saved_paths_opens_new_workspace(cx: &mut
     let session_id = acp::SessionId::new(Arc::from("archived-new-ws"));
 
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()),
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
         1,
         "should start with one workspace"
     );
@@ -3743,7 +3759,7 @@ async fn test_activate_archived_thread_saved_paths_opens_new_workspace(cx: &mut
     cx.run_until_parked();
 
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()),
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
         2,
         "should have opened a second workspace for the archived thread's saved paths"
     );
@@ -3768,6 +3784,10 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window(cx: &m
         cx.add_window(|window, cx| MultiWorkspace::test_new(project_b, window, cx));
 
     let multi_workspace_a_entity = multi_workspace_a.root(cx).unwrap();
+    let multi_workspace_b_entity = multi_workspace_b.root(cx).unwrap();
+
+    let cx_b = &mut gpui::VisualTestContext::from_window(multi_workspace_b.into(), cx);
+    let _sidebar_b = setup_sidebar(&multi_workspace_b_entity, cx_b);
 
     let cx_a = &mut gpui::VisualTestContext::from_window(multi_workspace_a.into(), cx);
     let sidebar = setup_sidebar(&multi_workspace_a_entity, cx_a);
@@ -3794,14 +3814,14 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window(cx: &m
 
     assert_eq!(
         multi_workspace_a
-            .read_with(cx_a, |mw, _| mw.workspaces().len())
+            .read_with(cx_a, |mw, _| mw.workspaces().count())
             .unwrap(),
         1,
         "should not add the other window's workspace into the current window"
     );
     assert_eq!(
         multi_workspace_b
-            .read_with(cx_a, |mw, _| mw.workspaces().len())
+            .read_with(cx_a, |mw, _| mw.workspaces().count())
             .unwrap(),
         1,
         "should reuse the existing workspace in the other window"
@@ -3871,14 +3891,14 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window_with_t
 
     assert_eq!(
         multi_workspace_a
-            .read_with(cx_a, |mw, _| mw.workspaces().len())
+            .read_with(cx_a, |mw, _| mw.workspaces().count())
             .unwrap(),
         1,
         "should not add the other window's workspace into the current window"
     );
     assert_eq!(
         multi_workspace_b
-            .read_with(cx_a, |mw, _| mw.workspaces().len())
+            .read_with(cx_a, |mw, _| mw.workspaces().count())
             .unwrap(),
         1,
         "should reuse the existing workspace in the other window"
@@ -3921,6 +3941,10 @@ async fn test_activate_archived_thread_prefers_current_window_for_matching_paths
         cx.add_window(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
 
     let multi_workspace_a_entity = multi_workspace_a.root(cx).unwrap();
+    let multi_workspace_b_entity = multi_workspace_b.root(cx).unwrap();
+
+    let cx_b = &mut gpui::VisualTestContext::from_window(multi_workspace_b.into(), cx);
+    let _sidebar_b = setup_sidebar(&multi_workspace_b_entity, cx_b);
 
     let cx_a = &mut gpui::VisualTestContext::from_window(multi_workspace_a.into(), cx);
     let sidebar_a = setup_sidebar(&multi_workspace_a_entity, cx_a);
@@ -3958,14 +3982,14 @@ async fn test_activate_archived_thread_prefers_current_window_for_matching_paths
     });
     assert_eq!(
         multi_workspace_a
-            .read_with(cx_a, |mw, _| mw.workspaces().len())
+            .read_with(cx_a, |mw, _| mw.workspaces().count())
             .unwrap(),
         1,
         "current window should continue reusing its existing workspace"
     );
     assert_eq!(
         multi_workspace_b
-            .read_with(cx_a, |mw, _| mw.workspaces().len())
+            .read_with(cx_a, |mw, _| mw.workspaces().count())
             .unwrap(),
         1,
         "other windows should not be activated just because they also match the saved paths"
@@ -4029,19 +4053,20 @@ async fn test_archive_thread_uses_next_threads_own_workspace(cx: &mut TestAppCon
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx));
 
+    let sidebar = setup_sidebar(&multi_workspace, cx);
+
     let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
         mw.test_add_workspace(worktree_project.clone(), window, cx)
     });
 
     // Activate main workspace so the sidebar tracks the main panel.
     multi_workspace.update_in(cx, |mw, window, cx| {
-        let workspace = mw.workspaces()[0].clone();
+        let workspace = mw.workspaces().next().unwrap().clone();
         mw.activate(workspace, window, cx);
     });
 
-    let sidebar = setup_sidebar(&multi_workspace, cx);
-
-    let main_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspaces()[0].clone());
+    let main_workspace =
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone());
     let main_panel = add_agent_panel(&main_workspace, cx);
     let _worktree_panel = add_agent_panel(&worktree_workspace, cx);
 
@@ -4195,10 +4220,10 @@ async fn test_linked_worktree_threads_not_duplicated_across_groups(cx: &mut Test
 
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_only.clone(), window, cx));
+    let sidebar = setup_sidebar(&multi_workspace, cx);
     multi_workspace.update_in(cx, |mw, window, cx| {
         mw.test_add_workspace(multi_root.clone(), window, cx);
     });
-    let sidebar = setup_sidebar(&multi_workspace, cx);
 
     // Save a thread under the linked worktree path.
     save_named_thread_metadata("wt-thread", "Worktree Thread", &worktree_project, cx).await;
@@ -4313,8 +4338,8 @@ async fn test_thread_switcher_ordering(cx: &mut TestAppContext) {
     // so all three have last_accessed_at set.
     // Access order is: A (most recent), B, C (oldest).
 
-    // ── 1. Open switcher: threads sorted by last_accessed_at ───────────
-    open_and_focus_sidebar(&sidebar, cx);
+    // ── 1. Open switcher: threads sorted by last_accessed_at ─────────────────
+    focus_sidebar(&sidebar, cx);
     sidebar.update_in(cx, |sidebar, window, cx| {
         sidebar.on_toggle_thread_switcher(&ToggleThreadSwitcher::default(), window, cx);
     });
@@ -4759,6 +4784,170 @@ async fn test_linked_worktree_workspace_shows_main_worktree_threads(cx: &mut Tes
     );
 }
 
+async fn init_multi_project_test(
+    paths: &[&str],
+    cx: &mut TestAppContext,
+) -> (Arc<FakeFs>, Entity<project::Project>) {
+    agent_ui::test_support::init_test(cx);
+    cx.update(|cx| {
+        cx.update_flags(false, vec!["agent-v2".into()]);
+        ThreadStore::init_global(cx);
+        ThreadMetadataStore::init_global(cx);
+        language_model::LanguageModelRegistry::test(cx);
+        prompt_store::init(cx);
+    });
+    let fs = FakeFs::new(cx.executor());
+    for path in paths {
+        fs.insert_tree(path, serde_json::json!({ ".git": {}, "src": {} }))
+            .await;
+    }
+    cx.update(|cx| <dyn fs::Fs>::set_global(fs.clone(), cx));
+    let project =
+        project::Project::test(fs.clone() as Arc<dyn fs::Fs>, [paths[0].as_ref()], cx).await;
+    (fs, project)
+}
+
+async fn add_test_project(
+    path: &str,
+    fs: &Arc<FakeFs>,
+    multi_workspace: &Entity<MultiWorkspace>,
+    cx: &mut gpui::VisualTestContext,
+) -> Entity<Workspace> {
+    let project = project::Project::test(fs.clone() as Arc<dyn fs::Fs>, [path.as_ref()], cx).await;
+    let workspace = multi_workspace.update_in(cx, |mw, window, cx| {
+        mw.test_add_workspace(project, window, cx)
+    });
+    cx.run_until_parked();
+    workspace
+}
+
+#[gpui::test]
+async fn test_transient_workspace_lifecycle(cx: &mut TestAppContext) {
+    let (fs, project_a) =
+        init_multi_project_test(&["/project-a", "/project-b", "/project-c"], cx).await;
+    let (multi_workspace, cx) =
+        cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
+    let _sidebar = setup_sidebar_closed(&multi_workspace, cx);
+
+    // Sidebar starts closed. Initial workspace A is transient.
+    let workspace_a = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+    assert!(!multi_workspace.read_with(cx, |mw, _| mw.sidebar_open()));
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        1
+    );
+    assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_a));
+
+    // Add B — replaces A as the transient workspace.
+    let workspace_b = add_test_project("/project-b", &fs, &multi_workspace, cx).await;
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        1
+    );
+    assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_b));
+
+    // Add C — replaces B as the transient workspace.
+    let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await;
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        1
+    );
+    assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c));
+}
+
+#[gpui::test]
+async fn test_transient_workspace_retained(cx: &mut TestAppContext) {
+    let (fs, project_a) = init_multi_project_test(
+        &["/project-a", "/project-b", "/project-c", "/project-d"],
+        cx,
+    )
+    .await;
+    let (multi_workspace, cx) =
+        cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
+    let _sidebar = setup_sidebar(&multi_workspace, cx);
+    assert!(multi_workspace.read_with(cx, |mw, _| mw.sidebar_open()));
+
+    // Add B — retained since sidebar is open.
+    let workspace_a = add_test_project("/project-b", &fs, &multi_workspace, cx).await;
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        2
+    );
+
+    // Switch to A — B survives. (Switching from one internal workspace, to another)
+    multi_workspace.update_in(cx, |mw, window, cx| mw.activate(workspace_a, window, cx));
+    cx.run_until_parked();
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        2
+    );
+
+    // Close sidebar — both A and B remain retained.
+    multi_workspace.update_in(cx, |mw, window, cx| mw.close_sidebar(window, cx));
+    cx.run_until_parked();
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        2
+    );
+
+    // Add C — added as new transient workspace. (switching from retained, to transient)
+    let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await;
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        3
+    );
+    assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c));
+
+    // Add D — replaces C as the transient workspace (Have retained and transient workspaces, transient workspace is dropped)
+    let workspace_d = add_test_project("/project-d", &fs, &multi_workspace, cx).await;
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        3
+    );
+    assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_d));
+}
+
+#[gpui::test]
+async fn test_transient_workspace_promotion(cx: &mut TestAppContext) {
+    let (fs, project_a) =
+        init_multi_project_test(&["/project-a", "/project-b", "/project-c"], cx).await;
+    let (multi_workspace, cx) =
+        cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
+    setup_sidebar_closed(&multi_workspace, cx);
+
+    // Add B — replaces A as the transient workspace (A is discarded).
+    let workspace_b = add_test_project("/project-b", &fs, &multi_workspace, cx).await;
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        1
+    );
+    assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_b));
+
+    // Open sidebar — promotes the transient B to retained.
+    multi_workspace.update_in(cx, |mw, window, cx| {
+        mw.toggle_sidebar(window, cx);
+    });
+    cx.run_until_parked();
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        1
+    );
+    assert!(multi_workspace.read_with(cx, |mw, _| mw.workspaces().any(|w| w == &workspace_b)));
+
+    // Close sidebar — the retained B remains.
+    multi_workspace.update_in(cx, |mw, window, cx| {
+        mw.toggle_sidebar(window, cx);
+    });
+
+    // Add C — added as new transient workspace.
+    let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await;
+    assert_eq!(
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+        2
+    );
+    assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c));
+}
+
 #[gpui::test]
 async fn test_legacy_thread_with_canonical_path_opens_main_repo_workspace(cx: &mut TestAppContext) {
     init_test(cx);
@@ -4843,12 +5032,12 @@ async fn test_legacy_thread_with_canonical_path_opens_main_repo_workspace(cx: &m
 
     // Verify only 1 workspace before clicking.
     assert_eq!(
-        multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()),
+        multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
         1,
     );
 
     // Focus and select the legacy thread, then confirm.
-    open_and_focus_sidebar(&sidebar, cx);
+    focus_sidebar(&sidebar, cx);
     let thread_index = sidebar.read_with(cx, |sidebar, _| {
         sidebar
             .contents
@@ -5057,7 +5246,12 @@ mod property_test {
         match operation {
             Operation::SaveThread { workspace_index } => {
                 let project = multi_workspace.read_with(cx, |mw, cx| {
-                    mw.workspaces()[workspace_index].read(cx).project().clone()
+                    mw.workspaces()
+                        .nth(workspace_index)
+                        .unwrap()
+                        .read(cx)
+                        .project()
+                        .clone()
                 });
                 save_thread_to_path(state, &project, cx);
             }
@@ -5144,7 +5338,7 @@ mod property_test {
             }
             Operation::RemoveWorkspace { index } => {
                 let removed = multi_workspace.update_in(cx, |mw, window, cx| {
-                    let workspace = mw.workspaces()[index].clone();
+                    let workspace = mw.workspaces().nth(index).unwrap().clone();
                     mw.remove(&workspace, window, cx)
                 });
                 if removed {
@@ -5158,8 +5352,8 @@ mod property_test {
                 }
             }
             Operation::SwitchWorkspace { index } => {
-                let workspace =
-                    multi_workspace.read_with(cx, |mw, _| mw.workspaces()[index].clone());
+                let workspace = multi_workspace
+                    .read_with(cx, |mw, _| mw.workspaces().nth(index).unwrap().clone());
                 multi_workspace.update_in(cx, |mw, window, cx| {
                     mw.activate(workspace, window, cx);
                 });
@@ -5209,8 +5403,9 @@ mod property_test {
                     .await;
 
                 // Re-scan the main workspace's project so it discovers the new worktree.
-                let main_workspace =
-                    multi_workspace.read_with(cx, |mw, _| mw.workspaces()[workspace_index].clone());
+                let main_workspace = multi_workspace.read_with(cx, |mw, _| {
+                    mw.workspaces().nth(workspace_index).unwrap().clone()
+                });
                 let main_project = main_workspace.read_with(cx, |ws, _| ws.project().clone());
                 main_project
                     .update(cx, |p, cx| p.git_scans_complete(cx))
@@ -5297,7 +5492,11 @@ mod property_test {
         let Some(multi_workspace) = sidebar.multi_workspace.upgrade() else {
             anyhow::bail!("sidebar should still have an associated multi-workspace");
         };
-        let workspaces = multi_workspace.read(cx).workspaces().to_vec();
+        let workspaces = multi_workspace
+            .read(cx)
+            .workspaces()
+            .cloned()
+            .collect::<Vec<_>>();
         let thread_store = ThreadMetadataStore::global(cx);
 
         let sidebar_thread_ids: HashSet<acp::SessionId> = sidebar

crates/title_bar/src/title_bar.rs 🔗

@@ -740,7 +740,6 @@ impl TitleBar {
             .map(|mw| {
                 mw.read(cx)
                     .workspaces()
-                    .iter()
                     .filter_map(|ws| ws.read(cx).database_id())
                     .collect()
             })
@@ -803,7 +802,6 @@ impl TitleBar {
             .map(|mw| {
                 mw.read(cx)
                     .workspaces()
-                    .iter()
                     .filter_map(|ws| ws.read(cx).database_id())
                     .collect()
             })

crates/workspace/src/multi_workspace.rs 🔗

@@ -6,9 +6,7 @@ use gpui::{
     ManagedView, MouseButton, Pixels, Render, Subscription, Task, Tiling, Window, WindowId,
     actions, deferred, px,
 };
-#[cfg(any(test, feature = "test-support"))]
-use project::Project;
-use project::{DirectoryLister, DisableAiSettings, ProjectGroupKey};
+use project::{DirectoryLister, DisableAiSettings, Project, ProjectGroupKey};
 use settings::Settings;
 pub use settings::SidebarSide;
 use std::future::Future;
@@ -42,10 +40,7 @@ actions!(
         CloseWorkspaceSidebar,
         /// Moves focus to or from the workspace sidebar without closing it.
         FocusWorkspaceSidebar,
-        /// Switches to the next workspace.
-        NextWorkspace,
-        /// Switches to the previous workspace.
-        PreviousWorkspace,
+        //TODO: Restore next/previous workspace
     ]
 );
 
@@ -223,10 +218,57 @@ impl<T: Sidebar> SidebarHandle for Entity<T> {
     }
 }
 
+/// Tracks which workspace the user is currently looking at.
+///
+/// `Persistent` workspaces live in the `workspaces` vec and are shown in the
+/// sidebar. `Transient` workspaces exist outside the vec and are discarded
+/// when the user switches away.
+enum ActiveWorkspace {
+    /// A persistent workspace, identified by index into the `workspaces` vec.
+    Persistent(usize),
+    /// A workspace not in the `workspaces` vec that will be discarded on
+    /// switch or promoted to persistent when the sidebar is opened.
+    Transient(Entity<Workspace>),
+}
+
+impl ActiveWorkspace {
+    fn persistent_index(&self) -> Option<usize> {
+        match self {
+            Self::Persistent(index) => Some(*index),
+            Self::Transient(_) => None,
+        }
+    }
+
+    fn transient_workspace(&self) -> Option<&Entity<Workspace>> {
+        match self {
+            Self::Transient(workspace) => Some(workspace),
+            Self::Persistent(_) => None,
+        }
+    }
+
+    /// Sets the active workspace to transient, returning the previous
+    /// transient workspace (if any).
+    fn set_transient(&mut self, workspace: Entity<Workspace>) -> Option<Entity<Workspace>> {
+        match std::mem::replace(self, Self::Transient(workspace)) {
+            Self::Transient(old) => Some(old),
+            Self::Persistent(_) => None,
+        }
+    }
+
+    /// Sets the active workspace to persistent at the given index,
+    /// returning the previous transient workspace (if any).
+    fn set_persistent(&mut self, index: usize) -> Option<Entity<Workspace>> {
+        match std::mem::replace(self, Self::Persistent(index)) {
+            Self::Transient(workspace) => Some(workspace),
+            Self::Persistent(_) => None,
+        }
+    }
+}
+
 pub struct MultiWorkspace {
     window_id: WindowId,
     workspaces: Vec<Entity<Workspace>>,
-    active_workspace_index: usize,
+    active_workspace: ActiveWorkspace,
     project_group_keys: Vec<ProjectGroupKey>,
     sidebar: Option<Box<dyn SidebarHandle>>,
     sidebar_open: bool,
@@ -262,12 +304,15 @@ impl MultiWorkspace {
             }
         });
         let quit_subscription = cx.on_app_quit(Self::app_will_quit);
-        let settings_subscription =
-            cx.observe_global_in::<settings::SettingsStore>(window, |this, window, cx| {
-                if DisableAiSettings::get_global(cx).disable_ai && this.sidebar_open {
-                    this.close_sidebar(window, cx);
+        let settings_subscription = cx.observe_global_in::<settings::SettingsStore>(window, {
+            let mut previous_disable_ai = DisableAiSettings::get_global(cx).disable_ai;
+            move |this, window, cx| {
+                if DisableAiSettings::get_global(cx).disable_ai != previous_disable_ai {
+                    this.collapse_to_single_workspace(window, cx);
+                    previous_disable_ai = DisableAiSettings::get_global(cx).disable_ai;
                 }
-            });
+            }
+        });
         Self::subscribe_to_workspace(&workspace, window, cx);
         let weak_self = cx.weak_entity();
         workspace.update(cx, |workspace, cx| {
@@ -275,9 +320,9 @@ impl MultiWorkspace {
         });
         Self {
             window_id: window.window_handle().window_id(),
-            project_group_keys: vec![workspace.read(cx).project_group_key(cx)],
-            workspaces: vec![workspace],
-            active_workspace_index: 0,
+            project_group_keys: Vec::new(),
+            workspaces: Vec::new(),
+            active_workspace: ActiveWorkspace::Transient(workspace),
             sidebar: None,
             sidebar_open: false,
             sidebar_overlay: None,
@@ -339,7 +384,7 @@ impl MultiWorkspace {
             return;
         }
 
-        if self.sidebar_open {
+        if self.sidebar_open() {
             self.close_sidebar(window, cx);
         } else {
             self.open_sidebar(cx);
@@ -355,7 +400,7 @@ impl MultiWorkspace {
             return;
         }
 
-        if self.sidebar_open {
+        if self.sidebar_open() {
             self.close_sidebar(window, cx);
         }
     }
@@ -365,7 +410,7 @@ impl MultiWorkspace {
             return;
         }
 
-        if self.sidebar_open {
+        if self.sidebar_open() {
             let sidebar_is_focused = self
                 .sidebar
                 .as_ref()
@@ -390,8 +435,13 @@ impl MultiWorkspace {
 
     pub fn open_sidebar(&mut self, cx: &mut Context<Self>) {
         self.sidebar_open = true;
+        if let ActiveWorkspace::Transient(workspace) = &self.active_workspace {
+            let workspace = workspace.clone();
+            let index = self.promote_transient(workspace, cx);
+            self.active_workspace = ActiveWorkspace::Persistent(index);
+        }
         let sidebar_focus_handle = self.sidebar.as_ref().map(|s| s.focus_handle(cx));
-        for workspace in &self.workspaces {
+        for workspace in self.workspaces.iter() {
             workspace.update(cx, |workspace, _cx| {
                 workspace.set_sidebar_focus_handle(sidebar_focus_handle.clone());
             });
@@ -402,7 +452,7 @@ impl MultiWorkspace {
 
     pub fn close_sidebar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
         self.sidebar_open = false;
-        for workspace in &self.workspaces {
+        for workspace in self.workspaces.iter() {
             workspace.update(cx, |workspace, _cx| {
                 workspace.set_sidebar_focus_handle(None);
             });
@@ -417,7 +467,7 @@ impl MultiWorkspace {
     pub fn close_window(&mut self, _: &CloseWindow, window: &mut Window, cx: &mut Context<Self>) {
         cx.spawn_in(window, async move |this, cx| {
             let workspaces = this.update(cx, |multi_workspace, _cx| {
-                multi_workspace.workspaces().to_vec()
+                multi_workspace.workspaces().cloned().collect::<Vec<_>>()
             })?;
 
             for workspace in workspaces {
@@ -468,6 +518,9 @@ impl MultiWorkspace {
     }
 
     pub fn add_project_group_key(&mut self, project_group_key: ProjectGroupKey) {
+        if project_group_key.path_list().paths().is_empty() {
+            return;
+        }
         if self.project_group_keys.contains(&project_group_key) {
             return;
         }
@@ -656,6 +709,12 @@ impl MultiWorkspace {
             return Task::ready(Ok(workspace));
         }
 
+        if let Some(transient) = self.active_workspace.transient_workspace() {
+            if transient.read(cx).project_group_key(cx).path_list() == &path_list {
+                return Task::ready(Ok(transient.clone()));
+            }
+        }
+
         let paths = path_list.paths().to_vec();
         let app_state = self.workspace().read(cx).app_state().clone();
         let requesting_window = window.window_handle().downcast::<MultiWorkspace>();
@@ -679,25 +738,23 @@ impl MultiWorkspace {
     }
 
     pub fn workspace(&self) -> &Entity<Workspace> {
-        &self.workspaces[self.active_workspace_index]
-    }
-
-    pub fn workspaces(&self) -> &[Entity<Workspace>] {
-        &self.workspaces
+        match &self.active_workspace {
+            ActiveWorkspace::Persistent(index) => &self.workspaces[*index],
+            ActiveWorkspace::Transient(workspace) => workspace,
+        }
     }
 
-    pub fn active_workspace_index(&self) -> usize {
-        self.active_workspace_index
+    pub fn workspaces(&self) -> impl Iterator<Item = &Entity<Workspace>> {
+        self.workspaces
+            .iter()
+            .chain(self.active_workspace.transient_workspace())
     }
 
-    /// Adds a workspace to this window without changing which workspace is
-    /// active.
+    /// Adds a workspace to this window as persistent without changing which
+    /// workspace is active. Unlike `activate()`, this always inserts into the
+    /// persistent list regardless of sidebar state — it's used for system-
+    /// initiated additions like deserialization and worktree discovery.
     pub fn add(&mut self, workspace: Entity<Workspace>, window: &Window, cx: &mut Context<Self>) {
-        if !self.multi_workspace_enabled(cx) {
-            self.set_single_workspace(workspace, cx);
-            return;
-        }
-
         self.insert_workspace(workspace, window, cx);
     }
 
@@ -708,26 +765,74 @@ impl MultiWorkspace {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        if !self.multi_workspace_enabled(cx) {
-            self.set_single_workspace(workspace, cx);
+        // Re-activating the current workspace is a no-op.
+        if self.workspace() == &workspace {
+            self.focus_active_workspace(window, cx);
             return;
         }
 
-        let index = self.insert_workspace(workspace, &*window, cx);
-        let changed = self.active_workspace_index != index;
-        self.active_workspace_index = index;
-        if changed {
-            cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
-            self.serialize(cx);
+        // Resolve where we're going.
+        let new_index = if let Some(index) = self.workspaces.iter().position(|w| *w == workspace) {
+            Some(index)
+        } else if self.sidebar_open {
+            Some(self.insert_workspace(workspace.clone(), &*window, cx))
+        } else {
+            None
+        };
+
+        // Transition the active workspace.
+        if let Some(index) = new_index {
+            if let Some(old) = self.active_workspace.set_persistent(index) {
+                if self.sidebar_open {
+                    self.promote_transient(old, cx);
+                } else {
+                    self.detach_workspace(&old, cx);
+                    cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old.entity_id()));
+                }
+            }
+        } else {
+            Self::subscribe_to_workspace(&workspace, window, cx);
+            let weak_self = cx.weak_entity();
+            workspace.update(cx, |workspace, cx| {
+                workspace.set_multi_workspace(weak_self, cx);
+            });
+            if let Some(old) = self.active_workspace.set_transient(workspace) {
+                self.detach_workspace(&old, cx);
+                cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old.entity_id()));
+            }
         }
+
+        cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
+        self.serialize(cx);
         self.focus_active_workspace(window, cx);
         cx.notify();
     }
 
-    fn set_single_workspace(&mut self, workspace: Entity<Workspace>, cx: &mut Context<Self>) {
-        self.workspaces[0] = workspace;
-        self.active_workspace_index = 0;
-        cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
+    /// Promotes a former transient workspace into the persistent list.
+    /// Returns the index of the newly inserted workspace.
+    fn promote_transient(&mut self, workspace: Entity<Workspace>, cx: &mut Context<Self>) -> usize {
+        let project_group_key = workspace.read(cx).project().read(cx).project_group_key(cx);
+        self.add_project_group_key(project_group_key);
+        self.workspaces.push(workspace.clone());
+        cx.emit(MultiWorkspaceEvent::WorkspaceAdded(workspace));
+        self.workspaces.len() - 1
+    }
+
+    /// Collapses to a single transient workspace, discarding all persistent
+    /// workspaces. Used when multi-workspace is disabled (e.g. disable_ai).
+    fn collapse_to_single_workspace(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+        if self.sidebar_open {
+            self.close_sidebar(window, cx);
+        }
+        let active = self.workspace().clone();
+        for workspace in std::mem::take(&mut self.workspaces) {
+            if workspace != active {
+                self.detach_workspace(&workspace, cx);
+                cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(workspace.entity_id()));
+            }
+        }
+        self.project_group_keys.clear();
+        self.active_workspace = ActiveWorkspace::Transient(active);
         cx.notify();
     }
 
@@ -783,7 +888,7 @@ impl MultiWorkspace {
     }
 
     fn sync_sidebar_to_workspace(&self, workspace: &Entity<Workspace>, cx: &mut Context<Self>) {
-        if self.sidebar_open {
+        if self.sidebar_open() {
             let sidebar_focus_handle = self.sidebar.as_ref().map(|s| s.focus_handle(cx));
             workspace.update(cx, |workspace, _| {
                 workspace.set_sidebar_focus_handle(sidebar_focus_handle);
@@ -791,30 +896,6 @@ impl MultiWorkspace {
         }
     }
 
-    fn cycle_workspace(&mut self, delta: isize, window: &mut Window, cx: &mut Context<Self>) {
-        let count = self.workspaces.len() as isize;
-        if count <= 1 {
-            return;
-        }
-        let current = self.active_workspace_index as isize;
-        let next = ((current + delta).rem_euclid(count)) as usize;
-        let workspace = self.workspaces[next].clone();
-        self.activate(workspace, window, cx);
-    }
-
-    fn next_workspace(&mut self, _: &NextWorkspace, window: &mut Window, cx: &mut Context<Self>) {
-        self.cycle_workspace(1, window, cx);
-    }
-
-    fn previous_workspace(
-        &mut self,
-        _: &PreviousWorkspace,
-        window: &mut Window,
-        cx: &mut Context<Self>,
-    ) {
-        self.cycle_workspace(-1, window, cx);
-    }
-
     pub(crate) fn serialize(&mut self, cx: &mut Context<Self>) {
         self._serialize_task = Some(cx.spawn(async move |this, cx| {
             let Some((window_id, state)) = this
@@ -1040,26 +1121,82 @@ impl MultiWorkspace {
         let Some(index) = self.workspaces.iter().position(|w| w == workspace) else {
             return false;
         };
+
+        let old_key = workspace.read(cx).project_group_key(cx);
+
         if self.workspaces.len() <= 1 {
-            return false;
-        }
+            let has_worktrees = workspace.read(cx).visible_worktrees(cx).next().is_some();
 
-        let removed_workspace = self.workspaces.remove(index);
+            if !has_worktrees {
+                return false;
+            }
 
-        if self.active_workspace_index >= self.workspaces.len() {
-            self.active_workspace_index = self.workspaces.len() - 1;
-        } else if self.active_workspace_index > index {
-            self.active_workspace_index -= 1;
+            let old_workspace = workspace.clone();
+            let old_entity_id = old_workspace.entity_id();
+
+            let app_state = old_workspace.read(cx).app_state().clone();
+
+            let project = Project::local(
+                app_state.client.clone(),
+                app_state.node_runtime.clone(),
+                app_state.user_store.clone(),
+                app_state.languages.clone(),
+                app_state.fs.clone(),
+                None,
+                project::LocalProjectFlags::default(),
+                cx,
+            );
+
+            let new_workspace = cx.new(|cx| Workspace::new(None, project, app_state, window, cx));
+
+            self.workspaces[0] = new_workspace.clone();
+            self.active_workspace = ActiveWorkspace::Persistent(0);
+
+            Self::subscribe_to_workspace(&new_workspace, window, cx);
+
+            self.sync_sidebar_to_workspace(&new_workspace, cx);
+
+            let weak_self = cx.weak_entity();
+
+            new_workspace.update(cx, |workspace, cx| {
+                workspace.set_multi_workspace(weak_self, cx);
+            });
+
+            self.detach_workspace(&old_workspace, cx);
+
+            cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old_entity_id));
+            cx.emit(MultiWorkspaceEvent::WorkspaceAdded(new_workspace));
+            cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
+        } else {
+            let removed_workspace = self.workspaces.remove(index);
+
+            if let Some(active_index) = self.active_workspace.persistent_index() {
+                if active_index >= self.workspaces.len() {
+                    self.active_workspace = ActiveWorkspace::Persistent(self.workspaces.len() - 1);
+                } else if active_index > index {
+                    self.active_workspace = ActiveWorkspace::Persistent(active_index - 1);
+                }
+            }
+
+            self.detach_workspace(&removed_workspace, cx);
+
+            cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(
+                removed_workspace.entity_id(),
+            ));
+            cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
         }
 
-        self.detach_workspace(&removed_workspace, cx);
+        let key_still_in_use = self
+            .workspaces
+            .iter()
+            .any(|ws| ws.read(cx).project_group_key(cx) == old_key);
+
+        if !key_still_in_use {
+            self.project_group_keys.retain(|k| k != &old_key);
+        }
 
         self.serialize(cx);
         self.focus_active_workspace(window, cx);
-        cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(
-            removed_workspace.entity_id(),
-        ));
-        cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
         cx.notify();
 
         true
@@ -1288,8 +1425,6 @@ impl Render for MultiWorkspace {
                             this.focus_sidebar(window, cx);
                         },
                     ))
-                    .on_action(cx.listener(Self::next_workspace))
-                    .on_action(cx.listener(Self::previous_workspace))
                     .on_action(cx.listener(Self::move_active_workspace_to_new_window))
                     .on_action(cx.listener(
                         |this: &mut Self, action: &ToggleThreadSwitcher, window, cx| {

crates/workspace/src/multi_workspace_tests.rs 🔗

@@ -99,6 +99,10 @@ async fn test_project_group_keys_initial(cx: &mut TestAppContext) {
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(project, window, cx));
 
+    multi_workspace.update(cx, |mw, cx| {
+        mw.open_sidebar(cx);
+    });
+
     multi_workspace.read_with(cx, |mw, _cx| {
         let keys: Vec<&ProjectGroupKey> = mw.project_group_keys().collect();
         assert_eq!(keys.len(), 1, "should have exactly one key on creation");
@@ -125,6 +129,10 @@ async fn test_project_group_keys_add_workspace(cx: &mut TestAppContext) {
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
 
+    multi_workspace.update(cx, |mw, cx| {
+        mw.open_sidebar(cx);
+    });
+
     multi_workspace.read_with(cx, |mw, _cx| {
         assert_eq!(mw.project_group_keys().count(), 1);
     });
@@ -162,6 +170,10 @@ async fn test_project_group_keys_duplicate_not_added(cx: &mut TestAppContext) {
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
 
+    multi_workspace.update(cx, |mw, cx| {
+        mw.open_sidebar(cx);
+    });
+
     multi_workspace.update_in(cx, |mw, window, cx| {
         mw.test_add_workspace(project_a2, window, cx);
     });
@@ -189,6 +201,10 @@ async fn test_project_group_keys_on_worktree_added(cx: &mut TestAppContext) {
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
 
+    multi_workspace.update(cx, |mw, cx| {
+        mw.open_sidebar(cx);
+    });
+
     // Add a second worktree to the same project.
     let (worktree, _) = project
         .update(cx, |project, cx| {
@@ -232,6 +248,10 @@ async fn test_project_group_keys_on_worktree_removed(cx: &mut TestAppContext) {
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
 
+    multi_workspace.update(cx, |mw, cx| {
+        mw.open_sidebar(cx);
+    });
+
     // Remove one worktree.
     let worktree_b_id = project.read_with(cx, |project, cx| {
         project
@@ -282,6 +302,10 @@ async fn test_project_group_keys_across_multiple_workspaces_and_worktree_changes
     let (multi_workspace, cx) =
         cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx));
 
+    multi_workspace.update(cx, |mw, cx| {
+        mw.open_sidebar(cx);
+    });
+
     multi_workspace.update_in(cx, |mw, window, cx| {
         mw.test_add_workspace(project_b, window, cx);
     });

crates/workspace/src/pane.rs 🔗

@@ -3670,6 +3670,11 @@ impl Pane {
                 this.drag_split_direction = None;
                 this.handle_external_paths_drop(paths, window, cx)
             }))
+            .on_click(cx.listener(move |this, event: &ClickEvent, window, cx| {
+                if event.click_count() == 2 {
+                    window.dispatch_action(this.double_click_dispatch_action.boxed_clone(), cx);
+                }
+            }))
     }
 
     pub fn render_menu_overlay(menu: &Entity<ContextMenu>) -> Div {
@@ -4917,14 +4922,17 @@ impl Render for DraggedTab {
 
 #[cfg(test)]
 mod tests {
-    use std::{cell::Cell, iter::zip, num::NonZero};
+    use std::{cell::Cell, iter::zip, num::NonZero, rc::Rc};
 
     use super::*;
     use crate::{
         Member,
         item::test::{TestItem, TestProjectItem},
     };
-    use gpui::{AppContext, Axis, TestAppContext, VisualTestContext, size};
+    use gpui::{
+        AppContext, Axis, Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent,
+        TestAppContext, VisualTestContext, size,
+    };
     use project::FakeFs;
     use settings::SettingsStore;
     use theme::LoadThemes;
@@ -6649,8 +6657,6 @@ mod tests {
 
     #[gpui::test]
     async fn test_drag_tab_to_middle_tab_with_mouse_events(cx: &mut TestAppContext) {
-        use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent};
-
         init_test(cx);
         let fs = FakeFs::new(cx.executor());
 
@@ -6702,8 +6708,6 @@ mod tests {
     async fn test_drag_pinned_tab_when_show_pinned_tabs_in_separate_row_enabled(
         cx: &mut TestAppContext,
     ) {
-        use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent};
-
         init_test(cx);
         set_pinned_tabs_separate_row(cx, true);
         let fs = FakeFs::new(cx.executor());
@@ -6779,8 +6783,6 @@ mod tests {
     async fn test_drag_unpinned_tab_when_show_pinned_tabs_in_separate_row_enabled(
         cx: &mut TestAppContext,
     ) {
-        use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent};
-
         init_test(cx);
         set_pinned_tabs_separate_row(cx, true);
         let fs = FakeFs::new(cx.executor());
@@ -6833,8 +6835,6 @@ mod tests {
     async fn test_drag_mixed_tabs_when_show_pinned_tabs_in_separate_row_enabled(
         cx: &mut TestAppContext,
     ) {
-        use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent};
-
         init_test(cx);
         set_pinned_tabs_separate_row(cx, true);
         let fs = FakeFs::new(cx.executor());
@@ -6900,8 +6900,6 @@ mod tests {
 
     #[gpui::test]
     async fn test_middle_click_pinned_tab_does_not_close(cx: &mut TestAppContext) {
-        use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseUpEvent};
-
         init_test(cx);
         let fs = FakeFs::new(cx.executor());
 
@@ -6971,6 +6969,74 @@ mod tests {
         assert_item_labels(&pane, ["A*!"], cx);
     }
 
+    #[gpui::test]
+    async fn test_double_click_pinned_tab_bar_empty_space_creates_new_tab(cx: &mut TestAppContext) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+
+        let project = Project::test(fs, None, cx).await;
+        let (workspace, cx) =
+            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+        let pane = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone());
+
+        // The real NewFile handler lives in editor::init, which isn't initialized
+        // in workspace tests. Register a global action handler that sets a flag so
+        // we can verify the action is dispatched without depending on the editor crate.
+        // TODO: If editor::init is ever available in workspace tests, remove this
+        // flag and assert the resulting tab bar state directly instead.
+        let new_file_dispatched = Rc::new(Cell::new(false));
+        cx.update(|_, cx| {
+            let new_file_dispatched = new_file_dispatched.clone();
+            cx.on_action(move |_: &NewFile, _cx| {
+                new_file_dispatched.set(true);
+            });
+        });
+
+        set_pinned_tabs_separate_row(cx, true);
+
+        let item_a = add_labeled_item(&pane, "A", false, cx);
+        add_labeled_item(&pane, "B", false, cx);
+
+        pane.update_in(cx, |pane, window, cx| {
+            let ix = pane
+                .index_for_item_id(item_a.item_id())
+                .expect("item A should exist");
+            pane.pin_tab_at(ix, window, cx);
+        });
+        assert_item_labels(&pane, ["A!", "B*"], cx);
+        cx.run_until_parked();
+
+        let pinned_drop_target_bounds = cx
+            .debug_bounds("pinned_tabs_border")
+            .expect("pinned_tabs_border should have debug bounds");
+
+        cx.simulate_event(MouseDownEvent {
+            position: pinned_drop_target_bounds.center(),
+            button: MouseButton::Left,
+            modifiers: Modifiers::default(),
+            click_count: 2,
+            first_mouse: false,
+        });
+
+        cx.run_until_parked();
+
+        cx.simulate_event(MouseUpEvent {
+            position: pinned_drop_target_bounds.center(),
+            button: MouseButton::Left,
+            modifiers: Modifiers::default(),
+            click_count: 2,
+        });
+
+        cx.run_until_parked();
+
+        // TODO: If editor::init is ever available in workspace tests, replace this
+        // with an assert_item_labels check that verifies a new tab is actually created.
+        assert!(
+            new_file_dispatched.get(),
+            "Double-clicking pinned tab bar empty space should dispatch the new file action"
+        );
+    }
+
     #[gpui::test]
     async fn test_add_item_with_new_item(cx: &mut TestAppContext) {
         init_test(cx);

crates/workspace/src/persistence.rs 🔗

@@ -2535,6 +2535,10 @@ mod tests {
         let (multi_workspace, cx) =
             cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx));
 
+        multi_workspace.update(cx, |mw, cx| {
+            mw.open_sidebar(cx);
+        });
+
         multi_workspace.update_in(cx, |mw, _, cx| {
             mw.set_random_database_id(cx);
         });
@@ -2564,7 +2568,7 @@ mod tests {
 
         // --- Remove the second workspace (index 1) ---
         multi_workspace.update_in(cx, |mw, window, cx| {
-            let ws = mw.workspaces()[1].clone();
+            let ws = mw.workspaces().nth(1).unwrap().clone();
             mw.remove(&ws, window, cx);
         });
 
@@ -4191,6 +4195,10 @@ mod tests {
         let (multi_workspace, cx) =
             cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx));
 
+        multi_workspace.update(cx, |mw, cx| {
+            mw.open_sidebar(cx);
+        });
+
         multi_workspace.update_in(cx, |mw, _, cx| {
             mw.set_random_database_id(cx);
         });
@@ -4233,7 +4241,7 @@ mod tests {
 
         // Remove workspace at index 1 (the second workspace).
         multi_workspace.update_in(cx, |mw, window, cx| {
-            let ws = mw.workspaces()[1].clone();
+            let ws = mw.workspaces().nth(1).unwrap().clone();
             mw.remove(&ws, window, cx);
         });
 
@@ -4288,6 +4296,10 @@ mod tests {
         let (multi_workspace, cx) =
             cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx));
 
+        multi_workspace.update(cx, |mw, cx| {
+            mw.open_sidebar(cx);
+        });
+
         multi_workspace.update_in(cx, |mw, _, cx| {
             mw.workspace().update(cx, |ws, _cx| {
                 ws.set_database_id(ws1_id);
@@ -4339,7 +4351,7 @@ mod tests {
 
         // Remove workspace2 (index 1).
         multi_workspace.update_in(cx, |mw, window, cx| {
-            let ws = mw.workspaces()[1].clone();
+            let ws = mw.workspaces().nth(1).unwrap().clone();
             mw.remove(&ws, window, cx);
         });
 
@@ -4385,6 +4397,10 @@ mod tests {
         let (multi_workspace, cx) =
             cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx));
 
+        multi_workspace.update(cx, |mw, cx| {
+            mw.open_sidebar(cx);
+        });
+
         multi_workspace.update_in(cx, |mw, _, cx| {
             mw.set_random_database_id(cx);
         });
@@ -4418,7 +4434,7 @@ mod tests {
 
         // Remove workspace2 — this pushes a task to pending_removal_tasks.
         multi_workspace.update_in(cx, |mw, window, cx| {
-            let ws = mw.workspaces()[1].clone();
+            let ws = mw.workspaces().nth(1).unwrap().clone();
             mw.remove(&ws, window, cx);
         });
 
@@ -4427,7 +4443,6 @@ mod tests {
         let all_tasks = multi_workspace.update_in(cx, |mw, window, cx| {
             let mut tasks: Vec<Task<()>> = mw
                 .workspaces()
-                .iter()
                 .map(|workspace| {
                     workspace.update(cx, |workspace, cx| {
                         workspace.flush_serialization(window, cx)
@@ -4747,6 +4762,10 @@ mod tests {
         let (multi_workspace, cx) = cx
             .add_window_view(|window, cx| MultiWorkspace::test_new(project_2.clone(), window, cx));
 
+        multi_workspace.update(cx, |mw, cx| {
+            mw.open_sidebar(cx);
+        });
+
         multi_workspace.update_in(cx, |mw, window, cx| {
             mw.test_add_workspace(project_1.clone(), window, cx);
         });

crates/workspace/src/workspace.rs 🔗

@@ -32,8 +32,8 @@ pub use crate::notifications::NotificationFrame;
 pub use dock::Panel;
 pub use multi_workspace::{
     CloseWorkspaceSidebar, DraggedSidebar, FocusWorkspaceSidebar, MultiWorkspace,
-    MultiWorkspaceEvent, NextWorkspace, PreviousWorkspace, Sidebar, SidebarEvent, SidebarHandle,
-    SidebarRenderState, SidebarSide, ToggleWorkspaceSidebar, sidebar_side_context_menu,
+    MultiWorkspaceEvent, Sidebar, SidebarEvent, SidebarHandle, SidebarRenderState, SidebarSide,
+    ToggleWorkspaceSidebar, sidebar_side_context_menu,
 };
 pub use path_list::{PathList, SerializedPathList};
 pub use toast_layer::{ToastAction, ToastLayer, ToastView};
@@ -9079,7 +9079,7 @@ pub fn workspace_windows_for_location(
             };
 
             multi_workspace.read(cx).is_ok_and(|multi_workspace| {
-                multi_workspace.workspaces().iter().any(|workspace| {
+                multi_workspace.workspaces().any(|workspace| {
                     match workspace.read(cx).workspace_location(cx) {
                         WorkspaceLocation::Location(location, _) => {
                             match (&location, serialized_location) {
@@ -10741,6 +10741,12 @@ mod tests {
             cx.add_window(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx));
         cx.run_until_parked();
 
+        multi_workspace_handle
+            .update(cx, |mw, _window, cx| {
+                mw.open_sidebar(cx);
+            })
+            .unwrap();
+
         let workspace_a = multi_workspace_handle
             .read_with(cx, |mw, _| mw.workspace().clone())
             .unwrap();
@@ -10754,7 +10760,7 @@ mod tests {
         // Activate workspace A
         multi_workspace_handle
             .update(cx, |mw, window, cx| {
-                let workspace = mw.workspaces()[0].clone();
+                let workspace = mw.workspaces().next().unwrap().clone();
                 mw.activate(workspace, window, cx);
             })
             .unwrap();
@@ -10776,7 +10782,7 @@ mod tests {
         // Verify workspace A is active
         multi_workspace_handle
             .read_with(cx, |mw, _| {
-                assert_eq!(mw.active_workspace_index(), 0);
+                assert_eq!(mw.workspace(), &workspace_a);
             })
             .unwrap();
 
@@ -10792,8 +10798,8 @@ mod tests {
         multi_workspace_handle
             .read_with(cx, |mw, _| {
                 assert_eq!(
-                    mw.active_workspace_index(),
-                    1,
+                    mw.workspace(),
+                    &workspace_b,
                     "workspace B should be activated when it prompts"
                 );
             })
@@ -14511,6 +14517,12 @@ mod tests {
             cx.add_window(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx));
         cx.run_until_parked();
 
+        multi_workspace_handle
+            .update(cx, |mw, _window, cx| {
+                mw.open_sidebar(cx);
+            })
+            .unwrap();
+
         let workspace_a = multi_workspace_handle
             .read_with(cx, |mw, _| mw.workspace().clone())
             .unwrap();
@@ -14524,7 +14536,7 @@ mod tests {
         // Switch to workspace A
         multi_workspace_handle
             .update(cx, |mw, window, cx| {
-                let workspace = mw.workspaces()[0].clone();
+                let workspace = mw.workspaces().next().unwrap().clone();
                 mw.activate(workspace, window, cx);
             })
             .unwrap();
@@ -14570,7 +14582,7 @@ mod tests {
         // Switch to workspace B
         multi_workspace_handle
             .update(cx, |mw, window, cx| {
-                let workspace = mw.workspaces()[1].clone();
+                let workspace = mw.workspaces().nth(1).unwrap().clone();
                 mw.activate(workspace, window, cx);
             })
             .unwrap();
@@ -14579,7 +14591,7 @@ mod tests {
         // Switch back to workspace A
         multi_workspace_handle
             .update(cx, |mw, window, cx| {
-                let workspace = mw.workspaces()[0].clone();
+                let workspace = mw.workspaces().next().unwrap().clone();
                 mw.activate(workspace, window, cx);
             })
             .unwrap();

crates/zed/src/visual_test_runner.rs 🔗

@@ -2606,7 +2606,7 @@ fn run_multi_workspace_sidebar_visual_tests(
     // Add worktree to workspace 1 (index 0) so it shows as "private-test-remote"
     let add_worktree1_task = multi_workspace_window
         .update(cx, |multi_workspace, _window, cx| {
-            let workspace1 = &multi_workspace.workspaces()[0];
+            let workspace1 = multi_workspace.workspaces().next().unwrap();
             let project = workspace1.read(cx).project().clone();
             project.update(cx, |project, cx| {
                 project.find_or_create_worktree(&workspace1_dir, true, cx)
@@ -2625,7 +2625,7 @@ fn run_multi_workspace_sidebar_visual_tests(
     // Add worktree to workspace 2 (index 1) so it shows as "zed"
     let add_worktree2_task = multi_workspace_window
         .update(cx, |multi_workspace, _window, cx| {
-            let workspace2 = &multi_workspace.workspaces()[1];
+            let workspace2 = multi_workspace.workspaces().nth(1).unwrap();
             let project = workspace2.read(cx).project().clone();
             project.update(cx, |project, cx| {
                 project.find_or_create_worktree(&workspace2_dir, true, cx)
@@ -2644,7 +2644,7 @@ fn run_multi_workspace_sidebar_visual_tests(
     // Switch to workspace 1 so it's highlighted as active (index 0)
     multi_workspace_window
         .update(cx, |multi_workspace, window, cx| {
-            let workspace = multi_workspace.workspaces()[0].clone();
+            let workspace = multi_workspace.workspaces().next().unwrap().clone();
             multi_workspace.activate(workspace, window, cx);
         })
         .context("Failed to activate workspace 1")?;
@@ -2672,7 +2672,7 @@ fn run_multi_workspace_sidebar_visual_tests(
     let save_tasks = multi_workspace_window
         .update(cx, |multi_workspace, _window, cx| {
             let thread_store = agent::ThreadStore::global(cx);
-            let workspaces = multi_workspace.workspaces().to_vec();
+            let workspaces: Vec<_> = multi_workspace.workspaces().cloned().collect();
             let mut tasks = Vec::new();
 
             for (index, workspace) in workspaces.iter().enumerate() {
@@ -3211,7 +3211,7 @@ edition = "2021"
     // Add the git project as a worktree
     let add_worktree_task = workspace_window
         .update(cx, |multi_workspace, _window, cx| {
-            let workspace = &multi_workspace.workspaces()[0];
+            let workspace = multi_workspace.workspaces().next().unwrap();
             let project = workspace.read(cx).project().clone();
             project.update(cx, |project, cx| {
                 project.find_or_create_worktree(&project_path, true, cx)
@@ -3236,7 +3236,7 @@ edition = "2021"
     // Open the project panel
     let (weak_workspace, async_window_cx) = workspace_window
         .update(cx, |multi_workspace, window, cx| {
-            let workspace = &multi_workspace.workspaces()[0];
+            let workspace = multi_workspace.workspaces().next().unwrap();
             (workspace.read(cx).weak_handle(), window.to_async(cx))
         })
         .context("Failed to get workspace handle")?;
@@ -3250,7 +3250,7 @@ edition = "2021"
 
     workspace_window
         .update(cx, |multi_workspace, window, cx| {
-            let workspace = &multi_workspace.workspaces()[0];
+            let workspace = multi_workspace.workspaces().next().unwrap();
             workspace.update(cx, |workspace, cx| {
                 workspace.add_panel(project_panel, window, cx);
                 workspace.open_panel::<ProjectPanel>(window, cx);
@@ -3263,7 +3263,7 @@ edition = "2021"
     // Open main.rs in the editor
     let open_file_task = workspace_window
         .update(cx, |multi_workspace, window, cx| {
-            let workspace = &multi_workspace.workspaces()[0];
+            let workspace = multi_workspace.workspaces().next().unwrap();
             workspace.update(cx, |workspace, cx| {
                 let worktree = workspace.project().read(cx).worktrees(cx).next();
                 if let Some(worktree) = worktree {
@@ -3291,7 +3291,7 @@ edition = "2021"
     // Load the AgentPanel
     let (weak_workspace, async_window_cx) = workspace_window
         .update(cx, |multi_workspace, window, cx| {
-            let workspace = &multi_workspace.workspaces()[0];
+            let workspace = multi_workspace.workspaces().next().unwrap();
             (workspace.read(cx).weak_handle(), window.to_async(cx))
         })
         .context("Failed to get workspace handle for agent panel")?;
@@ -3335,7 +3335,7 @@ edition = "2021"
 
     workspace_window
         .update(cx, |multi_workspace, window, cx| {
-            let workspace = &multi_workspace.workspaces()[0];
+            let workspace = multi_workspace.workspaces().next().unwrap();
             workspace.update(cx, |workspace, cx| {
                 workspace.add_panel(panel.clone(), window, cx);
                 workspace.open_panel::<AgentPanel>(window, cx);
@@ -3512,7 +3512,7 @@ edition = "2021"
                 .is_none()
         });
         let workspace_count = workspace_window.update(cx, |multi_workspace, _window, _cx| {
-            multi_workspace.workspaces().len()
+            multi_workspace.workspaces().count()
         })?;
         if workspace_count == 2 && status_cleared {
             creation_complete = true;
@@ -3531,7 +3531,7 @@ edition = "2021"
     // error state by injecting the stub server, and shrink the panel so the
     // editor content is visible.
     workspace_window.update(cx, |multi_workspace, window, cx| {
-        let new_workspace = &multi_workspace.workspaces()[1];
+        let new_workspace = multi_workspace.workspaces().nth(1).unwrap();
         new_workspace.update(cx, |workspace, cx| {
             if let Some(new_panel) = workspace.panel::<AgentPanel>(cx) {
                 new_panel.update(cx, |panel, cx| {
@@ -3544,7 +3544,7 @@ edition = "2021"
 
     // Type and send a message so the thread target dropdown disappears.
     let new_panel = workspace_window.update(cx, |multi_workspace, _window, cx| {
-        let new_workspace = &multi_workspace.workspaces()[1];
+        let new_workspace = multi_workspace.workspaces().nth(1).unwrap();
         new_workspace.read(cx).panel::<AgentPanel>(cx)
     })?;
     if let Some(new_panel) = new_panel {
@@ -3585,7 +3585,7 @@ edition = "2021"
 
     workspace_window
         .update(cx, |multi_workspace, _window, cx| {
-            let workspace = &multi_workspace.workspaces()[0];
+            let workspace = multi_workspace.workspaces().next().unwrap();
             let project = workspace.read(cx).project().clone();
             project.update(cx, |project, cx| {
                 let worktree_ids: Vec<_> =

crates/zed/src/zed.rs 🔗

@@ -1524,7 +1524,7 @@ fn quit(_: &Quit, cx: &mut App) {
             let window = *window;
             let workspaces = window
                 .update(cx, |multi_workspace, _, _| {
-                    multi_workspace.workspaces().to_vec()
+                    multi_workspace.workspaces().cloned().collect::<Vec<_>>()
                 })
                 .log_err();
 
@@ -2458,7 +2458,6 @@ mod tests {
             .update(cx, |multi_workspace, window, cx| {
                 let mut tasks = multi_workspace
                     .workspaces()
-                    .iter()
                     .map(|workspace| {
                         workspace.update(cx, |workspace, cx| {
                             workspace.flush_serialization(window, cx)
@@ -2610,7 +2609,7 @@ mod tests {
         cx.run_until_parked();
         multi_workspace_1
             .update(cx, |multi_workspace, _window, cx| {
-                assert_eq!(multi_workspace.workspaces().len(), 2);
+                assert_eq!(multi_workspace.workspaces().count(), 2);
                 assert!(multi_workspace.sidebar_open());
                 let workspace = multi_workspace.workspace().read(cx);
                 assert_eq!(
@@ -5512,6 +5511,11 @@ mod tests {
             let project = project1.clone();
             |window, cx| MultiWorkspace::test_new(project, window, cx)
         });
+        window
+            .update(cx, |multi_workspace, _, cx| {
+                multi_workspace.open_sidebar(cx);
+            })
+            .unwrap();
 
         cx.run_until_parked();
         assert_eq!(cx.windows().len(), 1, "Should start with 1 window");
@@ -5534,7 +5538,7 @@ mod tests {
 
         let workspace1 = window
             .read_with(cx, |multi_workspace, _| {
-                multi_workspace.workspaces()[0].clone()
+                multi_workspace.workspaces().next().unwrap().clone()
             })
             .unwrap();
 
@@ -5543,8 +5547,8 @@ mod tests {
                 multi_workspace.activate(workspace2.clone(), window, cx);
                 multi_workspace.activate(workspace3.clone(), window, cx);
                 // Switch back to workspace1 for test setup
-                multi_workspace.activate(workspace1, window, cx);
-                assert_eq!(multi_workspace.active_workspace_index(), 0);
+                multi_workspace.activate(workspace1.clone(), window, cx);
+                assert_eq!(multi_workspace.workspace(), &workspace1);
             })
             .unwrap();
 
@@ -5553,8 +5557,8 @@ mod tests {
         // Verify setup: 3 workspaces, workspace 0 active, still 1 window
         window
             .read_with(cx, |multi_workspace, _| {
-                assert_eq!(multi_workspace.workspaces().len(), 3);
-                assert_eq!(multi_workspace.active_workspace_index(), 0);
+                assert_eq!(multi_workspace.workspaces().count(), 3);
+                assert_eq!(multi_workspace.workspace(), &workspace1);
             })
             .unwrap();
         assert_eq!(cx.windows().len(), 1);
@@ -5577,8 +5581,8 @@ mod tests {
         window
             .read_with(cx, |multi_workspace, cx| {
                 assert_eq!(
-                    multi_workspace.active_workspace_index(),
-                    2,
+                    multi_workspace.workspace(),
+                    &workspace3,
                     "Should have switched to workspace 3 which contains /dir3"
                 );
                 let active_item = multi_workspace
@@ -5611,8 +5615,8 @@ mod tests {
         window
             .read_with(cx, |multi_workspace, cx| {
                 assert_eq!(
-                    multi_workspace.active_workspace_index(),
-                    1,
+                    multi_workspace.workspace(),
+                    &workspace2,
                     "Should have switched to workspace 2 which contains /dir2"
                 );
                 let active_item = multi_workspace
@@ -5660,8 +5664,8 @@ mod tests {
         window
             .read_with(cx, |multi_workspace, cx| {
                 assert_eq!(
-                    multi_workspace.active_workspace_index(),
-                    0,
+                    multi_workspace.workspace(),
+                    &workspace1,
                     "Should have switched back to workspace 0 which contains /dir1"
                 );
                 let active_item = multi_workspace
@@ -5711,6 +5715,11 @@ mod tests {
             let project = project1.clone();
             |window, cx| MultiWorkspace::test_new(project, window, cx)
         });
+        window1
+            .update(cx, |multi_workspace, _, cx| {
+                multi_workspace.open_sidebar(cx);
+            })
+            .unwrap();
 
         cx.run_until_parked();
 
@@ -5737,6 +5746,11 @@ mod tests {
             let project = project3.clone();
             |window, cx| MultiWorkspace::test_new(project, window, cx)
         });
+        window2
+            .update(cx, |multi_workspace, _, cx| {
+                multi_workspace.open_sidebar(cx);
+            })
+            .unwrap();
 
         cx.run_until_parked();
         assert_eq!(cx.windows().len(), 2);
@@ -5771,7 +5785,7 @@ mod tests {
         // Verify workspace1_1 is active
         window1
             .read_with(cx, |multi_workspace, _| {
-                assert_eq!(multi_workspace.active_workspace_index(), 0);
+                assert_eq!(multi_workspace.workspace(), &workspace1_1);
             })
             .unwrap();
 
@@ -5837,7 +5851,7 @@ mod tests {
         // Verify workspace1_1 is still active (not workspace1_2 with dirty item)
         window1
             .read_with(cx, |multi_workspace, _| {
-                assert_eq!(multi_workspace.active_workspace_index(), 0);
+                assert_eq!(multi_workspace.workspace(), &workspace1_1);
             })
             .unwrap();
 
@@ -5848,8 +5862,8 @@ mod tests {
         window1
             .read_with(cx, |multi_workspace, _| {
                 assert_eq!(
-                    multi_workspace.active_workspace_index(),
-                    1,
+                    multi_workspace.workspace(),
+                    &workspace1_2,
                     "Case 2: Non-active workspace should be activated when it has dirty item"
                 );
             })
@@ -6002,6 +6016,12 @@ mod tests {
             .await
             .expect("failed to open first workspace");
 
+        window_a
+            .update(cx, |multi_workspace, _, cx| {
+                multi_workspace.open_sidebar(cx);
+            })
+            .unwrap();
+
         window_a
             .update(cx, |multi_workspace, window, cx| {
                 multi_workspace.open_project(vec![dir2.into()], OpenMode::Activate, window, cx)
@@ -6028,13 +6048,19 @@ mod tests {
             .await
             .expect("failed to open third workspace");
 
+        window_b
+            .update(cx, |multi_workspace, _, cx| {
+                multi_workspace.open_sidebar(cx);
+            })
+            .unwrap();
+
         // Currently dir2 is active because it was added last.
         // So, switch window_a's active workspace to dir1 (index 0).
         // This sets up a non-trivial assertion: after restore, dir1 should
         // still be active rather than whichever workspace happened to restore last.
         window_a
             .update(cx, |multi_workspace, window, cx| {
-                let workspace = multi_workspace.workspaces()[0].clone();
+                let workspace = multi_workspace.workspaces().next().unwrap().clone();
                 multi_workspace.activate(workspace, window, cx);
             })
             .unwrap();
@@ -6150,7 +6176,7 @@ mod tests {
                         ProjectGroupKey::new(None, PathList::new(&[dir2])),
                     ]
                 );
-                assert_eq!(mw.workspaces().len(), 1);
+                assert_eq!(mw.workspaces().count(), 1);
             })
             .unwrap();
 
@@ -6161,7 +6187,7 @@ mod tests {
                     mw.project_group_keys().cloned().collect::<Vec<_>>(),
                     vec![ProjectGroupKey::new(None, PathList::new(&[dir3]))]
                 );
-                assert_eq!(mw.workspaces().len(), 1);
+                assert_eq!(mw.workspaces().count(), 1);
             })
             .unwrap();
     }

crates/zeta_prompt/Cargo.toml 🔗

@@ -13,6 +13,7 @@ path = "src/zeta_prompt.rs"
 
 [dependencies]
 anyhow.workspace = true
+imara-diff.workspace = true
 serde.workspace = true
 strum.workspace = true
 

crates/zeta_prompt/src/udiff.rs 🔗

@@ -6,6 +6,10 @@ use std::{
 };
 
 use anyhow::{Context as _, Result, anyhow};
+use imara_diff::{
+    Algorithm, Sink, diff,
+    intern::{InternedInput, Interner, Token},
+};
 
 pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> {
     if prefix.is_empty() {
@@ -221,6 +225,181 @@ pub fn disambiguate_by_line_number(
     }
 }
 
+pub fn unified_diff_with_context(
+    old_text: &str,
+    new_text: &str,
+    old_start_line: u32,
+    new_start_line: u32,
+    context_lines: u32,
+) -> String {
+    let input = InternedInput::new(old_text, new_text);
+    diff(
+        Algorithm::Histogram,
+        &input,
+        OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line, context_lines),
+    )
+}
+
+struct OffsetUnifiedDiffBuilder<'a> {
+    before: &'a [Token],
+    after: &'a [Token],
+    interner: &'a Interner<&'a str>,
+    pos: u32,
+    before_hunk_start: u32,
+    after_hunk_start: u32,
+    before_hunk_len: u32,
+    after_hunk_len: u32,
+    old_line_offset: u32,
+    new_line_offset: u32,
+    context_lines: u32,
+    buffer: String,
+    dst: String,
+}
+
+impl<'a> OffsetUnifiedDiffBuilder<'a> {
+    fn new(
+        input: &'a InternedInput<&'a str>,
+        old_line_offset: u32,
+        new_line_offset: u32,
+        context_lines: u32,
+    ) -> Self {
+        Self {
+            before_hunk_start: 0,
+            after_hunk_start: 0,
+            before_hunk_len: 0,
+            after_hunk_len: 0,
+            old_line_offset,
+            new_line_offset,
+            context_lines,
+            buffer: String::with_capacity(8),
+            dst: String::new(),
+            interner: &input.interner,
+            before: &input.before,
+            after: &input.after,
+            pos: 0,
+        }
+    }
+
+    fn print_tokens(&mut self, tokens: &[Token], prefix: char) {
+        for &token in tokens {
+            writeln!(&mut self.buffer, "{prefix}{}", self.interner[token]).unwrap();
+        }
+    }
+
+    fn flush(&mut self) {
+        if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
+            return;
+        }
+
+        let end = (self.pos + self.context_lines).min(self.before.len() as u32);
+        self.update_pos(end, end);
+
+        writeln!(
+            &mut self.dst,
+            "@@ -{},{} +{},{} @@",
+            self.before_hunk_start + 1 + self.old_line_offset,
+            self.before_hunk_len,
+            self.after_hunk_start + 1 + self.new_line_offset,
+            self.after_hunk_len,
+        )
+        .unwrap();
+        write!(&mut self.dst, "{}", &self.buffer).unwrap();
+        self.buffer.clear();
+        self.before_hunk_len = 0;
+        self.after_hunk_len = 0;
+    }
+
+    fn update_pos(&mut self, print_to: u32, move_to: u32) {
+        self.print_tokens(&self.before[self.pos as usize..print_to as usize], ' ');
+        let len = print_to - self.pos;
+        self.before_hunk_len += len;
+        self.after_hunk_len += len;
+        self.pos = move_to;
+    }
+}
+
+impl Sink for OffsetUnifiedDiffBuilder<'_> {
+    type Out = String;
+
+    fn process_change(&mut self, before: Range<u32>, after: Range<u32>) {
+        if before.start - self.pos > self.context_lines * 2 {
+            self.flush();
+        }
+        if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
+            self.pos = before.start.saturating_sub(self.context_lines);
+            self.before_hunk_start = self.pos;
+            self.after_hunk_start = after.start.saturating_sub(self.context_lines);
+        }
+
+        self.update_pos(before.start, before.end);
+        self.before_hunk_len += before.end - before.start;
+        self.after_hunk_len += after.end - after.start;
+        self.print_tokens(
+            &self.before[before.start as usize..before.end as usize],
+            '-',
+        );
+        self.print_tokens(&self.after[after.start as usize..after.end as usize], '+');
+    }
+
+    fn finish(mut self) -> Self::Out {
+        self.flush();
+        self.dst
+    }
+}
+
+pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option<usize>) -> String {
+    let Some(cursor_offset) = cursor_offset else {
+        return patch.to_string();
+    };
+
+    let mut result = String::new();
+    let mut line_start_offset = 0usize;
+
+    for line in patch.lines() {
+        if matches!(
+            DiffLine::parse(line),
+            DiffLine::Garbage(content)
+                if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER)
+        ) {
+            continue;
+        }
+
+        if !result.is_empty() {
+            result.push('\n');
+        }
+        result.push_str(line);
+
+        match DiffLine::parse(line) {
+            DiffLine::Addition(content) => {
+                let line_end_offset = line_start_offset + content.len();
+
+                if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset {
+                    let cursor_column = cursor_offset - line_start_offset;
+
+                    result.push('\n');
+                    result.push('#');
+                    for _ in 0..cursor_column {
+                        result.push(' ');
+                    }
+                    write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
+                }
+
+                line_start_offset = line_end_offset + 1;
+            }
+            DiffLine::Context(content) => {
+                line_start_offset += content.len() + 1;
+            }
+            _ => {}
+        }
+    }
+
+    if patch.ends_with('\n') {
+        result.push('\n');
+    }
+
+    result
+}
+
 pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
     apply_diff_to_string_with_hunk_offset(diff_str, text).map(|(text, _)| text)
 }
@@ -1203,4 +1382,25 @@ mod tests {
         // Edit range end should be clamped to 7 (new context length).
         assert_eq!(hunk.edits[0].range, 4..7);
     }
+
+    #[test]
+    fn test_unified_diff_with_context_matches_expected_context_window() {
+        let old_text = "line1\nline2\nline3\nline4\nline5\nCHANGE_ME\nline7\nline8\n";
+        let new_text = "line1\nline2\nline3\nline4\nline5\nCHANGED\nline7\nline8\n";
+
+        let diff_default = unified_diff_with_context(old_text, new_text, 0, 0, 3);
+        assert_eq!(
+            diff_default,
+            "@@ -3,6 +3,6 @@\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n"
+        );
+
+        let diff_full_context = unified_diff_with_context(old_text, new_text, 0, 0, 8);
+        assert_eq!(
+            diff_full_context,
+            "@@ -1,8 +1,8 @@\n line1\n line2\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n"
+        );
+
+        let diff_no_context = unified_diff_with_context(old_text, new_text, 0, 0, 0);
+        assert_eq!(diff_no_context, "@@ -6,1 +6,1 @@\n-CHANGE_ME\n+CHANGED\n");
+    }
 }

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -106,10 +106,19 @@ impl std::fmt::Display for ZetaFormat {
 
 impl ZetaFormat {
     pub fn parse(format_name: &str) -> Result<Self> {
+        let lower = format_name.to_lowercase();
+
+        // Exact case-insensitive match takes priority, bypassing ambiguity checks.
+        for variant in ZetaFormat::iter() {
+            if <&'static str>::from(&variant).to_lowercase() == lower {
+                return Ok(variant);
+            }
+        }
+
         let mut results = ZetaFormat::iter().filter(|version| {
             <&'static str>::from(version)
                 .to_lowercase()
-                .contains(&format_name.to_lowercase())
+                .contains(&lower)
         });
         let Some(result) = results.next() else {
             anyhow::bail!(
@@ -927,11 +936,39 @@ fn cursor_in_new_text(
     })
 }
 
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
 pub struct ParsedOutput {
     /// Text that should replace the editable region
     pub new_editable_region: String,
     /// The byte range within `cursor_excerpt` that this replacement applies to
     pub range_in_excerpt: Range<usize>,
+    /// Byte offset of the cursor marker within `new_editable_region`, if present
+    pub cursor_offset_in_new_editable_region: Option<usize>,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
+pub struct CursorPosition {
+    pub path: String,
+    pub row: usize,
+    pub column: usize,
+    pub offset: usize,
+    pub editable_region_offset: usize,
+}
+
+pub fn parsed_output_from_editable_region(
+    range_in_excerpt: Range<usize>,
+    mut new_editable_region: String,
+) -> ParsedOutput {
+    let cursor_offset_in_new_editable_region = new_editable_region.find(CURSOR_MARKER);
+    if let Some(offset) = cursor_offset_in_new_editable_region {
+        new_editable_region.replace_range(offset..offset + CURSOR_MARKER.len(), "");
+    }
+
+    ParsedOutput {
+        new_editable_region,
+        range_in_excerpt,
+        cursor_offset_in_new_editable_region,
+    }
 }
 
 /// Parse model output for the given zeta format
@@ -999,12 +1036,97 @@ pub fn parse_zeta2_model_output(
     let range_in_excerpt =
         range_in_context.start + context_start..range_in_context.end + context_start;
 
-    Ok(ParsedOutput {
-        new_editable_region: output,
-        range_in_excerpt,
+    Ok(parsed_output_from_editable_region(range_in_excerpt, output))
+}
+
+pub fn parse_zeta2_model_output_as_patch(
+    output: &str,
+    format: ZetaFormat,
+    prompt_inputs: &ZetaPromptInput,
+) -> Result<String> {
+    let parsed = parse_zeta2_model_output(output, format, prompt_inputs)?;
+    parsed_output_to_patch(prompt_inputs, parsed)
+}
+
+pub fn cursor_position_from_parsed_output(
+    prompt_inputs: &ZetaPromptInput,
+    parsed: &ParsedOutput,
+) -> Option<CursorPosition> {
+    let cursor_offset = parsed.cursor_offset_in_new_editable_region?;
+    let editable_region_offset = parsed.range_in_excerpt.start;
+    let excerpt = prompt_inputs.cursor_excerpt.as_ref();
+
+    let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
+
+    let new_editable_region = &parsed.new_editable_region;
+    let prefix_end = cursor_offset.min(new_editable_region.len());
+    let new_region_prefix = &new_editable_region[..prefix_end];
+
+    let row = editable_region_start_line + new_region_prefix.matches('\n').count();
+
+    let column = match new_region_prefix.rfind('\n') {
+        Some(last_newline) => cursor_offset - last_newline - 1,
+        None => {
+            let content_prefix = &excerpt[..editable_region_offset];
+            let content_column = match content_prefix.rfind('\n') {
+                Some(last_newline) => editable_region_offset - last_newline - 1,
+                None => editable_region_offset,
+            };
+            content_column + cursor_offset
+        }
+    };
+
+    Some(CursorPosition {
+        path: prompt_inputs.cursor_path.to_string_lossy().into_owned(),
+        row,
+        column,
+        offset: editable_region_offset + cursor_offset,
+        editable_region_offset: cursor_offset,
     })
 }
 
+pub fn parsed_output_to_patch(
+    prompt_inputs: &ZetaPromptInput,
+    parsed: ParsedOutput,
+) -> Result<String> {
+    let range_in_excerpt = parsed.range_in_excerpt;
+    let excerpt = prompt_inputs.cursor_excerpt.as_ref();
+    let old_text = excerpt[range_in_excerpt.clone()].to_string();
+    let mut new_text = parsed.new_editable_region;
+
+    let mut old_text_normalized = old_text;
+    if !new_text.is_empty() && !new_text.ends_with('\n') {
+        new_text.push('\n');
+    }
+    if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
+        old_text_normalized.push('\n');
+    }
+
+    let editable_region_offset = range_in_excerpt.start;
+    let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count() as u32;
+    let editable_region_lines = old_text_normalized.lines().count() as u32;
+
+    let diff = udiff::unified_diff_with_context(
+        &old_text_normalized,
+        &new_text,
+        editable_region_start_line,
+        editable_region_start_line,
+        editable_region_lines,
+    );
+
+    let path = prompt_inputs
+        .cursor_path
+        .to_string_lossy()
+        .trim_start_matches('/')
+        .to_string();
+    let formatted_diff = format!("--- a/{path}\n+++ b/{path}\n{diff}");
+
+    Ok(udiff::encode_cursor_in_patch(
+        &formatted_diff,
+        parsed.cursor_offset_in_new_editable_region,
+    ))
+}
+
 pub fn excerpt_range_for_format(
     format: ZetaFormat,
     ranges: &ExcerptRanges,
@@ -5400,6 +5522,33 @@ mod tests {
         assert_eq!(apply_edit(excerpt, &output1), "new content\n");
     }
 
+    #[test]
+    fn test_parsed_output_to_patch_round_trips_through_udiff_application() {
+        let excerpt = "before ctx\nctx start\neditable old\nctx end\nafter ctx\n";
+        let context_start = excerpt.find("ctx start").unwrap();
+        let context_end = excerpt.find("after ctx").unwrap();
+        let editable_start = excerpt.find("editable old").unwrap();
+        let editable_end = editable_start + "editable old\n".len();
+        let input = make_input_with_context_range(
+            excerpt,
+            editable_start..editable_end,
+            context_start..context_end,
+            editable_start,
+        );
+
+        let parsed = parse_zeta2_model_output(
+            "editable new\n>>>>>>> UPDATED\n",
+            ZetaFormat::V0131GitMergeMarkersPrefix,
+            &input,
+        )
+        .unwrap();
+        let expected = apply_edit(excerpt, &parsed);
+        let patch = parsed_output_to_patch(&input, parsed).unwrap();
+        let patched = udiff::apply_diff_to_string(&patch, excerpt).unwrap();
+
+        assert_eq!(patched, expected);
+    }
+
     #[test]
     fn test_special_tokens_not_triggered_by_comment_separator() {
         // Regression test for https://github.com/zed-industries/zed/issues/52489

tooling/compliance/Cargo.toml 🔗

@@ -0,0 +1,38 @@
+[package]
+name = "compliance"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[features]
+octo-client = ["dep:octocrab", "dep:jsonwebtoken", "dep:futures", "dep:tokio"]
+
+[dependencies]
+anyhow.workspace = true
+async-trait.workspace = true
+derive_more.workspace = true
+futures = { workspace = true, optional = true }
+itertools.workspace = true
+jsonwebtoken = { version = "10.2", features = ["use_pem"], optional = true }
+octocrab = { version = "0.49", default-features = false, features = [
+    "default-client",
+    "jwt-aws-lc-rs",
+    "retry",
+    "rustls",
+    "rustls-aws-lc-rs",
+    "stream",
+    "timeout"
+], optional = true }
+regex.workspace = true
+semver.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+tokio = { workspace = true, optional = true }
+
+[dev-dependencies]
+indoc.workspace = true
+tokio = { workspace = true, features = ["rt", "macros"] }

tooling/compliance/src/checks.rs 🔗

@@ -0,0 +1,647 @@
+use std::{fmt, ops::Not as _};
+
+use itertools::Itertools as _;
+
+use crate::{
+    git::{CommitDetails, CommitList},
+    github::{
+        CommitAuthor, GitHubClient, GitHubUser, GithubLogin, PullRequestComment, PullRequestData,
+        PullRequestReview, ReviewState,
+    },
+    report::Report,
+};
+
+const ZED_ZIPPY_COMMENT_APPROVAL_PATTERN: &str = "@zed-zippy approve";
+const ZED_ZIPPY_GROUP_APPROVAL: &str = "@zed-industries/approved";
+
+#[derive(Debug)]
+pub enum ReviewSuccess {
+    ApprovingComment(Vec<PullRequestComment>),
+    CoAuthored(Vec<CommitAuthor>),
+    ExternalMergedContribution { merged_by: GitHubUser },
+    PullRequestReviewed(Vec<PullRequestReview>),
+}
+
+impl ReviewSuccess {
+    pub(crate) fn reviewers(&self) -> anyhow::Result<String> {
+        let reviewers = match self {
+            Self::CoAuthored(authors) => authors.iter().map(ToString::to_string).collect_vec(),
+            Self::PullRequestReviewed(reviews) => reviews
+                .iter()
+                .filter_map(|review| review.user.as_ref())
+                .map(|user| format!("@{}", user.login))
+                .collect_vec(),
+            Self::ApprovingComment(comments) => comments
+                .iter()
+                .map(|comment| format!("@{}", comment.user.login))
+                .collect_vec(),
+            Self::ExternalMergedContribution { merged_by } => {
+                vec![format!("@{}", merged_by.login)]
+            }
+        };
+
+        let reviewers = reviewers.into_iter().unique().collect_vec();
+
+        reviewers
+            .is_empty()
+            .not()
+            .then(|| reviewers.join(", "))
+            .ok_or_else(|| anyhow::anyhow!("Expected at least one reviewer"))
+    }
+}
+
+impl fmt::Display for ReviewSuccess {
+    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Self::CoAuthored(_) => formatter.write_str("Co-authored by an organization member"),
+            Self::PullRequestReviewed(_) => {
+                formatter.write_str("Approved by an organization review")
+            }
+            Self::ApprovingComment(_) => {
+                formatter.write_str("Approved by an organization approval comment")
+            }
+            Self::ExternalMergedContribution { .. } => {
+                formatter.write_str("External merged contribution")
+            }
+        }
+    }
+}
+
+#[derive(Debug)]
+pub enum ReviewFailure {
+    // todo: We could still query the GitHub API here to search for one
+    NoPullRequestFound,
+    Unreviewed,
+    UnableToDetermineReviewer,
+    Other(anyhow::Error),
+}
+
+impl fmt::Display for ReviewFailure {
+    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Self::NoPullRequestFound => formatter.write_str("No pull request found"),
+            Self::Unreviewed => formatter
+                .write_str("No qualifying organization approval found for the pull request"),
+            Self::UnableToDetermineReviewer => formatter.write_str("Could not determine reviewer"),
+            Self::Other(error) => write!(formatter, "Failed to inspect review state: {error}"),
+        }
+    }
+}
+
+pub(crate) type ReviewResult = Result<ReviewSuccess, ReviewFailure>;
+
+impl<E: Into<anyhow::Error>> From<E> for ReviewFailure {
+    fn from(err: E) -> Self {
+        Self::Other(anyhow::anyhow!(err))
+    }
+}
+
+pub struct Reporter<'a> {
+    commits: CommitList,
+    github_client: &'a GitHubClient,
+}
+
+impl<'a> Reporter<'a> {
+    pub fn new(commits: CommitList, github_client: &'a GitHubClient) -> Self {
+        Self {
+            commits,
+            github_client,
+        }
+    }
+
+    /// Method that checks every commit for compliance
+    async fn check_commit(&self, commit: &CommitDetails) -> Result<ReviewSuccess, ReviewFailure> {
+        let Some(pr_number) = commit.pr_number() else {
+            return Err(ReviewFailure::NoPullRequestFound);
+        };
+
+        let pull_request = self.github_client.get_pull_request(pr_number).await?;
+
+        if let Some(approval) = self.check_pull_request_approved(&pull_request).await? {
+            return Ok(approval);
+        }
+
+        if let Some(approval) = self
+            .check_approving_pull_request_comment(&pull_request)
+            .await?
+        {
+            return Ok(approval);
+        }
+
+        if let Some(approval) = self.check_commit_co_authors(commit).await? {
+            return Ok(approval);
+        }
+
+        // if let Some(approval) = self.check_external_merged_pr(pr_number).await? {
+        //     return Ok(approval);
+        // }
+
+        Err(ReviewFailure::Unreviewed)
+    }
+
+    async fn check_commit_co_authors(
+        &self,
+        commit: &CommitDetails,
+    ) -> Result<Option<ReviewSuccess>, ReviewFailure> {
+        if commit.co_authors().is_some()
+            && let Some(commit_authors) = self
+                .github_client
+                .get_commit_authors([commit.sha()])
+                .await?
+                .get(commit.sha())
+                .and_then(|authors| authors.co_authors())
+        {
+            let mut org_co_authors = Vec::new();
+            for co_author in commit_authors {
+                if let Some(github_login) = co_author.user()
+                    && self
+                        .github_client
+                        .check_org_membership(github_login)
+                        .await?
+                {
+                    org_co_authors.push(co_author.clone());
+                }
+            }
+
+            Ok(org_co_authors
+                .is_empty()
+                .not()
+                .then_some(ReviewSuccess::CoAuthored(org_co_authors)))
+        } else {
+            Ok(None)
+        }
+    }
+
+    #[allow(unused)]
+    async fn check_external_merged_pr(
+        &self,
+        pull_request: PullRequestData,
+    ) -> Result<Option<ReviewSuccess>, ReviewFailure> {
+        if let Some(user) = pull_request.user
+            && self
+                .github_client
+                .check_org_membership(&GithubLogin::new(user.login))
+                .await?
+                .not()
+        {
+            pull_request.merged_by.map_or(
+                Err(ReviewFailure::UnableToDetermineReviewer),
+                |merged_by| {
+                    Ok(Some(ReviewSuccess::ExternalMergedContribution {
+                        merged_by,
+                    }))
+                },
+            )
+        } else {
+            Ok(None)
+        }
+    }
+
+    async fn check_pull_request_approved(
+        &self,
+        pull_request: &PullRequestData,
+    ) -> Result<Option<ReviewSuccess>, ReviewFailure> {
+        let pr_reviews = self
+            .github_client
+            .get_pull_request_reviews(pull_request.number)
+            .await?;
+
+        if !pr_reviews.is_empty() {
+            let mut org_approving_reviews = Vec::new();
+            for review in pr_reviews {
+                if let Some(github_login) = review.user.as_ref()
+                    && pull_request
+                        .user
+                        .as_ref()
+                        .is_none_or(|pr_user| pr_user.login != github_login.login)
+                    && review
+                        .state
+                        .is_some_and(|state| state == ReviewState::Approved)
+                    && self
+                        .github_client
+                        .check_org_membership(&GithubLogin::new(github_login.login.clone()))
+                        .await?
+                {
+                    org_approving_reviews.push(review);
+                }
+            }
+
+            Ok(org_approving_reviews
+                .is_empty()
+                .not()
+                .then_some(ReviewSuccess::PullRequestReviewed(org_approving_reviews)))
+        } else {
+            Ok(None)
+        }
+    }
+
+    async fn check_approving_pull_request_comment(
+        &self,
+        pull_request: &PullRequestData,
+    ) -> Result<Option<ReviewSuccess>, ReviewFailure> {
+        let other_comments = self
+            .github_client
+            .get_pull_request_comments(pull_request.number)
+            .await?;
+
+        if !other_comments.is_empty() {
+            let mut org_approving_comments = Vec::new();
+
+            for comment in other_comments {
+                if pull_request
+                    .user
+                    .as_ref()
+                    .is_some_and(|pr_author| pr_author.login != comment.user.login)
+                    && comment.body.as_ref().is_some_and(|body| {
+                        body.contains(ZED_ZIPPY_COMMENT_APPROVAL_PATTERN)
+                            || body.contains(ZED_ZIPPY_GROUP_APPROVAL)
+                    })
+                    && self
+                        .github_client
+                        .check_org_membership(&GithubLogin::new(comment.user.login.clone()))
+                        .await?
+                {
+                    org_approving_comments.push(comment);
+                }
+            }
+
+            Ok(org_approving_comments
+                .is_empty()
+                .not()
+                .then_some(ReviewSuccess::ApprovingComment(org_approving_comments)))
+        } else {
+            Ok(None)
+        }
+    }
+
+    pub async fn generate_report(mut self) -> anyhow::Result<Report> {
+        let mut report = Report::new();
+
+        let commits_to_check = std::mem::take(&mut self.commits);
+        let total_commits = commits_to_check.len();
+
+        for (i, commit) in commits_to_check.into_iter().enumerate() {
+            println!(
+                "Checking commit {:?} ({current}/{total})",
+                commit.sha().short(),
+                current = i + 1,
+                total = total_commits
+            );
+
+            let review_result = self.check_commit(&commit).await;
+
+            if let Err(err) = &review_result {
+                println!("Commit {:?} failed review: {:?}", commit.sha().short(), err);
+            }
+
+            report.add(commit, review_result);
+        }
+
+        Ok(report)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::rc::Rc;
+    use std::str::FromStr;
+
+    use crate::git::{CommitDetails, CommitList, CommitSha};
+    use crate::github::{
+        AuthorsForCommits, GitHubApiClient, GitHubClient, GitHubUser, GithubLogin,
+        PullRequestComment, PullRequestData, PullRequestReview, ReviewState,
+    };
+
+    use super::{Reporter, ReviewFailure, ReviewSuccess};
+
+    struct MockGitHubApi {
+        pull_request: PullRequestData,
+        reviews: Vec<PullRequestReview>,
+        comments: Vec<PullRequestComment>,
+        commit_authors_json: serde_json::Value,
+        org_members: Vec<String>,
+    }
+
+    #[async_trait::async_trait(?Send)]
+    impl GitHubApiClient for MockGitHubApi {
+        async fn get_pull_request(&self, _pr_number: u64) -> anyhow::Result<PullRequestData> {
+            Ok(self.pull_request.clone())
+        }
+
+        async fn get_pull_request_reviews(
+            &self,
+            _pr_number: u64,
+        ) -> anyhow::Result<Vec<PullRequestReview>> {
+            Ok(self.reviews.clone())
+        }
+
+        async fn get_pull_request_comments(
+            &self,
+            _pr_number: u64,
+        ) -> anyhow::Result<Vec<PullRequestComment>> {
+            Ok(self.comments.clone())
+        }
+
+        async fn get_commit_authors(
+            &self,
+            _commit_shas: &[&CommitSha],
+        ) -> anyhow::Result<AuthorsForCommits> {
+            serde_json::from_value(self.commit_authors_json.clone()).map_err(Into::into)
+        }
+
+        async fn check_org_membership(&self, login: &GithubLogin) -> anyhow::Result<bool> {
+            Ok(self
+                .org_members
+                .iter()
+                .any(|member| member == login.as_str()))
+        }
+
+        async fn ensure_pull_request_has_label(
+            &self,
+            _label: &str,
+            _pr_number: u64,
+        ) -> anyhow::Result<()> {
+            Ok(())
+        }
+    }
+
+    fn make_commit(
+        sha: &str,
+        author_name: &str,
+        author_email: &str,
+        title: &str,
+        body: &str,
+    ) -> CommitDetails {
+        let formatted = format!(
+            "{sha}|field-delimiter|{author_name}|field-delimiter|{author_email}|field-delimiter|\
+             {title}|body-delimiter|{body}|commit-delimiter|"
+        );
+        CommitList::from_str(&formatted)
+            .expect("test commit should parse")
+            .into_iter()
+            .next()
+            .expect("should have one commit")
+    }
+
+    fn review(login: &str, state: ReviewState) -> PullRequestReview {
+        PullRequestReview {
+            user: Some(GitHubUser {
+                login: login.to_owned(),
+            }),
+            state: Some(state),
+        }
+    }
+
+    fn comment(login: &str, body: &str) -> PullRequestComment {
+        PullRequestComment {
+            user: GitHubUser {
+                login: login.to_owned(),
+            },
+            body: Some(body.to_owned()),
+        }
+    }
+
+    struct TestScenario {
+        pull_request: PullRequestData,
+        reviews: Vec<PullRequestReview>,
+        comments: Vec<PullRequestComment>,
+        commit_authors_json: serde_json::Value,
+        org_members: Vec<String>,
+        commit: CommitDetails,
+    }
+
+    impl TestScenario {
+        fn single_commit() -> Self {
+            Self {
+                pull_request: PullRequestData {
+                    number: 1234,
+                    user: Some(GitHubUser {
+                        login: "alice".to_owned(),
+                    }),
+                    merged_by: None,
+                },
+                reviews: vec![],
+                comments: vec![],
+                commit_authors_json: serde_json::json!({}),
+                org_members: vec![],
+                commit: make_commit(
+                    "abc12345abc12345",
+                    "Alice",
+                    "alice@test.com",
+                    "Fix thing (#1234)",
+                    "",
+                ),
+            }
+        }
+
+        fn with_reviews(mut self, reviews: Vec<PullRequestReview>) -> Self {
+            self.reviews = reviews;
+            self
+        }
+
+        fn with_comments(mut self, comments: Vec<PullRequestComment>) -> Self {
+            self.comments = comments;
+            self
+        }
+
+        fn with_org_members(mut self, members: Vec<&str>) -> Self {
+            self.org_members = members.into_iter().map(str::to_owned).collect();
+            self
+        }
+
+        fn with_commit_authors_json(mut self, json: serde_json::Value) -> Self {
+            self.commit_authors_json = json;
+            self
+        }
+
+        fn with_commit(mut self, commit: CommitDetails) -> Self {
+            self.commit = commit;
+            self
+        }
+
+        async fn run_scenario(self) -> Result<ReviewSuccess, ReviewFailure> {
+            let mock = MockGitHubApi {
+                pull_request: self.pull_request,
+                reviews: self.reviews,
+                comments: self.comments,
+                commit_authors_json: self.commit_authors_json,
+                org_members: self.org_members,
+            };
+            let client = GitHubClient::new(Rc::new(mock));
+            let reporter = Reporter::new(CommitList::default(), &client);
+            reporter.check_commit(&self.commit).await
+        }
+    }
+
+    #[tokio::test]
+    async fn approved_review_by_org_member_succeeds() {
+        let result = TestScenario::single_commit()
+            .with_reviews(vec![review("bob", ReviewState::Approved)])
+            .with_org_members(vec!["bob"])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Ok(ReviewSuccess::PullRequestReviewed(_))));
+    }
+
+    #[tokio::test]
+    async fn non_approved_review_state_is_not_accepted() {
+        let result = TestScenario::single_commit()
+            .with_reviews(vec![review("bob", ReviewState::Other)])
+            .with_org_members(vec!["bob"])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+    }
+
+    #[tokio::test]
+    async fn review_by_non_org_member_is_not_accepted() {
+        let result = TestScenario::single_commit()
+            .with_reviews(vec![review("bob", ReviewState::Approved)])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+    }
+
+    #[tokio::test]
+    async fn pr_author_own_approval_review_is_rejected() {
+        let result = TestScenario::single_commit()
+            .with_reviews(vec![review("alice", ReviewState::Approved)])
+            .with_org_members(vec!["alice"])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+    }
+
+    #[tokio::test]
+    async fn pr_author_own_approval_comment_is_rejected() {
+        let result = TestScenario::single_commit()
+            .with_comments(vec![comment("alice", "@zed-zippy approve")])
+            .with_org_members(vec!["alice"])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+    }
+
+    #[tokio::test]
+    async fn approval_comment_by_org_member_succeeds() {
+        let result = TestScenario::single_commit()
+            .with_comments(vec![comment("bob", "@zed-zippy approve")])
+            .with_org_members(vec!["bob"])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_))));
+    }
+
+    #[tokio::test]
+    async fn group_approval_comment_by_org_member_succeeds() {
+        let result = TestScenario::single_commit()
+            .with_comments(vec![comment("bob", "@zed-industries/approved")])
+            .with_org_members(vec!["bob"])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_))));
+    }
+
+    #[tokio::test]
+    async fn comment_without_approval_pattern_is_not_accepted() {
+        let result = TestScenario::single_commit()
+            .with_comments(vec![comment("bob", "looks good")])
+            .with_org_members(vec!["bob"])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+    }
+
+    #[tokio::test]
+    async fn commit_without_pr_number_is_no_pr_found() {
+        let result = TestScenario::single_commit()
+            .with_commit(make_commit(
+                "abc12345abc12345",
+                "Alice",
+                "alice@test.com",
+                "Fix thing without PR number",
+                "",
+            ))
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Err(ReviewFailure::NoPullRequestFound)));
+    }
+
+    #[tokio::test]
+    async fn pr_review_takes_precedence_over_comment() {
+        let result = TestScenario::single_commit()
+            .with_reviews(vec![review("bob", ReviewState::Approved)])
+            .with_comments(vec![comment("charlie", "@zed-zippy approve")])
+            .with_org_members(vec!["bob", "charlie"])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Ok(ReviewSuccess::PullRequestReviewed(_))));
+    }
+
+    #[tokio::test]
+    async fn comment_takes_precedence_over_co_author() {
+        let result = TestScenario::single_commit()
+            .with_comments(vec![comment("bob", "@zed-zippy approve")])
+            .with_commit_authors_json(serde_json::json!({
+                "abc12345abc12345": {
+                    "author": {
+                        "name": "Alice",
+                        "email": "alice@test.com",
+                        "user": { "login": "alice" }
+                    },
+                    "authors": [{
+                        "name": "Charlie",
+                        "email": "charlie@test.com",
+                        "user": { "login": "charlie" }
+                    }]
+                }
+            }))
+            .with_commit(make_commit(
+                "abc12345abc12345",
+                "Alice",
+                "alice@test.com",
+                "Fix thing (#1234)",
+                "Co-authored-by: Charlie <charlie@test.com>",
+            ))
+            .with_org_members(vec!["bob", "charlie"])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_))));
+    }
+
+    #[tokio::test]
+    async fn co_author_org_member_succeeds() {
+        let result = TestScenario::single_commit()
+            .with_commit_authors_json(serde_json::json!({
+                "abc12345abc12345": {
+                    "author": {
+                        "name": "Alice",
+                        "email": "alice@test.com",
+                        "user": { "login": "alice" }
+                    },
+                    "authors": [{
+                        "name": "Bob",
+                        "email": "bob@test.com",
+                        "user": { "login": "bob" }
+                    }]
+                }
+            }))
+            .with_commit(make_commit(
+                "abc12345abc12345",
+                "Alice",
+                "alice@test.com",
+                "Fix thing (#1234)",
+                "Co-authored-by: Bob <bob@test.com>",
+            ))
+            .with_org_members(vec!["bob"])
+            .run_scenario()
+            .await;
+        assert!(matches!(result, Ok(ReviewSuccess::CoAuthored(_))));
+    }
+
+    #[tokio::test]
+    async fn no_reviews_no_comments_no_coauthors_is_unreviewed() {
+        let result = TestScenario::single_commit().run_scenario().await;
+        assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+    }
+}

tooling/compliance/src/git.rs 🔗

@@ -0,0 +1,591 @@
+#![allow(clippy::disallowed_methods, reason = "This is only used in xtasks")]
+use std::{
+    fmt::{self, Debug},
+    ops::Not,
+    process::Command,
+    str::FromStr,
+    sync::LazyLock,
+};
+
+use anyhow::{Context, Result, anyhow};
+use derive_more::{Deref, DerefMut, FromStr};
+
+use itertools::Itertools;
+use regex::Regex;
+use semver::Version;
+use serde::Deserialize;
+
+pub trait Subcommand {
+    type ParsedOutput: FromStr<Err = anyhow::Error>;
+
+    fn args(&self) -> impl IntoIterator<Item = String>;
+}
+
+#[derive(Deref, DerefMut)]
+pub struct GitCommand<G: Subcommand> {
+    #[deref]
+    #[deref_mut]
+    subcommand: G,
+}
+
+impl<G: Subcommand> GitCommand<G> {
+    #[must_use]
+    pub fn run(subcommand: G) -> Result<G::ParsedOutput> {
+        Self { subcommand }.run_impl()
+    }
+
+    fn run_impl(self) -> Result<G::ParsedOutput> {
+        let command_output = Command::new("git")
+            .args(self.subcommand.args())
+            .output()
+            .context("Failed to spawn command")?;
+
+        if command_output.status.success() {
+            String::from_utf8(command_output.stdout)
+                .map_err(|_| anyhow!("Invalid UTF8"))
+                .and_then(|s| {
+                    G::ParsedOutput::from_str(s.trim())
+                        .map_err(|e| anyhow!("Failed to parse from string: {e:?}"))
+                })
+        } else {
+            anyhow::bail!(
+                "Command failed with exit code {}, stderr: {}",
+                command_output.status.code().unwrap_or_default(),
+                String::from_utf8(command_output.stderr).unwrap_or_default()
+            )
+        }
+    }
+}
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum ReleaseChannel {
+    Stable,
+    Preview,
+}
+
+impl ReleaseChannel {
+    pub(crate) fn tag_suffix(&self) -> &'static str {
+        match self {
+            ReleaseChannel::Stable => "",
+            ReleaseChannel::Preview => "-pre",
+        }
+    }
+}
+
+#[derive(Debug, Clone)]
+pub struct VersionTag(Version, ReleaseChannel);
+
+impl VersionTag {
+    pub fn parse(input: &str) -> Result<Self, anyhow::Error> {
+        // Being a bit more lenient for human inputs
+        let version = input.strip_prefix('v').unwrap_or(input);
+
+        let (version_str, channel) = version
+            .strip_suffix("-pre")
+            .map_or((version, ReleaseChannel::Stable), |version_str| {
+                (version_str, ReleaseChannel::Preview)
+            });
+
+        Version::parse(version_str)
+            .map(|version| Self(version, channel))
+            .map_err(|_| anyhow::anyhow!("Failed to parse version from tag!"))
+    }
+
+    pub fn version(&self) -> &Version {
+        &self.0
+    }
+}
+
+impl ToString for VersionTag {
+    fn to_string(&self) -> String {
+        format!(
+            "v{version}{channel_suffix}",
+            version = self.0,
+            channel_suffix = self.1.tag_suffix()
+        )
+    }
+}
+
+#[derive(Debug, Deref, FromStr, PartialEq, Eq, Hash, Deserialize)]
+pub struct CommitSha(pub(crate) String);
+
+impl CommitSha {
+    pub fn short(&self) -> &str {
+        self.0.as_str().split_at(8).0
+    }
+}
+
+#[derive(Debug)]
+pub struct CommitDetails {
+    sha: CommitSha,
+    author: Committer,
+    title: String,
+    body: String,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct Committer {
+    name: String,
+    email: String,
+}
+
+impl Committer {
+    pub fn new(name: &str, email: &str) -> Self {
+        Self {
+            name: name.to_owned(),
+            email: email.to_owned(),
+        }
+    }
+}
+
+impl fmt::Display for Committer {
+    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(formatter, "{} ({})", self.name, self.email)
+    }
+}
+
+impl CommitDetails {
+    const BODY_DELIMITER: &str = "|body-delimiter|";
+    const COMMIT_DELIMITER: &str = "|commit-delimiter|";
+    const FIELD_DELIMITER: &str = "|field-delimiter|";
+    const FORMAT_STRING: &str = "%H|field-delimiter|%an|field-delimiter|%ae|field-delimiter|%s|body-delimiter|%b|commit-delimiter|";
+
+    fn parse(line: &str, body: &str) -> Result<Self, anyhow::Error> {
+        let Some([sha, author_name, author_email, title]) =
+            line.splitn(4, Self::FIELD_DELIMITER).collect_array()
+        else {
+            return Err(anyhow!("Failed to parse commit fields from input {line}"));
+        };
+
+        Ok(CommitDetails {
+            sha: CommitSha(sha.to_owned()),
+            author: Committer::new(author_name, author_email),
+            title: title.to_owned(),
+            body: body.to_owned(),
+        })
+    }
+
+    pub fn pr_number(&self) -> Option<u64> {
+        // Since we use squash merge, all commit titles end with the '(#12345)' pattern.
+        // While we could strictly speaking index into this directly, go for a slightly
+        // less prone approach to errors
+        const PATTERN: &str = " (#";
+        self.title
+            .rfind(PATTERN)
+            .and_then(|location| {
+                self.title[location..]
+                    .find(')')
+                    .map(|relative_end| location + PATTERN.len()..location + relative_end)
+            })
+            .and_then(|range| self.title[range].parse().ok())
+    }
+
+    pub(crate) fn co_authors(&self) -> Option<Vec<Committer>> {
+        static CO_AUTHOR_REGEX: LazyLock<Regex> =
+            LazyLock::new(|| Regex::new(r"Co-authored-by: (.+) <(.+)>").unwrap());
+
+        let mut co_authors = Vec::new();
+
+        for cap in CO_AUTHOR_REGEX.captures_iter(&self.body.as_ref()) {
+            let Some((name, email)) = cap
+                .get(1)
+                .map(|m| m.as_str())
+                .zip(cap.get(2).map(|m| m.as_str()))
+            else {
+                continue;
+            };
+            co_authors.push(Committer::new(name, email));
+        }
+
+        co_authors.is_empty().not().then_some(co_authors)
+    }
+
+    pub(crate) fn author(&self) -> &Committer {
+        &self.author
+    }
+
+    pub(crate) fn title(&self) -> &str {
+        &self.title
+    }
+
+    pub(crate) fn sha(&self) -> &CommitSha {
+        &self.sha
+    }
+}
+
+#[derive(Debug, Deref, Default, DerefMut)]
+pub struct CommitList(Vec<CommitDetails>);
+
+impl CommitList {
+    pub fn range(&self) -> Option<String> {
+        self.0
+            .first()
+            .zip(self.0.last())
+            .map(|(first, last)| format!("{}..{}", first.sha().0, last.sha().0))
+    }
+}
+
+impl IntoIterator for CommitList {
+    type IntoIter = std::vec::IntoIter<CommitDetails>;
+    type Item = CommitDetails;
+
+    fn into_iter(self) -> std::vec::IntoIter<Self::Item> {
+        self.0.into_iter()
+    }
+}
+
+impl FromStr for CommitList {
+    type Err = anyhow::Error;
+
+    fn from_str(input: &str) -> Result<Self, Self::Err> {
+        Ok(CommitList(
+            input
+                .split(CommitDetails::COMMIT_DELIMITER)
+                .filter(|commit_details| !commit_details.is_empty())
+                .map(|commit_details| {
+                    let (line, body) = commit_details
+                        .trim()
+                        .split_once(CommitDetails::BODY_DELIMITER)
+                        .expect("Missing body delimiter");
+                    CommitDetails::parse(line, body)
+                        .expect("Parsing from the output should succeed")
+                })
+                .collect(),
+        ))
+    }
+}
+
+pub struct GetVersionTags;
+
+impl Subcommand for GetVersionTags {
+    type ParsedOutput = VersionTagList;
+
+    fn args(&self) -> impl IntoIterator<Item = String> {
+        ["tag", "-l", "v*"].map(ToOwned::to_owned)
+    }
+}
+
+pub struct VersionTagList(Vec<VersionTag>);
+
+impl VersionTagList {
+    pub fn sorted(mut self) -> Self {
+        self.0.sort_by(|a, b| a.version().cmp(b.version()));
+        self
+    }
+
+    pub fn find_previous_minor_version(&self, version_tag: &VersionTag) -> Option<&VersionTag> {
+        self.0
+            .iter()
+            .take_while(|tag| tag.version() < version_tag.version())
+            .collect_vec()
+            .into_iter()
+            .rev()
+            .find(|tag| {
+                (tag.version().major < version_tag.version().major
+                    || (tag.version().major == version_tag.version().major
+                        && tag.version().minor < version_tag.version().minor))
+                    && tag.version().patch == 0
+            })
+    }
+}
+
+impl FromStr for VersionTagList {
+    type Err = anyhow::Error;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        let version_tags = s.lines().flat_map(VersionTag::parse).collect_vec();
+
+        version_tags
+            .is_empty()
+            .not()
+            .then_some(Self(version_tags))
+            .ok_or_else(|| anyhow::anyhow!("No version tags found"))
+    }
+}
+
+pub struct CommitsFromVersionToHead {
+    version_tag: VersionTag,
+    branch: String,
+}
+
+impl CommitsFromVersionToHead {
+    pub fn new(version_tag: VersionTag, branch: String) -> Self {
+        Self {
+            version_tag,
+            branch,
+        }
+    }
+}
+
+impl Subcommand for CommitsFromVersionToHead {
+    type ParsedOutput = CommitList;
+
+    fn args(&self) -> impl IntoIterator<Item = String> {
+        [
+            "log".to_string(),
+            format!("--pretty=format:{}", CommitDetails::FORMAT_STRING),
+            format!(
+                "{version}..{branch}",
+                version = self.version_tag.to_string(),
+                branch = self.branch
+            ),
+        ]
+    }
+}
+
+pub struct NoOutput;
+
+impl FromStr for NoOutput {
+    type Err = anyhow::Error;
+
+    fn from_str(_: &str) -> Result<Self, Self::Err> {
+        Ok(NoOutput)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use indoc::indoc;
+
+    #[test]
+    fn parse_stable_version_tag() {
+        let tag = VersionTag::parse("v0.172.8").unwrap();
+        assert_eq!(tag.version().major, 0);
+        assert_eq!(tag.version().minor, 172);
+        assert_eq!(tag.version().patch, 8);
+        assert_eq!(tag.1, ReleaseChannel::Stable);
+    }
+
+    #[test]
+    fn parse_preview_version_tag() {
+        let tag = VersionTag::parse("v0.172.1-pre").unwrap();
+        assert_eq!(tag.version().major, 0);
+        assert_eq!(tag.version().minor, 172);
+        assert_eq!(tag.version().patch, 1);
+        assert_eq!(tag.1, ReleaseChannel::Preview);
+    }
+
+    #[test]
+    fn parse_version_tag_without_v_prefix() {
+        let tag = VersionTag::parse("0.172.8").unwrap();
+        assert_eq!(tag.version().major, 0);
+        assert_eq!(tag.version().minor, 172);
+        assert_eq!(tag.version().patch, 8);
+    }
+
+    #[test]
+    fn parse_invalid_version_tag() {
+        let result = VersionTag::parse("vConradTest");
+        assert!(result.is_err());
+    }
+
+    #[test]
+    fn version_tag_stable_roundtrip() {
+        let tag = VersionTag::parse("v0.172.8").unwrap();
+        assert_eq!(tag.to_string(), "v0.172.8");
+    }
+
+    #[test]
+    fn version_tag_preview_roundtrip() {
+        let tag = VersionTag::parse("v0.172.1-pre").unwrap();
+        assert_eq!(tag.to_string(), "v0.172.1-pre");
+    }
+
+    #[test]
+    fn sorted_orders_by_semver() {
+        let input = indoc! {"
+            v0.172.8
+            v0.170.1
+            v0.171.4
+            v0.170.2
+            v0.172.11
+            v0.171.3
+            v0.172.9
+        "};
+        let list = VersionTagList::from_str(input).unwrap().sorted();
+        for window in list.0.windows(2) {
+            assert!(
+                window[0].version() <= window[1].version(),
+                "{} should come before {}",
+                window[0].to_string(),
+                window[1].to_string()
+            );
+        }
+        assert_eq!(list.0[0].to_string(), "v0.170.1");
+        assert_eq!(list.0[list.0.len() - 1].to_string(), "v0.172.11");
+    }
+
+    #[test]
+    fn find_previous_minor_for_173_returns_172() {
+        let input = indoc! {"
+            v0.170.1
+            v0.170.2
+            v0.171.3
+            v0.171.4
+            v0.172.0
+            v0.172.8
+            v0.172.9
+            v0.172.11
+        "};
+        let list = VersionTagList::from_str(input).unwrap().sorted();
+        let target = VersionTag::parse("v0.173.0").unwrap();
+        let previous = list.find_previous_minor_version(&target).unwrap();
+        assert_eq!(previous.version().major, 0);
+        assert_eq!(previous.version().minor, 172);
+        assert_eq!(previous.version().patch, 0);
+    }
+
+    #[test]
+    fn find_previous_minor_skips_same_minor() {
+        let input = indoc! {"
+            v0.172.8
+            v0.172.9
+            v0.172.11
+        "};
+        let list = VersionTagList::from_str(input).unwrap().sorted();
+        let target = VersionTag::parse("v0.172.8").unwrap();
+        assert!(list.find_previous_minor_version(&target).is_none());
+    }
+
+    #[test]
+    fn find_previous_minor_with_major_version_gap() {
+        let input = indoc! {"
+            v0.172.0
+            v0.172.9
+            v0.172.11
+        "};
+        let list = VersionTagList::from_str(input).unwrap().sorted();
+        let target = VersionTag::parse("v1.0.0").unwrap();
+        let previous = list.find_previous_minor_version(&target).unwrap();
+        assert_eq!(previous.to_string(), "v0.172.0");
+    }
+
+    #[test]
+    fn find_previous_minor_requires_zero_patch_version() {
+        let input = indoc! {"
+            v0.172.1
+            v0.172.9
+            v0.172.11
+        "};
+        let list = VersionTagList::from_str(input).unwrap().sorted();
+        let target = VersionTag::parse("v1.0.0").unwrap();
+        assert!(list.find_previous_minor_version(&target).is_none());
+    }
+
+    #[test]
+    fn parse_tag_list_from_real_tags() {
+        let input = indoc! {"
+            v0.9999-temporary
+            vConradTest
+            v0.172.8
+        "};
+        let list = VersionTagList::from_str(input).unwrap();
+        assert_eq!(list.0.len(), 1);
+        assert_eq!(list.0[0].to_string(), "v0.172.8");
+    }
+
+    #[test]
+    fn parse_empty_tag_list_fails() {
+        let result = VersionTagList::from_str("");
+        assert!(result.is_err());
+    }
+
+    #[test]
+    fn pr_number_from_squash_merge_title() {
+        let line = format!(
+            "abc123{d}Author Name{d}author@email.com{d}Add cool feature (#12345)",
+            d = CommitDetails::FIELD_DELIMITER
+        );
+        let commit = CommitDetails::parse(&line, "").unwrap();
+        assert_eq!(commit.pr_number(), Some(12345));
+    }
+
+    #[test]
+    fn pr_number_missing() {
+        let line = format!(
+            "abc123{d}Author Name{d}author@email.com{d}Some commit without PR ref",
+            d = CommitDetails::FIELD_DELIMITER
+        );
+        let commit = CommitDetails::parse(&line, "").unwrap();
+        assert_eq!(commit.pr_number(), None);
+    }
+
+    #[test]
+    fn pr_number_takes_last_match() {
+        let line = format!(
+            "abc123{d}Author Name{d}author@email.com{d}Fix (#123) and refactor (#456)",
+            d = CommitDetails::FIELD_DELIMITER
+        );
+        let commit = CommitDetails::parse(&line, "").unwrap();
+        assert_eq!(commit.pr_number(), Some(456));
+    }
+
+    #[test]
+    fn co_authors_parsed_from_body() {
+        let line = format!(
+            "abc123{d}Author Name{d}author@email.com{d}Some title",
+            d = CommitDetails::FIELD_DELIMITER
+        );
+        let body = indoc! {"
+            Co-authored-by: Alice Smith <alice@example.com>
+            Co-authored-by: Bob Jones <bob@example.com>
+        "};
+        let commit = CommitDetails::parse(&line, body).unwrap();
+        let co_authors = commit.co_authors().unwrap();
+        assert_eq!(co_authors.len(), 2);
+        assert_eq!(
+            co_authors[0],
+            Committer::new("Alice Smith", "alice@example.com")
+        );
+        assert_eq!(
+            co_authors[1],
+            Committer::new("Bob Jones", "bob@example.com")
+        );
+    }
+
+    #[test]
+    fn no_co_authors_returns_none() {
+        let line = format!(
+            "abc123{d}Author Name{d}author@email.com{d}Some title",
+            d = CommitDetails::FIELD_DELIMITER
+        );
+        let commit = CommitDetails::parse(&line, "").unwrap();
+        assert!(commit.co_authors().is_none());
+    }
+
+    #[test]
+    fn commit_sha_short_returns_first_8_chars() {
+        let sha = CommitSha("abcdef1234567890abcdef1234567890abcdef12".into());
+        assert_eq!(sha.short(), "abcdef12");
+    }
+
+    #[test]
+    fn parse_commit_list_from_git_log_format() {
+        let fd = CommitDetails::FIELD_DELIMITER;
+        let bd = CommitDetails::BODY_DELIMITER;
+        let cd = CommitDetails::COMMIT_DELIMITER;
+
+        let input = format!(
+            "sha111{fd}Alice{fd}alice@test.com{fd}First commit (#100){bd}First body{cd}sha222{fd}Bob{fd}bob@test.com{fd}Second commit (#200){bd}Second body{cd}"
+        );
+
+        let list = CommitList::from_str(&input).unwrap();
+        assert_eq!(list.0.len(), 2);
+
+        assert_eq!(list.0[0].sha().0, "sha111");
+        assert_eq!(
+            list.0[0].author(),
+            &Committer::new("Alice", "alice@test.com")
+        );
+        assert_eq!(list.0[0].title(), "First commit (#100)");
+        assert_eq!(list.0[0].pr_number(), Some(100));
+        assert_eq!(list.0[0].body, "First body");
+
+        assert_eq!(list.0[1].sha().0, "sha222");
+        assert_eq!(list.0[1].author(), &Committer::new("Bob", "bob@test.com"));
+        assert_eq!(list.0[1].title(), "Second commit (#200)");
+        assert_eq!(list.0[1].pr_number(), Some(200));
+        assert_eq!(list.0[1].body, "Second body");
+    }
+}

tooling/compliance/src/github.rs 🔗

@@ -0,0 +1,424 @@
+use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
+
+use anyhow::Result;
+use derive_more::Deref;
+use serde::Deserialize;
+
+use crate::git::CommitSha;
+
+pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
+
+#[derive(Debug, Clone)]
+pub struct GitHubUser {
+    pub login: String,
+}
+
+#[derive(Debug, Clone)]
+pub struct PullRequestData {
+    pub number: u64,
+    pub user: Option<GitHubUser>,
+    pub merged_by: Option<GitHubUser>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum ReviewState {
+    Approved,
+    Other,
+}
+
+#[derive(Debug, Clone)]
+pub struct PullRequestReview {
+    pub user: Option<GitHubUser>,
+    pub state: Option<ReviewState>,
+}
+
+#[derive(Debug, Clone)]
+pub struct PullRequestComment {
+    pub user: GitHubUser,
+    pub body: Option<String>,
+}
+
+#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
+pub struct GithubLogin {
+    login: String,
+}
+
+impl GithubLogin {
+    pub(crate) fn new(login: String) -> Self {
+        Self { login }
+    }
+}
+
+impl fmt::Display for GithubLogin {
+    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(formatter, "@{}", self.login)
+    }
+}
+
+#[derive(Debug, Deserialize, Clone)]
+pub struct CommitAuthor {
+    name: String,
+    email: String,
+    user: Option<GithubLogin>,
+}
+
+impl CommitAuthor {
+    pub(crate) fn user(&self) -> Option<&GithubLogin> {
+        self.user.as_ref()
+    }
+}
+
+impl PartialEq for CommitAuthor {
+    fn eq(&self, other: &Self) -> bool {
+        self.user.as_ref().zip(other.user.as_ref()).map_or_else(
+            || self.email == other.email || self.name == other.name,
+            |(l, r)| l == r,
+        )
+    }
+}
+
+impl fmt::Display for CommitAuthor {
+    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self.user.as_ref() {
+            Some(user) => write!(formatter, "{} ({user})", self.name),
+            None => write!(formatter, "{} ({})", self.name, self.email),
+        }
+    }
+}
+
+#[derive(Debug, Deserialize)]
+pub struct CommitAuthors {
+    #[serde(rename = "author")]
+    primary_author: CommitAuthor,
+    #[serde(rename = "authors")]
+    co_authors: Vec<CommitAuthor>,
+}
+
+impl CommitAuthors {
+    pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
+        self.co_authors.is_empty().not().then(|| {
+            self.co_authors
+                .iter()
+                .filter(|co_author| *co_author != &self.primary_author)
+        })
+    }
+}
+
+#[derive(Debug, Deserialize, Deref)]
+pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
+
+#[async_trait::async_trait(?Send)]
+pub trait GitHubApiClient {
+    async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
+    async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
+    async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
+    async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
+    async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
+    async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>;
+}
+
+pub struct GitHubClient {
+    api: Rc<dyn GitHubApiClient>,
+}
+
+impl GitHubClient {
+    pub fn new(api: Rc<dyn GitHubApiClient>) -> Self {
+        Self { api }
+    }
+
+    #[cfg(feature = "octo-client")]
+    pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
+        let client = OctocrabClient::new(app_id, app_private_key).await?;
+        Ok(Self::new(Rc::new(client)))
+    }
+
+    pub async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
+        self.api.get_pull_request(pr_number).await
+    }
+
+    pub async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
+        self.api.get_pull_request_reviews(pr_number).await
+    }
+
+    pub async fn get_pull_request_comments(
+        &self,
+        pr_number: u64,
+    ) -> Result<Vec<PullRequestComment>> {
+        self.api.get_pull_request_comments(pr_number).await
+    }
+
+    pub async fn get_commit_authors<'a>(
+        &self,
+        commit_shas: impl IntoIterator<Item = &'a CommitSha>,
+    ) -> Result<AuthorsForCommits> {
+        let shas: Vec<&CommitSha> = commit_shas.into_iter().collect();
+        self.api.get_commit_authors(&shas).await
+    }
+
+    pub async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
+        self.api.check_org_membership(login).await
+    }
+
+    pub async fn add_label_to_pull_request(&self, label: &str, pr_number: u64) -> Result<()> {
+        self.api
+            .ensure_pull_request_has_label(label, pr_number)
+            .await
+    }
+}
+
+#[cfg(feature = "octo-client")]
+mod octo_client {
+    use anyhow::{Context, Result};
+    use futures::TryStreamExt as _;
+    use itertools::Itertools;
+    use jsonwebtoken::EncodingKey;
+    use octocrab::{
+        Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
+        service::middleware::cache::mem::InMemoryCache,
+    };
+    use serde::de::DeserializeOwned;
+    use tokio::pin;
+
+    use crate::git::CommitSha;
+
+    use super::{
+        AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
+        PullRequestData, PullRequestReview, ReviewState,
+    };
+
+    const PAGE_SIZE: u8 = 100;
+    const ORG: &str = "zed-industries";
+    const REPO: &str = "zed";
+
+    pub struct OctocrabClient {
+        client: Octocrab,
+    }
+
+    impl OctocrabClient {
+        pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
+            let octocrab = Octocrab::builder()
+                .cache(InMemoryCache::new())
+                .app(
+                    app_id.into(),
+                    EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
+                )
+                .build()?;
+
+            let installations = octocrab
+                .apps()
+                .installations()
+                .send()
+                .await
+                .context("Failed to fetch installations")?
+                .take_items();
+
+            let installation_id = installations
+                .into_iter()
+                .find(|installation| installation.account.login == ORG)
+                .context("Could not find Zed repository in installations")?
+                .id;
+
+            let client = octocrab.installation(installation_id)?;
+            Ok(Self { client })
+        }
+
+        fn build_co_authors_query<'a>(shas: impl IntoIterator<Item = &'a CommitSha>) -> String {
+            const FRAGMENT: &str = r#"
+                ... on Commit {
+                    author {
+                        name
+                        email
+                        user { login }
+                    }
+                    authors(first: 10) {
+                        nodes {
+                            name
+                            email
+                            user { login }
+                        }
+                    }
+                }
+            "#;
+
+            let objects: String = shas
+                .into_iter()
+                .map(|commit_sha| {
+                    format!(
+                        "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
+                        sha = **commit_sha
+                    )
+                })
+                .join("\n");
+
+            format!("{{  repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects}  }} }}")
+                .replace("\n", "")
+        }
+
+        async fn graphql<R: octocrab::FromResponse>(
+            &self,
+            query: &serde_json::Value,
+        ) -> octocrab::Result<R> {
+            self.client.graphql(query).await
+        }
+
+        async fn get_all<T: DeserializeOwned + 'static>(
+            &self,
+            page: Page<T>,
+        ) -> octocrab::Result<Vec<T>> {
+            self.get_filtered(page, |_| true).await
+        }
+
+        async fn get_filtered<T: DeserializeOwned + 'static>(
+            &self,
+            page: Page<T>,
+            predicate: impl Fn(&T) -> bool,
+        ) -> octocrab::Result<Vec<T>> {
+            let stream = page.into_stream(&self.client);
+            pin!(stream);
+
+            let mut results = Vec::new();
+
+            while let Some(item) = stream.try_next().await?
+                && predicate(&item)
+            {
+                results.push(item);
+            }
+
+            Ok(results)
+        }
+    }
+
+    #[async_trait::async_trait(?Send)]
+    impl GitHubApiClient for OctocrabClient {
+        async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
+            let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
+            Ok(PullRequestData {
+                number: pr.number,
+                user: pr.user.map(|user| GitHubUser { login: user.login }),
+                merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
+            })
+        }
+
+        async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
+            let page = self
+                .client
+                .pulls(ORG, REPO)
+                .list_reviews(pr_number)
+                .per_page(PAGE_SIZE)
+                .send()
+                .await?;
+
+            let reviews = self.get_all(page).await?;
+
+            Ok(reviews
+                .into_iter()
+                .map(|review| PullRequestReview {
+                    user: review.user.map(|user| GitHubUser { login: user.login }),
+                    state: review.state.map(|state| match state {
+                        OctocrabReviewState::Approved => ReviewState::Approved,
+                        _ => ReviewState::Other,
+                    }),
+                })
+                .collect())
+        }
+
+        async fn get_pull_request_comments(
+            &self,
+            pr_number: u64,
+        ) -> Result<Vec<PullRequestComment>> {
+            let page = self
+                .client
+                .issues(ORG, REPO)
+                .list_comments(pr_number)
+                .per_page(PAGE_SIZE)
+                .send()
+                .await?;
+
+            let comments = self.get_all(page).await?;
+
+            Ok(comments
+                .into_iter()
+                .map(|comment| PullRequestComment {
+                    user: GitHubUser {
+                        login: comment.user.login,
+                    },
+                    body: comment.body,
+                })
+                .collect())
+        }
+
+        async fn get_commit_authors(
+            &self,
+            commit_shas: &[&CommitSha],
+        ) -> Result<AuthorsForCommits> {
+            let query = Self::build_co_authors_query(commit_shas.iter().copied());
+            let query = serde_json::json!({ "query": query });
+            let mut response = self.graphql::<serde_json::Value>(&query).await?;
+
+            response
+                .get_mut("data")
+                .and_then(|data| data.get_mut("repository"))
+                .and_then(|repo| repo.as_object_mut())
+                .ok_or_else(|| anyhow::anyhow!("Unexpected response format!"))
+                .and_then(|commit_data| {
+                    let mut response_map = serde_json::Map::with_capacity(commit_data.len());
+
+                    for (key, value) in commit_data.iter_mut() {
+                        let key_without_prefix = key.strip_prefix("commit").unwrap_or(key);
+                        if let Some(authors) = value.get_mut("authors") {
+                            if let Some(nodes) = authors.get("nodes") {
+                                *authors = nodes.clone();
+                            }
+                        }
+
+                        response_map.insert(key_without_prefix.to_owned(), value.clone());
+                    }
+
+                    serde_json::from_value(serde_json::Value::Object(response_map))
+                        .context("Failed to deserialize commit authors")
+                })
+        }
+
+        async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
+            let page = self
+                .client
+                .orgs(ORG)
+                .list_members()
+                .per_page(PAGE_SIZE)
+                .send()
+                .await?;
+
+            let members = self.get_all(page).await?;
+
+            Ok(members
+                .into_iter()
+                .any(|member| member.login == login.as_str()))
+        }
+
+        async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> {
+            if self
+                .get_filtered(
+                    self.client
+                        .issues(ORG, REPO)
+                        .list_labels_for_issue(pr_number)
+                        .per_page(PAGE_SIZE)
+                        .send()
+                        .await?,
+                    |pr_label| pr_label.name == label,
+                )
+                .await
+                .is_ok_and(|l| l.is_empty())
+            {
+                self.client
+                    .issues(ORG, REPO)
+                    .add_labels(pr_number, &[label.to_owned()])
+                    .await?;
+            }
+
+            Ok(())
+        }
+    }
+}
+
+#[cfg(feature = "octo-client")]
+pub use octo_client::OctocrabClient;

tooling/compliance/src/report.rs 🔗

@@ -0,0 +1,446 @@
+use std::{
+    fs::{self, File},
+    io::{BufWriter, Write},
+    path::Path,
+};
+
+use anyhow::Context as _;
+use derive_more::Display;
+use itertools::{Either, Itertools};
+
+use crate::{
+    checks::{ReviewFailure, ReviewResult, ReviewSuccess},
+    git::CommitDetails,
+};
+
+const PULL_REQUEST_BASE_URL: &str = "https://github.com/zed-industries/zed/pull";
+
+#[derive(Debug)]
+pub struct ReportEntry<R> {
+    pub commit: CommitDetails,
+    reason: R,
+}
+
+impl<R: ToString> ReportEntry<R> {
+    fn commit_cell(&self) -> String {
+        let title = escape_markdown_link_text(self.commit.title());
+
+        match self.commit.pr_number() {
+            Some(pr_number) => format!("[{title}]({PULL_REQUEST_BASE_URL}/{pr_number})"),
+            None => escape_markdown_table_text(self.commit.title()),
+        }
+    }
+
+    fn pull_request_cell(&self) -> String {
+        self.commit
+            .pr_number()
+            .map(|pr_number| format!("#{pr_number}"))
+            .unwrap_or_else(|| "—".to_owned())
+    }
+
+    fn author_cell(&self) -> String {
+        escape_markdown_table_text(&self.commit.author().to_string())
+    }
+
+    fn reason_cell(&self) -> String {
+        escape_markdown_table_text(&self.reason.to_string())
+    }
+}
+
+impl ReportEntry<ReviewFailure> {
+    fn issue_kind(&self) -> IssueKind {
+        match self.reason {
+            ReviewFailure::Other(_) => IssueKind::Error,
+            _ => IssueKind::NotReviewed,
+        }
+    }
+}
+
+impl ReportEntry<ReviewSuccess> {
+    fn reviewers_cell(&self) -> String {
+        match &self.reason.reviewers() {
+            Ok(reviewers) => escape_markdown_table_text(&reviewers),
+            Err(_) => "—".to_owned(),
+        }
+    }
+}
+
+#[derive(Debug, Default)]
+pub struct ReportSummary {
+    pub pull_requests: usize,
+    pub reviewed: usize,
+    pub not_reviewed: usize,
+    pub errors: usize,
+}
+
+pub enum ReportReviewSummary {
+    MissingReviews,
+    MissingReviewsWithErrors,
+    NoIssuesFound,
+}
+
+impl ReportSummary {
+    fn from_entries(entries: &[ReportEntry<ReviewResult>]) -> Self {
+        Self {
+            pull_requests: entries
+                .iter()
+                .filter_map(|entry| entry.commit.pr_number())
+                .unique()
+                .count(),
+            reviewed: entries.iter().filter(|entry| entry.reason.is_ok()).count(),
+            not_reviewed: entries
+                .iter()
+                .filter(|entry| {
+                    matches!(
+                        entry.reason,
+                        Err(ReviewFailure::NoPullRequestFound | ReviewFailure::Unreviewed)
+                    )
+                })
+                .count(),
+            errors: entries
+                .iter()
+                .filter(|entry| matches!(entry.reason, Err(ReviewFailure::Other(_))))
+                .count(),
+        }
+    }
+
+    pub fn review_summary(&self) -> ReportReviewSummary {
+        match self.not_reviewed {
+            0 if self.errors == 0 => ReportReviewSummary::NoIssuesFound,
+            1.. if self.errors == 0 => ReportReviewSummary::MissingReviews,
+            _ => ReportReviewSummary::MissingReviewsWithErrors,
+        }
+    }
+
+    fn has_errors(&self) -> bool {
+        self.errors > 0
+    }
+}
+
+#[derive(Clone, Copy, Debug, Display, PartialEq, Eq, PartialOrd, Ord)]
+enum IssueKind {
+    #[display("Error")]
+    Error,
+    #[display("Not reviewed")]
+    NotReviewed,
+}
+
+#[derive(Debug, Default)]
+pub struct Report {
+    entries: Vec<ReportEntry<ReviewResult>>,
+}
+
+impl Report {
+    pub fn new() -> Self {
+        Self::default()
+    }
+
+    pub fn add(&mut self, commit: CommitDetails, result: ReviewResult) {
+        self.entries.push(ReportEntry {
+            commit,
+            reason: result,
+        });
+    }
+
+    pub fn errors(&self) -> impl Iterator<Item = &ReportEntry<ReviewResult>> {
+        self.entries.iter().filter(|entry| entry.reason.is_err())
+    }
+
+    pub fn summary(&self) -> ReportSummary {
+        ReportSummary::from_entries(&self.entries)
+    }
+
+    pub fn write_markdown(self, path: impl AsRef<Path>) -> anyhow::Result<()> {
+        let path = path.as_ref();
+
+        if let Some(parent) = path
+            .parent()
+            .filter(|parent| !parent.as_os_str().is_empty())
+        {
+            fs::create_dir_all(parent).with_context(|| {
+                format!(
+                    "Failed to create parent directory for markdown report at {}",
+                    path.display()
+                )
+            })?;
+        }
+
+        let summary = self.summary();
+        let (successes, mut issues): (Vec<_>, Vec<_>) =
+            self.entries
+                .into_iter()
+                .partition_map(|entry| match entry.reason {
+                    Ok(success) => Either::Left(ReportEntry {
+                        reason: success,
+                        commit: entry.commit,
+                    }),
+                    Err(fail) => Either::Right(ReportEntry {
+                        reason: fail,
+                        commit: entry.commit,
+                    }),
+                });
+
+        issues.sort_by_key(|entry| entry.issue_kind());
+
+        let file = File::create(path)
+            .with_context(|| format!("Failed to create markdown report at {}", path.display()))?;
+        let mut writer = BufWriter::new(file);
+
+        writeln!(writer, "# Compliance report")?;
+        writeln!(writer)?;
+        writeln!(writer, "## Overview")?;
+        writeln!(writer)?;
+        writeln!(writer, "- PRs: {}", summary.pull_requests)?;
+        writeln!(writer, "- Reviewed: {}", summary.reviewed)?;
+        writeln!(writer, "- Not reviewed: {}", summary.not_reviewed)?;
+        if summary.has_errors() {
+            writeln!(writer, "- Errors: {}", summary.errors)?;
+        }
+        writeln!(writer)?;
+
+        write_issue_table(&mut writer, &issues, &summary)?;
+        write_success_table(&mut writer, &successes)?;
+
+        writer
+            .flush()
+            .with_context(|| format!("Failed to flush markdown report to {}", path.display()))
+    }
+}
+
+fn write_issue_table(
+    writer: &mut impl Write,
+    issues: &[ReportEntry<ReviewFailure>],
+    summary: &ReportSummary,
+) -> std::io::Result<()> {
+    if summary.has_errors() {
+        writeln!(writer, "## Errors and unreviewed commits")?;
+    } else {
+        writeln!(writer, "## Unreviewed commits")?;
+    }
+    writeln!(writer)?;
+
+    if issues.is_empty() {
+        if summary.has_errors() {
+            writeln!(writer, "No errors or unreviewed commits found.")?;
+        } else {
+            writeln!(writer, "No unreviewed commits found.")?;
+        }
+        writeln!(writer)?;
+        return Ok(());
+    }
+
+    writeln!(writer, "| Commit | PR | Author | Outcome | Reason |")?;
+    writeln!(writer, "| --- | --- | --- | --- | --- |")?;
+
+    for entry in issues {
+        let issue_kind = entry.issue_kind();
+        writeln!(
+            writer,
+            "| {} | {} | {} | {} | {} |",
+            entry.commit_cell(),
+            entry.pull_request_cell(),
+            entry.author_cell(),
+            issue_kind,
+            entry.reason_cell(),
+        )?;
+    }
+
+    writeln!(writer)?;
+    Ok(())
+}
+
+fn write_success_table(
+    writer: &mut impl Write,
+    successful_entries: &[ReportEntry<ReviewSuccess>],
+) -> std::io::Result<()> {
+    writeln!(writer, "## Successful commits")?;
+    writeln!(writer)?;
+
+    if successful_entries.is_empty() {
+        writeln!(writer, "No successful commits found.")?;
+        writeln!(writer)?;
+        return Ok(());
+    }
+
+    writeln!(writer, "| Commit | PR | Author | Reviewers | Reason |")?;
+    writeln!(writer, "| --- | --- | --- | --- | --- |")?;
+
+    for entry in successful_entries {
+        writeln!(
+            writer,
+            "| {} | {} | {} | {} | {} |",
+            entry.commit_cell(),
+            entry.pull_request_cell(),
+            entry.author_cell(),
+            entry.reviewers_cell(),
+            entry.reason_cell(),
+        )?;
+    }
+
+    writeln!(writer)?;
+    Ok(())
+}
+
+fn escape_markdown_link_text(input: &str) -> String {
+    escape_markdown_table_text(input)
+        .replace('[', r"\[")
+        .replace(']', r"\]")
+}
+
+fn escape_markdown_table_text(input: &str) -> String {
+    input
+        .replace('\\', r"\\")
+        .replace('|', r"\|")
+        .replace('\r', "")
+        .replace('\n', "<br>")
+}
+
+#[cfg(test)]
+mod tests {
+    use std::str::FromStr;
+
+    use crate::{
+        checks::{ReviewFailure, ReviewSuccess},
+        git::{CommitDetails, CommitList},
+        github::{GitHubUser, PullRequestReview, ReviewState},
+    };
+
+    use super::{Report, ReportReviewSummary};
+
+    fn make_commit(
+        sha: &str,
+        author_name: &str,
+        author_email: &str,
+        title: &str,
+        body: &str,
+    ) -> CommitDetails {
+        let formatted = format!(
+            "{sha}|field-delimiter|{author_name}|field-delimiter|{author_email}|field-delimiter|{title}|body-delimiter|{body}|commit-delimiter|"
+        );
+        CommitList::from_str(&formatted)
+            .expect("test commit should parse")
+            .into_iter()
+            .next()
+            .expect("should have one commit")
+    }
+
+    fn reviewed() -> ReviewSuccess {
+        ReviewSuccess::PullRequestReviewed(vec![PullRequestReview {
+            user: Some(GitHubUser {
+                login: "reviewer".to_owned(),
+            }),
+            state: Some(ReviewState::Approved),
+        }])
+    }
+
+    #[test]
+    fn report_summary_counts_are_accurate() {
+        let mut report = Report::new();
+
+        report.add(
+            make_commit(
+                "aaa",
+                "Alice",
+                "alice@test.com",
+                "Reviewed commit (#100)",
+                "",
+            ),
+            Ok(reviewed()),
+        );
+        report.add(
+            make_commit("bbb", "Bob", "bob@test.com", "Unreviewed commit (#200)", ""),
+            Err(ReviewFailure::Unreviewed),
+        );
+        report.add(
+            make_commit("ccc", "Carol", "carol@test.com", "No PR commit", ""),
+            Err(ReviewFailure::NoPullRequestFound),
+        );
+        report.add(
+            make_commit("ddd", "Dave", "dave@test.com", "Error commit (#300)", ""),
+            Err(ReviewFailure::Other(anyhow::anyhow!("some error"))),
+        );
+
+        let summary = report.summary();
+        assert_eq!(summary.pull_requests, 3);
+        assert_eq!(summary.reviewed, 1);
+        assert_eq!(summary.not_reviewed, 2);
+        assert_eq!(summary.errors, 1);
+    }
+
+    #[test]
+    fn report_summary_all_reviewed_is_no_issues() {
+        let mut report = Report::new();
+
+        report.add(
+            make_commit("aaa", "Alice", "alice@test.com", "First (#100)", ""),
+            Ok(reviewed()),
+        );
+        report.add(
+            make_commit("bbb", "Bob", "bob@test.com", "Second (#200)", ""),
+            Ok(reviewed()),
+        );
+
+        let summary = report.summary();
+        assert!(matches!(
+            summary.review_summary(),
+            ReportReviewSummary::NoIssuesFound
+        ));
+    }
+
+    #[test]
+    fn report_summary_missing_reviews_only() {
+        let mut report = Report::new();
+
+        report.add(
+            make_commit("aaa", "Alice", "alice@test.com", "Reviewed (#100)", ""),
+            Ok(reviewed()),
+        );
+        report.add(
+            make_commit("bbb", "Bob", "bob@test.com", "Unreviewed (#200)", ""),
+            Err(ReviewFailure::Unreviewed),
+        );
+
+        let summary = report.summary();
+        assert!(matches!(
+            summary.review_summary(),
+            ReportReviewSummary::MissingReviews
+        ));
+    }
+
+    #[test]
+    fn report_summary_errors_and_missing_reviews() {
+        let mut report = Report::new();
+
+        report.add(
+            make_commit("aaa", "Alice", "alice@test.com", "Unreviewed (#100)", ""),
+            Err(ReviewFailure::Unreviewed),
+        );
+        report.add(
+            make_commit("bbb", "Bob", "bob@test.com", "Errored (#200)", ""),
+            Err(ReviewFailure::Other(anyhow::anyhow!("check failed"))),
+        );
+
+        let summary = report.summary();
+        assert!(matches!(
+            summary.review_summary(),
+            ReportReviewSummary::MissingReviewsWithErrors
+        ));
+    }
+
+    #[test]
+    fn report_summary_deduplicates_pull_requests() {
+        let mut report = Report::new();
+
+        report.add(
+            make_commit("aaa", "Alice", "alice@test.com", "First change (#100)", ""),
+            Ok(reviewed()),
+        );
+        report.add(
+            make_commit("bbb", "Bob", "bob@test.com", "Second change (#100)", ""),
+            Ok(reviewed()),
+        );
+
+        let summary = report.summary();
+        assert_eq!(summary.pull_requests, 1);
+    }
+}

tooling/xtask/Cargo.toml 🔗

@@ -15,7 +15,8 @@ backtrace.workspace = true
 cargo_metadata.workspace = true
 cargo_toml.workspace = true
 clap = { workspace = true, features = ["derive"] }
-toml.workspace = true
+compliance = { workspace = true, features = ["octo-client"] }
+gh-workflow.workspace = true
 indoc.workspace = true
 indexmap.workspace = true
 itertools.workspace = true
@@ -24,5 +25,6 @@ serde.workspace = true
 serde_json.workspace = true
 serde_yaml = "0.9.34"
 strum.workspace = true
+tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
+toml.workspace = true
 toml_edit.workspace = true
-gh-workflow.workspace = true

tooling/xtask/src/main.rs 🔗

@@ -15,6 +15,7 @@ struct Args {
 enum CliCommand {
     /// Runs `cargo clippy`.
     Clippy(tasks::clippy::ClippyArgs),
+    Compliance(tasks::compliance::ComplianceArgs),
     Licenses(tasks::licenses::LicensesArgs),
     /// Checks that packages conform to a set of standards.
     PackageConformity(tasks::package_conformity::PackageConformityArgs),
@@ -31,6 +32,7 @@ fn main() -> Result<()> {
 
     match args.command {
         CliCommand::Clippy(args) => tasks::clippy::run_clippy(args),
+        CliCommand::Compliance(args) => tasks::compliance::check_compliance(args),
         CliCommand::Licenses(args) => tasks::licenses::run_licenses(args),
         CliCommand::PackageConformity(args) => {
             tasks::package_conformity::run_package_conformity(args)

tooling/xtask/src/tasks.rs 🔗

@@ -1,4 +1,5 @@
 pub mod clippy;
+pub mod compliance;
 pub mod licenses;
 pub mod package_conformity;
 pub mod publish_gpui;

tooling/xtask/src/tasks/compliance.rs 🔗

@@ -0,0 +1,135 @@
+use std::path::PathBuf;
+
+use anyhow::{Context, Result};
+use clap::Parser;
+
+use compliance::{
+    checks::Reporter,
+    git::{CommitsFromVersionToHead, GetVersionTags, GitCommand, VersionTag},
+    github::GitHubClient,
+    report::ReportReviewSummary,
+};
+
+#[derive(Parser)]
+pub struct ComplianceArgs {
+    #[arg(value_parser = VersionTag::parse)]
+    // The version to be on the lookout for
+    pub(crate) version_tag: VersionTag,
+    #[arg(long)]
+    // The markdown file to write the compliance report to
+    report_path: PathBuf,
+    #[arg(long)]
+    // An optional branch to use instead of the determined version branch
+    branch: Option<String>,
+}
+
+impl ComplianceArgs {
+    pub(crate) fn version_tag(&self) -> &VersionTag {
+        &self.version_tag
+    }
+
+    fn version_branch(&self) -> String {
+        self.branch.clone().unwrap_or_else(|| {
+            format!(
+                "v{major}.{minor}.x",
+                major = self.version_tag().version().major,
+                minor = self.version_tag().version().minor
+            )
+        })
+    }
+}
+
+async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> {
+    let app_id = std::env::var("GITHUB_APP_ID").context("Missing GITHUB_APP_ID")?;
+    let key = std::env::var("GITHUB_APP_KEY").context("Missing GITHUB_APP_KEY")?;
+
+    let tag = args.version_tag();
+
+    let previous_version = GitCommand::run(GetVersionTags)?
+        .sorted()
+        .find_previous_minor_version(&tag)
+        .cloned()
+        .ok_or_else(|| {
+            anyhow::anyhow!(
+                "Could not find previous version for tag {tag}",
+                tag = tag.to_string()
+            )
+        })?;
+
+    println!(
+        "Checking compliance for version {} with version {} as base",
+        tag.version(),
+        previous_version.version()
+    );
+
+    let commits = GitCommand::run(CommitsFromVersionToHead::new(
+        previous_version,
+        args.version_branch(),
+    ))?;
+
+    let Some(range) = commits.range() else {
+        anyhow::bail!("No commits found to check");
+    };
+
+    println!("Checking commit range {range}, {} total", commits.len());
+
+    let client = GitHubClient::for_app(
+        app_id.parse().context("Failed to parse app ID as int")?,
+        key.as_ref(),
+    )
+    .await?;
+
+    println!("Initialized GitHub client for app ID {app_id}");
+
+    let report = Reporter::new(commits, &client).generate_report().await?;
+
+    println!(
+        "Generated report for version {}",
+        args.version_tag().to_string()
+    );
+
+    let summary = report.summary();
+
+    println!(
+        "Applying compliance labels to {} pull requests",
+        summary.pull_requests
+    );
+
+    for report in report.errors() {
+        if let Some(pr_number) = report.commit.pr_number() {
+            println!("Adding review label to PR {}...", pr_number);
+
+            client
+                .add_label_to_pull_request(compliance::github::PR_REVIEW_LABEL, pr_number)
+                .await?;
+        }
+    }
+
+    let report_path = args.report_path.with_extension("md");
+
+    report.write_markdown(&report_path)?;
+
+    println!("Wrote compliance report to {}", report_path.display());
+
+    match summary.review_summary() {
+        ReportReviewSummary::MissingReviews => Err(anyhow::anyhow!(
+            "Compliance check failed, found {} commits not reviewed",
+            summary.not_reviewed
+        )),
+        ReportReviewSummary::MissingReviewsWithErrors => Err(anyhow::anyhow!(
+            "Compliance check failed with {} unreviewed commits and {} other issues",
+            summary.not_reviewed,
+            summary.errors
+        )),
+        ReportReviewSummary::NoIssuesFound => {
+            println!("No issues found, compliance check passed.");
+            Ok(())
+        }
+    }
+}
+
+pub fn check_compliance(args: ComplianceArgs) -> Result<()> {
+    tokio::runtime::Runtime::new()
+        .context("Failed to create tokio runtime")
+        .and_then(|handle| handle.block_on(check_compliance_impl(args)))
+}

tooling/xtask/src/tasks/workflows.rs 🔗

@@ -11,6 +11,7 @@ mod autofix_pr;
 mod bump_patch_version;
 mod cherry_pick;
 mod compare_perf;
+mod compliance_check;
 mod danger;
 mod deploy_collab;
 mod extension_auto_bump;
@@ -197,6 +198,7 @@ pub fn run_workflows(args: GenerateWorkflowArgs) -> Result<()> {
         WorkflowFile::zed(bump_patch_version::bump_patch_version),
         WorkflowFile::zed(cherry_pick::cherry_pick),
         WorkflowFile::zed(compare_perf::compare_perf),
+        WorkflowFile::zed(compliance_check::compliance_check),
         WorkflowFile::zed(danger::danger),
         WorkflowFile::zed(deploy_collab::deploy_collab),
         WorkflowFile::zed(extension_bump::extension_bump),

tooling/xtask/src/tasks/workflows/compliance_check.rs 🔗

@@ -0,0 +1,66 @@
+use gh_workflow::{Event, Expression, Job, Run, Schedule, Step, Workflow};
+
+use crate::tasks::workflows::{
+    runners,
+    steps::{self, CommonJobConditions, named},
+    vars::{self, StepOutput},
+};
+
+pub fn compliance_check() -> Workflow {
+    let check = scheduled_compliance_check();
+
+    named::workflow()
+        .on(Event::default().schedule([Schedule::new("30 17 * * 2")]))
+        .add_env(("CARGO_TERM_COLOR", "always"))
+        .add_job(check.name, check.job)
+}
+
+fn scheduled_compliance_check() -> steps::NamedJob {
+    let determine_version_step = named::bash(indoc::indoc! {r#"
+        VERSION=$(sed -n 's/^version = "\(.*\)"/\1/p' crates/zed/Cargo.toml | tr -d '[:space:]')
+        if [ -z "$VERSION" ]; then
+            echo "Could not determine version from crates/zed/Cargo.toml"
+            exit 1
+        fi
+        TAG="v${VERSION}-pre"
+        echo "Checking compliance for $TAG"
+        echo "tag=$TAG" >> "$GITHUB_OUTPUT"
+    "#})
+    .id("determine-version");
+
+    let tag_output = StepOutput::new(&determine_version_step, "tag");
+
+    fn run_compliance_check(tag: &StepOutput) -> Step<Run> {
+        named::bash(
+            r#"cargo xtask compliance "$LATEST_TAG" --branch main --report-path target/compliance-report"#,
+        )
+        .id("run-compliance-check")
+        .add_env(("LATEST_TAG", tag.to_string()))
+        .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID))
+        .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY))
+    }
+
+    fn send_failure_slack_notification(tag: &StepOutput) -> Step<Run> {
+        named::bash(indoc::indoc! {r#"
+            MESSAGE="⚠️ Scheduled compliance check failed for upcoming preview release $LATEST_TAG: There are PRs with missing reviews."
+
+            curl -X POST -H 'Content-type: application/json' \
+                --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+                "$SLACK_WEBHOOK"
+        "#})
+        .if_condition(Expression::new("failure()"))
+        .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES))
+        .add_env(("LATEST_TAG", tag.to_string()))
+    }
+
+    named::job(
+        Job::default()
+            .with_repository_owner_guard()
+            .runs_on(runners::LINUX_SMALL)
+            .add_step(steps::checkout_repo().with_full_history())
+            .add_step(steps::cache_rust_dependencies_namespace())
+            .add_step(determine_version_step)
+            .add_step(run_compliance_check(&tag_output))
+            .add_step(send_failure_slack_notification(&tag_output)),
+    )
+}

tooling/xtask/src/tasks/workflows/release.rs 🔗

@@ -1,11 +1,13 @@
-use gh_workflow::{Event, Expression, Push, Run, Step, Use, Workflow, ctx::Context};
+use gh_workflow::{Event, Expression, Job, Push, Run, Step, Use, Workflow, ctx::Context};
 use indoc::formatdoc;
 
 use crate::tasks::workflows::{
     run_bundling::{bundle_linux, bundle_mac, bundle_windows},
     run_tests,
     runners::{self, Arch, Platform},
-    steps::{self, FluentBuilder, NamedJob, dependant_job, named, release_job},
+    steps::{
+        self, CommonJobConditions, FluentBuilder, NamedJob, dependant_job, named, release_job,
+    },
     vars::{self, StepOutput, assets},
 };
 
@@ -22,6 +24,7 @@ pub(crate) fn release() -> Workflow {
     let check_scripts = run_tests::check_scripts();
 
     let create_draft_release = create_draft_release();
+    let compliance = compliance_check();
 
     let bundle = ReleaseBundleJobs {
         linux_aarch64: bundle_linux(
@@ -92,6 +95,7 @@ pub(crate) fn release() -> Workflow {
         .add_job(windows_clippy.name, windows_clippy.job)
         .add_job(check_scripts.name, check_scripts.job)
         .add_job(create_draft_release.name, create_draft_release.job)
+        .add_job(compliance.name, compliance.job)
         .map(|mut workflow| {
             for job in bundle.into_jobs() {
                 workflow = workflow.add_job(job.name, job.job);
@@ -149,6 +153,59 @@ pub(crate) fn create_sentry_release() -> Step<Use> {
     .add_with(("environment", "production"))
 }
 
+fn compliance_check() -> NamedJob {
+    fn run_compliance_check() -> Step<Run> {
+        named::bash(
+            r#"cargo xtask compliance "$GITHUB_REF_NAME" --report-path "$COMPLIANCE_FILE_OUTPUT""#,
+        )
+        .id("run-compliance-check")
+        .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID))
+        .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY))
+    }
+
+    fn send_compliance_slack_notification() -> Step<Run> {
+        named::bash(indoc::indoc! {r#"
+            if [ "$COMPLIANCE_OUTCOME" == "success" ]; then
+                STATUS="✅ Compliance check passed for $GITHUB_REF_NAME"
+            else
+                STATUS="❌ Compliance check failed for $GITHUB_REF_NAME"
+            fi
+
+            REPORT_CONTENT=""
+            if [ -f "$COMPLIANCE_FILE_OUTPUT" ]; then
+                REPORT_CONTENT=$(cat "$REPORT_FILE")
+            fi
+
+            MESSAGE=$(printf "%s\n\n%s" "$STATUS" "$REPORT_CONTENT")
+
+            curl -X POST -H 'Content-type: application/json' \
+                --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+                "$SLACK_WEBHOOK"
+        "#})
+        .if_condition(Expression::new("always()"))
+        .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES))
+        .add_env((
+            "COMPLIANCE_OUTCOME",
+            "${{ steps.run-compliance-check.outcome }}",
+        ))
+    }
+
+    named::job(
+        Job::default()
+            .add_env(("COMPLIANCE_FILE_PATH", "compliance.md"))
+            .with_repository_owner_guard()
+            .runs_on(runners::LINUX_DEFAULT)
+            .add_step(
+                steps::checkout_repo()
+                    .with_full_history()
+                    .with_ref(Context::github().ref_()),
+            )
+            .add_step(steps::cache_rust_dependencies_namespace())
+            .add_step(run_compliance_check())
+            .add_step(send_compliance_slack_notification()),
+    )
+}
+
 fn validate_release_assets(deps: &[&NamedJob]) -> NamedJob {
     let expected_assets: Vec<String> = assets::all().iter().map(|a| format!("\"{a}\"")).collect();
     let expected_assets_json = format!("[{}]", expected_assets.join(", "));
@@ -171,10 +228,54 @@ fn validate_release_assets(deps: &[&NamedJob]) -> NamedJob {
         "#,
     };
 
+    fn run_post_upload_compliance_check() -> Step<Run> {
+        named::bash(
+            r#"cargo xtask compliance "$GITHUB_REF_NAME" --report-path target/compliance-report"#,
+        )
+        .id("run-post-upload-compliance-check")
+        .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID))
+        .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY))
+    }
+
+    fn send_post_upload_compliance_notification() -> Step<Run> {
+        named::bash(indoc::indoc! {r#"
+            if [ -z "$COMPLIANCE_OUTCOME" ] || [ "$COMPLIANCE_OUTCOME" == "skipped" ]; then
+                echo "Compliance check was skipped, not sending notification"
+                exit 0
+            fi
+
+            TAG="$GITHUB_REF_NAME"
+
+            if [ "$COMPLIANCE_OUTCOME" == "success" ]; then
+                MESSAGE="✅ Post-upload compliance re-check passed for $TAG"
+            else
+                MESSAGE="❌ Post-upload compliance re-check failed for $TAG"
+            fi
+
+            curl -X POST -H 'Content-type: application/json' \
+                --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+                "$SLACK_WEBHOOK"
+        "#})
+        .if_condition(Expression::new("always()"))
+        .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES))
+        .add_env((
+            "COMPLIANCE_OUTCOME",
+            "${{ steps.run-post-upload-compliance-check.outcome }}",
+        ))
+    }
+
     named::job(
-        dependant_job(deps).runs_on(runners::LINUX_SMALL).add_step(
-            named::bash(&validation_script).add_env(("GITHUB_TOKEN", vars::GITHUB_TOKEN)),
-        ),
+        dependant_job(deps)
+            .runs_on(runners::LINUX_SMALL)
+            .add_step(named::bash(&validation_script).add_env(("GITHUB_TOKEN", vars::GITHUB_TOKEN)))
+            .add_step(
+                steps::checkout_repo()
+                    .with_full_history()
+                    .with_ref(Context::github().ref_()),
+            )
+            .add_step(steps::cache_rust_dependencies_namespace())
+            .add_step(run_post_upload_compliance_check())
+            .add_step(send_post_upload_compliance_notification()),
     )
 }
 
@@ -255,7 +356,7 @@ fn create_draft_release() -> NamedJob {
             .add_step(
                 steps::checkout_repo()
                     .with_custom_fetch_depth(25)
-                    .with_ref("${{ github.ref }}"),
+                    .with_ref(Context::github().ref_()),
             )
             .add_step(steps::script("script/determine-release-channel"))
             .add_step(steps::script("mkdir -p target/"))