Detailed changes
@@ -26,6 +26,7 @@ jobs:
CC: clang
CXX: clang++
DOCS_AMPLITUDE_API_KEY: ${{ secrets.DOCS_AMPLITUDE_API_KEY }}
+ DOCS_CONSENT_IO_INSTANCE: ${{ secrets.DOCS_CONSENT_IO_INSTANCE }}
- name: Deploy Docs
uses: cloudflare/wrangler-action@da0e0dfe58b7a431659754fdf3f186c529afbe65 # v3
@@ -76,6 +76,7 @@ dependencies = [
"clock",
"collections",
"ctor",
+ "fs",
"futures 0.3.31",
"gpui",
"indoc",
@@ -1352,6 +1353,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"log",
+ "scopeguard",
"simplelog",
"tempfile",
"windows 0.61.3",
@@ -7597,7 +7599,7 @@ dependencies = [
"mach2 0.5.0",
"media",
"metal",
- "naga",
+ "naga 28.0.0",
"num_cpus",
"objc",
"objc2",
@@ -10701,6 +10703,30 @@ name = "naga"
version = "28.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "618f667225063219ddfc61251087db8a9aec3c3f0950c916b614e403486f1135"
+dependencies = [
+ "arrayvec",
+ "bit-set",
+ "bitflags 2.10.0",
+ "cfg-if",
+ "cfg_aliases 0.2.1",
+ "codespan-reporting 0.12.0",
+ "half",
+ "hashbrown 0.16.1",
+ "hexf-parse",
+ "indexmap",
+ "libm",
+ "log",
+ "num-traits",
+ "once_cell",
+ "rustc-hash 1.1.0",
+ "thiserror 2.0.17",
+ "unicode-ident",
+]
+
+[[package]]
+name = "naga"
+version = "28.0.1"
+source = "git+https://github.com/zed-industries/wgpu?rev=6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d#6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d"
dependencies = [
"arrayvec",
"bit-set",
@@ -19825,6 +19851,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
+ "cloud_api_types",
"cloud_llm_client",
"futures 0.3.31",
"gpui",
@@ -19889,9 +19916,8 @@ checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3"
[[package]]
name = "wgpu"
-version = "28.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f9cb534d5ffd109c7d1135f34cdae29e60eab94855a625dcfe1705f8bc7ad79f"
+version = "28.0.1"
+source = "git+https://github.com/zed-industries/wgpu?rev=6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d#6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d"
dependencies = [
"arrayvec",
"bitflags 2.10.0",
@@ -19902,7 +19928,7 @@ dependencies = [
"hashbrown 0.16.1",
"js-sys",
"log",
- "naga",
+ "naga 28.0.1",
"parking_lot",
"portable-atomic",
"profiling",
@@ -19919,9 +19945,8 @@ dependencies = [
[[package]]
name = "wgpu-core"
-version = "28.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8bb4c8b5db5f00e56f1f08869d870a0dff7c8bc7ebc01091fec140b0cf0211a9"
+version = "28.0.1"
+source = "git+https://github.com/zed-industries/wgpu?rev=6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d#6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d"
dependencies = [
"arrayvec",
"bit-set",
@@ -19933,7 +19958,7 @@ dependencies = [
"hashbrown 0.16.1",
"indexmap",
"log",
- "naga",
+ "naga 28.0.1",
"once_cell",
"parking_lot",
"portable-atomic",
@@ -19951,36 +19976,32 @@ dependencies = [
[[package]]
name = "wgpu-core-deps-apple"
-version = "28.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "87b7b696b918f337c486bf93142454080a32a37832ba8a31e4f48221890047da"
+version = "28.0.1"
+source = "git+https://github.com/zed-industries/wgpu?rev=6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d#6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d"
dependencies = [
"wgpu-hal",
]
[[package]]
name = "wgpu-core-deps-emscripten"
-version = "28.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "34b251c331f84feac147de3c4aa3aa45112622a95dd7ee1b74384fa0458dbd79"
+version = "28.0.1"
+source = "git+https://github.com/zed-industries/wgpu?rev=6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d#6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d"
dependencies = [
"wgpu-hal",
]
[[package]]
name = "wgpu-core-deps-windows-linux-android"
-version = "28.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "68ca976e72b2c9964eb243e281f6ce7f14a514e409920920dcda12ae40febaae"
+version = "28.0.1"
+source = "git+https://github.com/zed-industries/wgpu?rev=6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d#6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d"
dependencies = [
"wgpu-hal",
]
[[package]]
name = "wgpu-hal"
-version = "28.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "293080d77fdd14d6b08a67c5487dfddbf874534bb7921526db56a7b75d7e3bef"
+version = "28.0.1"
+source = "git+https://github.com/zed-industries/wgpu?rev=6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d#6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d"
dependencies = [
"android_system_properties",
"arrayvec",
@@ -20003,7 +20024,7 @@ dependencies = [
"libloading",
"log",
"metal",
- "naga",
+ "naga 28.0.1",
"ndk-sys",
"objc",
"once_cell",
@@ -20026,9 +20047,8 @@ dependencies = [
[[package]]
name = "wgpu-types"
-version = "28.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e18308757e594ed2cd27dddbb16a139c42a683819d32a2e0b1b0167552f5840c"
+version = "28.0.1"
+source = "git+https://github.com/zed-industries/wgpu?rev=6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d#6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d"
dependencies = [
"bitflags 2.10.0",
"bytemuck",
@@ -770,7 +770,7 @@ wax = "0.7"
which = "6.0.0"
wasm-bindgen = "0.2.113"
web-time = "1.1.0"
-wgpu = "28.0"
+wgpu = { git = "https://github.com/zed-industries/wgpu", rev = "6e0c2546d99dad72ce6ffb5b04349e6a4ce96e6d" }
windows-core = "0.61"
yawc = "0.2.5"
zeroize = "1.8"
@@ -815,6 +815,7 @@ features = [
"Win32_System_Ole",
"Win32_System_Performance",
"Win32_System_Pipes",
+ "Win32_System_RestartManager",
"Win32_System_SystemInformation",
"Win32_System_SystemServices",
"Win32_System_Threading",
@@ -204,6 +204,7 @@
{
"context": "Editor && editor_agent_diff",
"bindings": {
+ "alt-y": "agent::Keep",
"ctrl-alt-y": "agent::Keep",
"ctrl-alt-z": "agent::Reject",
"shift-alt-y": "agent::KeepAll",
@@ -214,6 +215,7 @@
{
"context": "AgentDiff",
"bindings": {
+ "alt-y": "agent::Keep",
"ctrl-alt-y": "agent::Keep",
"ctrl-alt-z": "agent::Reject",
"shift-alt-y": "agent::KeepAll",
@@ -242,6 +242,7 @@
"context": "AgentDiff",
"use_key_equivalents": true,
"bindings": {
+ "cmd-y": "agent::Keep",
"cmd-alt-y": "agent::Keep",
"cmd-alt-z": "agent::Reject",
"shift-alt-y": "agent::KeepAll",
@@ -252,6 +253,7 @@
"context": "Editor && editor_agent_diff",
"use_key_equivalents": true,
"bindings": {
+ "cmd-y": "agent::Keep",
"cmd-alt-y": "agent::Keep",
"cmd-alt-z": "agent::Reject",
"shift-alt-y": "agent::KeepAll",
@@ -448,6 +450,13 @@
"down": "search::NextHistoryQuery",
},
},
+ {
+ "context": "BufferSearchBar || ProjectSearchBar",
+ "use_key_equivalents": true,
+ "bindings": {
+ "ctrl-enter": "editor::Newline",
+ },
+ },
{
"context": "ProjectSearchBar",
"use_key_equivalents": true,
@@ -203,6 +203,7 @@
"context": "Editor && editor_agent_diff",
"use_key_equivalents": true,
"bindings": {
+ "alt-y": "agent::Keep",
"ctrl-alt-y": "agent::Keep",
"ctrl-alt-z": "agent::Reject",
"shift-alt-y": "agent::KeepAll",
@@ -214,6 +215,7 @@
"context": "AgentDiff",
"use_key_equivalents": true,
"bindings": {
+ "alt-y": "agent::Keep",
"ctrl-alt-y": "agent::Keep",
"ctrl-alt-z": "agent::Reject",
"shift-alt-y": "agent::KeepAll",
@@ -972,6 +972,8 @@ pub struct AcpThread {
had_error: bool,
/// The user's unsent prompt text, persisted so it can be restored when reloading the thread.
draft_prompt: Option<Vec<acp::ContentBlock>>,
+ /// The initial scroll position for the thread view, set during session registration.
+ ui_scroll_position: Option<gpui::ListOffset>,
}
impl From<&AcpThread> for ActionLogTelemetry {
@@ -1210,6 +1212,7 @@ impl AcpThread {
pending_terminal_exit: HashMap::default(),
had_error: false,
draft_prompt: None,
+ ui_scroll_position: None,
}
}
@@ -1229,6 +1232,14 @@ impl AcpThread {
self.draft_prompt = prompt;
}
+ pub fn ui_scroll_position(&self) -> Option<gpui::ListOffset> {
+ self.ui_scroll_position
+ }
+
+ pub fn set_ui_scroll_position(&mut self, position: Option<gpui::ListOffset>) {
+ self.ui_scroll_position = position;
+ }
+
pub fn connection(&self) -> &Rc<dyn AgentConnection> {
&self.connection
}
@@ -20,6 +20,7 @@ buffer_diff.workspace = true
log.workspace = true
clock.workspace = true
collections.workspace = true
+fs.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
@@ -1,14 +1,20 @@
use anyhow::{Context as _, Result};
use buffer_diff::BufferDiff;
use clock;
-use collections::BTreeMap;
+use collections::{BTreeMap, HashMap};
+use fs::MTime;
use futures::{FutureExt, StreamExt, channel::mpsc};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
use language::{Anchor, Buffer, BufferEvent, Point, ToOffset, ToPoint};
use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
-use std::{cmp, ops::Range, sync::Arc};
+use std::{
+ cmp,
+ ops::Range,
+ path::{Path, PathBuf},
+ sync::Arc,
+};
use text::{Edit, Patch, Rope};
use util::{RangeExt, ResultExt as _};
@@ -54,6 +60,8 @@ pub struct ActionLog {
linked_action_log: Option<Entity<ActionLog>>,
/// Stores undo information for the most recent reject operation
last_reject_undo: Option<LastRejectUndo>,
+ /// Tracks the last time files were read by the agent, to detect external modifications
+ file_read_times: HashMap<PathBuf, MTime>,
}
impl ActionLog {
@@ -64,6 +72,7 @@ impl ActionLog {
project,
linked_action_log: None,
last_reject_undo: None,
+ file_read_times: HashMap::default(),
}
}
@@ -76,6 +85,32 @@ impl ActionLog {
&self.project
}
+ pub fn file_read_time(&self, path: &Path) -> Option<MTime> {
+ self.file_read_times.get(path).copied()
+ }
+
+ fn update_file_read_time(&mut self, buffer: &Entity<Buffer>, cx: &App) {
+ let buffer = buffer.read(cx);
+ if let Some(file) = buffer.file() {
+ if let Some(local_file) = file.as_local() {
+ if let Some(mtime) = file.disk_state().mtime() {
+ let abs_path = local_file.abs_path(cx);
+ self.file_read_times.insert(abs_path, mtime);
+ }
+ }
+ }
+ }
+
+ fn remove_file_read_time(&mut self, buffer: &Entity<Buffer>, cx: &App) {
+ let buffer = buffer.read(cx);
+ if let Some(file) = buffer.file() {
+ if let Some(local_file) = file.as_local() {
+ let abs_path = local_file.abs_path(cx);
+ self.file_read_times.remove(&abs_path);
+ }
+ }
+ }
+
fn track_buffer_internal(
&mut self,
buffer: Entity<Buffer>,
@@ -506,24 +541,69 @@ impl ActionLog {
/// Track a buffer as read by agent, so we can notify the model about user edits.
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
- if let Some(linked_action_log) = &mut self.linked_action_log {
- linked_action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ self.buffer_read_impl(buffer, true, cx);
+ }
+
+ fn buffer_read_impl(
+ &mut self,
+ buffer: Entity<Buffer>,
+ record_file_read_time: bool,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(linked_action_log) = &self.linked_action_log {
+ // We don't want to share read times since the other agent hasn't read it necessarily
+ linked_action_log.update(cx, |log, cx| {
+ log.buffer_read_impl(buffer.clone(), false, cx);
+ });
+ }
+ if record_file_read_time {
+ self.update_file_read_time(&buffer, cx);
}
self.track_buffer_internal(buffer, false, cx);
}
/// Mark a buffer as created by agent, so we can refresh it in the context
pub fn buffer_created(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
- if let Some(linked_action_log) = &mut self.linked_action_log {
- linked_action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
+ self.buffer_created_impl(buffer, true, cx);
+ }
+
+ fn buffer_created_impl(
+ &mut self,
+ buffer: Entity<Buffer>,
+ record_file_read_time: bool,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(linked_action_log) = &self.linked_action_log {
+ // We don't want to share read times since the other agent hasn't read it necessarily
+ linked_action_log.update(cx, |log, cx| {
+ log.buffer_created_impl(buffer.clone(), false, cx);
+ });
+ }
+ if record_file_read_time {
+ self.update_file_read_time(&buffer, cx);
}
self.track_buffer_internal(buffer, true, cx);
}
/// Mark a buffer as edited by agent, so we can refresh it in the context
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
- if let Some(linked_action_log) = &mut self.linked_action_log {
- linked_action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+ self.buffer_edited_impl(buffer, true, cx);
+ }
+
+ fn buffer_edited_impl(
+ &mut self,
+ buffer: Entity<Buffer>,
+ record_file_read_time: bool,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(linked_action_log) = &self.linked_action_log {
+ // We don't want to share read times since the other agent hasn't read it necessarily
+ linked_action_log.update(cx, |log, cx| {
+ log.buffer_edited_impl(buffer.clone(), false, cx);
+ });
+ }
+ if record_file_read_time {
+ self.update_file_read_time(&buffer, cx);
}
let new_version = buffer.read(cx).version();
let tracked_buffer = self.track_buffer_internal(buffer, false, cx);
@@ -536,6 +616,8 @@ impl ActionLog {
}
pub fn will_delete_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
+ // Ok to propagate file read time removal to linked action log
+ self.remove_file_read_time(&buffer, cx);
let has_linked_action_log = self.linked_action_log.is_some();
let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx);
match tracked_buffer.status {
@@ -2976,6 +3058,196 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_file_read_time_recorded_on_buffer_read(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({"file": "hello world"}))
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ let abs_path = PathBuf::from(path!("/dir/file"));
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "file_read_time should be None before buffer_read"
+ );
+
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ });
+
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()),
+ "file_read_time should be recorded after buffer_read"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_file_read_time_recorded_on_buffer_edited(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({"file": "hello world"}))
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ let abs_path = PathBuf::from(path!("/dir/file"));
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "file_read_time should be None before buffer_edited"
+ );
+
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+ });
+
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()),
+ "file_read_time should be recorded after buffer_edited"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_file_read_time_recorded_on_buffer_created(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({"file": "existing content"}))
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ let abs_path = PathBuf::from(path!("/dir/file"));
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "file_read_time should be None before buffer_created"
+ );
+
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
+ });
+
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()),
+ "file_read_time should be recorded after buffer_created"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_file_read_time_removed_on_delete(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({"file": "hello world"}))
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ let abs_path = PathBuf::from(path!("/dir/file"));
+
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ });
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()),
+ "file_read_time should exist after buffer_read"
+ );
+
+ cx.update(|cx| {
+ action_log.update(cx, |log, cx| log.will_delete_buffer(buffer.clone(), cx));
+ });
+ assert!(
+ action_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "file_read_time should be removed after will_delete_buffer"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_file_read_time_not_forwarded_to_linked_action_log(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/dir"), json!({"file": "hello world"}))
+ .await;
+ let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+ let parent_log = cx.new(|_| ActionLog::new(project.clone()));
+ let child_log =
+ cx.new(|_| ActionLog::new(project.clone()).with_linked_action_log(parent_log.clone()));
+
+ let file_path = project
+ .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+ .unwrap();
+ let buffer = project
+ .update(cx, |project, cx| project.open_buffer(file_path, cx))
+ .await
+ .unwrap();
+
+ let abs_path = PathBuf::from(path!("/dir/file"));
+
+ cx.update(|cx| {
+ child_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ });
+ assert!(
+ child_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_some()),
+ "child should record file_read_time on buffer_read"
+ );
+ assert!(
+ parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "parent should NOT get file_read_time from child's buffer_read"
+ );
+
+ cx.update(|cx| {
+ child_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+ });
+ assert!(
+ parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "parent should NOT get file_read_time from child's buffer_edited"
+ );
+
+ cx.update(|cx| {
+ child_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
+ });
+ assert!(
+ parent_log.read_with(cx, |log, _| log.file_read_time(&abs_path).is_none()),
+ "parent should NOT get file_read_time from child's buffer_created"
+ );
+ }
+
#[derive(Debug, PartialEq)]
struct HunkStatus {
range: Range<Point>,
@@ -352,6 +352,8 @@ impl NativeAgent {
let parent_session_id = thread.parent_thread_id();
let title = thread.title();
let draft_prompt = thread.draft_prompt().map(Vec::from);
+ let scroll_position = thread.ui_scroll_position();
+ let token_usage = thread.latest_token_usage();
let project = thread.project.clone();
let action_log = thread.action_log.clone();
let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
@@ -367,6 +369,8 @@ impl NativeAgent {
cx,
);
acp_thread.set_draft_prompt(draft_prompt);
+ acp_thread.set_ui_scroll_position(scroll_position);
+ acp_thread.update_token_usage(token_usage, cx);
acp_thread
});
@@ -1917,7 +1921,9 @@ mod internal_tests {
use gpui::TestAppContext;
use indoc::formatdoc;
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
- use language_model::{LanguageModelProviderId, LanguageModelProviderName};
+ use language_model::{
+ LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
+ };
use serde_json::json;
use settings::SettingsStore;
use util::{path, rel_path::rel_path};
@@ -2549,6 +2555,13 @@ mod internal_tests {
cx.run_until_parked();
model.send_last_completion_stream_text_chunk("Lorem.");
+ model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
+ language_model::TokenUsage {
+ input_tokens: 150,
+ output_tokens: 75,
+ ..Default::default()
+ },
+ ));
model.end_last_completion_stream();
cx.run_until_parked();
summary_model
@@ -2587,6 +2600,12 @@ mod internal_tests {
acp_thread.update(cx, |thread, _cx| {
thread.set_draft_prompt(Some(draft_blocks.clone()));
});
+ thread.update(cx, |thread, _cx| {
+ thread.set_ui_scroll_position(Some(gpui::ListOffset {
+ item_ix: 5,
+ offset_in_item: gpui::px(12.5),
+ }));
+ });
thread.update(cx, |_thread, cx| cx.notify());
cx.run_until_parked();
@@ -2632,6 +2651,24 @@ mod internal_tests {
acp_thread.read_with(cx, |thread, _| {
assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
});
+
+ // Ensure token usage survived the round-trip.
+ acp_thread.read_with(cx, |thread, _| {
+ let usage = thread
+ .token_usage()
+ .expect("token usage should be restored after reload");
+ assert_eq!(usage.input_tokens, 150);
+ assert_eq!(usage.output_tokens, 75);
+ });
+
+ // Ensure scroll position survived the round-trip.
+ acp_thread.read_with(cx, |thread, _| {
+ let scroll = thread
+ .ui_scroll_position()
+ .expect("scroll position should be restored after reload");
+ assert_eq!(scroll.item_ix, 5);
+ assert_eq!(scroll.offset_in_item, gpui::px(12.5));
+ });
}
fn thread_entries(
@@ -66,6 +66,14 @@ pub struct DbThread {
pub thinking_effort: Option<String>,
#[serde(default)]
pub draft_prompt: Option<Vec<acp::ContentBlock>>,
+ #[serde(default)]
+ pub ui_scroll_position: Option<SerializedScrollPosition>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
+pub struct SerializedScrollPosition {
+ pub item_ix: usize,
+ pub offset_in_item: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -108,6 +116,7 @@ impl SharedThread {
thinking_enabled: false,
thinking_effort: None,
draft_prompt: None,
+ ui_scroll_position: None,
}
}
@@ -286,6 +295,7 @@ impl DbThread {
thinking_enabled: false,
thinking_effort: None,
draft_prompt: None,
+ ui_scroll_position: None,
})
}
}
@@ -637,6 +647,7 @@ mod tests {
thinking_enabled: false,
thinking_effort: None,
draft_prompt: None,
+ ui_scroll_position: None,
}
}
@@ -841,4 +852,53 @@ mod tests {
assert_eq!(threads.len(), 1);
assert!(threads[0].folder_paths.is_empty());
}
+
+ #[test]
+ fn test_scroll_position_defaults_to_none() {
+ let json = r#"{
+ "title": "Old Thread",
+ "messages": [],
+ "updated_at": "2024-01-01T00:00:00Z"
+ }"#;
+
+ let db_thread: DbThread = serde_json::from_str(json).expect("Failed to deserialize");
+
+ assert!(
+ db_thread.ui_scroll_position.is_none(),
+ "Legacy threads without scroll_position field should default to None"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_scroll_position_roundtrips_through_save_load(cx: &mut TestAppContext) {
+ let database = ThreadsDatabase::new(cx.executor()).unwrap();
+
+ let thread_id = session_id("thread-with-scroll");
+
+ let mut thread = make_thread(
+ "Thread With Scroll",
+ Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(),
+ );
+ thread.ui_scroll_position = Some(SerializedScrollPosition {
+ item_ix: 42,
+ offset_in_item: 13.5,
+ });
+
+ database
+ .save_thread(thread_id.clone(), thread, PathList::default())
+ .await
+ .unwrap();
+
+ let loaded = database
+ .load_thread(thread_id)
+ .await
+ .unwrap()
+ .expect("thread should exist");
+
+ let scroll = loaded
+ .ui_scroll_position
+ .expect("scroll_position should be restored");
+ assert_eq!(scroll.item_ix, 42);
+ assert!((scroll.offset_in_item - 13.5).abs() < f32::EPSILON);
+ }
}
@@ -50,9 +50,9 @@ async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
// Add just the tools we need for this test
let language_registry = project.read(cx).languages().clone();
thread.add_tool(crate::ReadFileTool::new(
- cx.weak_entity(),
project.clone(),
thread.action_log().clone(),
+ true,
));
thread.add_tool(crate::EditFileTool::new(
project.clone(),
@@ -893,14 +893,13 @@ pub struct Thread {
pub(crate) prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
pub(crate) project: Entity<Project>,
pub(crate) action_log: Entity<ActionLog>,
- /// Tracks the last time files were read by the agent, to detect external modifications
- pub(crate) file_read_times: HashMap<PathBuf, fs::MTime>,
/// True if this thread was imported from a shared thread and can be synced.
imported: bool,
/// If this is a subagent thread, contains context about the parent
subagent_context: Option<SubagentContext>,
/// The user's unsent prompt text, persisted so it can be restored when reloading the thread.
draft_prompt: Option<Vec<acp::ContentBlock>>,
+ ui_scroll_position: Option<gpui::ListOffset>,
/// Weak references to running subagent threads for cancellation propagation
running_subagents: Vec<WeakEntity<Thread>>,
}
@@ -1013,10 +1012,10 @@ impl Thread {
prompt_capabilities_rx,
project,
action_log,
- file_read_times: HashMap::default(),
imported: false,
subagent_context: None,
draft_prompt: None,
+ ui_scroll_position: None,
running_subagents: Vec::new(),
}
}
@@ -1229,10 +1228,13 @@ impl Thread {
updated_at: db_thread.updated_at,
prompt_capabilities_tx,
prompt_capabilities_rx,
- file_read_times: HashMap::default(),
imported: db_thread.imported,
subagent_context: db_thread.subagent_context,
draft_prompt: db_thread.draft_prompt,
+ ui_scroll_position: db_thread.ui_scroll_position.map(|sp| gpui::ListOffset {
+ item_ix: sp.item_ix,
+ offset_in_item: gpui::px(sp.offset_in_item),
+ }),
running_subagents: Vec::new(),
}
}
@@ -1258,6 +1260,12 @@ impl Thread {
thinking_enabled: self.thinking_enabled,
thinking_effort: self.thinking_effort.clone(),
draft_prompt: self.draft_prompt.clone(),
+ ui_scroll_position: self.ui_scroll_position.map(|lo| {
+ crate::db::SerializedScrollPosition {
+ item_ix: lo.item_ix,
+ offset_in_item: lo.offset_in_item.as_f32(),
+ }
+ }),
};
cx.background_spawn(async move {
@@ -1307,6 +1315,14 @@ impl Thread {
self.draft_prompt = prompt;
}
+ pub fn ui_scroll_position(&self) -> Option<gpui::ListOffset> {
+ self.ui_scroll_position
+ }
+
+ pub fn set_ui_scroll_position(&mut self, position: Option<gpui::ListOffset>) {
+ self.ui_scroll_position = position;
+ }
+
pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
self.model.as_ref()
}
@@ -1416,6 +1432,9 @@ impl Thread {
environment: Rc<dyn ThreadEnvironment>,
cx: &mut Context<Self>,
) {
+ // Only update the agent location for the root thread, not for subagents.
+ let update_agent_location = self.parent_thread_id().is_none();
+
let language_registry = self.project.read(cx).languages().clone();
self.add_tool(CopyPathTool::new(self.project.clone()));
self.add_tool(CreateDirectoryTool::new(self.project.clone()));
@@ -1443,9 +1462,9 @@ impl Thread {
self.add_tool(NowTool);
self.add_tool(OpenTool::new(self.project.clone()));
self.add_tool(ReadFileTool::new(
- cx.weak_entity(),
self.project.clone(),
self.action_log.clone(),
+ update_agent_location,
));
self.add_tool(SaveFileTool::new(self.project.clone()));
self.add_tool(RestoreFileFromDiskTool::new(self.project.clone()));
@@ -2617,7 +2636,8 @@ impl Thread {
}
}
- let use_streaming_edit_tool = cx.has_flag::<StreamingEditFileToolFeatureFlag>();
+ let use_streaming_edit_tool =
+ cx.has_flag::<StreamingEditFileToolFeatureFlag>() && model.supports_streaming_tools();
let mut tools = self
.tools
@@ -146,6 +146,7 @@ mod tests {
thinking_enabled: false,
thinking_effort: None,
draft_prompt: None,
+ ui_scroll_position: None,
}
}
@@ -305,13 +305,13 @@ impl AgentTool for EditFileTool {
// Check if the file has been modified since the agent last read it
if let Some(abs_path) = abs_path.as_ref() {
- let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.update(cx, |thread, cx| {
- let last_read = thread.file_read_times.get(abs_path).copied();
+ let last_read_mtime = action_log.read_with(cx, |log, _| log.file_read_time(abs_path));
+ let (current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.read_with(cx, |thread, cx| {
let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
let dirty = buffer.read(cx).is_dirty();
let has_save = thread.has_tool(SaveFileTool::NAME);
let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
- (last_read, current, dirty, has_save, has_restore)
+ (current, dirty, has_save, has_restore)
})?;
// Check for unsaved changes first - these indicate modifications we don't know about
@@ -470,17 +470,6 @@ impl AgentTool for EditFileTool {
log.buffer_edited(buffer.clone(), cx);
});
- // Update the recorded read time after a successful edit so consecutive edits work
- if let Some(abs_path) = abs_path.as_ref() {
- if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
- buffer.file().and_then(|file| file.disk_state().mtime())
- }) {
- self.thread.update(cx, |thread, _| {
- thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
- })?;
- }
- }
-
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let (new_text, unified_diff) = cx
.background_spawn({
@@ -2212,14 +2201,18 @@ mod tests {
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
// Initially, file_read_times should be empty
- let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
+ let is_empty = action_log.read_with(cx, |action_log, _| {
+ action_log
+ .file_read_time(path!("/root/test.txt").as_ref())
+ .is_none()
+ });
assert!(is_empty, "file_read_times should start empty");
// Create read tool
let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
project.clone(),
- action_log,
+ action_log.clone(),
+ true,
));
// Read the file to record the read time
@@ -2238,12 +2231,9 @@ mod tests {
.unwrap();
// Verify that file_read_times now contains an entry for the file
- let has_entry = thread.read_with(cx, |thread, _| {
- thread.file_read_times.len() == 1
- && thread
- .file_read_times
- .keys()
- .any(|path| path.ends_with("test.txt"))
+ let has_entry = action_log.read_with(cx, |log, _| {
+ log.file_read_time(path!("/root/test.txt").as_ref())
+ .is_some()
});
assert!(
has_entry,
@@ -2265,11 +2255,14 @@ mod tests {
.await
.unwrap();
- // Should still have exactly one entry
- let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
+ // Should still have an entry after re-reading
+ let has_entry = action_log.read_with(cx, |log, _| {
+ log.file_read_time(path!("/root/test.txt").as_ref())
+ .is_some()
+ });
assert!(
- has_one_entry,
- "file_read_times should still have one entry after re-reading"
+ has_entry,
+ "file_read_times should still have an entry after re-reading"
);
}
@@ -2309,11 +2302,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -2423,11 +2412,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -2534,11 +2519,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -2,7 +2,7 @@ use action_log::ActionLog;
use agent_client_protocol::{self as acp, ToolCallUpdateFields};
use anyhow::{Context as _, Result, anyhow};
use futures::FutureExt as _;
-use gpui::{App, Entity, SharedString, Task, WeakEntity};
+use gpui::{App, Entity, SharedString, Task};
use indoc::formatdoc;
use language::Point;
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
@@ -21,7 +21,7 @@ use super::tool_permissions::{
ResolvedProjectPath, authorize_symlink_access, canonicalize_worktree_roots,
resolve_project_path,
};
-use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput, outline};
+use crate::{AgentTool, ToolCallEventStream, ToolInput, outline};
/// Reads the content of the given file in the project.
///
@@ -56,21 +56,21 @@ pub struct ReadFileToolInput {
}
pub struct ReadFileTool {
- thread: WeakEntity<Thread>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
+ update_agent_location: bool,
}
impl ReadFileTool {
pub fn new(
- thread: WeakEntity<Thread>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
+ update_agent_location: bool,
) -> Self {
Self {
- thread,
project,
action_log,
+ update_agent_location,
}
}
}
@@ -119,7 +119,6 @@ impl AgentTool for ReadFileTool {
cx: &mut App,
) -> Task<Result<LanguageModelToolResultContent, LanguageModelToolResultContent>> {
let project = self.project.clone();
- let thread = self.thread.clone();
let action_log = self.action_log.clone();
cx.spawn(async move |cx| {
let input = input
@@ -257,20 +256,6 @@ impl AgentTool for ReadFileTool {
return Err(tool_content_err(format!("{file_path} not found")));
}
- // Record the file read time and mtime
- if let Some(mtime) = buffer.read_with(cx, |buffer, _| {
- buffer.file().and_then(|file| file.disk_state().mtime())
- }) {
- thread
- .update(cx, |thread, _| {
- thread.file_read_times.insert(abs_path.to_path_buf(), mtime);
- })
- .ok();
- }
-
-
- let update_agent_location = self.thread.read_with(cx, |thread, _cx| !thread.is_subagent()).unwrap_or_default();
-
let mut anchor = None;
// Check if specific line ranges are provided
@@ -330,7 +315,7 @@ impl AgentTool for ReadFileTool {
};
project.update(cx, |project, cx| {
- if update_agent_location {
+ if self.update_agent_location {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
@@ -362,13 +347,10 @@ impl AgentTool for ReadFileTool {
#[cfg(test)]
mod test {
use super::*;
- use crate::{ContextServerRegistry, Templates, Thread};
use agent_client_protocol as acp;
use fs::Fs as _;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
- use language_model::fake_provider::FakeLanguageModel;
use project::{FakeFs, Project};
- use prompt_store::ProjectContext;
use serde_json::json;
use settings::SettingsStore;
use std::path::PathBuf;
@@ -383,20 +365,7 @@ mod test {
fs.insert_tree(path!("/root"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
let (event_stream, _) = ToolCallEventStream::test();
let result = cx
@@ -429,20 +398,7 @@ mod test {
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@@ -476,20 +432,7 @@ mod test {
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(language::rust_lang());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@@ -569,20 +512,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
let result = cx
.update(|cx| {
let input = ReadFileToolInput {
@@ -614,20 +544,7 @@ mod test {
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
// start_line of 0 should be treated as 1
let result = cx
@@ -757,20 +674,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
// Reading a file outside the project worktree should fail
let result = cx
@@ -965,20 +869,7 @@ mod test {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let tool = Arc::new(ReadFileTool::new(project, action_log, true));
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let read_task = cx.update(|cx| {
@@ -1084,24 +975,7 @@ mod test {
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log.clone(),
- ));
+ let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone(), true));
// Test reading allowed files in worktree1
let result = cx
@@ -1288,24 +1162,7 @@ mod test {
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true));
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| {
@@ -1364,24 +1221,7 @@ mod test {
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true));
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let task = cx.update(|cx| {
@@ -1444,24 +1284,7 @@ mod test {
cx.executor().run_until_parked();
let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let context_server_registry =
- cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
- let tool = Arc::new(ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let tool = Arc::new(ReadFileTool::new(project.clone(), action_log, true));
let (event_stream, mut event_rx) = ToolCallEventStream::test();
let result = cx
@@ -161,29 +161,42 @@ impl AgentTool for SpawnAgentTool {
Ok((subagent, session_info))
})?;
- match subagent.send(input.message, cx).await {
- Ok(output) => {
- session_info.message_end_index =
- cx.update(|cx| Some(subagent.num_entries(cx).saturating_sub(1)));
- event_stream.update_fields_with_meta(
- acp::ToolCallUpdateFields::new().content(vec![output.clone().into()]),
- Some(acp::Meta::from_iter([(
- SUBAGENT_SESSION_INFO_META_KEY.into(),
- serde_json::json!(&session_info),
- )])),
- );
+ let send_result = subagent.send(input.message, cx).await;
+
+ session_info.message_end_index =
+ cx.update(|cx| Some(subagent.num_entries(cx).saturating_sub(1)));
+
+ let meta = Some(acp::Meta::from_iter([(
+ SUBAGENT_SESSION_INFO_META_KEY.into(),
+ serde_json::json!(&session_info),
+ )]));
+
+ let (output, result) = match send_result {
+ Ok(output) => (
+ output.clone(),
Ok(SpawnAgentToolOutput::Success {
session_id: session_info.session_id.clone(),
session_info,
output,
- })
+ }),
+ ),
+ Err(e) => {
+ let error = e.to_string();
+ (
+ error.clone(),
+ Err(SpawnAgentToolOutput::Error {
+ session_id: Some(session_info.session_id.clone()),
+ error,
+ session_info: Some(session_info),
+ }),
+ )
}
- Err(e) => Err(SpawnAgentToolOutput::Error {
- session_id: Some(session_info.session_id.clone()),
- error: e.to_string(),
- session_info: Some(session_info),
- }),
- }
+ };
+ event_stream.update_fields_with_meta(
+ acp::ToolCallUpdateFields::new().content(vec![output.into()]),
+ meta,
+ );
+ result
})
}
@@ -483,7 +483,12 @@ impl EditSession {
.await
.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
- ensure_buffer_saved(&buffer, &abs_path, tool, cx)?;
+ let action_log = tool
+ .thread
+ .read_with(cx, |thread, _cx| thread.action_log().clone())
+ .ok();
+
+ ensure_buffer_saved(&buffer, &abs_path, tool, action_log.as_ref(), cx)?;
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
event_stream.update_diff(diff.clone());
@@ -495,13 +500,9 @@ impl EditSession {
}
}) as Box<dyn FnOnce()>);
- tool.thread
- .update(cx, |thread, cx| {
- thread
- .action_log()
- .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))
- })
- .ok();
+ if let Some(action_log) = &action_log {
+ action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+ }
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let old_text = cx
@@ -637,18 +638,6 @@ impl EditSession {
log.buffer_edited(buffer.clone(), cx);
});
- if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
- buffer.file().and_then(|file| file.disk_state().mtime())
- }) {
- tool.thread
- .update(cx, |thread, _| {
- thread
- .file_read_times
- .insert(abs_path.to_path_buf(), new_mtime);
- })
- .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
- }
-
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let (new_text, unified_diff) = cx
.background_spawn({
@@ -1018,10 +1007,12 @@ fn ensure_buffer_saved(
buffer: &Entity<Buffer>,
abs_path: &PathBuf,
tool: &StreamingEditFileTool,
+ action_log: Option<&Entity<ActionLog>>,
cx: &mut AsyncApp,
) -> Result<(), StreamingEditFileToolOutput> {
- let check_result = tool.thread.update(cx, |thread, cx| {
- let last_read = thread.file_read_times.get(abs_path).copied();
+ let last_read_mtime =
+ action_log.and_then(|log| log.read_with(cx, |log, _| log.file_read_time(abs_path)));
+ let check_result = tool.thread.read_with(cx, |thread, cx| {
let current = buffer
.read(cx)
.file()
@@ -1029,12 +1020,10 @@ fn ensure_buffer_saved(
let dirty = buffer.read(cx).is_dirty();
let has_save = thread.has_tool(SaveFileTool::NAME);
let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
- (last_read, current, dirty, has_save, has_restore)
+ (current, dirty, has_save, has_restore)
});
- let Ok((last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool)) =
- check_result
- else {
+ let Ok((current_mtime, is_dirty, has_save_tool, has_restore_tool)) = check_result else {
return Ok(());
};
@@ -4006,11 +3995,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(StreamingEditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -4112,11 +4097,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(StreamingEditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -4225,11 +4206,7 @@ mod tests {
let languages = project.read_with(cx, |project, _| project.languages().clone());
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
- let read_tool = Arc::new(crate::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true));
let edit_tool = Arc::new(StreamingEditFileTool::new(
project.clone(),
thread.downgrade(),
@@ -845,6 +845,10 @@ impl ConnectionView {
);
});
+ if let Some(scroll_position) = thread.read(cx).ui_scroll_position() {
+ list_state.scroll_to(scroll_position);
+ }
+
AgentDiff::set_active_thread(&self.workspace, thread.clone(), window, cx);
let connection = thread.read(cx).connection().clone();
@@ -248,7 +248,8 @@ pub struct ThreadView {
pub resumed_without_history: bool,
pub resume_thread_metadata: Option<AgentSessionInfo>,
pub _cancel_task: Option<Task<()>>,
- _draft_save_task: Option<Task<()>>,
+ _save_task: Option<Task<()>>,
+ _draft_resolve_task: Option<Task<()>>,
pub skip_queue_processing_count: usize,
pub user_interrupted_generation: bool,
pub can_fast_track_queue: bool,
@@ -396,7 +397,7 @@ impl ThreadView {
} else {
Some(editor.update(cx, |editor, cx| editor.draft_contents(cx)))
};
- this._draft_save_task = Some(cx.spawn(async move |this, cx| {
+ this._draft_resolve_task = Some(cx.spawn(async move |this, cx| {
let draft = if let Some(task) = draft_contents_task {
let blocks = task.await.ok().filter(|b| !b.is_empty());
blocks
@@ -407,15 +408,7 @@ impl ThreadView {
this.thread.update(cx, |thread, _cx| {
thread.set_draft_prompt(draft);
});
- })
- .ok();
- cx.background_executor()
- .timer(SERIALIZATION_THROTTLE_TIME)
- .await;
- this.update(cx, |this, cx| {
- if let Some(thread) = this.as_native_thread(cx) {
- thread.update(cx, |_thread, cx| cx.notify());
- }
+ this.schedule_save(cx);
})
.ok();
}));
@@ -471,7 +464,8 @@ impl ThreadView {
is_loading_contents: false,
new_server_version_available: None,
_cancel_task: None,
- _draft_save_task: None,
+ _save_task: None,
+ _draft_resolve_task: None,
skip_queue_processing_count: 0,
user_interrupted_generation: false,
can_fast_track_queue: false,
@@ -487,12 +481,50 @@ impl ThreadView {
_history_subscription: history_subscription,
show_codex_windows_warning,
};
+ let list_state_for_scroll = this.list_state.clone();
+ let thread_view = cx.entity().downgrade();
+ this.list_state
+ .set_scroll_handler(move |_event, _window, cx| {
+ let list_state = list_state_for_scroll.clone();
+ let thread_view = thread_view.clone();
+ // N.B. We must defer because the scroll handler is called while the
+ // ListState's RefCell is mutably borrowed. Reading logical_scroll_top()
+ // directly would panic from a double borrow.
+ cx.defer(move |cx| {
+ let scroll_top = list_state.logical_scroll_top();
+ let _ = thread_view.update(cx, |this, cx| {
+ if let Some(thread) = this.as_native_thread(cx) {
+ thread.update(cx, |thread, _cx| {
+ thread.set_ui_scroll_position(Some(scroll_top));
+ });
+ }
+ this.schedule_save(cx);
+ });
+ });
+ });
+
if should_auto_submit {
this.send(window, cx);
}
this
}
+ /// Schedule a throttled save of the thread state (draft prompt, scroll position, etc.).
+ /// Multiple calls within `SERIALIZATION_THROTTLE_TIME` are coalesced into a single save.
+ fn schedule_save(&mut self, cx: &mut Context<Self>) {
+ self._save_task = Some(cx.spawn(async move |this, cx| {
+ cx.background_executor()
+ .timer(SERIALIZATION_THROTTLE_TIME)
+ .await;
+ this.update(cx, |this, cx| {
+ if let Some(thread) = this.as_native_thread(cx) {
+ thread.update(cx, |_thread, cx| cx.notify());
+ }
+ })
+ .ok();
+ }));
+ }
+
pub fn handle_message_editor_event(
&mut self,
_editor: &Entity<MessageEditor>,
@@ -6736,6 +6768,31 @@ impl ThreadView {
.read(cx)
.pending_tool_call(thread.read(cx).session_id(), cx);
+ let session_id = thread.read(cx).session_id().clone();
+
+ let fullscreen_toggle = h_flex()
+ .id(entry_ix)
+ .py_1()
+ .w_full()
+ .justify_center()
+ .border_t_1()
+ .when(is_failed, |this| this.border_dashed())
+ .border_color(self.tool_card_border_color(cx))
+ .hover(|s| s.bg(cx.theme().colors().element_hover))
+ .child(
+ Icon::new(IconName::Maximize)
+ .color(Color::Muted)
+ .size(IconSize::Small),
+ )
+ .tooltip(Tooltip::text("Make Subagent Full Screen"))
+ .on_click(cx.listener(move |this, _event, window, cx| {
+ this.server_view
+ .update(cx, |this, cx| {
+ this.navigate_to_session(session_id.clone(), window, cx);
+ })
+ .ok();
+ }));
+
if is_running && let Some((_, subagent_tool_call_id, _)) = pending_tool_call {
if let Some((entry_ix, tool_call)) =
thread.read(cx).tool_call(&subagent_tool_call_id)
@@ -6750,11 +6807,11 @@ impl ThreadView {
window,
cx,
))
+ .child(fullscreen_toggle)
} else {
this
}
} else {
- let session_id = thread.read(cx).session_id().clone();
this.when(is_expanded, |this| {
this.child(self.render_subagent_expanded_content(
thread_view,
@@ -6771,34 +6828,7 @@ impl ThreadView {
.title(message),
)
})
- .child(
- h_flex()
- .id(entry_ix)
- .py_1()
- .w_full()
- .justify_center()
- .border_t_1()
- .when(is_failed, |this| this.border_dashed())
- .border_color(self.tool_card_border_color(cx))
- .hover(|s| s.bg(cx.theme().colors().element_hover))
- .child(
- Icon::new(IconName::Maximize)
- .color(Color::Muted)
- .size(IconSize::Small),
- )
- .tooltip(Tooltip::text("Make Subagent Full Screen"))
- .on_click(cx.listener(move |this, _event, window, cx| {
- this.server_view
- .update(cx, |this, cx| {
- this.navigate_to_session(
- session_id.clone(),
- window,
- cx,
- );
- })
- .ok();
- })),
- )
+ .child(fullscreen_toggle)
})
}
})
@@ -19,6 +19,7 @@ log.workspace = true
simplelog.workspace = true
[target.'cfg(target_os = "windows")'.dependencies]
+scopeguard = "1.2"
windows.workspace = true
[target.'cfg(target_os = "windows")'.dev-dependencies]
@@ -1,13 +1,22 @@
use std::{
+ ffi::OsStr,
+ os::windows::ffi::OsStrExt,
path::Path,
sync::LazyLock,
time::{Duration, Instant},
};
use anyhow::{Context as _, Result};
-use windows::Win32::{
- Foundation::{HWND, LPARAM, WPARAM},
- UI::WindowsAndMessaging::PostMessageW,
+use windows::{
+ Win32::{
+ Foundation::{HWND, LPARAM, WPARAM},
+ System::RestartManager::{
+ CCH_RM_SESSION_KEY, RmEndSession, RmGetList, RmRegisterResources, RmShutdown,
+ RmStartSession,
+ },
+ UI::WindowsAndMessaging::PostMessageW,
+ },
+ core::{PCWSTR, PWSTR},
};
use crate::windows_impl::WM_JOB_UPDATED;
@@ -262,9 +271,106 @@ pub(crate) static JOBS: LazyLock<[Job; 9]> = LazyLock::new(|| {
]
});
+/// Attempts to use Windows Restart Manager to release file handles held by other processes
+/// (e.g., Explorer.exe) on the files we need to move during the update.
+///
+/// This is a best-effort operation - if it fails, we'll still try the update and rely on
+/// the retry logic.
+fn release_file_handles(app_dir: &Path) -> Result<()> {
+ // Files that commonly get locked by Explorer or other processes
+ let files_to_release = [
+ app_dir.join("Zed.exe"),
+ app_dir.join("bin\\Zed.exe"),
+ app_dir.join("bin\\zed"),
+ app_dir.join("conpty.dll"),
+ ];
+
+ log::info!("Attempting to release file handles using Restart Manager...");
+
+ let mut session: u32 = 0;
+ let mut session_key = [0u16; CCH_RM_SESSION_KEY as usize + 1];
+
+ // Start a Restart Manager session
+ let err = unsafe {
+ RmStartSession(
+ &mut session,
+ Some(0),
+ PWSTR::from_raw(session_key.as_mut_ptr()),
+ )
+ };
+ if err.is_err() {
+ anyhow::bail!("RmStartSession failed: {err:?}");
+ }
+
+ // Ensure we end the session when done
+ let _session_guard = scopeguard::guard(session, |s| {
+ let _ = unsafe { RmEndSession(s) };
+ });
+
+ // Convert paths to wide strings for Windows API
+ let wide_paths: Vec<Vec<u16>> = files_to_release
+ .iter()
+ .filter(|p| p.exists())
+ .map(|p| {
+ OsStr::new(p)
+ .encode_wide()
+ .chain(std::iter::once(0))
+ .collect()
+ })
+ .collect();
+
+ if wide_paths.is_empty() {
+ log::info!("No files to release handles for");
+ return Ok(());
+ }
+
+ let pcwstr_paths: Vec<PCWSTR> = wide_paths
+ .iter()
+ .map(|p| PCWSTR::from_raw(p.as_ptr()))
+ .collect();
+
+ // Register the files we want to modify
+ let err = unsafe { RmRegisterResources(session, Some(&pcwstr_paths), None, None) };
+ if err.is_err() {
+ anyhow::bail!("RmRegisterResources failed: {err:?}");
+ }
+
+ // Check if any processes are using these files
+ let mut needed: u32 = 0;
+ let mut count: u32 = 0;
+ let mut reboot_reasons: u32 = 0;
+ let _ = unsafe { RmGetList(session, &mut needed, &mut count, None, &mut reboot_reasons) };
+
+ if needed == 0 {
+ log::info!("No processes are holding handles to the files");
+ return Ok(());
+ }
+
+ log::info!(
+ "{} process(es) are holding handles to the files, requesting release...",
+ needed
+ );
+
+ // Request processes to release their handles
+ // RmShutdown with flags=0 asks applications to release handles gracefully
+ // For Explorer, this typically releases icon cache handles without closing Explorer
+ let err = unsafe { RmShutdown(session, 0, None) };
+ if err.is_err() {
+ anyhow::bail!("RmShutdown failed: {:?}", err);
+ }
+
+ log::info!("Successfully requested handle release");
+ Ok(())
+}
+
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>, launch: bool) -> Result<()> {
let hwnd = hwnd.map(|ptr| HWND(ptr as _));
+ // Try to release file handles before starting the update
+ if let Err(e) = release_file_handles(app_dir) {
+ log::warn!("Restart Manager failed (will continue anyway): {}", e);
+ }
+
let mut last_successful_job = None;
'outer: for (i, job) in JOBS.iter().enumerate() {
let start = Instant::now();
@@ -279,19 +385,22 @@ pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>, launch: bool)
unsafe { PostMessageW(hwnd, WM_JOB_UPDATED, WPARAM(0), LPARAM(0))? };
break;
}
- Err(err) => {
- // Check if it's a "not found" error
- let io_err = err.downcast_ref::<std::io::Error>().unwrap();
- if io_err.kind() == std::io::ErrorKind::NotFound {
- log::warn!("File or folder not found.");
- last_successful_job = Some(i);
- unsafe { PostMessageW(hwnd, WM_JOB_UPDATED, WPARAM(0), LPARAM(0))? };
- break;
+ Err(err) => match err.downcast_ref::<std::io::Error>() {
+ Some(io_err) => match io_err.kind() {
+ std::io::ErrorKind::NotFound => {
+ log::error!("Operation failed with file not found, aborting: {}", err);
+ break 'outer;
+ }
+ _ => {
+ log::error!("Operation failed (retrying): {}", err);
+ std::thread::sleep(Duration::from_millis(50));
+ }
+ },
+ None => {
+ log::error!("Operation failed with unexpected error, aborting: {}", err);
+ break 'outer;
}
-
- log::error!("Operation failed: {} ({:?})", err, io_err.kind());
- std::thread::sleep(Duration::from_millis(50));
- }
+ },
}
}
}
@@ -9,7 +9,9 @@ use futures::AsyncReadExt as _;
use gpui::{App, Task};
use gpui_tokio::Tokio;
use http_client::http::request;
-use http_client::{AsyncBody, HttpClientWithUrl, HttpRequestExt, Method, Request, StatusCode};
+use http_client::{
+ AsyncBody, HttpClientWithUrl, HttpRequestExt, Json, Method, Request, StatusCode,
+};
use parking_lot::RwLock;
use thiserror::Error;
use yawc::WebSocket;
@@ -141,6 +143,7 @@ impl CloudApiClient {
pub async fn create_llm_token(
&self,
system_id: Option<String>,
+ organization_id: Option<OrganizationId>,
) -> Result<CreateLlmTokenResponse, ClientApiError> {
let request_builder = Request::builder()
.method(Method::POST)
@@ -153,7 +156,10 @@ impl CloudApiClient {
builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id)
});
- let request = self.build_request(request_builder, AsyncBody::default())?;
+ let request = self.build_request(
+ request_builder,
+ Json(CreateLlmTokenBody { organization_id }),
+ )?;
let mut response = self.http_client.send(request).await?;
@@ -52,6 +52,12 @@ pub struct AcceptTermsOfServiceResponse {
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct LlmToken(pub String);
+#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)]
+pub struct CreateLlmTokenBody {
+ #[serde(default)]
+ pub organization_id: Option<OrganizationId>,
+}
+
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct CreateLlmTokenResponse {
pub token: LlmToken,
@@ -437,6 +437,8 @@ impl Server {
.add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
.add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
.add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
+ .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
+ .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
.add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
@@ -1,15 +1,14 @@
-use std::path::Path;
+use std::path::{Path, PathBuf};
use call::ActiveCall;
use git::status::{FileStatus, StatusCode, TrackedStatus};
use git_ui::project_diff::ProjectDiff;
-use gpui::{AppContext as _, TestAppContext, VisualTestContext};
+use gpui::{AppContext as _, BackgroundExecutor, TestAppContext, VisualTestContext};
use project::ProjectPath;
use serde_json::json;
use util::{path, rel_path::rel_path};
use workspace::{MultiWorkspace, Workspace};
-//
use crate::TestServer;
#[gpui::test]
@@ -141,3 +140,142 @@ async fn test_project_diff(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext)
);
});
}
+
+#[gpui::test]
+async fn test_remote_git_worktrees(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ let mut server = TestServer::start(executor.clone()).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+ server
+ .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)])
+ .await;
+ let active_call_a = cx_a.read(ActiveCall::global);
+
+ client_a
+ .fs()
+ .insert_tree(
+ path!("/project"),
+ json!({ ".git": {}, "file.txt": "content" }),
+ )
+ .await;
+
+ let (project_a, _) = client_a.build_local_project(path!("/project"), cx_a).await;
+
+ let project_id = active_call_a
+ .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
+ .await
+ .unwrap();
+ let project_b = client_b.join_remote_project(project_id, cx_b).await;
+
+ executor.run_until_parked();
+
+ let repo_b = cx_b.update(|cx| project_b.read(cx).active_repository(cx).unwrap());
+
+ // Initially only the main worktree (the repo itself) should be present
+ let worktrees = cx_b
+ .update(|cx| repo_b.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(worktrees.len(), 1);
+ assert_eq!(worktrees[0].path, PathBuf::from(path!("/project")));
+
+ // Client B creates a git worktree via the remote project
+ let worktree_directory = PathBuf::from(path!("/project"));
+ cx_b.update(|cx| {
+ repo_b.update(cx, |repository, _| {
+ repository.create_worktree(
+ "feature-branch".to_string(),
+ worktree_directory.clone(),
+ Some("abc123".to_string()),
+ )
+ })
+ })
+ .await
+ .unwrap()
+ .unwrap();
+
+ executor.run_until_parked();
+
+ // Client B lists worktrees — should see main + the one just created
+ let worktrees = cx_b
+ .update(|cx| repo_b.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(worktrees.len(), 2);
+ assert_eq!(worktrees[0].path, PathBuf::from(path!("/project")));
+ assert_eq!(worktrees[1].path, worktree_directory.join("feature-branch"));
+ assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch");
+ assert_eq!(worktrees[1].sha.as_ref(), "abc123");
+
+ // Verify from the host side that the worktree was actually created
+ let host_worktrees = {
+ let repo_a = cx_a.update(|cx| {
+ project_a
+ .read(cx)
+ .repositories(cx)
+ .values()
+ .next()
+ .unwrap()
+ .clone()
+ });
+ cx_a.update(|cx| repo_a.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap()
+ };
+ assert_eq!(host_worktrees.len(), 2);
+ assert_eq!(host_worktrees[0].path, PathBuf::from(path!("/project")));
+ assert_eq!(
+ host_worktrees[1].path,
+ worktree_directory.join("feature-branch")
+ );
+
+ // Client B creates a second git worktree without an explicit commit
+ cx_b.update(|cx| {
+ repo_b.update(cx, |repository, _| {
+ repository.create_worktree(
+ "bugfix-branch".to_string(),
+ worktree_directory.clone(),
+ None,
+ )
+ })
+ })
+ .await
+ .unwrap()
+ .unwrap();
+
+ executor.run_until_parked();
+
+ // Client B lists worktrees — should now have main + two created
+ let worktrees = cx_b
+ .update(|cx| repo_b.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(worktrees.len(), 3);
+
+ let feature_worktree = worktrees
+ .iter()
+ .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/feature-branch")
+ .expect("should find feature-branch worktree");
+ assert_eq!(
+ feature_worktree.path,
+ worktree_directory.join("feature-branch")
+ );
+
+ let bugfix_worktree = worktrees
+ .iter()
+ .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/bugfix-branch")
+ .expect("should find bugfix-branch worktree");
+ assert_eq!(
+ bugfix_worktree.path,
+ worktree_directory.join("bugfix-branch")
+ );
+ assert_eq!(bugfix_worktree.sha.as_ref(), "fake-sha");
+}
@@ -7205,3 +7205,89 @@ async fn test_remote_git_branches(
assert_eq!(host_branch.name(), "totally-new-branch");
}
+
+#[gpui::test]
+async fn test_guest_can_rejoin_shared_project_after_leaving_call(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+ cx_c: &mut TestAppContext,
+) {
+ let mut server = TestServer::start(executor.clone()).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+ let client_c = server.create_client(cx_c, "user_c").await;
+
+ server
+ .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b), (&client_c, cx_c)])
+ .await;
+
+ client_a
+ .fs()
+ .insert_tree(
+ path!("/project"),
+ json!({
+ "file.txt": "hello\n",
+ }),
+ )
+ .await;
+
+ let (project_a, _worktree_id) = client_a.build_local_project(path!("/project"), cx_a).await;
+ let active_call_a = cx_a.read(ActiveCall::global);
+ let project_id = active_call_a
+ .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
+ .await
+ .unwrap();
+
+ let _project_b = client_b.join_remote_project(project_id, cx_b).await;
+ executor.run_until_parked();
+
+ // third client joins call to prevent room from being torn down
+ let _project_c = client_c.join_remote_project(project_id, cx_c).await;
+ executor.run_until_parked();
+
+ let active_call_b = cx_b.read(ActiveCall::global);
+ active_call_b
+ .update(cx_b, |call, cx| call.hang_up(cx))
+ .await
+ .unwrap();
+ executor.run_until_parked();
+
+ let user_id_b = client_b.current_user_id(cx_b).to_proto();
+ let active_call_a = cx_a.read(ActiveCall::global);
+ active_call_a
+ .update(cx_a, |call, cx| call.invite(user_id_b, None, cx))
+ .await
+ .unwrap();
+ executor.run_until_parked();
+ let active_call_b = cx_b.read(ActiveCall::global);
+ active_call_b
+ .update(cx_b, |call, cx| call.accept_incoming(cx))
+ .await
+ .unwrap();
+ executor.run_until_parked();
+
+ let _project_b2 = client_b.join_remote_project(project_id, cx_b).await;
+ executor.run_until_parked();
+
+ project_a.read_with(cx_a, |project, _| {
+ let guest_count = project
+ .collaborators()
+ .values()
+ .filter(|c| !c.is_host)
+ .count();
+
+ assert_eq!(
+ guest_count, 2,
+ "host should have exactly one guest collaborator after rejoin"
+ );
+ });
+
+ _project_b.read_with(cx_b, |project, _| {
+ assert_eq!(
+ project.client_subscriptions().len(),
+ 0,
+ "We should clear all host subscriptions after leaving the project"
+ );
+ })
+}
@@ -33,7 +33,7 @@ use settings::{
SettingsStore,
};
use std::{
- path::Path,
+ path::{Path, PathBuf},
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
@@ -396,6 +396,130 @@ async fn test_ssh_collaboration_git_branches(
});
}
+#[gpui::test]
+async fn test_ssh_collaboration_git_worktrees(
+ executor: BackgroundExecutor,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+ server_cx: &mut TestAppContext,
+) {
+ cx_a.set_name("a");
+ cx_b.set_name("b");
+ server_cx.set_name("server");
+
+ cx_a.update(|cx| {
+ release_channel::init(semver::Version::new(0, 0, 0), cx);
+ });
+ server_cx.update(|cx| {
+ release_channel::init(semver::Version::new(0, 0, 0), cx);
+ });
+
+ let mut server = TestServer::start(executor.clone()).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+ server
+ .create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)])
+ .await;
+
+ let (opts, server_ssh, _) = RemoteClient::fake_server(cx_a, server_cx);
+ let remote_fs = FakeFs::new(server_cx.executor());
+ remote_fs
+ .insert_tree("/project", json!({ ".git": {}, "file.txt": "content" }))
+ .await;
+
+ server_cx.update(HeadlessProject::init);
+ let languages = Arc::new(LanguageRegistry::new(server_cx.executor()));
+ let headless_project = server_cx.new(|cx| {
+ HeadlessProject::new(
+ HeadlessAppState {
+ session: server_ssh,
+ fs: remote_fs.clone(),
+ http_client: Arc::new(BlockedHttpClient),
+ node_runtime: NodeRuntime::unavailable(),
+ languages,
+ extension_host_proxy: Arc::new(ExtensionHostProxy::new()),
+ startup_time: std::time::Instant::now(),
+ },
+ false,
+ cx,
+ )
+ });
+
+ let client_ssh = RemoteClient::connect_mock(opts, cx_a).await;
+ let (project_a, _) = client_a
+ .build_ssh_project("/project", client_ssh, false, cx_a)
+ .await;
+
+ let active_call_a = cx_a.read(ActiveCall::global);
+ let project_id = active_call_a
+ .update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
+ .await
+ .unwrap();
+ let project_b = client_b.join_remote_project(project_id, cx_b).await;
+
+ executor.run_until_parked();
+
+ let repo_b = cx_b.update(|cx| project_b.read(cx).active_repository(cx).unwrap());
+
+ let worktrees = cx_b
+ .update(|cx| repo_b.update(cx, |repo, _| repo.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(worktrees.len(), 1);
+
+ let worktree_directory = PathBuf::from("/project");
+ cx_b.update(|cx| {
+ repo_b.update(cx, |repo, _| {
+ repo.create_worktree(
+ "feature-branch".to_string(),
+ worktree_directory.clone(),
+ Some("abc123".to_string()),
+ )
+ })
+ })
+ .await
+ .unwrap()
+ .unwrap();
+
+ executor.run_until_parked();
+
+ let worktrees = cx_b
+ .update(|cx| repo_b.update(cx, |repo, _| repo.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(worktrees.len(), 2);
+ assert_eq!(worktrees[1].path, worktree_directory.join("feature-branch"));
+ assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch");
+ assert_eq!(worktrees[1].sha.as_ref(), "abc123");
+
+ let server_worktrees = {
+ let server_repo = server_cx.update(|cx| {
+ headless_project.update(cx, |headless_project, cx| {
+ headless_project
+ .git_store
+ .read(cx)
+ .repositories()
+ .values()
+ .next()
+ .unwrap()
+ .clone()
+ })
+ });
+ server_cx
+ .update(|cx| server_repo.update(cx, |repo, _| repo.worktrees()))
+ .await
+ .unwrap()
+ .unwrap()
+ };
+ assert_eq!(server_worktrees.len(), 2);
+ assert_eq!(
+ server_worktrees[1].path,
+ worktree_directory.join("feature-branch")
+ );
+}
+
#[gpui::test]
async fn test_ssh_collaboration_formatting_with_prettier(
executor: BackgroundExecutor,
@@ -578,6 +578,7 @@ fn handle_postprocessing() -> Result<()> {
.expect("Default title not a string")
.to_string();
let amplitude_key = std::env::var("DOCS_AMPLITUDE_API_KEY").unwrap_or_default();
+ let consent_io_instance = std::env::var("DOCS_CONSENT_IO_INSTANCE").unwrap_or_default();
output.insert("html".to_string(), zed_html);
mdbook::Renderer::render(&mdbook::renderer::HtmlHandlebars::new(), &ctx)?;
@@ -647,6 +648,7 @@ fn handle_postprocessing() -> Result<()> {
zlog::trace!(logger => "Updating {:?}", pretty_path(&file, &root_dir));
let contents = contents.replace("#description#", meta_description);
let contents = contents.replace("#amplitude_key#", &litude_key);
+ let contents = contents.replace("#consent_io_instance#", &consent_io_instance);
let contents = title_regex()
.replace(&contents, |_: ®ex::Captures| {
format!("<title>{}</title>", meta_title)
@@ -1,7 +1,7 @@
use anyhow::Result;
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
-use cloud_api_types::SubmitEditPredictionFeedbackBody;
+use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
use cloud_llm_client::predict_edits_v3::{
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
};
@@ -69,6 +69,7 @@ pub mod sweep_ai;
pub mod udiff;
mod capture_example;
+pub mod open_ai_compatible;
mod zed_edit_prediction_delegate;
pub mod zeta;
@@ -107,13 +108,8 @@ const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
-pub struct Zeta2FeatureFlag;
pub struct EditPredictionJumpsFeatureFlag;
-impl FeatureFlag for Zeta2FeatureFlag {
- const NAME: &'static str = "zeta2";
-}
-
impl FeatureFlag for EditPredictionJumpsFeatureFlag {
const NAME: &'static str = "edit_prediction_jumps";
}
@@ -129,6 +125,7 @@ impl Global for EditPredictionStoreGlobal {}
#[derive(Clone)]
pub struct Zeta2RawConfig {
pub model_id: Option<String>,
+ pub environment: Option<String>,
pub format: ZetaFormat,
}
@@ -147,7 +144,7 @@ pub struct EditPredictionStore {
pub sweep_ai: SweepAi,
pub mercury: Mercury,
data_collection_choice: DataCollectionChoice,
- reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
+ reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
settled_predictions_tx: mpsc::UnboundedSender<Instant>,
shown_predictions: VecDeque<EditPrediction>,
rated_predictions: HashSet<EditPredictionId>,
@@ -155,6 +152,11 @@ pub struct EditPredictionStore {
settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
}
+pub(crate) struct EditPredictionRejectionPayload {
+ rejection: EditPredictionRejection,
+ organization_id: Option<OrganizationId>,
+}
+
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum EditPredictionModel {
Zeta,
@@ -723,8 +725,13 @@ impl EditPredictionStore {
|this, _listener, _event, cx| {
let client = this.client.clone();
let llm_token = this.llm_token.clone();
+ let organization_id = this
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
cx.spawn(async move |_this, _cx| {
- llm_token.refresh(&client).await?;
+ llm_token.refresh(&client, organization_id).await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
@@ -754,7 +761,12 @@ impl EditPredictionStore {
let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
let format = ZetaFormat::parse(&version_str).ok()?;
let model_id = env::var("ZED_ZETA_MODEL").ok();
- Some(Zeta2RawConfig { model_id, format })
+ let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
+ Some(Zeta2RawConfig {
+ model_id,
+ environment,
+ format,
+ })
}
pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
@@ -785,11 +797,17 @@ impl EditPredictionStore {
let client = self.client.clone();
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
+
cx.spawn(async move |this, cx| {
let experiments = cx
.background_spawn(async move {
let http_client = client.http_client();
- let token = llm_token.acquire(&client).await?;
+ let token = llm_token.acquire(&client, organization_id).await?;
let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
let request = http_client::Request::builder()
.method(Method::GET)
@@ -1428,7 +1446,7 @@ impl EditPredictionStore {
}
async fn handle_rejected_predictions(
- rx: UnboundedReceiver<EditPredictionRejection>,
+ rx: UnboundedReceiver<EditPredictionRejectionPayload>,
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: Version,
@@ -1437,7 +1455,11 @@ impl EditPredictionStore {
let mut rx = std::pin::pin!(rx.peekable());
let mut batched = Vec::new();
- while let Some(rejection) = rx.next().await {
+ while let Some(EditPredictionRejectionPayload {
+ rejection,
+ organization_id,
+ }) = rx.next().await
+ {
batched.push(rejection);
if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
@@ -1475,6 +1497,7 @@ impl EditPredictionStore {
},
client.clone(),
llm_token.clone(),
+ organization_id,
app_version.clone(),
true,
)
@@ -1680,13 +1703,23 @@ impl EditPredictionStore {
all_language_settings(None, cx).edit_predictions.provider,
EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
);
+
if is_cloud {
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
+
self.reject_predictions_tx
- .unbounded_send(EditPredictionRejection {
- request_id: prediction_id.to_string(),
- reason,
- was_shown,
- model_version,
+ .unbounded_send(EditPredictionRejectionPayload {
+ rejection: EditPredictionRejection {
+ request_id: prediction_id.to_string(),
+ reason,
+ was_shown,
+ model_version,
+ },
+ organization_id,
})
.log_err();
}
@@ -2108,7 +2141,7 @@ impl EditPredictionStore {
active_buffer.clone(),
position,
trigger,
- cx.has_flag::<Zeta2FeatureFlag>(),
+ cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
cx,
)
}
@@ -2341,6 +2374,7 @@ impl EditPredictionStore {
client: Arc<Client>,
custom_url: Option<Arc<Url>>,
llm_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
app_version: Version,
) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
let url = if let Some(custom_url) = custom_url {
@@ -2360,6 +2394,7 @@ impl EditPredictionStore {
},
client,
llm_token,
+ organization_id,
app_version,
true,
)
@@ -2370,6 +2405,7 @@ impl EditPredictionStore {
input: ZetaPromptInput,
client: Arc<Client>,
llm_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
app_version: Version,
trigger: PredictEditsRequestTrigger,
) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
@@ -2392,6 +2428,7 @@ impl EditPredictionStore {
},
client,
llm_token,
+ organization_id,
app_version,
true,
)
@@ -2445,6 +2482,7 @@ impl EditPredictionStore {
build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
client: Arc<Client>,
llm_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
app_version: Version,
require_auth: bool,
) -> Result<(Res, Option<EditPredictionUsage>)>
@@ -2454,9 +2492,12 @@ impl EditPredictionStore {
let http_client = client.http_client();
let mut token = if require_auth {
- Some(llm_token.acquire(&client).await?)
+ Some(llm_token.acquire(&client, organization_id.clone()).await?)
} else {
- llm_token.acquire(&client).await.ok()
+ llm_token
+ .acquire(&client, organization_id.clone())
+ .await
+ .ok()
};
let mut did_retry = false;
@@ -2498,7 +2539,7 @@ impl EditPredictionStore {
return Ok((serde_json::from_slice(&body)?, usage));
} else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
did_retry = true;
- token = Some(llm_token.refresh(&client).await?);
+ token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
@@ -1,6 +1,7 @@
use crate::{
- EditPredictionId, EditPredictionModelInput, cursor_excerpt, prediction::EditPredictionResult,
- zeta,
+ EditPredictionId, EditPredictionModelInput, cursor_excerpt,
+ open_ai_compatible::{self, load_open_ai_compatible_api_key_if_needed},
+ prediction::EditPredictionResult,
};
use anyhow::{Context as _, Result, anyhow};
use gpui::{App, AppContext as _, Entity, Task};
@@ -58,6 +59,8 @@ pub fn request_prediction(
return Task::ready(Err(anyhow!("Unsupported edit prediction provider for FIM")));
};
+ let api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
+
let result = cx.background_spawn(async move {
let (excerpt_range, _) = cursor_excerpt::editable_and_context_ranges_for_cursor_position(
cursor_point,
@@ -90,12 +93,14 @@ pub fn request_prediction(
let stop_tokens = get_fim_stop_tokens();
let max_tokens = settings.max_output_tokens;
- let (response_text, request_id) = zeta::send_custom_server_request(
+
+ let (response_text, request_id) = open_ai_compatible::send_custom_server_request(
provider,
&settings,
prompt,
max_tokens,
stop_tokens,
+ api_key,
&http_client,
)
.await?;
@@ -0,0 +1,133 @@
+use anyhow::{Context as _, Result};
+use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
+use futures::AsyncReadExt as _;
+use gpui::{App, AppContext as _, Entity, Global, SharedString, Task, http_client};
+use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
+use language_model::{ApiKeyState, EnvVar, env_var};
+use std::sync::Arc;
+
+pub fn open_ai_compatible_api_url(cx: &App) -> SharedString {
+ all_language_settings(None, cx)
+ .edit_predictions
+ .open_ai_compatible_api
+ .as_ref()
+ .map(|settings| settings.api_url.clone())
+ .unwrap_or_default()
+ .into()
+}
+
+pub const OPEN_AI_COMPATIBLE_CREDENTIALS_USERNAME: &str = "openai-compatible-api-token";
+pub static OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> =
+ env_var!("ZED_OPEN_AI_COMPATIBLE_EDIT_PREDICTION_API_KEY");
+
+struct GlobalOpenAiCompatibleApiKey(Entity<ApiKeyState>);
+
+impl Global for GlobalOpenAiCompatibleApiKey {}
+
+pub fn open_ai_compatible_api_token(cx: &mut App) -> Entity<ApiKeyState> {
+ if let Some(global) = cx.try_global::<GlobalOpenAiCompatibleApiKey>() {
+ return global.0.clone();
+ }
+
+ let entity = cx.new(|cx| {
+ ApiKeyState::new(
+ open_ai_compatible_api_url(cx),
+ OPEN_AI_COMPATIBLE_TOKEN_ENV_VAR.clone(),
+ )
+ });
+ cx.set_global(GlobalOpenAiCompatibleApiKey(entity.clone()));
+ entity
+}
+
+pub fn load_open_ai_compatible_api_token(
+ cx: &mut App,
+) -> Task<Result<(), language_model::AuthenticateError>> {
+ let api_url = open_ai_compatible_api_url(cx);
+ open_ai_compatible_api_token(cx).update(cx, |key_state, cx| {
+ key_state.load_if_needed(api_url, |s| s, cx)
+ })
+}
+
+pub fn load_open_ai_compatible_api_key_if_needed(
+ provider: settings::EditPredictionProvider,
+ cx: &mut App,
+) -> Option<Arc<str>> {
+ if provider != settings::EditPredictionProvider::OpenAiCompatibleApi {
+ return None;
+ }
+ _ = load_open_ai_compatible_api_token(cx);
+ let url = open_ai_compatible_api_url(cx);
+ return open_ai_compatible_api_token(cx).read(cx).key(&url);
+}
+
+pub(crate) async fn send_custom_server_request(
+ provider: settings::EditPredictionProvider,
+ settings: &OpenAiCompatibleEditPredictionSettings,
+ prompt: String,
+ max_tokens: u32,
+ stop_tokens: Vec<String>,
+ api_key: Option<Arc<str>>,
+ http_client: &Arc<dyn http_client::HttpClient>,
+) -> Result<(String, String)> {
+ match provider {
+ settings::EditPredictionProvider::Ollama => {
+ let response = crate::ollama::make_request(
+ settings.clone(),
+ prompt,
+ stop_tokens,
+ http_client.clone(),
+ )
+ .await?;
+ Ok((response.response, response.created_at))
+ }
+ _ => {
+ let request = RawCompletionRequest {
+ model: settings.model.clone(),
+ prompt,
+ max_tokens: Some(max_tokens),
+ temperature: None,
+ stop: stop_tokens
+ .into_iter()
+ .map(std::borrow::Cow::Owned)
+ .collect(),
+ environment: None,
+ };
+
+ let request_body = serde_json::to_string(&request)?;
+ let mut http_request_builder = http_client::Request::builder()
+ .method(http_client::Method::POST)
+ .uri(settings.api_url.as_ref())
+ .header("Content-Type", "application/json");
+
+ if let Some(api_key) = api_key {
+ http_request_builder =
+ http_request_builder.header("Authorization", format!("Bearer {}", api_key));
+ }
+
+ let http_request =
+ http_request_builder.body(http_client::AsyncBody::from(request_body))?;
+
+ let mut response = http_client.send(http_request).await?;
+ let status = response.status();
+
+ if !status.is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ anyhow::bail!("custom server error: {} - {}", status, body);
+ }
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ let parsed: RawCompletionResponse =
+ serde_json::from_str(&body).context("Failed to parse completion response")?;
+ let text = parsed
+ .choices
+ .into_iter()
+ .next()
+ .map(|choice| choice.text)
+ .unwrap_or_default();
+ Ok((text, parsed.id))
+ }
+ }
+}
@@ -2,29 +2,30 @@ use crate::cursor_excerpt::compute_excerpt_ranges;
use crate::prediction::EditPredictionResult;
use crate::{
CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
- EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama,
+ EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
};
-use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
+use anyhow::Result;
+use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
use edit_prediction_types::PredictedCursorPosition;
-use futures::AsyncReadExt as _;
-use gpui::{App, AppContext as _, Task, http_client, prelude::*};
-use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
+use gpui::{App, AppContext as _, Task, prelude::*};
+use language::language_settings::all_language_settings;
use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
use release_channel::AppVersion;
use settings::EditPredictionPromptFormat;
use text::{Anchor, Bias};
-use std::env;
-use std::ops::Range;
-use std::{path::Path, sync::Arc, time::Instant};
+use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
use zeta_prompt::{
CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill,
- prompt_input_contains_special_tokens,
+ output_with_context_for_format, prompt_input_contains_special_tokens,
zeta1::{self, EDITABLE_REGION_END_MARKER},
};
+use crate::open_ai_compatible::{
+ load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
+};
+
pub fn request_prediction_with_zeta(
store: &mut EditPredictionStore,
EditPredictionModelInput {
@@ -56,6 +57,7 @@ pub fn request_prediction_with_zeta(
let buffer_snapshotted_at = Instant::now();
let raw_config = store.zeta2_raw_config().cloned();
let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
+ let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
let excerpt_path: Arc<Path> = snapshot
.file()
@@ -64,6 +66,11 @@ pub fn request_prediction_with_zeta(
let client = store.client.clone();
let llm_token = store.llm_token.clone();
+ let organization_id = store
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
let app_version = AppVersion::global(cx);
let request_task = cx.background_spawn({
@@ -131,6 +138,7 @@ pub fn request_prediction_with_zeta(
prompt,
max_tokens,
stop_tokens,
+ open_ai_compatible_api_key.clone(),
&http_client,
)
.await?;
@@ -157,6 +165,7 @@ pub fn request_prediction_with_zeta(
prompt,
max_tokens,
vec![],
+ open_ai_compatible_api_key.clone(),
&http_client,
)
.await?;
@@ -177,13 +186,17 @@ pub fn request_prediction_with_zeta(
let prompt = format_zeta_prompt(&prompt_input, config.format);
let prefill = get_prefill(&prompt_input, config.format);
let prompt = format!("{prompt}{prefill}");
+ let environment = config
+ .environment
+ .clone()
+ .or_else(|| Some(config.format.to_string().to_lowercase()));
let request = RawCompletionRequest {
model: config.model_id.clone().unwrap_or_default(),
prompt,
temperature: None,
stop: vec![],
max_tokens: Some(2048),
- environment: Some(config.format.to_string().to_lowercase()),
+ environment,
};
editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
@@ -197,6 +210,7 @@ pub fn request_prediction_with_zeta(
client,
None,
llm_token,
+ organization_id,
app_version,
)
.await?;
@@ -215,6 +229,7 @@ pub fn request_prediction_with_zeta(
prompt_input.clone(),
client,
llm_token,
+ organization_id,
app_version,
trigger,
)
@@ -240,6 +255,25 @@ pub fn request_prediction_with_zeta(
return Ok((Some((request_id, None, model_version)), usage));
};
+ let editable_range_in_buffer = editable_range_in_excerpt.start
+ + full_context_offset_range.start
+ ..editable_range_in_excerpt.end + full_context_offset_range.start;
+
+ let mut old_text = snapshot
+ .text_for_range(editable_range_in_buffer.clone())
+ .collect::<String>();
+
+ // For the hashline format, the model may return <|set|>/<|insert|>
+ // edit commands instead of a full replacement. Apply them against
+ // the original editable region to produce the full replacement text.
+ // This must happen before cursor marker stripping because the cursor
+ // marker is embedded inside edit command content.
+ if let Some(rewritten_output) =
+ output_with_context_for_format(zeta_version, &old_text, &output_text)?
+ {
+ output_text = rewritten_output;
+ }
+
// Client-side cursor marker processing (applies to both raw and v3 responses)
let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
if let Some(offset) = cursor_offset_in_output {
@@ -259,14 +293,6 @@ pub fn request_prediction_with_zeta(
.ok();
}
- let editable_range_in_buffer = editable_range_in_excerpt.start
- + full_context_offset_range.start
- ..editable_range_in_excerpt.end + full_context_offset_range.start;
-
- let mut old_text = snapshot
- .text_for_range(editable_range_in_buffer.clone())
- .collect::<String>();
-
if !output_text.is_empty() && !output_text.ends_with('\n') {
output_text.push('\n');
}
@@ -400,66 +426,6 @@ pub fn zeta2_prompt_input(
(full_context_offset_range, prompt_input)
}
-pub(crate) async fn send_custom_server_request(
- provider: settings::EditPredictionProvider,
- settings: &OpenAiCompatibleEditPredictionSettings,
- prompt: String,
- max_tokens: u32,
- stop_tokens: Vec<String>,
- http_client: &Arc<dyn http_client::HttpClient>,
-) -> Result<(String, String)> {
- match provider {
- settings::EditPredictionProvider::Ollama => {
- let response =
- ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone())
- .await?;
- Ok((response.response, response.created_at))
- }
- _ => {
- let request = RawCompletionRequest {
- model: settings.model.clone(),
- prompt,
- max_tokens: Some(max_tokens),
- temperature: None,
- stop: stop_tokens
- .into_iter()
- .map(std::borrow::Cow::Owned)
- .collect(),
- environment: None,
- };
-
- let request_body = serde_json::to_string(&request)?;
- let http_request = http_client::Request::builder()
- .method(http_client::Method::POST)
- .uri(settings.api_url.as_ref())
- .header("Content-Type", "application/json")
- .body(http_client::AsyncBody::from(request_body))?;
-
- let mut response = http_client.send(http_request).await?;
- let status = response.status();
-
- if !status.is_success() {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- anyhow::bail!("custom server error: {} - {}", status, body);
- }
-
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- let parsed: RawCompletionResponse =
- serde_json::from_str(&body).context("Failed to parse completion response")?;
- let text = parsed
- .choices
- .into_iter()
- .next()
- .map(|choice| choice.text)
- .unwrap_or_default();
- Ok((text, parsed.id))
- }
- }
-}
-
pub(crate) fn edit_prediction_accepted(
store: &EditPredictionStore,
current_prediction: CurrentEditPrediction,
@@ -475,6 +441,11 @@ pub(crate) fn edit_prediction_accepted(
let require_auth = custom_accept_url.is_none();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
+ let organization_id = store
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|organization| organization.id.clone());
let app_version = AppVersion::global(cx);
cx.background_spawn(async move {
@@ -499,6 +470,7 @@ pub(crate) fn edit_prediction_accepted(
},
client,
llm_token,
+ organization_id,
app_version,
require_auth,
)
@@ -12,7 +12,8 @@ use similar::DiffableStr;
use std::ops::Range;
use std::sync::Arc;
use zeta_prompt::{
- ZetaFormat, excerpt_range_for_format, format_zeta_prompt, resolve_cursor_region,
+ ZetaFormat, encode_patch_as_output_for_format, excerpt_range_for_format, format_zeta_prompt,
+ output_end_marker_for_format, resolve_cursor_region,
};
pub async fn run_format_prompt(
@@ -53,18 +54,22 @@ pub async fn run_format_prompt(
let prompt = format_zeta_prompt(prompt_inputs, zeta_format);
let prefill = zeta_prompt::get_prefill(prompt_inputs, zeta_format);
- let (expected_patch, expected_cursor_offset) = example
+ let expected_output = example
.spec
.expected_patches_with_cursor_positions()
.into_iter()
.next()
- .context("expected patches is empty")?;
- let expected_output = zeta2_output_for_patch(
- prompt_inputs,
- &expected_patch,
- expected_cursor_offset,
- zeta_format,
- )?;
+ .and_then(|(expected_patch, expected_cursor_offset)| {
+ zeta2_output_for_patch(
+ prompt_inputs,
+ &expected_patch,
+ expected_cursor_offset,
+ zeta_format,
+ )
+ .ok()
+ })
+ .unwrap_or_default();
+
let rejected_output = example.spec.rejected_patch.as_ref().and_then(|patch| {
zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok()
});
@@ -97,6 +102,12 @@ pub fn zeta2_output_for_patch(
old_editable_region.push('\n');
}
+ if let Some(encoded_output) =
+ encode_patch_as_output_for_format(version, &old_editable_region, patch, cursor_offset)?
+ {
+ return Ok(encoded_output);
+ }
+
let (mut result, first_hunk_offset) =
udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context(
|| {
@@ -116,16 +127,11 @@ pub fn zeta2_output_for_patch(
result.insert_str(offset, zeta_prompt::CURSOR_MARKER);
}
- match version {
- ZetaFormat::V0120GitMergeMarkers
- | ZetaFormat::V0131GitMergeMarkersPrefix
- | ZetaFormat::V0211SeedCoder => {
- if !result.ends_with('\n') {
- result.push('\n');
- }
- result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
+ if let Some(end_marker) = output_end_marker_for_format(version) {
+ if !result.ends_with('\n') {
+ result.push('\n');
}
- _ => (),
+ result.push_str(end_marker);
}
Ok(result)
@@ -358,6 +358,7 @@ enum PredictionProvider {
Mercury,
Zeta1,
Zeta2(ZetaFormat),
+ Baseten(ZetaFormat),
Teacher(TeacherBackend),
TeacherNonBatching(TeacherBackend),
Repair,
@@ -376,6 +377,7 @@ impl std::fmt::Display for PredictionProvider {
PredictionProvider::Mercury => write!(f, "mercury"),
PredictionProvider::Zeta1 => write!(f, "zeta1"),
PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
+ PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
PredictionProvider::TeacherNonBatching(backend) => {
write!(f, "teacher-non-batching:{backend}")
@@ -415,6 +417,13 @@ impl std::str::FromStr for PredictionProvider {
Ok(PredictionProvider::TeacherNonBatching(backend))
}
"repair" => Ok(PredictionProvider::Repair),
+ "baseten" => {
+ let format = arg
+ .map(ZetaFormat::parse)
+ .transpose()?
+ .unwrap_or(ZetaFormat::default());
+ Ok(PredictionProvider::Baseten(format))
+ }
_ => {
anyhow::bail!(
"unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
@@ -6,7 +6,11 @@ use crate::{
};
use anyhow::{Context as _, Result};
use edit_prediction::example_spec::encode_cursor_in_patch;
-use zeta_prompt::{CURSOR_MARKER, ZetaFormat};
+use zeta_prompt::{
+ CURSOR_MARKER, ZetaFormat, clean_extracted_region_for_format,
+ current_region_markers_for_format, output_end_marker_for_format,
+ output_with_context_for_format,
+};
pub fn run_parse_output(example: &mut Example) -> Result<()> {
example
@@ -51,22 +55,7 @@ pub fn parse_prediction_output(
}
fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<String> {
- let (current_marker, end_marker) = match format {
- ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
- ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
- ("<|fim_middle|>current\n", "<|fim_suffix|>")
- }
- ZetaFormat::V0120GitMergeMarkers
- | ZetaFormat::V0131GitMergeMarkersPrefix
- | ZetaFormat::V0211Prefill => (
- zeta_prompt::v0120_git_merge_markers::START_MARKER,
- zeta_prompt::v0120_git_merge_markers::SEPARATOR,
- ),
- ZetaFormat::V0211SeedCoder => (
- zeta_prompt::seed_coder::START_MARKER,
- zeta_prompt::seed_coder::SEPARATOR,
- ),
- };
+ let (current_marker, end_marker) = current_region_markers_for_format(format);
let start = prompt.find(current_marker).with_context(|| {
format!(
@@ -82,8 +71,7 @@ fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<Stri
let region = &prompt[start..end];
let region = region.replace(CURSOR_MARKER, "");
-
- Ok(region)
+ Ok(clean_extracted_region_for_format(format, ®ion))
}
fn parse_zeta2_output(
@@ -100,6 +88,9 @@ fn parse_zeta2_output(
let old_text = extract_zeta2_current_region(prompt, format)?;
let mut new_text = actual_output.to_string();
+ if let Some(transformed) = output_with_context_for_format(format, &old_text, &new_text)? {
+ new_text = transformed;
+ }
let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
Some(offset)
@@ -107,19 +98,9 @@ fn parse_zeta2_output(
None
};
- let suffix = match format {
- ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
- zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
- }
- ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
- ZetaFormat::V0112MiddleAtEnd
- | ZetaFormat::V0113Ordered
- | ZetaFormat::V0114180EditableRegion => "",
- ZetaFormat::V0211SeedCoder => zeta_prompt::seed_coder::END_MARKER,
- };
- if !suffix.is_empty() {
+ if let Some(marker) = output_end_marker_for_format(format) {
new_text = new_text
- .strip_suffix(suffix)
+ .strip_suffix(marker)
.unwrap_or(&new_text)
.to_string();
}
@@ -6,14 +6,18 @@ use crate::{
headless::EpAppState,
load_project::run_load_project,
openai_client::OpenAiClient,
+ parse_output::parse_prediction_output,
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
- progress::{ExampleProgress, InfoStyle, Step},
+ progress::{ExampleProgress, InfoStyle, Step, StepProgress},
retrieve_context::run_context_retrieval,
};
use anyhow::Context as _;
+use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
-use futures::{FutureExt as _, StreamExt as _, future::Shared};
+use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
use gpui::{AppContext as _, AsyncApp, Task};
+use http_client::{AsyncBody, HttpClient, Method};
+use reqwest_client::ReqwestClient;
use std::{
fs,
sync::{
@@ -79,6 +83,22 @@ pub async fn run_prediction(
.await;
}
+ if let PredictionProvider::Baseten(format) = provider {
+ run_format_prompt(
+ example,
+ &FormatPromptArgs {
+ provider: PredictionProvider::Zeta2(format),
+ },
+ app_state.clone(),
+ example_progress,
+ cx,
+ )
+ .await?;
+
+ let step_progress = example_progress.start(Step::Predict);
+ return predict_baseten(example, format, &step_progress).await;
+ }
+
run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
@@ -116,7 +136,8 @@ pub async fn run_prediction(
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher(..)
| PredictionProvider::TeacherNonBatching(..)
- | PredictionProvider::Repair => {
+ | PredictionProvider::Repair
+ | PredictionProvider::Baseten(_) => {
unreachable!()
}
};
@@ -127,7 +148,12 @@ pub async fn run_prediction(
if let PredictionProvider::Zeta2(format) = provider {
if format != ZetaFormat::default() {
let model_id = std::env::var("ZED_ZETA_MODEL").ok();
- store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
+ let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
+ store.set_zeta2_raw_config(Zeta2RawConfig {
+ model_id,
+ environment,
+ format,
+ });
}
}
});
@@ -480,6 +506,89 @@ async fn predict_openai(
Ok(())
}
+pub async fn predict_baseten(
+ example: &mut Example,
+ format: ZetaFormat,
+ step_progress: &StepProgress,
+) -> anyhow::Result<()> {
+ let model_id =
+ std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
+
+ let api_key =
+ std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
+
+ let prompt = example.prompt.as_ref().context("Prompt is required")?;
+ let prompt_text = prompt.input.clone();
+ let prefill = prompt.prefill.clone().unwrap_or_default();
+
+ step_progress.set_substatus("running prediction via baseten");
+
+ let environment: String = <&'static str>::from(&format).to_lowercase();
+ let url = format!(
+ "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
+ );
+
+ let request_body = RawCompletionRequest {
+ model: model_id,
+ prompt: prompt_text.clone(),
+ max_tokens: Some(2048),
+ temperature: Some(0.),
+ stop: vec![],
+ environment: None,
+ };
+
+ let body_bytes =
+ serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
+
+ let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
+ let request = http_client::Request::builder()
+ .method(Method::POST)
+ .uri(&url)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Api-Key {api_key}"))
+ .body(AsyncBody::from(body_bytes))?;
+
+ let mut response = http_client.send(request).await?;
+ let status = response.status();
+
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .context("Failed to read Baseten response body")?;
+
+ if !status.is_success() {
+ anyhow::bail!("Baseten API returned {status}: {body}");
+ }
+
+ let completion: RawCompletionResponse =
+ serde_json::from_str(&body).context("Failed to parse Baseten response")?;
+
+ let actual_output = completion
+ .choices
+ .into_iter()
+ .next()
+ .map(|choice| choice.text)
+ .unwrap_or_default();
+
+ let actual_output = format!("{prefill}{actual_output}");
+
+ let (actual_patch, actual_cursor) =
+ parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
+
+ let prediction = ExamplePrediction {
+ actual_patch: Some(actual_patch),
+ actual_output,
+ actual_cursor,
+ error: None,
+ provider: PredictionProvider::Baseten(format),
+ };
+
+ example.predictions.push(prediction);
+ Ok(())
+}
+
pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
match provider {
Some(PredictionProvider::Teacher(backend)) => match backend {
@@ -34,7 +34,7 @@ pub struct MinCaptureVersion {
pub patch: u32,
}
-const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
+const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
@@ -715,7 +715,7 @@ pub async fn fetch_rated_examples_after(
AND rated.event_properties:inputs IS NOT NULL
AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
AND rated.event_properties:output IS NOT NULL
- AND rated.event_properties:can_collect_data = true
+ AND rated.event_properties:inputs:can_collect_data = true
ORDER BY rated.time ASC
LIMIT ?
OFFSET ?
@@ -823,11 +823,11 @@ fn rated_examples_from_response<'a>(
let environment = get_string("environment");
let zed_version = get_string("zed_version");
- match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
- (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
+ match (inputs, output.clone(), rating.clone(), time.clone()) {
+ (Some(inputs), Some(output), Some(rating), Some(time)) => {
Some(build_rated_example(
request_id,
- device_id,
+ device_id.unwrap_or_default(),
time,
inputs,
output,
@@ -840,11 +840,10 @@ fn rated_examples_from_response<'a>(
}
_ => {
log::warn!(
- "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
+ "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} time={:?}",
inputs_json.is_some(),
output.is_some(),
rating.is_some(),
- device_id.is_some(),
time.is_some(),
);
None
@@ -3,7 +3,7 @@ use client::{Client, UserStore, zed_urls};
use cloud_llm_client::UsageLimit;
use codestral::{self, CodestralEditPredictionDelegate};
use copilot::Status;
-use edit_prediction::{EditPredictionStore, Zeta2FeatureFlag};
+use edit_prediction::EditPredictionStore;
use edit_prediction_types::EditPredictionDelegateHandle;
use editor::{
Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll,
@@ -22,9 +22,7 @@ use language::{
};
use project::{DisableAiSettings, Project};
use regex::Regex;
-use settings::{
- EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file,
-};
+use settings::{Settings, SettingsStore, update_settings_file};
use std::{
rc::Rc,
sync::{Arc, LazyLock},
@@ -539,9 +537,15 @@ impl EditPredictionButton {
edit_prediction::ollama::ensure_authenticated(cx);
let sweep_api_token_task = edit_prediction::sweep_ai::load_sweep_api_token(cx);
let mercury_api_token_task = edit_prediction::mercury::load_mercury_api_token(cx);
+ let open_ai_compatible_api_token_task =
+ edit_prediction::open_ai_compatible::load_open_ai_compatible_api_token(cx);
cx.spawn(async move |this, cx| {
- _ = futures::join!(sweep_api_token_task, mercury_api_token_task);
+ _ = futures::join!(
+ sweep_api_token_task,
+ mercury_api_token_task,
+ open_ai_compatible_api_token_task
+ );
this.update(cx, |_, cx| {
cx.notify();
})
@@ -770,13 +774,7 @@ impl EditPredictionButton {
menu = menu.separator().header("Privacy");
- if matches!(
- provider,
- EditPredictionProvider::Zed
- | EditPredictionProvider::Experimental(
- EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
- )
- ) {
+ if matches!(provider, EditPredictionProvider::Zed) {
if let Some(provider) = &self.edit_prediction_provider {
let data_collection = provider.data_collection_state(cx);
@@ -1399,12 +1397,6 @@ pub fn get_available_providers(cx: &mut App) -> Vec<EditPredictionProvider> {
providers.push(EditPredictionProvider::Zed);
- if cx.has_flag::<Zeta2FeatureFlag>() {
- providers.push(EditPredictionProvider::Experimental(
- EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
- ));
- }
-
if let Some(app_state) = workspace::AppState::global(cx).upgrade()
&& copilot::GlobalCopilotAuth::try_get_or_init(app_state, cx)
.is_some_and(|copilot| copilot.0.read(cx).is_authenticated())
@@ -1006,11 +1006,6 @@ impl DisplayMap {
&self.block_map.folded_buffers
}
- #[instrument(skip_all)]
- pub(super) fn clear_folded_buffer(&mut self, buffer_id: language::BufferId) {
- self.block_map.folded_buffers.remove(&buffer_id);
- }
-
#[instrument(skip_all)]
pub fn insert_creases(
&mut self,
@@ -1924,6 +1919,9 @@ impl DisplaySnapshot {
color
}
}),
+ underline: chunk_highlight
+ .underline
+ .filter(|_| editor_style.show_underlines),
..chunk_highlight
}
});
@@ -24147,9 +24147,13 @@ impl Editor {
self.display_map.update(cx, |display_map, cx| {
display_map.invalidate_semantic_highlights(*buffer_id);
display_map.clear_lsp_folding_ranges(*buffer_id, cx);
- display_map.clear_folded_buffer(*buffer_id);
});
}
+
+ self.display_map.update(cx, |display_map, cx| {
+ display_map.unfold_buffers(removed_buffer_ids.iter().copied(), cx);
+ });
+
jsx_tag_auto_close::refresh_enabled_in_any_buffer(self, multibuffer, cx);
cx.emit(EditorEvent::ExcerptsRemoved {
ids: ids.clone(),
@@ -58,4 +58,4 @@ gpui = { workspace = true, features = ["test-support"] }
git = { workspace = true, features = ["test-support"] }
[features]
-test-support = ["gpui/test-support", "git/test-support"]
+test-support = ["gpui/test-support", "git/test-support", "util/test-support"]
@@ -20,7 +20,7 @@ use ignore::gitignore::GitignoreBuilder;
use parking_lot::Mutex;
use rope::Rope;
use smol::{channel::Sender, future::FutureExt as _};
-use std::{path::PathBuf, sync::Arc};
+use std::{path::PathBuf, sync::Arc, sync::atomic::AtomicBool};
use text::LineEnding;
use util::{paths::PathStyle, rel_path::RelPath};
@@ -32,6 +32,7 @@ pub struct FakeGitRepository {
pub(crate) dot_git_path: PathBuf,
pub(crate) repository_dir_path: PathBuf,
pub(crate) common_dir_path: PathBuf,
+ pub(crate) is_trusted: Arc<AtomicBool>,
}
#[derive(Debug, Clone)]
@@ -406,7 +407,31 @@ impl GitRepository for FakeGitRepository {
}
fn worktrees(&self) -> BoxFuture<'_, Result<Vec<Worktree>>> {
- self.with_state_async(false, |state| Ok(state.worktrees.clone()))
+ let dot_git_path = self.dot_git_path.clone();
+ self.with_state_async(false, move |state| {
+ let work_dir = dot_git_path
+ .parent()
+ .map(PathBuf::from)
+ .unwrap_or(dot_git_path);
+ let head_sha = state
+ .refs
+ .get("HEAD")
+ .cloned()
+ .unwrap_or_else(|| "0000000".to_string());
+ let branch_ref = state
+ .current_branch_name
+ .as_ref()
+ .map(|name| format!("refs/heads/{name}"))
+ .unwrap_or_else(|| "refs/heads/main".to_string());
+ let main_worktree = Worktree {
+ path: work_dir,
+ ref_name: branch_ref.into(),
+ sha: head_sha.into(),
+ };
+ let mut all = vec![main_worktree];
+ all.extend(state.worktrees.iter().cloned());
+ Ok(all)
+ })
}
fn create_worktree(
@@ -1011,146 +1036,13 @@ impl GitRepository for FakeGitRepository {
fn commit_data_reader(&self) -> Result<CommitDataReader> {
anyhow::bail!("commit_data_reader not supported for FakeGitRepository")
}
-}
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::{FakeFs, Fs};
- use gpui::TestAppContext;
- use serde_json::json;
- use std::path::Path;
-
- #[gpui::test]
- async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) {
- let worktree_dir_settings = &["../worktrees", ".git/zed-worktrees", "my-worktrees/"];
-
- for worktree_dir_setting in worktree_dir_settings {
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree("/project", json!({".git": {}, "file.txt": "content"}))
- .await;
- let repo = fs
- .open_repo(Path::new("/project/.git"), None)
- .expect("should open fake repo");
-
- // Initially no worktrees
- let worktrees = repo.worktrees().await.unwrap();
- assert!(worktrees.is_empty());
-
- let expected_dir = git::repository::resolve_worktree_directory(
- Path::new("/project"),
- worktree_dir_setting,
- );
-
- // Create a worktree
- repo.create_worktree(
- "feature-branch".to_string(),
- expected_dir.clone(),
- Some("abc123".to_string()),
- )
- .await
- .unwrap();
-
- // List worktrees — should have one
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 1);
- assert_eq!(
- worktrees[0].path,
- expected_dir.join("feature-branch"),
- "failed for worktree_directory setting: {worktree_dir_setting:?}"
- );
- assert_eq!(worktrees[0].ref_name.as_ref(), "refs/heads/feature-branch");
- assert_eq!(worktrees[0].sha.as_ref(), "abc123");
-
- // Directory should exist in FakeFs after create
- assert!(
- fs.is_dir(&expected_dir.join("feature-branch")).await,
- "worktree directory should be created in FakeFs for setting {worktree_dir_setting:?}"
- );
-
- // Create a second worktree (without explicit commit)
- repo.create_worktree("bugfix-branch".to_string(), expected_dir.clone(), None)
- .await
- .unwrap();
-
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 2);
- assert!(
- fs.is_dir(&expected_dir.join("bugfix-branch")).await,
- "second worktree directory should be created in FakeFs for setting {worktree_dir_setting:?}"
- );
-
- // Rename the first worktree
- repo.rename_worktree(
- expected_dir.join("feature-branch"),
- expected_dir.join("renamed-branch"),
- )
- .await
- .unwrap();
+ fn set_trusted(&self, trusted: bool) {
+ self.is_trusted
+ .store(trusted, std::sync::atomic::Ordering::Release);
+ }
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 2);
- assert!(
- worktrees
- .iter()
- .any(|w| w.path == expected_dir.join("renamed-branch")),
- "renamed worktree should exist at new path for setting {worktree_dir_setting:?}"
- );
- assert!(
- worktrees
- .iter()
- .all(|w| w.path != expected_dir.join("feature-branch")),
- "old path should no longer exist for setting {worktree_dir_setting:?}"
- );
-
- // Directory should be moved in FakeFs after rename
- assert!(
- !fs.is_dir(&expected_dir.join("feature-branch")).await,
- "old worktree directory should not exist after rename for setting {worktree_dir_setting:?}"
- );
- assert!(
- fs.is_dir(&expected_dir.join("renamed-branch")).await,
- "new worktree directory should exist after rename for setting {worktree_dir_setting:?}"
- );
-
- // Rename a nonexistent worktree should fail
- let result = repo
- .rename_worktree(PathBuf::from("/nonexistent"), PathBuf::from("/somewhere"))
- .await;
- assert!(result.is_err());
-
- // Remove a worktree
- repo.remove_worktree(expected_dir.join("renamed-branch"), false)
- .await
- .unwrap();
-
- let worktrees = repo.worktrees().await.unwrap();
- assert_eq!(worktrees.len(), 1);
- assert_eq!(worktrees[0].path, expected_dir.join("bugfix-branch"));
-
- // Directory should be removed from FakeFs after remove
- assert!(
- !fs.is_dir(&expected_dir.join("renamed-branch")).await,
- "worktree directory should be removed from FakeFs for setting {worktree_dir_setting:?}"
- );
-
- // Remove a nonexistent worktree should fail
- let result = repo
- .remove_worktree(PathBuf::from("/nonexistent"), false)
- .await;
- assert!(result.is_err());
-
- // Remove the last worktree
- repo.remove_worktree(expected_dir.join("bugfix-branch"), false)
- .await
- .unwrap();
-
- let worktrees = repo.worktrees().await.unwrap();
- assert!(worktrees.is_empty());
- assert!(
- !fs.is_dir(&expected_dir.join("bugfix-branch")).await,
- "last worktree directory should be removed from FakeFs for setting {worktree_dir_setting:?}"
- );
- }
+ fn is_trusted(&self) -> bool {
+ self.is_trusted.load(std::sync::atomic::Ordering::Acquire)
}
}
@@ -2776,6 +2776,7 @@ impl Fs for FakeFs {
repository_dir_path: repository_dir_path.to_owned(),
common_dir_path: common_dir_path.to_owned(),
checkpoints: Arc::default(),
+ is_trusted: Arc::default(),
}) as _
},
)
@@ -1,9 +1,146 @@
use fs::{FakeFs, Fs};
-use gpui::BackgroundExecutor;
+use gpui::{BackgroundExecutor, TestAppContext};
use serde_json::json;
-use std::path::Path;
+use std::path::{Path, PathBuf};
use util::path;
+#[gpui::test]
+async fn test_fake_worktree_lifecycle(cx: &mut TestAppContext) {
+ let worktree_dir_settings = &["../worktrees", ".git/zed-worktrees", "my-worktrees/"];
+
+ for worktree_dir_setting in worktree_dir_settings {
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree("/project", json!({".git": {}, "file.txt": "content"}))
+ .await;
+ let repo = fs
+ .open_repo(Path::new("/project/.git"), None)
+ .expect("should open fake repo");
+
+ // Initially only the main worktree exists
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 1);
+ assert_eq!(worktrees[0].path, PathBuf::from("/project"));
+
+ let expected_dir = git::repository::resolve_worktree_directory(
+ Path::new("/project"),
+ worktree_dir_setting,
+ );
+
+ // Create a worktree
+ repo.create_worktree(
+ "feature-branch".to_string(),
+ expected_dir.clone(),
+ Some("abc123".to_string()),
+ )
+ .await
+ .unwrap();
+
+ // List worktrees — should have main + one created
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 2);
+ assert_eq!(worktrees[0].path, PathBuf::from("/project"));
+ assert_eq!(
+ worktrees[1].path,
+ expected_dir.join("feature-branch"),
+ "failed for worktree_directory setting: {worktree_dir_setting:?}"
+ );
+ assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch");
+ assert_eq!(worktrees[1].sha.as_ref(), "abc123");
+
+ // Directory should exist in FakeFs after create
+ assert!(
+ fs.is_dir(&expected_dir.join("feature-branch")).await,
+ "worktree directory should be created in FakeFs for setting {worktree_dir_setting:?}"
+ );
+
+ // Create a second worktree (without explicit commit)
+ repo.create_worktree("bugfix-branch".to_string(), expected_dir.clone(), None)
+ .await
+ .unwrap();
+
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 3);
+ assert!(
+ fs.is_dir(&expected_dir.join("bugfix-branch")).await,
+ "second worktree directory should be created in FakeFs for setting {worktree_dir_setting:?}"
+ );
+
+ // Rename the first worktree
+ repo.rename_worktree(
+ expected_dir.join("feature-branch"),
+ expected_dir.join("renamed-branch"),
+ )
+ .await
+ .unwrap();
+
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 3);
+ assert!(
+ worktrees
+ .iter()
+ .any(|w| w.path == expected_dir.join("renamed-branch")),
+ "renamed worktree should exist at new path for setting {worktree_dir_setting:?}"
+ );
+ assert!(
+ worktrees
+ .iter()
+ .all(|w| w.path != expected_dir.join("feature-branch")),
+ "old path should no longer exist for setting {worktree_dir_setting:?}"
+ );
+
+ // Directory should be moved in FakeFs after rename
+ assert!(
+ !fs.is_dir(&expected_dir.join("feature-branch")).await,
+ "old worktree directory should not exist after rename for setting {worktree_dir_setting:?}"
+ );
+ assert!(
+ fs.is_dir(&expected_dir.join("renamed-branch")).await,
+ "new worktree directory should exist after rename for setting {worktree_dir_setting:?}"
+ );
+
+ // Rename a nonexistent worktree should fail
+ let result = repo
+ .rename_worktree(PathBuf::from("/nonexistent"), PathBuf::from("/somewhere"))
+ .await;
+ assert!(result.is_err());
+
+ // Remove a worktree
+ repo.remove_worktree(expected_dir.join("renamed-branch"), false)
+ .await
+ .unwrap();
+
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 2);
+ assert_eq!(worktrees[0].path, PathBuf::from("/project"));
+ assert_eq!(worktrees[1].path, expected_dir.join("bugfix-branch"));
+
+ // Directory should be removed from FakeFs after remove
+ assert!(
+ !fs.is_dir(&expected_dir.join("renamed-branch")).await,
+ "worktree directory should be removed from FakeFs for setting {worktree_dir_setting:?}"
+ );
+
+ // Remove a nonexistent worktree should fail
+ let result = repo
+ .remove_worktree(PathBuf::from("/nonexistent"), false)
+ .await;
+ assert!(result.is_err());
+
+ // Remove the last worktree
+ repo.remove_worktree(expected_dir.join("bugfix-branch"), false)
+ .await
+ .unwrap();
+
+ let worktrees = repo.worktrees().await.unwrap();
+ assert_eq!(worktrees.len(), 1);
+ assert_eq!(worktrees[0].path, PathBuf::from("/project"));
+ assert!(
+ !fs.is_dir(&expected_dir.join("bugfix-branch")).await,
+ "last worktree directory should be removed from FakeFs for setting {worktree_dir_setting:?}"
+ );
+ }
+}
+
#[gpui::test]
async fn test_checkpoints(executor: BackgroundExecutor) {
let fs = FakeFs::new(executor);
@@ -0,0 +1,28 @@
+allow-private-module-inception = true
+avoid-breaking-exported-api = false
+ignore-interior-mutability = [
+ # Suppresses clippy::mutable_key_type, which is a false positive as the Eq
+ # and Hash impls do not use fields with interior mutability.
+ "agent_ui::context::AgentContextKey"
+]
+disallowed-methods = [
+ { path = "std::process::Command::spawn", reason = "Spawning `std::process::Command` can block the current thread for an unknown duration", replacement = "smol::process::Command::spawn" },
+ { path = "std::process::Command::output", reason = "Spawning `std::process::Command` can block the current thread for an unknown duration", replacement = "smol::process::Command::output" },
+ { path = "std::process::Command::status", reason = "Spawning `std::process::Command` can block the current thread for an unknown duration", replacement = "smol::process::Command::status" },
+ { path = "std::process::Command::stdin", reason = "`smol::process::Command::from()` does not preserve stdio configuration", replacement = "smol::process::Command::stdin" },
+ { path = "std::process::Command::stdout", reason = "`smol::process::Command::from()` does not preserve stdio configuration", replacement = "smol::process::Command::stdout" },
+ { path = "std::process::Command::stderr", reason = "`smol::process::Command::from()` does not preserve stdio configuration", replacement = "smol::process::Command::stderr" },
+ { path = "smol::Timer::after", reason = "smol::Timer introduces non-determinism in tests", replacement = "gpui::BackgroundExecutor::timer" },
+ { path = "serde_json::from_reader", reason = "Parsing from a buffer is much slower than first reading the buffer into a Vec/String, see https://github.com/serde-rs/json/issues/160#issuecomment-253446892. Use `serde_json::from_slice` instead." },
+ { path = "serde_json_lenient::from_reader", reason = "Parsing from a buffer is much slower than first reading the buffer into a Vec/String, see https://github.com/serde-rs/json/issues/160#issuecomment-253446892, Use `serde_json_lenient::from_slice` instead." },
+ { path = "cocoa::foundation::NSString::alloc", reason = "NSString must be autoreleased to avoid memory leaks. Use `ns_string()` helper instead." },
+ { path = "smol::process::Command::new", reason = "Git commands must go through `GitBinary::build_command` to ensure security flags like `-c core.fsmonitor=false` are always applied.", replacement = "GitBinary::build_command" },
+ { path = "util::command::new_command", reason = "Git commands must go through `GitBinary::build_command` to ensure security flags like `-c core.fsmonitor=false` are always applied.", replacement = "GitBinary::build_command" },
+ { path = "util::command::Command::new", reason = "Git commands must go through `GitBinary::build_command` to ensure security flags like `-c core.fsmonitor=false` are always applied.", replacement = "GitBinary::build_command" },
+]
+disallowed-types = [
+ # { path = "std::collections::HashMap", replacement = "collections::HashMap" },
+ # { path = "std::collections::HashSet", replacement = "collections::HashSet" },
+ # { path = "indexmap::IndexSet", replacement = "collections::IndexSet" },
+ # { path = "indexmap::IndexMap", replacement = "collections::IndexMap" },
+]
@@ -1,11 +1,11 @@
use crate::Oid;
use crate::commit::get_messages;
-use crate::repository::RepoPath;
+use crate::repository::{GitBinary, RepoPath};
use anyhow::{Context as _, Result};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use serde::{Deserialize, Serialize};
-use std::{ops::Range, path::Path};
+use std::ops::Range;
use text::{LineEnding, Rope};
use time::OffsetDateTime;
use time::UtcOffset;
@@ -21,15 +21,13 @@ pub struct Blame {
}
impl Blame {
- pub async fn for_path(
- git_binary: &Path,
- working_directory: &Path,
+ pub(crate) async fn for_path(
+ git: &GitBinary,
path: &RepoPath,
content: &Rope,
line_ending: LineEnding,
) -> Result<Self> {
- let output =
- run_git_blame(git_binary, working_directory, path, content, line_ending).await?;
+ let output = run_git_blame(git, path, content, line_ending).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
@@ -40,7 +38,7 @@ impl Blame {
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
- let messages = get_messages(working_directory, &shas)
+ let messages = get_messages(git, &shas)
.await
.context("failed to get commit messages")?;
@@ -52,8 +50,7 @@ const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
- git_binary: &Path,
- working_directory: &Path,
+ git: &GitBinary,
path: &RepoPath,
contents: &Rope,
line_ending: LineEnding,
@@ -61,12 +58,7 @@ async fn run_git_blame(
let mut child = {
let span = ztracing::debug_span!("spawning git-blame command", path = path.as_unix_str());
let _enter = span.enter();
- util::command::new_command(git_binary)
- .current_dir(working_directory)
- .arg("blame")
- .arg("--incremental")
- .arg("--contents")
- .arg("-")
+ git.build_command(["blame", "--incremental", "--contents", "-"])
.arg(path.as_unix_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
@@ -1,11 +1,11 @@
use crate::{
BuildCommitPermalinkParams, GitHostingProviderRegistry, GitRemote, Oid, parse_git_remote_url,
- status::StatusCode,
+ repository::GitBinary, status::StatusCode,
};
use anyhow::{Context as _, Result};
use collections::HashMap;
use gpui::SharedString;
-use std::{path::Path, sync::Arc};
+use std::sync::Arc;
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
@@ -48,7 +48,7 @@ impl ParsedCommitMessage {
}
}
-pub async fn get_messages(working_directory: &Path, shas: &[Oid]) -> Result<HashMap<Oid, String>> {
+pub(crate) async fn get_messages(git: &GitBinary, shas: &[Oid]) -> Result<HashMap<Oid, String>> {
if shas.is_empty() {
return Ok(HashMap::default());
}
@@ -63,12 +63,12 @@ pub async fn get_messages(working_directory: &Path, shas: &[Oid]) -> Result<Hash
let mut result = vec![];
for shas in shas.chunks(MAX_ENTRIES_PER_INVOCATION) {
- let partial = get_messages_impl(working_directory, shas).await?;
+ let partial = get_messages_impl(git, shas).await?;
result.extend(partial);
}
result
} else {
- get_messages_impl(working_directory, shas).await?
+ get_messages_impl(git, shas).await?
};
Ok(shas
@@ -78,11 +78,10 @@ pub async fn get_messages(working_directory: &Path, shas: &[Oid]) -> Result<Hash
.collect::<HashMap<Oid, String>>())
}
-async fn get_messages_impl(working_directory: &Path, shas: &[Oid]) -> Result<Vec<String>> {
+async fn get_messages_impl(git: &GitBinary, shas: &[Oid]) -> Result<Vec<String>> {
const MARKER: &str = "<MARKER>";
- let output = util::command::new_command("git")
- .current_dir(working_directory)
- .arg("show")
+ let output = git
+ .build_command(["show"])
.arg("-s")
.arg(format!("--format=%B{}", MARKER))
.args(shas.iter().map(ToString::to_string))
@@ -21,6 +21,7 @@ use text::LineEnding;
use std::collections::HashSet;
use std::ffi::{OsStr, OsString};
+use std::sync::atomic::AtomicBool;
use std::process::ExitStatus;
use std::str::FromStr;
@@ -303,6 +304,7 @@ impl Branch {
pub struct Worktree {
pub path: PathBuf,
pub ref_name: SharedString,
+ // todo(git_worktree) This type should be a Oid
pub sha: SharedString,
}
@@ -340,6 +342,8 @@ pub fn parse_worktrees_from_str<T: AsRef<str>>(raw_worktrees: T) -> Vec<Worktree
// Ignore other lines: detached, bare, locked, prunable, etc.
}
+ // todo(git_worktree) We should add a test for detach head state
+ // a detach head will have ref_name as none so we would skip it
if let (Some(path), Some(sha), Some(ref_name)) = (path, sha, ref_name) {
worktrees.push(Worktree {
path: PathBuf::from(path),
@@ -958,6 +962,9 @@ pub trait GitRepository: Send + Sync {
) -> BoxFuture<'_, Result<()>>;
fn commit_data_reader(&self) -> Result<CommitDataReader>;
+
+ fn set_trusted(&self, trusted: bool);
+ fn is_trusted(&self) -> bool;
}
pub enum DiffType {
@@ -984,6 +991,7 @@ pub struct RealGitRepository {
pub any_git_binary_path: PathBuf,
any_git_binary_help_output: Arc<Mutex<Option<SharedString>>>,
executor: BackgroundExecutor,
+ is_trusted: Arc<AtomicBool>,
}
impl RealGitRepository {
@@ -1002,6 +1010,7 @@ impl RealGitRepository {
any_git_binary_path,
executor,
any_git_binary_help_output: Arc::new(Mutex::new(None)),
+ is_trusted: Arc::new(AtomicBool::new(false)),
})
}
@@ -1013,20 +1022,24 @@ impl RealGitRepository {
.map(Path::to_path_buf)
}
+ fn git_binary(&self) -> Result<GitBinary> {
+ Ok(GitBinary::new(
+ self.any_git_binary_path.clone(),
+ self.working_directory()
+ .with_context(|| "Can't run git commands without a working directory")?,
+ self.executor.clone(),
+ self.is_trusted(),
+ ))
+ }
+
async fn any_git_binary_help_output(&self) -> SharedString {
if let Some(output) = self.any_git_binary_help_output.lock().clone() {
return output;
}
- let git_binary_path = self.any_git_binary_path.clone();
- let executor = self.executor.clone();
- let working_directory = self.working_directory();
+ let git_binary = self.git_binary();
let output: SharedString = self
.executor
- .spawn(async move {
- GitBinary::new(git_binary_path, working_directory?, executor)
- .run(["help", "-a"])
- .await
- })
+ .spawn(async move { git_binary?.run(["help", "-a"]).await })
.await
.unwrap_or_default()
.into();
@@ -1069,6 +1082,7 @@ pub async fn get_git_committer(cx: &AsyncApp) -> GitCommitter {
git_binary_path.unwrap_or(PathBuf::from("git")),
paths::home_dir().clone(),
cx.background_executor().clone(),
+ true,
);
cx.background_spawn(async move {
@@ -1100,14 +1114,12 @@ impl GitRepository for RealGitRepository {
}
fn show(&self, commit: String) -> BoxFuture<'_, Result<CommitDetails>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = self.working_directory();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
- let output = new_command(git_binary_path)
- .current_dir(&working_directory)
- .args([
+ let git = git_binary?;
+ let output = git
+ .build_command([
"--no-optional-locks",
"show",
"--no-patch",
@@ -1138,15 +1150,14 @@ impl GitRepository for RealGitRepository {
}
fn load_commit(&self, commit: String, cx: AsyncApp) -> BoxFuture<'_, Result<CommitDiff>> {
- let Some(working_directory) = self.repository.lock().workdir().map(ToOwned::to_owned)
- else {
+ if self.repository.lock().workdir().is_none() {
return future::ready(Err(anyhow!("no working directory"))).boxed();
- };
- let git_binary_path = self.any_git_binary_path.clone();
+ }
+ let git_binary = self.git_binary();
cx.background_spawn(async move {
- let show_output = util::command::new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args([
+ let git = git_binary?;
+ let show_output = git
+ .build_command([
"--no-optional-locks",
"show",
"--format=",
@@ -1167,9 +1178,8 @@ impl GitRepository for RealGitRepository {
let changes = parse_git_diff_name_status(&show_stdout);
let parent_sha = format!("{}^", commit);
- let mut cat_file_process = util::command::new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(["--no-optional-locks", "cat-file", "--batch=%(objectsize)"])
+ let mut cat_file_process = git
+ .build_command(["--no-optional-locks", "cat-file", "--batch=%(objectsize)"])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
@@ -1276,18 +1286,17 @@ impl GitRepository for RealGitRepository {
mode: ResetMode,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
+ let git_binary = self.git_binary();
async move {
- let working_directory = self.working_directory();
-
let mode_flag = match mode {
ResetMode::Mixed => "--mixed",
ResetMode::Soft => "--soft",
};
- let output = new_command(&self.any_git_binary_path)
+ let git = git_binary?;
+ let output = git
+ .build_command(["reset", mode_flag, &commit])
.envs(env.iter())
- .current_dir(&working_directory?)
- .args(["reset", mode_flag, &commit])
.output()
.await?;
anyhow::ensure!(
@@ -1306,17 +1315,16 @@ impl GitRepository for RealGitRepository {
paths: Vec<RepoPath>,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
async move {
if paths.is_empty() {
return Ok(());
}
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory?)
+ let git = git_binary?;
+ let output = git
+ .build_command(["checkout", &commit, "--"])
.envs(env.iter())
- .args(["checkout", &commit, "--"])
.args(paths.iter().map(|path| path.as_unix_str()))
.output()
.await?;
@@ -1411,18 +1419,16 @@ impl GitRepository for RealGitRepository {
env: Arc<HashMap<String, String>>,
is_executable: bool,
) -> BoxFuture<'_, anyhow::Result<()>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
+ let git = git_binary?;
let mode = if is_executable { "100755" } else { "100644" };
if let Some(content) = content {
- let mut child = new_command(&git_binary_path)
- .current_dir(&working_directory)
+ let mut child = git
+ .build_command(["hash-object", "-w", "--stdin"])
.envs(env.iter())
- .args(["hash-object", "-w", "--stdin"])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()?;
@@ -1435,10 +1441,9 @@ impl GitRepository for RealGitRepository {
log::debug!("indexing SHA: {sha}, path {path:?}");
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory)
+ let output = git
+ .build_command(["update-index", "--add", "--cacheinfo", mode, sha])
.envs(env.iter())
- .args(["update-index", "--add", "--cacheinfo", mode, sha])
.arg(path.as_unix_str())
.output()
.await?;
@@ -1450,10 +1455,9 @@ impl GitRepository for RealGitRepository {
);
} else {
log::debug!("removing path {path:?} from the index");
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory)
+ let output = git
+ .build_command(["update-index", "--force-remove"])
.envs(env.iter())
- .args(["update-index", "--force-remove"])
.arg(path.as_unix_str())
.output()
.await?;
@@ -1482,14 +1486,12 @@ impl GitRepository for RealGitRepository {
}
fn revparse_batch(&self, revs: Vec<String>) -> BoxFuture<'_, Result<Vec<Option<String>>>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
- let mut process = new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args([
+ let git = git_binary?;
+ let mut process = git
+ .build_command([
"--no-optional-locks",
"cat-file",
"--batch-check=%(objectname)",
@@ -1542,19 +1544,14 @@ impl GitRepository for RealGitRepository {
}
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = match self.working_directory() {
- Ok(working_directory) => working_directory,
+ let git = match self.git_binary() {
+ Ok(git) => git,
Err(e) => return Task::ready(Err(e)),
};
let args = git_status_args(path_prefixes);
log::debug!("Checking for git status in {path_prefixes:?}");
self.executor.spawn(async move {
- let output = new_command(&git_binary_path)
- .current_dir(working_directory)
- .args(args)
- .output()
- .await?;
+ let output = git.build_command(args).output().await?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
stdout.parse()
@@ -1566,9 +1563,8 @@ impl GitRepository for RealGitRepository {
}
fn diff_tree(&self, request: DiffTreeType) -> BoxFuture<'_, Result<TreeDiff>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = match self.working_directory() {
- Ok(working_directory) => working_directory,
+ let git = match self.git_binary() {
+ Ok(git) => git,
Err(e) => return Task::ready(Err(e)).boxed(),
};
@@ -1593,11 +1589,7 @@ impl GitRepository for RealGitRepository {
self.executor
.spawn(async move {
- let output = new_command(&git_binary_path)
- .current_dir(working_directory)
- .args(args)
- .output()
- .await?;
+ let output = git.build_command(args).output().await?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
stdout.parse()
@@ -1610,13 +1602,12 @@ impl GitRepository for RealGitRepository {
}
fn stash_entries(&self) -> BoxFuture<'_, Result<GitStash>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = self.working_directory();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let output = new_command(&git_binary_path)
- .current_dir(working_directory?)
- .args(&["stash", "list", "--pretty=format:%gd%x00%H%x00%ct%x00%s"])
+ let git = git_binary?;
+ let output = git
+ .build_command(&["stash", "list", "--pretty=format:%gd%x00%H%x00%ct%x00%s"])
.output()
.await?;
if output.status.success() {
@@ -1631,8 +1622,7 @@ impl GitRepository for RealGitRepository {
}
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
let fields = [
@@ -1654,12 +1644,8 @@ impl GitRepository for RealGitRepository {
"--format",
&fields,
];
- let working_directory = working_directory?;
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(args)
- .output()
- .await?;
+ let git = git_binary?;
+ let output = git.build_command(args).output().await?;
anyhow::ensure!(
output.status.success(),
@@ -1673,11 +1659,7 @@ impl GitRepository for RealGitRepository {
if branches.is_empty() {
let args = vec!["symbolic-ref", "--quiet", "HEAD"];
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(args)
- .output()
- .await?;
+ let output = git.build_command(args).output().await?;
// git symbolic-ref returns a non-0 exit code if HEAD points
// to something other than a branch
@@ -1699,13 +1681,12 @@ impl GitRepository for RealGitRepository {
}
fn worktrees(&self) -> BoxFuture<'_, Result<Vec<Worktree>>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = self.working_directory();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let output = new_command(&git_binary_path)
- .current_dir(working_directory?)
- .args(&["--no-optional-locks", "worktree", "list", "--porcelain"])
+ let git = git_binary?;
+ let output = git
+ .build_command(&["--no-optional-locks", "worktree", "list", "--porcelain"])
.output()
.await?;
if output.status.success() {
@@ -1725,8 +1706,7 @@ impl GitRepository for RealGitRepository {
directory: PathBuf,
from_commit: Option<String>,
) -> BoxFuture<'_, Result<()>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = self.working_directory();
+ let git_binary = self.git_binary();
let final_path = directory.join(&name);
let mut args = vec![
OsString::from("--no-optional-locks"),
@@ -1746,11 +1726,8 @@ impl GitRepository for RealGitRepository {
self.executor
.spawn(async move {
std::fs::create_dir_all(final_path.parent().unwrap_or(&final_path))?;
- let output = new_command(&git_binary_path)
- .current_dir(working_directory?)
- .args(args)
- .output()
- .await?;
+ let git = git_binary?;
+ let output = git.build_command(args).output().await?;
if output.status.success() {
Ok(())
} else {
@@ -1762,9 +1739,7 @@ impl GitRepository for RealGitRepository {
}
fn remove_worktree(&self, path: PathBuf, force: bool) -> BoxFuture<'_, Result<()>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = self.working_directory();
- let executor = self.executor.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
@@ -1778,18 +1753,14 @@ impl GitRepository for RealGitRepository {
}
args.push("--".into());
args.push(path.as_os_str().into());
- GitBinary::new(git_binary_path, working_directory?, executor)
- .run(args)
- .await?;
+ git_binary?.run(args).await?;
anyhow::Ok(())
})
.boxed()
}
fn rename_worktree(&self, old_path: PathBuf, new_path: PathBuf) -> BoxFuture<'_, Result<()>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = self.working_directory();
- let executor = self.executor.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
@@ -1801,9 +1772,7 @@ impl GitRepository for RealGitRepository {
old_path.as_os_str().into(),
new_path.as_os_str().into(),
];
- GitBinary::new(git_binary_path, working_directory?, executor)
- .run(args)
- .await?;
+ git_binary?.run(args).await?;
anyhow::Ok(())
})
.boxed()
@@ -1811,9 +1780,7 @@ impl GitRepository for RealGitRepository {
fn change_branch(&self, name: String) -> BoxFuture<'_, Result<()>> {
let repo = self.repository.clone();
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
- let executor = self.executor.clone();
+ let git_binary = self.git_binary();
let branch = self.executor.spawn(async move {
let repo = repo.lock();
let branch = if let Ok(branch) = repo.find_branch(&name, BranchType::Local) {
@@ -1848,9 +1815,7 @@ impl GitRepository for RealGitRepository {
self.executor
.spawn(async move {
let branch = branch.await?;
- GitBinary::new(git_binary_path, working_directory?, executor)
- .run(&["checkout", &branch])
- .await?;
+ git_binary?.run(&["checkout", &branch]).await?;
anyhow::Ok(())
})
.boxed()
@@ -1861,9 +1826,7 @@ impl GitRepository for RealGitRepository {
name: String,
base_branch: Option<String>,
) -> BoxFuture<'_, Result<()>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = self.working_directory();
- let executor = self.executor.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
@@ -1874,22 +1837,18 @@ impl GitRepository for RealGitRepository {
args.push(&base_branch_str);
}
- GitBinary::new(git_binary_path, working_directory?, executor)
- .run(&args)
- .await?;
+ git_binary?.run(&args).await?;
anyhow::Ok(())
})
.boxed()
}
fn rename_branch(&self, branch: String, new_name: String) -> BoxFuture<'_, Result<()>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = self.working_directory();
- let executor = self.executor.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- GitBinary::new(git_binary_path, working_directory?, executor)
+ git_binary?
.run(&["branch", "-m", &branch, &new_name])
.await?;
anyhow::Ok(())
@@ -1898,15 +1857,11 @@ impl GitRepository for RealGitRepository {
}
fn delete_branch(&self, name: String) -> BoxFuture<'_, Result<()>> {
- let git_binary_path = self.any_git_binary_path.clone();
- let working_directory = self.working_directory();
- let executor = self.executor.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- GitBinary::new(git_binary_path, working_directory?, executor)
- .run(&["branch", "-d", &name])
- .await?;
+ git_binary?.run(&["branch", "-d", &name]).await?;
anyhow::Ok(())
})
.boxed()
@@ -1918,20 +1873,11 @@ impl GitRepository for RealGitRepository {
content: Rope,
line_ending: LineEnding,
) -> BoxFuture<'_, Result<crate::blame::Blame>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
- let executor = self.executor.clone();
+ let git = self.git_binary();
- executor
+ self.executor
.spawn(async move {
- crate::blame::Blame::for_path(
- &git_binary_path,
- &working_directory?,
- &path,
- &content,
- line_ending,
- )
- .await
+ crate::blame::Blame::for_path(&git?, &path, &content, line_ending).await
})
.boxed()
}
@@ -1946,11 +1892,10 @@ impl GitRepository for RealGitRepository {
skip: usize,
limit: Option<usize>,
) -> BoxFuture<'_, Result<FileHistory>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
+ let git = git_binary?;
// Use a unique delimiter with a hardcoded UUID to separate commits
// This essentially eliminates any chance of encountering the delimiter in actual commit data
let commit_delimiter =
@@ -1978,9 +1923,8 @@ impl GitRepository for RealGitRepository {
args.push("--");
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(&args)
+ let output = git
+ .build_command(&args)
.arg(path.as_unix_str())
.output()
.await?;
@@ -2025,30 +1969,17 @@ impl GitRepository for RealGitRepository {
}
fn diff(&self, diff: DiffType) -> BoxFuture<'_, Result<String>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
+ let git = git_binary?;
let output = match diff {
DiffType::HeadToIndex => {
- new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(["diff", "--staged"])
- .output()
- .await?
- }
- DiffType::HeadToWorktree => {
- new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(["diff"])
- .output()
- .await?
+ git.build_command(["diff", "--staged"]).output().await?
}
+ DiffType::HeadToWorktree => git.build_command(["diff"]).output().await?,
DiffType::MergeBase { base_ref } => {
- new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(["diff", "--merge-base", base_ref.as_ref()])
+ git.build_command(["diff", "--merge-base", base_ref.as_ref()])
.output()
.await?
}
@@ -2068,38 +1999,29 @@ impl GitRepository for RealGitRepository {
&self,
diff: DiffType,
) -> BoxFuture<'_, Result<HashMap<RepoPath, crate::status::DiffStat>>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
+ let git = git_binary?;
let output = match diff {
DiffType::HeadToIndex => {
- new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(["diff", "--numstat", "--staged"])
+ git.build_command(["diff", "--numstat", "--staged"])
.output()
.await?
}
DiffType::HeadToWorktree => {
- new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(["diff", "--numstat"])
- .output()
- .await?
+ git.build_command(["diff", "--numstat"]).output().await?
}
DiffType::MergeBase { base_ref } => {
- new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args([
- "diff",
- "--numstat",
- "--merge-base",
- base_ref.as_ref(),
- "HEAD",
- ])
- .output()
- .await?
+ git.build_command([
+ "diff",
+ "--numstat",
+ "--merge-base",
+ base_ref.as_ref(),
+ "HEAD",
+ ])
+ .output()
+ .await?
}
};
@@ -2120,15 +2042,14 @@ impl GitRepository for RealGitRepository {
paths: Vec<RepoPath>,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
if !paths.is_empty() {
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory?)
+ let git = git_binary?;
+ let output = git
+ .build_command(["update-index", "--add", "--remove", "--"])
.envs(env.iter())
- .args(["update-index", "--add", "--remove", "--"])
.args(paths.iter().map(|p| p.as_unix_str()))
.output()
.await?;
@@ -2148,16 +2069,15 @@ impl GitRepository for RealGitRepository {
paths: Vec<RepoPath>,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
if !paths.is_empty() {
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory?)
+ let git = git_binary?;
+ let output = git
+ .build_command(["reset", "--quiet", "--"])
.envs(env.iter())
- .args(["reset", "--quiet", "--"])
.args(paths.iter().map(|p| p.as_std_path()))
.output()
.await?;
@@ -2178,19 +2098,16 @@ impl GitRepository for RealGitRepository {
paths: Vec<RepoPath>,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let mut cmd = new_command(&git_binary_path);
- cmd.current_dir(&working_directory?)
+ let git = git_binary?;
+ let output = git
+ .build_command(["stash", "push", "--quiet", "--include-untracked"])
.envs(env.iter())
- .args(["stash", "push", "--quiet"])
- .arg("--include-untracked");
-
- cmd.args(paths.iter().map(|p| p.as_unix_str()));
-
- let output = cmd.output().await?;
+ .args(paths.iter().map(|p| p.as_unix_str()))
+ .output()
+ .await?;
anyhow::ensure!(
output.status.success(),
@@ -2207,20 +2124,15 @@ impl GitRepository for RealGitRepository {
index: Option<usize>,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let mut cmd = new_command(git_binary_path);
+ let git = git_binary?;
let mut args = vec!["stash".to_string(), "pop".to_string()];
if let Some(index) = index {
args.push(format!("stash@{{{}}}", index));
}
- cmd.current_dir(&working_directory?)
- .envs(env.iter())
- .args(args);
-
- let output = cmd.output().await?;
+ let output = git.build_command(&args).envs(env.iter()).output().await?;
anyhow::ensure!(
output.status.success(),
@@ -2237,20 +2149,15 @@ impl GitRepository for RealGitRepository {
index: Option<usize>,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let mut cmd = new_command(git_binary_path);
+ let git = git_binary?;
let mut args = vec!["stash".to_string(), "apply".to_string()];
if let Some(index) = index {
args.push(format!("stash@{{{}}}", index));
}
- cmd.current_dir(&working_directory?)
- .envs(env.iter())
- .args(args);
-
- let output = cmd.output().await?;
+ let output = git.build_command(&args).envs(env.iter()).output().await?;
anyhow::ensure!(
output.status.success(),
@@ -2267,20 +2174,15 @@ impl GitRepository for RealGitRepository {
index: Option<usize>,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let mut cmd = new_command(git_binary_path);
+ let git = git_binary?;
let mut args = vec!["stash".to_string(), "drop".to_string()];
if let Some(index) = index {
args.push(format!("stash@{{{}}}", index));
}
- cmd.current_dir(&working_directory?)
- .envs(env.iter())
- .args(args);
-
- let output = cmd.output().await?;
+ let output = git.build_command(&args).envs(env.iter()).output().await?;
anyhow::ensure!(
output.status.success(),
@@ -2300,16 +2202,14 @@ impl GitRepository for RealGitRepository {
ask_pass: AskPassDelegate,
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
let executor = self.executor.clone();
// Note: Do not spawn this command on the background thread, it might pop open the credential helper
// which we want to block on.
async move {
- let mut cmd = new_command(git_binary_path);
- cmd.current_dir(&working_directory?)
- .envs(env.iter())
- .args(["commit", "--quiet", "-m"])
+ let git = git_binary?;
+ let mut cmd = git.build_command(["commit", "--quiet", "-m"]);
+ cmd.envs(env.iter())
.arg(&message.to_string())
.arg("--cleanup=strip")
.arg("--no-verify")
@@ -2348,16 +2248,21 @@ impl GitRepository for RealGitRepository {
let working_directory = self.working_directory();
let executor = cx.background_executor().clone();
let git_binary_path = self.system_git_binary_path.clone();
+ let is_trusted = self.is_trusted();
// Note: Do not spawn this command on the background thread, it might pop open the credential helper
// which we want to block on.
async move {
let git_binary_path = git_binary_path.context("git not found on $PATH, can't push")?;
let working_directory = working_directory?;
- let mut command = new_command(git_binary_path);
+ let git = GitBinary::new(
+ git_binary_path,
+ working_directory,
+ executor.clone(),
+ is_trusted,
+ );
+ let mut command = git.build_command(["push"]);
command
.envs(env.iter())
- .current_dir(&working_directory)
- .args(["push"])
.args(options.map(|option| match option {
PushOptions::SetUpstream => "--set-upstream",
PushOptions::Force => "--force-with-lease",
@@ -2385,15 +2290,20 @@ impl GitRepository for RealGitRepository {
let working_directory = self.working_directory();
let executor = cx.background_executor().clone();
let git_binary_path = self.system_git_binary_path.clone();
+ let is_trusted = self.is_trusted();
// Note: Do not spawn this command on the background thread, it might pop open the credential helper
// which we want to block on.
async move {
let git_binary_path = git_binary_path.context("git not found on $PATH, can't pull")?;
- let mut command = new_command(git_binary_path);
- command
- .envs(env.iter())
- .current_dir(&working_directory?)
- .arg("pull");
+ let working_directory = working_directory?;
+ let git = GitBinary::new(
+ git_binary_path,
+ working_directory,
+ executor.clone(),
+ is_trusted,
+ );
+ let mut command = git.build_command(["pull"]);
+ command.envs(env.iter());
if rebase {
command.arg("--rebase");
@@ -2421,15 +2331,21 @@ impl GitRepository for RealGitRepository {
let remote_name = format!("{}", fetch_options);
let git_binary_path = self.system_git_binary_path.clone();
let executor = cx.background_executor().clone();
+ let is_trusted = self.is_trusted();
// Note: Do not spawn this command on the background thread, it might pop open the credential helper
// which we want to block on.
async move {
let git_binary_path = git_binary_path.context("git not found on $PATH, can't fetch")?;
- let mut command = new_command(git_binary_path);
+ let working_directory = working_directory?;
+ let git = GitBinary::new(
+ git_binary_path,
+ working_directory,
+ executor.clone(),
+ is_trusted,
+ );
+ let mut command = git.build_command(["fetch", &remote_name]);
command
.envs(env.iter())
- .current_dir(&working_directory?)
- .args(["fetch", &remote_name])
.stdout(Stdio::piped())
.stderr(Stdio::piped());
@@ -2439,14 +2355,12 @@ impl GitRepository for RealGitRepository {
}
fn get_push_remote(&self, branch: String) -> BoxFuture<'_, Result<Option<Remote>>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(["rev-parse", "--abbrev-ref"])
+ let git = git_binary?;
+ let output = git
+ .build_command(["rev-parse", "--abbrev-ref"])
.arg(format!("{branch}@{{push}}"))
.output()
.await?;
@@ -2466,14 +2380,12 @@ impl GitRepository for RealGitRepository {
}
fn get_branch_remote(&self, branch: String) -> BoxFuture<'_, Result<Option<Remote>>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(["config", "--get"])
+ let git = git_binary?;
+ let output = git
+ .build_command(["config", "--get"])
.arg(format!("branch.{branch}.remote"))
.output()
.await?;
@@ -2490,16 +2402,11 @@ impl GitRepository for RealGitRepository {
}
fn get_all_remotes(&self) -> BoxFuture<'_, Result<Vec<Remote>>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(["remote", "-v"])
- .output()
- .await?;
+ let git = git_binary?;
+ let output = git.build_command(["remote", "-v"]).output().await?;
anyhow::ensure!(
output.status.success(),
@@ -2548,17 +2455,12 @@ impl GitRepository for RealGitRepository {
}
fn check_for_pushed_commit(&self) -> BoxFuture<'_, Result<Vec<SharedString>>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
+ let git = git_binary?;
let git_cmd = async |args: &[&str]| -> Result<String> {
- let output = new_command(&git_binary_path)
- .current_dir(&working_directory)
- .args(args)
- .output()
- .await?;
+ let output = git.build_command(args).output().await?;
anyhow::ensure!(
output.status.success(),
String::from_utf8_lossy(&output.stderr).to_string()
@@ -2607,14 +2509,10 @@ impl GitRepository for RealGitRepository {
}
fn checkpoint(&self) -> BoxFuture<'static, Result<GitRepositoryCheckpoint>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
- let executor = self.executor.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
- let mut git = GitBinary::new(git_binary_path, working_directory.clone(), executor)
- .envs(checkpoint_author_envs());
+ let mut git = git_binary?.envs(checkpoint_author_envs());
git.with_temp_index(async |git| {
let head_sha = git.run(&["rev-parse", "HEAD"]).await.ok();
let mut excludes = exclude_files(git).await?;
@@ -2640,15 +2538,10 @@ impl GitRepository for RealGitRepository {
}
fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> {
- let working_directory = self.working_directory();
- let git_binary_path = self.any_git_binary_path.clone();
-
- let executor = self.executor.clone();
+ let git_binary = self.git_binary();
self.executor
.spawn(async move {
- let working_directory = working_directory?;
-
- let git = GitBinary::new(git_binary_path, working_directory, executor);
+ let git = git_binary?;
git.run(&[
"restore",
"--source",
@@ -98,7 +98,6 @@ pub struct WgpuRenderer {
queue: Arc<wgpu::Queue>,
surface: wgpu::Surface<'static>,
surface_config: wgpu::SurfaceConfiguration,
- surface_configured: bool,
pipelines: WgpuPipelines,
bind_group_layouts: WgpuBindGroupLayouts,
atlas: Arc<WgpuAtlas>,
@@ -381,7 +380,6 @@ impl WgpuRenderer {
queue,
surface,
surface_config,
- surface_configured: true,
pipelines,
bind_group_layouts,
atlas,
@@ -875,9 +873,7 @@ impl WgpuRenderer {
self.surface_config.width = clamped_width.max(1);
self.surface_config.height = clamped_height.max(1);
- if self.surface_configured {
- self.surface.configure(&self.device, &self.surface_config);
- }
+ self.surface.configure(&self.device, &self.surface_config);
// Invalidate intermediate textures - they will be lazily recreated
// in draw() after we confirm the surface is healthy. This avoids
@@ -928,9 +924,7 @@ impl WgpuRenderer {
if new_alpha_mode != self.surface_config.alpha_mode {
self.surface_config.alpha_mode = new_alpha_mode;
- if self.surface_configured {
- self.surface.configure(&self.device, &self.surface_config);
- }
+ self.surface.configure(&self.device, &self.surface_config);
self.pipelines = Self::create_pipelines(
&self.device,
&self.bind_group_layouts,
@@ -991,7 +985,7 @@ impl WgpuRenderer {
let frame = match self.surface.get_current_texture() {
Ok(frame) => frame,
Err(wgpu::SurfaceError::Lost | wgpu::SurfaceError::Outdated) => {
- self.surface_configured = false;
+ self.surface.configure(&self.device, &self.surface_config);
return;
}
Err(e) => {
@@ -7,6 +7,7 @@ use std::{
use bytes::Bytes;
use futures::AsyncRead;
use http_body::{Body, Frame};
+use serde::Serialize;
/// Based on the implementation of AsyncBody in
/// <https://github.com/sagebind/isahc/blob/5c533f1ef4d6bdf1fd291b5103c22110f41d0bf0/src/body/mod.rs>.
@@ -88,6 +89,19 @@ impl From<&'static str> for AsyncBody {
}
}
+/// Newtype wrapper that serializes a value as JSON into an `AsyncBody`.
+pub struct Json<T: Serialize>(pub T);
+
+impl<T: Serialize> From<Json<T>> for AsyncBody {
+ fn from(json: Json<T>) -> Self {
+ Self::from_bytes(
+ serde_json::to_vec(&json.0)
+ .expect("failed to serialize JSON")
+ .into(),
+ )
+ }
+}
+
impl<T: Into<Self>> From<Option<T>> for AsyncBody {
fn from(body: Option<T>) -> Self {
match body {
@@ -5,7 +5,7 @@ pub mod github;
pub mod github_download;
pub use anyhow::{Result, anyhow};
-pub use async_body::{AsyncBody, Inner};
+pub use async_body::{AsyncBody, Inner, Json};
use derive_more::Deref;
use http::HeaderValue;
pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder};
@@ -4,6 +4,7 @@ use std::sync::Arc;
use anyhow::{Context as _, Result};
use client::Client;
use cloud_api_client::ClientApiError;
+use cloud_api_types::OrganizationId;
use cloud_api_types::websocket_protocol::MessageToClient;
use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
@@ -26,29 +27,46 @@ impl fmt::Display for PaymentRequiredError {
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
- pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
+ pub async fn acquire(
+ &self,
+ client: &Arc<Client>,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
} else {
- Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
+ Self::fetch(
+ RwLockUpgradableReadGuard::upgrade(lock).await,
+ client,
+ organization_id,
+ )
+ .await
}
}
- pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
- Self::fetch(self.0.write().await, client).await
+ pub async fn refresh(
+ &self,
+ client: &Arc<Client>,
+ organization_id: Option<OrganizationId>,
+ ) -> Result<String> {
+ Self::fetch(self.0.write().await, client, organization_id).await
}
async fn fetch(
mut lock: RwLockWriteGuard<'_, Option<String>>,
client: &Arc<Client>,
+ organization_id: Option<OrganizationId>,
) -> Result<String> {
let system_id = client
.telemetry()
.system_id()
.map(|system_id| system_id.to_string());
- let result = client.cloud_client().create_llm_token(system_id).await;
+ let result = client
+ .cloud_client()
+ .create_llm_token(system_id, organization_id)
+ .await;
match result {
Ok(response) => {
*lock = Some(response.token.0.clone());
@@ -3,7 +3,7 @@ use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use client::{Client, UserStore, zed_urls};
-use cloud_api_types::Plan;
+use cloud_api_types::{OrganizationId, Plan};
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
@@ -122,15 +122,25 @@ impl State {
recommended_models: Vec::new(),
_fetch_models_task: cx.spawn(async move |this, cx| {
maybe!(async move {
- let (client, llm_api_token) = this
- .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
+ let (client, llm_api_token, organization_id) =
+ this.read_with(cx, |this, cx| {
+ (
+ client.clone(),
+ this.llm_api_token.clone(),
+ this.user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone()),
+ )
+ })?;
while current_user.borrow().is_none() {
current_user.next().await;
}
let response =
- Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
+ Self::fetch_models(client.clone(), llm_api_token.clone(), organization_id)
+ .await?;
this.update(cx, |this, cx| this.update_models(response, cx))?;
anyhow::Ok(())
})
@@ -146,9 +156,17 @@ impl State {
move |this, _listener, _event, cx| {
let client = this.client.clone();
let llm_api_token = this.llm_api_token.clone();
+ let organization_id = this
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone());
cx.spawn(async move |this, cx| {
- llm_api_token.refresh(&client).await?;
- let response = Self::fetch_models(client, llm_api_token).await?;
+ llm_api_token
+ .refresh(&client, organization_id.clone())
+ .await?;
+ let response =
+ Self::fetch_models(client, llm_api_token, organization_id).await?;
this.update(cx, |this, cx| {
this.update_models(response, cx);
})
@@ -209,9 +227,10 @@ impl State {
async fn fetch_models(
client: Arc<Client>,
llm_api_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
) -> Result<ListModelsResponse> {
let http_client = &client.http_client();
- let token = llm_api_token.acquire(&client).await?;
+ let token = llm_api_token.acquire(&client, organization_id).await?;
let request = http_client::Request::builder()
.method(Method::GET)
@@ -273,11 +292,13 @@ impl CloudLanguageModelProvider {
&self,
model: Arc<cloud_llm_client::LanguageModel>,
llm_api_token: LlmApiToken,
+ user_store: Entity<UserStore>,
) -> Arc<dyn LanguageModel> {
Arc::new(CloudLanguageModel {
id: LanguageModelId(SharedString::from(model.id.0.clone())),
model,
llm_api_token,
+ user_store,
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
})
@@ -306,36 +327,46 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let default_model = self.state.read(cx).default_model.clone()?;
- let llm_api_token = self.state.read(cx).llm_api_token.clone();
- Some(self.create_language_model(default_model, llm_api_token))
+ let state = self.state.read(cx);
+ let default_model = state.default_model.clone()?;
+ let llm_api_token = state.llm_api_token.clone();
+ let user_store = state.user_store.clone();
+ Some(self.create_language_model(default_model, llm_api_token, user_store))
}
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
- let llm_api_token = self.state.read(cx).llm_api_token.clone();
- Some(self.create_language_model(default_fast_model, llm_api_token))
+ let state = self.state.read(cx);
+ let default_fast_model = state.default_fast_model.clone()?;
+ let llm_api_token = state.llm_api_token.clone();
+ let user_store = state.user_store.clone();
+ Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
}
fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
- let llm_api_token = self.state.read(cx).llm_api_token.clone();
- self.state
- .read(cx)
+ let state = self.state.read(cx);
+ let llm_api_token = state.llm_api_token.clone();
+ let user_store = state.user_store.clone();
+ state
.recommended_models
.iter()
.cloned()
- .map(|model| self.create_language_model(model, llm_api_token.clone()))
+ .map(|model| {
+ self.create_language_model(model, llm_api_token.clone(), user_store.clone())
+ })
.collect()
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
- let llm_api_token = self.state.read(cx).llm_api_token.clone();
- self.state
- .read(cx)
+ let state = self.state.read(cx);
+ let llm_api_token = state.llm_api_token.clone();
+ let user_store = state.user_store.clone();
+ state
.models
.iter()
.cloned()
- .map(|model| self.create_language_model(model, llm_api_token.clone()))
+ .map(|model| {
+ self.create_language_model(model, llm_api_token.clone(), user_store.clone())
+ })
.collect()
}
@@ -367,6 +398,7 @@ pub struct CloudLanguageModel {
id: LanguageModelId,
model: Arc<cloud_llm_client::LanguageModel>,
llm_api_token: LlmApiToken,
+ user_store: Entity<UserStore>,
client: Arc<Client>,
request_limiter: RateLimiter,
}
@@ -380,12 +412,15 @@ impl CloudLanguageModel {
async fn perform_llm_completion(
client: Arc<Client>,
llm_api_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
app_version: Option<Version>,
body: CompletionBody,
) -> Result<PerformLlmCompletionResponse> {
let http_client = &client.http_client();
- let mut token = llm_api_token.acquire(&client).await?;
+ let mut token = llm_api_token
+ .acquire(&client, organization_id.clone())
+ .await?;
let mut refreshed_token = false;
loop {
@@ -416,7 +451,9 @@ impl CloudLanguageModel {
}
if !refreshed_token && response.needs_llm_token_refresh() {
- token = llm_api_token.refresh(&client).await?;
+ token = llm_api_token
+ .refresh(&client, organization_id.clone())
+ .await?;
refreshed_token = true;
continue;
}
@@ -670,12 +707,17 @@ impl LanguageModel for CloudLanguageModel {
cloud_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
+ let organization_id = self
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone());
let model_id = self.model.id.to_string();
let generate_content_request =
into_google(request, model_id.clone(), GoogleModelMode::Default);
async move {
let http_client = &client.http_client();
- let token = llm_api_token.acquire(&client).await?;
+ let token = llm_api_token.acquire(&client, organization_id).await?;
let request_body = CountTokensBody {
provider: cloud_llm_client::LanguageModelProvider::Google,
@@ -736,6 +778,13 @@ impl LanguageModel for CloudLanguageModel {
let prompt_id = request.prompt_id.clone();
let intent = request.intent;
let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
+ let user_store = self.user_store.clone();
+ let organization_id = cx.update(|cx| {
+ user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone())
+ });
let thinking_allowed = request.thinking_allowed;
let enable_thinking = thinking_allowed && self.model.supports_thinking;
let provider_name = provider_name(&self.model.provider);
@@ -767,6 +816,7 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
+ let organization_id = organization_id.clone();
let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {
response,
@@ -774,6 +824,7 @@ impl LanguageModel for CloudLanguageModel {
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
+ organization_id,
app_version,
CompletionBody {
thread_id,
@@ -803,6 +854,7 @@ impl LanguageModel for CloudLanguageModel {
cloud_llm_client::LanguageModelProvider::OpenAi => {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
+ let organization_id = organization_id.clone();
let effort = request
.thinking_effort
.as_ref()
@@ -828,6 +880,7 @@ impl LanguageModel for CloudLanguageModel {
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
+ organization_id,
app_version,
CompletionBody {
thread_id,
@@ -861,6 +914,7 @@ impl LanguageModel for CloudLanguageModel {
None,
);
let llm_api_token = self.llm_api_token.clone();
+ let organization_id = organization_id.clone();
let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {
response,
@@ -868,6 +922,7 @@ impl LanguageModel for CloudLanguageModel {
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
+ organization_id,
app_version,
CompletionBody {
thread_id,
@@ -902,6 +957,7 @@ impl LanguageModel for CloudLanguageModel {
} = Self::perform_llm_completion(
client.clone(),
llm_api_token,
+ organization_id,
app_version,
CompletionBody {
thread_id,
@@ -1,6 +1,6 @@
name = "C++"
grammar = "cpp"
-path_suffixes = ["cc", "hh", "cpp", "h", "hpp", "cxx", "hxx", "c++", "h++", "ipp", "inl", "ino", "ixx", "cu", "cuh", "C", "H"]
+path_suffixes = ["cc", "hh", "cpp", "cppm", "h", "hpp", "cxx", "hxx", "c++", "h++", "ipp", "inl", "ino", "ixx", "cu", "cuh", "C", "H"]
line_comments = ["// ", "/// ", "//! "]
first_line_pattern = '^//.*-\*-\s*C\+\+\s*-\*-'
decrease_indent_patterns = [
@@ -179,7 +179,13 @@ pub fn init(languages: Arc<LanguageRegistry>, fs: Arc<dyn Fs>, node: NodeRuntime
},
LanguageInfo {
name: "python",
- adapters: vec![basedpyright_lsp_adapter, ruff_lsp_adapter],
+ adapters: vec![
+ basedpyright_lsp_adapter,
+ ruff_lsp_adapter,
+ ty_lsp_adapter,
+ py_lsp_adapter,
+ python_lsp_adapter,
+ ],
context: Some(python_context_provider),
toolchain: Some(python_toolchain_provider),
manifest_name: Some(SharedString::new_static("pyproject.toml").into()),
@@ -281,9 +287,6 @@ pub fn init(languages: Arc<LanguageRegistry>, fs: Arc<dyn Fs>, node: NodeRuntime
typescript_lsp_adapter,
);
- languages.register_available_lsp_adapter(python_lsp_adapter.name(), python_lsp_adapter);
- languages.register_available_lsp_adapter(py_lsp_adapter.name(), py_lsp_adapter);
- languages.register_available_lsp_adapter(ty_lsp_adapter.name(), ty_lsp_adapter);
// Register Tailwind for the existing languages that should have it by default.
//
// This can be driven by the `language_servers` setting once we have a way for
@@ -544,7 +544,7 @@ impl Render for ProfilerWindow {
let path = cx.prompt_for_new_path(
&active_path,
- Some("performance_profile.miniprof"),
+ Some("performance_profile.miniprof.json"),
);
cx.background_spawn(async move {
@@ -6,6 +6,9 @@ pub mod pending_op;
use crate::{
ProjectEnvironment, ProjectItem, ProjectPath,
buffer_store::{BufferStore, BufferStoreEvent},
+ trusted_worktrees::{
+ PathTrust, TrustedWorktrees, TrustedWorktreesEvent, TrustedWorktreesStore,
+ },
worktree_store::{WorktreeStore, WorktreeStoreEvent},
};
use anyhow::{Context as _, Result, anyhow, bail};
@@ -354,6 +357,7 @@ impl LocalRepositoryState {
dot_git_abs_path: Arc<Path>,
project_environment: WeakEntity<ProjectEnvironment>,
fs: Arc<dyn Fs>,
+ is_trusted: bool,
cx: &mut AsyncApp,
) -> anyhow::Result<Self> {
let environment = project_environment
@@ -381,6 +385,7 @@ impl LocalRepositoryState {
}
})
.await?;
+ backend.set_trusted(is_trusted);
Ok(LocalRepositoryState {
backend,
environment: Arc::new(environment),
@@ -495,11 +500,15 @@ impl GitStore {
state: GitStoreState,
cx: &mut Context<Self>,
) -> Self {
- let _subscriptions = vec![
+ let mut _subscriptions = vec![
cx.subscribe(&worktree_store, Self::on_worktree_store_event),
cx.subscribe(&buffer_store, Self::on_buffer_store_event),
];
+ if let Some(trusted_worktrees) = TrustedWorktrees::try_get_global(cx) {
+ _subscriptions.push(cx.subscribe(&trusted_worktrees, Self::on_trusted_worktrees_event));
+ }
+
GitStore {
state,
buffer_store,
@@ -1517,6 +1526,13 @@ impl GitStore {
let original_repo_abs_path: Arc<Path> =
git::repository::original_repo_path_from_common_dir(common_dir_abs_path).into();
let id = RepositoryId(next_repository_id.fetch_add(1, atomic::Ordering::Release));
+ let is_trusted = TrustedWorktrees::try_get_global(cx)
+ .map(|trusted_worktrees| {
+ trusted_worktrees.update(cx, |trusted_worktrees, cx| {
+ trusted_worktrees.can_trust(&self.worktree_store, worktree_id, cx)
+ })
+ })
+ .unwrap_or(false);
let git_store = cx.weak_entity();
let repo = cx.new(|cx| {
let mut repo = Repository::local(
@@ -1526,6 +1542,7 @@ impl GitStore {
dot_git_abs_path.clone(),
project_environment.downgrade(),
fs.clone(),
+ is_trusted,
git_store,
cx,
);
@@ -1566,6 +1583,39 @@ impl GitStore {
}
}
+ fn on_trusted_worktrees_event(
+ &mut self,
+ _: Entity<TrustedWorktreesStore>,
+ event: &TrustedWorktreesEvent,
+ cx: &mut Context<Self>,
+ ) {
+ if !matches!(self.state, GitStoreState::Local { .. }) {
+ return;
+ }
+
+ let (is_trusted, event_paths) = match event {
+ TrustedWorktreesEvent::Trusted(_, trusted_paths) => (true, trusted_paths),
+ TrustedWorktreesEvent::Restricted(_, restricted_paths) => (false, restricted_paths),
+ };
+
+ for (repo_id, worktree_ids) in &self.worktree_ids {
+ if worktree_ids
+ .iter()
+ .any(|worktree_id| event_paths.contains(&PathTrust::Worktree(*worktree_id)))
+ {
+ if let Some(repo) = self.repositories.get(repo_id) {
+ let repository_state = repo.read(cx).repository_state.clone();
+ cx.background_spawn(async move {
+ if let Ok(RepositoryState::Local(state)) = repository_state.await {
+ state.backend.set_trusted(is_trusted);
+ }
+ })
+ .detach();
+ }
+ }
+ }
+ }
+
fn on_buffer_store_event(
&mut self,
_: Entity<BufferStore>,
@@ -3763,6 +3813,13 @@ impl MergeDetails {
}
impl Repository {
+ pub fn is_trusted(&self) -> bool {
+ match self.repository_state.peek() {
+ Some(Ok(RepositoryState::Local(state))) => state.backend.is_trusted(),
+ _ => false,
+ }
+ }
+
pub fn snapshot(&self) -> RepositorySnapshot {
self.snapshot.clone()
}
@@ -3788,6 +3845,7 @@ impl Repository {
dot_git_abs_path: Arc<Path>,
project_environment: WeakEntity<ProjectEnvironment>,
fs: Arc<dyn Fs>,
+ is_trusted: bool,
git_store: WeakEntity<GitStore>,
cx: &mut Context<Self>,
) -> Self {
@@ -3804,6 +3862,7 @@ impl Repository {
dot_git_abs_path,
project_environment,
fs,
+ is_trusted,
cx,
)
.await
@@ -1942,6 +1942,11 @@ impl Project {
}
}
+ #[cfg(feature = "test-support")]
+ pub fn client_subscriptions(&self) -> &Vec<client::Subscription> {
+ &self.client_subscriptions
+ }
+
#[cfg(feature = "test-support")]
pub async fn example(
root_paths: impl IntoIterator<Item = &Path>,
@@ -2741,6 +2746,7 @@ impl Project {
} = &mut self.client_state
{
*sharing_has_stopped = true;
+ self.client_subscriptions.clear();
self.collaborators.clear();
self.worktree_store.update(cx, |store, cx| {
store.disconnected_from_host(cx);
@@ -1174,3 +1174,327 @@ mod git_traversal {
pretty_assertions::assert_eq!(found_statuses, expected_statuses);
}
}
+
+mod git_worktrees {
+ use std::path::PathBuf;
+
+ use fs::FakeFs;
+ use gpui::TestAppContext;
+ use serde_json::json;
+ use settings::SettingsStore;
+ use util::path;
+
+ fn init_test(cx: &mut gpui::TestAppContext) {
+ zlog::init_test();
+
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_git_worktrees_list_and_create(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ ".git": {},
+ "file.txt": "content",
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ cx.executor().run_until_parked();
+
+ let repository = project.read_with(cx, |project, cx| {
+ project.repositories(cx).values().next().unwrap().clone()
+ });
+
+ let worktrees = cx
+ .update(|cx| repository.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(worktrees.len(), 1);
+ assert_eq!(worktrees[0].path, PathBuf::from(path!("/root")));
+
+ let worktree_directory = PathBuf::from(path!("/root"));
+ cx.update(|cx| {
+ repository.update(cx, |repository, _| {
+ repository.create_worktree(
+ "feature-branch".to_string(),
+ worktree_directory.clone(),
+ Some("abc123".to_string()),
+ )
+ })
+ })
+ .await
+ .unwrap()
+ .unwrap();
+
+ cx.executor().run_until_parked();
+
+ let worktrees = cx
+ .update(|cx| repository.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(worktrees.len(), 2);
+ assert_eq!(worktrees[0].path, PathBuf::from(path!("/root")));
+ assert_eq!(worktrees[1].path, worktree_directory.join("feature-branch"));
+ assert_eq!(worktrees[1].ref_name.as_ref(), "refs/heads/feature-branch");
+ assert_eq!(worktrees[1].sha.as_ref(), "abc123");
+
+ cx.update(|cx| {
+ repository.update(cx, |repository, _| {
+ repository.create_worktree(
+ "bugfix-branch".to_string(),
+ worktree_directory.clone(),
+ None,
+ )
+ })
+ })
+ .await
+ .unwrap()
+ .unwrap();
+
+ cx.executor().run_until_parked();
+
+ // List worktrees — should now have main + two created
+ let worktrees = cx
+ .update(|cx| repository.update(cx, |repository, _| repository.worktrees()))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(worktrees.len(), 3);
+
+ let feature_worktree = worktrees
+ .iter()
+ .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/feature-branch")
+ .expect("should find feature-branch worktree");
+ assert_eq!(
+ feature_worktree.path,
+ worktree_directory.join("feature-branch")
+ );
+
+ let bugfix_worktree = worktrees
+ .iter()
+ .find(|worktree| worktree.ref_name.as_ref() == "refs/heads/bugfix-branch")
+ .expect("should find bugfix-branch worktree");
+ assert_eq!(
+ bugfix_worktree.path,
+ worktree_directory.join("bugfix-branch")
+ );
+ assert_eq!(bugfix_worktree.sha.as_ref(), "fake-sha");
+ }
+
+ use crate::Project;
+}
+
+mod trust_tests {
+ use collections::HashSet;
+ use fs::FakeFs;
+ use gpui::TestAppContext;
+ use project::trusted_worktrees::*;
+
+ use serde_json::json;
+ use settings::SettingsStore;
+ use util::path;
+
+ use crate::Project;
+
+ fn init_test(cx: &mut TestAppContext) {
+ zlog::init_test();
+
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_repository_defaults_to_untrusted_without_trust_system(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ ".git": {},
+ "a.txt": "hello",
+ }),
+ )
+ .await;
+
+ // Create project without trust system — repos should default to untrusted.
+ let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
+ cx.executor().run_until_parked();
+
+ let repository = project.read_with(cx, |project, cx| {
+ project.repositories(cx).values().next().unwrap().clone()
+ });
+
+ repository.read_with(cx, |repo, _| {
+ assert!(
+ !repo.is_trusted(),
+ "repository should default to untrusted when no trust system is initialized"
+ );
+ });
+ }
+
+ #[gpui::test]
+ async fn test_multiple_repos_trust_with_single_worktree(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ ".git": {},
+ "a.txt": "hello",
+ "sub": {
+ ".git": {},
+ "b.txt": "world",
+ },
+ }),
+ )
+ .await;
+
+ cx.update(|cx| {
+ init(DbTrustedPaths::default(), cx);
+ });
+
+ let project =
+ Project::test_with_worktree_trust(fs.clone(), [path!("/project").as_ref()], cx).await;
+ cx.executor().run_until_parked();
+
+ let worktree_store = project.read_with(cx, |project, _| project.worktree_store());
+ let worktree_id = worktree_store.read_with(cx, |store, cx| {
+ store.worktrees().next().unwrap().read(cx).id()
+ });
+
+ let repos = project.read_with(cx, |project, cx| {
+ project
+ .repositories(cx)
+ .values()
+ .cloned()
+ .collect::<Vec<_>>()
+ });
+ assert_eq!(repos.len(), 2, "should have two repositories");
+ for repo in &repos {
+ repo.read_with(cx, |repo, _| {
+ assert!(
+ !repo.is_trusted(),
+ "all repos should be untrusted initially"
+ );
+ });
+ }
+
+ let trusted_worktrees = cx
+ .update(|cx| TrustedWorktrees::try_get_global(cx).expect("trust global should be set"));
+ trusted_worktrees.update(cx, |store, cx| {
+ store.trust(
+ &worktree_store,
+ HashSet::from_iter([PathTrust::Worktree(worktree_id)]),
+ cx,
+ );
+ });
+ cx.executor().run_until_parked();
+
+ for repo in &repos {
+ repo.read_with(cx, |repo, _| {
+ assert!(
+ repo.is_trusted(),
+ "all repos should be trusted after worktree is trusted"
+ );
+ });
+ }
+ }
+
+ #[gpui::test]
+ async fn test_repository_trust_restrict_trust_cycle(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ ".git": {},
+ "a.txt": "hello",
+ }),
+ )
+ .await;
+
+ cx.update(|cx| {
+ project::trusted_worktrees::init(DbTrustedPaths::default(), cx);
+ });
+
+ let project =
+ Project::test_with_worktree_trust(fs.clone(), [path!("/project").as_ref()], cx).await;
+ cx.executor().run_until_parked();
+
+ let worktree_store = project.read_with(cx, |project, _| project.worktree_store());
+ let worktree_id = worktree_store.read_with(cx, |store, cx| {
+ store.worktrees().next().unwrap().read(cx).id()
+ });
+
+ let repository = project.read_with(cx, |project, cx| {
+ project.repositories(cx).values().next().unwrap().clone()
+ });
+
+ repository.read_with(cx, |repo, _| {
+ assert!(!repo.is_trusted(), "repository should start untrusted");
+ });
+
+ let trusted_worktrees = cx
+ .update(|cx| TrustedWorktrees::try_get_global(cx).expect("trust global should be set"));
+
+ trusted_worktrees.update(cx, |store, cx| {
+ store.trust(
+ &worktree_store,
+ HashSet::from_iter([PathTrust::Worktree(worktree_id)]),
+ cx,
+ );
+ });
+ cx.executor().run_until_parked();
+
+ repository.read_with(cx, |repo, _| {
+ assert!(
+ repo.is_trusted(),
+ "repository should be trusted after worktree is trusted"
+ );
+ });
+
+ trusted_worktrees.update(cx, |store, cx| {
+ store.restrict(
+ worktree_store.downgrade(),
+ HashSet::from_iter([PathTrust::Worktree(worktree_id)]),
+ cx,
+ );
+ });
+ cx.executor().run_until_parked();
+
+ repository.read_with(cx, |repo, _| {
+ assert!(
+ !repo.is_trusted(),
+ "repository should be untrusted after worktree is restricted"
+ );
+ });
+
+ trusted_worktrees.update(cx, |store, cx| {
+ store.trust(
+ &worktree_store,
+ HashSet::from_iter([PathTrust::Worktree(worktree_id)]),
+ cx,
+ );
+ });
+ cx.executor().run_until_parked();
+
+ repository.read_with(cx, |repo, _| {
+ assert!(
+ repo.is_trusted(),
+ "repository should be trusted again after second trust"
+ );
+ });
+ }
+}
@@ -5359,6 +5359,52 @@ async fn test_rescan_and_remote_updates(cx: &mut gpui::TestAppContext) {
});
}
+#[cfg(target_os = "linux")]
+#[gpui::test(retries = 5)]
+async fn test_recreated_directory_receives_child_events(cx: &mut gpui::TestAppContext) {
+ init_test(cx);
+ cx.executor().allow_parking();
+
+ let dir = TempTree::new(json!({}));
+ let project = Project::test(Arc::new(RealFs::new(None, cx.executor())), [dir.path()], cx).await;
+ let tree = project.update(cx, |project, cx| project.worktrees(cx).next().unwrap());
+
+ tree.flush_fs_events(cx).await;
+
+ let repro_dir = dir.path().join("repro");
+ std::fs::create_dir(&repro_dir).unwrap();
+ tree.flush_fs_events(cx).await;
+
+ cx.update(|cx| {
+ assert!(tree.read(cx).entry_for_path(rel_path("repro")).is_some());
+ });
+
+ std::fs::remove_dir_all(&repro_dir).unwrap();
+ tree.flush_fs_events(cx).await;
+
+ cx.update(|cx| {
+ assert!(tree.read(cx).entry_for_path(rel_path("repro")).is_none());
+ });
+
+ std::fs::create_dir(&repro_dir).unwrap();
+ tree.flush_fs_events(cx).await;
+
+ cx.update(|cx| {
+ assert!(tree.read(cx).entry_for_path(rel_path("repro")).is_some());
+ });
+
+ std::fs::write(repro_dir.join("repro-marker"), "").unwrap();
+ tree.flush_fs_events(cx).await;
+
+ cx.update(|cx| {
+ assert!(
+ tree.read(cx)
+ .entry_for_path(rel_path("repro/repro-marker"))
+ .is_some()
+ );
+ });
+}
+
#[gpui::test(iterations = 10)]
async fn test_buffer_identity_across_renames(cx: &mut gpui::TestAppContext) {
init_test(cx);
@@ -6457,11 +6457,14 @@ impl Render for ProjectPanel {
el.on_action(cx.listener(Self::trash))
})
})
- .when(project.is_local(), |el| {
- el.on_action(cx.listener(Self::reveal_in_finder))
- .on_action(cx.listener(Self::open_system))
- .on_action(cx.listener(Self::open_in_terminal))
- })
+ .when(
+ project.is_local() || project.is_via_wsl_with_host_interop(cx),
+ |el| {
+ el.on_action(cx.listener(Self::reveal_in_finder))
+ .on_action(cx.listener(Self::open_system))
+ .on_action(cx.listener(Self::open_in_terminal))
+ },
+ )
.when(project.is_via_remote_server(), |el| {
el.on_action(cx.listener(Self::open_in_terminal))
.on_action(cx.listener(Self::download_from_remote))
@@ -1161,12 +1161,11 @@ impl RemoteServerProjects {
workspace.toggle_modal(window, cx, |window, cx| {
RemoteConnectionModal::new(&connection_options, Vec::new(), window, cx)
});
- let prompt = workspace
- .active_modal::<RemoteConnectionModal>(cx)
- .unwrap()
- .read(cx)
- .prompt
- .clone();
+ // can be None if another copy of this modal opened in the meantime
+ let Some(modal) = workspace.active_modal::<RemoteConnectionModal>(cx) else {
+ return;
+ };
+ let prompt = modal.read(cx).prompt.clone();
let connect = connect(
ConnectionIdentifier::setup(),
@@ -2,15 +2,12 @@
/// The tests in this file assume that server_cx is running on Windows too.
/// We neead to find a way to test Windows-Non-Windows interactions.
use crate::headless_project::HeadlessProject;
-use agent::{
- AgentTool, ReadFileTool, ReadFileToolInput, Templates, Thread, ToolCallEventStream, ToolInput,
-};
+use agent::{AgentTool, ReadFileTool, ReadFileToolInput, ToolCallEventStream, ToolInput};
use client::{Client, UserStore};
use clock::FakeSystemClock;
use collections::{HashMap, HashSet};
use git::repository::DiffType;
-use language_model::{LanguageModelToolResultContent, fake_provider::FakeLanguageModel};
-use prompt_store::ProjectContext;
+use language_model::LanguageModelToolResultContent;
use extension::ExtensionHostProxy;
use fs::{FakeFs, Fs};
@@ -2065,27 +2062,12 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu
let action_log = cx.new(|_| action_log::ActionLog::new(project.clone()));
- // Create a minimal thread for the ReadFileTool
- let context_server_registry =
- cx.new(|cx| agent::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
- let model = Arc::new(FakeLanguageModel::default());
- let thread = cx.new(|cx| {
- Thread::new(
- project.clone(),
- cx.new(|_cx| ProjectContext::default()),
- context_server_registry,
- Templates::new(),
- Some(model),
- cx,
- )
- });
-
let input = ReadFileToolInput {
path: "project/b.txt".into(),
start_line: None,
end_line: None,
};
- let read_tool = Arc::new(ReadFileTool::new(thread.downgrade(), project, action_log));
+ let read_tool = Arc::new(ReadFileTool::new(project, action_log, true));
let (event_stream, _) = ToolCallEventStream::test();
let exists_result = cx.update(|cx| {
@@ -1,13 +1,11 @@
-#![allow(unused, dead_code)]
use std::sync::Arc;
use std::time::{Duration, Instant};
use editor::{Editor, EditorMode, MultiBuffer, SizingBehavior};
use futures::future::Shared;
use gpui::{
- App, Entity, EventEmitter, Focusable, Hsla, InteractiveElement, KeyContext,
- RetainAllImageCache, StatefulInteractiveElement, Task, TextStyleRefinement, image_cache,
- prelude::*,
+ App, Entity, EventEmitter, Focusable, Hsla, InteractiveElement, RetainAllImageCache,
+ StatefulInteractiveElement, Task, TextStyleRefinement, prelude::*,
};
use language::{Buffer, Language, LanguageRegistry};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
@@ -236,7 +234,7 @@ pub trait RenderableCell: Render {
fn source(&self) -> &String;
fn selected(&self) -> bool;
fn set_selected(&mut self, selected: bool) -> &mut Self;
- fn selected_bg_color(&self, window: &mut Window, cx: &mut Context<Self>) -> Hsla {
+ fn selected_bg_color(&self, _window: &mut Window, cx: &mut Context<Self>) -> Hsla {
if self.selected() {
let mut color = cx.theme().colors().element_hover;
color.fade_out(0.5);
@@ -253,7 +251,7 @@ pub trait RenderableCell: Render {
fn cell_position_spacer(
&self,
is_first: bool,
- window: &mut Window,
+ _window: &mut Window,
cx: &mut Context<Self>,
) -> Option<impl IntoElement> {
let cell_position = self.cell_position();
@@ -328,7 +326,6 @@ pub struct MarkdownCell {
editing: bool,
selected: bool,
cell_position: Option<CellPosition>,
- languages: Arc<LanguageRegistry>,
_editor_subscription: gpui::Subscription,
}
@@ -386,7 +383,6 @@ impl MarkdownCell {
let markdown = cx.new(|cx| Markdown::new(source.clone().into(), None, None, cx));
- let cell_id = id.clone();
let editor_subscription =
cx.subscribe(&editor, move |this, _editor, event, cx| match event {
editor::EditorEvent::Blurred => {
@@ -410,7 +406,6 @@ impl MarkdownCell {
editing: start_editing,
selected: false,
cell_position: None,
- languages,
_editor_subscription: editor_subscription,
}
}
@@ -461,8 +456,6 @@ impl MarkdownCell {
.unwrap_or_default();
self.source = source.clone();
- let languages = self.languages.clone();
-
self.markdown.update(cx, |markdown, cx| {
markdown.reset(source.into(), cx);
});
@@ -606,7 +599,7 @@ pub struct CodeCell {
outputs: Vec<Output>,
selected: bool,
cell_position: Option<CellPosition>,
- language_task: Task<()>,
+ _language_task: Task<()>,
execution_start_time: Option<Instant>,
execution_duration: Option<Duration>,
is_executing: bool,
@@ -670,10 +663,10 @@ impl CodeCell {
outputs: Vec::new(),
selected: false,
cell_position: None,
- language_task,
execution_start_time: None,
execution_duration: None,
is_executing: false,
+ _language_task: language_task,
}
}
@@ -748,10 +741,10 @@ impl CodeCell {
outputs,
selected: false,
cell_position: None,
- language_task,
execution_start_time: None,
execution_duration: None,
is_executing: false,
+ _language_task: language_task,
}
}
@@ -879,15 +872,7 @@ impl CodeCell {
cx.notify();
}
- fn output_control(&self) -> Option<CellControlType> {
- if self.has_outputs() {
- Some(CellControlType::ClearCell)
- } else {
- None
- }
- }
-
- pub fn gutter_output(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ pub fn gutter_output(&self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_selected = self.selected();
div()
@@ -948,7 +933,7 @@ impl RenderableCell for CodeCell {
&self.source
}
- fn control(&self, window: &mut Window, cx: &mut Context<Self>) -> Option<CellControl> {
+ fn control(&self, _window: &mut Window, cx: &mut Context<Self>) -> Option<CellControl> {
let control_type = if self.has_outputs() {
CellControlType::RerunCell
} else {
@@ -1038,8 +1023,7 @@ impl RenderableCell for CodeCell {
}
impl RunnableCell for CodeCell {
- fn run(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- println!("Running code cell: {}", self.id);
+ fn run(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
cx.emit(CellEvent::Run(self.id.clone()));
}
@@ -1062,11 +1046,8 @@ impl Render for CodeCell {
} else {
None
};
- let output_max_width = plain::max_width_for_columns(
- ReplSettings::get_global(cx).output_max_width_columns,
- window,
- cx,
- );
+ let output_max_width =
+ plain::max_width_for_columns(ReplSettings::get_global(cx).max_columns, window, cx);
// get the language from the editor's buffer
let language_name = self
.editor
@@ -1198,41 +1179,23 @@ impl Render for CodeCell {
},
)
// output at bottom
- .child(div().w_full().children(self.outputs.iter().map(
- |output| {
- let content = match output {
- Output::Plain { content, .. } => {
- Some(content.clone().into_any_element())
- }
- Output::Markdown { content, .. } => {
- Some(content.clone().into_any_element())
- }
- Output::Stream { content, .. } => {
- Some(content.clone().into_any_element())
- }
- Output::Image { content, .. } => {
- Some(content.clone().into_any_element())
- }
- Output::Message(message) => Some(
- div()
- .child(message.clone())
- .into_any_element(),
- ),
- Output::Table { content, .. } => {
- Some(content.clone().into_any_element())
- }
- Output::Json { content, .. } => {
- Some(content.clone().into_any_element())
- }
- Output::ErrorOutput(error_view) => {
- error_view.render(window, cx)
- }
- Output::ClearOutputWaitMarker => None,
- };
-
- div().children(content)
- },
- ))),
+ .child(
+ div()
+ .id((
+ ElementId::from(self.id.to_string()),
+ "output-scroll",
+ ))
+ .w_full()
+ .when_some(output_max_width, |div, max_width| {
+ div.max_w(max_width).overflow_x_scroll()
+ })
+ .when_some(output_max_height, |div, max_height| {
+ div.max_h(max_height).overflow_y_scroll()
+ })
+ .children(self.outputs.iter().map(|output| {
+ div().children(output.content(window, cx))
+ })),
+ ),
),
),
)
@@ -253,18 +253,8 @@ impl Output {
)
}
- pub fn render(
- &self,
- workspace: WeakEntity<Workspace>,
- window: &mut Window,
- cx: &mut Context<ExecutionView>,
- ) -> impl IntoElement + use<> {
- let max_width = plain::max_width_for_columns(
- ReplSettings::get_global(cx).output_max_width_columns,
- window,
- cx,
- );
- let content = match self {
+ pub fn content(&self, window: &mut Window, cx: &mut App) -> Option<AnyElement> {
+ match self {
Self::Plain { content, .. } => Some(content.clone().into_any_element()),
Self::Markdown { content, .. } => Some(content.clone().into_any_element()),
Self::Stream { content, .. } => Some(content.clone().into_any_element()),
@@ -274,21 +264,36 @@ impl Output {
Self::Json { content, .. } => Some(content.clone().into_any_element()),
Self::ErrorOutput(error_view) => error_view.render(window, cx),
Self::ClearOutputWaitMarker => None,
- };
+ }
+ }
- let needs_horizontal_scroll = matches!(self, Self::Table { .. } | Self::Image { .. });
+ pub fn render(
+ &self,
+ workspace: WeakEntity<Workspace>,
+ window: &mut Window,
+ cx: &mut Context<ExecutionView>,
+ ) -> impl IntoElement + use<> {
+ let max_width =
+ plain::max_width_for_columns(ReplSettings::get_global(cx).max_columns, window, cx);
+ let content = self.content(window, cx);
+
+ let needs_horizontal_scroll = matches!(self, Self::Table { .. });
h_flex()
.id("output-content")
.w_full()
- .when_some(max_width, |this, max_w| this.max_w(max_w))
- .overflow_x_scroll()
+ .when_else(
+ needs_horizontal_scroll,
+ |this| this.overflow_x_scroll(),
+ |this| this.overflow_x_hidden(),
+ )
.items_start()
.child(
div()
.when(!needs_horizontal_scroll, |el| {
el.flex_1().w_full().overflow_x_hidden()
})
+ .when_some(max_width, |el, max_width| el.max_w(max_width))
.children(content),
)
.children(match self {
@@ -3,10 +3,10 @@ use base64::{
Engine as _, alphabet,
engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig},
};
-use gpui::{App, ClipboardItem, Image, ImageFormat, RenderImage, Window, img};
+use gpui::{App, ClipboardItem, Image, ImageFormat, Pixels, RenderImage, Window, img};
use settings::Settings as _;
use std::sync::Arc;
-use ui::{IntoElement, Styled, div, prelude::*};
+use ui::{IntoElement, Styled, prelude::*};
use crate::outputs::{OutputContent, plain};
use crate::repl_settings::ReplSettings;
@@ -113,7 +113,7 @@ impl Render for ImageView {
let settings = ReplSettings::get_global(cx);
let line_height = window.line_height();
- let max_width = plain::max_width_for_columns(settings.output_max_width_columns, window, cx);
+ let max_width = plain::max_width_for_columns(settings.max_columns, window, cx);
let max_height = if settings.output_max_height_lines > 0 {
Some(line_height * settings.output_max_height_lines as f32)
@@ -125,7 +125,7 @@ impl Render for ImageView {
let image = self.image.clone();
- div().h(height).w(width).child(img(image))
+ img(image).w(width).h(height)
}
}
@@ -22,7 +22,7 @@ use alacritty_terminal::{
term::Config,
vte::ansi::Processor,
};
-use gpui::{Bounds, ClipboardItem, Entity, FontStyle, TextStyle, WhiteSpace, canvas, size};
+use gpui::{Bounds, ClipboardItem, Entity, FontStyle, Pixels, TextStyle, WhiteSpace, canvas, size};
use language::Buffer;
use settings::Settings as _;
use terminal::terminal_settings::TerminalSettings;
@@ -27,11 +27,6 @@ pub struct ReplSettings {
///
/// Default: 0
pub output_max_height_lines: usize,
- /// Maximum number of columns of output to display before scaling images.
- /// Set to 0 to disable output width limits.
- ///
- /// Default: 0
- pub output_max_width_columns: usize,
}
impl Settings for ReplSettings {
@@ -44,7 +39,6 @@ impl Settings for ReplSettings {
inline_output: repl.inline_output.unwrap_or(true),
inline_output_max_length: repl.inline_output_max_length.unwrap_or(50),
output_max_height_lines: repl.output_max_height_lines.unwrap_or(0),
- output_max_width_columns: repl.output_max_width_columns.unwrap_or(0),
}
}
}
@@ -90,7 +90,7 @@ pub enum EditPredictionProvider {
Experimental(&'static str),
}
-pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2";
+const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2";
impl<'de> Deserialize<'de> for EditPredictionProvider {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
@@ -157,10 +157,7 @@ impl EditPredictionProvider {
EditPredictionProvider::Codestral => Some("Codestral"),
EditPredictionProvider::Sweep => Some("Sweep"),
EditPredictionProvider::Mercury => Some("Mercury"),
- EditPredictionProvider::Experimental(
- EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
- ) => Some("Zeta2"),
- EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => None,
+ EditPredictionProvider::Experimental(_) | EditPredictionProvider::None => None,
EditPredictionProvider::Ollama => Some("Ollama"),
EditPredictionProvider::OpenAiCompatibleApi => Some("OpenAI-Compatible API"),
}
@@ -1148,11 +1148,6 @@ pub struct ReplSettingsContent {
///
/// Default: 0
pub output_max_height_lines: Option<usize>,
- /// Maximum number of columns of output to display before scaling images.
- /// Set to 0 to disable output width limits.
- ///
- /// Default: 0
- pub output_max_width_columns: Option<usize>,
}
/// Settings for configuring the which-key popup behaviour.
@@ -2,6 +2,7 @@ use codestral::{CODESTRAL_API_URL, codestral_api_key_state, codestral_api_url};
use edit_prediction::{
ApiKeyState,
mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
+ open_ai_compatible::{open_ai_compatible_api_token, open_ai_compatible_api_url},
sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
};
use edit_prediction_ui::{get_available_providers, set_completion_provider};
@@ -33,7 +34,9 @@ pub(crate) fn render_edit_prediction_setup_page(
render_api_key_provider(
IconName::Inception,
"Mercury",
- "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
+ ApiKeyDocs::Link {
+ dashboard_url: "https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
+ },
mercury_api_token(cx),
|_cx| MERCURY_CREDENTIALS_URL,
None,
@@ -46,7 +49,9 @@ pub(crate) fn render_edit_prediction_setup_page(
render_api_key_provider(
IconName::SweepAi,
"Sweep",
- "https://app.sweep.dev/".into(),
+ ApiKeyDocs::Link {
+ dashboard_url: "https://app.sweep.dev/".into(),
+ },
sweep_api_token(cx),
|_cx| SWEEP_CREDENTIALS_URL,
Some(
@@ -68,7 +73,9 @@ pub(crate) fn render_edit_prediction_setup_page(
render_api_key_provider(
IconName::AiMistral,
"Codestral",
- "https://console.mistral.ai/codestral".into(),
+ ApiKeyDocs::Link {
+ dashboard_url: "https://console.mistral.ai/codestral".into(),
+ },
codestral_api_key_state(cx),
|cx| codestral_api_url(cx),
Some(
@@ -87,7 +94,31 @@ pub(crate) fn render_edit_prediction_setup_page(
.into_any_element(),
),
Some(render_ollama_provider(settings_window, window, cx).into_any_element()),
- Some(render_open_ai_compatible_provider(settings_window, window, cx).into_any_element()),
+ Some(
+ render_api_key_provider(
+ IconName::AiOpenAiCompat,
+ "OpenAI Compatible API",
+ ApiKeyDocs::Custom {
+ message: "Set an API key here. It will be sent as Authorization: Bearer {key}."
+ .into(),
+ },
+ open_ai_compatible_api_token(cx),
+ |cx| open_ai_compatible_api_url(cx),
+ Some(
+ settings_window
+ .render_sub_page_items_section(
+ open_ai_compatible_settings().iter().enumerate(),
+ true,
+ window,
+ cx,
+ )
+ .into_any_element(),
+ ),
+ window,
+ cx,
+ )
+ .into_any_element(),
+ ),
];
div()
@@ -162,10 +193,15 @@ fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement {
.into_any_element()
}
+enum ApiKeyDocs {
+ Link { dashboard_url: SharedString },
+ Custom { message: SharedString },
+}
+
fn render_api_key_provider(
icon: IconName,
title: &'static str,
- link: SharedString,
+ docs: ApiKeyDocs,
api_key_state: Entity<ApiKeyState>,
current_url: fn(&mut App) -> SharedString,
additional_fields: Option<AnyElement>,
@@ -209,25 +245,32 @@ fn render_api_key_provider(
.icon(icon)
.no_padding(true);
let button_link_label = format!("{} dashboard", title);
- let description = h_flex()
- .min_w_0()
- .gap_0p5()
- .child(
- Label::new("Visit the")
+ let description = match docs {
+ ApiKeyDocs::Custom { message } => h_flex().min_w_0().gap_0p5().child(
+ Label::new(message)
.size(LabelSize::Small)
.color(Color::Muted),
- )
- .child(
- ButtonLink::new(button_link_label, link)
- .no_icon(true)
- .label_size(LabelSize::Small)
- .label_color(Color::Muted),
- )
- .child(
- Label::new("to generate an API key.")
- .size(LabelSize::Small)
- .color(Color::Muted),
- );
+ ),
+ ApiKeyDocs::Link { dashboard_url } => h_flex()
+ .min_w_0()
+ .gap_0p5()
+ .child(
+ Label::new("Visit the")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ ButtonLink::new(button_link_label, dashboard_url)
+ .no_icon(true)
+ .label_size(LabelSize::Small)
+ .label_color(Color::Muted),
+ )
+ .child(
+ Label::new("to generate an API key.")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ };
let configured_card_label = if is_from_env_var {
"API Key Set in Environment Variable"
} else {
@@ -484,34 +527,6 @@ fn ollama_settings() -> Box<[SettingsPageItem]> {
])
}
-fn render_open_ai_compatible_provider(
- settings_window: &SettingsWindow,
- window: &mut Window,
- cx: &mut Context<SettingsWindow>,
-) -> impl IntoElement {
- let open_ai_compatible_settings = open_ai_compatible_settings();
- let additional_fields = settings_window
- .render_sub_page_items_section(
- open_ai_compatible_settings.iter().enumerate(),
- true,
- window,
- cx,
- )
- .into_any_element();
-
- v_flex()
- .id("open-ai-compatible")
- .min_w_0()
- .pt_8()
- .gap_1p5()
- .child(
- SettingsSectionHeader::new("OpenAI Compatible API")
- .icon(IconName::AiOpenAiCompat)
- .no_padding(true),
- )
- .child(div().px_neg_8().child(additional_fields))
-}
-
fn open_ai_compatible_settings() -> Box<[SettingsPageItem]> {
Box::new([
SettingsPageItem::SettingItem(SettingItem {
@@ -89,7 +89,7 @@ const FILE_SUFFIXES_BY_ICON_KEY: &[(&str, &[&str])] = &[
(
"cpp",
&[
- "c++", "h++", "cc", "cpp", "cxx", "hh", "hpp", "hxx", "inl", "ixx",
+ "c++", "h++", "cc", "cpp", "cppm", "cxx", "hh", "hpp", "hxx", "inl", "ixx",
],
),
("crystal", &["cr", "ecr"]),
@@ -14,6 +14,7 @@ path = "src/web_search_providers.rs"
[dependencies]
anyhow.workspace = true
client.workspace = true
+cloud_api_types.workspace = true
cloud_llm_client.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -1,7 +1,8 @@
use std::sync::Arc;
use anyhow::{Context as _, Result};
-use client::Client;
+use client::{Client, UserStore};
+use cloud_api_types::OrganizationId;
use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
@@ -14,8 +15,8 @@ pub struct CloudWebSearchProvider {
}
impl CloudWebSearchProvider {
- pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
- let state = cx.new(|cx| State::new(client, cx));
+ pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) -> Self {
+ let state = cx.new(|cx| State::new(client, user_store, cx));
Self { state }
}
@@ -23,24 +24,31 @@ impl CloudWebSearchProvider {
pub struct State {
client: Arc<Client>,
+ user_store: Entity<UserStore>,
llm_api_token: LlmApiToken,
_llm_token_subscription: Subscription,
}
impl State {
- pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
+ pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
Self {
client,
+ user_store,
llm_api_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
|this, _, _event, cx| {
let client = this.client.clone();
let llm_api_token = this.llm_api_token.clone();
+ let organization_id = this
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone());
cx.spawn(async move |_this, _cx| {
- llm_api_token.refresh(&client).await?;
+ llm_api_token.refresh(&client, organization_id).await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
@@ -61,21 +69,31 @@ impl WebSearchProvider for CloudWebSearchProvider {
let state = self.state.read(cx);
let client = state.client.clone();
let llm_api_token = state.llm_api_token.clone();
+ let organization_id = state
+ .user_store
+ .read(cx)
+ .current_organization()
+ .map(|o| o.id.clone());
let body = WebSearchBody { query };
- cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
+ cx.background_spawn(async move {
+ perform_web_search(client, llm_api_token, organization_id, body).await
+ })
}
}
async fn perform_web_search(
client: Arc<Client>,
llm_api_token: LlmApiToken,
+ organization_id: Option<OrganizationId>,
body: WebSearchBody,
) -> Result<WebSearchResponse> {
const MAX_RETRIES: usize = 3;
let http_client = &client.http_client();
let mut retries_remaining = MAX_RETRIES;
- let mut token = llm_api_token.acquire(&client).await?;
+ let mut token = llm_api_token
+ .acquire(&client, organization_id.clone())
+ .await?;
loop {
if retries_remaining == 0 {
@@ -100,7 +118,9 @@ async fn perform_web_search(
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else if response.needs_llm_token_refresh() {
- token = llm_api_token.refresh(&client).await?;
+ token = llm_api_token
+ .refresh(&client, organization_id.clone())
+ .await?;
retries_remaining -= 1;
} else {
// For now we will only retry if the LLM token is expired,
@@ -1,26 +1,28 @@
mod cloud;
-use client::Client;
+use client::{Client, UserStore};
use gpui::{App, Context, Entity};
use language_model::LanguageModelRegistry;
use std::sync::Arc;
use web_search::{WebSearchProviderId, WebSearchRegistry};
-pub fn init(client: Arc<Client>, cx: &mut App) {
+pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let registry = WebSearchRegistry::global(cx);
registry.update(cx, |registry, cx| {
- register_web_search_providers(registry, client, cx);
+ register_web_search_providers(registry, client, user_store, cx);
});
}
fn register_web_search_providers(
registry: &mut WebSearchRegistry,
client: Arc<Client>,
+ user_store: Entity<UserStore>,
cx: &mut Context<WebSearchRegistry>,
) {
register_zed_web_search_provider(
registry,
client.clone(),
+ user_store.clone(),
&LanguageModelRegistry::global(cx),
cx,
);
@@ -29,7 +31,13 @@ fn register_web_search_providers(
&LanguageModelRegistry::global(cx),
move |this, registry, event, cx| {
if let language_model::Event::DefaultModelChanged = event {
- register_zed_web_search_provider(this, client.clone(), ®istry, cx)
+ register_zed_web_search_provider(
+ this,
+ client.clone(),
+ user_store.clone(),
+ ®istry,
+ cx,
+ )
}
},
)
@@ -39,6 +47,7 @@ fn register_web_search_providers(
fn register_zed_web_search_provider(
registry: &mut WebSearchRegistry,
client: Arc<Client>,
+ user_store: Entity<UserStore>,
language_model_registry: &Entity<LanguageModelRegistry>,
cx: &mut Context<WebSearchRegistry>,
) {
@@ -47,7 +56,10 @@ fn register_zed_web_search_provider(
.default_model()
.is_some_and(|default| default.is_provided_by_zed());
if using_zed_provider {
- registry.register_provider(cloud::CloudWebSearchProvider::new(client, cx), cx)
+ registry.register_provider(
+ cloud::CloudWebSearchProvider::new(client, user_store, cx),
+ cx,
+ )
} else {
registry.unregister_provider(WebSearchProviderId(
cloud::ZED_WEB_SEARCH_PROVIDER_ID.into(),
@@ -2945,7 +2945,7 @@ impl BackgroundScannerState {
self.snapshot.check_invariants(false);
}
- fn remove_path(&mut self, path: &RelPath) {
+ fn remove_path(&mut self, path: &RelPath, watcher: &dyn Watcher) {
log::trace!("background scanner removing path {path:?}");
let mut new_entries;
let removed_entries;
@@ -2961,7 +2961,12 @@ impl BackgroundScannerState {
self.snapshot.entries_by_path = new_entries;
let mut removed_ids = Vec::with_capacity(removed_entries.summary().count);
+ let mut removed_dir_abs_paths = Vec::new();
for entry in removed_entries.cursor::<()>(()) {
+ if entry.is_dir() {
+ removed_dir_abs_paths.push(self.snapshot.absolutize(&entry.path));
+ }
+
match self.removed_entries.entry(entry.inode) {
hash_map::Entry::Occupied(mut e) => {
let prev_removed_entry = e.get_mut();
@@ -2997,6 +3002,10 @@ impl BackgroundScannerState {
.git_repositories
.retain(|id, _| removed_ids.binary_search(id).is_err());
+ for removed_dir_abs_path in removed_dir_abs_paths {
+ watcher.remove(&removed_dir_abs_path).log_err();
+ }
+
#[cfg(feature = "test-support")]
self.snapshot.check_invariants(false);
}
@@ -4461,7 +4470,10 @@ impl BackgroundScanner {
if self.settings.is_path_excluded(&child_path) {
log::debug!("skipping excluded child entry {child_path:?}");
- self.state.lock().await.remove_path(&child_path);
+ self.state
+ .lock()
+ .await
+ .remove_path(&child_path, self.watcher.as_ref());
continue;
}
@@ -4651,7 +4663,7 @@ impl BackgroundScanner {
// detected regardless of the order of the paths.
for (path, metadata) in relative_paths.iter().zip(metadata.iter()) {
if matches!(metadata, Ok(None)) || doing_recursive_update {
- state.remove_path(path);
+ state.remove_path(path, self.watcher.as_ref());
}
}
@@ -645,7 +645,7 @@ fn main() {
zed::remote_debug::init(cx);
edit_prediction_ui::init(cx);
web_search::init(cx);
- web_search_providers::init(app_state.client.clone(), cx);
+ web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx);
snippet_provider::init(cx);
edit_prediction_registry::init(app_state.client.clone(), app_state.user_store.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx);
@@ -144,7 +144,7 @@ fn cleanup_old_hang_traces() {
entry
.path()
.extension()
- .is_some_and(|ext| ext == "miniprof")
+ .is_some_and(|ext| ext == "json" || ext == "miniprof")
})
.collect();
@@ -175,7 +175,7 @@ fn save_hang_trace(
.collect::<Vec<_>>();
let trace_path = paths::hang_traces_dir().join(&format!(
- "hang-{}.miniprof",
+ "hang-{}.miniprof.json",
hang_time.format("%Y-%m-%d_%H-%M-%S")
));
@@ -193,7 +193,7 @@ fn save_hang_trace(
entry
.path()
.extension()
- .is_some_and(|ext| ext == "miniprof")
+ .is_some_and(|ext| ext == "json" || ext == "miniprof")
})
.collect();
@@ -2032,32 +2032,9 @@ fn run_agent_thread_view_test(
// Create the necessary entities for the ReadFileTool
let action_log = cx.update(|cx| cx.new(|_| action_log::ActionLog::new(project.clone())));
- let context_server_registry = cx.update(|cx| {
- cx.new(|cx| agent::ContextServerRegistry::new(project.read(cx).context_server_store(), cx))
- });
- let fake_model = Arc::new(language_model::fake_provider::FakeLanguageModel::default());
- let project_context = cx.update(|cx| cx.new(|_| prompt_store::ProjectContext::default()));
-
- // Create the agent Thread
- let thread = cx.update(|cx| {
- cx.new(|cx| {
- agent::Thread::new(
- project.clone(),
- project_context,
- context_server_registry,
- agent::Templates::new(),
- Some(fake_model),
- cx,
- )
- })
- });
// Create the ReadFileTool
- let tool = Arc::new(agent::ReadFileTool::new(
- thread.downgrade(),
- project.clone(),
- action_log,
- ));
+ let tool = Arc::new(agent::ReadFileTool::new(project.clone(), action_log, true));
// Create a test event stream to capture tool output
let (event_stream, mut event_receiver) = agent::ToolCallEventStream::test();
@@ -5021,7 +5021,7 @@ mod tests {
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
web_search::init(cx);
git_graph::init(cx);
- web_search_providers::init(app_state.client.clone(), cx);
+ web_search_providers::init(app_state.client.clone(), app_state.user_store.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
project::AgentRegistryStore::init_global(
cx,
@@ -2,15 +2,12 @@ use client::{Client, UserStore};
use codestral::{CodestralEditPredictionDelegate, load_codestral_api_key};
use collections::HashMap;
use copilot::CopilotEditPredictionDelegate;
-use edit_prediction::{EditPredictionModel, ZedEditPredictionDelegate, Zeta2FeatureFlag};
+use edit_prediction::{EditPredictionModel, ZedEditPredictionDelegate};
use editor::Editor;
-use feature_flags::FeatureFlagAppExt;
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
-use settings::{
- EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, EditPredictionPromptFormat, SettingsStore,
-};
+use settings::{EditPredictionPromptFormat, SettingsStore};
use std::{cell::RefCell, rc::Rc, sync::Arc};
use ui::Window;
@@ -81,9 +78,6 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
.detach();
cx.observe_global::<SettingsStore>({
- let editors = editors.clone();
- let client = client.clone();
- let user_store = user_store.clone();
let mut previous_config = edit_prediction_provider_config_for_settings(cx);
move |cx| {
let new_provider_config = edit_prediction_provider_config_for_settings(cx);
@@ -107,24 +101,6 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
}
})
.detach();
-
- cx.observe_flag::<Zeta2FeatureFlag, _>({
- let mut previous_config = edit_prediction_provider_config_for_settings(cx);
- move |_is_enabled, cx| {
- let new_provider_config = edit_prediction_provider_config_for_settings(cx);
- if new_provider_config != previous_config {
- previous_config = new_provider_config;
- assign_edit_prediction_providers(
- &editors,
- new_provider_config,
- &client,
- user_store.clone(),
- cx,
- );
- }
- }
- })
- .detach();
}
fn edit_prediction_provider_config_for_settings(cx: &App) -> Option<EditPredictionProviderConfig> {
@@ -154,7 +130,10 @@ fn edit_prediction_provider_config_for_settings(cx: &App) -> Option<EditPredicti
}
}
- if format == EditPredictionPromptFormat::Zeta {
+ if matches!(
+ format,
+ EditPredictionPromptFormat::Zeta | EditPredictionPromptFormat::Zeta2
+ ) {
Some(EditPredictionProviderConfig::Zed(EditPredictionModel::Zeta))
} else {
Some(EditPredictionProviderConfig::Zed(
@@ -168,15 +147,7 @@ fn edit_prediction_provider_config_for_settings(cx: &App) -> Option<EditPredicti
EditPredictionProvider::Mercury => Some(EditPredictionProviderConfig::Zed(
EditPredictionModel::Mercury,
)),
- EditPredictionProvider::Experimental(name) => {
- if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
- && cx.has_flag::<Zeta2FeatureFlag>()
- {
- Some(EditPredictionProviderConfig::Zed(EditPredictionModel::Zeta))
- } else {
- None
- }
- }
+ EditPredictionProvider::Experimental(_) => None,
}
}
@@ -86,6 +86,7 @@ pub enum ZetaFormat {
V0131GitMergeMarkersPrefix,
V0211Prefill,
V0211SeedCoder,
+ v0226Hashline,
}
impl std::fmt::Display for ZetaFormat {
@@ -122,25 +123,6 @@ impl ZetaFormat {
.collect::<Vec<_>>()
.concat()
}
-
- pub fn special_tokens(&self) -> &'static [&'static str] {
- match self {
- ZetaFormat::V0112MiddleAtEnd
- | ZetaFormat::V0113Ordered
- | ZetaFormat::V0114180EditableRegion => &[
- "<|fim_prefix|>",
- "<|fim_suffix|>",
- "<|fim_middle|>",
- "<|file_sep|>",
- CURSOR_MARKER,
- ],
- ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::special_tokens(),
- ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
- v0131_git_merge_markers_prefix::special_tokens()
- }
- ZetaFormat::V0211SeedCoder => seed_coder::special_tokens(),
- }
- }
}
#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
@@ -212,33 +194,29 @@ pub struct RelatedExcerpt {
}
pub fn prompt_input_contains_special_tokens(input: &ZetaPromptInput, format: ZetaFormat) -> bool {
- format
- .special_tokens()
+ special_tokens_for_format(format)
.iter()
.any(|token| input.cursor_excerpt.contains(token))
}
pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> String {
- format_zeta_prompt_with_budget(input, format, MAX_PROMPT_TOKENS)
+ format_prompt_with_budget_for_format(input, format, MAX_PROMPT_TOKENS)
}
-/// Post-processes model output for the given zeta format by stripping format-specific suffixes.
-pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str {
+pub fn special_tokens_for_format(format: ZetaFormat) -> &'static [&'static str] {
match format {
- ZetaFormat::V0120GitMergeMarkers => output
- .strip_suffix(v0120_git_merge_markers::END_MARKER)
- .unwrap_or(output),
- ZetaFormat::V0131GitMergeMarkersPrefix => output
- .strip_suffix(v0131_git_merge_markers_prefix::END_MARKER)
- .unwrap_or(output),
- ZetaFormat::V0211SeedCoder => output
- .strip_suffix(seed_coder::END_MARKER)
- .unwrap_or(output),
- _ => output,
+ ZetaFormat::V0112MiddleAtEnd => v0112_middle_at_end::special_tokens(),
+ ZetaFormat::V0113Ordered => v0113_ordered::special_tokens(),
+ ZetaFormat::V0114180EditableRegion => v0114180_editable_region::special_tokens(),
+ ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::special_tokens(),
+ ZetaFormat::V0131GitMergeMarkersPrefix => v0131_git_merge_markers_prefix::special_tokens(),
+ ZetaFormat::V0211Prefill => v0211_prefill::special_tokens(),
+ ZetaFormat::V0211SeedCoder => seed_coder::special_tokens(),
+ ZetaFormat::v0226Hashline => hashline::special_tokens(),
}
}
-pub fn excerpt_range_for_format(
+pub fn excerpt_ranges_for_format(
format: ZetaFormat,
ranges: &ExcerptRanges,
) -> (Range<usize>, Range<usize>) {
@@ -247,129 +225,257 @@ pub fn excerpt_range_for_format(
ranges.editable_150.clone(),
ranges.editable_150_context_350.clone(),
),
- ZetaFormat::V0114180EditableRegion
- | ZetaFormat::V0120GitMergeMarkers
+ ZetaFormat::V0114180EditableRegion => (
+ ranges.editable_180.clone(),
+ ranges.editable_180_context_350.clone(),
+ ),
+ ZetaFormat::V0120GitMergeMarkers
| ZetaFormat::V0131GitMergeMarkersPrefix
| ZetaFormat::V0211Prefill
- | ZetaFormat::V0211SeedCoder => (
+ | ZetaFormat::V0211SeedCoder
+ | ZetaFormat::v0226Hashline => (
ranges.editable_350.clone(),
ranges.editable_350_context_150.clone(),
),
}
}
-pub fn resolve_cursor_region(
- input: &ZetaPromptInput,
- format: ZetaFormat,
-) -> (&str, Range<usize>, usize) {
- let (editable_range, context_range) = excerpt_range_for_format(format, &input.excerpt_ranges);
- let context_start = context_range.start;
- let context_text = &input.cursor_excerpt[context_range];
- let adjusted_editable =
- (editable_range.start - context_start)..(editable_range.end - context_start);
- let adjusted_cursor = input.cursor_offset_in_excerpt - context_start;
-
- (context_text, adjusted_editable, adjusted_cursor)
-}
-
-fn format_zeta_prompt_with_budget(
- input: &ZetaPromptInput,
+pub fn write_cursor_excerpt_section_for_format(
format: ZetaFormat,
- max_tokens: usize,
-) -> String {
- let (context, editable_range, cursor_offset) = resolve_cursor_region(input, format);
- let path = &*input.cursor_path;
-
- let mut cursor_section = String::new();
+ prompt: &mut String,
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+) {
match format {
- ZetaFormat::V0112MiddleAtEnd => {
- v0112_middle_at_end::write_cursor_excerpt_section(
- &mut cursor_section,
- path,
- context,
- &editable_range,
- cursor_offset,
- );
- }
+ ZetaFormat::V0112MiddleAtEnd => v0112_middle_at_end::write_cursor_excerpt_section(
+ prompt,
+ path,
+ context,
+ editable_range,
+ cursor_offset,
+ ),
ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
v0113_ordered::write_cursor_excerpt_section(
- &mut cursor_section,
+ prompt,
path,
context,
- &editable_range,
+ editable_range,
cursor_offset,
)
}
ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::write_cursor_excerpt_section(
- &mut cursor_section,
+ prompt,
path,
context,
- &editable_range,
+ editable_range,
cursor_offset,
),
ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
v0131_git_merge_markers_prefix::write_cursor_excerpt_section(
- &mut cursor_section,
+ prompt,
path,
context,
- &editable_range,
+ editable_range,
cursor_offset,
)
}
- ZetaFormat::V0211SeedCoder => {
- return seed_coder::format_prompt_with_budget(
+ ZetaFormat::V0211SeedCoder => seed_coder::write_cursor_excerpt_section(
+ prompt,
+ path,
+ context,
+ editable_range,
+ cursor_offset,
+ ),
+ ZetaFormat::v0226Hashline => hashline::write_cursor_excerpt_section(
+ prompt,
+ path,
+ context,
+ editable_range,
+ cursor_offset,
+ ),
+ }
+}
+
+pub fn format_prompt_with_budget_for_format(
+ input: &ZetaPromptInput,
+ format: ZetaFormat,
+ max_tokens: usize,
+) -> String {
+ let (context, editable_range, cursor_offset) = resolve_cursor_region(input, format);
+ let path = &*input.cursor_path;
+
+ match format {
+ ZetaFormat::V0211SeedCoder => seed_coder::format_prompt_with_budget(
+ path,
+ context,
+ &editable_range,
+ cursor_offset,
+ &input.events,
+ &input.related_files,
+ max_tokens,
+ ),
+ _ => {
+ let mut cursor_section = String::new();
+ write_cursor_excerpt_section_for_format(
+ format,
+ &mut cursor_section,
path,
context,
&editable_range,
cursor_offset,
+ );
+
+ let cursor_tokens = estimate_tokens(cursor_section.len());
+ let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens);
+
+ let edit_history_section = format_edit_history_within_budget(
&input.events,
+ "<|file_sep|>",
+ "edit history",
+ budget_after_cursor,
+ );
+ let edit_history_tokens = estimate_tokens(edit_history_section.len());
+ let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
+
+ let related_files_section = format_related_files_within_budget(
&input.related_files,
- max_tokens,
+ "<|file_sep|>",
+ "",
+ budget_after_edit_history,
);
+
+ let mut prompt = String::new();
+ prompt.push_str(&related_files_section);
+ prompt.push_str(&edit_history_section);
+ prompt.push_str(&cursor_section);
+ prompt
}
}
-
- let cursor_tokens = estimate_tokens(cursor_section.len());
- let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens);
-
- let edit_history_section = format_edit_history_within_budget(
- &input.events,
- "<|file_sep|>",
- "edit history",
- budget_after_cursor,
- );
- let edit_history_tokens = estimate_tokens(edit_history_section.len());
- let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
-
- let related_files_section = format_related_files_within_budget(
- &input.related_files,
- "<|file_sep|>",
- "",
- budget_after_edit_history,
- );
-
- let mut prompt = String::new();
- prompt.push_str(&related_files_section);
- prompt.push_str(&edit_history_section);
- prompt.push_str(&cursor_section);
- prompt
}
-pub fn get_prefill(input: &ZetaPromptInput, format: ZetaFormat) -> String {
+pub fn get_prefill_for_format(
+ format: ZetaFormat,
+ context: &str,
+ editable_range: &Range<usize>,
+) -> String {
match format {
+ ZetaFormat::V0211Prefill => v0211_prefill::get_prefill(context, editable_range),
ZetaFormat::V0112MiddleAtEnd
| ZetaFormat::V0113Ordered
| ZetaFormat::V0114180EditableRegion
| ZetaFormat::V0120GitMergeMarkers
| ZetaFormat::V0131GitMergeMarkersPrefix
- | ZetaFormat::V0211SeedCoder => String::new(),
- ZetaFormat::V0211Prefill => {
- let (context, editable_range, _) = resolve_cursor_region(input, format);
- v0211_prefill::get_prefill(context, &editable_range)
+ | ZetaFormat::V0211SeedCoder
+ | ZetaFormat::v0226Hashline => String::new(),
+ }
+}
+
+pub fn output_end_marker_for_format(format: ZetaFormat) -> Option<&'static str> {
+ match format {
+ ZetaFormat::V0120GitMergeMarkers => Some(v0120_git_merge_markers::END_MARKER),
+ ZetaFormat::V0131GitMergeMarkersPrefix => Some(v0131_git_merge_markers_prefix::END_MARKER),
+ ZetaFormat::V0211Prefill => Some(v0131_git_merge_markers_prefix::END_MARKER),
+ ZetaFormat::V0211SeedCoder => Some(seed_coder::END_MARKER),
+ ZetaFormat::V0112MiddleAtEnd
+ | ZetaFormat::V0113Ordered
+ | ZetaFormat::V0114180EditableRegion
+ | ZetaFormat::v0226Hashline => None,
+ }
+}
+
+pub fn current_region_markers_for_format(format: ZetaFormat) -> (&'static str, &'static str) {
+ match format {
+ ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
+ ZetaFormat::V0113Ordered
+ | ZetaFormat::V0114180EditableRegion
+ | ZetaFormat::v0226Hashline => ("<|fim_middle|>current\n", "<|fim_suffix|>"),
+ ZetaFormat::V0120GitMergeMarkers
+ | ZetaFormat::V0131GitMergeMarkersPrefix
+ | ZetaFormat::V0211Prefill => (
+ v0120_git_merge_markers::START_MARKER,
+ v0120_git_merge_markers::SEPARATOR,
+ ),
+ ZetaFormat::V0211SeedCoder => (seed_coder::START_MARKER, seed_coder::SEPARATOR),
+ }
+}
+
+pub fn clean_extracted_region_for_format(format: ZetaFormat, region: &str) -> String {
+ match format {
+ ZetaFormat::v0226Hashline => hashline::strip_hashline_prefixes(region),
+ _ => region.to_string(),
+ }
+}
+
+pub fn encode_patch_as_output_for_format(
+ format: ZetaFormat,
+ old_editable_region: &str,
+ patch: &str,
+ cursor_offset: Option<usize>,
+) -> Result<Option<String>> {
+ match format {
+ ZetaFormat::v0226Hashline => {
+ hashline::patch_to_edit_commands(old_editable_region, patch, cursor_offset).map(Some)
+ }
+ _ => Ok(None),
+ }
+}
+
+pub fn output_with_context_for_format(
+ format: ZetaFormat,
+ old_editable_region: &str,
+ output: &str,
+) -> Result<Option<String>> {
+ match format {
+ ZetaFormat::v0226Hashline => {
+ if hashline::output_has_edit_commands(output) {
+ Ok(Some(hashline::apply_edit_commands(
+ old_editable_region,
+ output,
+ )))
+ } else {
+ Ok(None)
+ }
}
+ _ => Ok(None),
}
}
+/// Post-processes model output for the given zeta format by stripping format-specific suffixes.
+pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str {
+ match output_end_marker_for_format(format) {
+ Some(marker) => output.strip_suffix(marker).unwrap_or(output),
+ None => output,
+ }
+}
+
+pub fn excerpt_range_for_format(
+ format: ZetaFormat,
+ ranges: &ExcerptRanges,
+) -> (Range<usize>, Range<usize>) {
+ excerpt_ranges_for_format(format, ranges)
+}
+
+pub fn resolve_cursor_region(
+ input: &ZetaPromptInput,
+ format: ZetaFormat,
+) -> (&str, Range<usize>, usize) {
+ let (editable_range, context_range) = excerpt_range_for_format(format, &input.excerpt_ranges);
+ let context_start = context_range.start;
+ let context_text = &input.cursor_excerpt[context_range];
+ let adjusted_editable =
+ (editable_range.start - context_start)..(editable_range.end - context_start);
+ let adjusted_cursor = input.cursor_offset_in_excerpt - context_start;
+
+ (context_text, adjusted_editable, adjusted_cursor)
+}
+
+pub fn get_prefill(input: &ZetaPromptInput, format: ZetaFormat) -> String {
+ let (context, editable_range, _) = resolve_cursor_region(input, format);
+ get_prefill_for_format(format, context, &editable_range)
+}
+
fn format_edit_history_within_budget(
events: &[Arc<Event>],
file_marker: &str,
@@ -533,6 +639,16 @@ pub fn write_related_files(
mod v0112_middle_at_end {
use super::*;
+ pub fn special_tokens() -> &'static [&'static str] {
+ &[
+ "<|fim_prefix|>",
+ "<|fim_suffix|>",
+ "<|fim_middle|>",
+ "<|file_sep|>",
+ CURSOR_MARKER,
+ ]
+ }
+
pub fn write_cursor_excerpt_section(
prompt: &mut String,
path: &Path,
@@ -567,6 +683,16 @@ mod v0112_middle_at_end {
mod v0113_ordered {
use super::*;
+ pub fn special_tokens() -> &'static [&'static str] {
+ &[
+ "<|fim_prefix|>",
+ "<|fim_suffix|>",
+ "<|fim_middle|>",
+ "<|file_sep|>",
+ CURSOR_MARKER,
+ ]
+ }
+
pub fn write_cursor_excerpt_section(
prompt: &mut String,
path: &Path,
@@ -601,6 +727,14 @@ mod v0113_ordered {
}
}
+mod v0114180_editable_region {
+ use super::*;
+
+ pub fn special_tokens() -> &'static [&'static str] {
+ v0113_ordered::special_tokens()
+ }
+}
+
pub mod v0120_git_merge_markers {
//! A prompt that uses git-style merge conflict markers to represent the editable region.
//!
@@ -752,6 +886,10 @@ pub mod v0131_git_merge_markers_prefix {
pub mod v0211_prefill {
use super::*;
+ pub fn special_tokens() -> &'static [&'static str] {
+ v0131_git_merge_markers_prefix::special_tokens()
+ }
+
pub fn get_prefill(context: &str, editable_range: &Range<usize>) -> String {
let editable_region = &context[editable_range.start..editable_range.end];
@@ -783,6 +921,1413 @@ pub mod v0211_prefill {
}
}
+pub mod hashline {
+
+ use std::fmt::Display;
+
+ pub const END_MARKER: &str = "<|fim_middle|>updated";
+ pub const START_MARKER: &str = "<|fim_middle|>current";
+
+ use super::*;
+
+ const SET_COMMAND_MARKER: &str = "<|set|>";
+ const INSERT_COMMAND_MARKER: &str = "<|insert|>";
+
+ pub fn special_tokens() -> &'static [&'static str] {
+ return &[
+ SET_COMMAND_MARKER,
+ "<|set_range|>",
+ INSERT_COMMAND_MARKER,
+ CURSOR_MARKER,
+ "<|file_sep|>",
+ "<|fim_prefix|>",
+ "<|fim_suffix|>",
+ "<|fim_middle|>",
+ ];
+ }
+
+ /// A parsed line reference like `3:c3` (line index 3 with hash 0xc3).
+ #[derive(Debug, Clone, PartialEq, Eq)]
+ struct LineRef {
+ index: usize,
+ hash: u8,
+ }
+
+ impl Display for LineRef {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}:{:02x}", self.index, self.hash)
+ }
+ }
+
+ pub fn hash_line(line: &[u8]) -> u8 {
+ let mut h: u8 = 0;
+ for &byte in line {
+ h = h.wrapping_add(byte);
+ }
+ return h;
+ }
+
+ /// Write the hashline-encoded editable region into `out`. Each line of
+ /// `editable_text` is prefixed with `{line_index}:{hash}|` and the cursor
+ /// marker is inserted at `cursor_offset_in_editable` (byte offset relative
+ /// to the start of `editable_text`).
+ pub fn write_hashline_editable_region(
+ out: &mut String,
+ editable_text: &str,
+ cursor_offset_in_editable: usize,
+ ) {
+ let mut offset = 0;
+ for (i, line) in editable_text.lines().enumerate() {
+ let (head, cursor, tail) = if cursor_offset_in_editable > offset
+ && cursor_offset_in_editable < offset + line.len()
+ {
+ (
+ &line[..cursor_offset_in_editable - offset],
+ CURSOR_MARKER,
+ &line[cursor_offset_in_editable - offset..],
+ )
+ } else {
+ (line, "", "")
+ };
+ write!(
+ out,
+ "\n{}|{head}{cursor}{tail}",
+ LineRef {
+ index: i,
+ hash: hash_line(line.as_bytes())
+ }
+ )
+ .unwrap();
+ offset += line.len() + 1;
+ }
+ }
+
+ pub fn write_cursor_excerpt_section(
+ prompt: &mut String,
+ path: &Path,
+ context: &str,
+ editable_range: &Range<usize>,
+ cursor_offset: usize,
+ ) {
+ let path_str = path.to_string_lossy();
+ write!(prompt, "<|file_sep|>{}\n", path_str).ok();
+
+ prompt.push_str("<|fim_prefix|>\n");
+ prompt.push_str(&context[..editable_range.start]);
+ prompt.push_str(START_MARKER);
+
+ let cursor_offset_in_editable = cursor_offset.saturating_sub(editable_range.start);
+ let editable_region = &context[editable_range.clone()];
+ write_hashline_editable_region(prompt, editable_region, cursor_offset_in_editable);
+
+ if !prompt.ends_with('\n') {
+ prompt.push('\n');
+ }
+
+ prompt.push_str("<|fim_suffix|>\n");
+ prompt.push_str(&context[editable_range.end..]);
+ if !prompt.ends_with('\n') {
+ prompt.push('\n');
+ }
+
+ prompt.push_str(END_MARKER);
+ }
+
+ /// A single edit command parsed from the model output.
+ #[derive(Debug)]
+ enum EditCommand<'a> {
+ /// Replace a range of lines (inclusive on both ends). Single-line set is
+ /// represented by `start == end`.
+ Set {
+ start: LineRef,
+ end: LineRef,
+ content: &'a str,
+ },
+ /// Insert new lines after the given line, or before the first line if
+ /// `after` is `None`.
+ Insert {
+ after: Option<LineRef>,
+ content: &'a str,
+ },
+ }
+
+ /// Parse a line reference like `3:c3` into a `LineRef`.
+ fn parse_line_ref(s: &str) -> Option<LineRef> {
+ let (idx_str, hash_str) = s.split_once(':')?;
+ let index = idx_str.parse::<usize>().ok()?;
+ let hash = u8::from_str_radix(hash_str, 16).ok()?;
+ Some(LineRef { index, hash })
+ }
+
+ /// Parse the model output into a list of `EditCommand`s.
+ fn parse_edit_commands(model_output: &str) -> Vec<EditCommand<'_>> {
+ let mut commands = Vec::new();
+ let mut offset = 0usize;
+
+ while offset < model_output.len() {
+ let next_nl = model_output[offset..]
+ .find('\n')
+ .map(|i| offset + i)
+ .unwrap_or(model_output.len());
+ let line = &model_output[offset..next_nl];
+ let line_end = if next_nl < model_output.len() {
+ next_nl + 1
+ } else {
+ next_nl
+ };
+
+ let trimmed = line.trim();
+ let (is_set, specifier) = if let Some(spec) = trimmed.strip_prefix(SET_COMMAND_MARKER) {
+ (true, spec)
+ } else if let Some(spec) = trimmed.strip_prefix(INSERT_COMMAND_MARKER) {
+ (false, spec)
+ } else {
+ offset = line_end;
+ continue;
+ };
+
+ let mut content_end = line_end;
+ let mut scan = line_end;
+
+ while scan < model_output.len() {
+ let body_nl = model_output[scan..]
+ .find('\n')
+ .map(|i| scan + i)
+ .unwrap_or(model_output.len());
+ let body_line = &model_output[scan..body_nl];
+ if body_line.trim().starts_with(SET_COMMAND_MARKER)
+ || body_line.trim().starts_with(INSERT_COMMAND_MARKER)
+ {
+ break;
+ }
+ scan = if body_nl < model_output.len() {
+ body_nl + 1
+ } else {
+ body_nl
+ };
+ content_end = scan;
+ }
+
+ let content = &model_output[line_end..content_end];
+
+ if is_set {
+ if let Some((start_str, end_str)) = specifier.split_once('-') {
+ if let (Some(start), Some(end)) =
+ (parse_line_ref(start_str), parse_line_ref(end_str))
+ {
+ commands.push(EditCommand::Set {
+ start,
+ end,
+ content,
+ });
+ }
+ } else if let Some(target) = parse_line_ref(specifier) {
+ commands.push(EditCommand::Set {
+ start: target.clone(),
+ end: target,
+ content,
+ });
+ }
+ } else {
+ let after = parse_line_ref(specifier);
+ commands.push(EditCommand::Insert { after, content });
+ }
+
+ offset = scan;
+ }
+
+ commands
+ }
+
+ /// Returns `true` if the model output contains `<|set|>` or `<|insert|>` commands
+ /// (as opposed to being a plain full-replacement output).
+ /// Strip the `{line_num}:{hash}|` prefixes from each line of a hashline-encoded
+ /// editable region, returning the plain text content.
+ pub fn strip_hashline_prefixes(region: &str) -> String {
+ let mut decoded: String = region
+ .lines()
+ .map(|line| line.find('|').map_or(line, |pos| &line[pos + 1..]))
+ .collect::<Vec<_>>()
+ .join("\n");
+ if region.ends_with('\n') {
+ decoded.push('\n');
+ }
+ decoded
+ }
+
+ pub fn output_has_edit_commands(model_output: &str) -> bool {
+ model_output.contains(SET_COMMAND_MARKER) || model_output.contains(INSERT_COMMAND_MARKER)
+ }
+
+ /// Apply `<|set|>` and `<|insert|>` edit commands from the model output to the
+ /// original editable region text.
+ ///
+ /// `editable_region` is the original text of the editable region (without hash
+ /// prefixes). `model_output` is the raw model response containing edit commands.
+ ///
+ /// Returns the full replacement text for the editable region.
+ pub fn apply_edit_commands(editable_region: &str, model_output: &str) -> String {
+ let original_lines: Vec<&str> = editable_region.lines().collect();
+ let old_hashes: Vec<u8> = original_lines
+ .iter()
+ .map(|line| hash_line(line.as_bytes()))
+ .collect();
+
+ let commands = parse_edit_commands(model_output);
+
+ // For set operations: indexed by start line → Some((end line index, content))
+ // For insert operations: indexed by line index → vec of content to insert after
+ // Insert-before-first is tracked separately.
+ let mut set_ops: Vec<Option<(usize, &str)>> = vec![None; original_lines.len()];
+ let mut insert_before_first: Vec<&str> = Vec::new();
+ let mut insert_after: Vec<Vec<&str>> = vec![Vec::new(); original_lines.len()];
+
+ for command in &commands {
+ match command {
+ EditCommand::Set {
+ start,
+ end,
+ content,
+ } => {
+ if start.index < old_hashes.len()
+ && end.index < old_hashes.len()
+ && start.index <= end.index
+ && old_hashes[start.index] == start.hash
+ && old_hashes[end.index] == end.hash
+ {
+ set_ops[start.index] = Some((end.index, *content));
+ }
+ }
+ EditCommand::Insert { after, content } => match after {
+ None => insert_before_first.push(*content),
+ Some(line_ref) => {
+ if line_ref.index < old_hashes.len()
+ && old_hashes[line_ref.index] == line_ref.hash
+ {
+ insert_after[line_ref.index].push(*content);
+ }
+ }
+ },
+ }
+ }
+
+ let mut result = String::new();
+
+ // Emit any insertions before the first line
+ for content in &insert_before_first {
+ result.push_str(content);
+ if !content.ends_with('\n') {
+ result.push('\n');
+ }
+ }
+
+ let mut i = 0;
+ while i < original_lines.len() {
+ if let Some((end_index, replacement)) = set_ops[i].as_ref() {
+ // Replace lines i..=end_index with the replacement content
+ result.push_str(replacement);
+ if !replacement.is_empty() && !replacement.ends_with('\n') {
+ result.push('\n');
+ }
+ // Emit any insertions after the end of this set range
+ if *end_index < insert_after.len() {
+ for content in &insert_after[*end_index] {
+ result.push_str(content);
+ if !content.ends_with('\n') {
+ result.push('\n');
+ }
+ }
+ }
+ i = end_index + 1;
+ } else {
+ // Keep the original line
+ result.push_str(original_lines[i]);
+ result.push('\n');
+ // Emit any insertions after this line
+ for content in &insert_after[i] {
+ result.push_str(content);
+ if !content.ends_with('\n') {
+ result.push('\n');
+ }
+ }
+ i += 1;
+ }
+ }
+
+ // Preserve trailing newline behavior: if the original ended with a
+ // newline the result already has one; if it didn't, trim the extra one
+ // we added.
+ if !editable_region.ends_with('\n') && result.ends_with('\n') {
+ result.pop();
+ }
+
+ result
+ }
+
+ /// Convert a unified diff patch into hashline edit commands.
+ ///
+ /// Parses the unified diff `patch` directly to determine which lines of
+ /// `old_text` are deleted/replaced and what new lines are added, then emits
+ /// `<|set|>` and `<|insert|>` edit commands referencing old lines by their
+ /// `{index}:{hash}` identifiers.
+ ///
+ /// `cursor_offset` is an optional byte offset into the first hunk's new
+ /// text (context + additions) where the cursor marker should be placed.
+ pub fn patch_to_edit_commands(
+ old_text: &str,
+ patch: &str,
+ cursor_offset: Option<usize>,
+ ) -> Result<String> {
+ let old_lines: Vec<&str> = old_text.lines().collect();
+ let old_hashes: Vec<u8> = old_lines
+ .iter()
+ .map(|line| hash_line(line.as_bytes()))
+ .collect();
+
+ let mut result = String::new();
+ let mut first_hunk = true;
+
+ struct Hunk<'a> {
+ line_range: Range<usize>,
+ new_text_lines: Vec<&'a str>,
+ cursor_line_offset_in_new_text: Option<(usize, usize)>,
+ }
+
+ // Parse the patch line by line. We only care about hunk headers,
+ // context, deletions, and additions.
+ let mut old_line_index: usize = 0;
+ let mut current_hunk: Option<Hunk> = None;
+ // Byte offset tracking within the hunk's new text for cursor placement.
+ let mut new_text_byte_offset: usize = 0;
+ // The line index of the last old line seen before/in the current hunk
+ // (used for insert-after reference).
+ let mut last_old_line_before_hunk: Option<usize> = None;
+
+ fn flush_hunk(
+ hunk: Hunk,
+ last_old_line: Option<usize>,
+ result: &mut String,
+ old_hashes: &[u8],
+ ) {
+ if hunk.line_range.is_empty() {
+ // Pure insertion — reference the old line to insert after when in bounds.
+ if let Some(after) = last_old_line
+ && let Some(&hash) = old_hashes.get(after)
+ {
+ write!(
+ result,
+ "{INSERT_COMMAND_MARKER}{}\n",
+ LineRef { index: after, hash }
+ )
+ .unwrap();
+ } else {
+ result.push_str(INSERT_COMMAND_MARKER);
+ result.push('\n');
+ }
+ } else {
+ let start = hunk.line_range.start;
+ let end_exclusive = hunk.line_range.end;
+ let deleted_line_count = end_exclusive.saturating_sub(start);
+
+ if deleted_line_count == 1 {
+ if let Some(&hash) = old_hashes.get(start) {
+ write!(
+ result,
+ "{SET_COMMAND_MARKER}{}\n",
+ LineRef { index: start, hash }
+ )
+ .unwrap();
+ } else {
+ result.push_str(SET_COMMAND_MARKER);
+ result.push('\n');
+ }
+ } else {
+ let end_inclusive = end_exclusive - 1;
+ match (
+ old_hashes.get(start).copied(),
+ old_hashes.get(end_inclusive).copied(),
+ ) {
+ (Some(start_hash), Some(end_hash)) => {
+ write!(
+ result,
+ "{SET_COMMAND_MARKER}{}-{}\n",
+ LineRef {
+ index: start,
+ hash: start_hash
+ },
+ LineRef {
+ index: end_inclusive,
+ hash: end_hash
+ }
+ )
+ .unwrap();
+ }
+ _ => {
+ result.push_str(SET_COMMAND_MARKER);
+ result.push('\n');
+ }
+ }
+ }
+ }
+ for (line_offset, line) in hunk.new_text_lines.iter().enumerate() {
+ if let Some((cursor_line_offset, char_offset)) = hunk.cursor_line_offset_in_new_text
+ && line_offset == cursor_line_offset
+ {
+ result.push_str(&line[..char_offset]);
+ result.push_str(CURSOR_MARKER);
+ result.push_str(&line[char_offset..]);
+ continue;
+ }
+
+ result.push_str(line);
+ }
+ }
+
+ for raw_line in patch.split_inclusive('\n') {
+ if raw_line.starts_with("@@") {
+ // Flush any pending change hunk from a previous patch hunk.
+ if let Some(hunk) = current_hunk.take() {
+ flush_hunk(hunk, last_old_line_before_hunk, &mut result, &old_hashes);
+ }
+
+ // Parse hunk header: @@ -old_start[,old_count] +new_start[,new_count] @@
+ // We intentionally do not trust old_start as a direct local index into `old_text`,
+ // because some patches are produced against a larger file region and carry
+ // non-local line numbers. We keep indexing local by advancing from parsed patch lines.
+ if first_hunk {
+ new_text_byte_offset = 0;
+ first_hunk = false;
+ }
+ continue;
+ }
+
+ if raw_line.starts_with("---") || raw_line.starts_with("+++") {
+ continue;
+ }
+ if raw_line.starts_with("\\ No newline") {
+ continue;
+ }
+
+ if raw_line.starts_with('-') {
+ // Extend or start a change hunk with this deleted old line.
+ match &mut current_hunk {
+ Some(Hunk {
+ line_range: range, ..
+ }) => range.end = old_line_index + 1,
+ None => {
+ current_hunk = Some(Hunk {
+ line_range: old_line_index..old_line_index + 1,
+ new_text_lines: Vec::new(),
+ cursor_line_offset_in_new_text: None,
+ });
+ }
+ }
+ old_line_index += 1;
+ } else if let Some(added_content) = raw_line.strip_prefix('+') {
+ // Place cursor marker if cursor_offset falls within this line.
+ let mut cursor_line_offset = None;
+ if let Some(cursor_off) = cursor_offset
+ && (first_hunk
+ || cursor_off >= new_text_byte_offset
+ && cursor_off <= new_text_byte_offset + added_content.len())
+ {
+ let line_offset = added_content.floor_char_boundary(
+ cursor_off
+ .saturating_sub(new_text_byte_offset)
+ .min(added_content.len()),
+ );
+ cursor_line_offset = Some(line_offset);
+ }
+
+ new_text_byte_offset += added_content.len();
+
+ let hunk = current_hunk.get_or_insert(Hunk {
+ line_range: old_line_index..old_line_index,
+ new_text_lines: vec![],
+ cursor_line_offset_in_new_text: None,
+ });
+ hunk.new_text_lines.push(added_content);
+ hunk.cursor_line_offset_in_new_text = cursor_line_offset
+ .map(|offset_in_line| (hunk.new_text_lines.len() - 1, offset_in_line));
+ } else {
+ // Context line (starts with ' ' or is empty).
+ if let Some(hunk) = current_hunk.take() {
+ flush_hunk(hunk, last_old_line_before_hunk, &mut result, &old_hashes);
+ }
+ last_old_line_before_hunk = Some(old_line_index);
+ old_line_index += 1;
+ let content = raw_line.strip_prefix(' ').unwrap_or(raw_line);
+ new_text_byte_offset += content.len();
+ }
+ }
+
+ // Flush final group.
+ if let Some(hunk) = current_hunk.take() {
+ flush_hunk(hunk, last_old_line_before_hunk, &mut result, &old_hashes);
+ }
+
+ // Trim a single trailing newline.
+ if result.ends_with('\n') {
+ result.pop();
+ }
+
+ Ok(result)
+ }
+
+ #[cfg(test)]
+ mod tests {
+ use super::*;
+ use indoc::indoc;
+
+ #[test]
+ fn test_format_cursor_region() {
+ struct Case {
+ name: &'static str,
+ context: &'static str,
+ editable_range: Range<usize>,
+ cursor_offset: usize,
+ expected: &'static str,
+ }
+
+ let cases = [
+ Case {
+ name: "basic_cursor_placement",
+ context: "hello world\n",
+ editable_range: 0..12,
+ cursor_offset: 5,
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ 0:5c|hello<|user_cursor|> world
+ <|fim_suffix|>
+ <|fim_middle|>updated"},
+ },
+ Case {
+ name: "multiline_cursor_on_second_line",
+ context: "aaa\nbbb\nccc\n",
+ editable_range: 0..12,
+ cursor_offset: 5, // byte 5 → 1 byte into "bbb"
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ 0:23|aaa
+ 1:26|b<|user_cursor|>bb
+ 2:29|ccc
+ <|fim_suffix|>
+ <|fim_middle|>updated"},
+ },
+ Case {
+ name: "no_trailing_newline_in_context",
+ context: "line1\nline2",
+ editable_range: 0..11,
+ cursor_offset: 3,
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ 0:d9|lin<|user_cursor|>e1
+ 1:da|line2
+ <|fim_suffix|>
+ <|fim_middle|>updated"},
+ },
+ Case {
+ name: "leading_newline_in_editable_region",
+ context: "\nabc\n",
+ editable_range: 0..5,
+ cursor_offset: 2, // byte 2 = 'a' in "abc" (after leading \n)
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ 0:00|
+ 1:26|a<|user_cursor|>bc
+ <|fim_suffix|>
+ <|fim_middle|>updated"},
+ },
+ Case {
+ name: "with_suffix",
+ context: "abc\ndef",
+ editable_range: 0..4, // editable region = "abc\n", suffix = "def"
+ cursor_offset: 2,
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ 0:26|ab<|user_cursor|>c
+ <|fim_suffix|>
+ def
+ <|fim_middle|>updated"},
+ },
+ Case {
+ name: "unicode_two_byte_chars",
+ context: "héllo\n",
+ editable_range: 0..7,
+ cursor_offset: 3, // byte 3 = after "hé" (h=1 byte, é=2 bytes), before "llo"
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ 0:1b|hé<|user_cursor|>llo
+ <|fim_suffix|>
+ <|fim_middle|>updated"},
+ },
+ Case {
+ name: "unicode_three_byte_chars",
+ context: "日本語\n",
+ editable_range: 0..10,
+ cursor_offset: 6, // byte 6 = after "日本" (3+3 bytes), before "語"
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ 0:80|日本<|user_cursor|>語
+ <|fim_suffix|>
+ <|fim_middle|>updated"},
+ },
+ Case {
+ name: "unicode_four_byte_chars",
+ context: "a🌍b\n",
+ editable_range: 0..7,
+ cursor_offset: 5, // byte 5 = after "a🌍" (1+4 bytes), before "b"
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ 0:6b|a🌍<|user_cursor|>b
+ <|fim_suffix|>
+ <|fim_middle|>updated"},
+ },
+ Case {
+ name: "cursor_at_start_of_region_not_placed",
+ context: "abc\n",
+ editable_range: 0..4,
+ cursor_offset: 0, // cursor_offset(0) > offset(0) is false → cursor not placed
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ 0:26|abc
+ <|fim_suffix|>
+ <|fim_middle|>updated"},
+ },
+ Case {
+ name: "cursor_at_end_of_line_not_placed",
+ context: "abc\ndef\n",
+ editable_range: 0..8,
+ cursor_offset: 3, // byte 3 = the \n after "abc" → falls between lines, not placed
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ 0:26|abc
+ 1:2f|def
+ <|fim_suffix|>
+ <|fim_middle|>updated"},
+ },
+ Case {
+ name: "cursor_offset_relative_to_context_not_editable_region",
+ // cursor_offset is relative to `context`, so when editable_range.start > 0,
+ // write_cursor_excerpt_section must subtract it before comparing against
+ // per-line offsets within the editable region.
+ context: "pre\naaa\nbbb\nsuf\n",
+ editable_range: 4..12, // editable region = "aaa\nbbb\n"
+ cursor_offset: 9, // byte 9 in context = second 'b' in "bbb"
+ expected: indoc! {"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ pre
+ <|fim_middle|>current
+ 0:23|aaa
+ 1:26|b<|user_cursor|>bb
+ <|fim_suffix|>
+ suf
+ <|fim_middle|>updated"},
+ },
+ ];
+
+ for case in &cases {
+ let mut prompt = String::new();
+ hashline::write_cursor_excerpt_section(
+ &mut prompt,
+ Path::new("test.rs"),
+ case.context,
+ &case.editable_range,
+ case.cursor_offset,
+ );
+ assert_eq!(prompt, case.expected, "failed case: {}", case.name);
+ }
+ }
+
+ #[test]
+ fn test_apply_edit_commands() {
+ struct Case {
+ name: &'static str,
+ original: &'static str,
+ model_output: &'static str,
+ expected: &'static str,
+ }
+
+ let cases = vec![
+ Case {
+ name: "set_single_line",
+ original: indoc! {"
+ let mut total = 0;
+ for product in products {
+ total += ;
+ }
+ total
+ "},
+ model_output: indoc! {"
+ <|set|>2:87
+ total += product.price;
+ "},
+ expected: indoc! {"
+ let mut total = 0;
+ for product in products {
+ total += product.price;
+ }
+ total
+ "},
+ },
+ Case {
+ name: "set_range",
+ original: indoc! {"
+ fn foo() {
+ let x = 1;
+ let y = 2;
+ let z = 3;
+ }
+ "},
+ model_output: indoc! {"
+ <|set|>1:46-3:4a
+ let sum = 6;
+ "},
+ expected: indoc! {"
+ fn foo() {
+ let sum = 6;
+ }
+ "},
+ },
+ Case {
+ name: "insert_after_line",
+ original: indoc! {"
+ fn main() {
+ let x = 1;
+ }
+ "},
+ model_output: indoc! {"
+ <|insert|>1:46
+ let y = 2;
+ "},
+ expected: indoc! {"
+ fn main() {
+ let x = 1;
+ let y = 2;
+ }
+ "},
+ },
+ Case {
+ name: "insert_before_first",
+ original: indoc! {"
+ let x = 1;
+ let y = 2;
+ "},
+ model_output: indoc! {"
+ <|insert|>
+ use std::io;
+ "},
+ expected: indoc! {"
+ use std::io;
+ let x = 1;
+ let y = 2;
+ "},
+ },
+ Case {
+ name: "set_with_cursor_marker",
+ original: indoc! {"
+ fn main() {
+ println!();
+ }
+ "},
+ model_output: indoc! {"
+ <|set|>1:34
+ eprintln!(\"<|user_cursor|>\");
+ "},
+ expected: indoc! {"
+ fn main() {
+ eprintln!(\"<|user_cursor|>\");
+ }
+ "},
+ },
+ Case {
+ name: "multiple_set_commands",
+ original: indoc! {"
+ aaa
+ bbb
+ ccc
+ ddd
+ "},
+ model_output: indoc! {"
+ <|set|>0:23
+ AAA
+ <|set|>2:29
+ CCC
+ "},
+ expected: indoc! {"
+ AAA
+ bbb
+ CCC
+ ddd
+ "},
+ },
+ Case {
+ name: "set_range_multiline_replacement",
+ original: indoc! {"
+ fn handle_submit() {
+ }
+
+ fn handle_keystroke() {
+ "},
+ model_output: indoc! {"
+ <|set|>0:3f-1:7d
+ fn handle_submit(modal_state: &mut ModalState) {
+ <|user_cursor|>
+ }
+ "},
+ expected: indoc! {"
+ fn handle_submit(modal_state: &mut ModalState) {
+ <|user_cursor|>
+ }
+
+ fn handle_keystroke() {
+ "},
+ },
+ Case {
+ name: "no_edit_commands_returns_original",
+ original: indoc! {"
+ hello
+ world
+ "},
+ model_output: "some random text with no commands",
+ expected: indoc! {"
+ hello
+ world
+ "},
+ },
+ Case {
+ name: "wrong_hash_set_ignored",
+ original: indoc! {"
+ aaa
+ bbb
+ "},
+ model_output: indoc! {"
+ <|set|>0:ff
+ ZZZ
+ "},
+ expected: indoc! {"
+ aaa
+ bbb
+ "},
+ },
+ Case {
+ name: "insert_and_set_combined",
+ original: indoc! {"
+ alpha
+ beta
+ gamma
+ "},
+ model_output: indoc! {"
+ <|set|>0:06
+ ALPHA
+ <|insert|>1:9c
+ beta_extra
+ "},
+ expected: indoc! {"
+ ALPHA
+ beta
+ beta_extra
+ gamma
+ "},
+ },
+ Case {
+ name: "no_trailing_newline_preserved",
+ original: "hello\nworld",
+ model_output: indoc! {"
+ <|set|>0:14
+ HELLO
+ "},
+ expected: "HELLO\nworld",
+ },
+ Case {
+ name: "set_range_hash_mismatch_in_end_bound",
+ original: indoc! {"
+ one
+ two
+ three
+ "},
+ model_output: indoc! {"
+ <|set|>0:42-2:ff
+ ONE_TWO_THREE
+ "},
+ expected: indoc! {"
+ one
+ two
+ three
+ "},
+ },
+ Case {
+ name: "set_range_start_greater_than_end_ignored",
+ original: indoc! {"
+ a
+ b
+ c
+ "},
+ model_output: indoc! {"
+ <|set|>2:63-1:62
+ X
+ "},
+ expected: indoc! {"
+ a
+ b
+ c
+ "},
+ },
+ Case {
+ name: "insert_out_of_bounds_ignored",
+ original: indoc! {"
+ x
+ y
+ "},
+ model_output: indoc! {"
+ <|insert|>99:aa
+ z
+ "},
+ expected: indoc! {"
+ x
+ y
+ "},
+ },
+ Case {
+ name: "set_out_of_bounds_ignored",
+ original: indoc! {"
+ x
+ y
+ "},
+ model_output: indoc! {"
+ <|set|>99:aa
+ z
+ "},
+ expected: indoc! {"
+ x
+ y
+ "},
+ },
+ Case {
+ name: "malformed_set_command_ignored",
+ original: indoc! {"
+ alpha
+ beta
+ "},
+ model_output: indoc! {"
+ <|set|>not-a-line-ref
+ UPDATED
+ "},
+ expected: indoc! {"
+ alpha
+ beta
+ "},
+ },
+ Case {
+ name: "malformed_insert_hash_treated_as_before_first",
+ original: indoc! {"
+ alpha
+ beta
+ "},
+ model_output: indoc! {"
+ <|insert|>1:nothex
+ preamble
+ "},
+ expected: indoc! {"
+ preamble
+ alpha
+ beta
+ "},
+ },
+ Case {
+ name: "set_then_insert_same_target_orders_insert_after_replacement",
+ original: indoc! {"
+ cat
+ dog
+ "},
+ model_output: indoc! {"
+ <|set|>0:38
+ CAT
+ <|insert|>0:38
+ TAIL
+ "},
+ expected: indoc! {"
+ CAT
+ TAIL
+ dog
+ "},
+ },
+ Case {
+ name: "overlapping_set_ranges_last_wins",
+ original: indoc! {"
+ a
+ b
+ c
+ d
+ "},
+ model_output: indoc! {"
+ <|set|>0:61-2:63
+ FIRST
+ <|set|>1:62-3:64
+ SECOND
+ "},
+ expected: indoc! {"
+ FIRST
+ d
+ "},
+ },
+ Case {
+ name: "insert_before_first_and_after_line",
+ original: indoc! {"
+ a
+ b
+ "},
+ model_output: indoc! {"
+ <|insert|>
+ HEAD
+ <|insert|>0:61
+ MID
+ "},
+ expected: indoc! {"
+ HEAD
+ a
+ MID
+ b
+ "},
+ },
+ ];
+
+ for case in &cases {
+ let result = hashline::apply_edit_commands(case.original, &case.model_output);
+ assert_eq!(result, case.expected, "failed case: {}", case.name);
+ }
+ }
+
+ #[test]
+ fn test_output_has_edit_commands() {
+ assert!(hashline::output_has_edit_commands(&format!(
+ "{}0:ab\nnew",
+ SET_COMMAND_MARKER
+ )));
+ assert!(hashline::output_has_edit_commands(&format!(
+ "{}0:ab\nnew",
+ INSERT_COMMAND_MARKER
+ )));
+ assert!(hashline::output_has_edit_commands(&format!(
+ "some text\n{}1:cd\nstuff",
+ SET_COMMAND_MARKER
+ )));
+ assert!(!hashline::output_has_edit_commands("just plain text"));
+ assert!(!hashline::output_has_edit_commands("NO_EDITS"));
+ }
+
+ // ---- hashline::patch_to_edit_commands round-trip tests ----
+
+ #[test]
+ fn test_patch_to_edit_commands() {
+ struct Case {
+ name: &'static str,
+ old: &'static str,
+ patch: &'static str,
+ expected_new: &'static str,
+ }
+
+ let cases = [
+ Case {
+ name: "single_line_replacement",
+ old: indoc! {"
+ let mut total = 0;
+ for product in products {
+ total += ;
+ }
+ total
+ "},
+ patch: indoc! {"
+ @@ -1,5 +1,5 @@
+ let mut total = 0;
+ for product in products {
+ - total += ;
+ + total += product.price;
+ }
+ total
+ "},
+ expected_new: indoc! {"
+ let mut total = 0;
+ for product in products {
+ total += product.price;
+ }
+ total
+ "},
+ },
+ Case {
+ name: "multiline_replacement",
+ old: indoc! {"
+ fn foo() {
+ let x = 1;
+ let y = 2;
+ let z = 3;
+ }
+ "},
+ patch: indoc! {"
+ @@ -1,5 +1,3 @@
+ fn foo() {
+ - let x = 1;
+ - let y = 2;
+ - let z = 3;
+ + let sum = 1 + 2 + 3;
+ }
+ "},
+ expected_new: indoc! {"
+ fn foo() {
+ let sum = 1 + 2 + 3;
+ }
+ "},
+ },
+ Case {
+ name: "insertion",
+ old: indoc! {"
+ fn main() {
+ let x = 1;
+ }
+ "},
+ patch: indoc! {"
+ @@ -1,3 +1,4 @@
+ fn main() {
+ let x = 1;
+ + let y = 2;
+ }
+ "},
+ expected_new: indoc! {"
+ fn main() {
+ let x = 1;
+ let y = 2;
+ }
+ "},
+ },
+ Case {
+ name: "insertion_before_first",
+ old: indoc! {"
+ let x = 1;
+ let y = 2;
+ "},
+ patch: indoc! {"
+ @@ -1,2 +1,3 @@
+ +use std::io;
+ let x = 1;
+ let y = 2;
+ "},
+ expected_new: indoc! {"
+ use std::io;
+ let x = 1;
+ let y = 2;
+ "},
+ },
+ Case {
+ name: "deletion",
+ old: indoc! {"
+ aaa
+ bbb
+ ccc
+ ddd
+ "},
+ patch: indoc! {"
+ @@ -1,4 +1,2 @@
+ aaa
+ -bbb
+ -ccc
+ ddd
+ "},
+ expected_new: indoc! {"
+ aaa
+ ddd
+ "},
+ },
+ Case {
+ name: "multiple_changes",
+ old: indoc! {"
+ alpha
+ beta
+ gamma
+ delta
+ epsilon
+ "},
+ patch: indoc! {"
+ @@ -1,5 +1,5 @@
+ -alpha
+ +ALPHA
+ beta
+ gamma
+ -delta
+ +DELTA
+ epsilon
+ "},
+ expected_new: indoc! {"
+ ALPHA
+ beta
+ gamma
+ DELTA
+ epsilon
+ "},
+ },
+ Case {
+ name: "replace_with_insertion",
+ old: indoc! {r#"
+ fn handle() {
+ modal_state.close();
+ modal_state.dismiss();
+ "#},
+ patch: indoc! {r#"
+ @@ -1,3 +1,4 @@
+ fn handle() {
+ modal_state.close();
+ + eprintln!("");
+ modal_state.dismiss();
+ "#},
+ expected_new: indoc! {r#"
+ fn handle() {
+ modal_state.close();
+ eprintln!("");
+ modal_state.dismiss();
+ "#},
+ },
+ Case {
+ name: "complete_replacement",
+ old: indoc! {"
+ aaa
+ bbb
+ ccc
+ "},
+ patch: indoc! {"
+ @@ -1,3 +1,3 @@
+ -aaa
+ -bbb
+ -ccc
+ +xxx
+ +yyy
+ +zzz
+ "},
+ expected_new: indoc! {"
+ xxx
+ yyy
+ zzz
+ "},
+ },
+ Case {
+ name: "add_function_body",
+ old: indoc! {"
+ fn foo() {
+ modal_state.dismiss();
+ }
+
+ fn
+
+ fn handle_keystroke() {
+ "},
+ patch: indoc! {"
+ @@ -1,6 +1,8 @@
+ fn foo() {
+ modal_state.dismiss();
+ }
+
+ -fn
+ +fn handle_submit() {
+ + todo()
+ +}
+
+ fn handle_keystroke() {
+ "},
+ expected_new: indoc! {"
+ fn foo() {
+ modal_state.dismiss();
+ }
+
+ fn handle_submit() {
+ todo()
+ }
+
+ fn handle_keystroke() {
+ "},
+ },
+ Case {
+ name: "with_cursor_offset",
+ old: indoc! {r#"
+ fn main() {
+ println!();
+ }
+ "#},
+ patch: indoc! {r#"
+ @@ -1,3 +1,3 @@
+ fn main() {
+ - println!();
+ + eprintln!("");
+ }
+ "#},
+ expected_new: indoc! {r#"
+ fn main() {
+ eprintln!("<|user_cursor|>");
+ }
+ "#},
+ },
+ Case {
+ name: "non_local_hunk_header_pure_insertion_repro",
+ old: indoc! {"
+ aaa
+ bbb
+ "},
+ patch: indoc! {"
+ @@ -20,2 +20,3 @@
+ aaa
+ +xxx
+ bbb
+ "},
+ expected_new: indoc! {"
+ aaa
+ xxx
+ bbb
+ "},
+ },
+ ];
+
+ for case in &cases {
+ // The cursor_offset for patch_to_edit_commands is relative to
+ // the first hunk's new text (context + additions). We compute
+ // it by finding where the marker sits in the expected output
+ // (which mirrors the new text of the hunk).
+ let cursor_offset = case.expected_new.find(CURSOR_MARKER);
+
+ let commands =
+ hashline::patch_to_edit_commands(case.old, case.patch, cursor_offset)
+ .unwrap_or_else(|e| panic!("failed case {}: {e}", case.name));
+
+ assert!(
+ hashline::output_has_edit_commands(&commands),
+ "case {}: expected edit commands, got: {commands:?}",
+ case.name,
+ );
+
+ let applied = hashline::apply_edit_commands(case.old, &commands);
+ assert_eq!(applied, case.expected_new, "case {}", case.name);
+ }
+ }
+ }
+}
+
pub mod seed_coder {
//! Seed-Coder prompt format using SPM (Suffix-Prefix-Middle) FIM mode.
//!
@@ -1,2 +1,5 @@
# Handlebars partials are not supported by Prettier.
*.hbs
+
+# Automatically generated
+theme/c15t@*.js
@@ -64,6 +64,22 @@ This will render a human-readable version of the action name, e.g., "zed: open s
Templates are functions that modify the source of the docs pages (usually with a regex match and replace).
You can see how the actions and keybindings are templated in `crates/docs_preprocessor/src/main.rs` for reference on how to create new templates.
+## Consent Banner
+
+We pre-bundle the `c15t` package because the docs pipeline does not include a JS bundler. If you need to update `c15t` and rebuild the bundle, use:
+
+```
+mkdir c15t-bundle && cd c15t-bundle
+npm init -y
+npm install c15t@<version> esbuild
+echo "import { getOrCreateConsentRuntime } from 'c15t'; window.c15t = { getOrCreateConsentRuntime };" > entry.js
+npx esbuild entry.js --bundle --format=iife --minify --outfile=c15t@<version>.js
+cp c15t@<version>.js ../theme/c15t@<version>.js
+cd .. && rm -rf c15t-bundle
+```
+
+Replace `<version>` with the new version of `c15t` you are installing. Then update `book.toml` to reference the new bundle filename.
+
### References
- Template Trait: `crates/docs_preprocessor/src/templates.rs`
@@ -23,8 +23,8 @@ default-description = "Learn how to use and customize Zed, the fast, collaborati
default-title = "Zed Code Editor Documentation"
no-section-label = true
preferred-dark-theme = "dark"
-additional-css = ["theme/page-toc.css", "theme/plugins.css", "theme/highlight.css"]
-additional-js = ["theme/page-toc.js", "theme/plugins.js"]
+additional-css = ["theme/page-toc.css", "theme/plugins.css", "theme/highlight.css", "theme/consent-banner.css"]
+additional-js = ["theme/page-toc.js", "theme/plugins.js", "theme/c15t@2.0.0-rc.3.js", "theme/analytics.js"]
[output.zed-html.print]
enable = false
@@ -88,7 +88,7 @@ With that done, choose one of the three authentication methods:
While it's possible to configure through the Agent Panel settings UI by entering your AWS access key and secret directly, we recommend using named profiles instead for better security practices.
To do this:
-1. Create an IAM User that you can assume in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users).
+1. Create an IAM User in the [IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users).
2. Create security credentials for that User, save them and keep them secure.
3. Open the Agent Configuration with (`agent: open settings`) and go to the Amazon Bedrock section
4. Copy the credentials from Step 2 into the respective **Access Key ID**, **Secret Access Key**, and **Region** fields.
@@ -91,6 +91,6 @@ Executes shell commands and returns the combined output, creating a new shell pr
## Other Tools
-### `subagent`
+### `spawn_agent`
-Spawns a subagent with its own context window to perform a delegated task. Useful for running parallel investigations, completing self-contained tasks, or performing research where only the outcome matters. Each subagent has access to the same tools as the parent agent.
+Spawns a subagent with its own context window to perform a delegated task. Each subagent has access to the same tools as the parent agent.
@@ -122,11 +122,40 @@ You can specify your preference using the `language_servers` setting:
In this example:
-- `intelephense` is set as the primary language server
-- `phpactor` is disabled (note the `!` prefix)
-- `...` expands to the rest of the language servers that are registered for PHP
+- `intelephense` is set as the primary language server.
+- `phpactor` and `phptools` are disabled (note the `!` prefix).
+- `"..."` expands to the rest of the language servers registered for PHP that are not already listed.
-This configuration allows you to tailor the language server setup to your specific needs, ensuring that you get the most suitable functionality for your development workflow.
+The `"..."` entry acts as a wildcard that includes any registered language server you haven't explicitly mentioned. Servers you list by name keep their position, and `"..."` fills in the remaining ones at that point in the list. Servers prefixed with `!` are excluded entirely. This means that if a new language server extension is installed or a new server is registered for a language, `"..."` will automatically include it. If you want full control over which servers are enabled, omit `"..."` — only the servers you list by name will be used.
+
+#### Examples
+
+Suppose you're working with Ruby. The default configuration is:
+
+```json [settings]
+{
+ "language_servers": [
+ "solargraph",
+ "!ruby-lsp",
+ "!rubocop",
+ "!sorbet",
+ "!steep",
+ "!kanayago",
+ "..."
+ ]
+}
+```
+
+When you override `language_servers` in your settings, your list **replaces** the default entirely. This means default-disabled servers like `kanayago` will be re-enabled by `"..."` unless you explicitly disable them again.
+
+| Configuration | Result |
+| ------------------------------------------------- | ------------------------------------------------------------------ |
+| `["..."]` | `solargraph`, `ruby-lsp`, `rubocop`, `sorbet`, `steep`, `kanayago` |
+| `["ruby-lsp", "..."]` | `ruby-lsp`, `solargraph`, `rubocop`, `sorbet`, `steep`, `kanayago` |
+| `["ruby-lsp", "!solargraph", "!kanayago", "..."]` | `ruby-lsp`, `rubocop`, `sorbet`, `steep` |
+| `["ruby-lsp", "solargraph"]` | `ruby-lsp`, `solargraph` |
+
+> Note: In the first example, `"..."` includes `kanayago` even though it is disabled by default. The override replaced the default list, so the `"!kanayago"` entry is no longer present. To keep it disabled, you must include `"!kanayago"` in your configuration.
### Toolchains
@@ -89,8 +89,8 @@ Configure language servers in Settings ({#kb zed::OpenSettings}) under Languages
"languages": {
"Python": {
"language_servers": [
- // Disable basedpyright and enable ty, and otherwise
- // use the default configuration.
+ // Disable basedpyright and enable ty, and include all
+ // other registered language servers (ruff, pylsp, pyright).
"ty",
"!basedpyright",
"..."
@@ -78,7 +78,7 @@ Download the importer
- `cd import && mkdir build && cd build`
- Run cmake to generate build files: `cmake -G Ninja -DCMAKE_BUILD_TYPE=Release ..`
- Build the importer: `ninja`
-- Run the importer on the trace file: `./tracy-import-miniprofiler /path/to/trace.miniprof /path/to/output.tracy`
+- Run the importer on the trace file: `./tracy-import-miniprofiler /path/to/trace.miniprof.json /path/to/output.tracy`
- Open the trace in tracy:
- If you're on windows download the v0.12.2 version from the releases on the upstream repo
- If you're on other platforms open it on the website: https://tracy.nereid.pl/ (the version might mismatch so your luck might vary, we need to host our own ideally..)
@@ -87,7 +87,7 @@ Download the importer
- Run the action: `zed open performance profiler`
- Hit the save button. This opens a save dialog or if that fails to open the trace gets saved in your working directory.
-- Convert the profile so it can be imported in tracy using the importer: `./tracy-import-miniprofiler <path to performance_profile.miniprof> output.tracy`
+- Convert the profile so it can be imported in tracy using the importer: `./tracy-import-miniprofiler <path to performance_profile.miniprof.json> output.tracy`
- Go to <https://tracy.nereid.pl/> hit the 'power button' in the top left and then open saved trace.
- Now zoom in to see the tasks and how long they took
@@ -0,0 +1,93 @@
+const amplitudeKey = document.querySelector(
+ 'meta[name="amplitude-key"]',
+)?.content;
+const consentInstance = document.querySelector(
+ 'meta[name="consent-io-instance"]',
+)?.content;
+
+document.addEventListener("DOMContentLoaded", () => {
+ if (!consentInstance || consentInstance.length === 0) return;
+ const { getOrCreateConsentRuntime } = window.c15t;
+
+ const { consentStore } = getOrCreateConsentRuntime({
+ mode: "c15t",
+ backendURL: consentInstance,
+ consentCategories: ["necessary", "measurement", "marketing"],
+ storageConfig: {
+ crossSubdomain: true,
+ },
+ scripts: [
+ {
+ id: "amplitude",
+ src: `https://cdn.amplitude.com/script/${amplitudeKey}.js`,
+ category: "measurement",
+ onLoad: () => {
+ window.amplitude.init(amplitudeKey, {
+ fetchRemoteConfig: true,
+ autocapture: true,
+ });
+ },
+ },
+ ],
+ });
+
+ let previousActiveUI = consentStore.getState().activeUI;
+ const banner = document.getElementById("c15t-banner");
+ const configureSection = document.getElementById("c15t-configure-section");
+ const configureBtn = document.getElementById("c15t-configure-btn");
+ const measurementToggle = document.getElementById("c15t-toggle-measurement");
+ const marketingToggle = document.getElementById("c15t-toggle-marketing");
+
+ const toggleConfigureMode = () => {
+ const currentConsents = consentStore.getState().consents;
+ measurementToggle.checked = currentConsents
+ ? (currentConsents.measurement ?? false)
+ : false;
+ marketingToggle.checked = currentConsents
+ ? (currentConsents.marketing ?? false)
+ : false;
+ configureSection.style.display = "flex";
+ configureBtn.innerHTML = "Save";
+ configureBtn.className = "c15t-button secondary";
+ configureBtn.title = "";
+ };
+
+ consentStore.subscribe((state) => {
+ const hideBanner =
+ state.activeUI === "none" ||
+ (state.activeUI === "banner" && state.mode === "opt-out");
+ banner.style.display = hideBanner ? "none" : "block";
+
+ if (state.activeUI === "dialog" && previousActiveUI !== "dialog") {
+ toggleConfigureMode();
+ }
+
+ previousActiveUI = state.activeUI;
+ });
+
+ configureBtn.addEventListener("click", () => {
+ if (consentStore.getState().activeUI === "dialog") {
+ consentStore
+ .getState()
+ .setConsent("measurement", measurementToggle.checked);
+ consentStore.getState().setConsent("marketing", marketingToggle.checked);
+ consentStore.getState().saveConsents("custom");
+ } else {
+ consentStore.getState().setActiveUI("dialog");
+ }
+ });
+
+ document.getElementById("c15t-accept").addEventListener("click", () => {
+ consentStore.getState().saveConsents("all");
+ });
+
+ document.getElementById("c15t-decline").addEventListener("click", () => {
+ consentStore.getState().saveConsents("necessary");
+ });
+
+ document
+ .getElementById("c15t-manage-consent-btn")
+ .addEventListener("click", () => {
+ consentStore.getState().setActiveUI("dialog");
+ });
+});
@@ -0,0 +1 @@
@@ -0,0 +1,292 @@
+#c15t-banner {
+ --color-offgray-50: hsl(218, 12%, 95%);
+ --color-offgray-100: hsl(218, 12%, 88%);
+ --color-offgray-200: hsl(218, 12%, 80%);
+ --color-offgray-300: hsl(218, 12%, 75%);
+ --color-offgray-400: hsl(218, 12%, 64%);
+ --color-offgray-500: hsl(218, 12%, 56%);
+ --color-offgray-600: hsl(218, 12%, 48%);
+ --color-offgray-700: hsl(218, 12%, 40%);
+ --color-offgray-800: hsl(218, 12%, 34%);
+ --color-offgray-900: hsl(218, 12%, 24%);
+ --color-offgray-950: hsl(218, 12%, 15%);
+ --color-offgray-1000: hsl(218, 12%, 5%);
+
+ --color-blue-50: oklch(97% 0.014 254.604);
+ --color-blue-100: oklch(93.2% 0.032 255.585);
+ --color-blue-200: oklch(88.2% 0.059 254.128);
+ --color-blue-300: oklch(80.9% 0.105 251.813);
+ --color-blue-400: oklch(70.7% 0.165 254.624);
+ --color-blue-500: oklch(62.3% 0.214 259.815);
+ --color-blue-600: oklch(54.6% 0.245 262.881);
+ --color-blue-700: oklch(48.8% 0.243 264.376);
+ --color-blue-800: oklch(42.4% 0.199 265.638);
+ --color-blue-900: oklch(37.9% 0.146 265.522);
+ --color-blue-950: oklch(28.2% 0.091 267.935);
+
+ --color-accent-blue: hsla(218, 93%, 42%, 1);
+
+ position: fixed;
+ z-index: 9999;
+ bottom: 16px;
+ right: 16px;
+ border-radius: 4px;
+ max-width: 300px;
+ background: white;
+ border: 1px solid
+ color-mix(in oklab, var(--color-offgray-200) 50%, transparent);
+ box-shadow: 6px 6px 0
+ color-mix(in oklab, var(--color-accent-blue) 6%, transparent);
+}
+
+.dark #c15t-banner {
+ border-color: color-mix(in oklab, var(--color-offgray-600) 14%, transparent);
+ background: var(--color-offgray-1000);
+ box-shadow: 5px 5px 0
+ color-mix(in oklab, var(--color-accent-blue) 8%, transparent);
+}
+
+#c15t-banner > div:first-child {
+ padding: 12px;
+ display: flex;
+ flex-direction: column;
+}
+
+#c15t-banner a {
+ color: var(--links);
+ text-decoration: underline;
+ text-decoration-color: var(--link-line-decoration);
+}
+
+#c15t-banner a:hover {
+ text-decoration-color: var(--link-line-decoration-hover);
+}
+
+#c15t-description {
+ font-size: 12px;
+ margin: 0;
+ margin-top: 4px;
+}
+
+#c15t-configure-section {
+ display: flex;
+ flex-direction: column;
+ gap: 8px;
+ border-top: 1px solid var(--divider);
+ padding: 12px;
+}
+
+#c15t-configure-section > div {
+ display: flex;
+ align-items: center;
+ justify-content: space-between;
+}
+
+#c15t-configure-section label {
+ text-transform: uppercase;
+ font-size: 11px;
+}
+
+#c15t-footer {
+ padding: 12px;
+ display: flex;
+ justify-content: space-between;
+ border-top: 1px solid var(--divider);
+ background-color: color-mix(
+ in oklab,
+ var(--color-offgray-50) 50%,
+ transparent
+ );
+}
+
+.dark #c15t-footer {
+ background-color: color-mix(
+ in oklab,
+ var(--color-offgray-600) 4%,
+ transparent
+ );
+}
+
+.c15t-button {
+ display: inline-flex;
+ align-items: center;
+ justify-content: center;
+ max-height: 28px;
+ color: black;
+ padding: 4px 8px;
+ font-size: 14px;
+ border-radius: 4px;
+ background: transparent;
+ border: 1px solid transparent;
+ transition: 100ms;
+ transition-property: box-shadow, border-color, background-color;
+}
+
+.c15t-button:hover {
+ background: color-mix(in oklab, var(--color-offgray-100) 50%, transparent);
+}
+
+.dark .c15t-button {
+ color: var(--color-offgray-50);
+}
+
+.dark .c15t-button:hover {
+ background: color-mix(in oklab, var(--color-offgray-500) 10%, transparent);
+}
+
+.c15t-button.icon {
+ padding: 0;
+ width: 24px;
+ height: 24px;
+}
+
+.c15t-button.primary {
+ color: var(--color-blue-700);
+ background: color-mix(in oklab, var(--color-blue-50) 60%, transparent);
+ border-color: color-mix(in oklab, var(--color-blue-500) 20%, transparent);
+ box-shadow: color-mix(in oklab, var(--color-blue-400) 10%, transparent) 0 -2px
+ 0 0 inset;
+}
+
+.c15t-button.primary:hover {
+ background: color-mix(in oklab, var(--color-blue-100) 50%, transparent);
+ box-shadow: none;
+}
+
+.dark .c15t-button.primary {
+ color: var(--color-blue-50);
+ background: color-mix(in oklab, var(--color-blue-500) 10%, transparent);
+ border-color: color-mix(in oklab, var(--color-blue-300) 10%, transparent);
+ box-shadow: color-mix(in oklab, var(--color-blue-300) 8%, transparent) 0 -2px
+ 0 0 inset;
+}
+
+.dark .c15t-button.primary:hover {
+ background: color-mix(in oklab, var(--color-blue-500) 20%, transparent);
+ box-shadow: none;
+}
+
+.c15t-button.secondary {
+ background: color-mix(in oklab, var(--color-offgray-50) 60%, transparent);
+ border-color: color-mix(in oklab, var(--color-offgray-200) 50%, transparent);
+ box-shadow: color-mix(in oklab, var(--color-offgray-500) 10%, transparent)
+ 0 -2px 0 0 inset;
+}
+
+.c15t-button.secondary:hover {
+ background: color-mix(in oklab, var(--color-offgray-100) 50%, transparent);
+ box-shadow: none;
+}
+
+.dark .c15t-button.secondary {
+ background: color-mix(in oklab, var(--color-offgray-300) 5%, transparent);
+ border-color: color-mix(in oklab, var(--color-offgray-400) 20%, transparent);
+ box-shadow: color-mix(in oklab, var(--color-offgray-300) 8%, transparent)
+ 0 -2px 0 0 inset;
+}
+
+.dark .c15t-button.secondary:hover {
+ background: color-mix(in oklab, var(--color-offgray-200) 10%, transparent);
+ box-shadow: none;
+}
+
+.c15t-switch {
+ position: relative;
+ display: inline-block;
+ width: 32px;
+ height: 20px;
+ flex-shrink: 0;
+}
+
+.c15t-switch input {
+ opacity: 0;
+ width: 0;
+ height: 0;
+ position: absolute;
+}
+
+.c15t-slider {
+ position: absolute;
+ cursor: pointer;
+ inset: 0;
+ background-color: color-mix(
+ in oklab,
+ var(--color-offgray-100) 80%,
+ transparent
+ );
+ border-radius: 20px;
+ box-shadow: inset 0 0 0 1px color-mix(in oklab, #000 5%, transparent);
+ transition: background-color 0.2s;
+}
+
+.c15t-slider:hover {
+ background-color: var(--color-offgray-100);
+}
+
+.dark .c15t-slider {
+ background-color: color-mix(in oklab, #fff 5%, transparent);
+ box-shadow: inset 0 0 0 1px color-mix(in oklab, #fff 15%, transparent);
+}
+
+.dark .c15t-slider:hover {
+ background-color: color-mix(in oklab, #fff 10%, transparent);
+}
+
+.c15t-slider:before {
+ position: absolute;
+ content: "";
+ height: 14px;
+ width: 14px;
+ left: 3px;
+ bottom: 3px;
+ background-color: white;
+ border-radius: 50%;
+ box-shadow:
+ 0 1px 3px 0 rgb(0 0 0 / 0.1),
+ 0 1px 2px -1px rgb(0 0 0 / 0.1);
+ transition: transform 0.2s;
+}
+
+.c15t-switch input:checked + .c15t-slider {
+ background-color: var(--color-accent-blue);
+ box-shadow: inset 0 0 0 1px color-mix(in oklab, #000 5%, transparent);
+}
+
+.c15t-switch input:checked + .c15t-slider:hover {
+ background-color: var(--color-accent-blue);
+}
+
+.dark .c15t-switch input:checked + .c15t-slider {
+ background-color: var(--color-accent-blue);
+ box-shadow: inset 0 0 0 1px color-mix(in oklab, #fff 15%, transparent);
+}
+
+.c15t-switch input:checked + .c15t-slider:before {
+ transform: translateX(12px);
+}
+
+.c15t-switch input:disabled + .c15t-slider {
+ opacity: 0.5;
+ cursor: default;
+ pointer-events: none;
+}
+
+.c15t-switch input:disabled + .c15t-slider:hover {
+ background-color: color-mix(
+ in oklab,
+ var(--color-offgray-100) 80%,
+ transparent
+ );
+}
+
+#c15t-manage-consent-btn {
+ appearance: none;
+ background: none;
+ border: none;
+ padding: 0;
+ cursor: pointer;
+}
+
+#c15t-manage-consent-btn:hover {
+ text-decoration-color: var(--link-line-decoration-hover);
+}
@@ -70,6 +70,8 @@
<!-- MathJax -->
<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
{{/if}}
+ <meta name="amplitude-key" content="#amplitude_key#" />
+ <meta name="consent-io-instance" content="#consent_io_instance#" />
</head>
<body class="no-js">
<div id="body-container">
@@ -343,6 +345,13 @@
href="https://zed.dev/blog"
>Blog</a
>
+ <span class="footer-separator">•</span>
+ <button
+ id="c15t-manage-consent-btn"
+ class="footer-link"
+ >
+ Manage Site Cookies
+ </button>
</footer>
</main>
<div class="toc-container">
@@ -444,23 +453,82 @@
{{/if}}
{{/if}}
- <!-- Amplitude Analytics -->
- <script>
- (function() {
- var amplitudeKey = '#amplitude_key#';
- if (amplitudeKey && amplitudeKey.indexOf('#') === -1) {
- var script = document.createElement('script');
- script.src = 'https://cdn.amplitude.com/script/' + amplitudeKey + '.js';
- script.onload = function() {
- window.amplitude.init(amplitudeKey, {
- fetchRemoteConfig: true,
- autocapture: true
- });
- };
- document.head.appendChild(script);
- }
- })();
- </script>
+ <!-- c15t Consent Banner -->
+ <div id="c15t-banner" style="display: none;">
+ <div>
+ <p id="c15t-description">
+ Zed uses cookies to improve your experience and for marketing. Read <a href="https://zed.dev/cookie-policy">our cookie policy</a> for more details.
+ </p>
+ </div>
+ <div id="c15t-configure-section" style="display: none">
+ <div>
+ <label for="c15t-toggle-necessary"
+ >Strictly Necessary</label
+ >
+ <label class="c15t-switch">
+ <input
+ type="checkbox"
+ id="c15t-toggle-necessary"
+ checked
+ disabled
+ />
+ <span class="c15t-slider"></span>
+ </label>
+ </div>
+ <div>
+ <label for="c15t-toggle-measurement">Analytics</label>
+ <label class="c15t-switch">
+ <input
+ type="checkbox"
+ id="c15t-toggle-measurement"
+ />
+ <span class="c15t-slider"></span>
+ </label>
+ </div>
+ <div>
+ <label for="c15t-toggle-marketing">Marketing</label>
+ <label class="c15t-switch">
+ <input
+ type="checkbox"
+ id="c15t-toggle-marketing"
+ />
+ <span class="c15t-slider"></span>
+ </label>
+ </div>
+ </div>
+ <div id="c15t-footer">
+ <button
+ id="c15t-configure-btn"
+ class="c15t-button icon"
+ title="Configure"
+ >
+ <svg
+ xmlns="http://www.w3.org/2000/svg"
+ width="14"
+ height="14"
+ viewBox="0 0 24 24"
+ fill="none"
+ stroke="currentColor"
+ stroke-width="2"
+ stroke-linecap="round"
+ stroke-linejoin="round"
+ >
+ <path d="M20 7h-9" />
+ <path d="M14 17H5" />
+ <circle cx="17" cy="17" r="3" />
+ <circle cx="7" cy="7" r="3" />
+ </svg>
+ </button>
+ <div>
+ <button id="c15t-decline" class="c15t-button">
+ Reject all
+ </button>
+ <button id="c15t-accept" class="c15t-button primary">
+ Accept all
+ </button>
+ </div>
+ </div>
+ </div>
</div>
</body>
</html>
@@ -42,6 +42,8 @@ extend-exclude = [
"crates/gpui_windows/src/window.rs",
# Some typos in the base mdBook CSS.
"docs/theme/css/",
+ # Automatically generated JS.
+ "docs/theme/c15t@*.js",
# Spellcheck triggers on `|Fixe[sd]|` regex part.
"script/danger/dangerfile.ts",
# Eval examples for prompts and criteria