thread_store.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use assistant_tool::{ToolId, ToolWorkingSet};
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use context_server::manager::ContextServerManager;
  9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 10use futures::future::{self, BoxFuture, Shared};
 11use futures::FutureExt as _;
 12use gpui::{
 13    prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
 14};
 15use heed::types::{SerdeBincode, SerdeJson};
 16use heed::Database;
 17use language_model::{LanguageModelToolUseId, Role};
 18use project::Project;
 19use prompt_store::PromptBuilder;
 20use serde::{Deserialize, Serialize};
 21use util::ResultExt as _;
 22
 23use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadId};
 24
 25pub fn init(cx: &mut App) {
 26    ThreadsDatabase::init(cx);
 27}
 28
 29pub struct ThreadStore {
 30    project: Entity<Project>,
 31    tools: Arc<ToolWorkingSet>,
 32    prompt_builder: Arc<PromptBuilder>,
 33    context_server_manager: Entity<ContextServerManager>,
 34    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 35    threads: Vec<SerializedThreadMetadata>,
 36}
 37
 38impl ThreadStore {
 39    pub fn new(
 40        project: Entity<Project>,
 41        tools: Arc<ToolWorkingSet>,
 42        prompt_builder: Arc<PromptBuilder>,
 43        cx: &mut App,
 44    ) -> Result<Entity<Self>> {
 45        let this = cx.new(|cx| {
 46            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 47            let context_server_manager = cx.new(|cx| {
 48                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 49            });
 50
 51            let this = Self {
 52                project,
 53                tools,
 54                prompt_builder,
 55                context_server_manager,
 56                context_server_tool_ids: HashMap::default(),
 57                threads: Vec::new(),
 58            };
 59            this.register_context_server_handlers(cx);
 60            this.reload(cx).detach_and_log_err(cx);
 61
 62            this
 63        });
 64
 65        Ok(this)
 66    }
 67
 68    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
 69        self.context_server_manager.clone()
 70    }
 71
 72    /// Returns the number of threads.
 73    pub fn thread_count(&self) -> usize {
 74        self.threads.len()
 75    }
 76
 77    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
 78        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 79        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 80        threads
 81    }
 82
 83    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
 84        self.threads().into_iter().take(limit).collect()
 85    }
 86
 87    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 88        cx.new(|cx| {
 89            Thread::new(
 90                self.project.clone(),
 91                self.tools.clone(),
 92                self.prompt_builder.clone(),
 93                cx,
 94            )
 95        })
 96    }
 97
 98    pub fn open_thread(
 99        &self,
100        id: &ThreadId,
101        cx: &mut Context<Self>,
102    ) -> Task<Result<Entity<Thread>>> {
103        let id = id.clone();
104        let database_future = ThreadsDatabase::global_future(cx);
105        cx.spawn(|this, mut cx| async move {
106            let database = database_future.await.map_err(|err| anyhow!(err))?;
107            let thread = database
108                .try_find_thread(id.clone())
109                .await?
110                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
111
112            this.update(&mut cx, |this, cx| {
113                cx.new(|cx| {
114                    Thread::deserialize(
115                        id.clone(),
116                        thread,
117                        this.project.clone(),
118                        this.tools.clone(),
119                        this.prompt_builder.clone(),
120                        cx,
121                    )
122                })
123            })
124        })
125    }
126
127    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
128        let (metadata, serialized_thread) =
129            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
130
131        let database_future = ThreadsDatabase::global_future(cx);
132        cx.spawn(|this, mut cx| async move {
133            let serialized_thread = serialized_thread.await?;
134            let database = database_future.await.map_err(|err| anyhow!(err))?;
135            database.save_thread(metadata, serialized_thread).await?;
136
137            this.update(&mut cx, |this, cx| this.reload(cx))?.await
138        })
139    }
140
141    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
142        let id = id.clone();
143        let database_future = ThreadsDatabase::global_future(cx);
144        cx.spawn(|this, mut cx| async move {
145            let database = database_future.await.map_err(|err| anyhow!(err))?;
146            database.delete_thread(id.clone()).await?;
147
148            this.update(&mut cx, |this, _cx| {
149                this.threads.retain(|thread| thread.id != id)
150            })
151        })
152    }
153
154    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
155        let database_future = ThreadsDatabase::global_future(cx);
156        cx.spawn(|this, mut cx| async move {
157            let threads = database_future
158                .await
159                .map_err(|err| anyhow!(err))?
160                .list_threads()
161                .await?;
162
163            this.update(&mut cx, |this, cx| {
164                this.threads = threads;
165                cx.notify();
166            })
167        })
168    }
169
170    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
171        cx.subscribe(
172            &self.context_server_manager.clone(),
173            Self::handle_context_server_event,
174        )
175        .detach();
176    }
177
178    fn handle_context_server_event(
179        &mut self,
180        context_server_manager: Entity<ContextServerManager>,
181        event: &context_server::manager::Event,
182        cx: &mut Context<Self>,
183    ) {
184        let tool_working_set = self.tools.clone();
185        match event {
186            context_server::manager::Event::ServerStarted { server_id } => {
187                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
188                    let context_server_manager = context_server_manager.clone();
189                    cx.spawn({
190                        let server = server.clone();
191                        let server_id = server_id.clone();
192                        |this, mut cx| async move {
193                            let Some(protocol) = server.client() else {
194                                return;
195                            };
196
197                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
198                                if let Some(tools) = protocol.list_tools().await.log_err() {
199                                    let tool_ids = tools
200                                        .tools
201                                        .into_iter()
202                                        .map(|tool| {
203                                            log::info!(
204                                                "registering context server tool: {:?}",
205                                                tool.name
206                                            );
207                                            tool_working_set.insert(Arc::new(
208                                                ContextServerTool::new(
209                                                    context_server_manager.clone(),
210                                                    server.id(),
211                                                    tool,
212                                                ),
213                                            ))
214                                        })
215                                        .collect::<Vec<_>>();
216
217                                    this.update(&mut cx, |this, _cx| {
218                                        this.context_server_tool_ids.insert(server_id, tool_ids);
219                                    })
220                                    .log_err();
221                                }
222                            }
223                        }
224                    })
225                    .detach();
226                }
227            }
228            context_server::manager::Event::ServerStopped { server_id } => {
229                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
230                    tool_working_set.remove(&tool_ids);
231                }
232            }
233        }
234    }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct SerializedThreadMetadata {
239    pub id: ThreadId,
240    pub summary: SharedString,
241    pub updated_at: DateTime<Utc>,
242}
243
244#[derive(Serialize, Deserialize)]
245pub struct SerializedThread {
246    pub summary: SharedString,
247    pub updated_at: DateTime<Utc>,
248    pub messages: Vec<SerializedMessage>,
249    #[serde(default)]
250    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
251}
252
253#[derive(Debug, Serialize, Deserialize)]
254pub struct SerializedMessage {
255    pub id: MessageId,
256    pub role: Role,
257    pub text: String,
258    #[serde(default)]
259    pub tool_uses: Vec<SerializedToolUse>,
260    #[serde(default)]
261    pub tool_results: Vec<SerializedToolResult>,
262}
263
264#[derive(Debug, Serialize, Deserialize)]
265pub struct SerializedToolUse {
266    pub id: LanguageModelToolUseId,
267    pub name: SharedString,
268    pub input: serde_json::Value,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272pub struct SerializedToolResult {
273    pub tool_use_id: LanguageModelToolUseId,
274    pub is_error: bool,
275    pub content: Arc<str>,
276}
277
278struct GlobalThreadsDatabase(
279    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
280);
281
282impl Global for GlobalThreadsDatabase {}
283
284pub(crate) struct ThreadsDatabase {
285    executor: BackgroundExecutor,
286    env: heed::Env,
287    threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
288}
289
290impl ThreadsDatabase {
291    fn global_future(
292        cx: &mut App,
293    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
294        GlobalThreadsDatabase::global(cx).0.clone()
295    }
296
297    fn init(cx: &mut App) {
298        let executor = cx.background_executor().clone();
299        let database_future = executor
300            .spawn({
301                let executor = executor.clone();
302                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
303                async move { ThreadsDatabase::new(database_path, executor) }
304            })
305            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
306            .boxed()
307            .shared();
308
309        cx.set_global(GlobalThreadsDatabase(database_future));
310    }
311
312    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
313        std::fs::create_dir_all(&path)?;
314
315        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
316        let env = unsafe {
317            heed::EnvOpenOptions::new()
318                .map_size(ONE_GB_IN_BYTES)
319                .max_dbs(1)
320                .open(path)?
321        };
322
323        let mut txn = env.write_txn()?;
324        let threads = env.create_database(&mut txn, Some("threads"))?;
325        txn.commit()?;
326
327        Ok(Self {
328            executor,
329            env,
330            threads,
331        })
332    }
333
334    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
335        let env = self.env.clone();
336        let threads = self.threads;
337
338        self.executor.spawn(async move {
339            let txn = env.read_txn()?;
340            let mut iter = threads.iter(&txn)?;
341            let mut threads = Vec::new();
342            while let Some((key, value)) = iter.next().transpose()? {
343                threads.push(SerializedThreadMetadata {
344                    id: key,
345                    summary: value.summary,
346                    updated_at: value.updated_at,
347                });
348            }
349
350            Ok(threads)
351        })
352    }
353
354    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
355        let env = self.env.clone();
356        let threads = self.threads;
357
358        self.executor.spawn(async move {
359            let txn = env.read_txn()?;
360            let thread = threads.get(&txn, &id)?;
361            Ok(thread)
362        })
363    }
364
365    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
366        let env = self.env.clone();
367        let threads = self.threads;
368
369        self.executor.spawn(async move {
370            let mut txn = env.write_txn()?;
371            threads.put(&mut txn, &id, &thread)?;
372            txn.commit()?;
373            Ok(())
374        })
375    }
376
377    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
378        let env = self.env.clone();
379        let threads = self.threads;
380
381        self.executor.spawn(async move {
382            let mut txn = env.write_txn()?;
383            threads.delete(&mut txn, &id)?;
384            txn.commit()?;
385            Ok(())
386        })
387    }
388}