main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod paths;
  5mod predict;
  6mod source_location;
  7mod syntax_retrieval_stats;
  8mod util;
  9
 10use crate::evaluate::{EvaluateArguments, run_evaluate};
 11use crate::example::{ExampleFormat, NamedExample};
 12use crate::predict::{PredictArguments, run_zeta2_predict};
 13use crate::syntax_retrieval_stats::retrieval_stats;
 14use ::util::paths::PathStyle;
 15use anyhow::{Result, anyhow};
 16use clap::{Args, Parser, Subcommand};
 17use cloud_llm_client::predict_edits_v3;
 18use edit_prediction_context::{
 19    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 20};
 21use gpui::{Application, AsyncApp, Entity, prelude::*};
 22use language::{Bias, Buffer, BufferSnapshot, Point};
 23use project::{Project, Worktree};
 24use reqwest_client::ReqwestClient;
 25use serde_json::json;
 26use std::io::{self};
 27use std::time::Duration;
 28use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 29use zeta2::ContextMode;
 30
 31use crate::headless::ZetaCliAppState;
 32use crate::source_location::SourceLocation;
 33use crate::util::{open_buffer, open_buffer_with_language_server};
 34
 35#[derive(Parser, Debug)]
 36#[command(name = "zeta")]
 37struct ZetaCliArgs {
 38    #[arg(long, default_value_t = false)]
 39    printenv: bool,
 40    #[command(subcommand)]
 41    command: Option<Command>,
 42}
 43
 44#[derive(Subcommand, Debug)]
 45enum Command {
 46    Zeta1 {
 47        #[command(subcommand)]
 48        command: Zeta1Command,
 49    },
 50    Zeta2 {
 51        #[command(subcommand)]
 52        command: Zeta2Command,
 53    },
 54    ConvertExample {
 55        path: PathBuf,
 56        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 57        output_format: ExampleFormat,
 58    },
 59    Clean,
 60}
 61
 62#[derive(Subcommand, Debug)]
 63enum Zeta1Command {
 64    Context {
 65        #[clap(flatten)]
 66        context_args: ContextArgs,
 67    },
 68}
 69
 70#[derive(Subcommand, Debug)]
 71enum Zeta2Command {
 72    Syntax {
 73        #[clap(flatten)]
 74        args: Zeta2Args,
 75        #[clap(flatten)]
 76        syntax_args: Zeta2SyntaxArgs,
 77        #[command(subcommand)]
 78        command: Zeta2SyntaxCommand,
 79    },
 80    Predict(PredictArguments),
 81    Eval(EvaluateArguments),
 82}
 83
 84#[derive(Subcommand, Debug)]
 85enum Zeta2SyntaxCommand {
 86    Context {
 87        #[clap(flatten)]
 88        context_args: ContextArgs,
 89    },
 90    Stats {
 91        #[arg(long)]
 92        worktree: PathBuf,
 93        #[arg(long)]
 94        extension: Option<String>,
 95        #[arg(long)]
 96        limit: Option<usize>,
 97        #[arg(long)]
 98        skip: Option<usize>,
 99    },
100}
101
102#[derive(Debug, Args)]
103#[group(requires = "worktree")]
104struct ContextArgs {
105    #[arg(long)]
106    worktree: PathBuf,
107    #[arg(long)]
108    cursor: SourceLocation,
109    #[arg(long)]
110    use_language_server: bool,
111    #[arg(long)]
112    edit_history: Option<FileOrStdin>,
113}
114
115#[derive(Debug, Args)]
116struct Zeta2Args {
117    #[arg(long, default_value_t = 8192)]
118    max_prompt_bytes: usize,
119    #[arg(long, default_value_t = 2048)]
120    max_excerpt_bytes: usize,
121    #[arg(long, default_value_t = 1024)]
122    min_excerpt_bytes: usize,
123    #[arg(long, default_value_t = 0.66)]
124    target_before_cursor_over_total_bytes: f32,
125    #[arg(long, default_value_t = 1024)]
126    max_diagnostic_bytes: usize,
127    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
128    prompt_format: PromptFormat,
129    #[arg(long, value_enum, default_value_t = Default::default())]
130    output_format: OutputFormat,
131    #[arg(long, default_value_t = 42)]
132    file_indexing_parallelism: usize,
133}
134
135#[derive(Debug, Args)]
136struct Zeta2SyntaxArgs {
137    #[arg(long, default_value_t = false)]
138    disable_imports_gathering: bool,
139    #[arg(long, default_value_t = u8::MAX)]
140    max_retrieved_definitions: u8,
141}
142
143fn syntax_args_to_options(
144    zeta2_args: &Zeta2Args,
145    syntax_args: &Zeta2SyntaxArgs,
146    omit_excerpt_overlaps: bool,
147) -> zeta2::ZetaOptions {
148    zeta2::ZetaOptions {
149        context: ContextMode::Syntax(EditPredictionContextOptions {
150            max_retrieved_declarations: syntax_args.max_retrieved_definitions,
151            use_imports: !syntax_args.disable_imports_gathering,
152            excerpt: EditPredictionExcerptOptions {
153                max_bytes: zeta2_args.max_excerpt_bytes,
154                min_bytes: zeta2_args.min_excerpt_bytes,
155                target_before_cursor_over_total_bytes: zeta2_args
156                    .target_before_cursor_over_total_bytes,
157            },
158            score: EditPredictionScoreOptions {
159                omit_excerpt_overlaps,
160            },
161        }),
162        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
163        max_prompt_bytes: zeta2_args.max_prompt_bytes,
164        prompt_format: zeta2_args.prompt_format.into(),
165        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
166        buffer_change_grouping_interval: Duration::ZERO,
167    }
168}
169
170#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
171enum PromptFormat {
172    MarkedExcerpt,
173    LabeledSections,
174    OnlySnippets,
175    #[default]
176    NumberedLines,
177    OldTextNewText,
178    Minimal,
179    MinimalQwen,
180}
181
182impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
183    fn into(self) -> predict_edits_v3::PromptFormat {
184        match self {
185            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
186            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
187            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
188            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
189            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
190            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
191            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
192        }
193    }
194}
195
196#[derive(clap::ValueEnum, Default, Debug, Clone)]
197enum OutputFormat {
198    #[default]
199    Prompt,
200    Request,
201    Full,
202}
203
204#[derive(Debug, Clone)]
205enum FileOrStdin {
206    File(PathBuf),
207    Stdin,
208}
209
210impl FileOrStdin {
211    async fn read_to_string(&self) -> Result<String, std::io::Error> {
212        match self {
213            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
214            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
215        }
216    }
217}
218
219impl FromStr for FileOrStdin {
220    type Err = <PathBuf as FromStr>::Err;
221
222    fn from_str(s: &str) -> Result<Self, Self::Err> {
223        match s {
224            "-" => Ok(Self::Stdin),
225            _ => Ok(Self::File(PathBuf::from_str(s)?)),
226        }
227    }
228}
229
230struct LoadedContext {
231    full_path_str: String,
232    snapshot: BufferSnapshot,
233    clipped_cursor: Point,
234    worktree: Entity<Worktree>,
235    project: Entity<Project>,
236    buffer: Entity<Buffer>,
237}
238
239async fn load_context(
240    args: &ContextArgs,
241    app_state: &Arc<ZetaCliAppState>,
242    cx: &mut AsyncApp,
243) -> Result<LoadedContext> {
244    let ContextArgs {
245        worktree: worktree_path,
246        cursor,
247        use_language_server,
248        ..
249    } = args;
250
251    let worktree_path = worktree_path.canonicalize()?;
252
253    let project = cx.update(|cx| {
254        Project::local(
255            app_state.client.clone(),
256            app_state.node_runtime.clone(),
257            app_state.user_store.clone(),
258            app_state.languages.clone(),
259            app_state.fs.clone(),
260            None,
261            cx,
262        )
263    })?;
264
265    let worktree = project
266        .update(cx, |project, cx| {
267            project.create_worktree(&worktree_path, true, cx)
268        })?
269        .await?;
270
271    let mut ready_languages = HashSet::default();
272    let (_lsp_open_handle, buffer) = if *use_language_server {
273        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
274            project.clone(),
275            worktree.clone(),
276            cursor.path.clone(),
277            &mut ready_languages,
278            cx,
279        )
280        .await?;
281        (Some(lsp_open_handle), buffer)
282    } else {
283        let buffer =
284            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
285        (None, buffer)
286    };
287
288    let full_path_str = worktree
289        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
290        .display(PathStyle::local())
291        .to_string();
292
293    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
294    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
295    if clipped_cursor != cursor.point {
296        let max_row = snapshot.max_point().row;
297        if cursor.point.row < max_row {
298            return Err(anyhow!(
299                "Cursor position {:?} is out of bounds (line length is {})",
300                cursor.point,
301                snapshot.line_len(cursor.point.row)
302            ));
303        } else {
304            return Err(anyhow!(
305                "Cursor position {:?} is out of bounds (max row is {})",
306                cursor.point,
307                max_row
308            ));
309        }
310    }
311
312    Ok(LoadedContext {
313        full_path_str,
314        snapshot,
315        clipped_cursor,
316        worktree,
317        project,
318        buffer,
319    })
320}
321
322async fn zeta2_syntax_context(
323    zeta2_args: Zeta2Args,
324    syntax_args: Zeta2SyntaxArgs,
325    args: ContextArgs,
326    app_state: &Arc<ZetaCliAppState>,
327    cx: &mut AsyncApp,
328) -> Result<String> {
329    let LoadedContext {
330        worktree,
331        project,
332        buffer,
333        clipped_cursor,
334        ..
335    } = load_context(&args, app_state, cx).await?;
336
337    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
338    // the whole worktree.
339    worktree
340        .read_with(cx, |worktree, _cx| {
341            worktree.as_local().unwrap().scan_complete()
342        })?
343        .await;
344    let output = cx
345        .update(|cx| {
346            let zeta = cx.new(|cx| {
347                zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
348            });
349            let indexing_done_task = zeta.update(cx, |zeta, cx| {
350                zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
351                zeta.register_buffer(&buffer, &project, cx);
352                zeta.wait_for_initial_indexing(&project, cx)
353            });
354            cx.spawn(async move |cx| {
355                indexing_done_task.await?;
356                let request = zeta
357                    .update(cx, |zeta, cx| {
358                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
359                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
360                    })?
361                    .await?;
362
363                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
364
365                match zeta2_args.output_format {
366                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
367                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
368                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
369                        "request": request,
370                        "prompt": prompt_string,
371                        "section_labels": section_labels,
372                    }))?),
373                }
374            })
375        })?
376        .await?;
377
378    Ok(output)
379}
380
381async fn zeta1_context(
382    args: ContextArgs,
383    app_state: &Arc<ZetaCliAppState>,
384    cx: &mut AsyncApp,
385) -> Result<zeta::GatherContextOutput> {
386    let LoadedContext {
387        full_path_str,
388        snapshot,
389        clipped_cursor,
390        ..
391    } = load_context(&args, app_state, cx).await?;
392
393    let events = match args.edit_history {
394        Some(events) => events.read_to_string().await?,
395        None => String::new(),
396    };
397
398    let prompt_for_events = move || (events, 0);
399    cx.update(|cx| {
400        zeta::gather_context(
401            full_path_str,
402            &snapshot,
403            clipped_cursor,
404            prompt_for_events,
405            cx,
406        )
407    })?
408    .await
409}
410
411fn main() {
412    zlog::init();
413    zlog::init_output_stderr();
414    let args = ZetaCliArgs::parse();
415    let http_client = Arc::new(ReqwestClient::new());
416    let app = Application::headless().with_http_client(http_client);
417
418    app.run(move |cx| {
419        let app_state = Arc::new(headless::init(cx));
420        cx.spawn(async move |cx| {
421            match args.command {
422                None => {
423                    if args.printenv {
424                        ::util::shell_env::print_env();
425                        return;
426                    } else {
427                        panic!("Expected a command");
428                    }
429                }
430                Some(Command::Zeta1 {
431                    command: Zeta1Command::Context { context_args },
432                }) => {
433                    let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
434                    let result = serde_json::to_string_pretty(&context.body).unwrap();
435                    println!("{}", result);
436                }
437                Some(Command::Zeta2 { command }) => match command {
438                    Zeta2Command::Predict(arguments) => {
439                        run_zeta2_predict(arguments, &app_state, cx).await;
440                    }
441                    Zeta2Command::Eval(arguments) => {
442                        run_evaluate(arguments, &app_state, cx).await;
443                    }
444                    Zeta2Command::Syntax {
445                        args,
446                        syntax_args,
447                        command,
448                    } => {
449                        let result = match command {
450                            Zeta2SyntaxCommand::Context { context_args } => {
451                                zeta2_syntax_context(
452                                    args,
453                                    syntax_args,
454                                    context_args,
455                                    &app_state,
456                                    cx,
457                                )
458                                .await
459                            }
460                            Zeta2SyntaxCommand::Stats {
461                                worktree,
462                                extension,
463                                limit,
464                                skip,
465                            } => {
466                                retrieval_stats(
467                                    worktree,
468                                    app_state,
469                                    extension,
470                                    limit,
471                                    skip,
472                                    syntax_args_to_options(&args, &syntax_args, false),
473                                    cx,
474                                )
475                                .await
476                            }
477                        };
478                        println!("{}", result.unwrap());
479                    }
480                },
481                Some(Command::ConvertExample {
482                    path,
483                    output_format,
484                }) => {
485                    let example = NamedExample::load(path).unwrap();
486                    example.write(output_format, io::stdout()).unwrap();
487                }
488                Some(Command::Clean) => {
489                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
490                }
491            };
492
493            let _ = cx.update(|cx| cx.quit());
494        })
495        .detach();
496    });
497}