Add an `eval` binary that evaluates our semantic index against CodeSearchNet (#17375)

Max Brunsfeld , Jason , Jason Mancuso , Nathan , and Richard created

This PR is the beginning of an evaluation framework for our AI features.
Right now, we're evaluating our semantic search feature against the
[CodeSearchNet](https://github.com/github/CodeSearchNet) code search
dataset. This dataset is very limited (for the most part, only 1 known
good search result per repo) but it has surfaced some problems with our
search already.

Release Notes:

- N/A

---------

Co-authored-by: Jason <jason@zed.dev>
Co-authored-by: Jason Mancuso <7891333+jvmncs@users.noreply.github.com>
Co-authored-by: Nathan <nathan@zed.dev>
Co-authored-by: Richard <richard@zed.dev>

Change summary

.github/workflows/ci.yml                             |  17 
Cargo.lock                                           |  27 
Cargo.toml                                           |   1 
crates/assistant/src/assistant_panel.rs              |   2 
crates/assistant/src/slash_command/file_command.rs   |   9 
crates/assistant/src/slash_command/search_command.rs |  54 
crates/evals/Cargo.toml                              |  37 
crates/evals/LICENSE-GPL                             |   1 
crates/evals/build.rs                                |  14 
crates/evals/src/eval.rs                             | 631 ++++++++++++++
crates/http_client/src/http_client.rs                |   3 
crates/semantic_index/src/embedding_index.rs         |  89 
crates/semantic_index/src/project_index.rs           |  17 
crates/semantic_index/src/semantic_index.rs          | 125 +
14 files changed, 882 insertions(+), 145 deletions(-)

Detailed changes

.github/workflows/ci.yml 🔗

@@ -101,7 +101,7 @@ jobs:
     timeout-minutes: 60
     name: (Linux) Run Clippy and tests
     runs-on:
-      - hosted-linux-x86-1
+      - buildjet-16vcpu-ubuntu-2204
     steps:
       - name: Add Rust to the PATH
         run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
@@ -111,6 +111,11 @@ jobs:
         with:
           clean: false
 
+      - name: Cache dependencies
+        uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
+        with:
+          save-if: ${{ github.ref == 'refs/heads/main' }}
+
       - name: Install Linux dependencies
         run: ./script/linux
 
@@ -264,7 +269,7 @@ jobs:
     timeout-minutes: 60
     name: Create a Linux bundle
     runs-on:
-      - hosted-linux-x86-1
+      - buildjet-16vcpu-ubuntu-2204
     if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
     needs: [linux_tests]
     env:
@@ -279,9 +284,6 @@ jobs:
       - name: Install Linux dependencies
         run: ./script/linux
 
-      - name: Limit target directory size
-        run: script/clear-target-dir-if-larger-than 100
-
       - name: Determine version and release channel
         if: ${{ startsWith(github.ref, 'refs/tags/v') }}
         run: |
@@ -335,7 +337,7 @@ jobs:
     timeout-minutes: 60
     name: Create arm64 Linux bundle
     runs-on:
-      - hosted-linux-arm-1
+      - buildjet-16vcpu-ubuntu-2204-arm
     if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
     needs: [linux_tests]
     env:
@@ -350,9 +352,6 @@ jobs:
       - name: Install Linux dependencies
         run: ./script/linux
 
-      - name: Limit target directory size
-        run: script/clear-target-dir-if-larger-than 100
-
       - name: Determine version and release channel
         if: ${{ startsWith(github.ref, 'refs/tags/v') }}
         run: |

Cargo.lock 🔗

@@ -4000,6 +4000,33 @@ dependencies = [
  "num-traits",
 ]
 
+[[package]]
+name = "evals"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "clap",
+ "client",
+ "clock",
+ "collections",
+ "env_logger",
+ "feature_flags",
+ "fs",
+ "git",
+ "gpui",
+ "http_client",
+ "language",
+ "languages",
+ "node_runtime",
+ "open_ai",
+ "project",
+ "semantic_index",
+ "serde",
+ "serde_json",
+ "settings",
+ "smol",
+]
+
 [[package]]
 name = "event-listener"
 version = "2.5.3"

Cargo.toml 🔗

@@ -27,6 +27,7 @@ members = [
     "crates/diagnostics",
     "crates/docs_preprocessor",
     "crates/editor",
+    "crates/evals",
     "crates/extension",
     "crates/extension_api",
     "crates/extension_cli",

crates/assistant/src/assistant_panel.rs 🔗

@@ -3282,7 +3282,7 @@ impl ContextEditor {
 
                     let fence = codeblock_fence_for_path(
                         filename.as_deref(),
-                        Some(selection.start.row..selection.end.row),
+                        Some(selection.start.row..=selection.end.row),
                     );
 
                     if let Some((line_comment_prefix, outline_text)) =

crates/assistant/src/slash_command/file_command.rs 🔗

@@ -8,7 +8,7 @@ use project::{PathMatchCandidateSet, Project};
 use serde::{Deserialize, Serialize};
 use std::{
     fmt::Write,
-    ops::Range,
+    ops::{Range, RangeInclusive},
     path::{Path, PathBuf},
     sync::{atomic::AtomicBool, Arc},
 };
@@ -342,7 +342,10 @@ fn collect_files(
     })
 }
 
-pub fn codeblock_fence_for_path(path: Option<&Path>, row_range: Option<Range<u32>>) -> String {
+pub fn codeblock_fence_for_path(
+    path: Option<&Path>,
+    row_range: Option<RangeInclusive<u32>>,
+) -> String {
     let mut text = String::new();
     write!(text, "```").unwrap();
 
@@ -357,7 +360,7 @@ pub fn codeblock_fence_for_path(path: Option<&Path>, row_range: Option<Range<u32
     }
 
     if let Some(row_range) = row_range {
-        write!(text, ":{}-{}", row_range.start + 1, row_range.end + 1).unwrap();
+        write!(text, ":{}-{}", row_range.start() + 1, row_range.end() + 1).unwrap();
     }
 
     text.push('\n');

crates/assistant/src/slash_command/search_command.rs 🔗

@@ -8,14 +8,12 @@ use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
 use feature_flags::FeatureFlag;
 use gpui::{AppContext, Task, WeakView};
 use language::{CodeLabel, LineEnding, LspAdapterDelegate};
-use semantic_index::SemanticDb;
+use semantic_index::{LoadedSearchResult, SemanticDb};
 use std::{
     fmt::Write,
-    path::PathBuf,
     sync::{atomic::AtomicBool, Arc},
 };
 use ui::{prelude::*, IconName};
-use util::ResultExt;
 use workspace::Workspace;
 
 pub(crate) struct SearchSlashCommandFeatureFlag;
@@ -107,52 +105,28 @@ impl SlashCommand for SearchSlashCommand {
                 })?
                 .await?;
 
-            let mut loaded_results = Vec::new();
-            for result in results {
-                let (full_path, file_content) =
-                    result.worktree.read_with(&cx, |worktree, _cx| {
-                        let entry_abs_path = worktree.abs_path().join(&result.path);
-                        let mut entry_full_path = PathBuf::from(worktree.root_name());
-                        entry_full_path.push(&result.path);
-                        let file_content = async {
-                            let entry_abs_path = entry_abs_path;
-                            fs.load(&entry_abs_path).await
-                        };
-                        (entry_full_path, file_content)
-                    })?;
-                if let Some(file_content) = file_content.await.log_err() {
-                    loaded_results.push((result, full_path, file_content));
-                }
-            }
+            let loaded_results = SemanticDb::load_results(results, &fs, &cx).await?;
 
             let output = cx
                 .background_executor()
                 .spawn(async move {
                     let mut text = format!("Search results for {query}:\n");
                     let mut sections = Vec::new();
-                    for (result, full_path, file_content) in loaded_results {
-                        let range_start = result.range.start.min(file_content.len());
-                        let range_end = result.range.end.min(file_content.len());
-
-                        let start_row = file_content[0..range_start].matches('\n').count() as u32;
-                        let end_row = file_content[0..range_end].matches('\n').count() as u32;
-                        let start_line_byte_offset = file_content[0..range_start]
-                            .rfind('\n')
-                            .map(|pos| pos + 1)
-                            .unwrap_or_default();
-                        let end_line_byte_offset = file_content[range_end..]
-                            .find('\n')
-                            .map(|pos| range_end + pos)
-                            .unwrap_or_else(|| file_content.len());
-
+                    for LoadedSearchResult {
+                        path,
+                        range,
+                        full_path,
+                        file_content,
+                        row_range,
+                    } in loaded_results
+                    {
                         let section_start_ix = text.len();
                         text.push_str(&codeblock_fence_for_path(
-                            Some(&result.path),
-                            Some(start_row..end_row),
+                            Some(&path),
+                            Some(row_range.clone()),
                         ));
 
-                        let mut excerpt =
-                            file_content[start_line_byte_offset..end_line_byte_offset].to_string();
+                        let mut excerpt = file_content[range].to_string();
                         LineEnding::normalize(&mut excerpt);
                         text.push_str(&excerpt);
                         writeln!(text, "\n```\n").unwrap();
@@ -161,7 +135,7 @@ impl SlashCommand for SearchSlashCommand {
                             section_start_ix..section_end_ix,
                             Some(&full_path),
                             false,
-                            Some(start_row + 1..end_row + 1),
+                            Some(row_range.start() + 1..row_range.end() + 1),
                         ));
                     }
 

crates/evals/Cargo.toml 🔗

@@ -0,0 +1,37 @@
+[package]
+name = "evals"
+description = "Evaluations for Zed's AI features"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[[bin]]
+name = "eval"
+path = "src/eval.rs"
+
+[dependencies]
+clap.workspace = true
+anyhow.workspace = true
+client.workspace = true
+clock.workspace = true
+collections.workspace = true
+env_logger.workspace = true
+feature_flags.workspace = true
+fs.workspace = true
+git.workspace = true
+gpui.workspace = true
+language.workspace = true
+languages.workspace = true
+http_client.workspace = true
+open_ai.workspace = true
+project.workspace = true
+settings.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+smol.workspace = true
+semantic_index.workspace = true
+node_runtime.workspace = true

crates/evals/build.rs 🔗

@@ -0,0 +1,14 @@
+fn main() {
+    if cfg!(target_os = "macos") {
+        println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
+
+        println!("cargo:rerun-if-env-changed=ZED_BUNDLE");
+        if std::env::var("ZED_BUNDLE").ok().as_deref() == Some("true") {
+            // Find WebRTC.framework in the Frameworks folder when running as part of an application bundle.
+            println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../Frameworks");
+        } else {
+            // Find WebRTC.framework as a sibling of the executable when running outside of an application bundle.
+            println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
+        }
+    }
+}

crates/evals/src/eval.rs 🔗

@@ -0,0 +1,631 @@
+use ::fs::{Fs, RealFs};
+use anyhow::Result;
+use clap::Parser;
+use client::{Client, UserStore};
+use clock::RealSystemClock;
+use collections::BTreeMap;
+use feature_flags::FeatureFlagAppExt as _;
+use git::GitHostingProviderRegistry;
+use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
+use http_client::{HttpClient, Method};
+use language::LanguageRegistry;
+use node_runtime::FakeNodeRuntime;
+use open_ai::OpenAiEmbeddingModel;
+use project::Project;
+use semantic_index::{OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status};
+use serde::{Deserialize, Serialize};
+use settings::SettingsStore;
+use smol::channel::bounded;
+use smol::io::AsyncReadExt;
+use smol::Timer;
+use std::ops::RangeInclusive;
+use std::time::Duration;
+use std::{
+    fs,
+    path::Path,
+    process::{exit, Command, Stdio},
+    sync::{
+        atomic::{AtomicUsize, Ordering::SeqCst},
+        Arc,
+    },
+};
+
+const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
+const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
+const EVAL_DB_PATH: &'static str = "target/eval_db";
+const SEARCH_RESULT_LIMIT: usize = 8;
+const SKIP_EVAL_PATH: &'static str = ".skip_eval";
+
+#[derive(clap::Parser)]
+#[command(author, version, about, long_about = None)]
+struct Cli {
+    #[command(subcommand)]
+    command: Commands,
+}
+
+#[derive(clap::Subcommand)]
+enum Commands {
+    Fetch {},
+    Run {
+        #[arg(long)]
+        repo: Option<String>,
+    },
+}
+
+#[derive(Clone, Deserialize, Serialize)]
+struct EvaluationProject {
+    repo: String,
+    sha: String,
+    queries: Vec<EvaluationQuery>,
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize)]
+struct EvaluationQuery {
+    query: String,
+    expected_results: Vec<EvaluationSearchResult>,
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
+struct EvaluationSearchResult {
+    file: String,
+    lines: RangeInclusive<u32>,
+}
+
+#[derive(Clone, Deserialize, Serialize)]
+struct EvaluationProjectOutcome {
+    repo: String,
+    sha: String,
+    queries: Vec<EvaluationQueryOutcome>,
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize)]
+struct EvaluationQueryOutcome {
+    repo: String,
+    query: String,
+    expected_results: Vec<EvaluationSearchResult>,
+    actual_results: Vec<EvaluationSearchResult>,
+    covered_file_count: usize,
+    overlapped_result_count: usize,
+    covered_result_count: usize,
+    total_result_count: usize,
+    covered_result_indices: Vec<usize>,
+}
+
+fn main() -> Result<()> {
+    let cli = Cli::parse();
+    env_logger::init();
+
+    gpui::App::headless().run(move |cx| {
+        let executor = cx.background_executor().clone();
+
+        match cli.command {
+            Commands::Fetch {} => {
+                executor
+                    .clone()
+                    .spawn(async move {
+                        if let Err(err) = fetch_evaluation_resources(&executor).await {
+                            eprintln!("Error: {}", err);
+                            exit(1);
+                        }
+                        exit(0);
+                    })
+                    .detach();
+            }
+            Commands::Run { repo } => {
+                cx.spawn(|mut cx| async move {
+                    if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
+                        eprintln!("Error: {}", err);
+                        exit(1);
+                    }
+                    exit(0);
+                })
+                .detach();
+            }
+        }
+    });
+
+    Ok(())
+}
+
+async fn fetch_evaluation_resources(executor: &BackgroundExecutor) -> Result<()> {
+    let http_client = http_client::HttpClientWithProxy::new(None, None);
+    fetch_code_search_net_resources(&http_client).await?;
+    fetch_eval_repos(executor, &http_client).await?;
+    Ok(())
+}
+
+async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
+    eprintln!("Fetching CodeSearchNet evaluations...");
+
+    let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
+
+    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
+    fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
+
+    // Fetch the annotations CSV, which contains the human-annotated search relevances
+    let annotations_path = dataset_dir.join("annotations.csv");
+    let annotations_csv_content = if annotations_path.exists() {
+        fs::read_to_string(&annotations_path).expect("failed to read annotations")
+    } else {
+        let response = http_client
+            .get(annotations_url, Default::default(), true)
+            .await
+            .expect("failed to fetch annotations csv");
+        let mut body = String::new();
+        response
+            .into_body()
+            .read_to_string(&mut body)
+            .await
+            .expect("failed to read annotations.csv response");
+        fs::write(annotations_path, &body).expect("failed to write annotations.csv");
+        body
+    };
+
+    // Parse the annotations CSV. Skip over queries with zero relevance.
+    let rows = annotations_csv_content.lines().filter_map(|line| {
+        let mut values = line.split(',');
+        let _language = values.next()?;
+        let query = values.next()?;
+        let github_url = values.next()?;
+        let score = values.next()?;
+
+        if score == "0" {
+            return None;
+        }
+
+        let url_path = github_url.strip_prefix("https://github.com/")?;
+        let (url_path, hash) = url_path.split_once('#')?;
+        let (repo_name, url_path) = url_path.split_once("/blob/")?;
+        let (sha, file_path) = url_path.split_once('/')?;
+        let line_range = if let Some((start, end)) = hash.split_once('-') {
+            start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
+        } else {
+            let row = hash.strip_prefix("L")?.parse().ok()?;
+            row..=row
+        };
+        Some((repo_name, sha, query, file_path, line_range))
+    });
+
+    // Group the annotations by repo and sha.
+    let mut evaluations_by_repo = BTreeMap::new();
+    for (repo_name, sha, query, file_path, lines) in rows {
+        let evaluation_project = evaluations_by_repo
+            .entry((repo_name, sha))
+            .or_insert_with(|| EvaluationProject {
+                repo: repo_name.to_string(),
+                sha: sha.to_string(),
+                queries: Vec::new(),
+            });
+
+        let ix = evaluation_project
+            .queries
+            .iter()
+            .position(|entry| entry.query == query)
+            .unwrap_or_else(|| {
+                evaluation_project.queries.push(EvaluationQuery {
+                    query: query.to_string(),
+                    expected_results: Vec::new(),
+                });
+                evaluation_project.queries.len() - 1
+            });
+        let results = &mut evaluation_project.queries[ix].expected_results;
+        let result = EvaluationSearchResult {
+            file: file_path.to_string(),
+            lines,
+        };
+        if !results.contains(&result) {
+            results.push(result);
+        }
+    }
+
+    let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
+    let evaluations_path = dataset_dir.join("evaluations.json");
+    fs::write(
+        &evaluations_path,
+        serde_json::to_vec_pretty(&evaluations).unwrap(),
+    )
+    .unwrap();
+
+    eprintln!(
+        "Fetched CodeSearchNet evaluations into {}",
+        evaluations_path.display()
+    );
+
+    Ok(())
+}
+
+async fn run_evaluation(
+    only_repo: Option<String>,
+    executor: &BackgroundExecutor,
+    cx: &mut AsyncAppContext,
+) -> Result<()> {
+    cx.update(|cx| {
+        let mut store = SettingsStore::new(cx);
+        store
+            .set_default_settings(settings::default_settings().as_ref(), cx)
+            .unwrap();
+        cx.set_global(store);
+        client::init_settings(cx);
+        language::init(cx);
+        Project::init_settings(cx);
+        cx.update_flags(false, vec![]);
+    })
+    .unwrap();
+
+    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
+    let evaluations_path = dataset_dir.join("evaluations.json");
+    let repos_dir = Path::new(EVAL_REPOS_DIR);
+    let db_path = Path::new(EVAL_DB_PATH);
+    let http_client = http_client::HttpClientWithProxy::new(None, None);
+    let api_key = std::env::var("OPENAI_API_KEY").unwrap();
+    let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
+    let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
+    let clock = Arc::new(RealSystemClock);
+    let client = cx
+        .update(|cx| {
+            Client::new(
+                clock,
+                Arc::new(http_client::HttpClientWithUrl::new(
+                    "https://zed.dev",
+                    None,
+                    None,
+                )),
+                cx,
+            )
+        })
+        .unwrap();
+    let user_store = cx
+        .new_model(|cx| UserStore::new(client.clone(), cx))
+        .unwrap();
+    let node_runtime = Arc::new(FakeNodeRuntime {});
+
+    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
+    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
+
+    let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
+        http_client.clone(),
+        OpenAiEmbeddingModel::TextEmbedding3Small,
+        open_ai::OPEN_AI_API_URL.to_string(),
+        api_key,
+    ));
+
+    let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
+    cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
+        .unwrap();
+
+    let mut covered_result_count = 0;
+    let mut overlapped_result_count = 0;
+    let mut covered_file_count = 0;
+    let mut total_result_count = 0;
+    eprint!("Running evals.");
+
+    for evaluation_project in evaluations {
+        if only_repo
+            .as_ref()
+            .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
+        {
+            continue;
+        }
+
+        eprint!("\r\x1B[2K");
+        eprint!(
+            "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
+            covered_result_count,
+            total_result_count,
+            overlapped_result_count,
+            total_result_count,
+            covered_file_count,
+            total_result_count,
+            evaluation_project.repo
+        );
+
+        let repo_db_path =
+            db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
+        let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider.clone(), cx)
+            .await
+            .unwrap();
+
+        let repo_dir = repos_dir.join(&evaluation_project.repo);
+        if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
+            eprintln!("Skipping {}: directory not found", evaluation_project.repo);
+            continue;
+        }
+
+        let project = cx
+            .update(|cx| {
+                Project::local(
+                    client.clone(),
+                    node_runtime.clone(),
+                    user_store.clone(),
+                    language_registry.clone(),
+                    fs.clone(),
+                    None,
+                    cx,
+                )
+            })
+            .unwrap();
+
+        let (worktree, _) = project
+            .update(cx, |project, cx| {
+                project.find_or_create_worktree(repo_dir, true, cx)
+            })?
+            .await?;
+
+        worktree
+            .update(cx, |worktree, _| {
+                worktree.as_local().unwrap().scan_complete()
+            })
+            .unwrap()
+            .await;
+
+        let project_index = cx
+            .update(|cx| semantic_index.create_project_index(project.clone(), cx))
+            .unwrap();
+        wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
+
+        for query in evaluation_project.queries {
+            let results = cx
+                .update(|cx| {
+                    let project_index = project_index.read(cx);
+                    project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx)
+                })
+                .unwrap()
+                .await
+                .unwrap();
+
+            let results = SemanticDb::load_results(results, &fs.clone(), &cx)
+                .await
+                .unwrap();
+
+            let mut project_covered_result_count = 0;
+            let mut project_overlapped_result_count = 0;
+            let mut project_covered_file_count = 0;
+            let mut covered_result_indices = Vec::new();
+            for expected_result in &query.expected_results {
+                let mut file_matched = false;
+                let mut range_overlapped = false;
+                let mut range_covered = false;
+
+                for (ix, result) in results.iter().enumerate() {
+                    if result.path.as_ref() == Path::new(&expected_result.file) {
+                        file_matched = true;
+                        let start_matched =
+                            result.row_range.contains(&expected_result.lines.start());
+                        let end_matched = result.row_range.contains(&expected_result.lines.end());
+
+                        if start_matched || end_matched {
+                            range_overlapped = true;
+                        }
+
+                        if start_matched && end_matched {
+                            range_covered = true;
+                            covered_result_indices.push(ix);
+                            break;
+                        }
+                    }
+                }
+
+                if range_covered {
+                    project_covered_result_count += 1
+                };
+                if range_overlapped {
+                    project_overlapped_result_count += 1
+                };
+                if file_matched {
+                    project_covered_file_count += 1
+                };
+            }
+            let outcome_repo = evaluation_project.repo.clone();
+
+            let query_results = EvaluationQueryOutcome {
+                repo: outcome_repo,
+                query: query.query,
+                total_result_count: query.expected_results.len(),
+                covered_result_count: project_covered_result_count,
+                overlapped_result_count: project_overlapped_result_count,
+                covered_file_count: project_covered_file_count,
+                expected_results: query.expected_results,
+                actual_results: results
+                    .iter()
+                    .map(|result| EvaluationSearchResult {
+                        file: result.path.to_string_lossy().to_string(),
+                        lines: result.row_range.clone(),
+                    })
+                    .collect(),
+                covered_result_indices,
+            };
+
+            overlapped_result_count += query_results.overlapped_result_count;
+            covered_result_count += query_results.covered_result_count;
+            covered_file_count += query_results.covered_file_count;
+            total_result_count += query_results.total_result_count;
+
+            println!("{}", serde_json::to_string(&query_results).unwrap());
+        }
+    }
+
+    eprint!(
+        "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured.",
+        covered_result_count,
+        total_result_count,
+        overlapped_result_count,
+        total_result_count,
+        covered_file_count,
+        total_result_count,
+    );
+
+    Ok(())
+}
+
+async fn wait_for_indexing_complete(
+    project_index: &Model<ProjectIndex>,
+    cx: &mut AsyncAppContext,
+    timeout: Option<Duration>,
+) {
+    let (tx, rx) = bounded(1);
+    let subscription = cx.update(|cx| {
+        cx.subscribe(project_index, move |_, event, _| {
+            if let Status::Idle = event {
+                let _ = tx.try_send(*event);
+            }
+        })
+    });
+
+    let result = match timeout {
+        Some(timeout_duration) => {
+            smol::future::or(
+                async {
+                    rx.recv().await.map_err(|_| ())?;
+                    Ok(())
+                },
+                async {
+                    Timer::after(timeout_duration).await;
+                    Err(())
+                },
+            )
+            .await
+        }
+        None => rx.recv().await.map(|_| ()).map_err(|_| ()),
+    };
+
+    match result {
+        Ok(_) => (),
+        Err(_) => {
+            if let Some(timeout) = timeout {
+                eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
+            }
+        }
+    }
+
+    drop(subscription);
+}
+
+async fn fetch_eval_repos(
+    executor: &BackgroundExecutor,
+    http_client: &dyn HttpClient,
+) -> Result<()> {
+    let dataset_dir = Path::new(CODESEARCH_NET_DIR);
+    let evaluations_path = dataset_dir.join("evaluations.json");
+    let repos_dir = Path::new(EVAL_REPOS_DIR);
+
+    let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
+    let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
+
+    eprint!("Fetching evaluation repositories...");
+
+    executor
+        .scoped(move |scope| {
+            let done_count = Arc::new(AtomicUsize::new(0));
+            let len = evaluations.len();
+            for chunk in evaluations.chunks(evaluations.len() / 8) {
+                let chunk = chunk.to_vec();
+                let done_count = done_count.clone();
+                scope.spawn(async move {
+                    for EvaluationProject { repo, sha, .. } in chunk {
+                        eprint!(
+                            "\rFetching evaluation repositories ({}/{})...",
+                            done_count.load(SeqCst),
+                            len,
+                        );
+
+                        fetch_eval_repo(repo, sha, repos_dir, http_client).await;
+                        done_count.fetch_add(1, SeqCst);
+                    }
+                });
+            }
+        })
+        .await;
+
+    Ok(())
+}
+
+async fn fetch_eval_repo(
+    repo: String,
+    sha: String,
+    repos_dir: &Path,
+    http_client: &dyn HttpClient,
+) {
+    let Some((owner, repo_name)) = repo.split_once('/') else {
+        return;
+    };
+    let repo_dir = repos_dir.join(owner).join(repo_name);
+    fs::create_dir_all(&repo_dir).unwrap();
+    let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
+    if skip_eval_path.exists() {
+        return;
+    }
+    if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
+        if head_content.trim() == sha {
+            return;
+        }
+    }
+    let repo_response = http_client
+        .send(
+            http_client::Request::builder()
+                .method(Method::HEAD)
+                .uri(format!("https://github.com/{}", repo))
+                .body(Default::default())
+                .expect(""),
+        )
+        .await
+        .expect("failed to check github repo");
+    if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
+        fs::write(&skip_eval_path, "").unwrap();
+        eprintln!(
+            "Repo {repo} is no longer public ({:?}). Skipping",
+            repo_response.status()
+        );
+        return;
+    }
+    if !repo_dir.join(".git").exists() {
+        let init_output = Command::new("git")
+            .current_dir(&repo_dir)
+            .args(&["init"])
+            .output()
+            .unwrap();
+        if !init_output.status.success() {
+            eprintln!(
+                "Failed to initialize git repository for {}: {}",
+                repo,
+                String::from_utf8_lossy(&init_output.stderr)
+            );
+            return;
+        }
+    }
+    let url = format!("https://github.com/{}.git", repo);
+    Command::new("git")
+        .current_dir(&repo_dir)
+        .args(&["remote", "add", "-f", "origin", &url])
+        .stdin(Stdio::null())
+        .output()
+        .unwrap();
+    let fetch_output = Command::new("git")
+        .current_dir(&repo_dir)
+        .args(&["fetch", "--depth", "1", "origin", &sha])
+        .stdin(Stdio::null())
+        .output()
+        .unwrap();
+    if !fetch_output.status.success() {
+        eprintln!(
+            "Failed to fetch {} for {}: {}",
+            sha,
+            repo,
+            String::from_utf8_lossy(&fetch_output.stderr)
+        );
+        return;
+    }
+    let checkout_output = Command::new("git")
+        .current_dir(&repo_dir)
+        .args(&["checkout", &sha])
+        .output()
+        .unwrap();
+
+    if !checkout_output.status.success() {
+        eprintln!(
+            "Failed to checkout {} for {}: {}",
+            sha,
+            repo,
+            String::from_utf8_lossy(&checkout_output.stderr)
+        );
+    }
+}

