1use std::cell::{Ref, RefCell};
2use std::path::{Path, PathBuf};
3use std::rc::Rc;
4use std::sync::{Arc, Mutex};
5
6use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
7use anyhow::{Context as _, Result, anyhow};
8use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
9use chrono::{DateTime, Utc};
10use collections::HashMap;
11use context_server::ContextServerId;
12use futures::channel::{mpsc, oneshot};
13use futures::future::{self, BoxFuture, Shared};
14use futures::{FutureExt as _, StreamExt as _};
15use gpui::{
16 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
17 Subscription, Task, prelude::*,
18};
19
20use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
21use project::context_server_store::{ContextServerStatus, ContextServerStore};
22use project::{Project, ProjectItem, ProjectPath, Worktree};
23use prompt_store::{
24 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
25 UserRulesContext, WorktreeContext,
26};
27use serde::{Deserialize, Serialize};
28use settings::{Settings as _, SettingsStore};
29use ui::Window;
30use util::ResultExt as _;
31
32use crate::context_server_tool::ContextServerTool;
33use crate::thread::{
34 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
35};
36use indoc::indoc;
37use sqlez::{
38 bindable::{Bind, Column},
39 connection::Connection,
40 statement::Statement,
41};
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub enum DataType {
45 #[serde(rename = "json")]
46 Json,
47 #[serde(rename = "zstd")]
48 Zstd,
49}
50
51impl Bind for DataType {
52 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
53 let value = match self {
54 DataType::Json => "json",
55 DataType::Zstd => "zstd",
56 };
57 value.bind(statement, start_index)
58 }
59}
60
61impl Column for DataType {
62 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
63 let (value, next_index) = String::column(statement, start_index)?;
64 let data_type = match value.as_str() {
65 "json" => DataType::Json,
66 "zstd" => DataType::Zstd,
67 _ => anyhow::bail!("Unknown data type: {}", value),
68 };
69 Ok((data_type, next_index))
70 }
71}
72
73const RULES_FILE_NAMES: [&'static str; 8] = [
74 ".rules",
75 ".cursorrules",
76 ".windsurfrules",
77 ".clinerules",
78 ".github/copilot-instructions.md",
79 "CLAUDE.md",
80 "AGENT.md",
81 "AGENTS.md",
82];
83
84pub fn init(cx: &mut App) {
85 ThreadsDatabase::init(cx);
86}
87
88/// A system prompt shared by all threads created by this ThreadStore
89#[derive(Clone, Default)]
90pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
91
92impl SharedProjectContext {
93 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
94 self.0.borrow()
95 }
96}
97
98pub type TextThreadStore = assistant_context_editor::ContextStore;
99
100pub struct ThreadStore {
101 project: Entity<Project>,
102 tools: Entity<ToolWorkingSet>,
103 prompt_builder: Arc<PromptBuilder>,
104 prompt_store: Option<Entity<PromptStore>>,
105 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
106 threads: Vec<SerializedThreadMetadata>,
107 project_context: SharedProjectContext,
108 reload_system_prompt_tx: mpsc::Sender<()>,
109 _reload_system_prompt_task: Task<()>,
110 _subscriptions: Vec<Subscription>,
111}
112
113pub struct RulesLoadingError {
114 pub message: SharedString,
115}
116
117impl EventEmitter<RulesLoadingError> for ThreadStore {}
118
119impl ThreadStore {
120 pub fn load(
121 project: Entity<Project>,
122 tools: Entity<ToolWorkingSet>,
123 prompt_store: Option<Entity<PromptStore>>,
124 prompt_builder: Arc<PromptBuilder>,
125 cx: &mut App,
126 ) -> Task<Result<Entity<Self>>> {
127 cx.spawn(async move |cx| {
128 let (thread_store, ready_rx) = cx.update(|cx| {
129 let mut option_ready_rx = None;
130 let thread_store = cx.new(|cx| {
131 let (thread_store, ready_rx) =
132 Self::new(project, tools, prompt_builder, prompt_store, cx);
133 option_ready_rx = Some(ready_rx);
134 thread_store
135 });
136 (thread_store, option_ready_rx.take().unwrap())
137 })?;
138 ready_rx.await?;
139 Ok(thread_store)
140 })
141 }
142
143 fn new(
144 project: Entity<Project>,
145 tools: Entity<ToolWorkingSet>,
146 prompt_builder: Arc<PromptBuilder>,
147 prompt_store: Option<Entity<PromptStore>>,
148 cx: &mut Context<Self>,
149 ) -> (Self, oneshot::Receiver<()>) {
150 let mut subscriptions = vec![
151 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
152 this.load_default_profile(cx);
153 }),
154 cx.subscribe(&project, Self::handle_project_event),
155 ];
156
157 if let Some(prompt_store) = prompt_store.as_ref() {
158 subscriptions.push(cx.subscribe(
159 prompt_store,
160 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
161 this.enqueue_system_prompt_reload();
162 },
163 ))
164 }
165
166 // This channel and task prevent concurrent and redundant loading of the system prompt.
167 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
168 let (ready_tx, ready_rx) = oneshot::channel();
169 let mut ready_tx = Some(ready_tx);
170 let reload_system_prompt_task = cx.spawn({
171 let prompt_store = prompt_store.clone();
172 async move |thread_store, cx| {
173 loop {
174 let Some(reload_task) = thread_store
175 .update(cx, |thread_store, cx| {
176 thread_store.reload_system_prompt(prompt_store.clone(), cx)
177 })
178 .ok()
179 else {
180 return;
181 };
182 reload_task.await;
183 if let Some(ready_tx) = ready_tx.take() {
184 ready_tx.send(()).ok();
185 }
186 reload_system_prompt_rx.next().await;
187 }
188 }
189 });
190
191 let this = Self {
192 project,
193 tools,
194 prompt_builder,
195 prompt_store,
196 context_server_tool_ids: HashMap::default(),
197 threads: Vec::new(),
198 project_context: SharedProjectContext::default(),
199 reload_system_prompt_tx,
200 _reload_system_prompt_task: reload_system_prompt_task,
201 _subscriptions: subscriptions,
202 };
203 this.load_default_profile(cx);
204 this.register_context_server_handlers(cx);
205 this.reload(cx).detach_and_log_err(cx);
206 (this, ready_rx)
207 }
208
209 fn handle_project_event(
210 &mut self,
211 _project: Entity<Project>,
212 event: &project::Event,
213 _cx: &mut Context<Self>,
214 ) {
215 match event {
216 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
217 self.enqueue_system_prompt_reload();
218 }
219 project::Event::WorktreeUpdatedEntries(_, items) => {
220 if items.iter().any(|(path, _, _)| {
221 RULES_FILE_NAMES
222 .iter()
223 .any(|name| path.as_ref() == Path::new(name))
224 }) {
225 self.enqueue_system_prompt_reload();
226 }
227 }
228 _ => {}
229 }
230 }
231
232 fn enqueue_system_prompt_reload(&mut self) {
233 self.reload_system_prompt_tx.try_send(()).ok();
234 }
235
236 // Note that this should only be called from `reload_system_prompt_task`.
237 fn reload_system_prompt(
238 &self,
239 prompt_store: Option<Entity<PromptStore>>,
240 cx: &mut Context<Self>,
241 ) -> Task<()> {
242 let worktrees = self
243 .project
244 .read(cx)
245 .visible_worktrees(cx)
246 .collect::<Vec<_>>();
247 let worktree_tasks = worktrees
248 .into_iter()
249 .map(|worktree| {
250 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
251 })
252 .collect::<Vec<_>>();
253 let default_user_rules_task = match prompt_store {
254 None => Task::ready(vec![]),
255 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
256 let prompts = prompt_store.default_prompt_metadata();
257 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
258 let contents = prompt_store.load(prompt_metadata.id, cx);
259 async move { (contents.await, prompt_metadata) }
260 });
261 cx.background_spawn(future::join_all(load_tasks))
262 }),
263 };
264
265 cx.spawn(async move |this, cx| {
266 let (worktrees, default_user_rules) =
267 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
268
269 let worktrees = worktrees
270 .into_iter()
271 .map(|(worktree, rules_error)| {
272 if let Some(rules_error) = rules_error {
273 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
274 }
275 worktree
276 })
277 .collect::<Vec<_>>();
278
279 let default_user_rules = default_user_rules
280 .into_iter()
281 .flat_map(|(contents, prompt_metadata)| match contents {
282 Ok(contents) => Some(UserRulesContext {
283 uuid: match prompt_metadata.id {
284 PromptId::User { uuid } => uuid,
285 PromptId::EditWorkflow => return None,
286 },
287 title: prompt_metadata.title.map(|title| title.to_string()),
288 contents,
289 }),
290 Err(err) => {
291 this.update(cx, |_, cx| {
292 cx.emit(RulesLoadingError {
293 message: format!("{err:?}").into(),
294 });
295 })
296 .ok();
297 None
298 }
299 })
300 .collect::<Vec<_>>();
301
302 this.update(cx, |this, _cx| {
303 *this.project_context.0.borrow_mut() =
304 Some(ProjectContext::new(worktrees, default_user_rules));
305 })
306 .ok();
307 })
308 }
309
310 fn load_worktree_info_for_system_prompt(
311 worktree: Entity<Worktree>,
312 project: Entity<Project>,
313 cx: &mut App,
314 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
315 let root_name = worktree.read(cx).root_name().into();
316
317 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
318 let Some(rules_task) = rules_task else {
319 return Task::ready((
320 WorktreeContext {
321 root_name,
322 rules_file: None,
323 },
324 None,
325 ));
326 };
327
328 cx.spawn(async move |_| {
329 let (rules_file, rules_file_error) = match rules_task.await {
330 Ok(rules_file) => (Some(rules_file), None),
331 Err(err) => (
332 None,
333 Some(RulesLoadingError {
334 message: format!("{err}").into(),
335 }),
336 ),
337 };
338 let worktree_info = WorktreeContext {
339 root_name,
340 rules_file,
341 };
342 (worktree_info, rules_file_error)
343 })
344 }
345
346 fn load_worktree_rules_file(
347 worktree: Entity<Worktree>,
348 project: Entity<Project>,
349 cx: &mut App,
350 ) -> Option<Task<Result<RulesFileContext>>> {
351 let worktree_ref = worktree.read(cx);
352 let worktree_id = worktree_ref.id();
353 let selected_rules_file = RULES_FILE_NAMES
354 .into_iter()
355 .filter_map(|name| {
356 worktree_ref
357 .entry_for_path(name)
358 .filter(|entry| entry.is_file())
359 .map(|entry| entry.path.clone())
360 })
361 .next();
362
363 // Note that Cline supports `.clinerules` being a directory, but that is not currently
364 // supported. This doesn't seem to occur often in GitHub repositories.
365 selected_rules_file.map(|path_in_worktree| {
366 let project_path = ProjectPath {
367 worktree_id,
368 path: path_in_worktree.clone(),
369 };
370 let buffer_task =
371 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
372 let rope_task = cx.spawn(async move |cx| {
373 buffer_task.await?.read_with(cx, |buffer, cx| {
374 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
375 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
376 })?
377 });
378 // Build a string from the rope on a background thread.
379 cx.background_spawn(async move {
380 let (project_entry_id, rope) = rope_task.await?;
381 anyhow::Ok(RulesFileContext {
382 path_in_worktree,
383 text: rope.to_string().trim().to_string(),
384 project_entry_id: project_entry_id.to_usize(),
385 })
386 })
387 })
388 }
389
390 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
391 &self.prompt_store
392 }
393
394 pub fn tools(&self) -> Entity<ToolWorkingSet> {
395 self.tools.clone()
396 }
397
398 /// Returns the number of threads.
399 pub fn thread_count(&self) -> usize {
400 self.threads.len()
401 }
402
403 pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
404 self.threads.iter()
405 }
406
407 pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
408 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
409 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
410 threads
411 }
412
413 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
414 cx.new(|cx| {
415 Thread::new(
416 self.project.clone(),
417 self.tools.clone(),
418 self.prompt_builder.clone(),
419 self.project_context.clone(),
420 cx,
421 )
422 })
423 }
424
425 pub fn create_thread_from_serialized(
426 &mut self,
427 serialized: SerializedThread,
428 cx: &mut Context<Self>,
429 ) -> Entity<Thread> {
430 cx.new(|cx| {
431 Thread::deserialize(
432 ThreadId::new(),
433 serialized,
434 self.project.clone(),
435 self.tools.clone(),
436 self.prompt_builder.clone(),
437 self.project_context.clone(),
438 None,
439 cx,
440 )
441 })
442 }
443
444 pub fn open_thread(
445 &self,
446 id: &ThreadId,
447 window: &mut Window,
448 cx: &mut Context<Self>,
449 ) -> Task<Result<Entity<Thread>>> {
450 let id = id.clone();
451 let database_future = ThreadsDatabase::global_future(cx);
452 let this = cx.weak_entity();
453 window.spawn(cx, async move |cx| {
454 let database = database_future.await.map_err(|err| anyhow!(err))?;
455 let thread = database
456 .try_find_thread(id.clone())
457 .await?
458 .with_context(|| format!("no thread found with ID: {id:?}"))?;
459
460 let thread = this.update_in(cx, |this, window, cx| {
461 cx.new(|cx| {
462 Thread::deserialize(
463 id.clone(),
464 thread,
465 this.project.clone(),
466 this.tools.clone(),
467 this.prompt_builder.clone(),
468 this.project_context.clone(),
469 Some(window),
470 cx,
471 )
472 })
473 })?;
474
475 Ok(thread)
476 })
477 }
478
479 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
480 let (metadata, serialized_thread) =
481 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
482
483 let database_future = ThreadsDatabase::global_future(cx);
484 cx.spawn(async move |this, cx| {
485 let serialized_thread = serialized_thread.await?;
486 let database = database_future.await.map_err(|err| anyhow!(err))?;
487 database.save_thread(metadata, serialized_thread).await?;
488
489 this.update(cx, |this, cx| this.reload(cx))?.await
490 })
491 }
492
493 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
494 let id = id.clone();
495 let database_future = ThreadsDatabase::global_future(cx);
496 cx.spawn(async move |this, cx| {
497 let database = database_future.await.map_err(|err| anyhow!(err))?;
498 database.delete_thread(id.clone()).await?;
499
500 this.update(cx, |this, cx| {
501 this.threads.retain(|thread| thread.id != id);
502 cx.notify();
503 })
504 })
505 }
506
507 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
508 let database_future = ThreadsDatabase::global_future(cx);
509 cx.spawn(async move |this, cx| {
510 let threads = database_future
511 .await
512 .map_err(|err| anyhow!(err))?
513 .list_threads()
514 .await?;
515
516 this.update(cx, |this, cx| {
517 this.threads = threads;
518 cx.notify();
519 })
520 })
521 }
522
523 fn load_default_profile(&self, cx: &mut Context<Self>) {
524 let assistant_settings = AgentSettings::get_global(cx);
525
526 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
527 }
528
529 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
530 let assistant_settings = AgentSettings::get_global(cx);
531
532 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
533 self.load_profile(profile.clone(), cx);
534 }
535 }
536
537 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
538 self.tools.update(cx, |tools, cx| {
539 tools.disable_all_tools(cx);
540 tools.enable(
541 ToolSource::Native,
542 &profile
543 .tools
544 .into_iter()
545 .filter_map(|(tool, enabled)| enabled.then(|| tool))
546 .collect::<Vec<_>>(),
547 cx,
548 );
549 });
550
551 if profile.enable_all_context_servers {
552 for context_server_id in self
553 .project
554 .read(cx)
555 .context_server_store()
556 .read(cx)
557 .all_server_ids()
558 {
559 self.tools.update(cx, |tools, cx| {
560 tools.enable_source(
561 ToolSource::ContextServer {
562 id: context_server_id.0.into(),
563 },
564 cx,
565 );
566 });
567 }
568 // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
569 for (context_server_id, preset) in profile.context_servers {
570 self.tools.update(cx, |tools, cx| {
571 tools.disable(
572 ToolSource::ContextServer {
573 id: context_server_id.into(),
574 },
575 &preset
576 .tools
577 .into_iter()
578 .filter_map(|(tool, enabled)| (!enabled).then(|| tool))
579 .collect::<Vec<_>>(),
580 cx,
581 )
582 })
583 }
584 } else {
585 for (context_server_id, preset) in profile.context_servers {
586 self.tools.update(cx, |tools, cx| {
587 tools.enable(
588 ToolSource::ContextServer {
589 id: context_server_id.into(),
590 },
591 &preset
592 .tools
593 .into_iter()
594 .filter_map(|(tool, enabled)| enabled.then(|| tool))
595 .collect::<Vec<_>>(),
596 cx,
597 )
598 })
599 }
600 }
601 }
602
603 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
604 cx.subscribe(
605 &self.project.read(cx).context_server_store(),
606 Self::handle_context_server_event,
607 )
608 .detach();
609 }
610
611 fn handle_context_server_event(
612 &mut self,
613 context_server_store: Entity<ContextServerStore>,
614 event: &project::context_server_store::Event,
615 cx: &mut Context<Self>,
616 ) {
617 let tool_working_set = self.tools.clone();
618 match event {
619 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
620 match status {
621 ContextServerStatus::Running => {
622 if let Some(server) =
623 context_server_store.read(cx).get_running_server(server_id)
624 {
625 let context_server_manager = context_server_store.clone();
626 cx.spawn({
627 let server = server.clone();
628 let server_id = server_id.clone();
629 async move |this, cx| {
630 let Some(protocol) = server.client() else {
631 return;
632 };
633
634 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
635 if let Some(tools) = protocol.list_tools().await.log_err() {
636 let tool_ids = tool_working_set
637 .update(cx, |tool_working_set, _| {
638 tools
639 .tools
640 .into_iter()
641 .map(|tool| {
642 log::info!(
643 "registering context server tool: {:?}",
644 tool.name
645 );
646 tool_working_set.insert(Arc::new(
647 ContextServerTool::new(
648 context_server_manager.clone(),
649 server.id(),
650 tool,
651 ),
652 ))
653 })
654 .collect::<Vec<_>>()
655 })
656 .log_err();
657
658 if let Some(tool_ids) = tool_ids {
659 this.update(cx, |this, cx| {
660 this.context_server_tool_ids
661 .insert(server_id, tool_ids);
662 this.load_default_profile(cx);
663 })
664 .log_err();
665 }
666 }
667 }
668 }
669 })
670 .detach();
671 }
672 }
673 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
674 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
675 tool_working_set.update(cx, |tool_working_set, _| {
676 tool_working_set.remove(&tool_ids);
677 });
678 self.load_default_profile(cx);
679 }
680 }
681 _ => {}
682 }
683 }
684 }
685 }
686}
687
688#[derive(Debug, Clone, Serialize, Deserialize)]
689pub struct SerializedThreadMetadata {
690 pub id: ThreadId,
691 pub summary: SharedString,
692 pub updated_at: DateTime<Utc>,
693}
694
695#[derive(Serialize, Deserialize, Debug)]
696pub struct SerializedThread {
697 pub version: String,
698 pub summary: SharedString,
699 pub updated_at: DateTime<Utc>,
700 pub messages: Vec<SerializedMessage>,
701 #[serde(default)]
702 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
703 #[serde(default)]
704 pub cumulative_token_usage: TokenUsage,
705 #[serde(default)]
706 pub request_token_usage: Vec<TokenUsage>,
707 #[serde(default)]
708 pub detailed_summary_state: DetailedSummaryState,
709 #[serde(default)]
710 pub exceeded_window_error: Option<ExceededWindowError>,
711 #[serde(default)]
712 pub model: Option<SerializedLanguageModel>,
713 #[serde(default)]
714 pub completion_mode: Option<CompletionMode>,
715 #[serde(default)]
716 pub tool_use_limit_reached: bool,
717}
718
719#[derive(Serialize, Deserialize, Debug)]
720pub struct SerializedLanguageModel {
721 pub provider: String,
722 pub model: String,
723}
724
725impl SerializedThread {
726 pub const VERSION: &'static str = "0.2.0";
727
728 pub fn from_json(json: &[u8]) -> Result<Self> {
729 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
730 match saved_thread_json.get("version") {
731 Some(serde_json::Value::String(version)) => match version.as_str() {
732 SerializedThreadV0_1_0::VERSION => {
733 let saved_thread =
734 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
735 Ok(saved_thread.upgrade())
736 }
737 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
738 saved_thread_json,
739 )?),
740 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
741 },
742 None => {
743 let saved_thread =
744 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
745 Ok(saved_thread.upgrade())
746 }
747 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
748 }
749 }
750}
751
752#[derive(Serialize, Deserialize, Debug)]
753pub struct SerializedThreadV0_1_0(
754 // The structure did not change, so we are reusing the latest SerializedThread.
755 // When making the next version, make sure this points to SerializedThreadV0_2_0
756 SerializedThread,
757);
758
759impl SerializedThreadV0_1_0 {
760 pub const VERSION: &'static str = "0.1.0";
761
762 pub fn upgrade(self) -> SerializedThread {
763 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
764
765 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
766
767 for message in self.0.messages {
768 if message.role == Role::User && !message.tool_results.is_empty() {
769 if let Some(last_message) = messages.last_mut() {
770 debug_assert!(last_message.role == Role::Assistant);
771
772 last_message.tool_results = message.tool_results;
773 continue;
774 }
775 }
776
777 messages.push(message);
778 }
779
780 SerializedThread { messages, ..self.0 }
781 }
782}
783
784#[derive(Debug, Serialize, Deserialize)]
785pub struct SerializedMessage {
786 pub id: MessageId,
787 pub role: Role,
788 #[serde(default)]
789 pub segments: Vec<SerializedMessageSegment>,
790 #[serde(default)]
791 pub tool_uses: Vec<SerializedToolUse>,
792 #[serde(default)]
793 pub tool_results: Vec<SerializedToolResult>,
794 #[serde(default)]
795 pub context: String,
796 #[serde(default)]
797 pub creases: Vec<SerializedCrease>,
798 #[serde(default)]
799 pub is_hidden: bool,
800}
801
802#[derive(Debug, Serialize, Deserialize)]
803#[serde(tag = "type")]
804pub enum SerializedMessageSegment {
805 #[serde(rename = "text")]
806 Text {
807 text: String,
808 },
809 #[serde(rename = "thinking")]
810 Thinking {
811 text: String,
812 #[serde(skip_serializing_if = "Option::is_none")]
813 signature: Option<String>,
814 },
815 RedactedThinking {
816 data: Vec<u8>,
817 },
818}
819
820#[derive(Debug, Serialize, Deserialize)]
821pub struct SerializedToolUse {
822 pub id: LanguageModelToolUseId,
823 pub name: SharedString,
824 pub input: serde_json::Value,
825}
826
827#[derive(Debug, Serialize, Deserialize)]
828pub struct SerializedToolResult {
829 pub tool_use_id: LanguageModelToolUseId,
830 pub is_error: bool,
831 pub content: LanguageModelToolResultContent,
832 pub output: Option<serde_json::Value>,
833}
834
835#[derive(Serialize, Deserialize)]
836struct LegacySerializedThread {
837 pub summary: SharedString,
838 pub updated_at: DateTime<Utc>,
839 pub messages: Vec<LegacySerializedMessage>,
840 #[serde(default)]
841 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
842}
843
844impl LegacySerializedThread {
845 pub fn upgrade(self) -> SerializedThread {
846 SerializedThread {
847 version: SerializedThread::VERSION.to_string(),
848 summary: self.summary,
849 updated_at: self.updated_at,
850 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
851 initial_project_snapshot: self.initial_project_snapshot,
852 cumulative_token_usage: TokenUsage::default(),
853 request_token_usage: Vec::new(),
854 detailed_summary_state: DetailedSummaryState::default(),
855 exceeded_window_error: None,
856 model: None,
857 completion_mode: None,
858 tool_use_limit_reached: false,
859 }
860 }
861}
862
863#[derive(Debug, Serialize, Deserialize)]
864struct LegacySerializedMessage {
865 pub id: MessageId,
866 pub role: Role,
867 pub text: String,
868 #[serde(default)]
869 pub tool_uses: Vec<SerializedToolUse>,
870 #[serde(default)]
871 pub tool_results: Vec<SerializedToolResult>,
872}
873
874impl LegacySerializedMessage {
875 fn upgrade(self) -> SerializedMessage {
876 SerializedMessage {
877 id: self.id,
878 role: self.role,
879 segments: vec![SerializedMessageSegment::Text { text: self.text }],
880 tool_uses: self.tool_uses,
881 tool_results: self.tool_results,
882 context: String::new(),
883 creases: Vec::new(),
884 is_hidden: false,
885 }
886 }
887}
888
889#[derive(Debug, Serialize, Deserialize)]
890pub struct SerializedCrease {
891 pub start: usize,
892 pub end: usize,
893 pub icon_path: SharedString,
894 pub label: SharedString,
895}
896
897struct GlobalThreadsDatabase(
898 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
899);
900
901impl Global for GlobalThreadsDatabase {}
902
903pub(crate) struct ThreadsDatabase {
904 executor: BackgroundExecutor,
905 connection: Arc<Mutex<Connection>>,
906}
907
908impl ThreadsDatabase {
909 fn connection(&self) -> Arc<Mutex<Connection>> {
910 self.connection.clone()
911 }
912
913 const COMPRESSION_LEVEL: i32 = 3;
914}
915
916impl Bind for ThreadId {
917 fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
918 self.to_string().bind(statement, start_index)
919 }
920}
921
922impl Column for ThreadId {
923 fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
924 let (id_str, next_index) = String::column(statement, start_index)?;
925 Ok((ThreadId::from(id_str.as_str()), next_index))
926 }
927}
928
929impl ThreadsDatabase {
930 fn global_future(
931 cx: &mut App,
932 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
933 GlobalThreadsDatabase::global(cx).0.clone()
934 }
935
936 fn init(cx: &mut App) {
937 let executor = cx.background_executor().clone();
938 let database_future = executor
939 .spawn({
940 let executor = executor.clone();
941 let threads_dir = paths::data_dir().join("threads");
942 async move { ThreadsDatabase::new(threads_dir, executor) }
943 })
944 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
945 .boxed()
946 .shared();
947
948 cx.set_global(GlobalThreadsDatabase(database_future));
949 }
950
951 pub fn new(threads_dir: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
952 std::fs::create_dir_all(&threads_dir)?;
953
954 let sqlite_path = threads_dir.join("threads.db");
955 let mdb_path = threads_dir.join("threads-db.1.mdb");
956
957 let needs_migration_from_heed = mdb_path.exists();
958
959 let connection = Connection::open_file(&sqlite_path.to_string_lossy());
960
961 connection.exec(indoc! {"
962 CREATE TABLE IF NOT EXISTS threads (
963 id TEXT PRIMARY KEY,
964 summary TEXT NOT NULL,
965 updated_at TEXT NOT NULL,
966 data_type TEXT NOT NULL,
967 data BLOB NOT NULL
968 )
969 "})?()
970 .map_err(|e| anyhow!("Failed to create threads table: {}", e))?;
971
972 let db = Self {
973 executor: executor.clone(),
974 connection: Arc::new(Mutex::new(connection)),
975 };
976
977 if needs_migration_from_heed {
978 let db_connection = db.connection();
979 let executor_clone = executor.clone();
980 executor
981 .spawn(async move {
982 log::info!("Starting threads.db migration");
983 Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?;
984 std::fs::remove_dir_all(mdb_path)?;
985 log::info!("threads.db migrated to sqlite");
986 Ok::<(), anyhow::Error>(())
987 })
988 .detach();
989 }
990
991 Ok(db)
992 }
993
994 // Remove this migration after 2025-09-01
995 fn migrate_from_heed(
996 mdb_path: &Path,
997 connection: Arc<Mutex<Connection>>,
998 _executor: BackgroundExecutor,
999 ) -> Result<()> {
1000 use heed::types::SerdeBincode;
1001 struct SerializedThreadHeed(SerializedThread);
1002
1003 impl heed::BytesEncode<'_> for SerializedThreadHeed {
1004 type EItem = SerializedThreadHeed;
1005
1006 fn bytes_encode(
1007 item: &Self::EItem,
1008 ) -> Result<std::borrow::Cow<[u8]>, heed::BoxedError> {
1009 serde_json::to_vec(&item.0)
1010 .map(std::borrow::Cow::Owned)
1011 .map_err(Into::into)
1012 }
1013 }
1014
1015 impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed {
1016 type DItem = SerializedThreadHeed;
1017
1018 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
1019 SerializedThread::from_json(bytes)
1020 .map(SerializedThreadHeed)
1021 .map_err(Into::into)
1022 }
1023 }
1024
1025 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
1026
1027 let env = unsafe {
1028 heed::EnvOpenOptions::new()
1029 .map_size(ONE_GB_IN_BYTES)
1030 .max_dbs(1)
1031 .open(mdb_path)?
1032 };
1033
1034 let txn = env.write_txn()?;
1035 let threads: heed::Database<SerdeBincode<ThreadId>, SerializedThreadHeed> = env
1036 .open_database(&txn, Some("threads"))?
1037 .ok_or_else(|| anyhow!("threads database not found"))?;
1038
1039 for result in threads.iter(&txn)? {
1040 let (thread_id, thread_heed) = result?;
1041 Self::save_thread_sync(&connection, thread_id, thread_heed.0)?;
1042 }
1043
1044 Ok(())
1045 }
1046
1047 fn save_thread_sync(
1048 connection: &Arc<Mutex<Connection>>,
1049 id: ThreadId,
1050 thread: SerializedThread,
1051 ) -> Result<()> {
1052 let json_data = serde_json::to_string(&thread)?;
1053 let summary = thread.summary.to_string();
1054 let updated_at = thread.updated_at.to_rfc3339();
1055
1056 let connection = connection.lock().unwrap();
1057
1058 let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?;
1059 let data_type = DataType::Zstd;
1060 let data = compressed;
1061
1062 let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec<u8>)>(indoc! {"
1063 INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?)
1064 "})?;
1065
1066 insert((id, summary, updated_at, data_type, data))?;
1067
1068 Ok(())
1069 }
1070
1071 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
1072 let connection = self.connection.clone();
1073
1074 self.executor.spawn(async move {
1075 let connection = connection.lock().unwrap();
1076 let mut select =
1077 connection.select_bound::<(), (ThreadId, String, String)>(indoc! {"
1078 SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC
1079 "})?;
1080
1081 let rows = select(())?;
1082 let mut threads = Vec::new();
1083
1084 for (id, summary, updated_at) in rows {
1085 threads.push(SerializedThreadMetadata {
1086 id,
1087 summary: summary.into(),
1088 updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc),
1089 });
1090 }
1091
1092 Ok(threads)
1093 })
1094 }
1095
1096 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
1097 let connection = self.connection.clone();
1098
1099 self.executor.spawn(async move {
1100 let connection = connection.lock().unwrap();
1101 let mut select = connection.select_bound::<ThreadId, (DataType, Vec<u8>)>(indoc! {"
1102 SELECT data_type, data FROM threads WHERE id = ? LIMIT 1
1103 "})?;
1104
1105 let rows = select(id)?;
1106 if let Some((data_type, data)) = rows.into_iter().next() {
1107 let json_data = match data_type {
1108 DataType::Zstd => {
1109 let decompressed = zstd::decode_all(&data[..])?;
1110 String::from_utf8(decompressed)?
1111 }
1112 DataType::Json => String::from_utf8(data)?,
1113 };
1114
1115 let thread = SerializedThread::from_json(json_data.as_bytes())?;
1116 Ok(Some(thread))
1117 } else {
1118 Ok(None)
1119 }
1120 })
1121 }
1122
1123 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
1124 let connection = self.connection.clone();
1125
1126 self.executor
1127 .spawn(async move { Self::save_thread_sync(&connection, id, thread) })
1128 }
1129
1130 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
1131 let connection = self.connection.clone();
1132
1133 self.executor.spawn(async move {
1134 let connection = connection.lock().unwrap();
1135
1136 let mut delete = connection.exec_bound::<ThreadId>(indoc! {"
1137 DELETE FROM threads WHERE id = ?
1138 "})?;
1139
1140 delete(id)?;
1141
1142 Ok(())
1143 })
1144 }
1145}