thread_store.rs

   1use crate::{
   2    context_server_tool::ContextServerTool,
   3    thread::{
   4        DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
   5    },
   6};
   7use agent_settings::{AgentProfileId, CompletionMode};
   8use anyhow::{Context as _, Result, anyhow};
   9use assistant_tool::{Tool, ToolId, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use context_server::ContextServerId;
  13use fs::{Fs, RemoveOptions};
  14use futures::{
  15    FutureExt as _, StreamExt as _,
  16    channel::{mpsc, oneshot},
  17    future::{self, BoxFuture, Shared},
  18};
  19use gpui::{
  20    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  21    Subscription, Task, Window, prelude::*,
  22};
  23use indoc::indoc;
  24use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  25use project::context_server_store::{ContextServerStatus, ContextServerStore};
  26use project::{Project, ProjectItem, ProjectPath, Worktree};
  27use prompt_store::{
  28    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  29    UserRulesContext, WorktreeContext,
  30};
  31use serde::{Deserialize, Serialize};
  32use sqlez::{
  33    bindable::{Bind, Column},
  34    connection::Connection,
  35    statement::Statement,
  36};
  37use std::{
  38    cell::{Ref, RefCell},
  39    path::{Path, PathBuf},
  40    rc::Rc,
  41    sync::{Arc, LazyLock, Mutex},
  42};
  43use util::{ResultExt as _, rel_path::RelPath};
  44
  45use zed_env_vars::ZED_STATELESS;
  46
  47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  48pub enum DataType {
  49    #[serde(rename = "json")]
  50    Json,
  51    #[serde(rename = "zstd")]
  52    Zstd,
  53}
  54
  55impl Bind for DataType {
  56    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  57        let value = match self {
  58            DataType::Json => "json",
  59            DataType::Zstd => "zstd",
  60        };
  61        value.bind(statement, start_index)
  62    }
  63}
  64
  65impl Column for DataType {
  66    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  67        let (value, next_index) = String::column(statement, start_index)?;
  68        let data_type = match value.as_str() {
  69            "json" => DataType::Json,
  70            "zstd" => DataType::Zstd,
  71            _ => anyhow::bail!("Unknown data type: {}", value),
  72        };
  73        Ok((data_type, next_index))
  74    }
  75}
  76
  77static RULES_FILE_NAMES: LazyLock<[&RelPath; 9]> = LazyLock::new(|| {
  78    [
  79        RelPath::unix(".rules").unwrap(),
  80        RelPath::unix(".cursorrules").unwrap(),
  81        RelPath::unix(".windsurfrules").unwrap(),
  82        RelPath::unix(".clinerules").unwrap(),
  83        RelPath::unix(".github/copilot-instructions.md").unwrap(),
  84        RelPath::unix("CLAUDE.md").unwrap(),
  85        RelPath::unix("AGENT.md").unwrap(),
  86        RelPath::unix("AGENTS.md").unwrap(),
  87        RelPath::unix("GEMINI.md").unwrap(),
  88    ]
  89});
  90
  91pub fn init(fs: Arc<dyn Fs>, cx: &mut App) {
  92    ThreadsDatabase::init(fs, cx);
  93}
  94
  95/// A system prompt shared by all threads created by this ThreadStore
  96#[derive(Clone, Default)]
  97pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  98
  99impl SharedProjectContext {
 100    pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
 101        self.0.borrow()
 102    }
 103}
 104
 105pub type TextThreadStore = assistant_context::ContextStore;
 106
 107pub struct ThreadStore {
 108    project: Entity<Project>,
 109    tools: Entity<ToolWorkingSet>,
 110    prompt_builder: Arc<PromptBuilder>,
 111    prompt_store: Option<Entity<PromptStore>>,
 112    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 113    threads: Vec<SerializedThreadMetadata>,
 114    project_context: SharedProjectContext,
 115    reload_system_prompt_tx: mpsc::Sender<()>,
 116    _reload_system_prompt_task: Task<()>,
 117    _subscriptions: Vec<Subscription>,
 118}
 119
 120pub struct RulesLoadingError {
 121    pub message: SharedString,
 122}
 123
 124impl EventEmitter<RulesLoadingError> for ThreadStore {}
 125
 126impl ThreadStore {
 127    pub fn load(
 128        project: Entity<Project>,
 129        tools: Entity<ToolWorkingSet>,
 130        prompt_store: Option<Entity<PromptStore>>,
 131        prompt_builder: Arc<PromptBuilder>,
 132        cx: &mut App,
 133    ) -> Task<Result<Entity<Self>>> {
 134        cx.spawn(async move |cx| {
 135            let (thread_store, ready_rx) = cx.update(|cx| {
 136                let mut option_ready_rx = None;
 137                let thread_store = cx.new(|cx| {
 138                    let (thread_store, ready_rx) =
 139                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 140                    option_ready_rx = Some(ready_rx);
 141                    thread_store
 142                });
 143                (thread_store, option_ready_rx.take().unwrap())
 144            })?;
 145            ready_rx.await?;
 146            Ok(thread_store)
 147        })
 148    }
 149
 150    fn new(
 151        project: Entity<Project>,
 152        tools: Entity<ToolWorkingSet>,
 153        prompt_builder: Arc<PromptBuilder>,
 154        prompt_store: Option<Entity<PromptStore>>,
 155        cx: &mut Context<Self>,
 156    ) -> (Self, oneshot::Receiver<()>) {
 157        let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 158
 159        if let Some(prompt_store) = prompt_store.as_ref() {
 160            subscriptions.push(cx.subscribe(
 161                prompt_store,
 162                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 163                    this.enqueue_system_prompt_reload();
 164                },
 165            ))
 166        }
 167
 168        // This channel and task prevent concurrent and redundant loading of the system prompt.
 169        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 170        let (ready_tx, ready_rx) = oneshot::channel();
 171        let mut ready_tx = Some(ready_tx);
 172        let reload_system_prompt_task = cx.spawn({
 173            let prompt_store = prompt_store.clone();
 174            async move |thread_store, cx| {
 175                loop {
 176                    let Some(reload_task) = thread_store
 177                        .update(cx, |thread_store, cx| {
 178                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 179                        })
 180                        .ok()
 181                    else {
 182                        return;
 183                    };
 184                    reload_task.await;
 185                    if let Some(ready_tx) = ready_tx.take() {
 186                        ready_tx.send(()).ok();
 187                    }
 188                    reload_system_prompt_rx.next().await;
 189                }
 190            }
 191        });
 192
 193        let this = Self {
 194            project,
 195            tools,
 196            prompt_builder,
 197            prompt_store,
 198            context_server_tool_ids: HashMap::default(),
 199            threads: Vec::new(),
 200            project_context: SharedProjectContext::default(),
 201            reload_system_prompt_tx,
 202            _reload_system_prompt_task: reload_system_prompt_task,
 203            _subscriptions: subscriptions,
 204        };
 205        this.register_context_server_handlers(cx);
 206        this.reload(cx).detach_and_log_err(cx);
 207        (this, ready_rx)
 208    }
 209
 210    #[cfg(any(test, feature = "test-support"))]
 211    pub fn fake(project: Entity<Project>, cx: &mut App) -> Self {
 212        Self {
 213            project,
 214            tools: cx.new(|_| ToolWorkingSet::default()),
 215            prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
 216            prompt_store: None,
 217            context_server_tool_ids: HashMap::default(),
 218            threads: Vec::new(),
 219            project_context: SharedProjectContext::default(),
 220            reload_system_prompt_tx: mpsc::channel(0).0,
 221            _reload_system_prompt_task: Task::ready(()),
 222            _subscriptions: vec![],
 223        }
 224    }
 225
 226    fn handle_project_event(
 227        &mut self,
 228        _project: Entity<Project>,
 229        event: &project::Event,
 230        _cx: &mut Context<Self>,
 231    ) {
 232        match event {
 233            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 234                self.enqueue_system_prompt_reload();
 235            }
 236            project::Event::WorktreeUpdatedEntries(_, items) => {
 237                if items
 238                    .iter()
 239                    .any(|(path, _, _)| RULES_FILE_NAMES.iter().any(|name| path.as_ref() == *name))
 240                {
 241                    self.enqueue_system_prompt_reload();
 242                }
 243            }
 244            _ => {}
 245        }
 246    }
 247
 248    fn enqueue_system_prompt_reload(&mut self) {
 249        self.reload_system_prompt_tx.try_send(()).ok();
 250    }
 251
 252    // Note that this should only be called from `reload_system_prompt_task`.
 253    fn reload_system_prompt(
 254        &self,
 255        prompt_store: Option<Entity<PromptStore>>,
 256        cx: &mut Context<Self>,
 257    ) -> Task<()> {
 258        let worktrees = self
 259            .project
 260            .read(cx)
 261            .visible_worktrees(cx)
 262            .collect::<Vec<_>>();
 263        let worktree_tasks = worktrees
 264            .into_iter()
 265            .map(|worktree| {
 266                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 267            })
 268            .collect::<Vec<_>>();
 269        let default_user_rules_task = match prompt_store {
 270            None => Task::ready(vec![]),
 271            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 272                let prompts = prompt_store.default_prompt_metadata();
 273                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 274                    let contents = prompt_store.load(prompt_metadata.id, cx);
 275                    async move { (contents.await, prompt_metadata) }
 276                });
 277                cx.background_spawn(future::join_all(load_tasks))
 278            }),
 279        };
 280
 281        cx.spawn(async move |this, cx| {
 282            let (worktrees, default_user_rules) =
 283                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 284
 285            let worktrees = worktrees
 286                .into_iter()
 287                .map(|(worktree, rules_error)| {
 288                    if let Some(rules_error) = rules_error {
 289                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 290                    }
 291                    worktree
 292                })
 293                .collect::<Vec<_>>();
 294
 295            let default_user_rules = default_user_rules
 296                .into_iter()
 297                .flat_map(|(contents, prompt_metadata)| match contents {
 298                    Ok(contents) => Some(UserRulesContext {
 299                        uuid: match prompt_metadata.id {
 300                            PromptId::User { uuid } => uuid,
 301                            PromptId::EditWorkflow => return None,
 302                        },
 303                        title: prompt_metadata.title.map(|title| title.to_string()),
 304                        contents,
 305                    }),
 306                    Err(err) => {
 307                        this.update(cx, |_, cx| {
 308                            cx.emit(RulesLoadingError {
 309                                message: format!("{err:?}").into(),
 310                            });
 311                        })
 312                        .ok();
 313                        None
 314                    }
 315                })
 316                .collect::<Vec<_>>();
 317
 318            this.update(cx, |this, _cx| {
 319                *this.project_context.0.borrow_mut() =
 320                    Some(ProjectContext::new(worktrees, default_user_rules));
 321            })
 322            .ok();
 323        })
 324    }
 325
 326    fn load_worktree_info_for_system_prompt(
 327        worktree: Entity<Worktree>,
 328        project: Entity<Project>,
 329        cx: &mut App,
 330    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 331        let tree = worktree.read(cx);
 332        let root_name = tree.root_name_str().into();
 333        let abs_path = tree.abs_path();
 334
 335        let mut context = WorktreeContext {
 336            root_name,
 337            abs_path,
 338            rules_file: None,
 339        };
 340
 341        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 342        let Some(rules_task) = rules_task else {
 343            return Task::ready((context, None));
 344        };
 345
 346        cx.spawn(async move |_| {
 347            let (rules_file, rules_file_error) = match rules_task.await {
 348                Ok(rules_file) => (Some(rules_file), None),
 349                Err(err) => (
 350                    None,
 351                    Some(RulesLoadingError {
 352                        message: format!("{err}").into(),
 353                    }),
 354                ),
 355            };
 356            context.rules_file = rules_file;
 357            (context, rules_file_error)
 358        })
 359    }
 360
 361    fn load_worktree_rules_file(
 362        worktree: Entity<Worktree>,
 363        project: Entity<Project>,
 364        cx: &mut App,
 365    ) -> Option<Task<Result<RulesFileContext>>> {
 366        let worktree = worktree.read(cx);
 367        let worktree_id = worktree.id();
 368        let selected_rules_file = RULES_FILE_NAMES
 369            .into_iter()
 370            .filter_map(|name| {
 371                worktree
 372                    .entry_for_path(name)
 373                    .filter(|entry| entry.is_file())
 374                    .map(|entry| entry.path.clone())
 375            })
 376            .next();
 377
 378        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 379        // supported. This doesn't seem to occur often in GitHub repositories.
 380        selected_rules_file.map(|path_in_worktree| {
 381            let project_path = ProjectPath {
 382                worktree_id,
 383                path: path_in_worktree.clone(),
 384            };
 385            let buffer_task =
 386                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 387            let rope_task = cx.spawn(async move |cx| {
 388                buffer_task.await?.read_with(cx, |buffer, cx| {
 389                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 390                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 391                })?
 392            });
 393            // Build a string from the rope on a background thread.
 394            cx.background_spawn(async move {
 395                let (project_entry_id, rope) = rope_task.await?;
 396                anyhow::Ok(RulesFileContext {
 397                    path_in_worktree,
 398                    text: rope.to_string().trim().to_string(),
 399                    project_entry_id: project_entry_id.to_usize(),
 400                })
 401            })
 402        })
 403    }
 404
 405    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 406        &self.prompt_store
 407    }
 408
 409    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 410        self.tools.clone()
 411    }
 412
 413    /// Returns the number of threads.
 414    pub fn thread_count(&self) -> usize {
 415        self.threads.len()
 416    }
 417
 418    pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 419        // ordering is from "ORDER BY" in `list_threads`
 420        self.threads.iter()
 421    }
 422
 423    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 424        cx.new(|cx| {
 425            Thread::new(
 426                self.project.clone(),
 427                self.tools.clone(),
 428                self.prompt_builder.clone(),
 429                self.project_context.clone(),
 430                cx,
 431            )
 432        })
 433    }
 434
 435    pub fn create_thread_from_serialized(
 436        &mut self,
 437        serialized: SerializedThread,
 438        cx: &mut Context<Self>,
 439    ) -> Entity<Thread> {
 440        cx.new(|cx| {
 441            Thread::deserialize(
 442                ThreadId::new(),
 443                serialized,
 444                self.project.clone(),
 445                self.tools.clone(),
 446                self.prompt_builder.clone(),
 447                self.project_context.clone(),
 448                None,
 449                cx,
 450            )
 451        })
 452    }
 453
 454    pub fn open_thread(
 455        &self,
 456        id: &ThreadId,
 457        window: &mut Window,
 458        cx: &mut Context<Self>,
 459    ) -> Task<Result<Entity<Thread>>> {
 460        let id = id.clone();
 461        let database_future = ThreadsDatabase::global_future(cx);
 462        let this = cx.weak_entity();
 463        window.spawn(cx, async move |cx| {
 464            let database = database_future.await.map_err(|err| anyhow!(err))?;
 465            let thread = database
 466                .try_find_thread(id.clone())
 467                .await?
 468                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 469
 470            let thread = this.update_in(cx, |this, window, cx| {
 471                cx.new(|cx| {
 472                    Thread::deserialize(
 473                        id.clone(),
 474                        thread,
 475                        this.project.clone(),
 476                        this.tools.clone(),
 477                        this.prompt_builder.clone(),
 478                        this.project_context.clone(),
 479                        Some(window),
 480                        cx,
 481                    )
 482                })
 483            })?;
 484
 485            Ok(thread)
 486        })
 487    }
 488
 489    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 490        let (metadata, serialized_thread) =
 491            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 492
 493        let database_future = ThreadsDatabase::global_future(cx);
 494        cx.spawn(async move |this, cx| {
 495            let serialized_thread = serialized_thread.await?;
 496            let database = database_future.await.map_err(|err| anyhow!(err))?;
 497            database.save_thread(metadata, serialized_thread).await?;
 498
 499            this.update(cx, |this, cx| this.reload(cx))?.await
 500        })
 501    }
 502
 503    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 504        let id = id.clone();
 505        let database_future = ThreadsDatabase::global_future(cx);
 506        cx.spawn(async move |this, cx| {
 507            let database = database_future.await.map_err(|err| anyhow!(err))?;
 508            database.delete_thread(id.clone()).await?;
 509
 510            this.update(cx, |this, cx| {
 511                this.threads.retain(|thread| thread.id != id);
 512                cx.notify();
 513            })
 514        })
 515    }
 516
 517    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 518        let database_future = ThreadsDatabase::global_future(cx);
 519        cx.spawn(async move |this, cx| {
 520            let threads = database_future
 521                .await
 522                .map_err(|err| anyhow!(err))?
 523                .list_threads()
 524                .await?;
 525
 526            this.update(cx, |this, cx| {
 527                this.threads = threads;
 528                cx.notify();
 529            })
 530        })
 531    }
 532
 533    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 534        let context_server_store = self.project.read(cx).context_server_store();
 535        cx.subscribe(&context_server_store, Self::handle_context_server_event)
 536            .detach();
 537
 538        // Check for any servers that were already running before the handler was registered
 539        for server in context_server_store.read(cx).running_servers() {
 540            self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
 541        }
 542    }
 543
 544    fn handle_context_server_event(
 545        &mut self,
 546        context_server_store: Entity<ContextServerStore>,
 547        event: &project::context_server_store::Event,
 548        cx: &mut Context<Self>,
 549    ) {
 550        let tool_working_set = self.tools.clone();
 551        match event {
 552            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 553                match status {
 554                    ContextServerStatus::Starting => {}
 555                    ContextServerStatus::Running => {
 556                        self.load_context_server_tools(server_id.clone(), context_server_store, cx);
 557                    }
 558                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 559                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 560                            tool_working_set.update(cx, |tool_working_set, cx| {
 561                                tool_working_set.remove(&tool_ids, cx);
 562                            });
 563                        }
 564                    }
 565                }
 566            }
 567        }
 568    }
 569
 570    fn load_context_server_tools(
 571        &self,
 572        server_id: ContextServerId,
 573        context_server_store: Entity<ContextServerStore>,
 574        cx: &mut Context<Self>,
 575    ) {
 576        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
 577            return;
 578        };
 579        let tool_working_set = self.tools.clone();
 580        cx.spawn(async move |this, cx| {
 581            let Some(protocol) = server.client() else {
 582                return;
 583            };
 584
 585            if protocol.capable(context_server::protocol::ServerCapability::Tools)
 586                && let Some(response) = protocol
 587                    .request::<context_server::types::requests::ListTools>(())
 588                    .await
 589                    .log_err()
 590            {
 591                let tool_ids = tool_working_set
 592                    .update(cx, |tool_working_set, cx| {
 593                        tool_working_set.extend(
 594                            response.tools.into_iter().map(|tool| {
 595                                Arc::new(ContextServerTool::new(
 596                                    context_server_store.clone(),
 597                                    server.id(),
 598                                    tool,
 599                                )) as Arc<dyn Tool>
 600                            }),
 601                            cx,
 602                        )
 603                    })
 604                    .log_err();
 605
 606                if let Some(tool_ids) = tool_ids {
 607                    this.update(cx, |this, _| {
 608                        this.context_server_tool_ids.insert(server_id, tool_ids);
 609                    })
 610                    .log_err();
 611                }
 612            }
 613        })
 614        .detach();
 615    }
 616}
 617
 618#[derive(Debug, Clone, Serialize, Deserialize)]
 619pub struct SerializedThreadMetadata {
 620    pub id: ThreadId,
 621    pub summary: SharedString,
 622    pub updated_at: DateTime<Utc>,
 623}
 624
 625#[derive(Serialize, Deserialize, Debug, PartialEq)]
 626pub struct SerializedThread {
 627    pub version: String,
 628    pub summary: SharedString,
 629    pub updated_at: DateTime<Utc>,
 630    pub messages: Vec<SerializedMessage>,
 631    #[serde(default)]
 632    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 633    #[serde(default)]
 634    pub cumulative_token_usage: TokenUsage,
 635    #[serde(default)]
 636    pub request_token_usage: Vec<TokenUsage>,
 637    #[serde(default)]
 638    pub detailed_summary_state: DetailedSummaryState,
 639    #[serde(default)]
 640    pub exceeded_window_error: Option<ExceededWindowError>,
 641    #[serde(default)]
 642    pub model: Option<SerializedLanguageModel>,
 643    #[serde(default)]
 644    pub completion_mode: Option<CompletionMode>,
 645    #[serde(default)]
 646    pub tool_use_limit_reached: bool,
 647    #[serde(default)]
 648    pub profile: Option<AgentProfileId>,
 649}
 650
 651#[derive(Serialize, Deserialize, Debug, PartialEq)]
 652pub struct SerializedLanguageModel {
 653    pub provider: String,
 654    pub model: String,
 655}
 656
 657impl SerializedThread {
 658    pub const VERSION: &'static str = "0.2.0";
 659
 660    pub fn from_json(json: &[u8]) -> Result<Self> {
 661        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 662        match saved_thread_json.get("version") {
 663            Some(serde_json::Value::String(version)) => match version.as_str() {
 664                SerializedThreadV0_1_0::VERSION => {
 665                    let saved_thread =
 666                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 667                    Ok(saved_thread.upgrade())
 668                }
 669                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 670                    saved_thread_json,
 671                )?),
 672                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 673            },
 674            None => {
 675                let saved_thread =
 676                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 677                Ok(saved_thread.upgrade())
 678            }
 679            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 680        }
 681    }
 682}
 683
 684#[derive(Serialize, Deserialize, Debug)]
 685pub struct SerializedThreadV0_1_0(
 686    // The structure did not change, so we are reusing the latest SerializedThread.
 687    // When making the next version, make sure this points to SerializedThreadV0_2_0
 688    SerializedThread,
 689);
 690
 691impl SerializedThreadV0_1_0 {
 692    pub const VERSION: &'static str = "0.1.0";
 693
 694    pub fn upgrade(self) -> SerializedThread {
 695        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 696
 697        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 698
 699        for message in self.0.messages {
 700            if message.role == Role::User
 701                && !message.tool_results.is_empty()
 702                && let Some(last_message) = messages.last_mut()
 703            {
 704                debug_assert!(last_message.role == Role::Assistant);
 705
 706                last_message.tool_results = message.tool_results;
 707                continue;
 708            }
 709
 710            messages.push(message);
 711        }
 712
 713        SerializedThread {
 714            messages,
 715            version: SerializedThread::VERSION.to_string(),
 716            ..self.0
 717        }
 718    }
 719}
 720
 721#[derive(Debug, Serialize, Deserialize, PartialEq)]
 722pub struct SerializedMessage {
 723    pub id: MessageId,
 724    pub role: Role,
 725    #[serde(default)]
 726    pub segments: Vec<SerializedMessageSegment>,
 727    #[serde(default)]
 728    pub tool_uses: Vec<SerializedToolUse>,
 729    #[serde(default)]
 730    pub tool_results: Vec<SerializedToolResult>,
 731    #[serde(default)]
 732    pub context: String,
 733    #[serde(default)]
 734    pub creases: Vec<SerializedCrease>,
 735    #[serde(default)]
 736    pub is_hidden: bool,
 737}
 738
 739#[derive(Debug, Serialize, Deserialize, PartialEq)]
 740#[serde(tag = "type")]
 741pub enum SerializedMessageSegment {
 742    #[serde(rename = "text")]
 743    Text {
 744        text: String,
 745    },
 746    #[serde(rename = "thinking")]
 747    Thinking {
 748        text: String,
 749        #[serde(skip_serializing_if = "Option::is_none")]
 750        signature: Option<String>,
 751    },
 752    RedactedThinking {
 753        data: String,
 754    },
 755}
 756
 757#[derive(Debug, Serialize, Deserialize, PartialEq)]
 758pub struct SerializedToolUse {
 759    pub id: LanguageModelToolUseId,
 760    pub name: SharedString,
 761    pub input: serde_json::Value,
 762}
 763
 764#[derive(Debug, Serialize, Deserialize, PartialEq)]
 765pub struct SerializedToolResult {
 766    pub tool_use_id: LanguageModelToolUseId,
 767    pub is_error: bool,
 768    pub content: LanguageModelToolResultContent,
 769    pub output: Option<serde_json::Value>,
 770}
 771
 772#[derive(Serialize, Deserialize)]
 773struct LegacySerializedThread {
 774    pub summary: SharedString,
 775    pub updated_at: DateTime<Utc>,
 776    pub messages: Vec<LegacySerializedMessage>,
 777    #[serde(default)]
 778    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 779}
 780
 781impl LegacySerializedThread {
 782    pub fn upgrade(self) -> SerializedThread {
 783        SerializedThread {
 784            version: SerializedThread::VERSION.to_string(),
 785            summary: self.summary,
 786            updated_at: self.updated_at,
 787            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 788            initial_project_snapshot: self.initial_project_snapshot,
 789            cumulative_token_usage: TokenUsage::default(),
 790            request_token_usage: Vec::new(),
 791            detailed_summary_state: DetailedSummaryState::default(),
 792            exceeded_window_error: None,
 793            model: None,
 794            completion_mode: None,
 795            tool_use_limit_reached: false,
 796            profile: None,
 797        }
 798    }
 799}
 800
 801#[derive(Debug, Serialize, Deserialize)]
 802struct LegacySerializedMessage {
 803    pub id: MessageId,
 804    pub role: Role,
 805    pub text: String,
 806    #[serde(default)]
 807    pub tool_uses: Vec<SerializedToolUse>,
 808    #[serde(default)]
 809    pub tool_results: Vec<SerializedToolResult>,
 810}
 811
 812impl LegacySerializedMessage {
 813    fn upgrade(self) -> SerializedMessage {
 814        SerializedMessage {
 815            id: self.id,
 816            role: self.role,
 817            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 818            tool_uses: self.tool_uses,
 819            tool_results: self.tool_results,
 820            context: String::new(),
 821            creases: Vec::new(),
 822            is_hidden: false,
 823        }
 824    }
 825}
 826
 827#[derive(Debug, Serialize, Deserialize, PartialEq)]
 828pub struct SerializedCrease {
 829    pub start: usize,
 830    pub end: usize,
 831    pub icon_path: SharedString,
 832    pub label: SharedString,
 833}
 834
 835struct GlobalThreadsDatabase(
 836    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 837);
 838
 839impl Global for GlobalThreadsDatabase {}
 840
 841pub(crate) struct ThreadsDatabase {
 842    executor: BackgroundExecutor,
 843    connection: Arc<Mutex<Connection>>,
 844}
 845
 846impl ThreadsDatabase {
 847    fn connection(&self) -> Arc<Mutex<Connection>> {
 848        self.connection.clone()
 849    }
 850
 851    const COMPRESSION_LEVEL: i32 = 3;
 852}
 853
 854impl Bind for ThreadId {
 855    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 856        self.to_string().bind(statement, start_index)
 857    }
 858}
 859
 860impl Column for ThreadId {
 861    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 862        let (id_str, next_index) = String::column(statement, start_index)?;
 863        Ok((ThreadId::from(id_str.as_str()), next_index))
 864    }
 865}
 866
 867impl ThreadsDatabase {
 868    fn global_future(
 869        cx: &mut App,
 870    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 871        GlobalThreadsDatabase::global(cx).0.clone()
 872    }
 873
 874    fn init(fs: Arc<dyn Fs>, cx: &mut App) {
 875        let executor = cx.background_executor().clone();
 876        let database_future = executor
 877            .spawn({
 878                let executor = executor.clone();
 879                let threads_dir = paths::data_dir().join("threads");
 880                async move { ThreadsDatabase::new(fs, threads_dir, executor).await }
 881            })
 882            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 883            .boxed()
 884            .shared();
 885
 886        cx.set_global(GlobalThreadsDatabase(database_future));
 887    }
 888
 889    pub async fn new(
 890        fs: Arc<dyn Fs>,
 891        threads_dir: PathBuf,
 892        executor: BackgroundExecutor,
 893    ) -> Result<Self> {
 894        fs.create_dir(&threads_dir).await?;
 895
 896        let sqlite_path = threads_dir.join("threads.db");
 897        let mdb_path = threads_dir.join("threads-db.1.mdb");
 898
 899        let needs_migration_from_heed = fs.is_file(&mdb_path).await;
 900
 901        let connection = if *ZED_STATELESS {
 902            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
 903        } else if cfg!(any(feature = "test-support", test)) {
 904            // rust stores the name of the test on the current thread.
 905            // We use this to automatically create a database that will
 906            // be shared within the test (for the test_retrieve_old_thread)
 907            // but not with concurrent tests.
 908            let thread = std::thread::current();
 909            let test_name = thread.name();
 910            Connection::open_memory(Some(&format!(
 911                "THREAD_FALLBACK_{}",
 912                test_name.unwrap_or_default()
 913            )))
 914        } else {
 915            Connection::open_file(&sqlite_path.to_string_lossy())
 916        };
 917
 918        connection.exec(indoc! {"
 919                CREATE TABLE IF NOT EXISTS threads (
 920                    id TEXT PRIMARY KEY,
 921                    summary TEXT NOT NULL,
 922                    updated_at TEXT NOT NULL,
 923                    data_type TEXT NOT NULL,
 924                    data BLOB NOT NULL
 925                )
 926            "})?()
 927        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 928
 929        let db = Self {
 930            executor: executor.clone(),
 931            connection: Arc::new(Mutex::new(connection)),
 932        };
 933
 934        if needs_migration_from_heed {
 935            let db_connection = db.connection();
 936            let executor_clone = executor.clone();
 937            executor
 938                .spawn(async move {
 939                    log::info!("Starting threads.db migration");
 940                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 941                    fs.remove_dir(
 942                        &mdb_path,
 943                        RemoveOptions {
 944                            recursive: true,
 945                            ignore_if_not_exists: true,
 946                        },
 947                    )
 948                    .await?;
 949                    log::info!("threads.db migrated to sqlite");
 950                    Ok::<(), anyhow::Error>(())
 951                })
 952                .detach();
 953        }
 954
 955        Ok(db)
 956    }
 957
 958    // Remove this migration after 2025-09-01
 959    fn migrate_from_heed(
 960        mdb_path: &Path,
 961        connection: Arc<Mutex<Connection>>,
 962        _executor: BackgroundExecutor,
 963    ) -> Result<()> {
 964        use heed::types::SerdeBincode;
 965        struct SerializedThreadHeed(SerializedThread);
 966
 967        impl heed::BytesEncode<'_> for SerializedThreadHeed {
 968            type EItem = SerializedThreadHeed;
 969
 970            fn bytes_encode(
 971                item: &Self::EItem,
 972            ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
 973                serde_json::to_vec(&item.0)
 974                    .map(std::borrow::Cow::Owned)
 975                    .map_err(Into::into)
 976            }
 977        }
 978
 979        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
 980            type DItem = SerializedThreadHeed;
 981
 982            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
 983                SerializedThread::from_json(bytes)
 984                    .map(SerializedThreadHeed)
 985                    .map_err(Into::into)
 986            }
 987        }
 988
 989        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
 990
 991        let env = unsafe {
 992            heed::EnvOpenOptions::new()
 993                .map_size(ONE_GB_IN_BYTES)
 994                .max_dbs(1)
 995                .open(mdb_path)?
 996        };
 997
 998        let txn = env.write_txn()?;
 999        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
1000            .open_database(&txn, Some("threads"))?
1001            .ok_or_else(|| anyhow!("threads database not found"))?;
1002
1003        for result in threads.iter(&txn)? {
1004            let (thread_id, thread_heed) = result?;
1005            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
1006        }
1007
1008        Ok(())
1009    }
1010
1011    fn save_thread_sync(
1012        connection: &Arc<Mutex<Connection>>,
1013        id: ThreadId,
1014        thread: SerializedThread,
1015    ) -> Result<()> {
1016        let json_data = serde_json::to_string(&thread)?;
1017        let summary = thread.summary.to_string();
1018        let updated_at = thread.updated_at.to_rfc3339();
1019
1020        let connection = connection.lock().unwrap();
1021
1022        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1023        let data_type = DataType::Zstd;
1024        let data = compressed;
1025
1026        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1027            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1028        "})?;
1029
1030        insert((id, summary, updated_at, data_type, data))?;
1031
1032        Ok(())
1033    }
1034
1035    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1036        let connection = self.connection.clone();
1037
1038        self.executor.spawn(async move {
1039            let connection = connection.lock().unwrap();
1040            let mut select =
1041                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1042                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1043            "})?;
1044
1045            let rows = select(())?;
1046            let mut threads = Vec::new();
1047
1048            for (id, summary, updated_at) in rows {
1049                threads.push(SerializedThreadMetadata {
1050                    id,
1051                    summary: summary.into(),
1052                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1053                });
1054            }
1055
1056            Ok(threads)
1057        })
1058    }
1059
1060    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1061        let connection = self.connection.clone();
1062
1063        self.executor.spawn(async move {
1064            let connection = connection.lock().unwrap();
1065            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1066                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1067            "})?;
1068
1069            let rows = select(id)?;
1070            if let Some((data_type, data)) = rows.into_iter().next() {
1071                let json_data = match data_type {
1072                    DataType::Zstd => {
1073                        let decompressed = zstd::decode_all(&data[..])?;
1074                        String::from_utf8(decompressed)?
1075                    }
1076                    DataType::Json => String::from_utf8(data)?,
1077                };
1078
1079                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1080                Ok(Some(thread))
1081            } else {
1082                Ok(None)
1083            }
1084        })
1085    }
1086
1087    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1088        let connection = self.connection.clone();
1089
1090        self.executor
1091            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1092    }
1093
1094    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1095        let connection = self.connection.clone();
1096
1097        self.executor.spawn(async move {
1098            let connection = connection.lock().unwrap();
1099
1100            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1101                DELETE FROM threads WHERE id = ?
1102            "})?;
1103
1104            delete(id)?;
1105
1106            Ok(())
1107        })
1108    }
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113    use super::*;
1114    use crate::thread::{DetailedSummaryState, MessageId};
1115    use chrono::Utc;
1116    use language_model::{Role, TokenUsage};
1117    use pretty_assertions::assert_eq;
1118
1119    #[test]
1120    fn test_legacy_serialized_thread_upgrade() {
1121        let updated_at = Utc::now();
1122        let legacy_thread = LegacySerializedThread {
1123            summary: "Test conversation".into(),
1124            updated_at,
1125            messages: vec![LegacySerializedMessage {
1126                id: MessageId(1),
1127                role: Role::User,
1128                text: "Hello, world!".to_string(),
1129                tool_uses: vec![],
1130                tool_results: vec![],
1131            }],
1132            initial_project_snapshot: None,
1133        };
1134
1135        let upgraded = legacy_thread.upgrade();
1136
1137        assert_eq!(
1138            upgraded,
1139            SerializedThread {
1140                summary: "Test conversation".into(),
1141                updated_at,
1142                messages: vec![SerializedMessage {
1143                    id: MessageId(1),
1144                    role: Role::User,
1145                    segments: vec![SerializedMessageSegment::Text {
1146                        text: "Hello, world!".to_string()
1147                    }],
1148                    tool_uses: vec![],
1149                    tool_results: vec![],
1150                    context: "".to_string(),
1151                    creases: vec![],
1152                    is_hidden: false
1153                }],
1154                version: SerializedThread::VERSION.to_string(),
1155                initial_project_snapshot: None,
1156                cumulative_token_usage: TokenUsage::default(),
1157                request_token_usage: vec![],
1158                detailed_summary_state: DetailedSummaryState::default(),
1159                exceeded_window_error: None,
1160                model: None,
1161                completion_mode: None,
1162                tool_use_limit_reached: false,
1163                profile: None
1164            }
1165        )
1166    }
1167
1168    #[test]
1169    fn test_serialized_threadv0_1_0_upgrade() {
1170        let updated_at = Utc::now();
1171        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1172            summary: "Test conversation".into(),
1173            updated_at,
1174            messages: vec![
1175                SerializedMessage {
1176                    id: MessageId(1),
1177                    role: Role::User,
1178                    segments: vec![SerializedMessageSegment::Text {
1179                        text: "Use tool_1".to_string(),
1180                    }],
1181                    tool_uses: vec![],
1182                    tool_results: vec![],
1183                    context: "".to_string(),
1184                    creases: vec![],
1185                    is_hidden: false,
1186                },
1187                SerializedMessage {
1188                    id: MessageId(2),
1189                    role: Role::Assistant,
1190                    segments: vec![SerializedMessageSegment::Text {
1191                        text: "I want to use a tool".to_string(),
1192                    }],
1193                    tool_uses: vec![SerializedToolUse {
1194                        id: "abc".into(),
1195                        name: "tool_1".into(),
1196                        input: serde_json::Value::Null,
1197                    }],
1198                    tool_results: vec![],
1199                    context: "".to_string(),
1200                    creases: vec![],
1201                    is_hidden: false,
1202                },
1203                SerializedMessage {
1204                    id: MessageId(1),
1205                    role: Role::User,
1206                    segments: vec![SerializedMessageSegment::Text {
1207                        text: "Here is the tool result".to_string(),
1208                    }],
1209                    tool_uses: vec![],
1210                    tool_results: vec![SerializedToolResult {
1211                        tool_use_id: "abc".into(),
1212                        is_error: false,
1213                        content: LanguageModelToolResultContent::Text("abcdef".into()),
1214                        output: Some(serde_json::Value::Null),
1215                    }],
1216                    context: "".to_string(),
1217                    creases: vec![],
1218                    is_hidden: false,
1219                },
1220            ],
1221            version: SerializedThreadV0_1_0::VERSION.to_string(),
1222            initial_project_snapshot: None,
1223            cumulative_token_usage: TokenUsage::default(),
1224            request_token_usage: vec![],
1225            detailed_summary_state: DetailedSummaryState::default(),
1226            exceeded_window_error: None,
1227            model: None,
1228            completion_mode: None,
1229            tool_use_limit_reached: false,
1230            profile: None,
1231        });
1232        let upgraded = thread_v0_1_0.upgrade();
1233
1234        assert_eq!(
1235            upgraded,
1236            SerializedThread {
1237                summary: "Test conversation".into(),
1238                updated_at,
1239                messages: vec![
1240                    SerializedMessage {
1241                        id: MessageId(1),
1242                        role: Role::User,
1243                        segments: vec![SerializedMessageSegment::Text {
1244                            text: "Use tool_1".to_string()
1245                        }],
1246                        tool_uses: vec![],
1247                        tool_results: vec![],
1248                        context: "".to_string(),
1249                        creases: vec![],
1250                        is_hidden: false
1251                    },
1252                    SerializedMessage {
1253                        id: MessageId(2),
1254                        role: Role::Assistant,
1255                        segments: vec![SerializedMessageSegment::Text {
1256                            text: "I want to use a tool".to_string(),
1257                        }],
1258                        tool_uses: vec![SerializedToolUse {
1259                            id: "abc".into(),
1260                            name: "tool_1".into(),
1261                            input: serde_json::Value::Null,
1262                        }],
1263                        tool_results: vec![SerializedToolResult {
1264                            tool_use_id: "abc".into(),
1265                            is_error: false,
1266                            content: LanguageModelToolResultContent::Text("abcdef".into()),
1267                            output: Some(serde_json::Value::Null),
1268                        }],
1269                        context: "".to_string(),
1270                        creases: vec![],
1271                        is_hidden: false,
1272                    },
1273                ],
1274                version: SerializedThread::VERSION.to_string(),
1275                initial_project_snapshot: None,
1276                cumulative_token_usage: TokenUsage::default(),
1277                request_token_usage: vec![],
1278                detailed_summary_state: DetailedSummaryState::default(),
1279                exceeded_window_error: None,
1280                model: None,
1281                completion_mode: None,
1282                tool_use_limit_reached: false,
1283                profile: None
1284            }
1285        )
1286    }
1287}