Port `semantic_index` to gpui2

Antonio Scandurra and Julia Risley created

Co-Authored-By: Julia Risley <julia@zed.dev>

Change summary

Cargo.lock                                            |   51 
Cargo.toml                                            |    3 
crates/ai2/src/auth.rs                                |    2 
crates/ai2/src/providers/open_ai/embedding.rs         |    4 
crates/gpui2/src/app/entity_map.rs                    |   19 
crates/semantic_index2/Cargo.toml                     |   69 
crates/semantic_index2/README.md                      |   20 
crates/semantic_index2/eval/gpt-engineer.json         |  114 
crates/semantic_index2/eval/tree-sitter.json          |  104 
crates/semantic_index2/src/db.rs                      |  603 ++++
crates/semantic_index2/src/embedding_queue.rs         |  169 +
crates/semantic_index2/src/parsing.rs                 |  414 +++
crates/semantic_index2/src/semantic_index.rs          | 1280 +++++++++
crates/semantic_index2/src/semantic_index_settings.rs |   28 
crates/semantic_index2/src/semantic_index_tests.rs    | 1697 +++++++++++++
crates/workspace2/src/workspace2.rs                   |    2 
16 files changed, 4,569 insertions(+), 10 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -8232,6 +8232,57 @@ dependencies = [
  "workspace",
 ]
 
+[[package]]
+name = "semantic_index2"
+version = "0.1.0"
+dependencies = [
+ "ai2",
+ "anyhow",
+ "async-trait",
+ "client2",
+ "collections",
+ "ctor",
+ "env_logger 0.9.3",
+ "futures 0.3.28",
+ "globset",
+ "gpui2",
+ "language2",
+ "lazy_static",
+ "log",
+ "ndarray",
+ "node_runtime",
+ "ordered-float 2.10.0",
+ "parking_lot 0.11.2",
+ "postage",
+ "pretty_assertions",
+ "project2",
+ "rand 0.8.5",
+ "rpc2",
+ "rusqlite",
+ "rust-embed",
+ "schemars",
+ "serde",
+ "serde_json",
+ "settings2",
+ "sha1",
+ "smol",
+ "tempdir",
+ "tiktoken-rs",
+ "tree-sitter",
+ "tree-sitter-cpp",
+ "tree-sitter-elixir",
+ "tree-sitter-json 0.20.0",
+ "tree-sitter-lua",
+ "tree-sitter-php",
+ "tree-sitter-ruby",
+ "tree-sitter-rust",
+ "tree-sitter-toml",
+ "tree-sitter-typescript",
+ "unindent",
+ "util",
+ "workspace2",
+]
+
 [[package]]
 name = "semver"
 version = "1.0.18"

Cargo.toml 🔗

