thread_store.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use assistant_tool::{ToolId, ToolWorkingSet};
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use context_server::manager::ContextServerManager;
  9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 10use futures::future::{self, BoxFuture, Shared};
 11use futures::FutureExt as _;
 12use gpui::{prelude::*, App, BackgroundExecutor, Context, Entity, SharedString, Task};
 13use heed::types::SerdeBincode;
 14use heed::Database;
 15use language_model::Role;
 16use project::Project;
 17use serde::{Deserialize, Serialize};
 18use util::ResultExt as _;
 19
 20use crate::thread::{MessageId, Thread, ThreadId};
 21
 22pub struct ThreadStore {
 23    #[allow(unused)]
 24    project: Entity<Project>,
 25    tools: Arc<ToolWorkingSet>,
 26    context_server_manager: Entity<ContextServerManager>,
 27    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 28    threads: Vec<SavedThreadMetadata>,
 29    database_future: Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 30}
 31
 32impl ThreadStore {
 33    pub fn new(
 34        project: Entity<Project>,
 35        tools: Arc<ToolWorkingSet>,
 36        cx: &mut App,
 37    ) -> Task<Result<Entity<Self>>> {
 38        cx.spawn(|mut cx| async move {
 39            let this = cx.new(|cx: &mut Context<Self>| {
 40                let context_server_factory_registry =
 41                    ContextServerFactoryRegistry::default_global(cx);
 42                let context_server_manager = cx.new(|cx| {
 43                    ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 44                });
 45
 46                let executor = cx.background_executor().clone();
 47                let database_future = executor
 48                    .spawn({
 49                        let executor = executor.clone();
 50                        let database_path = paths::support_dir().join("threads/threads-db.0.mdb");
 51                        async move { ThreadsDatabase::new(database_path, executor) }
 52                    })
 53                    .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 54                    .boxed()
 55                    .shared();
 56
 57                let this = Self {
 58                    project,
 59                    tools,
 60                    context_server_manager,
 61                    context_server_tool_ids: HashMap::default(),
 62                    threads: Vec::new(),
 63                    database_future,
 64                };
 65                this.register_context_server_handlers(cx);
 66
 67                this
 68            })?;
 69
 70            this.update(&mut cx, |this, cx| this.reload(cx))?.await?;
 71
 72            Ok(this)
 73        })
 74    }
 75
 76    /// Returns the number of threads.
 77    pub fn thread_count(&self) -> usize {
 78        self.threads.len()
 79    }
 80
 81    pub fn threads(&self) -> Vec<SavedThreadMetadata> {
 82        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 83        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 84        threads
 85    }
 86
 87    pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
 88        self.threads().into_iter().take(limit).collect()
 89    }
 90
 91    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 92        cx.new(|cx| Thread::new(self.tools.clone(), cx))
 93    }
 94
 95    pub fn open_thread(
 96        &self,
 97        id: &ThreadId,
 98        cx: &mut Context<Self>,
 99    ) -> Task<Result<Entity<Thread>>> {
100        let id = id.clone();
101        let database_future = self.database_future.clone();
102        cx.spawn(|this, mut cx| async move {
103            let database = database_future.await.map_err(|err| anyhow!(err))?;
104            let thread = database
105                .try_find_thread(id.clone())
106                .await?
107                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
108
109            this.update(&mut cx, |this, cx| {
110                cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
111            })
112        })
113    }
114
115    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
116        let (metadata, thread) = thread.update(cx, |thread, _cx| {
117            let id = thread.id().clone();
118            let thread = SavedThread {
119                summary: thread.summary_or_default(),
120                updated_at: thread.updated_at(),
121                messages: thread
122                    .messages()
123                    .map(|message| SavedMessage {
124                        id: message.id,
125                        role: message.role,
126                        text: message.text.clone(),
127                    })
128                    .collect(),
129            };
130
131            (id, thread)
132        });
133
134        let database_future = self.database_future.clone();
135        cx.spawn(|this, mut cx| async move {
136            let database = database_future.await.map_err(|err| anyhow!(err))?;
137            database.save_thread(metadata, thread).await?;
138
139            this.update(&mut cx, |this, cx| this.reload(cx))?.await
140        })
141    }
142
143    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
144        let id = id.clone();
145        let database_future = self.database_future.clone();
146        cx.spawn(|this, mut cx| async move {
147            let database = database_future.await.map_err(|err| anyhow!(err))?;
148            database.delete_thread(id.clone()).await?;
149
150            this.update(&mut cx, |this, _cx| {
151                this.threads.retain(|thread| thread.id != id)
152            })
153        })
154    }
155
156    fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
157        let database_future = self.database_future.clone();
158        cx.spawn(|this, mut cx| async move {
159            let threads = database_future
160                .await
161                .map_err(|err| anyhow!(err))?
162                .list_threads()
163                .await?;
164
165            this.update(&mut cx, |this, cx| {
166                this.threads = threads;
167                cx.notify();
168            })
169        })
170    }
171
172    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
173        cx.subscribe(
174            &self.context_server_manager.clone(),
175            Self::handle_context_server_event,
176        )
177        .detach();
178    }
179
180    fn handle_context_server_event(
181        &mut self,
182        context_server_manager: Entity<ContextServerManager>,
183        event: &context_server::manager::Event,
184        cx: &mut Context<Self>,
185    ) {
186        let tool_working_set = self.tools.clone();
187        match event {
188            context_server::manager::Event::ServerStarted { server_id } => {
189                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
190                    let context_server_manager = context_server_manager.clone();
191                    cx.spawn({
192                        let server = server.clone();
193                        let server_id = server_id.clone();
194                        |this, mut cx| async move {
195                            let Some(protocol) = server.client() else {
196                                return;
197                            };
198
199                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
200                                if let Some(tools) = protocol.list_tools().await.log_err() {
201                                    let tool_ids = tools
202                                        .tools
203                                        .into_iter()
204                                        .map(|tool| {
205                                            log::info!(
206                                                "registering context server tool: {:?}",
207                                                tool.name
208                                            );
209                                            tool_working_set.insert(Arc::new(
210                                                ContextServerTool::new(
211                                                    context_server_manager.clone(),
212                                                    server.id(),
213                                                    tool,
214                                                ),
215                                            ))
216                                        })
217                                        .collect::<Vec<_>>();
218
219                                    this.update(&mut cx, |this, _cx| {
220                                        this.context_server_tool_ids.insert(server_id, tool_ids);
221                                    })
222                                    .log_err();
223                                }
224                            }
225                        }
226                    })
227                    .detach();
228                }
229            }
230            context_server::manager::Event::ServerStopped { server_id } => {
231                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
232                    tool_working_set.remove(&tool_ids);
233                }
234            }
235        }
236    }
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct SavedThreadMetadata {
241    pub id: ThreadId,
242    pub summary: SharedString,
243    pub updated_at: DateTime<Utc>,
244}
245
246#[derive(Serialize, Deserialize)]
247pub struct SavedThread {
248    pub summary: SharedString,
249    pub updated_at: DateTime<Utc>,
250    pub messages: Vec<SavedMessage>,
251}
252
253#[derive(Serialize, Deserialize)]
254pub struct SavedMessage {
255    pub id: MessageId,
256    pub role: Role,
257    pub text: String,
258}
259
260struct ThreadsDatabase {
261    executor: BackgroundExecutor,
262    env: heed::Env,
263    threads: Database<SerdeBincode<ThreadId>, SerdeBincode<SavedThread>>,
264}
265
266impl ThreadsDatabase {
267    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
268        std::fs::create_dir_all(&path)?;
269
270        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
271        let env = unsafe {
272            heed::EnvOpenOptions::new()
273                .map_size(ONE_GB_IN_BYTES)
274                .max_dbs(1)
275                .open(path)?
276        };
277
278        let mut txn = env.write_txn()?;
279        let threads = env.create_database(&mut txn, Some("threads"))?;
280        txn.commit()?;
281
282        Ok(Self {
283            executor,
284            env,
285            threads,
286        })
287    }
288
289    pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
290        let env = self.env.clone();
291        let threads = self.threads;
292
293        self.executor.spawn(async move {
294            let txn = env.read_txn()?;
295            let mut iter = threads.iter(&txn)?;
296            let mut threads = Vec::new();
297            while let Some((key, value)) = iter.next().transpose()? {
298                threads.push(SavedThreadMetadata {
299                    id: key,
300                    summary: value.summary,
301                    updated_at: value.updated_at,
302                });
303            }
304
305            Ok(threads)
306        })
307    }
308
309    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
310        let env = self.env.clone();
311        let threads = self.threads;
312
313        self.executor.spawn(async move {
314            let txn = env.read_txn()?;
315            let thread = threads.get(&txn, &id)?;
316            Ok(thread)
317        })
318    }
319
320    pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
321        let env = self.env.clone();
322        let threads = self.threads;
323
324        self.executor.spawn(async move {
325            let mut txn = env.write_txn()?;
326            threads.put(&mut txn, &id, &thread)?;
327            txn.commit()?;
328            Ok(())
329        })
330    }
331
332    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
333        let env = self.env.clone();
334        let threads = self.threads;
335
336        self.executor.spawn(async move {
337            let mut txn = env.write_txn()?;
338            threads.delete(&mut txn, &id)?;
339            txn.commit()?;
340            Ok(())
341        })
342    }
343}