Track edit events

Bennet Bo Fenner created

Change summary

crates/zed/src/zed/edit_prediction_registry.rs |  10 +
crates/zeta2/src/zeta2.rs                      | 192 +++++++++++++++++++
2 files changed, 195 insertions(+), 7 deletions(-)

Detailed changes

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -204,6 +204,7 @@ fn assign_edit_prediction_provider(
                 }
 
                 if std::env::var("ZED_ZETA2").is_ok() {
+                    let zeta = zeta2::Zeta::global(client, &user_store, cx);
                     let provider = cx.new(|cx| {
                         zeta2::ZetaEditPredictionProvider::new(
                             editor.project(),
@@ -213,6 +214,15 @@ fn assign_edit_prediction_provider(
                         )
                     });
 
+                    if let Some(buffer) = &singleton_buffer
+                        && buffer.read(cx).file().is_some()
+                        && let Some(project) = editor.project()
+                    {
+                        zeta.update(cx, |zeta, cx| {
+                            zeta.register_buffer(buffer, project, cx);
+                        });
+                    }
+
                     editor.set_edit_prediction_provider(Some(provider), window, cx);
                 } else {
                     let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx);

crates/zeta2/src/zeta2.rs 🔗

@@ -22,8 +22,9 @@ use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use project::Project;
 use release_channel::AppVersion;
 use std::cmp;
-use std::collections::HashMap;
-use std::path::PathBuf;
+use std::collections::{HashMap, VecDeque, hash_map};
+use std::fmt::Write;
+use std::path::{Path, PathBuf};
 use std::str::FromStr as _;
 use std::time::{Duration, Instant};
 use std::{ops::Range, sync::Arc};
@@ -32,6 +33,11 @@ use util::ResultExt as _;
 use uuid::Uuid;
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 
+const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
+
+/// Maximum number of events to track.
+const MAX_EVENT_COUNT: usize = 16;
+
 #[derive(Clone)]
 struct ZetaGlobal(Entity<Zeta>);
 
@@ -42,13 +48,68 @@ pub struct Zeta {
     user_store: Entity<UserStore>,
     llm_token: LlmApiToken,
     _llm_token_subscription: Subscription,
-    projects: HashMap<EntityId, RegisteredProject>,
+    projects: HashMap<EntityId, ZetaProject>,
     excerpt_options: EditPredictionExcerptOptions,
     update_required: bool,
 }
 
-struct RegisteredProject {
+struct ZetaProject {
     syntax_index: Entity<SyntaxIndex>,
+    events: VecDeque<Event>,
+    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
+}
+
+struct RegisteredBuffer {
+    snapshot: BufferSnapshot,
+    _subscriptions: [gpui::Subscription; 2],
+}
+
+#[derive(Clone)]
+pub enum Event {
+    BufferChange {
+        old_snapshot: BufferSnapshot,
+        new_snapshot: BufferSnapshot,
+        timestamp: Instant,
+    },
+}
+
+impl Event {
+    //TODO: Actually use the events this in the prompt
+    fn to_prompt(&self) -> String {
+        match self {
+            Event::BufferChange {
+                old_snapshot,
+                new_snapshot,
+                ..
+            } => {
+                let mut prompt = String::new();
+
+                let old_path = old_snapshot
+                    .file()
+                    .map(|f| f.path().as_ref())
+                    .unwrap_or(Path::new("untitled"));
+                let new_path = new_snapshot
+                    .file()
+                    .map(|f| f.path().as_ref())
+                    .unwrap_or(Path::new("untitled"));
+                if old_path != new_path {
+                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
+                }
+
+                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
+                if !diff.is_empty() {
+                    write!(
+                        prompt,
+                        "User edited {:?}:\n```diff\n{}\n```",
+                        new_path, diff
+                    )
+                    .unwrap();
+                }
+
+                prompt
+            }
+        }
+    }
 }
 
 impl Zeta {
@@ -100,11 +161,129 @@ impl Zeta {
     }
 
     pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
+        self.get_or_init_zeta_project(project, cx);
+    }
+
+    pub fn register_buffer(
+        &mut self,
+        buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
+        cx: &mut Context<Self>,
+    ) {
+        let zeta_project = self.get_or_init_zeta_project(project, cx);
+        Self::register_buffer_impl(zeta_project, buffer, project, cx);
+    }
+
+    fn get_or_init_zeta_project(
+        &mut self,
+        project: &Entity<Project>,
+        cx: &mut App,
+    ) -> &mut ZetaProject {
         self.projects
             .entry(project.entity_id())
-            .or_insert_with(|| RegisteredProject {
+            .or_insert_with(|| ZetaProject {
                 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
-            });
+                events: VecDeque::new(),
+                registered_buffers: HashMap::new(),
+            })
+    }
+
+    fn register_buffer_impl<'a>(
+        zeta_project: &'a mut ZetaProject,
+        buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
+        cx: &mut Context<Self>,
+    ) -> &'a mut RegisteredBuffer {
+        let buffer_id = buffer.entity_id();
+        match zeta_project.registered_buffers.entry(buffer_id) {
+            hash_map::Entry::Occupied(entry) => entry.into_mut(),
+            hash_map::Entry::Vacant(entry) => {
+                let snapshot = buffer.read(cx).snapshot();
+                let project_entity_id = project.entity_id();
+                entry.insert(RegisteredBuffer {
+                    snapshot,
+                    _subscriptions: [
+                        cx.subscribe(buffer, {
+                            let project = project.downgrade();
+                            move |this, buffer, event, cx| {
+                                if let language::BufferEvent::Edited = event
+                                    && let Some(project) = project.upgrade()
+                                {
+                                    this.report_changes_for_buffer(&buffer, &project, cx);
+                                }
+                            }
+                        }),
+                        cx.observe_release(buffer, move |this, _buffer, _cx| {
+                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
+                            else {
+                                return;
+                            };
+                            zeta_project.registered_buffers.remove(&buffer_id);
+                        }),
+                    ],
+                })
+            }
+        }
+    }
+
+    fn report_changes_for_buffer(
+        &mut self,
+        buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
+        cx: &mut Context<Self>,
+    ) -> BufferSnapshot {
+        let zeta_project = self.get_or_init_zeta_project(project, cx);
+        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
+
+        let new_snapshot = buffer.read(cx).snapshot();
+        if new_snapshot.version != registered_buffer.snapshot.version {
+            let old_snapshot =
+                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
+            Self::push_event(
+                zeta_project,
+                Event::BufferChange {
+                    old_snapshot,
+                    new_snapshot: new_snapshot.clone(),
+                    timestamp: Instant::now(),
+                },
+            );
+        }
+
+        new_snapshot
+    }
+
+    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
+        let events = &mut zeta_project.events;
+
+        if let Some(Event::BufferChange {
+            new_snapshot: last_new_snapshot,
+            timestamp: last_timestamp,
+            ..
+        }) = events.back_mut()
+        {
+            // Coalesce edits for the same buffer when they happen one after the other.
+            let Event::BufferChange {
+                old_snapshot,
+                new_snapshot,
+                timestamp,
+            } = &event;
+
+            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
+                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
+                && old_snapshot.version == last_new_snapshot.version
+            {
+                *last_new_snapshot = new_snapshot.clone();
+                *last_timestamp = *timestamp;
+                return;
+            }
+        }
+
+        if events.len() >= MAX_EVENT_COUNT {
+            // These are halved instead of popping to improve prompt caching.
+            events.drain(..MAX_EVENT_COUNT / 2);
+        }
+
+        events.push_back(event);
     }
 
     pub fn request_prediction(
@@ -448,7 +627,6 @@ struct PendingPrediction {
 
 impl EditPredictionProvider for ZetaEditPredictionProvider {
     fn name() -> &'static str {
-        // TODO [zeta2]
         "zed-predict2"
     }