1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{
13 prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
14};
15use heed::types::{SerdeBincode, SerdeJson};
16use heed::Database;
17use language_model::{LanguageModelToolUseId, Role};
18use project::Project;
19use serde::{Deserialize, Serialize};
20use util::ResultExt as _;
21
22use crate::thread::{MessageId, Thread, ThreadId};
23
24pub fn init(cx: &mut App) {
25 ThreadsDatabase::init(cx);
26}
27
28pub struct ThreadStore {
29 project: Entity<Project>,
30 tools: Arc<ToolWorkingSet>,
31 context_server_manager: Entity<ContextServerManager>,
32 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
33 threads: Vec<SavedThreadMetadata>,
34}
35
36impl ThreadStore {
37 pub fn new(
38 project: Entity<Project>,
39 tools: Arc<ToolWorkingSet>,
40 cx: &mut App,
41 ) -> Result<Entity<Self>> {
42 let this = cx.new(|cx| {
43 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
44 let context_server_manager = cx.new(|cx| {
45 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
46 });
47
48 let this = Self {
49 project,
50 tools,
51 context_server_manager,
52 context_server_tool_ids: HashMap::default(),
53 threads: Vec::new(),
54 };
55 this.register_context_server_handlers(cx);
56 this.reload(cx).detach_and_log_err(cx);
57
58 this
59 });
60
61 Ok(this)
62 }
63
64 /// Returns the number of threads.
65 pub fn thread_count(&self) -> usize {
66 self.threads.len()
67 }
68
69 pub fn threads(&self) -> Vec<SavedThreadMetadata> {
70 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
71 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
72 threads
73 }
74
75 pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
76 self.threads().into_iter().take(limit).collect()
77 }
78
79 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
80 cx.new(|cx| Thread::new(self.project.clone(), self.tools.clone(), cx))
81 }
82
83 pub fn open_thread(
84 &self,
85 id: &ThreadId,
86 cx: &mut Context<Self>,
87 ) -> Task<Result<Entity<Thread>>> {
88 let id = id.clone();
89 let database_future = ThreadsDatabase::global_future(cx);
90 cx.spawn(|this, mut cx| async move {
91 let database = database_future.await.map_err(|err| anyhow!(err))?;
92 let thread = database
93 .try_find_thread(id.clone())
94 .await?
95 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
96
97 this.update(&mut cx, |this, cx| {
98 cx.new(|cx| {
99 Thread::from_saved(
100 id.clone(),
101 thread,
102 this.project.clone(),
103 this.tools.clone(),
104 cx,
105 )
106 })
107 })
108 })
109 }
110
111 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
112 let (metadata, thread) = thread.update(cx, |thread, _cx| {
113 let id = thread.id().clone();
114 let thread = SavedThread {
115 summary: thread.summary_or_default(),
116 updated_at: thread.updated_at(),
117 messages: thread
118 .messages()
119 .map(|message| {
120 let all_tool_uses = thread
121 .tool_uses_for_message(message.id)
122 .into_iter()
123 .chain(thread.scripting_tool_uses_for_message(message.id))
124 .map(|tool_use| SavedToolUse {
125 id: tool_use.id,
126 name: tool_use.name,
127 input: tool_use.input,
128 })
129 .collect();
130 let all_tool_results = thread
131 .tool_results_for_message(message.id)
132 .into_iter()
133 .chain(thread.scripting_tool_results_for_message(message.id))
134 .map(|tool_result| SavedToolResult {
135 tool_use_id: tool_result.tool_use_id.clone(),
136 is_error: tool_result.is_error,
137 content: tool_result.content.clone(),
138 })
139 .collect();
140
141 SavedMessage {
142 id: message.id,
143 role: message.role,
144 text: message.text.clone(),
145 tool_uses: all_tool_uses,
146 tool_results: all_tool_results,
147 }
148 })
149 .collect(),
150 };
151
152 (id, thread)
153 });
154
155 let database_future = ThreadsDatabase::global_future(cx);
156 cx.spawn(|this, mut cx| async move {
157 let database = database_future.await.map_err(|err| anyhow!(err))?;
158 database.save_thread(metadata, thread).await?;
159
160 this.update(&mut cx, |this, cx| this.reload(cx))?.await
161 })
162 }
163
164 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
165 let id = id.clone();
166 let database_future = ThreadsDatabase::global_future(cx);
167 cx.spawn(|this, mut cx| async move {
168 let database = database_future.await.map_err(|err| anyhow!(err))?;
169 database.delete_thread(id.clone()).await?;
170
171 this.update(&mut cx, |this, _cx| {
172 this.threads.retain(|thread| thread.id != id)
173 })
174 })
175 }
176
177 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
178 let database_future = ThreadsDatabase::global_future(cx);
179 cx.spawn(|this, mut cx| async move {
180 let threads = database_future
181 .await
182 .map_err(|err| anyhow!(err))?
183 .list_threads()
184 .await?;
185
186 this.update(&mut cx, |this, cx| {
187 this.threads = threads;
188 cx.notify();
189 })
190 })
191 }
192
193 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
194 cx.subscribe(
195 &self.context_server_manager.clone(),
196 Self::handle_context_server_event,
197 )
198 .detach();
199 }
200
201 fn handle_context_server_event(
202 &mut self,
203 context_server_manager: Entity<ContextServerManager>,
204 event: &context_server::manager::Event,
205 cx: &mut Context<Self>,
206 ) {
207 let tool_working_set = self.tools.clone();
208 match event {
209 context_server::manager::Event::ServerStarted { server_id } => {
210 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
211 let context_server_manager = context_server_manager.clone();
212 cx.spawn({
213 let server = server.clone();
214 let server_id = server_id.clone();
215 |this, mut cx| async move {
216 let Some(protocol) = server.client() else {
217 return;
218 };
219
220 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
221 if let Some(tools) = protocol.list_tools().await.log_err() {
222 let tool_ids = tools
223 .tools
224 .into_iter()
225 .map(|tool| {
226 log::info!(
227 "registering context server tool: {:?}",
228 tool.name
229 );
230 tool_working_set.insert(Arc::new(
231 ContextServerTool::new(
232 context_server_manager.clone(),
233 server.id(),
234 tool,
235 ),
236 ))
237 })
238 .collect::<Vec<_>>();
239
240 this.update(&mut cx, |this, _cx| {
241 this.context_server_tool_ids.insert(server_id, tool_ids);
242 })
243 .log_err();
244 }
245 }
246 }
247 })
248 .detach();
249 }
250 }
251 context_server::manager::Event::ServerStopped { server_id } => {
252 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
253 tool_working_set.remove(&tool_ids);
254 }
255 }
256 }
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct SavedThreadMetadata {
262 pub id: ThreadId,
263 pub summary: SharedString,
264 pub updated_at: DateTime<Utc>,
265}
266
267#[derive(Serialize, Deserialize)]
268pub struct SavedThread {
269 pub summary: SharedString,
270 pub updated_at: DateTime<Utc>,
271 pub messages: Vec<SavedMessage>,
272}
273
274#[derive(Debug, Serialize, Deserialize)]
275pub struct SavedMessage {
276 pub id: MessageId,
277 pub role: Role,
278 pub text: String,
279 #[serde(default)]
280 pub tool_uses: Vec<SavedToolUse>,
281 #[serde(default)]
282 pub tool_results: Vec<SavedToolResult>,
283}
284
285#[derive(Debug, Serialize, Deserialize)]
286pub struct SavedToolUse {
287 pub id: LanguageModelToolUseId,
288 pub name: SharedString,
289 pub input: serde_json::Value,
290}
291
292#[derive(Debug, Serialize, Deserialize)]
293pub struct SavedToolResult {
294 pub tool_use_id: LanguageModelToolUseId,
295 pub is_error: bool,
296 pub content: Arc<str>,
297}
298
299struct GlobalThreadsDatabase(
300 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
301);
302
303impl Global for GlobalThreadsDatabase {}
304
305pub(crate) struct ThreadsDatabase {
306 executor: BackgroundExecutor,
307 env: heed::Env,
308 threads: Database<SerdeBincode<ThreadId>, SerdeJson<SavedThread>>,
309}
310
311impl ThreadsDatabase {
312 fn global_future(
313 cx: &mut App,
314 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
315 GlobalThreadsDatabase::global(cx).0.clone()
316 }
317
318 fn init(cx: &mut App) {
319 let executor = cx.background_executor().clone();
320 let database_future = executor
321 .spawn({
322 let executor = executor.clone();
323 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
324 async move { ThreadsDatabase::new(database_path, executor) }
325 })
326 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
327 .boxed()
328 .shared();
329
330 cx.set_global(GlobalThreadsDatabase(database_future));
331 }
332
333 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
334 std::fs::create_dir_all(&path)?;
335
336 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
337 let env = unsafe {
338 heed::EnvOpenOptions::new()
339 .map_size(ONE_GB_IN_BYTES)
340 .max_dbs(1)
341 .open(path)?
342 };
343
344 let mut txn = env.write_txn()?;
345 let threads = env.create_database(&mut txn, Some("threads"))?;
346 txn.commit()?;
347
348 Ok(Self {
349 executor,
350 env,
351 threads,
352 })
353 }
354
355 pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
356 let env = self.env.clone();
357 let threads = self.threads;
358
359 self.executor.spawn(async move {
360 let txn = env.read_txn()?;
361 let mut iter = threads.iter(&txn)?;
362 let mut threads = Vec::new();
363 while let Some((key, value)) = iter.next().transpose()? {
364 threads.push(SavedThreadMetadata {
365 id: key,
366 summary: value.summary,
367 updated_at: value.updated_at,
368 });
369 }
370
371 Ok(threads)
372 })
373 }
374
375 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
376 let env = self.env.clone();
377 let threads = self.threads;
378
379 self.executor.spawn(async move {
380 let txn = env.read_txn()?;
381 let thread = threads.get(&txn, &id)?;
382 Ok(thread)
383 })
384 }
385
386 pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
387 let env = self.env.clone();
388 let threads = self.threads;
389
390 self.executor.spawn(async move {
391 let mut txn = env.write_txn()?;
392 threads.put(&mut txn, &id, &thread)?;
393 txn.commit()?;
394 Ok(())
395 })
396 }
397
398 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
399 let env = self.env.clone();
400 let threads = self.threads;
401
402 self.executor.spawn(async move {
403 let mut txn = env.write_txn()?;
404 threads.delete(&mut txn, &id)?;
405 txn.commit()?;
406 Ok(())
407 })
408 }
409}