thread_store.rs

   1use std::cell::{Ref, RefCell};
   2use std::path::{Path, PathBuf};
   3use std::rc::Rc;
   4use std::sync::{Arc, Mutex};
   5
   6use agent_settings::{AgentProfileId, CompletionMode};
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_tool::{ToolId, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::HashMap;
  11use context_server::ContextServerId;
  12use futures::channel::{mpsc, oneshot};
  13use futures::future::{self, BoxFuture, Shared};
  14use futures::{FutureExt as _, StreamExt as _};
  15use gpui::{
  16    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  17    Subscription, Task, prelude::*,
  18};
  19
  20use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  21use project::context_server_store::{ContextServerStatus, ContextServerStore};
  22use project::{Project, ProjectItem, ProjectPath, Worktree};
  23use prompt_store::{
  24    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  25    UserRulesContext, WorktreeContext,
  26};
  27use serde::{Deserialize, Serialize};
  28use ui::Window;
  29use util::ResultExt as _;
  30
  31use crate::context_server_tool::ContextServerTool;
  32use crate::thread::{
  33    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
  34};
  35use indoc::indoc;
  36use sqlez::{
  37    bindable::{Bind, Column},
  38    connection::Connection,
  39    statement::Statement,
  40};
  41
  42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  43pub enum DataType {
  44    #[serde(rename = "json")]
  45    Json,
  46    #[serde(rename = "zstd")]
  47    Zstd,
  48}
  49
  50impl Bind for DataType {
  51    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  52        let value = match self {
  53            DataType::Json => "json",
  54            DataType::Zstd => "zstd",
  55        };
  56        value.bind(statement, start_index)
  57    }
  58}
  59
  60impl Column for DataType {
  61    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  62        let (value, next_index) = String::column(statement, start_index)?;
  63        let data_type = match value.as_str() {
  64            "json" => DataType::Json,
  65            "zstd" => DataType::Zstd,
  66            _ => anyhow::bail!("Unknown data type: {}", value),
  67        };
  68        Ok((data_type, next_index))
  69    }
  70}
  71
  72const RULES_FILE_NAMES: [&'static str; 8] = [
  73    ".rules",
  74    ".cursorrules",
  75    ".windsurfrules",
  76    ".clinerules",
  77    ".github/copilot-instructions.md",
  78    "CLAUDE.md",
  79    "AGENT.md",
  80    "AGENTS.md",
  81];
  82
  83pub fn init(cx: &mut App) {
  84    ThreadsDatabase::init(cx);
  85}
  86
  87/// A system prompt shared by all threads created by this ThreadStore
  88#[derive(Clone, Default)]
  89pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  90
  91impl SharedProjectContext {
  92    pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
  93        self.0.borrow()
  94    }
  95}
  96
  97pub type TextThreadStore = assistant_context_editor::ContextStore;
  98
  99pub struct ThreadStore {
 100    project: Entity<Project>,
 101    tools: Entity<ToolWorkingSet>,
 102    prompt_builder: Arc<PromptBuilder>,
 103    prompt_store: Option<Entity<PromptStore>>,
 104    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 105    threads: Vec<SerializedThreadMetadata>,
 106    project_context: SharedProjectContext,
 107    reload_system_prompt_tx: mpsc::Sender<()>,
 108    _reload_system_prompt_task: Task<()>,
 109    _subscriptions: Vec<Subscription>,
 110}
 111
 112pub struct RulesLoadingError {
 113    pub message: SharedString,
 114}
 115
 116impl EventEmitter<RulesLoadingError> for ThreadStore {}
 117
 118impl ThreadStore {
 119    pub fn load(
 120        project: Entity<Project>,
 121        tools: Entity<ToolWorkingSet>,
 122        prompt_store: Option<Entity<PromptStore>>,
 123        prompt_builder: Arc<PromptBuilder>,
 124        cx: &mut App,
 125    ) -> Task<Result<Entity<Self>>> {
 126        cx.spawn(async move |cx| {
 127            let (thread_store, ready_rx) = cx.update(|cx| {
 128                let mut option_ready_rx = None;
 129                let thread_store = cx.new(|cx| {
 130                    let (thread_store, ready_rx) =
 131                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 132                    option_ready_rx = Some(ready_rx);
 133                    thread_store
 134                });
 135                (thread_store, option_ready_rx.take().unwrap())
 136            })?;
 137            ready_rx.await?;
 138            Ok(thread_store)
 139        })
 140    }
 141
 142    fn new(
 143        project: Entity<Project>,
 144        tools: Entity<ToolWorkingSet>,
 145        prompt_builder: Arc<PromptBuilder>,
 146        prompt_store: Option<Entity<PromptStore>>,
 147        cx: &mut Context<Self>,
 148    ) -> (Self, oneshot::Receiver<()>) {
 149        let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 150
 151        if let Some(prompt_store) = prompt_store.as_ref() {
 152            subscriptions.push(cx.subscribe(
 153                prompt_store,
 154                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 155                    this.enqueue_system_prompt_reload();
 156                },
 157            ))
 158        }
 159
 160        // This channel and task prevent concurrent and redundant loading of the system prompt.
 161        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 162        let (ready_tx, ready_rx) = oneshot::channel();
 163        let mut ready_tx = Some(ready_tx);
 164        let reload_system_prompt_task = cx.spawn({
 165            let prompt_store = prompt_store.clone();
 166            async move |thread_store, cx| {
 167                loop {
 168                    let Some(reload_task) = thread_store
 169                        .update(cx, |thread_store, cx| {
 170                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 171                        })
 172                        .ok()
 173                    else {
 174                        return;
 175                    };
 176                    reload_task.await;
 177                    if let Some(ready_tx) = ready_tx.take() {
 178                        ready_tx.send(()).ok();
 179                    }
 180                    reload_system_prompt_rx.next().await;
 181                }
 182            }
 183        });
 184
 185        let this = Self {
 186            project,
 187            tools,
 188            prompt_builder,
 189            prompt_store,
 190            context_server_tool_ids: HashMap::default(),
 191            threads: Vec::new(),
 192            project_context: SharedProjectContext::default(),
 193            reload_system_prompt_tx,
 194            _reload_system_prompt_task: reload_system_prompt_task,
 195            _subscriptions: subscriptions,
 196        };
 197        this.register_context_server_handlers(cx);
 198        this.reload(cx).detach_and_log_err(cx);
 199        (this, ready_rx)
 200    }
 201
 202    fn handle_project_event(
 203        &mut self,
 204        _project: Entity<Project>,
 205        event: &project::Event,
 206        _cx: &mut Context<Self>,
 207    ) {
 208        match event {
 209            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 210                self.enqueue_system_prompt_reload();
 211            }
 212            project::Event::WorktreeUpdatedEntries(_, items) => {
 213                if items.iter().any(|(path, _, _)| {
 214                    RULES_FILE_NAMES
 215                        .iter()
 216                        .any(|name| path.as_ref() == Path::new(name))
 217                }) {
 218                    self.enqueue_system_prompt_reload();
 219                }
 220            }
 221            _ => {}
 222        }
 223    }
 224
 225    fn enqueue_system_prompt_reload(&mut self) {
 226        self.reload_system_prompt_tx.try_send(()).ok();
 227    }
 228
 229    // Note that this should only be called from `reload_system_prompt_task`.
 230    fn reload_system_prompt(
 231        &self,
 232        prompt_store: Option<Entity<PromptStore>>,
 233        cx: &mut Context<Self>,
 234    ) -> Task<()> {
 235        let worktrees = self
 236            .project
 237            .read(cx)
 238            .visible_worktrees(cx)
 239            .collect::<Vec<_>>();
 240        let worktree_tasks = worktrees
 241            .into_iter()
 242            .map(|worktree| {
 243                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 244            })
 245            .collect::<Vec<_>>();
 246        let default_user_rules_task = match prompt_store {
 247            None => Task::ready(vec![]),
 248            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 249                let prompts = prompt_store.default_prompt_metadata();
 250                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 251                    let contents = prompt_store.load(prompt_metadata.id, cx);
 252                    async move { (contents.await, prompt_metadata) }
 253                });
 254                cx.background_spawn(future::join_all(load_tasks))
 255            }),
 256        };
 257
 258        cx.spawn(async move |this, cx| {
 259            let (worktrees, default_user_rules) =
 260                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 261
 262            let worktrees = worktrees
 263                .into_iter()
 264                .map(|(worktree, rules_error)| {
 265                    if let Some(rules_error) = rules_error {
 266                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 267                    }
 268                    worktree
 269                })
 270                .collect::<Vec<_>>();
 271
 272            let default_user_rules = default_user_rules
 273                .into_iter()
 274                .flat_map(|(contents, prompt_metadata)| match contents {
 275                    Ok(contents) => Some(UserRulesContext {
 276                        uuid: match prompt_metadata.id {
 277                            PromptId::User { uuid } => uuid,
 278                            PromptId::EditWorkflow => return None,
 279                        },
 280                        title: prompt_metadata.title.map(|title| title.to_string()),
 281                        contents,
 282                    }),
 283                    Err(err) => {
 284                        this.update(cx, |_, cx| {
 285                            cx.emit(RulesLoadingError {
 286                                message: format!("{err:?}").into(),
 287                            });
 288                        })
 289                        .ok();
 290                        None
 291                    }
 292                })
 293                .collect::<Vec<_>>();
 294
 295            this.update(cx, |this, _cx| {
 296                *this.project_context.0.borrow_mut() =
 297                    Some(ProjectContext::new(worktrees, default_user_rules));
 298            })
 299            .ok();
 300        })
 301    }
 302
 303    fn load_worktree_info_for_system_prompt(
 304        worktree: Entity<Worktree>,
 305        project: Entity<Project>,
 306        cx: &mut App,
 307    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 308        let tree = worktree.read(cx);
 309        let root_name = tree.root_name().into();
 310        let abs_path = tree.abs_path();
 311
 312        let mut context = WorktreeContext {
 313            root_name,
 314            abs_path,
 315            rules_file: None,
 316        };
 317
 318        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 319        let Some(rules_task) = rules_task else {
 320            return Task::ready((context, None));
 321        };
 322
 323        cx.spawn(async move |_| {
 324            let (rules_file, rules_file_error) = match rules_task.await {
 325                Ok(rules_file) => (Some(rules_file), None),
 326                Err(err) => (
 327                    None,
 328                    Some(RulesLoadingError {
 329                        message: format!("{err}").into(),
 330                    }),
 331                ),
 332            };
 333            context.rules_file = rules_file;
 334            (context, rules_file_error)
 335        })
 336    }
 337
 338    fn load_worktree_rules_file(
 339        worktree: Entity<Worktree>,
 340        project: Entity<Project>,
 341        cx: &mut App,
 342    ) -> Option<Task<Result<RulesFileContext>>> {
 343        let worktree = worktree.read(cx);
 344        let worktree_id = worktree.id();
 345        let selected_rules_file = RULES_FILE_NAMES
 346            .into_iter()
 347            .filter_map(|name| {
 348                worktree
 349                    .entry_for_path(name)
 350                    .filter(|entry| entry.is_file())
 351                    .map(|entry| entry.path.clone())
 352            })
 353            .next();
 354
 355        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 356        // supported. This doesn't seem to occur often in GitHub repositories.
 357        selected_rules_file.map(|path_in_worktree| {
 358            let project_path = ProjectPath {
 359                worktree_id,
 360                path: path_in_worktree.clone(),
 361            };
 362            let buffer_task =
 363                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 364            let rope_task = cx.spawn(async move |cx| {
 365                buffer_task.await?.read_with(cx, |buffer, cx| {
 366                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 367                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 368                })?
 369            });
 370            // Build a string from the rope on a background thread.
 371            cx.background_spawn(async move {
 372                let (project_entry_id, rope) = rope_task.await?;
 373                anyhow::Ok(RulesFileContext {
 374                    path_in_worktree,
 375                    text: rope.to_string().trim().to_string(),
 376                    project_entry_id: project_entry_id.to_usize(),
 377                })
 378            })
 379        })
 380    }
 381
 382    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 383        &self.prompt_store
 384    }
 385
 386    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 387        self.tools.clone()
 388    }
 389
 390    /// Returns the number of threads.
 391    pub fn thread_count(&self) -> usize {
 392        self.threads.len()
 393    }
 394
 395    pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 396        // ordering is from "ORDER BY" in `list_threads`
 397        self.threads.iter()
 398    }
 399
 400    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 401        cx.new(|cx| {
 402            Thread::new(
 403                self.project.clone(),
 404                self.tools.clone(),
 405                self.prompt_builder.clone(),
 406                self.project_context.clone(),
 407                cx,
 408            )
 409        })
 410    }
 411
 412    pub fn create_thread_from_serialized(
 413        &mut self,
 414        serialized: SerializedThread,
 415        cx: &mut Context<Self>,
 416    ) -> Entity<Thread> {
 417        cx.new(|cx| {
 418            Thread::deserialize(
 419                ThreadId::new(),
 420                serialized,
 421                self.project.clone(),
 422                self.tools.clone(),
 423                self.prompt_builder.clone(),
 424                self.project_context.clone(),
 425                None,
 426                cx,
 427            )
 428        })
 429    }
 430
 431    pub fn open_thread(
 432        &self,
 433        id: &ThreadId,
 434        window: &mut Window,
 435        cx: &mut Context<Self>,
 436    ) -> Task<Result<Entity<Thread>>> {
 437        let id = id.clone();
 438        let database_future = ThreadsDatabase::global_future(cx);
 439        let this = cx.weak_entity();
 440        window.spawn(cx, async move |cx| {
 441            let database = database_future.await.map_err(|err| anyhow!(err))?;
 442            let thread = database
 443                .try_find_thread(id.clone())
 444                .await?
 445                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 446
 447            let thread = this.update_in(cx, |this, window, cx| {
 448                cx.new(|cx| {
 449                    Thread::deserialize(
 450                        id.clone(),
 451                        thread,
 452                        this.project.clone(),
 453                        this.tools.clone(),
 454                        this.prompt_builder.clone(),
 455                        this.project_context.clone(),
 456                        Some(window),
 457                        cx,
 458                    )
 459                })
 460            })?;
 461
 462            Ok(thread)
 463        })
 464    }
 465
 466    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 467        let (metadata, serialized_thread) =
 468            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 469
 470        let database_future = ThreadsDatabase::global_future(cx);
 471        cx.spawn(async move |this, cx| {
 472            let serialized_thread = serialized_thread.await?;
 473            let database = database_future.await.map_err(|err| anyhow!(err))?;
 474            database.save_thread(metadata, serialized_thread).await?;
 475
 476            this.update(cx, |this, cx| this.reload(cx))?.await
 477        })
 478    }
 479
 480    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 481        let id = id.clone();
 482        let database_future = ThreadsDatabase::global_future(cx);
 483        cx.spawn(async move |this, cx| {
 484            let database = database_future.await.map_err(|err| anyhow!(err))?;
 485            database.delete_thread(id.clone()).await?;
 486
 487            this.update(cx, |this, cx| {
 488                this.threads.retain(|thread| thread.id != id);
 489                cx.notify();
 490            })
 491        })
 492    }
 493
 494    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 495        let database_future = ThreadsDatabase::global_future(cx);
 496        cx.spawn(async move |this, cx| {
 497            let threads = database_future
 498                .await
 499                .map_err(|err| anyhow!(err))?
 500                .list_threads()
 501                .await?;
 502
 503            this.update(cx, |this, cx| {
 504                this.threads = threads;
 505                cx.notify();
 506            })
 507        })
 508    }
 509
 510    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 511        let context_server_store = self.project.read(cx).context_server_store();
 512        cx.subscribe(&context_server_store, Self::handle_context_server_event)
 513            .detach();
 514
 515        // Check for any servers that were already running before the handler was registered
 516        for server in context_server_store.read(cx).running_servers() {
 517            self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
 518        }
 519    }
 520
 521    fn handle_context_server_event(
 522        &mut self,
 523        context_server_store: Entity<ContextServerStore>,
 524        event: &project::context_server_store::Event,
 525        cx: &mut Context<Self>,
 526    ) {
 527        let tool_working_set = self.tools.clone();
 528        match event {
 529            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 530                match status {
 531                    ContextServerStatus::Starting => {}
 532                    ContextServerStatus::Running => {
 533                        self.load_context_server_tools(server_id.clone(), context_server_store, cx);
 534                    }
 535                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 536                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 537                            tool_working_set.update(cx, |tool_working_set, _| {
 538                                tool_working_set.remove(&tool_ids);
 539                            });
 540                        }
 541                    }
 542                }
 543            }
 544        }
 545    }
 546
 547    fn load_context_server_tools(
 548        &self,
 549        server_id: ContextServerId,
 550        context_server_store: Entity<ContextServerStore>,
 551        cx: &mut Context<Self>,
 552    ) {
 553        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
 554            return;
 555        };
 556        let tool_working_set = self.tools.clone();
 557        cx.spawn(async move |this, cx| {
 558            let Some(protocol) = server.client() else {
 559                return;
 560            };
 561
 562            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
 563                if let Some(response) = protocol
 564                    .request::<context_server::types::requests::ListTools>(())
 565                    .await
 566                    .log_err()
 567                {
 568                    let tool_ids = tool_working_set
 569                        .update(cx, |tool_working_set, _| {
 570                            response
 571                                .tools
 572                                .into_iter()
 573                                .map(|tool| {
 574                                    log::info!("registering context server tool: {:?}", tool.name);
 575                                    tool_working_set.insert(Arc::new(ContextServerTool::new(
 576                                        context_server_store.clone(),
 577                                        server.id(),
 578                                        tool,
 579                                    )))
 580                                })
 581                                .collect::<Vec<_>>()
 582                        })
 583                        .log_err();
 584
 585                    if let Some(tool_ids) = tool_ids {
 586                        this.update(cx, |this, _| {
 587                            this.context_server_tool_ids.insert(server_id, tool_ids);
 588                        })
 589                        .log_err();
 590                    }
 591                }
 592            }
 593        })
 594        .detach();
 595    }
 596}
 597
 598#[derive(Debug, Clone, Serialize, Deserialize)]
 599pub struct SerializedThreadMetadata {
 600    pub id: ThreadId,
 601    pub summary: SharedString,
 602    pub updated_at: DateTime<Utc>,
 603}
 604
 605#[derive(Serialize, Deserialize, Debug, PartialEq)]
 606pub struct SerializedThread {
 607    pub version: String,
 608    pub summary: SharedString,
 609    pub updated_at: DateTime<Utc>,
 610    pub messages: Vec<SerializedMessage>,
 611    #[serde(default)]
 612    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 613    #[serde(default)]
 614    pub cumulative_token_usage: TokenUsage,
 615    #[serde(default)]
 616    pub request_token_usage: Vec<TokenUsage>,
 617    #[serde(default)]
 618    pub detailed_summary_state: DetailedSummaryState,
 619    #[serde(default)]
 620    pub exceeded_window_error: Option<ExceededWindowError>,
 621    #[serde(default)]
 622    pub model: Option<SerializedLanguageModel>,
 623    #[serde(default)]
 624    pub completion_mode: Option<CompletionMode>,
 625    #[serde(default)]
 626    pub tool_use_limit_reached: bool,
 627    #[serde(default)]
 628    pub profile: Option<AgentProfileId>,
 629}
 630
 631#[derive(Serialize, Deserialize, Debug, PartialEq)]
 632pub struct SerializedLanguageModel {
 633    pub provider: String,
 634    pub model: String,
 635}
 636
 637impl SerializedThread {
 638    pub const VERSION: &'static str = "0.2.0";
 639
 640    pub fn from_json(json: &[u8]) -> Result<Self> {
 641        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 642        match saved_thread_json.get("version") {
 643            Some(serde_json::Value::String(version)) => match version.as_str() {
 644                SerializedThreadV0_1_0::VERSION => {
 645                    let saved_thread =
 646                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 647                    Ok(saved_thread.upgrade())
 648                }
 649                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 650                    saved_thread_json,
 651                )?),
 652                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 653            },
 654            None => {
 655                let saved_thread =
 656                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 657                Ok(saved_thread.upgrade())
 658            }
 659            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 660        }
 661    }
 662}
 663
 664#[derive(Serialize, Deserialize, Debug)]
 665pub struct SerializedThreadV0_1_0(
 666    // The structure did not change, so we are reusing the latest SerializedThread.
 667    // When making the next version, make sure this points to SerializedThreadV0_2_0
 668    SerializedThread,
 669);
 670
 671impl SerializedThreadV0_1_0 {
 672    pub const VERSION: &'static str = "0.1.0";
 673
 674    pub fn upgrade(self) -> SerializedThread {
 675        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 676
 677        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 678
 679        for message in self.0.messages {
 680            if message.role == Role::User && !message.tool_results.is_empty() {
 681                if let Some(last_message) = messages.last_mut() {
 682                    debug_assert!(last_message.role == Role::Assistant);
 683
 684                    last_message.tool_results = message.tool_results;
 685                    continue;
 686                }
 687            }
 688
 689            messages.push(message);
 690        }
 691
 692        SerializedThread {
 693            messages,
 694            version: SerializedThread::VERSION.to_string(),
 695            ..self.0
 696        }
 697    }
 698}
 699
 700#[derive(Debug, Serialize, Deserialize, PartialEq)]
 701pub struct SerializedMessage {
 702    pub id: MessageId,
 703    pub role: Role,
 704    #[serde(default)]
 705    pub segments: Vec<SerializedMessageSegment>,
 706    #[serde(default)]
 707    pub tool_uses: Vec<SerializedToolUse>,
 708    #[serde(default)]
 709    pub tool_results: Vec<SerializedToolResult>,
 710    #[serde(default)]
 711    pub context: String,
 712    #[serde(default)]
 713    pub creases: Vec<SerializedCrease>,
 714    #[serde(default)]
 715    pub is_hidden: bool,
 716}
 717
 718#[derive(Debug, Serialize, Deserialize, PartialEq)]
 719#[serde(tag = "type")]
 720pub enum SerializedMessageSegment {
 721    #[serde(rename = "text")]
 722    Text {
 723        text: String,
 724    },
 725    #[serde(rename = "thinking")]
 726    Thinking {
 727        text: String,
 728        #[serde(skip_serializing_if = "Option::is_none")]
 729        signature: Option<String>,
 730    },
 731    RedactedThinking {
 732        data: Vec<u8>,
 733    },
 734}
 735
 736#[derive(Debug, Serialize, Deserialize, PartialEq)]
 737pub struct SerializedToolUse {
 738    pub id: LanguageModelToolUseId,
 739    pub name: SharedString,
 740    pub input: serde_json::Value,
 741}
 742
 743#[derive(Debug, Serialize, Deserialize, PartialEq)]
 744pub struct SerializedToolResult {
 745    pub tool_use_id: LanguageModelToolUseId,
 746    pub is_error: bool,
 747    pub content: LanguageModelToolResultContent,
 748    pub output: Option<serde_json::Value>,
 749}
 750
 751#[derive(Serialize, Deserialize)]
 752struct LegacySerializedThread {
 753    pub summary: SharedString,
 754    pub updated_at: DateTime<Utc>,
 755    pub messages: Vec<LegacySerializedMessage>,
 756    #[serde(default)]
 757    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 758}
 759
 760impl LegacySerializedThread {
 761    pub fn upgrade(self) -> SerializedThread {
 762        SerializedThread {
 763            version: SerializedThread::VERSION.to_string(),
 764            summary: self.summary,
 765            updated_at: self.updated_at,
 766            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 767            initial_project_snapshot: self.initial_project_snapshot,
 768            cumulative_token_usage: TokenUsage::default(),
 769            request_token_usage: Vec::new(),
 770            detailed_summary_state: DetailedSummaryState::default(),
 771            exceeded_window_error: None,
 772            model: None,
 773            completion_mode: None,
 774            tool_use_limit_reached: false,
 775            profile: None,
 776        }
 777    }
 778}
 779
 780#[derive(Debug, Serialize, Deserialize)]
 781struct LegacySerializedMessage {
 782    pub id: MessageId,
 783    pub role: Role,
 784    pub text: String,
 785    #[serde(default)]
 786    pub tool_uses: Vec<SerializedToolUse>,
 787    #[serde(default)]
 788    pub tool_results: Vec<SerializedToolResult>,
 789}
 790
 791impl LegacySerializedMessage {
 792    fn upgrade(self) -> SerializedMessage {
 793        SerializedMessage {
 794            id: self.id,
 795            role: self.role,
 796            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 797            tool_uses: self.tool_uses,
 798            tool_results: self.tool_results,
 799            context: String::new(),
 800            creases: Vec::new(),
 801            is_hidden: false,
 802        }
 803    }
 804}
 805
 806#[derive(Debug, Serialize, Deserialize, PartialEq)]
 807pub struct SerializedCrease {
 808    pub start: usize,
 809    pub end: usize,
 810    pub icon_path: SharedString,
 811    pub label: SharedString,
 812}
 813
 814struct GlobalThreadsDatabase(
 815    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 816);
 817
 818impl Global for GlobalThreadsDatabase {}
 819
 820pub(crate) struct ThreadsDatabase {
 821    executor: BackgroundExecutor,
 822    connection: Arc<Mutex<Connection>>,
 823}
 824
 825impl ThreadsDatabase {
 826    fn connection(&self) -> Arc<Mutex<Connection>> {
 827        self.connection.clone()
 828    }
 829
 830    const COMPRESSION_LEVEL: i32 = 3;
 831}
 832
 833impl Bind for ThreadId {
 834    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 835        self.to_string().bind(statement, start_index)
 836    }
 837}
 838
 839impl Column for ThreadId {
 840    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 841        let (id_str, next_index) = String::column(statement, start_index)?;
 842        Ok((ThreadId::from(id_str.as_str()), next_index))
 843    }
 844}
 845
 846impl ThreadsDatabase {
 847    fn global_future(
 848        cx: &mut App,
 849    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 850        GlobalThreadsDatabase::global(cx).0.clone()
 851    }
 852
 853    fn init(cx: &mut App) {
 854        let executor = cx.background_executor().clone();
 855        let database_future = executor
 856            .spawn({
 857                let executor = executor.clone();
 858                let threads_dir = paths::data_dir().join("threads");
 859                async move { ThreadsDatabase::new(threads_dir, executor) }
 860            })
 861            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 862            .boxed()
 863            .shared();
 864
 865        cx.set_global(GlobalThreadsDatabase(database_future));
 866    }
 867
 868    pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
 869        std::fs::create_dir_all(&threads_dir)?;
 870
 871        let sqlite_path = threads_dir.join("threads.db");
 872        let mdb_path = threads_dir.join("threads-db.1.mdb");
 873
 874        let needs_migration_from_heed = mdb_path.exists();
 875
 876        let connection = Connection::open_file(&sqlite_path.to_string_lossy());
 877
 878        connection.exec(indoc! {"
 879                CREATE TABLE IF NOT EXISTS threads (
 880                    id TEXT PRIMARY KEY,
 881                    summary TEXT NOT NULL,
 882                    updated_at TEXT NOT NULL,
 883                    data_type TEXT NOT NULL,
 884                    data BLOB NOT NULL
 885                )
 886            "})?()
 887        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 888
 889        let db = Self {
 890            executor: executor.clone(),
 891            connection: Arc::new(Mutex::new(connection)),
 892        };
 893
 894        if needs_migration_from_heed {
 895            let db_connection = db.connection();
 896            let executor_clone = executor.clone();
 897            executor
 898                .spawn(async move {
 899                    log::info!("Starting threads.db migration");
 900                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 901                    std::fs::remove_dir_all(mdb_path)?;
 902                    log::info!("threads.db migrated to sqlite");
 903                    Ok::<(), anyhow::Error>(())
 904                })
 905                .detach();
 906        }
 907
 908        Ok(db)
 909    }
 910
 911    // Remove this migration after 2025-09-01
 912    fn migrate_from_heed(
 913        mdb_path: &Path,
 914        connection: Arc<Mutex<Connection>>,
 915        _executor: BackgroundExecutor,
 916    ) -> Result<()> {
 917        use heed::types::SerdeBincode;
 918        struct SerializedThreadHeed(SerializedThread);
 919
 920        impl heed::BytesEncode<'_> for SerializedThreadHeed {
 921            type EItem = SerializedThreadHeed;
 922
 923            fn bytes_encode(
 924                item: &Self::EItem,
 925            ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
 926                serde_json::to_vec(&item.0)
 927                    .map(std::borrow::Cow::Owned)
 928                    .map_err(Into::into)
 929            }
 930        }
 931
 932        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
 933            type DItem = SerializedThreadHeed;
 934
 935            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
 936                SerializedThread::from_json(bytes)
 937                    .map(SerializedThreadHeed)
 938                    .map_err(Into::into)
 939            }
 940        }
 941
 942        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
 943
 944        let env = unsafe {
 945            heed::EnvOpenOptions::new()
 946                .map_size(ONE_GB_IN_BYTES)
 947                .max_dbs(1)
 948                .open(mdb_path)?
 949        };
 950
 951        let txn = env.write_txn()?;
 952        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
 953            .open_database(&txn, Some("threads"))?
 954            .ok_or_else(|| anyhow!("threads database not found"))?;
 955
 956        for result in threads.iter(&txn)? {
 957            let (thread_id, thread_heed) = result?;
 958            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
 959        }
 960
 961        Ok(())
 962    }
 963
 964    fn save_thread_sync(
 965        connection: &Arc<Mutex<Connection>>,
 966        id: ThreadId,
 967        thread: SerializedThread,
 968    ) -> Result<()> {
 969        let json_data = serde_json::to_string(&thread)?;
 970        let summary = thread.summary.to_string();
 971        let updated_at = thread.updated_at.to_rfc3339();
 972
 973        let connection = connection.lock().unwrap();
 974
 975        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
 976        let data_type = DataType::Zstd;
 977        let data = compressed;
 978
 979        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
 980            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
 981        "})?;
 982
 983        insert((id, summary, updated_at, data_type, data))?;
 984
 985        Ok(())
 986    }
 987
 988    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
 989        let connection = self.connection.clone();
 990
 991        self.executor.spawn(async move {
 992            let connection = connection.lock().unwrap();
 993            let mut select =
 994                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
 995                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
 996            "})?;
 997
 998            let rows = select(())?;
 999            let mut threads = Vec::new();
1000
1001            for (id, summary, updated_at) in rows {
1002                threads.push(SerializedThreadMetadata {
1003                    id,
1004                    summary: summary.into(),
1005                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1006                });
1007            }
1008
1009            Ok(threads)
1010        })
1011    }
1012
1013    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1014        let connection = self.connection.clone();
1015
1016        self.executor.spawn(async move {
1017            let connection = connection.lock().unwrap();
1018            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1019                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1020            "})?;
1021
1022            let rows = select(id)?;
1023            if let Some((data_type, data)) = rows.into_iter().next() {
1024                let json_data = match data_type {
1025                    DataType::Zstd => {
1026                        let decompressed = zstd::decode_all(&data[..])?;
1027                        String::from_utf8(decompressed)?
1028                    }
1029                    DataType::Json => String::from_utf8(data)?,
1030                };
1031
1032                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1033                Ok(Some(thread))
1034            } else {
1035                Ok(None)
1036            }
1037        })
1038    }
1039
1040    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1041        let connection = self.connection.clone();
1042
1043        self.executor
1044            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1045    }
1046
1047    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1048        let connection = self.connection.clone();
1049
1050        self.executor.spawn(async move {
1051            let connection = connection.lock().unwrap();
1052
1053            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1054                DELETE FROM threads WHERE id = ?
1055            "})?;
1056
1057            delete(id)?;
1058
1059            Ok(())
1060        })
1061    }
1062}
1063
1064#[cfg(test)]
1065mod tests {
1066    use super::*;
1067    use crate::thread::{DetailedSummaryState, MessageId};
1068    use chrono::Utc;
1069    use language_model::{Role, TokenUsage};
1070    use pretty_assertions::assert_eq;
1071
1072    #[test]
1073    fn test_legacy_serialized_thread_upgrade() {
1074        let updated_at = Utc::now();
1075        let legacy_thread = LegacySerializedThread {
1076            summary: "Test conversation".into(),
1077            updated_at,
1078            messages: vec![LegacySerializedMessage {
1079                id: MessageId(1),
1080                role: Role::User,
1081                text: "Hello, world!".to_string(),
1082                tool_uses: vec![],
1083                tool_results: vec![],
1084            }],
1085            initial_project_snapshot: None,
1086        };
1087
1088        let upgraded = legacy_thread.upgrade();
1089
1090        assert_eq!(
1091            upgraded,
1092            SerializedThread {
1093                summary: "Test conversation".into(),
1094                updated_at,
1095                messages: vec![SerializedMessage {
1096                    id: MessageId(1),
1097                    role: Role::User,
1098                    segments: vec![SerializedMessageSegment::Text {
1099                        text: "Hello, world!".to_string()
1100                    }],
1101                    tool_uses: vec![],
1102                    tool_results: vec![],
1103                    context: "".to_string(),
1104                    creases: vec![],
1105                    is_hidden: false
1106                }],
1107                version: SerializedThread::VERSION.to_string(),
1108                initial_project_snapshot: None,
1109                cumulative_token_usage: TokenUsage::default(),
1110                request_token_usage: vec![],
1111                detailed_summary_state: DetailedSummaryState::default(),
1112                exceeded_window_error: None,
1113                model: None,
1114                completion_mode: None,
1115                tool_use_limit_reached: false,
1116                profile: None
1117            }
1118        )
1119    }
1120
1121    #[test]
1122    fn test_serialized_threadv0_1_0_upgrade() {
1123        let updated_at = Utc::now();
1124        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1125            summary: "Test conversation".into(),
1126            updated_at,
1127            messages: vec![
1128                SerializedMessage {
1129                    id: MessageId(1),
1130                    role: Role::User,
1131                    segments: vec![SerializedMessageSegment::Text {
1132                        text: "Use tool_1".to_string(),
1133                    }],
1134                    tool_uses: vec![],
1135                    tool_results: vec![],
1136                    context: "".to_string(),
1137                    creases: vec![],
1138                    is_hidden: false,
1139                },
1140                SerializedMessage {
1141                    id: MessageId(2),
1142                    role: Role::Assistant,
1143                    segments: vec![SerializedMessageSegment::Text {
1144                        text: "I want to use a tool".to_string(),
1145                    }],
1146                    tool_uses: vec![SerializedToolUse {
1147                        id: "abc".into(),
1148                        name: "tool_1".into(),
1149                        input: serde_json::Value::Null,
1150                    }],
1151                    tool_results: vec![],
1152                    context: "".to_string(),
1153                    creases: vec![],
1154                    is_hidden: false,
1155                },
1156                SerializedMessage {
1157                    id: MessageId(1),
1158                    role: Role::User,
1159                    segments: vec![SerializedMessageSegment::Text {
1160                        text: "Here is the tool result".to_string(),
1161                    }],
1162                    tool_uses: vec![],
1163                    tool_results: vec![SerializedToolResult {
1164                        tool_use_id: "abc".into(),
1165                        is_error: false,
1166                        content: LanguageModelToolResultContent::Text("abcdef".into()),
1167                        output: Some(serde_json::Value::Null),
1168                    }],
1169                    context: "".to_string(),
1170                    creases: vec![],
1171                    is_hidden: false,
1172                },
1173            ],
1174            version: SerializedThreadV0_1_0::VERSION.to_string(),
1175            initial_project_snapshot: None,
1176            cumulative_token_usage: TokenUsage::default(),
1177            request_token_usage: vec![],
1178            detailed_summary_state: DetailedSummaryState::default(),
1179            exceeded_window_error: None,
1180            model: None,
1181            completion_mode: None,
1182            tool_use_limit_reached: false,
1183            profile: None,
1184        });
1185        let upgraded = thread_v0_1_0.upgrade();
1186
1187        assert_eq!(
1188            upgraded,
1189            SerializedThread {
1190                summary: "Test conversation".into(),
1191                updated_at,
1192                messages: vec![
1193                    SerializedMessage {
1194                        id: MessageId(1),
1195                        role: Role::User,
1196                        segments: vec![SerializedMessageSegment::Text {
1197                            text: "Use tool_1".to_string()
1198                        }],
1199                        tool_uses: vec![],
1200                        tool_results: vec![],
1201                        context: "".to_string(),
1202                        creases: vec![],
1203                        is_hidden: false
1204                    },
1205                    SerializedMessage {
1206                        id: MessageId(2),
1207                        role: Role::Assistant,
1208                        segments: vec![SerializedMessageSegment::Text {
1209                            text: "I want to use a tool".to_string(),
1210                        }],
1211                        tool_uses: vec![SerializedToolUse {
1212                            id: "abc".into(),
1213                            name: "tool_1".into(),
1214                            input: serde_json::Value::Null,
1215                        }],
1216                        tool_results: vec![SerializedToolResult {
1217                            tool_use_id: "abc".into(),
1218                            is_error: false,
1219                            content: LanguageModelToolResultContent::Text("abcdef".into()),
1220                            output: Some(serde_json::Value::Null),
1221                        }],
1222                        context: "".to_string(),
1223                        creases: vec![],
1224                        is_hidden: false,
1225                    },
1226                ],
1227                version: SerializedThread::VERSION.to_string(),
1228                initial_project_snapshot: None,
1229                cumulative_token_usage: TokenUsage::default(),
1230                request_token_usage: vec![],
1231                detailed_summary_state: DetailedSummaryState::default(),
1232                exceeded_window_error: None,
1233                model: None,
1234                completion_mode: None,
1235                tool_use_limit_reached: false,
1236                profile: None
1237            }
1238        )
1239    }
1240}