1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::manager::ContextServerManager;
13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
14use fs::Fs;
15use futures::channel::{mpsc, oneshot};
16use futures::future::{self, BoxFuture, Shared};
17use futures::{FutureExt as _, StreamExt as _};
18use gpui::{
19 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
20 Subscription, Task, prelude::*,
21};
22use heed::Database;
23use heed::types::SerdeBincode;
24use language_model::{LanguageModelToolUseId, Role, TokenUsage};
25use project::{Project, Worktree};
26use prompt_store::{
27 ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
28 RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
29};
30use serde::{Deserialize, Serialize};
31use settings::{Settings as _, SettingsStore};
32use util::ResultExt as _;
33
34use crate::thread::{
35 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
36};
37
38const RULES_FILE_NAMES: [&'static str; 6] = [
39 ".rules",
40 ".cursorrules",
41 ".windsurfrules",
42 ".clinerules",
43 ".github/copilot-instructions.md",
44 "CLAUDE.md",
45];
46
47pub fn init(cx: &mut App) {
48 ThreadsDatabase::init(cx);
49}
50
51/// A system prompt shared by all threads created by this ThreadStore
52#[derive(Clone, Default)]
53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
54
55impl SharedProjectContext {
56 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
57 self.0.borrow()
58 }
59}
60
61pub struct ThreadStore {
62 project: Entity<Project>,
63 tools: Entity<ToolWorkingSet>,
64 prompt_builder: Arc<PromptBuilder>,
65 prompt_store: Option<Entity<PromptStore>>,
66 context_server_manager: Entity<ContextServerManager>,
67 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
68 threads: Vec<SerializedThreadMetadata>,
69 project_context: SharedProjectContext,
70 reload_system_prompt_tx: mpsc::Sender<()>,
71 _reload_system_prompt_task: Task<()>,
72 _subscriptions: Vec<Subscription>,
73}
74
75pub struct RulesLoadingError {
76 pub message: SharedString,
77}
78
79impl EventEmitter<RulesLoadingError> for ThreadStore {}
80
81impl ThreadStore {
82 pub fn load(
83 project: Entity<Project>,
84 tools: Entity<ToolWorkingSet>,
85 prompt_builder: Arc<PromptBuilder>,
86 cx: &mut App,
87 ) -> Task<Result<Entity<Self>>> {
88 let prompt_store = PromptStore::global(cx);
89 cx.spawn(async move |cx| {
90 let prompt_store = prompt_store.await.ok();
91 let (thread_store, ready_rx) = cx.update(|cx| {
92 let mut option_ready_rx = None;
93 let thread_store = cx.new(|cx| {
94 let (thread_store, ready_rx) =
95 Self::new(project, tools, prompt_builder, prompt_store, cx);
96 option_ready_rx = Some(ready_rx);
97 thread_store
98 });
99 (thread_store, option_ready_rx.take().unwrap())
100 })?;
101 ready_rx.await?;
102 Ok(thread_store)
103 })
104 }
105
106 fn new(
107 project: Entity<Project>,
108 tools: Entity<ToolWorkingSet>,
109 prompt_builder: Arc<PromptBuilder>,
110 prompt_store: Option<Entity<PromptStore>>,
111 cx: &mut Context<Self>,
112 ) -> (Self, oneshot::Receiver<()>) {
113 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
114 let context_server_manager = cx.new(|cx| {
115 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
116 });
117
118 let mut subscriptions = vec![
119 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
120 this.load_default_profile(cx);
121 }),
122 cx.subscribe(&project, Self::handle_project_event),
123 ];
124
125 if let Some(prompt_store) = prompt_store.as_ref() {
126 subscriptions.push(cx.subscribe(
127 prompt_store,
128 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
129 this.enqueue_system_prompt_reload();
130 },
131 ))
132 }
133
134 // This channel and task prevent concurrent and redundant loading of the system prompt.
135 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
136 let (ready_tx, ready_rx) = oneshot::channel();
137 let mut ready_tx = Some(ready_tx);
138 let reload_system_prompt_task = cx.spawn({
139 let prompt_store = prompt_store.clone();
140 async move |thread_store, cx| {
141 loop {
142 let Some(reload_task) = thread_store
143 .update(cx, |thread_store, cx| {
144 thread_store.reload_system_prompt(prompt_store.clone(), cx)
145 })
146 .ok()
147 else {
148 return;
149 };
150 reload_task.await;
151 if let Some(ready_tx) = ready_tx.take() {
152 ready_tx.send(()).ok();
153 }
154 reload_system_prompt_rx.next().await;
155 }
156 }
157 });
158
159 let this = Self {
160 project,
161 tools,
162 prompt_builder,
163 prompt_store,
164 context_server_manager,
165 context_server_tool_ids: HashMap::default(),
166 threads: Vec::new(),
167 project_context: SharedProjectContext::default(),
168 reload_system_prompt_tx,
169 _reload_system_prompt_task: reload_system_prompt_task,
170 _subscriptions: subscriptions,
171 };
172 this.load_default_profile(cx);
173 this.register_context_server_handlers(cx);
174 this.reload(cx).detach_and_log_err(cx);
175 (this, ready_rx)
176 }
177
178 fn handle_project_event(
179 &mut self,
180 _project: Entity<Project>,
181 event: &project::Event,
182 _cx: &mut Context<Self>,
183 ) {
184 match event {
185 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
186 self.enqueue_system_prompt_reload();
187 }
188 project::Event::WorktreeUpdatedEntries(_, items) => {
189 if items.iter().any(|(path, _, _)| {
190 RULES_FILE_NAMES
191 .iter()
192 .any(|name| path.as_ref() == Path::new(name))
193 }) {
194 self.enqueue_system_prompt_reload();
195 }
196 }
197 _ => {}
198 }
199 }
200
201 fn enqueue_system_prompt_reload(&mut self) {
202 self.reload_system_prompt_tx.try_send(()).ok();
203 }
204
205 // Note that this should only be called from `reload_system_prompt_task`.
206 fn reload_system_prompt(
207 &self,
208 prompt_store: Option<Entity<PromptStore>>,
209 cx: &mut Context<Self>,
210 ) -> Task<()> {
211 let project = self.project.read(cx);
212 let worktree_tasks = project
213 .visible_worktrees(cx)
214 .map(|worktree| {
215 Self::load_worktree_info_for_system_prompt(
216 project.fs().clone(),
217 worktree.read(cx),
218 cx,
219 )
220 })
221 .collect::<Vec<_>>();
222 let default_user_rules_task = match prompt_store {
223 None => Task::ready(vec![]),
224 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
225 let prompts = prompt_store.default_prompt_metadata();
226 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
227 let contents = prompt_store.load(prompt_metadata.id, cx);
228 async move { (contents.await, prompt_metadata) }
229 });
230 cx.background_spawn(future::join_all(load_tasks))
231 }),
232 };
233
234 cx.spawn(async move |this, cx| {
235 let (worktrees, default_user_rules) =
236 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
237
238 let worktrees = worktrees
239 .into_iter()
240 .map(|(worktree, rules_error)| {
241 if let Some(rules_error) = rules_error {
242 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
243 }
244 worktree
245 })
246 .collect::<Vec<_>>();
247
248 let default_user_rules = default_user_rules
249 .into_iter()
250 .flat_map(|(contents, prompt_metadata)| match contents {
251 Ok(contents) => Some(UserRulesContext {
252 uuid: match prompt_metadata.id {
253 PromptId::User { uuid } => uuid,
254 PromptId::EditWorkflow => return None,
255 },
256 title: prompt_metadata.title.map(|title| title.to_string()),
257 contents,
258 }),
259 Err(err) => {
260 this.update(cx, |_, cx| {
261 cx.emit(RulesLoadingError {
262 message: format!("{err:?}").into(),
263 });
264 })
265 .ok();
266 None
267 }
268 })
269 .collect::<Vec<_>>();
270
271 this.update(cx, |this, _cx| {
272 *this.project_context.0.borrow_mut() =
273 Some(ProjectContext::new(worktrees, default_user_rules));
274 })
275 .ok();
276 })
277 }
278
279 fn load_worktree_info_for_system_prompt(
280 fs: Arc<dyn Fs>,
281 worktree: &Worktree,
282 cx: &App,
283 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
284 let root_name = worktree.root_name().into();
285
286 let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
287 let Some(rules_task) = rules_task else {
288 return Task::ready((
289 WorktreeContext {
290 root_name,
291 rules_file: None,
292 },
293 None,
294 ));
295 };
296
297 cx.spawn(async move |_| {
298 let (rules_file, rules_file_error) = match rules_task.await {
299 Ok(rules_file) => (Some(rules_file), None),
300 Err(err) => (
301 None,
302 Some(RulesLoadingError {
303 message: format!("{err}").into(),
304 }),
305 ),
306 };
307 let worktree_info = WorktreeContext {
308 root_name,
309 rules_file,
310 };
311 (worktree_info, rules_file_error)
312 })
313 }
314
315 fn load_worktree_rules_file(
316 fs: Arc<dyn Fs>,
317 worktree: &Worktree,
318 cx: &App,
319 ) -> Option<Task<Result<RulesFileContext>>> {
320 let selected_rules_file = RULES_FILE_NAMES
321 .into_iter()
322 .filter_map(|name| {
323 worktree
324 .entry_for_path(name)
325 .filter(|entry| entry.is_file())
326 .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
327 })
328 .next();
329
330 // Note that Cline supports `.clinerules` being a directory, but that is not currently
331 // supported. This doesn't seem to occur often in GitHub repositories.
332 selected_rules_file.map(|(path_in_worktree, abs_path)| {
333 let fs = fs.clone();
334 cx.background_spawn(async move {
335 let abs_path = abs_path?;
336 let text = fs.load(&abs_path).await.with_context(|| {
337 format!("Failed to load assistant rules file {:?}", abs_path)
338 })?;
339 anyhow::Ok(RulesFileContext {
340 path_in_worktree,
341 abs_path: abs_path.into(),
342 text: text.trim().to_string(),
343 })
344 })
345 })
346 }
347
348 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
349 self.context_server_manager.clone()
350 }
351
352 pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
353 self.prompt_store.clone()
354 }
355
356 pub fn load_rules(
357 &self,
358 prompt_id: UserPromptId,
359 cx: &App,
360 ) -> Task<Result<(PromptMetadata, String)>> {
361 let prompt_id = PromptId::User { uuid: prompt_id };
362 let Some(prompt_store) = self.prompt_store.as_ref() else {
363 return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
364 };
365 let prompt_store = prompt_store.read(cx);
366 let Some(metadata) = prompt_store.metadata(prompt_id) else {
367 return Task::ready(Err(anyhow!("User rules not found in library.")));
368 };
369 let text_task = prompt_store.load(prompt_id, cx);
370 cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
371 }
372
373 pub fn tools(&self) -> Entity<ToolWorkingSet> {
374 self.tools.clone()
375 }
376
377 /// Returns the number of threads.
378 pub fn thread_count(&self) -> usize {
379 self.threads.len()
380 }
381
382 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
383 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
384 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
385 threads
386 }
387
388 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
389 self.threads().into_iter().take(limit).collect()
390 }
391
392 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
393 cx.new(|cx| {
394 Thread::new(
395 self.project.clone(),
396 self.tools.clone(),
397 self.prompt_builder.clone(),
398 self.project_context.clone(),
399 cx,
400 )
401 })
402 }
403
404 pub fn open_thread(
405 &self,
406 id: &ThreadId,
407 cx: &mut Context<Self>,
408 ) -> Task<Result<Entity<Thread>>> {
409 let id = id.clone();
410 let database_future = ThreadsDatabase::global_future(cx);
411 cx.spawn(async move |this, cx| {
412 let database = database_future.await.map_err(|err| anyhow!(err))?;
413 let thread = database
414 .try_find_thread(id.clone())
415 .await?
416 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
417
418 let thread = this.update(cx, |this, cx| {
419 cx.new(|cx| {
420 Thread::deserialize(
421 id.clone(),
422 thread,
423 this.project.clone(),
424 this.tools.clone(),
425 this.prompt_builder.clone(),
426 this.project_context.clone(),
427 cx,
428 )
429 })
430 })?;
431
432 Ok(thread)
433 })
434 }
435
436 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
437 let (metadata, serialized_thread) =
438 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
439
440 let database_future = ThreadsDatabase::global_future(cx);
441 cx.spawn(async move |this, cx| {
442 let serialized_thread = serialized_thread.await?;
443 let database = database_future.await.map_err(|err| anyhow!(err))?;
444 database.save_thread(metadata, serialized_thread).await?;
445
446 this.update(cx, |this, cx| this.reload(cx))?.await
447 })
448 }
449
450 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
451 let id = id.clone();
452 let database_future = ThreadsDatabase::global_future(cx);
453 cx.spawn(async move |this, cx| {
454 let database = database_future.await.map_err(|err| anyhow!(err))?;
455 database.delete_thread(id.clone()).await?;
456
457 this.update(cx, |this, cx| {
458 this.threads.retain(|thread| thread.id != id);
459 cx.notify();
460 })
461 })
462 }
463
464 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
465 let database_future = ThreadsDatabase::global_future(cx);
466 cx.spawn(async move |this, cx| {
467 let threads = database_future
468 .await
469 .map_err(|err| anyhow!(err))?
470 .list_threads()
471 .await?;
472
473 this.update(cx, |this, cx| {
474 this.threads = threads;
475 cx.notify();
476 })
477 })
478 }
479
480 fn load_default_profile(&self, cx: &mut Context<Self>) {
481 let assistant_settings = AssistantSettings::get_global(cx);
482
483 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
484 }
485
486 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
487 let assistant_settings = AssistantSettings::get_global(cx);
488
489 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
490 self.load_profile(profile.clone(), cx);
491 }
492 }
493
494 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
495 self.tools.update(cx, |tools, cx| {
496 tools.disable_all_tools(cx);
497 tools.enable(
498 ToolSource::Native,
499 &profile
500 .tools
501 .iter()
502 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
503 .collect::<Vec<_>>(),
504 cx,
505 );
506 });
507
508 if profile.enable_all_context_servers {
509 for context_server in self.context_server_manager.read(cx).all_servers() {
510 self.tools.update(cx, |tools, cx| {
511 tools.enable_source(
512 ToolSource::ContextServer {
513 id: context_server.id().into(),
514 },
515 cx,
516 );
517 });
518 }
519 } else {
520 for (context_server_id, preset) in &profile.context_servers {
521 self.tools.update(cx, |tools, cx| {
522 tools.enable(
523 ToolSource::ContextServer {
524 id: context_server_id.clone().into(),
525 },
526 &preset
527 .tools
528 .iter()
529 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
530 .collect::<Vec<_>>(),
531 cx,
532 )
533 })
534 }
535 }
536 }
537
538 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
539 cx.subscribe(
540 &self.context_server_manager.clone(),
541 Self::handle_context_server_event,
542 )
543 .detach();
544 }
545
546 fn handle_context_server_event(
547 &mut self,
548 context_server_manager: Entity<ContextServerManager>,
549 event: &context_server::manager::Event,
550 cx: &mut Context<Self>,
551 ) {
552 let tool_working_set = self.tools.clone();
553 match event {
554 context_server::manager::Event::ServerStarted { server_id } => {
555 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
556 let context_server_manager = context_server_manager.clone();
557 cx.spawn({
558 let server = server.clone();
559 let server_id = server_id.clone();
560 async move |this, cx| {
561 let Some(protocol) = server.client() else {
562 return;
563 };
564
565 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
566 if let Some(tools) = protocol.list_tools().await.log_err() {
567 let tool_ids = tool_working_set
568 .update(cx, |tool_working_set, _| {
569 tools
570 .tools
571 .into_iter()
572 .map(|tool| {
573 log::info!(
574 "registering context server tool: {:?}",
575 tool.name
576 );
577 tool_working_set.insert(Arc::new(
578 ContextServerTool::new(
579 context_server_manager.clone(),
580 server.id(),
581 tool,
582 ),
583 ))
584 })
585 .collect::<Vec<_>>()
586 })
587 .log_err();
588
589 if let Some(tool_ids) = tool_ids {
590 this.update(cx, |this, cx| {
591 this.context_server_tool_ids
592 .insert(server_id, tool_ids);
593 this.load_default_profile(cx);
594 })
595 .log_err();
596 }
597 }
598 }
599 }
600 })
601 .detach();
602 }
603 }
604 context_server::manager::Event::ServerStopped { server_id } => {
605 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
606 tool_working_set.update(cx, |tool_working_set, _| {
607 tool_working_set.remove(&tool_ids);
608 });
609 self.load_default_profile(cx);
610 }
611 }
612 }
613 }
614}
615
616#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct SerializedThreadMetadata {
618 pub id: ThreadId,
619 pub summary: SharedString,
620 pub updated_at: DateTime<Utc>,
621}
622
623#[derive(Serialize, Deserialize, Debug)]
624pub struct SerializedThread {
625 pub version: String,
626 pub summary: SharedString,
627 pub updated_at: DateTime<Utc>,
628 pub messages: Vec<SerializedMessage>,
629 #[serde(default)]
630 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
631 #[serde(default)]
632 pub cumulative_token_usage: TokenUsage,
633 #[serde(default)]
634 pub request_token_usage: Vec<TokenUsage>,
635 #[serde(default)]
636 pub detailed_summary_state: DetailedSummaryState,
637 #[serde(default)]
638 pub exceeded_window_error: Option<ExceededWindowError>,
639}
640
641impl SerializedThread {
642 pub const VERSION: &'static str = "0.2.0";
643
644 pub fn from_json(json: &[u8]) -> Result<Self> {
645 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
646 match saved_thread_json.get("version") {
647 Some(serde_json::Value::String(version)) => match version.as_str() {
648 SerializedThreadV0_1_0::VERSION => {
649 let saved_thread =
650 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
651 Ok(saved_thread.upgrade())
652 }
653 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
654 saved_thread_json,
655 )?),
656 _ => Err(anyhow!(
657 "unrecognized serialized thread version: {}",
658 version
659 )),
660 },
661 None => {
662 let saved_thread =
663 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
664 Ok(saved_thread.upgrade())
665 }
666 version => Err(anyhow!(
667 "unrecognized serialized thread version: {:?}",
668 version
669 )),
670 }
671 }
672}
673
674#[derive(Serialize, Deserialize, Debug)]
675pub struct SerializedThreadV0_1_0(
676 // The structure did not change, so we are reusing the latest SerializedThread.
677 // When making the next version, make sure this points to SerializedThreadV0_2_0
678 SerializedThread,
679);
680
681impl SerializedThreadV0_1_0 {
682 pub const VERSION: &'static str = "0.1.0";
683
684 pub fn upgrade(self) -> SerializedThread {
685 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
686
687 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
688
689 for message in self.0.messages {
690 if message.role == Role::User && !message.tool_results.is_empty() {
691 if let Some(last_message) = messages.last_mut() {
692 debug_assert!(last_message.role == Role::Assistant);
693
694 last_message.tool_results = message.tool_results;
695 continue;
696 }
697 }
698
699 messages.push(message);
700 }
701
702 SerializedThread { messages, ..self.0 }
703 }
704}
705
706#[derive(Debug, Serialize, Deserialize)]
707pub struct SerializedMessage {
708 pub id: MessageId,
709 pub role: Role,
710 #[serde(default)]
711 pub segments: Vec<SerializedMessageSegment>,
712 #[serde(default)]
713 pub tool_uses: Vec<SerializedToolUse>,
714 #[serde(default)]
715 pub tool_results: Vec<SerializedToolResult>,
716 #[serde(default)]
717 pub context: String,
718}
719
720#[derive(Debug, Serialize, Deserialize)]
721#[serde(tag = "type")]
722pub enum SerializedMessageSegment {
723 #[serde(rename = "text")]
724 Text {
725 text: String,
726 },
727 #[serde(rename = "thinking")]
728 Thinking {
729 text: String,
730 #[serde(skip_serializing_if = "Option::is_none")]
731 signature: Option<String>,
732 },
733 RedactedThinking {
734 data: Vec<u8>,
735 },
736}
737
738#[derive(Debug, Serialize, Deserialize)]
739pub struct SerializedToolUse {
740 pub id: LanguageModelToolUseId,
741 pub name: SharedString,
742 pub input: serde_json::Value,
743}
744
745#[derive(Debug, Serialize, Deserialize)]
746pub struct SerializedToolResult {
747 pub tool_use_id: LanguageModelToolUseId,
748 pub is_error: bool,
749 pub content: Arc<str>,
750}
751
752#[derive(Serialize, Deserialize)]
753struct LegacySerializedThread {
754 pub summary: SharedString,
755 pub updated_at: DateTime<Utc>,
756 pub messages: Vec<LegacySerializedMessage>,
757 #[serde(default)]
758 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
759}
760
761impl LegacySerializedThread {
762 pub fn upgrade(self) -> SerializedThread {
763 SerializedThread {
764 version: SerializedThread::VERSION.to_string(),
765 summary: self.summary,
766 updated_at: self.updated_at,
767 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
768 initial_project_snapshot: self.initial_project_snapshot,
769 cumulative_token_usage: TokenUsage::default(),
770 request_token_usage: Vec::new(),
771 detailed_summary_state: DetailedSummaryState::default(),
772 exceeded_window_error: None,
773 }
774 }
775}
776
777#[derive(Debug, Serialize, Deserialize)]
778struct LegacySerializedMessage {
779 pub id: MessageId,
780 pub role: Role,
781 pub text: String,
782 #[serde(default)]
783 pub tool_uses: Vec<SerializedToolUse>,
784 #[serde(default)]
785 pub tool_results: Vec<SerializedToolResult>,
786}
787
788impl LegacySerializedMessage {
789 fn upgrade(self) -> SerializedMessage {
790 SerializedMessage {
791 id: self.id,
792 role: self.role,
793 segments: vec![SerializedMessageSegment::Text { text: self.text }],
794 tool_uses: self.tool_uses,
795 tool_results: self.tool_results,
796 context: String::new(),
797 }
798 }
799}
800
801struct GlobalThreadsDatabase(
802 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
803);
804
805impl Global for GlobalThreadsDatabase {}
806
807pub(crate) struct ThreadsDatabase {
808 executor: BackgroundExecutor,
809 env: heed::Env,
810 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
811}
812
813impl heed::BytesEncode<'_> for SerializedThread {
814 type EItem = SerializedThread;
815
816 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
817 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
818 }
819}
820
821impl<'a> heed::BytesDecode<'a> for SerializedThread {
822 type DItem = SerializedThread;
823
824 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
825 // We implement this type manually because we want to call `SerializedThread::from_json`,
826 // instead of the Deserialize trait implementation for `SerializedThread`.
827 SerializedThread::from_json(bytes).map_err(Into::into)
828 }
829}
830
831impl ThreadsDatabase {
832 fn global_future(
833 cx: &mut App,
834 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
835 GlobalThreadsDatabase::global(cx).0.clone()
836 }
837
838 fn init(cx: &mut App) {
839 let executor = cx.background_executor().clone();
840 let database_future = executor
841 .spawn({
842 let executor = executor.clone();
843 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
844 async move { ThreadsDatabase::new(database_path, executor) }
845 })
846 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
847 .boxed()
848 .shared();
849
850 cx.set_global(GlobalThreadsDatabase(database_future));
851 }
852
853 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
854 std::fs::create_dir_all(&path)?;
855
856 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
857 let env = unsafe {
858 heed::EnvOpenOptions::new()
859 .map_size(ONE_GB_IN_BYTES)
860 .max_dbs(1)
861 .open(path)?
862 };
863
864 let mut txn = env.write_txn()?;
865 let threads = env.create_database(&mut txn, Some("threads"))?;
866 txn.commit()?;
867
868 Ok(Self {
869 executor,
870 env,
871 threads,
872 })
873 }
874
875 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
876 let env = self.env.clone();
877 let threads = self.threads;
878
879 self.executor.spawn(async move {
880 let txn = env.read_txn()?;
881 let mut iter = threads.iter(&txn)?;
882 let mut threads = Vec::new();
883 while let Some((key, value)) = iter.next().transpose()? {
884 threads.push(SerializedThreadMetadata {
885 id: key,
886 summary: value.summary,
887 updated_at: value.updated_at,
888 });
889 }
890
891 Ok(threads)
892 })
893 }
894
895 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
896 let env = self.env.clone();
897 let threads = self.threads;
898
899 self.executor.spawn(async move {
900 let txn = env.read_txn()?;
901 let thread = threads.get(&txn, &id)?;
902 Ok(thread)
903 })
904 }
905
906 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
907 let env = self.env.clone();
908 let threads = self.threads;
909
910 self.executor.spawn(async move {
911 let mut txn = env.write_txn()?;
912 threads.put(&mut txn, &id, &thread)?;
913 txn.commit()?;
914 Ok(())
915 })
916 }
917
918 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
919 let env = self.env.clone();
920 let threads = self.threads;
921
922 self.executor.spawn(async move {
923 let mut txn = env.write_txn()?;
924 threads.delete(&mut txn, &id)?;
925 txn.commit()?;
926 Ok(())
927 })
928 }
929}