1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::manager::ContextServerManager;
13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
14use fs::Fs;
15use futures::channel::{mpsc, oneshot};
16use futures::future::{self, BoxFuture, Shared};
17use futures::{FutureExt as _, StreamExt as _};
18use gpui::{
19 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
20 Subscription, Task, prelude::*,
21};
22use heed::Database;
23use heed::types::SerdeBincode;
24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
25use project::{Project, Worktree};
26use prompt_store::{
27 DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptStore, PromptsUpdatedEvent,
28 RulesFileContext, WorktreeContext,
29};
30use serde::{Deserialize, Serialize};
31use settings::{Settings as _, SettingsStore};
32use util::ResultExt as _;
33
34use crate::thread::{
35 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
36};
37
38const RULES_FILE_NAMES: [&'static str; 6] = [
39 ".rules",
40 ".cursorrules",
41 ".windsurfrules",
42 ".clinerules",
43 ".github/copilot-instructions.md",
44 "CLAUDE.md",
45];
46
47pub fn init(cx: &mut App) {
48 ThreadsDatabase::init(cx);
49}
50
51/// A system prompt shared by all threads created by this ThreadStore
52#[derive(Clone, Default)]
53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
54
55impl SharedProjectContext {
56 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
57 self.0.borrow()
58 }
59}
60
61pub struct ThreadStore {
62 project: Entity<Project>,
63 tools: Entity<ToolWorkingSet>,
64 prompt_builder: Arc<PromptBuilder>,
65 context_server_manager: Entity<ContextServerManager>,
66 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
67 threads: Vec<SerializedThreadMetadata>,
68 project_context: SharedProjectContext,
69 reload_system_prompt_tx: mpsc::Sender<()>,
70 _reload_system_prompt_task: Task<()>,
71 _subscriptions: Vec<Subscription>,
72}
73
74pub struct RulesLoadingError {
75 pub message: SharedString,
76}
77
78impl EventEmitter<RulesLoadingError> for ThreadStore {}
79
80impl ThreadStore {
81 pub fn load(
82 project: Entity<Project>,
83 tools: Entity<ToolWorkingSet>,
84 prompt_builder: Arc<PromptBuilder>,
85 cx: &mut App,
86 ) -> Task<Result<Entity<Self>>> {
87 let prompt_store = PromptStore::global(cx);
88 cx.spawn(async move |cx| {
89 let prompt_store = prompt_store.await.ok();
90 let (thread_store, ready_rx) = cx.update(|cx| {
91 let mut option_ready_rx = None;
92 let thread_store = cx.new(|cx| {
93 let (thread_store, ready_rx) =
94 Self::new(project, tools, prompt_builder, prompt_store, cx);
95 option_ready_rx = Some(ready_rx);
96 thread_store
97 });
98 (thread_store, option_ready_rx.take().unwrap())
99 })?;
100 ready_rx.await?;
101 Ok(thread_store)
102 })
103 }
104
105 fn new(
106 project: Entity<Project>,
107 tools: Entity<ToolWorkingSet>,
108 prompt_builder: Arc<PromptBuilder>,
109 prompt_store: Option<Entity<PromptStore>>,
110 cx: &mut Context<Self>,
111 ) -> (Self, oneshot::Receiver<()>) {
112 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
113 let context_server_manager = cx.new(|cx| {
114 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
115 });
116
117 let mut subscriptions = vec![
118 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
119 this.load_default_profile(cx);
120 }),
121 cx.subscribe(&project, Self::handle_project_event),
122 ];
123
124 if let Some(prompt_store) = prompt_store.as_ref() {
125 subscriptions.push(cx.subscribe(
126 prompt_store,
127 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
128 this.enqueue_system_prompt_reload();
129 },
130 ))
131 }
132
133 // This channel and task prevent concurrent and redundant loading of the system prompt.
134 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
135 let (ready_tx, ready_rx) = oneshot::channel();
136 let mut ready_tx = Some(ready_tx);
137 let reload_system_prompt_task = cx.spawn({
138 async move |thread_store, cx| {
139 loop {
140 let Some(reload_task) = thread_store
141 .update(cx, |thread_store, cx| {
142 thread_store.reload_system_prompt(prompt_store.clone(), cx)
143 })
144 .ok()
145 else {
146 return;
147 };
148 reload_task.await;
149 if let Some(ready_tx) = ready_tx.take() {
150 ready_tx.send(()).ok();
151 }
152 reload_system_prompt_rx.next().await;
153 }
154 }
155 });
156
157 let this = Self {
158 project,
159 tools,
160 prompt_builder,
161 context_server_manager,
162 context_server_tool_ids: HashMap::default(),
163 threads: Vec::new(),
164 project_context: SharedProjectContext::default(),
165 reload_system_prompt_tx,
166 _reload_system_prompt_task: reload_system_prompt_task,
167 _subscriptions: subscriptions,
168 };
169 this.load_default_profile(cx);
170 this.register_context_server_handlers(cx);
171 this.reload(cx).detach_and_log_err(cx);
172 (this, ready_rx)
173 }
174
175 fn handle_project_event(
176 &mut self,
177 _project: Entity<Project>,
178 event: &project::Event,
179 _cx: &mut Context<Self>,
180 ) {
181 match event {
182 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
183 self.enqueue_system_prompt_reload();
184 }
185 project::Event::WorktreeUpdatedEntries(_, items) => {
186 if items.iter().any(|(path, _, _)| {
187 RULES_FILE_NAMES
188 .iter()
189 .any(|name| path.as_ref() == Path::new(name))
190 }) {
191 self.enqueue_system_prompt_reload();
192 }
193 }
194 _ => {}
195 }
196 }
197
198 fn enqueue_system_prompt_reload(&mut self) {
199 self.reload_system_prompt_tx.try_send(()).ok();
200 }
201
202 // Note that this should only be called from `reload_system_prompt_task`.
203 fn reload_system_prompt(
204 &self,
205 prompt_store: Option<Entity<PromptStore>>,
206 cx: &mut Context<Self>,
207 ) -> Task<()> {
208 let project = self.project.read(cx);
209 let worktree_tasks = project
210 .visible_worktrees(cx)
211 .map(|worktree| {
212 Self::load_worktree_info_for_system_prompt(
213 project.fs().clone(),
214 worktree.read(cx),
215 cx,
216 )
217 })
218 .collect::<Vec<_>>();
219 let default_user_rules_task = match prompt_store {
220 None => Task::ready(vec![]),
221 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
222 let prompts = prompt_store.default_prompt_metadata();
223 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
224 let contents = prompt_store.load(prompt_metadata.id, cx);
225 async move { (contents.await, prompt_metadata) }
226 });
227 cx.background_spawn(future::join_all(load_tasks))
228 }),
229 };
230
231 cx.spawn(async move |this, cx| {
232 let (worktrees, default_user_rules) =
233 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
234
235 let worktrees = worktrees
236 .into_iter()
237 .map(|(worktree, rules_error)| {
238 if let Some(rules_error) = rules_error {
239 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
240 }
241 worktree
242 })
243 .collect::<Vec<_>>();
244
245 let default_user_rules = default_user_rules
246 .into_iter()
247 .flat_map(|(contents, prompt_metadata)| match contents {
248 Ok(contents) => Some(DefaultUserRulesContext {
249 title: prompt_metadata.title.map(|title| title.to_string()),
250 contents,
251 }),
252 Err(err) => {
253 this.update(cx, |_, cx| {
254 cx.emit(RulesLoadingError {
255 message: format!("{err:?}").into(),
256 });
257 })
258 .ok();
259 None
260 }
261 })
262 .collect::<Vec<_>>();
263
264 this.update(cx, |this, _cx| {
265 *this.project_context.0.borrow_mut() =
266 Some(ProjectContext::new(worktrees, default_user_rules));
267 })
268 .ok();
269 })
270 }
271
272 fn load_worktree_info_for_system_prompt(
273 fs: Arc<dyn Fs>,
274 worktree: &Worktree,
275 cx: &App,
276 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
277 let root_name = worktree.root_name().into();
278 let abs_path = worktree.abs_path();
279
280 let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
281 let Some(rules_task) = rules_task else {
282 return Task::ready((
283 WorktreeContext {
284 root_name,
285 abs_path,
286 rules_file: None,
287 },
288 None,
289 ));
290 };
291
292 cx.spawn(async move |_| {
293 let (rules_file, rules_file_error) = match rules_task.await {
294 Ok(rules_file) => (Some(rules_file), None),
295 Err(err) => (
296 None,
297 Some(RulesLoadingError {
298 message: format!("{err}").into(),
299 }),
300 ),
301 };
302 let worktree_info = WorktreeContext {
303 root_name,
304 abs_path,
305 rules_file,
306 };
307 (worktree_info, rules_file_error)
308 })
309 }
310
311 fn load_worktree_rules_file(
312 fs: Arc<dyn Fs>,
313 worktree: &Worktree,
314 cx: &App,
315 ) -> Option<Task<Result<RulesFileContext>>> {
316 let selected_rules_file = RULES_FILE_NAMES
317 .into_iter()
318 .filter_map(|name| {
319 worktree
320 .entry_for_path(name)
321 .filter(|entry| entry.is_file())
322 .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
323 })
324 .next();
325
326 // Note that Cline supports `.clinerules` being a directory, but that is not currently
327 // supported. This doesn't seem to occur often in GitHub repositories.
328 selected_rules_file.map(|(path_in_worktree, abs_path)| {
329 let fs = fs.clone();
330 cx.background_spawn(async move {
331 let abs_path = abs_path?;
332 let text = fs.load(&abs_path).await.with_context(|| {
333 format!("Failed to load assistant rules file {:?}", abs_path)
334 })?;
335 anyhow::Ok(RulesFileContext {
336 path_in_worktree,
337 abs_path: abs_path.into(),
338 text: text.trim().to_string(),
339 })
340 })
341 })
342 }
343
344 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
345 self.context_server_manager.clone()
346 }
347
348 pub fn tools(&self) -> Entity<ToolWorkingSet> {
349 self.tools.clone()
350 }
351
352 /// Returns the number of threads.
353 pub fn thread_count(&self) -> usize {
354 self.threads.len()
355 }
356
357 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
358 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
359 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
360 threads
361 }
362
363 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
364 self.threads().into_iter().take(limit).collect()
365 }
366
367 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
368 cx.new(|cx| {
369 Thread::new(
370 self.project.clone(),
371 self.tools.clone(),
372 self.prompt_builder.clone(),
373 self.project_context.clone(),
374 cx,
375 )
376 })
377 }
378
379 pub fn open_thread(
380 &self,
381 id: &ThreadId,
382 cx: &mut Context<Self>,
383 ) -> Task<Result<Entity<Thread>>> {
384 let id = id.clone();
385 let database_future = ThreadsDatabase::global_future(cx);
386 cx.spawn(async move |this, cx| {
387 let database = database_future.await.map_err(|err| anyhow!(err))?;
388 let thread = database
389 .try_find_thread(id.clone())
390 .await?
391 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
392
393 let thread = this.update(cx, |this, cx| {
394 cx.new(|cx| {
395 Thread::deserialize(
396 id.clone(),
397 thread,
398 this.project.clone(),
399 this.tools.clone(),
400 this.prompt_builder.clone(),
401 this.project_context.clone(),
402 cx,
403 )
404 })
405 })?;
406
407 Ok(thread)
408 })
409 }
410
411 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
412 let (metadata, serialized_thread) =
413 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
414
415 let database_future = ThreadsDatabase::global_future(cx);
416 cx.spawn(async move |this, cx| {
417 let serialized_thread = serialized_thread.await?;
418 let database = database_future.await.map_err(|err| anyhow!(err))?;
419 database.save_thread(metadata, serialized_thread).await?;
420
421 this.update(cx, |this, cx| this.reload(cx))?.await
422 })
423 }
424
425 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
426 let id = id.clone();
427 let database_future = ThreadsDatabase::global_future(cx);
428 cx.spawn(async move |this, cx| {
429 let database = database_future.await.map_err(|err| anyhow!(err))?;
430 database.delete_thread(id.clone()).await?;
431
432 this.update(cx, |this, cx| {
433 this.threads.retain(|thread| thread.id != id);
434 cx.notify();
435 })
436 })
437 }
438
439 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
440 let database_future = ThreadsDatabase::global_future(cx);
441 cx.spawn(async move |this, cx| {
442 let threads = database_future
443 .await
444 .map_err(|err| anyhow!(err))?
445 .list_threads()
446 .await?;
447
448 this.update(cx, |this, cx| {
449 this.threads = threads;
450 cx.notify();
451 })
452 })
453 }
454
455 fn load_default_profile(&self, cx: &mut Context<Self>) {
456 let assistant_settings = AssistantSettings::get_global(cx);
457
458 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
459 }
460
461 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
462 let assistant_settings = AssistantSettings::get_global(cx);
463
464 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
465 self.load_profile(profile.clone(), cx);
466 }
467 }
468
469 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
470 self.tools.update(cx, |tools, cx| {
471 tools.disable_all_tools(cx);
472 tools.enable(
473 ToolSource::Native,
474 &profile
475 .tools
476 .iter()
477 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
478 .collect::<Vec<_>>(),
479 cx,
480 );
481 });
482
483 if profile.enable_all_context_servers {
484 for context_server in self.context_server_manager.read(cx).all_servers() {
485 self.tools.update(cx, |tools, cx| {
486 tools.enable_source(
487 ToolSource::ContextServer {
488 id: context_server.id().into(),
489 },
490 cx,
491 );
492 });
493 }
494 } else {
495 for (context_server_id, preset) in &profile.context_servers {
496 self.tools.update(cx, |tools, cx| {
497 tools.enable(
498 ToolSource::ContextServer {
499 id: context_server_id.clone().into(),
500 },
501 &preset
502 .tools
503 .iter()
504 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
505 .collect::<Vec<_>>(),
506 cx,
507 )
508 })
509 }
510 }
511 }
512
513 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
514 cx.subscribe(
515 &self.context_server_manager.clone(),
516 Self::handle_context_server_event,
517 )
518 .detach();
519 }
520
521 fn handle_context_server_event(
522 &mut self,
523 context_server_manager: Entity<ContextServerManager>,
524 event: &context_server::manager::Event,
525 cx: &mut Context<Self>,
526 ) {
527 let tool_working_set = self.tools.clone();
528 match event {
529 context_server::manager::Event::ServerStarted { server_id } => {
530 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
531 let context_server_manager = context_server_manager.clone();
532 cx.spawn({
533 let server = server.clone();
534 let server_id = server_id.clone();
535 async move |this, cx| {
536 let Some(protocol) = server.client() else {
537 return;
538 };
539
540 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
541 if let Some(tools) = protocol.list_tools().await.log_err() {
542 let tool_ids = tool_working_set
543 .update(cx, |tool_working_set, _| {
544 tools
545 .tools
546 .into_iter()
547 .map(|tool| {
548 log::info!(
549 "registering context server tool: {:?}",
550 tool.name
551 );
552 tool_working_set.insert(Arc::new(
553 ContextServerTool::new(
554 context_server_manager.clone(),
555 server.id(),
556 tool,
557 ),
558 ))
559 })
560 .collect::<Vec<_>>()
561 })
562 .log_err();
563
564 if let Some(tool_ids) = tool_ids {
565 this.update(cx, |this, cx| {
566 this.context_server_tool_ids
567 .insert(server_id, tool_ids);
568 this.load_default_profile(cx);
569 })
570 .log_err();
571 }
572 }
573 }
574 }
575 })
576 .detach();
577 }
578 }
579 context_server::manager::Event::ServerStopped { server_id } => {
580 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
581 tool_working_set.update(cx, |tool_working_set, _| {
582 tool_working_set.remove(&tool_ids);
583 });
584 self.load_default_profile(cx);
585 }
586 }
587 }
588 }
589}
590
591#[derive(Debug, Clone, Serialize, Deserialize)]
592pub struct SerializedThreadMetadata {
593 pub id: ThreadId,
594 pub summary: SharedString,
595 pub updated_at: DateTime<Utc>,
596}
597
598#[derive(Serialize, Deserialize, Debug)]
599pub struct SerializedThread {
600 pub version: String,
601 pub summary: SharedString,
602 pub updated_at: DateTime<Utc>,
603 pub messages: Vec<SerializedMessage>,
604 #[serde(default)]
605 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
606 #[serde(default)]
607 pub cumulative_token_usage: TokenUsage,
608 #[serde(default)]
609 pub request_token_usage: Vec<TokenUsage>,
610 #[serde(default)]
611 pub detailed_summary_state: DetailedSummaryState,
612 #[serde(default)]
613 pub exceeded_window_error: Option<ExceededWindowError>,
614}
615
616impl SerializedThread {
617 pub const VERSION: &'static str = "0.1.0";
618
619 pub fn from_json(json: &[u8]) -> Result<Self> {
620 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
621 match saved_thread_json.get("version") {
622 Some(serde_json::Value::String(version)) => match version.as_str() {
623 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
624 saved_thread_json,
625 )?),
626 _ => Err(anyhow!(
627 "unrecognized serialized thread version: {}",
628 version
629 )),
630 },
631 None => {
632 let saved_thread =
633 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
634 Ok(saved_thread.upgrade())
635 }
636 version => Err(anyhow!(
637 "unrecognized serialized thread version: {:?}",
638 version
639 )),
640 }
641 }
642}
643
644#[derive(Debug, Serialize, Deserialize)]
645pub struct SerializedMessage {
646 pub id: MessageId,
647 pub role: Role,
648 #[serde(default)]
649 pub segments: Vec<SerializedMessageSegment>,
650 #[serde(default)]
651 pub tool_uses: Vec<SerializedToolUse>,
652 #[serde(default)]
653 pub tool_results: Vec<SerializedToolResult>,
654 #[serde(default)]
655 pub context: String,
656}
657
658#[derive(Debug, Serialize, Deserialize)]
659#[serde(tag = "type")]
660pub enum SerializedMessageSegment {
661 #[serde(rename = "text")]
662 Text { text: String },
663 #[serde(rename = "thinking")]
664 Thinking { text: String },
665}
666
667#[derive(Debug, Serialize, Deserialize)]
668pub struct SerializedToolUse {
669 pub id: LanguageModelToolUseId,
670 pub name: SharedString,
671 pub input: serde_json::Value,
672}
673
674#[derive(Debug, Serialize, Deserialize)]
675pub struct SerializedToolResult {
676 pub tool_use_id: LanguageModelToolUseId,
677 pub is_error: bool,
678 pub content: Arc<str>,
679}
680
681#[derive(Serialize, Deserialize)]
682struct LegacySerializedThread {
683 pub summary: SharedString,
684 pub updated_at: DateTime<Utc>,
685 pub messages: Vec<LegacySerializedMessage>,
686 #[serde(default)]
687 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
688}
689
690impl LegacySerializedThread {
691 pub fn upgrade(self) -> SerializedThread {
692 SerializedThread {
693 version: SerializedThread::VERSION.to_string(),
694 summary: self.summary,
695 updated_at: self.updated_at,
696 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
697 initial_project_snapshot: self.initial_project_snapshot,
698 cumulative_token_usage: TokenUsage::default(),
699 request_token_usage: Vec::new(),
700 detailed_summary_state: DetailedSummaryState::default(),
701 exceeded_window_error: None,
702 }
703 }
704}
705
706#[derive(Debug, Serialize, Deserialize)]
707struct LegacySerializedMessage {
708 pub id: MessageId,
709 pub role: Role,
710 pub text: String,
711 #[serde(default)]
712 pub tool_uses: Vec<SerializedToolUse>,
713 #[serde(default)]
714 pub tool_results: Vec<SerializedToolResult>,
715}
716
717impl LegacySerializedMessage {
718 fn upgrade(self) -> SerializedMessage {
719 SerializedMessage {
720 id: self.id,
721 role: self.role,
722 segments: vec![SerializedMessageSegment::Text { text: self.text }],
723 tool_uses: self.tool_uses,
724 tool_results: self.tool_results,
725 context: String::new(),
726 }
727 }
728}
729
730struct GlobalThreadsDatabase(
731 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
732);
733
734impl Global for GlobalThreadsDatabase {}
735
736pub(crate) struct ThreadsDatabase {
737 executor: BackgroundExecutor,
738 env: heed::Env,
739 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
740}
741
742impl heed::BytesEncode<'_> for SerializedThread {
743 type EItem = SerializedThread;
744
745 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
746 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
747 }
748}
749
750impl<'a> heed::BytesDecode<'a> for SerializedThread {
751 type DItem = SerializedThread;
752
753 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
754 // We implement this type manually because we want to call `SerializedThread::from_json`,
755 // instead of the Deserialize trait implementation for `SerializedThread`.
756 SerializedThread::from_json(bytes).map_err(Into::into)
757 }
758}
759
760impl ThreadsDatabase {
761 fn global_future(
762 cx: &mut App,
763 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
764 GlobalThreadsDatabase::global(cx).0.clone()
765 }
766
767 fn init(cx: &mut App) {
768 let executor = cx.background_executor().clone();
769 let database_future = executor
770 .spawn({
771 let executor = executor.clone();
772 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
773 async move { ThreadsDatabase::new(database_path, executor) }
774 })
775 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
776 .boxed()
777 .shared();
778
779 cx.set_global(GlobalThreadsDatabase(database_future));
780 }
781
782 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
783 std::fs::create_dir_all(&path)?;
784
785 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
786 let env = unsafe {
787 heed::EnvOpenOptions::new()
788 .map_size(ONE_GB_IN_BYTES)
789 .max_dbs(1)
790 .open(path)?
791 };
792
793 let mut txn = env.write_txn()?;
794 let threads = env.create_database(&mut txn, Some("threads"))?;
795 txn.commit()?;
796
797 Ok(Self {
798 executor,
799 env,
800 threads,
801 })
802 }
803
804 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
805 let env = self.env.clone();
806 let threads = self.threads;
807
808 self.executor.spawn(async move {
809 let txn = env.read_txn()?;
810 let mut iter = threads.iter(&txn)?;
811 let mut threads = Vec::new();
812 while let Some((key, value)) = iter.next().transpose()? {
813 threads.push(SerializedThreadMetadata {
814 id: key,
815 summary: value.summary,
816 updated_at: value.updated_at,
817 });
818 }
819
820 Ok(threads)
821 })
822 }
823
824 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
825 let env = self.env.clone();
826 let threads = self.threads;
827
828 self.executor.spawn(async move {
829 let txn = env.read_txn()?;
830 let thread = threads.get(&txn, &id)?;
831 Ok(thread)
832 })
833 }
834
835 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
836 let env = self.env.clone();
837 let threads = self.threads;
838
839 self.executor.spawn(async move {
840 let mut txn = env.write_txn()?;
841 threads.put(&mut txn, &id, &thread)?;
842 txn.commit()?;
843 Ok(())
844 })
845 }
846
847 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
848 let env = self.env.clone();
849 let threads = self.threads;
850
851 self.executor.spawn(async move {
852 let mut txn = env.write_txn()?;
853 threads.delete(&mut txn, &id)?;
854 txn.commit()?;
855 Ok(())
856 })
857 }
858}