Detailed changes
@@ -304,6 +304,9 @@ name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
+dependencies = [
+ "serde",
+]
[[package]]
name = "as-raw-xcb-connection"
@@ -1709,6 +1712,19 @@ dependencies = [
"profiling",
]
+[[package]]
+name = "blake3"
+version = "1.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7"
+dependencies = [
+ "arrayref",
+ "arrayvec",
+ "cc",
+ "cfg-if",
+ "constant_time_eq",
+]
+
[[package]]
name = "block"
version = "0.1.6"
@@ -2752,6 +2768,12 @@ dependencies = [
"tiny-keccak",
]
+[[package]]
+name = "constant_time_eq"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
+
[[package]]
name = "context_servers"
version = "0.1.0"
@@ -4187,6 +4209,7 @@ dependencies = [
name = "feature_flags"
version = "0.1.0"
dependencies = [
+ "futures 0.3.30",
"gpui",
]
@@ -9814,10 +9837,13 @@ name = "semantic_index"
version = "0.1.0"
dependencies = [
"anyhow",
+ "arrayvec",
+ "blake3",
"client",
"clock",
"collections",
"env_logger",
+ "feature_flags",
"fs",
"futures 0.3.30",
"futures-batch",
@@ -9825,6 +9851,7 @@ dependencies = [
"heed",
"http_client",
"language",
+ "language_model",
"languages",
"log",
"open_ai",
@@ -309,6 +309,7 @@ aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/alacritty/alacritty", rev = "91d034ff8b53867143c005acfaa14609147c9a2c" }
any_vec = "0.14"
anyhow = "1.0.86"
+arrayvec = { version = "0.7.4", features = ["serde"] }
ashpd = "0.9.1"
async-compression = { version = "0.4", features = ["gzip", "futures-io"] }
async-dispatcher = "0.1"
@@ -325,6 +326,7 @@ bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
blade-util = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
+blake3 = "1.5.3"
cargo_metadata = "0.18"
cargo_toml = "0.20"
chrono = { version = "0.4", features = ["serde"] }
@@ -37,13 +37,13 @@ use language_model::{
pub(crate) use model_selector::*;
pub use prompts::PromptBuilder;
use prompts::PromptLoadingParams;
-use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
+use semantic_index::{CloudEmbeddingProvider, SemanticDb};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
use slash_command::{
- context_server_command, default_command, diagnostics_command, docs_command, fetch_command,
- file_command, now_command, project_command, prompt_command, search_command, symbols_command,
- tab_command, terminal_command, workflow_command,
+ auto_command, context_server_command, default_command, diagnostics_command, docs_command,
+ fetch_command, file_command, now_command, project_command, prompt_command, search_command,
+ symbols_command, tab_command, terminal_command, workflow_command,
};
use std::path::PathBuf;
use std::sync::Arc;
@@ -210,12 +210,13 @@ pub fn init(
let client = client.clone();
async move {
let embedding_provider = CloudEmbeddingProvider::new(client.clone());
- let semantic_index = SemanticIndex::new(
+ let semantic_index = SemanticDb::new(
paths::embeddings_dir().join("semantic-index-db.0.mdb"),
Arc::new(embedding_provider),
&mut cx,
)
.await?;
+
cx.update(|cx| cx.set_global(semantic_index))
}
})
@@ -364,6 +365,7 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
let slash_command_registry = SlashCommandRegistry::global(cx);
+
slash_command_registry.register_command(file_command::FileSlashCommand, true);
slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
slash_command_registry.register_command(tab_command::TabSlashCommand, true);
@@ -382,6 +384,17 @@ fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut
}
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
+ cx.observe_flag::<auto_command::AutoSlashCommandFeatureFlag, _>({
+ let slash_command_registry = slash_command_registry.clone();
+ move |is_enabled, _cx| {
+ if is_enabled {
+ // [#auto-staff-ship] TODO remove this when /auto is no longer staff-shipped
+ slash_command_registry.register_command(auto_command::AutoCommand, true);
+ }
+ }
+ })
+ .detach();
+
update_slash_commands_from_settings(cx);
cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
.detach();
@@ -4723,6 +4723,20 @@ impl Render for ContextEditorToolbarItem {
let weak_self = cx.view().downgrade();
let right_side = h_flex()
.gap_2()
+ // TODO display this in a nicer way, once we have a design for it.
+ // .children({
+ // let project = self
+ // .workspace
+ // .upgrade()
+ // .map(|workspace| workspace.read(cx).project().downgrade());
+ //
+ // let scan_items_remaining = cx.update_global(|db: &mut SemanticDb, cx| {
+ // project.and_then(|project| db.remaining_summaries(&project, cx))
+ // });
+
+ // scan_items_remaining
+ // .map(|remaining_items| format!("Files to scan: {}", remaining_items))
+ // })
.child(
ModelSelector::new(
self.fs.clone(),
@@ -519,6 +519,7 @@ impl Settings for AssistantSettings {
&mut settings.default_model,
value.default_model.map(Into::into),
);
+ // merge(&mut settings.infer_context, value.infer_context); TODO re-enable this once we ship context inference
}
Ok(settings)
@@ -19,6 +19,7 @@ use std::{
use ui::ActiveTheme;
use workspace::Workspace;
+pub mod auto_command;
pub mod context_server_command;
pub mod default_command;
pub mod diagnostics_command;
@@ -0,0 +1,360 @@
+use super::create_label_for_command;
+use super::{SlashCommand, SlashCommandOutput};
+use anyhow::{anyhow, Result};
+use assistant_slash_command::ArgumentCompletion;
+use feature_flags::FeatureFlag;
+use futures::StreamExt;
+use gpui::{AppContext, AsyncAppContext, Task, WeakView};
+use language::{CodeLabel, LspAdapterDelegate};
+use language_model::{
+ LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, Role,
+};
+use semantic_index::{FileSummary, SemanticDb};
+use smol::channel;
+use std::sync::{atomic::AtomicBool, Arc};
+use ui::{BorrowAppContext, WindowContext};
+use util::ResultExt;
+use workspace::Workspace;
+
+pub struct AutoSlashCommandFeatureFlag;
+
+impl FeatureFlag for AutoSlashCommandFeatureFlag {
+ const NAME: &'static str = "auto-slash-command";
+}
+
+pub(crate) struct AutoCommand;
+
+impl SlashCommand for AutoCommand {
+ fn name(&self) -> String {
+ "auto".into()
+ }
+
+ fn description(&self) -> String {
+ "Automatically infer what context to add, based on your prompt".into()
+ }
+
+ fn menu_text(&self) -> String {
+ "Automatically Infer Context".into()
+ }
+
+ fn label(&self, cx: &AppContext) -> CodeLabel {
+ create_label_for_command("auto", &["--prompt"], cx)
+ }
+
+ fn complete_argument(
+ self: Arc<Self>,
+ _arguments: &[String],
+ _cancel: Arc<AtomicBool>,
+ workspace: Option<WeakView<Workspace>>,
+ cx: &mut WindowContext,
+ ) -> Task<Result<Vec<ArgumentCompletion>>> {
+ // There's no autocomplete for a prompt, since it's arbitrary text.
+ // However, we can use this opportunity to kick off a drain of the backlog.
+ // That way, it can hopefully be done resummarizing by the time we've actually
+ // typed out our prompt. This re-runs on every keystroke during autocomplete,
+ // but in the future, we could instead do it only once, when /auto is first entered.
+ let Some(workspace) = workspace.and_then(|ws| ws.upgrade()) else {
+ log::warn!("workspace was dropped or unavailable during /auto autocomplete");
+
+ return Task::ready(Ok(Vec::new()));
+ };
+
+ let project = workspace.read(cx).project().clone();
+ let Some(project_index) =
+ cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
+ else {
+ return Task::ready(Err(anyhow!("No project indexer, cannot use /auto")));
+ };
+
+ let cx: &mut AppContext = cx;
+
+ cx.spawn(|cx: gpui::AsyncAppContext| async move {
+ let task = project_index.read_with(&cx, |project_index, cx| {
+ project_index.flush_summary_backlogs(cx)
+ })?;
+
+ cx.background_executor().spawn(task).await;
+
+ anyhow::Ok(Vec::new())
+ })
+ }
+
+ fn requires_argument(&self) -> bool {
+ true
+ }
+
+ fn run(
+ self: Arc<Self>,
+ arguments: &[String],
+ workspace: WeakView<Workspace>,
+ _delegate: Option<Arc<dyn LspAdapterDelegate>>,
+ cx: &mut WindowContext,
+ ) -> Task<Result<SlashCommandOutput>> {
+ let Some(workspace) = workspace.upgrade() else {
+ return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
+ };
+ if arguments.is_empty() {
+ return Task::ready(Err(anyhow!("missing prompt")));
+ };
+ let argument = arguments.join(" ");
+ let original_prompt = argument.to_string();
+ let project = workspace.read(cx).project().clone();
+ let Some(project_index) =
+ cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
+ else {
+ return Task::ready(Err(anyhow!("no project indexer")));
+ };
+
+ let task = cx.spawn(|cx: gpui::AsyncWindowContext| async move {
+ let summaries = project_index
+ .read_with(&cx, |project_index, cx| project_index.all_summaries(cx))?
+ .await?;
+
+ commands_for_summaries(&summaries, &original_prompt, &cx).await
+ });
+
+ // As a convenience, append /auto's argument to the end of the prompt
+ // so you don't have to write it again.
+ let original_prompt = argument.to_string();
+
+ cx.background_executor().spawn(async move {
+ let commands = task.await?;
+ let mut prompt = String::new();
+
+ log::info!(
+ "Translating this response into slash-commands: {:?}",
+ commands
+ );
+
+ for command in commands {
+ prompt.push('/');
+ prompt.push_str(&command.name);
+ prompt.push(' ');
+ prompt.push_str(&command.arg);
+ prompt.push('\n');
+ }
+
+ prompt.push('\n');
+ prompt.push_str(&original_prompt);
+
+ Ok(SlashCommandOutput {
+ text: prompt,
+ sections: Vec::new(),
+ run_commands_in_text: true,
+ })
+ })
+ }
+}
+
+const PROMPT_INSTRUCTIONS_BEFORE_SUMMARY: &str = include_str!("prompt_before_summary.txt");
+const PROMPT_INSTRUCTIONS_AFTER_SUMMARY: &str = include_str!("prompt_after_summary.txt");
+
+fn summaries_prompt(summaries: &[FileSummary], original_prompt: &str) -> String {
+ let json_summaries = serde_json::to_string(summaries).unwrap();
+
+ format!("{PROMPT_INSTRUCTIONS_BEFORE_SUMMARY}\n{json_summaries}\n{PROMPT_INSTRUCTIONS_AFTER_SUMMARY}\n{original_prompt}")
+}
+
+/// The slash commands that the model is told about, and which we look for in the inference response.
+const SUPPORTED_SLASH_COMMANDS: &[&str] = &["search", "file"];
+
+#[derive(Debug, Clone)]
+struct CommandToRun {
+ name: String,
+ arg: String,
+}
+
+/// Given the pre-indexed file summaries for this project, as well as the original prompt
+/// string passed to `/auto`, get a list of slash commands to run, along with their arguments.
+///
+/// The prompt's output does not include the slashes (to reduce the chance that it makes a mistake),
+/// so taking one of these returned Strings and turning it into a real slash-command-with-argument
+/// involves prepending a slash to it.
+///
+/// This function will validate that each of the returned lines begins with one of SUPPORTED_SLASH_COMMANDS.
+/// Any other lines it encounters will be discarded, with a warning logged.
+async fn commands_for_summaries(
+ summaries: &[FileSummary],
+ original_prompt: &str,
+ cx: &AsyncAppContext,
+) -> Result<Vec<CommandToRun>> {
+ if summaries.is_empty() {
+ log::warn!("Inferring no context because there were no summaries available.");
+ return Ok(Vec::new());
+ }
+
+ // Use the globally configured model to translate the summaries into slash-commands,
+ // because Qwen2-7B-Instruct has not done a good job at that task.
+ let Some(model) = cx.update(|cx| LanguageModelRegistry::read_global(cx).active_model())? else {
+ log::warn!("Can't infer context because there's no active model.");
+ return Ok(Vec::new());
+ };
+ // Only go up to 90% of the actual max token count, to reduce chances of
+ // exceeding the token count due to inaccuracies in the token counting heuristic.
+ let max_token_count = (model.max_token_count() * 9) / 10;
+
+ // Rather than recursing (which would require this async function use a pinned box),
+ // we use an explicit stack of arguments and answers for when we need to "recurse."
+ let mut stack = vec![summaries];
+ let mut final_response = Vec::new();
+ let mut prompts = Vec::new();
+
+ // TODO We only need to create multiple Requests because we currently
+ // don't have the ability to tell if a CompletionProvider::complete response
+ // was a "too many tokens in this request" error. If we had that, then
+ // we could try the request once, instead of having to make separate requests
+ // to check the token count and then afterwards to run the actual prompt.
+ let make_request = |prompt: String| LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![prompt.into()],
+ // Nothing in here will benefit from caching
+ cache: false,
+ }],
+ tools: Vec::new(),
+ stop: Vec::new(),
+ temperature: 1.0,
+ };
+
+ while let Some(current_summaries) = stack.pop() {
+ // The split can result in one slice being empty and the other having one element.
+ // Whenever that happens, skip the empty one.
+ if current_summaries.is_empty() {
+ continue;
+ }
+
+ log::info!(
+ "Inferring prompt context using {} file summaries",
+ current_summaries.len()
+ );
+
+ let prompt = summaries_prompt(¤t_summaries, original_prompt);
+ let start = std::time::Instant::now();
+ // Per OpenAI, 1 token ~= 4 chars in English (we go with 4.5 to overestimate a bit, because failed API requests cost a lot of perf)
+ // Verifying this against an actual model.count_tokens() confirms that it's usually within ~5% of the correct answer, whereas
+ // getting the correct answer from tiktoken takes hundreds of milliseconds (compared to this arithmetic being ~free).
+ // source: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
+ let token_estimate = prompt.len() * 2 / 9;
+ let duration = start.elapsed();
+ log::info!(
+ "Time taken to count tokens for prompt of length {:?}B: {:?}",
+ prompt.len(),
+ duration
+ );
+
+ if token_estimate < max_token_count {
+ prompts.push(prompt);
+ } else if current_summaries.len() == 1 {
+ log::warn!("Inferring context for a single file's summary failed because the prompt's token length exceeded the model's token limit.");
+ } else {
+ log::info!(
+ "Context inference using file summaries resulted in a prompt containing {token_estimate} tokens, which exceeded the model's max of {max_token_count}. Retrying as two separate prompts, each including half the number of summaries.",
+ );
+ let (left, right) = current_summaries.split_at(current_summaries.len() / 2);
+ stack.push(right);
+ stack.push(left);
+ }
+ }
+
+ let all_start = std::time::Instant::now();
+
+ let (tx, rx) = channel::bounded(1024);
+
+ let completion_streams = prompts
+ .into_iter()
+ .map(|prompt| {
+ let request = make_request(prompt.clone());
+ let model = model.clone();
+ let tx = tx.clone();
+ let stream = model.stream_completion(request, &cx);
+
+ (stream, tx)
+ })
+ .collect::<Vec<_>>();
+
+ cx.background_executor()
+ .spawn(async move {
+ let futures = completion_streams
+ .into_iter()
+ .enumerate()
+ .map(|(ix, (stream, tx))| async move {
+ let start = std::time::Instant::now();
+ let events = stream.await?;
+ log::info!("Time taken for awaiting /await chunk stream #{ix}: {:?}", start.elapsed());
+
+ let completion: String = events
+ .filter_map(|event| async {
+ if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
+ Some(text)
+ } else {
+ None
+ }
+ })
+ .collect()
+ .await;
+
+ log::info!("Time taken for all /auto chunks to come back for #{ix}: {:?}", start.elapsed());
+
+ for line in completion.split('\n') {
+ if let Some(first_space) = line.find(' ') {
+ let command = &line[..first_space].trim();
+ let arg = &line[first_space..].trim();
+
+ tx.send(CommandToRun {
+ name: command.to_string(),
+ arg: arg.to_string(),
+ })
+ .await?;
+ } else if !line.trim().is_empty() {
+ // All slash-commands currently supported in context inference need a space for the argument.
+ log::warn!(
+ "Context inference returned a non-blank line that contained no spaces (meaning no argument for the slash command): {:?}",
+ line
+ );
+ }
+ }
+
+ anyhow::Ok(())
+ })
+ .collect::<Vec<_>>();
+
+ let _ = futures::future::try_join_all(futures).await.log_err();
+
+ let duration = all_start.elapsed();
+ eprintln!("All futures completed in {:?}", duration);
+ })
+ .await;
+
+ drop(tx); // Close the channel so that rx.collect() won't hang. This is safe because all futures have completed.
+ let results = rx.collect::<Vec<_>>().await;
+ eprintln!(
+ "Finished collecting from the channel with {} results",
+ results.len()
+ );
+ for command in results {
+ // Don't return empty or duplicate commands
+ if !command.name.is_empty()
+ && !final_response
+ .iter()
+ .any(|cmd: &CommandToRun| cmd.name == command.name && cmd.arg == command.arg)
+ {
+ if SUPPORTED_SLASH_COMMANDS
+ .iter()
+ .any(|supported| &command.name == supported)
+ {
+ final_response.push(command);
+ } else {
+ log::warn!(
+ "Context inference returned an unrecognized slash command: {:?}",
+ command
+ );
+ }
+ }
+ }
+
+ // Sort the commands by name (reversed just so that /search appears before /file)
+ final_response.sort_by(|cmd1, cmd2| cmd1.name.cmp(&cmd2.name).reverse());
+
+ Ok(final_response)
+}
@@ -0,0 +1,24 @@
+Actions have a cost, so only include actions that you think
+will be helpful to you in doing a great job answering the
+prompt in the future.
+
+You must respond ONLY with a list of actions you would like to
+perform. Each action should be on its own line, and followed by a space and then its parameter.
+
+Actions can be performed more than once with different parameters.
+Here is an example valid response:
+
+```
+file path/to/my/file.txt
+file path/to/another/file.txt
+search something to search for
+search something else to search for
+```
+
+Once again, do not forget: you must respond ONLY in the format of
+one action per line, and the action name should be followed by
+its parameter. Your response must not include anything other
+than a list of actions, with one action per line, in this format.
+It is extremely important that you do not deviate from this format even slightly!
+
+This is the end of my instructions for how to respond. The rest is the prompt:
@@ -0,0 +1,31 @@
+I'm going to give you a prompt. I don't want you to respond
+to the prompt itself. I want you to figure out which of the following
+actions on my project, if any, would help you answer the prompt.
+
+Here are the actions:
+
+## file
+
+This action's parameter is a file path to one of the files
+in the project. If you ask for this action, I will tell you
+the full contents of the file, so you can learn all the
+details of the file.
+
+## search
+
+This action's parameter is a string to do a semantic search for
+across the files in the project. (You will have a JSON summary
+of all the files in the project.) It will tell you which files this string
+(or similar strings; it is a semantic search) appear in,
+as well as some context of the lines surrounding each result.
+It's very important that you only use this action when you think
+that searching across the specific files in this project for the query
+in question will be useful. For example, don't use this command to search
+for queries you might put into a general Web search engine, because those
+will be too general to give useful results in this project-specific search.
+
+---
+
+That was the end of the list of actions.
+
+Here is a JSON summary of each of the files in my project:
@@ -8,7 +8,7 @@ use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
use feature_flags::FeatureFlag;
use gpui::{AppContext, Task, WeakView};
use language::{CodeLabel, LineEnding, LspAdapterDelegate};
-use semantic_index::SemanticIndex;
+use semantic_index::SemanticDb;
use std::{
fmt::Write,
path::PathBuf,
@@ -92,8 +92,11 @@ impl SlashCommand for SearchSlashCommand {
let project = workspace.read(cx).project().clone();
let fs = project.read(cx).fs().clone();
- let project_index =
- cx.update_global(|index: &mut SemanticIndex, cx| index.project_index(project, cx));
+ let Some(project_index) =
+ cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
+ else {
+ return Task::ready(Err(anyhow::anyhow!("no project indexer")));
+ };
cx.spawn(|cx| async move {
let results = project_index
@@ -149,16 +149,16 @@ spec:
secretKeyRef:
name: google-ai
key: api_key
- - name: QWEN2_7B_API_KEY
+ - name: RUNPOD_API_KEY
valueFrom:
secretKeyRef:
- name: hugging-face
+ name: runpod
key: api_key
- - name: QWEN2_7B_API_URL
+ - name: RUNPOD_API_SUMMARY_URL
valueFrom:
secretKeyRef:
- name: hugging-face
- key: qwen2_api_url
+ name: runpod
+ key: summary
- name: BLOB_STORE_ACCESS_KEY
valueFrom:
secretKeyRef:
@@ -728,6 +728,11 @@ impl Database {
is_ignored: db_entry.is_ignored,
is_external: db_entry.is_external,
git_status: db_entry.git_status.map(|status| status as i32),
+ // This is only used in the summarization backlog, so if it's None,
+ // that just means we won't be able to detect when to resummarize
+ // based on total number of backlogged bytes - instead, we'd go
+ // on number of files only. That shouldn't be a huge deal in practice.
+ size: None,
is_fifo: db_entry.is_fifo,
});
}
@@ -663,6 +663,11 @@ impl Database {
is_ignored: db_entry.is_ignored,
is_external: db_entry.is_external,
git_status: db_entry.git_status.map(|status| status as i32),
+ // This is only used in the summarization backlog, so if it's None,
+ // that just means we won't be able to detect when to resummarize
+ // based on total number of backlogged bytes - instead, we'd go
+ // on number of files only. That shouldn't be a huge deal in practice.
+ size: None,
is_fifo: db_entry.is_fifo,
});
}
@@ -170,8 +170,8 @@ pub struct Config {
pub anthropic_api_key: Option<Arc<str>>,
pub anthropic_staff_api_key: Option<Arc<str>>,
pub llm_closed_beta_model_name: Option<Arc<str>>,
- pub qwen2_7b_api_key: Option<Arc<str>>,
- pub qwen2_7b_api_url: Option<Arc<str>>,
+ pub runpod_api_key: Option<Arc<str>>,
+ pub runpod_api_summary_url: Option<Arc<str>>,
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
@@ -235,8 +235,8 @@ impl Config {
stripe_api_key: None,
stripe_price_id: None,
supermaven_admin_api_key: None,
- qwen2_7b_api_key: None,
- qwen2_7b_api_url: None,
+ runpod_api_key: None,
+ runpod_api_summary_url: None,
user_backfiller_github_access_token: None,
}
}
@@ -402,12 +402,12 @@ async fn perform_completion(
LanguageModelProvider::Zed => {
let api_key = state
.config
- .qwen2_7b_api_key
+ .runpod_api_key
.as_ref()
.context("no Qwen2-7B API key configured on the server")?;
let api_url = state
.config
- .qwen2_7b_api_url
+ .runpod_api_summary_url
.as_ref()
.context("no Qwen2-7B URL configured on the server")?;
let chunks = open_ai::stream_completion(
@@ -1,5 +1,5 @@
use super::*;
-use sea_orm::QueryOrder;
+use sea_orm::{sea_query::OnConflict, QueryOrder};
use std::str::FromStr;
use strum::IntoEnumIterator as _;
@@ -99,6 +99,17 @@ impl LlmDatabase {
..Default::default()
}
}))
+ .on_conflict(
+ OnConflict::columns([model::Column::ProviderId, model::Column::Name])
+ .update_columns([
+ model::Column::MaxRequestsPerMinute,
+ model::Column::MaxTokensPerMinute,
+ model::Column::MaxTokensPerDay,
+ model::Column::PricePerMillionInputTokens,
+ model::Column::PricePerMillionOutputTokens,
+ ])
+ .to_owned(),
+ )
.exec_without_returning(&*tx)
.await?;
Ok(())
@@ -40,6 +40,15 @@ pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool)
price_per_million_input_tokens: 25, // $0.25/MTok
price_per_million_output_tokens: 125, // $1.25/MTok
},
+ ModelParams {
+ provider: LanguageModelProvider::Zed,
+ name: "Qwen/Qwen2-7B-Instruct".into(),
+ max_requests_per_minute: 5,
+ max_tokens_per_minute: 25_000, // These are arbitrary limits we've set to cap costs; we control this number
+ max_tokens_per_day: 300_000,
+ price_per_million_input_tokens: 25,
+ price_per_million_output_tokens: 125,
+ },
])
.await
}
@@ -679,8 +679,8 @@ impl TestServer {
stripe_api_key: None,
stripe_price_id: None,
supermaven_admin_api_key: None,
- qwen2_7b_api_key: None,
- qwen2_7b_api_url: None,
+ runpod_api_key: None,
+ runpod_api_summary_url: None,
user_backfiller_github_access_token: None,
},
})
@@ -13,3 +13,4 @@ path = "src/feature_flags.rs"
[dependencies]
gpui.workspace = true
+futures.workspace = true
@@ -1,4 +1,10 @@
+use futures::{channel::oneshot, FutureExt as _};
use gpui::{AppContext, Global, Subscription, ViewContext};
+use std::{
+ future::Future,
+ pin::Pin,
+ task::{Context, Poll},
+};
#[derive(Default)]
struct FeatureFlags {
@@ -53,6 +59,15 @@ impl FeatureFlag for ZedPro {
const NAME: &'static str = "zed-pro";
}
+pub struct AutoCommand {}
+impl FeatureFlag for AutoCommand {
+ const NAME: &'static str = "auto-command";
+
+ fn enabled_for_staff() -> bool {
+ false
+ }
+}
+
pub trait FeatureFlagViewExt<V: 'static> {
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
where
@@ -75,6 +90,7 @@ where
}
pub trait FeatureFlagAppExt {
+ fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag;
fn update_flags(&mut self, staff: bool, flags: Vec<String>);
fn set_staff(&mut self, staff: bool);
fn has_flag<T: FeatureFlag>(&self) -> bool;
@@ -82,7 +98,7 @@ pub trait FeatureFlagAppExt {
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
where
- F: Fn(bool, &mut AppContext) + 'static;
+ F: FnMut(bool, &mut AppContext) + 'static;
}
impl FeatureFlagAppExt for AppContext {
@@ -109,13 +125,49 @@ impl FeatureFlagAppExt for AppContext {
.unwrap_or(false)
}
- fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
+ fn observe_flag<T: FeatureFlag, F>(&mut self, mut callback: F) -> Subscription
where
- F: Fn(bool, &mut AppContext) + 'static,
+ F: FnMut(bool, &mut AppContext) + 'static,
{
self.observe_global::<FeatureFlags>(move |cx| {
let feature_flags = cx.global::<FeatureFlags>();
callback(feature_flags.has_flag::<T>(), cx);
})
}
+
+ fn wait_for_flag<T: FeatureFlag>(&mut self) -> WaitForFlag {
+ let (tx, rx) = oneshot::channel::<bool>();
+ let mut tx = Some(tx);
+ let subscription: Option<Subscription>;
+
+ match self.try_global::<FeatureFlags>() {
+ Some(feature_flags) => {
+ subscription = None;
+ tx.take().unwrap().send(feature_flags.has_flag::<T>()).ok();
+ }
+ None => {
+ subscription = Some(self.observe_global::<FeatureFlags>(move |cx| {
+ let feature_flags = cx.global::<FeatureFlags>();
+ if let Some(tx) = tx.take() {
+ tx.send(feature_flags.has_flag::<T>()).ok();
+ }
+ }));
+ }
+ }
+
+ WaitForFlag(rx, subscription)
+ }
+}
+
+pub struct WaitForFlag(oneshot::Receiver<bool>, Option<Subscription>);
+
+impl Future for WaitForFlag {
+ type Output = bool;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ self.0.poll_unpin(cx).map(|result| {
+ self.1.take();
+ result.unwrap_or(false)
+ })
+ }
}
@@ -171,6 +171,7 @@ pub struct Metadata {
pub mtime: SystemTime,
pub is_symlink: bool,
pub is_dir: bool,
+ pub len: u64,
pub is_fifo: bool,
}
@@ -497,6 +498,7 @@ impl Fs for RealFs {
Ok(Some(Metadata {
inode,
mtime: metadata.modified().unwrap(),
+ len: metadata.len(),
is_symlink,
is_dir: metadata.file_type().is_dir(),
is_fifo,
@@ -800,11 +802,13 @@ enum FakeFsEntry {
File {
inode: u64,
mtime: SystemTime,
+ len: u64,
content: Vec<u8>,
},
Dir {
inode: u64,
mtime: SystemTime,
+ len: u64,
entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
git_repo_state: Option<Arc<Mutex<git::repository::FakeGitRepositoryState>>>,
},
@@ -935,6 +939,7 @@ impl FakeFs {
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
inode: 0,
mtime: SystemTime::UNIX_EPOCH,
+ len: 0,
entries: Default::default(),
git_repo_state: None,
})),
@@ -969,6 +974,7 @@ impl FakeFs {
inode: new_inode,
mtime: new_mtime,
content: Vec::new(),
+ len: 0,
})));
}
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
@@ -1016,6 +1022,7 @@ impl FakeFs {
let file = Arc::new(Mutex::new(FakeFsEntry::File {
inode,
mtime,
+ len: content.len() as u64,
content,
}));
let mut kind = None;
@@ -1369,6 +1376,7 @@ impl Fs for FakeFs {
Arc::new(Mutex::new(FakeFsEntry::Dir {
inode,
mtime,
+ len: 0,
entries: Default::default(),
git_repo_state: None,
}))
@@ -1391,6 +1399,7 @@ impl Fs for FakeFs {
let file = Arc::new(Mutex::new(FakeFsEntry::File {
inode,
mtime,
+ len: 0,
content: Vec::new(),
}));
let mut kind = Some(PathEventKind::Created);
@@ -1539,6 +1548,7 @@ impl Fs for FakeFs {
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
inode,
mtime,
+ len: content.len() as u64,
content: Vec::new(),
})))
.clone(),
@@ -1694,16 +1704,22 @@ impl Fs for FakeFs {
let entry = entry.lock();
Ok(Some(match &*entry {
- FakeFsEntry::File { inode, mtime, .. } => Metadata {
+ FakeFsEntry::File {
+ inode, mtime, len, ..
+ } => Metadata {
inode: *inode,
mtime: *mtime,
+ len: *len,
is_dir: false,
is_symlink,
is_fifo: false,
},
- FakeFsEntry::Dir { inode, mtime, .. } => Metadata {
+ FakeFsEntry::Dir {
+ inode, mtime, len, ..
+ } => Metadata {
inode: *inode,
mtime: *mtime,
+ len: *len,
is_dir: true,
is_symlink,
is_fifo: false,
@@ -57,7 +57,6 @@ impl GitStatus {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(anyhow!("git status process failed: {}", stderr));
}
-
let stdout = String::from_utf8_lossy(&output.stdout);
let mut entries = stdout
.split('\0')
@@ -221,6 +221,10 @@ impl HttpClient for HttpClientWithUrl {
pub fn client(user_agent: Option<String>, proxy: Option<Uri>) -> Arc<dyn HttpClient> {
let mut builder = isahc::HttpClient::builder()
+ // Some requests to Qwen2 models on Runpod can take 32+ seconds,
+ // especially if there's a cold boot involved. We may need to have
+ // those requests use a different http client, because global timeouts
+ // of 50 and 60 seconds, respectively, would be very high!
.connect_timeout(Duration::from_secs(5))
.low_speed_timeout(100, Duration::from_secs(5))
.proxy(proxy.clone());
@@ -17,14 +17,14 @@ pub enum CloudModel {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
pub enum ZedModel {
- #[serde(rename = "qwen2-7b-instruct")]
+ #[serde(rename = "Qwen/Qwen2-7B-Instruct")]
Qwen2_7bInstruct,
}
impl ZedModel {
pub fn id(&self) -> &str {
match self {
- ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct",
+ ZedModel::Qwen2_7bInstruct => "Qwen/Qwen2-7B-Instruct",
}
}
@@ -319,7 +319,7 @@ impl AnthropicModel {
};
async move {
- let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ let api_key = api_key.ok_or_else(|| anyhow!("Missing Anthropic API Key"))?;
let request = anthropic::stream_completion(
http_client.as_ref(),
&api_url,
@@ -265,7 +265,7 @@ impl LanguageModel for GoogleLanguageModel {
let low_speed_timeout = settings.low_speed_timeout;
async move {
- let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
let response = google_ai::count_tokens(
http_client.as_ref(),
&api_url,
@@ -304,7 +304,7 @@ impl LanguageModel for GoogleLanguageModel {
};
let future = self.rate_limiter.stream(async move {
- let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API Key"))?;
let response = stream_generate_content(
http_client.as_ref(),
&api_url,
@@ -239,7 +239,7 @@ impl OpenAiLanguageModel {
};
let future = self.request_limiter.stream(async move {
- let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenAI API Key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
@@ -159,11 +159,13 @@ impl LanguageModelRegistry {
providers
}
- pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
+ pub fn available_models<'a>(
+ &'a self,
+ cx: &'a AppContext,
+ ) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
self.providers
.values()
.flat_map(|provider| provider.provided_models(cx))
- .collect()
}
pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
@@ -1823,6 +1823,7 @@ impl ProjectPanel {
path: entry.path.join("\0").into(),
inode: 0,
mtime: entry.mtime,
+ size: entry.size,
is_ignored: entry.is_ignored,
is_external: false,
is_private: false,
@@ -1855,6 +1855,7 @@ message Entry {
bool is_external = 8;
optional GitStatus git_status = 9;
bool is_fifo = 10;
+ optional uint64 size = 11;
}
message RepositoryEntry {
@@ -19,14 +19,18 @@ crate-type = ["bin"]
[dependencies]
anyhow.workspace = true
+arrayvec.workspace = true
+blake3.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
+feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
futures-batch.workspace = true
gpui.workspace = true
language.workspace = true
+language_model.workspace = true
log.workspace = true
heed.workspace = true
http_client.workspace = true
@@ -4,7 +4,7 @@ use gpui::App;
use http_client::HttpClientWithUrl;
use language::language_settings::AllLanguageSettings;
use project::Project;
-use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
+use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticDb};
use settings::SettingsStore;
use std::{
path::{Path, PathBuf},
@@ -50,7 +50,7 @@ fn main() {
));
cx.spawn(|mut cx| async move {
- let semantic_index = SemanticIndex::new(
+ let semantic_index = SemanticDb::new(
PathBuf::from("/tmp/semantic-index-db.mdb"),
embedding_provider,
&mut cx,
@@ -71,6 +71,7 @@ fn main() {
let project_index = cx
.update(|cx| semantic_index.project_index(project.clone(), cx))
+ .unwrap()
.unwrap();
let (tx, rx) = oneshot::channel();
@@ -12,6 +12,12 @@ use futures::{future::BoxFuture, FutureExt};
use serde::{Deserialize, Serialize};
use std::{fmt, future};
+/// Trait for embedding providers. Texts in, vectors out.
+pub trait EmbeddingProvider: Sync + Send {
+ fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
+ fn batch_size(&self) -> usize;
+}
+
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
pub struct Embedding(Vec<f32>);
@@ -68,12 +74,6 @@ impl fmt::Display for Embedding {
}
}
-/// Trait for embedding providers. Texts in, vectors out.
-pub trait EmbeddingProvider: Sync + Send {
- fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
- fn batch_size(&self) -> usize;
-}
-
#[derive(Debug)]
pub struct TextToEmbed<'a> {
pub text: &'a str,
@@ -0,0 +1,469 @@
+use crate::{
+ chunking::{self, Chunk},
+ embedding::{Embedding, EmbeddingProvider, TextToEmbed},
+ indexing::{IndexingEntryHandle, IndexingEntrySet},
+};
+use anyhow::{anyhow, Context as _, Result};
+use collections::Bound;
+use fs::Fs;
+use futures::stream::StreamExt;
+use futures_batch::ChunksTimeoutStreamExt;
+use gpui::{AppContext, Model, Task};
+use heed::types::{SerdeBincode, Str};
+use language::LanguageRegistry;
+use log;
+use project::{Entry, UpdatedEntriesSet, Worktree};
+use serde::{Deserialize, Serialize};
+use smol::channel;
+use std::{
+ cmp::Ordering,
+ future::Future,
+ iter,
+ path::Path,
+ sync::Arc,
+ time::{Duration, SystemTime},
+};
+use util::ResultExt;
+use worktree::Snapshot;
+
+pub struct EmbeddingIndex {
+ worktree: Model<Worktree>,
+ db_connection: heed::Env,
+ db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
+ fs: Arc<dyn Fs>,
+ language_registry: Arc<LanguageRegistry>,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ entry_ids_being_indexed: Arc<IndexingEntrySet>,
+}
+
+impl EmbeddingIndex {
+ pub fn new(
+ worktree: Model<Worktree>,
+ fs: Arc<dyn Fs>,
+ db_connection: heed::Env,
+ embedding_db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
+ language_registry: Arc<LanguageRegistry>,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ entry_ids_being_indexed: Arc<IndexingEntrySet>,
+ ) -> Self {
+ Self {
+ worktree,
+ fs,
+ db_connection,
+ db: embedding_db,
+ language_registry,
+ embedding_provider,
+ entry_ids_being_indexed,
+ }
+ }
+
+ pub fn db(&self) -> &heed::Database<Str, SerdeBincode<EmbeddedFile>> {
+ &self.db
+ }
+
+ pub fn index_entries_changed_on_disk(
+ &self,
+ cx: &AppContext,
+ ) -> impl Future<Output = Result<()>> {
+ let worktree = self.worktree.read(cx).snapshot();
+ let worktree_abs_path = worktree.abs_path().clone();
+ let scan = self.scan_entries(worktree, cx);
+ let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
+ let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
+ let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
+ async move {
+ futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
+ Ok(())
+ }
+ }
+
+ pub fn index_updated_entries(
+ &self,
+ updated_entries: UpdatedEntriesSet,
+ cx: &AppContext,
+ ) -> impl Future<Output = Result<()>> {
+ let worktree = self.worktree.read(cx).snapshot();
+ let worktree_abs_path = worktree.abs_path().clone();
+ let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
+ let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
+ let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
+ let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
+ async move {
+ futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
+ Ok(())
+ }
+ }
+
+ fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
+ let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
+ let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
+ let db_connection = self.db_connection.clone();
+ let db = self.db;
+ let entries_being_indexed = self.entry_ids_being_indexed.clone();
+ let task = cx.background_executor().spawn(async move {
+ let txn = db_connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+ let mut db_entries = db
+ .iter(&txn)
+ .context("failed to create iterator")?
+ .move_between_keys()
+ .peekable();
+
+ let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
+ for entry in worktree.files(false, 0) {
+ log::trace!("scanning for embedding index: {:?}", &entry.path);
+
+ let entry_db_key = db_key_for_path(&entry.path);
+
+ let mut saved_mtime = None;
+ while let Some(db_entry) = db_entries.peek() {
+ match db_entry {
+ Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
+ Ordering::Less => {
+ if let Some(deletion_range) = deletion_range.as_mut() {
+ deletion_range.1 = Bound::Included(db_path);
+ } else {
+ deletion_range =
+ Some((Bound::Included(db_path), Bound::Included(db_path)));
+ }
+
+ db_entries.next();
+ }
+ Ordering::Equal => {
+ if let Some(deletion_range) = deletion_range.take() {
+ deleted_entry_ranges_tx
+ .send((
+ deletion_range.0.map(ToString::to_string),
+ deletion_range.1.map(ToString::to_string),
+ ))
+ .await?;
+ }
+ saved_mtime = db_embedded_file.mtime;
+ db_entries.next();
+ break;
+ }
+ Ordering::Greater => {
+ break;
+ }
+ },
+ Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
+ }
+ }
+
+ if entry.mtime != saved_mtime {
+ let handle = entries_being_indexed.insert(entry.id);
+ updated_entries_tx.send((entry.clone(), handle)).await?;
+ }
+ }
+
+ if let Some(db_entry) = db_entries.next() {
+ let (db_path, _) = db_entry?;
+ deleted_entry_ranges_tx
+ .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
+ .await?;
+ }
+
+ Ok(())
+ });
+
+ ScanEntries {
+ updated_entries: updated_entries_rx,
+ deleted_entry_ranges: deleted_entry_ranges_rx,
+ task,
+ }
+ }
+
+ fn scan_updated_entries(
+ &self,
+ worktree: Snapshot,
+ updated_entries: UpdatedEntriesSet,
+ cx: &AppContext,
+ ) -> ScanEntries {
+ let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
+ let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
+ let entries_being_indexed = self.entry_ids_being_indexed.clone();
+ let task = cx.background_executor().spawn(async move {
+ for (path, entry_id, status) in updated_entries.iter() {
+ match status {
+ project::PathChange::Added
+ | project::PathChange::Updated
+ | project::PathChange::AddedOrUpdated => {
+ if let Some(entry) = worktree.entry_for_id(*entry_id) {
+ if entry.is_file() {
+ let handle = entries_being_indexed.insert(entry.id);
+ updated_entries_tx.send((entry.clone(), handle)).await?;
+ }
+ }
+ }
+ project::PathChange::Removed => {
+ let db_path = db_key_for_path(path);
+ deleted_entry_ranges_tx
+ .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
+ .await?;
+ }
+ project::PathChange::Loaded => {
+ // Do nothing.
+ }
+ }
+ }
+
+ Ok(())
+ });
+
+ ScanEntries {
+ updated_entries: updated_entries_rx,
+ deleted_entry_ranges: deleted_entry_ranges_rx,
+ task,
+ }
+ }
+
+ fn chunk_files(
+ &self,
+ worktree_abs_path: Arc<Path>,
+ entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
+ cx: &AppContext,
+ ) -> ChunkFiles {
+ let language_registry = self.language_registry.clone();
+ let fs = self.fs.clone();
+ let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
+ let task = cx.spawn(|cx| async move {
+ cx.background_executor()
+ .scoped(|cx| {
+ for _ in 0..cx.num_cpus() {
+ cx.spawn(async {
+ while let Ok((entry, handle)) = entries.recv().await {
+ let entry_abs_path = worktree_abs_path.join(&entry.path);
+ match fs.load(&entry_abs_path).await {
+ Ok(text) => {
+ let language = language_registry
+ .language_for_file_path(&entry.path)
+ .await
+ .ok();
+ let chunked_file = ChunkedFile {
+ chunks: chunking::chunk_text(
+ &text,
+ language.as_ref(),
+ &entry.path,
+ ),
+ handle,
+ path: entry.path,
+ mtime: entry.mtime,
+ text,
+ };
+
+ if chunked_files_tx.send(chunked_file).await.is_err() {
+ return;
+ }
+ }
+ Err(_)=> {
+ log::error!("Failed to read contents into a UTF-8 string: {entry_abs_path:?}");
+ }
+ }
+ }
+ });
+ }
+ })
+ .await;
+ Ok(())
+ });
+
+ ChunkFiles {
+ files: chunked_files_rx,
+ task,
+ }
+ }
+
+ pub fn embed_files(
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ chunked_files: channel::Receiver<ChunkedFile>,
+ cx: &AppContext,
+ ) -> EmbedFiles {
+ let embedding_provider = embedding_provider.clone();
+ let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
+ let task = cx.background_executor().spawn(async move {
+ let mut chunked_file_batches =
+ chunked_files.chunks_timeout(512, Duration::from_secs(2));
+ while let Some(chunked_files) = chunked_file_batches.next().await {
+ // View the batch of files as a vec of chunks
+ // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
+ // Once those are done, reassemble them back into the files in which they belong
+ // If any embeddings fail for a file, the entire file is discarded
+
+ let chunks: Vec<TextToEmbed> = chunked_files
+ .iter()
+ .flat_map(|file| {
+ file.chunks.iter().map(|chunk| TextToEmbed {
+ text: &file.text[chunk.range.clone()],
+ digest: chunk.digest,
+ })
+ })
+ .collect::<Vec<_>>();
+
+ let mut embeddings: Vec<Option<Embedding>> = Vec::new();
+ for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
+ if let Some(batch_embeddings) =
+ embedding_provider.embed(embedding_batch).await.log_err()
+ {
+ if batch_embeddings.len() == embedding_batch.len() {
+ embeddings.extend(batch_embeddings.into_iter().map(Some));
+ continue;
+ }
+ log::error!(
+ "embedding provider returned unexpected embedding count {}, expected {}",
+ batch_embeddings.len(), embedding_batch.len()
+ );
+ }
+
+ embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
+ }
+
+ let mut embeddings = embeddings.into_iter();
+ for chunked_file in chunked_files {
+ let mut embedded_file = EmbeddedFile {
+ path: chunked_file.path,
+ mtime: chunked_file.mtime,
+ chunks: Vec::new(),
+ };
+
+ let mut embedded_all_chunks = true;
+ for (chunk, embedding) in
+ chunked_file.chunks.into_iter().zip(embeddings.by_ref())
+ {
+ if let Some(embedding) = embedding {
+ embedded_file
+ .chunks
+ .push(EmbeddedChunk { chunk, embedding });
+ } else {
+ embedded_all_chunks = false;
+ }
+ }
+
+ if embedded_all_chunks {
+ embedded_files_tx
+ .send((embedded_file, chunked_file.handle))
+ .await?;
+ }
+ }
+ }
+ Ok(())
+ });
+
+ EmbedFiles {
+ files: embedded_files_rx,
+ task,
+ }
+ }
+
+ fn persist_embeddings(
+ &self,
+ mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
+ embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
+ cx: &AppContext,
+ ) -> Task<Result<()>> {
+ let db_connection = self.db_connection.clone();
+ let db = self.db;
+ cx.background_executor().spawn(async move {
+ while let Some(deletion_range) = deleted_entry_ranges.next().await {
+ let mut txn = db_connection.write_txn()?;
+ let start = deletion_range.0.as_ref().map(|start| start.as_str());
+ let end = deletion_range.1.as_ref().map(|end| end.as_str());
+ log::debug!("deleting embeddings in range {:?}", &(start, end));
+ db.delete_range(&mut txn, &(start, end))?;
+ txn.commit()?;
+ }
+
+ let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
+ while let Some(embedded_files) = embedded_files.next().await {
+ let mut txn = db_connection.write_txn()?;
+ for (file, _) in &embedded_files {
+ log::debug!("saving embedding for file {:?}", file.path);
+ let key = db_key_for_path(&file.path);
+ db.put(&mut txn, &key, file)?;
+ }
+ txn.commit()?;
+
+ drop(embedded_files);
+ log::debug!("committed");
+ }
+
+ Ok(())
+ })
+ }
+
+ pub fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
+ let connection = self.db_connection.clone();
+ let db = self.db;
+ cx.background_executor().spawn(async move {
+ let tx = connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+ let result = db
+ .iter(&tx)?
+ .map(|entry| Ok(entry?.1.path.clone()))
+ .collect::<Result<Vec<Arc<Path>>>>();
+ drop(tx);
+ result
+ })
+ }
+
+ pub fn chunks_for_path(
+ &self,
+ path: Arc<Path>,
+ cx: &AppContext,
+ ) -> Task<Result<Vec<EmbeddedChunk>>> {
+ let connection = self.db_connection.clone();
+ let db = self.db;
+ cx.background_executor().spawn(async move {
+ let tx = connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+ Ok(db
+ .get(&tx, &db_key_for_path(&path))?
+ .ok_or_else(|| anyhow!("no such path"))?
+ .chunks
+ .clone())
+ })
+ }
+}
+
+struct ScanEntries {
+ updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
+ deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
+ task: Task<Result<()>>,
+}
+
+struct ChunkFiles {
+ files: channel::Receiver<ChunkedFile>,
+ task: Task<Result<()>>,
+}
+
+pub struct ChunkedFile {
+ pub path: Arc<Path>,
+ pub mtime: Option<SystemTime>,
+ pub handle: IndexingEntryHandle,
+ pub text: String,
+ pub chunks: Vec<Chunk>,
+}
+
+pub struct EmbedFiles {
+ pub files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
+ pub task: Task<Result<()>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct EmbeddedFile {
+ pub path: Arc<Path>,
+ pub mtime: Option<SystemTime>,
+ pub chunks: Vec<EmbeddedChunk>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct EmbeddedChunk {
+ pub chunk: Chunk,
+ pub embedding: Embedding,
+}
+
+fn db_key_for_path(path: &Arc<Path>) -> String {
+ path.to_string_lossy().replace('/', "\0")
+}
@@ -0,0 +1,49 @@
+use collections::HashSet;
+use parking_lot::Mutex;
+use project::ProjectEntryId;
+use smol::channel;
+use std::sync::{Arc, Weak};
+
+/// The set of entries that are currently being indexed.
+pub struct IndexingEntrySet {
+ entry_ids: Mutex<HashSet<ProjectEntryId>>,
+ tx: channel::Sender<()>,
+}
+
+/// When dropped, removes the entry from the set of entries that are being indexed.
+#[derive(Clone)]
+pub(crate) struct IndexingEntryHandle {
+ entry_id: ProjectEntryId,
+ set: Weak<IndexingEntrySet>,
+}
+
+impl IndexingEntrySet {
+ pub fn new(tx: channel::Sender<()>) -> Self {
+ Self {
+ entry_ids: Default::default(),
+ tx,
+ }
+ }
+
+ pub fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
+ self.entry_ids.lock().insert(entry_id);
+ self.tx.send_blocking(()).ok();
+ IndexingEntryHandle {
+ entry_id,
+ set: Arc::downgrade(self),
+ }
+ }
+
+ pub fn len(&self) -> usize {
+ self.entry_ids.lock().len()
+ }
+}
+
+impl Drop for IndexingEntryHandle {
+ fn drop(&mut self) {
+ if let Some(set) = self.set.upgrade() {
+ set.tx.send_blocking(()).ok();
+ set.entry_ids.lock().remove(&self.entry_id);
+ }
+ }
+}
@@ -0,0 +1,523 @@
+use crate::{
+ embedding::{EmbeddingProvider, TextToEmbed},
+ summary_index::FileSummary,
+ worktree_index::{WorktreeIndex, WorktreeIndexHandle},
+};
+use anyhow::{anyhow, Context, Result};
+use collections::HashMap;
+use fs::Fs;
+use futures::{stream::StreamExt, FutureExt};
+use gpui::{
+ AppContext, Entity, EntityId, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel,
+};
+use language::LanguageRegistry;
+use log;
+use project::{Project, Worktree, WorktreeId};
+use serde::{Deserialize, Serialize};
+use smol::channel;
+use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc};
+use util::ResultExt;
+
+#[derive(Debug)]
+pub struct SearchResult {
+ pub worktree: Model<Worktree>,
+ pub path: Arc<Path>,
+ pub range: Range<usize>,
+ pub score: f32,
+}
+
+pub struct WorktreeSearchResult {
+ pub worktree_id: WorktreeId,
+ pub path: Arc<Path>,
+ pub range: Range<usize>,
+ pub score: f32,
+}
+
+#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
+pub enum Status {
+ Idle,
+ Loading,
+ Scanning { remaining_count: NonZeroUsize },
+}
+
+pub struct ProjectIndex {
+ db_connection: heed::Env,
+ project: WeakModel<Project>,
+ worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
+ language_registry: Arc<LanguageRegistry>,
+ fs: Arc<dyn Fs>,
+ last_status: Status,
+ status_tx: channel::Sender<()>,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ _maintain_status: Task<()>,
+ _subscription: Subscription,
+}
+
+impl ProjectIndex {
+ pub fn new(
+ project: Model<Project>,
+ db_connection: heed::Env,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let language_registry = project.read(cx).languages().clone();
+ let fs = project.read(cx).fs().clone();
+ let (status_tx, mut status_rx) = channel::unbounded();
+ let mut this = ProjectIndex {
+ db_connection,
+ project: project.downgrade(),
+ worktree_indices: HashMap::default(),
+ language_registry,
+ fs,
+ status_tx,
+ last_status: Status::Idle,
+ embedding_provider,
+ _subscription: cx.subscribe(&project, Self::handle_project_event),
+ _maintain_status: cx.spawn(|this, mut cx| async move {
+ while status_rx.next().await.is_some() {
+ if this
+ .update(&mut cx, |this, cx| this.update_status(cx))
+ .is_err()
+ {
+ break;
+ }
+ }
+ }),
+ };
+ this.update_worktree_indices(cx);
+ this
+ }
+
+ pub fn status(&self) -> Status {
+ self.last_status
+ }
+
+ pub fn project(&self) -> WeakModel<Project> {
+ self.project.clone()
+ }
+
+ pub fn fs(&self) -> Arc<dyn Fs> {
+ self.fs.clone()
+ }
+
+ fn handle_project_event(
+ &mut self,
+ _: Model<Project>,
+ event: &project::Event,
+ cx: &mut ModelContext<Self>,
+ ) {
+ match event {
+ project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
+ self.update_worktree_indices(cx);
+ }
+ _ => {}
+ }
+ }
+
+ fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
+ let Some(project) = self.project.upgrade() else {
+ return;
+ };
+
+ let worktrees = project
+ .read(cx)
+ .visible_worktrees(cx)
+ .filter_map(|worktree| {
+ if worktree.read(cx).is_local() {
+ Some((worktree.entity_id(), worktree))
+ } else {
+ None
+ }
+ })
+ .collect::<HashMap<_, _>>();
+
+ self.worktree_indices
+ .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
+ for (worktree_id, worktree) in worktrees {
+ self.worktree_indices.entry(worktree_id).or_insert_with(|| {
+ let worktree_index = WorktreeIndex::load(
+ worktree.clone(),
+ self.db_connection.clone(),
+ self.language_registry.clone(),
+ self.fs.clone(),
+ self.status_tx.clone(),
+ self.embedding_provider.clone(),
+ cx,
+ );
+
+ let load_worktree = cx.spawn(|this, mut cx| async move {
+ let result = match worktree_index.await {
+ Ok(worktree_index) => {
+ this.update(&mut cx, |this, _| {
+ this.worktree_indices.insert(
+ worktree_id,
+ WorktreeIndexHandle::Loaded {
+ index: worktree_index.clone(),
+ },
+ );
+ })?;
+ Ok(worktree_index)
+ }
+ Err(error) => {
+ this.update(&mut cx, |this, _cx| {
+ this.worktree_indices.remove(&worktree_id)
+ })?;
+ Err(Arc::new(error))
+ }
+ };
+
+ this.update(&mut cx, |this, cx| this.update_status(cx))?;
+
+ result
+ });
+
+ WorktreeIndexHandle::Loading {
+ index: load_worktree.shared(),
+ }
+ });
+ }
+
+ self.update_status(cx);
+ }
+
+ fn update_status(&mut self, cx: &mut ModelContext<Self>) {
+ let mut indexing_count = 0;
+ let mut any_loading = false;
+
+ for index in self.worktree_indices.values_mut() {
+ match index {
+ WorktreeIndexHandle::Loading { .. } => {
+ any_loading = true;
+ break;
+ }
+ WorktreeIndexHandle::Loaded { index, .. } => {
+ indexing_count += index.read(cx).entry_ids_being_indexed().len();
+ }
+ }
+ }
+
+ let status = if any_loading {
+ Status::Loading
+ } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
+ Status::Scanning { remaining_count }
+ } else {
+ Status::Idle
+ };
+
+ if status != self.last_status {
+ self.last_status = status;
+ cx.emit(status);
+ }
+ }
+
+ pub fn search(
+ &self,
+ query: String,
+ limit: usize,
+ cx: &AppContext,
+ ) -> Task<Result<Vec<SearchResult>>> {
+ let (chunks_tx, chunks_rx) = channel::bounded(1024);
+ let mut worktree_scan_tasks = Vec::new();
+ for worktree_index in self.worktree_indices.values() {
+ let worktree_index = worktree_index.clone();
+ let chunks_tx = chunks_tx.clone();
+ worktree_scan_tasks.push(cx.spawn(|cx| async move {
+ let index = match worktree_index {
+ WorktreeIndexHandle::Loading { index } => {
+ index.clone().await.map_err(|error| anyhow!(error))?
+ }
+ WorktreeIndexHandle::Loaded { index } => index.clone(),
+ };
+
+ index
+ .read_with(&cx, |index, cx| {
+ let worktree_id = index.worktree().read(cx).id();
+ let db_connection = index.db_connection().clone();
+ let db = *index.embedding_index().db();
+ cx.background_executor().spawn(async move {
+ let txn = db_connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+ let db_entries = db.iter(&txn).context("failed to iterate database")?;
+ for db_entry in db_entries {
+ let (_key, db_embedded_file) = db_entry?;
+ for chunk in db_embedded_file.chunks {
+ chunks_tx
+ .send((worktree_id, db_embedded_file.path.clone(), chunk))
+ .await?;
+ }
+ }
+ anyhow::Ok(())
+ })
+ })?
+ .await
+ }));
+ }
+ drop(chunks_tx);
+
+ let project = self.project.clone();
+ let embedding_provider = self.embedding_provider.clone();
+ cx.spawn(|cx| async move {
+ #[cfg(debug_assertions)]
+ let embedding_query_start = std::time::Instant::now();
+ log::info!("Searching for {query}");
+
+ let query_embeddings = embedding_provider
+ .embed(&[TextToEmbed::new(&query)])
+ .await?;
+ let query_embedding = query_embeddings
+ .into_iter()
+ .next()
+ .ok_or_else(|| anyhow!("no embedding for query"))?;
+
+ let mut results_by_worker = Vec::new();
+ for _ in 0..cx.background_executor().num_cpus() {
+ results_by_worker.push(Vec::<WorktreeSearchResult>::new());
+ }
+
+ #[cfg(debug_assertions)]
+ let search_start = std::time::Instant::now();
+
+ cx.background_executor()
+ .scoped(|cx| {
+ for results in results_by_worker.iter_mut() {
+ cx.spawn(async {
+ while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
+ let score = chunk.embedding.similarity(&query_embedding);
+ let ix = match results.binary_search_by(|probe| {
+ score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
+ }) {
+ Ok(ix) | Err(ix) => ix,
+ };
+ results.insert(
+ ix,
+ WorktreeSearchResult {
+ worktree_id,
+ path: path.clone(),
+ range: chunk.chunk.range.clone(),
+ score,
+ },
+ );
+ results.truncate(limit);
+ }
+ });
+ }
+ })
+ .await;
+
+ for scan_task in futures::future::join_all(worktree_scan_tasks).await {
+ scan_task.log_err();
+ }
+
+ project.read_with(&cx, |project, cx| {
+ let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
+ for worker_results in results_by_worker {
+ search_results.extend(worker_results.into_iter().filter_map(|result| {
+ Some(SearchResult {
+ worktree: project.worktree_for_id(result.worktree_id, cx)?,
+ path: result.path,
+ range: result.range,
+ score: result.score,
+ })
+ }));
+ }
+ search_results.sort_unstable_by(|a, b| {
+ b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
+ });
+ search_results.truncate(limit);
+
+ #[cfg(debug_assertions)]
+ {
+ let search_elapsed = search_start.elapsed();
+ log::debug!(
+ "searched {} entries in {:?}",
+ search_results.len(),
+ search_elapsed
+ );
+ let embedding_query_elapsed = embedding_query_start.elapsed();
+ log::debug!("embedding query took {:?}", embedding_query_elapsed);
+ }
+
+ search_results
+ })
+ })
+ }
+
+ #[cfg(test)]
+ pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
+ let mut result = 0;
+ for worktree_index in self.worktree_indices.values() {
+ if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
+ result += index.read(cx).path_count()?;
+ }
+ }
+ Ok(result)
+ }
+
+ pub(crate) fn worktree_index(
+ &self,
+ worktree_id: WorktreeId,
+ cx: &AppContext,
+ ) -> Option<Model<WorktreeIndex>> {
+ for index in self.worktree_indices.values() {
+ if let WorktreeIndexHandle::Loaded { index, .. } = index {
+ if index.read(cx).worktree().read(cx).id() == worktree_id {
+ return Some(index.clone());
+ }
+ }
+ }
+ None
+ }
+
+ pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
+ let mut result = self
+ .worktree_indices
+ .values()
+ .filter_map(|index| {
+ if let WorktreeIndexHandle::Loaded { index, .. } = index {
+ Some(index.clone())
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+ result.sort_by_key(|index| index.read(cx).worktree().read(cx).id());
+ result
+ }
+
+ pub fn all_summaries(&self, cx: &AppContext) -> Task<Result<Vec<FileSummary>>> {
+ let (summaries_tx, summaries_rx) = channel::bounded(1024);
+ let mut worktree_scan_tasks = Vec::new();
+ for worktree_index in self.worktree_indices.values() {
+ let worktree_index = worktree_index.clone();
+ let summaries_tx: channel::Sender<(String, String)> = summaries_tx.clone();
+ worktree_scan_tasks.push(cx.spawn(|cx| async move {
+ let index = match worktree_index {
+ WorktreeIndexHandle::Loading { index } => {
+ index.clone().await.map_err(|error| anyhow!(error))?
+ }
+ WorktreeIndexHandle::Loaded { index } => index.clone(),
+ };
+
+ index
+ .read_with(&cx, |index, cx| {
+ let db_connection = index.db_connection().clone();
+ let summary_index = index.summary_index();
+ let file_digest_db = summary_index.file_digest_db();
+ let summary_db = summary_index.summary_db();
+
+ cx.background_executor().spawn(async move {
+ let txn = db_connection
+ .read_txn()
+ .context("failed to create db read transaction")?;
+ let db_entries = file_digest_db
+ .iter(&txn)
+ .context("failed to iterate database")?;
+ for db_entry in db_entries {
+ let (file_path, db_file) = db_entry?;
+
+ match summary_db.get(&txn, &db_file.digest) {
+ Ok(opt_summary) => {
+ // Currently, we only use summaries we already have. If the file hasn't been
+ // summarized yet, then we skip it and don't include it in the inferred context.
+ // If we want to do just-in-time summarization, this would be the place to do it!
+ if let Some(summary) = opt_summary {
+ summaries_tx
+ .send((file_path.to_string(), summary.to_string()))
+ .await?;
+ } else {
+ log::warn!("No summary found for {:?}", &db_file);
+ }
+ }
+ Err(err) => {
+ log::error!(
+ "Error reading from summary database: {:?}",
+ err
+ );
+ }
+ }
+ }
+ anyhow::Ok(())
+ })
+ })?
+ .await
+ }));
+ }
+ drop(summaries_tx);
+
+ let project = self.project.clone();
+ cx.spawn(|cx| async move {
+ let mut results_by_worker = Vec::new();
+ for _ in 0..cx.background_executor().num_cpus() {
+ results_by_worker.push(Vec::<FileSummary>::new());
+ }
+
+ cx.background_executor()
+ .scoped(|cx| {
+ for results in results_by_worker.iter_mut() {
+ cx.spawn(async {
+ while let Ok((filename, summary)) = summaries_rx.recv().await {
+ results.push(FileSummary { filename, summary });
+ }
+ });
+ }
+ })
+ .await;
+
+ for scan_task in futures::future::join_all(worktree_scan_tasks).await {
+ scan_task.log_err();
+ }
+
+ project.read_with(&cx, |_project, _cx| {
+ results_by_worker.into_iter().flatten().collect()
+ })
+ })
+ }
+
+ /// Empty out the backlogs of all the worktrees in the project
+ pub fn flush_summary_backlogs(&self, cx: &AppContext) -> impl Future<Output = ()> {
+ let flush_start = std::time::Instant::now();
+
+ futures::future::join_all(self.worktree_indices.values().map(|worktree_index| {
+ let worktree_index = worktree_index.clone();
+
+ cx.spawn(|cx| async move {
+ let index = match worktree_index {
+ WorktreeIndexHandle::Loading { index } => {
+ index.clone().await.map_err(|error| anyhow!(error))?
+ }
+ WorktreeIndexHandle::Loaded { index } => index.clone(),
+ };
+ let worktree_abs_path =
+ cx.update(|cx| index.read(cx).worktree().read(cx).abs_path())?;
+
+ index
+ .read_with(&cx, |index, cx| {
+ cx.background_executor()
+ .spawn(index.summary_index().flush_backlog(worktree_abs_path, cx))
+ })?
+ .await
+ })
+ }))
+ .map(move |results| {
+ // Log any errors, but don't block the user. These summaries are supposed to
+ // improve quality by providing extra context, but they aren't hard requirements!
+ for result in results {
+ if let Err(err) = result {
+ log::error!("Error flushing summary backlog: {:?}", err);
+ }
+ }
+
+ log::info!("Summary backlog flushed in {:?}", flush_start.elapsed());
+ })
+ }
+
+ pub fn remaining_summaries(&self, cx: &mut ModelContext<Self>) -> usize {
+ self.worktree_indices(cx)
+ .iter()
+ .map(|index| index.read(cx).summary_index().backlog_len())
+ .sum()
+ }
+}
+
+impl EventEmitter<Status> for ProjectIndex {}
@@ -55,8 +55,12 @@ impl ProjectIndexDebugView {
for index in worktree_indices {
let (root_path, worktree_id, worktree_paths) =
index.read_with(&cx, |index, cx| {
- let worktree = index.worktree.read(cx);
- (worktree.abs_path(), worktree.id(), index.paths(cx))
+ let worktree = index.worktree().read(cx);
+ (
+ worktree.abs_path(),
+ worktree.id(),
+ index.embedding_index().paths(cx),
+ )
})?;
rows.push(Row::Worktree(root_path));
rows.extend(
@@ -82,10 +86,12 @@ impl ProjectIndexDebugView {
cx: &mut ViewContext<Self>,
) -> Option<()> {
let project_index = self.index.read(cx);
- let fs = project_index.fs.clone();
+ let fs = project_index.fs().clone();
let worktree_index = project_index.worktree_index(worktree_id, cx)?.read(cx);
- let root_path = worktree_index.worktree.read(cx).abs_path();
- let chunks = worktree_index.chunks_for_path(file_path.clone(), cx);
+ let root_path = worktree_index.worktree().read(cx).abs_path();
+ let chunks = worktree_index
+ .embedding_index()
+ .chunks_for_path(file_path.clone(), cx);
cx.spawn(|this, mut cx| async move {
let chunks = chunks.await?;
@@ -1,48 +1,35 @@
mod chunking;
mod embedding;
+mod embedding_index;
+mod indexing;
+mod project_index;
mod project_index_debug_view;
+mod summary_backlog;
+mod summary_index;
+mod worktree_index;
+
+use anyhow::{Context as _, Result};
+use collections::HashMap;
+use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
+use project::Project;
+use project_index::ProjectIndex;
+use std::{path::PathBuf, sync::Arc};
+use ui::ViewContext;
+use workspace::Workspace;
-use anyhow::{anyhow, Context as _, Result};
-use chunking::{chunk_text, Chunk};
-use collections::{Bound, HashMap, HashSet};
pub use embedding::*;
-use fs::Fs;
-use futures::{future::Shared, stream::StreamExt, FutureExt};
-use futures_batch::ChunksTimeoutStreamExt;
-use gpui::{
- AppContext, AsyncAppContext, BorrowAppContext, Context, Entity, EntityId, EventEmitter, Global,
- Model, ModelContext, Subscription, Task, WeakModel,
-};
-use heed::types::{SerdeBincode, Str};
-use language::LanguageRegistry;
-use parking_lot::Mutex;
-use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
-use serde::{Deserialize, Serialize};
-use smol::channel;
-use std::{
- cmp::Ordering,
- future::Future,
- iter,
- num::NonZeroUsize,
- ops::Range,
- path::{Path, PathBuf},
- sync::{Arc, Weak},
- time::{Duration, SystemTime},
-};
-use util::ResultExt;
-use worktree::Snapshot;
-
pub use project_index_debug_view::ProjectIndexDebugView;
+pub use summary_index::FileSummary;
-pub struct SemanticIndex {
+pub struct SemanticDb {
embedding_provider: Arc<dyn EmbeddingProvider>,
db_connection: heed::Env,
project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
}
-impl Global for SemanticIndex {}
+impl Global for SemanticDb {}
-impl SemanticIndex {
+impl SemanticDb {
pub async fn new(
db_path: PathBuf,
embedding_provider: Arc<dyn EmbeddingProvider>,
@@ -62,7 +49,45 @@ impl SemanticIndex {
.await
.context("opening database connection")?;
- Ok(SemanticIndex {
+ cx.update(|cx| {
+ cx.observe_new_views(
+ |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
+ let project = workspace.project().clone();
+
+ if cx.has_global::<SemanticDb>() {
+ cx.update_global::<SemanticDb, _>(|this, cx| {
+ let project_index = cx.new_model(|cx| {
+ ProjectIndex::new(
+ project.clone(),
+ this.db_connection.clone(),
+ this.embedding_provider.clone(),
+ cx,
+ )
+ });
+
+ let project_weak = project.downgrade();
+ this.project_indices
+ .insert(project_weak.clone(), project_index);
+
+ cx.on_release(move |_, _, cx| {
+ if cx.has_global::<SemanticDb>() {
+ cx.update_global::<SemanticDb, _>(|this, _| {
+ this.project_indices.remove(&project_weak);
+ })
+ }
+ })
+ .detach();
+ })
+ } else {
+ log::info!("No SemanticDb, skipping project index")
+ }
+ },
+ )
+ .detach();
+ })
+ .ok();
+
+ Ok(SemanticDb {
db_connection,
embedding_provider,
project_indices: HashMap::default(),
@@ -72,985 +97,50 @@ impl SemanticIndex {
pub fn project_index(
&mut self,
project: Model<Project>,
- cx: &mut AppContext,
- ) -> Model<ProjectIndex> {
- let project_weak = project.downgrade();
- project.update(cx, move |_, cx| {
- cx.on_release(move |_, cx| {
- if cx.has_global::<SemanticIndex>() {
- cx.update_global::<SemanticIndex, _>(|this, _| {
- this.project_indices.remove(&project_weak);
- })
- }
- })
- .detach();
- });
-
- self.project_indices
- .entry(project.downgrade())
- .or_insert_with(|| {
- cx.new_model(|cx| {
- ProjectIndex::new(
- project,
- self.db_connection.clone(),
- self.embedding_provider.clone(),
- cx,
- )
- })
- })
- .clone()
- }
-}
-
-pub struct ProjectIndex {
- db_connection: heed::Env,
- project: WeakModel<Project>,
- worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
- language_registry: Arc<LanguageRegistry>,
- fs: Arc<dyn Fs>,
- last_status: Status,
- status_tx: channel::Sender<()>,
- embedding_provider: Arc<dyn EmbeddingProvider>,
- _maintain_status: Task<()>,
- _subscription: Subscription,
-}
-
-#[derive(Clone)]
-enum WorktreeIndexHandle {
- Loading {
- index: Shared<Task<Result<Model<WorktreeIndex>, Arc<anyhow::Error>>>>,
- },
- Loaded {
- index: Model<WorktreeIndex>,
- },
-}
-
-impl ProjectIndex {
- fn new(
- project: Model<Project>,
- db_connection: heed::Env,
- embedding_provider: Arc<dyn EmbeddingProvider>,
- cx: &mut ModelContext<Self>,
- ) -> Self {
- let language_registry = project.read(cx).languages().clone();
- let fs = project.read(cx).fs().clone();
- let (status_tx, mut status_rx) = channel::unbounded();
- let mut this = ProjectIndex {
- db_connection,
- project: project.downgrade(),
- worktree_indices: HashMap::default(),
- language_registry,
- fs,
- status_tx,
- last_status: Status::Idle,
- embedding_provider,
- _subscription: cx.subscribe(&project, Self::handle_project_event),
- _maintain_status: cx.spawn(|this, mut cx| async move {
- while status_rx.next().await.is_some() {
- if this
- .update(&mut cx, |this, cx| this.update_status(cx))
- .is_err()
- {
- break;
- }
- }
- }),
- };
- this.update_worktree_indices(cx);
- this
- }
-
- pub fn status(&self) -> Status {
- self.last_status
- }
-
- pub fn project(&self) -> WeakModel<Project> {
- self.project.clone()
- }
-
- pub fn fs(&self) -> Arc<dyn Fs> {
- self.fs.clone()
- }
-
- fn handle_project_event(
- &mut self,
- _: Model<Project>,
- event: &project::Event,
- cx: &mut ModelContext<Self>,
- ) {
- match event {
- project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
- self.update_worktree_indices(cx);
- }
- _ => {}
- }
- }
-
- fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
- let Some(project) = self.project.upgrade() else {
- return;
- };
-
- let worktrees = project
- .read(cx)
- .visible_worktrees(cx)
- .filter_map(|worktree| {
- if worktree.read(cx).is_local() {
- Some((worktree.entity_id(), worktree))
- } else {
- None
- }
- })
- .collect::<HashMap<_, _>>();
-
- self.worktree_indices
- .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
- for (worktree_id, worktree) in worktrees {
- self.worktree_indices.entry(worktree_id).or_insert_with(|| {
- let worktree_index = WorktreeIndex::load(
- worktree.clone(),
- self.db_connection.clone(),
- self.language_registry.clone(),
- self.fs.clone(),
- self.status_tx.clone(),
- self.embedding_provider.clone(),
- cx,
- );
-
- let load_worktree = cx.spawn(|this, mut cx| async move {
- let result = match worktree_index.await {
- Ok(worktree_index) => {
- this.update(&mut cx, |this, _| {
- this.worktree_indices.insert(
- worktree_id,
- WorktreeIndexHandle::Loaded {
- index: worktree_index.clone(),
- },
- );
- })?;
- Ok(worktree_index)
- }
- Err(error) => {
- this.update(&mut cx, |this, _cx| {
- this.worktree_indices.remove(&worktree_id)
- })?;
- Err(Arc::new(error))
- }
- };
-
- this.update(&mut cx, |this, cx| this.update_status(cx))?;
-
- result
- });
-
- WorktreeIndexHandle::Loading {
- index: load_worktree.shared(),
- }
- });
- }
-
- self.update_status(cx);
+ _cx: &mut AppContext,
+ ) -> Option<Model<ProjectIndex>> {
+ self.project_indices.get(&project.downgrade()).cloned()
}
- fn update_status(&mut self, cx: &mut ModelContext<Self>) {
- let mut indexing_count = 0;
- let mut any_loading = false;
-
- for index in self.worktree_indices.values_mut() {
- match index {
- WorktreeIndexHandle::Loading { .. } => {
- any_loading = true;
- break;
- }
- WorktreeIndexHandle::Loaded { index, .. } => {
- indexing_count += index.read(cx).entry_ids_being_indexed.len();
- }
- }
- }
-
- let status = if any_loading {
- Status::Loading
- } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
- Status::Scanning { remaining_count }
- } else {
- Status::Idle
- };
-
- if status != self.last_status {
- self.last_status = status;
- cx.emit(status);
- }
- }
-
- pub fn search(
- &self,
- query: String,
- limit: usize,
- cx: &AppContext,
- ) -> Task<Result<Vec<SearchResult>>> {
- let (chunks_tx, chunks_rx) = channel::bounded(1024);
- let mut worktree_scan_tasks = Vec::new();
- for worktree_index in self.worktree_indices.values() {
- let worktree_index = worktree_index.clone();
- let chunks_tx = chunks_tx.clone();
- worktree_scan_tasks.push(cx.spawn(|cx| async move {
- let index = match worktree_index {
- WorktreeIndexHandle::Loading { index } => {
- index.clone().await.map_err(|error| anyhow!(error))?
- }
- WorktreeIndexHandle::Loaded { index } => index.clone(),
- };
-
- index
- .read_with(&cx, |index, cx| {
- let worktree_id = index.worktree.read(cx).id();
- let db_connection = index.db_connection.clone();
- let db = index.db;
- cx.background_executor().spawn(async move {
- let txn = db_connection
- .read_txn()
- .context("failed to create read transaction")?;
- let db_entries = db.iter(&txn).context("failed to iterate database")?;
- for db_entry in db_entries {
- let (_key, db_embedded_file) = db_entry?;
- for chunk in db_embedded_file.chunks {
- chunks_tx
- .send((worktree_id, db_embedded_file.path.clone(), chunk))
- .await?;
- }
- }
- anyhow::Ok(())
- })
- })?
- .await
- }));
- }
- drop(chunks_tx);
-
- let project = self.project.clone();
- let embedding_provider = self.embedding_provider.clone();
- cx.spawn(|cx| async move {
- #[cfg(debug_assertions)]
- let embedding_query_start = std::time::Instant::now();
- log::info!("Searching for {query}");
-
- let query_embeddings = embedding_provider
- .embed(&[TextToEmbed::new(&query)])
- .await?;
- let query_embedding = query_embeddings
- .into_iter()
- .next()
- .ok_or_else(|| anyhow!("no embedding for query"))?;
-
- let mut results_by_worker = Vec::new();
- for _ in 0..cx.background_executor().num_cpus() {
- results_by_worker.push(Vec::<WorktreeSearchResult>::new());
- }
-
- #[cfg(debug_assertions)]
- let search_start = std::time::Instant::now();
-
- cx.background_executor()
- .scoped(|cx| {
- for results in results_by_worker.iter_mut() {
- cx.spawn(async {
- while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
- let score = chunk.embedding.similarity(&query_embedding);
- let ix = match results.binary_search_by(|probe| {
- score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
- }) {
- Ok(ix) | Err(ix) => ix,
- };
- results.insert(
- ix,
- WorktreeSearchResult {
- worktree_id,
- path: path.clone(),
- range: chunk.chunk.range.clone(),
- score,
- },
- );
- results.truncate(limit);
- }
- });
- }
- })
- .await;
-
- for scan_task in futures::future::join_all(worktree_scan_tasks).await {
- scan_task.log_err();
- }
-
- project.read_with(&cx, |project, cx| {
- let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
- for worker_results in results_by_worker {
- search_results.extend(worker_results.into_iter().filter_map(|result| {
- Some(SearchResult {
- worktree: project.worktree_for_id(result.worktree_id, cx)?,
- path: result.path,
- range: result.range,
- score: result.score,
- })
- }));
- }
- search_results.sort_unstable_by(|a, b| {
- b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
- });
- search_results.truncate(limit);
-
- #[cfg(debug_assertions)]
- {
- let search_elapsed = search_start.elapsed();
- log::debug!(
- "searched {} entries in {:?}",
- search_results.len(),
- search_elapsed
- );
- let embedding_query_elapsed = embedding_query_start.elapsed();
- log::debug!("embedding query took {:?}", embedding_query_elapsed);
- }
-
- search_results
- })
- })
- }
-
- #[cfg(test)]
- pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
- let mut result = 0;
- for worktree_index in self.worktree_indices.values() {
- if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
- result += index.read(cx).path_count()?;
- }
- }
- Ok(result)
- }
-
- pub(crate) fn worktree_index(
+ pub fn remaining_summaries(
&self,
- worktree_id: WorktreeId,
- cx: &AppContext,
- ) -> Option<Model<WorktreeIndex>> {
- for index in self.worktree_indices.values() {
- if let WorktreeIndexHandle::Loaded { index, .. } = index {
- if index.read(cx).worktree.read(cx).id() == worktree_id {
- return Some(index.clone());
- }
- }
- }
- None
- }
-
- pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
- let mut result = self
- .worktree_indices
- .values()
- .filter_map(|index| {
- if let WorktreeIndexHandle::Loaded { index, .. } = index {
- Some(index.clone())
- } else {
- None
- }
- })
- .collect::<Vec<_>>();
- result.sort_by_key(|index| index.read(cx).worktree.read(cx).id());
- result
- }
-}
-
-pub struct SearchResult {
- pub worktree: Model<Worktree>,
- pub path: Arc<Path>,
- pub range: Range<usize>,
- pub score: f32,
-}
-
-pub struct WorktreeSearchResult {
- pub worktree_id: WorktreeId,
- pub path: Arc<Path>,
- pub range: Range<usize>,
- pub score: f32,
-}
-
-#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
-pub enum Status {
- Idle,
- Loading,
- Scanning { remaining_count: NonZeroUsize },
-}
-
-impl EventEmitter<Status> for ProjectIndex {}
-
-struct WorktreeIndex {
- worktree: Model<Worktree>,
- db_connection: heed::Env,
- db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
- language_registry: Arc<LanguageRegistry>,
- fs: Arc<dyn Fs>,
- embedding_provider: Arc<dyn EmbeddingProvider>,
- entry_ids_being_indexed: Arc<IndexingEntrySet>,
- _index_entries: Task<Result<()>>,
- _subscription: Subscription,
-}
-
-impl WorktreeIndex {
- pub fn load(
- worktree: Model<Worktree>,
- db_connection: heed::Env,
- language_registry: Arc<LanguageRegistry>,
- fs: Arc<dyn Fs>,
- status_tx: channel::Sender<()>,
- embedding_provider: Arc<dyn EmbeddingProvider>,
+ project: &WeakModel<Project>,
cx: &mut AppContext,
- ) -> Task<Result<Model<Self>>> {
- let worktree_abs_path = worktree.read(cx).abs_path();
- cx.spawn(|mut cx| async move {
- let db = cx
- .background_executor()
- .spawn({
- let db_connection = db_connection.clone();
- async move {
- let mut txn = db_connection.write_txn()?;
- let db_name = worktree_abs_path.to_string_lossy();
- let db = db_connection.create_database(&mut txn, Some(&db_name))?;
- txn.commit()?;
- anyhow::Ok(db)
- }
- })
- .await?;
- cx.new_model(|cx| {
- Self::new(
- worktree,
- db_connection,
- db,
- status_tx,
- language_registry,
- fs,
- embedding_provider,
- cx,
- )
+ ) -> Option<usize> {
+ self.project_indices.get(project).map(|project_index| {
+ project_index.update(cx, |project_index, cx| {
+ project_index.remaining_summaries(cx)
})
})
}
-
- #[allow(clippy::too_many_arguments)]
- fn new(
- worktree: Model<Worktree>,
- db_connection: heed::Env,
- db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
- status: channel::Sender<()>,
- language_registry: Arc<LanguageRegistry>,
- fs: Arc<dyn Fs>,
- embedding_provider: Arc<dyn EmbeddingProvider>,
- cx: &mut ModelContext<Self>,
- ) -> Self {
- let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
- let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
- if let worktree::Event::UpdatedEntries(update) = event {
- _ = updated_entries_tx.try_send(update.clone());
- }
- });
-
- Self {
- db_connection,
- db,
- worktree,
- language_registry,
- fs,
- embedding_provider,
- entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
- _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
- _subscription,
- }
- }
-
- async fn index_entries(
- this: WeakModel<Self>,
- updated_entries: channel::Receiver<UpdatedEntriesSet>,
- mut cx: AsyncAppContext,
- ) -> Result<()> {
- let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
- index.await.log_err();
-
- while let Ok(updated_entries) = updated_entries.recv().await {
- let index = this.update(&mut cx, |this, cx| {
- this.index_updated_entries(updated_entries, cx)
- })?;
- index.await.log_err();
- }
-
- Ok(())
- }
-
- fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
- let worktree = self.worktree.read(cx).snapshot();
- let worktree_abs_path = worktree.abs_path().clone();
- let scan = self.scan_entries(worktree, cx);
- let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
- let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
- let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
- async move {
- futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
- Ok(())
- }
- }
-
- fn index_updated_entries(
- &self,
- updated_entries: UpdatedEntriesSet,
- cx: &AppContext,
- ) -> impl Future<Output = Result<()>> {
- let worktree = self.worktree.read(cx).snapshot();
- let worktree_abs_path = worktree.abs_path().clone();
- let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
- let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
- let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
- let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
- async move {
- futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
- Ok(())
- }
- }
-
- fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> ScanEntries {
- let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
- let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
- let db_connection = self.db_connection.clone();
- let db = self.db;
- let entries_being_indexed = self.entry_ids_being_indexed.clone();
- let task = cx.background_executor().spawn(async move {
- let txn = db_connection
- .read_txn()
- .context("failed to create read transaction")?;
- let mut db_entries = db
- .iter(&txn)
- .context("failed to create iterator")?
- .move_between_keys()
- .peekable();
-
- let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
- for entry in worktree.files(false, 0) {
- let entry_db_key = db_key_for_path(&entry.path);
-
- let mut saved_mtime = None;
- while let Some(db_entry) = db_entries.peek() {
- match db_entry {
- Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
- Ordering::Less => {
- if let Some(deletion_range) = deletion_range.as_mut() {
- deletion_range.1 = Bound::Included(db_path);
- } else {
- deletion_range =
- Some((Bound::Included(db_path), Bound::Included(db_path)));
- }
-
- db_entries.next();
- }
- Ordering::Equal => {
- if let Some(deletion_range) = deletion_range.take() {
- deleted_entry_ranges_tx
- .send((
- deletion_range.0.map(ToString::to_string),
- deletion_range.1.map(ToString::to_string),
- ))
- .await?;
- }
- saved_mtime = db_embedded_file.mtime;
- db_entries.next();
- break;
- }
- Ordering::Greater => {
- break;
- }
- },
- Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
- }
- }
-
- if entry.mtime != saved_mtime {
- let handle = entries_being_indexed.insert(entry.id);
- updated_entries_tx.send((entry.clone(), handle)).await?;
- }
- }
-
- if let Some(db_entry) = db_entries.next() {
- let (db_path, _) = db_entry?;
- deleted_entry_ranges_tx
- .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
- .await?;
- }
-
- Ok(())
- });
-
- ScanEntries {
- updated_entries: updated_entries_rx,
- deleted_entry_ranges: deleted_entry_ranges_rx,
- task,
- }
- }
-
- fn scan_updated_entries(
- &self,
- worktree: Snapshot,
- updated_entries: UpdatedEntriesSet,
- cx: &AppContext,
- ) -> ScanEntries {
- let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
- let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
- let entries_being_indexed = self.entry_ids_being_indexed.clone();
- let task = cx.background_executor().spawn(async move {
- for (path, entry_id, status) in updated_entries.iter() {
- match status {
- project::PathChange::Added
- | project::PathChange::Updated
- | project::PathChange::AddedOrUpdated => {
- if let Some(entry) = worktree.entry_for_id(*entry_id) {
- if entry.is_file() {
- let handle = entries_being_indexed.insert(entry.id);
- updated_entries_tx.send((entry.clone(), handle)).await?;
- }
- }
- }
- project::PathChange::Removed => {
- let db_path = db_key_for_path(path);
- deleted_entry_ranges_tx
- .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
- .await?;
- }
- project::PathChange::Loaded => {
- // Do nothing.
- }
- }
- }
-
- Ok(())
- });
-
- ScanEntries {
- updated_entries: updated_entries_rx,
- deleted_entry_ranges: deleted_entry_ranges_rx,
- task,
- }
- }
-
- fn chunk_files(
- &self,
- worktree_abs_path: Arc<Path>,
- entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
- cx: &AppContext,
- ) -> ChunkFiles {
- let language_registry = self.language_registry.clone();
- let fs = self.fs.clone();
- let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
- let task = cx.spawn(|cx| async move {
- cx.background_executor()
- .scoped(|cx| {
- for _ in 0..cx.num_cpus() {
- cx.spawn(async {
- while let Ok((entry, handle)) = entries.recv().await {
- let entry_abs_path = worktree_abs_path.join(&entry.path);
- let Some(text) = fs
- .load(&entry_abs_path)
- .await
- .with_context(|| {
- format!("failed to read path {entry_abs_path:?}")
- })
- .log_err()
- else {
- continue;
- };
- let language = language_registry
- .language_for_file_path(&entry.path)
- .await
- .ok();
- let chunked_file = ChunkedFile {
- chunks: chunk_text(&text, language.as_ref(), &entry.path),
- handle,
- path: entry.path,
- mtime: entry.mtime,
- text,
- };
-
- if chunked_files_tx.send(chunked_file).await.is_err() {
- return;
- }
- }
- });
- }
- })
- .await;
- Ok(())
- });
-
- ChunkFiles {
- files: chunked_files_rx,
- task,
- }
- }
-
- fn embed_files(
- embedding_provider: Arc<dyn EmbeddingProvider>,
- chunked_files: channel::Receiver<ChunkedFile>,
- cx: &AppContext,
- ) -> EmbedFiles {
- let embedding_provider = embedding_provider.clone();
- let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
- let task = cx.background_executor().spawn(async move {
- let mut chunked_file_batches =
- chunked_files.chunks_timeout(512, Duration::from_secs(2));
- while let Some(chunked_files) = chunked_file_batches.next().await {
- // View the batch of files as a vec of chunks
- // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
- // Once those are done, reassemble them back into the files in which they belong
- // If any embeddings fail for a file, the entire file is discarded
-
- let chunks: Vec<TextToEmbed> = chunked_files
- .iter()
- .flat_map(|file| {
- file.chunks.iter().map(|chunk| TextToEmbed {
- text: &file.text[chunk.range.clone()],
- digest: chunk.digest,
- })
- })
- .collect::<Vec<_>>();
-
- let mut embeddings: Vec<Option<Embedding>> = Vec::new();
- for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
- if let Some(batch_embeddings) =
- embedding_provider.embed(embedding_batch).await.log_err()
- {
- if batch_embeddings.len() == embedding_batch.len() {
- embeddings.extend(batch_embeddings.into_iter().map(Some));
- continue;
- }
- log::error!(
- "embedding provider returned unexpected embedding count {}, expected {}",
- batch_embeddings.len(), embedding_batch.len()
- );
- }
-
- embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
- }
-
- let mut embeddings = embeddings.into_iter();
- for chunked_file in chunked_files {
- let mut embedded_file = EmbeddedFile {
- path: chunked_file.path,
- mtime: chunked_file.mtime,
- chunks: Vec::new(),
- };
-
- let mut embedded_all_chunks = true;
- for (chunk, embedding) in
- chunked_file.chunks.into_iter().zip(embeddings.by_ref())
- {
- if let Some(embedding) = embedding {
- embedded_file
- .chunks
- .push(EmbeddedChunk { chunk, embedding });
- } else {
- embedded_all_chunks = false;
- }
- }
-
- if embedded_all_chunks {
- embedded_files_tx
- .send((embedded_file, chunked_file.handle))
- .await?;
- }
- }
- }
- Ok(())
- });
-
- EmbedFiles {
- files: embedded_files_rx,
- task,
- }
- }
-
- fn persist_embeddings(
- &self,
- mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
- embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
- cx: &AppContext,
- ) -> Task<Result<()>> {
- let db_connection = self.db_connection.clone();
- let db = self.db;
- cx.background_executor().spawn(async move {
- while let Some(deletion_range) = deleted_entry_ranges.next().await {
- let mut txn = db_connection.write_txn()?;
- let start = deletion_range.0.as_ref().map(|start| start.as_str());
- let end = deletion_range.1.as_ref().map(|end| end.as_str());
- log::debug!("deleting embeddings in range {:?}", &(start, end));
- db.delete_range(&mut txn, &(start, end))?;
- txn.commit()?;
- }
-
- let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
- while let Some(embedded_files) = embedded_files.next().await {
- let mut txn = db_connection.write_txn()?;
- for (file, _) in &embedded_files {
- log::debug!("saving embedding for file {:?}", file.path);
- let key = db_key_for_path(&file.path);
- db.put(&mut txn, &key, file)?;
- }
- txn.commit()?;
-
- drop(embedded_files);
- log::debug!("committed");
- }
-
- Ok(())
- })
- }
-
- fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
- let connection = self.db_connection.clone();
- let db = self.db;
- cx.background_executor().spawn(async move {
- let tx = connection
- .read_txn()
- .context("failed to create read transaction")?;
- let result = db
- .iter(&tx)?
- .map(|entry| Ok(entry?.1.path.clone()))
- .collect::<Result<Vec<Arc<Path>>>>();
- drop(tx);
- result
- })
- }
-
- fn chunks_for_path(
- &self,
- path: Arc<Path>,
- cx: &AppContext,
- ) -> Task<Result<Vec<EmbeddedChunk>>> {
- let connection = self.db_connection.clone();
- let db = self.db;
- cx.background_executor().spawn(async move {
- let tx = connection
- .read_txn()
- .context("failed to create read transaction")?;
- Ok(db
- .get(&tx, &db_key_for_path(&path))?
- .ok_or_else(|| anyhow!("no such path"))?
- .chunks
- .clone())
- })
- }
-
- #[cfg(test)]
- fn path_count(&self) -> Result<u64> {
- let txn = self
- .db_connection
- .read_txn()
- .context("failed to create read transaction")?;
- Ok(self.db.len(&txn)?)
- }
-}
-
-struct ScanEntries {
- updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
- deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
- task: Task<Result<()>>,
-}
-
-struct ChunkFiles {
- files: channel::Receiver<ChunkedFile>,
- task: Task<Result<()>>,
-}
-
-struct ChunkedFile {
- pub path: Arc<Path>,
- pub mtime: Option<SystemTime>,
- pub handle: IndexingEntryHandle,
- pub text: String,
- pub chunks: Vec<Chunk>,
-}
-
-struct EmbedFiles {
- files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
- task: Task<Result<()>>,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-struct EmbeddedFile {
- path: Arc<Path>,
- mtime: Option<SystemTime>,
- chunks: Vec<EmbeddedChunk>,
-}
-
-#[derive(Clone, Debug, Serialize, Deserialize)]
-struct EmbeddedChunk {
- chunk: Chunk,
- embedding: Embedding,
-}
-
-/// The set of entries that are currently being indexed.
-struct IndexingEntrySet {
- entry_ids: Mutex<HashSet<ProjectEntryId>>,
- tx: channel::Sender<()>,
-}
-
-/// When dropped, removes the entry from the set of entries that are being indexed.
-#[derive(Clone)]
-struct IndexingEntryHandle {
- entry_id: ProjectEntryId,
- set: Weak<IndexingEntrySet>,
-}
-
-impl IndexingEntrySet {
- fn new(tx: channel::Sender<()>) -> Self {
- Self {
- entry_ids: Default::default(),
- tx,
- }
- }
-
- fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
- self.entry_ids.lock().insert(entry_id);
- self.tx.send_blocking(()).ok();
- IndexingEntryHandle {
- entry_id,
- set: Arc::downgrade(self),
- }
- }
-
- pub fn len(&self) -> usize {
- self.entry_ids.lock().len()
- }
-}
-
-impl Drop for IndexingEntryHandle {
- fn drop(&mut self) {
- if let Some(set) = self.set.upgrade() {
- set.tx.send_blocking(()).ok();
- set.entry_ids.lock().remove(&self.entry_id);
- }
- }
-}
-
-fn db_key_for_path(path: &Arc<Path>) -> String {
- path.to_string_lossy().replace('/', "\0")
}
#[cfg(test)]
mod tests {
use super::*;
+ use anyhow::anyhow;
+ use chunking::Chunk;
+ use embedding_index::{ChunkedFile, EmbeddingIndex};
+ use feature_flags::FeatureFlagAppExt;
+ use fs::FakeFs;
use futures::{future::BoxFuture, FutureExt};
use gpui::TestAppContext;
+ use indexing::IndexingEntrySet;
use language::language_settings::AllLanguageSettings;
- use project::Project;
+ use project::{Project, ProjectEntryId};
+ use serde_json::json;
use settings::SettingsStore;
+ use smol::{channel, stream::StreamExt};
use std::{future, path::Path, sync::Arc};
fn init_test(cx: &mut TestAppContext) {
+ env_logger::try_init().ok();
+
cx.update(|cx| {
let store = SettingsStore::test(cx);
cx.set_global(store);
language::init(cx);
+ cx.update_flags(false, vec![]);
Project::init_settings(cx);
SettingsStore::update(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
@@ -0,0 +1,48 @@
+use collections::HashMap;
+use std::{path::Path, sync::Arc, time::SystemTime};
+
+const MAX_FILES_BEFORE_RESUMMARIZE: usize = 4;
+const MAX_BYTES_BEFORE_RESUMMARIZE: u64 = 1_000_000; // 1 MB
+
+#[derive(Default, Debug)]
+pub struct SummaryBacklog {
+ /// Key: path to a file that needs summarization, but that we haven't summarized yet. Value: that file's size on disk, in bytes, and its mtime.
+ files: HashMap<Arc<Path>, (u64, Option<SystemTime>)>,
+ /// Cache of the sum of all values in `files`, so we don't have to traverse the whole map to check if we're over the byte limit.
+ total_bytes: u64,
+}
+
+impl SummaryBacklog {
+ /// Store the given path in the backlog, along with how many bytes are in it.
+ pub fn insert(&mut self, path: Arc<Path>, bytes_on_disk: u64, mtime: Option<SystemTime>) {
+ let (prev_bytes, _) = self
+ .files
+ .insert(path, (bytes_on_disk, mtime))
+ .unwrap_or_default(); // Default to 0 prev_bytes
+
+ // Update the cached total by subtracting out the old amount and adding the new one.
+ self.total_bytes = self.total_bytes - prev_bytes + bytes_on_disk;
+ }
+
+ /// Returns true if the total number of bytes in the backlog exceeds a predefined threshold.
+ pub fn needs_drain(&self) -> bool {
+ self.files.len() > MAX_FILES_BEFORE_RESUMMARIZE ||
+ // The whole purpose of the cached total_bytes is to make this comparison cheap.
+ // Otherwise we'd have to traverse the entire dictionary every time we wanted this answer.
+ self.total_bytes > MAX_BYTES_BEFORE_RESUMMARIZE
+ }
+
+ /// Remove all the entries in the backlog and return the file paths as an iterator.
+ #[allow(clippy::needless_lifetimes)] // Clippy thinks this 'a can be elided, but eliding it gives a compile error
+ pub fn drain<'a>(&'a mut self) -> impl Iterator<Item = (Arc<Path>, Option<SystemTime>)> + 'a {
+ self.total_bytes = 0;
+
+ self.files
+ .drain()
+ .map(|(path, (_size, mtime))| (path, mtime))
+ }
+
+ pub fn len(&self) -> usize {
+ self.files.len()
+ }
+}
@@ -0,0 +1,693 @@
+use anyhow::{anyhow, Context as _, Result};
+use arrayvec::ArrayString;
+use fs::Fs;
+use futures::{stream::StreamExt, TryFutureExt};
+use futures_batch::ChunksTimeoutStreamExt;
+use gpui::{AppContext, Model, Task};
+use heed::{
+ types::{SerdeBincode, Str},
+ RoTxn,
+};
+use language_model::{
+ LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, Role,
+};
+use log;
+use parking_lot::Mutex;
+use project::{Entry, UpdatedEntriesSet, Worktree};
+use serde::{Deserialize, Serialize};
+use smol::channel;
+use std::{
+ future::Future,
+ path::Path,
+ sync::Arc,
+ time::{Duration, Instant, SystemTime},
+};
+use util::ResultExt;
+use worktree::Snapshot;
+
+use crate::{indexing::IndexingEntrySet, summary_backlog::SummaryBacklog};
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct FileSummary {
+ pub filename: String,
+ pub summary: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct UnsummarizedFile {
+ // Path to the file on disk
+ path: Arc<Path>,
+ // The mtime of the file on disk
+ mtime: Option<SystemTime>,
+ // BLAKE3 hash of the source file's contents
+ digest: Blake3Digest,
+ // The source file's contents
+ contents: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct SummarizedFile {
+ // Path to the file on disk
+ path: String,
+ // The mtime of the file on disk
+ mtime: Option<SystemTime>,
+ // BLAKE3 hash of the source file's contents
+ digest: Blake3Digest,
+ // The LLM's summary of the file's contents
+ summary: String,
+}
+
+/// This is what blake3's to_hex() method returns - see https://docs.rs/blake3/1.5.3/src/blake3/lib.rs.html#246
+pub type Blake3Digest = ArrayString<{ blake3::OUT_LEN * 2 }>;
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct FileDigest {
+ pub mtime: Option<SystemTime>,
+ pub digest: Blake3Digest,
+}
+
+struct NeedsSummary {
+ files: channel::Receiver<UnsummarizedFile>,
+ task: Task<Result<()>>,
+}
+
+struct SummarizeFiles {
+ files: channel::Receiver<SummarizedFile>,
+ task: Task<Result<()>>,
+}
+
+pub struct SummaryIndex {
+ worktree: Model<Worktree>,
+ fs: Arc<dyn Fs>,
+ db_connection: heed::Env,
+ file_digest_db: heed::Database<Str, SerdeBincode<FileDigest>>, // Key: file path. Val: BLAKE3 digest of its contents.
+ summary_db: heed::Database<SerdeBincode<Blake3Digest>, Str>, // Key: BLAKE3 digest of a file's contents. Val: LLM summary of those contents.
+ backlog: Arc<Mutex<SummaryBacklog>>,
+ _entry_ids_being_indexed: Arc<IndexingEntrySet>, // TODO can this be removed?
+}
+
+struct Backlogged {
+ paths_to_digest: channel::Receiver<Vec<(Arc<Path>, Option<SystemTime>)>>,
+ task: Task<Result<()>>,
+}
+
+struct MightNeedSummaryFiles {
+ files: channel::Receiver<UnsummarizedFile>,
+ task: Task<Result<()>>,
+}
+
+impl SummaryIndex {
+ pub fn new(
+ worktree: Model<Worktree>,
+ fs: Arc<dyn Fs>,
+ db_connection: heed::Env,
+ file_digest_db: heed::Database<Str, SerdeBincode<FileDigest>>,
+ summary_db: heed::Database<SerdeBincode<Blake3Digest>, Str>,
+ _entry_ids_being_indexed: Arc<IndexingEntrySet>,
+ ) -> Self {
+ Self {
+ worktree,
+ fs,
+ db_connection,
+ file_digest_db,
+ summary_db,
+ _entry_ids_being_indexed,
+ backlog: Default::default(),
+ }
+ }
+
+ pub fn file_digest_db(&self) -> heed::Database<Str, SerdeBincode<FileDigest>> {
+ self.file_digest_db
+ }
+
+ pub fn summary_db(&self) -> heed::Database<SerdeBincode<Blake3Digest>, Str> {
+ self.summary_db
+ }
+
+ pub fn index_entries_changed_on_disk(
+ &self,
+ is_auto_available: bool,
+ cx: &AppContext,
+ ) -> impl Future<Output = Result<()>> {
+ let start = Instant::now();
+ let backlogged;
+ let digest;
+ let needs_summary;
+ let summaries;
+ let persist;
+
+ if is_auto_available {
+ let worktree = self.worktree.read(cx).snapshot();
+ let worktree_abs_path = worktree.abs_path().clone();
+
+ backlogged = self.scan_entries(worktree, cx);
+ digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
+ needs_summary = self.check_summary_cache(digest.files, cx);
+ summaries = self.summarize_files(needs_summary.files, cx);
+ persist = self.persist_summaries(summaries.files, cx);
+ } else {
+ // This feature is only staff-shipped, so make the rest of these no-ops.
+ backlogged = Backlogged {
+ paths_to_digest: channel::unbounded().1,
+ task: Task::ready(Ok(())),
+ };
+ digest = MightNeedSummaryFiles {
+ files: channel::unbounded().1,
+ task: Task::ready(Ok(())),
+ };
+ needs_summary = NeedsSummary {
+ files: channel::unbounded().1,
+ task: Task::ready(Ok(())),
+ };
+ summaries = SummarizeFiles {
+ files: channel::unbounded().1,
+ task: Task::ready(Ok(())),
+ };
+ persist = Task::ready(Ok(()));
+ }
+
+ async move {
+ futures::try_join!(
+ backlogged.task,
+ digest.task,
+ needs_summary.task,
+ summaries.task,
+ persist
+ )?;
+
+ if is_auto_available {
+ log::info!(
+ "Summarizing everything that changed on disk took {:?}",
+ start.elapsed()
+ );
+ }
+
+ Ok(())
+ }
+ }
+
+ pub fn index_updated_entries(
+ &mut self,
+ updated_entries: UpdatedEntriesSet,
+ is_auto_available: bool,
+ cx: &AppContext,
+ ) -> impl Future<Output = Result<()>> {
+ let start = Instant::now();
+ let backlogged;
+ let digest;
+ let needs_summary;
+ let summaries;
+ let persist;
+
+ if is_auto_available {
+ let worktree = self.worktree.read(cx).snapshot();
+ let worktree_abs_path = worktree.abs_path().clone();
+
+ backlogged = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
+ digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
+ needs_summary = self.check_summary_cache(digest.files, cx);
+ summaries = self.summarize_files(needs_summary.files, cx);
+ persist = self.persist_summaries(summaries.files, cx);
+ } else {
+ // This feature is only staff-shipped, so make the rest of these no-ops.
+ backlogged = Backlogged {
+ paths_to_digest: channel::unbounded().1,
+ task: Task::ready(Ok(())),
+ };
+ digest = MightNeedSummaryFiles {
+ files: channel::unbounded().1,
+ task: Task::ready(Ok(())),
+ };
+ needs_summary = NeedsSummary {
+ files: channel::unbounded().1,
+ task: Task::ready(Ok(())),
+ };
+ summaries = SummarizeFiles {
+ files: channel::unbounded().1,
+ task: Task::ready(Ok(())),
+ };
+ persist = Task::ready(Ok(()));
+ }
+
+ async move {
+ futures::try_join!(
+ backlogged.task,
+ digest.task,
+ needs_summary.task,
+ summaries.task,
+ persist
+ )?;
+
+ log::info!("Summarizing updated entries took {:?}", start.elapsed());
+
+ Ok(())
+ }
+ }
+
+ fn check_summary_cache(
+ &self,
+ mut might_need_summary: channel::Receiver<UnsummarizedFile>,
+ cx: &AppContext,
+ ) -> NeedsSummary {
+ let db_connection = self.db_connection.clone();
+ let db = self.summary_db;
+ let (needs_summary_tx, needs_summary_rx) = channel::bounded(512);
+ let task = cx.background_executor().spawn(async move {
+ while let Some(file) = might_need_summary.next().await {
+ let tx = db_connection
+ .read_txn()
+ .context("Failed to create read transaction for checking which hashes are in summary cache")?;
+
+ match db.get(&tx, &file.digest) {
+ Ok(opt_answer) => {
+ if opt_answer.is_none() {
+ // It's not in the summary cache db, so we need to summarize it.
+ log::debug!("File {:?} (digest {:?}) was NOT in the db cache and needs to be resummarized.", file.path.display(), &file.digest);
+ needs_summary_tx.send(file).await?;
+ } else {
+ log::debug!("File {:?} (digest {:?}) was in the db cache and does not need to be resummarized.", file.path.display(), &file.digest);
+ }
+ }
+ Err(err) => {
+ log::error!("Reading from the summaries database failed: {:?}", err);
+ }
+ }
+ }
+
+ Ok(())
+ });
+
+ NeedsSummary {
+ files: needs_summary_rx,
+ task,
+ }
+ }
+
+ fn scan_entries(&self, worktree: Snapshot, cx: &AppContext) -> Backlogged {
+ let (tx, rx) = channel::bounded(512);
+ let db_connection = self.db_connection.clone();
+ let digest_db = self.file_digest_db;
+ let backlog = Arc::clone(&self.backlog);
+ let task = cx.background_executor().spawn(async move {
+ let txn = db_connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+
+ for entry in worktree.files(false, 0) {
+ let needs_summary =
+ Self::add_to_backlog(Arc::clone(&backlog), digest_db, &txn, entry);
+
+ if !needs_summary.is_empty() {
+ tx.send(needs_summary).await?;
+ }
+ }
+
+ // TODO delete db entries for deleted files
+
+ Ok(())
+ });
+
+ Backlogged {
+ paths_to_digest: rx,
+ task,
+ }
+ }
+
+ fn add_to_backlog(
+ backlog: Arc<Mutex<SummaryBacklog>>,
+ digest_db: heed::Database<Str, SerdeBincode<FileDigest>>,
+ txn: &RoTxn<'_>,
+ entry: &Entry,
+ ) -> Vec<(Arc<Path>, Option<SystemTime>)> {
+ let entry_db_key = db_key_for_path(&entry.path);
+
+ match digest_db.get(&txn, &entry_db_key) {
+ Ok(opt_saved_digest) => {
+ // The file path is the same, but the mtime is different. (Or there was no mtime.)
+ // It needs updating, so add it to the backlog! Then, if the backlog is full, drain it and summarize its contents.
+ if entry.mtime != opt_saved_digest.and_then(|digest| digest.mtime) {
+ let mut backlog = backlog.lock();
+
+ log::info!(
+ "Inserting {:?} ({:?} bytes) into backlog",
+ &entry.path,
+ entry.size,
+ );
+ backlog.insert(Arc::clone(&entry.path), entry.size, entry.mtime);
+
+ if backlog.needs_drain() {
+ log::info!("Draining summary backlog...");
+ return backlog.drain().collect();
+ }
+ }
+ }
+ Err(err) => {
+ log::error!(
+ "Error trying to get file digest db entry {:?}: {:?}",
+ &entry_db_key,
+ err
+ );
+ }
+ }
+
+ Vec::new()
+ }
+
+ fn scan_updated_entries(
+ &self,
+ worktree: Snapshot,
+ updated_entries: UpdatedEntriesSet,
+ cx: &AppContext,
+ ) -> Backlogged {
+ log::info!("Scanning for updated entries that might need summarization...");
+ let (tx, rx) = channel::bounded(512);
+ // let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
+ let db_connection = self.db_connection.clone();
+ let digest_db = self.file_digest_db;
+ let backlog = Arc::clone(&self.backlog);
+ let task = cx.background_executor().spawn(async move {
+ let txn = db_connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+
+ for (path, entry_id, status) in updated_entries.iter() {
+ match status {
+ project::PathChange::Loaded
+ | project::PathChange::Added
+ | project::PathChange::Updated
+ | project::PathChange::AddedOrUpdated => {
+ if let Some(entry) = worktree.entry_for_id(*entry_id) {
+ if entry.is_file() {
+ let needs_summary = Self::add_to_backlog(
+ Arc::clone(&backlog),
+ digest_db,
+ &txn,
+ entry,
+ );
+
+ if !needs_summary.is_empty() {
+ tx.send(needs_summary).await?;
+ }
+ }
+ }
+ }
+ project::PathChange::Removed => {
+ let _db_path = db_key_for_path(path);
+ // TODO delete db entries for deleted files
+ // deleted_entry_ranges_tx
+ // .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
+ // .await?;
+ }
+ }
+ }
+
+ Ok(())
+ });
+
+ Backlogged {
+ paths_to_digest: rx,
+ // deleted_entry_ranges: deleted_entry_ranges_rx,
+ task,
+ }
+ }
+
+ fn digest_files(
+ &self,
+ paths: channel::Receiver<Vec<(Arc<Path>, Option<SystemTime>)>>,
+ worktree_abs_path: Arc<Path>,
+ cx: &AppContext,
+ ) -> MightNeedSummaryFiles {
+ let fs = self.fs.clone();
+ let (rx, tx) = channel::bounded(2048);
+ let task = cx.spawn(|cx| async move {
+ cx.background_executor()
+ .scoped(|cx| {
+ for _ in 0..cx.num_cpus() {
+ cx.spawn(async {
+ while let Ok(pairs) = paths.recv().await {
+ // Note: we could process all these files concurrently if desired. Might or might not speed things up.
+ for (path, mtime) in pairs {
+ let entry_abs_path = worktree_abs_path.join(&path);
+
+ // Load the file's contents and compute its hash digest.
+ let unsummarized_file = {
+ let Some(contents) = fs
+ .load(&entry_abs_path)
+ .await
+ .with_context(|| {
+ format!("failed to read path {entry_abs_path:?}")
+ })
+ .log_err()
+ else {
+ continue;
+ };
+
+ let digest = {
+ let mut hasher = blake3::Hasher::new();
+ // Incorporate both the (relative) file path as well as the contents of the file into the hash.
+ // This is because in some languages and frameworks, identical files can do different things
+ // depending on their paths (e.g. Rails controllers). It's also why we send the path to the model.
+ hasher.update(path.display().to_string().as_bytes());
+ hasher.update(contents.as_bytes());
+ hasher.finalize().to_hex()
+ };
+
+ UnsummarizedFile {
+ digest,
+ contents,
+ path,
+ mtime,
+ }
+ };
+
+ if let Err(err) = rx
+ .send(unsummarized_file)
+ .map_err(|error| anyhow!(error))
+ .await
+ {
+ log::error!("Error: {:?}", err);
+
+ return;
+ }
+ }
+ }
+ });
+ }
+ })
+ .await;
+ Ok(())
+ });
+
+ MightNeedSummaryFiles { files: tx, task }
+ }
+
+ fn summarize_files(
+ &self,
+ mut unsummarized_files: channel::Receiver<UnsummarizedFile>,
+ cx: &AppContext,
+ ) -> SummarizeFiles {
+ let (summarized_tx, summarized_rx) = channel::bounded(512);
+ let task = cx.spawn(|cx| async move {
+ while let Some(file) = unsummarized_files.next().await {
+ log::debug!("Summarizing {:?}", file);
+ let summary = cx
+ .update(|cx| Self::summarize_code(&file.contents, &file.path, cx))?
+ .await
+ .unwrap_or_else(|err| {
+ // Log a warning because we'll continue anyway.
+ // In the future, we may want to try splitting it up into multiple requests and concatenating the summaries,
+ // but this might give bad summaries due to cutting off source code files in the middle.
+ log::warn!("Failed to summarize {} - {:?}", file.path.display(), err);
+
+ String::new()
+ });
+
+ // Note that the summary could be empty because of an error talking to a cloud provider,
+ // e.g. because the context limit was exceeded. In that case, we return Ok(String::new()).
+ if !summary.is_empty() {
+ summarized_tx
+ .send(SummarizedFile {
+ path: file.path.display().to_string(),
+ digest: file.digest,
+ summary,
+ mtime: file.mtime,
+ })
+ .await?
+ }
+ }
+
+ Ok(())
+ });
+
+ SummarizeFiles {
+ files: summarized_rx,
+ task,
+ }
+ }
+
+ fn summarize_code(
+ code: &str,
+ path: &Path,
+ cx: &AppContext,
+ ) -> impl Future<Output = Result<String>> {
+ let start = Instant::now();
+ let (summary_model_id, use_cache): (LanguageModelId, bool) = (
+ "Qwen/Qwen2-7B-Instruct".to_string().into(), // TODO read this from the user's settings.
+ false, // qwen2 doesn't have a cache, but we should probably infer this from the model
+ );
+ let Some(model) = LanguageModelRegistry::read_global(cx)
+ .available_models(cx)
+ .find(|model| &model.id() == &summary_model_id)
+ else {
+ return cx.background_executor().spawn(async move {
+ Err(anyhow!("Couldn't find the preferred summarization model ({:?}) in the language registry's available models", summary_model_id))
+ });
+ };
+ let utf8_path = path.to_string_lossy();
+ const PROMPT_BEFORE_CODE: &str = "Summarize what the code in this file does in 3 sentences, using no newlines or bullet points in the summary:";
+ let prompt = format!("{PROMPT_BEFORE_CODE}\n{utf8_path}:\n{code}");
+
+ log::debug!(
+ "Summarizing code by sending this prompt to {:?}: {:?}",
+ model.name(),
+ &prompt
+ );
+
+ let request = LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![prompt.into()],
+ cache: use_cache,
+ }],
+ tools: Vec::new(),
+ stop: Vec::new(),
+ temperature: 1.0,
+ };
+
+ let code_len = code.len();
+ cx.spawn(|cx| async move {
+ let stream = model.stream_completion(request, &cx);
+ cx.background_executor()
+ .spawn(async move {
+ let answer: String = stream
+ .await?
+ .filter_map(|event| async {
+ if let Ok(LanguageModelCompletionEvent::Text(text)) = event {
+ Some(text)
+ } else {
+ None
+ }
+ })
+ .collect()
+ .await;
+
+ log::info!(
+ "It took {:?} to summarize {:?} bytes of code.",
+ start.elapsed(),
+ code_len
+ );
+
+ log::debug!("Summary was: {:?}", &answer);
+
+ Ok(answer)
+ })
+ .await
+
+ // TODO if summarization failed, put it back in the backlog!
+ })
+ }
+
+ fn persist_summaries(
+ &self,
+ summaries: channel::Receiver<SummarizedFile>,
+ cx: &AppContext,
+ ) -> Task<Result<()>> {
+ let db_connection = self.db_connection.clone();
+ let digest_db = self.file_digest_db;
+ let summary_db = self.summary_db;
+ cx.background_executor().spawn(async move {
+ let mut summaries = summaries.chunks_timeout(4096, Duration::from_secs(2));
+ while let Some(summaries) = summaries.next().await {
+ let mut txn = db_connection.write_txn()?;
+ for file in &summaries {
+ log::debug!(
+ "Saving summary of {:?} - which is {} bytes of summary for content digest {:?}",
+ &file.path,
+ file.summary.len(),
+ file.digest
+ );
+ digest_db.put(
+ &mut txn,
+ &file.path,
+ &FileDigest {
+ mtime: file.mtime,
+ digest: file.digest,
+ },
+ )?;
+ summary_db.put(&mut txn, &file.digest, &file.summary)?;
+ }
+ txn.commit()?;
+
+ drop(summaries);
+ log::debug!("committed summaries");
+ }
+
+ Ok(())
+ })
+ }
+
+ /// Empty out the backlog of files that haven't been resummarized, and resummarize them immediately.
+ pub(crate) fn flush_backlog(
+ &self,
+ worktree_abs_path: Arc<Path>,
+ cx: &AppContext,
+ ) -> impl Future<Output = Result<()>> {
+ let start = Instant::now();
+ let backlogged = {
+ let (tx, rx) = channel::bounded(512);
+ let needs_summary: Vec<(Arc<Path>, Option<SystemTime>)> = {
+ let mut backlog = self.backlog.lock();
+
+ backlog.drain().collect()
+ };
+
+ let task = cx.background_executor().spawn(async move {
+ tx.send(needs_summary).await?;
+ Ok(())
+ });
+
+ Backlogged {
+ paths_to_digest: rx,
+ task,
+ }
+ };
+
+ let digest = self.digest_files(backlogged.paths_to_digest, worktree_abs_path, cx);
+ let needs_summary = self.check_summary_cache(digest.files, cx);
+ let summaries = self.summarize_files(needs_summary.files, cx);
+ let persist = self.persist_summaries(summaries.files, cx);
+
+ async move {
+ futures::try_join!(
+ backlogged.task,
+ digest.task,
+ needs_summary.task,
+ summaries.task,
+ persist
+ )?;
+
+ log::info!("Summarizing backlogged entries took {:?}", start.elapsed());
+
+ Ok(())
+ }
+ }
+
+ pub(crate) fn backlog_len(&self) -> usize {
+ self.backlog.lock().len()
+ }
+}
+
+fn db_key_for_path(path: &Arc<Path>) -> String {
+ path.to_string_lossy().replace('/', "\0")
+}
@@ -0,0 +1,217 @@
+use crate::embedding::EmbeddingProvider;
+use crate::embedding_index::EmbeddingIndex;
+use crate::indexing::IndexingEntrySet;
+use crate::summary_index::SummaryIndex;
+use anyhow::Result;
+use feature_flags::{AutoCommand, FeatureFlagAppExt};
+use fs::Fs;
+use futures::future::Shared;
+use gpui::{
+ AppContext, AsyncAppContext, Context, Model, ModelContext, Subscription, Task, WeakModel,
+};
+use language::LanguageRegistry;
+use log;
+use project::{UpdatedEntriesSet, Worktree};
+use smol::channel;
+use std::sync::Arc;
+use util::ResultExt;
+
+#[derive(Clone)]
+pub enum WorktreeIndexHandle {
+ Loading {
+ index: Shared<Task<Result<Model<WorktreeIndex>, Arc<anyhow::Error>>>>,
+ },
+ Loaded {
+ index: Model<WorktreeIndex>,
+ },
+}
+
+pub struct WorktreeIndex {
+ worktree: Model<Worktree>,
+ db_connection: heed::Env,
+ embedding_index: EmbeddingIndex,
+ summary_index: SummaryIndex,
+ entry_ids_being_indexed: Arc<IndexingEntrySet>,
+ _index_entries: Task<Result<()>>,
+ _subscription: Subscription,
+}
+
+impl WorktreeIndex {
+ pub fn load(
+ worktree: Model<Worktree>,
+ db_connection: heed::Env,
+ language_registry: Arc<LanguageRegistry>,
+ fs: Arc<dyn Fs>,
+ status_tx: channel::Sender<()>,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ cx: &mut AppContext,
+ ) -> Task<Result<Model<Self>>> {
+ let worktree_for_index = worktree.clone();
+ let worktree_for_summary = worktree.clone();
+ let worktree_abs_path = worktree.read(cx).abs_path();
+ let embedding_fs = Arc::clone(&fs);
+ let summary_fs = fs;
+ cx.spawn(|mut cx| async move {
+ let entries_being_indexed = Arc::new(IndexingEntrySet::new(status_tx));
+ let (embedding_index, summary_index) = cx
+ .background_executor()
+ .spawn({
+ let entries_being_indexed = Arc::clone(&entries_being_indexed);
+ let db_connection = db_connection.clone();
+ async move {
+ let mut txn = db_connection.write_txn()?;
+ let embedding_index = {
+ let db_name = worktree_abs_path.to_string_lossy();
+ let db = db_connection.create_database(&mut txn, Some(&db_name))?;
+
+ EmbeddingIndex::new(
+ worktree_for_index,
+ embedding_fs,
+ db_connection.clone(),
+ db,
+ language_registry,
+ embedding_provider,
+ Arc::clone(&entries_being_indexed),
+ )
+ };
+ let summary_index = {
+ let file_digest_db = {
+ let db_name =
+ // Prepend something that wouldn't be found at the beginning of an
+ // absolute path, so we don't get db key namespace conflicts with
+ // embeddings, which use the abs path as a key.
+ format!("digests-{}", worktree_abs_path.to_string_lossy());
+ db_connection.create_database(&mut txn, Some(&db_name))?
+ };
+ let summary_db = {
+ let db_name =
+ // Prepend something that wouldn't be found at the beginning of an
+ // absolute path, so we don't get db key namespace conflicts with
+ // embeddings, which use the abs path as a key.
+ format!("summaries-{}", worktree_abs_path.to_string_lossy());
+ db_connection.create_database(&mut txn, Some(&db_name))?
+ };
+ SummaryIndex::new(
+ worktree_for_summary,
+ summary_fs,
+ db_connection.clone(),
+ file_digest_db,
+ summary_db,
+ Arc::clone(&entries_being_indexed),
+ )
+ };
+ txn.commit()?;
+ anyhow::Ok((embedding_index, summary_index))
+ }
+ })
+ .await?;
+
+ cx.new_model(|cx| {
+ Self::new(
+ worktree,
+ db_connection,
+ embedding_index,
+ summary_index,
+ entries_being_indexed,
+ cx,
+ )
+ })
+ })
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ worktree: Model<Worktree>,
+ db_connection: heed::Env,
+ embedding_index: EmbeddingIndex,
+ summary_index: SummaryIndex,
+ entry_ids_being_indexed: Arc<IndexingEntrySet>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
+ let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
+ if let worktree::Event::UpdatedEntries(update) = event {
+ log::debug!("Updating entries...");
+ _ = updated_entries_tx.try_send(update.clone());
+ }
+ });
+
+ Self {
+ db_connection,
+ embedding_index,
+ summary_index,
+ worktree,
+ entry_ids_being_indexed,
+ _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
+ _subscription,
+ }
+ }
+
+ pub fn entry_ids_being_indexed(&self) -> &IndexingEntrySet {
+ self.entry_ids_being_indexed.as_ref()
+ }
+
+ pub fn worktree(&self) -> &Model<Worktree> {
+ &self.worktree
+ }
+
+ pub fn db_connection(&self) -> &heed::Env {
+ &self.db_connection
+ }
+
+ pub fn embedding_index(&self) -> &EmbeddingIndex {
+ &self.embedding_index
+ }
+
+ pub fn summary_index(&self) -> &SummaryIndex {
+ &self.summary_index
+ }
+
+ async fn index_entries(
+ this: WeakModel<Self>,
+ updated_entries: channel::Receiver<UpdatedEntriesSet>,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ let is_auto_available = cx.update(|cx| cx.wait_for_flag::<AutoCommand>())?.await;
+ let index = this.update(&mut cx, |this, cx| {
+ futures::future::try_join(
+ this.embedding_index.index_entries_changed_on_disk(cx),
+ this.summary_index
+ .index_entries_changed_on_disk(is_auto_available, cx),
+ )
+ })?;
+ index.await.log_err();
+
+ while let Ok(updated_entries) = updated_entries.recv().await {
+ let is_auto_available = cx
+ .update(|cx| cx.has_flag::<AutoCommand>())
+ .unwrap_or(false);
+
+ let index = this.update(&mut cx, |this, cx| {
+ futures::future::try_join(
+ this.embedding_index
+ .index_updated_entries(updated_entries.clone(), cx),
+ this.summary_index.index_updated_entries(
+ updated_entries,
+ is_auto_available,
+ cx,
+ ),
+ )
+ })?;
+ index.await.log_err();
+ }
+
+ Ok(())
+ }
+
+ #[cfg(test)]
+ pub fn path_count(&self) -> Result<u64> {
+ use anyhow::Context;
+
+ let txn = self
+ .db_connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+ Ok(self.embedding_index().db().len(&txn)?)
+ }
+}
@@ -3227,6 +3227,8 @@ pub struct Entry {
pub git_status: Option<GitFileStatus>,
/// Whether this entry is considered to be a `.env` file.
pub is_private: bool,
+ /// The entry's size on disk, in bytes.
+ pub size: u64,
pub char_bag: CharBag,
pub is_fifo: bool,
}
@@ -3282,6 +3284,7 @@ impl Entry {
path,
inode: metadata.inode,
mtime: Some(metadata.mtime),
+ size: metadata.len,
canonical_path,
is_symlink: metadata.is_symlink,
is_ignored: false,
@@ -5210,6 +5213,7 @@ impl<'a> From<&'a Entry> for proto::Entry {
is_external: entry.is_external,
git_status: entry.git_status.map(git_status_to_proto),
is_fifo: entry.is_fifo,
+ size: Some(entry.size),
}
}
}
@@ -5231,6 +5235,7 @@ impl<'a> TryFrom<(&'a CharBag, proto::Entry)> for Entry {
path,
inode: entry.inode,
mtime: entry.mtime.map(|time| time.into()),
+ size: entry.size.unwrap_or(0),
canonical_path: None,
is_ignored: entry.is_ignored,
is_external: entry.is_external,