thread_store.rs

   1use crate::{
   2    MessageId, ThreadId,
   3    context_server_tool::ContextServerTool,
   4    thread::{DetailedSummaryState, ExceededWindowError, ProjectSnapshot, Thread},
   5};
   6use agent_settings::{AgentProfileId, CompletionMode};
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_tool::{ToolId, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::HashMap;
  11use context_server::ContextServerId;
  12use futures::{
  13    FutureExt as _, StreamExt as _,
  14    channel::{mpsc, oneshot},
  15    future::{self, BoxFuture, Shared},
  16};
  17use gpui::{
  18    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  19    Subscription, Task, Window, prelude::*,
  20};
  21use indoc::indoc;
  22use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  23use project::context_server_store::{ContextServerStatus, ContextServerStore};
  24use project::{Project, ProjectItem, ProjectPath, Worktree};
  25use prompt_store::{
  26    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  27    UserRulesContext, WorktreeContext,
  28};
  29use serde::{Deserialize, Serialize};
  30use sqlez::{
  31    bindable::{Bind, Column},
  32    connection::Connection,
  33    statement::Statement,
  34};
  35use std::{
  36    cell::{Ref, RefCell},
  37    path::{Path, PathBuf},
  38    rc::Rc,
  39    sync::{Arc, Mutex},
  40};
  41use util::ResultExt as _;
  42
  43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  44pub enum DataType {
  45    #[serde(rename = "json")]
  46    Json,
  47    #[serde(rename = "zstd")]
  48    Zstd,
  49}
  50
  51impl Bind for DataType {
  52    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  53        let value = match self {
  54            DataType::Json => "json",
  55            DataType::Zstd => "zstd",
  56        };
  57        value.bind(statement, start_index)
  58    }
  59}
  60
  61impl Column for DataType {
  62    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  63        let (value, next_index) = String::column(statement, start_index)?;
  64        let data_type = match value.as_str() {
  65            "json" => DataType::Json,
  66            "zstd" => DataType::Zstd,
  67            _ => anyhow::bail!("Unknown data type: {}", value),
  68        };
  69        Ok((data_type, next_index))
  70    }
  71}
  72
  73const RULES_FILE_NAMES: [&'static str; 9] = [
  74    ".rules",
  75    ".cursorrules",
  76    ".windsurfrules",
  77    ".clinerules",
  78    ".github/copilot-instructions.md",
  79    "CLAUDE.md",
  80    "AGENT.md",
  81    "AGENTS.md",
  82    "GEMINI.md",
  83];
  84
  85pub fn init(cx: &mut App) {
  86    ThreadsDatabase::init(cx);
  87}
  88
  89/// A system prompt shared by all threads created by this ThreadStore
  90#[derive(Clone, Default)]
  91pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  92
  93impl SharedProjectContext {
  94    pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
  95        self.0.borrow()
  96    }
  97}
  98
  99pub type TextThreadStore = assistant_context::ContextStore;
 100
 101pub struct ThreadStore {
 102    project: Entity<Project>,
 103    tools: Entity<ToolWorkingSet>,
 104    prompt_builder: Arc<PromptBuilder>,
 105    prompt_store: Option<Entity<PromptStore>>,
 106    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 107    threads: Vec<SerializedThreadMetadata>,
 108    project_context: SharedProjectContext,
 109    reload_system_prompt_tx: mpsc::Sender<()>,
 110    _reload_system_prompt_task: Task<()>,
 111    _subscriptions: Vec<Subscription>,
 112}
 113
 114pub struct RulesLoadingError {
 115    pub message: SharedString,
 116}
 117
 118impl EventEmitter<RulesLoadingError> for ThreadStore {}
 119
 120impl ThreadStore {
 121    pub fn load(
 122        project: Entity<Project>,
 123        tools: Entity<ToolWorkingSet>,
 124        prompt_store: Option<Entity<PromptStore>>,
 125        prompt_builder: Arc<PromptBuilder>,
 126        cx: &mut App,
 127    ) -> Task<Result<Entity<Self>>> {
 128        cx.spawn(async move |cx| {
 129            let (thread_store, ready_rx) = cx.update(|cx| {
 130                let mut option_ready_rx = None;
 131                let thread_store = cx.new(|cx| {
 132                    let (thread_store, ready_rx) =
 133                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 134                    option_ready_rx = Some(ready_rx);
 135                    thread_store
 136                });
 137                (thread_store, option_ready_rx.take().unwrap())
 138            })?;
 139            ready_rx.await?;
 140            Ok(thread_store)
 141        })
 142    }
 143
 144    fn new(
 145        project: Entity<Project>,
 146        tools: Entity<ToolWorkingSet>,
 147        prompt_builder: Arc<PromptBuilder>,
 148        prompt_store: Option<Entity<PromptStore>>,
 149        cx: &mut Context<Self>,
 150    ) -> (Self, oneshot::Receiver<()>) {
 151        let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 152
 153        if let Some(prompt_store) = prompt_store.as_ref() {
 154            subscriptions.push(cx.subscribe(
 155                prompt_store,
 156                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 157                    this.enqueue_system_prompt_reload();
 158                },
 159            ))
 160        }
 161
 162        // This channel and task prevent concurrent and redundant loading of the system prompt.
 163        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 164        let (ready_tx, ready_rx) = oneshot::channel();
 165        let mut ready_tx = Some(ready_tx);
 166        let reload_system_prompt_task = cx.spawn({
 167            let prompt_store = prompt_store.clone();
 168            async move |thread_store, cx| {
 169                loop {
 170                    let Some(reload_task) = thread_store
 171                        .update(cx, |thread_store, cx| {
 172                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 173                        })
 174                        .ok()
 175                    else {
 176                        return;
 177                    };
 178                    reload_task.await;
 179                    if let Some(ready_tx) = ready_tx.take() {
 180                        ready_tx.send(()).ok();
 181                    }
 182                    reload_system_prompt_rx.next().await;
 183                }
 184            }
 185        });
 186
 187        let this = Self {
 188            project,
 189            tools,
 190            prompt_builder,
 191            prompt_store,
 192            context_server_tool_ids: HashMap::default(),
 193            threads: Vec::new(),
 194            project_context: SharedProjectContext::default(),
 195            reload_system_prompt_tx,
 196            _reload_system_prompt_task: reload_system_prompt_task,
 197            _subscriptions: subscriptions,
 198        };
 199        this.register_context_server_handlers(cx);
 200        this.reload(cx).detach_and_log_err(cx);
 201        (this, ready_rx)
 202    }
 203
 204    fn handle_project_event(
 205        &mut self,
 206        _project: Entity<Project>,
 207        event: &project::Event,
 208        _cx: &mut Context<Self>,
 209    ) {
 210        match event {
 211            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 212                self.enqueue_system_prompt_reload();
 213            }
 214            project::Event::WorktreeUpdatedEntries(_, items) => {
 215                if items.iter().any(|(path, _, _)| {
 216                    RULES_FILE_NAMES
 217                        .iter()
 218                        .any(|name| path.as_ref() == Path::new(name))
 219                }) {
 220                    self.enqueue_system_prompt_reload();
 221                }
 222            }
 223            _ => {}
 224        }
 225    }
 226
 227    fn enqueue_system_prompt_reload(&mut self) {
 228        self.reload_system_prompt_tx.try_send(()).ok();
 229    }
 230
 231    // Note that this should only be called from `reload_system_prompt_task`.
 232    fn reload_system_prompt(
 233        &self,
 234        prompt_store: Option<Entity<PromptStore>>,
 235        cx: &mut Context<Self>,
 236    ) -> Task<()> {
 237        let worktrees = self
 238            .project
 239            .read(cx)
 240            .visible_worktrees(cx)
 241            .collect::<Vec<_>>();
 242        let worktree_tasks = worktrees
 243            .into_iter()
 244            .map(|worktree| {
 245                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 246            })
 247            .collect::<Vec<_>>();
 248        let default_user_rules_task = match prompt_store {
 249            None => Task::ready(vec![]),
 250            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 251                let prompts = prompt_store.default_prompt_metadata();
 252                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 253                    let contents = prompt_store.load(prompt_metadata.id, cx);
 254                    async move { (contents.await, prompt_metadata) }
 255                });
 256                cx.background_spawn(future::join_all(load_tasks))
 257            }),
 258        };
 259
 260        cx.spawn(async move |this, cx| {
 261            let (worktrees, default_user_rules) =
 262                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 263
 264            let worktrees = worktrees
 265                .into_iter()
 266                .map(|(worktree, rules_error)| {
 267                    if let Some(rules_error) = rules_error {
 268                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 269                    }
 270                    worktree
 271                })
 272                .collect::<Vec<_>>();
 273
 274            let default_user_rules = default_user_rules
 275                .into_iter()
 276                .flat_map(|(contents, prompt_metadata)| match contents {
 277                    Ok(contents) => Some(UserRulesContext {
 278                        uuid: match prompt_metadata.id {
 279                            PromptId::User { uuid } => uuid,
 280                            PromptId::EditWorkflow => return None,
 281                        },
 282                        title: prompt_metadata.title.map(|title| title.to_string()),
 283                        contents,
 284                    }),
 285                    Err(err) => {
 286                        this.update(cx, |_, cx| {
 287                            cx.emit(RulesLoadingError {
 288                                message: format!("{err:?}").into(),
 289                            });
 290                        })
 291                        .ok();
 292                        None
 293                    }
 294                })
 295                .collect::<Vec<_>>();
 296
 297            this.update(cx, |this, _cx| {
 298                *this.project_context.0.borrow_mut() =
 299                    Some(ProjectContext::new(worktrees, default_user_rules));
 300            })
 301            .ok();
 302        })
 303    }
 304
 305    fn load_worktree_info_for_system_prompt(
 306        worktree: Entity<Worktree>,
 307        project: Entity<Project>,
 308        cx: &mut App,
 309    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 310        let tree = worktree.read(cx);
 311        let root_name = tree.root_name().into();
 312        let abs_path = tree.abs_path();
 313
 314        let mut context = WorktreeContext {
 315            root_name,
 316            abs_path,
 317            rules_file: None,
 318        };
 319
 320        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 321        let Some(rules_task) = rules_task else {
 322            return Task::ready((context, None));
 323        };
 324
 325        cx.spawn(async move |_| {
 326            let (rules_file, rules_file_error) = match rules_task.await {
 327                Ok(rules_file) => (Some(rules_file), None),
 328                Err(err) => (
 329                    None,
 330                    Some(RulesLoadingError {
 331                        message: format!("{err}").into(),
 332                    }),
 333                ),
 334            };
 335            context.rules_file = rules_file;
 336            (context, rules_file_error)
 337        })
 338    }
 339
 340    fn load_worktree_rules_file(
 341        worktree: Entity<Worktree>,
 342        project: Entity<Project>,
 343        cx: &mut App,
 344    ) -> Option<Task<Result<RulesFileContext>>> {
 345        let worktree = worktree.read(cx);
 346        let worktree_id = worktree.id();
 347        let selected_rules_file = RULES_FILE_NAMES
 348            .into_iter()
 349            .filter_map(|name| {
 350                worktree
 351                    .entry_for_path(name)
 352                    .filter(|entry| entry.is_file())
 353                    .map(|entry| entry.path.clone())
 354            })
 355            .next();
 356
 357        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 358        // supported. This doesn't seem to occur often in GitHub repositories.
 359        selected_rules_file.map(|path_in_worktree| {
 360            let project_path = ProjectPath {
 361                worktree_id,
 362                path: path_in_worktree.clone(),
 363            };
 364            let buffer_task =
 365                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 366            let rope_task = cx.spawn(async move |cx| {
 367                buffer_task.await?.read_with(cx, |buffer, cx| {
 368                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 369                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 370                })?
 371            });
 372            // Build a string from the rope on a background thread.
 373            cx.background_spawn(async move {
 374                let (project_entry_id, rope) = rope_task.await?;
 375                anyhow::Ok(RulesFileContext {
 376                    path_in_worktree,
 377                    text: rope.to_string().trim().to_string(),
 378                    project_entry_id: project_entry_id.to_usize(),
 379                })
 380            })
 381        })
 382    }
 383
 384    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 385        &self.prompt_store
 386    }
 387
 388    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 389        self.tools.clone()
 390    }
 391
 392    /// Returns the number of threads.
 393    pub fn thread_count(&self) -> usize {
 394        self.threads.len()
 395    }
 396
 397    pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 398        // ordering is from "ORDER BY" in `list_threads`
 399        self.threads.iter()
 400    }
 401
 402    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<Thread>>> {
 403        todo!()
 404        // cx.new(|cx| {
 405        //     Thread::new(
 406        //         self.project.clone(),
 407        //         self.tools.clone(),
 408        //         self.prompt_builder.clone(),
 409        //         self.project_context.clone(),
 410        //         cx,
 411        //     )
 412        // })
 413    }
 414
 415    pub fn open_thread(
 416        &self,
 417        id: &ThreadId,
 418        window: &mut Window,
 419        cx: &mut Context<Self>,
 420    ) -> Task<Result<Entity<Thread>>> {
 421        let id = id.clone();
 422        let database_future = ThreadsDatabase::global_future(cx);
 423        let this = cx.weak_entity();
 424        window.spawn(cx, async move |cx| {
 425            let database = database_future.await.map_err(|err| anyhow!(err))?;
 426            let thread = database
 427                .try_find_thread(id.clone())
 428                .await?
 429                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 430
 431            todo!();
 432            // let thread = this.update_in(cx, |this, window, cx| {
 433            //     cx.new(|cx| {
 434            //         Thread::deserialize(
 435            //             id.clone(),
 436            //             thread,
 437            //             this.project.clone(),
 438            //             this.tools.clone(),
 439            //             this.prompt_builder.clone(),
 440            //             this.project_context.clone(),
 441            //             Some(window),
 442            //             cx,
 443            //         )
 444            //     })
 445            // })?;
 446            // Ok(thread)
 447        })
 448    }
 449
 450    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 451        todo!()
 452        // let (metadata, serialized_thread) =
 453        //     thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 454
 455        // let database_future = ThreadsDatabase::global_future(cx);
 456        // cx.spawn(async move |this, cx| {
 457        //     let serialized_thread = serialized_thread.await?;
 458        //     let database = database_future.await.map_err(|err| anyhow!(err))?;
 459        //     database.save_thread(metadata, serialized_thread).await?;
 460
 461        //     this.update(cx, |this, cx| this.reload(cx))?.await
 462        // })
 463    }
 464
 465    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 466        todo!()
 467        // let id = id.clone();
 468        // let database_future = ThreadsDatabase::global_future(cx);
 469        // cx.spawn(async move |this, cx| {
 470        //     let database = database_future.await.map_err(|err| anyhow!(err))?;
 471        //     database.delete_thread(id.clone()).await?;
 472
 473        //     this.update(cx, |this, cx| {
 474        //         this.threads.retain(|thread| thread.id != id);
 475        //         cx.notify();
 476        //     })
 477        // })
 478    }
 479
 480    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 481        let database_future = ThreadsDatabase::global_future(cx);
 482        cx.spawn(async move |this, cx| {
 483            let threads = database_future
 484                .await
 485                .map_err(|err| anyhow!(err))?
 486                .list_threads()
 487                .await?;
 488
 489            this.update(cx, |this, cx| {
 490                this.threads = threads;
 491                cx.notify();
 492            })
 493        })
 494    }
 495
 496    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 497        let context_server_store = self.project.read(cx).context_server_store();
 498        cx.subscribe(&context_server_store, Self::handle_context_server_event)
 499            .detach();
 500
 501        // Check for any servers that were already running before the handler was registered
 502        for server in context_server_store.read(cx).running_servers() {
 503            self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
 504        }
 505    }
 506
 507    fn handle_context_server_event(
 508        &mut self,
 509        context_server_store: Entity<ContextServerStore>,
 510        event: &project::context_server_store::Event,
 511        cx: &mut Context<Self>,
 512    ) {
 513        let tool_working_set = self.tools.clone();
 514        match event {
 515            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 516                match status {
 517                    ContextServerStatus::Starting => {}
 518                    ContextServerStatus::Running => {
 519                        self.load_context_server_tools(server_id.clone(), context_server_store, cx);
 520                    }
 521                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 522                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 523                            tool_working_set.update(cx, |tool_working_set, _| {
 524                                tool_working_set.remove(&tool_ids);
 525                            });
 526                        }
 527                    }
 528                }
 529            }
 530        }
 531    }
 532
 533    fn load_context_server_tools(
 534        &self,
 535        server_id: ContextServerId,
 536        context_server_store: Entity<ContextServerStore>,
 537        cx: &mut Context<Self>,
 538    ) {
 539        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
 540            return;
 541        };
 542        let tool_working_set = self.tools.clone();
 543        cx.spawn(async move |this, cx| {
 544            let Some(protocol) = server.client() else {
 545                return;
 546            };
 547
 548            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
 549                if let Some(response) = protocol
 550                    .request::<context_server::types::requests::ListTools>(())
 551                    .await
 552                    .log_err()
 553                {
 554                    let tool_ids = tool_working_set
 555                        .update(cx, |tool_working_set, _| {
 556                            response
 557                                .tools
 558                                .into_iter()
 559                                .map(|tool| {
 560                                    log::info!("registering context server tool: {:?}", tool.name);
 561                                    tool_working_set.insert(Arc::new(ContextServerTool::new(
 562                                        context_server_store.clone(),
 563                                        server.id(),
 564                                        tool,
 565                                    )))
 566                                })
 567                                .collect::<Vec<_>>()
 568                        })
 569                        .log_err();
 570
 571                    if let Some(tool_ids) = tool_ids {
 572                        this.update(cx, |this, _| {
 573                            this.context_server_tool_ids.insert(server_id, tool_ids);
 574                        })
 575                        .log_err();
 576                    }
 577                }
 578            }
 579        })
 580        .detach();
 581    }
 582}
 583
 584#[derive(Debug, Clone, Serialize, Deserialize)]
 585pub struct SerializedThreadMetadata {
 586    pub id: ThreadId,
 587    pub summary: SharedString,
 588    pub updated_at: DateTime<Utc>,
 589}
 590
 591#[derive(Serialize, Deserialize, Debug, PartialEq)]
 592pub struct SerializedThread {
 593    pub version: String,
 594    pub summary: SharedString,
 595    pub updated_at: DateTime<Utc>,
 596    pub messages: Vec<SerializedMessage>,
 597    #[serde(default)]
 598    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 599    #[serde(default)]
 600    pub cumulative_token_usage: TokenUsage,
 601    #[serde(default)]
 602    pub request_token_usage: Vec<TokenUsage>,
 603    #[serde(default)]
 604    pub detailed_summary_state: DetailedSummaryState,
 605    #[serde(default)]
 606    pub exceeded_window_error: Option<ExceededWindowError>,
 607    #[serde(default)]
 608    pub model: Option<SerializedLanguageModel>,
 609    #[serde(default)]
 610    pub completion_mode: Option<CompletionMode>,
 611    #[serde(default)]
 612    pub tool_use_limit_reached: bool,
 613    #[serde(default)]
 614    pub profile: Option<AgentProfileId>,
 615}
 616
 617#[derive(Serialize, Deserialize, Debug, PartialEq)]
 618pub struct SerializedLanguageModel {
 619    pub provider: String,
 620    pub model: String,
 621}
 622
 623impl SerializedThread {
 624    pub const VERSION: &'static str = "0.2.0";
 625
 626    pub fn from_json(json: &[u8]) -> Result<Self> {
 627        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 628        match saved_thread_json.get("version") {
 629            Some(serde_json::Value::String(version)) => match version.as_str() {
 630                SerializedThreadV0_1_0::VERSION => {
 631                    let saved_thread =
 632                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 633                    Ok(saved_thread.upgrade())
 634                }
 635                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 636                    saved_thread_json,
 637                )?),
 638                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 639            },
 640            None => {
 641                let saved_thread =
 642                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 643                Ok(saved_thread.upgrade())
 644            }
 645            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 646        }
 647    }
 648}
 649
 650#[derive(Serialize, Deserialize, Debug)]
 651pub struct SerializedThreadV0_1_0(
 652    // The structure did not change, so we are reusing the latest SerializedThread.
 653    // When making the next version, make sure this points to SerializedThreadV0_2_0
 654    SerializedThread,
 655);
 656
 657impl SerializedThreadV0_1_0 {
 658    pub const VERSION: &'static str = "0.1.0";
 659
 660    pub fn upgrade(self) -> SerializedThread {
 661        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 662
 663        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 664
 665        for message in self.0.messages {
 666            if message.role == Role::User && !message.tool_results.is_empty() {
 667                if let Some(last_message) = messages.last_mut() {
 668                    debug_assert!(last_message.role == Role::Assistant);
 669
 670                    last_message.tool_results = message.tool_results;
 671                    continue;
 672                }
 673            }
 674
 675            messages.push(message);
 676        }
 677
 678        SerializedThread {
 679            messages,
 680            version: SerializedThread::VERSION.to_string(),
 681            ..self.0
 682        }
 683    }
 684}
 685
 686#[derive(Debug, Serialize, Deserialize, PartialEq)]
 687pub struct SerializedMessage {
 688    pub id: MessageId,
 689    pub role: Role,
 690    #[serde(default)]
 691    pub segments: Vec<SerializedMessageSegment>,
 692    #[serde(default)]
 693    pub tool_uses: Vec<SerializedToolUse>,
 694    #[serde(default)]
 695    pub tool_results: Vec<SerializedToolResult>,
 696    #[serde(default)]
 697    pub context: String,
 698    #[serde(default)]
 699    pub creases: Vec<SerializedCrease>,
 700    #[serde(default)]
 701    pub is_hidden: bool,
 702}
 703
 704#[derive(Debug, Serialize, Deserialize, PartialEq)]
 705#[serde(tag = "type")]
 706pub enum SerializedMessageSegment {
 707    #[serde(rename = "text")]
 708    Text {
 709        text: String,
 710    },
 711    #[serde(rename = "thinking")]
 712    Thinking {
 713        text: String,
 714        #[serde(skip_serializing_if = "Option::is_none")]
 715        signature: Option<String>,
 716    },
 717    RedactedThinking {
 718        data: String,
 719    },
 720}
 721
 722#[derive(Debug, Serialize, Deserialize, PartialEq)]
 723pub struct SerializedToolUse {
 724    pub id: LanguageModelToolUseId,
 725    pub name: SharedString,
 726    pub input: serde_json::Value,
 727}
 728
 729#[derive(Debug, Serialize, Deserialize, PartialEq)]
 730pub struct SerializedToolResult {
 731    pub tool_use_id: LanguageModelToolUseId,
 732    pub is_error: bool,
 733    pub content: LanguageModelToolResultContent,
 734    pub output: Option<serde_json::Value>,
 735}
 736
 737#[derive(Serialize, Deserialize)]
 738struct LegacySerializedThread {
 739    pub summary: SharedString,
 740    pub updated_at: DateTime<Utc>,
 741    pub messages: Vec<LegacySerializedMessage>,
 742    #[serde(default)]
 743    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 744}
 745
 746impl LegacySerializedThread {
 747    pub fn upgrade(self) -> SerializedThread {
 748        SerializedThread {
 749            version: SerializedThread::VERSION.to_string(),
 750            summary: self.summary,
 751            updated_at: self.updated_at,
 752            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 753            initial_project_snapshot: self.initial_project_snapshot,
 754            cumulative_token_usage: TokenUsage::default(),
 755            request_token_usage: Vec::new(),
 756            detailed_summary_state: DetailedSummaryState::default(),
 757            exceeded_window_error: None,
 758            model: None,
 759            completion_mode: None,
 760            tool_use_limit_reached: false,
 761            profile: None,
 762        }
 763    }
 764}
 765
 766#[derive(Debug, Serialize, Deserialize)]
 767struct LegacySerializedMessage {
 768    pub id: MessageId,
 769    pub role: Role,
 770    pub text: String,
 771    #[serde(default)]
 772    pub tool_uses: Vec<SerializedToolUse>,
 773    #[serde(default)]
 774    pub tool_results: Vec<SerializedToolResult>,
 775}
 776
 777impl LegacySerializedMessage {
 778    fn upgrade(self) -> SerializedMessage {
 779        SerializedMessage {
 780            id: self.id,
 781            role: self.role,
 782            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 783            tool_uses: self.tool_uses,
 784            tool_results: self.tool_results,
 785            context: String::new(),
 786            creases: Vec::new(),
 787            is_hidden: false,
 788        }
 789    }
 790}
 791
 792#[derive(Debug, Serialize, Deserialize, PartialEq)]
 793pub struct SerializedCrease {
 794    pub start: usize,
 795    pub end: usize,
 796    pub icon_path: SharedString,
 797    pub label: SharedString,
 798}
 799
 800struct GlobalThreadsDatabase(
 801    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 802);
 803
 804impl Global for GlobalThreadsDatabase {}
 805
 806pub(crate) struct ThreadsDatabase {
 807    executor: BackgroundExecutor,
 808    connection: Arc<Mutex<Connection>>,
 809}
 810
 811impl ThreadsDatabase {
 812    fn connection(&self) -> Arc<Mutex<Connection>> {
 813        self.connection.clone()
 814    }
 815
 816    const COMPRESSION_LEVEL: i32 = 3;
 817}
 818
 819impl Bind for ThreadId {
 820    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 821        self.to_string().bind(statement, start_index)
 822    }
 823}
 824
 825impl Column for ThreadId {
 826    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 827        let (id_str, next_index) = String::column(statement, start_index)?;
 828        Ok((ThreadId::from(id_str.as_str()), next_index))
 829    }
 830}
 831
 832impl ThreadsDatabase {
 833    fn global_future(
 834        cx: &mut App,
 835    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 836        GlobalThreadsDatabase::global(cx).0.clone()
 837    }
 838
 839    fn init(cx: &mut App) {
 840        let executor = cx.background_executor().clone();
 841        let database_future = executor
 842            .spawn({
 843                let executor = executor.clone();
 844                let threads_dir = paths::data_dir().join("threads");
 845                async move { ThreadsDatabase::new(threads_dir, executor) }
 846            })
 847            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 848            .boxed()
 849            .shared();
 850
 851        cx.set_global(GlobalThreadsDatabase(database_future));
 852    }
 853
 854    pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
 855        std::fs::create_dir_all(&threads_dir)?;
 856
 857        let sqlite_path = threads_dir.join("threads.db");
 858        let mdb_path = threads_dir.join("threads-db.1.mdb");
 859
 860        let needs_migration_from_heed = mdb_path.exists();
 861
 862        let connection = Connection::open_file(&sqlite_path.to_string_lossy());
 863
 864        connection.exec(indoc! {"
 865                CREATE TABLE IF NOT EXISTS threads (
 866                    id TEXT PRIMARY KEY,
 867                    summary TEXT NOT NULL,
 868                    updated_at TEXT NOT NULL,
 869                    data_type TEXT NOT NULL,
 870                    data BLOB NOT NULL
 871                )
 872            "})?()
 873        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 874
 875        let db = Self {
 876            executor: executor.clone(),
 877            connection: Arc::new(Mutex::new(connection)),
 878        };
 879
 880        if needs_migration_from_heed {
 881            let db_connection = db.connection();
 882            let executor_clone = executor.clone();
 883            executor
 884                .spawn(async move {
 885                    log::info!("Starting threads.db migration");
 886                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 887                    std::fs::remove_dir_all(mdb_path)?;
 888                    log::info!("threads.db migrated to sqlite");
 889                    Ok::<(), anyhow::Error>(())
 890                })
 891                .detach();
 892        }
 893
 894        Ok(db)
 895    }
 896
 897    // Remove this migration after 2025-09-01
 898    fn migrate_from_heed(
 899        mdb_path: &Path,
 900        connection: Arc<Mutex<Connection>>,
 901        _executor: BackgroundExecutor,
 902    ) -> Result<()> {
 903        use heed::types::SerdeBincode;
 904        struct SerializedThreadHeed(SerializedThread);
 905
 906        impl heed::BytesEncode<'_> for SerializedThreadHeed {
 907            type EItem = SerializedThreadHeed;
 908
 909            fn bytes_encode(
 910                item: &Self::EItem,
 911            ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
 912                serde_json::to_vec(&item.0)
 913                    .map(std::borrow::Cow::Owned)
 914                    .map_err(Into::into)
 915            }
 916        }
 917
 918        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
 919            type DItem = SerializedThreadHeed;
 920
 921            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
 922                SerializedThread::from_json(bytes)
 923                    .map(SerializedThreadHeed)
 924                    .map_err(Into::into)
 925            }
 926        }
 927
 928        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
 929
 930        let env = unsafe {
 931            heed::EnvOpenOptions::new()
 932                .map_size(ONE_GB_IN_BYTES)
 933                .max_dbs(1)
 934                .open(mdb_path)?
 935        };
 936
 937        let txn = env.write_txn()?;
 938        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
 939            .open_database(&txn, Some("threads"))?
 940            .ok_or_else(|| anyhow!("threads database not found"))?;
 941
 942        for result in threads.iter(&txn)? {
 943            let (thread_id, thread_heed) = result?;
 944            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
 945        }
 946
 947        Ok(())
 948    }
 949
 950    fn save_thread_sync(
 951        connection: &Arc<Mutex<Connection>>,
 952        id: ThreadId,
 953        thread: SerializedThread,
 954    ) -> Result<()> {
 955        let json_data = serde_json::to_string(&thread)?;
 956        let summary = thread.summary.to_string();
 957        let updated_at = thread.updated_at.to_rfc3339();
 958
 959        let connection = connection.lock().unwrap();
 960
 961        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
 962        let data_type = DataType::Zstd;
 963        let data = compressed;
 964
 965        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
 966            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
 967        "})?;
 968
 969        insert((id, summary, updated_at, data_type, data))?;
 970
 971        Ok(())
 972    }
 973
 974    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
 975        let connection = self.connection.clone();
 976
 977        self.executor.spawn(async move {
 978            let connection = connection.lock().unwrap();
 979            let mut select =
 980                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
 981                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
 982            "})?;
 983
 984            let rows = select(())?;
 985            let mut threads = Vec::new();
 986
 987            for (id, summary, updated_at) in rows {
 988                threads.push(SerializedThreadMetadata {
 989                    id,
 990                    summary: summary.into(),
 991                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
 992                });
 993            }
 994
 995            Ok(threads)
 996        })
 997    }
 998
 999    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1000        let connection = self.connection.clone();
