1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::manager::ContextServerManager;
13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
14use fs::Fs;
15use futures::channel::{mpsc, oneshot};
16use futures::future::{self, BoxFuture, Shared};
17use futures::{FutureExt as _, StreamExt as _};
18use gpui::{
19 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
20 Subscription, Task, prelude::*,
21};
22use heed::Database;
23use heed::types::SerdeBincode;
24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
25use project::{Project, Worktree};
26use prompt_store::{
27 DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptId, PromptStore,
28 PromptsUpdatedEvent, RulesFileContext, WorktreeContext,
29};
30use serde::{Deserialize, Serialize};
31use settings::{Settings as _, SettingsStore};
32use util::ResultExt as _;
33
34use crate::thread::{
35 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
36};
37
38const RULES_FILE_NAMES: [&'static str; 6] = [
39 ".rules",
40 ".cursorrules",
41 ".windsurfrules",
42 ".clinerules",
43 ".github/copilot-instructions.md",
44 "CLAUDE.md",
45];
46
47pub fn init(cx: &mut App) {
48 ThreadsDatabase::init(cx);
49}
50
51/// A system prompt shared by all threads created by this ThreadStore
52#[derive(Clone, Default)]
53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
54
55impl SharedProjectContext {
56 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
57 self.0.borrow()
58 }
59}
60
61pub struct ThreadStore {
62 project: Entity<Project>,
63 tools: Entity<ToolWorkingSet>,
64 prompt_builder: Arc<PromptBuilder>,
65 context_server_manager: Entity<ContextServerManager>,
66 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
67 threads: Vec<SerializedThreadMetadata>,
68 project_context: SharedProjectContext,
69 reload_system_prompt_tx: mpsc::Sender<()>,
70 _reload_system_prompt_task: Task<()>,
71 _subscriptions: Vec<Subscription>,
72}
73
74pub struct RulesLoadingError {
75 pub message: SharedString,
76}
77
78impl EventEmitter<RulesLoadingError> for ThreadStore {}
79
80impl ThreadStore {
81 pub fn load(
82 project: Entity<Project>,
83 tools: Entity<ToolWorkingSet>,
84 prompt_builder: Arc<PromptBuilder>,
85 cx: &mut App,
86 ) -> Task<Result<Entity<Self>>> {
87 let prompt_store = PromptStore::global(cx);
88 cx.spawn(async move |cx| {
89 let prompt_store = prompt_store.await.ok();
90 let (thread_store, ready_rx) = cx.update(|cx| {
91 let mut option_ready_rx = None;
92 let thread_store = cx.new(|cx| {
93 let (thread_store, ready_rx) =
94 Self::new(project, tools, prompt_builder, prompt_store, cx);
95 option_ready_rx = Some(ready_rx);
96 thread_store
97 });
98 (thread_store, option_ready_rx.take().unwrap())
99 })?;
100 ready_rx.await?;
101 Ok(thread_store)
102 })
103 }
104
105 fn new(
106 project: Entity<Project>,
107 tools: Entity<ToolWorkingSet>,
108 prompt_builder: Arc<PromptBuilder>,
109 prompt_store: Option<Entity<PromptStore>>,
110 cx: &mut Context<Self>,
111 ) -> (Self, oneshot::Receiver<()>) {
112 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
113 let context_server_manager = cx.new(|cx| {
114 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
115 });
116
117 let mut subscriptions = vec![
118 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
119 this.load_default_profile(cx);
120 }),
121 cx.subscribe(&project, Self::handle_project_event),
122 ];
123
124 if let Some(prompt_store) = prompt_store.as_ref() {
125 subscriptions.push(cx.subscribe(
126 prompt_store,
127 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
128 this.enqueue_system_prompt_reload();
129 },
130 ))
131 }
132
133 // This channel and task prevent concurrent and redundant loading of the system prompt.
134 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
135 let (ready_tx, ready_rx) = oneshot::channel();
136 let mut ready_tx = Some(ready_tx);
137 let reload_system_prompt_task = cx.spawn({
138 async move |thread_store, cx| {
139 loop {
140 let Some(reload_task) = thread_store
141 .update(cx, |thread_store, cx| {
142 thread_store.reload_system_prompt(prompt_store.clone(), cx)
143 })
144 .ok()
145 else {
146 return;
147 };
148 reload_task.await;
149 if let Some(ready_tx) = ready_tx.take() {
150 ready_tx.send(()).ok();
151 }
152 reload_system_prompt_rx.next().await;
153 }
154 }
155 });
156
157 let this = Self {
158 project,
159 tools,
160 prompt_builder,
161 context_server_manager,
162 context_server_tool_ids: HashMap::default(),
163 threads: Vec::new(),
164 project_context: SharedProjectContext::default(),
165 reload_system_prompt_tx,
166 _reload_system_prompt_task: reload_system_prompt_task,
167 _subscriptions: subscriptions,
168 };
169 this.load_default_profile(cx);
170 this.register_context_server_handlers(cx);
171 this.reload(cx).detach_and_log_err(cx);
172 (this, ready_rx)
173 }
174
175 fn handle_project_event(
176 &mut self,
177 _project: Entity<Project>,
178 event: &project::Event,
179 _cx: &mut Context<Self>,
180 ) {
181 match event {
182 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
183 self.enqueue_system_prompt_reload();
184 }
185 project::Event::WorktreeUpdatedEntries(_, items) => {
186 if items.iter().any(|(path, _, _)| {
187 RULES_FILE_NAMES
188 .iter()
189 .any(|name| path.as_ref() == Path::new(name))
190 }) {
191 self.enqueue_system_prompt_reload();
192 }
193 }
194 _ => {}
195 }
196 }
197
198 fn enqueue_system_prompt_reload(&mut self) {
199 self.reload_system_prompt_tx.try_send(()).ok();
200 }
201
202 // Note that this should only be called from `reload_system_prompt_task`.
203 fn reload_system_prompt(
204 &self,
205 prompt_store: Option<Entity<PromptStore>>,
206 cx: &mut Context<Self>,
207 ) -> Task<()> {
208 let project = self.project.read(cx);
209 let worktree_tasks = project
210 .visible_worktrees(cx)
211 .map(|worktree| {
212 Self::load_worktree_info_for_system_prompt(
213 project.fs().clone(),
214 worktree.read(cx),
215 cx,
216 )
217 })
218 .collect::<Vec<_>>();
219 let default_user_rules_task = match prompt_store {
220 None => Task::ready(vec![]),
221 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
222 let prompts = prompt_store.default_prompt_metadata();
223 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
224 let contents = prompt_store.load(prompt_metadata.id, cx);
225 async move { (contents.await, prompt_metadata) }
226 });
227 cx.background_spawn(future::join_all(load_tasks))
228 }),
229 };
230
231 cx.spawn(async move |this, cx| {
232 let (worktrees, default_user_rules) =
233 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
234
235 let worktrees = worktrees
236 .into_iter()
237 .map(|(worktree, rules_error)| {
238 if let Some(rules_error) = rules_error {
239 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
240 }
241 worktree
242 })
243 .collect::<Vec<_>>();
244
245 let default_user_rules = default_user_rules
246 .into_iter()
247 .flat_map(|(contents, prompt_metadata)| match contents {
248 Ok(contents) => Some(DefaultUserRulesContext {
249 uuid: match prompt_metadata.id {
250 PromptId::User { uuid } => uuid,
251 PromptId::EditWorkflow => return None,
252 },
253 title: prompt_metadata.title.map(|title| title.to_string()),
254 contents,
255 }),
256 Err(err) => {
257 this.update(cx, |_, cx| {
258 cx.emit(RulesLoadingError {
259 message: format!("{err:?}").into(),
260 });
261 })
262 .ok();
263 None
264 }
265 })
266 .collect::<Vec<_>>();
267
268 this.update(cx, |this, _cx| {
269 *this.project_context.0.borrow_mut() =
270 Some(ProjectContext::new(worktrees, default_user_rules));
271 })
272 .ok();
273 })
274 }
275
276 fn load_worktree_info_for_system_prompt(
277 fs: Arc<dyn Fs>,
278 worktree: &Worktree,
279 cx: &App,
280 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
281 let root_name = worktree.root_name().into();
282 let abs_path = worktree.abs_path();
283
284 let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
285 let Some(rules_task) = rules_task else {
286 return Task::ready((
287 WorktreeContext {
288 root_name,
289 abs_path,
290 rules_file: None,
291 },
292 None,
293 ));
294 };
295
296 cx.spawn(async move |_| {
297 let (rules_file, rules_file_error) = match rules_task.await {
298 Ok(rules_file) => (Some(rules_file), None),
299 Err(err) => (
300 None,
301 Some(RulesLoadingError {
302 message: format!("{err}").into(),
303 }),
304 ),
305 };
306 let worktree_info = WorktreeContext {
307 root_name,
308 abs_path,
309 rules_file,
310 };
311 (worktree_info, rules_file_error)
312 })
313 }
314
315 fn load_worktree_rules_file(
316 fs: Arc<dyn Fs>,
317 worktree: &Worktree,
318 cx: &App,
319 ) -> Option<Task<Result<RulesFileContext>>> {
320 let selected_rules_file = RULES_FILE_NAMES
321 .into_iter()
322 .filter_map(|name| {
323 worktree
324 .entry_for_path(name)
325 .filter(|entry| entry.is_file())
326 .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
327 })
328 .next();
329
330 // Note that Cline supports `.clinerules` being a directory, but that is not currently
331 // supported. This doesn't seem to occur often in GitHub repositories.
332 selected_rules_file.map(|(path_in_worktree, abs_path)| {
333 let fs = fs.clone();
334 cx.background_spawn(async move {
335 let abs_path = abs_path?;
336 let text = fs.load(&abs_path).await.with_context(|| {
337 format!("Failed to load assistant rules file {:?}", abs_path)
338 })?;
339 anyhow::Ok(RulesFileContext {
340 path_in_worktree,
341 abs_path: abs_path.into(),
342 text: text.trim().to_string(),
343 })
344 })
345 })
346 }
347
348 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
349 self.context_server_manager.clone()
350 }
351
352 pub fn tools(&self) -> Entity<ToolWorkingSet> {
353 self.tools.clone()
354 }
355
356 /// Returns the number of threads.
357 pub fn thread_count(&self) -> usize {
358 self.threads.len()
359 }
360
361 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
362 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
363 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
364 threads
365 }
366
367 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
368 self.threads().into_iter().take(limit).collect()
369 }
370
371 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
372 cx.new(|cx| {
373 Thread::new(
374 self.project.clone(),
375 self.tools.clone(),
376 self.prompt_builder.clone(),
377 self.project_context.clone(),
378 cx,
379 )
380 })
381 }
382
383 pub fn open_thread(
384 &self,
385 id: &ThreadId,
386 cx: &mut Context<Self>,
387 ) -> Task<Result<Entity<Thread>>> {
388 let id = id.clone();
389 let database_future = ThreadsDatabase::global_future(cx);
390 cx.spawn(async move |this, cx| {
391 let database = database_future.await.map_err(|err| anyhow!(err))?;
392 let thread = database
393 .try_find_thread(id.clone())
394 .await?
395 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
396
397 let thread = this.update(cx, |this, cx| {
398 cx.new(|cx| {
399 Thread::deserialize(
400 id.clone(),
401 thread,
402 this.project.clone(),
403 this.tools.clone(),
404 this.prompt_builder.clone(),
405 this.project_context.clone(),
406 cx,
407 )
408 })
409 })?;
410
411 Ok(thread)
412 })
413 }
414
415 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
416 let (metadata, serialized_thread) =
417 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
418
419 let database_future = ThreadsDatabase::global_future(cx);
420 cx.spawn(async move |this, cx| {
421 let serialized_thread = serialized_thread.await?;
422 let database = database_future.await.map_err(|err| anyhow!(err))?;
423 database.save_thread(metadata, serialized_thread).await?;
424
425 this.update(cx, |this, cx| this.reload(cx))?.await
426 })
427 }
428
429 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
430 let id = id.clone();
431 let database_future = ThreadsDatabase::global_future(cx);
432 cx.spawn(async move |this, cx| {
433 let database = database_future.await.map_err(|err| anyhow!(err))?;
434 database.delete_thread(id.clone()).await?;
435
436 this.update(cx, |this, cx| {
437 this.threads.retain(|thread| thread.id != id);
438 cx.notify();
439 })
440 })
441 }
442
443 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
444 let database_future = ThreadsDatabase::global_future(cx);
445 cx.spawn(async move |this, cx| {
446 let threads = database_future
447 .await
448 .map_err(|err| anyhow!(err))?
449 .list_threads()
450 .await?;
451
452 this.update(cx, |this, cx| {
453 this.threads = threads;
454 cx.notify();
455 })
456 })
457 }
458
459 fn load_default_profile(&self, cx: &mut Context<Self>) {
460 let assistant_settings = AssistantSettings::get_global(cx);
461
462 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
463 }
464
465 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
466 let assistant_settings = AssistantSettings::get_global(cx);
467
468 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
469 self.load_profile(profile.clone(), cx);
470 }
471 }
472
473 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
474 self.tools.update(cx, |tools, cx| {
475 tools.disable_all_tools(cx);
476 tools.enable(
477 ToolSource::Native,
478 &profile
479 .tools
480 .iter()
481 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
482 .collect::<Vec<_>>(),
483 cx,
484 );
485 });
486
487 if profile.enable_all_context_servers {
488 for context_server in self.context_server_manager.read(cx).all_servers() {
489 self.tools.update(cx, |tools, cx| {
490 tools.enable_source(
491 ToolSource::ContextServer {
492 id: context_server.id().into(),
493 },
494 cx,
495 );
496 });
497 }
498 } else {
499 for (context_server_id, preset) in &profile.context_servers {
500 self.tools.update(cx, |tools, cx| {
501 tools.enable(
502 ToolSource::ContextServer {
503 id: context_server_id.clone().into(),
504 },
505 &preset
506 .tools
507 .iter()
508 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
509 .collect::<Vec<_>>(),
510 cx,
511 )
512 })
513 }
514 }
515 }
516
517 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
518 cx.subscribe(
519 &self.context_server_manager.clone(),
520 Self::handle_context_server_event,
521 )
522 .detach();
523 }
524
525 fn handle_context_server_event(
526 &mut self,
527 context_server_manager: Entity<ContextServerManager>,
528 event: &context_server::manager::Event,
529 cx: &mut Context<Self>,
530 ) {
531 let tool_working_set = self.tools.clone();
532 match event {
533 context_server::manager::Event::ServerStarted { server_id } => {
534 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
535 let context_server_manager = context_server_manager.clone();
536 cx.spawn({
537 let server = server.clone();
538 let server_id = server_id.clone();
539 async move |this, cx| {
540 let Some(protocol) = server.client() else {
541 return;
542 };
543
544 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
545 if let Some(tools) = protocol.list_tools().await.log_err() {
546 let tool_ids = tool_working_set
547 .update(cx, |tool_working_set, _| {
548 tools
549 .tools
550 .into_iter()
551 .map(|tool| {
552 log::info!(
553 "registering context server tool: {:?}",
554 tool.name
555 );
556 tool_working_set.insert(Arc::new(
557 ContextServerTool::new(
558 context_server_manager.clone(),
559 server.id(),
560 tool,
561 ),
562 ))
563 })
564 .collect::<Vec<_>>()
565 })
566 .log_err();
567
568 if let Some(tool_ids) = tool_ids {
569 this.update(cx, |this, cx| {
570 this.context_server_tool_ids
571 .insert(server_id, tool_ids);
572 this.load_default_profile(cx);
573 })
574 .log_err();
575 }
576 }
577 }
578 }
579 })
580 .detach();
581 }
582 }
583 context_server::manager::Event::ServerStopped { server_id } => {
584 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
585 tool_working_set.update(cx, |tool_working_set, _| {
586 tool_working_set.remove(&tool_ids);
587 });
588 self.load_default_profile(cx);
589 }
590 }
591 }
592 }
593}
594
595#[derive(Debug, Clone, Serialize, Deserialize)]
596pub struct SerializedThreadMetadata {
597 pub id: ThreadId,
598 pub summary: SharedString,
599 pub updated_at: DateTime<Utc>,
600}
601
602#[derive(Serialize, Deserialize, Debug)]
603pub struct SerializedThread {
604 pub version: String,
605 pub summary: SharedString,
606 pub updated_at: DateTime<Utc>,
607 pub messages: Vec<SerializedMessage>,
608 #[serde(default)]
609 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
610 #[serde(default)]
611 pub cumulative_token_usage: TokenUsage,
612 #[serde(default)]
613 pub request_token_usage: Vec<TokenUsage>,
614 #[serde(default)]
615 pub detailed_summary_state: DetailedSummaryState,
616 #[serde(default)]
617 pub exceeded_window_error: Option<ExceededWindowError>,
618}
619
620impl SerializedThread {
621 pub const VERSION: &'static str = "0.1.0";
622
623 pub fn from_json(json: &[u8]) -> Result<Self> {
624 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
625 match saved_thread_json.get("version") {
626 Some(serde_json::Value::String(version)) => match version.as_str() {
627 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
628 saved_thread_json,
629 )?),
630 _ => Err(anyhow!(
631 "unrecognized serialized thread version: {}",
632 version
633 )),
634 },
635 None => {
636 let saved_thread =
637 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
638 Ok(saved_thread.upgrade())
639 }
640 version => Err(anyhow!(
641 "unrecognized serialized thread version: {:?}",
642 version
643 )),
644 }
645 }
646}
647
648#[derive(Debug, Serialize, Deserialize)]
649pub struct SerializedMessage {
650 pub id: MessageId,
651 pub role: Role,
652 #[serde(default)]
653 pub segments: Vec<SerializedMessageSegment>,
654 #[serde(default)]
655 pub tool_uses: Vec<SerializedToolUse>,
656 #[serde(default)]
657 pub tool_results: Vec<SerializedToolResult>,
658 #[serde(default)]
659 pub context: String,
660}
661
662#[derive(Debug, Serialize, Deserialize)]
663#[serde(tag = "type")]
664pub enum SerializedMessageSegment {
665 #[serde(rename = "text")]
666 Text { text: String },
667 #[serde(rename = "thinking")]
668 Thinking { text: String },
669}
670
671#[derive(Debug, Serialize, Deserialize)]
672pub struct SerializedToolUse {
673 pub id: LanguageModelToolUseId,
674 pub name: SharedString,
675 pub input: serde_json::Value,
676}
677
678#[derive(Debug, Serialize, Deserialize)]
679pub struct SerializedToolResult {
680 pub tool_use_id: LanguageModelToolUseId,
681 pub is_error: bool,
682 pub content: Arc<str>,
683}
684
685#[derive(Serialize, Deserialize)]
686struct LegacySerializedThread {
687 pub summary: SharedString,
688 pub updated_at: DateTime<Utc>,
689 pub messages: Vec<LegacySerializedMessage>,
690 #[serde(default)]
691 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
692}
693
694impl LegacySerializedThread {
695 pub fn upgrade(self) -> SerializedThread {
696 SerializedThread {
697 version: SerializedThread::VERSION.to_string(),
698 summary: self.summary,
699 updated_at: self.updated_at,
700 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
701 initial_project_snapshot: self.initial_project_snapshot,
702 cumulative_token_usage: TokenUsage::default(),
703 request_token_usage: Vec::new(),
704 detailed_summary_state: DetailedSummaryState::default(),
705 exceeded_window_error: None,
706 }
707 }
708}
709
710#[derive(Debug, Serialize, Deserialize)]
711struct LegacySerializedMessage {
712 pub id: MessageId,
713 pub role: Role,
714 pub text: String,
715 #[serde(default)]
716 pub tool_uses: Vec<SerializedToolUse>,
717 #[serde(default)]
718 pub tool_results: Vec<SerializedToolResult>,
719}
720
721impl LegacySerializedMessage {
722 fn upgrade(self) -> SerializedMessage {
723 SerializedMessage {
724 id: self.id,
725 role: self.role,
726 segments: vec![SerializedMessageSegment::Text { text: self.text }],
727 tool_uses: self.tool_uses,
728 tool_results: self.tool_results,
729 context: String::new(),
730 }
731 }
732}
733
734struct GlobalThreadsDatabase(
735 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
736);
737
738impl Global for GlobalThreadsDatabase {}
739
740pub(crate) struct ThreadsDatabase {
741 executor: BackgroundExecutor,
742 env: heed::Env,
743 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
744}
745
746impl heed::BytesEncode<'_> for SerializedThread {
747 type EItem = SerializedThread;
748
749 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
750 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
751 }
752}
753
754impl<'a> heed::BytesDecode<'a> for SerializedThread {
755 type DItem = SerializedThread;
756
757 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
758 // We implement this type manually because we want to call `SerializedThread::from_json`,
759 // instead of the Deserialize trait implementation for `SerializedThread`.
760 SerializedThread::from_json(bytes).map_err(Into::into)
761 }
762}
763
764impl ThreadsDatabase {
765 fn global_future(
766 cx: &mut App,
767 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
768 GlobalThreadsDatabase::global(cx).0.clone()
769 }
770
771 fn init(cx: &mut App) {
772 let executor = cx.background_executor().clone();
773 let database_future = executor
774 .spawn({
775 let executor = executor.clone();
776 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
777 async move { ThreadsDatabase::new(database_path, executor) }
778 })
779 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
780 .boxed()
781 .shared();
782
783 cx.set_global(GlobalThreadsDatabase(database_future));
784 }
785
786 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
787 std::fs::create_dir_all(&path)?;
788
789 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
790 let env = unsafe {
791 heed::EnvOpenOptions::new()
792 .map_size(ONE_GB_IN_BYTES)
793 .max_dbs(1)
794 .open(path)?
795 };
796
797 let mut txn = env.write_txn()?;
798 let threads = env.create_database(&mut txn, Some("threads"))?;
799 txn.commit()?;
800
801 Ok(Self {
802 executor,
803 env,
804 threads,
805 })
806 }
807
808 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
809 let env = self.env.clone();
810 let threads = self.threads;
811
812 self.executor.spawn(async move {
813 let txn = env.read_txn()?;
814 let mut iter = threads.iter(&txn)?;
815 let mut threads = Vec::new();
816 while let Some((key, value)) = iter.next().transpose()? {
817 threads.push(SerializedThreadMetadata {
818 id: key,
819 summary: value.summary,
820 updated_at: value.updated_at,
821 });
822 }
823
824 Ok(threads)
825 })
826 }
827
828 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
829 let env = self.env.clone();
830 let threads = self.threads;
831
832 self.executor.spawn(async move {
833 let txn = env.read_txn()?;
834 let thread = threads.get(&txn, &id)?;
835 Ok(thread)
836 })
837 }
838
839 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
840 let env = self.env.clone();
841 let threads = self.threads;
842
843 self.executor.spawn(async move {
844 let mut txn = env.write_txn()?;
845 threads.put(&mut txn, &id, &thread)?;
846 txn.commit()?;
847 Ok(())
848 })
849 }
850
851 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
852 let env = self.env.clone();
853 let threads = self.threads;
854
855 self.executor.spawn(async move {
856 let mut txn = env.write_txn()?;
857 threads.delete(&mut txn, &id)?;
858 txn.commit()?;
859 Ok(())
860 })
861 }
862}