thread_store.rs

   1use std::cell::{Ref, RefCell};
   2use std::path::{Path, PathBuf};
   3use std::rc::Rc;
   4use std::sync::{Arc, Mutex};
   5
   6use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::HashMap;
  11use context_server::ContextServerId;
  12use futures::channel::{mpsc, oneshot};
  13use futures::future::{self, BoxFuture, Shared};
  14use futures::{FutureExt as _, StreamExt as _};
  15use gpui::{
  16    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  17    Subscription, Task, prelude::*,
  18};
  19
  20use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  21use project::context_server_store::{ContextServerStatus, ContextServerStore};
  22use project::{Project, ProjectItem, ProjectPath, Worktree};
  23use prompt_store::{
  24    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  25    UserRulesContext, WorktreeContext,
  26};
  27use serde::{Deserialize, Serialize};
  28use settings::{Settings as _, SettingsStore};
  29use ui::Window;
  30use util::ResultExt as _;
  31
  32use crate::context_server_tool::ContextServerTool;
  33use crate::thread::{
  34    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
  35};
  36use indoc::indoc;
  37use sqlez::{
  38    bindable::{Bind, Column},
  39    connection::Connection,
  40    statement::Statement,
  41};
  42
  43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  44pub enum DataType {
  45    #[serde(rename = "json")]
  46    Json,
  47    #[serde(rename = "zstd")]
  48    Zstd,
  49}
  50
  51impl Bind for DataType {
  52    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  53        let value = match self {
  54            DataType::Json => "json",
  55            DataType::Zstd => "zstd",
  56        };
  57        value.bind(statement, start_index)
  58    }
  59}
  60
  61impl Column for DataType {
  62    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  63        let (value, next_index) = String::column(statement, start_index)?;
  64        let data_type = match value.as_str() {
  65            "json" => DataType::Json,
  66            "zstd" => DataType::Zstd,
  67            _ => anyhow::bail!("Unknown data type: {}", value),
  68        };
  69        Ok((data_type, next_index))
  70    }
  71}
  72
  73const RULES_FILE_NAMES: [&'static str; 8] = [
  74    ".rules",
  75    ".cursorrules",
  76    ".windsurfrules",
  77    ".clinerules",
  78    ".github/copilot-instructions.md",
  79    "CLAUDE.md",
  80    "AGENT.md",
  81    "AGENTS.md",
  82];
  83
  84pub fn init(cx: &mut App) {
  85    ThreadsDatabase::init(cx);
  86}
  87
  88/// A system prompt shared by all threads created by this ThreadStore
  89#[derive(Clone, Default)]
  90pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  91
  92impl SharedProjectContext {
  93    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
  94        self.0.borrow()
  95    }
  96}
  97
  98pub type TextThreadStore = assistant_context_editor::ContextStore;
  99
 100pub struct ThreadStore {
 101    project: Entity<Project>,
 102    tools: Entity<ToolWorkingSet>,
 103    prompt_builder: Arc<PromptBuilder>,
 104    prompt_store: Option<Entity<PromptStore>>,
 105    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 106    threads: Vec<SerializedThreadMetadata>,
 107    project_context: SharedProjectContext,
 108    reload_system_prompt_tx: mpsc::Sender<()>,
 109    _reload_system_prompt_task: Task<()>,
 110    _subscriptions: Vec<Subscription>,
 111}
 112
 113pub struct RulesLoadingError {
 114    pub message: SharedString,
 115}
 116
 117impl EventEmitter<RulesLoadingError> for ThreadStore {}
 118
 119impl ThreadStore {
 120    pub fn load(
 121        project: Entity<Project>,
 122        tools: Entity<ToolWorkingSet>,
 123        prompt_store: Option<Entity<PromptStore>>,
 124        prompt_builder: Arc<PromptBuilder>,
 125        cx: &mut App,
 126    ) -> Task<Result<Entity<Self>>> {
 127        cx.spawn(async move |cx| {
 128            let (thread_store, ready_rx) = cx.update(|cx| {
 129                let mut option_ready_rx = None;
 130                let thread_store = cx.new(|cx| {
 131                    let (thread_store, ready_rx) =
 132                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 133                    option_ready_rx = Some(ready_rx);
 134                    thread_store
 135                });
 136                (thread_store, option_ready_rx.take().unwrap())
 137            })?;
 138            ready_rx.await?;
 139            Ok(thread_store)
 140        })
 141    }
 142
 143    fn new(
 144        project: Entity<Project>,
 145        tools: Entity<ToolWorkingSet>,
 146        prompt_builder: Arc<PromptBuilder>,
 147        prompt_store: Option<Entity<PromptStore>>,
 148        cx: &mut Context<Self>,
 149    ) -> (Self, oneshot::Receiver<()>) {
 150        let mut subscriptions = vec![
 151            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
 152                this.load_default_profile(cx);
 153            }),
 154            cx.subscribe(&project, Self::handle_project_event),
 155        ];
 156
 157        if let Some(prompt_store) = prompt_store.as_ref() {
 158            subscriptions.push(cx.subscribe(
 159                prompt_store,
 160                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 161                    this.enqueue_system_prompt_reload();
 162                },
 163            ))
 164        }
 165
 166        // This channel and task prevent concurrent and redundant loading of the system prompt.
 167        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 168        let (ready_tx, ready_rx) = oneshot::channel();
 169        let mut ready_tx = Some(ready_tx);
 170        let reload_system_prompt_task = cx.spawn({
 171            let prompt_store = prompt_store.clone();
 172            async move |thread_store, cx| {
 173                loop {
 174                    let Some(reload_task) = thread_store
 175                        .update(cx, |thread_store, cx| {
 176                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 177                        })
 178                        .ok()
 179                    else {
 180                        return;
 181                    };
 182                    reload_task.await;
 183                    if let Some(ready_tx) = ready_tx.take() {
 184                        ready_tx.send(()).ok();
 185                    }
 186                    reload_system_prompt_rx.next().await;
 187                }
 188            }
 189        });
 190
 191        let this = Self {
 192            project,
 193            tools,
 194            prompt_builder,
 195            prompt_store,
 196            context_server_tool_ids: HashMap::default(),
 197            threads: Vec::new(),
 198            project_context: SharedProjectContext::default(),
 199            reload_system_prompt_tx,
 200            _reload_system_prompt_task: reload_system_prompt_task,
 201            _subscriptions: subscriptions,
 202        };
 203        this.load_default_profile(cx);
 204        this.register_context_server_handlers(cx);
 205        this.reload(cx).detach_and_log_err(cx);
 206        (this, ready_rx)
 207    }
 208
 209    fn handle_project_event(
 210        &mut self,
 211        _project: Entity<Project>,
 212        event: &project::Event,
 213        _cx: &mut Context<Self>,
 214    ) {
 215        match event {
 216            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 217                self.enqueue_system_prompt_reload();
 218            }
 219            project::Event::WorktreeUpdatedEntries(_, items) => {
 220                if items.iter().any(|(path, _, _)| {
 221                    RULES_FILE_NAMES
 222                        .iter()
 223                        .any(|name| path.as_ref() == Path::new(name))
 224                }) {
 225                    self.enqueue_system_prompt_reload();
 226                }
 227            }
 228            _ => {}
 229        }
 230    }
 231
 232    fn enqueue_system_prompt_reload(&mut self) {
 233        self.reload_system_prompt_tx.try_send(()).ok();
 234    }
 235
 236    // Note that this should only be called from `reload_system_prompt_task`.
 237    fn reload_system_prompt(
 238        &self,
 239        prompt_store: Option<Entity<PromptStore>>,
 240        cx: &mut Context<Self>,
 241    ) -> Task<()> {
 242        let worktrees = self
 243            .project
 244            .read(cx)
 245            .visible_worktrees(cx)
 246            .collect::<Vec<_>>();
 247        let worktree_tasks = worktrees
 248            .into_iter()
 249            .map(|worktree| {
 250                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 251            })
 252            .collect::<Vec<_>>();
 253        let default_user_rules_task = match prompt_store {
 254            None => Task::ready(vec![]),
 255            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 256                let prompts = prompt_store.default_prompt_metadata();
 257                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 258                    let contents = prompt_store.load(prompt_metadata.id, cx);
 259                    async move { (contents.await, prompt_metadata) }
 260                });
 261                cx.background_spawn(future::join_all(load_tasks))
 262            }),
 263        };
 264
 265        cx.spawn(async move |this, cx| {
 266            let (worktrees, default_user_rules) =
 267                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 268
 269            let worktrees = worktrees
 270                .into_iter()
 271                .map(|(worktree, rules_error)| {
 272                    if let Some(rules_error) = rules_error {
 273                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 274                    }
 275                    worktree
 276                })
 277                .collect::<Vec<_>>();
 278
 279            let default_user_rules = default_user_rules
 280                .into_iter()
 281                .flat_map(|(contents, prompt_metadata)| match contents {
 282                    Ok(contents) => Some(UserRulesContext {
 283                        uuid: match prompt_metadata.id {
 284                            PromptId::User { uuid } => uuid,
 285                            PromptId::EditWorkflow => return None,
 286                        },
 287                        title: prompt_metadata.title.map(|title| title.to_string()),
 288                        contents,
 289                    }),
 290                    Err(err) => {
 291                        this.update(cx, |_, cx| {
 292                            cx.emit(RulesLoadingError {
 293                                message: format!("{err:?}").into(),
 294                            });
 295                        })
 296                        .ok();
 297                        None
 298                    }
 299                })
 300                .collect::<Vec<_>>();
 301
 302            this.update(cx, |this, _cx| {
 303                *this.project_context.0.borrow_mut() =
 304                    Some(ProjectContext::new(worktrees, default_user_rules));
 305            })
 306            .ok();
 307        })
 308    }
 309
 310    fn load_worktree_info_for_system_prompt(
 311        worktree: Entity<Worktree>,
 312        project: Entity<Project>,
 313        cx: &mut App,
 314    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 315        let root_name = worktree.read(cx).root_name().into();
 316
 317        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 318        let Some(rules_task) = rules_task else {
 319            return Task::ready((
 320                WorktreeContext {
 321                    root_name,
 322                    rules_file: None,
 323                },
 324                None,
 325            ));
 326        };
 327
 328        cx.spawn(async move |_| {
 329            let (rules_file, rules_file_error) = match rules_task.await {
 330                Ok(rules_file) => (Some(rules_file), None),
 331                Err(err) => (
 332                    None,
 333                    Some(RulesLoadingError {
 334                        message: format!("{err}").into(),
 335                    }),
 336                ),
 337            };
 338            let worktree_info = WorktreeContext {
 339                root_name,
 340                rules_file,
 341            };
 342            (worktree_info, rules_file_error)
 343        })
 344    }
 345
 346    fn load_worktree_rules_file(
 347        worktree: Entity<Worktree>,
 348        project: Entity<Project>,
 349        cx: &mut App,
 350    ) -> Option<Task<Result<RulesFileContext>>> {
 351        let worktree_ref = worktree.read(cx);
 352        let worktree_id = worktree_ref.id();
 353        let selected_rules_file = RULES_FILE_NAMES
 354            .into_iter()
 355            .filter_map(|name| {
 356                worktree_ref
 357                    .entry_for_path(name)
 358                    .filter(|entry| entry.is_file())
 359                    .map(|entry| entry.path.clone())
 360            })
 361            .next();
 362
 363        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 364        // supported. This doesn't seem to occur often in GitHub repositories.
 365        selected_rules_file.map(|path_in_worktree| {
 366            let project_path = ProjectPath {
 367                worktree_id,
 368                path: path_in_worktree.clone(),
 369            };
 370            let buffer_task =
 371                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 372            let rope_task = cx.spawn(async move |cx| {
 373                buffer_task.await?.read_with(cx, |buffer, cx| {
 374                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 375                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 376                })?
 377            });
 378            // Build a string from the rope on a background thread.
 379            cx.background_spawn(async move {
 380                let (project_entry_id, rope) = rope_task.await?;
 381                anyhow::Ok(RulesFileContext {
 382                    path_in_worktree,
 383                    text: rope.to_string().trim().to_string(),
 384                    project_entry_id: project_entry_id.to_usize(),
 385                })
 386            })
 387        })
 388    }
 389
 390    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 391        &self.prompt_store
 392    }
 393
 394    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 395        self.tools.clone()
 396    }
 397
 398    /// Returns the number of threads.
 399    pub fn thread_count(&self) -> usize {
 400        self.threads.len()
 401    }
 402
 403    pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 404        // ordering is from "ORDER BY" in `list_threads`
 405        self.threads.iter()
 406    }
 407
 408    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 409        cx.new(|cx| {
 410            Thread::new(
 411                self.project.clone(),
 412                self.tools.clone(),
 413                self.prompt_builder.clone(),
 414                self.project_context.clone(),
 415                cx,
 416            )
 417        })
 418    }
 419
 420    pub fn create_thread_from_serialized(
 421        &mut self,
 422        serialized: SerializedThread,
 423        cx: &mut Context<Self>,
 424    ) -> Entity<Thread> {
 425        cx.new(|cx| {
 426            Thread::deserialize(
 427                ThreadId::new(),
 428                serialized,
 429                self.project.clone(),
 430                self.tools.clone(),
 431                self.prompt_builder.clone(),
 432                self.project_context.clone(),
 433                None,
 434                cx,
 435            )
 436        })
 437    }
 438
 439    pub fn open_thread(
 440        &self,
 441        id: &ThreadId,
 442        window: &mut Window,
 443        cx: &mut Context<Self>,
 444    ) -> Task<Result<Entity<Thread>>> {
 445        let id = id.clone();
 446        let database_future = ThreadsDatabase::global_future(cx);
 447        let this = cx.weak_entity();
 448        window.spawn(cx, async move |cx| {
 449            let database = database_future.await.map_err(|err| anyhow!(err))?;
 450            let thread = database
 451                .try_find_thread(id.clone())
 452                .await?
 453                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 454
 455            let thread = this.update_in(cx, |this, window, cx| {
 456                cx.new(|cx| {
 457                    Thread::deserialize(
 458                        id.clone(),
 459                        thread,
 460                        this.project.clone(),
 461                        this.tools.clone(),
 462                        this.prompt_builder.clone(),
 463                        this.project_context.clone(),
 464                        Some(window),
 465                        cx,
 466                    )
 467                })
 468            })?;
 469
 470            Ok(thread)
 471        })
 472    }
 473
 474    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 475        let (metadata, serialized_thread) =
 476            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 477
 478        let database_future = ThreadsDatabase::global_future(cx);
 479        cx.spawn(async move |this, cx| {
 480            let serialized_thread = serialized_thread.await?;
 481            let database = database_future.await.map_err(|err| anyhow!(err))?;
 482            database.save_thread(metadata, serialized_thread).await?;
 483
 484            this.update(cx, |this, cx| this.reload(cx))?.await
 485        })
 486    }
 487
 488    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 489        let id = id.clone();
 490        let database_future = ThreadsDatabase::global_future(cx);
 491        cx.spawn(async move |this, cx| {
 492            let database = database_future.await.map_err(|err| anyhow!(err))?;
 493            database.delete_thread(id.clone()).await?;
 494
 495            this.update(cx, |this, cx| {
 496                this.threads.retain(|thread| thread.id != id);
 497                cx.notify();
 498            })
 499        })
 500    }
 501
 502    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 503        let database_future = ThreadsDatabase::global_future(cx);
 504        cx.spawn(async move |this, cx| {
 505            let threads = database_future
 506                .await
 507                .map_err(|err| anyhow!(err))?
 508                .list_threads()
 509                .await?;
 510
 511            this.update(cx, |this, cx| {
 512                this.threads = threads;
 513                cx.notify();
 514            })
 515        })
 516    }
 517
 518    fn load_default_profile(&self, cx: &mut Context<Self>) {
 519        let assistant_settings = AgentSettings::get_global(cx);
 520
 521        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
 522    }
 523
 524    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
 525        let assistant_settings = AgentSettings::get_global(cx);
 526
 527        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
 528            self.load_profile(profile.clone(), cx);
 529        }
 530    }
 531
 532    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
 533        self.tools.update(cx, |tools, cx| {
 534            tools.disable_all_tools(cx);
 535            tools.enable(
 536                ToolSource::Native,
 537                &profile
 538                    .tools
 539                    .into_iter()
 540                    .filter_map(|(tool, enabled)| enabled.then(|| tool))
 541                    .collect::<Vec<_>>(),
 542                cx,
 543            );
 544        });
 545
 546        if profile.enable_all_context_servers {
 547            for context_server_id in self
 548                .project
 549                .read(cx)
 550                .context_server_store()
 551                .read(cx)
 552                .all_server_ids()
 553            {
 554                self.tools.update(cx, |tools, cx| {
 555                    tools.enable_source(
 556                        ToolSource::ContextServer {
 557                            id: context_server_id.0.into(),
 558                        },
 559                        cx,
 560                    );
 561                });
 562            }
 563            // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
 564            for (context_server_id, preset) in profile.context_servers {
 565                self.tools.update(cx, |tools, cx| {
 566                    tools.disable(
 567                        ToolSource::ContextServer {
 568                            id: context_server_id.into(),
 569                        },
 570                        &preset
 571                            .tools
 572                            .into_iter()
 573                            .filter_map(|(tool, enabled)| (!enabled).then(|| tool))
 574                            .collect::<Vec<_>>(),
 575                        cx,
 576                    )
 577                })
 578            }
 579        } else {
 580            for (context_server_id, preset) in profile.context_servers {
 581                self.tools.update(cx, |tools, cx| {
 582                    tools.enable(
 583                        ToolSource::ContextServer {
 584                            id: context_server_id.into(),
 585                        },
 586                        &preset
 587                            .tools
 588                            .into_iter()
 589                            .filter_map(|(tool, enabled)| enabled.then(|| tool))
 590                            .collect::<Vec<_>>(),
 591                        cx,
 592                    )
 593                })
 594            }
 595        }
 596    }
 597
 598    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 599        cx.subscribe(
 600            &self.project.read(cx).context_server_store(),
 601            Self::handle_context_server_event,
 602        )
 603        .detach();
 604    }
 605
 606    fn handle_context_server_event(
 607        &mut self,
 608        context_server_store: Entity<ContextServerStore>,
 609        event: &project::context_server_store::Event,
 610        cx: &mut Context<Self>,
 611    ) {
 612        let tool_working_set = self.tools.clone();
 613        match event {
 614            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 615                match status {
 616                    ContextServerStatus::Running => {
 617                        if let Some(server) =
 618                            context_server_store.read(cx).get_running_server(server_id)
 619                        {
 620                            let context_server_manager = context_server_store.clone();
 621                            cx.spawn({
 622                                let server = server.clone();
 623                                let server_id = server_id.clone();
 624                                async move |this, cx| {
 625                                    let Some(protocol) = server.client() else {
 626                                        return;
 627                                    };
 628
 629                                    if protocol.capable(context_server::protocol::ServerCapability::Tools) {
 630                                        if let Some(tools) = protocol.list_tools().await.log_err() {
 631                                            let tool_ids = tool_working_set
 632                                                .update(cx, |tool_working_set, _| {
 633                                                    tools
 634                                                        .tools
 635                                                        .into_iter()
 636                                                        .map(|tool| {
 637                                                            log::info!(
 638                                                                "registering context server tool: {:?}",
 639                                                                tool.name
 640                                                            );
 641                                                            tool_working_set.insert(Arc::new(
 642                                                                ContextServerTool::new(
 643                                                                    context_server_manager.clone(),
 644                                                                    server.id(),
 645                                                                    tool,
 646                                                                ),
 647                                                            ))
 648                                                        })
 649                                                        .collect::<Vec<_>>()
 650                                                })
 651                                                .log_err();
 652
 653                                            if let Some(tool_ids) = tool_ids {
 654                                                this.update(cx, |this, cx| {
 655                                                    this.context_server_tool_ids
 656                                                        .insert(server_id, tool_ids);
 657                                                    this.load_default_profile(cx);
 658                                                })
 659                                                .log_err();
 660                                            }
 661                                        }
 662                                    }
 663                                }
 664                            })
 665                            .detach();
 666                        }
 667                    }
 668                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 669                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 670                            tool_working_set.update(cx, |tool_working_set, _| {
 671                                tool_working_set.remove(&tool_ids);
 672                            });
 673                            self.load_default_profile(cx);
 674                        }
 675                    }
 676                    _ => {}
 677                }
 678            }
 679        }
 680    }
 681}
 682
 683#[derive(Debug, Clone, Serialize, Deserialize)]
 684pub struct SerializedThreadMetadata {
 685    pub id: ThreadId,
 686    pub summary: SharedString,
 687    pub updated_at: DateTime<Utc>,
 688}
 689
 690#[derive(Serialize, Deserialize, Debug)]
 691pub struct SerializedThread {
 692    pub version: String,
 693    pub summary: SharedString,
 694    pub updated_at: DateTime<Utc>,
 695    pub messages: Vec<SerializedMessage>,
 696    #[serde(default)]
 697    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 698    #[serde(default)]
 699    pub cumulative_token_usage: TokenUsage,
 700    #[serde(default)]
 701    pub request_token_usage: Vec<TokenUsage>,
 702    #[serde(default)]
 703    pub detailed_summary_state: DetailedSummaryState,
 704    #[serde(default)]
 705    pub exceeded_window_error: Option<ExceededWindowError>,
 706    #[serde(default)]
 707    pub model: Option<SerializedLanguageModel>,
 708    #[serde(default)]
 709    pub completion_mode: Option<CompletionMode>,
 710    #[serde(default)]
 711    pub tool_use_limit_reached: bool,
 712}
 713
 714#[derive(Serialize, Deserialize, Debug)]
 715pub struct SerializedLanguageModel {
 716    pub provider: String,
 717    pub model: String,
 718}
 719
 720impl SerializedThread {
 721    pub const VERSION: &'static str = "0.2.0";
 722
 723    pub fn from_json(json: &[u8]) -> Result<Self> {
 724        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 725        match saved_thread_json.get("version") {
 726            Some(serde_json::Value::String(version)) => match version.as_str() {
 727                SerializedThreadV0_1_0::VERSION => {
 728                    let saved_thread =
 729                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 730                    Ok(saved_thread.upgrade())
 731                }
 732                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 733                    saved_thread_json,
 734                )?),
 735                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 736            },
 737            None => {
 738                let saved_thread =
 739                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 740                Ok(saved_thread.upgrade())
 741            }
 742            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 743        }
 744    }
 745}
 746
 747#[derive(Serialize, Deserialize, Debug)]
 748pub struct SerializedThreadV0_1_0(
 749    // The structure did not change, so we are reusing the latest SerializedThread.
 750    // When making the next version, make sure this points to SerializedThreadV0_2_0
 751    SerializedThread,
 752);
 753
 754impl SerializedThreadV0_1_0 {
 755    pub const VERSION: &'static str = "0.1.0";
 756
 757    pub fn upgrade(self) -> SerializedThread {
 758        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 759
 760        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 761
 762        for message in self.0.messages {
 763            if message.role == Role::User && !message.tool_results.is_empty() {
 764                if let Some(last_message) = messages.last_mut() {
 765                    debug_assert!(last_message.role == Role::Assistant);
 766
 767                    last_message.tool_results = message.tool_results;
 768                    continue;
 769                }
 770            }
 771
 772            messages.push(message);
 773        }
 774
 775        SerializedThread { messages, ..self.0 }
 776    }
 777}
 778
 779#[derive(Debug, Serialize, Deserialize)]
 780pub struct SerializedMessage {
 781    pub id: MessageId,
 782    pub role: Role,
 783    #[serde(default)]
 784    pub segments: Vec<SerializedMessageSegment>,
 785    #[serde(default)]
 786    pub tool_uses: Vec<SerializedToolUse>,
 787    #[serde(default)]
 788    pub tool_results: Vec<SerializedToolResult>,
 789    #[serde(default)]
 790    pub context: String,
 791    #[serde(default)]
 792    pub creases: Vec<SerializedCrease>,
 793    #[serde(default)]
 794    pub is_hidden: bool,
 795}
 796
 797#[derive(Debug, Serialize, Deserialize)]
 798#[serde(tag = "type")]
 799pub enum SerializedMessageSegment {
 800    #[serde(rename = "text")]
 801    Text {
 802        text: String,
 803    },
 804    #[serde(rename = "thinking")]
 805    Thinking {
 806        text: String,
 807        #[serde(skip_serializing_if = "Option::is_none")]
 808        signature: Option<String>,
 809    },
 810    RedactedThinking {
 811        data: Vec<u8>,
 812    },
 813}
 814
 815#[derive(Debug, Serialize, Deserialize)]
 816pub struct SerializedToolUse {
 817    pub id: LanguageModelToolUseId,
 818    pub name: SharedString,
 819    pub input: serde_json::Value,
 820}
 821
 822#[derive(Debug, Serialize, Deserialize)]
 823pub struct SerializedToolResult {
 824    pub tool_use_id: LanguageModelToolUseId,
 825    pub is_error: bool,
 826    pub content: LanguageModelToolResultContent,
 827    pub output: Option<serde_json::Value>,
 828}
 829
 830#[derive(Serialize, Deserialize)]
 831struct LegacySerializedThread {
 832    pub summary: SharedString,
 833    pub updated_at: DateTime<Utc>,
 834    pub messages: Vec<LegacySerializedMessage>,
 835    #[serde(default)]
 836    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 837}
 838
 839impl LegacySerializedThread {
 840    pub fn upgrade(self) -> SerializedThread {
 841        SerializedThread {
 842            version: SerializedThread::VERSION.to_string(),
 843            summary: self.summary,
 844            updated_at: self.updated_at,
 845            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 846            initial_project_snapshot: self.initial_project_snapshot,
 847            cumulative_token_usage: TokenUsage::default(),
 848            request_token_usage: Vec::new(),
 849            detailed_summary_state: DetailedSummaryState::default(),
 850            exceeded_window_error: None,
 851            model: None,
 852            completion_mode: None,
 853            tool_use_limit_reached: false,
 854        }
 855    }
 856}
 857
 858#[derive(Debug, Serialize, Deserialize)]
 859struct LegacySerializedMessage {
 860    pub id: MessageId,
 861    pub role: Role,
 862    pub text: String,
 863    #[serde(default)]
 864    pub tool_uses: Vec<SerializedToolUse>,
 865    #[serde(default)]
 866    pub tool_results: Vec<SerializedToolResult>,
 867}
 868
 869impl LegacySerializedMessage {
 870    fn upgrade(self) -> SerializedMessage {
 871        SerializedMessage {
 872            id: self.id,
 873            role: self.role,
 874            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 875            tool_uses: self.tool_uses,
 876            tool_results: self.tool_results,
 877            context: String::new(),
 878            creases: Vec::new(),
 879            is_hidden: false,
 880        }
 881    }
 882}
 883
 884#[derive(Debug, Serialize, Deserialize)]
 885pub struct SerializedCrease {
 886    pub start: usize,
 887    pub end: usize,
 888    pub icon_path: SharedString,
 889    pub label: SharedString,
 890}
 891
 892struct GlobalThreadsDatabase(
 893    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 894);
 895
 896impl Global for GlobalThreadsDatabase {}
 897
 898pub(crate) struct ThreadsDatabase {
 899    executor: BackgroundExecutor,
 900    connection: Arc<Mutex<Connection>>,
 901}
 902
 903impl ThreadsDatabase {
 904    fn connection(&self) -> Arc<Mutex<Connection>> {
 905        self.connection.clone()
 906    }
 907
 908    const COMPRESSION_LEVEL: i32 = 3;
 909}
 910
 911impl Bind for ThreadId {
 912    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 913        self.to_string().bind(statement, start_index)
 914    }
 915}
 916
 917impl Column for ThreadId {
 918    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 919        let (id_str, next_index) = String::column(statement, start_index)?;
 920        Ok((ThreadId::from(id_str.as_str()), next_index))
 921    }
 922}
 923
 924impl ThreadsDatabase {
 925    fn global_future(
 926        cx: &mut App,
 927    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 928        GlobalThreadsDatabase::global(cx).0.clone()
 929    }
 930
 931    fn init(cx: &mut App) {
 932        let executor = cx.background_executor().clone();
 933        let database_future = executor
 934            .spawn({
 935                let executor = executor.clone();
 936                let threads_dir = paths::data_dir().join("threads");
 937                async move { ThreadsDatabase::new(threads_dir, executor) }
 938            })
 939            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 940            .boxed()
 941            .shared();
 942
 943        cx.set_global(GlobalThreadsDatabase(database_future));
 944    }
 945
 946    pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
 947        std::fs::create_dir_all(&threads_dir)?;
 948
 949        let sqlite_path = threads_dir.join("threads.db");
 950        let mdb_path = threads_dir.join("threads-db.1.mdb");
 951
 952        let needs_migration_from_heed = mdb_path.exists();
 953
 954        let connection = Connection::open_file(&sqlite_path.to_string_lossy());
 955
 956        connection.exec(indoc! {"
 957                CREATE TABLE IF NOT EXISTS threads (
 958                    id TEXT PRIMARY KEY,
 959                    summary TEXT NOT NULL,
 960                    updated_at TEXT NOT NULL,
 961                    data_type TEXT NOT NULL,
 962                    data BLOB NOT NULL
 963                )
 964            "})?()
 965        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 966
 967        let db = Self {
 968            executor: executor.clone(),
 969            connection: Arc::new(Mutex::new(connection)),
 970        };
 971
 972        if needs_migration_from_heed {
 973            let db_connection = db.connection();
 974            let executor_clone = executor.clone();
 975            executor
 976                .spawn(async move {
 977                    log::info!("Starting threads.db migration");
 978                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 979                    std::fs::remove_dir_all(mdb_path)?;
 980                    log::info!("threads.db migrated to sqlite");
 981                    Ok::<(), anyhow::Error>(())
 982                })
 983                .detach();
 984        }
 985
 986        Ok(db)
 987    }
 988
 989    // Remove this migration after 2025-09-01
 990    fn migrate_from_heed(
 991        mdb_path: &Path,
 992        connection: Arc<Mutex<Connection>>,
 993        _executor: BackgroundExecutor,
 994    ) -> Result<()> {
 995        use heed::types::SerdeBincode;
 996        struct SerializedThreadHeed(SerializedThread);
 997
 998        impl heed::BytesEncode<'_> for SerializedThreadHeed {
 999            type EItem = SerializedThreadHeed;
1000
1001            fn bytes_encode(
1002                item: &Self::EItem,
1003            ) -> Result<std::borrow::Cow<[u8]>, heed::BoxedError> {
1004                serde_json::to_vec(&item.0)
1005                    .map(std::borrow::Cow::Owned)
1006                    .map_err(Into::into)
1007            }
1008        }
1009
1010        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
1011            type DItem = SerializedThreadHeed;
1012
1013            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
1014                SerializedThread::from_json(bytes)
1015                    .map(SerializedThreadHeed)
1016                    .map_err(Into::into)
1017            }
1018        }
1019
1020        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
1021
1022        let env = unsafe {
1023            heed::EnvOpenOptions::new()
1024                .map_size(ONE_GB_IN_BYTES)
1025                .max_dbs(1)
1026                .open(mdb_path)?
1027        };
1028
1029        let txn = env.write_txn()?;
1030        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
1031            .open_database(&txn, Some("threads"))?
1032            .ok_or_else(|| anyhow!("threads database not found"))?;
1033
1034        for result in threads.iter(&txn)? {
1035            let (thread_id, thread_heed) = result?;
1036            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
1037        }
1038
1039        Ok(())
1040    }
1041
1042    fn save_thread_sync(
1043        connection: &Arc<Mutex<Connection>>,
1044        id: ThreadId,
1045        thread: SerializedThread,
1046    ) -> Result<()> {
1047        let json_data = serde_json::to_string(&thread)?;
1048        let summary = thread.summary.to_string();
1049        let updated_at = thread.updated_at.to_rfc3339();
1050
1051        let connection = connection.lock().unwrap();
1052
1053        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1054        let data_type = DataType::Zstd;
1055        let data = compressed;
1056
1057        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1058            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1059        "})?;
1060
1061        insert((id, summary, updated_at, data_type, data))?;
1062
1063        Ok(())
1064    }
1065
1066    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1067        let connection = self.connection.clone();
1068
1069        self.executor.spawn(async move {
1070            let connection = connection.lock().unwrap();
1071            let mut select =
1072                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1073                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1074            "})?;
1075
1076            let rows = select(())?;
1077            let mut threads = Vec::new();
1078
1079            for (id, summary, updated_at) in rows {
1080                threads.push(SerializedThreadMetadata {
1081                    id,
1082                    summary: summary.into(),
1083                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1084                });
1085            }
1086
1087            Ok(threads)
1088        })
1089    }
1090
1091    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1092        let connection = self.connection.clone();
1093
1094        self.executor.spawn(async move {
1095            let connection = connection.lock().unwrap();
1096            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1097                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1098            "})?;
1099
1100            let rows = select(id)?;
1101            if let Some((data_type, data)) = rows.into_iter().next() {
1102                let json_data = match data_type {
1103                    DataType::Zstd => {
1104                        let decompressed = zstd::decode_all(&data[..])?;
1105                        String::from_utf8(decompressed)?
1106                    }
1107                    DataType::Json => String::from_utf8(data)?,
1108                };
1109
1110                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1111                Ok(Some(thread))
1112            } else {
1113                Ok(None)
1114            }
1115        })
1116    }
1117
1118    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1119        let connection = self.connection.clone();
1120
1121        self.executor
1122            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1123    }
1124
1125    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1126        let connection = self.connection.clone();
1127
1128        self.executor.spawn(async move {
1129            let connection = connection.lock().unwrap();
1130
1131            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1132                DELETE FROM threads WHERE id = ?
1133            "})?;
1134
1135            delete(id)?;
1136
1137            Ok(())
1138        })
1139    }
1140}