1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::ContextServerId;
13use futures::channel::{mpsc, oneshot};
14use futures::future::{self, BoxFuture, Shared};
15use futures::{FutureExt as _, StreamExt as _};
16use gpui::{
17 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
18 Subscription, Task, prelude::*,
19};
20use heed::Database;
21use heed::types::SerdeBincode;
22use language_model::{LanguageModelToolUseId, Role, TokenUsage};
23use project::context_server_store::{ContextServerStatus, ContextServerStore};
24use project::{Project, ProjectItem, ProjectPath, Worktree};
25use prompt_store::{
26 ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
27 UserRulesContext, WorktreeContext,
28};
29use serde::{Deserialize, Serialize};
30use settings::{Settings as _, SettingsStore};
31use util::ResultExt as _;
32
33use crate::context_server_tool::ContextServerTool;
34use crate::thread::{
35 DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
36};
37
38const RULES_FILE_NAMES: [&'static str; 6] = [
39 ".rules",
40 ".cursorrules",
41 ".windsurfrules",
42 ".clinerules",
43 ".github/copilot-instructions.md",
44 "CLAUDE.md",
45];
46
47pub fn init(cx: &mut App) {
48 ThreadsDatabase::init(cx);
49}
50
51/// A system prompt shared by all threads created by this ThreadStore
52#[derive(Clone, Default)]
53pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
54
55impl SharedProjectContext {
56 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
57 self.0.borrow()
58 }
59}
60
61pub type TextThreadStore = assistant_context_editor::ContextStore;
62
63pub struct ThreadStore {
64 project: Entity<Project>,
65 tools: Entity<ToolWorkingSet>,
66 prompt_builder: Arc<PromptBuilder>,
67 prompt_store: Option<Entity<PromptStore>>,
68 context_server_tool_ids: HashMap<ContextServerId, Vec<ToolId>>,
69 threads: Vec<SerializedThreadMetadata>,
70 project_context: SharedProjectContext,
71 reload_system_prompt_tx: mpsc::Sender<()>,
72 _reload_system_prompt_task: Task<()>,
73 _subscriptions: Vec<Subscription>,
74}
75
76pub struct RulesLoadingError {
77 pub message: SharedString,
78}
79
80impl EventEmitter<RulesLoadingError> for ThreadStore {}
81
82impl ThreadStore {
83 pub fn load(
84 project: Entity<Project>,
85 tools: Entity<ToolWorkingSet>,
86 prompt_store: Option<Entity<PromptStore>>,
87 prompt_builder: Arc<PromptBuilder>,
88 cx: &mut App,
89 ) -> Task<Result<Entity<Self>>> {
90 cx.spawn(async move |cx| {
91 let (thread_store, ready_rx) = cx.update(|cx| {
92 let mut option_ready_rx = None;
93 let thread_store = cx.new(|cx| {
94 let (thread_store, ready_rx) =
95 Self::new(project, tools, prompt_builder, prompt_store, cx);
96 option_ready_rx = Some(ready_rx);
97 thread_store
98 });
99 (thread_store, option_ready_rx.take().unwrap())
100 })?;
101 ready_rx.await?;
102 Ok(thread_store)
103 })
104 }
105
106 fn new(
107 project: Entity<Project>,
108 tools: Entity<ToolWorkingSet>,
109 prompt_builder: Arc<PromptBuilder>,
110 prompt_store: Option<Entity<PromptStore>>,
111 cx: &mut Context<Self>,
112 ) -> (Self, oneshot::Receiver<()>) {
113 let mut subscriptions = vec![
114 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
115 this.load_default_profile(cx);
116 }),
117 cx.subscribe(&project, Self::handle_project_event),
118 ];
119
120 if let Some(prompt_store) = prompt_store.as_ref() {
121 subscriptions.push(cx.subscribe(
122 prompt_store,
123 |this, _prompt_store, PromptsUpdatedEvent, _cx| {
124 this.enqueue_system_prompt_reload();
125 },
126 ))
127 }
128
129 // This channel and task prevent concurrent and redundant loading of the system prompt.
130 let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1);
131 let (ready_tx, ready_rx) = oneshot::channel();
132 let mut ready_tx = Some(ready_tx);
133 let reload_system_prompt_task = cx.spawn({
134 let prompt_store = prompt_store.clone();
135 async move |thread_store, cx| {
136 loop {
137 let Some(reload_task) = thread_store
138 .update(cx, |thread_store, cx| {
139 thread_store.reload_system_prompt(prompt_store.clone(), cx)
140 })
141 .ok()
142 else {
143 return;
144 };
145 reload_task.await;
146 if let Some(ready_tx) = ready_tx.take() {
147 ready_tx.send(()).ok();
148 }
149 reload_system_prompt_rx.next().await;
150 }
151 }
152 });
153
154 let this = Self {
155 project,
156 tools,
157 prompt_builder,
158 prompt_store,
159 context_server_tool_ids: HashMap::default(),
160 threads: Vec::new(),
161 project_context: SharedProjectContext::default(),
162 reload_system_prompt_tx,
163 _reload_system_prompt_task: reload_system_prompt_task,
164 _subscriptions: subscriptions,
165 };
166 this.load_default_profile(cx);
167 this.register_context_server_handlers(cx);
168 this.reload(cx).detach_and_log_err(cx);
169 (this, ready_rx)
170 }
171
172 fn handle_project_event(
173 &mut self,
174 _project: Entity<Project>,
175 event: &project::Event,
176 _cx: &mut Context<Self>,
177 ) {
178 match event {
179 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
180 self.enqueue_system_prompt_reload();
181 }
182 project::Event::WorktreeUpdatedEntries(_, items) => {
183 if items.iter().any(|(path, _, _)| {
184 RULES_FILE_NAMES
185 .iter()
186 .any(|name| path.as_ref() == Path::new(name))
187 }) {
188 self.enqueue_system_prompt_reload();
189 }
190 }
191 _ => {}
192 }
193 }
194
195 fn enqueue_system_prompt_reload(&mut self) {
196 self.reload_system_prompt_tx.try_send(()).ok();
197 }
198
199 // Note that this should only be called from `reload_system_prompt_task`.
200 fn reload_system_prompt(
201 &self,
202 prompt_store: Option<Entity<PromptStore>>,
203 cx: &mut Context<Self>,
204 ) -> Task<()> {
205 let worktrees = self
206 .project
207 .read(cx)
208 .visible_worktrees(cx)
209 .collect::<Vec<_>>();
210 let worktree_tasks = worktrees
211 .into_iter()
212 .map(|worktree| {
213 Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
214 })
215 .collect::<Vec<_>>();
216 let default_user_rules_task = match prompt_store {
217 None => Task::ready(vec![]),
218 Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| {
219 let prompts = prompt_store.default_prompt_metadata();
220 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
221 let contents = prompt_store.load(prompt_metadata.id, cx);
222 async move { (contents.await, prompt_metadata) }
223 });
224 cx.background_spawn(future::join_all(load_tasks))
225 }),
226 };
227
228 cx.spawn(async move |this, cx| {
229 let (worktrees, default_user_rules) =
230 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
231
232 let worktrees = worktrees
233 .into_iter()
234 .map(|(worktree, rules_error)| {
235 if let Some(rules_error) = rules_error {
236 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
237 }
238 worktree
239 })
240 .collect::<Vec<_>>();
241
242 let default_user_rules = default_user_rules
243 .into_iter()
244 .flat_map(|(contents, prompt_metadata)| match contents {
245 Ok(contents) => Some(UserRulesContext {
246 uuid: match prompt_metadata.id {
247 PromptId::User { uuid } => uuid,
248 PromptId::EditWorkflow => return None,
249 },
250 title: prompt_metadata.title.map(|title| title.to_string()),
251 contents,
252 }),
253 Err(err) => {
254 this.update(cx, |_, cx| {
255 cx.emit(RulesLoadingError {
256 message: format!("{err:?}").into(),
257 });
258 })
259 .ok();
260 None
261 }
262 })
263 .collect::<Vec<_>>();
264
265 this.update(cx, |this, _cx| {
266 *this.project_context.0.borrow_mut() =
267 Some(ProjectContext::new(worktrees, default_user_rules));
268 })
269 .ok();
270 })
271 }
272
273 fn load_worktree_info_for_system_prompt(
274 worktree: Entity<Worktree>,
275 project: Entity<Project>,
276 cx: &mut App,
277 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
278 let root_name = worktree.read(cx).root_name().into();
279
280 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
281 let Some(rules_task) = rules_task else {
282 return Task::ready((
283 WorktreeContext {
284 root_name,
285 rules_file: None,
286 },
287 None,
288 ));
289 };
290
291 cx.spawn(async move |_| {
292 let (rules_file, rules_file_error) = match rules_task.await {
293 Ok(rules_file) => (Some(rules_file), None),
294 Err(err) => (
295 None,
296 Some(RulesLoadingError {
297 message: format!("{err}").into(),
298 }),
299 ),
300 };
301 let worktree_info = WorktreeContext {
302 root_name,
303 rules_file,
304 };
305 (worktree_info, rules_file_error)
306 })
307 }
308
309 fn load_worktree_rules_file(
310 worktree: Entity<Worktree>,
311 project: Entity<Project>,
312 cx: &mut App,
313 ) -> Option<Task<Result<RulesFileContext>>> {
314 let worktree_ref = worktree.read(cx);
315 let worktree_id = worktree_ref.id();
316 let selected_rules_file = RULES_FILE_NAMES
317 .into_iter()
318 .filter_map(|name| {
319 worktree_ref
320 .entry_for_path(name)
321 .filter(|entry| entry.is_file())
322 .map(|entry| entry.path.clone())
323 })
324 .next();
325
326 // Note that Cline supports `.clinerules` being a directory, but that is not currently
327 // supported. This doesn't seem to occur often in GitHub repositories.
328 selected_rules_file.map(|path_in_worktree| {
329 let project_path = ProjectPath {
330 worktree_id,
331 path: path_in_worktree.clone(),
332 };
333 let buffer_task =
334 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
335 let rope_task = cx.spawn(async move |cx| {
336 buffer_task.await?.read_with(cx, |buffer, cx| {
337 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
338 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
339 })?
340 });
341 // Build a string from the rope on a background thread.
342 cx.background_spawn(async move {
343 let (project_entry_id, rope) = rope_task.await?;
344 anyhow::Ok(RulesFileContext {
345 path_in_worktree,
346 text: rope.to_string().trim().to_string(),
347 project_entry_id: project_entry_id.to_usize(),
348 })
349 })
350 })
351 }
352
353 pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
354 &self.prompt_store
355 }
356
357 pub fn tools(&self) -> Entity<ToolWorkingSet> {
358 self.tools.clone()
359 }
360
361 /// Returns the number of threads.
362 pub fn thread_count(&self) -> usize {
363 self.threads.len()
364 }
365
366 pub fn unordered_threads(&self) -> impl Iterator<Item = &SerializedThreadMetadata> {
367 self.threads.iter()
368 }
369
370 pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
371 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
372 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
373 threads
374 }
375
376 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
377 cx.new(|cx| {
378 Thread::new(
379 self.project.clone(),
380 self.tools.clone(),
381 self.prompt_builder.clone(),
382 self.project_context.clone(),
383 cx,
384 )
385 })
386 }
387
388 pub fn open_thread(
389 &self,
390 id: &ThreadId,
391 cx: &mut Context<Self>,
392 ) -> Task<Result<Entity<Thread>>> {
393 let id = id.clone();
394 let database_future = ThreadsDatabase::global_future(cx);
395 cx.spawn(async move |this, cx| {
396 let database = database_future.await.map_err(|err| anyhow!(err))?;
397 let thread = database
398 .try_find_thread(id.clone())
399 .await?
400 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
401
402 let thread = this.update(cx, |this, cx| {
403 cx.new(|cx| {
404 Thread::deserialize(
405 id.clone(),
406 thread,
407 this.project.clone(),
408 this.tools.clone(),
409 this.prompt_builder.clone(),
410 this.project_context.clone(),
411 cx,
412 )
413 })
414 })?;
415
416 Ok(thread)
417 })
418 }
419
420 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
421 let (metadata, serialized_thread) =
422 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
423
424 let database_future = ThreadsDatabase::global_future(cx);
425 cx.spawn(async move |this, cx| {
426 let serialized_thread = serialized_thread.await?;
427 let database = database_future.await.map_err(|err| anyhow!(err))?;
428 database.save_thread(metadata, serialized_thread).await?;
429
430 this.update(cx, |this, cx| this.reload(cx))?.await
431 })
432 }
433
434 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
435 let id = id.clone();
436 let database_future = ThreadsDatabase::global_future(cx);
437 cx.spawn(async move |this, cx| {
438 let database = database_future.await.map_err(|err| anyhow!(err))?;
439 database.delete_thread(id.clone()).await?;
440
441 this.update(cx, |this, cx| {
442 this.threads.retain(|thread| thread.id != id);
443 cx.notify();
444 })
445 })
446 }
447
448 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
449 let database_future = ThreadsDatabase::global_future(cx);
450 cx.spawn(async move |this, cx| {
451 let threads = database_future
452 .await
453 .map_err(|err| anyhow!(err))?
454 .list_threads()
455 .await?;
456
457 this.update(cx, |this, cx| {
458 this.threads = threads;
459 cx.notify();
460 })
461 })
462 }
463
464 fn load_default_profile(&self, cx: &mut Context<Self>) {
465 let assistant_settings = AssistantSettings::get_global(cx);
466
467 self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
468 }
469
470 pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
471 let assistant_settings = AssistantSettings::get_global(cx);
472
473 if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
474 self.load_profile(profile.clone(), cx);
475 }
476 }
477
478 pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
479 self.tools.update(cx, |tools, cx| {
480 tools.disable_all_tools(cx);
481 tools.enable(
482 ToolSource::Native,
483 &profile
484 .tools
485 .iter()
486 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
487 .collect::<Vec<_>>(),
488 cx,
489 );
490 });
491
492 if profile.enable_all_context_servers {
493 for context_server_id in self
494 .project
495 .read(cx)
496 .context_server_store()
497 .read(cx)
498 .all_server_ids()
499 {
500 self.tools.update(cx, |tools, cx| {
501 tools.enable_source(
502 ToolSource::ContextServer {
503 id: context_server_id.0.into(),
504 },
505 cx,
506 );
507 });
508 }
509 // Enable all the tools from all context servers, but disable the ones that are explicitly disabled
510 for (context_server_id, preset) in &profile.context_servers {
511 self.tools.update(cx, |tools, cx| {
512 tools.disable(
513 ToolSource::ContextServer {
514 id: context_server_id.clone().into(),
515 },
516 &preset
517 .tools
518 .iter()
519 .filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
520 .collect::<Vec<_>>(),
521 cx,
522 )
523 })
524 }
525 } else {
526 for (context_server_id, preset) in &profile.context_servers {
527 self.tools.update(cx, |tools, cx| {
528 tools.enable(
529 ToolSource::ContextServer {
530 id: context_server_id.clone().into(),
531 },
532 &preset
533 .tools
534 .iter()
535 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
536 .collect::<Vec<_>>(),
537 cx,
538 )
539 })
540 }
541 }
542 }
543
544 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
545 cx.subscribe(
546 &self.project.read(cx).context_server_store(),
547 Self::handle_context_server_event,
548 )
549 .detach();
550 }
551
552 fn handle_context_server_event(
553 &mut self,
554 context_server_store: Entity<ContextServerStore>,
555 event: &project::context_server_store::Event,
556 cx: &mut Context<Self>,
557 ) {
558 let tool_working_set = self.tools.clone();
559 match event {
560 project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
561 match status {
562 ContextServerStatus::Running => {
563 if let Some(server) =
564 context_server_store.read(cx).get_running_server(server_id)
565 {
566 let context_server_manager = context_server_store.clone();
567 cx.spawn({
568 let server = server.clone();
569 let server_id = server_id.clone();
570 async move |this, cx| {
571 let Some(protocol) = server.client() else {
572 return;
573 };
574
575 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
576 if let Some(tools) = protocol.list_tools().await.log_err() {
577 let tool_ids = tool_working_set
578 .update(cx, |tool_working_set, _| {
579 tools
580 .tools
581 .into_iter()
582 .map(|tool| {
583 log::info!(
584 "registering context server tool: {:?}",
585 tool.name
586 );
587 tool_working_set.insert(Arc::new(
588 ContextServerTool::new(
589 context_server_manager.clone(),
590 server.id(),
591 tool,
592 ),
593 ))
594 })
595 .collect::<Vec<_>>()
596 })
597 .log_err();
598
599 if let Some(tool_ids) = tool_ids {
600 this.update(cx, |this, cx| {
601 this.context_server_tool_ids
602 .insert(server_id, tool_ids);
603 this.load_default_profile(cx);
604 })
605 .log_err();
606 }
607 }
608 }
609 }
610 })
611 .detach();
612 }
613 }
614 ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
615 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
616 tool_working_set.update(cx, |tool_working_set, _| {
617 tool_working_set.remove(&tool_ids);
618 });
619 self.load_default_profile(cx);
620 }
621 }
622 _ => {}
623 }
624 }
625 }
626 }
627}
628
629#[derive(Debug, Clone, Serialize, Deserialize)]
630pub struct SerializedThreadMetadata {
631 pub id: ThreadId,
632 pub summary: SharedString,
633 pub updated_at: DateTime<Utc>,
634}
635
636#[derive(Serialize, Deserialize, Debug)]
637pub struct SerializedThread {
638 pub version: String,
639 pub summary: SharedString,
640 pub updated_at: DateTime<Utc>,
641 pub messages: Vec<SerializedMessage>,
642 #[serde(default)]
643 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
644 #[serde(default)]
645 pub cumulative_token_usage: TokenUsage,
646 #[serde(default)]
647 pub request_token_usage: Vec<TokenUsage>,
648 #[serde(default)]
649 pub detailed_summary_state: DetailedSummaryState,
650 #[serde(default)]
651 pub exceeded_window_error: Option<ExceededWindowError>,
652 #[serde(default)]
653 pub model: Option<SerializedLanguageModel>,
654 #[serde(default)]
655 pub completion_mode: Option<CompletionMode>,
656}
657
658#[derive(Serialize, Deserialize, Debug)]
659pub struct SerializedLanguageModel {
660 pub provider: String,
661 pub model: String,
662}
663
664impl SerializedThread {
665 pub const VERSION: &'static str = "0.2.0";
666
667 pub fn from_json(json: &[u8]) -> Result<Self> {
668 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
669 match saved_thread_json.get("version") {
670 Some(serde_json::Value::String(version)) => match version.as_str() {
671 SerializedThreadV0_1_0::VERSION => {
672 let saved_thread =
673 serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
674 Ok(saved_thread.upgrade())
675 }
676 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
677 saved_thread_json,
678 )?),
679 _ => Err(anyhow!(
680 "unrecognized serialized thread version: {}",
681 version
682 )),
683 },
684 None => {
685 let saved_thread =
686 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
687 Ok(saved_thread.upgrade())
688 }
689 version => Err(anyhow!(
690 "unrecognized serialized thread version: {:?}",
691 version
692 )),
693 }
694 }
695}
696
697#[derive(Serialize, Deserialize, Debug)]
698pub struct SerializedThreadV0_1_0(
699 // The structure did not change, so we are reusing the latest SerializedThread.
700 // When making the next version, make sure this points to SerializedThreadV0_2_0
701 SerializedThread,
702);
703
704impl SerializedThreadV0_1_0 {
705 pub const VERSION: &'static str = "0.1.0";
706
707 pub fn upgrade(self) -> SerializedThread {
708 debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
709
710 let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
711
712 for message in self.0.messages {
713 if message.role == Role::User && !message.tool_results.is_empty() {
714 if let Some(last_message) = messages.last_mut() {
715 debug_assert!(last_message.role == Role::Assistant);
716
717 last_message.tool_results = message.tool_results;
718 continue;
719 }
720 }
721
722 messages.push(message);
723 }
724
725 SerializedThread { messages, ..self.0 }
726 }
727}
728
729#[derive(Debug, Serialize, Deserialize)]
730pub struct SerializedMessage {
731 pub id: MessageId,
732 pub role: Role,
733 #[serde(default)]
734 pub segments: Vec<SerializedMessageSegment>,
735 #[serde(default)]
736 pub tool_uses: Vec<SerializedToolUse>,
737 #[serde(default)]
738 pub tool_results: Vec<SerializedToolResult>,
739 #[serde(default)]
740 pub context: String,
741 #[serde(default)]
742 pub creases: Vec<SerializedCrease>,
743}
744
745#[derive(Debug, Serialize, Deserialize)]
746#[serde(tag = "type")]
747pub enum SerializedMessageSegment {
748 #[serde(rename = "text")]
749 Text {
750 text: String,
751 },
752 #[serde(rename = "thinking")]
753 Thinking {
754 text: String,
755 #[serde(skip_serializing_if = "Option::is_none")]
756 signature: Option<String>,
757 },
758 RedactedThinking {
759 data: Vec<u8>,
760 },
761}
762
763#[derive(Debug, Serialize, Deserialize)]
764pub struct SerializedToolUse {
765 pub id: LanguageModelToolUseId,
766 pub name: SharedString,
767 pub input: serde_json::Value,
768}
769
770#[derive(Debug, Serialize, Deserialize)]
771pub struct SerializedToolResult {
772 pub tool_use_id: LanguageModelToolUseId,
773 pub is_error: bool,
774 pub content: Arc<str>,
775}
776
777#[derive(Serialize, Deserialize)]
778struct LegacySerializedThread {
779 pub summary: SharedString,
780 pub updated_at: DateTime<Utc>,
781 pub messages: Vec<LegacySerializedMessage>,
782 #[serde(default)]
783 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
784}
785
786impl LegacySerializedThread {
787 pub fn upgrade(self) -> SerializedThread {
788 SerializedThread {
789 version: SerializedThread::VERSION.to_string(),
790 summary: self.summary,
791 updated_at: self.updated_at,
792 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
793 initial_project_snapshot: self.initial_project_snapshot,
794 cumulative_token_usage: TokenUsage::default(),
795 request_token_usage: Vec::new(),
796 detailed_summary_state: DetailedSummaryState::default(),
797 exceeded_window_error: None,
798 model: None,
799 completion_mode: None,
800 }
801 }
802}
803
804#[derive(Debug, Serialize, Deserialize)]
805struct LegacySerializedMessage {
806 pub id: MessageId,
807 pub role: Role,
808 pub text: String,
809 #[serde(default)]
810 pub tool_uses: Vec<SerializedToolUse>,
811 #[serde(default)]
812 pub tool_results: Vec<SerializedToolResult>,
813}
814
815impl LegacySerializedMessage {
816 fn upgrade(self) -> SerializedMessage {
817 SerializedMessage {
818 id: self.id,
819 role: self.role,
820 segments: vec![SerializedMessageSegment::Text { text: self.text }],
821 tool_uses: self.tool_uses,
822 tool_results: self.tool_results,
823 context: String::new(),
824 creases: Vec::new(),
825 }
826 }
827}
828
829#[derive(Debug, Serialize, Deserialize)]
830pub struct SerializedCrease {
831 pub start: usize,
832 pub end: usize,
833 pub icon_path: SharedString,
834 pub label: SharedString,
835}
836
837struct GlobalThreadsDatabase(
838 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
839);
840
841impl Global for GlobalThreadsDatabase {}
842
843pub(crate) struct ThreadsDatabase {
844 executor: BackgroundExecutor,
845 env: heed::Env,
846 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
847}
848
849impl heed::BytesEncode<'_> for SerializedThread {
850 type EItem = SerializedThread;
851
852 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
853 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
854 }
855}
856
857impl<'a> heed::BytesDecode<'a> for SerializedThread {
858 type DItem = SerializedThread;
859
860 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
861 // We implement this type manually because we want to call `SerializedThread::from_json`,
862 // instead of the Deserialize trait implementation for `SerializedThread`.
863 SerializedThread::from_json(bytes).map_err(Into::into)
864 }
865}
866
867impl ThreadsDatabase {
868 fn global_future(
869 cx: &mut App,
870 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
871 GlobalThreadsDatabase::global(cx).0.clone()
872 }
873
874 fn init(cx: &mut App) {
875 let executor = cx.background_executor().clone();
876 let database_future = executor
877 .spawn({
878 let executor = executor.clone();
879 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
880 async move { ThreadsDatabase::new(database_path, executor) }
881 })
882 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
883 .boxed()
884 .shared();
885
886 cx.set_global(GlobalThreadsDatabase(database_future));
887 }
888
889 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
890 std::fs::create_dir_all(&path)?;
891
892 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
893 let env = unsafe {
894 heed::EnvOpenOptions::new()
895 .map_size(ONE_GB_IN_BYTES)
896 .max_dbs(1)
897 .open(path)?
898 };
899
900 let mut txn = env.write_txn()?;
901 let threads = env.create_database(&mut txn, Some("threads"))?;
902 txn.commit()?;
903
904 Ok(Self {
905 executor,
906 env,
907 threads,
908 })
909 }
910
911 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
912 let env = self.env.clone();
913 let threads = self.threads;
914
915 self.executor.spawn(async move {
916 let txn = env.read_txn()?;
917 let mut iter = threads.iter(&txn)?;
918 let mut threads = Vec::new();
919 while let Some((key, value)) = iter.next().transpose()? {
920 threads.push(SerializedThreadMetadata {
921 id: key,
922 summary: value.summary,
923 updated_at: value.updated_at,
924 });
925 }
926
927 Ok(threads)
928 })
929 }
930
931 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
932 let env = self.env.clone();
933 let threads = self.threads;
934
935 self.executor.spawn(async move {
936 let txn = env.read_txn()?;
937 let thread = threads.get(&txn, &id)?;
938 Ok(thread)
939 })
940 }
941
942 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
943 let env = self.env.clone();
944 let threads = self.threads;
945
946 self.executor.spawn(async move {
947 let mut txn = env.write_txn()?;
948 threads.put(&mut txn, &id, &thread)?;
949 txn.commit()?;
950 Ok(())
951 })
952 }
953
954 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
955 let env = self.env.clone();
956 let threads = self.threads;
957
958 self.executor.spawn(async move {
959 let mut txn = env.write_txn()?;
960 threads.delete(&mut txn, &id)?;
961 txn.commit()?;
962 Ok(())
963 })
964 }
965}