main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod paths;
  5mod predict;
  6mod source_location;
  7mod syntax_retrieval_stats;
  8mod util;
  9
 10use crate::evaluate::{EvaluateArguments, run_evaluate};
 11use crate::example::{ExampleFormat, NamedExample};
 12use crate::predict::{PredictArguments, run_zeta2_predict};
 13use crate::syntax_retrieval_stats::retrieval_stats;
 14use ::util::paths::PathStyle;
 15use anyhow::{Result, anyhow};
 16use clap::{Args, Parser, Subcommand};
 17use cloud_llm_client::predict_edits_v3;
 18use edit_prediction_context::{
 19    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 20};
 21use gpui::{Application, AsyncApp, Entity, prelude::*};
 22use language::{Bias, Buffer, BufferSnapshot, Point};
 23use project::{Project, Worktree};
 24use reqwest_client::ReqwestClient;
 25use serde_json::json;
 26use std::io::{self};
 27use std::time::Duration;
 28use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 29use zeta2::ContextMode;
 30
 31use crate::headless::ZetaCliAppState;
 32use crate::source_location::SourceLocation;
 33use crate::util::{open_buffer, open_buffer_with_language_server};
 34
 35#[derive(Parser, Debug)]
 36#[command(name = "zeta")]
 37struct ZetaCliArgs {
 38    #[command(subcommand)]
 39    command: Command,
 40}
 41
 42#[derive(Subcommand, Debug)]
 43enum Command {
 44    Zeta1 {
 45        #[command(subcommand)]
 46        command: Zeta1Command,
 47    },
 48    Zeta2 {
 49        #[command(subcommand)]
 50        command: Zeta2Command,
 51    },
 52    ConvertExample {
 53        path: PathBuf,
 54        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 55        output_format: ExampleFormat,
 56    },
 57    Clean,
 58}
 59
 60#[derive(Subcommand, Debug)]
 61enum Zeta1Command {
 62    Context {
 63        #[clap(flatten)]
 64        context_args: ContextArgs,
 65    },
 66}
 67
 68#[derive(Subcommand, Debug)]
 69enum Zeta2Command {
 70    Syntax {
 71        #[clap(flatten)]
 72        args: Zeta2Args,
 73        #[clap(flatten)]
 74        syntax_args: Zeta2SyntaxArgs,
 75        #[command(subcommand)]
 76        command: Zeta2SyntaxCommand,
 77    },
 78    Predict(PredictArguments),
 79    Eval(EvaluateArguments),
 80}
 81
 82#[derive(Subcommand, Debug)]
 83enum Zeta2SyntaxCommand {
 84    Context {
 85        #[clap(flatten)]
 86        context_args: ContextArgs,
 87    },
 88    Stats {
 89        #[arg(long)]
 90        worktree: PathBuf,
 91        #[arg(long)]
 92        extension: Option<String>,
 93        #[arg(long)]
 94        limit: Option<usize>,
 95        #[arg(long)]
 96        skip: Option<usize>,
 97    },
 98}
 99
