diff --git a/crates/zeta2/src/related_excerpts.rs b/crates/zeta2/src/related_excerpts.rs index 44388251e32678ff8d1b3ce594ab35996b235759..f1721020d000ec9b7ec308eaa3bac4951c45c3f8 100644 --- a/crates/zeta2/src/related_excerpts.rs +++ b/crates/zeta2/src/related_excerpts.rs @@ -149,6 +149,9 @@ pub fn find_related_excerpts( .find(|model| { model.provider_id() == MODEL_PROVIDER_ID && model.id() == LanguageModelId("claude-haiku-4-5-latest".into()) + // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b") + // model.provider_id() == LanguageModelProviderId::new("ollama") + // && model.id() == LanguageModelId("gpt-oss:20b".into()) }) else { return Task::ready(Err(anyhow!("could not find context model"))); diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index bff091b6f0cd5a37c19ee015f8a0383c8b138b40..65bd16ef598bafc4f92329a4699cb513d2220bc0 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -624,7 +624,7 @@ impl Zeta { }) } - fn request_prediction( + pub fn request_prediction( &mut self, project: &Entity, buffer: &Entity, diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 8f19287744697e9f0d2ffd520be8a814790b8345..00963e3bf57b569826a46e596e0606d879d56703 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -8,6 +8,7 @@ use crate::example::{ExampleFormat, NamedExample}; use crate::syntax_retrieval_stats::retrieval_stats; use ::serde::Serialize; use ::util::paths::PathStyle; +use ::util::rel_path::RelPath; use anyhow::{Context as _, Result, anyhow}; use clap::{Args, Parser, Subcommand}; use cloud_llm_client::predict_edits_v3::{self, Excerpt}; @@ -21,7 +22,7 @@ use futures::channel::mpsc; use gpui::{Application, AsyncApp, Entity, prelude::*}; use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point}; use language_model::LanguageModelRegistry; -use project::{Project, Worktree}; +use project::{Project, ProjectPath, Worktree}; use reqwest_client::ReqwestClient; use serde_json::json; use std::io; @@ -46,8 +47,6 @@ enum Command { command: Zeta1Command, }, Zeta2 { - #[clap(flatten)] - args: Zeta2Args, #[command(subcommand)] command: Zeta2Command, }, @@ -69,15 +68,22 @@ enum Zeta1Command { #[derive(Subcommand, Debug)] enum Zeta2Command { Syntax { + #[clap(flatten)] + args: Zeta2Args, #[clap(flatten)] syntax_args: Zeta2SyntaxArgs, #[command(subcommand)] command: Zeta2SyntaxCommand, }, Llm { + #[clap(flatten)] + args: Zeta2Args, #[command(subcommand)] command: Zeta2LlmCommand, }, + Predict { + example_path: PathBuf, + }, } #[derive(Subcommand, Debug)] @@ -319,6 +325,143 @@ async fn load_context( }) } +async fn zeta2_predict( + example: NamedExample, + app_state: &Arc, + cx: &mut AsyncApp, +) -> Result<()> { + dbg!(&example); + let worktree_path = example.setup_worktree().await?; + + cx.update(|cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry + .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID) + .unwrap() + .authenticate(cx) + }) + })? + .await?; + + app_state + .client + .sign_in_with_optional_connect(true, cx) + .await?; + + let project = cx.update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + })?; + + let worktree = project + .update(cx, |project, cx| { + project.create_worktree(&worktree_path, true, cx) + })? + .await?; + worktree + .read_with(cx, |worktree, _cx| { + worktree.as_local().unwrap().scan_complete() + })? + .await; + + let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc(); + + let cursor_buffer = project + .update(cx, |project, cx| { + project.open_buffer( + ProjectPath { + worktree_id: worktree.read(cx).id(), + path: cursor_path, + }, + cx, + ) + })? + .await?; + + let cursor_offset_within_excerpt = example + .example + .cursor_position + .find(CURSOR_MARKER) + .ok_or_else(|| anyhow!("missing cursor marker"))?; + let mut cursor_excerpt = example.example.cursor_position.clone(); + cursor_excerpt.replace_range( + cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), + "", + ); + let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| { + let text = buffer.text(); + + let mut matches = text.match_indices(&cursor_excerpt); + let Some((excerpt_offset, _)) = matches.next() else { + anyhow::bail!( + "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n" + ); + }; + assert!(matches.next().is_none()); + + Ok(excerpt_offset) + })??; + let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; + + let cursor_anchor = + cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?; + + let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?; + + let (prediction_task, mut debug_rx) = zeta.update(cx, |zeta, cx| { + let receiver = zeta.debug_info(); + let prediction_task = zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx); + (prediction_task, receiver) + })?; + + prediction_task.await.context("No prediction")?; + let mut response = None; + + let mut excerpts_text = String::new(); + while let Some(event) = debug_rx.next().await { + match event { + zeta2::ZetaDebugInfo::EditPredicted(request) => { + response = Some(request.response_rx.await?); + for included_file in request.request.included_files { + let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)]; + write_codeblock( + &included_file.path, + included_file.excerpts.iter(), + if included_file.path == request.request.excerpt_path { + &insertions + } else { + &[] + }, + included_file.max_row, + false, + &mut excerpts_text, + ); + } + } + _ => {} + } + } + + println!("## Excerpts\n"); + println!("{excerpts_text}"); + + println!("## Prediction\n"); + let response = response + .unwrap() + .map(|r| r.debug_info.unwrap().model_response.clone()) + .unwrap_or_else(|s| s); + println!("{response}"); + + anyhow::Ok(()) +} + async fn zeta2_syntax_context( zeta2_args: Zeta2Args, syntax_args: Zeta2SyntaxArgs, @@ -616,8 +759,14 @@ fn main() { let context = zeta1_context(context_args, &app_state, cx).await.unwrap(); serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err)) } - Command::Zeta2 { args, command } => match command { + Command::Zeta2 { command } => match command { + Zeta2Command::Predict { example_path } => { + let example = NamedExample::load(example_path).unwrap(); + zeta2_predict(example, &app_state, cx).await.unwrap(); + return; + } Zeta2Command::Syntax { + args, syntax_args, command, } => match command { @@ -643,7 +792,7 @@ fn main() { .await } }, - Zeta2Command::Llm { command } => match command { + Zeta2Command::Llm { args, command } => match command { Zeta2LlmCommand::Context { context_args } => { zeta2_llm_context(args, context_args, &app_state, cx).await }