@@ -95,6 +95,8 @@ members = [
     "crates/rpc2",
     "crates/search",
     "crates/search2",
+    "crates/semantic_index",
+    "crates/semantic_index2",
     "crates/settings",
     "crates/settings2",
     "crates/snippet",
@@ -114,7 +116,6 @@ members = [
     "crates/theme_selector2",
     "crates/ui2",
     "crates/util",
-    "crates/semantic_index",
     "crates/story",
     "crates/vim",
     "crates/vcs_menu",

crates/ai2/src/auth.rs 🔗

@@ -7,7 +7,7 @@ pub enum ProviderCredential {
     NotNeeded,
 }
 
-pub trait CredentialProvider {
+pub trait CredentialProvider: Send + Sync {
     fn has_credentials(&self) -> bool;
     fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
     fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);

crates/ai2/src/providers/open_ai/embedding.rs 🔗

@@ -35,7 +35,7 @@ pub struct OpenAIEmbeddingProvider {
     model: OpenAILanguageModel,
     credential: Arc<RwLock<ProviderCredential>>,
     pub client: Arc<dyn HttpClient>,
-    pub executor: Arc<BackgroundExecutor>,
+    pub executor: BackgroundExecutor,
     rate_limit_count_rx: watch::Receiver<Option<Instant>>,
     rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 }
@@ -66,7 +66,7 @@ struct OpenAIEmbeddingUsage {
 }
 
 impl OpenAIEmbeddingProvider {
-    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<BackgroundExecutor>) -> Self {
+    pub fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
         let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
         let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 

crates/gpui2/src/app/entity_map.rs 🔗

@@ -482,10 +482,6 @@ impl<T: 'static> WeakModel<T> {
     /// Update the entity referenced by this model with the given function if
     /// the referenced entity still exists. Returns an error if the entity has
     /// been released.
-    ///
-    /// The update function receives a context appropriate for its environment.
-    /// When updating in an `AppContext`, it receives a `ModelContext`.
-    /// When updating an a `WindowContext`, it receives a `ViewContext`.
     pub fn update<C, R>(
         &self,
         cx: &mut C,
@@ -501,6 +497,21 @@ impl<T: 'static> WeakModel<T> {
                 .map(|this| cx.update_model(&this, update)),
         )
     }
+
+    /// Reads the entity referenced by this model with the given function if
+    /// the referenced entity still exists. Returns an error if the entity has
+    /// been released.
+    pub fn read_with<C, R>(&self, cx: &C, read: impl FnOnce(&T, &AppContext) -> R) -> Result<R>
+    where
+        C: Context,
+        Result<C::Result<R>>: crate::Flatten<R>,
+    {
+        crate::Flatten::flatten(
+            self.upgrade()
+                .ok_or_else(|| anyhow!("entity release"))
+                .map(|this| cx.read_model(&this, read)),
+        )
+    }
 }
 
 impl<T> Hash for WeakModel<T> {

crates/semantic_index2/Cargo.toml 🔗

@@ -0,0 +1,69 @@
+[package]
+name = "semantic_index2"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/semantic_index.rs"
+doctest = false
+
+[dependencies]
+ai = { package = "ai2", path = "../ai2" }
+collections = { path = "../collections" }
+gpui = { package = "gpui2", path = "../gpui2" }
+language = { package = "language2", path = "../language2" }
+project = { package = "project2", path = "../project2" }
+workspace = { package = "workspace2", path = "../workspace2" }
+util = { path = "../util" }
+rpc = { package = "rpc2", path = "../rpc2" }
+settings = { package = "settings2", path = "../settings2" }
+anyhow.workspace = true
+postage.workspace = true
+futures.workspace = true
+ordered-float.workspace = true
+smol.workspace = true
+rusqlite.workspace = true
+log.workspace = true
+tree-sitter.workspace = true
+lazy_static.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+async-trait.workspace = true
+tiktoken-rs.workspace = true
+parking_lot.workspace = true
+rand.workspace = true
+schemars.workspace = true
+globset.workspace = true
+sha1 = "0.10.5"
+ndarray = { version = "0.15.0" }
+
+[dev-dependencies]
+ai = { package = "ai2", path = "../ai2", features = ["test-support"] }
+collections = { path = "../collections", features = ["test-support"] }
+gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
+language = { package = "language2", path = "../language2", features = ["test-support"] }
+project = { package = "project2", path = "../project2", features = ["test-support"] }
+rpc = { package = "rpc2", path = "../rpc2", features = ["test-support"] }
+workspace = { package = "workspace2", path = "../workspace2", features = ["test-support"] }
+settings = { package = "settings2", path = "../settings2", features = ["test-support"]}
+rust-embed = { version = "8.0", features = ["include-exclude"] }
+client = { package = "client2", path = "../client2" }
+node_runtime = { path = "../node_runtime"}
+
+pretty_assertions.workspace = true
+rand.workspace = true
+unindent.workspace = true
+tempdir.workspace = true
+ctor.workspace = true
+env_logger.workspace = true
+
+tree-sitter-typescript.workspace = true
+tree-sitter-json.workspace = true
+tree-sitter-rust.workspace = true
+tree-sitter-toml.workspace = true
+tree-sitter-cpp.workspace = true
+tree-sitter-elixir.workspace = true
+tree-sitter-lua.workspace = true
+tree-sitter-ruby.workspace = true
+tree-sitter-php.workspace = true

crates/semantic_index2/README.md 🔗

@@ -0,0 +1,20 @@
+
+# Semantic Index
+
+## Evaluation
+
+### Metrics
+
+nDCG@k:
+- "The value of NDCG is determined by comparing the relevance of the items returned by the search engine to the relevance of the item that a hypothetical "ideal" search engine would return.
+- "The relevance of result is represented by a score (also known as a 'grade') that is assigned to the search query. The scores of these results are then discounted based on their position in the search results -- did they get recommended first or last?"
+
+MRR@k:
+- "Mean reciprocal rank quantifies the rank of the first relevant item found in teh recommendation list."
+
+MAP@k:
+- "Mean average precision averages the precision@k metric at each relevant item position in the recommendation list.
+
+Resources:
+- [Evaluating recommendation metrics](https://www.shaped.ai/blog/evaluating-recommendation-systems-map-mmr-ndcg)
+- [Math Walkthrough](https://towardsdatascience.com/demystifying-ndcg-bee3be58cfe0)

crates/semantic_index2/eval/gpt-engineer.json 🔗

@@ -0,0 +1,114 @@
+{
+  "repo": "https://github.com/AntonOsika/gpt-engineer.git",
+  "commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
+  "assertions": [
+    {
+      "query": "How do I contribute to this project?",
+      "matches": [
+        ".github/CONTRIBUTING.md:1",
+        "ROADMAP.md:48"
+      ]
+    },
+    {
+      "query": "What version of the openai package is active?",
+      "matches": [
+        "pyproject.toml:14"
+      ]
+    },
+    {
+      "query": "Ask user for clarification",
+      "matches": [
+        "gpt_engineer/steps.py:69"
+      ]
+    },
+    {
+      "query": "generate tests for python code",
+      "matches": [
+        "gpt_engineer/steps.py:153"
+      ]
+    },
+    {
+      "query": "get item from database based on key",
+      "matches": [
+        "gpt_engineer/db.py:42",
+        "gpt_engineer/db.py:68"
+      ]
+    },
+    {
+      "query": "prompt user to select files",
+      "matches": [
+        "gpt_engineer/file_selector.py:171",
+        "gpt_engineer/file_selector.py:306",
+        "gpt_engineer/file_selector.py:289",
+        "gpt_engineer/file_selector.py:234"
+      ]
+    },
+    {
+      "query": "send to rudderstack",
+      "matches": [
+        "gpt_engineer/collect.py:11",
+        "gpt_engineer/collect.py:38"
+      ]
+    },
+    {
+      "query": "parse code blocks from chat messages",
+      "matches": [
+        "gpt_engineer/chat_to_files.py:10",
+        "docs/intro/chat_parsing.md:1"
+      ]
+    },
+    {
+      "query": "how do I use the docker cli?",
+      "matches": [
+        "docker/README.md:1"
+      ]
+    },
+    {
+      "query": "ask the user if the code ran successfully?",
+      "matches": [
+        "gpt_engineer/learning.py:54"
+      ]
+    },
+    {
+      "query": "how is consent granted by the user?",
+      "matches": [
+        "gpt_engineer/learning.py:107",
+        "gpt_engineer/learning.py:130",
+        "gpt_engineer/learning.py:152"
+      ]
+    },
+    {
+      "query": "what are all the different steps the agent can take?",
+      "matches": [
+        "docs/intro/steps_module.md:1",
+        "gpt_engineer/steps.py:391"
+      ]
+    },
+    {
+      "query": "ask the user for clarification?",
+      "matches": [
+        "gpt_engineer/steps.py:69"
+      ]
+    },
+    {
+      "query": "what models are available?",
+      "matches": [
+        "gpt_engineer/ai.py:315",
+        "gpt_engineer/ai.py:341",
+        "docs/open-models.md:1"
+      ]
+    },
+    {
+      "query": "what is the current focus of the project?",
+      "matches": [
+        "ROADMAP.md:11"
+      ]
+    },
+    {
+      "query": "does the agent know how to fix code?",
+      "matches": [
+        "gpt_engineer/steps.py:367"
+      ]
+    }
+  ]
+}

crates/semantic_index2/eval/tree-sitter.json 🔗

@@ -0,0 +1,104 @@
+{
+  "repo": "https://github.com/tree-sitter/tree-sitter.git",
+  "commit": "46af27796a76c72d8466627d499f2bca4af958ee",
+  "assertions": [
+    {
+      "query": "What attributes are available for the tags configuration struct?",
+      "matches": [
+        "tags/src/lib.rs:24"
+      ]
+    },
+    {
+      "query": "create a new tag configuration",
+      "matches": [
+        "tags/src/lib.rs:119"
+      ]
+    },
+    {
+      "query": "generate tags based on config",
+      "matches": [
+        "tags/src/lib.rs:261"
+      ]
+    },
+    {
+      "query": "match on ts quantifier in rust",
+      "matches": [
+        "lib/binding_rust/lib.rs:139"
+      ]
+    },
+    {
+      "query": "cli command to generate tags",
+      "matches": [
+        "cli/src/tags.rs:10"
+      ]
+    },
+    {
+      "query": "what version of the tree-sitter-tags package is active?",
+      "matches": [
+        "tags/Cargo.toml:4"
+      ]
+    },
+    {
+      "query": "Insert a new parse state",
+      "matches": [
+        "cli/src/generate/build_tables/build_parse_table.rs:153"
+      ]
+    },
+    {
+      "query": "Handle conflict when numerous actions occur on the same symbol",
+      "matches": [
+        "cli/src/generate/build_tables/build_parse_table.rs:363",
+        "cli/src/generate/build_tables/build_parse_table.rs:442"
+      ]
+    },
+    {
+      "query": "Match based on associativity of actions",
+      "matches": [
+        "cri/src/generate/build_tables/build_parse_table.rs:542"
+      ]
+    },
+    {
+      "query": "Format token set display",
+      "matches": [
+        "cli/src/generate/build_tables/item.rs:246"
+      ]
+    },
+    {
+      "query": "extract choices from rule",
+      "matches": [
+        "cli/src/generate/prepare_grammar/flatten_grammar.rs:124"
+      ]
+    },
+    {
+      "query": "How do we identify if a symbol is being used?",
+      "matches": [
+        "cli/src/generate/prepare_grammar/flatten_grammar.rs:175"
+      ]
+    },
+    {
+      "query": "How do we launch the playground?",
+      "matches": [
+        "cli/src/playground.rs:46"
+      ]
+    },
+    {
+      "query": "How do we test treesitter query matches in rust?",
+      "matches": [
+        "cli/src/query_testing.rs:152",
+        "cli/src/tests/query_test.rs:781",
+        "cli/src/tests/query_test.rs:2163",
+        "cli/src/tests/query_test.rs:3781",
+        "cli/src/tests/query_test.rs:887"
+      ]
+    },
+    {
+      "query": "What does the CLI do?",
+      "matches": [
+        "cli/README.md:10",
+        "cli/loader/README.md:3",
+        "docs/section-5-implementation.md:14",
+        "docs/section-5-implementation.md:18"
+      ]
+    }
+  ]
+}

crates/semantic_index2/src/db.rs 🔗

@@ -0,0 +1,603 @@
+use crate::{
+    parsing::{Span, SpanDigest},
+    SEMANTIC_INDEX_VERSION,
+};
+use ai::embedding::Embedding;
+use anyhow::{anyhow, Context, Result};
+use collections::HashMap;
+use futures::channel::oneshot;
+use gpui::BackgroundExecutor;
+use ndarray::{Array1, Array2};
+use ordered_float::OrderedFloat;
+use project::Fs;
+use rpc::proto::Timestamp;
+use rusqlite::params;
+use rusqlite::types::Value;
+use std::{
+    future::Future,
+    ops::Range,
+    path::{Path, PathBuf},
+    rc::Rc,
+    sync::Arc,
+    time::SystemTime,
+};
+use util::{paths::PathMatcher, TryFutureExt};
+
+pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
+    let mut indices = (0..data.len()).collect::<Vec<_>>();
+    indices.sort_by_key(|&i| &data[i]);
+    indices.reverse();
+    indices
+}
+
+#[derive(Debug)]
+pub struct FileRecord {
+    pub id: usize,
+    pub relative_path: String,
+    pub mtime: Timestamp,
+}
+
+#[derive(Clone)]
+pub struct VectorDatabase {
+    path: Arc<Path>,
+    transactions:
+        smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
+}
+
+impl VectorDatabase {
+    pub async fn new(
+        fs: Arc<dyn Fs>,
+        path: Arc<Path>,
+        executor: BackgroundExecutor,
+    ) -> Result<Self> {
+        if let Some(db_directory) = path.parent() {
+            fs.create_dir(db_directory).await?;
+        }
+
+        let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
+            Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
+        >();
+        executor
+            .spawn({
+                let path = path.clone();
+                async move {
+                    let mut connection = rusqlite::Connection::open(&path)?;
+
+                    connection.pragma_update(None, "journal_mode", "wal")?;
+                    connection.pragma_update(None, "synchronous", "normal")?;
+                    connection.pragma_update(None, "cache_size", 1000000)?;
+                    connection.pragma_update(None, "temp_store", "MEMORY")?;
+
+                    while let Ok(transaction) = transactions_rx.recv().await {
+                        transaction(&mut connection);
+                    }
+
+                    anyhow::Ok(())
+                }
+                .log_err()
+            })
+            .detach();
+        let this = Self {
+            transactions: transactions_tx,
+            path,
+        };
+        this.initialize_database().await?;
+        Ok(this)
+    }
+
+    pub fn path(&self) -> &Arc<Path> {
+        &self.path
+    }
+
+    fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
+    where
+        F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
+        T: 'static + Send,
+    {
+        let (tx, rx) = oneshot::channel();
+        let transactions = self.transactions.clone();
+        async move {
+            if transactions
+                .send(Box::new(|connection| {
+                    let result = connection
+                        .transaction()
+                        .map_err(|err| anyhow!(err))
+                        .and_then(|transaction| {
+                            let result = f(&transaction)?;
+                            transaction.commit()?;
+                            Ok(result)
+                        });
+                    let _ = tx.send(result);
+                }))
+                .await
+                .is_err()
+            {
+                return Err(anyhow!("connection was dropped"))?;
+            }
+            rx.await?
+        }
+    }
+
+    fn initialize_database(&self) -> impl Future<Output = Result<()>> {
+        self.transact(|db| {
+            rusqlite::vtab::array::load_module(&db)?;
+
+            // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
+            let version_query = db.prepare("SELECT version from semantic_index_config");
+            let version = version_query
+                .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
+            if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
+                log::trace!("vector database schema up to date");
+                return Ok(());
+            }
+
+            log::trace!("vector database schema out of date. updating...");
+            // We renamed the `documents` table to `spans`, so we want to drop
+            // `documents` without recreating it if it exists.
+            db.execute("DROP TABLE IF EXISTS documents", [])
+                .context("failed to drop 'documents' table")?;
+            db.execute("DROP TABLE IF EXISTS spans", [])
+                .context("failed to drop 'spans' table")?;
+            db.execute("DROP TABLE IF EXISTS files", [])
+                .context("failed to drop 'files' table")?;
+            db.execute("DROP TABLE IF EXISTS worktrees", [])
+                .context("failed to drop 'worktrees' table")?;
+            db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
+                .context("failed to drop 'semantic_index_config' table")?;
+
+            // Initialize Vector Databasing Tables
+            db.execute(
+                "CREATE TABLE semantic_index_config (
+                    version INTEGER NOT NULL
+                )",
+                [],
+            )?;
+
+            db.execute(
+                "INSERT INTO semantic_index_config (version) VALUES (?1)",
+                params![SEMANTIC_INDEX_VERSION],
+            )?;
+
+            db.execute(
+                "CREATE TABLE worktrees (
+                    id INTEGER PRIMARY KEY AUTOINCREMENT,
+                    absolute_path VARCHAR NOT NULL
+                );
+                CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
+                ",
+                [],
+            )?;
+
+            db.execute(
+                "CREATE TABLE files (
+                    id INTEGER PRIMARY KEY AUTOINCREMENT,
+                    worktree_id INTEGER NOT NULL,
+                    relative_path VARCHAR NOT NULL,
+                    mtime_seconds INTEGER NOT NULL,
+                    mtime_nanos INTEGER NOT NULL,
+                    FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
+                )",
+                [],
+            )?;
+
+            db.execute(
+                "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
+                [],
+            )?;
+
+            db.execute(
+                "CREATE TABLE spans (
+                    id INTEGER PRIMARY KEY AUTOINCREMENT,
+                    file_id INTEGER NOT NULL,
+                    start_byte INTEGER NOT NULL,
+                    end_byte INTEGER NOT NULL,
+                    name VARCHAR NOT NULL,
+                    embedding BLOB NOT NULL,
+                    digest BLOB NOT NULL,
+                    FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
+                )",
+                [],
+            )?;
+            db.execute(
+                "CREATE INDEX spans_digest ON spans (digest)",
+                [],
+            )?;
+
+            log::trace!("vector database initialized with updated schema.");
+            Ok(())
+        })
+    }
+
+    pub fn delete_file(
+        &self,
+        worktree_id: i64,
+        delete_path: Arc<Path>,
+    ) -> impl Future<Output = Result<()>> {
+        self.transact(move |db| {
+            db.execute(
+                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
+                params![worktree_id, delete_path.to_str()],
+            )?;
+            Ok(())
+        })
+    }
+
+    pub fn insert_file(
+        &self,
+        worktree_id: i64,
+        path: Arc<Path>,
+        mtime: SystemTime,
+        spans: Vec<Span>,
+    ) -> impl Future<Output = Result<()>> {
+        self.transact(move |db| {
+            // Return the existing ID, if both the file and mtime match
+            let mtime = Timestamp::from(mtime);
+
+            db.execute(
+                "
+                REPLACE INTO files
+                (worktree_id, relative_path, mtime_seconds, mtime_nanos)
+                VALUES (?1, ?2, ?3, ?4)
+                ",
+                params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
+            )?;
+
+            let file_id = db.last_insert_rowid();
+
+            let mut query = db.prepare(
+                "
+                INSERT INTO spans
+                (file_id, start_byte, end_byte, name, embedding, digest)
+                VALUES (?1, ?2, ?3, ?4, ?5, ?6)
+                ",
+            )?;
+
+            for span in spans {
+                query.execute(params![
+                    file_id,
+                    span.range.start.to_string(),
+                    span.range.end.to_string(),
+                    span.name,
+                    span.embedding,
+                    span.digest
+                ])?;
+            }
+
+            Ok(())
+        })
+    }
+
+    pub fn worktree_previously_indexed(
+        &self,
+        worktree_root_path: &Path,
+    ) -> impl Future<Output = Result<bool>> {
+        let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
+        self.transact(move |db| {
+            let mut worktree_query =
+                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+            let worktree_id = worktree_query
+                .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
+
+            if worktree_id.is_ok() {
+                return Ok(true);
+            } else {
+                return Ok(false);
+            }
+        })
+    }
+
+    pub fn embeddings_for_digests(
+        &self,
+        digests: Vec<SpanDigest>,
+    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
+        self.transact(move |db| {
+            let mut query = db.prepare(
+                "
+                SELECT digest, embedding
+                FROM spans
+                WHERE digest IN rarray(?)
+                ",
+            )?;
+            let mut embeddings_by_digest = HashMap::default();
+            let digests = Rc::new(
+                digests
+                    .into_iter()
+                    .map(|p| Value::Blob(p.0.to_vec()))
+                    .collect::<Vec<_>>(),
+            );
+            let rows = query.query_map(params![digests], |row| {
+                Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
+            })?;
+
+            for row in rows {
+                if let Ok(row) = row {
+                    embeddings_by_digest.insert(row.0, row.1);
+                }
+            }
+
+            Ok(embeddings_by_digest)
+        })
+    }
+
+    pub fn embeddings_for_files(
+        &self,
+        worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
+    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
+        self.transact(move |db| {
+            let mut query = db.prepare(
+                "
+                SELECT digest, embedding
+                FROM spans
+                LEFT JOIN files ON files.id = spans.file_id
+                WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
+            ",
+            )?;
+            let mut embeddings_by_digest = HashMap::default();
+            for (worktree_id, file_paths) in worktree_id_file_paths {
+                let file_paths = Rc::new(
+                    file_paths
+                        .into_iter()
+                        .map(|p| Value::Text(p.to_string_lossy().into_owned()))
+                        .collect::<Vec<_>>(),
+                );
+                let rows = query.query_map(params![worktree_id, file_paths], |row| {
+                    Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
+                })?;
+
+                for row in rows {
+                    if let Ok(row) = row {
+                        embeddings_by_digest.insert(row.0, row.1);
+                    }
+                }
+            }
+
+            Ok(embeddings_by_digest)
+        })
+    }
+
+    pub fn find_or_create_worktree(
+        &self,
+        worktree_root_path: Arc<Path>,
+    ) -> impl Future<Output = Result<i64>> {
+        self.transact(move |db| {
+            let mut worktree_query =
+                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+            let worktree_id = worktree_query
+                .query_row(params![worktree_root_path.to_string_lossy()], |row| {
+                    Ok(row.get::<_, i64>(0)?)
+                });
+
+            if worktree_id.is_ok() {
+                return Ok(worktree_id?);
+            }
+
+            // If worktree_id is Err, insert new worktree
+            db.execute(
+                "INSERT into worktrees (absolute_path) VALUES (?1)",
+                params![worktree_root_path.to_string_lossy()],
+            )?;
+            Ok(db.last_insert_rowid())
+        })
+    }
+
+    pub fn get_file_mtimes(
+        &self,
+        worktree_id: i64,
+    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
+        self.transact(move |db| {
+            let mut statement = db.prepare(
+                "
+                SELECT relative_path, mtime_seconds, mtime_nanos
+                FROM files
+                WHERE worktree_id = ?1
+                ORDER BY relative_path",
+            )?;
+            let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
+            for row in statement.query_map(params![worktree_id], |row| {
+                Ok((
+                    row.get::<_, String>(0)?.into(),
+                    Timestamp {
+                        seconds: row.get(1)?,
+                        nanos: row.get(2)?,
+                    }
+                    .into(),
+                ))
+            })? {
+                let row = row?;
+                result.insert(row.0, row.1);
+            }
+            Ok(result)
+        })
+    }
+
+    pub fn top_k_search(
+        &self,
+        query_embedding: &Embedding,
+        limit: usize,
+        file_ids: &[i64],
+    ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
+        let file_ids = file_ids.to_vec();
+        let query = query_embedding.clone().0;
+        let query = Array1::from_vec(query);
+        self.transact(move |db| {
+            let mut query_statement = db.prepare(
+                "
+                    SELECT
+                        id, embedding
+                    FROM
+                        spans
+                    WHERE
+                        file_id IN rarray(?)
+                    ",
+            )?;
+
+            let deserialized_rows = query_statement
+                .query_map(params![ids_to_sql(&file_ids)], |row| {
+                    Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
+                })?
+                .filter_map(|row| row.ok())
+                .collect::<Vec<(usize, Embedding)>>();
+
+            if deserialized_rows.len() == 0 {
+                return Ok(Vec::new());
+            }
+
+            // Get Length of Embeddings Returned
+            let embedding_len = deserialized_rows[0].1 .0.len();
+
+            let batch_n = 1000;
+            let mut batches = Vec::new();
+            let mut batch_ids = Vec::new();
+            let mut batch_embeddings: Vec<f32> = Vec::new();
+            deserialized_rows.iter().for_each(|(id, embedding)| {
+                batch_ids.push(id);
+                batch_embeddings.extend(&embedding.0);
+
+                if batch_ids.len() == batch_n {
+                    let embeddings = std::mem::take(&mut batch_embeddings);
+                    let ids = std::mem::take(&mut batch_ids);
+                    let array =
+                        Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
+                    match array {
+                        Ok(array) => {
+                            batches.push((ids, array));
+                        }
+                        Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
+                    }
+                }
+            });
+
+            if batch_ids.len() > 0 {
+                let array = Array2::from_shape_vec(
+                    (batch_ids.len(), embedding_len),
+                    batch_embeddings.clone(),
+                );
+                match array {
+                    Ok(array) => {
+                        batches.push((batch_ids.clone(), array));
+                    }
+                    Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
+                }
+            }
+
+            let mut ids: Vec<usize> = Vec::new();
+            let mut results = Vec::new();
+            for (batch_ids, array) in batches {
+                let scores = array
+                    .dot(&query.t())
+                    .to_vec()
+                    .iter()
+                    .map(|score| OrderedFloat(*score))
+                    .collect::<Vec<OrderedFloat<f32>>>();
+                results.extend(scores);
+                ids.extend(batch_ids);
+            }
+
+            let sorted_idx = argsort(&results);
+            let mut sorted_results = Vec::new();
+            let last_idx = limit.min(sorted_idx.len());
+            for idx in &sorted_idx[0..last_idx] {
+                sorted_results.push((ids[*idx] as i64, results[*idx]))
+            }
+
+            Ok(sorted_results)
+        })
+    }
+
+    pub fn retrieve_included_file_ids(
+        &self,
+        worktree_ids: &[i64],
+        includes: &[PathMatcher],
+        excludes: &[PathMatcher],
+    ) -> impl Future<Output = Result<Vec<i64>>> {
+        let worktree_ids = worktree_ids.to_vec();
+        let includes = includes.to_vec();
+        let excludes = excludes.to_vec();
+        self.transact(move |db| {
+            let mut file_query = db.prepare(
+                "
+                SELECT
+                    id, relative_path
+                FROM
+                    files
+                WHERE
+                    worktree_id IN rarray(?)
+                ",
+            )?;
+
+            let mut file_ids = Vec::<i64>::new();
+            let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
+
+            while let Some(row) = rows.next()? {
+                let file_id = row.get(0)?;
+                let relative_path = row.get_ref(1)?.as_str()?;
+                let included =
+                    includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
+                let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
+                if included && !excluded {
+                    file_ids.push(file_id);
+                }
+            }
+
+            anyhow::Ok(file_ids)
+        })
+    }
+
+    pub fn spans_for_ids(
+        &self,
+        ids: &[i64],
+    ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
+        let ids = ids.to_vec();
+        self.transact(move |db| {
+            let mut statement = db.prepare(
+                "
+                    SELECT
+                        spans.id,
+                        files.worktree_id,
+                        files.relative_path,
+                        spans.start_byte,
+                        spans.end_byte
+                    FROM
+                        spans, files
+                    WHERE
+                        spans.file_id = files.id AND
+                        spans.id in rarray(?)
+                ",
+            )?;
+
+            let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
+                Ok((
+                    row.get::<_, i64>(0)?,
+                    row.get::<_, i64>(1)?,
+                    row.get::<_, String>(2)?.into(),
+                    row.get(3)?..row.get(4)?,
+                ))
+            })?;
+
+            let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
+            for row in result_iter {
+                let (id, worktree_id, path, range) = row?;
+                values_by_id.insert(id, (worktree_id, path, range));
+            }
+
+            let mut results = Vec::with_capacity(ids.len());
+            for id in &ids {
+                let value = values_by_id
+                    .remove(id)
+                    .ok_or(anyhow!("missing span id {}", id))?;
+                results.push(value);
+            }
+
+            Ok(results)
+        })
+    }
+}
+
+fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
+    Rc::new(
+        ids.iter()
+            .copied()
+            .map(|v| rusqlite::types::Value::from(v))
+            .collect::<Vec<_>>(),
+    )
+}

