Detailed changes
@@ -101,7 +101,7 @@ jobs:
timeout-minutes: 60
name: (Linux) Run Clippy and tests
runs-on:
- - hosted-linux-x86-1
+ - buildjet-16vcpu-ubuntu-2204
steps:
- name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
@@ -111,6 +111,11 @@ jobs:
with:
clean: false
+ - name: Cache dependencies
+ uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
+ with:
+ save-if: ${{ github.ref == 'refs/heads/main' }}
+
- name: Install Linux dependencies
run: ./script/linux
@@ -264,7 +269,7 @@ jobs:
timeout-minutes: 60
name: Create a Linux bundle
runs-on:
- - hosted-linux-x86-1
+ - buildjet-16vcpu-ubuntu-2204
if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
needs: [linux_tests]
env:
@@ -279,9 +284,6 @@ jobs:
- name: Install Linux dependencies
run: ./script/linux
- - name: Limit target directory size
- run: script/clear-target-dir-if-larger-than 100
-
- name: Determine version and release channel
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
run: |
@@ -335,7 +337,7 @@ jobs:
timeout-minutes: 60
name: Create arm64 Linux bundle
runs-on:
- - hosted-linux-arm-1
+ - buildjet-16vcpu-ubuntu-2204-arm
if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
needs: [linux_tests]
env:
@@ -350,9 +352,6 @@ jobs:
- name: Install Linux dependencies
run: ./script/linux
- - name: Limit target directory size
- run: script/clear-target-dir-if-larger-than 100
-
- name: Determine version and release channel
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
run: |
@@ -4000,6 +4000,33 @@ dependencies = [
"num-traits",
]
+[[package]]
+name = "evals"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "clap",
+ "client",
+ "clock",
+ "collections",
+ "env_logger",
+ "feature_flags",
+ "fs",
+ "git",
+ "gpui",
+ "http_client",
+ "language",
+ "languages",
+ "node_runtime",
+ "open_ai",
+ "project",
+ "semantic_index",
+ "serde",
+ "serde_json",
+ "settings",
+ "smol",
+]
+
[[package]]
name = "event-listener"
version = "2.5.3"
@@ -27,6 +27,7 @@ members = [
"crates/diagnostics",
"crates/docs_preprocessor",
"crates/editor",
+ "crates/evals",
"crates/extension",
"crates/extension_api",
"crates/extension_cli",
@@ -3282,7 +3282,7 @@ impl ContextEditor {
let fence = codeblock_fence_for_path(
filename.as_deref(),
- Some(selection.start.row..selection.end.row),
+ Some(selection.start.row..=selection.end.row),
);
if let Some((line_comment_prefix, outline_text)) =
@@ -8,7 +8,7 @@ use project::{PathMatchCandidateSet, Project};
use serde::{Deserialize, Serialize};
use std::{
fmt::Write,
- ops::Range,
+ ops::{Range, RangeInclusive},
path::{Path, PathBuf},
sync::{atomic::AtomicBool, Arc},
};
@@ -342,7 +342,10 @@ fn collect_files(
})
}
-pub fn codeblock_fence_for_path(path: Option<&Path>, row_range: Option<Range<u32>>) -> String {
+pub fn codeblock_fence_for_path(
+ path: Option<&Path>,
+ row_range: Option<RangeInclusive<u32>>,
+) -> String {
let mut text = String::new();
write!(text, "```").unwrap();
@@ -357,7 +360,7 @@ pub fn codeblock_fence_for_path(path: Option<&Path>, row_range: Option<Range<u32
}
if let Some(row_range) = row_range {
- write!(text, ":{}-{}", row_range.start + 1, row_range.end + 1).unwrap();
+ write!(text, ":{}-{}", row_range.start() + 1, row_range.end() + 1).unwrap();
}
text.push('\n');
@@ -8,14 +8,12 @@ use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
use feature_flags::FeatureFlag;
use gpui::{AppContext, Task, WeakView};
use language::{CodeLabel, LineEnding, LspAdapterDelegate};
-use semantic_index::SemanticDb;
+use semantic_index::{LoadedSearchResult, SemanticDb};
use std::{
fmt::Write,
- path::PathBuf,
sync::{atomic::AtomicBool, Arc},
};
use ui::{prelude::*, IconName};
-use util::ResultExt;
use workspace::Workspace;
pub(crate) struct SearchSlashCommandFeatureFlag;
@@ -107,52 +105,28 @@ impl SlashCommand for SearchSlashCommand {
})?
.await?;
- let mut loaded_results = Vec::new();
- for result in results {
- let (full_path, file_content) =
- result.worktree.read_with(&cx, |worktree, _cx| {
- let entry_abs_path = worktree.abs_path().join(&result.path);
- let mut entry_full_path = PathBuf::from(worktree.root_name());
- entry_full_path.push(&result.path);
- let file_content = async {
- let entry_abs_path = entry_abs_path;
- fs.load(&entry_abs_path).await
- };
- (entry_full_path, file_content)
- })?;
- if let Some(file_content) = file_content.await.log_err() {
- loaded_results.push((result, full_path, file_content));
- }
- }
+ let loaded_results = SemanticDb::load_results(results, &fs, &cx).await?;
let output = cx
.background_executor()
.spawn(async move {
let mut text = format!("Search results for {query}:\n");
let mut sections = Vec::new();
- for (result, full_path, file_content) in loaded_results {
- let range_start = result.range.start.min(file_content.len());
- let range_end = result.range.end.min(file_content.len());
-
- let start_row = file_content[0..range_start].matches('\n').count() as u32;
- let end_row = file_content[0..range_end].matches('\n').count() as u32;
- let start_line_byte_offset = file_content[0..range_start]
- .rfind('\n')
- .map(|pos| pos + 1)
- .unwrap_or_default();
- let end_line_byte_offset = file_content[range_end..]
- .find('\n')
- .map(|pos| range_end + pos)
- .unwrap_or_else(|| file_content.len());
-
+ for LoadedSearchResult {
+ path,
+ range,
+ full_path,
+ file_content,
+ row_range,
+ } in loaded_results
+ {
let section_start_ix = text.len();
text.push_str(&codeblock_fence_for_path(
- Some(&result.path),
- Some(start_row..end_row),
+ Some(&path),
+ Some(row_range.clone()),
));
- let mut excerpt =
- file_content[start_line_byte_offset..end_line_byte_offset].to_string();
+ let mut excerpt = file_content[range].to_string();
LineEnding::normalize(&mut excerpt);
text.push_str(&excerpt);
writeln!(text, "\n```\n").unwrap();
@@ -161,7 +135,7 @@ impl SlashCommand for SearchSlashCommand {
section_start_ix..section_end_ix,
Some(&full_path),
false,
- Some(start_row + 1..end_row + 1),
+ Some(row_range.start() + 1..row_range.end() + 1),
));
}
@@ -0,0 +1,37 @@
+[package]
+name = "evals"
+description = "Evaluations for Zed's AI features"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[[bin]]
+name = "eval"
+path = "src/eval.rs"
+
+[dependencies]
+clap.workspace = true
+anyhow.workspace = true
+client.workspace = true
+clock.workspace = true
+collections.workspace = true
+env_logger.workspace = true
+feature_flags.workspace = true
+fs.workspace = true
+git.workspace = true
+gpui.workspace = true
+language.workspace = true
+languages.workspace = true
+http_client.workspace = true
+open_ai.workspace = true
+project.workspace = true
+settings.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+smol.workspace = true
+semantic_index.workspace = true
+node_runtime.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,14 @@
+fn main() {
+ if cfg!(target_os = "macos") {
+ println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
+
+ println!("cargo:rerun-if-env-changed=ZED_BUNDLE");
+ if std::env::var("ZED_BUNDLE").ok().as_deref() == Some("true") {
+ // Find WebRTC.framework in the Frameworks folder when running as part of an application bundle.
+ println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../Frameworks");
+ } else {
+ // Find WebRTC.framework as a sibling of the executable when running outside of an application bundle.
+ println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
+ }
+ }
+}
@@ -0,0 +1,631 @@
+use ::fs::{Fs, RealFs};
+use anyhow::Result;
+use clap::Parser;
+use client::{Client, UserStore};
+use clock::RealSystemClock;
+use collections::BTreeMap;
+use feature_flags::FeatureFlagAppExt as _;
+use git::GitHostingProviderRegistry;
+use gpui::{AsyncAppContext, BackgroundExecutor, Context, Model};
+use http_client::{HttpClient, Method};
+use language::LanguageRegistry;
+use node_runtime::FakeNodeRuntime;
+use open_ai::OpenAiEmbeddingModel;
+use project::Project;
+use semantic_index::{OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status};
+use serde::{Deserialize, Serialize};
+use settings::SettingsStore;
+use smol::channel::bounded;
+use smol::io::AsyncReadExt;
+use smol::Timer;
+use std::ops::RangeInclusive;
+use std::time::Duration;
+use std::{
+ fs,
+ path::Path,
+ process::{exit, Command, Stdio},
+ sync::{
+ atomic::{AtomicUsize, Ordering::SeqCst},
+ Arc,
+ },
+};
+
+const CODESEARCH_NET_DIR: &'static str = "target/datasets/code-search-net";
+const EVAL_REPOS_DIR: &'static str = "target/datasets/eval-repos";
+const EVAL_DB_PATH: &'static str = "target/eval_db";
+const SEARCH_RESULT_LIMIT: usize = 8;
+const SKIP_EVAL_PATH: &'static str = ".skip_eval";
+
+#[derive(clap::Parser)]
+#[command(author, version, about, long_about = None)]
+struct Cli {
+ #[command(subcommand)]
+ command: Commands,
+}
+
+#[derive(clap::Subcommand)]
+enum Commands {
+ Fetch {},
+ Run {
+ #[arg(long)]
+ repo: Option<String>,
+ },
+}
+
+#[derive(Clone, Deserialize, Serialize)]
+struct EvaluationProject {
+ repo: String,
+ sha: String,
+ queries: Vec<EvaluationQuery>,
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize)]
+struct EvaluationQuery {
+ query: String,
+ expected_results: Vec<EvaluationSearchResult>,
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
+struct EvaluationSearchResult {
+ file: String,
+ lines: RangeInclusive<u32>,
+}
+
+#[derive(Clone, Deserialize, Serialize)]
+struct EvaluationProjectOutcome {
+ repo: String,
+ sha: String,
+ queries: Vec<EvaluationQueryOutcome>,
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize)]
+struct EvaluationQueryOutcome {
+ repo: String,
+ query: String,
+ expected_results: Vec<EvaluationSearchResult>,
+ actual_results: Vec<EvaluationSearchResult>,
+ covered_file_count: usize,
+ overlapped_result_count: usize,
+ covered_result_count: usize,
+ total_result_count: usize,
+ covered_result_indices: Vec<usize>,
+}
+
+fn main() -> Result<()> {
+ let cli = Cli::parse();
+ env_logger::init();
+
+ gpui::App::headless().run(move |cx| {
+ let executor = cx.background_executor().clone();
+
+ match cli.command {
+ Commands::Fetch {} => {
+ executor
+ .clone()
+ .spawn(async move {
+ if let Err(err) = fetch_evaluation_resources(&executor).await {
+ eprintln!("Error: {}", err);
+ exit(1);
+ }
+ exit(0);
+ })
+ .detach();
+ }
+ Commands::Run { repo } => {
+ cx.spawn(|mut cx| async move {
+ if let Err(err) = run_evaluation(repo, &executor, &mut cx).await {
+ eprintln!("Error: {}", err);
+ exit(1);
+ }
+ exit(0);
+ })
+ .detach();
+ }
+ }
+ });
+
+ Ok(())
+}
+
+async fn fetch_evaluation_resources(executor: &BackgroundExecutor) -> Result<()> {
+ let http_client = http_client::HttpClientWithProxy::new(None, None);
+ fetch_code_search_net_resources(&http_client).await?;
+ fetch_eval_repos(executor, &http_client).await?;
+ Ok(())
+}
+
+async fn fetch_code_search_net_resources(http_client: &dyn HttpClient) -> Result<()> {
+ eprintln!("Fetching CodeSearchNet evaluations...");
+
+ let annotations_url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/annotationStore.csv";
+
+ let dataset_dir = Path::new(CODESEARCH_NET_DIR);
+ fs::create_dir_all(&dataset_dir).expect("failed to create CodeSearchNet directory");
+
+ // Fetch the annotations CSV, which contains the human-annotated search relevances
+ let annotations_path = dataset_dir.join("annotations.csv");
+ let annotations_csv_content = if annotations_path.exists() {
+ fs::read_to_string(&annotations_path).expect("failed to read annotations")
+ } else {
+ let response = http_client
+ .get(annotations_url, Default::default(), true)
+ .await
+ .expect("failed to fetch annotations csv");
+ let mut body = String::new();
+ response
+ .into_body()
+ .read_to_string(&mut body)
+ .await
+ .expect("failed to read annotations.csv response");
+ fs::write(annotations_path, &body).expect("failed to write annotations.csv");
+ body
+ };
+
+ // Parse the annotations CSV. Skip over queries with zero relevance.
+ let rows = annotations_csv_content.lines().filter_map(|line| {
+ let mut values = line.split(',');
+ let _language = values.next()?;
+ let query = values.next()?;
+ let github_url = values.next()?;
+ let score = values.next()?;
+
+ if score == "0" {
+ return None;
+ }
+
+ let url_path = github_url.strip_prefix("https://github.com/")?;
+ let (url_path, hash) = url_path.split_once('#')?;
+ let (repo_name, url_path) = url_path.split_once("/blob/")?;
+ let (sha, file_path) = url_path.split_once('/')?;
+ let line_range = if let Some((start, end)) = hash.split_once('-') {
+ start.strip_prefix("L")?.parse::<u32>().ok()?..=end.strip_prefix("L")?.parse().ok()?
+ } else {
+ let row = hash.strip_prefix("L")?.parse().ok()?;
+ row..=row
+ };
+ Some((repo_name, sha, query, file_path, line_range))
+ });
+
+ // Group the annotations by repo and sha.
+ let mut evaluations_by_repo = BTreeMap::new();
+ for (repo_name, sha, query, file_path, lines) in rows {
+ let evaluation_project = evaluations_by_repo
+ .entry((repo_name, sha))
+ .or_insert_with(|| EvaluationProject {
+ repo: repo_name.to_string(),
+ sha: sha.to_string(),
+ queries: Vec::new(),
+ });
+
+ let ix = evaluation_project
+ .queries
+ .iter()
+ .position(|entry| entry.query == query)
+ .unwrap_or_else(|| {
+ evaluation_project.queries.push(EvaluationQuery {
+ query: query.to_string(),
+ expected_results: Vec::new(),
+ });
+ evaluation_project.queries.len() - 1
+ });
+ let results = &mut evaluation_project.queries[ix].expected_results;
+ let result = EvaluationSearchResult {
+ file: file_path.to_string(),
+ lines,
+ };
+ if !results.contains(&result) {
+ results.push(result);
+ }
+ }
+
+ let evaluations = evaluations_by_repo.into_values().collect::<Vec<_>>();
+ let evaluations_path = dataset_dir.join("evaluations.json");
+ fs::write(
+ &evaluations_path,
+ serde_json::to_vec_pretty(&evaluations).unwrap(),
+ )
+ .unwrap();
+
+ eprintln!(
+ "Fetched CodeSearchNet evaluations into {}",
+ evaluations_path.display()
+ );
+
+ Ok(())
+}
+
+async fn run_evaluation(
+ only_repo: Option<String>,
+ executor: &BackgroundExecutor,
+ cx: &mut AsyncAppContext,
+) -> Result<()> {
+ cx.update(|cx| {
+ let mut store = SettingsStore::new(cx);
+ store
+ .set_default_settings(settings::default_settings().as_ref(), cx)
+ .unwrap();
+ cx.set_global(store);
+ client::init_settings(cx);
+ language::init(cx);
+ Project::init_settings(cx);
+ cx.update_flags(false, vec![]);
+ })
+ .unwrap();
+
+ let dataset_dir = Path::new(CODESEARCH_NET_DIR);
+ let evaluations_path = dataset_dir.join("evaluations.json");
+ let repos_dir = Path::new(EVAL_REPOS_DIR);
+ let db_path = Path::new(EVAL_DB_PATH);
+ let http_client = http_client::HttpClientWithProxy::new(None, None);
+ let api_key = std::env::var("OPENAI_API_KEY").unwrap();
+ let git_hosting_provider_registry = Arc::new(GitHostingProviderRegistry::new());
+ let fs = Arc::new(RealFs::new(git_hosting_provider_registry, None)) as Arc<dyn Fs>;
+ let clock = Arc::new(RealSystemClock);
+ let client = cx
+ .update(|cx| {
+ Client::new(
+ clock,
+ Arc::new(http_client::HttpClientWithUrl::new(
+ "https://zed.dev",
+ None,
+ None,
+ )),
+ cx,
+ )
+ })
+ .unwrap();
+ let user_store = cx
+ .new_model(|cx| UserStore::new(client.clone(), cx))
+ .unwrap();
+ let node_runtime = Arc::new(FakeNodeRuntime {});
+
+ let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
+ let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
+
+ let embedding_provider = Arc::new(OpenAiEmbeddingProvider::new(
+ http_client.clone(),
+ OpenAiEmbeddingModel::TextEmbedding3Small,
+ open_ai::OPEN_AI_API_URL.to_string(),
+ api_key,
+ ));
+
+ let language_registry = Arc::new(LanguageRegistry::new(executor.clone()));
+ cx.update(|cx| languages::init(language_registry.clone(), node_runtime.clone(), cx))
+ .unwrap();
+
+ let mut covered_result_count = 0;
+ let mut overlapped_result_count = 0;
+ let mut covered_file_count = 0;
+ let mut total_result_count = 0;
+ eprint!("Running evals.");
+
+ for evaluation_project in evaluations {
+ if only_repo
+ .as_ref()
+ .map_or(false, |only_repo| only_repo != &evaluation_project.repo)
+ {
+ continue;
+ }
+
+ eprint!("\r\x1B[2K");
+ eprint!(
+ "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured. Project: {}...",
+ covered_result_count,
+ total_result_count,
+ overlapped_result_count,
+ total_result_count,
+ covered_file_count,
+ total_result_count,
+ evaluation_project.repo
+ );
+
+ let repo_db_path =
+ db_path.join(format!("{}.db", evaluation_project.repo.replace('/', "_")));
+ let mut semantic_index = SemanticDb::new(repo_db_path, embedding_provider.clone(), cx)
+ .await
+ .unwrap();
+
+ let repo_dir = repos_dir.join(&evaluation_project.repo);
+ if !repo_dir.exists() || repo_dir.join(SKIP_EVAL_PATH).exists() {
+ eprintln!("Skipping {}: directory not found", evaluation_project.repo);
+ continue;
+ }
+
+ let project = cx
+ .update(|cx| {
+ Project::local(
+ client.clone(),
+ node_runtime.clone(),
+ user_store.clone(),
+ language_registry.clone(),
+ fs.clone(),
+ None,
+ cx,
+ )
+ })
+ .unwrap();
+
+ let (worktree, _) = project
+ .update(cx, |project, cx| {
+ project.find_or_create_worktree(repo_dir, true, cx)
+ })?
+ .await?;
+
+ worktree
+ .update(cx, |worktree, _| {
+ worktree.as_local().unwrap().scan_complete()
+ })
+ .unwrap()
+ .await;
+
+ let project_index = cx
+ .update(|cx| semantic_index.create_project_index(project.clone(), cx))
+ .unwrap();
+ wait_for_indexing_complete(&project_index, cx, Some(Duration::from_secs(120))).await;
+
+ for query in evaluation_project.queries {
+ let results = cx
+ .update(|cx| {
+ let project_index = project_index.read(cx);
+ project_index.search(query.query.clone(), SEARCH_RESULT_LIMIT, cx)
+ })
+ .unwrap()
+ .await
+ .unwrap();
+
+ let results = SemanticDb::load_results(results, &fs.clone(), &cx)
+ .await
+ .unwrap();
+
+ let mut project_covered_result_count = 0;
+ let mut project_overlapped_result_count = 0;
+ let mut project_covered_file_count = 0;
+ let mut covered_result_indices = Vec::new();
+ for expected_result in &query.expected_results {
+ let mut file_matched = false;
+ let mut range_overlapped = false;
+ let mut range_covered = false;
+
+ for (ix, result) in results.iter().enumerate() {
+ if result.path.as_ref() == Path::new(&expected_result.file) {
+ file_matched = true;
+ let start_matched =
+ result.row_range.contains(&expected_result.lines.start());
+ let end_matched = result.row_range.contains(&expected_result.lines.end());
+
+ if start_matched || end_matched {
+ range_overlapped = true;
+ }
+
+ if start_matched && end_matched {
+ range_covered = true;
+ covered_result_indices.push(ix);
+ break;
+ }
+ }
+ }
+
+ if range_covered {
+ project_covered_result_count += 1
+ };
+ if range_overlapped {
+ project_overlapped_result_count += 1
+ };
+ if file_matched {
+ project_covered_file_count += 1
+ };
+ }
+ let outcome_repo = evaluation_project.repo.clone();
+
+ let query_results = EvaluationQueryOutcome {
+ repo: outcome_repo,
+ query: query.query,
+ total_result_count: query.expected_results.len(),
+ covered_result_count: project_covered_result_count,
+ overlapped_result_count: project_overlapped_result_count,
+ covered_file_count: project_covered_file_count,
+ expected_results: query.expected_results,
+ actual_results: results
+ .iter()
+ .map(|result| EvaluationSearchResult {
+ file: result.path.to_string_lossy().to_string(),
+ lines: result.row_range.clone(),
+ })
+ .collect(),
+ covered_result_indices,
+ };
+
+ overlapped_result_count += query_results.overlapped_result_count;
+ covered_result_count += query_results.covered_result_count;
+ covered_file_count += query_results.covered_file_count;
+ total_result_count += query_results.total_result_count;
+
+ println!("{}", serde_json::to_string(&query_results).unwrap());
+ }
+ }
+
+ eprint!(
+ "Running evals. {}/{} covered. {}/{} overlapped. {}/{} files captured.",
+ covered_result_count,
+ total_result_count,
+ overlapped_result_count,
+ total_result_count,
+ covered_file_count,
+ total_result_count,
+ );
+
+ Ok(())
+}
+
+async fn wait_for_indexing_complete(
+ project_index: &Model<ProjectIndex>,
+ cx: &mut AsyncAppContext,
+ timeout: Option<Duration>,
+) {
+ let (tx, rx) = bounded(1);
+ let subscription = cx.update(|cx| {
+ cx.subscribe(project_index, move |_, event, _| {
+ if let Status::Idle = event {
+ let _ = tx.try_send(*event);
+ }
+ })
+ });
+
+ let result = match timeout {
+ Some(timeout_duration) => {
+ smol::future::or(
+ async {
+ rx.recv().await.map_err(|_| ())?;
+ Ok(())
+ },
+ async {
+ Timer::after(timeout_duration).await;
+ Err(())
+ },
+ )
+ .await
+ }
+ None => rx.recv().await.map(|_| ()).map_err(|_| ()),
+ };
+
+ match result {
+ Ok(_) => (),
+ Err(_) => {
+ if let Some(timeout) = timeout {
+ eprintln!("Timeout: Indexing did not complete within {:?}", timeout);
+ }
+ }
+ }
+
+ drop(subscription);
+}
+
+async fn fetch_eval_repos(
+ executor: &BackgroundExecutor,
+ http_client: &dyn HttpClient,
+) -> Result<()> {
+ let dataset_dir = Path::new(CODESEARCH_NET_DIR);
+ let evaluations_path = dataset_dir.join("evaluations.json");
+ let repos_dir = Path::new(EVAL_REPOS_DIR);
+
+ let evaluations = fs::read(&evaluations_path).expect("failed to read evaluations.json");
+ let evaluations: Vec<EvaluationProject> = serde_json::from_slice(&evaluations).unwrap();
+
+ eprint!("Fetching evaluation repositories...");
+
+ executor
+ .scoped(move |scope| {
+ let done_count = Arc::new(AtomicUsize::new(0));
+ let len = evaluations.len();
+ for chunk in evaluations.chunks(evaluations.len() / 8) {
+ let chunk = chunk.to_vec();
+ let done_count = done_count.clone();
+ scope.spawn(async move {
+ for EvaluationProject { repo, sha, .. } in chunk {
+ eprint!(
+ "\rFetching evaluation repositories ({}/{})...",
+ done_count.load(SeqCst),
+ len,
+ );
+
+ fetch_eval_repo(repo, sha, repos_dir, http_client).await;
+ done_count.fetch_add(1, SeqCst);
+ }
+ });
+ }
+ })
+ .await;
+
+ Ok(())
+}
+
+async fn fetch_eval_repo(
+ repo: String,
+ sha: String,
+ repos_dir: &Path,
+ http_client: &dyn HttpClient,
+) {
+ let Some((owner, repo_name)) = repo.split_once('/') else {
+ return;
+ };
+ let repo_dir = repos_dir.join(owner).join(repo_name);
+ fs::create_dir_all(&repo_dir).unwrap();
+ let skip_eval_path = repo_dir.join(SKIP_EVAL_PATH);
+ if skip_eval_path.exists() {
+ return;
+ }
+ if let Ok(head_content) = fs::read_to_string(&repo_dir.join(".git").join("HEAD")) {
+ if head_content.trim() == sha {
+ return;
+ }
+ }
+ let repo_response = http_client
+ .send(
+ http_client::Request::builder()
+ .method(Method::HEAD)
+ .uri(format!("https://github.com/{}", repo))
+ .body(Default::default())
+ .expect(""),
+ )
+ .await
+ .expect("failed to check github repo");
+ if !repo_response.status().is_success() && !repo_response.status().is_redirection() {
+ fs::write(&skip_eval_path, "").unwrap();
+ eprintln!(
+ "Repo {repo} is no longer public ({:?}). Skipping",
+ repo_response.status()
+ );
+ return;
+ }
+ if !repo_dir.join(".git").exists() {
+ let init_output = Command::new("git")
+ .current_dir(&repo_dir)
+ .args(&["init"])
+ .output()
+ .unwrap();
+ if !init_output.status.success() {
+ eprintln!(
+ "Failed to initialize git repository for {}: {}",
+ repo,
+ String::from_utf8_lossy(&init_output.stderr)
+ );
+ return;
+ }
+ }
+ let url = format!("https://github.com/{}.git", repo);
+ Command::new("git")
+ .current_dir(&repo_dir)
+ .args(&["remote", "add", "-f", "origin", &url])
+ .stdin(Stdio::null())
+ .output()
+ .unwrap();
+ let fetch_output = Command::new("git")
+ .current_dir(&repo_dir)
+ .args(&["fetch", "--depth", "1", "origin", &sha])
+ .stdin(Stdio::null())
+ .output()
+ .unwrap();
+ if !fetch_output.status.success() {
+ eprintln!(
+ "Failed to fetch {} for {}: {}",
+ sha,
+ repo,
+ String::from_utf8_lossy(&fetch_output.stderr)
+ );
+ return;
+ }
+ let checkout_output = Command::new("git")
+ .current_dir(&repo_dir)
+ .args(&["checkout", &sha])
+ .output()
+ .unwrap();
+
+ if !checkout_output.status.success() {
+ eprintln!(
+ "Failed to checkout {} for {}: {}",
+ sha,
+ repo,
+ String::from_utf8_lossy(&checkout_output.stderr)
+ );
+ }
+}
@@ -5,6 +5,7 @@ use derive_more::Deref;
use futures::future::BoxFuture;
use futures_lite::FutureExt;
use isahc::config::{Configurable, RedirectPolicy};
+pub use isahc::http;
pub use isahc::{
http::{Method, StatusCode, Uri},
AsyncBody, Error, HttpClient as IsahcHttpClient, Request, Response,
@@ -226,7 +227,7 @@ pub fn client(user_agent: Option<String>, proxy: Option<Uri>) -> Arc<dyn HttpCli
// those requests use a different http client, because global timeouts
// of 50 and 60 seconds, respectively, would be very high!
.connect_timeout(Duration::from_secs(5))
- .low_speed_timeout(100, Duration::from_secs(5))
+ .low_speed_timeout(100, Duration::from_secs(30))
.proxy(proxy.clone());
if let Some(user_agent) = user_agent {
builder = builder.default_header("User-Agent", user_agent);
@@ -234,30 +234,25 @@ impl EmbeddingIndex {
cx.spawn(async {
while let Ok((entry, handle)) = entries.recv().await {
let entry_abs_path = worktree_abs_path.join(&entry.path);
- match fs.load(&entry_abs_path).await {
- Ok(text) => {
- let language = language_registry
- .language_for_file_path(&entry.path)
- .await
- .ok();
- let chunked_file = ChunkedFile {
- chunks: chunking::chunk_text(
- &text,
- language.as_ref(),
- &entry.path,
- ),
- handle,
- path: entry.path,
- mtime: entry.mtime,
- text,
- };
-
- if chunked_files_tx.send(chunked_file).await.is_err() {
- return;
- }
- }
- Err(_)=> {
- log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}");
+ if let Some(text) = fs.load(&entry_abs_path).await.ok() {
+ let language = language_registry
+ .language_for_file_path(&entry.path)
+ .await
+ .ok();
+ let chunked_file = ChunkedFile {
+ chunks: chunking::chunk_text(
+ &text,
+ language.as_ref(),
+ &entry.path,
+ ),
+ handle,
+ path: entry.path,
+ mtime: entry.mtime,
+ text,
+ };
+
+ if chunked_files_tx.send(chunked_file).await.is_err() {
+ return;
}
}
}
@@ -358,33 +353,37 @@ impl EmbeddingIndex {
fn persist_embeddings(
&self,
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
- embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
+ mut embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
cx: &AppContext,
) -> Task<Result<()>> {
let db_connection = self.db_connection.clone();
let db = self.db;
- cx.background_executor().spawn(async move {
- while let Some(deletion_range) = deleted_entry_ranges.next().await {
- let mut txn = db_connection.write_txn()?;
- let start = deletion_range.0.as_ref().map(|start| start.as_str());
- let end = deletion_range.1.as_ref().map(|end| end.as_str());
- log::debug!("deleting embeddings in range {:?}", &(start, end));
- db.delete_range(&mut txn, &(start, end))?;
- txn.commit()?;
- }
- let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
- while let Some(embedded_files) = embedded_files.next().await {
- let mut txn = db_connection.write_txn()?;
- for (file, _) in &embedded_files {
- log::debug!("saving embedding for file {:?}", file.path);
- let key = db_key_for_path(&file.path);
- db.put(&mut txn, &key, file)?;
+ cx.background_executor().spawn(async move {
+ loop {
+ // Interleave deletions and persists of embedded files
+ futures::select_biased! {
+ deletion_range = deleted_entry_ranges.next() => {
+ if let Some(deletion_range) = deletion_range {
+ let mut txn = db_connection.write_txn()?;
+ let start = deletion_range.0.as_ref().map(|start| start.as_str());
+ let end = deletion_range.1.as_ref().map(|end| end.as_str());
+ log::debug!("deleting embeddings in range {:?}", &(start, end));
+ db.delete_range(&mut txn, &(start, end))?;
+ txn.commit()?;
+ }
+ },
+ file = embedded_files.next() => {
+ if let Some((file, _)) = file {
+ let mut txn = db_connection.write_txn()?;
+ log::debug!("saving embedding for file {:?}", file.path);
+ let key = db_key_for_path(&file.path);
+ db.put(&mut txn, &key, &file)?;
+ txn.commit()?;
+ }
+ },
+ complete => break,
}
- txn.commit()?;
-
- drop(embedded_files);
- log::debug!("committed");
}
Ok(())
@@ -15,7 +15,14 @@ use log;
use project::{Project, Worktree, WorktreeId};
use serde::{Deserialize, Serialize};
use smol::channel;
-use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc};
+use std::{
+ cmp::Ordering,
+ future::Future,
+ num::NonZeroUsize,
+ ops::{Range, RangeInclusive},
+ path::{Path, PathBuf},
+ sync::Arc,
+};
use util::ResultExt;
#[derive(Debug)]
@@ -26,6 +33,14 @@ pub struct SearchResult {
pub score: f32,
}
+pub struct LoadedSearchResult {
+ pub path: Arc<Path>,
+ pub range: Range<usize>,
+ pub full_path: PathBuf,
+ pub file_content: String,
+ pub row_range: RangeInclusive<u32>,
+}
+
pub struct WorktreeSearchResult {
pub worktree_id: WorktreeId,
pub path: Arc<Path>,
@@ -10,14 +10,16 @@ mod worktree_index;
use anyhow::{Context as _, Result};
use collections::HashMap;
+use fs::Fs;
use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
use project::Project;
-use project_index::ProjectIndex;
use std::{path::PathBuf, sync::Arc};
use ui::ViewContext;
+use util::ResultExt as _;
use workspace::Workspace;
pub use embedding::*;
+pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
pub use project_index_debug_view::ProjectIndexDebugView;
pub use summary_index::FileSummary;
@@ -56,27 +58,7 @@ impl SemanticDb {
if cx.has_global::<SemanticDb>() {
cx.update_global::<SemanticDb, _>(|this, cx| {
- let project_index = cx.new_model(|cx| {
- ProjectIndex::new(
- project.clone(),
- this.db_connection.clone(),
- this.embedding_provider.clone(),
- cx,
- )
- });
-
- let project_weak = project.downgrade();
- this.project_indices
- .insert(project_weak.clone(), project_index);
-
- cx.on_release(move |_, _, cx| {
- if cx.has_global::<SemanticDb>() {
- cx.update_global::<SemanticDb, _>(|this, _| {
- this.project_indices.remove(&project_weak);
- })
- }
- })
- .detach();
+ this.create_project_index(project, cx);
})
} else {
log::info!("No SemanticDb, skipping project index")
@@ -94,6 +76,50 @@ impl SemanticDb {
})
}
+ pub async fn load_results(
+ results: Vec<SearchResult>,
+ fs: &Arc<dyn Fs>,
+ cx: &AsyncAppContext,
+ ) -> Result<Vec<LoadedSearchResult>> {
+ let mut loaded_results = Vec::new();
+ for result in results {
+ let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| {
+ let entry_abs_path = worktree.abs_path().join(&result.path);
+ let mut entry_full_path = PathBuf::from(worktree.root_name());
+ entry_full_path.push(&result.path);
+ let file_content = async {
+ let entry_abs_path = entry_abs_path;
+ fs.load(&entry_abs_path).await
+ };
+ (entry_full_path, file_content)
+ })?;
+ if let Some(file_content) = file_content.await.log_err() {
+ let range_start = result.range.start.min(file_content.len());
+ let range_end = result.range.end.min(file_content.len());
+
+ let start_row = file_content[0..range_start].matches('\n').count() as u32;
+ let end_row = file_content[0..range_end].matches('\n').count() as u32;
+ let start_line_byte_offset = file_content[0..range_start]
+ .rfind('\n')
+ .map(|pos| pos + 1)
+ .unwrap_or_default();
+ let end_line_byte_offset = file_content[range_end..]
+ .find('\n')
+ .map(|pos| range_end + pos)
+ .unwrap_or_else(|| file_content.len());
+
+ loaded_results.push(LoadedSearchResult {
+ path: result.path,
+ range: start_line_byte_offset..end_line_byte_offset,
+ full_path,
+ file_content,
+ row_range: start_row..=end_row,
+ });
+ }
+ }
+ Ok(loaded_results)
+ }
+
pub fn project_index(
&mut self,
project: Model<Project>,
@@ -113,6 +139,36 @@ impl SemanticDb {
})
})
}
+
+ pub fn create_project_index(
+ &mut self,
+ project: Model<Project>,
+ cx: &mut AppContext,
+ ) -> Model<ProjectIndex> {
+ let project_index = cx.new_model(|cx| {
+ ProjectIndex::new(
+ project.clone(),
+ self.db_connection.clone(),
+ self.embedding_provider.clone(),
+ cx,
+ )
+ });
+
+ let project_weak = project.downgrade();
+ self.project_indices
+ .insert(project_weak.clone(), project_index.clone());
+
+ cx.observe_release(&project, move |_, cx| {
+ if cx.has_global::<SemanticDb>() {
+ cx.update_global::<SemanticDb, _>(|this, _| {
+ this.project_indices.remove(&project_weak);
+ })
+ }
+ })
+ .detach();
+
+ project_index
+ }
}
#[cfg(test)]
@@ -230,34 +286,13 @@ mod tests {
let project = Project::test(fs, [project_path], cx).await;
- cx.update(|cx| {
+ let project_index = cx.update(|cx| {
let language_registry = project.read(cx).languages().clone();
let node_runtime = project.read(cx).node_runtime().unwrap().clone();
languages::init(language_registry, node_runtime, cx);
-
- // Manually create and insert the ProjectIndex
- let project_index = cx.new_model(|cx| {
- ProjectIndex::new(
- project.clone(),
- semantic_index.db_connection.clone(),
- semantic_index.embedding_provider.clone(),
- cx,
- )
- });
- semantic_index
- .project_indices
- .insert(project.downgrade(), project_index);
+ semantic_index.create_project_index(project.clone(), cx)
});
- let project_index = cx
- .update(|_cx| {
- semantic_index
- .project_indices
- .get(&project.downgrade())
- .cloned()
- })
- .unwrap();
-
cx.run_until_parked();
while cx
.update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))