main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod metrics;
  5mod paths;
  6mod predict;
  7mod source_location;
  8mod syntax_retrieval_stats;
  9mod util;
 10
 11use crate::{
 12    evaluate::run_evaluate,
 13    example::{ExampleFormat, NamedExample},
 14    headless::ZetaCliAppState,
 15    predict::run_predict,
 16    source_location::SourceLocation,
 17    syntax_retrieval_stats::retrieval_stats,
 18    util::{open_buffer, open_buffer_with_language_server},
 19};
 20use ::util::paths::PathStyle;
 21use anyhow::{Result, anyhow};
 22use clap::{Args, Parser, Subcommand, ValueEnum};
 23use cloud_llm_client::predict_edits_v3;
 24use edit_prediction_context::EditPredictionExcerptOptions;
 25use gpui::{Application, AsyncApp, Entity, prelude::*};
 26use language::{Bias, Buffer, BufferSnapshot, Point};
 27use metrics::delta_chr_f;
 28use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
 29use reqwest_client::ReqwestClient;
 30use std::io::{self};
 31use std::time::Duration;
 32use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 33use zeta::ContextMode;
 34use zeta::udiff::DiffLine;
 35
 36#[derive(Parser, Debug)]
 37#[command(name = "zeta")]
 38struct ZetaCliArgs {
 39    #[arg(long, default_value_t = false)]
 40    printenv: bool,
 41    #[command(subcommand)]
 42    command: Option<Command>,
 43}
 44
 45#[derive(Subcommand, Debug)]
 46enum Command {
 47    Context(ContextArgs),
 48    ContextStats(ContextStatsArgs),
 49    Predict(PredictArguments),
 50    Eval(EvaluateArguments),
 51    ConvertExample {
 52        path: PathBuf,
 53        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 54        output_format: ExampleFormat,
 55    },
 56    Score {
 57        golden_patch: PathBuf,
 58        actual_patch: PathBuf,
 59    },
 60    Clean,
 61}
 62
 63#[derive(Debug, Args)]
 64struct ContextStatsArgs {
 65    #[arg(long)]
 66    worktree: PathBuf,
 67    #[arg(long)]
 68    extension: Option<String>,
 69    #[arg(long)]
 70    limit: Option<usize>,
 71    #[arg(long)]
 72    skip: Option<usize>,
 73    #[clap(flatten)]
 74    zeta2_args: Zeta2Args,
 75}
 76
 77#[derive(Debug, Args)]
 78struct ContextArgs {
 79    #[arg(long)]
 80    provider: ContextProvider,
 81    #[arg(long)]
 82    worktree: PathBuf,
 83    #[arg(long)]
 84    cursor: SourceLocation,
 85    #[arg(long)]
 86    use_language_server: bool,
 87    #[arg(long)]
 88    edit_history: Option<FileOrStdin>,
 89    #[clap(flatten)]
 90    zeta2_args: Zeta2Args,
 91}
 92
 93#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
 94enum ContextProvider {
 95    Zeta1,
 96    #[default]
 97    Zeta2,
 98}
 99
