diff --git a/.github/workflows/compliance_check.yml b/.github/workflows/compliance_check.yml new file mode 100644 index 0000000000000000000000000000000000000000..f09c460c233b04e78df01e7828b4def737dec16e --- /dev/null +++ b/.github/workflows/compliance_check.yml @@ -0,0 +1,55 @@ +# Generated from xtask::workflows::compliance_check +# Rebuild with `cargo xtask workflows`. +name: compliance_check +env: + CARGO_TERM_COLOR: always +on: + schedule: + - cron: 30 17 * * 2 +jobs: + scheduled_compliance_check: + if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions') + runs-on: namespace-profile-2x4-ubuntu-2404 + steps: + - name: steps::checkout_repo + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd + with: + clean: false + fetch-depth: 0 + - name: steps::cache_rust_dependencies_namespace + uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9 + with: + cache: rust + path: ~/.rustup + - id: determine-version + name: compliance_check::scheduled_compliance_check + run: | + VERSION=$(sed -n 's/^version = "\(.*\)"/\1/p' crates/zed/Cargo.toml | tr -d '[:space:]') + if [ -z "$VERSION" ]; then + echo "Could not determine version from crates/zed/Cargo.toml" + exit 1 + fi + TAG="v${VERSION}-pre" + echo "Checking compliance for $TAG" + echo "tag=$TAG" >> "$GITHUB_OUTPUT" + - id: run-compliance-check + name: compliance_check::scheduled_compliance_check::run_compliance_check + run: cargo xtask compliance "$LATEST_TAG" --branch main --report-path target/compliance-report + env: + LATEST_TAG: ${{ steps.determine-version.outputs.tag }} + GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }} + GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }} + - name: compliance_check::scheduled_compliance_check::send_failure_slack_notification + if: failure() + run: | + MESSAGE="⚠️ Scheduled compliance check failed for upcoming preview release $LATEST_TAG: There are PRs with missing reviews." + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }} + LATEST_TAG: ${{ steps.determine-version.outputs.tag }} +defaults: + run: + shell: bash -euxo pipefail {0} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 35efafcfcd97c0139f8225ce7b15a05946c385ad..1401144ab3abda17dd4f526edd42166d37a47a49 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -293,6 +293,51 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} timeout-minutes: 60 + compliance_check: + if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions') + runs-on: namespace-profile-16x32-ubuntu-2204 + env: + COMPLIANCE_FILE_PATH: compliance.md + steps: + - name: steps::checkout_repo + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd + with: + clean: false + fetch-depth: 0 + ref: ${{ github.ref }} + - name: steps::cache_rust_dependencies_namespace + uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9 + with: + cache: rust + path: ~/.rustup + - id: run-compliance-check + name: release::compliance_check::run_compliance_check + run: cargo xtask compliance "$GITHUB_REF_NAME" --report-path "$COMPLIANCE_FILE_OUTPUT" + env: + GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }} + GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }} + - name: release::compliance_check::send_compliance_slack_notification + if: always() + run: | + if [ "$COMPLIANCE_OUTCOME" == "success" ]; then + STATUS="✅ Compliance check passed for $GITHUB_REF_NAME" + else + STATUS="❌ Compliance check failed for $GITHUB_REF_NAME" + fi + + REPORT_CONTENT="" + if [ -f "$COMPLIANCE_FILE_OUTPUT" ]; then + REPORT_CONTENT=$(cat "$REPORT_FILE") + fi + + MESSAGE=$(printf "%s\n\n%s" "$STATUS" "$REPORT_CONTENT") + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }} + COMPLIANCE_OUTCOME: ${{ steps.run-compliance-check.outcome }} bundle_linux_aarch64: needs: - run_tests_linux @@ -613,6 +658,45 @@ jobs: echo "All expected assets are present in the release." env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: steps::checkout_repo + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd + with: + clean: false + fetch-depth: 0 + ref: ${{ github.ref }} + - name: steps::cache_rust_dependencies_namespace + uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9 + with: + cache: rust + path: ~/.rustup + - id: run-post-upload-compliance-check + name: release::validate_release_assets::run_post_upload_compliance_check + run: cargo xtask compliance "$GITHUB_REF_NAME" --report-path target/compliance-report + env: + GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }} + GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }} + - name: release::validate_release_assets::send_post_upload_compliance_notification + if: always() + run: | + if [ -z "$COMPLIANCE_OUTCOME" ] || [ "$COMPLIANCE_OUTCOME" == "skipped" ]; then + echo "Compliance check was skipped, not sending notification" + exit 0 + fi + + TAG="$GITHUB_REF_NAME" + + if [ "$COMPLIANCE_OUTCOME" == "success" ]; then + MESSAGE="✅ Post-upload compliance re-check passed for $TAG" + else + MESSAGE="❌ Post-upload compliance re-check failed for $TAG" + fi + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }} + COMPLIANCE_OUTCOME: ${{ steps.run-post-upload-compliance-check.outcome }} auto_release_preview: needs: - validate_release_assets diff --git a/Cargo.lock b/Cargo.lock index a782de048787ce7bde37ba16bbb869a4026b4a6d..94aca307210e19bc97c14002ba5f136edfd76778 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -677,6 +677,15 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arc-swap" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" +dependencies = [ + "rustversion", +] + [[package]] name = "arg_enum_proc_macro" version = "0.3.4" @@ -2530,6 +2539,16 @@ dependencies = [ "serde", ] +[[package]] +name = "cargo-platform" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87a0c0e6148f11f01f32650a2ea02d532b2ad4e81d8bd41e6e565b5adc5e6082" +dependencies = [ + "serde", + "serde_core", +] + [[package]] name = "cargo_metadata" version = "0.19.2" @@ -2537,7 +2556,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd5eb614ed4c27c5d706420e4320fbe3216ab31fa1c33cd8246ac36dae4479ba" dependencies = [ "camino", - "cargo-platform", + "cargo-platform 0.1.9", + "semver", + "serde", + "serde_json", + "thiserror 2.0.17", +] + +[[package]] +name = "cargo_metadata" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef987d17b0a113becdd19d3d0022d04d7ef41f9efe4f3fb63ac44ba61df3ade9" +dependencies = [ + "camino", + "cargo-platform 0.3.2", "semver", "serde", "serde_json", @@ -3284,6 +3317,25 @@ dependencies = [ "workspace", ] +[[package]] +name = "compliance" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "derive_more", + "futures 0.3.32", + "indoc", + "itertools 0.14.0", + "jsonwebtoken", + "octocrab", + "regex", + "semver", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "component" version = "0.1.0" @@ -8324,6 +8376,7 @@ dependencies = [ "http 1.3.1", "hyper 1.7.0", "hyper-util", + "log", "rustls 0.23.33", "rustls-native-certs 0.8.2", "rustls-pki-types", @@ -8332,6 +8385,19 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper 1.7.0", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -10012,7 +10078,7 @@ dependencies = [ [[package]] name = "lsp-types" version = "0.95.1" -source = "git+https://github.com/zed-industries/lsp-types?rev=a4f410987660bf560d1e617cb78117c6b6b9f599#a4f410987660bf560d1e617cb78117c6b6b9f599" +source = "git+https://github.com/zed-industries/lsp-types?rev=c7396459fefc7886b4adfa3b596832405ae1e880#c7396459fefc7886b4adfa3b596832405ae1e880" dependencies = [ "bitflags 1.3.2", "serde", @@ -11384,6 +11450,48 @@ dependencies = [ "memchr", ] +[[package]] +name = "octocrab" +version = "0.49.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63f6687a23731011d0117f9f4c3cdabaa7b5e42ca671f42b5cc0657c492540e3" +dependencies = [ + "arc-swap", + "async-trait", + "base64 0.22.1", + "bytes 1.11.1", + "cargo_metadata 0.23.1", + "cfg-if", + "chrono", + "either", + "futures 0.3.32", + "futures-core", + "futures-util", + "getrandom 0.2.16", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "hyper-rustls 0.27.7", + "hyper-timeout", + "hyper-util", + "jsonwebtoken", + "once_cell", + "percent-encoding", + "pin-project", + "secrecy", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "snafu", + "tokio", + "tower 0.5.2", + "tower-http 0.6.6", + "url", + "web-time", +] + [[package]] name = "ollama" version = "0.1.0" @@ -15385,6 +15493,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "secrecy" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" +dependencies = [ + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -16089,6 +16206,27 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0f7a918bd2a9951d18ee6e48f076843e8e73a9a5d22cf05bcd4b7a81bdd04e17" +[[package]] +name = "snafu" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e84b3f4eacbf3a1ce05eac6763b4d629d60cbc94d632e4092c54ade71f1e1a2" +dependencies = [ + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "snippet" version = "0.1.0" @@ -18093,8 +18231,10 @@ dependencies = [ "pin-project-lite", "sync_wrapper 1.0.2", "tokio", + "tokio-util", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -18132,6 +18272,7 @@ dependencies = [ "tower 0.5.2", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -19978,6 +20119,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" dependencies = [ "js-sys", + "serde", "wasm-bindgen", ] @@ -21715,9 +21857,10 @@ dependencies = [ "annotate-snippets", "anyhow", "backtrace", - "cargo_metadata", + "cargo_metadata 0.19.2", "cargo_toml", "clap", + "compliance", "gh-workflow", "indexmap", "indoc", @@ -21727,6 +21870,7 @@ dependencies = [ "serde_json", "serde_yaml", "strum 0.27.2", + "tokio", "toml 0.8.23", "toml_edit 0.22.27", ] @@ -22402,6 +22546,7 @@ name = "zeta_prompt" version = "0.1.0" dependencies = [ "anyhow", + "imara-diff", "indoc", "serde", "strum 0.27.2", diff --git a/Cargo.toml b/Cargo.toml index 81bbb1176ddddcc117fc9082586cbc08dbb95d61..5cb5b991b645ec1b78b16f48493c7c8dc1426344 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -242,6 +242,7 @@ members = [ # Tooling # + "tooling/compliance", "tooling/perf", "tooling/xtask", ] @@ -289,6 +290,7 @@ collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections", version = "0.1.0" } command_palette = { path = "crates/command_palette" } command_palette_hooks = { path = "crates/command_palette_hooks" } +compliance = { path = "tooling/compliance" } component = { path = "crates/component" } component_preview = { path = "crates/component_preview" } context_server = { path = "crates/context_server" } @@ -547,6 +549,7 @@ derive_more = { version = "2.1.1", features = [ "add_assign", "deref", "deref_mut", + "display", "from_str", "mul", "mul_assign", @@ -596,7 +599,7 @@ linkify = "0.10.0" libwebrtc = "0.3.26" livekit = { version = "0.7.32", features = ["tokio", "rustls-tls-native-roots"] } log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } -lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "a4f410987660bf560d1e617cb78117c6b6b9f599" } +lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c7396459fefc7886b4adfa3b596832405ae1e880" } mach2 = "0.5" markup5ever_rcdom = "0.3.0" metal = "0.33" diff --git a/assets/settings/default.json b/assets/settings/default.json index 5e1eb0e68d2f8a17f89422597aa29b99516333e8..63e906e3b11206fc458f8d7353f3ecba0abeb825 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -2417,6 +2417,7 @@ "toggle_relative_line_numbers": false, "use_system_clipboard": "always", "use_smartcase_find": false, + "use_regex_search": true, "gdefault": false, "highlight_on_yank_duration": 200, "custom_digraphs": {}, diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 0bcb8254c8b8123eef3faaa913bb360de8dcc76d..36c9fb40c4a573e09da05618a29c1898cced60ad 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1032,6 +1032,7 @@ pub struct AcpThread { connection: Rc, token_usage: Option, prompt_capabilities: acp::PromptCapabilities, + available_commands: Vec, _observe_prompt_capabilities: Task>, terminals: HashMap>, pending_terminal_output: HashMap>>, @@ -1220,6 +1221,7 @@ impl AcpThread { session_id, token_usage: None, prompt_capabilities, + available_commands: Vec::new(), _observe_prompt_capabilities: task, terminals: HashMap::default(), pending_terminal_output: HashMap::default(), @@ -1239,6 +1241,10 @@ impl AcpThread { self.prompt_capabilities.clone() } + pub fn available_commands(&self) -> &[acp::AvailableCommand] { + &self.available_commands + } + pub fn draft_prompt(&self) -> Option<&[acp::ContentBlock]> { self.draft_prompt.as_deref() } @@ -1419,7 +1425,10 @@ impl AcpThread { acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate { available_commands, .. - }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)), + }) => { + self.available_commands = available_commands.clone(); + cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)); + } acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate { current_mode_id, .. diff --git a/crates/agent/src/tests/edit_file_thread_test.rs b/crates/agent/src/tests/edit_file_thread_test.rs index 3beb5cb0d51abc55fbf3cf0849ced248a9d1fa5c..b5ce6441e790e0b79b2798dfe0008cc74eec69b8 100644 --- a/crates/agent/src/tests/edit_file_thread_test.rs +++ b/crates/agent/src/tests/edit_file_thread_test.rs @@ -202,3 +202,214 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) { ); }); } + +#[gpui::test] +async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes( + cx: &mut TestAppContext, +) { + super::init_test(cx); + super::always_allow_tools(cx); + + // Enable the streaming edit file tool feature flag. + cx.update(|cx| { + cx.update_flags(true, vec!["streaming-edit-file-tool".to_string()]); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n" + } + }), + ) + .await; + + let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let project_context = cx.new(|_cx| ProjectContext::default()); + let context_server_store = project.read_with(cx, |project, _| project.context_server_store()); + let context_server_registry = + cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + model.as_fake().set_supports_streaming_tools(true); + let fake_model = model.as_fake(); + + let thread = cx.new(|cx| { + let mut thread = crate::Thread::new( + project.clone(), + project_context, + context_server_registry, + crate::Templates::new(), + Some(model.clone()), + cx, + ); + let language_registry = project.read(cx).languages().clone(); + thread.add_tool(crate::StreamingEditFileTool::new( + project.clone(), + cx.weak_entity(), + thread.action_log().clone(), + language_registry, + )); + thread + }); + + let _events = thread + .update(cx, |thread, cx| { + thread.send( + UserMessageId::new(), + ["Write new content to src/main.rs"], + cx, + ) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use_id = "edit_1"; + let partial_1 = LanguageModelToolUse { + id: tool_use_id.into(), + name: EditFileTool::NAME.into(), + raw_input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write" + }) + .to_string(), + input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write" + }), + is_input_complete: false, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_1)); + cx.run_until_parked(); + + let partial_2 = LanguageModelToolUse { + id: tool_use_id.into(), + name: EditFileTool::NAME.into(), + raw_input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() { /* rewritten */ }" + }) + .to_string(), + input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() { /* rewritten */ }" + }), + is_input_complete: false, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_2)); + cx.run_until_parked(); + + // Now send a json parse error. At this point we have started writing content to the buffer. + fake_model.send_last_completion_stream_event( + LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_use_id.into(), + tool_name: EditFileTool::NAME.into(), + raw_input: r#"{"display_description":"Rewrite main.rs","path":"project/src/main.rs","mode":"write","content":"fn main() { /* rewritten "#.into(), + json_parse_error: "EOF while parsing a string at line 1 column 95".into(), + }, + ); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // cx.executor().advance_clock(Duration::from_secs(5)); + // cx.run_until_parked(); + + assert!( + !fake_model.pending_completions().is_empty(), + "Thread should have retried after the error" + ); + + // Respond with a new, well-formed, complete edit_file tool use. + let tool_use = LanguageModelToolUse { + id: "edit_2".into(), + name: EditFileTool::NAME.into(), + raw_input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n" + }) + .to_string(), + input: json!({ + "display_description": "Rewrite main.rs", + "path": "project/src/main.rs", + "mode": "write", + "content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n" + }), + is_input_complete: true, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use)); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let pending_completions = fake_model.pending_completions(); + assert!( + pending_completions.len() == 1, + "Expected only the follow-up completion containing the successful tool result" + ); + + let completion = pending_completions + .into_iter() + .last() + .expect("Expected a completion containing the tool result for edit_2"); + + let tool_result = completion + .messages + .iter() + .flat_map(|msg| &msg.content) + .find_map(|content| match content { + language_model::MessageContent::ToolResult(result) + if result.tool_use_id == language_model::LanguageModelToolUseId::from("edit_2") => + { + Some(result) + } + _ => None, + }) + .expect("Should have a tool result for edit_2"); + + // Ensure that the second tool call completed successfully and edits were applied. + assert!( + !tool_result.is_error, + "Tool result should succeed, got: {:?}", + tool_result + ); + let content_text = match &tool_result.content { + language_model::LanguageModelToolResultContent::Text(t) => t.to_string(), + other => panic!("Expected text content, got: {:?}", other), + }; + assert!( + !content_text.contains("file has been modified since you last read it"), + "Did not expect a stale last-read error, got: {content_text}" + ); + assert!( + !content_text.contains("This file has unsaved changes"), + "Did not expect an unsaved-changes error, got: {content_text}" + ); + + let file_content = fs + .load(path!("/project/src/main.rs").as_ref()) + .await + .expect("file should exist"); + super::assert_eq!( + file_content, + "fn main() {\n println!(\"Hello, rewritten!\");\n}\n", + "The second edit should be applied and saved gracefully" + ); + + fake_model.end_last_completion_stream(); + cx.run_until_parked(); +} diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index f7b52b2573144e4c2fd378cfb19c9ee2473a37db..ff53136a0ded4bbc283fea30598d8d30e6e29709 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -3903,6 +3903,117 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input( }); } +#[gpui::test] +async fn test_streaming_tool_json_parse_error_is_forwarded_to_running_tool( + cx: &mut TestAppContext, +) { + init_test(cx); + always_allow_tools(cx); + + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + thread.update(cx, |thread, _cx| { + thread.add_tool(StreamingJsonErrorContextTool); + }); + + let _events = thread + .update(cx, |thread, cx| { + thread.send( + UserMessageId::new(), + ["Use the streaming_json_error_context tool"], + cx, + ) + }) + .unwrap(); + cx.run_until_parked(); + + let tool_use = LanguageModelToolUse { + id: "tool_1".into(), + name: StreamingJsonErrorContextTool::NAME.into(), + raw_input: r#"{"text": "partial"#.into(), + input: json!({"text": "partial"}), + is_input_complete: false, + thought_signature: None, + }; + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use)); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_event( + LanguageModelCompletionEvent::ToolUseJsonParseError { + id: "tool_1".into(), + tool_name: StreamingJsonErrorContextTool::NAME.into(), + raw_input: r#"{"text": "partial"#.into(), + json_parse_error: "EOF while parsing a string at line 1 column 17".into(), + }, + ); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + cx.executor().advance_clock(Duration::from_secs(5)); + cx.run_until_parked(); + + let completion = fake_model + .pending_completions() + .pop() + .expect("No running turn"); + + let tool_results: Vec<_> = completion + .messages + .iter() + .flat_map(|message| &message.content) + .filter_map(|content| match content { + MessageContent::ToolResult(result) + if result.tool_use_id == language_model::LanguageModelToolUseId::from("tool_1") => + { + Some(result) + } + _ => None, + }) + .collect(); + + assert_eq!( + tool_results.len(), + 1, + "Expected exactly 1 tool result for tool_1, got {}: {:#?}", + tool_results.len(), + tool_results + ); + + let result = tool_results[0]; + assert!(result.is_error); + let content_text = match &result.content { + language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), + other => panic!("Expected text content, got {:?}", other), + }; + assert!( + content_text.contains("Saw partial text 'partial' before invalid JSON"), + "Expected tool-enriched partial context, got: {content_text}" + ); + assert!( + content_text + .contains("Error parsing input JSON: EOF while parsing a string at line 1 column 17"), + "Expected forwarded JSON parse error, got: {content_text}" + ); + assert!( + !content_text.contains("tool input was not fully received"), + "Should not contain orphaned sender error, got: {content_text}" + ); + + fake_model.send_last_completion_stream_text_chunk("Done"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _cx| { + assert!( + thread.is_turn_complete(), + "Thread should not be stuck; the turn should have completed", + ); + }); +} + /// Filters out the stop events for asserting against in tests fn stop_events(result_events: Vec>) -> Vec { result_events @@ -3959,6 +4070,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { InfiniteTool::NAME: true, CancellationAwareTool::NAME: true, StreamingEchoTool::NAME: true, + StreamingJsonErrorContextTool::NAME: true, StreamingFailingEchoTool::NAME: true, TerminalTool::NAME: true, UpdatePlanTool::NAME: true, diff --git a/crates/agent/src/tests/test_tools.rs b/crates/agent/src/tests/test_tools.rs index f36549a6c42f9e810c7794d8ec683613b6ae6933..4744204fae1213d49af92339b8847e9d1f470125 100644 --- a/crates/agent/src/tests/test_tools.rs +++ b/crates/agent/src/tests/test_tools.rs @@ -56,13 +56,12 @@ impl AgentTool for StreamingEchoTool { fn run( self: Arc, - mut input: ToolInput, + input: ToolInput, _event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take(); cx.spawn(async move |_cx| { - while input.recv_partial().await.is_some() {} let input = input .recv() .await @@ -75,6 +74,68 @@ impl AgentTool for StreamingEchoTool { } } +#[derive(JsonSchema, Serialize, Deserialize)] +pub struct StreamingJsonErrorContextToolInput { + /// The text to echo. + pub text: String, +} + +pub struct StreamingJsonErrorContextTool; + +impl AgentTool for StreamingJsonErrorContextTool { + type Input = StreamingJsonErrorContextToolInput; + type Output = String; + + const NAME: &'static str = "streaming_json_error_context"; + + fn supports_input_streaming() -> bool { + true + } + + fn kind() -> acp::ToolKind { + acp::ToolKind::Other + } + + fn initial_title( + &self, + _input: Result, + _cx: &mut App, + ) -> SharedString { + "Streaming JSON Error Context".into() + } + + fn run( + self: Arc, + mut input: ToolInput, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.spawn(async move |_cx| { + let mut last_partial_text = None; + + loop { + match input.next().await { + Ok(ToolInputPayload::Partial(partial)) => { + if let Some(text) = partial.get("text").and_then(|value| value.as_str()) { + last_partial_text = Some(text.to_string()); + } + } + Ok(ToolInputPayload::Full(input)) => return Ok(input.text), + Ok(ToolInputPayload::InvalidJson { error_message }) => { + let partial_text = last_partial_text.unwrap_or_default(); + return Err(format!( + "Saw partial text '{partial_text}' before invalid JSON: {error_message}" + )); + } + Err(error) => { + return Err(format!("Failed to receive tool input: {error}")); + } + } + } + }) + } +} + /// A streaming tool that echoes its input, used to test streaming tool /// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends /// before `is_input_complete`). @@ -119,7 +180,7 @@ impl AgentTool for StreamingFailingEchoTool { ) -> Task> { cx.spawn(async move |_cx| { for _ in 0..self.receive_chunks_until_failure { - let _ = input.recv_partial().await; + let _ = input.next().await; } Err("failed".into()) }) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index bcb5b7b2d2f3eb8cffd5be8b70fc08fef8e9fe37..ea342e8db4e4d97d5eccc849121cd0fd2e403017 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -22,13 +22,13 @@ use client::UserStore; use cloud_api_types::Plan; use collections::{HashMap, HashSet, IndexMap}; use fs::Fs; -use futures::stream; use futures::{ FutureExt, channel::{mpsc, oneshot}, future::Shared, stream::FuturesUnordered, }; +use futures::{StreamExt, stream}; use gpui::{ App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, }; @@ -47,7 +47,6 @@ use schemars::{JsonSchema, Schema}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file}; -use smol::stream::StreamExt; use std::{ collections::BTreeMap, marker::PhantomData, @@ -2095,7 +2094,7 @@ impl Thread { this.update(cx, |this, _cx| { this.pending_message() .tool_results - .insert(tool_result.tool_use_id.clone(), tool_result); + .insert(tool_result.tool_use_id.clone(), tool_result) })?; Ok(()) } @@ -2195,15 +2194,15 @@ impl Thread { raw_input, json_parse_error, } => { - return Ok(Some(Task::ready( - self.handle_tool_use_json_parse_error_event( - id, - tool_name, - raw_input, - json_parse_error, - event_stream, - ), - ))); + return Ok(self.handle_tool_use_json_parse_error_event( + id, + tool_name, + raw_input, + json_parse_error, + event_stream, + cancellation_rx, + cx, + )); } UsageUpdate(usage) => { telemetry::event!( @@ -2304,12 +2303,12 @@ impl Thread { if !tool_use.is_input_complete { if tool.supports_input_streaming() { let running_turn = self.running_turn.as_mut()?; - if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) { + if let Some(sender) = running_turn.streaming_tool_inputs.get_mut(&tool_use.id) { sender.send_partial(tool_use.input); return None; } - let (sender, tool_input) = ToolInputSender::channel(); + let (mut sender, tool_input) = ToolInputSender::channel(); sender.send_partial(tool_use.input); running_turn .streaming_tool_inputs @@ -2331,13 +2330,13 @@ impl Thread { } } - if let Some(sender) = self + if let Some(mut sender) = self .running_turn .as_mut()? .streaming_tool_inputs .remove(&tool_use.id) { - sender.send_final(tool_use.input); + sender.send_full(tool_use.input); return None; } @@ -2410,10 +2409,12 @@ impl Thread { raw_input: Arc, json_parse_error: String, event_stream: &ThreadEventStream, - ) -> LanguageModelToolResult { + cancellation_rx: watch::Receiver, + cx: &mut Context, + ) -> Option> { let tool_use = LanguageModelToolUse { - id: tool_use_id.clone(), - name: tool_name.clone(), + id: tool_use_id, + name: tool_name, raw_input: raw_input.to_string(), input: serde_json::json!({}), is_input_complete: true, @@ -2426,14 +2427,43 @@ impl Thread { event_stream, ); - let tool_output = format!("Error parsing input JSON: {json_parse_error}"); - LanguageModelToolResult { - tool_use_id, - tool_name, - is_error: true, - content: LanguageModelToolResultContent::Text(tool_output.into()), - output: Some(serde_json::Value::String(raw_input.to_string())), + let tool = self.tool(tool_use.name.as_ref()); + + let Some(tool) = tool else { + let content = format!("No tool named {} exists", tool_use.name); + return Some(Task::ready(LanguageModelToolResult { + content: LanguageModelToolResultContent::Text(Arc::from(content)), + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + output: None, + })); + }; + + let error_message = format!("Error parsing input JSON: {json_parse_error}"); + + if tool.supports_input_streaming() + && let Some(mut sender) = self + .running_turn + .as_mut()? + .streaming_tool_inputs + .remove(&tool_use.id) + { + sender.send_invalid_json(error_message); + return None; } + + log::debug!("Running tool {}. Received invalid JSON", tool_use.name); + let tool_input = ToolInput::invalid_json(error_message); + Some(self.run_tool( + tool, + tool_input, + tool_use.id, + tool_use.name, + event_stream, + cancellation_rx, + cx, + )) } fn send_or_update_tool_use( @@ -3114,8 +3144,7 @@ impl EventEmitter for Thread {} /// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams /// them, followed by the final complete input available through `.recv()`. pub struct ToolInput { - partial_rx: mpsc::UnboundedReceiver, - final_rx: oneshot::Receiver, + rx: mpsc::UnboundedReceiver>, _phantom: PhantomData, } @@ -3127,13 +3156,20 @@ impl ToolInput { } pub fn ready(value: serde_json::Value) -> Self { - let (partial_tx, partial_rx) = mpsc::unbounded(); - drop(partial_tx); - let (final_tx, final_rx) = oneshot::channel(); - final_tx.send(value).ok(); + let (tx, rx) = mpsc::unbounded(); + tx.unbounded_send(ToolInputPayload::Full(value)).ok(); Self { - partial_rx, - final_rx, + rx, + _phantom: PhantomData, + } + } + + pub fn invalid_json(error_message: String) -> Self { + let (tx, rx) = mpsc::unbounded(); + tx.unbounded_send(ToolInputPayload::InvalidJson { error_message }) + .ok(); + Self { + rx, _phantom: PhantomData, } } @@ -3147,65 +3183,89 @@ impl ToolInput { /// Wait for the final deserialized input, ignoring all partial updates. /// Non-streaming tools can use this to wait until the whole input is available. pub async fn recv(mut self) -> Result { - // Drain any remaining partials - while self.partial_rx.next().await.is_some() {} + while let Ok(value) = self.next().await { + match value { + ToolInputPayload::Full(value) => return Ok(value), + ToolInputPayload::Partial(_) => {} + ToolInputPayload::InvalidJson { error_message } => { + return Err(anyhow!(error_message)); + } + } + } + Err(anyhow!("tool input was not fully received")) + } + + pub async fn next(&mut self) -> Result> { let value = self - .final_rx + .rx + .next() .await - .map_err(|_| anyhow!("tool input was not fully received"))?; - serde_json::from_value(value).map_err(Into::into) - } + .ok_or_else(|| anyhow!("tool input was not fully received"))?; - /// Returns the next partial JSON snapshot, or `None` when input is complete. - /// Once this returns `None`, call `recv()` to get the final input. - pub async fn recv_partial(&mut self) -> Option { - self.partial_rx.next().await + Ok(match value { + ToolInputPayload::Partial(payload) => ToolInputPayload::Partial(payload), + ToolInputPayload::Full(payload) => { + ToolInputPayload::Full(serde_json::from_value(payload)?) + } + ToolInputPayload::InvalidJson { error_message } => { + ToolInputPayload::InvalidJson { error_message } + } + }) } fn cast(self) -> ToolInput { ToolInput { - partial_rx: self.partial_rx, - final_rx: self.final_rx, + rx: self.rx, _phantom: PhantomData, } } } +pub enum ToolInputPayload { + Partial(serde_json::Value), + Full(T), + InvalidJson { error_message: String }, +} + pub struct ToolInputSender { - partial_tx: mpsc::UnboundedSender, - final_tx: Option>, + has_received_final: bool, + tx: mpsc::UnboundedSender>, } impl ToolInputSender { pub(crate) fn channel() -> (Self, ToolInput) { - let (partial_tx, partial_rx) = mpsc::unbounded(); - let (final_tx, final_rx) = oneshot::channel(); + let (tx, rx) = mpsc::unbounded(); let sender = Self { - partial_tx, - final_tx: Some(final_tx), + tx, + has_received_final: false, }; let input = ToolInput { - partial_rx, - final_rx, + rx, _phantom: PhantomData, }; (sender, input) } pub(crate) fn has_received_final(&self) -> bool { - self.final_tx.is_none() + self.has_received_final } - pub(crate) fn send_partial(&self, value: serde_json::Value) { - self.partial_tx.unbounded_send(value).ok(); + pub fn send_partial(&mut self, payload: serde_json::Value) { + self.tx + .unbounded_send(ToolInputPayload::Partial(payload)) + .ok(); } - pub(crate) fn send_final(mut self, value: serde_json::Value) { - // Close the partial channel so recv_partial() returns None - self.partial_tx.close_channel(); - if let Some(final_tx) = self.final_tx.take() { - final_tx.send(value).ok(); - } + pub fn send_full(&mut self, payload: serde_json::Value) { + self.has_received_final = true; + self.tx.unbounded_send(ToolInputPayload::Full(payload)).ok(); + } + + pub fn send_invalid_json(&mut self, error_message: String) { + self.has_received_final = true; + self.tx + .unbounded_send(ToolInputPayload::InvalidJson { error_message }) + .ok(); } } @@ -4251,68 +4311,78 @@ mod tests { ) { let (thread, event_stream) = setup_thread_for_test(cx).await; - cx.update(|cx| { - thread.update(cx, |thread, _cx| { - let tool_use_id = LanguageModelToolUseId::from("test_tool_id"); - let tool_name: Arc = Arc::from("test_tool"); - let raw_input: Arc = Arc::from("{invalid json"); - let json_parse_error = "expected value at line 1 column 1".to_string(); - - // Call the function under test - let result = thread.handle_tool_use_json_parse_error_event( - tool_use_id.clone(), - tool_name.clone(), - raw_input.clone(), - json_parse_error, - &event_stream, - ); - - // Verify the result is an error - assert!(result.is_error); - assert_eq!(result.tool_use_id, tool_use_id); - assert_eq!(result.tool_name, tool_name); - assert!(matches!( - result.content, - LanguageModelToolResultContent::Text(_) - )); - - // Verify the tool use was added to the message content - { - let last_message = thread.pending_message(); - assert_eq!( - last_message.content.len(), - 1, - "Should have one tool_use in content" - ); - - match &last_message.content[0] { - AgentMessageContent::ToolUse(tool_use) => { - assert_eq!(tool_use.id, tool_use_id); - assert_eq!(tool_use.name, tool_name); - assert_eq!(tool_use.raw_input, raw_input.to_string()); - assert!(tool_use.is_input_complete); - // Should fall back to empty object for invalid JSON - assert_eq!(tool_use.input, json!({})); - } - _ => panic!("Expected ToolUse content"), - } - } - - // Insert the tool result (simulating what the caller does) - thread - .pending_message() - .tool_results - .insert(result.tool_use_id.clone(), result); + let tool_use_id = LanguageModelToolUseId::from("test_tool_id"); + let tool_name: Arc = Arc::from("test_tool"); + let raw_input: Arc = Arc::from("{invalid json"); + let json_parse_error = "expected value at line 1 column 1".to_string(); + + let (_cancellation_tx, cancellation_rx) = watch::channel(false); + + let result = cx + .update(|cx| { + thread.update(cx, |thread, cx| { + // Call the function under test + thread + .handle_tool_use_json_parse_error_event( + tool_use_id.clone(), + tool_name.clone(), + raw_input.clone(), + json_parse_error, + &event_stream, + cancellation_rx, + cx, + ) + .unwrap() + }) + }) + .await; + + // Verify the result is an error + assert!(result.is_error); + assert_eq!(result.tool_use_id, tool_use_id); + assert_eq!(result.tool_name, tool_name); + assert!(matches!( + result.content, + LanguageModelToolResultContent::Text(_) + )); - // Verify the tool result was added + thread.update(cx, |thread, _cx| { + // Verify the tool use was added to the message content + { let last_message = thread.pending_message(); assert_eq!( - last_message.tool_results.len(), + last_message.content.len(), 1, - "Should have one tool_result" + "Should have one tool_use in content" ); - assert!(last_message.tool_results.contains_key(&tool_use_id)); - }); - }); + + match &last_message.content[0] { + AgentMessageContent::ToolUse(tool_use) => { + assert_eq!(tool_use.id, tool_use_id); + assert_eq!(tool_use.name, tool_name); + assert_eq!(tool_use.raw_input, raw_input.to_string()); + assert!(tool_use.is_input_complete); + // Should fall back to empty object for invalid JSON + assert_eq!(tool_use.input, json!({})); + } + _ => panic!("Expected ToolUse content"), + } + } + + // Insert the tool result (simulating what the caller does) + thread + .pending_message() + .tool_results + .insert(result.tool_use_id.clone(), result); + + // Verify the tool result was added + let last_message = thread.pending_message(); + assert_eq!( + last_message.tool_results.len(), + 1, + "Should have one tool_result" + ); + assert!(last_message.tool_results.contains_key(&tool_use_id)); + }) } } diff --git a/crates/agent/src/tools/streaming_edit_file_tool.rs b/crates/agent/src/tools/streaming_edit_file_tool.rs index bc99515e499696e3df11101be8b813afa027c8f4..47da35bbf25ad188f3f6b98e843b2955910bb7ac 100644 --- a/crates/agent/src/tools/streaming_edit_file_tool.rs +++ b/crates/agent/src/tools/streaming_edit_file_tool.rs @@ -2,6 +2,7 @@ use super::edit_file_tool::EditFileTool; use super::restore_file_from_disk_tool::RestoreFileFromDiskTool; use super::save_file_tool::SaveFileTool; use super::tool_edit_parser::{ToolEditEvent, ToolEditParser}; +use crate::ToolInputPayload; use crate::{ AgentTool, Thread, ToolCallEventStream, ToolInput, edit_agent::{ @@ -12,7 +13,7 @@ use crate::{ use acp_thread::Diff; use action_log::ActionLog; use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields}; -use anyhow::{Context as _, Result}; +use anyhow::Result; use collections::HashSet; use futures::FutureExt as _; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; @@ -188,6 +189,10 @@ pub enum StreamingEditFileToolOutput { }, Error { error: String, + #[serde(default)] + input_path: Option, + #[serde(default)] + diff: String, }, } @@ -195,6 +200,8 @@ impl StreamingEditFileToolOutput { pub fn error(error: impl Into) -> Self { Self::Error { error: error.into(), + input_path: None, + diff: String::new(), } } } @@ -215,7 +222,24 @@ impl std::fmt::Display for StreamingEditFileToolOutput { ) } } - StreamingEditFileToolOutput::Error { error } => write!(f, "{error}"), + StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } => { + write!(f, "{error}\n")?; + if let Some(input_path) = input_path + && !diff.is_empty() + { + write!( + f, + "Edited {}:\n\n```diff\n{diff}\n```", + input_path.display() + ) + } else { + write!(f, "No edits were made.") + } + } } } } @@ -233,6 +257,14 @@ pub struct StreamingEditFileTool { language_registry: Arc, } +enum EditSessionResult { + Completed(EditSession), + Failed { + error: String, + session: Option, + }, +} + impl StreamingEditFileTool { pub fn new( project: Entity, @@ -276,6 +308,158 @@ impl StreamingEditFileTool { }); } } + + async fn ensure_buffer_saved(&self, buffer: &Entity, cx: &mut AsyncApp) { + let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| { + let settings = language_settings::LanguageSettings::for_buffer(buffer, cx); + settings.format_on_save != FormatOnSave::Off + }); + + if format_on_save_enabled { + self.project + .update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, + FormatTrigger::Save, + cx, + ) + }) + .await + .log_err(); + } + + self.project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .log_err(); + + self.action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + }); + } + + async fn process_streaming_edits( + &self, + input: &mut ToolInput, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> EditSessionResult { + let mut session: Option = None; + let mut last_partial: Option = None; + + loop { + futures::select! { + payload = input.next().fuse() => { + match payload { + Ok(payload) => match payload { + ToolInputPayload::Partial(partial) => { + if let Ok(parsed) = serde_json::from_value::(partial) { + let path_complete = parsed.path.is_some() + && parsed.path.as_ref() == last_partial.as_ref().and_then(|partial| partial.path.as_ref()); + + last_partial = Some(parsed.clone()); + + if session.is_none() + && path_complete + && let StreamingEditFileToolPartialInput { + path: Some(path), + display_description: Some(display_description), + mode: Some(mode), + .. + } = &parsed + { + match EditSession::new( + PathBuf::from(path), + display_description, + *mode, + self, + event_stream, + cx, + ) + .await + { + Ok(created_session) => session = Some(created_session), + Err(error) => { + log::error!("Failed to create edit session: {}", error); + return EditSessionResult::Failed { + error, + session: None, + }; + } + } + } + + if let Some(current_session) = &mut session + && let Err(error) = current_session.process(parsed, self, event_stream, cx) + { + log::error!("Failed to process edit: {}", error); + return EditSessionResult::Failed { error, session }; + } + } + } + ToolInputPayload::Full(full_input) => { + let mut session = if let Some(session) = session { + session + } else { + match EditSession::new( + full_input.path.clone(), + &full_input.display_description, + full_input.mode, + self, + event_stream, + cx, + ) + .await + { + Ok(created_session) => created_session, + Err(error) => { + log::error!("Failed to create edit session: {}", error); + return EditSessionResult::Failed { + error, + session: None, + }; + } + } + }; + + return match session.finalize(full_input, self, event_stream, cx).await { + Ok(()) => EditSessionResult::Completed(session), + Err(error) => { + log::error!("Failed to finalize edit: {}", error); + EditSessionResult::Failed { + error, + session: Some(session), + } + } + }; + } + ToolInputPayload::InvalidJson { error_message } => { + log::error!("Received invalid JSON: {error_message}"); + return EditSessionResult::Failed { + error: error_message, + session, + }; + } + }, + Err(error) => { + return EditSessionResult::Failed { + error: format!("Failed to receive tool input: {error}"), + session, + }; + } + } + } + _ = event_stream.cancelled_by_user().fuse() => { + return EditSessionResult::Failed { + error: "Edit cancelled by user".to_string(), + session, + }; + } + } + } + } } impl AgentTool for StreamingEditFileTool { @@ -348,94 +532,40 @@ impl AgentTool for StreamingEditFileTool { cx: &mut App, ) -> Task> { cx.spawn(async move |cx: &mut AsyncApp| { - let mut state: Option = None; - let mut last_partial: Option = None; - loop { - futures::select! { - partial = input.recv_partial().fuse() => { - let Some(partial_value) = partial else { break }; - if let Ok(parsed) = serde_json::from_value::(partial_value) { - let path_complete = parsed.path.is_some() - && parsed.path.as_ref() == last_partial.as_ref().and_then(|p| p.path.as_ref()); - - last_partial = Some(parsed.clone()); - - if state.is_none() - && path_complete - && let StreamingEditFileToolPartialInput { - path: Some(path), - display_description: Some(display_description), - mode: Some(mode), - .. - } = &parsed - { - match EditSession::new( - &PathBuf::from(path), - display_description, - *mode, - &self, - &event_stream, - cx, - ) - .await - { - Ok(session) => state = Some(session), - Err(e) => { - log::error!("Failed to create edit session: {}", e); - return Err(e); - } - } - } - - if let Some(state) = &mut state { - if let Err(e) = state.process(parsed, &self, &event_stream, cx) { - log::error!("Failed to process edit: {}", e); - return Err(e); - } - } - } - } - _ = event_stream.cancelled_by_user().fuse() => { - return Err(StreamingEditFileToolOutput::error("Edit cancelled by user")); - } - } - } - let full_input = - input - .recv() - .await - .map_err(|e| { - let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}")); - log::error!("Failed to receive tool input: {e}"); - err - })?; - - let mut state = if let Some(state) = state { - state - } else { - match EditSession::new( - &full_input.path, - &full_input.display_description, - full_input.mode, - &self, - &event_stream, - cx, - ) + match self + .process_streaming_edits(&mut input, &event_stream, cx) .await - { - Ok(session) => session, - Err(e) => { - log::error!("Failed to create edit session: {}", e); - return Err(e); - } + { + EditSessionResult::Completed(session) => { + self.ensure_buffer_saved(&session.buffer, cx).await; + let (new_text, diff) = session.compute_new_text_and_diff(cx).await; + Ok(StreamingEditFileToolOutput::Success { + old_text: session.old_text.clone(), + new_text, + input_path: session.input_path, + diff, + }) } - }; - match state.finalize(full_input, &self, &event_stream, cx).await { - Ok(output) => Ok(output), - Err(e) => { - log::error!("Failed to finalize edit: {}", e); - Err(e) + EditSessionResult::Failed { + error, + session: Some(session), + } => { + self.ensure_buffer_saved(&session.buffer, cx).await; + let (_new_text, diff) = session.compute_new_text_and_diff(cx).await; + Err(StreamingEditFileToolOutput::Error { + error, + input_path: Some(session.input_path), + diff, + }) } + EditSessionResult::Failed { + error, + session: None, + } => Err(StreamingEditFileToolOutput::Error { + error, + input_path: None, + diff: String::new(), + }), } }) } @@ -472,6 +602,7 @@ impl AgentTool for StreamingEditFileTool { pub struct EditSession { abs_path: PathBuf, + input_path: PathBuf, buffer: Entity, old_text: Arc, diff: Entity, @@ -518,23 +649,21 @@ impl EditPipeline { impl EditSession { async fn new( - path: &PathBuf, + path: PathBuf, display_description: &str, mode: StreamingEditFileMode, tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result { - let project_path = cx - .update(|cx| resolve_path(mode, &path, &tool.project, cx)) - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + ) -> Result { + let project_path = cx.update(|cx| resolve_path(mode, &path, &tool.project, cx))?; let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx)) else { - return Err(StreamingEditFileToolOutput::error(format!( + return Err(format!( "Worktree at '{}' does not exist", path.to_string_lossy() - ))); + )); }; event_stream.update_fields( @@ -543,13 +672,13 @@ impl EditSession { cx.update(|cx| tool.authorize(&path, &display_description, event_stream, cx)) .await - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + .map_err(|e| e.to_string())?; let buffer = tool .project .update(cx, |project, cx| project.open_buffer(project_path, cx)) .await - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + .map_err(|e| e.to_string())?; ensure_buffer_saved(&buffer, &abs_path, tool, cx)?; @@ -578,6 +707,7 @@ impl EditSession { Ok(Self { abs_path, + input_path: path, buffer, old_text, diff, @@ -594,22 +724,20 @@ impl EditSession { tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result { - let old_text = self.old_text.clone(); - + ) -> Result<(), String> { match input.mode { StreamingEditFileMode::Write => { - let content = input.content.ok_or_else(|| { - StreamingEditFileToolOutput::error("'content' field is required for write mode") - })?; + let content = input + .content + .ok_or_else(|| "'content' field is required for write mode".to_string())?; let events = self.parser.finalize_content(&content); self.process_events(&events, tool, event_stream, cx)?; } StreamingEditFileMode::Edit => { - let edits = input.edits.ok_or_else(|| { - StreamingEditFileToolOutput::error("'edits' field is required for edit mode") - })?; + let edits = input + .edits + .ok_or_else(|| "'edits' field is required for edit mode".to_string())?; let events = self.parser.finalize_edits(&edits); self.process_events(&events, tool, event_stream, cx)?; @@ -625,53 +753,15 @@ impl EditSession { } } } + Ok(()) + } - let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| { - let settings = language_settings::LanguageSettings::for_buffer(buffer, cx); - settings.format_on_save != FormatOnSave::Off - }); - - if format_on_save_enabled { - tool.action_log.update(cx, |log, cx| { - log.buffer_edited(self.buffer.clone(), cx); - }); - - let format_task = tool.project.update(cx, |project, cx| { - project.format( - HashSet::from_iter([self.buffer.clone()]), - LspFormatTarget::Buffers, - false, - FormatTrigger::Save, - cx, - ) - }); - futures::select! { - result = format_task.fuse() => { result.log_err(); }, - _ = event_stream.cancelled_by_user().fuse() => { - return Err(StreamingEditFileToolOutput::error("Edit cancelled by user")); - } - }; - } - - let save_task = tool.project.update(cx, |project, cx| { - project.save_buffer(self.buffer.clone(), cx) - }); - futures::select! { - result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; }, - _ = event_stream.cancelled_by_user().fuse() => { - return Err(StreamingEditFileToolOutput::error("Edit cancelled by user")); - } - }; - - tool.action_log.update(cx, |log, cx| { - log.buffer_edited(self.buffer.clone(), cx); - }); - + async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) { let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let (new_text, unified_diff) = cx .background_spawn({ let new_snapshot = new_snapshot.clone(); - let old_text = old_text.clone(); + let old_text = self.old_text.clone(); async move { let new_text = new_snapshot.text(); let diff = language::unified_diff(&old_text, &new_text); @@ -679,14 +769,7 @@ impl EditSession { } }) .await; - - let output = StreamingEditFileToolOutput::Success { - input_path: input.path, - new_text, - old_text: old_text.clone(), - diff: unified_diff, - }; - Ok(output) + (new_text, unified_diff) } fn process( @@ -695,7 +778,7 @@ impl EditSession { tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result<(), StreamingEditFileToolOutput> { + ) -> Result<(), String> { match &self.mode { StreamingEditFileMode::Write => { if let Some(content) = &partial.content { @@ -719,7 +802,7 @@ impl EditSession { tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, - ) -> Result<(), StreamingEditFileToolOutput> { + ) -> Result<(), String> { for event in events { match event { ToolEditEvent::ContentChunk { chunk } => { @@ -969,14 +1052,14 @@ fn extract_match( buffer: &Entity, edit_index: &usize, cx: &mut AsyncApp, -) -> Result, StreamingEditFileToolOutput> { +) -> Result, String> { match matches.len() { - 0 => Err(StreamingEditFileToolOutput::error(format!( + 0 => Err(format!( "Could not find matching text for edit at index {}. \ The old_text did not match any content in the file. \ Please read the file again to get the current content.", edit_index, - ))), + )), 1 => Ok(matches.into_iter().next().unwrap()), _ => { let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); @@ -985,12 +1068,12 @@ fn extract_match( .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string()) .collect::>() .join(", "); - Err(StreamingEditFileToolOutput::error(format!( + Err(format!( "Edit {} matched multiple locations in the file at lines: {}. \ Please provide more context in old_text to uniquely \ identify the location.", edit_index, lines - ))) + )) } } } @@ -1022,7 +1105,7 @@ fn ensure_buffer_saved( abs_path: &PathBuf, tool: &StreamingEditFileTool, cx: &mut AsyncApp, -) -> Result<(), StreamingEditFileToolOutput> { +) -> Result<(), String> { let last_read_mtime = tool .action_log .read_with(cx, |log, _| log.file_read_time(abs_path)); @@ -1063,15 +1146,14 @@ fn ensure_buffer_saved( then ask them to save or revert the file manually and inform you when it's ok to proceed." } }; - return Err(StreamingEditFileToolOutput::error(message)); + return Err(message.to_string()); } if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) { if current != last_read { - return Err(StreamingEditFileToolOutput::error( - "The file has been modified since you last read it. \ - Please read the file again to get the current state before editing it.", - )); + return Err("The file has been modified since you last read it. \ + Please read the file again to get the current state before editing it." + .to_string()); } } @@ -1083,56 +1165,63 @@ fn resolve_path( path: &PathBuf, project: &Entity, cx: &mut App, -) -> Result { +) -> Result { let project = project.read(cx); match mode { StreamingEditFileMode::Edit => { let path = project .find_project_path(&path, cx) - .context("Can't edit file: path not found")?; + .ok_or_else(|| "Can't edit file: path not found".to_string())?; let entry = project .entry_for_path(&path, cx) - .context("Can't edit file: path not found")?; + .ok_or_else(|| "Can't edit file: path not found".to_string())?; - anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory"); - Ok(path) + if entry.is_file() { + Ok(path) + } else { + Err("Can't edit file: path is a directory".to_string()) + } } StreamingEditFileMode::Write => { if let Some(path) = project.find_project_path(&path, cx) && let Some(entry) = project.entry_for_path(&path, cx) { - anyhow::ensure!(entry.is_file(), "Can't write to file: path is a directory"); - return Ok(path); + if entry.is_file() { + return Ok(path); + } else { + return Err("Can't write to file: path is a directory".to_string()); + } } - let parent_path = path.parent().context("Can't create file: incorrect path")?; + let parent_path = path + .parent() + .ok_or_else(|| "Can't create file: incorrect path".to_string())?; let parent_project_path = project.find_project_path(&parent_path, cx); let parent_entry = parent_project_path .as_ref() .and_then(|path| project.entry_for_path(path, cx)) - .context("Can't create file: parent directory doesn't exist")?; + .ok_or_else(|| "Can't create file: parent directory doesn't exist")?; - anyhow::ensure!( - parent_entry.is_dir(), - "Can't create file: parent is not a directory" - ); + if !parent_entry.is_dir() { + return Err("Can't create file: parent is not a directory".to_string()); + } let file_name = path .file_name() .and_then(|file_name| file_name.to_str()) .and_then(|file_name| RelPath::unix(file_name).ok()) - .context("Can't create file: invalid filename")?; + .ok_or_else(|| "Can't create file: invalid filename".to_string())?; let new_file_path = parent_project_path.map(|parent| ProjectPath { path: parent.path.join(file_name), ..parent }); - new_file_path.context("Can't create file") + new_file_path.ok_or_else(|| "Can't create file".to_string()) } } } @@ -1382,10 +1471,17 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; assert_eq!(error, "Can't edit file: path not found"); + assert!(diff.is_empty()); + assert_eq!(input_path, None); } #[gpui::test] @@ -1411,7 +1507,7 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else { panic!("expected error"); }; assert!( @@ -1424,7 +1520,7 @@ mod tests { async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1447,7 +1543,7 @@ mod tests { cx.run_until_parked(); // Now send the final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit lines", "path": "root/file.txt", "mode": "edit", @@ -1465,7 +1561,7 @@ mod tests { async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1485,7 +1581,7 @@ mod tests { cx.run_until_parked(); // Send final - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "path": "root/file.txt", "mode": "write", @@ -1503,7 +1599,7 @@ mod tests { async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver, mut cancellation_tx) = ToolCallEventStream::test_with_cancellation(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1521,7 +1617,7 @@ mod tests { drop(sender); let result = task.await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else { panic!("expected error"); }; assert!( @@ -1537,7 +1633,7 @@ mod tests { json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1578,7 +1674,7 @@ mod tests { cx.run_until_parked(); // Send final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit multiple lines", "path": "root/file.txt", "mode": "edit", @@ -1601,7 +1697,7 @@ mod tests { #[gpui::test] async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1625,7 +1721,7 @@ mod tests { cx.run_until_parked(); // Final with full content - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create new file", "path": "root/dir/new_file.txt", "mode": "write", @@ -1643,12 +1739,12 @@ mod tests { async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send final immediately with no partials (simulates non-streaming path) - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit lines", "path": "root/file.txt", "mode": "edit", @@ -1669,7 +1765,7 @@ mod tests { json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1739,7 +1835,7 @@ mod tests { ); // Send final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit multiple lines", "path": "root/file.txt", "mode": "edit", @@ -1767,7 +1863,7 @@ mod tests { async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1835,7 +1931,7 @@ mod tests { assert_eq!(buffer_text.as_deref(), Some("AAA\nbbb\nCCC\nddd\nEEEeee\n")); // Send final - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit three lines", "path": "root/file.txt", "mode": "edit", @@ -1857,7 +1953,7 @@ mod tests { async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1893,16 +1989,17 @@ mod tests { })); cx.run_until_parked(); - // Verify edit 1 was applied - let buffer_text = project.update(cx, |project, cx| { + let buffer = project.update(cx, |project, cx| { let pp = project .find_project_path(&PathBuf::from("root/file.txt"), cx) .unwrap(); - project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text()) + project.get_open_buffer(&pp, cx).unwrap() }); + + // Verify edit 1 was applied + let buffer_text = buffer.read_with(cx, |buffer, _cx| buffer.text()); assert_eq!( - buffer_text.as_deref(), - Some("MODIFIED\nline 2\nline 3\n"), + buffer_text, "MODIFIED\nline 2\nline 3\n", "First edit should be applied even though second edit will fail" ); @@ -1925,20 +2022,32 @@ mod tests { drop(sender); let result = task.await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; + assert!( error.contains("Could not find matching text for edit at index 1"), "Expected error about edit 1 failing, got: {error}" ); + // Ensure that first edit was applied successfully and that we saved the buffer + assert_eq!(input_path, Some(PathBuf::from("root/file.txt"))); + assert_eq!( + diff, + "@@ -1,3 +1,3 @@\n-line 1\n+MODIFIED\n line 2\n line 3\n" + ); } #[gpui::test] async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -1975,7 +2084,7 @@ mod tests { ); // Send final — the edit is applied during finalization - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Single edit", "path": "root/file.txt", "mode": "edit", @@ -1993,7 +2102,7 @@ mod tests { async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; - let (sender, input): (ToolInputSender, ToolInput) = + let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2020,7 +2129,7 @@ mod tests { cx.run_until_parked(); // Send the final complete input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit lines", "path": "root/file.txt", "mode": "edit", @@ -2038,7 +2147,7 @@ mod tests { async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world\n"})).await; - let (sender, input): (ToolInputSender, ToolInput) = + let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2064,7 +2173,7 @@ mod tests { // Create a channel and send multiple partials before a final, then use // ToolInput::resolved-style immediate delivery to confirm recv() works // when partials are already buffered. - let (sender, input): (ToolInputSender, ToolInput) = + let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2077,7 +2186,7 @@ mod tests { "path": "root/dir/new.txt", "mode": "write" })); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create", "path": "root/dir/new.txt", "mode": "write", @@ -2109,13 +2218,13 @@ mod tests { let result = test_resolve_path(&mode, "root/dir/subdir", cx); assert_eq!( - result.await.unwrap_err().to_string(), + result.await.unwrap_err(), "Can't write to file: path is a directory" ); let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx); assert_eq!( - result.await.unwrap_err().to_string(), + result.await.unwrap_err(), "Can't create file: parent directory doesn't exist" ); } @@ -2133,14 +2242,11 @@ mod tests { assert_resolved_path_eq(result.await, rel_path(path_without_root)); let result = test_resolve_path(&mode, "root/nonexistent.txt", cx); - assert_eq!( - result.await.unwrap_err().to_string(), - "Can't edit file: path not found" - ); + assert_eq!(result.await.unwrap_err(), "Can't edit file: path not found"); let result = test_resolve_path(&mode, "root/dir", cx); assert_eq!( - result.await.unwrap_err().to_string(), + result.await.unwrap_err(), "Can't edit file: path is a directory" ); } @@ -2149,7 +2255,7 @@ mod tests { mode: &StreamingEditFileMode, path: &str, cx: &mut TestAppContext, - ) -> anyhow::Result { + ) -> Result { init_test(cx); let fs = project::FakeFs::new(cx.executor()); @@ -2170,7 +2276,7 @@ mod tests { } #[track_caller] - fn assert_resolved_path_eq(path: anyhow::Result, expected: &RelPath) { + fn assert_resolved_path_eq(path: Result, expected: &RelPath) { let actual = path.expect("Should return valid path").path; assert_eq!(actual.as_ref(), expected); } @@ -2259,7 +2365,7 @@ mod tests { }); // Use streaming pattern so executor can pump the LSP request/response - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -2271,7 +2377,7 @@ mod tests { })); cx.run_until_parked(); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create main function", "path": "root/src/main.rs", "mode": "write", @@ -2310,7 +2416,7 @@ mod tests { }); }); - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let tool2 = Arc::new(StreamingEditFileTool::new( @@ -2329,7 +2435,7 @@ mod tests { })); cx.run_until_parked(); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Update main function", "path": "root/src/main.rs", "mode": "write", @@ -3288,14 +3394,22 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; + assert!( error.contains("has been modified since you last read it"), "Error should mention file modification, got: {}", error ); + assert!(diff.is_empty()); + assert!(input_path.is_none()); } #[gpui::test] @@ -3362,7 +3476,12 @@ mod tests { }) .await; - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + let StreamingEditFileToolOutput::Error { + error, + diff, + input_path, + } = result.unwrap_err() + else { panic!("expected error"); }; assert!( @@ -3380,6 +3499,8 @@ mod tests { "Error should ask user to manually save or revert when tools aren't available, got: {}", error ); + assert!(diff.is_empty()); + assert!(input_path.is_none()); } #[gpui::test] @@ -3390,7 +3511,7 @@ mod tests { // the modified buffer and succeeds. let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3420,7 +3541,7 @@ mod tests { cx.run_until_parked(); // Send the final input with all three edits. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overlapping edits", "path": "root/file.txt", "mode": "edit", @@ -3441,7 +3562,7 @@ mod tests { #[gpui::test] async fn test_streaming_create_content_streamed(cx: &mut TestAppContext) { let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3495,7 +3616,7 @@ mod tests { ); // Send final input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Create new file", "path": "root/dir/new_file.txt", "mode": "write", @@ -3516,7 +3637,7 @@ mod tests { json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, mut receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3559,7 +3680,7 @@ mod tests { }); // Send final input - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "path": "root/file.txt", "mode": "write", @@ -3587,7 +3708,7 @@ mod tests { json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), ) .await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3634,7 +3755,7 @@ mod tests { ); // Send final input with complete content - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "path": "root/file.txt", "mode": "write", @@ -3656,7 +3777,7 @@ mod tests { async fn test_streaming_edit_json_fixer_escape_corruption(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello\nworld\nfoo\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3690,7 +3811,7 @@ mod tests { cx.run_until_parked(); // Send final. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit", "path": "root/file.txt", "mode": "edit", @@ -3708,7 +3829,7 @@ mod tests { async fn test_streaming_final_input_stringified_edits_succeeds(cx: &mut TestAppContext) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello\nworld\n"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3719,7 +3840,7 @@ mod tests { })); cx.run_until_parked(); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Edit", "path": "root/file.txt", "mode": "edit", @@ -3823,7 +3944,7 @@ mod tests { ) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "old_content"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3849,7 +3970,7 @@ mod tests { cx.run_until_parked(); // Send final. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "mode": "write", "content": "new_content", @@ -3869,7 +3990,7 @@ mod tests { ) { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "old_content"})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); @@ -3902,7 +4023,7 @@ mod tests { cx.run_until_parked(); // Send final. - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Overwrite file", "mode": "edit", "edits": [{"old_text": "old_content", "new_text": "new_content"}], @@ -3939,11 +4060,11 @@ mod tests { let old_text = "}\n\n\n\nfn render_search"; let new_text = "}\n\nfn render_search"; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "Remove extra blank lines", "path": "root/file.rs", "mode": "edit", @@ -3980,11 +4101,11 @@ mod tests { let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.rs": file_content})).await; - let (sender, input) = ToolInput::::test(); + let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - sender.send_final(json!({ + sender.send_full(json!({ "display_description": "description", "path": "root/file.rs", "mode": "edit", diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 5fd39509df4ec2263e47c7e87b3e4b7852eaf154..41900e71e5d3ad7e5327ee7e04f73cb05eed5a5b 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -5175,7 +5175,7 @@ mod tests { multi_workspace .read_with(cx, |multi_workspace, _cx| { assert_eq!( - multi_workspace.workspaces().len(), + multi_workspace.workspaces().count(), 1, "LocalProject should not create a new workspace" ); @@ -5451,6 +5451,11 @@ mod tests { let multi_workspace = cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + multi_workspace + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); let workspace = multi_workspace .read_with(cx, |multi_workspace, _cx| { @@ -5538,15 +5543,14 @@ mod tests { .read_with(cx, |multi_workspace, cx| { // There should be more than one workspace now (the original + the new worktree). assert!( - multi_workspace.workspaces().len() > 1, + multi_workspace.workspaces().count() > 1, "expected a new workspace to have been created, found {}", - multi_workspace.workspaces().len(), + multi_workspace.workspaces().count(), ); // Check the newest workspace's panel for the correct agent. let new_workspace = multi_workspace .workspaces() - .iter() .find(|ws| ws.entity_id() != workspace.entity_id()) .expect("should find the new workspace"); let new_panel = new_workspace diff --git a/crates/agent_ui/src/conversation_view.rs b/crates/agent_ui/src/conversation_view.rs index ce125a5d7c901ccb6fc89f405f482cbf52b94f5d..7c9acfdf27d5b750afe4b8817af7f657f5fcdecc 100644 --- a/crates/agent_ui/src/conversation_view.rs +++ b/crates/agent_ui/src/conversation_view.rs @@ -812,7 +812,7 @@ impl ConversationView { let agent_id = self.agent.agent_id(); let session_capabilities = Arc::new(RwLock::new(SessionCapabilities::new( thread.read(cx).prompt_capabilities(), - vec![], + thread.read(cx).available_commands().to_vec(), ))); let action_log = thread.read(cx).action_log().clone(); @@ -1448,40 +1448,24 @@ impl ConversationView { self.emit_token_limit_telemetry_if_needed(thread, cx); } AcpThreadEvent::AvailableCommandsUpdated(available_commands) => { - let mut available_commands = available_commands.clone(); + if let Some(thread_view) = self.thread_view(&thread_id) { + let has_commands = !available_commands.is_empty(); - if thread - .read(cx) - .connection() - .auth_methods() - .iter() - .any(|method| method.id().0.as_ref() == "claude-login") - { - available_commands.push(acp::AvailableCommand::new("login", "Authenticate")); - available_commands.push(acp::AvailableCommand::new("logout", "Authenticate")); - } - - let has_commands = !available_commands.is_empty(); - if let Some(active) = self.active_thread() { - active.update(cx, |active, _cx| { - active - .session_capabilities - .write() - .set_available_commands(available_commands); - }); - } - - let agent_display_name = self - .agent_server_store - .read(cx) - .agent_display_name(&self.agent.agent_id()) - .unwrap_or_else(|| self.agent.agent_id().0.to_string().into()); + let agent_display_name = self + .agent_server_store + .read(cx) + .agent_display_name(&self.agent.agent_id()) + .unwrap_or_else(|| self.agent.agent_id().0.to_string().into()); - if let Some(active) = self.active_thread() { let new_placeholder = placeholder_text(agent_display_name.as_ref(), has_commands); - active.update(cx, |active, cx| { - active.message_editor.update(cx, |editor, cx| { + + thread_view.update(cx, |thread_view, cx| { + thread_view + .session_capabilities + .write() + .set_available_commands(available_commands.clone()); + thread_view.message_editor.update(cx, |editor, cx| { editor.set_placeholder_text(&new_placeholder, window, cx); }); }); @@ -2348,9 +2332,9 @@ impl ConversationView { } } + #[cfg(feature = "audio")] fn play_notification_sound(&self, window: &Window, cx: &mut App) { - let settings = AgentSettings::get_global(cx); - let _visible = window.is_window_active() + let visible = window.is_window_active() && if let Some(mw) = window.root::().flatten() { self.agent_panel_visible(&mw, cx) } else { @@ -2358,8 +2342,8 @@ impl ConversationView { .upgrade() .is_some_and(|workspace| AgentPanel::is_visible(&workspace, cx)) }; - #[cfg(feature = "audio")] - if settings.play_sound_when_agent_done.should_play(_visible) { + let settings = AgentSettings::get_global(cx); + if settings.play_sound_when_agent_done.should_play(visible) { Audio::play_sound(Sound::AgentDone, cx); } } @@ -2989,6 +2973,166 @@ pub(crate) mod tests { }); } + #[derive(Clone)] + struct RestoredAvailableCommandsConnection; + + impl AgentConnection for RestoredAvailableCommandsConnection { + fn agent_id(&self) -> AgentId { + AgentId::new("restored-available-commands") + } + + fn telemetry_id(&self) -> SharedString { + "restored-available-commands".into() + } + + fn new_session( + self: Rc, + project: Entity, + _work_dirs: PathList, + cx: &mut App, + ) -> Task>> { + let thread = build_test_thread( + self, + project, + "RestoredAvailableCommandsConnection", + SessionId::new("new-session"), + cx, + ); + Task::ready(Ok(thread)) + } + + fn supports_load_session(&self) -> bool { + true + } + + fn load_session( + self: Rc, + session_id: acp::SessionId, + project: Entity, + _work_dirs: PathList, + _title: Option, + cx: &mut App, + ) -> Task>> { + let thread = build_test_thread( + self, + project, + "RestoredAvailableCommandsConnection", + session_id, + cx, + ); + + thread + .update(cx, |thread, cx| { + thread.handle_session_update( + acp::SessionUpdate::AvailableCommandsUpdate( + acp::AvailableCommandsUpdate::new(vec![acp::AvailableCommand::new( + "help", "Get help", + )]), + ), + cx, + ) + }) + .expect("available commands update should succeed"); + + Task::ready(Ok(thread)) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(())) + } + + fn prompt( + &self, + _id: Option, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {} + + fn into_any(self: Rc) -> Rc { + self + } + } + + #[gpui::test] + async fn test_restored_threads_keep_available_commands(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let (multi_workspace, cx) = + cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()); + + let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx))); + let connection_store = + cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx))); + + let conversation_view = cx.update(|window, cx| { + cx.new(|cx| { + ConversationView::new( + Rc::new(StubAgentServer::new(RestoredAvailableCommandsConnection)), + connection_store, + Agent::Custom { id: "Test".into() }, + Some(SessionId::new("restored-session")), + None, + None, + None, + workspace.downgrade(), + project, + Some(thread_store), + None, + window, + cx, + ) + }) + }); + + cx.run_until_parked(); + + let message_editor = message_editor(&conversation_view, cx); + let editor = + message_editor.update(cx, |message_editor, _cx| message_editor.editor().clone()); + let placeholder = editor.update(cx, |editor, cx| editor.placeholder_text(cx)); + + active_thread(&conversation_view, cx).read_with(cx, |view, _cx| { + let available_commands = view + .session_capabilities + .read() + .available_commands() + .to_vec(); + assert_eq!(available_commands.len(), 1); + assert_eq!(available_commands[0].name.as_str(), "help"); + assert_eq!(available_commands[0].description.as_str(), "Get help"); + }); + + assert_eq!( + placeholder, + Some("Message Test — @ to include context, / for commands".to_string()) + ); + + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("/help", window, cx); + }); + + let contents_result = message_editor + .update(cx, |editor, cx| editor.contents(false, cx)) + .await; + + assert!(contents_result.is_ok()); + } + #[gpui::test] async fn test_resume_thread_uses_session_cwd_when_inside_project(cx: &mut TestAppContext) { init_test(cx); @@ -3375,7 +3519,6 @@ pub(crate) mod tests { // Verify workspace1 is no longer the active workspace multi_workspace_handle .read_with(cx, |mw, _cx| { - assert_eq!(mw.active_workspace_index(), 1); assert_ne!(mw.workspace(), &workspace1); }) .unwrap(); diff --git a/crates/agent_ui/src/conversation_view/thread_view.rs b/crates/agent_ui/src/conversation_view/thread_view.rs index 25af09832f3473aa690c7b205e1b56bab86e9709..685621eb3c93632f1e7410bbbad22b623d5e18c7 100644 --- a/crates/agent_ui/src/conversation_view/thread_view.rs +++ b/crates/agent_ui/src/conversation_view/thread_view.rs @@ -344,7 +344,8 @@ impl ThreadView { ) -> Self { let id = thread.read(cx).session_id().clone(); - let placeholder = placeholder_text(agent_display_name.as_ref(), false); + let has_commands = !session_capabilities.read().available_commands().is_empty(); + let placeholder = placeholder_text(agent_display_name.as_ref(), has_commands); let history_subscription = history.as_ref().map(|h| { cx.observe(h, |this, history, cx| { @@ -7389,9 +7390,8 @@ impl ThreadView { .gap_2() .map(|this| { if card_layout { - this.when(context_ix > 0, |this| { - this.pt_2() - .border_t_1() + this.p_2().when(context_ix > 0, |this| { + this.border_t_1() .border_color(self.tool_card_border_color(cx)) }) } else { diff --git a/crates/agent_ui/src/threads_archive_view.rs b/crates/agent_ui/src/threads_archive_view.rs index f0c02eefc34a03c5c45730ac4b53645c5b15a2e1..13b2aa1a37cd506c338d13db78bce751882e426a 100644 --- a/crates/agent_ui/src/threads_archive_view.rs +++ b/crates/agent_ui/src/threads_archive_view.rs @@ -218,6 +218,13 @@ impl ThreadsArchiveView { handle.focus(window, cx); } + pub fn is_filter_editor_focused(&self, window: &Window, cx: &App) -> bool { + self.filter_editor + .read(cx) + .focus_handle(cx) + .is_focused(window) + } + fn update_items(&mut self, cx: &mut Context) { let sessions = ThreadMetadataStore::global(cx) .read(cx) @@ -346,7 +353,6 @@ impl ThreadsArchiveView { .map(|mw| { mw.read(cx) .workspaces() - .iter() .filter_map(|ws| ws.read(cx).database_id()) .collect() }) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 7ed488b0ba62c10326a0e2154f0d2ba895e20a4f..20316fc3403de0e6212d13d455c5b619000d71b1 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -435,6 +435,7 @@ impl Server { .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(forward_read_only_project_request::) + .add_request_handler(forward_read_only_project_request::) .add_request_handler(forward_mutating_project_request::) .add_request_handler(disallow_guest_request::) .add_request_handler(disallow_guest_request::) diff --git a/crates/collab/tests/integration/git_tests.rs b/crates/collab/tests/integration/git_tests.rs index fdaacd768444bd44d8414247f922f38afb7e81d5..2fa67b072f1c3d49ef5ca1b90056fd08d57df1ba 100644 --- a/crates/collab/tests/integration/git_tests.rs +++ b/crates/collab/tests/integration/git_tests.rs @@ -424,6 +424,58 @@ async fn test_remote_git_worktrees( ); } +#[gpui::test] +async fn test_remote_git_head_sha( + executor: BackgroundExecutor, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + let mut server = TestServer::start(executor.clone()).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + server + .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)]) + .await; + let active_call_a = cx_a.read(ActiveCall::global); + + client_a + .fs() + .insert_tree( + path!("/project"), + json!({ ".git": {}, "file.txt": "content" }), + ) + .await; + + let (project_a, _) = client_a.build_local_project(path!("/project"), cx_a).await; + let local_head_sha = cx_a.update(|cx| { + project_a + .read(cx) + .active_repository(cx) + .unwrap() + .update(cx, |repository, _| repository.head_sha()) + }); + let local_head_sha = local_head_sha.await.unwrap().unwrap(); + + let project_id = active_call_a + .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx)) + .await + .unwrap(); + let project_b = client_b.join_remote_project(project_id, cx_b).await; + + executor.run_until_parked(); + + let remote_head_sha = cx_b.update(|cx| { + project_b + .read(cx) + .active_repository(cx) + .unwrap() + .update(cx, |repository, _| repository.head_sha()) + }); + let remote_head_sha = remote_head_sha.await.unwrap(); + + assert_eq!(remote_head_sha.unwrap(), local_head_sha); +} + #[gpui::test] async fn test_linked_worktrees_sync( executor: BackgroundExecutor, diff --git a/crates/dev_container/src/devcontainer_manifest.rs b/crates/dev_container/src/devcontainer_manifest.rs index e3a09ae548b68bb4d589d8a214ca1ba5daa9cfa4..5ef82fa3eb2a3ac5d13810e0f6102bec4f42295a 100644 --- a/crates/dev_container/src/devcontainer_manifest.rs +++ b/crates/dev_container/src/devcontainer_manifest.rs @@ -883,7 +883,13 @@ RUN sed -i -E 's/((^|\s)PATH=)([^\$]*)$/\1\${{PATH:-\3}}/g' /etc/profile || true labels: None, build: Some(DockerComposeServiceBuild { context: Some( - features_build_info.empty_context_dir.display().to_string(), + main_service + .build + .as_ref() + .and_then(|b| b.context.clone()) + .unwrap_or_else(|| { + features_build_info.empty_context_dir.display().to_string() + }), ), dockerfile: Some(dockerfile_path.display().to_string()), args: Some(build_args), @@ -3546,6 +3552,27 @@ ENV DOCKER_BUILDKIT=1 "# ); + let build_override = files + .iter() + .find(|f| { + f.file_name() + .is_some_and(|s| s.display().to_string() == "docker_compose_build.json") + }) + .expect("to be found"); + let build_override = test_dependencies.fs.load(build_override).await.unwrap(); + let build_config: DockerComposeConfig = + serde_json_lenient::from_str(&build_override).unwrap(); + let build_context = build_config + .services + .get("app") + .and_then(|s| s.build.as_ref()) + .and_then(|b| b.context.clone()) + .expect("build override should have a context"); + assert_eq!( + build_context, ".", + "build override should preserve the original build context from docker-compose.yml" + ); + let runtime_override = files .iter() .find(|f| { diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 1ba8b27aa785024a47a09c3299a1f3786a028ccf..ea7233cd976148f5eb726730635e0efaf6ceef86 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -2707,6 +2707,65 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte }); } +#[gpui::test] +async fn test_v3_prediction_strips_cursor_marker_from_edit_text(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + + fs.insert_tree( + "/root", + json!({ + "foo.txt": "hello" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project + .find_project_path(path!("root/foo.txt"), cx) + .unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(0, 5)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (request, respond_tx) = requests.predict.next().await.unwrap(); + let excerpt_length = request.input.cursor_excerpt.len(); + respond_tx + .send(PredictEditsV3Response { + request_id: Uuid::new_v4().to_string(), + output: "hello<|user_cursor|> world".to_string(), + editable_range: 0..excerpt_length, + model_version: None, + }) + .unwrap(); + + cx.run_until_parked(); + + ep_store.update(cx, |ep_store, cx| { + let prediction = ep_store + .prediction_at(&buffer, None, &project, cx) + .expect("should have prediction"); + let snapshot = buffer.read(cx).snapshot(); + let edits: Vec<_> = prediction + .edits + .iter() + .map(|(range, text)| (range.to_offset(&snapshot), text.clone())) + .collect(); + + assert_eq!(edits, vec![(5..5, " world".into())]); + }); +} + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index 4486cde22c3429568bf29f152d0f5f2ded59e8f4..a7da51173eefbcdb9e014f7dcca917e6ebebebf5 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -1,10 +1,11 @@ -use crate::udiff::DiffLine; use anyhow::{Context as _, Result}; use serde::{Deserialize, Serialize}; use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc}; use telemetry_events::EditPredictionRating; -pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]"; +pub use zeta_prompt::udiff::{ + CURSOR_POSITION_MARKER, encode_cursor_in_patch, extract_cursor_from_patch, +}; pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>"; /// Maximum cursor file size to capture (64KB). @@ -12,64 +13,6 @@ pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>"; /// falling back to git-based loading. pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024; -/// Encodes a cursor position into a diff patch by adding a comment line with a caret -/// pointing to the cursor column. -/// -/// The cursor offset is relative to the start of the new text content (additions and context lines). -/// Returns the patch with cursor marker comment lines inserted after the relevant addition line. -pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option) -> String { - let Some(cursor_offset) = cursor_offset else { - return patch.to_string(); - }; - - let mut result = String::new(); - let mut line_start_offset = 0usize; - - for line in patch.lines() { - if matches!( - DiffLine::parse(line), - DiffLine::Garbage(content) - if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER) - ) { - continue; - } - - if !result.is_empty() { - result.push('\n'); - } - result.push_str(line); - - match DiffLine::parse(line) { - DiffLine::Addition(content) => { - let line_end_offset = line_start_offset + content.len(); - - if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset { - let cursor_column = cursor_offset - line_start_offset; - - result.push('\n'); - result.push('#'); - for _ in 0..cursor_column { - result.push(' '); - } - write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap(); - } - - line_start_offset = line_end_offset + 1; - } - DiffLine::Context(content) => { - line_start_offset += content.len() + 1; - } - _ => {} - } - } - - if patch.ends_with('\n') { - result.push('\n'); - } - - result -} - #[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] pub struct ExampleSpec { #[serde(default)] @@ -509,53 +452,7 @@ impl ExampleSpec { pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option)> { self.expected_patches .iter() - .map(|patch| { - let mut clean_patch = String::new(); - let mut cursor_offset: Option = None; - let mut line_start_offset = 0usize; - let mut prev_line_start_offset = 0usize; - - for line in patch.lines() { - let diff_line = DiffLine::parse(line); - - match &diff_line { - DiffLine::Garbage(content) - if content.starts_with('#') - && content.contains(CURSOR_POSITION_MARKER) => - { - let caret_column = if let Some(caret_pos) = content.find('^') { - caret_pos - } else if let Some(_) = content.find('<') { - 0 - } else { - continue; - }; - let cursor_column = caret_column.saturating_sub('#'.len_utf8()); - cursor_offset = Some(prev_line_start_offset + cursor_column); - } - _ => { - if !clean_patch.is_empty() { - clean_patch.push('\n'); - } - clean_patch.push_str(line); - - match diff_line { - DiffLine::Addition(content) | DiffLine::Context(content) => { - prev_line_start_offset = line_start_offset; - line_start_offset += content.len() + 1; - } - _ => {} - } - } - } - } - - if patch.ends_with('\n') && !clean_patch.is_empty() { - clean_patch.push('\n'); - } - - (clean_patch, cursor_offset) - }) + .map(|patch| extract_cursor_from_patch(patch)) .collect() } diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index fdfe3ebcf06c8319f5ce00066fa279d79eda7eea..b4556e58b9247624e2d4caeddb5614ff5000d854 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -24,8 +24,9 @@ use zeta_prompt::{ParsedOutput, ZetaPromptInput}; use std::{env, ops::Range, path::Path, sync::Arc}; use zeta_prompt::{ - CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output, - prompt_input_contains_special_tokens, stop_tokens_for_format, + ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output, + parsed_output_from_editable_region, prompt_input_contains_special_tokens, + stop_tokens_for_format, zeta1::{self, EDITABLE_REGION_END_MARKER}, }; @@ -181,6 +182,7 @@ pub fn request_prediction_with_zeta( let parsed_output = output_text.map(|text| ParsedOutput { new_editable_region: text, range_in_excerpt: editable_range_in_excerpt, + cursor_offset_in_new_editable_region: None, }); (request_id, parsed_output, None, None) @@ -283,10 +285,10 @@ pub fn request_prediction_with_zeta( let request_id = EditPredictionId(response.request_id.into()); let output_text = Some(response.output).filter(|s| !s.is_empty()); let model_version = response.model_version; - let parsed_output = ParsedOutput { - new_editable_region: output_text.unwrap_or_default(), - range_in_excerpt: response.editable_range, - }; + let parsed_output = parsed_output_from_editable_region( + response.editable_range, + output_text.unwrap_or_default(), + ); Some((request_id, Some(parsed_output), model_version, usage)) }) @@ -299,6 +301,7 @@ pub fn request_prediction_with_zeta( let Some(ParsedOutput { new_editable_region: mut output_text, range_in_excerpt: editable_range_in_excerpt, + cursor_offset_in_new_editable_region: cursor_offset_in_output, }) = output else { return Ok((Some((request_id, None)), None)); @@ -312,13 +315,6 @@ pub fn request_prediction_with_zeta( .text_for_range(editable_range_in_buffer.clone()) .collect::(); - // Client-side cursor marker processing (applies to both raw and v3 responses) - let cursor_offset_in_output = output_text.find(CURSOR_MARKER); - if let Some(offset) = cursor_offset_in_output { - log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}"); - output_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); - } - if let Some(debug_tx) = &debug_tx { debug_tx .unbounded_send(DebugEvent::EditPredictionFinished( diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index 2b41384e176ac7a6cc5c3dc7f93ddbba3cf027ae..fc85afa371a4edfe8080d602000c38ecedb98c86 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -5,8 +5,7 @@ use crate::{ repair, }; use anyhow::{Context as _, Result}; -use edit_prediction::example_spec::encode_cursor_in_patch; -use zeta_prompt::{CURSOR_MARKER, ZetaFormat, parse_zeta2_model_output}; +use zeta_prompt::{ZetaFormat, parse_zeta2_model_output, parsed_output_to_patch}; pub fn run_parse_output(example: &mut Example) -> Result<()> { example @@ -65,46 +64,18 @@ fn parse_zeta2_output( .context("prompt_inputs required")?; let parsed = parse_zeta2_model_output(actual_output, format, prompt_inputs)?; - let range_in_excerpt = parsed.range_in_excerpt; - + let range_in_excerpt = parsed.range_in_excerpt.clone(); let excerpt = prompt_inputs.cursor_excerpt.as_ref(); - let old_text = excerpt[range_in_excerpt.clone()].to_string(); - let mut new_text = parsed.new_editable_region; - - let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) { - new_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); - Some(offset) - } else { - None - }; + let editable_region_offset = range_in_excerpt.start; + let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); - // Normalize trailing newlines for diff generation - let mut old_text_normalized = old_text; + let mut new_text = parsed.new_editable_region.clone(); if !new_text.is_empty() && !new_text.ends_with('\n') { new_text.push('\n'); } - if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') { - old_text_normalized.push('\n'); - } - - let editable_region_offset = range_in_excerpt.start; - let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); - let editable_region_lines = old_text_normalized.lines().count() as u32; - - let diff = language::unified_diff_with_context( - &old_text_normalized, - &new_text, - editable_region_start_line as u32, - editable_region_start_line as u32, - editable_region_lines, - ); - - let formatted_diff = format!( - "--- a/{path}\n+++ b/{path}\n{diff}", - path = example.spec.cursor_path.to_string_lossy(), - ); - let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset); + let cursor_offset = parsed.cursor_offset_in_new_editable_region; + let formatted_diff = parsed_output_to_patch(prompt_inputs, parsed)?; let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| { ActualCursor::from_editable_region( diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index c25b0ded5daea0674629ce4bea00736cb2eb3ffb..751796fb83164b78dc5d6789f0ae7870eff16ce1 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -36,8 +36,16 @@ pub struct FakeGitRepository { pub(crate) is_trusted: Arc, } +#[derive(Debug, Clone)] +pub struct FakeCommitSnapshot { + pub head_contents: HashMap, + pub index_contents: HashMap, + pub sha: String, +} + #[derive(Debug, Clone)] pub struct FakeGitRepositoryState { + pub commit_history: Vec, pub event_emitter: smol::channel::Sender, pub unmerged_paths: HashMap, pub head_contents: HashMap, @@ -74,6 +82,7 @@ impl FakeGitRepositoryState { oids: Default::default(), remotes: HashMap::default(), graph_commits: Vec::new(), + commit_history: Vec::new(), stash_entries: Default::default(), } } @@ -217,11 +226,52 @@ impl GitRepository for FakeGitRepository { fn reset( &self, - _commit: String, - _mode: ResetMode, + commit: String, + mode: ResetMode, _env: Arc>, ) -> BoxFuture<'_, Result<()>> { - unimplemented!() + self.with_state_async(true, move |state| { + let pop_count = if commit == "HEAD~" || commit == "HEAD^" { + 1 + } else if let Some(suffix) = commit.strip_prefix("HEAD~") { + suffix + .parse::() + .with_context(|| format!("Invalid HEAD~ offset: {commit}"))? + } else { + match state + .commit_history + .iter() + .rposition(|entry| entry.sha == commit) + { + Some(index) => state.commit_history.len() - index, + None => anyhow::bail!("Unknown commit ref: {commit}"), + } + }; + + if pop_count == 0 || pop_count > state.commit_history.len() { + anyhow::bail!( + "Cannot reset {pop_count} commit(s): only {} in history", + state.commit_history.len() + ); + } + + let target_index = state.commit_history.len() - pop_count; + let snapshot = state.commit_history[target_index].clone(); + state.commit_history.truncate(target_index); + + match mode { + ResetMode::Soft => { + state.head_contents = snapshot.head_contents; + } + ResetMode::Mixed => { + state.head_contents = snapshot.head_contents; + state.index_contents = state.head_contents.clone(); + } + } + + state.refs.insert("HEAD".into(), snapshot.sha); + Ok(()) + }) } fn checkout_files( @@ -490,7 +540,7 @@ impl GitRepository for FakeGitRepository { fn create_worktree( &self, - branch_name: String, + branch_name: Option, path: PathBuf, from_commit: Option, ) -> BoxFuture<'_, Result<()>> { @@ -505,8 +555,10 @@ impl GitRepository for FakeGitRepository { if let Some(message) = &state.simulated_create_worktree_error { anyhow::bail!("{message}"); } - if state.branches.contains(&branch_name) { - bail!("a branch named '{}' already exists", branch_name); + if let Some(ref name) = branch_name { + if state.branches.contains(name) { + bail!("a branch named '{}' already exists", name); + } } Ok(()) })??; @@ -515,13 +567,22 @@ impl GitRepository for FakeGitRepository { fs.create_dir(&path).await?; // Create .git/worktrees// directory with HEAD, commondir, gitdir. - let ref_name = format!("refs/heads/{branch_name}"); - let worktrees_entry_dir = common_dir_path.join("worktrees").join(&branch_name); + let worktree_entry_name = branch_name + .as_deref() + .unwrap_or_else(|| path.file_name().unwrap().to_str().unwrap()); + let worktrees_entry_dir = common_dir_path.join("worktrees").join(worktree_entry_name); fs.create_dir(&worktrees_entry_dir).await?; + let sha = from_commit.unwrap_or_else(|| "fake-sha".to_string()); + let head_content = if let Some(ref branch_name) = branch_name { + let ref_name = format!("refs/heads/{branch_name}"); + format!("ref: {ref_name}") + } else { + sha.clone() + }; fs.write_file_internal( worktrees_entry_dir.join("HEAD"), - format!("ref: {ref_name}").into_bytes(), + head_content.into_bytes(), false, )?; fs.write_file_internal( @@ -544,10 +605,12 @@ impl GitRepository for FakeGitRepository { )?; // Update git state: add ref and branch. - let sha = from_commit.unwrap_or_else(|| "fake-sha".to_string()); fs.with_git_state(&dot_git_path, true, move |state| { - state.refs.insert(ref_name, sha); - state.branches.insert(branch_name); + if let Some(branch_name) = branch_name { + let ref_name = format!("refs/heads/{branch_name}"); + state.refs.insert(ref_name, sha); + state.branches.insert(branch_name); + } Ok::<(), anyhow::Error>(()) })??; Ok(()) @@ -822,11 +885,30 @@ impl GitRepository for FakeGitRepository { &self, _message: gpui::SharedString, _name_and_email: Option<(gpui::SharedString, gpui::SharedString)>, - _options: CommitOptions, + options: CommitOptions, _askpass: AskPassDelegate, _env: Arc>, ) -> BoxFuture<'_, Result<()>> { - async { Ok(()) }.boxed() + self.with_state_async(true, move |state| { + if !options.allow_empty && !options.amend && state.index_contents == state.head_contents + { + anyhow::bail!("nothing to commit (use allow_empty to create an empty commit)"); + } + + let old_sha = state.refs.get("HEAD").cloned().unwrap_or_default(); + state.commit_history.push(FakeCommitSnapshot { + head_contents: state.head_contents.clone(), + index_contents: state.index_contents.clone(), + sha: old_sha, + }); + + state.head_contents = state.index_contents.clone(); + + let new_sha = format!("fake-commit-{}", state.commit_history.len()); + state.refs.insert("HEAD".into(), new_sha); + + Ok(()) + }) } fn run_hook( @@ -1210,6 +1292,24 @@ impl GitRepository for FakeGitRepository { anyhow::bail!("commit_data_reader not supported for FakeGitRepository") } + fn update_ref(&self, ref_name: String, commit: String) -> BoxFuture<'_, Result<()>> { + self.with_state_async(true, move |state| { + state.refs.insert(ref_name, commit); + Ok(()) + }) + } + + fn delete_ref(&self, ref_name: String) -> BoxFuture<'_, Result<()>> { + self.with_state_async(true, move |state| { + state.refs.remove(&ref_name); + Ok(()) + }) + } + + fn repair_worktrees(&self) -> BoxFuture<'_, Result<()>> { + async { Ok(()) }.boxed() + } + fn set_trusted(&self, trusted: bool) { self.is_trusted .store(trusted, std::sync::atomic::Ordering::Release); diff --git a/crates/fs/tests/integration/fake_git_repo.rs b/crates/fs/tests/integration/fake_git_repo.rs index 6428083c161235001ef29daf3583520e7f7d25a2..f4192a22bb42f88f8769ef59f817b2bf2a288fb9 100644 --- a/crates/fs/tests/integration/fake_git_repo.rs +++ b/crates/fs/tests/integration/fake_git_repo.rs @@ -24,7 +24,7 @@ async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) { // Create a worktree let worktree_1_dir = worktrees_dir.join("feature-branch"); repo.create_worktree( - "feature-branch".to_string(), + Some("feature-branch".to_string()), worktree_1_dir.clone(), Some("abc123".to_string()), ) @@ -47,9 +47,13 @@ async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) { // Create a second worktree (without explicit commit) let worktree_2_dir = worktrees_dir.join("bugfix-branch"); - repo.create_worktree("bugfix-branch".to_string(), worktree_2_dir.clone(), None) - .await - .unwrap(); + repo.create_worktree( + Some("bugfix-branch".to_string()), + worktree_2_dir.clone(), + None, + ) + .await + .unwrap(); let worktrees = repo.worktrees().await.unwrap(); assert_eq!(worktrees.len(), 3); diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index b03fe1b0c63904bfc751ab7946f92a7c8595db00..c42d2e28cf041e40404c1b8276ddcf5d10ca5f01 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -329,6 +329,7 @@ impl Upstream { pub struct CommitOptions { pub amend: bool, pub signoff: bool, + pub allow_empty: bool, } #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] @@ -715,7 +716,7 @@ pub trait GitRepository: Send + Sync { fn create_worktree( &self, - branch_name: String, + branch_name: Option, path: PathBuf, from_commit: Option, ) -> BoxFuture<'_, Result<()>>; @@ -916,6 +917,12 @@ pub trait GitRepository: Send + Sync { fn commit_data_reader(&self) -> Result; + fn update_ref(&self, ref_name: String, commit: String) -> BoxFuture<'_, Result<()>>; + + fn delete_ref(&self, ref_name: String) -> BoxFuture<'_, Result<()>>; + + fn repair_worktrees(&self) -> BoxFuture<'_, Result<()>>; + fn set_trusted(&self, trusted: bool); fn is_trusted(&self) -> bool; } @@ -1660,19 +1667,20 @@ impl GitRepository for RealGitRepository { fn create_worktree( &self, - branch_name: String, + branch_name: Option, path: PathBuf, from_commit: Option, ) -> BoxFuture<'_, Result<()>> { let git_binary = self.git_binary(); - let mut args = vec![ - OsString::from("worktree"), - OsString::from("add"), - OsString::from("-b"), - OsString::from(branch_name.as_str()), - OsString::from("--"), - OsString::from(path.as_os_str()), - ]; + let mut args = vec![OsString::from("worktree"), OsString::from("add")]; + if let Some(branch_name) = &branch_name { + args.push(OsString::from("-b")); + args.push(OsString::from(branch_name.as_str())); + } else { + args.push(OsString::from("--detach")); + } + args.push(OsString::from("--")); + args.push(OsString::from(path.as_os_str())); if let Some(from_commit) = from_commit { args.push(OsString::from(from_commit)); } else { @@ -2165,6 +2173,10 @@ impl GitRepository for RealGitRepository { cmd.arg("--signoff"); } + if options.allow_empty { + cmd.arg("--allow-empty"); + } + if let Some((name, email)) = name_and_email { cmd.arg("--author").arg(&format!("{name} <{email}>")); } @@ -2176,6 +2188,39 @@ impl GitRepository for RealGitRepository { .boxed() } + fn update_ref(&self, ref_name: String, commit: String) -> BoxFuture<'_, Result<()>> { + let git_binary = self.git_binary(); + self.executor + .spawn(async move { + let args: Vec = vec!["update-ref".into(), ref_name.into(), commit.into()]; + git_binary?.run(&args).await?; + Ok(()) + }) + .boxed() + } + + fn delete_ref(&self, ref_name: String) -> BoxFuture<'_, Result<()>> { + let git_binary = self.git_binary(); + self.executor + .spawn(async move { + let args: Vec = vec!["update-ref".into(), "-d".into(), ref_name.into()]; + git_binary?.run(&args).await?; + Ok(()) + }) + .boxed() + } + + fn repair_worktrees(&self) -> BoxFuture<'_, Result<()>> { + let git_binary = self.git_binary(); + self.executor + .spawn(async move { + let args: Vec = vec!["worktree".into(), "repair".into()]; + git_binary?.run(&args).await?; + Ok(()) + }) + .boxed() + } + fn push( &self, branch_name: String, @@ -4009,7 +4054,7 @@ mod tests { // Create a new worktree repo.create_worktree( - "test-branch".to_string(), + Some("test-branch".to_string()), worktree_path.clone(), Some("HEAD".to_string()), ) @@ -4068,7 +4113,7 @@ mod tests { // Create a worktree let worktree_path = worktrees_dir.join("worktree-to-remove"); repo.create_worktree( - "to-remove".to_string(), + Some("to-remove".to_string()), worktree_path.clone(), Some("HEAD".to_string()), ) @@ -4092,7 +4137,7 @@ mod tests { // Create a worktree let worktree_path = worktrees_dir.join("dirty-wt"); repo.create_worktree( - "dirty-wt".to_string(), + Some("dirty-wt".to_string()), worktree_path.clone(), Some("HEAD".to_string()), ) @@ -4162,7 +4207,7 @@ mod tests { // Create a worktree let old_path = worktrees_dir.join("old-worktree-name"); repo.create_worktree( - "old-name".to_string(), + Some("old-name".to_string()), old_path.clone(), Some("HEAD".to_string()), ) diff --git a/crates/git_graph/src/git_graph.rs b/crates/git_graph/src/git_graph.rs index 83cd01eda5c509583f24fd424426d20a55bbfbed..aa5f6bc6e1293cfd057baa0c5e9f77819da71086 100644 --- a/crates/git_graph/src/git_graph.rs +++ b/crates/git_graph/src/git_graph.rs @@ -2394,9 +2394,8 @@ impl GitGraph { let local_y = position_y - canvas_bounds.origin.y; if local_y >= px(0.) && local_y < canvas_bounds.size.height { - let row_in_viewport = (local_y / self.row_height).floor() as usize; - let scroll_rows = (scroll_offset_y / self.row_height).floor() as usize; - let absolute_row = scroll_rows + row_in_viewport; + let absolute_y = local_y + scroll_offset_y; + let absolute_row = (absolute_y / self.row_height).floor() as usize; if absolute_row < self.graph_data.commits.len() { return Some(absolute_row); @@ -4006,4 +4005,76 @@ mod tests { }); assert_eq!(reloaded_shas, vec![updated_head, updated_stash]); } + + #[gpui::test] + async fn test_git_graph_row_at_position_rounding(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + Path::new("/project"), + serde_json::json!({ + ".git": {}, + "file.txt": "content", + }), + ) + .await; + + let mut rng = StdRng::seed_from_u64(42); + let commits = generate_random_commit_dag(&mut rng, 10, false); + fs.set_graph_commits(Path::new("/project/.git"), commits.clone()); + + let project = Project::test(fs.clone(), [Path::new("/project")], cx).await; + cx.run_until_parked(); + + let repository = project.read_with(cx, |project, cx| { + project + .active_repository(cx) + .expect("should have a repository") + }); + + let (multi_workspace, cx) = cx.add_window_view(|window, cx| { + workspace::MultiWorkspace::test_new(project.clone(), window, cx) + }); + + let workspace_weak = + multi_workspace.read_with(&*cx, |multi, _| multi.workspace().downgrade()); + + let git_graph = cx.new_window_entity(|window, cx| { + GitGraph::new( + repository.read(cx).id, + project.read(cx).git_store().clone(), + workspace_weak, + window, + cx, + ) + }); + cx.run_until_parked(); + + git_graph.update(cx, |graph, cx| { + assert!( + graph.graph_data.commits.len() >= 10, + "graph should load dummy commits" + ); + + graph.row_height = px(20.0); + let origin_y = px(100.0); + graph.graph_canvas_bounds.set(Some(Bounds { + origin: point(px(0.0), origin_y), + size: gpui::size(px(100.0), px(1000.0)), + })); + + graph.table_interaction_state.update(cx, |state, _| { + state.set_scroll_offset(point(px(0.0), px(-15.0))) + }); + let pos_y = origin_y + px(10.0); + let absolute_calc_row = graph.row_at_position(pos_y, cx); + + assert_eq!( + absolute_calc_row, + Some(1), + "Row calculation should yield absolute row exactly" + ); + }); + } } diff --git a/crates/git_ui/src/commit_modal.rs b/crates/git_ui/src/commit_modal.rs index 432da803e6eedfec304836198f6111f5418084cc..2088ad77ec5d7e71bdfb42ebcbfab6d001f64375 100644 --- a/crates/git_ui/src/commit_modal.rs +++ b/crates/git_ui/src/commit_modal.rs @@ -453,6 +453,7 @@ impl CommitModal { CommitOptions { amend: is_amend_pending, signoff: is_signoff_enabled, + allow_empty: false, }, window, cx, diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index aac1ec1a19ab53913a830738ae528fb2c0c10248..0cb8ec6b78929d216b700b6e21cbf43a538c6f56 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -2155,6 +2155,7 @@ impl GitPanel { CommitOptions { amend: false, signoff: self.signoff_enabled, + allow_empty: false, }, window, cx, @@ -2195,6 +2196,7 @@ impl GitPanel { CommitOptions { amend: true, signoff: self.signoff_enabled, + allow_empty: false, }, window, cx, @@ -4454,7 +4456,11 @@ impl GitPanel { git_panel .update(cx, |git_panel, cx| { git_panel.commit_changes( - CommitOptions { amend, signoff }, + CommitOptions { + amend, + signoff, + allow_empty: false, + }, window, cx, ); diff --git a/crates/gpui/src/elements/list.rs b/crates/gpui/src/elements/list.rs index 5525f5c17d2ad33e1ce9696afded1cea5447020c..5a88d81c18db5e790b7bbed0fb9def23bc973e14 100644 --- a/crates/gpui/src/elements/list.rs +++ b/crates/gpui/src/elements/list.rs @@ -462,6 +462,13 @@ impl ListState { let current_offset = self.logical_scroll_top(); let state = &mut *self.0.borrow_mut(); + + if distance < px(0.) { + if let FollowState::Tail { is_following } = &mut state.follow_state { + *is_following = false; + } + } + let mut cursor = state.items.cursor::(()); cursor.seek(&Count(current_offset.item_ix), Bias::Right); @@ -536,6 +543,12 @@ impl ListState { scroll_top.offset_in_item = px(0.); } + if scroll_top.item_ix < item_count { + if let FollowState::Tail { is_following } = &mut state.follow_state { + *is_following = false; + } + } + state.logical_scroll_top = Some(scroll_top); } diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index ae01084a2657abdc86e7510aa49663cf98aabe70..50037f31facbac446de7ecf38536d1e4a24c7867 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -125,6 +125,7 @@ pub struct FakeLanguageModel { >, forbid_requests: AtomicBool, supports_thinking: AtomicBool, + supports_streaming_tools: AtomicBool, } impl Default for FakeLanguageModel { @@ -137,6 +138,7 @@ impl Default for FakeLanguageModel { current_completion_txs: Mutex::new(Vec::new()), forbid_requests: AtomicBool::new(false), supports_thinking: AtomicBool::new(false), + supports_streaming_tools: AtomicBool::new(false), } } } @@ -169,6 +171,10 @@ impl FakeLanguageModel { self.supports_thinking.store(supports, SeqCst); } + pub fn set_supports_streaming_tools(&self, supports: bool) { + self.supports_streaming_tools.store(supports, SeqCst); + } + pub fn pending_completions(&self) -> Vec { self.current_completion_txs .lock() @@ -282,6 +288,10 @@ impl LanguageModel for FakeLanguageModel { self.supports_thinking.load(SeqCst) } + fn supports_streaming_tools(&self) -> bool { + self.supports_streaming_tools.load(SeqCst) + } + fn telemetry_id(&self) -> String { "fake".to_string() } diff --git a/crates/picker/src/picker.rs b/crates/picker/src/picker.rs index 1e529cd53f2d2527af8525886d11dbcddbf33a34..eba5b3096194fe8a3379efeb9b230a6004cd2e36 100644 --- a/crates/picker/src/picker.rs +++ b/crates/picker/src/picker.rs @@ -121,6 +121,9 @@ pub trait PickerDelegate: Sized + 'static { ) -> bool { true } + fn select_on_hover(&self) -> bool { + true + } // Allows binding some optional effect to when the selection changes. fn selected_index_changed( @@ -788,12 +791,14 @@ impl Picker { this.handle_click(ix, event.modifiers.platform, window, cx) }), ) - .on_hover(cx.listener(move |this, hovered: &bool, window, cx| { - if *hovered { - this.set_selected_index(ix, None, false, window, cx); - cx.notify(); - } - })) + .when(self.delegate.select_on_hover(), |this| { + this.on_hover(cx.listener(move |this, hovered: &bool, window, cx| { + if *hovered { + this.set_selected_index(ix, None, false, window, cx); + cx.notify(); + } + })) + }) .children(self.delegate.render_match( ix, ix == self.delegate.selected_index(), diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index 6bc7f1ab52db8665efac7ab5631986b5ec0c8e33..e7e84ffe673881d898a56b64892887b9c8d6c809 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -329,6 +329,12 @@ pub struct GraphDataResponse<'a> { pub error: Option, } +#[derive(Clone, Debug)] +enum CreateWorktreeStartPoint { + Detached, + Branched { name: String }, +} + pub struct Repository { this: WeakEntity, snapshot: RepositorySnapshot, @@ -588,6 +594,7 @@ impl GitStore { client.add_entity_request_handler(Self::handle_create_worktree); client.add_entity_request_handler(Self::handle_remove_worktree); client.add_entity_request_handler(Self::handle_rename_worktree); + client.add_entity_request_handler(Self::handle_get_head_sha); } pub fn is_local(&self) -> bool { @@ -2340,6 +2347,7 @@ impl GitStore { CommitOptions { amend: options.amend, signoff: options.signoff, + allow_empty: options.allow_empty, }, askpass, cx, @@ -2406,12 +2414,18 @@ impl GitStore { let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; let directory = PathBuf::from(envelope.payload.directory); - let name = envelope.payload.name; + let start_point = if envelope.payload.name.is_empty() { + CreateWorktreeStartPoint::Detached + } else { + CreateWorktreeStartPoint::Branched { + name: envelope.payload.name, + } + }; let commit = envelope.payload.commit; repository_handle .update(&mut cx, |repository_handle, _| { - repository_handle.create_worktree(name, directory, commit) + repository_handle.create_worktree_with_start_point(start_point, directory, commit) }) .await??; @@ -2456,6 +2470,21 @@ impl GitStore { Ok(proto::Ack {}) } + async fn handle_get_head_sha( + this: Entity, + envelope: TypedEnvelope, + mut cx: AsyncApp, + ) -> Result { + let repository_id = RepositoryId::from_proto(envelope.payload.repository_id); + let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?; + + let head_sha = repository_handle + .update(&mut cx, |repository_handle, _| repository_handle.head_sha()) + .await??; + + Ok(proto::GitGetHeadShaResponse { sha: head_sha }) + } + async fn handle_get_branches( this: Entity, envelope: TypedEnvelope, @@ -5493,6 +5522,7 @@ impl Repository { options: Some(proto::commit::CommitOptions { amend: options.amend, signoff: options.signoff, + allow_empty: options.allow_empty, }), askpass_id, }) @@ -5974,36 +6004,174 @@ impl Repository { }) } + fn create_worktree_with_start_point( + &mut self, + start_point: CreateWorktreeStartPoint, + path: PathBuf, + commit: Option, + ) -> oneshot::Receiver> { + if matches!( + &start_point, + CreateWorktreeStartPoint::Branched { name } if name.is_empty() + ) { + let (sender, receiver) = oneshot::channel(); + sender + .send(Err(anyhow!("branch name cannot be empty"))) + .ok(); + return receiver; + } + + let id = self.id; + let message = match &start_point { + CreateWorktreeStartPoint::Detached => "git worktree add (detached)".into(), + CreateWorktreeStartPoint::Branched { name } => { + format!("git worktree add: {name}").into() + } + }; + + self.send_job(Some(message), move |repo, _cx| async move { + let branch_name = match start_point { + CreateWorktreeStartPoint::Detached => None, + CreateWorktreeStartPoint::Branched { name } => Some(name), + }; + let remote_name = branch_name.clone().unwrap_or_default(); + + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + backend.create_worktree(branch_name, path, commit).await + } + RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => { + client + .request(proto::GitCreateWorktree { + project_id: project_id.0, + repository_id: id.to_proto(), + name: remote_name, + directory: path.to_string_lossy().to_string(), + commit, + }) + .await?; + + Ok(()) + } + } + }) + } + pub fn create_worktree( &mut self, branch_name: String, path: PathBuf, commit: Option, ) -> oneshot::Receiver> { + self.create_worktree_with_start_point( + CreateWorktreeStartPoint::Branched { name: branch_name }, + path, + commit, + ) + } + + pub fn create_worktree_detached( + &mut self, + path: PathBuf, + commit: String, + ) -> oneshot::Receiver> { + self.create_worktree_with_start_point( + CreateWorktreeStartPoint::Detached, + path, + Some(commit), + ) + } + + pub fn head_sha(&mut self) -> oneshot::Receiver>> { let id = self.id; - self.send_job( - Some(format!("git worktree add: {}", branch_name).into()), - move |repo, _cx| async move { - match repo { - RepositoryState::Local(LocalRepositoryState { backend, .. }) => { - backend.create_worktree(branch_name, path, commit).await - } - RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => { - client - .request(proto::GitCreateWorktree { - project_id: project_id.0, - repository_id: id.to_proto(), - name: branch_name, - directory: path.to_string_lossy().to_string(), - commit, - }) - .await?; + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + Ok(backend.head_sha().await) + } + RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => { + let response = client + .request(proto::GitGetHeadSha { + project_id: project_id.0, + repository_id: id.to_proto(), + }) + .await?; - Ok(()) - } + Ok(response.sha) } - }, - ) + } + }) + } + + pub fn update_ref( + &mut self, + ref_name: String, + commit: String, + ) -> oneshot::Receiver> { + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + backend.update_ref(ref_name, commit).await + } + RepositoryState::Remote(_) => { + anyhow::bail!("update_ref is not supported for remote repositories") + } + } + }) + } + + pub fn delete_ref(&mut self, ref_name: String) -> oneshot::Receiver> { + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + backend.delete_ref(ref_name).await + } + RepositoryState::Remote(_) => { + anyhow::bail!("delete_ref is not supported for remote repositories") + } + } + }) + } + + pub fn resolve_commit(&mut self, sha: String) -> oneshot::Receiver> { + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + let results = backend.revparse_batch(vec![sha]).await?; + Ok(results.into_iter().next().flatten().is_some()) + } + RepositoryState::Remote(_) => { + anyhow::bail!("resolve_commit is not supported for remote repositories") + } + } + }) + } + + pub fn repair_worktrees(&mut self) -> oneshot::Receiver> { + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + backend.repair_worktrees().await + } + RepositoryState::Remote(_) => { + anyhow::bail!("repair_worktrees is not supported for remote repositories") + } + } + }) + } + + pub fn commit_exists(&mut self, sha: String) -> oneshot::Receiver> { + self.send_job(None, move |repo, _cx| async move { + match repo { + RepositoryState::Local(LocalRepositoryState { backend, .. }) => { + let results = backend.revparse_batch(vec![sha]).await?; + Ok(results.into_iter().next().flatten().is_some()) + } + RepositoryState::Remote(_) => { + anyhow::bail!("commit_exists is not supported for remote repositories") + } + } + }) } pub fn remove_worktree(&mut self, path: PathBuf, force: bool) -> oneshot::Receiver> { diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index 0cbb635d78dddc81aa7c75340f2fbebe83a474e3..9324feb21b1f50ac1041ed0afc8b59cb9b7fe2c6 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -403,6 +403,7 @@ message Commit { message CommitOptions { bool amend = 1; bool signoff = 2; + bool allow_empty = 3; } } @@ -567,6 +568,15 @@ message GitGetWorktrees { uint64 repository_id = 2; } +message GitGetHeadSha { + uint64 project_id = 1; + uint64 repository_id = 2; +} + +message GitGetHeadShaResponse { + optional string sha = 1; +} + message GitWorktreesResponse { repeated Worktree worktrees = 1; } diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 24e7c5372f2679eab1726487e1967edcef6024ed..8b62754d7af40b7c4f5e1a87ad42899d682ba453 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -474,7 +474,9 @@ message Envelope { GitCompareCheckpoints git_compare_checkpoints = 436; GitCompareCheckpointsResponse git_compare_checkpoints_response = 437; GitDiffCheckpoints git_diff_checkpoints = 438; - GitDiffCheckpointsResponse git_diff_checkpoints_response = 439; // current max + GitDiffCheckpointsResponse git_diff_checkpoints_response = 439; + GitGetHeadSha git_get_head_sha = 440; + GitGetHeadShaResponse git_get_head_sha_response = 441; // current max } reserved 87 to 88; diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index dd77d2a2da8d4dbc2c0f91f63cb59dd1591ee3f4..b77bd02313c13a9b04eb7762a97f9e77ac8cbaf8 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -351,6 +351,8 @@ messages!( (NewExternalAgentVersionAvailable, Background), (RemoteStarted, Background), (GitGetWorktrees, Background), + (GitGetHeadSha, Background), + (GitGetHeadShaResponse, Background), (GitWorktreesResponse, Background), (GitCreateWorktree, Background), (GitRemoveWorktree, Background), @@ -558,6 +560,7 @@ request_messages!( (GetContextServerCommand, ContextServerCommand), (RemoteStarted, Ack), (GitGetWorktrees, GitWorktreesResponse), + (GitGetHeadSha, GitGetHeadShaResponse), (GitCreateWorktree, Ack), (GitRemoveWorktree, Ack), (GitRenameWorktree, Ack), @@ -749,6 +752,7 @@ entity_messages!( ExternalAgentLoadingStatusUpdated, NewExternalAgentVersionAvailable, GitGetWorktrees, + GitGetHeadSha, GitCreateWorktree, GitRemoveWorktree, GitRenameWorktree, diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index 24010017ff9fa4eb62a1787332fed70f740ccc2d..e3bfc0dc08c95c0ce57b818e50965433a6c6bc98 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -357,7 +357,6 @@ pub fn init(cx: &mut App) { .update(cx, |multi_workspace, window, cx| { let sibling_workspace_ids: HashSet = multi_workspace .workspaces() - .iter() .filter_map(|ws| ws.read(cx).database_id()) .collect(); @@ -1113,7 +1112,6 @@ impl PickerDelegate for RecentProjectsDelegate { .update(cx, |multi_workspace, window, cx| { let workspace = multi_workspace .workspaces() - .iter() .find(|ws| ws.read(cx).database_id() == Some(workspace_id)) .cloned(); if let Some(workspace) = workspace { @@ -1932,7 +1930,6 @@ impl RecentProjectsDelegate { .update(cx, |multi_workspace, window, cx| { let workspace = multi_workspace .workspaces() - .iter() .find(|ws| ws.read(cx).database_id() == Some(workspace_id)) .cloned(); if let Some(workspace) = workspace { @@ -2055,6 +2052,11 @@ mod tests { assert_eq!(cx.update(|cx| cx.windows().len()), 1); let multi_workspace = cx.update(|cx| cx.windows()[0].downcast::().unwrap()); + multi_workspace + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); multi_workspace .update(cx, |multi_workspace, _, cx| { assert!(!multi_workspace.workspace().read(cx).is_edited()) @@ -2141,7 +2143,7 @@ mod tests { ); assert!( - multi_workspace.workspaces().contains(&dirty_workspace), + multi_workspace.workspaces().any(|w| w == &dirty_workspace), "The dirty workspace should still be present in multi-workspace mode" ); diff --git a/crates/rules_library/src/rules_library.rs b/crates/rules_library/src/rules_library.rs index 7e5a56f22d48c4d51f60d7d200dc8384582beb23..425f7d2aa3d9e9259fe005a0e15dee10e4e4baf1 100644 --- a/crates/rules_library/src/rules_library.rs +++ b/crates/rules_library/src/rules_library.rs @@ -225,6 +225,10 @@ impl PickerDelegate for RulePickerDelegate { } } + fn select_on_hover(&self) -> bool { + false + } + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { "Search…".into() } diff --git a/crates/settings_content/src/settings_content.rs b/crates/settings_content/src/settings_content.rs index 325e86e9e3af0fb888c2691f4be1b0fdeb06dfb4..6c60a7010f7cfc5b4fadf9a8cc386fe6e3267abc 100644 --- a/crates/settings_content/src/settings_content.rs +++ b/crates/settings_content/src/settings_content.rs @@ -763,6 +763,7 @@ pub struct VimSettingsContent { pub toggle_relative_line_numbers: Option, pub use_system_clipboard: Option, pub use_smartcase_find: Option, + pub use_regex_search: Option, /// When enabled, the `:substitute` command replaces all matches in a line /// by default. The 'g' flag then toggles this behavior., pub gdefault: Option, diff --git a/crates/settings_ui/src/page_data.rs b/crates/settings_ui/src/page_data.rs index bacfd227d83933d3ebd9b2d8836bbe19958acf2b..9978832c05bb29c97f118fccbe301214d81fa0c6 100644 --- a/crates/settings_ui/src/page_data.rs +++ b/crates/settings_ui/src/page_data.rs @@ -2447,7 +2447,7 @@ fn editor_page() -> SettingsPage { ] } - fn vim_settings_section() -> [SettingsPageItem; 12] { + fn vim_settings_section() -> [SettingsPageItem; 13] { [ SettingsPageItem::SectionHeader("Vim"), SettingsPageItem::SettingItem(SettingItem { @@ -2556,6 +2556,24 @@ fn editor_page() -> SettingsPage { metadata: None, files: USER, }), + SettingsPageItem::SettingItem(SettingItem { + title: "Regex Search", + description: "Use regex search by default in Vim search.", + field: Box::new(SettingField { + json_path: Some("vim.use_regex_search"), + pick: |settings_content| { + settings_content.vim.as_ref()?.use_regex_search.as_ref() + }, + write: |settings_content, value| { + settings_content + .vim + .get_or_insert_default() + .use_regex_search = value; + }, + }), + metadata: None, + files: USER, + }), SettingsPageItem::SettingItem(SettingItem { title: "Cursor Shape - Normal Mode", description: "Cursor shape for normal mode.", @@ -4433,7 +4451,7 @@ fn window_and_layout_page() -> SettingsPage { } fn panels_page() -> SettingsPage { - fn project_panel_section() -> [SettingsPageItem; 24] { + fn project_panel_section() -> [SettingsPageItem; 28] { [ SettingsPageItem::SectionHeader("Project Panel"), SettingsPageItem::SettingItem(SettingItem { @@ -4914,31 +4932,25 @@ fn panels_page() -> SettingsPage { files: USER, }), SettingsPageItem::SettingItem(SettingItem { - title: "Hidden Files", - description: "Globs to match files that will be considered \"hidden\" and can be hidden from the project panel.", - field: Box::new( - SettingField { - json_path: Some("worktree.hidden_files"), - pick: |settings_content| { - settings_content.project.worktree.hidden_files.as_ref() - }, - write: |settings_content, value| { - settings_content.project.worktree.hidden_files = value; - }, - } - .unimplemented(), - ), + title: "Sort Mode", + description: "Sort order for entries in the project panel.", + field: Box::new(SettingField { + json_path: Some("project_panel.sort_mode"), + pick: |settings_content| { + settings_content.project_panel.as_ref()?.sort_mode.as_ref() + }, + write: |settings_content, value| { + settings_content + .project_panel + .get_or_insert_default() + .sort_mode = value; + }, + }), metadata: None, files: USER, }), - ] - } - - fn auto_open_files_section() -> [SettingsPageItem; 5] { - [ - SettingsPageItem::SectionHeader("Auto Open Files"), SettingsPageItem::SettingItem(SettingItem { - title: "On Create", + title: "Auto Open Files On Create", description: "Whether to automatically open newly created files in the editor.", field: Box::new(SettingField { json_path: Some("project_panel.auto_open.on_create"), @@ -4964,7 +4976,7 @@ fn panels_page() -> SettingsPage { files: USER, }), SettingsPageItem::SettingItem(SettingItem { - title: "On Paste", + title: "Auto Open Files On Paste", description: "Whether to automatically open files after pasting or duplicating them.", field: Box::new(SettingField { json_path: Some("project_panel.auto_open.on_paste"), @@ -4990,7 +5002,7 @@ fn panels_page() -> SettingsPage { files: USER, }), SettingsPageItem::SettingItem(SettingItem { - title: "On Drop", + title: "Auto Open Files On Drop", description: "Whether to automatically open files dropped from external sources.", field: Box::new(SettingField { json_path: Some("project_panel.auto_open.on_drop"), @@ -5016,20 +5028,20 @@ fn panels_page() -> SettingsPage { files: USER, }), SettingsPageItem::SettingItem(SettingItem { - title: "Sort Mode", - description: "Sort order for entries in the project panel.", - field: Box::new(SettingField { - pick: |settings_content| { - settings_content.project_panel.as_ref()?.sort_mode.as_ref() - }, - write: |settings_content, value| { - settings_content - .project_panel - .get_or_insert_default() - .sort_mode = value; - }, - json_path: Some("project_panel.sort_mode"), - }), + title: "Hidden Files", + description: "Globs to match files that will be considered \"hidden\" and can be hidden from the project panel.", + field: Box::new( + SettingField { + json_path: Some("worktree.hidden_files"), + pick: |settings_content| { + settings_content.project.worktree.hidden_files.as_ref() + }, + write: |settings_content, value| { + settings_content.project.worktree.hidden_files = value; + }, + } + .unimplemented(), + ), metadata: None, files: USER, }), @@ -5807,7 +5819,6 @@ fn panels_page() -> SettingsPage { title: "Panels", items: concat_sections![ project_panel_section(), - auto_open_files_section(), terminal_panel_section(), outline_panel_section(), git_panel_section(), diff --git a/crates/settings_ui/src/settings_ui.rs b/crates/settings_ui/src/settings_ui.rs index 634db0e247fdc370c479df0ed4f6d1f84a5284f6..4c7a98f6c0fa94e659a6db4e00aa28e2b4516e13 100644 --- a/crates/settings_ui/src/settings_ui.rs +++ b/crates/settings_ui/src/settings_ui.rs @@ -3753,7 +3753,6 @@ fn all_projects( .flat_map(|multi_workspace| { multi_workspace .workspaces() - .iter() .map(|workspace| workspace.read(cx).project().clone()) .collect::>() }), diff --git a/crates/sidebar/src/sidebar.rs b/crates/sidebar/src/sidebar.rs index 53ae57d1a7c55f66e40e1d704859d689d41045e4..d6589361cd9417c2ac6d9025af92f1e096b341b1 100644 --- a/crates/sidebar/src/sidebar.rs +++ b/crates/sidebar/src/sidebar.rs @@ -434,7 +434,7 @@ impl Sidebar { }) .detach(); - let workspaces = multi_workspace.read(cx).workspaces().to_vec(); + let workspaces: Vec<_> = multi_workspace.read(cx).workspaces().cloned().collect(); cx.defer_in(window, move |this, window, cx| { for workspace in &workspaces { this.subscribe_to_workspace(workspace, window, cx); @@ -673,7 +673,6 @@ impl Sidebar { let mw = self.multi_workspace.upgrade()?; let mw = mw.read(cx); mw.workspaces() - .iter() .find(|ws| ws.read(cx).project_group_key(cx).path_list() == path_list) .cloned() } @@ -716,8 +715,8 @@ impl Sidebar { return; }; let mw = multi_workspace.read(cx); - let workspaces = mw.workspaces().to_vec(); - let active_workspace = mw.workspaces().get(mw.active_workspace_index()).cloned(); + let workspaces: Vec<_> = mw.workspaces().cloned().collect(); + let active_workspace = Some(mw.workspace().clone()); let agent_server_store = workspaces .first() @@ -1769,7 +1768,11 @@ impl Sidebar { dispatch_context.add("ThreadsSidebar"); dispatch_context.add("menu"); - let identifier = if self.filter_editor.focus_handle(cx).is_focused(window) { + let is_archived_search_focused = matches!(&self.view, SidebarView::Archive(archive) if archive.read(cx).is_filter_editor_focused(window, cx)); + + let identifier = if self.filter_editor.focus_handle(cx).is_focused(window) + || is_archived_search_focused + { "searching" } else { "not_searching" @@ -1989,7 +1992,6 @@ impl Sidebar { let workspace = window.read(cx).ok().and_then(|multi_workspace| { multi_workspace .workspaces() - .iter() .find(|workspace| predicate(workspace, cx)) .cloned() })?; @@ -2006,7 +2008,6 @@ impl Sidebar { multi_workspace .read(cx) .workspaces() - .iter() .find(|workspace| predicate(workspace, cx)) .cloned() }) @@ -2199,12 +2200,10 @@ impl Sidebar { return; } - let active_workspace = self.multi_workspace.upgrade().and_then(|w| { - w.read(cx) - .workspaces() - .get(w.read(cx).active_workspace_index()) - .cloned() - }); + let active_workspace = self + .multi_workspace + .upgrade() + .map(|w| w.read(cx).workspace().clone()); if let Some(workspace) = active_workspace { self.activate_thread_locally(&metadata, &workspace, window, cx); @@ -2339,7 +2338,7 @@ impl Sidebar { return; }; - let workspaces = multi_workspace.read(cx).workspaces().to_vec(); + let workspaces: Vec<_> = multi_workspace.read(cx).workspaces().cloned().collect(); for workspace in workspaces { if let Some(agent_panel) = workspace.read(cx).panel::(cx) { let cancelled = @@ -2932,7 +2931,6 @@ impl Sidebar { .map(|mw| { mw.read(cx) .workspaces() - .iter() .filter_map(|ws| ws.read(cx).database_id()) .collect() }) @@ -3400,12 +3398,9 @@ impl Sidebar { } fn active_workspace(&self, cx: &App) -> Option> { - self.multi_workspace.upgrade().and_then(|w| { - w.read(cx) - .workspaces() - .get(w.read(cx).active_workspace_index()) - .cloned() - }) + self.multi_workspace + .upgrade() + .map(|w| w.read(cx).workspace().clone()) } fn show_thread_import_modal(&mut self, window: &mut Window, cx: &mut Context) { @@ -3513,12 +3508,11 @@ impl Sidebar { } fn show_archive(&mut self, window: &mut Window, cx: &mut Context) { - let Some(active_workspace) = self.multi_workspace.upgrade().and_then(|w| { - w.read(cx) - .workspaces() - .get(w.read(cx).active_workspace_index()) - .cloned() - }) else { + let Some(active_workspace) = self + .multi_workspace + .upgrade() + .map(|w| w.read(cx).workspace().clone()) + else { return; }; let Some(agent_panel) = active_workspace.read(cx).panel::(cx) else { @@ -3820,12 +3814,12 @@ pub fn dump_workspace_info( let multi_workspace = workspace.multi_workspace().and_then(|weak| weak.upgrade()); let workspaces: Vec> = match &multi_workspace { - Some(mw) => mw.read(cx).workspaces().to_vec(), + Some(mw) => mw.read(cx).workspaces().cloned().collect(), None => vec![this_entity.clone()], }; - let active_index = multi_workspace + let active_workspace = multi_workspace .as_ref() - .map(|mw| mw.read(cx).active_workspace_index()); + .map(|mw| mw.read(cx).workspace().clone()); writeln!(output, "MultiWorkspace: {} workspace(s)", workspaces.len()).ok(); @@ -3837,13 +3831,10 @@ pub fn dump_workspace_info( } } - if let Some(index) = active_index { - writeln!(output, "Active workspace index: {index}").ok(); - } writeln!(output).ok(); for (index, ws) in workspaces.iter().enumerate() { - let is_active = active_index == Some(index); + let is_active = active_workspace.as_ref() == Some(ws); writeln!( output, "--- Workspace {index}{} ---", diff --git a/crates/sidebar/src/sidebar_tests.rs b/crates/sidebar/src/sidebar_tests.rs index cf1ee8a0f524d9d94edf83c24ecea900f3261fb8..60881acfe9461f7897d6013831970444b7a65544 100644 --- a/crates/sidebar/src/sidebar_tests.rs +++ b/crates/sidebar/src/sidebar_tests.rs @@ -77,6 +77,18 @@ async fn init_test_project( fn setup_sidebar( multi_workspace: &Entity, cx: &mut gpui::VisualTestContext, +) -> Entity { + let sidebar = setup_sidebar_closed(multi_workspace, cx); + multi_workspace.update_in(cx, |mw, window, cx| { + mw.toggle_sidebar(window, cx); + }); + cx.run_until_parked(); + sidebar +} + +fn setup_sidebar_closed( + multi_workspace: &Entity, + cx: &mut gpui::VisualTestContext, ) -> Entity { let multi_workspace = multi_workspace.clone(); let sidebar = @@ -172,16 +184,7 @@ fn save_thread_metadata( cx.run_until_parked(); } -fn open_and_focus_sidebar(sidebar: &Entity, cx: &mut gpui::VisualTestContext) { - let multi_workspace = sidebar.read_with(cx, |s, _| s.multi_workspace.upgrade()); - if let Some(multi_workspace) = multi_workspace { - multi_workspace.update_in(cx, |mw, window, cx| { - if !mw.sidebar_open() { - mw.toggle_sidebar(window, cx); - } - }); - } - cx.run_until_parked(); +fn focus_sidebar(sidebar: &Entity, cx: &mut gpui::VisualTestContext) { sidebar.update_in(cx, |_, window, cx| { cx.focus_self(window); }); @@ -544,7 +547,7 @@ async fn test_workspace_lifecycle(cx: &mut TestAppContext) { // Remove the second workspace multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[1].clone(); + let workspace = mw.workspaces().nth(1).cloned().unwrap(); mw.remove(&workspace, window, cx); }); cx.run_until_parked(); @@ -604,7 +607,7 @@ async fn test_view_more_batched_expansion(cx: &mut TestAppContext) { assert!(entries.iter().any(|e| e.contains("View More"))); // Focus and navigate to View More, then confirm to expand by one batch - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); for _ in 0..7 { cx.dispatch_action(SelectNext); } @@ -915,7 +918,7 @@ async fn test_keyboard_select_next_and_previous(cx: &mut TestAppContext) { // Entries: [header, thread3, thread2, thread1] // Focusing the sidebar does not set a selection; select_next/select_previous // handle None gracefully by starting from the first or last entry. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None); // First SelectNext from None starts at index 0 @@ -970,7 +973,7 @@ async fn test_keyboard_select_first_and_last(cx: &mut TestAppContext) { multi_workspace.update_in(cx, |_, _window, cx| cx.notify()); cx.run_until_parked(); - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); // SelectLast jumps to the end cx.dispatch_action(SelectLast); @@ -993,7 +996,7 @@ async fn test_keyboard_focus_in_does_not_set_selection(cx: &mut TestAppContext) // Open the sidebar so it's rendered, then focus it to trigger focus_in. // focus_in no longer sets a default selection. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None); // Manually set a selection, blur, then refocus — selection should be preserved @@ -1030,7 +1033,7 @@ async fn test_keyboard_confirm_on_project_header_toggles_collapse(cx: &mut TestA ); // Focus the sidebar and select the header (index 0) - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(0); }); @@ -1071,7 +1074,7 @@ async fn test_keyboard_confirm_on_view_more_expands(cx: &mut TestAppContext) { assert!(entries.iter().any(|e| e.contains("View More"))); // Focus sidebar (selection starts at None), then navigate down to the "View More" entry (index 6) - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); for _ in 0..7 { cx.dispatch_action(SelectNext); } @@ -1105,7 +1108,7 @@ async fn test_keyboard_expand_and_collapse_selected_entry(cx: &mut TestAppContex ); // Focus sidebar and manually select the header (index 0). Press left to collapse. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(0); }); @@ -1144,7 +1147,7 @@ async fn test_keyboard_collapse_from_child_selects_parent(cx: &mut TestAppContex cx.run_until_parked(); // Focus sidebar (selection starts at None), then navigate down to the thread (child) - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); cx.dispatch_action(SelectNext); cx.dispatch_action(SelectNext); assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(1)); @@ -1179,7 +1182,7 @@ async fn test_keyboard_navigation_on_empty_list(cx: &mut TestAppContext) { ); // Focus sidebar — focus_in does not set a selection - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None); // First SelectNext from None starts at index 0 (header) @@ -1211,7 +1214,7 @@ async fn test_selection_clamps_after_entry_removal(cx: &mut TestAppContext) { cx.run_until_parked(); // Focus sidebar (selection starts at None), navigate down to the thread (index 1) - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); cx.dispatch_action(SelectNext); cx.dispatch_action(SelectNext); assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(1)); @@ -1492,7 +1495,7 @@ async fn test_escape_clears_search_and_restores_full_list(cx: &mut TestAppContex ); // User types a search query to filter down. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); type_in_search(&sidebar, "alpha", cx); assert_eq!( visible_entries_as_strings(&sidebar, cx), @@ -1540,8 +1543,9 @@ async fn test_search_only_shows_workspace_headers_with_matches(cx: &mut TestAppC }); cx.run_until_parked(); - let project_b = - multi_workspace.read_with(cx, |mw, cx| mw.workspaces()[1].read(cx).project().clone()); + let project_b = multi_workspace.read_with(cx, |mw, cx| { + mw.workspaces().nth(1).unwrap().read(cx).project().clone() + }); for (id, title, hour) in [ ("b1", "Refactor sidebar layout", 3), @@ -1621,8 +1625,9 @@ async fn test_search_matches_workspace_name(cx: &mut TestAppContext) { }); cx.run_until_parked(); - let project_b = - multi_workspace.read_with(cx, |mw, cx| mw.workspaces()[1].read(cx).project().clone()); + let project_b = multi_workspace.read_with(cx, |mw, cx| { + mw.workspaces().nth(1).unwrap().read(cx).project().clone() + }); for (id, title, hour) in [ ("b1", "Refactor sidebar layout", 3), @@ -1764,7 +1769,7 @@ async fn test_search_finds_threads_inside_collapsed_groups(cx: &mut TestAppConte // User focuses the sidebar and collapses the group using keyboard: // manually select the header, then press SelectParent to collapse. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(0); }); @@ -1807,7 +1812,7 @@ async fn test_search_then_keyboard_navigate_and_confirm(cx: &mut TestAppContext) } cx.run_until_parked(); - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); // User types "fix" — two threads match. type_in_search(&sidebar, "fix", cx); @@ -1856,6 +1861,13 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC }); cx.run_until_parked(); + let (workspace_0, workspace_1) = multi_workspace.read_with(cx, |mw, _| { + ( + mw.workspaces().next().unwrap().clone(), + mw.workspaces().nth(1).unwrap().clone(), + ) + }); + save_thread_metadata( acp::SessionId::new(Arc::from("hist-1")), "Historical Thread".into(), @@ -1875,13 +1887,13 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC // Switch to workspace 1 so we can verify the confirm switches back. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[1].clone(); + let workspace = mw.workspaces().nth(1).unwrap().clone(); mw.activate(workspace, window, cx); }); cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 1 + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_1 ); // Confirm on the historical (non-live) thread at index 1. @@ -1895,8 +1907,8 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 0 + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_0 ); } @@ -2037,7 +2049,8 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) { let panel_b = add_agent_panel(&workspace_b, cx); cx.run_until_parked(); - let workspace_a = multi_workspace.read_with(cx, |mw, _cx| mw.workspaces()[0].clone()); + let workspace_a = + multi_workspace.read_with(cx, |mw, _cx| mw.workspaces().next().unwrap().clone()); // ── 1. Initial state: focused thread derived from active panel ───── sidebar.read_with(cx, |sidebar, _cx| { @@ -2135,7 +2148,7 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) { }); multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); cx.run_until_parked(); @@ -2190,8 +2203,8 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) { // Switching workspaces via the multi_workspace (simulates clicking // a workspace header) should clear focused_thread. multi_workspace.update_in(cx, |mw, window, cx| { - if let Some(index) = mw.workspaces().iter().position(|w| w == &workspace_b) { - let workspace = mw.workspaces()[index].clone(); + let workspace = mw.workspaces().find(|w| *w == &workspace_b).cloned(); + if let Some(workspace) = workspace { mw.activate(workspace, window, cx); } }); @@ -2477,6 +2490,8 @@ async fn test_cmd_n_shows_new_thread_entry_in_absorbed_worktree(cx: &mut TestApp let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); + let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(worktree_project.clone(), window, cx) }); @@ -2485,12 +2500,10 @@ async fn test_cmd_n_shows_new_thread_entry_in_absorbed_worktree(cx: &mut TestApp // Switch to the worktree workspace. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[1].clone(); + let workspace = mw.workspaces().nth(1).unwrap().clone(); mw.activate(workspace, window, cx); }); - let sidebar = setup_sidebar(&multi_workspace, cx); - // Create a non-empty thread in the worktree workspace. let connection = StubAgentConnection::new(); connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk( @@ -3027,6 +3040,8 @@ async fn test_absorbed_worktree_running_thread_shows_live_status(cx: &mut TestAp let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); + let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(worktree_project.clone(), window, cx) }); @@ -3037,12 +3052,10 @@ async fn test_absorbed_worktree_running_thread_shows_live_status(cx: &mut TestAp // Switch back to the main workspace before setting up the sidebar. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); - let sidebar = setup_sidebar(&multi_workspace, cx); - // Start a thread in the worktree workspace's panel and keep it // generating (don't resolve it). let connection = StubAgentConnection::new(); @@ -3127,6 +3140,8 @@ async fn test_absorbed_worktree_completion_triggers_notification(cx: &mut TestAp let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); + let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(worktree_project.clone(), window, cx) }); @@ -3134,12 +3149,10 @@ async fn test_absorbed_worktree_completion_triggers_notification(cx: &mut TestAp let worktree_panel = add_agent_panel(&worktree_workspace, cx); multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); - let sidebar = setup_sidebar(&multi_workspace, cx); - let connection = StubAgentConnection::new(); open_thread_with_connection(&worktree_panel, connection.clone(), cx); send_message(&worktree_panel, cx); @@ -3231,12 +3244,12 @@ async fn test_clicking_worktree_thread_opens_workspace_when_none_exists(cx: &mut // Only 1 workspace should exist. assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()), + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), 1, ); // Focus the sidebar and select the worktree thread. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(1); // index 0 is header, 1 is the thread }); @@ -3248,11 +3261,11 @@ async fn test_clicking_worktree_thread_opens_workspace_when_none_exists(cx: &mut // A new workspace should have been created for the worktree path. let new_workspace = multi_workspace.read_with(cx, |mw, _| { assert_eq!( - mw.workspaces().len(), + mw.workspaces().count(), 2, "confirming a worktree thread without a workspace should open one", ); - mw.workspaces()[1].clone() + mw.workspaces().nth(1).unwrap().clone() }); let new_path_list = @@ -3318,7 +3331,7 @@ async fn test_clicking_worktree_thread_does_not_briefly_render_as_separate_proje vec!["v [project]", " WT Thread {wt-feature-a}"], ); - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(1); // index 0 is header, 1 is the thread }); @@ -3444,18 +3457,19 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace( let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); + let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(worktree_project.clone(), window, cx) }); // Activate the main workspace before setting up the sidebar. - multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); - mw.activate(workspace, window, cx); + let main_workspace = multi_workspace.update_in(cx, |mw, window, cx| { + let workspace = mw.workspaces().next().unwrap().clone(); + mw.activate(workspace.clone(), window, cx); + workspace }); - let sidebar = setup_sidebar(&multi_workspace, cx); - save_named_thread_metadata("thread-main", "Main Thread", &main_project, cx).await; save_named_thread_metadata("thread-wt", "WT Thread", &worktree_project, cx).await; @@ -3475,13 +3489,13 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace( .expect("should find the worktree thread entry"); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 0, + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + main_workspace, "main workspace should be active initially" ); // Focus the sidebar and select the absorbed worktree thread. - open_and_focus_sidebar(&sidebar, cx); + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, _window, _cx| { sidebar.selection = Some(wt_thread_index); }); @@ -3491,9 +3505,7 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace( cx.run_until_parked(); // The worktree workspace should now be active, not the main one. - let active_workspace = multi_workspace.read_with(cx, |mw, _| { - mw.workspaces()[mw.active_workspace_index()].clone() - }); + let active_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()); assert_eq!( active_workspace, worktree_workspace, "clicking an absorbed worktree thread should activate the worktree workspace" @@ -3520,25 +3532,27 @@ async fn test_activate_archived_thread_with_saved_paths_activates_matching_works let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx)); - multi_workspace.update_in(cx, |mw, window, cx| { - mw.test_add_workspace(project_b.clone(), window, cx); - }); - let sidebar = setup_sidebar(&multi_workspace, cx); + let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| { + mw.test_add_workspace(project_b.clone(), window, cx) + }); + let workspace_a = + multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone()); + // Save a thread with path_list pointing to project-b. let session_id = acp::SessionId::new(Arc::from("archived-1")); save_test_thread_metadata(&session_id, &project_b, cx).await; // Ensure workspace A is active. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 0 + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_a ); // Call activate_archived_thread – should resolve saved paths and @@ -3562,8 +3576,8 @@ async fn test_activate_archived_thread_with_saved_paths_activates_matching_works cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 1, + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_b, "should have activated the workspace matching the saved path_list" ); } @@ -3588,21 +3602,23 @@ async fn test_activate_archived_thread_cwd_fallback_with_matching_workspace( let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); - multi_workspace.update_in(cx, |mw, window, cx| { - mw.test_add_workspace(project_b, window, cx); - }); - let sidebar = setup_sidebar(&multi_workspace, cx); + let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| { + mw.test_add_workspace(project_b, window, cx) + }); + let workspace_a = + multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone()); + // Start with workspace A active. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 0 + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_a ); // No thread saved to the store – cwd is the only path hint. @@ -3625,8 +3641,8 @@ async fn test_activate_archived_thread_cwd_fallback_with_matching_workspace( cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 1, + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_b, "should have activated the workspace matching the cwd" ); } @@ -3651,21 +3667,21 @@ async fn test_activate_archived_thread_no_paths_no_cwd_uses_active_workspace( let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); - multi_workspace.update_in(cx, |mw, window, cx| { - mw.test_add_workspace(project_b, window, cx); - }); - let sidebar = setup_sidebar(&multi_workspace, cx); + let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| { + mw.test_add_workspace(project_b, window, cx) + }); + // Activate workspace B (index 1) to make it the active one. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[1].clone(); + let workspace = mw.workspaces().nth(1).unwrap().clone(); mw.activate(workspace, window, cx); }); cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 1 + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_b ); // No saved thread, no cwd – should fall back to the active workspace. @@ -3688,8 +3704,8 @@ async fn test_activate_archived_thread_no_paths_no_cwd_uses_active_workspace( cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()), - 1, + multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()), + workspace_b, "should have stayed on the active workspace when no path info is available" ); } @@ -3719,7 +3735,7 @@ async fn test_activate_archived_thread_saved_paths_opens_new_workspace(cx: &mut let session_id = acp::SessionId::new(Arc::from("archived-new-ws")); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()), + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), 1, "should start with one workspace" ); @@ -3743,7 +3759,7 @@ async fn test_activate_archived_thread_saved_paths_opens_new_workspace(cx: &mut cx.run_until_parked(); assert_eq!( - multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()), + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), 2, "should have opened a second workspace for the archived thread's saved paths" ); @@ -3768,6 +3784,10 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window(cx: &m cx.add_window(|window, cx| MultiWorkspace::test_new(project_b, window, cx)); let multi_workspace_a_entity = multi_workspace_a.root(cx).unwrap(); + let multi_workspace_b_entity = multi_workspace_b.root(cx).unwrap(); + + let cx_b = &mut gpui::VisualTestContext::from_window(multi_workspace_b.into(), cx); + let _sidebar_b = setup_sidebar(&multi_workspace_b_entity, cx_b); let cx_a = &mut gpui::VisualTestContext::from_window(multi_workspace_a.into(), cx); let sidebar = setup_sidebar(&multi_workspace_a_entity, cx_a); @@ -3794,14 +3814,14 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window(cx: &m assert_eq!( multi_workspace_a - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "should not add the other window's workspace into the current window" ); assert_eq!( multi_workspace_b - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "should reuse the existing workspace in the other window" @@ -3871,14 +3891,14 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window_with_t assert_eq!( multi_workspace_a - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "should not add the other window's workspace into the current window" ); assert_eq!( multi_workspace_b - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "should reuse the existing workspace in the other window" @@ -3921,6 +3941,10 @@ async fn test_activate_archived_thread_prefers_current_window_for_matching_paths cx.add_window(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); let multi_workspace_a_entity = multi_workspace_a.root(cx).unwrap(); + let multi_workspace_b_entity = multi_workspace_b.root(cx).unwrap(); + + let cx_b = &mut gpui::VisualTestContext::from_window(multi_workspace_b.into(), cx); + let _sidebar_b = setup_sidebar(&multi_workspace_b_entity, cx_b); let cx_a = &mut gpui::VisualTestContext::from_window(multi_workspace_a.into(), cx); let sidebar_a = setup_sidebar(&multi_workspace_a_entity, cx_a); @@ -3958,14 +3982,14 @@ async fn test_activate_archived_thread_prefers_current_window_for_matching_paths }); assert_eq!( multi_workspace_a - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "current window should continue reusing its existing workspace" ); assert_eq!( multi_workspace_b - .read_with(cx_a, |mw, _| mw.workspaces().len()) + .read_with(cx_a, |mw, _| mw.workspaces().count()) .unwrap(), 1, "other windows should not be activated just because they also match the saved paths" @@ -4029,19 +4053,20 @@ async fn test_archive_thread_uses_next_threads_own_workspace(cx: &mut TestAppCon let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); + let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(worktree_project.clone(), window, cx) }); // Activate main workspace so the sidebar tracks the main panel. multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }); - let sidebar = setup_sidebar(&multi_workspace, cx); - - let main_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspaces()[0].clone()); + let main_workspace = + multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone()); let main_panel = add_agent_panel(&main_workspace, cx); let _worktree_panel = add_agent_panel(&worktree_workspace, cx); @@ -4195,10 +4220,10 @@ async fn test_linked_worktree_threads_not_duplicated_across_groups(cx: &mut Test let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_only.clone(), window, cx)); + let sidebar = setup_sidebar(&multi_workspace, cx); multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(multi_root.clone(), window, cx); }); - let sidebar = setup_sidebar(&multi_workspace, cx); // Save a thread under the linked worktree path. save_named_thread_metadata("wt-thread", "Worktree Thread", &worktree_project, cx).await; @@ -4313,8 +4338,8 @@ async fn test_thread_switcher_ordering(cx: &mut TestAppContext) { // so all three have last_accessed_at set. // Access order is: A (most recent), B, C (oldest). - // ── 1. Open switcher: threads sorted by last_accessed_at ─────────── - open_and_focus_sidebar(&sidebar, cx); + // ── 1. Open switcher: threads sorted by last_accessed_at ───────────────── + focus_sidebar(&sidebar, cx); sidebar.update_in(cx, |sidebar, window, cx| { sidebar.on_toggle_thread_switcher(&ToggleThreadSwitcher::default(), window, cx); }); @@ -4759,6 +4784,284 @@ async fn test_linked_worktree_workspace_shows_main_worktree_threads(cx: &mut Tes ); } +async fn init_multi_project_test( + paths: &[&str], + cx: &mut TestAppContext, +) -> (Arc, Entity) { + agent_ui::test_support::init_test(cx); + cx.update(|cx| { + cx.update_flags(false, vec!["agent-v2".into()]); + ThreadStore::init_global(cx); + ThreadMetadataStore::init_global(cx); + language_model::LanguageModelRegistry::test(cx); + prompt_store::init(cx); + }); + let fs = FakeFs::new(cx.executor()); + for path in paths { + fs.insert_tree(path, serde_json::json!({ ".git": {}, "src": {} })) + .await; + } + cx.update(|cx| ::set_global(fs.clone(), cx)); + let project = + project::Project::test(fs.clone() as Arc, [paths[0].as_ref()], cx).await; + (fs, project) +} + +async fn add_test_project( + path: &str, + fs: &Arc, + multi_workspace: &Entity, + cx: &mut gpui::VisualTestContext, +) -> Entity { + let project = project::Project::test(fs.clone() as Arc, [path.as_ref()], cx).await; + let workspace = multi_workspace.update_in(cx, |mw, window, cx| { + mw.test_add_workspace(project, window, cx) + }); + cx.run_until_parked(); + workspace +} + +#[gpui::test] +async fn test_transient_workspace_lifecycle(cx: &mut TestAppContext) { + let (fs, project_a) = + init_multi_project_test(&["/project-a", "/project-b", "/project-c"], cx).await; + let (multi_workspace, cx) = + cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); + let _sidebar = setup_sidebar_closed(&multi_workspace, cx); + + // Sidebar starts closed. Initial workspace A is transient. + let workspace_a = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()); + assert!(!multi_workspace.read_with(cx, |mw, _| mw.sidebar_open())); + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_a)); + + // Add B — replaces A as the transient workspace. + let workspace_b = add_test_project("/project-b", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_b)); + + // Add C — replaces B as the transient workspace. + let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c)); +} + +#[gpui::test] +async fn test_transient_workspace_retained(cx: &mut TestAppContext) { + let (fs, project_a) = init_multi_project_test( + &["/project-a", "/project-b", "/project-c", "/project-d"], + cx, + ) + .await; + let (multi_workspace, cx) = + cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); + let _sidebar = setup_sidebar(&multi_workspace, cx); + assert!(multi_workspace.read_with(cx, |mw, _| mw.sidebar_open())); + + // Add B — retained since sidebar is open. + let workspace_a = add_test_project("/project-b", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 2 + ); + + // Switch to A — B survives. (Switching from one internal workspace, to another) + multi_workspace.update_in(cx, |mw, window, cx| mw.activate(workspace_a, window, cx)); + cx.run_until_parked(); + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 2 + ); + + // Close sidebar — both A and B remain retained. + multi_workspace.update_in(cx, |mw, window, cx| mw.close_sidebar(window, cx)); + cx.run_until_parked(); + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 2 + ); + + // Add C — added as new transient workspace. (switching from retained, to transient) + let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 3 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c)); + + // Add D — replaces C as the transient workspace (Have retained and transient workspaces, transient workspace is dropped) + let workspace_d = add_test_project("/project-d", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 3 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_d)); +} + +#[gpui::test] +async fn test_transient_workspace_promotion(cx: &mut TestAppContext) { + let (fs, project_a) = + init_multi_project_test(&["/project-a", "/project-b", "/project-c"], cx).await; + let (multi_workspace, cx) = + cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); + setup_sidebar_closed(&multi_workspace, cx); + + // Add B — replaces A as the transient workspace (A is discarded). + let workspace_b = add_test_project("/project-b", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_b)); + + // Open sidebar — promotes the transient B to retained. + multi_workspace.update_in(cx, |mw, window, cx| { + mw.toggle_sidebar(window, cx); + }); + cx.run_until_parked(); + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspaces().any(|w| w == &workspace_b))); + + // Close sidebar — the retained B remains. + multi_workspace.update_in(cx, |mw, window, cx| { + mw.toggle_sidebar(window, cx); + }); + + // Add C — added as new transient workspace. + let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await; + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 2 + ); + assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c)); +} + +#[gpui::test] +async fn test_legacy_thread_with_canonical_path_opens_main_repo_workspace(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + + fs.insert_tree( + "/project", + serde_json::json!({ + ".git": { + "worktrees": { + "feature-a": { + "commondir": "../../", + "HEAD": "ref: refs/heads/feature-a", + }, + }, + }, + "src": {}, + }), + ) + .await; + + fs.insert_tree( + "/wt-feature-a", + serde_json::json!({ + ".git": "gitdir: /project/.git/worktrees/feature-a", + "src": {}, + }), + ) + .await; + + fs.add_linked_worktree_for_repo( + Path::new("/project/.git"), + false, + git::repository::Worktree { + path: PathBuf::from("/wt-feature-a"), + ref_name: Some("refs/heads/feature-a".into()), + sha: "abc".into(), + is_main: false, + }, + ) + .await; + + cx.update(|cx| ::set_global(fs.clone(), cx)); + + // Only a linked worktree workspace is open — no workspace for /project. + let worktree_project = project::Project::test(fs.clone(), ["/wt-feature-a".as_ref()], cx).await; + worktree_project + .update(cx, |p, cx| p.git_scans_complete(cx)) + .await; + + let (multi_workspace, cx) = cx.add_window_view(|window, cx| { + MultiWorkspace::test_new(worktree_project.clone(), window, cx) + }); + let sidebar = setup_sidebar(&multi_workspace, cx); + + // Save a legacy thread: folder_paths = main repo, main_worktree_paths = empty. + let legacy_session = acp::SessionId::new(Arc::from("legacy-main-thread")); + cx.update(|_, cx| { + let metadata = ThreadMetadata { + session_id: legacy_session.clone(), + agent_id: agent::ZED_AGENT_ID.clone(), + title: "Legacy Main Thread".into(), + updated_at: chrono::TimeZone::with_ymd_and_hms(&Utc, 2024, 1, 1, 0, 0, 0).unwrap(), + created_at: None, + folder_paths: PathList::new(&[PathBuf::from("/project")]), + main_worktree_paths: PathList::default(), + archived: false, + }; + ThreadMetadataStore::global(cx).update(cx, |store, cx| store.save_manually(metadata, cx)); + }); + cx.run_until_parked(); + + multi_workspace.update_in(cx, |_, _window, cx| cx.notify()); + cx.run_until_parked(); + + // The legacy thread should appear in the sidebar under the project group. + let entries = visible_entries_as_strings(&sidebar, cx); + assert!( + entries.iter().any(|e| e.contains("Legacy Main Thread")), + "legacy thread should be visible: {entries:?}", + ); + + // Verify only 1 workspace before clicking. + assert_eq!( + multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()), + 1, + ); + + // Focus and select the legacy thread, then confirm. + focus_sidebar(&sidebar, cx); + let thread_index = sidebar.read_with(cx, |sidebar, _| { + sidebar + .contents + .entries + .iter() + .position(|e| e.session_id().is_some_and(|id| id == &legacy_session)) + .expect("legacy thread should be in entries") + }); + sidebar.update_in(cx, |sidebar, _window, _cx| { + sidebar.selection = Some(thread_index); + }); + cx.dispatch_action(Confirm); + cx.run_until_parked(); + + let new_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()); + let new_path_list = + new_workspace.read_with(cx, |_, cx| workspace_path_list(&new_workspace, cx)); + assert_eq!( + new_path_list, + PathList::new(&[PathBuf::from("/project")]), + "the new workspace should be for the main repo, not the linked worktree", + ); +} + mod property_test { use super::*; @@ -4943,7 +5246,12 @@ mod property_test { match operation { Operation::SaveThread { workspace_index } => { let project = multi_workspace.read_with(cx, |mw, cx| { - mw.workspaces()[workspace_index].read(cx).project().clone() + mw.workspaces() + .nth(workspace_index) + .unwrap() + .read(cx) + .project() + .clone() }); save_thread_to_path(state, &project, cx); } @@ -5030,7 +5338,7 @@ mod property_test { } Operation::RemoveWorkspace { index } => { let removed = multi_workspace.update_in(cx, |mw, window, cx| { - let workspace = mw.workspaces()[index].clone(); + let workspace = mw.workspaces().nth(index).unwrap().clone(); mw.remove(&workspace, window, cx) }); if removed { @@ -5044,8 +5352,8 @@ mod property_test { } } Operation::SwitchWorkspace { index } => { - let workspace = - multi_workspace.read_with(cx, |mw, _| mw.workspaces()[index].clone()); + let workspace = multi_workspace + .read_with(cx, |mw, _| mw.workspaces().nth(index).unwrap().clone()); multi_workspace.update_in(cx, |mw, window, cx| { mw.activate(workspace, window, cx); }); @@ -5095,8 +5403,9 @@ mod property_test { .await; // Re-scan the main workspace's project so it discovers the new worktree. - let main_workspace = - multi_workspace.read_with(cx, |mw, _| mw.workspaces()[workspace_index].clone()); + let main_workspace = multi_workspace.read_with(cx, |mw, _| { + mw.workspaces().nth(workspace_index).unwrap().clone() + }); let main_project = main_workspace.read_with(cx, |ws, _| ws.project().clone()); main_project .update(cx, |p, cx| p.git_scans_complete(cx)) @@ -5183,7 +5492,11 @@ mod property_test { let Some(multi_workspace) = sidebar.multi_workspace.upgrade() else { anyhow::bail!("sidebar should still have an associated multi-workspace"); }; - let workspaces = multi_workspace.read(cx).workspaces().to_vec(); + let workspaces = multi_workspace + .read(cx) + .workspaces() + .cloned() + .collect::>(); let thread_store = ThreadMetadataStore::global(cx); let sidebar_thread_ids: HashSet = sidebar diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 3ecc6c844db834da91e2f24c3f0cf2d460b5f246..acccd6129f75ee2f5213fa359203220a7fee08c0 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -850,6 +850,7 @@ impl TerminalView { fn send_text(&mut self, text: &SendText, _: &mut Window, cx: &mut Context) { self.clear_bell(cx); + self.blink_manager.update(cx, BlinkManager::pause_blinking); self.terminal.update(cx, |term, _| { term.input(text.0.to_string().into_bytes()); }); @@ -858,6 +859,7 @@ impl TerminalView { fn send_keystroke(&mut self, text: &SendKeystroke, _: &mut Window, cx: &mut Context) { if let Some(keystroke) = Keystroke::parse(&text.0).log_err() { self.clear_bell(cx); + self.blink_manager.update(cx, BlinkManager::pause_blinking); self.process_keystroke(&keystroke, cx); } } diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index 440249907adb6d29602ad8e950d0fd26a2d1c31d..dfcd933dc20df9a6f6643402719f2ec1143cc7fe 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -740,7 +740,6 @@ impl TitleBar { .map(|mw| { mw.read(cx) .workspaces() - .iter() .filter_map(|ws| ws.read(cx).database_id()) .collect() }) @@ -803,7 +802,6 @@ impl TitleBar { .map(|mw| { mw.read(cx) .workspaces() - .iter() .filter_map(|ws| ws.read(cx).database_id()) .collect() }) diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index 6bf2afd09ae07ff8453a481a8d6e6e6a254e670f..6e992704f54bf7aba3cc775d906a90281234dbd0 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -7,7 +7,7 @@ use editor::{ }, }; use gpui::{Action, Context, Window, actions, px}; -use language::{CharKind, Point, Selection, SelectionGoal}; +use language::{CharKind, Point, Selection, SelectionGoal, TextObject, TreeSitterOptions}; use multi_buffer::MultiBufferRow; use schemars::JsonSchema; use serde::Deserialize; @@ -2451,6 +2451,10 @@ fn find_matching_bracket_text_based( .take_while(|(_, char_offset)| *char_offset < line_range.end) .find_map(|(ch, char_offset)| get_bracket_pair(ch).map(|info| (info, char_offset))); + if bracket_info.is_none() { + return find_matching_c_preprocessor_directive(map, line_range); + } + let (open, close, is_opening) = bracket_info?.0; let bracket_offset = bracket_info?.1; @@ -2482,6 +2486,122 @@ fn find_matching_bracket_text_based( None } +fn find_matching_c_preprocessor_directive( + map: &DisplaySnapshot, + line_range: Range, +) -> Option { + let line_start = map + .buffer_chars_at(line_range.start) + .skip_while(|(c, _)| *c == ' ' || *c == '\t') + .map(|(c, _)| c) + .take(6) + .collect::(); + + if line_start.starts_with("#if") + || line_start.starts_with("#else") + || line_start.starts_with("#elif") + { + let mut depth = 0i32; + for (ch, char_offset) in map.buffer_chars_at(line_range.end) { + if ch != '\n' { + continue; + } + let mut line_offset = char_offset + '\n'.len_utf8(); + + // Skip leading whitespace + map.buffer_chars_at(line_offset) + .take_while(|(c, _)| *c == ' ' || *c == '\t') + .for_each(|(_, _)| line_offset += 1); + + // Check what directive starts the next line + let next_line_start = map + .buffer_chars_at(line_offset) + .map(|(c, _)| c) + .take(6) + .collect::(); + + if next_line_start.starts_with("#if") { + depth += 1; + } else if next_line_start.starts_with("#endif") { + if depth > 0 { + depth -= 1; + } else { + return Some(line_offset); + } + } else if next_line_start.starts_with("#else") || next_line_start.starts_with("#elif") { + if depth == 0 { + return Some(line_offset); + } + } + } + } else if line_start.starts_with("#endif") { + let mut depth = 0i32; + for (ch, char_offset) in + map.reverse_buffer_chars_at(line_range.start.saturating_sub_usize(1)) + { + let mut line_offset = if char_offset == MultiBufferOffset(0) { + MultiBufferOffset(0) + } else if ch != '\n' { + continue; + } else { + char_offset + '\n'.len_utf8() + }; + + // Skip leading whitespace + map.buffer_chars_at(line_offset) + .take_while(|(c, _)| *c == ' ' || *c == '\t') + .for_each(|(_, _)| line_offset += 1); + + // Check what directive starts this line + let line_start = map + .buffer_chars_at(line_offset) + .skip_while(|(c, _)| *c == ' ' || *c == '\t') + .map(|(c, _)| c) + .take(6) + .collect::(); + + if line_start.starts_with("\n\n") { + // empty line + continue; + } else if line_start.starts_with("#endif") { + depth += 1; + } else if line_start.starts_with("#if") { + if depth > 0 { + depth -= 1; + } else { + return Some(line_offset); + } + } + } + } + None +} + +fn comment_delimiter_pair( + map: &DisplaySnapshot, + offset: MultiBufferOffset, +) -> Option<(Range, Range)> { + let snapshot = map.buffer_snapshot(); + snapshot + .text_object_ranges(offset..offset, TreeSitterOptions::default()) + .find_map(|(range, obj)| { + if !matches!(obj, TextObject::InsideComment | TextObject::AroundComment) + || !range.contains(&offset) + { + return None; + } + + let mut chars = snapshot.chars_at(range.start); + if (Some('/'), Some('*')) != (chars.next(), chars.next()) { + return None; + } + + let open_range = range.start..range.start + 2usize; + let close_range = range.end - 2..range.end; + Some((open_range, close_range)) + }) +} + fn matching( map: &DisplaySnapshot, display_point: DisplayPoint, @@ -2609,6 +2729,32 @@ fn matching( continue; } + if let Some((open_range, close_range)) = comment_delimiter_pair(map, offset) { + if open_range.contains(&offset) { + return close_range.start.to_display_point(map); + } + + if close_range.contains(&offset) { + return open_range.start.to_display_point(map); + } + + let open_candidate = (open_range.start >= offset + && line_range.contains(&open_range.start)) + .then_some((open_range.start.saturating_sub(offset), close_range.start)); + + let close_candidate = (close_range.start >= offset + && line_range.contains(&close_range.start)) + .then_some((close_range.start.saturating_sub(offset), open_range.start)); + + if let Some((_, destination)) = [open_candidate, close_candidate] + .into_iter() + .flatten() + .min_by_key(|(distance, _)| *distance) + { + return destination.to_display_point(map); + } + } + closest_pair_destination .map(|destination| destination.to_display_point(map)) .unwrap_or_else(|| { @@ -3497,6 +3643,119 @@ mod test { ); } + #[gpui::test] + async fn test_matching_comments(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state(indoc! {r"ˇ/* + this is a comment + */"}) + .await; + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"/* + this is a comment + ˇ*/"}); + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"ˇ/* + this is a comment + */"}); + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"/* + this is a comment + ˇ*/"}); + + cx.set_shared_state("ˇ// comment").await; + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq("ˇ// comment"); + } + + #[gpui::test] + async fn test_matching_preprocessor_directives(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state(indoc! {r"#ˇif + + #else + + #endif + "}) + .await; + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"#if + + ˇ#else + + #endif + "}); + + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"#if + + #else + + ˇ#endif + "}); + + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r"ˇ#if + + #else + + #endif + "}); + + cx.set_shared_state(indoc! {r" + #ˇif + #if + + #else + + #endif + + #else + #endif + "}) + .await; + + cx.simulate_shared_keystrokes("%").await; + cx.shared_state().await.assert_eq(indoc! {r" + #if + #if + + #else + + #endif + + ˇ#else + #endif + "}); + + cx.simulate_shared_keystrokes("% %").await; + cx.shared_state().await.assert_eq(indoc! {r" + ˇ#if + #if + + #else + + #endif + + #else + #endif + "}); + cx.simulate_shared_keystrokes("j % % %").await; + cx.shared_state().await.assert_eq(indoc! {r" + #if + ˇ#if + + #else + + #endif + + #else + #endif + "}); + } + #[gpui::test] async fn test_unmatched_forward(cx: &mut gpui::TestAppContext) { let mut cx = NeovimBackedTestContext::new(cx).await; diff --git a/crates/vim/src/normal/search.rs b/crates/vim/src/normal/search.rs index 6a8394f44710b7e241b7ba38f4913899a5afbce6..22c453c877ec89fdbf432d19d89167285b78b12f 100644 --- a/crates/vim/src/normal/search.rs +++ b/crates/vim/src/normal/search.rs @@ -245,7 +245,7 @@ impl Vim { search_bar.set_replacement(None, cx); let mut options = SearchOptions::NONE; - if action.regex { + if action.regex && VimSettings::get_global(cx).use_regex_search { options |= SearchOptions::REGEX; } if action.backwards { @@ -1446,4 +1446,66 @@ mod test { // The cursor should be at the match location on line 3 (row 2). cx.assert_state("hello world\nfoo bar\nhello ˇagain\n", Mode::Normal); } + + #[gpui::test] + async fn test_vim_search_respects_search_settings(cx: &mut gpui::TestAppContext) { + let mut cx = VimTestContext::new(cx, true).await; + + cx.update_global(|store: &mut SettingsStore, cx| { + store.update_user_settings(cx, |settings| { + settings.vim.get_or_insert_default().use_regex_search = Some(false); + }); + }); + + cx.set_state("ˇcontent", Mode::Normal); + cx.simulate_keystrokes("/"); + cx.run_until_parked(); + + // Verify search options are set from settings + let search_bar = cx.workspace(|workspace, _, cx| { + workspace + .active_pane() + .read(cx) + .toolbar() + .read(cx) + .item_of_type::() + .expect("Buffer search bar should be active") + }); + + cx.update_entity(search_bar, |bar, _window, _cx| { + assert!( + !bar.has_search_option(search::SearchOptions::REGEX), + "Vim search open without regex mode" + ); + }); + + cx.simulate_keystrokes("escape"); + cx.run_until_parked(); + + cx.update_global(|store: &mut SettingsStore, cx| { + store.update_user_settings(cx, |settings| { + settings.vim.get_or_insert_default().use_regex_search = Some(true); + }); + }); + + cx.simulate_keystrokes("/"); + cx.run_until_parked(); + + let search_bar = cx.workspace(|workspace, _, cx| { + workspace + .active_pane() + .read(cx) + .toolbar() + .read(cx) + .item_of_type::() + .expect("Buffer search bar should be active") + }); + + cx.update_entity(search_bar, |bar, _window, _cx| { + assert!( + bar.has_search_option(search::SearchOptions::REGEX), + "Vim search opens with regex mode" + ); + }); + } } diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index 6e1849340f17b776a34546dd9a118dc55e8dab84..a66111cae1576744c4c51d717984d67c12fc8235 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -2141,6 +2141,7 @@ struct VimSettings { pub toggle_relative_line_numbers: bool, pub use_system_clipboard: settings::UseSystemClipboard, pub use_smartcase_find: bool, + pub use_regex_search: bool, pub gdefault: bool, pub custom_digraphs: HashMap>, pub highlight_on_yank_duration: u64, @@ -2227,6 +2228,7 @@ impl Settings for VimSettings { toggle_relative_line_numbers: vim.toggle_relative_line_numbers.unwrap(), use_system_clipboard: vim.use_system_clipboard.unwrap(), use_smartcase_find: vim.use_smartcase_find.unwrap(), + use_regex_search: vim.use_regex_search.unwrap(), gdefault: vim.gdefault.unwrap(), custom_digraphs: vim.custom_digraphs.unwrap(), highlight_on_yank_duration: vim.highlight_on_yank_duration.unwrap(), diff --git a/crates/vim/test_data/test_matching_comments.json b/crates/vim/test_data/test_matching_comments.json new file mode 100644 index 0000000000000000000000000000000000000000..7fcf5e46e1ea16f2be794ff76b583242b33aabc0 --- /dev/null +++ b/crates/vim/test_data/test_matching_comments.json @@ -0,0 +1,10 @@ +{"Put":{"state":"ˇ/*\n this is a comment\n*/"}} +{"Key":"%"} +{"Get":{"state":"/*\n this is a comment\nˇ*/","mode":"Normal"}} +{"Key":"%"} +{"Get":{"state":"ˇ/*\n this is a comment\n*/","mode":"Normal"}} +{"Key":"%"} +{"Get":{"state":"/*\n this is a comment\nˇ*/","mode":"Normal"}} +{"Put":{"state":"ˇ// comment"}} +{"Key":"%"} +{"Get":{"state":"ˇ// comment","mode":"Normal"}} diff --git a/crates/vim/test_data/test_matching_preprocessor_directives.json b/crates/vim/test_data/test_matching_preprocessor_directives.json new file mode 100644 index 0000000000000000000000000000000000000000..9f0bd9792ee8dad5029f4ecaf325c231755530e1 --- /dev/null +++ b/crates/vim/test_data/test_matching_preprocessor_directives.json @@ -0,0 +1,18 @@ +{"Put":{"state":"#ˇif\n\n#else\n\n#endif\n"}} +{"Key":"%"} +{"Get":{"state":"#if\n\nˇ#else\n\n#endif\n","mode":"Normal"}} +{"Key":"%"} +{"Get":{"state":"#if\n\n#else\n\nˇ#endif\n","mode":"Normal"}} +{"Key":"%"} +{"Get":{"state":"ˇ#if\n\n#else\n\n#endif\n","mode":"Normal"}} +{"Put":{"state":"#ˇif\n #if\n\n #else\n\n #endif\n\n#else\n#endif\n"}} +{"Key":"%"} +{"Get":{"state":"#if\n #if\n\n #else\n\n #endif\n\nˇ#else\n#endif\n","mode":"Normal"}} +{"Key":"%"} +{"Key":"%"} +{"Get":{"state":"ˇ#if\n #if\n\n #else\n\n #endif\n\n#else\n#endif\n","mode":"Normal"}} +{"Key":"j"} +{"Key":"%"} +{"Key":"%"} +{"Key":"%"} +{"Get":{"state":"#if\n ˇ#if\n\n #else\n\n #endif\n\n#else\n#endif\n","mode":"Normal"}} diff --git a/crates/workspace/src/multi_workspace.rs b/crates/workspace/src/multi_workspace.rs index dc6060b70a0eeeebc1168113c2c9eb1ba2ddd251..a61ad3576c57ecd8b1811363d6b5607ead737821 100644 --- a/crates/workspace/src/multi_workspace.rs +++ b/crates/workspace/src/multi_workspace.rs @@ -6,9 +6,7 @@ use gpui::{ ManagedView, MouseButton, Pixels, Render, Subscription, Task, Tiling, Window, WindowId, actions, deferred, px, }; -#[cfg(any(test, feature = "test-support"))] -use project::Project; -use project::{DirectoryLister, DisableAiSettings, ProjectGroupKey}; +use project::{DirectoryLister, DisableAiSettings, Project, ProjectGroupKey}; use settings::Settings; pub use settings::SidebarSide; use std::future::Future; @@ -42,10 +40,7 @@ actions!( CloseWorkspaceSidebar, /// Moves focus to or from the workspace sidebar without closing it. FocusWorkspaceSidebar, - /// Switches to the next workspace. - NextWorkspace, - /// Switches to the previous workspace. - PreviousWorkspace, + //TODO: Restore next/previous workspace ] ); @@ -223,10 +218,57 @@ impl SidebarHandle for Entity { } } +/// Tracks which workspace the user is currently looking at. +/// +/// `Persistent` workspaces live in the `workspaces` vec and are shown in the +/// sidebar. `Transient` workspaces exist outside the vec and are discarded +/// when the user switches away. +enum ActiveWorkspace { + /// A persistent workspace, identified by index into the `workspaces` vec. + Persistent(usize), + /// A workspace not in the `workspaces` vec that will be discarded on + /// switch or promoted to persistent when the sidebar is opened. + Transient(Entity), +} + +impl ActiveWorkspace { + fn persistent_index(&self) -> Option { + match self { + Self::Persistent(index) => Some(*index), + Self::Transient(_) => None, + } + } + + fn transient_workspace(&self) -> Option<&Entity> { + match self { + Self::Transient(workspace) => Some(workspace), + Self::Persistent(_) => None, + } + } + + /// Sets the active workspace to transient, returning the previous + /// transient workspace (if any). + fn set_transient(&mut self, workspace: Entity) -> Option> { + match std::mem::replace(self, Self::Transient(workspace)) { + Self::Transient(old) => Some(old), + Self::Persistent(_) => None, + } + } + + /// Sets the active workspace to persistent at the given index, + /// returning the previous transient workspace (if any). + fn set_persistent(&mut self, index: usize) -> Option> { + match std::mem::replace(self, Self::Persistent(index)) { + Self::Transient(workspace) => Some(workspace), + Self::Persistent(_) => None, + } + } +} + pub struct MultiWorkspace { window_id: WindowId, workspaces: Vec>, - active_workspace_index: usize, + active_workspace: ActiveWorkspace, project_group_keys: Vec, sidebar: Option>, sidebar_open: bool, @@ -262,12 +304,15 @@ impl MultiWorkspace { } }); let quit_subscription = cx.on_app_quit(Self::app_will_quit); - let settings_subscription = - cx.observe_global_in::(window, |this, window, cx| { - if DisableAiSettings::get_global(cx).disable_ai && this.sidebar_open { - this.close_sidebar(window, cx); + let settings_subscription = cx.observe_global_in::(window, { + let mut previous_disable_ai = DisableAiSettings::get_global(cx).disable_ai; + move |this, window, cx| { + if DisableAiSettings::get_global(cx).disable_ai != previous_disable_ai { + this.collapse_to_single_workspace(window, cx); + previous_disable_ai = DisableAiSettings::get_global(cx).disable_ai; } - }); + } + }); Self::subscribe_to_workspace(&workspace, window, cx); let weak_self = cx.weak_entity(); workspace.update(cx, |workspace, cx| { @@ -275,9 +320,9 @@ impl MultiWorkspace { }); Self { window_id: window.window_handle().window_id(), - project_group_keys: vec![workspace.read(cx).project_group_key(cx)], - workspaces: vec![workspace], - active_workspace_index: 0, + project_group_keys: Vec::new(), + workspaces: Vec::new(), + active_workspace: ActiveWorkspace::Transient(workspace), sidebar: None, sidebar_open: false, sidebar_overlay: None, @@ -339,7 +384,7 @@ impl MultiWorkspace { return; } - if self.sidebar_open { + if self.sidebar_open() { self.close_sidebar(window, cx); } else { self.open_sidebar(cx); @@ -355,7 +400,7 @@ impl MultiWorkspace { return; } - if self.sidebar_open { + if self.sidebar_open() { self.close_sidebar(window, cx); } } @@ -365,7 +410,7 @@ impl MultiWorkspace { return; } - if self.sidebar_open { + if self.sidebar_open() { let sidebar_is_focused = self .sidebar .as_ref() @@ -390,8 +435,13 @@ impl MultiWorkspace { pub fn open_sidebar(&mut self, cx: &mut Context) { self.sidebar_open = true; + if let ActiveWorkspace::Transient(workspace) = &self.active_workspace { + let workspace = workspace.clone(); + let index = self.promote_transient(workspace, cx); + self.active_workspace = ActiveWorkspace::Persistent(index); + } let sidebar_focus_handle = self.sidebar.as_ref().map(|s| s.focus_handle(cx)); - for workspace in &self.workspaces { + for workspace in self.workspaces.iter() { workspace.update(cx, |workspace, _cx| { workspace.set_sidebar_focus_handle(sidebar_focus_handle.clone()); }); @@ -402,7 +452,7 @@ impl MultiWorkspace { pub fn close_sidebar(&mut self, window: &mut Window, cx: &mut Context) { self.sidebar_open = false; - for workspace in &self.workspaces { + for workspace in self.workspaces.iter() { workspace.update(cx, |workspace, _cx| { workspace.set_sidebar_focus_handle(None); }); @@ -417,7 +467,7 @@ impl MultiWorkspace { pub fn close_window(&mut self, _: &CloseWindow, window: &mut Window, cx: &mut Context) { cx.spawn_in(window, async move |this, cx| { let workspaces = this.update(cx, |multi_workspace, _cx| { - multi_workspace.workspaces().to_vec() + multi_workspace.workspaces().cloned().collect::>() })?; for workspace in workspaces { @@ -468,6 +518,9 @@ impl MultiWorkspace { } pub fn add_project_group_key(&mut self, project_group_key: ProjectGroupKey) { + if project_group_key.path_list().paths().is_empty() { + return; + } if self.project_group_keys.contains(&project_group_key) { return; } @@ -649,13 +702,19 @@ impl MultiWorkspace { if let Some(workspace) = self .workspaces .iter() - .find(|ws| ws.read(cx).project_group_key(cx).path_list() == &path_list) + .find(|ws| PathList::new(&ws.read(cx).root_paths(cx)) == path_list) .cloned() { self.activate(workspace.clone(), window, cx); return Task::ready(Ok(workspace)); } + if let Some(transient) = self.active_workspace.transient_workspace() { + if transient.read(cx).project_group_key(cx).path_list() == &path_list { + return Task::ready(Ok(transient.clone())); + } + } + let paths = path_list.paths().to_vec(); let app_state = self.workspace().read(cx).app_state().clone(); let requesting_window = window.window_handle().downcast::(); @@ -679,25 +738,23 @@ impl MultiWorkspace { } pub fn workspace(&self) -> &Entity { - &self.workspaces[self.active_workspace_index] - } - - pub fn workspaces(&self) -> &[Entity] { - &self.workspaces + match &self.active_workspace { + ActiveWorkspace::Persistent(index) => &self.workspaces[*index], + ActiveWorkspace::Transient(workspace) => workspace, + } } - pub fn active_workspace_index(&self) -> usize { - self.active_workspace_index + pub fn workspaces(&self) -> impl Iterator> { + self.workspaces + .iter() + .chain(self.active_workspace.transient_workspace()) } - /// Adds a workspace to this window without changing which workspace is - /// active. + /// Adds a workspace to this window as persistent without changing which + /// workspace is active. Unlike `activate()`, this always inserts into the + /// persistent list regardless of sidebar state — it's used for system- + /// initiated additions like deserialization and worktree discovery. pub fn add(&mut self, workspace: Entity, window: &Window, cx: &mut Context) { - if !self.multi_workspace_enabled(cx) { - self.set_single_workspace(workspace, cx); - return; - } - self.insert_workspace(workspace, window, cx); } @@ -708,26 +765,74 @@ impl MultiWorkspace { window: &mut Window, cx: &mut Context, ) { - if !self.multi_workspace_enabled(cx) { - self.set_single_workspace(workspace, cx); + // Re-activating the current workspace is a no-op. + if self.workspace() == &workspace { + self.focus_active_workspace(window, cx); return; } - let index = self.insert_workspace(workspace, &*window, cx); - let changed = self.active_workspace_index != index; - self.active_workspace_index = index; - if changed { - cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); - self.serialize(cx); + // Resolve where we're going. + let new_index = if let Some(index) = self.workspaces.iter().position(|w| *w == workspace) { + Some(index) + } else if self.sidebar_open { + Some(self.insert_workspace(workspace.clone(), &*window, cx)) + } else { + None + }; + + // Transition the active workspace. + if let Some(index) = new_index { + if let Some(old) = self.active_workspace.set_persistent(index) { + if self.sidebar_open { + self.promote_transient(old, cx); + } else { + self.detach_workspace(&old, cx); + cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old.entity_id())); + } + } + } else { + Self::subscribe_to_workspace(&workspace, window, cx); + let weak_self = cx.weak_entity(); + workspace.update(cx, |workspace, cx| { + workspace.set_multi_workspace(weak_self, cx); + }); + if let Some(old) = self.active_workspace.set_transient(workspace) { + self.detach_workspace(&old, cx); + cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old.entity_id())); + } } + + cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); + self.serialize(cx); self.focus_active_workspace(window, cx); cx.notify(); } - fn set_single_workspace(&mut self, workspace: Entity, cx: &mut Context) { - self.workspaces[0] = workspace; - self.active_workspace_index = 0; - cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); + /// Promotes a former transient workspace into the persistent list. + /// Returns the index of the newly inserted workspace. + fn promote_transient(&mut self, workspace: Entity, cx: &mut Context) -> usize { + let project_group_key = workspace.read(cx).project().read(cx).project_group_key(cx); + self.add_project_group_key(project_group_key); + self.workspaces.push(workspace.clone()); + cx.emit(MultiWorkspaceEvent::WorkspaceAdded(workspace)); + self.workspaces.len() - 1 + } + + /// Collapses to a single transient workspace, discarding all persistent + /// workspaces. Used when multi-workspace is disabled (e.g. disable_ai). + fn collapse_to_single_workspace(&mut self, window: &mut Window, cx: &mut Context) { + if self.sidebar_open { + self.close_sidebar(window, cx); + } + let active = self.workspace().clone(); + for workspace in std::mem::take(&mut self.workspaces) { + if workspace != active { + self.detach_workspace(&workspace, cx); + cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(workspace.entity_id())); + } + } + self.project_group_keys.clear(); + self.active_workspace = ActiveWorkspace::Transient(active); cx.notify(); } @@ -783,7 +888,7 @@ impl MultiWorkspace { } fn sync_sidebar_to_workspace(&self, workspace: &Entity, cx: &mut Context) { - if self.sidebar_open { + if self.sidebar_open() { let sidebar_focus_handle = self.sidebar.as_ref().map(|s| s.focus_handle(cx)); workspace.update(cx, |workspace, _| { workspace.set_sidebar_focus_handle(sidebar_focus_handle); @@ -791,30 +896,6 @@ impl MultiWorkspace { } } - fn cycle_workspace(&mut self, delta: isize, window: &mut Window, cx: &mut Context) { - let count = self.workspaces.len() as isize; - if count <= 1 { - return; - } - let current = self.active_workspace_index as isize; - let next = ((current + delta).rem_euclid(count)) as usize; - let workspace = self.workspaces[next].clone(); - self.activate(workspace, window, cx); - } - - fn next_workspace(&mut self, _: &NextWorkspace, window: &mut Window, cx: &mut Context) { - self.cycle_workspace(1, window, cx); - } - - fn previous_workspace( - &mut self, - _: &PreviousWorkspace, - window: &mut Window, - cx: &mut Context, - ) { - self.cycle_workspace(-1, window, cx); - } - pub(crate) fn serialize(&mut self, cx: &mut Context) { self._serialize_task = Some(cx.spawn(async move |this, cx| { let Some((window_id, state)) = this @@ -1040,26 +1121,82 @@ impl MultiWorkspace { let Some(index) = self.workspaces.iter().position(|w| w == workspace) else { return false; }; + + let old_key = workspace.read(cx).project_group_key(cx); + if self.workspaces.len() <= 1 { - return false; - } + let has_worktrees = workspace.read(cx).visible_worktrees(cx).next().is_some(); - let removed_workspace = self.workspaces.remove(index); + if !has_worktrees { + return false; + } - if self.active_workspace_index >= self.workspaces.len() { - self.active_workspace_index = self.workspaces.len() - 1; - } else if self.active_workspace_index > index { - self.active_workspace_index -= 1; + let old_workspace = workspace.clone(); + let old_entity_id = old_workspace.entity_id(); + + let app_state = old_workspace.read(cx).app_state().clone(); + + let project = Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + project::LocalProjectFlags::default(), + cx, + ); + + let new_workspace = cx.new(|cx| Workspace::new(None, project, app_state, window, cx)); + + self.workspaces[0] = new_workspace.clone(); + self.active_workspace = ActiveWorkspace::Persistent(0); + + Self::subscribe_to_workspace(&new_workspace, window, cx); + + self.sync_sidebar_to_workspace(&new_workspace, cx); + + let weak_self = cx.weak_entity(); + + new_workspace.update(cx, |workspace, cx| { + workspace.set_multi_workspace(weak_self, cx); + }); + + self.detach_workspace(&old_workspace, cx); + + cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old_entity_id)); + cx.emit(MultiWorkspaceEvent::WorkspaceAdded(new_workspace)); + cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); + } else { + let removed_workspace = self.workspaces.remove(index); + + if let Some(active_index) = self.active_workspace.persistent_index() { + if active_index >= self.workspaces.len() { + self.active_workspace = ActiveWorkspace::Persistent(self.workspaces.len() - 1); + } else if active_index > index { + self.active_workspace = ActiveWorkspace::Persistent(active_index - 1); + } + } + + self.detach_workspace(&removed_workspace, cx); + + cx.emit(MultiWorkspaceEvent::WorkspaceRemoved( + removed_workspace.entity_id(), + )); + cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); } - self.detach_workspace(&removed_workspace, cx); + let key_still_in_use = self + .workspaces + .iter() + .any(|ws| ws.read(cx).project_group_key(cx) == old_key); + + if !key_still_in_use { + self.project_group_keys.retain(|k| k != &old_key); + } self.serialize(cx); self.focus_active_workspace(window, cx); - cx.emit(MultiWorkspaceEvent::WorkspaceRemoved( - removed_workspace.entity_id(), - )); - cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged); cx.notify(); true @@ -1288,8 +1425,6 @@ impl Render for MultiWorkspace { this.focus_sidebar(window, cx); }, )) - .on_action(cx.listener(Self::next_workspace)) - .on_action(cx.listener(Self::previous_workspace)) .on_action(cx.listener(Self::move_active_workspace_to_new_window)) .on_action(cx.listener( |this: &mut Self, action: &ToggleThreadSwitcher, window, cx| { diff --git a/crates/workspace/src/multi_workspace_tests.rs b/crates/workspace/src/multi_workspace_tests.rs index 3083c23f6e3add91b0389a961567fc88e2043678..ab6ca43d5aff482b637add9083b1ad9d388d7993 100644 --- a/crates/workspace/src/multi_workspace_tests.rs +++ b/crates/workspace/src/multi_workspace_tests.rs @@ -99,6 +99,10 @@ async fn test_project_group_keys_initial(cx: &mut TestAppContext) { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project, window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.read_with(cx, |mw, _cx| { let keys: Vec<&ProjectGroupKey> = mw.project_group_keys().collect(); assert_eq!(keys.len(), 1, "should have exactly one key on creation"); @@ -125,6 +129,10 @@ async fn test_project_group_keys_add_workspace(cx: &mut TestAppContext) { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.read_with(cx, |mw, _cx| { assert_eq!(mw.project_group_keys().count(), 1); }); @@ -162,6 +170,10 @@ async fn test_project_group_keys_duplicate_not_added(cx: &mut TestAppContext) { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(project_a2, window, cx); }); @@ -189,6 +201,10 @@ async fn test_project_group_keys_on_worktree_added(cx: &mut TestAppContext) { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + // Add a second worktree to the same project. let (worktree, _) = project .update(cx, |project, cx| { @@ -232,6 +248,10 @@ async fn test_project_group_keys_on_worktree_removed(cx: &mut TestAppContext) { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + // Remove one worktree. let worktree_b_id = project.read_with(cx, |project, cx| { project @@ -282,6 +302,10 @@ async fn test_project_group_keys_across_multiple_workspaces_and_worktree_changes let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(project_b, window, cx); }); diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 92f0781f82234ce79d47db08785b6592fb53f566..27cc96ae80a010db2dd5357a9a0bc037ca762875 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -3670,6 +3670,11 @@ impl Pane { this.drag_split_direction = None; this.handle_external_paths_drop(paths, window, cx) })) + .on_click(cx.listener(move |this, event: &ClickEvent, window, cx| { + if event.click_count() == 2 { + window.dispatch_action(this.double_click_dispatch_action.boxed_clone(), cx); + } + })) } pub fn render_menu_overlay(menu: &Entity) -> Div { @@ -4917,14 +4922,17 @@ impl Render for DraggedTab { #[cfg(test)] mod tests { - use std::{cell::Cell, iter::zip, num::NonZero}; + use std::{cell::Cell, iter::zip, num::NonZero, rc::Rc}; use super::*; use crate::{ Member, item::test::{TestItem, TestProjectItem}, }; - use gpui::{AppContext, Axis, TestAppContext, VisualTestContext, size}; + use gpui::{ + AppContext, Axis, Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, + TestAppContext, VisualTestContext, size, + }; use project::FakeFs; use settings::SettingsStore; use theme::LoadThemes; @@ -6649,8 +6657,6 @@ mod tests { #[gpui::test] async fn test_drag_tab_to_middle_tab_with_mouse_events(cx: &mut TestAppContext) { - use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent}; - init_test(cx); let fs = FakeFs::new(cx.executor()); @@ -6702,8 +6708,6 @@ mod tests { async fn test_drag_pinned_tab_when_show_pinned_tabs_in_separate_row_enabled( cx: &mut TestAppContext, ) { - use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent}; - init_test(cx); set_pinned_tabs_separate_row(cx, true); let fs = FakeFs::new(cx.executor()); @@ -6779,8 +6783,6 @@ mod tests { async fn test_drag_unpinned_tab_when_show_pinned_tabs_in_separate_row_enabled( cx: &mut TestAppContext, ) { - use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent}; - init_test(cx); set_pinned_tabs_separate_row(cx, true); let fs = FakeFs::new(cx.executor()); @@ -6833,8 +6835,6 @@ mod tests { async fn test_drag_mixed_tabs_when_show_pinned_tabs_in_separate_row_enabled( cx: &mut TestAppContext, ) { - use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent}; - init_test(cx); set_pinned_tabs_separate_row(cx, true); let fs = FakeFs::new(cx.executor()); @@ -6900,8 +6900,6 @@ mod tests { #[gpui::test] async fn test_middle_click_pinned_tab_does_not_close(cx: &mut TestAppContext) { - use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseUpEvent}; - init_test(cx); let fs = FakeFs::new(cx.executor()); @@ -6971,6 +6969,74 @@ mod tests { assert_item_labels(&pane, ["A*!"], cx); } + #[gpui::test] + async fn test_double_click_pinned_tab_bar_empty_space_creates_new_tab(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + + let project = Project::test(fs, None, cx).await; + let (workspace, cx) = + cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let pane = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone()); + + // The real NewFile handler lives in editor::init, which isn't initialized + // in workspace tests. Register a global action handler that sets a flag so + // we can verify the action is dispatched without depending on the editor crate. + // TODO: If editor::init is ever available in workspace tests, remove this + // flag and assert the resulting tab bar state directly instead. + let new_file_dispatched = Rc::new(Cell::new(false)); + cx.update(|_, cx| { + let new_file_dispatched = new_file_dispatched.clone(); + cx.on_action(move |_: &NewFile, _cx| { + new_file_dispatched.set(true); + }); + }); + + set_pinned_tabs_separate_row(cx, true); + + let item_a = add_labeled_item(&pane, "A", false, cx); + add_labeled_item(&pane, "B", false, cx); + + pane.update_in(cx, |pane, window, cx| { + let ix = pane + .index_for_item_id(item_a.item_id()) + .expect("item A should exist"); + pane.pin_tab_at(ix, window, cx); + }); + assert_item_labels(&pane, ["A!", "B*"], cx); + cx.run_until_parked(); + + let pinned_drop_target_bounds = cx + .debug_bounds("pinned_tabs_border") + .expect("pinned_tabs_border should have debug bounds"); + + cx.simulate_event(MouseDownEvent { + position: pinned_drop_target_bounds.center(), + button: MouseButton::Left, + modifiers: Modifiers::default(), + click_count: 2, + first_mouse: false, + }); + + cx.run_until_parked(); + + cx.simulate_event(MouseUpEvent { + position: pinned_drop_target_bounds.center(), + button: MouseButton::Left, + modifiers: Modifiers::default(), + click_count: 2, + }); + + cx.run_until_parked(); + + // TODO: If editor::init is ever available in workspace tests, replace this + // with an assert_item_labels check that verifies a new tab is actually created. + assert!( + new_file_dispatched.get(), + "Double-clicking pinned tab bar empty space should dispatch the new file action" + ); + } + #[gpui::test] async fn test_add_item_with_new_item(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 644ff0282df216e79d6be24918d29b802e50a0e8..2994e9d0f67d73a30838f922c9b6a0b01b21ed14 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -2535,6 +2535,10 @@ mod tests { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, _, cx| { mw.set_random_database_id(cx); }); @@ -2564,7 +2568,7 @@ mod tests { // --- Remove the second workspace (index 1) --- multi_workspace.update_in(cx, |mw, window, cx| { - let ws = mw.workspaces()[1].clone(); + let ws = mw.workspaces().nth(1).unwrap().clone(); mw.remove(&ws, window, cx); }); @@ -4191,6 +4195,10 @@ mod tests { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, _, cx| { mw.set_random_database_id(cx); }); @@ -4233,7 +4241,7 @@ mod tests { // Remove workspace at index 1 (the second workspace). multi_workspace.update_in(cx, |mw, window, cx| { - let ws = mw.workspaces()[1].clone(); + let ws = mw.workspaces().nth(1).unwrap().clone(); mw.remove(&ws, window, cx); }); @@ -4288,6 +4296,10 @@ mod tests { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, _, cx| { mw.workspace().update(cx, |ws, _cx| { ws.set_database_id(ws1_id); @@ -4339,7 +4351,7 @@ mod tests { // Remove workspace2 (index 1). multi_workspace.update_in(cx, |mw, window, cx| { - let ws = mw.workspaces()[1].clone(); + let ws = mw.workspaces().nth(1).unwrap().clone(); mw.remove(&ws, window, cx); }); @@ -4385,6 +4397,10 @@ mod tests { let (multi_workspace, cx) = cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, _, cx| { mw.set_random_database_id(cx); }); @@ -4418,7 +4434,7 @@ mod tests { // Remove workspace2 — this pushes a task to pending_removal_tasks. multi_workspace.update_in(cx, |mw, window, cx| { - let ws = mw.workspaces()[1].clone(); + let ws = mw.workspaces().nth(1).unwrap().clone(); mw.remove(&ws, window, cx); }); @@ -4427,7 +4443,6 @@ mod tests { let all_tasks = multi_workspace.update_in(cx, |mw, window, cx| { let mut tasks: Vec> = mw .workspaces() - .iter() .map(|workspace| { workspace.update(cx, |workspace, cx| { workspace.flush_serialization(window, cx) @@ -4747,6 +4762,10 @@ mod tests { let (multi_workspace, cx) = cx .add_window_view(|window, cx| MultiWorkspace::test_new(project_2.clone(), window, cx)); + multi_workspace.update(cx, |mw, cx| { + mw.open_sidebar(cx); + }); + multi_workspace.update_in(cx, |mw, window, cx| { mw.test_add_workspace(project_1.clone(), window, cx); }); diff --git a/crates/workspace/src/security_modal.rs b/crates/workspace/src/security_modal.rs index 664aa891550cecdd602d54bfca579d04e03f33dc..2130a1d1eca3d33651a057d32a252718270f89f8 100644 --- a/crates/workspace/src/security_modal.rs +++ b/crates/workspace/src/security_modal.rs @@ -7,7 +7,7 @@ use std::{ }; use collections::{HashMap, HashSet}; -use gpui::{DismissEvent, EventEmitter, FocusHandle, Focusable, WeakEntity}; +use gpui::{DismissEvent, EventEmitter, FocusHandle, Focusable, ScrollHandle, WeakEntity}; use project::{ WorktreeId, @@ -17,7 +17,8 @@ use project::{ use smallvec::SmallVec; use theme::ActiveTheme; use ui::{ - AlertModal, Checkbox, FluentBuilder, KeyBinding, ListBulletItem, ToggleState, prelude::*, + AlertModal, Checkbox, FluentBuilder, KeyBinding, ListBulletItem, ToggleState, WithScrollbar, + prelude::*, }; use crate::{DismissDecision, ModalView, ToggleWorktreeSecurity}; @@ -29,6 +30,7 @@ pub struct SecurityModal { worktree_store: WeakEntity, remote_host: Option, focus_handle: FocusHandle, + project_list_scroll_handle: ScrollHandle, trusted: Option, } @@ -63,16 +65,17 @@ impl ModalView for SecurityModal { } impl Render for SecurityModal { - fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { if self.restricted_paths.is_empty() { self.dismiss(cx); return v_flex().into_any_element(); } - let header_label = if self.restricted_paths.len() == 1 { - "Unrecognized Project" + let restricted_count = self.restricted_paths.len(); + let header_label: SharedString = if restricted_count == 1 { + "Unrecognized Project".into() } else { - "Unrecognized Projects" + format!("Unrecognized Projects ({})", restricted_count).into() }; let trust_label = self.build_trust_label(); @@ -102,32 +105,61 @@ impl Render for SecurityModal { .child(Icon::new(IconName::Warning).color(Color::Warning)) .child(Label::new(header_label)), ) - .children(self.restricted_paths.values().filter_map(|restricted_path| { - let abs_path = if restricted_path.is_file { - restricted_path.abs_path.parent() - } else { - Some(restricted_path.abs_path.as_ref()) - }?; - let label = match &restricted_path.host { - Some(remote_host) => match &remote_host.user_name { - Some(user_name) => format!( - "{} ({}@{})", - self.shorten_path(abs_path).display(), - user_name, - remote_host.host_identifier - ), - None => format!( - "{} ({})", - self.shorten_path(abs_path).display(), - remote_host.host_identifier - ), - }, - None => self.shorten_path(abs_path).display().to_string(), - }; - Some(h_flex() - .pl(IconSize::default().rems() + rems(0.5)) - .child(Label::new(label).color(Color::Muted))) - })), + .child( + div() + .size_full() + .vertical_scrollbar_for(&self.project_list_scroll_handle, window, cx) + .child( + v_flex() + .id("paths_container") + .max_h_24() + .overflow_y_scroll() + .track_scroll(&self.project_list_scroll_handle) + .children( + self.restricted_paths.values().filter_map( + |restricted_path| { + let abs_path = if restricted_path.is_file { + restricted_path.abs_path.parent() + } else { + Some(restricted_path.abs_path.as_ref()) + }?; + let label = match &restricted_path.host { + Some(remote_host) => { + match &remote_host.user_name { + Some(user_name) => format!( + "{} ({}@{})", + self.shorten_path(abs_path) + .display(), + user_name, + remote_host.host_identifier + ), + None => format!( + "{} ({})", + self.shorten_path(abs_path) + .display(), + remote_host.host_identifier + ), + } + } + None => self + .shorten_path(abs_path) + .display() + .to_string(), + }; + Some( + h_flex() + .pl( + IconSize::default().rems() + rems(0.5), + ) + .child( + Label::new(label).color(Color::Muted), + ), + ) + }, + ), + ), + ), + ), ) .child( v_flex() @@ -219,6 +251,7 @@ impl SecurityModal { remote_host: remote_host.map(|host| host.into()), restricted_paths: HashMap::default(), focus_handle: cx.focus_handle(), + project_list_scroll_handle: ScrollHandle::new(), trust_parents: false, home_dir: std::env::home_dir(), trusted: None, diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index cc5d1e8635e9194522fea5506fef4084f8133c53..7979ffe828cbf8c4da5a40a29eaa6537f1433c3c 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -32,8 +32,8 @@ pub use crate::notifications::NotificationFrame; pub use dock::Panel; pub use multi_workspace::{ CloseWorkspaceSidebar, DraggedSidebar, FocusWorkspaceSidebar, MultiWorkspace, - MultiWorkspaceEvent, NextWorkspace, PreviousWorkspace, Sidebar, SidebarEvent, SidebarHandle, - SidebarRenderState, SidebarSide, ToggleWorkspaceSidebar, sidebar_side_context_menu, + MultiWorkspaceEvent, Sidebar, SidebarEvent, SidebarHandle, SidebarRenderState, SidebarSide, + ToggleWorkspaceSidebar, sidebar_side_context_menu, }; pub use path_list::{PathList, SerializedPathList}; pub use toast_layer::{ToastAction, ToastLayer, ToastView}; @@ -9079,7 +9079,7 @@ pub fn workspace_windows_for_location( }; multi_workspace.read(cx).is_ok_and(|multi_workspace| { - multi_workspace.workspaces().iter().any(|workspace| { + multi_workspace.workspaces().any(|workspace| { match workspace.read(cx).workspace_location(cx) { WorkspaceLocation::Location(location, _) => { match (&location, serialized_location) { @@ -10741,6 +10741,12 @@ mod tests { cx.add_window(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx)); cx.run_until_parked(); + multi_workspace_handle + .update(cx, |mw, _window, cx| { + mw.open_sidebar(cx); + }) + .unwrap(); + let workspace_a = multi_workspace_handle .read_with(cx, |mw, _| mw.workspace().clone()) .unwrap(); @@ -10754,7 +10760,7 @@ mod tests { // Activate workspace A multi_workspace_handle .update(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }) .unwrap(); @@ -10776,7 +10782,7 @@ mod tests { // Verify workspace A is active multi_workspace_handle .read_with(cx, |mw, _| { - assert_eq!(mw.active_workspace_index(), 0); + assert_eq!(mw.workspace(), &workspace_a); }) .unwrap(); @@ -10792,8 +10798,8 @@ mod tests { multi_workspace_handle .read_with(cx, |mw, _| { assert_eq!( - mw.active_workspace_index(), - 1, + mw.workspace(), + &workspace_b, "workspace B should be activated when it prompts" ); }) @@ -14511,6 +14517,12 @@ mod tests { cx.add_window(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx)); cx.run_until_parked(); + multi_workspace_handle + .update(cx, |mw, _window, cx| { + mw.open_sidebar(cx); + }) + .unwrap(); + let workspace_a = multi_workspace_handle .read_with(cx, |mw, _| mw.workspace().clone()) .unwrap(); @@ -14524,7 +14536,7 @@ mod tests { // Switch to workspace A multi_workspace_handle .update(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }) .unwrap(); @@ -14570,7 +14582,7 @@ mod tests { // Switch to workspace B multi_workspace_handle .update(cx, |mw, window, cx| { - let workspace = mw.workspaces()[1].clone(); + let workspace = mw.workspaces().nth(1).unwrap().clone(); mw.activate(workspace, window, cx); }) .unwrap(); @@ -14579,7 +14591,7 @@ mod tests { // Switch back to workspace A multi_workspace_handle .update(cx, |mw, window, cx| { - let workspace = mw.workspaces()[0].clone(); + let workspace = mw.workspaces().next().unwrap().clone(); mw.activate(workspace, window, cx); }) .unwrap(); diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index f1ed73fe89f0980a2705631063dcf4efbbe84bfb..b59123a1a159487f802210f3916e16856daf8e61 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -2606,7 +2606,7 @@ fn run_multi_workspace_sidebar_visual_tests( // Add worktree to workspace 1 (index 0) so it shows as "private-test-remote" let add_worktree1_task = multi_workspace_window .update(cx, |multi_workspace, _window, cx| { - let workspace1 = &multi_workspace.workspaces()[0]; + let workspace1 = multi_workspace.workspaces().next().unwrap(); let project = workspace1.read(cx).project().clone(); project.update(cx, |project, cx| { project.find_or_create_worktree(&workspace1_dir, true, cx) @@ -2625,7 +2625,7 @@ fn run_multi_workspace_sidebar_visual_tests( // Add worktree to workspace 2 (index 1) so it shows as "zed" let add_worktree2_task = multi_workspace_window .update(cx, |multi_workspace, _window, cx| { - let workspace2 = &multi_workspace.workspaces()[1]; + let workspace2 = multi_workspace.workspaces().nth(1).unwrap(); let project = workspace2.read(cx).project().clone(); project.update(cx, |project, cx| { project.find_or_create_worktree(&workspace2_dir, true, cx) @@ -2644,7 +2644,7 @@ fn run_multi_workspace_sidebar_visual_tests( // Switch to workspace 1 so it's highlighted as active (index 0) multi_workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = multi_workspace.workspaces()[0].clone(); + let workspace = multi_workspace.workspaces().next().unwrap().clone(); multi_workspace.activate(workspace, window, cx); }) .context("Failed to activate workspace 1")?; @@ -2672,7 +2672,7 @@ fn run_multi_workspace_sidebar_visual_tests( let save_tasks = multi_workspace_window .update(cx, |multi_workspace, _window, cx| { let thread_store = agent::ThreadStore::global(cx); - let workspaces = multi_workspace.workspaces().to_vec(); + let workspaces: Vec<_> = multi_workspace.workspaces().cloned().collect(); let mut tasks = Vec::new(); for (index, workspace) in workspaces.iter().enumerate() { @@ -3211,7 +3211,7 @@ edition = "2021" // Add the git project as a worktree let add_worktree_task = workspace_window .update(cx, |multi_workspace, _window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); let project = workspace.read(cx).project().clone(); project.update(cx, |project, cx| { project.find_or_create_worktree(&project_path, true, cx) @@ -3236,7 +3236,7 @@ edition = "2021" // Open the project panel let (weak_workspace, async_window_cx) = workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); (workspace.read(cx).weak_handle(), window.to_async(cx)) }) .context("Failed to get workspace handle")?; @@ -3250,7 +3250,7 @@ edition = "2021" workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); workspace.update(cx, |workspace, cx| { workspace.add_panel(project_panel, window, cx); workspace.open_panel::(window, cx); @@ -3263,7 +3263,7 @@ edition = "2021" // Open main.rs in the editor let open_file_task = workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); workspace.update(cx, |workspace, cx| { let worktree = workspace.project().read(cx).worktrees(cx).next(); if let Some(worktree) = worktree { @@ -3291,7 +3291,7 @@ edition = "2021" // Load the AgentPanel let (weak_workspace, async_window_cx) = workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); (workspace.read(cx).weak_handle(), window.to_async(cx)) }) .context("Failed to get workspace handle for agent panel")?; @@ -3335,7 +3335,7 @@ edition = "2021" workspace_window .update(cx, |multi_workspace, window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); workspace.update(cx, |workspace, cx| { workspace.add_panel(panel.clone(), window, cx); workspace.open_panel::(window, cx); @@ -3512,7 +3512,7 @@ edition = "2021" .is_none() }); let workspace_count = workspace_window.update(cx, |multi_workspace, _window, _cx| { - multi_workspace.workspaces().len() + multi_workspace.workspaces().count() })?; if workspace_count == 2 && status_cleared { creation_complete = true; @@ -3531,7 +3531,7 @@ edition = "2021" // error state by injecting the stub server, and shrink the panel so the // editor content is visible. workspace_window.update(cx, |multi_workspace, window, cx| { - let new_workspace = &multi_workspace.workspaces()[1]; + let new_workspace = multi_workspace.workspaces().nth(1).unwrap(); new_workspace.update(cx, |workspace, cx| { if let Some(new_panel) = workspace.panel::(cx) { new_panel.update(cx, |panel, cx| { @@ -3544,7 +3544,7 @@ edition = "2021" // Type and send a message so the thread target dropdown disappears. let new_panel = workspace_window.update(cx, |multi_workspace, _window, cx| { - let new_workspace = &multi_workspace.workspaces()[1]; + let new_workspace = multi_workspace.workspaces().nth(1).unwrap(); new_workspace.read(cx).panel::(cx) })?; if let Some(new_panel) = new_panel { @@ -3585,7 +3585,7 @@ edition = "2021" workspace_window .update(cx, |multi_workspace, _window, cx| { - let workspace = &multi_workspace.workspaces()[0]; + let workspace = multi_workspace.workspaces().next().unwrap(); let project = workspace.read(cx).project().clone(); project.update(cx, |project, cx| { let worktree_ids: Vec<_> = diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index ed49236a9da6b69f80c8c981eaddaa16ca69face..03e128415e1aa8390d1b95816755d3644064dada 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -1524,7 +1524,7 @@ fn quit(_: &Quit, cx: &mut App) { let window = *window; let workspaces = window .update(cx, |multi_workspace, _, _| { - multi_workspace.workspaces().to_vec() + multi_workspace.workspaces().cloned().collect::>() }) .log_err(); @@ -2458,7 +2458,6 @@ mod tests { .update(cx, |multi_workspace, window, cx| { let mut tasks = multi_workspace .workspaces() - .iter() .map(|workspace| { workspace.update(cx, |workspace, cx| { workspace.flush_serialization(window, cx) @@ -2610,7 +2609,7 @@ mod tests { cx.run_until_parked(); multi_workspace_1 .update(cx, |multi_workspace, _window, cx| { - assert_eq!(multi_workspace.workspaces().len(), 2); + assert_eq!(multi_workspace.workspaces().count(), 2); assert!(multi_workspace.sidebar_open()); let workspace = multi_workspace.workspace().read(cx); assert_eq!( @@ -5512,6 +5511,11 @@ mod tests { let project = project1.clone(); |window, cx| MultiWorkspace::test_new(project, window, cx) }); + window + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); cx.run_until_parked(); assert_eq!(cx.windows().len(), 1, "Should start with 1 window"); @@ -5534,7 +5538,7 @@ mod tests { let workspace1 = window .read_with(cx, |multi_workspace, _| { - multi_workspace.workspaces()[0].clone() + multi_workspace.workspaces().next().unwrap().clone() }) .unwrap(); @@ -5543,8 +5547,8 @@ mod tests { multi_workspace.activate(workspace2.clone(), window, cx); multi_workspace.activate(workspace3.clone(), window, cx); // Switch back to workspace1 for test setup - multi_workspace.activate(workspace1, window, cx); - assert_eq!(multi_workspace.active_workspace_index(), 0); + multi_workspace.activate(workspace1.clone(), window, cx); + assert_eq!(multi_workspace.workspace(), &workspace1); }) .unwrap(); @@ -5553,8 +5557,8 @@ mod tests { // Verify setup: 3 workspaces, workspace 0 active, still 1 window window .read_with(cx, |multi_workspace, _| { - assert_eq!(multi_workspace.workspaces().len(), 3); - assert_eq!(multi_workspace.active_workspace_index(), 0); + assert_eq!(multi_workspace.workspaces().count(), 3); + assert_eq!(multi_workspace.workspace(), &workspace1); }) .unwrap(); assert_eq!(cx.windows().len(), 1); @@ -5577,8 +5581,8 @@ mod tests { window .read_with(cx, |multi_workspace, cx| { assert_eq!( - multi_workspace.active_workspace_index(), - 2, + multi_workspace.workspace(), + &workspace3, "Should have switched to workspace 3 which contains /dir3" ); let active_item = multi_workspace @@ -5611,8 +5615,8 @@ mod tests { window .read_with(cx, |multi_workspace, cx| { assert_eq!( - multi_workspace.active_workspace_index(), - 1, + multi_workspace.workspace(), + &workspace2, "Should have switched to workspace 2 which contains /dir2" ); let active_item = multi_workspace @@ -5660,8 +5664,8 @@ mod tests { window .read_with(cx, |multi_workspace, cx| { assert_eq!( - multi_workspace.active_workspace_index(), - 0, + multi_workspace.workspace(), + &workspace1, "Should have switched back to workspace 0 which contains /dir1" ); let active_item = multi_workspace @@ -5711,6 +5715,11 @@ mod tests { let project = project1.clone(); |window, cx| MultiWorkspace::test_new(project, window, cx) }); + window1 + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); cx.run_until_parked(); @@ -5737,6 +5746,11 @@ mod tests { let project = project3.clone(); |window, cx| MultiWorkspace::test_new(project, window, cx) }); + window2 + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); cx.run_until_parked(); assert_eq!(cx.windows().len(), 2); @@ -5771,7 +5785,7 @@ mod tests { // Verify workspace1_1 is active window1 .read_with(cx, |multi_workspace, _| { - assert_eq!(multi_workspace.active_workspace_index(), 0); + assert_eq!(multi_workspace.workspace(), &workspace1_1); }) .unwrap(); @@ -5837,7 +5851,7 @@ mod tests { // Verify workspace1_1 is still active (not workspace1_2 with dirty item) window1 .read_with(cx, |multi_workspace, _| { - assert_eq!(multi_workspace.active_workspace_index(), 0); + assert_eq!(multi_workspace.workspace(), &workspace1_1); }) .unwrap(); @@ -5848,8 +5862,8 @@ mod tests { window1 .read_with(cx, |multi_workspace, _| { assert_eq!( - multi_workspace.active_workspace_index(), - 1, + multi_workspace.workspace(), + &workspace1_2, "Case 2: Non-active workspace should be activated when it has dirty item" ); }) @@ -6002,6 +6016,12 @@ mod tests { .await .expect("failed to open first workspace"); + window_a + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); + window_a .update(cx, |multi_workspace, window, cx| { multi_workspace.open_project(vec![dir2.into()], OpenMode::Activate, window, cx) @@ -6028,13 +6048,19 @@ mod tests { .await .expect("failed to open third workspace"); + window_b + .update(cx, |multi_workspace, _, cx| { + multi_workspace.open_sidebar(cx); + }) + .unwrap(); + // Currently dir2 is active because it was added last. // So, switch window_a's active workspace to dir1 (index 0). // This sets up a non-trivial assertion: after restore, dir1 should // still be active rather than whichever workspace happened to restore last. window_a .update(cx, |multi_workspace, window, cx| { - let workspace = multi_workspace.workspaces()[0].clone(); + let workspace = multi_workspace.workspaces().next().unwrap().clone(); multi_workspace.activate(workspace, window, cx); }) .unwrap(); @@ -6150,7 +6176,7 @@ mod tests { ProjectGroupKey::new(None, PathList::new(&[dir2])), ] ); - assert_eq!(mw.workspaces().len(), 1); + assert_eq!(mw.workspaces().count(), 1); }) .unwrap(); @@ -6161,7 +6187,7 @@ mod tests { mw.project_group_keys().cloned().collect::>(), vec![ProjectGroupKey::new(None, PathList::new(&[dir3]))] ); - assert_eq!(mw.workspaces().len(), 1); + assert_eq!(mw.workspaces().count(), 1); }) .unwrap(); } diff --git a/crates/zeta_prompt/Cargo.toml b/crates/zeta_prompt/Cargo.toml index 21634583d33e13cd9570041f3e8466d05cef9944..8acd91a7a43613fd63f4f46ab73e9485fd64e7d2 100644 --- a/crates/zeta_prompt/Cargo.toml +++ b/crates/zeta_prompt/Cargo.toml @@ -13,6 +13,7 @@ path = "src/zeta_prompt.rs" [dependencies] anyhow.workspace = true +imara-diff.workspace = true serde.workspace = true strum.workspace = true diff --git a/crates/zeta_prompt/src/udiff.rs b/crates/zeta_prompt/src/udiff.rs index 2658da5893ee923dc0f5798554276f5735abb51a..ab0837b9f54ac0bf9ef74038f0c876b751f70200 100644 --- a/crates/zeta_prompt/src/udiff.rs +++ b/crates/zeta_prompt/src/udiff.rs @@ -6,6 +6,10 @@ use std::{ }; use anyhow::{Context as _, Result, anyhow}; +use imara_diff::{ + Algorithm, Sink, diff, + intern::{InternedInput, Interner, Token}, +}; pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> { if prefix.is_empty() { @@ -221,6 +225,181 @@ pub fn disambiguate_by_line_number( } } +pub fn unified_diff_with_context( + old_text: &str, + new_text: &str, + old_start_line: u32, + new_start_line: u32, + context_lines: u32, +) -> String { + let input = InternedInput::new(old_text, new_text); + diff( + Algorithm::Histogram, + &input, + OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line, context_lines), + ) +} + +struct OffsetUnifiedDiffBuilder<'a> { + before: &'a [Token], + after: &'a [Token], + interner: &'a Interner<&'a str>, + pos: u32, + before_hunk_start: u32, + after_hunk_start: u32, + before_hunk_len: u32, + after_hunk_len: u32, + old_line_offset: u32, + new_line_offset: u32, + context_lines: u32, + buffer: String, + dst: String, +} + +impl<'a> OffsetUnifiedDiffBuilder<'a> { + fn new( + input: &'a InternedInput<&'a str>, + old_line_offset: u32, + new_line_offset: u32, + context_lines: u32, + ) -> Self { + Self { + before_hunk_start: 0, + after_hunk_start: 0, + before_hunk_len: 0, + after_hunk_len: 0, + old_line_offset, + new_line_offset, + context_lines, + buffer: String::with_capacity(8), + dst: String::new(), + interner: &input.interner, + before: &input.before, + after: &input.after, + pos: 0, + } + } + + fn print_tokens(&mut self, tokens: &[Token], prefix: char) { + for &token in tokens { + writeln!(&mut self.buffer, "{prefix}{}", self.interner[token]).unwrap(); + } + } + + fn flush(&mut self) { + if self.before_hunk_len == 0 && self.after_hunk_len == 0 { + return; + } + + let end = (self.pos + self.context_lines).min(self.before.len() as u32); + self.update_pos(end, end); + + writeln!( + &mut self.dst, + "@@ -{},{} +{},{} @@", + self.before_hunk_start + 1 + self.old_line_offset, + self.before_hunk_len, + self.after_hunk_start + 1 + self.new_line_offset, + self.after_hunk_len, + ) + .unwrap(); + write!(&mut self.dst, "{}", &self.buffer).unwrap(); + self.buffer.clear(); + self.before_hunk_len = 0; + self.after_hunk_len = 0; + } + + fn update_pos(&mut self, print_to: u32, move_to: u32) { + self.print_tokens(&self.before[self.pos as usize..print_to as usize], ' '); + let len = print_to - self.pos; + self.before_hunk_len += len; + self.after_hunk_len += len; + self.pos = move_to; + } +} + +impl Sink for OffsetUnifiedDiffBuilder<'_> { + type Out = String; + + fn process_change(&mut self, before: Range, after: Range) { + if before.start - self.pos > self.context_lines * 2 { + self.flush(); + } + if self.before_hunk_len == 0 && self.after_hunk_len == 0 { + self.pos = before.start.saturating_sub(self.context_lines); + self.before_hunk_start = self.pos; + self.after_hunk_start = after.start.saturating_sub(self.context_lines); + } + + self.update_pos(before.start, before.end); + self.before_hunk_len += before.end - before.start; + self.after_hunk_len += after.end - after.start; + self.print_tokens( + &self.before[before.start as usize..before.end as usize], + '-', + ); + self.print_tokens(&self.after[after.start as usize..after.end as usize], '+'); + } + + fn finish(mut self) -> Self::Out { + self.flush(); + self.dst + } +} + +pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option) -> String { + let Some(cursor_offset) = cursor_offset else { + return patch.to_string(); + }; + + let mut result = String::new(); + let mut line_start_offset = 0usize; + + for line in patch.lines() { + if matches!( + DiffLine::parse(line), + DiffLine::Garbage(content) + if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER) + ) { + continue; + } + + if !result.is_empty() { + result.push('\n'); + } + result.push_str(line); + + match DiffLine::parse(line) { + DiffLine::Addition(content) => { + let line_end_offset = line_start_offset + content.len(); + + if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset { + let cursor_column = cursor_offset - line_start_offset; + + result.push('\n'); + result.push('#'); + for _ in 0..cursor_column { + result.push(' '); + } + write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap(); + } + + line_start_offset = line_end_offset + 1; + } + DiffLine::Context(content) => { + line_start_offset += content.len() + 1; + } + _ => {} + } + } + + if patch.ends_with('\n') { + result.push('\n'); + } + + result +} + pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { apply_diff_to_string_with_hunk_offset(diff_str, text).map(|(text, _)| text) } @@ -1203,4 +1382,25 @@ mod tests { // Edit range end should be clamped to 7 (new context length). assert_eq!(hunk.edits[0].range, 4..7); } + + #[test] + fn test_unified_diff_with_context_matches_expected_context_window() { + let old_text = "line1\nline2\nline3\nline4\nline5\nCHANGE_ME\nline7\nline8\n"; + let new_text = "line1\nline2\nline3\nline4\nline5\nCHANGED\nline7\nline8\n"; + + let diff_default = unified_diff_with_context(old_text, new_text, 0, 0, 3); + assert_eq!( + diff_default, + "@@ -3,6 +3,6 @@\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n" + ); + + let diff_full_context = unified_diff_with_context(old_text, new_text, 0, 0, 8); + assert_eq!( + diff_full_context, + "@@ -1,8 +1,8 @@\n line1\n line2\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n" + ); + + let diff_no_context = unified_diff_with_context(old_text, new_text, 0, 0, 0); + assert_eq!(diff_no_context, "@@ -6,1 +6,1 @@\n-CHANGE_ME\n+CHANGED\n"); + } } diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 0d72d6cd7a46782aa4b572a4ef564d5fe3dec417..49b86404a8ad49c27e29bb2b887fb3fc8171c35c 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -106,10 +106,19 @@ impl std::fmt::Display for ZetaFormat { impl ZetaFormat { pub fn parse(format_name: &str) -> Result { + let lower = format_name.to_lowercase(); + + // Exact case-insensitive match takes priority, bypassing ambiguity checks. + for variant in ZetaFormat::iter() { + if <&'static str>::from(&variant).to_lowercase() == lower { + return Ok(variant); + } + } + let mut results = ZetaFormat::iter().filter(|version| { <&'static str>::from(version) .to_lowercase() - .contains(&format_name.to_lowercase()) + .contains(&lower) }); let Some(result) = results.next() else { anyhow::bail!( @@ -927,11 +936,39 @@ fn cursor_in_new_text( }) } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct ParsedOutput { /// Text that should replace the editable region pub new_editable_region: String, /// The byte range within `cursor_excerpt` that this replacement applies to pub range_in_excerpt: Range, + /// Byte offset of the cursor marker within `new_editable_region`, if present + pub cursor_offset_in_new_editable_region: Option, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct CursorPosition { + pub path: String, + pub row: usize, + pub column: usize, + pub offset: usize, + pub editable_region_offset: usize, +} + +pub fn parsed_output_from_editable_region( + range_in_excerpt: Range, + mut new_editable_region: String, +) -> ParsedOutput { + let cursor_offset_in_new_editable_region = new_editable_region.find(CURSOR_MARKER); + if let Some(offset) = cursor_offset_in_new_editable_region { + new_editable_region.replace_range(offset..offset + CURSOR_MARKER.len(), ""); + } + + ParsedOutput { + new_editable_region, + range_in_excerpt, + cursor_offset_in_new_editable_region, + } } /// Parse model output for the given zeta format @@ -999,12 +1036,97 @@ pub fn parse_zeta2_model_output( let range_in_excerpt = range_in_context.start + context_start..range_in_context.end + context_start; - Ok(ParsedOutput { - new_editable_region: output, - range_in_excerpt, + Ok(parsed_output_from_editable_region(range_in_excerpt, output)) +} + +pub fn parse_zeta2_model_output_as_patch( + output: &str, + format: ZetaFormat, + prompt_inputs: &ZetaPromptInput, +) -> Result { + let parsed = parse_zeta2_model_output(output, format, prompt_inputs)?; + parsed_output_to_patch(prompt_inputs, parsed) +} + +pub fn cursor_position_from_parsed_output( + prompt_inputs: &ZetaPromptInput, + parsed: &ParsedOutput, +) -> Option { + let cursor_offset = parsed.cursor_offset_in_new_editable_region?; + let editable_region_offset = parsed.range_in_excerpt.start; + let excerpt = prompt_inputs.cursor_excerpt.as_ref(); + + let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); + + let new_editable_region = &parsed.new_editable_region; + let prefix_end = cursor_offset.min(new_editable_region.len()); + let new_region_prefix = &new_editable_region[..prefix_end]; + + let row = editable_region_start_line + new_region_prefix.matches('\n').count(); + + let column = match new_region_prefix.rfind('\n') { + Some(last_newline) => cursor_offset - last_newline - 1, + None => { + let content_prefix = &excerpt[..editable_region_offset]; + let content_column = match content_prefix.rfind('\n') { + Some(last_newline) => editable_region_offset - last_newline - 1, + None => editable_region_offset, + }; + content_column + cursor_offset + } + }; + + Some(CursorPosition { + path: prompt_inputs.cursor_path.to_string_lossy().into_owned(), + row, + column, + offset: editable_region_offset + cursor_offset, + editable_region_offset: cursor_offset, }) } +pub fn parsed_output_to_patch( + prompt_inputs: &ZetaPromptInput, + parsed: ParsedOutput, +) -> Result { + let range_in_excerpt = parsed.range_in_excerpt; + let excerpt = prompt_inputs.cursor_excerpt.as_ref(); + let old_text = excerpt[range_in_excerpt.clone()].to_string(); + let mut new_text = parsed.new_editable_region; + + let mut old_text_normalized = old_text; + if !new_text.is_empty() && !new_text.ends_with('\n') { + new_text.push('\n'); + } + if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') { + old_text_normalized.push('\n'); + } + + let editable_region_offset = range_in_excerpt.start; + let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count() as u32; + let editable_region_lines = old_text_normalized.lines().count() as u32; + + let diff = udiff::unified_diff_with_context( + &old_text_normalized, + &new_text, + editable_region_start_line, + editable_region_start_line, + editable_region_lines, + ); + + let path = prompt_inputs + .cursor_path + .to_string_lossy() + .trim_start_matches('/') + .to_string(); + let formatted_diff = format!("--- a/{path}\n+++ b/{path}\n{diff}"); + + Ok(udiff::encode_cursor_in_patch( + &formatted_diff, + parsed.cursor_offset_in_new_editable_region, + )) +} + pub fn excerpt_range_for_format( format: ZetaFormat, ranges: &ExcerptRanges, @@ -5400,6 +5522,33 @@ mod tests { assert_eq!(apply_edit(excerpt, &output1), "new content\n"); } + #[test] + fn test_parsed_output_to_patch_round_trips_through_udiff_application() { + let excerpt = "before ctx\nctx start\neditable old\nctx end\nafter ctx\n"; + let context_start = excerpt.find("ctx start").unwrap(); + let context_end = excerpt.find("after ctx").unwrap(); + let editable_start = excerpt.find("editable old").unwrap(); + let editable_end = editable_start + "editable old\n".len(); + let input = make_input_with_context_range( + excerpt, + editable_start..editable_end, + context_start..context_end, + editable_start, + ); + + let parsed = parse_zeta2_model_output( + "editable new\n>>>>>>> UPDATED\n", + ZetaFormat::V0131GitMergeMarkersPrefix, + &input, + ) + .unwrap(); + let expected = apply_edit(excerpt, &parsed); + let patch = parsed_output_to_patch(&input, parsed).unwrap(); + let patched = udiff::apply_diff_to_string(&patch, excerpt).unwrap(); + + assert_eq!(patched, expected); + } + #[test] fn test_special_tokens_not_triggered_by_comment_separator() { // Regression test for https://github.com/zed-industries/zed/issues/52489 diff --git a/docs/src/vim.md b/docs/src/vim.md index 1798f16a93244f2694b30ffa70119da1e4498fdc..8e93edff081681a3e094c811e2d76822766ef67e 100644 --- a/docs/src/vim.md +++ b/docs/src/vim.md @@ -562,6 +562,7 @@ You can change the following settings to modify vim mode's behavior: | use_system_clipboard | Determines how system clipboard is used:
  • "always": use for all operations
  • "never": only use when explicitly specified
  • "on_yank": use for yank operations
