thread_store.rs

   1use crate::{
   2    context_server_tool::ContextServerTool,
   3    thread::{
   4        DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
   5    },
   6};
   7use agent_settings::{AgentProfileId, CompletionMode};
   8use anyhow::{Context as _, Result, anyhow};
   9use assistant_tool::{Tool, ToolId, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use client::CloudUserStore;
  12use collections::HashMap;
  13use context_server::ContextServerId;
  14use futures::{
  15    FutureExt as _, StreamExt as _,
  16    channel::{mpsc, oneshot},
  17    future::{self, BoxFuture, Shared},
  18};
  19use gpui::{
  20    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  21    Subscription, Task, Window, prelude::*,
  22};
  23use indoc::indoc;
  24use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  25use project::context_server_store::{ContextServerStatus, ContextServerStore};
  26use project::{Project, ProjectItem, ProjectPath, Worktree};
  27use prompt_store::{
  28    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  29    UserRulesContext, WorktreeContext,
  30};
  31use serde::{Deserialize, Serialize};
  32use sqlez::{
  33    bindable::{Bind, Column},
  34    connection::Connection,
  35    statement::Statement,
  36};
  37use std::{
  38    cell::{Ref, RefCell},
  39    path::{Path, PathBuf},
  40    rc::Rc,
  41    sync::{Arc, Mutex},
  42};
  43use util::ResultExt as _;
  44
  45pub static ZED_STATELESS: std::sync::LazyLock<bool> =
  46    std::sync::LazyLock::new(|| std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
  47
  48#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  49pub enum DataType {
  50    #[serde(rename = "json")]
  51    Json,
  52    #[serde(rename = "zstd")]
  53    Zstd,
  54}
  55
  56impl Bind for DataType {
  57    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  58        let value = match self {
  59            DataType::Json => "json",
  60            DataType::Zstd => "zstd",
  61        };
  62        value.bind(statement, start_index)
  63    }
  64}
  65
  66impl Column for DataType {
  67    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  68        let (value, next_index) = String::column(statement, start_index)?;
  69        let data_type = match value.as_str() {
  70            "json" => DataType::Json,
  71            "zstd" => DataType::Zstd,
  72            _ => anyhow::bail!("Unknown data type: {}", value),
  73        };
  74        Ok((data_type, next_index))
  75    }
  76}
  77
  78const RULES_FILE_NAMES: [&'static str; 9] = [
  79    ".rules",
  80    ".cursorrules",
  81    ".windsurfrules",
  82    ".clinerules",
  83    ".github/copilot-instructions.md",
  84    "CLAUDE.md",
  85    "AGENT.md",
  86    "AGENTS.md",
  87    "GEMINI.md",
  88];
  89
  90pub fn init(cx: &mut App) {
  91    ThreadsDatabase::init(cx);
  92}
  93
  94/// A system prompt shared by all threads created by this ThreadStore
  95#[derive(Clone, Default)]
  96pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  97
  98impl SharedProjectContext {
  99    pub fn borrow(&self) -> Ref<'_, Option<ProjectContext>> {
 100        self.0.borrow()
 101    }
 102}
 103
 104pub type TextThreadStore = assistant_context::ContextStore;
 105
 106pub struct ThreadStore {
 107    project: Entity<Project>,
 108    cloud_user_store: Entity<CloudUserStore>,
 109    tools: Entity<ToolWorkingSet>,
 110    prompt_builder: Arc<PromptBuilder>,
 111    prompt_store: Option<Entity<PromptStore>>,
 112    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 113    threads: Vec<SerializedThreadMetadata>,
 114    project_context: SharedProjectContext,
 115    reload_system_prompt_tx: mpsc::Sender<()>,
 116    _reload_system_prompt_task: Task<()>,
 117    _subscriptions: Vec<Subscription>,
 118}
 119
 120pub struct RulesLoadingError {
 121    pub message: SharedString,
 122}
 123
 124impl EventEmitter<RulesLoadingError> for ThreadStore {}
 125
 126impl ThreadStore {
 127    pub fn load(
 128        project: Entity<Project>,
 129        cloud_user_store: Entity<CloudUserStore>,
 130        tools: Entity<ToolWorkingSet>,
 131        prompt_store: Option<Entity<PromptStore>>,
 132        prompt_builder: Arc<PromptBuilder>,
 133        cx: &mut App,
 134    ) -> Task<Result<Entity<Self>>> {
 135        cx.spawn(async move |cx| {
 136            let (thread_store, ready_rx) = cx.update(|cx| {
 137                let mut option_ready_rx = None;
 138                let thread_store = cx.new(|cx| {
 139                    let (thread_store, ready_rx) = Self::new(
 140                        project,
 141                        cloud_user_store,
 142                        tools,
 143                        prompt_builder,
 144                        prompt_store,
 145                        cx,
 146                    );
 147                    option_ready_rx = Some(ready_rx);
 148                    thread_store
 149                });
 150                (thread_store, option_ready_rx.take().unwrap())
 151            })?;
 152            ready_rx.await?;
 153            Ok(thread_store)
 154        })
 155    }
 156
 157    fn new(
 158        project: Entity<Project>,
 159        cloud_user_store: Entity<CloudUserStore>,
 160        tools: Entity<ToolWorkingSet>,
 161        prompt_builder: Arc<PromptBuilder>,
 162        prompt_store: Option<Entity<PromptStore>>,
 163        cx: &mut Context<Self>,
 164    ) -> (Self, oneshot::Receiver<()>) {
 165        let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 166
 167        if let Some(prompt_store) = prompt_store.as_ref() {
 168            subscriptions.push(cx.subscribe(
 169                prompt_store,
 170                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 171                    this.enqueue_system_prompt_reload();
 172                },
 173            ))
 174        }
 175
 176        // This channel and task prevent concurrent and redundant loading of the system prompt.
 177        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 178        let (ready_tx, ready_rx) = oneshot::channel();
 179        let mut ready_tx = Some(ready_tx);
 180        let reload_system_prompt_task = cx.spawn({
 181            let prompt_store = prompt_store.clone();
 182            async move |thread_store, cx| {
 183                loop {
 184                    let Some(reload_task) = thread_store
 185                        .update(cx, |thread_store, cx| {
 186                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 187                        })
 188                        .ok()
 189                    else {
 190                        return;
 191                    };
 192                    reload_task.await;
 193                    if let Some(ready_tx) = ready_tx.take() {
 194                        ready_tx.send(()).ok();
 195                    }
 196                    reload_system_prompt_rx.next().await;
 197                }
 198            }
 199        });
 200
 201        let this = Self {
 202            project,
 203            cloud_user_store,
 204            tools,
 205            prompt_builder,
 206            prompt_store,
 207            context_server_tool_ids: HashMap::default(),
 208            threads: Vec::new(),
 209            project_context: SharedProjectContext::default(),
 210            reload_system_prompt_tx,
 211            _reload_system_prompt_task: reload_system_prompt_task,
 212            _subscriptions: subscriptions,
 213        };
 214        this.register_context_server_handlers(cx);
 215        this.reload(cx).detach_and_log_err(cx);
 216        (this, ready_rx)
 217    }
 218
 219    fn handle_project_event(
 220        &mut self,
 221        _project: Entity<Project>,
 222        event: &project::Event,
 223        _cx: &mut Context<Self>,
 224    ) {
 225        match event {
 226            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 227                self.enqueue_system_prompt_reload();
 228            }
 229            project::Event::WorktreeUpdatedEntries(_, items) => {
 230                if items.iter().any(|(path, _, _)| {
 231                    RULES_FILE_NAMES
 232                        .iter()
 233                        .any(|name| path.as_ref() == Path::new(name))
 234                }) {
 235                    self.enqueue_system_prompt_reload();
 236                }
 237            }
 238            _ => {}
 239        }
 240    }
 241
 242    fn enqueue_system_prompt_reload(&mut self) {
 243        self.reload_system_prompt_tx.try_send(()).ok();
 244    }
 245
 246    // Note that this should only be called from `reload_system_prompt_task`.
 247    fn reload_system_prompt(
 248        &self,
 249        prompt_store: Option<Entity<PromptStore>>,
 250        cx: &mut Context<Self>,
 251    ) -> Task<()> {
 252        let worktrees = self
 253            .project
 254            .read(cx)
 255            .visible_worktrees(cx)
 256            .collect::<Vec<_>>();
 257        let worktree_tasks = worktrees
 258            .into_iter()
 259            .map(|worktree| {
 260                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 261            })
 262            .collect::<Vec<_>>();
 263        let default_user_rules_task = match prompt_store {
 264            None => Task::ready(vec![]),
 265            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 266                let prompts = prompt_store.default_prompt_metadata();
 267                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 268                    let contents = prompt_store.load(prompt_metadata.id, cx);
 269                    async move { (contents.await, prompt_metadata) }
 270                });
 271                cx.background_spawn(future::join_all(load_tasks))
 272            }),
 273        };
 274
 275        cx.spawn(async move |this, cx| {
 276            let (worktrees, default_user_rules) =
 277                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 278
 279            let worktrees = worktrees
 280                .into_iter()
 281                .map(|(worktree, rules_error)| {
 282                    if let Some(rules_error) = rules_error {
 283                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 284                    }
 285                    worktree
 286                })
 287                .collect::<Vec<_>>();
 288
 289            let default_user_rules = default_user_rules
 290                .into_iter()
 291                .flat_map(|(contents, prompt_metadata)| match contents {
 292                    Ok(contents) => Some(UserRulesContext {
 293                        uuid: match prompt_metadata.id {
 294                            PromptId::User { uuid } => uuid,
 295                            PromptId::EditWorkflow => return None,
 296                        },
 297                        title: prompt_metadata.title.map(|title| title.to_string()),
 298                        contents,
 299                    }),
 300                    Err(err) => {
 301                        this.update(cx, |_, cx| {
 302                            cx.emit(RulesLoadingError {
 303                                message: format!("{err:?}").into(),
 304                            });
 305                        })
 306                        .ok();
 307                        None
 308                    }
 309                })
 310                .collect::<Vec<_>>();
 311
 312            this.update(cx, |this, _cx| {
 313                *this.project_context.0.borrow_mut() =
 314                    Some(ProjectContext::new(worktrees, default_user_rules));
 315            })
 316            .ok();
 317        })
 318    }
 319
 320    fn load_worktree_info_for_system_prompt(
 321        worktree: Entity<Worktree>,
 322        project: Entity<Project>,
 323        cx: &mut App,
 324    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 325        let tree = worktree.read(cx);
 326        let root_name = tree.root_name().into();
 327        let abs_path = tree.abs_path();
 328
 329        let mut context = WorktreeContext {
 330            root_name,
 331            abs_path,
 332            rules_file: None,
 333        };
 334
 335        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 336        let Some(rules_task) = rules_task else {
 337            return Task::ready((context, None));
 338        };
 339
 340        cx.spawn(async move |_| {
 341            let (rules_file, rules_file_error) = match rules_task.await {
 342                Ok(rules_file) => (Some(rules_file), None),
 343                Err(err) => (
 344                    None,
 345                    Some(RulesLoadingError {
 346                        message: format!("{err}").into(),
 347                    }),
 348                ),
 349            };
 350            context.rules_file = rules_file;
 351            (context, rules_file_error)
 352        })
 353    }
 354
 355    fn load_worktree_rules_file(
 356        worktree: Entity<Worktree>,
 357        project: Entity<Project>,
 358        cx: &mut App,
 359    ) -> Option<Task<Result<RulesFileContext>>> {
 360        let worktree = worktree.read(cx);
 361        let worktree_id = worktree.id();
 362        let selected_rules_file = RULES_FILE_NAMES
 363            .into_iter()
 364            .filter_map(|name| {
 365                worktree
 366                    .entry_for_path(name)
 367                    .filter(|entry| entry.is_file())
 368                    .map(|entry| entry.path.clone())
 369            })
 370            .next();
 371
 372        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 373        // supported. This doesn't seem to occur often in GitHub repositories.
 374        selected_rules_file.map(|path_in_worktree| {
 375            let project_path = ProjectPath {
 376                worktree_id,
 377                path: path_in_worktree.clone(),
 378            };
 379            let buffer_task =
 380                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 381            let rope_task = cx.spawn(async move |cx| {
 382                buffer_task.await?.read_with(cx, |buffer, cx| {
 383                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 384                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 385                })?
 386            });
 387            // Build a string from the rope on a background thread.
 388            cx.background_spawn(async move {
 389                let (project_entry_id, rope) = rope_task.await?;
 390                anyhow::Ok(RulesFileContext {
 391                    path_in_worktree,
 392                    text: rope.to_string().trim().to_string(),
 393                    project_entry_id: project_entry_id.to_usize(),
 394                })
 395            })
 396        })
 397    }
 398
 399    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 400        &self.prompt_store
 401    }
 402
 403    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 404        self.tools.clone()
 405    }
 406
 407    /// Returns the number of threads.
 408    pub fn thread_count(&self) -> usize {
 409        self.threads.len()
 410    }
 411
 412    pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 413        // ordering is from "ORDER BY" in `list_threads`
 414        self.threads.iter()
 415    }
 416
 417    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 418        cx.new(|cx| {
 419            Thread::new(
 420                self.project.clone(),
 421                self.cloud_user_store.clone(),
 422                self.tools.clone(),
 423                self.prompt_builder.clone(),
 424                self.project_context.clone(),
 425                cx,
 426            )
 427        })
 428    }
 429
 430    pub fn create_thread_from_serialized(
 431        &mut self,
 432        serialized: SerializedThread,
 433        cx: &mut Context<Self>,
 434    ) -> Entity<Thread> {
 435        cx.new(|cx| {
 436            Thread::deserialize(
 437                ThreadId::new(),
 438                serialized,
 439                self.project.clone(),
 440                self.cloud_user_store.clone(),
 441                self.tools.clone(),
 442                self.prompt_builder.clone(),
 443                self.project_context.clone(),
 444                None,
 445                cx,
 446            )
 447        })
 448    }
 449
 450    pub fn open_thread(
 451        &self,
 452        id: &ThreadId,
 453        window: &mut Window,
 454        cx: &mut Context<Self>,
 455    ) -> Task<Result<Entity<Thread>>> {
 456        let id = id.clone();
 457        let database_future = ThreadsDatabase::global_future(cx);
 458        let this = cx.weak_entity();
 459        window.spawn(cx, async move |cx| {
 460            let database = database_future.await.map_err(|err| anyhow!(err))?;
 461            let thread = database
 462                .try_find_thread(id.clone())
 463                .await?
 464                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 465
 466            let thread = this.update_in(cx, |this, window, cx| {
 467                cx.new(|cx| {
 468                    Thread::deserialize(
 469                        id.clone(),
 470                        thread,
 471                        this.project.clone(),
 472                        this.cloud_user_store.clone(),
 473                        this.tools.clone(),
 474                        this.prompt_builder.clone(),
 475                        this.project_context.clone(),
 476                        Some(window),
 477                        cx,
 478                    )
 479                })
 480            })?;
 481
 482            Ok(thread)
 483        })
 484    }
 485
 486    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 487        let (metadata, serialized_thread) =
 488            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 489
 490        let database_future = ThreadsDatabase::global_future(cx);
 491        cx.spawn(async move |this, cx| {
 492            let serialized_thread = serialized_thread.await?;
 493            let database = database_future.await.map_err(|err| anyhow!(err))?;
 494            database.save_thread(metadata, serialized_thread).await?;
 495
 496            this.update(cx, |this, cx| this.reload(cx))?.await
 497        })
 498    }
 499
 500    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 501        let id = id.clone();
 502        let database_future = ThreadsDatabase::global_future(cx);
 503        cx.spawn(async move |this, cx| {
 504            let database = database_future.await.map_err(|err| anyhow!(err))?;
 505            database.delete_thread(id.clone()).await?;
 506
 507            this.update(cx, |this, cx| {
 508                this.threads.retain(|thread| thread.id != id);
 509                cx.notify();
 510            })
 511        })
 512    }
 513
 514    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 515        let database_future = ThreadsDatabase::global_future(cx);
 516        cx.spawn(async move |this, cx| {
 517            let threads = database_future
 518                .await
 519                .map_err(|err| anyhow!(err))?
 520                .list_threads()
 521                .await?;
 522
 523            this.update(cx, |this, cx| {
 524                this.threads = threads;
 525                cx.notify();
 526            })
 527        })
 528    }
 529
 530    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 531        let context_server_store = self.project.read(cx).context_server_store();
 532        cx.subscribe(&context_server_store, Self::handle_context_server_event)
 533            .detach();
 534
 535        // Check for any servers that were already running before the handler was registered
 536        for server in context_server_store.read(cx).running_servers() {
 537            self.load_context_server_tools(server.id(), context_server_store.clone(), cx);
 538        }
 539    }
 540
 541    fn handle_context_server_event(
 542        &mut self,
 543        context_server_store: Entity<ContextServerStore>,
 544        event: &project::context_server_store::Event,
 545        cx: &mut Context<Self>,
 546    ) {
 547        let tool_working_set = self.tools.clone();
 548        match event {
 549            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 550                match status {
 551                    ContextServerStatus::Starting => {}
 552                    ContextServerStatus::Running => {
 553                        self.load_context_server_tools(server_id.clone(), context_server_store, cx);
 554                    }
 555                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 556                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 557                            tool_working_set.update(cx, |tool_working_set, cx| {
 558                                tool_working_set.remove(&tool_ids, cx);
 559                            });
 560                        }
 561                    }
 562                }
 563            }
 564        }
 565    }
 566
 567    fn load_context_server_tools(
 568        &self,
 569        server_id: ContextServerId,
 570        context_server_store: Entity<ContextServerStore>,
 571        cx: &mut Context<Self>,
 572    ) {
 573        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
 574            return;
 575        };
 576        let tool_working_set = self.tools.clone();
 577        cx.spawn(async move |this, cx| {
 578            let Some(protocol) = server.client() else {
 579                return;
 580            };
 581
 582            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
 583                if let Some(response) = protocol
 584                    .request::<context_server::types::requests::ListTools>(())
 585                    .await
 586                    .log_err()
 587                {
 588                    let tool_ids = tool_working_set
 589                        .update(cx, |tool_working_set, cx| {
 590                            tool_working_set.extend(
 591                                response.tools.into_iter().map(|tool| {
 592                                    Arc::new(ContextServerTool::new(
 593                                        context_server_store.clone(),
 594                                        server.id(),
 595                                        tool,
 596                                    )) as Arc<dyn Tool>
 597                                }),
 598                                cx,
 599                            )
 600                        })
 601                        .log_err();
 602
 603                    if let Some(tool_ids) = tool_ids {
 604                        this.update(cx, |this, _| {
 605                            this.context_server_tool_ids.insert(server_id, tool_ids);
 606                        })
 607                        .log_err();
 608                    }
 609                }
 610            }
 611        })
 612        .detach();
 613    }
 614}
 615
 616#[derive(Debug, Clone, Serialize, Deserialize)]
 617pub struct SerializedThreadMetadata {
 618    pub id: ThreadId,
 619    pub summary: SharedString,
 620    pub updated_at: DateTime<Utc>,
 621}
 622
 623#[derive(Serialize, Deserialize, Debug, PartialEq)]
 624pub struct SerializedThread {
 625    pub version: String,
 626    pub summary: SharedString,
 627    pub updated_at: DateTime<Utc>,
 628    pub messages: Vec<SerializedMessage>,
 629    #[serde(default)]
 630    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 631    #[serde(default)]
 632    pub cumulative_token_usage: TokenUsage,
 633    #[serde(default)]
 634    pub request_token_usage: Vec<TokenUsage>,
 635    #[serde(default)]
 636    pub detailed_summary_state: DetailedSummaryState,
 637    #[serde(default)]
 638    pub exceeded_window_error: Option<ExceededWindowError>,
 639    #[serde(default)]
 640    pub model: Option<SerializedLanguageModel>,
 641    #[serde(default)]
 642    pub completion_mode: Option<CompletionMode>,
 643    #[serde(default)]
 644    pub tool_use_limit_reached: bool,
 645    #[serde(default)]
 646    pub profile: Option<AgentProfileId>,
 647}
 648
 649#[derive(Serialize, Deserialize, Debug, PartialEq)]
 650pub struct SerializedLanguageModel {
 651    pub provider: String,
 652    pub model: String,
 653}
 654
 655impl SerializedThread {
 656    pub const VERSION: &'static str = "0.2.0";
 657
 658    pub fn from_json(json: &[u8]) -> Result<Self> {
 659        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 660        match saved_thread_json.get("version") {
 661            Some(serde_json::Value::String(version)) => match version.as_str() {
 662                SerializedThreadV0_1_0::VERSION => {
 663                    let saved_thread =
 664                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 665                    Ok(saved_thread.upgrade())
 666                }
 667                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 668                    saved_thread_json,
 669                )?),
 670                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 671            },
 672            None => {
 673                let saved_thread =
 674                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 675                Ok(saved_thread.upgrade())
 676            }
 677            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 678        }
 679    }
 680}
 681
 682#[derive(Serialize, Deserialize, Debug)]
 683pub struct SerializedThreadV0_1_0(
 684    // The structure did not change, so we are reusing the latest SerializedThread.
 685    // When making the next version, make sure this points to SerializedThreadV0_2_0
 686    SerializedThread,
 687);
 688
 689impl SerializedThreadV0_1_0 {
 690    pub const VERSION: &'static str = "0.1.0";
 691
 692    pub fn upgrade(self) -> SerializedThread {
 693        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 694
 695        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 696
 697        for message in self.0.messages {
 698            if message.role == Role::User && !message.tool_results.is_empty() {
 699                if let Some(last_message) = messages.last_mut() {
 700                    debug_assert!(last_message.role == Role::Assistant);
 701
 702                    last_message.tool_results = message.tool_results;
 703                    continue;
 704                }
 705            }
 706
 707            messages.push(message);
 708        }
 709
 710        SerializedThread {
 711            messages,
 712            version: SerializedThread::VERSION.to_string(),
 713            ..self.0
 714        }
 715    }
 716}
 717
 718#[derive(Debug, Serialize, Deserialize, PartialEq)]
 719pub struct SerializedMessage {
 720    pub id: MessageId,
 721    pub role: Role,
 722    #[serde(default)]
 723    pub segments: Vec<SerializedMessageSegment>,
 724    #[serde(default)]
 725    pub tool_uses: Vec<SerializedToolUse>,
 726    #[serde(default)]
 727    pub tool_results: Vec<SerializedToolResult>,
 728    #[serde(default)]
 729    pub context: String,
 730    #[serde(default)]
 731    pub creases: Vec<SerializedCrease>,
 732    #[serde(default)]
 733    pub is_hidden: bool,
 734}
 735
 736#[derive(Debug, Serialize, Deserialize, PartialEq)]
 737#[serde(tag = "type")]
 738pub enum SerializedMessageSegment {
 739    #[serde(rename = "text")]
 740    Text {
 741        text: String,
 742    },
 743    #[serde(rename = "thinking")]
 744    Thinking {
 745        text: String,
 746        #[serde(skip_serializing_if = "Option::is_none")]
 747        signature: Option<String>,
 748    },
 749    RedactedThinking {
 750        data: String,
 751    },
 752}
 753
 754#[derive(Debug, Serialize, Deserialize, PartialEq)]
 755pub struct SerializedToolUse {
 756    pub id: LanguageModelToolUseId,
 757    pub name: SharedString,
 758    pub input: serde_json::Value,
 759}
 760
 761#[derive(Debug, Serialize, Deserialize, PartialEq)]
 762pub struct SerializedToolResult {
 763    pub tool_use_id: LanguageModelToolUseId,
 764    pub is_error: bool,
 765    pub content: LanguageModelToolResultContent,
 766    pub output: Option<serde_json::Value>,
 767}
 768
 769#[derive(Serialize, Deserialize)]
 770struct LegacySerializedThread {
 771    pub summary: SharedString,
 772    pub updated_at: DateTime<Utc>,
 773    pub messages: Vec<LegacySerializedMessage>,
 774    #[serde(default)]
 775    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 776}
 777
 778impl LegacySerializedThread {
 779    pub fn upgrade(self) -> SerializedThread {
 780        SerializedThread {
 781            version: SerializedThread::VERSION.to_string(),
 782            summary: self.summary,
 783            updated_at: self.updated_at,
 784            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 785            initial_project_snapshot: self.initial_project_snapshot,
 786            cumulative_token_usage: TokenUsage::default(),
 787            request_token_usage: Vec::new(),
 788            detailed_summary_state: DetailedSummaryState::default(),
 789            exceeded_window_error: None,
 790            model: None,
 791            completion_mode: None,
 792            tool_use_limit_reached: false,
 793            profile: None,
 794        }
 795    }
 796}
 797
 798#[derive(Debug, Serialize, Deserialize)]
 799struct LegacySerializedMessage {
 800    pub id: MessageId,
 801    pub role: Role,
 802    pub text: String,
 803    #[serde(default)]
 804    pub tool_uses: Vec<SerializedToolUse>,
 805    #[serde(default)]
 806    pub tool_results: Vec<SerializedToolResult>,
 807}
 808
 809impl LegacySerializedMessage {
 810    fn upgrade(self) -> SerializedMessage {
 811        SerializedMessage {
 812            id: self.id,
 813            role: self.role,
 814            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 815            tool_uses: self.tool_uses,
 816            tool_results: self.tool_results,
 817            context: String::new(),
 818            creases: Vec::new(),
 819            is_hidden: false,
 820        }
 821    }
 822}
 823
 824#[derive(Debug, Serialize, Deserialize, PartialEq)]
 825pub struct SerializedCrease {
 826    pub start: usize,
 827    pub end: usize,
 828    pub icon_path: SharedString,
 829    pub label: SharedString,
 830}
 831
 832struct GlobalThreadsDatabase(
 833    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 834);
 835
 836impl Global for GlobalThreadsDatabase {}
 837
 838pub(crate) struct ThreadsDatabase {
 839    executor: BackgroundExecutor,
 840    connection: Arc<Mutex<Connection>>,
 841}
 842
 843impl ThreadsDatabase {
 844    fn connection(&self) -> Arc<Mutex<Connection>> {
 845        self.connection.clone()
 846    }
 847
 848    const COMPRESSION_LEVEL: i32 = 3;
 849}
 850
 851impl Bind for ThreadId {
 852    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 853        self.to_string().bind(statement, start_index)
 854    }
 855}
 856
 857impl Column for ThreadId {
 858    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 859        let (id_str, next_index) = String::column(statement, start_index)?;
 860        Ok((ThreadId::from(id_str.as_str()), next_index))
 861    }
 862}
 863
 864impl ThreadsDatabase {
 865    fn global_future(
 866        cx: &mut App,
 867    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 868        GlobalThreadsDatabase::global(cx).0.clone()
 869    }
 870
 871    fn init(cx: &mut App) {
 872        let executor = cx.background_executor().clone();
 873        let database_future = executor
 874            .spawn({
 875                let executor = executor.clone();
 876                let threads_dir = paths::data_dir().join("threads");
 877                async move { ThreadsDatabase::new(threads_dir, executor) }
 878            })
 879            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 880            .boxed()
 881            .shared();
 882
 883        cx.set_global(GlobalThreadsDatabase(database_future));
 884    }
 885
 886    pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
 887        std::fs::create_dir_all(&threads_dir)?;
 888
 889        let sqlite_path = threads_dir.join("threads.db");
 890        let mdb_path = threads_dir.join("threads-db.1.mdb");
 891
 892        let needs_migration_from_heed = mdb_path.exists();
 893
 894        let connection = if *ZED_STATELESS {
 895            Connection::open_memory(Some("THREAD_FALLBACK_DB"))
 896        } else {
 897            Connection::open_file(&sqlite_path.to_string_lossy())
 898        };
 899
 900        connection.exec(indoc! {"
 901                CREATE TABLE IF NOT EXISTS threads (
 902                    id TEXT PRIMARY KEY,
 903                    summary TEXT NOT NULL,
 904                    updated_at TEXT NOT NULL,
 905                    data_type TEXT NOT NULL,
 906                    data BLOB NOT NULL
 907                )
 908            "})?()
 909        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 910
 911        let db = Self {
 912            executor: executor.clone(),
 913            connection: Arc::new(Mutex::new(connection)),
 914        };
 915
 916        if needs_migration_from_heed {
 917            let db_connection = db.connection();
 918            let executor_clone = executor.clone();
 919            executor
 920                .spawn(async move {
 921                    log::info!("Starting threads.db migration");
 922                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 923                    std::fs::remove_dir_all(mdb_path)?;
 924                    log::info!("threads.db migrated to sqlite");
 925                    Ok::<(), anyhow::Error>(())
 926                })
 927                .detach();
 928        }
 929
 930        Ok(db)
 931    }
 932
 933    // Remove this migration after 2025-09-01
 934    fn migrate_from_heed(
 935        mdb_path: &Path,
 936        connection: Arc<Mutex<Connection>>,
 937        _executor: BackgroundExecutor,
 938    ) -> Result<()> {
 939        use heed::types::SerdeBincode;
 940        struct SerializedThreadHeed(SerializedThread);
 941
 942        impl heed::BytesEncode<'_> for SerializedThreadHeed {
 943            type EItem = SerializedThreadHeed;
 944
 945            fn bytes_encode(
 946                item: &Self::EItem,
 947            ) -> Result<std::borrow::Cow<'_, [u8]>, heed::BoxedError> {
 948                serde_json::to_vec(&item.0)
 949                    .map(std::borrow::Cow::Owned)
 950                    .map_err(Into::into)
 951            }
 952        }
 953
 954        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
 955            type DItem = SerializedThreadHeed;
 956
 957            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
 958                SerializedThread::from_json(bytes)
 959                    .map(SerializedThreadHeed)
 960                    .map_err(Into::into)
 961            }
 962        }
 963
 964        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
 965
 966        let env = unsafe {
 967            heed::EnvOpenOptions::new()
 968                .map_size(ONE_GB_IN_BYTES)
 969                .max_dbs(1)
 970                .open(mdb_path)?
 971        };
 972
 973        let txn = env.write_txn()?;
 974        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
 975            .open_database(&txn, Some("threads"))?
 976            .ok_or_else(|| anyhow!("threads database not found"))?;
 977
 978        for result in threads.iter(&txn)? {
 979            let (thread_id, thread_heed) = result?;
 980            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
 981        }
 982
 983        Ok(())
 984    }
 985
 986    fn save_thread_sync(
 987        connection: &Arc<Mutex<Connection>>,
 988        id: ThreadId,
 989        thread: SerializedThread,
 990    ) -> Result<()> {
 991        let json_data = serde_json::to_string(&thread)?;
 992        let summary = thread.summary.to_string();
 993        let updated_at = thread.updated_at.to_rfc3339();
 994
 995        let connection = connection.lock().unwrap();
 996
 997        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
 998        let data_type = DataType::Zstd;
 999        let data = compressed;
1000
1001        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1002            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1003        "})?;
1004
1005        insert((id, summary, updated_at, data_type, data))?;
1006
1007        Ok(())
1008    }
1009
1010    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1011        let connection = self.connection.clone();
1012
1013        self.executor.spawn(async move {
1014            let connection = connection.lock().unwrap();
1015            let mut select =
1016                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1017                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1018            "})?;
1019
1020            let rows = select(())?;
1021            let mut threads = Vec::new();
1022
1023            for (id, summary, updated_at) in rows {
1024                threads.push(SerializedThreadMetadata {
1025                    id,
1026                    summary: summary.into(),
1027                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1028                });
1029            }
1030
1031            Ok(threads)
1032        })
1033    }
1034
1035    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1036        let connection = self.connection.clone();
1037
1038        self.executor.spawn(async move {
1039            let connection = connection.lock().unwrap();
1040            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1041                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1042            "})?;
1043
1044            let rows = select(id)?;
1045            if let Some((data_type, data)) = rows.into_iter().next() {
1046                let json_data = match data_type {
1047                    DataType::Zstd => {
1048                        let decompressed = zstd::decode_all(&data[..])?;
1049                        String::from_utf8(decompressed)?
1050                    }
1051                    DataType::Json => String::from_utf8(data)?,
1052                };
1053
1054                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1055                Ok(Some(thread))
1056            } else {
1057                Ok(None)
1058            }
1059        })
1060    }
1061
1062    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1063        let connection = self.connection.clone();
1064
1065        self.executor
1066            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1067    }
1068
1069    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1070        let connection = self.connection.clone();
1071
1072        self.executor.spawn(async move {
1073            let connection = connection.lock().unwrap();
1074
1075            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1076                DELETE FROM threads WHERE id = ?
1077            "})?;
1078
1079            delete(id)?;
1080
1081            Ok(())
1082        })
1083    }
1084}
1085
1086#[cfg(test)]
1087mod tests {
1088    use super::*;
1089    use crate::thread::{DetailedSummaryState, MessageId};
1090    use chrono::Utc;
1091    use language_model::{Role, TokenUsage};
1092    use pretty_assertions::assert_eq;
1093
1094    #[test]
1095    fn test_legacy_serialized_thread_upgrade() {
1096        let updated_at = Utc::now();
1097        let legacy_thread = LegacySerializedThread {
1098            summary: "Test conversation".into(),
1099            updated_at,
1100            messages: vec![LegacySerializedMessage {
1101                id: MessageId(1),
1102                role: Role::User,
1103                text: "Hello, world!".to_string(),
1104                tool_uses: vec![],
1105                tool_results: vec![],
1106            }],
1107            initial_project_snapshot: None,
1108        };
1109
1110        let upgraded = legacy_thread.upgrade();
1111
1112        assert_eq!(
1113            upgraded,
1114            SerializedThread {
1115                summary: "Test conversation".into(),
1116                updated_at,
1117                messages: vec![SerializedMessage {
1118                    id: MessageId(1),
1119                    role: Role::User,
1120                    segments: vec![SerializedMessageSegment::Text {
1121                        text: "Hello, world!".to_string()
1122                    }],
1123                    tool_uses: vec![],
1124                    tool_results: vec![],
1125                    context: "".to_string(),
1126                    creases: vec![],
1127                    is_hidden: false
1128                }],
1129                version: SerializedThread::VERSION.to_string(),
1130                initial_project_snapshot: None,
1131                cumulative_token_usage: TokenUsage::default(),
1132                request_token_usage: vec![],
1133                detailed_summary_state: DetailedSummaryState::default(),
1134                exceeded_window_error: None,
1135                model: None,
1136                completion_mode: None,
1137                tool_use_limit_reached: false,
1138                profile: None
1139            }
1140        )
1141    }
1142
1143    #[test]
1144    fn test_serialized_threadv0_1_0_upgrade() {
1145        let updated_at = Utc::now();
1146        let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread {
1147            summary: "Test conversation".into(),
1148            updated_at,
1149            messages: vec![
1150                SerializedMessage {
1151                    id: MessageId(1),
1152                    role: Role::User,
1153                    segments: vec![SerializedMessageSegment::Text {
1154                        text: "Use tool_1".to_string(),
1155                    }],
1156                    tool_uses: vec![],
1157                    tool_results: vec![],
1158                    context: "".to_string(),
1159                    creases: vec![],
1160                    is_hidden: false,
1161                },
1162                SerializedMessage {
1163                    id: MessageId(2),
1164                    role: Role::Assistant,
1165                    segments: vec![SerializedMessageSegment::Text {
1166                        text: "I want to use a tool".to_string(),
1167                    }],
1168                    tool_uses: vec![SerializedToolUse {
1169                        id: "abc".into(),
1170                        name: "tool_1".into(),
1171                        input: serde_json::Value::Null,
1172                    }],
1173                    tool_results: vec![],
1174                    context: "".to_string(),
1175                    creases: vec![],
1176                    is_hidden: false,
1177                },
1178                SerializedMessage {
1179                    id: MessageId(1),
1180                    role: Role::User,
1181                    segments: vec![SerializedMessageSegment::Text {
1182                        text: "Here is the tool result".to_string(),
1183                    }],
1184                    tool_uses: vec![],
1185                    tool_results: vec![SerializedToolResult {
1186                        tool_use_id: "abc".into(),
1187                        is_error: false,
1188                        content: LanguageModelToolResultContent::Text("abcdef".into()),
1189                        output: Some(serde_json::Value::Null),
1190                    }],
1191                    context: "".to_string(),
1192                    creases: vec![],
1193                    is_hidden: false,
1194                },
1195            ],
1196            version: SerializedThreadV0_1_0::VERSION.to_string(),
1197            initial_project_snapshot: None,
1198            cumulative_token_usage: TokenUsage::default(),
1199            request_token_usage: vec![],
1200            detailed_summary_state: DetailedSummaryState::default(),
1201            exceeded_window_error: None,
1202            model: None,
1203            completion_mode: None,
1204            tool_use_limit_reached: false,
1205            profile: None,
1206        });
1207        let upgraded = thread_v0_1_0.upgrade();
1208
1209        assert_eq!(
1210            upgraded,
1211            SerializedThread {
1212                summary: "Test conversation".into(),
1213                updated_at,
1214                messages: vec![
1215                    SerializedMessage {
1216                        id: MessageId(1),
1217                        role: Role::User,
1218                        segments: vec![SerializedMessageSegment::Text {
1219                            text: "Use tool_1".to_string()
1220                        }],
1221                        tool_uses: vec![],
1222                        tool_results: vec![],
1223                        context: "".to_string(),
1224                        creases: vec![],
1225                        is_hidden: false
1226                    },
1227                    SerializedMessage {
1228                        id: MessageId(2),
1229                        role: Role::Assistant,
1230                        segments: vec![SerializedMessageSegment::Text {
1231                            text: "I want to use a tool".to_string(),
1232                        }],
1233                        tool_uses: vec![SerializedToolUse {
1234                            id: "abc".into(),
1235                            name: "tool_1".into(),
1236                            input: serde_json::Value::Null,
1237                        }],
1238                        tool_results: vec![SerializedToolResult {
1239                            tool_use_id: "abc".into(),
1240                            is_error: false,
1241                            content: LanguageModelToolResultContent::Text("abcdef".into()),
1242                            output: Some(serde_json::Value::Null),
1243                        }],
1244                        context: "".to_string(),
1245                        creases: vec![],
1246                        is_hidden: false,
1247                    },
1248                ],
1249                version: SerializedThread::VERSION.to_string(),
1250                initial_project_snapshot: None,
1251                cumulative_token_usage: TokenUsage::default(),
1252                request_token_usage: vec![],
1253                detailed_summary_state: DetailedSummaryState::default(),
1254                exceeded_window_error: None,
1255                model: None,
1256                completion_mode: None,
1257                tool_use_limit_reached: false,
1258                profile: None
1259            }
1260        )
1261    }
1262}