thread_store.rs

  1use std::borrow::Cow;
  2use std::cell::{Ref, RefCell};
  3use std::path::{Path, PathBuf};
  4use std::rc::Rc;
  5use std::sync::Arc;
  6
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  9use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
 10use chrono::{DateTime, Utc};
 11use collections::HashMap;
 12use context_server::manager::ContextServerManager;
 13use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 14use fs::Fs;
 15use futures::FutureExt as _;
 16use futures::future::{self, BoxFuture, Shared};
 17use gpui::{
 18    App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
 19    Subscription, Task, prelude::*,
 20};
 21use heed::Database;
 22use heed::types::SerdeBincode;
 23use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 24use project::{Project, Worktree};
 25use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
 26use serde::{Deserialize, Serialize};
 27use settings::{Settings as _, SettingsStore};
 28use util::ResultExt as _;
 29
 30use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
 31
 32const RULES_FILE_NAMES: [&'static str; 6] = [
 33    ".rules",
 34    ".cursorrules",
 35    ".windsurfrules",
 36    ".clinerules",
 37    ".github/copilot-instructions.md",
 38    "CLAUDE.md",
 39];
 40
 41pub fn init(cx: &mut App) {
 42    ThreadsDatabase::init(cx);
 43}
 44
 45/// A system prompt shared by all threads created by this ThreadStore
 46#[derive(Clone, Default)]
 47pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
 48
 49impl SharedProjectContext {
 50    pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
 51        self.0.borrow()
 52    }
 53}
 54
 55pub struct ThreadStore {
 56    project: Entity<Project>,
 57    tools: Arc<ToolWorkingSet>,
 58    prompt_builder: Arc<PromptBuilder>,
 59    context_server_manager: Entity<ContextServerManager>,
 60    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 61    threads: Vec<SerializedThreadMetadata>,
 62    project_context: SharedProjectContext,
 63    _subscriptions: Vec<Subscription>,
 64}
 65
 66pub struct RulesLoadingError {
 67    pub message: SharedString,
 68}
 69
 70impl EventEmitter<RulesLoadingError> for ThreadStore {}
 71
 72impl ThreadStore {
 73    pub fn load(
 74        project: Entity<Project>,
 75        tools: Arc<ToolWorkingSet>,
 76        prompt_builder: Arc<PromptBuilder>,
 77        cx: &mut App,
 78    ) -> Task<Entity<Self>> {
 79        let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
 80        let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
 81        cx.foreground_executor().spawn(async move {
 82            reload.await;
 83            thread_store
 84        })
 85    }
 86
 87    fn new(
 88        project: Entity<Project>,
 89        tools: Arc<ToolWorkingSet>,
 90        prompt_builder: Arc<PromptBuilder>,
 91        cx: &mut Context<Self>,
 92    ) -> Self {
 93        let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 94        let context_server_manager = cx.new(|cx| {
 95            ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 96        });
 97        let settings_subscription =
 98            cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
 99                this.load_default_profile(cx);
100            });
101        let project_subscription = cx.subscribe(&project, Self::handle_project_event);
102
103        let this = Self {
104            project,
105            tools,
106            prompt_builder,
107            context_server_manager,
108            context_server_tool_ids: HashMap::default(),
109            threads: Vec::new(),
110            project_context: SharedProjectContext::default(),
111            _subscriptions: vec![settings_subscription, project_subscription],
112        };
113        this.load_default_profile(cx);
114        this.register_context_server_handlers(cx);
115        this.reload(cx).detach_and_log_err(cx);
116        this
117    }
118
119    fn handle_project_event(
120        &mut self,
121        _project: Entity<Project>,
122        event: &project::Event,
123        cx: &mut Context<Self>,
124    ) {
125        match event {
126            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
127                self.reload_system_prompt(cx).detach();
128            }
129            project::Event::WorktreeUpdatedEntries(_, items) => {
130                if items.iter().any(|(path, _, _)| {
131                    RULES_FILE_NAMES
132                        .iter()
133                        .any(|name| path.as_ref() == Path::new(name))
134                }) {
135                    self.reload_system_prompt(cx).detach();
136                }
137            }
138            _ => {}
139        }
140    }
141
142    pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
143        let project = self.project.read(cx);
144        let tasks = project
145            .visible_worktrees(cx)
146            .map(|worktree| {
147                Self::load_worktree_info_for_system_prompt(
148                    project.fs().clone(),
149                    worktree.read(cx),
150                    cx,
151                )
152            })
153            .collect::<Vec<_>>();
154
155        cx.spawn(async move |this, cx| {
156            let results = futures::future::join_all(tasks).await;
157            let worktrees = results
158                .into_iter()
159                .map(|(worktree, rules_error)| {
160                    if let Some(rules_error) = rules_error {
161                        this.update(cx, |_, cx| cx.emit(rules_error)).ok();
162                    }
163                    worktree
164                })
165                .collect::<Vec<_>>();
166            this.update(cx, |this, _cx| {
167                *this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
168            })
169            .ok();
170        })
171    }
172
173    fn load_worktree_info_for_system_prompt(
174        fs: Arc<dyn Fs>,
175        worktree: &Worktree,
176        cx: &App,
177    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
178        let root_name = worktree.root_name().into();
179        let abs_path = worktree.abs_path();
180
181        let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
182        let Some(rules_task) = rules_task else {
183            return Task::ready((
184                WorktreeContext {
185                    root_name,
186                    abs_path,
187                    rules_file: None,
188                },
189                None,
190            ));
191        };
192
193        cx.spawn(async move |_| {
194            let (rules_file, rules_file_error) = match rules_task.await {
195                Ok(rules_file) => (Some(rules_file), None),
196                Err(err) => (
197                    None,
198                    Some(RulesLoadingError {
199                        message: format!("{err}").into(),
200                    }),
201                ),
202            };
203            let worktree_info = WorktreeContext {
204                root_name,
205                abs_path,
206                rules_file,
207            };
208            (worktree_info, rules_file_error)
209        })
210    }
211
212    fn load_worktree_rules_file(
213        fs: Arc<dyn Fs>,
214        worktree: &Worktree,
215        cx: &App,
216    ) -> Option<Task<Result<RulesFileContext>>> {
217        let selected_rules_file = RULES_FILE_NAMES
218            .into_iter()
219            .filter_map(|name| {
220                worktree
221                    .entry_for_path(name)
222                    .filter(|entry| entry.is_file())
223                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
224            })
225            .next();
226
227        // Note that Cline supports `.clinerules` being a directory, but that is not currently
228        // supported. This doesn't seem to occur often in GitHub repositories.
229        selected_rules_file.map(|(path_in_worktree, abs_path)| {
230            let fs = fs.clone();
231            cx.background_spawn(async move {
232                let abs_path = abs_path?;
233                let text = fs.load(&abs_path).await.with_context(|| {
234                    format!("Failed to load assistant rules file {:?}", abs_path)
235                })?;
236                anyhow::Ok(RulesFileContext {
237                    path_in_worktree,
238                    abs_path: abs_path.into(),
239                    text: text.trim().to_string(),
240                })
241            })
242        })
243    }
244
245    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
246        self.context_server_manager.clone()
247    }
248
249    pub fn tools(&self) -> Arc<ToolWorkingSet> {
250        self.tools.clone()
251    }
252
253    /// Returns the number of threads.
254    pub fn thread_count(&self) -> usize {
255        self.threads.len()
256    }
257
258    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
259        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
260        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
261        threads
262    }
263
264    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
265        self.threads().into_iter().take(limit).collect()
266    }
267
268    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
269        cx.new(|cx| {
270            Thread::new(
271                self.project.clone(),
272                self.tools.clone(),
273                self.prompt_builder.clone(),
274                self.project_context.clone(),
275                cx,
276            )
277        })
278    }
279
280    pub fn open_thread(
281        &self,
282        id: &ThreadId,
283        cx: &mut Context<Self>,
284    ) -> Task<Result<Entity<Thread>>> {
285        let id = id.clone();
286        let database_future = ThreadsDatabase::global_future(cx);
287        cx.spawn(async move |this, cx| {
288            let database = database_future.await.map_err(|err| anyhow!(err))?;
289            let thread = database
290                .try_find_thread(id.clone())
291                .await?
292                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
293
294            let thread = this.update(cx, |this, cx| {
295                cx.new(|cx| {
296                    Thread::deserialize(
297                        id.clone(),
298                        thread,
299                        this.project.clone(),
300                        this.tools.clone(),
301                        this.prompt_builder.clone(),
302                        this.project_context.clone(),
303                        cx,
304                    )
305                })
306            })?;
307
308            Ok(thread)
309        })
310    }
311
312    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
313        let (metadata, serialized_thread) =
314            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
315
316        let database_future = ThreadsDatabase::global_future(cx);
317        cx.spawn(async move |this, cx| {
318            let serialized_thread = serialized_thread.await?;
319            let database = database_future.await.map_err(|err| anyhow!(err))?;
320            database.save_thread(metadata, serialized_thread).await?;
321
322            this.update(cx, |this, cx| this.reload(cx))?.await
323        })
324    }
325
326    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
327        let id = id.clone();
328        let database_future = ThreadsDatabase::global_future(cx);
329        cx.spawn(async move |this, cx| {
330            let database = database_future.await.map_err(|err| anyhow!(err))?;
331            database.delete_thread(id.clone()).await?;
332
333            this.update(cx, |this, cx| {
334                this.threads.retain(|thread| thread.id != id);
335                cx.notify();
336            })
337        })
338    }
339
340    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
341        let database_future = ThreadsDatabase::global_future(cx);
342        cx.spawn(async move |this, cx| {
343            let threads = database_future
344                .await
345                .map_err(|err| anyhow!(err))?
346                .list_threads()
347                .await?;
348
349            this.update(cx, |this, cx| {
350                this.threads = threads;
351                cx.notify();
352            })
353        })
354    }
355
356    fn load_default_profile(&self, cx: &Context<Self>) {
357        let assistant_settings = AssistantSettings::get_global(cx);
358
359        self.load_profile_by_id(&assistant_settings.default_profile, cx);
360    }
361
362    pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
363        let assistant_settings = AssistantSettings::get_global(cx);
364
365        if let Some(profile) = assistant_settings.profiles.get(profile_id) {
366            self.load_profile(profile, cx);
367        }
368    }
369
370    pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
371        self.tools.disable_all_tools();
372        self.tools.enable(
373            ToolSource::Native,
374            &profile
375                .tools
376                .iter()
377                .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
378                .collect::<Vec<_>>(),
379        );
380
381        if profile.enable_all_context_servers {
382            for context_server in self.context_server_manager.read(cx).all_servers() {
383                self.tools.enable_source(
384                    ToolSource::ContextServer {
385                        id: context_server.id().into(),
386                    },
387                    cx,
388                );
389            }
390        } else {
391            for (context_server_id, preset) in &profile.context_servers {
392                self.tools.enable(
393                    ToolSource::ContextServer {
394                        id: context_server_id.clone().into(),
395                    },
396                    &preset
397                        .tools
398                        .iter()
399                        .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
400                        .collect::<Vec<_>>(),
401                )
402            }
403        }
404    }
405
406    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
407        cx.subscribe(
408            &self.context_server_manager.clone(),
409            Self::handle_context_server_event,
410        )
411        .detach();
412    }
413
414    fn handle_context_server_event(
415        &mut self,
416        context_server_manager: Entity<ContextServerManager>,
417        event: &context_server::manager::Event,
418        cx: &mut Context<Self>,
419    ) {
420        let tool_working_set = self.tools.clone();
421        match event {
422            context_server::manager::Event::ServerStarted { server_id } => {
423                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
424                    let context_server_manager = context_server_manager.clone();
425                    cx.spawn({
426                        let server = server.clone();
427                        let server_id = server_id.clone();
428                        async move |this, cx| {
429                            let Some(protocol) = server.client() else {
430                                return;
431                            };
432
433                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
434                                if let Some(tools) = protocol.list_tools().await.log_err() {
435                                    let tool_ids = tools
436                                        .tools
437                                        .into_iter()
438                                        .map(|tool| {
439                                            log::info!(
440                                                "registering context server tool: {:?}",
441                                                tool.name
442                                            );
443                                            tool_working_set.insert(Arc::new(
444                                                ContextServerTool::new(
445                                                    context_server_manager.clone(),
446                                                    server.id(),
447                                                    tool,
448                                                ),
449                                            ))
450                                        })
451                                        .collect::<Vec<_>>();
452
453                                    this.update(cx, |this, cx| {
454                                        this.context_server_tool_ids.insert(server_id, tool_ids);
455                                        this.load_default_profile(cx);
456                                    })
457                                    .log_err();
458                                }
459                            }
460                        }
461                    })
462                    .detach();
463                }
464            }
465            context_server::manager::Event::ServerStopped { server_id } => {
466                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
467                    tool_working_set.remove(&tool_ids);
468                    self.load_default_profile(cx);
469                }
470            }
471        }
472    }
473}
474
475#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct SerializedThreadMetadata {
477    pub id: ThreadId,
478    pub summary: SharedString,
479    pub updated_at: DateTime<Utc>,
480}
481
482#[derive(Serialize, Deserialize, Debug)]
483pub struct SerializedThread {
484    pub version: String,
485    pub summary: SharedString,
486    pub updated_at: DateTime<Utc>,
487    pub messages: Vec<SerializedMessage>,
488    #[serde(default)]
489    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
490    #[serde(default)]
491    pub cumulative_token_usage: TokenUsage,
492    #[serde(default)]
493    pub detailed_summary_state: DetailedSummaryState,
494}
495
496impl SerializedThread {
497    pub const VERSION: &'static str = "0.1.0";
498
499    pub fn from_json(json: &[u8]) -> Result<Self> {
500        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
501        match saved_thread_json.get("version") {
502            Some(serde_json::Value::String(version)) => match version.as_str() {
503                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
504                    saved_thread_json,
505                )?),
506                _ => Err(anyhow!(
507                    "unrecognized serialized thread version: {}",
508                    version
509                )),
510            },
511            None => {
512                let saved_thread =
513                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
514                Ok(saved_thread.upgrade())
515            }
516            version => Err(anyhow!(
517                "unrecognized serialized thread version: {:?}",
518                version
519            )),
520        }
521    }
522}
523
524#[derive(Debug, Serialize, Deserialize)]
525pub struct SerializedMessage {
526    pub id: MessageId,
527    pub role: Role,
528    #[serde(default)]
529    pub segments: Vec<SerializedMessageSegment>,
530    #[serde(default)]
531    pub tool_uses: Vec<SerializedToolUse>,
532    #[serde(default)]
533    pub tool_results: Vec<SerializedToolResult>,
534    #[serde(default)]
535    pub context: String,
536}
537
538#[derive(Debug, Serialize, Deserialize)]
539#[serde(tag = "type")]
540pub enum SerializedMessageSegment {
541    #[serde(rename = "text")]
542    Text { text: String },
543    #[serde(rename = "thinking")]
544    Thinking { text: String },
545}
546
547#[derive(Debug, Serialize, Deserialize)]
548pub struct SerializedToolUse {
549    pub id: LanguageModelToolUseId,
550    pub name: SharedString,
551    pub input: serde_json::Value,
552}
553
554#[derive(Debug, Serialize, Deserialize)]
555pub struct SerializedToolResult {
556    pub tool_use_id: LanguageModelToolUseId,
557    pub is_error: bool,
558    pub content: Arc<str>,
559}
560
561#[derive(Serialize, Deserialize)]
562struct LegacySerializedThread {
563    pub summary: SharedString,
564    pub updated_at: DateTime<Utc>,
565    pub messages: Vec<LegacySerializedMessage>,
566    #[serde(default)]
567    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
568}
569
570impl LegacySerializedThread {
571    pub fn upgrade(self) -> SerializedThread {
572        SerializedThread {
573            version: SerializedThread::VERSION.to_string(),
574            summary: self.summary,
575            updated_at: self.updated_at,
576            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
577            initial_project_snapshot: self.initial_project_snapshot,
578            cumulative_token_usage: TokenUsage::default(),
579            detailed_summary_state: DetailedSummaryState::default(),
580        }
581    }
582}
583
584#[derive(Debug, Serialize, Deserialize)]
585struct LegacySerializedMessage {
586    pub id: MessageId,
587    pub role: Role,
588    pub text: String,
589    #[serde(default)]
590    pub tool_uses: Vec<SerializedToolUse>,
591    #[serde(default)]
592    pub tool_results: Vec<SerializedToolResult>,
593}
594
595impl LegacySerializedMessage {
596    fn upgrade(self) -> SerializedMessage {
597        SerializedMessage {
598            id: self.id,
599            role: self.role,
600            segments: vec![SerializedMessageSegment::Text { text: self.text }],
601            tool_uses: self.tool_uses,
602            tool_results: self.tool_results,
603            context: String::new(),
604        }
605    }
606}
607
608struct GlobalThreadsDatabase(
609    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
610);
611
612impl Global for GlobalThreadsDatabase {}
613
614pub(crate) struct ThreadsDatabase {
615    executor: BackgroundExecutor,
616    env: heed::Env,
617    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
618}
619
620impl heed::BytesEncode<'_> for SerializedThread {
621    type EItem = SerializedThread;
622
623    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
624        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
625    }
626}
627
628impl<'a> heed::BytesDecode<'a> for SerializedThread {
629    type DItem = SerializedThread;
630
631    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
632        // We implement this type manually because we want to call `SerializedThread::from_json`,
633        // instead of the Deserialize trait implementation for `SerializedThread`.
634        SerializedThread::from_json(bytes).map_err(Into::into)
635    }
636}
637
638impl ThreadsDatabase {
639    fn global_future(
640        cx: &mut App,
641    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
642        GlobalThreadsDatabase::global(cx).0.clone()
643    }
644
645    fn init(cx: &mut App) {
646        let executor = cx.background_executor().clone();
647        let database_future = executor
648            .spawn({
649                let executor = executor.clone();
650                let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
651                async move { ThreadsDatabase::new(database_path, executor) }
652            })
653            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
654            .boxed()
655            .shared();
656
657        cx.set_global(GlobalThreadsDatabase(database_future));
658    }
659
660    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
661        std::fs::create_dir_all(&path)?;
662
663        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
664        let env = unsafe {
665            heed::EnvOpenOptions::new()
666                .map_size(ONE_GB_IN_BYTES)
667                .max_dbs(1)
668                .open(path)?
669        };
670
671        let mut txn = env.write_txn()?;
672        let threads = env.create_database(&mut txn, Some("threads"))?;
673        txn.commit()?;
674
675        Ok(Self {
676            executor,
677            env,
678            threads,
679        })
680    }
681
682    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
683        let env = self.env.clone();
684        let threads = self.threads;
685
686        self.executor.spawn(async move {
687            let txn = env.read_txn()?;
688            let mut iter = threads.iter(&txn)?;
689            let mut threads = Vec::new();
690            while let Some((key, value)) = iter.next().transpose()? {
691                threads.push(SerializedThreadMetadata {
692                    id: key,
693                    summary: value.summary,
694                    updated_at: value.updated_at,
695                });
696            }
697
698            Ok(threads)
699        })
700    }
701
702    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
703        let env = self.env.clone();
704        let threads = self.threads;
705
706        self.executor.spawn(async move {
707            let txn = env.read_txn()?;
708            let thread = threads.get(&txn, &id)?;
709            Ok(thread)
710        })
711    }
712
713    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
714        let env = self.env.clone();
715        let threads = self.threads;
716
717        self.executor.spawn(async move {
718            let mut txn = env.write_txn()?;
719            threads.put(&mut txn, &id, &thread)?;
720            txn.commit()?;
721            Ok(())
722        })
723    }
724
725    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
726        let env = self.env.clone();
727        let threads = self.threads;
728
729        self.executor.spawn(async move {
730            let mut txn = env.write_txn()?;
731            threads.delete(&mut txn, &id)?;
732            txn.commit()?;
733            Ok(())
734        })
735    }
736}