100#[derive(Clone, Debug, Args)]
101struct Zeta2Args {
102    #[arg(long, default_value_t = 8192)]
103    max_prompt_bytes: usize,
104    #[arg(long, default_value_t = 2048)]
105    max_excerpt_bytes: usize,
106    #[arg(long, default_value_t = 1024)]
107    min_excerpt_bytes: usize,
108    #[arg(long, default_value_t = 0.66)]
109    target_before_cursor_over_total_bytes: f32,
110    #[arg(long, default_value_t = 1024)]
111    max_diagnostic_bytes: usize,
112    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
113    prompt_format: PromptFormat,
114    #[arg(long, value_enum, default_value_t = Default::default())]
115    output_format: OutputFormat,
116    #[arg(long, default_value_t = 42)]
117    file_indexing_parallelism: usize,
118    #[arg(long, default_value_t = false)]
119    disable_imports_gathering: bool,
120    #[arg(long, default_value_t = u8::MAX)]
121    max_retrieved_definitions: u8,
122}
123
124#[derive(Debug, Args)]
125pub struct PredictArguments {
126    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
127    format: PredictionsOutputFormat,
128    example_path: PathBuf,
129    #[clap(flatten)]
130    options: PredictionOptions,
131}
132
133#[derive(Clone, Debug, Args)]
134pub struct PredictionOptions {
135    #[clap(flatten)]
136    zeta2: Zeta2Args,
137    #[clap(long)]
138    provider: PredictionProvider,
139    #[clap(long, value_enum, default_value_t = CacheMode::default())]
140    cache: CacheMode,
141}
142
143#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
144pub enum CacheMode {
145    /// Use cached LLM requests and responses, except when multiple repetitions are requested
146    #[default]
147    Auto,
148    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
149    #[value(alias = "request")]
150    Requests,
151    /// Ignore existing cache entries for both LLM and search.
152    Skip,
153    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
154    /// Useful for reproducing results and fixing bugs outside of search queries
155    Force,
156}
157
158impl CacheMode {
159    fn use_cached_llm_responses(&self) -> bool {
160        self.assert_not_auto();
161        matches!(self, CacheMode::Requests | CacheMode::Force)
162    }
163
164    fn use_cached_search_results(&self) -> bool {
165        self.assert_not_auto();
166        matches!(self, CacheMode::Force)
167    }
168
169    fn assert_not_auto(&self) {
170        assert_ne!(
171            *self,
172            CacheMode::Auto,
173            "Cache mode should not be auto at this point!"
174        );
175    }
176}
177
178#[derive(clap::ValueEnum, Debug, Clone)]
179pub enum PredictionsOutputFormat {
180    Json,
181    Md,
182    Diff,
183}
184
185#[derive(Debug, Args)]
186pub struct EvaluateArguments {
187    example_paths: Vec<PathBuf>,
188    #[clap(flatten)]
189    options: PredictionOptions,
190    #[clap(short, long, default_value_t = 1, alias = "repeat")]
191    repetitions: u16,
192    #[arg(long)]
193    skip_prediction: bool,
194}
195
196#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
197enum PredictionProvider {
198    Zeta1,
199    #[default]
200    Zeta2,
201    Sweep,
202}
203
204fn zeta2_args_to_options(args: &Zeta2Args) -> zeta::ZetaOptions {
205    zeta::ZetaOptions {
206        context: ContextMode::Lsp(EditPredictionExcerptOptions {
207            max_bytes: args.max_excerpt_bytes,
208            min_bytes: args.min_excerpt_bytes,
209            target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
210        }),
211        max_diagnostic_bytes: args.max_diagnostic_bytes,
212        max_prompt_bytes: args.max_prompt_bytes,
213        prompt_format: args.prompt_format.into(),
214        file_indexing_parallelism: args.file_indexing_parallelism,
215        buffer_change_grouping_interval: Duration::ZERO,
216    }
217}
218
219#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
220enum PromptFormat {
221    MarkedExcerpt,
222    LabeledSections,
223    OnlySnippets,
224    #[default]
225    NumberedLines,
226    OldTextNewText,
227    Minimal,
228    MinimalQwen,
229    SeedCoder1120,
230}
231
232impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
233    fn into(self) -> predict_edits_v3::PromptFormat {
234        match self {
235            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
236            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
237            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
238            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
239            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
240            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
241            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
242            Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
243        }
244    }
245}
246
247#[derive(clap::ValueEnum, Default, Debug, Clone)]
248enum OutputFormat {
249    #[default]
250    Prompt,
251    Request,
252    Full,
253}
254
255#[derive(Debug, Clone)]
256enum FileOrStdin {
257    File(PathBuf),
258    Stdin,
259}
260
261impl FileOrStdin {
262    async fn read_to_string(&self) -> Result<String, std::io::Error> {
263        match self {
264            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
265            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
266        }
267    }
268}
269
270impl FromStr for FileOrStdin {
271    type Err = <PathBuf as FromStr>::Err;
272
273    fn from_str(s: &str) -> Result<Self, Self::Err> {
274        match s {
275            "-" => Ok(Self::Stdin),
276            _ => Ok(Self::File(PathBuf::from_str(s)?)),
277        }
278    }
279}
280
281struct LoadedContext {
282    full_path_str: String,
283    snapshot: BufferSnapshot,
284    clipped_cursor: Point,
285    worktree: Entity<Worktree>,
286    project: Entity<Project>,
287    buffer: Entity<Buffer>,
288    lsp_open_handle: Option<OpenLspBufferHandle>,
289}
290
291async fn load_context(
292    args: &ContextArgs,
293    app_state: &Arc<ZetaCliAppState>,
294    cx: &mut AsyncApp,
295) -> Result<LoadedContext> {
296    let ContextArgs {
297        worktree: worktree_path,
298        cursor,
299        use_language_server,
300        ..
301    } = args;
302
303    let worktree_path = worktree_path.canonicalize()?;
304
305    let project = cx.update(|cx| {
306        Project::local(
307            app_state.client.clone(),
308            app_state.node_runtime.clone(),
309            app_state.user_store.clone(),
310            app_state.languages.clone(),
311            app_state.fs.clone(),
312            None,
313            cx,
314        )
315    })?;
316
317    let worktree = project
318        .update(cx, |project, cx| {
319            project.create_worktree(&worktree_path, true, cx)
320        })?
321        .await?;
322
323    let mut ready_languages = HashSet::default();
324    let (lsp_open_handle, buffer) = if *use_language_server {
325        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
326            project.clone(),
327            worktree.clone(),
328            cursor.path.clone(),
329            &mut ready_languages,
330            cx,
331        )
332        .await?;
333        (Some(lsp_open_handle), buffer)
334    } else {
335        let buffer =
336            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
337        (None, buffer)
338    };
339
340    let full_path_str = worktree
341        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
342        .display(PathStyle::local())
343        .to_string();
344
345    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
346    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
347    if clipped_cursor != cursor.point {
348        let max_row = snapshot.max_point().row;
349        if cursor.point.row < max_row {
350            return Err(anyhow!(
351                "Cursor position {:?} is out of bounds (line length is {})",
352                cursor.point,
353                snapshot.line_len(cursor.point.row)
354            ));
355        } else {
356            return Err(anyhow!(
357                "Cursor position {:?} is out of bounds (max row is {})",
358                cursor.point,
359                max_row
360            ));
361        }
362    }
363
364    Ok(LoadedContext {
365        full_path_str,
366        snapshot,
367        clipped_cursor,
368        worktree,
369        project,
370        buffer,
371        lsp_open_handle,
372    })
373}
374
375async fn zeta2_context(
376    args: ContextArgs,
377    app_state: &Arc<ZetaCliAppState>,
378    cx: &mut AsyncApp,
379) -> Result<String> {
380    let LoadedContext {
381        worktree,
382        project,
383        buffer,
384        clipped_cursor,
385        lsp_open_handle: _handle,
386        ..
387    } = load_context(&args, app_state, cx).await?;
388
389    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
390    // the whole worktree.
391    worktree
392        .read_with(cx, |worktree, _cx| {
393            worktree.as_local().unwrap().scan_complete()
394        })?
395        .await;
396    let output = cx
397        .update(|cx| {
398            let zeta = cx.new(|cx| {
399                zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
400            });
401            let indexing_done_task = zeta.update(cx, |zeta, cx| {
402                zeta.set_options(zeta2_args_to_options(&args.zeta2_args));
403                zeta.register_buffer(&buffer, &project, cx);
404                zeta.wait_for_initial_indexing(&project, cx)
405            });
406            cx.spawn(async move |cx| {
407                indexing_done_task.await?;
408                let updates_rx = zeta.update(cx, |zeta, cx| {
409                    let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
410                    zeta.set_use_context(true);
411                    zeta.refresh_context_if_needed(&project, &buffer, cursor, cx);
412                    zeta.project_context_updates(&project).unwrap()
413                })?;
414
415                updates_rx.recv().await.ok();
416
417                let context = zeta.update(cx, |zeta, cx| {
418                    zeta.context_for_project(&project, cx).to_vec()
419                })?;
420
421                anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
422            })
423        })?
424        .await?;
425
426    Ok(output)
427}
428
429async fn zeta1_context(
430    args: ContextArgs,
431    app_state: &Arc<ZetaCliAppState>,
432    cx: &mut AsyncApp,
433) -> Result<zeta::zeta1::GatherContextOutput> {
434    let LoadedContext {
435        full_path_str,
436        snapshot,
437        clipped_cursor,
438        ..
439    } = load_context(&args, app_state, cx).await?;
440
441    let events = match args.edit_history {
442        Some(events) => events.read_to_string().await?,
443        None => String::new(),
444    };
445
446    let prompt_for_events = move || (events, 0);
447    cx.update(|cx| {
448        zeta::zeta1::gather_context(
449            full_path_str,
450            &snapshot,
451            clipped_cursor,
452            prompt_for_events,
453            cloud_llm_client::PredictEditsRequestTrigger::Cli,
454            cx,
455        )
456    })?
457    .await
458}
459
460fn main() {
461    zlog::init();
462    zlog::init_output_stderr();
463    let args = ZetaCliArgs::parse();
464    let http_client = Arc::new(ReqwestClient::new());
465    let app = Application::headless().with_http_client(http_client);
466
467    app.run(move |cx| {
468        let app_state = Arc::new(headless::init(cx));
469        cx.spawn(async move |cx| {
470            match args.command {
471                None => {
472                    if args.printenv {
473                        ::util::shell_env::print_env();
474                    } else {
475                        panic!("Expected a command");
476                    }
477                }
478                Some(Command::ContextStats(arguments)) => {
479                    let result = retrieval_stats(
480                        arguments.worktree,
481                        app_state,
482                        arguments.extension,
483                        arguments.limit,
484                        arguments.skip,
485                        zeta2_args_to_options(&arguments.zeta2_args),
486                        cx,
487                    )
488                    .await;
489                    println!("{}", result.unwrap());
490                }
491                Some(Command::Context(context_args)) => {
492                    let result = match context_args.provider {
493                        ContextProvider::Zeta1 => {
494                            let context =
495                                zeta1_context(context_args, &app_state, cx).await.unwrap();
496                            serde_json::to_string_pretty(&context.body).unwrap()
497                        }
498                        ContextProvider::Zeta2 => {
499                            zeta2_context(context_args, &app_state, cx).await.unwrap()
500                        }
501                    };
502                    println!("{}", result);
503                }
504                Some(Command::Predict(arguments)) => {
505                    run_predict(arguments, &app_state, cx).await;
506                }
507                Some(Command::Eval(arguments)) => {
508                    run_evaluate(arguments, &app_state, cx).await;
509                }
510                Some(Command::ConvertExample {
511                    path,
512                    output_format,
513                }) => {
514                    let example = NamedExample::load(path).unwrap();
515                    example.write(output_format, io::stdout()).unwrap();
516                }
517                Some(Command::Score {
518                    golden_patch,
519                    actual_patch,
520                }) => {
521                    let golden_content = std::fs::read_to_string(golden_patch).unwrap();
522                    let actual_content = std::fs::read_to_string(actual_patch).unwrap();
523
524                    let golden_diff: Vec<DiffLine> = golden_content
525                        .lines()
526                        .map(|line| DiffLine::parse(line))
527                        .collect();
528
529                    let actual_diff: Vec<DiffLine> = actual_content
530                        .lines()
531                        .map(|line| DiffLine::parse(line))
532                        .collect();
533
534                    let score = delta_chr_f(&golden_diff, &actual_diff);
535                    println!("{:.2}", score);
536                }
537                Some(Command::Clean) => {
538                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
539                }
540            };
541
542            let _ = cx.update(|cx| cx.quit());
543        })
544        .detach();
545    });
546}