@@ -35,12 +35,13 @@ use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use settings::WorktreeId;
+use std::collections::hash_map;
+use std::mem;
use std::str::FromStr;
use std::{
cmp,
fmt::Write,
future::Future,
- mem,
ops::Range,
path::Path,
rc::Rc,
@@ -211,9 +212,8 @@ impl std::fmt::Debug for EditPrediction {
}
pub struct Zeta {
+ projects: HashMap<EntityId, ZetaProject>,
client: Arc<Client>,
- events: VecDeque<Event>,
- registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
shown_completions: VecDeque<EditPrediction>,
rated_completions: HashSet<EditPredictionId>,
data_collection_choice: Entity<DataCollectionChoice>,
@@ -225,6 +225,11 @@ pub struct Zeta {
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
}
+struct ZetaProject {
+ events: VecDeque<Event>,
+ registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
+}
+
impl Zeta {
pub fn global(cx: &mut App) -> Option<Entity<Self>> {
cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
@@ -255,7 +260,9 @@ impl Zeta {
}
pub fn clear_history(&mut self) {
- self.events.clear();
+ for zeta_project in self.projects.values_mut() {
+ zeta_project.events.clear();
+ }
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -269,11 +276,10 @@ impl Zeta {
let data_collection_choice = cx.new(|_| data_collection_choice);
Self {
+ projects: HashMap::default(),
client,
- events: VecDeque::new(),
shown_completions: VecDeque::new(),
rated_completions: HashSet::default(),
- registered_buffers: HashMap::default(),
data_collection_choice,
llm_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
@@ -294,12 +300,35 @@ impl Zeta {
}
}
- fn push_event(&mut self, event: Event) {
+ fn get_or_init_zeta_project(
+ &mut self,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> &mut ZetaProject {
+ let project_id = project.entity_id();
+ match self.projects.entry(project_id) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ cx.observe_release(project, move |this, _, _cx| {
+ this.projects.remove(&project_id);
+ })
+ .detach();
+ entry.insert(ZetaProject {
+ events: VecDeque::with_capacity(MAX_EVENT_COUNT),
+ registered_buffers: HashMap::default(),
+ })
+ }
+ }
+ }
+
+ fn push_event(zeta_project: &mut ZetaProject, event: Event) {
+ let events = &mut zeta_project.events;
+
if let Some(Event::BufferChange {
new_snapshot: last_new_snapshot,
timestamp: last_timestamp,
..
- }) = self.events.back_mut()
+ }) = events.back_mut()
{
// Coalesce edits for the same buffer when they happen one after the other.
let Event::BufferChange {
@@ -318,50 +347,65 @@ impl Zeta {
}
}
- self.events.push_back(event);
- if self.events.len() >= MAX_EVENT_COUNT {
+ if events.len() >= MAX_EVENT_COUNT {
// These are halved instead of popping to improve prompt caching.
- self.events.drain(..MAX_EVENT_COUNT / 2);
+ events.drain(..MAX_EVENT_COUNT / 2);
}
- }
-
- pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
- let buffer_id = buffer.entity_id();
- let weak_buffer = buffer.downgrade();
-
- if let std::collections::hash_map::Entry::Vacant(entry) =
- self.registered_buffers.entry(buffer_id)
- {
- let snapshot = buffer.read(cx).snapshot();
- entry.insert(RegisteredBuffer {
- snapshot,
- _subscriptions: [
- cx.subscribe(buffer, move |this, buffer, event, cx| {
- this.handle_buffer_event(buffer, event, cx);
- }),
- cx.observe_release(buffer, move |this, _buffer, _cx| {
- this.registered_buffers.remove(&weak_buffer.entity_id());
- }),
- ],
- });
- };
+ events.push_back(event);
}
- fn handle_buffer_event(
+ pub fn register_buffer(
&mut self,
- buffer: Entity<Buffer>,
- event: &language::BufferEvent,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
cx: &mut Context<Self>,
) {
- if let language::BufferEvent::Edited = event {
- self.report_changes_for_buffer(&buffer, cx);
+ let zeta_project = self.get_or_init_zeta_project(project, cx);
+ Self::register_buffer_impl(zeta_project, buffer, project, cx);
+ }
+
+ fn register_buffer_impl<'a>(
+ zeta_project: &'a mut ZetaProject,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> &'a mut RegisteredBuffer {
+ let buffer_id = buffer.entity_id();
+ match zeta_project.registered_buffers.entry(buffer_id) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let snapshot = buffer.read(cx).snapshot();
+ let project_entity_id = project.entity_id();
+ entry.insert(RegisteredBuffer {
+ snapshot,
+ _subscriptions: [
+ cx.subscribe(buffer, {
+ let project = project.downgrade();
+ move |this, buffer, event, cx| {
+ if let language::BufferEvent::Edited = event
+ && let Some(project) = project.upgrade()
+ {
+ this.report_changes_for_buffer(&buffer, &project, cx);
+ }
+ }
+ }),
+ cx.observe_release(buffer, move |this, _buffer, _cx| {
+ let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
+ else {
+ return;
+ };
+ zeta_project.registered_buffers.remove(&buffer_id);
+ }),
+ ],
+ })
+ }
}
}
fn request_completion_impl<F, R>(
&mut self,
- project: Option<&Entity<Project>>,
+ project: &Entity<Project>,
buffer: &Entity<Buffer>,
cursor: language::Anchor,
can_collect_data: bool,
@@ -376,16 +420,14 @@ impl Zeta {
{
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
- let snapshot = self.report_changes_for_buffer(&buffer, cx);
+ let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
let zeta = cx.entity();
- let events = self.events.clone();
+ let events = self.get_or_init_zeta_project(project, cx).events.clone();
let client = self.client.clone();
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
- let git_info = if let (true, Some(project), Some(file)) =
- (can_collect_data, project, snapshot.file())
- {
+ let git_info = if let (true, Some(file)) = (can_collect_data, snapshot.file()) {
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
} else {
None
@@ -512,163 +554,10 @@ impl Zeta {
})
}
- // Generates several example completions of various states to fill the Zeta completion modal
- #[cfg(any(test, feature = "test-support"))]
- pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
- use language::Point;
-
- let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
- And maybe a short line
-
- Then a few lines
-
- and then another
- "#};
-
- let project = None;
- let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
- let position = buffer.read(cx).anchor_before(Point::new(1, 0));
-
- let completion_tasks = vec![
- self.fake_completion(
- project,
- &buffer,
- position,
- PredictEditsResponse {
- request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(),
- output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
-a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
-[here's an edit]
-And maybe a short line
-Then a few lines
-and then another
-{EDITABLE_REGION_END_MARKER}
- ", ),
- },
- cx,
- ),
- self.fake_completion(
- project,
- &buffer,
- position,
- PredictEditsResponse {
- request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(),
- output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
-a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
-And maybe a short line
-[and another edit]
-Then a few lines
-and then another
-{EDITABLE_REGION_END_MARKER}
- "#),
- },
- cx,
- ),
- self.fake_completion(
- project,
- &buffer,
- position,
- PredictEditsResponse {
- request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(),
- output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
-a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
-And maybe a short line
-
-Then a few lines
-
-and then another
-{EDITABLE_REGION_END_MARKER}
- "#),
- },
- cx,
- ),
- self.fake_completion(
- project,
- &buffer,
- position,
- PredictEditsResponse {
- request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(),
- output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
-a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
-And maybe a short line
-
-Then a few lines
-
-and then another
-{EDITABLE_REGION_END_MARKER}
- "#),
- },
- cx,
- ),
- self.fake_completion(
- project,
- &buffer,
- position,
- PredictEditsResponse {
- request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(),
- output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
-a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
-And maybe a short line
-Then a few lines
-[a third completion]
-and then another
-{EDITABLE_REGION_END_MARKER}
- "#),
- },
- cx,
- ),
- self.fake_completion(
- project,
- &buffer,
- position,
- PredictEditsResponse {
- request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(),
- output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
-a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
-And maybe a short line
-and then another
-[fourth completion example]
-{EDITABLE_REGION_END_MARKER}
- "#),
- },
- cx,
- ),
- self.fake_completion(
- project,
- &buffer,
- position,
- PredictEditsResponse {
- request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(),
- output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
-a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
-And maybe a short line
-Then a few lines
-and then another
-[fifth and final completion]
-{EDITABLE_REGION_END_MARKER}
- "#),
- },
- cx,
- ),
- ];
-
- cx.spawn(async move |zeta, cx| {
- for task in completion_tasks {
- task.await.unwrap();
- }
-
- zeta.update(cx, |zeta, _cx| {
- zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]);
- zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]);
- })
- .ok();
- })
- }
-
#[cfg(any(test, feature = "test-support"))]
pub fn fake_completion(
&mut self,
- project: Option<&Entity<Project>>,
+ project: &Entity<Project>,
buffer: &Entity<Buffer>,
position: language::Anchor,
response: PredictEditsResponse,
@@ -683,7 +572,7 @@ and then another
pub fn request_completion(
&mut self,
- project: Option<&Entity<Project>>,
+ project: &Entity<Project>,
buffer: &Entity<Buffer>,
position: language::Anchor,
can_collect_data: bool,
@@ -1043,23 +932,23 @@ and then another
fn report_changes_for_buffer(
&mut self,
buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
cx: &mut Context<Self>,
) -> BufferSnapshot {
- self.register_buffer(buffer, cx);
+ let zeta_project = self.get_or_init_zeta_project(project, cx);
+ let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
- let registered_buffer = self
- .registered_buffers
- .get_mut(&buffer.entity_id())
- .unwrap();
let new_snapshot = buffer.read(cx).snapshot();
-
if new_snapshot.version != registered_buffer.snapshot.version {
let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
- self.push_event(Event::BufferChange {
- old_snapshot,
- new_snapshot: new_snapshot.clone(),
- timestamp: Instant::now(),
- });
+ Self::push_event(
+ zeta_project,
+ Event::BufferChange {
+ old_snapshot,
+ new_snapshot: new_snapshot.clone(),
+ timestamp: Instant::now(),
+ },
+ );
}
new_snapshot
@@ -1140,7 +1029,7 @@ pub struct GatherContextOutput {
}
pub fn gather_context(
- project: Option<&Entity<Project>>,
+ project: &Entity<Project>,
full_path_str: String,
snapshot: &BufferSnapshot,
cursor_point: language::Point,
@@ -1149,8 +1038,7 @@ pub fn gather_context(
git_info: Option<PredictEditsGitInfo>,
cx: &App,
) -> Task<Result<GatherContextOutput>> {
- let local_lsp_store =
- project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
+ let local_lsp_store = project.read(cx).lsp_store().read(cx).as_local();
let diagnostic_groups: Vec<(String, serde_json::Value)> =
if can_collect_data && let Some(local_lsp_store) = local_lsp_store {
snapshot
@@ -1540,6 +1428,9 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
if self.zeta.read(cx).update_required {
return;
}
+ let Some(project) = project else {
+ return;
+ };
if self
.zeta
@@ -1578,13 +1469,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
let completion_request = this.update(cx, |this, cx| {
this.last_request_timestamp = Instant::now();
this.zeta.update(cx, |zeta, cx| {
- zeta.request_completion(
- project.as_ref(),
- &buffer,
- position,
- can_collect_data,
- cx,
- )
+ zeta.request_completion(&project, &buffer, position, can_collect_data, cx)
})
});
@@ -1762,7 +1647,6 @@ fn tokens_for_bytes(bytes: usize) -> usize {
#[cfg(test)]
mod tests {
- use client::UserStore;
use client::test::FakeServer;
use clock::FakeSystemClock;
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -1771,6 +1655,7 @@ mod tests {
use indoc::indoc;
use language::Point;
use settings::SettingsStore;
+ use util::path;
use super::*;
@@ -1897,6 +1782,7 @@ mod tests {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
client::init_settings(cx);
+ Project::init_settings(cx);
});
let edits = edits_for_prediction(
@@ -1961,6 +1847,7 @@ mod tests {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
client::init_settings(cx);
+ Project::init_settings(cx);
});
let buffer_content = "lorem\n";
@@ -2010,13 +1897,14 @@ mod tests {
});
// Construct the fake server to authenticate.
let _server = FakeServer::for_client(42, &client, cx).await;
- let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
-
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
+
+ let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx));
let completion_task = zeta.update(cx, |zeta, cx| {
- zeta.request_completion(None, &buffer, cursor, false, cx)
+ zeta.request_completion(&project, &buffer, cursor, false, cx)
});
let completion = completion_task.await.unwrap().unwrap();
@@ -2074,14 +1962,15 @@ mod tests {
});
// Construct the fake server to authenticate.
let _server = FakeServer::for_client(42, &client, cx).await;
- let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx));
-
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
+
+ let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx));
let completion_task = zeta.update(cx, |zeta, cx| {
- zeta.request_completion(None, &buffer, cursor, false, cx)
+ zeta.request_completion(&project, &buffer, cursor, false, cx)
});
let completion = completion_task.await.unwrap().unwrap();
@@ -10,7 +10,7 @@ use language::Bias;
use language::Buffer;
use language::Point;
use language_model::LlmApiToken;
-use project::{Project, ProjectPath};
+use project::{Project, ProjectPath, Worktree};
use release_channel::AppVersion;
use reqwest_client::ReqwestClient;
use std::path::{Path, PathBuf};
@@ -129,15 +129,33 @@ async fn get_context(
return Err(anyhow!("Absolute paths are not supported in --cursor"));
}
- let (project, _lsp_open_handle, buffer) = if use_language_server {
- let (project, lsp_open_handle, buffer) =
- open_buffer_with_language_server(&worktree_path, &cursor.path, app_state, cx).await?;
- (Some(project), Some(lsp_open_handle), buffer)
+ let project = cx.update(|cx| {
+ Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ )
+ })?;
+
+ let worktree = project
+ .update(cx, |project, cx| {
+ project.create_worktree(&worktree_path, true, cx)
+ })?
+ .await?;
+
+ let (_lsp_open_handle, buffer) = if use_language_server {
+ let (lsp_open_handle, buffer) =
+ open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
+ (Some(lsp_open_handle), buffer)
} else {
let abs_path = worktree_path.join(&cursor.path);
let content = smol::fs::read_to_string(&abs_path).await?;
let buffer = cx.new(|cx| Buffer::local(content, cx))?;
- (None, None, buffer)
+ (None, buffer)
};
let worktree_name = worktree_path
@@ -177,7 +195,7 @@ async fn get_context(
let mut gather_context_output = cx
.update(|cx| {
gather_context(
- project.as_ref(),
+ &project,
full_path_str,
&snapshot,
clipped_cursor,
@@ -198,29 +216,11 @@ async fn get_context(
}
pub async fn open_buffer_with_language_server(
- worktree_path: &Path,
+ project: &Entity<Project>,
+ worktree: &Entity<Worktree>,
path: &Path,
- app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
-) -> Result<(Entity<Project>, Entity<Entity<Buffer>>, Entity<Buffer>)> {
- let project = cx.update(|cx| {
- Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- )
- })?;
-
- let worktree = project
- .update(cx, |project, cx| {
- project.create_worktree(worktree_path, true, cx)
- })?
- .await?;
-
+) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
worktree_id: worktree.id(),
path: path.to_path_buf().into(),
@@ -237,7 +237,7 @@ pub async fn open_buffer_with_language_server(
let log_prefix = path.to_string_lossy().to_string();
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
- Ok((project, lsp_open_handle, buffer))
+ Ok((lsp_open_handle, buffer))
}
// TODO: Dedupe with similar function in crates/eval/src/instance.rs