@@ -1,8 +1,7 @@
-use std::borrow::Cow;
use std::cell::{Ref, RefCell};
use std::path::{Path, PathBuf};
use std::rc::Rc;
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
@@ -17,8 +16,7 @@ use gpui::{
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
Subscription, Task, prelude::*,
};
-use heed::Database;
-use heed::types::SerdeBincode;
+
use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
use project::context_server_store::{ContextServerStatus, ContextServerStore};
use project::{Project, ProjectItem, ProjectPath, Worktree};
@@ -35,6 +33,42 @@ use crate::context_server_tool::ContextServerTool;
use crate::thread::{
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
};
+use indoc::indoc;
+use sqlez::{
+ bindable::{Bind, Column},
+ connection::Connection,
+ statement::Statement,
+};
+
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+pub enum DataType {
+ #[serde(rename = "json")]
+ Json,
+ #[serde(rename = "zstd")]
+ Zstd,
+}
+
+impl Bind for DataType {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ let value = match self {
+ DataType::Json => "json",
+ DataType::Zstd => "zstd",
+ };
+ value.bind(statement, start_index)
+ }
+}
+
+impl Column for DataType {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let (value, next_index) = String::column(statement, start_index)?;
+ let data_type = match value.as_str() {
+ "json" => DataType::Json,
+ "zstd" => DataType::Zstd,
+ _ => anyhow::bail!("Unknown data type: {}", value),
+ };
+ Ok((data_type, next_index))
+ }
+}
const RULES_FILE_NAMES: [&'static str; 6] = [
".rules",
@@ -866,25 +900,27 @@ impl Global for GlobalThreadsDatabase {}
pub(crate) struct ThreadsDatabase {
executor: BackgroundExecutor,
- env: heed::Env,
- threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
+ connection: Arc<Mutex<Connection>>,
}
-impl heed::BytesEncode<'_> for SerializedThread {
- type EItem = SerializedThread;
-
- fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
- serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
+impl ThreadsDatabase {
+ fn connection(&self) -> Arc<Mutex<Connection>> {
+ self.connection.clone()
}
+
+ const COMPRESSION_LEVEL: i32 = 3;
}
-impl<'a> heed::BytesDecode<'a> for SerializedThread {
- type DItem = SerializedThread;
+impl Bind for ThreadId {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ self.to_string().bind(statement, start_index)
+ }
+}
- fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
- // We implement this type manually because we want to call `SerializedThread::from_json`,
- // instead of the Deserialize trait implementation for `SerializedThread`.
- SerializedThread::from_json(bytes).map_err(Into::into)
+impl Column for ThreadId {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let (id_str, next_index) = String::column(statement, start_index)?;
+ Ok((ThreadId::from(id_str.as_str()), next_index))
}
}
@@ -900,8 +936,8 @@ impl ThreadsDatabase {
let database_future = executor
.spawn({
let executor = executor.clone();
- let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
- async move { ThreadsDatabase::new(database_path, executor) }
+ let threads_dir = paths::data_dir().join("threads");
+ async move { ThreadsDatabase::new(threads_dir, executor) }
})
.then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
.boxed()
@@ -910,41 +946,144 @@ impl ThreadsDatabase {
cx.set_global(GlobalThreadsDatabase(database_future));
}
- pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
- std::fs::create_dir_all(&path)?;
+ pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
+ std::fs::create_dir_all(&threads_dir)?;
+
+ let sqlite_path = threads_dir.join("threads.db");
+ let mdb_path = threads_dir.join("threads-db.1.mdb");
+
+ let needs_migration_from_heed = mdb_path.exists();
+
+ let connection = Connection::open_file(&sqlite_path.to_string_lossy());
+
+ connection.exec(indoc! {"
+ CREATE TABLE IF NOT EXISTS threads (
+ id TEXT PRIMARY KEY,
+ summary TEXT NOT NULL,
+ updated_at TEXT NOT NULL,
+ data_type TEXT NOT NULL,
+ data BLOB NOT NULL
+ )
+ "})?()
+ .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
+
+ let db = Self {
+ executor: executor.clone(),
+ connection: Arc::new(Mutex::new(connection)),
+ };
+
+ if needs_migration_from_heed {
+ let db_connection = db.connection();
+ let executor_clone = executor.clone();
+ executor
+ .spawn(async move {
+ log::info!("Starting threads.db migration");
+ Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
+ std::fs::remove_dir_all(mdb_path)?;
+ log::info!("threads.db migrated to sqlite");
+ Ok::<(), anyhow::Error>(())
+ })
+ .detach();
+ }
+
+ Ok(db)
+ }
+
+ // Remove this migration after 2025-09-01
+ fn migrate_from_heed(
+ mdb_path: &Path,
+ connection: Arc<Mutex<Connection>>,
+ _executor: BackgroundExecutor,
+ ) -> Result<()> {
+ use heed::types::SerdeBincode;
+ struct SerializedThreadHeed(SerializedThread);
+
+ impl heed::BytesEncode<'_> for SerializedThreadHeed {
+ type EItem = SerializedThreadHeed;
+
+ fn bytes_encode(
+ item: &Self::EItem,
+ ) -> Result<std::borrow::Cow<[u8]>, heed::BoxedError> {
+ serde_json::to_vec(&item.0)
+ .map(std::borrow::Cow::Owned)
+ .map_err(Into::into)
+ }
+ }
+
+ impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
+ type DItem = SerializedThreadHeed;
+
+ fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
+ SerializedThread::from_json(bytes)
+ .map(SerializedThreadHeed)
+ .map_err(Into::into)
+ }
+ }
const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
+
let env = unsafe {
heed::EnvOpenOptions::new()
.map_size(ONE_GB_IN_BYTES)
.max_dbs(1)
- .open(path)?
+ .open(mdb_path)?
};
- let mut txn = env.write_txn()?;
- let threads = env.create_database(&mut txn, Some("threads"))?;
- txn.commit()?;
+ let txn = env.write_txn()?;
+ let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
+ .open_database(&txn, Some("threads"))?
+ .ok_or_else(|| anyhow!("threads database not found"))?;
- Ok(Self {
- executor,
- env,
- threads,
- })
+ for result in threads.iter(&txn)? {
+ let (thread_id, thread_heed) = result?;
+ Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
+ }
+
+ Ok(())
+ }
+
+ fn save_thread_sync(
+ connection: &Arc<Mutex<Connection>>,
+ id: ThreadId,
+ thread: SerializedThread,
+ ) -> Result<()> {
+ let json_data = serde_json::to_string(&thread)?;
+ let summary = thread.summary.to_string();
+ let updated_at = thread.updated_at.to_rfc3339();
+
+ let connection = connection.lock().unwrap();
+
+ let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
+ let data_type = DataType::Zstd;
+ let data = compressed;
+
+ let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
+ INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
+ "})?;
+
+ insert((id, summary, updated_at, data_type, data))?;
+
+ Ok(())
}
pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
- let env = self.env.clone();
- let threads = self.threads;
+ let connection = self.connection.clone();
self.executor.spawn(async move {
- let txn = env.read_txn()?;
- let mut iter = threads.iter(&txn)?;
+ let connection = connection.lock().unwrap();
+ let mut select =
+ connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
+ SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
+ "})?;
+
+ let rows = select(())?;
let mut threads = Vec::new();
- while let Some((key, value)) = iter.next().transpose()? {
+
+ for (id, summary, updated_at) in rows {
threads.push(SerializedThreadMetadata {
- id: key,
- summary: value.summary,
- updated_at: value.updated_at,
+ id,
+ summary: summary.into(),
+ updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
});
}
@@ -953,36 +1092,51 @@ impl ThreadsDatabase {
}
pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
- let env = self.env.clone();
- let threads = self.threads;
+ let connection = self.connection.clone();
self.executor.spawn(async move {
- let txn = env.read_txn()?;
- let thread = threads.get(&txn, &id)?;
- Ok(thread)
+ let connection = connection.lock().unwrap();
+ let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
+ SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
+ "})?;
+
+ let rows = select(id)?;
+ if let Some((data_type, data)) = rows.into_iter().next() {
+ let json_data = match data_type {
+ DataType::Zstd => {
+ let decompressed = zstd::decode_all(&data[..])?;
+ String::from_utf8(decompressed)?
+ }
+ DataType::Json => String::from_utf8(data)?,
+ };
+
+ let thread = SerializedThread::from_json(json_data.as_bytes())?;
+ Ok(Some(thread))
+ } else {
+ Ok(None)
+ }
})
}
pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
- let env = self.env.clone();
- let threads = self.threads;
+ let connection = self.connection.clone();
- self.executor.spawn(async move {
- let mut txn = env.write_txn()?;
- threads.put(&mut txn, &id, &thread)?;
- txn.commit()?;
- Ok(())
- })
+ self.executor
+ .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
}
pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
- let env = self.env.clone();
- let threads = self.threads;
+ let connection = self.connection.clone();
self.executor.spawn(async move {
- let mut txn = env.write_txn()?;
- threads.delete(&mut txn, &id)?;
- txn.commit()?;
+ let connection = connection.lock().unwrap();
+
+ let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
+ DELETE FROM threads WHERE id = ?
+ "})?;
+
+ delete(id)?;
+
Ok(())
})
}