1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::manager::{ContextServerManager, ContextServerStatus};
13use context_server::{ContextServerDescriptorRegistry, ContextServerTool};
14use futures::channel::{mpsc, oneshot};
15use futures::future::{self, BoxFuture, Shared};
16use futures::{FutureExt as _, StreamExt as _};
17use gpui::{
18 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
19 Subscription, Task, prelude::*,
20};
21use heed::Database;
22use heed::types::SerdeBincode;
23use language_model::{LanguageModelToolUseId, Role, TokenUsage};
24use project::{Project, ProjectItem, ProjectPath, Worktree};
25use prompt_store::{
26 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
27 UserRulesContext, WorktreeContext,
28};
29use serde::{Deserialize, Serialize};
30use settings::{Settings as _, SettingsStore};
31use util::ResultExt as _;
32
33use crate::thread::{
34 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
35};
36
37const RULES_FILE_NAMES: [&'static str; 6] = [
38 ".rules",
39 ".cursorrules",
40 ".windsurfrules",
41 ".clinerules",
42 ".github/copilot-instructions.md",
43 "CLAUDE.md",
44];
45
46pub fn init(cx: &mut App) {
47 ThreadsDatabase::init(cx);
48}
49
50/// A system prompt shared by all threads created by this ThreadStore
51#[derive(Clone, Default)]
52pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
53
54impl SharedProjectContext {
55 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
56 self.0.borrow()
57 }
58}
59
60pub struct ThreadStore {
61 project: Entity<Project>,
62 tools: Entity<ToolWorkingSet>,
63 prompt_builder: Arc<PromptBuilder>,
64 prompt_store: Option<Entity<PromptStore>>,
65 context_server_manager: Entity<ContextServerManager>,
66 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
67 threads: Vec<SerializedThreadMetadata>,
68 project_context: SharedProjectContext,
69 reload_system_prompt_tx: mpsc::Sender<()>,
70 _reload_system_prompt_task: Task<()>,
71 _subscriptions: Vec<Subscription>,
72}
73
74pub struct RulesLoadingError {
75 pub message: SharedString,
76}
77
78impl EventEmitter<RulesLoadingError> for ThreadStore {}
79
80impl ThreadStore {
81 pub fn load(
82 project: Entity<Project>,
83 tools: Entity<ToolWorkingSet>,
84 prompt_store: Option<Entity<PromptStore>>,
85 prompt_builder: Arc<PromptBuilder>,
86 cx: &mut App,
87 ) -> Task<Result<Entity<Self>>> {
88 cx.spawn(async move |cx| {
89 let (thread_store, ready_rx) = cx.update(|cx| {
90 let mut option_ready_rx = None;
91 let thread_store = cx.new(|cx| {
92 let (thread_store, ready_rx) =
93 Self::new(project, tools, prompt_builder, prompt_store, cx);
94 option_ready_rx = Some(ready_rx);
95 thread_store
96 });
97 (thread_store, option_ready_rx.take().unwrap())
98 })?;
99 ready_rx.await?;
100 Ok(thread_store)
101 })
102 }
103
104 fn new(
105 project: Entity<Project>,
106 tools: Entity<ToolWorkingSet>,
107 prompt_builder: Arc<PromptBuilder>,
108 prompt_store: Option<Entity<PromptStore>>,
109 cx: &mut Context<Self>,
110 ) -> (Self, oneshot::Receiver<()>) {
111 let context_server_factory_registry = ContextServerDescriptorRegistry::default_global(cx);
112 let context_server_manager = cx.new(|cx| {
113 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
114 });
115
116 let mut subscriptions = vec![
117 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
118 this.load_default_profile(cx);
119 }),
120 cx.subscribe(&project, Self::handle_project_event),
121 ];
122
123 if let Some(prompt_store) = prompt_store.as_ref() {
124 subscriptions.push(cx.subscribe(
125 prompt_store,
126 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
127 this.enqueue_system_prompt_reload();
128 },
129 ))
130 }
131
132 // This channel and task prevent concurrent and redundant loading of the system prompt.
133 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
134 let (ready_tx, ready_rx) = oneshot::channel();
135 let mut ready_tx = Some(ready_tx);
136 let reload_system_prompt_task = cx.spawn({
137 let prompt_store = prompt_store.clone();
138 async move |thread_store, cx| {
139 loop {
140 let Some(reload_task) = thread_store
141 .update(cx, |thread_store, cx| {
142 thread_store.reload_system_prompt(prompt_store.clone(), cx)
143 })
144 .ok()
145 else {
146 return;
147 };
148 reload_task.await;
149 if let Some(ready_tx) = ready_tx.take() {
150 ready_tx.send(()).ok();
151 }
152 reload_system_prompt_rx.next().await;
153 }
154 }
155 });
156
157 let this = Self {
158 project,
159 tools,
160 prompt_builder,
161 prompt_store,
162 context_server_manager,
163 context_server_tool_ids: HashMap::default(),
164 threads: Vec::new(),
165 project_context: SharedProjectContext::default(),
166 reload_system_prompt_tx,
167 _reload_system_prompt_task: reload_system_prompt_task,
168 _subscriptions: subscriptions,
169 };
170 this.load_default_profile(cx);
171 this.register_context_server_handlers(cx);
172 this.reload(cx).detach_and_log_err(cx);
173 (this, ready_rx)
174 }
175
176 fn handle_project_event(
177 &mut self,
178 _project: Entity<Project>,
179 event: &project::Event,
180 _cx: &mut Context<Self>,
181 ) {
182 match event {
183 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
184 self.enqueue_system_prompt_reload();
185 }
186 project::Event::WorktreeUpdatedEntries(_, items) => {
187 if items.iter().any(|(path, _, _)| {
188 RULES_FILE_NAMES
189 .iter()
190 .any(|name| path.as_ref() == Path::new(name))
191 }) {
192 self.enqueue_system_prompt_reload();
193 }
194 }
195 _ => {}
196 }
197 }
198
199 fn enqueue_system_prompt_reload(&mut self) {
200 self.reload_system_prompt_tx.try_send(()).ok();
201 }
202
203 // Note that this should only be called from `reload_system_prompt_task`.
204 fn reload_system_prompt(
205 &self,
206 prompt_store: Option<Entity<PromptStore>>,
207 cx: &mut Context<Self>,
208 ) -> Task<()> {
209 let worktrees = self
210 .project
211 .read(cx)
212 .visible_worktrees(cx)
213 .collect::<Vec<_>>();
214 let worktree_tasks = worktrees
215 .into_iter()
216 .map(|worktree| {
217 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
218 })
219 .collect::<Vec<_>>();
220 let default_user_rules_task = match prompt_store {
221 None => Task::ready(vec![]),
222 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
223 let prompts = prompt_store.default_prompt_metadata();
224 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
225 let contents = prompt_store.load(prompt_metadata.id, cx);
226 async move { (contents.await, prompt_metadata) }
227 });
228 cx.background_spawn(future::join_all(load_tasks))
229 }),
230 };
231
232 cx.spawn(async move |this, cx| {
233 let (worktrees, default_user_rules) =
234 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
235
236 let worktrees = worktrees
237 .into_iter()
238 .map(|(worktree, rules_error)| {
239 if let Some(rules_error) = rules_error {
240 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
241 }
242 worktree
243 })
244 .collect::<Vec<_>>();
245
246 let default_user_rules = default_user_rules
247 .into_iter()
248 .flat_map(|(contents, prompt_metadata)| match contents {
249 Ok(contents) => Some(UserRulesContext {
250 uuid: match prompt_metadata.id {
251 PromptId::User { uuid } => uuid,
252 PromptId::EditWorkflow => return None,
253 },
254 title: prompt_metadata.title.map(|title| title.to_string()),
255 contents,
256 }),
257 Err(err) => {
258 this.update(cx, |_, cx| {
259 cx.emit(RulesLoadingError {
260 message: format!("{err:?}").into(),
261 });
262 })
263 .ok();
264 None
265 }
266 })
267 .collect::<Vec<_>>();
268
269 this.update(cx, |this, _cx| {
270 *this.project_context.0.borrow_mut() =
271 Some(ProjectContext::new(worktrees, default_user_rules));
272 })
273 .ok();
274 })
275 }
276
277 fn load_worktree_info_for_system_prompt(
278 worktree: Entity<Worktree>,
279 project: Entity<Project>,
280 cx: &mut App,
281 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
282 let root_name = worktree.read(cx).root_name().into();
283
284 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
285 let Some(rules_task) = rules_task else {
286 return Task::ready((
287 WorktreeContext {
288 root_name,
289 rules_file: None,
290 },
291 None,
292 ));
293 };
294
295 cx.spawn(async move |_| {
296 let (rules_file, rules_file_error) = match rules_task.await {
297 Ok(rules_file) => (Some(rules_file), None),
298 Err(err) => (
299 None,
300 Some(RulesLoadingError {
301 message: format!("{err}").into(),
302 }),
303 ),
304 };
305 let worktree_info = WorktreeContext {
306 root_name,
307 rules_file,
308 };
309 (worktree_info, rules_file_error)
310 })
311 }
312
313 fn load_worktree_rules_file(
314 worktree: Entity<Worktree>,
315 project: Entity<Project>,
316 cx: &mut App,
317 ) -> Option<Task<Result<RulesFileContext>>> {
318 let worktree_ref = worktree.read(cx);
319 let worktree_id = worktree_ref.id();
320 let selected_rules_file = RULES_FILE_NAMES
321 .into_iter()
322 .filter_map(|name| {
323 worktree_ref
324 .entry_for_path(name)
325 .filter(|entry| entry.is_file())
326 .map(|entry| entry.path.clone())
327 })
328 .next();
329
330 // Note that Cline supports `.clinerules` being a directory, but that is not currently
331 // supported. This doesn't seem to occur often in GitHub repositories.
332 selected_rules_file.map(|path_in_worktree| {
333 let project_path = ProjectPath {
334 worktree_id,
335 path: path_in_worktree.clone(),
336 };
337 let buffer_task =
338 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
339 let rope_task = cx.spawn(async move |cx| {
340 buffer_task.await?.read_with(cx, |buffer, cx| {
341 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
342 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
343 })?
344 });
345 // Build a string from the rope on a background thread.
346 cx.background_spawn(async move {
347 let (project_entry_id, rope) = rope_task.await?;
348 anyhow::Ok(RulesFileContext {
349 path_in_worktree,
350 text: rope.to_string().trim().to_string(),
351 project_entry_id: project_entry_id.to_usize(),
352 })
353 })
354 })
355 }
356
357 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
358 self.context_server_manager.clone()
359 }
360
361 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
362 &self.prompt_store
363 }
364
365 pub fn tools(&self) -> Entity<ToolWorkingSet> {
366 self.tools.clone()
367 }
368
369 /// Returns the number of threads.
370 pub fn thread_count(&self) -> usize {
371 self.threads.len()
372 }
373
374 pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
375 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
376 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
377 threads
378 }
379
380 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
381 cx.new(|cx| {
382 Thread::new(
383 self.project.clone(),
384 self.tools.clone(),
385 self.prompt_builder.clone(),
386 self.project_context.clone(),
387 cx,
388 )
389 })
390 }
391
392 pub fn open_thread(
393 &self,
394 id: &ThreadId,
395 cx: &mut Context<Self>,
396 ) -> Task<Result<Entity<Thread>>> {
397 let id = id.clone();
398 let database_future = ThreadsDatabase::global_future(cx);
399 cx.spawn(async move |this, cx| {
400 let database = database_future.await.map_err(|err| anyhow!(err))?;
401 let thread = database
402 .try_find_thread(id.clone())
403 .await?
404 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
405
406 let thread = this.update(cx, |this, cx| {
407 cx.new(|cx| {
408 Thread::deserialize(
409 id.clone(),
410 thread,
411 this.project.clone(),
412 this.tools.clone(),
413 this.prompt_builder.clone(),
414 this.project_context.clone(),
415 cx,
416 )
417 })
418 })?;
419
420 Ok(thread)
421 })
422 }
423
424 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
425 let (metadata, serialized_thread) =
426 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
427
428 let database_future = ThreadsDatabase::global_future(cx);
429 cx.spawn(async move |this, cx| {
430 let serialized_thread = serialized_thread.await?;
431 let database = database_future.await.map_err(|err| anyhow!(err))?;
432 database.save_thread(metadata, serialized_thread).await?;
433
434 this.update(cx, |this, cx| this.reload(cx))?.await
435 })
436 }
437
438 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
439 let id = id.clone();
440 let database_future = ThreadsDatabase::global_future(cx);
441 cx.spawn(async move |this, cx| {
442 let database = database_future.await.map_err(|err| anyhow!(err))?;
443 database.delete_thread(id.clone()).await?;
444
445 this.update(cx, |this, cx| {
446 this.threads.retain(|thread| thread.id != id);
447 cx.notify();
448 })
449 })
450 }
451
452 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
453 let database_future = ThreadsDatabase::global_future(cx);
454 cx.spawn(async move |this, cx| {
455 let threads = database_future
456 .await
457 .map_err(|err| anyhow!(err))?
458 .list_threads()
459 .await?;
460
461 this.update(cx, |this, cx| {
462 this.threads = threads;
463 cx.notify();
464 })
465 })
466 }
467
468 fn load_default_profile(&self, cx: &mut Context<Self>) {
469 let assistant_settings = AssistantSettings::get_global(cx);
470
471 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
472 }
473
474 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
475 let assistant_settings = AssistantSettings::get_global(cx);
476
477 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
478 self.load_profile(profile.clone(), cx);
479 }
480 }
481
482 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
483 self.tools.update(cx, |tools, cx| {
484 tools.disable_all_tools(cx);
485 tools.enable(
486 ToolSource::Native,
487 &profile
488 .tools
489 .iter()
490 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
491 .collect::<Vec<_>>(),
492 cx,
493 );
494 });
495
496 if profile.enable_all_context_servers {
497 for context_server in self.context_server_manager.read(cx).all_servers() {
498 self.tools.update(cx, |tools, cx| {
499 tools.enable_source(
500 ToolSource::ContextServer {
501 id: context_server.id().into(),
502 },
503 cx,
504 );
505 });
506 }
507 // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
508 for (context_server_id, preset) in &profile.context_servers {
509 self.tools.update(cx, |tools, cx| {
510 tools.disable(
511 ToolSource::ContextServer {
512 id: context_server_id.clone().into(),
513 },
514 &preset
515 .tools
516 .iter()
517 .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
518 .collect::<Vec<_>>(),
519 cx,
520 )
521 })
522 }
523 } else {
524 for (context_server_id, preset) in &profile.context_servers {
525 self.tools.update(cx, |tools, cx| {
526 tools.enable(
527 ToolSource::ContextServer {
528 id: context_server_id.clone().into(),
529 },
530 &preset
531 .tools
532 .iter()
533 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
534 .collect::<Vec<_>>(),
535 cx,
536 )
537 })
538 }
539 }
540 }
541
542 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
543 cx.subscribe(
544 &self.context_server_manager.clone(),
545 Self::handle_context_server_event,
546 )
547 .detach();
548 }
549
550 fn handle_context_server_event(
551 &mut self,
552 context_server_manager: Entity<ContextServerManager>,
553 event: &context_server::manager::Event,
554 cx: &mut Context<Self>,
555 ) {
556 let tool_working_set = self.tools.clone();
557 match event {
558 context_server::manager::Event::ServerStatusChanged { server_id, status } => {
559 match status {
560 Some(ContextServerStatus::Running) => {
561 if let Some(server) = context_server_manager.read(cx).get_server(server_id)
562 {
563 let context_server_manager = context_server_manager.clone();
564 cx.spawn({
565 let server = server.clone();
566 let server_id = server_id.clone();
567 async move |this, cx| {
568 let Some(protocol) = server.client() else {
569 return;
570 };
571
572 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
573 if let Some(tools) = protocol.list_tools().await.log_err() {
574 let tool_ids = tool_working_set
575 .update(cx, |tool_working_set, _| {
576 tools
577 .tools
578 .into_iter()
579 .map(|tool| {
580 log::info!(
581 "registering context server tool: {:?}",
582 tool.name
583 );
584 tool_working_set.insert(Arc::new(
585 ContextServerTool::new(
586 context_server_manager.clone(),
587 server.id(),
588 tool,
589 ),
590 ))
591 })
592 .collect::<Vec<_>>()
593 })
594 .log_err();
595
596 if let Some(tool_ids) = tool_ids {
597 this.update(cx, |this, cx| {
598 this.context_server_tool_ids
599 .insert(server_id, tool_ids);
600 this.load_default_profile(cx);
601 })
602 .log_err();
603 }
604 }
605 }
606 }
607 })
608 .detach();
609 }
610 }
611 None => {
612 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
613 tool_working_set.update(cx, |tool_working_set, _| {
614 tool_working_set.remove(&tool_ids);
615 });
616 self.load_default_profile(cx);
617 }
618 }
619 _ => {}
620 }
621 }
622 }
623 }
624}
625
626#[derive(Debug, Clone, Serialize, Deserialize)]
627pub struct SerializedThreadMetadata {
628 pub id: ThreadId,
629 pub summary: SharedString,
630 pub updated_at: DateTime<Utc>,
631}
632
633#[derive(Serialize, Deserialize, Debug)]
634pub struct SerializedThread {
635 pub version: String,
636 pub summary: SharedString,
637 pub updated_at: DateTime<Utc>,
638 pub messages: Vec<SerializedMessage>,
639 #[serde(default)]
640 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
641 #[serde(default)]
642 pub cumulative_token_usage: TokenUsage,
643 #[serde(default)]
644 pub request_token_usage: Vec<TokenUsage>,
645 #[serde(default)]
646 pub detailed_summary_state: DetailedSummaryState,
647 #[serde(default)]
648 pub exceeded_window_error: Option<ExceededWindowError>,
649 #[serde(default)]
650 pub model: Option<SerializedLanguageModel>,
651}
652
653#[derive(Serialize, Deserialize, Debug)]
654pub struct SerializedLanguageModel {
655 pub provider: String,
656 pub model: String,
657}
658
659impl SerializedThread {
660 pub const VERSION: &'static str = "0.2.0";
661
662 pub fn from_json(json: &[u8]) -> Result<Self> {
663 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
664 match saved_thread_json.get("version") {
665 Some(serde_json::Value::String(version)) => match version.as_str() {
666 SerializedThreadV0_1_0::VERSION => {
667 let saved_thread =
668 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
669 Ok(saved_thread.upgrade())
670 }
671 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
672 saved_thread_json,
673 )?),
674 _ => Err(anyhow!(
675 "unrecognized serialized thread version: {}",
676 version
677 )),
678 },
679 None => {
680 let saved_thread =
681 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
682 Ok(saved_thread.upgrade())
683 }
684 version => Err(anyhow!(
685 "unrecognized serialized thread version: {:?}",
686 version
687 )),
688 }
689 }
690}
691
692#[derive(Serialize, Deserialize, Debug)]
693pub struct SerializedThreadV0_1_0(
694 // The structure did not change, so we are reusing the latest SerializedThread.
695 // When making the next version, make sure this points to SerializedThreadV0_2_0
696 SerializedThread,
697);
698
699impl SerializedThreadV0_1_0 {
700 pub const VERSION: &'static str = "0.1.0";
701
702 pub fn upgrade(self) -> SerializedThread {
703 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
704
705 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
706
707 for message in self.0.messages {
708 if message.role == Role::User && !message.tool_results.is_empty() {
709 if let Some(last_message) = messages.last_mut() {
710 debug_assert!(last_message.role == Role::Assistant);
711
712 last_message.tool_results = message.tool_results;
713 continue;
714 }
715 }
716
717 messages.push(message);
718 }
719
720 SerializedThread { messages, ..self.0 }
721 }
722}
723
724#[derive(Debug, Serialize, Deserialize)]
725pub struct SerializedMessage {
726 pub id: MessageId,
727 pub role: Role,
728 #[serde(default)]
729 pub segments: Vec<SerializedMessageSegment>,
730 #[serde(default)]
731 pub tool_uses: Vec<SerializedToolUse>,
732 #[serde(default)]
733 pub tool_results: Vec<SerializedToolResult>,
734 #[serde(default)]
735 pub context: String,
736 #[serde(default)]
737 pub creases: Vec<SerializedCrease>,
738}
739
740#[derive(Debug, Serialize, Deserialize)]
741#[serde(tag = "type")]
742pub enum SerializedMessageSegment {
743 #[serde(rename = "text")]
744 Text {
745 text: String,
746 },
747 #[serde(rename = "thinking")]
748 Thinking {
749 text: String,
750 #[serde(skip_serializing_if = "Option::is_none")]
751 signature: Option<String>,
752 },
753 RedactedThinking {
754 data: Vec<u8>,
755 },
756}
757
758#[derive(Debug, Serialize, Deserialize)]
759pub struct SerializedToolUse {
760 pub id: LanguageModelToolUseId,
761 pub name: SharedString,
762 pub input: serde_json::Value,
763}
764
765#[derive(Debug, Serialize, Deserialize)]
766pub struct SerializedToolResult {
767 pub tool_use_id: LanguageModelToolUseId,
768 pub is_error: bool,
769 pub content: Arc<str>,
770}
771
772#[derive(Serialize, Deserialize)]
773struct LegacySerializedThread {
774 pub summary: SharedString,
775 pub updated_at: DateTime<Utc>,
776 pub messages: Vec<LegacySerializedMessage>,
777 #[serde(default)]
778 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
779}
780
781impl LegacySerializedThread {
782 pub fn upgrade(self) -> SerializedThread {
783 SerializedThread {
784 version: SerializedThread::VERSION.to_string(),
785 summary: self.summary,
786 updated_at: self.updated_at,
787 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
788 initial_project_snapshot: self.initial_project_snapshot,
789 cumulative_token_usage: TokenUsage::default(),
790 request_token_usage: Vec::new(),
791 detailed_summary_state: DetailedSummaryState::default(),
792 exceeded_window_error: None,
793 model: None,
794 }
795 }
796}
797
798#[derive(Debug, Serialize, Deserialize)]
799struct LegacySerializedMessage {
800 pub id: MessageId,
801 pub role: Role,
802 pub text: String,
803 #[serde(default)]
804 pub tool_uses: Vec<SerializedToolUse>,
805 #[serde(default)]
806 pub tool_results: Vec<SerializedToolResult>,
807}
808
809impl LegacySerializedMessage {
810 fn upgrade(self) -> SerializedMessage {
811 SerializedMessage {
812 id: self.id,
813 role: self.role,
814 segments: vec![SerializedMessageSegment::Text { text: self.text }],
815 tool_uses: self.tool_uses,
816 tool_results: self.tool_results,
817 context: String::new(),
818 creases: Vec::new(),
819 }
820 }
821}
822
823#[derive(Debug, Serialize, Deserialize)]
824pub struct SerializedCrease {
825 pub start: usize,
826 pub end: usize,
827 pub icon_path: SharedString,
828 pub label: SharedString,
829}
830
831struct GlobalThreadsDatabase(
832 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
833);
834
835impl Global for GlobalThreadsDatabase {}
836
837pub(crate) struct ThreadsDatabase {
838 executor: BackgroundExecutor,
839 env: heed::Env,
840 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
841}
842
843impl heed::BytesEncode<'_> for SerializedThread {
844 type EItem = SerializedThread;
845
846 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
847 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
848 }
849}
850
851impl<'a> heed::BytesDecode<'a> for SerializedThread {
852 type DItem = SerializedThread;
853
854 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
855 // We implement this type manually because we want to call `SerializedThread::from_json`,
856 // instead of the Deserialize trait implementation for `SerializedThread`.
857 SerializedThread::from_json(bytes).map_err(Into::into)
858 }
859}
860
861impl ThreadsDatabase {
862 fn global_future(
863 cx: &mut App,
864 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
865 GlobalThreadsDatabase::global(cx).0.clone()
866 }
867
868 fn init(cx: &mut App) {
869 let executor = cx.background_executor().clone();
870 let database_future = executor
871 .spawn({
872 let executor = executor.clone();
873 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
874 async move { ThreadsDatabase::new(database_path, executor) }
875 })
876 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
877 .boxed()
878 .shared();
879
880 cx.set_global(GlobalThreadsDatabase(database_future));
881 }
882
883 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
884 std::fs::create_dir_all(&path)?;
885
886 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
887 let env = unsafe {
888 heed::EnvOpenOptions::new()
889 .map_size(ONE_GB_IN_BYTES)
890 .max_dbs(1)
891 .open(path)?
892 };
893
894 let mut txn = env.write_txn()?;
895 let threads = env.create_database(&mut txn, Some("threads"))?;
896 txn.commit()?;
897
898 Ok(Self {
899 executor,
900 env,
901 threads,
902 })
903 }
904
905 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
906 let env = self.env.clone();
907 let threads = self.threads;
908
909 self.executor.spawn(async move {
910 let txn = env.read_txn()?;
911 let mut iter = threads.iter(&txn)?;
912 let mut threads = Vec::new();
913 while let Some((key, value)) = iter.next().transpose()? {
914 threads.push(SerializedThreadMetadata {
915 id: key,
916 summary: value.summary,
917 updated_at: value.updated_at,
918 });
919 }
920
921 Ok(threads)
922 })
923 }
924
925 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
926 let env = self.env.clone();
927 let threads = self.threads;
928
929 self.executor.spawn(async move {
930 let txn = env.read_txn()?;
931 let thread = threads.get(&txn, &id)?;
932 Ok(thread)
933 })
934 }
935
936 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
937 let env = self.env.clone();
938 let threads = self.threads;
939
940 self.executor.spawn(async move {
941 let mut txn = env.write_txn()?;
942 threads.put(&mut txn, &id, &thread)?;
943 txn.commit()?;
944 Ok(())
945 })
946 }
947
948 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
949 let env = self.env.clone();
950 let threads = self.threads;
951
952 self.executor.spawn(async move {
953 let mut txn = env.write_txn()?;
954 threads.delete(&mut txn, &id)?;
955 txn.commit()?;
956 Ok(())
957 })
958 }
959}