thread_store.rs

   1use crate::{
   2    context_server_tool::ContextServerTool,
   3    thread::{
   4        DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
   5    },
   6};
   7use agent_settings::{AgentProfileId, CompletionMode};
   8use anyhow::{Context as _, Result, anyhow};
   9use assistant_tool::{Tool, ToolId, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use context_server::ContextServerId;
  13use futures::{
  14    FutureExt as _, StreamExt as _,
  15    channel::{mpsc, oneshot},
  16    future::{self, BoxFuture, Shared},
  17};
  18use gpui::{
  19    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  20    Subscription, Task, Window, prelude::*,
  21};
  22use indoc::indoc;
  23use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  24use project::context_server_store::{ContextServerStatus, ContextServerStore};
  25use project::{Project, ProjectItem, ProjectPath, Worktree};
  26use prompt_store::{
  27    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  28    UserRulesContext, WorktreeContext,
  29};
  30use serde::{Deserialize, Serialize};
  31use sqlez::{
  32    bindable::{Bind, Column},
  33    connection::Connection,
  34    statement::Statement,
  35};
  36use std::{
  37    cell::{Ref, RefCell},
  38    path::{Path, PathBuf},
  39    rc::Rc,
  40    sync::{Arc, Mutex},
  41};
  42use util::ResultExt as _;
  43
  44pub static ZED_STATELESS: std::sync::LazyLock<bool> =
  45    std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
  46
  47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  48pub enum DataType {
  49    #[serde(rename = "json")]
  50    Json,
  51    #[serde(rename = "zstd")]
  52    Zstd,
  53}
  54
  55impl Bind for DataType {
  56    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  57        let value = match self {
  58            DataType::Json => "json",
  59            DataType::Zstd => "zstd",
  60        };
  61        value.bind(statement, start_index)
  62    }
  63}
  64
  65impl Column for DataType {
  66    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  67        let (value, next_index) = String::column(statement, start_index)?;
  68        let data_type = match value.as_str() {
  69            "json" => DataType::Json,
  70            "zstd" => DataType::Zstd,
  71            _ => anyhow::bail!("Unknown data type: {}", value),
  72        };
  73        Ok((data_type, next_index))
  74    }
  75}
  76
  77const RULES_FILE_NAMES: [&'static str; 9] = [
  78    ".rules",
  79    ".cursorrules",
  80    ".windsurfrules",
  81    ".clinerules",
  82    ".github/copilot-instructions.md",
  83    "CLAUDE.md",
  84    "AGENT.md",
  85    "AGENTS.md",
  86    "GEMINI.md",
  87];
  88
  89pub fn init(cx: &mut App) {
  90    ThreadsDatabase::init(cx);
  91}
  92
  93/// A system prompt shared by all threads created by this ThreadStore
  94#[derive(Clone, Default)]
  95pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  96
  97impl SharedProjectContext {
  98    pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
  99        self.0.borrow()
 100    }
 101}
 102
 103pub type TextThreadStore = assistant_context::ContextStore;
 104
 105pub struct ThreadStore {
 106    project: Entity<Project>,
 107    tools: Entity<ToolWorkingSet>,
 108    prompt_builder: Arc<PromptBuilder>,
 109    prompt_store: Option<Entity<PromptStore>>,
 110    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 111    threads: Vec<SerializedThreadMetadata>,
 112    project_context: SharedProjectContext,
 113    reload_system_prompt_tx: mpsc::Sender<()>,
 114    _reload_system_prompt_task: Task<()>,
 115    _subscriptions: Vec<Subscription>,
 116}
 117
 118pub struct RulesLoadingError {
 119    pub message: SharedString,
 120}
 121
 122impl EventEmitter<RulesLoadingError> for ThreadStore {}
 123
 124impl ThreadStore {
 125    pub fn load(
 126        project: Entity<Project>,
 127        tools: Entity<ToolWorkingSet>,
 128        prompt_store: Option<Entity<PromptStore>>,
 129        prompt_builder: Arc<PromptBuilder>,
 130        cx: &mut App,
 131    ) -> Task<Result<Entity<Self>>> {
 132        cx.spawn(async move |cx| {
 133            let (thread_store, ready_rx) = cx.update(|cx| {
 134                let mut option_ready_rx = None;
 135                let thread_store = cx.new(|cx| {
 136                    let (thread_store, ready_rx) =
 137                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 138                    option_ready_rx = Some(ready_rx);
 139                    thread_store
 140                });
 141                (thread_store, option_ready_rx.take().unwrap())
 142            })?;
 143            ready_rx.await?;
 144            Ok(thread_store)
 145        })
 146    }
 147
 148    fn new(
 149        project: Entity<Project>,
 150        tools: Entity<ToolWorkingSet>,
 151        prompt_builder: Arc<PromptBuilder>,
 152        prompt_store: Option<Entity<PromptStore>>,
 153        cx: &mut Context<Self>,
 154    ) -> (Self, oneshot::Receiver<()>) {
 155        let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 156
 157        if let Some(prompt_store) = prompt_store.as_ref() {
 158            subscriptions.push(cx.subscribe(
 159                prompt_store,
 160                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 161                    this.enqueue_system_prompt_reload();
 162                },
 163            ))
 164        }
 165
 166        // This channel and task prevent concurrent and redundant loading of the system prompt.
 167        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 168        let (ready_tx, ready_rx) = oneshot::channel();
 169        let mut ready_tx = Some(ready_tx);
 170        let reload_system_prompt_task = cx.spawn({
 171            let prompt_store = prompt_store.clone();
 172            async move |thread_store, cx| {
 173                loop {
 174                    let Some(reload_task) = thread_store
 175                        .update(cx, |thread_store, cx| {
 176                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 177                        })
 178                        .ok()
 179                    else {
 180                        return;
 181                    };
 182                    reload_task.await;
 183                    if let Some(ready_tx) = ready_tx.take() {
 184                        ready_tx.send(()).ok();
 185                    }
 186                    reload_system_prompt_rx.next().await;
 187                }
 188            }
 189        });
 190
 191        let this = Self {
 192            project,
 193            tools,
 194            prompt_builder,
 195            prompt_store,
 196            context_server_tool_ids: HashMap::default(),
 197            threads: Vec::new(),
 198            project_context: SharedProjectContext::default(),
 199            reload_system_prompt_tx,
 200            _reload_system_prompt_task: reload_system_prompt_task,
 201            _subscriptions: subscriptions,
 202        };
 203        this.register_context_server_handlers(cx);
 204        this.reload(cx).detach_and_log_err(cx);
 205        (this, ready_rx)
 206    }
 207
 208    fn handle_project_event(
 209        &mut self,
 210        _project: Entity<Project>,
 211        event: &project::Event,
 212        _cx: &mut Context<Self>,
 213    ) {
 214        match event {
 215            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 216                self.enqueue_system_prompt_reload();
 217            }
 218            project::Event::WorktreeUpdatedEntries(_, items) => {
 219                if items.iter().any(|(path, _, _)| {
 220                    RULES_FILE_NAMES
 221                        .iter()
 222                        .any(|name| path.as_ref() == Path::new(name))
 223                }) {
 224                    self.enqueue_system_prompt_reload();
 225                }
 226            }
 227            _ => {}
 228        }
 229    }
 230
 231    fn enqueue_system_prompt_reload(&mut self) {
 232        self.reload_system_prompt_tx.try_send(()).ok();
 233    }
 234
 235    // Note that this should only be called from `reload_system_prompt_task`.
 236    fn reload_system_prompt(
 237        &self,
 238        prompt_store: Option<Entity<PromptStore>>,
 239        cx: &mut Context<Self>,
 240    ) -> Task<()> {
 241        let worktrees = self
 242            .project
 243            .read(cx)
 244            .visible_worktrees(cx)
 245            .collect::<Vec<_>>();
 246        let worktree_tasks = worktrees
 247            .into_iter()
 248            .map(|worktree| {
 249                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 250            })
 251            .collect::<Vec<_>>();
 252        let default_user_rules_task = match prompt_store {
 253            None => Task::ready(vec![]),
 254            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 255                let prompts = prompt_store.default_prompt_metadata();
 256                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 257                    let contents = prompt_store.load(prompt_metadata.id, cx);
 258                    async move { (contents.await, prompt_metadata) }
 259                });
 260                cx.background_spawn(future::join_all(load_tasks))
 261            }),
 262        };
 263
 264        cx.spawn(async move |this, cx| {
 265            let (worktrees, default_user_rules) =
 266                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 267
 268            let worktrees = worktrees
 269                .into_iter()
 270                .map(|(worktree, rules_error)| {
 271                    if let Some(rules_error) = rules_error {
 272                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 273                    }
 274                    worktree
 275                })
 276                .collect::<Vec<_>>();
 277
 278            let default_user_rules = default_user_rules
 279                .into_iter()
 280                .flat_map(|(contents, prompt_metadata)| match contents {
 281                    Ok(contents) => Some(UserRulesContext {
 282                        uuid: match prompt_metadata.id {
 283                            PromptId::User { uuid } => uuid,
 284                            PromptId::EditWorkflow => return None,
 285                        },
 286                        title: prompt_metadata.title.map(|title| title.to_string()),
 287                        contents,
 288                    }),
 289                    Err(err) => {
 290                        this.update(cx, |_, cx| {
 291                            cx.emit(RulesLoadingError {
 292                                message: format!("{err:?}").into(),
 293                            });
 294                        })
 295                        .ok();
 296                        None
 297                    }
 298                })
 299                .collect::<Vec<_>>();
 300
 301            this.update(cx, |this, _cx| {
 302                *this.project_context.0.borrow_mut() =
 303                    Some(ProjectContext::new(worktrees, default_user_rules));
 304            })
 305            .ok();
 306        })
 307    }
 308
 309    fn load_worktree_info_for_system_prompt(
 310        worktree: Entity<Worktree>,
 311        project: Entity<Project>,
 312        cx: &mut App,
 313    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 314        let tree = worktree.read(cx);
 315        let root_name = tree.root_name().into();
 316        let abs_path = tree.abs_path();
 317
 318        let mut context = WorktreeContext {
 319            root_name,
 320            abs_path,
 321            rules_file: None,
 322        };
 323
 324        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 325        let Some(rules_task) = rules_task else {
 326            return Task::ready((context, None));
 327        };
 328
 329        cx.spawn(async move |_| {
 330            let (rules_file, rules_file_error) = match rules_task.await {
 331                Ok(rules_file) => (Some(rules_file), None),
 332                Err(err) => (
 333                    None,
 334                    Some(RulesLoadingError {
 335                        message: format!("{err}").into(),
 336                    }),
 337                ),
 338            };
 339            context.rules_file = rules_file;
 340            (context, rules_file_error)
 341        })
 342    }
 343
 344    fn load_worktree_rules_file(
 345        worktree: Entity<Worktree>,
 346        project: Entity<Project>,
 347        cx: &mut App,
 348    ) -> Option<Task<Result<RulesFileContext>>> {
 349        let worktree = worktree.read(cx);
 350        let worktree_id = worktree.id();
 351        let selected_rules_file = RULES_FILE_NAMES
 352            .into_iter()
 353            .filter_map(|name| {
 354                worktree
 355                    .entry_for_path(name)
 356                    .filter(|entry| entry.is_file())
 357                    .map(|entry| entry.path.clone())
 358            })
 359            .next();
 360
 361        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 362        // supported. This doesn't seem to occur often in GitHub repositories.
 363        selected_rules_file.map(|path_in_worktree| {
 364            let project_path = ProjectPath {
 365                worktree_id,
 366                path: path_in_worktree.clone(),
 367            };
 368            let buffer_task =
 369                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 370            let rope_task = cx.spawn(async move |cx| {
 371                buffer_task.await?.read_with(cx, |buffer, cx| {
 372                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 373                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 374                })?
 375            });
 376            // Build a string from the rope on a background thread.
 377            cx.background_spawn(async move {
 378                let (project_entry_id, rope) = rope_task.await?;
 379                anyhow::Ok(RulesFileContext {
 380                    path_in_worktree,
 381                    text: rope.to_string().trim().to_string(),
 382                    project_entry_id: project_entry_id.to_usize(),
 383                })
 384            })
 385        })
 386    }
 387
 388    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 389        &self.prompt_store
 390    }
 391
 392    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 393        self.tools.clone()
 394    }
 395
 396    /// Returns the number of threads.
 397    pub fn thread_count(&self) -> usize {
 398        self.threads.len()
 399    }
 400
 401    pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 402        // ordering is from "ORDER BY" in `list_threads`
 403        self.threads.iter()
 404    }
 405
 406    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 407        cx.new(|cx| {
 408            Thread::new(
 409                self.project.clone(),
 410                self.tools.clone(),
 411                self.prompt_builder.clone(),
 412                self.project_context.clone(),
 413                cx,
 414            )
 415        })
 416    }
 417
 418    pub fn create_thread_from_serialized(
 419        &mut self,
 420        serialized: SerializedThread,
 421        cx: &mut Context<Self>,
 422    ) -> Entity<Thread> {
 423        cx.new(|cx| {
 424            Thread::deserialize(
 425                ThreadId::new(),
 426                serialized,
 427                self.project.clone(),
 428                self.tools.clone(),
 429                self.prompt_builder.clone(),
 430                self.project_context.clone(),
 431                None,
 432                cx,
 433            )
 434        })
 435    }
 436
 437    pub fn open_thread(
 438        &self,
 439        id: &ThreadId,
 440        window: &mut Window,
 441        cx: &mut Context<Self>,
 442    ) -> Task<Result<Entity<Thread>>> {
 443        let id = id.clone();
 444        let database_future = ThreadsDatabase::global_future(cx);
 445        let this = cx.weak_entity();
 446        window.spawn(cx, async move |cx| {
 447            let database = database_future.await.map_err(|err| anyhow!(err))?;
 448            let thread = database
 449                .try_find_thread(id.clone())
 450                .await?
 451                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 452
 453            let thread = this.update_in(cx, |this, window, cx| {
 454                cx.new(|cx| {
 455                    Thread::deserialize(
 456                        id.clone(),
 457                        thread,
 458                        this.project.clone(),
 459                        this.tools.clone(),
 460                        this.prompt_builder.clone(),
 461                        this.project_context.clone(),
 462                        Some(window),
 463                        cx,
 464                    )
 465                })
 466            })?;
 467
 468            Ok(thread)
 469        })
 470    }
 471
 472    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 473        let (metadata, serialized_thread) =
 474            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 475
 476        let database_future = ThreadsDatabase::global_future(cx);
 477        cx.spawn(async move |this, cx| {
 478            let serialized_thread = serialized_thread.await?;
 479            let database = database_future.await.map_err(|err| anyhow!(err))?;
 480            database.save_thread(metadata, serialized_thread).await?;
 481
 482            this.update(cx, |this, cx| this.reload(cx))?.await
 483        })
 484    }
 485
 486    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 487        let id = id.clone();
 488        let database_future = ThreadsDatabase::global_future(cx);
 489        cx.spawn(async move |this, cx| {
 490            let database = database_future.await.map_err(|err| anyhow!(err))?;
 491            database.delete_thread(id.clone()).await?;
 492
 493            this.update(cx, |this, cx| {
 494                this.threads.retain(|thread| thread.id != id);
 495                cx.notify();
 496            })
 497        })
 498    }
 499
 500    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 501        let database_future = ThreadsDatabase::global_future(cx);
 502        cx.spawn(async move |this, cx| {
 503            let threads = database_future
 504                .await
 505                .map_err(|err| anyhow!(err))?
 506                .list_threads()
 507                .await?;
 508
 509            this.update(cx, |this, cx| {
 510                this.threads = threads;
 511                cx.notify();
 512            })
 513        })
 514    }
 515
 516    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 517        let context_server_store = self.project.read(cx).context_server_store();
 518        cx.subscribe(&context_server_store, Self::handle_context_server_event)
 519            .detach();
 520
 521        // Check for any servers that were already running before the handler was registered
 522        for server in context_server_store.read(cx).running_servers() {
 523            self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
 524        }
 525    }
 526
 527    fn handle_context_server_event(
 528        &mut self,
 529        context_server_store: Entity<ContextServerStore>,
 530        event: &project::context_server_store::Event,
 531        cx: &mut Context<Self>,
 532    ) {
 533        let tool_working_set = self.tools.clone();
 534        match event {
 535            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 536                match status {
 537                    ContextServerStatus::Starting => {}
 538                    ContextServerStatus::Running => {
 539                        self.load_context_server_tools(server_id.clone(), context_server_store, cx);
 540                    }
 541                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 542                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 543                            tool_working_set.update(cx, |tool_working_set, cx| {
 544                                tool_working_set.remove(&tool_ids, cx);
 545                            });
 546                        }
 547                    }
 548                }
 549            }
 550        }
 551    }
 552
 553    fn load_context_server_tools(
 554        &self,
 555        server_id: ContextServerId,
 556        context_server_store: Entity<ContextServerStore>,
 557        cx: &mut Context<Self>,
 558    ) {
 559        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
 560            return;
 561        };
 562        let tool_working_set = self.tools.clone();
 563        cx.spawn(async move |this, cx| {
 564            let Some(protocol) = server.client() else {
 565                return;
 566            };
 567
 568            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
 569                if let Some(response) = protocol
 570                    .request::<context_server::types::requests::ListTools>(())
 571                    .await
 572                    .log_err()
 573                {
 574                    let tool_ids = tool_working_set
 575                        .update(cx, |tool_working_set, cx| {
 576                            tool_working_set.extend(
 577                                response.tools.into_iter().map(|tool| {
 578                                    Arc::new(ContextServerTool::new(
 579                                        context_server_store.clone(),
 580                                        server.id(),
 581                                        tool,
 582                                    )) as Arc<dyn Tool>
 583                                }),
 584                                cx,
 585                            )
 586                        })
 587                        .log_err();
 588
 589                    if let Some(tool_ids) = tool_ids {
 590                        this.update(cx, |this, _| {
 591                            this.context_server_tool_ids.insert(server_id, tool_ids);
 592                        })
 593                        .log_err();
 594                    }
 595                }
 596            }
 597        })
 598        .detach();
 599    }
 600}
 601
 602#[derive(Debug, Clone, Serialize, Deserialize)]
 603pub struct SerializedThreadMetadata {
 604    pub id: ThreadId,
 605    pub summary: SharedString,
 606    pub updated_at: DateTime<Utc>,
 607}
 608
 609#[derive(Serialize, Deserialize, Debug, PartialEq)]
 610pub struct SerializedThread {
 611    pub version: String,
 612    pub summary: SharedString,
 613    pub updated_at: DateTime<Utc>,
 614    pub messages: Vec<SerializedMessage>,
 615    #[serde(default)]
 616    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 617    #[serde(default)]
 618    pub cumulative_token_usage: TokenUsage,
 619    #[serde(default)]
 620    pub request_token_usage: Vec<TokenUsage>,
 621    #[serde(default)]
 622    pub detailed_summary_state: DetailedSummaryState,
 623    #[serde(default)]
 624    pub exceeded_window_error: Option<ExceededWindowError>,
 625    #[serde(default)]
 626    pub model: Option<SerializedLanguageModel>,
 627    #[serde(default)]
 628    pub completion_mode: Option<CompletionMode>,
 629    #[serde(default)]
 630    pub tool_use_limit_reached: bool,
 631    #[serde(default)]
 632    pub profile: Option<AgentProfileId>,
 633}
 634
 635#[derive(Serialize, Deserialize, Debug, PartialEq)]
 636pub struct SerializedLanguageModel {
 637    pub provider: String,
 638    pub model: String,
 639}
 640
 641impl SerializedThread {
 642    pub const VERSION: &'static str = "0.2.0";
 643
 644    pub fn from_json(json: &[u8]) -> Result<Self> {
 645        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 646        match saved_thread_json.get("version") {
 647            Some(serde_json::Value::String(version)) => match version.as_str() {
 648                SerializedThreadV0_1_0::VERSION => {
 649                    let saved_thread =
 650                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 651                    Ok(saved_thread.upgrade())
 652                }
 653                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 654                    saved_thread_json,
 655                )?),
 656                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 657            },
 658            None => {
 659                let saved_thread =
 660                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 661                Ok(saved_thread.upgrade())
 662            }
 663            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 664        }
 665    }
 666}
 667
 668#[derive(Serialize, Deserialize, Debug)]
 669pub struct SerializedThreadV0_1_0(
 670    // The structure did not change, so we are reusing the latest SerializedThread.
 671    // When making the next version, make sure this points to SerializedThreadV0_2_0
 672    SerializedThread,
 673);
 674
 675impl SerializedThreadV0_1_0 {
 676    pub const VERSION: &'static str = "0.1.0";
 677
 678    pub fn upgrade(self) -> SerializedThread {
 679        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 680
 681        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 682
 683        for message in self.0.messages {
 684            if message.role == Role::User && !message.tool_results.is_empty() {
 685                if let Some(last_message) = messages.last_mut() {
 686                    debug_assert!(last_message.role == Role::Assistant);
 687
 688                    last_message.tool_results = message.tool_results;
 689                    continue;
 690                }
 691            }
 692
 693            messages.push(message);
 694        }
 695
 696        SerializedThread {
 697            messages,
 698            version: SerializedThread::VERSION.to_string(),
 699            ..self.0
 700        }
 701    }
 702}
 703
 704#[derive(Debug, Serialize, Deserialize, PartialEq)]
 705pub struct SerializedMessage {
 706    pub id: MessageId,
 707    pub role: Role,
 708    #[serde(default)]
 709    pub segments: Vec<SerializedMessageSegment>,
 710    #[serde(default)]
 711    pub tool_uses: Vec<SerializedToolUse>,
 712    #[serde(default)]
 713    pub tool_results: Vec<SerializedToolResult>,
 714    #[serde(default)]
 715    pub context: String,
 716    #[serde(default)]
 717    pub creases: Vec<SerializedCrease>,
 718    #[serde(default)]
 719    pub is_hidden: bool,
 720}
 721
 722#[derive(Debug, Serialize, Deserialize, PartialEq)]
 723#[serde(tag = "type")]
 724pub enum SerializedMessageSegment {
 725    #[serde(rename = "text")]
 726    Text {
 727        text: String,
 728    },
 729    #[serde(rename = "thinking")]
 730    Thinking {
 731        text: String,
 732        #[serde(skip_serializing_if = "Option::is_none")]
 733        signature: Option<String>,
 734    },
 735    RedactedThinking {
 736        data: String,
 737    },
 738}
 739
 740#[derive(Debug, Serialize, Deserialize, PartialEq)]
 741pub struct SerializedToolUse {
 742    pub id: LanguageModelToolUseId,
 743    pub name: SharedString,
 744    pub input: serde_json::Value,
 745}
 746
 747#[derive(Debug, Serialize, Deserialize, PartialEq)]
 748pub struct SerializedToolResult {
 749    pub tool_use_id: LanguageModelToolUseId,
 750    pub is_error: bool,
 751    pub content: LanguageModelToolResultContent,
 752    pub output: Option<serde_json::Value>,
 753}
 754
 755#[derive(Serialize, Deserialize)]
 756struct LegacySerializedThread {
 757    pub summary: SharedString,
 758    pub updated_at: DateTime<Utc>,
 759    pub messages: Vec<LegacySerializedMessage>,
 760    #[serde(default)]
 761    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 762}
 763
 764impl LegacySerializedThread {
 765    pub fn upgrade(self) -> SerializedThread {
 766        SerializedThread {
 767            version: SerializedThread::VERSION.to_string(),
 768            summary: self.summary,
 769            updated_at: self.updated_at,
 770            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 771            initial_project_snapshot: self.initial_project_snapshot,
 772            cumulative_token_usage: TokenUsage::default(),
 773            request_token_usage: Vec::new(),
 774            detailed_summary_state: DetailedSummaryState::default(),
 775            exceeded_window_error: None,
 776            model: None,
 777            completion_mode: None,
 778            tool_use_limit_reached: false,
 779            profile: None,
 780        }
 781    }
 782}
 783
 784#[derive(Debug, Serialize, Deserialize)]
 785struct LegacySerializedMessage {
 786    pub id: MessageId,
 787    pub role: Role,
 788    pub text: String,
 789    #[serde(default)]
 790    pub tool_uses: Vec<SerializedToolUse>,
 791    #[serde(default)]
 792    pub tool_results: Vec<SerializedToolResult>,
 793}
 794
 795impl LegacySerializedMessage {
 796    fn upgrade(self) -> SerializedMessage {
 797        SerializedMessage {
 798            id: self.id,
 799            role: self.role,
 800            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 801            tool_uses: self.tool_uses,
 802            tool_results: self.tool_results,
 803            context: String::new(),
 804            creases: Vec::new(),
 805            is_hidden: false,
 806        }
 807    }
 808}
 809
 810#[derive(Debug, Serialize, Deserialize, PartialEq)]
 811pub struct SerializedCrease {
 812    pub start: usize,
 813    pub end: usize,
 814    pub icon_path: SharedString,
 815    pub label: SharedString,
 816}
 817
 818struct GlobalThreadsDatabase(
 819    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 820);
 821
 822impl Global for GlobalThreadsDatabase {}
 823
 824pub(crate) struct ThreadsDatabase {
 825    executor: BackgroundExecutor,
 826    connection: Arc<Mutex<Connection>>,
 827}
 828
 829impl ThreadsDatabase {
 830    fn connection(&self) -> Arc<Mutex<Connection>> {
 831        self.connection.clone()
 832    }
 833
 834    const COMPRESSION_LEVEL: i32 = 3;
 835}
 836
 837impl Bind for ThreadId {
 838    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 839        self.to_string().bind(statement, start_index)
 840    }
 841}
 842
 843impl Column for ThreadId {
 844    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 845        let (id_str, next_index) = String::column(statement, start_index)?;
 846        Ok((ThreadId::from(id_str.as_str()), next_index))
 847    }
 848}
 849
 850impl ThreadsDatabase {
 851    fn global_future(
 852        cx: &mut App,
 853    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 854        GlobalThreadsDatabase::global(cx).0.clone()
 855    }
 856
 857    fn init(cx: &mut App) {
 858        let executor = cx.background_executor().clone();
 859        let database_future = executor
 860            .spawn({
 861                let executor = executor.clone();
 862                let threads_dir = paths::data_dir().join("threads");
 863                async move { ThreadsDatabase::new(threads_dir, executor) }
 864            })
 865            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 866            .boxed()
 867            .shared();
 868
 869        cx.set_global(GlobalThreadsDatabase(database_future));
 870    }
 871
 872    pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
 873        std::fs::create_dir_all(&threads_dir)?;
 874
 875        let sqlite_path = threads_dir.join("threads.db");
 876        let mdb_path = threads_dir.join("threads-db.1.mdb");
 877
 878        let needs_migration_from_heed = mdb_path.exists();
 879
 880        let connection = if *ZED_STATELESS {
 881            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
 882        } else {
 883            Connection::open_file(&sqlite_path.to_string_lossy())
 884        };
 885
 886        connection.exec(indoc! {"
 887                CREATE TABLE IF NOT EXISTS threads (
 888                    id TEXT PRIMARY KEY,
 889                    summary TEXT NOT NULL,
 890                    updated_at TEXT NOT NULL,
 891                    data_type TEXT NOT NULL,
 892                    data BLOB NOT NULL
 893                )
 894            "})?()
 895        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 896
 897        let db = Self {
 898            executor: executor.clone(),
 899            connection: Arc::new(Mutex::new(connection)),
 900        };
 901
 902        if needs_migration_from_heed {
 903            let db_connection = db.connection();
 904            let executor_clone = executor.clone();
 905            executor
 906                .spawn(async move {
 907                    log::info!("Starting threads.db migration");
 908                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 909                    std::fs::remove_dir_all(mdb_path)?;
 910                    log::info!("threads.db migrated to sqlite");
 911                    Ok::<(), anyhow::Error>(())
 912                })
 913                .detach();
 914        }
 915
 916        Ok(db)
 917    }
 918
 919    // Remove this migration after 2025-09-01
 920    fn migrate_from_heed(
 921        mdb_path: &Path,
 922        connection: Arc<Mutex<Connection>>,
 923        _executor: BackgroundExecutor,
 924    ) -> Result<()> {
 925        use heed::types::SerdeBincode;
 926        struct SerializedThreadHeed(SerializedThread);
 927
 928        impl heed::BytesEncode<'_> for SerializedThreadHeed {
 929            type EItem = SerializedThreadHeed;
 930
 931            fn bytes_encode(
 932                item: &Self::EItem,
 933            ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
 934                serde_json::to_vec(&item.0)
 935                    .map(std::borrow::Cow::Owned)
 936                    .map_err(Into::into)
 937            }
 938        }
 939
 940        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
 941            type DItem = SerializedThreadHeed;
 942
 943            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
 944                SerializedThread::from_json(bytes)
 945                    .map(SerializedThreadHeed)
 946                    .map_err(Into::into)
 947            }
 948        }
 949
 950        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
 951
 952        let env = unsafe {
 953            heed::EnvOpenOptions::new()
 954                .map_size(ONE_GB_IN_BYTES)
 955                .max_dbs(1)
 956                .open(mdb_path)?
 957        };
 958
 959        let txn = env.write_txn()?;
 960        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
 961            .open_database(&txn, Some("threads"))?
 962            .ok_or_else(|| anyhow!("threads database not found"))?;
 963
 964        for result in threads.iter(&txn)? {
 965            let (thread_id, thread_heed) = result?;
 966            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
 967        }
 968
 969        Ok(())
 970    }
 971
 972    fn save_thread_sync(
 973        connection: &Arc<Mutex<Connection>>,
 974        id: ThreadId,
 975        thread: SerializedThread,
 976    ) -> Result<()> {
 977        let json_data = serde_json::to_string(&thread)?;
 978        let summary = thread.summary.to_string();
 979        let updated_at = thread.updated_at.to_rfc3339();
 980
 981        let connection = connection.lock().unwrap();
 982
 983        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
 984        let data_type = DataType::Zstd;
 985        let data = compressed;
 986
 987        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
 988            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
 989        "})?;
 990
 991        insert((id, summary, updated_at, data_type, data))?;
 992
 993        Ok(())
 994    }
 995
 996    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
 997        let connection = self.connection.clone();
 998
 999        self.executor.spawn(async move {
1000            let connection = connection.lock().unwrap();
1001            let mut select =
1002                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1003                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1004            "})?;
1005
1006            let rows = select(())?;
1007            let mut threads = Vec::new();
1008
1009            for (id, summary, updated_at) in rows {
1010                threads.push(SerializedThreadMetadata {
1011                    id,
1012                    summary: summary.into(),
1013                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1014                });
1015            }
1016
1017            Ok(threads)
1018        })
1019    }
1020
1021    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1022        let connection = self.connection.clone();
1023
1024        self.executor.spawn(async move {
1025            let connection = connection.lock().unwrap();
1026            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1027                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1028            "})?;
1029
1030            let rows = select(id)?;
1031            if let Some((data_type, data)) = rows.into_iter().next() {
1032                let json_data = match data_type {
1033                    DataType::Zstd => {
1034                        let decompressed = zstd::decode_all(&data[..])?;
1035                        String::from_utf8(decompressed)?
1036                    }
1037                    DataType::Json => String::from_utf8(data)?,
1038                };
1039
1040                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1041                Ok(Some(thread))
1042            } else {
1043                Ok(None)
1044            }
1045        })
1046    }
1047
1048    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1049        let connection = self.connection.clone();
1050
1051        self.executor
1052            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1053    }
1054
1055    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1056        let connection = self.connection.clone();
1057
1058        self.executor.spawn(async move {
1059            let connection = connection.lock().unwrap();
1060
1061            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1062                DELETE FROM threads WHERE id = ?
1063            "})?;
1064
1065            delete(id)?;
1066
1067            Ok(())
1068        })
1069    }
1070}
1071
1072#[cfg(test)]
1073mod tests {
1074    use super::*;
1075    use crate::thread::{DetailedSummaryState, MessageId};
1076    use chrono::Utc;
1077    use language_model::{Role, TokenUsage};
1078    use pretty_assertions::assert_eq;
1079
1080    #[test]
1081    fn test_legacy_serialized_thread_upgrade() {
1082        let updated_at = Utc::now();
1083        let legacy_thread = LegacySerializedThread {
1084            summary: "Test conversation".into(),
1085            updated_at,
1086            messages: vec![LegacySerializedMessage {
1087                id: MessageId(1),
1088                role: Role::User,
1089                text: "Hello, world!".to_string(),
1090                tool_uses: vec![],
1091                tool_results: vec![],
1092            }],
1093            initial_project_snapshot: None,
1094        };
1095
1096        let upgraded = legacy_thread.upgrade();
1097
1098        assert_eq!(
1099            upgraded,
1100            SerializedThread {
1101                summary: "Test conversation".into(),
1102                updated_at,
1103                messages: vec![SerializedMessage {
1104                    id: MessageId(1),
1105                    role: Role::User,
1106                    segments: vec![SerializedMessageSegment::Text {
1107                        text: "Hello, world!".to_string()
1108                    }],
1109                    tool_uses: vec![],
1110                    tool_results: vec![],
1111                    context: "".to_string(),
1112                    creases: vec![],
1113                    is_hidden: false
1114                }],
1115                version: SerializedThread::VERSION.to_string(),
1116                initial_project_snapshot: None,
1117                cumulative_token_usage: TokenUsage::default(),
1118                request_token_usage: vec![],
1119                detailed_summary_state: DetailedSummaryState::default(),
1120                exceeded_window_error: None,
1121                model: None,
1122                completion_mode: None,
1123                tool_use_limit_reached: false,
1124                profile: None
1125            }
1126        )
1127    }
1128
1129    #[test]
1130    fn test_serialized_threadv0_1_0_upgrade() {
1131        let updated_at = Utc::now();
1132        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1133            summary: "Test conversation".into(),
1134            updated_at,
1135            messages: vec![
1136                SerializedMessage {
1137                    id: MessageId(1),
1138                    role: Role::User,
1139                    segments: vec![SerializedMessageSegment::Text {
1140                        text: "Use tool_1".to_string(),
1141                    }],
1142                    tool_uses: vec![],
1143                    tool_results: vec![],
1144                    context: "".to_string(),
1145                    creases: vec![],
1146                    is_hidden: false,
1147                },
1148                SerializedMessage {
1149                    id: MessageId(2),
1150                    role: Role::Assistant,
1151                    segments: vec![SerializedMessageSegment::Text {
1152                        text: "I want to use a tool".to_string(),
1153                    }],
1154                    tool_uses: vec![SerializedToolUse {
1155                        id: "abc".into(),
1156                        name: "tool_1".into(),
1157                        input: serde_json::Value::Null,
1158                    }],
1159                    tool_results: vec![],
1160                    context: "".to_string(),
1161                    creases: vec![],
1162                    is_hidden: false,
1163                },
1164                SerializedMessage {
1165                    id: MessageId(1),
1166                    role: Role::User,
1167                    segments: vec![SerializedMessageSegment::Text {
1168                        text: "Here is the tool result".to_string(),
1169                    }],
1170                    tool_uses: vec![],
1171                    tool_results: vec![SerializedToolResult {
1172                        tool_use_id: "abc".into(),
1173                        is_error: false,
1174                        content: LanguageModelToolResultContent::Text("abcdef".into()),
1175                        output: Some(serde_json::Value::Null),
1176                    }],
1177                    context: "".to_string(),
1178                    creases: vec![],
1179                    is_hidden: false,
1180                },
1181            ],
1182            version: SerializedThreadV0_1_0::VERSION.to_string(),
1183            initial_project_snapshot: None,
1184            cumulative_token_usage: TokenUsage::default(),
1185            request_token_usage: vec![],
1186            detailed_summary_state: DetailedSummaryState::default(),
1187            exceeded_window_error: None,
1188            model: None,
1189            completion_mode: None,
1190            tool_use_limit_reached: false,
1191            profile: None,
1192        });
1193        let upgraded = thread_v0_1_0.upgrade();
1194
1195        assert_eq!(
1196            upgraded,
1197            SerializedThread {
1198                summary: "Test conversation".into(),
1199                updated_at,
1200                messages: vec![
1201                    SerializedMessage {
1202                        id: MessageId(1),
1203                        role: Role::User,
1204                        segments: vec![SerializedMessageSegment::Text {
1205                            text: "Use tool_1".to_string()
1206                        }],
1207                        tool_uses: vec![],
1208                        tool_results: vec![],
1209                        context: "".to_string(),
1210                        creases: vec![],
1211                        is_hidden: false
1212                    },
1213                    SerializedMessage {
1214                        id: MessageId(2),
1215                        role: Role::Assistant,
1216                        segments: vec![SerializedMessageSegment::Text {
1217                            text: "I want to use a tool".to_string(),
1218                        }],
1219                        tool_uses: vec![SerializedToolUse {
1220                            id: "abc".into(),
1221                            name: "tool_1".into(),
1222                            input: serde_json::Value::Null,
1223                        }],
1224                        tool_results: vec![SerializedToolResult {
1225                            tool_use_id: "abc".into(),
1226                            is_error: false,
1227                            content: LanguageModelToolResultContent::Text("abcdef".into()),
1228                            output: Some(serde_json::Value::Null),
1229                        }],
1230                        context: "".to_string(),
1231                        creases: vec![],
1232                        is_hidden: false,
1233                    },
1234                ],
1235                version: SerializedThread::VERSION.to_string(),
1236                initial_project_snapshot: None,
1237                cumulative_token_usage: TokenUsage::default(),
1238                request_token_usage: vec![],
1239                detailed_summary_state: DetailedSummaryState::default(),
1240                exceeded_window_error: None,
1241                model: None,
1242                completion_mode: None,
1243                tool_use_limit_reached: false,
1244                profile: None
1245            }
1246        )
1247    }
1248}