thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
  8use anyhow::{Context as _, Result, anyhow};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::ContextServerId;
 13use futures::channel::{mpsc, oneshot};
 14use futures::future::{self, BoxFuture, Shared};
 15use futures::{FutureExt as _, StreamExt as _};
 16use gpui::{
 17    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 18    Subscription, Task, prelude::*,
 19};
 20use heed::Database;
 21use heed::types::SerdeBincode;
 22use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
 23use project::context_server_store::{ContextServerStatus, ContextServerStore};
 24use project::{Project, ProjectItem, ProjectPath, Worktree};
 25use prompt_store::{
 26    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
 27    UserRulesContext, WorktreeContext,
 28};
 29use serde::{Deserialize, Serialize};
 30use settings::{Settings as _, SettingsStore};
 31use ui::Window;
 32use util::ResultExt as _;
 33
 34use crate::context_server_tool::ContextServerTool;
 35use crate::thread::{
 36    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 37};
 38
 39const RULES_FILE_NAMES: [&'static str; 6] = [
 40    ".rules",
 41    ".cursorrules",
 42    ".windsurfrules",
 43    ".clinerules",
 44    ".github/copilot-instructions.md",
 45    "CLAUDE.md",
 46];
 47
 48pub fn init(cx: &mut App) {
 49    ThreadsDatabase::init(cx);
 50}
 51
 52/// A system prompt shared by all threads created by this ThreadStore
 53#[derive(Clone, Default)]
 54pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 55
 56impl SharedProjectContext {
 57    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 58        self.0.borrow()
 59    }
 60}
 61
 62pub type TextThreadStore = assistant_context_editor::ContextStore;
 63
 64pub struct ThreadStore {
 65    project: Entity<Project>,
 66    tools: Entity<ToolWorkingSet>,
 67    prompt_builder: Arc<PromptBuilder>,
 68    prompt_store: Option<Entity<PromptStore>>,
 69    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 70    threads: Vec<SerializedThreadMetadata>,
 71    project_context: SharedProjectContext,
 72    reload_system_prompt_tx: mpsc::Sender<()>,
 73    _reload_system_prompt_task: Task<()>,
 74    _subscriptions: Vec<Subscription>,
 75}
 76
 77pub struct RulesLoadingError {
 78    pub message: SharedString,
 79}
 80
 81impl EventEmitter<RulesLoadingError> for ThreadStore {}
 82
 83impl ThreadStore {
 84    pub fn load(
 85        project: Entity<Project>,
 86        tools: Entity<ToolWorkingSet>,
 87        prompt_store: Option<Entity<PromptStore>>,
 88        prompt_builder: Arc<PromptBuilder>,
 89        cx: &mut App,
 90    ) -> Task<Result<Entity<Self>>> {
 91        cx.spawn(async move |cx| {
 92            let (thread_store, ready_rx) = cx.update(|cx| {
 93                let mut option_ready_rx = None;
 94                let thread_store = cx.new(|cx| {
 95                    let (thread_store, ready_rx) =
 96                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 97                    option_ready_rx = Some(ready_rx);
 98                    thread_store
 99                });
100                (thread_store, option_ready_rx.take().unwrap())
101            })?;
102            ready_rx.await?;
103            Ok(thread_store)
104        })
105    }
106
107    fn new(
108        project: Entity<Project>,
109        tools: Entity<ToolWorkingSet>,
110        prompt_builder: Arc<PromptBuilder>,
111        prompt_store: Option<Entity<PromptStore>>,
112        cx: &mut Context<Self>,
113    ) -> (Self, oneshot::Receiver<()>) {
114        let mut subscriptions = vec![
115            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
116                this.load_default_profile(cx);
117            }),
118            cx.subscribe(&project, Self::handle_project_event),
119        ];
120
121        if let Some(prompt_store) = prompt_store.as_ref() {
122            subscriptions.push(cx.subscribe(
123                prompt_store,
124                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
125                    this.enqueue_system_prompt_reload();
126                },
127            ))
128        }
129
130        // This channel and task prevent concurrent and redundant loading of the system prompt.
131        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
132        let (ready_tx, ready_rx) = oneshot::channel();
133        let mut ready_tx = Some(ready_tx);
134        let reload_system_prompt_task = cx.spawn({
135            let prompt_store = prompt_store.clone();
136            async move |thread_store, cx| {
137                loop {
138                    let Some(reload_task) = thread_store
139                        .update(cx, |thread_store, cx| {
140                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
141                        })
142                        .ok()
143                    else {
144                        return;
145                    };
146                    reload_task.await;
147                    if let Some(ready_tx) = ready_tx.take() {
148                        ready_tx.send(()).ok();
149                    }
150                    reload_system_prompt_rx.next().await;
151                }
152            }
153        });
154
155        let this = Self {
156            project,
157            tools,
158            prompt_builder,
159            prompt_store,
160            context_server_tool_ids: HashMap::default(),
161            threads: Vec::new(),
162            project_context: SharedProjectContext::default(),
163            reload_system_prompt_tx,
164            _reload_system_prompt_task: reload_system_prompt_task,
165            _subscriptions: subscriptions,
166        };
167        this.load_default_profile(cx);
168        this.register_context_server_handlers(cx);
169        this.reload(cx).detach_and_log_err(cx);
170        (this, ready_rx)
171    }
172
173    fn handle_project_event(
174        &mut self,
175        _project: Entity<Project>,
176        event: &project::Event,
177        _cx: &mut Context<Self>,
178    ) {
179        match event {
180            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
181                self.enqueue_system_prompt_reload();
182            }
183            project::Event::WorktreeUpdatedEntries(_, items) => {
184                if items.iter().any(|(path, _, _)| {
185                    RULES_FILE_NAMES
186                        .iter()
187                        .any(|name| path.as_ref() == Path::new(name))
188                }) {
189                    self.enqueue_system_prompt_reload();
190                }
191            }
192            _ => {}
193        }
194    }
195
196    fn enqueue_system_prompt_reload(&mut self) {
197        self.reload_system_prompt_tx.try_send(()).ok();
198    }
199
200    // Note that this should only be called from `reload_system_prompt_task`.
201    fn reload_system_prompt(
202        &self,
203        prompt_store: Option<Entity<PromptStore>>,
204        cx: &mut Context<Self>,
205    ) -> Task<()> {
206        let worktrees = self
207            .project
208            .read(cx)
209            .visible_worktrees(cx)
210            .collect::<Vec<_>>();
211        let worktree_tasks = worktrees
212            .into_iter()
213            .map(|worktree| {
214                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
215            })
216            .collect::<Vec<_>>();
217        let default_user_rules_task = match prompt_store {
218            None => Task::ready(vec![]),
219            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
220                let prompts = prompt_store.default_prompt_metadata();
221                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
222                    let contents = prompt_store.load(prompt_metadata.id, cx);
223                    async move { (contents.await, prompt_metadata) }
224                });
225                cx.background_spawn(future::join_all(load_tasks))
226            }),
227        };
228
229        cx.spawn(async move |this, cx| {
230            let (worktrees, default_user_rules) =
231                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
232
233            let worktrees = worktrees
234                .into_iter()
235                .map(|(worktree, rules_error)| {
236                    if let Some(rules_error) = rules_error {
237                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
238                    }
239                    worktree
240                })
241                .collect::<Vec<_>>();
242
243            let default_user_rules = default_user_rules
244                .into_iter()
245                .flat_map(|(contents, prompt_metadata)| match contents {
246                    Ok(contents) => Some(UserRulesContext {
247                        uuid: match prompt_metadata.id {
248                            PromptId::User { uuid } => uuid,
249                            PromptId::EditWorkflow => return None,
250                        },
251                        title: prompt_metadata.title.map(|title| title.to_string()),
252                        contents,
253                    }),
254                    Err(err) => {
255                        this.update(cx, |_, cx| {
256                            cx.emit(RulesLoadingError {
257                                message: format!("{err:?}").into(),
258                            });
259                        })
260                        .ok();
261                        None
262                    }
263                })
264                .collect::<Vec<_>>();
265
266            this.update(cx, |this, _cx| {
267                *this.project_context.0.borrow_mut() =
268                    Some(ProjectContext::new(worktrees, default_user_rules));
269            })
270            .ok();
271        })
272    }
273
274    fn load_worktree_info_for_system_prompt(
275        worktree: Entity<Worktree>,
276        project: Entity<Project>,
277        cx: &mut App,
278    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
279        let root_name = worktree.read(cx).root_name().into();
280
281        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
282        let Some(rules_task) = rules_task else {
283            return Task::ready((
284                WorktreeContext {
285                    root_name,
286                    rules_file: None,
287                },
288                None,
289            ));
290        };
291
292        cx.spawn(async move |_| {
293            let (rules_file, rules_file_error) = match rules_task.await {
294                Ok(rules_file) => (Some(rules_file), None),
295                Err(err) => (
296                    None,
297                    Some(RulesLoadingError {
298                        message: format!("{err}").into(),
299                    }),
300                ),
301            };
302            let worktree_info = WorktreeContext {
303                root_name,
304                rules_file,
305            };
306            (worktree_info, rules_file_error)
307        })
308    }
309
310    fn load_worktree_rules_file(
311        worktree: Entity<Worktree>,
312        project: Entity<Project>,
313        cx: &mut App,
314    ) -> Option<Task<Result<RulesFileContext>>> {
315        let worktree_ref = worktree.read(cx);
316        let worktree_id = worktree_ref.id();
317        let selected_rules_file = RULES_FILE_NAMES
318            .into_iter()
319            .filter_map(|name| {
320                worktree_ref
321                    .entry_for_path(name)
322                    .filter(|entry| entry.is_file())
323                    .map(|entry| entry.path.clone())
324            })
325            .next();
326
327        // Note that Cline supports `.clinerules` being a directory, but that is not currently
328        // supported. This doesn't seem to occur often in GitHub repositories.
329        selected_rules_file.map(|path_in_worktree| {
330            let project_path = ProjectPath {
331                worktree_id,
332                path: path_in_worktree.clone(),
333            };
334            let buffer_task =
335                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
336            let rope_task = cx.spawn(async move |cx| {
337                buffer_task.await?.read_with(cx, |buffer, cx| {
338                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
339                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
340                })?
341            });
342            // Build a string from the rope on a background thread.
343            cx.background_spawn(async move {
344                let (project_entry_id, rope) = rope_task.await?;
345                anyhow::Ok(RulesFileContext {
346                    path_in_worktree,
347                    text: rope.to_string().trim().to_string(),
348                    project_entry_id: project_entry_id.to_usize(),
349                })
350            })
351        })
352    }
353
354    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
355        &self.prompt_store
356    }
357
358    pub fn tools(&self) -> Entity<ToolWorkingSet> {
359        self.tools.clone()
360    }
361
362    /// Returns the number of threads.
363    pub fn thread_count(&self) -> usize {
364        self.threads.len()
365    }
366
367    pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
368        self.threads.iter()
369    }
370
371    pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
372        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
373        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
374        threads
375    }
376
377    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
378        cx.new(|cx| {
379            Thread::new(
380                self.project.clone(),
381                self.tools.clone(),
382                self.prompt_builder.clone(),
383                self.project_context.clone(),
384                cx,
385            )
386        })
387    }
388
389    pub fn create_thread_from_serialized(
390        &mut self,
391        serialized: SerializedThread,
392        cx: &mut Context<Self>,
393    ) -> Entity<Thread> {
394        cx.new(|cx| {
395            Thread::deserialize(
396                ThreadId::new(),
397                serialized,
398                self.project.clone(),
399                self.tools.clone(),
400                self.prompt_builder.clone(),
401                self.project_context.clone(),
402                None,
403                cx,
404            )
405        })
406    }
407
408    pub fn open_thread(
409        &self,
410        id: &ThreadId,
411        window: &mut Window,
412        cx: &mut Context<Self>,
413    ) -> Task<Result<Entity<Thread>>> {
414        let id = id.clone();
415        let database_future = ThreadsDatabase::global_future(cx);
416        let this = cx.weak_entity();
417        window.spawn(cx, async move |cx| {
418            let database = database_future.await.map_err(|err| anyhow!(err))?;
419            let thread = database
420                .try_find_thread(id.clone())
421                .await?
422                .with_context(|| format!("no thread found with ID: {id:?}"))?;
423
424            let thread = this.update_in(cx, |this, window, cx| {
425                cx.new(|cx| {
426                    Thread::deserialize(
427                        id.clone(),
428                        thread,
429                        this.project.clone(),
430                        this.tools.clone(),
431                        this.prompt_builder.clone(),
432                        this.project_context.clone(),
433                        Some(window),
434                        cx,
435                    )
436                })
437            })?;
438
439            Ok(thread)
440        })
441    }
442
443    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
444        let (metadata, serialized_thread) =
445            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
446
447        let database_future = ThreadsDatabase::global_future(cx);
448        cx.spawn(async move |this, cx| {
449            let serialized_thread = serialized_thread.await?;
450            let database = database_future.await.map_err(|err| anyhow!(err))?;
451            database.save_thread(metadata, serialized_thread).await?;
452
453            this.update(cx, |this, cx| this.reload(cx))?.await
454        })
455    }
456
457    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
458        let id = id.clone();
459        let database_future = ThreadsDatabase::global_future(cx);
460        cx.spawn(async move |this, cx| {
461            let database = database_future.await.map_err(|err| anyhow!(err))?;
462            database.delete_thread(id.clone()).await?;
463
464            this.update(cx, |this, cx| {
465                this.threads.retain(|thread| thread.id != id);
466                cx.notify();
467            })
468        })
469    }
470
471    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
472        let database_future = ThreadsDatabase::global_future(cx);
473        cx.spawn(async move |this, cx| {
474            let threads = database_future
475                .await
476                .map_err(|err| anyhow!(err))?
477                .list_threads()
478                .await?;
479
480            this.update(cx, |this, cx| {
481                this.threads = threads;
482                cx.notify();
483            })
484        })
485    }
486
487    fn load_default_profile(&self, cx: &mut Context<Self>) {
488        let assistant_settings = AgentSettings::get_global(cx);
489
490        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
491    }
492
493    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
494        let assistant_settings = AgentSettings::get_global(cx);
495
496        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
497            self.load_profile(profile.clone(), cx);
498        }
499    }
500
501    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
502        self.tools.update(cx, |tools, cx| {
503            tools.disable_all_tools(cx);
504            tools.enable(
505                ToolSource::Native,
506                &profile
507                    .tools
508                    .into_iter()
509                    .filter_map(|(tool, enabled)| enabled.then(|| tool))
510                    .collect::<Vec<_>>(),
511                cx,
512            );
513        });
514
515        if profile.enable_all_context_servers {
516            for context_server_id in self
517                .project
518                .read(cx)
519                .context_server_store()
520                .read(cx)
521                .all_server_ids()
522            {
523                self.tools.update(cx, |tools, cx| {
524                    tools.enable_source(
525                        ToolSource::ContextServer {
526                            id: context_server_id.0.into(),
527                        },
528                        cx,
529                    );
530                });
531            }
532            // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
533            for (context_server_id, preset) in profile.context_servers {
534                self.tools.update(cx, |tools, cx| {
535                    tools.disable(
536                        ToolSource::ContextServer {
537                            id: context_server_id.into(),
538                        },
539                        &preset
540                            .tools
541                            .into_iter()
542                            .filter_map(|(tool, enabled)| (!enabled).then(|| tool))
543                            .collect::<Vec<_>>(),
544                        cx,
545                    )
546                })
547            }
548        } else {
549            for (context_server_id, preset) in profile.context_servers {
550                self.tools.update(cx, |tools, cx| {
551                    tools.enable(
552                        ToolSource::ContextServer {
553                            id: context_server_id.into(),
554                        },
555                        &preset
556                            .tools
557                            .into_iter()
558                            .filter_map(|(tool, enabled)| enabled.then(|| tool))
559                            .collect::<Vec<_>>(),
560                        cx,
561                    )
562                })
563            }
564        }
565    }
566
567    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
568        cx.subscribe(
569            &self.project.read(cx).context_server_store(),
570            Self::handle_context_server_event,
571        )
572        .detach();
573    }
574
575    fn handle_context_server_event(
576        &mut self,
577        context_server_store: Entity<ContextServerStore>,
578        event: &project::context_server_store::Event,
579        cx: &mut Context<Self>,
580    ) {
581        let tool_working_set = self.tools.clone();
582        match event {
583            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
584                match status {
585                    ContextServerStatus::Running => {
586                        if let Some(server) =
587                            context_server_store.read(cx).get_running_server(server_id)
588                        {
589                            let context_server_manager = context_server_store.clone();
590                            cx.spawn({
591                                let server = server.clone();
592                                let server_id = server_id.clone();
593                                async move |this, cx| {
594                                    let Some(protocol) = server.client() else {
595                                        return;
596                                    };
597
598                                    if protocol.capable(context_server::protocol::ServerCapability::Tools) {
599                                        if let Some(tools) = protocol.list_tools().await.log_err() {
600                                            let tool_ids = tool_working_set
601                                                .update(cx, |tool_working_set, _| {
602                                                    tools
603                                                        .tools
604                                                        .into_iter()
605                                                        .map(|tool| {
606                                                            log::info!(
607                                                                "registering context server tool: {:?}",
608                                                                tool.name
609                                                            );
610                                                            tool_working_set.insert(Arc::new(
611                                                                ContextServerTool::new(
612                                                                    context_server_manager.clone(),
613                                                                    server.id(),
614                                                                    tool,
615                                                                ),
616                                                            ))
617                                                        })
618                                                        .collect::<Vec<_>>()
619                                                })
620                                                .log_err();
621
622                                            if let Some(tool_ids) = tool_ids {
623                                                this.update(cx, |this, cx| {
624                                                    this.context_server_tool_ids
625                                                        .insert(server_id, tool_ids);
626                                                    this.load_default_profile(cx);
627                                                })
628                                                .log_err();
629                                            }
630                                        }
631                                    }
632                                }
633                            })
634                            .detach();
635                        }
636                    }
637                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
638                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
639                            tool_working_set.update(cx, |tool_working_set, _| {
640                                tool_working_set.remove(&tool_ids);
641                            });
642                            self.load_default_profile(cx);
643                        }
644                    }
645                    _ => {}
646                }
647            }
648        }
649    }
650}
651
652#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct SerializedThreadMetadata {
654    pub id: ThreadId,
655    pub summary: SharedString,
656    pub updated_at: DateTime<Utc>,
657}
658
659#[derive(Serialize, Deserialize, Debug)]
660pub struct SerializedThread {
661    pub version: String,
662    pub summary: SharedString,
663    pub updated_at: DateTime<Utc>,
664    pub messages: Vec<SerializedMessage>,
665    #[serde(default)]
666    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
667    #[serde(default)]
668    pub cumulative_token_usage: TokenUsage,
669    #[serde(default)]
670    pub request_token_usage: Vec<TokenUsage>,
671    #[serde(default)]
672    pub detailed_summary_state: DetailedSummaryState,
673    #[serde(default)]
674    pub exceeded_window_error: Option<ExceededWindowError>,
675    #[serde(default)]
676    pub model: Option<SerializedLanguageModel>,
677    #[serde(default)]
678    pub completion_mode: Option<CompletionMode>,
679    #[serde(default)]
680    pub tool_use_limit_reached: bool,
681}
682
683#[derive(Serialize, Deserialize, Debug)]
684pub struct SerializedLanguageModel {
685    pub provider: String,
686    pub model: String,
687}
688
689impl SerializedThread {
690    pub const VERSION: &'static str = "0.2.0";
691
692    pub fn from_json(json: &[u8]) -> Result<Self> {
693        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
694        match saved_thread_json.get("version") {
695            Some(serde_json::Value::String(version)) => match version.as_str() {
696                SerializedThreadV0_1_0::VERSION => {
697                    let saved_thread =
698                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
699                    Ok(saved_thread.upgrade())
700                }
701                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
702                    saved_thread_json,
703                )?),
704                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
705            },
706            None => {
707                let saved_thread =
708                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
709                Ok(saved_thread.upgrade())
710            }
711            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
712        }
713    }
714}
715
716#[derive(Serialize, Deserialize, Debug)]
717pub struct SerializedThreadV0_1_0(
718    // The structure did not change, so we are reusing the latest SerializedThread.
719    // When making the next version, make sure this points to SerializedThreadV0_2_0
720    SerializedThread,
721);
722
723impl SerializedThreadV0_1_0 {
724    pub const VERSION: &'static str = "0.1.0";
725
726    pub fn upgrade(self) -> SerializedThread {
727        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
728
729        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
730
731        for message in self.0.messages {
732            if message.role == Role::User && !message.tool_results.is_empty() {
733                if let Some(last_message) = messages.last_mut() {
734                    debug_assert!(last_message.role == Role::Assistant);
735
736                    last_message.tool_results = message.tool_results;
737                    continue;
738                }
739            }
740
741            messages.push(message);
742        }
743
744        SerializedThread { messages, ..self.0 }
745    }
746}
747
748#[derive(Debug, Serialize, Deserialize)]
749pub struct SerializedMessage {
750    pub id: MessageId,
751    pub role: Role,
752    #[serde(default)]
753    pub segments: Vec<SerializedMessageSegment>,
754    #[serde(default)]
755    pub tool_uses: Vec<SerializedToolUse>,
756    #[serde(default)]
757    pub tool_results: Vec<SerializedToolResult>,
758    #[serde(default)]
759    pub context: String,
760    #[serde(default)]
761    pub creases: Vec<SerializedCrease>,
762    #[serde(default)]
763    pub is_hidden: bool,
764}
765
766#[derive(Debug, Serialize, Deserialize)]
767#[serde(tag = "type")]
768pub enum SerializedMessageSegment {
769    #[serde(rename = "text")]
770    Text {
771        text: String,
772    },
773    #[serde(rename = "thinking")]
774    Thinking {
775        text: String,
776        #[serde(skip_serializing_if = "Option::is_none")]
777        signature: Option<String>,
778    },
779    RedactedThinking {
780        data: Vec<u8>,
781    },
782}
783
784#[derive(Debug, Serialize, Deserialize)]
785pub struct SerializedToolUse {
786    pub id: LanguageModelToolUseId,
787    pub name: SharedString,
788    pub input: serde_json::Value,
789}
790
791#[derive(Debug, Serialize, Deserialize)]
792pub struct SerializedToolResult {
793    pub tool_use_id: LanguageModelToolUseId,
794    pub is_error: bool,
795    pub content: LanguageModelToolResultContent,
796    pub output: Option<serde_json::Value>,
797}
798
799#[derive(Serialize, Deserialize)]
800struct LegacySerializedThread {
801    pub summary: SharedString,
802    pub updated_at: DateTime<Utc>,
803    pub messages: Vec<LegacySerializedMessage>,
804    #[serde(default)]
805    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
806}
807
808impl LegacySerializedThread {
809    pub fn upgrade(self) -> SerializedThread {
810        SerializedThread {
811            version: SerializedThread::VERSION.to_string(),
812            summary: self.summary,
813            updated_at: self.updated_at,
814            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
815            initial_project_snapshot: self.initial_project_snapshot,
816            cumulative_token_usage: TokenUsage::default(),
817            request_token_usage: Vec::new(),
818            detailed_summary_state: DetailedSummaryState::default(),
819            exceeded_window_error: None,
820            model: None,
821            completion_mode: None,
822            tool_use_limit_reached: false,
823        }
824    }
825}
826
827#[derive(Debug, Serialize, Deserialize)]
828struct LegacySerializedMessage {
829    pub id: MessageId,
830    pub role: Role,
831    pub text: String,
832    #[serde(default)]
833    pub tool_uses: Vec<SerializedToolUse>,
834    #[serde(default)]
835    pub tool_results: Vec<SerializedToolResult>,
836}
837
838impl LegacySerializedMessage {
839    fn upgrade(self) -> SerializedMessage {
840        SerializedMessage {
841            id: self.id,
842            role: self.role,
843            segments: vec![SerializedMessageSegment::Text { text: self.text }],
844            tool_uses: self.tool_uses,
845            tool_results: self.tool_results,
846            context: String::new(),
847            creases: Vec::new(),
848            is_hidden: false,
849        }
850    }
851}
852
853#[derive(Debug, Serialize, Deserialize)]
854pub struct SerializedCrease {
855    pub start: usize,
856    pub end: usize,
857    pub icon_path: SharedString,
858    pub label: SharedString,
859}
860
861struct GlobalThreadsDatabase(
862    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
863);
864
865impl Global for GlobalThreadsDatabase {}
866
867pub(crate) struct ThreadsDatabase {
868    executor: BackgroundExecutor,
869    env: heed::Env,
870    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
871}
872
873impl heed::BytesEncode<'_> for SerializedThread {
874    type EItem = SerializedThread;
875
876    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
877        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
878    }
879}
880
881impl<'a> heed::BytesDecode<'a> for SerializedThread {
882    type DItem = SerializedThread;
883
884    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
885        // We implement this type manually because we want to call `SerializedThread::from_json`,
886        // instead of the Deserialize trait implementation for `SerializedThread`.
887        SerializedThread::from_json(bytes).map_err(Into::into)
888    }
889}
890
891impl ThreadsDatabase {
892    fn global_future(
893        cx: &mut App,
894    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
895        GlobalThreadsDatabase::global(cx).0.clone()
896    }
897
898    fn init(cx: &mut App) {
899        let executor = cx.background_executor().clone();
900        let database_future = executor
901            .spawn({
902                let executor = executor.clone();
903                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
904                async move { ThreadsDatabase::new(database_path, executor) }
905            })
906            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
907            .boxed()
908            .shared();
909
910        cx.set_global(GlobalThreadsDatabase(database_future));
911    }
912
913    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
914        std::fs::create_dir_all(&path)?;
915
916        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
917        let env = unsafe {
918            heed::EnvOpenOptions::new()
919                .map_size(ONE_GB_IN_BYTES)
920                .max_dbs(1)
921                .open(path)?
922        };
923
924        let mut txn = env.write_txn()?;
925        let threads = env.create_database(&mut txn, Some("threads"))?;
926        txn.commit()?;
927
928        Ok(Self {
929            executor,
930            env,
931            threads,
932        })
933    }
934
935    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
936        let env = self.env.clone();
937        let threads = self.threads;
938
939        self.executor.spawn(async move {
940            let txn = env.read_txn()?;
941            let mut iter = threads.iter(&txn)?;
942            let mut threads = Vec::new();
943            while let Some((key, value)) = iter.next().transpose()? {
944                threads.push(SerializedThreadMetadata {
945                    id: key,
946                    summary: value.summary,
947                    updated_at: value.updated_at,
948                });
949            }
950
951            Ok(threads)
952        })
953    }
954
955    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
956        let env = self.env.clone();
957        let threads = self.threads;
958
959        self.executor.spawn(async move {
960            let txn = env.read_txn()?;
961            let thread = threads.get(&txn, &id)?;
962            Ok(thread)
963        })
964    }
965
966    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
967        let env = self.env.clone();
968        let threads = self.threads;
969
970        self.executor.spawn(async move {
971            let mut txn = env.write_txn()?;
972            threads.put(&mut txn, &id, &thread)?;
973            txn.commit()?;
974            Ok(())
975        })
976    }
977
978    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
979        let env = self.env.clone();
980        let threads = self.threads;
981
982        self.executor.spawn(async move {
983            let mut txn = env.write_txn()?;
984            threads.delete(&mut txn, &id)?;
985            txn.commit()?;
986            Ok(())
987        })
988    }
989}