thread_store.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use assistant_tool::{ToolId, ToolWorkingSet};
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use context_server::manager::ContextServerManager;
  9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 10use futures::future::{self, BoxFuture, Shared};
 11use futures::FutureExt as _;
 12use gpui::{
 13    prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
 14};
 15use heed::types::{SerdeBincode, SerdeJson};
 16use heed::Database;
 17use language_model::{LanguageModelToolUseId, Role};
 18use project::Project;
 19use prompt_store::PromptBuilder;
 20use serde::{Deserialize, Serialize};
 21use util::ResultExt as _;
 22
 23use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
 24
 25pub fn init(cx: &mut App) {
 26    ThreadsDatabase::init(cx);
 27}
 28
 29pub struct ThreadStore {
 30    project: Entity<Project>,
 31    tools: Arc<ToolWorkingSet>,
 32    prompt_builder: Arc<PromptBuilder>,
 33    context_server_manager: Entity<ContextServerManager>,
 34    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 35    threads: Vec<SerializedThreadMetadata>,
 36}
 37
 38impl ThreadStore {
 39    pub fn new(
 40        project: Entity<Project>,
 41        tools: Arc<ToolWorkingSet>,
 42        prompt_builder: Arc<PromptBuilder>,
 43        cx: &mut App,
 44    ) -> Result<Entity<Self>> {
 45        let this = cx.new(|cx| {
 46            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 47            let context_server_manager = cx.new(|cx| {
 48                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 49            });
 50
 51            let this = Self {
 52                project,
 53                tools,
 54                prompt_builder,
 55                context_server_manager,
 56                context_server_tool_ids: HashMap::default(),
 57                threads: Vec::new(),
 58            };
 59            this.register_context_server_handlers(cx);
 60            this.reload(cx).detach_and_log_err(cx);
 61
 62            this
 63        });
 64
 65        Ok(this)
 66    }
 67
 68    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
 69        self.context_server_manager.clone()
 70    }
 71
 72    pub fn tools(&self) -> Arc<ToolWorkingSet> {
 73        self.tools.clone()
 74    }
 75
 76    /// Returns the number of threads.
 77    pub fn thread_count(&self) -> usize {
 78        self.threads.len()
 79    }
 80
 81    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
 82        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 83        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 84        threads
 85    }
 86
 87    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
 88        self.threads().into_iter().take(limit).collect()
 89    }
 90
 91    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 92        cx.new(|cx| {
 93            Thread::new(
 94                self.project.clone(),
 95                self.tools.clone(),
 96                self.prompt_builder.clone(),
 97                cx,
 98            )
 99        })
