Include outline when predicting edits with Zeta (#22895)

Antonio Scandurra and Thorsten created

Release Notes:

- N/A

Co-authored-by: Thorsten <thorsten@zed.dev>

Change summary

crates/collab/src/llm.rs                   |  8 ++++++
crates/collab/src/llm/prediction_prompt.md |  1 
crates/rpc/src/llm.rs                      |  1 
crates/zeta/src/zeta.rs                    | 30 ++++++++++++++---------
4 files changed, 28 insertions(+), 12 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -459,7 +459,15 @@ async fn predict_edits(
         .prediction_model
         .as_ref()
         .context("no PREDICTION_MODEL configured on the server")?;
+
+    let outline_prefix = params
+        .outline
+        .as_ref()
+        .map(|outline| format!("### Outline for current file:\n{}\n", outline))
+        .unwrap_or_default();
+
     let prompt = include_str!("./llm/prediction_prompt.md")
+        .replace("<outline>", &outline_prefix)
         .replace("<events>", &params.input_events)
         .replace("<excerpt>", &params.input_excerpt);
     let mut response = open_ai::complete_text(

crates/collab/src/llm/prediction_prompt.md 🔗

@@ -1,3 +1,4 @@
+<outline>## Task
 Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 
 ### Instruction:

crates/rpc/src/llm.rs 🔗

@@ -36,6 +36,7 @@ pub struct PerformCompletionParams {
 
 #[derive(Debug, Serialize, Deserialize)]
 pub struct PredictEditsParams {
+    pub outline: Option<String>,
     pub input_events: String,
     pub input_excerpt: String,
 }

crates/zeta/src/zeta.rs 🔗

@@ -300,29 +300,35 @@ impl Zeta {
         cx.spawn(|this, mut cx| async move {
             let request_sent_at = Instant::now();
 
-            let input_events = cx
+            let (input_events, input_excerpt, input_outline) = cx
                 .background_executor()
-                .spawn(async move {
-                    let mut input_events = String::new();
-                    for event in events {
-                        if !input_events.is_empty() {
-                            input_events.push('\n');
-                            input_events.push('\n');
+                .spawn({
+                    let snapshot = snapshot.clone();
+                    let excerpt_range = excerpt_range.clone();
+                    async move {
+                        let mut input_events = String::new();
+                        for event in events {
+                            if !input_events.is_empty() {
+                                input_events.push('\n');
+                                input_events.push('\n');
+                            }
+                            input_events.push_str(&event.to_prompt());
                         }
-                        input_events.push_str(&event.to_prompt());
+
+                        let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
+                        let input_outline = prompt_for_outline(&snapshot);
+
+                        (input_events, input_excerpt, input_outline)
                     }
-                    input_events
                 })
                 .await;
 
-            let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
-            let input_outline = prompt_for_outline(&snapshot);
-
             log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
 
             let body = PredictEditsParams {
                 input_events: input_events.clone(),
                 input_excerpt: input_excerpt.clone(),
+                outline: Some(input_outline.clone()),
             };
 
             let response = perform_predict_edits(client, llm_token, body).await?;