From 60c546196a37d684a03bc19e830121c089a5f858 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 30 Oct 2025 18:41:09 -0300 Subject: [PATCH] zeta2: Expose llm-based context retrieval via zeta_cli (#41584) Release Notes: - N/A --------- Co-authored-by: Max Brunsfeld Co-authored-by: Oleksiy Syvokon --- .../src/cloud_zeta2_prompt.rs | 27 +- crates/zeta2/src/merge_excerpts.rs | 26 +- crates/zeta2/src/related_excerpts.rs | 66 +- crates/zeta2/src/zeta2.rs | 34 +- crates/zeta2_tools/src/zeta2_context_view.rs | 7 +- crates/zeta_cli/src/main.rs | 605 ++++++++++++------ ...val_stats.rs => syntax_retrieval_stats.rs} | 0 7 files changed, 509 insertions(+), 256 deletions(-) rename crates/zeta_cli/src/{retrieval_stats.rs => syntax_retrieval_stats.rs} (100%) diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index 1c8b1caf80db28ef936aa9a747b4a163e183134f..a0df39b50eb6753397f5afd37aa30b71b853b9c5 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -182,8 +182,8 @@ pub fn build_prompt( } for related_file in &request.included_files { - writeln!(&mut prompt, "`````filename={}", related_file.path.display()).unwrap(); - write_excerpts( + write_codeblock( + &related_file.path, &related_file.excerpts, if related_file.path == request.excerpt_path { &insertions @@ -194,7 +194,6 @@ pub fn build_prompt( request.prompt_format == PromptFormat::NumLinesUniDiff, &mut prompt, ); - write!(&mut prompt, "`````\n\n").unwrap(); } } @@ -205,6 +204,25 @@ pub fn build_prompt( Ok((prompt, section_labels)) } +pub fn write_codeblock<'a>( + path: &Path, + excerpts: impl IntoIterator, + sorted_insertions: &[(Point, &str)], + file_line_count: Line, + include_line_numbers: bool, + output: &'a mut String, +) { + writeln!(output, "`````path={}", path.display()).unwrap(); + write_excerpts( + excerpts, + sorted_insertions, + file_line_count, + include_line_numbers, + output, + ); + write!(output, "`````\n\n").unwrap(); +} + pub fn write_excerpts<'a>( excerpts: impl IntoIterator, sorted_insertions: &[(Point, &str)], @@ -597,8 +615,7 @@ impl<'a> SyntaxBasedPrompt<'a> { disjoint_snippets.push(current_snippet); } - // TODO: remove filename=? - writeln!(output, "`````filename={}", file_path.display()).ok(); + writeln!(output, "`````path={}", file_path.display()).ok(); let mut skipped_last_snippet = false; for (snippet, range) in disjoint_snippets { let section_index = section_ranges.len(); diff --git a/crates/zeta2/src/merge_excerpts.rs b/crates/zeta2/src/merge_excerpts.rs index 4cb7ab6cf4d3b63e641087f0c22cf0f900f56adc..846d8034a8c2e88b8552dc8c9d48af6ccdc5efcf 100644 --- a/crates/zeta2/src/merge_excerpts.rs +++ b/crates/zeta2/src/merge_excerpts.rs @@ -1,4 +1,4 @@ -use cloud_llm_client::predict_edits_v3::{self, Excerpt}; +use cloud_llm_client::predict_edits_v3::Excerpt; use edit_prediction_context::Line; use language::{BufferSnapshot, Point}; use std::ops::Range; @@ -58,26 +58,12 @@ pub fn merge_excerpts( output } -pub fn write_merged_excerpts( - buffer: &BufferSnapshot, - sorted_line_ranges: impl IntoIterator>, - sorted_insertions: &[(predict_edits_v3::Point, &str)], - output: &mut String, -) { - cloud_zeta2_prompt::write_excerpts( - merge_excerpts(buffer, sorted_line_ranges).iter(), - sorted_insertions, - Line(buffer.max_point().row), - true, - output, - ); -} - #[cfg(test)] mod tests { use std::sync::Arc; use super::*; + use cloud_llm_client::predict_edits_v3; use gpui::{TestAppContext, prelude::*}; use indoc::indoc; use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt}; @@ -168,7 +154,13 @@ mod tests { .collect(); let mut output = String::new(); - write_merged_excerpts(&buffer.snapshot(), ranges, &insertions, &mut output); + cloud_zeta2_prompt::write_excerpts( + merge_excerpts(&buffer.snapshot(), ranges).iter(), + &insertions, + Line(buffer.max_point().row), + true, + &mut output, + ); assert_eq!(output, expected_output); }); } diff --git a/crates/zeta2/src/related_excerpts.rs b/crates/zeta2/src/related_excerpts.rs index d8fff7e0201716be45451c302c4f83b667727bc2..dd27992274ae2b25ec07e2a47dc8a60b46f5f3f2 100644 --- a/crates/zeta2/src/related_excerpts.rs +++ b/crates/zeta2/src/related_excerpts.rs @@ -1,13 +1,13 @@ use std::{ - cmp::Reverse, collections::hash_map::Entry, fmt::Write, ops::Range, path::PathBuf, sync::Arc, - time::Instant, + cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant, }; use crate::{ - ZetaContextRetrievalDebugInfo, ZetaDebugInfo, ZetaSearchQueryDebugInfo, - merge_excerpts::write_merged_excerpts, + ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo, + ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts, }; use anyhow::{Result, anyhow}; +use cloud_zeta2_prompt::write_codeblock; use collections::HashMap; use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line}; use futures::{ @@ -22,8 +22,9 @@ use language::{ }; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, - LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, + LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, + LanguageModelToolUse, MessageContent, Role, }; use project::{ Project, WorktreeSettings, @@ -63,7 +64,7 @@ const SEARCH_PROMPT: &str = indoc! {r#" ## Current cursor context - `````filename={current_file_path} + `````path={current_file_path} {cursor_excerpt} ````` @@ -130,11 +131,13 @@ pub struct LlmContextOptions { pub excerpt: EditPredictionExcerptOptions, } -pub fn find_related_excerpts<'a>( +pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID; + +pub fn find_related_excerpts( buffer: Entity, cursor_position: Anchor, project: &Entity, - events: impl Iterator, + mut edit_history_unified_diff: String, options: &LlmContextOptions, debug_tx: Option>, cx: &App, @@ -144,23 +147,15 @@ pub fn find_related_excerpts<'a>( .read(cx) .available_models(cx) .find(|model| { - model.provider_id() == language_model::ANTHROPIC_PROVIDER_ID + model.provider_id() == MODEL_PROVIDER_ID && model.id() == LanguageModelId("claude-haiku-4-5-latest".into()) }) else { - return Task::ready(Err(anyhow!("could not find claude model"))); + return Task::ready(Err(anyhow!("could not find context model"))); }; - let mut edits_string = String::new(); - - for event in events { - if let Some(event) = event.to_request_event(cx) { - writeln!(&mut edits_string, "{event}").ok(); - } - } - - if edits_string.is_empty() { - edits_string.push_str("(No user edits yet)"); + if edit_history_unified_diff.is_empty() { + edit_history_unified_diff.push_str("(No user edits yet)"); } // TODO [zeta2] include breadcrumbs? @@ -178,10 +173,22 @@ pub fn find_related_excerpts<'a>( .unwrap_or_else(|| "untitled".to_string()); let prompt = SEARCH_PROMPT - .replace("{edits}", &edits_string) + .replace("{edits}", &edit_history_unified_diff) .replace("{current_file_path}", ¤t_file_path) .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body); + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( + ZetaContextRetrievalStartedDebugInfo { + project: project.clone(), + timestamp: Instant::now(), + search_prompt: prompt.clone(), + }, + )) + .ok(); + } + let path_style = project.read(cx).path_style(cx); let exclude_matcher = { @@ -428,19 +435,14 @@ pub fn find_related_excerpts<'a>( .line_ranges .sort_unstable_by_key(|range| (range.start, Reverse(range.end))); - writeln!( - &mut merged_result, - "`````filename={}", - matched.full_path.display() - ) - .unwrap(); - write_merged_excerpts( - &matched.snapshot, - matched.line_ranges, + write_codeblock( + &matched.full_path, + merge_excerpts(&matched.snapshot, matched.line_ranges).iter(), &[], + Line(matched.snapshot.max_point().row), + true, &mut merged_result, ); - merged_result.push_str("`````\n\n"); result_buffers_by_path.insert( matched.full_path, diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index b6311f9d25dfc91c078f6614b344eb91cabd51eb..bff091b6f0cd5a37c19ee015f8a0383c8b138b40 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -28,6 +28,7 @@ use project::Project; use release_channel::AppVersion; use serde::de::DeserializeOwned; use std::collections::{VecDeque, hash_map}; +use std::fmt::Write; use std::ops::Range; use std::path::Path; use std::str::FromStr as _; @@ -38,10 +39,10 @@ use util::ResultExt as _; use util::rel_path::RelPathBuf; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; -mod merge_excerpts; +pub mod merge_excerpts; mod prediction; mod provider; -mod related_excerpts; +pub mod related_excerpts; use crate::merge_excerpts::merge_excerpts; use crate::prediction::EditPrediction; @@ -135,7 +136,7 @@ impl ContextMode { } pub enum ZetaDebugInfo { - ContextRetrievalStarted(ZetaContextRetrievalDebugInfo), + ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo), SearchQueriesGenerated(ZetaSearchQueryDebugInfo), SearchQueriesExecuted(ZetaContextRetrievalDebugInfo), SearchResultsFiltered(ZetaContextRetrievalDebugInfo), @@ -143,6 +144,12 @@ pub enum ZetaDebugInfo { EditPredicted(ZetaEditPredictionDebugInfo), } +pub struct ZetaContextRetrievalStartedDebugInfo { + pub project: Entity, + pub timestamp: Instant, + pub search_prompt: String, +} + pub struct ZetaContextRetrievalDebugInfo { pub project: Entity, pub timestamp: Instant, @@ -1086,17 +1093,6 @@ impl Zeta { zeta_project .refresh_context_task .get_or_insert(cx.spawn(async move |this, cx| { - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( - ZetaContextRetrievalDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - }, - )) - .ok(); - } - let related_excerpts = this .update(cx, |this, cx| { let Some(zeta_project) = this.projects.get(&project.entity_id()) else { @@ -1107,11 +1103,19 @@ impl Zeta { return Task::ready(anyhow::Ok(HashMap::default())); }; + let mut edit_history_unified_diff = String::new(); + + for event in zeta_project.events.iter() { + if let Some(event) = event.to_request_event(cx) { + writeln!(&mut edit_history_unified_diff, "{event}").ok(); + } + } + find_related_excerpts( buffer.clone(), cursor_position, &project, - zeta_project.events.iter(), + edit_history_unified_diff, options, debug_tx, cx, diff --git a/crates/zeta2_tools/src/zeta2_context_view.rs b/crates/zeta2_tools/src/zeta2_context_view.rs index 0abca0fbf451955c285fe3a9df482c507dc4ff10..9532d77622645f80696d69ed92b0190e48f838c7 100644 --- a/crates/zeta2_tools/src/zeta2_context_view.rs +++ b/crates/zeta2_tools/src/zeta2_context_view.rs @@ -24,7 +24,10 @@ use ui::{ v_flex, }; use workspace::{Item, ItemHandle as _}; -use zeta2::{Zeta, ZetaContextRetrievalDebugInfo, ZetaDebugInfo, ZetaSearchQueryDebugInfo}; +use zeta2::{ + Zeta, ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo, + ZetaSearchQueryDebugInfo, +}; pub struct Zeta2ContextView { empty_focus_handle: FocusHandle, @@ -130,7 +133,7 @@ impl Zeta2ContextView { fn handle_context_retrieval_started( &mut self, - info: ZetaContextRetrievalDebugInfo, + info: ZetaContextRetrievalStartedDebugInfo, window: &mut Window, cx: &mut Context, ) { diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index eea80898870d68a8ad361de43d4556438ed25444..7a6d4b26dc87cd9db7d40fe2745520ee5f574ea6 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -1,26 +1,29 @@ mod headless; -mod retrieval_stats; mod source_location; +mod syntax_retrieval_stats; mod util; -use crate::retrieval_stats::retrieval_stats; +use crate::syntax_retrieval_stats::retrieval_stats; +use ::serde::Serialize; use ::util::paths::PathStyle; -use anyhow::{Result, anyhow}; +use anyhow::{Context as _, Result, anyhow}; use clap::{Args, Parser, Subcommand}; -use cloud_llm_client::predict_edits_v3::{self}; +use cloud_llm_client::predict_edits_v3::{self, Excerpt}; +use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; use edit_prediction_context::{ - EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions, + EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions, + EditPredictionScoreOptions, Line, }; -use gpui::{Application, AsyncApp, prelude::*}; -use language::Bias; -use language_model::LlmApiToken; -use project::Project; -use release_channel::AppVersion; +use futures::StreamExt as _; +use futures::channel::mpsc; +use gpui::{Application, AsyncApp, Entity, prelude::*}; +use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point}; +use language_model::LanguageModelRegistry; +use project::{Project, Worktree}; use reqwest_client::ReqwestClient; use serde_json::json; use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc}; -use zeta::{PerformPredictEditsParams, Zeta}; -use zeta2::ContextMode; +use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery}; use crate::headless::ZetaCliAppState; use crate::source_location::SourceLocation; @@ -30,27 +33,52 @@ use crate::util::{open_buffer, open_buffer_with_language_server}; #[command(name = "zeta")] struct ZetaCliArgs { #[command(subcommand)] - command: Commands, + command: Command, } #[derive(Subcommand, Debug)] -enum Commands { - Context(ContextArgs), - Zeta2Context { +enum Command { + Zeta1 { + #[command(subcommand)] + command: Zeta1Command, + }, + Zeta2 { #[clap(flatten)] - zeta2_args: Zeta2Args, + args: Zeta2Args, + #[command(subcommand)] + command: Zeta2Command, + }, +} + +#[derive(Subcommand, Debug)] +enum Zeta1Command { + Context { #[clap(flatten)] context_args: ContextArgs, }, - Predict { - #[arg(long)] - predict_edits_body: Option, +} + +#[derive(Subcommand, Debug)] +enum Zeta2Command { + Syntax { #[clap(flatten)] - context_args: Option, + syntax_args: Zeta2SyntaxArgs, + #[command(subcommand)] + command: Zeta2SyntaxCommand, + }, + Llm { + #[command(subcommand)] + command: Zeta2LlmCommand, }, - RetrievalStats { +} + +#[derive(Subcommand, Debug)] +enum Zeta2SyntaxCommand { + Context { #[clap(flatten)] - zeta2_args: Zeta2Args, + context_args: ContextArgs, + }, + Stats { #[arg(long)] worktree: PathBuf, #[arg(long)] @@ -62,6 +90,14 @@ enum Commands { }, } +#[derive(Subcommand, Debug)] +enum Zeta2LlmCommand { + Context { + #[clap(flatten)] + context_args: ContextArgs, + }, +} + #[derive(Debug, Args)] #[group(requires = "worktree")] struct ContextArgs { @@ -72,7 +108,7 @@ struct ContextArgs { #[arg(long)] use_language_server: bool, #[arg(long)] - events: Option, + edit_history: Option, } #[derive(Debug, Args)] @@ -93,12 +129,42 @@ struct Zeta2Args { output_format: OutputFormat, #[arg(long, default_value_t = 42)] file_indexing_parallelism: usize, +} + +#[derive(Debug, Args)] +struct Zeta2SyntaxArgs { #[arg(long, default_value_t = false)] disable_imports_gathering: bool, #[arg(long, default_value_t = u8::MAX)] max_retrieved_definitions: u8, } +fn syntax_args_to_options( + zeta2_args: &Zeta2Args, + syntax_args: &Zeta2SyntaxArgs, + omit_excerpt_overlaps: bool, +) -> zeta2::ZetaOptions { + zeta2::ZetaOptions { + context: ContextMode::Syntax(EditPredictionContextOptions { + max_retrieved_declarations: syntax_args.max_retrieved_definitions, + use_imports: !syntax_args.disable_imports_gathering, + excerpt: EditPredictionExcerptOptions { + max_bytes: zeta2_args.max_excerpt_bytes, + min_bytes: zeta2_args.min_excerpt_bytes, + target_before_cursor_over_total_bytes: zeta2_args + .target_before_cursor_over_total_bytes, + }, + score: EditPredictionScoreOptions { + omit_excerpt_overlaps, + }, + }), + max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes, + max_prompt_bytes: zeta2_args.max_prompt_bytes, + prompt_format: zeta2_args.prompt_format.clone().into(), + file_indexing_parallelism: zeta2_args.file_indexing_parallelism, + } +} + #[derive(clap::ValueEnum, Default, Debug, Clone)] enum PromptFormat { MarkedExcerpt, @@ -153,22 +219,25 @@ impl FromStr for FileOrStdin { } } -enum GetContextOutput { - Zeta1(zeta::GatherContextOutput), - Zeta2(String), +struct LoadedContext { + full_path_str: String, + snapshot: BufferSnapshot, + clipped_cursor: Point, + worktree: Entity, + project: Entity, + buffer: Entity, } -async fn get_context( - zeta2_args: Option, - args: ContextArgs, +async fn load_context( + args: &ContextArgs, app_state: &Arc, cx: &mut AsyncApp, -) -> Result { +) -> Result { let ContextArgs { worktree: worktree_path, cursor, use_language_server, - events, + .. } = args; let worktree_path = worktree_path.canonicalize()?; @@ -192,7 +261,7 @@ async fn get_context( .await?; let mut ready_languages = HashSet::default(); - let (_lsp_open_handle, buffer) = if use_language_server { + let (_lsp_open_handle, buffer) = if *use_language_server { let (lsp_open_handle, _, buffer) = open_buffer_with_language_server( project.clone(), worktree.clone(), @@ -232,95 +301,294 @@ async fn get_context( } } - let events = match events { + Ok(LoadedContext { + full_path_str, + snapshot, + clipped_cursor, + worktree, + project, + buffer, + }) +} + +async fn zeta2_syntax_context( + zeta2_args: Zeta2Args, + syntax_args: Zeta2SyntaxArgs, + args: ContextArgs, + app_state: &Arc, + cx: &mut AsyncApp, +) -> Result { + let LoadedContext { + worktree, + project, + buffer, + clipped_cursor, + .. + } = load_context(&args, app_state, cx).await?; + + // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for + // the whole worktree. + worktree + .read_with(cx, |worktree, _cx| { + worktree.as_local().unwrap().scan_complete() + })? + .await; + let output = cx + .update(|cx| { + let zeta = cx.new(|cx| { + zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) + }); + let indexing_done_task = zeta.update(cx, |zeta, cx| { + zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true)); + zeta.register_buffer(&buffer, &project, cx); + zeta.wait_for_initial_indexing(&project, cx) + }); + cx.spawn(async move |cx| { + indexing_done_task.await?; + let request = zeta + .update(cx, |zeta, cx| { + let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); + zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx) + })? + .await?; + + let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?; + + match zeta2_args.output_format { + OutputFormat::Prompt => anyhow::Ok(prompt_string), + OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?), + OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({ + "request": request, + "prompt": prompt_string, + "section_labels": section_labels, + }))?), + } + }) + })? + .await?; + + Ok(output) +} + +async fn zeta2_llm_context( + zeta2_args: Zeta2Args, + context_args: ContextArgs, + app_state: &Arc, + cx: &mut AsyncApp, +) -> Result { + let LoadedContext { + buffer, + clipped_cursor, + snapshot: cursor_snapshot, + project, + .. + } = load_context(&context_args, app_state, cx).await?; + + let cursor_position = cursor_snapshot.anchor_after(clipped_cursor); + + cx.update(|cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry + .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID) + .unwrap() + .authenticate(cx) + }) + })? + .await?; + + let edit_history_unified_diff = match context_args.edit_history { Some(events) => events.read_to_string().await?, None => String::new(), }; - if let Some(zeta2_args) = zeta2_args { - // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for - // the whole worktree. - worktree - .read_with(cx, |worktree, _cx| { - worktree.as_local().unwrap().scan_complete() - })? - .await; - let output = cx - .update(|cx| { - let zeta = cx.new(|cx| { - zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) - }); - let indexing_done_task = zeta.update(cx, |zeta, cx| { - zeta.set_options(zeta2_args.to_options(true)); - zeta.register_buffer(&buffer, &project, cx); - zeta.wait_for_initial_indexing(&project, cx) - }); - cx.spawn(async move |cx| { - indexing_done_task.await?; - let request = zeta - .update(cx, |zeta, cx| { - let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); - zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx) - })? - .await?; - - let (prompt_string, section_labels) = - cloud_zeta2_prompt::build_prompt(&request)?; - - match zeta2_args.output_format { - OutputFormat::Prompt => anyhow::Ok(prompt_string), - OutputFormat::Request => { - anyhow::Ok(serde_json::to_string_pretty(&request)?) - } - OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({ - "request": request, - "prompt": prompt_string, - "section_labels": section_labels, - }))?), - } - }) - })? - .await?; - Ok(GetContextOutput::Zeta2(output)) - } else { - let prompt_for_events = move || (events, 0); - Ok(GetContextOutput::Zeta1( - cx.update(|cx| { - zeta::gather_context( - full_path_str, - &snapshot, - clipped_cursor, - prompt_for_events, - cx, - ) - })? - .await?, - )) - } -} + let (debug_tx, mut debug_rx) = mpsc::unbounded(); -impl Zeta2Args { - fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions { - zeta2::ZetaOptions { - context: ContextMode::Syntax(EditPredictionContextOptions { - max_retrieved_declarations: self.max_retrieved_definitions, - use_imports: !self.disable_imports_gathering, - excerpt: EditPredictionExcerptOptions { - max_bytes: self.max_excerpt_bytes, - min_bytes: self.min_excerpt_bytes, - target_before_cursor_over_total_bytes: self - .target_before_cursor_over_total_bytes, - }, - score: EditPredictionScoreOptions { - omit_excerpt_overlaps, + let excerpt_options = EditPredictionExcerptOptions { + max_bytes: zeta2_args.max_excerpt_bytes, + min_bytes: zeta2_args.min_excerpt_bytes, + target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes, + }; + + let related_excerpts = cx + .update(|cx| { + zeta2::related_excerpts::find_related_excerpts( + buffer, + cursor_position, + &project, + edit_history_unified_diff, + &LlmContextOptions { + excerpt: excerpt_options.clone(), }, - }), - max_diagnostic_bytes: self.max_diagnostic_bytes, - max_prompt_bytes: self.max_prompt_bytes, - prompt_format: self.prompt_format.clone().into(), - file_indexing_parallelism: self.file_indexing_parallelism, + Some(debug_tx), + cx, + ) + })? + .await?; + + let cursor_excerpt = EditPredictionExcerpt::select_from_buffer( + clipped_cursor, + &cursor_snapshot, + &excerpt_options, + None, + ) + .context("line didn't fit")?; + + #[derive(Serialize)] + struct Output { + excerpts: Vec, + formatted_excerpts: String, + meta: OutputMeta, + } + + #[derive(Default, Serialize)] + struct OutputMeta { + search_prompt: String, + search_queries: Vec, + } + + #[derive(Serialize)] + struct OutputExcerpt { + path: PathBuf, + #[serde(flatten)] + excerpt: Excerpt, + } + + let mut meta = OutputMeta::default(); + + while let Some(debug_info) = debug_rx.next().await { + match debug_info { + zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => { + meta.search_prompt = info.search_prompt; + } + zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => { + meta.search_queries = info.queries + } + _ => {} } } + + cx.update(|cx| { + let mut excerpts = Vec::new(); + let mut formatted_excerpts = String::new(); + + let cursor_insertions = [( + predict_edits_v3::Point { + line: Line(clipped_cursor.row), + column: clipped_cursor.column, + }, + CURSOR_MARKER, + )]; + + let mut cursor_excerpt_added = false; + + for (buffer, ranges) in related_excerpts { + let excerpt_snapshot = buffer.read(cx).snapshot(); + + let mut line_ranges = ranges + .into_iter() + .map(|range| { + let point_range = range.to_point(&excerpt_snapshot); + Line(point_range.start.row)..Line(point_range.end.row) + }) + .collect::>(); + + let Some(file) = excerpt_snapshot.file() else { + continue; + }; + let path = file.full_path(cx); + + let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx); + if is_cursor_file { + let insertion_ix = line_ranges + .binary_search_by(|probe| { + probe + .start + .cmp(&cursor_excerpt.line_range.start) + .then(cursor_excerpt.line_range.end.cmp(&probe.end)) + }) + .unwrap_or_else(|ix| ix); + line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone()); + cursor_excerpt_added = true; + } + + let merged_excerpts = + zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges) + .into_iter() + .map(|excerpt| OutputExcerpt { + path: path.clone(), + excerpt, + }); + + let excerpt_start_ix = excerpts.len(); + excerpts.extend(merged_excerpts); + + write_codeblock( + &path, + excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt), + if is_cursor_file { + &cursor_insertions + } else { + &[] + }, + Line(excerpt_snapshot.max_point().row), + true, + &mut formatted_excerpts, + ); + } + + if !cursor_excerpt_added { + write_codeblock( + &cursor_snapshot.file().unwrap().full_path(cx), + &[Excerpt { + start_line: cursor_excerpt.line_range.start, + text: cursor_excerpt.text(&cursor_snapshot).body.into(), + }], + &cursor_insertions, + Line(cursor_snapshot.max_point().row), + true, + &mut formatted_excerpts, + ); + } + + let output = Output { + excerpts, + formatted_excerpts, + meta, + }; + + Ok(serde_json::to_string_pretty(&output)?) + }) + .unwrap() +} + +async fn zeta1_context( + args: ContextArgs, + app_state: &Arc, + cx: &mut AsyncApp, +) -> Result { + let LoadedContext { + full_path_str, + snapshot, + clipped_cursor, + .. + } = load_context(&args, app_state, cx).await?; + + let events = match args.edit_history { + Some(events) => events.read_to_string().await?, + None => String::new(), + }; + + let prompt_for_events = move || (events, 0); + cx.update(|cx| { + zeta::gather_context( + full_path_str, + &snapshot, + clipped_cursor, + prompt_for_events, + cx, + ) + })? + .await } fn main() { @@ -334,80 +602,47 @@ fn main() { let app_state = Arc::new(headless::init(cx)); cx.spawn(async move |cx| { let result = match args.command { - Commands::Zeta2Context { - zeta2_args, - context_args, - } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await { - Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(), - Ok(GetContextOutput::Zeta2(output)) => Ok(output), - Err(err) => Err(err), - }, - Commands::Context(context_args) => { - match get_context(None, context_args, &app_state, cx).await { - Ok(GetContextOutput::Zeta1(output)) => { - Ok(serde_json::to_string_pretty(&output.body).unwrap()) - } - Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(), - Err(err) => Err(err), - } - } - Commands::Predict { - predict_edits_body, - context_args, - } => { - cx.spawn(async move |cx| { - let app_version = cx.update(|cx| AppVersion::global(cx))?; - app_state.client.sign_in(true, cx).await?; - let llm_token = LlmApiToken::default(); - llm_token.refresh(&app_state.client).await?; - - let predict_edits_body = - if let Some(predict_edits_body) = predict_edits_body { - serde_json::from_str(&predict_edits_body.read_to_string().await?)? - } else if let Some(context_args) = context_args { - match get_context(None, context_args, &app_state, cx).await? { - GetContextOutput::Zeta1(output) => output.body, - GetContextOutput::Zeta2 { .. } => unreachable!(), - } - } else { - return Err(anyhow!( - "Expected either --predict-edits-body-file \ - or the required args of the `context` command." - )); - }; - - let (response, _usage) = - Zeta::perform_predict_edits(PerformPredictEditsParams { - client: app_state.client.clone(), - llm_token, - app_version, - body: predict_edits_body, - }) - .await?; - - Ok(response.output_excerpt) - }) - .await - } - Commands::RetrievalStats { - zeta2_args, - worktree, - extension, - limit, - skip, + Command::Zeta1 { + command: Zeta1Command::Context { context_args }, } => { - retrieval_stats( - worktree, - app_state, - extension, - limit, - skip, - (&zeta2_args).to_options(false), - cx, - ) - .await + let context = zeta1_context(context_args, &app_state, cx).await.unwrap(); + serde_json::to_string_pretty(&context.body).map_err(|err| anyhow::anyhow!(err)) } + Command::Zeta2 { args, command } => match command { + Zeta2Command::Syntax { + syntax_args, + command, + } => match command { + Zeta2SyntaxCommand::Context { context_args } => { + zeta2_syntax_context(args, syntax_args, context_args, &app_state, cx) + .await + } + Zeta2SyntaxCommand::Stats { + worktree, + extension, + limit, + skip, + } => { + retrieval_stats( + worktree, + app_state, + extension, + limit, + skip, + syntax_args_to_options(&args, &syntax_args, false), + cx, + ) + .await + } + }, + Zeta2Command::Llm { command } => match command { + Zeta2LlmCommand::Context { context_args } => { + zeta2_llm_context(args, context_args, &app_state, cx).await + } + }, + }, }; + match result { Ok(output) => { println!("{}", output); diff --git a/crates/zeta_cli/src/retrieval_stats.rs b/crates/zeta_cli/src/syntax_retrieval_stats.rs similarity index 100% rename from crates/zeta_cli/src/retrieval_stats.rs rename to crates/zeta_cli/src/syntax_retrieval_stats.rs