1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use futures::channel::{mpsc, oneshot};
14use futures::future::{self, BoxFuture, Shared};
15use futures::{FutureExt as _, StreamExt as _};
16use gpui::{
17 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
18 Subscription, Task, prelude::*,
19};
20use heed::Database;
21use heed::types::SerdeBincode;
22use language_model::{LanguageModelToolUseId, Role, TokenUsage};
23use project::context_server_store::{ContextServerStatus, ContextServerStore};
24use project::{Project, ProjectItem, ProjectPath, Worktree};
25use prompt_store::{
26 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
27 UserRulesContext, WorktreeContext,
28};
29use serde::{Deserialize, Serialize};
30use settings::{Settings as _, SettingsStore};
31use ui::Window;
32use util::ResultExt as _;
33
34use crate::context_server_tool::ContextServerTool;
35use crate::thread::{
36 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
37};
38
39const RULES_FILE_NAMES: [&'static str; 6] = [
40 ".rules",
41 ".cursorrules",
42 ".windsurfrules",
43 ".clinerules",
44 ".github/copilot-instructions.md",
45 "CLAUDE.md",
46];
47
48pub fn init(cx: &mut App) {
49 ThreadsDatabase::init(cx);
50}
51
52/// A system prompt shared by all threads created by this ThreadStore
53#[derive(Clone, Default)]
54pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
55
56impl SharedProjectContext {
57 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
58 self.0.borrow()
59 }
60}
61
62pub type TextThreadStore = assistant_context_editor::ContextStore;
63
64pub struct ThreadStore {
65 project: Entity<Project>,
66 tools: Entity<ToolWorkingSet>,
67 prompt_builder: Arc<PromptBuilder>,
68 prompt_store: Option<Entity<PromptStore>>,
69 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
70 threads: Vec<SerializedThreadMetadata>,
71 project_context: SharedProjectContext,
72 reload_system_prompt_tx: mpsc::Sender<()>,
73 _reload_system_prompt_task: Task<()>,
74 _subscriptions: Vec<Subscription>,
75}
76
77pub struct RulesLoadingError {
78 pub message: SharedString,
79}
80
81impl EventEmitter<RulesLoadingError> for ThreadStore {}
82
83impl ThreadStore {
84 pub fn load(
85 project: Entity<Project>,
86 tools: Entity<ToolWorkingSet>,
87 prompt_store: Option<Entity<PromptStore>>,
88 prompt_builder: Arc<PromptBuilder>,
89 cx: &mut App,
90 ) -> Task<Result<Entity<Self>>> {
91 cx.spawn(async move |cx| {
92 let (thread_store, ready_rx) = cx.update(|cx| {
93 let mut option_ready_rx = None;
94 let thread_store = cx.new(|cx| {
95 let (thread_store, ready_rx) =
96 Self::new(project, tools, prompt_builder, prompt_store, cx);
97 option_ready_rx = Some(ready_rx);
98 thread_store
99 });
100 (thread_store, option_ready_rx.take().unwrap())
101 })?;
102 ready_rx.await?;
103 Ok(thread_store)
104 })
105 }
106
107 fn new(
108 project: Entity<Project>,
109 tools: Entity<ToolWorkingSet>,
110 prompt_builder: Arc<PromptBuilder>,
111 prompt_store: Option<Entity<PromptStore>>,
112 cx: &mut Context<Self>,
113 ) -> (Self, oneshot::Receiver<()>) {
114 let mut subscriptions = vec![
115 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
116 this.load_default_profile(cx);
117 }),
118 cx.subscribe(&project, Self::handle_project_event),
119 ];
120
121 if let Some(prompt_store) = prompt_store.as_ref() {
122 subscriptions.push(cx.subscribe(
123 prompt_store,
124 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
125 this.enqueue_system_prompt_reload();
126 },
127 ))
128 }
129
130 // This channel and task prevent concurrent and redundant loading of the system prompt.
131 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
132 let (ready_tx, ready_rx) = oneshot::channel();
133 let mut ready_tx = Some(ready_tx);
134 let reload_system_prompt_task = cx.spawn({
135 let prompt_store = prompt_store.clone();
136 async move |thread_store, cx| {
137 loop {
138 let Some(reload_task) = thread_store
139 .update(cx, |thread_store, cx| {
140 thread_store.reload_system_prompt(prompt_store.clone(), cx)
141 })
142 .ok()
143 else {
144 return;
145 };
146 reload_task.await;
147 if let Some(ready_tx) = ready_tx.take() {
148 ready_tx.send(()).ok();
149 }
150 reload_system_prompt_rx.next().await;
151 }
152 }
153 });
154
155 let this = Self {
156 project,
157 tools,
158 prompt_builder,
159 prompt_store,
160 context_server_tool_ids: HashMap::default(),
161 threads: Vec::new(),
162 project_context: SharedProjectContext::default(),
163 reload_system_prompt_tx,
164 _reload_system_prompt_task: reload_system_prompt_task,
165 _subscriptions: subscriptions,
166 };
167 this.load_default_profile(cx);
168 this.register_context_server_handlers(cx);
169 this.reload(cx).detach_and_log_err(cx);
170 (this, ready_rx)
171 }
172
173 fn handle_project_event(
174 &mut self,
175 _project: Entity<Project>,
176 event: &project::Event,
177 _cx: &mut Context<Self>,
178 ) {
179 match event {
180 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
181 self.enqueue_system_prompt_reload();
182 }
183 project::Event::WorktreeUpdatedEntries(_, items) => {
184 if items.iter().any(|(path, _, _)| {
185 RULES_FILE_NAMES
186 .iter()
187 .any(|name| path.as_ref() == Path::new(name))
188 }) {
189 self.enqueue_system_prompt_reload();
190 }
191 }
192 _ => {}
193 }
194 }
195
196 fn enqueue_system_prompt_reload(&mut self) {
197 self.reload_system_prompt_tx.try_send(()).ok();
198 }
199
200 // Note that this should only be called from `reload_system_prompt_task`.
201 fn reload_system_prompt(
202 &self,
203 prompt_store: Option<Entity<PromptStore>>,
204 cx: &mut Context<Self>,
205 ) -> Task<()> {
206 let worktrees = self
207 .project
208 .read(cx)
209 .visible_worktrees(cx)
210 .collect::<Vec<_>>();
211 let worktree_tasks = worktrees
212 .into_iter()
213 .map(|worktree| {
214 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
215 })
216 .collect::<Vec<_>>();
217 let default_user_rules_task = match prompt_store {
218 None => Task::ready(vec![]),
219 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
220 let prompts = prompt_store.default_prompt_metadata();
221 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
222 let contents = prompt_store.load(prompt_metadata.id, cx);
223 async move { (contents.await, prompt_metadata) }
224 });
225 cx.background_spawn(future::join_all(load_tasks))
226 }),
227 };
228
229 cx.spawn(async move |this, cx| {
230 let (worktrees, default_user_rules) =
231 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
232
233 let worktrees = worktrees
234 .into_iter()
235 .map(|(worktree, rules_error)| {
236 if let Some(rules_error) = rules_error {
237 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
238 }
239 worktree
240 })
241 .collect::<Vec<_>>();
242
243 let default_user_rules = default_user_rules
244 .into_iter()
245 .flat_map(|(contents, prompt_metadata)| match contents {
246 Ok(contents) => Some(UserRulesContext {
247 uuid: match prompt_metadata.id {
248 PromptId::User { uuid } => uuid,
249 PromptId::EditWorkflow => return None,
250 },
251 title: prompt_metadata.title.map(|title| title.to_string()),
252 contents,
253 }),
254 Err(err) => {
255 this.update(cx, |_, cx| {
256 cx.emit(RulesLoadingError {
257 message: format!("{err:?}").into(),
258 });
259 })
260 .ok();
261 None
262 }
263 })
264 .collect::<Vec<_>>();
265
266 this.update(cx, |this, _cx| {
267 *this.project_context.0.borrow_mut() =
268 Some(ProjectContext::new(worktrees, default_user_rules));
269 })
270 .ok();
271 })
272 }
273
274 fn load_worktree_info_for_system_prompt(
275 worktree: Entity<Worktree>,
276 project: Entity<Project>,
277 cx: &mut App,
278 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
279 let root_name = worktree.read(cx).root_name().into();
280
281 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
282 let Some(rules_task) = rules_task else {
283 return Task::ready((
284 WorktreeContext {
285 root_name,
286 rules_file: None,
287 },
288 None,
289 ));
290 };
291
292 cx.spawn(async move |_| {
293 let (rules_file, rules_file_error) = match rules_task.await {
294 Ok(rules_file) => (Some(rules_file), None),
295 Err(err) => (
296 None,
297 Some(RulesLoadingError {
298 message: format!("{err}").into(),
299 }),
300 ),
301 };
302 let worktree_info = WorktreeContext {
303 root_name,
304 rules_file,
305 };
306 (worktree_info, rules_file_error)
307 })
308 }
309
310 fn load_worktree_rules_file(
311 worktree: Entity<Worktree>,
312 project: Entity<Project>,
313 cx: &mut App,
314 ) -> Option<Task<Result<RulesFileContext>>> {
315 let worktree_ref = worktree.read(cx);
316 let worktree_id = worktree_ref.id();
317 let selected_rules_file = RULES_FILE_NAMES
318 .into_iter()
319 .filter_map(|name| {
320 worktree_ref
321 .entry_for_path(name)
322 .filter(|entry| entry.is_file())
323 .map(|entry| entry.path.clone())
324 })
325 .next();
326
327 // Note that Cline supports `.clinerules` being a directory, but that is not currently
328 // supported. This doesn't seem to occur often in GitHub repositories.
329 selected_rules_file.map(|path_in_worktree| {
330 let project_path = ProjectPath {
331 worktree_id,
332 path: path_in_worktree.clone(),
333 };
334 let buffer_task =
335 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
336 let rope_task = cx.spawn(async move |cx| {
337 buffer_task.await?.read_with(cx, |buffer, cx| {
338 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
339 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
340 })?
341 });
342 // Build a string from the rope on a background thread.
343 cx.background_spawn(async move {
344 let (project_entry_id, rope) = rope_task.await?;
345 anyhow::Ok(RulesFileContext {
346 path_in_worktree,
347 text: rope.to_string().trim().to_string(),
348 project_entry_id: project_entry_id.to_usize(),
349 })
350 })
351 })
352 }
353
354 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
355 &self.prompt_store
356 }
357
358 pub fn tools(&self) -> Entity<ToolWorkingSet> {
359 self.tools.clone()
360 }
361
362 /// Returns the number of threads.
363 pub fn thread_count(&self) -> usize {
364 self.threads.len()
365 }
366
367 pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
368 self.threads.iter()
369 }
370
371 pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
372 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
373 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
374 threads
375 }
376
377 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
378 cx.new(|cx| {
379 Thread::new(
380 self.project.clone(),
381 self.tools.clone(),
382 self.prompt_builder.clone(),
383 self.project_context.clone(),
384 cx,
385 )
386 })
387 }
388
389 pub fn open_thread(
390 &self,
391 id: &ThreadId,
392 window: &mut Window,
393 cx: &mut Context<Self>,
394 ) -> Task<Result<Entity<Thread>>> {
395 let id = id.clone();
396 let database_future = ThreadsDatabase::global_future(cx);
397 let this = cx.weak_entity();
398 window.spawn(cx, async move |cx| {
399 let database = database_future.await.map_err(|err| anyhow!(err))?;
400 let thread = database
401 .try_find_thread(id.clone())
402 .await?
403 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
404
405 let thread = this.update_in(cx, |this, window, cx| {
406 cx.new(|cx| {
407 Thread::deserialize(
408 id.clone(),
409 thread,
410 this.project.clone(),
411 this.tools.clone(),
412 this.prompt_builder.clone(),
413 this.project_context.clone(),
414 window,
415 cx,
416 )
417 })
418 })?;
419
420 Ok(thread)
421 })
422 }
423
424 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
425 let (metadata, serialized_thread) =
426 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
427
428 let database_future = ThreadsDatabase::global_future(cx);
429 cx.spawn(async move |this, cx| {
430 let serialized_thread = serialized_thread.await?;
431 let database = database_future.await.map_err(|err| anyhow!(err))?;
432 database.save_thread(metadata, serialized_thread).await?;
433
434 this.update(cx, |this, cx| this.reload(cx))?.await
435 })
436 }
437
438 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
439 let id = id.clone();
440 let database_future = ThreadsDatabase::global_future(cx);
441 cx.spawn(async move |this, cx| {
442 let database = database_future.await.map_err(|err| anyhow!(err))?;
443 database.delete_thread(id.clone()).await?;
444
445 this.update(cx, |this, cx| {
446 this.threads.retain(|thread| thread.id != id);
447 cx.notify();
448 })
449 })
450 }
451
452 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
453 let database_future = ThreadsDatabase::global_future(cx);
454 cx.spawn(async move |this, cx| {
455 let threads = database_future
456 .await
457 .map_err(|err| anyhow!(err))?
458 .list_threads()
459 .await?;
460
461 this.update(cx, |this, cx| {
462 this.threads = threads;
463 cx.notify();
464 })
465 })
466 }
467
468 fn load_default_profile(&self, cx: &mut Context<Self>) {
469 let assistant_settings = AssistantSettings::get_global(cx);
470
471 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
472 }
473
474 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
475 let assistant_settings = AssistantSettings::get_global(cx);
476
477 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
478 self.load_profile(profile.clone(), cx);
479 }
480 }
481
482 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
483 self.tools.update(cx, |tools, cx| {
484 tools.disable_all_tools(cx);
485 tools.enable(
486 ToolSource::Native,
487 &profile
488 .tools
489 .iter()
490 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
491 .collect::<Vec<_>>(),
492 cx,
493 );
494 });
495
496 if profile.enable_all_context_servers {
497 for context_server_id in self
498 .project
499 .read(cx)
500 .context_server_store()
501 .read(cx)
502 .all_server_ids()
503 {
504 self.tools.update(cx, |tools, cx| {
505 tools.enable_source(
506 ToolSource::ContextServer {
507 id: context_server_id.0.into(),
508 },
509 cx,
510 );
511 });
512 }
513 // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
514 for (context_server_id, preset) in &profile.context_servers {
515 self.tools.update(cx, |tools, cx| {
516 tools.disable(
517 ToolSource::ContextServer {
518 id: context_server_id.clone().into(),
519 },
520 &preset
521 .tools
522 .iter()
523 .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
524 .collect::<Vec<_>>(),
525 cx,
526 )
527 })
528 }
529 } else {
530 for (context_server_id, preset) in &profile.context_servers {
531 self.tools.update(cx, |tools, cx| {
532 tools.enable(
533 ToolSource::ContextServer {
534 id: context_server_id.clone().into(),
535 },
536 &preset
537 .tools
538 .iter()
539 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
540 .collect::<Vec<_>>(),
541 cx,
542 )
543 })
544 }
545 }
546 }
547
548 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
549 cx.subscribe(
550 &self.project.read(cx).context_server_store(),
551 Self::handle_context_server_event,
552 )
553 .detach();
554 }
555
556 fn handle_context_server_event(
557 &mut self,
558 context_server_store: Entity<ContextServerStore>,
559 event: &project::context_server_store::Event,
560 cx: &mut Context<Self>,
561 ) {
562 let tool_working_set = self.tools.clone();
563 match event {
564 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
565 match status {
566 ContextServerStatus::Running => {
567 if let Some(server) =
568 context_server_store.read(cx).get_running_server(server_id)
569 {
570 let context_server_manager = context_server_store.clone();
571 cx.spawn({
572 let server = server.clone();
573 let server_id = server_id.clone();
574 async move |this, cx| {
575 let Some(protocol) = server.client() else {
576 return;
577 };
578
579 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
580 if let Some(tools) = protocol.list_tools().await.log_err() {
581 let tool_ids = tool_working_set
582 .update(cx, |tool_working_set, _| {
583 tools
584 .tools
585 .into_iter()
586 .map(|tool| {
587 log::info!(
588 "registering context server tool: {:?}",
589 tool.name
590 );
591 tool_working_set.insert(Arc::new(
592 ContextServerTool::new(
593 context_server_manager.clone(),
594 server.id(),
595 tool,
596 ),
597 ))
598 })
599 .collect::<Vec<_>>()
600 })
601 .log_err();
602
603 if let Some(tool_ids) = tool_ids {
604 this.update(cx, |this, cx| {
605 this.context_server_tool_ids
606 .insert(server_id, tool_ids);
607 this.load_default_profile(cx);
608 })
609 .log_err();
610 }
611 }
612 }
613 }
614 })
615 .detach();
616 }
617 }
618 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
619 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
620 tool_working_set.update(cx, |tool_working_set, _| {
621 tool_working_set.remove(&tool_ids);
622 });
623 self.load_default_profile(cx);
624 }
625 }
626 _ => {}
627 }
628 }
629 }
630 }
631}
632
633#[derive(Debug, Clone, Serialize, Deserialize)]
634pub struct SerializedThreadMetadata {
635 pub id: ThreadId,
636 pub summary: SharedString,
637 pub updated_at: DateTime<Utc>,
638}
639
640#[derive(Serialize, Deserialize, Debug)]
641pub struct SerializedThread {
642 pub version: String,
643 pub summary: SharedString,
644 pub updated_at: DateTime<Utc>,
645 pub messages: Vec<SerializedMessage>,
646 #[serde(default)]
647 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
648 #[serde(default)]
649 pub cumulative_token_usage: TokenUsage,
650 #[serde(default)]
651 pub request_token_usage: Vec<TokenUsage>,
652 #[serde(default)]
653 pub detailed_summary_state: DetailedSummaryState,
654 #[serde(default)]
655 pub exceeded_window_error: Option<ExceededWindowError>,
656 #[serde(default)]
657 pub model: Option<SerializedLanguageModel>,
658 #[serde(default)]
659 pub completion_mode: Option<CompletionMode>,
660 #[serde(default)]
661 pub profile: Option<AgentProfileId>,
662}
663
664#[derive(Serialize, Deserialize, Debug)]
665pub struct SerializedLanguageModel {
666 pub provider: String,
667 pub model: String,
668}
669
670impl SerializedThread {
671 pub const VERSION: &'static str = "0.2.0";
672
673 pub fn from_json(json: &[u8]) -> Result<Self> {
674 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
675 match saved_thread_json.get("version") {
676 Some(serde_json::Value::String(version)) => match version.as_str() {
677 SerializedThreadV0_1_0::VERSION => {
678 let saved_thread =
679 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
680 Ok(saved_thread.upgrade())
681 }
682 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
683 saved_thread_json,
684 )?),
685 _ => Err(anyhow!(
686 "unrecognized serialized thread version: {}",
687 version
688 )),
689 },
690 None => {
691 let saved_thread =
692 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
693 Ok(saved_thread.upgrade())
694 }
695 version => Err(anyhow!(
696 "unrecognized serialized thread version: {:?}",
697 version
698 )),
699 }
700 }
701}
702
703#[derive(Serialize, Deserialize, Debug)]
704pub struct SerializedThreadV0_1_0(
705 // The structure did not change, so we are reusing the latest SerializedThread.
706 // When making the next version, make sure this points to SerializedThreadV0_2_0
707 SerializedThread,
708);
709
710impl SerializedThreadV0_1_0 {
711 pub const VERSION: &'static str = "0.1.0";
712
713 pub fn upgrade(self) -> SerializedThread {
714 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
715
716 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
717
718 for message in self.0.messages {
719 if message.role == Role::User && !message.tool_results.is_empty() {
720 if let Some(last_message) = messages.last_mut() {
721 debug_assert!(last_message.role == Role::Assistant);
722
723 last_message.tool_results = message.tool_results;
724 continue;
725 }
726 }
727
728 messages.push(message);
729 }
730
731 SerializedThread { messages, ..self.0 }
732 }
733}
734
735#[derive(Debug, Serialize, Deserialize)]
736pub struct SerializedMessage {
737 pub id: MessageId,
738 pub role: Role,
739 #[serde(default)]
740 pub segments: Vec<SerializedMessageSegment>,
741 #[serde(default)]
742 pub tool_uses: Vec<SerializedToolUse>,
743 #[serde(default)]
744 pub tool_results: Vec<SerializedToolResult>,
745 #[serde(default)]
746 pub context: String,
747 #[serde(default)]
748 pub creases: Vec<SerializedCrease>,
749}
750
751#[derive(Debug, Serialize, Deserialize)]
752#[serde(tag = "type")]
753pub enum SerializedMessageSegment {
754 #[serde(rename = "text")]
755 Text {
756 text: String,
757 },
758 #[serde(rename = "thinking")]
759 Thinking {
760 text: String,
761 #[serde(skip_serializing_if = "Option::is_none")]
762 signature: Option<String>,
763 },
764 RedactedThinking {
765 data: Vec<u8>,
766 },
767}
768
769#[derive(Debug, Serialize, Deserialize)]
770pub struct SerializedToolUse {
771 pub id: LanguageModelToolUseId,
772 pub name: SharedString,
773 pub input: serde_json::Value,
774}
775
776#[derive(Debug, Serialize, Deserialize)]
777pub struct SerializedToolResult {
778 pub tool_use_id: LanguageModelToolUseId,
779 pub is_error: bool,
780 pub content: Arc<str>,
781 pub output: Option<serde_json::Value>,
782}
783
784#[derive(Serialize, Deserialize)]
785struct LegacySerializedThread {
786 pub summary: SharedString,
787 pub updated_at: DateTime<Utc>,
788 pub messages: Vec<LegacySerializedMessage>,
789 #[serde(default)]
790 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
791}
792
793impl LegacySerializedThread {
794 pub fn upgrade(self) -> SerializedThread {
795 SerializedThread {
796 version: SerializedThread::VERSION.to_string(),
797 summary: self.summary,
798 updated_at: self.updated_at,
799 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
800 initial_project_snapshot: self.initial_project_snapshot,
801 cumulative_token_usage: TokenUsage::default(),
802 request_token_usage: Vec::new(),
803 detailed_summary_state: DetailedSummaryState::default(),
804 exceeded_window_error: None,
805 model: None,
806 completion_mode: None,
807 profile: None,
808 }
809 }
810}
811
812#[derive(Debug, Serialize, Deserialize)]
813struct LegacySerializedMessage {
814 pub id: MessageId,
815 pub role: Role,
816 pub text: String,
817 #[serde(default)]
818 pub tool_uses: Vec<SerializedToolUse>,
819 #[serde(default)]
820 pub tool_results: Vec<SerializedToolResult>,
821}
822
823impl LegacySerializedMessage {
824 fn upgrade(self) -> SerializedMessage {
825 SerializedMessage {
826 id: self.id,
827 role: self.role,
828 segments: vec![SerializedMessageSegment::Text { text: self.text }],
829 tool_uses: self.tool_uses,
830 tool_results: self.tool_results,
831 context: String::new(),
832 creases: Vec::new(),
833 }
834 }
835}
836
837#[derive(Debug, Serialize, Deserialize)]
838pub struct SerializedCrease {
839 pub start: usize,
840 pub end: usize,
841 pub icon_path: SharedString,
842 pub label: SharedString,
843}
844
845struct GlobalThreadsDatabase(
846 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
847);
848
849impl Global for GlobalThreadsDatabase {}
850
851pub(crate) struct ThreadsDatabase {
852 executor: BackgroundExecutor,
853 env: heed::Env,
854 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
855}
856
857impl heed::BytesEncode<'_> for SerializedThread {
858 type EItem = SerializedThread;
859
860 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
861 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
862 }
863}
864
865impl<'a> heed::BytesDecode<'a> for SerializedThread {
866 type DItem = SerializedThread;
867
868 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
869 // We implement this type manually because we want to call `SerializedThread::from_json`,
870 // instead of the Deserialize trait implementation for `SerializedThread`.
871 SerializedThread::from_json(bytes).map_err(Into::into)
872 }
873}
874
875impl ThreadsDatabase {
876 fn global_future(
877 cx: &mut App,
878 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
879 GlobalThreadsDatabase::global(cx).0.clone()
880 }
881
882 fn init(cx: &mut App) {
883 let executor = cx.background_executor().clone();
884 let database_future = executor
885 .spawn({
886 let executor = executor.clone();
887 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
888 async move { ThreadsDatabase::new(database_path, executor) }
889 })
890 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
891 .boxed()
892 .shared();
893
894 cx.set_global(GlobalThreadsDatabase(database_future));
895 }
896
897 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
898 std::fs::create_dir_all(&path)?;
899
900 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
901 let env = unsafe {
902 heed::EnvOpenOptions::new()
903 .map_size(ONE_GB_IN_BYTES)
904 .max_dbs(1)
905 .open(path)?
906 };
907
908 let mut txn = env.write_txn()?;
909 let threads = env.create_database(&mut txn, Some("threads"))?;
910 txn.commit()?;
911
912 Ok(Self {
913 executor,
914 env,
915 threads,
916 })
917 }
918
919 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
920 let env = self.env.clone();
921 let threads = self.threads;
922
923 self.executor.spawn(async move {
924 let txn = env.read_txn()?;
925 let mut iter = threads.iter(&txn)?;
926 let mut threads = Vec::new();
927 while let Some((key, value)) = iter.next().transpose()? {
928 threads.push(SerializedThreadMetadata {
929 id: key,
930 summary: value.summary,
931 updated_at: value.updated_at,
932 });
933 }
934
935 Ok(threads)
936 })
937 }
938
939 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
940 let env = self.env.clone();
941 let threads = self.threads;
942
943 self.executor.spawn(async move {
944 let txn = env.read_txn()?;
945 let thread = threads.get(&txn, &id)?;
946 Ok(thread)
947 })
948 }
949
950 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
951 let env = self.env.clone();
952 let threads = self.threads;
953
954 self.executor.spawn(async move {
955 let mut txn = env.write_txn()?;
956 threads.put(&mut txn, &id, &thread)?;
957 txn.commit()?;
958 Ok(())
959 })
960 }
961
962 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
963 let env = self.env.clone();
964 let threads = self.threads;
965
966 self.executor.spawn(async move {
967 let mut txn = env.write_txn()?;
968 threads.delete(&mut txn, &id)?;
969 txn.commit()?;
970 Ok(())
971 })
972 }
973}