thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::manager::ContextServerManager;
 13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 14use fs::Fs;
 15use futures::channel::{mpsc, oneshot};
 16use futures::future::{self, BoxFuture, Shared};
 17use futures::{FutureExt as _, StreamExt as _};
 18use gpui::{
 19    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 20    Subscription, Task, prelude::*,
 21};
 22use heed::Database;
 23use heed::types::SerdeBincode;
 24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 25use project::{Project, Worktree};
 26use prompt_store::{
 27    DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptId, PromptStore,
 28    PromptsUpdatedEvent, RulesFileContext, WorktreeContext,
 29};
 30use serde::{Deserialize, Serialize};
 31use settings::{Settings as _, SettingsStore};
 32use util::ResultExt as _;
 33
 34use crate::thread::{
 35    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 36};
 37
 38const RULES_FILE_NAMES: [&'static str; 6] = [
 39    ".rules",
 40    ".cursorrules",
 41    ".windsurfrules",
 42    ".clinerules",
 43    ".github/copilot-instructions.md",
 44    "CLAUDE.md",
 45];
 46
 47pub fn init(cx: &mut App) {
 48    ThreadsDatabase::init(cx);
 49}
 50
 51/// A system prompt shared by all threads created by this ThreadStore
 52#[derive(Clone, Default)]
 53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 54
 55impl SharedProjectContext {
 56    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 57        self.0.borrow()
 58    }
 59}
 60
 61pub struct ThreadStore {
 62    project: Entity<Project>,
 63    tools: Entity<ToolWorkingSet>,
 64    prompt_builder: Arc<PromptBuilder>,
 65    context_server_manager: Entity<ContextServerManager>,
 66    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 67    threads: Vec<SerializedThreadMetadata>,
 68    project_context: SharedProjectContext,
 69    reload_system_prompt_tx: mpsc::Sender<()>,
 70    _reload_system_prompt_task: Task<()>,
 71    _subscriptions: Vec<Subscription>,
 72}
 73
 74pub struct RulesLoadingError {
 75    pub message: SharedString,
 76}
 77
 78impl EventEmitter<RulesLoadingError> for ThreadStore {}
 79
 80impl ThreadStore {
 81    pub fn load(
 82        project: Entity<Project>,
 83        tools: Entity<ToolWorkingSet>,
 84        prompt_builder: Arc<PromptBuilder>,
 85        cx: &mut App,
 86    ) -> Task<Result<Entity<Self>>> {
 87        let prompt_store = PromptStore::global(cx);
 88        cx.spawn(async move |cx| {
 89            let prompt_store = prompt_store.await.ok();
 90            let (thread_store, ready_rx) = cx.update(|cx| {
 91                let mut option_ready_rx = None;
 92                let thread_store = cx.new(|cx| {
 93                    let (thread_store, ready_rx) =
 94                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 95                    option_ready_rx = Some(ready_rx);
 96                    thread_store
 97                });
 98                (thread_store, option_ready_rx.take().unwrap())
 99            })?;
100            ready_rx.await?;
101            Ok(thread_store)
102        })
103    }
104
105    fn new(
106        project: Entity<Project>,
107        tools: Entity<ToolWorkingSet>,
108        prompt_builder: Arc<PromptBuilder>,
109        prompt_store: Option<Entity<PromptStore>>,
110        cx: &mut Context<Self>,
111    ) -> (Self, oneshot::Receiver<()>) {
112        let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
113        let context_server_manager = cx.new(|cx| {
114            ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
115        });
116
117        let mut subscriptions = vec![
118            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
119                this.load_default_profile(cx);
120            }),
121            cx.subscribe(&project, Self::handle_project_event),
122        ];
123
124        if let Some(prompt_store) = prompt_store.as_ref() {
125            subscriptions.push(cx.subscribe(
126                prompt_store,
127                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
128                    this.enqueue_system_prompt_reload();
129                },
130            ))
131        }
132
133        // This channel and task prevent concurrent and redundant loading of the system prompt.
134        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
135        let (ready_tx, ready_rx) = oneshot::channel();
136        let mut ready_tx = Some(ready_tx);
137        let reload_system_prompt_task = cx.spawn({
138            async move |thread_store, cx| {
139                loop {
140                    let Some(reload_task) = thread_store
141                        .update(cx, |thread_store, cx| {
142                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
143                        })
144                        .ok()
145                    else {
146                        return;
147                    };
148                    reload_task.await;
149                    if let Some(ready_tx) = ready_tx.take() {
150                        ready_tx.send(()).ok();
151                    }
152                    reload_system_prompt_rx.next().await;
153                }
154            }
155        });
156
157        let this = Self {
158            project,
159            tools,
160            prompt_builder,
161            context_server_manager,
162            context_server_tool_ids: HashMap::default(),
163            threads: Vec::new(),
164            project_context: SharedProjectContext::default(),
165            reload_system_prompt_tx,
166            _reload_system_prompt_task: reload_system_prompt_task,
167            _subscriptions: subscriptions,
168        };
169        this.load_default_profile(cx);
170        this.register_context_server_handlers(cx);
171        this.reload(cx).detach_and_log_err(cx);
172        (this, ready_rx)
173    }
174
175    fn handle_project_event(
176        &mut self,
177        _project: Entity<Project>,
178        event: &project::Event,
179        _cx: &mut Context<Self>,
180    ) {
181        match event {
182            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
183                self.enqueue_system_prompt_reload();
184            }
185            project::Event::WorktreeUpdatedEntries(_, items) => {
186                if items.iter().any(|(path, _, _)| {
187                    RULES_FILE_NAMES
188                        .iter()
189                        .any(|name| path.as_ref() == Path::new(name))
190                }) {
191                    self.enqueue_system_prompt_reload();
192                }
193            }
194            _ => {}
195        }
196    }
197
198    fn enqueue_system_prompt_reload(&mut self) {
199        self.reload_system_prompt_tx.try_send(()).ok();
200    }
201
202    // Note that this should only be called from `reload_system_prompt_task`.
203    fn reload_system_prompt(
204        &self,
205        prompt_store: Option<Entity<PromptStore>>,
206        cx: &mut Context<Self>,
207    ) -> Task<()> {
208        let project = self.project.read(cx);
209        let worktree_tasks = project
210            .visible_worktrees(cx)
211            .map(|worktree| {
212                Self::load_worktree_info_for_system_prompt(
213                    project.fs().clone(),
214                    worktree.read(cx),
215                    cx,
216                )
217            })
218            .collect::<Vec<_>>();
219        let default_user_rules_task = match prompt_store {
220            None => Task::ready(vec![]),
221            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
222                let prompts = prompt_store.default_prompt_metadata();
223                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
224                    let contents = prompt_store.load(prompt_metadata.id, cx);
225                    async move { (contents.await, prompt_metadata) }
226                });
227                cx.background_spawn(future::join_all(load_tasks))
228            }),
229        };
230
231        cx.spawn(async move |this, cx| {
232            let (worktrees, default_user_rules) =
233                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
234
235            let worktrees = worktrees
236                .into_iter()
237                .map(|(worktree, rules_error)| {
238                    if let Some(rules_error) = rules_error {
239                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
240                    }
241                    worktree
242                })
243                .collect::<Vec<_>>();
244
245            let default_user_rules = default_user_rules
246                .into_iter()
247                .flat_map(|(contents, prompt_metadata)| match contents {
248                    Ok(contents) => Some(DefaultUserRulesContext {
249                        uuid: match prompt_metadata.id {
250                            PromptId::User { uuid } => uuid,
251                            PromptId::EditWorkflow => return None,
252                        },
253                        title: prompt_metadata.title.map(|title| title.to_string()),
254                        contents,
255                    }),
256                    Err(err) => {
257                        this.update(cx, |_, cx| {
258                            cx.emit(RulesLoadingError {
259                                message: format!("{err:?}").into(),
260                            });
261                        })
262                        .ok();
263                        None
264                    }
265                })
266                .collect::<Vec<_>>();
267
268            this.update(cx, |this, _cx| {
269                *this.project_context.0.borrow_mut() =
270                    Some(ProjectContext::new(worktrees, default_user_rules));
271            })
272            .ok();
273        })
274    }
275
276    fn load_worktree_info_for_system_prompt(
277        fs: Arc<dyn Fs>,
278        worktree: &Worktree,
279        cx: &App,
280    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
281        let root_name = worktree.root_name().into();
282        let abs_path = worktree.abs_path();
283
284        let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
285        let Some(rules_task) = rules_task else {
286            return Task::ready((
287                WorktreeContext {
288                    root_name,
289                    abs_path,
290                    rules_file: None,
291                },
292                None,
293            ));
294        };
295
296        cx.spawn(async move |_| {
297            let (rules_file, rules_file_error) = match rules_task.await {
298                Ok(rules_file) => (Some(rules_file), None),
299                Err(err) => (
300                    None,
301                    Some(RulesLoadingError {
302                        message: format!("{err}").into(),
303                    }),
304                ),
305            };
306            let worktree_info = WorktreeContext {
307                root_name,
308                abs_path,
309                rules_file,
310            };
311            (worktree_info, rules_file_error)
312        })
313    }
314
315    fn load_worktree_rules_file(
316        fs: Arc<dyn Fs>,
317        worktree: &Worktree,
318        cx: &App,
319    ) -> Option<Task<Result<RulesFileContext>>> {
320        let selected_rules_file = RULES_FILE_NAMES
321            .into_iter()
322            .filter_map(|name| {
323                worktree
324                    .entry_for_path(name)
325                    .filter(|entry| entry.is_file())
326                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
327            })
328            .next();
329
330        // Note that Cline supports `.clinerules` being a directory, but that is not currently
331        // supported. This doesn't seem to occur often in GitHub repositories.
332        selected_rules_file.map(|(path_in_worktree, abs_path)| {
333            let fs = fs.clone();
334            cx.background_spawn(async move {
335                let abs_path = abs_path?;
336                let text = fs.load(&abs_path).await.with_context(|| {
337                    format!("Failed to load assistant rules file {:?}", abs_path)
338                })?;
339                anyhow::Ok(RulesFileContext {
340                    path_in_worktree,
341                    abs_path: abs_path.into(),
342                    text: text.trim().to_string(),
343                })
344            })
345        })
346    }
347
348    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
349        self.context_server_manager.clone()
350    }
351
352    pub fn tools(&self) -> Entity<ToolWorkingSet> {
353        self.tools.clone()
354    }
355
356    /// Returns the number of threads.
357    pub fn thread_count(&self) -> usize {
358        self.threads.len()
359    }
360
361    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
362        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
363        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
364        threads
365    }
366
367    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
368        self.threads().into_iter().take(limit).collect()
369    }
370
371    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
372        cx.new(|cx| {
373            Thread::new(
374                self.project.clone(),
375                self.tools.clone(),
376                self.prompt_builder.clone(),
377                self.project_context.clone(),
378                cx,
379            )
380        })
381    }
382
383    pub fn open_thread(
384        &self,
385        id: &ThreadId,
386        cx: &mut Context<Self>,
387    ) -> Task<Result<Entity<Thread>>> {
388        let id = id.clone();
389        let database_future = ThreadsDatabase::global_future(cx);
390        cx.spawn(async move |this, cx| {
391            let database = database_future.await.map_err(|err| anyhow!(err))?;
392            let thread = database
393                .try_find_thread(id.clone())
394                .await?
395                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
396
397            let thread = this.update(cx, |this, cx| {
398                cx.new(|cx| {
399                    Thread::deserialize(
400                        id.clone(),
401                        thread,
402                        this.project.clone(),
403                        this.tools.clone(),
404                        this.prompt_builder.clone(),
405                        this.project_context.clone(),
406                        cx,
407                    )
408                })
409            })?;
410
411            Ok(thread)
412        })
413    }
414
415    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
416        let (metadata, serialized_thread) =
417            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
418
419        let database_future = ThreadsDatabase::global_future(cx);
420        cx.spawn(async move |this, cx| {
421            let serialized_thread = serialized_thread.await?;
422            let database = database_future.await.map_err(|err| anyhow!(err))?;
423            database.save_thread(metadata, serialized_thread).await?;
424
425            this.update(cx, |this, cx| this.reload(cx))?.await
426        })
427    }
428
429    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
430        let id = id.clone();
431        let database_future = ThreadsDatabase::global_future(cx);
432        cx.spawn(async move |this, cx| {
433            let database = database_future.await.map_err(|err| anyhow!(err))?;
434            database.delete_thread(id.clone()).await?;
435
436            this.update(cx, |this, cx| {
437                this.threads.retain(|thread| thread.id != id);
438                cx.notify();
439            })
440        })
441    }
442
443    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
444        let database_future = ThreadsDatabase::global_future(cx);
445        cx.spawn(async move |this, cx| {
446            let threads = database_future
447                .await
448                .map_err(|err| anyhow!(err))?
449                .list_threads()
450                .await?;
451
452            this.update(cx, |this, cx| {
453                this.threads = threads;
454                cx.notify();
455            })
456        })
457    }
458
459    fn load_default_profile(&self, cx: &mut Context<Self>) {
460        let assistant_settings = AssistantSettings::get_global(cx);
461
462        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
463    }
464
465    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
466        let assistant_settings = AssistantSettings::get_global(cx);
467
468        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
469            self.load_profile(profile.clone(), cx);
470        }
471    }
472
473    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
474        self.tools.update(cx, |tools, cx| {
475            tools.disable_all_tools(cx);
476            tools.enable(
477                ToolSource::Native,
478                &profile
479                    .tools
480                    .iter()
481                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
482                    .collect::<Vec<_>>(),
483                cx,
484            );
485        });
486
487        if profile.enable_all_context_servers {
488            for context_server in self.context_server_manager.read(cx).all_servers() {
489                self.tools.update(cx, |tools, cx| {
490                    tools.enable_source(
491                        ToolSource::ContextServer {
492                            id: context_server.id().into(),
493                        },
494                        cx,
495                    );
496                });
497            }
498        } else {
499            for (context_server_id, preset) in &profile.context_servers {
500                self.tools.update(cx, |tools, cx| {
501                    tools.enable(
502                        ToolSource::ContextServer {
503                            id: context_server_id.clone().into(),
504                        },
505                        &preset
506                            .tools
507                            .iter()
508                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
509                            .collect::<Vec<_>>(),
510                        cx,
511                    )
512                })
513            }
514        }
515    }
516
517    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
518        cx.subscribe(
519            &self.context_server_manager.clone(),
520            Self::handle_context_server_event,
521        )
522        .detach();
523    }
524
525    fn handle_context_server_event(
526        &mut self,
527        context_server_manager: Entity<ContextServerManager>,
528        event: &context_server::manager::Event,
529        cx: &mut Context<Self>,
530    ) {
531        let tool_working_set = self.tools.clone();
532        match event {
533            context_server::manager::Event::ServerStarted { server_id } => {
534                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
535                    let context_server_manager = context_server_manager.clone();
536                    cx.spawn({
537                        let server = server.clone();
538                        let server_id = server_id.clone();
539                        async move |this, cx| {
540                            let Some(protocol) = server.client() else {
541                                return;
542                            };
543
544                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
545                                if let Some(tools) = protocol.list_tools().await.log_err() {
546                                    let tool_ids = tool_working_set
547                                        .update(cx, |tool_working_set, _| {
548                                            tools
549                                                .tools
550                                                .into_iter()
551                                                .map(|tool| {
552                                                    log::info!(
553                                                        "registering context server tool: {:?}",
554                                                        tool.name
555                                                    );
556                                                    tool_working_set.insert(Arc::new(
557                                                        ContextServerTool::new(
558                                                            context_server_manager.clone(),
559                                                            server.id(),
560                                                            tool,
561                                                        ),
562                                                    ))
563                                                })
564                                                .collect::<Vec<_>>()
565                                        })
566                                        .log_err();
567
568                                    if let Some(tool_ids) = tool_ids {
569                                        this.update(cx, |this, cx| {
570                                            this.context_server_tool_ids
571                                                .insert(server_id, tool_ids);
572                                            this.load_default_profile(cx);
573                                        })
574                                        .log_err();
575                                    }
576                                }
577                            }
578                        }
579                    })
580                    .detach();
581                }
582            }
583            context_server::manager::Event::ServerStopped { server_id } => {
584                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
585                    tool_working_set.update(cx, |tool_working_set, _| {
586                        tool_working_set.remove(&tool_ids);
587                    });
588                    self.load_default_profile(cx);
589                }
590            }
591        }
592    }
593}
594
595#[derive(Debug, Clone, Serialize, Deserialize)]
596pub struct SerializedThreadMetadata {
597    pub id: ThreadId,
598    pub summary: SharedString,
599    pub updated_at: DateTime<Utc>,
600}
601
602#[derive(Serialize, Deserialize, Debug)]
603pub struct SerializedThread {
604    pub version: String,
605    pub summary: SharedString,
606    pub updated_at: DateTime<Utc>,
607    pub messages: Vec<SerializedMessage>,
608    #[serde(default)]
609    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
610    #[serde(default)]
611    pub cumulative_token_usage: TokenUsage,
612    #[serde(default)]
613    pub request_token_usage: Vec<TokenUsage>,
614    #[serde(default)]
615    pub detailed_summary_state: DetailedSummaryState,
616    #[serde(default)]
617    pub exceeded_window_error: Option<ExceededWindowError>,
618}
619
620impl SerializedThread {
621    pub const VERSION: &'static str = "0.1.0";
622
623    pub fn from_json(json: &[u8]) -> Result<Self> {
624        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
625        match saved_thread_json.get("version") {
626            Some(serde_json::Value::String(version)) => match version.as_str() {
627                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
628                    saved_thread_json,
629                )?),
630                _ => Err(anyhow!(
631                    "unrecognized serialized thread version: {}",
632                    version
633                )),
634            },
635            None => {
636                let saved_thread =
637                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
638                Ok(saved_thread.upgrade())
639            }
640            version => Err(anyhow!(
641                "unrecognized serialized thread version: {:?}",
642                version
643            )),
644        }
645    }
646}
647
648#[derive(Debug, Serialize, Deserialize)]
649pub struct SerializedMessage {
650    pub id: MessageId,
651    pub role: Role,
652    #[serde(default)]
653    pub segments: Vec<SerializedMessageSegment>,
654    #[serde(default)]
655    pub tool_uses: Vec<SerializedToolUse>,
656    #[serde(default)]
657    pub tool_results: Vec<SerializedToolResult>,
658    #[serde(default)]
659    pub context: String,
660}
661
662#[derive(Debug, Serialize, Deserialize)]
663#[serde(tag = "type")]
664pub enum SerializedMessageSegment {
665    #[serde(rename = "text")]
666    Text { text: String },
667    #[serde(rename = "thinking")]
668    Thinking { text: String },
669}
670
671#[derive(Debug, Serialize, Deserialize)]
672pub struct SerializedToolUse {
673    pub id: LanguageModelToolUseId,
674    pub name: SharedString,
675    pub input: serde_json::Value,
676}
677
678#[derive(Debug, Serialize, Deserialize)]
679pub struct SerializedToolResult {
680    pub tool_use_id: LanguageModelToolUseId,
681    pub is_error: bool,
682    pub content: Arc<str>,
683}
684
685#[derive(Serialize, Deserialize)]
686struct LegacySerializedThread {
687    pub summary: SharedString,
688    pub updated_at: DateTime<Utc>,
689    pub messages: Vec<LegacySerializedMessage>,
690    #[serde(default)]
691    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
692}
693
694impl LegacySerializedThread {
695    pub fn upgrade(self) -> SerializedThread {
696        SerializedThread {
697            version: SerializedThread::VERSION.to_string(),
698            summary: self.summary,
699            updated_at: self.updated_at,
700            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
701            initial_project_snapshot: self.initial_project_snapshot,
702            cumulative_token_usage: TokenUsage::default(),
703            request_token_usage: Vec::new(),
704            detailed_summary_state: DetailedSummaryState::default(),
705            exceeded_window_error: None,
706        }
707    }
708}
709
710#[derive(Debug, Serialize, Deserialize)]
711struct LegacySerializedMessage {
712    pub id: MessageId,
713    pub role: Role,
714    pub text: String,
715    #[serde(default)]
716    pub tool_uses: Vec<SerializedToolUse>,
717    #[serde(default)]
718    pub tool_results: Vec<SerializedToolResult>,
719}
720
721impl LegacySerializedMessage {
722    fn upgrade(self) -> SerializedMessage {
723        SerializedMessage {
724            id: self.id,
725            role: self.role,
726            segments: vec![SerializedMessageSegment::Text { text: self.text }],
727            tool_uses: self.tool_uses,
728            tool_results: self.tool_results,
729            context: String::new(),
730        }
731    }
732}
733
734struct GlobalThreadsDatabase(
735    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
736);
737
738impl Global for GlobalThreadsDatabase {}
739
740pub(crate) struct ThreadsDatabase {
741    executor: BackgroundExecutor,
742    env: heed::Env,
743    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
744}
745
746impl heed::BytesEncode<'_> for SerializedThread {
747    type EItem = SerializedThread;
748
749    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
750        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
751    }
752}
753
754impl<'a> heed::BytesDecode<'a> for SerializedThread {
755    type DItem = SerializedThread;
756
757    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
758        // We implement this type manually because we want to call `SerializedThread::from_json`,
759        // instead of the Deserialize trait implementation for `SerializedThread`.
760        SerializedThread::from_json(bytes).map_err(Into::into)
761    }
762}
763
764impl ThreadsDatabase {
765    fn global_future(
766        cx: &mut App,
767    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
768        GlobalThreadsDatabase::global(cx).0.clone()
769    }
770
771    fn init(cx: &mut App) {
772        let executor = cx.background_executor().clone();
773        let database_future = executor
774            .spawn({
775                let executor = executor.clone();
776                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
777                async move { ThreadsDatabase::new(database_path, executor) }
778            })
779            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
780            .boxed()
781            .shared();
782
783        cx.set_global(GlobalThreadsDatabase(database_future));
784    }
785
786    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
787        std::fs::create_dir_all(&path)?;
788
789        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
790        let env = unsafe {
791            heed::EnvOpenOptions::new()
792                .map_size(ONE_GB_IN_BYTES)
793                .max_dbs(1)
794                .open(path)?
795        };
796
797        let mut txn = env.write_txn()?;
798        let threads = env.create_database(&mut txn, Some("threads"))?;
799        txn.commit()?;
800
801        Ok(Self {
802            executor,
803            env,
804            threads,
805        })
806    }
807
808    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
809        let env = self.env.clone();
810        let threads = self.threads;
811
812        self.executor.spawn(async move {
813            let txn = env.read_txn()?;
814            let mut iter = threads.iter(&txn)?;
815            let mut threads = Vec::new();
816            while let Some((key, value)) = iter.next().transpose()? {
817                threads.push(SerializedThreadMetadata {
818                    id: key,
819                    summary: value.summary,
820                    updated_at: value.updated_at,
821                });
822            }
823
824            Ok(threads)
825        })
826    }
827
828    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
829        let env = self.env.clone();
830        let threads = self.threads;
831
832        self.executor.spawn(async move {
833            let txn = env.read_txn()?;
834            let thread = threads.get(&txn, &id)?;
835            Ok(thread)
836        })
837    }
838
839    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
840        let env = self.env.clone();
841        let threads = self.threads;
842
843        self.executor.spawn(async move {
844            let mut txn = env.write_txn()?;
845            threads.put(&mut txn, &id, &thread)?;
846            txn.commit()?;
847            Ok(())
848        })
849    }
850
851    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
852        let env = self.env.clone();
853        let threads = self.threads;
854
855        self.executor.spawn(async move {
856            let mut txn = env.write_txn()?;
857            threads.delete(&mut txn, &id)?;
858            txn.commit()?;
859            Ok(())
860        })
861    }
862}