crates/semantic_index2/src/embedding_queue.rs 🔗

@@ -0,0 +1,169 @@
+use crate::{parsing::Span, JobHandle};
+use ai::embedding::EmbeddingProvider;
+use gpui::BackgroundExecutor;
+use parking_lot::Mutex;
+use smol::channel;
+use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
+
+#[derive(Clone)]
+pub struct FileToEmbed {
+    pub worktree_id: i64,
+    pub path: Arc<Path>,
+    pub mtime: SystemTime,
+    pub spans: Vec<Span>,
+    pub job_handle: JobHandle,
+}
+
+impl std::fmt::Debug for FileToEmbed {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("FileToEmbed")
+            .field("worktree_id", &self.worktree_id)
+            .field("path", &self.path)
+            .field("mtime", &self.mtime)
+            .field("spans", &self.spans)
+            .finish_non_exhaustive()
+    }
+}
+
+impl PartialEq for FileToEmbed {
+    fn eq(&self, other: &Self) -> bool {
+        self.worktree_id == other.worktree_id
+            && self.path == other.path
+            && self.mtime == other.mtime
+            && self.spans == other.spans
+    }
+}
+
+pub struct EmbeddingQueue {
+    embedding_provider: Arc<dyn EmbeddingProvider>,
+    pending_batch: Vec<FileFragmentToEmbed>,
+    executor: BackgroundExecutor,
+    pending_batch_token_count: usize,
+    finished_files_tx: channel::Sender<FileToEmbed>,
+    finished_files_rx: channel::Receiver<FileToEmbed>,
+}
+
+#[derive(Clone)]
+pub struct FileFragmentToEmbed {
+    file: Arc<Mutex<FileToEmbed>>,
+    span_range: Range<usize>,
+}
+
+impl EmbeddingQueue {
+    pub fn new(
+        embedding_provider: Arc<dyn EmbeddingProvider>,
+        executor: BackgroundExecutor,
+    ) -> Self {
+        let (finished_files_tx, finished_files_rx) = channel::unbounded();
+        Self {
+            embedding_provider,
+            executor,
+            pending_batch: Vec::new(),
+            pending_batch_token_count: 0,
+            finished_files_tx,
+            finished_files_rx,
+        }
+    }
+
+    pub fn push(&mut self, file: FileToEmbed) {
+        if file.spans.is_empty() {
+            self.finished_files_tx.try_send(file).unwrap();
+            return;
+        }
+
+        let file = Arc::new(Mutex::new(file));
+
+        self.pending_batch.push(FileFragmentToEmbed {
+            file: file.clone(),
+            span_range: 0..0,
+        });
+
+        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
+        for (ix, span) in file.lock().spans.iter().enumerate() {
+            let span_token_count = if span.embedding.is_none() {
+                span.token_count
+            } else {
+                0
+            };
+
+            let next_token_count = self.pending_batch_token_count + span_token_count;
+            if next_token_count > self.embedding_provider.max_tokens_per_batch() {
+                let range_end = fragment_range.end;
+                self.flush();
+                self.pending_batch.push(FileFragmentToEmbed {
+                    file: file.clone(),
+                    span_range: range_end..range_end,
+                });
+                fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
+            }
+
+            fragment_range.end = ix + 1;
+            self.pending_batch_token_count += span_token_count;
+        }
+    }
+
+    pub fn flush(&mut self) {
+        let batch = mem::take(&mut self.pending_batch);
+        self.pending_batch_token_count = 0;
+        if batch.is_empty() {
+            return;
+        }
+
+        let finished_files_tx = self.finished_files_tx.clone();
+        let embedding_provider = self.embedding_provider.clone();
+
+        self.executor
+            .spawn(async move {
+                let mut spans = Vec::new();
+                for fragment in &batch {
+                    let file = fragment.file.lock();
+                    spans.extend(
+                        file.spans[fragment.span_range.clone()]
+                            .iter()
+                            .filter(|d| d.embedding.is_none())
+                            .map(|d| d.content.clone()),
+                    );
+                }
+
+                // If spans is 0, just send the fragment to the finished files if its the last one.
+                if spans.is_empty() {
+                    for fragment in batch.clone() {
+                        if let Some(file) = Arc::into_inner(fragment.file) {
+                            finished_files_tx.try_send(file.into_inner()).unwrap();
+                        }
+                    }
+                    return;
+                };
+
+                match embedding_provider.embed_batch(spans).await {
+                    Ok(embeddings) => {
+                        let mut embeddings = embeddings.into_iter();
+                        for fragment in batch {
+                            for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
+                                .iter_mut()
+                                .filter(|d| d.embedding.is_none())
+                            {
+                                if let Some(embedding) = embeddings.next() {
+                                    span.embedding = Some(embedding);
+                                } else {
+                                    log::error!("number of embeddings != number of documents");
+                                }
+                            }
+
+                            if let Some(file) = Arc::into_inner(fragment.file) {
+                                finished_files_tx.try_send(file.into_inner()).unwrap();
+                            }
+                        }
+                    }
+                    Err(error) => {
+                        log::error!("{:?}", error);
+                    }
+                }
+            })
+            .detach();
+    }
+
+    pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
+        self.finished_files_rx.clone()
+    }
+}

crates/semantic_index2/src/parsing.rs 🔗

@@ -0,0 +1,414 @@
+use ai::{
+    embedding::{Embedding, EmbeddingProvider},
+    models::TruncationDirection,
+};
+use anyhow::{anyhow, Result};
+use language::{Grammar, Language};
+use rusqlite::{
+    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
+    ToSql,
+};
+use sha1::{Digest, Sha1};
+use std::{
+    borrow::Cow,
+    cmp::{self, Reverse},
+    collections::HashSet,
+    ops::Range,
+    path::Path,
+    sync::Arc,
+};
+use tree_sitter::{Parser, QueryCursor};
+
+#[derive(Debug, PartialEq, Eq, Clone, Hash)]
+pub struct SpanDigest(pub [u8; 20]);
+
+impl FromSql for SpanDigest {
+    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+        let blob = value.as_blob()?;
+        let bytes =
+            blob.try_into()
+                .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
+                    expected_size: 20,
+                    blob_size: blob.len(),
+                })?;
+        return Ok(SpanDigest(bytes));
+    }
+}
+
+impl ToSql for SpanDigest {
+    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+        self.0.to_sql()
+    }
+}
+
+impl From<&'_ str> for SpanDigest {
+    fn from(value: &'_ str) -> Self {
+        let mut sha1 = Sha1::new();
+        sha1.update(value);
+        Self(sha1.finalize().into())
+    }
+}
+
+#[derive(Debug, PartialEq, Clone)]
+pub struct Span {
+    pub name: String,
+    pub range: Range<usize>,
+    pub content: String,
+    pub embedding: Option<Embedding>,
+    pub digest: SpanDigest,
+    pub token_count: usize,
+}
+
+const CODE_CONTEXT_TEMPLATE: &str =
+    "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
+const ENTIRE_FILE_TEMPLATE: &str =
+    "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
+const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
+pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
+    "TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
+];
+
+pub struct CodeContextRetriever {
+    pub parser: Parser,
+    pub cursor: QueryCursor,
+    pub embedding_provider: Arc<dyn EmbeddingProvider>,
+}
+
+// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
+// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
+// If there are preceeding comments, we track this with a context capture
+// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
+// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
+#[derive(Debug, Clone)]
+pub struct CodeContextMatch {
+    pub start_col: usize,
+    pub item_range: Option<Range<usize>>,
+    pub name_range: Option<Range<usize>>,
+    pub context_ranges: Vec<Range<usize>>,
+    pub collapse_ranges: Vec<Range<usize>>,
+}
+
+impl CodeContextRetriever {
+    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
+        Self {
+            parser: Parser::new(),
+            cursor: QueryCursor::new(),
+            embedding_provider,
+        }
+    }
+
+    fn parse_entire_file(
+        &self,
+        relative_path: Option<&Path>,
+        language_name: Arc<str>,
+        content: &str,
+    ) -> Result<Vec<Span>> {
+        let document_span = ENTIRE_FILE_TEMPLATE
+            .replace(
+                "<path>",
+                &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
+            )
+            .replace("<language>", language_name.as_ref())
+            .replace("<item>", &content);
+        let digest = SpanDigest::from(document_span.as_str());
+        let model = self.embedding_provider.base_model();
+        let document_span = model.truncate(
+            &document_span,
+            model.capacity()?,
+            ai::models::TruncationDirection::End,
+        )?;
+        let token_count = model.count_tokens(&document_span)?;
+
+        Ok(vec![Span {
+            range: 0..content.len(),
+            content: document_span,
+            embedding: Default::default(),
+            name: language_name.to_string(),
+            digest,
+            token_count,
+        }])
+    }
+
+    fn parse_markdown_file(
+        &self,
+        relative_path: Option<&Path>,
+        content: &str,
+    ) -> Result<Vec<Span>> {
+        let document_span = MARKDOWN_CONTEXT_TEMPLATE
+            .replace(
+                "<path>",
+                &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
+            )
+            .replace("<item>", &content);
+        let digest = SpanDigest::from(document_span.as_str());
+
+        let model = self.embedding_provider.base_model();
+        let document_span = model.truncate(
+            &document_span,
+            model.capacity()?,
+            ai::models::TruncationDirection::End,
+        )?;
+        let token_count = model.count_tokens(&document_span)?;
+
+        Ok(vec![Span {
+            range: 0..content.len(),
+            content: document_span,
+            embedding: None,
+            name: "Markdown".to_string(),
+            digest,
+            token_count,
+        }])
+    }
+
+    fn get_matches_in_file(
+        &mut self,
+        content: &str,
+        grammar: &Arc<Grammar>,
+    ) -> Result<Vec<CodeContextMatch>> {
+        let embedding_config = grammar
+            .embedding_config
+            .as_ref()
+            .ok_or_else(|| anyhow!("no embedding queries"))?;
+        self.parser.set_language(grammar.ts_language).unwrap();
+
+        let tree = self
+            .parser
+            .parse(&content, None)
+            .ok_or_else(|| anyhow!("parsing failed"))?;
+
+        let mut captures: Vec<CodeContextMatch> = Vec::new();
+        let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
+        let mut keep_ranges: Vec<Range<usize>> = Vec::new();
+        for mat in self.cursor.matches(
+            &embedding_config.query,
+            tree.root_node(),
+            content.as_bytes(),
+        ) {
+            let mut start_col = 0;
+            let mut item_range: Option<Range<usize>> = None;
+            let mut name_range: Option<Range<usize>> = None;
+            let mut context_ranges: Vec<Range<usize>> = Vec::new();
+            collapse_ranges.clear();
+            keep_ranges.clear();
+            for capture in mat.captures {
+                if capture.index == embedding_config.item_capture_ix {
+                    item_range = Some(capture.node.byte_range());
+                    start_col = capture.node.start_position().column;
+                } else if Some(capture.index) == embedding_config.name_capture_ix {
+                    name_range = Some(capture.node.byte_range());
+                } else if Some(capture.index) == embedding_config.context_capture_ix {
+                    context_ranges.push(capture.node.byte_range());
+                } else if Some(capture.index) == embedding_config.collapse_capture_ix {
+                    collapse_ranges.push(capture.node.byte_range());
+                } else if Some(capture.index) == embedding_config.keep_capture_ix {
+                    keep_ranges.push(capture.node.byte_range());
+                }
+            }
+
+            captures.push(CodeContextMatch {
+                start_col,
+                item_range,
+                name_range,
+                context_ranges,
+                collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
+            });
+        }
+        Ok(captures)
+    }
+
+    pub fn parse_file_with_template(
+        &mut self,
+        relative_path: Option<&Path>,
+        content: &str,
+        language: Arc<Language>,
+    ) -> Result<Vec<Span>> {
+        let language_name = language.name();
+
+        if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
+            return self.parse_entire_file(relative_path, language_name, &content);
+        } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
+            return self.parse_markdown_file(relative_path, &content);
+        }
+
+        let mut spans = self.parse_file(content, language)?;
+        for span in &mut spans {
+            let document_content = CODE_CONTEXT_TEMPLATE
+                .replace(
+                    "<path>",
+                    &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
+                )
+                .replace("<language>", language_name.as_ref())
+                .replace("item", &span.content);
+
+            let model = self.embedding_provider.base_model();
+            let document_content = model.truncate(
+                &document_content,
+                model.capacity()?,
+                TruncationDirection::End,
+            )?;
+            let token_count = model.count_tokens(&document_content)?;
+
+            span.content = document_content;
+            span.token_count = token_count;
+        }
+        Ok(spans)
+    }
+
+    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
+        let grammar = language
+            .grammar()
+            .ok_or_else(|| anyhow!("no grammar for language"))?;
+
+        // Iterate through query matches
+        let matches = self.get_matches_in_file(content, grammar)?;
+
+        let language_scope = language.default_scope();
+        let placeholder = language_scope.collapsed_placeholder();
+
+        let mut spans = Vec::new();
+        let mut collapsed_ranges_within = Vec::new();
+        let mut parsed_name_ranges = HashSet::new();
+        for (i, context_match) in matches.iter().enumerate() {
+            // Items which are collapsible but not embeddable have no item range
+            let item_range = if let Some(item_range) = context_match.item_range.clone() {
+                item_range
+            } else {
+                continue;
+            };
+
+            // Checks for deduplication
+            let name;
+            if let Some(name_range) = context_match.name_range.clone() {
+                name = content
+                    .get(name_range.clone())
+                    .map_or(String::new(), |s| s.to_string());
+                if parsed_name_ranges.contains(&name_range) {
+                    continue;
+                }
+                parsed_name_ranges.insert(name_range);
+            } else {
+                name = String::new();
+            }
+
+            collapsed_ranges_within.clear();
+            'outer: for remaining_match in &matches[(i + 1)..] {
+                for collapsed_range in &remaining_match.collapse_ranges {
+                    if item_range.start <= collapsed_range.start
+                        && item_range.end >= collapsed_range.end
+                    {
+                        collapsed_ranges_within.push(collapsed_range.clone());
+                    } else {
+                        break 'outer;
+                    }
+                }
+            }
+
+            collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
+
+            let mut span_content = String::new();
+            for context_range in &context_match.context_ranges {
+                add_content_from_range(
+                    &mut span_content,
+                    content,
+                    context_range.clone(),
+                    context_match.start_col,
+                );
+                span_content.push_str("\n");
+            }
+
+            let mut offset = item_range.start;
+            for collapsed_range in &collapsed_ranges_within {
+                if collapsed_range.start > offset {
+                    add_content_from_range(
+                        &mut span_content,
+                        content,
+                        offset..collapsed_range.start,
+                        context_match.start_col,
+                    );
+                    offset = collapsed_range.start;
+                }
+
+                if collapsed_range.end > offset {
+                    span_content.push_str(placeholder);
+                    offset = collapsed_range.end;
+                }
+            }
+
+            if offset < item_range.end {
+                add_content_from_range(
+                    &mut span_content,
+                    content,
+                    offset..item_range.end,
+                    context_match.start_col,
+                );
+            }
+
+            let sha1 = SpanDigest::from(span_content.as_str());
+            spans.push(Span {
+                name,
+                content: span_content,
+                range: item_range.clone(),
+                embedding: None,
+                digest: sha1,
+                token_count: 0,
+            })
+        }
+
+        return Ok(spans);
+    }
+}
+
+pub(crate) fn subtract_ranges(
+    ranges: &[Range<usize>],
+    ranges_to_subtract: &[Range<usize>],
+) -> Vec<Range<usize>> {
+    let mut result = Vec::new();
+
+    let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
+
+    for range in ranges {
+        let mut offset = range.start;
+
+        while offset < range.end {
+            if let Some(range_to_subtract) = ranges_to_subtract.peek() {
+                if offset < range_to_subtract.start {
+                    let next_offset = cmp::min(range_to_subtract.start, range.end);
+                    result.push(offset..next_offset);
+                    offset = next_offset;
+                } else {
+                    let next_offset = cmp::min(range_to_subtract.end, range.end);
+                    offset = next_offset;
+                }
+
+                if offset >= range_to_subtract.end {
+                    ranges_to_subtract.next();
+                }
+            } else {
+                result.push(offset..range.end);
+                offset = range.end;
+            }
+        }
+    }
+
+    result
+}
+
+fn add_content_from_range(
+    output: &mut String,
+    content: &str,
+    range: Range<usize>,
+    start_col: usize,
+) {
+    for mut line in content.get(range.clone()).unwrap_or("").lines() {
+        for _ in 0..start_col {
+            if line.starts_with(' ') {
+                line = &line[1..];
+            } else {
+                break;
+            }
+        }
+        output.push_str(line);
+        output.push('\n');
+    }
+    output.pop();
+}

