Detailed changes
@@ -0,0 +1,55 @@
+# Generated from xtask::workflows::compliance_check
+# Rebuild with `cargo xtask workflows`.
+name: compliance_check
+env:
+ CARGO_TERM_COLOR: always
+on:
+ schedule:
+ - cron: 30 17 * * 2
+jobs:
+ scheduled_compliance_check:
+ if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
+ runs-on: namespace-profile-2x4-ubuntu-2404
+ steps:
+ - name: steps::checkout_repo
+ uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd
+ with:
+ clean: false
+ fetch-depth: 0
+ - name: steps::cache_rust_dependencies_namespace
+ uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9
+ with:
+ cache: rust
+ path: ~/.rustup
+ - id: determine-version
+ name: compliance_check::scheduled_compliance_check
+ run: |
+ VERSION=$(sed -n 's/^version = "\(.*\)"/\1/p' crates/zed/Cargo.toml | tr -d '[:space:]')
+ if [ -z "$VERSION" ]; then
+ echo "Could not determine version from crates/zed/Cargo.toml"
+ exit 1
+ fi
+ TAG="v${VERSION}-pre"
+ echo "Checking compliance for $TAG"
+ echo "tag=$TAG" >> "$GITHUB_OUTPUT"
+ - id: run-compliance-check
+ name: compliance_check::scheduled_compliance_check::run_compliance_check
+ run: cargo xtask compliance "$LATEST_TAG" --branch main --report-path target/compliance-report
+ env:
+ LATEST_TAG: ${{ steps.determine-version.outputs.tag }}
+ GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }}
+ GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }}
+ - name: compliance_check::scheduled_compliance_check::send_failure_slack_notification
+ if: failure()
+ run: |
+ MESSAGE="⚠️ Scheduled compliance check failed for upcoming preview release $LATEST_TAG: There are PRs with missing reviews."
+
+ curl -X POST -H 'Content-type: application/json' \
+ --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+ "$SLACK_WEBHOOK"
+ env:
+ SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }}
+ LATEST_TAG: ${{ steps.determine-version.outputs.tag }}
+defaults:
+ run:
+ shell: bash -euxo pipefail {0}
@@ -293,6 +293,51 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
timeout-minutes: 60
+ compliance_check:
+ if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
+ runs-on: namespace-profile-16x32-ubuntu-2204
+ env:
+ COMPLIANCE_FILE_PATH: compliance.md
+ steps:
+ - name: steps::checkout_repo
+ uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd
+ with:
+ clean: false
+ fetch-depth: 0
+ ref: ${{ github.ref }}
+ - name: steps::cache_rust_dependencies_namespace
+ uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9
+ with:
+ cache: rust
+ path: ~/.rustup
+ - id: run-compliance-check
+ name: release::compliance_check::run_compliance_check
+ run: cargo xtask compliance "$GITHUB_REF_NAME" --report-path "$COMPLIANCE_FILE_OUTPUT"
+ env:
+ GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }}
+ GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }}
+ - name: release::compliance_check::send_compliance_slack_notification
+ if: always()
+ run: |
+ if [ "$COMPLIANCE_OUTCOME" == "success" ]; then
+ STATUS="✅ Compliance check passed for $GITHUB_REF_NAME"
+ else
+ STATUS="❌ Compliance check failed for $GITHUB_REF_NAME"
+ fi
+
+ REPORT_CONTENT=""
+ if [ -f "$COMPLIANCE_FILE_OUTPUT" ]; then
+ REPORT_CONTENT=$(cat "$REPORT_FILE")
+ fi
+
+ MESSAGE=$(printf "%s\n\n%s" "$STATUS" "$REPORT_CONTENT")
+
+ curl -X POST -H 'Content-type: application/json' \
+ --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+ "$SLACK_WEBHOOK"
+ env:
+ SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }}
+ COMPLIANCE_OUTCOME: ${{ steps.run-compliance-check.outcome }}
bundle_linux_aarch64:
needs:
- run_tests_linux
@@ -613,6 +658,45 @@ jobs:
echo "All expected assets are present in the release."
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ - name: steps::checkout_repo
+ uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd
+ with:
+ clean: false
+ fetch-depth: 0
+ ref: ${{ github.ref }}
+ - name: steps::cache_rust_dependencies_namespace
+ uses: namespacelabs/nscloud-cache-action@a90bb5d4b27522ce881c6e98eebd7d7e6d1653f9
+ with:
+ cache: rust
+ path: ~/.rustup
+ - id: run-post-upload-compliance-check
+ name: release::validate_release_assets::run_post_upload_compliance_check
+ run: cargo xtask compliance "$GITHUB_REF_NAME" --report-path target/compliance-report
+ env:
+ GITHUB_APP_ID: ${{ secrets.ZED_ZIPPY_APP_ID }}
+ GITHUB_APP_KEY: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }}
+ - name: release::validate_release_assets::send_post_upload_compliance_notification
+ if: always()
+ run: |
+ if [ -z "$COMPLIANCE_OUTCOME" ] || [ "$COMPLIANCE_OUTCOME" == "skipped" ]; then
+ echo "Compliance check was skipped, not sending notification"
+ exit 0
+ fi
+
+ TAG="$GITHUB_REF_NAME"
+
+ if [ "$COMPLIANCE_OUTCOME" == "success" ]; then
+ MESSAGE="✅ Post-upload compliance re-check passed for $TAG"
+ else
+ MESSAGE="❌ Post-upload compliance re-check failed for $TAG"
+ fi
+
+ curl -X POST -H 'Content-type: application/json' \
+ --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+ "$SLACK_WEBHOOK"
+ env:
+ SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK_WORKFLOW_FAILURES }}
+ COMPLIANCE_OUTCOME: ${{ steps.run-post-upload-compliance-check.outcome }}
auto_release_preview:
needs:
- validate_release_assets
@@ -677,6 +677,15 @@ dependencies = [
"derive_arbitrary",
]
+[[package]]
+name = "arc-swap"
+version = "1.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6"
+dependencies = [
+ "rustversion",
+]
+
[[package]]
name = "arg_enum_proc_macro"
version = "0.3.4"
@@ -2530,6 +2539,16 @@ dependencies = [
"serde",
]
+[[package]]
+name = "cargo-platform"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "87a0c0e6148f11f01f32650a2ea02d532b2ad4e81d8bd41e6e565b5adc5e6082"
+dependencies = [
+ "serde",
+ "serde_core",
+]
+
[[package]]
name = "cargo_metadata"
version = "0.19.2"
@@ -2537,7 +2556,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd5eb614ed4c27c5d706420e4320fbe3216ab31fa1c33cd8246ac36dae4479ba"
dependencies = [
"camino",
- "cargo-platform",
+ "cargo-platform 0.1.9",
+ "semver",
+ "serde",
+ "serde_json",
+ "thiserror 2.0.17",
+]
+
+[[package]]
+name = "cargo_metadata"
+version = "0.23.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ef987d17b0a113becdd19d3d0022d04d7ef41f9efe4f3fb63ac44ba61df3ade9"
+dependencies = [
+ "camino",
+ "cargo-platform 0.3.2",
"semver",
"serde",
"serde_json",
@@ -3284,6 +3317,25 @@ dependencies = [
"workspace",
]
+[[package]]
+name = "compliance"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "async-trait",
+ "derive_more",
+ "futures 0.3.32",
+ "indoc",
+ "itertools 0.14.0",
+ "jsonwebtoken",
+ "octocrab",
+ "regex",
+ "semver",
+ "serde",
+ "serde_json",
+ "tokio",
+]
+
[[package]]
name = "component"
version = "0.1.0"
@@ -8324,6 +8376,7 @@ dependencies = [
"http 1.3.1",
"hyper 1.7.0",
"hyper-util",
+ "log",
"rustls 0.23.33",
"rustls-native-certs 0.8.2",
"rustls-pki-types",
@@ -8332,6 +8385,19 @@ dependencies = [
"tower-service",
]
+[[package]]
+name = "hyper-timeout"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0"
+dependencies = [
+ "hyper 1.7.0",
+ "hyper-util",
+ "pin-project-lite",
+ "tokio",
+ "tower-service",
+]
+
[[package]]
name = "hyper-tls"
version = "0.5.0"
@@ -10012,7 +10078,7 @@ dependencies = [
[[package]]
name = "lsp-types"
version = "0.95.1"
-source = "git+https://github.com/zed-industries/lsp-types?rev=a4f410987660bf560d1e617cb78117c6b6b9f599#a4f410987660bf560d1e617cb78117c6b6b9f599"
+source = "git+https://github.com/zed-industries/lsp-types?rev=c7396459fefc7886b4adfa3b596832405ae1e880#c7396459fefc7886b4adfa3b596832405ae1e880"
dependencies = [
"bitflags 1.3.2",
"serde",
@@ -11384,6 +11450,48 @@ dependencies = [
"memchr",
]
+[[package]]
+name = "octocrab"
+version = "0.49.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "63f6687a23731011d0117f9f4c3cdabaa7b5e42ca671f42b5cc0657c492540e3"
+dependencies = [
+ "arc-swap",
+ "async-trait",
+ "base64 0.22.1",
+ "bytes 1.11.1",
+ "cargo_metadata 0.23.1",
+ "cfg-if",
+ "chrono",
+ "either",
+ "futures 0.3.32",
+ "futures-core",
+ "futures-util",
+ "getrandom 0.2.16",
+ "http 1.3.1",
+ "http-body 1.0.1",
+ "http-body-util",
+ "hyper 1.7.0",
+ "hyper-rustls 0.27.7",
+ "hyper-timeout",
+ "hyper-util",
+ "jsonwebtoken",
+ "once_cell",
+ "percent-encoding",
+ "pin-project",
+ "secrecy",
+ "serde",
+ "serde_json",
+ "serde_path_to_error",
+ "serde_urlencoded",
+ "snafu",
+ "tokio",
+ "tower 0.5.2",
+ "tower-http 0.6.6",
+ "url",
+ "web-time",
+]
+
[[package]]
name = "ollama"
version = "0.1.0"
@@ -15385,6 +15493,15 @@ dependencies = [
"zeroize",
]
+[[package]]
+name = "secrecy"
+version = "0.10.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a"
+dependencies = [
+ "zeroize",
+]
+
[[package]]
name = "security-framework"
version = "2.11.1"
@@ -16089,6 +16206,27 @@ version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f7a918bd2a9951d18ee6e48f076843e8e73a9a5d22cf05bcd4b7a81bdd04e17"
+[[package]]
+name = "snafu"
+version = "0.8.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6e84b3f4eacbf3a1ce05eac6763b4d629d60cbc94d632e4092c54ade71f1e1a2"
+dependencies = [
+ "snafu-derive",
+]
+
+[[package]]
+name = "snafu-derive"
+version = "0.8.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
+dependencies = [
+ "heck 0.5.0",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.117",
+]
+
[[package]]
name = "snippet"
version = "0.1.0"
@@ -18093,8 +18231,10 @@ dependencies = [
"pin-project-lite",
"sync_wrapper 1.0.2",
"tokio",
+ "tokio-util",
"tower-layer",
"tower-service",
+ "tracing",
]
[[package]]
@@ -18132,6 +18272,7 @@ dependencies = [
"tower 0.5.2",
"tower-layer",
"tower-service",
+ "tracing",
]
[[package]]
@@ -19978,6 +20119,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
+ "serde",
"wasm-bindgen",
]
@@ -21715,9 +21857,10 @@ dependencies = [
"annotate-snippets",
"anyhow",
"backtrace",
- "cargo_metadata",
+ "cargo_metadata 0.19.2",
"cargo_toml",
"clap",
+ "compliance",
"gh-workflow",
"indexmap",
"indoc",
@@ -21727,6 +21870,7 @@ dependencies = [
"serde_json",
"serde_yaml",
"strum 0.27.2",
+ "tokio",
"toml 0.8.23",
"toml_edit 0.22.27",
]
@@ -22402,6 +22546,7 @@ name = "zeta_prompt"
version = "0.1.0"
dependencies = [
"anyhow",
+ "imara-diff",
"indoc",
"serde",
"strum 0.27.2",
@@ -242,6 +242,7 @@ members = [
# Tooling
#
+ "tooling/compliance",
"tooling/perf",
"tooling/xtask",
]
@@ -289,6 +290,7 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections", version = "0.1.0" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
+compliance = { path = "tooling/compliance" }
component = { path = "crates/component" }
component_preview = { path = "crates/component_preview" }
context_server = { path = "crates/context_server" }
@@ -547,6 +549,7 @@ derive_more = { version = "2.1.1", features = [
"add_assign",
"deref",
"deref_mut",
+ "display",
"from_str",
"mul",
"mul_assign",
@@ -596,7 +599,7 @@ linkify = "0.10.0"
libwebrtc = "0.3.26"
livekit = { version = "0.7.32", features = ["tokio", "rustls-tls-native-roots"] }
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
-lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "a4f410987660bf560d1e617cb78117c6b6b9f599" }
+lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c7396459fefc7886b4adfa3b596832405ae1e880" }
mach2 = "0.5"
markup5ever_rcdom = "0.3.0"
metal = "0.33"
@@ -2417,6 +2417,7 @@
"toggle_relative_line_numbers": false,
"use_system_clipboard": "always",
"use_smartcase_find": false,
+ "use_regex_search": true,
"gdefault": false,
"highlight_on_yank_duration": 200,
"custom_digraphs": {},
@@ -1032,6 +1032,7 @@ pub struct AcpThread {
connection: Rc<dyn AgentConnection>,
token_usage: Option<TokenUsage>,
prompt_capabilities: acp::PromptCapabilities,
+ available_commands: Vec<acp::AvailableCommand>,
_observe_prompt_capabilities: Task<anyhow::Result<()>>,
terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
pending_terminal_output: HashMap<acp::TerminalId, Vec<Vec<u8>>>,
@@ -1220,6 +1221,7 @@ impl AcpThread {
session_id,
token_usage: None,
prompt_capabilities,
+ available_commands: Vec::new(),
_observe_prompt_capabilities: task,
terminals: HashMap::default(),
pending_terminal_output: HashMap::default(),
@@ -1239,6 +1241,10 @@ impl AcpThread {
self.prompt_capabilities.clone()
}
+ pub fn available_commands(&self) -> &[acp::AvailableCommand] {
+ &self.available_commands
+ }
+
pub fn draft_prompt(&self) -> Option<&[acp::ContentBlock]> {
self.draft_prompt.as_deref()
}
@@ -1419,7 +1425,10 @@ impl AcpThread {
acp::SessionUpdate::AvailableCommandsUpdate(acp::AvailableCommandsUpdate {
available_commands,
..
- }) => cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands)),
+ }) => {
+ self.available_commands = available_commands.clone();
+ cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands));
+ }
acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
current_mode_id,
..
@@ -202,3 +202,214 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
);
});
}
+
+#[gpui::test]
+async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes(
+ cx: &mut TestAppContext,
+) {
+ super::init_test(cx);
+ super::always_allow_tools(cx);
+
+ // Enable the streaming edit file tool feature flag.
+ cx.update(|cx| {
+ cx.update_flags(true, vec!["streaming-edit-file-tool".to_string()]);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "src": {
+ "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n"
+ }
+ }),
+ )
+ .await;
+
+ let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let project_context = cx.new(|_cx| ProjectContext::default());
+ let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
+ let context_server_registry =
+ cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+ model.as_fake().set_supports_streaming_tools(true);
+ let fake_model = model.as_fake();
+
+ let thread = cx.new(|cx| {
+ let mut thread = crate::Thread::new(
+ project.clone(),
+ project_context,
+ context_server_registry,
+ crate::Templates::new(),
+ Some(model.clone()),
+ cx,
+ );
+ let language_registry = project.read(cx).languages().clone();
+ thread.add_tool(crate::StreamingEditFileTool::new(
+ project.clone(),
+ cx.weak_entity(),
+ thread.action_log().clone(),
+ language_registry,
+ ));
+ thread
+ });
+
+ let _events = thread
+ .update(cx, |thread, cx| {
+ thread.send(
+ UserMessageId::new(),
+ ["Write new content to src/main.rs"],
+ cx,
+ )
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ let tool_use_id = "edit_1";
+ let partial_1 = LanguageModelToolUse {
+ id: tool_use_id.into(),
+ name: EditFileTool::NAME.into(),
+ raw_input: json!({
+ "display_description": "Rewrite main.rs",
+ "path": "project/src/main.rs",
+ "mode": "write"
+ })
+ .to_string(),
+ input: json!({
+ "display_description": "Rewrite main.rs",
+ "path": "project/src/main.rs",
+ "mode": "write"
+ }),
+ is_input_complete: false,
+ thought_signature: None,
+ };
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_1));
+ cx.run_until_parked();
+
+ let partial_2 = LanguageModelToolUse {
+ id: tool_use_id.into(),
+ name: EditFileTool::NAME.into(),
+ raw_input: json!({
+ "display_description": "Rewrite main.rs",
+ "path": "project/src/main.rs",
+ "mode": "write",
+ "content": "fn main() { /* rewritten */ }"
+ })
+ .to_string(),
+ input: json!({
+ "display_description": "Rewrite main.rs",
+ "path": "project/src/main.rs",
+ "mode": "write",
+ "content": "fn main() { /* rewritten */ }"
+ }),
+ is_input_complete: false,
+ thought_signature: None,
+ };
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_2));
+ cx.run_until_parked();
+
+ // Now send a json parse error. At this point we have started writing content to the buffer.
+ fake_model.send_last_completion_stream_event(
+ LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_use_id.into(),
+ tool_name: EditFileTool::NAME.into(),
+ raw_input: r#"{"display_description":"Rewrite main.rs","path":"project/src/main.rs","mode":"write","content":"fn main() { /* rewritten "#.into(),
+ json_parse_error: "EOF while parsing a string at line 1 column 95".into(),
+ },
+ );
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ // cx.executor().advance_clock(Duration::from_secs(5));
+ // cx.run_until_parked();
+
+ assert!(
+ !fake_model.pending_completions().is_empty(),
+ "Thread should have retried after the error"
+ );
+
+ // Respond with a new, well-formed, complete edit_file tool use.
+ let tool_use = LanguageModelToolUse {
+ id: "edit_2".into(),
+ name: EditFileTool::NAME.into(),
+ raw_input: json!({
+ "display_description": "Rewrite main.rs",
+ "path": "project/src/main.rs",
+ "mode": "write",
+ "content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n"
+ })
+ .to_string(),
+ input: json!({
+ "display_description": "Rewrite main.rs",
+ "path": "project/src/main.rs",
+ "mode": "write",
+ "content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n"
+ }),
+ is_input_complete: true,
+ thought_signature: None,
+ };
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ let pending_completions = fake_model.pending_completions();
+ assert!(
+ pending_completions.len() == 1,
+ "Expected only the follow-up completion containing the successful tool result"
+ );
+
+ let completion = pending_completions
+ .into_iter()
+ .last()
+ .expect("Expected a completion containing the tool result for edit_2");
+
+ let tool_result = completion
+ .messages
+ .iter()
+ .flat_map(|msg| &msg.content)
+ .find_map(|content| match content {
+ language_model::MessageContent::ToolResult(result)
+ if result.tool_use_id == language_model::LanguageModelToolUseId::from("edit_2") =>
+ {
+ Some(result)
+ }
+ _ => None,
+ })
+ .expect("Should have a tool result for edit_2");
+
+ // Ensure that the second tool call completed successfully and edits were applied.
+ assert!(
+ !tool_result.is_error,
+ "Tool result should succeed, got: {:?}",
+ tool_result
+ );
+ let content_text = match &tool_result.content {
+ language_model::LanguageModelToolResultContent::Text(t) => t.to_string(),
+ other => panic!("Expected text content, got: {:?}", other),
+ };
+ assert!(
+ !content_text.contains("file has been modified since you last read it"),
+ "Did not expect a stale last-read error, got: {content_text}"
+ );
+ assert!(
+ !content_text.contains("This file has unsaved changes"),
+ "Did not expect an unsaved-changes error, got: {content_text}"
+ );
+
+ let file_content = fs
+ .load(path!("/project/src/main.rs").as_ref())
+ .await
+ .expect("file should exist");
+ super::assert_eq!(
+ file_content,
+ "fn main() {\n println!(\"Hello, rewritten!\");\n}\n",
+ "The second edit should be applied and saved gracefully"
+ );
+
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+}
@@ -3903,6 +3903,117 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
});
}
+#[gpui::test]
+async fn test_streaming_tool_json_parse_error_is_forwarded_to_running_tool(
+ cx: &mut TestAppContext,
+) {
+ init_test(cx);
+ always_allow_tools(cx);
+
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ thread.update(cx, |thread, _cx| {
+ thread.add_tool(StreamingJsonErrorContextTool);
+ });
+
+ let _events = thread
+ .update(cx, |thread, cx| {
+ thread.send(
+ UserMessageId::new(),
+ ["Use the streaming_json_error_context tool"],
+ cx,
+ )
+ })
+ .unwrap();
+ cx.run_until_parked();
+
+ let tool_use = LanguageModelToolUse {
+ id: "tool_1".into(),
+ name: StreamingJsonErrorContextTool::NAME.into(),
+ raw_input: r#"{"text": "partial"#.into(),
+ input: json!({"text": "partial"}),
+ is_input_complete: false,
+ thought_signature: None,
+ };
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_event(
+ LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: "tool_1".into(),
+ tool_name: StreamingJsonErrorContextTool::NAME.into(),
+ raw_input: r#"{"text": "partial"#.into(),
+ json_parse_error: "EOF while parsing a string at line 1 column 17".into(),
+ },
+ );
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ cx.executor().advance_clock(Duration::from_secs(5));
+ cx.run_until_parked();
+
+ let completion = fake_model
+ .pending_completions()
+ .pop()
+ .expect("No running turn");
+
+ let tool_results: Vec<_> = completion
+ .messages
+ .iter()
+ .flat_map(|message| &message.content)
+ .filter_map(|content| match content {
+ MessageContent::ToolResult(result)
+ if result.tool_use_id == language_model::LanguageModelToolUseId::from("tool_1") =>
+ {
+ Some(result)
+ }
+ _ => None,
+ })
+ .collect();
+
+ assert_eq!(
+ tool_results.len(),
+ 1,
+ "Expected exactly 1 tool result for tool_1, got {}: {:#?}",
+ tool_results.len(),
+ tool_results
+ );
+
+ let result = tool_results[0];
+ assert!(result.is_error);
+ let content_text = match &result.content {
+ language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
+ other => panic!("Expected text content, got {:?}", other),
+ };
+ assert!(
+ content_text.contains("Saw partial text 'partial' before invalid JSON"),
+ "Expected tool-enriched partial context, got: {content_text}"
+ );
+ assert!(
+ content_text
+ .contains("Error parsing input JSON: EOF while parsing a string at line 1 column 17"),
+ "Expected forwarded JSON parse error, got: {content_text}"
+ );
+ assert!(
+ !content_text.contains("tool input was not fully received"),
+ "Should not contain orphaned sender error, got: {content_text}"
+ );
+
+ fake_model.send_last_completion_stream_text_chunk("Done");
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+
+ thread.read_with(cx, |thread, _cx| {
+ assert!(
+ thread.is_turn_complete(),
+ "Thread should not be stuck; the turn should have completed",
+ );
+ });
+}
+
/// Filters out the stop events for asserting against in tests
fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
result_events
@@ -3959,6 +4070,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
InfiniteTool::NAME: true,
CancellationAwareTool::NAME: true,
StreamingEchoTool::NAME: true,
+ StreamingJsonErrorContextTool::NAME: true,
StreamingFailingEchoTool::NAME: true,
TerminalTool::NAME: true,
UpdatePlanTool::NAME: true,
@@ -56,13 +56,12 @@ impl AgentTool for StreamingEchoTool {
fn run(
self: Arc<Self>,
- mut input: ToolInput<Self::Input>,
+ input: ToolInput<Self::Input>,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String, String>> {
let wait_until_complete_rx = self.wait_until_complete_rx.lock().unwrap().take();
cx.spawn(async move |_cx| {
- while input.recv_partial().await.is_some() {}
let input = input
.recv()
.await
@@ -75,6 +74,68 @@ impl AgentTool for StreamingEchoTool {
}
}
+#[derive(JsonSchema, Serialize, Deserialize)]
+pub struct StreamingJsonErrorContextToolInput {
+ /// The text to echo.
+ pub text: String,
+}
+
+pub struct StreamingJsonErrorContextTool;
+
+impl AgentTool for StreamingJsonErrorContextTool {
+ type Input = StreamingJsonErrorContextToolInput;
+ type Output = String;
+
+ const NAME: &'static str = "streaming_json_error_context";
+
+ fn supports_input_streaming() -> bool {
+ true
+ }
+
+ fn kind() -> acp::ToolKind {
+ acp::ToolKind::Other
+ }
+
+ fn initial_title(
+ &self,
+ _input: Result<Self::Input, serde_json::Value>,
+ _cx: &mut App,
+ ) -> SharedString {
+ "Streaming JSON Error Context".into()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ mut input: ToolInput<Self::Input>,
+ _event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Task<Result<String, String>> {
+ cx.spawn(async move |_cx| {
+ let mut last_partial_text = None;
+
+ loop {
+ match input.next().await {
+ Ok(ToolInputPayload::Partial(partial)) => {
+ if let Some(text) = partial.get("text").and_then(|value| value.as_str()) {
+ last_partial_text = Some(text.to_string());
+ }
+ }
+ Ok(ToolInputPayload::Full(input)) => return Ok(input.text),
+ Ok(ToolInputPayload::InvalidJson { error_message }) => {
+ let partial_text = last_partial_text.unwrap_or_default();
+ return Err(format!(
+ "Saw partial text '{partial_text}' before invalid JSON: {error_message}"
+ ));
+ }
+ Err(error) => {
+ return Err(format!("Failed to receive tool input: {error}"));
+ }
+ }
+ }
+ })
+ }
+}
+
/// A streaming tool that echoes its input, used to test streaming tool
/// lifecycle (e.g. partial delivery and cleanup when the LLM stream ends
/// before `is_input_complete`).
@@ -119,7 +180,7 @@ impl AgentTool for StreamingFailingEchoTool {
) -> Task<Result<Self::Output, Self::Output>> {
cx.spawn(async move |_cx| {
for _ in 0..self.receive_chunks_until_failure {
- let _ = input.recv_partial().await;
+ let _ = input.next().await;
}
Err("failed".into())
})
@@ -22,13 +22,13 @@ use client::UserStore;
use cloud_api_types::Plan;
use collections::{HashMap, HashSet, IndexMap};
use fs::Fs;
-use futures::stream;
use futures::{
FutureExt,
channel::{mpsc, oneshot},
future::Shared,
stream::FuturesUnordered,
};
+use futures::{StreamExt, stream};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
};
@@ -47,7 +47,6 @@ use schemars::{JsonSchema, Schema};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use settings::{LanguageModelSelection, Settings, ToolPermissionMode, update_settings_file};
-use smol::stream::StreamExt;
use std::{
collections::BTreeMap,
marker::PhantomData,
@@ -2095,7 +2094,7 @@ impl Thread {
this.update(cx, |this, _cx| {
this.pending_message()
.tool_results
- .insert(tool_result.tool_use_id.clone(), tool_result);
+ .insert(tool_result.tool_use_id.clone(), tool_result)
})?;
Ok(())
}
@@ -2195,15 +2194,15 @@ impl Thread {
raw_input,
json_parse_error,
} => {
- return Ok(Some(Task::ready(
- self.handle_tool_use_json_parse_error_event(
- id,
- tool_name,
- raw_input,
- json_parse_error,
- event_stream,
- ),
- )));
+ return Ok(self.handle_tool_use_json_parse_error_event(
+ id,
+ tool_name,
+ raw_input,
+ json_parse_error,
+ event_stream,
+ cancellation_rx,
+ cx,
+ ));
}
UsageUpdate(usage) => {
telemetry::event!(
@@ -2304,12 +2303,12 @@ impl Thread {
if !tool_use.is_input_complete {
if tool.supports_input_streaming() {
let running_turn = self.running_turn.as_mut()?;
- if let Some(sender) = running_turn.streaming_tool_inputs.get(&tool_use.id) {
+ if let Some(sender) = running_turn.streaming_tool_inputs.get_mut(&tool_use.id) {
sender.send_partial(tool_use.input);
return None;
}
- let (sender, tool_input) = ToolInputSender::channel();
+ let (mut sender, tool_input) = ToolInputSender::channel();
sender.send_partial(tool_use.input);
running_turn
.streaming_tool_inputs
@@ -2331,13 +2330,13 @@ impl Thread {
}
}
- if let Some(sender) = self
+ if let Some(mut sender) = self
.running_turn
.as_mut()?
.streaming_tool_inputs
.remove(&tool_use.id)
{
- sender.send_final(tool_use.input);
+ sender.send_full(tool_use.input);
return None;
}
@@ -2410,10 +2409,12 @@ impl Thread {
raw_input: Arc<str>,
json_parse_error: String,
event_stream: &ThreadEventStream,
- ) -> LanguageModelToolResult {
+ cancellation_rx: watch::Receiver<bool>,
+ cx: &mut Context<Self>,
+ ) -> Option<Task<LanguageModelToolResult>> {
let tool_use = LanguageModelToolUse {
- id: tool_use_id.clone(),
- name: tool_name.clone(),
+ id: tool_use_id,
+ name: tool_name,
raw_input: raw_input.to_string(),
input: serde_json::json!({}),
is_input_complete: true,
@@ -2426,14 +2427,43 @@ impl Thread {
event_stream,
);
- let tool_output = format!("Error parsing input JSON: {json_parse_error}");
- LanguageModelToolResult {
- tool_use_id,
- tool_name,
- is_error: true,
- content: LanguageModelToolResultContent::Text(tool_output.into()),
- output: Some(serde_json::Value::String(raw_input.to_string())),
+ let tool = self.tool(tool_use.name.as_ref());
+
+ let Some(tool) = tool else {
+ let content = format!("No tool named {} exists", tool_use.name);
+ return Some(Task::ready(LanguageModelToolResult {
+ content: LanguageModelToolResultContent::Text(Arc::from(content)),
+ tool_use_id: tool_use.id,
+ tool_name: tool_use.name,
+ is_error: true,
+ output: None,
+ }));
+ };
+
+ let error_message = format!("Error parsing input JSON: {json_parse_error}");
+
+ if tool.supports_input_streaming()
+ && let Some(mut sender) = self
+ .running_turn
+ .as_mut()?
+ .streaming_tool_inputs
+ .remove(&tool_use.id)
+ {
+ sender.send_invalid_json(error_message);
+ return None;
}
+
+ log::debug!("Running tool {}. Received invalid JSON", tool_use.name);
+ let tool_input = ToolInput::invalid_json(error_message);
+ Some(self.run_tool(
+ tool,
+ tool_input,
+ tool_use.id,
+ tool_use.name,
+ event_stream,
+ cancellation_rx,
+ cx,
+ ))
}
fn send_or_update_tool_use(
@@ -3114,8 +3144,7 @@ impl EventEmitter<TitleUpdated> for Thread {}
/// For streaming tools, partial JSON snapshots arrive via `.recv_partial()` as the LLM streams
/// them, followed by the final complete input available through `.recv()`.
pub struct ToolInput<T> {
- partial_rx: mpsc::UnboundedReceiver<serde_json::Value>,
- final_rx: oneshot::Receiver<serde_json::Value>,
+ rx: mpsc::UnboundedReceiver<ToolInputPayload<serde_json::Value>>,
_phantom: PhantomData<T>,
}
@@ -3127,13 +3156,20 @@ impl<T: DeserializeOwned> ToolInput<T> {
}
pub fn ready(value: serde_json::Value) -> Self {
- let (partial_tx, partial_rx) = mpsc::unbounded();
- drop(partial_tx);
- let (final_tx, final_rx) = oneshot::channel();
- final_tx.send(value).ok();
+ let (tx, rx) = mpsc::unbounded();
+ tx.unbounded_send(ToolInputPayload::Full(value)).ok();
Self {
- partial_rx,
- final_rx,
+ rx,
+ _phantom: PhantomData,
+ }
+ }
+
+ pub fn invalid_json(error_message: String) -> Self {
+ let (tx, rx) = mpsc::unbounded();
+ tx.unbounded_send(ToolInputPayload::InvalidJson { error_message })
+ .ok();
+ Self {
+ rx,
_phantom: PhantomData,
}
}
@@ -3147,65 +3183,89 @@ impl<T: DeserializeOwned> ToolInput<T> {
/// Wait for the final deserialized input, ignoring all partial updates.
/// Non-streaming tools can use this to wait until the whole input is available.
pub async fn recv(mut self) -> Result<T> {
- // Drain any remaining partials
- while self.partial_rx.next().await.is_some() {}
+ while let Ok(value) = self.next().await {
+ match value {
+ ToolInputPayload::Full(value) => return Ok(value),
+ ToolInputPayload::Partial(_) => {}
+ ToolInputPayload::InvalidJson { error_message } => {
+ return Err(anyhow!(error_message));
+ }
+ }
+ }
+ Err(anyhow!("tool input was not fully received"))
+ }
+
+ pub async fn next(&mut self) -> Result<ToolInputPayload<T>> {
let value = self
- .final_rx
+ .rx
+ .next()
.await
- .map_err(|_| anyhow!("tool input was not fully received"))?;
- serde_json::from_value(value).map_err(Into::into)
- }
+ .ok_or_else(|| anyhow!("tool input was not fully received"))?;
- /// Returns the next partial JSON snapshot, or `None` when input is complete.
- /// Once this returns `None`, call `recv()` to get the final input.
- pub async fn recv_partial(&mut self) -> Option<serde_json::Value> {
- self.partial_rx.next().await
+ Ok(match value {
+ ToolInputPayload::Partial(payload) => ToolInputPayload::Partial(payload),
+ ToolInputPayload::Full(payload) => {
+ ToolInputPayload::Full(serde_json::from_value(payload)?)
+ }
+ ToolInputPayload::InvalidJson { error_message } => {
+ ToolInputPayload::InvalidJson { error_message }
+ }
+ })
}
fn cast<U: DeserializeOwned>(self) -> ToolInput<U> {
ToolInput {
- partial_rx: self.partial_rx,
- final_rx: self.final_rx,
+ rx: self.rx,
_phantom: PhantomData,
}
}
}
+pub enum ToolInputPayload<T> {
+ Partial(serde_json::Value),
+ Full(T),
+ InvalidJson { error_message: String },
+}
+
pub struct ToolInputSender {
- partial_tx: mpsc::UnboundedSender<serde_json::Value>,
- final_tx: Option<oneshot::Sender<serde_json::Value>>,
+ has_received_final: bool,
+ tx: mpsc::UnboundedSender<ToolInputPayload<serde_json::Value>>,
}
impl ToolInputSender {
pub(crate) fn channel() -> (Self, ToolInput<serde_json::Value>) {
- let (partial_tx, partial_rx) = mpsc::unbounded();
- let (final_tx, final_rx) = oneshot::channel();
+ let (tx, rx) = mpsc::unbounded();
let sender = Self {
- partial_tx,
- final_tx: Some(final_tx),
+ tx,
+ has_received_final: false,
};
let input = ToolInput {
- partial_rx,
- final_rx,
+ rx,
_phantom: PhantomData,
};
(sender, input)
}
pub(crate) fn has_received_final(&self) -> bool {
- self.final_tx.is_none()
+ self.has_received_final
}
- pub(crate) fn send_partial(&self, value: serde_json::Value) {
- self.partial_tx.unbounded_send(value).ok();
+ pub fn send_partial(&mut self, payload: serde_json::Value) {
+ self.tx
+ .unbounded_send(ToolInputPayload::Partial(payload))
+ .ok();
}
- pub(crate) fn send_final(mut self, value: serde_json::Value) {
- // Close the partial channel so recv_partial() returns None
- self.partial_tx.close_channel();
- if let Some(final_tx) = self.final_tx.take() {
- final_tx.send(value).ok();
- }
+ pub fn send_full(&mut self, payload: serde_json::Value) {
+ self.has_received_final = true;
+ self.tx.unbounded_send(ToolInputPayload::Full(payload)).ok();
+ }
+
+ pub fn send_invalid_json(&mut self, error_message: String) {
+ self.has_received_final = true;
+ self.tx
+ .unbounded_send(ToolInputPayload::InvalidJson { error_message })
+ .ok();
}
}
@@ -4251,68 +4311,78 @@ mod tests {
) {
let (thread, event_stream) = setup_thread_for_test(cx).await;
- cx.update(|cx| {
- thread.update(cx, |thread, _cx| {
- let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
- let tool_name: Arc<str> = Arc::from("test_tool");
- let raw_input: Arc<str> = Arc::from("{invalid json");
- let json_parse_error = "expected value at line 1 column 1".to_string();
-
- // Call the function under test
- let result = thread.handle_tool_use_json_parse_error_event(
- tool_use_id.clone(),
- tool_name.clone(),
- raw_input.clone(),
- json_parse_error,
- &event_stream,
- );
-
- // Verify the result is an error
- assert!(result.is_error);
- assert_eq!(result.tool_use_id, tool_use_id);
- assert_eq!(result.tool_name, tool_name);
- assert!(matches!(
- result.content,
- LanguageModelToolResultContent::Text(_)
- ));
-
- // Verify the tool use was added to the message content
- {
- let last_message = thread.pending_message();
- assert_eq!(
- last_message.content.len(),
- 1,
- "Should have one tool_use in content"
- );
-
- match &last_message.content[0] {
- AgentMessageContent::ToolUse(tool_use) => {
- assert_eq!(tool_use.id, tool_use_id);
- assert_eq!(tool_use.name, tool_name);
- assert_eq!(tool_use.raw_input, raw_input.to_string());
- assert!(tool_use.is_input_complete);
- // Should fall back to empty object for invalid JSON
- assert_eq!(tool_use.input, json!({}));
- }
- _ => panic!("Expected ToolUse content"),
- }
- }
-
- // Insert the tool result (simulating what the caller does)
- thread
- .pending_message()
- .tool_results
- .insert(result.tool_use_id.clone(), result);
+ let tool_use_id = LanguageModelToolUseId::from("test_tool_id");
+ let tool_name: Arc<str> = Arc::from("test_tool");
+ let raw_input: Arc<str> = Arc::from("{invalid json");
+ let json_parse_error = "expected value at line 1 column 1".to_string();
+
+ let (_cancellation_tx, cancellation_rx) = watch::channel(false);
+
+ let result = cx
+ .update(|cx| {
+ thread.update(cx, |thread, cx| {
+ // Call the function under test
+ thread
+ .handle_tool_use_json_parse_error_event(
+ tool_use_id.clone(),
+ tool_name.clone(),
+ raw_input.clone(),
+ json_parse_error,
+ &event_stream,
+ cancellation_rx,
+ cx,
+ )
+ .unwrap()
+ })
+ })
+ .await;
+
+ // Verify the result is an error
+ assert!(result.is_error);
+ assert_eq!(result.tool_use_id, tool_use_id);
+ assert_eq!(result.tool_name, tool_name);
+ assert!(matches!(
+ result.content,
+ LanguageModelToolResultContent::Text(_)
+ ));
- // Verify the tool result was added
+ thread.update(cx, |thread, _cx| {
+ // Verify the tool use was added to the message content
+ {
let last_message = thread.pending_message();
assert_eq!(
- last_message.tool_results.len(),
+ last_message.content.len(),
1,
- "Should have one tool_result"
+ "Should have one tool_use in content"
);
- assert!(last_message.tool_results.contains_key(&tool_use_id));
- });
- });
+
+ match &last_message.content[0] {
+ AgentMessageContent::ToolUse(tool_use) => {
+ assert_eq!(tool_use.id, tool_use_id);
+ assert_eq!(tool_use.name, tool_name);
+ assert_eq!(tool_use.raw_input, raw_input.to_string());
+ assert!(tool_use.is_input_complete);
+ // Should fall back to empty object for invalid JSON
+ assert_eq!(tool_use.input, json!({}));
+ }
+ _ => panic!("Expected ToolUse content"),
+ }
+ }
+
+ // Insert the tool result (simulating what the caller does)
+ thread
+ .pending_message()
+ .tool_results
+ .insert(result.tool_use_id.clone(), result);
+
+ // Verify the tool result was added
+ let last_message = thread.pending_message();
+ assert_eq!(
+ last_message.tool_results.len(),
+ 1,
+ "Should have one tool_result"
+ );
+ assert!(last_message.tool_results.contains_key(&tool_use_id));
+ })
}
}
@@ -2,6 +2,7 @@ use super::edit_file_tool::EditFileTool;
use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
use super::save_file_tool::SaveFileTool;
use super::tool_edit_parser::{ToolEditEvent, ToolEditParser};
+use crate::ToolInputPayload;
use crate::{
AgentTool, Thread, ToolCallEventStream, ToolInput,
edit_agent::{
@@ -12,7 +13,7 @@ use crate::{
use acp_thread::Diff;
use action_log::ActionLog;
use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
-use anyhow::{Context as _, Result};
+use anyhow::Result;
use collections::HashSet;
use futures::FutureExt as _;
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
@@ -188,6 +189,10 @@ pub enum StreamingEditFileToolOutput {
},
Error {
error: String,
+ #[serde(default)]
+ input_path: Option<PathBuf>,
+ #[serde(default)]
+ diff: String,
},
}
@@ -195,6 +200,8 @@ impl StreamingEditFileToolOutput {
pub fn error(error: impl Into<String>) -> Self {
Self::Error {
error: error.into(),
+ input_path: None,
+ diff: String::new(),
}
}
}
@@ -215,7 +222,24 @@ impl std::fmt::Display for StreamingEditFileToolOutput {
)
}
}
- StreamingEditFileToolOutput::Error { error } => write!(f, "{error}"),
+ StreamingEditFileToolOutput::Error {
+ error,
+ diff,
+ input_path,
+ } => {
+ write!(f, "{error}\n")?;
+ if let Some(input_path) = input_path
+ && !diff.is_empty()
+ {
+ write!(
+ f,
+ "Edited {}:\n\n```diff\n{diff}\n```",
+ input_path.display()
+ )
+ } else {
+ write!(f, "No edits were made.")
+ }
+ }
}
}
}
@@ -233,6 +257,14 @@ pub struct StreamingEditFileTool {
language_registry: Arc<LanguageRegistry>,
}
+enum EditSessionResult {
+ Completed(EditSession),
+ Failed {
+ error: String,
+ session: Option<EditSession>,
+ },
+}
+
impl StreamingEditFileTool {
pub fn new(
project: Entity<Project>,
@@ -276,6 +308,158 @@ impl StreamingEditFileTool {
});
}
}
+
+ async fn ensure_buffer_saved(&self, buffer: &Entity<Buffer>, cx: &mut AsyncApp) {
+ let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
+ let settings = language_settings::LanguageSettings::for_buffer(buffer, cx);
+ settings.format_on_save != FormatOnSave::Off
+ });
+
+ if format_on_save_enabled {
+ self.project
+ .update(cx, |project, cx| {
+ project.format(
+ HashSet::from_iter([buffer.clone()]),
+ LspFormatTarget::Buffers,
+ false,
+ FormatTrigger::Save,
+ cx,
+ )
+ })
+ .await
+ .log_err();
+ }
+
+ self.project
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+ .await
+ .log_err();
+
+ self.action_log.update(cx, |log, cx| {
+ log.buffer_edited(buffer.clone(), cx);
+ });
+ }
+
+ async fn process_streaming_edits(
+ &self,
+ input: &mut ToolInput<StreamingEditFileToolInput>,
+ event_stream: &ToolCallEventStream,
+ cx: &mut AsyncApp,
+ ) -> EditSessionResult {
+ let mut session: Option<EditSession> = None;
+ let mut last_partial: Option<StreamingEditFileToolPartialInput> = None;
+
+ loop {
+ futures::select! {
+ payload = input.next().fuse() => {
+ match payload {
+ Ok(payload) => match payload {
+ ToolInputPayload::Partial(partial) => {
+ if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial) {
+ let path_complete = parsed.path.is_some()
+ && parsed.path.as_ref() == last_partial.as_ref().and_then(|partial| partial.path.as_ref());
+
+ last_partial = Some(parsed.clone());
+
+ if session.is_none()
+ && path_complete
+ && let StreamingEditFileToolPartialInput {
+ path: Some(path),
+ display_description: Some(display_description),
+ mode: Some(mode),
+ ..
+ } = &parsed
+ {
+ match EditSession::new(
+ PathBuf::from(path),
+ display_description,
+ *mode,
+ self,
+ event_stream,
+ cx,
+ )
+ .await
+ {
+ Ok(created_session) => session = Some(created_session),
+ Err(error) => {
+ log::error!("Failed to create edit session: {}", error);
+ return EditSessionResult::Failed {
+ error,
+ session: None,
+ };
+ }
+ }
+ }
+
+ if let Some(current_session) = &mut session
+ && let Err(error) = current_session.process(parsed, self, event_stream, cx)
+ {
+ log::error!("Failed to process edit: {}", error);
+ return EditSessionResult::Failed { error, session };
+ }
+ }
+ }
+ ToolInputPayload::Full(full_input) => {
+ let mut session = if let Some(session) = session {
+ session
+ } else {
+ match EditSession::new(
+ full_input.path.clone(),
+ &full_input.display_description,
+ full_input.mode,
+ self,
+ event_stream,
+ cx,
+ )
+ .await
+ {
+ Ok(created_session) => created_session,
+ Err(error) => {
+ log::error!("Failed to create edit session: {}", error);
+ return EditSessionResult::Failed {
+ error,
+ session: None,
+ };
+ }
+ }
+ };
+
+ return match session.finalize(full_input, self, event_stream, cx).await {
+ Ok(()) => EditSessionResult::Completed(session),
+ Err(error) => {
+ log::error!("Failed to finalize edit: {}", error);
+ EditSessionResult::Failed {
+ error,
+ session: Some(session),
+ }
+ }
+ };
+ }
+ ToolInputPayload::InvalidJson { error_message } => {
+ log::error!("Received invalid JSON: {error_message}");
+ return EditSessionResult::Failed {
+ error: error_message,
+ session,
+ };
+ }
+ },
+ Err(error) => {
+ return EditSessionResult::Failed {
+ error: format!("Failed to receive tool input: {error}"),
+ session,
+ };
+ }
+ }
+ }
+ _ = event_stream.cancelled_by_user().fuse() => {
+ return EditSessionResult::Failed {
+ error: "Edit cancelled by user".to_string(),
+ session,
+ };
+ }
+ }
+ }
+ }
}
impl AgentTool for StreamingEditFileTool {
@@ -348,94 +532,40 @@ impl AgentTool for StreamingEditFileTool {
cx: &mut App,
) -> Task<Result<Self::Output, Self::Output>> {
cx.spawn(async move |cx: &mut AsyncApp| {
- let mut state: Option<EditSession> = None;
- let mut last_partial: Option<StreamingEditFileToolPartialInput> = None;
- loop {
- futures::select! {
- partial = input.recv_partial().fuse() => {
- let Some(partial_value) = partial else { break };
- if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial_value) {
- let path_complete = parsed.path.is_some()
- && parsed.path.as_ref() == last_partial.as_ref().and_then(|p| p.path.as_ref());
-
- last_partial = Some(parsed.clone());
-
- if state.is_none()
- && path_complete
- && let StreamingEditFileToolPartialInput {
- path: Some(path),
- display_description: Some(display_description),
- mode: Some(mode),
- ..
- } = &parsed
- {
- match EditSession::new(
- &PathBuf::from(path),
- display_description,
- *mode,
- &self,
- &event_stream,
- cx,
- )
- .await
- {
- Ok(session) => state = Some(session),
- Err(e) => {
- log::error!("Failed to create edit session: {}", e);
- return Err(e);
- }
- }
- }
-
- if let Some(state) = &mut state {
- if let Err(e) = state.process(parsed, &self, &event_stream, cx) {
- log::error!("Failed to process edit: {}", e);
- return Err(e);
- }
- }
- }
- }
- _ = event_stream.cancelled_by_user().fuse() => {
- return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
- }
- }
- }
- let full_input =
- input
- .recv()
- .await
- .map_err(|e| {
- let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}"));
- log::error!("Failed to receive tool input: {e}");
- err
- })?;
-
- let mut state = if let Some(state) = state {
- state
- } else {
- match EditSession::new(
- &full_input.path,
- &full_input.display_description,
- full_input.mode,
- &self,
- &event_stream,
- cx,
- )
+ match self
+ .process_streaming_edits(&mut input, &event_stream, cx)
.await
- {
- Ok(session) => session,
- Err(e) => {
- log::error!("Failed to create edit session: {}", e);
- return Err(e);
- }
+ {
+ EditSessionResult::Completed(session) => {
+ self.ensure_buffer_saved(&session.buffer, cx).await;
+ let (new_text, diff) = session.compute_new_text_and_diff(cx).await;
+ Ok(StreamingEditFileToolOutput::Success {
+ old_text: session.old_text.clone(),
+ new_text,
+ input_path: session.input_path,
+ diff,
+ })
}
- };
- match state.finalize(full_input, &self, &event_stream, cx).await {
- Ok(output) => Ok(output),
- Err(e) => {
- log::error!("Failed to finalize edit: {}", e);
- Err(e)
+ EditSessionResult::Failed {
+ error,
+ session: Some(session),
+ } => {
+ self.ensure_buffer_saved(&session.buffer, cx).await;
+ let (_new_text, diff) = session.compute_new_text_and_diff(cx).await;
+ Err(StreamingEditFileToolOutput::Error {
+ error,
+ input_path: Some(session.input_path),
+ diff,
+ })
}
+ EditSessionResult::Failed {
+ error,
+ session: None,
+ } => Err(StreamingEditFileToolOutput::Error {
+ error,
+ input_path: None,
+ diff: String::new(),
+ }),
}
})
}
@@ -472,6 +602,7 @@ impl AgentTool for StreamingEditFileTool {
pub struct EditSession {
abs_path: PathBuf,
+ input_path: PathBuf,
buffer: Entity<Buffer>,
old_text: Arc<String>,
diff: Entity<Diff>,
@@ -518,23 +649,21 @@ impl EditPipeline {
impl EditSession {
async fn new(
- path: &PathBuf,
+ path: PathBuf,
display_description: &str,
mode: StreamingEditFileMode,
tool: &StreamingEditFileTool,
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
- ) -> Result<Self, StreamingEditFileToolOutput> {
- let project_path = cx
- .update(|cx| resolve_path(mode, &path, &tool.project, cx))
- .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+ ) -> Result<Self, String> {
+ let project_path = cx.update(|cx| resolve_path(mode, &path, &tool.project, cx))?;
let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx))
else {
- return Err(StreamingEditFileToolOutput::error(format!(
+ return Err(format!(
"Worktree at '{}' does not exist",
path.to_string_lossy()
- )));
+ ));
};
event_stream.update_fields(
@@ -543,13 +672,13 @@ impl EditSession {
cx.update(|cx| tool.authorize(&path, &display_description, event_stream, cx))
.await
- .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+ .map_err(|e| e.to_string())?;
let buffer = tool
.project
.update(cx, |project, cx| project.open_buffer(project_path, cx))
.await
- .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+ .map_err(|e| e.to_string())?;
ensure_buffer_saved(&buffer, &abs_path, tool, cx)?;
@@ -578,6 +707,7 @@ impl EditSession {
Ok(Self {
abs_path,
+ input_path: path,
buffer,
old_text,
diff,
@@ -594,22 +724,20 @@ impl EditSession {
tool: &StreamingEditFileTool,
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
- ) -> Result<StreamingEditFileToolOutput, StreamingEditFileToolOutput> {
- let old_text = self.old_text.clone();
-
+ ) -> Result<(), String> {
match input.mode {
StreamingEditFileMode::Write => {
- let content = input.content.ok_or_else(|| {
- StreamingEditFileToolOutput::error("'content' field is required for write mode")
- })?;
+ let content = input
+ .content
+ .ok_or_else(|| "'content' field is required for write mode".to_string())?;
let events = self.parser.finalize_content(&content);
self.process_events(&events, tool, event_stream, cx)?;
}
StreamingEditFileMode::Edit => {
- let edits = input.edits.ok_or_else(|| {
- StreamingEditFileToolOutput::error("'edits' field is required for edit mode")
- })?;
+ let edits = input
+ .edits
+ .ok_or_else(|| "'edits' field is required for edit mode".to_string())?;
let events = self.parser.finalize_edits(&edits);
self.process_events(&events, tool, event_stream, cx)?;
@@ -625,53 +753,15 @@ impl EditSession {
}
}
}
+ Ok(())
+ }
- let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| {
- let settings = language_settings::LanguageSettings::for_buffer(buffer, cx);
- settings.format_on_save != FormatOnSave::Off
- });
-
- if format_on_save_enabled {
- tool.action_log.update(cx, |log, cx| {
- log.buffer_edited(self.buffer.clone(), cx);
- });
-
- let format_task = tool.project.update(cx, |project, cx| {
- project.format(
- HashSet::from_iter([self.buffer.clone()]),
- LspFormatTarget::Buffers,
- false,
- FormatTrigger::Save,
- cx,
- )
- });
- futures::select! {
- result = format_task.fuse() => { result.log_err(); },
- _ = event_stream.cancelled_by_user().fuse() => {
- return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
- }
- };
- }
-
- let save_task = tool.project.update(cx, |project, cx| {
- project.save_buffer(self.buffer.clone(), cx)
- });
- futures::select! {
- result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; },
- _ = event_stream.cancelled_by_user().fuse() => {
- return Err(StreamingEditFileToolOutput::error("Edit cancelled by user"));
- }
- };
-
- tool.action_log.update(cx, |log, cx| {
- log.buffer_edited(self.buffer.clone(), cx);
- });
-
+ async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) {
let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let (new_text, unified_diff) = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
- let old_text = old_text.clone();
+ let old_text = self.old_text.clone();
async move {
let new_text = new_snapshot.text();
let diff = language::unified_diff(&old_text, &new_text);
@@ -679,14 +769,7 @@ impl EditSession {
}
})
.await;
-
- let output = StreamingEditFileToolOutput::Success {
- input_path: input.path,
- new_text,
- old_text: old_text.clone(),
- diff: unified_diff,
- };
- Ok(output)
+ (new_text, unified_diff)
}
fn process(
@@ -695,7 +778,7 @@ impl EditSession {
tool: &StreamingEditFileTool,
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
- ) -> Result<(), StreamingEditFileToolOutput> {
+ ) -> Result<(), String> {
match &self.mode {
StreamingEditFileMode::Write => {
if let Some(content) = &partial.content {
@@ -719,7 +802,7 @@ impl EditSession {
tool: &StreamingEditFileTool,
event_stream: &ToolCallEventStream,
cx: &mut AsyncApp,
- ) -> Result<(), StreamingEditFileToolOutput> {
+ ) -> Result<(), String> {
for event in events {
match event {
ToolEditEvent::ContentChunk { chunk } => {
@@ -969,14 +1052,14 @@ fn extract_match(
buffer: &Entity<Buffer>,
edit_index: &usize,
cx: &mut AsyncApp,
-) -> Result<Range<usize>, StreamingEditFileToolOutput> {
+) -> Result<Range<usize>, String> {
match matches.len() {
- 0 => Err(StreamingEditFileToolOutput::error(format!(
+ 0 => Err(format!(
"Could not find matching text for edit at index {}. \
The old_text did not match any content in the file. \
Please read the file again to get the current content.",
edit_index,
- ))),
+ )),
1 => Ok(matches.into_iter().next().unwrap()),
_ => {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
@@ -985,12 +1068,12 @@ fn extract_match(
.map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
.collect::<Vec<_>>()
.join(", ");
- Err(StreamingEditFileToolOutput::error(format!(
+ Err(format!(
"Edit {} matched multiple locations in the file at lines: {}. \
Please provide more context in old_text to uniquely \
identify the location.",
edit_index, lines
- )))
+ ))
}
}
}
@@ -1022,7 +1105,7 @@ fn ensure_buffer_saved(
abs_path: &PathBuf,
tool: &StreamingEditFileTool,
cx: &mut AsyncApp,
-) -> Result<(), StreamingEditFileToolOutput> {
+) -> Result<(), String> {
let last_read_mtime = tool
.action_log
.read_with(cx, |log, _| log.file_read_time(abs_path));
@@ -1063,15 +1146,14 @@ fn ensure_buffer_saved(
then ask them to save or revert the file manually and inform you when it's ok to proceed."
}
};
- return Err(StreamingEditFileToolOutput::error(message));
+ return Err(message.to_string());
}
if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
if current != last_read {
- return Err(StreamingEditFileToolOutput::error(
- "The file has been modified since you last read it. \
- Please read the file again to get the current state before editing it.",
- ));
+ return Err("The file has been modified since you last read it. \
+ Please read the file again to get the current state before editing it."
+ .to_string());
}
}
@@ -1083,56 +1165,63 @@ fn resolve_path(
path: &PathBuf,
project: &Entity<Project>,
cx: &mut App,
-) -> Result<ProjectPath> {
+) -> Result<ProjectPath, String> {
let project = project.read(cx);
match mode {
StreamingEditFileMode::Edit => {
let path = project
.find_project_path(&path, cx)
- .context("Can't edit file: path not found")?;
+ .ok_or_else(|| "Can't edit file: path not found".to_string())?;
let entry = project
.entry_for_path(&path, cx)
- .context("Can't edit file: path not found")?;
+ .ok_or_else(|| "Can't edit file: path not found".to_string())?;
- anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
- Ok(path)
+ if entry.is_file() {
+ Ok(path)
+ } else {
+ Err("Can't edit file: path is a directory".to_string())
+ }
}
StreamingEditFileMode::Write => {
if let Some(path) = project.find_project_path(&path, cx)
&& let Some(entry) = project.entry_for_path(&path, cx)
{
- anyhow::ensure!(entry.is_file(), "Can't write to file: path is a directory");
- return Ok(path);
+ if entry.is_file() {
+ return Ok(path);
+ } else {
+ return Err("Can't write to file: path is a directory".to_string());
+ }
}
- let parent_path = path.parent().context("Can't create file: incorrect path")?;
+ let parent_path = path
+ .parent()
+ .ok_or_else(|| "Can't create file: incorrect path".to_string())?;
let parent_project_path = project.find_project_path(&parent_path, cx);
let parent_entry = parent_project_path
.as_ref()
.and_then(|path| project.entry_for_path(path, cx))
- .context("Can't create file: parent directory doesn't exist")?;
+ .ok_or_else(|| "Can't create file: parent directory doesn't exist")?;
- anyhow::ensure!(
- parent_entry.is_dir(),
- "Can't create file: parent is not a directory"
- );
+ if !parent_entry.is_dir() {
+ return Err("Can't create file: parent is not a directory".to_string());
+ }
let file_name = path
.file_name()
.and_then(|file_name| file_name.to_str())
.and_then(|file_name| RelPath::unix(file_name).ok())
- .context("Can't create file: invalid filename")?;
+ .ok_or_else(|| "Can't create file: invalid filename".to_string())?;
let new_file_path = parent_project_path.map(|parent| ProjectPath {
path: parent.path.join(file_name),
..parent
});
- new_file_path.context("Can't create file")
+ new_file_path.ok_or_else(|| "Can't create file".to_string())
}
}
}
@@ -1382,10 +1471,17 @@ mod tests {
})
.await;
- let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+ let StreamingEditFileToolOutput::Error {
+ error,
+ diff,
+ input_path,
+ } = result.unwrap_err()
+ else {
panic!("expected error");
};
assert_eq!(error, "Can't edit file: path not found");
+ assert!(diff.is_empty());
+ assert_eq!(input_path, None);
}
#[gpui::test]
@@ -1411,7 +1507,7 @@ mod tests {
})
.await;
- let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+ let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else {
panic!("expected error");
};
assert!(
@@ -1424,7 +1520,7 @@ mod tests {
async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1447,7 +1543,7 @@ mod tests {
cx.run_until_parked();
// Now send the final complete input
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit lines",
"path": "root/file.txt",
"mode": "edit",
@@ -1465,7 +1561,7 @@ mod tests {
async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "hello world"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1485,7 +1581,7 @@ mod tests {
cx.run_until_parked();
// Send final
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Overwrite file",
"path": "root/file.txt",
"mode": "write",
@@ -1503,7 +1599,7 @@ mod tests {
async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "hello world"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver, mut cancellation_tx) =
ToolCallEventStream::test_with_cancellation();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1521,7 +1617,7 @@ mod tests {
drop(sender);
let result = task.await;
- let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+ let StreamingEditFileToolOutput::Error { error, .. } = result.unwrap_err() else {
panic!("expected error");
};
assert!(
@@ -1537,7 +1633,7 @@ mod tests {
json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}),
)
.await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1578,7 +1674,7 @@ mod tests {
cx.run_until_parked();
// Send final complete input
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit multiple lines",
"path": "root/file.txt",
"mode": "edit",
@@ -1601,7 +1697,7 @@ mod tests {
#[gpui::test]
async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1625,7 +1721,7 @@ mod tests {
cx.run_until_parked();
// Final with full content
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Create new file",
"path": "root/dir/new_file.txt",
"mode": "write",
@@ -1643,12 +1739,12 @@ mod tests {
async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
// Send final immediately with no partials (simulates non-streaming path)
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit lines",
"path": "root/file.txt",
"mode": "edit",
@@ -1669,7 +1765,7 @@ mod tests {
json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}),
)
.await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1739,7 +1835,7 @@ mod tests {
);
// Send final complete input
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit multiple lines",
"path": "root/file.txt",
"mode": "edit",
@@ -1767,7 +1863,7 @@ mod tests {
async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) {
let (tool, project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1835,7 +1931,7 @@ mod tests {
assert_eq!(buffer_text.as_deref(), Some("AAA\nbbb\nCCC\nddd\nEEEeee\n"));
// Send final
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit three lines",
"path": "root/file.txt",
"mode": "edit",
@@ -1857,7 +1953,7 @@ mod tests {
async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) {
let (tool, project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1893,16 +1989,17 @@ mod tests {
}));
cx.run_until_parked();
- // Verify edit 1 was applied
- let buffer_text = project.update(cx, |project, cx| {
+ let buffer = project.update(cx, |project, cx| {
let pp = project
.find_project_path(&PathBuf::from("root/file.txt"), cx)
.unwrap();
- project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text())
+ project.get_open_buffer(&pp, cx).unwrap()
});
+
+ // Verify edit 1 was applied
+ let buffer_text = buffer.read_with(cx, |buffer, _cx| buffer.text());
assert_eq!(
- buffer_text.as_deref(),
- Some("MODIFIED\nline 2\nline 3\n"),
+ buffer_text, "MODIFIED\nline 2\nline 3\n",
"First edit should be applied even though second edit will fail"
);
@@ -1925,20 +2022,32 @@ mod tests {
drop(sender);
let result = task.await;
- let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else {
+ let StreamingEditFileToolOutput::Error {
+ error,
+ diff,
+ input_path,
+ } = result.unwrap_err()
+ else {
panic!("expected error");
};
+
assert!(
error.contains("Could not find matching text for edit at index 1"),
"Expected error about edit 1 failing, got: {error}"
);
+ // Ensure that first edit was applied successfully and that we saved the buffer
+ assert_eq!(input_path, Some(PathBuf::from("root/file.txt")));
+ assert_eq!(
+ diff,
+ "@@ -1,3 +1,3 @@\n-line 1\n+MODIFIED\n line 2\n line 3\n"
+ );
}
#[gpui::test]
async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) {
let (tool, project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "hello world\n"})).await;
- let (sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
+ let (mut sender, input) = ToolInput::<StreamingEditFileToolInput>::test();
let (event_stream, _receiver) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -1975,7 +2084,7 @@ mod tests {
);
// Send final — the edit is applied during finalization
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Single edit",
"path": "root/file.txt",
"mode": "edit",
@@ -1993,7 +2102,7 @@ mod tests {
async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
- let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+ let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
ToolInput::test();
let (event_stream, _event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2020,7 +2129,7 @@ mod tests {
cx.run_until_parked();
// Send the final complete input
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Edit lines",
"path": "root/file.txt",
"mode": "edit",
@@ -2038,7 +2147,7 @@ mod tests {
async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) {
let (tool, _project, _action_log, _fs, _thread) =
setup_test(cx, json!({"file.txt": "hello world\n"})).await;
- let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+ let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
ToolInput::test();
let (event_stream, _event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2064,7 +2173,7 @@ mod tests {
// Create a channel and send multiple partials before a final, then use
// ToolInput::resolved-style immediate delivery to confirm recv() works
// when partials are already buffered.
- let (sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
+ let (mut sender, input): (ToolInputSender, ToolInput<StreamingEditFileToolInput>) =
ToolInput::test();
let (event_stream, _event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| tool.clone().run(input, event_stream, cx));
@@ -2077,7 +2186,7 @@ mod tests {
"path": "root/dir/new.txt",
"mode": "write"
}));
- sender.send_final(json!({
+ sender.send_full(json!({
"display_description": "Create",
"path": "root/dir/new.txt",
"mode": "write",
@@ -2109,13 +2218,13 @@ mod tests {
let result = test_resolve_path(&mode, "root/dir/subdir", cx);
assert_eq!(
- result.await.unwrap_err().to_string(),
+ result.await.unwrap_err(),
"Can't write to file: path is a directory"
);
let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx);
assert_eq!(
- result.await.unwrap_err().to_string(),
+ result.await.unwrap_err(),
"Can't create file: parent directory doesn't exist"
);
}
@@ -2133,14 +2242,11 @@ mod tests {
assert_resolved_path_eq(result.await, rel_path(path_without_root));
let result = test_resolve_path(&mode, "root/nonexistent.txt", cx);
- assert_eq!(
- result.await.unwrap_err().to_string(),
- "Can't edit file: path not found"
- );
+ assert_eq!(result.await.unwrap_err(), "Can't edit file: path not found");
let result = test_resolve_path(&mode, "root/dir", cx);
assert_eq!(
- result.await.unwrap_err().to_string(),
+ result.await.unwrap_err(),
"Can't edit file: path is a directory"
);
}
@@ -5175,7 +5175,7 @@ mod tests {
multi_workspace
.read_with(cx, |multi_workspace, _cx| {
assert_eq!(
- multi_workspace.workspaces().len(),
+ multi_workspace.workspaces().count(),
1,
"LocalProject should not create a new workspace"
);
@@ -5451,6 +5451,11 @@ mod tests {
let multi_workspace =
cx.add_window(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ multi_workspace
+ .update(cx, |multi_workspace, _, cx| {
+ multi_workspace.open_sidebar(cx);
+ })
+ .unwrap();
let workspace = multi_workspace
.read_with(cx, |multi_workspace, _cx| {
@@ -5538,15 +5543,14 @@ mod tests {
.read_with(cx, |multi_workspace, cx| {
// There should be more than one workspace now (the original + the new worktree).
assert!(
- multi_workspace.workspaces().len() > 1,
+ multi_workspace.workspaces().count() > 1,
"expected a new workspace to have been created, found {}",
- multi_workspace.workspaces().len(),
+ multi_workspace.workspaces().count(),
);
// Check the newest workspace's panel for the correct agent.
let new_workspace = multi_workspace
.workspaces()
- .iter()
.find(|ws| ws.entity_id() != workspace.entity_id())
.expect("should find the new workspace");
let new_panel = new_workspace
@@ -812,7 +812,7 @@ impl ConversationView {
let agent_id = self.agent.agent_id();
let session_capabilities = Arc::new(RwLock::new(SessionCapabilities::new(
thread.read(cx).prompt_capabilities(),
- vec![],
+ thread.read(cx).available_commands().to_vec(),
)));
let action_log = thread.read(cx).action_log().clone();
@@ -1448,40 +1448,24 @@ impl ConversationView {
self.emit_token_limit_telemetry_if_needed(thread, cx);
}
AcpThreadEvent::AvailableCommandsUpdated(available_commands) => {
- let mut available_commands = available_commands.clone();
+ if let Some(thread_view) = self.thread_view(&thread_id) {
+ let has_commands = !available_commands.is_empty();
- if thread
- .read(cx)
- .connection()
- .auth_methods()
- .iter()
- .any(|method| method.id().0.as_ref() == "claude-login")
- {
- available_commands.push(acp::AvailableCommand::new("login", "Authenticate"));
- available_commands.push(acp::AvailableCommand::new("logout", "Authenticate"));
- }
-
- let has_commands = !available_commands.is_empty();
- if let Some(active) = self.active_thread() {
- active.update(cx, |active, _cx| {
- active
- .session_capabilities
- .write()
- .set_available_commands(available_commands);
- });
- }
-
- let agent_display_name = self
- .agent_server_store
- .read(cx)
- .agent_display_name(&self.agent.agent_id())
- .unwrap_or_else(|| self.agent.agent_id().0.to_string().into());
+ let agent_display_name = self
+ .agent_server_store
+ .read(cx)
+ .agent_display_name(&self.agent.agent_id())
+ .unwrap_or_else(|| self.agent.agent_id().0.to_string().into());
- if let Some(active) = self.active_thread() {
let new_placeholder =
placeholder_text(agent_display_name.as_ref(), has_commands);
- active.update(cx, |active, cx| {
- active.message_editor.update(cx, |editor, cx| {
+
+ thread_view.update(cx, |thread_view, cx| {
+ thread_view
+ .session_capabilities
+ .write()
+ .set_available_commands(available_commands.clone());
+ thread_view.message_editor.update(cx, |editor, cx| {
editor.set_placeholder_text(&new_placeholder, window, cx);
});
});
@@ -2348,9 +2332,9 @@ impl ConversationView {
}
}
+ #[cfg(feature = "audio")]
fn play_notification_sound(&self, window: &Window, cx: &mut App) {
- let settings = AgentSettings::get_global(cx);
- let _visible = window.is_window_active()
+ let visible = window.is_window_active()
&& if let Some(mw) = window.root::<MultiWorkspace>().flatten() {
self.agent_panel_visible(&mw, cx)
} else {
@@ -2358,8 +2342,8 @@ impl ConversationView {
.upgrade()
.is_some_and(|workspace| AgentPanel::is_visible(&workspace, cx))
};
- #[cfg(feature = "audio")]
- if settings.play_sound_when_agent_done.should_play(_visible) {
+ let settings = AgentSettings::get_global(cx);
+ if settings.play_sound_when_agent_done.should_play(visible) {
Audio::play_sound(Sound::AgentDone, cx);
}
}
@@ -2989,6 +2973,166 @@ pub(crate) mod tests {
});
}
+ #[derive(Clone)]
+ struct RestoredAvailableCommandsConnection;
+
+ impl AgentConnection for RestoredAvailableCommandsConnection {
+ fn agent_id(&self) -> AgentId {
+ AgentId::new("restored-available-commands")
+ }
+
+ fn telemetry_id(&self) -> SharedString {
+ "restored-available-commands".into()
+ }
+
+ fn new_session(
+ self: Rc<Self>,
+ project: Entity<Project>,
+ _work_dirs: PathList,
+ cx: &mut App,
+ ) -> Task<gpui::Result<Entity<AcpThread>>> {
+ let thread = build_test_thread(
+ self,
+ project,
+ "RestoredAvailableCommandsConnection",
+ SessionId::new("new-session"),
+ cx,
+ );
+ Task::ready(Ok(thread))
+ }
+
+ fn supports_load_session(&self) -> bool {
+ true
+ }
+
+ fn load_session(
+ self: Rc<Self>,
+ session_id: acp::SessionId,
+ project: Entity<Project>,
+ _work_dirs: PathList,
+ _title: Option<SharedString>,
+ cx: &mut App,
+ ) -> Task<gpui::Result<Entity<AcpThread>>> {
+ let thread = build_test_thread(
+ self,
+ project,
+ "RestoredAvailableCommandsConnection",
+ session_id,
+ cx,
+ );
+
+ thread
+ .update(cx, |thread, cx| {
+ thread.handle_session_update(
+ acp::SessionUpdate::AvailableCommandsUpdate(
+ acp::AvailableCommandsUpdate::new(vec![acp::AvailableCommand::new(
+ "help", "Get help",
+ )]),
+ ),
+ cx,
+ )
+ })
+ .expect("available commands update should succeed");
+
+ Task::ready(Ok(thread))
+ }
+
+ fn auth_methods(&self) -> &[acp::AuthMethod] {
+ &[]
+ }
+
+ fn authenticate(
+ &self,
+ _method_id: acp::AuthMethodId,
+ _cx: &mut App,
+ ) -> Task<gpui::Result<()>> {
+ Task::ready(Ok(()))
+ }
+
+ fn prompt(
+ &self,
+ _id: Option<acp_thread::UserMessageId>,
+ _params: acp::PromptRequest,
+ _cx: &mut App,
+ ) -> Task<gpui::Result<acp::PromptResponse>> {
+ Task::ready(Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)))
+ }
+
+ fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {}
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
+ }
+
+ #[gpui::test]
+ async fn test_restored_threads_keep_available_commands(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ let project = Project::test(fs, [], cx).await;
+ let (multi_workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+
+ let thread_store = cx.update(|_window, cx| cx.new(|cx| ThreadStore::new(cx)));
+ let connection_store =
+ cx.update(|_window, cx| cx.new(|cx| AgentConnectionStore::new(project.clone(), cx)));
+
+ let conversation_view = cx.update(|window, cx| {
+ cx.new(|cx| {
+ ConversationView::new(
+ Rc::new(StubAgentServer::new(RestoredAvailableCommandsConnection)),
+ connection_store,
+ Agent::Custom { id: "Test".into() },
+ Some(SessionId::new("restored-session")),
+ None,
+ None,
+ None,
+ workspace.downgrade(),
+ project,
+ Some(thread_store),
+ None,
+ window,
+ cx,
+ )
+ })
+ });
+
+ cx.run_until_parked();
+
+ let message_editor = message_editor(&conversation_view, cx);
+ let editor =
+ message_editor.update(cx, |message_editor, _cx| message_editor.editor().clone());
+ let placeholder = editor.update(cx, |editor, cx| editor.placeholder_text(cx));
+
+ active_thread(&conversation_view, cx).read_with(cx, |view, _cx| {
+ let available_commands = view
+ .session_capabilities
+ .read()
+ .available_commands()
+ .to_vec();
+ assert_eq!(available_commands.len(), 1);
+ assert_eq!(available_commands[0].name.as_str(), "help");
+ assert_eq!(available_commands[0].description.as_str(), "Get help");
+ });
+
+ assert_eq!(
+ placeholder,
+ Some("Message Test — @ to include context, / for commands".to_string())
+ );
+
+ message_editor.update_in(cx, |editor, window, cx| {
+ editor.set_text("/help", window, cx);
+ });
+
+ let contents_result = message_editor
+ .update(cx, |editor, cx| editor.contents(false, cx))
+ .await;
+
+ assert!(contents_result.is_ok());
+ }
+
#[gpui::test]
async fn test_resume_thread_uses_session_cwd_when_inside_project(cx: &mut TestAppContext) {
init_test(cx);
@@ -3375,7 +3519,6 @@ pub(crate) mod tests {
// Verify workspace1 is no longer the active workspace
multi_workspace_handle
.read_with(cx, |mw, _cx| {
- assert_eq!(mw.active_workspace_index(), 1);
assert_ne!(mw.workspace(), &workspace1);
})
.unwrap();
@@ -344,7 +344,8 @@ impl ThreadView {
) -> Self {
let id = thread.read(cx).session_id().clone();
- let placeholder = placeholder_text(agent_display_name.as_ref(), false);
+ let has_commands = !session_capabilities.read().available_commands().is_empty();
+ let placeholder = placeholder_text(agent_display_name.as_ref(), has_commands);
let history_subscription = history.as_ref().map(|h| {
cx.observe(h, |this, history, cx| {
@@ -7389,9 +7390,8 @@ impl ThreadView {
.gap_2()
.map(|this| {
if card_layout {
- this.when(context_ix > 0, |this| {
- this.pt_2()
- .border_t_1()
+ this.p_2().when(context_ix > 0, |this| {
+ this.border_t_1()
.border_color(self.tool_card_border_color(cx))
})
} else {
@@ -218,6 +218,13 @@ impl ThreadsArchiveView {
handle.focus(window, cx);
}
+ pub fn is_filter_editor_focused(&self, window: &Window, cx: &App) -> bool {
+ self.filter_editor
+ .read(cx)
+ .focus_handle(cx)
+ .is_focused(window)
+ }
+
fn update_items(&mut self, cx: &mut Context<Self>) {
let sessions = ThreadMetadataStore::global(cx)
.read(cx)
@@ -346,7 +353,6 @@ impl ThreadsArchiveView {
.map(|mw| {
mw.read(cx)
.workspaces()
- .iter()
.filter_map(|ws| ws.read(cx).database_id())
.collect()
})
@@ -435,6 +435,7 @@ impl Server {
.add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
.add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
.add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
+ .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
.add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
.add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
.add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
@@ -424,6 +424,58 @@ async fn test_remote_git_worktrees(
);
}
+#[gpui::test]
+async fn test_remote_git_head_sha(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ let mut server = TestServer::start(executor.clone()).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+ server
+ .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)])
+ .await;
+ let active_call_a = cx_a.read(ActiveCall::global);
+
+ client_a
+ .fs()
+ .insert_tree(
+ path!("/project"),
+ json!({ ".git": {}, "file.txt": "content" }),
+ )
+ .await;
+
+ let (project_a, _) = client_a.build_local_project(path!("/project"), cx_a).await;
+ let local_head_sha = cx_a.update(|cx| {
+ project_a
+ .read(cx)
+ .active_repository(cx)
+ .unwrap()
+ .update(cx, |repository, _| repository.head_sha())
+ });
+ let local_head_sha = local_head_sha.await.unwrap().unwrap();
+
+ let project_id = active_call_a
+ .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
+ .await
+ .unwrap();
+ let project_b = client_b.join_remote_project(project_id, cx_b).await;
+
+ executor.run_until_parked();
+
+ let remote_head_sha = cx_b.update(|cx| {
+ project_b
+ .read(cx)
+ .active_repository(cx)
+ .unwrap()
+ .update(cx, |repository, _| repository.head_sha())
+ });
+ let remote_head_sha = remote_head_sha.await.unwrap();
+
+ assert_eq!(remote_head_sha.unwrap(), local_head_sha);
+}
+
#[gpui::test]
async fn test_linked_worktrees_sync(
executor: BackgroundExecutor,
@@ -883,7 +883,13 @@ RUN sed -i -E 's/((^|\s)PATH=)([^\$]*)$/\1\${{PATH:-\3}}/g' /etc/profile || true
labels: None,
build: Some(DockerComposeServiceBuild {
context: Some(
- features_build_info.empty_context_dir.display().to_string(),
+ main_service
+ .build
+ .as_ref()
+ .and_then(|b| b.context.clone())
+ .unwrap_or_else(|| {
+ features_build_info.empty_context_dir.display().to_string()
+ }),
),
dockerfile: Some(dockerfile_path.display().to_string()),
args: Some(build_args),
@@ -3546,6 +3552,27 @@ ENV DOCKER_BUILDKIT=1
"#
);
+ let build_override = files
+ .iter()
+ .find(|f| {
+ f.file_name()
+ .is_some_and(|s| s.display().to_string() == "docker_compose_build.json")
+ })
+ .expect("to be found");
+ let build_override = test_dependencies.fs.load(build_override).await.unwrap();
+ let build_config: DockerComposeConfig =
+ serde_json_lenient::from_str(&build_override).unwrap();
+ let build_context = build_config
+ .services
+ .get("app")
+ .and_then(|s| s.build.as_ref())
+ .and_then(|b| b.context.clone())
+ .expect("build override should have a context");
+ assert_eq!(
+ build_context, ".",
+ "build override should preserve the original build context from docker-compose.yml"
+ );
+
let runtime_override = files
.iter()
.find(|f| {
@@ -2707,6 +2707,65 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte
});
}
+#[gpui::test]
+async fn test_v3_prediction_strips_cursor_marker_from_edit_text(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.txt": "hello"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project
+ .find_project_path(path!("root/foo.txt"), cx)
+ .unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(0, 5));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let excerpt_length = request.input.cursor_excerpt.len();
+ respond_tx
+ .send(PredictEditsV3Response {
+ request_id: Uuid::new_v4().to_string(),
+ output: "hello<|user_cursor|> world".to_string(),
+ editable_range: 0..excerpt_length,
+ model_version: None,
+ })
+ .unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.update(cx, |ep_store, cx| {
+ let prediction = ep_store
+ .prediction_at(&buffer, None, &project, cx)
+ .expect("should have prediction");
+ let snapshot = buffer.read(cx).snapshot();
+ let edits: Vec<_> = prediction
+ .edits
+ .iter()
+ .map(|(range, text)| (range.to_offset(&snapshot), text.clone()))
+ .collect();
+
+ assert_eq!(edits, vec![(5..5, " world".into())]);
+ });
+}
+
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
@@ -1,10 +1,11 @@
-use crate::udiff::DiffLine;
use anyhow::{Context as _, Result};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
use telemetry_events::EditPredictionRating;
-pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
+pub use zeta_prompt::udiff::{
+ CURSOR_POSITION_MARKER, encode_cursor_in_patch, extract_cursor_from_patch,
+};
pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
/// Maximum cursor file size to capture (64KB).
@@ -12,64 +13,6 @@ pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
/// falling back to git-based loading.
pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
-/// Encodes a cursor position into a diff patch by adding a comment line with a caret
-/// pointing to the cursor column.
-///
-/// The cursor offset is relative to the start of the new text content (additions and context lines).
-/// Returns the patch with cursor marker comment lines inserted after the relevant addition line.
-pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option<usize>) -> String {
- let Some(cursor_offset) = cursor_offset else {
- return patch.to_string();
- };
-
- let mut result = String::new();
- let mut line_start_offset = 0usize;
-
- for line in patch.lines() {
- if matches!(
- DiffLine::parse(line),
- DiffLine::Garbage(content)
- if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER)
- ) {
- continue;
- }
-
- if !result.is_empty() {
- result.push('\n');
- }
- result.push_str(line);
-
- match DiffLine::parse(line) {
- DiffLine::Addition(content) => {
- let line_end_offset = line_start_offset + content.len();
-
- if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset {
- let cursor_column = cursor_offset - line_start_offset;
-
- result.push('\n');
- result.push('#');
- for _ in 0..cursor_column {
- result.push(' ');
- }
- write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
- }
-
- line_start_offset = line_end_offset + 1;
- }
- DiffLine::Context(content) => {
- line_start_offset += content.len() + 1;
- }
- _ => {}
- }
- }
-
- if patch.ends_with('\n') {
- result.push('\n');
- }
-
- result
-}
-
#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
pub struct ExampleSpec {
#[serde(default)]
@@ -509,53 +452,7 @@ impl ExampleSpec {
pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option<usize>)> {
self.expected_patches
.iter()
- .map(|patch| {
- let mut clean_patch = String::new();
- let mut cursor_offset: Option<usize> = None;
- let mut line_start_offset = 0usize;
- let mut prev_line_start_offset = 0usize;
-
- for line in patch.lines() {
- let diff_line = DiffLine::parse(line);
-
- match &diff_line {
- DiffLine::Garbage(content)
- if content.starts_with('#')
- && content.contains(CURSOR_POSITION_MARKER) =>
- {
- let caret_column = if let Some(caret_pos) = content.find('^') {
- caret_pos
- } else if let Some(_) = content.find('<') {
- 0
- } else {
- continue;
- };
- let cursor_column = caret_column.saturating_sub('#'.len_utf8());
- cursor_offset = Some(prev_line_start_offset + cursor_column);
- }
- _ => {
- if !clean_patch.is_empty() {
- clean_patch.push('\n');
- }
- clean_patch.push_str(line);
-
- match diff_line {
- DiffLine::Addition(content) | DiffLine::Context(content) => {
- prev_line_start_offset = line_start_offset;
- line_start_offset += content.len() + 1;
- }
- _ => {}
- }
- }
- }
- }
-
- if patch.ends_with('\n') && !clean_patch.is_empty() {
- clean_patch.push('\n');
- }
-
- (clean_patch, cursor_offset)
- })
+ .map(|patch| extract_cursor_from_patch(patch))
.collect()
}
@@ -24,8 +24,9 @@ use zeta_prompt::{ParsedOutput, ZetaPromptInput};
use std::{env, ops::Range, path::Path, sync::Arc};
use zeta_prompt::{
- CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
- prompt_input_contains_special_tokens, stop_tokens_for_format,
+ ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
+ parsed_output_from_editable_region, prompt_input_contains_special_tokens,
+ stop_tokens_for_format,
zeta1::{self, EDITABLE_REGION_END_MARKER},
};
@@ -181,6 +182,7 @@ pub fn request_prediction_with_zeta(
let parsed_output = output_text.map(|text| ParsedOutput {
new_editable_region: text,
range_in_excerpt: editable_range_in_excerpt,
+ cursor_offset_in_new_editable_region: None,
});
(request_id, parsed_output, None, None)
@@ -283,10 +285,10 @@ pub fn request_prediction_with_zeta(
let request_id = EditPredictionId(response.request_id.into());
let output_text = Some(response.output).filter(|s| !s.is_empty());
let model_version = response.model_version;
- let parsed_output = ParsedOutput {
- new_editable_region: output_text.unwrap_or_default(),
- range_in_excerpt: response.editable_range,
- };
+ let parsed_output = parsed_output_from_editable_region(
+ response.editable_range,
+ output_text.unwrap_or_default(),
+ );
Some((request_id, Some(parsed_output), model_version, usage))
})
@@ -299,6 +301,7 @@ pub fn request_prediction_with_zeta(
let Some(ParsedOutput {
new_editable_region: mut output_text,
range_in_excerpt: editable_range_in_excerpt,
+ cursor_offset_in_new_editable_region: cursor_offset_in_output,
}) = output
else {
return Ok((Some((request_id, None)), None));
@@ -312,13 +315,6 @@ pub fn request_prediction_with_zeta(
.text_for_range(editable_range_in_buffer.clone())
.collect::<String>();
- // Client-side cursor marker processing (applies to both raw and v3 responses)
- let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
- if let Some(offset) = cursor_offset_in_output {
- log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
- output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
- }
-
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
@@ -5,8 +5,7 @@ use crate::{
repair,
};
use anyhow::{Context as _, Result};
-use edit_prediction::example_spec::encode_cursor_in_patch;
-use zeta_prompt::{CURSOR_MARKER, ZetaFormat, parse_zeta2_model_output};
+use zeta_prompt::{ZetaFormat, parse_zeta2_model_output, parsed_output_to_patch};
pub fn run_parse_output(example: &mut Example) -> Result<()> {
example
@@ -65,46 +64,18 @@ fn parse_zeta2_output(
.context("prompt_inputs required")?;
let parsed = parse_zeta2_model_output(actual_output, format, prompt_inputs)?;
- let range_in_excerpt = parsed.range_in_excerpt;
-
+ let range_in_excerpt = parsed.range_in_excerpt.clone();
let excerpt = prompt_inputs.cursor_excerpt.as_ref();
- let old_text = excerpt[range_in_excerpt.clone()].to_string();
- let mut new_text = parsed.new_editable_region;
-
- let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
- new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
- Some(offset)
- } else {
- None
- };
+ let editable_region_offset = range_in_excerpt.start;
+ let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
- // Normalize trailing newlines for diff generation
- let mut old_text_normalized = old_text;
+ let mut new_text = parsed.new_editable_region.clone();
if !new_text.is_empty() && !new_text.ends_with('\n') {
new_text.push('\n');
}
- if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
- old_text_normalized.push('\n');
- }
-
- let editable_region_offset = range_in_excerpt.start;
- let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
- let editable_region_lines = old_text_normalized.lines().count() as u32;
-
- let diff = language::unified_diff_with_context(
- &old_text_normalized,
- &new_text,
- editable_region_start_line as u32,
- editable_region_start_line as u32,
- editable_region_lines,
- );
-
- let formatted_diff = format!(
- "--- a/{path}\n+++ b/{path}\n{diff}",
- path = example.spec.cursor_path.to_string_lossy(),
- );
- let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset);
+ let cursor_offset = parsed.cursor_offset_in_new_editable_region;
+ let formatted_diff = parsed_output_to_patch(prompt_inputs, parsed)?;
let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
ActualCursor::from_editable_region(
@@ -36,8 +36,16 @@ pub struct FakeGitRepository {
pub(crate) is_trusted: Arc<AtomicBool>,
}
+#[derive(Debug, Clone)]
+pub struct FakeCommitSnapshot {
+ pub head_contents: HashMap<RepoPath, String>,
+ pub index_contents: HashMap<RepoPath, String>,
+ pub sha: String,
+}
+
#[derive(Debug, Clone)]
pub struct FakeGitRepositoryState {
+ pub commit_history: Vec<FakeCommitSnapshot>,
pub event_emitter: smol::channel::Sender<PathBuf>,
pub unmerged_paths: HashMap<RepoPath, UnmergedStatus>,
pub head_contents: HashMap<RepoPath, String>,
@@ -74,6 +82,7 @@ impl FakeGitRepositoryState {
oids: Default::default(),
remotes: HashMap::default(),
graph_commits: Vec::new(),
+ commit_history: Vec::new(),
stash_entries: Default::default(),
}
}
@@ -217,11 +226,52 @@ impl GitRepository for FakeGitRepository {
fn reset(
&self,
- _commit: String,
- _mode: ResetMode,
+ commit: String,
+ mode: ResetMode,
_env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
- unimplemented!()
+ self.with_state_async(true, move |state| {
+ let pop_count = if commit == "HEAD~" || commit == "HEAD^" {
+ 1
+ } else if let Some(suffix) = commit.strip_prefix("HEAD~") {
+ suffix
+ .parse::<usize>()
+ .with_context(|| format!("Invalid HEAD~ offset: {commit}"))?
+ } else {
+ match state
+ .commit_history
+ .iter()
+ .rposition(|entry| entry.sha == commit)
+ {
+ Some(index) => state.commit_history.len() - index,
+ None => anyhow::bail!("Unknown commit ref: {commit}"),
+ }
+ };
+
+ if pop_count == 0 || pop_count > state.commit_history.len() {
+ anyhow::bail!(
+ "Cannot reset {pop_count} commit(s): only {} in history",
+ state.commit_history.len()
+ );
+ }
+
+ let target_index = state.commit_history.len() - pop_count;
+ let snapshot = state.commit_history[target_index].clone();
+ state.commit_history.truncate(target_index);
+
+ match mode {
+ ResetMode::Soft => {
+ state.head_contents = snapshot.head_contents;
+ }
+ ResetMode::Mixed => {
+ state.head_contents = snapshot.head_contents;
+ state.index_contents = state.head_contents.clone();
+ }
+ }
+
+ state.refs.insert("HEAD".into(), snapshot.sha);
+ Ok(())
+ })
}
fn checkout_files(
@@ -490,7 +540,7 @@ impl GitRepository for FakeGitRepository {
fn create_worktree(
&self,
- branch_name: String,
+ branch_name: Option<String>,
path: PathBuf,
from_commit: Option<String>,
) -> BoxFuture<'_, Result<()>> {
@@ -505,8 +555,10 @@ impl GitRepository for FakeGitRepository {
if let Some(message) = &state.simulated_create_worktree_error {
anyhow::bail!("{message}");
}
- if state.branches.contains(&branch_name) {
- bail!("a branch named '{}' already exists", branch_name);
+ if let Some(ref name) = branch_name {
+ if state.branches.contains(name) {
+ bail!("a branch named '{}' already exists", name);
+ }
}
Ok(())
})??;
@@ -515,13 +567,22 @@ impl GitRepository for FakeGitRepository {
fs.create_dir(&path).await?;
// Create .git/worktrees/<name>/ directory with HEAD, commondir, gitdir.
- let ref_name = format!("refs/heads/{branch_name}");
- let worktrees_entry_dir = common_dir_path.join("worktrees").join(&branch_name);
+ let worktree_entry_name = branch_name
+ .as_deref()
+ .unwrap_or_else(|| path.file_name().unwrap().to_str().unwrap());
+ let worktrees_entry_dir = common_dir_path.join("worktrees").join(worktree_entry_name);
fs.create_dir(&worktrees_entry_dir).await?;
+ let sha = from_commit.unwrap_or_else(|| "fake-sha".to_string());
+ let head_content = if let Some(ref branch_name) = branch_name {
+ let ref_name = format!("refs/heads/{branch_name}");
+ format!("ref: {ref_name}")
+ } else {
+ sha.clone()
+ };
fs.write_file_internal(
worktrees_entry_dir.join("HEAD"),
- format!("ref: {ref_name}").into_bytes(),
+ head_content.into_bytes(),
false,
)?;
fs.write_file_internal(
@@ -544,10 +605,12 @@ impl GitRepository for FakeGitRepository {
)?;
// Update git state: add ref and branch.
- let sha = from_commit.unwrap_or_else(|| "fake-sha".to_string());
fs.with_git_state(&dot_git_path, true, move |state| {
- state.refs.insert(ref_name, sha);
- state.branches.insert(branch_name);
+ if let Some(branch_name) = branch_name {
+ let ref_name = format!("refs/heads/{branch_name}");
+ state.refs.insert(ref_name, sha);
+ state.branches.insert(branch_name);
+ }
Ok::<(), anyhow::Error>(())
})??;
Ok(())
@@ -822,11 +885,30 @@ impl GitRepository for FakeGitRepository {
&self,
_message: gpui::SharedString,
_name_and_email: Option<(gpui::SharedString, gpui::SharedString)>,
- _options: CommitOptions,
+ options: CommitOptions,
_askpass: AskPassDelegate,
_env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
- async { Ok(()) }.boxed()
+ self.with_state_async(true, move |state| {
+ if !options.allow_empty && !options.amend && state.index_contents == state.head_contents
+ {
+ anyhow::bail!("nothing to commit (use allow_empty to create an empty commit)");
+ }
+
+ let old_sha = state.refs.get("HEAD").cloned().unwrap_or_default();
+ state.commit_history.push(FakeCommitSnapshot {
+ head_contents: state.head_contents.clone(),
+ index_contents: state.index_contents.clone(),
+ sha: old_sha,
+ });
+
+ state.head_contents = state.index_contents.clone();
+
+ let new_sha = format!("fake-commit-{}", state.commit_history.len());
+ state.refs.insert("HEAD".into(), new_sha);
+
+ Ok(())
+ })
}
fn run_hook(
@@ -1210,6 +1292,24 @@ impl GitRepository for FakeGitRepository {
anyhow::bail!("commit_data_reader not supported for FakeGitRepository")
}
+ fn update_ref(&self, ref_name: String, commit: String) -> BoxFuture<'_, Result<()>> {
+ self.with_state_async(true, move |state| {
+ state.refs.insert(ref_name, commit);
+ Ok(())
+ })
+ }
+
+ fn delete_ref(&self, ref_name: String) -> BoxFuture<'_, Result<()>> {
+ self.with_state_async(true, move |state| {
+ state.refs.remove(&ref_name);
+ Ok(())
+ })
+ }
+
+ fn repair_worktrees(&self) -> BoxFuture<'_, Result<()>> {
+ async { Ok(()) }.boxed()
+ }
+
fn set_trusted(&self, trusted: bool) {
self.is_trusted
.store(trusted, std::sync::atomic::Ordering::Release);
@@ -24,7 +24,7 @@ async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) {
// Create a worktree
let worktree_1_dir = worktrees_dir.join("feature-branch");
repo.create_worktree(
- "feature-branch".to_string(),
+ Some("feature-branch".to_string()),
worktree_1_dir.clone(),
Some("abc123".to_string()),
)
@@ -47,9 +47,13 @@ async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) {
// Create a second worktree (without explicit commit)
let worktree_2_dir = worktrees_dir.join("bugfix-branch");
- repo.create_worktree("bugfix-branch".to_string(), worktree_2_dir.clone(), None)
- .await
- .unwrap();
+ repo.create_worktree(
+ Some("bugfix-branch".to_string()),
+ worktree_2_dir.clone(),
+ None,
+ )
+ .await
+ .unwrap();
let worktrees = repo.worktrees().await.unwrap();
assert_eq!(worktrees.len(), 3);
@@ -329,6 +329,7 @@ impl Upstream {
pub struct CommitOptions {
pub amend: bool,
pub signoff: bool,
+ pub allow_empty: bool,
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
@@ -715,7 +716,7 @@ pub trait GitRepository: Send + Sync {
fn create_worktree(
&self,
- branch_name: String,
+ branch_name: Option<String>,
path: PathBuf,
from_commit: Option<String>,
) -> BoxFuture<'_, Result<()>>;
@@ -916,6 +917,12 @@ pub trait GitRepository: Send + Sync {
fn commit_data_reader(&self) -> Result<CommitDataReader>;
+ fn update_ref(&self, ref_name: String, commit: String) -> BoxFuture<'_, Result<()>>;
+
+ fn delete_ref(&self, ref_name: String) -> BoxFuture<'_, Result<()>>;
+
+ fn repair_worktrees(&self) -> BoxFuture<'_, Result<()>>;
+
fn set_trusted(&self, trusted: bool);
fn is_trusted(&self) -> bool;
}
@@ -1660,19 +1667,20 @@ impl GitRepository for RealGitRepository {
fn create_worktree(
&self,
- branch_name: String,
+ branch_name: Option<String>,
path: PathBuf,
from_commit: Option<String>,
) -> BoxFuture<'_, Result<()>> {
let git_binary = self.git_binary();
- let mut args = vec![
- OsString::from("worktree"),
- OsString::from("add"),
- OsString::from("-b"),
- OsString::from(branch_name.as_str()),
- OsString::from("--"),
- OsString::from(path.as_os_str()),
- ];
+ let mut args = vec![OsString::from("worktree"), OsString::from("add")];
+ if let Some(branch_name) = &branch_name {
+ args.push(OsString::from("-b"));
+ args.push(OsString::from(branch_name.as_str()));
+ } else {
+ args.push(OsString::from("--detach"));
+ }
+ args.push(OsString::from("--"));
+ args.push(OsString::from(path.as_os_str()));
if let Some(from_commit) = from_commit {
args.push(OsString::from(from_commit));
} else {
@@ -2165,6 +2173,10 @@ impl GitRepository for RealGitRepository {
cmd.arg("--signoff");
}
+ if options.allow_empty {
+ cmd.arg("--allow-empty");
+ }
+
if let Some((name, email)) = name_and_email {
cmd.arg("--author").arg(&format!("{name} <{email}>"));
}
@@ -2176,6 +2188,39 @@ impl GitRepository for RealGitRepository {
.boxed()
}
+ fn update_ref(&self, ref_name: String, commit: String) -> BoxFuture<'_, Result<()>> {
+ let git_binary = self.git_binary();
+ self.executor
+ .spawn(async move {
+ let args: Vec<OsString> = vec!["update-ref".into(), ref_name.into(), commit.into()];
+ git_binary?.run(&args).await?;
+ Ok(())
+ })
+ .boxed()
+ }
+
+ fn delete_ref(&self, ref_name: String) -> BoxFuture<'_, Result<()>> {
+ let git_binary = self.git_binary();
+ self.executor
+ .spawn(async move {
+ let args: Vec<OsString> = vec!["update-ref".into(), "-d".into(), ref_name.into()];
+ git_binary?.run(&args).await?;
+ Ok(())
+ })
+ .boxed()
+ }
+
+ fn repair_worktrees(&self) -> BoxFuture<'_, Result<()>> {
+ let git_binary = self.git_binary();
+ self.executor
+ .spawn(async move {
+ let args: Vec<OsString> = vec!["worktree".into(), "repair".into()];
+ git_binary?.run(&args).await?;
+ Ok(())
+ })
+ .boxed()
+ }
+
fn push(
&self,
branch_name: String,
@@ -4009,7 +4054,7 @@ mod tests {
// Create a new worktree
repo.create_worktree(
- "test-branch".to_string(),
+ Some("test-branch".to_string()),
worktree_path.clone(),
Some("HEAD".to_string()),
)
@@ -4068,7 +4113,7 @@ mod tests {
// Create a worktree
let worktree_path = worktrees_dir.join("worktree-to-remove");
repo.create_worktree(
- "to-remove".to_string(),
+ Some("to-remove".to_string()),
worktree_path.clone(),
Some("HEAD".to_string()),
)
@@ -4092,7 +4137,7 @@ mod tests {
// Create a worktree
let worktree_path = worktrees_dir.join("dirty-wt");
repo.create_worktree(
- "dirty-wt".to_string(),
+ Some("dirty-wt".to_string()),
worktree_path.clone(),
Some("HEAD".to_string()),
)
@@ -4162,7 +4207,7 @@ mod tests {
// Create a worktree
let old_path = worktrees_dir.join("old-worktree-name");
repo.create_worktree(
- "old-name".to_string(),
+ Some("old-name".to_string()),
old_path.clone(),
Some("HEAD".to_string()),
)
@@ -2394,9 +2394,8 @@ impl GitGraph {
let local_y = position_y - canvas_bounds.origin.y;
if local_y >= px(0.) && local_y < canvas_bounds.size.height {
- let row_in_viewport = (local_y / self.row_height).floor() as usize;
- let scroll_rows = (scroll_offset_y / self.row_height).floor() as usize;
- let absolute_row = scroll_rows + row_in_viewport;
+ let absolute_y = local_y + scroll_offset_y;
+ let absolute_row = (absolute_y / self.row_height).floor() as usize;
if absolute_row < self.graph_data.commits.len() {
return Some(absolute_row);
@@ -4006,4 +4005,76 @@ mod tests {
});
assert_eq!(reloaded_shas, vec![updated_head, updated_stash]);
}
+
+ #[gpui::test]
+ async fn test_git_graph_row_at_position_rounding(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ Path::new("/project"),
+ serde_json::json!({
+ ".git": {},
+ "file.txt": "content",
+ }),
+ )
+ .await;
+
+ let mut rng = StdRng::seed_from_u64(42);
+ let commits = generate_random_commit_dag(&mut rng, 10, false);
+ fs.set_graph_commits(Path::new("/project/.git"), commits.clone());
+
+ let project = Project::test(fs.clone(), [Path::new("/project")], cx).await;
+ cx.run_until_parked();
+
+ let repository = project.read_with(cx, |project, cx| {
+ project
+ .active_repository(cx)
+ .expect("should have a repository")
+ });
+
+ let (multi_workspace, cx) = cx.add_window_view(|window, cx| {
+ workspace::MultiWorkspace::test_new(project.clone(), window, cx)
+ });
+
+ let workspace_weak =
+ multi_workspace.read_with(&*cx, |multi, _| multi.workspace().downgrade());
+
+ let git_graph = cx.new_window_entity(|window, cx| {
+ GitGraph::new(
+ repository.read(cx).id,
+ project.read(cx).git_store().clone(),
+ workspace_weak,
+ window,
+ cx,
+ )
+ });
+ cx.run_until_parked();
+
+ git_graph.update(cx, |graph, cx| {
+ assert!(
+ graph.graph_data.commits.len() >= 10,
+ "graph should load dummy commits"
+ );
+
+ graph.row_height = px(20.0);
+ let origin_y = px(100.0);
+ graph.graph_canvas_bounds.set(Some(Bounds {
+ origin: point(px(0.0), origin_y),
+ size: gpui::size(px(100.0), px(1000.0)),
+ }));
+
+ graph.table_interaction_state.update(cx, |state, _| {
+ state.set_scroll_offset(point(px(0.0), px(-15.0)))
+ });
+ let pos_y = origin_y + px(10.0);
+ let absolute_calc_row = graph.row_at_position(pos_y, cx);
+
+ assert_eq!(
+ absolute_calc_row,
+ Some(1),
+ "Row calculation should yield absolute row exactly"
+ );
+ });
+ }
}
@@ -453,6 +453,7 @@ impl CommitModal {
CommitOptions {
amend: is_amend_pending,
signoff: is_signoff_enabled,
+ allow_empty: false,
},
window,
cx,
@@ -2155,6 +2155,7 @@ impl GitPanel {
CommitOptions {
amend: false,
signoff: self.signoff_enabled,
+ allow_empty: false,
},
window,
cx,
@@ -2195,6 +2196,7 @@ impl GitPanel {
CommitOptions {
amend: true,
signoff: self.signoff_enabled,
+ allow_empty: false,
},
window,
cx,
@@ -4454,7 +4456,11 @@ impl GitPanel {
git_panel
.update(cx, |git_panel, cx| {
git_panel.commit_changes(
- CommitOptions { amend, signoff },
+ CommitOptions {
+ amend,
+ signoff,
+ allow_empty: false,
+ },
window,
cx,
);
@@ -462,6 +462,13 @@ impl ListState {
let current_offset = self.logical_scroll_top();
let state = &mut *self.0.borrow_mut();
+
+ if distance < px(0.) {
+ if let FollowState::Tail { is_following } = &mut state.follow_state {
+ *is_following = false;
+ }
+ }
+
let mut cursor = state.items.cursor::<ListItemSummary>(());
cursor.seek(&Count(current_offset.item_ix), Bias::Right);
@@ -536,6 +543,12 @@ impl ListState {
scroll_top.offset_in_item = px(0.);
}
+ if scroll_top.item_ix < item_count {
+ if let FollowState::Tail { is_following } = &mut state.follow_state {
+ *is_following = false;
+ }
+ }
+
state.logical_scroll_top = Some(scroll_top);
}
@@ -125,6 +125,7 @@ pub struct FakeLanguageModel {
>,
forbid_requests: AtomicBool,
supports_thinking: AtomicBool,
+ supports_streaming_tools: AtomicBool,
}
impl Default for FakeLanguageModel {
@@ -137,6 +138,7 @@ impl Default for FakeLanguageModel {
current_completion_txs: Mutex::new(Vec::new()),
forbid_requests: AtomicBool::new(false),
supports_thinking: AtomicBool::new(false),
+ supports_streaming_tools: AtomicBool::new(false),
}
}
}
@@ -169,6 +171,10 @@ impl FakeLanguageModel {
self.supports_thinking.store(supports, SeqCst);
}
+ pub fn set_supports_streaming_tools(&self, supports: bool) {
+ self.supports_streaming_tools.store(supports, SeqCst);
+ }
+
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_txs
.lock()
@@ -282,6 +288,10 @@ impl LanguageModel for FakeLanguageModel {
self.supports_thinking.load(SeqCst)
}
+ fn supports_streaming_tools(&self) -> bool {
+ self.supports_streaming_tools.load(SeqCst)
+ }
+
fn telemetry_id(&self) -> String {
"fake".to_string()
}
@@ -121,6 +121,9 @@ pub trait PickerDelegate: Sized + 'static {
) -> bool {
true
}
+ fn select_on_hover(&self) -> bool {
+ true
+ }
// Allows binding some optional effect to when the selection changes.
fn selected_index_changed(
@@ -788,12 +791,14 @@ impl<D: PickerDelegate> Picker<D> {
this.handle_click(ix, event.modifiers.platform, window, cx)
}),
)
- .on_hover(cx.listener(move |this, hovered: &bool, window, cx| {
- if *hovered {
- this.set_selected_index(ix, None, false, window, cx);
- cx.notify();
- }
- }))
+ .when(self.delegate.select_on_hover(), |this| {
+ this.on_hover(cx.listener(move |this, hovered: &bool, window, cx| {
+ if *hovered {
+ this.set_selected_index(ix, None, false, window, cx);
+ cx.notify();
+ }
+ }))
+ })
.children(self.delegate.render_match(
ix,
ix == self.delegate.selected_index(),
@@ -329,6 +329,12 @@ pub struct GraphDataResponse<'a> {
pub error: Option<SharedString>,
}
+#[derive(Clone, Debug)]
+enum CreateWorktreeStartPoint {
+ Detached,
+ Branched { name: String },
+}
+
pub struct Repository {
this: WeakEntity<Self>,
snapshot: RepositorySnapshot,
@@ -588,6 +594,7 @@ impl GitStore {
client.add_entity_request_handler(Self::handle_create_worktree);
client.add_entity_request_handler(Self::handle_remove_worktree);
client.add_entity_request_handler(Self::handle_rename_worktree);
+ client.add_entity_request_handler(Self::handle_get_head_sha);
}
pub fn is_local(&self) -> bool {
@@ -2340,6 +2347,7 @@ impl GitStore {
CommitOptions {
amend: options.amend,
signoff: options.signoff,
+ allow_empty: options.allow_empty,
},
askpass,
cx,
@@ -2406,12 +2414,18 @@ impl GitStore {
let repository_id = RepositoryId::from_proto(envelope.payload.repository_id);
let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?;
let directory = PathBuf::from(envelope.payload.directory);
- let name = envelope.payload.name;
+ let start_point = if envelope.payload.name.is_empty() {
+ CreateWorktreeStartPoint::Detached
+ } else {
+ CreateWorktreeStartPoint::Branched {
+ name: envelope.payload.name,
+ }
+ };
let commit = envelope.payload.commit;
repository_handle
.update(&mut cx, |repository_handle, _| {
- repository_handle.create_worktree(name, directory, commit)
+ repository_handle.create_worktree_with_start_point(start_point, directory, commit)
})
.await??;
@@ -2456,6 +2470,21 @@ impl GitStore {
Ok(proto::Ack {})
}
+ async fn handle_get_head_sha(
+ this: Entity<Self>,
+ envelope: TypedEnvelope<proto::GitGetHeadSha>,
+ mut cx: AsyncApp,
+ ) -> Result<proto::GitGetHeadShaResponse> {
+ let repository_id = RepositoryId::from_proto(envelope.payload.repository_id);
+ let repository_handle = Self::repository_for_request(&this, repository_id, &mut cx)?;
+
+ let head_sha = repository_handle
+ .update(&mut cx, |repository_handle, _| repository_handle.head_sha())
+ .await??;
+
+ Ok(proto::GitGetHeadShaResponse { sha: head_sha })
+ }
+
async fn handle_get_branches(
this: Entity<Self>,
envelope: TypedEnvelope<proto::GitGetBranches>,
@@ -5493,6 +5522,7 @@ impl Repository {
options: Some(proto::commit::CommitOptions {
amend: options.amend,
signoff: options.signoff,
+ allow_empty: options.allow_empty,
}),
askpass_id,
})
@@ -5974,36 +6004,174 @@ impl Repository {
})
}
+ fn create_worktree_with_start_point(
+ &mut self,
+ start_point: CreateWorktreeStartPoint,
+ path: PathBuf,
+ commit: Option<String>,
+ ) -> oneshot::Receiver<Result<()>> {
+ if matches!(
+ &start_point,
+ CreateWorktreeStartPoint::Branched { name } if name.is_empty()
+ ) {
+ let (sender, receiver) = oneshot::channel();
+ sender
+ .send(Err(anyhow!("branch name cannot be empty")))
+ .ok();
+ return receiver;
+ }
+
+ let id = self.id;
+ let message = match &start_point {
+ CreateWorktreeStartPoint::Detached => "git worktree add (detached)".into(),
+ CreateWorktreeStartPoint::Branched { name } => {
+ format!("git worktree add: {name}").into()
+ }
+ };
+
+ self.send_job(Some(message), move |repo, _cx| async move {
+ let branch_name = match start_point {
+ CreateWorktreeStartPoint::Detached => None,
+ CreateWorktreeStartPoint::Branched { name } => Some(name),
+ };
+ let remote_name = branch_name.clone().unwrap_or_default();
+
+ match repo {
+ RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
+ backend.create_worktree(branch_name, path, commit).await
+ }
+ RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => {
+ client
+ .request(proto::GitCreateWorktree {
+ project_id: project_id.0,
+ repository_id: id.to_proto(),
+ name: remote_name,
+ directory: path.to_string_lossy().to_string(),
+ commit,
+ })
+ .await?;
+
+ Ok(())
+ }
+ }
+ })
+ }
+
pub fn create_worktree(
&mut self,
branch_name: String,
path: PathBuf,
commit: Option<String>,
) -> oneshot::Receiver<Result<()>> {
+ self.create_worktree_with_start_point(
+ CreateWorktreeStartPoint::Branched { name: branch_name },
+ path,
+ commit,
+ )
+ }
+
+ pub fn create_worktree_detached(
+ &mut self,
+ path: PathBuf,
+ commit: String,
+ ) -> oneshot::Receiver<Result<()>> {
+ self.create_worktree_with_start_point(
+ CreateWorktreeStartPoint::Detached,
+ path,
+ Some(commit),
+ )
+ }
+
+ pub fn head_sha(&mut self) -> oneshot::Receiver<Result<Option<String>>> {
let id = self.id;
- self.send_job(
- Some(format!("git worktree add: {}", branch_name).into()),
- move |repo, _cx| async move {
- match repo {
- RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
- backend.create_worktree(branch_name, path, commit).await
- }
- RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => {
- client
- .request(proto::GitCreateWorktree {
- project_id: project_id.0,
- repository_id: id.to_proto(),
- name: branch_name,
- directory: path.to_string_lossy().to_string(),
- commit,
- })
- .await?;
+ self.send_job(None, move |repo, _cx| async move {
+ match repo {
+ RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
+ Ok(backend.head_sha().await)
+ }
+ RepositoryState::Remote(RemoteRepositoryState { project_id, client }) => {
+ let response = client
+ .request(proto::GitGetHeadSha {
+ project_id: project_id.0,
+ repository_id: id.to_proto(),
+ })
+ .await?;
- Ok(())
- }
+ Ok(response.sha)
}
- },
- )
+ }
+ })
+ }
+
+ pub fn update_ref(
+ &mut self,
+ ref_name: String,
+ commit: String,
+ ) -> oneshot::Receiver<Result<()>> {
+ self.send_job(None, move |repo, _cx| async move {
+ match repo {
+ RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
+ backend.update_ref(ref_name, commit).await
+ }
+ RepositoryState::Remote(_) => {
+ anyhow::bail!("update_ref is not supported for remote repositories")
+ }
+ }
+ })
+ }
+
+ pub fn delete_ref(&mut self, ref_name: String) -> oneshot::Receiver<Result<()>> {
+ self.send_job(None, move |repo, _cx| async move {
+ match repo {
+ RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
+ backend.delete_ref(ref_name).await
+ }
+ RepositoryState::Remote(_) => {
+ anyhow::bail!("delete_ref is not supported for remote repositories")
+ }
+ }
+ })
+ }
+
+ pub fn resolve_commit(&mut self, sha: String) -> oneshot::Receiver<Result<bool>> {
+ self.send_job(None, move |repo, _cx| async move {
+ match repo {
+ RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
+ let results = backend.revparse_batch(vec![sha]).await?;
+ Ok(results.into_iter().next().flatten().is_some())
+ }
+ RepositoryState::Remote(_) => {
+ anyhow::bail!("resolve_commit is not supported for remote repositories")
+ }
+ }
+ })
+ }
+
+ pub fn repair_worktrees(&mut self) -> oneshot::Receiver<Result<()>> {
+ self.send_job(None, move |repo, _cx| async move {
+ match repo {
+ RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
+ backend.repair_worktrees().await
+ }
+ RepositoryState::Remote(_) => {
+ anyhow::bail!("repair_worktrees is not supported for remote repositories")
+ }
+ }
+ })
+ }
+
+ pub fn commit_exists(&mut self, sha: String) -> oneshot::Receiver<Result<bool>> {
+ self.send_job(None, move |repo, _cx| async move {
+ match repo {
+ RepositoryState::Local(LocalRepositoryState { backend, .. }) => {
+ let results = backend.revparse_batch(vec![sha]).await?;
+ Ok(results.into_iter().next().flatten().is_some())
+ }
+ RepositoryState::Remote(_) => {
+ anyhow::bail!("commit_exists is not supported for remote repositories")
+ }
+ }
+ })
}
pub fn remove_worktree(&mut self, path: PathBuf, force: bool) -> oneshot::Receiver<Result<()>> {
@@ -403,6 +403,7 @@ message Commit {
message CommitOptions {
bool amend = 1;
bool signoff = 2;
+ bool allow_empty = 3;
}
}
@@ -567,6 +568,15 @@ message GitGetWorktrees {
uint64 repository_id = 2;
}
+message GitGetHeadSha {
+ uint64 project_id = 1;
+ uint64 repository_id = 2;
+}
+
+message GitGetHeadShaResponse {
+ optional string sha = 1;
+}
+
message GitWorktreesResponse {
repeated Worktree worktrees = 1;
}
@@ -474,7 +474,9 @@ message Envelope {
GitCompareCheckpoints git_compare_checkpoints = 436;
GitCompareCheckpointsResponse git_compare_checkpoints_response = 437;
GitDiffCheckpoints git_diff_checkpoints = 438;
- GitDiffCheckpointsResponse git_diff_checkpoints_response = 439; // current max
+ GitDiffCheckpointsResponse git_diff_checkpoints_response = 439;
+ GitGetHeadSha git_get_head_sha = 440;
+ GitGetHeadShaResponse git_get_head_sha_response = 441; // current max
}
reserved 87 to 88;
@@ -351,6 +351,8 @@ messages!(
(NewExternalAgentVersionAvailable, Background),
(RemoteStarted, Background),
(GitGetWorktrees, Background),
+ (GitGetHeadSha, Background),
+ (GitGetHeadShaResponse, Background),
(GitWorktreesResponse, Background),
(GitCreateWorktree, Background),
(GitRemoveWorktree, Background),
@@ -558,6 +560,7 @@ request_messages!(
(GetContextServerCommand, ContextServerCommand),
(RemoteStarted, Ack),
(GitGetWorktrees, GitWorktreesResponse),
+ (GitGetHeadSha, GitGetHeadShaResponse),
(GitCreateWorktree, Ack),
(GitRemoveWorktree, Ack),
(GitRenameWorktree, Ack),
@@ -749,6 +752,7 @@ entity_messages!(
ExternalAgentLoadingStatusUpdated,
NewExternalAgentVersionAvailable,
GitGetWorktrees,
+ GitGetHeadSha,
GitCreateWorktree,
GitRemoveWorktree,
GitRenameWorktree,
@@ -357,7 +357,6 @@ pub fn init(cx: &mut App) {
.update(cx, |multi_workspace, window, cx| {
let sibling_workspace_ids: HashSet<WorkspaceId> = multi_workspace
.workspaces()
- .iter()
.filter_map(|ws| ws.read(cx).database_id())
.collect();
@@ -1113,7 +1112,6 @@ impl PickerDelegate for RecentProjectsDelegate {
.update(cx, |multi_workspace, window, cx| {
let workspace = multi_workspace
.workspaces()
- .iter()
.find(|ws| ws.read(cx).database_id() == Some(workspace_id))
.cloned();
if let Some(workspace) = workspace {
@@ -1932,7 +1930,6 @@ impl RecentProjectsDelegate {
.update(cx, |multi_workspace, window, cx| {
let workspace = multi_workspace
.workspaces()
- .iter()
.find(|ws| ws.read(cx).database_id() == Some(workspace_id))
.cloned();
if let Some(workspace) = workspace {
@@ -2055,6 +2052,11 @@ mod tests {
assert_eq!(cx.update(|cx| cx.windows().len()), 1);
let multi_workspace = cx.update(|cx| cx.windows()[0].downcast::<MultiWorkspace>().unwrap());
+ multi_workspace
+ .update(cx, |multi_workspace, _, cx| {
+ multi_workspace.open_sidebar(cx);
+ })
+ .unwrap();
multi_workspace
.update(cx, |multi_workspace, _, cx| {
assert!(!multi_workspace.workspace().read(cx).is_edited())
@@ -2141,7 +2143,7 @@ mod tests {
);
assert!(
- multi_workspace.workspaces().contains(&dirty_workspace),
+ multi_workspace.workspaces().any(|w| w == &dirty_workspace),
"The dirty workspace should still be present in multi-workspace mode"
);
@@ -225,6 +225,10 @@ impl PickerDelegate for RulePickerDelegate {
}
}
+ fn select_on_hover(&self) -> bool {
+ false
+ }
+
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Search…".into()
}
@@ -763,6 +763,7 @@ pub struct VimSettingsContent {
pub toggle_relative_line_numbers: Option<bool>,
pub use_system_clipboard: Option<UseSystemClipboard>,
pub use_smartcase_find: Option<bool>,
+ pub use_regex_search: Option<bool>,
/// When enabled, the `:substitute` command replaces all matches in a line
/// by default. The 'g' flag then toggles this behavior.,
pub gdefault: Option<bool>,
@@ -2447,7 +2447,7 @@ fn editor_page() -> SettingsPage {
]
}
- fn vim_settings_section() -> [SettingsPageItem; 12] {
+ fn vim_settings_section() -> [SettingsPageItem; 13] {
[
SettingsPageItem::SectionHeader("Vim"),
SettingsPageItem::SettingItem(SettingItem {
@@ -2556,6 +2556,24 @@ fn editor_page() -> SettingsPage {
metadata: None,
files: USER,
}),
+ SettingsPageItem::SettingItem(SettingItem {
+ title: "Regex Search",
+ description: "Use regex search by default in Vim search.",
+ field: Box::new(SettingField {
+ json_path: Some("vim.use_regex_search"),
+ pick: |settings_content| {
+ settings_content.vim.as_ref()?.use_regex_search.as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content
+ .vim
+ .get_or_insert_default()
+ .use_regex_search = value;
+ },
+ }),
+ metadata: None,
+ files: USER,
+ }),
SettingsPageItem::SettingItem(SettingItem {
title: "Cursor Shape - Normal Mode",
description: "Cursor shape for normal mode.",
@@ -4433,7 +4451,7 @@ fn window_and_layout_page() -> SettingsPage {
}
fn panels_page() -> SettingsPage {
- fn project_panel_section() -> [SettingsPageItem; 24] {
+ fn project_panel_section() -> [SettingsPageItem; 28] {
[
SettingsPageItem::SectionHeader("Project Panel"),
SettingsPageItem::SettingItem(SettingItem {
@@ -4914,31 +4932,25 @@ fn panels_page() -> SettingsPage {
files: USER,
}),
SettingsPageItem::SettingItem(SettingItem {
- title: "Hidden Files",
- description: "Globs to match files that will be considered \"hidden\" and can be hidden from the project panel.",
- field: Box::new(
- SettingField {
- json_path: Some("worktree.hidden_files"),
- pick: |settings_content| {
- settings_content.project.worktree.hidden_files.as_ref()
- },
- write: |settings_content, value| {
- settings_content.project.worktree.hidden_files = value;
- },
- }
- .unimplemented(),
- ),
+ title: "Sort Mode",
+ description: "Sort order for entries in the project panel.",
+ field: Box::new(SettingField {
+ json_path: Some("project_panel.sort_mode"),
+ pick: |settings_content| {
+ settings_content.project_panel.as_ref()?.sort_mode.as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content
+ .project_panel
+ .get_or_insert_default()
+ .sort_mode = value;
+ },
+ }),
metadata: None,
files: USER,
}),
- ]
- }
-
- fn auto_open_files_section() -> [SettingsPageItem; 5] {
- [
- SettingsPageItem::SectionHeader("Auto Open Files"),
SettingsPageItem::SettingItem(SettingItem {
- title: "On Create",
+ title: "Auto Open Files On Create",
description: "Whether to automatically open newly created files in the editor.",
field: Box::new(SettingField {
json_path: Some("project_panel.auto_open.on_create"),
@@ -4964,7 +4976,7 @@ fn panels_page() -> SettingsPage {
files: USER,
}),
SettingsPageItem::SettingItem(SettingItem {
- title: "On Paste",
+ title: "Auto Open Files On Paste",
description: "Whether to automatically open files after pasting or duplicating them.",
field: Box::new(SettingField {
json_path: Some("project_panel.auto_open.on_paste"),
@@ -4990,7 +5002,7 @@ fn panels_page() -> SettingsPage {
files: USER,
}),
SettingsPageItem::SettingItem(SettingItem {
- title: "On Drop",
+ title: "Auto Open Files On Drop",
description: "Whether to automatically open files dropped from external sources.",
field: Box::new(SettingField {
json_path: Some("project_panel.auto_open.on_drop"),
@@ -5016,20 +5028,20 @@ fn panels_page() -> SettingsPage {
files: USER,
}),
SettingsPageItem::SettingItem(SettingItem {
- title: "Sort Mode",
- description: "Sort order for entries in the project panel.",
- field: Box::new(SettingField {
- pick: |settings_content| {
- settings_content.project_panel.as_ref()?.sort_mode.as_ref()
- },
- write: |settings_content, value| {
- settings_content
- .project_panel
- .get_or_insert_default()
- .sort_mode = value;
- },
- json_path: Some("project_panel.sort_mode"),
- }),
+ title: "Hidden Files",
+ description: "Globs to match files that will be considered \"hidden\" and can be hidden from the project panel.",
+ field: Box::new(
+ SettingField {
+ json_path: Some("worktree.hidden_files"),
+ pick: |settings_content| {
+ settings_content.project.worktree.hidden_files.as_ref()
+ },
+ write: |settings_content, value| {
+ settings_content.project.worktree.hidden_files = value;
+ },
+ }
+ .unimplemented(),
+ ),
metadata: None,
files: USER,
}),
@@ -5807,7 +5819,6 @@ fn panels_page() -> SettingsPage {
title: "Panels",
items: concat_sections![
project_panel_section(),
- auto_open_files_section(),
terminal_panel_section(),
outline_panel_section(),
git_panel_section(),
@@ -3753,7 +3753,6 @@ fn all_projects(
.flat_map(|multi_workspace| {
multi_workspace
.workspaces()
- .iter()
.map(|workspace| workspace.read(cx).project().clone())
.collect::<Vec<_>>()
}),
@@ -434,7 +434,7 @@ impl Sidebar {
})
.detach();
- let workspaces = multi_workspace.read(cx).workspaces().to_vec();
+ let workspaces: Vec<_> = multi_workspace.read(cx).workspaces().cloned().collect();
cx.defer_in(window, move |this, window, cx| {
for workspace in &workspaces {
this.subscribe_to_workspace(workspace, window, cx);
@@ -673,7 +673,6 @@ impl Sidebar {
let mw = self.multi_workspace.upgrade()?;
let mw = mw.read(cx);
mw.workspaces()
- .iter()
.find(|ws| ws.read(cx).project_group_key(cx).path_list() == path_list)
.cloned()
}
@@ -716,8 +715,8 @@ impl Sidebar {
return;
};
let mw = multi_workspace.read(cx);
- let workspaces = mw.workspaces().to_vec();
- let active_workspace = mw.workspaces().get(mw.active_workspace_index()).cloned();
+ let workspaces: Vec<_> = mw.workspaces().cloned().collect();
+ let active_workspace = Some(mw.workspace().clone());
let agent_server_store = workspaces
.first()
@@ -1769,7 +1768,11 @@ impl Sidebar {
dispatch_context.add("ThreadsSidebar");
dispatch_context.add("menu");
- let identifier = if self.filter_editor.focus_handle(cx).is_focused(window) {
+ let is_archived_search_focused = matches!(&self.view, SidebarView::Archive(archive) if archive.read(cx).is_filter_editor_focused(window, cx));
+
+ let identifier = if self.filter_editor.focus_handle(cx).is_focused(window)
+ || is_archived_search_focused
+ {
"searching"
} else {
"not_searching"
@@ -1989,7 +1992,6 @@ impl Sidebar {
let workspace = window.read(cx).ok().and_then(|multi_workspace| {
multi_workspace
.workspaces()
- .iter()
.find(|workspace| predicate(workspace, cx))
.cloned()
})?;
@@ -2006,7 +2008,6 @@ impl Sidebar {
multi_workspace
.read(cx)
.workspaces()
- .iter()
.find(|workspace| predicate(workspace, cx))
.cloned()
})
@@ -2199,12 +2200,10 @@ impl Sidebar {
return;
}
- let active_workspace = self.multi_workspace.upgrade().and_then(|w| {
- w.read(cx)
- .workspaces()
- .get(w.read(cx).active_workspace_index())
- .cloned()
- });
+ let active_workspace = self
+ .multi_workspace
+ .upgrade()
+ .map(|w| w.read(cx).workspace().clone());
if let Some(workspace) = active_workspace {
self.activate_thread_locally(&metadata, &workspace, window, cx);
@@ -2339,7 +2338,7 @@ impl Sidebar {
return;
};
- let workspaces = multi_workspace.read(cx).workspaces().to_vec();
+ let workspaces: Vec<_> = multi_workspace.read(cx).workspaces().cloned().collect();
for workspace in workspaces {
if let Some(agent_panel) = workspace.read(cx).panel::<AgentPanel>(cx) {
let cancelled =
@@ -2932,7 +2931,6 @@ impl Sidebar {
.map(|mw| {
mw.read(cx)
.workspaces()
- .iter()
.filter_map(|ws| ws.read(cx).database_id())
.collect()
})
@@ -3400,12 +3398,9 @@ impl Sidebar {
}
fn active_workspace(&self, cx: &App) -> Option<Entity<Workspace>> {
- self.multi_workspace.upgrade().and_then(|w| {
- w.read(cx)
- .workspaces()
- .get(w.read(cx).active_workspace_index())
- .cloned()
- })
+ self.multi_workspace
+ .upgrade()
+ .map(|w| w.read(cx).workspace().clone())
}
fn show_thread_import_modal(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -3513,12 +3508,11 @@ impl Sidebar {
}
fn show_archive(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- let Some(active_workspace) = self.multi_workspace.upgrade().and_then(|w| {
- w.read(cx)
- .workspaces()
- .get(w.read(cx).active_workspace_index())
- .cloned()
- }) else {
+ let Some(active_workspace) = self
+ .multi_workspace
+ .upgrade()
+ .map(|w| w.read(cx).workspace().clone())
+ else {
return;
};
let Some(agent_panel) = active_workspace.read(cx).panel::<AgentPanel>(cx) else {
@@ -3820,12 +3814,12 @@ pub fn dump_workspace_info(
let multi_workspace = workspace.multi_workspace().and_then(|weak| weak.upgrade());
let workspaces: Vec<gpui::Entity<Workspace>> = match &multi_workspace {
- Some(mw) => mw.read(cx).workspaces().to_vec(),
+ Some(mw) => mw.read(cx).workspaces().cloned().collect(),
None => vec![this_entity.clone()],
};
- let active_index = multi_workspace
+ let active_workspace = multi_workspace
.as_ref()
- .map(|mw| mw.read(cx).active_workspace_index());
+ .map(|mw| mw.read(cx).workspace().clone());
writeln!(output, "MultiWorkspace: {} workspace(s)", workspaces.len()).ok();
@@ -3837,13 +3831,10 @@ pub fn dump_workspace_info(
}
}
- if let Some(index) = active_index {
- writeln!(output, "Active workspace index: {index}").ok();
- }
writeln!(output).ok();
for (index, ws) in workspaces.iter().enumerate() {
- let is_active = active_index == Some(index);
+ let is_active = active_workspace.as_ref() == Some(ws);
writeln!(
output,
"--- Workspace {index}{} ---",
@@ -77,6 +77,18 @@ async fn init_test_project(
fn setup_sidebar(
multi_workspace: &Entity<MultiWorkspace>,
cx: &mut gpui::VisualTestContext,
+) -> Entity<Sidebar> {
+ let sidebar = setup_sidebar_closed(multi_workspace, cx);
+ multi_workspace.update_in(cx, |mw, window, cx| {
+ mw.toggle_sidebar(window, cx);
+ });
+ cx.run_until_parked();
+ sidebar
+}
+
+fn setup_sidebar_closed(
+ multi_workspace: &Entity<MultiWorkspace>,
+ cx: &mut gpui::VisualTestContext,
) -> Entity<Sidebar> {
let multi_workspace = multi_workspace.clone();
let sidebar =
@@ -172,16 +184,7 @@ fn save_thread_metadata(
cx.run_until_parked();
}
-fn open_and_focus_sidebar(sidebar: &Entity<Sidebar>, cx: &mut gpui::VisualTestContext) {
- let multi_workspace = sidebar.read_with(cx, |s, _| s.multi_workspace.upgrade());
- if let Some(multi_workspace) = multi_workspace {
- multi_workspace.update_in(cx, |mw, window, cx| {
- if !mw.sidebar_open() {
- mw.toggle_sidebar(window, cx);
- }
- });
- }
- cx.run_until_parked();
+fn focus_sidebar(sidebar: &Entity<Sidebar>, cx: &mut gpui::VisualTestContext) {
sidebar.update_in(cx, |_, window, cx| {
cx.focus_self(window);
});
@@ -544,7 +547,7 @@ async fn test_workspace_lifecycle(cx: &mut TestAppContext) {
// Remove the second workspace
multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[1].clone();
+ let workspace = mw.workspaces().nth(1).cloned().unwrap();
mw.remove(&workspace, window, cx);
});
cx.run_until_parked();
@@ -604,7 +607,7 @@ async fn test_view_more_batched_expansion(cx: &mut TestAppContext) {
assert!(entries.iter().any(|e| e.contains("View More")));
// Focus and navigate to View More, then confirm to expand by one batch
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
for _ in 0..7 {
cx.dispatch_action(SelectNext);
}
@@ -915,7 +918,7 @@ async fn test_keyboard_select_next_and_previous(cx: &mut TestAppContext) {
// Entries: [header, thread3, thread2, thread1]
// Focusing the sidebar does not set a selection; select_next/select_previous
// handle None gracefully by starting from the first or last entry.
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None);
// First SelectNext from None starts at index 0
@@ -970,7 +973,7 @@ async fn test_keyboard_select_first_and_last(cx: &mut TestAppContext) {
multi_workspace.update_in(cx, |_, _window, cx| cx.notify());
cx.run_until_parked();
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
// SelectLast jumps to the end
cx.dispatch_action(SelectLast);
@@ -993,7 +996,7 @@ async fn test_keyboard_focus_in_does_not_set_selection(cx: &mut TestAppContext)
// Open the sidebar so it's rendered, then focus it to trigger focus_in.
// focus_in no longer sets a default selection.
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None);
// Manually set a selection, blur, then refocus — selection should be preserved
@@ -1030,7 +1033,7 @@ async fn test_keyboard_confirm_on_project_header_toggles_collapse(cx: &mut TestA
);
// Focus the sidebar and select the header (index 0)
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
sidebar.update_in(cx, |sidebar, _window, _cx| {
sidebar.selection = Some(0);
});
@@ -1071,7 +1074,7 @@ async fn test_keyboard_confirm_on_view_more_expands(cx: &mut TestAppContext) {
assert!(entries.iter().any(|e| e.contains("View More")));
// Focus sidebar (selection starts at None), then navigate down to the "View More" entry (index 6)
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
for _ in 0..7 {
cx.dispatch_action(SelectNext);
}
@@ -1105,7 +1108,7 @@ async fn test_keyboard_expand_and_collapse_selected_entry(cx: &mut TestAppContex
);
// Focus sidebar and manually select the header (index 0). Press left to collapse.
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
sidebar.update_in(cx, |sidebar, _window, _cx| {
sidebar.selection = Some(0);
});
@@ -1144,7 +1147,7 @@ async fn test_keyboard_collapse_from_child_selects_parent(cx: &mut TestAppContex
cx.run_until_parked();
// Focus sidebar (selection starts at None), then navigate down to the thread (child)
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
cx.dispatch_action(SelectNext);
cx.dispatch_action(SelectNext);
assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(1));
@@ -1179,7 +1182,7 @@ async fn test_keyboard_navigation_on_empty_list(cx: &mut TestAppContext) {
);
// Focus sidebar — focus_in does not set a selection
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
assert_eq!(sidebar.read_with(cx, |s, _| s.selection), None);
// First SelectNext from None starts at index 0 (header)
@@ -1211,7 +1214,7 @@ async fn test_selection_clamps_after_entry_removal(cx: &mut TestAppContext) {
cx.run_until_parked();
// Focus sidebar (selection starts at None), navigate down to the thread (index 1)
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
cx.dispatch_action(SelectNext);
cx.dispatch_action(SelectNext);
assert_eq!(sidebar.read_with(cx, |s, _| s.selection), Some(1));
@@ -1492,7 +1495,7 @@ async fn test_escape_clears_search_and_restores_full_list(cx: &mut TestAppContex
);
// User types a search query to filter down.
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
type_in_search(&sidebar, "alpha", cx);
assert_eq!(
visible_entries_as_strings(&sidebar, cx),
@@ -1540,8 +1543,9 @@ async fn test_search_only_shows_workspace_headers_with_matches(cx: &mut TestAppC
});
cx.run_until_parked();
- let project_b =
- multi_workspace.read_with(cx, |mw, cx| mw.workspaces()[1].read(cx).project().clone());
+ let project_b = multi_workspace.read_with(cx, |mw, cx| {
+ mw.workspaces().nth(1).unwrap().read(cx).project().clone()
+ });
for (id, title, hour) in [
("b1", "Refactor sidebar layout", 3),
@@ -1621,8 +1625,9 @@ async fn test_search_matches_workspace_name(cx: &mut TestAppContext) {
});
cx.run_until_parked();
- let project_b =
- multi_workspace.read_with(cx, |mw, cx| mw.workspaces()[1].read(cx).project().clone());
+ let project_b = multi_workspace.read_with(cx, |mw, cx| {
+ mw.workspaces().nth(1).unwrap().read(cx).project().clone()
+ });
for (id, title, hour) in [
("b1", "Refactor sidebar layout", 3),
@@ -1764,7 +1769,7 @@ async fn test_search_finds_threads_inside_collapsed_groups(cx: &mut TestAppConte
// User focuses the sidebar and collapses the group using keyboard:
// manually select the header, then press SelectParent to collapse.
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
sidebar.update_in(cx, |sidebar, _window, _cx| {
sidebar.selection = Some(0);
});
@@ -1807,7 +1812,7 @@ async fn test_search_then_keyboard_navigate_and_confirm(cx: &mut TestAppContext)
}
cx.run_until_parked();
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
// User types "fix" — two threads match.
type_in_search(&sidebar, "fix", cx);
@@ -1856,6 +1861,13 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC
});
cx.run_until_parked();
+ let (workspace_0, workspace_1) = multi_workspace.read_with(cx, |mw, _| {
+ (
+ mw.workspaces().next().unwrap().clone(),
+ mw.workspaces().nth(1).unwrap().clone(),
+ )
+ });
+
save_thread_metadata(
acp::SessionId::new(Arc::from("hist-1")),
"Historical Thread".into(),
@@ -1875,13 +1887,13 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC
// Switch to workspace 1 so we can verify the confirm switches back.
multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[1].clone();
+ let workspace = mw.workspaces().nth(1).unwrap().clone();
mw.activate(workspace, window, cx);
});
cx.run_until_parked();
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
- 1
+ multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+ workspace_1
);
// Confirm on the historical (non-live) thread at index 1.
@@ -1895,8 +1907,8 @@ async fn test_confirm_on_historical_thread_activates_workspace(cx: &mut TestAppC
cx.run_until_parked();
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
- 0
+ multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+ workspace_0
);
}
@@ -2037,7 +2049,8 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) {
let panel_b = add_agent_panel(&workspace_b, cx);
cx.run_until_parked();
- let workspace_a = multi_workspace.read_with(cx, |mw, _cx| mw.workspaces()[0].clone());
+ let workspace_a =
+ multi_workspace.read_with(cx, |mw, _cx| mw.workspaces().next().unwrap().clone());
// ── 1. Initial state: focused thread derived from active panel ─────
sidebar.read_with(cx, |sidebar, _cx| {
@@ -2135,7 +2148,7 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) {
});
multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[0].clone();
+ let workspace = mw.workspaces().next().unwrap().clone();
mw.activate(workspace, window, cx);
});
cx.run_until_parked();
@@ -2190,8 +2203,8 @@ async fn test_focused_thread_tracks_user_intent(cx: &mut TestAppContext) {
// Switching workspaces via the multi_workspace (simulates clicking
// a workspace header) should clear focused_thread.
multi_workspace.update_in(cx, |mw, window, cx| {
- if let Some(index) = mw.workspaces().iter().position(|w| w == &workspace_b) {
- let workspace = mw.workspaces()[index].clone();
+ let workspace = mw.workspaces().find(|w| *w == &workspace_b).cloned();
+ if let Some(workspace) = workspace {
mw.activate(workspace, window, cx);
}
});
@@ -2477,6 +2490,8 @@ async fn test_cmd_n_shows_new_thread_entry_in_absorbed_worktree(cx: &mut TestApp
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx));
+ let sidebar = setup_sidebar(&multi_workspace, cx);
+
let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
mw.test_add_workspace(worktree_project.clone(), window, cx)
});
@@ -2485,12 +2500,10 @@ async fn test_cmd_n_shows_new_thread_entry_in_absorbed_worktree(cx: &mut TestApp
// Switch to the worktree workspace.
multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[1].clone();
+ let workspace = mw.workspaces().nth(1).unwrap().clone();
mw.activate(workspace, window, cx);
});
- let sidebar = setup_sidebar(&multi_workspace, cx);
-
// Create a non-empty thread in the worktree workspace.
let connection = StubAgentConnection::new();
connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk(
@@ -3027,6 +3040,8 @@ async fn test_absorbed_worktree_running_thread_shows_live_status(cx: &mut TestAp
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx));
+ let sidebar = setup_sidebar(&multi_workspace, cx);
+
let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
mw.test_add_workspace(worktree_project.clone(), window, cx)
});
@@ -3037,12 +3052,10 @@ async fn test_absorbed_worktree_running_thread_shows_live_status(cx: &mut TestAp
// Switch back to the main workspace before setting up the sidebar.
multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[0].clone();
+ let workspace = mw.workspaces().next().unwrap().clone();
mw.activate(workspace, window, cx);
});
- let sidebar = setup_sidebar(&multi_workspace, cx);
-
// Start a thread in the worktree workspace's panel and keep it
// generating (don't resolve it).
let connection = StubAgentConnection::new();
@@ -3127,6 +3140,8 @@ async fn test_absorbed_worktree_completion_triggers_notification(cx: &mut TestAp
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx));
+ let sidebar = setup_sidebar(&multi_workspace, cx);
+
let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
mw.test_add_workspace(worktree_project.clone(), window, cx)
});
@@ -3134,12 +3149,10 @@ async fn test_absorbed_worktree_completion_triggers_notification(cx: &mut TestAp
let worktree_panel = add_agent_panel(&worktree_workspace, cx);
multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[0].clone();
+ let workspace = mw.workspaces().next().unwrap().clone();
mw.activate(workspace, window, cx);
});
- let sidebar = setup_sidebar(&multi_workspace, cx);
-
let connection = StubAgentConnection::new();
open_thread_with_connection(&worktree_panel, connection.clone(), cx);
send_message(&worktree_panel, cx);
@@ -3231,12 +3244,12 @@ async fn test_clicking_worktree_thread_opens_workspace_when_none_exists(cx: &mut
// Only 1 workspace should exist.
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()),
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
1,
);
// Focus the sidebar and select the worktree thread.
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
sidebar.update_in(cx, |sidebar, _window, _cx| {
sidebar.selection = Some(1); // index 0 is header, 1 is the thread
});
@@ -3248,11 +3261,11 @@ async fn test_clicking_worktree_thread_opens_workspace_when_none_exists(cx: &mut
// A new workspace should have been created for the worktree path.
let new_workspace = multi_workspace.read_with(cx, |mw, _| {
assert_eq!(
- mw.workspaces().len(),
+ mw.workspaces().count(),
2,
"confirming a worktree thread without a workspace should open one",
);
- mw.workspaces()[1].clone()
+ mw.workspaces().nth(1).unwrap().clone()
});
let new_path_list =
@@ -3318,7 +3331,7 @@ async fn test_clicking_worktree_thread_does_not_briefly_render_as_separate_proje
vec!["v [project]", " WT Thread {wt-feature-a}"],
);
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
sidebar.update_in(cx, |sidebar, _window, _cx| {
sidebar.selection = Some(1); // index 0 is header, 1 is the thread
});
@@ -3444,18 +3457,19 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace(
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx));
+ let sidebar = setup_sidebar(&multi_workspace, cx);
+
let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
mw.test_add_workspace(worktree_project.clone(), window, cx)
});
// Activate the main workspace before setting up the sidebar.
- multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[0].clone();
- mw.activate(workspace, window, cx);
+ let main_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
+ let workspace = mw.workspaces().next().unwrap().clone();
+ mw.activate(workspace.clone(), window, cx);
+ workspace
});
- let sidebar = setup_sidebar(&multi_workspace, cx);
-
save_named_thread_metadata("thread-main", "Main Thread", &main_project, cx).await;
save_named_thread_metadata("thread-wt", "WT Thread", &worktree_project, cx).await;
@@ -3475,13 +3489,13 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace(
.expect("should find the worktree thread entry");
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
- 0,
+ multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+ main_workspace,
"main workspace should be active initially"
);
// Focus the sidebar and select the absorbed worktree thread.
- open_and_focus_sidebar(&sidebar, cx);
+ focus_sidebar(&sidebar, cx);
sidebar.update_in(cx, |sidebar, _window, _cx| {
sidebar.selection = Some(wt_thread_index);
});
@@ -3491,9 +3505,7 @@ async fn test_clicking_absorbed_worktree_thread_activates_worktree_workspace(
cx.run_until_parked();
// The worktree workspace should now be active, not the main one.
- let active_workspace = multi_workspace.read_with(cx, |mw, _| {
- mw.workspaces()[mw.active_workspace_index()].clone()
- });
+ let active_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
assert_eq!(
active_workspace, worktree_workspace,
"clicking an absorbed worktree thread should activate the worktree workspace"
@@ -3520,25 +3532,27 @@ async fn test_activate_archived_thread_with_saved_paths_activates_matching_works
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx));
- multi_workspace.update_in(cx, |mw, window, cx| {
- mw.test_add_workspace(project_b.clone(), window, cx);
- });
-
let sidebar = setup_sidebar(&multi_workspace, cx);
+ let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| {
+ mw.test_add_workspace(project_b.clone(), window, cx)
+ });
+ let workspace_a =
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone());
+
// Save a thread with path_list pointing to project-b.
let session_id = acp::SessionId::new(Arc::from("archived-1"));
save_test_thread_metadata(&session_id, &project_b, cx).await;
// Ensure workspace A is active.
multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[0].clone();
+ let workspace = mw.workspaces().next().unwrap().clone();
mw.activate(workspace, window, cx);
});
cx.run_until_parked();
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
- 0
+ multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+ workspace_a
);
// Call activate_archived_thread – should resolve saved paths and
@@ -3562,8 +3576,8 @@ async fn test_activate_archived_thread_with_saved_paths_activates_matching_works
cx.run_until_parked();
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
- 1,
+ multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+ workspace_b,
"should have activated the workspace matching the saved path_list"
);
}
@@ -3588,21 +3602,23 @@ async fn test_activate_archived_thread_cwd_fallback_with_matching_workspace(
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
- multi_workspace.update_in(cx, |mw, window, cx| {
- mw.test_add_workspace(project_b, window, cx);
- });
-
let sidebar = setup_sidebar(&multi_workspace, cx);
+ let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| {
+ mw.test_add_workspace(project_b, window, cx)
+ });
+ let workspace_a =
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone());
+
// Start with workspace A active.
multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[0].clone();
+ let workspace = mw.workspaces().next().unwrap().clone();
mw.activate(workspace, window, cx);
});
cx.run_until_parked();
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
- 0
+ multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+ workspace_a
);
// No thread saved to the store – cwd is the only path hint.
@@ -3625,8 +3641,8 @@ async fn test_activate_archived_thread_cwd_fallback_with_matching_workspace(
cx.run_until_parked();
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
- 1,
+ multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+ workspace_b,
"should have activated the workspace matching the cwd"
);
}
@@ -3651,21 +3667,21 @@ async fn test_activate_archived_thread_no_paths_no_cwd_uses_active_workspace(
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
- multi_workspace.update_in(cx, |mw, window, cx| {
- mw.test_add_workspace(project_b, window, cx);
- });
-
let sidebar = setup_sidebar(&multi_workspace, cx);
+ let workspace_b = multi_workspace.update_in(cx, |mw, window, cx| {
+ mw.test_add_workspace(project_b, window, cx)
+ });
+
// Activate workspace B (index 1) to make it the active one.
multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[1].clone();
+ let workspace = mw.workspaces().nth(1).unwrap().clone();
mw.activate(workspace, window, cx);
});
cx.run_until_parked();
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
- 1
+ multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+ workspace_b
);
// No saved thread, no cwd – should fall back to the active workspace.
@@ -3688,8 +3704,8 @@ async fn test_activate_archived_thread_no_paths_no_cwd_uses_active_workspace(
cx.run_until_parked();
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.active_workspace_index()),
- 1,
+ multi_workspace.read_with(cx, |mw, _| mw.workspace().clone()),
+ workspace_b,
"should have stayed on the active workspace when no path info is available"
);
}
@@ -3719,7 +3735,7 @@ async fn test_activate_archived_thread_saved_paths_opens_new_workspace(cx: &mut
let session_id = acp::SessionId::new(Arc::from("archived-new-ws"));
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()),
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
1,
"should start with one workspace"
);
@@ -3743,7 +3759,7 @@ async fn test_activate_archived_thread_saved_paths_opens_new_workspace(cx: &mut
cx.run_until_parked();
assert_eq!(
- multi_workspace.read_with(cx, |mw, _| mw.workspaces().len()),
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
2,
"should have opened a second workspace for the archived thread's saved paths"
);
@@ -3768,6 +3784,10 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window(cx: &m
cx.add_window(|window, cx| MultiWorkspace::test_new(project_b, window, cx));
let multi_workspace_a_entity = multi_workspace_a.root(cx).unwrap();
+ let multi_workspace_b_entity = multi_workspace_b.root(cx).unwrap();
+
+ let cx_b = &mut gpui::VisualTestContext::from_window(multi_workspace_b.into(), cx);
+ let _sidebar_b = setup_sidebar(&multi_workspace_b_entity, cx_b);
let cx_a = &mut gpui::VisualTestContext::from_window(multi_workspace_a.into(), cx);
let sidebar = setup_sidebar(&multi_workspace_a_entity, cx_a);
@@ -3794,14 +3814,14 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window(cx: &m
assert_eq!(
multi_workspace_a
- .read_with(cx_a, |mw, _| mw.workspaces().len())
+ .read_with(cx_a, |mw, _| mw.workspaces().count())
.unwrap(),
1,
"should not add the other window's workspace into the current window"
);
assert_eq!(
multi_workspace_b
- .read_with(cx_a, |mw, _| mw.workspaces().len())
+ .read_with(cx_a, |mw, _| mw.workspaces().count())
.unwrap(),
1,
"should reuse the existing workspace in the other window"
@@ -3871,14 +3891,14 @@ async fn test_activate_archived_thread_reuses_workspace_in_another_window_with_t
assert_eq!(
multi_workspace_a
- .read_with(cx_a, |mw, _| mw.workspaces().len())
+ .read_with(cx_a, |mw, _| mw.workspaces().count())
.unwrap(),
1,
"should not add the other window's workspace into the current window"
);
assert_eq!(
multi_workspace_b
- .read_with(cx_a, |mw, _| mw.workspaces().len())
+ .read_with(cx_a, |mw, _| mw.workspaces().count())
.unwrap(),
1,
"should reuse the existing workspace in the other window"
@@ -3921,6 +3941,10 @@ async fn test_activate_archived_thread_prefers_current_window_for_matching_paths
cx.add_window(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
let multi_workspace_a_entity = multi_workspace_a.root(cx).unwrap();
+ let multi_workspace_b_entity = multi_workspace_b.root(cx).unwrap();
+
+ let cx_b = &mut gpui::VisualTestContext::from_window(multi_workspace_b.into(), cx);
+ let _sidebar_b = setup_sidebar(&multi_workspace_b_entity, cx_b);
let cx_a = &mut gpui::VisualTestContext::from_window(multi_workspace_a.into(), cx);
let sidebar_a = setup_sidebar(&multi_workspace_a_entity, cx_a);
@@ -3958,14 +3982,14 @@ async fn test_activate_archived_thread_prefers_current_window_for_matching_paths
});
assert_eq!(
multi_workspace_a
- .read_with(cx_a, |mw, _| mw.workspaces().len())
+ .read_with(cx_a, |mw, _| mw.workspaces().count())
.unwrap(),
1,
"current window should continue reusing its existing workspace"
);
assert_eq!(
multi_workspace_b
- .read_with(cx_a, |mw, _| mw.workspaces().len())
+ .read_with(cx_a, |mw, _| mw.workspaces().count())
.unwrap(),
1,
"other windows should not be activated just because they also match the saved paths"
@@ -4029,19 +4053,20 @@ async fn test_archive_thread_uses_next_threads_own_workspace(cx: &mut TestAppCon
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(main_project.clone(), window, cx));
+ let sidebar = setup_sidebar(&multi_workspace, cx);
+
let worktree_workspace = multi_workspace.update_in(cx, |mw, window, cx| {
mw.test_add_workspace(worktree_project.clone(), window, cx)
});
// Activate main workspace so the sidebar tracks the main panel.
multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[0].clone();
+ let workspace = mw.workspaces().next().unwrap().clone();
mw.activate(workspace, window, cx);
});
- let sidebar = setup_sidebar(&multi_workspace, cx);
-
- let main_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspaces()[0].clone());
+ let main_workspace =
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().next().unwrap().clone());
let main_panel = add_agent_panel(&main_workspace, cx);
let _worktree_panel = add_agent_panel(&worktree_workspace, cx);
@@ -4195,10 +4220,10 @@ async fn test_linked_worktree_threads_not_duplicated_across_groups(cx: &mut Test
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_only.clone(), window, cx));
+ let sidebar = setup_sidebar(&multi_workspace, cx);
multi_workspace.update_in(cx, |mw, window, cx| {
mw.test_add_workspace(multi_root.clone(), window, cx);
});
- let sidebar = setup_sidebar(&multi_workspace, cx);
// Save a thread under the linked worktree path.
save_named_thread_metadata("wt-thread", "Worktree Thread", &worktree_project, cx).await;
@@ -4313,8 +4338,8 @@ async fn test_thread_switcher_ordering(cx: &mut TestAppContext) {
// so all three have last_accessed_at set.
// Access order is: A (most recent), B, C (oldest).
- // ── 1. Open switcher: threads sorted by last_accessed_at ───────────
- open_and_focus_sidebar(&sidebar, cx);
+ // ── 1. Open switcher: threads sorted by last_accessed_at ─────────────────
+ focus_sidebar(&sidebar, cx);
sidebar.update_in(cx, |sidebar, window, cx| {
sidebar.on_toggle_thread_switcher(&ToggleThreadSwitcher::default(), window, cx);
});
@@ -4759,6 +4784,284 @@ async fn test_linked_worktree_workspace_shows_main_worktree_threads(cx: &mut Tes
);
}
+async fn init_multi_project_test(
+ paths: &[&str],
+ cx: &mut TestAppContext,
+) -> (Arc<FakeFs>, Entity<project::Project>) {
+ agent_ui::test_support::init_test(cx);
+ cx.update(|cx| {
+ cx.update_flags(false, vec!["agent-v2".into()]);
+ ThreadStore::init_global(cx);
+ ThreadMetadataStore::init_global(cx);
+ language_model::LanguageModelRegistry::test(cx);
+ prompt_store::init(cx);
+ });
+ let fs = FakeFs::new(cx.executor());
+ for path in paths {
+ fs.insert_tree(path, serde_json::json!({ ".git": {}, "src": {} }))
+ .await;
+ }
+ cx.update(|cx| <dyn fs::Fs>::set_global(fs.clone(), cx));
+ let project =
+ project::Project::test(fs.clone() as Arc<dyn fs::Fs>, [paths[0].as_ref()], cx).await;
+ (fs, project)
+}
+
+async fn add_test_project(
+ path: &str,
+ fs: &Arc<FakeFs>,
+ multi_workspace: &Entity<MultiWorkspace>,
+ cx: &mut gpui::VisualTestContext,
+) -> Entity<Workspace> {
+ let project = project::Project::test(fs.clone() as Arc<dyn fs::Fs>, [path.as_ref()], cx).await;
+ let workspace = multi_workspace.update_in(cx, |mw, window, cx| {
+ mw.test_add_workspace(project, window, cx)
+ });
+ cx.run_until_parked();
+ workspace
+}
+
+#[gpui::test]
+async fn test_transient_workspace_lifecycle(cx: &mut TestAppContext) {
+ let (fs, project_a) =
+ init_multi_project_test(&["/project-a", "/project-b", "/project-c"], cx).await;
+ let (multi_workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
+ let _sidebar = setup_sidebar_closed(&multi_workspace, cx);
+
+ // Sidebar starts closed. Initial workspace A is transient.
+ let workspace_a = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+ assert!(!multi_workspace.read_with(cx, |mw, _| mw.sidebar_open()));
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 1
+ );
+ assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_a));
+
+ // Add B — replaces A as the transient workspace.
+ let workspace_b = add_test_project("/project-b", &fs, &multi_workspace, cx).await;
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 1
+ );
+ assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_b));
+
+ // Add C — replaces B as the transient workspace.
+ let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await;
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 1
+ );
+ assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c));
+}
+
+#[gpui::test]
+async fn test_transient_workspace_retained(cx: &mut TestAppContext) {
+ let (fs, project_a) = init_multi_project_test(
+ &["/project-a", "/project-b", "/project-c", "/project-d"],
+ cx,
+ )
+ .await;
+ let (multi_workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
+ let _sidebar = setup_sidebar(&multi_workspace, cx);
+ assert!(multi_workspace.read_with(cx, |mw, _| mw.sidebar_open()));
+
+ // Add B — retained since sidebar is open.
+ let workspace_a = add_test_project("/project-b", &fs, &multi_workspace, cx).await;
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 2
+ );
+
+ // Switch to A — B survives. (Switching from one internal workspace, to another)
+ multi_workspace.update_in(cx, |mw, window, cx| mw.activate(workspace_a, window, cx));
+ cx.run_until_parked();
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 2
+ );
+
+ // Close sidebar — both A and B remain retained.
+ multi_workspace.update_in(cx, |mw, window, cx| mw.close_sidebar(window, cx));
+ cx.run_until_parked();
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 2
+ );
+
+ // Add C — added as new transient workspace. (switching from retained, to transient)
+ let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await;
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 3
+ );
+ assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c));
+
+ // Add D — replaces C as the transient workspace (Have retained and transient workspaces, transient workspace is dropped)
+ let workspace_d = add_test_project("/project-d", &fs, &multi_workspace, cx).await;
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 3
+ );
+ assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_d));
+}
+
+#[gpui::test]
+async fn test_transient_workspace_promotion(cx: &mut TestAppContext) {
+ let (fs, project_a) =
+ init_multi_project_test(&["/project-a", "/project-b", "/project-c"], cx).await;
+ let (multi_workspace, cx) =
+ cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
+ setup_sidebar_closed(&multi_workspace, cx);
+
+ // Add B — replaces A as the transient workspace (A is discarded).
+ let workspace_b = add_test_project("/project-b", &fs, &multi_workspace, cx).await;
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 1
+ );
+ assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_b));
+
+ // Open sidebar — promotes the transient B to retained.
+ multi_workspace.update_in(cx, |mw, window, cx| {
+ mw.toggle_sidebar(window, cx);
+ });
+ cx.run_until_parked();
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 1
+ );
+ assert!(multi_workspace.read_with(cx, |mw, _| mw.workspaces().any(|w| w == &workspace_b)));
+
+ // Close sidebar — the retained B remains.
+ multi_workspace.update_in(cx, |mw, window, cx| {
+ mw.toggle_sidebar(window, cx);
+ });
+
+ // Add C — added as new transient workspace.
+ let workspace_c = add_test_project("/project-c", &fs, &multi_workspace, cx).await;
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 2
+ );
+ assert!(multi_workspace.read_with(cx, |mw, _| mw.workspace() == &workspace_c));
+}
+
+#[gpui::test]
+async fn test_legacy_thread_with_canonical_path_opens_main_repo_workspace(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+
+ fs.insert_tree(
+ "/project",
+ serde_json::json!({
+ ".git": {
+ "worktrees": {
+ "feature-a": {
+ "commondir": "../../",
+ "HEAD": "ref: refs/heads/feature-a",
+ },
+ },
+ },
+ "src": {},
+ }),
+ )
+ .await;
+
+ fs.insert_tree(
+ "/wt-feature-a",
+ serde_json::json!({
+ ".git": "gitdir: /project/.git/worktrees/feature-a",
+ "src": {},
+ }),
+ )
+ .await;
+
+ fs.add_linked_worktree_for_repo(
+ Path::new("/project/.git"),
+ false,
+ git::repository::Worktree {
+ path: PathBuf::from("/wt-feature-a"),
+ ref_name: Some("refs/heads/feature-a".into()),
+ sha: "abc".into(),
+ is_main: false,
+ },
+ )
+ .await;
+
+ cx.update(|cx| <dyn fs::Fs>::set_global(fs.clone(), cx));
+
+ // Only a linked worktree workspace is open — no workspace for /project.
+ let worktree_project = project::Project::test(fs.clone(), ["/wt-feature-a".as_ref()], cx).await;
+ worktree_project
+ .update(cx, |p, cx| p.git_scans_complete(cx))
+ .await;
+
+ let (multi_workspace, cx) = cx.add_window_view(|window, cx| {
+ MultiWorkspace::test_new(worktree_project.clone(), window, cx)
+ });
+ let sidebar = setup_sidebar(&multi_workspace, cx);
+
+ // Save a legacy thread: folder_paths = main repo, main_worktree_paths = empty.
+ let legacy_session = acp::SessionId::new(Arc::from("legacy-main-thread"));
+ cx.update(|_, cx| {
+ let metadata = ThreadMetadata {
+ session_id: legacy_session.clone(),
+ agent_id: agent::ZED_AGENT_ID.clone(),
+ title: "Legacy Main Thread".into(),
+ updated_at: chrono::TimeZone::with_ymd_and_hms(&Utc, 2024, 1, 1, 0, 0, 0).unwrap(),
+ created_at: None,
+ folder_paths: PathList::new(&[PathBuf::from("/project")]),
+ main_worktree_paths: PathList::default(),
+ archived: false,
+ };
+ ThreadMetadataStore::global(cx).update(cx, |store, cx| store.save_manually(metadata, cx));
+ });
+ cx.run_until_parked();
+
+ multi_workspace.update_in(cx, |_, _window, cx| cx.notify());
+ cx.run_until_parked();
+
+ // The legacy thread should appear in the sidebar under the project group.
+ let entries = visible_entries_as_strings(&sidebar, cx);
+ assert!(
+ entries.iter().any(|e| e.contains("Legacy Main Thread")),
+ "legacy thread should be visible: {entries:?}",
+ );
+
+ // Verify only 1 workspace before clicking.
+ assert_eq!(
+ multi_workspace.read_with(cx, |mw, _| mw.workspaces().count()),
+ 1,
+ );
+
+ // Focus and select the legacy thread, then confirm.
+ focus_sidebar(&sidebar, cx);
+ let thread_index = sidebar.read_with(cx, |sidebar, _| {
+ sidebar
+ .contents
+ .entries
+ .iter()
+ .position(|e| e.session_id().is_some_and(|id| id == &legacy_session))
+ .expect("legacy thread should be in entries")
+ });
+ sidebar.update_in(cx, |sidebar, _window, _cx| {
+ sidebar.selection = Some(thread_index);
+ });
+ cx.dispatch_action(Confirm);
+ cx.run_until_parked();
+
+ let new_workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
+ let new_path_list =
+ new_workspace.read_with(cx, |_, cx| workspace_path_list(&new_workspace, cx));
+ assert_eq!(
+ new_path_list,
+ PathList::new(&[PathBuf::from("/project")]),
+ "the new workspace should be for the main repo, not the linked worktree",
+ );
+}
+
mod property_test {
use super::*;
@@ -4943,7 +5246,12 @@ mod property_test {
match operation {
Operation::SaveThread { workspace_index } => {
let project = multi_workspace.read_with(cx, |mw, cx| {
- mw.workspaces()[workspace_index].read(cx).project().clone()
+ mw.workspaces()
+ .nth(workspace_index)
+ .unwrap()
+ .read(cx)
+ .project()
+ .clone()
});
save_thread_to_path(state, &project, cx);
}
@@ -5030,7 +5338,7 @@ mod property_test {
}
Operation::RemoveWorkspace { index } => {
let removed = multi_workspace.update_in(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[index].clone();
+ let workspace = mw.workspaces().nth(index).unwrap().clone();
mw.remove(&workspace, window, cx)
});
if removed {
@@ -5044,8 +5352,8 @@ mod property_test {
}
}
Operation::SwitchWorkspace { index } => {
- let workspace =
- multi_workspace.read_with(cx, |mw, _| mw.workspaces()[index].clone());
+ let workspace = multi_workspace
+ .read_with(cx, |mw, _| mw.workspaces().nth(index).unwrap().clone());
multi_workspace.update_in(cx, |mw, window, cx| {
mw.activate(workspace, window, cx);
});
@@ -850,6 +850,7 @@ impl TerminalView {
fn send_text(&mut self, text: &SendText, _: &mut Window, cx: &mut Context<Self>) {
self.clear_bell(cx);
+ self.blink_manager.update(cx, BlinkManager::pause_blinking);
self.terminal.update(cx, |term, _| {
term.input(text.0.to_string().into_bytes());
});
@@ -858,6 +859,7 @@ impl TerminalView {
fn send_keystroke(&mut self, text: &SendKeystroke, _: &mut Window, cx: &mut Context<Self>) {
if let Some(keystroke) = Keystroke::parse(&text.0).log_err() {
self.clear_bell(cx);
+ self.blink_manager.update(cx, BlinkManager::pause_blinking);
self.process_keystroke(&keystroke, cx);
}
}
@@ -740,7 +740,6 @@ impl TitleBar {
.map(|mw| {
mw.read(cx)
.workspaces()
- .iter()
.filter_map(|ws| ws.read(cx).database_id())
.collect()
})
@@ -803,7 +802,6 @@ impl TitleBar {
.map(|mw| {
mw.read(cx)
.workspaces()
- .iter()
.filter_map(|ws| ws.read(cx).database_id())
.collect()
})
@@ -7,7 +7,7 @@ use editor::{
},
};
use gpui::{Action, Context, Window, actions, px};
-use language::{CharKind, Point, Selection, SelectionGoal};
+use language::{CharKind, Point, Selection, SelectionGoal, TextObject, TreeSitterOptions};
use multi_buffer::MultiBufferRow;
use schemars::JsonSchema;
use serde::Deserialize;
@@ -2451,6 +2451,10 @@ fn find_matching_bracket_text_based(
.take_while(|(_, char_offset)| *char_offset < line_range.end)
.find_map(|(ch, char_offset)| get_bracket_pair(ch).map(|info| (info, char_offset)));
+ if bracket_info.is_none() {
+ return find_matching_c_preprocessor_directive(map, line_range);
+ }
+
let (open, close, is_opening) = bracket_info?.0;
let bracket_offset = bracket_info?.1;
@@ -2482,6 +2486,122 @@ fn find_matching_bracket_text_based(
None
}
+fn find_matching_c_preprocessor_directive(
+ map: &DisplaySnapshot,
+ line_range: Range<MultiBufferOffset>,
+) -> Option<MultiBufferOffset> {
+ let line_start = map
+ .buffer_chars_at(line_range.start)
+ .skip_while(|(c, _)| *c == ' ' || *c == '\t')
+ .map(|(c, _)| c)
+ .take(6)
+ .collect::<String>();
+
+ if line_start.starts_with("#if")
+ || line_start.starts_with("#else")
+ || line_start.starts_with("#elif")
+ {
+ let mut depth = 0i32;
+ for (ch, char_offset) in map.buffer_chars_at(line_range.end) {
+ if ch != '\n' {
+ continue;
+ }
+ let mut line_offset = char_offset + '\n'.len_utf8();
+
+ // Skip leading whitespace
+ map.buffer_chars_at(line_offset)
+ .take_while(|(c, _)| *c == ' ' || *c == '\t')
+ .for_each(|(_, _)| line_offset += 1);
+
+ // Check what directive starts the next line
+ let next_line_start = map
+ .buffer_chars_at(line_offset)
+ .map(|(c, _)| c)
+ .take(6)
+ .collect::<String>();
+
+ if next_line_start.starts_with("#if") {
+ depth += 1;
+ } else if next_line_start.starts_with("#endif") {
+ if depth > 0 {
+ depth -= 1;
+ } else {
+ return Some(line_offset);
+ }
+ } else if next_line_start.starts_with("#else") || next_line_start.starts_with("#elif") {
+ if depth == 0 {
+ return Some(line_offset);
+ }
+ }
+ }
+ } else if line_start.starts_with("#endif") {
+ let mut depth = 0i32;
+ for (ch, char_offset) in
+ map.reverse_buffer_chars_at(line_range.start.saturating_sub_usize(1))
+ {
+ let mut line_offset = if char_offset == MultiBufferOffset(0) {
+ MultiBufferOffset(0)
+ } else if ch != '\n' {
+ continue;
+ } else {
+ char_offset + '\n'.len_utf8()
+ };
+
+ // Skip leading whitespace
+ map.buffer_chars_at(line_offset)
+ .take_while(|(c, _)| *c == ' ' || *c == '\t')
+ .for_each(|(_, _)| line_offset += 1);
+
+ // Check what directive starts this line
+ let line_start = map
+ .buffer_chars_at(line_offset)
+ .skip_while(|(c, _)| *c == ' ' || *c == '\t')
+ .map(|(c, _)| c)
+ .take(6)
+ .collect::<String>();
+
+ if line_start.starts_with("\n\n") {
+ // empty line
+ continue;
+ } else if line_start.starts_with("#endif") {
+ depth += 1;
+ } else if line_start.starts_with("#if") {
+ if depth > 0 {
+ depth -= 1;
+ } else {
+ return Some(line_offset);
+ }
+ }
+ }
+ }
+ None
+}
+
+fn comment_delimiter_pair(
+ map: &DisplaySnapshot,
+ offset: MultiBufferOffset,
+) -> Option<(Range<MultiBufferOffset>, Range<MultiBufferOffset>)> {
+ let snapshot = map.buffer_snapshot();
+ snapshot
+ .text_object_ranges(offset..offset, TreeSitterOptions::default())
+ .find_map(|(range, obj)| {
+ if !matches!(obj, TextObject::InsideComment | TextObject::AroundComment)
+ || !range.contains(&offset)
+ {
+ return None;
+ }
+
+ let mut chars = snapshot.chars_at(range.start);
+ if (Some('/'), Some('*')) != (chars.next(), chars.next()) {
+ return None;
+ }
+
+ let open_range = range.start..range.start + 2usize;
+ let close_range = range.end - 2..range.end;
+ Some((open_range, close_range))
+ })
+}
+
fn matching(
map: &DisplaySnapshot,
display_point: DisplayPoint,
@@ -2609,6 +2729,32 @@ fn matching(
continue;
}
+ if let Some((open_range, close_range)) = comment_delimiter_pair(map, offset) {
+ if open_range.contains(&offset) {
+ return close_range.start.to_display_point(map);
+ }
+
+ if close_range.contains(&offset) {
+ return open_range.start.to_display_point(map);
+ }
+
+ let open_candidate = (open_range.start >= offset
+ && line_range.contains(&open_range.start))
+ .then_some((open_range.start.saturating_sub(offset), close_range.start));
+
+ let close_candidate = (close_range.start >= offset
+ && line_range.contains(&close_range.start))
+ .then_some((close_range.start.saturating_sub(offset), open_range.start));
+
+ if let Some((_, destination)) = [open_candidate, close_candidate]
+ .into_iter()
+ .flatten()
+ .min_by_key(|(distance, _)| *distance)
+ {
+ return destination.to_display_point(map);
+ }
+ }
+
closest_pair_destination
.map(|destination| destination.to_display_point(map))
.unwrap_or_else(|| {
@@ -3497,6 +3643,119 @@ mod test {
);
}
+ #[gpui::test]
+ async fn test_matching_comments(cx: &mut gpui::TestAppContext) {
+ let mut cx = NeovimBackedTestContext::new(cx).await;
+
+ cx.set_shared_state(indoc! {r"ˇ/*
+ this is a comment
+ */"})
+ .await;
+ cx.simulate_shared_keystrokes("%").await;
+ cx.shared_state().await.assert_eq(indoc! {r"/*
+ this is a comment
+ ˇ*/"});
+ cx.simulate_shared_keystrokes("%").await;
+ cx.shared_state().await.assert_eq(indoc! {r"ˇ/*
+ this is a comment
+ */"});
+ cx.simulate_shared_keystrokes("%").await;
+ cx.shared_state().await.assert_eq(indoc! {r"/*
+ this is a comment
+ ˇ*/"});
+
+ cx.set_shared_state("ˇ// comment").await;
+ cx.simulate_shared_keystrokes("%").await;
+ cx.shared_state().await.assert_eq("ˇ// comment");
+ }
+
+ #[gpui::test]
+ async fn test_matching_preprocessor_directives(cx: &mut gpui::TestAppContext) {
+ let mut cx = NeovimBackedTestContext::new(cx).await;
+
+ cx.set_shared_state(indoc! {r"#ˇif
+
+ #else
+
+ #endif
+ "})
+ .await;
+ cx.simulate_shared_keystrokes("%").await;
+ cx.shared_state().await.assert_eq(indoc! {r"#if
+
+ ˇ#else
+
+ #endif
+ "});
+
+ cx.simulate_shared_keystrokes("%").await;
+ cx.shared_state().await.assert_eq(indoc! {r"#if
+
+ #else
+
+ ˇ#endif
+ "});
+
+ cx.simulate_shared_keystrokes("%").await;
+ cx.shared_state().await.assert_eq(indoc! {r"ˇ#if
+
+ #else
+
+ #endif
+ "});
+
+ cx.set_shared_state(indoc! {r"
+ #ˇif
+ #if
+
+ #else
+
+ #endif
+
+ #else
+ #endif
+ "})
+ .await;
+
+ cx.simulate_shared_keystrokes("%").await;
+ cx.shared_state().await.assert_eq(indoc! {r"
+ #if
+ #if
+
+ #else
+
+ #endif
+
+ ˇ#else
+ #endif
+ "});
+
+ cx.simulate_shared_keystrokes("% %").await;
+ cx.shared_state().await.assert_eq(indoc! {r"
+ ˇ#if
+ #if
+
+ #else
+
+ #endif
+
+ #else
+ #endif
+ "});
+ cx.simulate_shared_keystrokes("j % % %").await;
+ cx.shared_state().await.assert_eq(indoc! {r"
+ #if
+ ˇ#if
+
+ #else
+
+ #endif
+
+ #else
+ #endif
+ "});
+ }
+
#[gpui::test]
async fn test_unmatched_forward(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
@@ -245,7 +245,7 @@ impl Vim {
search_bar.set_replacement(None, cx);
let mut options = SearchOptions::NONE;
- if action.regex {
+ if action.regex && VimSettings::get_global(cx).use_regex_search {
options |= SearchOptions::REGEX;
}
if action.backwards {
@@ -1446,4 +1446,66 @@ mod test {
// The cursor should be at the match location on line 3 (row 2).
cx.assert_state("hello world\nfoo bar\nhello ˇagain\n", Mode::Normal);
}
+
+ #[gpui::test]
+ async fn test_vim_search_respects_search_settings(cx: &mut gpui::TestAppContext) {
+ let mut cx = VimTestContext::new(cx, true).await;
+
+ cx.update_global(|store: &mut SettingsStore, cx| {
+ store.update_user_settings(cx, |settings| {
+ settings.vim.get_or_insert_default().use_regex_search = Some(false);
+ });
+ });
+
+ cx.set_state("ˇcontent", Mode::Normal);
+ cx.simulate_keystrokes("/");
+ cx.run_until_parked();
+
+ // Verify search options are set from settings
+ let search_bar = cx.workspace(|workspace, _, cx| {
+ workspace
+ .active_pane()
+ .read(cx)
+ .toolbar()
+ .read(cx)
+ .item_of_type::<BufferSearchBar>()
+ .expect("Buffer search bar should be active")
+ });
+
+ cx.update_entity(search_bar, |bar, _window, _cx| {
+ assert!(
+ !bar.has_search_option(search::SearchOptions::REGEX),
+ "Vim search open without regex mode"
+ );
+ });
+
+ cx.simulate_keystrokes("escape");
+ cx.run_until_parked();
+
+ cx.update_global(|store: &mut SettingsStore, cx| {
+ store.update_user_settings(cx, |settings| {
+ settings.vim.get_or_insert_default().use_regex_search = Some(true);
+ });
+ });
+
+ cx.simulate_keystrokes("/");
+ cx.run_until_parked();
+
+ let search_bar = cx.workspace(|workspace, _, cx| {
+ workspace
+ .active_pane()
+ .read(cx)
+ .toolbar()
+ .read(cx)
+ .item_of_type::<BufferSearchBar>()
+ .expect("Buffer search bar should be active")
+ });
+
+ cx.update_entity(search_bar, |bar, _window, _cx| {
+ assert!(
+ bar.has_search_option(search::SearchOptions::REGEX),
+ "Vim search opens with regex mode"
+ );
+ });
+ }
}
@@ -2141,6 +2141,7 @@ struct VimSettings {
pub toggle_relative_line_numbers: bool,
pub use_system_clipboard: settings::UseSystemClipboard,
pub use_smartcase_find: bool,
+ pub use_regex_search: bool,
pub gdefault: bool,
pub custom_digraphs: HashMap<String, Arc<str>>,
pub highlight_on_yank_duration: u64,
@@ -2227,6 +2228,7 @@ impl Settings for VimSettings {
toggle_relative_line_numbers: vim.toggle_relative_line_numbers.unwrap(),
use_system_clipboard: vim.use_system_clipboard.unwrap(),
use_smartcase_find: vim.use_smartcase_find.unwrap(),
+ use_regex_search: vim.use_regex_search.unwrap(),
gdefault: vim.gdefault.unwrap(),
custom_digraphs: vim.custom_digraphs.unwrap(),
highlight_on_yank_duration: vim.highlight_on_yank_duration.unwrap(),
@@ -0,0 +1,10 @@
+{"Put":{"state":"ˇ/*\n this is a comment\n*/"}}
+{"Key":"%"}
+{"Get":{"state":"/*\n this is a comment\nˇ*/","mode":"Normal"}}
+{"Key":"%"}
+{"Get":{"state":"ˇ/*\n this is a comment\n*/","mode":"Normal"}}
+{"Key":"%"}
+{"Get":{"state":"/*\n this is a comment\nˇ*/","mode":"Normal"}}
+{"Put":{"state":"ˇ// comment"}}
+{"Key":"%"}
+{"Get":{"state":"ˇ// comment","mode":"Normal"}}
@@ -0,0 +1,18 @@
+{"Put":{"state":"#ˇif\n\n#else\n\n#endif\n"}}
+{"Key":"%"}
+{"Get":{"state":"#if\n\nˇ#else\n\n#endif\n","mode":"Normal"}}
+{"Key":"%"}
+{"Get":{"state":"#if\n\n#else\n\nˇ#endif\n","mode":"Normal"}}
+{"Key":"%"}
+{"Get":{"state":"ˇ#if\n\n#else\n\n#endif\n","mode":"Normal"}}
+{"Put":{"state":"#ˇif\n #if\n\n #else\n\n #endif\n\n#else\n#endif\n"}}
+{"Key":"%"}
+{"Get":{"state":"#if\n #if\n\n #else\n\n #endif\n\nˇ#else\n#endif\n","mode":"Normal"}}
+{"Key":"%"}
+{"Key":"%"}
+{"Get":{"state":"ˇ#if\n #if\n\n #else\n\n #endif\n\n#else\n#endif\n","mode":"Normal"}}
+{"Key":"j"}
+{"Key":"%"}
+{"Key":"%"}
+{"Key":"%"}
+{"Get":{"state":"#if\n ˇ#if\n\n #else\n\n #endif\n\n#else\n#endif\n","mode":"Normal"}}
@@ -6,9 +6,7 @@ use gpui::{
ManagedView, MouseButton, Pixels, Render, Subscription, Task, Tiling, Window, WindowId,
actions, deferred, px,
};
-#[cfg(any(test, feature = "test-support"))]
-use project::Project;
-use project::{DirectoryLister, DisableAiSettings, ProjectGroupKey};
+use project::{DirectoryLister, DisableAiSettings, Project, ProjectGroupKey};
use settings::Settings;
pub use settings::SidebarSide;
use std::future::Future;
@@ -42,10 +40,7 @@ actions!(
CloseWorkspaceSidebar,
/// Moves focus to or from the workspace sidebar without closing it.
FocusWorkspaceSidebar,
- /// Switches to the next workspace.
- NextWorkspace,
- /// Switches to the previous workspace.
- PreviousWorkspace,
+ //TODO: Restore next/previous workspace
]
);
@@ -223,10 +218,57 @@ impl<T: Sidebar> SidebarHandle for Entity<T> {
}
}
+/// Tracks which workspace the user is currently looking at.
+///
+/// `Persistent` workspaces live in the `workspaces` vec and are shown in the
+/// sidebar. `Transient` workspaces exist outside the vec and are discarded
+/// when the user switches away.
+enum ActiveWorkspace {
+ /// A persistent workspace, identified by index into the `workspaces` vec.
+ Persistent(usize),
+ /// A workspace not in the `workspaces` vec that will be discarded on
+ /// switch or promoted to persistent when the sidebar is opened.
+ Transient(Entity<Workspace>),
+}
+
+impl ActiveWorkspace {
+ fn persistent_index(&self) -> Option<usize> {
+ match self {
+ Self::Persistent(index) => Some(*index),
+ Self::Transient(_) => None,
+ }
+ }
+
+ fn transient_workspace(&self) -> Option<&Entity<Workspace>> {
+ match self {
+ Self::Transient(workspace) => Some(workspace),
+ Self::Persistent(_) => None,
+ }
+ }
+
+ /// Sets the active workspace to transient, returning the previous
+ /// transient workspace (if any).
+ fn set_transient(&mut self, workspace: Entity<Workspace>) -> Option<Entity<Workspace>> {
+ match std::mem::replace(self, Self::Transient(workspace)) {
+ Self::Transient(old) => Some(old),
+ Self::Persistent(_) => None,
+ }
+ }
+
+ /// Sets the active workspace to persistent at the given index,
+ /// returning the previous transient workspace (if any).
+ fn set_persistent(&mut self, index: usize) -> Option<Entity<Workspace>> {
+ match std::mem::replace(self, Self::Persistent(index)) {
+ Self::Transient(workspace) => Some(workspace),
+ Self::Persistent(_) => None,
+ }
+ }
+}
+
pub struct MultiWorkspace {
window_id: WindowId,
workspaces: Vec<Entity<Workspace>>,
- active_workspace_index: usize,
+ active_workspace: ActiveWorkspace,
project_group_keys: Vec<ProjectGroupKey>,
sidebar: Option<Box<dyn SidebarHandle>>,
sidebar_open: bool,
@@ -262,12 +304,15 @@ impl MultiWorkspace {
}
});
let quit_subscription = cx.on_app_quit(Self::app_will_quit);
- let settings_subscription =
- cx.observe_global_in::<settings::SettingsStore>(window, |this, window, cx| {
- if DisableAiSettings::get_global(cx).disable_ai && this.sidebar_open {
- this.close_sidebar(window, cx);
+ let settings_subscription = cx.observe_global_in::<settings::SettingsStore>(window, {
+ let mut previous_disable_ai = DisableAiSettings::get_global(cx).disable_ai;
+ move |this, window, cx| {
+ if DisableAiSettings::get_global(cx).disable_ai != previous_disable_ai {
+ this.collapse_to_single_workspace(window, cx);
+ previous_disable_ai = DisableAiSettings::get_global(cx).disable_ai;
}
- });
+ }
+ });
Self::subscribe_to_workspace(&workspace, window, cx);
let weak_self = cx.weak_entity();
workspace.update(cx, |workspace, cx| {
@@ -275,9 +320,9 @@ impl MultiWorkspace {
});
Self {
window_id: window.window_handle().window_id(),
- project_group_keys: vec![workspace.read(cx).project_group_key(cx)],
- workspaces: vec![workspace],
- active_workspace_index: 0,
+ project_group_keys: Vec::new(),
+ workspaces: Vec::new(),
+ active_workspace: ActiveWorkspace::Transient(workspace),
sidebar: None,
sidebar_open: false,
sidebar_overlay: None,
@@ -339,7 +384,7 @@ impl MultiWorkspace {
return;
}
- if self.sidebar_open {
+ if self.sidebar_open() {
self.close_sidebar(window, cx);
} else {
self.open_sidebar(cx);
@@ -355,7 +400,7 @@ impl MultiWorkspace {
return;
}
- if self.sidebar_open {
+ if self.sidebar_open() {
self.close_sidebar(window, cx);
}
}
@@ -365,7 +410,7 @@ impl MultiWorkspace {
return;
}
- if self.sidebar_open {
+ if self.sidebar_open() {
let sidebar_is_focused = self
.sidebar
.as_ref()
@@ -390,8 +435,13 @@ impl MultiWorkspace {
pub fn open_sidebar(&mut self, cx: &mut Context<Self>) {
self.sidebar_open = true;
+ if let ActiveWorkspace::Transient(workspace) = &self.active_workspace {
+ let workspace = workspace.clone();
+ let index = self.promote_transient(workspace, cx);
+ self.active_workspace = ActiveWorkspace::Persistent(index);
+ }
let sidebar_focus_handle = self.sidebar.as_ref().map(|s| s.focus_handle(cx));
- for workspace in &self.workspaces {
+ for workspace in self.workspaces.iter() {
workspace.update(cx, |workspace, _cx| {
workspace.set_sidebar_focus_handle(sidebar_focus_handle.clone());
});
@@ -402,7 +452,7 @@ impl MultiWorkspace {
pub fn close_sidebar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.sidebar_open = false;
- for workspace in &self.workspaces {
+ for workspace in self.workspaces.iter() {
workspace.update(cx, |workspace, _cx| {
workspace.set_sidebar_focus_handle(None);
});
@@ -417,7 +467,7 @@ impl MultiWorkspace {
pub fn close_window(&mut self, _: &CloseWindow, window: &mut Window, cx: &mut Context<Self>) {
cx.spawn_in(window, async move |this, cx| {
let workspaces = this.update(cx, |multi_workspace, _cx| {
- multi_workspace.workspaces().to_vec()
+ multi_workspace.workspaces().cloned().collect::<Vec<_>>()
})?;
for workspace in workspaces {
@@ -468,6 +518,9 @@ impl MultiWorkspace {
}
pub fn add_project_group_key(&mut self, project_group_key: ProjectGroupKey) {
+ if project_group_key.path_list().paths().is_empty() {
+ return;
+ }
if self.project_group_keys.contains(&project_group_key) {
return;
}
@@ -649,13 +702,19 @@ impl MultiWorkspace {
if let Some(workspace) = self
.workspaces
.iter()
- .find(|ws| ws.read(cx).project_group_key(cx).path_list() == &path_list)
+ .find(|ws| PathList::new(&ws.read(cx).root_paths(cx)) == path_list)
.cloned()
{
self.activate(workspace.clone(), window, cx);
return Task::ready(Ok(workspace));
}
+ if let Some(transient) = self.active_workspace.transient_workspace() {
+ if transient.read(cx).project_group_key(cx).path_list() == &path_list {
+ return Task::ready(Ok(transient.clone()));
+ }
+ }
+
let paths = path_list.paths().to_vec();
let app_state = self.workspace().read(cx).app_state().clone();
let requesting_window = window.window_handle().downcast::<MultiWorkspace>();
@@ -679,25 +738,23 @@ impl MultiWorkspace {
}
pub fn workspace(&self) -> &Entity<Workspace> {
- &self.workspaces[self.active_workspace_index]
- }
-
- pub fn workspaces(&self) -> &[Entity<Workspace>] {
- &self.workspaces
+ match &self.active_workspace {
+ ActiveWorkspace::Persistent(index) => &self.workspaces[*index],
+ ActiveWorkspace::Transient(workspace) => workspace,
+ }
}
- pub fn active_workspace_index(&self) -> usize {
- self.active_workspace_index
+ pub fn workspaces(&self) -> impl Iterator<Item = &Entity<Workspace>> {
+ self.workspaces
+ .iter()
+ .chain(self.active_workspace.transient_workspace())
}
- /// Adds a workspace to this window without changing which workspace is
- /// active.
+ /// Adds a workspace to this window as persistent without changing which
+ /// workspace is active. Unlike `activate()`, this always inserts into the
+ /// persistent list regardless of sidebar state — it's used for system-
+ /// initiated additions like deserialization and worktree discovery.
pub fn add(&mut self, workspace: Entity<Workspace>, window: &Window, cx: &mut Context<Self>) {
- if !self.multi_workspace_enabled(cx) {
- self.set_single_workspace(workspace, cx);
- return;
- }
-
self.insert_workspace(workspace, window, cx);
}
@@ -708,26 +765,74 @@ impl MultiWorkspace {
window: &mut Window,
cx: &mut Context<Self>,
) {
- if !self.multi_workspace_enabled(cx) {
- self.set_single_workspace(workspace, cx);
+ // Re-activating the current workspace is a no-op.
+ if self.workspace() == &workspace {
+ self.focus_active_workspace(window, cx);
return;
}
- let index = self.insert_workspace(workspace, &*window, cx);
- let changed = self.active_workspace_index != index;
- self.active_workspace_index = index;
- if changed {
- cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
- self.serialize(cx);
+ // Resolve where we're going.
+ let new_index = if let Some(index) = self.workspaces.iter().position(|w| *w == workspace) {
+ Some(index)
+ } else if self.sidebar_open {
+ Some(self.insert_workspace(workspace.clone(), &*window, cx))
+ } else {
+ None
+ };
+
+ // Transition the active workspace.
+ if let Some(index) = new_index {
+ if let Some(old) = self.active_workspace.set_persistent(index) {
+ if self.sidebar_open {
+ self.promote_transient(old, cx);
+ } else {
+ self.detach_workspace(&old, cx);
+ cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old.entity_id()));
+ }
+ }
+ } else {
+ Self::subscribe_to_workspace(&workspace, window, cx);
+ let weak_self = cx.weak_entity();
+ workspace.update(cx, |workspace, cx| {
+ workspace.set_multi_workspace(weak_self, cx);
+ });
+ if let Some(old) = self.active_workspace.set_transient(workspace) {
+ self.detach_workspace(&old, cx);
+ cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old.entity_id()));
+ }
}
+
+ cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
+ self.serialize(cx);
self.focus_active_workspace(window, cx);
cx.notify();
}
- fn set_single_workspace(&mut self, workspace: Entity<Workspace>, cx: &mut Context<Self>) {
- self.workspaces[0] = workspace;
- self.active_workspace_index = 0;
- cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
+ /// Promotes a former transient workspace into the persistent list.
+ /// Returns the index of the newly inserted workspace.
+ fn promote_transient(&mut self, workspace: Entity<Workspace>, cx: &mut Context<Self>) -> usize {
+ let project_group_key = workspace.read(cx).project().read(cx).project_group_key(cx);
+ self.add_project_group_key(project_group_key);
+ self.workspaces.push(workspace.clone());
+ cx.emit(MultiWorkspaceEvent::WorkspaceAdded(workspace));
+ self.workspaces.len() - 1
+ }
+
+ /// Collapses to a single transient workspace, discarding all persistent
+ /// workspaces. Used when multi-workspace is disabled (e.g. disable_ai).
+ fn collapse_to_single_workspace(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ if self.sidebar_open {
+ self.close_sidebar(window, cx);
+ }
+ let active = self.workspace().clone();
+ for workspace in std::mem::take(&mut self.workspaces) {
+ if workspace != active {
+ self.detach_workspace(&workspace, cx);
+ cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(workspace.entity_id()));
+ }
+ }
+ self.project_group_keys.clear();
+ self.active_workspace = ActiveWorkspace::Transient(active);
cx.notify();
}
@@ -783,7 +888,7 @@ impl MultiWorkspace {
}
fn sync_sidebar_to_workspace(&self, workspace: &Entity<Workspace>, cx: &mut Context<Self>) {
- if self.sidebar_open {
+ if self.sidebar_open() {
let sidebar_focus_handle = self.sidebar.as_ref().map(|s| s.focus_handle(cx));
workspace.update(cx, |workspace, _| {
workspace.set_sidebar_focus_handle(sidebar_focus_handle);
@@ -791,30 +896,6 @@ impl MultiWorkspace {
}
}
- fn cycle_workspace(&mut self, delta: isize, window: &mut Window, cx: &mut Context<Self>) {
- let count = self.workspaces.len() as isize;
- if count <= 1 {
- return;
- }
- let current = self.active_workspace_index as isize;
- let next = ((current + delta).rem_euclid(count)) as usize;
- let workspace = self.workspaces[next].clone();
- self.activate(workspace, window, cx);
- }
-
- fn next_workspace(&mut self, _: &NextWorkspace, window: &mut Window, cx: &mut Context<Self>) {
- self.cycle_workspace(1, window, cx);
- }
-
- fn previous_workspace(
- &mut self,
- _: &PreviousWorkspace,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self.cycle_workspace(-1, window, cx);
- }
-
pub(crate) fn serialize(&mut self, cx: &mut Context<Self>) {
self._serialize_task = Some(cx.spawn(async move |this, cx| {
let Some((window_id, state)) = this
@@ -1040,26 +1121,82 @@ impl MultiWorkspace {
let Some(index) = self.workspaces.iter().position(|w| w == workspace) else {
return false;
};
+
+ let old_key = workspace.read(cx).project_group_key(cx);
+
if self.workspaces.len() <= 1 {
- return false;
- }
+ let has_worktrees = workspace.read(cx).visible_worktrees(cx).next().is_some();
- let removed_workspace = self.workspaces.remove(index);
+ if !has_worktrees {
+ return false;
+ }
- if self.active_workspace_index >= self.workspaces.len() {
- self.active_workspace_index = self.workspaces.len() - 1;
- } else if self.active_workspace_index > index {
- self.active_workspace_index -= 1;
+ let old_workspace = workspace.clone();
+ let old_entity_id = old_workspace.entity_id();
+
+ let app_state = old_workspace.read(cx).app_state().clone();
+
+ let project = Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ project::LocalProjectFlags::default(),
+ cx,
+ );
+
+ let new_workspace = cx.new(|cx| Workspace::new(None, project, app_state, window, cx));
+
+ self.workspaces[0] = new_workspace.clone();
+ self.active_workspace = ActiveWorkspace::Persistent(0);
+
+ Self::subscribe_to_workspace(&new_workspace, window, cx);
+
+ self.sync_sidebar_to_workspace(&new_workspace, cx);
+
+ let weak_self = cx.weak_entity();
+
+ new_workspace.update(cx, |workspace, cx| {
+ workspace.set_multi_workspace(weak_self, cx);
+ });
+
+ self.detach_workspace(&old_workspace, cx);
+
+ cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(old_entity_id));
+ cx.emit(MultiWorkspaceEvent::WorkspaceAdded(new_workspace));
+ cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
+ } else {
+ let removed_workspace = self.workspaces.remove(index);
+
+ if let Some(active_index) = self.active_workspace.persistent_index() {
+ if active_index >= self.workspaces.len() {
+ self.active_workspace = ActiveWorkspace::Persistent(self.workspaces.len() - 1);
+ } else if active_index > index {
+ self.active_workspace = ActiveWorkspace::Persistent(active_index - 1);
+ }
+ }
+
+ self.detach_workspace(&removed_workspace, cx);
+
+ cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(
+ removed_workspace.entity_id(),
+ ));
+ cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
}
- self.detach_workspace(&removed_workspace, cx);
+ let key_still_in_use = self
+ .workspaces
+ .iter()
+ .any(|ws| ws.read(cx).project_group_key(cx) == old_key);
+
+ if !key_still_in_use {
+ self.project_group_keys.retain(|k| k != &old_key);
+ }
self.serialize(cx);
self.focus_active_workspace(window, cx);
- cx.emit(MultiWorkspaceEvent::WorkspaceRemoved(
- removed_workspace.entity_id(),
- ));
- cx.emit(MultiWorkspaceEvent::ActiveWorkspaceChanged);
cx.notify();
true
@@ -1288,8 +1425,6 @@ impl Render for MultiWorkspace {
this.focus_sidebar(window, cx);
},
))
- .on_action(cx.listener(Self::next_workspace))
- .on_action(cx.listener(Self::previous_workspace))
.on_action(cx.listener(Self::move_active_workspace_to_new_window))
.on_action(cx.listener(
|this: &mut Self, action: &ToggleThreadSwitcher, window, cx| {
@@ -99,6 +99,10 @@ async fn test_project_group_keys_initial(cx: &mut TestAppContext) {
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project, window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
multi_workspace.read_with(cx, |mw, _cx| {
let keys: Vec<&ProjectGroupKey> = mw.project_group_keys().collect();
assert_eq!(keys.len(), 1, "should have exactly one key on creation");
@@ -125,6 +129,10 @@ async fn test_project_group_keys_add_workspace(cx: &mut TestAppContext) {
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
multi_workspace.read_with(cx, |mw, _cx| {
assert_eq!(mw.project_group_keys().count(), 1);
});
@@ -162,6 +170,10 @@ async fn test_project_group_keys_duplicate_not_added(cx: &mut TestAppContext) {
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a, window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
multi_workspace.update_in(cx, |mw, window, cx| {
mw.test_add_workspace(project_a2, window, cx);
});
@@ -189,6 +201,10 @@ async fn test_project_group_keys_on_worktree_added(cx: &mut TestAppContext) {
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
// Add a second worktree to the same project.
let (worktree, _) = project
.update(cx, |project, cx| {
@@ -232,6 +248,10 @@ async fn test_project_group_keys_on_worktree_removed(cx: &mut TestAppContext) {
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
// Remove one worktree.
let worktree_b_id = project.read_with(cx, |project, cx| {
project
@@ -282,6 +302,10 @@ async fn test_project_group_keys_across_multiple_workspaces_and_worktree_changes
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
multi_workspace.update_in(cx, |mw, window, cx| {
mw.test_add_workspace(project_b, window, cx);
});
@@ -3670,6 +3670,11 @@ impl Pane {
this.drag_split_direction = None;
this.handle_external_paths_drop(paths, window, cx)
}))
+ .on_click(cx.listener(move |this, event: &ClickEvent, window, cx| {
+ if event.click_count() == 2 {
+ window.dispatch_action(this.double_click_dispatch_action.boxed_clone(), cx);
+ }
+ }))
}
pub fn render_menu_overlay(menu: &Entity<ContextMenu>) -> Div {
@@ -4917,14 +4922,17 @@ impl Render for DraggedTab {
#[cfg(test)]
mod tests {
- use std::{cell::Cell, iter::zip, num::NonZero};
+ use std::{cell::Cell, iter::zip, num::NonZero, rc::Rc};
use super::*;
use crate::{
Member,
item::test::{TestItem, TestProjectItem},
};
- use gpui::{AppContext, Axis, TestAppContext, VisualTestContext, size};
+ use gpui::{
+ AppContext, Axis, Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent,
+ TestAppContext, VisualTestContext, size,
+ };
use project::FakeFs;
use settings::SettingsStore;
use theme::LoadThemes;
@@ -6649,8 +6657,6 @@ mod tests {
#[gpui::test]
async fn test_drag_tab_to_middle_tab_with_mouse_events(cx: &mut TestAppContext) {
- use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent};
-
init_test(cx);
let fs = FakeFs::new(cx.executor());
@@ -6702,8 +6708,6 @@ mod tests {
async fn test_drag_pinned_tab_when_show_pinned_tabs_in_separate_row_enabled(
cx: &mut TestAppContext,
) {
- use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent};
-
init_test(cx);
set_pinned_tabs_separate_row(cx, true);
let fs = FakeFs::new(cx.executor());
@@ -6779,8 +6783,6 @@ mod tests {
async fn test_drag_unpinned_tab_when_show_pinned_tabs_in_separate_row_enabled(
cx: &mut TestAppContext,
) {
- use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent};
-
init_test(cx);
set_pinned_tabs_separate_row(cx, true);
let fs = FakeFs::new(cx.executor());
@@ -6833,8 +6835,6 @@ mod tests {
async fn test_drag_mixed_tabs_when_show_pinned_tabs_in_separate_row_enabled(
cx: &mut TestAppContext,
) {
- use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent};
-
init_test(cx);
set_pinned_tabs_separate_row(cx, true);
let fs = FakeFs::new(cx.executor());
@@ -6900,8 +6900,6 @@ mod tests {
#[gpui::test]
async fn test_middle_click_pinned_tab_does_not_close(cx: &mut TestAppContext) {
- use gpui::{Modifiers, MouseButton, MouseDownEvent, MouseUpEvent};
-
init_test(cx);
let fs = FakeFs::new(cx.executor());
@@ -6971,6 +6969,74 @@ mod tests {
assert_item_labels(&pane, ["A*!"], cx);
}
+ #[gpui::test]
+ async fn test_double_click_pinned_tab_bar_empty_space_creates_new_tab(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+
+ let project = Project::test(fs, None, cx).await;
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+ let pane = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone());
+
+ // The real NewFile handler lives in editor::init, which isn't initialized
+ // in workspace tests. Register a global action handler that sets a flag so
+ // we can verify the action is dispatched without depending on the editor crate.
+ // TODO: If editor::init is ever available in workspace tests, remove this
+ // flag and assert the resulting tab bar state directly instead.
+ let new_file_dispatched = Rc::new(Cell::new(false));
+ cx.update(|_, cx| {
+ let new_file_dispatched = new_file_dispatched.clone();
+ cx.on_action(move |_: &NewFile, _cx| {
+ new_file_dispatched.set(true);
+ });
+ });
+
+ set_pinned_tabs_separate_row(cx, true);
+
+ let item_a = add_labeled_item(&pane, "A", false, cx);
+ add_labeled_item(&pane, "B", false, cx);
+
+ pane.update_in(cx, |pane, window, cx| {
+ let ix = pane
+ .index_for_item_id(item_a.item_id())
+ .expect("item A should exist");
+ pane.pin_tab_at(ix, window, cx);
+ });
+ assert_item_labels(&pane, ["A!", "B*"], cx);
+ cx.run_until_parked();
+
+ let pinned_drop_target_bounds = cx
+ .debug_bounds("pinned_tabs_border")
+ .expect("pinned_tabs_border should have debug bounds");
+
+ cx.simulate_event(MouseDownEvent {
+ position: pinned_drop_target_bounds.center(),
+ button: MouseButton::Left,
+ modifiers: Modifiers::default(),
+ click_count: 2,
+ first_mouse: false,
+ });
+
+ cx.run_until_parked();
+
+ cx.simulate_event(MouseUpEvent {
+ position: pinned_drop_target_bounds.center(),
+ button: MouseButton::Left,
+ modifiers: Modifiers::default(),
+ click_count: 2,
+ });
+
+ cx.run_until_parked();
+
+ // TODO: If editor::init is ever available in workspace tests, replace this
+ // with an assert_item_labels check that verifies a new tab is actually created.
+ assert!(
+ new_file_dispatched.get(),
+ "Double-clicking pinned tab bar empty space should dispatch the new file action"
+ );
+ }
+
#[gpui::test]
async fn test_add_item_with_new_item(cx: &mut TestAppContext) {
init_test(cx);
@@ -2535,6 +2535,10 @@ mod tests {
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
multi_workspace.update_in(cx, |mw, _, cx| {
mw.set_random_database_id(cx);
});
@@ -2564,7 +2568,7 @@ mod tests {
// --- Remove the second workspace (index 1) ---
multi_workspace.update_in(cx, |mw, window, cx| {
- let ws = mw.workspaces()[1].clone();
+ let ws = mw.workspaces().nth(1).unwrap().clone();
mw.remove(&ws, window, cx);
});
@@ -4191,6 +4195,10 @@ mod tests {
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
multi_workspace.update_in(cx, |mw, _, cx| {
mw.set_random_database_id(cx);
});
@@ -4233,7 +4241,7 @@ mod tests {
// Remove workspace at index 1 (the second workspace).
multi_workspace.update_in(cx, |mw, window, cx| {
- let ws = mw.workspaces()[1].clone();
+ let ws = mw.workspaces().nth(1).unwrap().clone();
mw.remove(&ws, window, cx);
});
@@ -4288,6 +4296,10 @@ mod tests {
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
multi_workspace.update_in(cx, |mw, _, cx| {
mw.workspace().update(cx, |ws, _cx| {
ws.set_database_id(ws1_id);
@@ -4339,7 +4351,7 @@ mod tests {
// Remove workspace2 (index 1).
multi_workspace.update_in(cx, |mw, window, cx| {
- let ws = mw.workspaces()[1].clone();
+ let ws = mw.workspaces().nth(1).unwrap().clone();
mw.remove(&ws, window, cx);
});
@@ -4385,6 +4397,10 @@ mod tests {
let (multi_workspace, cx) =
cx.add_window_view(|window, cx| MultiWorkspace::test_new(project1.clone(), window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
multi_workspace.update_in(cx, |mw, _, cx| {
mw.set_random_database_id(cx);
});
@@ -4418,7 +4434,7 @@ mod tests {
// Remove workspace2 — this pushes a task to pending_removal_tasks.
multi_workspace.update_in(cx, |mw, window, cx| {
- let ws = mw.workspaces()[1].clone();
+ let ws = mw.workspaces().nth(1).unwrap().clone();
mw.remove(&ws, window, cx);
});
@@ -4427,7 +4443,6 @@ mod tests {
let all_tasks = multi_workspace.update_in(cx, |mw, window, cx| {
let mut tasks: Vec<Task<()>> = mw
.workspaces()
- .iter()
.map(|workspace| {
workspace.update(cx, |workspace, cx| {
workspace.flush_serialization(window, cx)
@@ -4747,6 +4762,10 @@ mod tests {
let (multi_workspace, cx) = cx
.add_window_view(|window, cx| MultiWorkspace::test_new(project_2.clone(), window, cx));
+ multi_workspace.update(cx, |mw, cx| {
+ mw.open_sidebar(cx);
+ });
+
multi_workspace.update_in(cx, |mw, window, cx| {
mw.test_add_workspace(project_1.clone(), window, cx);
});
@@ -7,7 +7,7 @@ use std::{
};
use collections::{HashMap, HashSet};
-use gpui::{DismissEvent, EventEmitter, FocusHandle, Focusable, WeakEntity};
+use gpui::{DismissEvent, EventEmitter, FocusHandle, Focusable, ScrollHandle, WeakEntity};
use project::{
WorktreeId,
@@ -17,7 +17,8 @@ use project::{
use smallvec::SmallVec;
use theme::ActiveTheme;
use ui::{
- AlertModal, Checkbox, FluentBuilder, KeyBinding, ListBulletItem, ToggleState, prelude::*,
+ AlertModal, Checkbox, FluentBuilder, KeyBinding, ListBulletItem, ToggleState, WithScrollbar,
+ prelude::*,
};
use crate::{DismissDecision, ModalView, ToggleWorktreeSecurity};
@@ -29,6 +30,7 @@ pub struct SecurityModal {
worktree_store: WeakEntity<WorktreeStore>,
remote_host: Option<RemoteHostLocation>,
focus_handle: FocusHandle,
+ project_list_scroll_handle: ScrollHandle,
trusted: Option<bool>,
}
@@ -63,16 +65,17 @@ impl ModalView for SecurityModal {
}
impl Render for SecurityModal {
- fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
if self.restricted_paths.is_empty() {
self.dismiss(cx);
return v_flex().into_any_element();
}
- let header_label = if self.restricted_paths.len() == 1 {
- "Unrecognized Project"
+ let restricted_count = self.restricted_paths.len();
+ let header_label: SharedString = if restricted_count == 1 {
+ "Unrecognized Project".into()
} else {
- "Unrecognized Projects"
+ format!("Unrecognized Projects ({})", restricted_count).into()
};
let trust_label = self.build_trust_label();
@@ -102,32 +105,61 @@ impl Render for SecurityModal {
.child(Icon::new(IconName::Warning).color(Color::Warning))
.child(Label::new(header_label)),
)
- .children(self.restricted_paths.values().filter_map(|restricted_path| {
- let abs_path = if restricted_path.is_file {
- restricted_path.abs_path.parent()
- } else {
- Some(restricted_path.abs_path.as_ref())
- }?;
- let label = match &restricted_path.host {
- Some(remote_host) => match &remote_host.user_name {
- Some(user_name) => format!(
- "{} ({}@{})",
- self.shorten_path(abs_path).display(),
- user_name,
- remote_host.host_identifier
- ),
- None => format!(
- "{} ({})",
- self.shorten_path(abs_path).display(),
- remote_host.host_identifier
- ),
- },
- None => self.shorten_path(abs_path).display().to_string(),
- };
- Some(h_flex()
- .pl(IconSize::default().rems() + rems(0.5))
- .child(Label::new(label).color(Color::Muted)))
- })),
+ .child(
+ div()
+ .size_full()
+ .vertical_scrollbar_for(&self.project_list_scroll_handle, window, cx)
+ .child(
+ v_flex()
+ .id("paths_container")
+ .max_h_24()
+ .overflow_y_scroll()
+ .track_scroll(&self.project_list_scroll_handle)
+ .children(
+ self.restricted_paths.values().filter_map(
+ |restricted_path| {
+ let abs_path = if restricted_path.is_file {
+ restricted_path.abs_path.parent()
+ } else {
+ Some(restricted_path.abs_path.as_ref())
+ }?;
+ let label = match &restricted_path.host {
+ Some(remote_host) => {
+ match &remote_host.user_name {
+ Some(user_name) => format!(
+ "{} ({}@{})",
+ self.shorten_path(abs_path)
+ .display(),
+ user_name,
+ remote_host.host_identifier
+ ),
+ None => format!(
+ "{} ({})",
+ self.shorten_path(abs_path)
+ .display(),
+ remote_host.host_identifier
+ ),
+ }
+ }
+ None => self
+ .shorten_path(abs_path)
+ .display()
+ .to_string(),
+ };
+ Some(
+ h_flex()
+ .pl(
+ IconSize::default().rems() + rems(0.5),
+ )
+ .child(
+ Label::new(label).color(Color::Muted),
+ ),
+ )
+ },
+ ),
+ ),
+ ),
+ ),
)
.child(
v_flex()
@@ -219,6 +251,7 @@ impl SecurityModal {
remote_host: remote_host.map(|host| host.into()),
restricted_paths: HashMap::default(),
focus_handle: cx.focus_handle(),
+ project_list_scroll_handle: ScrollHandle::new(),
trust_parents: false,
home_dir: std::env::home_dir(),
trusted: None,
@@ -32,8 +32,8 @@ pub use crate::notifications::NotificationFrame;
pub use dock::Panel;
pub use multi_workspace::{
CloseWorkspaceSidebar, DraggedSidebar, FocusWorkspaceSidebar, MultiWorkspace,
- MultiWorkspaceEvent, NextWorkspace, PreviousWorkspace, Sidebar, SidebarEvent, SidebarHandle,
- SidebarRenderState, SidebarSide, ToggleWorkspaceSidebar, sidebar_side_context_menu,
+ MultiWorkspaceEvent, Sidebar, SidebarEvent, SidebarHandle, SidebarRenderState, SidebarSide,
+ ToggleWorkspaceSidebar, sidebar_side_context_menu,
};
pub use path_list::{PathList, SerializedPathList};
pub use toast_layer::{ToastAction, ToastLayer, ToastView};
@@ -9079,7 +9079,7 @@ pub fn workspace_windows_for_location(
};
multi_workspace.read(cx).is_ok_and(|multi_workspace| {
- multi_workspace.workspaces().iter().any(|workspace| {
+ multi_workspace.workspaces().any(|workspace| {
match workspace.read(cx).workspace_location(cx) {
WorkspaceLocation::Location(location, _) => {
match (&location, serialized_location) {
@@ -10741,6 +10741,12 @@ mod tests {
cx.add_window(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx));
cx.run_until_parked();
+ multi_workspace_handle
+ .update(cx, |mw, _window, cx| {
+ mw.open_sidebar(cx);
+ })
+ .unwrap();
+
let workspace_a = multi_workspace_handle
.read_with(cx, |mw, _| mw.workspace().clone())
.unwrap();
@@ -10754,7 +10760,7 @@ mod tests {
// Activate workspace A
multi_workspace_handle
.update(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[0].clone();
+ let workspace = mw.workspaces().next().unwrap().clone();
mw.activate(workspace, window, cx);
})
.unwrap();
@@ -10776,7 +10782,7 @@ mod tests {
// Verify workspace A is active
multi_workspace_handle
.read_with(cx, |mw, _| {
- assert_eq!(mw.active_workspace_index(), 0);
+ assert_eq!(mw.workspace(), &workspace_a);
})
.unwrap();
@@ -10792,8 +10798,8 @@ mod tests {
multi_workspace_handle
.read_with(cx, |mw, _| {
assert_eq!(
- mw.active_workspace_index(),
- 1,
+ mw.workspace(),
+ &workspace_b,
"workspace B should be activated when it prompts"
);
})
@@ -14511,6 +14517,12 @@ mod tests {
cx.add_window(|window, cx| MultiWorkspace::test_new(project_a.clone(), window, cx));
cx.run_until_parked();
+ multi_workspace_handle
+ .update(cx, |mw, _window, cx| {
+ mw.open_sidebar(cx);
+ })
+ .unwrap();
+
let workspace_a = multi_workspace_handle
.read_with(cx, |mw, _| mw.workspace().clone())
.unwrap();
@@ -14524,7 +14536,7 @@ mod tests {
// Switch to workspace A
multi_workspace_handle
.update(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[0].clone();
+ let workspace = mw.workspaces().next().unwrap().clone();
mw.activate(workspace, window, cx);
})
.unwrap();
@@ -14570,7 +14582,7 @@ mod tests {
// Switch to workspace B
multi_workspace_handle
.update(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[1].clone();
+ let workspace = mw.workspaces().nth(1).unwrap().clone();
mw.activate(workspace, window, cx);
})
.unwrap();
@@ -14579,7 +14591,7 @@ mod tests {
// Switch back to workspace A
multi_workspace_handle
.update(cx, |mw, window, cx| {
- let workspace = mw.workspaces()[0].clone();
+ let workspace = mw.workspaces().next().unwrap().clone();
mw.activate(workspace, window, cx);
})
.unwrap();
@@ -2606,7 +2606,7 @@ fn run_multi_workspace_sidebar_visual_tests(
// Add worktree to workspace 1 (index 0) so it shows as "private-test-remote"
let add_worktree1_task = multi_workspace_window
.update(cx, |multi_workspace, _window, cx| {
- let workspace1 = &multi_workspace.workspaces()[0];
+ let workspace1 = multi_workspace.workspaces().next().unwrap();
let project = workspace1.read(cx).project().clone();
project.update(cx, |project, cx| {
project.find_or_create_worktree(&workspace1_dir, true, cx)
@@ -2625,7 +2625,7 @@ fn run_multi_workspace_sidebar_visual_tests(
// Add worktree to workspace 2 (index 1) so it shows as "zed"
let add_worktree2_task = multi_workspace_window
.update(cx, |multi_workspace, _window, cx| {
- let workspace2 = &multi_workspace.workspaces()[1];
+ let workspace2 = multi_workspace.workspaces().nth(1).unwrap();
let project = workspace2.read(cx).project().clone();
project.update(cx, |project, cx| {
project.find_or_create_worktree(&workspace2_dir, true, cx)
@@ -2644,7 +2644,7 @@ fn run_multi_workspace_sidebar_visual_tests(
// Switch to workspace 1 so it's highlighted as active (index 0)
multi_workspace_window
.update(cx, |multi_workspace, window, cx| {
- let workspace = multi_workspace.workspaces()[0].clone();
+ let workspace = multi_workspace.workspaces().next().unwrap().clone();
multi_workspace.activate(workspace, window, cx);
})
.context("Failed to activate workspace 1")?;
@@ -2672,7 +2672,7 @@ fn run_multi_workspace_sidebar_visual_tests(
let save_tasks = multi_workspace_window
.update(cx, |multi_workspace, _window, cx| {
let thread_store = agent::ThreadStore::global(cx);
- let workspaces = multi_workspace.workspaces().to_vec();
+ let workspaces: Vec<_> = multi_workspace.workspaces().cloned().collect();
let mut tasks = Vec::new();
for (index, workspace) in workspaces.iter().enumerate() {
@@ -3211,7 +3211,7 @@ edition = "2021"
// Add the git project as a worktree
let add_worktree_task = workspace_window
.update(cx, |multi_workspace, _window, cx| {
- let workspace = &multi_workspace.workspaces()[0];
+ let workspace = multi_workspace.workspaces().next().unwrap();
let project = workspace.read(cx).project().clone();
project.update(cx, |project, cx| {
project.find_or_create_worktree(&project_path, true, cx)
@@ -3236,7 +3236,7 @@ edition = "2021"
// Open the project panel
let (weak_workspace, async_window_cx) = workspace_window
.update(cx, |multi_workspace, window, cx| {
- let workspace = &multi_workspace.workspaces()[0];
+ let workspace = multi_workspace.workspaces().next().unwrap();
(workspace.read(cx).weak_handle(), window.to_async(cx))
})
.context("Failed to get workspace handle")?;
@@ -3250,7 +3250,7 @@ edition = "2021"
workspace_window
.update(cx, |multi_workspace, window, cx| {
- let workspace = &multi_workspace.workspaces()[0];
+ let workspace = multi_workspace.workspaces().next().unwrap();
workspace.update(cx, |workspace, cx| {
workspace.add_panel(project_panel, window, cx);
workspace.open_panel::<ProjectPanel>(window, cx);
@@ -3263,7 +3263,7 @@ edition = "2021"
// Open main.rs in the editor
let open_file_task = workspace_window
.update(cx, |multi_workspace, window, cx| {
- let workspace = &multi_workspace.workspaces()[0];
+ let workspace = multi_workspace.workspaces().next().unwrap();
workspace.update(cx, |workspace, cx| {
let worktree = workspace.project().read(cx).worktrees(cx).next();
if let Some(worktree) = worktree {
@@ -3291,7 +3291,7 @@ edition = "2021"
// Load the AgentPanel
let (weak_workspace, async_window_cx) = workspace_window
.update(cx, |multi_workspace, window, cx| {
- let workspace = &multi_workspace.workspaces()[0];
+ let workspace = multi_workspace.workspaces().next().unwrap();
(workspace.read(cx).weak_handle(), window.to_async(cx))
})
.context("Failed to get workspace handle for agent panel")?;
@@ -3335,7 +3335,7 @@ edition = "2021"
workspace_window
.update(cx, |multi_workspace, window, cx| {
- let workspace = &multi_workspace.workspaces()[0];
+ let workspace = multi_workspace.workspaces().next().unwrap();
workspace.update(cx, |workspace, cx| {
workspace.add_panel(panel.clone(), window, cx);
workspace.open_panel::<AgentPanel>(window, cx);
@@ -3512,7 +3512,7 @@ edition = "2021"
.is_none()
});
let workspace_count = workspace_window.update(cx, |multi_workspace, _window, _cx| {
- multi_workspace.workspaces().len()
+ multi_workspace.workspaces().count()
})?;
if workspace_count == 2 && status_cleared {
creation_complete = true;
@@ -3531,7 +3531,7 @@ edition = "2021"
// error state by injecting the stub server, and shrink the panel so the
// editor content is visible.
workspace_window.update(cx, |multi_workspace, window, cx| {
- let new_workspace = &multi_workspace.workspaces()[1];
+ let new_workspace = multi_workspace.workspaces().nth(1).unwrap();
new_workspace.update(cx, |workspace, cx| {
if let Some(new_panel) = workspace.panel::<AgentPanel>(cx) {
new_panel.update(cx, |panel, cx| {
@@ -3544,7 +3544,7 @@ edition = "2021"
// Type and send a message so the thread target dropdown disappears.
let new_panel = workspace_window.update(cx, |multi_workspace, _window, cx| {
- let new_workspace = &multi_workspace.workspaces()[1];
+ let new_workspace = multi_workspace.workspaces().nth(1).unwrap();
new_workspace.read(cx).panel::<AgentPanel>(cx)
})?;
if let Some(new_panel) = new_panel {
@@ -3585,7 +3585,7 @@ edition = "2021"
workspace_window
.update(cx, |multi_workspace, _window, cx| {
- let workspace = &multi_workspace.workspaces()[0];
+ let workspace = multi_workspace.workspaces().next().unwrap();
let project = workspace.read(cx).project().clone();
project.update(cx, |project, cx| {
let worktree_ids: Vec<_> =
@@ -1524,7 +1524,7 @@ fn quit(_: &Quit, cx: &mut App) {
let window = *window;
let workspaces = window
.update(cx, |multi_workspace, _, _| {
- multi_workspace.workspaces().to_vec()
+ multi_workspace.workspaces().cloned().collect::<Vec<_>>()
})
.log_err();
@@ -2458,7 +2458,6 @@ mod tests {
.update(cx, |multi_workspace, window, cx| {
let mut tasks = multi_workspace
.workspaces()
- .iter()
.map(|workspace| {
workspace.update(cx, |workspace, cx| {
workspace.flush_serialization(window, cx)
@@ -2610,7 +2609,7 @@ mod tests {
cx.run_until_parked();
multi_workspace_1
.update(cx, |multi_workspace, _window, cx| {
- assert_eq!(multi_workspace.workspaces().len(), 2);
+ assert_eq!(multi_workspace.workspaces().count(), 2);
assert!(multi_workspace.sidebar_open());
let workspace = multi_workspace.workspace().read(cx);
assert_eq!(
@@ -5512,6 +5511,11 @@ mod tests {
let project = project1.clone();
|window, cx| MultiWorkspace::test_new(project, window, cx)
});
+ window
+ .update(cx, |multi_workspace, _, cx| {
+ multi_workspace.open_sidebar(cx);
+ })
+ .unwrap();
cx.run_until_parked();
assert_eq!(cx.windows().len(), 1, "Should start with 1 window");
@@ -5534,7 +5538,7 @@ mod tests {
let workspace1 = window
.read_with(cx, |multi_workspace, _| {
- multi_workspace.workspaces()[0].clone()
+ multi_workspace.workspaces().next().unwrap().clone()
})
.unwrap();
@@ -5543,8 +5547,8 @@ mod tests {
multi_workspace.activate(workspace2.clone(), window, cx);
multi_workspace.activate(workspace3.clone(), window, cx);
// Switch back to workspace1 for test setup
- multi_workspace.activate(workspace1, window, cx);
- assert_eq!(multi_workspace.active_workspace_index(), 0);
+ multi_workspace.activate(workspace1.clone(), window, cx);
+ assert_eq!(multi_workspace.workspace(), &workspace1);
})
.unwrap();
@@ -5553,8 +5557,8 @@ mod tests {
// Verify setup: 3 workspaces, workspace 0 active, still 1 window
window
.read_with(cx, |multi_workspace, _| {
- assert_eq!(multi_workspace.workspaces().len(), 3);
- assert_eq!(multi_workspace.active_workspace_index(), 0);
+ assert_eq!(multi_workspace.workspaces().count(), 3);
+ assert_eq!(multi_workspace.workspace(), &workspace1);
})
.unwrap();
assert_eq!(cx.windows().len(), 1);
@@ -5577,8 +5581,8 @@ mod tests {
window
.read_with(cx, |multi_workspace, cx| {
assert_eq!(
- multi_workspace.active_workspace_index(),
- 2,
+ multi_workspace.workspace(),
+ &workspace3,
"Should have switched to workspace 3 which contains /dir3"
);
let active_item = multi_workspace
@@ -5611,8 +5615,8 @@ mod tests {
window
.read_with(cx, |multi_workspace, cx| {
assert_eq!(
- multi_workspace.active_workspace_index(),
- 1,
+ multi_workspace.workspace(),
+ &workspace2,
"Should have switched to workspace 2 which contains /dir2"
);
let active_item = multi_workspace
@@ -5660,8 +5664,8 @@ mod tests {
window
.read_with(cx, |multi_workspace, cx| {
assert_eq!(
- multi_workspace.active_workspace_index(),
- 0,
+ multi_workspace.workspace(),
+ &workspace1,
"Should have switched back to workspace 0 which contains /dir1"
);
let active_item = multi_workspace
@@ -5711,6 +5715,11 @@ mod tests {
let project = project1.clone();
|window, cx| MultiWorkspace::test_new(project, window, cx)
});
+ window1
+ .update(cx, |multi_workspace, _, cx| {
+ multi_workspace.open_sidebar(cx);
+ })
+ .unwrap();
cx.run_until_parked();
@@ -5737,6 +5746,11 @@ mod tests {
let project = project3.clone();
|window, cx| MultiWorkspace::test_new(project, window, cx)
});
+ window2
+ .update(cx, |multi_workspace, _, cx| {
+ multi_workspace.open_sidebar(cx);
+ })
+ .unwrap();
cx.run_until_parked();
assert_eq!(cx.windows().len(), 2);
@@ -5771,7 +5785,7 @@ mod tests {
// Verify workspace1_1 is active
window1
.read_with(cx, |multi_workspace, _| {
- assert_eq!(multi_workspace.active_workspace_index(), 0);
+ assert_eq!(multi_workspace.workspace(), &workspace1_1);
})
.unwrap();
@@ -5837,7 +5851,7 @@ mod tests {
// Verify workspace1_1 is still active (not workspace1_2 with dirty item)
window1
.read_with(cx, |multi_workspace, _| {
- assert_eq!(multi_workspace.active_workspace_index(), 0);
+ assert_eq!(multi_workspace.workspace(), &workspace1_1);
})
.unwrap();
@@ -5848,8 +5862,8 @@ mod tests {
window1
.read_with(cx, |multi_workspace, _| {
assert_eq!(
- multi_workspace.active_workspace_index(),
- 1,
+ multi_workspace.workspace(),
+ &workspace1_2,
"Case 2: Non-active workspace should be activated when it has dirty item"
);
})
@@ -6002,6 +6016,12 @@ mod tests {
.await
.expect("failed to open first workspace");
+ window_a
+ .update(cx, |multi_workspace, _, cx| {
+ multi_workspace.open_sidebar(cx);
+ })
+ .unwrap();
+
window_a
.update(cx, |multi_workspace, window, cx| {
multi_workspace.open_project(vec![dir2.into()], OpenMode::Activate, window, cx)
@@ -6028,13 +6048,19 @@ mod tests {
.await
.expect("failed to open third workspace");
+ window_b
+ .update(cx, |multi_workspace, _, cx| {
+ multi_workspace.open_sidebar(cx);
+ })
+ .unwrap();
+
// Currently dir2 is active because it was added last.
// So, switch window_a's active workspace to dir1 (index 0).
// This sets up a non-trivial assertion: after restore, dir1 should
// still be active rather than whichever workspace happened to restore last.
window_a
.update(cx, |multi_workspace, window, cx| {
- let workspace = multi_workspace.workspaces()[0].clone();
+ let workspace = multi_workspace.workspaces().next().unwrap().clone();
multi_workspace.activate(workspace, window, cx);
})
.unwrap();
@@ -6150,7 +6176,7 @@ mod tests {
ProjectGroupKey::new(None, PathList::new(&[dir2])),
]
);
- assert_eq!(mw.workspaces().len(), 1);
+ assert_eq!(mw.workspaces().count(), 1);
})
.unwrap();
@@ -6161,7 +6187,7 @@ mod tests {
mw.project_group_keys().cloned().collect::<Vec<_>>(),
vec![ProjectGroupKey::new(None, PathList::new(&[dir3]))]
);
- assert_eq!(mw.workspaces().len(), 1);
+ assert_eq!(mw.workspaces().count(), 1);
})
.unwrap();
}
@@ -13,6 +13,7 @@ path = "src/zeta_prompt.rs"
[dependencies]
anyhow.workspace = true
+imara-diff.workspace = true
serde.workspace = true
strum.workspace = true
@@ -6,6 +6,10 @@ use std::{
};
use anyhow::{Context as _, Result, anyhow};
+use imara_diff::{
+ Algorithm, Sink, diff,
+ intern::{InternedInput, Interner, Token},
+};
pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> {
if prefix.is_empty() {
@@ -221,6 +225,181 @@ pub fn disambiguate_by_line_number(
}
}
+pub fn unified_diff_with_context(
+ old_text: &str,
+ new_text: &str,
+ old_start_line: u32,
+ new_start_line: u32,
+ context_lines: u32,
+) -> String {
+ let input = InternedInput::new(old_text, new_text);
+ diff(
+ Algorithm::Histogram,
+ &input,
+ OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line, context_lines),
+ )
+}
+
+struct OffsetUnifiedDiffBuilder<'a> {
+ before: &'a [Token],
+ after: &'a [Token],
+ interner: &'a Interner<&'a str>,
+ pos: u32,
+ before_hunk_start: u32,
+ after_hunk_start: u32,
+ before_hunk_len: u32,
+ after_hunk_len: u32,
+ old_line_offset: u32,
+ new_line_offset: u32,
+ context_lines: u32,
+ buffer: String,
+ dst: String,
+}
+
+impl<'a> OffsetUnifiedDiffBuilder<'a> {
+ fn new(
+ input: &'a InternedInput<&'a str>,
+ old_line_offset: u32,
+ new_line_offset: u32,
+ context_lines: u32,
+ ) -> Self {
+ Self {
+ before_hunk_start: 0,
+ after_hunk_start: 0,
+ before_hunk_len: 0,
+ after_hunk_len: 0,
+ old_line_offset,
+ new_line_offset,
+ context_lines,
+ buffer: String::with_capacity(8),
+ dst: String::new(),
+ interner: &input.interner,
+ before: &input.before,
+ after: &input.after,
+ pos: 0,
+ }
+ }
+
+ fn print_tokens(&mut self, tokens: &[Token], prefix: char) {
+ for &token in tokens {
+ writeln!(&mut self.buffer, "{prefix}{}", self.interner[token]).unwrap();
+ }
+ }
+
+ fn flush(&mut self) {
+ if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
+ return;
+ }
+
+ let end = (self.pos + self.context_lines).min(self.before.len() as u32);
+ self.update_pos(end, end);
+
+ writeln!(
+ &mut self.dst,
+ "@@ -{},{} +{},{} @@",
+ self.before_hunk_start + 1 + self.old_line_offset,
+ self.before_hunk_len,
+ self.after_hunk_start + 1 + self.new_line_offset,
+ self.after_hunk_len,
+ )
+ .unwrap();
+ write!(&mut self.dst, "{}", &self.buffer).unwrap();
+ self.buffer.clear();
+ self.before_hunk_len = 0;
+ self.after_hunk_len = 0;
+ }
+
+ fn update_pos(&mut self, print_to: u32, move_to: u32) {
+ self.print_tokens(&self.before[self.pos as usize..print_to as usize], ' ');
+ let len = print_to - self.pos;
+ self.before_hunk_len += len;
+ self.after_hunk_len += len;
+ self.pos = move_to;
+ }
+}
+
+impl Sink for OffsetUnifiedDiffBuilder<'_> {
+ type Out = String;
+
+ fn process_change(&mut self, before: Range<u32>, after: Range<u32>) {
+ if before.start - self.pos > self.context_lines * 2 {
+ self.flush();
+ }
+ if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
+ self.pos = before.start.saturating_sub(self.context_lines);
+ self.before_hunk_start = self.pos;
+ self.after_hunk_start = after.start.saturating_sub(self.context_lines);
+ }
+
+ self.update_pos(before.start, before.end);
+ self.before_hunk_len += before.end - before.start;
+ self.after_hunk_len += after.end - after.start;
+ self.print_tokens(
+ &self.before[before.start as usize..before.end as usize],
+ '-',
+ );
+ self.print_tokens(&self.after[after.start as usize..after.end as usize], '+');
+ }
+
+ fn finish(mut self) -> Self::Out {
+ self.flush();
+ self.dst
+ }
+}
+
+pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option<usize>) -> String {
+ let Some(cursor_offset) = cursor_offset else {
+ return patch.to_string();
+ };
+
+ let mut result = String::new();
+ let mut line_start_offset = 0usize;
+
+ for line in patch.lines() {
+ if matches!(
+ DiffLine::parse(line),
+ DiffLine::Garbage(content)
+ if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER)
+ ) {
+ continue;
+ }
+
+ if !result.is_empty() {
+ result.push('\n');
+ }
+ result.push_str(line);
+
+ match DiffLine::parse(line) {
+ DiffLine::Addition(content) => {
+ let line_end_offset = line_start_offset + content.len();
+
+ if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset {
+ let cursor_column = cursor_offset - line_start_offset;
+
+ result.push('\n');
+ result.push('#');
+ for _ in 0..cursor_column {
+ result.push(' ');
+ }
+ write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
+ }
+
+ line_start_offset = line_end_offset + 1;
+ }
+ DiffLine::Context(content) => {
+ line_start_offset += content.len() + 1;
+ }
+ _ => {}
+ }
+ }
+
+ if patch.ends_with('\n') {
+ result.push('\n');
+ }
+
+ result
+}
+
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
apply_diff_to_string_with_hunk_offset(diff_str, text).map(|(text, _)| text)
}
@@ -1203,4 +1382,25 @@ mod tests {
// Edit range end should be clamped to 7 (new context length).
assert_eq!(hunk.edits[0].range, 4..7);
}
+
+ #[test]
+ fn test_unified_diff_with_context_matches_expected_context_window() {
+ let old_text = "line1\nline2\nline3\nline4\nline5\nCHANGE_ME\nline7\nline8\n";
+ let new_text = "line1\nline2\nline3\nline4\nline5\nCHANGED\nline7\nline8\n";
+
+ let diff_default = unified_diff_with_context(old_text, new_text, 0, 0, 3);
+ assert_eq!(
+ diff_default,
+ "@@ -3,6 +3,6 @@\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n"
+ );
+
+ let diff_full_context = unified_diff_with_context(old_text, new_text, 0, 0, 8);
+ assert_eq!(
+ diff_full_context,
+ "@@ -1,8 +1,8 @@\n line1\n line2\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n"
+ );
+
+ let diff_no_context = unified_diff_with_context(old_text, new_text, 0, 0, 0);
+ assert_eq!(diff_no_context, "@@ -6,1 +6,1 @@\n-CHANGE_ME\n+CHANGED\n");
+ }
}
@@ -106,10 +106,19 @@ impl std::fmt::Display for ZetaFormat {
impl ZetaFormat {
pub fn parse(format_name: &str) -> Result<Self> {
+ let lower = format_name.to_lowercase();
+
+ // Exact case-insensitive match takes priority, bypassing ambiguity checks.
+ for variant in ZetaFormat::iter() {
+ if <&'static str>::from(&variant).to_lowercase() == lower {
+ return Ok(variant);
+ }
+ }
+
let mut results = ZetaFormat::iter().filter(|version| {
<&'static str>::from(version)
.to_lowercase()
- .contains(&format_name.to_lowercase())
+ .contains(&lower)
});
let Some(result) = results.next() else {
anyhow::bail!(
@@ -927,11 +936,39 @@ fn cursor_in_new_text(
})
}
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ParsedOutput {
/// Text that should replace the editable region
pub new_editable_region: String,
/// The byte range within `cursor_excerpt` that this replacement applies to
pub range_in_excerpt: Range<usize>,
+ /// Byte offset of the cursor marker within `new_editable_region`, if present
+ pub cursor_offset_in_new_editable_region: Option<usize>,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
+pub struct CursorPosition {
+ pub path: String,
+ pub row: usize,
+ pub column: usize,
+ pub offset: usize,
+ pub editable_region_offset: usize,
+}
+
+pub fn parsed_output_from_editable_region(
+ range_in_excerpt: Range<usize>,
+ mut new_editable_region: String,
+) -> ParsedOutput {
+ let cursor_offset_in_new_editable_region = new_editable_region.find(CURSOR_MARKER);
+ if let Some(offset) = cursor_offset_in_new_editable_region {
+ new_editable_region.replace_range(offset..offset + CURSOR_MARKER.len(), "");
+ }
+
+ ParsedOutput {
+ new_editable_region,
+ range_in_excerpt,
+ cursor_offset_in_new_editable_region,
+ }
}
/// Parse model output for the given zeta format
@@ -999,12 +1036,97 @@ pub fn parse_zeta2_model_output(
let range_in_excerpt =
range_in_context.start + context_start..range_in_context.end + context_start;
- Ok(ParsedOutput {
- new_editable_region: output,
- range_in_excerpt,
+ Ok(parsed_output_from_editable_region(range_in_excerpt, output))
+}
+
+pub fn parse_zeta2_model_output_as_patch(
+ output: &str,
+ format: ZetaFormat,
+ prompt_inputs: &ZetaPromptInput,
+) -> Result<String> {
+ let parsed = parse_zeta2_model_output(output, format, prompt_inputs)?;
+ parsed_output_to_patch(prompt_inputs, parsed)
+}
+
+pub fn cursor_position_from_parsed_output(
+ prompt_inputs: &ZetaPromptInput,
+ parsed: &ParsedOutput,
+) -> Option<CursorPosition> {
+ let cursor_offset = parsed.cursor_offset_in_new_editable_region?;
+ let editable_region_offset = parsed.range_in_excerpt.start;
+ let excerpt = prompt_inputs.cursor_excerpt.as_ref();
+
+ let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
+
+ let new_editable_region = &parsed.new_editable_region;
+ let prefix_end = cursor_offset.min(new_editable_region.len());
+ let new_region_prefix = &new_editable_region[..prefix_end];
+
+ let row = editable_region_start_line + new_region_prefix.matches('\n').count();
+
+ let column = match new_region_prefix.rfind('\n') {
+ Some(last_newline) => cursor_offset - last_newline - 1,
+ None => {
+ let content_prefix = &excerpt[..editable_region_offset];
+ let content_column = match content_prefix.rfind('\n') {
+ Some(last_newline) => editable_region_offset - last_newline - 1,
+ None => editable_region_offset,
+ };
+ content_column + cursor_offset
+ }
+ };
+
+ Some(CursorPosition {
+ path: prompt_inputs.cursor_path.to_string_lossy().into_owned(),
+ row,
+ column,
+ offset: editable_region_offset + cursor_offset,
+ editable_region_offset: cursor_offset,
})
}
+pub fn parsed_output_to_patch(
+ prompt_inputs: &ZetaPromptInput,
+ parsed: ParsedOutput,
+) -> Result<String> {
+ let range_in_excerpt = parsed.range_in_excerpt;
+ let excerpt = prompt_inputs.cursor_excerpt.as_ref();
+ let old_text = excerpt[range_in_excerpt.clone()].to_string();
+ let mut new_text = parsed.new_editable_region;
+
+ let mut old_text_normalized = old_text;
+ if !new_text.is_empty() && !new_text.ends_with('\n') {
+ new_text.push('\n');
+ }
+ if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
+ old_text_normalized.push('\n');
+ }
+
+ let editable_region_offset = range_in_excerpt.start;
+ let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count() as u32;
+ let editable_region_lines = old_text_normalized.lines().count() as u32;
+
+ let diff = udiff::unified_diff_with_context(
+ &old_text_normalized,
+ &new_text,
+ editable_region_start_line,
+ editable_region_start_line,
+ editable_region_lines,
+ );
+
+ let path = prompt_inputs
+ .cursor_path
+ .to_string_lossy()
+ .trim_start_matches('/')
+ .to_string();
+ let formatted_diff = format!("--- a/{path}\n+++ b/{path}\n{diff}");
+
+ Ok(udiff::encode_cursor_in_patch(
+ &formatted_diff,
+ parsed.cursor_offset_in_new_editable_region,
+ ))
+}
+
pub fn excerpt_range_for_format(
format: ZetaFormat,
ranges: &ExcerptRanges,
@@ -5400,6 +5522,33 @@ mod tests {
assert_eq!(apply_edit(excerpt, &output1), "new content\n");
}
+ #[test]
+ fn test_parsed_output_to_patch_round_trips_through_udiff_application() {
+ let excerpt = "before ctx\nctx start\neditable old\nctx end\nafter ctx\n";
+ let context_start = excerpt.find("ctx start").unwrap();
+ let context_end = excerpt.find("after ctx").unwrap();
+ let editable_start = excerpt.find("editable old").unwrap();
+ let editable_end = editable_start + "editable old\n".len();
+ let input = make_input_with_context_range(
+ excerpt,
+ editable_start..editable_end,
+ context_start..context_end,
+ editable_start,
+ );
+
+ let parsed = parse_zeta2_model_output(
+ "editable new\n>>>>>>> UPDATED\n",
+ ZetaFormat::V0131GitMergeMarkersPrefix,
+ &input,
+ )
+ .unwrap();
+ let expected = apply_edit(excerpt, &parsed);
+ let patch = parsed_output_to_patch(&input, parsed).unwrap();
+ let patched = udiff::apply_diff_to_string(&patch, excerpt).unwrap();
+
+ assert_eq!(patched, expected);
+ }
+
#[test]
fn test_special_tokens_not_triggered_by_comment_separator() {
// Regression test for https://github.com/zed-industries/zed/issues/52489
@@ -562,6 +562,7 @@ You can change the following settings to modify vim mode's behavior:
| use_system_clipboard | Determines how system clipboard is used:<br><ul><li>"always": use for all operations</li><li>"never": only use when explicitly specified</li><li>"on_yank": use for yank operations</li></ul> | "always" |
| use_multiline_find | deprecated |
| use_smartcase_find | If `true`, `f` and `t` motions are case-insensitive when the target letter is lowercase. | false |
+| use_regex_search | If `true`, then vim search will use regex mode | true |
| gdefault | If `true`, the `:substitute` command replaces all matches in a line by default (as if `g` flag was given). The `g` flag then toggles this, replacing only the first match. | false |
| toggle_relative_line_numbers | If `true`, line numbers are relative in normal mode and absolute in insert mode, giving you the best of both options. | false |
| custom_digraphs | An object that allows you to add custom digraphs. Read below for an example. | {} |
@@ -587,6 +588,7 @@ Here's an example of these settings changed:
"default_mode": "insert",
"use_system_clipboard": "never",
"use_smartcase_find": true,
+ "use_regex_search": true,
"gdefault": true,
"toggle_relative_line_numbers": true,
"highlight_on_yank_duration": 50,
@@ -0,0 +1,38 @@
+[package]
+name = "compliance"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[features]
+octo-client = ["dep:octocrab", "dep:jsonwebtoken", "dep:futures", "dep:tokio"]
+
+[dependencies]
+anyhow.workspace = true
+async-trait.workspace = true
+derive_more.workspace = true
+futures = { workspace = true, optional = true }
+itertools.workspace = true
+jsonwebtoken = { version = "10.2", features = ["use_pem"], optional = true }
+octocrab = { version = "0.49", default-features = false, features = [
+ "default-client",
+ "jwt-aws-lc-rs",
+ "retry",
+ "rustls",
+ "rustls-aws-lc-rs",
+ "stream",
+ "timeout"
+], optional = true }
+regex.workspace = true
+semver.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+tokio = { workspace = true, optional = true }
+
+[dev-dependencies]
+indoc.workspace = true
+tokio = { workspace = true, features = ["rt", "macros"] }
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,647 @@
+use std::{fmt, ops::Not as _};
+
+use itertools::Itertools as _;
+
+use crate::{
+ git::{CommitDetails, CommitList},
+ github::{
+ CommitAuthor, GitHubClient, GitHubUser, GithubLogin, PullRequestComment, PullRequestData,
+ PullRequestReview, ReviewState,
+ },
+ report::Report,
+};
+
+const ZED_ZIPPY_COMMENT_APPROVAL_PATTERN: &str = "@zed-zippy approve";
+const ZED_ZIPPY_GROUP_APPROVAL: &str = "@zed-industries/approved";
+
+#[derive(Debug)]
+pub enum ReviewSuccess {
+ ApprovingComment(Vec<PullRequestComment>),
+ CoAuthored(Vec<CommitAuthor>),
+ ExternalMergedContribution { merged_by: GitHubUser },
+ PullRequestReviewed(Vec<PullRequestReview>),
+}
+
+impl ReviewSuccess {
+ pub(crate) fn reviewers(&self) -> anyhow::Result<String> {
+ let reviewers = match self {
+ Self::CoAuthored(authors) => authors.iter().map(ToString::to_string).collect_vec(),
+ Self::PullRequestReviewed(reviews) => reviews
+ .iter()
+ .filter_map(|review| review.user.as_ref())
+ .map(|user| format!("@{}", user.login))
+ .collect_vec(),
+ Self::ApprovingComment(comments) => comments
+ .iter()
+ .map(|comment| format!("@{}", comment.user.login))
+ .collect_vec(),
+ Self::ExternalMergedContribution { merged_by } => {
+ vec![format!("@{}", merged_by.login)]
+ }
+ };
+
+ let reviewers = reviewers.into_iter().unique().collect_vec();
+
+ reviewers
+ .is_empty()
+ .not()
+ .then(|| reviewers.join(", "))
+ .ok_or_else(|| anyhow::anyhow!("Expected at least one reviewer"))
+ }
+}
+
+impl fmt::Display for ReviewSuccess {
+ fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::CoAuthored(_) => formatter.write_str("Co-authored by an organization member"),
+ Self::PullRequestReviewed(_) => {
+ formatter.write_str("Approved by an organization review")
+ }
+ Self::ApprovingComment(_) => {
+ formatter.write_str("Approved by an organization approval comment")
+ }
+ Self::ExternalMergedContribution { .. } => {
+ formatter.write_str("External merged contribution")
+ }
+ }
+ }
+}
+
+#[derive(Debug)]
+pub enum ReviewFailure {
+ // todo: We could still query the GitHub API here to search for one
+ NoPullRequestFound,
+ Unreviewed,
+ UnableToDetermineReviewer,
+ Other(anyhow::Error),
+}
+
+impl fmt::Display for ReviewFailure {
+ fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::NoPullRequestFound => formatter.write_str("No pull request found"),
+ Self::Unreviewed => formatter
+ .write_str("No qualifying organization approval found for the pull request"),
+ Self::UnableToDetermineReviewer => formatter.write_str("Could not determine reviewer"),
+ Self::Other(error) => write!(formatter, "Failed to inspect review state: {error}"),
+ }
+ }
+}
+
+pub(crate) type ReviewResult = Result<ReviewSuccess, ReviewFailure>;
+
+impl<E: Into<anyhow::Error>> From<E> for ReviewFailure {
+ fn from(err: E) -> Self {
+ Self::Other(anyhow::anyhow!(err))
+ }
+}
+
+pub struct Reporter<'a> {
+ commits: CommitList,
+ github_client: &'a GitHubClient,
+}
+
+impl<'a> Reporter<'a> {
+ pub fn new(commits: CommitList, github_client: &'a GitHubClient) -> Self {
+ Self {
+ commits,
+ github_client,
+ }
+ }
+
+ /// Method that checks every commit for compliance
+ async fn check_commit(&self, commit: &CommitDetails) -> Result<ReviewSuccess, ReviewFailure> {
+ let Some(pr_number) = commit.pr_number() else {
+ return Err(ReviewFailure::NoPullRequestFound);
+ };
+
+ let pull_request = self.github_client.get_pull_request(pr_number).await?;
+
+ if let Some(approval) = self.check_pull_request_approved(&pull_request).await? {
+ return Ok(approval);
+ }
+
+ if let Some(approval) = self
+ .check_approving_pull_request_comment(&pull_request)
+ .await?
+ {
+ return Ok(approval);
+ }
+
+ if let Some(approval) = self.check_commit_co_authors(commit).await? {
+ return Ok(approval);
+ }
+
+ // if let Some(approval) = self.check_external_merged_pr(pr_number).await? {
+ // return Ok(approval);
+ // }
+
+ Err(ReviewFailure::Unreviewed)
+ }
+
+ async fn check_commit_co_authors(
+ &self,
+ commit: &CommitDetails,
+ ) -> Result<Option<ReviewSuccess>, ReviewFailure> {
+ if commit.co_authors().is_some()
+ && let Some(commit_authors) = self
+ .github_client
+ .get_commit_authors([commit.sha()])
+ .await?
+ .get(commit.sha())
+ .and_then(|authors| authors.co_authors())
+ {
+ let mut org_co_authors = Vec::new();
+ for co_author in commit_authors {
+ if let Some(github_login) = co_author.user()
+ && self
+ .github_client
+ .check_org_membership(github_login)
+ .await?
+ {
+ org_co_authors.push(co_author.clone());
+ }
+ }
+
+ Ok(org_co_authors
+ .is_empty()
+ .not()
+ .then_some(ReviewSuccess::CoAuthored(org_co_authors)))
+ } else {
+ Ok(None)
+ }
+ }
+
+ #[allow(unused)]
+ async fn check_external_merged_pr(
+ &self,
+ pull_request: PullRequestData,
+ ) -> Result<Option<ReviewSuccess>, ReviewFailure> {
+ if let Some(user) = pull_request.user
+ && self
+ .github_client
+ .check_org_membership(&GithubLogin::new(user.login))
+ .await?
+ .not()
+ {
+ pull_request.merged_by.map_or(
+ Err(ReviewFailure::UnableToDetermineReviewer),
+ |merged_by| {
+ Ok(Some(ReviewSuccess::ExternalMergedContribution {
+ merged_by,
+ }))
+ },
+ )
+ } else {
+ Ok(None)
+ }
+ }
+
+ async fn check_pull_request_approved(
+ &self,
+ pull_request: &PullRequestData,
+ ) -> Result<Option<ReviewSuccess>, ReviewFailure> {
+ let pr_reviews = self
+ .github_client
+ .get_pull_request_reviews(pull_request.number)
+ .await?;
+
+ if !pr_reviews.is_empty() {
+ let mut org_approving_reviews = Vec::new();
+ for review in pr_reviews {
+ if let Some(github_login) = review.user.as_ref()
+ && pull_request
+ .user
+ .as_ref()
+ .is_none_or(|pr_user| pr_user.login != github_login.login)
+ && review
+ .state
+ .is_some_and(|state| state == ReviewState::Approved)
+ && self
+ .github_client
+ .check_org_membership(&GithubLogin::new(github_login.login.clone()))
+ .await?
+ {
+ org_approving_reviews.push(review);
+ }
+ }
+
+ Ok(org_approving_reviews
+ .is_empty()
+ .not()
+ .then_some(ReviewSuccess::PullRequestReviewed(org_approving_reviews)))
+ } else {
+ Ok(None)
+ }
+ }
+
+ async fn check_approving_pull_request_comment(
+ &self,
+ pull_request: &PullRequestData,
+ ) -> Result<Option<ReviewSuccess>, ReviewFailure> {
+ let other_comments = self
+ .github_client
+ .get_pull_request_comments(pull_request.number)
+ .await?;
+
+ if !other_comments.is_empty() {
+ let mut org_approving_comments = Vec::new();
+
+ for comment in other_comments {
+ if pull_request
+ .user
+ .as_ref()
+ .is_some_and(|pr_author| pr_author.login != comment.user.login)
+ && comment.body.as_ref().is_some_and(|body| {
+ body.contains(ZED_ZIPPY_COMMENT_APPROVAL_PATTERN)
+ || body.contains(ZED_ZIPPY_GROUP_APPROVAL)
+ })
+ && self
+ .github_client
+ .check_org_membership(&GithubLogin::new(comment.user.login.clone()))
+ .await?
+ {
+ org_approving_comments.push(comment);
+ }
+ }
+
+ Ok(org_approving_comments
+ .is_empty()
+ .not()
+ .then_some(ReviewSuccess::ApprovingComment(org_approving_comments)))
+ } else {
+ Ok(None)
+ }
+ }
+
+ pub async fn generate_report(mut self) -> anyhow::Result<Report> {
+ let mut report = Report::new();
+
+ let commits_to_check = std::mem::take(&mut self.commits);
+ let total_commits = commits_to_check.len();
+
+ for (i, commit) in commits_to_check.into_iter().enumerate() {
+ println!(
+ "Checking commit {:?} ({current}/{total})",
+ commit.sha().short(),
+ current = i + 1,
+ total = total_commits
+ );
+
+ let review_result = self.check_commit(&commit).await;
+
+ if let Err(err) = &review_result {
+ println!("Commit {:?} failed review: {:?}", commit.sha().short(), err);
+ }
+
+ report.add(commit, review_result);
+ }
+
+ Ok(report)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::rc::Rc;
+ use std::str::FromStr;
+
+ use crate::git::{CommitDetails, CommitList, CommitSha};
+ use crate::github::{
+ AuthorsForCommits, GitHubApiClient, GitHubClient, GitHubUser, GithubLogin,
+ PullRequestComment, PullRequestData, PullRequestReview, ReviewState,
+ };
+
+ use super::{Reporter, ReviewFailure, ReviewSuccess};
+
+ struct MockGitHubApi {
+ pull_request: PullRequestData,
+ reviews: Vec<PullRequestReview>,
+ comments: Vec<PullRequestComment>,
+ commit_authors_json: serde_json::Value,
+ org_members: Vec<String>,
+ }
+
+ #[async_trait::async_trait(?Send)]
+ impl GitHubApiClient for MockGitHubApi {
+ async fn get_pull_request(&self, _pr_number: u64) -> anyhow::Result<PullRequestData> {
+ Ok(self.pull_request.clone())
+ }
+
+ async fn get_pull_request_reviews(
+ &self,
+ _pr_number: u64,
+ ) -> anyhow::Result<Vec<PullRequestReview>> {
+ Ok(self.reviews.clone())
+ }
+
+ async fn get_pull_request_comments(
+ &self,
+ _pr_number: u64,
+ ) -> anyhow::Result<Vec<PullRequestComment>> {
+ Ok(self.comments.clone())
+ }
+
+ async fn get_commit_authors(
+ &self,
+ _commit_shas: &[&CommitSha],
+ ) -> anyhow::Result<AuthorsForCommits> {
+ serde_json::from_value(self.commit_authors_json.clone()).map_err(Into::into)
+ }
+
+ async fn check_org_membership(&self, login: &GithubLogin) -> anyhow::Result<bool> {
+ Ok(self
+ .org_members
+ .iter()
+ .any(|member| member == login.as_str()))
+ }
+
+ async fn ensure_pull_request_has_label(
+ &self,
+ _label: &str,
+ _pr_number: u64,
+ ) -> anyhow::Result<()> {
+ Ok(())
+ }
+ }
+
+ fn make_commit(
+ sha: &str,
+ author_name: &str,
+ author_email: &str,
+ title: &str,
+ body: &str,
+ ) -> CommitDetails {
+ let formatted = format!(
+ "{sha}|field-delimiter|{author_name}|field-delimiter|{author_email}|field-delimiter|\
+ {title}|body-delimiter|{body}|commit-delimiter|"
+ );
+ CommitList::from_str(&formatted)
+ .expect("test commit should parse")
+ .into_iter()
+ .next()
+ .expect("should have one commit")
+ }
+
+ fn review(login: &str, state: ReviewState) -> PullRequestReview {
+ PullRequestReview {
+ user: Some(GitHubUser {
+ login: login.to_owned(),
+ }),
+ state: Some(state),
+ }
+ }
+
+ fn comment(login: &str, body: &str) -> PullRequestComment {
+ PullRequestComment {
+ user: GitHubUser {
+ login: login.to_owned(),
+ },
+ body: Some(body.to_owned()),
+ }
+ }
+
+ struct TestScenario {
+ pull_request: PullRequestData,
+ reviews: Vec<PullRequestReview>,
+ comments: Vec<PullRequestComment>,
+ commit_authors_json: serde_json::Value,
+ org_members: Vec<String>,
+ commit: CommitDetails,
+ }
+
+ impl TestScenario {
+ fn single_commit() -> Self {
+ Self {
+ pull_request: PullRequestData {
+ number: 1234,
+ user: Some(GitHubUser {
+ login: "alice".to_owned(),
+ }),
+ merged_by: None,
+ },
+ reviews: vec![],
+ comments: vec![],
+ commit_authors_json: serde_json::json!({}),
+ org_members: vec![],
+ commit: make_commit(
+ "abc12345abc12345",
+ "Alice",
+ "alice@test.com",
+ "Fix thing (#1234)",
+ "",
+ ),
+ }
+ }
+
+ fn with_reviews(mut self, reviews: Vec<PullRequestReview>) -> Self {
+ self.reviews = reviews;
+ self
+ }
+
+ fn with_comments(mut self, comments: Vec<PullRequestComment>) -> Self {
+ self.comments = comments;
+ self
+ }
+
+ fn with_org_members(mut self, members: Vec<&str>) -> Self {
+ self.org_members = members.into_iter().map(str::to_owned).collect();
+ self
+ }
+
+ fn with_commit_authors_json(mut self, json: serde_json::Value) -> Self {
+ self.commit_authors_json = json;
+ self
+ }
+
+ fn with_commit(mut self, commit: CommitDetails) -> Self {
+ self.commit = commit;
+ self
+ }
+
+ async fn run_scenario(self) -> Result<ReviewSuccess, ReviewFailure> {
+ let mock = MockGitHubApi {
+ pull_request: self.pull_request,
+ reviews: self.reviews,
+ comments: self.comments,
+ commit_authors_json: self.commit_authors_json,
+ org_members: self.org_members,
+ };
+ let client = GitHubClient::new(Rc::new(mock));
+ let reporter = Reporter::new(CommitList::default(), &client);
+ reporter.check_commit(&self.commit).await
+ }
+ }
+
+ #[tokio::test]
+ async fn approved_review_by_org_member_succeeds() {
+ let result = TestScenario::single_commit()
+ .with_reviews(vec![review("bob", ReviewState::Approved)])
+ .with_org_members(vec!["bob"])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Ok(ReviewSuccess::PullRequestReviewed(_))));
+ }
+
+ #[tokio::test]
+ async fn non_approved_review_state_is_not_accepted() {
+ let result = TestScenario::single_commit()
+ .with_reviews(vec![review("bob", ReviewState::Other)])
+ .with_org_members(vec!["bob"])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+ }
+
+ #[tokio::test]
+ async fn review_by_non_org_member_is_not_accepted() {
+ let result = TestScenario::single_commit()
+ .with_reviews(vec![review("bob", ReviewState::Approved)])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+ }
+
+ #[tokio::test]
+ async fn pr_author_own_approval_review_is_rejected() {
+ let result = TestScenario::single_commit()
+ .with_reviews(vec![review("alice", ReviewState::Approved)])
+ .with_org_members(vec!["alice"])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+ }
+
+ #[tokio::test]
+ async fn pr_author_own_approval_comment_is_rejected() {
+ let result = TestScenario::single_commit()
+ .with_comments(vec![comment("alice", "@zed-zippy approve")])
+ .with_org_members(vec!["alice"])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+ }
+
+ #[tokio::test]
+ async fn approval_comment_by_org_member_succeeds() {
+ let result = TestScenario::single_commit()
+ .with_comments(vec![comment("bob", "@zed-zippy approve")])
+ .with_org_members(vec!["bob"])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_))));
+ }
+
+ #[tokio::test]
+ async fn group_approval_comment_by_org_member_succeeds() {
+ let result = TestScenario::single_commit()
+ .with_comments(vec![comment("bob", "@zed-industries/approved")])
+ .with_org_members(vec!["bob"])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_))));
+ }
+
+ #[tokio::test]
+ async fn comment_without_approval_pattern_is_not_accepted() {
+ let result = TestScenario::single_commit()
+ .with_comments(vec![comment("bob", "looks good")])
+ .with_org_members(vec!["bob"])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+ }
+
+ #[tokio::test]
+ async fn commit_without_pr_number_is_no_pr_found() {
+ let result = TestScenario::single_commit()
+ .with_commit(make_commit(
+ "abc12345abc12345",
+ "Alice",
+ "alice@test.com",
+ "Fix thing without PR number",
+ "",
+ ))
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Err(ReviewFailure::NoPullRequestFound)));
+ }
+
+ #[tokio::test]
+ async fn pr_review_takes_precedence_over_comment() {
+ let result = TestScenario::single_commit()
+ .with_reviews(vec![review("bob", ReviewState::Approved)])
+ .with_comments(vec![comment("charlie", "@zed-zippy approve")])
+ .with_org_members(vec!["bob", "charlie"])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Ok(ReviewSuccess::PullRequestReviewed(_))));
+ }
+
+ #[tokio::test]
+ async fn comment_takes_precedence_over_co_author() {
+ let result = TestScenario::single_commit()
+ .with_comments(vec![comment("bob", "@zed-zippy approve")])
+ .with_commit_authors_json(serde_json::json!({
+ "abc12345abc12345": {
+ "author": {
+ "name": "Alice",
+ "email": "alice@test.com",
+ "user": { "login": "alice" }
+ },
+ "authors": [{
+ "name": "Charlie",
+ "email": "charlie@test.com",
+ "user": { "login": "charlie" }
+ }]
+ }
+ }))
+ .with_commit(make_commit(
+ "abc12345abc12345",
+ "Alice",
+ "alice@test.com",
+ "Fix thing (#1234)",
+ "Co-authored-by: Charlie <charlie@test.com>",
+ ))
+ .with_org_members(vec!["bob", "charlie"])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Ok(ReviewSuccess::ApprovingComment(_))));
+ }
+
+ #[tokio::test]
+ async fn co_author_org_member_succeeds() {
+ let result = TestScenario::single_commit()
+ .with_commit_authors_json(serde_json::json!({
+ "abc12345abc12345": {
+ "author": {
+ "name": "Alice",
+ "email": "alice@test.com",
+ "user": { "login": "alice" }
+ },
+ "authors": [{
+ "name": "Bob",
+ "email": "bob@test.com",
+ "user": { "login": "bob" }
+ }]
+ }
+ }))
+ .with_commit(make_commit(
+ "abc12345abc12345",
+ "Alice",
+ "alice@test.com",
+ "Fix thing (#1234)",
+ "Co-authored-by: Bob <bob@test.com>",
+ ))
+ .with_org_members(vec!["bob"])
+ .run_scenario()
+ .await;
+ assert!(matches!(result, Ok(ReviewSuccess::CoAuthored(_))));
+ }
+
+ #[tokio::test]
+ async fn no_reviews_no_comments_no_coauthors_is_unreviewed() {
+ let result = TestScenario::single_commit().run_scenario().await;
+ assert!(matches!(result, Err(ReviewFailure::Unreviewed)));
+ }
+}
@@ -0,0 +1,591 @@
+#![allow(clippy::disallowed_methods, reason = "This is only used in xtasks")]
+use std::{
+ fmt::{self, Debug},
+ ops::Not,
+ process::Command,
+ str::FromStr,
+ sync::LazyLock,
+};
+
+use anyhow::{Context, Result, anyhow};
+use derive_more::{Deref, DerefMut, FromStr};
+
+use itertools::Itertools;
+use regex::Regex;
+use semver::Version;
+use serde::Deserialize;
+
+pub trait Subcommand {
+ type ParsedOutput: FromStr<Err = anyhow::Error>;
+
+ fn args(&self) -> impl IntoIterator<Item = String>;
+}
+
+#[derive(Deref, DerefMut)]
+pub struct GitCommand<G: Subcommand> {
+ #[deref]
+ #[deref_mut]
+ subcommand: G,
+}
+
+impl<G: Subcommand> GitCommand<G> {
+ #[must_use]
+ pub fn run(subcommand: G) -> Result<G::ParsedOutput> {
+ Self { subcommand }.run_impl()
+ }
+
+ fn run_impl(self) -> Result<G::ParsedOutput> {
+ let command_output = Command::new("git")
+ .args(self.subcommand.args())
+ .output()
+ .context("Failed to spawn command")?;
+
+ if command_output.status.success() {
+ String::from_utf8(command_output.stdout)
+ .map_err(|_| anyhow!("Invalid UTF8"))
+ .and_then(|s| {
+ G::ParsedOutput::from_str(s.trim())
+ .map_err(|e| anyhow!("Failed to parse from string: {e:?}"))
+ })
+ } else {
+ anyhow::bail!(
+ "Command failed with exit code {}, stderr: {}",
+ command_output.status.code().unwrap_or_default(),
+ String::from_utf8(command_output.stderr).unwrap_or_default()
+ )
+ }
+ }
+}
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum ReleaseChannel {
+ Stable,
+ Preview,
+}
+
+impl ReleaseChannel {
+ pub(crate) fn tag_suffix(&self) -> &'static str {
+ match self {
+ ReleaseChannel::Stable => "",
+ ReleaseChannel::Preview => "-pre",
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct VersionTag(Version, ReleaseChannel);
+
+impl VersionTag {
+ pub fn parse(input: &str) -> Result<Self, anyhow::Error> {
+ // Being a bit more lenient for human inputs
+ let version = input.strip_prefix('v').unwrap_or(input);
+
+ let (version_str, channel) = version
+ .strip_suffix("-pre")
+ .map_or((version, ReleaseChannel::Stable), |version_str| {
+ (version_str, ReleaseChannel::Preview)
+ });
+
+ Version::parse(version_str)
+ .map(|version| Self(version, channel))
+ .map_err(|_| anyhow::anyhow!("Failed to parse version from tag!"))
+ }
+
+ pub fn version(&self) -> &Version {
+ &self.0
+ }
+}
+
+impl ToString for VersionTag {
+ fn to_string(&self) -> String {
+ format!(
+ "v{version}{channel_suffix}",
+ version = self.0,
+ channel_suffix = self.1.tag_suffix()
+ )
+ }
+}
+
+#[derive(Debug, Deref, FromStr, PartialEq, Eq, Hash, Deserialize)]
+pub struct CommitSha(pub(crate) String);
+
+impl CommitSha {
+ pub fn short(&self) -> &str {
+ self.0.as_str().split_at(8).0
+ }
+}
+
+#[derive(Debug)]
+pub struct CommitDetails {
+ sha: CommitSha,
+ author: Committer,
+ title: String,
+ body: String,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct Committer {
+ name: String,
+ email: String,
+}
+
+impl Committer {
+ pub fn new(name: &str, email: &str) -> Self {
+ Self {
+ name: name.to_owned(),
+ email: email.to_owned(),
+ }
+ }
+}
+
+impl fmt::Display for Committer {
+ fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(formatter, "{} ({})", self.name, self.email)
+ }
+}
+
+impl CommitDetails {
+ const BODY_DELIMITER: &str = "|body-delimiter|";
+ const COMMIT_DELIMITER: &str = "|commit-delimiter|";
+ const FIELD_DELIMITER: &str = "|field-delimiter|";
+ const FORMAT_STRING: &str = "%H|field-delimiter|%an|field-delimiter|%ae|field-delimiter|%s|body-delimiter|%b|commit-delimiter|";
+
+ fn parse(line: &str, body: &str) -> Result<Self, anyhow::Error> {
+ let Some([sha, author_name, author_email, title]) =
+ line.splitn(4, Self::FIELD_DELIMITER).collect_array()
+ else {
+ return Err(anyhow!("Failed to parse commit fields from input {line}"));
+ };
+
+ Ok(CommitDetails {
+ sha: CommitSha(sha.to_owned()),
+ author: Committer::new(author_name, author_email),
+ title: title.to_owned(),
+ body: body.to_owned(),
+ })
+ }
+
+ pub fn pr_number(&self) -> Option<u64> {
+ // Since we use squash merge, all commit titles end with the '(#12345)' pattern.
+ // While we could strictly speaking index into this directly, go for a slightly
+ // less prone approach to errors
+ const PATTERN: &str = " (#";
+ self.title
+ .rfind(PATTERN)
+ .and_then(|location| {
+ self.title[location..]
+ .find(')')
+ .map(|relative_end| location + PATTERN.len()..location + relative_end)
+ })
+ .and_then(|range| self.title[range].parse().ok())
+ }
+
+ pub(crate) fn co_authors(&self) -> Option<Vec<Committer>> {
+ static CO_AUTHOR_REGEX: LazyLock<Regex> =
+ LazyLock::new(|| Regex::new(r"Co-authored-by: (.+) <(.+)>").unwrap());
+
+ let mut co_authors = Vec::new();
+
+ for cap in CO_AUTHOR_REGEX.captures_iter(&self.body.as_ref()) {
+ let Some((name, email)) = cap
+ .get(1)
+ .map(|m| m.as_str())
+ .zip(cap.get(2).map(|m| m.as_str()))
+ else {
+ continue;
+ };
+ co_authors.push(Committer::new(name, email));
+ }
+
+ co_authors.is_empty().not().then_some(co_authors)
+ }
+
+ pub(crate) fn author(&self) -> &Committer {
+ &self.author
+ }
+
+ pub(crate) fn title(&self) -> &str {
+ &self.title
+ }
+
+ pub(crate) fn sha(&self) -> &CommitSha {
+ &self.sha
+ }
+}
+
+#[derive(Debug, Deref, Default, DerefMut)]
+pub struct CommitList(Vec<CommitDetails>);
+
+impl CommitList {
+ pub fn range(&self) -> Option<String> {
+ self.0
+ .first()
+ .zip(self.0.last())
+ .map(|(first, last)| format!("{}..{}", first.sha().0, last.sha().0))
+ }
+}
+
+impl IntoIterator for CommitList {
+ type IntoIter = std::vec::IntoIter<CommitDetails>;
+ type Item = CommitDetails;
+
+ fn into_iter(self) -> std::vec::IntoIter<Self::Item> {
+ self.0.into_iter()
+ }
+}
+
+impl FromStr for CommitList {
+ type Err = anyhow::Error;
+
+ fn from_str(input: &str) -> Result<Self, Self::Err> {
+ Ok(CommitList(
+ input
+ .split(CommitDetails::COMMIT_DELIMITER)
+ .filter(|commit_details| !commit_details.is_empty())
+ .map(|commit_details| {
+ let (line, body) = commit_details
+ .trim()
+ .split_once(CommitDetails::BODY_DELIMITER)
+ .expect("Missing body delimiter");
+ CommitDetails::parse(line, body)
+ .expect("Parsing from the output should succeed")
+ })
+ .collect(),
+ ))
+ }
+}
+
+pub struct GetVersionTags;
+
+impl Subcommand for GetVersionTags {
+ type ParsedOutput = VersionTagList;
+
+ fn args(&self) -> impl IntoIterator<Item = String> {
+ ["tag", "-l", "v*"].map(ToOwned::to_owned)
+ }
+}
+
+pub struct VersionTagList(Vec<VersionTag>);
+
+impl VersionTagList {
+ pub fn sorted(mut self) -> Self {
+ self.0.sort_by(|a, b| a.version().cmp(b.version()));
+ self
+ }
+
+ pub fn find_previous_minor_version(&self, version_tag: &VersionTag) -> Option<&VersionTag> {
+ self.0
+ .iter()
+ .take_while(|tag| tag.version() < version_tag.version())
+ .collect_vec()
+ .into_iter()
+ .rev()
+ .find(|tag| {
+ (tag.version().major < version_tag.version().major
+ || (tag.version().major == version_tag.version().major
+ && tag.version().minor < version_tag.version().minor))
+ && tag.version().patch == 0
+ })
+ }
+}
+
+impl FromStr for VersionTagList {
+ type Err = anyhow::Error;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ let version_tags = s.lines().flat_map(VersionTag::parse).collect_vec();
+
+ version_tags
+ .is_empty()
+ .not()
+ .then_some(Self(version_tags))
+ .ok_or_else(|| anyhow::anyhow!("No version tags found"))
+ }
+}
+
+pub struct CommitsFromVersionToHead {
+ version_tag: VersionTag,
+ branch: String,
+}
+
+impl CommitsFromVersionToHead {
+ pub fn new(version_tag: VersionTag, branch: String) -> Self {
+ Self {
+ version_tag,
+ branch,
+ }
+ }
+}
+
+impl Subcommand for CommitsFromVersionToHead {
+ type ParsedOutput = CommitList;
+
+ fn args(&self) -> impl IntoIterator<Item = String> {
+ [
+ "log".to_string(),
+ format!("--pretty=format:{}", CommitDetails::FORMAT_STRING),
+ format!(
+ "{version}..{branch}",
+ version = self.version_tag.to_string(),
+ branch = self.branch
+ ),
+ ]
+ }
+}
+
+pub struct NoOutput;
+
+impl FromStr for NoOutput {
+ type Err = anyhow::Error;
+
+ fn from_str(_: &str) -> Result<Self, Self::Err> {
+ Ok(NoOutput)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use indoc::indoc;
+
+ #[test]
+ fn parse_stable_version_tag() {
+ let tag = VersionTag::parse("v0.172.8").unwrap();
+ assert_eq!(tag.version().major, 0);
+ assert_eq!(tag.version().minor, 172);
+ assert_eq!(tag.version().patch, 8);
+ assert_eq!(tag.1, ReleaseChannel::Stable);
+ }
+
+ #[test]
+ fn parse_preview_version_tag() {
+ let tag = VersionTag::parse("v0.172.1-pre").unwrap();
+ assert_eq!(tag.version().major, 0);
+ assert_eq!(tag.version().minor, 172);
+ assert_eq!(tag.version().patch, 1);
+ assert_eq!(tag.1, ReleaseChannel::Preview);
+ }
+
+ #[test]
+ fn parse_version_tag_without_v_prefix() {
+ let tag = VersionTag::parse("0.172.8").unwrap();
+ assert_eq!(tag.version().major, 0);
+ assert_eq!(tag.version().minor, 172);
+ assert_eq!(tag.version().patch, 8);
+ }
+
+ #[test]
+ fn parse_invalid_version_tag() {
+ let result = VersionTag::parse("vConradTest");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn version_tag_stable_roundtrip() {
+ let tag = VersionTag::parse("v0.172.8").unwrap();
+ assert_eq!(tag.to_string(), "v0.172.8");
+ }
+
+ #[test]
+ fn version_tag_preview_roundtrip() {
+ let tag = VersionTag::parse("v0.172.1-pre").unwrap();
+ assert_eq!(tag.to_string(), "v0.172.1-pre");
+ }
+
+ #[test]
+ fn sorted_orders_by_semver() {
+ let input = indoc! {"
+ v0.172.8
+ v0.170.1
+ v0.171.4
+ v0.170.2
+ v0.172.11
+ v0.171.3
+ v0.172.9
+ "};
+ let list = VersionTagList::from_str(input).unwrap().sorted();
+ for window in list.0.windows(2) {
+ assert!(
+ window[0].version() <= window[1].version(),
+ "{} should come before {}",
+ window[0].to_string(),
+ window[1].to_string()
+ );
+ }
+ assert_eq!(list.0[0].to_string(), "v0.170.1");
+ assert_eq!(list.0[list.0.len() - 1].to_string(), "v0.172.11");
+ }
+
+ #[test]
+ fn find_previous_minor_for_173_returns_172() {
+ let input = indoc! {"
+ v0.170.1
+ v0.170.2
+ v0.171.3
+ v0.171.4
+ v0.172.0
+ v0.172.8
+ v0.172.9
+ v0.172.11
+ "};
+ let list = VersionTagList::from_str(input).unwrap().sorted();
+ let target = VersionTag::parse("v0.173.0").unwrap();
+ let previous = list.find_previous_minor_version(&target).unwrap();
+ assert_eq!(previous.version().major, 0);
+ assert_eq!(previous.version().minor, 172);
+ assert_eq!(previous.version().patch, 0);
+ }
+
+ #[test]
+ fn find_previous_minor_skips_same_minor() {
+ let input = indoc! {"
+ v0.172.8
+ v0.172.9
+ v0.172.11
+ "};
+ let list = VersionTagList::from_str(input).unwrap().sorted();
+ let target = VersionTag::parse("v0.172.8").unwrap();
+ assert!(list.find_previous_minor_version(&target).is_none());
+ }
+
+ #[test]
+ fn find_previous_minor_with_major_version_gap() {
+ let input = indoc! {"
+ v0.172.0
+ v0.172.9
+ v0.172.11
+ "};
+ let list = VersionTagList::from_str(input).unwrap().sorted();
+ let target = VersionTag::parse("v1.0.0").unwrap();
+ let previous = list.find_previous_minor_version(&target).unwrap();
+ assert_eq!(previous.to_string(), "v0.172.0");
+ }
+
+ #[test]
+ fn find_previous_minor_requires_zero_patch_version() {
+ let input = indoc! {"
+ v0.172.1
+ v0.172.9
+ v0.172.11
+ "};
+ let list = VersionTagList::from_str(input).unwrap().sorted();
+ let target = VersionTag::parse("v1.0.0").unwrap();
+ assert!(list.find_previous_minor_version(&target).is_none());
+ }
+
+ #[test]
+ fn parse_tag_list_from_real_tags() {
+ let input = indoc! {"
+ v0.9999-temporary
+ vConradTest
+ v0.172.8
+ "};
+ let list = VersionTagList::from_str(input).unwrap();
+ assert_eq!(list.0.len(), 1);
+ assert_eq!(list.0[0].to_string(), "v0.172.8");
+ }
+
+ #[test]
+ fn parse_empty_tag_list_fails() {
+ let result = VersionTagList::from_str("");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn pr_number_from_squash_merge_title() {
+ let line = format!(
+ "abc123{d}Author Name{d}author@email.com{d}Add cool feature (#12345)",
+ d = CommitDetails::FIELD_DELIMITER
+ );
+ let commit = CommitDetails::parse(&line, "").unwrap();
+ assert_eq!(commit.pr_number(), Some(12345));
+ }
+
+ #[test]
+ fn pr_number_missing() {
+ let line = format!(
+ "abc123{d}Author Name{d}author@email.com{d}Some commit without PR ref",
+ d = CommitDetails::FIELD_DELIMITER
+ );
+ let commit = CommitDetails::parse(&line, "").unwrap();
+ assert_eq!(commit.pr_number(), None);
+ }
+
+ #[test]
+ fn pr_number_takes_last_match() {
+ let line = format!(
+ "abc123{d}Author Name{d}author@email.com{d}Fix (#123) and refactor (#456)",
+ d = CommitDetails::FIELD_DELIMITER
+ );
+ let commit = CommitDetails::parse(&line, "").unwrap();
+ assert_eq!(commit.pr_number(), Some(456));
+ }
+
+ #[test]
+ fn co_authors_parsed_from_body() {
+ let line = format!(
+ "abc123{d}Author Name{d}author@email.com{d}Some title",
+ d = CommitDetails::FIELD_DELIMITER
+ );
+ let body = indoc! {"
+ Co-authored-by: Alice Smith <alice@example.com>
+ Co-authored-by: Bob Jones <bob@example.com>
+ "};
+ let commit = CommitDetails::parse(&line, body).unwrap();
+ let co_authors = commit.co_authors().unwrap();
+ assert_eq!(co_authors.len(), 2);
+ assert_eq!(
+ co_authors[0],
+ Committer::new("Alice Smith", "alice@example.com")
+ );
+ assert_eq!(
+ co_authors[1],
+ Committer::new("Bob Jones", "bob@example.com")
+ );
+ }
+
+ #[test]
+ fn no_co_authors_returns_none() {
+ let line = format!(
+ "abc123{d}Author Name{d}author@email.com{d}Some title",
+ d = CommitDetails::FIELD_DELIMITER
+ );
+ let commit = CommitDetails::parse(&line, "").unwrap();
+ assert!(commit.co_authors().is_none());
+ }
+
+ #[test]
+ fn commit_sha_short_returns_first_8_chars() {
+ let sha = CommitSha("abcdef1234567890abcdef1234567890abcdef12".into());
+ assert_eq!(sha.short(), "abcdef12");
+ }
+
+ #[test]
+ fn parse_commit_list_from_git_log_format() {
+ let fd = CommitDetails::FIELD_DELIMITER;
+ let bd = CommitDetails::BODY_DELIMITER;
+ let cd = CommitDetails::COMMIT_DELIMITER;
+
+ let input = format!(
+ "sha111{fd}Alice{fd}alice@test.com{fd}First commit (#100){bd}First body{cd}sha222{fd}Bob{fd}bob@test.com{fd}Second commit (#200){bd}Second body{cd}"
+ );
+
+ let list = CommitList::from_str(&input).unwrap();
+ assert_eq!(list.0.len(), 2);
+
+ assert_eq!(list.0[0].sha().0, "sha111");
+ assert_eq!(
+ list.0[0].author(),
+ &Committer::new("Alice", "alice@test.com")
+ );
+ assert_eq!(list.0[0].title(), "First commit (#100)");
+ assert_eq!(list.0[0].pr_number(), Some(100));
+ assert_eq!(list.0[0].body, "First body");
+
+ assert_eq!(list.0[1].sha().0, "sha222");
+ assert_eq!(list.0[1].author(), &Committer::new("Bob", "bob@test.com"));
+ assert_eq!(list.0[1].title(), "Second commit (#200)");
+ assert_eq!(list.0[1].pr_number(), Some(200));
+ assert_eq!(list.0[1].body, "Second body");
+ }
+}
@@ -0,0 +1,424 @@
+use std::{collections::HashMap, fmt, ops::Not, rc::Rc};
+
+use anyhow::Result;
+use derive_more::Deref;
+use serde::Deserialize;
+
+use crate::git::CommitSha;
+
+pub const PR_REVIEW_LABEL: &str = "PR state:needs review";
+
+#[derive(Debug, Clone)]
+pub struct GitHubUser {
+ pub login: String,
+}
+
+#[derive(Debug, Clone)]
+pub struct PullRequestData {
+ pub number: u64,
+ pub user: Option<GitHubUser>,
+ pub merged_by: Option<GitHubUser>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum ReviewState {
+ Approved,
+ Other,
+}
+
+#[derive(Debug, Clone)]
+pub struct PullRequestReview {
+ pub user: Option<GitHubUser>,
+ pub state: Option<ReviewState>,
+}
+
+#[derive(Debug, Clone)]
+pub struct PullRequestComment {
+ pub user: GitHubUser,
+ pub body: Option<String>,
+}
+
+#[derive(Debug, Deserialize, Clone, Deref, PartialEq, Eq)]
+pub struct GithubLogin {
+ login: String,
+}
+
+impl GithubLogin {
+ pub(crate) fn new(login: String) -> Self {
+ Self { login }
+ }
+}
+
+impl fmt::Display for GithubLogin {
+ fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(formatter, "@{}", self.login)
+ }
+}
+
+#[derive(Debug, Deserialize, Clone)]
+pub struct CommitAuthor {
+ name: String,
+ email: String,
+ user: Option<GithubLogin>,
+}
+
+impl CommitAuthor {
+ pub(crate) fn user(&self) -> Option<&GithubLogin> {
+ self.user.as_ref()
+ }
+}
+
+impl PartialEq for CommitAuthor {
+ fn eq(&self, other: &Self) -> bool {
+ self.user.as_ref().zip(other.user.as_ref()).map_or_else(
+ || self.email == other.email || self.name == other.name,
+ |(l, r)| l == r,
+ )
+ }
+}
+
+impl fmt::Display for CommitAuthor {
+ fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self.user.as_ref() {
+ Some(user) => write!(formatter, "{} ({user})", self.name),
+ None => write!(formatter, "{} ({})", self.name, self.email),
+ }
+ }
+}
+
+#[derive(Debug, Deserialize)]
+pub struct CommitAuthors {
+ #[serde(rename = "author")]
+ primary_author: CommitAuthor,
+ #[serde(rename = "authors")]
+ co_authors: Vec<CommitAuthor>,
+}
+
+impl CommitAuthors {
+ pub fn co_authors(&self) -> Option<impl Iterator<Item = &CommitAuthor>> {
+ self.co_authors.is_empty().not().then(|| {
+ self.co_authors
+ .iter()
+ .filter(|co_author| *co_author != &self.primary_author)
+ })
+ }
+}
+
+#[derive(Debug, Deserialize, Deref)]
+pub struct AuthorsForCommits(HashMap<CommitSha, CommitAuthors>);
+
+#[async_trait::async_trait(?Send)]
+pub trait GitHubApiClient {
+ async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData>;
+ async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>>;
+ async fn get_pull_request_comments(&self, pr_number: u64) -> Result<Vec<PullRequestComment>>;
+ async fn get_commit_authors(&self, commit_shas: &[&CommitSha]) -> Result<AuthorsForCommits>;
+ async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool>;
+ async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()>;
+}
+
+pub struct GitHubClient {
+ api: Rc<dyn GitHubApiClient>,
+}
+
+impl GitHubClient {
+ pub fn new(api: Rc<dyn GitHubApiClient>) -> Self {
+ Self { api }
+ }
+
+ #[cfg(feature = "octo-client")]
+ pub async fn for_app(app_id: u64, app_private_key: &str) -> Result<Self> {
+ let client = OctocrabClient::new(app_id, app_private_key).await?;
+ Ok(Self::new(Rc::new(client)))
+ }
+
+ pub async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
+ self.api.get_pull_request(pr_number).await
+ }
+
+ pub async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
+ self.api.get_pull_request_reviews(pr_number).await
+ }
+
+ pub async fn get_pull_request_comments(
+ &self,
+ pr_number: u64,
+ ) -> Result<Vec<PullRequestComment>> {
+ self.api.get_pull_request_comments(pr_number).await
+ }
+
+ pub async fn get_commit_authors<'a>(
+ &self,
+ commit_shas: impl IntoIterator<Item = &'a CommitSha>,
+ ) -> Result<AuthorsForCommits> {
+ let shas: Vec<&CommitSha> = commit_shas.into_iter().collect();
+ self.api.get_commit_authors(&shas).await
+ }
+
+ pub async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
+ self.api.check_org_membership(login).await
+ }
+
+ pub async fn add_label_to_pull_request(&self, label: &str, pr_number: u64) -> Result<()> {
+ self.api
+ .ensure_pull_request_has_label(label, pr_number)
+ .await
+ }
+}
+
+#[cfg(feature = "octo-client")]
+mod octo_client {
+ use anyhow::{Context, Result};
+ use futures::TryStreamExt as _;
+ use itertools::Itertools;
+ use jsonwebtoken::EncodingKey;
+ use octocrab::{
+ Octocrab, Page, models::pulls::ReviewState as OctocrabReviewState,
+ service::middleware::cache::mem::InMemoryCache,
+ };
+ use serde::de::DeserializeOwned;
+ use tokio::pin;
+
+ use crate::git::CommitSha;
+
+ use super::{
+ AuthorsForCommits, GitHubApiClient, GitHubUser, GithubLogin, PullRequestComment,
+ PullRequestData, PullRequestReview, ReviewState,
+ };
+
+ const PAGE_SIZE: u8 = 100;
+ const ORG: &str = "zed-industries";
+ const REPO: &str = "zed";
+
+ pub struct OctocrabClient {
+ client: Octocrab,
+ }
+
+ impl OctocrabClient {
+ pub async fn new(app_id: u64, app_private_key: &str) -> Result<Self> {
+ let octocrab = Octocrab::builder()
+ .cache(InMemoryCache::new())
+ .app(
+ app_id.into(),
+ EncodingKey::from_rsa_pem(app_private_key.as_bytes())?,
+ )
+ .build()?;
+
+ let installations = octocrab
+ .apps()
+ .installations()
+ .send()
+ .await
+ .context("Failed to fetch installations")?
+ .take_items();
+
+ let installation_id = installations
+ .into_iter()
+ .find(|installation| installation.account.login == ORG)
+ .context("Could not find Zed repository in installations")?
+ .id;
+
+ let client = octocrab.installation(installation_id)?;
+ Ok(Self { client })
+ }
+
+ fn build_co_authors_query<'a>(shas: impl IntoIterator<Item = &'a CommitSha>) -> String {
+ const FRAGMENT: &str = r#"
+ ... on Commit {
+ author {
+ name
+ email
+ user { login }
+ }
+ authors(first: 10) {
+ nodes {
+ name
+ email
+ user { login }
+ }
+ }
+ }
+ "#;
+
+ let objects: String = shas
+ .into_iter()
+ .map(|commit_sha| {
+ format!(
+ "commit{sha}: object(oid: \"{sha}\") {{ {FRAGMENT} }}",
+ sha = **commit_sha
+ )
+ })
+ .join("\n");
+
+ format!("{{ repository(owner: \"{ORG}\", name: \"{REPO}\") {{ {objects} }} }}")
+ .replace("\n", "")
+ }
+
+ async fn graphql<R: octocrab::FromResponse>(
+ &self,
+ query: &serde_json::Value,
+ ) -> octocrab::Result<R> {
+ self.client.graphql(query).await
+ }
+
+ async fn get_all<T: DeserializeOwned + 'static>(
+ &self,
+ page: Page<T>,
+ ) -> octocrab::Result<Vec<T>> {
+ self.get_filtered(page, |_| true).await
+ }
+
+ async fn get_filtered<T: DeserializeOwned + 'static>(
+ &self,
+ page: Page<T>,
+ predicate: impl Fn(&T) -> bool,
+ ) -> octocrab::Result<Vec<T>> {
+ let stream = page.into_stream(&self.client);
+ pin!(stream);
+
+ let mut results = Vec::new();
+
+ while let Some(item) = stream.try_next().await?
+ && predicate(&item)
+ {
+ results.push(item);
+ }
+
+ Ok(results)
+ }
+ }
+
+ #[async_trait::async_trait(?Send)]
+ impl GitHubApiClient for OctocrabClient {
+ async fn get_pull_request(&self, pr_number: u64) -> Result<PullRequestData> {
+ let pr = self.client.pulls(ORG, REPO).get(pr_number).await?;
+ Ok(PullRequestData {
+ number: pr.number,
+ user: pr.user.map(|user| GitHubUser { login: user.login }),
+ merged_by: pr.merged_by.map(|user| GitHubUser { login: user.login }),
+ })
+ }
+
+ async fn get_pull_request_reviews(&self, pr_number: u64) -> Result<Vec<PullRequestReview>> {
+ let page = self
+ .client
+ .pulls(ORG, REPO)
+ .list_reviews(pr_number)
+ .per_page(PAGE_SIZE)
+ .send()
+ .await?;
+
+ let reviews = self.get_all(page).await?;
+
+ Ok(reviews
+ .into_iter()
+ .map(|review| PullRequestReview {
+ user: review.user.map(|user| GitHubUser { login: user.login }),
+ state: review.state.map(|state| match state {
+ OctocrabReviewState::Approved => ReviewState::Approved,
+ _ => ReviewState::Other,
+ }),
+ })
+ .collect())
+ }
+
+ async fn get_pull_request_comments(
+ &self,
+ pr_number: u64,
+ ) -> Result<Vec<PullRequestComment>> {
+ let page = self
+ .client
+ .issues(ORG, REPO)
+ .list_comments(pr_number)
+ .per_page(PAGE_SIZE)
+ .send()
+ .await?;
+
+ let comments = self.get_all(page).await?;
+
+ Ok(comments
+ .into_iter()
+ .map(|comment| PullRequestComment {
+ user: GitHubUser {
+ login: comment.user.login,
+ },
+ body: comment.body,
+ })
+ .collect())
+ }
+
+ async fn get_commit_authors(
+ &self,
+ commit_shas: &[&CommitSha],
+ ) -> Result<AuthorsForCommits> {
+ let query = Self::build_co_authors_query(commit_shas.iter().copied());
+ let query = serde_json::json!({ "query": query });
+ let mut response = self.graphql::<serde_json::Value>(&query).await?;
+
+ response
+ .get_mut("data")
+ .and_then(|data| data.get_mut("repository"))
+ .and_then(|repo| repo.as_object_mut())
+ .ok_or_else(|| anyhow::anyhow!("Unexpected response format!"))
+ .and_then(|commit_data| {
+ let mut response_map = serde_json::Map::with_capacity(commit_data.len());
+
+ for (key, value) in commit_data.iter_mut() {
+ let key_without_prefix = key.strip_prefix("commit").unwrap_or(key);
+ if let Some(authors) = value.get_mut("authors") {
+ if let Some(nodes) = authors.get("nodes") {
+ *authors = nodes.clone();
+ }
+ }
+
+ response_map.insert(key_without_prefix.to_owned(), value.clone());
+ }
+
+ serde_json::from_value(serde_json::Value::Object(response_map))
+ .context("Failed to deserialize commit authors")
+ })
+ }
+
+ async fn check_org_membership(&self, login: &GithubLogin) -> Result<bool> {
+ let page = self
+ .client
+ .orgs(ORG)
+ .list_members()
+ .per_page(PAGE_SIZE)
+ .send()
+ .await?;
+
+ let members = self.get_all(page).await?;
+
+ Ok(members
+ .into_iter()
+ .any(|member| member.login == login.as_str()))
+ }
+
+ async fn ensure_pull_request_has_label(&self, label: &str, pr_number: u64) -> Result<()> {
+ if self
+ .get_filtered(
+ self.client
+ .issues(ORG, REPO)
+ .list_labels_for_issue(pr_number)
+ .per_page(PAGE_SIZE)
+ .send()
+ .await?,
+ |pr_label| pr_label.name == label,
+ )
+ .await
+ .is_ok_and(|l| l.is_empty())
+ {
+ self.client
+ .issues(ORG, REPO)
+ .add_labels(pr_number, &[label.to_owned()])
+ .await?;
+ }
+
+ Ok(())
+ }
+ }
+}
+
+#[cfg(feature = "octo-client")]
+pub use octo_client::OctocrabClient;
@@ -0,0 +1,4 @@
+pub mod checks;
+pub mod git;
+pub mod github;
+pub mod report;
@@ -0,0 +1,446 @@
+use std::{
+ fs::{self, File},
+ io::{BufWriter, Write},
+ path::Path,
+};
+
+use anyhow::Context as _;
+use derive_more::Display;
+use itertools::{Either, Itertools};
+
+use crate::{
+ checks::{ReviewFailure, ReviewResult, ReviewSuccess},
+ git::CommitDetails,
+};
+
+const PULL_REQUEST_BASE_URL: &str = "https://github.com/zed-industries/zed/pull";
+
+#[derive(Debug)]
+pub struct ReportEntry<R> {
+ pub commit: CommitDetails,
+ reason: R,
+}
+
+impl<R: ToString> ReportEntry<R> {
+ fn commit_cell(&self) -> String {
+ let title = escape_markdown_link_text(self.commit.title());
+
+ match self.commit.pr_number() {
+ Some(pr_number) => format!("[{title}]({PULL_REQUEST_BASE_URL}/{pr_number})"),
+ None => escape_markdown_table_text(self.commit.title()),
+ }
+ }
+
+ fn pull_request_cell(&self) -> String {
+ self.commit
+ .pr_number()
+ .map(|pr_number| format!("#{pr_number}"))
+ .unwrap_or_else(|| "—".to_owned())
+ }
+
+ fn author_cell(&self) -> String {
+ escape_markdown_table_text(&self.commit.author().to_string())
+ }
+
+ fn reason_cell(&self) -> String {
+ escape_markdown_table_text(&self.reason.to_string())
+ }
+}
+
+impl ReportEntry<ReviewFailure> {
+ fn issue_kind(&self) -> IssueKind {
+ match self.reason {
+ ReviewFailure::Other(_) => IssueKind::Error,
+ _ => IssueKind::NotReviewed,
+ }
+ }
+}
+
+impl ReportEntry<ReviewSuccess> {
+ fn reviewers_cell(&self) -> String {
+ match &self.reason.reviewers() {
+ Ok(reviewers) => escape_markdown_table_text(&reviewers),
+ Err(_) => "—".to_owned(),
+ }
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct ReportSummary {
+ pub pull_requests: usize,
+ pub reviewed: usize,
+ pub not_reviewed: usize,
+ pub errors: usize,
+}
+
+pub enum ReportReviewSummary {
+ MissingReviews,
+ MissingReviewsWithErrors,
+ NoIssuesFound,
+}
+
+impl ReportSummary {
+ fn from_entries(entries: &[ReportEntry<ReviewResult>]) -> Self {
+ Self {
+ pull_requests: entries
+ .iter()
+ .filter_map(|entry| entry.commit.pr_number())
+ .unique()
+ .count(),
+ reviewed: entries.iter().filter(|entry| entry.reason.is_ok()).count(),
+ not_reviewed: entries
+ .iter()
+ .filter(|entry| {
+ matches!(
+ entry.reason,
+ Err(ReviewFailure::NoPullRequestFound | ReviewFailure::Unreviewed)
+ )
+ })
+ .count(),
+ errors: entries
+ .iter()
+ .filter(|entry| matches!(entry.reason, Err(ReviewFailure::Other(_))))
+ .count(),
+ }
+ }
+
+ pub fn review_summary(&self) -> ReportReviewSummary {
+ match self.not_reviewed {
+ 0 if self.errors == 0 => ReportReviewSummary::NoIssuesFound,
+ 1.. if self.errors == 0 => ReportReviewSummary::MissingReviews,
+ _ => ReportReviewSummary::MissingReviewsWithErrors,
+ }
+ }
+
+ fn has_errors(&self) -> bool {
+ self.errors > 0
+ }
+}
+
+#[derive(Clone, Copy, Debug, Display, PartialEq, Eq, PartialOrd, Ord)]
+enum IssueKind {
+ #[display("Error")]
+ Error,
+ #[display("Not reviewed")]
+ NotReviewed,
+}
+
+#[derive(Debug, Default)]
+pub struct Report {
+ entries: Vec<ReportEntry<ReviewResult>>,
+}
+
+impl Report {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ pub fn add(&mut self, commit: CommitDetails, result: ReviewResult) {
+ self.entries.push(ReportEntry {
+ commit,
+ reason: result,
+ });
+ }
+
+ pub fn errors(&self) -> impl Iterator<Item = &ReportEntry<ReviewResult>> {
+ self.entries.iter().filter(|entry| entry.reason.is_err())
+ }
+
+ pub fn summary(&self) -> ReportSummary {
+ ReportSummary::from_entries(&self.entries)
+ }
+
+ pub fn write_markdown(self, path: impl AsRef<Path>) -> anyhow::Result<()> {
+ let path = path.as_ref();
+
+ if let Some(parent) = path
+ .parent()
+ .filter(|parent| !parent.as_os_str().is_empty())
+ {
+ fs::create_dir_all(parent).with_context(|| {
+ format!(
+ "Failed to create parent directory for markdown report at {}",
+ path.display()
+ )
+ })?;
+ }
+
+ let summary = self.summary();
+ let (successes, mut issues): (Vec<_>, Vec<_>) =
+ self.entries
+ .into_iter()
+ .partition_map(|entry| match entry.reason {
+ Ok(success) => Either::Left(ReportEntry {
+ reason: success,
+ commit: entry.commit,
+ }),
+ Err(fail) => Either::Right(ReportEntry {
+ reason: fail,
+ commit: entry.commit,
+ }),
+ });
+
+ issues.sort_by_key(|entry| entry.issue_kind());
+
+ let file = File::create(path)
+ .with_context(|| format!("Failed to create markdown report at {}", path.display()))?;
+ let mut writer = BufWriter::new(file);
+
+ writeln!(writer, "# Compliance report")?;
+ writeln!(writer)?;
+ writeln!(writer, "## Overview")?;
+ writeln!(writer)?;
+ writeln!(writer, "- PRs: {}", summary.pull_requests)?;
+ writeln!(writer, "- Reviewed: {}", summary.reviewed)?;
+ writeln!(writer, "- Not reviewed: {}", summary.not_reviewed)?;
+ if summary.has_errors() {
+ writeln!(writer, "- Errors: {}", summary.errors)?;
+ }
+ writeln!(writer)?;
+
+ write_issue_table(&mut writer, &issues, &summary)?;
+ write_success_table(&mut writer, &successes)?;
+
+ writer
+ .flush()
+ .with_context(|| format!("Failed to flush markdown report to {}", path.display()))
+ }
+}
+
+fn write_issue_table(
+ writer: &mut impl Write,
+ issues: &[ReportEntry<ReviewFailure>],
+ summary: &ReportSummary,
+) -> std::io::Result<()> {
+ if summary.has_errors() {
+ writeln!(writer, "## Errors and unreviewed commits")?;
+ } else {
+ writeln!(writer, "## Unreviewed commits")?;
+ }
+ writeln!(writer)?;
+
+ if issues.is_empty() {
+ if summary.has_errors() {
+ writeln!(writer, "No errors or unreviewed commits found.")?;
+ } else {
+ writeln!(writer, "No unreviewed commits found.")?;
+ }
+ writeln!(writer)?;
+ return Ok(());
+ }
+
+ writeln!(writer, "| Commit | PR | Author | Outcome | Reason |")?;
+ writeln!(writer, "| --- | --- | --- | --- | --- |")?;
+
+ for entry in issues {
+ let issue_kind = entry.issue_kind();
+ writeln!(
+ writer,
+ "| {} | {} | {} | {} | {} |",
+ entry.commit_cell(),
+ entry.pull_request_cell(),
+ entry.author_cell(),
+ issue_kind,
+ entry.reason_cell(),
+ )?;
+ }
+
+ writeln!(writer)?;
+ Ok(())
+}
+
+fn write_success_table(
+ writer: &mut impl Write,
+ successful_entries: &[ReportEntry<ReviewSuccess>],
+) -> std::io::Result<()> {
+ writeln!(writer, "## Successful commits")?;
+ writeln!(writer)?;
+
+ if successful_entries.is_empty() {
+ writeln!(writer, "No successful commits found.")?;
+ writeln!(writer)?;
+ return Ok(());
+ }
+
+ writeln!(writer, "| Commit | PR | Author | Reviewers | Reason |")?;
+ writeln!(writer, "| --- | --- | --- | --- | --- |")?;
+
+ for entry in successful_entries {
+ writeln!(
+ writer,
+ "| {} | {} | {} | {} | {} |",
+ entry.commit_cell(),
+ entry.pull_request_cell(),
+ entry.author_cell(),
+ entry.reviewers_cell(),
+ entry.reason_cell(),
+ )?;
+ }
+
+ writeln!(writer)?;
+ Ok(())
+}
+
+fn escape_markdown_link_text(input: &str) -> String {
+ escape_markdown_table_text(input)
+ .replace('[', r"\[")
+ .replace(']', r"\]")
+}
+
+fn escape_markdown_table_text(input: &str) -> String {
+ input
+ .replace('\\', r"\\")
+ .replace('|', r"\|")
+ .replace('\r', "")
+ .replace('\n', "<br>")
+}
+
+#[cfg(test)]
+mod tests {
+ use std::str::FromStr;
+
+ use crate::{
+ checks::{ReviewFailure, ReviewSuccess},
+ git::{CommitDetails, CommitList},
+ github::{GitHubUser, PullRequestReview, ReviewState},
+ };
+
+ use super::{Report, ReportReviewSummary};
+
+ fn make_commit(
+ sha: &str,
+ author_name: &str,
+ author_email: &str,
+ title: &str,
+ body: &str,
+ ) -> CommitDetails {
+ let formatted = format!(
+ "{sha}|field-delimiter|{author_name}|field-delimiter|{author_email}|field-delimiter|{title}|body-delimiter|{body}|commit-delimiter|"
+ );
+ CommitList::from_str(&formatted)
+ .expect("test commit should parse")
+ .into_iter()
+ .next()
+ .expect("should have one commit")
+ }
+
+ fn reviewed() -> ReviewSuccess {
+ ReviewSuccess::PullRequestReviewed(vec![PullRequestReview {
+ user: Some(GitHubUser {
+ login: "reviewer".to_owned(),
+ }),
+ state: Some(ReviewState::Approved),
+ }])
+ }
+
+ #[test]
+ fn report_summary_counts_are_accurate() {
+ let mut report = Report::new();
+
+ report.add(
+ make_commit(
+ "aaa",
+ "Alice",
+ "alice@test.com",
+ "Reviewed commit (#100)",
+ "",
+ ),
+ Ok(reviewed()),
+ );
+ report.add(
+ make_commit("bbb", "Bob", "bob@test.com", "Unreviewed commit (#200)", ""),
+ Err(ReviewFailure::Unreviewed),
+ );
+ report.add(
+ make_commit("ccc", "Carol", "carol@test.com", "No PR commit", ""),
+ Err(ReviewFailure::NoPullRequestFound),
+ );
+ report.add(
+ make_commit("ddd", "Dave", "dave@test.com", "Error commit (#300)", ""),
+ Err(ReviewFailure::Other(anyhow::anyhow!("some error"))),
+ );
+
+ let summary = report.summary();
+ assert_eq!(summary.pull_requests, 3);
+ assert_eq!(summary.reviewed, 1);
+ assert_eq!(summary.not_reviewed, 2);
+ assert_eq!(summary.errors, 1);
+ }
+
+ #[test]
+ fn report_summary_all_reviewed_is_no_issues() {
+ let mut report = Report::new();
+
+ report.add(
+ make_commit("aaa", "Alice", "alice@test.com", "First (#100)", ""),
+ Ok(reviewed()),
+ );
+ report.add(
+ make_commit("bbb", "Bob", "bob@test.com", "Second (#200)", ""),
+ Ok(reviewed()),
+ );
+
+ let summary = report.summary();
+ assert!(matches!(
+ summary.review_summary(),
+ ReportReviewSummary::NoIssuesFound
+ ));
+ }
+
+ #[test]
+ fn report_summary_missing_reviews_only() {
+ let mut report = Report::new();
+
+ report.add(
+ make_commit("aaa", "Alice", "alice@test.com", "Reviewed (#100)", ""),
+ Ok(reviewed()),
+ );
+ report.add(
+ make_commit("bbb", "Bob", "bob@test.com", "Unreviewed (#200)", ""),
+ Err(ReviewFailure::Unreviewed),
+ );
+
+ let summary = report.summary();
+ assert!(matches!(
+ summary.review_summary(),
+ ReportReviewSummary::MissingReviews
+ ));
+ }
+
+ #[test]
+ fn report_summary_errors_and_missing_reviews() {
+ let mut report = Report::new();
+
+ report.add(
+ make_commit("aaa", "Alice", "alice@test.com", "Unreviewed (#100)", ""),
+ Err(ReviewFailure::Unreviewed),
+ );
+ report.add(
+ make_commit("bbb", "Bob", "bob@test.com", "Errored (#200)", ""),
+ Err(ReviewFailure::Other(anyhow::anyhow!("check failed"))),
+ );
+
+ let summary = report.summary();
+ assert!(matches!(
+ summary.review_summary(),
+ ReportReviewSummary::MissingReviewsWithErrors
+ ));
+ }
+
+ #[test]
+ fn report_summary_deduplicates_pull_requests() {
+ let mut report = Report::new();
+
+ report.add(
+ make_commit("aaa", "Alice", "alice@test.com", "First change (#100)", ""),
+ Ok(reviewed()),
+ );
+ report.add(
+ make_commit("bbb", "Bob", "bob@test.com", "Second change (#100)", ""),
+ Ok(reviewed()),
+ );
+
+ let summary = report.summary();
+ assert_eq!(summary.pull_requests, 1);
+ }
+}
@@ -15,7 +15,8 @@ backtrace.workspace = true
cargo_metadata.workspace = true
cargo_toml.workspace = true
clap = { workspace = true, features = ["derive"] }
-toml.workspace = true
+compliance = { workspace = true, features = ["octo-client"] }
+gh-workflow.workspace = true
indoc.workspace = true
indexmap.workspace = true
itertools.workspace = true
@@ -24,5 +25,6 @@ serde.workspace = true
serde_json.workspace = true
serde_yaml = "0.9.34"
strum.workspace = true
+tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
+toml.workspace = true
toml_edit.workspace = true
-gh-workflow.workspace = true
@@ -15,6 +15,7 @@ struct Args {
enum CliCommand {
/// Runs `cargo clippy`.
Clippy(tasks::clippy::ClippyArgs),
+ Compliance(tasks::compliance::ComplianceArgs),
Licenses(tasks::licenses::LicensesArgs),
/// Checks that packages conform to a set of standards.
PackageConformity(tasks::package_conformity::PackageConformityArgs),
@@ -31,6 +32,7 @@ fn main() -> Result<()> {
match args.command {
CliCommand::Clippy(args) => tasks::clippy::run_clippy(args),
+ CliCommand::Compliance(args) => tasks::compliance::check_compliance(args),
CliCommand::Licenses(args) => tasks::licenses::run_licenses(args),
CliCommand::PackageConformity(args) => {
tasks::package_conformity::run_package_conformity(args)
@@ -1,4 +1,5 @@
pub mod clippy;
+pub mod compliance;
pub mod licenses;
pub mod package_conformity;
pub mod publish_gpui;
@@ -0,0 +1,135 @@
+use std::path::PathBuf;
+
+use anyhow::{Context, Result};
+use clap::Parser;
+
+use compliance::{
+ checks::Reporter,
+ git::{CommitsFromVersionToHead, GetVersionTags, GitCommand, VersionTag},
+ github::GitHubClient,
+ report::ReportReviewSummary,
+};
+
+#[derive(Parser)]
+pub struct ComplianceArgs {
+ #[arg(value_parser = VersionTag::parse)]
+ // The version to be on the lookout for
+ pub(crate) version_tag: VersionTag,
+ #[arg(long)]
+ // The markdown file to write the compliance report to
+ report_path: PathBuf,
+ #[arg(long)]
+ // An optional branch to use instead of the determined version branch
+ branch: Option<String>,
+}
+
+impl ComplianceArgs {
+ pub(crate) fn version_tag(&self) -> &VersionTag {
+ &self.version_tag
+ }
+
+ fn version_branch(&self) -> String {
+ self.branch.clone().unwrap_or_else(|| {
+ format!(
+ "v{major}.{minor}.x",
+ major = self.version_tag().version().major,
+ minor = self.version_tag().version().minor
+ )
+ })
+ }
+}
+
+async fn check_compliance_impl(args: ComplianceArgs) -> Result<()> {
+ let app_id = std::env::var("GITHUB_APP_ID").context("Missing GITHUB_APP_ID")?;
+ let key = std::env::var("GITHUB_APP_KEY").context("Missing GITHUB_APP_KEY")?;
+
+ let tag = args.version_tag();
+
+ let previous_version = GitCommand::run(GetVersionTags)?
+ .sorted()
+ .find_previous_minor_version(&tag)
+ .cloned()
+ .ok_or_else(|| {
+ anyhow::anyhow!(
+ "Could not find previous version for tag {tag}",
+ tag = tag.to_string()
+ )
+ })?;
+
+ println!(
+ "Checking compliance for version {} with version {} as base",
+ tag.version(),
+ previous_version.version()
+ );
+
+ let commits = GitCommand::run(CommitsFromVersionToHead::new(
+ previous_version,
+ args.version_branch(),
+ ))?;
+
+ let Some(range) = commits.range() else {
+ anyhow::bail!("No commits found to check");
+ };
+
+ println!("Checking commit range {range}, {} total", commits.len());
+
+ let client = GitHubClient::for_app(
+ app_id.parse().context("Failed to parse app ID as int")?,
+ key.as_ref(),
+ )
+ .await?;
+
+ println!("Initialized GitHub client for app ID {app_id}");
+
+ let report = Reporter::new(commits, &client).generate_report().await?;
+
+ println!(
+ "Generated report for version {}",
+ args.version_tag().to_string()
+ );
+
+ let summary = report.summary();
+
+ println!(
+ "Applying compliance labels to {} pull requests",
+ summary.pull_requests
+ );
+
+ for report in report.errors() {
+ if let Some(pr_number) = report.commit.pr_number() {
+ println!("Adding review label to PR {}...", pr_number);
+
+ client
+ .add_label_to_pull_request(compliance::github::PR_REVIEW_LABEL, pr_number)
+ .await?;
+ }
+ }
+
+ let report_path = args.report_path.with_extension("md");
+
+ report.write_markdown(&report_path)?;
+
+ println!("Wrote compliance report to {}", report_path.display());
+
+ match summary.review_summary() {
+ ReportReviewSummary::MissingReviews => Err(anyhow::anyhow!(
+ "Compliance check failed, found {} commits not reviewed",
+ summary.not_reviewed
+ )),
+ ReportReviewSummary::MissingReviewsWithErrors => Err(anyhow::anyhow!(
+ "Compliance check failed with {} unreviewed commits and {} other issues",
+ summary.not_reviewed,
+ summary.errors
+ )),
+ ReportReviewSummary::NoIssuesFound => {
+ println!("No issues found, compliance check passed.");
+ Ok(())
+ }
+ }
+}
+
+pub fn check_compliance(args: ComplianceArgs) -> Result<()> {
+ tokio::runtime::Runtime::new()
+ .context("Failed to create tokio runtime")
+ .and_then(|handle| handle.block_on(check_compliance_impl(args)))
+}
@@ -11,6 +11,7 @@ mod autofix_pr;
mod bump_patch_version;
mod cherry_pick;
mod compare_perf;
+mod compliance_check;
mod danger;
mod deploy_collab;
mod extension_auto_bump;
@@ -197,6 +198,7 @@ pub fn run_workflows(args: GenerateWorkflowArgs) -> Result<()> {
WorkflowFile::zed(bump_patch_version::bump_patch_version),
WorkflowFile::zed(cherry_pick::cherry_pick),
WorkflowFile::zed(compare_perf::compare_perf),
+ WorkflowFile::zed(compliance_check::compliance_check),
WorkflowFile::zed(danger::danger),
WorkflowFile::zed(deploy_collab::deploy_collab),
WorkflowFile::zed(extension_bump::extension_bump),
@@ -0,0 +1,66 @@
+use gh_workflow::{Event, Expression, Job, Run, Schedule, Step, Workflow};
+
+use crate::tasks::workflows::{
+ runners,
+ steps::{self, CommonJobConditions, named},
+ vars::{self, StepOutput},
+};
+
+pub fn compliance_check() -> Workflow {
+ let check = scheduled_compliance_check();
+
+ named::workflow()
+ .on(Event::default().schedule([Schedule::new("30 17 * * 2")]))
+ .add_env(("CARGO_TERM_COLOR", "always"))
+ .add_job(check.name, check.job)
+}
+
+fn scheduled_compliance_check() -> steps::NamedJob {
+ let determine_version_step = named::bash(indoc::indoc! {r#"
+ VERSION=$(sed -n 's/^version = "\(.*\)"/\1/p' crates/zed/Cargo.toml | tr -d '[:space:]')
+ if [ -z "$VERSION" ]; then
+ echo "Could not determine version from crates/zed/Cargo.toml"
+ exit 1
+ fi
+ TAG="v${VERSION}-pre"
+ echo "Checking compliance for $TAG"
+ echo "tag=$TAG" >> "$GITHUB_OUTPUT"
+ "#})
+ .id("determine-version");
+
+ let tag_output = StepOutput::new(&determine_version_step, "tag");
+
+ fn run_compliance_check(tag: &StepOutput) -> Step<Run> {
+ named::bash(
+ r#"cargo xtask compliance "$LATEST_TAG" --branch main --report-path target/compliance-report"#,
+ )
+ .id("run-compliance-check")
+ .add_env(("LATEST_TAG", tag.to_string()))
+ .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID))
+ .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY))
+ }
+
+ fn send_failure_slack_notification(tag: &StepOutput) -> Step<Run> {
+ named::bash(indoc::indoc! {r#"
+ MESSAGE="⚠️ Scheduled compliance check failed for upcoming preview release $LATEST_TAG: There are PRs with missing reviews."
+
+ curl -X POST -H 'Content-type: application/json' \
+ --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+ "$SLACK_WEBHOOK"
+ "#})
+ .if_condition(Expression::new("failure()"))
+ .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES))
+ .add_env(("LATEST_TAG", tag.to_string()))
+ }
+
+ named::job(
+ Job::default()
+ .with_repository_owner_guard()
+ .runs_on(runners::LINUX_SMALL)
+ .add_step(steps::checkout_repo().with_full_history())
+ .add_step(steps::cache_rust_dependencies_namespace())
+ .add_step(determine_version_step)
+ .add_step(run_compliance_check(&tag_output))
+ .add_step(send_failure_slack_notification(&tag_output)),
+ )
+}
@@ -1,11 +1,13 @@
-use gh_workflow::{Event, Expression, Push, Run, Step, Use, Workflow, ctx::Context};
+use gh_workflow::{Event, Expression, Job, Push, Run, Step, Use, Workflow, ctx::Context};
use indoc::formatdoc;
use crate::tasks::workflows::{
run_bundling::{bundle_linux, bundle_mac, bundle_windows},
run_tests,
runners::{self, Arch, Platform},
- steps::{self, FluentBuilder, NamedJob, dependant_job, named, release_job},
+ steps::{
+ self, CommonJobConditions, FluentBuilder, NamedJob, dependant_job, named, release_job,
+ },
vars::{self, StepOutput, assets},
};
@@ -22,6 +24,7 @@ pub(crate) fn release() -> Workflow {
let check_scripts = run_tests::check_scripts();
let create_draft_release = create_draft_release();
+ let compliance = compliance_check();
let bundle = ReleaseBundleJobs {
linux_aarch64: bundle_linux(
@@ -92,6 +95,7 @@ pub(crate) fn release() -> Workflow {
.add_job(windows_clippy.name, windows_clippy.job)
.add_job(check_scripts.name, check_scripts.job)
.add_job(create_draft_release.name, create_draft_release.job)
+ .add_job(compliance.name, compliance.job)
.map(|mut workflow| {
for job in bundle.into_jobs() {
workflow = workflow.add_job(job.name, job.job);
@@ -149,6 +153,59 @@ pub(crate) fn create_sentry_release() -> Step<Use> {
.add_with(("environment", "production"))
}
+fn compliance_check() -> NamedJob {
+ fn run_compliance_check() -> Step<Run> {
+ named::bash(
+ r#"cargo xtask compliance "$GITHUB_REF_NAME" --report-path "$COMPLIANCE_FILE_OUTPUT""#,
+ )
+ .id("run-compliance-check")
+ .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID))
+ .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY))
+ }
+
+ fn send_compliance_slack_notification() -> Step<Run> {
+ named::bash(indoc::indoc! {r#"
+ if [ "$COMPLIANCE_OUTCOME" == "success" ]; then
+ STATUS="✅ Compliance check passed for $GITHUB_REF_NAME"
+ else
+ STATUS="❌ Compliance check failed for $GITHUB_REF_NAME"
+ fi
+
+ REPORT_CONTENT=""
+ if [ -f "$COMPLIANCE_FILE_OUTPUT" ]; then
+ REPORT_CONTENT=$(cat "$REPORT_FILE")
+ fi
+
+ MESSAGE=$(printf "%s\n\n%s" "$STATUS" "$REPORT_CONTENT")
+
+ curl -X POST -H 'Content-type: application/json' \
+ --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+ "$SLACK_WEBHOOK"
+ "#})
+ .if_condition(Expression::new("always()"))
+ .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES))
+ .add_env((
+ "COMPLIANCE_OUTCOME",
+ "${{ steps.run-compliance-check.outcome }}",
+ ))
+ }
+
+ named::job(
+ Job::default()
+ .add_env(("COMPLIANCE_FILE_PATH", "compliance.md"))
+ .with_repository_owner_guard()
+ .runs_on(runners::LINUX_DEFAULT)
+ .add_step(
+ steps::checkout_repo()
+ .with_full_history()
+ .with_ref(Context::github().ref_()),
+ )
+ .add_step(steps::cache_rust_dependencies_namespace())
+ .add_step(run_compliance_check())
+ .add_step(send_compliance_slack_notification()),
+ )
+}
+
fn validate_release_assets(deps: &[&NamedJob]) -> NamedJob {
let expected_assets: Vec<String> = assets::all().iter().map(|a| format!("\"{a}\"")).collect();
let expected_assets_json = format!("[{}]", expected_assets.join(", "));
@@ -171,10 +228,54 @@ fn validate_release_assets(deps: &[&NamedJob]) -> NamedJob {
"#,
};
+ fn run_post_upload_compliance_check() -> Step<Run> {
+ named::bash(
+ r#"cargo xtask compliance "$GITHUB_REF_NAME" --report-path target/compliance-report"#,
+ )
+ .id("run-post-upload-compliance-check")
+ .add_env(("GITHUB_APP_ID", vars::ZED_ZIPPY_APP_ID))
+ .add_env(("GITHUB_APP_KEY", vars::ZED_ZIPPY_APP_PRIVATE_KEY))
+ }
+
+ fn send_post_upload_compliance_notification() -> Step<Run> {
+ named::bash(indoc::indoc! {r#"
+ if [ -z "$COMPLIANCE_OUTCOME" ] || [ "$COMPLIANCE_OUTCOME" == "skipped" ]; then
+ echo "Compliance check was skipped, not sending notification"
+ exit 0
+ fi
+
+ TAG="$GITHUB_REF_NAME"
+
+ if [ "$COMPLIANCE_OUTCOME" == "success" ]; then
+ MESSAGE="✅ Post-upload compliance re-check passed for $TAG"
+ else
+ MESSAGE="❌ Post-upload compliance re-check failed for $TAG"
+ fi
+
+ curl -X POST -H 'Content-type: application/json' \
+ --data "$(jq -n --arg text "$MESSAGE" '{"text": $text}')" \
+ "$SLACK_WEBHOOK"
+ "#})
+ .if_condition(Expression::new("always()"))
+ .add_env(("SLACK_WEBHOOK", vars::SLACK_WEBHOOK_WORKFLOW_FAILURES))
+ .add_env((
+ "COMPLIANCE_OUTCOME",
+ "${{ steps.run-post-upload-compliance-check.outcome }}",
+ ))
+ }
+
named::job(
- dependant_job(deps).runs_on(runners::LINUX_SMALL).add_step(
- named::bash(&validation_script).add_env(("GITHUB_TOKEN", vars::GITHUB_TOKEN)),
- ),
+ dependant_job(deps)
+ .runs_on(runners::LINUX_SMALL)
+ .add_step(named::bash(&validation_script).add_env(("GITHUB_TOKEN", vars::GITHUB_TOKEN)))
+ .add_step(
+ steps::checkout_repo()
+ .with_full_history()
+ .with_ref(Context::github().ref_()),
+ )
+ .add_step(steps::cache_rust_dependencies_namespace())
+ .add_step(run_post_upload_compliance_check())
+ .add_step(send_post_upload_compliance_notification()),
)
}
@@ -255,7 +356,7 @@ fn create_draft_release() -> NamedJob {
.add_step(
steps::checkout_repo()
.with_custom_fetch_depth(25)
- .with_ref("${{ github.ref }}"),
+ .with_ref(Context::github().ref_()),
)
.add_step(steps::script("script/determine-release-channel"))
.add_step(steps::script("mkdir -p target/"))