thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::manager::ContextServerManager;
 13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 14use fs::Fs;
 15use futures::channel::{mpsc, oneshot};
 16use futures::future::{self, BoxFuture, Shared};
 17use futures::{FutureExt as _, StreamExt as _};
 18use gpui::{
 19    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 20    Subscription, Task, prelude::*,
 21};
 22use heed::Database;
 23use heed::types::SerdeBincode;
 24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 25use project::{Project, Worktree};
 26use prompt_store::{
 27    DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptStore, PromptsUpdatedEvent,
 28    RulesFileContext, WorktreeContext,
 29};
 30use serde::{Deserialize, Serialize};
 31use settings::{Settings as _, SettingsStore};
 32use util::ResultExt as _;
 33
 34use crate::thread::{
 35    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 36};
 37
 38const RULES_FILE_NAMES: [&'static str; 6] = [
 39    ".rules",
 40    ".cursorrules",
 41    ".windsurfrules",
 42    ".clinerules",
 43    ".github/copilot-instructions.md",
 44    "CLAUDE.md",
 45];
 46
 47pub fn init(cx: &mut App) {
 48    ThreadsDatabase::init(cx);
 49}
 50
 51/// A system prompt shared by all threads created by this ThreadStore
 52#[derive(Clone, Default)]
 53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 54
 55impl SharedProjectContext {
 56    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 57        self.0.borrow()
 58    }
 59}
 60
 61pub struct ThreadStore {
 62    project: Entity<Project>,
 63    tools: Entity<ToolWorkingSet>,
 64    prompt_builder: Arc<PromptBuilder>,
 65    context_server_manager: Entity<ContextServerManager>,
 66    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 67    threads: Vec<SerializedThreadMetadata>,
 68    project_context: SharedProjectContext,
 69    reload_system_prompt_tx: mpsc::Sender<()>,
 70    _reload_system_prompt_task: Task<()>,
 71    _subscriptions: Vec<Subscription>,
 72}
 73
 74pub struct RulesLoadingError {
 75    pub message: SharedString,
 76}
 77
 78impl EventEmitter<RulesLoadingError> for ThreadStore {}
 79
 80impl ThreadStore {
 81    pub fn load(
 82        project: Entity<Project>,
 83        tools: Entity<ToolWorkingSet>,
 84        prompt_builder: Arc<PromptBuilder>,
 85        cx: &mut App,
 86    ) -> Task<Result<Entity<Self>>> {
 87        let prompt_store = PromptStore::global(cx);
 88        cx.spawn(async move |cx| {
 89            let prompt_store = prompt_store.await.ok();
 90            let (thread_store, ready_rx) = cx.update(|cx| {
 91                let mut option_ready_rx = None;
 92                let thread_store = cx.new(|cx| {
 93                    let (thread_store, ready_rx) =
 94                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 95                    option_ready_rx = Some(ready_rx);
 96                    thread_store
 97                });
 98                (thread_store, option_ready_rx.take().unwrap())
 99            })?;
100            ready_rx.await?;
101            Ok(thread_store)
102        })
103    }
104
105    fn new(
106        project: Entity<Project>,
107        tools: Entity<ToolWorkingSet>,
108        prompt_builder: Arc<PromptBuilder>,
109        prompt_store: Option<Entity<PromptStore>>,
110        cx: &mut Context<Self>,
111    ) -> (Self, oneshot::Receiver<()>) {
112        let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
113        let context_server_manager = cx.new(|cx| {
114            ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
115        });
116
117        let mut subscriptions = vec![
118            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
119                this.load_default_profile(cx);
120            }),
121            cx.subscribe(&project, Self::handle_project_event),
122        ];
123
124        if let Some(prompt_store) = prompt_store.as_ref() {
125            subscriptions.push(cx.subscribe(
126                prompt_store,
127                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
128                    this.enqueue_system_prompt_reload();
129                },
130            ))
131        }
132
133        // This channel and task prevent concurrent and redundant loading of the system prompt.
134        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
135        let (ready_tx, ready_rx) = oneshot::channel();
136        let mut ready_tx = Some(ready_tx);
137        let reload_system_prompt_task = cx.spawn({
138            async move |thread_store, cx| {
139                loop {
140                    let Some(reload_task) = thread_store
141                        .update(cx, |thread_store, cx| {
142                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
143                        })
144                        .ok()
145                    else {
146                        return;
147                    };
148                    reload_task.await;
149                    if let Some(ready_tx) = ready_tx.take() {
150                        ready_tx.send(()).ok();
151                    }
152                    reload_system_prompt_rx.next().await;
153                }
154            }
155        });
156
157        let this = Self {
158            project,
159            tools,
160            prompt_builder,
161            context_server_manager,
162            context_server_tool_ids: HashMap::default(),
163            threads: Vec::new(),
164            project_context: SharedProjectContext::default(),
165            reload_system_prompt_tx,
166            _reload_system_prompt_task: reload_system_prompt_task,
167            _subscriptions: subscriptions,
168        };
169        this.load_default_profile(cx);
170        this.register_context_server_handlers(cx);
171        this.reload(cx).detach_and_log_err(cx);
172        (this, ready_rx)
173    }
174
175    fn handle_project_event(
176        &mut self,
177        _project: Entity<Project>,
178        event: &project::Event,
179        _cx: &mut Context<Self>,
180    ) {
181        match event {
182            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
183                self.enqueue_system_prompt_reload();
184            }
185            project::Event::WorktreeUpdatedEntries(_, items) => {
186                if items.iter().any(|(path, _, _)| {
187                    RULES_FILE_NAMES
188                        .iter()
189                        .any(|name| path.as_ref() == Path::new(name))
190                }) {
191                    self.enqueue_system_prompt_reload();
192                }
193            }
194            _ => {}
195        }
196    }
197
198    fn enqueue_system_prompt_reload(&mut self) {
199        self.reload_system_prompt_tx.try_send(()).ok();
200    }
201
202    // Note that this should only be called from `reload_system_prompt_task`.
203    fn reload_system_prompt(
204        &self,
205        prompt_store: Option<Entity<PromptStore>>,
206        cx: &mut Context<Self>,
207    ) -> Task<()> {
208        let project = self.project.read(cx);
209        let worktree_tasks = project
210            .visible_worktrees(cx)
211            .map(|worktree| {
212                Self::load_worktree_info_for_system_prompt(
213                    project.fs().clone(),
214                    worktree.read(cx),
215                    cx,
216                )
217            })
218            .collect::<Vec<_>>();
219        let default_user_rules_task = match prompt_store {
220            None => Task::ready(vec![]),
221            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
222                let prompts = prompt_store.default_prompt_metadata();
223                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
224                    let contents = prompt_store.load(prompt_metadata.id, cx);
225                    async move { (contents.await, prompt_metadata) }
226                });
227                cx.background_spawn(future::join_all(load_tasks))
228            }),
229        };
230
231        cx.spawn(async move |this, cx| {
232            let (worktrees, default_user_rules) =
233                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
234
235            let worktrees = worktrees
236                .into_iter()
237                .map(|(worktree, rules_error)| {
238                    if let Some(rules_error) = rules_error {
239                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
240                    }
241                    worktree
242                })
243                .collect::<Vec<_>>();
244
245            let default_user_rules = default_user_rules
246                .into_iter()
247                .flat_map(|(contents, prompt_metadata)| match contents {
248                    Ok(contents) => Some(DefaultUserRulesContext {
249                        title: prompt_metadata.title.map(|title| title.to_string()),
250                        contents,
251                    }),
252                    Err(err) => {
253                        this.update(cx, |_, cx| {
254                            cx.emit(RulesLoadingError {
255                                message: format!("{err:?}").into(),
256                            });
257                        })
258                        .ok();
259                        None
260                    }
261                })
262                .collect::<Vec<_>>();
263
264            this.update(cx, |this, _cx| {
265                *this.project_context.0.borrow_mut() =
266                    Some(ProjectContext::new(worktrees, default_user_rules));
267            })
268            .ok();
269        })
270    }
271
272    fn load_worktree_info_for_system_prompt(
273        fs: Arc<dyn Fs>,
274        worktree: &Worktree,
275        cx: &App,
276    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
277        let root_name = worktree.root_name().into();
278        let abs_path = worktree.abs_path();
279
280        let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
281        let Some(rules_task) = rules_task else {
282            return Task::ready((
283                WorktreeContext {
284                    root_name,
285                    abs_path,
286                    rules_file: None,
287                },
288                None,
289            ));
290        };
291
292        cx.spawn(async move |_| {
293            let (rules_file, rules_file_error) = match rules_task.await {
294                Ok(rules_file) => (Some(rules_file), None),
295                Err(err) => (
296                    None,
297                    Some(RulesLoadingError {
298                        message: format!("{err}").into(),
299                    }),
300                ),
301            };
302            let worktree_info = WorktreeContext {
303                root_name,
304                abs_path,
305                rules_file,
306            };
307            (worktree_info, rules_file_error)
308        })
309    }
310
311    fn load_worktree_rules_file(
312        fs: Arc<dyn Fs>,
313        worktree: &Worktree,
314        cx: &App,
315    ) -> Option<Task<Result<RulesFileContext>>> {
316        let selected_rules_file = RULES_FILE_NAMES
317            .into_iter()
318            .filter_map(|name| {
319                worktree
320                    .entry_for_path(name)
321                    .filter(|entry| entry.is_file())
322                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
323            })
324            .next();
325
326        // Note that Cline supports `.clinerules` being a directory, but that is not currently
327        // supported. This doesn't seem to occur often in GitHub repositories.
328        selected_rules_file.map(|(path_in_worktree, abs_path)| {
329            let fs = fs.clone();
330            cx.background_spawn(async move {
331                let abs_path = abs_path?;
332                let text = fs.load(&abs_path).await.with_context(|| {
333                    format!("Failed to load assistant rules file {:?}", abs_path)
334                })?;
335                anyhow::Ok(RulesFileContext {
336                    path_in_worktree,
337                    abs_path: abs_path.into(),
338                    text: text.trim().to_string(),
339                })
340            })
341        })
342    }
343
344    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
345        self.context_server_manager.clone()
346    }
347
348    pub fn tools(&self) -> Entity<ToolWorkingSet> {
349        self.tools.clone()
350    }
351
352    /// Returns the number of threads.
353    pub fn thread_count(&self) -> usize {
354        self.threads.len()
355    }
356
357    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
358        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
359        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
360        threads
361    }
362
363    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
364        self.threads().into_iter().take(limit).collect()
365    }
366
367    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
368        cx.new(|cx| {
369            Thread::new(
370                self.project.clone(),
371                self.tools.clone(),
372                self.prompt_builder.clone(),
373                self.project_context.clone(),
374                cx,
375            )
376        })
377    }
378
379    pub fn open_thread(
380        &self,
381        id: &ThreadId,
382        cx: &mut Context<Self>,
383    ) -> Task<Result<Entity<Thread>>> {
384        let id = id.clone();
385        let database_future = ThreadsDatabase::global_future(cx);
386        cx.spawn(async move |this, cx| {
387            let database = database_future.await.map_err(|err| anyhow!(err))?;
388            let thread = database
389                .try_find_thread(id.clone())
390                .await?
391                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
392
393            let thread = this.update(cx, |this, cx| {
394                cx.new(|cx| {
395                    Thread::deserialize(
396                        id.clone(),
397                        thread,
398                        this.project.clone(),
399                        this.tools.clone(),
400                        this.prompt_builder.clone(),
401                        this.project_context.clone(),
402                        cx,
403                    )
404                })
405            })?;
406
407            Ok(thread)
408        })
409    }
410
411    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
412        let (metadata, serialized_thread) =
413            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
414
415        let database_future = ThreadsDatabase::global_future(cx);
416        cx.spawn(async move |this, cx| {
417            let serialized_thread = serialized_thread.await?;
418            let database = database_future.await.map_err(|err| anyhow!(err))?;
419            database.save_thread(metadata, serialized_thread).await?;
420
421            this.update(cx, |this, cx| this.reload(cx))?.await
422        })
423    }
424
425    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
426        let id = id.clone();
427        let database_future = ThreadsDatabase::global_future(cx);
428        cx.spawn(async move |this, cx| {
429            let database = database_future.await.map_err(|err| anyhow!(err))?;
430            database.delete_thread(id.clone()).await?;
431
432            this.update(cx, |this, cx| {
433                this.threads.retain(|thread| thread.id != id);
434                cx.notify();
435            })
436        })
437    }
438
439    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
440        let database_future = ThreadsDatabase::global_future(cx);
441        cx.spawn(async move |this, cx| {
442            let threads = database_future
443                .await
444                .map_err(|err| anyhow!(err))?
445                .list_threads()
446                .await?;
447
448            this.update(cx, |this, cx| {
449                this.threads = threads;
450                cx.notify();
451            })
452        })
453    }
454
455    fn load_default_profile(&self, cx: &mut Context<Self>) {
456        let assistant_settings = AssistantSettings::get_global(cx);
457
458        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
459    }
460
461    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
462        let assistant_settings = AssistantSettings::get_global(cx);
463
464        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
465            self.load_profile(profile.clone(), cx);
466        }
467    }
468
469    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
470        self.tools.update(cx, |tools, cx| {
471            tools.disable_all_tools(cx);
472            tools.enable(
473                ToolSource::Native,
474                &profile
475                    .tools
476                    .iter()
477                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
478                    .collect::<Vec<_>>(),
479                cx,
480            );
481        });
482
483        if profile.enable_all_context_servers {
484            for context_server in self.context_server_manager.read(cx).all_servers() {
485                self.tools.update(cx, |tools, cx| {
486                    tools.enable_source(
487                        ToolSource::ContextServer {
488                            id: context_server.id().into(),
489                        },
490                        cx,
491                    );
492                });
493            }
494        } else {
495            for (context_server_id, preset) in &profile.context_servers {
496                self.tools.update(cx, |tools, cx| {
497                    tools.enable(
498                        ToolSource::ContextServer {
499                            id: context_server_id.clone().into(),
500                        },
501                        &preset
502                            .tools
503                            .iter()
504                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
505                            .collect::<Vec<_>>(),
506                        cx,
507                    )
508                })
509            }
510        }
511    }
512
513    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
514        cx.subscribe(
515            &self.context_server_manager.clone(),
516            Self::handle_context_server_event,
517        )
518        .detach();
519    }
520
521    fn handle_context_server_event(
522        &mut self,
523        context_server_manager: Entity<ContextServerManager>,
524        event: &context_server::manager::Event,
525        cx: &mut Context<Self>,
526    ) {
527        let tool_working_set = self.tools.clone();
528        match event {
529            context_server::manager::Event::ServerStarted { server_id } => {
530                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
531                    let context_server_manager = context_server_manager.clone();
532                    cx.spawn({
533                        let server = server.clone();
534                        let server_id = server_id.clone();
535                        async move |this, cx| {
536                            let Some(protocol) = server.client() else {
537                                return;
538                            };
539
540                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
541                                if let Some(tools) = protocol.list_tools().await.log_err() {
542                                    let tool_ids = tool_working_set
543                                        .update(cx, |tool_working_set, _| {
544                                            tools
545                                                .tools
546                                                .into_iter()
547                                                .map(|tool| {
548                                                    log::info!(
549                                                        "registering context server tool: {:?}",
550                                                        tool.name
551                                                    );
552                                                    tool_working_set.insert(Arc::new(
553                                                        ContextServerTool::new(
554                                                            context_server_manager.clone(),
555                                                            server.id(),
556                                                            tool,
557                                                        ),
558                                                    ))
559                                                })
560                                                .collect::<Vec<_>>()
561                                        })
562                                        .log_err();
563
564                                    if let Some(tool_ids) = tool_ids {
565                                        this.update(cx, |this, cx| {
566                                            this.context_server_tool_ids
567                                                .insert(server_id, tool_ids);
568                                            this.load_default_profile(cx);
569                                        })
570                                        .log_err();
571                                    }
572                                }
573                            }
574                        }
575                    })
576                    .detach();
577                }
578            }
579            context_server::manager::Event::ServerStopped { server_id } => {
580                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
581                    tool_working_set.update(cx, |tool_working_set, _| {
582                        tool_working_set.remove(&tool_ids);
583                    });
584                    self.load_default_profile(cx);
585                }
586            }
587        }
588    }
589}
590
591#[derive(Debug, Clone, Serialize, Deserialize)]
592pub struct SerializedThreadMetadata {
593    pub id: ThreadId,
594    pub summary: SharedString,
595    pub updated_at: DateTime<Utc>,
596}
597
598#[derive(Serialize, Deserialize, Debug)]
599pub struct SerializedThread {
600    pub version: String,
601    pub summary: SharedString,
602    pub updated_at: DateTime<Utc>,
603    pub messages: Vec<SerializedMessage>,
604    #[serde(default)]
605    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
606    #[serde(default)]
607    pub cumulative_token_usage: TokenUsage,
608    #[serde(default)]
609    pub request_token_usage: Vec<TokenUsage>,
610    #[serde(default)]
611    pub detailed_summary_state: DetailedSummaryState,
612    #[serde(default)]
613    pub exceeded_window_error: Option<ExceededWindowError>,
614}
615
616impl SerializedThread {
617    pub const VERSION: &'static str = "0.1.0";
618
619    pub fn from_json(json: &[u8]) -> Result<Self> {
620        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
621        match saved_thread_json.get("version") {
622            Some(serde_json::Value::String(version)) => match version.as_str() {
623                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
624                    saved_thread_json,
625                )?),
626                _ => Err(anyhow!(
627                    "unrecognized serialized thread version: {}",
628                    version
629                )),
630            },
631            None => {
632                let saved_thread =
633                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
634                Ok(saved_thread.upgrade())
635            }
636            version => Err(anyhow!(
637                "unrecognized serialized thread version: {:?}",
638                version
639            )),
640        }
641    }
642}
643
644#[derive(Debug, Serialize, Deserialize)]
645pub struct SerializedMessage {
646    pub id: MessageId,
647    pub role: Role,
648    #[serde(default)]
649    pub segments: Vec<SerializedMessageSegment>,
650    #[serde(default)]
651    pub tool_uses: Vec<SerializedToolUse>,
652    #[serde(default)]
653    pub tool_results: Vec<SerializedToolResult>,
654    #[serde(default)]
655    pub context: String,
656}
657
658#[derive(Debug, Serialize, Deserialize)]
659#[serde(tag = "type")]
660pub enum SerializedMessageSegment {
661    #[serde(rename = "text")]
662    Text { text: String },
663    #[serde(rename = "thinking")]
664    Thinking { text: String },
665}
666
667#[derive(Debug, Serialize, Deserialize)]
668pub struct SerializedToolUse {
669    pub id: LanguageModelToolUseId,
670    pub name: SharedString,
671    pub input: serde_json::Value,
672}
673
674#[derive(Debug, Serialize, Deserialize)]
675pub struct SerializedToolResult {
676    pub tool_use_id: LanguageModelToolUseId,
677    pub is_error: bool,
678    pub content: Arc<str>,
679}
680
681#[derive(Serialize, Deserialize)]
682struct LegacySerializedThread {
683    pub summary: SharedString,
684    pub updated_at: DateTime<Utc>,
685    pub messages: Vec<LegacySerializedMessage>,
686    #[serde(default)]
687    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
688}
689
690impl LegacySerializedThread {
691    pub fn upgrade(self) -> SerializedThread {
692        SerializedThread {
693            version: SerializedThread::VERSION.to_string(),
694            summary: self.summary,
695            updated_at: self.updated_at,
696            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
697            initial_project_snapshot: self.initial_project_snapshot,
698            cumulative_token_usage: TokenUsage::default(),
699            request_token_usage: Vec::new(),
700            detailed_summary_state: DetailedSummaryState::default(),
701            exceeded_window_error: None,
702        }
703    }
704}
705
706#[derive(Debug, Serialize, Deserialize)]
707struct LegacySerializedMessage {
708    pub id: MessageId,
709    pub role: Role,
710    pub text: String,
711    #[serde(default)]
712    pub tool_uses: Vec<SerializedToolUse>,
713    #[serde(default)]
714    pub tool_results: Vec<SerializedToolResult>,
715}
716
717impl LegacySerializedMessage {
718    fn upgrade(self) -> SerializedMessage {
719        SerializedMessage {
720            id: self.id,
721            role: self.role,
722            segments: vec![SerializedMessageSegment::Text { text: self.text }],
723            tool_uses: self.tool_uses,
724            tool_results: self.tool_results,
725            context: String::new(),
726        }
727    }
728}
729
730struct GlobalThreadsDatabase(
731    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
732);
733
734impl Global for GlobalThreadsDatabase {}
735
736pub(crate) struct ThreadsDatabase {
737    executor: BackgroundExecutor,
738    env: heed::Env,
739    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
740}
741
742impl heed::BytesEncode<'_> for SerializedThread {
743    type EItem = SerializedThread;
744
745    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
746        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
747    }
748}
749
750impl<'a> heed::BytesDecode<'a> for SerializedThread {
751    type DItem = SerializedThread;
752
753    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
754        // We implement this type manually because we want to call `SerializedThread::from_json`,
755        // instead of the Deserialize trait implementation for `SerializedThread`.
756        SerializedThread::from_json(bytes).map_err(Into::into)
757    }
758}
759
760impl ThreadsDatabase {
761    fn global_future(
762        cx: &mut App,
763    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
764        GlobalThreadsDatabase::global(cx).0.clone()
765    }
766
767    fn init(cx: &mut App) {
768        let executor = cx.background_executor().clone();
769        let database_future = executor
770            .spawn({
771                let executor = executor.clone();
772                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
773                async move { ThreadsDatabase::new(database_path, executor) }
774            })
775            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
776            .boxed()
777            .shared();
778
779        cx.set_global(GlobalThreadsDatabase(database_future));
780    }
781
782    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
783        std::fs::create_dir_all(&path)?;
784
785        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
786        let env = unsafe {
787            heed::EnvOpenOptions::new()
788                .map_size(ONE_GB_IN_BYTES)
789                .max_dbs(1)
790                .open(path)?
791        };
792
793        let mut txn = env.write_txn()?;
794        let threads = env.create_database(&mut txn, Some("threads"))?;
795        txn.commit()?;
796
797        Ok(Self {
798            executor,
799            env,
800            threads,
801        })
802    }
803
804    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
805        let env = self.env.clone();
806        let threads = self.threads;
807
808        self.executor.spawn(async move {
809            let txn = env.read_txn()?;
810            let mut iter = threads.iter(&txn)?;
811            let mut threads = Vec::new();
812            while let Some((key, value)) = iter.next().transpose()? {
813                threads.push(SerializedThreadMetadata {
814                    id: key,
815                    summary: value.summary,
816                    updated_at: value.updated_at,
817                });
818            }
819
820            Ok(threads)
821        })
822    }
823
824    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
825        let env = self.env.clone();
826        let threads = self.threads;
827
828        self.executor.spawn(async move {
829            let txn = env.read_txn()?;
830            let thread = threads.get(&txn, &id)?;
831            Ok(thread)
832        })
833    }
834
835    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
836        let env = self.env.clone();
837        let threads = self.threads;
838
839        self.executor.spawn(async move {
840            let mut txn = env.write_txn()?;
841            threads.put(&mut txn, &id, &thread)?;
842            txn.commit()?;
843            Ok(())
844        })
845    }
846
847    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
848        let env = self.env.clone();
849        let threads = self.threads;
850
851        self.executor.spawn(async move {
852            let mut txn = env.write_txn()?;
853            threads.delete(&mut txn, &id)?;
854            txn.commit()?;
855            Ok(())
856        })
857    }
858}