@@ -39,12 +39,13 @@ use multi_buffer::MultiBufferPoint;
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use settings::WorktreeId;
+use std::collections::hash_map;
+use std::mem;
use std::str::FromStr;
use std::{
cmp,
fmt::Write,
future::Future,
- mem,
ops::Range,
path::Path,
rc::Rc,
@@ -55,6 +56,7 @@ use telemetry_events::EditPredictionRating;
use thiserror::Error;
use util::{ResultExt, maybe};
use uuid::Uuid;
+use workspace::Workspace;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
use worktree::Worktree;
@@ -235,8 +237,6 @@ impl std::fmt::Debug for EditPrediction {
pub struct Zeta {
client: Arc<Client>,
- events: VecDeque<Event>,
- registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
shown_completions: VecDeque<EditPrediction>,
rated_completions: HashSet<EditPredictionId>,
data_collection_choice: Entity<DataCollectionChoice>,
@@ -246,6 +246,12 @@ pub struct Zeta {
update_required: bool,
user_store: Entity<UserStore>,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+ projects: HashMap<EntityId, ZetaProject>,
+}
+
+struct ZetaProject {
+ events: VecDeque<Event>,
+ registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
recent_editors: VecDeque<RecentEditor>,
last_activity_state: Option<ActivityState>,
_activity_poll_task: Option<Task<Result<()>>>,
@@ -296,7 +302,9 @@ impl Zeta {
}
pub fn clear_history(&mut self) {
- self.events.clear();
+ for zeta_project in self.projects.values_mut() {
+ zeta_project.events.clear();
+ }
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -309,6 +317,7 @@ impl Zeta {
let data_collection_choice = Self::load_data_collection_choices();
let data_collection_choice = cx.new(|_| data_collection_choice);
+ /* todo!
let mut activity_poll_task = None;
if let Some(workspace) = &workspace {
@@ -344,13 +353,12 @@ impl Zeta {
}
}));
}
+ */
Self {
client,
- events: VecDeque::with_capacity(MAX_EVENT_COUNT),
shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
rated_completions: HashSet::default(),
- registered_buffers: HashMap::default(),
data_collection_choice,
llm_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
@@ -368,18 +376,42 @@ impl Zeta {
update_required: false,
license_detection_watchers: HashMap::default(),
user_store,
- recent_editors: VecDeque::new(),
- last_activity_state: None,
- _activity_poll_task: activity_poll_task,
+ projects: HashMap::default(),
}
}
- fn push_event(&mut self, event: Event) {
+ fn get_mut_or_init_zeta_project(
+ &mut self,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> &mut ZetaProject {
+ let project_id = project.entity_id();
+ match self.projects.entry(project_id) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ cx.observe_release(project, move |this, _, _cx| {
+ this.projects.remove(&project_id);
+ });
+ entry.insert(ZetaProject {
+ events: VecDeque::with_capacity(MAX_EVENT_COUNT),
+ registered_buffers: HashMap::default(),
+ recent_editors: VecDeque::new(),
+ last_activity_state: None,
+ // todo!
+ _activity_poll_task: None,
+ })
+ }
+ }
+ }
+
+ fn push_event(zeta_project: &mut ZetaProject, event: Event) {
+ let events = &mut zeta_project.events;
+
if let Some(Event::BufferChange {
new_snapshot: last_new_snapshot,
timestamp: last_timestamp,
..
- }) = self.events.back_mut()
+ }) = events.back_mut()
{
// Coalesce edits for the same buffer when they happen one after the other.
let Event::BufferChange {
@@ -398,51 +430,65 @@ impl Zeta {
}
}
- if self.events.len() >= MAX_EVENT_COUNT {
+ if events.len() >= MAX_EVENT_COUNT {
// These are halved instead of popping to improve prompt caching.
- self.events.drain(..MAX_EVENT_COUNT / 2);
+ events.drain(..MAX_EVENT_COUNT / 2);
}
- self.events.push_back(event);
+ events.push_back(event);
}
- pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
- let buffer_id = buffer.entity_id();
- let weak_buffer = buffer.downgrade();
-
- if let std::collections::hash_map::Entry::Vacant(entry) =
- self.registered_buffers.entry(buffer_id)
- {
- let snapshot = buffer.read(cx).snapshot();
-
- entry.insert(RegisteredBuffer {
- snapshot,
- _subscriptions: [
- cx.subscribe(buffer, move |this, buffer, event, cx| {
- this.handle_buffer_event(buffer, event, cx);
- }),
- cx.observe_release(buffer, move |this, _buffer, _cx| {
- this.registered_buffers.remove(&weak_buffer.entity_id());
- }),
- ],
- });
- };
- }
-
- fn handle_buffer_event(
+ pub fn register_buffer(
&mut self,
- buffer: Entity<Buffer>,
- event: &language::BufferEvent,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
cx: &mut Context<Self>,
) {
- if let language::BufferEvent::Edited = event {
- self.report_changes_for_buffer(&buffer, cx);
+ let zeta_project = self.get_mut_or_init_zeta_project(project, cx);
+ Self::register_buffer_impl(zeta_project, buffer, project, cx);
+ }
+
+ fn register_buffer_impl<'a>(
+ zeta_project: &'a mut ZetaProject,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> &'a mut RegisteredBuffer {
+ let buffer_id = buffer.entity_id();
+ match zeta_project.registered_buffers.entry(buffer_id) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let snapshot = buffer.read(cx).snapshot();
+ let project_entity_id = project.entity_id();
+ entry.insert(RegisteredBuffer {
+ snapshot,
+ _subscriptions: [
+ cx.subscribe(buffer, {
+ let project = project.downgrade();
+ move |this, buffer, event, cx| {
+ if let language::BufferEvent::Edited = event
+ && let Some(project) = project.upgrade()
+ {
+ this.report_changes_for_buffer(&buffer, &project, cx);
+ }
+ }
+ }),
+ cx.observe_release(buffer, move |this, _buffer, _cx| {
+ let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
+ else {
+ return;
+ };
+ zeta_project.registered_buffers.remove(&buffer_id);
+ }),
+ ],
+ })
+ }
}
}
fn request_completion_impl<F, R>(
&mut self,
- project: Option<Entity<Project>>,
+ project: &Entity<Project>,
buffer: &Entity<Buffer>,
cursor: language::Anchor,
can_collect_data: CanCollectData,
@@ -457,9 +503,12 @@ impl Zeta {
{
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
- let snapshot = self.report_changes_for_buffer(&buffer, cx);
+ let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
let zeta = cx.entity();
- let events = self.events.clone();
+ let events = self
+ .get_mut_or_init_zeta_project(project, cx)
+ .events
+ .clone();
let client = self.client.clone();
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
@@ -487,6 +536,8 @@ impl Zeta {
} = gather_task.await?;
let done_gathering_context_at = Instant::now();
+ let additional_context_task: Option<Task<PredictEditsAdditionalContext>> = None;
+ /* todo!
let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
&& let Some(file) = snapshot.file()
&& let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
@@ -503,7 +554,7 @@ impl Zeta {
snapshot,
&buffer_snapshotted_at,
project_path,
- project.as_ref(),
+ &project,
cx,
)
}) {
@@ -515,6 +566,7 @@ impl Zeta {
} else {
None
};
+ */
log::debug!(
"Events:\n{}\nExcerpt:\n{:?}",
@@ -606,6 +658,7 @@ impl Zeta {
);
}
+ /* todo!
if let Some(additional_context_task) = additional_context_task {
cx.background_spawn(async move {
if let Some(additional_context) = additional_context_task.await {
@@ -618,6 +671,7 @@ impl Zeta {
})
.detach();
}
+ */
edit_prediction
})
@@ -626,6 +680,7 @@ impl Zeta {
// Generates several example completions of various states to fill the Zeta completion modal
#[cfg(any(test, feature = "test-support"))]
pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
+ /*
use language::Point;
let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
@@ -773,12 +828,14 @@ and then another
})
.ok();
})
+ */
+ todo!()
}
#[cfg(any(test, feature = "test-support"))]
pub fn fake_completion(
&mut self,
- project: Option<Entity<Project>>,
+ project: &Entity<Project>,
buffer: &Entity<Buffer>,
position: language::Anchor,
response: PredictEditsResponse,
@@ -798,7 +855,7 @@ and then another
pub fn request_completion(
&mut self,
- project: Option<Entity<Project>>,
+ project: &Entity<Project>,
buffer: &Entity<Buffer>,
position: language::Anchor,
can_collect_data: CanCollectData,
@@ -1155,23 +1212,23 @@ and then another
fn report_changes_for_buffer(
&mut self,
buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
cx: &mut Context<Self>,
) -> BufferSnapshot {
- self.register_buffer(buffer, cx);
+ let zeta_project = self.get_mut_or_init_zeta_project(project, cx);
+ let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
- let registered_buffer = self
- .registered_buffers
- .get_mut(&buffer.entity_id())
- .unwrap();
let new_snapshot = buffer.read(cx).snapshot();
-
if new_snapshot.version != registered_buffer.snapshot.version {
let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
- self.push_event(Event::BufferChange {
- old_snapshot,
- new_snapshot: new_snapshot.clone(),
- timestamp: Instant::now(),
- });
+ Self::push_event(
+ zeta_project,
+ Event::BufferChange {
+ old_snapshot,
+ new_snapshot: new_snapshot.clone(),
+ timestamp: Instant::now(),
+ },
+ );
}
new_snapshot
@@ -1194,6 +1251,7 @@ and then another
}
}
+ /*
fn gather_additional_context(
&mut self,
cursor_point: language::Point,
@@ -1201,10 +1259,10 @@ and then another
snapshot: BufferSnapshot,
buffer_snapshotted_at: &Instant,
project_path: ProjectPath,
- project: Option<&Entity<Project>>,
+ project: &WeakEntity<Project>,
cx: &mut Context<Self>,
) -> Option<Task<PredictEditsAdditionalContext>> {
- let project = project?.read(cx);
+ let project = project.upgrade()?.read(cx);
let entry = project.entry_for_path(&project_path, cx)?;
if !worktree_entry_is_eligible_for_collection(&entry) {
return None;
@@ -1511,6 +1569,7 @@ and then another
}
results
}
+ */
}
fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
@@ -1926,6 +1985,10 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
if self.zeta.read(cx).update_required {
return;
}
+ // todo! Don't require a project
+ let Some(project) = project else {
+ return;
+ };
if self
.zeta
@@ -1964,7 +2027,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
let completion_request = this.update(cx, |this, cx| {
this.last_request_timestamp = Instant::now();
this.zeta.update(cx, |zeta, cx| {
- zeta.request_completion(project, &buffer, position, can_collect_data, cx)
+ zeta.request_completion(&project, &buffer, position, can_collect_data, cx)
})
});
@@ -2140,6 +2203,7 @@ fn tokens_for_bytes(bytes: usize) -> usize {
bytes / BYTES_PER_TOKEN_GUESS
}
+/* todo!
#[cfg(test)]
mod tests {
use client::UserStore;
@@ -2510,3 +2574,4 @@ mod tests {
zlog::init_test();
}
}
+*/