main.rs

   1mod headless;
   2
   3use anyhow::{Context as _, Result, anyhow};
   4use clap::{Args, Parser, Subcommand};
   5use cloud_llm_client::predict_edits_v3::{self, DeclarationScoreComponents};
   6use edit_prediction_context::{
   7    Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
   8    EditPredictionExcerptOptions, EditPredictionScoreOptions, Identifier, Imports, Reference,
   9    ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
  10};
  11use futures::channel::mpsc;
  12use futures::{FutureExt as _, StreamExt as _};
  13use gpui::{AppContext, Application, AsyncApp};
  14use gpui::{Entity, Task};
  15use language::{Bias, BufferSnapshot, LanguageServerId, Point};
  16use language::{Buffer, OffsetRangeExt};
  17use language::{LanguageId, ParseStatus};
  18use language_model::LlmApiToken;
  19use ordered_float::OrderedFloat;
  20use project::{Project, ProjectEntryId, ProjectPath, Worktree};
  21use release_channel::AppVersion;
  22use reqwest_client::ReqwestClient;
  23use serde::{Deserialize, Deserializer, Serialize, Serializer};
  24use serde_json::json;
  25use std::cmp::Reverse;
  26use std::collections::{HashMap, HashSet};
  27use std::fmt::{self, Display};
  28use std::fs::File;
  29use std::hash::Hash;
  30use std::hash::Hasher;
  31use std::io::Write as _;
  32use std::ops::Range;
  33use std::path::{Path, PathBuf};
  34use std::process::exit;
  35use std::str::FromStr;
  36use std::sync::atomic::AtomicUsize;
  37use std::sync::{Arc, atomic};
  38use std::time::Duration;
  39use util::paths::PathStyle;
  40use util::rel_path::RelPath;
  41use util::{RangeExt, ResultExt as _};
  42use zeta::{PerformPredictEditsParams, Zeta};
  43
  44use crate::headless::ZetaCliAppState;
  45
  46#[derive(Parser, Debug)]
  47#[command(name = "zeta")]
  48struct ZetaCliArgs {
  49    #[command(subcommand)]
  50    command: Commands,
  51}
  52
  53#[derive(Subcommand, Debug)]
  54enum Commands {
  55    Context(ContextArgs),
  56    Zeta2Context {
  57        #[clap(flatten)]
  58        zeta2_args: Zeta2Args,
  59        #[clap(flatten)]
  60        context_args: ContextArgs,
  61    },
  62    Predict {
  63        #[arg(long)]
  64        predict_edits_body: Option<FileOrStdin>,
  65        #[clap(flatten)]
  66        context_args: Option<ContextArgs>,
  67    },
  68    RetrievalStats {
  69        #[clap(flatten)]
  70        zeta2_args: Zeta2Args,
  71        #[arg(long)]
  72        worktree: PathBuf,
  73        #[arg(long)]
  74        extension: Option<String>,
  75        #[arg(long)]
  76        limit: Option<usize>,
  77        #[arg(long)]
  78        skip: Option<usize>,
  79    },
  80}
  81
  82#[derive(Debug, Args)]
  83#[group(requires = "worktree")]
  84struct ContextArgs {
  85    #[arg(long)]
  86    worktree: PathBuf,
  87    #[arg(long)]
  88    cursor: SourceLocation,
  89    #[arg(long)]
  90    use_language_server: bool,
  91    #[arg(long)]
  92    events: Option<FileOrStdin>,
  93}
  94
  95#[derive(Debug, Args)]
  96struct Zeta2Args {
  97    #[arg(long, default_value_t = 8192)]
  98    max_prompt_bytes: usize,
  99    #[arg(long, default_value_t = 2048)]
 100    max_excerpt_bytes: usize,
 101    #[arg(long, default_value_t = 1024)]
 102    min_excerpt_bytes: usize,
 103    #[arg(long, default_value_t = 0.66)]
 104    target_before_cursor_over_total_bytes: f32,
 105    #[arg(long, default_value_t = 1024)]
 106    max_diagnostic_bytes: usize,
 107    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 108    prompt_format: PromptFormat,
 109    #[arg(long, value_enum, default_value_t = Default::default())]
 110    output_format: OutputFormat,
 111    #[arg(long, default_value_t = 42)]
 112    file_indexing_parallelism: usize,
 113    #[arg(long, default_value_t = false)]
 114    disable_imports_gathering: bool,
 115}
 116
 117#[derive(clap::ValueEnum, Default, Debug, Clone)]
 118enum PromptFormat {
 119    #[default]
 120    MarkedExcerpt,
 121    LabeledSections,
 122    OnlySnippets,
 123}
 124
 125impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
 126    fn into(self) -> predict_edits_v3::PromptFormat {
 127        match self {
 128            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
 129            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
 130            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
 131        }
 132    }
 133}
 134
 135#[derive(clap::ValueEnum, Default, Debug, Clone)]
 136enum OutputFormat {
 137    #[default]
 138    Prompt,
 139    Request,
 140    Full,
 141}
 142
 143#[derive(Debug, Clone)]
 144enum FileOrStdin {
 145    File(PathBuf),
 146    Stdin,
 147}
 148
 149impl FileOrStdin {
 150    async fn read_to_string(&self) -> Result<String, std::io::Error> {
 151        match self {
 152            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
 153            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
 154        }
 155    }
 156}
 157
 158impl FromStr for FileOrStdin {
 159    type Err = <PathBuf as FromStr>::Err;
 160
 161    fn from_str(s: &str) -> Result<Self, Self::Err> {
 162        match s {
 163            "-" => Ok(Self::Stdin),
 164            _ => Ok(Self::File(PathBuf::from_str(s)?)),
 165        }
 166    }
 167}
 168
 169#[derive(Debug, Clone, Hash, Eq, PartialEq)]
 170struct SourceLocation {
 171    path: Arc<RelPath>,
 172    point: Point,
 173}
 174
 175impl Serialize for SourceLocation {
 176    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 177    where
 178        S: Serializer,
 179    {
 180        serializer.serialize_str(&self.to_string())
 181    }
 182}
 183
 184impl<'de> Deserialize<'de> for SourceLocation {
 185    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 186    where
 187        D: Deserializer<'de>,
 188    {
 189        let s = String::deserialize(deserializer)?;
 190        s.parse().map_err(serde::de::Error::custom)
 191    }
 192}
 193
 194impl Display for SourceLocation {
 195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 196        write!(
 197            f,
 198            "{}:{}:{}",
 199            self.path.display(PathStyle::Posix),
 200            self.point.row + 1,
 201            self.point.column + 1
 202        )
 203    }
 204}
 205
 206impl FromStr for SourceLocation {
 207    type Err = anyhow::Error;
 208
 209    fn from_str(s: &str) -> Result<Self> {
 210        let parts: Vec<&str> = s.split(':').collect();
 211        if parts.len() != 3 {
 212            return Err(anyhow!(
 213                "Invalid source location. Expected 'file.rs:line:column', got '{}'",
 214                s
 215            ));
 216        }
 217
 218        let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
 219        let line: u32 = parts[1]
 220            .parse()
 221            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
 222        let column: u32 = parts[2]
 223            .parse()
 224            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
 225
 226        // Convert from 1-based to 0-based indexing
 227        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
 228
 229        Ok(SourceLocation { path, point })
 230    }
 231}
 232
 233enum GetContextOutput {
 234    Zeta1(zeta::GatherContextOutput),
 235    Zeta2(String),
 236}
 237
 238async fn get_context(
 239    zeta2_args: Option<Zeta2Args>,
 240    args: ContextArgs,
 241    app_state: &Arc<ZetaCliAppState>,
 242    cx: &mut AsyncApp,
 243) -> Result<GetContextOutput> {
 244    let ContextArgs {
 245        worktree: worktree_path,
 246        cursor,
 247        use_language_server,
 248        events,
 249    } = args;
 250
 251    let worktree_path = worktree_path.canonicalize()?;
 252
 253    let project = cx.update(|cx| {
 254        Project::local(
 255            app_state.client.clone(),
 256            app_state.node_runtime.clone(),
 257            app_state.user_store.clone(),
 258            app_state.languages.clone(),
 259            app_state.fs.clone(),
 260            None,
 261            cx,
 262        )
 263    })?;
 264
 265    let worktree = project
 266        .update(cx, |project, cx| {
 267            project.create_worktree(&worktree_path, true, cx)
 268        })?
 269        .await?;
 270
 271    let mut ready_languages = HashSet::default();
 272    let (_lsp_open_handle, buffer) = if use_language_server {
 273        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
 274            project.clone(),
 275            worktree.clone(),
 276            cursor.path.clone(),
 277            &mut ready_languages,
 278            cx,
 279        )
 280        .await?;
 281        (Some(lsp_open_handle), buffer)
 282    } else {
 283        let buffer =
 284            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
 285        (None, buffer)
 286    };
 287
 288    let full_path_str = worktree
 289        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
 290        .display(PathStyle::local())
 291        .to_string();
 292
 293    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
 294    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
 295    if clipped_cursor != cursor.point {
 296        let max_row = snapshot.max_point().row;
 297        if cursor.point.row < max_row {
 298            return Err(anyhow!(
 299                "Cursor position {:?} is out of bounds (line length is {})",
 300                cursor.point,
 301                snapshot.line_len(cursor.point.row)
 302            ));
 303        } else {
 304            return Err(anyhow!(
 305                "Cursor position {:?} is out of bounds (max row is {})",
 306                cursor.point,
 307                max_row
 308            ));
 309        }
 310    }
 311
 312    let events = match events {
 313        Some(events) => events.read_to_string().await?,
 314        None => String::new(),
 315    };
 316
 317    if let Some(zeta2_args) = zeta2_args {
 318        // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
 319        // the whole worktree.
 320        worktree
 321            .read_with(cx, |worktree, _cx| {
 322                worktree.as_local().unwrap().scan_complete()
 323            })?
 324            .await;
 325        let output = cx
 326            .update(|cx| {
 327                let zeta = cx.new(|cx| {
 328                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
 329                });
 330                let indexing_done_task = zeta.update(cx, |zeta, cx| {
 331                    zeta.set_options(zeta2_args.to_options(true));
 332                    zeta.register_buffer(&buffer, &project, cx);
 333                    zeta.wait_for_initial_indexing(&project, cx)
 334                });
 335                cx.spawn(async move |cx| {
 336                    indexing_done_task.await?;
 337                    let request = zeta
 338                        .update(cx, |zeta, cx| {
 339                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
 340                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
 341                        })?
 342                        .await?;
 343
 344                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
 345                    let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
 346
 347                    match zeta2_args.output_format {
 348                        OutputFormat::Prompt => anyhow::Ok(prompt_string),
 349                        OutputFormat::Request => {
 350                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
 351                        }
 352                        OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
 353                            "request": request,
 354                            "prompt": prompt_string,
 355                            "section_labels": section_labels,
 356                        }))?),
 357                    }
 358                })
 359            })?
 360            .await?;
 361        Ok(GetContextOutput::Zeta2(output))
 362    } else {
 363        let prompt_for_events = move || (events, 0);
 364        Ok(GetContextOutput::Zeta1(
 365            cx.update(|cx| {
 366                zeta::gather_context(
 367                    full_path_str,
 368                    &snapshot,
 369                    clipped_cursor,
 370                    prompt_for_events,
 371                    cx,
 372                )
 373            })?
 374            .await?,
 375        ))
 376    }
 377}
 378
 379impl Zeta2Args {
 380    fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
 381        zeta2::ZetaOptions {
 382            context: EditPredictionContextOptions {
 383                use_imports: !self.disable_imports_gathering,
 384                excerpt: EditPredictionExcerptOptions {
 385                    max_bytes: self.max_excerpt_bytes,
 386                    min_bytes: self.min_excerpt_bytes,
 387                    target_before_cursor_over_total_bytes: self
 388                        .target_before_cursor_over_total_bytes,
 389                },
 390                score: EditPredictionScoreOptions {
 391                    omit_excerpt_overlaps,
 392                },
 393            },
 394            max_diagnostic_bytes: self.max_diagnostic_bytes,
 395            max_prompt_bytes: self.max_prompt_bytes,
 396            prompt_format: self.prompt_format.clone().into(),
 397            file_indexing_parallelism: self.file_indexing_parallelism,
 398        }
 399    }
 400}
 401
 402pub async fn retrieval_stats(
 403    worktree: PathBuf,
 404    app_state: Arc<ZetaCliAppState>,
 405    only_extension: Option<String>,
 406    file_limit: Option<usize>,
 407    skip_files: Option<usize>,
 408    options: zeta2::ZetaOptions,
 409    cx: &mut AsyncApp,
 410) -> Result<String> {
 411    let options = Arc::new(options);
 412    let worktree_path = worktree.canonicalize()?;
 413
 414    let project = cx.update(|cx| {
 415        Project::local(
 416            app_state.client.clone(),
 417            app_state.node_runtime.clone(),
 418            app_state.user_store.clone(),
 419            app_state.languages.clone(),
 420            app_state.fs.clone(),
 421            None,
 422            cx,
 423        )
 424    })?;
 425
 426    let worktree = project
 427        .update(cx, |project, cx| {
 428            project.create_worktree(&worktree_path, true, cx)
 429        })?
 430        .await?;
 431
 432    // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
 433    worktree
 434        .read_with(cx, |worktree, _cx| {
 435            worktree.as_local().unwrap().scan_complete()
 436        })?
 437        .await;
 438
 439    let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?;
 440    index
 441        .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
 442        .await?;
 443    let indexed_files = index
 444        .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
 445        .await;
 446    let mut filtered_files = indexed_files
 447        .into_iter()
 448        .filter(|project_path| {
 449            let file_extension = project_path.path.extension();
 450            if let Some(only_extension) = only_extension.as_ref() {
 451                file_extension.is_some_and(|extension| extension == only_extension)
 452            } else {
 453                file_extension
 454                    .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
 455            }
 456        })
 457        .collect::<Vec<_>>();
 458    filtered_files.sort_by(|a, b| a.path.cmp(&b.path));
 459
 460    let index_state = index.read_with(cx, |index, _cx| index.state().clone())?;
 461    cx.update(|_| {
 462        drop(index);
 463    })?;
 464    let index_state = Arc::new(
 465        Arc::into_inner(index_state)
 466            .context("Index state had more than 1 reference")?
 467            .into_inner(),
 468    );
 469
 470    struct FileSnapshot {
 471        project_entry_id: ProjectEntryId,
 472        snapshot: BufferSnapshot,
 473        hash: u64,
 474        parent_abs_path: Arc<Path>,
 475    }
 476
 477    let files: Vec<FileSnapshot> = futures::future::try_join_all({
 478        filtered_files
 479            .iter()
 480            .map(|file| {
 481                let buffer_task =
 482                    open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx);
 483                cx.spawn(async move |cx| {
 484                    let buffer = buffer_task.await?;
 485                    let (project_entry_id, parent_abs_path, snapshot) =
 486                        buffer.read_with(cx, |buffer, cx| {
 487                            let file = project::File::from_dyn(buffer.file()).unwrap();
 488                            let project_entry_id = file.project_entry_id().unwrap();
 489                            let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path);
 490                            if !parent_abs_path.pop() {
 491                                panic!("Invalid worktree path");
 492                            }
 493
 494                            (project_entry_id, parent_abs_path, buffer.snapshot())
 495                        })?;
 496
 497                    anyhow::Ok(
 498                        cx.background_spawn(async move {
 499                            let mut hasher = collections::FxHasher::default();
 500                            snapshot.text().hash(&mut hasher);
 501                            FileSnapshot {
 502                                project_entry_id,
 503                                snapshot,
 504                                hash: hasher.finish(),
 505                                parent_abs_path: parent_abs_path.into(),
 506                            }
 507                        })
 508                        .await,
 509                    )
 510                })
 511            })
 512            .collect::<Vec<_>>()
 513    })
 514    .await?;
 515
 516    let mut file_snapshots = HashMap::default();
 517    let mut hasher = collections::FxHasher::default();
 518    for FileSnapshot {
 519        project_entry_id,
 520        snapshot,
 521        hash,
 522        ..
 523    } in &files
 524    {
 525        file_snapshots.insert(*project_entry_id, snapshot.clone());
 526        hash.hash(&mut hasher);
 527    }
 528    let files_hash = hasher.finish();
 529    let file_snapshots = Arc::new(file_snapshots);
 530
 531    let lsp_definitions_path = std::env::current_dir()?.join(format!(
 532        "target/zeta2-lsp-definitions-{:x}.json",
 533        files_hash
 534    ));
 535
 536    let lsp_definitions: Arc<_> = if std::fs::exists(&lsp_definitions_path)? {
 537        log::info!(
 538            "Using cached LSP definitions from {}",
 539            lsp_definitions_path.display()
 540        );
 541        serde_json::from_reader(File::open(&lsp_definitions_path)?)?
 542    } else {
 543        log::warn!(
 544            "No LSP definitions found populating {}",
 545            lsp_definitions_path.display()
 546        );
 547        let lsp_definitions =
 548            gather_lsp_definitions(&filtered_files, &worktree, &project, cx).await?;
 549        serde_json::to_writer_pretty(File::create(&lsp_definitions_path)?, &lsp_definitions)?;
 550        lsp_definitions
 551    }
 552    .into();
 553
 554    let files_len = files.len().min(file_limit.unwrap_or(usize::MAX));
 555    let done_count = Arc::new(AtomicUsize::new(0));
 556
 557    let (output_tx, mut output_rx) = mpsc::unbounded::<RetrievalStatsResult>();
 558    let mut output = std::fs::File::create("target/zeta-retrieval-stats.txt")?;
 559
 560    let tasks = files
 561        .into_iter()
 562        .skip(skip_files.unwrap_or(0))
 563        .take(file_limit.unwrap_or(usize::MAX))
 564        .map(|project_file| {
 565            let index_state = index_state.clone();
 566            let lsp_definitions = lsp_definitions.clone();
 567            let options = options.clone();
 568            let output_tx = output_tx.clone();
 569            let done_count = done_count.clone();
 570            let file_snapshots = file_snapshots.clone();
 571            cx.background_spawn(async move {
 572                let snapshot = project_file.snapshot;
 573
 574                let full_range = 0..snapshot.len();
 575                let references = references_in_range(
 576                    full_range,
 577                    &snapshot.text(),
 578                    ReferenceRegion::Nearby,
 579                    &snapshot,
 580                );
 581
 582                println!("references: {}", references.len(),);
 583
 584                let imports = if options.context.use_imports {
 585                    Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
 586                } else {
 587                    Imports::default()
 588                };
 589
 590                let path = snapshot.file().unwrap().path();
 591
 592                for reference in references {
 593                    let query_point = snapshot.offset_to_point(reference.range.start);
 594                    let source_location = SourceLocation {
 595                        path: path.clone(),
 596                        point: query_point,
 597                    };
 598                    let lsp_definitions = lsp_definitions
 599                        .definitions
 600                        .get(&source_location)
 601                        .cloned()
 602                        .unwrap_or_else(|| {
 603                            log::warn!(
 604                                "No definitions found for source location: {:?}",
 605                                source_location
 606                            );
 607                            Vec::new()
 608                        });
 609
 610                    let retrieve_result = retrieve_definitions(
 611                        &reference,
 612                        &imports,
 613                        query_point,
 614                        &snapshot,
 615                        &index_state,
 616                        &file_snapshots,
 617                        &options,
 618                    )
 619                    .await?;
 620
 621                    // TODO: LSP returns things like locals, this filters out some of those, but potentially
 622                    // hides some retrieval issues.
 623                    if retrieve_result.definitions.is_empty() {
 624                        continue;
 625                    }
 626
 627                    let mut best_match = None;
 628                    let mut has_external_definition = false;
 629                    let mut in_excerpt = false;
 630                    for (index, retrieved_definition) in
 631                        retrieve_result.definitions.iter().enumerate()
 632                    {
 633                        for lsp_definition in &lsp_definitions {
 634                            let SourceRange {
 635                                path,
 636                                point_range,
 637                                offset_range,
 638                            } = lsp_definition;
 639                            let lsp_point_range =
 640                                SerializablePoint::into_language_point_range(point_range.clone());
 641                            has_external_definition = has_external_definition
 642                                || path.is_absolute()
 643                                || path
 644                                    .components()
 645                                    .any(|component| component.as_os_str() == "node_modules");
 646                            let is_match = path.as_path()
 647                                == retrieved_definition.path.as_std_path()
 648                                && retrieved_definition
 649                                    .range
 650                                    .contains_inclusive(&lsp_point_range);
 651                            if is_match {
 652                                if best_match.is_none() {
 653                                    best_match = Some(index);
 654                                }
 655                            }
 656                            in_excerpt = in_excerpt
 657                                || retrieve_result.excerpt_range.as_ref().is_some_and(
 658                                    |excerpt_range| excerpt_range.contains_inclusive(&offset_range),
 659                                );
 660                        }
 661                    }
 662
 663                    let outcome = if let Some(best_match) = best_match {
 664                        RetrievalOutcome::Match { best_match }
 665                    } else if has_external_definition {
 666                        RetrievalOutcome::NoMatchDueToExternalLspDefinitions
 667                    } else if in_excerpt {
 668                        RetrievalOutcome::ProbablyLocal
 669                    } else {
 670                        RetrievalOutcome::NoMatch
 671                    };
 672
 673                    let result = RetrievalStatsResult {
 674                        outcome,
 675                        path: path.clone(),
 676                        identifier: reference.identifier,
 677                        point: query_point,
 678                        lsp_definitions,
 679                        retrieved_definitions: retrieve_result.definitions,
 680                    };
 681
 682                    output_tx.unbounded_send(result).ok();
 683                }
 684
 685                println!(
 686                    "{:02}/{:02} done",
 687                    done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1,
 688                    files_len,
 689                );
 690
 691                anyhow::Ok(())
 692            })
 693        })
 694        .collect::<Vec<_>>();
 695
 696    drop(output_tx);
 697
 698    let results_task = cx.background_spawn(async move {
 699        let mut results = Vec::new();
 700        while let Some(result) = output_rx.next().await {
 701            output
 702                .write_all(format!("{:#?}\n", result).as_bytes())
 703                .log_err();
 704            results.push(result)
 705        }
 706        results
 707    });
 708
 709    futures::future::try_join_all(tasks).await?;
 710    println!("Tasks completed");
 711    let results = results_task.await;
 712    println!("Results received");
 713
 714    let mut references_count = 0;
 715
 716    let mut included_count = 0;
 717    let mut both_absent_count = 0;
 718
 719    let mut retrieved_count = 0;
 720    let mut top_match_count = 0;
 721    let mut non_top_match_count = 0;
 722    let mut ranking_involved_top_match_count = 0;
 723
 724    let mut no_match_count = 0;
 725    let mut no_match_none_retrieved = 0;
 726    let mut no_match_wrong_retrieval = 0;
 727
 728    let mut expected_no_match_count = 0;
 729    let mut in_excerpt_count = 0;
 730    let mut external_definition_count = 0;
 731
 732    for result in results {
 733        references_count += 1;
 734        match &result.outcome {
 735            RetrievalOutcome::Match { best_match } => {
 736                included_count += 1;
 737                retrieved_count += 1;
 738                let multiple = result.retrieved_definitions.len() > 1;
 739                if *best_match == 0 {
 740                    top_match_count += 1;
 741                    if multiple {
 742                        ranking_involved_top_match_count += 1;
 743                    }
 744                } else {
 745                    non_top_match_count += 1;
 746                }
 747            }
 748            RetrievalOutcome::NoMatch => {
 749                if result.lsp_definitions.is_empty() {
 750                    included_count += 1;
 751                    both_absent_count += 1;
 752                } else {
 753                    no_match_count += 1;
 754                    if result.retrieved_definitions.is_empty() {
 755                        no_match_none_retrieved += 1;
 756                    } else {
 757                        no_match_wrong_retrieval += 1;
 758                    }
 759                }
 760            }
 761            RetrievalOutcome::NoMatchDueToExternalLspDefinitions => {
 762                expected_no_match_count += 1;
 763                external_definition_count += 1;
 764            }
 765            RetrievalOutcome::ProbablyLocal => {
 766                included_count += 1;
 767                in_excerpt_count += 1;
 768            }
 769        }
 770    }
 771
 772    fn count_and_percentage(part: usize, total: usize) -> String {
 773        format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0)
 774    }
 775
 776    println!("");
 777    println!("╮ references: {}", references_count);
 778    println!(
 779        "├─╮ included: {}",
 780        count_and_percentage(included_count, references_count),
 781    );
 782    println!(
 783        "│ ├─╮ retrieved: {}",
 784        count_and_percentage(retrieved_count, references_count)
 785    );
 786    println!(
 787        "│ │ ├─╮ top match : {}",
 788        count_and_percentage(top_match_count, retrieved_count)
 789    );
 790    println!(
 791        "│ │ │ ╰─╴ involving ranking: {}",
 792        count_and_percentage(ranking_involved_top_match_count, top_match_count)
 793    );
 794    println!(
 795        "│ │ ╰─╴ non-top match: {}",
 796        count_and_percentage(non_top_match_count, retrieved_count)
 797    );
 798    println!(
 799        "│ ├─╴ both absent: {}",
 800        count_and_percentage(both_absent_count, included_count)
 801    );
 802    println!(
 803        "│ ╰─╴ in excerpt: {}",
 804        count_and_percentage(in_excerpt_count, included_count)
 805    );
 806    println!(
 807        "├─╮ no match: {}",
 808        count_and_percentage(no_match_count, references_count)
 809    );
 810    println!(
 811        "│ ├─╴ none retrieved: {}",
 812        count_and_percentage(no_match_none_retrieved, no_match_count)
 813    );
 814    println!(
 815        "│ ╰─╴ wrong retrieval: {}",
 816        count_and_percentage(no_match_wrong_retrieval, no_match_count)
 817    );
 818    println!(
 819        "╰─╮ expected no match: {}",
 820        count_and_percentage(expected_no_match_count, references_count)
 821    );
 822    println!(
 823        "  ╰─╴ external definition: {}",
 824        count_and_percentage(external_definition_count, expected_no_match_count)
 825    );
 826
 827    println!("");
 828    println!("LSP definition cache at {}", lsp_definitions_path.display());
 829
 830    Ok("".to_string())
 831}
 832
 833struct RetrieveResult {
 834    definitions: Vec<RetrievedDefinition>,
 835    excerpt_range: Option<Range<usize>>,
 836}
 837
 838async fn retrieve_definitions(
 839    reference: &Reference,
 840    imports: &Imports,
 841    query_point: Point,
 842    snapshot: &BufferSnapshot,
 843    index: &Arc<SyntaxIndexState>,
 844    file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
 845    options: &Arc<zeta2::ZetaOptions>,
 846) -> Result<RetrieveResult> {
 847    let mut single_reference_map = HashMap::default();
 848    single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
 849    let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
 850        query_point,
 851        snapshot,
 852        imports,
 853        &options.context,
 854        Some(&index),
 855        |_, _, _| single_reference_map,
 856    );
 857
 858    let Some(edit_prediction_context) = edit_prediction_context else {
 859        return Ok(RetrieveResult {
 860            definitions: Vec::new(),
 861            excerpt_range: None,
 862        });
 863    };
 864
 865    let mut retrieved_definitions = Vec::new();
 866    for scored_declaration in edit_prediction_context.declarations {
 867        match &scored_declaration.declaration {
 868            Declaration::File {
 869                project_entry_id,
 870                declaration,
 871                ..
 872            } => {
 873                let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
 874                    log::error!("bug: file project entry not found");
 875                    continue;
 876                };
 877                let path = snapshot.file().unwrap().path().clone();
 878                retrieved_definitions.push(RetrievedDefinition {
 879                    path,
 880                    range: snapshot.offset_to_point(declaration.item_range.start)
 881                        ..snapshot.offset_to_point(declaration.item_range.end),
 882                    score: scored_declaration.score(DeclarationStyle::Declaration),
 883                    retrieval_score: scored_declaration.retrieval_score(),
 884                    components: scored_declaration.components,
 885                });
 886            }
 887            Declaration::Buffer {
 888                project_entry_id,
 889                rope,
 890                declaration,
 891                ..
 892            } => {
 893                let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
 894                    // This case happens when dependency buffers have been opened by
 895                    // go-to-definition, resulting in single-file worktrees.
 896                    continue;
 897                };
 898                let path = snapshot.file().unwrap().path().clone();
 899                retrieved_definitions.push(RetrievedDefinition {
 900                    path,
 901                    range: rope.offset_to_point(declaration.item_range.start)
 902                        ..rope.offset_to_point(declaration.item_range.end),
 903                    score: scored_declaration.score(DeclarationStyle::Declaration),
 904                    retrieval_score: scored_declaration.retrieval_score(),
 905                    components: scored_declaration.components,
 906                });
 907            }
 908        }
 909    }
 910    retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score)));
 911
 912    Ok(RetrieveResult {
 913        definitions: retrieved_definitions,
 914        excerpt_range: Some(edit_prediction_context.excerpt.range),
 915    })
 916}
 917
 918async fn gather_lsp_definitions(
 919    files: &[ProjectPath],
 920    worktree: &Entity<Worktree>,
 921    project: &Entity<Project>,
 922    cx: &mut AsyncApp,
 923) -> Result<LspResults> {
 924    let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
 925
 926    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
 927    cx.subscribe(&lsp_store, {
 928        move |_, event, _| {
 929            if let project::LspStoreEvent::LanguageServerUpdate {
 930                message:
 931                    client::proto::update_language_server::Variant::WorkProgress(
 932                        client::proto::LspWorkProgress {
 933                            message: Some(message),
 934                            ..
 935                        },
 936                    ),
 937                ..
 938            } = event
 939            {
 940                println!("{message}")
 941            }
 942        }
 943    })?
 944    .detach();
 945
 946    let mut definitions = HashMap::default();
 947    let mut error_count = 0;
 948    let mut lsp_open_handles = Vec::new();
 949    let mut ready_languages = HashSet::default();
 950    for (file_index, project_path) in files.iter().enumerate() {
 951        println!(
 952            "Processing file {} of {}: {}",
 953            file_index + 1,
 954            files.len(),
 955            project_path.path.display(PathStyle::Posix)
 956        );
 957
 958        let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
 959            project.clone(),
 960            worktree.clone(),
 961            project_path.path.clone(),
 962            &mut ready_languages,
 963            cx,
 964        )
 965        .await
 966        .log_err() else {
 967            continue;
 968        };
 969        lsp_open_handles.push(lsp_open_handle);
 970
 971        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 972        let full_range = 0..snapshot.len();
 973        let references = references_in_range(
 974            full_range,
 975            &snapshot.text(),
 976            ReferenceRegion::Nearby,
 977            &snapshot,
 978        );
 979
 980        loop {
 981            let is_ready = lsp_store
 982                .read_with(cx, |lsp_store, _cx| {
 983                    lsp_store
 984                        .language_server_statuses
 985                        .get(&language_server_id)
 986                        .is_some_and(|status| status.pending_work.is_empty())
 987                })
 988                .unwrap();
 989            if is_ready {
 990                break;
 991            }
 992            cx.background_executor()
 993                .timer(Duration::from_millis(10))
 994                .await;
 995        }
 996
 997        for reference in references {
 998            // TODO: Rename declaration to definition in edit_prediction_context?
 999            let lsp_result = project
1000                .update(cx, |project, cx| {
1001                    project.definitions(&buffer, reference.range.start, cx)
1002                })?
1003                .await;
1004
1005            match lsp_result {
1006                Ok(lsp_definitions) => {
1007                    let mut targets = Vec::new();
1008                    for target in lsp_definitions.unwrap_or_default() {
1009                        let buffer = target.target.buffer;
1010                        let anchor_range = target.target.range;
1011                        buffer.read_with(cx, |buffer, cx| {
1012                            let Some(file) = project::File::from_dyn(buffer.file()) else {
1013                                return;
1014                            };
1015                            let file_worktree = file.worktree.read(cx);
1016                            let file_worktree_id = file_worktree.id();
1017                            // Relative paths for worktree files, absolute for all others
1018                            let path = if worktree_id != file_worktree_id {
1019                                file.worktree.read(cx).absolutize(&file.path)
1020                            } else {
1021                                file.path.as_std_path().to_path_buf()
1022                            };
1023                            let offset_range = anchor_range.to_offset(&buffer);
1024                            let point_range = SerializablePoint::from_language_point_range(
1025                                offset_range.to_point(&buffer),
1026                            );
1027                            targets.push(SourceRange {
1028                                path,
1029                                offset_range,
1030                                point_range,
1031                            });
1032                        })?;
1033                    }
1034
1035                    definitions.insert(
1036                        SourceLocation {
1037                            path: project_path.path.clone(),
1038                            point: snapshot.offset_to_point(reference.range.start),
1039                        },
1040                        targets,
1041                    );
1042                }
1043                Err(err) => {
1044                    log::error!("Language server error: {err}");
1045                    error_count += 1;
1046                }
1047            }
1048        }
1049    }
1050
1051    log::error!("Encountered {} language server errors", error_count);
1052
1053    Ok(LspResults { definitions })
1054}
1055
1056#[derive(Debug, Clone, Serialize, Deserialize)]
1057#[serde(transparent)]
1058struct LspResults {
1059    definitions: HashMap<SourceLocation, Vec<SourceRange>>,
1060}
1061
1062#[derive(Debug, Clone, Serialize, Deserialize)]
1063struct SourceRange {
1064    path: PathBuf,
1065    point_range: Range<SerializablePoint>,
1066    offset_range: Range<usize>,
1067}
1068
1069/// Serializes to 1-based row and column indices.
1070#[derive(Debug, Clone, Serialize, Deserialize)]
1071pub struct SerializablePoint {
1072    pub row: u32,
1073    pub column: u32,
1074}
1075
1076impl SerializablePoint {
1077    pub fn into_language_point_range(range: Range<Self>) -> Range<Point> {
1078        range.start.into()..range.end.into()
1079    }
1080
1081    pub fn from_language_point_range(range: Range<Point>) -> Range<Self> {
1082        range.start.into()..range.end.into()
1083    }
1084}
1085
1086impl From<Point> for SerializablePoint {
1087    fn from(point: Point) -> Self {
1088        SerializablePoint {
1089            row: point.row + 1,
1090            column: point.column + 1,
1091        }
1092    }
1093}
1094
1095impl From<SerializablePoint> for Point {
1096    fn from(serializable: SerializablePoint) -> Self {
1097        Point {
1098            row: serializable.row.saturating_sub(1),
1099            column: serializable.column.saturating_sub(1),
1100        }
1101    }
1102}
1103
1104#[derive(Debug)]
1105struct RetrievalStatsResult {
1106    outcome: RetrievalOutcome,
1107    #[allow(dead_code)]
1108    path: Arc<RelPath>,
1109    #[allow(dead_code)]
1110    identifier: Identifier,
1111    #[allow(dead_code)]
1112    point: Point,
1113    #[allow(dead_code)]
1114    lsp_definitions: Vec<SourceRange>,
1115    retrieved_definitions: Vec<RetrievedDefinition>,
1116}
1117
1118#[derive(Debug)]
1119enum RetrievalOutcome {
1120    Match {
1121        /// Lowest index within retrieved_definitions that matches an LSP definition.
1122        best_match: usize,
1123    },
1124    ProbablyLocal,
1125    NoMatch,
1126    NoMatchDueToExternalLspDefinitions,
1127}
1128
1129#[derive(Debug)]
1130struct RetrievedDefinition {
1131    path: Arc<RelPath>,
1132    range: Range<Point>,
1133    score: f32,
1134    #[allow(dead_code)]
1135    retrieval_score: f32,
1136    #[allow(dead_code)]
1137    components: DeclarationScoreComponents,
1138}
1139
1140pub fn open_buffer(
1141    project: Entity<Project>,
1142    worktree: Entity<Worktree>,
1143    path: Arc<RelPath>,
1144    cx: &AsyncApp,
1145) -> Task<Result<Entity<Buffer>>> {
1146    cx.spawn(async move |cx| {
1147        let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
1148            worktree_id: worktree.id(),
1149            path,
1150        })?;
1151
1152        let buffer = project
1153            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
1154            .await?;
1155
1156        let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
1157        while *parse_status.borrow() != ParseStatus::Idle {
1158            parse_status.changed().await?;
1159        }
1160
1161        Ok(buffer)
1162    })
1163}
1164
1165pub async fn open_buffer_with_language_server(
1166    project: Entity<Project>,
1167    worktree: Entity<Worktree>,
1168    path: Arc<RelPath>,
1169    ready_languages: &mut HashSet<LanguageId>,
1170    cx: &mut AsyncApp,
1171) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
1172    let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
1173
1174    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
1175        (
1176            project.register_buffer_with_language_servers(&buffer, cx),
1177            project.path_style(cx),
1178        )
1179    })?;
1180
1181    let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
1182        buffer.language().map(|language| language.id())
1183    })?
1184    else {
1185        return Err(anyhow!("No language for {}", path.display(path_style)));
1186    };
1187
1188    let log_prefix = path.display(path_style);
1189    if !ready_languages.contains(&language_id) {
1190        wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
1191        ready_languages.insert(language_id);
1192    }
1193
1194    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
1195
1196    // hacky wait for buffer to be registered with the language server
1197    for _ in 0..100 {
1198        let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
1199            buffer.update(cx, |buffer, cx| {
1200                lsp_store
1201                    .language_servers_for_local_buffer(&buffer, cx)
1202                    .next()
1203                    .map(|(_, language_server)| language_server.server_id())
1204            })
1205        })?
1206        else {
1207            cx.background_executor()
1208                .timer(Duration::from_millis(10))
1209                .await;
1210            continue;
1211        };
1212
1213        return Ok((lsp_open_handle, language_server_id, buffer));
1214    }
1215
1216    return Err(anyhow!("No language server found for buffer"));
1217}
1218
1219// TODO: Dedupe with similar function in crates/eval/src/instance.rs
1220pub fn wait_for_lang_server(
1221    project: &Entity<Project>,
1222    buffer: &Entity<Buffer>,
1223    log_prefix: String,
1224    cx: &mut AsyncApp,
1225) -> Task<Result<()>> {
1226    println!("{}⏵ Waiting for language server", log_prefix);
1227
1228    let (mut tx, mut rx) = mpsc::channel(1);
1229
1230    let lsp_store = project
1231        .read_with(cx, |project, _| project.lsp_store())
1232        .unwrap();
1233
1234    let has_lang_server = buffer
1235        .update(cx, |buffer, cx| {
1236            lsp_store.update(cx, |lsp_store, cx| {
1237                lsp_store
1238                    .language_servers_for_local_buffer(buffer, cx)
1239                    .next()
1240                    .is_some()
1241            })
1242        })
1243        .unwrap_or(false);
1244
1245    if has_lang_server {
1246        project
1247            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
1248            .unwrap()
1249            .detach();
1250    }
1251    let (mut added_tx, mut added_rx) = mpsc::channel(1);
1252
1253    let subscriptions = [
1254        cx.subscribe(&lsp_store, {
1255            let log_prefix = log_prefix.clone();
1256            move |_, event, _| {
1257                if let project::LspStoreEvent::LanguageServerUpdate {
1258                    message:
1259                        client::proto::update_language_server::Variant::WorkProgress(
1260                            client::proto::LspWorkProgress {
1261                                message: Some(message),
1262                                ..
1263                            },
1264                        ),
1265                    ..
1266                } = event
1267                {
1268                    println!("{}{message}", log_prefix)
1269                }
1270            }
1271        }),
1272        cx.subscribe(project, {
1273            let buffer = buffer.clone();
1274            move |project, event, cx| match event {
1275                project::Event::LanguageServerAdded(_, _, _) => {
1276                    let buffer = buffer.clone();
1277                    project
1278                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
1279                        .detach();
1280                    added_tx.try_send(()).ok();
1281                }
1282                project::Event::DiskBasedDiagnosticsFinished { .. } => {
1283                    tx.try_send(()).ok();
1284                }
1285                _ => {}
1286            }
1287        }),
1288    ];
1289
1290    cx.spawn(async move |cx| {
1291        if !has_lang_server {
1292            // some buffers never have a language server, so this aborts quickly in that case.
1293            let timeout = cx.background_executor().timer(Duration::from_secs(5));
1294            futures::select! {
1295                _ = added_rx.next() => {},
1296                _ = timeout.fuse() => {
1297                    anyhow::bail!("Waiting for language server add timed out after 5 seconds");
1298                }
1299            };
1300        }
1301        let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
1302        let result = futures::select! {
1303            _ = rx.next() => {
1304                println!("{}⚑ Language server idle", log_prefix);
1305                anyhow::Ok(())
1306            },
1307            _ = timeout.fuse() => {
1308                anyhow::bail!("LSP wait timed out after 5 minutes");
1309            }
1310        };
1311        drop(subscriptions);
1312        result
1313    })
1314}
1315
1316fn main() {
1317    zlog::init();
1318    zlog::init_output_stderr();
1319    let args = ZetaCliArgs::parse();
1320    let http_client = Arc::new(ReqwestClient::new());
1321    let app = Application::headless().with_http_client(http_client);
1322
1323    app.run(move |cx| {
1324        let app_state = Arc::new(headless::init(cx));
1325        cx.spawn(async move |cx| {
1326            let result = match args.command {
1327                Commands::Zeta2Context {
1328                    zeta2_args,
1329                    context_args,
1330                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
1331                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
1332                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
1333                    Err(err) => Err(err),
1334                },
1335                Commands::Context(context_args) => {
1336                    match get_context(None, context_args, &app_state, cx).await {
1337                        Ok(GetContextOutput::Zeta1(output)) => {
1338                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
1339                        }
1340                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
1341                        Err(err) => Err(err),
1342                    }
1343                }
1344                Commands::Predict {
1345                    predict_edits_body,
1346                    context_args,
1347                } => {
1348                    cx.spawn(async move |cx| {
1349                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
1350                        app_state.client.sign_in(true, cx).await?;
1351                        let llm_token = LlmApiToken::default();
1352                        llm_token.refresh(&app_state.client).await?;
1353
1354                        let predict_edits_body =
1355                            if let Some(predict_edits_body) = predict_edits_body {
1356                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
1357                            } else if let Some(context_args) = context_args {
1358                                match get_context(None, context_args, &app_state, cx).await? {
1359                                    GetContextOutput::Zeta1(output) => output.body,
1360                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
1361                                }
1362                            } else {
1363                                return Err(anyhow!(
1364                                    "Expected either --predict-edits-body-file \
1365                                    or the required args of the `context` command."
1366                                ));
1367                            };
1368
1369                        let (response, _usage) =
1370                            Zeta::perform_predict_edits(PerformPredictEditsParams {
1371                                client: app_state.client.clone(),
1372                                llm_token,
1373                                app_version,
1374                                body: predict_edits_body,
1375                            })
1376                            .await?;
1377
1378                        Ok(response.output_excerpt)
1379                    })
1380                    .await
1381                }
1382                Commands::RetrievalStats {
1383                    zeta2_args,
1384                    worktree,
1385                    extension,
1386                    limit,
1387                    skip,
1388                } => {
1389                    retrieval_stats(
1390                        worktree,
1391                        app_state,
1392                        extension,
1393                        limit,
1394                        skip,
1395                        (&zeta2_args).to_options(false),
1396                        cx,
1397                    )
1398                    .await
1399                }
1400            };
1401            match result {
1402                Ok(output) => {
1403                    println!("{}", output);
1404                    let _ = cx.update(|cx| cx.quit());
1405                }
1406                Err(e) => {
1407                    eprintln!("Failed: {:?}", e);
1408                    exit(1);
1409                }
1410            }
1411        })
1412        .detach();
1413    });
1414}