thread_store.rs

   1use crate::{
   2    context_server_tool::ContextServerTool,
   3    thread::{
   4        DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
   5    },
   6};
   7use agent_settings::{AgentProfileId, CompletionMode};
   8use anyhow::{Context as _, Result, anyhow};
   9use assistant_tool::{Tool, ToolId, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::HashMap;
  12use context_server::ContextServerId;
  13use futures::{
  14    FutureExt as _, StreamExt as _,
  15    channel::{mpsc, oneshot},
  16    future::{self, BoxFuture, Shared},
  17};
  18use gpui::{
  19    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  20    Subscription, Task, Window, prelude::*,
  21};
  22use indoc::indoc;
  23use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  24use project::context_server_store::{ContextServerStatus, ContextServerStore};
  25use project::{Project, ProjectItem, ProjectPath, Worktree};
  26use prompt_store::{
  27    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  28    UserRulesContext, WorktreeContext,
  29};
  30use serde::{Deserialize, Serialize};
  31use sqlez::{
  32    bindable::{Bind, Column},
  33    connection::Connection,
  34    statement::Statement,
  35};
  36use std::{
  37    cell::{Ref, RefCell},
  38    path::{Path, PathBuf},
  39    rc::Rc,
  40    sync::{Arc, Mutex},
  41};
  42use util::ResultExt as _;
  43
  44use zed_env_vars::ZED_STATELESS;
  45
  46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  47pub enum DataType {
  48    #[serde(rename = "json")]
  49    Json,
  50    #[serde(rename = "zstd")]
  51    Zstd,
  52}
  53
  54impl Bind for DataType {
  55    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  56        let value = match self {
  57            DataType::Json => "json",
  58            DataType::Zstd => "zstd",
  59        };
  60        value.bind(statement, start_index)
  61    }
  62}
  63
  64impl Column for DataType {
  65    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  66        let (value, next_index) = String::column(statement, start_index)?;
  67        let data_type = match value.as_str() {
  68            "json" => DataType::Json,
  69            "zstd" => DataType::Zstd,
  70            _ => anyhow::bail!("Unknown data type: {}", value),
  71        };
  72        Ok((data_type, next_index))
  73    }
  74}
  75
  76const RULES_FILE_NAMES: [&str; 9] = [
  77    ".rules",
  78    ".cursorrules",
  79    ".windsurfrules",
  80    ".clinerules",
  81    ".github/copilot-instructions.md",
  82    "CLAUDE.md",
  83    "AGENT.md",
  84    "AGENTS.md",
  85    "GEMINI.md",
  86];
  87
  88pub fn init(cx: &mut App) {
  89    ThreadsDatabase::init(cx);
  90}
  91
  92/// A system prompt shared by all threads created by this ThreadStore
  93#[derive(Clone, Default)]
  94pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  95
  96impl SharedProjectContext {
  97    pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
  98        self.0.borrow()
  99    }
 100}
 101
 102pub type TextThreadStore = assistant_context::ContextStore;
 103
 104pub struct ThreadStore {
 105    project: Entity<Project>,
 106    tools: Entity<ToolWorkingSet>,
 107    prompt_builder: Arc<PromptBuilder>,
 108    prompt_store: Option<Entity<PromptStore>>,
 109    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 110    threads: Vec<SerializedThreadMetadata>,
 111    project_context: SharedProjectContext,
 112    reload_system_prompt_tx: mpsc::Sender<()>,
 113    _reload_system_prompt_task: Task<()>,
 114    _subscriptions: Vec<Subscription>,
 115}
 116
 117pub struct RulesLoadingError {
 118    pub message: SharedString,
 119}
 120
 121impl EventEmitter<RulesLoadingError> for ThreadStore {}
 122
 123impl ThreadStore {
 124    pub fn load(
 125        project: Entity<Project>,
 126        tools: Entity<ToolWorkingSet>,
 127        prompt_store: Option<Entity<PromptStore>>,
 128        prompt_builder: Arc<PromptBuilder>,
 129        cx: &mut App,
 130    ) -> Task<Result<Entity<Self>>> {
 131        cx.spawn(async move |cx| {
 132            let (thread_store, ready_rx) = cx.update(|cx| {
 133                let mut option_ready_rx = None;
 134                let thread_store = cx.new(|cx| {
 135                    let (thread_store, ready_rx) =
 136                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 137                    option_ready_rx = Some(ready_rx);
 138                    thread_store
 139                });
 140                (thread_store, option_ready_rx.take().unwrap())
 141            })?;
 142            ready_rx.await?;
 143            Ok(thread_store)
 144        })
 145    }
 146
 147    fn new(
 148        project: Entity<Project>,
 149        tools: Entity<ToolWorkingSet>,
 150        prompt_builder: Arc<PromptBuilder>,
 151        prompt_store: Option<Entity<PromptStore>>,
 152        cx: &mut Context<Self>,
 153    ) -> (Self, oneshot::Receiver<()>) {
 154        let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 155
 156        if let Some(prompt_store) = prompt_store.as_ref() {
 157            subscriptions.push(cx.subscribe(
 158                prompt_store,
 159                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 160                    this.enqueue_system_prompt_reload();
 161                },
 162            ))
 163        }
 164
 165        // This channel and task prevent concurrent and redundant loading of the system prompt.
 166        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 167        let (ready_tx, ready_rx) = oneshot::channel();
 168        let mut ready_tx = Some(ready_tx);
 169        let reload_system_prompt_task = cx.spawn({
 170            let prompt_store = prompt_store.clone();
 171            async move |thread_store, cx| {
 172                loop {
 173                    let Some(reload_task) = thread_store
 174                        .update(cx, |thread_store, cx| {
 175                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 176                        })
 177                        .ok()
 178                    else {
 179                        return;
 180                    };
 181                    reload_task.await;
 182                    if let Some(ready_tx) = ready_tx.take() {
 183                        ready_tx.send(()).ok();
 184                    }
 185                    reload_system_prompt_rx.next().await;
 186                }
 187            }
 188        });
 189
 190        let this = Self {
 191            project,
 192            tools,
 193            prompt_builder,
 194            prompt_store,
 195            context_server_tool_ids: HashMap::default(),
 196            threads: Vec::new(),
 197            project_context: SharedProjectContext::default(),
 198            reload_system_prompt_tx,
 199            _reload_system_prompt_task: reload_system_prompt_task,
 200            _subscriptions: subscriptions,
 201        };
 202        this.register_context_server_handlers(cx);
 203        this.reload(cx).detach_and_log_err(cx);
 204        (this, ready_rx)
 205    }
 206
 207    #[cfg(any(test, feature = "test-support"))]
 208    pub fn fake(project: Entity<Project>, cx: &mut App) -> Self {
 209        Self {
 210            project,
 211            tools: cx.new(|_| ToolWorkingSet::default()),
 212            prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
 213            prompt_store: None,
 214            context_server_tool_ids: HashMap::default(),
 215            threads: Vec::new(),
 216            project_context: SharedProjectContext::default(),
 217            reload_system_prompt_tx: mpsc::channel(0).0,
 218            _reload_system_prompt_task: Task::ready(()),
 219            _subscriptions: vec![],
 220        }
 221    }
 222
 223    fn handle_project_event(
 224        &mut self,
 225        _project: Entity<Project>,
 226        event: &project::Event,
 227        _cx: &mut Context<Self>,
 228    ) {
 229        match event {
 230            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 231                self.enqueue_system_prompt_reload();
 232            }
 233            project::Event::WorktreeUpdatedEntries(_, items) => {
 234                if items.iter().any(|(path, _, _)| {
 235                    RULES_FILE_NAMES
 236                        .iter()
 237                        .any(|name| path.as_ref() == Path::new(name))
 238                }) {
 239                    self.enqueue_system_prompt_reload();
 240                }
 241            }
 242            _ => {}
 243        }
 244    }
 245
 246    fn enqueue_system_prompt_reload(&mut self) {
 247        self.reload_system_prompt_tx.try_send(()).ok();
 248    }
 249
 250    // Note that this should only be called from `reload_system_prompt_task`.
 251    fn reload_system_prompt(
 252        &self,
 253        prompt_store: Option<Entity<PromptStore>>,
 254        cx: &mut Context<Self>,
 255    ) -> Task<()> {
 256        let worktrees = self
 257            .project
 258            .read(cx)
 259            .visible_worktrees(cx)
 260            .collect::<Vec<_>>();
 261        let worktree_tasks = worktrees
 262            .into_iter()
 263            .map(|worktree| {
 264                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 265            })
 266            .collect::<Vec<_>>();
 267        let default_user_rules_task = match prompt_store {
 268            None => Task::ready(vec![]),
 269            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 270                let prompts = prompt_store.default_prompt_metadata();
 271                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 272                    let contents = prompt_store.load(prompt_metadata.id, cx);
 273                    async move { (contents.await, prompt_metadata) }
 274                });
 275                cx.background_spawn(future::join_all(load_tasks))
 276            }),
 277        };
 278
 279        cx.spawn(async move |this, cx| {
 280            let (worktrees, default_user_rules) =
 281                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 282
 283            let worktrees = worktrees
 284                .into_iter()
 285                .map(|(worktree, rules_error)| {
 286                    if let Some(rules_error) = rules_error {
 287                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 288                    }
 289                    worktree
 290                })
 291                .collect::<Vec<_>>();
 292
 293            let default_user_rules = default_user_rules
 294                .into_iter()
 295                .flat_map(|(contents, prompt_metadata)| match contents {
 296                    Ok(contents) => Some(UserRulesContext {
 297                        uuid: match prompt_metadata.id {
 298                            PromptId::User { uuid } => uuid,
 299                            PromptId::EditWorkflow => return None,
 300                        },
 301                        title: prompt_metadata.title.map(|title| title.to_string()),
 302                        contents,
 303                    }),
 304                    Err(err) => {
 305                        this.update(cx, |_, cx| {
 306                            cx.emit(RulesLoadingError {
 307                                message: format!("{err:?}").into(),
 308                            });
 309                        })
 310                        .ok();
 311                        None
 312                    }
 313                })
 314                .collect::<Vec<_>>();
 315
 316            this.update(cx, |this, _cx| {
 317                *this.project_context.0.borrow_mut() =
 318                    Some(ProjectContext::new(worktrees, default_user_rules));
 319            })
 320            .ok();
 321        })
 322    }
 323
 324    fn load_worktree_info_for_system_prompt(
 325        worktree: Entity<Worktree>,
 326        project: Entity<Project>,
 327        cx: &mut App,
 328    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 329        let tree = worktree.read(cx);
 330        let root_name = tree.root_name().into();
 331        let abs_path = tree.abs_path();
 332
 333        let mut context = WorktreeContext {
 334            root_name,
 335            abs_path,
 336            rules_file: None,
 337        };
 338
 339        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 340        let Some(rules_task) = rules_task else {
 341            return Task::ready((context, None));
 342        };
 343
 344        cx.spawn(async move |_| {
 345            let (rules_file, rules_file_error) = match rules_task.await {
 346                Ok(rules_file) => (Some(rules_file), None),
 347                Err(err) => (
 348                    None,
 349                    Some(RulesLoadingError {
 350                        message: format!("{err}").into(),
 351                    }),
 352                ),
 353            };
 354            context.rules_file = rules_file;
 355            (context, rules_file_error)
 356        })
 357    }
 358
 359    fn load_worktree_rules_file(
 360        worktree: Entity<Worktree>,
 361        project: Entity<Project>,
 362        cx: &mut App,
 363    ) -> Option<Task<Result<RulesFileContext>>> {
 364        let worktree = worktree.read(cx);
 365        let worktree_id = worktree.id();
 366        let selected_rules_file = RULES_FILE_NAMES
 367            .into_iter()
 368            .filter_map(|name| {
 369                worktree
 370                    .entry_for_path(name)
 371                    .filter(|entry| entry.is_file())
 372                    .map(|entry| entry.path.clone())
 373            })
 374            .next();
 375
 376        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 377        // supported. This doesn't seem to occur often in GitHub repositories.
 378        selected_rules_file.map(|path_in_worktree| {
 379            let project_path = ProjectPath {
 380                worktree_id,
 381                path: path_in_worktree.clone(),
 382            };
 383            let buffer_task =
 384                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 385            let rope_task = cx.spawn(async move |cx| {
 386                buffer_task.await?.read_with(cx, |buffer, cx| {
 387                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 388                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 389                })?
 390            });
 391            // Build a string from the rope on a background thread.
 392            cx.background_spawn(async move {
 393                let (project_entry_id, rope) = rope_task.await?;
 394                anyhow::Ok(RulesFileContext {
 395                    path_in_worktree,
 396                    text: rope.to_string().trim().to_string(),
 397                    project_entry_id: project_entry_id.to_usize(),
 398                })
 399            })
 400        })
 401    }
 402
 403    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 404        &self.prompt_store
 405    }
 406
 407    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 408        self.tools.clone()
 409    }
 410
 411    /// Returns the number of threads.
 412    pub fn thread_count(&self) -> usize {
 413        self.threads.len()
 414    }
 415
 416    pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 417        // ordering is from "ORDER BY" in `list_threads`
 418        self.threads.iter()
 419    }
 420
 421    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 422        cx.new(|cx| {
 423            Thread::new(
 424                self.project.clone(),
 425                self.tools.clone(),
 426                self.prompt_builder.clone(),
 427                self.project_context.clone(),
 428                cx,
 429            )
 430        })
 431    }
 432
 433    pub fn create_thread_from_serialized(
 434        &mut self,
 435        serialized: SerializedThread,
 436        cx: &mut Context<Self>,
 437    ) -> Entity<Thread> {
 438        cx.new(|cx| {
 439            Thread::deserialize(
 440                ThreadId::new(),
 441                serialized,
 442                self.project.clone(),
 443                self.tools.clone(),
 444                self.prompt_builder.clone(),
 445                self.project_context.clone(),
 446                None,
 447                cx,
 448            )
 449        })
 450    }
 451
 452    pub fn open_thread(
 453        &self,
 454        id: &ThreadId,
 455        window: &mut Window,
 456        cx: &mut Context<Self>,
 457    ) -> Task<Result<Entity<Thread>>> {
 458        let id = id.clone();
 459        let database_future = ThreadsDatabase::global_future(cx);
 460        let this = cx.weak_entity();
 461        window.spawn(cx, async move |cx| {
 462            let database = database_future.await.map_err(|err| anyhow!(err))?;
 463            let thread = database
 464                .try_find_thread(id.clone())
 465                .await?
 466                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 467
 468            let thread = this.update_in(cx, |this, window, cx| {
 469                cx.new(|cx| {
 470                    Thread::deserialize(
 471                        id.clone(),
 472                        thread,
 473                        this.project.clone(),
 474                        this.tools.clone(),
 475                        this.prompt_builder.clone(),
 476                        this.project_context.clone(),
 477                        Some(window),
 478                        cx,
 479                    )
 480                })
 481            })?;
 482
 483            Ok(thread)
 484        })
 485    }
 486
 487    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 488        let (metadata, serialized_thread) =
 489            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 490
 491        let database_future = ThreadsDatabase::global_future(cx);
 492        cx.spawn(async move |this, cx| {
 493            let serialized_thread = serialized_thread.await?;
 494            let database = database_future.await.map_err(|err| anyhow!(err))?;
 495            database.save_thread(metadata, serialized_thread).await?;
 496
 497            this.update(cx, |this, cx| this.reload(cx))?.await
 498        })
 499    }
 500
 501    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 502        let id = id.clone();
 503        let database_future = ThreadsDatabase::global_future(cx);
 504        cx.spawn(async move |this, cx| {
 505            let database = database_future.await.map_err(|err| anyhow!(err))?;
 506            database.delete_thread(id.clone()).await?;
 507
 508            this.update(cx, |this, cx| {
 509                this.threads.retain(|thread| thread.id != id);
 510                cx.notify();
 511            })
 512        })
 513    }
 514
 515    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 516        let database_future = ThreadsDatabase::global_future(cx);
 517        cx.spawn(async move |this, cx| {
 518            let threads = database_future
 519                .await
 520                .map_err(|err| anyhow!(err))?
 521                .list_threads()
 522                .await?;
 523
 524            this.update(cx, |this, cx| {
 525                this.threads = threads;
 526                cx.notify();
 527            })
 528        })
 529    }
 530
 531    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 532        let context_server_store = self.project.read(cx).context_server_store();
 533        cx.subscribe(&context_server_store, Self::handle_context_server_event)
 534            .detach();
 535
 536        // Check for any servers that were already running before the handler was registered
 537        for server in context_server_store.read(cx).running_servers() {
 538            self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
 539        }
 540    }
 541
 542    fn handle_context_server_event(
 543        &mut self,
 544        context_server_store: Entity<ContextServerStore>,
 545        event: &project::context_server_store::Event,
 546        cx: &mut Context<Self>,
 547    ) {
 548        let tool_working_set = self.tools.clone();
 549        match event {
 550            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 551                match status {
 552                    ContextServerStatus::Starting => {}
 553                    ContextServerStatus::Running => {
 554                        self.load_context_server_tools(server_id.clone(), context_server_store, cx);
 555                    }
 556                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 557                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 558                            tool_working_set.update(cx, |tool_working_set, cx| {
 559                                tool_working_set.remove(&tool_ids, cx);
 560                            });
 561                        }
 562                    }
 563                }
 564            }
 565        }
 566    }
 567
 568    fn load_context_server_tools(
 569        &self,
 570        server_id: ContextServerId,
 571        context_server_store: Entity<ContextServerStore>,
 572        cx: &mut Context<Self>,
 573    ) {
 574        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
 575            return;
 576        };
 577        let tool_working_set = self.tools.clone();
 578        cx.spawn(async move |this, cx| {
 579            let Some(protocol) = server.client() else {
 580                return;
 581            };
 582
 583            if protocol.capable(context_server::protocol::ServerCapability::Tools)
 584                && let Some(response) = protocol
 585                    .request::<context_server::types::requests::ListTools>(())
 586                    .await
 587                    .log_err()
 588            {
 589                let tool_ids = tool_working_set
 590                    .update(cx, |tool_working_set, cx| {
 591                        tool_working_set.extend(
 592                            response.tools.into_iter().map(|tool| {
 593                                Arc::new(ContextServerTool::new(
 594                                    context_server_store.clone(),
 595                                    server.id(),
 596                                    tool,
 597                                )) as Arc<dyn Tool>
 598                            }),
 599                            cx,
 600                        )
 601                    })
 602                    .log_err();
 603
 604                if let Some(tool_ids) = tool_ids {
 605                    this.update(cx, |this, _| {
 606                        this.context_server_tool_ids.insert(server_id, tool_ids);
 607                    })
 608                    .log_err();
 609                }
 610            }
 611        })
 612        .detach();
 613    }
 614}
 615
 616#[derive(Debug, Clone, Serialize, Deserialize)]
 617pub struct SerializedThreadMetadata {
 618    pub id: ThreadId,
 619    pub summary: SharedString,
 620    pub updated_at: DateTime<Utc>,
 621}
 622
 623#[derive(Serialize, Deserialize, Debug, PartialEq)]
 624pub struct SerializedThread {
 625    pub version: String,
 626    pub summary: SharedString,
 627    pub updated_at: DateTime<Utc>,
 628    pub messages: Vec<SerializedMessage>,
 629    #[serde(default)]
 630    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 631    #[serde(default)]
 632    pub cumulative_token_usage: TokenUsage,
 633    #[serde(default)]
 634    pub request_token_usage: Vec<TokenUsage>,
 635    #[serde(default)]
 636    pub detailed_summary_state: DetailedSummaryState,
 637    #[serde(default)]
 638    pub exceeded_window_error: Option<ExceededWindowError>,
 639    #[serde(default)]
 640    pub model: Option<SerializedLanguageModel>,
 641    #[serde(default)]
 642    pub completion_mode: Option<CompletionMode>,
 643    #[serde(default)]
 644    pub tool_use_limit_reached: bool,
 645    #[serde(default)]
 646    pub profile: Option<AgentProfileId>,
 647}
 648
 649#[derive(Serialize, Deserialize, Debug, PartialEq)]
 650pub struct SerializedLanguageModel {
 651    pub provider: String,
 652    pub model: String,
 653}
 654
 655impl SerializedThread {
 656    pub const VERSION: &'static str = "0.2.0";
 657
 658    pub fn from_json(json: &[u8]) -> Result<Self> {
 659        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 660        match saved_thread_json.get("version") {
 661            Some(serde_json::Value::String(version)) => match version.as_str() {
 662                SerializedThreadV0_1_0::VERSION => {
 663                    let saved_thread =
 664                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 665                    Ok(saved_thread.upgrade())
 666                }
 667                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 668                    saved_thread_json,
 669                )?),
 670                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 671            },
 672            None => {
 673                let saved_thread =
 674                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 675                Ok(saved_thread.upgrade())
 676            }
 677            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 678        }
 679    }
 680}
 681
 682#[derive(Serialize, Deserialize, Debug)]
 683pub struct SerializedThreadV0_1_0(
 684    // The structure did not change, so we are reusing the latest SerializedThread.
 685    // When making the next version, make sure this points to SerializedThreadV0_2_0
 686    SerializedThread,
 687);
 688
 689impl SerializedThreadV0_1_0 {
 690    pub const VERSION: &'static str = "0.1.0";
 691
 692    pub fn upgrade(self) -> SerializedThread {
 693        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 694
 695        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 696
 697        for message in self.0.messages {
 698            if message.role == Role::User
 699                && !message.tool_results.is_empty()
 700                && let Some(last_message) = messages.last_mut()
 701            {
 702                debug_assert!(last_message.role == Role::Assistant);
 703
 704                last_message.tool_results = message.tool_results;
 705                continue;
 706            }
 707
 708            messages.push(message);
 709        }
 710
 711        SerializedThread {
 712            messages,
 713            version: SerializedThread::VERSION.to_string(),
 714            ..self.0
 715        }
 716    }
 717}
 718
 719#[derive(Debug, Serialize, Deserialize, PartialEq)]
 720pub struct SerializedMessage {
 721    pub id: MessageId,
 722    pub role: Role,
 723    #[serde(default)]
 724    pub segments: Vec<SerializedMessageSegment>,
 725    #[serde(default)]
 726    pub tool_uses: Vec<SerializedToolUse>,
 727    #[serde(default)]
 728    pub tool_results: Vec<SerializedToolResult>,
 729    #[serde(default)]
 730    pub context: String,
 731    #[serde(default)]
 732    pub creases: Vec<SerializedCrease>,
 733    #[serde(default)]
 734    pub is_hidden: bool,
 735}
 736
 737#[derive(Debug, Serialize, Deserialize, PartialEq)]
 738#[serde(tag = "type")]
 739pub enum SerializedMessageSegment {
 740    #[serde(rename = "text")]
 741    Text {
 742        text: String,
 743    },
 744    #[serde(rename = "thinking")]
 745    Thinking {
 746        text: String,
 747        #[serde(skip_serializing_if = "Option::is_none")]
 748        signature: Option<String>,
 749    },
 750    RedactedThinking {
 751        data: String,
 752    },
 753}
 754
 755#[derive(Debug, Serialize, Deserialize, PartialEq)]
 756pub struct SerializedToolUse {
 757    pub id: LanguageModelToolUseId,
 758    pub name: SharedString,
 759    pub input: serde_json::Value,
 760}
 761
 762#[derive(Debug, Serialize, Deserialize, PartialEq)]
 763pub struct SerializedToolResult {
 764    pub tool_use_id: LanguageModelToolUseId,
 765    pub is_error: bool,
 766    pub content: LanguageModelToolResultContent,
 767    pub output: Option<serde_json::Value>,
 768}
 769
 770#[derive(Serialize, Deserialize)]
 771struct LegacySerializedThread {
 772    pub summary: SharedString,
 773    pub updated_at: DateTime<Utc>,
 774    pub messages: Vec<LegacySerializedMessage>,
 775    #[serde(default)]
 776    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 777}
 778
 779impl LegacySerializedThread {
 780    pub fn upgrade(self) -> SerializedThread {
 781        SerializedThread {
 782            version: SerializedThread::VERSION.to_string(),
 783            summary: self.summary,
 784            updated_at: self.updated_at,
 785            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 786            initial_project_snapshot: self.initial_project_snapshot,
 787            cumulative_token_usage: TokenUsage::default(),
 788            request_token_usage: Vec::new(),
 789            detailed_summary_state: DetailedSummaryState::default(),
 790            exceeded_window_error: None,
 791            model: None,
 792            completion_mode: None,
 793            tool_use_limit_reached: false,
 794            profile: None,
 795        }
 796    }
 797}
 798
 799#[derive(Debug, Serialize, Deserialize)]
 800struct LegacySerializedMessage {
 801    pub id: MessageId,
 802    pub role: Role,
 803    pub text: String,
 804    #[serde(default)]
 805    pub tool_uses: Vec<SerializedToolUse>,
 806    #[serde(default)]
 807    pub tool_results: Vec<SerializedToolResult>,
 808}
 809
 810impl LegacySerializedMessage {
 811    fn upgrade(self) -> SerializedMessage {
 812        SerializedMessage {
 813            id: self.id,
 814            role: self.role,
 815            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 816            tool_uses: self.tool_uses,
 817            tool_results: self.tool_results,
 818            context: String::new(),
 819            creases: Vec::new(),
 820            is_hidden: false,
 821        }
 822    }
 823}
 824
 825#[derive(Debug, Serialize, Deserialize, PartialEq)]
 826pub struct SerializedCrease {
 827    pub start: usize,
 828    pub end: usize,
 829    pub icon_path: SharedString,
 830    pub label: SharedString,
 831}
 832
 833struct GlobalThreadsDatabase(
 834    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 835);
 836
 837impl Global for GlobalThreadsDatabase {}
 838
 839pub(crate) struct ThreadsDatabase {
 840    executor: BackgroundExecutor,
 841    connection: Arc<Mutex<Connection>>,
 842}
 843
 844impl ThreadsDatabase {
 845    fn connection(&self) -> Arc<Mutex<Connection>> {
 846        self.connection.clone()
 847    }
 848
 849    const COMPRESSION_LEVEL: i32 = 3;
 850}
 851
 852impl Bind for ThreadId {
 853    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 854        self.to_string().bind(statement, start_index)
 855    }
 856}
 857
 858impl Column for ThreadId {
 859    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 860        let (id_str, next_index) = String::column(statement, start_index)?;
 861        Ok((ThreadId::from(id_str.as_str()), next_index))
 862    }
 863}
 864
 865impl ThreadsDatabase {
 866    fn global_future(
 867        cx: &mut App,
 868    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 869        GlobalThreadsDatabase::global(cx).0.clone()
 870    }
 871
 872    fn init(cx: &mut App) {
 873        let executor = cx.background_executor().clone();
 874        let database_future = executor
 875            .spawn({
 876                let executor = executor.clone();
 877                let threads_dir = paths::data_dir().join("threads");
 878                async move { ThreadsDatabase::new(threads_dir, executor) }
 879            })
 880            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 881            .boxed()
 882            .shared();
 883
 884        cx.set_global(GlobalThreadsDatabase(database_future));
 885    }
 886
 887    pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
 888        std::fs::create_dir_all(&threads_dir)?;
 889
 890        let sqlite_path = threads_dir.join("threads.db");
 891        let mdb_path = threads_dir.join("threads-db.1.mdb");
 892
 893        let needs_migration_from_heed = mdb_path.exists();
 894
 895        let connection = if *ZED_STATELESS {
 896            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
 897        } else if cfg!(any(feature = "test-support", test)) {
 898            // rust stores the name of the test on the current thread.
 899            // We use this to automatically create a database that will
 900            // be shared within the test (for the test_retrieve_old_thread)
 901            // but not with concurrent tests.
 902            let thread = std::thread::current();
 903            let test_name = thread.name();
 904            Connection::open_memory(Some(&format!(
 905                "THREAD_FALLBACK_{}",
 906                test_name.unwrap_or_default()
 907            )))
 908        } else {
 909            Connection::open_file(&sqlite_path.to_string_lossy())
 910        };
 911
 912        connection.exec(indoc! {"
 913                CREATE TABLE IF NOT EXISTS threads (
 914                    id TEXT PRIMARY KEY,
 915                    summary TEXT NOT NULL,
 916                    updated_at TEXT NOT NULL,
 917                    data_type TEXT NOT NULL,
 918                    data BLOB NOT NULL
 919                )
 920            "})?()
 921        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 922
 923        let db = Self {
 924            executor: executor.clone(),
 925            connection: Arc::new(Mutex::new(connection)),
 926        };
 927
 928        if needs_migration_from_heed {
 929            let db_connection = db.connection();
 930            let executor_clone = executor.clone();
 931            executor
 932                .spawn(async move {
 933                    log::info!("Starting threads.db migration");
 934                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 935                    std::fs::remove_dir_all(mdb_path)?;
 936                    log::info!("threads.db migrated to sqlite");
 937                    Ok::<(), anyhow::Error>(())
 938                })
 939                .detach();
 940        }
 941
 942        Ok(db)
 943    }
 944
 945    // Remove this migration after 2025-09-01
 946    fn migrate_from_heed(
 947        mdb_path: &Path,
 948        connection: Arc<Mutex<Connection>>,
 949        _executor: BackgroundExecutor,
 950    ) -> Result<()> {
 951        use heed::types::SerdeBincode;
 952        struct SerializedThreadHeed(SerializedThread);
 953
 954        impl heed::BytesEncode<'_> for SerializedThreadHeed {
 955            type EItem = SerializedThreadHeed;
 956
 957            fn bytes_encode(
 958                item: &Self::EItem,
 959            ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
 960                serde_json::to_vec(&item.0)
 961                    .map(std::borrow::Cow::Owned)
 962                    .map_err(Into::into)
 963            }
 964        }
 965
 966        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
 967            type DItem = SerializedThreadHeed;
 968
 969            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
 970                SerializedThread::from_json(bytes)
 971                    .map(SerializedThreadHeed)
 972                    .map_err(Into::into)
 973            }
 974        }
 975
 976        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
 977
 978        let env = unsafe {
 979            heed::EnvOpenOptions::new()
 980                .map_size(ONE_GB_IN_BYTES)
 981                .max_dbs(1)
 982                .open(mdb_path)?
 983        };
 984
 985        let txn = env.write_txn()?;
 986        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
 987            .open_database(&txn, Some("threads"))?
 988            .ok_or_else(|| anyhow!("threads database not found"))?;
 989
 990        for result in threads.iter(&txn)? {
 991            let (thread_id, thread_heed) = result?;
 992            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
 993        }
 994
 995        Ok(())
 996    }
 997
 998    fn save_thread_sync(
 999        connection: &Arc<Mutex<Connection>>,
1000        id: ThreadId,
1001        thread: SerializedThread,
1002    ) -> Result<()> {
1003        let json_data = serde_json::to_string(&thread)?;
1004        let summary = thread.summary.to_string();
1005        let updated_at = thread.updated_at.to_rfc3339();
1006
1007        let connection = connection.lock().unwrap();
1008
1009        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1010        let data_type = DataType::Zstd;
1011        let data = compressed;
1012
1013        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1014            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1015        "})?;
1016
1017        insert((id, summary, updated_at, data_type, data))?;
1018
1019        Ok(())
1020    }
1021
1022    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1023        let connection = self.connection.clone();
1024
1025        self.executor.spawn(async move {
1026            let connection = connection.lock().unwrap();
1027            let mut select =
1028                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1029                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1030            "})?;
1031
1032            let rows = select(())?;
1033            let mut threads = Vec::new();
1034
1035            for (id, summary, updated_at) in rows {
1036                threads.push(SerializedThreadMetadata {
1037                    id,
1038                    summary: summary.into(),
1039                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1040                });
1041            }
1042
1043            Ok(threads)
1044        })
1045    }
1046
1047    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1048        let connection = self.connection.clone();
1049
1050        self.executor.spawn(async move {
1051            let connection = connection.lock().unwrap();
1052            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1053                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1054            "})?;
1055
1056            let rows = select(id)?;
1057            if let Some((data_type, data)) = rows.into_iter().next() {
1058                let json_data = match data_type {
1059                    DataType::Zstd => {
1060                        let decompressed = zstd::decode_all(&data[..])?;
1061                        String::from_utf8(decompressed)?
1062                    }
1063                    DataType::Json => String::from_utf8(data)?,
1064                };
1065
1066                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1067                Ok(Some(thread))
1068            } else {
1069                Ok(None)
1070            }
1071        })
1072    }
1073
1074    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1075        let connection = self.connection.clone();
1076
1077        self.executor
1078            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1079    }
1080
1081    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1082        let connection = self.connection.clone();
1083
1084        self.executor.spawn(async move {
1085            let connection = connection.lock().unwrap();
1086
1087            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1088                DELETE FROM threads WHERE id = ?
1089            "})?;
1090
1091            delete(id)?;
1092
1093            Ok(())
1094        })
1095    }
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100    use super::*;
1101    use crate::thread::{DetailedSummaryState, MessageId};
1102    use chrono::Utc;
1103    use language_model::{Role, TokenUsage};
1104    use pretty_assertions::assert_eq;
1105
1106    #[test]
1107    fn test_legacy_serialized_thread_upgrade() {
1108        let updated_at = Utc::now();
1109        let legacy_thread = LegacySerializedThread {
1110            summary: "Test conversation".into(),
1111            updated_at,
1112            messages: vec![LegacySerializedMessage {
1113                id: MessageId(1),
1114                role: Role::User,
1115                text: "Hello, world!".to_string(),
1116                tool_uses: vec![],
1117                tool_results: vec![],
1118            }],
1119            initial_project_snapshot: None,
1120        };
1121
1122        let upgraded = legacy_thread.upgrade();
1123
1124        assert_eq!(
1125            upgraded,
1126            SerializedThread {
1127                summary: "Test conversation".into(),
1128                updated_at,
1129                messages: vec![SerializedMessage {
1130                    id: MessageId(1),
1131                    role: Role::User,
1132                    segments: vec![SerializedMessageSegment::Text {
1133                        text: "Hello, world!".to_string()
1134                    }],
1135                    tool_uses: vec![],
1136                    tool_results: vec![],
1137                    context: "".to_string(),
1138                    creases: vec![],
1139                    is_hidden: false
1140                }],
1141                version: SerializedThread::VERSION.to_string(),
1142                initial_project_snapshot: None,
1143                cumulative_token_usage: TokenUsage::default(),
1144                request_token_usage: vec![],
1145                detailed_summary_state: DetailedSummaryState::default(),
1146                exceeded_window_error: None,
1147                model: None,
1148                completion_mode: None,
1149                tool_use_limit_reached: false,
1150                profile: None
1151            }
1152        )
1153    }
1154
1155    #[test]
1156    fn test_serialized_threadv0_1_0_upgrade() {
1157        let updated_at = Utc::now();
1158        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1159            summary: "Test conversation".into(),
1160            updated_at,
1161            messages: vec![
1162                SerializedMessage {
1163                    id: MessageId(1),
1164                    role: Role::User,
1165                    segments: vec![SerializedMessageSegment::Text {
1166                        text: "Use tool_1".to_string(),
1167                    }],
1168                    tool_uses: vec![],
1169                    tool_results: vec![],
1170                    context: "".to_string(),
1171                    creases: vec![],
1172                    is_hidden: false,
1173                },
1174                SerializedMessage {
1175                    id: MessageId(2),
1176                    role: Role::Assistant,
1177                    segments: vec![SerializedMessageSegment::Text {
1178                        text: "I want to use a tool".to_string(),
1179                    }],
1180                    tool_uses: vec![SerializedToolUse {
1181                        id: "abc".into(),
1182                        name: "tool_1".into(),
1183                        input: serde_json::Value::Null,
1184                    }],
1185                    tool_results: vec![],
1186                    context: "".to_string(),
1187                    creases: vec![],
1188                    is_hidden: false,
1189                },
1190                SerializedMessage {
1191                    id: MessageId(1),
1192                    role: Role::User,
1193                    segments: vec![SerializedMessageSegment::Text {
1194                        text: "Here is the tool result".to_string(),
1195                    }],
1196                    tool_uses: vec![],
1197                    tool_results: vec![SerializedToolResult {
1198                        tool_use_id: "abc".into(),
1199                        is_error: false,
1200                        content: LanguageModelToolResultContent::Text("abcdef".into()),
1201                        output: Some(serde_json::Value::Null),
1202                    }],
1203                    context: "".to_string(),
1204                    creases: vec![],
1205                    is_hidden: false,
1206                },
1207            ],
1208            version: SerializedThreadV0_1_0::VERSION.to_string(),
1209            initial_project_snapshot: None,
1210            cumulative_token_usage: TokenUsage::default(),
1211            request_token_usage: vec![],
1212            detailed_summary_state: DetailedSummaryState::default(),
1213            exceeded_window_error: None,
1214            model: None,
1215            completion_mode: None,
1216            tool_use_limit_reached: false,
1217            profile: None,
1218        });
1219        let upgraded = thread_v0_1_0.upgrade();
1220
1221        assert_eq!(
1222            upgraded,
1223            SerializedThread {
1224                summary: "Test conversation".into(),
1225                updated_at,
1226                messages: vec![
1227                    SerializedMessage {
1228                        id: MessageId(1),
1229                        role: Role::User,
1230                        segments: vec![SerializedMessageSegment::Text {
1231                            text: "Use tool_1".to_string()
1232                        }],
1233                        tool_uses: vec![],
1234                        tool_results: vec![],
1235                        context: "".to_string(),
1236                        creases: vec![],
1237                        is_hidden: false
1238                    },
1239                    SerializedMessage {
1240                        id: MessageId(2),
1241                        role: Role::Assistant,
1242                        segments: vec![SerializedMessageSegment::Text {
1243                            text: "I want to use a tool".to_string(),
1244                        }],
1245                        tool_uses: vec![SerializedToolUse {
1246                            id: "abc".into(),
1247                            name: "tool_1".into(),
1248                            input: serde_json::Value::Null,
1249                        }],
1250                        tool_results: vec![SerializedToolResult {
1251                            tool_use_id: "abc".into(),
1252                            is_error: false,
1253                            content: LanguageModelToolResultContent::Text("abcdef".into()),
1254                            output: Some(serde_json::Value::Null),
1255                        }],
1256                        context: "".to_string(),
1257                        creases: vec![],
1258                        is_hidden: false,
1259                    },
1260                ],
1261                version: SerializedThread::VERSION.to_string(),
1262                initial_project_snapshot: None,
1263                cumulative_token_usage: TokenUsage::default(),
1264                request_token_usage: vec![],
1265                detailed_summary_state: DetailedSummaryState::default(),
1266                exceeded_window_error: None,
1267                model: None,
1268                completion_mode: None,
1269                tool_use_limit_reached: false,
1270                profile: None
1271            }
1272        )
1273    }
1274}