1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
8use anyhow::{Context as _, Result, anyhow};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use futures::channel::{mpsc, oneshot};
14use futures::future::{self, BoxFuture, Shared};
15use futures::{FutureExt as _, StreamExt as _};
16use gpui::{
17 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
18 Subscription, Task, prelude::*,
19};
20use heed::Database;
21use heed::types::SerdeBincode;
22use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
23use project::context_server_store::{ContextServerStatus, ContextServerStore};
24use project::{Project, ProjectItem, ProjectPath, Worktree};
25use prompt_store::{
26 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
27 UserRulesContext, WorktreeContext,
28};
29use serde::{Deserialize, Serialize};
30use settings::{Settings as _, SettingsStore};
31use ui::Window;
32use util::ResultExt as _;
33
34use crate::context_server_tool::ContextServerTool;
35use crate::thread::{
36 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
37};
38
39const RULES_FILE_NAMES: [&'static str; 6] = [
40 ".rules",
41 ".cursorrules",
42 ".windsurfrules",
43 ".clinerules",
44 ".github/copilot-instructions.md",
45 "CLAUDE.md",
46];
47
48pub fn init(cx: &mut App) {
49 ThreadsDatabase::init(cx);
50}
51
52/// A system prompt shared by all threads created by this ThreadStore
53#[derive(Clone, Default)]
54pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
55
56impl SharedProjectContext {
57 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
58 self.0.borrow()
59 }
60}
61
62pub type TextThreadStore = assistant_context_editor::ContextStore;
63
64pub struct ThreadStore {
65 project: Entity<Project>,
66 tools: Entity<ToolWorkingSet>,
67 prompt_builder: Arc<PromptBuilder>,
68 prompt_store: Option<Entity<PromptStore>>,
69 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
70 threads: Vec<SerializedThreadMetadata>,
71 project_context: SharedProjectContext,
72 reload_system_prompt_tx: mpsc::Sender<()>,
73 _reload_system_prompt_task: Task<()>,
74 _subscriptions: Vec<Subscription>,
75}
76
77pub struct RulesLoadingError {
78 pub message: SharedString,
79}
80
81impl EventEmitter<RulesLoadingError> for ThreadStore {}
82
83impl ThreadStore {
84 pub fn load(
85 project: Entity<Project>,
86 tools: Entity<ToolWorkingSet>,
87 prompt_store: Option<Entity<PromptStore>>,
88 prompt_builder: Arc<PromptBuilder>,
89 cx: &mut App,
90 ) -> Task<Result<Entity<Self>>> {
91 cx.spawn(async move |cx| {
92 let (thread_store, ready_rx) = cx.update(|cx| {
93 let mut option_ready_rx = None;
94 let thread_store = cx.new(|cx| {
95 let (thread_store, ready_rx) =
96 Self::new(project, tools, prompt_builder, prompt_store, cx);
97 option_ready_rx = Some(ready_rx);
98 thread_store
99 });
100 (thread_store, option_ready_rx.take().unwrap())
101 })?;
102 ready_rx.await?;
103 Ok(thread_store)
104 })
105 }
106
107 fn new(
108 project: Entity<Project>,
109 tools: Entity<ToolWorkingSet>,
110 prompt_builder: Arc<PromptBuilder>,
111 prompt_store: Option<Entity<PromptStore>>,
112 cx: &mut Context<Self>,
113 ) -> (Self, oneshot::Receiver<()>) {
114 let mut subscriptions = vec![
115 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
116 this.load_default_profile(cx);
117 }),
118 cx.subscribe(&project, Self::handle_project_event),
119 ];
120
121 if let Some(prompt_store) = prompt_store.as_ref() {
122 subscriptions.push(cx.subscribe(
123 prompt_store,
124 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
125 this.enqueue_system_prompt_reload();
126 },
127 ))
128 }
129
130 // This channel and task prevent concurrent and redundant loading of the system prompt.
131 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
132 let (ready_tx, ready_rx) = oneshot::channel();
133 let mut ready_tx = Some(ready_tx);
134 let reload_system_prompt_task = cx.spawn({
135 let prompt_store = prompt_store.clone();
136 async move |thread_store, cx| {
137 loop {
138 let Some(reload_task) = thread_store
139 .update(cx, |thread_store, cx| {
140 thread_store.reload_system_prompt(prompt_store.clone(), cx)
141 })
142 .ok()
143 else {
144 return;
145 };
146 reload_task.await;
147 if let Some(ready_tx) = ready_tx.take() {
148 ready_tx.send(()).ok();
149 }
150 reload_system_prompt_rx.next().await;
151 }
152 }
153 });
154
155 let this = Self {
156 project,
157 tools,
158 prompt_builder,
159 prompt_store,
160 context_server_tool_ids: HashMap::default(),
161 threads: Vec::new(),
162 project_context: SharedProjectContext::default(),
163 reload_system_prompt_tx,
164 _reload_system_prompt_task: reload_system_prompt_task,
165 _subscriptions: subscriptions,
166 };
167 this.load_default_profile(cx);
168 this.register_context_server_handlers(cx);
169 this.reload(cx).detach_and_log_err(cx);
170 (this, ready_rx)
171 }
172
173 fn handle_project_event(
174 &mut self,
175 _project: Entity<Project>,
176 event: &project::Event,
177 _cx: &mut Context<Self>,
178 ) {
179 match event {
180 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
181 self.enqueue_system_prompt_reload();
182 }
183 project::Event::WorktreeUpdatedEntries(_, items) => {
184 if items.iter().any(|(path, _, _)| {
185 RULES_FILE_NAMES
186 .iter()
187 .any(|name| path.as_ref() == Path::new(name))
188 }) {
189 self.enqueue_system_prompt_reload();
190 }
191 }
192 _ => {}
193 }
194 }
195
196 fn enqueue_system_prompt_reload(&mut self) {
197 self.reload_system_prompt_tx.try_send(()).ok();
198 }
199
200 // Note that this should only be called from `reload_system_prompt_task`.
201 fn reload_system_prompt(
202 &self,
203 prompt_store: Option<Entity<PromptStore>>,
204 cx: &mut Context<Self>,
205 ) -> Task<()> {
206 let worktrees = self
207 .project
208 .read(cx)
209 .visible_worktrees(cx)
210 .collect::<Vec<_>>();
211 let worktree_tasks = worktrees
212 .into_iter()
213 .map(|worktree| {
214 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
215 })
216 .collect::<Vec<_>>();
217 let default_user_rules_task = match prompt_store {
218 None => Task::ready(vec![]),
219 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
220 let prompts = prompt_store.default_prompt_metadata();
221 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
222 let contents = prompt_store.load(prompt_metadata.id, cx);
223 async move { (contents.await, prompt_metadata) }
224 });
225 cx.background_spawn(future::join_all(load_tasks))
226 }),
227 };
228
229 cx.spawn(async move |this, cx| {
230 let (worktrees, default_user_rules) =
231 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
232
233 let worktrees = worktrees
234 .into_iter()
235 .map(|(worktree, rules_error)| {
236 if let Some(rules_error) = rules_error {
237 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
238 }
239 worktree
240 })
241 .collect::<Vec<_>>();
242
243 let default_user_rules = default_user_rules
244 .into_iter()
245 .flat_map(|(contents, prompt_metadata)| match contents {
246 Ok(contents) => Some(UserRulesContext {
247 uuid: match prompt_metadata.id {
248 PromptId::User { uuid } => uuid,
249 PromptId::EditWorkflow => return None,
250 },
251 title: prompt_metadata.title.map(|title| title.to_string()),
252 contents,
253 }),
254 Err(err) => {
255 this.update(cx, |_, cx| {
256 cx.emit(RulesLoadingError {
257 message: format!("{err:?}").into(),
258 });
259 })
260 .ok();
261 None
262 }
263 })
264 .collect::<Vec<_>>();
265
266 this.update(cx, |this, _cx| {
267 *this.project_context.0.borrow_mut() =
268 Some(ProjectContext::new(worktrees, default_user_rules));
269 })
270 .ok();
271 })
272 }
273
274 fn load_worktree_info_for_system_prompt(
275 worktree: Entity<Worktree>,
276 project: Entity<Project>,
277 cx: &mut App,
278 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
279 let root_name = worktree.read(cx).root_name().into();
280
281 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
282 let Some(rules_task) = rules_task else {
283 return Task::ready((
284 WorktreeContext {
285 root_name,
286 rules_file: None,
287 },
288 None,
289 ));
290 };
291
292 cx.spawn(async move |_| {
293 let (rules_file, rules_file_error) = match rules_task.await {
294 Ok(rules_file) => (Some(rules_file), None),
295 Err(err) => (
296 None,
297 Some(RulesLoadingError {
298 message: format!("{err}").into(),
299 }),
300 ),
301 };
302 let worktree_info = WorktreeContext {
303 root_name,
304 rules_file,
305 };
306 (worktree_info, rules_file_error)
307 })
308 }
309
310 fn load_worktree_rules_file(
311 worktree: Entity<Worktree>,
312 project: Entity<Project>,
313 cx: &mut App,
314 ) -> Option<Task<Result<RulesFileContext>>> {
315 let worktree_ref = worktree.read(cx);
316 let worktree_id = worktree_ref.id();
317 let selected_rules_file = RULES_FILE_NAMES
318 .into_iter()
319 .filter_map(|name| {
320 worktree_ref
321 .entry_for_path(name)
322 .filter(|entry| entry.is_file())
323 .map(|entry| entry.path.clone())
324 })
325 .next();
326
327 // Note that Cline supports `.clinerules` being a directory, but that is not currently
328 // supported. This doesn't seem to occur often in GitHub repositories.
329 selected_rules_file.map(|path_in_worktree| {
330 let project_path = ProjectPath {
331 worktree_id,
332 path: path_in_worktree.clone(),
333 };
334 let buffer_task =
335 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
336 let rope_task = cx.spawn(async move |cx| {
337 buffer_task.await?.read_with(cx, |buffer, cx| {
338 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
339 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
340 })?
341 });
342 // Build a string from the rope on a background thread.
343 cx.background_spawn(async move {
344 let (project_entry_id, rope) = rope_task.await?;
345 anyhow::Ok(RulesFileContext {
346 path_in_worktree,
347 text: rope.to_string().trim().to_string(),
348 project_entry_id: project_entry_id.to_usize(),
349 })
350 })
351 })
352 }
353
354 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
355 &self.prompt_store
356 }
357
358 pub fn tools(&self) -> Entity<ToolWorkingSet> {
359 self.tools.clone()
360 }
361
362 /// Returns the number of threads.
363 pub fn thread_count(&self) -> usize {
364 self.threads.len()
365 }
366
367 pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
368 self.threads.iter()
369 }
370
371 pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
372 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
373 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
374 threads
375 }
376
377 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
378 cx.new(|cx| {
379 Thread::new(
380 self.project.clone(),
381 self.tools.clone(),
382 self.prompt_builder.clone(),
383 self.project_context.clone(),
384 cx,
385 )
386 })
387 }
388
389 pub fn create_thread_from_serialized(
390 &mut self,
391 serialized: SerializedThread,
392 cx: &mut Context<Self>,
393 ) -> Entity<Thread> {
394 cx.new(|cx| {
395 Thread::deserialize(
396 ThreadId::new(),
397 serialized,
398 self.project.clone(),
399 self.tools.clone(),
400 self.prompt_builder.clone(),
401 self.project_context.clone(),
402 None,
403 cx,
404 )
405 })
406 }
407
408 pub fn open_thread(
409 &self,
410 id: &ThreadId,
411 window: &mut Window,
412 cx: &mut Context<Self>,
413 ) -> Task<Result<Entity<Thread>>> {
414 let id = id.clone();
415 let database_future = ThreadsDatabase::global_future(cx);
416 let this = cx.weak_entity();
417 window.spawn(cx, async move |cx| {
418 let database = database_future.await.map_err(|err| anyhow!(err))?;
419 let thread = database
420 .try_find_thread(id.clone())
421 .await?
422 .with_context(|| format!("no thread found with ID: {id:?}"))?;
423
424 let thread = this.update_in(cx, |this, window, cx| {
425 cx.new(|cx| {
426 Thread::deserialize(
427 id.clone(),
428 thread,
429 this.project.clone(),
430 this.tools.clone(),
431 this.prompt_builder.clone(),
432 this.project_context.clone(),
433 Some(window),
434 cx,
435 )
436 })
437 })?;
438
439 Ok(thread)
440 })
441 }
442
443 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
444 let (metadata, serialized_thread) =
445 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
446
447 let database_future = ThreadsDatabase::global_future(cx);
448 cx.spawn(async move |this, cx| {
449 let serialized_thread = serialized_thread.await?;
450 let database = database_future.await.map_err(|err| anyhow!(err))?;
451 database.save_thread(metadata, serialized_thread).await?;
452
453 this.update(cx, |this, cx| this.reload(cx))?.await
454 })
455 }
456
457 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
458 let id = id.clone();
459 let database_future = ThreadsDatabase::global_future(cx);
460 cx.spawn(async move |this, cx| {
461 let database = database_future.await.map_err(|err| anyhow!(err))?;
462 database.delete_thread(id.clone()).await?;
463
464 this.update(cx, |this, cx| {
465 this.threads.retain(|thread| thread.id != id);
466 cx.notify();
467 })
468 })
469 }
470
471 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
472 let database_future = ThreadsDatabase::global_future(cx);
473 cx.spawn(async move |this, cx| {
474 let threads = database_future
475 .await
476 .map_err(|err| anyhow!(err))?
477 .list_threads()
478 .await?;
479
480 this.update(cx, |this, cx| {
481 this.threads = threads;
482 cx.notify();
483 })
484 })
485 }
486
487 fn load_default_profile(&self, cx: &mut Context<Self>) {
488 let assistant_settings = AgentSettings::get_global(cx);
489
490 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
491 }
492
493 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
494 let assistant_settings = AgentSettings::get_global(cx);
495
496 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
497 self.load_profile(profile.clone(), cx);
498 }
499 }
500
501 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
502 self.tools.update(cx, |tools, cx| {
503 tools.disable_all_tools(cx);
504 tools.enable(
505 ToolSource::Native,
506 &profile
507 .tools
508 .into_iter()
509 .filter_map(|(tool, enabled)| enabled.then(|| tool))
510 .collect::<Vec<_>>(),
511 cx,
512 );
513 });
514
515 if profile.enable_all_context_servers {
516 for context_server_id in self
517 .project
518 .read(cx)
519 .context_server_store()
520 .read(cx)
521 .all_server_ids()
522 {
523 self.tools.update(cx, |tools, cx| {
524 tools.enable_source(
525 ToolSource::ContextServer {
526 id: context_server_id.0.into(),
527 },
528 cx,
529 );
530 });
531 }
532 // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
533 for (context_server_id, preset) in profile.context_servers {
534 self.tools.update(cx, |tools, cx| {
535 tools.disable(
536 ToolSource::ContextServer {
537 id: context_server_id.into(),
538 },
539 &preset
540 .tools
541 .into_iter()
542 .filter_map(|(tool, enabled)| (!enabled).then(|| tool))
543 .collect::<Vec<_>>(),
544 cx,
545 )
546 })
547 }
548 } else {
549 for (context_server_id, preset) in profile.context_servers {
550 self.tools.update(cx, |tools, cx| {
551 tools.enable(
552 ToolSource::ContextServer {
553 id: context_server_id.into(),
554 },
555 &preset
556 .tools
557 .into_iter()
558 .filter_map(|(tool, enabled)| enabled.then(|| tool))
559 .collect::<Vec<_>>(),
560 cx,
561 )
562 })
563 }
564 }
565 }
566
567 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
568 cx.subscribe(
569 &self.project.read(cx).context_server_store(),
570 Self::handle_context_server_event,
571 )
572 .detach();
573 }
574
575 fn handle_context_server_event(
576 &mut self,
577 context_server_store: Entity<ContextServerStore>,
578 event: &project::context_server_store::Event,
579 cx: &mut Context<Self>,
580 ) {
581 let tool_working_set = self.tools.clone();
582 match event {
583 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
584 match status {
585 ContextServerStatus::Running => {
586 if let Some(server) =
587 context_server_store.read(cx).get_running_server(server_id)
588 {
589 let context_server_manager = context_server_store.clone();
590 cx.spawn({
591 let server = server.clone();
592 let server_id = server_id.clone();
593 async move |this, cx| {
594 let Some(protocol) = server.client() else {
595 return;
596 };
597
598 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
599 if let Some(tools) = protocol.list_tools().await.log_err() {
600 let tool_ids = tool_working_set
601 .update(cx, |tool_working_set, _| {
602 tools
603 .tools
604 .into_iter()
605 .map(|tool| {
606 log::info!(
607 "registering context server tool: {:?}",
608 tool.name
609 );
610 tool_working_set.insert(Arc::new(
611 ContextServerTool::new(
612 context_server_manager.clone(),
613 server.id(),
614 tool,
615 ),
616 ))
617 })
618 .collect::<Vec<_>>()
619 })
620 .log_err();
621
622 if let Some(tool_ids) = tool_ids {
623 this.update(cx, |this, cx| {
624 this.context_server_tool_ids
625 .insert(server_id, tool_ids);
626 this.load_default_profile(cx);
627 })
628 .log_err();
629 }
630 }
631 }
632 }
633 })
634 .detach();
635 }
636 }
637 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
638 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
639 tool_working_set.update(cx, |tool_working_set, _| {
640 tool_working_set.remove(&tool_ids);
641 });
642 self.load_default_profile(cx);
643 }
644 }
645 _ => {}
646 }
647 }
648 }
649 }
650}
651
652#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct SerializedThreadMetadata {
654 pub id: ThreadId,
655 pub summary: SharedString,
656 pub updated_at: DateTime<Utc>,
657}
658
659#[derive(Serialize, Deserialize, Debug)]
660pub struct SerializedThread {
661 pub version: String,
662 pub summary: SharedString,
663 pub updated_at: DateTime<Utc>,
664 pub messages: Vec<SerializedMessage>,
665 #[serde(default)]
666 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
667 #[serde(default)]
668 pub cumulative_token_usage: TokenUsage,
669 #[serde(default)]
670 pub request_token_usage: Vec<TokenUsage>,
671 #[serde(default)]
672 pub detailed_summary_state: DetailedSummaryState,
673 #[serde(default)]
674 pub exceeded_window_error: Option<ExceededWindowError>,
675 #[serde(default)]
676 pub model: Option<SerializedLanguageModel>,
677 #[serde(default)]
678 pub completion_mode: Option<CompletionMode>,
679 #[serde(default)]
680 pub tool_use_limit_reached: bool,
681}
682
683#[derive(Serialize, Deserialize, Debug)]
684pub struct SerializedLanguageModel {
685 pub provider: String,
686 pub model: String,
687}
688
689impl SerializedThread {
690 pub const VERSION: &'static str = "0.2.0";
691
692 pub fn from_json(json: &[u8]) -> Result<Self> {
693 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
694 match saved_thread_json.get("version") {
695 Some(serde_json::Value::String(version)) => match version.as_str() {
696 SerializedThreadV0_1_0::VERSION => {
697 let saved_thread =
698 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
699 Ok(saved_thread.upgrade())
700 }
701 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
702 saved_thread_json,
703 )?),
704 _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
705 },
706 None => {
707 let saved_thread =
708 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
709 Ok(saved_thread.upgrade())
710 }
711 version => anyhow::bail!("unrecognized serialized thread version: {version:?}"),
712 }
713 }
714}
715
716#[derive(Serialize, Deserialize, Debug)]
717pub struct SerializedThreadV0_1_0(
718 // The structure did not change, so we are reusing the latest SerializedThread.
719 // When making the next version, make sure this points to SerializedThreadV0_2_0
720 SerializedThread,
721);
722
723impl SerializedThreadV0_1_0 {
724 pub const VERSION: &'static str = "0.1.0";
725
726 pub fn upgrade(self) -> SerializedThread {
727 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
728
729 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
730
731 for message in self.0.messages {
732 if message.role == Role::User && !message.tool_results.is_empty() {
733 if let Some(last_message) = messages.last_mut() {
734 debug_assert!(last_message.role == Role::Assistant);
735
736 last_message.tool_results = message.tool_results;
737 continue;
738 }
739 }
740
741 messages.push(message);
742 }
743
744 SerializedThread { messages, ..self.0 }
745 }
746}
747
748#[derive(Debug, Serialize, Deserialize)]
749pub struct SerializedMessage {
750 pub id: MessageId,
751 pub role: Role,
752 #[serde(default)]
753 pub segments: Vec<SerializedMessageSegment>,
754 #[serde(default)]
755 pub tool_uses: Vec<SerializedToolUse>,
756 #[serde(default)]
757 pub tool_results: Vec<SerializedToolResult>,
758 #[serde(default)]
759 pub context: String,
760 #[serde(default)]
761 pub creases: Vec<SerializedCrease>,
762 #[serde(default)]
763 pub is_hidden: bool,
764}
765
766#[derive(Debug, Serialize, Deserialize)]
767#[serde(tag = "type")]
768pub enum SerializedMessageSegment {
769 #[serde(rename = "text")]
770 Text {
771 text: String,
772 },
773 #[serde(rename = "thinking")]
774 Thinking {
775 text: String,
776 #[serde(skip_serializing_if = "Option::is_none")]
777 signature: Option<String>,
778 },
779 RedactedThinking {
780 data: Vec<u8>,
781 },
782}
783
784#[derive(Debug, Serialize, Deserialize)]
785pub struct SerializedToolUse {
786 pub id: LanguageModelToolUseId,
787 pub name: SharedString,
788 pub input: serde_json::Value,
789}
790
791#[derive(Debug, Serialize, Deserialize)]
792pub struct SerializedToolResult {
793 pub tool_use_id: LanguageModelToolUseId,
794 pub is_error: bool,
795 pub content: LanguageModelToolResultContent,
796 pub output: Option<serde_json::Value>,
797}
798
799#[derive(Serialize, Deserialize)]
800struct LegacySerializedThread {
801 pub summary: SharedString,
802 pub updated_at: DateTime<Utc>,
803 pub messages: Vec<LegacySerializedMessage>,
804 #[serde(default)]
805 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
806}
807
808impl LegacySerializedThread {
809 pub fn upgrade(self) -> SerializedThread {
810 SerializedThread {
811 version: SerializedThread::VERSION.to_string(),
812 summary: self.summary,
813 updated_at: self.updated_at,
814 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
815 initial_project_snapshot: self.initial_project_snapshot,
816 cumulative_token_usage: TokenUsage::default(),
817 request_token_usage: Vec::new(),
818 detailed_summary_state: DetailedSummaryState::default(),
819 exceeded_window_error: None,
820 model: None,
821 completion_mode: None,
822 tool_use_limit_reached: false,
823 }
824 }
825}
826
827#[derive(Debug, Serialize, Deserialize)]
828struct LegacySerializedMessage {
829 pub id: MessageId,
830 pub role: Role,
831 pub text: String,
832 #[serde(default)]
833 pub tool_uses: Vec<SerializedToolUse>,
834 #[serde(default)]
835 pub tool_results: Vec<SerializedToolResult>,
836}
837
838impl LegacySerializedMessage {
839 fn upgrade(self) -> SerializedMessage {
840 SerializedMessage {
841 id: self.id,
842 role: self.role,
843 segments: vec![SerializedMessageSegment::Text { text: self.text }],
844 tool_uses: self.tool_uses,
845 tool_results: self.tool_results,
846 context: String::new(),
847 creases: Vec::new(),
848 is_hidden: false,
849 }
850 }
851}
852
853#[derive(Debug, Serialize, Deserialize)]
854pub struct SerializedCrease {
855 pub start: usize,
856 pub end: usize,
857 pub icon_path: SharedString,
858 pub label: SharedString,
859}
860
861struct GlobalThreadsDatabase(
862 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
863);
864
865impl Global for GlobalThreadsDatabase {}
866
867pub(crate) struct ThreadsDatabase {
868 executor: BackgroundExecutor,
869 env: heed::Env,
870 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
871}
872
873impl heed::BytesEncode<'_> for SerializedThread {
874 type EItem = SerializedThread;
875
876 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
877 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
878 }
879}
880
881impl<'a> heed::BytesDecode<'a> for SerializedThread {
882 type DItem = SerializedThread;
883
884 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
885 // We implement this type manually because we want to call `SerializedThread::from_json`,
886 // instead of the Deserialize trait implementation for `SerializedThread`.
887 SerializedThread::from_json(bytes).map_err(Into::into)
888 }
889}
890
891impl ThreadsDatabase {
892 fn global_future(
893 cx: &mut App,
894 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
895 GlobalThreadsDatabase::global(cx).0.clone()
896 }
897
898 fn init(cx: &mut App) {
899 let executor = cx.background_executor().clone();
900 let database_future = executor
901 .spawn({
902 let executor = executor.clone();
903 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
904 async move { ThreadsDatabase::new(database_path, executor) }
905 })
906 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
907 .boxed()
908 .shared();
909
910 cx.set_global(GlobalThreadsDatabase(database_future));
911 }
912
913 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
914 std::fs::create_dir_all(&path)?;
915
916 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
917 let env = unsafe {
918 heed::EnvOpenOptions::new()
919 .map_size(ONE_GB_IN_BYTES)
920 .max_dbs(1)
921 .open(path)?
922 };
923
924 let mut txn = env.write_txn()?;
925 let threads = env.create_database(&mut txn, Some("threads"))?;
926 txn.commit()?;
927
928 Ok(Self {
929 executor,
930 env,
931 threads,
932 })
933 }
934
935 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
936 let env = self.env.clone();
937 let threads = self.threads;
938
939 self.executor.spawn(async move {
940 let txn = env.read_txn()?;
941 let mut iter = threads.iter(&txn)?;
942 let mut threads = Vec::new();
943 while let Some((key, value)) = iter.next().transpose()? {
944 threads.push(SerializedThreadMetadata {
945 id: key,
946 summary: value.summary,
947 updated_at: value.updated_at,
948 });
949 }
950
951 Ok(threads)
952 })
953 }
954
955 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
956 let env = self.env.clone();
957 let threads = self.threads;
958
959 self.executor.spawn(async move {
960 let txn = env.read_txn()?;
961 let thread = threads.get(&txn, &id)?;
962 Ok(thread)
963 })
964 }
965
966 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
967 let env = self.env.clone();
968 let threads = self.threads;
969
970 self.executor.spawn(async move {
971 let mut txn = env.write_txn()?;
972 threads.put(&mut txn, &id, &thread)?;
973 txn.commit()?;
974 Ok(())
975 })
976 }
977
978 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
979 let env = self.env.clone();
980 let threads = self.threads;
981
982 self.executor.spawn(async move {
983 let mut txn = env.write_txn()?;
984 threads.delete(&mut txn, &id)?;
985 txn.commit()?;
986 Ok(())
987 })
988 }
989}