@@ -8,6 +8,7 @@ use crate::example::{ExampleFormat, NamedExample};
use crate::syntax_retrieval_stats::retrieval_stats;
use ::serde::Serialize;
use ::util::paths::PathStyle;
+use ::util::rel_path::RelPath;
use anyhow::{Context as _, Result, anyhow};
use clap::{Args, Parser, Subcommand};
use cloud_llm_client::predict_edits_v3::{self, Excerpt};
@@ -21,7 +22,7 @@ use futures::channel::mpsc;
use gpui::{Application, AsyncApp, Entity, prelude::*};
use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
use language_model::LanguageModelRegistry;
-use project::{Project, Worktree};
+use project::{Project, ProjectPath, Worktree};
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::io;
@@ -46,8 +47,6 @@ enum Command {
command: Zeta1Command,
},
Zeta2 {
- #[clap(flatten)]
- args: Zeta2Args,
#[command(subcommand)]
command: Zeta2Command,
},
@@ -69,15 +68,22 @@ enum Zeta1Command {
#[derive(Subcommand, Debug)]
enum Zeta2Command {
Syntax {
+ #[clap(flatten)]
+ args: Zeta2Args,
#[clap(flatten)]
syntax_args: Zeta2SyntaxArgs,
#[command(subcommand)]
command: Zeta2SyntaxCommand,
},
Llm {
+ #[clap(flatten)]
+ args: Zeta2Args,
#[command(subcommand)]
command: Zeta2LlmCommand,
},
+ Predict {
+ example_path: PathBuf,
+ },
}
#[derive(Subcommand, Debug)]
@@ -319,6 +325,143 @@ async fn load_context(
})
}
+async fn zeta2_predict(
+ example: NamedExample,
+ app_state: &Arc<ZetaCliAppState>,
+ cx: &mut AsyncApp,
+) -> Result<()> {
+ dbg!(&example);
+ let worktree_path = example.setup_worktree().await?;
+
+ cx.update(|cx| {
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry
+ .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
+ .unwrap()
+ .authenticate(cx)
+ })
+ })?
+ .await?;
+
+ app_state
+ .client
+ .sign_in_with_optional_connect(true, cx)
+ .await?;
+
+ let project = cx.update(|cx| {
+ Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ )
+ })?;
+
+ let worktree = project
+ .update(cx, |project, cx| {
+ project.create_worktree(&worktree_path, true, cx)
+ })?
+ .await?;
+ worktree
+ .read_with(cx, |worktree, _cx| {
+ worktree.as_local().unwrap().scan_complete()
+ })?
+ .await;
+
+ let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
+
+ let cursor_buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer(
+ ProjectPath {
+ worktree_id: worktree.read(cx).id(),
+ path: cursor_path,
+ },
+ cx,
+ )
+ })?
+ .await?;
+
+ let cursor_offset_within_excerpt = example
+ .example
+ .cursor_position
+ .find(CURSOR_MARKER)
+ .ok_or_else(|| anyhow!("missing cursor marker"))?;
+ let mut cursor_excerpt = example.example.cursor_position.clone();
+ cursor_excerpt.replace_range(
+ cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
+ "",
+ );
+ let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
+ let text = buffer.text();
+
+ let mut matches = text.match_indices(&cursor_excerpt);
+ let Some((excerpt_offset, _)) = matches.next() else {
+ anyhow::bail!(
+ "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
+ );
+ };
+ assert!(matches.next().is_none());
+
+ Ok(excerpt_offset)
+ })??;
+ let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+
+ let cursor_anchor =
+ cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
+
+ let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
+
+ let (prediction_task, mut debug_rx) = zeta.update(cx, |zeta, cx| {
+ let receiver = zeta.debug_info();
+ let prediction_task = zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx);
+ (prediction_task, receiver)
+ })?;
+
+ prediction_task.await.context("No prediction")?;
+ let mut response = None;
+
+ let mut excerpts_text = String::new();
+ while let Some(event) = debug_rx.next().await {
+ match event {
+ zeta2::ZetaDebugInfo::EditPredicted(request) => {
+ response = Some(request.response_rx.await?);
+ for included_file in request.request.included_files {
+ let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
+ write_codeblock(
+ &included_file.path,
+ included_file.excerpts.iter(),
+ if included_file.path == request.request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ included_file.max_row,
+ false,
+ &mut excerpts_text,
+ );
+ }
+ }
+ _ => {}
+ }
+ }
+
+ println!("## Excerpts\n");
+ println!("{excerpts_text}");
+
+ println!("## Prediction\n");
+ let response = response
+ .unwrap()
+ .map(|r| r.debug_info.unwrap().model_response.clone())
+ .unwrap_or_else(|s| s);
+ println!("{response}");
+
+ anyhow::Ok(())
+}
+
async fn zeta2_syntax_context(
zeta2_args: Zeta2Args,
syntax_args: Zeta2SyntaxArgs,
@@ -616,8 +759,14 @@ fn main() {
let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err))
}
- Command::Zeta2 { args, command } => match command {
+ Command::Zeta2 { command } => match command {
+ Zeta2Command::Predict { example_path } => {
+ let example = NamedExample::load(example_path).unwrap();
+ zeta2_predict(example, &app_state, cx).await.unwrap();
+ return;
+ }
Zeta2Command::Syntax {
+ args,
syntax_args,
command,
} => match command {
@@ -643,7 +792,7 @@ fn main() {
.await
}
},
- Zeta2Command::Llm { command } => match command {
+ Zeta2Command::Llm { args, command } => match command {
Zeta2LlmCommand::Context { context_args } => {
zeta2_llm_context(args, context_args, &app_state, cx).await
}