1use std::borrow::Cow;
2use std::cell::{Ref, RefCell};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5use std::sync::Arc;
6
7use anyhow::{Context as _, Result, anyhow};
8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
10use chrono::{DateTime, Utc};
11use collections::HashMap;
12use context_server::manager::ContextServerManager;
13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
14use fs::Fs;
15use futures::FutureExt as _;
16use futures::future::{self, BoxFuture, Shared};
17use gpui::{
18 App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
19 Subscription, Task, prelude::*,
20};
21use heed::Database;
22use heed::types::SerdeBincode;
23use language_model::{LanguageModelToolUseId, Role, TokenUsage};
24use project::{Project, Worktree};
25use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
26use serde::{Deserialize, Serialize};
27use settings::{Settings as _, SettingsStore};
28use util::ResultExt as _;
29
30use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
31
32const RULES_FILE_NAMES: [&'static str; 6] = [
33 ".rules",
34 ".cursorrules",
35 ".windsurfrules",
36 ".clinerules",
37 ".github/copilot-instructions.md",
38 "CLAUDE.md",
39];
40
41pub fn init(cx: &mut App) {
42 ThreadsDatabase::init(cx);
43}
44
45/// A system prompt shared by all threads created by this ThreadStore
46#[derive(Clone, Default)]
47pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
48
49impl SharedProjectContext {
50 pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
51 self.0.borrow()
52 }
53}
54
55pub struct ThreadStore {
56 project: Entity<Project>,
57 tools: Arc<ToolWorkingSet>,
58 prompt_builder: Arc<PromptBuilder>,
59 context_server_manager: Entity<ContextServerManager>,
60 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
61 threads: Vec<SerializedThreadMetadata>,
62 project_context: SharedProjectContext,
63 _subscriptions: Vec<Subscription>,
64}
65
66pub struct RulesLoadingError {
67 pub message: SharedString,
68}
69
70impl EventEmitter<RulesLoadingError> for ThreadStore {}
71
72impl ThreadStore {
73 pub fn load(
74 project: Entity<Project>,
75 tools: Arc<ToolWorkingSet>,
76 prompt_builder: Arc<PromptBuilder>,
77 cx: &mut App,
78 ) -> Task<Entity<Self>> {
79 let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
80 let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
81 cx.foreground_executor().spawn(async move {
82 reload.await;
83 thread_store
84 })
85 }
86
87 fn new(
88 project: Entity<Project>,
89 tools: Arc<ToolWorkingSet>,
90 prompt_builder: Arc<PromptBuilder>,
91 cx: &mut Context<Self>,
92 ) -> Self {
93 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
94 let context_server_manager = cx.new(|cx| {
95 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
96 });
97 let settings_subscription =
98 cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
99 this.load_default_profile(cx);
100 });
101 let project_subscription = cx.subscribe(&project, Self::handle_project_event);
102
103 let this = Self {
104 project,
105 tools,
106 prompt_builder,
107 context_server_manager,
108 context_server_tool_ids: HashMap::default(),
109 threads: Vec::new(),
110 project_context: SharedProjectContext::default(),
111 _subscriptions: vec![settings_subscription, project_subscription],
112 };
113 this.load_default_profile(cx);
114 this.register_context_server_handlers(cx);
115 this.reload(cx).detach_and_log_err(cx);
116 this
117 }
118
119 fn handle_project_event(
120 &mut self,
121 _project: Entity<Project>,
122 event: &project::Event,
123 cx: &mut Context<Self>,
124 ) {
125 match event {
126 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
127 self.reload_system_prompt(cx).detach();
128 }
129 project::Event::WorktreeUpdatedEntries(_, items) => {
130 if items.iter().any(|(path, _, _)| {
131 RULES_FILE_NAMES
132 .iter()
133 .any(|name| path.as_ref() == Path::new(name))
134 }) {
135 self.reload_system_prompt(cx).detach();
136 }
137 }
138 _ => {}
139 }
140 }
141
142 pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
143 let project = self.project.read(cx);
144 let tasks = project
145 .visible_worktrees(cx)
146 .map(|worktree| {
147 Self::load_worktree_info_for_system_prompt(
148 project.fs().clone(),
149 worktree.read(cx),
150 cx,
151 )
152 })
153 .collect::<Vec<_>>();
154
155 cx.spawn(async move |this, cx| {
156 let results = futures::future::join_all(tasks).await;
157 let worktrees = results
158 .into_iter()
159 .map(|(worktree, rules_error)| {
160 if let Some(rules_error) = rules_error {
161 this.update(cx, |_, cx| cx.emit(rules_error)).ok();
162 }
163 worktree
164 })
165 .collect::<Vec<_>>();
166 this.update(cx, |this, _cx| {
167 *this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
168 })
169 .ok();
170 })
171 }
172
173 fn load_worktree_info_for_system_prompt(
174 fs: Arc<dyn Fs>,
175 worktree: &Worktree,
176 cx: &App,
177 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
178 let root_name = worktree.root_name().into();
179 let abs_path = worktree.abs_path();
180
181 let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
182 let Some(rules_task) = rules_task else {
183 return Task::ready((
184 WorktreeContext {
185 root_name,
186 abs_path,
187 rules_file: None,
188 },
189 None,
190 ));
191 };
192
193 cx.spawn(async move |_| {
194 let (rules_file, rules_file_error) = match rules_task.await {
195 Ok(rules_file) => (Some(rules_file), None),
196 Err(err) => (
197 None,
198 Some(RulesLoadingError {
199 message: format!("{err}").into(),
200 }),
201 ),
202 };
203 let worktree_info = WorktreeContext {
204 root_name,
205 abs_path,
206 rules_file,
207 };
208 (worktree_info, rules_file_error)
209 })
210 }
211
212 fn load_worktree_rules_file(
213 fs: Arc<dyn Fs>,
214 worktree: &Worktree,
215 cx: &App,
216 ) -> Option<Task<Result<RulesFileContext>>> {
217 let selected_rules_file = RULES_FILE_NAMES
218 .into_iter()
219 .filter_map(|name| {
220 worktree
221 .entry_for_path(name)
222 .filter(|entry| entry.is_file())
223 .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
224 })
225 .next();
226
227 // Note that Cline supports `.clinerules` being a directory, but that is not currently
228 // supported. This doesn't seem to occur often in GitHub repositories.
229 selected_rules_file.map(|(path_in_worktree, abs_path)| {
230 let fs = fs.clone();
231 cx.background_spawn(async move {
232 let abs_path = abs_path?;
233 let text = fs.load(&abs_path).await.with_context(|| {
234 format!("Failed to load assistant rules file {:?}", abs_path)
235 })?;
236 anyhow::Ok(RulesFileContext {
237 path_in_worktree,
238 abs_path: abs_path.into(),
239 text: text.trim().to_string(),
240 })
241 })
242 })
243 }
244
245 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
246 self.context_server_manager.clone()
247 }
248
249 pub fn tools(&self) -> Arc<ToolWorkingSet> {
250 self.tools.clone()
251 }
252
253 /// Returns the number of threads.
254 pub fn thread_count(&self) -> usize {
255 self.threads.len()
256 }
257
258 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
259 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
260 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
261 threads
262 }
263
264 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
265 self.threads().into_iter().take(limit).collect()
266 }
267
268 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
269 cx.new(|cx| {
270 Thread::new(
271 self.project.clone(),
272 self.tools.clone(),
273 self.prompt_builder.clone(),
274 self.project_context.clone(),
275 cx,
276 )
277 })
278 }
279
280 pub fn open_thread(
281 &self,
282 id: &ThreadId,
283 cx: &mut Context<Self>,
284 ) -> Task<Result<Entity<Thread>>> {
285 let id = id.clone();
286 let database_future = ThreadsDatabase::global_future(cx);
287 cx.spawn(async move |this, cx| {
288 let database = database_future.await.map_err(|err| anyhow!(err))?;
289 let thread = database
290 .try_find_thread(id.clone())
291 .await?
292 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
293
294 let thread = this.update(cx, |this, cx| {
295 cx.new(|cx| {
296 Thread::deserialize(
297 id.clone(),
298 thread,
299 this.project.clone(),
300 this.tools.clone(),
301 this.prompt_builder.clone(),
302 this.project_context.clone(),
303 cx,
304 )
305 })
306 })?;
307
308 Ok(thread)
309 })
310 }
311
312 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
313 let (metadata, serialized_thread) =
314 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
315
316 let database_future = ThreadsDatabase::global_future(cx);
317 cx.spawn(async move |this, cx| {
318 let serialized_thread = serialized_thread.await?;
319 let database = database_future.await.map_err(|err| anyhow!(err))?;
320 database.save_thread(metadata, serialized_thread).await?;
321
322 this.update(cx, |this, cx| this.reload(cx))?.await
323 })
324 }
325
326 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
327 let id = id.clone();
328 let database_future = ThreadsDatabase::global_future(cx);
329 cx.spawn(async move |this, cx| {
330 let database = database_future.await.map_err(|err| anyhow!(err))?;
331 database.delete_thread(id.clone()).await?;
332
333 this.update(cx, |this, cx| {
334 this.threads.retain(|thread| thread.id != id);
335 cx.notify();
336 })
337 })
338 }
339
340 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
341 let database_future = ThreadsDatabase::global_future(cx);
342 cx.spawn(async move |this, cx| {
343 let threads = database_future
344 .await
345 .map_err(|err| anyhow!(err))?
346 .list_threads()
347 .await?;
348
349 this.update(cx, |this, cx| {
350 this.threads = threads;
351 cx.notify();
352 })
353 })
354 }
355
356 fn load_default_profile(&self, cx: &Context<Self>) {
357 let assistant_settings = AssistantSettings::get_global(cx);
358
359 self.load_profile_by_id(&assistant_settings.default_profile, cx);
360 }
361
362 pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
363 let assistant_settings = AssistantSettings::get_global(cx);
364
365 if let Some(profile) = assistant_settings.profiles.get(profile_id) {
366 self.load_profile(profile, cx);
367 }
368 }
369
370 pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
371 self.tools.disable_all_tools();
372 self.tools.enable(
373 ToolSource::Native,
374 &profile
375 .tools
376 .iter()
377 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
378 .collect::<Vec<_>>(),
379 );
380
381 if profile.enable_all_context_servers {
382 for context_server in self.context_server_manager.read(cx).all_servers() {
383 self.tools.enable_source(
384 ToolSource::ContextServer {
385 id: context_server.id().into(),
386 },
387 cx,
388 );
389 }
390 } else {
391 for (context_server_id, preset) in &profile.context_servers {
392 self.tools.enable(
393 ToolSource::ContextServer {
394 id: context_server_id.clone().into(),
395 },
396 &preset
397 .tools
398 .iter()
399 .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
400 .collect::<Vec<_>>(),
401 )
402 }
403 }
404 }
405
406 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
407 cx.subscribe(
408 &self.context_server_manager.clone(),
409 Self::handle_context_server_event,
410 )
411 .detach();
412 }
413
414 fn handle_context_server_event(
415 &mut self,
416 context_server_manager: Entity<ContextServerManager>,
417 event: &context_server::manager::Event,
418 cx: &mut Context<Self>,
419 ) {
420 let tool_working_set = self.tools.clone();
421 match event {
422 context_server::manager::Event::ServerStarted { server_id } => {
423 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
424 let context_server_manager = context_server_manager.clone();
425 cx.spawn({
426 let server = server.clone();
427 let server_id = server_id.clone();
428 async move |this, cx| {
429 let Some(protocol) = server.client() else {
430 return;
431 };
432
433 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
434 if let Some(tools) = protocol.list_tools().await.log_err() {
435 let tool_ids = tools
436 .tools
437 .into_iter()
438 .map(|tool| {
439 log::info!(
440 "registering context server tool: {:?}",
441 tool.name
442 );
443 tool_working_set.insert(Arc::new(
444 ContextServerTool::new(
445 context_server_manager.clone(),
446 server.id(),
447 tool,
448 ),
449 ))
450 })
451 .collect::<Vec<_>>();
452
453 this.update(cx, |this, cx| {
454 this.context_server_tool_ids.insert(server_id, tool_ids);
455 this.load_default_profile(cx);
456 })
457 .log_err();
458 }
459 }
460 }
461 })
462 .detach();
463 }
464 }
465 context_server::manager::Event::ServerStopped { server_id } => {
466 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
467 tool_working_set.remove(&tool_ids);
468 self.load_default_profile(cx);
469 }
470 }
471 }
472 }
473}
474
475#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct SerializedThreadMetadata {
477 pub id: ThreadId,
478 pub summary: SharedString,
479 pub updated_at: DateTime<Utc>,
480}
481
482#[derive(Serialize, Deserialize, Debug)]
483pub struct SerializedThread {
484 pub version: String,
485 pub summary: SharedString,
486 pub updated_at: DateTime<Utc>,
487 pub messages: Vec<SerializedMessage>,
488 #[serde(default)]
489 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
490 #[serde(default)]
491 pub cumulative_token_usage: TokenUsage,
492 #[serde(default)]
493 pub detailed_summary_state: DetailedSummaryState,
494}
495
496impl SerializedThread {
497 pub const VERSION: &'static str = "0.1.0";
498
499 pub fn from_json(json: &[u8]) -> Result<Self> {
500 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
501 match saved_thread_json.get("version") {
502 Some(serde_json::Value::String(version)) => match version.as_str() {
503 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
504 saved_thread_json,
505 )?),
506 _ => Err(anyhow!(
507 "unrecognized serialized thread version: {}",
508 version
509 )),
510 },
511 None => {
512 let saved_thread =
513 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
514 Ok(saved_thread.upgrade())
515 }
516 version => Err(anyhow!(
517 "unrecognized serialized thread version: {:?}",
518 version
519 )),
520 }
521 }
522}
523
524#[derive(Debug, Serialize, Deserialize)]
525pub struct SerializedMessage {
526 pub id: MessageId,
527 pub role: Role,
528 #[serde(default)]
529 pub segments: Vec<SerializedMessageSegment>,
530 #[serde(default)]
531 pub tool_uses: Vec<SerializedToolUse>,
532 #[serde(default)]
533 pub tool_results: Vec<SerializedToolResult>,
534 #[serde(default)]
535 pub context: String,
536}
537
538#[derive(Debug, Serialize, Deserialize)]
539#[serde(tag = "type")]
540pub enum SerializedMessageSegment {
541 #[serde(rename = "text")]
542 Text { text: String },
543 #[serde(rename = "thinking")]
544 Thinking { text: String },
545}
546
547#[derive(Debug, Serialize, Deserialize)]
548pub struct SerializedToolUse {
549 pub id: LanguageModelToolUseId,
550 pub name: SharedString,
551 pub input: serde_json::Value,
552}
553
554#[derive(Debug, Serialize, Deserialize)]
555pub struct SerializedToolResult {
556 pub tool_use_id: LanguageModelToolUseId,
557 pub is_error: bool,
558 pub content: Arc<str>,
559}
560
561#[derive(Serialize, Deserialize)]
562struct LegacySerializedThread {
563 pub summary: SharedString,
564 pub updated_at: DateTime<Utc>,
565 pub messages: Vec<LegacySerializedMessage>,
566 #[serde(default)]
567 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
568}
569
570impl LegacySerializedThread {
571 pub fn upgrade(self) -> SerializedThread {
572 SerializedThread {
573 version: SerializedThread::VERSION.to_string(),
574 summary: self.summary,
575 updated_at: self.updated_at,
576 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
577 initial_project_snapshot: self.initial_project_snapshot,
578 cumulative_token_usage: TokenUsage::default(),
579 detailed_summary_state: DetailedSummaryState::default(),
580 }
581 }
582}
583
584#[derive(Debug, Serialize, Deserialize)]
585struct LegacySerializedMessage {
586 pub id: MessageId,
587 pub role: Role,
588 pub text: String,
589 #[serde(default)]
590 pub tool_uses: Vec<SerializedToolUse>,
591 #[serde(default)]
592 pub tool_results: Vec<SerializedToolResult>,
593}
594
595impl LegacySerializedMessage {
596 fn upgrade(self) -> SerializedMessage {
597 SerializedMessage {
598 id: self.id,
599 role: self.role,
600 segments: vec![SerializedMessageSegment::Text { text: self.text }],
601 tool_uses: self.tool_uses,
602 tool_results: self.tool_results,
603 context: String::new(),
604 }
605 }
606}
607
608struct GlobalThreadsDatabase(
609 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
610);
611
612impl Global for GlobalThreadsDatabase {}
613
614pub(crate) struct ThreadsDatabase {
615 executor: BackgroundExecutor,
616 env: heed::Env,
617 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
618}
619
620impl heed::BytesEncode<'_> for SerializedThread {
621 type EItem = SerializedThread;
622
623 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
624 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
625 }
626}
627
628impl<'a> heed::BytesDecode<'a> for SerializedThread {
629 type DItem = SerializedThread;
630
631 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
632 // We implement this type manually because we want to call `SerializedThread::from_json`,
633 // instead of the Deserialize trait implementation for `SerializedThread`.
634 SerializedThread::from_json(bytes).map_err(Into::into)
635 }
636}
637
638impl ThreadsDatabase {
639 fn global_future(
640 cx: &mut App,
641 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
642 GlobalThreadsDatabase::global(cx).0.clone()
643 }
644
645 fn init(cx: &mut App) {
646 let executor = cx.background_executor().clone();
647 let database_future = executor
648 .spawn({
649 let executor = executor.clone();
650 let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
651 async move { ThreadsDatabase::new(database_path, executor) }
652 })
653 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
654 .boxed()
655 .shared();
656
657 cx.set_global(GlobalThreadsDatabase(database_future));
658 }
659
660 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
661 std::fs::create_dir_all(&path)?;
662
663 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
664 let env = unsafe {
665 heed::EnvOpenOptions::new()
666 .map_size(ONE_GB_IN_BYTES)
667 .max_dbs(1)
668 .open(path)?
669 };
670
671 let mut txn = env.write_txn()?;
672 let threads = env.create_database(&mut txn, Some("threads"))?;
673 txn.commit()?;
674
675 Ok(Self {
676 executor,
677 env,
678 threads,
679 })
680 }
681
682 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
683 let env = self.env.clone();
684 let threads = self.threads;
685
686 self.executor.spawn(async move {
687 let txn = env.read_txn()?;
688 let mut iter = threads.iter(&txn)?;
689 let mut threads = Vec::new();
690 while let Some((key, value)) = iter.next().transpose()? {
691 threads.push(SerializedThreadMetadata {
692 id: key,
693 summary: value.summary,
694 updated_at: value.updated_at,
695 });
696 }
697
698 Ok(threads)
699 })
700 }
701
702 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
703 let env = self.env.clone();
704 let threads = self.threads;
705
706 self.executor.spawn(async move {
707 let txn = env.read_txn()?;
708 let thread = threads.get(&txn, &id)?;
709 Ok(thread)
710 })
711 }
712
713 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
714 let env = self.env.clone();
715 let threads = self.threads;
716
717 self.executor.spawn(async move {
718 let mut txn = env.write_txn()?;
719 threads.put(&mut txn, &id, &thread)?;
720 txn.commit()?;
721 Ok(())
722 })
723 }
724
725 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
726 let env = self.env.clone();
727 let threads = self.threads;
728
729 self.executor.spawn(async move {
730 let mut txn = env.write_txn()?;
731 threads.delete(&mut txn, &id)?;
732 txn.commit()?;
733 Ok(())
734 })
735 }
736}