thread_store.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use assistant_tool::{ToolId, ToolWorkingSet};
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use context_server::manager::ContextServerManager;
  9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 10use futures::future::{self, BoxFuture, Shared};
 11use futures::FutureExt as _;
 12use gpui::{prelude::*, App, BackgroundExecutor, Context, Entity, SharedString, Task};
 13use heed::types::SerdeBincode;
 14use heed::Database;
 15use language_model::Role;
 16use project::Project;
 17use serde::{Deserialize, Serialize};
 18use util::ResultExt as _;
 19
 20use crate::thread::{MessageId, Thread, ThreadId};
 21
 22pub struct ThreadStore {
 23    #[allow(unused)]
 24    project: Entity<Project>,
 25    tools: Arc<ToolWorkingSet>,
 26    context_server_manager: Entity<ContextServerManager>,
 27    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 28    threads: Vec<SavedThreadMetadata>,
 29    database_future: Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 30}
 31
 32impl ThreadStore {
 33    pub fn new(
 34        project: Entity<Project>,
 35        tools: Arc<ToolWorkingSet>,
 36        cx: &mut App,
 37    ) -> Task<Result<Entity<Self>>> {
 38        cx.spawn(|mut cx| async move {
 39            let this = cx.new(|cx: &mut Context<Self>| {
 40                let context_server_factory_registry =
 41                    ContextServerFactoryRegistry::default_global(cx);
 42                let context_server_manager = cx.new(|cx| {
 43                    ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 44                });
 45
 46                let executor = cx.background_executor().clone();
 47                let database_future = executor
 48                    .spawn({
 49                        let executor = executor.clone();
 50                        let database_path = paths::support_dir().join("threads/threads-db.0.mdb");
 51                        async move { ThreadsDatabase::new(database_path, executor) }
 52                    })
 53                    .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 54                    .boxed()
 55                    .shared();
 56
 57                let this = Self {
 58                    project,
 59                    tools,
 60                    context_server_manager,
 61                    context_server_tool_ids: HashMap::default(),
 62                    threads: Vec::new(),
 63                    database_future,
 64                };
 65                this.register_context_server_handlers(cx);
 66
 67                this
 68            })?;
 69
 70            log::info!("[assistant2-debug] reloading threads");
 71            this.update(&mut cx, |this, cx| this.reload(cx))?.await?;
 72            log::info!("[assistant2-debug] finished reloading threads");
 73
 74            Ok(this)
 75        })
 76    }
 77
 78    /// Returns the number of threads.
 79    pub fn thread_count(&self) -> usize {
 80        self.threads.len()
 81    }
 82
 83    pub fn threads(&self) -> Vec<SavedThreadMetadata> {
 84        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 85        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 86        threads
 87    }
 88
 89    pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
 90        self.threads().into_iter().take(limit).collect()
 91    }
 92
 93    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 94        cx.new(|cx| Thread::new(self.tools.clone(), cx))
 95    }
 96
 97    pub fn open_thread(
 98        &self,
 99        id: &ThreadId,
100        cx: &mut Context<Self>,
101    ) -> Task<Result<Entity<Thread>>> {
102        let id = id.clone();
103        let database_future = self.database_future.clone();
104        cx.spawn(|this, mut cx| async move {
105            let database = database_future.await.map_err(|err| anyhow!(err))?;
106            let thread = database
107                .try_find_thread(id.clone())
108                .await?
109                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
110
111            this.update(&mut cx, |this, cx| {
112                cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
113            })
114        })
115    }
116
117    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
118        let (metadata, thread) = thread.update(cx, |thread, _cx| {
119            let id = thread.id().clone();
120            let thread = SavedThread {
121                summary: thread.summary_or_default(),
122                updated_at: thread.updated_at(),
123                messages: thread
124                    .messages()
125                    .map(|message| SavedMessage {
126                        id: message.id,
127                        role: message.role,
128                        text: message.text.clone(),
129                    })
130                    .collect(),
131            };
132
133            (id, thread)
134        });
135
136        let database_future = self.database_future.clone();
137        cx.spawn(|this, mut cx| async move {
138            let database = database_future.await.map_err(|err| anyhow!(err))?;
139            database.save_thread(metadata, thread).await?;
140
141            this.update(&mut cx, |this, cx| this.reload(cx))?.await
142        })
143    }
144
145    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
146        let id = id.clone();
147        let database_future = self.database_future.clone();
148        cx.spawn(|this, mut cx| async move {
149            let database = database_future.await.map_err(|err| anyhow!(err))?;
150            database.delete_thread(id.clone()).await?;
151
152            this.update(&mut cx, |this, _cx| {
153                this.threads.retain(|thread| thread.id != id)
154            })
155        })
156    }
157
158    fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
159        let database_future = self.database_future.clone();
160        cx.spawn(|this, mut cx| async move {
161            let threads = database_future
162                .await
163                .map_err(|err| anyhow!(err))?
164                .list_threads()
165                .await?;
166
167            this.update(&mut cx, |this, cx| {
168                this.threads = threads;
169                cx.notify();
170            })
171        })
172    }
173
174    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
175        cx.subscribe(
176            &self.context_server_manager.clone(),
177            Self::handle_context_server_event,
178        )
179        .detach();
180    }
181
182    fn handle_context_server_event(
183        &mut self,
184        context_server_manager: Entity<ContextServerManager>,
185        event: &context_server::manager::Event,
186        cx: &mut Context<Self>,
187    ) {
188        let tool_working_set = self.tools.clone();
189        match event {
190            context_server::manager::Event::ServerStarted { server_id } => {
191                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
192                    let context_server_manager = context_server_manager.clone();
193                    cx.spawn({
194                        let server = server.clone();
195                        let server_id = server_id.clone();
196                        |this, mut cx| async move {
197                            let Some(protocol) = server.client() else {
198                                return;
199                            };
200
201                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
202                                if let Some(tools) = protocol.list_tools().await.log_err() {
203                                    let tool_ids = tools
204                                        .tools
205                                        .into_iter()
206                                        .map(|tool| {
207                                            log::info!(
208                                                "registering context server tool: {:?}",
209                                                tool.name
210                                            );
211                                            tool_working_set.insert(Arc::new(
212                                                ContextServerTool::new(
213                                                    context_server_manager.clone(),
214                                                    server.id(),
215                                                    tool,
216                                                ),
217                                            ))
218                                        })
219                                        .collect::<Vec<_>>();
220
221                                    this.update(&mut cx, |this, _cx| {
222                                        this.context_server_tool_ids.insert(server_id, tool_ids);
223                                    })
224                                    .log_err();
225                                }
226                            }
227                        }
228                    })
229                    .detach();
230                }
231            }
232            context_server::manager::Event::ServerStopped { server_id } => {
233                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
234                    tool_working_set.remove(&tool_ids);
235                }
236            }
237        }
238    }
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct SavedThreadMetadata {
243    pub id: ThreadId,
244    pub summary: SharedString,
245    pub updated_at: DateTime<Utc>,
246}
247
248#[derive(Serialize, Deserialize)]
249pub struct SavedThread {
250    pub summary: SharedString,
251    pub updated_at: DateTime<Utc>,
252    pub messages: Vec<SavedMessage>,
253}
254
255#[derive(Serialize, Deserialize)]
256pub struct SavedMessage {
257    pub id: MessageId,
258    pub role: Role,
259    pub text: String,
260}
261
262struct ThreadsDatabase {
263    executor: BackgroundExecutor,
264    env: heed::Env,
265    threads: Database<SerdeBincode<ThreadId>, SerdeBincode<SavedThread>>,
266}
267
268impl ThreadsDatabase {
269    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
270        std::fs::create_dir_all(&path)?;
271
272        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
273        let env = unsafe {
274            heed::EnvOpenOptions::new()
275                .map_size(ONE_GB_IN_BYTES)
276                .max_dbs(1)
277                .open(path)?
278        };
279
280        let mut txn = env.write_txn()?;
281        let threads = env.create_database(&mut txn, Some("threads"))?;
282        txn.commit()?;
283
284        Ok(Self {
285            executor,
286            env,
287            threads,
288        })
289    }
290
291    pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
292        let env = self.env.clone();
293        let threads = self.threads;
294
295        self.executor.spawn(async move {
296            let txn = env.read_txn()?;
297            let mut iter = threads.iter(&txn)?;
298            let mut threads = Vec::new();
299            while let Some((key, value)) = iter.next().transpose()? {
300                threads.push(SavedThreadMetadata {
301                    id: key,
302                    summary: value.summary,
303                    updated_at: value.updated_at,
304                });
305            }
306
307            Ok(threads)
308        })
309    }
310
311    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
312        let env = self.env.clone();
313        let threads = self.threads;
314
315        self.executor.spawn(async move {
316            let txn = env.read_txn()?;
317            let thread = threads.get(&txn, &id)?;
318            Ok(thread)
319        })
320    }
321
322    pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
323        let env = self.env.clone();
324        let threads = self.threads;
325
326        self.executor.spawn(async move {
327            let mut txn = env.write_txn()?;
328            threads.put(&mut txn, &id, &thread)?;
329            txn.commit()?;
330            Ok(())
331        })
332    }
333
334    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
335        let env = self.env.clone();
336        let threads = self.threads;
337
338        self.executor.spawn(async move {
339            let mut txn = env.write_txn()?;
340            threads.delete(&mut txn, &id)?;
341            txn.commit()?;
342            Ok(())
343        })
344    }
345}