1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::manager::ContextServerManager;
13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
14use fs::Fs;
15use futures::channel::{mpsc, oneshot};
16use futures::future::{self, BoxFuture, Shared};
17use futures::{FutureExt as _, StreamExt as _};
18use gpui::{
19 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
20 Subscription, Task, prelude::*,
21};
22use heed::Database;
23use heed::types::SerdeBincode;
24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
25use project::{Project, Worktree};
26use prompt_store::{
27 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
28 UserRulesContext, WorktreeContext,
29};
30use serde::{Deserialize, Serialize};
31use settings::{Settings as _, SettingsStore};
32use util::ResultExt as _;
33
34use crate::thread::{
35 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
36};
37
38const RULES_FILE_NAMES: [&'static str; 6] = [
39 ".rules",
40 ".cursorrules",
41 ".windsurfrules",
42 ".clinerules",
43 ".github/copilot-instructions.md",
44 "CLAUDE.md",
45];
46
47pub fn init(cx: &mut App) {
48 ThreadsDatabase::init(cx);
49}
50
51/// A system prompt shared by all threads created by this ThreadStore
52#[derive(Clone, Default)]
53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
54
55impl SharedProjectContext {
56 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
57 self.0.borrow()
58 }
59}
60
61pub struct ThreadStore {
62 project: Entity<Project>,
63 tools: Entity<ToolWorkingSet>,
64 prompt_builder: Arc<PromptBuilder>,
65 prompt_store: Option<Entity<PromptStore>>,
66 context_server_manager: Entity<ContextServerManager>,
67 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
68 threads: Vec<SerializedThreadMetadata>,
69 project_context: SharedProjectContext,
70 reload_system_prompt_tx: mpsc::Sender<()>,
71 _reload_system_prompt_task: Task<()>,
72 _subscriptions: Vec<Subscription>,
73}
74
75pub struct RulesLoadingError {
76 pub message: SharedString,
77}
78
79impl EventEmitter<RulesLoadingError> for ThreadStore {}
80
81impl ThreadStore {
82 pub fn load(
83 project: Entity<Project>,
84 tools: Entity<ToolWorkingSet>,
85 prompt_store: Option<Entity<PromptStore>>,
86 prompt_builder: Arc<PromptBuilder>,
87 cx: &mut App,
88 ) -> Task<Result<Entity<Self>>> {
89 cx.spawn(async move |cx| {
90 let (thread_store, ready_rx) = cx.update(|cx| {
91 let mut option_ready_rx = None;
92 let thread_store = cx.new(|cx| {
93 let (thread_store, ready_rx) =
94 Self::new(project, tools, prompt_builder, prompt_store, cx);
95 option_ready_rx = Some(ready_rx);
96 thread_store
97 });
98 (thread_store, option_ready_rx.take().unwrap())
99 })?;
100 ready_rx.await?;
101 Ok(thread_store)
102 })
103 }
104
105 fn new(
106 project: Entity<Project>,
107 tools: Entity<ToolWorkingSet>,
108 prompt_builder: Arc<PromptBuilder>,
109 prompt_store: Option<Entity<PromptStore>>,
110 cx: &mut Context<Self>,
111 ) -> (Self, oneshot::Receiver<()>) {
112 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
113 let context_server_manager = cx.new(|cx| {
114 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
115 });
116
117 let mut subscriptions = vec![
118 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
119 this.load_default_profile(cx);
120 }),
121 cx.subscribe(&project, Self::handle_project_event),
122 ];
123
124 if let Some(prompt_store) = prompt_store.as_ref() {
125 subscriptions.push(cx.subscribe(
126 prompt_store,
127 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
128 this.enqueue_system_prompt_reload();
129 },
130 ))
131 }
132
133 // This channel and task prevent concurrent and redundant loading of the system prompt.
134 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
135 let (ready_tx, ready_rx) = oneshot::channel();
136 let mut ready_tx = Some(ready_tx);
137 let reload_system_prompt_task = cx.spawn({
138 let prompt_store = prompt_store.clone();
139 async move |thread_store, cx| {
140 loop {
141 let Some(reload_task) = thread_store
142 .update(cx, |thread_store, cx| {
143 thread_store.reload_system_prompt(prompt_store.clone(), cx)
144 })
145 .ok()
146 else {
147 return;
148 };
149 reload_task.await;
150 if let Some(ready_tx) = ready_tx.take() {
151 ready_tx.send(()).ok();
152 }
153 reload_system_prompt_rx.next().await;
154 }
155 }
156 });
157
158 let this = Self {
159 project,
160 tools,
161 prompt_builder,
162 prompt_store,
163 context_server_manager,
164 context_server_tool_ids: HashMap::default(),
165 threads: Vec::new(),
166 project_context: SharedProjectContext::default(),
167 reload_system_prompt_tx,
168 _reload_system_prompt_task: reload_system_prompt_task,
169 _subscriptions: subscriptions,
170 };
171 this.load_default_profile(cx);
172 this.register_context_server_handlers(cx);
173 this.reload(cx).detach_and_log_err(cx);
174 (this, ready_rx)
175 }
176
177 fn handle_project_event(
178 &mut self,
179 _project: Entity<Project>,
180 event: &project::Event,
181 _cx: &mut Context<Self>,
182 ) {
183 match event {
184 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
185 self.enqueue_system_prompt_reload();
186 }
187 project::Event::WorktreeUpdatedEntries(_, items) => {
188 if items.iter().any(|(path, _, _)| {
189 RULES_FILE_NAMES
190 .iter()
191 .any(|name| path.as_ref() == Path::new(name))
192 }) {
193 self.enqueue_system_prompt_reload();
194 }
195 }
196 _ => {}
197 }
198 }
199
200 fn enqueue_system_prompt_reload(&mut self) {
201 self.reload_system_prompt_tx.try_send(()).ok();
202 }
203
204 // Note that this should only be called from `reload_system_prompt_task`.
205 fn reload_system_prompt(
206 &self,
207 prompt_store: Option<Entity<PromptStore>>,
208 cx: &mut Context<Self>,
209 ) -> Task<()> {
210 let project = self.project.read(cx);
211 let worktree_tasks = project
212 .visible_worktrees(cx)
213 .map(|worktree| {
214 Self::load_worktree_info_for_system_prompt(
215 project.fs().clone(),
216 worktree.read(cx),
217 cx,
218 )
219 })
220 .collect::<Vec<_>>();
221 let default_user_rules_task = match prompt_store {
222 None => Task::ready(vec![]),
223 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
224 let prompts = prompt_store.default_prompt_metadata();
225 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
226 let contents = prompt_store.load(prompt_metadata.id, cx);
227 async move { (contents.await, prompt_metadata) }
228 });
229 cx.background_spawn(future::join_all(load_tasks))
230 }),
231 };
232
233 cx.spawn(async move |this, cx| {
234 let (worktrees, default_user_rules) =
235 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
236
237 let worktrees = worktrees
238 .into_iter()
239 .map(|(worktree, rules_error)| {
240 if let Some(rules_error) = rules_error {
241 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
242 }
243 worktree
244 })
245 .collect::<Vec<_>>();
246
247 let default_user_rules = default_user_rules
248 .into_iter()
249 .flat_map(|(contents, prompt_metadata)| match contents {
250 Ok(contents) => Some(UserRulesContext {
251 uuid: match prompt_metadata.id {
252 PromptId::User { uuid } => uuid,
253 PromptId::EditWorkflow => return None,
254 },
255 title: prompt_metadata.title.map(|title| title.to_string()),
256 contents,
257 }),
258 Err(err) => {
259 this.update(cx, |_, cx| {
260 cx.emit(RulesLoadingError {
261 message: format!("{err:?}").into(),
262 });
263 })
264 .ok();
265 None
266 }
267 })
268 .collect::<Vec<_>>();
269
270 this.update(cx, |this, _cx| {
271 *this.project_context.0.borrow_mut() =
272 Some(ProjectContext::new(worktrees, default_user_rules));
273 })
274 .ok();
275 })
276 }
277
278 fn load_worktree_info_for_system_prompt(
279 fs: Arc<dyn Fs>,
280 worktree: &Worktree,
281 cx: &App,
282 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
283 let root_name = worktree.root_name().into();
284
285 let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
286 let Some(rules_task) = rules_task else {
287 return Task::ready((
288 WorktreeContext {
289 root_name,
290 rules_file: None,
291 },
292 None,
293 ));
294 };
295
296 cx.spawn(async move |_| {
297 let (rules_file, rules_file_error) = match rules_task.await {
298 Ok(rules_file) => (Some(rules_file), None),
299 Err(err) => (
300 None,
301 Some(RulesLoadingError {
302 message: format!("{err}").into(),
303 }),
304 ),
305 };
306 let worktree_info = WorktreeContext {
307 root_name,
308 rules_file,
309 };
310 (worktree_info, rules_file_error)
311 })
312 }
313
314 fn load_worktree_rules_file(
315 fs: Arc<dyn Fs>,
316 worktree: &Worktree,
317 cx: &App,
318 ) -> Option<Task<Result<RulesFileContext>>> {
319 let selected_rules_file = RULES_FILE_NAMES
320 .into_iter()
321 .filter_map(|name| {
322 worktree
323 .entry_for_path(name)
324 .filter(|entry| entry.is_file())
325 .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
326 })
327 .next();
328
329 // Note that Cline supports `.clinerules` being a directory, but that is not currently
330 // supported. This doesn't seem to occur often in GitHub repositories.
331 selected_rules_file.map(|(path_in_worktree, abs_path)| {
332 let fs = fs.clone();
333 cx.background_spawn(async move {
334 let abs_path = abs_path?;
335 let text = fs.load(&abs_path).await.with_context(|| {
336 format!("Failed to load assistant rules file {:?}", abs_path)
337 })?;
338 anyhow::Ok(RulesFileContext {
339 path_in_worktree,
340 abs_path: abs_path.into(),
341 text: text.trim().to_string(),
342 })
343 })
344 })
345 }
346
347 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
348 self.context_server_manager.clone()
349 }
350
351 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
352 &self.prompt_store
353 }
354
355 pub fn tools(&self) -> Entity<ToolWorkingSet> {
356 self.tools.clone()
357 }
358
359 /// Returns the number of threads.
360 pub fn thread_count(&self) -> usize {
361 self.threads.len()
362 }
363
364 pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
365 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
366 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
367 threads
368 }
369
370 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
371 cx.new(|cx| {
372 Thread::new(
373 self.project.clone(),
374 self.tools.clone(),
375 self.prompt_builder.clone(),
376 self.project_context.clone(),
377 cx,
378 )
379 })
380 }
381
382 pub fn open_thread(
383 &self,
384 id: &ThreadId,
385 cx: &mut Context<Self>,
386 ) -> Task<Result<Entity<Thread>>> {
387 let id = id.clone();
388 let database_future = ThreadsDatabase::global_future(cx);
389 cx.spawn(async move |this, cx| {
390 let database = database_future.await.map_err(|err| anyhow!(err))?;
391 let thread = database
392 .try_find_thread(id.clone())
393 .await?
394 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
395
396 let thread = this.update(cx, |this, cx| {
397 cx.new(|cx| {
398 Thread::deserialize(
399 id.clone(),
400 thread,
401 this.project.clone(),
402 this.tools.clone(),
403 this.prompt_builder.clone(),
404 this.project_context.clone(),
405 cx,
406 )
407 })
408 })?;
409
410 Ok(thread)
411 })
412 }
413
414 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
415 let (metadata, serialized_thread) =
416 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
417
418 let database_future = ThreadsDatabase::global_future(cx);
419 cx.spawn(async move |this, cx| {
420 let serialized_thread = serialized_thread.await?;
421 let database = database_future.await.map_err(|err| anyhow!(err))?;
422 database.save_thread(metadata, serialized_thread).await?;
423
424 this.update(cx, |this, cx| this.reload(cx))?.await
425 })
426 }
427
428 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
429 let id = id.clone();
430 let database_future = ThreadsDatabase::global_future(cx);
431 cx.spawn(async move |this, cx| {
432 let database = database_future.await.map_err(|err| anyhow!(err))?;
433 database.delete_thread(id.clone()).await?;
434
435 this.update(cx, |this, cx| {
436 this.threads.retain(|thread| thread.id != id);
437 cx.notify();
438 })
439 })
440 }
441
442 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
443 let database_future = ThreadsDatabase::global_future(cx);
444 cx.spawn(async move |this, cx| {
445 let threads = database_future
446 .await
447 .map_err(|err| anyhow!(err))?
448 .list_threads()
449 .await?;
450
451 this.update(cx, |this, cx| {
452 this.threads = threads;
453 cx.notify();
454 })
455 })
456 }
457
458 fn load_default_profile(&self, cx: &mut Context<Self>) {
459 let assistant_settings = AssistantSettings::get_global(cx);
460
461 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
462 }
463
464 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
465 let assistant_settings = AssistantSettings::get_global(cx);
466
467 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
468 self.load_profile(profile.clone(), cx);
469 }
470 }
471
472 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
473 self.tools.update(cx, |tools, cx| {
474 tools.disable_all_tools(cx);
475 tools.enable(
476 ToolSource::Native,
477 &profile
478 .tools
479 .iter()
480 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
481 .collect::<Vec<_>>(),
482 cx,
483 );
484 });
485
486 if profile.enable_all_context_servers {
487 for context_server in self.context_server_manager.read(cx).all_servers() {
488 self.tools.update(cx, |tools, cx| {
489 tools.enable_source(
490 ToolSource::ContextServer {
491 id: context_server.id().into(),
492 },
493 cx,
494 );
495 });
496 }
497 } else {
498 for (context_server_id, preset) in &profile.context_servers {
499 self.tools.update(cx, |tools, cx| {
500 tools.enable(
501 ToolSource::ContextServer {
502 id: context_server_id.clone().into(),
503 },
504 &preset
505 .tools
506 .iter()
507 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
508 .collect::<Vec<_>>(),
509 cx,
510 )
511 })
512 }
513 }
514 }
515
516 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
517 cx.subscribe(
518 &self.context_server_manager.clone(),
519 Self::handle_context_server_event,
520 )
521 .detach();
522 }
523
524 fn handle_context_server_event(
525 &mut self,
526 context_server_manager: Entity<ContextServerManager>,
527 event: &context_server::manager::Event,
528 cx: &mut Context<Self>,
529 ) {
530 let tool_working_set = self.tools.clone();
531 match event {
532 context_server::manager::Event::ServerStarted { server_id } => {
533 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
534 let context_server_manager = context_server_manager.clone();
535 cx.spawn({
536 let server = server.clone();
537 let server_id = server_id.clone();
538 async move |this, cx| {
539 let Some(protocol) = server.client() else {
540 return;
541 };
542
543 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
544 if let Some(tools) = protocol.list_tools().await.log_err() {
545 let tool_ids = tool_working_set
546 .update(cx, |tool_working_set, _| {
547 tools
548 .tools
549 .into_iter()
550 .map(|tool| {
551 log::info!(
552 "registering context server tool: {:?}",
553 tool.name
554 );
555 tool_working_set.insert(Arc::new(
556 ContextServerTool::new(
557 context_server_manager.clone(),
558 server.id(),
559 tool,
560 ),
561 ))
562 })
563 .collect::<Vec<_>>()
564 })
565 .log_err();
566
567 if let Some(tool_ids) = tool_ids {
568 this.update(cx, |this, cx| {
569 this.context_server_tool_ids
570 .insert(server_id, tool_ids);
571 this.load_default_profile(cx);
572 })
573 .log_err();
574 }
575 }
576 }
577 }
578 })
579 .detach();
580 }
581 }
582 context_server::manager::Event::ServerStopped { server_id } => {
583 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
584 tool_working_set.update(cx, |tool_working_set, _| {
585 tool_working_set.remove(&tool_ids);
586 });
587 self.load_default_profile(cx);
588 }
589 }
590 }
591 }
592}
593
594#[derive(Debug, Clone, Serialize, Deserialize)]
595pub struct SerializedThreadMetadata {
596 pub id: ThreadId,
597 pub summary: SharedString,
598 pub updated_at: DateTime<Utc>,
599}
600
601#[derive(Serialize, Deserialize, Debug)]
602pub struct SerializedThread {
603 pub version: String,
604 pub summary: SharedString,
605 pub updated_at: DateTime<Utc>,
606 pub messages: Vec<SerializedMessage>,
607 #[serde(default)]
608 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
609 #[serde(default)]
610 pub cumulative_token_usage: TokenUsage,
611 #[serde(default)]
612 pub request_token_usage: Vec<TokenUsage>,
613 #[serde(default)]
614 pub detailed_summary_state: DetailedSummaryState,
615 #[serde(default)]
616 pub exceeded_window_error: Option<ExceededWindowError>,
617}
618
619impl SerializedThread {
620 pub const VERSION: &'static str = "0.2.0";
621
622 pub fn from_json(json: &[u8]) -> Result<Self> {
623 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
624 match saved_thread_json.get("version") {
625 Some(serde_json::Value::String(version)) => match version.as_str() {
626 SerializedThreadV0_1_0::VERSION => {
627 let saved_thread =
628 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
629 Ok(saved_thread.upgrade())
630 }
631 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
632 saved_thread_json,
633 )?),
634 _ => Err(anyhow!(
635 "unrecognized serialized thread version: {}",
636 version
637 )),
638 },
639 None => {
640 let saved_thread =
641 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
642 Ok(saved_thread.upgrade())
643 }
644 version => Err(anyhow!(
645 "unrecognized serialized thread version: {:?}",
646 version
647 )),
648 }
649 }
650}
651
652#[derive(Serialize, Deserialize, Debug)]
653pub struct SerializedThreadV0_1_0(
654 // The structure did not change, so we are reusing the latest SerializedThread.
655 // When making the next version, make sure this points to SerializedThreadV0_2_0
656 SerializedThread,
657);
658
659impl SerializedThreadV0_1_0 {
660 pub const VERSION: &'static str = "0.1.0";
661
662 pub fn upgrade(self) -> SerializedThread {
663 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
664
665 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
666
667 for message in self.0.messages {
668 if message.role == Role::User && !message.tool_results.is_empty() {
669 if let Some(last_message) = messages.last_mut() {
670 debug_assert!(last_message.role == Role::Assistant);
671
672 last_message.tool_results = message.tool_results;
673 continue;
674 }
675 }
676
677 messages.push(message);
678 }
679
680 SerializedThread { messages, ..self.0 }
681 }
682}
683
684#[derive(Debug, Serialize, Deserialize)]
685pub struct SerializedMessage {
686 pub id: MessageId,
687 pub role: Role,
688 #[serde(default)]
689 pub segments: Vec<SerializedMessageSegment>,
690 #[serde(default)]
691 pub tool_uses: Vec<SerializedToolUse>,
692 #[serde(default)]
693 pub tool_results: Vec<SerializedToolResult>,
694 #[serde(default)]
695 pub context: String,
696}
697
698#[derive(Debug, Serialize, Deserialize)]
699#[serde(tag = "type")]
700pub enum SerializedMessageSegment {
701 #[serde(rename = "text")]
702 Text {
703 text: String,
704 },
705 #[serde(rename = "thinking")]
706 Thinking {
707 text: String,
708 #[serde(skip_serializing_if = "Option::is_none")]
709 signature: Option<String>,
710 },
711 RedactedThinking {
712 data: Vec<u8>,
713 },
714}
715
716#[derive(Debug, Serialize, Deserialize)]
717pub struct SerializedToolUse {
718 pub id: LanguageModelToolUseId,
719 pub name: SharedString,
720 pub input: serde_json::Value,
721}
722
723#[derive(Debug, Serialize, Deserialize)]
724pub struct SerializedToolResult {
725 pub tool_use_id: LanguageModelToolUseId,
726 pub is_error: bool,
727 pub content: Arc<str>,
728}
729
730#[derive(Serialize, Deserialize)]
731struct LegacySerializedThread {
732 pub summary: SharedString,
733 pub updated_at: DateTime<Utc>,
734 pub messages: Vec<LegacySerializedMessage>,
735 #[serde(default)]
736 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
737}
738
739impl LegacySerializedThread {
740 pub fn upgrade(self) -> SerializedThread {
741 SerializedThread {
742 version: SerializedThread::VERSION.to_string(),
743 summary: self.summary,
744 updated_at: self.updated_at,
745 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
746 initial_project_snapshot: self.initial_project_snapshot,
747 cumulative_token_usage: TokenUsage::default(),
748 request_token_usage: Vec::new(),
749 detailed_summary_state: DetailedSummaryState::default(),
750 exceeded_window_error: None,
751 }
752 }
753}
754
755#[derive(Debug, Serialize, Deserialize)]
756struct LegacySerializedMessage {
757 pub id: MessageId,
758 pub role: Role,
759 pub text: String,
760 #[serde(default)]
761 pub tool_uses: Vec<SerializedToolUse>,
762 #[serde(default)]
763 pub tool_results: Vec<SerializedToolResult>,
764}
765
766impl LegacySerializedMessage {
767 fn upgrade(self) -> SerializedMessage {
768 SerializedMessage {
769 id: self.id,
770 role: self.role,
771 segments: vec![SerializedMessageSegment::Text { text: self.text }],
772 tool_uses: self.tool_uses,
773 tool_results: self.tool_results,
774 context: String::new(),
775 }
776 }
777}
778
779struct GlobalThreadsDatabase(
780 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
781);
782
783impl Global for GlobalThreadsDatabase {}
784
785pub(crate) struct ThreadsDatabase {
786 executor: BackgroundExecutor,
787 env: heed::Env,
788 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
789}
790
791impl heed::BytesEncode<'_> for SerializedThread {
792 type EItem = SerializedThread;
793
794 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
795 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
796 }
797}
798
799impl<'a> heed::BytesDecode<'a> for SerializedThread {
800 type DItem = SerializedThread;
801
802 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
803 // We implement this type manually because we want to call `SerializedThread::from_json`,
804 // instead of the Deserialize trait implementation for `SerializedThread`.
805 SerializedThread::from_json(bytes).map_err(Into::into)
806 }
807}
808
809impl ThreadsDatabase {
810 fn global_future(
811 cx: &mut App,
812 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
813 GlobalThreadsDatabase::global(cx).0.clone()
814 }
815
816 fn init(cx: &mut App) {
817 let executor = cx.background_executor().clone();
818 let database_future = executor
819 .spawn({
820 let executor = executor.clone();
821 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
822 async move { ThreadsDatabase::new(database_path, executor) }
823 })
824 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
825 .boxed()
826 .shared();
827
828 cx.set_global(GlobalThreadsDatabase(database_future));
829 }
830
831 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
832 std::fs::create_dir_all(&path)?;
833
834 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
835 let env = unsafe {
836 heed::EnvOpenOptions::new()
837 .map_size(ONE_GB_IN_BYTES)
838 .max_dbs(1)
839 .open(path)?
840 };
841
842 let mut txn = env.write_txn()?;
843 let threads = env.create_database(&mut txn, Some("threads"))?;
844 txn.commit()?;
845
846 Ok(Self {
847 executor,
848 env,
849 threads,
850 })
851 }
852
853 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
854 let env = self.env.clone();
855 let threads = self.threads;
856
857 self.executor.spawn(async move {
858 let txn = env.read_txn()?;
859 let mut iter = threads.iter(&txn)?;
860 let mut threads = Vec::new();
861 while let Some((key, value)) = iter.next().transpose()? {
862 threads.push(SerializedThreadMetadata {
863 id: key,
864 summary: value.summary,
865 updated_at: value.updated_at,
866 });
867 }
868
869 Ok(threads)
870 })
871 }
872
873 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
874 let env = self.env.clone();
875 let threads = self.threads;
876
877 self.executor.spawn(async move {
878 let txn = env.read_txn()?;
879 let thread = threads.get(&txn, &id)?;
880 Ok(thread)
881 })
882 }
883
884 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
885 let env = self.env.clone();
886 let threads = self.threads;
887
888 self.executor.spawn(async move {
889 let mut txn = env.write_txn()?;
890 threads.put(&mut txn, &id, &thread)?;
891 txn.commit()?;
892 Ok(())
893 })
894 }
895
896 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
897 let env = self.env.clone();
898 let threads = self.threads;
899
900 self.executor.spawn(async move {
901 let mut txn = env.write_txn()?;
902 threads.delete(&mut txn, &id)?;
903 txn.commit()?;
904 Ok(())
905 })
906 }
907}