main.rs

  1mod evaluate;
  2mod example;
  3mod headless;
  4mod paths;
  5mod predict;
  6mod source_location;
  7mod syntax_retrieval_stats;
  8mod util;
  9
 10use crate::evaluate::{EvaluateArguments, run_evaluate};
 11use crate::example::{ExampleFormat, NamedExample};
 12use crate::predict::{PredictArguments, run_zeta2_predict};
 13use crate::syntax_retrieval_stats::retrieval_stats;
 14use ::util::paths::PathStyle;
 15use anyhow::{Result, anyhow};
 16use clap::{Args, Parser, Subcommand};
 17use cloud_llm_client::predict_edits_v3;
 18use edit_prediction_context::{
 19    EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
 20};
 21use gpui::{Application, AsyncApp, Entity, prelude::*};
 22use language::{Bias, Buffer, BufferSnapshot, Point};
 23use project::{Project, Worktree};
 24use reqwest_client::ReqwestClient;
 25use serde_json::json;
 26use std::io::{self};
 27use std::time::Duration;
 28use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
 29use zeta2::ContextMode;
 30
 31use crate::headless::ZetaCliAppState;
 32use crate::source_location::SourceLocation;
 33use crate::util::{open_buffer, open_buffer_with_language_server};
 34
 35#[derive(Parser, Debug)]
 36#[command(name = "zeta")]
 37struct ZetaCliArgs {
 38    #[command(subcommand)]
 39    command: Command,
 40}
 41
 42#[derive(Subcommand, Debug)]
 43enum Command {
 44    Zeta1 {
 45        #[command(subcommand)]
 46        command: Zeta1Command,
 47    },
 48    Zeta2 {
 49        #[command(subcommand)]
 50        command: Zeta2Command,
 51    },
 52    ConvertExample {
 53        path: PathBuf,
 54        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
 55        output_format: ExampleFormat,
 56    },
 57}
 58
 59#[derive(Subcommand, Debug)]
 60enum Zeta1Command {
 61    Context {
 62        #[clap(flatten)]
 63        context_args: ContextArgs,
 64    },
 65}
 66
 67#[derive(Subcommand, Debug)]
 68enum Zeta2Command {
 69    Syntax {
 70        #[clap(flatten)]
 71        args: Zeta2Args,
 72        #[clap(flatten)]
 73        syntax_args: Zeta2SyntaxArgs,
 74        #[command(subcommand)]
 75        command: Zeta2SyntaxCommand,
 76    },
 77    Predict(PredictArguments),
 78    Eval(EvaluateArguments),
 79}
 80
 81#[derive(Subcommand, Debug)]
 82enum Zeta2SyntaxCommand {
 83    Context {
 84        #[clap(flatten)]
 85        context_args: ContextArgs,
 86    },
 87    Stats {
 88        #[arg(long)]
 89        worktree: PathBuf,
 90        #[arg(long)]
 91        extension: Option<String>,
 92        #[arg(long)]
 93        limit: Option<usize>,
 94        #[arg(long)]
 95        skip: Option<usize>,
 96    },
 97}
 98
 99#[derive(Debug, Args)]