100    }
101
102    pub fn open_thread(
103        &self,
104        id: &ThreadId,
105        cx: &mut Context<Self>,
106    ) -> Task<Result<Entity<Thread>>> {
107        let id = id.clone();
108        let database_future = ThreadsDatabase::global_future(cx);
109        cx.spawn(async move |this, cx| {
110            let database = database_future.await.map_err(|err| anyhow!(err))?;
111            let thread = database
112                .try_find_thread(id.clone())
113                .await?
114                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
115
116            let thread = this.update(cx, |this, cx| {
117                cx.new(|cx| {
118                    Thread::deserialize(
119                        id.clone(),
120                        thread,
121                        this.project.clone(),
122                        this.tools.clone(),
123                        this.prompt_builder.clone(),
124                        cx,
125                    )
126                })
127            })?;
128
129            let (system_prompt_context, load_error) = thread
130                .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
131                .await;
132            thread.update(cx, |thread, cx| {
133                thread.set_system_prompt_context(system_prompt_context);
134                if let Some(load_error) = load_error {
135                    cx.emit(ThreadEvent::ShowError(load_error));
136                }
137            })?;
138
139            Ok(thread)
140        })
141    }
142
143    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
144        let (metadata, serialized_thread) =
145            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
146
147        let database_future = ThreadsDatabase::global_future(cx);
148        cx.spawn(async move |this, cx| {
149            let serialized_thread = serialized_thread.await?;
150            let database = database_future.await.map_err(|err| anyhow!(err))?;
151            database.save_thread(metadata, serialized_thread).await?;
152
153            this.update(cx, |this, cx| this.reload(cx))?.await
154        })
155    }
156
157    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
158        let id = id.clone();
159        let database_future = ThreadsDatabase::global_future(cx);
160        cx.spawn(async move |this, cx| {
161            let database = database_future.await.map_err(|err| anyhow!(err))?;
162            database.delete_thread(id.clone()).await?;
163
164            this.update(cx, |this, _cx| {
165                this.threads.retain(|thread| thread.id != id)
166            })
167        })
168    }
169
170    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
171        let database_future = ThreadsDatabase::global_future(cx);
172        cx.spawn(async move |this, cx| {
173            let threads = database_future
174                .await
175                .map_err(|err| anyhow!(err))?
176                .list_threads()
177                .await?;
178
179            this.update(cx, |this, cx| {
180                this.threads = threads;
181                cx.notify();
182            })
183        })
184    }
185
186    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
187        cx.subscribe(
188            &self.context_server_manager.clone(),
189            Self::handle_context_server_event,
190        )
191        .detach();
192    }
193
194    fn handle_context_server_event(
195        &mut self,
196        context_server_manager: Entity<ContextServerManager>,
197        event: &context_server::manager::Event,
198        cx: &mut Context<Self>,
199    ) {
200        let tool_working_set = self.tools.clone();
201        match event {
202            context_server::manager::Event::ServerStarted { server_id } => {
203                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
204                    let context_server_manager = context_server_manager.clone();
205                    cx.spawn({
206                        let server = server.clone();
207                        let server_id = server_id.clone();
208                        async move |this, cx| {
209                            let Some(protocol) = server.client() else {
210                                return;
211                            };
212
213                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
214                                if let Some(tools) = protocol.list_tools().await.log_err() {
215                                    let tool_ids = tools
216                                        .tools
217                                        .into_iter()
218                                        .map(|tool| {
219                                            log::info!(
220                                                "registering context server tool: {:?}",
221                                                tool.name
222                                            );
223                                            tool_working_set.insert(Arc::new(
224                                                ContextServerTool::new(
225                                                    context_server_manager.clone(),
226                                                    server.id(),
227                                                    tool,
228                                                ),
229                                            ))
230                                        })
231                                        .collect::<Vec<_>>();
232
233                                    this.update(cx, |this, _cx| {
234                                        this.context_server_tool_ids.insert(server_id, tool_ids);
235                                    })
236                                    .log_err();
237                                }
238                            }
239                        }
240                    })
241                    .detach();
242                }
243            }
244            context_server::manager::Event::ServerStopped { server_id } => {
245                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
246                    tool_working_set.remove(&tool_ids);
247                }
248            }
249        }
250    }
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct SerializedThreadMetadata {
255    pub id: ThreadId,
256    pub summary: SharedString,
257    pub updated_at: DateTime<Utc>,
258}
259
260#[derive(Serialize, Deserialize)]
261pub struct SerializedThread {
262    pub summary: SharedString,
263    pub updated_at: DateTime<Utc>,
264    pub messages: Vec<SerializedMessage>,
265    #[serde(default)]
266    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
267}
268
269#[derive(Debug, Serialize, Deserialize)]
270pub struct SerializedMessage {
271    pub id: MessageId,
272    pub role: Role,
273    pub text: String,
274    #[serde(default)]
275    pub tool_uses: Vec<SerializedToolUse>,
276    #[serde(default)]
277    pub tool_results: Vec<SerializedToolResult>,
278}
279
280#[derive(Debug, Serialize, Deserialize)]
281pub struct SerializedToolUse {
282    pub id: LanguageModelToolUseId,
283    pub name: SharedString,
284    pub input: serde_json::Value,
285}
286
287#[derive(Debug, Serialize, Deserialize)]
288pub struct SerializedToolResult {
289    pub tool_use_id: LanguageModelToolUseId,
290    pub is_error: bool,
291    pub content: Arc<str>,
292}
293
294struct GlobalThreadsDatabase(
295    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
296);
297
298impl Global for GlobalThreadsDatabase {}
299
300pub(crate) struct ThreadsDatabase {
301    executor: BackgroundExecutor,
302    env: heed::Env,
303    threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
304}
305
306impl ThreadsDatabase {
307    fn global_future(
308        cx: &mut App,
309    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
310        GlobalThreadsDatabase::global(cx).0.clone()
311    }
312
313    fn init(cx: &mut App) {
314        let executor = cx.background_executor().clone();
315        let database_future = executor
316            .spawn({
317                let executor = executor.clone();
318                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
319                async move { ThreadsDatabase::new(database_path, executor) }
320            })
321            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
322            .boxed()
323            .shared();
324
325        cx.set_global(GlobalThreadsDatabase(database_future));
326    }
327
328    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
329        std::fs::create_dir_all(&path)?;
330
331        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
332        let env = unsafe {
333            heed::EnvOpenOptions::new()
334                .map_size(ONE_GB_IN_BYTES)
335                .max_dbs(1)
336                .open(path)?
337        };
338
339        let mut txn = env.write_txn()?;
340        let threads = env.create_database(&mut txn, Some("threads"))?;
341        txn.commit()?;
342
343        Ok(Self {
344            executor,
345            env,
346            threads,
347        })
348    }
349
350    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
351        let env = self.env.clone();
352        let threads = self.threads;
353
354        self.executor.spawn(async move {
355            let txn = env.read_txn()?;
356            let mut iter = threads.iter(&txn)?;
357            let mut threads = Vec::new();
358            while let Some((key, value)) = iter.next().transpose()? {
359                threads.push(SerializedThreadMetadata {
360                    id: key,
361                    summary: value.summary,
362                    updated_at: value.updated_at,
363                });
364            }
365
366            Ok(threads)
367        })
368    }
369
370    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
371        let env = self.env.clone();
372        let threads = self.threads;
373
374        self.executor.spawn(async move {
375            let txn = env.read_txn()?;
376            let thread = threads.get(&txn, &id)?;
377            Ok(thread)
378        })
379    }
380
381    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
382        let env = self.env.clone();
383        let threads = self.threads;
384
385        self.executor.spawn(async move {
386            let mut txn = env.write_txn()?;
387            threads.put(&mut txn, &id, &thread)?;
388            txn.commit()?;
389            Ok(())
390        })
391    }
392
393    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
394        let env = self.env.clone();
395        let threads = self.threads;
396
397        self.executor.spawn(async move {
398            let mut txn = env.write_txn()?;
399            threads.delete(&mut txn, &id)?;
400            txn.commit()?;
401            Ok(())
402        })
403    }
404}