thread_store.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3
  4use anyhow::{anyhow, Result};
  5use assistant_tool::{ToolId, ToolWorkingSet};
  6use chrono::{DateTime, Utc};
  7use collections::HashMap;
  8use context_server::manager::ContextServerManager;
  9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 10use futures::future::{self, BoxFuture, Shared};
 11use futures::FutureExt as _;
 12use gpui::{
 13    prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
 14};
 15use heed::types::{SerdeBincode, SerdeJson};
 16use heed::Database;
 17use language_model::{LanguageModelToolUseId, Role};
 18use project::Project;
 19use serde::{Deserialize, Serialize};
 20use util::ResultExt as _;
 21
 22use crate::thread::{MessageId, Thread, ThreadId};
 23
 24pub fn init(cx: &mut App) {
 25    ThreadsDatabase::init(cx);
 26}
 27
 28pub struct ThreadStore {
 29    project: Entity<Project>,
 30    tools: Arc<ToolWorkingSet>,
 31    context_server_manager: Entity<ContextServerManager>,
 32    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 33    threads: Vec<SavedThreadMetadata>,
 34}
 35
 36impl ThreadStore {
 37    pub fn new(
 38        project: Entity<Project>,
 39        tools: Arc<ToolWorkingSet>,
 40        cx: &mut App,
 41    ) -> Result<Entity<Self>> {
 42        let this = cx.new(|cx| {
 43            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 44            let context_server_manager = cx.new(|cx| {
 45                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 46            });
 47
 48            let this = Self {
 49                project,
 50                tools,
 51                context_server_manager,
 52                context_server_tool_ids: HashMap::default(),
 53                threads: Vec::new(),
 54            };
 55            this.register_context_server_handlers(cx);
 56            this.reload(cx).detach_and_log_err(cx);
 57
 58            this
 59        });
 60
 61        Ok(this)
 62    }
 63
 64    /// Returns the number of threads.
 65    pub fn thread_count(&self) -> usize {
 66        self.threads.len()
 67    }
 68
 69    pub fn threads(&self) -> Vec<SavedThreadMetadata> {
 70        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 71        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 72        threads
 73    }
 74
 75    pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
 76        self.threads().into_iter().take(limit).collect()
 77    }
 78
 79    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 80        cx.new(|cx| Thread::new(self.project.clone(), self.tools.clone(), cx))
 81    }
 82
 83    pub fn open_thread(
 84        &self,
 85        id: &ThreadId,
 86        cx: &mut Context<Self>,
 87    ) -> Task<Result<Entity<Thread>>> {
 88        let id = id.clone();
 89        let database_future = ThreadsDatabase::global_future(cx);
 90        cx.spawn(|this, mut cx| async move {
 91            let database = database_future.await.map_err(|err| anyhow!(err))?;
 92            let thread = database
 93                .try_find_thread(id.clone())
 94                .await?
 95                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
 96
 97            this.update(&mut cx, |this, cx| {
 98                cx.new(|cx| {
 99                    Thread::from_saved(
100                        id.clone(),
101                        thread,
102                        this.project.clone(),
103                        this.tools.clone(),
104                        cx,
105                    )
106                })
107            })
108        })
109    }
110
111    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
112        let (metadata, thread) = thread.update(cx, |thread, _cx| {
113            let id = thread.id().clone();
114            let thread = SavedThread {
115                summary: thread.summary_or_default(),
116                updated_at: thread.updated_at(),
117                messages: thread
118                    .messages()
119                    .map(|message| {
120                        let all_tool_uses = thread
121                            .tool_uses_for_message(message.id)
122                            .into_iter()
123                            .chain(thread.scripting_tool_uses_for_message(message.id))
124                            .map(|tool_use| SavedToolUse {
125                                id: tool_use.id,
126                                name: tool_use.name,
127                                input: tool_use.input,
128                            })
129                            .collect();
130                        let all_tool_results = thread
131                            .tool_results_for_message(message.id)
132                            .into_iter()
133                            .chain(thread.scripting_tool_results_for_message(message.id))
134                            .map(|tool_result| SavedToolResult {
135                                tool_use_id: tool_result.tool_use_id.clone(),
136                                is_error: tool_result.is_error,
137                                content: tool_result.content.clone(),
138                            })
139                            .collect();
140
141                        SavedMessage {
142                            id: message.id,
143                            role: message.role,
144                            text: message.text.clone(),
145                            tool_uses: all_tool_uses,
146                            tool_results: all_tool_results,
147                        }
148                    })
149                    .collect(),
150            };
151
152            (id, thread)
153        });
154
155        let database_future = ThreadsDatabase::global_future(cx);
156        cx.spawn(|this, mut cx| async move {
157            let database = database_future.await.map_err(|err| anyhow!(err))?;
158            database.save_thread(metadata, thread).await?;
159
160            this.update(&mut cx, |this, cx| this.reload(cx))?.await
161        })
162    }
163
164    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
165        let id = id.clone();
166        let database_future = ThreadsDatabase::global_future(cx);
167        cx.spawn(|this, mut cx| async move {
168            let database = database_future.await.map_err(|err| anyhow!(err))?;
169            database.delete_thread(id.clone()).await?;
170
171            this.update(&mut cx, |this, _cx| {
172                this.threads.retain(|thread| thread.id != id)
173            })
174        })
175    }
176
177    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
178        let database_future = ThreadsDatabase::global_future(cx);
179        cx.spawn(|this, mut cx| async move {
180            let threads = database_future
181                .await
182                .map_err(|err| anyhow!(err))?
183                .list_threads()
184                .await?;
185
186            this.update(&mut cx, |this, cx| {
187                this.threads = threads;
188                cx.notify();
189            })
190        })
191    }
192
193    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
194        cx.subscribe(
195            &self.context_server_manager.clone(),
196            Self::handle_context_server_event,
197        )
198        .detach();
199    }
200
201    fn handle_context_server_event(
202        &mut self,
203        context_server_manager: Entity<ContextServerManager>,
204        event: &context_server::manager::Event,
205        cx: &mut Context<Self>,
206    ) {
207        let tool_working_set = self.tools.clone();
208        match event {
209            context_server::manager::Event::ServerStarted { server_id } => {
210                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
211                    let context_server_manager = context_server_manager.clone();
212                    cx.spawn({
213                        let server = server.clone();
214                        let server_id = server_id.clone();
215                        |this, mut cx| async move {
216                            let Some(protocol) = server.client() else {
217                                return;
218                            };
219
220                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
221                                if let Some(tools) = protocol.list_tools().await.log_err() {
222                                    let tool_ids = tools
223                                        .tools
224                                        .into_iter()
225                                        .map(|tool| {
226                                            log::info!(
227                                                "registering context server tool: {:?}",
228                                                tool.name
229                                            );
230                                            tool_working_set.insert(Arc::new(
231                                                ContextServerTool::new(
232                                                    context_server_manager.clone(),
233                                                    server.id(),
234                                                    tool,
235                                                ),
236                                            ))
237                                        })
238                                        .collect::<Vec<_>>();
239
240                                    this.update(&mut cx, |this, _cx| {
241                                        this.context_server_tool_ids.insert(server_id, tool_ids);
242                                    })
243                                    .log_err();
244                                }
245                            }
246                        }
247                    })
248                    .detach();
249                }
250            }
251            context_server::manager::Event::ServerStopped { server_id } => {
252                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
253                    tool_working_set.remove(&tool_ids);
254                }
255            }
256        }
257    }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct SavedThreadMetadata {
262    pub id: ThreadId,
263    pub summary: SharedString,
264    pub updated_at: DateTime<Utc>,
265}
266
267#[derive(Serialize, Deserialize)]
268pub struct SavedThread {
269    pub summary: SharedString,
270    pub updated_at: DateTime<Utc>,
271    pub messages: Vec<SavedMessage>,
272}
273
274#[derive(Debug, Serialize, Deserialize)]
275pub struct SavedMessage {
276    pub id: MessageId,
277    pub role: Role,
278    pub text: String,
279    #[serde(default)]
280    pub tool_uses: Vec<SavedToolUse>,
281    #[serde(default)]
282    pub tool_results: Vec<SavedToolResult>,
283}
284
285#[derive(Debug, Serialize, Deserialize)]
286pub struct SavedToolUse {
287    pub id: LanguageModelToolUseId,
288    pub name: SharedString,
289    pub input: serde_json::Value,
290}
291
292#[derive(Debug, Serialize, Deserialize)]
293pub struct SavedToolResult {
294    pub tool_use_id: LanguageModelToolUseId,
295    pub is_error: bool,
296    pub content: Arc<str>,
297}
298
299struct GlobalThreadsDatabase(
300    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
301);
302
303impl Global for GlobalThreadsDatabase {}
304
305pub(crate) struct ThreadsDatabase {
306    executor: BackgroundExecutor,
307    env: heed::Env,
308    threads: Database<SerdeBincode<ThreadId>, SerdeJson<SavedThread>>,
309}
310
311impl ThreadsDatabase {
312    fn global_future(
313        cx: &mut App,
314    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
315        GlobalThreadsDatabase::global(cx).0.clone()
316    }
317
318    fn init(cx: &mut App) {
319        let executor = cx.background_executor().clone();
320        let database_future = executor
321            .spawn({
322                let executor = executor.clone();
323                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
324                async move { ThreadsDatabase::new(database_path, executor) }
325            })
326            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
327            .boxed()
328            .shared();
329
330        cx.set_global(GlobalThreadsDatabase(database_future));
331    }
332
333    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
334        std::fs::create_dir_all(&path)?;
335
336        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
337        let env = unsafe {
338            heed::EnvOpenOptions::new()
339                .map_size(ONE_GB_IN_BYTES)
340                .max_dbs(1)
341                .open(path)?
342        };
343
344        let mut txn = env.write_txn()?;
345        let threads = env.create_database(&mut txn, Some("threads"))?;
346        txn.commit()?;
347
348        Ok(Self {
349            executor,
350            env,
351            threads,
352        })
353    }
354
355    pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
356        let env = self.env.clone();
357        let threads = self.threads;
358
359        self.executor.spawn(async move {
360            let txn = env.read_txn()?;
361            let mut iter = threads.iter(&txn)?;
362            let mut threads = Vec::new();
363            while let Some((key, value)) = iter.next().transpose()? {
364                threads.push(SavedThreadMetadata {
365                    id: key,
366                    summary: value.summary,
367                    updated_at: value.updated_at,
368                });
369            }
370
371            Ok(threads)
372        })
373    }
374
375    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
376        let env = self.env.clone();
377        let threads = self.threads;
378
379        self.executor.spawn(async move {
380            let txn = env.read_txn()?;
381            let thread = threads.get(&txn, &id)?;
382            Ok(thread)
383        })
384    }
385
386    pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
387        let env = self.env.clone();
388        let threads = self.threads;
389
390        self.executor.spawn(async move {
391            let mut txn = env.write_txn()?;
392            threads.put(&mut txn, &id, &thread)?;
393            txn.commit()?;
394            Ok(())
395        })
396    }
397
398    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
399        let env = self.env.clone();
400        let threads = self.threads;
401
402        self.executor.spawn(async move {
403            let mut txn = env.write_txn()?;
404            threads.delete(&mut txn, &id)?;
405            txn.commit()?;
406            Ok(())
407        })
408    }
409}