thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::manager::ContextServerManager;
 13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 14use fs::Fs;
 15use futures::channel::{mpsc, oneshot};
 16use futures::future::{self, BoxFuture, Shared};
 17use futures::{FutureExt as _, StreamExt as _};
 18use gpui::{
 19    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 20    Subscription, Task, prelude::*,
 21};
 22use heed::Database;
 23use heed::types::SerdeBincode;
 24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 25use project::{Project, Worktree};
 26use prompt_store::{
 27    DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptId, PromptStore,
 28    PromptsUpdatedEvent, RulesFileContext, WorktreeContext,
 29};
 30use serde::{Deserialize, Serialize};
 31use settings::{Settings as _, SettingsStore};
 32use util::ResultExt as _;
 33
 34use crate::thread::{
 35    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 36};
 37
 38const RULES_FILE_NAMES: [&'static str; 6] = [
 39    ".rules",
 40    ".cursorrules",
 41    ".windsurfrules",
 42    ".clinerules",
 43    ".github/copilot-instructions.md",
 44    "CLAUDE.md",
 45];
 46
 47pub fn init(cx: &mut App) {
 48    ThreadsDatabase::init(cx);
 49}
 50
 51/// A system prompt shared by all threads created by this ThreadStore
 52#[derive(Clone, Default)]
 53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 54
 55impl SharedProjectContext {
 56    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 57        self.0.borrow()
 58    }
 59}
 60
 61pub struct ThreadStore {
 62    project: Entity<Project>,
 63    tools: Entity<ToolWorkingSet>,
 64    prompt_builder: Arc<PromptBuilder>,
 65    context_server_manager: Entity<ContextServerManager>,
 66    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 67    threads: Vec<SerializedThreadMetadata>,
 68    project_context: SharedProjectContext,
 69    reload_system_prompt_tx: mpsc::Sender<()>,
 70    _reload_system_prompt_task: Task<()>,
 71    _subscriptions: Vec<Subscription>,
 72}
 73
 74pub struct RulesLoadingError {
 75    pub message: SharedString,
 76}
 77
 78impl EventEmitter<RulesLoadingError> for ThreadStore {}
 79
 80impl ThreadStore {
 81    pub fn load(
 82        project: Entity<Project>,
 83        tools: Entity<ToolWorkingSet>,
 84        prompt_builder: Arc<PromptBuilder>,
 85        cx: &mut App,
 86    ) -> Task<Result<Entity<Self>>> {
 87        let prompt_store = PromptStore::global(cx);
 88        cx.spawn(async move |cx| {
 89            let prompt_store = prompt_store.await.ok();
 90            let (thread_store, ready_rx) = cx.update(|cx| {
 91                let mut option_ready_rx = None;
 92                let thread_store = cx.new(|cx| {
 93                    let (thread_store, ready_rx) =
 94                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 95                    option_ready_rx = Some(ready_rx);
 96                    thread_store
 97                });
 98                (thread_store, option_ready_rx.take().unwrap())
 99            })?;
100            ready_rx.await?;
101            Ok(thread_store)
102        })
103    }
104
105    fn new(
106        project: Entity<Project>,
107        tools: Entity<ToolWorkingSet>,
108        prompt_builder: Arc<PromptBuilder>,
109        prompt_store: Option<Entity<PromptStore>>,
110        cx: &mut Context<Self>,
111    ) -> (Self, oneshot::Receiver<()>) {
112        let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
113        let context_server_manager = cx.new(|cx| {
114            ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
115        });
116
117        let mut subscriptions = vec![
118            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
119                this.load_default_profile(cx);
120            }),
121            cx.subscribe(&project, Self::handle_project_event),
122        ];
123
124        if let Some(prompt_store) = prompt_store.as_ref() {
125            subscriptions.push(cx.subscribe(
126                prompt_store,
127                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
128                    this.enqueue_system_prompt_reload();
129                },
130            ))
131        }
132
133        // This channel and task prevent concurrent and redundant loading of the system prompt.
134        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
135        let (ready_tx, ready_rx) = oneshot::channel();
136        let mut ready_tx = Some(ready_tx);
137        let reload_system_prompt_task = cx.spawn({
138            async move |thread_store, cx| {
139                loop {
140                    let Some(reload_task) = thread_store
141                        .update(cx, |thread_store, cx| {
142                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
143                        })
144                        .ok()
145                    else {
146                        return;
147                    };
148                    reload_task.await;
149                    if let Some(ready_tx) = ready_tx.take() {
150                        ready_tx.send(()).ok();
151                    }
152                    reload_system_prompt_rx.next().await;
153                }
154            }
155        });
156
157        let this = Self {
158            project,
159            tools,
160            prompt_builder,
161            context_server_manager,
162            context_server_tool_ids: HashMap::default(),
163            threads: Vec::new(),
164            project_context: SharedProjectContext::default(),
165            reload_system_prompt_tx,
166            _reload_system_prompt_task: reload_system_prompt_task,
167            _subscriptions: subscriptions,
168        };
169        this.load_default_profile(cx);
170        this.register_context_server_handlers(cx);
171        this.reload(cx).detach_and_log_err(cx);
172        (this, ready_rx)
173    }
174
175    fn handle_project_event(
176        &mut self,
177        _project: Entity<Project>,
178        event: &project::Event,
179        _cx: &mut Context<Self>,
180    ) {
181        match event {
182            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
183                self.enqueue_system_prompt_reload();
184            }
185            project::Event::WorktreeUpdatedEntries(_, items) => {
186                if items.iter().any(|(path, _, _)| {
187                    RULES_FILE_NAMES
188                        .iter()
189                        .any(|name| path.as_ref() == Path::new(name))
190                }) {
191                    self.enqueue_system_prompt_reload();
192                }
193            }
194            _ => {}
195        }
196    }
197
198    fn enqueue_system_prompt_reload(&mut self) {
199        self.reload_system_prompt_tx.try_send(()).ok();
200    }
201
202    // Note that this should only be called from `reload_system_prompt_task`.
203    fn reload_system_prompt(
204        &self,
205        prompt_store: Option<Entity<PromptStore>>,
206        cx: &mut Context<Self>,
207    ) -> Task<()> {
208        let project = self.project.read(cx);
209        let worktree_tasks = project
210            .visible_worktrees(cx)
211            .map(|worktree| {
212                Self::load_worktree_info_for_system_prompt(
213                    project.fs().clone(),
214                    worktree.read(cx),
215                    cx,
216                )
217            })
218            .collect::<Vec<_>>();
219        let default_user_rules_task = match prompt_store {
220            None => Task::ready(vec![]),
221            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
222                let prompts = prompt_store.default_prompt_metadata();
223                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
224                    let contents = prompt_store.load(prompt_metadata.id, cx);
225                    async move { (contents.await, prompt_metadata) }
226                });
227                cx.background_spawn(future::join_all(load_tasks))
228            }),
229        };
230
231        cx.spawn(async move |this, cx| {
232            let (worktrees, default_user_rules) =
233                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
234
235            let worktrees = worktrees
236                .into_iter()
237                .map(|(worktree, rules_error)| {
238                    if let Some(rules_error) = rules_error {
239                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
240                    }
241                    worktree
242                })
243                .collect::<Vec<_>>();
244
245            let default_user_rules = default_user_rules
246                .into_iter()
247                .flat_map(|(contents, prompt_metadata)| match contents {
248                    Ok(contents) => Some(DefaultUserRulesContext {
249                        uuid: match prompt_metadata.id {
250                            PromptId::User { uuid } => uuid,
251                            PromptId::EditWorkflow => return None,
252                        },
253                        title: prompt_metadata.title.map(|title| title.to_string()),
254                        contents,
255                    }),
256                    Err(err) => {
257                        this.update(cx, |_, cx| {
258                            cx.emit(RulesLoadingError {
259                                message: format!("{err:?}").into(),
260                            });
261                        })
262                        .ok();
263                        None
264                    }
265                })
266                .collect::<Vec<_>>();
267
268            this.update(cx, |this, _cx| {
269                *this.project_context.0.borrow_mut() =
270                    Some(ProjectContext::new(worktrees, default_user_rules));
271            })
272            .ok();
273        })
274    }
275
276    fn load_worktree_info_for_system_prompt(
277        fs: Arc<dyn Fs>,
278        worktree: &Worktree,
279        cx: &App,
280    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
281        let root_name = worktree.root_name().into();
282
283        let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
284        let Some(rules_task) = rules_task else {
285            return Task::ready((
286                WorktreeContext {
287                    root_name,
288                    rules_file: None,
289                },
290                None,
291            ));
292        };
293
294        cx.spawn(async move |_| {
295            let (rules_file, rules_file_error) = match rules_task.await {
296                Ok(rules_file) => (Some(rules_file), None),
297                Err(err) => (
298                    None,
299                    Some(RulesLoadingError {
300                        message: format!("{err}").into(),
301                    }),
302                ),
303            };
304            let worktree_info = WorktreeContext {
305                root_name,
306                rules_file,
307            };
308            (worktree_info, rules_file_error)
309        })
310    }
311
312    fn load_worktree_rules_file(
313        fs: Arc<dyn Fs>,
314        worktree: &Worktree,
315        cx: &App,
316    ) -> Option<Task<Result<RulesFileContext>>> {
317        let selected_rules_file = RULES_FILE_NAMES
318            .into_iter()
319            .filter_map(|name| {
320                worktree
321                    .entry_for_path(name)
322                    .filter(|entry| entry.is_file())
323                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
324            })
325            .next();
326
327        // Note that Cline supports `.clinerules` being a directory, but that is not currently
328        // supported. This doesn't seem to occur often in GitHub repositories.
329        selected_rules_file.map(|(path_in_worktree, abs_path)| {
330            let fs = fs.clone();
331            cx.background_spawn(async move {
332                let abs_path = abs_path?;
333                let text = fs.load(&abs_path).await.with_context(|| {
334                    format!("Failed to load assistant rules file {:?}", abs_path)
335                })?;
336                anyhow::Ok(RulesFileContext {
337                    path_in_worktree,
338                    abs_path: abs_path.into(),
339                    text: text.trim().to_string(),
340                })
341            })
342        })
343    }
344
345    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
346        self.context_server_manager.clone()
347    }
348
349    pub fn tools(&self) -> Entity<ToolWorkingSet> {
350        self.tools.clone()
351    }
352
353    /// Returns the number of threads.
354    pub fn thread_count(&self) -> usize {
355        self.threads.len()
356    }
357
358    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
359        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
360        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
361        threads
362    }
363
364    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
365        self.threads().into_iter().take(limit).collect()
366    }
367
368    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
369        cx.new(|cx| {
370            Thread::new(
371                self.project.clone(),
372                self.tools.clone(),
373                self.prompt_builder.clone(),
374                self.project_context.clone(),
375                cx,
376            )
377        })
378    }
379
380    pub fn open_thread(
381        &self,
382        id: &ThreadId,
383        cx: &mut Context<Self>,
384    ) -> Task<Result<Entity<Thread>>> {
385        let id = id.clone();
386        let database_future = ThreadsDatabase::global_future(cx);
387        cx.spawn(async move |this, cx| {
388            let database = database_future.await.map_err(|err| anyhow!(err))?;
389            let thread = database
390                .try_find_thread(id.clone())
391                .await?
392                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
393
394            let thread = this.update(cx, |this, cx| {
395                cx.new(|cx| {
396                    Thread::deserialize(
397                        id.clone(),
398                        thread,
399                        this.project.clone(),
400                        this.tools.clone(),
401                        this.prompt_builder.clone(),
402                        this.project_context.clone(),
403                        cx,
404                    )
405                })
406            })?;
407
408            Ok(thread)
409        })
410    }
411
412    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
413        let (metadata, serialized_thread) =
414            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
415
416        let database_future = ThreadsDatabase::global_future(cx);
417        cx.spawn(async move |this, cx| {
418            let serialized_thread = serialized_thread.await?;
419            let database = database_future.await.map_err(|err| anyhow!(err))?;
420            database.save_thread(metadata, serialized_thread).await?;
421
422            this.update(cx, |this, cx| this.reload(cx))?.await
423        })
424    }
425
426    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
427        let id = id.clone();
428        let database_future = ThreadsDatabase::global_future(cx);
429        cx.spawn(async move |this, cx| {
430            let database = database_future.await.map_err(|err| anyhow!(err))?;
431            database.delete_thread(id.clone()).await?;
432
433            this.update(cx, |this, cx| {
434                this.threads.retain(|thread| thread.id != id);
435                cx.notify();
436            })
437        })
438    }
439
440    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
441        let database_future = ThreadsDatabase::global_future(cx);
442        cx.spawn(async move |this, cx| {
443            let threads = database_future
444                .await
445                .map_err(|err| anyhow!(err))?
446                .list_threads()
447                .await?;
448
449            this.update(cx, |this, cx| {
450                this.threads = threads;
451                cx.notify();
452            })
453        })
454    }
455
456    fn load_default_profile(&self, cx: &mut Context<Self>) {
457        let assistant_settings = AssistantSettings::get_global(cx);
458
459        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
460    }
461
462    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
463        let assistant_settings = AssistantSettings::get_global(cx);
464
465        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
466            self.load_profile(profile.clone(), cx);
467        }
468    }
469
470    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
471        self.tools.update(cx, |tools, cx| {
472            tools.disable_all_tools(cx);
473            tools.enable(
474                ToolSource::Native,
475                &profile
476                    .tools
477                    .iter()
478                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
479                    .collect::<Vec<_>>(),
480                cx,
481            );
482        });
483
484        if profile.enable_all_context_servers {
485            for context_server in self.context_server_manager.read(cx).all_servers() {
486                self.tools.update(cx, |tools, cx| {
487                    tools.enable_source(
488                        ToolSource::ContextServer {
489                            id: context_server.id().into(),
490                        },
491                        cx,
492                    );
493                });
494            }
495        } else {
496            for (context_server_id, preset) in &profile.context_servers {
497                self.tools.update(cx, |tools, cx| {
498                    tools.enable(
499                        ToolSource::ContextServer {
500                            id: context_server_id.clone().into(),
501                        },
502                        &preset
503                            .tools
504                            .iter()
505                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
506                            .collect::<Vec<_>>(),
507                        cx,
508                    )
509                })
510            }
511        }
512    }
513
514    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
515        cx.subscribe(
516            &self.context_server_manager.clone(),
517            Self::handle_context_server_event,
518        )
519        .detach();
520    }
521
522    fn handle_context_server_event(
523        &mut self,
524        context_server_manager: Entity<ContextServerManager>,
525        event: &context_server::manager::Event,
526        cx: &mut Context<Self>,
527    ) {
528        let tool_working_set = self.tools.clone();
529        match event {
530            context_server::manager::Event::ServerStarted { server_id } => {
531                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
532                    let context_server_manager = context_server_manager.clone();
533                    cx.spawn({
534                        let server = server.clone();
535                        let server_id = server_id.clone();
536                        async move |this, cx| {
537                            let Some(protocol) = server.client() else {
538                                return;
539                            };
540
541                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
542                                if let Some(tools) = protocol.list_tools().await.log_err() {
543                                    let tool_ids = tool_working_set
544                                        .update(cx, |tool_working_set, _| {
545                                            tools
546                                                .tools
547                                                .into_iter()
548                                                .map(|tool| {
549                                                    log::info!(
550                                                        "registering context server tool: {:?}",
551                                                        tool.name
552                                                    );
553                                                    tool_working_set.insert(Arc::new(
554                                                        ContextServerTool::new(
555                                                            context_server_manager.clone(),
556                                                            server.id(),
557                                                            tool,
558                                                        ),
559                                                    ))
560                                                })
561                                                .collect::<Vec<_>>()
562                                        })
563                                        .log_err();
564
565                                    if let Some(tool_ids) = tool_ids {
566                                        this.update(cx, |this, cx| {
567                                            this.context_server_tool_ids
568                                                .insert(server_id, tool_ids);
569                                            this.load_default_profile(cx);
570                                        })
571                                        .log_err();
572                                    }
573                                }
574                            }
575                        }
576                    })
577                    .detach();
578                }
579            }
580            context_server::manager::Event::ServerStopped { server_id } => {
581                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
582                    tool_working_set.update(cx, |tool_working_set, _| {
583                        tool_working_set.remove(&tool_ids);
584                    });
585                    self.load_default_profile(cx);
586                }
587            }
588        }
589    }
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize)]
593pub struct SerializedThreadMetadata {
594    pub id: ThreadId,
595    pub summary: SharedString,
596    pub updated_at: DateTime<Utc>,
597}
598
599#[derive(Serialize, Deserialize, Debug)]
600pub struct SerializedThread {
601    pub version: String,
602    pub summary: SharedString,
603    pub updated_at: DateTime<Utc>,
604    pub messages: Vec<SerializedMessage>,
605    #[serde(default)]
606    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
607    #[serde(default)]
608    pub cumulative_token_usage: TokenUsage,
609    #[serde(default)]
610    pub request_token_usage: Vec<TokenUsage>,
611    #[serde(default)]
612    pub detailed_summary_state: DetailedSummaryState,
613    #[serde(default)]
614    pub exceeded_window_error: Option<ExceededWindowError>,
615}
616
617impl SerializedThread {
618    pub const VERSION: &'static str = "0.1.0";
619
620    pub fn from_json(json: &[u8]) -> Result<Self> {
621        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
622        match saved_thread_json.get("version") {
623            Some(serde_json::Value::String(version)) => match version.as_str() {
624                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
625                    saved_thread_json,
626                )?),
627                _ => Err(anyhow!(
628                    "unrecognized serialized thread version: {}",
629                    version
630                )),
631            },
632            None => {
633                let saved_thread =
634                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
635                Ok(saved_thread.upgrade())
636            }
637            version => Err(anyhow!(
638                "unrecognized serialized thread version: {:?}",
639                version
640            )),
641        }
642    }
643}
644
645#[derive(Debug, Serialize, Deserialize)]
646pub struct SerializedMessage {
647    pub id: MessageId,
648    pub role: Role,
649    #[serde(default)]
650    pub segments: Vec<SerializedMessageSegment>,
651    #[serde(default)]
652    pub tool_uses: Vec<SerializedToolUse>,
653    #[serde(default)]
654    pub tool_results: Vec<SerializedToolResult>,
655    #[serde(default)]
656    pub context: String,
657}
658
659#[derive(Debug, Serialize, Deserialize)]
660#[serde(tag = "type")]
661pub enum SerializedMessageSegment {
662    #[serde(rename = "text")]
663    Text { text: String },
664    #[serde(rename = "thinking")]
665    Thinking { text: String },
666}
667
668#[derive(Debug, Serialize, Deserialize)]
669pub struct SerializedToolUse {
670    pub id: LanguageModelToolUseId,
671    pub name: SharedString,
672    pub input: serde_json::Value,
673}
674
675#[derive(Debug, Serialize, Deserialize)]
676pub struct SerializedToolResult {
677    pub tool_use_id: LanguageModelToolUseId,
678    pub is_error: bool,
679    pub content: Arc<str>,
680}
681
682#[derive(Serialize, Deserialize)]
683struct LegacySerializedThread {
684    pub summary: SharedString,
685    pub updated_at: DateTime<Utc>,
686    pub messages: Vec<LegacySerializedMessage>,
687    #[serde(default)]
688    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
689}
690
691impl LegacySerializedThread {
692    pub fn upgrade(self) -> SerializedThread {
693        SerializedThread {
694            version: SerializedThread::VERSION.to_string(),
695            summary: self.summary,
696            updated_at: self.updated_at,
697            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
698            initial_project_snapshot: self.initial_project_snapshot,
699            cumulative_token_usage: TokenUsage::default(),
700            request_token_usage: Vec::new(),
701            detailed_summary_state: DetailedSummaryState::default(),
702            exceeded_window_error: None,
703        }
704    }
705}
706
707#[derive(Debug, Serialize, Deserialize)]
708struct LegacySerializedMessage {
709    pub id: MessageId,
710    pub role: Role,
711    pub text: String,
712    #[serde(default)]
713    pub tool_uses: Vec<SerializedToolUse>,
714    #[serde(default)]
715    pub tool_results: Vec<SerializedToolResult>,
716}
717
718impl LegacySerializedMessage {
719    fn upgrade(self) -> SerializedMessage {
720        SerializedMessage {
721            id: self.id,
722            role: self.role,
723            segments: vec![SerializedMessageSegment::Text { text: self.text }],
724            tool_uses: self.tool_uses,
725            tool_results: self.tool_results,
726            context: String::new(),
727        }
728    }
729}
730
731struct GlobalThreadsDatabase(
732    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
733);
734
735impl Global for GlobalThreadsDatabase {}
736
737pub(crate) struct ThreadsDatabase {
738    executor: BackgroundExecutor,
739    env: heed::Env,
740    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
741}
742
743impl heed::BytesEncode<'_> for SerializedThread {
744    type EItem = SerializedThread;
745
746    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
747        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
748    }
749}
750
751impl<'a> heed::BytesDecode<'a> for SerializedThread {
752    type DItem = SerializedThread;
753
754    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
755        // We implement this type manually because we want to call `SerializedThread::from_json`,
756        // instead of the Deserialize trait implementation for `SerializedThread`.
757        SerializedThread::from_json(bytes).map_err(Into::into)
758    }
759}
760
761impl ThreadsDatabase {
762    fn global_future(
763        cx: &mut App,
764    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
765        GlobalThreadsDatabase::global(cx).0.clone()
766    }
767
768    fn init(cx: &mut App) {
769        let executor = cx.background_executor().clone();
770        let database_future = executor
771            .spawn({
772                let executor = executor.clone();
773                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
774                async move { ThreadsDatabase::new(database_path, executor) }
775            })
776            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
777            .boxed()
778            .shared();
779
780        cx.set_global(GlobalThreadsDatabase(database_future));
781    }
782
783    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
784        std::fs::create_dir_all(&path)?;
785
786        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
787        let env = unsafe {
788            heed::EnvOpenOptions::new()
789                .map_size(ONE_GB_IN_BYTES)
790                .max_dbs(1)
791                .open(path)?
792        };
793
794        let mut txn = env.write_txn()?;
795        let threads = env.create_database(&mut txn, Some("threads"))?;
796        txn.commit()?;
797
798        Ok(Self {
799            executor,
800            env,
801            threads,
802        })
803    }
804
805    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
806        let env = self.env.clone();
807        let threads = self.threads;
808
809        self.executor.spawn(async move {
810            let txn = env.read_txn()?;
811            let mut iter = threads.iter(&txn)?;
812            let mut threads = Vec::new();
813            while let Some((key, value)) = iter.next().transpose()? {
814                threads.push(SerializedThreadMetadata {
815                    id: key,
816                    summary: value.summary,
817                    updated_at: value.updated_at,
818                });
819            }
820
821            Ok(threads)
822        })
823    }
824
825    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
826        let env = self.env.clone();
827        let threads = self.threads;
828
829        self.executor.spawn(async move {
830            let txn = env.read_txn()?;
831            let thread = threads.get(&txn, &id)?;
832            Ok(thread)
833        })
834    }
835
836    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
837        let env = self.env.clone();
838        let threads = self.threads;
839
840        self.executor.spawn(async move {
841            let mut txn = env.write_txn()?;
842            threads.put(&mut txn, &id, &thread)?;
843            txn.commit()?;
844            Ok(())
845        })
846    }
847
848    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
849        let env = self.env.clone();
850        let threads = self.threads;
851
852        self.executor.spawn(async move {
853            let mut txn = env.write_txn()?;
854            threads.delete(&mut txn, &id)?;
855            txn.commit()?;
856            Ok(())
857        })
858    }
859}