thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::ContextServerId;
 13use futures::channel::{mpsc, oneshot};
 14use futures::future::{self, BoxFuture, Shared};
 15use futures::{FutureExt as _, StreamExt as _};
 16use gpui::{
 17    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 18    Subscription, Task, prelude::*,
 19};
 20use heed::Database;
 21use heed::types::SerdeBincode;
 22use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 23use project::context_server_store::{ContextServerStatus, ContextServerStore};
 24use project::{Project, ProjectItem, ProjectPath, Worktree};
 25use prompt_store::{
 26    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
 27    UserRulesContext, WorktreeContext,
 28};
 29use serde::{Deserialize, Serialize};
 30use settings::{Settings as _, SettingsStore};
 31use ui::Window;
 32use util::ResultExt as _;
 33
 34use crate::context_server_tool::ContextServerTool;
 35use crate::thread::{
 36    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
 37};
 38
 39const RULES_FILE_NAMES: [&'static str; 6] = [
 40    ".rules",
 41    ".cursorrules",
 42    ".windsurfrules",
 43    ".clinerules",
 44    ".github/copilot-instructions.md",
 45    "CLAUDE.md",
 46];
 47
 48pub fn init(cx: &mut App) {
 49    ThreadsDatabase::init(cx);
 50}
 51
 52/// A system prompt shared by all threads created by this ThreadStore
 53#[derive(Clone, Default)]
 54pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 55
 56impl SharedProjectContext {
 57    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 58        self.0.borrow()
 59    }
 60}
 61
 62pub type TextThreadStore = assistant_context_editor::ContextStore;
 63
 64pub struct ThreadStore {
 65    project: Entity<Project>,
 66    tools: Entity<ToolWorkingSet>,
 67    prompt_builder: Arc<PromptBuilder>,
 68    prompt_store: Option<Entity<PromptStore>>,
 69    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 70    threads: Vec<SerializedThreadMetadata>,
 71    project_context: SharedProjectContext,
 72    reload_system_prompt_tx: mpsc::Sender<()>,
 73    _reload_system_prompt_task: Task<()>,
 74    _subscriptions: Vec<Subscription>,
 75}
 76
 77pub struct RulesLoadingError {
 78    pub message: SharedString,
 79}
 80
 81impl EventEmitter<RulesLoadingError> for ThreadStore {}
 82
 83impl ThreadStore {
 84    pub fn load(
 85        project: Entity<Project>,
 86        tools: Entity<ToolWorkingSet>,
 87        prompt_store: Option<Entity<PromptStore>>,
 88        prompt_builder: Arc<PromptBuilder>,
 89        cx: &mut App,
 90    ) -> Task<Result<Entity<Self>>> {
 91        cx.spawn(async move |cx| {
 92            let (thread_store, ready_rx) = cx.update(|cx| {
 93                let mut option_ready_rx = None;
 94                let thread_store = cx.new(|cx| {
 95                    let (thread_store, ready_rx) =
 96                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 97                    option_ready_rx = Some(ready_rx);
 98                    thread_store
 99                });
100                (thread_store, option_ready_rx.take().unwrap())
101            })?;
102            ready_rx.await?;
103            Ok(thread_store)
104        })
105    }
106
107    fn new(
108        project: Entity<Project>,
109        tools: Entity<ToolWorkingSet>,
110        prompt_builder: Arc<PromptBuilder>,
111        prompt_store: Option<Entity<PromptStore>>,
112        cx: &mut Context<Self>,
113    ) -> (Self, oneshot::Receiver<()>) {
114        let mut subscriptions = vec![
115            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
116                this.load_default_profile(cx);
117            }),
118            cx.subscribe(&project, Self::handle_project_event),
119        ];
120
121        if let Some(prompt_store) = prompt_store.as_ref() {
122            subscriptions.push(cx.subscribe(
123                prompt_store,
124                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
125                    this.enqueue_system_prompt_reload();
126                },
127            ))
128        }
129
130        // This channel and task prevent concurrent and redundant loading of the system prompt.
131        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
132        let (ready_tx, ready_rx) = oneshot::channel();
133        let mut ready_tx = Some(ready_tx);
134        let reload_system_prompt_task = cx.spawn({
135            let prompt_store = prompt_store.clone();
136            async move |thread_store, cx| {
137                loop {
138                    let Some(reload_task) = thread_store
139                        .update(cx, |thread_store, cx| {
140                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
141                        })
142                        .ok()
143                    else {
144                        return;
145                    };
146                    reload_task.await;
147                    if let Some(ready_tx) = ready_tx.take() {
148                        ready_tx.send(()).ok();
149                    }
150                    reload_system_prompt_rx.next().await;
151                }
152            }
153        });
154
155        let this = Self {
156            project,
157            tools,
158            prompt_builder,
159            prompt_store,
160            context_server_tool_ids: HashMap::default(),
161            threads: Vec::new(),
162            project_context: SharedProjectContext::default(),
163            reload_system_prompt_tx,
164            _reload_system_prompt_task: reload_system_prompt_task,
165            _subscriptions: subscriptions,
166        };
167        this.load_default_profile(cx);
168        this.register_context_server_handlers(cx);
169        this.reload(cx).detach_and_log_err(cx);
170        (this, ready_rx)
171    }
172
173    fn handle_project_event(
174        &mut self,
175        _project: Entity<Project>,
176        event: &project::Event,
177        _cx: &mut Context<Self>,
178    ) {
179        match event {
180            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
181                self.enqueue_system_prompt_reload();
182            }
183            project::Event::WorktreeUpdatedEntries(_, items) => {
184                if items.iter().any(|(path, _, _)| {
185                    RULES_FILE_NAMES
186                        .iter()
187                        .any(|name| path.as_ref() == Path::new(name))
188                }) {
189                    self.enqueue_system_prompt_reload();
190                }
191            }
192            _ => {}
193        }
194    }
195
196    fn enqueue_system_prompt_reload(&mut self) {
197        self.reload_system_prompt_tx.try_send(()).ok();
198    }
199
200    // Note that this should only be called from `reload_system_prompt_task`.
201    fn reload_system_prompt(
202        &self,
203        prompt_store: Option<Entity<PromptStore>>,
204        cx: &mut Context<Self>,
205    ) -> Task<()> {
206        let worktrees = self
207            .project
208            .read(cx)
209            .visible_worktrees(cx)
210            .collect::<Vec<_>>();
211        let worktree_tasks = worktrees
212            .into_iter()
213            .map(|worktree| {
214                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
215            })
216            .collect::<Vec<_>>();
217        let default_user_rules_task = match prompt_store {
218            None => Task::ready(vec![]),
219            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
220                let prompts = prompt_store.default_prompt_metadata();
221                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
222                    let contents = prompt_store.load(prompt_metadata.id, cx);
223                    async move { (contents.await, prompt_metadata) }
224                });
225                cx.background_spawn(future::join_all(load_tasks))
226            }),
227        };
228
229        cx.spawn(async move |this, cx| {
230            let (worktrees, default_user_rules) =
231                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
232
233            let worktrees = worktrees
234                .into_iter()
235                .map(|(worktree, rules_error)| {
236                    if let Some(rules_error) = rules_error {
237                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
238                    }
239                    worktree
240                })
241                .collect::<Vec<_>>();
242
243            let default_user_rules = default_user_rules
244                .into_iter()
245                .flat_map(|(contents, prompt_metadata)| match contents {
246                    Ok(contents) => Some(UserRulesContext {
247                        uuid: match prompt_metadata.id {
248                            PromptId::User { uuid } => uuid,
249                            PromptId::EditWorkflow => return None,
250                        },
251                        title: prompt_metadata.title.map(|title| title.to_string()),
252                        contents,
253                    }),
254                    Err(err) => {
255                        this.update(cx, |_, cx| {
256                            cx.emit(RulesLoadingError {
257                                message: format!("{err:?}").into(),
258                            });
259                        })
260                        .ok();
261                        None
262                    }
263                })
264                .collect::<Vec<_>>();
265
266            this.update(cx, |this, _cx| {
267                *this.project_context.0.borrow_mut() =
268                    Some(ProjectContext::new(worktrees, default_user_rules));
269            })
270            .ok();
271        })
272    }
273
274    fn load_worktree_info_for_system_prompt(
275        worktree: Entity<Worktree>,
276        project: Entity<Project>,
277        cx: &mut App,
278    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
279        let root_name = worktree.read(cx).root_name().into();
280
281        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
282        let Some(rules_task) = rules_task else {
283            return Task::ready((
284                WorktreeContext {
285                    root_name,
286                    rules_file: None,
287                },
288                None,
289            ));
290        };
291
292        cx.spawn(async move |_| {
293            let (rules_file, rules_file_error) = match rules_task.await {
294                Ok(rules_file) => (Some(rules_file), None),
295                Err(err) => (
296                    None,
297                    Some(RulesLoadingError {
298                        message: format!("{err}").into(),
299                    }),
300                ),
301            };
302            let worktree_info = WorktreeContext {
303                root_name,
304                rules_file,
305            };
306            (worktree_info, rules_file_error)
307        })
308    }
309
310    fn load_worktree_rules_file(
311        worktree: Entity<Worktree>,
312        project: Entity<Project>,
313        cx: &mut App,
314    ) -> Option<Task<Result<RulesFileContext>>> {
315        let worktree_ref = worktree.read(cx);
316        let worktree_id = worktree_ref.id();
317        let selected_rules_file = RULES_FILE_NAMES
318            .into_iter()
319            .filter_map(|name| {
320                worktree_ref
321                    .entry_for_path(name)
322                    .filter(|entry| entry.is_file())
323                    .map(|entry| entry.path.clone())
324            })
325            .next();
326
327        // Note that Cline supports `.clinerules` being a directory, but that is not currently
328        // supported. This doesn't seem to occur often in GitHub repositories.
329        selected_rules_file.map(|path_in_worktree| {
330            let project_path = ProjectPath {
331                worktree_id,
332                path: path_in_worktree.clone(),
333            };
334            let buffer_task =
335                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
336            let rope_task = cx.spawn(async move |cx| {
337                buffer_task.await?.read_with(cx, |buffer, cx| {
338                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
339                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
340                })?
341            });
342            // Build a string from the rope on a background thread.
343            cx.background_spawn(async move {
344                let (project_entry_id, rope) = rope_task.await?;
345                anyhow::Ok(RulesFileContext {
346                    path_in_worktree,
347                    text: rope.to_string().trim().to_string(),
348                    project_entry_id: project_entry_id.to_usize(),
349                })
350            })
351        })
352    }
353
354    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
355        &self.prompt_store
356    }
357
358    pub fn tools(&self) -> Entity<ToolWorkingSet> {
359        self.tools.clone()
360    }
361
362    /// Returns the number of threads.
363    pub fn thread_count(&self) -> usize {
364        self.threads.len()
365    }
366
367    pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
368        self.threads.iter()
369    }
370
371    pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
372        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
373        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
374        threads
375    }
376
377    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
378        cx.new(|cx| {
379            Thread::new(
380                self.project.clone(),
381                self.tools.clone(),
382                self.prompt_builder.clone(),
383                self.project_context.clone(),
384                cx,
385            )
386        })
387    }
388
389    pub fn open_thread(
390        &self,
391        id: &ThreadId,
392        window: &mut Window,
393        cx: &mut Context<Self>,
394    ) -> Task<Result<Entity<Thread>>> {
395        let id = id.clone();
396        let database_future = ThreadsDatabase::global_future(cx);
397        let this = cx.weak_entity();
398        window.spawn(cx, async move |cx| {
399            let database = database_future.await.map_err(|err| anyhow!(err))?;
400            let thread = database
401                .try_find_thread(id.clone())
402                .await?
403                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
404
405            let thread = this.update_in(cx, |this, window, cx| {
406                cx.new(|cx| {
407                    Thread::deserialize(
408                        id.clone(),
409                        thread,
410                        this.project.clone(),
411                        this.tools.clone(),
412                        this.prompt_builder.clone(),
413                        this.project_context.clone(),
414                        window,
415                        cx,
416                    )
417                })
418            })?;
419
420            Ok(thread)
421        })
422    }
423
424    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
425        let (metadata, serialized_thread) =
426            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
427
428        let database_future = ThreadsDatabase::global_future(cx);
429        cx.spawn(async move |this, cx| {
430            let serialized_thread = serialized_thread.await?;
431            let database = database_future.await.map_err(|err| anyhow!(err))?;
432            database.save_thread(metadata, serialized_thread).await?;
433
434            this.update(cx, |this, cx| this.reload(cx))?.await
435        })
436    }
437
438    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
439        let id = id.clone();
440        let database_future = ThreadsDatabase::global_future(cx);
441        cx.spawn(async move |this, cx| {
442            let database = database_future.await.map_err(|err| anyhow!(err))?;
443            database.delete_thread(id.clone()).await?;
444
445            this.update(cx, |this, cx| {
446                this.threads.retain(|thread| thread.id != id);
447                cx.notify();
448            })
449        })
450    }
451
452    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
453        let database_future = ThreadsDatabase::global_future(cx);
454        cx.spawn(async move |this, cx| {
455            let threads = database_future
456                .await
457                .map_err(|err| anyhow!(err))?
458                .list_threads()
459                .await?;
460
461            this.update(cx, |this, cx| {
462                this.threads = threads;
463                cx.notify();
464            })
465        })
466    }
467
468    fn load_default_profile(&self, cx: &mut Context<Self>) {
469        let assistant_settings = AssistantSettings::get_global(cx);
470
471        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
472    }
473
474    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
475        let assistant_settings = AssistantSettings::get_global(cx);
476
477        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
478            self.load_profile(profile.clone(), cx);
479        }
480    }
481
482    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
483        self.tools.update(cx, |tools, cx| {
484            tools.disable_all_tools(cx);
485            tools.enable(
486                ToolSource::Native,
487                &profile
488                    .tools
489                    .iter()
490                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
491                    .collect::<Vec<_>>(),
492                cx,
493            );
494        });
495
496        if profile.enable_all_context_servers {
497            for context_server_id in self
498                .project
499                .read(cx)
500                .context_server_store()
501                .read(cx)
502                .all_server_ids()
503            {
504                self.tools.update(cx, |tools, cx| {
505                    tools.enable_source(
506                        ToolSource::ContextServer {
507                            id: context_server_id.0.into(),
508                        },
509                        cx,
510                    );
511                });
512            }
513            // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
514            for (context_server_id, preset) in &profile.context_servers {
515                self.tools.update(cx, |tools, cx| {
516                    tools.disable(
517                        ToolSource::ContextServer {
518                            id: context_server_id.clone().into(),
519                        },
520                        &preset
521                            .tools
522                            .iter()
523                            .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
524                            .collect::<Vec<_>>(),
525                        cx,
526                    )
527                })
528            }
529        } else {
530            for (context_server_id, preset) in &profile.context_servers {
531                self.tools.update(cx, |tools, cx| {
532                    tools.enable(
533                        ToolSource::ContextServer {
534                            id: context_server_id.clone().into(),
535                        },
536                        &preset
537                            .tools
538                            .iter()
539                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
540                            .collect::<Vec<_>>(),
541                        cx,
542                    )
543                })
544            }
545        }
546    }
547
548    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
549        cx.subscribe(
550            &self.project.read(cx).context_server_store(),
551            Self::handle_context_server_event,
552        )
553        .detach();
554    }
555
556    fn handle_context_server_event(
557        &mut self,
558        context_server_store: Entity<ContextServerStore>,
559        event: &project::context_server_store::Event,
560        cx: &mut Context<Self>,
561    ) {
562        let tool_working_set = self.tools.clone();
563        match event {
564            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
565                match status {
566                    ContextServerStatus::Running => {
567                        if let Some(server) =
568                            context_server_store.read(cx).get_running_server(server_id)
569                        {
570                            let context_server_manager = context_server_store.clone();
571                            cx.spawn({
572                                let server = server.clone();
573                                let server_id = server_id.clone();
574                                async move |this, cx| {
575                                    let Some(protocol) = server.client() else {
576                                        return;
577                                    };
578
579                                    if protocol.capable(context_server::protocol::ServerCapability::Tools) {
580                                        if let Some(tools) = protocol.list_tools().await.log_err() {
581                                            let tool_ids = tool_working_set
582                                                .update(cx, |tool_working_set, _| {
583                                                    tools
584                                                        .tools
585                                                        .into_iter()
586                                                        .map(|tool| {
587                                                            log::info!(
588                                                                "registering context server tool: {:?}",
589                                                                tool.name
590                                                            );
591                                                            tool_working_set.insert(Arc::new(
592                                                                ContextServerTool::new(
593                                                                    context_server_manager.clone(),
594                                                                    server.id(),
595                                                                    tool,
596                                                                ),
597                                                            ))
598                                                        })
599                                                        .collect::<Vec<_>>()
600                                                })
601                                                .log_err();
602
603                                            if let Some(tool_ids) = tool_ids {
604                                                this.update(cx, |this, cx| {
605                                                    this.context_server_tool_ids
606                                                        .insert(server_id, tool_ids);
607                                                    this.load_default_profile(cx);
608                                                })
609                                                .log_err();
610                                            }
611                                        }
612                                    }
613                                }
614                            })
615                            .detach();
616                        }
617                    }
618                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
619                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
620                            tool_working_set.update(cx, |tool_working_set, _| {
621                                tool_working_set.remove(&tool_ids);
622                            });
623                            self.load_default_profile(cx);
624                        }
625                    }
626                    _ => {}
627                }
628            }
629        }
630    }
631}
632
633#[derive(Debug, Clone, Serialize, Deserialize)]
634pub struct SerializedThreadMetadata {
635    pub id: ThreadId,
636    pub summary: SharedString,
637    pub updated_at: DateTime<Utc>,
638}
639
640#[derive(Serialize, Deserialize, Debug)]
641pub struct SerializedThread {
642    pub version: String,
643    pub summary: SharedString,
644    pub updated_at: DateTime<Utc>,
645    pub messages: Vec<SerializedMessage>,
646    #[serde(default)]
647    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
648    #[serde(default)]
649    pub cumulative_token_usage: TokenUsage,
650    #[serde(default)]
651    pub request_token_usage: Vec<TokenUsage>,
652    #[serde(default)]
653    pub detailed_summary_state: DetailedSummaryState,
654    #[serde(default)]
655    pub exceeded_window_error: Option<ExceededWindowError>,
656    #[serde(default)]
657    pub model: Option<SerializedLanguageModel>,
658    #[serde(default)]
659    pub completion_mode: Option<CompletionMode>,
660}
661
662#[derive(Serialize, Deserialize, Debug)]
663pub struct SerializedLanguageModel {
664    pub provider: String,
665    pub model: String,
666}
667
668impl SerializedThread {
669    pub const VERSION: &'static str = "0.2.0";
670
671    pub fn from_json(json: &[u8]) -> Result<Self> {
672        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
673        match saved_thread_json.get("version") {
674            Some(serde_json::Value::String(version)) => match version.as_str() {
675                SerializedThreadV0_1_0::VERSION => {
676                    let saved_thread =
677                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
678                    Ok(saved_thread.upgrade())
679                }
680                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
681                    saved_thread_json,
682                )?),
683                _ => Err(anyhow!(
684                    "unrecognized serialized thread version: {}",
685                    version
686                )),
687            },
688            None => {
689                let saved_thread =
690                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
691                Ok(saved_thread.upgrade())
692            }
693            version => Err(anyhow!(
694                "unrecognized serialized thread version: {:?}",
695                version
696            )),
697        }
698    }
699}
700
701#[derive(Serialize, Deserialize, Debug)]
702pub struct SerializedThreadV0_1_0(
703    // The structure did not change, so we are reusing the latest SerializedThread.
704    // When making the next version, make sure this points to SerializedThreadV0_2_0
705    SerializedThread,
706);
707
708impl SerializedThreadV0_1_0 {
709    pub const VERSION: &'static str = "0.1.0";
710
711    pub fn upgrade(self) -> SerializedThread {
712        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
713
714        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
715
716        for message in self.0.messages {
717            if message.role == Role::User && !message.tool_results.is_empty() {
718                if let Some(last_message) = messages.last_mut() {
719                    debug_assert!(last_message.role == Role::Assistant);
720
721                    last_message.tool_results = message.tool_results;
722                    continue;
723                }
724            }
725
726            messages.push(message);
727        }
728
729        SerializedThread { messages, ..self.0 }
730    }
731}
732
733#[derive(Debug, Serialize, Deserialize)]
734pub struct SerializedMessage {
735    pub id: MessageId,
736    pub role: Role,
737    #[serde(default)]
738    pub segments: Vec<SerializedMessageSegment>,
739    #[serde(default)]
740    pub tool_uses: Vec<SerializedToolUse>,
741    #[serde(default)]
742    pub tool_results: Vec<SerializedToolResult>,
743    #[serde(default)]
744    pub context: String,
745    #[serde(default)]
746    pub creases: Vec<SerializedCrease>,
747}
748
749#[derive(Debug, Serialize, Deserialize)]
750#[serde(tag = "type")]
751pub enum SerializedMessageSegment {
752    #[serde(rename = "text")]
753    Text {
754        text: String,
755    },
756    #[serde(rename = "thinking")]
757    Thinking {
758        text: String,
759        #[serde(skip_serializing_if = "Option::is_none")]
760        signature: Option<String>,
761    },
762    RedactedThinking {
763        data: Vec<u8>,
764    },
765}
766
767#[derive(Debug, Serialize, Deserialize)]
768pub struct SerializedToolUse {
769    pub id: LanguageModelToolUseId,
770    pub name: SharedString,
771    pub input: serde_json::Value,
772}
773
774#[derive(Debug, Serialize, Deserialize)]
775pub struct SerializedToolResult {
776    pub tool_use_id: LanguageModelToolUseId,
777    pub is_error: bool,
778    pub content: Arc<str>,
779    pub output: Option<serde_json::Value>,
780}
781
782#[derive(Serialize, Deserialize)]
783struct LegacySerializedThread {
784    pub summary: SharedString,
785    pub updated_at: DateTime<Utc>,
786    pub messages: Vec<LegacySerializedMessage>,
787    #[serde(default)]
788    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
789}
790
791impl LegacySerializedThread {
792    pub fn upgrade(self) -> SerializedThread {
793        SerializedThread {
794            version: SerializedThread::VERSION.to_string(),
795            summary: self.summary,
796            updated_at: self.updated_at,
797            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
798            initial_project_snapshot: self.initial_project_snapshot,
799            cumulative_token_usage: TokenUsage::default(),
800            request_token_usage: Vec::new(),
801            detailed_summary_state: DetailedSummaryState::default(),
802            exceeded_window_error: None,
803            model: None,
804            completion_mode: None,
805        }
806    }
807}
808
809#[derive(Debug, Serialize, Deserialize)]
810struct LegacySerializedMessage {
811    pub id: MessageId,
812    pub role: Role,
813    pub text: String,
814    #[serde(default)]
815    pub tool_uses: Vec<SerializedToolUse>,
816    #[serde(default)]
817    pub tool_results: Vec<SerializedToolResult>,
818}
819
820impl LegacySerializedMessage {
821    fn upgrade(self) -> SerializedMessage {
822        SerializedMessage {
823            id: self.id,
824            role: self.role,
825            segments: vec![SerializedMessageSegment::Text { text: self.text }],
826            tool_uses: self.tool_uses,
827            tool_results: self.tool_results,
828            context: String::new(),
829            creases: Vec::new(),
830        }
831    }
832}
833
834#[derive(Debug, Serialize, Deserialize)]
835pub struct SerializedCrease {
836    pub start: usize,
837    pub end: usize,
838    pub icon_path: SharedString,
839    pub label: SharedString,
840}
841
842struct GlobalThreadsDatabase(
843    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
844);
845
846impl Global for GlobalThreadsDatabase {}
847
848pub(crate) struct ThreadsDatabase {
849    executor: BackgroundExecutor,
850    env: heed::Env,
851    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
852}
853
854impl heed::BytesEncode<'_> for SerializedThread {
855    type EItem = SerializedThread;
856
857    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
858        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
859    }
860}
861
862impl<'a> heed::BytesDecode<'a> for SerializedThread {
863    type DItem = SerializedThread;
864
865    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
866        // We implement this type manually because we want to call `SerializedThread::from_json`,
867        // instead of the Deserialize trait implementation for `SerializedThread`.
868        SerializedThread::from_json(bytes).map_err(Into::into)
869    }
870}
871
872impl ThreadsDatabase {
873    fn global_future(
874        cx: &mut App,
875    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
876        GlobalThreadsDatabase::global(cx).0.clone()
877    }
878
879    fn init(cx: &mut App) {
880        let executor = cx.background_executor().clone();
881        let database_future = executor
882            .spawn({
883                let executor = executor.clone();
884                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
885                async move { ThreadsDatabase::new(database_path, executor) }
886            })
887            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
888            .boxed()
889            .shared();
890
891        cx.set_global(GlobalThreadsDatabase(database_future));
892    }
893
894    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
895        std::fs::create_dir_all(&path)?;
896
897        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
898        let env = unsafe {
899            heed::EnvOpenOptions::new()
900                .map_size(ONE_GB_IN_BYTES)
901                .max_dbs(1)
902                .open(path)?
903        };
904
905        let mut txn = env.write_txn()?;
906        let threads = env.create_database(&mut txn, Some("threads"))?;
907        txn.commit()?;
908
909        Ok(Self {
910            executor,
911            env,
912            threads,
913        })
914    }
915
916    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
917        let env = self.env.clone();
918        let threads = self.threads;
919
920        self.executor.spawn(async move {
921            let txn = env.read_txn()?;
922            let mut iter = threads.iter(&txn)?;
923            let mut threads = Vec::new();
924            while let Some((key, value)) = iter.next().transpose()? {
925                threads.push(SerializedThreadMetadata {
926                    id: key,
927                    summary: value.summary,
928                    updated_at: value.updated_at,
929                });
930            }
931
932            Ok(threads)
933        })
934    }
935
936    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
937        let env = self.env.clone();
938        let threads = self.threads;
939
940        self.executor.spawn(async move {
941            let txn = env.read_txn()?;
942            let thread = threads.get(&txn, &id)?;
943            Ok(thread)
944        })
945    }
946
947    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
948        let env = self.env.clone();
949        let threads = self.threads;
950
951        self.executor.spawn(async move {
952            let mut txn = env.write_txn()?;
953            threads.put(&mut txn, &id, &thread)?;
954            txn.commit()?;
955            Ok(())
956        })
957    }
958
959    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
960        let env = self.env.clone();
961        let threads = self.threads;
962
963        self.executor.spawn(async move {
964            let mut txn = env.write_txn()?;
965            threads.delete(&mut txn, &id)?;
966            txn.commit()?;
967            Ok(())
968        })
969    }
970}