thread_store.rs

   1use std::cell::{Ref, RefCell};
   2use std::path::{Path, PathBuf};
   3use std::rc::Rc;
   4use std::sync::{Arc, Mutex};
   5
   6use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::HashMap;
  11use context_server::ContextServerId;
  12use futures::channel::{mpsc, oneshot};
  13use futures::future::{self, BoxFuture, Shared};
  14use futures::{FutureExt as _, StreamExt as _};
  15use gpui::{
  16    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
  17    Subscription, Task, prelude::*,
  18};
  19
  20use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
  21use project::context_server_store::{ContextServerStatus, ContextServerStore};
  22use project::{Project, ProjectItem, ProjectPath, Worktree};
  23use prompt_store::{
  24    ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
  25    UserRulesContext, WorktreeContext,
  26};
  27use serde::{Deserialize, Serialize};
  28use settings::{Settings as _, SettingsStore};
  29use ui::Window;
  30use util::ResultExt as _;
  31
  32use crate::context_server_tool::ContextServerTool;
  33use crate::thread::{
  34    DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
  35};
  36use indoc::indoc;
  37use sqlez::{
  38    bindable::{Bind, Column},
  39    connection::Connection,
  40    statement::Statement,
  41};
  42
  43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  44pub enum DataType {
  45    #[serde(rename = "json")]
  46    Json,
  47    #[serde(rename = "zstd")]
  48    Zstd,
  49}
  50
  51impl Bind for DataType {
  52    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
  53        let value = match self {
  54            DataType::Json => "json",
  55            DataType::Zstd => "zstd",
  56        };
  57        value.bind(statement, start_index)
  58    }
  59}
  60
  61impl Column for DataType {
  62    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
  63        let (value, next_index) = String::column(statement, start_index)?;
  64        let data_type = match value.as_str() {
  65            "json" => DataType::Json,
  66            "zstd" => DataType::Zstd,
  67            _ => anyhow::bail!("Unknown data type: {}", value),
  68        };
  69        Ok((data_type, next_index))
  70    }
  71}
  72
  73const RULES_FILE_NAMES: [&'static str; 8] = [
  74    ".rules",
  75    ".cursorrules",
  76    ".windsurfrules",
  77    ".clinerules",
  78    ".github/copilot-instructions.md",
  79    "CLAUDE.md",
  80    "AGENT.md",
  81    "AGENTS.md",
  82];
  83
  84pub fn init(cx: &mut App) {
  85    ThreadsDatabase::init(cx);
  86}
  87
  88/// A system prompt shared by all threads created by this ThreadStore
  89#[derive(Clone, Default)]
  90pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
  91
  92impl SharedProjectContext {
  93    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
  94        self.0.borrow()
  95    }
  96}
  97
  98pub type TextThreadStore = assistant_context_editor::ContextStore;
  99
 100pub struct ThreadStore {
 101    project: Entity<Project>,
 102    tools: Entity<ToolWorkingSet>,
 103    prompt_builder: Arc<PromptBuilder>,
 104    prompt_store: Option<Entity<PromptStore>>,
 105    context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
 106    threads: Vec<SerializedThreadMetadata>,
 107    project_context: SharedProjectContext,
 108    reload_system_prompt_tx: mpsc::Sender<()>,
 109    _reload_system_prompt_task: Task<()>,
 110    _subscriptions: Vec<Subscription>,
 111}
 112
 113pub struct RulesLoadingError {
 114    pub message: SharedString,
 115}
 116
 117impl EventEmitter<RulesLoadingError> for ThreadStore {}
 118
 119impl ThreadStore {
 120    pub fn load(
 121        project: Entity<Project>,
 122        tools: Entity<ToolWorkingSet>,
 123        prompt_store: Option<Entity<PromptStore>>,
 124        prompt_builder: Arc<PromptBuilder>,
 125        cx: &mut App,
 126    ) -> Task<Result<Entity<Self>>> {
 127        cx.spawn(async move |cx| {
 128            let (thread_store, ready_rx) = cx.update(|cx| {
 129                let mut option_ready_rx = None;
 130                let thread_store = cx.new(|cx| {
 131                    let (thread_store, ready_rx) =
 132                        Self::new(project, tools, prompt_builder, prompt_store, cx);
 133                    option_ready_rx = Some(ready_rx);
 134                    thread_store
 135                });
 136                (thread_store, option_ready_rx.take().unwrap())
 137            })?;
 138            ready_rx.await?;
 139            Ok(thread_store)
 140        })
 141    }
 142
 143    fn new(
 144        project: Entity<Project>,
 145        tools: Entity<ToolWorkingSet>,
 146        prompt_builder: Arc<PromptBuilder>,
 147        prompt_store: Option<Entity<PromptStore>>,
 148        cx: &mut Context<Self>,
 149    ) -> (Self, oneshot::Receiver<()>) {
 150        let mut subscriptions = vec![
 151            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
 152                this.load_default_profile(cx);
 153            }),
 154            cx.subscribe(&project, Self::handle_project_event),
 155        ];
 156
 157        if let Some(prompt_store) = prompt_store.as_ref() {
 158            subscriptions.push(cx.subscribe(
 159                prompt_store,
 160                |this, _prompt_store, PromptsUpdatedEvent, _cx| {
 161                    this.enqueue_system_prompt_reload();
 162                },
 163            ))
 164        }
 165
 166        // This channel and task prevent concurrent and redundant loading of the system prompt.
 167        let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
 168        let (ready_tx, ready_rx) = oneshot::channel();
 169        let mut ready_tx = Some(ready_tx);
 170        let reload_system_prompt_task = cx.spawn({
 171            let prompt_store = prompt_store.clone();
 172            async move |thread_store, cx| {
 173                loop {
 174                    let Some(reload_task) = thread_store
 175                        .update(cx, |thread_store, cx| {
 176                            thread_store.reload_system_prompt(prompt_store.clone(), cx)
 177                        })
 178                        .ok()
 179                    else {
 180                        return;
 181                    };
 182                    reload_task.await;
 183                    if let Some(ready_tx) = ready_tx.take() {
 184                        ready_tx.send(()).ok();
 185                    }
 186                    reload_system_prompt_rx.next().await;
 187                }
 188            }
 189        });
 190
 191        let this = Self {
 192            project,
 193            tools,
 194            prompt_builder,
 195            prompt_store,
 196            context_server_tool_ids: HashMap::default(),
 197            threads: Vec::new(),
 198            project_context: SharedProjectContext::default(),
 199            reload_system_prompt_tx,
 200            _reload_system_prompt_task: reload_system_prompt_task,
 201            _subscriptions: subscriptions,
 202        };
 203        this.load_default_profile(cx);
 204        this.register_context_server_handlers(cx);
 205        this.reload(cx).detach_and_log_err(cx);
 206        (this, ready_rx)
 207    }
 208
 209    fn handle_project_event(
 210        &mut self,
 211        _project: Entity<Project>,
 212        event: &project::Event,
 213        _cx: &mut Context<Self>,
 214    ) {
 215        match event {
 216            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 217                self.enqueue_system_prompt_reload();
 218            }
 219            project::Event::WorktreeUpdatedEntries(_, items) => {
 220                if items.iter().any(|(path, _, _)| {
 221                    RULES_FILE_NAMES
 222                        .iter()
 223                        .any(|name| path.as_ref() == Path::new(name))
 224                }) {
 225                    self.enqueue_system_prompt_reload();
 226                }
 227            }
 228            _ => {}
 229        }
 230    }
 231
 232    fn enqueue_system_prompt_reload(&mut self) {
 233        self.reload_system_prompt_tx.try_send(()).ok();
 234    }
 235
 236    // Note that this should only be called from `reload_system_prompt_task`.
 237    fn reload_system_prompt(
 238        &self,
 239        prompt_store: Option<Entity<PromptStore>>,
 240        cx: &mut Context<Self>,
 241    ) -> Task<()> {
 242        let worktrees = self
 243            .project
 244            .read(cx)
 245            .visible_worktrees(cx)
 246            .collect::<Vec<_>>();
 247        let worktree_tasks = worktrees
 248            .into_iter()
 249            .map(|worktree| {
 250                Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
 251            })
 252            .collect::<Vec<_>>();
 253        let default_user_rules_task = match prompt_store {
 254            None => Task::ready(vec![]),
 255            Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
 256                let prompts = prompt_store.default_prompt_metadata();
 257                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 258                    let contents = prompt_store.load(prompt_metadata.id, cx);
 259                    async move { (contents.await, prompt_metadata) }
 260                });
 261                cx.background_spawn(future::join_all(load_tasks))
 262            }),
 263        };
 264
 265        cx.spawn(async move |this, cx| {
 266            let (worktrees, default_user_rules) =
 267                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 268
 269            let worktrees = worktrees
 270                .into_iter()
 271                .map(|(worktree, rules_error)| {
 272                    if let Some(rules_error) = rules_error {
 273                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 274                    }
 275                    worktree
 276                })
 277                .collect::<Vec<_>>();
 278
 279            let default_user_rules = default_user_rules
 280                .into_iter()
 281                .flat_map(|(contents, prompt_metadata)| match contents {
 282                    Ok(contents) => Some(UserRulesContext {
 283                        uuid: match prompt_metadata.id {
 284                            PromptId::User { uuid } => uuid,
 285                            PromptId::EditWorkflow => return None,
 286                        },
 287                        title: prompt_metadata.title.map(|title| title.to_string()),
 288                        contents,
 289                    }),
 290                    Err(err) => {
 291                        this.update(cx, |_, cx| {
 292                            cx.emit(RulesLoadingError {
 293                                message: format!("{err:?}").into(),
 294                            });
 295                        })
 296                        .ok();
 297                        None
 298                    }
 299                })
 300                .collect::<Vec<_>>();
 301
 302            this.update(cx, |this, _cx| {
 303                *this.project_context.0.borrow_mut() =
 304                    Some(ProjectContext::new(worktrees, default_user_rules));
 305            })
 306            .ok();
 307        })
 308    }
 309
 310    fn load_worktree_info_for_system_prompt(
 311        worktree: Entity<Worktree>,
 312        project: Entity<Project>,
 313        cx: &mut App,
 314    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 315        let root_name = worktree.read(cx).root_name().into();
 316
 317        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 318        let Some(rules_task) = rules_task else {
 319            return Task::ready((
 320                WorktreeContext {
 321                    root_name,
 322                    rules_file: None,
 323                },
 324                None,
 325            ));
 326        };
 327
 328        cx.spawn(async move |_| {
 329            let (rules_file, rules_file_error) = match rules_task.await {
 330                Ok(rules_file) => (Some(rules_file), None),
 331                Err(err) => (
 332                    None,
 333                    Some(RulesLoadingError {
 334                        message: format!("{err}").into(),
 335                    }),
 336                ),
 337            };
 338            let worktree_info = WorktreeContext {
 339                root_name,
 340                rules_file,
 341            };
 342            (worktree_info, rules_file_error)
 343        })
 344    }
 345
 346    fn load_worktree_rules_file(
 347        worktree: Entity<Worktree>,
 348        project: Entity<Project>,
 349        cx: &mut App,
 350    ) -> Option<Task<Result<RulesFileContext>>> {
 351        let worktree_ref = worktree.read(cx);
 352        let worktree_id = worktree_ref.id();
 353        let selected_rules_file = RULES_FILE_NAMES
 354            .into_iter()
 355            .filter_map(|name| {
 356                worktree_ref
 357                    .entry_for_path(name)
 358                    .filter(|entry| entry.is_file())
 359                    .map(|entry| entry.path.clone())
 360            })
 361            .next();
 362
 363        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 364        // supported. This doesn't seem to occur often in GitHub repositories.
 365        selected_rules_file.map(|path_in_worktree| {
 366            let project_path = ProjectPath {
 367                worktree_id,
 368                path: path_in_worktree.clone(),
 369            };
 370            let buffer_task =
 371                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 372            let rope_task = cx.spawn(async move |cx| {
 373                buffer_task.await?.read_with(cx, |buffer, cx| {
 374                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 375                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 376                })?
 377            });
 378            // Build a string from the rope on a background thread.
 379            cx.background_spawn(async move {
 380                let (project_entry_id, rope) = rope_task.await?;
 381                anyhow::Ok(RulesFileContext {
 382                    path_in_worktree,
 383                    text: rope.to_string().trim().to_string(),
 384                    project_entry_id: project_entry_id.to_usize(),
 385                })
 386            })
 387        })
 388    }
 389
 390    pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
 391        &self.prompt_store
 392    }
 393
 394    pub fn tools(&self) -> Entity<ToolWorkingSet> {
 395        self.tools.clone()
 396    }
 397
 398    /// Returns the number of threads.
 399    pub fn thread_count(&self) -> usize {
 400        self.threads.len()
 401    }
 402
 403    pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
 404        self.threads.iter()
 405    }
 406
 407    pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
 408        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 409        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 410        threads
 411    }
 412
 413    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 414        cx.new(|cx| {
 415            Thread::new(
 416                self.project.clone(),
 417                self.tools.clone(),
 418                self.prompt_builder.clone(),
 419                self.project_context.clone(),
 420                cx,
 421            )
 422        })
 423    }
 424
 425    pub fn create_thread_from_serialized(
 426        &mut self,
 427        serialized: SerializedThread,
 428        cx: &mut Context<Self>,
 429    ) -> Entity<Thread> {
 430        cx.new(|cx| {
 431            Thread::deserialize(
 432                ThreadId::new(),
 433                serialized,
 434                self.project.clone(),
 435                self.tools.clone(),
 436                self.prompt_builder.clone(),
 437                self.project_context.clone(),
 438                None,
 439                cx,
 440            )
 441        })
 442    }
 443
 444    pub fn open_thread(
 445        &self,
 446        id: &ThreadId,
 447        window: &mut Window,
 448        cx: &mut Context<Self>,
 449    ) -> Task<Result<Entity<Thread>>> {
 450        let id = id.clone();
 451        let database_future = ThreadsDatabase::global_future(cx);
 452        let this = cx.weak_entity();
 453        window.spawn(cx, async move |cx| {
 454            let database = database_future.await.map_err(|err| anyhow!(err))?;
 455            let thread = database
 456                .try_find_thread(id.clone())
 457                .await?
 458                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 459
 460            let thread = this.update_in(cx, |this, window, cx| {
 461                cx.new(|cx| {
 462                    Thread::deserialize(
 463                        id.clone(),
 464                        thread,
 465                        this.project.clone(),
 466                        this.tools.clone(),
 467                        this.prompt_builder.clone(),
 468                        this.project_context.clone(),
 469                        Some(window),
 470                        cx,
 471                    )
 472                })
 473            })?;
 474
 475            Ok(thread)
 476        })
 477    }
 478
 479    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
 480        let (metadata, serialized_thread) =
 481            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
 482
 483        let database_future = ThreadsDatabase::global_future(cx);
 484        cx.spawn(async move |this, cx| {
 485            let serialized_thread = serialized_thread.await?;
 486            let database = database_future.await.map_err(|err| anyhow!(err))?;
 487            database.save_thread(metadata, serialized_thread).await?;
 488
 489            this.update(cx, |this, cx| this.reload(cx))?.await
 490        })
 491    }
 492
 493    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
 494        let id = id.clone();
 495        let database_future = ThreadsDatabase::global_future(cx);
 496        cx.spawn(async move |this, cx| {
 497            let database = database_future.await.map_err(|err| anyhow!(err))?;
 498            database.delete_thread(id.clone()).await?;
 499
 500            this.update(cx, |this, cx| {
 501                this.threads.retain(|thread| thread.id != id);
 502                cx.notify();
 503            })
 504        })
 505    }
 506
 507    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 508        let database_future = ThreadsDatabase::global_future(cx);
 509        cx.spawn(async move |this, cx| {
 510            let threads = database_future
 511                .await
 512                .map_err(|err| anyhow!(err))?
 513                .list_threads()
 514                .await?;
 515
 516            this.update(cx, |this, cx| {
 517                this.threads = threads;
 518                cx.notify();
 519            })
 520        })
 521    }
 522
 523    fn load_default_profile(&self, cx: &mut Context<Self>) {
 524        let assistant_settings = AgentSettings::get_global(cx);
 525
 526        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
 527    }
 528
 529    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
 530        let assistant_settings = AgentSettings::get_global(cx);
 531
 532        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
 533            self.load_profile(profile.clone(), cx);
 534        }
 535    }
 536
 537    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
 538        self.tools.update(cx, |tools, cx| {
 539            tools.disable_all_tools(cx);
 540            tools.enable(
 541                ToolSource::Native,
 542                &profile
 543                    .tools
 544                    .into_iter()
 545                    .filter_map(|(tool, enabled)| enabled.then(|| tool))
 546                    .collect::<Vec<_>>(),
 547                cx,
 548            );
 549        });
 550
 551        if profile.enable_all_context_servers {
 552            for context_server_id in self
 553                .project
 554                .read(cx)
 555                .context_server_store()
 556                .read(cx)
 557                .all_server_ids()
 558            {
 559                self.tools.update(cx, |tools, cx| {
 560                    tools.enable_source(
 561                        ToolSource::ContextServer {
 562                            id: context_server_id.0.into(),
 563                        },
 564                        cx,
 565                    );
 566                });
 567            }
 568            // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
 569            for (context_server_id, preset) in profile.context_servers {
 570                self.tools.update(cx, |tools, cx| {
 571                    tools.disable(
 572                        ToolSource::ContextServer {
 573                            id: context_server_id.into(),
 574                        },
 575                        &preset
 576                            .tools
 577                            .into_iter()
 578                            .filter_map(|(tool, enabled)| (!enabled).then(|| tool))
 579                            .collect::<Vec<_>>(),
 580                        cx,
 581                    )
 582                })
 583            }
 584        } else {
 585            for (context_server_id, preset) in profile.context_servers {
 586                self.tools.update(cx, |tools, cx| {
 587                    tools.enable(
 588                        ToolSource::ContextServer {
 589                            id: context_server_id.into(),
 590                        },
 591                        &preset
 592                            .tools
 593                            .into_iter()
 594                            .filter_map(|(tool, enabled)| enabled.then(|| tool))
 595                            .collect::<Vec<_>>(),
 596                        cx,
 597                    )
 598                })
 599            }
 600        }
 601    }
 602
 603    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 604        cx.subscribe(
 605            &self.project.read(cx).context_server_store(),
 606            Self::handle_context_server_event,
 607        )
 608        .detach();
 609    }
 610
 611    fn handle_context_server_event(
 612        &mut self,
 613        context_server_store: Entity<ContextServerStore>,
 614        event: &project::context_server_store::Event,
 615        cx: &mut Context<Self>,
 616    ) {
 617        let tool_working_set = self.tools.clone();
 618        match event {
 619            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 620                match status {
 621                    ContextServerStatus::Running => {
 622                        if let Some(server) =
 623                            context_server_store.read(cx).get_running_server(server_id)
 624                        {
 625                            let context_server_manager = context_server_store.clone();
 626                            cx.spawn({
 627                                let server = server.clone();
 628                                let server_id = server_id.clone();
 629                                async move |this, cx| {
 630                                    let Some(protocol) = server.client() else {
 631                                        return;
 632                                    };
 633
 634                                    if protocol.capable(context_server::protocol::ServerCapability::Tools) {
 635                                        if let Some(tools) = protocol.list_tools().await.log_err() {
 636                                            let tool_ids = tool_working_set
 637                                                .update(cx, |tool_working_set, _| {
 638                                                    tools
 639                                                        .tools
 640                                                        .into_iter()
 641                                                        .map(|tool| {
 642                                                            log::info!(
 643                                                                "registering context server tool: {:?}",
 644                                                                tool.name
 645                                                            );
 646                                                            tool_working_set.insert(Arc::new(
 647                                                                ContextServerTool::new(
 648                                                                    context_server_manager.clone(),
 649                                                                    server.id(),
 650                                                                    tool,
 651                                                                ),
 652                                                            ))
 653                                                        })
 654                                                        .collect::<Vec<_>>()
 655                                                })
 656                                                .log_err();
 657
 658                                            if let Some(tool_ids) = tool_ids {
 659                                                this.update(cx, |this, cx| {
 660                                                    this.context_server_tool_ids
 661                                                        .insert(server_id, tool_ids);
 662                                                    this.load_default_profile(cx);
 663                                                })
 664                                                .log_err();
 665                                            }
 666                                        }
 667                                    }
 668                                }
 669                            })
 670                            .detach();
 671                        }
 672                    }
 673                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 674                        if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
 675                            tool_working_set.update(cx, |tool_working_set, _| {
 676                                tool_working_set.remove(&tool_ids);
 677                            });
 678                            self.load_default_profile(cx);
 679                        }
 680                    }
 681                    _ => {}
 682                }
 683            }
 684        }
 685    }
 686}
 687
 688#[derive(Debug, Clone, Serialize, Deserialize)]
 689pub struct SerializedThreadMetadata {
 690    pub id: ThreadId,
 691    pub summary: SharedString,
 692    pub updated_at: DateTime<Utc>,
 693}
 694
 695#[derive(Serialize, Deserialize, Debug)]
 696pub struct SerializedThread {
 697    pub version: String,
 698    pub summary: SharedString,
 699    pub updated_at: DateTime<Utc>,
 700    pub messages: Vec<SerializedMessage>,
 701    #[serde(default)]
 702    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 703    #[serde(default)]
 704    pub cumulative_token_usage: TokenUsage,
 705    #[serde(default)]
 706    pub request_token_usage: Vec<TokenUsage>,
 707    #[serde(default)]
 708    pub detailed_summary_state: DetailedSummaryState,
 709    #[serde(default)]
 710    pub exceeded_window_error: Option<ExceededWindowError>,
 711    #[serde(default)]
 712    pub model: Option<SerializedLanguageModel>,
 713    #[serde(default)]
 714    pub completion_mode: Option<CompletionMode>,
 715    #[serde(default)]
 716    pub tool_use_limit_reached: bool,
 717}
 718
 719#[derive(Serialize, Deserialize, Debug)]
 720pub struct SerializedLanguageModel {
 721    pub provider: String,
 722    pub model: String,
 723}
 724
 725impl SerializedThread {
 726    pub const VERSION: &'static str = "0.2.0";
 727
 728    pub fn from_json(json: &[u8]) -> Result<Self> {
 729        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
 730        match saved_thread_json.get("version") {
 731            Some(serde_json::Value::String(version)) => match version.as_str() {
 732                SerializedThreadV0_1_0::VERSION => {
 733                    let saved_thread =
 734                        serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
 735                    Ok(saved_thread.upgrade())
 736                }
 737                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
 738                    saved_thread_json,
 739                )?),
 740                _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 741            },
 742            None => {
 743                let saved_thread =
 744                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
 745                Ok(saved_thread.upgrade())
 746            }
 747            version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
 748        }
 749    }
 750}
 751
 752#[derive(Serialize, Deserialize, Debug)]
 753pub struct SerializedThreadV0_1_0(
 754    // The structure did not change, so we are reusing the latest SerializedThread.
 755    // When making the next version, make sure this points to SerializedThreadV0_2_0
 756    SerializedThread,
 757);
 758
 759impl SerializedThreadV0_1_0 {
 760    pub const VERSION: &'static str = "0.1.0";
 761
 762    pub fn upgrade(self) -> SerializedThread {
 763        debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
 764
 765        let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
 766
 767        for message in self.0.messages {
 768            if message.role == Role::User && !message.tool_results.is_empty() {
 769                if let Some(last_message) = messages.last_mut() {
 770                    debug_assert!(last_message.role == Role::Assistant);
 771
 772                    last_message.tool_results = message.tool_results;
 773                    continue;
 774                }
 775            }
 776
 777            messages.push(message);
 778        }
 779
 780        SerializedThread { messages, ..self.0 }
 781    }
 782}
 783
 784#[derive(Debug, Serialize, Deserialize)]
 785pub struct SerializedMessage {
 786    pub id: MessageId,
 787    pub role: Role,
 788    #[serde(default)]
 789    pub segments: Vec<SerializedMessageSegment>,
 790    #[serde(default)]
 791    pub tool_uses: Vec<SerializedToolUse>,
 792    #[serde(default)]
 793    pub tool_results: Vec<SerializedToolResult>,
 794    #[serde(default)]
 795    pub context: String,
 796    #[serde(default)]
 797    pub creases: Vec<SerializedCrease>,
 798    #[serde(default)]
 799    pub is_hidden: bool,
 800}
 801
 802#[derive(Debug, Serialize, Deserialize)]
 803#[serde(tag = "type")]
 804pub enum SerializedMessageSegment {
 805    #[serde(rename = "text")]
 806    Text {
 807        text: String,
 808    },
 809    #[serde(rename = "thinking")]
 810    Thinking {
 811        text: String,
 812        #[serde(skip_serializing_if = "Option::is_none")]
 813        signature: Option<String>,
 814    },
 815    RedactedThinking {
 816        data: Vec<u8>,
 817    },
 818}
 819
 820#[derive(Debug, Serialize, Deserialize)]
 821pub struct SerializedToolUse {
 822    pub id: LanguageModelToolUseId,
 823    pub name: SharedString,
 824    pub input: serde_json::Value,
 825}
 826
 827#[derive(Debug, Serialize, Deserialize)]
 828pub struct SerializedToolResult {
 829    pub tool_use_id: LanguageModelToolUseId,
 830    pub is_error: bool,
 831    pub content: LanguageModelToolResultContent,
 832    pub output: Option<serde_json::Value>,
 833}
 834
 835#[derive(Serialize, Deserialize)]
 836struct LegacySerializedThread {
 837    pub summary: SharedString,
 838    pub updated_at: DateTime<Utc>,
 839    pub messages: Vec<LegacySerializedMessage>,
 840    #[serde(default)]
 841    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
 842}
 843
 844impl LegacySerializedThread {
 845    pub fn upgrade(self) -> SerializedThread {
 846        SerializedThread {
 847            version: SerializedThread::VERSION.to_string(),
 848            summary: self.summary,
 849            updated_at: self.updated_at,
 850            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
 851            initial_project_snapshot: self.initial_project_snapshot,
 852            cumulative_token_usage: TokenUsage::default(),
 853            request_token_usage: Vec::new(),
 854            detailed_summary_state: DetailedSummaryState::default(),
 855            exceeded_window_error: None,
 856            model: None,
 857            completion_mode: None,
 858            tool_use_limit_reached: false,
 859        }
 860    }
 861}
 862
 863#[derive(Debug, Serialize, Deserialize)]
 864struct LegacySerializedMessage {
 865    pub id: MessageId,
 866    pub role: Role,
 867    pub text: String,
 868    #[serde(default)]
 869    pub tool_uses: Vec<SerializedToolUse>,
 870    #[serde(default)]
 871    pub tool_results: Vec<SerializedToolResult>,
 872}
 873
 874impl LegacySerializedMessage {
 875    fn upgrade(self) -> SerializedMessage {
 876        SerializedMessage {
 877            id: self.id,
 878            role: self.role,
 879            segments: vec![SerializedMessageSegment::Text { text: self.text }],
 880            tool_uses: self.tool_uses,
 881            tool_results: self.tool_results,
 882            context: String::new(),
 883            creases: Vec::new(),
 884            is_hidden: false,
 885        }
 886    }
 887}
 888
 889#[derive(Debug, Serialize, Deserialize)]
 890pub struct SerializedCrease {
 891    pub start: usize,
 892    pub end: usize,
 893    pub icon_path: SharedString,
 894    pub label: SharedString,
 895}
 896
 897struct GlobalThreadsDatabase(
 898    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 899);
 900
 901impl Global for GlobalThreadsDatabase {}
 902
 903pub(crate) struct ThreadsDatabase {
 904    executor: BackgroundExecutor,
 905    connection: Arc<Mutex<Connection>>,
 906}
 907
 908impl ThreadsDatabase {
 909    fn connection(&self) -> Arc<Mutex<Connection>> {
 910        self.connection.clone()
 911    }
 912
 913    const COMPRESSION_LEVEL: i32 = 3;
 914}
 915
 916impl Bind for ThreadId {
 917    fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
 918        self.to_string().bind(statement, start_index)
 919    }
 920}
 921
 922impl Column for ThreadId {
 923    fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
 924        let (id_str, next_index) = String::column(statement, start_index)?;
 925        Ok((ThreadId::from(id_str.as_str()), next_index))
 926    }
 927}
 928
 929impl ThreadsDatabase {
 930    fn global_future(
 931        cx: &mut App,
 932    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
 933        GlobalThreadsDatabase::global(cx).0.clone()
 934    }
 935
 936    fn init(cx: &mut App) {
 937        let executor = cx.background_executor().clone();
 938        let database_future = executor
 939            .spawn({
 940                let executor = executor.clone();
 941                let threads_dir = paths::data_dir().join("threads");
 942                async move { ThreadsDatabase::new(threads_dir, executor) }
 943            })
 944            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
 945            .boxed()
 946            .shared();
 947
 948        cx.set_global(GlobalThreadsDatabase(database_future));
 949    }
 950
 951    pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
 952        std::fs::create_dir_all(&threads_dir)?;
 953
 954        let sqlite_path = threads_dir.join("threads.db");
 955        let mdb_path = threads_dir.join("threads-db.1.mdb");
 956
 957        let needs_migration_from_heed = mdb_path.exists();
 958
 959        let connection = Connection::open_file(&sqlite_path.to_string_lossy());
 960
 961        connection.exec(indoc! {"
 962                CREATE TABLE IF NOT EXISTS threads (
 963                    id TEXT PRIMARY KEY,
 964                    summary TEXT NOT NULL,
 965                    updated_at TEXT NOT NULL,
 966                    data_type TEXT NOT NULL,
 967                    data BLOB NOT NULL
 968                )
 969            "})?()
 970        .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
 971
 972        let db = Self {
 973            executor: executor.clone(),
 974            connection: Arc::new(Mutex::new(connection)),
 975        };
 976
 977        if needs_migration_from_heed {
 978            let db_connection = db.connection();
 979            let executor_clone = executor.clone();
 980            executor
 981                .spawn(async move {
 982                    log::info!("Starting threads.db migration");
 983                    Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
 984                    std::fs::remove_dir_all(mdb_path)?;
 985                    log::info!("threads.db migrated to sqlite");
 986                    Ok::<(), anyhow::Error>(())
 987                })
 988                .detach();
 989        }
 990
 991        Ok(db)
 992    }
 993
 994    // Remove this migration after 2025-09-01
 995    fn migrate_from_heed(
 996        mdb_path: &Path,
 997        connection: Arc<Mutex<Connection>>,
 998        _executor: BackgroundExecutor,
 999    ) -> Result<()> {
1000        use heed::types::SerdeBincode;
1001        struct SerializedThreadHeed(SerializedThread);
1002
1003        impl heed::BytesEncode<'_> for SerializedThreadHeed {
1004            type EItem = SerializedThreadHeed;
1005
1006            fn bytes_encode(
1007                item: &Self::EItem,
1008            ) -> Result<std::borrow::Cow<[u8]>, heed::BoxedError> {
1009                serde_json::to_vec(&item.0)
1010                    .map(std::borrow::Cow::Owned)
1011                    .map_err(Into::into)
1012            }
1013        }
1014
1015        impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
1016            type DItem = SerializedThreadHeed;
1017
1018            fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
1019                SerializedThread::from_json(bytes)
1020                    .map(SerializedThreadHeed)
1021                    .map_err(Into::into)
1022            }
1023        }
1024
1025        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
1026
1027        let env = unsafe {
1028            heed::EnvOpenOptions::new()
1029                .map_size(ONE_GB_IN_BYTES)
1030                .max_dbs(1)
1031                .open(mdb_path)?
1032        };
1033
1034        let txn = env.write_txn()?;
1035        let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
1036            .open_database(&txn, Some("threads"))?
1037            .ok_or_else(|| anyhow!("threads database not found"))?;
1038
1039        for result in threads.iter(&txn)? {
1040            let (thread_id, thread_heed) = result?;
1041            Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
1042        }
1043
1044        Ok(())
1045    }
1046
1047    fn save_thread_sync(
1048        connection: &Arc<Mutex<Connection>>,
1049        id: ThreadId,
1050        thread: SerializedThread,
1051    ) -> Result<()> {
1052        let json_data = serde_json::to_string(&thread)?;
1053        let summary = thread.summary.to_string();
1054        let updated_at = thread.updated_at.to_rfc3339();
1055
1056        let connection = connection.lock().unwrap();
1057
1058        let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1059        let data_type = DataType::Zstd;
1060        let data = compressed;
1061
1062        let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1063            INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1064        "})?;
1065
1066        insert((id, summary, updated_at, data_type, data))?;
1067
1068        Ok(())
1069    }
1070
1071    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1072        let connection = self.connection.clone();
1073
1074        self.executor.spawn(async move {
1075            let connection = connection.lock().unwrap();
1076            let mut select =
1077                connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1078                SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1079            "})?;
1080
1081            let rows = select(())?;
1082            let mut threads = Vec::new();
1083
1084            for (id, summary, updated_at) in rows {
1085                threads.push(SerializedThreadMetadata {
1086                    id,
1087                    summary: summary.into(),
1088                    updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1089                });
1090            }
1091
1092            Ok(threads)
1093        })
1094    }
1095
1096    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1097        let connection = self.connection.clone();
1098
1099        self.executor.spawn(async move {
1100            let connection = connection.lock().unwrap();
1101            let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1102                SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1103            "})?;
1104
1105            let rows = select(id)?;
1106            if let Some((data_type, data)) = rows.into_iter().next() {
1107                let json_data = match data_type {
1108                    DataType::Zstd => {
1109                        let decompressed = zstd::decode_all(&data[..])?;
1110                        String::from_utf8(decompressed)?
1111                    }
1112                    DataType::Json => String::from_utf8(data)?,
1113                };
1114
1115                let thread = SerializedThread::from_json(json_data.as_bytes())?;
1116                Ok(Some(thread))
1117            } else {
1118                Ok(None)
1119            }
1120        })
1121    }
1122
1123    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1124        let connection = self.connection.clone();
1125
1126        self.executor
1127            .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1128    }
1129
1130    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1131        let connection = self.connection.clone();
1132
1133        self.executor.spawn(async move {
1134            let connection = connection.lock().unwrap();
1135
1136            let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1137                DELETE FROM threads WHERE id = ?
1138            "})?;
1139
1140            delete(id)?;
1141
1142            Ok(())
1143        })
1144    }
1145}