1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use futures::channel::{mpsc, oneshot};
14use futures::future::{self, BoxFuture, Shared};
15use futures::{FutureExt as _, StreamExt as _};
16use gpui::{
17 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
18 Subscription, Task, prelude::*,
19};
20use heed::Database;
21use heed::types::SerdeBincode;
22use language_model::{LanguageModelToolUseId, Role, TokenUsage};
23use project::context_server_store::{ContextServerStatus, ContextServerStore};
24use project::{Project, ProjectItem, ProjectPath, Worktree};
25use prompt_store::{
26 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
27 UserRulesContext, WorktreeContext,
28};
29use serde::{Deserialize, Serialize};
30use settings::{Settings as _, SettingsStore};
31use util::ResultExt as _;
32
33use crate::context_server_tool::ContextServerTool;
34use crate::thread::{
35 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
36};
37
38const RULES_FILE_NAMES: [&'static str; 6] = [
39 ".rules",
40 ".cursorrules",
41 ".windsurfrules",
42 ".clinerules",
43 ".github/copilot-instructions.md",
44 "CLAUDE.md",
45];
46
47pub fn init(cx: &mut App) {
48 ThreadsDatabase::init(cx);
49}
50
51/// A system prompt shared by all threads created by this ThreadStore
52#[derive(Clone, Default)]
53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
54
55impl SharedProjectContext {
56 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
57 self.0.borrow()
58 }
59}
60
61pub struct ThreadStore {
62 project: Entity<Project>,
63 tools: Entity<ToolWorkingSet>,
64 prompt_builder: Arc<PromptBuilder>,
65 prompt_store: Option<Entity<PromptStore>>,
66 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
67 threads: Vec<SerializedThreadMetadata>,
68 project_context: SharedProjectContext,
69 reload_system_prompt_tx: mpsc::Sender<()>,
70 _reload_system_prompt_task: Task<()>,
71 _subscriptions: Vec<Subscription>,
72}
73
74pub struct RulesLoadingError {
75 pub message: SharedString,
76}
77
78impl EventEmitter<RulesLoadingError> for ThreadStore {}
79
80impl ThreadStore {
81 pub fn load(
82 project: Entity<Project>,
83 tools: Entity<ToolWorkingSet>,
84 prompt_store: Option<Entity<PromptStore>>,
85 prompt_builder: Arc<PromptBuilder>,
86 cx: &mut App,
87 ) -> Task<Result<Entity<Self>>> {
88 cx.spawn(async move |cx| {
89 let (thread_store, ready_rx) = cx.update(|cx| {
90 let mut option_ready_rx = None;
91 let thread_store = cx.new(|cx| {
92 let (thread_store, ready_rx) =
93 Self::new(project, tools, prompt_builder, prompt_store, cx);
94 option_ready_rx = Some(ready_rx);
95 thread_store
96 });
97 (thread_store, option_ready_rx.take().unwrap())
98 })?;
99 ready_rx.await?;
100 Ok(thread_store)
101 })
102 }
103
104 fn new(
105 project: Entity<Project>,
106 tools: Entity<ToolWorkingSet>,
107 prompt_builder: Arc<PromptBuilder>,
108 prompt_store: Option<Entity<PromptStore>>,
109 cx: &mut Context<Self>,
110 ) -> (Self, oneshot::Receiver<()>) {
111 let mut subscriptions = vec![
112 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
113 this.load_default_profile(cx);
114 }),
115 cx.subscribe(&project, Self::handle_project_event),
116 ];
117
118 if let Some(prompt_store) = prompt_store.as_ref() {
119 subscriptions.push(cx.subscribe(
120 prompt_store,
121 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
122 this.enqueue_system_prompt_reload();
123 },
124 ))
125 }
126
127 // This channel and task prevent concurrent and redundant loading of the system prompt.
128 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
129 let (ready_tx, ready_rx) = oneshot::channel();
130 let mut ready_tx = Some(ready_tx);
131 let reload_system_prompt_task = cx.spawn({
132 let prompt_store = prompt_store.clone();
133 async move |thread_store, cx| {
134 loop {
135 let Some(reload_task) = thread_store
136 .update(cx, |thread_store, cx| {
137 thread_store.reload_system_prompt(prompt_store.clone(), cx)
138 })
139 .ok()
140 else {
141 return;
142 };
143 reload_task.await;
144 if let Some(ready_tx) = ready_tx.take() {
145 ready_tx.send(()).ok();
146 }
147 reload_system_prompt_rx.next().await;
148 }
149 }
150 });
151
152 let this = Self {
153 project,
154 tools,
155 prompt_builder,
156 prompt_store,
157 context_server_tool_ids: HashMap::default(),
158 threads: Vec::new(),
159 project_context: SharedProjectContext::default(),
160 reload_system_prompt_tx,
161 _reload_system_prompt_task: reload_system_prompt_task,
162 _subscriptions: subscriptions,
163 };
164 this.load_default_profile(cx);
165 this.register_context_server_handlers(cx);
166 this.reload(cx).detach_and_log_err(cx);
167 (this, ready_rx)
168 }
169
170 fn handle_project_event(
171 &mut self,
172 _project: Entity<Project>,
173 event: &project::Event,
174 _cx: &mut Context<Self>,
175 ) {
176 match event {
177 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
178 self.enqueue_system_prompt_reload();
179 }
180 project::Event::WorktreeUpdatedEntries(_, items) => {
181 if items.iter().any(|(path, _, _)| {
182 RULES_FILE_NAMES
183 .iter()
184 .any(|name| path.as_ref() == Path::new(name))
185 }) {
186 self.enqueue_system_prompt_reload();
187 }
188 }
189 _ => {}
190 }
191 }
192
193 fn enqueue_system_prompt_reload(&mut self) {
194 self.reload_system_prompt_tx.try_send(()).ok();
195 }
196
197 // Note that this should only be called from `reload_system_prompt_task`.
198 fn reload_system_prompt(
199 &self,
200 prompt_store: Option<Entity<PromptStore>>,
201 cx: &mut Context<Self>,
202 ) -> Task<()> {
203 let worktrees = self
204 .project
205 .read(cx)
206 .visible_worktrees(cx)
207 .collect::<Vec<_>>();
208 let worktree_tasks = worktrees
209 .into_iter()
210 .map(|worktree| {
211 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
212 })
213 .collect::<Vec<_>>();
214 let default_user_rules_task = match prompt_store {
215 None => Task::ready(vec![]),
216 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
217 let prompts = prompt_store.default_prompt_metadata();
218 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
219 let contents = prompt_store.load(prompt_metadata.id, cx);
220 async move { (contents.await, prompt_metadata) }
221 });
222 cx.background_spawn(future::join_all(load_tasks))
223 }),
224 };
225
226 cx.spawn(async move |this, cx| {
227 let (worktrees, default_user_rules) =
228 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
229
230 let worktrees = worktrees
231 .into_iter()
232 .map(|(worktree, rules_error)| {
233 if let Some(rules_error) = rules_error {
234 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
235 }
236 worktree
237 })
238 .collect::<Vec<_>>();
239
240 let default_user_rules = default_user_rules
241 .into_iter()
242 .flat_map(|(contents, prompt_metadata)| match contents {
243 Ok(contents) => Some(UserRulesContext {
244 uuid: match prompt_metadata.id {
245 PromptId::User { uuid } => uuid,
246 PromptId::EditWorkflow => return None,
247 },
248 title: prompt_metadata.title.map(|title| title.to_string()),
249 contents,
250 }),
251 Err(err) => {
252 this.update(cx, |_, cx| {
253 cx.emit(RulesLoadingError {
254 message: format!("{err:?}").into(),
255 });
256 })
257 .ok();
258 None
259 }
260 })
261 .collect::<Vec<_>>();
262
263 this.update(cx, |this, _cx| {
264 *this.project_context.0.borrow_mut() =
265 Some(ProjectContext::new(worktrees, default_user_rules));
266 })
267 .ok();
268 })
269 }
270
271 fn load_worktree_info_for_system_prompt(
272 worktree: Entity<Worktree>,
273 project: Entity<Project>,
274 cx: &mut App,
275 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
276 let root_name = worktree.read(cx).root_name().into();
277
278 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
279 let Some(rules_task) = rules_task else {
280 return Task::ready((
281 WorktreeContext {
282 root_name,
283 rules_file: None,
284 },
285 None,
286 ));
287 };
288
289 cx.spawn(async move |_| {
290 let (rules_file, rules_file_error) = match rules_task.await {
291 Ok(rules_file) => (Some(rules_file), None),
292 Err(err) => (
293 None,
294 Some(RulesLoadingError {
295 message: format!("{err}").into(),
296 }),
297 ),
298 };
299 let worktree_info = WorktreeContext {
300 root_name,
301 rules_file,
302 };
303 (worktree_info, rules_file_error)
304 })
305 }
306
307 fn load_worktree_rules_file(
308 worktree: Entity<Worktree>,
309 project: Entity<Project>,
310 cx: &mut App,
311 ) -> Option<Task<Result<RulesFileContext>>> {
312 let worktree_ref = worktree.read(cx);
313 let worktree_id = worktree_ref.id();
314 let selected_rules_file = RULES_FILE_NAMES
315 .into_iter()
316 .filter_map(|name| {
317 worktree_ref
318 .entry_for_path(name)
319 .filter(|entry| entry.is_file())
320 .map(|entry| entry.path.clone())
321 })
322 .next();
323
324 // Note that Cline supports `.clinerules` being a directory, but that is not currently
325 // supported. This doesn't seem to occur often in GitHub repositories.
326 selected_rules_file.map(|path_in_worktree| {
327 let project_path = ProjectPath {
328 worktree_id,
329 path: path_in_worktree.clone(),
330 };
331 let buffer_task =
332 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
333 let rope_task = cx.spawn(async move |cx| {
334 buffer_task.await?.read_with(cx, |buffer, cx| {
335 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
336 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
337 })?
338 });
339 // Build a string from the rope on a background thread.
340 cx.background_spawn(async move {
341 let (project_entry_id, rope) = rope_task.await?;
342 anyhow::Ok(RulesFileContext {
343 path_in_worktree,
344 text: rope.to_string().trim().to_string(),
345 project_entry_id: project_entry_id.to_usize(),
346 })
347 })
348 })
349 }
350
351 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
352 &self.prompt_store
353 }
354
355 pub fn tools(&self) -> Entity<ToolWorkingSet> {
356 self.tools.clone()
357 }
358
359 /// Returns the number of threads.
360 pub fn thread_count(&self) -> usize {
361 self.threads.len()
362 }
363
364 pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
365 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
366 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
367 threads
368 }
369
370 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
371 cx.new(|cx| {
372 Thread::new(
373 self.project.clone(),
374 self.tools.clone(),
375 self.prompt_builder.clone(),
376 self.project_context.clone(),
377 cx,
378 )
379 })
380 }
381
382 pub fn open_thread(
383 &self,
384 id: &ThreadId,
385 cx: &mut Context<Self>,
386 ) -> Task<Result<Entity<Thread>>> {
387 let id = id.clone();
388 let database_future = ThreadsDatabase::global_future(cx);
389 cx.spawn(async move |this, cx| {
390 let database = database_future.await.map_err(|err| anyhow!(err))?;
391 let thread = database
392 .try_find_thread(id.clone())
393 .await?
394 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
395
396 let thread = this.update(cx, |this, cx| {
397 cx.new(|cx| {
398 Thread::deserialize(
399 id.clone(),
400 thread,
401 this.project.clone(),
402 this.tools.clone(),
403 this.prompt_builder.clone(),
404 this.project_context.clone(),
405 cx,
406 )
407 })
408 })?;
409
410 Ok(thread)
411 })
412 }
413
414 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
415 let (metadata, serialized_thread) =
416 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
417
418 let database_future = ThreadsDatabase::global_future(cx);
419 cx.spawn(async move |this, cx| {
420 let serialized_thread = serialized_thread.await?;
421 let database = database_future.await.map_err(|err| anyhow!(err))?;
422 database.save_thread(metadata, serialized_thread).await?;
423
424 this.update(cx, |this, cx| this.reload(cx))?.await
425 })
426 }
427
428 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
429 let id = id.clone();
430 let database_future = ThreadsDatabase::global_future(cx);
431 cx.spawn(async move |this, cx| {
432 let database = database_future.await.map_err(|err| anyhow!(err))?;
433 database.delete_thread(id.clone()).await?;
434
435 this.update(cx, |this, cx| {
436 this.threads.retain(|thread| thread.id != id);
437 cx.notify();
438 })
439 })
440 }
441
442 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
443 let database_future = ThreadsDatabase::global_future(cx);
444 cx.spawn(async move |this, cx| {
445 let threads = database_future
446 .await
447 .map_err(|err| anyhow!(err))?
448 .list_threads()
449 .await?;
450
451 this.update(cx, |this, cx| {
452 this.threads = threads;
453 cx.notify();
454 })
455 })
456 }
457
458 fn load_default_profile(&self, cx: &mut Context<Self>) {
459 let assistant_settings = AssistantSettings::get_global(cx);
460
461 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
462 }
463
464 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
465 let assistant_settings = AssistantSettings::get_global(cx);
466
467 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
468 self.load_profile(profile.clone(), cx);
469 }
470 }
471
472 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
473 self.tools.update(cx, |tools, cx| {
474 tools.disable_all_tools(cx);
475 tools.enable(
476 ToolSource::Native,
477 &profile
478 .tools
479 .iter()
480 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
481 .collect::<Vec<_>>(),
482 cx,
483 );
484 });
485
486 if profile.enable_all_context_servers {
487 for context_server_id in self
488 .project
489 .read(cx)
490 .context_server_store()
491 .read(cx)
492 .all_server_ids()
493 {
494 self.tools.update(cx, |tools, cx| {
495 tools.enable_source(
496 ToolSource::ContextServer {
497 id: context_server_id.0.into(),
498 },
499 cx,
500 );
501 });
502 }
503 // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
504 for (context_server_id, preset) in &profile.context_servers {
505 self.tools.update(cx, |tools, cx| {
506 tools.disable(
507 ToolSource::ContextServer {
508 id: context_server_id.clone().into(),
509 },
510 &preset
511 .tools
512 .iter()
513 .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
514 .collect::<Vec<_>>(),
515 cx,
516 )
517 })
518 }
519 } else {
520 for (context_server_id, preset) in &profile.context_servers {
521 self.tools.update(cx, |tools, cx| {
522 tools.enable(
523 ToolSource::ContextServer {
524 id: context_server_id.clone().into(),
525 },
526 &preset
527 .tools
528 .iter()
529 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
530 .collect::<Vec<_>>(),
531 cx,
532 )
533 })
534 }
535 }
536 }
537
538 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
539 cx.subscribe(
540 &self.project.read(cx).context_server_store(),
541 Self::handle_context_server_event,
542 )
543 .detach();
544 }
545
546 fn handle_context_server_event(
547 &mut self,
548 context_server_store: Entity<ContextServerStore>,
549 event: &project::context_server_store::Event,
550 cx: &mut Context<Self>,
551 ) {
552 let tool_working_set = self.tools.clone();
553 match event {
554 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
555 match status {
556 ContextServerStatus::Running => {
557 if let Some(server) =
558 context_server_store.read(cx).get_running_server(server_id)
559 {
560 let context_server_manager = context_server_store.clone();
561 cx.spawn({
562 let server = server.clone();
563 let server_id = server_id.clone();
564 async move |this, cx| {
565 let Some(protocol) = server.client() else {
566 return;
567 };
568
569 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
570 if let Some(tools) = protocol.list_tools().await.log_err() {
571 let tool_ids = tool_working_set
572 .update(cx, |tool_working_set, _| {
573 tools
574 .tools
575 .into_iter()
576 .map(|tool| {
577 log::info!(
578 "registering context server tool: {:?}",
579 tool.name
580 );
581 tool_working_set.insert(Arc::new(
582 ContextServerTool::new(
583 context_server_manager.clone(),
584 server.id(),
585 tool,
586 ),
587 ))
588 })
589 .collect::<Vec<_>>()
590 })
591 .log_err();
592
593 if let Some(tool_ids) = tool_ids {
594 this.update(cx, |this, cx| {
595 this.context_server_tool_ids
596 .insert(server_id, tool_ids);
597 this.load_default_profile(cx);
598 })
599 .log_err();
600 }
601 }
602 }
603 }
604 })
605 .detach();
606 }
607 }
608 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
609 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
610 tool_working_set.update(cx, |tool_working_set, _| {
611 tool_working_set.remove(&tool_ids);
612 });
613 self.load_default_profile(cx);
614 }
615 }
616 _ => {}
617 }
618 }
619 }
620 }
621}
622
623#[derive(Debug, Clone, Serialize, Deserialize)]
624pub struct SerializedThreadMetadata {
625 pub id: ThreadId,
626 pub summary: SharedString,
627 pub updated_at: DateTime<Utc>,
628}
629
630#[derive(Serialize, Deserialize, Debug)]
631pub struct SerializedThread {
632 pub version: String,
633 pub summary: SharedString,
634 pub updated_at: DateTime<Utc>,
635 pub messages: Vec<SerializedMessage>,
636 #[serde(default)]
637 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
638 #[serde(default)]
639 pub cumulative_token_usage: TokenUsage,
640 #[serde(default)]
641 pub request_token_usage: Vec<TokenUsage>,
642 #[serde(default)]
643 pub detailed_summary_state: DetailedSummaryState,
644 #[serde(default)]
645 pub exceeded_window_error: Option<ExceededWindowError>,
646 #[serde(default)]
647 pub model: Option<SerializedLanguageModel>,
648}
649
650#[derive(Serialize, Deserialize, Debug)]
651pub struct SerializedLanguageModel {
652 pub provider: String,
653 pub model: String,
654}
655
656impl SerializedThread {
657 pub const VERSION: &'static str = "0.2.0";
658
659 pub fn from_json(json: &[u8]) -> Result<Self> {
660 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
661 match saved_thread_json.get("version") {
662 Some(serde_json::Value::String(version)) => match version.as_str() {
663 SerializedThreadV0_1_0::VERSION => {
664 let saved_thread =
665 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
666 Ok(saved_thread.upgrade())
667 }
668 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
669 saved_thread_json,
670 )?),
671 _ => Err(anyhow!(
672 "unrecognized serialized thread version: {}",
673 version
674 )),
675 },
676 None => {
677 let saved_thread =
678 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
679 Ok(saved_thread.upgrade())
680 }
681 version => Err(anyhow!(
682 "unrecognized serialized thread version: {:?}",
683 version
684 )),
685 }
686 }
687}
688
689#[derive(Serialize, Deserialize, Debug)]
690pub struct SerializedThreadV0_1_0(
691 // The structure did not change, so we are reusing the latest SerializedThread.
692 // When making the next version, make sure this points to SerializedThreadV0_2_0
693 SerializedThread,
694);
695
696impl SerializedThreadV0_1_0 {
697 pub const VERSION: &'static str = "0.1.0";
698
699 pub fn upgrade(self) -> SerializedThread {
700 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
701
702 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
703
704 for message in self.0.messages {
705 if message.role == Role::User && !message.tool_results.is_empty() {
706 if let Some(last_message) = messages.last_mut() {
707 debug_assert!(last_message.role == Role::Assistant);
708
709 last_message.tool_results = message.tool_results;
710 continue;
711 }
712 }
713
714 messages.push(message);
715 }
716
717 SerializedThread { messages, ..self.0 }
718 }
719}
720
721#[derive(Debug, Serialize, Deserialize)]
722pub struct SerializedMessage {
723 pub id: MessageId,
724 pub role: Role,
725 #[serde(default)]
726 pub segments: Vec<SerializedMessageSegment>,
727 #[serde(default)]
728 pub tool_uses: Vec<SerializedToolUse>,
729 #[serde(default)]
730 pub tool_results: Vec<SerializedToolResult>,
731 #[serde(default)]
732 pub context: String,
733 #[serde(default)]
734 pub creases: Vec<SerializedCrease>,
735}
736
737#[derive(Debug, Serialize, Deserialize)]
738#[serde(tag = "type")]
739pub enum SerializedMessageSegment {
740 #[serde(rename = "text")]
741 Text {
742 text: String,
743 },
744 #[serde(rename = "thinking")]
745 Thinking {
746 text: String,
747 #[serde(skip_serializing_if = "Option::is_none")]
748 signature: Option<String>,
749 },
750 RedactedThinking {
751 data: Vec<u8>,
752 },
753}
754
755#[derive(Debug, Serialize, Deserialize)]
756pub struct SerializedToolUse {
757 pub id: LanguageModelToolUseId,
758 pub name: SharedString,
759 pub input: serde_json::Value,
760}
761
762#[derive(Debug, Serialize, Deserialize)]
763pub struct SerializedToolResult {
764 pub tool_use_id: LanguageModelToolUseId,
765 pub is_error: bool,
766 pub content: Arc<str>,
767}
768
769#[derive(Serialize, Deserialize)]
770struct LegacySerializedThread {
771 pub summary: SharedString,
772 pub updated_at: DateTime<Utc>,
773 pub messages: Vec<LegacySerializedMessage>,
774 #[serde(default)]
775 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
776}
777
778impl LegacySerializedThread {
779 pub fn upgrade(self) -> SerializedThread {
780 SerializedThread {
781 version: SerializedThread::VERSION.to_string(),
782 summary: self.summary,
783 updated_at: self.updated_at,
784 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
785 initial_project_snapshot: self.initial_project_snapshot,
786 cumulative_token_usage: TokenUsage::default(),
787 request_token_usage: Vec::new(),
788 detailed_summary_state: DetailedSummaryState::default(),
789 exceeded_window_error: None,
790 model: None,
791 }
792 }
793}
794
795#[derive(Debug, Serialize, Deserialize)]
796struct LegacySerializedMessage {
797 pub id: MessageId,
798 pub role: Role,
799 pub text: String,
800 #[serde(default)]
801 pub tool_uses: Vec<SerializedToolUse>,
802 #[serde(default)]
803 pub tool_results: Vec<SerializedToolResult>,
804}
805
806impl LegacySerializedMessage {
807 fn upgrade(self) -> SerializedMessage {
808 SerializedMessage {
809 id: self.id,
810 role: self.role,
811 segments: vec![SerializedMessageSegment::Text { text: self.text }],
812 tool_uses: self.tool_uses,
813 tool_results: self.tool_results,
814 context: String::new(),
815 creases: Vec::new(),
816 }
817 }
818}
819
820#[derive(Debug, Serialize, Deserialize)]
821pub struct SerializedCrease {
822 pub start: usize,
823 pub end: usize,
824 pub icon_path: SharedString,
825 pub label: SharedString,
826}
827
828struct GlobalThreadsDatabase(
829 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
830);
831
832impl Global for GlobalThreadsDatabase {}
833
834pub(crate) struct ThreadsDatabase {
835 executor: BackgroundExecutor,
836 env: heed::Env,
837 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
838}
839
840impl heed::BytesEncode<'_> for SerializedThread {
841 type EItem = SerializedThread;
842
843 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
844 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
845 }
846}
847
848impl<'a> heed::BytesDecode<'a> for SerializedThread {
849 type DItem = SerializedThread;
850
851 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
852 // We implement this type manually because we want to call `SerializedThread::from_json`,
853 // instead of the Deserialize trait implementation for `SerializedThread`.
854 SerializedThread::from_json(bytes).map_err(Into::into)
855 }
856}
857
858impl ThreadsDatabase {
859 fn global_future(
860 cx: &mut App,
861 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
862 GlobalThreadsDatabase::global(cx).0.clone()
863 }
864
865 fn init(cx: &mut App) {
866 let executor = cx.background_executor().clone();
867 let database_future = executor
868 .spawn({
869 let executor = executor.clone();
870 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
871 async move { ThreadsDatabase::new(database_path, executor) }
872 })
873 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
874 .boxed()
875 .shared();
876
877 cx.set_global(GlobalThreadsDatabase(database_future));
878 }
879
880 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
881 std::fs::create_dir_all(&path)?;
882
883 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
884 let env = unsafe {
885 heed::EnvOpenOptions::new()
886 .map_size(ONE_GB_IN_BYTES)
887 .max_dbs(1)
888 .open(path)?
889 };
890
891 let mut txn = env.write_txn()?;
892 let threads = env.create_database(&mut txn, Some("threads"))?;
893 txn.commit()?;
894
895 Ok(Self {
896 executor,
897 env,
898 threads,
899 })
900 }
901
902 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
903 let env = self.env.clone();
904 let threads = self.threads;
905
906 self.executor.spawn(async move {
907 let txn = env.read_txn()?;
908 let mut iter = threads.iter(&txn)?;
909 let mut threads = Vec::new();
910 while let Some((key, value)) = iter.next().transpose()? {
911 threads.push(SerializedThreadMetadata {
912 id: key,
913 summary: value.summary,
914 updated_at: value.updated_at,
915 });
916 }
917
918 Ok(threads)
919 })
920 }
921
922 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
923 let env = self.env.clone();
924 let threads = self.threads;
925
926 self.executor.spawn(async move {
927 let txn = env.read_txn()?;
928 let thread = threads.get(&txn, &id)?;
929 Ok(thread)
930 })
931 }
932
933 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
934 let env = self.env.clone();
935 let threads = self.threads;
936
937 self.executor.spawn(async move {
938 let mut txn = env.write_txn()?;
939 threads.put(&mut txn, &id, &thread)?;
940 txn.commit()?;
941 Ok(())
942 })
943 }
944
945 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
946 let env = self.env.clone();
947 let threads = self.threads;
948
949 self.executor.spawn(async move {
950 let mut txn = env.write_txn()?;
951 threads.delete(&mut txn, &id)?;
952 txn.commit()?;
953 Ok(())
954 })
955 }
956}