crates/semantic_index2/src/semantic_index.rs 🔗

@@ -0,0 +1,1280 @@
+mod db;
+mod embedding_queue;
+mod parsing;
+pub mod semantic_index_settings;
+
+#[cfg(test)]
+mod semantic_index_tests;
+
+use crate::semantic_index_settings::SemanticIndexSettings;
+use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
+use anyhow::{anyhow, Context as _, Result};
+use collections::{BTreeMap, HashMap, HashSet};
+use db::VectorDatabase;
+use embedding_queue::{EmbeddingQueue, FileToEmbed};
+use futures::{future, FutureExt, StreamExt};
+use gpui::{
+    AppContext, AsyncAppContext, BorrowWindow, Context, Model, ModelContext, Task, ViewContext,
+    WeakModel,
+};
+use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
+use lazy_static::lazy_static;
+use ordered_float::OrderedFloat;
+use parking_lot::Mutex;
+use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
+use postage::watch;
+use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
+use settings::Settings;
+use smol::channel;
+use std::{
+    cmp::Reverse,
+    env,
+    future::Future,
+    mem,
+    ops::Range,
+    path::{Path, PathBuf},
+    sync::{Arc, Weak},
+    time::{Duration, Instant, SystemTime},
+};
+use util::paths::PathMatcher;
+use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
+use workspace::Workspace;
+
+const SEMANTIC_INDEX_VERSION: usize = 11;
+const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
+const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
+
+lazy_static! {
+    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
+}
+
+pub fn init(
+    fs: Arc<dyn Fs>,
+    http_client: Arc<dyn HttpClient>,
+    language_registry: Arc<LanguageRegistry>,
+    cx: &mut AppContext,
+) {
+    SemanticIndexSettings::register(cx);
+
+    let db_file_path = EMBEDDINGS_DIR
+        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
+        .join("embeddings_db");
+
+    cx.observe_new_views(
+        |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
+            let Some(semantic_index) = SemanticIndex::global(cx) else {
+                return;
+            };
+            let project = workspace.project().clone();
+
+            if project.read(cx).is_local() {
+                cx.app_mut()
+                    .spawn(|mut cx| async move {
+                        let previously_indexed = semantic_index
+                            .update(&mut cx, |index, cx| {
+                                index.project_previously_indexed(&project, cx)
+                            })?
+                            .await?;
+                        if previously_indexed {
+                            semantic_index
+                                .update(&mut cx, |index, cx| index.index_project(project, cx))?
+                                .await?;
+                        }
+                        anyhow::Ok(())
+                    })
+                    .detach_and_log_err(cx);
+            }
+        },
+    )
+    .detach();
+
+    cx.spawn(move |cx| async move {
+        let semantic_index = SemanticIndex::new(
+            fs,
+            db_file_path,
+            Arc::new(OpenAIEmbeddingProvider::new(
+                http_client,
+                cx.background_executor().clone(),
+            )),
+            language_registry,
+            cx.clone(),
+        )
+        .await?;
+
+        cx.update(|cx| cx.set_global(semantic_index.clone()))?;
+
+        anyhow::Ok(())
+    })
+    .detach();
+}
+
+#[derive(Copy, Clone, Debug)]
+pub enum SemanticIndexStatus {
+    NotAuthenticated,
+    NotIndexed,
+    Indexed,
+    Indexing {
+        remaining_files: usize,
+        rate_limit_expiry: Option<Instant>,
+    },
+}
+
+pub struct SemanticIndex {
+    fs: Arc<dyn Fs>,
+    db: VectorDatabase,
+    embedding_provider: Arc<dyn EmbeddingProvider>,
+    language_registry: Arc<LanguageRegistry>,
+    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
+    _embedding_task: Task<()>,
+    _parsing_files_tasks: Vec<Task<()>>,
+    projects: HashMap<WeakModel<Project>, ProjectState>,
+}
+
+struct ProjectState {
+    worktrees: HashMap<WorktreeId, WorktreeState>,
+    pending_file_count_rx: watch::Receiver<usize>,
+    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
+    pending_index: usize,
+    _subscription: gpui::Subscription,
+    _observe_pending_file_count: Task<()>,
+}
+
+enum WorktreeState {
+    Registering(RegisteringWorktreeState),
+    Registered(RegisteredWorktreeState),
+}
+
+impl WorktreeState {
+    fn is_registered(&self) -> bool {
+        matches!(self, Self::Registered(_))
+    }
+
+    fn paths_changed(
+        &mut self,
+        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
+        worktree: &Worktree,
+    ) {
+        let changed_paths = match self {
+            Self::Registering(state) => &mut state.changed_paths,
+            Self::Registered(state) => &mut state.changed_paths,
+        };
+
+        for (path, entry_id, change) in changes.iter() {
+            let Some(entry) = worktree.entry_for_id(*entry_id) else {
+                continue;
+            };
+            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
+                continue;
+            }
+            changed_paths.insert(
+                path.clone(),
+                ChangedPathInfo {
+                    mtime: entry.mtime,
+                    is_deleted: *change == PathChange::Removed,
+                },
+            );
+        }
+    }
+}
+
+struct RegisteringWorktreeState {
+    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
+    done_rx: watch::Receiver<Option<()>>,
+    _registration: Task<()>,
+}
+
+impl RegisteringWorktreeState {
+    fn done(&self) -> impl Future<Output = ()> {
+        let mut done_rx = self.done_rx.clone();
+        async move {
+            while let Some(result) = done_rx.next().await {
+                if result.is_some() {
+                    break;
+                }
+            }
+        }
+    }
+}
+
+struct RegisteredWorktreeState {
+    db_id: i64,
+    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
+}
+
+struct ChangedPathInfo {
+    mtime: SystemTime,
+    is_deleted: bool,
+}
+
+#[derive(Clone)]
+pub struct JobHandle {
+    /// The outer Arc is here to count the clones of a JobHandle instance;
+    /// when the last handle to a given job is dropped, we decrement a counter (just once).
+    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
+}
+
+impl JobHandle {
+    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
+        *tx.lock().borrow_mut() += 1;
+        Self {
+            tx: Arc::new(Arc::downgrade(&tx)),
+        }
+    }
+}
+
+impl ProjectState {
+    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
+        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
+        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
+        Self {
+            worktrees: Default::default(),
+            pending_file_count_rx: pending_file_count_rx.clone(),
+            pending_file_count_tx,
+            pending_index: 0,
+            _subscription: subscription,
+            _observe_pending_file_count: cx.spawn({
+                let mut pending_file_count_rx = pending_file_count_rx.clone();
+                |this, mut cx| async move {
+                    while let Some(_) = pending_file_count_rx.next().await {
+                        if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
+                            break;
+                        }
+                    }
+                }
+            }),
+        }
+    }
+
+    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
+        self.worktrees
+            .iter()
+            .find_map(|(worktree_id, worktree_state)| match worktree_state {
+                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
+                _ => None,
+            })
+    }
+}
+
+#[derive(Clone)]
+pub struct PendingFile {
+    worktree_db_id: i64,
+    relative_path: Arc<Path>,
+    absolute_path: PathBuf,
+    language: Option<Arc<Language>>,
+    modified_time: SystemTime,
+    job_handle: JobHandle,
+}
+
+#[derive(Clone)]
+pub struct SearchResult {
+    pub buffer: Model<Buffer>,
+    pub range: Range<Anchor>,
+    pub similarity: OrderedFloat<f32>,
+}
+
+impl SemanticIndex {
+    pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
+        if cx.has_global::<Model<Self>>() {
+            Some(cx.global::<Model<SemanticIndex>>().clone())
+        } else {
+            None
+        }
+    }
+
+    pub fn authenticate(&mut self, cx: &mut AppContext) -> bool {
+        if !self.embedding_provider.has_credentials() {
+            self.embedding_provider.retrieve_credentials(cx);
+        } else {
+            return true;
+        }
+
+        self.embedding_provider.has_credentials()
+    }
+
+    pub fn is_authenticated(&self) -> bool {
+        self.embedding_provider.has_credentials()
+    }
+
+    pub fn enabled(cx: &AppContext) -> bool {
+        SemanticIndexSettings::get_global(cx).enabled
+    }
+
+    pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
+        if !self.is_authenticated() {
+            return SemanticIndexStatus::NotAuthenticated;
+        }
+
+        if let Some(project_state) = self.projects.get(&project.downgrade()) {
+            if project_state
+                .worktrees
+                .values()
+                .all(|worktree| worktree.is_registered())
+                && project_state.pending_index == 0
+            {
+                SemanticIndexStatus::Indexed
+            } else {
+                SemanticIndexStatus::Indexing {
+                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
+                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
+                }
+            }
+        } else {
+            SemanticIndexStatus::NotIndexed
+        }
+    }
+
+    pub async fn new(
+        fs: Arc<dyn Fs>,
+        database_path: PathBuf,
+        embedding_provider: Arc<dyn EmbeddingProvider>,
+        language_registry: Arc<LanguageRegistry>,
+        mut cx: AsyncAppContext,
+    ) -> Result<Model<Self>> {
+        let t0 = Instant::now();
+        let database_path = Arc::from(database_path);
+        let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
+            .await?;
+
+        log::trace!(
+            "db initialization took {:?} milliseconds",
+            t0.elapsed().as_millis()
+        );
+
+        cx.build_model(|cx| {
+            let t0 = Instant::now();
+            let embedding_queue =
+                EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
+            let _embedding_task = cx.background_executor().spawn({
+                let embedded_files = embedding_queue.finished_files();
+                let db = db.clone();
+                async move {
+                    while let Ok(file) = embedded_files.recv().await {
+                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
+                            .await
+                            .log_err();
+                    }
+                }
+            });
+
+            // Parse files into embeddable spans.
+            let (parsing_files_tx, parsing_files_rx) =
+                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
+            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
+            let mut _parsing_files_tasks = Vec::new();
+            for _ in 0..cx.background_executor().num_cpus() {
+                let fs = fs.clone();
+                let mut parsing_files_rx = parsing_files_rx.clone();
+                let embedding_provider = embedding_provider.clone();
+                let embedding_queue = embedding_queue.clone();
+                let background = cx.background_executor().clone();
+                _parsing_files_tasks.push(cx.background_executor().spawn(async move {
+                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
+                    loop {
+                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
+                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
+                        futures::select_biased! {
+                            next_file_to_parse = next_file_to_parse => {
+                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
+                                    Self::parse_file(
+                                        &fs,
+                                        pending_file,
+                                        &mut retriever,
+                                        &embedding_queue,
+                                        &embeddings_for_digest,
+                                    )
+                                    .await
+                                } else {
+                                    break;
+                                }
+                            },
+                            _ = timer => {
+                                embedding_queue.lock().flush();
+                            }
+                        }
+                    }
+                }));
+            }
+
+            log::trace!(
+                "semantic index task initialization took {:?} milliseconds",
+                t0.elapsed().as_millis()
+            );
+            Self {
+                fs,
+                db,
+                embedding_provider,
+                language_registry,
+                parsing_files_tx,
+                _embedding_task,
+                _parsing_files_tasks,
+                projects: Default::default(),
+            }
+        })
+    }
+
+    async fn parse_file(
+        fs: &Arc<dyn Fs>,
+        pending_file: PendingFile,
+        retriever: &mut CodeContextRetriever,
+        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
+        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
+    ) {
+        let Some(language) = pending_file.language else {
+            return;
+        };
+
+        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
+            if let Some(mut spans) = retriever
+                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
+                .log_err()
+            {
+                log::trace!(
+                    "parsed path {:?}: {} spans",
+                    pending_file.relative_path,
+                    spans.len()
+                );
+
+                for span in &mut spans {
+                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
+                        span.embedding = Some(embedding.to_owned());
+                    }
+                }
+
+                embedding_queue.lock().push(FileToEmbed {
+                    worktree_id: pending_file.worktree_db_id,
+                    path: pending_file.relative_path,
+                    mtime: pending_file.modified_time,
+                    job_handle: pending_file.job_handle,
+                    spans,
+                });
+            }
+        }
+    }
+
+    pub fn project_previously_indexed(
+        &mut self,
+        project: &Model<Project>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<bool>> {
+        let worktrees_indexed_previously = project
+            .read(cx)
+            .worktrees()
+            .map(|worktree| {
+                self.db
+                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
+            })
+            .collect::<Vec<_>>();
+        cx.spawn(|_, _cx| async move {
+            let worktree_indexed_previously =
+                futures::future::join_all(worktrees_indexed_previously).await;
+
+            Ok(worktree_indexed_previously
+                .iter()
+                .filter(|worktree| worktree.is_ok())
+                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
+        })
+    }
+
+    fn project_entries_changed(
+        &mut self,
+        project: Model<Project>,
+        worktree_id: WorktreeId,
+        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
+            return;
+        };
+        let project = project.downgrade();
+        let Some(project_state) = self.projects.get_mut(&project) else {
+            return;
+        };
+
+        let worktree = worktree.read(cx);
+        let worktree_state =
+            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
+                worktree_state
+            } else {
+                return;
+            };
+        worktree_state.paths_changed(changes, worktree);
+        if let WorktreeState::Registered(_) = worktree_state {
+            cx.spawn(|this, mut cx| async move {
+                cx.background_executor()
+                    .timer(BACKGROUND_INDEXING_DELAY)
+                    .await;
+                if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
+                    this.update(&mut cx, |this, cx| {
+                        this.index_project(project, cx).detach_and_log_err(cx)
+                    })?;
+                }
+                anyhow::Ok(())
+            })
+            .detach_and_log_err(cx);
+        }
+    }
+
+    fn register_worktree(
+        &mut self,
+        project: Model<Project>,
+        worktree: Model<Worktree>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        let project = project.downgrade();
+        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
+            project_state
+        } else {
+            return;
+        };
+        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
+            worktree
+        } else {
+            return;
+        };
+        let worktree_abs_path = worktree.abs_path().clone();
+        let scan_complete = worktree.scan_complete();
+        let worktree_id = worktree.id();
+        let db = self.db.clone();
+        let language_registry = self.language_registry.clone();
+        let (mut done_tx, done_rx) = watch::channel();
+        let registration = cx.spawn(|this, mut cx| {
+            async move {
+                let register = async {
+                    scan_complete.await;
+                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
+                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
+                    let worktree = if let Some(project) = project.upgrade() {
+                        project
+                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
+                            .ok()
+                            .flatten()
+                            .context("worktree not found")?
+                    } else {
+                        return anyhow::Ok(());
+                    };
+                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
+                    let mut changed_paths = cx
+                        .background_executor()
+                        .spawn(async move {
+                            let mut changed_paths = BTreeMap::new();
+                            for file in worktree.files(false, 0) {
+                                let absolute_path = worktree.absolutize(&file.path);
+
+                                if file.is_external || file.is_ignored || file.is_symlink {
+                                    continue;
+                                }
+
+                                if let Ok(language) = language_registry
+                                    .language_for_file(&absolute_path, None)
+                                    .await
+                                {
+                                    // Test if file is valid parseable file
+                                    if !PARSEABLE_ENTIRE_FILE_TYPES
+                                        .contains(&language.name().as_ref())
+                                        && &language.name().as_ref() != &"Markdown"
+                                        && language
+                                            .grammar()
+                                            .and_then(|grammar| grammar.embedding_config.as_ref())
+                                            .is_none()
+                                    {
+                                        continue;
+                                    }
+
+                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
+                                    let already_stored = stored_mtime
+                                        .map_or(false, |existing_mtime| {
+                                            existing_mtime == file.mtime
+                                        });
+
+                                    if !already_stored {
+                                        changed_paths.insert(
+                                            file.path.clone(),
+                                            ChangedPathInfo {
+                                                mtime: file.mtime,
+                                                is_deleted: false,
+                                            },
+                                        );
+                                    }
+                                }
+                            }
+
+                            // Clean up entries from database that are no longer in the worktree.
+                            for (path, mtime) in file_mtimes {
+                                changed_paths.insert(
+                                    path.into(),
+                                    ChangedPathInfo {
+                                        mtime,
+                                        is_deleted: true,
+                                    },
+                                );
+                            }
+
+                            anyhow::Ok(changed_paths)
+                        })
+                        .await?;
+                    this.update(&mut cx, |this, cx| {
+                        let project_state = this
+                            .projects
+                            .get_mut(&project)
+                            .context("project not registered")?;
+                        let project = project.upgrade().context("project was dropped")?;
+
+                        if let Some(WorktreeState::Registering(state)) =
+                            project_state.worktrees.remove(&worktree_id)
+                        {
+                            changed_paths.extend(state.changed_paths);
+                        }
+                        project_state.worktrees.insert(
+                            worktree_id,
+                            WorktreeState::Registered(RegisteredWorktreeState {
+                                db_id,
+                                changed_paths,
+                            }),
+                        );
+                        this.index_project(project, cx).detach_and_log_err(cx);
+
+                        anyhow::Ok(())
+                    })??;
+
+                    anyhow::Ok(())
+                };
+
+                if register.await.log_err().is_none() {
+                    // Stop tracking this worktree if the registration failed.
+                    this.update(&mut cx, |this, _| {
+                        this.projects.get_mut(&project).map(|project_state| {
+                            project_state.worktrees.remove(&worktree_id);
+                        });
+                    })
+                    .ok();
+                }
+
+                *done_tx.borrow_mut() = Some(());
+            }
+        });
+        project_state.worktrees.insert(
+            worktree_id,
+            WorktreeState::Registering(RegisteringWorktreeState {
+                changed_paths: Default::default(),
+                done_rx,
+                _registration: registration,
+            }),
+        );
+    }
+
+    fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
+        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
+        {
+            project_state
+        } else {
+            return;
+        };
+
+        let mut worktrees = project
+            .read(cx)
+            .worktrees()
+            .filter(|worktree| worktree.read(cx).is_local())
+            .collect::<Vec<_>>();
+        let worktree_ids = worktrees
+            .iter()
+            .map(|worktree| worktree.read(cx).id())
+            .collect::<HashSet<_>>();
+
+        // Remove worktrees that are no longer present
+        project_state
+            .worktrees
+            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
+
+        // Register new worktrees
+        worktrees.retain(|worktree| {
+            let worktree_id = worktree.read(cx).id();
+            !project_state.worktrees.contains_key(&worktree_id)
+        });
+        for worktree in worktrees {
+            self.register_worktree(project.clone(), worktree, cx);
+        }
+    }
+
+    pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
+        Some(
+            self.projects
+                .get(&project.downgrade())?
+                .pending_file_count_rx
+                .clone(),
+        )
+    }
+
+    pub fn search_project(
+        &mut self,
+        project: Model<Project>,
+        query: String,
+        limit: usize,
+        includes: Vec<PathMatcher>,
+        excludes: Vec<PathMatcher>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<Vec<SearchResult>>> {
+        if query.is_empty() {
+            return Task::ready(Ok(Vec::new()));
+        }
+
+        let index = self.index_project(project.clone(), cx);
+        let embedding_provider = self.embedding_provider.clone();
+
+        cx.spawn(|this, mut cx| async move {
+            index.await?;
+            let t0 = Instant::now();
+
+            let query = embedding_provider
+                .embed_batch(vec![query])
+                .await?
+                .pop()
+                .context("could not embed query")?;
+            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
+
+            let search_start = Instant::now();
+            let modified_buffer_results = this.update(&mut cx, |this, cx| {
+                this.search_modified_buffers(
+                    &project,
+                    query.clone(),
+                    limit,
+                    &includes,
+                    &excludes,
+                    cx,
+                )
+            })?;
+            let file_results = this.update(&mut cx, |this, cx| {
+                this.search_files(project, query, limit, includes, excludes, cx)
+            })?;
+            let (modified_buffer_results, file_results) =
+                futures::join!(modified_buffer_results, file_results);
+
+            // Weave together the results from modified buffers and files.
+            let mut results = Vec::new();
+            let mut modified_buffers = HashSet::default();
+            for result in modified_buffer_results.log_err().unwrap_or_default() {
+                modified_buffers.insert(result.buffer.clone());
+                results.push(result);
+            }
+            for result in file_results.log_err().unwrap_or_default() {
+                if !modified_buffers.contains(&result.buffer) {
+                    results.push(result);
+                }
+            }
+            results.sort_by_key(|result| Reverse(result.similarity));
+            results.truncate(limit);
+            log::trace!("Semantic search took {:?}", search_start.elapsed());
+            Ok(results)
+        })
+    }
+
+    pub fn search_files(
+        &mut self,
+        project: Model<Project>,
+        query: Embedding,
+        limit: usize,
+        includes: Vec<PathMatcher>,
+        excludes: Vec<PathMatcher>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<Vec<SearchResult>>> {
+        let db_path = self.db.path().clone();
+        let fs = self.fs.clone();
+        cx.spawn(|this, mut cx| async move {
+            let database = VectorDatabase::new(
+                fs.clone(),
+                db_path.clone(),
+                cx.background_executor().clone(),
+            )
+            .await?;
+
+            let worktree_db_ids = this.read_with(&cx, |this, _| {
+                let project_state = this
+                    .projects
+                    .get(&project.downgrade())
+                    .context("project was not indexed")?;
+                let worktree_db_ids = project_state
+                    .worktrees
+                    .values()
+                    .filter_map(|worktree| {
+                        if let WorktreeState::Registered(worktree) = worktree {
+                            Some(worktree.db_id)
+                        } else {
+                            None
+                        }
+                    })
+                    .collect::<Vec<i64>>();
+                anyhow::Ok(worktree_db_ids)
+            })??;
+
+            let file_ids = database
+                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
+                .await?;
+
+            let batch_n = cx.background_executor().num_cpus();
+            let ids_len = file_ids.clone().len();
+            let minimum_batch_size = 50;
+
+            let batch_size = {
+                let size = ids_len / batch_n;
+                if size < minimum_batch_size {
+                    minimum_batch_size
+                } else {
+                    size
+                }
+            };
+
+            let mut batch_results = Vec::new();
+            for batch in file_ids.chunks(batch_size) {
+                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
+                let limit = limit.clone();
+                let fs = fs.clone();
+                let db_path = db_path.clone();
+                let query = query.clone();
+                if let Some(db) =
+                    VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
+                        .await
+                        .log_err()
+                {
+                    batch_results.push(async move {
+                        db.top_k_search(&query, limit, batch.as_slice()).await
+                    });
+                }
+            }
+
+            let batch_results = futures::future::join_all(batch_results).await;
+
+            let mut results = Vec::new();
+            for batch_result in batch_results {
+                if batch_result.is_ok() {
+                    for (id, similarity) in batch_result.unwrap() {
+                        let ix = match results
+                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
+                        {
+                            Ok(ix) => ix,
+                            Err(ix) => ix,
+                        };
+
+                        results.insert(ix, (id, similarity));
+                        results.truncate(limit);
+                    }
+                }
+            }
+
+            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
+            let scores = results
+                .into_iter()
+                .map(|(_, score)| score)
+                .collect::<Vec<_>>();
+            let spans = database.spans_for_ids(ids.as_slice()).await?;
+
+            let mut tasks = Vec::new();
+            let mut ranges = Vec::new();
+            let weak_project = project.downgrade();
+            project.update(&mut cx, |project, cx| {
+                let this = this.upgrade().context("index was dropped")?;
+                for (worktree_db_id, file_path, byte_range) in spans {
+                    let project_state =
+                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
+                            state
+                        } else {
+                            return Err(anyhow!("project not added"));
+                        };
+                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
+                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
+                        ranges.push(byte_range);
+                    }
+                }
+
+                Ok(())
+            })??;
+
+            let buffers = futures::future::join_all(tasks).await;
+            Ok(buffers
+                .into_iter()
+                .zip(ranges)
+                .zip(scores)
+                .filter_map(|((buffer, range), similarity)| {
+                    let buffer = buffer.log_err()?;
+                    let range = buffer
+                        .read_with(&cx, |buffer, _| {
+                            let start = buffer.clip_offset(range.start, Bias::Left);
+                            let end = buffer.clip_offset(range.end, Bias::Right);
+                            buffer.anchor_before(start)..buffer.anchor_after(end)
+                        })
+                        .log_err()?;
+                    Some(SearchResult {
+                        buffer,
+                        range,
+                        similarity,
+                    })
+                })
+                .collect())
+        })
+    }
+
+    fn search_modified_buffers(
+        &self,
+        project: &Model<Project>,
+        query: Embedding,
+        limit: usize,
+        includes: &[PathMatcher],
+        excludes: &[PathMatcher],
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<Vec<SearchResult>>> {
+        let modified_buffers = project
+            .read(cx)
+            .opened_buffers()
+            .into_iter()
+            .filter_map(|buffer_handle| {
+                let buffer = buffer_handle.read(cx);
+                let snapshot = buffer.snapshot();
+                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
+                    excludes.iter().any(|matcher| matcher.is_match(&path))
+                });
+
+                let included = if includes.len() == 0 {
+                    true
+                } else {
+                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
+                        includes.iter().any(|matcher| matcher.is_match(&path))
+                    })
+                };
+
+                if buffer.is_dirty() && !excluded && included {
+                    Some((buffer_handle, snapshot))
+                } else {
+                    None
+                }
+            })
+            .collect::<HashMap<_, _>>();
+
+        let embedding_provider = self.embedding_provider.clone();
+        let fs = self.fs.clone();
+        let db_path = self.db.path().clone();
+        let background = cx.background_executor().clone();
+        cx.background_executor().spawn(async move {
+            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
+            let mut results = Vec::<SearchResult>::new();
+
+            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
+            for (buffer, snapshot) in modified_buffers {
+                let language = snapshot
+                    .language_at(0)
+                    .cloned()
+                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
+                let mut spans = retriever
+                    .parse_file_with_template(None, &snapshot.text(), language)
+                    .log_err()
+                    .unwrap_or_default();
+                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
+                    .await
+                    .log_err()
+                    .is_some()
+                {
+                    for span in spans {
+                        let similarity = span.embedding.unwrap().similarity(&query);
+                        let ix = match results
+                            .binary_search_by_key(&Reverse(similarity), |result| {
+                                Reverse(result.similarity)
+                            }) {
+                            Ok(ix) => ix,
+                            Err(ix) => ix,
+                        };
+
+                        let range = {
+                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
+                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
+                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
+                        };
+
+                        results.insert(
+                            ix,
+                            SearchResult {
+                                buffer: buffer.clone(),
+                                range,
+                                similarity,
+                            },
+                        );
+                        results.truncate(limit);
+                    }
+                }
+            }
+
+            Ok(results)
+        })
+    }
+
+    pub fn index_project(
+        &mut self,
+        project: Model<Project>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
+        if !self.is_authenticated() {
+            if !self.authenticate(cx) {
+                return Task::ready(Err(anyhow!("user is not authenticated")));
+            }
+        }
+
+        if !self.projects.contains_key(&project.downgrade()) {
+            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
+                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
+                    this.project_worktrees_changed(project.clone(), cx);
+                }
+                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
+                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
+                }
+                _ => {}
+            });
+            let project_state = ProjectState::new(subscription, cx);
+            self.projects.insert(project.downgrade(), project_state);
+            self.project_worktrees_changed(project.clone(), cx);
+        }
+        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
+        project_state.pending_index += 1;
+        cx.notify();
+
+        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
+        let db = self.db.clone();
+        let language_registry = self.language_registry.clone();
+        let parsing_files_tx = self.parsing_files_tx.clone();
+        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
+
+        cx.spawn(|this, mut cx| async move {
+            worktree_registration.await?;
+
+            let mut pending_files = Vec::new();
+            let mut files_to_delete = Vec::new();
+            this.update(&mut cx, |this, cx| {
+                let project_state = this
+                    .projects
+                    .get_mut(&project.downgrade())
+                    .context("project was dropped")?;
+                let pending_file_count_tx = &project_state.pending_file_count_tx;
+
+                project_state
+                    .worktrees
+                    .retain(|worktree_id, worktree_state| {
+                        let worktree = if let Some(worktree) =
+                            project.read(cx).worktree_for_id(*worktree_id, cx)
+                        {
+                            worktree
+                        } else {
+                            return false;
+                        };
+                        let worktree_state =
+                            if let WorktreeState::Registered(worktree_state) = worktree_state {
+                                worktree_state
+                            } else {
+                                return true;
+                            };
+
+                        worktree_state.changed_paths.retain(|path, info| {
+                            if info.is_deleted {
+                                files_to_delete.push((worktree_state.db_id, path.clone()));
+                            } else {
+                                let absolute_path = worktree.read(cx).absolutize(path);
+                                let job_handle = JobHandle::new(pending_file_count_tx);
+                                pending_files.push(PendingFile {
+                                    absolute_path,
+                                    relative_path: path.clone(),
+                                    language: None,
+                                    job_handle,
+                                    modified_time: info.mtime,
+                                    worktree_db_id: worktree_state.db_id,
+                                });
+                            }
+
+                            false
+                        });
+                        true
+                    });
+
+                anyhow::Ok(())
+            })??;
+
+            cx.background_executor()
+                .spawn(async move {
+                    for (worktree_db_id, path) in files_to_delete {
+                        db.delete_file(worktree_db_id, path).await.log_err();
+                    }
+
+                    let embeddings_for_digest = {
+                        let mut files = HashMap::default();
+                        for pending_file in &pending_files {
+                            files
+                                .entry(pending_file.worktree_db_id)
+                                .or_insert(Vec::new())
+                                .push(pending_file.relative_path.clone());
+                        }
+                        Arc::new(
+                            db.embeddings_for_files(files)
+                                .await
+                                .log_err()
+                                .unwrap_or_default(),
+                        )
+                    };
+
+                    for mut pending_file in pending_files {
+                        if let Ok(language) = language_registry
+                            .language_for_file(&pending_file.relative_path, None)
+                            .await
+                        {
+                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
+                                && &language.name().as_ref() != &"Markdown"
+                                && language
+                                    .grammar()
+                                    .and_then(|grammar| grammar.embedding_config.as_ref())
+                                    .is_none()
+                            {
+                                continue;
+                            }
+                            pending_file.language = Some(language);
+                        }
+                        parsing_files_tx
+                            .try_send((embeddings_for_digest.clone(), pending_file))
+                            .ok();
+                    }
+
+                    // Wait until we're done indexing.
+                    while let Some(count) = pending_file_count_rx.next().await {
+                        if count == 0 {
+                            break;
+                        }
+                    }
+                })
+                .await;
+
+            this.update(&mut cx, |this, cx| {
+                let project_state = this
+                    .projects
+                    .get_mut(&project.downgrade())
+                    .context("project was dropped")?;
+                project_state.pending_index -= 1;
+                cx.notify();
+                anyhow::Ok(())
+            })??;
+
+            Ok(())
+        })
+    }
+
+    fn wait_for_worktree_registration(
+        &self,
+        project: &Model<Project>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
+        let project = project.downgrade();
+        cx.spawn(|this, cx| async move {
+            loop {
+                let mut pending_worktrees = Vec::new();
+                this.upgrade()
+                    .context("semantic index dropped")?
+                    .read_with(&cx, |this, _| {
+                        if let Some(project) = this.projects.get(&project) {
+                            for worktree in project.worktrees.values() {
+                                if let WorktreeState::Registering(worktree) = worktree {
+                                    pending_worktrees.push(worktree.done());
+                                }
+                            }
+                        }
+                    })?;
+
+                if pending_worktrees.is_empty() {
+                    break;
+                } else {
+                    future::join_all(pending_worktrees).await;
+                }
+            }
+            Ok(())
+        })
+    }
+
+    async fn embed_spans(
+        spans: &mut [Span],
+        embedding_provider: &dyn EmbeddingProvider,
+        db: &VectorDatabase,
+    ) -> Result<()> {
+        let mut batch = Vec::new();
+        let mut batch_tokens = 0;
+        let mut embeddings = Vec::new();
+
+        let digests = spans
+            .iter()
+            .map(|span| span.digest.clone())
+            .collect::<Vec<_>>();
+        let embeddings_for_digests = db
+            .embeddings_for_digests(digests)
+            .await
+            .log_err()
+            .unwrap_or_default();
+
+        for span in &*spans {
+            if embeddings_for_digests.contains_key(&span.digest) {
+                continue;
+            };
+
+            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
+                let batch_embeddings = embedding_provider
+                    .embed_batch(mem::take(&mut batch))
+                    .await?;
+                embeddings.extend(batch_embeddings);
+                batch_tokens = 0;
+            }
+
+            batch_tokens += span.token_count;
+            batch.push(span.content.clone());
+        }
+
+        if !batch.is_empty() {
+            let batch_embeddings = embedding_provider
+                .embed_batch(mem::take(&mut batch))
+                .await?;
+
+            embeddings.extend(batch_embeddings);
+        }
+
+        let mut embeddings = embeddings.into_iter();
+        for span in spans {
+            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
+                Some(embedding.clone())
+            } else {
+                embeddings.next()
+            };
+            let embedding = embedding.context("failed to embed spans")?;
+            span.embedding = Some(embedding);
+        }
+        Ok(())
+    }
+}
+
+impl Drop for JobHandle {
+    fn drop(&mut self) {
+        if let Some(inner) = Arc::get_mut(&mut self.tx) {
+            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
+            if let Some(tx) = inner.upgrade() {
+                let mut tx = tx.lock();
+                *tx.borrow_mut() -= 1;
+            }
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+
+    use super::*;
+    #[test]
+    fn test_job_handle() {
+        let (job_count_tx, job_count_rx) = watch::channel_with(0);
+        let tx = Arc::new(Mutex::new(job_count_tx));
+        let job_handle = JobHandle::new(&tx);
+
+        assert_eq!(1, *job_count_rx.borrow());
+        let new_job_handle = job_handle.clone();
+        assert_eq!(1, *job_count_rx.borrow());
+        drop(job_handle);
+        assert_eq!(1, *job_count_rx.borrow());
+        drop(new_job_handle);
+        assert_eq!(0, *job_count_rx.borrow());
+    }
+}

crates/semantic_index2/src/semantic_index_settings.rs 🔗

@@ -0,0 +1,28 @@
+use anyhow;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use settings::Settings;
+
+#[derive(Deserialize, Debug)]
+pub struct SemanticIndexSettings {
+    pub enabled: bool,
+}
+
+#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
+pub struct SemanticIndexSettingsContent {
+    pub enabled: Option<bool>,
+}
+
+impl Settings for SemanticIndexSettings {
+    const KEY: Option<&'static str> = Some("semantic_index");
+
+    type FileContent = SemanticIndexSettingsContent;
+
+    fn load(
+        default_value: &Self::FileContent,
+        user_values: &[&Self::FileContent],
+        _: &mut gpui::AppContext,
+    ) -> anyhow::Result<Self> {
+        Self::load_via_json_merge(default_value, user_values)
+    }
+}

crates/semantic_index2/src/semantic_index_tests.rs 🔗

@@ -0,0 +1,1697 @@
+use crate::{
+    embedding_queue::EmbeddingQueue,
+    parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
+    semantic_index_settings::SemanticIndexSettings,
+    FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
+};
+use ai::test::FakeEmbeddingProvider;
+
+use gpui::{Task, TestAppContext};
+use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
+use parking_lot::Mutex;
+use pretty_assertions::assert_eq;
+use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
+use rand::{rngs::StdRng, Rng};
+use serde_json::json;
+use settings::{Settings, SettingsStore};
+use std::{path::Path, sync::Arc, time::SystemTime};
+use unindent::Unindent;
+use util::{paths::PathMatcher, RandomCharIter};
+
+#[ctor::ctor]
+fn init_logger() {
+    if std::env::var("RUST_LOG").is_ok() {
+        env_logger::init();
+    }
+}
+
+#[gpui::test]
+async fn test_semantic_index(cx: &mut TestAppContext) {
+    init_test(cx);
+
+    let fs = FakeFs::new(cx.background_executor.clone());
+    fs.insert_tree(
+        "/the-root",
+        json!({
+            "src": {
+                "file1.rs": "
+                    fn aaa() {
+                        println!(\"aaaaaaaaaaaa!\");
+                    }
+
+                    fn zzzzz() {
+                        println!(\"SLEEPING\");
+                    }
+                ".unindent(),
+                "file2.rs": "
+                    fn bbb() {
+                        println!(\"bbbbbbbbbbbbb!\");
+                    }
+                    struct pqpqpqp {}
+                ".unindent(),
+                "file3.toml": "
+                    ZZZZZZZZZZZZZZZZZZ = 5
+                ".unindent(),
+            }
+        }),
+    )
+    .await;
+
+    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
+    let rust_language = rust_lang();
+    let toml_language = toml_lang();
+    languages.add(rust_language);
+    languages.add(toml_language);
+
+    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
+    let db_path = db_dir.path().join("db.sqlite");
+
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+    let semantic_index = SemanticIndex::new(
+        fs.clone(),
+        db_path,
+        embedding_provider.clone(),
+        languages,
+        cx.to_async(),
+    )
+    .await
+    .unwrap();
+
+    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
+
+    let search_results = semantic_index.update(cx, |store, cx| {
+        store.search_project(
+            project.clone(),
+            "aaaaaabbbbzz".to_string(),
+            5,
+            vec![],
+            vec![],
+            cx,
+        )
+    });
+    let pending_file_count =
+        semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
+    cx.background_executor.run_until_parked();
+    assert_eq!(*pending_file_count.borrow(), 3);
+    cx.background_executor
+        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
+    assert_eq!(*pending_file_count.borrow(), 0);
+
+    let search_results = search_results.await.unwrap();
+    assert_search_results(
+        &search_results,
+        &[
+            (Path::new("src/file1.rs").into(), 0),
+            (Path::new("src/file2.rs").into(), 0),
+            (Path::new("src/file3.toml").into(), 0),
+            (Path::new("src/file1.rs").into(), 45),
+            (Path::new("src/file2.rs").into(), 45),
+        ],
+        cx,
+    );
+
+    // Test Include Files Functonality
+    let include_files = vec![PathMatcher::new("*.rs").unwrap()];
+    let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
+    let rust_only_search_results = semantic_index
+        .update(cx, |store, cx| {
+            store.search_project(
+                project.clone(),
+                "aaaaaabbbbzz".to_string(),
+                5,
+                include_files,
+                vec![],
+                cx,
+            )
+        })
+        .await
+        .unwrap();
+
+    assert_search_results(
+        &rust_only_search_results,
+        &[
+            (Path::new("src/file1.rs").into(), 0),
+            (Path::new("src/file2.rs").into(), 0),
+            (Path::new("src/file1.rs").into(), 45),
+            (Path::new("src/file2.rs").into(), 45),
+        ],
+        cx,
+    );
+
+    let no_rust_search_results = semantic_index
+        .update(cx, |store, cx| {
+            store.search_project(
+                project.clone(),
+                "aaaaaabbbbzz".to_string(),
+                5,
+                vec![],
+                exclude_files,
+                cx,
+            )
+        })
+        .await
+        .unwrap();
+
+    assert_search_results(
+        &no_rust_search_results,
+        &[(Path::new("src/file3.toml").into(), 0)],
+        cx,
+    );
+
+    fs.save(
+        "/the-root/src/file2.rs".as_ref(),
+        &"
+            fn dddd() { println!(\"ddddd!\"); }
+            struct pqpqpqp {}
+        "
+        .unindent()
+        .into(),
+        Default::default(),
+    )
+    .await
+    .unwrap();
+
+    cx.background_executor
+        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
+
+    let prev_embedding_count = embedding_provider.embedding_count();
+    let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
+    cx.background_executor.run_until_parked();
+    assert_eq!(*pending_file_count.borrow(), 1);
+    cx.background_executor
+        .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
+    assert_eq!(*pending_file_count.borrow(), 0);
+    index.await.unwrap();
+
+    assert_eq!(
+        embedding_provider.embedding_count() - prev_embedding_count,
+        1
+    );
+}
+
+#[gpui::test(iterations = 10)]
+async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
+    let (outstanding_job_count, _) = postage::watch::channel_with(0);
+    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
+
+    let files = (1..=3)
+        .map(|file_ix| FileToEmbed {
+            worktree_id: 5,
+            path: Path::new(&format!("path-{file_ix}")).into(),
+            mtime: SystemTime::now(),
+            spans: (0..rng.gen_range(4..22))
+                .map(|document_ix| {
+                    let content_len = rng.gen_range(10..100);
+                    let content = RandomCharIter::new(&mut rng)
+                        .with_simple_text()
+                        .take(content_len)
+                        .collect::<String>();
+                    let digest = SpanDigest::from(content.as_str());
+                    Span {
+                        range: 0..10,
+                        embedding: None,
+                        name: format!("document {document_ix}"),
+                        content,
+                        digest,
+                        token_count: rng.gen_range(10..30),
+                    }
+                })
+                .collect(),
+            job_handle: JobHandle::new(&outstanding_job_count),
+        })
+        .collect::<Vec<_>>();
+
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+
+    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
+    for file in &files {
+        queue.push(file.clone());
+    }
+    queue.flush();
+
+    cx.background_executor.run_until_parked();
+    let finished_files = queue.finished_files();
+    let mut embedded_files: Vec<_> = files
+        .iter()
+        .map(|_| finished_files.try_recv().expect("no finished file"))
+        .collect();
+
+    let expected_files: Vec<_> = files
+        .iter()
+        .map(|file| {
+            let mut file = file.clone();
+            for doc in &mut file.spans {
+                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
+            }
+            file
+        })
+        .collect();
+
+    embedded_files.sort_by_key(|f| f.path.clone());
+
+    assert_eq!(embedded_files, expected_files);
+}
+
+#[track_caller]
+fn assert_search_results(
+    actual: &[SearchResult],
+    expected: &[(Arc<Path>, usize)],
+    cx: &TestAppContext,
+) {
+    let actual = actual
+        .iter()
+        .map(|search_result| {
+            search_result.buffer.read_with(cx, |buffer, _cx| {
+                (
+                    buffer.file().unwrap().path().clone(),
+                    search_result.range.start.to_offset(buffer),
+                )
+            })
+        })
+        .collect::<Vec<_>>();
+    assert_eq!(actual, expected);
+}
+
+#[gpui::test]
+async fn test_code_context_retrieval_rust() {
+    let language = rust_lang();
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
+
+    let text = "
+        /// A doc comment
+        /// that spans multiple lines
+        #[gpui::test]
+        fn a() {
+            b
+        }
+
+        impl C for D {
+        }
+
+        impl E {
+            // This is also a preceding comment
+            pub fn function_1() -> Option<()> {
+                unimplemented!();
+            }
+
+            // This is a preceding comment
+            fn function_2() -> Result<()> {
+                unimplemented!();
+            }
+        }
+
+        #[derive(Clone)]
+        struct D {
+            name: String
+        }
+    "
+    .unindent();
+
+    let documents = retriever.parse_file(&text, language).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[
+            (
+                "
+                /// A doc comment
+                /// that spans multiple lines
+                #[gpui::test]
+                fn a() {
+                    b
+                }"
+                .unindent(),
+                text.find("fn a").unwrap(),
+            ),
+            (
+                "
+                impl C for D {
+                }"
+                .unindent(),
+                text.find("impl C").unwrap(),
+            ),
+            (
+                "
+                impl E {
+                    // This is also a preceding comment
+                    pub fn function_1() -> Option<()> { /* ... */ }
+
+                    // This is a preceding comment
+                    fn function_2() -> Result<()> { /* ... */ }
+                }"
+                .unindent(),
+                text.find("impl E").unwrap(),
+            ),
+            (
+                "
+                // This is also a preceding comment
+                pub fn function_1() -> Option<()> {
+                    unimplemented!();
+                }"
+                .unindent(),
+                text.find("pub fn function_1").unwrap(),
+            ),
+            (
+                "
+                // This is a preceding comment
+                fn function_2() -> Result<()> {
+                    unimplemented!();
+                }"
+                .unindent(),
+                text.find("fn function_2").unwrap(),
+            ),
+            (
+                "
+                #[derive(Clone)]
+                struct D {
+                    name: String
+                }"
+                .unindent(),
+                text.find("struct D").unwrap(),
+            ),
+        ],
+    );
+}
+
+#[gpui::test]
+async fn test_code_context_retrieval_json() {
+    let language = json_lang();
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
+
+    let text = r#"
+        {
+            "array": [1, 2, 3, 4],
+            "string": "abcdefg",
+            "nested_object": {
+                "array_2": [5, 6, 7, 8],
+                "string_2": "hijklmnop",
+                "boolean": true,
+                "none": null
+            }
+        }
+    "#
+    .unindent();
+
+    let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[(
+            r#"
+                {
+                    "array": [],
+                    "string": "",
+                    "nested_object": {
+                        "array_2": [],
+                        "string_2": "",
+                        "boolean": true,
+                        "none": null
+                    }
+                }"#
+            .unindent(),
+            text.find("{").unwrap(),
+        )],
+    );
+
+    let text = r#"
+        [
+            {
+                "name": "somebody",
+                "age": 42
+            },
+            {
+                "name": "somebody else",
+                "age": 43
+            }
+        ]
+    "#
+    .unindent();
+
+    let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[(
+            r#"
+            [{
+                    "name": "",
+                    "age": 42
+                }]"#
+            .unindent(),
+            text.find("[").unwrap(),
+        )],
+    );
+}
+
+fn assert_documents_eq(
+    documents: &[Span],
+    expected_contents_and_start_offsets: &[(String, usize)],
+) {
+    assert_eq!(
+        documents
+            .iter()
+            .map(|document| (document.content.clone(), document.range.start))
+            .collect::<Vec<_>>(),
+        expected_contents_and_start_offsets
+    );
+}
+
+#[gpui::test]
+async fn test_code_context_retrieval_javascript() {
+    let language = js_lang();
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
+
+    let text = "
+        /* globals importScripts, backend */
+        function _authorize() {}
+
+        /**
+         * Sometimes the frontend build is way faster than backend.
+         */
+        export async function authorizeBank() {
+            _authorize(pushModal, upgradingAccountId, {});
+        }
+
+        export class SettingsPage {
+            /* This is a test setting */
+            constructor(page) {
+                this.page = page;
+            }
+        }
+
+        /* This is a test comment */
+        class TestClass {}
+
+        /* Schema for editor_events in Clickhouse. */
+        export interface ClickhouseEditorEvent {
+            installation_id: string
+            operation: string
+        }
+        "
+    .unindent();
+
+    let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[
+            (
+                "
+            /* globals importScripts, backend */
+            function _authorize() {}"
+                    .unindent(),
+                37,
+            ),
+            (
+                "
+            /**
+             * Sometimes the frontend build is way faster than backend.
+             */
+            export async function authorizeBank() {
+                _authorize(pushModal, upgradingAccountId, {});
+            }"
+                .unindent(),
+                131,
+            ),
+            (
+                "
+                export class SettingsPage {
+                    /* This is a test setting */
+                    constructor(page) {
+                        this.page = page;
+                    }
+                }"
+                .unindent(),
+                225,
+            ),
+            (
+                "
+                /* This is a test setting */
+                constructor(page) {
+                    this.page = page;
+                }"
+                .unindent(),
+                290,
+            ),
+            (
+                "
+                /* This is a test comment */
+                class TestClass {}"
+                    .unindent(),
+                374,
+            ),
+            (
+                "
+                /* Schema for editor_events in Clickhouse. */
+                export interface ClickhouseEditorEvent {
+                    installation_id: string
+                    operation: string
+                }"
+                .unindent(),
+                440,
+            ),
+        ],
+    )
+}
+
+#[gpui::test]
+async fn test_code_context_retrieval_lua() {
+    let language = lua_lang();
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
+
+    let text = r#"
+        -- Creates a new class
+        -- @param baseclass The Baseclass of this class, or nil.
+        -- @return A new class reference.
+        function classes.class(baseclass)
+            -- Create the class definition and metatable.
+            local classdef = {}
+            -- Find the super class, either Object or user-defined.
+            baseclass = baseclass or classes.Object
+            -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
+            setmetatable(classdef, { __index = baseclass })
+            -- All class instances have a reference to the class object.
+            classdef.class = classdef
+            --- Recursivly allocates the inheritance tree of the instance.
+            -- @param mastertable The 'root' of the inheritance tree.
+            -- @return Returns the instance with the allocated inheritance tree.
+            function classdef.alloc(mastertable)
+                -- All class instances have a reference to a superclass object.
+                local instance = { super = baseclass.alloc(mastertable) }
+                -- Any functions this instance does not know of will 'look up' to the superclass definition.
+                setmetatable(instance, { __index = classdef, __newindex = mastertable })
+                return instance
+            end
+        end
+        "#.unindent();
+
+    let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[
+            (r#"
+                -- Creates a new class
+                -- @param baseclass The Baseclass of this class, or nil.
+                -- @return A new class reference.
+                function classes.class(baseclass)
+                    -- Create the class definition and metatable.
+                    local classdef = {}
+                    -- Find the super class, either Object or user-defined.
+                    baseclass = baseclass or classes.Object
+                    -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
+                    setmetatable(classdef, { __index = baseclass })
+                    -- All class instances have a reference to the class object.
+                    classdef.class = classdef
+                    --- Recursivly allocates the inheritance tree of the instance.
+                    -- @param mastertable The 'root' of the inheritance tree.
+                    -- @return Returns the instance with the allocated inheritance tree.
+                    function classdef.alloc(mastertable)
+                        --[ ... ]--
+                        --[ ... ]--
+                    end
+                end"#.unindent(),
+            114),
+            (r#"
+            --- Recursivly allocates the inheritance tree of the instance.
+            -- @param mastertable The 'root' of the inheritance tree.
+            -- @return Returns the instance with the allocated inheritance tree.
+            function classdef.alloc(mastertable)
+                -- All class instances have a reference to a superclass object.
+                local instance = { super = baseclass.alloc(mastertable) }
+                -- Any functions this instance does not know of will 'look up' to the superclass definition.
+                setmetatable(instance, { __index = classdef, __newindex = mastertable })
+                return instance
+            end"#.unindent(), 809),
+        ]
+    );
+}
+
+#[gpui::test]
+async fn test_code_context_retrieval_elixir() {
+    let language = elixir_lang();
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
+
+    let text = r#"
+        defmodule File.Stream do
+            @moduledoc """
+            Defines a `File.Stream` struct returned by `File.stream!/3`.
+
+            The following fields are public:
+
+            * `path`          - the file path
+            * `modes`         - the file modes
+            * `raw`           - a boolean indicating if bin functions should be used
+            * `line_or_bytes` - if reading should read lines or a given number of bytes
+            * `node`          - the node the file belongs to
+
+            """
+
+            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
+
+            @type t :: %__MODULE__{}
+
+            @doc false
+            def __build__(path, modes, line_or_bytes) do
+            raw = :lists.keyfind(:encoding, 1, modes) == false
+
+            modes =
+                case raw do
+                true ->
+                    case :lists.keyfind(:read_ahead, 1, modes) do
+                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
+                    {:read_ahead, _} -> [:raw | modes]
+                    false -> [:raw, :read_ahead | modes]
+                    end
+
+                false ->
+                    modes
+                end
+
+            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
+
+            end"#
+    .unindent();
+
+    let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[(
+            r#"
+        defmodule File.Stream do
+            @moduledoc """
+            Defines a `File.Stream` struct returned by `File.stream!/3`.
+
+            The following fields are public:
+
+            * `path`          - the file path
+            * `modes`         - the file modes
+            * `raw`           - a boolean indicating if bin functions should be used
+            * `line_or_bytes` - if reading should read lines or a given number of bytes
+            * `node`          - the node the file belongs to
+
+            """
+
+            defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
+
+            @type t :: %__MODULE__{}
+
+            @doc false
+            def __build__(path, modes, line_or_bytes) do
+            raw = :lists.keyfind(:encoding, 1, modes) == false
+
+            modes =
+                case raw do
+                true ->
+                    case :lists.keyfind(:read_ahead, 1, modes) do
+                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
+                    {:read_ahead, _} -> [:raw | modes]
+                    false -> [:raw, :read_ahead | modes]
+                    end
+
+                false ->
+                    modes
+                end
+
+            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
+
+            end"#
+                .unindent(),
+            0,
+        ),(r#"
+            @doc false
+            def __build__(path, modes, line_or_bytes) do
+            raw = :lists.keyfind(:encoding, 1, modes) == false
+
+            modes =
+                case raw do
+                true ->
+                    case :lists.keyfind(:read_ahead, 1, modes) do
+                    {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
+                    {:read_ahead, _} -> [:raw | modes]
+                    false -> [:raw, :read_ahead | modes]
+                    end
+
+                false ->
+                    modes
+                end
+
+            %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
+
+            end"#.unindent(), 574)],
+    );
+}
+
+#[gpui::test]
+async fn test_code_context_retrieval_cpp() {
+    let language = cpp_lang();
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
+
+    let text = "
+    /**
+     * @brief Main function
+     * @returns 0 on exit
+     */
+    int main() { return 0; }
+
+    /**
+    * This is a test comment
+    */
+    class MyClass {       // The class
+        public:           // Access specifier
+        int myNum;        // Attribute (int variable)
+        string myString;  // Attribute (string variable)
+    };
+
+    // This is a test comment
+    enum Color { red, green, blue };
+
+    /** This is a preceding block comment
+     * This is the second line
+     */
+    struct {           // Structure declaration
+        int myNum;       // Member (int variable)
+        string myString; // Member (string variable)
+    } myStructure;
+
+    /**
+     * @brief Matrix class.
+     */
+    template <typename T,
+              typename = typename std::enable_if<
+                std::is_integral<T>::value || std::is_floating_point<T>::value,
+                bool>::type>
+    class Matrix2 {
+        std::vector<std::vector<T>> _mat;
+
+        public:
+            /**
+            * @brief Constructor
+            * @tparam Integer ensuring integers are being evaluated and not other
+            * data types.
+            * @param size denoting the size of Matrix as size x size
+            */
+            template <typename Integer,
+                    typename = typename std::enable_if<std::is_integral<Integer>::value,
+                    Integer>::type>
+            explicit Matrix(const Integer size) {
+                for (size_t i = 0; i < size; ++i) {
+                    _mat.emplace_back(std::vector<T>(size, 0));
+                }
+            }
+    }"
+    .unindent();
+
+    let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[
+            (
+                "
+        /**
+         * @brief Main function
+         * @returns 0 on exit
+         */
+        int main() { return 0; }"
+                    .unindent(),
+                54,
+            ),
+            (
+                "
+                /**
+                * This is a test comment
+                */
+                class MyClass {       // The class
+                    public:           // Access specifier
+                    int myNum;        // Attribute (int variable)
+                    string myString;  // Attribute (string variable)
+                }"
+                .unindent(),
+                112,
+            ),
+            (
+                "
+                // This is a test comment
+                enum Color { red, green, blue }"
+                    .unindent(),
+                322,
+            ),
+            (
+                "
+                /** This is a preceding block comment
+                 * This is the second line
+                 */
+                struct {           // Structure declaration
+                    int myNum;       // Member (int variable)
+                    string myString; // Member (string variable)
+                } myStructure;"
+                    .unindent(),
+                425,
+            ),
+            (
+                "
+                /**
+                 * @brief Matrix class.
+                 */
+                template <typename T,
+                          typename = typename std::enable_if<
+                            std::is_integral<T>::value || std::is_floating_point<T>::value,
+                            bool>::type>
+                class Matrix2 {
+                    std::vector<std::vector<T>> _mat;
+
+                    public:
+                        /**
+                        * @brief Constructor
+                        * @tparam Integer ensuring integers are being evaluated and not other
+                        * data types.
+                        * @param size denoting the size of Matrix as size x size
+                        */
+                        template <typename Integer,
+                                typename = typename std::enable_if<std::is_integral<Integer>::value,
+                                Integer>::type>
+                        explicit Matrix(const Integer size) {
+                            for (size_t i = 0; i < size; ++i) {
+                                _mat.emplace_back(std::vector<T>(size, 0));
+                            }
+                        }
+                }"
+                .unindent(),
+                612,
+            ),
+            (
+                "
+                explicit Matrix(const Integer size) {
+                    for (size_t i = 0; i < size; ++i) {
+                        _mat.emplace_back(std::vector<T>(size, 0));
+                    }
+                }"
+                .unindent(),
+                1226,
+            ),
+        ],
+    );
+}
+
+#[gpui::test]
+async fn test_code_context_retrieval_ruby() {
+    let language = ruby_lang();
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
+
+    let text = r#"
+        # This concern is inspired by "sudo mode" on GitHub. It
+        # is a way to re-authenticate a user before allowing them
+        # to see or perform an action.
+        #
+        # Add `before_action :require_challenge!` to actions you
+        # want to protect.
+        #
+        # The user will be shown a page to enter the challenge (which
+        # is either the password, or just the username when no
+        # password exists). Upon passing, there is a grace period
+        # during which no challenge will be asked from the user.
+        #
+        # Accessing challenge-protected resources during the grace
+        # period will refresh the grace period.
+        module ChallengableConcern
+            extend ActiveSupport::Concern
+
+            CHALLENGE_TIMEOUT = 1.hour.freeze
+
+            def require_challenge!
+                return if skip_challenge?
+
+                if challenge_passed_recently?
+                    session[:challenge_passed_at] = Time.now.utc
+                    return
+                end
+
+                @challenge = Form::Challenge.new(return_to: request.url)
+
+                if params.key?(:form_challenge)
+                    if challenge_passed?
+                        session[:challenge_passed_at] = Time.now.utc
+                    else
+                        flash.now[:alert] = I18n.t('challenge.invalid_password')
+                        render_challenge
+                    end
+                else
+                    render_challenge
+                end
+            end
+
+            def challenge_passed?
+                current_user.valid_password?(challenge_params[:current_password])
+            end
+        end
+
+        class Animal
+            include Comparable
+
+            attr_reader :legs
+
+            def initialize(name, legs)
+                @name, @legs = name, legs
+            end
+
+            def <=>(other)
+                legs <=> other.legs
+            end
+        end
+
+        # Singleton method for car object
+        def car.wheels
+            puts "There are four wheels"
+        end"#
+        .unindent();
+
+    let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[
+            (
+                r#"
+        # This concern is inspired by "sudo mode" on GitHub. It
+        # is a way to re-authenticate a user before allowing them
+        # to see or perform an action.
+        #
+        # Add `before_action :require_challenge!` to actions you
+        # want to protect.
+        #
+        # The user will be shown a page to enter the challenge (which
+        # is either the password, or just the username when no
+        # password exists). Upon passing, there is a grace period
+        # during which no challenge will be asked from the user.
+        #
+        # Accessing challenge-protected resources during the grace
+        # period will refresh the grace period.
+        module ChallengableConcern
+            extend ActiveSupport::Concern
+
+            CHALLENGE_TIMEOUT = 1.hour.freeze
+
+            def require_challenge!
+                # ...
+            end
+
+            def challenge_passed?
+                # ...
+            end
+        end"#
+                    .unindent(),
+                558,
+            ),
+            (
+                r#"
+            def require_challenge!
+                return if skip_challenge?
+
+                if challenge_passed_recently?
+                    session[:challenge_passed_at] = Time.now.utc
+                    return
+                end
+
+                @challenge = Form::Challenge.new(return_to: request.url)
+
+                if params.key?(:form_challenge)
+                    if challenge_passed?
+                        session[:challenge_passed_at] = Time.now.utc
+                    else
+                        flash.now[:alert] = I18n.t('challenge.invalid_password')
+                        render_challenge
+                    end
+                else
+                    render_challenge
+                end
+            end"#
+                    .unindent(),
+                663,
+            ),
+            (
+                r#"
+                def challenge_passed?
+                    current_user.valid_password?(challenge_params[:current_password])
+                end"#
+                    .unindent(),
+                1254,
+            ),
+            (
+                r#"
+                class Animal
+                    include Comparable
+
+                    attr_reader :legs
+
+                    def initialize(name, legs)
+                        # ...
+                    end
+
+                    def <=>(other)
+                        # ...
+                    end
+                end"#
+                    .unindent(),
+                1363,
+            ),
+            (
+                r#"
+                def initialize(name, legs)
+                    @name, @legs = name, legs
+                end"#
+                    .unindent(),
+                1427,
+            ),
+            (
+                r#"
+                def <=>(other)
+                    legs <=> other.legs
+                end"#
+                    .unindent(),
+                1501,
+            ),
+            (
+                r#"
+                # Singleton method for car object
+                def car.wheels
+                    puts "There are four wheels"
+                end"#
+                    .unindent(),
+                1591,
+            ),
+        ],
+    );
+}
+
+#[gpui::test]
+async fn test_code_context_retrieval_php() {
+    let language = php_lang();
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
+
+    let text = r#"
+        <?php
+
+        namespace LevelUp\Experience\Concerns;
+
+        /*
+        This is a multiple-lines comment block
+        that spans over multiple
+        lines
+        */
+        function functionName() {
+            echo "Hello world!";
+        }
+
+        trait HasAchievements
+        {
+            /**
+            * @throws \Exception
+            */
+            public function grantAchievement(Achievement $achievement, $progress = null): void
+            {
+                if ($progress > 100) {
+                    throw new Exception(message: 'Progress cannot be greater than 100');
+                }
+
+                if ($this->achievements()->find($achievement->id)) {
+                    throw new Exception(message: 'User already has this Achievement');
+                }
+
+                $this->achievements()->attach($achievement, [
+                    'progress' => $progress ?? null,
+                ]);
+
+                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
+            }
+
+            public function achievements(): BelongsToMany
+            {
+                return $this->belongsToMany(related: Achievement::class)
+                ->withPivot(columns: 'progress')
+                ->where('is_secret', false)
+                ->using(AchievementUser::class);
+            }
+        }
+
+        interface Multiplier
+        {
+            public function qualifies(array $data): bool;
+
+            public function setMultiplier(): int;
+        }
+
+        enum AuditType: string
+        {
+            case Add = 'add';
+            case Remove = 'remove';
+            case Reset = 'reset';
+            case LevelUp = 'level_up';
+        }
+
+        ?>"#
+    .unindent();
+
+    let documents = retriever.parse_file(&text, language.clone()).unwrap();
+
+    assert_documents_eq(
+        &documents,
+        &[
+            (
+                r#"
+        /*
+        This is a multiple-lines comment block
+        that spans over multiple
+        lines
+        */
+        function functionName() {
+            echo "Hello world!";
+        }"#
+                .unindent(),
+                123,
+            ),
+            (
+                r#"
+        trait HasAchievements
+        {
+            /**
+            * @throws \Exception
+            */
+            public function grantAchievement(Achievement $achievement, $progress = null): void
+            {/* ... */}
+
+            public function achievements(): BelongsToMany
+            {/* ... */}
+        }"#
+                .unindent(),
+                177,
+            ),
+            (r#"
+            /**
+            * @throws \Exception
+            */
+            public function grantAchievement(Achievement $achievement, $progress = null): void
+            {
+                if ($progress > 100) {
+                    throw new Exception(message: 'Progress cannot be greater than 100');
+                }
+
+                if ($this->achievements()->find($achievement->id)) {
+                    throw new Exception(message: 'User already has this Achievement');
+                }
+
+                $this->achievements()->attach($achievement, [
+                    'progress' => $progress ?? null,
+                ]);
+
+                $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
+            }"#.unindent(), 245),
+            (r#"
+                public function achievements(): BelongsToMany
+                {
+                    return $this->belongsToMany(related: Achievement::class)
+                    ->withPivot(columns: 'progress')
+                    ->where('is_secret', false)
+                    ->using(AchievementUser::class);
+                }"#.unindent(), 902),
+            (r#"
+                interface Multiplier
+                {
+                    public function qualifies(array $data): bool;
+
+                    public function setMultiplier(): int;
+                }"#.unindent(),
+                1146),
+            (r#"
+                enum AuditType: string
+                {
+                    case Add = 'add';
+                    case Remove = 'remove';
+                    case Reset = 'reset';
+                    case LevelUp = 'level_up';
+                }"#.unindent(), 1265)
+        ],
+    );
+}
+
+fn js_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "Javascript".into(),
+                path_suffixes: vec!["js".into()],
+                ..Default::default()
+            },
+            Some(tree_sitter_typescript::language_tsx()),
+        )
+        .with_embedding_query(
+            &r#"
+
+            (
+                (comment)* @context
+                .
+                [
+                (export_statement
+                    (function_declaration
+                        "async"? @name
+                        "function" @name
+                        name: (_) @name))
+                (function_declaration
+                    "async"? @name
+                    "function" @name
+                    name: (_) @name)
+                ] @item
+            )
+
+            (
+                (comment)* @context
+                .
+                [
+                (export_statement
+                    (class_declaration
+                        "class" @name
+                        name: (_) @name))
+                (class_declaration
+                    "class" @name
+                    name: (_) @name)
+                ] @item
+            )
+
+            (
+                (comment)* @context
+                .
+                [
+                (export_statement
+                    (interface_declaration
+                        "interface" @name
+                        name: (_) @name))
+                (interface_declaration
+                    "interface" @name
+                    name: (_) @name)
+                ] @item
+            )
+
+            (
+                (comment)* @context
+                .
+                [
+                (export_statement
+                    (enum_declaration
+                        "enum" @name
+                        name: (_) @name))
+                (enum_declaration
+                    "enum" @name
+                    name: (_) @name)
+                ] @item
+            )
+
+            (
+                (comment)* @context
+                .
+                (method_definition
+                    [
+                        "get"
+                        "set"
+                        "async"
+                        "*"
+                        "static"
+                    ]* @name
+                    name: (_) @name) @item
+            )
+
+                    "#
+            .unindent(),
+        )
+        .unwrap(),
+    )
+}
+
+fn rust_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "Rust".into(),
+                path_suffixes: vec!["rs".into()],
+                collapsed_placeholder: " /* ... */ ".to_string(),
+                ..Default::default()
+            },
+            Some(tree_sitter_rust::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (
+                [(line_comment) (attribute_item)]* @context
+                .
+                [
+                    (struct_item
+                        name: (_) @name)
+
+                    (enum_item
+                        name: (_) @name)
+
+                    (impl_item
+                        trait: (_)? @name
+                        "for"? @name
+                        type: (_) @name)
+
+                    (trait_item
+                        name: (_) @name)
+
+                    (function_item
+                        name: (_) @name
+                        body: (block
+                            "{" @keep
+                            "}" @keep) @collapse)
+
+                    (macro_definition
+                        name: (_) @name)
+                ] @item
+            )
+
+            (attribute_item) @collapse
+            (use_declaration) @collapse
+            "#,
+        )
+        .unwrap(),
+    )
+}
+
+fn json_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "JSON".into(),
+                path_suffixes: vec!["json".into()],
+                ..Default::default()
+            },
+            Some(tree_sitter_json::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (document) @item
+
+            (array
+                "[" @keep
+                .
+                (object)? @keep
+                "]" @keep) @collapse
+
+            (pair value: (string
+                "\"" @keep
+                "\"" @keep) @collapse)
+            "#,
+        )
+        .unwrap(),
+    )
+}
+
+fn toml_lang() -> Arc<Language> {
+    Arc::new(Language::new(
+        LanguageConfig {
+            name: "TOML".into(),
+            path_suffixes: vec!["toml".into()],
+            ..Default::default()
+        },
+        Some(tree_sitter_toml::language()),
+    ))
+}
+
+fn cpp_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "CPP".into(),
+                path_suffixes: vec!["cpp".into()],
+                ..Default::default()
+            },
+            Some(tree_sitter_cpp::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (
+                (comment)* @context
+                .
+                (function_definition
+                    (type_qualifier)? @name
+                    type: (_)? @name
+                    declarator: [
+                        (function_declarator
+                            declarator: (_) @name)
+                        (pointer_declarator
+                            "*" @name
+                            declarator: (function_declarator
+                            declarator: (_) @name))
+                        (pointer_declarator
+                            "*" @name
+                            declarator: (pointer_declarator
+                                "*" @name
+                            declarator: (function_declarator
+                                declarator: (_) @name)))
+                        (reference_declarator
+                            ["&" "&&"] @name
+                            (function_declarator
+                            declarator: (_) @name))
+                    ]
+                    (type_qualifier)? @name) @item
+                )
+
+            (
+                (comment)* @context
+                .
+                (template_declaration
+                    (class_specifier
+                        "class" @name
+                        name: (_) @name)
+                        ) @item
+            )
+
+            (
+                (comment)* @context
+                .
+                (class_specifier
+                    "class" @name
+                    name: (_) @name) @item
+                )
+
+            (
+                (comment)* @context
+                .
+                (enum_specifier
+                    "enum" @name
+                    name: (_) @name) @item
+                )
+
+            (
+                (comment)* @context
+                .
+                (declaration
+                    type: (struct_specifier
+                    "struct" @name)
+                    declarator: (_) @name) @item
+            )
+
+            "#,
+        )
+        .unwrap(),
+    )
+}
+
+fn lua_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "Lua".into(),
+                path_suffixes: vec!["lua".into()],
+                collapsed_placeholder: "--[ ... ]--".to_string(),
+                ..Default::default()
+            },
+            Some(tree_sitter_lua::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (
+                (comment)* @context
+                .
+                (function_declaration
+                    "function" @name
+                    name: (_) @name
+                    (comment)* @collapse
+                    body: (block) @collapse
+                ) @item
+            )
+        "#,
+        )
+        .unwrap(),
+    )
+}
+
+fn php_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "PHP".into(),
+                path_suffixes: vec!["php".into()],
+                collapsed_placeholder: "/* ... */".into(),
+                ..Default::default()
+            },
+            Some(tree_sitter_php::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (
+                (comment)* @context
+                .
+                [
+                    (function_definition
+                        "function" @name
+                        name: (_) @name
+                        body: (_
+                            "{" @keep
+                            "}" @keep) @collapse
+                        )
+
+                    (trait_declaration
+                        "trait" @name
+                        name: (_) @name)
+
+                    (method_declaration
+                        "function" @name
+                        name: (_) @name
+                        body: (_
+                            "{" @keep
+                            "}" @keep) @collapse
+                        )
+
+                    (interface_declaration
+                        "interface" @name
+                        name: (_) @name
+                        )
+
+                    (enum_declaration
+                        "enum" @name
+                        name: (_) @name
+                        )
+
+                ] @item
+            )
+            "#,
+        )
+        .unwrap(),
+    )
+}
+
+fn ruby_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "Ruby".into(),
+                path_suffixes: vec!["rb".into()],
+                collapsed_placeholder: "# ...".to_string(),
+                ..Default::default()
+            },
+            Some(tree_sitter_ruby::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (
+                (comment)* @context
+                .
+                [
+                (module
+                    "module" @name
+                    name: (_) @name)
+                (method
+                    "def" @name
+                    name: (_) @name
+                    body: (body_statement) @collapse)
+                (class
+                    "class" @name
+                    name: (_) @name)
+                (singleton_method
+                    "def" @name
+                    object: (_) @name
+                    "." @name
+                    name: (_) @name
+                    body: (body_statement) @collapse)
+                ] @item
+            )
+            "#,
+        )
+        .unwrap(),
+    )
+}
+
+fn elixir_lang() -> Arc<Language> {
+    Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "Elixir".into(),
+                path_suffixes: vec!["rs".into()],
+                ..Default::default()
+            },
+            Some(tree_sitter_elixir::language()),
+        )
+        .with_embedding_query(
+            r#"
+            (
+                (unary_operator
+                    operator: "@"
+                    operand: (call
+                        target: (identifier) @unary
+                        (#match? @unary "^(doc)$"))
+                    ) @context
+                .
+                (call
+                target: (identifier) @name
+                (arguments
+                [
+                (identifier) @name
+                (call
+                target: (identifier) @name)
+                (binary_operator
+                left: (call
+                target: (identifier) @name)
+                operator: "when")
+                ])
+                (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
+                )
+
+            (call
+                target: (identifier) @name
+                (arguments (alias) @name)
+                (#any-match? @name "^(defmodule|defprotocol)$")) @item
+            "#,
+        )
+        .unwrap(),
+    )
+}
+
+#[gpui::test]
+fn test_subtract_ranges() {
+    // collapsed_ranges: Vec<Range<usize>>, keep_ranges: Vec<Range<usize>>
+
+    assert_eq!(
+        subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
+        vec![1..4, 10..21]
+    );
+
+    assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
+}
+
+fn init_test(cx: &mut TestAppContext) {
+    cx.update(|cx| {
+        let settings_store = SettingsStore::test(cx);
+        cx.set_global(settings_store);
+        SemanticIndexSettings::register(cx);
+        ProjectSettings::register(cx);
+    });
+}

crates/workspace2/src/workspace2.rs 🔗

@@ -3942,8 +3942,6 @@ impl std::fmt::Debug for OpenPaths {
     }
 }
 
-pub struct WorkspaceCreated(pub WeakView<Workspace>);
-
 pub fn activate_workspace_for_project(
     cx: &mut AppContext,
     predicate: impl Fn(&Project, &AppContext) -> bool + Send + 'static,