semantic index eval, indexing appropriately

KCaverly created

Change summary

Cargo.lock                                  |   4 
crates/semantic_index/Cargo.toml            |   4 
crates/semantic_index/eval/tree-sitter.json |   6 
crates/semantic_index/examples/eval.rs      | 202 ++++++++++++++++++----
crates/semantic_index/src/semantic_index.rs |   6 
5 files changed, 172 insertions(+), 50 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -6744,6 +6744,7 @@ dependencies = [
  "anyhow",
  "async-trait",
  "bincode",
+ "client",
  "collections",
  "ctor",
  "editor",
@@ -6757,6 +6758,7 @@ dependencies = [
  "lazy_static",
  "log",
  "matrixmultiply",
+ "node_runtime",
  "parking_lot 0.11.2",
  "parse_duration",
  "picker",
@@ -6766,6 +6768,7 @@ dependencies = [
  "rand 0.8.5",
  "rpc",
  "rusqlite",
+ "rust-embed",
  "schemars",
  "serde",
  "serde_json",
@@ -6788,6 +6791,7 @@ dependencies = [
  "unindent",
  "util",
  "workspace",
+ "zed",
 ]
 
 [[package]]

crates/semantic_index/Cargo.toml 🔗

@@ -51,6 +51,10 @@ rpc = { path = "../rpc", features = ["test-support"] }
 workspace = { path = "../workspace", features = ["test-support"] }
 settings = { path = "../settings", features = ["test-support"]}
 git2 = { version = "0.15"}
+rust-embed = { version = "8.0", features = ["include-exclude"] }
+client = { path = "../client" }
+zed = { path = "../zed"}
+node_runtime = { path = "../node_runtime"}
 
 pretty_assertions.workspace = true
 rand.workspace = true

crates/semantic_index/eval/tree-sitter.json 🔗

@@ -17,7 +17,7 @@
     {
       "query": "generate tags based on config",
       "matches": [
-        "tags/src/lib.rs:261",
+        "tags/src/lib.rs:261"
       ]
     },
     {
@@ -54,13 +54,13 @@
     {
       "query": "Match based on associativity of actions",
       "matches": [
-        "cri/src/generate/build_tables/build_parse_table.rs:542",
+        "cri/src/generate/build_tables/build_parse_table.rs:542"
       ]
     },
     {
       "query": "Format token set display",
       "matches": [
-        "cli/src/generate/build_tables/item.rs:246",
+        "cli/src/generate/build_tables/item.rs:246"
       ]
     },
     {

crates/semantic_index/examples/eval.rs 🔗

@@ -1,8 +1,46 @@
+use anyhow::{anyhow, Result};
+use client::{self, UserStore};
 use git2::{Object, Oid, Repository};
-use semantic_index::SearchResult;
+use gpui::{AppContext, AssetSource, ModelHandle, Task};
+use language::LanguageRegistry;
+use node_runtime::RealNodeRuntime;
+use project::{Fs, Project, RealFs};
+use rust_embed::RustEmbed;
+use semantic_index::embedding::OpenAIEmbeddings;
+use semantic_index::semantic_index_settings::SemanticIndexSettings;
+use semantic_index::{SearchResult, SemanticIndex};
 use serde::Deserialize;
-use std::path::{Path, PathBuf};
-use std::{env, fs};
+use settings::{default_settings, handle_settings_file_changes, watch_config_file, SettingsStore};
+use std::path::{self, Path, PathBuf};
+use std::sync::Arc;
+use std::time::Duration;
+use std::{cmp, env, fs};
+use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
+use util::http::{self, HttpClient};
+use util::paths::{self, EMBEDDINGS_DIR};
+use zed::languages;
+
+#[derive(RustEmbed)]
+#[folder = "../../assets"]
+#[include = "fonts/**/*"]
+#[include = "icons/**/*"]
+#[include = "themes/**/*"]
+#[include = "sounds/**/*"]
+#[include = "*.md"]
+#[exclude = "*.DS_Store"]
+pub struct Assets;
+
+impl AssetSource for Assets {
+    fn load(&self, path: &str) -> Result<std::borrow::Cow<[u8]>> {
+        Self::get(path)
+            .map(|f| f.data)
+            .ok_or_else(|| anyhow!("could not find asset at path \"{}\"", path))
+    }
+
+    fn list(&self, path: &str) -> Vec<std::borrow::Cow<'static, str>> {
+        Self::iter().filter(|p| p.starts_with(path)).collect()
+    }
+}
 
 #[derive(Deserialize, Clone)]
 struct EvaluationQuery {
@@ -13,15 +51,18 @@ struct EvaluationQuery {
 impl EvaluationQuery {
     fn match_pairs(&self) -> Vec<(PathBuf, usize)> {
         let mut pairs = Vec::new();
-        for match_identifier in self.matches {
-            let match_parts = match_identifier.split(":");
+        for match_identifier in self.matches.iter() {
+            let mut match_parts = match_identifier.split(":");
 
             if let Some(file_path) = match_parts.next() {
                 if let Some(row_number) = match_parts.next() {
-                    pairs.push((PathBuf::from(file_path), from_str::<usize>(row_number)));
+                    pairs.push((
+                        PathBuf::from(file_path),
+                        row_number.parse::<usize>().unwrap(),
+                    ));
                 }
             }
-
+        }
         pairs
     }
 }
@@ -33,7 +74,7 @@ struct RepoEval {
     assertions: Vec<EvaluationQuery>,
 }
 
-const TMP_REPO_PATH: &str = "./target/eval_repos";
+const TMP_REPO_PATH: &str = "eval_repos";
 
 fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
     let eval_folder = env::current_dir()?
@@ -74,7 +115,12 @@ fn clone_repo(repo_eval: RepoEval) -> anyhow::Result<PathBuf> {
         .unwrap()
         .to_owned()
         .replace(".git", "");
-    let clone_path = Path::new(TMP_REPO_PATH).join(&repo_name).to_path_buf();
+
+    let clone_path = fs::canonicalize(env::current_dir()?)?
+        .parent()
+        .ok_or(anyhow!("path canonicalization failed"))?
+        .join(TMP_REPO_PATH)
+        .join(&repo_name);
 
     // Delete Clone Path if already exists
     let _ = fs::remove_dir_all(&clone_path);
@@ -105,7 +151,6 @@ fn dcg(hits: Vec<usize>) -> f32 {
 }
 
 fn evaluate_ndcg(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {
-
     // NDCG or Normalized Discounted Cumulative Gain, is determined by comparing the relevance of
     // items returned by the search engine relative to the hypothetical ideal.
     // Relevance is represented as a series of booleans, in which each search result returned
@@ -125,47 +170,118 @@ fn evaluate_ndcg(eval_query: EvaluationQuery, search_results: Vec<SearchResult>,
     // very high quality, whereas rank results quickly drop off after the first result.
 
     let ideal = vec![1; cmp::min(eval_query.matches.len(), k)];
+    let hits = vec![1];
 
     return dcg(hits) / dcg(ideal);
 }
 
-fn evaluate_map(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {
-
-}
-
-fn evaluate_repo(repo_eval: RepoEval, clone_path: PathBuf) {
-
-    // Launch new repo as a new Zed workspace/project
-    // Index the project
-    // Search each eval_query
-    // Calculate Statistics
+// fn evaluate_map(eval_query: EvaluationQuery, search_results: Vec<SearchResult>, k: usize) -> f32 {}
 
+fn init_logger() {
+    env_logger::init();
 }
 
 fn main() {
-
-    // zed/main.rs
-    // creating an app and running it, gives you the context.
-    // create a project, find_or_create_local_worktree.
-
-    if let Ok(repo_evals) = parse_eval() {
-        for repo in repo_evals {
-            let cloned = clone_repo(repo.clone());
-            match cloned {
-                Ok(clone_path) => {
-                    println!(
-                        "Cloned {:?} @ {:?} into {:?}",
-                        repo.repo, repo.commit, &clone_path
-                    );
-
-                    // Evaluate Repo
-                    evaluate_repo(repo, clone_path);
-
-                }
-                Err(err) => {
-                    println!("Error Cloning: {:?}", err);
+    // Launch new repo as a new Zed workspace/project
+    let app = gpui::App::new(Assets).unwrap();
+    let fs = Arc::new(RealFs);
+    let http = http::client();
+    let user_settings_file_rx =
+        watch_config_file(app.background(), fs.clone(), paths::SETTINGS.clone());
+    let http_client = http::client();
+    init_logger();
+
+    app.run(move |cx| {
+        cx.set_global(*RELEASE_CHANNEL);
+
+        let client = client::Client::new(http.clone(), cx);
+        let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client.clone(), cx));
+
+        // Initialize Settings
+        let mut store = SettingsStore::default();
+        store
+            .set_default_settings(default_settings().as_ref(), cx)
+            .unwrap();
+        cx.set_global(store);
+        handle_settings_file_changes(user_settings_file_rx, cx);
+
+        // Initialize Languages
+        let login_shell_env_loaded = Task::ready(());
+        let mut languages = LanguageRegistry::new(login_shell_env_loaded);
+        languages.set_executor(cx.background().clone());
+        let languages = Arc::new(languages);
+
+        let node_runtime = RealNodeRuntime::new(http.clone());
+        languages::init(languages.clone(), node_runtime.clone());
+
+        project::Project::init(&client, cx);
+        semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
+
+        settings::register::<SemanticIndexSettings>(cx);
+
+        let db_file_path = EMBEDDINGS_DIR
+            .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
+            .join("embeddings_db");
+
+        let languages = languages.clone();
+        let fs = fs.clone();
+        cx.spawn(|mut cx| async move {
+            let semantic_index = SemanticIndex::new(
+                fs.clone(),
+                db_file_path,
+                Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+                languages.clone(),
+                cx.clone(),
+            )
+            .await?;
+
+            if let Ok(repo_evals) = parse_eval() {
+                for repo in repo_evals {
+                    let cloned = clone_repo(repo.clone());
+                    match cloned {
+                        Ok(clone_path) => {
+                            log::trace!(
+                                "Cloned {:?} @ {:?} into {:?}",
+                                repo.repo,
+                                repo.commit,
+                                &clone_path
+                            );
+
+                            // Create Project
+                            let project = cx.update(|cx| {
+                                Project::local(
+                                    client.clone(),
+                                    user_store.clone(),
+                                    languages.clone(),
+                                    fs.clone(),
+                                    cx,
+                                )
+                            });
+
+                            // Register Worktree
+                            let _ = project
+                                .update(&mut cx, |project, cx| {
+                                    println!(
+                                        "Creating worktree in project: {:?}",
+                                        clone_path.clone()
+                                    );
+                                    project.find_or_create_local_worktree(clone_path, true, cx)
+                                })
+                                .await;
+
+                            let _ = semantic_index
+                                .update(&mut cx, |index, cx| index.index_project(project, cx))
+                                .await;
+                        }
+                        Err(err) => {
+                            log::trace!("Error cloning: {:?}", err);
+                        }
+                    }
                 }
             }
-        }
-    }
+
+            anyhow::Ok(())
+        })
+        .detach();
+    });
 }

crates/semantic_index/src/semantic_index.rs 🔗

@@ -1,5 +1,5 @@
 mod db;
-mod embedding;
+pub mod embedding;
 mod embedding_queue;
 mod parsing;
 pub mod semantic_index_settings;
@@ -301,7 +301,7 @@ impl SemanticIndex {
         }
     }
 
-    async fn new(
+    pub async fn new(
         fs: Arc<dyn Fs>,
         database_path: PathBuf,
         embedding_provider: Arc<dyn EmbeddingProvider>,
@@ -837,8 +837,6 @@ impl SemanticIndex {
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<()>> {
         if !self.projects.contains_key(&project.downgrade()) {
-            log::trace!("Registering Project for Semantic Index");
-
             let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
                 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
                     this.project_worktrees_changed(project.clone(), cx);