thread_store.rs

   1use crate::{
   2    context_server_tool::ContextServerTool,
   3    thread::{
   4        DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
   5    },
   6};
   7use agent_settings::{AgentProfileId, CompletionMode};
   8use anyhow::{Context as _, Result, anyhow};
   9use assistant_tool::{Tool, ToolId, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use context_server::ContextServerId;
  13use fs::{Fs, RemoveOptions};
  14use futures::{
  15    FutureExt as _, StreamExt as _,
  16    channel::{mpsc, oneshot},
  17    future::{self, BoxFuture, Shared},
  18};
  19use gpui::{
  20    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  21    Subscription, Task, Window, prelude::*,
  22};
  23use indoc::indoc;
  24use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  25use project::context_server_store::{ContextServerStatus, ContextServerStore};
  26use project::{Project, ProjectItem, ProjectPath, Worktree};
  27use prompt_store::{
  28    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  29    UserRulesContext, WorktreeContext,
  30};
  31use serde::{Deserialize, Serialize};
  32use sqlez::{
  33    bindable::{Bind, Column},
  34    connection::Connection,
  35    statement::Statement,
  36};
  37use std::{
  38    cell::{Ref, RefCell},
  39    path::{Path, PathBuf},
  40    rc::Rc,
  41    sync::{Arc, Mutex},
  42};
  43use util::{ResultExt as _, rel_path::RelPath};
  44
  45use zed_env_vars::ZED_STATELESS;
  46
  47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  48pub enum DataType {
  49    #[serde(rename = "json")]
  50    Json,
  51    #[serde(rename = "zstd")]
  52    Zstd,
  53}
  54
  55impl Bind for DataType {
  56    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  57        let value = match self {
  58            DataType::Json => "json",
  59            DataType::Zstd => "zstd",
  60        };
  61        value.bind(statement, start_index)
  62    }
  63}
  64
  65impl Column for DataType {
  66    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  67        let (value, next_index) = String::column(statement, start_index)?;
  68        let data_type = match value.as_str() {
  69            "json" => DataType::Json,
  70            "zstd" => DataType::Zstd,
  71            _ => anyhow::bail!("Unknown data type: {}", value),
  72        };
  73        Ok((data_type, next_index))
  74    }
  75}
  76
  77const RULES_FILE_NAMES: [&str; 9] = [
  78    ".rules",
  79    ".cursorrules",
  80    ".windsurfrules",
  81    ".clinerules",
  82    ".github/copilot-instructions.md",
  83    "CLAUDE.md",
  84    "AGENT.md",
  85    "AGENTS.md",
  86    "GEMINI.md",
  87];
  88
  89pub fn init(fs: Arc<dyn Fs>, cx: &mut App) {
  90    ThreadsDatabase::init(fs, cx);
  91}
  92
  93/// A system prompt shared by all threads created by this ThreadStore
  94#[derive(Clone, Default)]
  95pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  96
  97impl SharedProjectContext {
  98    pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
  99        self.0.borrow()
 100    }
 101}
 102
 103pub type TextThreadStore = assistant_context::ContextStore;
 104
 105pub struct ThreadStore {
 106    project: Entity<Project>,
 107    tools: Entity<ToolWorkingSet>,
 108    prompt_builder: Arc<PromptBuilder>,
 109    prompt_store: Option<Entity<PromptStore>>,
 110    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 111    threads: Vec<SerializedThreadMetadata>,
 112    project_context: SharedProjectContext,
 113    reload_system_prompt_tx: mpsc::Sender<()>,
 114    _reload_system_prompt_task: Task<()>,
 115    _subscriptions: Vec<Subscription>,
 116}
 117
 118pub struct RulesLoadingError {
 119    pub message: SharedString,
 120}
 121
 122impl EventEmitter<RulesLoadingError> for ThreadStore {}
 123
 124impl ThreadStore {
 125    pub fn load(
 126        project: Entity<Project>,
 127        tools: Entity<ToolWorkingSet>,
 128        prompt_store: Option<Entity<PromptStore>>,
 129        prompt_builder: Arc<PromptBuilder>,
 130        cx: &mut App,
 131    ) -> Task<Result<Entity<Self>>> {
 132        cx.spawn(async move |cx| {
 133            let (thread_store, ready_rx) = cx.update(|cx| {
 134                let mut option_ready_rx = None;
 135                let thread_store = cx.new(|cx| {
 136                    let (thread_store, ready_rx) =
 137                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 138                    option_ready_rx = Some(ready_rx);
 139                    thread_store
 140                });
 141                (thread_store, option_ready_rx.take().unwrap())
 142            })?;
 143            ready_rx.await?;
 144            Ok(thread_store)
 145        })
 146    }
 147
 148    fn new(
 149        project: Entity<Project>,
 150        tools: Entity<ToolWorkingSet>,
 151        prompt_builder: Arc<PromptBuilder>,
 152        prompt_store: Option<Entity<PromptStore>>,
 153        cx: &mut Context<Self>,
 154    ) -> (Self, oneshot::Receiver<()>) {
 155        let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 156
 157        if let Some(prompt_store) = prompt_store.as_ref() {
 158            subscriptions.push(cx.subscribe(
 159                prompt_store,
 160                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 161                    this.enqueue_system_prompt_reload();
 162                },
 163            ))
 164        }
 165
 166        // This channel and task prevent concurrent and redundant loading of the system prompt.
 167        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 168        let (ready_tx, ready_rx) = oneshot::channel();
 169        let mut ready_tx = Some(ready_tx);
 170        let reload_system_prompt_task = cx.spawn({
 171            let prompt_store = prompt_store.clone();
 172            async move |thread_store, cx| {
 173                loop {
 174                    let Some(reload_task) = thread_store
 175                        .update(cx, |thread_store, cx| {
 176                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 177                        })
 178                        .ok()
 179                    else {
 180                        return;
 181                    };
 182                    reload_task.await;
 183                    if let Some(ready_tx) = ready_tx.take() {
 184                        ready_tx.send(()).ok();
 185                    }
 186                    reload_system_prompt_rx.next().await;
 187                }
 188            }
 189        });
 190
 191        let this = Self {
 192            project,
 193            tools,
 194            prompt_builder,
 195            prompt_store,
 196            context_server_tool_ids: HashMap::default(),
 197            threads: Vec::new(),
 198            project_context: SharedProjectContext::default(),
 199            reload_system_prompt_tx,
 200            _reload_system_prompt_task: reload_system_prompt_task,
 201            _subscriptions: subscriptions,
 202        };
 203        this.register_context_server_handlers(cx);
 204        this.reload(cx).detach_and_log_err(cx);
 205        (this, ready_rx)
 206    }
 207
 208    #[cfg(any(test, feature = "test-support"))]
 209    pub fn fake(project: Entity<Project>, cx: &mut App) -> Self {
 210        Self {
 211            project,
 212            tools: cx.new(|_| ToolWorkingSet::default()),
 213            prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
 214            prompt_store: None,
 215            context_server_tool_ids: HashMap::default(),
 216            threads: Vec::new(),
 217            project_context: SharedProjectContext::default(),
 218            reload_system_prompt_tx: mpsc::channel(0).0,
 219            _reload_system_prompt_task: Task::ready(()),
 220            _subscriptions: vec![],
 221        }
 222    }
 223
 224    fn handle_project_event(
 225        &mut self,
 226        _project: Entity<Project>,
 227        event: &project::Event,
 228        _cx: &mut Context<Self>,
 229    ) {
 230        match event {
 231            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 232                self.enqueue_system_prompt_reload();
 233            }
 234            project::Event::WorktreeUpdatedEntries(_, items) => {
 235                if items.iter().any(|(path, _, _)| {
 236                    RULES_FILE_NAMES
 237                        .iter()
 238                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 239                }) {
 240                    self.enqueue_system_prompt_reload();
 241                }
 242            }
 243            _ => {}
 244        }
 245    }
 246
 247    fn enqueue_system_prompt_reload(&mut self) {
 248        self.reload_system_prompt_tx.try_send(()).ok();
 249    }
 250
 251    // Note that this should only be called from `reload_system_prompt_task`.
 252    fn reload_system_prompt(
 253        &self,
 254        prompt_store: Option<Entity<PromptStore>>,
 255        cx: &mut Context<Self>,
 256    ) -> Task<()> {
 257        let worktrees = self
 258            .project
 259            .read(cx)
 260            .visible_worktrees(cx)
 261            .collect::<Vec<_>>();
 262        let worktree_tasks = worktrees
 263            .into_iter()
 264            .map(|worktree| {
 265                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 266            })
 267            .collect::<Vec<_>>();
 268        let default_user_rules_task = match prompt_store {
 269            None => Task::ready(vec![]),
 270            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 271                let prompts = prompt_store.default_prompt_metadata();
 272                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 273                    let contents = prompt_store.load(prompt_metadata.id, cx);
 274                    async move { (contents.await, prompt_metadata) }
 275                });
 276                cx.background_spawn(future::join_all(load_tasks))
 277            }),
 278        };
 279
 280        cx.spawn(async move |this, cx| {
 281            let (worktrees, default_user_rules) =
 282                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 283
 284            let worktrees = worktrees
 285                .into_iter()
 286                .map(|(worktree, rules_error)| {
 287                    if let Some(rules_error) = rules_error {
 288                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 289                    }
 290                    worktree
 291                })
 292                .collect::<Vec<_>>();
 293
 294            let default_user_rules = default_user_rules
 295                .into_iter()
 296                .flat_map(|(contents, prompt_metadata)| match contents {
 297                    Ok(contents) => Some(UserRulesContext {
 298                        uuid: match prompt_metadata.id {
 299                            PromptId::User { uuid } => uuid,
 300                            PromptId::EditWorkflow => return None,
 301                        },
 302                        title: prompt_metadata.title.map(|title| title.to_string()),
 303                        contents,
 304                    }),
 305                    Err(err) => {
 306                        this.update(cx, |_, cx| {
 307                            cx.emit(RulesLoadingError {
 308                                message: format!("{err:?}").into(),
 309                            });
 310                        })
 311                        .ok();
 312                        None
 313                    }
 314                })
 315                .collect::<Vec<_>>();
 316
 317            this.update(cx, |this, _cx| {
 318                *this.project_context.0.borrow_mut() =
 319                    Some(ProjectContext::new(worktrees, default_user_rules));
 320            })
 321            .ok();
 322        })
 323    }
 324
 325    fn load_worktree_info_for_system_prompt(
 326        worktree: Entity<Worktree>,
 327        project: Entity<Project>,
 328        cx: &mut App,
 329    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 330        let tree = worktree.read(cx);
 331        let root_name = tree.root_name_str().into();
 332        let abs_path = tree.abs_path();
 333
 334        let mut context = WorktreeContext {
 335            root_name,
 336            abs_path,
 337            rules_file: None,
 338        };
 339
 340        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 341        let Some(rules_task) = rules_task else {
 342            return Task::ready((context, None));
 343        };
 344
 345        cx.spawn(async move |_| {
 346            let (rules_file, rules_file_error) = match rules_task.await {
 347                Ok(rules_file) => (Some(rules_file), None),
 348                Err(err) => (
 349                    None,
 350                    Some(RulesLoadingError {
 351                        message: format!("{err}").into(),
 352                    }),
 353                ),
 354            };
 355            context.rules_file = rules_file;
 356            (context, rules_file_error)
 357        })
 358    }
 359
 360    fn load_worktree_rules_file(
 361        worktree: Entity<Worktree>,
 362        project: Entity<Project>,
 363        cx: &mut App,
 364    ) -> Option<Task<Result<RulesFileContext>>> {
 365        let worktree = worktree.read(cx);
 366        let worktree_id = worktree.id();
 367        let selected_rules_file = RULES_FILE_NAMES
 368            .into_iter()
 369            .filter_map(|name| {
 370                worktree
 371                    .entry_for_path(RelPath::unix(name).unwrap())
 372                    .filter(|entry| entry.is_file())
 373                    .map(|entry| entry.path.clone())
 374            })
 375            .next();
 376
 377        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 378        // supported. This doesn't seem to occur often in GitHub repositories.
 379        selected_rules_file.map(|path_in_worktree| {
 380            let project_path = ProjectPath {
 381                worktree_id,
 382                path: path_in_worktree.clone(),
 383            };
 384            let buffer_task =
 385                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 386            let rope_task = cx.spawn(async move |cx| {
 387                buffer_task.await?.read_with(cx, |buffer, cx| {
 388                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 389                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 390                })?
 391            });
 392            // Build a string from the rope on a background thread.
 393            cx.background_spawn(async move {
 394                let (project_entry_id, rope) = rope_task.await?;
 395                anyhow::Ok(RulesFileContext {
 396                    path_in_worktree,
 397                    text: rope.to_string().trim().to_string(),
 398                    project_entry_id: project_entry_id.to_usize(),
 399                })
 400            })
 401        })
 402    }
 403
 404    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 405        &self.prompt_store
 406    }
 407
 408    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 409        self.tools.clone()
 410    }
 411
 412    /// Returns the number of threads.
 413    pub fn thread_count(&self) -> usize {
 414        self.threads.len()
 415    }
 416
 417    pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 418        // ordering is from "ORDER BY" in `list_threads`
 419        self.threads.iter()
 420    }
 421
 422    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 423        cx.new(|cx| {
 424            Thread::new(
 425                self.project.clone(),
 426                self.tools.clone(),
 427                self.prompt_builder.clone(),
 428                self.project_context.clone(),
 429                cx,
 430            )
 431        })
 432    }
 433
 434    pub fn create_thread_from_serialized(
 435        &mut self,
 436        serialized: SerializedThread,
 437        cx: &mut Context<Self>,
 438    ) -> Entity<Thread> {
 439        cx.new(|cx| {
 440            Thread::deserialize(
 441                ThreadId::new(),
 442                serialized,
 443                self.project.clone(),
 444                self.tools.clone(),
 445                self.prompt_builder.clone(),
 446                self.project_context.clone(),
 447                None,
 448                cx,
 449            )
 450        })
 451    }
 452
 453    pub fn open_thread(
 454        &self,
 455        id: &ThreadId,
 456        window: &mut Window,
 457        cx: &mut Context<Self>,
 458    ) -> Task<Result<Entity<Thread>>> {
 459        let id = id.clone();
 460        let database_future = ThreadsDatabase::global_future(cx);
 461        let this = cx.weak_entity();
 462        window.spawn(cx, async move |cx| {
 463            let database = database_future.await.map_err(|err| anyhow!(err))?;
 464            let thread = database
 465                .try_find_thread(id.clone())
 466                .await?
 467                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 468
 469            let thread = this.update_in(cx, |this, window, cx| {
 470                cx.new(|cx| {
 471                    Thread::deserialize(
 472                        id.clone(),
 473                        thread,
 474                        this.project.clone(),
 475                        this.tools.clone(),
 476                        this.prompt_builder.clone(),
 477                        this.project_context.clone(),
 478                        Some(window),
 479                        cx,
 480                    )
 481                })
 482            })?;
 483
 484            Ok(thread)
 485        })
 486    }
 487
 488    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 489        let (metadata, serialized_thread) =
 490            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 491
 492        let database_future = ThreadsDatabase::global_future(cx);
 493        cx.spawn(async move |this, cx| {
 494            let serialized_thread = serialized_thread.await?;
 495            let database = database_future.await.map_err(|err| anyhow!(err))?;
 496            database.save_thread(metadata, serialized_thread).await?;
 497
 498            this.update(cx, |this, cx| this.reload(cx))?.await
 499        })
 500    }
 501
 502    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 503        let id = id.clone();
 504        let database_future = ThreadsDatabase::global_future(cx);
 505        cx.spawn(async move |this, cx| {
 506            let database = database_future.await.map_err(|err| anyhow!(err))?;
 507            database.delete_thread(id.clone()).await?;
 508
 509            this.update(cx, |this, cx| {
 510                this.threads.retain(|thread| thread.id != id);
 511                cx.notify();
 512            })
 513        })
 514    }
 515
 516    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 517        let database_future = ThreadsDatabase::global_future(cx);
 518        cx.spawn(async move |this, cx| {
 519            let threads = database_future
 520                .await
 521                .map_err(|err| anyhow!(err))?
 522                .list_threads()
 523                .await?;
 524
 525            this.update(cx, |this, cx| {
 526                this.threads = threads;
 527                cx.notify();
 528            })
 529        })
 530    }
 531
 532    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 533        let context_server_store = self.project.read(cx).context_server_store();
 534        cx.subscribe(&context_server_store, Self::handle_context_server_event)
 535            .detach();
 536
 537        // Check for any servers that were already running before the handler was registered
 538        for server in context_server_store.read(cx).running_servers() {
 539            self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
 540        }
 541    }
 542
 543    fn handle_context_server_event(
 544        &mut self,
 545        context_server_store: Entity<ContextServerStore>,
 546        event: &project::context_server_store::Event,
 547        cx: &mut Context<Self>,
 548    ) {
 549        let tool_working_set = self.tools.clone();
 550        match event {
 551            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 552                match status {
 553                    ContextServerStatus::Starting => {}
 554                    ContextServerStatus::Running => {
 555                        self.load_context_server_tools(server_id.clone(), context_server_store, cx);
 556                    }
 557                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 558                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 559                            tool_working_set.update(cx, |tool_working_set, cx| {
 560                                tool_working_set.remove(&tool_ids, cx);
 561                            });
 562                        }
 563                    }
 564                }
 565            }
 566        }
 567    }
 568
 569    fn load_context_server_tools(
 570        &self,
 571        server_id: ContextServerId,
 572        context_server_store: Entity<ContextServerStore>,
 573        cx: &mut Context<Self>,
 574    ) {
 575        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
 576            return;
 577        };
 578        let tool_working_set = self.tools.clone();
 579        cx.spawn(async move |this, cx| {
 580            let Some(protocol) = server.client() else {
 581                return;
 582            };
 583
 584            if protocol.capable(context_server::protocol::ServerCapability::Tools)
 585                && let Some(response) = protocol
 586                    .request::<context_server::types::requests::ListTools>(())
 587                    .await
 588                    .log_err()
 589            {
 590                let tool_ids = tool_working_set
 591                    .update(cx, |tool_working_set, cx| {
 592                        tool_working_set.extend(
 593                            response.tools.into_iter().map(|tool| {
 594                                Arc::new(ContextServerTool::new(
 595                                    context_server_store.clone(),
 596                                    server.id(),
 597                                    tool,
 598                                )) as Arc<dyn Tool>
 599                            }),
 600                            cx,
 601                        )
 602                    })
 603                    .log_err();
 604
 605                if let Some(tool_ids) = tool_ids {
 606                    this.update(cx, |this, _| {
 607                        this.context_server_tool_ids.insert(server_id, tool_ids);
 608                    })
 609                    .log_err();
 610                }
 611            }
 612        })
 613        .detach();
 614    }
 615}
 616
 617#[derive(Debug, Clone, Serialize, Deserialize)]
 618pub struct SerializedThreadMetadata {
 619    pub id: ThreadId,
 620    pub summary: SharedString,
 621    pub updated_at: DateTime<Utc>,
 622}
 623
 624#[derive(Serialize, Deserialize, Debug, PartialEq)]
 625pub struct SerializedThread {
 626    pub version: String,
 627    pub summary: SharedString,
 628    pub updated_at: DateTime<Utc>,
 629    pub messages: Vec<SerializedMessage>,
 630    #[serde(default)]
 631    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 632    #[serde(default)]
 633    pub cumulative_token_usage: TokenUsage,
 634    #[serde(default)]
 635    pub request_token_usage: Vec<TokenUsage>,
 636    #[serde(default)]
 637    pub detailed_summary_state: DetailedSummaryState,
 638    #[serde(default)]
 639    pub exceeded_window_error: Option<ExceededWindowError>,
 640    #[serde(default)]
 641    pub model: Option<SerializedLanguageModel>,
 642    #[serde(default)]
 643    pub completion_mode: Option<CompletionMode>,
 644    #[serde(default)]
 645    pub tool_use_limit_reached: bool,
 646    #[serde(default)]
 647    pub profile: Option<AgentProfileId>,
 648}
 649
 650#[derive(Serialize, Deserialize, Debug, PartialEq)]
 651pub struct SerializedLanguageModel {
 652    pub provider: String,
 653    pub model: String,
 654}
 655
 656impl SerializedThread {
 657    pub const VERSION: &'static str = "0.2.0";
 658
 659    pub fn from_json(json: &[u8]) -> Result<Self> {
 660        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 661        match saved_thread_json.get("version") {
 662            Some(serde_json::Value::String(version)) => match version.as_str() {
 663                SerializedThreadV0_1_0::VERSION => {
 664                    let saved_thread =
 665                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 666                    Ok(saved_thread.upgrade())
 667                }
 668                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 669                    saved_thread_json,
 670                )?),
 671                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 672            },
 673            None => {
 674                let saved_thread =
 675                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 676                Ok(saved_thread.upgrade())
 677            }
 678            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 679        }
 680    }
 681}
 682
 683#[derive(Serialize, Deserialize, Debug)]
 684pub struct SerializedThreadV0_1_0(
 685    // The structure did not change, so we are reusing the latest SerializedThread.
 686    // When making the next version, make sure this points to SerializedThreadV0_2_0
 687    SerializedThread,
 688);
 689
 690impl SerializedThreadV0_1_0 {
 691    pub const VERSION: &'static str = "0.1.0";
 692
 693    pub fn upgrade(self) -> SerializedThread {
 694        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 695
 696        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 697
 698        for message in self.0.messages {
 699            if message.role == Role::User
 700                && !message.tool_results.is_empty()
 701                && let Some(last_message) = messages.last_mut()
 702            {
 703                debug_assert!(last_message.role == Role::Assistant);
 704
 705                last_message.tool_results = message.tool_results;
 706                continue;
 707            }
 708
 709            messages.push(message);
 710        }
 711
 712        SerializedThread {
 713            messages,
 714            version: SerializedThread::VERSION.to_string(),
 715            ..self.0
 716        }
 717    }
 718}
 719
 720#[derive(Debug, Serialize, Deserialize, PartialEq)]
 721pub struct SerializedMessage {
 722    pub id: MessageId,
 723    pub role: Role,
 724    #[serde(default)]
 725    pub segments: Vec<SerializedMessageSegment>,
 726    #[serde(default)]
 727    pub tool_uses: Vec<SerializedToolUse>,
 728    #[serde(default)]
 729    pub tool_results: Vec<SerializedToolResult>,
 730    #[serde(default)]
 731    pub context: String,
 732    #[serde(default)]
 733    pub creases: Vec<SerializedCrease>,
 734    #[serde(default)]
 735    pub is_hidden: bool,
 736}
 737
 738#[derive(Debug, Serialize, Deserialize, PartialEq)]
 739#[serde(tag = "type")]
 740pub enum SerializedMessageSegment {
 741    #[serde(rename = "text")]
 742    Text {
 743        text: String,
 744    },
 745    #[serde(rename = "thinking")]
 746    Thinking {
 747        text: String,
 748        #[serde(skip_serializing_if = "Option::is_none")]
 749        signature: Option<String>,
 750    },
 751    RedactedThinking {
 752        data: String,
 753    },
 754}
 755
 756#[derive(Debug, Serialize, Deserialize, PartialEq)]
 757pub struct SerializedToolUse {
 758    pub id: LanguageModelToolUseId,
 759    pub name: SharedString,
 760    pub input: serde_json::Value,
 761}
 762
 763#[derive(Debug, Serialize, Deserialize, PartialEq)]
 764pub struct SerializedToolResult {
 765    pub tool_use_id: LanguageModelToolUseId,
 766    pub is_error: bool,
 767    pub content: LanguageModelToolResultContent,
 768    pub output: Option<serde_json::Value>,
 769}
 770
 771#[derive(Serialize, Deserialize)]
 772struct LegacySerializedThread {
 773    pub summary: SharedString,
 774    pub updated_at: DateTime<Utc>,
 775    pub messages: Vec<LegacySerializedMessage>,
 776    #[serde(default)]
 777    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 778}
 779
 780impl LegacySerializedThread {
 781    pub fn upgrade(self) -> SerializedThread {
 782        SerializedThread {
 783            version: SerializedThread::VERSION.to_string(),
 784            summary: self.summary,
 785            updated_at: self.updated_at,
 786            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 787            initial_project_snapshot: self.initial_project_snapshot,
 788            cumulative_token_usage: TokenUsage::default(),
 789            request_token_usage: Vec::new(),
 790            detailed_summary_state: DetailedSummaryState::default(),
 791            exceeded_window_error: None,
 792            model: None,
 793            completion_mode: None,
 794            tool_use_limit_reached: false,
 795            profile: None,
 796        }
 797    }
 798}
 799
 800#[derive(Debug, Serialize, Deserialize)]
 801struct LegacySerializedMessage {
 802    pub id: MessageId,
 803    pub role: Role,
 804    pub text: String,
 805    #[serde(default)]
 806    pub tool_uses: Vec<SerializedToolUse>,
 807    #[serde(default)]
 808    pub tool_results: Vec<SerializedToolResult>,
 809}
 810
 811impl LegacySerializedMessage {
 812    fn upgrade(self) -> SerializedMessage {
 813        SerializedMessage {
 814            id: self.id,
 815            role: self.role,
 816            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 817            tool_uses: self.tool_uses,
 818            tool_results: self.tool_results,
 819            context: String::new(),
 820            creases: Vec::new(),
 821            is_hidden: false,
 822        }
 823    }
 824}
 825
 826#[derive(Debug, Serialize, Deserialize, PartialEq)]
 827pub struct SerializedCrease {
 828    pub start: usize,
 829    pub end: usize,
 830    pub icon_path: SharedString,
 831    pub label: SharedString,
 832}
 833
 834struct GlobalThreadsDatabase(
 835    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 836);
 837
 838impl Global for GlobalThreadsDatabase {}
 839
 840pub(crate) struct ThreadsDatabase {
 841    executor: BackgroundExecutor,
 842    connection: Arc<Mutex<Connection>>,
 843}
 844
 845impl ThreadsDatabase {
 846    fn connection(&self) -> Arc<Mutex<Connection>> {
 847        self.connection.clone()
 848    }
 849
 850    const COMPRESSION_LEVEL: i32 = 3;
 851}
 852
 853impl Bind for ThreadId {
 854    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 855        self.to_string().bind(statement, start_index)
 856    }
 857}
 858
 859impl Column for ThreadId {
 860    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 861        let (id_str, next_index) = String::column(statement, start_index)?;
 862        Ok((ThreadId::from(id_str.as_str()), next_index))
 863    }
 864}
 865
 866impl ThreadsDatabase {
 867    fn global_future(
 868        cx: &mut App,
 869    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 870        GlobalThreadsDatabase::global(cx).0.clone()
 871    }
 872
 873    fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 874        let executor = cx.background_executor().clone();
 875        let database_future = executor
 876            .spawn({
 877                let executor = executor.clone();
 878                let threads_dir = paths::data_dir().join("threads");
 879                async move { ThreadsDatabase::new(fs, threads_dir, executor).await }
 880            })
 881            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 882            .boxed()
 883            .shared();
 884
 885        cx.set_global(GlobalThreadsDatabase(database_future));
 886    }
 887
 888    pub async fn new(
 889        fs: Arc<dyn Fs>,
 890        threads_dir: PathBuf,
 891        executor: BackgroundExecutor,
 892    ) -> Result<Self> {
 893        fs.create_dir(&threads_dir).await?;
 894
 895        let sqlite_path = threads_dir.join("threads.db");
 896        let mdb_path = threads_dir.join("threads-db.1.mdb");
 897
 898        let needs_migration_from_heed = fs.is_file(&mdb_path).await;
 899
 900        let connection = if *ZED_STATELESS {
 901            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
 902        } else if cfg!(any(feature = "test-support", test)) {
 903            // rust stores the name of the test on the current thread.
 904            // We use this to automatically create a database that will
 905            // be shared within the test (for the test_retrieve_old_thread)
 906            // but not with concurrent tests.
 907            let thread = std::thread::current();
 908            let test_name = thread.name();
 909            Connection::open_memory(Some(&format!(
 910                "THREAD_FALLBACK_{}",
 911                test_name.unwrap_or_default()
 912            )))
 913        } else {
 914            Connection::open_file(&sqlite_path.to_string_lossy())
 915        };
 916
 917        connection.exec(indoc! {"
 918                CREATE TABLE IF NOT EXISTS threads (
 919                    id TEXT PRIMARY KEY,
 920                    summary TEXT NOT NULL,
 921                    updated_at TEXT NOT NULL,
 922                    data_type TEXT NOT NULL,
 923                    data BLOB NOT NULL
 924                )
 925            "})?()
 926        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 927
 928        let db = Self {
 929            executor: executor.clone(),
 930            connection: Arc::new(Mutex::new(connection)),
 931        };
 932
 933        if needs_migration_from_heed {
 934            let db_connection = db.connection();
 935            let executor_clone = executor.clone();
 936            executor
 937                .spawn(async move {
 938                    log::info!("Starting threads.db migration");
 939                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 940                    fs.remove_dir(
 941                        &mdb_path,
 942                        RemoveOptions {
 943                            recursive: true,
 944                            ignore_if_not_exists: true,
 945                        },
 946                    )
 947                    .await?;
 948                    log::info!("threads.db migrated to sqlite");
 949                    Ok::<(), anyhow::Error>(())
 950                })
 951                .detach();
 952        }
 953
 954        Ok(db)
 955    }
 956
 957    // Remove this migration after 2025-09-01
 958    fn migrate_from_heed(
 959        mdb_path: &Path,
 960        connection: Arc<Mutex<Connection>>,
 961        _executor: BackgroundExecutor,
 962    ) -> Result<()> {
 963        use heed::types::SerdeBincode;
 964        struct SerializedThreadHeed(SerializedThread);
 965
 966        impl heed::BytesEncode<'_> for SerializedThreadHeed {
 967            type EItem = SerializedThreadHeed;
 968
 969            fn bytes_encode(
 970                item: &Self::EItem,
 971            ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
 972                serde_json::to_vec(&item.0)
 973                    .map(std::borrow::Cow::Owned)
 974                    .map_err(Into::into)
 975            }
 976        }
 977
 978        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
 979            type DItem = SerializedThreadHeed;
 980
 981            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
 982                SerializedThread::from_json(bytes)
 983                    .map(SerializedThreadHeed)
 984                    .map_err(Into::into)
 985            }
 986        }
 987
 988        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
 989
 990        let env = unsafe {
 991            heed::EnvOpenOptions::new()
 992                .map_size(ONE_GB_IN_BYTES)
 993                .max_dbs(1)
 994                .open(mdb_path)?
 995        };
 996
 997        let txn = env.write_txn()?;
 998        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
 999            .open_database(&txn, Some("threads"))?
