main.rs

   1mod headless;
   2
   3use anyhow::{Context as _, Result, anyhow};
   4use clap::{Args, Parser, Subcommand};
   5use cloud_llm_client::predict_edits_v3::{self, DeclarationScoreComponents};
   6use edit_prediction_context::{
   7    Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
   8    EditPredictionExcerptOptions, EditPredictionScoreOptions, Identifier, Imports, Reference,
   9    ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
  10};
  11use futures::channel::mpsc;
  12use futures::{FutureExt as _, StreamExt as _};
  13use gpui::{AppContext, Application, AsyncApp};
  14use gpui::{Entity, Task};
  15use language::{Bias, BufferSnapshot, LanguageServerId, Point};
  16use language::{Buffer, OffsetRangeExt};
  17use language::{LanguageId, ParseStatus};
  18use language_model::LlmApiToken;
  19use ordered_float::OrderedFloat;
  20use project::{Project, ProjectEntryId, ProjectPath, Worktree};
  21use release_channel::AppVersion;
  22use reqwest_client::ReqwestClient;
  23use serde::{Deserialize, Deserializer, Serialize, Serializer};
  24use serde_json::json;
  25use std::cmp::Reverse;
  26use std::collections::{HashMap, HashSet};
  27use std::fmt::{self, Display};
  28use std::fs::File;
  29use std::hash::Hash;
  30use std::hash::Hasher;
  31use std::io::Write as _;
  32use std::ops::Range;
  33use std::path::{Path, PathBuf};
  34use std::process::exit;
  35use std::str::FromStr;
  36use std::sync::atomic::AtomicUsize;
  37use std::sync::{Arc, atomic};
  38use std::time::Duration;
  39use util::paths::PathStyle;
  40use util::rel_path::RelPath;
  41use util::{RangeExt, ResultExt as _};
  42use zeta::{PerformPredictEditsParams, Zeta};
  43
  44use crate::headless::ZetaCliAppState;
  45
  46#[derive(Parser, Debug)]
  47#[command(name = "zeta")]
  48struct ZetaCliArgs {
  49    #[command(subcommand)]
  50    command: Commands,
  51}
  52
  53#[derive(Subcommand, Debug)]
  54enum Commands {
  55    Context(ContextArgs),
  56    Zeta2Context {
  57        #[clap(flatten)]
  58        zeta2_args: Zeta2Args,
  59        #[clap(flatten)]
  60        context_args: ContextArgs,
  61    },
  62    Predict {
  63        #[arg(long)]
  64        predict_edits_body: Option<FileOrStdin>,
  65        #[clap(flatten)]
  66        context_args: Option<ContextArgs>,
  67    },
  68    RetrievalStats {
  69        #[clap(flatten)]
  70        zeta2_args: Zeta2Args,
  71        #[arg(long)]
  72        worktree: PathBuf,
  73        #[arg(long)]
  74        extension: Option<String>,
  75        #[arg(long)]
  76        limit: Option<usize>,
  77        #[arg(long)]
  78        skip: Option<usize>,
  79    },
  80}
  81
  82#[derive(Debug, Args)]
  83#[group(requires = "worktree")]
  84struct ContextArgs {
  85    #[arg(long)]
  86    worktree: PathBuf,
  87    #[arg(long)]
  88    cursor: SourceLocation,
  89    #[arg(long)]
  90    use_language_server: bool,
  91    #[arg(long)]
  92    events: Option<FileOrStdin>,
  93}
  94
  95#[derive(Debug, Args)]
  96struct Zeta2Args {
  97    #[arg(long, default_value_t = 8192)]
  98    max_prompt_bytes: usize,
  99    #[arg(long, default_value_t = 2048)]
 100    max_excerpt_bytes: usize,
 101    #[arg(long, default_value_t = 1024)]
 102    min_excerpt_bytes: usize,
 103    #[arg(long, default_value_t = 0.66)]
 104    target_before_cursor_over_total_bytes: f32,
 105    #[arg(long, default_value_t = 1024)]
 106    max_diagnostic_bytes: usize,
 107    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
 108    prompt_format: PromptFormat,
 109    #[arg(long, value_enum, default_value_t = Default::default())]
 110    output_format: OutputFormat,
 111    #[arg(long, default_value_t = 42)]
 112    file_indexing_parallelism: usize,
 113    #[arg(long, default_value_t = false)]
 114    disable_imports_gathering: bool,
 115    #[arg(long, default_value_t = 0.5)]
 116    prefilter_score_ratio: f32,
 117}
 118
 119#[derive(clap::ValueEnum, Default, Debug, Clone)]
 120enum PromptFormat {
 121    #[default]
 122    MarkedExcerpt,
 123    LabeledSections,
 124    OnlySnippets,
 125}
 126
 127impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
 128    fn into(self) -> predict_edits_v3::PromptFormat {
 129        match self {
 130            Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
 131            Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
 132            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
 133        }
 134    }
 135}
 136
 137#[derive(clap::ValueEnum, Default, Debug, Clone)]
 138enum OutputFormat {
 139    #[default]
 140    Prompt,
 141    Request,
 142    Full,
 143}
 144
 145#[derive(Debug, Clone)]
 146enum FileOrStdin {
 147    File(PathBuf),
 148    Stdin,
 149}
 150
 151impl FileOrStdin {
 152    async fn read_to_string(&self) -> Result<String, std::io::Error> {
 153        match self {
 154            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
 155            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
 156        }
 157    }
 158}
 159
 160impl FromStr for FileOrStdin {
 161    type Err = <PathBuf as FromStr>::Err;
 162
 163    fn from_str(s: &str) -> Result<Self, Self::Err> {
 164        match s {
 165            "-" => Ok(Self::Stdin),
 166            _ => Ok(Self::File(PathBuf::from_str(s)?)),
 167        }
 168    }
 169}
 170
 171#[derive(Debug, Clone, Hash, Eq, PartialEq)]
 172struct SourceLocation {
 173    path: Arc<RelPath>,
 174    point: Point,
 175}
 176
 177impl Serialize for SourceLocation {
 178    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 179    where
 180        S: Serializer,
 181    {
 182        serializer.serialize_str(&self.to_string())
 183    }
 184}
 185
 186impl<'de> Deserialize<'de> for SourceLocation {
 187    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 188    where
 189        D: Deserializer<'de>,
 190    {
 191        let s = String::deserialize(deserializer)?;
 192        s.parse().map_err(serde::de::Error::custom)
 193    }
 194}
 195
 196impl Display for SourceLocation {
 197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 198        write!(
 199            f,
 200            "{}:{}:{}",
 201            self.path.display(PathStyle::Posix),
 202            self.point.row + 1,
 203            self.point.column + 1
 204        )
 205    }
 206}
 207
 208impl FromStr for SourceLocation {
 209    type Err = anyhow::Error;
 210
 211    fn from_str(s: &str) -> Result<Self> {
 212        let parts: Vec<&str> = s.split(':').collect();
 213        if parts.len() != 3 {
 214            return Err(anyhow!(
 215                "Invalid source location. Expected 'file.rs:line:column', got '{}'",
 216                s
 217            ));
 218        }
 219
 220        let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
 221        let line: u32 = parts[1]
 222            .parse()
 223            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
 224        let column: u32 = parts[2]
 225            .parse()
 226            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
 227
 228        // Convert from 1-based to 0-based indexing
 229        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
 230
 231        Ok(SourceLocation { path, point })
 232    }
 233}
 234
 235enum GetContextOutput {
 236    Zeta1(zeta::GatherContextOutput),
 237    Zeta2(String),
 238}
 239
 240async fn get_context(
 241    zeta2_args: Option<Zeta2Args>,
 242    args: ContextArgs,
 243    app_state: &Arc<ZetaCliAppState>,
 244    cx: &mut AsyncApp,
 245) -> Result<GetContextOutput> {
 246    let ContextArgs {
 247        worktree: worktree_path,
 248        cursor,
 249        use_language_server,
 250        events,
 251    } = args;
 252
 253    let worktree_path = worktree_path.canonicalize()?;
 254
 255    let project = cx.update(|cx| {
 256        Project::local(
 257            app_state.client.clone(),
 258            app_state.node_runtime.clone(),
 259            app_state.user_store.clone(),
 260            app_state.languages.clone(),
 261            app_state.fs.clone(),
 262            None,
 263            cx,
 264        )
 265    })?;
 266
 267    let worktree = project
 268        .update(cx, |project, cx| {
 269            project.create_worktree(&worktree_path, true, cx)
 270        })?
 271        .await?;
 272
 273    let mut ready_languages = HashSet::default();
 274    let (_lsp_open_handle, buffer) = if use_language_server {
 275        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
 276            project.clone(),
 277            worktree.clone(),
 278            cursor.path.clone(),
 279            &mut ready_languages,
 280            cx,
 281        )
 282        .await?;
 283        (Some(lsp_open_handle), buffer)
 284    } else {
 285        let buffer =
 286            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
 287        (None, buffer)
 288    };
 289
 290    let full_path_str = worktree
 291        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
 292        .display(PathStyle::local())
 293        .to_string();
 294
 295    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
 296    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
 297    if clipped_cursor != cursor.point {
 298        let max_row = snapshot.max_point().row;
 299        if cursor.point.row < max_row {
 300            return Err(anyhow!(
 301                "Cursor position {:?} is out of bounds (line length is {})",
 302                cursor.point,
 303                snapshot.line_len(cursor.point.row)
 304            ));
 305        } else {
 306            return Err(anyhow!(
 307                "Cursor position {:?} is out of bounds (max row is {})",
 308                cursor.point,
 309                max_row
 310            ));
 311        }
 312    }
 313
 314    let events = match events {
 315        Some(events) => events.read_to_string().await?,
 316        None => String::new(),
 317    };
 318
 319    if let Some(zeta2_args) = zeta2_args {
 320        // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
 321        // the whole worktree.
 322        worktree
 323            .read_with(cx, |worktree, _cx| {
 324                worktree.as_local().unwrap().scan_complete()
 325            })?
 326            .await;
 327        let output = cx
 328            .update(|cx| {
 329                let zeta = cx.new(|cx| {
 330                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
 331                });
 332                let indexing_done_task = zeta.update(cx, |zeta, cx| {
 333                    zeta.set_options(zeta2_args.to_options(true));
 334                    zeta.register_buffer(&buffer, &project, cx);
 335                    zeta.wait_for_initial_indexing(&project, cx)
 336                });
 337                cx.spawn(async move |cx| {
 338                    indexing_done_task.await?;
 339                    let request = zeta
 340                        .update(cx, |zeta, cx| {
 341                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
 342                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
 343                        })?
 344                        .await?;
 345
 346                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
 347                    let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
 348
 349                    match zeta2_args.output_format {
 350                        OutputFormat::Prompt => anyhow::Ok(prompt_string),
 351                        OutputFormat::Request => {
 352                            anyhow::Ok(serde_json::to_string_pretty(&request)?)
 353                        }
 354                        OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
 355                            "request": request,
 356                            "prompt": prompt_string,
 357                            "section_labels": section_labels,
 358                        }))?),
 359                    }
 360                })
 361            })?
 362            .await?;
 363        Ok(GetContextOutput::Zeta2(output))
 364    } else {
 365        let prompt_for_events = move || (events, 0);
 366        Ok(GetContextOutput::Zeta1(
 367            cx.update(|cx| {
 368                zeta::gather_context(
 369                    full_path_str,
 370                    &snapshot,
 371                    clipped_cursor,
 372                    prompt_for_events,
 373                    cx,
 374                )
 375            })?
 376            .await?,
 377        ))
 378    }
 379}
 380
 381impl Zeta2Args {
 382    fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
 383        zeta2::ZetaOptions {
 384            context: EditPredictionContextOptions {
 385                use_imports: !self.disable_imports_gathering,
 386                excerpt: EditPredictionExcerptOptions {
 387                    max_bytes: self.max_excerpt_bytes,
 388                    min_bytes: self.min_excerpt_bytes,
 389                    target_before_cursor_over_total_bytes: self
 390                        .target_before_cursor_over_total_bytes,
 391                },
 392                score: EditPredictionScoreOptions {
 393                    omit_excerpt_overlaps,
 394                    prefilter_score_ratio: self.prefilter_score_ratio,
 395                },
 396            },
 397            max_diagnostic_bytes: self.max_diagnostic_bytes,
 398            max_prompt_bytes: self.max_prompt_bytes,
 399            prompt_format: self.prompt_format.clone().into(),
 400            file_indexing_parallelism: self.file_indexing_parallelism,
 401        }
 402    }
 403}
 404
 405pub async fn retrieval_stats(
 406    worktree: PathBuf,
 407    app_state: Arc<ZetaCliAppState>,
 408    only_extension: Option<String>,
 409    file_limit: Option<usize>,
 410    skip_files: Option<usize>,
 411    options: zeta2::ZetaOptions,
 412    cx: &mut AsyncApp,
 413) -> Result<String> {
 414    let options = Arc::new(options);
 415    let worktree_path = worktree.canonicalize()?;
 416
 417    let project = cx.update(|cx| {
 418        Project::local(
 419            app_state.client.clone(),
 420            app_state.node_runtime.clone(),
 421            app_state.user_store.clone(),
 422            app_state.languages.clone(),
 423            app_state.fs.clone(),
 424            None,
 425            cx,
 426        )
 427    })?;
 428
 429    let worktree = project
 430        .update(cx, |project, cx| {
 431            project.create_worktree(&worktree_path, true, cx)
 432        })?
 433        .await?;
 434
 435    // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
 436    worktree
 437        .read_with(cx, |worktree, _cx| {
 438            worktree.as_local().unwrap().scan_complete()
 439        })?
 440        .await;
 441
 442    let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?;
 443    index
 444        .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
 445        .await?;
 446    let indexed_files = index
 447        .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
 448        .await;
 449    let mut filtered_files = indexed_files
 450        .into_iter()
 451        .filter(|project_path| {
 452            let file_extension = project_path.path.extension();
 453            if let Some(only_extension) = only_extension.as_ref() {
 454                file_extension.is_some_and(|extension| extension == only_extension)
 455            } else {
 456                file_extension
 457                    .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
 458            }
 459        })
 460        .collect::<Vec<_>>();
 461    filtered_files.sort_by(|a, b| a.path.cmp(&b.path));
 462
 463    let index_state = index.read_with(cx, |index, _cx| index.state().clone())?;
 464    cx.update(|_| {
 465        drop(index);
 466    })?;
 467    let index_state = Arc::new(
 468        Arc::into_inner(index_state)
 469            .context("Index state had more than 1 reference")?
 470            .into_inner(),
 471    );
 472
 473    struct FileSnapshot {
 474        project_entry_id: ProjectEntryId,
 475        snapshot: BufferSnapshot,
 476        hash: u64,
 477        parent_abs_path: Arc<Path>,
 478    }
 479
 480    let files: Vec<FileSnapshot> = futures::future::try_join_all({
 481        filtered_files
 482            .iter()
 483            .map(|file| {
 484                let buffer_task =
 485                    open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx);
 486                cx.spawn(async move |cx| {
 487                    let buffer = buffer_task.await?;
 488                    let (project_entry_id, parent_abs_path, snapshot) =
 489                        buffer.read_with(cx, |buffer, cx| {
 490                            let file = project::File::from_dyn(buffer.file()).unwrap();
 491                            let project_entry_id = file.project_entry_id().unwrap();
 492                            let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path);
 493                            if !parent_abs_path.pop() {
 494                                panic!("Invalid worktree path");
 495                            }
 496
 497                            (project_entry_id, parent_abs_path, buffer.snapshot())
 498                        })?;
 499
 500                    anyhow::Ok(
 501                        cx.background_spawn(async move {
 502                            let mut hasher = collections::FxHasher::default();
 503                            snapshot.text().hash(&mut hasher);
 504                            FileSnapshot {
 505                                project_entry_id,
 506                                snapshot,
 507                                hash: hasher.finish(),
 508                                parent_abs_path: parent_abs_path.into(),
 509                            }
 510                        })
 511                        .await,
 512                    )
 513                })
 514            })
 515            .collect::<Vec<_>>()
 516    })
 517    .await?;
 518
 519    let mut file_snapshots = HashMap::default();
 520    let mut hasher = collections::FxHasher::default();
 521    for FileSnapshot {
 522        project_entry_id,
 523        snapshot,
 524        hash,
 525        ..
 526    } in &files
 527    {
 528        file_snapshots.insert(*project_entry_id, snapshot.clone());
 529        hash.hash(&mut hasher);
 530    }
 531    let files_hash = hasher.finish();
 532    let file_snapshots = Arc::new(file_snapshots);
 533
 534    let lsp_definitions_path = std::env::current_dir()?.join(format!(
 535        "target/zeta2-lsp-definitions-{:x}.json",
 536        files_hash
 537    ));
 538
 539    let lsp_definitions: Arc<_> = if std::fs::exists(&lsp_definitions_path)? {
 540        log::info!(
 541            "Using cached LSP definitions from {}",
 542            lsp_definitions_path.display()
 543        );
 544        serde_json::from_reader(File::open(&lsp_definitions_path)?)?
 545    } else {
 546        log::warn!(
 547            "No LSP definitions found populating {}",
 548            lsp_definitions_path.display()
 549        );
 550        let lsp_definitions =
 551            gather_lsp_definitions(&filtered_files, &worktree, &project, cx).await?;
 552        serde_json::to_writer_pretty(File::create(&lsp_definitions_path)?, &lsp_definitions)?;
 553        lsp_definitions
 554    }
 555    .into();
 556
 557    let files_len = files.len().min(file_limit.unwrap_or(usize::MAX));
 558    let done_count = Arc::new(AtomicUsize::new(0));
 559
 560    let (output_tx, mut output_rx) = mpsc::unbounded::<RetrievalStatsResult>();
 561    let mut output = std::fs::File::create("target/zeta-retrieval-stats.txt")?;
 562
 563    let tasks = files
 564        .into_iter()
 565        .skip(skip_files.unwrap_or(0))
 566        .take(file_limit.unwrap_or(usize::MAX))
 567        .map(|project_file| {
 568            let index_state = index_state.clone();
 569            let lsp_definitions = lsp_definitions.clone();
 570            let options = options.clone();
 571            let output_tx = output_tx.clone();
 572            let done_count = done_count.clone();
 573            let file_snapshots = file_snapshots.clone();
 574            cx.background_spawn(async move {
 575                let snapshot = project_file.snapshot;
 576
 577                let full_range = 0..snapshot.len();
 578                let references = references_in_range(
 579                    full_range,
 580                    &snapshot.text(),
 581                    ReferenceRegion::Nearby,
 582                    &snapshot,
 583                );
 584
 585                println!("references: {}", references.len(),);
 586
 587                let imports = if options.context.use_imports {
 588                    Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
 589                } else {
 590                    Imports::default()
 591                };
 592
 593                let path = snapshot.file().unwrap().path();
 594
 595                for reference in references {
 596                    let query_point = snapshot.offset_to_point(reference.range.start);
 597                    let source_location = SourceLocation {
 598                        path: path.clone(),
 599                        point: query_point,
 600                    };
 601                    let lsp_definitions = lsp_definitions
 602                        .definitions
 603                        .get(&source_location)
 604                        .cloned()
 605                        .unwrap_or_else(|| {
 606                            log::warn!(
 607                                "No definitions found for source location: {:?}",
 608                                source_location
 609                            );
 610                            Vec::new()
 611                        });
 612
 613                    let retrieve_result = retrieve_definitions(
 614                        &reference,
 615                        &imports,
 616                        query_point,
 617                        &snapshot,
 618                        &index_state,
 619                        &file_snapshots,
 620                        &options,
 621                    )
 622                    .await?;
 623
 624                    // TODO: LSP returns things like locals, this filters out some of those, but potentially
 625                    // hides some retrieval issues.
 626                    if retrieve_result.definitions.is_empty() {
 627                        continue;
 628                    }
 629
 630                    let mut best_match = None;
 631                    let mut has_external_definition = false;
 632                    let mut in_excerpt = false;
 633                    for (index, retrieved_definition) in
 634                        retrieve_result.definitions.iter().enumerate()
 635                    {
 636                        for lsp_definition in &lsp_definitions {
 637                            let SourceRange {
 638                                path,
 639                                point_range,
 640                                offset_range,
 641                            } = lsp_definition;
 642                            let lsp_point_range =
 643                                SerializablePoint::into_language_point_range(point_range.clone());
 644                            has_external_definition = has_external_definition
 645                                || path.is_absolute()
 646                                || path
 647                                    .components()
 648                                    .any(|component| component.as_os_str() == "node_modules");
 649                            let is_match = path.as_path()
 650                                == retrieved_definition.path.as_std_path()
 651                                && retrieved_definition
 652                                    .range
 653                                    .contains_inclusive(&lsp_point_range);
 654                            if is_match {
 655                                if best_match.is_none() {
 656                                    best_match = Some(index);
 657                                }
 658                            }
 659                            in_excerpt = in_excerpt
 660                                || retrieve_result.excerpt_range.as_ref().is_some_and(
 661                                    |excerpt_range| excerpt_range.contains_inclusive(&offset_range),
 662                                );
 663                        }
 664                    }
 665
 666                    let outcome = if let Some(best_match) = best_match {
 667                        RetrievalOutcome::Match { best_match }
 668                    } else if has_external_definition {
 669                        RetrievalOutcome::NoMatchDueToExternalLspDefinitions
 670                    } else if in_excerpt {
 671                        RetrievalOutcome::ProbablyLocal
 672                    } else {
 673                        RetrievalOutcome::NoMatch
 674                    };
 675
 676                    let result = RetrievalStatsResult {
 677                        outcome,
 678                        path: path.clone(),
 679                        identifier: reference.identifier,
 680                        point: query_point,
 681                        lsp_definitions,
 682                        retrieved_definitions: retrieve_result.definitions,
 683                    };
 684
 685                    output_tx.unbounded_send(result).ok();
 686                }
 687
 688                println!(
 689                    "{:02}/{:02} done",
 690                    done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1,
 691                    files_len,
 692                );
 693
 694                anyhow::Ok(())
 695            })
 696        })
 697        .collect::<Vec<_>>();
 698
 699    drop(output_tx);
 700
 701    let results_task = cx.background_spawn(async move {
 702        let mut results = Vec::new();
 703        while let Some(result) = output_rx.next().await {
 704            output
 705                .write_all(format!("{:#?}\n", result).as_bytes())
 706                .log_err();
 707            results.push(result)
 708        }
 709        results
 710    });
 711
 712    futures::future::try_join_all(tasks).await?;
 713    println!("Tasks completed");
 714    let results = results_task.await;
 715    println!("Results received");
 716
 717    let mut references_count = 0;
 718
 719    let mut included_count = 0;
 720    let mut both_absent_count = 0;
 721
 722    let mut retrieved_count = 0;
 723    let mut top_match_count = 0;
 724    let mut non_top_match_count = 0;
 725    let mut ranking_involved_top_match_count = 0;
 726
 727    let mut no_match_count = 0;
 728    let mut no_match_none_retrieved = 0;
 729    let mut no_match_wrong_retrieval = 0;
 730
 731    let mut expected_no_match_count = 0;
 732    let mut in_excerpt_count = 0;
 733    let mut external_definition_count = 0;
 734
 735    for result in results {
 736        references_count += 1;
 737        match &result.outcome {
 738            RetrievalOutcome::Match { best_match } => {
 739                included_count += 1;
 740                retrieved_count += 1;
 741                let multiple = result.retrieved_definitions.len() > 1;
 742                if *best_match == 0 {
 743                    top_match_count += 1;
 744                    if multiple {
 745                        ranking_involved_top_match_count += 1;
 746                    }
 747                } else {
 748                    non_top_match_count += 1;
 749                }
 750            }
 751            RetrievalOutcome::NoMatch => {
 752                if result.lsp_definitions.is_empty() {
 753                    included_count += 1;
 754                    both_absent_count += 1;
 755                } else {
 756                    no_match_count += 1;
 757                    if result.retrieved_definitions.is_empty() {
 758                        no_match_none_retrieved += 1;
 759                    } else {
 760                        no_match_wrong_retrieval += 1;
 761                    }
 762                }
 763            }
 764            RetrievalOutcome::NoMatchDueToExternalLspDefinitions => {
 765                expected_no_match_count += 1;
 766                external_definition_count += 1;
 767            }
 768            RetrievalOutcome::ProbablyLocal => {
 769                included_count += 1;
 770                in_excerpt_count += 1;
 771            }
 772        }
 773    }
 774
 775    fn count_and_percentage(part: usize, total: usize) -> String {
 776        format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0)
 777    }
 778
 779    println!("");
 780    println!("╮ references: {}", references_count);
 781    println!(
 782        "├─╮ included: {}",
 783        count_and_percentage(included_count, references_count),
 784    );
 785    println!(
 786        "│ ├─╮ retrieved: {}",
 787        count_and_percentage(retrieved_count, references_count)
 788    );
 789    println!(
 790        "│ │ ├─╮ top match : {}",
 791        count_and_percentage(top_match_count, retrieved_count)
 792    );
 793    println!(
 794        "│ │ │ ╰─╴ involving ranking: {}",
 795        count_and_percentage(ranking_involved_top_match_count, top_match_count)
 796    );
 797    println!(
 798        "│ │ ╰─╴ non-top match: {}",
 799        count_and_percentage(non_top_match_count, retrieved_count)
 800    );
 801    println!(
 802        "│ ├─╴ both absent: {}",
 803        count_and_percentage(both_absent_count, included_count)
 804    );
 805    println!(
 806        "│ ╰─╴ in excerpt: {}",
 807        count_and_percentage(in_excerpt_count, included_count)
 808    );
 809    println!(
 810        "├─╮ no match: {}",
 811        count_and_percentage(no_match_count, references_count)
 812    );
 813    println!(
 814        "│ ├─╴ none retrieved: {}",
 815        count_and_percentage(no_match_none_retrieved, no_match_count)
 816    );
 817    println!(
 818        "│ ╰─╴ wrong retrieval: {}",
 819        count_and_percentage(no_match_wrong_retrieval, no_match_count)
 820    );
 821    println!(
 822        "╰─╮ expected no match: {}",
 823        count_and_percentage(expected_no_match_count, references_count)
 824    );
 825    println!(
 826        "  ╰─╴ external definition: {}",
 827        count_and_percentage(external_definition_count, expected_no_match_count)
 828    );
 829
 830    println!("");
 831    println!("LSP definition cache at {}", lsp_definitions_path.display());
 832
 833    Ok("".to_string())
 834}
 835
 836struct RetrieveResult {
 837    definitions: Vec<RetrievedDefinition>,
 838    excerpt_range: Option<Range<usize>>,
 839}
 840
 841async fn retrieve_definitions(
 842    reference: &Reference,
 843    imports: &Imports,
 844    query_point: Point,
 845    snapshot: &BufferSnapshot,
 846    index: &Arc<SyntaxIndexState>,
 847    file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
 848    options: &Arc<zeta2::ZetaOptions>,
 849) -> Result<RetrieveResult> {
 850    let mut single_reference_map = HashMap::default();
 851    single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
 852    let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
 853        query_point,
 854        snapshot,
 855        imports,
 856        &options.context,
 857        Some(&index),
 858        |_, _, _| single_reference_map,
 859    );
 860
 861    let Some(edit_prediction_context) = edit_prediction_context else {
 862        return Ok(RetrieveResult {
 863            definitions: Vec::new(),
 864            excerpt_range: None,
 865        });
 866    };
 867
 868    let mut retrieved_definitions = Vec::new();
 869    for scored_declaration in edit_prediction_context.declarations {
 870        match &scored_declaration.declaration {
 871            Declaration::File {
 872                project_entry_id,
 873                declaration,
 874                ..
 875            } => {
 876                let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
 877                    log::error!("bug: file project entry not found");
 878                    continue;
 879                };
 880                let path = snapshot.file().unwrap().path().clone();
 881                retrieved_definitions.push(RetrievedDefinition {
 882                    path,
 883                    range: snapshot.offset_to_point(declaration.item_range.start)
 884                        ..snapshot.offset_to_point(declaration.item_range.end),
 885                    score: scored_declaration.score(DeclarationStyle::Declaration),
 886                    retrieval_score: scored_declaration.retrieval_score(),
 887                    components: scored_declaration.components,
 888                });
 889            }
 890            Declaration::Buffer {
 891                project_entry_id,
 892                rope,
 893                declaration,
 894                ..
 895            } => {
 896                let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
 897                    // This case happens when dependency buffers have been opened by
 898                    // go-to-definition, resulting in single-file worktrees.
 899                    continue;
 900                };
 901                let path = snapshot.file().unwrap().path().clone();
 902                retrieved_definitions.push(RetrievedDefinition {
 903                    path,
 904                    range: rope.offset_to_point(declaration.item_range.start)
 905                        ..rope.offset_to_point(declaration.item_range.end),
 906                    score: scored_declaration.score(DeclarationStyle::Declaration),
 907                    retrieval_score: scored_declaration.retrieval_score(),
 908                    components: scored_declaration.components,
 909                });
 910            }
 911        }
 912    }
 913    retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score)));
 914
 915    Ok(RetrieveResult {
 916        definitions: retrieved_definitions,
 917        excerpt_range: Some(edit_prediction_context.excerpt.range),
 918    })
 919}
 920
 921async fn gather_lsp_definitions(
 922    files: &[ProjectPath],
 923    worktree: &Entity<Worktree>,
 924    project: &Entity<Project>,
 925    cx: &mut AsyncApp,
 926) -> Result<LspResults> {
 927    let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
 928
 929    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
 930    cx.subscribe(&lsp_store, {
 931        move |_, event, _| {
 932            if let project::LspStoreEvent::LanguageServerUpdate {
 933                message:
 934                    client::proto::update_language_server::Variant::WorkProgress(
 935                        client::proto::LspWorkProgress {
 936                            message: Some(message),
 937                            ..
 938                        },
 939                    ),
 940                ..
 941            } = event
 942            {
 943                println!("{message}")
 944            }
 945        }
 946    })?
 947    .detach();
 948
 949    let mut definitions = HashMap::default();
 950    let mut error_count = 0;
 951    let mut lsp_open_handles = Vec::new();
 952    let mut ready_languages = HashSet::default();
 953    for (file_index, project_path) in files.iter().enumerate() {
 954        println!(
 955            "Processing file {} of {}: {}",
 956            file_index + 1,
 957            files.len(),
 958            project_path.path.display(PathStyle::Posix)
 959        );
 960
 961        let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
 962            project.clone(),
 963            worktree.clone(),
 964            project_path.path.clone(),
 965            &mut ready_languages,
 966            cx,
 967        )
 968        .await
 969        .log_err() else {
 970            continue;
 971        };
 972        lsp_open_handles.push(lsp_open_handle);
 973
 974        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 975        let full_range = 0..snapshot.len();
 976        let references = references_in_range(
 977            full_range,
 978            &snapshot.text(),
 979            ReferenceRegion::Nearby,
 980            &snapshot,
 981        );
 982
 983        loop {
 984            let is_ready = lsp_store
 985                .read_with(cx, |lsp_store, _cx| {
 986                    lsp_store
 987                        .language_server_statuses
 988                        .get(&language_server_id)
 989                        .is_some_and(|status| status.pending_work.is_empty())
 990                })
 991                .unwrap();
 992            if is_ready {
 993                break;
 994            }
 995            cx.background_executor()
 996                .timer(Duration::from_millis(10))
 997                .await;
 998        }
 999
1000        for reference in references {
1001            // TODO: Rename declaration to definition in edit_prediction_context?
1002            let lsp_result = project
1003                .update(cx, |project, cx| {
1004                    project.definitions(&buffer, reference.range.start, cx)
1005                })?
1006                .await;
1007
1008            match lsp_result {
1009                Ok(lsp_definitions) => {
1010                    let mut targets = Vec::new();
1011                    for target in lsp_definitions.unwrap_or_default() {
1012                        let buffer = target.target.buffer;
1013                        let anchor_range = target.target.range;
1014                        buffer.read_with(cx, |buffer, cx| {
1015                            let Some(file) = project::File::from_dyn(buffer.file()) else {
1016                                return;
1017                            };
1018                            let file_worktree = file.worktree.read(cx);
1019                            let file_worktree_id = file_worktree.id();
1020                            // Relative paths for worktree files, absolute for all others
1021                            let path = if worktree_id != file_worktree_id {
1022                                file.worktree.read(cx).absolutize(&file.path)
1023                            } else {
1024                                file.path.as_std_path().to_path_buf()
1025                            };
1026                            let offset_range = anchor_range.to_offset(&buffer);
1027                            let point_range = SerializablePoint::from_language_point_range(
1028                                offset_range.to_point(&buffer),
1029                            );
1030                            targets.push(SourceRange {
1031                                path,
1032                                offset_range,
1033                                point_range,
1034                            });
1035                        })?;
1036                    }
1037
1038                    definitions.insert(
1039                        SourceLocation {
1040                            path: project_path.path.clone(),
1041                            point: snapshot.offset_to_point(reference.range.start),
1042                        },
1043                        targets,
1044                    );
1045                }
1046                Err(err) => {
1047                    log::error!("Language server error: {err}");
1048                    error_count += 1;
1049                }
1050            }
1051        }
1052    }
1053
1054    log::error!("Encountered {} language server errors", error_count);
1055
1056    Ok(LspResults { definitions })
1057}
1058
1059#[derive(Debug, Clone, Serialize, Deserialize)]
1060#[serde(transparent)]
1061struct LspResults {
1062    definitions: HashMap<SourceLocation, Vec<SourceRange>>,
1063}
1064
1065#[derive(Debug, Clone, Serialize, Deserialize)]
1066struct SourceRange {
1067    path: PathBuf,
1068    point_range: Range<SerializablePoint>,
1069    offset_range: Range<usize>,
1070}
1071
1072/// Serializes to 1-based row and column indices.
1073#[derive(Debug, Clone, Serialize, Deserialize)]
1074pub struct SerializablePoint {
1075    pub row: u32,
1076    pub column: u32,
1077}
1078
1079impl SerializablePoint {
1080    pub fn into_language_point_range(range: Range<Self>) -> Range<Point> {
1081        range.start.into()..range.end.into()
1082    }
1083
1084    pub fn from_language_point_range(range: Range<Point>) -> Range<Self> {
1085        range.start.into()..range.end.into()
1086    }
1087}
1088
1089impl From<Point> for SerializablePoint {
1090    fn from(point: Point) -> Self {
1091        SerializablePoint {
1092            row: point.row + 1,
1093            column: point.column + 1,
1094        }
1095    }
1096}
1097
1098impl From<SerializablePoint> for Point {
1099    fn from(serializable: SerializablePoint) -> Self {
1100        Point {
1101            row: serializable.row.saturating_sub(1),
1102            column: serializable.column.saturating_sub(1),
1103        }
1104    }
1105}
1106
1107#[derive(Debug)]
1108struct RetrievalStatsResult {
1109    outcome: RetrievalOutcome,
1110    #[allow(dead_code)]
1111    path: Arc<RelPath>,
1112    #[allow(dead_code)]
1113    identifier: Identifier,
1114    #[allow(dead_code)]
1115    point: Point,
1116    #[allow(dead_code)]
1117    lsp_definitions: Vec<SourceRange>,
1118    retrieved_definitions: Vec<RetrievedDefinition>,
1119}
1120
1121#[derive(Debug)]
1122enum RetrievalOutcome {
1123    Match {
1124        /// Lowest index within retrieved_definitions that matches an LSP definition.
1125        best_match: usize,
1126    },
1127    ProbablyLocal,
1128    NoMatch,
1129    NoMatchDueToExternalLspDefinitions,
1130}
1131
1132#[derive(Debug)]
1133struct RetrievedDefinition {
1134    path: Arc<RelPath>,
1135    range: Range<Point>,
1136    score: f32,
1137    #[allow(dead_code)]
1138    retrieval_score: f32,
1139    #[allow(dead_code)]
1140    components: DeclarationScoreComponents,
1141}
1142
1143pub fn open_buffer(
1144    project: Entity<Project>,
1145    worktree: Entity<Worktree>,
1146    path: Arc<RelPath>,
1147    cx: &AsyncApp,
1148) -> Task<Result<Entity<Buffer>>> {
1149    cx.spawn(async move |cx| {
1150        let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
1151            worktree_id: worktree.id(),
1152            path,
1153        })?;
1154
1155        let buffer = project
1156            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
1157            .await?;
1158
1159        let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
1160        while *parse_status.borrow() != ParseStatus::Idle {
1161            parse_status.changed().await?;
1162        }
1163
1164        Ok(buffer)
1165    })
1166}
1167
1168pub async fn open_buffer_with_language_server(
1169    project: Entity<Project>,
1170    worktree: Entity<Worktree>,
1171    path: Arc<RelPath>,
1172    ready_languages: &mut HashSet<LanguageId>,
1173    cx: &mut AsyncApp,
1174) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
1175    let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
1176
1177    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
1178        (
1179            project.register_buffer_with_language_servers(&buffer, cx),
1180            project.path_style(cx),
1181        )
1182    })?;
1183
1184    let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
1185        buffer.language().map(|language| language.id())
1186    })?
1187    else {
1188        return Err(anyhow!("No language for {}", path.display(path_style)));
1189    };
1190
1191    let log_prefix = path.display(path_style);
1192    if !ready_languages.contains(&language_id) {
1193        wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
1194        ready_languages.insert(language_id);
1195    }
1196
1197    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
1198
1199    // hacky wait for buffer to be registered with the language server
1200    for _ in 0..100 {
1201        let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
1202            buffer.update(cx, |buffer, cx| {
1203                lsp_store
1204                    .language_servers_for_local_buffer(&buffer, cx)
1205                    .next()
1206                    .map(|(_, language_server)| language_server.server_id())
1207            })
1208        })?
1209        else {
1210            cx.background_executor()
1211                .timer(Duration::from_millis(10))
1212                .await;
1213            continue;
1214        };
1215
1216        return Ok((lsp_open_handle, language_server_id, buffer));
1217    }
1218
1219    return Err(anyhow!("No language server found for buffer"));
1220}
1221
1222// TODO: Dedupe with similar function in crates/eval/src/instance.rs
1223pub fn wait_for_lang_server(
1224    project: &Entity<Project>,
1225    buffer: &Entity<Buffer>,
1226    log_prefix: String,
1227    cx: &mut AsyncApp,
1228) -> Task<Result<()>> {
1229    println!("{}⏵ Waiting for language server", log_prefix);
1230
1231    let (mut tx, mut rx) = mpsc::channel(1);
1232
1233    let lsp_store = project
1234        .read_with(cx, |project, _| project.lsp_store())
1235        .unwrap();
1236
1237    let has_lang_server = buffer
1238        .update(cx, |buffer, cx| {
1239            lsp_store.update(cx, |lsp_store, cx| {
1240                lsp_store
1241                    .language_servers_for_local_buffer(buffer, cx)
1242                    .next()
1243                    .is_some()
1244            })
1245        })
1246        .unwrap_or(false);
1247
1248    if has_lang_server {
1249        project
1250            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
1251            .unwrap()
1252            .detach();
1253    }
1254    let (mut added_tx, mut added_rx) = mpsc::channel(1);
1255
1256    let subscriptions = [
1257        cx.subscribe(&lsp_store, {
1258            let log_prefix = log_prefix.clone();
1259            move |_, event, _| {
1260                if let project::LspStoreEvent::LanguageServerUpdate {
1261                    message:
1262                        client::proto::update_language_server::Variant::WorkProgress(
1263                            client::proto::LspWorkProgress {
1264                                message: Some(message),
1265                                ..
1266                            },
1267                        ),
1268                    ..
1269                } = event
1270                {
1271                    println!("{}{message}", log_prefix)
1272                }
1273            }
1274        }),
1275        cx.subscribe(project, {
1276            let buffer = buffer.clone();
1277            move |project, event, cx| match event {
1278                project::Event::LanguageServerAdded(_, _, _) => {
1279                    let buffer = buffer.clone();
1280                    project
1281                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
1282                        .detach();
1283                    added_tx.try_send(()).ok();
1284                }
1285                project::Event::DiskBasedDiagnosticsFinished { .. } => {
1286                    tx.try_send(()).ok();
1287                }
1288                _ => {}
1289            }
1290        }),
1291    ];
1292
1293    cx.spawn(async move |cx| {
1294        if !has_lang_server {
1295            // some buffers never have a language server, so this aborts quickly in that case.
1296            let timeout = cx.background_executor().timer(Duration::from_secs(5));
1297            futures::select! {
1298                _ = added_rx.next() => {},
1299                _ = timeout.fuse() => {
1300                    anyhow::bail!("Waiting for language server add timed out after 5 seconds");
1301                }
1302            };
1303        }
1304        let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
1305        let result = futures::select! {
1306            _ = rx.next() => {
1307                println!("{}⚑ Language server idle", log_prefix);
1308                anyhow::Ok(())
1309            },
1310            _ = timeout.fuse() => {
1311                anyhow::bail!("LSP wait timed out after 5 minutes");
1312            }
1313        };
1314        drop(subscriptions);
1315        result
1316    })
1317}
1318
1319fn main() {
1320    zlog::init();
1321    zlog::init_output_stderr();
1322    let args = ZetaCliArgs::parse();
1323    let http_client = Arc::new(ReqwestClient::new());
1324    let app = Application::headless().with_http_client(http_client);
1325
1326    app.run(move |cx| {
1327        let app_state = Arc::new(headless::init(cx));
1328        cx.spawn(async move |cx| {
1329            let result = match args.command {
1330                Commands::Zeta2Context {
1331                    zeta2_args,
1332                    context_args,
1333                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
1334                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
1335                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
1336                    Err(err) => Err(err),
1337                },
1338                Commands::Context(context_args) => {
1339                    match get_context(None, context_args, &app_state, cx).await {
1340                        Ok(GetContextOutput::Zeta1(output)) => {
1341                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
1342                        }
1343                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
1344                        Err(err) => Err(err),
1345                    }
1346                }
1347                Commands::Predict {
1348                    predict_edits_body,
1349                    context_args,
1350                } => {
1351                    cx.spawn(async move |cx| {
1352                        let app_version = cx.update(|cx| AppVersion::global(cx))?;
1353                        app_state.client.sign_in(true, cx).await?;
1354                        let llm_token = LlmApiToken::default();
1355                        llm_token.refresh(&app_state.client).await?;
1356
1357                        let predict_edits_body =
1358                            if let Some(predict_edits_body) = predict_edits_body {
1359                                serde_json::from_str(&predict_edits_body.read_to_string().await?)?
1360                            } else if let Some(context_args) = context_args {
1361                                match get_context(None, context_args, &app_state, cx).await? {
1362                                    GetContextOutput::Zeta1(output) => output.body,
1363                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
1364                                }
1365                            } else {
1366                                return Err(anyhow!(
1367                                    "Expected either --predict-edits-body-file \
1368                                    or the required args of the `context` command."
1369                                ));
1370                            };
1371
1372                        let (response, _usage) =
1373                            Zeta::perform_predict_edits(PerformPredictEditsParams {
1374                                client: app_state.client.clone(),
1375                                llm_token,
1376                                app_version,
1377                                body: predict_edits_body,
1378                            })
1379                            .await?;
1380
1381                        Ok(response.output_excerpt)
1382                    })
1383                    .await
1384                }
1385                Commands::RetrievalStats {
1386                    zeta2_args,
1387                    worktree,
1388                    extension,
1389                    limit,
1390                    skip,
1391                } => {
1392                    retrieval_stats(
1393                        worktree,
1394                        app_state,
1395                        extension,
1396                        limit,
1397                        skip,
1398                        (&zeta2_args).to_options(false),
1399                        cx,
1400                    )
1401                    .await
1402                }
1403            };
1404            match result {
1405                Ok(output) => {
1406                    println!("{}", output);
1407                    let _ = cx.update(|cx| cx.quit());
1408                }
1409                Err(e) => {
1410                    eprintln!("Failed: {:?}", e);
1411                    exit(1);
1412                }
1413            }
1414        })
1415        .detach();
1416    });
1417}