thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::manager::ContextServerManager;
 13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 14use fs::Fs;
 15use futures::channel::{mpsc, oneshot};
 16use futures::future::{self, BoxFuture, Shared};
 17use futures::{FutureExt as _, StreamExt as _};
 18use gpui::{
 19    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 20    Subscription, Task, prelude::*,
 21};
 22use heed::Database;
 23use heed::types::SerdeBincode;
 24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 25use project::{Project, Worktree};
 26use prompt_store::{
 27    ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
 28    RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
 29};
 30use serde::{Deserialize, Serialize};
 31use settings::{Settings as _, SettingsStore};
 32use util::ResultExt as _;
 33
 34use crate::thread::{
 35    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 36};
 37
 38const RULES_FILE_NAMES: [&'static str; 6] = [
 39    ".rules",
 40    ".cursorrules",
 41    ".windsurfrules",
 42    ".clinerules",
 43    ".github/copilot-instructions.md",
 44    "CLAUDE.md",
 45];
 46
 47pub fn init(cx: &mut App) {
 48    ThreadsDatabase::init(cx);
 49}
 50
 51/// A system prompt shared by all threads created by this ThreadStore
 52#[derive(Clone, Default)]
 53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 54
 55impl SharedProjectContext {
 56    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 57        self.0.borrow()
 58    }
 59}
 60
 61pub struct ThreadStore {
 62    project: Entity<Project>,
 63    tools: Entity<ToolWorkingSet>,
 64    prompt_builder: Arc<PromptBuilder>,
 65    prompt_store: Option<Entity<PromptStore>>,
 66    context_server_manager: Entity<ContextServerManager>,
 67    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 68    threads: Vec<SerializedThreadMetadata>,
 69    project_context: SharedProjectContext,
 70    reload_system_prompt_tx: mpsc::Sender<()>,
 71    _reload_system_prompt_task: Task<()>,
 72    _subscriptions: Vec<Subscription>,
 73}
 74
 75pub struct RulesLoadingError {
 76    pub message: SharedString,
 77}
 78
 79impl EventEmitter<RulesLoadingError> for ThreadStore {}
 80
 81impl ThreadStore {
 82    pub fn load(
 83        project: Entity<Project>,
 84        tools: Entity<ToolWorkingSet>,
 85        prompt_builder: Arc<PromptBuilder>,
 86        cx: &mut App,
 87    ) -> Task<Result<Entity<Self>>> {
 88        let prompt_store = PromptStore::global(cx);
 89        cx.spawn(async move |cx| {
 90            let prompt_store = prompt_store.await.ok();
 91            let (thread_store, ready_rx) = cx.update(|cx| {
 92                let mut option_ready_rx = None;
 93                let thread_store = cx.new(|cx| {
 94                    let (thread_store, ready_rx) =
 95                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 96                    option_ready_rx = Some(ready_rx);
 97                    thread_store
 98                });
 99                (thread_store, option_ready_rx.take().unwrap())
100            })?;
101            ready_rx.await?;
102            Ok(thread_store)
103        })
104    }
105
106    fn new(
107        project: Entity<Project>,
108        tools: Entity<ToolWorkingSet>,
109        prompt_builder: Arc<PromptBuilder>,
110        prompt_store: Option<Entity<PromptStore>>,
111        cx: &mut Context<Self>,
112    ) -> (Self, oneshot::Receiver<()>) {
113        let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
114        let context_server_manager = cx.new(|cx| {
115            ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
116        });
117
118        let mut subscriptions = vec![
119            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
120                this.load_default_profile(cx);
121            }),
122            cx.subscribe(&project, Self::handle_project_event),
123        ];
124
125        if let Some(prompt_store) = prompt_store.as_ref() {
126            subscriptions.push(cx.subscribe(
127                prompt_store,
128                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
129                    this.enqueue_system_prompt_reload();
130                },
131            ))
132        }
133
134        // This channel and task prevent concurrent and redundant loading of the system prompt.
135        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
136        let (ready_tx, ready_rx) = oneshot::channel();
137        let mut ready_tx = Some(ready_tx);
138        let reload_system_prompt_task = cx.spawn({
139            let prompt_store = prompt_store.clone();
140            async move |thread_store, cx| {
141                loop {
142                    let Some(reload_task) = thread_store
143                        .update(cx, |thread_store, cx| {
144                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
145                        })
146                        .ok()
147                    else {
148                        return;
149                    };
150                    reload_task.await;
151                    if let Some(ready_tx) = ready_tx.take() {
152                        ready_tx.send(()).ok();
153                    }
154                    reload_system_prompt_rx.next().await;
155                }
156            }
157        });
158
159        let this = Self {
160            project,
161            tools,
162            prompt_builder,
163            prompt_store,
164            context_server_manager,
165            context_server_tool_ids: HashMap::default(),
166            threads: Vec::new(),
167            project_context: SharedProjectContext::default(),
168            reload_system_prompt_tx,
169            _reload_system_prompt_task: reload_system_prompt_task,
170            _subscriptions: subscriptions,
171        };
172        this.load_default_profile(cx);
173        this.register_context_server_handlers(cx);
174        this.reload(cx).detach_and_log_err(cx);
175        (this, ready_rx)
176    }
177
178    fn handle_project_event(
179        &mut self,
180        _project: Entity<Project>,
181        event: &project::Event,
182        _cx: &mut Context<Self>,
183    ) {
184        match event {
185            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
186                self.enqueue_system_prompt_reload();
187            }
188            project::Event::WorktreeUpdatedEntries(_, items) => {
189                if items.iter().any(|(path, _, _)| {
190                    RULES_FILE_NAMES
191                        .iter()
192                        .any(|name| path.as_ref() == Path::new(name))
193                }) {
194                    self.enqueue_system_prompt_reload();
195                }
196            }
197            _ => {}
198        }
199    }
200
201    fn enqueue_system_prompt_reload(&mut self) {
202        self.reload_system_prompt_tx.try_send(()).ok();
203    }
204
205    // Note that this should only be called from `reload_system_prompt_task`.
206    fn reload_system_prompt(
207        &self,
208        prompt_store: Option<Entity<PromptStore>>,
209        cx: &mut Context<Self>,
210    ) -> Task<()> {
211        let project = self.project.read(cx);
212        let worktree_tasks = project
213            .visible_worktrees(cx)
214            .map(|worktree| {
215                Self::load_worktree_info_for_system_prompt(
216                    project.fs().clone(),
217                    worktree.read(cx),
218                    cx,
219                )
220            })
221            .collect::<Vec<_>>();
222        let default_user_rules_task = match prompt_store {
223            None => Task::ready(vec![]),
224            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
225                let prompts = prompt_store.default_prompt_metadata();
226                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
227                    let contents = prompt_store.load(prompt_metadata.id, cx);
228                    async move { (contents.await, prompt_metadata) }
229                });
230                cx.background_spawn(future::join_all(load_tasks))
231            }),
232        };
233
234        cx.spawn(async move |this, cx| {
235            let (worktrees, default_user_rules) =
236                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
237
238            let worktrees = worktrees
239                .into_iter()
240                .map(|(worktree, rules_error)| {
241                    if let Some(rules_error) = rules_error {
242                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
243                    }
244                    worktree
245                })
246                .collect::<Vec<_>>();
247
248            let default_user_rules = default_user_rules
249                .into_iter()
250                .flat_map(|(contents, prompt_metadata)| match contents {
251                    Ok(contents) => Some(UserRulesContext {
252                        uuid: match prompt_metadata.id {
253                            PromptId::User { uuid } => uuid,
254                            PromptId::EditWorkflow => return None,
255                        },
256                        title: prompt_metadata.title.map(|title| title.to_string()),
257                        contents,
258                    }),
259                    Err(err) => {
260                        this.update(cx, |_, cx| {
261                            cx.emit(RulesLoadingError {
262                                message: format!("{err:?}").into(),
263                            });
264                        })
265                        .ok();
266                        None
267                    }
268                })
269                .collect::<Vec<_>>();
270
271            this.update(cx, |this, _cx| {
272                *this.project_context.0.borrow_mut() =
273                    Some(ProjectContext::new(worktrees, default_user_rules));
274            })
275            .ok();
276        })
277    }
278
279    fn load_worktree_info_for_system_prompt(
280        fs: Arc<dyn Fs>,
281        worktree: &Worktree,
282        cx: &App,
283    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
284        let root_name = worktree.root_name().into();
285
286        let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
287        let Some(rules_task) = rules_task else {
288            return Task::ready((
289                WorktreeContext {
290                    root_name,
291                    rules_file: None,
292                },
293                None,
294            ));
295        };
296
297        cx.spawn(async move |_| {
298            let (rules_file, rules_file_error) = match rules_task.await {
299                Ok(rules_file) => (Some(rules_file), None),
300                Err(err) => (
301                    None,
302                    Some(RulesLoadingError {
303                        message: format!("{err}").into(),
304                    }),
305                ),
306            };
307            let worktree_info = WorktreeContext {
308                root_name,
309                rules_file,
310            };
311            (worktree_info, rules_file_error)
312        })
313    }
314
315    fn load_worktree_rules_file(
316        fs: Arc<dyn Fs>,
317        worktree: &Worktree,
318        cx: &App,
319    ) -> Option<Task<Result<RulesFileContext>>> {
320        let selected_rules_file = RULES_FILE_NAMES
321            .into_iter()
322            .filter_map(|name| {
323                worktree
324                    .entry_for_path(name)
325                    .filter(|entry| entry.is_file())
326                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
327            })
328            .next();
329
330        // Note that Cline supports `.clinerules` being a directory, but that is not currently
331        // supported. This doesn't seem to occur often in GitHub repositories.
332        selected_rules_file.map(|(path_in_worktree, abs_path)| {
333            let fs = fs.clone();
334            cx.background_spawn(async move {
335                let abs_path = abs_path?;
336                let text = fs.load(&abs_path).await.with_context(|| {
337                    format!("Failed to load assistant rules file {:?}", abs_path)
338                })?;
339                anyhow::Ok(RulesFileContext {
340                    path_in_worktree,
341                    abs_path: abs_path.into(),
342                    text: text.trim().to_string(),
343                })
344            })
345        })
346    }
347
348    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
349        self.context_server_manager.clone()
350    }
351
352    pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
353        self.prompt_store.clone()
354    }
355
356    pub fn load_rules(
357        &self,
358        prompt_id: UserPromptId,
359        cx: &App,
360    ) -> Task<Result<(PromptMetadata, String)>> {
361        let prompt_id = PromptId::User { uuid: prompt_id };
362        let Some(prompt_store) = self.prompt_store.as_ref() else {
363            return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
364        };
365        let prompt_store = prompt_store.read(cx);
366        let Some(metadata) = prompt_store.metadata(prompt_id) else {
367            return Task::ready(Err(anyhow!("User rules not found in library.")));
368        };
369        let text_task = prompt_store.load(prompt_id, cx);
370        cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
371    }
372
373    pub fn tools(&self) -> Entity<ToolWorkingSet> {
374        self.tools.clone()
375    }
376
377    /// Returns the number of threads.
378    pub fn thread_count(&self) -> usize {
379        self.threads.len()
380    }
381
382    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
383        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
384        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
385        threads
386    }
387
388    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
389        self.threads().into_iter().take(limit).collect()
390    }
391
392    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
393        cx.new(|cx| {
394            Thread::new(
395                self.project.clone(),
396                self.tools.clone(),
397                self.prompt_builder.clone(),
398                self.project_context.clone(),
399                cx,
400            )
401        })
402    }
403
404    pub fn open_thread(
405        &self,
406        id: &ThreadId,
407        cx: &mut Context<Self>,
408    ) -> Task<Result<Entity<Thread>>> {
409        let id = id.clone();
410        let database_future = ThreadsDatabase::global_future(cx);
411        cx.spawn(async move |this, cx| {
412            let database = database_future.await.map_err(|err| anyhow!(err))?;
413            let thread = database
414                .try_find_thread(id.clone())
415                .await?
416                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
417
418            let thread = this.update(cx, |this, cx| {
419                cx.new(|cx| {
420                    Thread::deserialize(
421                        id.clone(),
422                        thread,
423                        this.project.clone(),
424                        this.tools.clone(),
425                        this.prompt_builder.clone(),
426                        this.project_context.clone(),
427                        cx,
428                    )
429                })
430            })?;
431
432            Ok(thread)
433        })
434    }
435
436    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
437        let (metadata, serialized_thread) =
438            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
439
440        let database_future = ThreadsDatabase::global_future(cx);
441        cx.spawn(async move |this, cx| {
442            let serialized_thread = serialized_thread.await?;
443            let database = database_future.await.map_err(|err| anyhow!(err))?;
444            database.save_thread(metadata, serialized_thread).await?;
445
446            this.update(cx, |this, cx| this.reload(cx))?.await
447        })
448    }
449
450    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
451        let id = id.clone();
452        let database_future = ThreadsDatabase::global_future(cx);
453        cx.spawn(async move |this, cx| {
454            let database = database_future.await.map_err(|err| anyhow!(err))?;
455            database.delete_thread(id.clone()).await?;
456
457            this.update(cx, |this, cx| {
458                this.threads.retain(|thread| thread.id != id);
459                cx.notify();
460            })
461        })
462    }
463
464    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
465        let database_future = ThreadsDatabase::global_future(cx);
466        cx.spawn(async move |this, cx| {
467            let threads = database_future
468                .await
469                .map_err(|err| anyhow!(err))?
470                .list_threads()
471                .await?;
472
473            this.update(cx, |this, cx| {
474                this.threads = threads;
475                cx.notify();
476            })
477        })
478    }
479
480    fn load_default_profile(&self, cx: &mut Context<Self>) {
481        let assistant_settings = AssistantSettings::get_global(cx);
482
483        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
484    }
485
486    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
487        let assistant_settings = AssistantSettings::get_global(cx);
488
489        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
490            self.load_profile(profile.clone(), cx);
491        }
492    }
493
494    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
495        self.tools.update(cx, |tools, cx| {
496            tools.disable_all_tools(cx);
497            tools.enable(
498                ToolSource::Native,
499                &profile
500                    .tools
501                    .iter()
502                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
503                    .collect::<Vec<_>>(),
504                cx,
505            );
506        });
507
508        if profile.enable_all_context_servers {
509            for context_server in self.context_server_manager.read(cx).all_servers() {
510                self.tools.update(cx, |tools, cx| {
511                    tools.enable_source(
512                        ToolSource::ContextServer {
513                            id: context_server.id().into(),
514                        },
515                        cx,
516                    );
517                });
518            }
519        } else {
520            for (context_server_id, preset) in &profile.context_servers {
521                self.tools.update(cx, |tools, cx| {
522                    tools.enable(
523                        ToolSource::ContextServer {
524                            id: context_server_id.clone().into(),
525                        },
526                        &preset
527                            .tools
528                            .iter()
529                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
530                            .collect::<Vec<_>>(),
531                        cx,
532                    )
533                })
534            }
535        }
536    }
537
538    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
539        cx.subscribe(
540            &self.context_server_manager.clone(),
541            Self::handle_context_server_event,
542        )
543        .detach();
544    }
545
546    fn handle_context_server_event(
547        &mut self,
548        context_server_manager: Entity<ContextServerManager>,
549        event: &context_server::manager::Event,
550        cx: &mut Context<Self>,
551    ) {
552        let tool_working_set = self.tools.clone();
553        match event {
554            context_server::manager::Event::ServerStarted { server_id } => {
555                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
556                    let context_server_manager = context_server_manager.clone();
557                    cx.spawn({
558                        let server = server.clone();
559                        let server_id = server_id.clone();
560                        async move |this, cx| {
561                            let Some(protocol) = server.client() else {
562                                return;
563                            };
564
565                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
566                                if let Some(tools) = protocol.list_tools().await.log_err() {
567                                    let tool_ids = tool_working_set
568                                        .update(cx, |tool_working_set, _| {
569                                            tools
570                                                .tools
571                                                .into_iter()
572                                                .map(|tool| {
573                                                    log::info!(
574                                                        "registering context server tool: {:?}",
575                                                        tool.name
576                                                    );
577                                                    tool_working_set.insert(Arc::new(
578                                                        ContextServerTool::new(
579                                                            context_server_manager.clone(),
580                                                            server.id(),
581                                                            tool,
582                                                        ),
583                                                    ))
584                                                })
585                                                .collect::<Vec<_>>()
586                                        })
587                                        .log_err();
588
589                                    if let Some(tool_ids) = tool_ids {
590                                        this.update(cx, |this, cx| {
591                                            this.context_server_tool_ids
592                                                .insert(server_id, tool_ids);
593                                            this.load_default_profile(cx);
594                                        })
595                                        .log_err();
596                                    }
597                                }
598                            }
599                        }
600                    })
601                    .detach();
602                }
603            }
604            context_server::manager::Event::ServerStopped { server_id } => {
605                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
606                    tool_working_set.update(cx, |tool_working_set, _| {
607                        tool_working_set.remove(&tool_ids);
608                    });
609                    self.load_default_profile(cx);
610                }
611            }
612        }
613    }
614}
615
616#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct SerializedThreadMetadata {
618    pub id: ThreadId,
619    pub summary: SharedString,
620    pub updated_at: DateTime<Utc>,
621}
622
623#[derive(Serialize, Deserialize, Debug)]
624pub struct SerializedThread {
625    pub version: String,
626    pub summary: SharedString,
627    pub updated_at: DateTime<Utc>,
628    pub messages: Vec<SerializedMessage>,
629    #[serde(default)]
630    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
631    #[serde(default)]
632    pub cumulative_token_usage: TokenUsage,
633    #[serde(default)]
634    pub request_token_usage: Vec<TokenUsage>,
635    #[serde(default)]
636    pub detailed_summary_state: DetailedSummaryState,
637    #[serde(default)]
638    pub exceeded_window_error: Option<ExceededWindowError>,
639}
640
641impl SerializedThread {
642    pub const VERSION: &'static str = "0.1.0";
643
644    pub fn from_json(json: &[u8]) -> Result<Self> {
645        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
646        match saved_thread_json.get("version") {
647            Some(serde_json::Value::String(version)) => match version.as_str() {
648                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
649                    saved_thread_json,
650                )?),
651                _ => Err(anyhow!(
652                    "unrecognized serialized thread version: {}",
653                    version
654                )),
655            },
656            None => {
657                let saved_thread =
658                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
659                Ok(saved_thread.upgrade())
660            }
661            version => Err(anyhow!(
662                "unrecognized serialized thread version: {:?}",
663                version
664            )),
665        }
666    }
667}
668
669#[derive(Debug, Serialize, Deserialize)]
670pub struct SerializedMessage {
671    pub id: MessageId,
672    pub role: Role,
673    #[serde(default)]
674    pub segments: Vec<SerializedMessageSegment>,
675    #[serde(default)]
676    pub tool_uses: Vec<SerializedToolUse>,
677    #[serde(default)]
678    pub tool_results: Vec<SerializedToolResult>,
679    #[serde(default)]
680    pub context: String,
681}
682
683#[derive(Debug, Serialize, Deserialize)]
684#[serde(tag = "type")]
685pub enum SerializedMessageSegment {
686    #[serde(rename = "text")]
687    Text {
688        text: String,
689    },
690    #[serde(rename = "thinking")]
691    Thinking {
692        text: String,
693        #[serde(skip_serializing_if = "Option::is_none")]
694        signature: Option<String>,
695    },
696    RedactedThinking {
697        data: Vec<u8>,
698    },
699}
700
701#[derive(Debug, Serialize, Deserialize)]
702pub struct SerializedToolUse {
703    pub id: LanguageModelToolUseId,
704    pub name: SharedString,
705    pub input: serde_json::Value,
706}
707
708#[derive(Debug, Serialize, Deserialize)]
709pub struct SerializedToolResult {
710    pub tool_use_id: LanguageModelToolUseId,
711    pub is_error: bool,
712    pub content: Arc<str>,
713}
714
715#[derive(Serialize, Deserialize)]
716struct LegacySerializedThread {
717    pub summary: SharedString,
718    pub updated_at: DateTime<Utc>,
719    pub messages: Vec<LegacySerializedMessage>,
720    #[serde(default)]
721    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
722}
723
724impl LegacySerializedThread {
725    pub fn upgrade(self) -> SerializedThread {
726        SerializedThread {
727            version: SerializedThread::VERSION.to_string(),
728            summary: self.summary,
729            updated_at: self.updated_at,
730            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
731            initial_project_snapshot: self.initial_project_snapshot,
732            cumulative_token_usage: TokenUsage::default(),
733            request_token_usage: Vec::new(),
734            detailed_summary_state: DetailedSummaryState::default(),
735            exceeded_window_error: None,
736        }
737    }
738}
739
740#[derive(Debug, Serialize, Deserialize)]
741struct LegacySerializedMessage {
742    pub id: MessageId,
743    pub role: Role,
744    pub text: String,
745    #[serde(default)]
746    pub tool_uses: Vec<SerializedToolUse>,
747    #[serde(default)]
748    pub tool_results: Vec<SerializedToolResult>,
749}
750
751impl LegacySerializedMessage {
752    fn upgrade(self) -> SerializedMessage {
753        SerializedMessage {
754            id: self.id,
755            role: self.role,
756            segments: vec![SerializedMessageSegment::Text { text: self.text }],
757            tool_uses: self.tool_uses,
758            tool_results: self.tool_results,
759            context: String::new(),
760        }
761    }
762}
763
764struct GlobalThreadsDatabase(
765    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
766);
767
768impl Global for GlobalThreadsDatabase {}
769
770pub(crate) struct ThreadsDatabase {
771    executor: BackgroundExecutor,
772    env: heed::Env,
773    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
774}
775
776impl heed::BytesEncode<'_> for SerializedThread {
777    type EItem = SerializedThread;
778
779    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
780        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
781    }
782}
783
784impl<'a> heed::BytesDecode<'a> for SerializedThread {
785    type DItem = SerializedThread;
786
787    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
788        // We implement this type manually because we want to call `SerializedThread::from_json`,
789        // instead of the Deserialize trait implementation for `SerializedThread`.
790        SerializedThread::from_json(bytes).map_err(Into::into)
791    }
792}
793
794impl ThreadsDatabase {
795    fn global_future(
796        cx: &mut App,
797    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
798        GlobalThreadsDatabase::global(cx).0.clone()
799    }
800
801    fn init(cx: &mut App) {
802        let executor = cx.background_executor().clone();
803        let database_future = executor
804            .spawn({
805                let executor = executor.clone();
806                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
807                async move { ThreadsDatabase::new(database_path, executor) }
808            })
809            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
810            .boxed()
811            .shared();
812
813        cx.set_global(GlobalThreadsDatabase(database_future));
814    }
815
816    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
817        std::fs::create_dir_all(&path)?;
818
819        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
820        let env = unsafe {
821            heed::EnvOpenOptions::new()
822                .map_size(ONE_GB_IN_BYTES)
823                .max_dbs(1)
824                .open(path)?
825        };
826
827        let mut txn = env.write_txn()?;
828        let threads = env.create_database(&mut txn, Some("threads"))?;
829        txn.commit()?;
830
831        Ok(Self {
832            executor,
833            env,
834            threads,
835        })
836    }
837
838    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
839        let env = self.env.clone();
840        let threads = self.threads;
841
842        self.executor.spawn(async move {
843            let txn = env.read_txn()?;
844            let mut iter = threads.iter(&txn)?;
845            let mut threads = Vec::new();
846            while let Some((key, value)) = iter.next().transpose()? {
847                threads.push(SerializedThreadMetadata {
848                    id: key,
849                    summary: value.summary,
850                    updated_at: value.updated_at,
851                });
852            }
853
854            Ok(threads)
855        })
856    }
857
858    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
859        let env = self.env.clone();
860        let threads = self.threads;
861
862        self.executor.spawn(async move {
863            let txn = env.read_txn()?;
864            let thread = threads.get(&txn, &id)?;
865            Ok(thread)
866        })
867    }
868
869    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
870        let env = self.env.clone();
871        let threads = self.threads;
872
873        self.executor.spawn(async move {
874            let mut txn = env.write_txn()?;
875            threads.put(&mut txn, &id, &thread)?;
876            txn.commit()?;
877            Ok(())
878        })
879    }
880
881    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
882        let env = self.env.clone();
883        let threads = self.threads;
884
885        self.executor.spawn(async move {
886            let mut txn = env.write_txn()?;
887            threads.delete(&mut txn, &id)?;
888            txn.commit()?;
889            Ok(())
890        })
891    }
892}