Cleanup + only record git info if current file may be in repo

Michael Sloan created

Change summary

crates/cloud_llm_client/src/cloud_llm_client.rs |   4 
crates/zeta/src/zeta.rs                         | 141 ++++++++++--------
crates/zeta_cli/src/main.rs                     |   7 
3 files changed, 80 insertions(+), 72 deletions(-)

Detailed changes

crates/cloud_llm_client/src/cloud_llm_client.rs 🔗

@@ -152,8 +152,6 @@ pub struct PredictEditsBody {
     /// Info about the git repository state, only present when can_collect_data is true.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub git_info: Option<PredictEditsGitInfo>,
-    #[serde(skip_serializing_if = "Option::is_none", default)]
-    pub recent_files: Option<Vec<PredictEditsRecentFile>>,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -167,6 +165,8 @@ pub struct PredictEditsGitInfo {
     /// URL of the remote called `upstream`.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub remote_upstream_url: Option<String>,
+    #[serde(skip_serializing_if = "Option::is_none", default)]
+    pub recent_files: Option<Vec<PredictEditsRecentFile>>,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]

crates/zeta/src/zeta.rs 🔗

@@ -404,7 +404,7 @@ impl Zeta {
         project: Option<&Entity<Project>>,
         buffer: &Entity<Buffer>,
         cursor: language::Anchor,
-        can_collect_data: bool,
+        can_collect_data: CanCollectData,
         cx: &mut Context<Self>,
         perform_predict_edits: F,
     ) -> Task<Result<Option<EditPrediction>>>
@@ -423,17 +423,10 @@ impl Zeta {
         let llm_token = self.llm_token.clone();
         let app_version = AppVersion::global(cx);
 
-        let (git_info, recent_files) = if let (true, Some(project), Some(file)) =
-            (can_collect_data, project, snapshot.file())
-            && let Some(repository) =
-                git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
-        {
-            let repository = repository.read(cx);
-            let git_info = make_predict_edits_git_info(repository);
-            let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
-            (git_info, Some(recent_files))
+        let git_info = if matches!(can_collect_data, CanCollectData(true)) {
+            self.gather_git_info(project.clone(), &buffer_snapshotted_at, &snapshot, cx)
         } else {
-            (None, None)
+            None
         };
 
         let full_path: Arc<Path> = snapshot
@@ -452,7 +445,6 @@ impl Zeta {
             make_events_prompt,
             can_collect_data,
             git_info,
-            recent_files,
             cx,
         );
 
@@ -725,9 +717,15 @@ and then another
     ) -> Task<Result<Option<EditPrediction>>> {
         use std::future::ready;
 
-        self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
-            ready(Ok((response, None)))
-        })
+        self.request_completion_impl(
+            None,
+            project,
+            buffer,
+            position,
+            CanCollectData(false),
+            cx,
+            |_params| ready(Ok((response, None))),
+        )
     }
 
     pub fn request_completion(
@@ -735,7 +733,7 @@ and then another
         project: Option<&Entity<Project>>,
         buffer: &Entity<Buffer>,
         position: language::Anchor,
-        can_collect_data: bool,
+        can_collect_data: CanCollectData,
         cx: &mut Context<Self>,
     ) -> Task<Result<Option<EditPrediction>>> {
         self.request_completion_impl(
@@ -1132,6 +1130,46 @@ and then another
         }
     }
 
+    fn gather_git_info(
+        &mut self,
+        project: Option<&Entity<Project>>,
+        buffer_snapshotted_at: &Instant,
+        snapshot: &BufferSnapshot,
+        cx: &Context<Self>,
+    ) -> Option<PredictEditsGitInfo> {
+        let project = project?.read(cx);
+        let file = snapshot.file()?;
+        let project_path = ProjectPath::from_file(file.as_ref(), cx);
+        let entry = project.entry_for_path(&project_path, cx)?;
+        if !worktree_entry_eligible_for_collection(&entry) {
+            return None;
+        }
+
+        let git_store = project.git_store().read(cx);
+        let (repository, _repo_path) =
+            git_store.repository_and_path_for_project_path(&project_path, cx)?;
+
+        let repository = repository.read(cx);
+        let head_sha = repository
+            .head_commit
+            .as_ref()
+            .map(|head_commit| head_commit.sha.to_string());
+        let remote_origin_url = repository.remote_origin_url.clone();
+        let remote_upstream_url = repository.remote_upstream_url.clone();
+        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
+            return None;
+        }
+
+        let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
+
+        Some(PredictEditsGitInfo {
+            head_sha,
+            remote_origin_url,
+            remote_upstream_url,
+            recent_files: Some(recent_files),
+        })
+    }
+
     fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
         let now = Instant::now();
         if let Some(existing_ix) = self
@@ -1166,12 +1204,7 @@ and then another
             if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
                 && let worktree = worktree.read(cx)
                 && let Some(entry) = worktree.entry_for_id(*entry_id)