100#[derive(Debug, Args)]
101#[group(requires = "worktree")]
102struct ContextArgs {
103    #[arg(long)]
104    worktree: PathBuf,
105    #[arg(long)]
106    cursor: SourceLocation,
107    #[arg(long)]
108    use_language_server: bool,
109    #[arg(long)]
110    edit_history: Option<FileOrStdin>,
111}
112
113#[derive(Debug, Args)]
114struct Zeta2Args {
115    #[arg(long, default_value_t = 8192)]
116    max_prompt_bytes: usize,
117    #[arg(long, default_value_t = 2048)]
118    max_excerpt_bytes: usize,
119    #[arg(long, default_value_t = 1024)]
120    min_excerpt_bytes: usize,
121    #[arg(long, default_value_t = 0.66)]
122    target_before_cursor_over_total_bytes: f32,
123    #[arg(long, default_value_t = 1024)]
124    max_diagnostic_bytes: usize,
125    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
126    prompt_format: PromptFormat,
127    #[arg(long, value_enum, default_value_t = Default::default())]
128    output_format: OutputFormat,
129    #[arg(long, default_value_t = 42)]
130    file_indexing_parallelism: usize,
131}
132
133#[derive(Debug, Args)]
134struct Zeta2SyntaxArgs {
135    #[arg(long, default_value_t = false)]
136    disable_imports_gathering: bool,
137    #[arg(long, default_value_t = u8::MAX)]
138    max_retrieved_definitions: u8,
139}
140
141fn syntax_args_to_options(
142    zeta2_args: &Zeta2Args,
143    syntax_args: &Zeta2SyntaxArgs,
144    omit_excerpt_overlaps: bool,
145) -> zeta2::ZetaOptions {
146    zeta2::ZetaOptions {
147        context: ContextMode::Syntax(EditPredictionContextOptions {
148            max_retrieved_declarations: syntax_args.max_retrieved_definitions,
149            use_imports: !syntax_args.disable_imports_gathering,
150            excerpt: EditPredictionExcerptOptions {
151                max_bytes: zeta2_args.max_excerpt_bytes,
152                min_bytes: zeta2_args.min_excerpt_bytes,
153                target_before_cursor_over_total_bytes: zeta2_args
154                    .target_before_cursor_over_total_bytes,
155            },
156            score: EditPredictionScoreOptions {
157                omit_excerpt_overlaps,
158            },
159        }),
160        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
161        max_prompt_bytes: zeta2_args.max_prompt_bytes,
162        prompt_format: zeta2_args.prompt_format.into(),
163        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
164        buffer_change_grouping_interval: Duration::ZERO,
165    }
166}
167
168#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
169enum PromptFormat {
170    MarkedExcerpt,
171    LabeledSections,
172    OnlySnippets,
173    #[default]
174    NumberedLines,
175    OldTextNewText,
176}
177
178impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
179    fn into(self) -> predict_edits_v3::PromptFormat {
180        match self {
181            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
182            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
183            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
184            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
185            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
186        }
187    }
188}
189
190#[derive(clap::ValueEnum, Default, Debug, Clone)]
191enum OutputFormat {
192    #[default]
193    Prompt,
194    Request,
195    Full,
196}
197
198#[derive(Debug, Clone)]
199enum FileOrStdin {
200    File(PathBuf),
201    Stdin,
202}
203
204impl FileOrStdin {
205    async fn read_to_string(&self) -> Result<String, std::io::Error> {
206        match self {
207            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
208            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
209        }
210    }
211}
212
213impl FromStr for FileOrStdin {
214    type Err = <PathBuf as FromStr>::Err;
215
216    fn from_str(s: &str) -> Result<Self, Self::Err> {
217        match s {
218            "-" => Ok(Self::Stdin),
219            _ => Ok(Self::File(PathBuf::from_str(s)?)),
220        }
221    }
222}
223
224struct LoadedContext {
225    full_path_str: String,
226    snapshot: BufferSnapshot,
227    clipped_cursor: Point,
228    worktree: Entity<Worktree>,
229    project: Entity<Project>,
230    buffer: Entity<Buffer>,
231}
232
233async fn load_context(
234    args: &ContextArgs,
235    app_state: &Arc<ZetaCliAppState>,
236    cx: &mut AsyncApp,
237) -> Result<LoadedContext> {
238    let ContextArgs {
239        worktree: worktree_path,
240        cursor,
241        use_language_server,
242        ..
243    } = args;
244
245    let worktree_path = worktree_path.canonicalize()?;
246
247    let project = cx.update(|cx| {
248        Project::local(
249            app_state.client.clone(),
250            app_state.node_runtime.clone(),
251            app_state.user_store.clone(),
252            app_state.languages.clone(),
253            app_state.fs.clone(),
254            None,
255            cx,
256        )
257    })?;
258
259    let worktree = project
260        .update(cx, |project, cx| {
261            project.create_worktree(&worktree_path, true, cx)
262        })?
263        .await?;
264
265    let mut ready_languages = HashSet::default();
266    let (_lsp_open_handle, buffer) = if *use_language_server {
267        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
268            project.clone(),
269            worktree.clone(),
270            cursor.path.clone(),
271            &mut ready_languages,
272            cx,
273        )
274        .await?;
275        (Some(lsp_open_handle), buffer)
276    } else {
277        let buffer =
278            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
279        (None, buffer)
280    };
281
282    let full_path_str = worktree
283        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
284        .display(PathStyle::local())
285        .to_string();
286
287    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
288    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
289    if clipped_cursor != cursor.point {
290        let max_row = snapshot.max_point().row;
291        if cursor.point.row < max_row {
292            return Err(anyhow!(
293                "Cursor position {:?} is out of bounds (line length is {})",
294                cursor.point,
295                snapshot.line_len(cursor.point.row)
296            ));
297        } else {
298            return Err(anyhow!(
299                "Cursor position {:?} is out of bounds (max row is {})",
300                cursor.point,
301                max_row
302            ));
303        }
304    }
305
306    Ok(LoadedContext {
307        full_path_str,
308        snapshot,
309        clipped_cursor,
310        worktree,
311        project,
312        buffer,
313    })
314}
315
316async fn zeta2_syntax_context(
317    zeta2_args: Zeta2Args,
318    syntax_args: Zeta2SyntaxArgs,
319    args: ContextArgs,
320    app_state: &Arc<ZetaCliAppState>,
321    cx: &mut AsyncApp,
322) -> Result<String> {
323    let LoadedContext {
324        worktree,
325        project,
326        buffer,
327        clipped_cursor,
328        ..
329    } = load_context(&args, app_state, cx).await?;
330
331    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
332    // the whole worktree.
333    worktree
334        .read_with(cx, |worktree, _cx| {
335            worktree.as_local().unwrap().scan_complete()
336        })?
337        .await;
338    let output = cx
339        .update(|cx| {
340            let zeta = cx.new(|cx| {
341                zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
342            });
343            let indexing_done_task = zeta.update(cx, |zeta, cx| {
344                zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
345                zeta.register_buffer(&buffer, &project, cx);
346                zeta.wait_for_initial_indexing(&project, cx)
347            });
348            cx.spawn(async move |cx| {
349                indexing_done_task.await?;
350                let request = zeta
351                    .update(cx, |zeta, cx| {
352                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
353                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
354                    })?
355                    .await?;
356
357                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
358
359                match zeta2_args.output_format {
360                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
361                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
362                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
363                        "request": request,
364                        "prompt": prompt_string,
365                        "section_labels": section_labels,
366                    }))?),
367                }
368            })
369        })?
370        .await?;
371
372    Ok(output)
373}
374
375async fn zeta1_context(
376    args: ContextArgs,
377    app_state: &Arc<ZetaCliAppState>,
378    cx: &mut AsyncApp,
379) -> Result<zeta::GatherContextOutput> {
380    let LoadedContext {
381        full_path_str,
382        snapshot,
383        clipped_cursor,
384        ..
385    } = load_context(&args, app_state, cx).await?;
386
387    let events = match args.edit_history {
388        Some(events) => events.read_to_string().await?,
389        None => String::new(),
390    };
391
392    let prompt_for_events = move || (events, 0);
393    cx.update(|cx| {
394        zeta::gather_context(
395            full_path_str,
396            &snapshot,
397            clipped_cursor,
398            prompt_for_events,
399            cx,
400        )
401    })?
402    .await
403}
404
405fn main() {
406    zlog::init();
407    zlog::init_output_stderr();
408    let args = ZetaCliArgs::parse();
409    let http_client = Arc::new(ReqwestClient::new());
410    let app = Application::headless().with_http_client(http_client);
411
412    app.run(move |cx| {
413        let app_state = Arc::new(headless::init(cx));
414        cx.spawn(async move |cx| {
415            match args.command {
416                Command::Zeta1 {
417                    command: Zeta1Command::Context { context_args },
418                } => {
419                    let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
420                    let result = serde_json::to_string_pretty(&context.body).unwrap();
421                    println!("{}", result);
422                }
423                Command::Zeta2 { command } => match command {
424                    Zeta2Command::Predict(arguments) => {
425                        run_zeta2_predict(arguments, &app_state, cx).await;
426                    }
427                    Zeta2Command::Eval(arguments) => {
428                        run_evaluate(arguments, &app_state, cx).await;
429                    }
430                    Zeta2Command::Syntax {
431                        args,
432                        syntax_args,
433                        command,
434                    } => {
435                        let result = match command {
436                            Zeta2SyntaxCommand::Context { context_args } => {
437                                zeta2_syntax_context(
438                                    args,
439                                    syntax_args,
440                                    context_args,
441                                    &app_state,
442                                    cx,
443                                )
444                                .await
445                            }
446                            Zeta2SyntaxCommand::Stats {
447                                worktree,
448                                extension,
449                                limit,
450                                skip,
451                            } => {
452                                retrieval_stats(
453                                    worktree,
454                                    app_state,
455                                    extension,
456                                    limit,
457                                    skip,
458                                    syntax_args_to_options(&args, &syntax_args, false),
459                                    cx,
460                                )
461                                .await
462                            }
463                        };
464                        println!("{}", result.unwrap());
465                    }
466                },
467                Command::ConvertExample {
468                    path,
469                    output_format,
470                } => {
471                    let example = NamedExample::load(path).unwrap();
472                    example.write(output_format, io::stdout()).unwrap();
473                }
474                Command::Clean => std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap(),
475            };
476
477            let _ = cx.update(|cx| cx.quit());
478        })
479        .detach();
480    });
481}