zeta: Record recently active files when data collection is enabled

Michael Sloan created

Change summary

crates/cloud_llm_client/src/cloud_llm_client.rs |   8 
crates/zed/src/zed/edit_prediction_registry.rs  |   5 
crates/zeta/src/zeta.rs                         | 171 ++++++++++++++----
crates/zeta_cli/src/main.rs                     |   2 
4 files changed, 140 insertions(+), 46 deletions(-)

Detailed changes

crates/cloud_llm_client/src/cloud_llm_client.rs 🔗

@@ -152,6 +152,8 @@ pub struct PredictEditsBody {
     /// Info about the git repository state, only present when can_collect_data is true.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub git_info: Option<PredictEditsGitInfo>,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub recent_files: Option<Vec<PredictEditsRecentFile>>,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -167,6 +169,12 @@ pub struct PredictEditsGitInfo {
     pub remote_upstream_url: Option<String>,
 }
 
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PredictEditsRecentFile {
+    pub repo_path: String,
+    pub active_to_now_ms: u32,
+}
+
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct PredictEditsResponse {
     pub request_id: Uuid,

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -204,10 +204,7 @@ fn assign_edit_prediction_provider(
                     }
                 }
 
-                let workspace = window
-                    .root::<Workspace>()
-                    .flatten()
-                    .map(|workspace| workspace.downgrade());
+                let workspace = window.root::<Workspace>().flatten();
 
                 let zeta =
                     zeta::Zeta::register(workspace, worktree, client.clone(), user_store, cx);

crates/zeta/src/zeta.rs 🔗

@@ -6,19 +6,21 @@ mod onboarding_modal;
 mod onboarding_telemetry;
 mod rate_completion_modal;
 
+use arrayvec::ArrayVec;
 pub(crate) use completion_diff_element::*;
 use db::kvp::{Dismissable, KEY_VALUE_STORE};
 use edit_prediction::DataCollectionState;
 pub use init::*;
 use license_detection::LicenseDetectionWatcher;
+use project::git_store::Repository;
 pub use rate_completion_modal::*;
 
 use anyhow::{Context as _, Result, anyhow};
-use arrayvec::ArrayVec;
 use client::{Client, EditPredictionUsage, UserStore};
 use cloud_llm_client::{
     AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
-    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
+    PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
+    ZED_VERSION_HEADER_NAME,
 };
 use collections::{HashMap, HashSet, VecDeque};
 use futures::AsyncReadExt;
@@ -32,7 +34,7 @@ use language::{
     Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
 };
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
-use project::{Project, ProjectPath};
+use project::{Project, ProjectEntryId, ProjectPath};
 use release_channel::AppVersion;
 use settings::WorktreeId;
 use std::str::FromStr;
@@ -70,6 +72,12 @@ const MAX_DIAGNOSTIC_GROUPS: usize = 10;
 /// Maximum number of events to track.
 const MAX_EVENT_COUNT: usize = 16;
 
