zeta2 cli: Output raw request (#38876)

Agus Zubiaga , Bennet Bo Fenner , and Oleksiy Syvokon created

Release Notes:

- N/A

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>

Change summary

crates/zeta_cli/src/main.rs | 51 +++++++++++++++++++++++++++++---------
1 file changed, 38 insertions(+), 13 deletions(-)

Detailed changes

crates/zeta_cli/src/main.rs 🔗

@@ -2,7 +2,7 @@ mod headless;
 
 use anyhow::{Result, anyhow};
 use clap::{Args, Parser, Subcommand};
-use cloud_llm_client::predict_edits_v3::PromptFormat;
+use cloud_llm_client::predict_edits_v3;
 use edit_prediction_context::EditPredictionExcerptOptions;
 use futures::channel::mpsc;
 use futures::{FutureExt as _, StreamExt as _};
@@ -75,18 +75,35 @@ struct Zeta2Args {
     target_before_cursor_over_total_bytes: f32,
     #[arg(long, default_value_t = 1024)]
     max_diagnostic_bytes: usize,
-    #[arg(long, value_parser = parse_format)]
-    format: PromptFormat,
+    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
+    prompt_format: PromptFormat,
+    #[arg(long, value_enum, default_value_t = Default::default())]
+    output_format: OutputFormat,
 }
 
-fn parse_format(s: &str) -> Result<PromptFormat> {
-    match s {
-        "marked_excerpt" => Ok(PromptFormat::MarkedExcerpt),
-        "labeled_sections" => Ok(PromptFormat::LabeledSections),
-        _ => Err(anyhow!("Invalid format: {}", s)),
+#[derive(clap::ValueEnum, Default, Debug, Clone)]
+enum PromptFormat {
+    #[default]
+    MarkedExcerpt,
+    LabeledSections,
+}
+
+impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
+    fn into(self) -> predict_edits_v3::PromptFormat {
+        match self {
+            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
+            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
+        }
     }
 }
 
+#[derive(clap::ValueEnum, Default, Debug, Clone)]
+enum OutputFormat {
+    #[default]
+    Prompt,
+    Request,
+}
+
 #[derive(Debug, Clone)]
 enum FileOrStdin {
     File(PathBuf),
@@ -239,7 +256,7 @@ async fn get_context(
                         },
                         max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
                         max_prompt_bytes: zeta2_args.max_prompt_bytes,
-                        prompt_format: zeta2_args.format,
+                        prompt_format: zeta2_args.prompt_format.into(),
                     })
                 });
                 // TODO: Actually wait for indexing.
@@ -252,9 +269,17 @@ async fn get_context(
                             zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
                         })?
                         .await?;
-                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
-                    // TODO: Output the section label ranges
-                    anyhow::Ok(planned_prompt.to_prompt_string()?.0)
+                    match zeta2_args.output_format {
+                        OutputFormat::Prompt => {
+                            let planned_prompt =
+                                cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
+                            // TODO: Output the section label ranges
+                            anyhow::Ok(planned_prompt.to_prompt_string()?.0)
+                        }
+                        OutputFormat::Request => {
+                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
+                        }
+                    }
                 })
             })?
             .await?,
@@ -469,7 +494,7 @@ fn main() {
                     println!("{}", output);
                     // TODO: Remove this once the 5 second delay is properly replaced.
                     if is_zeta2_context_command {
-                        eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
+                        eprintln!("Note that zeta_cli doesn't yet wait for indexing, instead waits 5 seconds.");
                     }
                     let _ = cx.update(|cx| cx.quit());
                 }