thread_store.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use assistant_tool::{ToolId, ToolWorkingSet};
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use context_server::manager::ContextServerManager;
  9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 10use futures::future::{self, BoxFuture, Shared};
 11use futures::FutureExt as _;
 12use gpui::{prelude::*, AppContext, BackgroundExecutor, Model, ModelContext, SharedString, Task};
 13use heed::types::SerdeBincode;
 14use heed::Database;
 15use language_model::Role;
 16use project::Project;
 17use serde::{Deserialize, Serialize};
 18use util::ResultExt as _;
 19
 20use crate::thread::{MessageId, Thread, ThreadId};
 21
 22pub struct ThreadStore {
 23    #[allow(unused)]
 24    project: Model<Project>,
 25    tools: Arc<ToolWorkingSet>,
 26    context_server_manager: Model<ContextServerManager>,
 27    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 28    threads: Vec<SavedThreadMetadata>,
 29    database_future: Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 30}
 31
 32impl ThreadStore {
 33    pub fn new(
 34        project: Model<Project>,
 35        tools: Arc<ToolWorkingSet>,
 36        cx: &mut AppContext,
 37    ) -> Task<Result<Model<Self>>> {
 38        cx.spawn(|mut cx| async move {
 39            let this = cx.new_model(|cx: &mut ModelContext<Self>| {
 40                let context_server_factory_registry =
 41                    ContextServerFactoryRegistry::default_global(cx);
 42                let context_server_manager = cx.new_model(|cx| {
 43                    ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 44                });
 45
 46                let executor = cx.background_executor().clone();
 47                let database_future = executor
 48                    .spawn({
 49                        let executor = executor.clone();
 50                        let database_path = paths::support_dir().join("threads/threads-db.0.mdb");
 51                        async move { ThreadsDatabase::new(database_path, executor) }
 52                    })
 53                    .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 54                    .boxed()
 55                    .shared();
 56
 57                let this = Self {
 58                    project,
 59                    tools,
 60                    context_server_manager,
 61                    context_server_tool_ids: HashMap::default(),
 62                    threads: Vec::new(),
 63                    database_future,
 64                };
 65                this.register_context_server_handlers(cx);
 66
 67                this
 68            })?;
 69
 70            this.update(&mut cx, |this, cx| this.reload(cx))?.await?;
 71
 72            Ok(this)
 73        })
 74    }
 75
 76    /// Returns the number of threads.
 77    pub fn thread_count(&self) -> usize {
 78        self.threads.len()
 79    }
 80
 81    pub fn threads(&self) -> Vec<SavedThreadMetadata> {
 82        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 83        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 84        threads
 85    }
 86
 87    pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
 88        self.threads().into_iter().take(limit).collect()
 89    }
 90
 91    pub fn create_thread(&mut self, cx: &mut ModelContext<Self>) -> Model<Thread> {
 92        cx.new_model(|cx| Thread::new(self.tools.clone(), cx))
 93    }
 94
 95    pub fn open_thread(
 96        &self,
 97        id: &ThreadId,
 98        cx: &mut ModelContext<Self>,
 99    ) -> Task<Result<Model<Thread>>> {
100        let id = id.clone();
101        let database_future = self.database_future.clone();
102        cx.spawn(|this, mut cx| async move {
103            let database = database_future.await.map_err(|err| anyhow!(err))?;
104            let thread = database
105                .try_find_thread(id.clone())
106                .await?
107                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
108
109            this.update(&mut cx, |this, cx| {
110                cx.new_model(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
111            })
112        })
113    }
114
115    pub fn save_thread(
116        &self,
117        thread: &Model<Thread>,
118        cx: &mut ModelContext<Self>,
119    ) -> Task<Result<()>> {
120        let (metadata, thread) = thread.update(cx, |thread, _cx| {
121            let id = thread.id().clone();
122            let thread = SavedThread {
123                summary: thread.summary_or_default(),
124                updated_at: thread.updated_at(),
125                messages: thread
126                    .messages()
127                    .map(|message| SavedMessage {
128                        id: message.id,
129                        role: message.role,
130                        text: message.text.clone(),
131                    })
132                    .collect(),
133            };
134
135            (id, thread)
136        });
137
138        let database_future = self.database_future.clone();
139        cx.spawn(|this, mut cx| async move {
140            let database = database_future.await.map_err(|err| anyhow!(err))?;
141            database.save_thread(metadata, thread).await?;
142
143            this.update(&mut cx, |this, cx| this.reload(cx))?.await
144        })
145    }
146
147    pub fn delete_thread(
148        &mut self,
149        id: &ThreadId,
150        cx: &mut ModelContext<Self>,
151    ) -> Task<Result<()>> {
152        let id = id.clone();
153        let database_future = self.database_future.clone();
154        cx.spawn(|this, mut cx| async move {
155            let database = database_future.await.map_err(|err| anyhow!(err))?;
156            database.delete_thread(id.clone()).await?;
157
158            this.update(&mut cx, |this, _cx| {
159                this.threads.retain(|thread| thread.id != id)
160            })
161        })
162    }
163
164    fn reload(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
165        let database_future = self.database_future.clone();
166        cx.spawn(|this, mut cx| async move {
167            let threads = database_future
168                .await
169                .map_err(|err| anyhow!(err))?
170                .list_threads()
171                .await?;
172
173            this.update(&mut cx, |this, cx| {
174                this.threads = threads;
175                cx.notify();
176            })
177        })
178    }
179
180    fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
181        cx.subscribe(
182            &self.context_server_manager.clone(),
183            Self::handle_context_server_event,
184        )
185        .detach();
186    }
187
188    fn handle_context_server_event(
189        &mut self,
190        context_server_manager: Model<ContextServerManager>,
191        event: &context_server::manager::Event,
192        cx: &mut ModelContext<Self>,
193    ) {
194        let tool_working_set = self.tools.clone();
195        match event {
196            context_server::manager::Event::ServerStarted { server_id } => {
197                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
198                    let context_server_manager = context_server_manager.clone();
199                    cx.spawn({
200                        let server = server.clone();
201                        let server_id = server_id.clone();
202                        |this, mut cx| async move {
203                            let Some(protocol) = server.client() else {
204                                return;
205                            };
206
207                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
208                                if let Some(tools) = protocol.list_tools().await.log_err() {
209                                    let tool_ids = tools
210                                        .tools
211                                        .into_iter()
212                                        .map(|tool| {
213                                            log::info!(
214                                                "registering context server tool: {:?}",
215                                                tool.name
216                                            );
217                                            tool_working_set.insert(Arc::new(
218                                                ContextServerTool::new(
219                                                    context_server_manager.clone(),
220                                                    server.id(),
221                                                    tool,
222                                                ),
223                                            ))
224                                        })
225                                        .collect::<Vec<_>>();
226
227                                    this.update(&mut cx, |this, _cx| {
228                                        this.context_server_tool_ids.insert(server_id, tool_ids);
229                                    })
230                                    .log_err();
231                                }
232                            }
233                        }
234                    })
235                    .detach();
236                }
237            }
238            context_server::manager::Event::ServerStopped { server_id } => {
239                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
240                    tool_working_set.remove(&tool_ids);
241                }
242            }
243        }
244    }
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct SavedThreadMetadata {
249    pub id: ThreadId,
250    pub summary: SharedString,
251    pub updated_at: DateTime<Utc>,
252}
253
254#[derive(Serialize, Deserialize)]
255pub struct SavedThread {
256    pub summary: SharedString,
257    pub updated_at: DateTime<Utc>,
258    pub messages: Vec<SavedMessage>,
259}
260
261#[derive(Serialize, Deserialize)]
262pub struct SavedMessage {
263    pub id: MessageId,
264    pub role: Role,
265    pub text: String,
266}
267
268struct ThreadsDatabase {
269    executor: BackgroundExecutor,
270    env: heed::Env,
271    threads: Database<SerdeBincode<ThreadId>, SerdeBincode<SavedThread>>,
272}
273
274impl ThreadsDatabase {
275    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
276        std::fs::create_dir_all(&path)?;
277
278        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
279        let env = unsafe {
280            heed::EnvOpenOptions::new()
281                .map_size(ONE_GB_IN_BYTES)
282                .max_dbs(1)
283                .open(path)?
284        };
285
286        let mut txn = env.write_txn()?;
287        let threads = env.create_database(&mut txn, Some("threads"))?;
288        txn.commit()?;
289
290        Ok(Self {
291            executor,
292            env,
293            threads,
294        })
295    }
296
297    pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
298        let env = self.env.clone();
299        let threads = self.threads;
300
301        self.executor.spawn(async move {
302            let txn = env.read_txn()?;
303            let mut iter = threads.iter(&txn)?;
304            let mut threads = Vec::new();
305            while let Some((key, value)) = iter.next().transpose()? {
306                threads.push(SavedThreadMetadata {
307                    id: key,
308                    summary: value.summary,
309                    updated_at: value.updated_at,
310                });
311            }
312
313            Ok(threads)
314        })
315    }
316
317    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
318        let env = self.env.clone();
319        let threads = self.threads;
320
321        self.executor.spawn(async move {
322            let txn = env.read_txn()?;
323            let thread = threads.get(&txn, &id)?;
324            Ok(thread)
325        })
326    }
327
328    pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
329        let env = self.env.clone();
330        let threads = self.threads;
331
332        self.executor.spawn(async move {
333            let mut txn = env.write_txn()?;
334            threads.put(&mut txn, &id, &thread)?;
335            txn.commit()?;
336            Ok(())
337        })
338    }
339
340    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
341        let env = self.env.clone();
342        let threads = self.threads;
343
344        self.executor.spawn(async move {
345            let mut txn = env.write_txn()?;
346            threads.delete(&mut txn, &id)?;
347            txn.commit()?;
348            Ok(())
349        })
350    }
351}