Include repo_url in edit prediction requests when data collection is enabled (#50745)

Max Brunsfeld created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/edit_prediction_tests.rs |  1 +
crates/edit_prediction/src/fim.rs                   |  1 +
crates/edit_prediction/src/mercury.rs               |  1 +
crates/edit_prediction/src/prediction.rs            |  1 +
crates/edit_prediction/src/sweep_ai.rs              |  1 +
crates/edit_prediction/src/zeta.rs                  | 15 +++++++++++++++
crates/edit_prediction_cli/src/load_project.rs      |  1 +
crates/edit_prediction_cli/src/reversal_tracking.rs |  1 +
crates/zeta_prompt/src/zeta_prompt.rs               |  6 ++++++
9 files changed, 28 insertions(+)

Detailed changes

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -1848,6 +1848,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
             experiment: None,
             in_open_source_repo: false,
             can_collect_data: false,
+            repo_url: None,
         },
         buffer_snapshotted_at: Instant::now(),
         response_received_at: Instant::now(),

crates/edit_prediction/src/fim.rs 🔗

@@ -85,6 +85,7 @@ pub fn request_prediction(
             experiment: None,
             in_open_source_repo: false,
             can_collect_data: false,
+            repo_url: None,
         };
 
         let prefix = inputs.cursor_excerpt[..inputs.cursor_offset_in_excerpt].to_string();

crates/edit_prediction/src/prediction.rs 🔗

@@ -165,6 +165,7 @@ mod tests {
                 experiment: None,
                 in_open_source_repo: false,
                 can_collect_data: false,
+                repo_url: None,
             },
             buffer_snapshotted_at: Instant::now(),
             response_received_at: Instant::now(),

crates/edit_prediction/src/zeta.rs 🔗

@@ -64,6 +64,18 @@ pub fn request_prediction_with_zeta(
         .map(|file| -> Arc<Path> { file.full_path(cx).into() })
         .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 
+    let repo_url = if can_collect_data {
+        let buffer_id = buffer.read(cx).remote_id();
+        project
+            .read(cx)
+            .git_store()
+            .read(cx)
+            .repository_and_path_for_buffer_id(buffer_id, cx)
+            .and_then(|(repo, _)| repo.read(cx).default_remote_url())
+    } else {
+        None
+    };
+
     let client = store.client.clone();
     let llm_token = store.llm_token.clone();
     let organization_id = store
@@ -91,6 +103,7 @@ pub fn request_prediction_with_zeta(
                 preferred_experiment,
                 is_open_source,
                 can_collect_data,
+                repo_url,
             );
 
             if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
@@ -391,6 +404,7 @@ pub fn zeta2_prompt_input(
     preferred_experiment: Option<String>,
     is_open_source: bool,
     can_collect_data: bool,
+    repo_url: Option<String>,
 ) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
     let cursor_point = cursor_offset.to_point(snapshot);
 
@@ -422,6 +436,7 @@ pub fn zeta2_prompt_input(
         experiment: preferred_experiment,
         in_open_source_repo: is_open_source,
         can_collect_data,
+        repo_url,
     };
     (full_context_offset_range, prompt_input)
 }

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -61,6 +61,8 @@ pub struct ZetaPromptInput {
     pub in_open_source_repo: bool,
     #[serde(default)]
     pub can_collect_data: bool,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub repo_url: Option<String>,
 }
 
 #[derive(
@@ -2715,6 +2717,7 @@ mod tests {
             experiment: None,
             in_open_source_repo: false,
             can_collect_data: false,
+            repo_url: None,
         }
     }
 
@@ -3312,6 +3315,7 @@ mod tests {
             experiment: None,
             in_open_source_repo: false,
             can_collect_data: false,
+            repo_url: None,
         };
 
         let prompt = zeta1::format_zeta1_from_input(&input, 15..41, 0..excerpt.len());
@@ -3374,6 +3378,7 @@ mod tests {
             experiment: None,
             in_open_source_repo: false,
             can_collect_data: false,
+            repo_url: None,
         };
 
         let prompt = zeta1::format_zeta1_from_input(&input, 0..28, 0..28);
@@ -3431,6 +3436,7 @@ mod tests {
             experiment: None,
             in_open_source_repo: false,
             can_collect_data: false,
+            repo_url: None,
         };
 
         let prompt = zeta1::format_zeta1_from_input(&input, editable_range, context_range);