From 989ff500d912e8d19df4d23c4028788f7dc97fd1 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Mon, 22 Sep 2025 18:40:16 +0200 Subject: [PATCH] Track edit events --- .../zed/src/zed/edit_prediction_registry.rs | 10 + crates/zeta2/src/zeta2.rs | 192 +++++++++++++++++- 2 files changed, 195 insertions(+), 7 deletions(-) diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index d13b93d28f2b1d43b30abd825952fb9a3548bc17..d0e8e26074296e5b54bccaa73de7e06e4aacf205 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -204,6 +204,7 @@ fn assign_edit_prediction_provider( } if std::env::var("ZED_ZETA2").is_ok() { + let zeta = zeta2::Zeta::global(client, &user_store, cx); let provider = cx.new(|cx| { zeta2::ZetaEditPredictionProvider::new( editor.project(), @@ -213,6 +214,15 @@ fn assign_edit_prediction_provider( ) }); + if let Some(buffer) = &singleton_buffer + && buffer.read(cx).file().is_some() + && let Some(project) = editor.project() + { + zeta.update(cx, |zeta, cx| { + zeta.register_buffer(buffer, project, cx); + }); + } + editor.set_edit_prediction_provider(Some(provider), window, cx); } else { let zeta = zeta::Zeta::register(worktree, client.clone(), user_store, cx); diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 0d78345bbfe52d81dcd3593717d2352811ec0faf..5bcb750f4c4a73eac489a22ef3ca8d2198589172 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -22,8 +22,9 @@ use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::Project; use release_channel::AppVersion; use std::cmp; -use std::collections::HashMap; -use std::path::PathBuf; +use std::collections::{HashMap, VecDeque, hash_map}; +use std::fmt::Write; +use std::path::{Path, PathBuf}; use std::str::FromStr as _; use std::time::{Duration, Instant}; use std::{ops::Range, sync::Arc}; @@ -32,6 +33,11 @@ use util::ResultExt as _; use uuid::Uuid; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; +const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); + +/// Maximum number of events to track. +const MAX_EVENT_COUNT: usize = 16; + #[derive(Clone)] struct ZetaGlobal(Entity); @@ -42,13 +48,68 @@ pub struct Zeta { user_store: Entity, llm_token: LlmApiToken, _llm_token_subscription: Subscription, - projects: HashMap, + projects: HashMap, excerpt_options: EditPredictionExcerptOptions, update_required: bool, } -struct RegisteredProject { +struct ZetaProject { syntax_index: Entity, + events: VecDeque, + registered_buffers: HashMap, +} + +struct RegisteredBuffer { + snapshot: BufferSnapshot, + _subscriptions: [gpui::Subscription; 2], +} + +#[derive(Clone)] +pub enum Event { + BufferChange { + old_snapshot: BufferSnapshot, + new_snapshot: BufferSnapshot, + timestamp: Instant, + }, +} + +impl Event { + //TODO: Actually use the events this in the prompt + fn to_prompt(&self) -> String { + match self { + Event::BufferChange { + old_snapshot, + new_snapshot, + .. + } => { + let mut prompt = String::new(); + + let old_path = old_snapshot + .file() + .map(|f| f.path().as_ref()) + .unwrap_or(Path::new("untitled")); + let new_path = new_snapshot + .file() + .map(|f| f.path().as_ref()) + .unwrap_or(Path::new("untitled")); + if old_path != new_path { + writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap(); + } + + let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text()); + if !diff.is_empty() { + write!( + prompt, + "User edited {:?}:\n```diff\n{}\n```", + new_path, diff + ) + .unwrap(); + } + + prompt + } + } + } } impl Zeta { @@ -100,11 +161,129 @@ impl Zeta { } pub fn register_project(&mut self, project: &Entity, cx: &mut App) { + self.get_or_init_zeta_project(project, cx); + } + + pub fn register_buffer( + &mut self, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) { + let zeta_project = self.get_or_init_zeta_project(project, cx); + Self::register_buffer_impl(zeta_project, buffer, project, cx); + } + + fn get_or_init_zeta_project( + &mut self, + project: &Entity, + cx: &mut App, + ) -> &mut ZetaProject { self.projects .entry(project.entity_id()) - .or_insert_with(|| RegisteredProject { + .or_insert_with(|| ZetaProject { syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)), - }); + events: VecDeque::new(), + registered_buffers: HashMap::new(), + }) + } + + fn register_buffer_impl<'a>( + zeta_project: &'a mut ZetaProject, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) -> &'a mut RegisteredBuffer { + let buffer_id = buffer.entity_id(); + match zeta_project.registered_buffers.entry(buffer_id) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let snapshot = buffer.read(cx).snapshot(); + let project_entity_id = project.entity_id(); + entry.insert(RegisteredBuffer { + snapshot, + _subscriptions: [ + cx.subscribe(buffer, { + let project = project.downgrade(); + move |this, buffer, event, cx| { + if let language::BufferEvent::Edited = event + && let Some(project) = project.upgrade() + { + this.report_changes_for_buffer(&buffer, &project, cx); + } + } + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + let Some(zeta_project) = this.projects.get_mut(&project_entity_id) + else { + return; + }; + zeta_project.registered_buffers.remove(&buffer_id); + }), + ], + }) + } + } + } + + fn report_changes_for_buffer( + &mut self, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) -> BufferSnapshot { + let zeta_project = self.get_or_init_zeta_project(project, cx); + let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx); + + let new_snapshot = buffer.read(cx).snapshot(); + if new_snapshot.version != registered_buffer.snapshot.version { + let old_snapshot = + std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); + Self::push_event( + zeta_project, + Event::BufferChange { + old_snapshot, + new_snapshot: new_snapshot.clone(), + timestamp: Instant::now(), + }, + ); + } + + new_snapshot + } + + fn push_event(zeta_project: &mut ZetaProject, event: Event) { + let events = &mut zeta_project.events; + + if let Some(Event::BufferChange { + new_snapshot: last_new_snapshot, + timestamp: last_timestamp, + .. + }) = events.back_mut() + { + // Coalesce edits for the same buffer when they happen one after the other. + let Event::BufferChange { + old_snapshot, + new_snapshot, + timestamp, + } = &event; + + if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL + && old_snapshot.remote_id() == last_new_snapshot.remote_id() + && old_snapshot.version == last_new_snapshot.version + { + *last_new_snapshot = new_snapshot.clone(); + *last_timestamp = *timestamp; + return; + } + } + + if events.len() >= MAX_EVENT_COUNT { + // These are halved instead of popping to improve prompt caching. + events.drain(..MAX_EVENT_COUNT / 2); + } + + events.push_back(event); } pub fn request_prediction( @@ -448,7 +627,6 @@ struct PendingPrediction { impl EditPredictionProvider for ZetaEditPredictionProvider { fn name() -> &'static str { - // TODO [zeta2] "zed-predict2" }