thread_store.rs

   1use crate::{
   2    context_server_tool::ContextServerTool,
   3    thread::{
   4        DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
   5    },
   6};
   7use agent_settings::{AgentProfileId, CompletionMode};
   8use anyhow::{Context as _, Result, anyhow};
   9use assistant_tool::{ToolId, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use context_server::ContextServerId;
  13use futures::{
  14    FutureExt as _, StreamExt as _,
  15    channel::{mpsc, oneshot},
  16    future::{self, BoxFuture, Shared},
  17};
  18use gpui::{
  19    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  20    Subscription, Task, Window, prelude::*,
  21};
  22use indoc::indoc;
  23use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  24use project::context_server_store::{ContextServerStatus, ContextServerStore};
  25use project::{Project, ProjectItem, ProjectPath, Worktree};
  26use prompt_store::{
  27    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  28    UserRulesContext, WorktreeContext,
  29};
  30use serde::{Deserialize, Serialize};
  31use sqlez::{
  32    bindable::{Bind, Column},
  33    connection::Connection,
  34    statement::Statement,
  35};
  36use std::{
  37    cell::{Ref, RefCell},
  38    path::{Path, PathBuf},
  39    rc::Rc,
  40    sync::{Arc, Mutex},
  41};
  42use util::ResultExt as _;
  43
  44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  45pub enum DataType {
  46    #[serde(rename = "json")]
  47    Json,
  48    #[serde(rename = "zstd")]
  49    Zstd,
  50}
  51
  52impl Bind for DataType {
  53    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  54        let value = match self {
  55            DataType::Json => "json",
  56            DataType::Zstd => "zstd",
  57        };
  58        value.bind(statement, start_index)
  59    }
  60}
  61
  62impl Column for DataType {
  63    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  64        let (value, next_index) = String::column(statement, start_index)?;
  65        let data_type = match value.as_str() {
  66            "json" => DataType::Json,
  67            "zstd" => DataType::Zstd,
  68            _ => anyhow::bail!("Unknown data type: {}", value),
  69        };
  70        Ok((data_type, next_index))
  71    }
  72}
  73
  74const RULES_FILE_NAMES: [&'static str; 9] = [
  75    ".rules",
  76    ".cursorrules",
  77    ".windsurfrules",
  78    ".clinerules",
  79    ".github/copilot-instructions.md",
  80    "CLAUDE.md",
  81    "AGENT.md",
  82    "AGENTS.md",
  83    "GEMINI.md",
  84];
  85
  86pub fn init(cx: &mut App) {
  87    ThreadsDatabase::init(cx);
  88}
  89
  90/// A system prompt shared by all threads created by this ThreadStore
  91#[derive(Clone, Default)]
  92pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  93
  94impl SharedProjectContext {
  95    pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
  96        self.0.borrow()
  97    }
  98}
  99
 100pub type TextThreadStore = assistant_context::ContextStore;
 101
 102pub struct ThreadStore {
 103    project: Entity<Project>,
 104    tools: Entity<ToolWorkingSet>,
 105    prompt_builder: Arc<PromptBuilder>,
 106    prompt_store: Option<Entity<PromptStore>>,
 107    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 108    threads: Vec<SerializedThreadMetadata>,
 109    project_context: SharedProjectContext,
 110    reload_system_prompt_tx: mpsc::Sender<()>,
 111    _reload_system_prompt_task: Task<()>,
 112    _subscriptions: Vec<Subscription>,
 113}
 114
 115pub struct RulesLoadingError {
 116    pub message: SharedString,
 117}
 118
 119impl EventEmitter<RulesLoadingError> for ThreadStore {}
 120
 121impl ThreadStore {
 122    pub fn load(
 123        project: Entity<Project>,
 124        tools: Entity<ToolWorkingSet>,
 125        prompt_store: Option<Entity<PromptStore>>,
 126        prompt_builder: Arc<PromptBuilder>,
 127        cx: &mut App,
 128    ) -> Task<Result<Entity<Self>>> {
 129        cx.spawn(async move |cx| {
 130            let (thread_store, ready_rx) = cx.update(|cx| {
 131                let mut option_ready_rx = None;
 132                let thread_store = cx.new(|cx| {
 133                    let (thread_store, ready_rx) =
 134                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 135                    option_ready_rx = Some(ready_rx);
 136                    thread_store
 137                });
 138                (thread_store, option_ready_rx.take().unwrap())
 139            })?;
 140            ready_rx.await?;
 141            Ok(thread_store)
 142        })
 143    }
 144
 145    fn new(
 146        project: Entity<Project>,
 147        tools: Entity<ToolWorkingSet>,
 148        prompt_builder: Arc<PromptBuilder>,
 149        prompt_store: Option<Entity<PromptStore>>,
 150        cx: &mut Context<Self>,
 151    ) -> (Self, oneshot::Receiver<()>) {
 152        let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 153
 154        if let Some(prompt_store) = prompt_store.as_ref() {
 155            subscriptions.push(cx.subscribe(
 156                prompt_store,
 157                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 158                    this.enqueue_system_prompt_reload();
 159                },
 160            ))
 161        }
 162
 163        // This channel and task prevent concurrent and redundant loading of the system prompt.
 164        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 165        let (ready_tx, ready_rx) = oneshot::channel();
 166        let mut ready_tx = Some(ready_tx);
 167        let reload_system_prompt_task = cx.spawn({
 168            let prompt_store = prompt_store.clone();
 169            async move |thread_store, cx| {
 170                loop {
 171                    let Some(reload_task) = thread_store
 172                        .update(cx, |thread_store, cx| {
 173                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 174                        })
 175                        .ok()
 176                    else {
 177                        return;
 178                    };
 179                    reload_task.await;
 180                    if let Some(ready_tx) = ready_tx.take() {
 181                        ready_tx.send(()).ok();
 182                    }
 183                    reload_system_prompt_rx.next().await;
 184                }
 185            }
 186        });
 187
 188        let this = Self {
 189            project,
 190            tools,
 191            prompt_builder,
 192            prompt_store,
 193            context_server_tool_ids: HashMap::default(),
 194            threads: Vec::new(),
 195            project_context: SharedProjectContext::default(),
 196            reload_system_prompt_tx,
 197            _reload_system_prompt_task: reload_system_prompt_task,
 198            _subscriptions: subscriptions,
 199        };
 200        this.register_context_server_handlers(cx);
 201        this.reload(cx).detach_and_log_err(cx);
 202        (this, ready_rx)
 203    }
 204
 205    fn handle_project_event(
 206        &mut self,
 207        _project: Entity<Project>,
 208        event: &project::Event,
 209        _cx: &mut Context<Self>,
 210    ) {
 211        match event {
 212            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 213                self.enqueue_system_prompt_reload();
 214            }
 215            project::Event::WorktreeUpdatedEntries(_, items) => {
 216                if items.iter().any(|(path, _, _)| {
 217                    RULES_FILE_NAMES
 218                        .iter()
 219                        .any(|name| path.as_ref() == Path::new(name))
 220                }) {
 221                    self.enqueue_system_prompt_reload();
 222                }
 223            }
 224            _ => {}
 225        }
 226    }
 227
 228    fn enqueue_system_prompt_reload(&mut self) {
 229        self.reload_system_prompt_tx.try_send(()).ok();
 230    }
 231
 232    // Note that this should only be called from `reload_system_prompt_task`.
 233    fn reload_system_prompt(
 234        &self,
 235        prompt_store: Option<Entity<PromptStore>>,
 236        cx: &mut Context<Self>,
 237    ) -> Task<()> {
 238        let worktrees = self
 239            .project
 240            .read(cx)
 241            .visible_worktrees(cx)
 242            .collect::<Vec<_>>();
 243        let worktree_tasks = worktrees
 244            .into_iter()
 245            .map(|worktree| {
 246                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 247            })
 248            .collect::<Vec<_>>();
 249        let default_user_rules_task = match prompt_store {
 250            None => Task::ready(vec![]),
 251            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 252                let prompts = prompt_store.default_prompt_metadata();
 253                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 254                    let contents = prompt_store.load(prompt_metadata.id, cx);
 255                    async move { (contents.await, prompt_metadata) }
 256                });
 257                cx.background_spawn(future::join_all(load_tasks))
 258            }),
 259        };
 260
 261        cx.spawn(async move |this, cx| {
 262            let (worktrees, default_user_rules) =
 263                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 264
 265            let worktrees = worktrees
 266                .into_iter()
 267                .map(|(worktree, rules_error)| {
 268                    if let Some(rules_error) = rules_error {
 269                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 270                    }
 271                    worktree
 272                })
 273                .collect::<Vec<_>>();
 274
 275            let default_user_rules = default_user_rules
 276                .into_iter()
 277                .flat_map(|(contents, prompt_metadata)| match contents {
 278                    Ok(contents) => Some(UserRulesContext {
 279                        uuid: match prompt_metadata.id {
 280                            PromptId::User { uuid } => uuid,
 281                            PromptId::EditWorkflow => return None,
 282                        },
 283                        title: prompt_metadata.title.map(|title| title.to_string()),
 284                        contents,
 285                    }),
 286                    Err(err) => {
 287                        this.update(cx, |_, cx| {
 288                            cx.emit(RulesLoadingError {
 289                                message: format!("{err:?}").into(),
 290                            });
 291                        })
 292                        .ok();
 293                        None
 294                    }
 295                })
 296                .collect::<Vec<_>>();
 297
 298            this.update(cx, |this, _cx| {
 299                *this.project_context.0.borrow_mut() =
 300                    Some(ProjectContext::new(worktrees, default_user_rules));
 301            })
 302            .ok();
 303        })
 304    }
 305
 306    fn load_worktree_info_for_system_prompt(
 307        worktree: Entity<Worktree>,
 308        project: Entity<Project>,
 309        cx: &mut App,
 310    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 311        let tree = worktree.read(cx);
 312        let root_name = tree.root_name().into();
 313        let abs_path = tree.abs_path();
 314
 315        let mut context = WorktreeContext {
 316            root_name,
 317            abs_path,
 318            rules_file: None,
 319        };
 320
 321        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 322        let Some(rules_task) = rules_task else {
 323            return Task::ready((context, None));
 324        };
 325
 326        cx.spawn(async move |_| {
 327            let (rules_file, rules_file_error) = match rules_task.await {
 328                Ok(rules_file) => (Some(rules_file), None),
 329                Err(err) => (
 330                    None,
 331                    Some(RulesLoadingError {
 332                        message: format!("{err}").into(),
 333                    }),
 334                ),
 335            };
 336            context.rules_file = rules_file;
 337            (context, rules_file_error)
 338        })
 339    }
 340
 341    fn load_worktree_rules_file(
 342        worktree: Entity<Worktree>,
 343        project: Entity<Project>,
 344        cx: &mut App,
 345    ) -> Option<Task<Result<RulesFileContext>>> {
 346        let worktree = worktree.read(cx);
 347        let worktree_id = worktree.id();
 348        let selected_rules_file = RULES_FILE_NAMES
 349            .into_iter()
 350            .filter_map(|name| {
 351                worktree
 352                    .entry_for_path(name)
 353                    .filter(|entry| entry.is_file())
 354                    .map(|entry| entry.path.clone())
 355            })
 356            .next();
 357
 358        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 359        // supported. This doesn't seem to occur often in GitHub repositories.
 360        selected_rules_file.map(|path_in_worktree| {
 361            let project_path = ProjectPath {
 362                worktree_id,
 363                path: path_in_worktree.clone(),
 364            };
 365            let buffer_task =
 366                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 367            let rope_task = cx.spawn(async move |cx| {
 368                buffer_task.await?.read_with(cx, |buffer, cx| {
 369                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 370                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 371                })?
 372            });
 373            // Build a string from the rope on a background thread.
 374            cx.background_spawn(async move {
 375                let (project_entry_id, rope) = rope_task.await?;
 376                anyhow::Ok(RulesFileContext {
 377                    path_in_worktree,
 378                    text: rope.to_string().trim().to_string(),
 379                    project_entry_id: project_entry_id.to_usize(),
 380                })
 381            })
 382        })
 383    }
 384
 385    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 386        &self.prompt_store
 387    }
 388
 389    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 390        self.tools.clone()
 391    }
 392
 393    /// Returns the number of threads.
 394    pub fn thread_count(&self) -> usize {
 395        self.threads.len()
 396    }
 397
 398    pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 399        // ordering is from "ORDER BY" in `list_threads`
 400        self.threads.iter()
 401    }
 402
 403    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 404        cx.new(|cx| {
 405            Thread::new(
 406                self.project.clone(),
 407                self.tools.clone(),
 408                self.prompt_builder.clone(),
 409                self.project_context.clone(),
 410                cx,
 411            )
 412        })
 413    }
 414
 415    pub fn create_thread_from_serialized(
 416        &mut self,
 417        serialized: SerializedThread,
 418        cx: &mut Context<Self>,
 419    ) -> Entity<Thread> {
 420        cx.new(|cx| {
 421            Thread::deserialize(
 422                ThreadId::new(),
 423                serialized,
 424                self.project.clone(),
 425                self.tools.clone(),
 426                self.prompt_builder.clone(),
 427                self.project_context.clone(),
 428                None,
 429                cx,
 430            )
 431        })
 432    }
 433
 434    pub fn open_thread(
 435        &self,
 436        id: &ThreadId,
 437        window: &mut Window,
 438        cx: &mut Context<Self>,
 439    ) -> Task<Result<Entity<Thread>>> {
 440        let id = id.clone();
 441        let database_future = ThreadsDatabase::global_future(cx);
 442        let this = cx.weak_entity();
 443        window.spawn(cx, async move |cx| {
 444            let database = database_future.await.map_err(|err| anyhow!(err))?;
 445            let thread = database
 446                .try_find_thread(id.clone())
 447                .await?
 448                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 449
 450            let thread = this.update_in(cx, |this, window, cx| {
 451                cx.new(|cx| {
 452                    Thread::deserialize(
 453                        id.clone(),
 454                        thread,
 455                        this.project.clone(),
 456                        this.tools.clone(),
 457                        this.prompt_builder.clone(),
 458                        this.project_context.clone(),
 459                        Some(window),
 460                        cx,
 461                    )
 462                })
 463            })?;
 464
 465            Ok(thread)
 466        })
 467    }
 468
 469    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 470        let (metadata, serialized_thread) =
 471            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 472
 473        let database_future = ThreadsDatabase::global_future(cx);
 474        cx.spawn(async move |this, cx| {
 475            let serialized_thread = serialized_thread.await?;
 476            let database = database_future.await.map_err(|err| anyhow!(err))?;
 477            database.save_thread(metadata, serialized_thread).await?;
 478
 479            this.update(cx, |this, cx| this.reload(cx))?.await
 480        })
 481    }
 482
 483    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 484        let id = id.clone();
 485        let database_future = ThreadsDatabase::global_future(cx);
 486        cx.spawn(async move |this, cx| {
 487            let database = database_future.await.map_err(|err| anyhow!(err))?;
 488            database.delete_thread(id.clone()).await?;
 489
 490            this.update(cx, |this, cx| {
 491                this.threads.retain(|thread| thread.id != id);
 492                cx.notify();
 493            })
 494        })
 495    }
 496
 497    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 498        let database_future = ThreadsDatabase::global_future(cx);
 499        cx.spawn(async move |this, cx| {
 500            let threads = database_future
 501                .await
 502                .map_err(|err| anyhow!(err))?
 503                .list_threads()
 504                .await?;
 505
 506            this.update(cx, |this, cx| {
 507                this.threads = threads;
 508                cx.notify();
 509            })
 510        })
 511    }
 512
 513    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 514        let context_server_store = self.project.read(cx).context_server_store();
 515        cx.subscribe(&context_server_store, Self::handle_context_server_event)
 516            .detach();
 517
 518        // Check for any servers that were already running before the handler was registered
 519        for server in context_server_store.read(cx).running_servers() {
 520            self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
 521        }
 522    }
 523
 524    fn handle_context_server_event(
 525        &mut self,
 526        context_server_store: Entity<ContextServerStore>,
 527        event: &project::context_server_store::Event,
 528        cx: &mut Context<Self>,
 529    ) {
 530        let tool_working_set = self.tools.clone();
 531        match event {
 532            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 533                match status {
 534                    ContextServerStatus::Starting => {}
 535                    ContextServerStatus::Running => {
 536                        self.load_context_server_tools(server_id.clone(), context_server_store, cx);
 537                    }
 538                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 539                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 540                            tool_working_set.update(cx, |tool_working_set, _| {
 541                                tool_working_set.remove(&tool_ids);
 542                            });
 543                        }
 544                    }
 545                }
 546            }
 547        }
 548    }
 549
 550    fn load_context_server_tools(
 551        &self,
 552        server_id: ContextServerId,
 553        context_server_store: Entity<ContextServerStore>,
 554        cx: &mut Context<Self>,
 555    ) {
 556        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
 557            return;
 558        };
 559        let tool_working_set = self.tools.clone();
 560        cx.spawn(async move |this, cx| {
 561            let Some(protocol) = server.client() else {
 562                return;
 563            };
 564
 565            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
 566                if let Some(response) = protocol
 567                    .request::<context_server::types::requests::ListTools>(())
 568                    .await
 569                    .log_err()
 570                {
 571                    let tool_ids = tool_working_set
 572                        .update(cx, |tool_working_set, _| {
 573                            response
 574                                .tools
 575                                .into_iter()
 576                                .map(|tool| {
 577                                    log::info!("registering context server tool: {:?}", tool.name);
 578                                    tool_working_set.insert(Arc::new(ContextServerTool::new(
 579                                        context_server_store.clone(),
 580                                        server.id(),
 581                                        tool,
 582                                    )))
 583                                })
 584                                .collect::<Vec<_>>()
 585                        })
 586                        .log_err();
 587
 588                    if let Some(tool_ids) = tool_ids {
 589                        this.update(cx, |this, _| {
 590                            this.context_server_tool_ids.insert(server_id, tool_ids);
 591                        })
 592                        .log_err();
 593                    }
 594                }
 595            }
 596        })
 597        .detach();
 598    }
 599}
 600
 601#[derive(Debug, Clone, Serialize, Deserialize)]
 602pub struct SerializedThreadMetadata {
 603    pub id: ThreadId,
 604    pub summary: SharedString,
 605    pub updated_at: DateTime<Utc>,
 606}
 607
 608#[derive(Serialize, Deserialize, Debug, PartialEq)]
 609pub struct SerializedThread {
 610    pub version: String,
 611    pub summary: SharedString,
 612    pub updated_at: DateTime<Utc>,
 613    pub messages: Vec<SerializedMessage>,
 614    #[serde(default)]
 615    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 616    #[serde(default)]
 617    pub cumulative_token_usage: TokenUsage,
 618    #[serde(default)]
 619    pub request_token_usage: Vec<TokenUsage>,
 620    #[serde(default)]
 621    pub detailed_summary_state: DetailedSummaryState,
 622    #[serde(default)]
 623    pub exceeded_window_error: Option<ExceededWindowError>,
 624    #[serde(default)]
 625    pub model: Option<SerializedLanguageModel>,
 626    #[serde(default)]
 627    pub completion_mode: Option<CompletionMode>,
 628    #[serde(default)]
 629    pub tool_use_limit_reached: bool,
 630    #[serde(default)]
 631    pub profile: Option<AgentProfileId>,
 632}
 633
 634#[derive(Serialize, Deserialize, Debug, PartialEq)]
 635pub struct SerializedLanguageModel {
 636    pub provider: String,
 637    pub model: String,
 638}
 639
 640impl SerializedThread {
 641    pub const VERSION: &'static str = "0.2.0";
 642
 643    pub fn from_json(json: &[u8]) -> Result<Self> {
 644        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 645        match saved_thread_json.get("version") {
 646            Some(serde_json::Value::String(version)) => match version.as_str() {
 647                SerializedThreadV0_1_0::VERSION => {
 648                    let saved_thread =
 649                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 650                    Ok(saved_thread.upgrade())
 651                }
 652                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 653                    saved_thread_json,
 654                )?),
 655                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 656            },
 657            None => {
 658                let saved_thread =
 659                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 660                Ok(saved_thread.upgrade())
 661            }
 662            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 663        }
 664    }
 665}
 666
 667#[derive(Serialize, Deserialize, Debug)]
 668pub struct SerializedThreadV0_1_0(
 669    // The structure did not change, so we are reusing the latest SerializedThread.
 670    // When making the next version, make sure this points to SerializedThreadV0_2_0
 671    SerializedThread,
 672);
 673
 674impl SerializedThreadV0_1_0 {
 675    pub const VERSION: &'static str = "0.1.0";
 676
 677    pub fn upgrade(self) -> SerializedThread {
 678        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 679
 680        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 681
 682        for message in self.0.messages {
 683            if message.role == Role::User && !message.tool_results.is_empty() {
 684                if let Some(last_message) = messages.last_mut() {
 685                    debug_assert!(last_message.role == Role::Assistant);
 686
 687                    last_message.tool_results = message.tool_results;
 688                    continue;
 689                }
 690            }
 691
 692            messages.push(message);
 693        }
 694
 695        SerializedThread {
 696            messages,
 697            version: SerializedThread::VERSION.to_string(),
 698            ..self.0
 699        }
 700    }
 701}
 702
 703#[derive(Debug, Serialize, Deserialize, PartialEq)]
 704pub struct SerializedMessage {
 705    pub id: MessageId,
 706    pub role: Role,
 707    #[serde(default)]
 708    pub segments: Vec<SerializedMessageSegment>,
 709    #[serde(default)]
 710    pub tool_uses: Vec<SerializedToolUse>,
 711    #[serde(default)]
 712    pub tool_results: Vec<SerializedToolResult>,
 713    #[serde(default)]
 714    pub context: String,
 715    #[serde(default)]
 716    pub creases: Vec<SerializedCrease>,
 717    #[serde(default)]
 718    pub is_hidden: bool,
 719}
 720
 721#[derive(Debug, Serialize, Deserialize, PartialEq)]
 722#[serde(tag = "type")]
 723pub enum SerializedMessageSegment {
 724    #[serde(rename = "text")]
 725    Text {
 726        text: String,
 727    },
 728    #[serde(rename = "thinking")]
 729    Thinking {
 730        text: String,
 731        #[serde(skip_serializing_if = "Option::is_none")]
 732        signature: Option<String>,
 733    },
 734    RedactedThinking {
 735        data: String,
 736    },
 737}
 738
 739#[derive(Debug, Serialize, Deserialize, PartialEq)]
 740pub struct SerializedToolUse {
 741    pub id: LanguageModelToolUseId,
 742    pub name: SharedString,
 743    pub input: serde_json::Value,
 744}
 745
 746#[derive(Debug, Serialize, Deserialize, PartialEq)]
 747pub struct SerializedToolResult {
 748    pub tool_use_id: LanguageModelToolUseId,
 749    pub is_error: bool,
 750    pub content: LanguageModelToolResultContent,
 751    pub output: Option<serde_json::Value>,
 752}
 753
 754#[derive(Serialize, Deserialize)]
 755struct LegacySerializedThread {
 756    pub summary: SharedString,
 757    pub updated_at: DateTime<Utc>,
 758    pub messages: Vec<LegacySerializedMessage>,
 759    #[serde(default)]
 760    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 761}
 762
 763impl LegacySerializedThread {
 764    pub fn upgrade(self) -> SerializedThread {
 765        SerializedThread {
 766            version: SerializedThread::VERSION.to_string(),
 767            summary: self.summary,
 768            updated_at: self.updated_at,
 769            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 770            initial_project_snapshot: self.initial_project_snapshot,
 771            cumulative_token_usage: TokenUsage::default(),
 772            request_token_usage: Vec::new(),
 773            detailed_summary_state: DetailedSummaryState::default(),
 774            exceeded_window_error: None,
 775            model: None,
 776            completion_mode: None,
 777            tool_use_limit_reached: false,
 778            profile: None,
 779        }
 780    }
 781}
 782
 783#[derive(Debug, Serialize, Deserialize)]
 784struct LegacySerializedMessage {
 785    pub id: MessageId,
 786    pub role: Role,
 787    pub text: String,
 788    #[serde(default)]
 789    pub tool_uses: Vec<SerializedToolUse>,
 790    #[serde(default)]
 791    pub tool_results: Vec<SerializedToolResult>,
 792}
 793
 794impl LegacySerializedMessage {
 795    fn upgrade(self) -> SerializedMessage {
 796        SerializedMessage {
 797            id: self.id,
 798            role: self.role,
 799            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 800            tool_uses: self.tool_uses,
 801            tool_results: self.tool_results,
 802            context: String::new(),
 803            creases: Vec::new(),
 804            is_hidden: false,
 805        }
 806    }
 807}
 808
 809#[derive(Debug, Serialize, Deserialize, PartialEq)]
 810pub struct SerializedCrease {
 811    pub start: usize,
 812    pub end: usize,
 813    pub icon_path: SharedString,
 814    pub label: SharedString,
 815}
 816
 817struct GlobalThreadsDatabase(
 818    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 819);
 820
 821impl Global for GlobalThreadsDatabase {}
 822
 823pub(crate) struct ThreadsDatabase {
 824    executor: BackgroundExecutor,
 825    connection: Arc<Mutex<Connection>>,
 826}
 827
 828impl ThreadsDatabase {
 829    fn connection(&self) -> Arc<Mutex<Connection>> {
 830        self.connection.clone()
 831    }
 832
 833    const COMPRESSION_LEVEL: i32 = 3;
 834}
 835
 836impl Bind for ThreadId {
 837    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 838        self.to_string().bind(statement, start_index)
 839    }
 840}
 841
 842impl Column for ThreadId {
 843    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 844        let (id_str, next_index) = String::column(statement, start_index)?;
 845        Ok((ThreadId::from(id_str.as_str()), next_index))
 846    }
 847}
 848
 849impl ThreadsDatabase {
 850    fn global_future(
 851        cx: &mut App,
 852    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 853        GlobalThreadsDatabase::global(cx).0.clone()
 854    }
 855
 856    fn init(cx: &mut App) {
 857        let executor = cx.background_executor().clone();
 858        let database_future = executor
 859            .spawn({
 860                let executor = executor.clone();
 861                let threads_dir = paths::data_dir().join("threads");
 862                async move { ThreadsDatabase::new(threads_dir, executor) }
 863            })
 864            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 865            .boxed()
 866            .shared();
 867
 868        cx.set_global(GlobalThreadsDatabase(database_future));
 869    }
 870
 871    pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
 872        std::fs::create_dir_all(&threads_dir)?;
 873
 874        let sqlite_path = threads_dir.join("threads.db");
 875        let mdb_path = threads_dir.join("threads-db.1.mdb");
 876
 877        let needs_migration_from_heed = mdb_path.exists();
 878
 879        let connection = Connection::open_file(&sqlite_path.to_string_lossy());
 880
 881        connection.exec(indoc! {"
 882                CREATE TABLE IF NOT EXISTS threads (
 883                    id TEXT PRIMARY KEY,
 884                    summary TEXT NOT NULL,
 885                    updated_at TEXT NOT NULL,
 886                    data_type TEXT NOT NULL,
 887                    data BLOB NOT NULL
 888                )
 889            "})?()
 890        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 891
 892        let db = Self {
 893            executor: executor.clone(),
 894            connection: Arc::new(Mutex::new(connection)),
 895        };
 896
 897        if needs_migration_from_heed {
 898            let db_connection = db.connection();
 899            let executor_clone = executor.clone();
 900            executor
 901                .spawn(async move {
 902                    log::info!("Starting threads.db migration");
 903                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 904                    std::fs::remove_dir_all(mdb_path)?;
 905                    log::info!("threads.db migrated to sqlite");
 906                    Ok::<(), anyhow::Error>(())
 907                })
 908                .detach();
 909        }
 910
 911        Ok(db)
 912    }
 913
 914    // Remove this migration after 2025-09-01
 915    fn migrate_from_heed(
 916        mdb_path: &Path,
 917        connection: Arc<Mutex<Connection>>,
 918        _executor: BackgroundExecutor,
 919    ) -> Result<()> {
 920        use heed::types::SerdeBincode;
 921        struct SerializedThreadHeed(SerializedThread);
 922
 923        impl heed::BytesEncode<'_> for SerializedThreadHeed {
 924            type EItem = SerializedThreadHeed;
 925
 926            fn bytes_encode(
 927                item: &Self::EItem,
 928            ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
 929                serde_json::to_vec(&item.0)
 930                    .map(std::borrow::Cow::Owned)
 931                    .map_err(Into::into)
 932            }
 933        }
 934
 935        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
 936            type DItem = SerializedThreadHeed;
 937
 938            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
 939                SerializedThread::from_json(bytes)
 940                    .map(SerializedThreadHeed)
 941                    .map_err(Into::into)
 942            }
 943        }
 944
 945        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
 946
 947        let env = unsafe {
 948            heed::EnvOpenOptions::new()
 949                .map_size(ONE_GB_IN_BYTES)
 950                .max_dbs(1)
 951                .open(mdb_path)?
 952        };
 953
 954        let txn = env.write_txn()?;
 955        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
 956            .open_database(&txn, Some("threads"))?
 957            .ok_or_else(|| anyhow!("threads database not found"))?;
 958
 959        for result in threads.iter(&txn)? {
 960            let (thread_id, thread_heed) = result?;
 961            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
 962        }
 963
 964        Ok(())
 965    }
 966
 967    fn save_thread_sync(
 968        connection: &Arc<Mutex<Connection>>,
 969        id: ThreadId,
 970        thread: SerializedThread,
 971    ) -> Result<()> {
 972        let json_data = serde_json::to_string(&thread)?;
 973        let summary = thread.summary.to_string();
 974        let updated_at = thread.updated_at.to_rfc3339();
 975
 976        let connection = connection.lock().unwrap();
 977
 978        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
 979        let data_type = DataType::Zstd;
 980        let data = compressed;
 981
 982        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
 983            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
 984        "})?;
 985
 986        insert((id, summary, updated_at, data_type, data))?;
 987
 988        Ok(())
 989    }
 990
 991    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
 992        let connection = self.connection.clone();
 993
 994        self.executor.spawn(async move {
 995            let connection = connection.lock().unwrap();
 996            let mut select =
 997                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
 998                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
 999            "})?;
1000
1001            let rows = select(())?;
1002            let mut threads = Vec::new();
1003
1004            for (id, summary, updated_at) in rows {
1005                threads.push(SerializedThreadMetadata {
1006                    id,
1007                    summary: summary.into(),
1008                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1009                });
1010            }
1011
1012            Ok(threads)
1013        })
1014    }
1015
1016    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1017        let connection = self.connection.clone();
1018
1019        self.executor.spawn(async move {
1020            let connection = connection.lock().unwrap();
1021            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1022                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1023            "})?;
1024
1025            let rows = select(id)?;
1026            if let Some((data_type, data)) = rows.into_iter().next() {
1027                let json_data = match data_type {
1028                    DataType::Zstd => {
1029                        let decompressed = zstd::decode_all(&data[..])?;
1030                        String::from_utf8(decompressed)?
1031                    }
1032                    DataType::Json => String::from_utf8(data)?,
1033                };
1034
1035                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1036                Ok(Some(thread))
1037            } else {
1038                Ok(None)
1039            }
1040        })
1041    }
1042
1043    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1044        let connection = self.connection.clone();
1045
1046        self.executor
1047            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1048    }
1049
1050    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1051        let connection = self.connection.clone();
1052
1053        self.executor.spawn(async move {
1054            let connection = connection.lock().unwrap();
1055
1056            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1057                DELETE FROM threads WHERE id = ?
1058            "})?;
1059
1060            delete(id)?;
1061
1062            Ok(())
1063        })
1064    }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070    use crate::thread::{DetailedSummaryState, MessageId};
1071    use chrono::Utc;
1072    use language_model::{Role, TokenUsage};
1073    use pretty_assertions::assert_eq;
1074
1075    #[test]
1076    fn test_legacy_serialized_thread_upgrade() {
1077        let updated_at = Utc::now();
1078        let legacy_thread = LegacySerializedThread {
1079            summary: "Test conversation".into(),
1080            updated_at,
1081            messages: vec![LegacySerializedMessage {
1082                id: MessageId(1),
1083                role: Role::User,
1084                text: "Hello, world!".to_string(),
1085                tool_uses: vec![],
1086                tool_results: vec![],
1087            }],
1088            initial_project_snapshot: None,
1089        };
1090
1091        let upgraded = legacy_thread.upgrade();
1092
1093        assert_eq!(
1094            upgraded,
1095            SerializedThread {
1096                summary: "Test conversation".into(),
1097                updated_at,
1098                messages: vec![SerializedMessage {
1099                    id: MessageId(1),
1100                    role: Role::User,
1101                    segments: vec![SerializedMessageSegment::Text {
1102                        text: "Hello, world!".to_string()
1103                    }],
1104                    tool_uses: vec![],
1105                    tool_results: vec![],
1106                    context: "".to_string(),
1107                    creases: vec![],
1108                    is_hidden: false
1109                }],
1110                version: SerializedThread::VERSION.to_string(),
1111                initial_project_snapshot: None,
1112                cumulative_token_usage: TokenUsage::default(),
1113                request_token_usage: vec![],
1114                detailed_summary_state: DetailedSummaryState::default(),
1115                exceeded_window_error: None,
1116                model: None,
1117                completion_mode: None,
1118                tool_use_limit_reached: false,
1119                profile: None
1120            }
1121        )
1122    }
1123
1124    #[test]
1125    fn test_serialized_threadv0_1_0_upgrade() {
1126        let updated_at = Utc::now();
1127        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1128            summary: "Test conversation".into(),
1129            updated_at,
1130            messages: vec![
1131                SerializedMessage {
1132                    id: MessageId(1),
1133                    role: Role::User,
1134                    segments: vec![SerializedMessageSegment::Text {
1135                        text: "Use tool_1".to_string(),
1136                    }],
1137                    tool_uses: vec![],
1138                    tool_results: vec![],
1139                    context: "".to_string(),
1140                    creases: vec![],
1141                    is_hidden: false,
1142                },
1143                SerializedMessage {
1144                    id: MessageId(2),
1145                    role: Role::Assistant,
1146                    segments: vec![SerializedMessageSegment::Text {
1147                        text: "I want to use a tool".to_string(),
1148                    }],
1149                    tool_uses: vec![SerializedToolUse {
1150                        id: "abc".into(),
1151                        name: "tool_1".into(),
1152                        input: serde_json::Value::Null,
1153                    }],
1154                    tool_results: vec![],
1155                    context: "".to_string(),
1156                    creases: vec![],
1157                    is_hidden: false,
1158                },
1159                SerializedMessage {
1160                    id: MessageId(1),
1161                    role: Role::User,
1162                    segments: vec![SerializedMessageSegment::Text {
1163                        text: "Here is the tool result".to_string(),
1164                    }],
1165                    tool_uses: vec![],
1166                    tool_results: vec![SerializedToolResult {
1167                        tool_use_id: "abc".into(),
1168                        is_error: false,
1169                        content: LanguageModelToolResultContent::Text("abcdef".into()),
1170                        output: Some(serde_json::Value::Null),
1171                    }],
1172                    context: "".to_string(),
1173                    creases: vec![],
1174                    is_hidden: false,
1175                },
1176            ],
1177            version: SerializedThreadV0_1_0::VERSION.to_string(),
1178            initial_project_snapshot: None,
1179            cumulative_token_usage: TokenUsage::default(),
1180            request_token_usage: vec![],
1181            detailed_summary_state: DetailedSummaryState::default(),
1182            exceeded_window_error: None,
1183            model: None,
1184            completion_mode: None,
1185            tool_use_limit_reached: false,
1186            profile: None,
1187        });
1188        let upgraded = thread_v0_1_0.upgrade();
1189
1190        assert_eq!(
1191            upgraded,
1192            SerializedThread {
1193                summary: "Test conversation".into(),
1194                updated_at,
1195                messages: vec![
1196                    SerializedMessage {
1197                        id: MessageId(1),
1198                        role: Role::User,
1199                        segments: vec![SerializedMessageSegment::Text {
1200                            text: "Use tool_1".to_string()
1201                        }],
1202                        tool_uses: vec![],
1203                        tool_results: vec![],
1204                        context: "".to_string(),
1205                        creases: vec![],
1206                        is_hidden: false
1207                    },
1208                    SerializedMessage {
1209                        id: MessageId(2),
1210                        role: Role::Assistant,
1211                        segments: vec![SerializedMessageSegment::Text {
1212                            text: "I want to use a tool".to_string(),
1213                        }],
1214                        tool_uses: vec![SerializedToolUse {
1215                            id: "abc".into(),
1216                            name: "tool_1".into(),
1217                            input: serde_json::Value::Null,
1218                        }],
1219                        tool_results: vec![SerializedToolResult {
1220                            tool_use_id: "abc".into(),
1221                            is_error: false,
1222                            content: LanguageModelToolResultContent::Text("abcdef".into()),
1223                            output: Some(serde_json::Value::Null),
1224                        }],
1225                        context: "".to_string(),
1226                        creases: vec![],
1227                        is_hidden: false,
1228                    },
1229                ],
1230                version: SerializedThread::VERSION.to_string(),
1231                initial_project_snapshot: None,
1232                cumulative_token_usage: TokenUsage::default(),
1233                request_token_usage: vec![],
1234                detailed_summary_state: DetailedSummaryState::default(),
1235                exceeded_window_error: None,
1236                model: None,
1237                completion_mode: None,
1238                tool_use_limit_reached: false,
1239                profile: None
1240            }
1241        )
1242    }
1243}