1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{
13 prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
14};
15use heed::types::{SerdeBincode, SerdeJson};
16use heed::Database;
17use language_model::{LanguageModelToolUseId, Role};
18use project::Project;
19use prompt_store::PromptBuilder;
20use serde::{Deserialize, Serialize};
21use util::ResultExt as _;
22
23use crate::thread::{MessageId, Thread, ThreadId};
24
25pub fn init(cx: &mut App) {
26 ThreadsDatabase::init(cx);
27}
28
29pub struct ThreadStore {
30 project: Entity<Project>,
31 tools: Arc<ToolWorkingSet>,
32 prompt_builder: Arc<PromptBuilder>,
33 context_server_manager: Entity<ContextServerManager>,
34 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
35 threads: Vec<SavedThreadMetadata>,
36}
37
38impl ThreadStore {
39 pub fn new(
40 project: Entity<Project>,
41 tools: Arc<ToolWorkingSet>,
42 prompt_builder: Arc<PromptBuilder>,
43 cx: &mut App,
44 ) -> Result<Entity<Self>> {
45 let this = cx.new(|cx| {
46 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
47 let context_server_manager = cx.new(|cx| {
48 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
49 });
50
51 let this = Self {
52 project,
53 tools,
54 prompt_builder,
55 context_server_manager,
56 context_server_tool_ids: HashMap::default(),
57 threads: Vec::new(),
58 };
59 this.register_context_server_handlers(cx);
60 this.reload(cx).detach_and_log_err(cx);
61
62 this
63 });
64
65 Ok(this)
66 }
67
68 /// Returns the number of threads.
69 pub fn thread_count(&self) -> usize {
70 self.threads.len()
71 }
72
73 pub fn threads(&self) -> Vec<SavedThreadMetadata> {
74 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
75 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
76 threads
77 }
78
79 pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
80 self.threads().into_iter().take(limit).collect()
81 }
82
83 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
84 cx.new(|cx| {
85 Thread::new(
86 self.project.clone(),
87 self.tools.clone(),
88 self.prompt_builder.clone(),
89 cx,
90 )
91 })
92 }
93
94 pub fn open_thread(
95 &self,
96 id: &ThreadId,
97 cx: &mut Context<Self>,
98 ) -> Task<Result<Entity<Thread>>> {
99 let id = id.clone();
100 let database_future = ThreadsDatabase::global_future(cx);
101 cx.spawn(|this, mut cx| async move {
102 let database = database_future.await.map_err(|err| anyhow!(err))?;
103 let thread = database
104 .try_find_thread(id.clone())
105 .await?
106 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
107
108 this.update(&mut cx, |this, cx| {
109 cx.new(|cx| {
110 Thread::from_saved(
111 id.clone(),
112 thread,
113 this.project.clone(),
114 this.tools.clone(),
115 this.prompt_builder.clone(),
116 cx,
117 )
118 })
119 })
120 })
121 }
122
123 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
124 let (metadata, thread) = thread.update(cx, |thread, _cx| {
125 let id = thread.id().clone();
126 let thread = SavedThread {
127 summary: thread.summary_or_default(),
128 updated_at: thread.updated_at(),
129 messages: thread
130 .messages()
131 .map(|message| {
132 let all_tool_uses = thread
133 .tool_uses_for_message(message.id)
134 .into_iter()
135 .chain(thread.scripting_tool_uses_for_message(message.id))
136 .map(|tool_use| SavedToolUse {
137 id: tool_use.id,
138 name: tool_use.name,
139 input: tool_use.input,
140 })
141 .collect();
142 let all_tool_results = thread
143 .tool_results_for_message(message.id)
144 .into_iter()
145 .chain(thread.scripting_tool_results_for_message(message.id))
146 .map(|tool_result| SavedToolResult {
147 tool_use_id: tool_result.tool_use_id.clone(),
148 is_error: tool_result.is_error,
149 content: tool_result.content.clone(),
150 })
151 .collect();
152
153 SavedMessage {
154 id: message.id,
155 role: message.role,
156 text: message.text.clone(),
157 tool_uses: all_tool_uses,
158 tool_results: all_tool_results,
159 }
160 })
161 .collect(),
162 };
163
164 (id, thread)
165 });
166
167 let database_future = ThreadsDatabase::global_future(cx);
168 cx.spawn(|this, mut cx| async move {
169 let database = database_future.await.map_err(|err| anyhow!(err))?;
170 database.save_thread(metadata, thread).await?;
171
172 this.update(&mut cx, |this, cx| this.reload(cx))?.await
173 })
174 }
175
176 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
177 let id = id.clone();
178 let database_future = ThreadsDatabase::global_future(cx);
179 cx.spawn(|this, mut cx| async move {
180 let database = database_future.await.map_err(|err| anyhow!(err))?;
181 database.delete_thread(id.clone()).await?;
182
183 this.update(&mut cx, |this, _cx| {
184 this.threads.retain(|thread| thread.id != id)
185 })
186 })
187 }
188
189 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
190 let database_future = ThreadsDatabase::global_future(cx);
191 cx.spawn(|this, mut cx| async move {
192 let threads = database_future
193 .await
194 .map_err(|err| anyhow!(err))?
195 .list_threads()
196 .await?;
197
198 this.update(&mut cx, |this, cx| {
199 this.threads = threads;
200 cx.notify();
201 })
202 })
203 }
204
205 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
206 cx.subscribe(
207 &self.context_server_manager.clone(),
208 Self::handle_context_server_event,
209 )
210 .detach();
211 }
212
213 fn handle_context_server_event(
214 &mut self,
215 context_server_manager: Entity<ContextServerManager>,
216 event: &context_server::manager::Event,
217 cx: &mut Context<Self>,
218 ) {
219 let tool_working_set = self.tools.clone();
220 match event {
221 context_server::manager::Event::ServerStarted { server_id } => {
222 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
223 let context_server_manager = context_server_manager.clone();
224 cx.spawn({
225 let server = server.clone();
226 let server_id = server_id.clone();
227 |this, mut cx| async move {
228 let Some(protocol) = server.client() else {
229 return;
230 };
231
232 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
233 if let Some(tools) = protocol.list_tools().await.log_err() {
234 let tool_ids = tools
235 .tools
236 .into_iter()
237 .map(|tool| {
238 log::info!(
239 "registering context server tool: {:?}",
240 tool.name
241 );
242 tool_working_set.insert(Arc::new(
243 ContextServerTool::new(
244 context_server_manager.clone(),
245 server.id(),
246 tool,
247 ),
248 ))
249 })
250 .collect::<Vec<_>>();
251
252 this.update(&mut cx, |this, _cx| {
253 this.context_server_tool_ids.insert(server_id, tool_ids);
254 })
255 .log_err();
256 }
257 }
258 }
259 })
260 .detach();
261 }
262 }
263 context_server::manager::Event::ServerStopped { server_id } => {
264 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
265 tool_working_set.remove(&tool_ids);
266 }
267 }
268 }
269 }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct SavedThreadMetadata {
274 pub id: ThreadId,
275 pub summary: SharedString,
276 pub updated_at: DateTime<Utc>,
277}
278
279#[derive(Serialize, Deserialize)]
280pub struct SavedThread {
281 pub summary: SharedString,
282 pub updated_at: DateTime<Utc>,
283 pub messages: Vec<SavedMessage>,
284}
285
286#[derive(Debug, Serialize, Deserialize)]
287pub struct SavedMessage {
288 pub id: MessageId,
289 pub role: Role,
290 pub text: String,
291 #[serde(default)]
292 pub tool_uses: Vec<SavedToolUse>,
293 #[serde(default)]
294 pub tool_results: Vec<SavedToolResult>,
295}
296
297#[derive(Debug, Serialize, Deserialize)]
298pub struct SavedToolUse {
299 pub id: LanguageModelToolUseId,
300 pub name: SharedString,
301 pub input: serde_json::Value,
302}
303
304#[derive(Debug, Serialize, Deserialize)]
305pub struct SavedToolResult {
306 pub tool_use_id: LanguageModelToolUseId,
307 pub is_error: bool,
308 pub content: Arc<str>,
309}
310
311struct GlobalThreadsDatabase(
312 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
313);
314
315impl Global for GlobalThreadsDatabase {}
316
317pub(crate) struct ThreadsDatabase {
318 executor: BackgroundExecutor,
319 env: heed::Env,
320 threads: Database<SerdeBincode<ThreadId>, SerdeJson<SavedThread>>,
321}
322
323impl ThreadsDatabase {
324 fn global_future(
325 cx: &mut App,
326 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
327 GlobalThreadsDatabase::global(cx).0.clone()
328 }
329
330 fn init(cx: &mut App) {
331 let executor = cx.background_executor().clone();
332 let database_future = executor
333 .spawn({
334 let executor = executor.clone();
335 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
336 async move { ThreadsDatabase::new(database_path, executor) }
337 })
338 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
339 .boxed()
340 .shared();
341
342 cx.set_global(GlobalThreadsDatabase(database_future));
343 }
344
345 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
346 std::fs::create_dir_all(&path)?;
347
348 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
349 let env = unsafe {
350 heed::EnvOpenOptions::new()
351 .map_size(ONE_GB_IN_BYTES)
352 .max_dbs(1)
353 .open(path)?
354 };
355
356 let mut txn = env.write_txn()?;
357 let threads = env.create_database(&mut txn, Some("threads"))?;
358 txn.commit()?;
359
360 Ok(Self {
361 executor,
362 env,
363 threads,
364 })
365 }
366
367 pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
368 let env = self.env.clone();
369 let threads = self.threads;
370
371 self.executor.spawn(async move {
372 let txn = env.read_txn()?;
373 let mut iter = threads.iter(&txn)?;
374 let mut threads = Vec::new();
375 while let Some((key, value)) = iter.next().transpose()? {
376 threads.push(SavedThreadMetadata {
377 id: key,
378 summary: value.summary,
379 updated_at: value.updated_at,
380 });
381 }
382
383 Ok(threads)
384 })
385 }
386
387 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
388 let env = self.env.clone();
389 let threads = self.threads;
390
391 self.executor.spawn(async move {
392 let txn = env.read_txn()?;
393 let thread = threads.get(&txn, &id)?;
394 Ok(thread)
395 })
396 }
397
398 pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
399 let env = self.env.clone();
400 let threads = self.threads;
401
402 self.executor.spawn(async move {
403 let mut txn = env.write_txn()?;
404 threads.put(&mut txn, &id, &thread)?;
405 txn.commit()?;
406 Ok(())
407 })
408 }
409
410 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
411 let env = self.env.clone();
412 let threads = self.threads;
413
414 self.executor.spawn(async move {
415 let mut txn = env.write_txn()?;
416 threads.delete(&mut txn, &id)?;
417 txn.commit()?;
418 Ok(())
419 })
420 }
421}