+/// Maximum number of recent files to track.
+const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
+
+/// Maximum number of edit predictions to store for feedback.
+const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
+
 actions!(
     edit_prediction,
     [
@@ -212,7 +220,7 @@ impl std::fmt::Debug for EditPrediction {
 }
 
 pub struct Zeta {
-    workspace: Option<WeakEntity<Workspace>>,
+    workspace: WeakEntity<Workspace>,
     client: Arc<Client>,
     events: VecDeque<Event>,
     registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
@@ -225,6 +233,7 @@ pub struct Zeta {
     update_required: bool,
     user_store: Entity<UserStore>,
     license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+    recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
 }
 
 impl Zeta {
@@ -233,7 +242,7 @@ impl Zeta {
     }
 
     pub fn register(
-        workspace: Option<WeakEntity<Workspace>>,
+        workspace: Option<Entity<Workspace>>,
         worktree: Option<Entity<Worktree>>,
         client: Arc<Client>,
         user_store: Entity<UserStore>,
@@ -266,7 +275,7 @@ impl Zeta {
     }
 
     fn new(
-        workspace: Option<WeakEntity<Workspace>>,
+        workspace: Option<Entity<Workspace>>,
         client: Arc<Client>,
         user_store: Entity<UserStore>,
         cx: &mut Context<Self>,
@@ -276,11 +285,27 @@ impl Zeta {
         let data_collection_choice = Self::load_data_collection_choices();
         let data_collection_choice = cx.new(|_| data_collection_choice);
 
+        if let Some(workspace) = &workspace {
+            cx.subscribe(
+                &workspace.read(cx).project().clone(),
+                |this, _workspace, event, _cx| match event {
+                    project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
+                        this.push_recent_project_entry(*project_entry_id)
+                    }
+                    _ => {}
+                },
+            )
+            .detach();
+        }
+
         Self {
-            workspace,
+            workspace: workspace.map_or_else(
+                || WeakEntity::new_invalid(),
+                |workspace| workspace.downgrade(),
+            ),
             client,
-            events: VecDeque::new(),
-            shown_completions: VecDeque::new(),
+            events: VecDeque::with_capacity(MAX_EVENT_COUNT),
+            shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
             rated_completions: HashSet::default(),
             registered_buffers: HashMap::default(),
             data_collection_choice,
@@ -300,6 +325,7 @@ impl Zeta {
             update_required: false,
             license_detection_watchers: HashMap::default(),
             user_store,
+            recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
         }
     }
 
@@ -327,11 +353,12 @@ impl Zeta {
             }
         }
 
-        self.events.push_back(event);
         if self.events.len() >= MAX_EVENT_COUNT {
             // These are halved instead of popping to improve prompt caching.
             self.events.drain(..MAX_EVENT_COUNT / 2);
         }
+
+        self.events.push_back(event);
     }
 
     pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
@@ -393,12 +420,17 @@ impl Zeta {
         let llm_token = self.llm_token.clone();
         let app_version = AppVersion::global(cx);
 
-        let git_info = if let (true, Some(project), Some(file)) =
+        let (git_info, recent_files) = if let (true, Some(project), Some(file)) =
             (can_collect_data, project, snapshot.file())
+            && let Some(repository) =
+                git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
         {
-            git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
+            let repository = repository.read(cx);
+            let git_info = make_predict_edits_git_info(repository);
+            let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
+            (git_info, Some(recent_files))
         } else {
-            None
+            (None, None)
         };
 
         let full_path: Arc<Path> = snapshot
@@ -417,6 +449,7 @@ impl Zeta {
             make_events_prompt,
             can_collect_data,
             git_info,
+            recent_files,
             cx,
         );
 
@@ -702,12 +735,8 @@ and then another
         can_collect_data: bool,
         cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPrediction>>> {
-        let workspace = self
-            .workspace
-            .as_ref()
-            .and_then(|workspace| workspace.upgrade());
         self.request_completion_impl(
-            workspace,
+            self.workspace.upgrade(),
             project,
             buffer,
             position,
@@ -1021,11 +1050,11 @@ and then another
     }
 
     pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
-        self.shown_completions.push_front(completion.clone());
-        if self.shown_completions.len() > 50 {
+        if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
             let completion = self.shown_completions.pop_back().unwrap();
             self.rated_completions.remove(&completion.id);
         }
+        self.shown_completions.push_front(completion.clone());
         cx.notify();
     }
 
@@ -1099,6 +1128,63 @@ and then another
             None => DataCollectionChoice::NotAnswered,
         }
     }
+
+    fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
+        let now = Instant::now();
+        if let Some(existing_ix) = self
+            .recent_project_entries
+            .iter()
+            .rposition(|(id, _)| *id == project_entry_id)
+        {
+            self.recent_project_entries.remove(existing_ix);
+        }
+        if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
+            self.recent_project_entries.pop_front();
+        }
+        self.recent_project_entries
+            .push_back((project_entry_id, now));
+    }
+
+    fn recent_files(
+        &mut self,
+        now: &Instant,
+        repository: &Repository,
+        cx: &Context<Self>,
+    ) -> Vec<PredictEditsRecentFile> {
+        let Ok(project) = self
+            .workspace
+            .read_with(cx, |workspace, _cx| workspace.project().clone())
+        else {
+            return Vec::new();
+        };
+        let mut results = Vec::new();
+        for ix in (0..self.recent_project_entries.len()).rev() {
+            let (id, last_active_at) = &self.recent_project_entries[ix];
+            let Some(project_path) = project.read(cx).path_for_entry(*id, cx) else {
+                self.recent_project_entries.remove(ix);
+                continue;
+            };
+            let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx) else {
+                // entry not removed since queries involving other repositories might occur later
+                continue;
+            };
+            let Some(repo_path) = repo_path.to_str() else {
+                // paths may not be valid UTF-8
+                self.recent_project_entries.remove(ix);
+                continue;
+            };
+            let Ok(active_to_now_ms) = now.duration_since(*last_active_at).as_millis().try_into()
+            else {
+                self.recent_project_entries.remove(ix);
+                continue;
+            };
+            results.push(PredictEditsRecentFile {
+                repo_path: repo_path.to_string(),
+                active_to_now_ms,
+            });
+        }
+        results
+    }
 }
 
 pub struct PerformPredictEditsParams {
@@ -1123,33 +1209,32 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
         .sum()
 }
 
