1mod headless;
   2
   3use anyhow::{Context as _, Result, anyhow};
   4use clap::{Args, Parser, Subcommand};
   5use cloud_llm_client::predict_edits_v3::{self, DeclarationScoreComponents};
   6use edit_prediction_context::{
   7    Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
   8    EditPredictionExcerptOptions, EditPredictionScoreOptions, Identifier, Imports, Reference,
   9    ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
  10};
  11use futures::channel::mpsc;
  12use futures::{FutureExt as _, StreamExt as _};
  13use gpui::{AppContext, Application, AsyncApp};
  14use gpui::{Entity, Task};
  15use language::{Bias, BufferSnapshot, LanguageServerId, Point};
  16use language::{Buffer, OffsetRangeExt};
  17use language::{LanguageId, ParseStatus};
  18use language_model::LlmApiToken;
  19use ordered_float::OrderedFloat;
  20use project::{Project, ProjectEntryId, ProjectPath, Worktree};
  21use release_channel::AppVersion;
  22use reqwest_client::ReqwestClient;
  23use serde::{Deserialize, Deserializer, Serialize, Serializer};
  24use serde_json::json;
  25use std::cmp::Reverse;
  26use std::collections::{HashMap, HashSet};
  27use std::fmt::{self, Display};
  28use std::fs::File;
  29use std::hash::Hash;
  30use std::hash::Hasher;
  31use std::io::{BufRead, BufReader, BufWriter, Write as _};
  32use std::ops::Range;
  33use std::path::{Path, PathBuf};
  34use std::process::exit;
  35use std::str::FromStr;
  36use std::sync::atomic::AtomicUsize;
  37use std::sync::{Arc, atomic};
  38use std::time::Duration;
  39use util::paths::PathStyle;
  40use util::rel_path::RelPath;
  41use util::{RangeExt, ResultExt as _};
  42use zeta::{PerformPredictEditsParams, Zeta};
  43
  44use crate::headless::ZetaCliAppState;
  45
  46#[derive(Parser, Debug)]
  47#[command(name = "zeta")]
  48struct ZetaCliArgs {
  49    #[command(subcommand)]
  50    command: Commands,
  51}
  52
  53#[derive(Subcommand, Debug)]
  54enum Commands {
  55    Context(ContextArgs),
  56    Zeta2Context {
  57        #[clap(flatten)]
  58        zeta2_args: Zeta2Args,
  59        #[clap(flatten)]
  60        context_args: ContextArgs,
  61    },
  62    Predict {
  63        #[arg(long)]
  64        predict_edits_body: Option<FileOrStdin>,
  65        #[clap(flatten)]
  66        context_args: Option<ContextArgs>,
  67    },
  68    RetrievalStats {
  69        #[clap(flatten)]
  70        zeta2_args: Zeta2Args,
  71        #[arg(long)]
  72        worktree: PathBuf,
  73        #[arg(long)]
  74        extension: Option<String>,
  75        #[arg(long)]
  76        limit: Option<usize>,
  77        #[arg(long)]
  78        skip: Option<usize>,
  79    },
  80}
  81
  82#[derive(Debug, Args)]
  83#[group(requires = "worktree")]
  84struct ContextArgs {
  85    #[arg(long)]
  86    worktree: PathBuf,
  87    #[arg(long)]
  88    cursor: SourceLocation,
  89    #[arg(long)]
  90    use_language_server: bool,
  91    #[arg(long)]
  92    events: Option<FileOrStdin>,
  93}
  94
  95#[derive(Debug, Args)]
  96struct Zeta2Args {
  97    #[arg(long, default_value_t = 8192)]
  98    max_prompt_bytes: usize,
  99    #[arg(long, default_value_t = 2048)]
 100    max_excerpt_bytes: usize,
 101    #[arg(long, default_value_t = 1024)]
 102    min_excerpt_bytes: usize,
 103    #[arg(long, default_value_t = 0.66)]
 104    target_before_cursor_over_total_bytes: f32,
 105    #[arg(long, default_value_t = 1024)]
 106    max_diagnostic_bytes: usize,
 107    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 108    prompt_format: PromptFormat,
 109    #[arg(long, value_enum, default_value_t = Default::default())]
 110    output_format: OutputFormat,
 111    #[arg(long, default_value_t = 42)]
 112    file_indexing_parallelism: usize,
 113    #[arg(long, default_value_t = false)]
 114    disable_imports_gathering: bool,
 115}
 116
 117#[derive(clap::ValueEnum, Default, Debug, Clone)]
 118enum PromptFormat {
 119    #[default]
 120    MarkedExcerpt,
 121    LabeledSections,
 122    OnlySnippets,
 123}
 124
 125impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
 126    fn into(self) -> predict_edits_v3::PromptFormat {
 127        match self {
 128            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
 129            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
 130            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
 131        }
 132    }
 133}
 134
 135#[derive(clap::ValueEnum, Default, Debug, Clone)]
 136enum OutputFormat {
 137    #[default]
 138    Prompt,
 139    Request,
 140    Full,
 141}
 142
 143#[derive(Debug, Clone)]
 144enum FileOrStdin {
 145    File(PathBuf),
 146    Stdin,
 147}
 148
 149impl FileOrStdin {
 150    async fn read_to_string(&self) -> Result<String, std::io::Error> {
 151        match self {
 152            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
 153            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
 154        }
 155    }
 156}
 157
 158impl FromStr for FileOrStdin {
 159    type Err = <PathBuf as FromStr>::Err;
 160
 161    fn from_str(s: &str) -> Result<Self, Self::Err> {
 162        match s {
 163            "-" => Ok(Self::Stdin),
 164            _ => Ok(Self::File(PathBuf::from_str(s)?)),
 165        }
 166    }
 167}
 168
 169#[derive(Debug, Clone, Hash, Eq, PartialEq)]
 170struct SourceLocation {
 171    path: Arc<RelPath>,
 172    point: Point,
 173}
 174
 175impl Serialize for SourceLocation {
 176    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 177    where
 178        S: Serializer,
 179    {
 180        serializer.serialize_str(&self.to_string())
 181    }
 182}
 183
 184impl<'de> Deserialize<'de> for SourceLocation {
 185    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 186    where
 187        D: Deserializer<'de>,
 188    {
 189        let s = String::deserialize(deserializer)?;
 190        s.parse().map_err(serde::de::Error::custom)
 191    }
 192}
 193
 194impl Display for SourceLocation {
 195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 196        write!(
 197            f,
 198            "{}:{}:{}",
 199            self.path.display(PathStyle::Posix),
 200            self.point.row + 1,
 201            self.point.column + 1
 202        )
 203    }
 204}
 205
 206impl FromStr for SourceLocation {
 207    type Err = anyhow::Error;
 208
 209    fn from_str(s: &str) -> Result<Self> {
 210        let parts: Vec<&str> = s.split(':').collect();
 211        if parts.len() != 3 {
 212            return Err(anyhow!(
 213                "Invalid source location. Expected 'file.rs:line:column', got '{}'",
 214                s
 215            ));
 216        }
 217
 218        let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
 219        let line: u32 = parts[1]
 220            .parse()
 221            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
 222        let column: u32 = parts[2]
 223            .parse()
 224            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
 225
 226        // Convert from 1-based to 0-based indexing
 227        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
 228
 229        Ok(SourceLocation { path, point })
 230    }
 231}
 232
 233enum GetContextOutput {
 234    Zeta1(zeta::GatherContextOutput),
 235    Zeta2(String),
 236}
 237
 238async fn get_context(
 239    zeta2_args: Option<Zeta2Args>,
 240    args: ContextArgs,
 241    app_state: &Arc<ZetaCliAppState>,
 242    cx: &mut AsyncApp,
 243) -> Result<GetContextOutput> {
 244    let ContextArgs {
 245        worktree: worktree_path,
 246        cursor,
 247        use_language_server,
 248        events,
 249    } = args;
 250
 251    let worktree_path = worktree_path.canonicalize()?;
 252
 253    let project = cx.update(|cx| {
 254        Project::local(
 255            app_state.client.clone(),
 256            app_state.node_runtime.clone(),
 257            app_state.user_store.clone(),
 258            app_state.languages.clone(),
 259            app_state.fs.clone(),
 260            None,
 261            cx,
 262        )
 263    })?;
 264
 265    let worktree = project
 266        .update(cx, |project, cx| {
 267            project.create_worktree(&worktree_path, true, cx)
 268        })?
 269        .await?;
 270
 271    let mut ready_languages = HashSet::default();
 272    let (_lsp_open_handle, buffer) = if use_language_server {
 273        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
 274            project.clone(),
 275            worktree.clone(),
 276            cursor.path.clone(),
 277            &mut ready_languages,
 278            cx,
 279        )
 280        .await?;
 281        (Some(lsp_open_handle), buffer)
 282    } else {
 283        let buffer =
 284            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
 285        (None, buffer)
 286    };
 287
 288    let full_path_str = worktree
 289        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
 290        .display(PathStyle::local())
 291        .to_string();
 292
 293    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
 294    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
 295    if clipped_cursor != cursor.point {
 296        let max_row = snapshot.max_point().row;
 297        if cursor.point.row < max_row {
 298            return Err(anyhow!(
 299                "Cursor position {:?} is out of bounds (line length is {})",
 300                cursor.point,
 301                snapshot.line_len(cursor.point.row)
 302            ));
 303        } else {
 304            return Err(anyhow!(
 305                "Cursor position {:?} is out of bounds (max row is {})",
 306                cursor.point,
 307                max_row
 308            ));
 309        }
 310    }
 311
 312    let events = match events {
 313        Some(events) => events.read_to_string().await?,
 314        None => String::new(),
 315    };
 316
 317    if let Some(zeta2_args) = zeta2_args {
 318        // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
 319        // the whole worktree.
 320        worktree
 321            .read_with(cx, |worktree, _cx| {
 322                worktree.as_local().unwrap().scan_complete()
 323            })?
 324            .await;
 325        let output = cx
 326            .update(|cx| {
 327                let zeta = cx.new(|cx| {
 328                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
 329                });
 330                let indexing_done_task = zeta.update(cx, |zeta, cx| {
 331                    zeta.set_options(zeta2_args.to_options(true));
 332                    zeta.register_buffer(&buffer, &project, cx);
 333                    zeta.wait_for_initial_indexing(&project, cx)
 334                });
 335                cx.spawn(async move |cx| {
 336                    indexing_done_task.await?;
 337                    let request = zeta
 338                        .update(cx, |zeta, cx| {
 339                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
 340                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
 341                        })?
 342                        .await?;
 343
 344                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
 345                    let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
 346
 347                    match zeta2_args.output_format {
 348                        OutputFormat::Prompt => anyhow::Ok(prompt_string),
 349                        OutputFormat::Request => {
 350                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
 351                        }
 352                        OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
 353                            "request": request,
 354                            "prompt": prompt_string,
 355                            "section_labels": section_labels,
 356                        }))?),
 357                    }
 358                })
 359            })?
 360            .await?;
 361        Ok(GetContextOutput::Zeta2(output))
 362    } else {
 363        let prompt_for_events = move || (events, 0);
 364        Ok(GetContextOutput::Zeta1(
 365            cx.update(|cx| {
 366                zeta::gather_context(
 367                    full_path_str,
 368                    &snapshot,
 369                    clipped_cursor,
 370                    prompt_for_events,
 371                    cx,
 372                )
 373            })?
 374            .await?,
 375        ))
 376    }
 377}
 378
 379impl Zeta2Args {
 380    fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
 381        zeta2::ZetaOptions {
 382            context: EditPredictionContextOptions {
 383                use_imports: !self.disable_imports_gathering,
 384                excerpt: EditPredictionExcerptOptions {
 385                    max_bytes: self.max_excerpt_bytes,
 386                    min_bytes: self.min_excerpt_bytes,
 387                    target_before_cursor_over_total_bytes: self
 388                        .target_before_cursor_over_total_bytes,
 389                },
 390                score: EditPredictionScoreOptions {
 391                    omit_excerpt_overlaps,
 392                },
 393            },
 394            max_diagnostic_bytes: self.max_diagnostic_bytes,
 395            max_prompt_bytes: self.max_prompt_bytes,
 396            prompt_format: self.prompt_format.clone().into(),
 397            file_indexing_parallelism: self.file_indexing_parallelism,
 398        }
 399    }
 400}
 401
 402pub async fn retrieval_stats(
 403    worktree: PathBuf,
 404    app_state: Arc<ZetaCliAppState>,
 405    only_extension: Option<String>,
 406    file_limit: Option<usize>,
 407    skip_files: Option<usize>,
 408    options: zeta2::ZetaOptions,
 409    cx: &mut AsyncApp,
 410) -> Result<String> {
 411    let options = Arc::new(options);
 412    let worktree_path = worktree.canonicalize()?;
 413
 414    let project = cx.update(|cx| {
 415        Project::local(
 416            app_state.client.clone(),
 417            app_state.node_runtime.clone(),
 418            app_state.user_store.clone(),
 419            app_state.languages.clone(),
 420            app_state.fs.clone(),
 421            None,
 422            cx,
 423        )
 424    })?;
 425
 426    let worktree = project
 427        .update(cx, |project, cx| {
 428            project.create_worktree(&worktree_path, true, cx)
 429        })?
 430        .await?;
 431
 432    // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
 433    worktree
 434        .read_with(cx, |worktree, _cx| {
 435            worktree.as_local().unwrap().scan_complete()
 436        })?
 437        .await;
 438
 439    let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?;
 440    index
 441        .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
 442        .await?;
 443    let indexed_files = index
 444        .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
 445        .await;
 446    let mut filtered_files = indexed_files
 447        .into_iter()
 448        .filter(|project_path| {
 449            let file_extension = project_path.path.extension();
 450            if let Some(only_extension) = only_extension.as_ref() {
 451                file_extension.is_some_and(|extension| extension == only_extension)
 452            } else {
 453                file_extension
 454                    .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
 455            }
 456        })
 457        .collect::<Vec<_>>();
 458    filtered_files.sort_by(|a, b| a.path.cmp(&b.path));
 459
 460    let index_state = index.read_with(cx, |index, _cx| index.state().clone())?;
 461    cx.update(|_| {
 462        drop(index);
 463    })?;
 464    let index_state = Arc::new(
 465        Arc::into_inner(index_state)
 466            .context("Index state had more than 1 reference")?
 467            .into_inner(),
 468    );
 469
 470    struct FileSnapshot {
 471        project_entry_id: ProjectEntryId,
 472        snapshot: BufferSnapshot,
 473        hash: u64,
 474        parent_abs_path: Arc<Path>,
 475    }
 476
 477    let files: Vec<FileSnapshot> = futures::future::try_join_all({
 478        filtered_files
 479            .iter()
 480            .map(|file| {
 481                let buffer_task =
 482                    open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx);
 483                cx.spawn(async move |cx| {
 484                    let buffer = buffer_task.await?;
 485                    let (project_entry_id, parent_abs_path, snapshot) =
 486                        buffer.read_with(cx, |buffer, cx| {
 487                            let file = project::File::from_dyn(buffer.file()).unwrap();
 488                            let project_entry_id = file.project_entry_id().unwrap();
 489                            let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path);
 490                            if !parent_abs_path.pop() {
 491                                panic!("Invalid worktree path");
 492                            }
 493
 494                            (project_entry_id, parent_abs_path, buffer.snapshot())
 495                        })?;
 496
 497                    anyhow::Ok(
 498                        cx.background_spawn(async move {
 499                            let mut hasher = collections::FxHasher::default();
 500                            snapshot.text().hash(&mut hasher);
 501                            FileSnapshot {
 502                                project_entry_id,
 503                                snapshot,
 504                                hash: hasher.finish(),
 505                                parent_abs_path: parent_abs_path.into(),
 506                            }
 507                        })
 508                        .await,
 509                    )
 510                })
 511            })
 512            .collect::<Vec<_>>()
 513    })
 514    .await?;
 515
 516    let mut file_snapshots = HashMap::default();
 517    let mut hasher = collections::FxHasher::default();
 518    for FileSnapshot {
 519        project_entry_id,
 520        snapshot,
 521        hash,
 522        ..
 523    } in &files
 524    {
 525        file_snapshots.insert(*project_entry_id, snapshot.clone());
 526        hash.hash(&mut hasher);
 527    }
 528    let files_hash = hasher.finish();
 529    let file_snapshots = Arc::new(file_snapshots);
 530
 531    let lsp_definitions_path = std::env::current_dir()?.join(format!(
 532        "target/zeta2-lsp-definitions-{:x}.jsonl",
 533        files_hash
 534    ));
 535
 536    let mut lsp_definitions = HashMap::default();
 537    let mut lsp_files = 0;
 538
 539    if std::fs::exists(&lsp_definitions_path)? {
 540        log::info!(
 541            "Using cached LSP definitions from {}",
 542            lsp_definitions_path.display()
 543        );
 544
 545        let file = File::options()
 546            .read(true)
 547            .write(true)
 548            .open(&lsp_definitions_path)?;
 549        let lines = BufReader::new(&file).lines();
 550        let mut valid_len: usize = 0;
 551
 552        for (line, expected_file) in lines.zip(files.iter()) {
 553            let line = line?;
 554            let FileLspDefinitions { path, references } = match serde_json::from_str(&line) {
 555                Ok(ok) => ok,
 556                Err(_) => {
 557                    log::error!("Found invalid cache line. Truncating to #{lsp_files}.",);
 558                    file.set_len(valid_len as u64)?;
 559                    break;
 560                }
 561            };
 562            let expected_path = expected_file.snapshot.file().unwrap().path().as_unix_str();
 563            if expected_path != path.as_ref() {
 564                log::error!(
 565                    "Expected file #{} to be {expected_path}, but found {path}. Truncating to #{lsp_files}.",
 566                    lsp_files + 1
 567                );
 568                file.set_len(valid_len as u64)?;
 569                break;
 570            }
 571            for (point, ranges) in references {
 572                let Ok(path) = RelPath::new(Path::new(path.as_ref()), PathStyle::Posix) else {
 573                    log::warn!("Invalid path: {}", path);
 574                    continue;
 575                };
 576                lsp_definitions.insert(
 577                    SourceLocation {
 578                        path: path.into_arc(),
 579                        point: point.into(),
 580                    },
 581                    ranges,
 582                );
 583            }
 584            lsp_files += 1;
 585            valid_len += line.len() + 1
 586        }
 587    }
 588
 589    if lsp_files < files.len() {
 590        if lsp_files == 0 {
 591            log::warn!(
 592                "No LSP definitions found, populating {}",
 593                lsp_definitions_path.display()
 594            );
 595        } else {
 596            log::warn!("{} files missing from LSP cache", files.len() - lsp_files);
 597        }
 598
 599        gather_lsp_definitions(
 600            &lsp_definitions_path,
 601            lsp_files,
 602            &filtered_files,
 603            &worktree,
 604            &project,
 605            &mut lsp_definitions,
 606            cx,
 607        )
 608        .await?;
 609    }
 610    let files_len = files.len().min(file_limit.unwrap_or(usize::MAX));
 611    let done_count = Arc::new(AtomicUsize::new(0));
 612
 613    let (output_tx, mut output_rx) = mpsc::unbounded::<RetrievalStatsResult>();
 614    let mut output = std::fs::File::create("target/zeta-retrieval-stats.txt")?;
 615
 616    let tasks = files
 617        .into_iter()
 618        .skip(skip_files.unwrap_or(0))
 619        .take(file_limit.unwrap_or(usize::MAX))
 620        .map(|project_file| {
 621            let index_state = index_state.clone();
 622            let lsp_definitions = lsp_definitions.clone();
 623            let options = options.clone();
 624            let output_tx = output_tx.clone();
 625            let done_count = done_count.clone();
 626            let file_snapshots = file_snapshots.clone();
 627            cx.background_spawn(async move {
 628                let snapshot = project_file.snapshot;
 629
 630                let full_range = 0..snapshot.len();
 631                let references = references_in_range(
 632                    full_range,
 633                    &snapshot.text(),
 634                    ReferenceRegion::Nearby,
 635                    &snapshot,
 636                );
 637
 638                println!("references: {}", references.len(),);
 639
 640                let imports = if options.context.use_imports {
 641                    Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
 642                } else {
 643                    Imports::default()
 644                };
 645
 646                let path = snapshot.file().unwrap().path();
 647
 648                for reference in references {
 649                    let query_point = snapshot.offset_to_point(reference.range.start);
 650                    let source_location = SourceLocation {
 651                        path: path.clone(),
 652                        point: query_point,
 653                    };
 654                    let lsp_definitions = lsp_definitions
 655                        .get(&source_location)
 656                        .cloned()
 657                        .unwrap_or_else(|| {
 658                            log::warn!(
 659                                "No definitions found for source location: {:?}",
 660                                source_location
 661                            );
 662                            Vec::new()
 663                        });
 664
 665                    let retrieve_result = retrieve_definitions(
 666                        &reference,
 667                        &imports,
 668                        query_point,
 669                        &snapshot,
 670                        &index_state,
 671                        &file_snapshots,
 672                        &options,
 673                    )
 674                    .await?;
 675
 676                    // TODO: LSP returns things like locals, this filters out some of those, but potentially
 677                    // hides some retrieval issues.
 678                    if retrieve_result.definitions.is_empty() {
 679                        continue;
 680                    }
 681
 682                    let mut best_match = None;
 683                    let mut has_external_definition = false;
 684                    let mut in_excerpt = false;
 685                    for (index, retrieved_definition) in
 686                        retrieve_result.definitions.iter().enumerate()
 687                    {
 688                        for lsp_definition in &lsp_definitions {
 689                            let SourceRange {
 690                                path,
 691                                point_range,
 692                                offset_range,
 693                            } = lsp_definition;
 694                            let lsp_point_range =
 695                                SerializablePoint::into_language_point_range(point_range.clone());
 696                            has_external_definition = has_external_definition
 697                                || path.is_absolute()
 698                                || path
 699                                    .components()
 700                                    .any(|component| component.as_os_str() == "node_modules");
 701                            let is_match = path.as_path()
 702                                == retrieved_definition.path.as_std_path()
 703                                && retrieved_definition
 704                                    .range
 705                                    .contains_inclusive(&lsp_point_range);
 706                            if is_match {
 707                                if best_match.is_none() {
 708                                    best_match = Some(index);
 709                                }
 710                            }
 711                            in_excerpt = in_excerpt
 712                                || retrieve_result.excerpt_range.as_ref().is_some_and(
 713                                    |excerpt_range| excerpt_range.contains_inclusive(&offset_range),
 714                                );
 715                        }
 716                    }
 717
 718                    let outcome = if let Some(best_match) = best_match {
 719                        RetrievalOutcome::Match { best_match }
 720                    } else if has_external_definition {
 721                        RetrievalOutcome::NoMatchDueToExternalLspDefinitions
 722                    } else if in_excerpt {
 723                        RetrievalOutcome::ProbablyLocal
 724                    } else {
 725                        RetrievalOutcome::NoMatch
 726                    };
 727
 728                    let result = RetrievalStatsResult {
 729                        outcome,
 730                        path: path.clone(),
 731                        identifier: reference.identifier,
 732                        point: query_point,
 733                        lsp_definitions,
 734                        retrieved_definitions: retrieve_result.definitions,
 735                    };
 736
 737                    output_tx.unbounded_send(result).ok();
 738                }
 739
 740                println!(
 741                    "{:02}/{:02} done",
 742                    done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1,
 743                    files_len,
 744                );
 745
 746                anyhow::Ok(())
 747            })
 748        })
 749        .collect::<Vec<_>>();
 750
 751    drop(output_tx);
 752
 753    let results_task = cx.background_spawn(async move {
 754        let mut results = Vec::new();
 755        while let Some(result) = output_rx.next().await {
 756            output
 757                .write_all(format!("{:#?}\n", result).as_bytes())
 758                .log_err();
 759            results.push(result)
 760        }
 761        results
 762    });
 763
 764    futures::future::try_join_all(tasks).await?;
 765    println!("Tasks completed");
 766    let results = results_task.await;
 767    println!("Results received");
 768
 769    let mut references_count = 0;
 770
 771    let mut included_count = 0;
 772    let mut both_absent_count = 0;
 773
 774    let mut retrieved_count = 0;
 775    let mut top_match_count = 0;
 776    let mut non_top_match_count = 0;
 777    let mut ranking_involved_top_match_count = 0;
 778
 779    let mut no_match_count = 0;
 780    let mut no_match_none_retrieved = 0;
 781    let mut no_match_wrong_retrieval = 0;
 782
 783    let mut expected_no_match_count = 0;
 784    let mut in_excerpt_count = 0;
 785    let mut external_definition_count = 0;
 786
 787    for result in results {
 788        references_count += 1;
 789        match &result.outcome {
 790            RetrievalOutcome::Match { best_match } => {
 791                included_count += 1;
 792                retrieved_count += 1;
 793                let multiple = result.retrieved_definitions.len() > 1;
 794                if *best_match == 0 {
 795                    top_match_count += 1;
 796                    if multiple {
 797                        ranking_involved_top_match_count += 1;
 798                    }
 799                } else {
 800                    non_top_match_count += 1;
 801                }
 802            }
 803            RetrievalOutcome::NoMatch => {
 804                if result.lsp_definitions.is_empty() {
 805                    included_count += 1;
 806                    both_absent_count += 1;
 807                } else {
 808                    no_match_count += 1;
 809                    if result.retrieved_definitions.is_empty() {
 810                        no_match_none_retrieved += 1;
 811                    } else {
 812                        no_match_wrong_retrieval += 1;
 813                    }
 814                }
 815            }
 816            RetrievalOutcome::NoMatchDueToExternalLspDefinitions => {
 817                expected_no_match_count += 1;
 818                external_definition_count += 1;
 819            }
 820            RetrievalOutcome::ProbablyLocal => {
 821                included_count += 1;
 822                in_excerpt_count += 1;
 823            }
 824        }
 825    }
 826
 827    fn count_and_percentage(part: usize, total: usize) -> String {
 828        format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0)
 829    }
 830
 831    println!("");
 832    println!("╮ references: {}", references_count);
 833    println!(
 834        "├─╮ included: {}",
 835        count_and_percentage(included_count, references_count),
 836    );
 837    println!(
 838        "│ ├─╮ retrieved: {}",
 839        count_and_percentage(retrieved_count, references_count)
 840    );
 841    println!(
 842        "│ │ ├─╮ top match : {}",
 843        count_and_percentage(top_match_count, retrieved_count)
 844    );
 845    println!(
 846        "│ │ │ ╰─╴ involving ranking: {}",
 847        count_and_percentage(ranking_involved_top_match_count, top_match_count)
 848    );
 849    println!(
 850        "│ │ ╰─╴ non-top match: {}",
 851        count_and_percentage(non_top_match_count, retrieved_count)
 852    );
 853    println!(
 854        "│ ├─╴ both absent: {}",
 855        count_and_percentage(both_absent_count, included_count)
 856    );
 857    println!(
 858        "│ ╰─╴ in excerpt: {}",
 859        count_and_percentage(in_excerpt_count, included_count)
 860    );
 861    println!(
 862        "├─╮ no match: {}",
 863        count_and_percentage(no_match_count, references_count)
 864    );
 865    println!(
 866        "│ ├─╴ none retrieved: {}",
 867        count_and_percentage(no_match_none_retrieved, no_match_count)
 868    );
 869    println!(
 870        "│ ╰─╴ wrong retrieval: {}",
 871        count_and_percentage(no_match_wrong_retrieval, no_match_count)
 872    );
 873    println!(
 874        "╰─╮ expected no match: {}",
 875        count_and_percentage(expected_no_match_count, references_count)
 876    );
 877    println!(
 878        "  ╰─╴ external definition: {}",
 879        count_and_percentage(external_definition_count, expected_no_match_count)
 880    );
 881
 882    println!("");
 883    println!("LSP definition cache at {}", lsp_definitions_path.display());
 884
 885    Ok("".to_string())
 886}
 887
 888struct RetrieveResult {
 889    definitions: Vec<RetrievedDefinition>,
 890    excerpt_range: Option<Range<usize>>,
 891}
 892
 893async fn retrieve_definitions(
 894    reference: &Reference,
 895    imports: &Imports,
 896    query_point: Point,
 897    snapshot: &BufferSnapshot,
 898    index: &Arc<SyntaxIndexState>,
 899    file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
 900    options: &Arc<zeta2::ZetaOptions>,
 901) -> Result<RetrieveResult> {
 902    let mut single_reference_map = HashMap::default();
 903    single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
 904    let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
 905        query_point,
 906        snapshot,
 907        imports,
 908        &options.context,
 909        Some(&index),
 910        |_, _, _| single_reference_map,
 911    );
 912
 913    let Some(edit_prediction_context) = edit_prediction_context else {
 914        return Ok(RetrieveResult {
 915            definitions: Vec::new(),
 916            excerpt_range: None,
 917        });
 918    };
 919
 920    let mut retrieved_definitions = Vec::new();
 921    for scored_declaration in edit_prediction_context.declarations {
 922        match &scored_declaration.declaration {
 923            Declaration::File {
 924                project_entry_id,
 925                declaration,
 926                ..
 927            } => {
 928                let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
 929                    log::error!("bug: file project entry not found");
 930                    continue;
 931                };
 932                let path = snapshot.file().unwrap().path().clone();
 933                retrieved_definitions.push(RetrievedDefinition {
 934                    path,
 935                    range: snapshot.offset_to_point(declaration.item_range.start)
 936                        ..snapshot.offset_to_point(declaration.item_range.end),
 937                    score: scored_declaration.score(DeclarationStyle::Declaration),
 938                    retrieval_score: scored_declaration.retrieval_score(),
 939                    components: scored_declaration.components,
 940                });
 941            }
 942            Declaration::Buffer {
 943                project_entry_id,
 944                rope,
 945                declaration,
 946                ..
 947            } => {
 948                let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
 949                    // This case happens when dependency buffers have been opened by
 950                    // go-to-definition, resulting in single-file worktrees.
 951                    continue;
 952                };
 953                let path = snapshot.file().unwrap().path().clone();
 954                retrieved_definitions.push(RetrievedDefinition {
 955                    path,
 956                    range: rope.offset_to_point(declaration.item_range.start)
 957                        ..rope.offset_to_point(declaration.item_range.end),
 958                    score: scored_declaration.score(DeclarationStyle::Declaration),
 959                    retrieval_score: scored_declaration.retrieval_score(),
 960                    components: scored_declaration.components,
 961                });
 962            }
 963        }
 964    }
 965    retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score)));
 966
 967    Ok(RetrieveResult {
 968        definitions: retrieved_definitions,
 969        excerpt_range: Some(edit_prediction_context.excerpt.range),
 970    })
 971}
 972
 973async fn gather_lsp_definitions(
 974    lsp_definitions_path: &Path,
 975    start_index: usize,
 976    files: &[ProjectPath],
 977    worktree: &Entity<Worktree>,
 978    project: &Entity<Project>,
 979    definitions: &mut HashMap<SourceLocation, Vec<SourceRange>>,
 980    cx: &mut AsyncApp,
 981) -> Result<()> {
 982    let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
 983
 984    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
 985    cx.subscribe(&lsp_store, {
 986        move |_, event, _| {
 987            if let project::LspStoreEvent::LanguageServerUpdate {
 988                message:
 989                    client::proto::update_language_server::Variant::WorkProgress(
 990                        client::proto::LspWorkProgress {
 991                            message: Some(message),
 992                            ..
 993                        },
 994                    ),
 995                ..
 996            } = event
 997            {
 998                println!("⟲ {message}")
 999            }
1000        }
1001    })?
1002    .detach();
1003
1004    let (cache_line_tx, mut cache_line_rx) = mpsc::unbounded::<FileLspDefinitions>();
1005
1006    let cache_file = File::options()
1007        .append(true)
1008        .create(true)
1009        .open(lsp_definitions_path)
1010        .unwrap();
1011
1012    let cache_task = cx.background_spawn(async move {
1013        let mut writer = BufWriter::new(cache_file);
1014        while let Some(line) = cache_line_rx.next().await {
1015            serde_json::to_writer(&mut writer, &line).unwrap();
1016            writer.write_all(&[b'\n']).unwrap();
1017        }
1018        writer.flush().unwrap();
1019    });
1020
1021    let mut error_count = 0;
1022    let mut lsp_open_handles = Vec::new();
1023    let mut ready_languages = HashSet::default();
1024    for (file_index, project_path) in files[start_index..].iter().enumerate() {
1025        println!(
1026            "Processing file {} of {}: {}",
1027            start_index + file_index + 1,
1028            files.len(),
1029            project_path.path.display(PathStyle::Posix)
1030        );
1031
1032        let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
1033            project.clone(),
1034            worktree.clone(),
1035            project_path.path.clone(),
1036            &mut ready_languages,
1037            cx,
1038        )
1039        .await
1040        .log_err() else {
1041            continue;
1042        };
1043        lsp_open_handles.push(lsp_open_handle);
1044
1045        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1046        let full_range = 0..snapshot.len();
1047        let references = references_in_range(
1048            full_range,
1049            &snapshot.text(),
1050            ReferenceRegion::Nearby,
1051            &snapshot,
1052        );
1053
1054        loop {
1055            let is_ready = lsp_store
1056                .read_with(cx, |lsp_store, _cx| {
1057                    lsp_store
1058                        .language_server_statuses
1059                        .get(&language_server_id)
1060                        .is_some_and(|status| status.pending_work.is_empty())
1061                })
1062                .unwrap();
1063            if is_ready {
1064                break;
1065            }
1066            cx.background_executor()
1067                .timer(Duration::from_millis(10))
1068                .await;
1069        }
1070
1071        let mut cache_line_references = Vec::with_capacity(references.len());
1072
1073        for reference in references {
1074            // TODO: Rename declaration to definition in edit_prediction_context?
1075            let lsp_result = project
1076                .update(cx, |project, cx| {
1077                    project.definitions(&buffer, reference.range.start, cx)
1078                })?
1079                .await;
1080
1081            match lsp_result {
1082                Ok(lsp_definitions) => {
1083                    let mut targets = Vec::new();
1084                    for target in lsp_definitions.unwrap_or_default() {
1085                        let buffer = target.target.buffer;
1086                        let anchor_range = target.target.range;
1087                        buffer.read_with(cx, |buffer, cx| {
1088                            let Some(file) = project::File::from_dyn(buffer.file()) else {
1089                                return;
1090                            };
1091                            let file_worktree = file.worktree.read(cx);
1092                            let file_worktree_id = file_worktree.id();
1093                            // Relative paths for worktree files, absolute for all others
1094                            let path = if worktree_id != file_worktree_id {
1095                                file.worktree.read(cx).absolutize(&file.path)
1096                            } else {
1097                                file.path.as_std_path().to_path_buf()
1098                            };
1099                            let offset_range = anchor_range.to_offset(&buffer);
1100                            let point_range = SerializablePoint::from_language_point_range(
1101                                offset_range.to_point(&buffer),
1102                            );
1103                            targets.push(SourceRange {
1104                                path,
1105                                offset_range,
1106                                point_range,
1107                            });
1108                        })?;
1109                    }
1110
1111                    let point = snapshot.offset_to_point(reference.range.start);
1112
1113                    cache_line_references.push((point.into(), targets.clone()));
1114                    definitions.insert(
1115                        SourceLocation {
1116                            path: project_path.path.clone(),
1117                            point,
1118                        },
1119                        targets,
1120                    );
1121                }
1122                Err(err) => {
1123                    log::error!("Language server error: {err}");
1124                    error_count += 1;
1125                }
1126            }
1127        }
1128
1129        cache_line_tx
1130            .unbounded_send(FileLspDefinitions {
1131                path: project_path.path.as_unix_str().into(),
1132                references: cache_line_references,
1133            })
1134            .log_err();
1135    }
1136
1137    drop(cache_line_tx);
1138
1139    if error_count > 0 {
1140        log::error!("Encountered {} language server errors", error_count);
1141    }
1142
1143    cache_task.await;
1144
1145    Ok(())
1146}
1147
1148#[derive(Serialize, Deserialize)]
1149struct FileLspDefinitions {
1150    path: Arc<str>,
1151    references: Vec<(SerializablePoint, Vec<SourceRange>)>,
1152}
1153
1154#[derive(Debug, Clone, Serialize, Deserialize)]
1155struct SourceRange {
1156    path: PathBuf,
1157    point_range: Range<SerializablePoint>,
1158    offset_range: Range<usize>,
1159}
1160
1161/// Serializes to 1-based row and column indices.
1162#[derive(Debug, Clone, Serialize, Deserialize)]
1163pub struct SerializablePoint {
1164    pub row: u32,
1165    pub column: u32,
1166}
1167
1168impl SerializablePoint {
1169    pub fn into_language_point_range(range: Range<Self>) -> Range<Point> {
1170        range.start.into()..range.end.into()
1171    }
1172
1173    pub fn from_language_point_range(range: Range<Point>) -> Range<Self> {
1174        range.start.into()..range.end.into()
1175    }
1176}
1177
1178impl From<Point> for SerializablePoint {
1179    fn from(point: Point) -> Self {
1180        SerializablePoint {
1181            row: point.row + 1,
1182            column: point.column + 1,
1183        }
1184    }
1185}
1186
1187impl From<SerializablePoint> for Point {
1188    fn from(serializable: SerializablePoint) -> Self {
1189        Point {
1190            row: serializable.row.saturating_sub(1),
1191            column: serializable.column.saturating_sub(1),
1192        }
1193    }
1194}
1195
1196#[derive(Debug)]
1197struct RetrievalStatsResult {
1198    outcome: RetrievalOutcome,
1199    #[allow(dead_code)]
1200    path: Arc<RelPath>,
1201    #[allow(dead_code)]
1202    identifier: Identifier,
1203    #[allow(dead_code)]
1204    point: Point,
1205    #[allow(dead_code)]
1206    lsp_definitions: Vec<SourceRange>,
1207    retrieved_definitions: Vec<RetrievedDefinition>,
1208}
1209
1210#[derive(Debug)]
1211enum RetrievalOutcome {
1212    Match {
1213        /// Lowest index within retrieved_definitions that matches an LSP definition.
1214        best_match: usize,
1215    },
1216    ProbablyLocal,
1217    NoMatch,
1218    NoMatchDueToExternalLspDefinitions,
1219}
1220
1221#[derive(Debug)]
1222struct RetrievedDefinition {
1223    path: Arc<RelPath>,
1224    range: Range<Point>,
1225    score: f32,
1226    #[allow(dead_code)]
1227    retrieval_score: f32,
1228    #[allow(dead_code)]
1229    components: DeclarationScoreComponents,
1230}
1231
1232pub fn open_buffer(
1233    project: Entity<Project>,
1234    worktree: Entity<Worktree>,
1235    path: Arc<RelPath>,
1236    cx: &AsyncApp,
1237) -> Task<Result<Entity<Buffer>>> {
1238    cx.spawn(async move |cx| {
1239        let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
1240            worktree_id: worktree.id(),
1241            path,
1242        })?;
1243
1244        let buffer = project
1245            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
1246            .await?;
1247
1248        let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
1249        while *parse_status.borrow() != ParseStatus::Idle {
1250            parse_status.changed().await?;
1251        }
1252
1253        Ok(buffer)
1254    })
1255}
1256
1257pub async fn open_buffer_with_language_server(
1258    project: Entity<Project>,
1259    worktree: Entity<Worktree>,
1260    path: Arc<RelPath>,
1261    ready_languages: &mut HashSet<LanguageId>,
1262    cx: &mut AsyncApp,
1263) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
1264    let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
1265
1266    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
1267        (
1268            project.register_buffer_with_language_servers(&buffer, cx),
1269            project.path_style(cx),
1270        )
1271    })?;
1272
1273    let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
1274        buffer.language().map(|language| language.id())
1275    })?
1276    else {
1277        return Err(anyhow!("No language for {}", path.display(path_style)));
1278    };
1279
1280    let log_prefix = path.display(path_style);
1281    if !ready_languages.contains(&language_id) {
1282        wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
1283        ready_languages.insert(language_id);
1284    }
1285
1286    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
1287
1288    // hacky wait for buffer to be registered with the language server
1289    for _ in 0..100 {
1290        let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
1291            buffer.update(cx, |buffer, cx| {
1292                lsp_store
1293                    .language_servers_for_local_buffer(&buffer, cx)
1294                    .next()
1295                    .map(|(_, language_server)| language_server.server_id())
1296            })
1297        })?
1298        else {
1299            cx.background_executor()
1300                .timer(Duration::from_millis(10))
1301                .await;
1302            continue;
1303        };
1304
1305        return Ok((lsp_open_handle, language_server_id, buffer));
1306    }
1307
1308    return Err(anyhow!("No language server found for buffer"));
1309}
1310
1311// TODO: Dedupe with similar function in crates/eval/src/instance.rs
1312pub fn wait_for_lang_server(
1313    project: &Entity<Project>,
1314    buffer: &Entity<Buffer>,
1315    log_prefix: String,
1316    cx: &mut AsyncApp,
1317) -> Task<Result<()>> {
1318    println!("{}⏵ Waiting for language server", log_prefix);
1319
1320    let (mut tx, mut rx) = mpsc::channel(1);
1321
1322    let lsp_store = project
1323        .read_with(cx, |project, _| project.lsp_store())
1324        .unwrap();
1325
1326    let has_lang_server = buffer
1327        .update(cx, |buffer, cx| {
1328            lsp_store.update(cx, |lsp_store, cx| {
1329                lsp_store
1330                    .language_servers_for_local_buffer(buffer, cx)
1331                    .next()
1332                    .is_some()
1333            })
1334        })
1335        .unwrap_or(false);
1336
1337    if has_lang_server {
1338        project
1339            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
1340            .unwrap()
1341            .detach();
1342    }
1343    let (mut added_tx, mut added_rx) = mpsc::channel(1);
1344
1345    let subscriptions = [
1346        cx.subscribe(&lsp_store, {
1347            let log_prefix = log_prefix.clone();
1348            move |_, event, _| {
1349                if let project::LspStoreEvent::LanguageServerUpdate {
1350                    message:
1351                        client::proto::update_language_server::Variant::WorkProgress(
1352                            client::proto::LspWorkProgress {
1353                                message: Some(message),
1354                                ..
1355                            },
1356                        ),
1357                    ..
1358                } = event
1359                {
1360                    println!("{}⟲ {message}", log_prefix)
1361                }
1362            }
1363        }),
1364        cx.subscribe(project, {
1365            let buffer = buffer.clone();
1366            move |project, event, cx| match event {
1367                project::Event::LanguageServerAdded(_, _, _) => {
1368                    let buffer = buffer.clone();
1369                    project
1370                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
1371                        .detach();
1372                    added_tx.try_send(()).ok();
1373                }
1374                project::Event::DiskBasedDiagnosticsFinished { .. } => {
1375                    tx.try_send(()).ok();
1376                }
1377                _ => {}
1378            }
1379        }),
1380    ];
1381
1382    cx.spawn(async move |cx| {
1383        if !has_lang_server {
1384            // some buffers never have a language server, so this aborts quickly in that case.
1385            let timeout = cx.background_executor().timer(Duration::from_secs(5));
1386            futures::select! {
1387                _ = added_rx.next() => {},
1388                _ = timeout.fuse() => {
1389                    anyhow::bail!("Waiting for language server add timed out after 5 seconds");
1390                }
1391            };
1392        }
1393        let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
1394        let result = futures::select! {
1395            _ = rx.next() => {
1396                println!("{}⚑ Language server idle", log_prefix);
1397                anyhow::Ok(())
1398            },
1399            _ = timeout.fuse() => {
1400                anyhow::bail!("LSP wait timed out after 5 minutes");
1401            }
1402        };
1403        drop(subscriptions);
1404        result
1405    })
1406}
1407
1408fn main() {
1409    zlog::init();
1410    zlog::init_output_stderr();
1411    let args = ZetaCliArgs::parse();
1412    let http_client = Arc::new(ReqwestClient::new());
1413    let app = Application::headless().with_http_client(http_client);
1414
1415    app.run(move |cx| {
1416        let app_state = Arc::new(headless::init(cx));
1417        cx.spawn(async move |cx| {
1418            let result = match args.command {
1419                Commands::Zeta2Context {
1420                    zeta2_args,
1421                    context_args,
1422                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
1423                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
1424                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
1425                    Err(err) => Err(err),
1426                },
1427                Commands::Context(context_args) => {
1428                    match get_context(None, context_args, &app_state, cx).await {
1429                        Ok(GetContextOutput::Zeta1(output)) => {
1430                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
1431                        }
1432                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
1433                        Err(err) => Err(err),
1434                    }
1435                }
1436                Commands::Predict {
1437                    predict_edits_body,
1438                    context_args,
1439                } => {
1440                    cx.spawn(async move |cx| {
1441                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
1442                        app_state.client.sign_in(true, cx).await?;
1443                        let llm_token = LlmApiToken::default();
1444                        llm_token.refresh(&app_state.client).await?;
1445
1446                        let predict_edits_body =
1447                            if let Some(predict_edits_body) = predict_edits_body {
1448                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
1449                            } else if let Some(context_args) = context_args {
1450                                match get_context(None, context_args, &app_state, cx).await? {
1451                                    GetContextOutput::Zeta1(output) => output.body,
1452                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
1453                                }
1454                            } else {
1455                                return Err(anyhow!(
1456                                    "Expected either --predict-edits-body-file \
1457                                    or the required args of the `context` command."
1458                                ));
1459                            };
1460
1461                        let (response, _usage) =
1462                            Zeta::perform_predict_edits(PerformPredictEditsParams {
1463                                client: app_state.client.clone(),
1464                                llm_token,
1465                                app_version,
1466                                body: predict_edits_body,
1467                            })
1468                            .await?;
1469
1470                        Ok(response.output_excerpt)
1471                    })
1472                    .await
1473                }
1474                Commands::RetrievalStats {
1475                    zeta2_args,
1476                    worktree,
1477                    extension,
1478                    limit,
1479                    skip,
1480                } => {
1481                    retrieval_stats(
1482                        worktree,
1483                        app_state,
1484                        extension,
1485                        limit,
1486                        skip,
1487                        (&zeta2_args).to_options(false),
1488                        cx,
1489                    )
1490                    .await
1491                }
1492            };
1493            match result {
1494                Ok(output) => {
1495                    println!("{}", output);
1496                    let _ = cx.update(|cx| cx.quit());
1497                }
1498                Err(e) => {
1499                    eprintln!("Failed: {:?}", e);
1500                    exit(1);
1501                }
1502            }
1503        })
1504        .detach();
1505    });
1506}