main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod paths;
  5mod predict;
  6mod source_location;
  7mod syntax_retrieval_stats;
  8mod util;
  9
 10use crate::{
 11    evaluate::run_evaluate,
 12    example::{ExampleFormat, NamedExample},
 13    headless::ZetaCliAppState,
 14    predict::run_predict,
 15    source_location::SourceLocation,
 16    syntax_retrieval_stats::retrieval_stats,
 17    util::{open_buffer, open_buffer_with_language_server},
 18};
 19use ::util::paths::PathStyle;
 20use anyhow::{Result, anyhow};
 21use clap::{Args, Parser, Subcommand, ValueEnum};
 22use cloud_llm_client::predict_edits_v3;
 23use edit_prediction_context::{
 24    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 25};
 26use gpui::{Application, AsyncApp, Entity, prelude::*};
 27use language::{Bias, Buffer, BufferSnapshot, Point};
 28use project::{Project, Worktree};
 29use reqwest_client::ReqwestClient;
 30use serde_json::json;
 31use std::io::{self};
 32use std::time::Duration;
 33use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 34use zeta::ContextMode;
 35
 36#[derive(Parser, Debug)]
 37#[command(name = "zeta")]
 38struct ZetaCliArgs {
 39    #[arg(long, default_value_t = false)]
 40    printenv: bool,
 41    #[command(subcommand)]
 42    command: Option<Command>,
 43}
 44
 45#[derive(Subcommand, Debug)]
 46enum Command {
 47    Context(ContextArgs),
 48    ContextStats(ContextStatsArgs),
 49    Predict(PredictArguments),
 50    Eval(EvaluateArguments),
 51    ConvertExample {
 52        path: PathBuf,
 53        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 54        output_format: ExampleFormat,
 55    },
 56    Clean,
 57}
 58
 59#[derive(Debug, Args)]
 60struct ContextStatsArgs {
 61    #[arg(long)]
 62    worktree: PathBuf,
 63    #[arg(long)]
 64    extension: Option<String>,
 65    #[arg(long)]
 66    limit: Option<usize>,
 67    #[arg(long)]
 68    skip: Option<usize>,
 69    #[clap(flatten)]
 70    zeta2_args: Zeta2Args,
 71}
 72
 73#[derive(Debug, Args)]
 74struct ContextArgs {
 75    #[arg(long)]
 76    provider: ContextProvider,
 77    #[arg(long)]
 78    worktree: PathBuf,
 79    #[arg(long)]
 80    cursor: SourceLocation,
 81    #[arg(long)]
 82    use_language_server: bool,
 83    #[arg(long)]
 84    edit_history: Option<FileOrStdin>,
 85    #[clap(flatten)]
 86    zeta2_args: Zeta2Args,
 87}
 88
 89#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
 90enum ContextProvider {
 91    Zeta1,
 92    #[default]
 93    Syntax,
 94}
 95
 96#[derive(Clone, Debug, Args)]
 97struct Zeta2Args {
 98    #[arg(long, default_value_t = 8192)]
 99    max_prompt_bytes: usize,
100    #[arg(long, default_value_t = 2048)]
101    max_excerpt_bytes: usize,
102    #[arg(long, default_value_t = 1024)]
103    min_excerpt_bytes: usize,
104    #[arg(long, default_value_t = 0.66)]
105    target_before_cursor_over_total_bytes: f32,
106    #[arg(long, default_value_t = 1024)]
107    max_diagnostic_bytes: usize,
108    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
109    prompt_format: PromptFormat,
110    #[arg(long, value_enum, default_value_t = Default::default())]
111    output_format: OutputFormat,
112    #[arg(long, default_value_t = 42)]
113    file_indexing_parallelism: usize,
114    #[arg(long, default_value_t = false)]
115    disable_imports_gathering: bool,
116    #[arg(long, default_value_t = u8::MAX)]
117    max_retrieved_definitions: u8,
118}
119
120#[derive(Debug, Args)]
121pub struct PredictArguments {
122    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
123    format: PredictionsOutputFormat,
124    example_path: PathBuf,
125    #[clap(flatten)]
126    options: PredictionOptions,
127}
128
129#[derive(Clone, Debug, Args)]
130pub struct PredictionOptions {
131    #[arg(long)]
132    use_expected_context: bool,
133    #[clap(flatten)]
134    zeta2: Zeta2Args,
135    #[clap(long)]
136    provider: PredictionProvider,
137    #[clap(long, value_enum, default_value_t = CacheMode::default())]
138    cache: CacheMode,
139}
140
141#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
142pub enum CacheMode {
143    /// Use cached LLM requests and responses, except when multiple repetitions are requested
144    #[default]
145    Auto,
146    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
147    #[value(alias = "request")]
148    Requests,
149    /// Ignore existing cache entries for both LLM and search.
150    Skip,
151    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
152    /// Useful for reproducing results and fixing bugs outside of search queries
153    Force,
154}
155
156impl CacheMode {
157    fn use_cached_llm_responses(&self) -> bool {
158        self.assert_not_auto();
159        matches!(self, CacheMode::Requests | CacheMode::Force)
160    }
161
162    fn use_cached_search_results(&self) -> bool {
163        self.assert_not_auto();
164        matches!(self, CacheMode::Force)
165    }
166
167    fn assert_not_auto(&self) {
168        assert_ne!(
169            *self,
170            CacheMode::Auto,
171            "Cache mode should not be auto at this point!"
172        );
173    }
174}
175
176#[derive(clap::ValueEnum, Debug, Clone)]
177pub enum PredictionsOutputFormat {
178    Json,
179    Md,
180    Diff,
181}
182
183#[derive(Debug, Args)]
184pub struct EvaluateArguments {
185    example_paths: Vec<PathBuf>,
186    #[clap(flatten)]
187    options: PredictionOptions,
188    #[clap(short, long, default_value_t = 1, alias = "repeat")]
189    repetitions: u16,
190    #[arg(long)]
191    skip_prediction: bool,
192}
193
194#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
195enum PredictionProvider {
196    Zeta1,
197    #[default]
198    Zeta2,
199    Sweep,
200}
201
202fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
203    zeta::ZetaOptions {
204        context: ContextMode::Syntax(EditPredictionContextOptions {
205            max_retrieved_declarations: args.max_retrieved_definitions,
206            use_imports: !args.disable_imports_gathering,
207            excerpt: EditPredictionExcerptOptions {
208                max_bytes: args.max_excerpt_bytes,
209                min_bytes: args.min_excerpt_bytes,
210                target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
211            },
212            score: EditPredictionScoreOptions {
213                omit_excerpt_overlaps,
214            },
215        }),
216        max_diagnostic_bytes: args.max_diagnostic_bytes,
217        max_prompt_bytes: args.max_prompt_bytes,
218        prompt_format: args.prompt_format.into(),
219        file_indexing_parallelism: args.file_indexing_parallelism,
220        buffer_change_grouping_interval: Duration::ZERO,
221    }
222}
223
224#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
225enum PromptFormat {
226    MarkedExcerpt,
227    LabeledSections,
228    OnlySnippets,
229    #[default]
230    NumberedLines,
231    OldTextNewText,
232    Minimal,
233    MinimalQwen,
234    SeedCoder1120,
235}
236
237impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
238    fn into(self) -> predict_edits_v3::PromptFormat {
239        match self {
240            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
241            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
242            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
243            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
244            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
245            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
246            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
247            Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
248        }
249    }
250}
251
252#[derive(clap::ValueEnum, Default, Debug, Clone)]
253enum OutputFormat {
254    #[default]
255    Prompt,
256    Request,
257    Full,
258}
259
260#[derive(Debug, Clone)]
261enum FileOrStdin {
262    File(PathBuf),
263    Stdin,
264}
265
266impl FileOrStdin {
267    async fn read_to_string(&self) -> Result<String, std::io::Error> {
268        match self {
269            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
270            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
271        }
272    }
273}
274
275impl FromStr for FileOrStdin {
276    type Err = <PathBuf as FromStr>::Err;
277
278    fn from_str(s: &str) -> Result<Self, Self::Err> {
279        match s {
280            "-" => Ok(Self::Stdin),
281            _ => Ok(Self::File(PathBuf::from_str(s)?)),
282        }
283    }
284}
285
286struct LoadedContext {
287    full_path_str: String,
288    snapshot: BufferSnapshot,
289    clipped_cursor: Point,
290    worktree: Entity<Worktree>,
291    project: Entity<Project>,
292    buffer: Entity<Buffer>,
293}
294
295async fn load_context(
296    args: &ContextArgs,
297    app_state: &Arc<ZetaCliAppState>,
298    cx: &mut AsyncApp,
299) -> Result<LoadedContext> {
300    let ContextArgs {
301        worktree: worktree_path,
302        cursor,
303        use_language_server,
304        ..
305    } = args;
306
307    let worktree_path = worktree_path.canonicalize()?;
308
309    let project = cx.update(|cx| {
310        Project::local(
311            app_state.client.clone(),
312            app_state.node_runtime.clone(),
313            app_state.user_store.clone(),
314            app_state.languages.clone(),
315            app_state.fs.clone(),
316            None,
317            cx,
318        )
319    })?;
320
321    let worktree = project
322        .update(cx, |project, cx| {
323            project.create_worktree(&worktree_path, true, cx)
324        })?
325        .await?;
326
327    let mut ready_languages = HashSet::default();
328    let (_lsp_open_handle, buffer) = if *use_language_server {
329        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
330            project.clone(),
331            worktree.clone(),
332            cursor.path.clone(),
333            &mut ready_languages,
334            cx,
335        )
336        .await?;
337        (Some(lsp_open_handle), buffer)
338    } else {
339        let buffer =
340            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
341        (None, buffer)
342    };
343
344    let full_path_str = worktree
345        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
346        .display(PathStyle::local())
347        .to_string();
348
349    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
350    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
351    if clipped_cursor != cursor.point {
352        let max_row = snapshot.max_point().row;
353        if cursor.point.row < max_row {
354            return Err(anyhow!(
355                "Cursor position {:?} is out of bounds (line length is {})",
356                cursor.point,
357                snapshot.line_len(cursor.point.row)
358            ));
359        } else {
360            return Err(anyhow!(
361                "Cursor position {:?} is out of bounds (max row is {})",
362                cursor.point,
363                max_row
364            ));
365        }
366    }
367
368    Ok(LoadedContext {
369        full_path_str,
370        snapshot,
371        clipped_cursor,
372        worktree,
373        project,
374        buffer,
375    })
376}
377
378async fn zeta2_syntax_context(
379    args: ContextArgs,
380    app_state: &Arc<ZetaCliAppState>,
381    cx: &mut AsyncApp,
382) -> Result<String> {
383    let LoadedContext {
384        worktree,
385        project,
386        buffer,
387        clipped_cursor,
388        ..
389    } = load_context(&args, app_state, cx).await?;
390
391    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
392    // the whole worktree.
393    worktree
394        .read_with(cx, |worktree, _cx| {
395            worktree.as_local().unwrap().scan_complete()
396        })?
397        .await;
398    let output = cx
399        .update(|cx| {
400            let zeta = cx.new(|cx| {
401                zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
402            });
403            let indexing_done_task = zeta.update(cx, |zeta, cx| {
404                zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
405                zeta.register_buffer(&buffer, &project, cx);
406                zeta.wait_for_initial_indexing(&project, cx)
407            });
408            cx.spawn(async move |cx| {
409                indexing_done_task.await?;
410                let request = zeta
411                    .update(cx, |zeta, cx| {
412                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
413                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
414                    })?
415                    .await?;
416
417                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
418
419                match args.zeta2_args.output_format {
420                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
421                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
422                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
423                        "request": request,
424                        "prompt": prompt_string,
425                        "section_labels": section_labels,
426                    }))?),
427                }
428            })
429        })?
430        .await?;
431
432    Ok(output)
433}
434
435async fn zeta1_context(
436    args: ContextArgs,
437    app_state: &Arc<ZetaCliAppState>,
438    cx: &mut AsyncApp,
439) -> Result<zeta::zeta1::GatherContextOutput> {
440    let LoadedContext {
441        full_path_str,
442        snapshot,
443        clipped_cursor,
444        ..
445    } = load_context(&args, app_state, cx).await?;
446
447    let events = match args.edit_history {
448        Some(events) => events.read_to_string().await?,
449        None => String::new(),
450    };
451
452    let prompt_for_events = move || (events, 0);
453    cx.update(|cx| {
454        zeta::zeta1::gather_context(
455            full_path_str,
456            &snapshot,
457            clipped_cursor,
458            prompt_for_events,
459            cx,
460        )
461    })?
462    .await
463}
464
465fn main() {
466    zlog::init();
467    zlog::init_output_stderr();
468    let args = ZetaCliArgs::parse();
469    let http_client = Arc::new(ReqwestClient::new());
470    let app = Application::headless().with_http_client(http_client);
471
472    app.run(move |cx| {
473        let app_state = Arc::new(headless::init(cx));
474        cx.spawn(async move |cx| {
475            match args.command {
476                None => {
477                    if args.printenv {
478                        ::util::shell_env::print_env();
479                        return;
480                    } else {
481                        panic!("Expected a command");
482                    }
483                }
484                Some(Command::ContextStats(arguments)) => {
485                    let result = retrieval_stats(
486                        arguments.worktree,
487                        app_state,
488                        arguments.extension,
489                        arguments.limit,
490                        arguments.skip,
491                        zeta2_args_to_options(&arguments.zeta2_args, false),
492                        cx,
493                    )
494                    .await;
495                    println!("{}", result.unwrap());
496                }
497                Some(Command::Context(context_args)) => {
498                    let result = match context_args.provider {
499                        ContextProvider::Zeta1 => {
500                            let context =
501                                zeta1_context(context_args, &app_state, cx).await.unwrap();
502                            serde_json::to_string_pretty(&context.body).unwrap()
503                        }
504                        ContextProvider::Syntax => {
505                            zeta2_syntax_context(context_args, &app_state, cx)
506                                .await
507                                .unwrap()
508                        }
509                    };
510                    println!("{}", result);
511                }
512                Some(Command::Predict(arguments)) => {
513                    run_predict(arguments, &app_state, cx).await;
514                }
515                Some(Command::Eval(arguments)) => {
516                    run_evaluate(arguments, &app_state, cx).await;
517                }
518                Some(Command::ConvertExample {
519                    path,
520                    output_format,
521                }) => {
522                    let example = NamedExample::load(path).unwrap();
523                    example.write(output_format, io::stdout()).unwrap();
524                }
525                Some(Command::Clean) => {
526                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
527                }
528            };
529
530            let _ = cx.update(|cx| cx.quit());
531        })
532        .detach();
533    });
534}