1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::manager::ContextServerManager;
13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
14use fs::Fs;
15use futures::channel::{mpsc, oneshot};
16use futures::future::{self, BoxFuture, Shared};
17use futures::{FutureExt as _, StreamExt as _};
18use gpui::{
19 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
20 Subscription, Task, prelude::*,
21};
22use heed::Database;
23use heed::types::SerdeBincode;
24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
25use project::{Project, Worktree};
26use prompt_store::{
27 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
28 UserRulesContext, WorktreeContext,
29};
30use serde::{Deserialize, Serialize};
31use settings::{Settings as _, SettingsStore};
32use util::ResultExt as _;
33
34use crate::thread::{
35 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
36};
37
38const RULES_FILE_NAMES: [&'static str; 6] = [
39 ".rules",
40 ".cursorrules",
41 ".windsurfrules",
42 ".clinerules",
43 ".github/copilot-instructions.md",
44 "CLAUDE.md",
45];
46
47pub fn init(cx: &mut App) {
48 ThreadsDatabase::init(cx);
49}
50
51/// A system prompt shared by all threads created by this ThreadStore
52#[derive(Clone, Default)]
53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
54
55impl SharedProjectContext {
56 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
57 self.0.borrow()
58 }
59}
60
61pub struct ThreadStore {
62 project: Entity<Project>,
63 tools: Entity<ToolWorkingSet>,
64 prompt_builder: Arc<PromptBuilder>,
65 prompt_store: Option<Entity<PromptStore>>,
66 context_server_manager: Entity<ContextServerManager>,
67 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
68 threads: Vec<SerializedThreadMetadata>,
69 project_context: SharedProjectContext,
70 reload_system_prompt_tx: mpsc::Sender<()>,
71 _reload_system_prompt_task: Task<()>,
72 _subscriptions: Vec<Subscription>,
73}
74
75pub struct RulesLoadingError {
76 pub message: SharedString,
77}
78
79impl EventEmitter<RulesLoadingError> for ThreadStore {}
80
81impl ThreadStore {
82 pub fn load(
83 project: Entity<Project>,
84 tools: Entity<ToolWorkingSet>,
85 prompt_store: Option<Entity<PromptStore>>,
86 prompt_builder: Arc<PromptBuilder>,
87 cx: &mut App,
88 ) -> Task<Result<Entity<Self>>> {
89 cx.spawn(async move |cx| {
90 let (thread_store, ready_rx) = cx.update(|cx| {
91 let mut option_ready_rx = None;
92 let thread_store = cx.new(|cx| {
93 let (thread_store, ready_rx) =
94 Self::new(project, tools, prompt_builder, prompt_store, cx);
95 option_ready_rx = Some(ready_rx);
96 thread_store
97 });
98 (thread_store, option_ready_rx.take().unwrap())
99 })?;
100 ready_rx.await?;
101 Ok(thread_store)
102 })
103 }
104
105 fn new(
106 project: Entity<Project>,
107 tools: Entity<ToolWorkingSet>,
108 prompt_builder: Arc<PromptBuilder>,
109 prompt_store: Option<Entity<PromptStore>>,
110 cx: &mut Context<Self>,
111 ) -> (Self, oneshot::Receiver<()>) {
112 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
113 let context_server_manager = cx.new(|cx| {
114 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
115 });
116
117 let mut subscriptions = vec![
118 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
119 this.load_default_profile(cx);
120 }),
121 cx.subscribe(&project, Self::handle_project_event),
122 ];
123
124 if let Some(prompt_store) = prompt_store.as_ref() {
125 subscriptions.push(cx.subscribe(
126 prompt_store,
127 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
128 this.enqueue_system_prompt_reload();
129 },
130 ))
131 }
132
133 // This channel and task prevent concurrent and redundant loading of the system prompt.
134 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
135 let (ready_tx, ready_rx) = oneshot::channel();
136 let mut ready_tx = Some(ready_tx);
137 let reload_system_prompt_task = cx.spawn({
138 let prompt_store = prompt_store.clone();
139 async move |thread_store, cx| {
140 loop {
141 let Some(reload_task) = thread_store
142 .update(cx, |thread_store, cx| {
143 thread_store.reload_system_prompt(prompt_store.clone(), cx)
144 })
145 .ok()
146 else {
147 return;
148 };
149 reload_task.await;
150 if let Some(ready_tx) = ready_tx.take() {
151 ready_tx.send(()).ok();
152 }
153 reload_system_prompt_rx.next().await;
154 }
155 }
156 });
157
158 let this = Self {
159 project,
160 tools,
161 prompt_builder,
162 prompt_store,
163 context_server_manager,
164 context_server_tool_ids: HashMap::default(),
165 threads: Vec::new(),
166 project_context: SharedProjectContext::default(),
167 reload_system_prompt_tx,
168 _reload_system_prompt_task: reload_system_prompt_task,
169 _subscriptions: subscriptions,
170 };
171 this.load_default_profile(cx);
172 this.register_context_server_handlers(cx);
173 this.reload(cx).detach_and_log_err(cx);
174 (this, ready_rx)
175 }
176
177 fn handle_project_event(
178 &mut self,
179 _project: Entity<Project>,
180 event: &project::Event,
181 _cx: &mut Context<Self>,
182 ) {
183 match event {
184 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
185 self.enqueue_system_prompt_reload();
186 }
187 project::Event::WorktreeUpdatedEntries(_, items) => {
188 if items.iter().any(|(path, _, _)| {
189 RULES_FILE_NAMES
190 .iter()
191 .any(|name| path.as_ref() == Path::new(name))
192 }) {
193 self.enqueue_system_prompt_reload();
194 }
195 }
196 _ => {}
197 }
198 }
199
200 fn enqueue_system_prompt_reload(&mut self) {
201 self.reload_system_prompt_tx.try_send(()).ok();
202 }
203
204 // Note that this should only be called from `reload_system_prompt_task`.
205 fn reload_system_prompt(
206 &self,
207 prompt_store: Option<Entity<PromptStore>>,
208 cx: &mut Context<Self>,
209 ) -> Task<()> {
210 let project = self.project.read(cx);
211 let worktree_tasks = project
212 .visible_worktrees(cx)
213 .map(|worktree| {
214 Self::load_worktree_info_for_system_prompt(
215 project.fs().clone(),
216 worktree.read(cx),
217 cx,
218 )
219 })
220 .collect::<Vec<_>>();
221 let default_user_rules_task = match prompt_store {
222 None => Task::ready(vec![]),
223 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
224 let prompts = prompt_store.default_prompt_metadata();
225 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
226 let contents = prompt_store.load(prompt_metadata.id, cx);
227 async move { (contents.await, prompt_metadata) }
228 });
229 cx.background_spawn(future::join_all(load_tasks))
230 }),
231 };
232
233 cx.spawn(async move |this, cx| {
234 let (worktrees, default_user_rules) =
235 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
236
237 let worktrees = worktrees
238 .into_iter()
239 .map(|(worktree, rules_error)| {
240 if let Some(rules_error) = rules_error {
241 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
242 }
243 worktree
244 })
245 .collect::<Vec<_>>();
246
247 let default_user_rules = default_user_rules
248 .into_iter()
249 .flat_map(|(contents, prompt_metadata)| match contents {
250 Ok(contents) => Some(UserRulesContext {
251 uuid: match prompt_metadata.id {
252 PromptId::User { uuid } => uuid,
253 PromptId::EditWorkflow => return None,
254 },
255 title: prompt_metadata.title.map(|title| title.to_string()),
256 contents,
257 }),
258 Err(err) => {
259 this.update(cx, |_, cx| {
260 cx.emit(RulesLoadingError {
261 message: format!("{err:?}").into(),
262 });
263 })
264 .ok();
265 None
266 }
267 })
268 .collect::<Vec<_>>();
269
270 this.update(cx, |this, _cx| {
271 *this.project_context.0.borrow_mut() =
272 Some(ProjectContext::new(worktrees, default_user_rules));
273 })
274 .ok();
275 })
276 }
277
278 fn load_worktree_info_for_system_prompt(
279 fs: Arc<dyn Fs>,
280 worktree: &Worktree,
281 cx: &App,
282 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
283 let root_name = worktree.root_name().into();
284
285 let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
286 let Some(rules_task) = rules_task else {
287 return Task::ready((
288 WorktreeContext {
289 root_name,
290 rules_file: None,
291 },
292 None,
293 ));
294 };
295
296 cx.spawn(async move |_| {
297 let (rules_file, rules_file_error) = match rules_task.await {
298 Ok(rules_file) => (Some(rules_file), None),
299 Err(err) => (
300 None,
301 Some(RulesLoadingError {
302 message: format!("{err}").into(),
303 }),
304 ),
305 };
306 let worktree_info = WorktreeContext {
307 root_name,
308 rules_file,
309 };
310 (worktree_info, rules_file_error)
311 })
312 }
313
314 fn load_worktree_rules_file(
315 fs: Arc<dyn Fs>,
316 worktree: &Worktree,
317 cx: &App,
318 ) -> Option<Task<Result<RulesFileContext>>> {
319 let selected_rules_file = RULES_FILE_NAMES
320 .into_iter()
321 .filter_map(|name| {
322 worktree
323 .entry_for_path(name)
324 .filter(|entry| entry.is_file())
325 .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
326 })
327 .next();
328
329 // Note that Cline supports `.clinerules` being a directory, but that is not currently
330 // supported. This doesn't seem to occur often in GitHub repositories.
331 selected_rules_file.map(|(path_in_worktree, abs_path)| {
332 let fs = fs.clone();
333 cx.background_spawn(async move {
334 let abs_path = abs_path?;
335 let text = fs.load(&abs_path).await.with_context(|| {
336 format!("Failed to load assistant rules file {:?}", abs_path)
337 })?;
338 anyhow::Ok(RulesFileContext {
339 path_in_worktree,
340 abs_path: abs_path.into(),
341 text: text.trim().to_string(),
342 })
343 })
344 })
345 }
346
347 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
348 self.context_server_manager.clone()
349 }
350
351 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
352 &self.prompt_store
353 }
354
355 pub fn tools(&self) -> Entity<ToolWorkingSet> {
356 self.tools.clone()
357 }
358
359 /// Returns the number of threads.
360 pub fn thread_count(&self) -> usize {
361 self.threads.len()
362 }
363
364 pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
365 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
366 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
367 threads
368 }
369
370 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
371 cx.new(|cx| {
372 Thread::new(
373 self.project.clone(),
374 self.tools.clone(),
375 self.prompt_builder.clone(),
376 self.project_context.clone(),
377 cx,
378 )
379 })
380 }
381
382 pub fn open_thread(
383 &self,
384 id: &ThreadId,
385 cx: &mut Context<Self>,
386 ) -> Task<Result<Entity<Thread>>> {
387 let id = id.clone();
388 let database_future = ThreadsDatabase::global_future(cx);
389 cx.spawn(async move |this, cx| {
390 let database = database_future.await.map_err(|err| anyhow!(err))?;
391 let thread = database
392 .try_find_thread(id.clone())
393 .await?
394 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
395
396 let thread = this.update(cx, |this, cx| {
397 cx.new(|cx| {
398 Thread::deserialize(
399 id.clone(),
400 thread,
401 this.project.clone(),
402 this.tools.clone(),
403 this.prompt_builder.clone(),
404 this.project_context.clone(),
405 cx,
406 )
407 })
408 })?;
409
410 Ok(thread)
411 })
412 }
413
414 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
415 let (metadata, serialized_thread) =
416 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
417
418 let database_future = ThreadsDatabase::global_future(cx);
419 cx.spawn(async move |this, cx| {
420 let serialized_thread = serialized_thread.await?;
421 let database = database_future.await.map_err(|err| anyhow!(err))?;
422 database.save_thread(metadata, serialized_thread).await?;
423
424 this.update(cx, |this, cx| this.reload(cx))?.await
425 })
426 }
427
428 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
429 let id = id.clone();
430 let database_future = ThreadsDatabase::global_future(cx);
431 cx.spawn(async move |this, cx| {
432 let database = database_future.await.map_err(|err| anyhow!(err))?;
433 database.delete_thread(id.clone()).await?;
434
435 this.update(cx, |this, cx| {
436 this.threads.retain(|thread| thread.id != id);
437 cx.notify();
438 })
439 })
440 }
441
442 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
443 let database_future = ThreadsDatabase::global_future(cx);
444 cx.spawn(async move |this, cx| {
445 let threads = database_future
446 .await
447 .map_err(|err| anyhow!(err))?
448 .list_threads()
449 .await?;
450
451 this.update(cx, |this, cx| {
452 this.threads = threads;
453 cx.notify();
454 })
455 })
456 }
457
458 fn load_default_profile(&self, cx: &mut Context<Self>) {
459 let assistant_settings = AssistantSettings::get_global(cx);
460
461 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
462 }
463
464 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
465 let assistant_settings = AssistantSettings::get_global(cx);
466
467 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
468 self.load_profile(profile.clone(), cx);
469 }
470 }
471
472 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
473 self.tools.update(cx, |tools, cx| {
474 tools.disable_all_tools(cx);
475 tools.enable(
476 ToolSource::Native,
477 &profile
478 .tools
479 .iter()
480 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
481 .collect::<Vec<_>>(),
482 cx,
483 );
484 });
485
486 if profile.enable_all_context_servers {
487 for context_server in self.context_server_manager.read(cx).all_servers() {
488 self.tools.update(cx, |tools, cx| {
489 tools.enable_source(
490 ToolSource::ContextServer {
491 id: context_server.id().into(),
492 },
493 cx,
494 );
495 });
496 }
497 // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
498 for (context_server_id, preset) in &profile.context_servers {
499 self.tools.update(cx, |tools, cx| {
500 tools.disable(
501 ToolSource::ContextServer {
502 id: context_server_id.clone().into(),
503 },
504 &preset
505 .tools
506 .iter()
507 .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
508 .collect::<Vec<_>>(),
509 cx,
510 )
511 })
512 }
513 } else {
514 for (context_server_id, preset) in &profile.context_servers {
515 self.tools.update(cx, |tools, cx| {
516 tools.enable(
517 ToolSource::ContextServer {
518 id: context_server_id.clone().into(),
519 },
520 &preset
521 .tools
522 .iter()
523 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
524 .collect::<Vec<_>>(),
525 cx,
526 )
527 })
528 }
529 }
530 }
531
532 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
533 cx.subscribe(
534 &self.context_server_manager.clone(),
535 Self::handle_context_server_event,
536 )
537 .detach();
538 }
539
540 fn handle_context_server_event(
541 &mut self,
542 context_server_manager: Entity<ContextServerManager>,
543 event: &context_server::manager::Event,
544 cx: &mut Context<Self>,
545 ) {
546 let tool_working_set = self.tools.clone();
547 match event {
548 context_server::manager::Event::ServerStarted { server_id } => {
549 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
550 let context_server_manager = context_server_manager.clone();
551 cx.spawn({
552 let server = server.clone();
553 let server_id = server_id.clone();
554 async move |this, cx| {
555 let Some(protocol) = server.client() else {
556 return;
557 };
558
559 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
560 if let Some(tools) = protocol.list_tools().await.log_err() {
561 let tool_ids = tool_working_set
562 .update(cx, |tool_working_set, _| {
563 tools
564 .tools
565 .into_iter()
566 .map(|tool| {
567 log::info!(
568 "registering context server tool: {:?}",
569 tool.name
570 );
571 tool_working_set.insert(Arc::new(
572 ContextServerTool::new(
573 context_server_manager.clone(),
574 server.id(),
575 tool,
576 ),
577 ))
578 })
579 .collect::<Vec<_>>()
580 })
581 .log_err();
582
583 if let Some(tool_ids) = tool_ids {
584 this.update(cx, |this, cx| {
585 this.context_server_tool_ids
586 .insert(server_id, tool_ids);
587 this.load_default_profile(cx);
588 })
589 .log_err();
590 }
591 }
592 }
593 }
594 })
595 .detach();
596 }
597 }
598 context_server::manager::Event::ServerStopped { server_id } => {
599 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
600 tool_working_set.update(cx, |tool_working_set, _| {
601 tool_working_set.remove(&tool_ids);
602 });
603 self.load_default_profile(cx);
604 }
605 }
606 }
607 }
608}
609
610#[derive(Debug, Clone, Serialize, Deserialize)]
611pub struct SerializedThreadMetadata {
612 pub id: ThreadId,
613 pub summary: SharedString,
614 pub updated_at: DateTime<Utc>,
615}
616
617#[derive(Serialize, Deserialize, Debug)]
618pub struct SerializedThread {
619 pub version: String,
620 pub summary: SharedString,
621 pub updated_at: DateTime<Utc>,
622 pub messages: Vec<SerializedMessage>,
623 #[serde(default)]
624 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
625 #[serde(default)]
626 pub cumulative_token_usage: TokenUsage,
627 #[serde(default)]
628 pub request_token_usage: Vec<TokenUsage>,
629 #[serde(default)]
630 pub detailed_summary_state: DetailedSummaryState,
631 #[serde(default)]
632 pub exceeded_window_error: Option<ExceededWindowError>,
633}
634
635impl SerializedThread {
636 pub const VERSION: &'static str = "0.2.0";
637
638 pub fn from_json(json: &[u8]) -> Result<Self> {
639 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
640 match saved_thread_json.get("version") {
641 Some(serde_json::Value::String(version)) => match version.as_str() {
642 SerializedThreadV0_1_0::VERSION => {
643 let saved_thread =
644 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
645 Ok(saved_thread.upgrade())
646 }
647 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
648 saved_thread_json,
649 )?),
650 _ => Err(anyhow!(
651 "unrecognized serialized thread version: {}",
652 version
653 )),
654 },
655 None => {
656 let saved_thread =
657 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
658 Ok(saved_thread.upgrade())
659 }
660 version => Err(anyhow!(
661 "unrecognized serialized thread version: {:?}",
662 version
663 )),
664 }
665 }
666}
667
668#[derive(Serialize, Deserialize, Debug)]
669pub struct SerializedThreadV0_1_0(
670 // The structure did not change, so we are reusing the latest SerializedThread.
671 // When making the next version, make sure this points to SerializedThreadV0_2_0
672 SerializedThread,
673);
674
675impl SerializedThreadV0_1_0 {
676 pub const VERSION: &'static str = "0.1.0";
677
678 pub fn upgrade(self) -> SerializedThread {
679 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
680
681 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
682
683 for message in self.0.messages {
684 if message.role == Role::User && !message.tool_results.is_empty() {
685 if let Some(last_message) = messages.last_mut() {
686 debug_assert!(last_message.role == Role::Assistant);
687
688 last_message.tool_results = message.tool_results;
689 continue;
690 }
691 }
692
693 messages.push(message);
694 }
695
696 SerializedThread { messages, ..self.0 }
697 }
698}
699
700#[derive(Debug, Serialize, Deserialize)]
701pub struct SerializedMessage {
702 pub id: MessageId,
703 pub role: Role,
704 #[serde(default)]
705 pub segments: Vec<SerializedMessageSegment>,
706 #[serde(default)]
707 pub tool_uses: Vec<SerializedToolUse>,
708 #[serde(default)]
709 pub tool_results: Vec<SerializedToolResult>,
710 #[serde(default)]
711 pub context: String,
712}
713
714#[derive(Debug, Serialize, Deserialize)]
715#[serde(tag = "type")]
716pub enum SerializedMessageSegment {
717 #[serde(rename = "text")]
718 Text {
719 text: String,
720 },
721 #[serde(rename = "thinking")]
722 Thinking {
723 text: String,
724 #[serde(skip_serializing_if = "Option::is_none")]
725 signature: Option<String>,
726 },
727 RedactedThinking {
728 data: Vec<u8>,
729 },
730}
731
732#[derive(Debug, Serialize, Deserialize)]
733pub struct SerializedToolUse {
734 pub id: LanguageModelToolUseId,
735 pub name: SharedString,
736 pub input: serde_json::Value,
737}
738
739#[derive(Debug, Serialize, Deserialize)]
740pub struct SerializedToolResult {
741 pub tool_use_id: LanguageModelToolUseId,
742 pub is_error: bool,
743 pub content: Arc<str>,
744}
745
746#[derive(Serialize, Deserialize)]
747struct LegacySerializedThread {
748 pub summary: SharedString,
749 pub updated_at: DateTime<Utc>,
750 pub messages: Vec<LegacySerializedMessage>,
751 #[serde(default)]
752 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
753}
754
755impl LegacySerializedThread {
756 pub fn upgrade(self) -> SerializedThread {
757 SerializedThread {
758 version: SerializedThread::VERSION.to_string(),
759 summary: self.summary,
760 updated_at: self.updated_at,
761 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
762 initial_project_snapshot: self.initial_project_snapshot,
763 cumulative_token_usage: TokenUsage::default(),
764 request_token_usage: Vec::new(),
765 detailed_summary_state: DetailedSummaryState::default(),
766 exceeded_window_error: None,
767 }
768 }
769}
770
771#[derive(Debug, Serialize, Deserialize)]
772struct LegacySerializedMessage {
773 pub id: MessageId,
774 pub role: Role,
775 pub text: String,
776 #[serde(default)]
777 pub tool_uses: Vec<SerializedToolUse>,
778 #[serde(default)]
779 pub tool_results: Vec<SerializedToolResult>,
780}
781
782impl LegacySerializedMessage {
783 fn upgrade(self) -> SerializedMessage {
784 SerializedMessage {
785 id: self.id,
786 role: self.role,
787 segments: vec![SerializedMessageSegment::Text { text: self.text }],
788 tool_uses: self.tool_uses,
789 tool_results: self.tool_results,
790 context: String::new(),
791 }
792 }
793}
794
795struct GlobalThreadsDatabase(
796 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
797);
798
799impl Global for GlobalThreadsDatabase {}
800
801pub(crate) struct ThreadsDatabase {
802 executor: BackgroundExecutor,
803 env: heed::Env,
804 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
805}
806
807impl heed::BytesEncode<'_> for SerializedThread {
808 type EItem = SerializedThread;
809
810 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
811 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
812 }
813}
814
815impl<'a> heed::BytesDecode<'a> for SerializedThread {
816 type DItem = SerializedThread;
817
818 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
819 // We implement this type manually because we want to call `SerializedThread::from_json`,
820 // instead of the Deserialize trait implementation for `SerializedThread`.
821 SerializedThread::from_json(bytes).map_err(Into::into)
822 }
823}
824
825impl ThreadsDatabase {
826 fn global_future(
827 cx: &mut App,
828 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
829 GlobalThreadsDatabase::global(cx).0.clone()
830 }
831
832 fn init(cx: &mut App) {
833 let executor = cx.background_executor().clone();
834 let database_future = executor
835 .spawn({
836 let executor = executor.clone();
837 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
838 async move { ThreadsDatabase::new(database_path, executor) }
839 })
840 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
841 .boxed()
842 .shared();
843
844 cx.set_global(GlobalThreadsDatabase(database_future));
845 }
846
847 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
848 std::fs::create_dir_all(&path)?;
849
850 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
851 let env = unsafe {
852 heed::EnvOpenOptions::new()
853 .map_size(ONE_GB_IN_BYTES)
854 .max_dbs(1)
855 .open(path)?
856 };
857
858 let mut txn = env.write_txn()?;
859 let threads = env.create_database(&mut txn, Some("threads"))?;
860 txn.commit()?;
861
862 Ok(Self {
863 executor,
864 env,
865 threads,
866 })
867 }
868
869 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
870 let env = self.env.clone();
871 let threads = self.threads;
872
873 self.executor.spawn(async move {
874 let txn = env.read_txn()?;
875 let mut iter = threads.iter(&txn)?;
876 let mut threads = Vec::new();
877 while let Some((key, value)) = iter.next().transpose()? {
878 threads.push(SerializedThreadMetadata {
879 id: key,
880 summary: value.summary,
881 updated_at: value.updated_at,
882 });
883 }
884
885 Ok(threads)
886 })
887 }
888
889 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
890 let env = self.env.clone();
891 let threads = self.threads;
892
893 self.executor.spawn(async move {
894 let txn = env.read_txn()?;
895 let thread = threads.get(&txn, &id)?;
896 Ok(thread)
897 })
898 }
899
900 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
901 let env = self.env.clone();
902 let threads = self.threads;
903
904 self.executor.spawn(async move {
905 let mut txn = env.write_txn()?;
906 threads.put(&mut txn, &id, &thread)?;
907 txn.commit()?;
908 Ok(())
909 })
910 }
911
912 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
913 let env = self.env.clone();
914 let threads = self.threads;
915
916 self.executor.spawn(async move {
917 let mut txn = env.write_txn()?;
918 threads.delete(&mut txn, &id)?;
919 txn.commit()?;
920 Ok(())
921 })
922 }
923}