thread_store.rs

   1use std::cell::{Ref, RefCell};
   2use std::path::{Path, PathBuf};
   3use std::rc::Rc;
   4use std::sync::{Arc, Mutex};
   5
   6use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::HashMap;
  11use context_server::ContextServerId;
  12use futures::channel::{mpsc, oneshot};
  13use futures::future::{self, BoxFuture, Shared};
  14use futures::{FutureExt as _, StreamExt as _};
  15use gpui::{
  16    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  17    Subscription, Task, prelude::*,
  18};
  19
  20use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  21use project::context_server_store::{ContextServerStatus, ContextServerStore};
  22use project::{Project, ProjectItem, ProjectPath, Worktree};
  23use prompt_store::{
  24    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  25    UserRulesContext, WorktreeContext,
  26};
  27use serde::{Deserialize, Serialize};
  28use settings::{Settings as _, SettingsStore};
  29use ui::Window;
  30use util::ResultExt as _;
  31
  32use crate::context_server_tool::ContextServerTool;
  33use crate::thread::{
  34    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
  35};
  36use indoc::indoc;
  37use sqlez::{
  38    bindable::{Bind, Column},
  39    connection::Connection,
  40    statement::Statement,
  41};
  42
  43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  44pub enum DataType {
  45    #[serde(rename = "json")]
  46    Json,
  47    #[serde(rename = "zstd")]
  48    Zstd,
  49}
  50
  51impl Bind for DataType {
  52    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  53        let value = match self {
  54            DataType::Json => "json",
  55            DataType::Zstd => "zstd",
  56        };
  57        value.bind(statement, start_index)
  58    }
  59}
  60
  61impl Column for DataType {
  62    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  63        let (value, next_index) = String::column(statement, start_index)?;
  64        let data_type = match value.as_str() {
  65            "json" => DataType::Json,
  66            "zstd" => DataType::Zstd,
  67            _ => anyhow::bail!("Unknown data type: {}", value),
  68        };
  69        Ok((data_type, next_index))
  70    }
  71}
  72
  73const RULES_FILE_NAMES: [&'static str; 6] = [
  74    ".rules",
  75    ".cursorrules",
  76    ".windsurfrules",
  77    ".clinerules",
  78    ".github/copilot-instructions.md",
  79    "CLAUDE.md",
  80];
  81
  82pub fn init(cx: &mut App) {
  83    ThreadsDatabase::init(cx);
  84}
  85
  86/// A system prompt shared by all threads created by this ThreadStore
  87#[derive(Clone, Default)]
  88pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  89
  90impl SharedProjectContext {
  91    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
  92        self.0.borrow()
  93    }
  94}
  95
  96pub type TextThreadStore = assistant_context_editor::ContextStore;
  97
  98pub struct ThreadStore {
  99    project: Entity<Project>,
 100    tools: Entity<ToolWorkingSet>,
 101    prompt_builder: Arc<PromptBuilder>,
 102    prompt_store: Option<Entity<PromptStore>>,
 103    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 104    threads: Vec<SerializedThreadMetadata>,
 105    project_context: SharedProjectContext,
 106    reload_system_prompt_tx: mpsc::Sender<()>,
 107    _reload_system_prompt_task: Task<()>,
 108    _subscriptions: Vec<Subscription>,
 109}
 110
 111pub struct RulesLoadingError {
 112    pub message: SharedString,
 113}
 114
 115impl EventEmitter<RulesLoadingError> for ThreadStore {}
 116
 117impl ThreadStore {
 118    pub fn load(
 119        project: Entity<Project>,
 120        tools: Entity<ToolWorkingSet>,
 121        prompt_store: Option<Entity<PromptStore>>,
 122        prompt_builder: Arc<PromptBuilder>,
 123        cx: &mut App,
 124    ) -> Task<Result<Entity<Self>>> {
 125        cx.spawn(async move |cx| {
 126            let (thread_store, ready_rx) = cx.update(|cx| {
 127                let mut option_ready_rx = None;
 128                let thread_store = cx.new(|cx| {
 129                    let (thread_store, ready_rx) =
 130                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 131                    option_ready_rx = Some(ready_rx);
 132                    thread_store
 133                });
 134                (thread_store, option_ready_rx.take().unwrap())
 135            })?;
 136            ready_rx.await?;
 137            Ok(thread_store)
 138        })
 139    }
 140
 141    fn new(
 142        project: Entity<Project>,
 143        tools: Entity<ToolWorkingSet>,
 144        prompt_builder: Arc<PromptBuilder>,
 145        prompt_store: Option<Entity<PromptStore>>,
 146        cx: &mut Context<Self>,
 147    ) -> (Self, oneshot::Receiver<()>) {
 148        let mut subscriptions = vec![
 149            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
 150                this.load_default_profile(cx);
 151            }),
 152            cx.subscribe(&project, Self::handle_project_event),
 153        ];
 154
 155        if let Some(prompt_store) = prompt_store.as_ref() {
 156            subscriptions.push(cx.subscribe(
 157                prompt_store,
 158                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 159                    this.enqueue_system_prompt_reload();
 160                },
 161            ))
 162        }
 163
 164        // This channel and task prevent concurrent and redundant loading of the system prompt.
 165        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 166        let (ready_tx, ready_rx) = oneshot::channel();
 167        let mut ready_tx = Some(ready_tx);
 168        let reload_system_prompt_task = cx.spawn({
 169            let prompt_store = prompt_store.clone();
 170            async move |thread_store, cx| {
 171                loop {
 172                    let Some(reload_task) = thread_store
 173                        .update(cx, |thread_store, cx| {
 174                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 175                        })
 176                        .ok()
 177                    else {
 178                        return;
 179                    };
 180                    reload_task.await;
 181                    if let Some(ready_tx) = ready_tx.take() {
 182                        ready_tx.send(()).ok();
 183                    }
 184                    reload_system_prompt_rx.next().await;
 185                }
 186            }
 187        });
 188
 189        let this = Self {
 190            project,
 191            tools,
 192            prompt_builder,
 193            prompt_store,
 194            context_server_tool_ids: HashMap::default(),
 195            threads: Vec::new(),
 196            project_context: SharedProjectContext::default(),
 197            reload_system_prompt_tx,
 198            _reload_system_prompt_task: reload_system_prompt_task,
 199            _subscriptions: subscriptions,
 200        };
 201        this.load_default_profile(cx);
 202        this.register_context_server_handlers(cx);
 203        this.reload(cx).detach_and_log_err(cx);
 204        (this, ready_rx)
 205    }
 206
 207    fn handle_project_event(
 208        &mut self,
 209        _project: Entity<Project>,
 210        event: &project::Event,
 211        _cx: &mut Context<Self>,
 212    ) {
 213        match event {
 214            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 215                self.enqueue_system_prompt_reload();
 216            }
 217            project::Event::WorktreeUpdatedEntries(_, items) => {
 218                if items.iter().any(|(path, _, _)| {
 219                    RULES_FILE_NAMES
 220                        .iter()
 221                        .any(|name| path.as_ref() == Path::new(name))
 222                }) {
 223                    self.enqueue_system_prompt_reload();
 224                }
 225            }
 226            _ => {}
 227        }
 228    }
 229
 230    fn enqueue_system_prompt_reload(&mut self) {
 231        self.reload_system_prompt_tx.try_send(()).ok();
 232    }
 233
 234    // Note that this should only be called from `reload_system_prompt_task`.
 235    fn reload_system_prompt(
 236        &self,
 237        prompt_store: Option<Entity<PromptStore>>,
 238        cx: &mut Context<Self>,
 239    ) -> Task<()> {
 240        let worktrees = self
 241            .project
 242            .read(cx)
 243            .visible_worktrees(cx)
 244            .collect::<Vec<_>>();
 245        let worktree_tasks = worktrees
 246            .into_iter()
 247            .map(|worktree| {
 248                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 249            })
 250            .collect::<Vec<_>>();
 251        let default_user_rules_task = match prompt_store {
 252            None => Task::ready(vec![]),
 253            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 254                let prompts = prompt_store.default_prompt_metadata();
 255                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 256                    let contents = prompt_store.load(prompt_metadata.id, cx);
 257                    async move { (contents.await, prompt_metadata) }
 258                });
 259                cx.background_spawn(future::join_all(load_tasks))
 260            }),
 261        };
 262
 263        cx.spawn(async move |this, cx| {
 264            let (worktrees, default_user_rules) =
 265                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 266
 267            let worktrees = worktrees
 268                .into_iter()
 269                .map(|(worktree, rules_error)| {
 270                    if let Some(rules_error) = rules_error {
 271                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 272                    }
 273                    worktree
 274                })
 275                .collect::<Vec<_>>();
 276
 277            let default_user_rules = default_user_rules
 278                .into_iter()
 279                .flat_map(|(contents, prompt_metadata)| match contents {
 280                    Ok(contents) => Some(UserRulesContext {
 281                        uuid: match prompt_metadata.id {
 282                            PromptId::User { uuid } => uuid,
 283                            PromptId::EditWorkflow => return None,
 284                        },
 285                        title: prompt_metadata.title.map(|title| title.to_string()),
 286                        contents,
 287                    }),
 288                    Err(err) => {
 289                        this.update(cx, |_, cx| {
 290                            cx.emit(RulesLoadingError {
 291                                message: format!("{err:?}").into(),
 292                            });
 293                        })
 294                        .ok();
 295                        None
 296                    }
 297                })
 298                .collect::<Vec<_>>();
 299
 300            this.update(cx, |this, _cx| {
 301                *this.project_context.0.borrow_mut() =
 302                    Some(ProjectContext::new(worktrees, default_user_rules));
 303            })
 304            .ok();
 305        })
 306    }
 307
 308    fn load_worktree_info_for_system_prompt(
 309        worktree: Entity<Worktree>,
 310        project: Entity<Project>,
 311        cx: &mut App,
 312    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 313        let root_name = worktree.read(cx).root_name().into();
 314
 315        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 316        let Some(rules_task) = rules_task else {
 317            return Task::ready((
 318                WorktreeContext {
 319                    root_name,
 320                    rules_file: None,
 321                },
 322                None,
 323            ));
 324        };
 325
 326        cx.spawn(async move |_| {
 327            let (rules_file, rules_file_error) = match rules_task.await {
 328                Ok(rules_file) => (Some(rules_file), None),
 329                Err(err) => (
 330                    None,
 331                    Some(RulesLoadingError {
 332                        message: format!("{err}").into(),
 333                    }),
 334                ),
 335            };
 336            let worktree_info = WorktreeContext {
 337                root_name,
 338                rules_file,
 339            };
 340            (worktree_info, rules_file_error)
 341        })
 342    }
 343
 344    fn load_worktree_rules_file(
 345        worktree: Entity<Worktree>,
 346        project: Entity<Project>,
 347        cx: &mut App,
 348    ) -> Option<Task<Result<RulesFileContext>>> {
 349        let worktree_ref = worktree.read(cx);
 350        let worktree_id = worktree_ref.id();
 351        let selected_rules_file = RULES_FILE_NAMES
 352            .into_iter()
 353            .filter_map(|name| {
 354                worktree_ref
 355                    .entry_for_path(name)
 356                    .filter(|entry| entry.is_file())
 357                    .map(|entry| entry.path.clone())
 358            })
 359            .next();
 360
 361        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 362        // supported. This doesn't seem to occur often in GitHub repositories.
 363        selected_rules_file.map(|path_in_worktree| {
 364            let project_path = ProjectPath {
 365                worktree_id,
 366                path: path_in_worktree.clone(),
 367            };
 368            let buffer_task =
 369                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 370            let rope_task = cx.spawn(async move |cx| {
 371                buffer_task.await?.read_with(cx, |buffer, cx| {
 372                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 373                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 374                })?
 375            });
 376            // Build a string from the rope on a background thread.
 377            cx.background_spawn(async move {
 378                let (project_entry_id, rope) = rope_task.await?;
 379                anyhow::Ok(RulesFileContext {
 380                    path_in_worktree,
 381                    text: rope.to_string().trim().to_string(),
 382                    project_entry_id: project_entry_id.to_usize(),
 383                })
 384            })
 385        })
 386    }
 387
 388    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 389        &self.prompt_store
 390    }
 391
 392    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 393        self.tools.clone()
 394    }
 395
 396    /// Returns the number of threads.
 397    pub fn thread_count(&self) -> usize {
 398        self.threads.len()
 399    }
 400
 401    pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 402        self.threads.iter()
 403    }
 404
 405    pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
 406        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 407        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 408        threads
 409    }
 410
 411    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 412        cx.new(|cx| {
 413            Thread::new(
 414                self.project.clone(),
 415                self.tools.clone(),
 416                self.prompt_builder.clone(),
 417                self.project_context.clone(),
 418                cx,
 419            )
 420        })
 421    }
 422
 423    pub fn create_thread_from_serialized(
 424        &mut self,
 425        serialized: SerializedThread,
 426        cx: &mut Context<Self>,
 427    ) -> Entity<Thread> {
 428        cx.new(|cx| {
 429            Thread::deserialize(
 430                ThreadId::new(),
 431                serialized,
 432                self.project.clone(),
 433                self.tools.clone(),
 434                self.prompt_builder.clone(),
 435                self.project_context.clone(),
 436                None,
 437                cx,
 438            )
 439        })
 440    }
 441
 442    pub fn open_thread(
 443        &self,
 444        id: &ThreadId,
 445        window: &mut Window,
 446        cx: &mut Context<Self>,
 447    ) -> Task<Result<Entity<Thread>>> {
 448        let id = id.clone();
 449        let database_future = ThreadsDatabase::global_future(cx);
 450        let this = cx.weak_entity();
 451        window.spawn(cx, async move |cx| {
 452            let database = database_future.await.map_err(|err| anyhow!(err))?;
 453            let thread = database
 454                .try_find_thread(id.clone())
 455                .await?
 456                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 457
 458            let thread = this.update_in(cx, |this, window, cx| {
 459                cx.new(|cx| {
 460                    Thread::deserialize(
 461                        id.clone(),
 462                        thread,
 463                        this.project.clone(),
 464                        this.tools.clone(),
 465                        this.prompt_builder.clone(),
 466                        this.project_context.clone(),
 467                        Some(window),
 468                        cx,
 469                    )
 470                })
 471            })?;
 472
 473            Ok(thread)
 474        })
 475    }
 476
 477    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 478        let (metadata, serialized_thread) =
 479            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 480
 481        let database_future = ThreadsDatabase::global_future(cx);
 482        cx.spawn(async move |this, cx| {
 483            let serialized_thread = serialized_thread.await?;
 484            let database = database_future.await.map_err(|err| anyhow!(err))?;
 485            database.save_thread(metadata, serialized_thread).await?;
 486
 487            this.update(cx, |this, cx| this.reload(cx))?.await
 488        })
 489    }
 490
 491    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 492        let id = id.clone();
 493        let database_future = ThreadsDatabase::global_future(cx);
 494        cx.spawn(async move |this, cx| {
 495            let database = database_future.await.map_err(|err| anyhow!(err))?;
 496            database.delete_thread(id.clone()).await?;
 497
 498            this.update(cx, |this, cx| {
 499                this.threads.retain(|thread| thread.id != id);
 500                cx.notify();
 501            })
 502        })
 503    }
 504
 505    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 506        let database_future = ThreadsDatabase::global_future(cx);
 507        cx.spawn(async move |this, cx| {
 508            let threads = database_future
 509                .await
 510                .map_err(|err| anyhow!(err))?
 511                .list_threads()
 512                .await?;
 513
 514            this.update(cx, |this, cx| {
 515                this.threads = threads;
 516                cx.notify();
 517            })
 518        })
 519    }
 520
 521    fn load_default_profile(&self, cx: &mut Context<Self>) {
 522        let assistant_settings = AgentSettings::get_global(cx);
 523
 524        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
 525    }
 526
 527    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
 528        let assistant_settings = AgentSettings::get_global(cx);
 529
 530        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
 531            self.load_profile(profile.clone(), cx);
 532        }
 533    }
 534
 535    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
 536        self.tools.update(cx, |tools, cx| {
 537            tools.disable_all_tools(cx);
 538            tools.enable(
 539                ToolSource::Native,
 540                &profile
 541                    .tools
 542                    .into_iter()
 543                    .filter_map(|(tool, enabled)| enabled.then(|| tool))
 544                    .collect::<Vec<_>>(),
 545                cx,
 546            );
 547        });
 548
 549        if profile.enable_all_context_servers {
 550            for context_server_id in self
 551                .project
 552                .read(cx)
 553                .context_server_store()
 554                .read(cx)
 555                .all_server_ids()
 556            {
 557                self.tools.update(cx, |tools, cx| {
 558                    tools.enable_source(
 559                        ToolSource::ContextServer {
 560                            id: context_server_id.0.into(),
 561                        },
 562                        cx,
 563                    );
 564                });
 565            }
 566            // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
 567            for (context_server_id, preset) in profile.context_servers {
 568                self.tools.update(cx, |tools, cx| {
 569                    tools.disable(
 570                        ToolSource::ContextServer {
 571                            id: context_server_id.into(),
 572                        },
 573                        &preset
 574                            .tools
 575                            .into_iter()
 576                            .filter_map(|(tool, enabled)| (!enabled).then(|| tool))
 577                            .collect::<Vec<_>>(),
 578                        cx,
 579                    )
 580                })
 581            }
 582        } else {
 583            for (context_server_id, preset) in profile.context_servers {
 584                self.tools.update(cx, |tools, cx| {
 585                    tools.enable(
 586                        ToolSource::ContextServer {
 587                            id: context_server_id.into(),
 588                        },
 589                        &preset
 590                            .tools
 591                            .into_iter()
 592                            .filter_map(|(tool, enabled)| enabled.then(|| tool))
 593                            .collect::<Vec<_>>(),
 594                        cx,
 595                    )
 596                })
 597            }
 598        }
 599    }
 600
 601    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 602        cx.subscribe(
 603            &self.project.read(cx).context_server_store(),
 604            Self::handle_context_server_event,
 605        )
 606        .detach();
 607    }
 608
 609    fn handle_context_server_event(
 610        &mut self,
 611        context_server_store: Entity<ContextServerStore>,
 612        event: &project::context_server_store::Event,
 613        cx: &mut Context<Self>,
 614    ) {
 615        let tool_working_set = self.tools.clone();
 616        match event {
 617            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 618                match status {
 619                    ContextServerStatus::Running => {
 620                        if let Some(server) =
 621                            context_server_store.read(cx).get_running_server(server_id)
 622                        {
 623                            let context_server_manager = context_server_store.clone();
 624                            cx.spawn({
 625                                let server = server.clone();
 626                                let server_id = server_id.clone();
 627                                async move |this, cx| {
 628                                    let Some(protocol) = server.client() else {
 629                                        return;
 630                                    };
 631
 632                                    if protocol.capable(context_server::protocol::ServerCapability::Tools) {
 633                                        if let Some(tools) = protocol.list_tools().await.log_err() {
 634                                            let tool_ids = tool_working_set
 635                                                .update(cx, |tool_working_set, _| {
 636                                                    tools
 637                                                        .tools
 638                                                        .into_iter()
 639                                                        .map(|tool| {
 640                                                            log::info!(
 641                                                                "registering context server tool: {:?}",
 642                                                                tool.name
 643                                                            );
 644                                                            tool_working_set.insert(Arc::new(
 645                                                                ContextServerTool::new(
 646                                                                    context_server_manager.clone(),
 647                                                                    server.id(),
 648                                                                    tool,
 649                                                                ),
 650                                                            ))
 651                                                        })
 652                                                        .collect::<Vec<_>>()
 653                                                })
 654                                                .log_err();
 655
 656                                            if let Some(tool_ids) = tool_ids {
 657                                                this.update(cx, |this, cx| {
 658                                                    this.context_server_tool_ids
 659                                                        .insert(server_id, tool_ids);
 660                                                    this.load_default_profile(cx);
 661                                                })
 662                                                .log_err();
 663                                            }
 664                                        }
 665                                    }
 666                                }
 667                            })
 668                            .detach();
 669                        }
 670                    }
 671                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 672                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 673                            tool_working_set.update(cx, |tool_working_set, _| {
 674                                tool_working_set.remove(&tool_ids);
 675                            });
 676                            self.load_default_profile(cx);
 677                        }
 678                    }
 679                    _ => {}
 680                }
 681            }
 682        }
 683    }
 684}
 685
 686#[derive(Debug, Clone, Serialize, Deserialize)]
 687pub struct SerializedThreadMetadata {
 688    pub id: ThreadId,
 689    pub summary: SharedString,
 690    pub updated_at: DateTime<Utc>,
 691}
 692
 693#[derive(Serialize, Deserialize, Debug)]
 694pub struct SerializedThread {
 695    pub version: String,
 696    pub summary: SharedString,
 697    pub updated_at: DateTime<Utc>,
 698    pub messages: Vec<SerializedMessage>,
 699    #[serde(default)]
 700    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 701    #[serde(default)]
 702    pub cumulative_token_usage: TokenUsage,
 703    #[serde(default)]
 704    pub request_token_usage: Vec<TokenUsage>,
 705    #[serde(default)]
 706    pub detailed_summary_state: DetailedSummaryState,
 707    #[serde(default)]
 708    pub exceeded_window_error: Option<ExceededWindowError>,
 709    #[serde(default)]
 710    pub model: Option<SerializedLanguageModel>,
 711    #[serde(default)]
 712    pub completion_mode: Option<CompletionMode>,
 713    #[serde(default)]
 714    pub tool_use_limit_reached: bool,
 715}
 716
 717#[derive(Serialize, Deserialize, Debug)]
 718pub struct SerializedLanguageModel {
 719    pub provider: String,
 720    pub model: String,
 721}
 722
 723impl SerializedThread {
 724    pub const VERSION: &'static str = "0.2.0";
 725
 726    pub fn from_json(json: &[u8]) -> Result<Self> {
 727        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 728        match saved_thread_json.get("version") {
 729            Some(serde_json::Value::String(version)) => match version.as_str() {
 730                SerializedThreadV0_1_0::VERSION => {
 731                    let saved_thread =
 732                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 733                    Ok(saved_thread.upgrade())
 734                }
 735                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 736                    saved_thread_json,
 737                )?),
 738                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 739            },
 740            None => {
 741                let saved_thread =
 742                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 743                Ok(saved_thread.upgrade())
 744            }
 745            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 746        }
 747    }
 748}
 749
 750#[derive(Serialize, Deserialize, Debug)]
 751pub struct SerializedThreadV0_1_0(
 752    // The structure did not change, so we are reusing the latest SerializedThread.
 753    // When making the next version, make sure this points to SerializedThreadV0_2_0
 754    SerializedThread,
 755);
 756
 757impl SerializedThreadV0_1_0 {
 758    pub const VERSION: &'static str = "0.1.0";
 759
 760    pub fn upgrade(self) -> SerializedThread {
 761        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 762
 763        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 764
 765        for message in self.0.messages {
 766            if message.role == Role::User && !message.tool_results.is_empty() {
 767                if let Some(last_message) = messages.last_mut() {
 768                    debug_assert!(last_message.role == Role::Assistant);
 769
 770                    last_message.tool_results = message.tool_results;
 771                    continue;
 772                }
 773            }
 774
 775            messages.push(message);
 776        }
 777
 778        SerializedThread { messages, ..self.0 }
 779    }
 780}
 781
 782#[derive(Debug, Serialize, Deserialize)]
 783pub struct SerializedMessage {
 784    pub id: MessageId,
 785    pub role: Role,
 786    #[serde(default)]
 787    pub segments: Vec<SerializedMessageSegment>,
 788    #[serde(default)]
 789    pub tool_uses: Vec<SerializedToolUse>,
 790    #[serde(default)]
 791    pub tool_results: Vec<SerializedToolResult>,
 792    #[serde(default)]
 793    pub context: String,
 794    #[serde(default)]
 795    pub creases: Vec<SerializedCrease>,
 796    #[serde(default)]
 797    pub is_hidden: bool,
 798}
 799
 800#[derive(Debug, Serialize, Deserialize)]
 801#[serde(tag = "type")]
 802pub enum SerializedMessageSegment {
 803    #[serde(rename = "text")]
 804    Text {
 805        text: String,
 806    },
 807    #[serde(rename = "thinking")]
 808    Thinking {
 809        text: String,
 810        #[serde(skip_serializing_if = "Option::is_none")]
 811        signature: Option<String>,
 812    },
 813    RedactedThinking {
 814        data: Vec<u8>,
 815    },
 816}
 817
 818#[derive(Debug, Serialize, Deserialize)]
 819pub struct SerializedToolUse {
 820    pub id: LanguageModelToolUseId,
 821    pub name: SharedString,
 822    pub input: serde_json::Value,
 823}
 824
 825#[derive(Debug, Serialize, Deserialize)]
 826pub struct SerializedToolResult {
 827    pub tool_use_id: LanguageModelToolUseId,
 828    pub is_error: bool,
 829    pub content: LanguageModelToolResultContent,
 830    pub output: Option<serde_json::Value>,
 831}
 832
 833#[derive(Serialize, Deserialize)]
 834struct LegacySerializedThread {
 835    pub summary: SharedString,
 836    pub updated_at: DateTime<Utc>,
 837    pub messages: Vec<LegacySerializedMessage>,
 838    #[serde(default)]
 839    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 840}
 841
 842impl LegacySerializedThread {
 843    pub fn upgrade(self) -> SerializedThread {
 844        SerializedThread {
 845            version: SerializedThread::VERSION.to_string(),
 846            summary: self.summary,
 847            updated_at: self.updated_at,
 848            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 849            initial_project_snapshot: self.initial_project_snapshot,
 850            cumulative_token_usage: TokenUsage::default(),
 851            request_token_usage: Vec::new(),
 852            detailed_summary_state: DetailedSummaryState::default(),
 853            exceeded_window_error: None,
 854            model: None,
 855            completion_mode: None,
 856            tool_use_limit_reached: false,
 857        }
 858    }
 859}
 860
 861#[derive(Debug, Serialize, Deserialize)]
 862struct LegacySerializedMessage {
 863    pub id: MessageId,
 864    pub role: Role,
 865    pub text: String,
 866    #[serde(default)]
 867    pub tool_uses: Vec<SerializedToolUse>,
 868    #[serde(default)]
 869    pub tool_results: Vec<SerializedToolResult>,
 870}
 871
 872impl LegacySerializedMessage {
 873    fn upgrade(self) -> SerializedMessage {
 874        SerializedMessage {
 875            id: self.id,
 876            role: self.role,
 877            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 878            tool_uses: self.tool_uses,
 879            tool_results: self.tool_results,
 880            context: String::new(),
 881            creases: Vec::new(),
 882            is_hidden: false,
 883        }
 884    }
 885}
 886
 887#[derive(Debug, Serialize, Deserialize)]
 888pub struct SerializedCrease {
 889    pub start: usize,
 890    pub end: usize,
 891    pub icon_path: SharedString,
 892    pub label: SharedString,
 893}
 894
 895struct GlobalThreadsDatabase(
 896    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 897);
 898
 899impl Global for GlobalThreadsDatabase {}
 900
 901pub(crate) struct ThreadsDatabase {
 902    executor: BackgroundExecutor,
 903    connection: Arc<Mutex<Connection>>,
 904}
 905
 906impl ThreadsDatabase {
 907    fn connection(&self) -> Arc<Mutex<Connection>> {
 908        self.connection.clone()
 909    }
 910
 911    const COMPRESSION_LEVEL: i32 = 3;
 912}
 913
 914impl Bind for ThreadId {
 915    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 916        self.to_string().bind(statement, start_index)
 917    }
 918}
 919
 920impl Column for ThreadId {
 921    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 922        let (id_str, next_index) = String::column(statement, start_index)?;
 923        Ok((ThreadId::from(id_str.as_str()), next_index))
 924    }
 925}
 926
 927impl ThreadsDatabase {
 928    fn global_future(
 929        cx: &mut App,
 930    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 931        GlobalThreadsDatabase::global(cx).0.clone()
 932    }
 933
 934    fn init(cx: &mut App) {
 935        let executor = cx.background_executor().clone();
 936        let database_future = executor
 937            .spawn({
 938                let executor = executor.clone();
 939                let threads_dir = paths::data_dir().join("threads");
 940                async move { ThreadsDatabase::new(threads_dir, executor) }
 941            })
 942            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 943            .boxed()
 944            .shared();
 945
 946        cx.set_global(GlobalThreadsDatabase(database_future));
 947    }
 948
 949    pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
 950        std::fs::create_dir_all(&threads_dir)?;
 951
 952        let sqlite_path = threads_dir.join("threads.db");
 953        let mdb_path = threads_dir.join("threads-db.1.mdb");
 954
 955        let needs_migration_from_heed = mdb_path.exists();
 956
 957        let connection = Connection::open_file(&sqlite_path.to_string_lossy());
 958
 959        connection.exec(indoc! {"
 960                CREATE TABLE IF NOT EXISTS threads (
 961                    id TEXT PRIMARY KEY,
 962                    summary TEXT NOT NULL,
 963                    updated_at TEXT NOT NULL,
 964                    data_type TEXT NOT NULL,
 965                    data BLOB NOT NULL
 966                )
 967            "})?()
 968        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 969
 970        let db = Self {
 971            executor: executor.clone(),
 972            connection: Arc::new(Mutex::new(connection)),
 973        };
 974
 975        if needs_migration_from_heed {
 976            let db_connection = db.connection();
 977            let executor_clone = executor.clone();
 978            executor
 979                .spawn(async move {
 980                    log::info!("Starting threads.db migration");
 981                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 982                    std::fs::remove_dir_all(mdb_path)?;
 983                    log::info!("threads.db migrated to sqlite");
 984                    Ok::<(), anyhow::Error>(())
 985                })
 986                .detach();
 987        }
 988
 989        Ok(db)
 990    }
 991
 992    // Remove this migration after 2025-09-01
 993    fn migrate_from_heed(
 994        mdb_path: &Path,
 995        connection: Arc<Mutex<Connection>>,
 996        _executor: BackgroundExecutor,
 997    ) -> Result<()> {
 998        use heed::types::SerdeBincode;
 999        struct SerializedThreadHeed(SerializedThread);
1000
1001        impl heed::BytesEncode<'_> for SerializedThreadHeed {
1002            type EItem = SerializedThreadHeed;
1003
1004            fn bytes_encode(
1005                item: &Self::EItem,
1006            ) -> Result<std::borrow::Cow<[u8]>, heed::BoxedError> {
1007                serde_json::to_vec(&item.0)
1008                    .map(std::borrow::Cow::Owned)
1009                    .map_err(Into::into)
1010            }
1011        }
1012
1013        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
1014            type DItem = SerializedThreadHeed;
1015
1016            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
1017                SerializedThread::from_json(bytes)
1018                    .map(SerializedThreadHeed)
1019                    .map_err(Into::into)
1020            }
1021        }
1022
1023        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
1024
1025        let env = unsafe {
1026            heed::EnvOpenOptions::new()
1027                .map_size(ONE_GB_IN_BYTES)
1028                .max_dbs(1)
1029                .open(mdb_path)?
1030        };
1031
1032        let txn = env.write_txn()?;
1033        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
1034            .open_database(&txn, Some("threads"))?
1035            .ok_or_else(|| anyhow!("threads database not found"))?;
1036
1037        for result in threads.iter(&txn)? {
1038            let (thread_id, thread_heed) = result?;
1039            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
1040        }
1041
1042        Ok(())
1043    }
1044
1045    fn save_thread_sync(
1046        connection: &Arc<Mutex<Connection>>,
1047        id: ThreadId,
1048        thread: SerializedThread,
1049    ) -> Result<()> {
1050        let json_data = serde_json::to_string(&thread)?;
1051        let summary = thread.summary.to_string();
1052        let updated_at = thread.updated_at.to_rfc3339();
1053
1054        let connection = connection.lock().unwrap();
1055
1056        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1057        let data_type = DataType::Zstd;
1058        let data = compressed;
1059
1060        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1061            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1062        "})?;
1063
1064        insert((id, summary, updated_at, data_type, data))?;
1065
1066        Ok(())
1067    }
1068
1069    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1070        let connection = self.connection.clone();
1071
1072        self.executor.spawn(async move {
1073            let connection = connection.lock().unwrap();
1074            let mut select =
1075                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1076                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1077            "})?;
1078
1079            let rows = select(())?;
1080            let mut threads = Vec::new();
1081
1082            for (id, summary, updated_at) in rows {
1083                threads.push(SerializedThreadMetadata {
1084                    id,
1085                    summary: summary.into(),
1086                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1087                });
1088            }
1089
1090            Ok(threads)
1091        })
1092    }
1093
1094    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1095        let connection = self.connection.clone();
1096
1097        self.executor.spawn(async move {
1098            let connection = connection.lock().unwrap();
1099            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1100                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1101            "})?;
1102
1103            let rows = select(id)?;
1104            if let Some((data_type, data)) = rows.into_iter().next() {
1105                let json_data = match data_type {
1106                    DataType::Zstd => {
1107                        let decompressed = zstd::decode_all(&data[..])?;
1108                        String::from_utf8(decompressed)?
1109                    }
1110                    DataType::Json => String::from_utf8(data)?,
1111                };
1112
1113                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1114                Ok(Some(thread))
1115            } else {
1116                Ok(None)
1117            }
1118        })
1119    }
1120
1121    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1122        let connection = self.connection.clone();
1123
1124        self.executor
1125            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1126    }
1127
1128    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1129        let connection = self.connection.clone();
1130
1131        self.executor.spawn(async move {
1132            let connection = connection.lock().unwrap();
1133
1134            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1135                DELETE FROM threads WHERE id = ?
1136            "})?;
1137
1138            delete(id)?;
1139
1140            Ok(())
1141        })
1142    }
1143}