1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{
13 prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
14};
15use heed::types::{SerdeBincode, SerdeJson};
16use heed::Database;
17use language_model::{LanguageModelToolUseId, Role};
18use project::Project;
19use prompt_store::PromptBuilder;
20use serde::{Deserialize, Serialize};
21use util::ResultExt as _;
22
23use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadId};
24
25pub fn init(cx: &mut App) {
26 ThreadsDatabase::init(cx);
27}
28
29pub struct ThreadStore {
30 project: Entity<Project>,
31 tools: Arc<ToolWorkingSet>,
32 prompt_builder: Arc<PromptBuilder>,
33 context_server_manager: Entity<ContextServerManager>,
34 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
35 threads: Vec<SerializedThreadMetadata>,
36}
37
38impl ThreadStore {
39 pub fn new(
40 project: Entity<Project>,
41 tools: Arc<ToolWorkingSet>,
42 prompt_builder: Arc<PromptBuilder>,
43 cx: &mut App,
44 ) -> Result<Entity<Self>> {
45 let this = cx.new(|cx| {
46 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
47 let context_server_manager = cx.new(|cx| {
48 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
49 });
50
51 let this = Self {
52 project,
53 tools,
54 prompt_builder,
55 context_server_manager,
56 context_server_tool_ids: HashMap::default(),
57 threads: Vec::new(),
58 };
59 this.register_context_server_handlers(cx);
60 this.reload(cx).detach_and_log_err(cx);
61
62 this
63 });
64
65 Ok(this)
66 }
67
68 /// Returns the number of threads.
69 pub fn thread_count(&self) -> usize {
70 self.threads.len()
71 }
72
73 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
74 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
75 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
76 threads
77 }
78
79 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
80 self.threads().into_iter().take(limit).collect()
81 }
82
83 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
84 cx.new(|cx| {
85 Thread::new(
86 self.project.clone(),
87 self.tools.clone(),
88 self.prompt_builder.clone(),
89 cx,
90 )
91 })
92 }
93
94 pub fn open_thread(
95 &self,
96 id: &ThreadId,
97 cx: &mut Context<Self>,
98 ) -> Task<Result<Entity<Thread>>> {
99 let id = id.clone();
100 let database_future = ThreadsDatabase::global_future(cx);
101 cx.spawn(|this, mut cx| async move {
102 let database = database_future.await.map_err(|err| anyhow!(err))?;
103 let thread = database
104 .try_find_thread(id.clone())
105 .await?
106 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
107
108 this.update(&mut cx, |this, cx| {
109 cx.new(|cx| {
110 Thread::deserialize(
111 id.clone(),
112 thread,
113 this.project.clone(),
114 this.tools.clone(),
115 this.prompt_builder.clone(),
116 cx,
117 )
118 })
119 })
120 })
121 }
122
123 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
124 let (metadata, serialized_thread) =
125 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
126
127 let database_future = ThreadsDatabase::global_future(cx);
128 cx.spawn(|this, mut cx| async move {
129 let serialized_thread = serialized_thread.await?;
130 let database = database_future.await.map_err(|err| anyhow!(err))?;
131 database.save_thread(metadata, serialized_thread).await?;
132
133 this.update(&mut cx, |this, cx| this.reload(cx))?.await
134 })
135 }
136
137 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
138 let id = id.clone();
139 let database_future = ThreadsDatabase::global_future(cx);
140 cx.spawn(|this, mut cx| async move {
141 let database = database_future.await.map_err(|err| anyhow!(err))?;
142 database.delete_thread(id.clone()).await?;
143
144 this.update(&mut cx, |this, _cx| {
145 this.threads.retain(|thread| thread.id != id)
146 })
147 })
148 }
149
150 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
151 let database_future = ThreadsDatabase::global_future(cx);
152 cx.spawn(|this, mut cx| async move {
153 let threads = database_future
154 .await
155 .map_err(|err| anyhow!(err))?
156 .list_threads()
157 .await?;
158
159 this.update(&mut cx, |this, cx| {
160 this.threads = threads;
161 cx.notify();
162 })
163 })
164 }
165
166 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
167 cx.subscribe(
168 &self.context_server_manager.clone(),
169 Self::handle_context_server_event,
170 )
171 .detach();
172 }
173
174 fn handle_context_server_event(
175 &mut self,
176 context_server_manager: Entity<ContextServerManager>,
177 event: &context_server::manager::Event,
178 cx: &mut Context<Self>,
179 ) {
180 let tool_working_set = self.tools.clone();
181 match event {
182 context_server::manager::Event::ServerStarted { server_id } => {
183 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
184 let context_server_manager = context_server_manager.clone();
185 cx.spawn({
186 let server = server.clone();
187 let server_id = server_id.clone();
188 |this, mut cx| async move {
189 let Some(protocol) = server.client() else {
190 return;
191 };
192
193 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
194 if let Some(tools) = protocol.list_tools().await.log_err() {
195 let tool_ids = tools
196 .tools
197 .into_iter()
198 .map(|tool| {
199 log::info!(
200 "registering context server tool: {:?}",
201 tool.name
202 );
203 tool_working_set.insert(Arc::new(
204 ContextServerTool::new(
205 context_server_manager.clone(),
206 server.id(),
207 tool,
208 ),
209 ))
210 })
211 .collect::<Vec<_>>();
212
213 this.update(&mut cx, |this, _cx| {
214 this.context_server_tool_ids.insert(server_id, tool_ids);
215 })
216 .log_err();
217 }
218 }
219 }
220 })
221 .detach();
222 }
223 }
224 context_server::manager::Event::ServerStopped { server_id } => {
225 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
226 tool_working_set.remove(&tool_ids);
227 }
228 }
229 }
230 }
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct SerializedThreadMetadata {
235 pub id: ThreadId,
236 pub summary: SharedString,
237 pub updated_at: DateTime<Utc>,
238}
239
240#[derive(Serialize, Deserialize)]
241pub struct SerializedThread {
242 pub summary: SharedString,
243 pub updated_at: DateTime<Utc>,
244 pub messages: Vec<SerializedMessage>,
245 #[serde(default)]
246 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250pub struct SerializedMessage {
251 pub id: MessageId,
252 pub role: Role,
253 pub text: String,
254 #[serde(default)]
255 pub tool_uses: Vec<SerializedToolUse>,
256 #[serde(default)]
257 pub tool_results: Vec<SerializedToolResult>,
258}
259
260#[derive(Debug, Serialize, Deserialize)]
261pub struct SerializedToolUse {
262 pub id: LanguageModelToolUseId,
263 pub name: SharedString,
264 pub input: serde_json::Value,
265}
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct SerializedToolResult {
269 pub tool_use_id: LanguageModelToolUseId,
270 pub is_error: bool,
271 pub content: Arc<str>,
272}
273
274struct GlobalThreadsDatabase(
275 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
276);
277
278impl Global for GlobalThreadsDatabase {}
279
280pub(crate) struct ThreadsDatabase {
281 executor: BackgroundExecutor,
282 env: heed::Env,
283 threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
284}
285
286impl ThreadsDatabase {
287 fn global_future(
288 cx: &mut App,
289 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
290 GlobalThreadsDatabase::global(cx).0.clone()
291 }
292
293 fn init(cx: &mut App) {
294 let executor = cx.background_executor().clone();
295 let database_future = executor
296 .spawn({
297 let executor = executor.clone();
298 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
299 async move { ThreadsDatabase::new(database_path, executor) }
300 })
301 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
302 .boxed()
303 .shared();
304
305 cx.set_global(GlobalThreadsDatabase(database_future));
306 }
307
308 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
309 std::fs::create_dir_all(&path)?;
310
311 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
312 let env = unsafe {
313 heed::EnvOpenOptions::new()
314 .map_size(ONE_GB_IN_BYTES)
315 .max_dbs(1)
316 .open(path)?
317 };
318
319 let mut txn = env.write_txn()?;
320 let threads = env.create_database(&mut txn, Some("threads"))?;
321 txn.commit()?;
322
323 Ok(Self {
324 executor,
325 env,
326 threads,
327 })
328 }
329
330 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
331 let env = self.env.clone();
332 let threads = self.threads;
333
334 self.executor.spawn(async move {
335 let txn = env.read_txn()?;
336 let mut iter = threads.iter(&txn)?;
337 let mut threads = Vec::new();
338 while let Some((key, value)) = iter.next().transpose()? {
339 threads.push(SerializedThreadMetadata {
340 id: key,
341 summary: value.summary,
342 updated_at: value.updated_at,
343 });
344 }
345
346 Ok(threads)
347 })
348 }
349
350 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
351 let env = self.env.clone();
352 let threads = self.threads;
353
354 self.executor.spawn(async move {
355 let txn = env.read_txn()?;
356 let thread = threads.get(&txn, &id)?;
357 Ok(thread)
358 })
359 }
360
361 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
362 let env = self.env.clone();
363 let threads = self.threads;
364
365 self.executor.spawn(async move {
366 let mut txn = env.write_txn()?;
367 threads.put(&mut txn, &id, &thread)?;
368 txn.commit()?;
369 Ok(())
370 })
371 }
372
373 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
374 let env = self.env.clone();
375 let threads = self.threads;
376
377 self.executor.spawn(async move {
378 let mut txn = env.write_txn()?;
379 threads.delete(&mut txn, &id)?;
380 txn.commit()?;
381 Ok(())
382 })
383 }
384}