-                && entry.is_file()
-                && entry.is_created()
-                && !entry.is_ignored
-                && !entry.is_private
-                && !entry.is_external
-                && !entry.is_fifo
+                && worktree_entry_eligible_for_collection(entry)
             {
                 let project_path = ProjectPath {
                     worktree_id: worktree.id(),
@@ -1191,12 +1224,6 @@ and then another
                     self.recent_project_entries.remove(ix);
                     continue;
                 }
-                if let Some(file_status) = repository.status_for_path(&repo_path) {
-                    if file_status.is_ignored() || file_status.is_untracked() {
-                        // entry not removed because it may belong to a nested repository
-                        continue;
-                    }
-                }
                 let Ok(active_to_now_ms) =
                     now.duration_since(*last_active_at).as_millis().try_into()
                 else {
@@ -1215,6 +1242,15 @@ and then another
     }
 }
 
+fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
+    entry.is_file()
+        && entry.is_created()
+        && !entry.is_ignored
+        && !entry.is_private
+        && !entry.is_external
+        && !entry.is_fifo
+}
+
 pub struct PerformPredictEditsParams {
     pub client: Arc<Client>,
     pub llm_token: LlmApiToken,
@@ -1237,34 +1273,6 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
         .sum()
 }
 
-fn git_repository_for_file(
-    project: &Entity<Project>,
-    project_path: &ProjectPath,
-    cx: &App,
-) -> Option<Entity<Repository>> {
-    let git_store = project.read(cx).git_store().read(cx);
-    git_store
-        .repository_and_path_for_project_path(project_path, cx)
-        .map(|(repo, _repo_path)| repo)
-}
-
-fn make_predict_edits_git_info(repository: &Repository) -> Option<PredictEditsGitInfo> {
-    let head_sha = repository
-        .head_commit
-        .as_ref()
-        .map(|head_commit| head_commit.sha.to_string());
-    let remote_origin_url = repository.remote_origin_url.clone();
-    let remote_upstream_url = repository.remote_upstream_url.clone();
-    if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
-        return None;
-    }
-    Some(PredictEditsGitInfo {
-        head_sha,
-        remote_origin_url,
-        remote_upstream_url,
-    })
-}
-
 pub struct GatherContextOutput {
     pub body: PredictEditsBody,
     pub editable_range: Range<usize>,
@@ -1276,15 +1284,16 @@ pub fn gather_context(
     snapshot: &BufferSnapshot,
     cursor_point: language::Point,
     make_events_prompt: impl FnOnce() -> String + Send + 'static,
-    can_collect_data: bool,
+    can_collect_data: CanCollectData,
     git_info: Option<PredictEditsGitInfo>,
-    recent_files: Option<Vec<PredictEditsRecentFile>>,
     cx: &App,
 ) -> Task<Result<GatherContextOutput>> {
     let local_lsp_store =
         project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
     let diagnostic_groups: Vec<(String, serde_json::Value)> =
-        if can_collect_data && let Some(local_lsp_store) = local_lsp_store {
+        if matches!(can_collect_data, CanCollectData(true))
+            && let Some(local_lsp_store) = local_lsp_store
+        {
             snapshot
                 .diagnostic_groups(None)
                 .into_iter()
@@ -1325,12 +1334,11 @@ pub fn gather_context(
             let body = PredictEditsBody {
                 input_events,
                 input_excerpt: input_excerpt.prompt,
-                can_collect_data,
+                can_collect_data: can_collect_data.0,
                 diagnostic_groups,
                 git_info,
                 outline: None,
                 speculated_output: None,
-                recent_files,
             };
 
             Ok(GatherContextOutput {
@@ -1491,6 +1499,9 @@ pub struct ProviderDataCollection {
     license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
 }
 
+#[derive(Debug, Clone, Copy)]
+pub struct CanCollectData(pub bool);
+
 impl ProviderDataCollection {
     pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
         let choice_and_watcher = buffer.and_then(|buffer| {
@@ -1524,8 +1535,8 @@ impl ProviderDataCollection {
         }
     }
 
-    pub fn can_collect_data(&self, cx: &App) -> bool {
-        self.is_data_collection_enabled(cx) && self.is_project_open_source()
+    pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
+        CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
     }
 
     pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
@@ -2149,7 +2160,7 @@ mod tests {
         let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
         let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
         let completion_task = zeta.update(cx, |zeta, cx| {
-            zeta.request_completion(None, &buffer, cursor, false, cx)
+            zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
         });
 
         let completion = completion_task.await.unwrap().unwrap();
@@ -2214,7 +2225,7 @@ mod tests {
         let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
         let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
         let completion_task = zeta.update(cx, |zeta, cx| {
-            zeta.request_completion(None, &buffer, cursor, false, cx)
+            zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
         });
 
         let completion = completion_task.await.unwrap().unwrap();

crates/zeta_cli/src/main.rs 🔗

@@ -18,7 +18,7 @@ use std::process::exit;
 use std::str::FromStr;
 use std::sync::Arc;
 use std::time::Duration;
-use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
+use zeta::{CanCollectData, GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
 
 use crate::headless::ZetaCliAppState;
 
@@ -172,9 +172,7 @@ async fn get_context(
         None => String::new(),
     };
     // Enable gathering extra data not currently needed for edit predictions
-    let can_collect_data = true;
     let git_info = None;
-    let recent_files = None;
     let mut gather_context_output = cx
         .update(|cx| {
             gather_context(
@@ -183,9 +181,8 @@ async fn get_context(
                 &snapshot,
                 clipped_cursor,
                 move || events,
-                can_collect_data,
+                CanCollectData(true),
                 git_info,
-                recent_files,
                 cx,
             )
         })?