main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod metrics;
  5mod paths;
  6mod predict;
  7mod source_location;
  8mod training;
  9mod util;
 10
 11use crate::{
 12    evaluate::run_evaluate,
 13    example::{ExampleFormat, NamedExample},
 14    headless::ZetaCliAppState,
 15    predict::run_predict,
 16    source_location::SourceLocation,
 17    training::{context::ContextType, distill::run_distill},
 18    util::{open_buffer, open_buffer_with_language_server},
 19};
 20use ::util::{ResultExt, paths::PathStyle};
 21use anyhow::{Result, anyhow};
 22use clap::{Args, Parser, Subcommand, ValueEnum};
 23use cloud_llm_client::predict_edits_v3;
 24use edit_prediction::udiff::DiffLine;
 25use edit_prediction_context::EditPredictionExcerptOptions;
 26use gpui::{Application, AsyncApp, Entity, prelude::*};
 27use language::{Bias, Buffer, BufferSnapshot, Point};
 28use metrics::delta_chr_f;
 29use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
 30use reqwest_client::ReqwestClient;
 31use std::io::{self};
 32use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 33
 34#[derive(Parser, Debug)]
 35#[command(name = "zeta")]
 36struct ZetaCliArgs {
 37    #[arg(long, default_value_t = false)]
 38    printenv: bool,
 39    #[command(subcommand)]
 40    command: Option<Command>,
 41}
 42
 43#[derive(Subcommand, Debug)]
 44enum Command {
 45    Context(ContextArgs),
 46    Predict(PredictArguments),
 47    Eval(EvaluateArguments),
 48    Distill(DistillArguments),
 49    ConvertExample {
 50        path: PathBuf,
 51        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 52        output_format: ExampleFormat,
 53    },
 54    Score {
 55        golden_patch: PathBuf,
 56        actual_patch: PathBuf,
 57    },
 58    Clean,
 59}
 60
 61#[derive(Debug, Args)]
 62struct ContextArgs {
 63    #[arg(long)]
 64    provider: ContextProvider,
 65    #[arg(long)]
 66    worktree: PathBuf,
 67    #[arg(long)]
 68    cursor: SourceLocation,
 69    #[arg(long)]
 70    use_language_server: bool,
 71    #[arg(long)]
 72    edit_history: Option<FileOrStdin>,
 73    #[clap(flatten)]
 74    zeta2_args: Zeta2Args,
 75}
 76
 77#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
 78enum ContextProvider {
 79    Zeta1,
 80    #[default]
 81    Zeta2,
 82}
 83
 84#[derive(Clone, Debug, Args)]
 85struct Zeta2Args {
 86    #[arg(long, default_value_t = 8192)]
 87    max_prompt_bytes: usize,
 88    #[arg(long, default_value_t = 2048)]
 89    max_excerpt_bytes: usize,
 90    #[arg(long, default_value_t = 1024)]
 91    min_excerpt_bytes: usize,
 92    #[arg(long, default_value_t = 0.66)]
 93    target_before_cursor_over_total_bytes: f32,
 94    #[arg(long, default_value_t = 1024)]
 95    max_diagnostic_bytes: usize,
 96    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 97    prompt_format: PromptFormat,
 98    #[arg(long, value_enum, default_value_t = Default::default())]
 99    output_format: OutputFormat,
100    #[arg(long, default_value_t = 42)]
101    file_indexing_parallelism: usize,
102    #[arg(long, default_value_t = false)]
103    disable_imports_gathering: bool,
104    #[arg(long, default_value_t = u8::MAX)]
105    max_retrieved_definitions: u8,
106}
107
108#[derive(Debug, Args)]
109pub struct PredictArguments {
110    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
111    format: PredictionsOutputFormat,
112    example_path: PathBuf,
113    #[clap(flatten)]
114    options: PredictionOptions,
115}
116
117#[derive(Debug, Args)]
118pub struct DistillArguments {
119    split_commit_dataset: PathBuf,
120    #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
121    context_type: ContextType,
122    #[clap(long)]
123    batch: Option<String>,
124}
125
126#[derive(Clone, Debug, Args)]
127pub struct PredictionOptions {
128    #[clap(flatten)]
129    zeta2: Zeta2Args,
130    #[clap(long)]
131    provider: PredictionProvider,
132    #[clap(long, value_enum, default_value_t = CacheMode::default())]
133    cache: CacheMode,
134}
135
136#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
137pub enum CacheMode {
138    /// Use cached LLM requests and responses, except when multiple repetitions are requested
139    #[default]
140    Auto,
141    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
142    #[value(alias = "request")]
143    Requests,
144    /// Ignore existing cache entries for both LLM and search.
145    Skip,
146    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
147    /// Useful for reproducing results and fixing bugs outside of search queries
148    Force,
149}
150
151impl CacheMode {
152    fn use_cached_llm_responses(&self) -> bool {
153        self.assert_not_auto();
154        matches!(self, CacheMode::Requests | CacheMode::Force)
155    }
156
157    fn use_cached_search_results(&self) -> bool {
158        self.assert_not_auto();
159        matches!(self, CacheMode::Force)
160    }
161
162    fn assert_not_auto(&self) {
163        assert_ne!(
164            *self,
165            CacheMode::Auto,
166            "Cache mode should not be auto at this point!"
167        );
168    }
169}
170
171#[derive(clap::ValueEnum, Debug, Clone)]
172pub enum PredictionsOutputFormat {
173    Json,
174    Md,
175    Diff,
176}
177
178#[derive(Debug, Args)]
179pub struct EvaluateArguments {
180    example_paths: Vec<PathBuf>,
181    #[clap(flatten)]
182    options: PredictionOptions,
183    #[clap(short, long, default_value_t = 1, alias = "repeat")]
184    repetitions: u16,
185    #[arg(long)]
186    skip_prediction: bool,
187}
188
189#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
190enum PredictionProvider {
191    Zeta1,
192    #[default]
193    Zeta2,
194    Sweep,
195}
196
197fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
198    edit_prediction::ZetaOptions {
199        context: EditPredictionExcerptOptions {
200            max_bytes: args.max_excerpt_bytes,
201            min_bytes: args.min_excerpt_bytes,
202            target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
203        },
204        max_prompt_bytes: args.max_prompt_bytes,
205        prompt_format: args.prompt_format.into(),
206    }
207}
208
209#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
210enum PromptFormat {
211    OnlySnippets,
212    #[default]
213    OldTextNewText,
214    Minimal,
215    MinimalQwen,
216    SeedCoder1120,
217}
218
219impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
220    fn into(self) -> predict_edits_v3::PromptFormat {
221        match self {
222            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
223            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
224            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
225            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
226            Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
227        }
228    }
229}
230
231#[derive(clap::ValueEnum, Default, Debug, Clone)]
232enum OutputFormat {
233    #[default]
234    Prompt,
235    Request,
236    Full,
237}
238
239#[derive(Debug, Clone)]
240enum FileOrStdin {
241    File(PathBuf),
242    Stdin,
243}
244
245impl FileOrStdin {
246    async fn read_to_string(&self) -> Result<String, std::io::Error> {
247        match self {
248            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
249            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
250        }
251    }
252}
253
254impl FromStr for FileOrStdin {
255    type Err = <PathBuf as FromStr>::Err;
256
257    fn from_str(s: &str) -> Result<Self, Self::Err> {
258        match s {
259            "-" => Ok(Self::Stdin),
260            _ => Ok(Self::File(PathBuf::from_str(s)?)),
261        }
262    }
263}
264
265struct LoadedContext {
266    full_path_str: String,
267    snapshot: BufferSnapshot,
268    clipped_cursor: Point,
269    worktree: Entity<Worktree>,
270    project: Entity<Project>,
271    buffer: Entity<Buffer>,
272    lsp_open_handle: Option<OpenLspBufferHandle>,
273}
274
275async fn load_context(
276    args: &ContextArgs,
277    app_state: &Arc<ZetaCliAppState>,
278    cx: &mut AsyncApp,
279) -> Result<LoadedContext> {
280    let ContextArgs {
281        worktree: worktree_path,
282        cursor,
283        use_language_server,
284        ..
285    } = args;
286
287    let worktree_path = worktree_path.canonicalize()?;
288
289    let project = cx.update(|cx| {
290        Project::local(
291            app_state.client.clone(),
292            app_state.node_runtime.clone(),
293            app_state.user_store.clone(),
294            app_state.languages.clone(),
295            app_state.fs.clone(),
296            None,
297            cx,
298        )
299    })?;
300
301    let worktree = project
302        .update(cx, |project, cx| {
303            project.create_worktree(&worktree_path, true, cx)
304        })?
305        .await?;
306
307    let mut ready_languages = HashSet::default();
308    let (lsp_open_handle, buffer) = if *use_language_server {
309        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
310            project.clone(),
311            worktree.clone(),
312            cursor.path.clone(),
313            &mut ready_languages,
314            cx,
315        )
316        .await?;
317        (Some(lsp_open_handle), buffer)
318    } else {
319        let buffer =
320            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
321        (None, buffer)
322    };
323
324    let full_path_str = worktree
325        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
326        .display(PathStyle::local())
327        .to_string();
328
329    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
330    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
331    if clipped_cursor != cursor.point {
332        let max_row = snapshot.max_point().row;
333        if cursor.point.row < max_row {
334            return Err(anyhow!(
335                "Cursor position {:?} is out of bounds (line length is {})",
336                cursor.point,
337                snapshot.line_len(cursor.point.row)
338            ));
339        } else {
340            return Err(anyhow!(
341                "Cursor position {:?} is out of bounds (max row is {})",
342                cursor.point,
343                max_row
344            ));
345        }
346    }
347
348    Ok(LoadedContext {
349        full_path_str,
350        snapshot,
351        clipped_cursor,
352        worktree,
353        project,
354        buffer,
355        lsp_open_handle,
356    })
357}
358
359async fn zeta2_context(
360    args: ContextArgs,
361    app_state: &Arc<ZetaCliAppState>,
362    cx: &mut AsyncApp,
363) -> Result<String> {
364    let LoadedContext {
365        worktree,
366        project,
367        buffer,
368        clipped_cursor,
369        lsp_open_handle: _handle,
370        ..
371    } = load_context(&args, app_state, cx).await?;
372
373    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
374    // the whole worktree.
375    worktree
376        .read_with(cx, |worktree, _cx| {
377            worktree.as_local().unwrap().scan_complete()
378        })?
379        .await;
380    let output = cx
381        .update(|cx| {
382            let store = cx.new(|cx| {
383                edit_prediction::EditPredictionStore::new(
384                    app_state.client.clone(),
385                    app_state.user_store.clone(),
386                    cx,
387                )
388            });
389            store.update(cx, |store, cx| {
390                store.set_options(zeta2_args_to_options(&args.zeta2_args));
391                store.register_buffer(&buffer, &project, cx);
392            });
393            cx.spawn(async move |cx| {
394                let updates_rx = store.update(cx, |store, cx| {
395                    let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
396                    store.set_use_context(true);
397                    store.refresh_context(&project, &buffer, cursor, cx);
398                    store.project_context_updates(&project).unwrap()
399                })?;
400
401                updates_rx.recv().await.ok();
402
403                let context = store.update(cx, |store, cx| {
404                    store.context_for_project(&project, cx).to_vec()
405                })?;
406
407                anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
408            })
409        })?
410        .await?;
411
412    Ok(output)
413}
414
415async fn zeta1_context(
416    args: ContextArgs,
417    app_state: &Arc<ZetaCliAppState>,
418    cx: &mut AsyncApp,
419) -> Result<edit_prediction::zeta1::GatherContextOutput> {
420    let LoadedContext {
421        full_path_str,
422        snapshot,
423        clipped_cursor,
424        ..
425    } = load_context(&args, app_state, cx).await?;
426
427    let events = match args.edit_history {
428        Some(events) => events.read_to_string().await?,
429        None => String::new(),
430    };
431
432    let prompt_for_events = move || (events, 0);
433    cx.update(|cx| {
434        edit_prediction::zeta1::gather_context(
435            full_path_str,
436            &snapshot,
437            clipped_cursor,
438            prompt_for_events,
439            cloud_llm_client::PredictEditsRequestTrigger::Cli,
440            cx,
441        )
442    })?
443    .await
444}
445
446fn main() {
447    zlog::init();
448    zlog::init_output_stderr();
449    let args = ZetaCliArgs::parse();
450    let http_client = Arc::new(ReqwestClient::new());
451    let app = Application::headless().with_http_client(http_client);
452
453    app.run(move |cx| {
454        let app_state = Arc::new(headless::init(cx));
455        cx.spawn(async move |cx| {
456            match args.command {
457                None => {
458                    if args.printenv {
459                        ::util::shell_env::print_env();
460                    } else {
461                        panic!("Expected a command");
462                    }
463                }
464                Some(Command::Context(context_args)) => {
465                    let result = match context_args.provider {
466                        ContextProvider::Zeta1 => {
467                            let context =
468                                zeta1_context(context_args, &app_state, cx).await.unwrap();
469                            serde_json::to_string_pretty(&context.body).unwrap()
470                        }
471                        ContextProvider::Zeta2 => {
472                            zeta2_context(context_args, &app_state, cx).await.unwrap()
473                        }
474                    };
475                    println!("{}", result);
476                }
477                Some(Command::Predict(arguments)) => {
478                    run_predict(arguments, &app_state, cx).await;
479                }
480                Some(Command::Eval(arguments)) => {
481                    run_evaluate(arguments, &app_state, cx).await;
482                }
483                Some(Command::Distill(arguments)) => {
484                    let _guard = cx
485                        .update(|cx| gpui_tokio::Tokio::handle(cx))
486                        .unwrap()
487                        .enter();
488                    run_distill(arguments).await.log_err();
489                }
490                Some(Command::ConvertExample {
491                    path,
492                    output_format,
493                }) => {
494                    let example = NamedExample::load(path).unwrap();
495                    example.write(output_format, io::stdout()).unwrap();
496                }
497                Some(Command::Score {
498                    golden_patch,
499                    actual_patch,
500                }) => {
501                    let golden_content = std::fs::read_to_string(golden_patch).unwrap();
502                    let actual_content = std::fs::read_to_string(actual_patch).unwrap();
503
504                    let golden_diff: Vec<DiffLine> = golden_content
505                        .lines()
506                        .map(|line| DiffLine::parse(line))
507                        .collect();
508
509                    let actual_diff: Vec<DiffLine> = actual_content
510                        .lines()
511                        .map(|line| DiffLine::parse(line))
512                        .collect();
513
514                    let score = delta_chr_f(&golden_diff, &actual_diff);
515                    println!("{:.2}", score);
516                }
517                Some(Command::Clean) => {
518                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
519                }
520            };
521
522            let _ = cx.update(|cx| cx.quit());
523        })
524        .detach();
525    });
526}