100#[group(requires = "worktree")]
101struct ContextArgs {
102    #[arg(long)]
103    worktree: PathBuf,
104    #[arg(long)]
105    cursor: SourceLocation,
106    #[arg(long)]
107    use_language_server: bool,
108    #[arg(long)]
109    edit_history: Option<FileOrStdin>,
110}
111
112#[derive(Debug, Args)]
113struct Zeta2Args {
114    #[arg(long, default_value_t = 8192)]
115    max_prompt_bytes: usize,
116    #[arg(long, default_value_t = 2048)]
117    max_excerpt_bytes: usize,
118    #[arg(long, default_value_t = 1024)]
119    min_excerpt_bytes: usize,
120    #[arg(long, default_value_t = 0.66)]
121    target_before_cursor_over_total_bytes: f32,
122    #[arg(long, default_value_t = 1024)]
123    max_diagnostic_bytes: usize,
124    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
125    prompt_format: PromptFormat,
126    #[arg(long, value_enum, default_value_t = Default::default())]
127    output_format: OutputFormat,
128    #[arg(long, default_value_t = 42)]
129    file_indexing_parallelism: usize,
130}
131
132#[derive(Debug, Args)]
133struct Zeta2SyntaxArgs {
134    #[arg(long, default_value_t = false)]
135    disable_imports_gathering: bool,
136    #[arg(long, default_value_t = u8::MAX)]
137    max_retrieved_definitions: u8,
138}
139
140fn syntax_args_to_options(
141    zeta2_args: &Zeta2Args,
142    syntax_args: &Zeta2SyntaxArgs,
143    omit_excerpt_overlaps: bool,
144) -> zeta2::ZetaOptions {
145    zeta2::ZetaOptions {
146        context: ContextMode::Syntax(EditPredictionContextOptions {
147            max_retrieved_declarations: syntax_args.max_retrieved_definitions,
148            use_imports: !syntax_args.disable_imports_gathering,
149            excerpt: EditPredictionExcerptOptions {
150                max_bytes: zeta2_args.max_excerpt_bytes,
151                min_bytes: zeta2_args.min_excerpt_bytes,
152                target_before_cursor_over_total_bytes: zeta2_args
153                    .target_before_cursor_over_total_bytes,
154            },
155            score: EditPredictionScoreOptions {
156                omit_excerpt_overlaps,
157            },
158        }),
159        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
160        max_prompt_bytes: zeta2_args.max_prompt_bytes,
161        prompt_format: zeta2_args.prompt_format.into(),
162        file_indexing_parallelism: zeta2_args.file_indexing_parallelism,
163        buffer_change_grouping_interval: Duration::ZERO,
164    }
165}
166
167#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
168enum PromptFormat {
169    MarkedExcerpt,
170    LabeledSections,
171    OnlySnippets,
172    #[default]
173    NumberedLines,
174    OldTextNewText,
175}
176
177impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
178    fn into(self) -> predict_edits_v3::PromptFormat {
179        match self {
180            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
181            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
182            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
183            Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
184            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
185        }
186    }
187}
188
189#[derive(clap::ValueEnum, Default, Debug, Clone)]
190enum OutputFormat {
191    #[default]
192    Prompt,
193    Request,
194    Full,
195}
196
197#[derive(Debug, Clone)]
198enum FileOrStdin {
199    File(PathBuf),
200    Stdin,
201}
202
203impl FileOrStdin {
204    async fn read_to_string(&self) -> Result<String, std::io::Error> {
205        match self {
206            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
207            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
208        }
209    }
210}
211
212impl FromStr for FileOrStdin {
213    type Err = <PathBuf as FromStr>::Err;
214
215    fn from_str(s: &str) -> Result<Self, Self::Err> {
216        match s {
217            "-" => Ok(Self::Stdin),
218            _ => Ok(Self::File(PathBuf::from_str(s)?)),
219        }
220    }
221}
222
223struct LoadedContext {
224    full_path_str: String,
225    snapshot: BufferSnapshot,
226    clipped_cursor: Point,
227    worktree: Entity<Worktree>,
228    project: Entity<Project>,
229    buffer: Entity<Buffer>,
230}
231
232async fn load_context(
233    args: &ContextArgs,
234    app_state: &Arc<ZetaCliAppState>,
235    cx: &mut AsyncApp,
236) -> Result<LoadedContext> {
237    let ContextArgs {
238        worktree: worktree_path,
239        cursor,
240        use_language_server,
241        ..
242    } = args;
243
244    let worktree_path = worktree_path.canonicalize()?;
245
246    let project = cx.update(|cx| {
247        Project::local(
248            app_state.client.clone(),
249            app_state.node_runtime.clone(),
250            app_state.user_store.clone(),
251            app_state.languages.clone(),
252            app_state.fs.clone(),
253            None,
254            cx,
255        )
256    })?;
257
258    let worktree = project
259        .update(cx, |project, cx| {
260            project.create_worktree(&worktree_path, true, cx)
261        })?
262        .await?;
263
264    let mut ready_languages = HashSet::default();
265    let (_lsp_open_handle, buffer) = if *use_language_server {
266        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
267            project.clone(),
268            worktree.clone(),
269            cursor.path.clone(),
270            &mut ready_languages,
271            cx,
272        )
273        .await?;
274        (Some(lsp_open_handle), buffer)
275    } else {
276        let buffer =
277            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
278        (None, buffer)
279    };
280
281    let full_path_str = worktree
282        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
283        .display(PathStyle::local())
284        .to_string();
285
286    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
287    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
288    if clipped_cursor != cursor.point {
289        let max_row = snapshot.max_point().row;
290        if cursor.point.row < max_row {
291            return Err(anyhow!(
292                "Cursor position {:?} is out of bounds (line length is {})",
293                cursor.point,
294                snapshot.line_len(cursor.point.row)
295            ));
296        } else {
297            return Err(anyhow!(
298                "Cursor position {:?} is out of bounds (max row is {})",
299                cursor.point,
300                max_row
301            ));
302        }
303    }
304
305    Ok(LoadedContext {
306        full_path_str,
307        snapshot,
308        clipped_cursor,
309        worktree,
310        project,
311        buffer,
312    })
313}
314
315async fn zeta2_syntax_context(
316    zeta2_args: Zeta2Args,
317    syntax_args: Zeta2SyntaxArgs,
318    args: ContextArgs,
319    app_state: &Arc<ZetaCliAppState>,
320    cx: &mut AsyncApp,
321) -> Result<String> {
322    let LoadedContext {
323        worktree,
324        project,
325        buffer,
326        clipped_cursor,
327        ..
328    } = load_context(&args, app_state, cx).await?;
329
330    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
331    // the whole worktree.
332    worktree
333        .read_with(cx, |worktree, _cx| {
334            worktree.as_local().unwrap().scan_complete()
335        })?
336        .await;
337    let output = cx
338        .update(|cx| {
339            let zeta = cx.new(|cx| {
340                zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
341            });
342            let indexing_done_task = zeta.update(cx, |zeta, cx| {
343                zeta.set_options(syntax_args_to_options(&zeta2_args, &syntax_args, true));
344                zeta.register_buffer(&buffer, &project, cx);
345                zeta.wait_for_initial_indexing(&project, cx)
346            });
347            cx.spawn(async move |cx| {
348                indexing_done_task.await?;
349                let request = zeta
350                    .update(cx, |zeta, cx| {
351                        let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
352                        zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
353                    })?
354                    .await?;
355
356                let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
357
358                match zeta2_args.output_format {
359                    OutputFormat::Prompt => anyhow::Ok(prompt_string),
360                    OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
361                    OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
362                        "request": request,
363                        "prompt": prompt_string,
364                        "section_labels": section_labels,
365                    }))?),
366                }
367            })
368        })?
369        .await?;
370
371    Ok(output)
372}
373
374async fn zeta1_context(
375    args: ContextArgs,
376    app_state: &Arc<ZetaCliAppState>,
377    cx: &mut AsyncApp,
378) -> Result<zeta::GatherContextOutput> {
379    let LoadedContext {
380        full_path_str,
381        snapshot,
382        clipped_cursor,
383        ..
384    } = load_context(&args, app_state, cx).await?;
385
386    let events = match args.edit_history {
387        Some(events) => events.read_to_string().await?,
388        None => String::new(),
389    };
390
391    let prompt_for_events = move || (events, 0);
392    cx.update(|cx| {
393        zeta::gather_context(
394            full_path_str,
395            &snapshot,
396            clipped_cursor,
397            prompt_for_events,
398            cx,
399        )
400    })?
401    .await
402}
403
404fn main() {
405    zlog::init();
406    zlog::init_output_stderr();
407    let args = ZetaCliArgs::parse();
408    let http_client = Arc::new(ReqwestClient::new());
409    let app = Application::headless().with_http_client(http_client);
410
411    app.run(move |cx| {
412        let app_state = Arc::new(headless::init(cx));
413        cx.spawn(async move |cx| {
414            match args.command {
415                Command::Zeta1 {
416                    command: Zeta1Command::Context { context_args },
417                } => {
418                    let context = zeta1_context(context_args, &app_state, cx).await.unwrap();
419                    let result = serde_json::to_string_pretty(&context.body).unwrap();
420                    println!("{}", result);
421                }
422                Command::Zeta2 { command } => match command {
423                    Zeta2Command::Predict(arguments) => {
424                        run_zeta2_predict(arguments, &app_state, cx).await;
425                    }
426                    Zeta2Command::Eval(arguments) => {
427                        run_evaluate(arguments, &app_state, cx).await;
428                    }
429                    Zeta2Command::Syntax {
430                        args,
431                        syntax_args,
432                        command,
433                    } => {
434                        let result = match command {
435                            Zeta2SyntaxCommand::Context { context_args } => {
436                                zeta2_syntax_context(
437                                    args,
438                                    syntax_args,
439                                    context_args,
440                                    &app_state,
441                                    cx,
442                                )
443                                .await
444                            }
445                            Zeta2SyntaxCommand::Stats {
446                                worktree,
447                                extension,
448                                limit,
449                                skip,
450                            } => {
451                                retrieval_stats(
452                                    worktree,
453                                    app_state,
454                                    extension,
455                                    limit,
456                                    skip,
457                                    syntax_args_to_options(&args, &syntax_args, false),
458                                    cx,
459                                )
460                                .await
461                            }
462                        };
463                        println!("{}", result.unwrap());
464                    }
465                },
466                Command::ConvertExample {
467                    path,
468                    output_format,
469                } => {
470                    let example = NamedExample::load(path).unwrap();
471                    example.write(output_format, io::stdout()).unwrap();
472                }
473            };
474
475            let _ = cx.update(|cx| cx.quit());
476        })
477        .detach();
478    });
479}