-fn git_info_for_file(
+fn git_repository_for_file(
     project: &Entity<Project>,
     project_path: &ProjectPath,
     cx: &App,
-) -> Option<PredictEditsGitInfo> {
+) -> Option<Entity<Repository>> {
     let git_store = project.read(cx).git_store().read(cx);
-    if let Some((repository, _repo_path)) =
-        git_store.repository_and_path_for_project_path(project_path, cx)
-    {
-        let repository = repository.read(cx);
-        let head_sha = repository
-            .head_commit
-            .as_ref()
-            .map(|head_commit| head_commit.sha.to_string());
-        let remote_origin_url = repository.remote_origin_url.clone();
-        let remote_upstream_url = repository.remote_upstream_url.clone();
-        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
-            return None;
-        }
-        Some(PredictEditsGitInfo {
-            head_sha,
-            remote_origin_url,
-            remote_upstream_url,
-        })
-    } else {
-        None
+    git_store
+        .repository_and_path_for_project_path(project_path, cx)
+        .map(|(repo, _repo_path)| repo)
+}
+
+fn make_predict_edits_git_info(repository: &Repository) -> Option<PredictEditsGitInfo> {
+    let head_sha = repository
+        .head_commit
+        .as_ref()
+        .map(|head_commit| head_commit.sha.to_string());
+    let remote_origin_url = repository.remote_origin_url.clone();
+    let remote_upstream_url = repository.remote_upstream_url.clone();
+    if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
+        return None;
     }
+    Some(PredictEditsGitInfo {
+        head_sha,
+        remote_origin_url,
+        remote_upstream_url,
+    })
 }
 
 pub struct GatherContextOutput {
@@ -1165,6 +1250,7 @@ pub fn gather_context(
     make_events_prompt: impl FnOnce() -> String + Send + 'static,
     can_collect_data: bool,
     git_info: Option<PredictEditsGitInfo>,
+    recent_files: Option<Vec<PredictEditsRecentFile>>,
     cx: &App,
 ) -> Task<Result<GatherContextOutput>> {
     let local_lsp_store =
@@ -1216,6 +1302,7 @@ pub fn gather_context(
                 git_info,
                 outline: None,
                 speculated_output: None,
+                recent_files,
             };
 
             Ok(GatherContextOutput {

crates/zeta_cli/src/main.rs 🔗

@@ -174,6 +174,7 @@ async fn get_context(
     // Enable gathering extra data not currently needed for edit predictions
     let can_collect_data = true;
     let git_info = None;
+    let recent_files = None;
     let mut gather_context_output = cx
         .update(|cx| {
             gather_context(
@@ -184,6 +185,7 @@ async fn get_context(
                 move || events,
                 can_collect_data,
                 git_info,
+                recent_files,
                 cx,
             )
         })?