Checkpoint: Adding predict command

Agus Zubiaga and Max Brunsfeld created

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

crates/zeta2/src/related_excerpts.rs |   3 
crates/zeta2/src/zeta2.rs            |   2 
crates/zeta_cli/src/main.rs          | 159 +++++++++++++++++++++++++++++
3 files changed, 158 insertions(+), 6 deletions(-)

Detailed changes

crates/zeta2/src/related_excerpts.rs 🔗

@@ -149,6 +149,9 @@ pub fn find_related_excerpts(
         .find(|model| {
             model.provider_id() == MODEL_PROVIDER_ID
                 && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
+            // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b")
+            // model.provider_id() == LanguageModelProviderId::new("ollama")
+            //     && model.id() == LanguageModelId("gpt-oss:20b".into())
         })
     else {
         return Task::ready(Err(anyhow!("could not find context model")));

crates/zeta2/src/zeta2.rs 🔗

@@ -624,7 +624,7 @@ impl Zeta {
         })
     }
 
-    fn request_prediction(
+    pub fn request_prediction(
         &mut self,
         project: &Entity<Project>,
         buffer: &Entity<Buffer>,

crates/zeta_cli/src/main.rs 🔗

@@ -8,6 +8,7 @@ use crate::example::{ExampleFormat, NamedExample};
 use crate::syntax_retrieval_stats::retrieval_stats;
 use ::serde::Serialize;
 use ::util::paths::PathStyle;
+use ::util::rel_path::RelPath;
 use anyhow::{Context as _, Result, anyhow};
 use clap::{Args, Parser, Subcommand};
 use cloud_llm_client::predict_edits_v3::{self, Excerpt};
@@ -21,7 +22,7 @@ use futures::channel::mpsc;
 use gpui::{Application, AsyncApp, Entity, prelude::*};
 use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
 use language_model::LanguageModelRegistry;
-use project::{Project, Worktree};
+use project::{Project, ProjectPath, Worktree};
 use reqwest_client::ReqwestClient;
 use serde_json::json;
 use std::io;
@@ -46,8 +47,6 @@ enum Command {
         command: Zeta1Command,
     },
     Zeta2 {
-        #[clap(flatten)]
-        args: Zeta2Args,
         #[command(subcommand)]
         command: Zeta2Command,
     },
@@ -69,15 +68,22 @@ enum Zeta1Command {
 #[derive(Subcommand, Debug)]
 enum Zeta2Command {
     Syntax {
+        #[clap(flatten)]
+        args: Zeta2Args,
         #[clap(flatten)]
         syntax_args: Zeta2SyntaxArgs,
         #[command(subcommand)]
         command: Zeta2SyntaxCommand,
     },
     Llm {
+        #[clap(flatten)]
+        args: Zeta2Args,
         #[command(subcommand)]
         command: Zeta2LlmCommand,
     },
+    Predict {
+        example_path: PathBuf,
+    },
 }
 
 #[derive(Subcommand, Debug)]
@@ -319,6 +325,143 @@ async fn load_context(
     })
 }
 
+async fn zeta2_predict(
+    example: NamedExample,
+    app_state: &Arc<ZetaCliAppState>,
+    cx: &mut AsyncApp,
+) -> Result<()> {
+    dbg!(&example);
+    let worktree_path = example.setup_worktree().await?;
+
+    cx.update(|cx| {
+        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+            registry
+                .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
+                .unwrap()
+                .authenticate(cx)
+        })
+    })?
+    .await?;
+
+    app_state
+        .client
+        .sign_in_with_optional_connect(true, cx)
+        .await?;
+
+    let project = cx.update(|cx| {
+        Project::local(
+            app_state.client.clone(),
+            app_state.node_runtime.clone(),
+            app_state.user_store.clone(),
+            app_state.languages.clone(),
+            app_state.fs.clone(),
+            None,
+            cx,
+        )
+    })?;
+
+    let worktree = project
+        .update(cx, |project, cx| {
+            project.create_worktree(&worktree_path, true, cx)
+        })?
+        .await?;
+    worktree
+        .read_with(cx, |worktree, _cx| {
+            worktree.as_local().unwrap().scan_complete()
+        })?
+        .await;
+
+    let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
+
+    let cursor_buffer = project
+        .update(cx, |project, cx| {
+            project.open_buffer(
+                ProjectPath {
+                    worktree_id: worktree.read(cx).id(),
+                    path: cursor_path,
+                },
+                cx,
+            )
+        })?
+        .await?;
+
+    let cursor_offset_within_excerpt = example
+        .example
+        .cursor_position
+        .find(CURSOR_MARKER)
+        .ok_or_else(|| anyhow!("missing cursor marker"))?;
+    let mut cursor_excerpt = example.example.cursor_position.clone();
+    cursor_excerpt.replace_range(
+        cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
+        "",
+    );
+    let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
+        let text = buffer.text();
+
+        let mut matches = text.match_indices(&cursor_excerpt);
+        let Some((excerpt_offset, _)) = matches.next() else {
+            anyhow::bail!(
+                "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
+            );
+        };
+        assert!(matches.next().is_none());
+
+        Ok(excerpt_offset)
+    })??;
+    let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+
+    let cursor_anchor =
+        cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
+
+    let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
+
+    let (prediction_task, mut debug_rx) = zeta.update(cx, |zeta, cx| {
+        let receiver = zeta.debug_info();
+        let prediction_task = zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx);
+        (prediction_task, receiver)
+    })?;
+
+    prediction_task.await.context("No prediction")?;
+    let mut response = None;
+
+    let mut excerpts_text = String::new();
+    while let Some(event) = debug_rx.next().await {
+        match event {
+            zeta2::ZetaDebugInfo::EditPredicted(request) => {
+                response = Some(request.response_rx.await?);
+                for included_file in request.request.included_files {
+                    let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
+                    write_codeblock(
+                        &included_file.path,
+                        included_file.excerpts.iter(),
+                        if included_file.path == request.request.excerpt_path {
+                            &insertions
+                        } else {
+                            &[]
+                        },
+                        included_file.max_row,
+                        false,
+                        &mut excerpts_text,
+                    );
+                }
+            }
+            _ => {}
+        }
+    }
+
+    println!("## Excerpts\n");
+    println!("{excerpts_text}");
+
+    println!("## Prediction\n");
+    let response = response
+        .unwrap()
+        .map(|r| r.debug_info.unwrap().model_response.clone())
+        .unwrap_or_else(|s| s);
+    println!("{response}");
+
+    anyhow::Ok(())
+}
+
 async fn zeta2_syntax_context(
     zeta2_args: Zeta2Args,
     syntax_args: Zeta2SyntaxArgs,
@@ -616,8 +759,14 @@ fn main() {
                     let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
                     serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err))
                 }
-                Command::Zeta2 { args, command } => match command {
+                Command::Zeta2 { command } => match command {
+                    Zeta2Command::Predict { example_path } => {
+                        let example = NamedExample::load(example_path).unwrap();
+                        zeta2_predict(example, &app_state, cx).await.unwrap();
+                        return;
+                    }
                     Zeta2Command::Syntax {
+                        args,
                         syntax_args,
                         command,
                     } => match command {
@@ -643,7 +792,7 @@ fn main() {
                             .await
                         }
                     },
-                    Zeta2Command::Llm { command } => match command {
+                    Zeta2Command::Llm { args, command } => match command {
                         Zeta2LlmCommand::Context { context_args } => {
                             zeta2_llm_context(args, context_args, &app_state, cx).await
                         }