thread_store.rs

  1use std::borrow::Cow;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{anyhow, Result};
  6use assistant_tool::{ToolId, ToolWorkingSet};
  7use chrono::{DateTime, Utc};
  8use collections::HashMap;
  9use context_server::manager::ContextServerManager;
 10use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 11use futures::future::{self, BoxFuture, Shared};
 12use futures::FutureExt as _;
 13use gpui::{
 14    prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
 15};
 16use heed::types::SerdeBincode;
 17use heed::Database;
 18use language_model::{LanguageModelToolUseId, Role};
 19use project::Project;
 20use prompt_store::PromptBuilder;
 21use serde::{Deserialize, Serialize};
 22use util::ResultExt as _;
 23
 24use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
 25
 26pub fn init(cx: &mut App) {
 27    ThreadsDatabase::init(cx);
 28}
 29
 30pub struct ThreadStore {
 31    project: Entity<Project>,
 32    tools: Arc<ToolWorkingSet>,
 33    prompt_builder: Arc<PromptBuilder>,
 34    context_server_manager: Entity<ContextServerManager>,
 35    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 36    threads: Vec<SerializedThreadMetadata>,
 37}
 38
 39impl ThreadStore {
 40    pub fn new(
 41        project: Entity<Project>,
 42        tools: Arc<ToolWorkingSet>,
 43        prompt_builder: Arc<PromptBuilder>,
 44        cx: &mut App,
 45    ) -> Result<Entity<Self>> {
 46        let this = cx.new(|cx| {
 47            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 48            let context_server_manager = cx.new(|cx| {
 49                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 50            });
 51
 52            let this = Self {
 53                project,
 54                tools,
 55                prompt_builder,
 56                context_server_manager,
 57                context_server_tool_ids: HashMap::default(),
 58                threads: Vec::new(),
 59            };
 60            this.register_context_server_handlers(cx);
 61            this.reload(cx).detach_and_log_err(cx);
 62
 63            this
 64        });
 65
 66        Ok(this)
 67    }
 68
 69    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
 70        self.context_server_manager.clone()
 71    }
 72
 73    pub fn tools(&self) -> Arc<ToolWorkingSet> {
 74        self.tools.clone()
 75    }
 76
 77    /// Returns the number of threads.
 78    pub fn thread_count(&self) -> usize {
 79        self.threads.len()
 80    }
 81
 82    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
 83        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 84        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 85        threads
 86    }
 87
 88    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
 89        self.threads().into_iter().take(limit).collect()
 90    }
 91
 92    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 93        cx.new(|cx| {
 94            Thread::new(
 95                self.project.clone(),
 96                self.tools.clone(),
 97                self.prompt_builder.clone(),
 98                cx,
 99            )
100        })
101    }
102
103    pub fn open_thread(
104        &self,
105        id: &ThreadId,
106        cx: &mut Context<Self>,
107    ) -> Task<Result<Entity<Thread>>> {
108        let id = id.clone();
109        let database_future = ThreadsDatabase::global_future(cx);
110        cx.spawn(async move |this, cx| {
111            let database = database_future.await.map_err(|err| anyhow!(err))?;
112            let thread = database
113                .try_find_thread(id.clone())
114                .await?
115                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
116
117            let thread = this.update(cx, |this, cx| {
118                cx.new(|cx| {
119                    Thread::deserialize(
120                        id.clone(),
121                        thread,
122                        this.project.clone(),
123                        this.tools.clone(),
124                        this.prompt_builder.clone(),
125                        cx,
126                    )
127                })
128            })?;
129
130            let (system_prompt_context, load_error) = thread
131                .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
132                .await;
133            thread.update(cx, |thread, cx| {
134                thread.set_system_prompt_context(system_prompt_context);
135                if let Some(load_error) = load_error {
136                    cx.emit(ThreadEvent::ShowError(load_error));
137                }
138            })?;
139
140            Ok(thread)
141        })
142    }
143
144    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
145        let (metadata, serialized_thread) =
146            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
147
148        let database_future = ThreadsDatabase::global_future(cx);
149        cx.spawn(async move |this, cx| {
150            let serialized_thread = serialized_thread.await?;
151            let database = database_future.await.map_err(|err| anyhow!(err))?;
152            database.save_thread(metadata, serialized_thread).await?;
153
154            this.update(cx, |this, cx| this.reload(cx))?.await
155        })
156    }
157
158    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
159        let id = id.clone();
160        let database_future = ThreadsDatabase::global_future(cx);
161        cx.spawn(async move |this, cx| {
162            let database = database_future.await.map_err(|err| anyhow!(err))?;
163            database.delete_thread(id.clone()).await?;
164
165            this.update(cx, |this, _cx| {
166                this.threads.retain(|thread| thread.id != id)
167            })
168        })
169    }
170
171    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
172        let database_future = ThreadsDatabase::global_future(cx);
173        cx.spawn(async move |this, cx| {
174            let threads = database_future
175                .await
176                .map_err(|err| anyhow!(err))?
177                .list_threads()
178                .await?;
179
180            this.update(cx, |this, cx| {
181                this.threads = threads;
182                cx.notify();
183            })
184        })
185    }
186
187    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
188        cx.subscribe(
189            &self.context_server_manager.clone(),
190            Self::handle_context_server_event,
191        )
192        .detach();
193    }
194
195    fn handle_context_server_event(
196        &mut self,
197        context_server_manager: Entity<ContextServerManager>,
198        event: &context_server::manager::Event,
199        cx: &mut Context<Self>,
200    ) {
201        let tool_working_set = self.tools.clone();
202        match event {
203            context_server::manager::Event::ServerStarted { server_id } => {
204                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
205                    let context_server_manager = context_server_manager.clone();
206                    cx.spawn({
207                        let server = server.clone();
208                        let server_id = server_id.clone();
209                        async move |this, cx| {
210                            let Some(protocol) = server.client() else {
211                                return;
212                            };
213
214                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
215                                if let Some(tools) = protocol.list_tools().await.log_err() {
216                                    let tool_ids = tools
217                                        .tools
218                                        .into_iter()
219                                        .map(|tool| {
220                                            log::info!(
221                                                "registering context server tool: {:?}",
222                                                tool.name
223                                            );
224                                            tool_working_set.insert(Arc::new(
225                                                ContextServerTool::new(
226                                                    context_server_manager.clone(),
227                                                    server.id(),
228                                                    tool,
229                                                ),
230                                            ))
231                                        })
232                                        .collect::<Vec<_>>();
233
234                                    this.update(cx, |this, _cx| {
235                                        this.context_server_tool_ids.insert(server_id, tool_ids);
236                                    })
237                                    .log_err();
238                                }
239                            }
240                        }
241                    })
242                    .detach();
243                }
244            }
245            context_server::manager::Event::ServerStopped { server_id } => {
246                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
247                    tool_working_set.remove(&tool_ids);
248                }
249            }
250        }
251    }
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct SerializedThreadMetadata {
256    pub id: ThreadId,
257    pub summary: SharedString,
258    pub updated_at: DateTime<Utc>,
259}
260
261#[derive(Serialize, Deserialize)]
262pub struct SerializedThread {
263    pub version: String,
264    pub summary: SharedString,
265    pub updated_at: DateTime<Utc>,
266    pub messages: Vec<SerializedMessage>,
267    #[serde(default)]
268    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
269}
270
271impl SerializedThread {
272    pub const VERSION: &'static str = "0.1.0";
273
274    pub fn from_json(json: &[u8]) -> Result<Self> {
275        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
276        match saved_thread_json.get("version") {
277            Some(serde_json::Value::String(version)) => match version.as_str() {
278                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
279                    saved_thread_json,
280                )?),
281                _ => Err(anyhow!(
282                    "unrecognized serialized thread version: {}",
283                    version
284                )),
285            },
286            None => {
287                let saved_thread =
288                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
289                Ok(saved_thread.upgrade())
290            }
291            version => Err(anyhow!(
292                "unrecognized serialized thread version: {:?}",
293                version
294            )),
295        }
296    }
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300pub struct SerializedMessage {
301    pub id: MessageId,
302    pub role: Role,
303    #[serde(default)]
304    pub segments: Vec<SerializedMessageSegment>,
305    #[serde(default)]
306    pub tool_uses: Vec<SerializedToolUse>,
307    #[serde(default)]
308    pub tool_results: Vec<SerializedToolResult>,
309}
310
311#[derive(Debug, Serialize, Deserialize)]
312#[serde(tag = "type")]
313pub enum SerializedMessageSegment {
314    #[serde(rename = "text")]
315    Text { text: String },
316    #[serde(rename = "thinking")]
317    Thinking { text: String },
318}
319
320#[derive(Debug, Serialize, Deserialize)]
321pub struct SerializedToolUse {
322    pub id: LanguageModelToolUseId,
323    pub name: SharedString,
324    pub input: serde_json::Value,
325}
326
327#[derive(Debug, Serialize, Deserialize)]
328pub struct SerializedToolResult {
329    pub tool_use_id: LanguageModelToolUseId,
330    pub is_error: bool,
331    pub content: Arc<str>,
332}
333
334#[derive(Serialize, Deserialize)]
335struct LegacySerializedThread {
336    pub summary: SharedString,
337    pub updated_at: DateTime<Utc>,
338    pub messages: Vec<LegacySerializedMessage>,
339    #[serde(default)]
340    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
341}
342
343impl LegacySerializedThread {
344    pub fn upgrade(self) -> SerializedThread {
345        SerializedThread {
346            version: SerializedThread::VERSION.to_string(),
347            summary: self.summary,
348            updated_at: self.updated_at,
349            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
350            initial_project_snapshot: self.initial_project_snapshot,
351        }
352    }
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356struct LegacySerializedMessage {
357    pub id: MessageId,
358    pub role: Role,
359    pub text: String,
360    #[serde(default)]
361    pub tool_uses: Vec<SerializedToolUse>,
362    #[serde(default)]
363    pub tool_results: Vec<SerializedToolResult>,
364}
365
366impl LegacySerializedMessage {
367    fn upgrade(self) -> SerializedMessage {
368        SerializedMessage {
369            id: self.id,
370            role: self.role,
371            segments: vec![SerializedMessageSegment::Text { text: self.text }],
372            tool_uses: self.tool_uses,
373            tool_results: self.tool_results,
374        }
375    }
376}
377
378struct GlobalThreadsDatabase(
379    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
380);
381
382impl Global for GlobalThreadsDatabase {}
383
384pub(crate) struct ThreadsDatabase {
385    executor: BackgroundExecutor,
386    env: heed::Env,
387    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
388}
389
390impl heed::BytesEncode<'_> for SerializedThread {
391    type EItem = SerializedThread;
392
393    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
394        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
395    }
396}
397
398impl<'a> heed::BytesDecode<'a> for SerializedThread {
399    type DItem = SerializedThread;
400
401    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
402        // We implement this type manually because we want to call `SerializedThread::from_json`,
403        // instead of the Deserialize trait implementation for `SerializedThread`.
404        SerializedThread::from_json(bytes).map_err(Into::into)
405    }
406}
407
408impl ThreadsDatabase {
409    fn global_future(
410        cx: &mut App,
411    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
412        GlobalThreadsDatabase::global(cx).0.clone()
413    }
414
415    fn init(cx: &mut App) {
416        let executor = cx.background_executor().clone();
417        let database_future = executor
418            .spawn({
419                let executor = executor.clone();
420                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
421                async move { ThreadsDatabase::new(database_path, executor) }
422            })
423            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
424            .boxed()
425            .shared();
426
427        cx.set_global(GlobalThreadsDatabase(database_future));
428    }
429
430    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
431        std::fs::create_dir_all(&path)?;
432
433        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
434        let env = unsafe {
435            heed::EnvOpenOptions::new()
436                .map_size(ONE_GB_IN_BYTES)
437                .max_dbs(1)
438                .open(path)?
439        };
440
441        let mut txn = env.write_txn()?;
442        let threads = env.create_database(&mut txn, Some("threads"))?;
443        txn.commit()?;
444
445        Ok(Self {
446            executor,
447            env,
448            threads,
449        })
450    }
451
452    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
453        let env = self.env.clone();
454        let threads = self.threads;
455
456        self.executor.spawn(async move {
457            let txn = env.read_txn()?;
458            let mut iter = threads.iter(&txn)?;
459            let mut threads = Vec::new();
460            while let Some((key, value)) = iter.next().transpose()? {
461                threads.push(SerializedThreadMetadata {
462                    id: key,
463                    summary: value.summary,
464                    updated_at: value.updated_at,
465                });
466            }
467
468            Ok(threads)
469        })
470    }
471
472    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
473        let env = self.env.clone();
474        let threads = self.threads;
475
476        self.executor.spawn(async move {
477            let txn = env.read_txn()?;
478            let thread = threads.get(&txn, &id)?;
479            Ok(thread)
480        })
481    }
482
483    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
484        let env = self.env.clone();
485        let threads = self.threads;
486
487        self.executor.spawn(async move {
488            let mut txn = env.write_txn()?;
489            threads.put(&mut txn, &id, &thread)?;
490            txn.commit()?;
491            Ok(())
492        })
493    }
494
495    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
496        let env = self.env.clone();
497        let threads = self.threads;
498
499        self.executor.spawn(async move {
500            let mut txn = env.write_txn()?;
501            threads.delete(&mut txn, &id)?;
502            txn.commit()?;
503            Ok(())
504        })
505    }
506}