thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::ContextServerId;
 13use futures::channel::{mpsc, oneshot};
 14use futures::future::{self, BoxFuture, Shared};
 15use futures::{FutureExt as _, StreamExt as _};
 16use gpui::{
 17    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 18    Subscription, Task, prelude::*,
 19};
 20use heed::Database;
 21use heed::types::SerdeBincode;
 22use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 23use project::context_server_store::{ContextServerStatus, ContextServerStore};
 24use project::{Project, ProjectItem, ProjectPath, Worktree};
 25use prompt_store::{
 26    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
 27    UserRulesContext, WorktreeContext,
 28};
 29use serde::{Deserialize, Serialize};
 30use settings::{Settings as _, SettingsStore};
 31use util::ResultExt as _;
 32
 33use crate::context_server_tool::ContextServerTool;
 34use crate::thread::{
 35    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 36};
 37
 38const RULES_FILE_NAMES: [&'static str; 6] = [
 39    ".rules",
 40    ".cursorrules",
 41    ".windsurfrules",
 42    ".clinerules",
 43    ".github/copilot-instructions.md",
 44    "CLAUDE.md",
 45];
 46
 47pub fn init(cx: &mut App) {
 48    ThreadsDatabase::init(cx);
 49}
 50
 51/// A system prompt shared by all threads created by this ThreadStore
 52#[derive(Clone, Default)]
 53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 54
 55impl SharedProjectContext {
 56    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 57        self.0.borrow()
 58    }
 59}
 60
 61pub type TextThreadStore = assistant_context_editor::ContextStore;
 62
 63pub struct ThreadStore {
 64    project: Entity<Project>,
 65    tools: Entity<ToolWorkingSet>,
 66    prompt_builder: Arc<PromptBuilder>,
 67    prompt_store: Option<Entity<PromptStore>>,
 68    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 69    threads: Vec<SerializedThreadMetadata>,
 70    project_context: SharedProjectContext,
 71    reload_system_prompt_tx: mpsc::Sender<()>,
 72    _reload_system_prompt_task: Task<()>,
 73    _subscriptions: Vec<Subscription>,
 74}
 75
 76pub struct RulesLoadingError {
 77    pub message: SharedString,
 78}
 79
 80impl EventEmitter<RulesLoadingError> for ThreadStore {}
 81
 82impl ThreadStore {
 83    pub fn load(
 84        project: Entity<Project>,
 85        tools: Entity<ToolWorkingSet>,
 86        prompt_store: Option<Entity<PromptStore>>,
 87        prompt_builder: Arc<PromptBuilder>,
 88        cx: &mut App,
 89    ) -> Task<Result<Entity<Self>>> {
 90        cx.spawn(async move |cx| {
 91            let (thread_store, ready_rx) = cx.update(|cx| {
 92                let mut option_ready_rx = None;
 93                let thread_store = cx.new(|cx| {
 94                    let (thread_store, ready_rx) =
 95                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 96                    option_ready_rx = Some(ready_rx);
 97                    thread_store
 98                });
 99                (thread_store, option_ready_rx.take().unwrap())
100            })?;
101            ready_rx.await?;
102            Ok(thread_store)
103        })
104    }
105
106    fn new(
107        project: Entity<Project>,
108        tools: Entity<ToolWorkingSet>,
109        prompt_builder: Arc<PromptBuilder>,
110        prompt_store: Option<Entity<PromptStore>>,
111        cx: &mut Context<Self>,
112    ) -> (Self, oneshot::Receiver<()>) {
113        let mut subscriptions = vec![
114            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
115                this.load_default_profile(cx);
116            }),
117            cx.subscribe(&project, Self::handle_project_event),
118        ];
119
120        if let Some(prompt_store) = prompt_store.as_ref() {
121            subscriptions.push(cx.subscribe(
122                prompt_store,
123                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
124                    this.enqueue_system_prompt_reload();
125                },
126            ))
127        }
128
129        // This channel and task prevent concurrent and redundant loading of the system prompt.
130        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
131        let (ready_tx, ready_rx) = oneshot::channel();
132        let mut ready_tx = Some(ready_tx);
133        let reload_system_prompt_task = cx.spawn({
134            let prompt_store = prompt_store.clone();
135            async move |thread_store, cx| {
136                loop {
137                    let Some(reload_task) = thread_store
138                        .update(cx, |thread_store, cx| {
139                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
140                        })
141                        .ok()
142                    else {
143                        return;
144                    };
145                    reload_task.await;
146                    if let Some(ready_tx) = ready_tx.take() {
147                        ready_tx.send(()).ok();
148                    }
149                    reload_system_prompt_rx.next().await;
150                }
151            }
152        });
153
154        let this = Self {
155            project,
156            tools,
157            prompt_builder,
158            prompt_store,
159            context_server_tool_ids: HashMap::default(),
160            threads: Vec::new(),
161            project_context: SharedProjectContext::default(),
162            reload_system_prompt_tx,
163            _reload_system_prompt_task: reload_system_prompt_task,
164            _subscriptions: subscriptions,
165        };
166        this.load_default_profile(cx);
167        this.register_context_server_handlers(cx);
168        this.reload(cx).detach_and_log_err(cx);
169        (this, ready_rx)
170    }
171
172    fn handle_project_event(
173        &mut self,
174        _project: Entity<Project>,
175        event: &project::Event,
176        _cx: &mut Context<Self>,
177    ) {
178        match event {
179            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
180                self.enqueue_system_prompt_reload();
181            }
182            project::Event::WorktreeUpdatedEntries(_, items) => {
183                if items.iter().any(|(path, _, _)| {
184                    RULES_FILE_NAMES
185                        .iter()
186                        .any(|name| path.as_ref() == Path::new(name))
187                }) {
188                    self.enqueue_system_prompt_reload();
189                }
190            }
191            _ => {}
192        }
193    }
194
195    fn enqueue_system_prompt_reload(&mut self) {
196        self.reload_system_prompt_tx.try_send(()).ok();
197    }
198
199    // Note that this should only be called from `reload_system_prompt_task`.
200    fn reload_system_prompt(
201        &self,
202        prompt_store: Option<Entity<PromptStore>>,
203        cx: &mut Context<Self>,
204    ) -> Task<()> {
205        let worktrees = self
206            .project
207            .read(cx)
208            .visible_worktrees(cx)
209            .collect::<Vec<_>>();
210        let worktree_tasks = worktrees
211            .into_iter()
212            .map(|worktree| {
213                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
214            })
215            .collect::<Vec<_>>();
216        let default_user_rules_task = match prompt_store {
217            None => Task::ready(vec![]),
218            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
219                let prompts = prompt_store.default_prompt_metadata();
220                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
221                    let contents = prompt_store.load(prompt_metadata.id, cx);
222                    async move { (contents.await, prompt_metadata) }
223                });
224                cx.background_spawn(future::join_all(load_tasks))
225            }),
226        };
227
228        cx.spawn(async move |this, cx| {
229            let (worktrees, default_user_rules) =
230                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
231
232            let worktrees = worktrees
233                .into_iter()
234                .map(|(worktree, rules_error)| {
235                    if let Some(rules_error) = rules_error {
236                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
237                    }
238                    worktree
239                })
240                .collect::<Vec<_>>();
241
242            let default_user_rules = default_user_rules
243                .into_iter()
244                .flat_map(|(contents, prompt_metadata)| match contents {
245                    Ok(contents) => Some(UserRulesContext {
246                        uuid: match prompt_metadata.id {
247                            PromptId::User { uuid } => uuid,
248                            PromptId::EditWorkflow => return None,
249                        },
250                        title: prompt_metadata.title.map(|title| title.to_string()),
251                        contents,
252                    }),
253                    Err(err) => {
254                        this.update(cx, |_, cx| {
255                            cx.emit(RulesLoadingError {
256                                message: format!("{err:?}").into(),
257                            });
258                        })
259                        .ok();
260                        None
261                    }
262                })
263                .collect::<Vec<_>>();
264
265            this.update(cx, |this, _cx| {
266                *this.project_context.0.borrow_mut() =
267                    Some(ProjectContext::new(worktrees, default_user_rules));
268            })
269            .ok();
270        })
271    }
272
273    fn load_worktree_info_for_system_prompt(
274        worktree: Entity<Worktree>,
275        project: Entity<Project>,
276        cx: &mut App,
277    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
278        let root_name = worktree.read(cx).root_name().into();
279
280        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
281        let Some(rules_task) = rules_task else {
282            return Task::ready((
283                WorktreeContext {
284                    root_name,
285                    rules_file: None,
286                },
287                None,
288            ));
289        };
290
291        cx.spawn(async move |_| {
292            let (rules_file, rules_file_error) = match rules_task.await {
293                Ok(rules_file) => (Some(rules_file), None),
294                Err(err) => (
295                    None,
296                    Some(RulesLoadingError {
297                        message: format!("{err}").into(),
298                    }),
299                ),
300            };
301            let worktree_info = WorktreeContext {
302                root_name,
303                rules_file,
304            };
305            (worktree_info, rules_file_error)
306        })
307    }
308
309    fn load_worktree_rules_file(
310        worktree: Entity<Worktree>,
311        project: Entity<Project>,
312        cx: &mut App,
313    ) -> Option<Task<Result<RulesFileContext>>> {
314        let worktree_ref = worktree.read(cx);
315        let worktree_id = worktree_ref.id();
316        let selected_rules_file = RULES_FILE_NAMES
317            .into_iter()
318            .filter_map(|name| {
319                worktree_ref
320                    .entry_for_path(name)
321                    .filter(|entry| entry.is_file())
322                    .map(|entry| entry.path.clone())
323            })
324            .next();
325
326        // Note that Cline supports `.clinerules` being a directory, but that is not currently
327        // supported. This doesn't seem to occur often in GitHub repositories.
328        selected_rules_file.map(|path_in_worktree| {
329            let project_path = ProjectPath {
330                worktree_id,
331                path: path_in_worktree.clone(),
332            };
333            let buffer_task =
334                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
335            let rope_task = cx.spawn(async move |cx| {
336                buffer_task.await?.read_with(cx, |buffer, cx| {
337                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
338                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
339                })?
340            });
341            // Build a string from the rope on a background thread.
342            cx.background_spawn(async move {
343                let (project_entry_id, rope) = rope_task.await?;
344                anyhow::Ok(RulesFileContext {
345                    path_in_worktree,
346                    text: rope.to_string().trim().to_string(),
347                    project_entry_id: project_entry_id.to_usize(),
348                })
349            })
350        })
351    }
352
353    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
354        &self.prompt_store
355    }
356
357    pub fn tools(&self) -> Entity<ToolWorkingSet> {
358        self.tools.clone()
359    }
360
361    /// Returns the number of threads.
362    pub fn thread_count(&self) -> usize {
363        self.threads.len()
364    }
365
366    pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
367        self.threads.iter()
368    }
369
370    pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
371        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
372        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
373        threads
374    }
375
376    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
377        cx.new(|cx| {
378            Thread::new(
379                self.project.clone(),
380                self.tools.clone(),
381                self.prompt_builder.clone(),
382                self.project_context.clone(),
383                cx,
384            )
385        })
386    }
387
388    pub fn open_thread(
389        &self,
390        id: &ThreadId,
391        cx: &mut Context<Self>,
392    ) -> Task<Result<Entity<Thread>>> {
393        let id = id.clone();
394        let database_future = ThreadsDatabase::global_future(cx);
395        cx.spawn(async move |this, cx| {
396            let database = database_future.await.map_err(|err| anyhow!(err))?;
397            let thread = database
398                .try_find_thread(id.clone())
399                .await?
400                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
401
402            let thread = this.update(cx, |this, cx| {
403                cx.new(|cx| {
404                    Thread::deserialize(
405                        id.clone(),
406                        thread,
407                        this.project.clone(),
408                        this.tools.clone(),
409                        this.prompt_builder.clone(),
410                        this.project_context.clone(),
411                        cx,
412                    )
413                })
414            })?;
415
416            Ok(thread)
417        })
418    }
419
420    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
421        let (metadata, serialized_thread) =
422            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
423
424        let database_future = ThreadsDatabase::global_future(cx);
425        cx.spawn(async move |this, cx| {
426            let serialized_thread = serialized_thread.await?;
427            let database = database_future.await.map_err(|err| anyhow!(err))?;
428            database.save_thread(metadata, serialized_thread).await?;
429
430            this.update(cx, |this, cx| this.reload(cx))?.await
431        })
432    }
433
434    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
435        let id = id.clone();
436        let database_future = ThreadsDatabase::global_future(cx);
437        cx.spawn(async move |this, cx| {
438            let database = database_future.await.map_err(|err| anyhow!(err))?;
439            database.delete_thread(id.clone()).await?;
440
441            this.update(cx, |this, cx| {
442                this.threads.retain(|thread| thread.id != id);
443                cx.notify();
444            })
445        })
446    }
447
448    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
449        let database_future = ThreadsDatabase::global_future(cx);
450        cx.spawn(async move |this, cx| {
451            let threads = database_future
452                .await
453                .map_err(|err| anyhow!(err))?
454                .list_threads()
455                .await?;
456
457            this.update(cx, |this, cx| {
458                this.threads = threads;
459                cx.notify();
460            })
461        })
462    }
463
464    fn load_default_profile(&self, cx: &mut Context<Self>) {
465        let assistant_settings = AssistantSettings::get_global(cx);
466
467        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
468    }
469
470    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
471        let assistant_settings = AssistantSettings::get_global(cx);
472
473        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
474            self.load_profile(profile.clone(), cx);
475        }
476    }
477
478    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
479        self.tools.update(cx, |tools, cx| {
480            tools.disable_all_tools(cx);
481            tools.enable(
482                ToolSource::Native,
483                &profile
484                    .tools
485                    .iter()
486                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
487                    .collect::<Vec<_>>(),
488                cx,
489            );
490        });
491
492        if profile.enable_all_context_servers {
493            for context_server_id in self
494                .project
495                .read(cx)
496                .context_server_store()
497                .read(cx)
498                .all_server_ids()
499            {
500                self.tools.update(cx, |tools, cx| {
501                    tools.enable_source(
502                        ToolSource::ContextServer {
503                            id: context_server_id.0.into(),
504                        },
505                        cx,
506                    );
507                });
508            }
509            // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
510            for (context_server_id, preset) in &profile.context_servers {
511                self.tools.update(cx, |tools, cx| {
512                    tools.disable(
513                        ToolSource::ContextServer {
514                            id: context_server_id.clone().into(),
515                        },
516                        &preset
517                            .tools
518                            .iter()
519                            .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
520                            .collect::<Vec<_>>(),
521                        cx,
522                    )
523                })
524            }
525        } else {
526            for (context_server_id, preset) in &profile.context_servers {
527                self.tools.update(cx, |tools, cx| {
528                    tools.enable(
529                        ToolSource::ContextServer {
530                            id: context_server_id.clone().into(),
531                        },
532                        &preset
533                            .tools
534                            .iter()
535                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
536                            .collect::<Vec<_>>(),
537                        cx,
538                    )
539                })
540            }
541        }
542    }
543
544    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
545        cx.subscribe(
546            &self.project.read(cx).context_server_store(),
547            Self::handle_context_server_event,
548        )
549        .detach();
550    }
551
552    fn handle_context_server_event(
553        &mut self,
554        context_server_store: Entity<ContextServerStore>,
555        event: &project::context_server_store::Event,
556        cx: &mut Context<Self>,
557    ) {
558        let tool_working_set = self.tools.clone();
559        match event {
560            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
561                match status {
562                    ContextServerStatus::Running => {
563                        if let Some(server) =
564                            context_server_store.read(cx).get_running_server(server_id)
565                        {
566                            let context_server_manager = context_server_store.clone();
567                            cx.spawn({
568                                let server = server.clone();
569                                let server_id = server_id.clone();
570                                async move |this, cx| {
571                                    let Some(protocol) = server.client() else {
572                                        return;
573                                    };
574
575                                    if protocol.capable(context_server::protocol::ServerCapability::Tools) {
576                                        if let Some(tools) = protocol.list_tools().await.log_err() {
577                                            let tool_ids = tool_working_set
578                                                .update(cx, |tool_working_set, _| {
579                                                    tools
580                                                        .tools
581                                                        .into_iter()
582                                                        .map(|tool| {
583                                                            log::info!(
584                                                                "registering context server tool: {:?}",
585                                                                tool.name
586                                                            );
587                                                            tool_working_set.insert(Arc::new(
588                                                                ContextServerTool::new(
589                                                                    context_server_manager.clone(),
590                                                                    server.id(),
591                                                                    tool,
592                                                                ),
593                                                            ))
594                                                        })
595                                                        .collect::<Vec<_>>()
596                                                })
597                                                .log_err();
598
599                                            if let Some(tool_ids) = tool_ids {
600                                                this.update(cx, |this, cx| {
601                                                    this.context_server_tool_ids
602                                                        .insert(server_id, tool_ids);
603                                                    this.load_default_profile(cx);
604                                                })
605                                                .log_err();
606                                            }
607                                        }
608                                    }
609                                }
610                            })
611                            .detach();
612                        }
613                    }
614                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
615                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
616                            tool_working_set.update(cx, |tool_working_set, _| {
617                                tool_working_set.remove(&tool_ids);
618                            });
619                            self.load_default_profile(cx);
620                        }
621                    }
622                    _ => {}
623                }
624            }
625        }
626    }
627}
628
629#[derive(Debug, Clone, Serialize, Deserialize)]
630pub struct SerializedThreadMetadata {
631    pub id: ThreadId,
632    pub summary: SharedString,
633    pub updated_at: DateTime<Utc>,
634}
635
636#[derive(Serialize, Deserialize, Debug)]
637pub struct SerializedThread {
638    pub version: String,
639    pub summary: SharedString,
640    pub updated_at: DateTime<Utc>,
641    pub messages: Vec<SerializedMessage>,
642    #[serde(default)]
643    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
644    #[serde(default)]
645    pub cumulative_token_usage: TokenUsage,
646    #[serde(default)]
647    pub request_token_usage: Vec<TokenUsage>,
648    #[serde(default)]
649    pub detailed_summary_state: DetailedSummaryState,
650    #[serde(default)]
651    pub exceeded_window_error: Option<ExceededWindowError>,
652    #[serde(default)]
653    pub model: Option<SerializedLanguageModel>,
654    #[serde(default)]
655    pub completion_mode: Option<CompletionMode>,
656}
657
658#[derive(Serialize, Deserialize, Debug)]
659pub struct SerializedLanguageModel {
660    pub provider: String,
661    pub model: String,
662}
663
664impl SerializedThread {
665    pub const VERSION: &'static str = "0.2.0";
666
667    pub fn from_json(json: &[u8]) -> Result<Self> {
668        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
669        match saved_thread_json.get("version") {
670            Some(serde_json::Value::String(version)) => match version.as_str() {
671                SerializedThreadV0_1_0::VERSION => {
672                    let saved_thread =
673                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
674                    Ok(saved_thread.upgrade())
675                }
676                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
677                    saved_thread_json,
678                )?),
679                _ => Err(anyhow!(
680                    "unrecognized serialized thread version: {}",
681                    version
682                )),
683            },
684            None => {
685                let saved_thread =
686                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
687                Ok(saved_thread.upgrade())
688            }
689            version => Err(anyhow!(
690                "unrecognized serialized thread version: {:?}",
691                version
692            )),
693        }
694    }
695}
696
697#[derive(Serialize, Deserialize, Debug)]
698pub struct SerializedThreadV0_1_0(
699    // The structure did not change, so we are reusing the latest SerializedThread.
700    // When making the next version, make sure this points to SerializedThreadV0_2_0
701    SerializedThread,
702);
703
704impl SerializedThreadV0_1_0 {
705    pub const VERSION: &'static str = "0.1.0";
706
707    pub fn upgrade(self) -> SerializedThread {
708        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
709
710        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
711
712        for message in self.0.messages {
713            if message.role == Role::User && !message.tool_results.is_empty() {
714                if let Some(last_message) = messages.last_mut() {
715                    debug_assert!(last_message.role == Role::Assistant);
716
717                    last_message.tool_results = message.tool_results;
718                    continue;
719                }
720            }
721
722            messages.push(message);
723        }
724
725        SerializedThread { messages, ..self.0 }
726    }
727}
728
729#[derive(Debug, Serialize, Deserialize)]
730pub struct SerializedMessage {
731    pub id: MessageId,
732    pub role: Role,
733    #[serde(default)]
734    pub segments: Vec<SerializedMessageSegment>,
735    #[serde(default)]
736    pub tool_uses: Vec<SerializedToolUse>,
737    #[serde(default)]
738    pub tool_results: Vec<SerializedToolResult>,
739    #[serde(default)]
740    pub context: String,
741    #[serde(default)]
742    pub creases: Vec<SerializedCrease>,
743}
744
745#[derive(Debug, Serialize, Deserialize)]
746#[serde(tag = "type")]
747pub enum SerializedMessageSegment {
748    #[serde(rename = "text")]
749    Text {
750        text: String,
751    },
752    #[serde(rename = "thinking")]
753    Thinking {
754        text: String,
755        #[serde(skip_serializing_if = "Option::is_none")]
756        signature: Option<String>,
757    },
758    RedactedThinking {
759        data: Vec<u8>,
760    },
761}
762
763#[derive(Debug, Serialize, Deserialize)]
764pub struct SerializedToolUse {
765    pub id: LanguageModelToolUseId,
766    pub name: SharedString,
767    pub input: serde_json::Value,
768}
769
770#[derive(Debug, Serialize, Deserialize)]
771pub struct SerializedToolResult {
772    pub tool_use_id: LanguageModelToolUseId,
773    pub is_error: bool,
774    pub content: Arc<str>,
775}
776
777#[derive(Serialize, Deserialize)]
778struct LegacySerializedThread {
779    pub summary: SharedString,
780    pub updated_at: DateTime<Utc>,
781    pub messages: Vec<LegacySerializedMessage>,
782    #[serde(default)]
783    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
784}
785
786impl LegacySerializedThread {
787    pub fn upgrade(self) -> SerializedThread {
788        SerializedThread {
789            version: SerializedThread::VERSION.to_string(),
790            summary: self.summary,
791            updated_at: self.updated_at,
792            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
793            initial_project_snapshot: self.initial_project_snapshot,
794            cumulative_token_usage: TokenUsage::default(),
795            request_token_usage: Vec::new(),
796            detailed_summary_state: DetailedSummaryState::default(),
797            exceeded_window_error: None,
798            model: None,
799            completion_mode: None,
800        }
801    }
802}
803
804#[derive(Debug, Serialize, Deserialize)]
805struct LegacySerializedMessage {
806    pub id: MessageId,
807    pub role: Role,
808    pub text: String,
809    #[serde(default)]
810    pub tool_uses: Vec<SerializedToolUse>,
811    #[serde(default)]
812    pub tool_results: Vec<SerializedToolResult>,
813}
814
815impl LegacySerializedMessage {
816    fn upgrade(self) -> SerializedMessage {
817        SerializedMessage {
818            id: self.id,
819            role: self.role,
820            segments: vec![SerializedMessageSegment::Text { text: self.text }],
821            tool_uses: self.tool_uses,
822            tool_results: self.tool_results,
823            context: String::new(),
824            creases: Vec::new(),
825        }
826    }
827}
828
829#[derive(Debug, Serialize, Deserialize)]
830pub struct SerializedCrease {
831    pub start: usize,
832    pub end: usize,
833    pub icon_path: SharedString,
834    pub label: SharedString,
835}
836
837struct GlobalThreadsDatabase(
838    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
839);
840
841impl Global for GlobalThreadsDatabase {}
842
843pub(crate) struct ThreadsDatabase {
844    executor: BackgroundExecutor,
845    env: heed::Env,
846    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
847}
848
849impl heed::BytesEncode<'_> for SerializedThread {
850    type EItem = SerializedThread;
851
852    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
853        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
854    }
855}
856
857impl<'a> heed::BytesDecode<'a> for SerializedThread {
858    type DItem = SerializedThread;
859
860    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
861        // We implement this type manually because we want to call `SerializedThread::from_json`,
862        // instead of the Deserialize trait implementation for `SerializedThread`.
863        SerializedThread::from_json(bytes).map_err(Into::into)
864    }
865}
866
867impl ThreadsDatabase {
868    fn global_future(
869        cx: &mut App,
870    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
871        GlobalThreadsDatabase::global(cx).0.clone()
872    }
873
874    fn init(cx: &mut App) {
875        let executor = cx.background_executor().clone();
876        let database_future = executor
877            .spawn({
878                let executor = executor.clone();
879                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
880                async move { ThreadsDatabase::new(database_path, executor) }
881            })
882            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
883            .boxed()
884            .shared();
885
886        cx.set_global(GlobalThreadsDatabase(database_future));
887    }
888
889    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
890        std::fs::create_dir_all(&path)?;
891
892        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
893        let env = unsafe {
894            heed::EnvOpenOptions::new()
895                .map_size(ONE_GB_IN_BYTES)
896                .max_dbs(1)
897                .open(path)?
898        };
899
900        let mut txn = env.write_txn()?;
901        let threads = env.create_database(&mut txn, Some("threads"))?;
902        txn.commit()?;
903
904        Ok(Self {
905            executor,
906            env,
907            threads,
908        })
909    }
910
911    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
912        let env = self.env.clone();
913        let threads = self.threads;
914
915        self.executor.spawn(async move {
916            let txn = env.read_txn()?;
917            let mut iter = threads.iter(&txn)?;
918            let mut threads = Vec::new();
919            while let Some((key, value)) = iter.next().transpose()? {
920                threads.push(SerializedThreadMetadata {
921                    id: key,
922                    summary: value.summary,
923                    updated_at: value.updated_at,
924                });
925            }
926
927            Ok(threads)
928        })
929    }
930
931    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
932        let env = self.env.clone();
933        let threads = self.threads;
934
935        self.executor.spawn(async move {
936            let txn = env.read_txn()?;
937            let thread = threads.get(&txn, &id)?;
938            Ok(thread)
939        })
940    }
941
942    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
943        let env = self.env.clone();
944        let threads = self.threads;
945
946        self.executor.spawn(async move {
947            let mut txn = env.write_txn()?;
948            threads.put(&mut txn, &id, &thread)?;
949            txn.commit()?;
950            Ok(())
951        })
952    }
953
954    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
955        let env = self.env.clone();
956        let threads = self.threads;
957
958        self.executor.spawn(async move {
959            let mut txn = env.write_txn()?;
960            threads.delete(&mut txn, &id)?;
961            txn.commit()?;
962            Ok(())
963        })
964    }
965}