thread_store.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use assistant_tool::{ToolId, ToolWorkingSet};
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use context_server::manager::ContextServerManager;
  9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 10use futures::future::{self, BoxFuture, Shared};
 11use futures::FutureExt as _;
 12use gpui::{
 13    prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
 14};
 15use heed::types::{SerdeBincode, SerdeJson};
 16use heed::Database;
 17use language_model::{LanguageModelToolUseId, Role};
 18use project::Project;
 19use serde::{Deserialize, Serialize};
 20use util::ResultExt as _;
 21
 22use crate::thread::{MessageId, Thread, ThreadId};
 23
 24pub fn init(cx: &mut App) {
 25    ThreadsDatabase::init(cx);
 26}
 27
 28pub struct ThreadStore {
 29    #[allow(unused)]
 30    project: Entity<Project>,
 31    tools: Arc<ToolWorkingSet>,
 32    context_server_manager: Entity<ContextServerManager>,
 33    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 34    threads: Vec<SavedThreadMetadata>,
 35}
 36
 37impl ThreadStore {
 38    pub fn new(
 39        project: Entity<Project>,
 40        tools: Arc<ToolWorkingSet>,
 41        cx: &mut App,
 42    ) -> Result<Entity<Self>> {
 43        let this = cx.new(|cx| {
 44            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 45            let context_server_manager = cx.new(|cx| {
 46                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 47            });
 48
 49            let this = Self {
 50                project,
 51                tools,
 52                context_server_manager,
 53                context_server_tool_ids: HashMap::default(),
 54                threads: Vec::new(),
 55            };
 56            this.register_context_server_handlers(cx);
 57            this.reload(cx).detach_and_log_err(cx);
 58
 59            this
 60        });
 61
 62        Ok(this)
 63    }
 64
 65    /// Returns the number of threads.
 66    pub fn thread_count(&self) -> usize {
 67        self.threads.len()
 68    }
 69
 70    pub fn threads(&self) -> Vec<SavedThreadMetadata> {
 71        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 72        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 73        threads
 74    }
 75
 76    pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
 77        self.threads().into_iter().take(limit).collect()
 78    }
 79
 80    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 81        cx.new(|cx| Thread::new(self.tools.clone(), cx))
 82    }
 83
 84    pub fn open_thread(
 85        &self,
 86        id: &ThreadId,
 87        cx: &mut Context<Self>,
 88    ) -> Task<Result<Entity<Thread>>> {
 89        let id = id.clone();
 90        let database_future = ThreadsDatabase::global_future(cx);
 91        cx.spawn(|this, mut cx| async move {
 92            let database = database_future.await.map_err(|err| anyhow!(err))?;
 93            let thread = database
 94                .try_find_thread(id.clone())
 95                .await?
 96                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
 97
 98            this.update(&mut cx, |this, cx| {
 99                cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
100            })
101        })
102    }
103
104    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
105        let (metadata, thread) = thread.update(cx, |thread, _cx| {
106            let id = thread.id().clone();
107            let thread = SavedThread {
108                summary: thread.summary_or_default(),
109                updated_at: thread.updated_at(),
110                messages: thread
111                    .messages()
112                    .map(|message| SavedMessage {
113                        id: message.id,
114                        role: message.role,
115                        text: message.text.clone(),
116                        tool_uses: thread
117                            .tool_uses_for_message(message.id)
118                            .into_iter()
119                            .map(|tool_use| SavedToolUse {
120                                id: tool_use.id,
121                                name: tool_use.name,
122                                input: tool_use.input,
123                            })
124                            .collect(),
125                        tool_results: thread
126                            .tool_results_for_message(message.id)
127                            .into_iter()
128                            .map(|tool_result| SavedToolResult {
129                                tool_use_id: tool_result.tool_use_id.clone(),
130                                is_error: tool_result.is_error,
131                                content: tool_result.content.clone(),
132                            })
133                            .collect(),
134                    })
135                    .collect(),
136            };
137
138            (id, thread)
139        });
140
141        let database_future = ThreadsDatabase::global_future(cx);
142        cx.spawn(|this, mut cx| async move {
143            let database = database_future.await.map_err(|err| anyhow!(err))?;
144            database.save_thread(metadata, thread).await?;
145
146            this.update(&mut cx, |this, cx| this.reload(cx))?.await
147        })
148    }
149
150    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
151        let id = id.clone();
152        let database_future = ThreadsDatabase::global_future(cx);
153        cx.spawn(|this, mut cx| async move {
154            let database = database_future.await.map_err(|err| anyhow!(err))?;
155            database.delete_thread(id.clone()).await?;
156
157            this.update(&mut cx, |this, _cx| {
158                this.threads.retain(|thread| thread.id != id)
159            })
160        })
161    }
162
163    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
164        let database_future = ThreadsDatabase::global_future(cx);
165        cx.spawn(|this, mut cx| async move {
166            let threads = database_future
167                .await
168                .map_err(|err| anyhow!(err))?
169                .list_threads()
170                .await?;
171
172            this.update(&mut cx, |this, cx| {
173                this.threads = threads;
174                cx.notify();
175            })
176        })
177    }
178
179    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
180        cx.subscribe(
181            &self.context_server_manager.clone(),
182            Self::handle_context_server_event,
183        )
184        .detach();
185    }
186
187    fn handle_context_server_event(
188        &mut self,
189        context_server_manager: Entity<ContextServerManager>,
190        event: &context_server::manager::Event,
191        cx: &mut Context<Self>,
192    ) {
193        let tool_working_set = self.tools.clone();
194        match event {
195            context_server::manager::Event::ServerStarted { server_id } => {
196                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
197                    let context_server_manager = context_server_manager.clone();
198                    cx.spawn({
199                        let server = server.clone();
200                        let server_id = server_id.clone();
201                        |this, mut cx| async move {
202                            let Some(protocol) = server.client() else {
203                                return;
204                            };
205
206                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
207                                if let Some(tools) = protocol.list_tools().await.log_err() {
208                                    let tool_ids = tools
209                                        .tools
210                                        .into_iter()
211                                        .map(|tool| {
212                                            log::info!(
213                                                "registering context server tool: {:?}",
214                                                tool.name
215                                            );
216                                            tool_working_set.insert(Arc::new(
217                                                ContextServerTool::new(
218                                                    context_server_manager.clone(),
219                                                    server.id(),
220                                                    tool,
221                                                ),
222                                            ))
223                                        })
224                                        .collect::<Vec<_>>();
225
226                                    this.update(&mut cx, |this, _cx| {
227                                        this.context_server_tool_ids.insert(server_id, tool_ids);
228                                    })
229                                    .log_err();
230                                }
231                            }
232                        }
233                    })
234                    .detach();
235                }
236            }
237            context_server::manager::Event::ServerStopped { server_id } => {
238                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
239                    tool_working_set.remove(&tool_ids);
240                }
241            }
242        }
243    }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct SavedThreadMetadata {
248    pub id: ThreadId,
249    pub summary: SharedString,
250    pub updated_at: DateTime<Utc>,
251}
252
253#[derive(Serialize, Deserialize)]
254pub struct SavedThread {
255    pub summary: SharedString,
256    pub updated_at: DateTime<Utc>,
257    pub messages: Vec<SavedMessage>,
258}
259
260#[derive(Debug, Serialize, Deserialize)]
261pub struct SavedMessage {
262    pub id: MessageId,
263    pub role: Role,
264    pub text: String,
265    #[serde(default)]
266    pub tool_uses: Vec<SavedToolUse>,
267    #[serde(default)]
268    pub tool_results: Vec<SavedToolResult>,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272pub struct SavedToolUse {
273    pub id: LanguageModelToolUseId,
274    pub name: SharedString,
275    pub input: serde_json::Value,
276}
277
278#[derive(Debug, Serialize, Deserialize)]
279pub struct SavedToolResult {
280    pub tool_use_id: LanguageModelToolUseId,
281    pub is_error: bool,
282    pub content: Arc<str>,
283}
284
285struct GlobalThreadsDatabase(
286    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
287);
288
289impl Global for GlobalThreadsDatabase {}
290
291pub(crate) struct ThreadsDatabase {
292    executor: BackgroundExecutor,
293    env: heed::Env,
294    threads: Database<SerdeBincode<ThreadId>, SerdeJson<SavedThread>>,
295}
296
297impl ThreadsDatabase {
298    fn global_future(
299        cx: &mut App,
300    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
301        GlobalThreadsDatabase::global(cx).0.clone()
302    }
303
304    fn init(cx: &mut App) {
305        let executor = cx.background_executor().clone();
306        let database_future = executor
307            .spawn({
308                let executor = executor.clone();
309                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
310                async move { ThreadsDatabase::new(database_path, executor) }
311            })
312            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
313            .boxed()
314            .shared();
315
316        cx.set_global(GlobalThreadsDatabase(database_future));
317    }
318
319    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
320        std::fs::create_dir_all(&path)?;
321
322        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
323        let env = unsafe {
324            heed::EnvOpenOptions::new()
325                .map_size(ONE_GB_IN_BYTES)
326                .max_dbs(1)
327                .open(path)?
328        };
329
330        let mut txn = env.write_txn()?;
331        let threads = env.create_database(&mut txn, Some("threads"))?;
332        txn.commit()?;
333
334        Ok(Self {
335            executor,
336            env,
337            threads,
338        })
339    }
340
341    pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
342        let env = self.env.clone();
343        let threads = self.threads;
344
345        self.executor.spawn(async move {
346            let txn = env.read_txn()?;
347            let mut iter = threads.iter(&txn)?;
348            let mut threads = Vec::new();
349            while let Some((key, value)) = iter.next().transpose()? {
350                threads.push(SavedThreadMetadata {
351                    id: key,
352                    summary: value.summary,
353                    updated_at: value.updated_at,
354                });
355            }
356
357            Ok(threads)
358        })
359    }
360
361    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
362        let env = self.env.clone();
363        let threads = self.threads;
364
365        self.executor.spawn(async move {
366            let txn = env.read_txn()?;
367            let thread = threads.get(&txn, &id)?;
368            Ok(thread)
369        })
370    }
371
372    pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
373        let env = self.env.clone();
374        let threads = self.threads;
375
376        self.executor.spawn(async move {
377            let mut txn = env.write_txn()?;
378            threads.put(&mut txn, &id, &thread)?;
379            txn.commit()?;
380            Ok(())
381        })
382    }
383
384    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
385        let env = self.env.clone();
386        let threads = self.threads;
387
388        self.executor.spawn(async move {
389            let mut txn = env.write_txn()?;
390            threads.delete(&mut txn, &id)?;
391            txn.commit()?;
392            Ok(())
393        })
394    }
395}