diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 4460d660055bb5647ea5ef8f87d049d9c115b308..6c36927cd17ed21ed84989d5d9a3ef13828dc142 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -2,7 +2,7 @@ mod headless; use anyhow::{Result, anyhow}; use clap::{Args, Parser, Subcommand}; -use cloud_llm_client::predict_edits_v3::PromptFormat; +use cloud_llm_client::predict_edits_v3; use edit_prediction_context::EditPredictionExcerptOptions; use futures::channel::mpsc; use futures::{FutureExt as _, StreamExt as _}; @@ -75,18 +75,35 @@ struct Zeta2Args { target_before_cursor_over_total_bytes: f32, #[arg(long, default_value_t = 1024)] max_diagnostic_bytes: usize, - #[arg(long, value_parser = parse_format)] - format: PromptFormat, + #[arg(long, value_enum, default_value_t = PromptFormat::default())] + prompt_format: PromptFormat, + #[arg(long, value_enum, default_value_t = Default::default())] + output_format: OutputFormat, } -fn parse_format(s: &str) -> Result { - match s { - "marked_excerpt" => Ok(PromptFormat::MarkedExcerpt), - "labeled_sections" => Ok(PromptFormat::LabeledSections), - _ => Err(anyhow!("Invalid format: {}", s)), +#[derive(clap::ValueEnum, Default, Debug, Clone)] +enum PromptFormat { + #[default] + MarkedExcerpt, + LabeledSections, +} + +impl Into for PromptFormat { + fn into(self) -> predict_edits_v3::PromptFormat { + match self { + Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt, + Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections, + } } } +#[derive(clap::ValueEnum, Default, Debug, Clone)] +enum OutputFormat { + #[default] + Prompt, + Request, +} + #[derive(Debug, Clone)] enum FileOrStdin { File(PathBuf), @@ -239,7 +256,7 @@ async fn get_context( }, max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes, max_prompt_bytes: zeta2_args.max_prompt_bytes, - prompt_format: zeta2_args.format, + prompt_format: zeta2_args.prompt_format.into(), }) }); // TODO: Actually wait for indexing. @@ -252,9 +269,17 @@ async fn get_context( zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx) })? .await?; - let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?; - // TODO: Output the section label ranges - anyhow::Ok(planned_prompt.to_prompt_string()?.0) + match zeta2_args.output_format { + OutputFormat::Prompt => { + let planned_prompt = + cloud_zeta2_prompt::PlannedPrompt::populate(&request)?; + // TODO: Output the section label ranges + anyhow::Ok(planned_prompt.to_prompt_string()?.0) + } + OutputFormat::Request => { + anyhow::Ok(serde_json::to_string_pretty(&request)?) + } + } }) })? .await?, @@ -469,7 +494,7 @@ fn main() { println!("{}", output); // TODO: Remove this once the 5 second delay is properly replaced. if is_zeta2_context_command { - eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds."); + eprintln!("Note that zeta_cli doesn't yet wait for indexing, instead waits 5 seconds."); } let _ = cx.update(|cx| cx.quit()); }