1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{
13 prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
14};
15use heed::types::{SerdeBincode, SerdeJson};
16use heed::Database;
17use language_model::{LanguageModelToolUseId, Role};
18use project::Project;
19use prompt_store::PromptBuilder;
20use serde::{Deserialize, Serialize};
21use util::ResultExt as _;
22
23use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
24
25pub fn init(cx: &mut App) {
26 ThreadsDatabase::init(cx);
27}
28
29pub struct ThreadStore {
30 project: Entity<Project>,
31 tools: Arc<ToolWorkingSet>,
32 prompt_builder: Arc<PromptBuilder>,
33 context_server_manager: Entity<ContextServerManager>,
34 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
35 threads: Vec<SerializedThreadMetadata>,
36}
37
38impl ThreadStore {
39 pub fn new(
40 project: Entity<Project>,
41 tools: Arc<ToolWorkingSet>,
42 prompt_builder: Arc<PromptBuilder>,
43 cx: &mut App,
44 ) -> Result<Entity<Self>> {
45 let this = cx.new(|cx| {
46 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
47 let context_server_manager = cx.new(|cx| {
48 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
49 });
50
51 let this = Self {
52 project,
53 tools,
54 prompt_builder,
55 context_server_manager,
56 context_server_tool_ids: HashMap::default(),
57 threads: Vec::new(),
58 };
59 this.register_context_server_handlers(cx);
60 this.reload(cx).detach_and_log_err(cx);
61
62 this
63 });
64
65 Ok(this)
66 }
67
68 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
69 self.context_server_manager.clone()
70 }
71
72 pub fn tools(&self) -> Arc<ToolWorkingSet> {
73 self.tools.clone()
74 }
75
76 /// Returns the number of threads.
77 pub fn thread_count(&self) -> usize {
78 self.threads.len()
79 }
80
81 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
82 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
83 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
84 threads
85 }
86
87 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
88 self.threads().into_iter().take(limit).collect()
89 }
90
91 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
92 cx.new(|cx| {
93 Thread::new(
94 self.project.clone(),
95 self.tools.clone(),
96 self.prompt_builder.clone(),
97 cx,
98 )
99 })
100 }
101
102 pub fn open_thread(
103 &self,
104 id: &ThreadId,
105 cx: &mut Context<Self>,
106 ) -> Task<Result<Entity<Thread>>> {
107 let id = id.clone();
108 let database_future = ThreadsDatabase::global_future(cx);
109 cx.spawn(async move |this, cx| {
110 let database = database_future.await.map_err(|err| anyhow!(err))?;
111 let thread = database
112 .try_find_thread(id.clone())
113 .await?
114 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
115
116 let thread = this.update(cx, |this, cx| {
117 cx.new(|cx| {
118 Thread::deserialize(
119 id.clone(),
120 thread,
121 this.project.clone(),
122 this.tools.clone(),
123 this.prompt_builder.clone(),
124 cx,
125 )
126 })
127 })?;
128
129 let (system_prompt_context, load_error) = thread
130 .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
131 .await;
132 thread.update(cx, |thread, cx| {
133 thread.set_system_prompt_context(system_prompt_context);
134 if let Some(load_error) = load_error {
135 cx.emit(ThreadEvent::ShowError(load_error));
136 }
137 })?;
138
139 Ok(thread)
140 })
141 }
142
143 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
144 let (metadata, serialized_thread) =
145 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
146
147 let database_future = ThreadsDatabase::global_future(cx);
148 cx.spawn(async move |this, cx| {
149 let serialized_thread = serialized_thread.await?;
150 let database = database_future.await.map_err(|err| anyhow!(err))?;
151 database.save_thread(metadata, serialized_thread).await?;
152
153 this.update(cx, |this, cx| this.reload(cx))?.await
154 })
155 }
156
157 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
158 let id = id.clone();
159 let database_future = ThreadsDatabase::global_future(cx);
160 cx.spawn(async move |this, cx| {
161 let database = database_future.await.map_err(|err| anyhow!(err))?;
162 database.delete_thread(id.clone()).await?;
163
164 this.update(cx, |this, _cx| {
165 this.threads.retain(|thread| thread.id != id)
166 })
167 })
168 }
169
170 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
171 let database_future = ThreadsDatabase::global_future(cx);
172 cx.spawn(async move |this, cx| {
173 let threads = database_future
174 .await
175 .map_err(|err| anyhow!(err))?
176 .list_threads()
177 .await?;
178
179 this.update(cx, |this, cx| {
180 this.threads = threads;
181 cx.notify();
182 })
183 })
184 }
185
186 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
187 cx.subscribe(
188 &self.context_server_manager.clone(),
189 Self::handle_context_server_event,
190 )
191 .detach();
192 }
193
194 fn handle_context_server_event(
195 &mut self,
196 context_server_manager: Entity<ContextServerManager>,
197 event: &context_server::manager::Event,
198 cx: &mut Context<Self>,
199 ) {
200 let tool_working_set = self.tools.clone();
201 match event {
202 context_server::manager::Event::ServerStarted { server_id } => {
203 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
204 let context_server_manager = context_server_manager.clone();
205 cx.spawn({
206 let server = server.clone();
207 let server_id = server_id.clone();
208 async move |this, cx| {
209 let Some(protocol) = server.client() else {
210 return;
211 };
212
213 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
214 if let Some(tools) = protocol.list_tools().await.log_err() {
215 let tool_ids = tools
216 .tools
217 .into_iter()
218 .map(|tool| {
219 log::info!(
220 "registering context server tool: {:?}",
221 tool.name
222 );
223 tool_working_set.insert(Arc::new(
224 ContextServerTool::new(
225 context_server_manager.clone(),
226 server.id(),
227 tool,
228 ),
229 ))
230 })
231 .collect::<Vec<_>>();
232
233 this.update(cx, |this, _cx| {
234 this.context_server_tool_ids.insert(server_id, tool_ids);
235 })
236 .log_err();
237 }
238 }
239 }
240 })
241 .detach();
242 }
243 }
244 context_server::manager::Event::ServerStopped { server_id } => {
245 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
246 tool_working_set.remove(&tool_ids);
247 }
248 }
249 }
250 }
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct SerializedThreadMetadata {
255 pub id: ThreadId,
256 pub summary: SharedString,
257 pub updated_at: DateTime<Utc>,
258}
259
260#[derive(Serialize, Deserialize)]
261pub struct SerializedThread {
262 pub summary: SharedString,
263 pub updated_at: DateTime<Utc>,
264 pub messages: Vec<SerializedMessage>,
265 #[serde(default)]
266 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
267}
268
269#[derive(Debug, Serialize, Deserialize)]
270pub struct SerializedMessage {
271 pub id: MessageId,
272 pub role: Role,
273 pub text: String,
274 #[serde(default)]
275 pub tool_uses: Vec<SerializedToolUse>,
276 #[serde(default)]
277 pub tool_results: Vec<SerializedToolResult>,
278}
279
280#[derive(Debug, Serialize, Deserialize)]
281pub struct SerializedToolUse {
282 pub id: LanguageModelToolUseId,
283 pub name: SharedString,
284 pub input: serde_json::Value,
285}
286
287#[derive(Debug, Serialize, Deserialize)]
288pub struct SerializedToolResult {
289 pub tool_use_id: LanguageModelToolUseId,
290 pub is_error: bool,
291 pub content: Arc<str>,
292}
293
294struct GlobalThreadsDatabase(
295 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
296);
297
298impl Global for GlobalThreadsDatabase {}
299
300pub(crate) struct ThreadsDatabase {
301 executor: BackgroundExecutor,
302 env: heed::Env,
303 threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
304}
305
306impl ThreadsDatabase {
307 fn global_future(
308 cx: &mut App,
309 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
310 GlobalThreadsDatabase::global(cx).0.clone()
311 }
312
313 fn init(cx: &mut App) {
314 let executor = cx.background_executor().clone();
315 let database_future = executor
316 .spawn({
317 let executor = executor.clone();
318 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
319 async move { ThreadsDatabase::new(database_path, executor) }
320 })
321 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
322 .boxed()
323 .shared();
324
325 cx.set_global(GlobalThreadsDatabase(database_future));
326 }
327
328 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
329 std::fs::create_dir_all(&path)?;
330
331 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
332 let env = unsafe {
333 heed::EnvOpenOptions::new()
334 .map_size(ONE_GB_IN_BYTES)
335 .max_dbs(1)
336 .open(path)?
337 };
338
339 let mut txn = env.write_txn()?;
340 let threads = env.create_database(&mut txn, Some("threads"))?;
341 txn.commit()?;
342
343 Ok(Self {
344 executor,
345 env,
346 threads,
347 })
348 }
349
350 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
351 let env = self.env.clone();
352 let threads = self.threads;
353
354 self.executor.spawn(async move {
355 let txn = env.read_txn()?;
356 let mut iter = threads.iter(&txn)?;
357 let mut threads = Vec::new();
358 while let Some((key, value)) = iter.next().transpose()? {
359 threads.push(SerializedThreadMetadata {
360 id: key,
361 summary: value.summary,
362 updated_at: value.updated_at,
363 });
364 }
365
366 Ok(threads)
367 })
368 }
369
370 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
371 let env = self.env.clone();
372 let threads = self.threads;
373
374 self.executor.spawn(async move {
375 let txn = env.read_txn()?;
376 let thread = threads.get(&txn, &id)?;
377 Ok(thread)
378 })
379 }
380
381 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
382 let env = self.env.clone();
383 let threads = self.threads;
384
385 self.executor.spawn(async move {
386 let mut txn = env.write_txn()?;
387 threads.put(&mut txn, &id, &thread)?;
388 txn.commit()?;
389 Ok(())
390 })
391 }
392
393 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
394 let env = self.env.clone();
395 let threads = self.threads;
396
397 self.executor.spawn(async move {
398 let mut txn = env.write_txn()?;
399 threads.delete(&mut txn, &id)?;
400 txn.commit()?;
401 Ok(())
402 })
403 }
404}