1001
1002        self.executor.spawn(async move {
1003            let connection = connection.lock().unwrap();
1004            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1005                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1006            "})?;
1007
1008            let rows = select(id)?;
1009            if let Some((data_type, data)) = rows.into_iter().next() {
1010                let json_data = match data_type {
1011                    DataType::Zstd => {
1012                        let decompressed = zstd::decode_all(&data[..])?;
1013                        String::from_utf8(decompressed)?
1014                    }
1015                    DataType::Json => String::from_utf8(data)?,
1016                };
1017
1018                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1019                Ok(Some(thread))
1020            } else {
1021                Ok(None)
1022            }
1023        })
1024    }
1025
1026    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1027        let connection = self.connection.clone();
1028
1029        self.executor
1030            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1031    }
1032
1033    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1034        let connection = self.connection.clone();
1035
1036        self.executor.spawn(async move {
1037            let connection = connection.lock().unwrap();
1038
1039            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1040                DELETE FROM threads WHERE id = ?
1041            "})?;
1042
1043            delete(id)?;
1044
1045            Ok(())
1046        })
1047    }
1048}
1049
1050#[cfg(test)]
1051mod tests {
1052    use super::*;
1053    use crate::{MessageId, thread::DetailedSummaryState};
1054    use chrono::Utc;
1055    use language_model::{Role, TokenUsage};
1056    use pretty_assertions::assert_eq;
1057
1058    #[test]
1059    fn test_legacy_serialized_thread_upgrade() {
1060        let updated_at = Utc::now();
1061        let legacy_thread = LegacySerializedThread {
1062            summary: "Test conversation".into(),
1063            updated_at,
1064            messages: vec![LegacySerializedMessage {
1065                id: MessageId(1),
1066                role: Role::User,
1067                text: "Hello, world!".to_string(),
1068                tool_uses: vec![],
1069                tool_results: vec![],
1070            }],
1071            initial_project_snapshot: None,
1072        };
1073
1074        let upgraded = legacy_thread.upgrade();
1075
1076        assert_eq!(
1077            upgraded,
1078            SerializedThread {
1079                summary: "Test conversation".into(),
1080                updated_at,
1081                messages: vec![SerializedMessage {
1082                    id: MessageId(1),
1083                    role: Role::User,
1084                    segments: vec![SerializedMessageSegment::Text {
1085                        text: "Hello, world!".to_string()
1086                    }],
1087                    tool_uses: vec![],
1088                    tool_results: vec![],
1089                    context: "".to_string(),
1090                    creases: vec![],
1091                    is_hidden: false
1092                }],
1093                version: SerializedThread::VERSION.to_string(),
1094                initial_project_snapshot: None,
1095                cumulative_token_usage: TokenUsage::default(),
1096                request_token_usage: vec![],
1097                detailed_summary_state: DetailedSummaryState::default(),
1098                exceeded_window_error: None,
1099                model: None,
1100                completion_mode: None,
1101                tool_use_limit_reached: false,
1102                profile: None
1103            }
1104        )
1105    }
1106
1107    #[test]
1108    fn test_serialized_threadv0_1_0_upgrade() {
1109        let updated_at = Utc::now();
1110        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1111            summary: "Test conversation".into(),
1112            updated_at,
1113            messages: vec![
1114                SerializedMessage {
1115                    id: MessageId(1),
1116                    role: Role::User,
1117                    segments: vec![SerializedMessageSegment::Text {
1118                        text: "Use tool_1".to_string(),
1119                    }],
1120                    tool_uses: vec![],
1121                    tool_results: vec![],
1122                    context: "".to_string(),
1123                    creases: vec![],
1124                    is_hidden: false,
1125                },
1126                SerializedMessage {
1127                    id: MessageId(2),
1128                    role: Role::Assistant,
1129                    segments: vec![SerializedMessageSegment::Text {
1130                        text: "I want to use a tool".to_string(),
1131                    }],
1132                    tool_uses: vec![SerializedToolUse {
1133                        id: "abc".into(),
1134                        name: "tool_1".into(),
1135                        input: serde_json::Value::Null,
1136                    }],
1137                    tool_results: vec![],
1138                    context: "".to_string(),
1139                    creases: vec![],
1140                    is_hidden: false,
1141                },
1142                SerializedMessage {
1143                    id: MessageId(1),
1144                    role: Role::User,
1145                    segments: vec![SerializedMessageSegment::Text {
1146                        text: "Here is the tool result".to_string(),
1147                    }],
1148                    tool_uses: vec![],
1149                    tool_results: vec![SerializedToolResult {
1150                        tool_use_id: "abc".into(),
1151                        is_error: false,
1152                        content: LanguageModelToolResultContent::Text("abcdef".into()),
1153                        output: Some(serde_json::Value::Null),
1154                    }],
1155                    context: "".to_string(),
1156                    creases: vec![],
1157                    is_hidden: false,
1158                },
1159            ],
1160            version: SerializedThreadV0_1_0::VERSION.to_string(),
1161            initial_project_snapshot: None,
1162            cumulative_token_usage: TokenUsage::default(),
1163            request_token_usage: vec![],
1164            detailed_summary_state: DetailedSummaryState::default(),
1165            exceeded_window_error: None,
1166            model: None,
1167            completion_mode: None,
1168            tool_use_limit_reached: false,
1169            profile: None,
1170        });
1171        let upgraded = thread_v0_1_0.upgrade();
1172
1173        assert_eq!(
1174            upgraded,
1175            SerializedThread {
1176                summary: "Test conversation".into(),
1177                updated_at,
1178                messages: vec![
1179                    SerializedMessage {
1180                        id: MessageId(1),
1181                        role: Role::User,
1182                        segments: vec![SerializedMessageSegment::Text {
1183                            text: "Use tool_1".to_string()
1184                        }],
1185                        tool_uses: vec![],
1186                        tool_results: vec![],
1187                        context: "".to_string(),
1188                        creases: vec![],
1189                        is_hidden: false
1190                    },
1191                    SerializedMessage {
1192                        id: MessageId(2),
1193                        role: Role::Assistant,
1194                        segments: vec![SerializedMessageSegment::Text {
1195                            text: "I want to use a tool".to_string(),
1196                        }],
1197                        tool_uses: vec![SerializedToolUse {
1198                            id: "abc".into(),
1199                            name: "tool_1".into(),
1200                            input: serde_json::Value::Null,
1201                        }],
1202                        tool_results: vec![SerializedToolResult {
1203                            tool_use_id: "abc".into(),
1204                            is_error: false,
1205                            content: LanguageModelToolResultContent::Text("abcdef".into()),
1206                            output: Some(serde_json::Value::Null),
1207                        }],
1208                        context: "".to_string(),
1209                        creases: vec![],
1210                        is_hidden: false,
1211                    },
1212                ],
1213                version: SerializedThread::VERSION.to_string(),
1214                initial_project_snapshot: None,
1215                cumulative_token_usage: TokenUsage::default(),
1216                request_token_usage: vec![],
1217                detailed_summary_state: DetailedSummaryState::default(),
1218                exceeded_window_error: None,
1219                model: None,
1220                completion_mode: None,
1221                tool_use_limit_reached: false,
1222                profile: None
1223            }
1224        )
1225    }
1226}