thread_store.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use assistant_tool::{ToolId, ToolWorkingSet};
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use context_server::manager::ContextServerManager;
  9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 10use futures::future::{self, BoxFuture, Shared};
 11use futures::FutureExt as _;
 12use gpui::{prelude::*, App, BackgroundExecutor, Context, Entity, SharedString, Task};
 13use heed::types::SerdeBincode;
 14use heed::Database;
 15use language_model::Role;
 16use project::Project;
 17use serde::{Deserialize, Serialize};
 18use util::ResultExt as _;
 19
 20use crate::thread::{MessageId, Thread, ThreadId};
 21
 22pub struct ThreadStore {
 23    #[allow(unused)]
 24    project: Entity<Project>,
 25    tools: Arc<ToolWorkingSet>,
 26    context_server_manager: Entity<ContextServerManager>,
 27    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 28    threads: Vec<SavedThreadMetadata>,
 29    database_future: Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 30}
 31
 32impl ThreadStore {
 33    pub fn new(
 34        project: Entity<Project>,
 35        tools: Arc<ToolWorkingSet>,
 36        cx: &mut App,
 37    ) -> Result<Entity<Self>> {
 38        let this = cx.new(|cx| {
 39            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 40            let context_server_manager = cx.new(|cx| {
 41                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 42            });
 43
 44            let executor = cx.background_executor().clone();
 45            let database_future = executor
 46                .spawn({
 47                    let executor = executor.clone();
 48                    let database_path = paths::support_dir().join("threads/threads-db.0.mdb");
 49                    async move { ThreadsDatabase::new(database_path, executor) }
 50                })
 51                .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 52                .boxed()
 53                .shared();
 54
 55            let this = Self {
 56                project,
 57                tools,
 58                context_server_manager,
 59                context_server_tool_ids: HashMap::default(),
 60                threads: Vec::new(),
 61                database_future,
 62            };
 63            this.register_context_server_handlers(cx);
 64            this.reload(cx).detach_and_log_err(cx);
 65
 66            this
 67        });
 68
 69        Ok(this)
 70    }
 71
 72    /// Returns the number of threads.
 73    pub fn thread_count(&self) -> usize {
 74        self.threads.len()
 75    }
 76
 77    pub fn threads(&self) -> Vec<SavedThreadMetadata> {
 78        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 79        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 80        threads
 81    }
 82
 83    pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
 84        self.threads().into_iter().take(limit).collect()
 85    }
 86
 87    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 88        cx.new(|cx| Thread::new(self.tools.clone(), cx))
 89    }
 90
 91    pub fn open_thread(
 92        &self,
 93        id: &ThreadId,
 94        cx: &mut Context<Self>,
 95    ) -> Task<Result<Entity<Thread>>> {
 96        let id = id.clone();
 97        let database_future = self.database_future.clone();
 98        cx.spawn(|this, mut cx| async move {
 99            let database = database_future.await.map_err(|err| anyhow!(err))?;
100            let thread = database
101                .try_find_thread(id.clone())
102                .await?
103                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
104
105            this.update(&mut cx, |this, cx| {
106                cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
107            })
108        })
109    }
110
111    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
112        let (metadata, thread) = thread.update(cx, |thread, _cx| {
113            let id = thread.id().clone();
114            let thread = SavedThread {
115                summary: thread.summary_or_default(),
116                updated_at: thread.updated_at(),
117                messages: thread
118                    .messages()
119                    .map(|message| SavedMessage {
120                        id: message.id,
121                        role: message.role,
122                        text: message.text.clone(),
123                    })
124                    .collect(),
125            };
126
127            (id, thread)
128        });
129
130        let database_future = self.database_future.clone();
131        cx.spawn(|this, mut cx| async move {
132            let database = database_future.await.map_err(|err| anyhow!(err))?;
133            database.save_thread(metadata, thread).await?;
134
135            this.update(&mut cx, |this, cx| this.reload(cx))?.await
136        })
137    }
138
139    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
140        let id = id.clone();
141        let database_future = self.database_future.clone();
142        cx.spawn(|this, mut cx| async move {
143            let database = database_future.await.map_err(|err| anyhow!(err))?;
144            database.delete_thread(id.clone()).await?;
145
146            this.update(&mut cx, |this, _cx| {
147                this.threads.retain(|thread| thread.id != id)
148            })
149        })
150    }
151
152    fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
153        let database_future = self.database_future.clone();
154        cx.spawn(|this, mut cx| async move {
155            let threads = database_future
156                .await
157                .map_err(|err| anyhow!(err))?
158                .list_threads()
159                .await?;
160
161            this.update(&mut cx, |this, cx| {
162                this.threads = threads;
163                cx.notify();
164            })
165        })
166    }
167
168    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
169        cx.subscribe(
170            &self.context_server_manager.clone(),
171            Self::handle_context_server_event,
172        )
173        .detach();
174    }
175
176    fn handle_context_server_event(
177        &mut self,
178        context_server_manager: Entity<ContextServerManager>,
179        event: &context_server::manager::Event,
180        cx: &mut Context<Self>,
181    ) {
182        let tool_working_set = self.tools.clone();
183        match event {
184            context_server::manager::Event::ServerStarted { server_id } => {
185                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
186                    let context_server_manager = context_server_manager.clone();
187                    cx.spawn({
188                        let server = server.clone();
189                        let server_id = server_id.clone();
190                        |this, mut cx| async move {
191                            let Some(protocol) = server.client() else {
192                                return;
193                            };
194
195                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
196                                if let Some(tools) = protocol.list_tools().await.log_err() {
197                                    let tool_ids = tools
198                                        .tools
199                                        .into_iter()
200                                        .map(|tool| {
201                                            log::info!(
202                                                "registering context server tool: {:?}",
203                                                tool.name
204                                            );
205                                            tool_working_set.insert(Arc::new(
206                                                ContextServerTool::new(
207                                                    context_server_manager.clone(),
208                                                    server.id(),
209                                                    tool,
210                                                ),
211                                            ))
212                                        })
213                                        .collect::<Vec<_>>();
214
215                                    this.update(&mut cx, |this, _cx| {
216                                        this.context_server_tool_ids.insert(server_id, tool_ids);
217                                    })
218                                    .log_err();
219                                }
220                            }
221                        }
222                    })
223                    .detach();
224                }
225            }
226            context_server::manager::Event::ServerStopped { server_id } => {
227                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
228                    tool_working_set.remove(&tool_ids);
229                }
230            }
231        }
232    }
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct SavedThreadMetadata {
237    pub id: ThreadId,
238    pub summary: SharedString,
239    pub updated_at: DateTime<Utc>,
240}
241
242#[derive(Serialize, Deserialize)]
243pub struct SavedThread {
244    pub summary: SharedString,
245    pub updated_at: DateTime<Utc>,
246    pub messages: Vec<SavedMessage>,
247}
248
249#[derive(Serialize, Deserialize)]
250pub struct SavedMessage {
251    pub id: MessageId,
252    pub role: Role,
253    pub text: String,
254}
255
256struct ThreadsDatabase {
257    executor: BackgroundExecutor,
258    env: heed::Env,
259    threads: Database<SerdeBincode<ThreadId>, SerdeBincode<SavedThread>>,
260}
261
262impl ThreadsDatabase {
263    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
264        std::fs::create_dir_all(&path)?;
265
266        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
267        let env = unsafe {
268            heed::EnvOpenOptions::new()
269                .map_size(ONE_GB_IN_BYTES)
270                .max_dbs(1)
271                .open(path)?
272        };
273
274        let mut txn = env.write_txn()?;
275        let threads = env.create_database(&mut txn, Some("threads"))?;
276        txn.commit()?;
277
278        Ok(Self {
279            executor,
280            env,
281            threads,
282        })
283    }
284
285    pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
286        let env = self.env.clone();
287        let threads = self.threads;
288
289        self.executor.spawn(async move {
290            let txn = env.read_txn()?;
291            let mut iter = threads.iter(&txn)?;
292            let mut threads = Vec::new();
293            while let Some((key, value)) = iter.next().transpose()? {
294                threads.push(SavedThreadMetadata {
295                    id: key,
296                    summary: value.summary,
297                    updated_at: value.updated_at,
298                });
299            }
300
301            Ok(threads)
302        })
303    }
304
305    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
306        let env = self.env.clone();
307        let threads = self.threads;
308
309        self.executor.spawn(async move {
310            let txn = env.read_txn()?;
311            let thread = threads.get(&txn, &id)?;
312            Ok(thread)
313        })
314    }
315
316    pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
317        let env = self.env.clone();
318        let threads = self.threads;
319
320        self.executor.spawn(async move {
321            let mut txn = env.write_txn()?;
322            threads.put(&mut txn, &id, &thread)?;
323            txn.commit()?;
324            Ok(())
325        })
326    }
327
328    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
329        let env = self.env.clone();
330        let threads = self.threads;
331
332        self.executor.spawn(async move {
333            let mut txn = env.write_txn()?;
334            threads.delete(&mut txn, &id)?;
335            txn.commit()?;
336            Ok(())
337        })
338    }
339}