thread_store.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use assistant_tool::{ToolId, ToolWorkingSet};
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use context_server::manager::ContextServerManager;
  9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 10use futures::future::{self, BoxFuture, Shared};
 11use futures::FutureExt as _;
 12use gpui::{
 13    prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
 14};
 15use heed::types::{SerdeBincode, SerdeJson};
 16use heed::Database;
 17use language_model::{LanguageModelToolUseId, Role};
 18use project::Project;
 19use prompt_store::PromptBuilder;
 20use serde::{Deserialize, Serialize};
 21use util::ResultExt as _;
 22
 23use crate::thread::{MessageId, Thread, ThreadId};
 24
 25pub fn init(cx: &mut App) {
 26    ThreadsDatabase::init(cx);
 27}
 28
 29pub struct ThreadStore {
 30    project: Entity<Project>,
 31    tools: Arc<ToolWorkingSet>,
 32    prompt_builder: Arc<PromptBuilder>,
 33    context_server_manager: Entity<ContextServerManager>,
 34    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 35    threads: Vec<SavedThreadMetadata>,
 36}
 37
 38impl ThreadStore {
 39    pub fn new(
 40        project: Entity<Project>,
 41        tools: Arc<ToolWorkingSet>,
 42        prompt_builder: Arc<PromptBuilder>,
 43        cx: &mut App,
 44    ) -> Result<Entity<Self>> {
 45        let this = cx.new(|cx| {
 46            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 47            let context_server_manager = cx.new(|cx| {
 48                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 49            });
 50
 51            let this = Self {
 52                project,
 53                tools,
 54                prompt_builder,
 55                context_server_manager,
 56                context_server_tool_ids: HashMap::default(),
 57                threads: Vec::new(),
 58            };
 59            this.register_context_server_handlers(cx);
 60            this.reload(cx).detach_and_log_err(cx);
 61
 62            this
 63        });
 64
 65        Ok(this)
 66    }
 67
 68    /// Returns the number of threads.
 69    pub fn thread_count(&self) -> usize {
 70        self.threads.len()
 71    }
 72
 73    pub fn threads(&self) -> Vec<SavedThreadMetadata> {
 74        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 75        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 76        threads
 77    }
 78
 79    pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
 80        self.threads().into_iter().take(limit).collect()
 81    }
 82
 83    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 84        cx.new(|cx| {
 85            Thread::new(
 86                self.project.clone(),
 87                self.tools.clone(),
 88                self.prompt_builder.clone(),
 89                cx,
 90            )
 91        })
 92    }
 93
 94    pub fn open_thread(
 95        &self,
 96        id: &ThreadId,
 97        cx: &mut Context<Self>,
 98    ) -> Task<Result<Entity<Thread>>> {
 99        let id = id.clone();
100        let database_future = ThreadsDatabase::global_future(cx);
101        cx.spawn(|this, mut cx| async move {
102            let database = database_future.await.map_err(|err| anyhow!(err))?;
103            let thread = database
104                .try_find_thread(id.clone())
105                .await?
106                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
107
108            this.update(&mut cx, |this, cx| {
109                cx.new(|cx| {
110                    Thread::from_saved(
111                        id.clone(),
112                        thread,
113                        this.project.clone(),
114                        this.tools.clone(),
115                        this.prompt_builder.clone(),
116                        cx,
117                    )
118                })
119            })
120        })
121    }
122
123    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
124        let (metadata, thread) = thread.update(cx, |thread, _cx| {
125            let id = thread.id().clone();
126            let thread = SavedThread {
127                summary: thread.summary_or_default(),
128                updated_at: thread.updated_at(),
129                messages: thread
130                    .messages()
131                    .map(|message| {
132                        let all_tool_uses = thread
133                            .tool_uses_for_message(message.id)
134                            .into_iter()
135                            .chain(thread.scripting_tool_uses_for_message(message.id))
136                            .map(|tool_use| SavedToolUse {
137                                id: tool_use.id,
138                                name: tool_use.name,
139                                input: tool_use.input,
140                            })
141                            .collect();
142                        let all_tool_results = thread
143                            .tool_results_for_message(message.id)
144                            .into_iter()
145                            .chain(thread.scripting_tool_results_for_message(message.id))
146                            .map(|tool_result| SavedToolResult {
147                                tool_use_id: tool_result.tool_use_id.clone(),
148                                is_error: tool_result.is_error,
149                                content: tool_result.content.clone(),
150                            })
151                            .collect();
152
153                        SavedMessage {
154                            id: message.id,
155                            role: message.role,
156                            text: message.text.clone(),
157                            tool_uses: all_tool_uses,
158                            tool_results: all_tool_results,
159                        }
160                    })
161                    .collect(),
162            };
163
164            (id, thread)
165        });
166
167        let database_future = ThreadsDatabase::global_future(cx);
168        cx.spawn(|this, mut cx| async move {
169            let database = database_future.await.map_err(|err| anyhow!(err))?;
170            database.save_thread(metadata, thread).await?;
171
172            this.update(&mut cx, |this, cx| this.reload(cx))?.await
173        })
174    }
175
176    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
177        let id = id.clone();
178        let database_future = ThreadsDatabase::global_future(cx);
179        cx.spawn(|this, mut cx| async move {
180            let database = database_future.await.map_err(|err| anyhow!(err))?;
181            database.delete_thread(id.clone()).await?;
182
183            this.update(&mut cx, |this, _cx| {
184                this.threads.retain(|thread| thread.id != id)
185            })
186        })
187    }
188
189    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
190        let database_future = ThreadsDatabase::global_future(cx);
191        cx.spawn(|this, mut cx| async move {
192            let threads = database_future
193                .await
194                .map_err(|err| anyhow!(err))?
195                .list_threads()
196                .await?;
197
198            this.update(&mut cx, |this, cx| {
199                this.threads = threads;
200                cx.notify();
201            })
202        })
203    }
204
205    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
206        cx.subscribe(
207            &self.context_server_manager.clone(),
208            Self::handle_context_server_event,
209        )
210        .detach();
211    }
212
213    fn handle_context_server_event(
214        &mut self,
215        context_server_manager: Entity<ContextServerManager>,
216        event: &context_server::manager::Event,
217        cx: &mut Context<Self>,
218    ) {
219        let tool_working_set = self.tools.clone();
220        match event {
221            context_server::manager::Event::ServerStarted { server_id } => {
222                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
223                    let context_server_manager = context_server_manager.clone();
224                    cx.spawn({
225                        let server = server.clone();
226                        let server_id = server_id.clone();
227                        |this, mut cx| async move {
228                            let Some(protocol) = server.client() else {
229                                return;
230                            };
231
232                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
233                                if let Some(tools) = protocol.list_tools().await.log_err() {
234                                    let tool_ids = tools
235                                        .tools
236                                        .into_iter()
237                                        .map(|tool| {
238                                            log::info!(
239                                                "registering context server tool: {:?}",
240                                                tool.name
241                                            );
242                                            tool_working_set.insert(Arc::new(
243                                                ContextServerTool::new(
244                                                    context_server_manager.clone(),
245                                                    server.id(),
246                                                    tool,
247                                                ),
248                                            ))
249                                        })
250                                        .collect::<Vec<_>>();
251
252                                    this.update(&mut cx, |this, _cx| {
253                                        this.context_server_tool_ids.insert(server_id, tool_ids);
254                                    })
255                                    .log_err();
256                                }
257                            }
258                        }
259                    })
260                    .detach();
261                }
262            }
263            context_server::manager::Event::ServerStopped { server_id } => {
264                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
265                    tool_working_set.remove(&tool_ids);
266                }
267            }
268        }
269    }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct SavedThreadMetadata {
274    pub id: ThreadId,
275    pub summary: SharedString,
276    pub updated_at: DateTime<Utc>,
277}
278
279#[derive(Serialize, Deserialize)]
280pub struct SavedThread {
281    pub summary: SharedString,
282    pub updated_at: DateTime<Utc>,
283    pub messages: Vec<SavedMessage>,
284}
285
286#[derive(Debug, Serialize, Deserialize)]
287pub struct SavedMessage {
288    pub id: MessageId,
289    pub role: Role,
290    pub text: String,
291    #[serde(default)]
292    pub tool_uses: Vec<SavedToolUse>,
293    #[serde(default)]
294    pub tool_results: Vec<SavedToolResult>,
295}
296
297#[derive(Debug, Serialize, Deserialize)]
298pub struct SavedToolUse {
299    pub id: LanguageModelToolUseId,
300    pub name: SharedString,
301    pub input: serde_json::Value,
302}
303
304#[derive(Debug, Serialize, Deserialize)]
305pub struct SavedToolResult {
306    pub tool_use_id: LanguageModelToolUseId,
307    pub is_error: bool,
308    pub content: Arc<str>,
309}
310
311struct GlobalThreadsDatabase(
312    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
313);
314
315impl Global for GlobalThreadsDatabase {}
316
317pub(crate) struct ThreadsDatabase {
318    executor: BackgroundExecutor,
319    env: heed::Env,
320    threads: Database<SerdeBincode<ThreadId>, SerdeJson<SavedThread>>,
321}
322
323impl ThreadsDatabase {
324    fn global_future(
325        cx: &mut App,
326    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
327        GlobalThreadsDatabase::global(cx).0.clone()
328    }
329
330    fn init(cx: &mut App) {
331        let executor = cx.background_executor().clone();
332        let database_future = executor
333            .spawn({
334                let executor = executor.clone();
335                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
336                async move { ThreadsDatabase::new(database_path, executor) }
337            })
338            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
339            .boxed()
340            .shared();
341
342        cx.set_global(GlobalThreadsDatabase(database_future));
343    }
344
345    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
346        std::fs::create_dir_all(&path)?;
347
348        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
349        let env = unsafe {
350            heed::EnvOpenOptions::new()
351                .map_size(ONE_GB_IN_BYTES)
352                .max_dbs(1)
353                .open(path)?
354        };
355
356        let mut txn = env.write_txn()?;
357        let threads = env.create_database(&mut txn, Some("threads"))?;
358        txn.commit()?;
359
360        Ok(Self {
361            executor,
362            env,
363            threads,
364        })
365    }
366
367    pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
368        let env = self.env.clone();
369        let threads = self.threads;
370
371        self.executor.spawn(async move {
372            let txn = env.read_txn()?;
373            let mut iter = threads.iter(&txn)?;
374            let mut threads = Vec::new();
375            while let Some((key, value)) = iter.next().transpose()? {
376                threads.push(SavedThreadMetadata {
377                    id: key,
378                    summary: value.summary,
379                    updated_at: value.updated_at,
380                });
381            }
382
383            Ok(threads)
384        })
385    }
386
387    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
388        let env = self.env.clone();
389        let threads = self.threads;
390
391        self.executor.spawn(async move {
392            let txn = env.read_txn()?;
393            let thread = threads.get(&txn, &id)?;
394            Ok(thread)
395        })
396    }
397
398    pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
399        let env = self.env.clone();
400        let threads = self.threads;
401
402        self.executor.spawn(async move {
403            let mut txn = env.write_txn()?;
404            threads.put(&mut txn, &id, &thread)?;
405            txn.commit()?;
406            Ok(())
407        })
408    }
409
410    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
411        let env = self.env.clone();
412        let threads = self.threads;
413
414        self.executor.spawn(async move {
415            let mut txn = env.write_txn()?;
416            threads.delete(&mut txn, &id)?;
417            txn.commit()?;
418            Ok(())
419        })
420    }
421}