main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod metrics;
  5mod paths;
  6mod predict;
  7mod source_location;
  8mod syntax_retrieval_stats;
  9mod util;
 10
 11use crate::{
 12    evaluate::run_evaluate,
 13    example::{ExampleFormat, NamedExample},
 14    headless::ZetaCliAppState,
 15    predict::run_predict,
 16    source_location::SourceLocation,
 17    syntax_retrieval_stats::retrieval_stats,
 18    util::{open_buffer, open_buffer_with_language_server},
 19};
 20use ::util::paths::PathStyle;
 21use anyhow::{Result, anyhow};
 22use clap::{Args, Parser, Subcommand, ValueEnum};
 23use cloud_llm_client::predict_edits_v3;
 24use edit_prediction_context::{
 25    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 26};
 27use gpui::{Application, AsyncApp, Entity, prelude::*};
 28use language::{Bias, Buffer, BufferSnapshot, Point};
 29use metrics::delta_chr_f;
 30use project::{Project, Worktree};
 31use reqwest_client::ReqwestClient;
 32use serde_json::json;
 33use std::io::{self};
 34use std::time::Duration;
 35use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 36use zeta::ContextMode;
 37use zeta::udiff::DiffLine;
 38
 39#[derive(Parser, Debug)]
 40#[command(name = "zeta")]
 41struct ZetaCliArgs {
 42    #[arg(long, default_value_t = false)]
 43    printenv: bool,
 44    #[command(subcommand)]
 45    command: Option<Command>,
 46}
 47
 48#[derive(Subcommand, Debug)]
 49enum Command {
 50    Context(ContextArgs),
 51    ContextStats(ContextStatsArgs),
 52    Predict(PredictArguments),
 53    Eval(EvaluateArguments),
 54    ConvertExample {
 55        path: PathBuf,
 56        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 57        output_format: ExampleFormat,
 58    },
 59    Score {
 60        golden_patch: PathBuf,
 61        actual_patch: PathBuf,
 62    },
 63    Clean,
 64}
 65
 66#[derive(Debug, Args)]
 67struct ContextStatsArgs {
 68    #[arg(long)]
 69    worktree: PathBuf,
 70    #[arg(long)]
 71    extension: Option<String>,
 72    #[arg(long)]
 73    limit: Option<usize>,
 74    #[arg(long)]
 75    skip: Option<usize>,
 76    #[clap(flatten)]
 77    zeta2_args: Zeta2Args,
 78}
 79
 80#[derive(Debug, Args)]
 81struct ContextArgs {
 82    #[arg(long)]
 83    provider: ContextProvider,
 84    #[arg(long)]
 85    worktree: PathBuf,
 86    #[arg(long)]
 87    cursor: SourceLocation,
 88    #[arg(long)]
 89    use_language_server: bool,
 90    #[arg(long)]
 91    edit_history: Option<FileOrStdin>,
 92    #[clap(flatten)]
 93    zeta2_args: Zeta2Args,
 94}
 95
 96#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
 97enum ContextProvider {
 98    Zeta1,
 99    #[default]
100    Syntax,
101}
102
103#[derive(Clone, Debug, Args)]
104struct Zeta2Args {
105    #[arg(long, default_value_t = 8192)]
106    max_prompt_bytes: usize,
107    #[arg(long, default_value_t = 2048)]
108    max_excerpt_bytes: usize,
109    #[arg(long, default_value_t = 1024)]
110    min_excerpt_bytes: usize,
111    #[arg(long, default_value_t = 0.66)]
112    target_before_cursor_over_total_bytes: f32,
113    #[arg(long, default_value_t = 1024)]
114    max_diagnostic_bytes: usize,
115    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
116    prompt_format: PromptFormat,
117    #[arg(long, value_enum, default_value_t = Default::default())]
118    output_format: OutputFormat,
119    #[arg(long, default_value_t = 42)]
120    file_indexing_parallelism: usize,
121    #[arg(long, default_value_t = false)]
122    disable_imports_gathering: bool,
123    #[arg(long, default_value_t = u8::MAX)]
124    max_retrieved_definitions: u8,
125}
126
127#[derive(Debug, Args)]
128pub struct PredictArguments {
129    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
130    format: PredictionsOutputFormat,
131    example_path: PathBuf,
132    #[clap(flatten)]
133    options: PredictionOptions,
134}
135
136#[derive(Clone, Debug, Args)]
137pub struct PredictionOptions {
138    #[clap(flatten)]
139    zeta2: Zeta2Args,
140    #[clap(long)]
141    provider: PredictionProvider,
142    #[clap(long, value_enum, default_value_t = CacheMode::default())]
143    cache: CacheMode,
144}
145
146#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
147pub enum CacheMode {
148    /// Use cached LLM requests and responses, except when multiple repetitions are requested
149    #[default]
150    Auto,
151    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
152    #[value(alias = "request")]
153    Requests,
154    /// Ignore existing cache entries for both LLM and search.
155    Skip,
156    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
157    /// Useful for reproducing results and fixing bugs outside of search queries
158    Force,
159}
160
161impl CacheMode {
162    fn use_cached_llm_responses(&self) -> bool {
163        self.assert_not_auto();
164        matches!(self, CacheMode::Requests | CacheMode::Force)
165    }
166
167    fn use_cached_search_results(&self) -> bool {
168        self.assert_not_auto();
169        matches!(self, CacheMode::Force)
170    }
171
172    fn assert_not_auto(&self) {
173        assert_ne!(
174            *self,
175            CacheMode::Auto,
176            "Cache mode should not be auto at this point!"
177        );
178    }
179}
180
181#[derive(clap::ValueEnum, Debug, Clone)]
182pub enum PredictionsOutputFormat {
183    Json,
184    Md,
185    Diff,
186}
187
188#[derive(Debug, Args)]
189pub struct EvaluateArguments {
190    example_paths: Vec<PathBuf>,
191    #[clap(flatten)]
192    options: PredictionOptions,
193    #[clap(short, long, default_value_t = 1, alias = "repeat")]
194    repetitions: u16,
195    #[arg(long)]
196    skip_prediction: bool,
197}
198
199#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
200enum PredictionProvider {
201    Zeta1,
202    #[default]
203    Zeta2,
204    Sweep,
205}
206
207fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
208    zeta::ZetaOptions {
209        context: ContextMode::Syntax(EditPredictionContextOptions {
210            max_retrieved_declarations: args.max_retrieved_definitions,
211            use_imports: !args.disable_imports_gathering,
212            excerpt: EditPredictionExcerptOptions {
213                max_bytes: args.max_excerpt_bytes,
214                min_bytes: args.min_excerpt_bytes,
215                target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
216            },
217            score: EditPredictionScoreOptions {
218                omit_excerpt_overlaps,
219            },
220        }),
221        max_diagnostic_bytes: args.max_diagnostic_bytes,
222        max_prompt_bytes: args.max_prompt_bytes,
223        prompt_format: args.prompt_format.into(),
224        file_indexing_parallelism: args.file_indexing_parallelism,
225        buffer_change_grouping_interval: Duration::ZERO,
226    }
227}
228
229#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
230enum PromptFormat {
231    MarkedExcerpt,
232    LabeledSections,
233    OnlySnippets,
234    #[default]
235    NumberedLines,
236    OldTextNewText,
237    Minimal,
238    MinimalQwen,
239    SeedCoder1120,
240}
241
242impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
243    fn into(self) -> predict_edits_v3::PromptFormat {
244        match self {
245            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
246            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
247            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
248            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
249            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
250            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
251            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
252            Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
253        }
254    }
255}
256
257#[derive(clap::ValueEnum, Default, Debug, Clone)]
258enum OutputFormat {
259    #[default]
260    Prompt,
261    Request,
262    Full,
263}
264
265#[derive(Debug, Clone)]
266enum FileOrStdin {
267    File(PathBuf),
268    Stdin,
269}
270
271impl FileOrStdin {
272    async fn read_to_string(&self) -> Result<String, std::io::Error> {
273        match self {
274            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
275            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
276        }
277    }
278}
279
280impl FromStr for FileOrStdin {
281    type Err = <PathBuf as FromStr>::Err;
282
283    fn from_str(s: &str) -> Result<Self, Self::Err> {
284        match s {
285            "-" => Ok(Self::Stdin),
286            _ => Ok(Self::File(PathBuf::from_str(s)?)),
287        }
288    }
289}
290
291struct LoadedContext {
292    full_path_str: String,
293    snapshot: BufferSnapshot,
294    clipped_cursor: Point,
295    worktree: Entity<Worktree>,
296    project: Entity<Project>,
297    buffer: Entity<Buffer>,
298}
299
300async fn load_context(
301    args: &ContextArgs,
302    app_state: &Arc<ZetaCliAppState>,
303    cx: &mut AsyncApp,
304) -> Result<LoadedContext> {
305    let ContextArgs {
306        worktree: worktree_path,
307        cursor,
308        use_language_server,
309        ..
310    } = args;
311
312    let worktree_path = worktree_path.canonicalize()?;
313
314    let project = cx.update(|cx| {
315        Project::local(
316            app_state.client.clone(),
317            app_state.node_runtime.clone(),
318            app_state.user_store.clone(),
319            app_state.languages.clone(),
320            app_state.fs.clone(),
321            None,
322            cx,
323        )
324    })?;
325
326    let worktree = project
327        .update(cx, |project, cx| {
328            project.create_worktree(&worktree_path, true, cx)
329        })?
330        .await?;
331
332    let mut ready_languages = HashSet::default();
333    let (_lsp_open_handle, buffer) = if *use_language_server {
334        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
335            project.clone(),
336            worktree.clone(),
337            cursor.path.clone(),
338            &mut ready_languages,
339            cx,
340        )
341        .await?;
342        (Some(lsp_open_handle), buffer)
343    } else {
344        let buffer =
345            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
346        (None, buffer)
347    };
348
349    let full_path_str = worktree
350        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
351        .display(PathStyle::local())
352        .to_string();
353
354    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
355    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
356    if clipped_cursor != cursor.point {
357        let max_row = snapshot.max_point().row;
358        if cursor.point.row < max_row {
359            return Err(anyhow!(
360                "Cursor position {:?} is out of bounds (line length is {})",
361                cursor.point,
362                snapshot.line_len(cursor.point.row)
363            ));
364        } else {
365            return Err(anyhow!(
366                "Cursor position {:?} is out of bounds (max row is {})",
367                cursor.point,
368                max_row
369            ));
370        }
371    }
372
373    Ok(LoadedContext {
374        full_path_str,
375        snapshot,
376        clipped_cursor,
377        worktree,
378        project,
379        buffer,
380    })
381}
382
383async fn zeta2_syntax_context(
384    args: ContextArgs,
385    app_state: &Arc<ZetaCliAppState>,
386    cx: &mut AsyncApp,
387) -> Result<String> {
388    let LoadedContext {
389        worktree,
390        project,
391        buffer,
392        clipped_cursor,
393        ..
394    } = load_context(&args, app_state, cx).await?;
395
396    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
397    // the whole worktree.
398    worktree
399        .read_with(cx, |worktree, _cx| {
400            worktree.as_local().unwrap().scan_complete()
401        })?
402        .await;
403    let output = cx
404        .update(|cx| {
405            let zeta = cx.new(|cx| {
406                zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
407            });
408            let indexing_done_task = zeta.update(cx, |zeta, cx| {
409                zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
410                zeta.register_buffer(&buffer, &project, cx);
411                zeta.wait_for_initial_indexing(&project, cx)
412            });
413            cx.spawn(async move |cx| {
414                indexing_done_task.await?;
415                let request = zeta
416                    .update(cx, |zeta, cx| {
417                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
418                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
419                    })?
420                    .await?;
421
422                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
423
424                match args.zeta2_args.output_format {
425                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
426                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
427                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
428                        "request": request,
429                        "prompt": prompt_string,
430                        "section_labels": section_labels,
431                    }))?),
432                }
433            })
434        })?
435        .await?;
436
437    Ok(output)
438}
439
440async fn zeta1_context(
441    args: ContextArgs,
442    app_state: &Arc<ZetaCliAppState>,
443    cx: &mut AsyncApp,
444) -> Result<zeta::zeta1::GatherContextOutput> {
445    let LoadedContext {
446        full_path_str,
447        snapshot,
448        clipped_cursor,
449        ..
450    } = load_context(&args, app_state, cx).await?;
451
452    let events = match args.edit_history {
453        Some(events) => events.read_to_string().await?,
454        None => String::new(),
455    };
456
457    let prompt_for_events = move || (events, 0);
458    cx.update(|cx| {
459        zeta::zeta1::gather_context(
460            full_path_str,
461            &snapshot,
462            clipped_cursor,
463            prompt_for_events,
464            cloud_llm_client::PredictEditsRequestTrigger::Cli,
465            cx,
466        )
467    })?
468    .await
469}
470
471fn main() {
472    zlog::init();
473    zlog::init_output_stderr();
474    let args = ZetaCliArgs::parse();
475    let http_client = Arc::new(ReqwestClient::new());
476    let app = Application::headless().with_http_client(http_client);
477
478    app.run(move |cx| {
479        let app_state = Arc::new(headless::init(cx));
480        cx.spawn(async move |cx| {
481            match args.command {
482                None => {
483                    if args.printenv {
484                        ::util::shell_env::print_env();
485                        return;
486                    } else {
487                        panic!("Expected a command");
488                    }
489                }
490                Some(Command::ContextStats(arguments)) => {
491                    let result = retrieval_stats(
492                        arguments.worktree,
493                        app_state,
494                        arguments.extension,
495                        arguments.limit,
496                        arguments.skip,
497                        zeta2_args_to_options(&arguments.zeta2_args, false),
498                        cx,
499                    )
500                    .await;
501                    println!("{}", result.unwrap());
502                }
503                Some(Command::Context(context_args)) => {
504                    let result = match context_args.provider {
505                        ContextProvider::Zeta1 => {
506                            let context =
507                                zeta1_context(context_args, &app_state, cx).await.unwrap();
508                            serde_json::to_string_pretty(&context.body).unwrap()
509                        }
510                        ContextProvider::Syntax => {
511                            zeta2_syntax_context(context_args, &app_state, cx)
512                                .await
513                                .unwrap()
514                        }
515                    };
516                    println!("{}", result);
517                }
518                Some(Command::Predict(arguments)) => {
519                    run_predict(arguments, &app_state, cx).await;
520                }
521                Some(Command::Eval(arguments)) => {
522                    run_evaluate(arguments, &app_state, cx).await;
523                }
524                Some(Command::ConvertExample {
525                    path,
526                    output_format,
527                }) => {
528                    let example = NamedExample::load(path).unwrap();
529                    example.write(output_format, io::stdout()).unwrap();
530                }
531                Some(Command::Score {
532                    golden_patch,
533                    actual_patch,
534                }) => {
535                    let golden_content = std::fs::read_to_string(golden_patch).unwrap();
536                    let actual_content = std::fs::read_to_string(actual_patch).unwrap();
537
538                    let golden_diff: Vec<DiffLine> = golden_content
539                        .lines()
540                        .map(|line| DiffLine::parse(line))
541                        .collect();
542
543                    let actual_diff: Vec<DiffLine> = actual_content
544                        .lines()
545                        .map(|line| DiffLine::parse(line))
546                        .collect();
547
548                    let score = delta_chr_f(&golden_diff, &actual_diff);
549                    println!("{:.2}", score);
550                }
551                Some(Command::Clean) => {
552                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
553                }
554            };
555
556            let _ = cx.update(|cx| cx.quit());
557        })
558        .detach();
559    });
560}