1000            .ok_or_else(|| anyhow!("threads database not found"))?;
1001
1002        for result in threads.iter(&txn)? {
1003            let (thread_id, thread_heed) = result?;
1004            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
1005        }
1006
1007        Ok(())
1008    }
1009
1010    fn save_thread_sync(
1011        connection: &Arc<Mutex<Connection>>,
1012        id: ThreadId,
1013        thread: SerializedThread,
1014    ) -> Result<()> {
1015        let json_data = serde_json::to_string(&thread)?;
1016        let summary = thread.summary.to_string();
1017        let updated_at = thread.updated_at.to_rfc3339();
1018
1019        let connection = connection.lock().unwrap();
1020
1021        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1022        let data_type = DataType::Zstd;
1023        let data = compressed;
1024
1025        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1026            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1027        "})?;
1028
1029        insert((id, summary, updated_at, data_type, data))?;
1030
1031        Ok(())
1032    }
1033
1034    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1035        let connection = self.connection.clone();
1036
1037        self.executor.spawn(async move {
1038            let connection = connection.lock().unwrap();
1039            let mut select =
1040                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1041                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1042            "})?;
1043
1044            let rows = select(())?;
1045            let mut threads = Vec::new();
1046
1047            for (id, summary, updated_at) in rows {
1048                threads.push(SerializedThreadMetadata {
1049                    id,
1050                    summary: summary.into(),
1051                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1052                });
1053            }
1054
1055            Ok(threads)
1056        })
1057    }
1058
1059    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1060        let connection = self.connection.clone();
1061
1062        self.executor.spawn(async move {
1063            let connection = connection.lock().unwrap();
1064            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1065                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1066            "})?;
1067
1068            let rows = select(id)?;
1069            if let Some((data_type, data)) = rows.into_iter().next() {
1070                let json_data = match data_type {
1071                    DataType::Zstd => {
1072                        let decompressed = zstd::decode_all(&data[..])?;
1073                        String::from_utf8(decompressed)?
1074                    }
1075                    DataType::Json => String::from_utf8(data)?,
1076                };
1077
1078                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1079                Ok(Some(thread))
1080            } else {
1081                Ok(None)
1082            }
1083        })
1084    }
1085
1086    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1087        let connection = self.connection.clone();
1088
1089        self.executor
1090            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1091    }
1092
1093    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1094        let connection = self.connection.clone();
1095
1096        self.executor.spawn(async move {
1097            let connection = connection.lock().unwrap();
1098
1099            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1100                DELETE FROM threads WHERE id = ?
1101            "})?;
1102
1103            delete(id)?;
1104
1105            Ok(())
1106        })
1107    }
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112    use super::*;
1113    use crate::thread::{DetailedSummaryState, MessageId};
1114    use chrono::Utc;
1115    use language_model::{Role, TokenUsage};
1116    use pretty_assertions::assert_eq;
1117
1118    #[test]
1119    fn test_legacy_serialized_thread_upgrade() {
1120        let updated_at = Utc::now();
1121        let legacy_thread = LegacySerializedThread {
1122            summary: "Test conversation".into(),
1123            updated_at,
1124            messages: vec![LegacySerializedMessage {
1125                id: MessageId(1),
1126                role: Role::User,
1127                text: "Hello, world!".to_string(),
1128                tool_uses: vec![],
1129                tool_results: vec![],
1130            }],
1131            initial_project_snapshot: None,
1132        };
1133
1134        let upgraded = legacy_thread.upgrade();
1135
1136        assert_eq!(
1137            upgraded,
1138            SerializedThread {
1139                summary: "Test conversation".into(),
1140                updated_at,
1141                messages: vec![SerializedMessage {
1142                    id: MessageId(1),
1143                    role: Role::User,
1144                    segments: vec![SerializedMessageSegment::Text {
1145                        text: "Hello, world!".to_string()
1146                    }],
1147                    tool_uses: vec![],
1148                    tool_results: vec![],
1149                    context: "".to_string(),
1150                    creases: vec![],
1151                    is_hidden: false
1152                }],
1153                version: SerializedThread::VERSION.to_string(),
1154                initial_project_snapshot: None,
1155                cumulative_token_usage: TokenUsage::default(),
1156                request_token_usage: vec![],
1157                detailed_summary_state: DetailedSummaryState::default(),
1158                exceeded_window_error: None,
1159                model: None,
1160                completion_mode: None,
1161                tool_use_limit_reached: false,
1162                profile: None
1163            }
1164        )
1165    }
1166
1167    #[test]
1168    fn test_serialized_threadv0_1_0_upgrade() {
1169        let updated_at = Utc::now();
1170        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1171            summary: "Test conversation".into(),
1172            updated_at,
1173            messages: vec![
1174                SerializedMessage {
1175                    id: MessageId(1),
1176                    role: Role::User,
1177                    segments: vec![SerializedMessageSegment::Text {
1178                        text: "Use tool_1".to_string(),
1179                    }],
1180                    tool_uses: vec![],
1181                    tool_results: vec![],
1182                    context: "".to_string(),
1183                    creases: vec![],
1184                    is_hidden: false,
1185                },
1186                SerializedMessage {
1187                    id: MessageId(2),
1188                    role: Role::Assistant,
1189                    segments: vec![SerializedMessageSegment::Text {
1190                        text: "I want to use a tool".to_string(),
1191                    }],
1192                    tool_uses: vec![SerializedToolUse {
1193                        id: "abc".into(),
1194                        name: "tool_1".into(),
1195                        input: serde_json::Value::Null,
1196                    }],
1197                    tool_results: vec![],
1198                    context: "".to_string(),
1199                    creases: vec![],
1200                    is_hidden: false,
1201                },
1202                SerializedMessage {
1203                    id: MessageId(1),
1204                    role: Role::User,
1205                    segments: vec![SerializedMessageSegment::Text {
1206                        text: "Here is the tool result".to_string(),
1207                    }],
1208                    tool_uses: vec![],
1209                    tool_results: vec![SerializedToolResult {
1210                        tool_use_id: "abc".into(),
1211                        is_error: false,
1212                        content: LanguageModelToolResultContent::Text("abcdef".into()),
1213                        output: Some(serde_json::Value::Null),
1214                    }],
1215                    context: "".to_string(),
1216                    creases: vec![],
1217                    is_hidden: false,
1218                },
1219            ],
1220            version: SerializedThreadV0_1_0::VERSION.to_string(),
1221            initial_project_snapshot: None,
1222            cumulative_token_usage: TokenUsage::default(),
1223            request_token_usage: vec![],
1224            detailed_summary_state: DetailedSummaryState::default(),
1225            exceeded_window_error: None,
1226            model: None,
1227            completion_mode: None,
1228            tool_use_limit_reached: false,
1229            profile: None,
1230        });
1231        let upgraded = thread_v0_1_0.upgrade();
1232
1233        assert_eq!(
1234            upgraded,
1235            SerializedThread {
1236                summary: "Test conversation".into(),
1237                updated_at,
1238                messages: vec![
1239                    SerializedMessage {
1240                        id: MessageId(1),
1241                        role: Role::User,
1242                        segments: vec![SerializedMessageSegment::Text {
1243                            text: "Use tool_1".to_string()
1244                        }],
1245                        tool_uses: vec![],
1246                        tool_results: vec![],
1247                        context: "".to_string(),
1248                        creases: vec![],
1249                        is_hidden: false
1250                    },
1251                    SerializedMessage {
1252                        id: MessageId(2),
1253                        role: Role::Assistant,
1254                        segments: vec![SerializedMessageSegment::Text {
1255                            text: "I want to use a tool".to_string(),
1256                        }],
1257                        tool_uses: vec![SerializedToolUse {
1258                            id: "abc".into(),
1259                            name: "tool_1".into(),
1260                            input: serde_json::Value::Null,
1261                        }],
1262                        tool_results: vec![SerializedToolResult {
1263                            tool_use_id: "abc".into(),
1264                            is_error: false,
1265                            content: LanguageModelToolResultContent::Text("abcdef".into()),
1266                            output: Some(serde_json::Value::Null),
1267                        }],
1268                        context: "".to_string(),
1269                        creases: vec![],
1270                        is_hidden: false,
1271                    },
1272                ],
1273                version: SerializedThread::VERSION.to_string(),
1274                initial_project_snapshot: None,
1275                cumulative_token_usage: TokenUsage::default(),
1276                request_token_usage: vec![],
1277                detailed_summary_state: DetailedSummaryState::default(),
1278                exceeded_window_error: None,
1279                model: None,
1280                completion_mode: None,
1281                tool_use_limit_reached: false,
1282                profile: None
1283            }
1284        )
1285    }
1286}