1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use futures::channel::{mpsc, oneshot};
14use futures::future::{self, BoxFuture, Shared};
15use futures::{FutureExt as _, StreamExt as _};
16use gpui::{
17 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
18 Subscription, Task, prelude::*,
19};
20use heed::Database;
21use heed::types::SerdeBincode;
22use language_model::{LanguageModelToolUseId, Role, TokenUsage};
23use project::context_server_store::{ContextServerStatus, ContextServerStore};
24use project::{Project, ProjectItem, ProjectPath, Worktree};
25use prompt_store::{
26 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
27 UserRulesContext, WorktreeContext,
28};
29use serde::{Deserialize, Serialize};
30use settings::{Settings as _, SettingsStore};
31use ui::Window;
32use util::ResultExt as _;
33
34use crate::context_server_tool::ContextServerTool;
35use crate::thread::{
36 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
37};
38
39const RULES_FILE_NAMES: [&'static str; 6] = [
40 ".rules",
41 ".cursorrules",
42 ".windsurfrules",
43 ".clinerules",
44 ".github/copilot-instructions.md",
45 "CLAUDE.md",
46];
47
48pub fn init(cx: &mut App) {
49 ThreadsDatabase::init(cx);
50}
51
52/// A system prompt shared by all threads created by this ThreadStore
53#[derive(Clone, Default)]
54pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
55
56impl SharedProjectContext {
57 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
58 self.0.borrow()
59 }
60}
61
62pub type TextThreadStore = assistant_context_editor::ContextStore;
63
64pub struct ThreadStore {
65 project: Entity<Project>,
66 tools: Entity<ToolWorkingSet>,
67 prompt_builder: Arc<PromptBuilder>,
68 prompt_store: Option<Entity<PromptStore>>,
69 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
70 threads: Vec<SerializedThreadMetadata>,
71 project_context: SharedProjectContext,
72 reload_system_prompt_tx: mpsc::Sender<()>,
73 _reload_system_prompt_task: Task<()>,
74 _subscriptions: Vec<Subscription>,
75}
76
77pub struct RulesLoadingError {
78 pub message: SharedString,
79}
80
81impl EventEmitter<RulesLoadingError> for ThreadStore {}
82
83impl ThreadStore {
84 pub fn load(
85 project: Entity<Project>,
86 tools: Entity<ToolWorkingSet>,
87 prompt_store: Option<Entity<PromptStore>>,
88 prompt_builder: Arc<PromptBuilder>,
89 cx: &mut App,
90 ) -> Task<Result<Entity<Self>>> {
91 cx.spawn(async move |cx| {
92 let (thread_store, ready_rx) = cx.update(|cx| {
93 let mut option_ready_rx = None;
94 let thread_store = cx.new(|cx| {
95 let (thread_store, ready_rx) =
96 Self::new(project, tools, prompt_builder, prompt_store, cx);
97 option_ready_rx = Some(ready_rx);
98 thread_store
99 });
100 (thread_store, option_ready_rx.take().unwrap())
101 })?;
102 ready_rx.await?;
103 Ok(thread_store)
104 })
105 }
106
107 fn new(
108 project: Entity<Project>,
109 tools: Entity<ToolWorkingSet>,
110 prompt_builder: Arc<PromptBuilder>,
111 prompt_store: Option<Entity<PromptStore>>,
112 cx: &mut Context<Self>,
113 ) -> (Self, oneshot::Receiver<()>) {
114 let mut subscriptions = vec![
115 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
116 this.load_default_profile(cx);
117 }),
118 cx.subscribe(&project, Self::handle_project_event),
119 ];
120
121 if let Some(prompt_store) = prompt_store.as_ref() {
122 subscriptions.push(cx.subscribe(
123 prompt_store,
124 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
125 this.enqueue_system_prompt_reload();
126 },
127 ))
128 }
129
130 // This channel and task prevent concurrent and redundant loading of the system prompt.
131 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
132 let (ready_tx, ready_rx) = oneshot::channel();
133 let mut ready_tx = Some(ready_tx);
134 let reload_system_prompt_task = cx.spawn({
135 let prompt_store = prompt_store.clone();
136 async move |thread_store, cx| {
137 loop {
138 let Some(reload_task) = thread_store
139 .update(cx, |thread_store, cx| {
140 thread_store.reload_system_prompt(prompt_store.clone(), cx)
141 })
142 .ok()
143 else {
144 return;
145 };
146 reload_task.await;
147 if let Some(ready_tx) = ready_tx.take() {
148 ready_tx.send(()).ok();
149 }
150 reload_system_prompt_rx.next().await;
151 }
152 }
153 });
154
155 let this = Self {
156 project,
157 tools,
158 prompt_builder,
159 prompt_store,
160 context_server_tool_ids: HashMap::default(),
161 threads: Vec::new(),
162 project_context: SharedProjectContext::default(),
163 reload_system_prompt_tx,
164 _reload_system_prompt_task: reload_system_prompt_task,
165 _subscriptions: subscriptions,
166 };
167 this.load_default_profile(cx);
168 this.register_context_server_handlers(cx);
169 this.reload(cx).detach_and_log_err(cx);
170 (this, ready_rx)
171 }
172
173 fn handle_project_event(
174 &mut self,
175 _project: Entity<Project>,
176 event: &project::Event,
177 _cx: &mut Context<Self>,
178 ) {
179 match event {
180 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
181 self.enqueue_system_prompt_reload();
182 }
183 project::Event::WorktreeUpdatedEntries(_, items) => {
184 if items.iter().any(|(path, _, _)| {
185 RULES_FILE_NAMES
186 .iter()
187 .any(|name| path.as_ref() == Path::new(name))
188 }) {
189 self.enqueue_system_prompt_reload();
190 }
191 }
192 _ => {}
193 }
194 }
195
196 fn enqueue_system_prompt_reload(&mut self) {
197 self.reload_system_prompt_tx.try_send(()).ok();
198 }
199
200 // Note that this should only be called from `reload_system_prompt_task`.
201 fn reload_system_prompt(
202 &self,
203 prompt_store: Option<Entity<PromptStore>>,
204 cx: &mut Context<Self>,
205 ) -> Task<()> {
206 let worktrees = self
207 .project
208 .read(cx)
209 .visible_worktrees(cx)
210 .collect::<Vec<_>>();
211 let worktree_tasks = worktrees
212 .into_iter()
213 .map(|worktree| {
214 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
215 })
216 .collect::<Vec<_>>();
217 let default_user_rules_task = match prompt_store {
218 None => Task::ready(vec![]),
219 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
220 let prompts = prompt_store.default_prompt_metadata();
221 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
222 let contents = prompt_store.load(prompt_metadata.id, cx);
223 async move { (contents.await, prompt_metadata) }
224 });
225 cx.background_spawn(future::join_all(load_tasks))
226 }),
227 };
228
229 cx.spawn(async move |this, cx| {
230 let (worktrees, default_user_rules) =
231 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
232
233 let worktrees = worktrees
234 .into_iter()
235 .map(|(worktree, rules_error)| {
236 if let Some(rules_error) = rules_error {
237 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
238 }
239 worktree
240 })
241 .collect::<Vec<_>>();
242
243 let default_user_rules = default_user_rules
244 .into_iter()
245 .flat_map(|(contents, prompt_metadata)| match contents {
246 Ok(contents) => Some(UserRulesContext {
247 uuid: match prompt_metadata.id {
248 PromptId::User { uuid } => uuid,
249 PromptId::EditWorkflow => return None,
250 },
251 title: prompt_metadata.title.map(|title| title.to_string()),
252 contents,
253 }),
254 Err(err) => {
255 this.update(cx, |_, cx| {
256 cx.emit(RulesLoadingError {
257 message: format!("{err:?}").into(),
258 });
259 })
260 .ok();
261 None
262 }
263 })
264 .collect::<Vec<_>>();
265
266 this.update(cx, |this, _cx| {
267 *this.project_context.0.borrow_mut() =
268 Some(ProjectContext::new(worktrees, default_user_rules));
269 })
270 .ok();
271 })
272 }
273
274 fn load_worktree_info_for_system_prompt(
275 worktree: Entity<Worktree>,
276 project: Entity<Project>,
277 cx: &mut App,
278 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
279 let root_name = worktree.read(cx).root_name().into();
280
281 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
282 let Some(rules_task) = rules_task else {
283 return Task::ready((
284 WorktreeContext {
285 root_name,
286 rules_file: None,
287 },
288 None,
289 ));
290 };
291
292 cx.spawn(async move |_| {
293 let (rules_file, rules_file_error) = match rules_task.await {
294 Ok(rules_file) => (Some(rules_file), None),
295 Err(err) => (
296 None,
297 Some(RulesLoadingError {
298 message: format!("{err}").into(),
299 }),
300 ),
301 };
302 let worktree_info = WorktreeContext {
303 root_name,
304 rules_file,
305 };
306 (worktree_info, rules_file_error)
307 })
308 }
309
310 fn load_worktree_rules_file(
311 worktree: Entity<Worktree>,
312 project: Entity<Project>,
313 cx: &mut App,
314 ) -> Option<Task<Result<RulesFileContext>>> {
315 let worktree_ref = worktree.read(cx);
316 let worktree_id = worktree_ref.id();
317 let selected_rules_file = RULES_FILE_NAMES
318 .into_iter()
319 .filter_map(|name| {
320 worktree_ref
321 .entry_for_path(name)
322 .filter(|entry| entry.is_file())
323 .map(|entry| entry.path.clone())
324 })
325 .next();
326
327 // Note that Cline supports `.clinerules` being a directory, but that is not currently
328 // supported. This doesn't seem to occur often in GitHub repositories.
329 selected_rules_file.map(|path_in_worktree| {
330 let project_path = ProjectPath {
331 worktree_id,
332 path: path_in_worktree.clone(),
333 };
334 let buffer_task =
335 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
336 let rope_task = cx.spawn(async move |cx| {
337 buffer_task.await?.read_with(cx, |buffer, cx| {
338 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
339 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
340 })?
341 });
342 // Build a string from the rope on a background thread.
343 cx.background_spawn(async move {
344 let (project_entry_id, rope) = rope_task.await?;
345 anyhow::Ok(RulesFileContext {
346 path_in_worktree,
347 text: rope.to_string().trim().to_string(),
348 project_entry_id: project_entry_id.to_usize(),
349 })
350 })
351 })
352 }
353
354 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
355 &self.prompt_store
356 }
357
358 pub fn tools(&self) -> Entity<ToolWorkingSet> {
359 self.tools.clone()
360 }
361
362 /// Returns the number of threads.
363 pub fn thread_count(&self) -> usize {
364 self.threads.len()
365 }
366
367 pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
368 self.threads.iter()
369 }
370
371 pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
372 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
373 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
374 threads
375 }
376
377 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
378 cx.new(|cx| {
379 Thread::new(
380 self.project.clone(),
381 self.tools.clone(),
382 self.prompt_builder.clone(),
383 self.project_context.clone(),
384 cx,
385 )
386 })
387 }
388
389 pub fn open_thread(
390 &self,
391 id: &ThreadId,
392 window: &mut Window,
393 cx: &mut Context<Self>,
394 ) -> Task<Result<Entity<Thread>>> {
395 let id = id.clone();
396 let database_future = ThreadsDatabase::global_future(cx);
397 let this = cx.weak_entity();
398 window.spawn(cx, async move |cx| {
399 let database = database_future.await.map_err(|err| anyhow!(err))?;
400 let thread = database
401 .try_find_thread(id.clone())
402 .await?
403 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
404
405 let thread = this.update_in(cx, |this, window, cx| {
406 cx.new(|cx| {
407 Thread::deserialize(
408 id.clone(),
409 thread,
410 this.project.clone(),
411 this.tools.clone(),
412 this.prompt_builder.clone(),
413 this.project_context.clone(),
414 window,
415 cx,
416 )
417 })
418 })?;
419
420 Ok(thread)
421 })
422 }
423
424 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
425 let (metadata, serialized_thread) =
426 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
427
428 let database_future = ThreadsDatabase::global_future(cx);
429 cx.spawn(async move |this, cx| {
430 let serialized_thread = serialized_thread.await?;
431 let database = database_future.await.map_err(|err| anyhow!(err))?;
432 database.save_thread(metadata, serialized_thread).await?;
433
434 this.update(cx, |this, cx| this.reload(cx))?.await
435 })
436 }
437
438 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
439 let id = id.clone();
440 let database_future = ThreadsDatabase::global_future(cx);
441 cx.spawn(async move |this, cx| {
442 let database = database_future.await.map_err(|err| anyhow!(err))?;
443 database.delete_thread(id.clone()).await?;
444
445 this.update(cx, |this, cx| {
446 this.threads.retain(|thread| thread.id != id);
447 cx.notify();
448 })
449 })
450 }
451
452 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
453 let database_future = ThreadsDatabase::global_future(cx);
454 cx.spawn(async move |this, cx| {
455 let threads = database_future
456 .await
457 .map_err(|err| anyhow!(err))?
458 .list_threads()
459 .await?;
460
461 this.update(cx, |this, cx| {
462 this.threads = threads;
463 cx.notify();
464 })
465 })
466 }
467
468 fn load_default_profile(&self, cx: &mut Context<Self>) {
469 let assistant_settings = AssistantSettings::get_global(cx);
470
471 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
472 }
473
474 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
475 let assistant_settings = AssistantSettings::get_global(cx);
476
477 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
478 self.load_profile(profile.clone(), cx);
479 }
480 }
481
482 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
483 self.tools.update(cx, |tools, cx| {
484 tools.disable_all_tools(cx);
485 tools.enable(
486 ToolSource::Native,
487 &profile
488 .tools
489 .into_iter()
490 .filter_map(|(tool, enabled)| enabled.then(|| tool))
491 .collect::<Vec<_>>(),
492 cx,
493 );
494 });
495
496 if profile.enable_all_context_servers {
497 for context_server_id in self
498 .project
499 .read(cx)
500 .context_server_store()
501 .read(cx)
502 .all_server_ids()
503 {
504 self.tools.update(cx, |tools, cx| {
505 tools.enable_source(
506 ToolSource::ContextServer {
507 id: context_server_id.0.into(),
508 },
509 cx,
510 );
511 });
512 }
513 // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
514 for (context_server_id, preset) in profile.context_servers {
515 self.tools.update(cx, |tools, cx| {
516 tools.disable(
517 ToolSource::ContextServer {
518 id: context_server_id.into(),
519 },
520 &preset
521 .tools
522 .into_iter()
523 .filter_map(|(tool, enabled)| (!enabled).then(|| tool))
524 .collect::<Vec<_>>(),
525 cx,
526 )
527 })
528 }
529 } else {
530 for (context_server_id, preset) in profile.context_servers {
531 self.tools.update(cx, |tools, cx| {
532 tools.enable(
533 ToolSource::ContextServer {
534 id: context_server_id.into(),
535 },
536 &preset
537 .tools
538 .into_iter()
539 .filter_map(|(tool, enabled)| enabled.then(|| tool))
540 .collect::<Vec<_>>(),
541 cx,
542 )
543 })
544 }
545 }
546 }
547
548 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
549 cx.subscribe(
550 &self.project.read(cx).context_server_store(),
551 Self::handle_context_server_event,
552 )
553 .detach();
554 }
555
556 fn handle_context_server_event(
557 &mut self,
558 context_server_store: Entity<ContextServerStore>,
559 event: &project::context_server_store::Event,
560 cx: &mut Context<Self>,
561 ) {
562 let tool_working_set = self.tools.clone();
563 match event {
564 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
565 match status {
566 ContextServerStatus::Running => {
567 if let Some(server) =
568 context_server_store.read(cx).get_running_server(server_id)
569 {
570 let context_server_manager = context_server_store.clone();
571 cx.spawn({
572 let server = server.clone();
573 let server_id = server_id.clone();
574 async move |this, cx| {
575 let Some(protocol) = server.client() else {
576 return;
577 };
578
579 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
580 if let Some(tools) = protocol.list_tools().await.log_err() {
581 let tool_ids = tool_working_set
582 .update(cx, |tool_working_set, _| {
583 tools
584 .tools
585 .into_iter()
586 .map(|tool| {
587 log::info!(
588 "registering context server tool: {:?}",
589 tool.name
590 );
591 tool_working_set.insert(Arc::new(
592 ContextServerTool::new(
593 context_server_manager.clone(),
594 server.id(),
595 tool,
596 ),
597 ))
598 })
599 .collect::<Vec<_>>()
600 })
601 .log_err();
602
603 if let Some(tool_ids) = tool_ids {
604 this.update(cx, |this, cx| {
605 this.context_server_tool_ids
606 .insert(server_id, tool_ids);
607 this.load_default_profile(cx);
608 })
609 .log_err();
610 }
611 }
612 }
613 }
614 })
615 .detach();
616 }
617 }
618 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
619 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
620 tool_working_set.update(cx, |tool_working_set, _| {
621 tool_working_set.remove(&tool_ids);
622 });
623 self.load_default_profile(cx);
624 }
625 }
626 _ => {}
627 }
628 }
629 }
630 }
631}
632
633#[derive(Debug, Clone, Serialize, Deserialize)]
634pub struct SerializedThreadMetadata {
635 pub id: ThreadId,
636 pub summary: SharedString,
637 pub updated_at: DateTime<Utc>,
638}
639
640#[derive(Serialize, Deserialize, Debug)]
641pub struct SerializedThread {
642 pub version: String,
643 pub summary: SharedString,
644 pub updated_at: DateTime<Utc>,
645 pub messages: Vec<SerializedMessage>,
646 #[serde(default)]
647 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
648 #[serde(default)]
649 pub cumulative_token_usage: TokenUsage,
650 #[serde(default)]
651 pub request_token_usage: Vec<TokenUsage>,
652 #[serde(default)]
653 pub detailed_summary_state: DetailedSummaryState,
654 #[serde(default)]
655 pub exceeded_window_error: Option<ExceededWindowError>,
656 #[serde(default)]
657 pub model: Option<SerializedLanguageModel>,
658 #[serde(default)]
659 pub completion_mode: Option<CompletionMode>,
660}
661
662#[derive(Serialize, Deserialize, Debug)]
663pub struct SerializedLanguageModel {
664 pub provider: String,
665 pub model: String,
666}
667
668impl SerializedThread {
669 pub const VERSION: &'static str = "0.2.0";
670
671 pub fn from_json(json: &[u8]) -> Result<Self> {
672 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
673 match saved_thread_json.get("version") {
674 Some(serde_json::Value::String(version)) => match version.as_str() {
675 SerializedThreadV0_1_0::VERSION => {
676 let saved_thread =
677 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
678 Ok(saved_thread.upgrade())
679 }
680 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
681 saved_thread_json,
682 )?),
683 _ => Err(anyhow!(
684 "unrecognized serialized thread version: {}",
685 version
686 )),
687 },
688 None => {
689 let saved_thread =
690 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
691 Ok(saved_thread.upgrade())
692 }
693 version => Err(anyhow!(
694 "unrecognized serialized thread version: {:?}",
695 version
696 )),
697 }
698 }
699}
700
701#[derive(Serialize, Deserialize, Debug)]
702pub struct SerializedThreadV0_1_0(
703 // The structure did not change, so we are reusing the latest SerializedThread.
704 // When making the next version, make sure this points to SerializedThreadV0_2_0
705 SerializedThread,
706);
707
708impl SerializedThreadV0_1_0 {
709 pub const VERSION: &'static str = "0.1.0";
710
711 pub fn upgrade(self) -> SerializedThread {
712 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
713
714 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
715
716 for message in self.0.messages {
717 if message.role == Role::User && !message.tool_results.is_empty() {
718 if let Some(last_message) = messages.last_mut() {
719 debug_assert!(last_message.role == Role::Assistant);
720
721 last_message.tool_results = message.tool_results;
722 continue;
723 }
724 }
725
726 messages.push(message);
727 }
728
729 SerializedThread { messages, ..self.0 }
730 }
731}
732
733#[derive(Debug, Serialize, Deserialize)]
734pub struct SerializedMessage {
735 pub id: MessageId,
736 pub role: Role,
737 #[serde(default)]
738 pub segments: Vec<SerializedMessageSegment>,
739 #[serde(default)]
740 pub tool_uses: Vec<SerializedToolUse>,
741 #[serde(default)]
742 pub tool_results: Vec<SerializedToolResult>,
743 #[serde(default)]
744 pub context: String,
745 #[serde(default)]
746 pub creases: Vec<SerializedCrease>,
747}
748
749#[derive(Debug, Serialize, Deserialize)]
750#[serde(tag = "type")]
751pub enum SerializedMessageSegment {
752 #[serde(rename = "text")]
753 Text {
754 text: String,
755 },
756 #[serde(rename = "thinking")]
757 Thinking {
758 text: String,
759 #[serde(skip_serializing_if = "Option::is_none")]
760 signature: Option<String>,
761 },
762 RedactedThinking {
763 data: Vec<u8>,
764 },
765}
766
767#[derive(Debug, Serialize, Deserialize)]
768pub struct SerializedToolUse {
769 pub id: LanguageModelToolUseId,
770 pub name: SharedString,
771 pub input: serde_json::Value,
772}
773
774#[derive(Debug, Serialize, Deserialize)]
775pub struct SerializedToolResult {
776 pub tool_use_id: LanguageModelToolUseId,
777 pub is_error: bool,
778 pub content: Arc<str>,
779 pub output: Option<serde_json::Value>,
780}
781
782#[derive(Serialize, Deserialize)]
783struct LegacySerializedThread {
784 pub summary: SharedString,
785 pub updated_at: DateTime<Utc>,
786 pub messages: Vec<LegacySerializedMessage>,
787 #[serde(default)]
788 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
789}
790
791impl LegacySerializedThread {
792 pub fn upgrade(self) -> SerializedThread {
793 SerializedThread {
794 version: SerializedThread::VERSION.to_string(),
795 summary: self.summary,
796 updated_at: self.updated_at,
797 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
798 initial_project_snapshot: self.initial_project_snapshot,
799 cumulative_token_usage: TokenUsage::default(),
800 request_token_usage: Vec::new(),
801 detailed_summary_state: DetailedSummaryState::default(),
802 exceeded_window_error: None,
803 model: None,
804 completion_mode: None,
805 }
806 }
807}
808
809#[derive(Debug, Serialize, Deserialize)]
810struct LegacySerializedMessage {
811 pub id: MessageId,
812 pub role: Role,
813 pub text: String,
814 #[serde(default)]
815 pub tool_uses: Vec<SerializedToolUse>,
816 #[serde(default)]
817 pub tool_results: Vec<SerializedToolResult>,
818}
819
820impl LegacySerializedMessage {
821 fn upgrade(self) -> SerializedMessage {
822 SerializedMessage {
823 id: self.id,
824 role: self.role,
825 segments: vec![SerializedMessageSegment::Text { text: self.text }],
826 tool_uses: self.tool_uses,
827 tool_results: self.tool_results,
828 context: String::new(),
829 creases: Vec::new(),
830 }
831 }
832}
833
834#[derive(Debug, Serialize, Deserialize)]
835pub struct SerializedCrease {
836 pub start: usize,
837 pub end: usize,
838 pub icon_path: SharedString,
839 pub label: SharedString,
840}
841
842struct GlobalThreadsDatabase(
843 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
844);
845
846impl Global for GlobalThreadsDatabase {}
847
848pub(crate) struct ThreadsDatabase {
849 executor: BackgroundExecutor,
850 env: heed::Env,
851 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
852}
853
854impl heed::BytesEncode<'_> for SerializedThread {
855 type EItem = SerializedThread;
856
857 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
858 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
859 }
860}
861
862impl<'a> heed::BytesDecode<'a> for SerializedThread {
863 type DItem = SerializedThread;
864
865 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
866 // We implement this type manually because we want to call `SerializedThread::from_json`,
867 // instead of the Deserialize trait implementation for `SerializedThread`.
868 SerializedThread::from_json(bytes).map_err(Into::into)
869 }
870}
871
872impl ThreadsDatabase {
873 fn global_future(
874 cx: &mut App,
875 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
876 GlobalThreadsDatabase::global(cx).0.clone()
877 }
878
879 fn init(cx: &mut App) {
880 let executor = cx.background_executor().clone();
881 let database_future = executor
882 .spawn({
883 let executor = executor.clone();
884 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
885 async move { ThreadsDatabase::new(database_path, executor) }
886 })
887 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
888 .boxed()
889 .shared();
890
891 cx.set_global(GlobalThreadsDatabase(database_future));
892 }
893
894 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
895 std::fs::create_dir_all(&path)?;
896
897 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
898 let env = unsafe {
899 heed::EnvOpenOptions::new()
900 .map_size(ONE_GB_IN_BYTES)
901 .max_dbs(1)
902 .open(path)?
903 };
904
905 let mut txn = env.write_txn()?;
906 let threads = env.create_database(&mut txn, Some("threads"))?;
907 txn.commit()?;
908
909 Ok(Self {
910 executor,
911 env,
912 threads,
913 })
914 }
915
916 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
917 let env = self.env.clone();
918 let threads = self.threads;
919
920 self.executor.spawn(async move {
921 let txn = env.read_txn()?;
922 let mut iter = threads.iter(&txn)?;
923 let mut threads = Vec::new();
924 while let Some((key, value)) = iter.next().transpose()? {
925 threads.push(SerializedThreadMetadata {
926 id: key,
927 summary: value.summary,
928 updated_at: value.updated_at,
929 });
930 }
931
932 Ok(threads)
933 })
934 }
935
936 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
937 let env = self.env.clone();
938 let threads = self.threads;
939
940 self.executor.spawn(async move {
941 let txn = env.read_txn()?;
942 let thread = threads.get(&txn, &id)?;
943 Ok(thread)
944 })
945 }
946
947 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
948 let env = self.env.clone();
949 let threads = self.threads;
950
951 self.executor.spawn(async move {
952 let mut txn = env.write_txn()?;
953 threads.put(&mut txn, &id, &thread)?;
954 txn.commit()?;
955 Ok(())
956 })
957 }
958
959 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
960 let env = self.env.clone();
961 let threads = self.threads;
962
963 self.executor.spawn(async move {
964 let mut txn = env.write_txn()?;
965 threads.delete(&mut txn, &id)?;
966 txn.commit()?;
967 Ok(())
968 })
969 }
970}