| "always" | | use_multiline_find | deprecated | | use_smartcase_find | If `true`, `f` and `t` motions are case-insensitive when the target letter is lowercase. | false | +| use_regex_search | If `true`, then vim search will use regex mode | true | | gdefault | If `true`, the `:substitute` command replaces all matches in a line by default (as if `g` flag was given). The `g` flag then toggles this, replacing only the first match. | false | | toggle_relative_line_numbers | If `true`, line numbers are relative in normal mode and absolute in insert mode, giving you the best of both options. | false | | custom_digraphs | An object that allows you to add custom digraphs. Read below for an example. | {} | @@ -587,6 +588,7 @@ Here's an example of these settings changed: "default_mode": "insert", "use_system_clipboard": "never", "use_smartcase_find": true, + "use_regex_search": true, "gdefault": true, "toggle_relative_line_numbers": true, "highlight_on_yank_duration": 50, diff --git a/tooling/compliance/Cargo.toml b/tooling/compliance/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..9b1ade359daa4b7a02beff861c94e01fff071f84 --- /dev/null +++ b/tooling/compliance/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "compliance" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[features] +octo-client = ["dep:octocrab", "dep:jsonwebtoken", "dep:futures", "dep:tokio"] + +[dependencies] +anyhow.workspace = true +async-trait.workspace = true +derive_more.workspace = true +futures = { workspace = true, optional = true } +itertools.workspace = true +jsonwebtoken = { version = "10.2", features = ["use_pem"], optional = true } +octocrab = { version = "0.49", default-features = false, features = [ + "default-client", + "jwt-aws-lc-rs", + "retry", + "rustls", + "rustls-aws-lc-rs", + "stream", + "timeout" +], optional = true } +regex.workspace = true +semver.workspace = true +serde.workspace = true +serde_json.workspace = true +tokio = { workspace = true, optional = true } + +[dev-dependencies] +indoc.workspace = true +tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/tooling/compliance/LICENSE-GPL b/tooling/compliance/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/tooling/compliance/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/tooling/compliance/src/checks.rs b/tooling/compliance/src/checks.rs new file mode 100644 index 0000000000000000000000000000000000000000..a0623fbbbc179edf9f5b6d777b3116ff498f0265 --- /dev/null +++ b/tooling/compliance/src/checks.rs @@ -0,0 +1,647 @@ +use std::{fmt, ops::Not as _}; + +use itertools::Itertools as _; + +use crate::{ + git::{CommitDetails, CommitList}, + github::{ + CommitAuthor, GitHubClient, GitHubUser, GithubLogin, PullRequestComment, PullRequestData, + PullRequestReview, ReviewState, + }, + report::Report, +}; + +const ZED_ZIPPY_COMMENT_APPROVAL_PATTERN: &str = "@zed-zippy approve"; +const ZED_ZIPPY_GROUP_APPROVAL: &str = "@zed-industries/approved"; + +#[derive(Debug)] +pub enum ReviewSuccess { + ApprovingComment(Vec), + CoAuthored(Vec), + ExternalMergedContribution { merged_by: GitHubUser }, + PullRequestReviewed(Vec), +} + +impl ReviewSuccess { + pub(crate) fn reviewers(&self) -> anyhow::Result { + let reviewers = match self { + Self::CoAuthored(authors) => authors.iter().map(ToString::to_string).collect_vec(), + Self::PullRequestReviewed(reviews) => reviews + .iter() + .filter_map(|review| review.user.as_ref()) + .map(|user| format!("@{}", user.login)) + .collect_vec(), + Self::ApprovingComment(comments) => comments + .iter() + .map(|comment| format!("@{}", comment.user.login)) + .collect_vec(), + Self::ExternalMergedContribution { merged_by } => { + vec![format!("@{}", merged_by.login)] + } + }; + + let reviewers = reviewers.into_iter().unique().collect_vec(); + + reviewers + .is_empty() + .not() + .then(|| reviewers.join(", ")) + .ok_or_else(|| anyhow::anyhow!("Expected at least one reviewer")) + } +} + +impl fmt::Display for ReviewSuccess { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::CoAuthored(_) => formatter.write_str("Co-authored by an organization member"), + Self::PullRequestReviewed(_) => { + formatter.write_str("Approved by an organization review") + } + Self::ApprovingComment(_) => { + formatter.write_str("Approved by an organization approval comment") + } + Self::ExternalMergedContribution { .. } => { + formatter.write_str("External merged contribution") + } + } + } +} + +#[derive(Debug)] +pub enum ReviewFailure { + // todo: We could still query the GitHub API here to search for one + NoPullRequestFound, + Unreviewed, + UnableToDetermineReviewer, + Other(anyhow::Error), +} + +impl fmt::Display for ReviewFailure { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::NoPullRequestFound => formatter.write_str("No pull request found"), + Self::Unreviewed => formatter + .write_str("No qualifying organization approval found for the pull request"), + Self::UnableToDetermineReviewer => formatter.write_str("Could not determine reviewer"), + Self::Other(error) => write!(formatter, "Failed to inspect review state: {error}"), + } + } +} + +pub(crate) type ReviewResult = Result; + +impl> From for ReviewFailure { + fn from(err: E) -> Self { + Self::Other(anyhow::anyhow!(err)) + } +} + +pub struct Reporter<'a> { + commits: CommitList, + github_client: &'a GitHubClient, +} + +impl<'a> Reporter<'a> { + pub fn new(commits: CommitList, github_client: &'a GitHubClient) -> Self { + Self { + commits, + github_client, + } + } + + /// Method that checks every commit for compliance + async fn check_commit(&self, commit: &CommitDetails) -> Result { + let Some(pr_number) = commit.pr_number() else { + return Err(ReviewFailure::NoPullRequestFound); + }; + + let pull_request = self.github_client.get_pull_request(pr_number).await?; + + if let Some(approval) = self.check_pull_request_approved(&pull_request).await? { + return Ok(approval); + } + + if let Some(approval) = self + .check_approving_pull_request_comment(&pull_request) + .await? + { + return Ok(approval); + } + + if let Some(approval) = self.check_commit_co_authors(commit).await? { + return Ok(approval); + } + + // if let Some(approval) = self.check_external_merged_pr(pr_number).await? { + // return Ok(approval); + // } + + Err(ReviewFailure::Unreviewed) + } + + async fn check_commit_co_authors( + &self, + commit: &CommitDetails, + ) -> Result, ReviewFailure> { + if commit.co_authors().is_some() + && let Some(commit_authors) = self + .github_client + .get_commit_authors([commit.sha()]) + .await? + .get(commit.sha()) + .and_then(|authors| authors.co_authors()) + { + let mut org_co_authors = Vec::new(); + for co_author in commit_authors { + if let Some(github_login) = co_author.user() + && self + .github_client + .check_org_membership(github_login) + .await? + { + org_co_authors.push(co_author.clone()); + } + } + + Ok(org_co_authors + .is_empty() + .not() + .then_some(ReviewSuccess::CoAuthored(org_co_authors))) + } else { + Ok(None) + } + } + + #[allow(unused)] + async fn check_external_merged_pr( + &self, + pull_request: PullRequestData, + ) -> Result, ReviewFailure> { + if let Some(user) = pull_request.user + && self + .github_client + .check_org_membership(&GithubLogin::new(user.login)) + .await? + .not() + { + pull_request.merged_by.map_or( + Err(ReviewFailure::UnableToDetermineReviewer), + |merged_by| { + Ok(Some(ReviewSuccess::ExternalMergedContribution { + merged_by, + })) + }, + ) + } else { + Ok(None) + } + } + + async fn check_pull_request_approved( + &self, + pull_request: &PullRequestData, + ) -> Result, ReviewFailure> { + let pr_reviews = self + .github_client + .get_pull_request_reviews(pull_request.number) + .await?; + + if !pr_reviews.is_empty() { + let mut org_approving_reviews = Vec::new(); + for review in pr_reviews { + if let Some(github_login) = review.user.as_ref() + && pull_request + .user + .as_ref() + .is_none_or(|pr_user| pr_user.login != github_login.login) + && review + .state + .is_some_and(|state| state == ReviewState::Approved) + && self + .github_client + .check_org_membership(&GithubLogin::new(github_login.login.clone())) + .await? + { + org_approving_reviews.push(review); + } + } + + Ok(org_approving_reviews + .is_empty() + .not() + .then_some(ReviewSuccess::PullRequestReviewed(org_approving_reviews))) + } else { + Ok(None) + } + } + + async fn check_approving_pull_request_comment( + &self, + pull_request: &PullRequestData, + ) -> Result, ReviewFailure> { + let other_comments = self + .github_client + .get_pull_request_comments(pull_request.number) + .await?; + + if !other_comments.is_empty() { + let mut org_approving_comments = Vec::new(); + + for comment in other_comments { + if pull_request + .user + .as_ref() + .is_some_and(|pr_author| pr_author.login != comment.user.login) + && comment.body.as_ref().is_some_and(|body| { + body.contains(ZED_ZIPPY_COMMENT_APPROVAL_PATTERN) + || body.contains(ZED_ZIPPY_GROUP_APPROVAL) + }) + && self + .github_client + .check_org_membership(&GithubLogin::new(comment.user.login.clone())) + .await? + { + org_approving_comments.push(comment); + } + } + + Ok(org_approving_comments + .is_empty() + .not() + .then_some(ReviewSuccess::ApprovingComment(org_approving_comments))) + } else { + Ok(None) + } + } + + pub async fn generate_report(mut self) -> anyhow::Result { + let mut report = Report::new(); + + let commits_to_check = std::mem::take(&mut self.commits); + let total_commits = commits_to_check.len(); + + for (i, commit) in commits_to_check.into_iter().enumerate() { + println!( + "Checking commit {:?} ({current}/{total})", + commit.sha().short(), + current = i + 1, + total = total_commits + ); + + let review_result = self.check_commit(&commit).await; + + if let Err(err) = &review_result { + println!("Commit {:?} failed review: {:?}", commit.sha().short(), err); + } + + report.add(commit, review_result); + } + + Ok(report) + } +} + +#[cfg(test)] +mod tests { + use std::rc::Rc; + use std::str::FromStr; + + use crate::git::{CommitDetails, CommitList, CommitSha}; + use crate::github::{ + AuthorsForCommits, GitHubApiClient, GitHubClient, GitHubUser, GithubLogin, + PullRequestComment, PullRequestData, PullRequestReview, ReviewState, + }; + + use super::{Reporter, ReviewFailure, ReviewSuccess}; + + struct MockGitHubApi { + pull_request: PullRequestData, + reviews: Vec, + comments: Vec, + commit_authors_json: serde_json::Value, + org_members: Vec, + } + + #[async_trait::async_trait(?Send)] + impl GitHubApiClient for MockGitHubApi { + async fn get_pull_request(&self, _pr_number: u64) -> anyhow::Result { + Ok(self.pull_request.clone()) + } + + async fn get_pull_request_reviews( + &self, + _pr_number: u64, + ) -> anyhow::Result> { + Ok(self.reviews.clone()) + } + + async fn get_pull_request_comments( + &self, + _pr_number: u64, + ) -> anyhow::Result> { + Ok(self.comments.clone()) + } + + async fn get_commit_authors( + &self, + _commit_shas: &[&CommitSha], + ) -> anyhow::Result { + serde_json::from_value(self.commit_authors_json.clone()).map_err(Into::into) + } + + async fn check_org_membership(&self, login: &GithubLogin) -> anyhow::Result { + Ok(self + .org_members + .iter() + .any(|member| member == login.as_str())) + } + + async fn ensure_pull_request_has_label( + &self, + _label: &str, + _pr_number: u64, + ) -> anyhow::Result<()> { + Ok(()) + } + } + + fn make_commit( + sha: &str, + author_name: &str, + author_email: &str, + title: &str, + body: &str, + ) -> CommitDetails { + let formatted = format!( + "{sha}|field-delimiter|{author_name}|field-delimiter|{author_email}|field-delimiter|\ + {title}|body-delimiter|{body}|commit-delimiter|" + ); + CommitList::from_str(&formatted) + .expect("test commit should parse") + .into_iter() + .next() + .expect("should have one commit") + } + + fn review(login: &str, state: ReviewState) -> PullRequestReview { + PullRequestReview { + user: Some(GitHubUser { + login: login.to_owned(), + }), + state: Some(state), + } + } + + fn comment(login: &str, body: &str) -> PullRequestComment { + PullRequestComment { + user: GitHubUser { + login: login.to_owned(), + }, + body: Some(body.to_owned()), + } + } + + struct TestScenario { + pull_request: PullRequestData, + reviews: Vec, + comments: Vec, + commit_authors_json: serde_json::Value, + org_members: Vec, + commit: CommitDetails, + } + + impl TestScenario { + fn single_commit() -> Self { + Self { + pull_request: PullRequestData { + number: 1234, + user: Some(GitHubUser { + login: "alice".to_owned(), + }), + merged_by: None, + }, + reviews: vec![], + comments: vec![], + commit_authors_json: serde_json::json!({}), + org_members: vec![], + commit: make_commit( + "abc12345abc12345", + "Alice", + "alice@test.com", + "Fix thing (#1234)", + "", + ), + } + } + + fn with_reviews(mut self, reviews: Vec) -> Self { + self.reviews = reviews; + self + } + + fn with_comments(mut self, comments: Vec) -> Self { + self.comments = comments; + self + } + + fn with_org_members(mut self, members: Vec<&str>) -> Self { + self.org_members = members.into_iter().map(str::to_owned).collect(); + self + } + + fn with_commit_authors_json(mut self, json: serde_json::Value) -> Self { + self.commit_authors_json = json; + self + } + + fn with_commit(mut self, commit: CommitDetails) -> Self { + self.commit = commit; + self + } + + async fn run_scenario(self) -> Result { + let mock = MockGitHubApi { + pull_request: self.pull_request, + reviews: self.reviews, + comments: self.comments, + commit_authors_json: self.commit_authors_json, + org_members: self.org_members, + }; + let client = GitHubClient::new(Rc::new(mock)); + let reporter = Reporter::new(CommitList::default(), &client); + reporter.check_commit(&self.commit).await + } + } + + #[tokio::test] + async fn approved_review_by_org_member_succeeds() { + let result = TestScenario::single_commit() + .with_reviews(vec![review("bob", ReviewState::Approved)]) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::PullRequestReviewed(_)))); + } + + #[tokio::test] + async fn non_approved_review_state_is_not_accepted() { + let result = TestScenario::single_commit() + .with_reviews(vec![review("bob", ReviewState::Other)]) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } + + #[tokio::test] + async fn review_by_non_org_member_is_not_accepted() { + let result = TestScenario::single_commit() + .with_reviews(vec![review("bob", ReviewState::Approved)]) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } + + #[tokio::test] + async fn pr_author_own_approval_review_is_rejected() { + let result = TestScenario::single_commit() + .with_reviews(vec![review("alice", ReviewState::Approved)]) + .with_org_members(vec!["alice"]) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } + + #[tokio::test] + async fn pr_author_own_approval_comment_is_rejected() { + let result = TestScenario::single_commit() + .with_comments(vec![comment("alice", "@zed-zippy approve")]) + .with_org_members(vec!["alice"]) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } + + #[tokio::test] + async fn approval_comment_by_org_member_succeeds() { + let result = TestScenario::single_commit() + .with_comments(vec![comment("bob", "@zed-zippy approve")]) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_)))); + } + + #[tokio::test] + async fn group_approval_comment_by_org_member_succeeds() { + let result = TestScenario::single_commit() + .with_comments(vec![comment("bob", "@zed-industries/approved")]) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_)))); + } + + #[tokio::test] + async fn comment_without_approval_pattern_is_not_accepted() { + let result = TestScenario::single_commit() + .with_comments(vec![comment("bob", "looks good")]) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } + + #[tokio::test] + async fn commit_without_pr_number_is_no_pr_found() { + let result = TestScenario::single_commit() + .with_commit(make_commit( + "abc12345abc12345", + "Alice", + "alice@test.com", + "Fix thing without PR number", + "", + )) + .run_scenario() + .await; + assert!(matches!(result, Err(ReviewFailure::NoPullRequestFound))); + } + + #[tokio::test] + async fn pr_review_takes_precedence_over_comment() { + let result = TestScenario::single_commit() + .with_reviews(vec![review("bob", ReviewState::Approved)]) + .with_comments(vec![comment("charlie", "@zed-zippy approve")]) + .with_org_members(vec!["bob", "charlie"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::PullRequestReviewed(_)))); + } + + #[tokio::test] + async fn comment_takes_precedence_over_co_author() { + let result = TestScenario::single_commit() + .with_comments(vec![comment("bob", "@zed-zippy approve")]) + .with_commit_authors_json(serde_json::json!({ + "abc12345abc12345": { + "author": { + "name": "Alice", + "email": "alice@test.com", + "user": { "login": "alice" } + }, + "authors": [{ + "name": "Charlie", + "email": "charlie@test.com", + "user": { "login": "charlie" } + }] + } + })) + .with_commit(make_commit( + "abc12345abc12345", + "Alice", + "alice@test.com", + "Fix thing (#1234)", + "Co-authored-by: Charlie ", + )) + .with_org_members(vec!["bob", "charlie"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_)))); + } + + #[tokio::test] + async fn co_author_org_member_succeeds() { + let result = TestScenario::single_commit() + .with_commit_authors_json(serde_json::json!({ + "abc12345abc12345": { + "author": { + "name": "Alice", + "email": "alice@test.com", + "user": { "login": "alice" } + }, + "authors": [{ + "name": "Bob", + "email": "bob@test.com", + "user": { "login": "bob" } + }] + } + })) + .with_commit(make_commit( + "abc12345abc12345", + "Alice", + "alice@test.com", + "Fix thing (#1234)", + "Co-authored-by: Bob ", + )) + .with_org_members(vec!["bob"]) + .run_scenario() + .await; + assert!(matches!(result, Ok(ReviewSuccess::CoAuthored(_)))); + } + + #[tokio::test] + async fn no_reviews_no_comments_no_coauthors_is_unreviewed() { + let result = TestScenario::single_commit().run_scenario().await; + assert!(matches!(result, Err(ReviewFailure::Unreviewed))); + } +} diff --git a/tooling/compliance/src/git.rs b/tooling/compliance/src/git.rs new file mode 100644 index 0000000000000000000000000000000000000000..fa2cb725712de82526d4ce717c2ec3dc97d22885 --- /dev/null +++ b/tooling/compliance/src/git.rs @@ -0,0 +1,591 @@ +#![allow(clippy::disallowed_methods, reason = "This is only used in xtasks")] +use std::{ + fmt::{self, Debug}, + ops::Not, + process::Command, + str::FromStr, + sync::LazyLock, +}; + +use anyhow::{Context, Result, anyhow}; +use derive_more::{Deref, DerefMut, FromStr}; + +use itertools::Itertools; +use regex::Regex; +use semver::Version; +use serde::Deserialize; + +pub trait Subcommand { + type ParsedOutput: FromStr; + + fn args(&self) -> impl IntoIterator; +} + +#[derive(Deref, DerefMut)] +pub struct GitCommand { + #[deref] + #[deref_mut] + subcommand: G, +} + +impl GitCommand { + #[must_use] + pub fn run(subcommand: G) -> Result { + Self { subcommand }.run_impl() + } + + fn run_impl(self) -> Result { + let command_output = Command::new("git") + .args(self.subcommand.args()) + .output() + .context("Failed to spawn command")?; + + if command_output.status.success() { + String::from_utf8(command_output.stdout) + .map_err(|_| anyhow!("Invalid UTF8")) + .and_then(|s| { + G::ParsedOutput::from_str(s.trim()) + .map_err(|e| anyhow!("Failed to parse from string: {e:?}")) + }) + } else { + anyhow::bail!( + "Command failed with exit code {}, stderr: {}", + command_output.status.code().unwrap_or_default(), + String::from_utf8(command_output.stderr).unwrap_or_default() + ) + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum ReleaseChannel { + Stable, + Preview, +} + +impl ReleaseChannel { + pub(crate) fn tag_suffix(&self) -> &'static str { + match self { + ReleaseChannel::Stable => "", + ReleaseChannel::Preview => "-pre", + } + } +} + +#[derive(Debug, Clone)] +pub struct VersionTag(Version, ReleaseChannel); + +impl VersionTag { + pub fn parse(input: &str) -> Result { + // Being a bit more lenient for human inputs + let version = input.strip_prefix('v').unwrap_or(input); + + let (version_str, channel) = version + .strip_suffix("-pre") + .map_or((version, ReleaseChannel::Stable), |version_str| { + (version_str, ReleaseChannel::Preview) + }); + + Version::parse(version_str) + .map(|version| Self(version, channel)) + .map_err(|_| anyhow::anyhow!("Failed to parse version from tag!")) + } + + pub fn version(&self) -> &Version { + &self.0 + } +} + +impl ToString for VersionTag { + fn to_string(&self) -> String { + format!( + "v{version}{channel_suffix}", + version = self.0, + channel_suffix = self.1.tag_suffix() + ) + } +} + +#[derive(Debug, Deref, FromStr, PartialEq, Eq, Hash, Deserialize)] +pub struct CommitSha(pub(crate) String); + +impl CommitSha { + pub fn short(&self) -> &str { + self.0.as_str().split_at(8).0 + } +} + +#[derive(Debug)] +pub struct CommitDetails { + sha: CommitSha, + author: Committer, + title: String, + body: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Committer { + name: String, + email: String, +} + +impl Committer { + pub fn new(name: &str, email: &str) -> Self { + Self { + name: name.to_owned(), + email: email.to_owned(), + } + } +} + +impl fmt::Display for Committer { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "{} ({})", self.name, self.email) + } +} + +impl CommitDetails { + const BODY_DELIMITER: &str = "|body-delimiter|"; + const COMMIT_DELIMITER: &str = "|commit-delimiter|"; + const FIELD_DELIMITER: &str = "|field-delimiter|"; + const FORMAT_STRING: &str = "%H|field-delimiter|%an|field-delimiter|%ae|field-delimiter|%s|body-delimiter|%b|commit-delimiter|"; + + fn parse(line: &str, body: &str) -> Result { + let Some([sha, author_name, author_email, title]) = + line.splitn(4, Self::FIELD_DELIMITER).collect_array() + else { + return Err(anyhow!("Failed to parse commit fields from input {line}")); + }; + + Ok(CommitDetails { + sha: CommitSha(sha.to_owned()), + author: Committer::new(author_name, author_email), + title: title.to_owned(), + body: body.to_owned(), + }) + } + + pub fn pr_number(&self) -> Option { + // Since we use squash merge, all commit titles end with the '(#12345)' pattern. + // While we could strictly speaking index into this directly, go for a slightly + // less prone approach to errors + const PATTERN: &str = " (#"; + self.title + .rfind(PATTERN) + .and_then(|location| { + self.title[location..] + .find(')') + .map(|relative_end| location + PATTERN.len()..location + relative_end) + }) + .and_then(|range| self.title[range].parse().ok()) + } + + pub(crate) fn co_authors(&self) -> Option> { + static CO_AUTHOR_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"Co-authored-by: (.+) <(.+)>").unwrap()); + + let mut co_authors = Vec::new(); + + for cap in CO_AUTHOR_REGEX.captures_iter(&self.body.as_ref()) { + let Some((name, email)) = cap + .get(1) + .map(|m| m.as_str()) + .zip(cap.get(2).map(|m| m.as_str())) + else { + continue; + }; + co_authors.push(Committer::new(name, email)); + } + + co_authors.is_empty().not().then_some(co_authors) + } + + pub(crate) fn author(&self) -> &Committer { + &self.author + } + + pub(crate) fn title(&self) -> &str { + &self.title + } + + pub(crate) fn sha(&self) -> &CommitSha { + &self.sha + } +} + +#[derive(Debug, Deref, Default, DerefMut)] +pub struct CommitList(Vec); + +impl CommitList { + pub fn range(&self) -> Option { + self.0 + .first() + .zip(self.0.last()) + .map(|(first, last)| format!("{}..{}", first.sha().0, last.sha().0)) + } +} + +impl IntoIterator for CommitList { + type IntoIter = std::vec::IntoIter; + type Item = CommitDetails; + + fn into_iter(self) -> std::vec::IntoIter { + self.0.into_iter() + } +} + +impl FromStr for CommitList { + type Err = anyhow::Error; + + fn from_str(input: &str) -> Result { + Ok(CommitList( + input + .split(CommitDetails::COMMIT_DELIMITER) + .filter(|commit_details| !commit_details.is_empty()) + .map(|commit_details| { + let (line, body) = commit_details + .trim() + .split_once(CommitDetails::BODY_DELIMITER) + .expect("Missing body delimiter"); + CommitDetails::parse(line, body) + .expect("Parsing from the output should succeed") + }) + .collect(), + )) + } +} + +pub struct GetVersionTags; + +impl Subcommand for GetVersionTags { + type ParsedOutput = VersionTagList; + + fn args(&self) -> impl IntoIterator { + ["tag", "-l", "v*"].map(ToOwned::to_owned) + } +} + +pub struct VersionTagList(Vec); + +impl VersionTagList { + pub fn sorted(mut self) -> Self { + self.0.sort_by(|a, b| a.version().cmp(b.version())); + self + } + + pub fn find_previous_minor_version(&self, version_tag: &VersionTag) -> Option<&VersionTag> { + self.0 + .iter() + .take_while(|tag| tag.version() < version_tag.version()) + .collect_vec() + .into_iter() + .rev() + .find(|tag| { + (tag.version().major < version_tag.version().major + || (tag.version().major == version_tag.version().major + && tag.version().minor < version_tag.version().minor)) + && tag.version().patch == 0 + }) + } +} + +impl FromStr for VersionTagList { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let version_tags = s.lines().flat_map(VersionTag::parse).collect_vec(); + + version_tags + .is_empty() + .not() + .then_some(Self(version_tags)) + .ok_or_else(|| anyhow::anyhow!("No version tags found")) + } +} + +pub struct CommitsFromVersionToHead { + version_tag: VersionTag, + branch: String, +} + +impl CommitsFromVersionToHead { + pub fn new(version_tag: VersionTag, branch: String) -> Self { + Self { + version_tag, + branch, + } + } +} + +impl Subcommand for CommitsFromVersionToHead { + type ParsedOutput = CommitList; + + fn args(&self) -> impl IntoIterator { + [ + "log".to_string(), + format!("--pretty=format:{}", CommitDetails::FORMAT_STRING), + format!( + "{version}..{branch}", + version = self.version_tag.to_string(), + branch = self.branch + ), + ] + } +} + +pub struct NoOutput; + +impl FromStr for NoOutput { + type Err = anyhow::Error; + + fn from_str(_: &str) -> Result { + Ok(NoOutput) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + + #[test] + fn parse_stable_version_tag() { + let tag = VersionTag::parse("v0.172.8").unwrap(); + assert_eq!(tag.version().major, 0); + assert_eq!(tag.version().minor, 172); + assert_eq!(tag.version().patch, 8); + assert_eq!(tag.1, ReleaseChannel::Stable); + } + + #[test] + fn parse_preview_version_tag() { + let tag = VersionTag::parse("v0.172.1-pre").unwrap(); + assert_eq!(tag.version().major, 0); + assert_eq!(tag.version().minor, 172); + assert_eq!(tag.version().patch, 1); + assert_eq!(tag.1, ReleaseChannel::Preview); + } + + #[test] + fn parse_version_tag_without_v_prefix() { + let tag = VersionTag::parse("0.172.8").unwrap(); + assert_eq!(tag.version().major, 0); + assert_eq!(tag.version().minor, 172); + assert_eq!(tag.version().patch, 8); + } + + #[test] + fn parse_invalid_version_tag() { + let result = VersionTag::parse("vConradTest"); + assert!(result.is_err()); + } + + #[test] + fn version_tag_stable_roundtrip() { + let tag = VersionTag::parse("v0.172.8").unwrap(); + assert_eq!(tag.to_string(), "v0.172.8"); + } + + #[test] + fn version_tag_preview_roundtrip() { + let tag = VersionTag::parse("v0.172.1-pre").unwrap(); + assert_eq!(tag.to_string(), "v0.172.1-pre"); + } + + #[test] + fn sorted_orders_by_semver() { + let input = indoc! {" + v0.172.8 + v0.170.1 + v0.171.4 + v0.170.2 + v0.172.11 + v0.171.3 + v0.172.9 + "}; + let list = VersionTagList::from_str(input).unwrap().sorted(); + for window in list.0.windows(2) { + assert!( + window[0].version() <= window[1].version(), + "{} should come before {}", + window[0].to_string(), + window[1].to_string() + ); + } + assert_eq!(list.0[0].to_string(), "v0.170.1"); + assert_eq!(list.0[list.0.len() - 1].to_string(), "v0.172.11"); + } + + #[test] + fn find_previous_minor_for_173_returns_172() { + let input = indoc! {" + v0.170.1 + v0.170.2 + v0.171.3 + v0.171.4 + v0.172.0 + v0.172.8 + v0.172.9 + v0.172.11 + "}; + let list = VersionTagList::from_str(input).unwrap().sorted(); + let target = VersionTag::parse("v0.173.0").unwrap(); + let previous = list.find_previous_minor_version(&target).unwrap(); + assert_eq!(previous.version().major, 0); + assert_eq!(previous.version().minor, 172); + assert_eq!(previous.version().patch, 0); + } + + #[test] + fn find_previous_minor_skips_same_minor() { + let input = indoc! {" + v0.172.8 + v0.172.9 + v0.172.11 + "}; + let list = VersionTagList::from_str(input).unwrap().sorted(); + let target = VersionTag::parse("v0.172.8").unwrap(); + assert!(list.find_previous_minor_version(&target).is_none()); + } + + #[test] + fn find_previous_minor_with_major_version_gap() { + let input = indoc! {" + v0.172.0 + v0.172.9 + v0.172.11 + "}; + let list = VersionTagList::from_str(input).unwrap().sorted(); + let target = VersionTag::parse("v1.0.0").unwrap(); + let previous = list.find_previous_minor_version(&target).unwrap(); + assert_eq!(previous.to_string(), "v0.172.0"); + } + + #[test] + fn find_previous_minor_requires_zero_patch_version() { + let input = indoc! {" + v0.172.1 + v0.172.9 + v0.172.11 + "}; + let list = VersionTagList::from_str(input).unwrap().sorted(); + let target = VersionTag::parse("v1.0.0").unwrap(); + assert!(list.find_previous_minor_version(&target).is_none()); + } + + #[test] + fn parse_tag_list_from_real_tags() { + let input = indoc! {" + v0.9999-temporary + vConradTest + v0.172.8 + "}; + let list = VersionTagList::from_str(input).unwrap(); + assert_eq!(list.0.len(), 1); + assert_eq!(list.0[0].to_string(), "v0.172.8"); + } + + #[test] + fn parse_empty_tag_list_fails() { + let result = VersionTagList::from_str(""); + assert!(result.is_err()); + } + + #[test] + fn pr_number_from_squash_merge_title() { + let line = format!( + "abc123{d}Author Name{d}author@email.com{d}Add cool feature (#12345)", + d = CommitDetails::FIELD_DELIMITER + ); + let commit = CommitDetails::parse(&line, "").unwrap(); + assert_eq!(commit.pr_number(), Some(12345)); + } + + #[test] + fn pr_number_missing() { + let line = format!( + "abc123{d}Author Name{d}author@email.com{d}Some commit without PR ref", + d = CommitDetails::FIELD_DELIMITER + ); + let commit = CommitDetails::parse(&line, "").unwrap(); + assert_eq!(commit.pr_number(), None); + } + + #[test] + fn pr_number_takes_last_match() { + let line = format!( + "abc123{d}Author Name{d}author@email.com{d}Fix (#123) and refactor (#456)", + d = CommitDetails::FIELD_DELIMITER + ); + let commit = CommitDetails::parse(&line, "").unwrap(); + assert_eq!(commit.pr_number(), Some(456)); + } + + #[test] + fn co_authors_parsed_from_body() { + let line = format!( + "abc123{d}Author Name{d}author@email.com{d}Some title", + d = CommitDetails::FIELD_DELIMITER + ); + let body = indoc! {" + Co-authored-by: Alice Smith + Co-authored-by: Bob Jones + "}; + let commit = CommitDetails::parse(&line, body).unwrap(); + let co_authors = commit.co_authors().unwrap(); + assert_eq!(co_authors.len(), 2); + assert_eq!( + co_authors[0], + Committer::new("Alice Smith", "alice@example.com") + ); + assert_eq!( + co_authors[1], + Committer::new("Bob Jones", "bob@example.com") + ); + } + + #[test] + fn no_co_authors_returns_none() { + let line = format!( + "abc123{d}Author Name{d}author@email.com{d}Some title", + d = CommitDetails::FIELD_DELIMITER + ); + let commit = CommitDetails::parse(&line, "").unwrap(); + assert!(commit.co_authors().is_none()); + } + + #[test] + fn commit_sha_short_returns_first_8_chars() { + let sha = CommitSha("abcdef1234567890abcdef1234567890abcdef12".into()); + assert_eq!(sha.short(), "abcdef12"); + } + + #[test] + fn parse_commit_list_from_git_log_format() { + let fd = CommitDetails::FIELD_DELIMITER; + let bd = CommitDetails::BODY_DELIMITER; + let cd = CommitDetails::COMMIT_DELIMITER; + + let input = format!( + "sha111{fd}Alice{fd}alice@test.com{fd}First commit (#100){bd}First body{cd}sha222{fd}Bob{fd}bob@test.com{fd}Second commit (#200){bd}Second body{cd}" + ); + + let list = CommitList::from_str(&input).unwrap(); + assert_eq!(list.0.len(), 2); + + assert_eq!(list.0[0].sha().0, "sha111"); + assert_eq!( + list.0[0].author(), + &Committer::new("Alice", "alice@test.com") + ); + assert_eq!(list.0[0].title(), "First commit (#100)"); + assert_eq!(list.0[0].pr_number(), Some(100)); + assert_eq!(list.0[0].body, "First body"); + + assert_eq!(list.0[1].sha().0, "sha222"); + assert_eq!(list.0[1].author(), &Committer::new("Bob", "bob@test.com")); + assert_eq!(list.0[1].title(), "Second commit (#200)"); + assert_eq!(list.0[1].pr_number(), Some(200)); + assert_eq!(list.0[1].body, "Second body"); + } +} diff --git a/tooling/compliance/src/github.rs b/tooling/compliance/src/github.rs new file mode 100644 index 0000000000000000000000000000000000000000..ebd2f2c75f5d0083632a8f70e3ea9dd2680d4eb5 --- /dev/null +++ b/tooling/compliance/src/github.rs @@ -0,0 +1,424 @@ +use std::{collections::HashMap, fmt, ops::Not, rc::Rc}; + +use anyhow::Result; +use derive_more::Deref; +use serde::Deserialize; + +use crate::git::CommitSha; + +pub const PR_REVIEW_LABEL: &str = "PR state:needs review"; + +#[derive(Debug, Clone)] +pub struct GitHubUser { + pub login: String, +} + +#[derive(Debug, Clone)] +pub struct PullRequestData { + pub number: u64, + pub user: Option, + pub merged_by: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReviewState { + Approved, + Other, +} + +#[derive(Debug, Clone)] +pub struct PullRequestReview { + pub user: Option, + pub state: Option, +} + +#[derive(Debug, Clone)] +pub struct PullRequestComment { + pub user: GitHubUser, + pub body: Option, +} + +#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)] +pub struct GithubLogin { + login: String, +} + +impl GithubLogin { + pub(crate) fn new(login: String) -> Self { + Self { login } + } +} + +impl fmt::Display for GithubLogin { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "@{}", self.login) + } +} + +#[derive(Debug, Deserialize, Clone)] +pub struct CommitAuthor { + name: String, + email: String, + user: Option, +} + +impl CommitAuthor { + pub(crate) fn user(&self) -> Option<&GithubLogin> { + self.user.as_ref() + } +} + +impl PartialEq for CommitAuthor { + fn eq(&self, other: &Self) -> bool { + self.user.as_ref().zip(other.user.as_ref()).map_or_else( + || self.email == other.email || self.name == other.name, + |(l, r)| l == r, + ) + } +} + +impl fmt::Display for CommitAuthor { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.user.as_ref() { + Some(user) => write!(formatter, "{} ({user})", self.name), + None => write!(formatter, "{} ({})", self.name, self.email), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct CommitAuthors { + #[serde(rename = "author")] + primary_author: CommitAuthor, + #[serde(rename = "authors")] + co_authors: Vec, +} + +impl CommitAuthors { + pub fn co_authors(&self) -> Option> { + self.co_authors.is_empty().not().then(|| { + self.co_authors + .iter() + .filter(|co_author| *co_author != &self.primary_author) + }) + } +} + +#[derive(Debug, Deserialize, Deref)] +pub struct AuthorsForCommits(HashMap); + +#[async_trait::async_trait(?Send)] +pub trait GitHubApiClient { + async fn get_pull_request(&self, pr_number: u64) -> Result; + async fn get_pull_request_reviews(&self, pr_number: u64) -> Result>; + async fn get_pull_request_comments(&self, pr_number: u64) -> Result>; + async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result; + async fn check_org_membership(&self, login: &GithubLogin) -> Result; + async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>; +} + +pub struct GitHubClient { + api: Rc, +} + +impl GitHubClient { + pub fn new(api: Rc) -> Self { + Self { api } + } + + #[cfg(feature = "octo-client")] + pub async fn for_app(app_id: u64, app_private_key: &str) -> Result { + let client = OctocrabClient::new(app_id, app_private_key).await?; + Ok(Self::new(Rc::new(client))) + } + + pub async fn get_pull_request(&self, pr_number: u64) -> Result { + self.api.get_pull_request(pr_number).await + } + + pub async fn get_pull_request_reviews(&self, pr_number: u64) -> Result> { + self.api.get_pull_request_reviews(pr_number).await + } + + pub async fn get_pull_request_comments( + &self, + pr_number: u64, + ) -> Result> { + self.api.get_pull_request_comments(pr_number).await + } + + pub async fn get_commit_authors<'a>( + &self, + commit_shas: impl IntoIterator, + ) -> Result { + let shas: Vec<&CommitSha> = commit_shas.into_iter().collect(); + self.api.get_commit_authors(&shas).await + } + + pub async fn check_org_membership(&self, login: &GithubLogin) -> Result { + self.api.check_org_membership(login).await + } + + pub async fn add_label_to_pull_request(&self, label: &str, pr_number: u64) -> Result<()> { + self.api + .ensure_pull_request_has_label(label, pr_number) + .await + } +} + +#[cfg(feature = "octo-client")] +mod octo_client { + use anyhow::{Context, Result}; + use futures::TryStreamExt as _; + use itertools::Itertools; + use jsonwebtoken::EncodingKey; + use octocrab::{ + Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState, + service::middleware::cache::mem::InMemoryCache, + }; + use serde::de::DeserializeOwned; + use tokio::pin; + + use crate::git::CommitSha; + + use super::{ + AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment, + PullRequestData, PullRequestReview, ReviewState, + }; + + const PAGE_SIZE: u8 = 100; + const ORG: &str = "zed-industries"; + const REPO: &str = "zed"; + + pub struct OctocrabClient { + client: Octocrab, + } + + impl OctocrabClient { + pub async fn new(app_id: u64, app_private_key: &str) -> Result { + let octocrab = Octocrab::builder() + .cache(InMemoryCache::new()) + .app( + app_id.into(), + EncodingKey::from_rsa_pem(app_private_key.as_bytes())?, + ) + .build()?; + + let installations = octocrab + .apps() + .installations() + .send() + .await + .context("Failed to fetch installations")? + .take_items(); + + let installation_id = installations + .into_iter() + .find(|installation| installation.account.login == ORG) + .context("Could not find Zed repository in installations")? + .id; + + let client = octocrab.installation(installation_id)?; + Ok(Self { client }) + } + + fn build_co_authors_query<'a>(shas: impl IntoIterator) -> String { + const FRAGMENT: &str = r#" + ... on Commit { + author { + name + email + user { login } + } + authors(first: 10) { + nodes { + name + email + user { login } + } + } + } + "#; + + let objects: String = shas + .into_iter() + .map(|commit_sha| { + format!( + "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}", + sha = **commit_sha + ) + }) + .join("\n"); + + format!("{{ repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects} }} }}") + .replace("\n", "") + } + + async fn graphql( + &self, + query: &serde_json::Value, + ) -> octocrab::Result { + self.client.graphql(query).await + } + + async fn get_all( + &self, + page: Page, + ) -> octocrab::Result> { + self.get_filtered(page, |_| true).await + } + + async fn get_filtered( + &self, + page: Page, + predicate: impl Fn(&T) -> bool, + ) -> octocrab::Result> { + let stream = page.into_stream(&self.client); + pin!(stream); + + let mut results = Vec::new(); + + while let Some(item) = stream.try_next().await? + && predicate(&item) + { + results.push(item); + } + + Ok(results) + } + } + + #[async_trait::async_trait(?Send)] + impl GitHubApiClient for OctocrabClient { + async fn get_pull_request(&self, pr_number: u64) -> Result { + let pr = self.client.pulls(ORG, REPO).get(pr_number).await?; + Ok(PullRequestData { + number: pr.number, + user: pr.user.map(|user| GitHubUser { login: user.login }), + merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }), + }) + } + + async fn get_pull_request_reviews(&self, pr_number: u64) -> Result> { + let page = self + .client + .pulls(ORG, REPO) + .list_reviews(pr_number) + .per_page(PAGE_SIZE) + .send() + .await?; + + let reviews = self.get_all(page).await?; + + Ok(reviews + .into_iter() + .map(|review| PullRequestReview { + user: review.user.map(|user| GitHubUser { login: user.login }), + state: review.state.map(|state| match state { + OctocrabReviewState::Approved => ReviewState::Approved, + _ => ReviewState::Other, + }), + }) + .collect()) + } + + async fn get_pull_request_comments( + &self, + pr_number: u64, + ) -> Result> { + let page = self + .client + .issues(ORG, REPO) + .list_comments(pr_number) + .per_page(PAGE_SIZE) + .send() + .await?; + + let comments = self.get_all(page).await?; + + Ok(comments + .into_iter() + .map(|comment| PullRequestComment { + user: GitHubUser { + login: comment.user.login, + }, + body: comment.body, + }) + .collect()) + } + + async fn get_commit_authors( + &self, + commit_shas: &[&CommitSha], + ) -> Result { + let query = Self::build_co_authors_query(commit_shas.iter().copied()); + let query = serde_json::json!({ "query": query }); + let mut response = self.graphql::(&query).await?; + + response + .get_mut("data") + .and_then(|data| data.get_mut("repository")) + .and_then(|repo| repo.as_object_mut()) + .ok_or_else(|| anyhow::anyhow!("Unexpected response format!")) + .and_then(|commit_data| { + let mut response_map = serde_json::Map::with_capacity(commit_data.len()); + + for (key, value) in commit_data.iter_mut() { + let key_without_prefix = key.strip_prefix("commit").unwrap_or(key); + if let Some(authors) = value.get_mut("authors") { + if let Some(nodes) = authors.get("nodes") { + *authors = nodes.clone(); + } + } + + response_map.insert(key_without_prefix.to_owned(), value.clone()); + } + + serde_json::from_value(serde_json::Value::Object(response_map)) + .context("Failed to deserialize commit authors") + }) + } + + async fn check_org_membership(&self, login: &GithubLogin) -> Result { + let page = self + .client + .orgs(ORG) + .list_members() + .per_page(PAGE_SIZE) + .send() + .await?; + + let members = self.get_all(page).await?; + + Ok(members + .into_iter() + .any(|member| member.login == login.as_str())) + } + + async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> { + if self + .get_filtered( + self.client + .issues(ORG, REPO) + .list_labels_for_issue(pr_number) + .per_page(PAGE_SIZE) + .send() + .await?, + |pr_label| pr_label.name == label, + ) + .await + .is_ok_and(|l| l.is_empty()) + { + self.client + .issues(ORG, REPO) + .add_labels(pr_number, &[label.to_owned()]) + .await?; + } + + Ok(()) + } + } +} + +#[cfg(feature = "octo-client")] +pub use octo_client::OctocrabClient; diff --git a/tooling/compliance/src/lib.rs b/tooling/compliance/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..9476412c6d6d1f56b1396bf5d700924549c707da --- /dev/null +++ b/tooling/compliance/src/lib.rs @@ -0,0 +1,4 @@ +pub mod checks; +pub mod git; +pub mod github; +pub mod report; diff --git a/tooling/compliance/src/report.rs b/tooling/compliance/src/report.rs new file mode 100644 index 0000000000000000000000000000000000000000..16df145394726b97382884fbdfdc3164c0029786 --- /dev/null +++ b/tooling/compliance/src/report.rs @@ -0,0 +1,446 @@ +use std::{ + fs::{self, File}, + io::{BufWriter, Write}, + path::Path, +}; + +use anyhow::Context as _; +use derive_more::Display; +use itertools::{Either, Itertools}; + +use crate::{ + checks::{ReviewFailure, ReviewResult, ReviewSuccess}, + git::CommitDetails, +}; + +const PULL_REQUEST_BASE_URL: &str = "https://github.com/zed-industries/zed/pull"; + +#[derive(Debug)] +pub struct ReportEntry { + pub commit: CommitDetails, + reason: R, +} + +impl ReportEntry { + fn commit_cell(&self) -> String { + let title = escape_markdown_link_text(self.commit.title()); + + match self.commit.pr_number() { + Some(pr_number) => format!("[{title}]({PULL_REQUEST_BASE_URL}/{pr_number})"), + None => escape_markdown_table_text(self.commit.title()), + } + } + + fn pull_request_cell(&self) -> String { + self.commit + .pr_number() + .map(|pr_number| format!("#{pr_number}")) + .unwrap_or_else(|| "—".to_owned()) + } + + fn author_cell(&self) -> String { + escape_markdown_table_text(&self.commit.author().to_string()) + } + + fn reason_cell(&self) -> String { + escape_markdown_table_text(&self.reason.to_string()) + } +} + +impl ReportEntry { + fn issue_kind(&self) -> IssueKind { + match self.reason { + ReviewFailure::Other(_) => IssueKind::Error, + _ => IssueKind::NotReviewed, + } + } +} + +impl ReportEntry { + fn reviewers_cell(&self) -> String { + match &self.reason.reviewers() { + Ok(reviewers) => escape_markdown_table_text(&reviewers), + Err(_) => "—".to_owned(), + } + } +} + +#[derive(Debug, Default)] +pub struct ReportSummary { + pub pull_requests: usize, + pub reviewed: usize, + pub not_reviewed: usize, + pub errors: usize, +} + +pub enum ReportReviewSummary { + MissingReviews, + MissingReviewsWithErrors, + NoIssuesFound, +} + +impl ReportSummary { + fn from_entries(entries: &[ReportEntry]) -> Self { + Self { + pull_requests: entries + .iter() + .filter_map(|entry| entry.commit.pr_number()) + .unique() + .count(), + reviewed: entries.iter().filter(|entry| entry.reason.is_ok()).count(), + not_reviewed: entries + .iter() + .filter(|entry| { + matches!( + entry.reason, + Err(ReviewFailure::NoPullRequestFound | ReviewFailure::Unreviewed) + ) + }) + .count(), + errors: entries + .iter() + .filter(|entry| matches!(entry.reason, Err(ReviewFailure::Other(_)))) + .count(), + } + } + + pub fn review_summary(&self) -> ReportReviewSummary { + match self.not_reviewed { + 0 if self.errors == 0 => ReportReviewSummary::NoIssuesFound, + 1.. if self.errors == 0 => ReportReviewSummary::MissingReviews, + _ => ReportReviewSummary::MissingReviewsWithErrors, + } + } + + fn has_errors(&self) -> bool { + self.errors > 0 + } +} + +#[derive(Clone, Copy, Debug, Display, PartialEq, Eq, PartialOrd, Ord)] +enum IssueKind { + #[display("Error")] + Error, + #[display("Not reviewed")] + NotReviewed, +} + +#[derive(Debug, Default)] +pub struct Report { + entries: Vec>, +} + +impl Report { + pub fn new() -> Self { + Self::default() + } + + pub fn add(&mut self, commit: CommitDetails, result: ReviewResult) { + self.entries.push(ReportEntry { + commit, + reason: result, + }); + } + + pub fn errors(&self) -> impl Iterator> { + self.entries.iter().filter(|entry| entry.reason.is_err()) + } + + pub fn summary(&self) -> ReportSummary { + ReportSummary::from_entries(&self.entries) + } + + pub fn write_markdown(self, path: impl AsRef) -> anyhow::Result<()> { + let path = path.as_ref(); + + if let Some(parent) = path + .parent() + .filter(|parent| !parent.as_os_str().is_empty()) + { + fs::create_dir_all(parent).with_context(|| { + format!( + "Failed to create parent directory for markdown report at {}", + path.display() + ) + })?; + } + + let summary = self.summary(); + let (successes, mut issues): (Vec<_>, Vec<_>) = + self.entries + .into_iter() + .partition_map(|entry| match entry.reason { + Ok(success) => Either::Left(ReportEntry { + reason: success, + commit: entry.commit, + }), + Err(fail) => Either::Right(ReportEntry { + reason: fail, + commit: entry.commit, + }), + }); + + issues.sort_by_key(|entry| entry.issue_kind()); + + let file = File::create(path) + .with_context(|| format!("Failed to create markdown report at {}", path.display()))?; + let mut writer = BufWriter::new(file); + + writeln!(writer, "# Compliance report")?; + writeln!(writer)?; + writeln!(writer, "## Overview")?; + writeln!(writer)?; + writeln!(writer, "- PRs: {}", summary.pull_requests)?; + writeln!(writer, "- Reviewed: {}", summary.reviewed)?; + writeln!(writer, "- Not reviewed: {}", summary.not_reviewed)?; + if summary.has_errors() { + writeln!(writer, "- Errors: {}", summary.errors)?; + } + writeln!(writer)?; + + write_issue_table(&mut writer, &issues, &summary)?; + write_success_table(&mut writer, &successes)?; + + writer + .flush() + .with_context(|| format!("Failed to flush markdown report to {}", path.display())) + } +} + +fn write_issue_table( + writer: &mut impl Write, + issues: &[ReportEntry], + summary: &ReportSummary, +) -> std::io::Result<()> { + if summary.has_errors() { + writeln!(writer, "## Errors and unreviewed commits")?; + } else { + writeln!(writer, "## Unreviewed commits")?; + } + writeln!(writer)?; + + if issues.is_empty() { + if summary.has_errors() { + writeln!(writer, "No errors or unreviewed commits found.")?; + } else { + writeln!(writer, "No unreviewed commits found.")?; + } + writeln!(writer)?; + return Ok(()); + } + + writeln!(writer, "| Commit | PR | Author | Outcome | Reason |")?; + writeln!(writer, "| --- | --- | --- | --- | --- |")?; + + for entry in issues { + let issue_kind = entry.issue_kind(); + writeln!( + writer, + "| {} | {} | {} | {} | {} |", + entry.commit_cell(), + entry.pull_request_cell(), + entry.author_cell(), + issue_kind, + entry.reason_cell(), + )?; + } + + writeln!(writer)?; + Ok(()) +} + +fn write_success_table( + writer: &mut impl Write, + successful_entries: &[ReportEntry], +) -> std::io::Result<()> { + writeln!(writer, "## Successful commits")?; + writeln!(writer)?; + + if successful_entries.is_empty() { + writeln!(writer, "No successful commits found.")?; + writeln!(writer)?; + return Ok(()); + } + + writeln!(writer, "| Commit | PR | Author | Reviewers | Reason |")?; + writeln!(writer, "| --- | --- | --- | --- | --- |")?; + + for entry in successful_entries { + writeln!( + writer, + "| {} | {} | {} | {} | {} |", + entry.commit_cell(), + entry.pull_request_cell(), + entry.author_cell(), + entry.reviewers_cell(), + entry.reason_cell(), + )?; + } + + writeln!(writer)?; + Ok(()) +} + +fn escape_markdown_link_text(input: &str) -> String { + escape_markdown_table_text(input) + .replace('[', r"\[") + .replace(']', r"\]") +} + +fn escape_markdown_table_text(input: &str) -> String { + input + .replace('\\', r"\\") + .replace('|', r"\|") + .replace('\r', "") + .replace('\n', "
") +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use crate::{ + checks::{ReviewFailure, ReviewSuccess}, + git::{CommitDetails, CommitList}, + github::{GitHubUser, PullRequestReview, ReviewState}, + }; + + use super::{Report, ReportReviewSummary}; + + fn make_commit( + sha: &str, + author_name: &str, + author_email: &str, + title: &str, + body: &str, + ) -> CommitDetails { + let formatted = format!( + "{sha}|field-delimiter|{author_name}|field-delimiter|{author_email}|field-delimiter|{title}|body-delimiter|{body}|commit-delimiter|" + ); + CommitList::from_str(&formatted) + .expect("test commit should parse") + .into_iter() + .next() + .expect("should have one commit") + } + + fn reviewed() -> ReviewSuccess { + ReviewSuccess::PullRequestReviewed(vec![PullRequestReview { + user: Some(GitHubUser { + login: "reviewer".to_owned(), + }), + state: Some(ReviewState::Approved), + }]) + } + + #[test] + fn report_summary_counts_are_accurate() { + let mut report = Report::new(); + + report.add( + make_commit( + "aaa", + "Alice", + "alice@test.com", + "Reviewed commit (#100)", + "", + ), + Ok(reviewed()), + ); + report.add( + make_commit("bbb", "Bob", "bob@test.com", "Unreviewed commit (#200)", ""), + Err(ReviewFailure::Unreviewed), + ); + report.add( + make_commit("ccc", "Carol", "carol@test.com", "No PR commit", ""), + Err(ReviewFailure::NoPullRequestFound), + ); + report.add( + make_commit("ddd", "Dave", "dave@test.com", "Error commit (#300)", ""), + Err(ReviewFailure::Other(anyhow::anyhow!("some error"))), + ); + + let summary = report.summary(); + assert_eq!(summary.pull_requests, 3); + assert_eq!(summary.reviewed, 1); + assert_eq!(summary.not_reviewed, 2); + assert_eq!(summary.errors, 1); + } + + #[test] + fn report_summary_all_reviewed_is_no_issues() { + let mut report = Report::new(); + + report.add( + make_commit("aaa", "Alice", "alice@test.com", "First (#100)", ""), + Ok(reviewed()), + ); + report.add( + make_commit("bbb", "Bob", "bob@test.com", "Second (#200)", ""), + Ok(reviewed()), + ); + + let summary = report.summary(); + assert!(matches!( + summary.review_summary(), + ReportReviewSummary::NoIssuesFound + )); + } + + #[test] + fn report_summary_missing_reviews_only() { + let mut report = Report::new(); + + report.add( + make_commit("aaa", "Alice", "alice@test.com", "Reviewed (#100)", ""), + Ok(reviewed()), + ); + report.add( + make_commit("bbb", "Bob", "bob@test.com", "Unreviewed (#200)", ""), + Err(ReviewFailure::Unreviewed), + ); + + let summary = report.summary(); + assert!(matches!( + summary.review_summary(), + ReportReviewSummary::MissingReviews + )); + } + + #[test] + fn report_summary_errors_and_missing_reviews() { + let mut report = Report::new(); + + report.add( + make_commit("aaa", "Alice", "alice@test.com", "Unreviewed (#100)", ""), + Err(ReviewFailure::Unreviewed), + ); + report.add( + make_commit("bbb", "Bob", "bob@test.com", "Errored (#200)", ""), + Err(ReviewFailure::Other(anyhow::anyhow!("check failed"))), + ); + + let summary = report.summary(); + assert!(matches!( + summary.review_summary(), + ReportReviewSummary::MissingReviewsWithErrors + )); + } + + #[test] + fn report_summary_deduplicates_pull_requests() { + let mut report = Report::new(); + + report.add( + make_commit("aaa", "Alice", "alice@test.com", "First change (#100)", ""), + Ok(reviewed()), + ); + report.add( + make_commit("bbb", "Bob", "bob@test.com", "Second change (#100)", ""), + Ok(reviewed()), + ); + + let summary = report.summary(); + assert_eq!(summary.pull_requests, 1); + } +} diff --git a/tooling/xtask/Cargo.toml b/tooling/xtask/Cargo.toml index 21090d1304ea0eab9ad70808b91f76789f2fd923..f9628dfa6390872210df9f3cc00b367d9420f522 100644 --- a/tooling/xtask/Cargo.toml +++ b/tooling/xtask/Cargo.toml @@ -15,7 +15,8 @@ backtrace.workspace = true cargo_metadata.workspace = true cargo_toml.workspace = true clap = { workspace = true, features = ["derive"] } -toml.workspace = true +compliance = { workspace = true, features = ["octo-client"] } +gh-workflow.workspace = true indoc.workspace = true indexmap.workspace = true itertools.workspace = true @@ -24,5 +25,6 @@ serde.workspace = true serde_json.workspace = true serde_yaml = "0.9.34" strum.workspace = true +tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } +toml.workspace = true toml_edit.workspace = true -gh-workflow.workspace = true diff --git a/tooling/xtask/src/main.rs b/tooling/xtask/src/main.rs index 05afe3c766829137a7c2ba6e73d57638624d5e6a..c442f1c509e28172b7283c95e518eee743b7730c 100644 --- a/tooling/xtask/src/main.rs +++ b/tooling/xtask/src/main.rs @@ -15,6 +15,7 @@ struct Args { enum CliCommand { /// Runs `cargo clippy`. Clippy(tasks::clippy::ClippyArgs), + Compliance(tasks::compliance::ComplianceArgs), Licenses(tasks::licenses::LicensesArgs), /// Checks that packages conform to a set of standards. PackageConformity(tasks::package_conformity::PackageConformityArgs), @@ -31,6 +32,7 @@ fn main() -> Result<()> { match args.command { CliCommand::Clippy(args) => tasks::clippy::run_clippy(args), + CliCommand::Compliance(args) => tasks::compliance::check_compliance(args), CliCommand::Licenses(args) => tasks::licenses::run_licenses(args), CliCommand::PackageConformity(args) => { tasks::package_conformity::run_package_conformity(args) diff --git a/tooling/xtask/src/tasks.rs b/tooling/xtask/src/tasks.rs index 80f504fa0345de0d5bc71c5b44c71846f04c50bc..ea67d0abc5fcbd8e85f40251a7997bc6fbbbca1f 100644 --- a/tooling/xtask/src/tasks.rs +++ b/tooling/xtask/src/tasks.rs @@ -1,4 +1,5 @@ pub mod clippy; +pub mod compliance; pub mod licenses; pub mod package_conformity; pub mod publish_gpui; diff --git a/tooling/xtask/src/tasks/compliance.rs b/tooling/xtask/src/tasks/compliance.rs new file mode 100644 index 0000000000000000000000000000000000000000..78cc32b23f3160ae950aaa5e374071dd107ec350 --- /dev/null +++ b/tooling/xtask/src/tasks/compliance.rs @@ -0,0 +1,135 @@ +use std::path::PathBuf; + +use anyhow::{Context, Result}; +use clap::Parser; + +use compliance::{ + checks::Reporter, + git::{CommitsFromVersionToHead, GetVersionTags, GitCommand, VersionTag}, + github::GitHubClient, + report::ReportReviewSummary, +}; + +#[derive(Parser)] +pub struct ComplianceArgs { + #[arg(value_parser = VersionTag::parse)] + // The version to be on the lookout for + pub(crate) version_tag: VersionTag, + #[arg(long)] + // The markdown file to write the compliance report to + report_path: PathBuf, + #[arg(long)] + // An optional branch to use instead of the determined version branch + branch: Option, +} + +impl ComplianceArgs { + pub(crate) fn version_tag(&self) -> &VersionTag { + &self.version_tag + } + + fn version_branch(&self) -> String { + self.branch.clone().unwrap_or_else(|| { + format!( + "v{major}.{minor}.x", + major = self.version_tag().version().major, + minor = self.version_tag().version().minor + ) + }) + } +} + +async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> { + let app_id = std::env::var("GITHUB_APP_ID").context("Missing GITHUB_APP_ID")?; + let key = std::env::var("GITHUB_APP_KEY").context("Missing GITHUB_APP_KEY")?; + + let tag = args.version_tag(); + + let previous_version = GitCommand::run(GetVersionTags)? + .sorted() + .find_previous_minor_version(&tag) + .cloned() + .ok_or_else(|| { + anyhow::anyhow!( + "Could not find previous version for tag {tag}", + tag = tag.to_string() + ) + })?; + + println!( + "Checking compliance for version {} with version {} as base", + tag.version(), + previous_version.version() + ); + + let commits = GitCommand::run(CommitsFromVersionToHead::new( + previous_version, + args.version_branch(), + ))?; + + let Some(range) = commits.range() else { + anyhow::bail!("No commits found to check"); + }; + + println!("Checking commit range {range}, {} total", commits.len()); + + let client = GitHubClient::for_app( + app_id.parse().context("Failed to parse app ID as int")?, + key.as_ref(), + ) + .await?; + + println!("Initialized GitHub client for app ID {app_id}"); + + let report = Reporter::new(commits, &client).generate_report().await?; + + println!( + "Generated report for version {}", + args.version_tag().to_string() + ); + + let summary = report.summary(); + + println!( + "Applying compliance labels to {} pull requests", + summary.pull_requests + ); + + for report in report.errors() { + if let Some(pr_number) = report.commit.pr_number() { + println!("Adding review label to PR {}...", pr_number); + + client + .add_label_to_pull_request(compliance::github::PR_REVIEW_LABEL, pr_number) + .await?; + } + } + + let report_path = args.report_path.with_extension("md"); + + report.write_markdown(&report_path)?; + + println!("Wrote compliance report to {}", report_path.display()); + + match summary.review_summary() { + ReportReviewSummary::MissingReviews => Err(anyhow::anyhow!( + "Compliance check failed, found {} commits not reviewed", + summary.not_reviewed + )), + ReportReviewSummary::MissingReviewsWithErrors => Err(anyhow::anyhow!( + "Compliance check failed with {} unreviewed commits and {} other issues", + summary.not_reviewed, + summary.errors + )), + ReportReviewSummary::NoIssuesFound => { + println!("No issues found, compliance check passed."); + Ok(()) + } + } +} + +pub fn check_compliance(args: ComplianceArgs) -> Result<()> { + tokio::runtime::Runtime::new() + .context("Failed to create tokio runtime") + .and_then(|handle| handle.block_on(check_compliance_impl(args))) +} diff --git a/tooling/xtask/src/tasks/workflows.rs b/tooling/xtask/src/tasks/workflows.rs index 414c0b7fd8dc2a99027d8687bcf1d4dbe9c4bb85..387c739a1ac12d4d65d11f33777525c59f05f7f2 100644 --- a/tooling/xtask/src/tasks/workflows.rs +++ b/tooling/xtask/src/tasks/workflows.rs @@ -11,6 +11,7 @@ mod autofix_pr; mod bump_patch_version; mod cherry_pick; mod compare_perf; +mod compliance_check; mod danger; mod deploy_collab; mod extension_auto_bump; @@ -197,6 +198,7 @@ pub fn run_workflows(args: GenerateWorkflowArgs) -> Result<()> { WorkflowFile::zed(bump_patch_version::bump_patch_version), WorkflowFile::zed(cherry_pick::cherry_pick), WorkflowFile::zed(compare_perf::compare_perf), + WorkflowFile::zed(compliance_check::compliance_check), WorkflowFile::zed(danger::danger), WorkflowFile::zed(deploy_collab::deploy_collab), WorkflowFile::zed(extension_bump::extension_bump), diff --git a/tooling/xtask/src/tasks/workflows/compliance_check.rs b/tooling/xtask/src/tasks/workflows/compliance_check.rs new file mode 100644 index 0000000000000000000000000000000000000000..9e2f4ae1e588c545266ec5a8246ac9781c6b668b --- /dev/null +++ b/tooling/xtask/src/tasks/workflows/compliance_check.rs @@ -0,0 +1,66 @@ +use gh_workflow::{Event, Expression, Job, Run, Schedule, Step, Workflow}; + +use crate::tasks::workflows::{ + runners, + steps::{self, CommonJobConditions, named}, + vars::{self, StepOutput}, +}; + +pub fn compliance_check() -> Workflow { + let check = scheduled_compliance_check(); + + named::workflow() + .on(Event::default().schedule([Schedule::new("30 17 * * 2")])) + .add_env(("CARGO_TERM_COLOR", "always")) + .add_job(check.name, check.job) +} + +fn scheduled_compliance_check() -> steps::NamedJob { + let determine_version_step = named::bash(indoc::indoc! {r#" + VERSION=$(sed -n 's/^version = "\(.*\)"/\1/p' crates/zed/Cargo.toml | tr -d '[:space:]') + if [ -z "$VERSION" ]; then + echo "Could not determine version from crates/zed/Cargo.toml" + exit 1 + fi + TAG="v${VERSION}-pre" + echo "Checking compliance for $TAG" + echo "tag=$TAG" >> "$GITHUB_OUTPUT" + "#}) + .id("determine-version"); + + let tag_output = StepOutput::new(&determine_version_step, "tag"); + + fn run_compliance_check(tag: &StepOutput) -> Step { + named::bash( + r#"cargo xtask compliance "$LATEST_TAG" --branch main --report-path target/compliance-report"#, + ) + .id("run-compliance-check") + .add_env(("LATEST_TAG", tag.to_string())) + .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID)) + .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY)) + } + + fn send_failure_slack_notification(tag: &StepOutput) -> Step { + named::bash(indoc::indoc! {r#" + MESSAGE="⚠️ Scheduled compliance check failed for upcoming preview release $LATEST_TAG: There are PRs with missing reviews." + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + "#}) + .if_condition(Expression::new("failure()")) + .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES)) + .add_env(("LATEST_TAG", tag.to_string())) + } + + named::job( + Job::default() + .with_repository_owner_guard() + .runs_on(runners::LINUX_SMALL) + .add_step(steps::checkout_repo().with_full_history()) + .add_step(steps::cache_rust_dependencies_namespace()) + .add_step(determine_version_step) + .add_step(run_compliance_check(&tag_output)) + .add_step(send_failure_slack_notification(&tag_output)), + ) +} diff --git a/tooling/xtask/src/tasks/workflows/release.rs b/tooling/xtask/src/tasks/workflows/release.rs index 4d7dc24d5e2d78cae87339877d730d3e3fb945b0..3efe3e7c5c127e8580a9ca22d2d0e1ab4e7c80e9 100644 --- a/tooling/xtask/src/tasks/workflows/release.rs +++ b/tooling/xtask/src/tasks/workflows/release.rs @@ -1,11 +1,13 @@ -use gh_workflow::{Event, Expression, Push, Run, Step, Use, Workflow, ctx::Context}; +use gh_workflow::{Event, Expression, Job, Push, Run, Step, Use, Workflow, ctx::Context}; use indoc::formatdoc; use crate::tasks::workflows::{ run_bundling::{bundle_linux, bundle_mac, bundle_windows}, run_tests, runners::{self, Arch, Platform}, - steps::{self, FluentBuilder, NamedJob, dependant_job, named, release_job}, + steps::{ + self, CommonJobConditions, FluentBuilder, NamedJob, dependant_job, named, release_job, + }, vars::{self, StepOutput, assets}, }; @@ -22,6 +24,7 @@ pub(crate) fn release() -> Workflow { let check_scripts = run_tests::check_scripts(); let create_draft_release = create_draft_release(); + let compliance = compliance_check(); let bundle = ReleaseBundleJobs { linux_aarch64: bundle_linux( @@ -92,6 +95,7 @@ pub(crate) fn release() -> Workflow { .add_job(windows_clippy.name, windows_clippy.job) .add_job(check_scripts.name, check_scripts.job) .add_job(create_draft_release.name, create_draft_release.job) + .add_job(compliance.name, compliance.job) .map(|mut workflow| { for job in bundle.into_jobs() { workflow = workflow.add_job(job.name, job.job); @@ -149,6 +153,59 @@ pub(crate) fn create_sentry_release() -> Step { .add_with(("environment", "production")) } +fn compliance_check() -> NamedJob { + fn run_compliance_check() -> Step { + named::bash( + r#"cargo xtask compliance "$GITHUB_REF_NAME" --report-path "$COMPLIANCE_FILE_OUTPUT""#, + ) + .id("run-compliance-check") + .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID)) + .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY)) + } + + fn send_compliance_slack_notification() -> Step { + named::bash(indoc::indoc! {r#" + if [ "$COMPLIANCE_OUTCOME" == "success" ]; then + STATUS="✅ Compliance check passed for $GITHUB_REF_NAME" + else + STATUS="❌ Compliance check failed for $GITHUB_REF_NAME" + fi + + REPORT_CONTENT="" + if [ -f "$COMPLIANCE_FILE_OUTPUT" ]; then + REPORT_CONTENT=$(cat "$REPORT_FILE") + fi + + MESSAGE=$(printf "%s\n\n%s" "$STATUS" "$REPORT_CONTENT") + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + "#}) + .if_condition(Expression::new("always()")) + .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES)) + .add_env(( + "COMPLIANCE_OUTCOME", + "${{ steps.run-compliance-check.outcome }}", + )) + } + + named::job( + Job::default() + .add_env(("COMPLIANCE_FILE_PATH", "compliance.md")) + .with_repository_owner_guard() + .runs_on(runners::LINUX_DEFAULT) + .add_step( + steps::checkout_repo() + .with_full_history() + .with_ref(Context::github().ref_()), + ) + .add_step(steps::cache_rust_dependencies_namespace()) + .add_step(run_compliance_check()) + .add_step(send_compliance_slack_notification()), + ) +} + fn validate_release_assets(deps: &[&NamedJob]) -> NamedJob { let expected_assets: Vec = assets::all().iter().map(|a| format!("\"{a}\"")).collect(); let expected_assets_json = format!("[{}]", expected_assets.join(", ")); @@ -171,10 +228,54 @@ fn validate_release_assets(deps: &[&NamedJob]) -> NamedJob { "#, }; + fn run_post_upload_compliance_check() -> Step { + named::bash( + r#"cargo xtask compliance "$GITHUB_REF_NAME" --report-path target/compliance-report"#, + ) + .id("run-post-upload-compliance-check") + .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID)) + .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY)) + } + + fn send_post_upload_compliance_notification() -> Step { + named::bash(indoc::indoc! {r#" + if [ -z "$COMPLIANCE_OUTCOME" ] || [ "$COMPLIANCE_OUTCOME" == "skipped" ]; then + echo "Compliance check was skipped, not sending notification" + exit 0 + fi + + TAG="$GITHUB_REF_NAME" + + if [ "$COMPLIANCE_OUTCOME" == "success" ]; then + MESSAGE="✅ Post-upload compliance re-check passed for $TAG" + else + MESSAGE="❌ Post-upload compliance re-check failed for $TAG" + fi + + curl -X POST -H 'Content-type: application/json' \ + --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \ + "$SLACK_WEBHOOK" + "#}) + .if_condition(Expression::new("always()")) + .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES)) + .add_env(( + "COMPLIANCE_OUTCOME", + "${{ steps.run-post-upload-compliance-check.outcome }}", + )) + } + named::job( - dependant_job(deps).runs_on(runners::LINUX_SMALL).add_step( - named::bash(&validation_script).add_env(("GITHUB_TOKEN", vars::GITHUB_TOKEN)), - ), + dependant_job(deps) + .runs_on(runners::LINUX_SMALL) + .add_step(named::bash(&validation_script).add_env(("GITHUB_TOKEN", vars::GITHUB_TOKEN))) + .add_step( + steps::checkout_repo() + .with_full_history() + .with_ref(Context::github().ref_()), + ) + .add_step(steps::cache_rust_dependencies_namespace()) + .add_step(run_post_upload_compliance_check()) + .add_step(send_post_upload_compliance_notification()), ) } @@ -255,7 +356,7 @@ fn create_draft_release() -> NamedJob { .add_step( steps::checkout_repo() .with_custom_fetch_depth(25) - .with_ref("${{ github.ref }}"), + .with_ref(Context::github().ref_()), ) .add_step(steps::script("script/determine-release-channel")) .add_step(steps::script("mkdir -p target/"))