crates/http_client/src/http_client.rs 🔗

@@ -5,6 +5,7 @@ use derive_more::Deref;
 use futures::future::BoxFuture;
 use futures_lite::FutureExt;
 use isahc::config::{Configurable, RedirectPolicy};
+pub use isahc::http;
 pub use isahc::{
     http::{Method, StatusCode, Uri},
     AsyncBody, Error, HttpClient as IsahcHttpClient, Request, Response,
@@ -226,7 +227,7 @@ pub fn client(user_agent: Option<String>, proxy: Option<Uri>) -> Arc<dyn HttpCli
         // those requests use a different http client, because global timeouts
         // of 50 and 60 seconds, respectively, would be very high!
         .connect_timeout(Duration::from_secs(5))
-        .low_speed_timeout(100, Duration::from_secs(5))
+        .low_speed_timeout(100, Duration::from_secs(30))
         .proxy(proxy.clone());
     if let Some(user_agent) = user_agent {
         builder = builder.default_header("User-Agent", user_agent);

crates/semantic_index/src/embedding_index.rs 🔗

@@ -234,30 +234,25 @@ impl EmbeddingIndex {
                         cx.spawn(async {
                             while let Ok((entry, handle)) = entries.recv().await {
                                 let entry_abs_path = worktree_abs_path.join(&entry.path);
-                                match fs.load(&entry_abs_path).await {
-                                    Ok(text) => {
-                                        let language = language_registry
-                                            .language_for_file_path(&entry.path)
-                                            .await
-                                            .ok();
-                                        let chunked_file = ChunkedFile {
-                                            chunks: chunking::chunk_text(
-                                                &text,
-                                                language.as_ref(),
-                                                &entry.path,
-                                            ),
-                                            handle,
-                                            path: entry.path,
-                                            mtime: entry.mtime,
-                                            text,
-                                        };
-
-                                        if chunked_files_tx.send(chunked_file).await.is_err() {
-                                            return;
-                                        }
-                                    }
-                                    Err(_)=> {
-                                        log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}");
+                                if let Some(text) = fs.load(&entry_abs_path).await.ok() {
+                                    let language = language_registry
+                                        .language_for_file_path(&entry.path)
+                                        .await
+                                        .ok();
+                                    let chunked_file = ChunkedFile {
+                                        chunks: chunking::chunk_text(
+                                            &text,
+                                            language.as_ref(),
+                                            &entry.path,
+                                        ),
+                                        handle,
+                                        path: entry.path,
+                                        mtime: entry.mtime,
+                                        text,
+                                    };
+
+                                    if chunked_files_tx.send(chunked_file).await.is_err() {
+                                        return;
                                     }
                                 }
                             }
@@ -358,33 +353,37 @@ impl EmbeddingIndex {
     fn persist_embeddings(
         &self,
         mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
-        embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
+        mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
         cx: &AppContext,
     ) -> Task<Result<()>> {
         let db_connection = self.db_connection.clone();
         let db = self.db;
-        cx.background_executor().spawn(async move {
-            while let Some(deletion_range) = deleted_entry_ranges.next().await {
-                let mut txn = db_connection.write_txn()?;
-                let start = deletion_range.0.as_ref().map(|start| start.as_str());
-                let end = deletion_range.1.as_ref().map(|end| end.as_str());
-                log::debug!("deleting embeddings in range {:?}", &(start, end));
-                db.delete_range(&mut txn, &(start, end))?;
-                txn.commit()?;
-            }
 
-            let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
-            while let Some(embedded_files) = embedded_files.next().await {
-                let mut txn = db_connection.write_txn()?;
-                for (file, _) in &embedded_files {
-                    log::debug!("saving embedding for file {:?}", file.path);
-                    let key = db_key_for_path(&file.path);
-                    db.put(&mut txn, &key, file)?;
+        cx.background_executor().spawn(async move {
+            loop {
+                // Interleave deletions and persists of embedded files
+                futures::select_biased! {
+                    deletion_range = deleted_entry_ranges.next() => {
+                        if let Some(deletion_range) = deletion_range {
+                            let mut txn = db_connection.write_txn()?;
+                            let start = deletion_range.0.as_ref().map(|start| start.as_str());
+                            let end = deletion_range.1.as_ref().map(|end| end.as_str());
+                            log::debug!("deleting embeddings in range {:?}", &(start, end));
+                            db.delete_range(&mut txn, &(start, end))?;
+                            txn.commit()?;
+                        }
+                    },
+                    file = embedded_files.next() => {
+                        if let Some((file, _)) = file {
+                            let mut txn = db_connection.write_txn()?;
+                            log::debug!("saving embedding for file {:?}", file.path);
+                            let key = db_key_for_path(&file.path);
+                            db.put(&mut txn, &key, &file)?;
+                            txn.commit()?;
+                        }
+                    },
+                    complete => break,
                 }
-                txn.commit()?;
-
-                drop(embedded_files);
-                log::debug!("committed");
             }
 
             Ok(())

crates/semantic_index/src/project_index.rs 🔗

@@ -15,7 +15,14 @@ use log;
 use project::{Project, Worktree, WorktreeId};
 use serde::{Deserialize, Serialize};
 use smol::channel;
-use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc};
+use std::{
+    cmp::Ordering,
+    future::Future,
+    num::NonZeroUsize,
+    ops::{Range, RangeInclusive},
+    path::{Path, PathBuf},
+    sync::Arc,
+};
 use util::ResultExt;
 
 #[derive(Debug)]
@@ -26,6 +33,14 @@ pub struct SearchResult {
     pub score: f32,
 }
 
+pub struct LoadedSearchResult {
+    pub path: Arc<Path>,
+    pub range: Range<usize>,
+    pub full_path: PathBuf,
+    pub file_content: String,
+    pub row_range: RangeInclusive<u32>,
+}
+
 pub struct WorktreeSearchResult {
     pub worktree_id: WorktreeId,
     pub path: Arc<Path>,

crates/semantic_index/src/semantic_index.rs 🔗

@@ -10,14 +10,16 @@ mod worktree_index;
 
 use anyhow::{Context as _, Result};
 use collections::HashMap;
+use fs::Fs;
 use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
 use project::Project;
-use project_index::ProjectIndex;
 use std::{path::PathBuf, sync::Arc};
 use ui::ViewContext;
+use util::ResultExt as _;
 use workspace::Workspace;
 
 pub use embedding::*;
+pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
 pub use project_index_debug_view::ProjectIndexDebugView;
 pub use summary_index::FileSummary;
 
@@ -56,27 +58,7 @@ impl SemanticDb {
 
                     if cx.has_global::<SemanticDb>() {
                         cx.update_global::<SemanticDb, _>(|this, cx| {
-                            let project_index = cx.new_model(|cx| {
-                                ProjectIndex::new(
-                                    project.clone(),
-                                    this.db_connection.clone(),
-                                    this.embedding_provider.clone(),
-                                    cx,
-                                )
-                            });
-
-                            let project_weak = project.downgrade();
-                            this.project_indices
-                                .insert(project_weak.clone(), project_index);
-
-                            cx.on_release(move |_, _, cx| {
-                                if cx.has_global::<SemanticDb>() {
-                                    cx.update_global::<SemanticDb, _>(|this, _| {
-                                        this.project_indices.remove(&project_weak);
-                                    })
-                                }
-                            })
-                            .detach();
+                            this.create_project_index(project, cx);
                         })
                     } else {
                         log::info!("No SemanticDb, skipping project index")
@@ -94,6 +76,50 @@ impl SemanticDb {
         })
     }
 
+    pub async fn load_results(
+        results: Vec<SearchResult>,
+        fs: &Arc<dyn Fs>,
+        cx: &AsyncAppContext,
+    ) -> Result<Vec<LoadedSearchResult>> {
+        let mut loaded_results = Vec::new();
+        for result in results {
+            let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| {
+                let entry_abs_path = worktree.abs_path().join(&result.path);
+                let mut entry_full_path = PathBuf::from(worktree.root_name());
+                entry_full_path.push(&result.path);
+                let file_content = async {
+                    let entry_abs_path = entry_abs_path;
+                    fs.load(&entry_abs_path).await
+                };
+                (entry_full_path, file_content)
+            })?;
+            if let Some(file_content) = file_content.await.log_err() {
+                let range_start = result.range.start.min(file_content.len());
+                let range_end = result.range.end.min(file_content.len());
+
+                let start_row = file_content[0..range_start].matches('\n').count() as u32;
+                let end_row = file_content[0..range_end].matches('\n').count() as u32;
+                let start_line_byte_offset = file_content[0..range_start]
+                    .rfind('\n')
+                    .map(|pos| pos + 1)
+                    .unwrap_or_default();
+                let end_line_byte_offset = file_content[range_end..]
+                    .find('\n')
+                    .map(|pos| range_end + pos)
+                    .unwrap_or_else(|| file_content.len());
+
+                loaded_results.push(LoadedSearchResult {
+                    path: result.path,
+                    range: start_line_byte_offset..end_line_byte_offset,
+                    full_path,
+                    file_content,
+                    row_range: start_row..=end_row,
+                });
+            }
+        }
+        Ok(loaded_results)
+    }
+
     pub fn project_index(
         &mut self,
         project: Model<Project>,
@@ -113,6 +139,36 @@ impl SemanticDb {
             })
         })
     }
+
+    pub fn create_project_index(
+        &mut self,
+        project: Model<Project>,
+        cx: &mut AppContext,
+    ) -> Model<ProjectIndex> {
+        let project_index = cx.new_model(|cx| {
+            ProjectIndex::new(
+                project.clone(),
+                self.db_connection.clone(),
+                self.embedding_provider.clone(),
+                cx,
+            )
+        });
+
+        let project_weak = project.downgrade();
+        self.project_indices
+            .insert(project_weak.clone(), project_index.clone());
+
+        cx.observe_release(&project, move |_, cx| {
+            if cx.has_global::<SemanticDb>() {
+                cx.update_global::<SemanticDb, _>(|this, _| {
+                    this.project_indices.remove(&project_weak);
+                })
+            }
+        })
+        .detach();
+
+        project_index
+    }
 }
 
 #[cfg(test)]
@@ -230,34 +286,13 @@ mod tests {
 
         let project = Project::test(fs, [project_path], cx).await;
 
-        cx.update(|cx| {
+        let project_index = cx.update(|cx| {
             let language_registry = project.read(cx).languages().clone();
             let node_runtime = project.read(cx).node_runtime().unwrap().clone();
             languages::init(language_registry, node_runtime, cx);
-
-            // Manually create and insert the ProjectIndex
-            let project_index = cx.new_model(|cx| {
-                ProjectIndex::new(
-                    project.clone(),
-                    semantic_index.db_connection.clone(),
-                    semantic_index.embedding_provider.clone(),
-                    cx,
-                )
-            });
-            semantic_index
-                .project_indices
-                .insert(project.downgrade(), project_index);
+            semantic_index.create_project_index(project.clone(), cx)
         });
 
-        let project_index = cx
-            .update(|_cx| {
-                semantic_index
-                    .project_indices
-                    .get(&project.downgrade())
-                    .cloned()
-            })
-            .unwrap();
-
         cx.run_until_parked();
         while cx
             .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))