thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::ContextServerId;
 13use futures::channel::{mpsc, oneshot};
 14use futures::future::{self, BoxFuture, Shared};
 15use futures::{FutureExt as _, StreamExt as _};
 16use gpui::{
 17    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 18    Subscription, Task, prelude::*,
 19};
 20use heed::Database;
 21use heed::types::SerdeBincode;
 22use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 23use project::context_server_store::{ContextServerStatus, ContextServerStore};
 24use project::{Project, ProjectItem, ProjectPath, Worktree};
 25use prompt_store::{
 26    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
 27    UserRulesContext, WorktreeContext,
 28};
 29use serde::{Deserialize, Serialize};
 30use settings::{Settings as _, SettingsStore};
 31use util::ResultExt as _;
 32
 33use crate::context_server_tool::ContextServerTool;
 34use crate::thread::{
 35    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 36};
 37
 38const RULES_FILE_NAMES: [&'static str; 6] = [
 39    ".rules",
 40    ".cursorrules",
 41    ".windsurfrules",
 42    ".clinerules",
 43    ".github/copilot-instructions.md",
 44    "CLAUDE.md",
 45];
 46
 47pub fn init(cx: &mut App) {
 48    ThreadsDatabase::init(cx);
 49}
 50
 51/// A system prompt shared by all threads created by this ThreadStore
 52#[derive(Clone, Default)]
 53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 54
 55impl SharedProjectContext {
 56    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 57        self.0.borrow()
 58    }
 59}
 60
 61pub type TextThreadStore = assistant_context_editor::ContextStore;
 62
 63pub struct ThreadStore {
 64    project: Entity<Project>,
 65    tools: Entity<ToolWorkingSet>,
 66    prompt_builder: Arc<PromptBuilder>,
 67    prompt_store: Option<Entity<PromptStore>>,
 68    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 69    threads: Vec<SerializedThreadMetadata>,
 70    project_context: SharedProjectContext,
 71    reload_system_prompt_tx: mpsc::Sender<()>,
 72    _reload_system_prompt_task: Task<()>,
 73    _subscriptions: Vec<Subscription>,
 74}
 75
 76pub struct RulesLoadingError {
 77    pub message: SharedString,
 78}
 79
 80impl EventEmitter<RulesLoadingError> for ThreadStore {}
 81
 82impl ThreadStore {
 83    pub fn load(
 84        project: Entity<Project>,
 85        tools: Entity<ToolWorkingSet>,
 86        prompt_store: Option<Entity<PromptStore>>,
 87        prompt_builder: Arc<PromptBuilder>,
 88        cx: &mut App,
 89    ) -> Task<Result<Entity<Self>>> {
 90        cx.spawn(async move |cx| {
 91            let (thread_store, ready_rx) = cx.update(|cx| {
 92                let mut option_ready_rx = None;
 93                let thread_store = cx.new(|cx| {
 94                    let (thread_store, ready_rx) =
 95                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 96                    option_ready_rx = Some(ready_rx);
 97                    thread_store
 98                });
 99                (thread_store, option_ready_rx.take().unwrap())
100            })?;
101            ready_rx.await?;
102            Ok(thread_store)
103        })
104    }
105
106    fn new(
107        project: Entity<Project>,
108        tools: Entity<ToolWorkingSet>,
109        prompt_builder: Arc<PromptBuilder>,
110        prompt_store: Option<Entity<PromptStore>>,
111        cx: &mut Context<Self>,
112    ) -> (Self, oneshot::Receiver<()>) {
113        let mut subscriptions = vec![
114            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
115                this.load_default_profile(cx);
116            }),
117            cx.subscribe(&project, Self::handle_project_event),
118        ];
119
120        if let Some(prompt_store) = prompt_store.as_ref() {
121            subscriptions.push(cx.subscribe(
122                prompt_store,
123                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
124                    this.enqueue_system_prompt_reload();
125                },
126            ))
127        }
128
129        // This channel and task prevent concurrent and redundant loading of the system prompt.
130        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
131        let (ready_tx, ready_rx) = oneshot::channel();
132        let mut ready_tx = Some(ready_tx);
133        let reload_system_prompt_task = cx.spawn({
134            let prompt_store = prompt_store.clone();
135            async move |thread_store, cx| {
136                loop {
137                    let Some(reload_task) = thread_store
138                        .update(cx, |thread_store, cx| {
139                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
140                        })
141                        .ok()
142                    else {
143                        return;
144                    };
145                    reload_task.await;
146                    if let Some(ready_tx) = ready_tx.take() {
147                        ready_tx.send(()).ok();
148                    }
149                    reload_system_prompt_rx.next().await;
150                }
151            }
152        });
153
154        let this = Self {
155            project,
156            tools,
157            prompt_builder,
158            prompt_store,
159            context_server_tool_ids: HashMap::default(),
160            threads: Vec::new(),
161            project_context: SharedProjectContext::default(),
162            reload_system_prompt_tx,
163            _reload_system_prompt_task: reload_system_prompt_task,
164            _subscriptions: subscriptions,
165        };
166        this.load_default_profile(cx);
167        this.register_context_server_handlers(cx);
168        this.reload(cx).detach_and_log_err(cx);
169        (this, ready_rx)
170    }
171
172    fn handle_project_event(
173        &mut self,
174        _project: Entity<Project>,
175        event: &project::Event,
176        _cx: &mut Context<Self>,
177    ) {
178        match event {
179            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
180                self.enqueue_system_prompt_reload();
181            }
182            project::Event::WorktreeUpdatedEntries(_, items) => {
183                if items.iter().any(|(path, _, _)| {
184                    RULES_FILE_NAMES
185                        .iter()
186                        .any(|name| path.as_ref() == Path::new(name))
187                }) {
188                    self.enqueue_system_prompt_reload();
189                }
190            }
191            _ => {}
192        }
193    }
194
195    fn enqueue_system_prompt_reload(&mut self) {
196        self.reload_system_prompt_tx.try_send(()).ok();
197    }
198
199    // Note that this should only be called from `reload_system_prompt_task`.
200    fn reload_system_prompt(
201        &self,
202        prompt_store: Option<Entity<PromptStore>>,
203        cx: &mut Context<Self>,
204    ) -> Task<()> {
205        let worktrees = self
206            .project
207            .read(cx)
208            .visible_worktrees(cx)
209            .collect::<Vec<_>>();
210        let worktree_tasks = worktrees
211            .into_iter()
212            .map(|worktree| {
213                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
214            })
215            .collect::<Vec<_>>();
216        let default_user_rules_task = match prompt_store {
217            None => Task::ready(vec![]),
218            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
219                let prompts = prompt_store.default_prompt_metadata();
220                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
221                    let contents = prompt_store.load(prompt_metadata.id, cx);
222                    async move { (contents.await, prompt_metadata) }
223                });
224                cx.background_spawn(future::join_all(load_tasks))
225            }),
226        };
227
228        cx.spawn(async move |this, cx| {
229            let (worktrees, default_user_rules) =
230                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
231
232            let worktrees = worktrees
233                .into_iter()
234                .map(|(worktree, rules_error)| {
235                    if let Some(rules_error) = rules_error {
236                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
237                    }
238                    worktree
239                })
240                .collect::<Vec<_>>();
241
242            let default_user_rules = default_user_rules
243                .into_iter()
244                .flat_map(|(contents, prompt_metadata)| match contents {
245                    Ok(contents) => Some(UserRulesContext {
246                        uuid: match prompt_metadata.id {
247                            PromptId::User { uuid } => uuid,
248                            PromptId::EditWorkflow => return None,
249                        },
250                        title: prompt_metadata.title.map(|title| title.to_string()),
251                        contents,
252                    }),
253                    Err(err) => {
254                        this.update(cx, |_, cx| {
255                            cx.emit(RulesLoadingError {
256                                message: format!("{err:?}").into(),
257                            });
258                        })
259                        .ok();
260                        None
261                    }
262                })
263                .collect::<Vec<_>>();
264
265            this.update(cx, |this, _cx| {
266                *this.project_context.0.borrow_mut() =
267                    Some(ProjectContext::new(worktrees, default_user_rules));
268            })
269            .ok();
270        })
271    }
272
273    fn load_worktree_info_for_system_prompt(
274        worktree: Entity<Worktree>,
275        project: Entity<Project>,
276        cx: &mut App,
277    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
278        let root_name = worktree.read(cx).root_name().into();
279
280        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
281        let Some(rules_task) = rules_task else {
282            return Task::ready((
283                WorktreeContext {
284                    root_name,
285                    rules_file: None,
286                },
287                None,
288            ));
289        };
290
291        cx.spawn(async move |_| {
292            let (rules_file, rules_file_error) = match rules_task.await {
293                Ok(rules_file) => (Some(rules_file), None),
294                Err(err) => (
295                    None,
296                    Some(RulesLoadingError {
297                        message: format!("{err}").into(),
298                    }),
299                ),
300            };
301            let worktree_info = WorktreeContext {
302                root_name,
303                rules_file,
304            };
305            (worktree_info, rules_file_error)
306        })
307    }
308
309    fn load_worktree_rules_file(
310        worktree: Entity<Worktree>,
311        project: Entity<Project>,
312        cx: &mut App,
313    ) -> Option<Task<Result<RulesFileContext>>> {
314        let worktree_ref = worktree.read(cx);
315        let worktree_id = worktree_ref.id();
316        let selected_rules_file = RULES_FILE_NAMES
317            .into_iter()
318            .filter_map(|name| {
319                worktree_ref
320                    .entry_for_path(name)
321                    .filter(|entry| entry.is_file())
322                    .map(|entry| entry.path.clone())
323            })
324            .next();
325
326        // Note that Cline supports `.clinerules` being a directory, but that is not currently
327        // supported. This doesn't seem to occur often in GitHub repositories.
328        selected_rules_file.map(|path_in_worktree| {
329            let project_path = ProjectPath {
330                worktree_id,
331                path: path_in_worktree.clone(),
332            };
333            let buffer_task =
334                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
335            let rope_task = cx.spawn(async move |cx| {
336                buffer_task.await?.read_with(cx, |buffer, cx| {
337                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
338                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
339                })?
340            });
341            // Build a string from the rope on a background thread.
342            cx.background_spawn(async move {
343                let (project_entry_id, rope) = rope_task.await?;
344                anyhow::Ok(RulesFileContext {
345                    path_in_worktree,
346                    text: rope.to_string().trim().to_string(),
347                    project_entry_id: project_entry_id.to_usize(),
348                })
349            })
350        })
351    }
352
353    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
354        &self.prompt_store
355    }
356
357    pub fn tools(&self) -> Entity<ToolWorkingSet> {
358        self.tools.clone()
359    }
360
361    /// Returns the number of threads.
362    pub fn thread_count(&self) -> usize {
363        self.threads.len()
364    }
365
366    pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
367        self.threads.iter()
368    }
369
370    pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
371        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
372        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
373        threads
374    }
375
376    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
377        cx.new(|cx| {
378            Thread::new(
379                self.project.clone(),
380                self.tools.clone(),
381                self.prompt_builder.clone(),
382                self.project_context.clone(),
383                cx,
384            )
385        })
386    }
387
388    pub fn open_thread(
389        &self,
390        id: &ThreadId,
391        cx: &mut Context<Self>,
392    ) -> Task<Result<Entity<Thread>>> {
393        let id = id.clone();
394        let database_future = ThreadsDatabase::global_future(cx);
395        cx.spawn(async move |this, cx| {
396            let database = database_future.await.map_err(|err| anyhow!(err))?;
397            let thread = database
398                .try_find_thread(id.clone())
399                .await?
400                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
401
402            let thread = this.update(cx, |this, cx| {
403                cx.new(|cx| {
404                    Thread::deserialize(
405                        id.clone(),
406                        thread,
407                        this.project.clone(),
408                        this.tools.clone(),
409                        this.prompt_builder.clone(),
410                        this.project_context.clone(),
411                        cx,
412                    )
413                })
414            })?;
415
416            Ok(thread)
417        })
418    }
419
420    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
421        let (metadata, serialized_thread) =
422            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
423
424        let database_future = ThreadsDatabase::global_future(cx);
425        cx.spawn(async move |this, cx| {
426            let serialized_thread = serialized_thread.await?;
427            let database = database_future.await.map_err(|err| anyhow!(err))?;
428            database.save_thread(metadata, serialized_thread).await?;
429
430            this.update(cx, |this, cx| this.reload(cx))?.await
431        })
432    }
433
434    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
435        let id = id.clone();
436        let database_future = ThreadsDatabase::global_future(cx);
437        cx.spawn(async move |this, cx| {
438            let database = database_future.await.map_err(|err| anyhow!(err))?;
439            database.delete_thread(id.clone()).await?;
440
441            this.update(cx, |this, cx| {
442                this.threads.retain(|thread| thread.id != id);
443                cx.notify();
444            })
445        })
446    }
447
448    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
449        let database_future = ThreadsDatabase::global_future(cx);
450        cx.spawn(async move |this, cx| {
451            let threads = database_future
452                .await
453                .map_err(|err| anyhow!(err))?
454                .list_threads()
455                .await?;
456
457            this.update(cx, |this, cx| {
458                this.threads = threads;
459                cx.notify();
460            })
461        })
462    }
463
464    fn load_default_profile(&self, cx: &mut Context<Self>) {
465        let assistant_settings = AssistantSettings::get_global(cx);
466
467        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
468    }
469
470    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
471        let assistant_settings = AssistantSettings::get_global(cx);
472
473        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
474            self.load_profile(profile.clone(), cx);
475        }
476    }
477
478    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
479        self.tools.update(cx, |tools, cx| {
480            tools.disable_all_tools(cx);
481            tools.enable(
482                ToolSource::Native,
483                &profile
484                    .tools
485                    .iter()
486                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
487                    .collect::<Vec<_>>(),
488                cx,
489            );
490        });
491
492        if profile.enable_all_context_servers {
493            for context_server_id in self
494                .project
495                .read(cx)
496                .context_server_store()
497                .read(cx)
498                .all_server_ids()
499            {
500                self.tools.update(cx, |tools, cx| {
501                    tools.enable_source(
502                        ToolSource::ContextServer {
503                            id: context_server_id.0.into(),
504                        },
505                        cx,
506                    );
507                });
508            }
509            // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
510            for (context_server_id, preset) in &profile.context_servers {
511                self.tools.update(cx, |tools, cx| {
512                    tools.disable(
513                        ToolSource::ContextServer {
514                            id: context_server_id.clone().into(),
515                        },
516                        &preset
517                            .tools
518                            .iter()
519                            .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
520                            .collect::<Vec<_>>(),
521                        cx,
522                    )
523                })
524            }
525        } else {
526            for (context_server_id, preset) in &profile.context_servers {
527                self.tools.update(cx, |tools, cx| {
528                    tools.enable(
529                        ToolSource::ContextServer {
530                            id: context_server_id.clone().into(),
531                        },
532                        &preset
533                            .tools
534                            .iter()
535                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
536                            .collect::<Vec<_>>(),
537                        cx,
538                    )
539                })
540            }
541        }
542    }
543
544    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
545        cx.subscribe(
546            &self.project.read(cx).context_server_store(),
547            Self::handle_context_server_event,
548        )
549        .detach();
550    }
551
552    fn handle_context_server_event(
553        &mut self,
554        context_server_store: Entity<ContextServerStore>,
555        event: &project::context_server_store::Event,
556        cx: &mut Context<Self>,
557    ) {
558        let tool_working_set = self.tools.clone();
559        match event {
560            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
561                match status {
562                    ContextServerStatus::Running => {
563                        if let Some(server) =
564                            context_server_store.read(cx).get_running_server(server_id)
565                        {
566                            let context_server_manager = context_server_store.clone();
567                            cx.spawn({
568                                let server = server.clone();
569                                let server_id = server_id.clone();
570                                async move |this, cx| {
571                                    let Some(protocol) = server.client() else {
572                                        return;
573                                    };
574
575                                    if protocol.capable(context_server::protocol::ServerCapability::Tools) {
576                                        if let Some(tools) = protocol.list_tools().await.log_err() {
577                                            let tool_ids = tool_working_set
578                                                .update(cx, |tool_working_set, _| {
579                                                    tools
580                                                        .tools
581                                                        .into_iter()
582                                                        .map(|tool| {
583                                                            log::info!(
584                                                                "registering context server tool: {:?}",
585                                                                tool.name
586                                                            );
587                                                            tool_working_set.insert(Arc::new(
588                                                                ContextServerTool::new(
589                                                                    context_server_manager.clone(),
590                                                                    server.id(),
591                                                                    tool,
592                                                                ),
593                                                            ))
594                                                        })
595                                                        .collect::<Vec<_>>()
596                                                })
597                                                .log_err();
598
599                                            if let Some(tool_ids) = tool_ids {
600                                                this.update(cx, |this, cx| {
601                                                    this.context_server_tool_ids
602                                                        .insert(server_id, tool_ids);
603                                                    this.load_default_profile(cx);
604                                                })
605                                                .log_err();
606                                            }
607                                        }
608                                    }
609                                }
610                            })
611                            .detach();
612                        }
613                    }
614                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
615                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
616                            tool_working_set.update(cx, |tool_working_set, _| {
617                                tool_working_set.remove(&tool_ids);
618                            });
619                            self.load_default_profile(cx);
620                        }
621                    }
622                    _ => {}
623                }
624            }
625        }
626    }
627}
628
629#[derive(Debug, Clone, Serialize, Deserialize)]
630pub struct SerializedThreadMetadata {
631    pub id: ThreadId,
632    pub summary: SharedString,
633    pub updated_at: DateTime<Utc>,
634}
635
636#[derive(Serialize, Deserialize, Debug)]
637pub struct SerializedThread {
638    pub version: String,
639    pub summary: SharedString,
640    pub updated_at: DateTime<Utc>,
641    pub messages: Vec<SerializedMessage>,
642    #[serde(default)]
643    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
644    #[serde(default)]
645    pub cumulative_token_usage: TokenUsage,
646    #[serde(default)]
647    pub request_token_usage: Vec<TokenUsage>,
648    #[serde(default)]
649    pub detailed_summary_state: DetailedSummaryState,
650    #[serde(default)]
651    pub exceeded_window_error: Option<ExceededWindowError>,
652    #[serde(default)]
653    pub model: Option<SerializedLanguageModel>,
654}
655
656#[derive(Serialize, Deserialize, Debug)]
657pub struct SerializedLanguageModel {
658    pub provider: String,
659    pub model: String,
660}
661
662impl SerializedThread {
663    pub const VERSION: &'static str = "0.2.0";
664
665    pub fn from_json(json: &[u8]) -> Result<Self> {
666        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
667        match saved_thread_json.get("version") {
668            Some(serde_json::Value::String(version)) => match version.as_str() {
669                SerializedThreadV0_1_0::VERSION => {
670                    let saved_thread =
671                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
672                    Ok(saved_thread.upgrade())
673                }
674                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
675                    saved_thread_json,
676                )?),
677                _ => Err(anyhow!(
678                    "unrecognized serialized thread version: {}",
679                    version
680                )),
681            },
682            None => {
683                let saved_thread =
684                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
685                Ok(saved_thread.upgrade())
686            }
687            version => Err(anyhow!(
688                "unrecognized serialized thread version: {:?}",
689                version
690            )),
691        }
692    }
693}
694
695#[derive(Serialize, Deserialize, Debug)]
696pub struct SerializedThreadV0_1_0(
697    // The structure did not change, so we are reusing the latest SerializedThread.
698    // When making the next version, make sure this points to SerializedThreadV0_2_0
699    SerializedThread,
700);
701
702impl SerializedThreadV0_1_0 {
703    pub const VERSION: &'static str = "0.1.0";
704
705    pub fn upgrade(self) -> SerializedThread {
706        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
707
708        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
709
710        for message in self.0.messages {
711            if message.role == Role::User && !message.tool_results.is_empty() {
712                if let Some(last_message) = messages.last_mut() {
713                    debug_assert!(last_message.role == Role::Assistant);
714
715                    last_message.tool_results = message.tool_results;
716                    continue;
717                }
718            }
719
720            messages.push(message);
721        }
722
723        SerializedThread { messages, ..self.0 }
724    }
725}
726
727#[derive(Debug, Serialize, Deserialize)]
728pub struct SerializedMessage {
729    pub id: MessageId,
730    pub role: Role,
731    #[serde(default)]
732    pub segments: Vec<SerializedMessageSegment>,
733    #[serde(default)]
734    pub tool_uses: Vec<SerializedToolUse>,
735    #[serde(default)]
736    pub tool_results: Vec<SerializedToolResult>,
737    #[serde(default)]
738    pub context: String,
739    #[serde(default)]
740    pub creases: Vec<SerializedCrease>,
741}
742
743#[derive(Debug, Serialize, Deserialize)]
744#[serde(tag = "type")]
745pub enum SerializedMessageSegment {
746    #[serde(rename = "text")]
747    Text {
748        text: String,
749    },
750    #[serde(rename = "thinking")]
751    Thinking {
752        text: String,
753        #[serde(skip_serializing_if = "Option::is_none")]
754        signature: Option<String>,
755    },
756    RedactedThinking {
757        data: Vec<u8>,
758    },
759}
760
761#[derive(Debug, Serialize, Deserialize)]
762pub struct SerializedToolUse {
763    pub id: LanguageModelToolUseId,
764    pub name: SharedString,
765    pub input: serde_json::Value,
766}
767
768#[derive(Debug, Serialize, Deserialize)]
769pub struct SerializedToolResult {
770    pub tool_use_id: LanguageModelToolUseId,
771    pub is_error: bool,
772    pub content: Arc<str>,
773}
774
775#[derive(Serialize, Deserialize)]
776struct LegacySerializedThread {
777    pub summary: SharedString,
778    pub updated_at: DateTime<Utc>,
779    pub messages: Vec<LegacySerializedMessage>,
780    #[serde(default)]
781    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
782}
783
784impl LegacySerializedThread {
785    pub fn upgrade(self) -> SerializedThread {
786        SerializedThread {
787            version: SerializedThread::VERSION.to_string(),
788            summary: self.summary,
789            updated_at: self.updated_at,
790            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
791            initial_project_snapshot: self.initial_project_snapshot,
792            cumulative_token_usage: TokenUsage::default(),
793            request_token_usage: Vec::new(),
794            detailed_summary_state: DetailedSummaryState::default(),
795            exceeded_window_error: None,
796            model: None,
797        }
798    }
799}
800
801#[derive(Debug, Serialize, Deserialize)]
802struct LegacySerializedMessage {
803    pub id: MessageId,
804    pub role: Role,
805    pub text: String,
806    #[serde(default)]
807    pub tool_uses: Vec<SerializedToolUse>,
808    #[serde(default)]
809    pub tool_results: Vec<SerializedToolResult>,
810}
811
812impl LegacySerializedMessage {
813    fn upgrade(self) -> SerializedMessage {
814        SerializedMessage {
815            id: self.id,
816            role: self.role,
817            segments: vec![SerializedMessageSegment::Text { text: self.text }],
818            tool_uses: self.tool_uses,
819            tool_results: self.tool_results,
820            context: String::new(),
821            creases: Vec::new(),
822        }
823    }
824}
825
826#[derive(Debug, Serialize, Deserialize)]
827pub struct SerializedCrease {
828    pub start: usize,
829    pub end: usize,
830    pub icon_path: SharedString,
831    pub label: SharedString,
832}
833
834struct GlobalThreadsDatabase(
835    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
836);
837
838impl Global for GlobalThreadsDatabase {}
839
840pub(crate) struct ThreadsDatabase {
841    executor: BackgroundExecutor,
842    env: heed::Env,
843    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
844}
845
846impl heed::BytesEncode<'_> for SerializedThread {
847    type EItem = SerializedThread;
848
849    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
850        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
851    }
852}
853
854impl<'a> heed::BytesDecode<'a> for SerializedThread {
855    type DItem = SerializedThread;
856
857    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
858        // We implement this type manually because we want to call `SerializedThread::from_json`,
859        // instead of the Deserialize trait implementation for `SerializedThread`.
860        SerializedThread::from_json(bytes).map_err(Into::into)
861    }
862}
863
864impl ThreadsDatabase {
865    fn global_future(
866        cx: &mut App,
867    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
868        GlobalThreadsDatabase::global(cx).0.clone()
869    }
870
871    fn init(cx: &mut App) {
872        let executor = cx.background_executor().clone();
873        let database_future = executor
874            .spawn({
875                let executor = executor.clone();
876                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
877                async move { ThreadsDatabase::new(database_path, executor) }
878            })
879            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
880            .boxed()
881            .shared();
882
883        cx.set_global(GlobalThreadsDatabase(database_future));
884    }
885
886    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
887        std::fs::create_dir_all(&path)?;
888
889        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
890        let env = unsafe {
891            heed::EnvOpenOptions::new()
892                .map_size(ONE_GB_IN_BYTES)
893                .max_dbs(1)
894                .open(path)?
895        };
896
897        let mut txn = env.write_txn()?;
898        let threads = env.create_database(&mut txn, Some("threads"))?;
899        txn.commit()?;
900
901        Ok(Self {
902            executor,
903            env,
904            threads,
905        })
906    }
907
908    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
909        let env = self.env.clone();
910        let threads = self.threads;
911
912        self.executor.spawn(async move {
913            let txn = env.read_txn()?;
914            let mut iter = threads.iter(&txn)?;
915            let mut threads = Vec::new();
916            while let Some((key, value)) = iter.next().transpose()? {
917                threads.push(SerializedThreadMetadata {
918                    id: key,
919                    summary: value.summary,
920                    updated_at: value.updated_at,
921                });
922            }
923
924            Ok(threads)
925        })
926    }
927
928    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
929        let env = self.env.clone();
930        let threads = self.threads;
931
932        self.executor.spawn(async move {
933            let txn = env.read_txn()?;
934            let thread = threads.get(&txn, &id)?;
935            Ok(thread)
936        })
937    }
938
939    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
940        let env = self.env.clone();
941        let threads = self.threads;
942
943        self.executor.spawn(async move {
944            let mut txn = env.write_txn()?;
945            threads.put(&mut txn, &id, &thread)?;
946            txn.commit()?;
947            Ok(())
948        })
949    }
950
951    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
952        let env = self.env.clone();
953        let threads = self.threads;
954
955        self.executor.spawn(async move {
956            let mut txn = env.write_txn()?;
957            threads.delete(&mut txn, &id)?;
958            txn.commit()?;
959            Ok(())
960        })
961    }
962}