@@ -204,6 +204,7 @@ fn assign_edit_prediction_provider(
}
if std::env::var("ZED_ZETA2").is_ok() {
+ let zeta = zeta2::Zeta::global(client, &user_store, cx);
let provider = cx.new(|cx| {
zeta2::ZetaEditPredictionProvider::new(
editor.project(),
@@ -213,6 +214,15 @@ fn assign_edit_prediction_provider(
)
});
+ if let Some(buffer) = &singleton_buffer
+ && buffer.read(cx).file().is_some()
+ && let Some(project) = editor.project()
+ {
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_buffer(buffer, project, cx);
+ });
+ }
+
editor.set_edit_prediction_provider(Some(provider), window, cx);
} else {
let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);
@@ -22,8 +22,9 @@ use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::Project;
use release_channel::AppVersion;
use std::cmp;
-use std::collections::HashMap;
-use std::path::PathBuf;
+use std::collections::{HashMap, VecDeque, hash_map};
+use std::fmt::Write;
+use std::path::{Path, PathBuf};
use std::str::FromStr as _;
use std::time::{Duration, Instant};
use std::{ops::Range, sync::Arc};
@@ -32,6 +33,11 @@ use util::ResultExt as _;
use uuid::Uuid;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
+
+/// Maximum number of events to track.
+const MAX_EVENT_COUNT: usize = 16;
+
#[derive(Clone)]
struct ZetaGlobal(Entity<Zeta>);
@@ -42,13 +48,68 @@ pub struct Zeta {
user_store: Entity<UserStore>,
llm_token: LlmApiToken,
_llm_token_subscription: Subscription,
- projects: HashMap<EntityId, RegisteredProject>,
+ projects: HashMap<EntityId, ZetaProject>,
excerpt_options: EditPredictionExcerptOptions,
update_required: bool,
}
-struct RegisteredProject {
+struct ZetaProject {
syntax_index: Entity<SyntaxIndex>,
+ events: VecDeque<Event>,
+ registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
+}
+
+struct RegisteredBuffer {
+ snapshot: BufferSnapshot,
+ _subscriptions: [gpui::Subscription; 2],
+}
+
+#[derive(Clone)]
+pub enum Event {
+ BufferChange {
+ old_snapshot: BufferSnapshot,
+ new_snapshot: BufferSnapshot,
+ timestamp: Instant,
+ },
+}
+
+impl Event {
+ //TODO: Actually use the events this in the prompt
+ fn to_prompt(&self) -> String {
+ match self {
+ Event::BufferChange {
+ old_snapshot,
+ new_snapshot,
+ ..
+ } => {
+ let mut prompt = String::new();
+
+ let old_path = old_snapshot
+ .file()
+ .map(|f| f.path().as_ref())
+ .unwrap_or(Path::new("untitled"));
+ let new_path = new_snapshot
+ .file()
+ .map(|f| f.path().as_ref())
+ .unwrap_or(Path::new("untitled"));
+ if old_path != new_path {
+ writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
+ }
+
+ let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
+ if !diff.is_empty() {
+ write!(
+ prompt,
+ "User edited {:?}:\n```diff\n{}\n```",
+ new_path, diff
+ )
+ .unwrap();
+ }
+
+ prompt
+ }
+ }
+ }
}
impl Zeta {
@@ -100,11 +161,129 @@ impl Zeta {
}
pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
+ self.get_or_init_zeta_project(project, cx);
+ }
+
+ pub fn register_buffer(
+ &mut self,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) {
+ let zeta_project = self.get_or_init_zeta_project(project, cx);
+ Self::register_buffer_impl(zeta_project, buffer, project, cx);
+ }
+
+ fn get_or_init_zeta_project(
+ &mut self,
+ project: &Entity<Project>,
+ cx: &mut App,
+ ) -> &mut ZetaProject {
self.projects
.entry(project.entity_id())
- .or_insert_with(|| RegisteredProject {
+ .or_insert_with(|| ZetaProject {
syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
- });
+ events: VecDeque::new(),
+ registered_buffers: HashMap::new(),
+ })
+ }
+
+ fn register_buffer_impl<'a>(
+ zeta_project: &'a mut ZetaProject,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> &'a mut RegisteredBuffer {
+ let buffer_id = buffer.entity_id();
+ match zeta_project.registered_buffers.entry(buffer_id) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let snapshot = buffer.read(cx).snapshot();
+ let project_entity_id = project.entity_id();
+ entry.insert(RegisteredBuffer {
+ snapshot,
+ _subscriptions: [
+ cx.subscribe(buffer, {
+ let project = project.downgrade();
+ move |this, buffer, event, cx| {
+ if let language::BufferEvent::Edited = event
+ && let Some(project) = project.upgrade()
+ {
+ this.report_changes_for_buffer(&buffer, &project, cx);
+ }
+ }
+ }),
+ cx.observe_release(buffer, move |this, _buffer, _cx| {
+ let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
+ else {
+ return;
+ };
+ zeta_project.registered_buffers.remove(&buffer_id);
+ }),
+ ],
+ })
+ }
+ }
+ }
+
+ fn report_changes_for_buffer(
+ &mut self,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> BufferSnapshot {
+ let zeta_project = self.get_or_init_zeta_project(project, cx);
+ let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
+
+ let new_snapshot = buffer.read(cx).snapshot();
+ if new_snapshot.version != registered_buffer.snapshot.version {
+ let old_snapshot =
+ std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
+ Self::push_event(
+ zeta_project,
+ Event::BufferChange {
+ old_snapshot,
+ new_snapshot: new_snapshot.clone(),
+ timestamp: Instant::now(),
+ },
+ );
+ }
+
+ new_snapshot
+ }
+
+ fn push_event(zeta_project: &mut ZetaProject, event: Event) {
+ let events = &mut zeta_project.events;
+
+ if let Some(Event::BufferChange {
+ new_snapshot: last_new_snapshot,
+ timestamp: last_timestamp,
+ ..
+ }) = events.back_mut()
+ {
+ // Coalesce edits for the same buffer when they happen one after the other.
+ let Event::BufferChange {
+ old_snapshot,
+ new_snapshot,
+ timestamp,
+ } = &event;
+
+ if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
+ && old_snapshot.remote_id() == last_new_snapshot.remote_id()
+ && old_snapshot.version == last_new_snapshot.version
+ {
+ *last_new_snapshot = new_snapshot.clone();
+ *last_timestamp = *timestamp;
+ return;
+ }
+ }
+
+ if events.len() >= MAX_EVENT_COUNT {
+ // These are halved instead of popping to improve prompt caching.
+ events.drain(..MAX_EVENT_COUNT / 2);
+ }
+
+ events.push_back(event);
}
pub fn request_prediction(
@@ -448,7 +627,6 @@ struct PendingPrediction {
impl EditPredictionProvider for ZetaEditPredictionProvider {
fn name() -> &'static str {
- // TODO [zeta2]
"zed-predict2"
}