diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index fd924de4406497ab8ab8f7c520a63b28cced98d6..03bd5359cd01048c2edb5c3b8743916ddc3b4f2d 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -48,6 +48,8 @@ pub enum PromptFormat { #[default] MarkedExcerpt, LabeledSections, + /// Prompt format intended for use via zeta_cli + OnlySnippets, } impl PromptFormat { @@ -61,6 +63,7 @@ impl std::fmt::Display for PromptFormat { match self { PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"), PromptFormat::LabeledSections => write!(f, "Labeled Sections"), + PromptFormat::OnlySnippets => write!(f, "Only Snippets"), } } } diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index cc5c8cb8b287e620e38910a6bc4408f67a5722aa..9c1b64013abd8ade6a951838ded00c36aafba347 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -54,6 +54,8 @@ pub fn system_prompt(format: PromptFormat) -> &'static str { match format { PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT, PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT, + // only intended for use via zeta_cli + PromptFormat::OnlySnippets => "", } } @@ -343,6 +345,7 @@ impl<'a> PlannedPrompt<'a> { self.request.excerpt_range.start + self.request.cursor_offset, CURSOR_MARKER, )], + PromptFormat::OnlySnippets => vec![], }; let mut prompt = String::new(); @@ -432,12 +435,13 @@ impl<'a> PlannedPrompt<'a> { } writeln!(output, "```{}", file_path.display()).ok(); + let mut skipped_last_snippet = false; for (snippet, range) in disjoint_snippets { let section_index = section_ranges.len(); match self.request.prompt_format { - PromptFormat::MarkedExcerpt => { - if range.start > 0 { + PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets => { + if range.start > 0 && !skipped_last_snippet { output.push_str("…\n"); } } @@ -454,25 +458,38 @@ impl<'a> PlannedPrompt<'a> { } if is_excerpt_file { - excerpt_index = Some(section_index); - let mut last_offset = range.start; - let mut i = 0; - while i < excerpt_file_insertions.len() { - let (offset, insertion) = &excerpt_file_insertions[i]; - let found = *offset >= range.start && *offset <= range.end; - if found { - output.push_str( - &snippet.text[last_offset - range.start..offset - range.start], - ); - output.push_str(insertion); - last_offset = *offset; - excerpt_file_insertions.remove(i); - continue; + if self.request.prompt_format == PromptFormat::OnlySnippets { + if range.start >= self.request.excerpt_range.start + && range.end <= self.request.excerpt_range.end + { + skipped_last_snippet = true; + } else { + skipped_last_snippet = false; + output.push_str(snippet.text); } - i += 1; + } else { + let mut last_offset = range.start; + let mut i = 0; + while i < excerpt_file_insertions.len() { + let (offset, insertion) = &excerpt_file_insertions[i]; + let found = *offset >= range.start && *offset <= range.end; + if found { + excerpt_index = Some(section_index); + output.push_str( + &snippet.text[last_offset - range.start..offset - range.start], + ); + output.push_str(insertion); + last_offset = *offset; + excerpt_file_insertions.remove(i); + continue; + } + i += 1; + } + skipped_last_snippet = false; + output.push_str(&snippet.text[last_offset - range.start..]); } - output.push_str(&snippet.text[last_offset - range.start..]); } else { + skipped_last_snippet = false; output.push_str(snippet.text); } @@ -483,7 +500,11 @@ impl<'a> PlannedPrompt<'a> { } Ok(SectionLabels { - excerpt_index: excerpt_index.context("bug: no snippet found for excerpt")?, + // TODO: Clean this up + excerpt_index: match self.request.prompt_format { + PromptFormat::OnlySnippets => 0, + _ => excerpt_index.context("bug: no snippet found for excerpt")?, + }, section_ranges, }) } diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 47cbbd29d5edb9a0f4a36ec8a3cd5f4c3d7857b4..8895603f2984f4c411611980d60145d2bfebcbf9 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -15,6 +15,7 @@ use language_model::LlmApiToken; use project::{Project, ProjectPath, Worktree}; use release_channel::AppVersion; use reqwest_client::ReqwestClient; +use serde_json::json; use std::path::{Path, PathBuf}; use std::process::exit; use std::str::FromStr; @@ -86,6 +87,7 @@ enum PromptFormat { #[default] MarkedExcerpt, LabeledSections, + OnlySnippets, } impl Into for PromptFormat { @@ -93,6 +95,7 @@ impl Into for PromptFormat { match self { Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt, Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections, + Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets, } } } @@ -102,6 +105,7 @@ enum OutputFormat { #[default] Prompt, Request, + Both, } #[derive(Debug, Clone)] @@ -269,16 +273,18 @@ async fn get_context( zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx) })? .await?; + + let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?; + let prompt_string = planned_prompt.to_prompt_string()?.0; match zeta2_args.output_format { - OutputFormat::Prompt => { - let planned_prompt = - cloud_zeta2_prompt::PlannedPrompt::populate(&request)?; - // TODO: Output the section label ranges - anyhow::Ok(planned_prompt.to_prompt_string()?.0) - } + OutputFormat::Prompt => anyhow::Ok(prompt_string), OutputFormat::Request => { anyhow::Ok(serde_json::to_string_pretty(&request)?) } + OutputFormat::Both => anyhow::Ok(serde_json::to_string_pretty(&json!({ + "request": request, + "prompt": prompt_string, + }))?), } }) })?