1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::manager::ContextServerManager;
13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
14use fs::Fs;
15use futures::channel::{mpsc, oneshot};
16use futures::future::{self, BoxFuture, Shared};
17use futures::{FutureExt as _, StreamExt as _};
18use gpui::{
19 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
20 Subscription, Task, prelude::*,
21};
22use heed::Database;
23use heed::types::SerdeBincode;
24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
25use project::{Project, Worktree};
26use prompt_store::{
27 DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptId, PromptStore,
28 PromptsUpdatedEvent, RulesFileContext, WorktreeContext,
29};
30use serde::{Deserialize, Serialize};
31use settings::{Settings as _, SettingsStore};
32use util::ResultExt as _;
33
34use crate::thread::{
35 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
36};
37
38const RULES_FILE_NAMES: [&'static str; 6] = [
39 ".rules",
40 ".cursorrules",
41 ".windsurfrules",
42 ".clinerules",
43 ".github/copilot-instructions.md",
44 "CLAUDE.md",
45];
46
47pub fn init(cx: &mut App) {
48 ThreadsDatabase::init(cx);
49}
50
51/// A system prompt shared by all threads created by this ThreadStore
52#[derive(Clone, Default)]
53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
54
55impl SharedProjectContext {
56 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
57 self.0.borrow()
58 }
59}
60
61pub struct ThreadStore {
62 project: Entity<Project>,
63 tools: Entity<ToolWorkingSet>,
64 prompt_builder: Arc<PromptBuilder>,
65 context_server_manager: Entity<ContextServerManager>,
66 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
67 threads: Vec<SerializedThreadMetadata>,
68 project_context: SharedProjectContext,
69 reload_system_prompt_tx: mpsc::Sender<()>,
70 _reload_system_prompt_task: Task<()>,
71 _subscriptions: Vec<Subscription>,
72}
73
74pub struct RulesLoadingError {
75 pub message: SharedString,
76}
77
78impl EventEmitter<RulesLoadingError> for ThreadStore {}
79
80impl ThreadStore {
81 pub fn load(
82 project: Entity<Project>,
83 tools: Entity<ToolWorkingSet>,
84 prompt_builder: Arc<PromptBuilder>,
85 cx: &mut App,
86 ) -> Task<Result<Entity<Self>>> {
87 let prompt_store = PromptStore::global(cx);
88 cx.spawn(async move |cx| {
89 let prompt_store = prompt_store.await.ok();
90 let (thread_store, ready_rx) = cx.update(|cx| {
91 let mut option_ready_rx = None;
92 let thread_store = cx.new(|cx| {
93 let (thread_store, ready_rx) =
94 Self::new(project, tools, prompt_builder, prompt_store, cx);
95 option_ready_rx = Some(ready_rx);
96 thread_store
97 });
98 (thread_store, option_ready_rx.take().unwrap())
99 })?;
100 ready_rx.await?;
101 Ok(thread_store)
102 })
103 }
104
105 fn new(
106 project: Entity<Project>,
107 tools: Entity<ToolWorkingSet>,
108 prompt_builder: Arc<PromptBuilder>,
109 prompt_store: Option<Entity<PromptStore>>,
110 cx: &mut Context<Self>,
111 ) -> (Self, oneshot::Receiver<()>) {
112 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
113 let context_server_manager = cx.new(|cx| {
114 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
115 });
116
117 let mut subscriptions = vec![
118 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
119 this.load_default_profile(cx);
120 }),
121 cx.subscribe(&project, Self::handle_project_event),
122 ];
123
124 if let Some(prompt_store) = prompt_store.as_ref() {
125 subscriptions.push(cx.subscribe(
126 prompt_store,
127 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
128 this.enqueue_system_prompt_reload();
129 },
130 ))
131 }
132
133 // This channel and task prevent concurrent and redundant loading of the system prompt.
134 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
135 let (ready_tx, ready_rx) = oneshot::channel();
136 let mut ready_tx = Some(ready_tx);
137 let reload_system_prompt_task = cx.spawn({
138 async move |thread_store, cx| {
139 loop {
140 let Some(reload_task) = thread_store
141 .update(cx, |thread_store, cx| {
142 thread_store.reload_system_prompt(prompt_store.clone(), cx)
143 })
144 .ok()
145 else {
146 return;
147 };
148 reload_task.await;
149 if let Some(ready_tx) = ready_tx.take() {
150 ready_tx.send(()).ok();
151 }
152 reload_system_prompt_rx.next().await;
153 }
154 }
155 });
156
157 let this = Self {
158 project,
159 tools,
160 prompt_builder,
161 context_server_manager,
162 context_server_tool_ids: HashMap::default(),
163 threads: Vec::new(),
164 project_context: SharedProjectContext::default(),
165 reload_system_prompt_tx,
166 _reload_system_prompt_task: reload_system_prompt_task,
167 _subscriptions: subscriptions,
168 };
169 this.load_default_profile(cx);
170 this.register_context_server_handlers(cx);
171 this.reload(cx).detach_and_log_err(cx);
172 (this, ready_rx)
173 }
174
175 fn handle_project_event(
176 &mut self,
177 _project: Entity<Project>,
178 event: &project::Event,
179 _cx: &mut Context<Self>,
180 ) {
181 match event {
182 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
183 self.enqueue_system_prompt_reload();
184 }
185 project::Event::WorktreeUpdatedEntries(_, items) => {
186 if items.iter().any(|(path, _, _)| {
187 RULES_FILE_NAMES
188 .iter()
189 .any(|name| path.as_ref() == Path::new(name))
190 }) {
191 self.enqueue_system_prompt_reload();
192 }
193 }
194 _ => {}
195 }
196 }
197
198 fn enqueue_system_prompt_reload(&mut self) {
199 self.reload_system_prompt_tx.try_send(()).ok();
200 }
201
202 // Note that this should only be called from `reload_system_prompt_task`.
203 fn reload_system_prompt(
204 &self,
205 prompt_store: Option<Entity<PromptStore>>,
206 cx: &mut Context<Self>,
207 ) -> Task<()> {
208 let project = self.project.read(cx);
209 let worktree_tasks = project
210 .visible_worktrees(cx)
211 .map(|worktree| {
212 Self::load_worktree_info_for_system_prompt(
213 project.fs().clone(),
214 worktree.read(cx),
215 cx,
216 )
217 })
218 .collect::<Vec<_>>();
219 let default_user_rules_task = match prompt_store {
220 None => Task::ready(vec![]),
221 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
222 let prompts = prompt_store.default_prompt_metadata();
223 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
224 let contents = prompt_store.load(prompt_metadata.id, cx);
225 async move { (contents.await, prompt_metadata) }
226 });
227 cx.background_spawn(future::join_all(load_tasks))
228 }),
229 };
230
231 cx.spawn(async move |this, cx| {
232 let (worktrees, default_user_rules) =
233 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
234
235 let worktrees = worktrees
236 .into_iter()
237 .map(|(worktree, rules_error)| {
238 if let Some(rules_error) = rules_error {
239 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
240 }
241 worktree
242 })
243 .collect::<Vec<_>>();
244
245 let default_user_rules = default_user_rules
246 .into_iter()
247 .flat_map(|(contents, prompt_metadata)| match contents {
248 Ok(contents) => Some(DefaultUserRulesContext {
249 uuid: match prompt_metadata.id {
250 PromptId::User { uuid } => uuid,
251 PromptId::EditWorkflow => return None,
252 },
253 title: prompt_metadata.title.map(|title| title.to_string()),
254 contents,
255 }),
256 Err(err) => {
257 this.update(cx, |_, cx| {
258 cx.emit(RulesLoadingError {
259 message: format!("{err:?}").into(),
260 });
261 })
262 .ok();
263 None
264 }
265 })
266 .collect::<Vec<_>>();
267
268 this.update(cx, |this, _cx| {
269 *this.project_context.0.borrow_mut() =
270 Some(ProjectContext::new(worktrees, default_user_rules));
271 })
272 .ok();
273 })
274 }
275
276 fn load_worktree_info_for_system_prompt(
277 fs: Arc<dyn Fs>,
278 worktree: &Worktree,
279 cx: &App,
280 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
281 let root_name = worktree.root_name().into();
282
283 let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
284 let Some(rules_task) = rules_task else {
285 return Task::ready((
286 WorktreeContext {
287 root_name,
288 rules_file: None,
289 },
290 None,
291 ));
292 };
293
294 cx.spawn(async move |_| {
295 let (rules_file, rules_file_error) = match rules_task.await {
296 Ok(rules_file) => (Some(rules_file), None),
297 Err(err) => (
298 None,
299 Some(RulesLoadingError {
300 message: format!("{err}").into(),
301 }),
302 ),
303 };
304 let worktree_info = WorktreeContext {
305 root_name,
306 rules_file,
307 };
308 (worktree_info, rules_file_error)
309 })
310 }
311
312 fn load_worktree_rules_file(
313 fs: Arc<dyn Fs>,
314 worktree: &Worktree,
315 cx: &App,
316 ) -> Option<Task<Result<RulesFileContext>>> {
317 let selected_rules_file = RULES_FILE_NAMES
318 .into_iter()
319 .filter_map(|name| {
320 worktree
321 .entry_for_path(name)
322 .filter(|entry| entry.is_file())
323 .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
324 })
325 .next();
326
327 // Note that Cline supports `.clinerules` being a directory, but that is not currently
328 // supported. This doesn't seem to occur often in GitHub repositories.
329 selected_rules_file.map(|(path_in_worktree, abs_path)| {
330 let fs = fs.clone();
331 cx.background_spawn(async move {
332 let abs_path = abs_path?;
333 let text = fs.load(&abs_path).await.with_context(|| {
334 format!("Failed to load assistant rules file {:?}", abs_path)
335 })?;
336 anyhow::Ok(RulesFileContext {
337 path_in_worktree,
338 abs_path: abs_path.into(),
339 text: text.trim().to_string(),
340 })
341 })
342 })
343 }
344
345 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
346 self.context_server_manager.clone()
347 }
348
349 pub fn tools(&self) -> Entity<ToolWorkingSet> {
350 self.tools.clone()
351 }
352
353 /// Returns the number of threads.
354 pub fn thread_count(&self) -> usize {
355 self.threads.len()
356 }
357
358 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
359 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
360 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
361 threads
362 }
363
364 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
365 self.threads().into_iter().take(limit).collect()
366 }
367
368 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
369 cx.new(|cx| {
370 Thread::new(
371 self.project.clone(),
372 self.tools.clone(),
373 self.prompt_builder.clone(),
374 self.project_context.clone(),
375 cx,
376 )
377 })
378 }
379
380 pub fn open_thread(
381 &self,
382 id: &ThreadId,
383 cx: &mut Context<Self>,
384 ) -> Task<Result<Entity<Thread>>> {
385 let id = id.clone();
386 let database_future = ThreadsDatabase::global_future(cx);
387 cx.spawn(async move |this, cx| {
388 let database = database_future.await.map_err(|err| anyhow!(err))?;
389 let thread = database
390 .try_find_thread(id.clone())
391 .await?
392 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
393
394 let thread = this.update(cx, |this, cx| {
395 cx.new(|cx| {
396 Thread::deserialize(
397 id.clone(),
398 thread,
399 this.project.clone(),
400 this.tools.clone(),
401 this.prompt_builder.clone(),
402 this.project_context.clone(),
403 cx,
404 )
405 })
406 })?;
407
408 Ok(thread)
409 })
410 }
411
412 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
413 let (metadata, serialized_thread) =
414 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
415
416 let database_future = ThreadsDatabase::global_future(cx);
417 cx.spawn(async move |this, cx| {
418 let serialized_thread = serialized_thread.await?;
419 let database = database_future.await.map_err(|err| anyhow!(err))?;
420 database.save_thread(metadata, serialized_thread).await?;
421
422 this.update(cx, |this, cx| this.reload(cx))?.await
423 })
424 }
425
426 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
427 let id = id.clone();
428 let database_future = ThreadsDatabase::global_future(cx);
429 cx.spawn(async move |this, cx| {
430 let database = database_future.await.map_err(|err| anyhow!(err))?;
431 database.delete_thread(id.clone()).await?;
432
433 this.update(cx, |this, cx| {
434 this.threads.retain(|thread| thread.id != id);
435 cx.notify();
436 })
437 })
438 }
439
440 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
441 let database_future = ThreadsDatabase::global_future(cx);
442 cx.spawn(async move |this, cx| {
443 let threads = database_future
444 .await
445 .map_err(|err| anyhow!(err))?
446 .list_threads()
447 .await?;
448
449 this.update(cx, |this, cx| {
450 this.threads = threads;
451 cx.notify();
452 })
453 })
454 }
455
456 fn load_default_profile(&self, cx: &mut Context<Self>) {
457 let assistant_settings = AssistantSettings::get_global(cx);
458
459 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
460 }
461
462 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
463 let assistant_settings = AssistantSettings::get_global(cx);
464
465 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
466 self.load_profile(profile.clone(), cx);
467 }
468 }
469
470 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
471 self.tools.update(cx, |tools, cx| {
472 tools.disable_all_tools(cx);
473 tools.enable(
474 ToolSource::Native,
475 &profile
476 .tools
477 .iter()
478 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
479 .collect::<Vec<_>>(),
480 cx,
481 );
482 });
483
484 if profile.enable_all_context_servers {
485 for context_server in self.context_server_manager.read(cx).all_servers() {
486 self.tools.update(cx, |tools, cx| {
487 tools.enable_source(
488 ToolSource::ContextServer {
489 id: context_server.id().into(),
490 },
491 cx,
492 );
493 });
494 }
495 } else {
496 for (context_server_id, preset) in &profile.context_servers {
497 self.tools.update(cx, |tools, cx| {
498 tools.enable(
499 ToolSource::ContextServer {
500 id: context_server_id.clone().into(),
501 },
502 &preset
503 .tools
504 .iter()
505 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
506 .collect::<Vec<_>>(),
507 cx,
508 )
509 })
510 }
511 }
512 }
513
514 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
515 cx.subscribe(
516 &self.context_server_manager.clone(),
517 Self::handle_context_server_event,
518 )
519 .detach();
520 }
521
522 fn handle_context_server_event(
523 &mut self,
524 context_server_manager: Entity<ContextServerManager>,
525 event: &context_server::manager::Event,
526 cx: &mut Context<Self>,
527 ) {
528 let tool_working_set = self.tools.clone();
529 match event {
530 context_server::manager::Event::ServerStarted { server_id } => {
531 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
532 let context_server_manager = context_server_manager.clone();
533 cx.spawn({
534 let server = server.clone();
535 let server_id = server_id.clone();
536 async move |this, cx| {
537 let Some(protocol) = server.client() else {
538 return;
539 };
540
541 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
542 if let Some(tools) = protocol.list_tools().await.log_err() {
543 let tool_ids = tool_working_set
544 .update(cx, |tool_working_set, _| {
545 tools
546 .tools
547 .into_iter()
548 .map(|tool| {
549 log::info!(
550 "registering context server tool: {:?}",
551 tool.name
552 );
553 tool_working_set.insert(Arc::new(
554 ContextServerTool::new(
555 context_server_manager.clone(),
556 server.id(),
557 tool,
558 ),
559 ))
560 })
561 .collect::<Vec<_>>()
562 })
563 .log_err();
564
565 if let Some(tool_ids) = tool_ids {
566 this.update(cx, |this, cx| {
567 this.context_server_tool_ids
568 .insert(server_id, tool_ids);
569 this.load_default_profile(cx);
570 })
571 .log_err();
572 }
573 }
574 }
575 }
576 })
577 .detach();
578 }
579 }
580 context_server::manager::Event::ServerStopped { server_id } => {
581 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
582 tool_working_set.update(cx, |tool_working_set, _| {
583 tool_working_set.remove(&tool_ids);
584 });
585 self.load_default_profile(cx);
586 }
587 }
588 }
589 }
590}
591
592#[derive(Debug, Clone, Serialize, Deserialize)]
593pub struct SerializedThreadMetadata {
594 pub id: ThreadId,
595 pub summary: SharedString,
596 pub updated_at: DateTime<Utc>,
597}
598
599#[derive(Serialize, Deserialize, Debug)]
600pub struct SerializedThread {
601 pub version: String,
602 pub summary: SharedString,
603 pub updated_at: DateTime<Utc>,
604 pub messages: Vec<SerializedMessage>,
605 #[serde(default)]
606 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
607 #[serde(default)]
608 pub cumulative_token_usage: TokenUsage,
609 #[serde(default)]
610 pub request_token_usage: Vec<TokenUsage>,
611 #[serde(default)]
612 pub detailed_summary_state: DetailedSummaryState,
613 #[serde(default)]
614 pub exceeded_window_error: Option<ExceededWindowError>,
615}
616
617impl SerializedThread {
618 pub const VERSION: &'static str = "0.1.0";
619
620 pub fn from_json(json: &[u8]) -> Result<Self> {
621 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
622 match saved_thread_json.get("version") {
623 Some(serde_json::Value::String(version)) => match version.as_str() {
624 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
625 saved_thread_json,
626 )?),
627 _ => Err(anyhow!(
628 "unrecognized serialized thread version: {}",
629 version
630 )),
631 },
632 None => {
633 let saved_thread =
634 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
635 Ok(saved_thread.upgrade())
636 }
637 version => Err(anyhow!(
638 "unrecognized serialized thread version: {:?}",
639 version
640 )),
641 }
642 }
643}
644
645#[derive(Debug, Serialize, Deserialize)]
646pub struct SerializedMessage {
647 pub id: MessageId,
648 pub role: Role,
649 #[serde(default)]
650 pub segments: Vec<SerializedMessageSegment>,
651 #[serde(default)]
652 pub tool_uses: Vec<SerializedToolUse>,
653 #[serde(default)]
654 pub tool_results: Vec<SerializedToolResult>,
655 #[serde(default)]
656 pub context: String,
657}
658
659#[derive(Debug, Serialize, Deserialize)]
660#[serde(tag = "type")]
661pub enum SerializedMessageSegment {
662 #[serde(rename = "text")]
663 Text { text: String },
664 #[serde(rename = "thinking")]
665 Thinking { text: String },
666}
667
668#[derive(Debug, Serialize, Deserialize)]
669pub struct SerializedToolUse {
670 pub id: LanguageModelToolUseId,
671 pub name: SharedString,
672 pub input: serde_json::Value,
673}
674
675#[derive(Debug, Serialize, Deserialize)]
676pub struct SerializedToolResult {
677 pub tool_use_id: LanguageModelToolUseId,
678 pub is_error: bool,
679 pub content: Arc<str>,
680}
681
682#[derive(Serialize, Deserialize)]
683struct LegacySerializedThread {
684 pub summary: SharedString,
685 pub updated_at: DateTime<Utc>,
686 pub messages: Vec<LegacySerializedMessage>,
687 #[serde(default)]
688 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
689}
690
691impl LegacySerializedThread {
692 pub fn upgrade(self) -> SerializedThread {
693 SerializedThread {
694 version: SerializedThread::VERSION.to_string(),
695 summary: self.summary,
696 updated_at: self.updated_at,
697 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
698 initial_project_snapshot: self.initial_project_snapshot,
699 cumulative_token_usage: TokenUsage::default(),
700 request_token_usage: Vec::new(),
701 detailed_summary_state: DetailedSummaryState::default(),
702 exceeded_window_error: None,
703 }
704 }
705}
706
707#[derive(Debug, Serialize, Deserialize)]
708struct LegacySerializedMessage {
709 pub id: MessageId,
710 pub role: Role,
711 pub text: String,
712 #[serde(default)]
713 pub tool_uses: Vec<SerializedToolUse>,
714 #[serde(default)]
715 pub tool_results: Vec<SerializedToolResult>,
716}
717
718impl LegacySerializedMessage {
719 fn upgrade(self) -> SerializedMessage {
720 SerializedMessage {
721 id: self.id,
722 role: self.role,
723 segments: vec![SerializedMessageSegment::Text { text: self.text }],
724 tool_uses: self.tool_uses,
725 tool_results: self.tool_results,
726 context: String::new(),
727 }
728 }
729}
730
731struct GlobalThreadsDatabase(
732 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
733);
734
735impl Global for GlobalThreadsDatabase {}
736
737pub(crate) struct ThreadsDatabase {
738 executor: BackgroundExecutor,
739 env: heed::Env,
740 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
741}
742
743impl heed::BytesEncode<'_> for SerializedThread {
744 type EItem = SerializedThread;
745
746 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
747 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
748 }
749}
750
751impl<'a> heed::BytesDecode<'a> for SerializedThread {
752 type DItem = SerializedThread;
753
754 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
755 // We implement this type manually because we want to call `SerializedThread::from_json`,
756 // instead of the Deserialize trait implementation for `SerializedThread`.
757 SerializedThread::from_json(bytes).map_err(Into::into)
758 }
759}
760
761impl ThreadsDatabase {
762 fn global_future(
763 cx: &mut App,
764 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
765 GlobalThreadsDatabase::global(cx).0.clone()
766 }
767
768 fn init(cx: &mut App) {
769 let executor = cx.background_executor().clone();
770 let database_future = executor
771 .spawn({
772 let executor = executor.clone();
773 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
774 async move { ThreadsDatabase::new(database_path, executor) }
775 })
776 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
777 .boxed()
778 .shared();
779
780 cx.set_global(GlobalThreadsDatabase(database_future));
781 }
782
783 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
784 std::fs::create_dir_all(&path)?;
785
786 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
787 let env = unsafe {
788 heed::EnvOpenOptions::new()
789 .map_size(ONE_GB_IN_BYTES)
790 .max_dbs(1)
791 .open(path)?
792 };
793
794 let mut txn = env.write_txn()?;
795 let threads = env.create_database(&mut txn, Some("threads"))?;
796 txn.commit()?;
797
798 Ok(Self {
799 executor,
800 env,
801 threads,
802 })
803 }
804
805 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
806 let env = self.env.clone();
807 let threads = self.threads;
808
809 self.executor.spawn(async move {
810 let txn = env.read_txn()?;
811 let mut iter = threads.iter(&txn)?;
812 let mut threads = Vec::new();
813 while let Some((key, value)) = iter.next().transpose()? {
814 threads.push(SerializedThreadMetadata {
815 id: key,
816 summary: value.summary,
817 updated_at: value.updated_at,
818 });
819 }
820
821 Ok(threads)
822 })
823 }
824
825 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
826 let env = self.env.clone();
827 let threads = self.threads;
828
829 self.executor.spawn(async move {
830 let txn = env.read_txn()?;
831 let thread = threads.get(&txn, &id)?;
832 Ok(thread)
833 })
834 }
835
836 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
837 let env = self.env.clone();
838 let threads = self.threads;
839
840 self.executor.spawn(async move {
841 let mut txn = env.write_txn()?;
842 threads.put(&mut txn, &id, &thread)?;
843 txn.commit()?;
844 Ok(())
845 })
846 }
847
848 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
849 let env = self.env.clone();
850 let threads = self.threads;
851
852 self.executor.spawn(async move {
853 let mut txn = env.write_txn()?;
854 threads.delete(&mut txn, &id)?;
855 txn.commit()?;
856 Ok(())
857 })
858 }
859}