thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::manager::{ContextServerManager, ContextServerStatus};
 13use context_server::{ContextServerDescriptorRegistry, ContextServerTool};
 14use futures::channel::{mpsc, oneshot};
 15use futures::future::{self, BoxFuture, Shared};
 16use futures::{FutureExt as _, StreamExt as _};
 17use gpui::{
 18    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 19    Subscription, Task, prelude::*,
 20};
 21use heed::Database;
 22use heed::types::SerdeBincode;
 23use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 24use project::{Project, ProjectItem, ProjectPath, Worktree};
 25use prompt_store::{
 26    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
 27    UserRulesContext, WorktreeContext,
 28};
 29use serde::{Deserialize, Serialize};
 30use settings::{Settings as _, SettingsStore};
 31use util::ResultExt as _;
 32
 33use crate::thread::{
 34    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 35};
 36
 37const RULES_FILE_NAMES: [&'static str; 6] = [
 38    ".rules",
 39    ".cursorrules",
 40    ".windsurfrules",
 41    ".clinerules",
 42    ".github/copilot-instructions.md",
 43    "CLAUDE.md",
 44];
 45
 46pub fn init(cx: &mut App) {
 47    ThreadsDatabase::init(cx);
 48}
 49
 50/// A system prompt shared by all threads created by this ThreadStore
 51#[derive(Clone, Default)]
 52pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 53
 54impl SharedProjectContext {
 55    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 56        self.0.borrow()
 57    }
 58}
 59
 60pub struct ThreadStore {
 61    project: Entity<Project>,
 62    tools: Entity<ToolWorkingSet>,
 63    prompt_builder: Arc<PromptBuilder>,
 64    prompt_store: Option<Entity<PromptStore>>,
 65    context_server_manager: Entity<ContextServerManager>,
 66    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 67    threads: Vec<SerializedThreadMetadata>,
 68    project_context: SharedProjectContext,
 69    reload_system_prompt_tx: mpsc::Sender<()>,
 70    _reload_system_prompt_task: Task<()>,
 71    _subscriptions: Vec<Subscription>,
 72}
 73
 74pub struct RulesLoadingError {
 75    pub message: SharedString,
 76}
 77
 78impl EventEmitter<RulesLoadingError> for ThreadStore {}
 79
 80impl ThreadStore {
 81    pub fn load(
 82        project: Entity<Project>,
 83        tools: Entity<ToolWorkingSet>,
 84        prompt_store: Option<Entity<PromptStore>>,
 85        prompt_builder: Arc<PromptBuilder>,
 86        cx: &mut App,
 87    ) -> Task<Result<Entity<Self>>> {
 88        cx.spawn(async move |cx| {
 89            let (thread_store, ready_rx) = cx.update(|cx| {
 90                let mut option_ready_rx = None;
 91                let thread_store = cx.new(|cx| {
 92                    let (thread_store, ready_rx) =
 93                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 94                    option_ready_rx = Some(ready_rx);
 95                    thread_store
 96                });
 97                (thread_store, option_ready_rx.take().unwrap())
 98            })?;
 99            ready_rx.await?;
100            Ok(thread_store)
101        })
102    }
103
104    fn new(
105        project: Entity<Project>,
106        tools: Entity<ToolWorkingSet>,
107        prompt_builder: Arc<PromptBuilder>,
108        prompt_store: Option<Entity<PromptStore>>,
109        cx: &mut Context<Self>,
110    ) -> (Self, oneshot::Receiver<()>) {
111        let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx);
112        let context_server_manager = cx.new(|cx| {
113            ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
114        });
115
116        let mut subscriptions = vec![
117            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
118                this.load_default_profile(cx);
119            }),
120            cx.subscribe(&project, Self::handle_project_event),
121        ];
122
123        if let Some(prompt_store) = prompt_store.as_ref() {
124            subscriptions.push(cx.subscribe(
125                prompt_store,
126                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
127                    this.enqueue_system_prompt_reload();
128                },
129            ))
130        }
131
132        // This channel and task prevent concurrent and redundant loading of the system prompt.
133        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
134        let (ready_tx, ready_rx) = oneshot::channel();
135        let mut ready_tx = Some(ready_tx);
136        let reload_system_prompt_task = cx.spawn({
137            let prompt_store = prompt_store.clone();
138            async move |thread_store, cx| {
139                loop {
140                    let Some(reload_task) = thread_store
141                        .update(cx, |thread_store, cx| {
142                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
143                        })
144                        .ok()
145                    else {
146                        return;
147                    };
148                    reload_task.await;
149                    if let Some(ready_tx) = ready_tx.take() {
150                        ready_tx.send(()).ok();
151                    }
152                    reload_system_prompt_rx.next().await;
153                }
154            }
155        });
156
157        let this = Self {
158            project,
159            tools,
160            prompt_builder,
161            prompt_store,
162            context_server_manager,
163            context_server_tool_ids: HashMap::default(),
164            threads: Vec::new(),
165            project_context: SharedProjectContext::default(),
166            reload_system_prompt_tx,
167            _reload_system_prompt_task: reload_system_prompt_task,
168            _subscriptions: subscriptions,
169        };
170        this.load_default_profile(cx);
171        this.register_context_server_handlers(cx);
172        this.reload(cx).detach_and_log_err(cx);
173        (this, ready_rx)
174    }
175
176    fn handle_project_event(
177        &mut self,
178        _project: Entity<Project>,
179        event: &project::Event,
180        _cx: &mut Context<Self>,
181    ) {
182        match event {
183            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
184                self.enqueue_system_prompt_reload();
185            }
186            project::Event::WorktreeUpdatedEntries(_, items) => {
187                if items.iter().any(|(path, _, _)| {
188                    RULES_FILE_NAMES
189                        .iter()
190                        .any(|name| path.as_ref() == Path::new(name))
191                }) {
192                    self.enqueue_system_prompt_reload();
193                }
194            }
195            _ => {}
196        }
197    }
198
199    fn enqueue_system_prompt_reload(&mut self) {
200        self.reload_system_prompt_tx.try_send(()).ok();
201    }
202
203    // Note that this should only be called from `reload_system_prompt_task`.
204    fn reload_system_prompt(
205        &self,
206        prompt_store: Option<Entity<PromptStore>>,
207        cx: &mut Context<Self>,
208    ) -> Task<()> {
209        let worktrees = self
210            .project
211            .read(cx)
212            .visible_worktrees(cx)
213            .collect::<Vec<_>>();
214        let worktree_tasks = worktrees
215            .into_iter()
216            .map(|worktree| {
217                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
218            })
219            .collect::<Vec<_>>();
220        let default_user_rules_task = match prompt_store {
221            None => Task::ready(vec![]),
222            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
223                let prompts = prompt_store.default_prompt_metadata();
224                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
225                    let contents = prompt_store.load(prompt_metadata.id, cx);
226                    async move { (contents.await, prompt_metadata) }
227                });
228                cx.background_spawn(future::join_all(load_tasks))
229            }),
230        };
231
232        cx.spawn(async move |this, cx| {
233            let (worktrees, default_user_rules) =
234                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
235
236            let worktrees = worktrees
237                .into_iter()
238                .map(|(worktree, rules_error)| {
239                    if let Some(rules_error) = rules_error {
240                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
241                    }
242                    worktree
243                })
244                .collect::<Vec<_>>();
245
246            let default_user_rules = default_user_rules
247                .into_iter()
248                .flat_map(|(contents, prompt_metadata)| match contents {
249                    Ok(contents) => Some(UserRulesContext {
250                        uuid: match prompt_metadata.id {
251                            PromptId::User { uuid } => uuid,
252                            PromptId::EditWorkflow => return None,
253                        },
254                        title: prompt_metadata.title.map(|title| title.to_string()),
255                        contents,
256                    }),
257                    Err(err) => {
258                        this.update(cx, |_, cx| {
259                            cx.emit(RulesLoadingError {
260                                message: format!("{err:?}").into(),
261                            });
262                        })
263                        .ok();
264                        None
265                    }
266                })
267                .collect::<Vec<_>>();
268
269            this.update(cx, |this, _cx| {
270                *this.project_context.0.borrow_mut() =
271                    Some(ProjectContext::new(worktrees, default_user_rules));
272            })
273            .ok();
274        })
275    }
276
277    fn load_worktree_info_for_system_prompt(
278        worktree: Entity<Worktree>,
279        project: Entity<Project>,
280        cx: &mut App,
281    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
282        let root_name = worktree.read(cx).root_name().into();
283
284        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
285        let Some(rules_task) = rules_task else {
286            return Task::ready((
287                WorktreeContext {
288                    root_name,
289                    rules_file: None,
290                },
291                None,
292            ));
293        };
294
295        cx.spawn(async move |_| {
296            let (rules_file, rules_file_error) = match rules_task.await {
297                Ok(rules_file) => (Some(rules_file), None),
298                Err(err) => (
299                    None,
300                    Some(RulesLoadingError {
301                        message: format!("{err}").into(),
302                    }),
303                ),
304            };
305            let worktree_info = WorktreeContext {
306                root_name,
307                rules_file,
308            };
309            (worktree_info, rules_file_error)
310        })
311    }
312
313    fn load_worktree_rules_file(
314        worktree: Entity<Worktree>,
315        project: Entity<Project>,
316        cx: &mut App,
317    ) -> Option<Task<Result<RulesFileContext>>> {
318        let worktree_ref = worktree.read(cx);
319        let worktree_id = worktree_ref.id();
320        let selected_rules_file = RULES_FILE_NAMES
321            .into_iter()
322            .filter_map(|name| {
323                worktree_ref
324                    .entry_for_path(name)
325                    .filter(|entry| entry.is_file())
326                    .map(|entry| entry.path.clone())
327            })
328            .next();
329
330        // Note that Cline supports `.clinerules` being a directory, but that is not currently
331        // supported. This doesn't seem to occur often in GitHub repositories.
332        selected_rules_file.map(|path_in_worktree| {
333            let project_path = ProjectPath {
334                worktree_id,
335                path: path_in_worktree.clone(),
336            };
337            let buffer_task =
338                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
339            let rope_task = cx.spawn(async move |cx| {
340                buffer_task.await?.read_with(cx, |buffer, cx| {
341                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
342                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
343                })?
344            });
345            // Build a string from the rope on a background thread.
346            cx.background_spawn(async move {
347                let (project_entry_id, rope) = rope_task.await?;
348                anyhow::Ok(RulesFileContext {
349                    path_in_worktree,
350                    text: rope.to_string().trim().to_string(),
351                    project_entry_id: project_entry_id.to_usize(),
352                })
353            })
354        })
355    }
356
357    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
358        self.context_server_manager.clone()
359    }
360
361    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
362        &self.prompt_store
363    }
364
365    pub fn tools(&self) -> Entity<ToolWorkingSet> {
366        self.tools.clone()
367    }
368
369    /// Returns the number of threads.
370    pub fn thread_count(&self) -> usize {
371        self.threads.len()
372    }
373
374    pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
375        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
376        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
377        threads
378    }
379
380    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
381        cx.new(|cx| {
382            Thread::new(
383                self.project.clone(),
384                self.tools.clone(),
385                self.prompt_builder.clone(),
386                self.project_context.clone(),
387                cx,
388            )
389        })
390    }
391
392    pub fn open_thread(
393        &self,
394        id: &ThreadId,
395        cx: &mut Context<Self>,
396    ) -> Task<Result<Entity<Thread>>> {
397        let id = id.clone();
398        let database_future = ThreadsDatabase::global_future(cx);
399        cx.spawn(async move |this, cx| {
400            let database = database_future.await.map_err(|err| anyhow!(err))?;
401            let thread = database
402                .try_find_thread(id.clone())
403                .await?
404                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
405
406            let thread = this.update(cx, |this, cx| {
407                cx.new(|cx| {
408                    Thread::deserialize(
409                        id.clone(),
410                        thread,
411                        this.project.clone(),
412                        this.tools.clone(),
413                        this.prompt_builder.clone(),
414                        this.project_context.clone(),
415                        cx,
416                    )
417                })
418            })?;
419
420            Ok(thread)
421        })
422    }
423
424    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
425        let (metadata, serialized_thread) =
426            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
427
428        let database_future = ThreadsDatabase::global_future(cx);
429        cx.spawn(async move |this, cx| {
430            let serialized_thread = serialized_thread.await?;
431            let database = database_future.await.map_err(|err| anyhow!(err))?;
432            database.save_thread(metadata, serialized_thread).await?;
433
434            this.update(cx, |this, cx| this.reload(cx))?.await
435        })
436    }
437
438    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
439        let id = id.clone();
440        let database_future = ThreadsDatabase::global_future(cx);
441        cx.spawn(async move |this, cx| {
442            let database = database_future.await.map_err(|err| anyhow!(err))?;
443            database.delete_thread(id.clone()).await?;
444
445            this.update(cx, |this, cx| {
446                this.threads.retain(|thread| thread.id != id);
447                cx.notify();
448            })
449        })
450    }
451
452    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
453        let database_future = ThreadsDatabase::global_future(cx);
454        cx.spawn(async move |this, cx| {
455            let threads = database_future
456                .await
457                .map_err(|err| anyhow!(err))?
458                .list_threads()
459                .await?;
460
461            this.update(cx, |this, cx| {
462                this.threads = threads;
463                cx.notify();
464            })
465        })
466    }
467
468    fn load_default_profile(&self, cx: &mut Context<Self>) {
469        let assistant_settings = AssistantSettings::get_global(cx);
470
471        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
472    }
473
474    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
475        let assistant_settings = AssistantSettings::get_global(cx);
476
477        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
478            self.load_profile(profile.clone(), cx);
479        }
480    }
481
482    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
483        self.tools.update(cx, |tools, cx| {
484            tools.disable_all_tools(cx);
485            tools.enable(
486                ToolSource::Native,
487                &profile
488                    .tools
489                    .iter()
490                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
491                    .collect::<Vec<_>>(),
492                cx,
493            );
494        });
495
496        if profile.enable_all_context_servers {
497            for context_server in self.context_server_manager.read(cx).all_servers() {
498                self.tools.update(cx, |tools, cx| {
499                    tools.enable_source(
500                        ToolSource::ContextServer {
501                            id: context_server.id().into(),
502                        },
503                        cx,
504                    );
505                });
506            }
507            // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
508            for (context_server_id, preset) in &profile.context_servers {
509                self.tools.update(cx, |tools, cx| {
510                    tools.disable(
511                        ToolSource::ContextServer {
512                            id: context_server_id.clone().into(),
513                        },
514                        &preset
515                            .tools
516                            .iter()
517                            .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
518                            .collect::<Vec<_>>(),
519                        cx,
520                    )
521                })
522            }
523        } else {
524            for (context_server_id, preset) in &profile.context_servers {
525                self.tools.update(cx, |tools, cx| {
526                    tools.enable(
527                        ToolSource::ContextServer {
528                            id: context_server_id.clone().into(),
529                        },
530                        &preset
531                            .tools
532                            .iter()
533                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
534                            .collect::<Vec<_>>(),
535                        cx,
536                    )
537                })
538            }
539        }
540    }
541
542    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
543        cx.subscribe(
544            &self.context_server_manager.clone(),
545            Self::handle_context_server_event,
546        )
547        .detach();
548    }
549
550    fn handle_context_server_event(
551        &mut self,
552        context_server_manager: Entity<ContextServerManager>,
553        event: &context_server::manager::Event,
554        cx: &mut Context<Self>,
555    ) {
556        let tool_working_set = self.tools.clone();
557        match event {
558            context_server::manager::Event::ServerStatusChanged { server_id, status } => {
559                match status {
560                    Some(ContextServerStatus::Running) => {
561                        if let Some(server) = context_server_manager.read(cx).get_server(server_id)
562                        {
563                            let context_server_manager = context_server_manager.clone();
564                            cx.spawn({
565                                let server = server.clone();
566                                let server_id = server_id.clone();
567                                async move |this, cx| {
568                                    let Some(protocol) = server.client() else {
569                                        return;
570                                    };
571
572                                    if protocol.capable(context_server::protocol::ServerCapability::Tools) {
573                                        if let Some(tools) = protocol.list_tools().await.log_err() {
574                                            let tool_ids = tool_working_set
575                                                .update(cx, |tool_working_set, _| {
576                                                    tools
577                                                        .tools
578                                                        .into_iter()
579                                                        .map(|tool| {
580                                                            log::info!(
581                                                                "registering context server tool: {:?}",
582                                                                tool.name
583                                                            );
584                                                            tool_working_set.insert(Arc::new(
585                                                                ContextServerTool::new(
586                                                                    context_server_manager.clone(),
587                                                                    server.id(),
588                                                                    tool,
589                                                                ),
590                                                            ))
591                                                        })
592                                                        .collect::<Vec<_>>()
593                                                })
594                                                .log_err();
595
596                                            if let Some(tool_ids) = tool_ids {
597                                                this.update(cx, |this, cx| {
598                                                    this.context_server_tool_ids
599                                                        .insert(server_id, tool_ids);
600                                                    this.load_default_profile(cx);
601                                                })
602                                                .log_err();
603                                            }
604                                        }
605                                    }
606                                }
607                            })
608                            .detach();
609                        }
610                    }
611                    None => {
612                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
613                            tool_working_set.update(cx, |tool_working_set, _| {
614                                tool_working_set.remove(&tool_ids);
615                            });
616                            self.load_default_profile(cx);
617                        }
618                    }
619                    _ => {}
620                }
621            }
622        }
623    }
624}
625
626#[derive(Debug, Clone, Serialize, Deserialize)]
627pub struct SerializedThreadMetadata {
628    pub id: ThreadId,
629    pub summary: SharedString,
630    pub updated_at: DateTime<Utc>,
631}
632
633#[derive(Serialize, Deserialize, Debug)]
634pub struct SerializedThread {
635    pub version: String,
636    pub summary: SharedString,
637    pub updated_at: DateTime<Utc>,
638    pub messages: Vec<SerializedMessage>,
639    #[serde(default)]
640    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
641    #[serde(default)]
642    pub cumulative_token_usage: TokenUsage,
643    #[serde(default)]
644    pub request_token_usage: Vec<TokenUsage>,
645    #[serde(default)]
646    pub detailed_summary_state: DetailedSummaryState,
647    #[serde(default)]
648    pub exceeded_window_error: Option<ExceededWindowError>,
649    #[serde(default)]
650    pub model: Option<SerializedLanguageModel>,
651}
652
653#[derive(Serialize, Deserialize, Debug)]
654pub struct SerializedLanguageModel {
655    pub provider: String,
656    pub model: String,
657}
658
659impl SerializedThread {
660    pub const VERSION: &'static str = "0.2.0";
661
662    pub fn from_json(json: &[u8]) -> Result<Self> {
663        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
664        match saved_thread_json.get("version") {
665            Some(serde_json::Value::String(version)) => match version.as_str() {
666                SerializedThreadV0_1_0::VERSION => {
667                    let saved_thread =
668                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
669                    Ok(saved_thread.upgrade())
670                }
671                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
672                    saved_thread_json,
673                )?),
674                _ => Err(anyhow!(
675                    "unrecognized serialized thread version: {}",
676                    version
677                )),
678            },
679            None => {
680                let saved_thread =
681                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
682                Ok(saved_thread.upgrade())
683            }
684            version => Err(anyhow!(
685                "unrecognized serialized thread version: {:?}",
686                version
687            )),
688        }
689    }
690}
691
692#[derive(Serialize, Deserialize, Debug)]
693pub struct SerializedThreadV0_1_0(
694    // The structure did not change, so we are reusing the latest SerializedThread.
695    // When making the next version, make sure this points to SerializedThreadV0_2_0
696    SerializedThread,
697);
698
699impl SerializedThreadV0_1_0 {
700    pub const VERSION: &'static str = "0.1.0";
701
702    pub fn upgrade(self) -> SerializedThread {
703        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
704
705        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
706
707        for message in self.0.messages {
708            if message.role == Role::User && !message.tool_results.is_empty() {
709                if let Some(last_message) = messages.last_mut() {
710                    debug_assert!(last_message.role == Role::Assistant);
711
712                    last_message.tool_results = message.tool_results;
713                    continue;
714                }
715            }
716
717            messages.push(message);
718        }
719
720        SerializedThread { messages, ..self.0 }
721    }
722}
723
724#[derive(Debug, Serialize, Deserialize)]
725pub struct SerializedMessage {
726    pub id: MessageId,
727    pub role: Role,
728    #[serde(default)]
729    pub segments: Vec<SerializedMessageSegment>,
730    #[serde(default)]
731    pub tool_uses: Vec<SerializedToolUse>,
732    #[serde(default)]
733    pub tool_results: Vec<SerializedToolResult>,
734    #[serde(default)]
735    pub context: String,
736    #[serde(default)]
737    pub creases: Vec<SerializedCrease>,
738}
739
740#[derive(Debug, Serialize, Deserialize)]
741#[serde(tag = "type")]
742pub enum SerializedMessageSegment {
743    #[serde(rename = "text")]
744    Text {
745        text: String,
746    },
747    #[serde(rename = "thinking")]
748    Thinking {
749        text: String,
750        #[serde(skip_serializing_if = "Option::is_none")]
751        signature: Option<String>,
752    },
753    RedactedThinking {
754        data: Vec<u8>,
755    },
756}
757
758#[derive(Debug, Serialize, Deserialize)]
759pub struct SerializedToolUse {
760    pub id: LanguageModelToolUseId,
761    pub name: SharedString,
762    pub input: serde_json::Value,
763}
764
765#[derive(Debug, Serialize, Deserialize)]
766pub struct SerializedToolResult {
767    pub tool_use_id: LanguageModelToolUseId,
768    pub is_error: bool,
769    pub content: Arc<str>,
770}
771
772#[derive(Serialize, Deserialize)]
773struct LegacySerializedThread {
774    pub summary: SharedString,
775    pub updated_at: DateTime<Utc>,
776    pub messages: Vec<LegacySerializedMessage>,
777    #[serde(default)]
778    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
779}
780
781impl LegacySerializedThread {
782    pub fn upgrade(self) -> SerializedThread {
783        SerializedThread {
784            version: SerializedThread::VERSION.to_string(),
785            summary: self.summary,
786            updated_at: self.updated_at,
787            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
788            initial_project_snapshot: self.initial_project_snapshot,
789            cumulative_token_usage: TokenUsage::default(),
790            request_token_usage: Vec::new(),
791            detailed_summary_state: DetailedSummaryState::default(),
792            exceeded_window_error: None,
793            model: None,
794        }
795    }
796}
797
798#[derive(Debug, Serialize, Deserialize)]
799struct LegacySerializedMessage {
800    pub id: MessageId,
801    pub role: Role,
802    pub text: String,
803    #[serde(default)]
804    pub tool_uses: Vec<SerializedToolUse>,
805    #[serde(default)]
806    pub tool_results: Vec<SerializedToolResult>,
807}
808
809impl LegacySerializedMessage {
810    fn upgrade(self) -> SerializedMessage {
811        SerializedMessage {
812            id: self.id,
813            role: self.role,
814            segments: vec![SerializedMessageSegment::Text { text: self.text }],
815            tool_uses: self.tool_uses,
816            tool_results: self.tool_results,
817            context: String::new(),
818            creases: Vec::new(),
819        }
820    }
821}
822
823#[derive(Debug, Serialize, Deserialize)]
824pub struct SerializedCrease {
825    pub start: usize,
826    pub end: usize,
827    pub icon_path: SharedString,
828    pub label: SharedString,
829}
830
831struct GlobalThreadsDatabase(
832    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
833);
834
835impl Global for GlobalThreadsDatabase {}
836
837pub(crate) struct ThreadsDatabase {
838    executor: BackgroundExecutor,
839    env: heed::Env,
840    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
841}
842
843impl heed::BytesEncode<'_> for SerializedThread {
844    type EItem = SerializedThread;
845
846    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
847        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
848    }
849}
850
851impl<'a> heed::BytesDecode<'a> for SerializedThread {
852    type DItem = SerializedThread;
853
854    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
855        // We implement this type manually because we want to call `SerializedThread::from_json`,
856        // instead of the Deserialize trait implementation for `SerializedThread`.
857        SerializedThread::from_json(bytes).map_err(Into::into)
858    }
859}
860
861impl ThreadsDatabase {
862    fn global_future(
863        cx: &mut App,
864    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
865        GlobalThreadsDatabase::global(cx).0.clone()
866    }
867
868    fn init(cx: &mut App) {
869        let executor = cx.background_executor().clone();
870        let database_future = executor
871            .spawn({
872                let executor = executor.clone();
873                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
874                async move { ThreadsDatabase::new(database_path, executor) }
875            })
876            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
877            .boxed()
878            .shared();
879
880        cx.set_global(GlobalThreadsDatabase(database_future));
881    }
882
883    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
884        std::fs::create_dir_all(&path)?;
885
886        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
887        let env = unsafe {
888            heed::EnvOpenOptions::new()
889                .map_size(ONE_GB_IN_BYTES)
890                .max_dbs(1)
891                .open(path)?
892        };
893
894        let mut txn = env.write_txn()?;
895        let threads = env.create_database(&mut txn, Some("threads"))?;
896        txn.commit()?;
897
898        Ok(Self {
899            executor,
900            env,
901            threads,
902        })
903    }
904
905    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
906        let env = self.env.clone();
907        let threads = self.threads;
908
909        self.executor.spawn(async move {
910            let txn = env.read_txn()?;
911            let mut iter = threads.iter(&txn)?;
912            let mut threads = Vec::new();
913            while let Some((key, value)) = iter.next().transpose()? {
914                threads.push(SerializedThreadMetadata {
915                    id: key,
916                    summary: value.summary,
917                    updated_at: value.updated_at,
918                });
919            }
920
921            Ok(threads)
922        })
923    }
924
925    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
926        let env = self.env.clone();
927        let threads = self.threads;
928
929        self.executor.spawn(async move {
930            let txn = env.read_txn()?;
931            let thread = threads.get(&txn, &id)?;
932            Ok(thread)
933        })
934    }
935
936    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
937        let env = self.env.clone();
938        let threads = self.threads;
939
940        self.executor.spawn(async move {
941            let mut txn = env.write_txn()?;
942            threads.put(&mut txn, &id, &thread)?;
943            txn.commit()?;
944            Ok(())
945        })
946    }
947
948    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
949        let env = self.env.clone();
950        let threads = self.threads;
951
952        self.executor.spawn(async move {
953            let mut txn = env.write_txn()?;
954            threads.delete(&mut txn, &id)?;
955            txn.commit()?;
956            Ok(())
957        })
958    }
959}