main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod paths;
  5mod predict;
  6mod source_location;
  7mod syntax_retrieval_stats;
  8mod util;
  9
 10use crate::evaluate::{EvaluateArguments, run_evaluate};
 11use crate::example::{ExampleFormat, NamedExample};
 12use crate::predict::{PredictArguments, run_zeta2_predict};
 13use crate::syntax_retrieval_stats::retrieval_stats;
 14use ::util::paths::PathStyle;
 15use anyhow::{Result, anyhow};
 16use clap::{Args, Parser, Subcommand};
 17use cloud_llm_client::predict_edits_v3;
 18use edit_prediction_context::{
 19    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 20};
 21use gpui::{Application, AsyncApp, Entity, prelude::*};
 22use language::{Bias, Buffer, BufferSnapshot, Point};
 23use project::{Project, Worktree};
 24use reqwest_client::ReqwestClient;
 25use serde_json::json;
 26use std::io::{self};
 27use std::time::Duration;
 28use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 29use zeta2::ContextMode;
 30
 31use crate::headless::ZetaCliAppState;
 32use crate::source_location::SourceLocation;
 33use crate::util::{open_buffer, open_buffer_with_language_server};
 34
 35#[derive(Parser, Debug)]
 36#[command(name = "zeta")]
 37struct ZetaCliArgs {
 38    #[arg(long, default_value_t = false)]
 39    printenv: bool,
 40    #[command(subcommand)]
 41    command: Option<Command>,
 42}
 43
 44#[derive(Subcommand, Debug)]
 45enum Command {
 46    Zeta1 {
 47        #[command(subcommand)]
 48        command: Zeta1Command,
 49    },
 50    Zeta2 {
 51        #[command(subcommand)]
 52        command: Zeta2Command,
 53    },
 54    ConvertExample {
 55        path: PathBuf,
 56        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 57        output_format: ExampleFormat,
 58    },
 59    Clean,
 60}
 61
 62#[derive(Subcommand, Debug)]
 63enum Zeta1Command {
 64    Context {
 65        #[clap(flatten)]
 66        context_args: ContextArgs,
 67    },
 68}
 69
 70#[derive(Subcommand, Debug)]
 71enum Zeta2Command {
 72    Syntax {
 73        #[clap(flatten)]
 74        args: Zeta2Args,
 75        #[clap(flatten)]
 76        syntax_args: Zeta2SyntaxArgs,
 77        #[command(subcommand)]
 78        command: Zeta2SyntaxCommand,
 79    },
 80    Predict(PredictArguments),
 81    Eval(EvaluateArguments),
 82}
 83
 84#[derive(Subcommand, Debug)]
 85enum Zeta2SyntaxCommand {
 86    Context {
 87        #[clap(flatten)]
 88        context_args: ContextArgs,
 89    },
 90    Stats {
 91        #[arg(long)]
 92        worktree: PathBuf,
 93        #[arg(long)]
 94        extension: Option<String>,
 95        #[arg(long)]
 96        limit: Option<usize>,
 97        #[arg(long)]
 98        skip: Option<usize>,
 99    },
100}
101
102#[derive(Debug, Args)]
103#[group(requires = "worktree")]
104struct ContextArgs {
105    #[arg(long)]
106    worktree: PathBuf,
107    #[arg(long)]
108    cursor: SourceLocation,
109    #[arg(long)]
110    use_language_server: bool,
111    #[arg(long)]
112    edit_history: Option<FileOrStdin>,
113}
114
115#[derive(Debug, Args)]
116struct Zeta2Args {
117    #[arg(long, default_value_t = 8192)]
118    max_prompt_bytes: usize,
119    #[arg(long, default_value_t = 2048)]
120    max_excerpt_bytes: usize,
121    #[arg(long, default_value_t = 1024)]
122    min_excerpt_bytes: usize,
123    #[arg(long, default_value_t = 0.66)]
124    target_before_cursor_over_total_bytes: f32,
125    #[arg(long, default_value_t = 1024)]
126    max_diagnostic_bytes: usize,
127    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
128    prompt_format: PromptFormat,
129    #[arg(long, value_enum, default_value_t = Default::default())]
130    output_format: OutputFormat,
131    #[arg(long, default_value_t = 42)]
132    file_indexing_parallelism: usize,
133}
134
135#[derive(Debug, Args)]
136struct Zeta2SyntaxArgs {
137    #[arg(long, default_value_t = false)]
138    disable_imports_gathering: bool,
139    #[arg(long, default_value_t = u8::MAX)]
140    max_retrieved_definitions: u8,
141}
142
143fn syntax_args_to_options(
144    zeta2_args: &Zeta2Args,
145    syntax_args: &Zeta2SyntaxArgs,
146    omit_excerpt_overlaps: bool,
147) -> zeta2::ZetaOptions {
148    zeta2::ZetaOptions {
149        context: ContextMode::Syntax(EditPredictionContextOptions {
150            max_retrieved_declarations: syntax_args.max_retrieved_definitions,
151            use_imports: !syntax_args.disable_imports_gathering,
152            excerpt: EditPredictionExcerptOptions {
153                max_bytes: zeta2_args.max_excerpt_bytes,
154                min_bytes: zeta2_args.min_excerpt_bytes,
155                target_before_cursor_over_total_bytes: zeta2_args
156                    .target_before_cursor_over_total_bytes,
157            },
158            score: EditPredictionScoreOptions {
159                omit_excerpt_overlaps,
160            },
161        }),
162        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
163        max_prompt_bytes: zeta2_args.max_prompt_bytes,
164        prompt_format: zeta2_args.prompt_format.into(),
165        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
166        buffer_change_grouping_interval: Duration::ZERO,
167    }
168}
169
170#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
171enum PromptFormat {
172    MarkedExcerpt,
173    LabeledSections,
174    OnlySnippets,
175    #[default]
176    NumberedLines,
177    OldTextNewText,
178    Minimal,
179}
180
181impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
182    fn into(self) -> predict_edits_v3::PromptFormat {
183        match self {
184            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
185            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
186            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
187            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
188            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
189            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
190        }
191    }
192}
193
194#[derive(clap::ValueEnum, Default, Debug, Clone)]
195enum OutputFormat {
196    #[default]
197    Prompt,
198    Request,
199    Full,
200}
201
202#[derive(Debug, Clone)]
203enum FileOrStdin {
204    File(PathBuf),
205    Stdin,
206}
207
208impl FileOrStdin {
209    async fn read_to_string(&self) -> Result<String, std::io::Error> {
210        match self {
211            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
212            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
213        }
214    }
215}
216
217impl FromStr for FileOrStdin {
218    type Err = <PathBuf as FromStr>::Err;
219
220    fn from_str(s: &str) -> Result<Self, Self::Err> {
221        match s {
222            "-" => Ok(Self::Stdin),
223            _ => Ok(Self::File(PathBuf::from_str(s)?)),
224        }
225    }
226}
227
228struct LoadedContext {
229    full_path_str: String,
230    snapshot: BufferSnapshot,
231    clipped_cursor: Point,
232    worktree: Entity<Worktree>,
233    project: Entity<Project>,
234    buffer: Entity<Buffer>,
235}
236
237async fn load_context(
238    args: &ContextArgs,
239    app_state: &Arc<ZetaCliAppState>,
240    cx: &mut AsyncApp,
241) -> Result<LoadedContext> {
242    let ContextArgs {
243        worktree: worktree_path,
244        cursor,
245        use_language_server,
246        ..
247    } = args;
248
249    let worktree_path = worktree_path.canonicalize()?;
250
251    let project = cx.update(|cx| {
252        Project::local(
253            app_state.client.clone(),
254            app_state.node_runtime.clone(),
255            app_state.user_store.clone(),
256            app_state.languages.clone(),
257            app_state.fs.clone(),
258            None,
259            cx,
260        )
261    })?;
262
263    let worktree = project
264        .update(cx, |project, cx| {
265            project.create_worktree(&worktree_path, true, cx)
266        })?
267        .await?;
268
269    let mut ready_languages = HashSet::default();
270    let (_lsp_open_handle, buffer) = if *use_language_server {
271        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
272            project.clone(),
273            worktree.clone(),
274            cursor.path.clone(),
275            &mut ready_languages,
276            cx,
277        )
278        .await?;
279        (Some(lsp_open_handle), buffer)
280    } else {
281        let buffer =
282            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
283        (None, buffer)
284    };
285
286    let full_path_str = worktree
287        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
288        .display(PathStyle::local())
289        .to_string();
290
291    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
292    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
293    if clipped_cursor != cursor.point {
294        let max_row = snapshot.max_point().row;
295        if cursor.point.row < max_row {
296            return Err(anyhow!(
297                "Cursor position {:?} is out of bounds (line length is {})",
298                cursor.point,
299                snapshot.line_len(cursor.point.row)
300            ));
301        } else {
302            return Err(anyhow!(
303                "Cursor position {:?} is out of bounds (max row is {})",
304                cursor.point,
305                max_row
306            ));
307        }
308    }
309
310    Ok(LoadedContext {
311        full_path_str,
312        snapshot,
313        clipped_cursor,
314        worktree,
315        project,
316        buffer,
317    })
318}
319
320async fn zeta2_syntax_context(
321    zeta2_args: Zeta2Args,
322    syntax_args: Zeta2SyntaxArgs,
323    args: ContextArgs,
324    app_state: &Arc<ZetaCliAppState>,
325    cx: &mut AsyncApp,
326) -> Result<String> {
327    let LoadedContext {
328        worktree,
329        project,
330        buffer,
331        clipped_cursor,
332        ..
333    } = load_context(&args, app_state, cx).await?;
334
335    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
336    // the whole worktree.
337    worktree
338        .read_with(cx, |worktree, _cx| {
339            worktree.as_local().unwrap().scan_complete()
340        })?
341        .await;
342    let output = cx
343        .update(|cx| {
344            let zeta = cx.new(|cx| {
345                zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
346            });
347            let indexing_done_task = zeta.update(cx, |zeta, cx| {
348                zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
349                zeta.register_buffer(&buffer, &project, cx);
350                zeta.wait_for_initial_indexing(&project, cx)
351            });
352            cx.spawn(async move |cx| {
353                indexing_done_task.await?;
354                let request = zeta
355                    .update(cx, |zeta, cx| {
356                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
357                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
358                    })?
359                    .await?;
360
361                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
362
363                match zeta2_args.output_format {
364                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
365                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
366                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
367                        "request": request,
368                        "prompt": prompt_string,
369                        "section_labels": section_labels,
370                    }))?),
371                }
372            })
373        })?
374        .await?;
375
376    Ok(output)
377}
378
379async fn zeta1_context(
380    args: ContextArgs,
381    app_state: &Arc<ZetaCliAppState>,
382    cx: &mut AsyncApp,
383) -> Result<zeta::GatherContextOutput> {
384    let LoadedContext {
385        full_path_str,
386        snapshot,
387        clipped_cursor,
388        ..
389    } = load_context(&args, app_state, cx).await?;
390
391    let events = match args.edit_history {
392        Some(events) => events.read_to_string().await?,
393        None => String::new(),
394    };
395
396    let prompt_for_events = move || (events, 0);
397    cx.update(|cx| {
398        zeta::gather_context(
399            full_path_str,
400            &snapshot,
401            clipped_cursor,
402            prompt_for_events,
403            cx,
404        )
405    })?
406    .await
407}
408
409fn main() {
410    zlog::init();
411    zlog::init_output_stderr();
412    let args = ZetaCliArgs::parse();
413    let http_client = Arc::new(ReqwestClient::new());
414    let app = Application::headless().with_http_client(http_client);
415
416    app.run(move |cx| {
417        let app_state = Arc::new(headless::init(cx));
418        cx.spawn(async move |cx| {
419            match args.command {
420                None => {
421                    if args.printenv {
422                        ::util::shell_env::print_env();
423                        return;
424                    } else {
425                        panic!("Expected a command");
426                    }
427                }
428                Some(Command::Zeta1 {
429                    command: Zeta1Command::Context { context_args },
430                }) => {
431                    let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
432                    let result = serde_json::to_string_pretty(&context.body).unwrap();
433                    println!("{}", result);
434                }
435                Some(Command::Zeta2 { command }) => match command {
436                    Zeta2Command::Predict(arguments) => {
437                        run_zeta2_predict(arguments, &app_state, cx).await;
438                    }
439                    Zeta2Command::Eval(arguments) => {
440                        run_evaluate(arguments, &app_state, cx).await;
441                    }
442                    Zeta2Command::Syntax {
443                        args,
444                        syntax_args,
445                        command,
446                    } => {
447                        let result = match command {
448                            Zeta2SyntaxCommand::Context { context_args } => {
449                                zeta2_syntax_context(
450                                    args,
451                                    syntax_args,
452                                    context_args,
453                                    &app_state,
454                                    cx,
455                                )
456                                .await
457                            }
458                            Zeta2SyntaxCommand::Stats {
459                                worktree,
460                                extension,
461                                limit,
462                                skip,
463                            } => {
464                                retrieval_stats(
465                                    worktree,
466                                    app_state,
467                                    extension,
468                                    limit,
469                                    skip,
470                                    syntax_args_to_options(&args, &syntax_args, false),
471                                    cx,
472                                )
473                                .await
474                            }
475                        };
476                        println!("{}", result.unwrap());
477                    }
478                },
479                Some(Command::ConvertExample {
480                    path,
481                    output_format,
482                }) => {
483                    let example = NamedExample::load(path).unwrap();
484                    example.write(output_format, io::stdout()).unwrap();
485                }
486                Some(Command::Clean) => {
487                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
488                }
489            };
490
491            let _ = cx.update(|cx| cx.quit());
492        })
493        .detach();
494    });
495}