main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod metrics;
  5mod paths;
  6mod predict;
  7mod source_location;
  8mod util;
  9
 10use crate::{
 11    evaluate::run_evaluate,
 12    example::{ExampleFormat, NamedExample},
 13    headless::ZetaCliAppState,
 14    predict::run_predict,
 15    source_location::SourceLocation,
 16    util::{open_buffer, open_buffer_with_language_server},
 17};
 18use ::util::paths::PathStyle;
 19use anyhow::{Result, anyhow};
 20use clap::{Args, Parser, Subcommand, ValueEnum};
 21use cloud_llm_client::predict_edits_v3;
 22use edit_prediction::udiff::DiffLine;
 23use edit_prediction_context::EditPredictionExcerptOptions;
 24use gpui::{Application, AsyncApp, Entity, prelude::*};
 25use language::{Bias, Buffer, BufferSnapshot, Point};
 26use metrics::delta_chr_f;
 27use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
 28use reqwest_client::ReqwestClient;
 29use std::io::{self};
 30use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 31
 32#[derive(Parser, Debug)]
 33#[command(name = "zeta")]
 34struct ZetaCliArgs {
 35    #[arg(long, default_value_t = false)]
 36    printenv: bool,
 37    #[command(subcommand)]
 38    command: Option<Command>,
 39}
 40
 41#[derive(Subcommand, Debug)]
 42enum Command {
 43    Context(ContextArgs),
 44    Predict(PredictArguments),
 45    Eval(EvaluateArguments),
 46    ConvertExample {
 47        path: PathBuf,
 48        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 49        output_format: ExampleFormat,
 50    },
 51    Score {
 52        golden_patch: PathBuf,
 53        actual_patch: PathBuf,
 54    },
 55    Clean,
 56}
 57
 58#[derive(Debug, Args)]
 59struct ContextArgs {
 60    #[arg(long)]
 61    provider: ContextProvider,
 62    #[arg(long)]
 63    worktree: PathBuf,
 64    #[arg(long)]
 65    cursor: SourceLocation,
 66    #[arg(long)]
 67    use_language_server: bool,
 68    #[arg(long)]
 69    edit_history: Option<FileOrStdin>,
 70    #[clap(flatten)]
 71    zeta2_args: Zeta2Args,
 72}
 73
 74#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
 75enum ContextProvider {
 76    Zeta1,
 77    #[default]
 78    Zeta2,
 79}
 80
 81#[derive(Clone, Debug, Args)]
 82struct Zeta2Args {
 83    #[arg(long, default_value_t = 8192)]
 84    max_prompt_bytes: usize,
 85    #[arg(long, default_value_t = 2048)]
 86    max_excerpt_bytes: usize,
 87    #[arg(long, default_value_t = 1024)]
 88    min_excerpt_bytes: usize,
 89    #[arg(long, default_value_t = 0.66)]
 90    target_before_cursor_over_total_bytes: f32,
 91    #[arg(long, default_value_t = 1024)]
 92    max_diagnostic_bytes: usize,
 93    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 94    prompt_format: PromptFormat,
 95    #[arg(long, value_enum, default_value_t = Default::default())]
 96    output_format: OutputFormat,
 97    #[arg(long, default_value_t = 42)]
 98    file_indexing_parallelism: usize,
 99    #[arg(long, default_value_t = false)]
100    disable_imports_gathering: bool,
101    #[arg(long, default_value_t = u8::MAX)]
102    max_retrieved_definitions: u8,
103}
104
105#[derive(Debug, Args)]
106pub struct PredictArguments {
107    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
108    format: PredictionsOutputFormat,
109    example_path: PathBuf,
110    #[clap(flatten)]
111    options: PredictionOptions,
112}
113
114#[derive(Clone, Debug, Args)]
115pub struct PredictionOptions {
116    #[clap(flatten)]
117    zeta2: Zeta2Args,
118    #[clap(long)]
119    provider: PredictionProvider,
120    #[clap(long, value_enum, default_value_t = CacheMode::default())]
121    cache: CacheMode,
122}
123
124#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
125pub enum CacheMode {
126    /// Use cached LLM requests and responses, except when multiple repetitions are requested
127    #[default]
128    Auto,
129    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
130    #[value(alias = "request")]
131    Requests,
132    /// Ignore existing cache entries for both LLM and search.
133    Skip,
134    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
135    /// Useful for reproducing results and fixing bugs outside of search queries
136    Force,
137}
138
139impl CacheMode {
140    fn use_cached_llm_responses(&self) -> bool {
141        self.assert_not_auto();
142        matches!(self, CacheMode::Requests | CacheMode::Force)
143    }
144
145    fn use_cached_search_results(&self) -> bool {
146        self.assert_not_auto();
147        matches!(self, CacheMode::Force)
148    }
149
150    fn assert_not_auto(&self) {
151        assert_ne!(
152            *self,
153            CacheMode::Auto,
154            "Cache mode should not be auto at this point!"
155        );
156    }
157}
158
159#[derive(clap::ValueEnum, Debug, Clone)]
160pub enum PredictionsOutputFormat {
161    Json,
162    Md,
163    Diff,
164}
165
166#[derive(Debug, Args)]
167pub struct EvaluateArguments {
168    example_paths: Vec<PathBuf>,
169    #[clap(flatten)]
170    options: PredictionOptions,
171    #[clap(short, long, default_value_t = 1, alias = "repeat")]
172    repetitions: u16,
173    #[arg(long)]
174    skip_prediction: bool,
175}
176
177#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
178enum PredictionProvider {
179    Zeta1,
180    #[default]
181    Zeta2,
182    Sweep,
183}
184
185fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
186    edit_prediction::ZetaOptions {
187        context: EditPredictionExcerptOptions {
188            max_bytes: args.max_excerpt_bytes,
189            min_bytes: args.min_excerpt_bytes,
190            target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
191        },
192        max_prompt_bytes: args.max_prompt_bytes,
193        prompt_format: args.prompt_format.into(),
194    }
195}
196
197#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
198enum PromptFormat {
199    OnlySnippets,
200    #[default]
201    OldTextNewText,
202    Minimal,
203    MinimalQwen,
204    SeedCoder1120,
205}
206
207impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
208    fn into(self) -> predict_edits_v3::PromptFormat {
209        match self {
210            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
211            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
212            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
213            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
214            Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
215        }
216    }
217}
218
219#[derive(clap::ValueEnum, Default, Debug, Clone)]
220enum OutputFormat {
221    #[default]
222    Prompt,
223    Request,
224    Full,
225}
226
227#[derive(Debug, Clone)]
228enum FileOrStdin {
229    File(PathBuf),
230    Stdin,
231}
232
233impl FileOrStdin {
234    async fn read_to_string(&self) -> Result<String, std::io::Error> {
235        match self {
236            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
237            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
238        }
239    }
240}
241
242impl FromStr for FileOrStdin {
243    type Err = <PathBuf as FromStr>::Err;
244
245    fn from_str(s: &str) -> Result<Self, Self::Err> {
246        match s {
247            "-" => Ok(Self::Stdin),
248            _ => Ok(Self::File(PathBuf::from_str(s)?)),
249        }
250    }
251}
252
253struct LoadedContext {
254    full_path_str: String,
255    snapshot: BufferSnapshot,
256    clipped_cursor: Point,
257    worktree: Entity<Worktree>,
258    project: Entity<Project>,
259    buffer: Entity<Buffer>,
260    lsp_open_handle: Option<OpenLspBufferHandle>,
261}
262
263async fn load_context(
264    args: &ContextArgs,
265    app_state: &Arc<ZetaCliAppState>,
266    cx: &mut AsyncApp,
267) -> Result<LoadedContext> {
268    let ContextArgs {
269        worktree: worktree_path,
270        cursor,
271        use_language_server,
272        ..
273    } = args;
274
275    let worktree_path = worktree_path.canonicalize()?;
276
277    let project = cx.update(|cx| {
278        Project::local(
279            app_state.client.clone(),
280            app_state.node_runtime.clone(),
281            app_state.user_store.clone(),
282            app_state.languages.clone(),
283            app_state.fs.clone(),
284            None,
285            cx,
286        )
287    })?;
288
289    let worktree = project
290        .update(cx, |project, cx| {
291            project.create_worktree(&worktree_path, true, cx)
292        })?
293        .await?;
294
295    let mut ready_languages = HashSet::default();
296    let (lsp_open_handle, buffer) = if *use_language_server {
297        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
298            project.clone(),
299            worktree.clone(),
300            cursor.path.clone(),
301            &mut ready_languages,
302            cx,
303        )
304        .await?;
305        (Some(lsp_open_handle), buffer)
306    } else {
307        let buffer =
308            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
309        (None, buffer)
310    };
311
312    let full_path_str = worktree
313        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
314        .display(PathStyle::local())
315        .to_string();
316
317    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
318    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
319    if clipped_cursor != cursor.point {
320        let max_row = snapshot.max_point().row;
321        if cursor.point.row < max_row {
322            return Err(anyhow!(
323                "Cursor position {:?} is out of bounds (line length is {})",
324                cursor.point,
325                snapshot.line_len(cursor.point.row)
326            ));
327        } else {
328            return Err(anyhow!(
329                "Cursor position {:?} is out of bounds (max row is {})",
330                cursor.point,
331                max_row
332            ));
333        }
334    }
335
336    Ok(LoadedContext {
337        full_path_str,
338        snapshot,
339        clipped_cursor,
340        worktree,
341        project,
342        buffer,
343        lsp_open_handle,
344    })
345}
346
347async fn zeta2_context(
348    args: ContextArgs,
349    app_state: &Arc<ZetaCliAppState>,
350    cx: &mut AsyncApp,
351) -> Result<String> {
352    let LoadedContext {
353        worktree,
354        project,
355        buffer,
356        clipped_cursor,
357        lsp_open_handle: _handle,
358        ..
359    } = load_context(&args, app_state, cx).await?;
360
361    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
362    // the whole worktree.
363    worktree
364        .read_with(cx, |worktree, _cx| {
365            worktree.as_local().unwrap().scan_complete()
366        })?
367        .await;
368    let output = cx
369        .update(|cx| {
370            let store = cx.new(|cx| {
371                edit_prediction::EditPredictionStore::new(
372                    app_state.client.clone(),
373                    app_state.user_store.clone(),
374                    cx,
375                )
376            });
377            store.update(cx, |store, cx| {
378                store.set_options(zeta2_args_to_options(&args.zeta2_args));
379                store.register_buffer(&buffer, &project, cx);
380            });
381            cx.spawn(async move |cx| {
382                let updates_rx = store.update(cx, |store, cx| {
383                    let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
384                    store.set_use_context(true);
385                    store.refresh_context(&project, &buffer, cursor, cx);
386                    store.project_context_updates(&project).unwrap()
387                })?;
388
389                updates_rx.recv().await.ok();
390
391                let context = store.update(cx, |store, cx| {
392                    store.context_for_project(&project, cx).to_vec()
393                })?;
394
395                anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
396            })
397        })?
398        .await?;
399
400    Ok(output)
401}
402
403async fn zeta1_context(
404    args: ContextArgs,
405    app_state: &Arc<ZetaCliAppState>,
406    cx: &mut AsyncApp,
407) -> Result<edit_prediction::zeta1::GatherContextOutput> {
408    let LoadedContext {
409        full_path_str,
410        snapshot,
411        clipped_cursor,
412        ..
413    } = load_context(&args, app_state, cx).await?;
414
415    let events = match args.edit_history {
416        Some(events) => events.read_to_string().await?,
417        None => String::new(),
418    };
419
420    let prompt_for_events = move || (events, 0);
421    cx.update(|cx| {
422        edit_prediction::zeta1::gather_context(
423            full_path_str,
424            &snapshot,
425            clipped_cursor,
426            prompt_for_events,
427            cloud_llm_client::PredictEditsRequestTrigger::Cli,
428            cx,
429        )
430    })?
431    .await
432}
433
434fn main() {
435    zlog::init();
436    zlog::init_output_stderr();
437    let args = ZetaCliArgs::parse();
438    let http_client = Arc::new(ReqwestClient::new());
439    let app = Application::headless().with_http_client(http_client);
440
441    app.run(move |cx| {
442        let app_state = Arc::new(headless::init(cx));
443        cx.spawn(async move |cx| {
444            match args.command {
445                None => {
446                    if args.printenv {
447                        ::util::shell_env::print_env();
448                    } else {
449                        panic!("Expected a command");
450                    }
451                }
452                Some(Command::Context(context_args)) => {
453                    let result = match context_args.provider {
454                        ContextProvider::Zeta1 => {
455                            let context =
456                                zeta1_context(context_args, &app_state, cx).await.unwrap();
457                            serde_json::to_string_pretty(&context.body).unwrap()
458                        }
459                        ContextProvider::Zeta2 => {
460                            zeta2_context(context_args, &app_state, cx).await.unwrap()
461                        }
462                    };
463                    println!("{}", result);
464                }
465                Some(Command::Predict(arguments)) => {
466                    run_predict(arguments, &app_state, cx).await;
467                }
468                Some(Command::Eval(arguments)) => {
469                    run_evaluate(arguments, &app_state, cx).await;
470                }
471                Some(Command::ConvertExample {
472                    path,
473                    output_format,
474                }) => {
475                    let example = NamedExample::load(path).unwrap();
476                    example.write(output_format, io::stdout()).unwrap();
477                }
478                Some(Command::Score {
479                    golden_patch,
480                    actual_patch,
481                }) => {
482                    let golden_content = std::fs::read_to_string(golden_patch).unwrap();
483                    let actual_content = std::fs::read_to_string(actual_patch).unwrap();
484
485                    let golden_diff: Vec<DiffLine> = golden_content
486                        .lines()
487                        .map(|line| DiffLine::parse(line))
488                        .collect();
489
490                    let actual_diff: Vec<DiffLine> = actual_content
491                        .lines()
492                        .map(|line| DiffLine::parse(line))
493                        .collect();
494
495                    let score = delta_chr_f(&golden_diff, &actual_diff);
496                    println!("{:.2}", score);
497                }
498                Some(Command::Clean) => {
499                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
500                }
501            };
502
503            let _ = cx.update(|cx| cx.quit());
504        })
505        .detach();
506    });
507}