thread_store.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use assistant_tool::{ToolId, ToolWorkingSet};
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use context_server::manager::ContextServerManager;
  9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 10use futures::future::{self, BoxFuture, Shared};
 11use futures::FutureExt as _;
 12use gpui::{
 13    prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
 14};
 15use heed::types::{SerdeBincode, SerdeJson};
 16use heed::Database;
 17use language_model::Role;
 18use project::Project;
 19use serde::{Deserialize, Serialize};
 20use util::ResultExt as _;
 21
 22use crate::thread::{MessageId, Thread, ThreadId};
 23
 24pub fn init(cx: &mut App) {
 25    ThreadsDatabase::init(cx);
 26}
 27
 28pub struct ThreadStore {
 29    #[allow(unused)]
 30    project: Entity<Project>,
 31    tools: Arc<ToolWorkingSet>,
 32    context_server_manager: Entity<ContextServerManager>,
 33    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 34    threads: Vec<SavedThreadMetadata>,
 35}
 36
 37impl ThreadStore {
 38    pub fn new(
 39        project: Entity<Project>,
 40        tools: Arc<ToolWorkingSet>,
 41        cx: &mut App,
 42    ) -> Result<Entity<Self>> {
 43        let this = cx.new(|cx| {
 44            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 45            let context_server_manager = cx.new(|cx| {
 46                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 47            });
 48
 49            let this = Self {
 50                project,
 51                tools,
 52                context_server_manager,
 53                context_server_tool_ids: HashMap::default(),
 54                threads: Vec::new(),
 55            };
 56            this.register_context_server_handlers(cx);
 57            this.reload(cx).detach_and_log_err(cx);
 58
 59            this
 60        });
 61
 62        Ok(this)
 63    }
 64
 65    /// Returns the number of threads.
 66    pub fn thread_count(&self) -> usize {
 67        self.threads.len()
 68    }
 69
 70    pub fn threads(&self) -> Vec<SavedThreadMetadata> {
 71        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 72        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 73        threads
 74    }
 75
 76    pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
 77        self.threads().into_iter().take(limit).collect()
 78    }
 79
 80    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 81        cx.new(|cx| Thread::new(self.tools.clone(), cx))
 82    }
 83
 84    pub fn open_thread(
 85        &self,
 86        id: &ThreadId,
 87        cx: &mut Context<Self>,
 88    ) -> Task<Result<Entity<Thread>>> {
 89        let id = id.clone();
 90        let database_future = ThreadsDatabase::global_future(cx);
 91        cx.spawn(|this, mut cx| async move {
 92            let database = database_future.await.map_err(|err| anyhow!(err))?;
 93            let thread = database
 94                .try_find_thread(id.clone())
 95                .await?
 96                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
 97
 98            this.update(&mut cx, |this, cx| {
 99                cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
100            })
101        })
102    }
103
104    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
105        let (metadata, thread) = thread.update(cx, |thread, _cx| {
106            let id = thread.id().clone();
107            let thread = SavedThread {
108                summary: thread.summary_or_default(),
109                updated_at: thread.updated_at(),
110                messages: thread
111                    .messages()
112                    .map(|message| SavedMessage {
113                        id: message.id,
114                        role: message.role,
115                        text: message.text.clone(),
116                    })
117                    .collect(),
118            };
119
120            (id, thread)
121        });
122
123        let database_future = ThreadsDatabase::global_future(cx);
124        cx.spawn(|this, mut cx| async move {
125            let database = database_future.await.map_err(|err| anyhow!(err))?;
126            database.save_thread(metadata, thread).await?;
127
128            this.update(&mut cx, |this, cx| this.reload(cx))?.await
129        })
130    }
131
132    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
133        let id = id.clone();
134        let database_future = ThreadsDatabase::global_future(cx);
135        cx.spawn(|this, mut cx| async move {
136            let database = database_future.await.map_err(|err| anyhow!(err))?;
137            database.delete_thread(id.clone()).await?;
138
139            this.update(&mut cx, |this, _cx| {
140                this.threads.retain(|thread| thread.id != id)
141            })
142        })
143    }
144
145    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
146        let database_future = ThreadsDatabase::global_future(cx);
147        cx.spawn(|this, mut cx| async move {
148            let threads = database_future
149                .await
150                .map_err(|err| anyhow!(err))?
151                .list_threads()
152                .await?;
153
154            this.update(&mut cx, |this, cx| {
155                this.threads = threads;
156                cx.notify();
157            })
158        })
159    }
160
161    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
162        cx.subscribe(
163            &self.context_server_manager.clone(),
164            Self::handle_context_server_event,
165        )
166        .detach();
167    }
168
169    fn handle_context_server_event(
170        &mut self,
171        context_server_manager: Entity<ContextServerManager>,
172        event: &context_server::manager::Event,
173        cx: &mut Context<Self>,
174    ) {
175        let tool_working_set = self.tools.clone();
176        match event {
177            context_server::manager::Event::ServerStarted { server_id } => {
178                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
179                    let context_server_manager = context_server_manager.clone();
180                    cx.spawn({
181                        let server = server.clone();
182                        let server_id = server_id.clone();
183                        |this, mut cx| async move {
184                            let Some(protocol) = server.client() else {
185                                return;
186                            };
187
188                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
189                                if let Some(tools) = protocol.list_tools().await.log_err() {
190                                    let tool_ids = tools
191                                        .tools
192                                        .into_iter()
193                                        .map(|tool| {
194                                            log::info!(
195                                                "registering context server tool: {:?}",
196                                                tool.name
197                                            );
198                                            tool_working_set.insert(Arc::new(
199                                                ContextServerTool::new(
200                                                    context_server_manager.clone(),
201                                                    server.id(),
202                                                    tool,
203                                                ),
204                                            ))
205                                        })
206                                        .collect::<Vec<_>>();
207
208                                    this.update(&mut cx, |this, _cx| {
209                                        this.context_server_tool_ids.insert(server_id, tool_ids);
210                                    })
211                                    .log_err();
212                                }
213                            }
214                        }
215                    })
216                    .detach();
217                }
218            }
219            context_server::manager::Event::ServerStopped { server_id } => {
220                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
221                    tool_working_set.remove(&tool_ids);
222                }
223            }
224        }
225    }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct SavedThreadMetadata {
230    pub id: ThreadId,
231    pub summary: SharedString,
232    pub updated_at: DateTime<Utc>,
233}
234
235#[derive(Serialize, Deserialize)]
236pub struct SavedThread {
237    pub summary: SharedString,
238    pub updated_at: DateTime<Utc>,
239    pub messages: Vec<SavedMessage>,
240}
241
242#[derive(Serialize, Deserialize)]
243pub struct SavedMessage {
244    pub id: MessageId,
245    pub role: Role,
246    pub text: String,
247}
248
249struct GlobalThreadsDatabase(
250    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
251);
252
253impl Global for GlobalThreadsDatabase {}
254
255pub(crate) struct ThreadsDatabase {
256    executor: BackgroundExecutor,
257    env: heed::Env,
258    threads: Database<SerdeBincode<ThreadId>, SerdeJson<SavedThread>>,
259}
260
261impl ThreadsDatabase {
262    fn global_future(
263        cx: &mut App,
264    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
265        GlobalThreadsDatabase::global(cx).0.clone()
266    }
267
268    fn init(cx: &mut App) {
269        let executor = cx.background_executor().clone();
270        let database_future = executor
271            .spawn({
272                let executor = executor.clone();
273                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
274                async move { ThreadsDatabase::new(database_path, executor) }
275            })
276            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
277            .boxed()
278            .shared();
279
280        cx.set_global(GlobalThreadsDatabase(database_future));
281    }
282
283    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
284        std::fs::create_dir_all(&path)?;
285
286        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
287        let env = unsafe {
288            heed::EnvOpenOptions::new()
289                .map_size(ONE_GB_IN_BYTES)
290                .max_dbs(1)
291                .open(path)?
292        };
293
294        let mut txn = env.write_txn()?;
295        let threads = env.create_database(&mut txn, Some("threads"))?;
296        txn.commit()?;
297
298        Ok(Self {
299            executor,
300            env,
301            threads,
302        })
303    }
304
305    pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
306        let env = self.env.clone();
307        let threads = self.threads;
308
309        self.executor.spawn(async move {
310            let txn = env.read_txn()?;
311            let mut iter = threads.iter(&txn)?;
312            let mut threads = Vec::new();
313            while let Some((key, value)) = iter.next().transpose()? {
314                threads.push(SavedThreadMetadata {
315                    id: key,
316                    summary: value.summary,
317                    updated_at: value.updated_at,
318                });
319            }
320
321            Ok(threads)
322        })
323    }
324
325    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
326        let env = self.env.clone();
327        let threads = self.threads;
328
329        self.executor.spawn(async move {
330            let txn = env.read_txn()?;
331            let thread = threads.get(&txn, &id)?;
332            Ok(thread)
333        })
334    }
335
336    pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
337        let env = self.env.clone();
338        let threads = self.threads;
339
340        self.executor.spawn(async move {
341            let mut txn = env.write_txn()?;
342            threads.put(&mut txn, &id, &thread)?;
343            txn.commit()?;
344            Ok(())
345        })
346    }
347
348    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
349        let env = self.env.clone();
350        let threads = self.threads;
351
352        self.executor.spawn(async move {
353            let mut txn = env.write_txn()?;
354            threads.delete(&mut txn, &id)?;
355            txn.commit()?;
356            Ok(())
357        })
358    }
359}