1use std::cell::{Ref, RefCell};
2use std::path::{Path, PathBuf};
3use std::rc::Rc;
4use std::sync::{Arc, Mutex};
5
6use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
7use anyhow::{Context as _, Result, anyhow};
8use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
9use chrono::{DateTime, Utc};
10use collections::HashMap;
11use context_server::ContextServerId;
12use futures::channel::{mpsc, oneshot};
13use futures::future::{self, BoxFuture, Shared};
14use futures::{FutureExt as _, StreamExt as _};
15use gpui::{
16 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
17 Subscription, Task, prelude::*,
18};
19
20use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
21use project::context_server_store::{ContextServerStatus, ContextServerStore};
22use project::{Project, ProjectItem, ProjectPath, Worktree};
23use prompt_store::{
24 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
25 UserRulesContext, WorktreeContext,
26};
27use serde::{Deserialize, Serialize};
28use settings::{Settings as _, SettingsStore};
29use ui::Window;
30use util::ResultExt as _;
31
32use crate::context_server_tool::ContextServerTool;
33use crate::thread::{
34 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
35};
36use indoc::indoc;
37use sqlez::{
38 bindable::{Bind, Column},
39 connection::Connection,
40 statement::Statement,
41};
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub enum DataType {
45 #[serde(rename = "json")]
46 Json,
47 #[serde(rename = "zstd")]
48 Zstd,
49}
50
51impl Bind for DataType {
52 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
53 let value = match self {
54 DataType::Json => "json",
55 DataType::Zstd => "zstd",
56 };
57 value.bind(statement, start_index)
58 }
59}
60
61impl Column for DataType {
62 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
63 let (value, next_index) = String::column(statement, start_index)?;
64 let data_type = match value.as_str() {
65 "json" => DataType::Json,
66 "zstd" => DataType::Zstd,
67 _ => anyhow::bail!("Unknown data type: {}", value),
68 };
69 Ok((data_type, next_index))
70 }
71}
72
73const RULES_FILE_NAMES: [&'static str; 8] = [
74 ".rules",
75 ".cursorrules",
76 ".windsurfrules",
77 ".clinerules",
78 ".github/copilot-instructions.md",
79 "CLAUDE.md",
80 "AGENT.md",
81 "AGENTS.md",
82];
83
84pub fn init(cx: &mut App) {
85 ThreadsDatabase::init(cx);
86}
87
88/// A system prompt shared by all threads created by this ThreadStore
89#[derive(Clone, Default)]
90pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
91
92impl SharedProjectContext {
93 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
94 self.0.borrow()
95 }
96}
97
98pub type TextThreadStore = assistant_context_editor::ContextStore;
99
100pub struct ThreadStore {
101 project: Entity<Project>,
102 tools: Entity<ToolWorkingSet>,
103 prompt_builder: Arc<PromptBuilder>,
104 prompt_store: Option<Entity<PromptStore>>,
105 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
106 threads: Vec<SerializedThreadMetadata>,
107 project_context: SharedProjectContext,
108 reload_system_prompt_tx: mpsc::Sender<()>,
109 _reload_system_prompt_task: Task<()>,
110 _subscriptions: Vec<Subscription>,
111}
112
113pub struct RulesLoadingError {
114 pub message: SharedString,
115}
116
117impl EventEmitter<RulesLoadingError> for ThreadStore {}
118
119impl ThreadStore {
120 pub fn load(
121 project: Entity<Project>,
122 tools: Entity<ToolWorkingSet>,
123 prompt_store: Option<Entity<PromptStore>>,
124 prompt_builder: Arc<PromptBuilder>,
125 cx: &mut App,
126 ) -> Task<Result<Entity<Self>>> {
127 cx.spawn(async move |cx| {
128 let (thread_store, ready_rx) = cx.update(|cx| {
129 let mut option_ready_rx = None;
130 let thread_store = cx.new(|cx| {
131 let (thread_store, ready_rx) =
132 Self::new(project, tools, prompt_builder, prompt_store, cx);
133 option_ready_rx = Some(ready_rx);
134 thread_store
135 });
136 (thread_store, option_ready_rx.take().unwrap())
137 })?;
138 ready_rx.await?;
139 Ok(thread_store)
140 })
141 }
142
143 fn new(
144 project: Entity<Project>,
145 tools: Entity<ToolWorkingSet>,
146 prompt_builder: Arc<PromptBuilder>,
147 prompt_store: Option<Entity<PromptStore>>,
148 cx: &mut Context<Self>,
149 ) -> (Self, oneshot::Receiver<()>) {
150 let mut subscriptions = vec![
151 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
152 this.load_default_profile(cx);
153 }),
154 cx.subscribe(&project, Self::handle_project_event),
155 ];
156
157 if let Some(prompt_store) = prompt_store.as_ref() {
158 subscriptions.push(cx.subscribe(
159 prompt_store,
160 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
161 this.enqueue_system_prompt_reload();
162 },
163 ))
164 }
165
166 // This channel and task prevent concurrent and redundant loading of the system prompt.
167 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
168 let (ready_tx, ready_rx) = oneshot::channel();
169 let mut ready_tx = Some(ready_tx);
170 let reload_system_prompt_task = cx.spawn({
171 let prompt_store = prompt_store.clone();
172 async move |thread_store, cx| {
173 loop {
174 let Some(reload_task) = thread_store
175 .update(cx, |thread_store, cx| {
176 thread_store.reload_system_prompt(prompt_store.clone(), cx)
177 })
178 .ok()
179 else {
180 return;
181 };
182 reload_task.await;
183 if let Some(ready_tx) = ready_tx.take() {
184 ready_tx.send(()).ok();
185 }
186 reload_system_prompt_rx.next().await;
187 }
188 }
189 });
190
191 let this = Self {
192 project,
193 tools,
194 prompt_builder,
195 prompt_store,
196 context_server_tool_ids: HashMap::default(),
197 threads: Vec::new(),
198 project_context: SharedProjectContext::default(),
199 reload_system_prompt_tx,
200 _reload_system_prompt_task: reload_system_prompt_task,
201 _subscriptions: subscriptions,
202 };
203 this.load_default_profile(cx);
204 this.register_context_server_handlers(cx);
205 this.reload(cx).detach_and_log_err(cx);
206 (this, ready_rx)
207 }
208
209 fn handle_project_event(
210 &mut self,
211 _project: Entity<Project>,
212 event: &project::Event,
213 _cx: &mut Context<Self>,
214 ) {
215 match event {
216 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
217 self.enqueue_system_prompt_reload();
218 }
219 project::Event::WorktreeUpdatedEntries(_, items) => {
220 if items.iter().any(|(path, _, _)| {
221 RULES_FILE_NAMES
222 .iter()
223 .any(|name| path.as_ref() == Path::new(name))
224 }) {
225 self.enqueue_system_prompt_reload();
226 }
227 }
228 _ => {}
229 }
230 }
231
232 fn enqueue_system_prompt_reload(&mut self) {
233 self.reload_system_prompt_tx.try_send(()).ok();
234 }
235
236 // Note that this should only be called from `reload_system_prompt_task`.
237 fn reload_system_prompt(
238 &self,
239 prompt_store: Option<Entity<PromptStore>>,
240 cx: &mut Context<Self>,
241 ) -> Task<()> {
242 let worktrees = self
243 .project
244 .read(cx)
245 .visible_worktrees(cx)
246 .collect::<Vec<_>>();
247 let worktree_tasks = worktrees
248 .into_iter()
249 .map(|worktree| {
250 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
251 })
252 .collect::<Vec<_>>();
253 let default_user_rules_task = match prompt_store {
254 None => Task::ready(vec![]),
255 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
256 let prompts = prompt_store.default_prompt_metadata();
257 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
258 let contents = prompt_store.load(prompt_metadata.id, cx);
259 async move { (contents.await, prompt_metadata) }
260 });
261 cx.background_spawn(future::join_all(load_tasks))
262 }),
263 };
264
265 cx.spawn(async move |this, cx| {
266 let (worktrees, default_user_rules) =
267 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
268
269 let worktrees = worktrees
270 .into_iter()
271 .map(|(worktree, rules_error)| {
272 if let Some(rules_error) = rules_error {
273 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
274 }
275 worktree
276 })
277 .collect::<Vec<_>>();
278
279 let default_user_rules = default_user_rules
280 .into_iter()
281 .flat_map(|(contents, prompt_metadata)| match contents {
282 Ok(contents) => Some(UserRulesContext {
283 uuid: match prompt_metadata.id {
284 PromptId::User { uuid } => uuid,
285 PromptId::EditWorkflow => return None,
286 },
287 title: prompt_metadata.title.map(|title| title.to_string()),
288 contents,
289 }),
290 Err(err) => {
291 this.update(cx, |_, cx| {
292 cx.emit(RulesLoadingError {
293 message: format!("{err:?}").into(),
294 });
295 })
296 .ok();
297 None
298 }
299 })
300 .collect::<Vec<_>>();
301
302 this.update(cx, |this, _cx| {
303 *this.project_context.0.borrow_mut() =
304 Some(ProjectContext::new(worktrees, default_user_rules));
305 })
306 .ok();
307 })
308 }
309
310 fn load_worktree_info_for_system_prompt(
311 worktree: Entity<Worktree>,
312 project: Entity<Project>,
313 cx: &mut App,
314 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
315 let root_name = worktree.read(cx).root_name().into();
316
317 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
318 let Some(rules_task) = rules_task else {
319 return Task::ready((
320 WorktreeContext {
321 root_name,
322 rules_file: None,
323 },
324 None,
325 ));
326 };
327
328 cx.spawn(async move |_| {
329 let (rules_file, rules_file_error) = match rules_task.await {
330 Ok(rules_file) => (Some(rules_file), None),
331 Err(err) => (
332 None,
333 Some(RulesLoadingError {
334 message: format!("{err}").into(),
335 }),
336 ),
337 };
338 let worktree_info = WorktreeContext {
339 root_name,
340 rules_file,
341 };
342 (worktree_info, rules_file_error)
343 })
344 }
345
346 fn load_worktree_rules_file(
347 worktree: Entity<Worktree>,
348 project: Entity<Project>,
349 cx: &mut App,
350 ) -> Option<Task<Result<RulesFileContext>>> {
351 let worktree_ref = worktree.read(cx);
352 let worktree_id = worktree_ref.id();
353 let selected_rules_file = RULES_FILE_NAMES
354 .into_iter()
355 .filter_map(|name| {
356 worktree_ref
357 .entry_for_path(name)
358 .filter(|entry| entry.is_file())
359 .map(|entry| entry.path.clone())
360 })
361 .next();
362
363 // Note that Cline supports `.clinerules` being a directory, but that is not currently
364 // supported. This doesn't seem to occur often in GitHub repositories.
365 selected_rules_file.map(|path_in_worktree| {
366 let project_path = ProjectPath {
367 worktree_id,
368 path: path_in_worktree.clone(),
369 };
370 let buffer_task =
371 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
372 let rope_task = cx.spawn(async move |cx| {
373 buffer_task.await?.read_with(cx, |buffer, cx| {
374 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
375 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
376 })?
377 });
378 // Build a string from the rope on a background thread.
379 cx.background_spawn(async move {
380 let (project_entry_id, rope) = rope_task.await?;
381 anyhow::Ok(RulesFileContext {
382 path_in_worktree,
383 text: rope.to_string().trim().to_string(),
384 project_entry_id: project_entry_id.to_usize(),
385 })
386 })
387 })
388 }
389
390 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
391 &self.prompt_store
392 }
393
394 pub fn tools(&self) -> Entity<ToolWorkingSet> {
395 self.tools.clone()
396 }
397
398 /// Returns the number of threads.
399 pub fn thread_count(&self) -> usize {
400 self.threads.len()
401 }
402
403 pub fn reverse_chronological_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
404 // ordering is from "ORDER BY" in `list_threads`
405 self.threads.iter()
406 }
407
408 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
409 cx.new(|cx| {
410 Thread::new(
411 self.project.clone(),
412 self.tools.clone(),
413 self.prompt_builder.clone(),
414 self.project_context.clone(),
415 cx,
416 )
417 })
418 }
419
420 pub fn create_thread_from_serialized(
421 &mut self,
422 serialized: SerializedThread,
423 cx: &mut Context<Self>,
424 ) -> Entity<Thread> {
425 cx.new(|cx| {
426 Thread::deserialize(
427 ThreadId::new(),
428 serialized,
429 self.project.clone(),
430 self.tools.clone(),
431 self.prompt_builder.clone(),
432 self.project_context.clone(),
433 None,
434 cx,
435 )
436 })
437 }
438
439 pub fn open_thread(
440 &self,
441 id: &ThreadId,
442 window: &mut Window,
443 cx: &mut Context<Self>,
444 ) -> Task<Result<Entity<Thread>>> {
445 let id = id.clone();
446 let database_future = ThreadsDatabase::global_future(cx);
447 let this = cx.weak_entity();
448 window.spawn(cx, async move |cx| {
449 let database = database_future.await.map_err(|err| anyhow!(err))?;
450 let thread = database
451 .try_find_thread(id.clone())
452 .await?
453 .with_context(|| format!("no thread found with ID: {id:?}"))?;
454
455 let thread = this.update_in(cx, |this, window, cx| {
456 cx.new(|cx| {
457 Thread::deserialize(
458 id.clone(),
459 thread,
460 this.project.clone(),
461 this.tools.clone(),
462 this.prompt_builder.clone(),
463 this.project_context.clone(),
464 Some(window),
465 cx,
466 )
467 })
468 })?;
469
470 Ok(thread)
471 })
472 }
473
474 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
475 let (metadata, serialized_thread) =
476 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
477
478 let database_future = ThreadsDatabase::global_future(cx);
479 cx.spawn(async move |this, cx| {
480 let serialized_thread = serialized_thread.await?;
481 let database = database_future.await.map_err(|err| anyhow!(err))?;
482 database.save_thread(metadata, serialized_thread).await?;
483
484 this.update(cx, |this, cx| this.reload(cx))?.await
485 })
486 }
487
488 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
489 let id = id.clone();
490 let database_future = ThreadsDatabase::global_future(cx);
491 cx.spawn(async move |this, cx| {
492 let database = database_future.await.map_err(|err| anyhow!(err))?;
493 database.delete_thread(id.clone()).await?;
494
495 this.update(cx, |this, cx| {
496 this.threads.retain(|thread| thread.id != id);
497 cx.notify();
498 })
499 })
500 }
501
502 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
503 let database_future = ThreadsDatabase::global_future(cx);
504 cx.spawn(async move |this, cx| {
505 let threads = database_future
506 .await
507 .map_err(|err| anyhow!(err))?
508 .list_threads()
509 .await?;
510
511 this.update(cx, |this, cx| {
512 this.threads = threads;
513 cx.notify();
514 })
515 })
516 }
517
518 fn load_default_profile(&self, cx: &mut Context<Self>) {
519 let assistant_settings = AgentSettings::get_global(cx);
520
521 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
522 }
523
524 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
525 let assistant_settings = AgentSettings::get_global(cx);
526
527 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
528 self.load_profile(profile.clone(), cx);
529 }
530 }
531
532 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
533 self.tools.update(cx, |tools, cx| {
534 tools.disable_all_tools(cx);
535 tools.enable(
536 ToolSource::Native,
537 &profile
538 .tools
539 .into_iter()
540 .filter_map(|(tool, enabled)| enabled.then(|| tool))
541 .collect::<Vec<_>>(),
542 cx,
543 );
544 });
545
546 if profile.enable_all_context_servers {
547 for context_server_id in self
548 .project
549 .read(cx)
550 .context_server_store()
551 .read(cx)
552 .all_server_ids()
553 {
554 self.tools.update(cx, |tools, cx| {
555 tools.enable_source(
556 ToolSource::ContextServer {
557 id: context_server_id.0.into(),
558 },
559 cx,
560 );
561 });
562 }
563 // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
564 for (context_server_id, preset) in profile.context_servers {
565 self.tools.update(cx, |tools, cx| {
566 tools.disable(
567 ToolSource::ContextServer {
568 id: context_server_id.into(),
569 },
570 &preset
571 .tools
572 .into_iter()
573 .filter_map(|(tool, enabled)| (!enabled).then(|| tool))
574 .collect::<Vec<_>>(),
575 cx,
576 )
577 })
578 }
579 } else {
580 for (context_server_id, preset) in profile.context_servers {
581 self.tools.update(cx, |tools, cx| {
582 tools.enable(
583 ToolSource::ContextServer {
584 id: context_server_id.into(),
585 },
586 &preset
587 .tools
588 .into_iter()
589 .filter_map(|(tool, enabled)| enabled.then(|| tool))
590 .collect::<Vec<_>>(),
591 cx,
592 )
593 })
594 }
595 }
596 }
597
598 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
599 cx.subscribe(
600 &self.project.read(cx).context_server_store(),
601 Self::handle_context_server_event,
602 )
603 .detach();
604 }
605
606 fn handle_context_server_event(
607 &mut self,
608 context_server_store: Entity<ContextServerStore>,
609 event: &project::context_server_store::Event,
610 cx: &mut Context<Self>,
611 ) {
612 let tool_working_set = self.tools.clone();
613 match event {
614 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
615 match status {
616 ContextServerStatus::Running => {
617 if let Some(server) =
618 context_server_store.read(cx).get_running_server(server_id)
619 {
620 let context_server_manager = context_server_store.clone();
621 cx.spawn({
622 let server = server.clone();
623 let server_id = server_id.clone();
624 async move |this, cx| {
625 let Some(protocol) = server.client() else {
626 return;
627 };
628
629 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
630 if let Some(tools) = protocol.list_tools().await.log_err() {
631 let tool_ids = tool_working_set
632 .update(cx, |tool_working_set, _| {
633 tools
634 .tools
635 .into_iter()
636 .map(|tool| {
637 log::info!(
638 "registering context server tool: {:?}",
639 tool.name
640 );
641 tool_working_set.insert(Arc::new(
642 ContextServerTool::new(
643 context_server_manager.clone(),
644 server.id(),
645 tool,
646 ),
647 ))
648 })
649 .collect::<Vec<_>>()
650 })
651 .log_err();
652
653 if let Some(tool_ids) = tool_ids {
654 this.update(cx, |this, cx| {
655 this.context_server_tool_ids
656 .insert(server_id, tool_ids);
657 this.load_default_profile(cx);
658 })
659 .log_err();
660 }
661 }
662 }
663 }
664 })
665 .detach();
666 }
667 }
668 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
669 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
670 tool_working_set.update(cx, |tool_working_set, _| {
671 tool_working_set.remove(&tool_ids);
672 });
673 self.load_default_profile(cx);
674 }
675 }
676 _ => {}
677 }
678 }
679 }
680 }
681}
682
683#[derive(Debug, Clone, Serialize, Deserialize)]
684pub struct SerializedThreadMetadata {
685 pub id: ThreadId,
686 pub summary: SharedString,
687 pub updated_at: DateTime<Utc>,
688}
689
690#[derive(Serialize, Deserialize, Debug)]
691pub struct SerializedThread {
692 pub version: String,
693 pub summary: SharedString,
694 pub updated_at: DateTime<Utc>,
695 pub messages: Vec<SerializedMessage>,
696 #[serde(default)]
697 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
698 #[serde(default)]
699 pub cumulative_token_usage: TokenUsage,
700 #[serde(default)]
701 pub request_token_usage: Vec<TokenUsage>,
702 #[serde(default)]
703 pub detailed_summary_state: DetailedSummaryState,
704 #[serde(default)]
705 pub exceeded_window_error: Option<ExceededWindowError>,
706 #[serde(default)]
707 pub model: Option<SerializedLanguageModel>,
708 #[serde(default)]
709 pub completion_mode: Option<CompletionMode>,
710 #[serde(default)]
711 pub tool_use_limit_reached: bool,
712}
713
714#[derive(Serialize, Deserialize, Debug)]
715pub struct SerializedLanguageModel {
716 pub provider: String,
717 pub model: String,
718}
719
720impl SerializedThread {
721 pub const VERSION: &'static str = "0.2.0";
722
723 pub fn from_json(json: &[u8]) -> Result<Self> {
724 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
725 match saved_thread_json.get("version") {
726 Some(serde_json::Value::String(version)) => match version.as_str() {
727 SerializedThreadV0_1_0::VERSION => {
728 let saved_thread =
729 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
730 Ok(saved_thread.upgrade())
731 }
732 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
733 saved_thread_json,
734 )?),
735 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
736 },
737 None => {
738 let saved_thread =
739 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
740 Ok(saved_thread.upgrade())
741 }
742 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
743 }
744 }
745}
746
747#[derive(Serialize, Deserialize, Debug)]
748pub struct SerializedThreadV0_1_0(
749 // The structure did not change, so we are reusing the latest SerializedThread.
750 // When making the next version, make sure this points to SerializedThreadV0_2_0
751 SerializedThread,
752);
753
754impl SerializedThreadV0_1_0 {
755 pub const VERSION: &'static str = "0.1.0";
756
757 pub fn upgrade(self) -> SerializedThread {
758 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
759
760 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
761
762 for message in self.0.messages {
763 if message.role == Role::User && !message.tool_results.is_empty() {
764 if let Some(last_message) = messages.last_mut() {
765 debug_assert!(last_message.role == Role::Assistant);
766
767 last_message.tool_results = message.tool_results;
768 continue;
769 }
770 }
771
772 messages.push(message);
773 }
774
775 SerializedThread { messages, ..self.0 }
776 }
777}
778
779#[derive(Debug, Serialize, Deserialize)]
780pub struct SerializedMessage {
781 pub id: MessageId,
782 pub role: Role,
783 #[serde(default)]
784 pub segments: Vec<SerializedMessageSegment>,
785 #[serde(default)]
786 pub tool_uses: Vec<SerializedToolUse>,
787 #[serde(default)]
788 pub tool_results: Vec<SerializedToolResult>,
789 #[serde(default)]
790 pub context: String,
791 #[serde(default)]
792 pub creases: Vec<SerializedCrease>,
793 #[serde(default)]
794 pub is_hidden: bool,
795}
796
797#[derive(Debug, Serialize, Deserialize)]
798#[serde(tag = "type")]
799pub enum SerializedMessageSegment {
800 #[serde(rename = "text")]
801 Text {
802 text: String,
803 },
804 #[serde(rename = "thinking")]
805 Thinking {
806 text: String,
807 #[serde(skip_serializing_if = "Option::is_none")]
808 signature: Option<String>,
809 },
810 RedactedThinking {
811 data: Vec<u8>,
812 },
813}
814
815#[derive(Debug, Serialize, Deserialize)]
816pub struct SerializedToolUse {
817 pub id: LanguageModelToolUseId,
818 pub name: SharedString,
819 pub input: serde_json::Value,
820}
821
822#[derive(Debug, Serialize, Deserialize)]
823pub struct SerializedToolResult {
824 pub tool_use_id: LanguageModelToolUseId,
825 pub is_error: bool,
826 pub content: LanguageModelToolResultContent,
827 pub output: Option<serde_json::Value>,
828}
829
830#[derive(Serialize, Deserialize)]
831struct LegacySerializedThread {
832 pub summary: SharedString,
833 pub updated_at: DateTime<Utc>,
834 pub messages: Vec<LegacySerializedMessage>,
835 #[serde(default)]
836 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
837}
838
839impl LegacySerializedThread {
840 pub fn upgrade(self) -> SerializedThread {
841 SerializedThread {
842 version: SerializedThread::VERSION.to_string(),
843 summary: self.summary,
844 updated_at: self.updated_at,
845 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
846 initial_project_snapshot: self.initial_project_snapshot,
847 cumulative_token_usage: TokenUsage::default(),
848 request_token_usage: Vec::new(),
849 detailed_summary_state: DetailedSummaryState::default(),
850 exceeded_window_error: None,
851 model: None,
852 completion_mode: None,
853 tool_use_limit_reached: false,
854 }
855 }
856}
857
858#[derive(Debug, Serialize, Deserialize)]
859struct LegacySerializedMessage {
860 pub id: MessageId,
861 pub role: Role,
862 pub text: String,
863 #[serde(default)]
864 pub tool_uses: Vec<SerializedToolUse>,
865 #[serde(default)]
866 pub tool_results: Vec<SerializedToolResult>,
867}
868
869impl LegacySerializedMessage {
870 fn upgrade(self) -> SerializedMessage {
871 SerializedMessage {
872 id: self.id,
873 role: self.role,
874 segments: vec![SerializedMessageSegment::Text { text: self.text }],
875 tool_uses: self.tool_uses,
876 tool_results: self.tool_results,
877 context: String::new(),
878 creases: Vec::new(),
879 is_hidden: false,
880 }
881 }
882}
883
884#[derive(Debug, Serialize, Deserialize)]
885pub struct SerializedCrease {
886 pub start: usize,
887 pub end: usize,
888 pub icon_path: SharedString,
889 pub label: SharedString,
890}
891
892struct GlobalThreadsDatabase(
893 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
894);
895
896impl Global for GlobalThreadsDatabase {}
897
898pub(crate) struct ThreadsDatabase {
899 executor: BackgroundExecutor,
900 connection: Arc<Mutex<Connection>>,
901}
902
903impl ThreadsDatabase {
904 fn connection(&self) -> Arc<Mutex<Connection>> {
905 self.connection.clone()
906 }
907
908 const COMPRESSION_LEVEL: i32 = 3;
909}
910
911impl Bind for ThreadId {
912 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
913 self.to_string().bind(statement, start_index)
914 }
915}
916
917impl Column for ThreadId {
918 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
919 let (id_str, next_index) = String::column(statement, start_index)?;
920 Ok((ThreadId::from(id_str.as_str()), next_index))
921 }
922}
923
924impl ThreadsDatabase {
925 fn global_future(
926 cx: &mut App,
927 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
928 GlobalThreadsDatabase::global(cx).0.clone()
929 }
930
931 fn init(cx: &mut App) {
932 let executor = cx.background_executor().clone();
933 let database_future = executor
934 .spawn({
935 let executor = executor.clone();
936 let threads_dir = paths::data_dir().join("threads");
937 async move { ThreadsDatabase::new(threads_dir, executor) }
938 })
939 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
940 .boxed()
941 .shared();
942
943 cx.set_global(GlobalThreadsDatabase(database_future));
944 }
945
946 pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
947 std::fs::create_dir_all(&threads_dir)?;
948
949 let sqlite_path = threads_dir.join("threads.db");
950 let mdb_path = threads_dir.join("threads-db.1.mdb");
951
952 let needs_migration_from_heed = mdb_path.exists();
953
954 let connection = Connection::open_file(&sqlite_path.to_string_lossy());
955
956 connection.exec(indoc! {"
957 CREATE TABLE IF NOT EXISTS threads (
958 id TEXT PRIMARY KEY,
959 summary TEXT NOT NULL,
960 updated_at TEXT NOT NULL,
961 data_type TEXT NOT NULL,
962 data BLOB NOT NULL
963 )
964 "})?()
965 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
966
967 let db = Self {
968 executor: executor.clone(),
969 connection: Arc::new(Mutex::new(connection)),
970 };
971
972 if needs_migration_from_heed {
973 let db_connection = db.connection();
974 let executor_clone = executor.clone();
975 executor
976 .spawn(async move {
977 log::info!("Starting threads.db migration");
978 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
979 std::fs::remove_dir_all(mdb_path)?;
980 log::info!("threads.db migrated to sqlite");
981 Ok::<(), anyhow::Error>(())
982 })
983 .detach();
984 }
985
986 Ok(db)
987 }
988
989 // Remove this migration after 2025-09-01
990 fn migrate_from_heed(
991 mdb_path: &Path,
992 connection: Arc<Mutex<Connection>>,
993 _executor: BackgroundExecutor,
994 ) -> Result<()> {
995 use heed::types::SerdeBincode;
996 struct SerializedThreadHeed(SerializedThread);
997
998 impl heed::BytesEncode<'_> for SerializedThreadHeed {
999 type EItem = SerializedThreadHeed;
1000
1001 fn bytes_encode(
1002 item: &Self::EItem,
1003 ) -> Result<std::borrow::Cow<[u8]>, heed::BoxedError> {
1004 serde_json::to_vec(&item.0)
1005 .map(std::borrow::Cow::Owned)
1006 .map_err(Into::into)
1007 }
1008 }
1009
1010 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
1011 type DItem = SerializedThreadHeed;
1012
1013 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
1014 SerializedThread::from_json(bytes)
1015 .map(SerializedThreadHeed)
1016 .map_err(Into::into)
1017 }
1018 }
1019
1020 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
1021
1022 let env = unsafe {
1023 heed::EnvOpenOptions::new()
1024 .map_size(ONE_GB_IN_BYTES)
1025 .max_dbs(1)
1026 .open(mdb_path)?
1027 };
1028
1029 let txn = env.write_txn()?;
1030 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
1031 .open_database(&txn, Some("threads"))?
1032 .ok_or_else(|| anyhow!("threads database not found"))?;
1033
1034 for result in threads.iter(&txn)? {
1035 let (thread_id, thread_heed) = result?;
1036 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
1037 }
1038
1039 Ok(())
1040 }
1041
1042 fn save_thread_sync(
1043 connection: &Arc<Mutex<Connection>>,
1044 id: ThreadId,
1045 thread: SerializedThread,
1046 ) -> Result<()> {
1047 let json_data = serde_json::to_string(&thread)?;
1048 let summary = thread.summary.to_string();
1049 let updated_at = thread.updated_at.to_rfc3339();
1050
1051 let connection = connection.lock().unwrap();
1052
1053 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1054 let data_type = DataType::Zstd;
1055 let data = compressed;
1056
1057 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1058 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1059 "})?;
1060
1061 insert((id, summary, updated_at, data_type, data))?;
1062
1063 Ok(())
1064 }
1065
1066 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1067 let connection = self.connection.clone();
1068
1069 self.executor.spawn(async move {
1070 let connection = connection.lock().unwrap();
1071 let mut select =
1072 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1073 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1074 "})?;
1075
1076 let rows = select(())?;
1077 let mut threads = Vec::new();
1078
1079 for (id, summary, updated_at) in rows {
1080 threads.push(SerializedThreadMetadata {
1081 id,
1082 summary: summary.into(),
1083 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1084 });
1085 }
1086
1087 Ok(threads)
1088 })
1089 }
1090
1091 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1092 let connection = self.connection.clone();
1093
1094 self.executor.spawn(async move {
1095 let connection = connection.lock().unwrap();
1096 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1097 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1098 "})?;
1099
1100 let rows = select(id)?;
1101 if let Some((data_type, data)) = rows.into_iter().next() {
1102 let json_data = match data_type {
1103 DataType::Zstd => {
1104 let decompressed = zstd::decode_all(&data[..])?;
1105 String::from_utf8(decompressed)?
1106 }
1107 DataType::Json => String::from_utf8(data)?,
1108 };
1109
1110 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1111 Ok(Some(thread))
1112 } else {
1113 Ok(None)
1114 }
1115 })
1116 }
1117
1118 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1119 let connection = self.connection.clone();
1120
1121 self.executor
1122 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1123 }
1124
1125 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1126 let connection = self.connection.clone();
1127
1128 self.executor.spawn(async move {
1129 let connection = connection.lock().unwrap();
1130
1131 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1132 DELETE FROM threads WHERE id = ?
1133 "})?;
1134
1135 delete(id)?;
1136
1137 Ok(())
1138 })
1139 }
1140}