Initial support for AI assistant rules files (#27168)

Michael Sloan , Danilo , Nathan , and Thomas created

Release Notes:

- N/A

---------

Co-authored-by: Danilo <danilo@zed.dev>
Co-authored-by: Nathan <nathan@zed.dev>
Co-authored-by: Thomas <thomas@zed.dev>

Change summary

assets/prompts/assistant_system_prompt.hbs |  16 ++
crates/assistant2/src/active_thread.rs     |  88 +++++++++++
crates/assistant2/src/assistant_panel.rs   |  17 +-
crates/assistant2/src/message_editor.rs    |  13 +
crates/assistant2/src/thread.rs            | 176 +++++++++++++++++++----
crates/assistant2/src/thread_store.rs      |  18 ++
crates/assistant_eval/src/eval.rs          |  15 ++
crates/prompt_store/src/prompts.rs         |  30 +++
8 files changed, 322 insertions(+), 51 deletions(-)

Detailed changes

assets/prompts/assistant_system_prompt.hbs 🔗

@@ -14,5 +14,19 @@ Be concise and direct in your responses.
 The user has opened a project that contains the following root directories/files:
 
 {{#each worktrees}}
-- {{root_name}} (absolute path: {{abs_path}})
+- `{{root_name}}` (absolute path: `{{abs_path}}`)
 {{/each}}
+{{#if has_rules}}
+
+There are rules that apply to these root directories:
+{{#each worktrees}}
+{{#if rules_file}}
+
+`{{root_name}}/{{rules_file.rel_path}}`:
+
+``````
+{{{rules_file.text}}}
+``````
+{{/if}}
+{{/each}}
+{{/if}}

crates/assistant2/src/active_thread.rs 🔗

@@ -8,7 +8,7 @@ use gpui::{
     list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent,
     DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, ListOffset,
     ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Transformation,
-    UnderlineStyle,
+    UnderlineStyle, WeakEntity,
 };
 use language::{Buffer, LanguageRegistry};
 use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
@@ -18,9 +18,9 @@ use settings::Settings as _;
 use std::sync::Arc;
 use std::time::Duration;
 use theme::ThemeSettings;
-use ui::Color;
 use ui::{prelude::*, Disclosure, KeyBinding};
 use util::ResultExt as _;
+use workspace::{OpenOptions, Workspace};
 
 use crate::context_store::{refresh_context_store_text, ContextStore};
 
@@ -29,6 +29,7 @@ pub struct ActiveThread {
     thread_store: Entity<ThreadStore>,
     thread: Entity<Thread>,
     context_store: Entity<ContextStore>,
+    workspace: WeakEntity<Workspace>,
     save_thread_task: Option<Task<()>>,
     messages: Vec<MessageId>,
     list_state: ListState,
@@ -50,6 +51,7 @@ impl ActiveThread {
         thread_store: Entity<ThreadStore>,
         language_registry: Arc<LanguageRegistry>,
         context_store: Entity<ContextStore>,
+        workspace: WeakEntity<Workspace>,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Self {
@@ -63,6 +65,7 @@ impl ActiveThread {
             thread_store,
             thread: thread.clone(),
             context_store,
+            workspace,
             save_thread_task: None,
             messages: Vec::new(),
             rendered_messages_by_id: HashMap::default(),
@@ -736,6 +739,7 @@ impl ActiveThread {
         };
 
         v_flex()
+            .when(ix == 0, |parent| parent.child(self.render_rules_item(cx)))
             .when_some(checkpoint, |parent, checkpoint| {
                 parent.child(
                     h_flex().pl_2().child(
@@ -1042,6 +1046,86 @@ impl ActiveThread {
                 }),
         )
     }
+
+    fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
+        let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
+        else {
+            return div().into_any();
+        };
+
+        let rules_files = system_prompt_context
+            .worktrees
+            .iter()
+            .filter_map(|worktree| worktree.rules_file.as_ref())
+            .collect::<Vec<_>>();
+
+        let label_text = match rules_files.as_slice() {
+            &[] => return div().into_any(),
+            &[rules_file] => {
+                format!("Using {:?} file", rules_file.rel_path)
+            }
+            rules_files => {
+                format!("Using {} rules files", rules_files.len())
+            }
+        };
+
+        div()
+            .pt_1()
+            .px_2p5()
+            .child(
+                h_flex()
+                    .group("rules-item")
+                    .w_full()
+                    .gap_2()
+                    .justify_between()
+                    .child(
+                        h_flex()
+                            .gap_1p5()
+                            .child(
+                                Icon::new(IconName::File)
+                                    .size(IconSize::XSmall)
+                                    .color(Color::Disabled),
+                            )
+                            .child(
+                                Label::new(label_text)
+                                    .size(LabelSize::XSmall)
+                                    .color(Color::Muted)
+                                    .buffer_font(cx),
+                            ),
+                    )
+                    .child(
+                        div().visible_on_hover("rules-item").child(
+                            Button::new("open-rules", "Open Rules")
+                                .label_size(LabelSize::XSmall)
+                                .on_click(cx.listener(Self::handle_open_rules)),
+                        ),
+                    ),
+            )
+            .into_any()
+    }
+
+    fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
+        let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
+        else {
+            return;
+        };
+
+        let abs_paths = system_prompt_context
+            .worktrees
+            .iter()
+            .flat_map(|worktree| worktree.rules_file.as_ref())
+            .map(|rules_file| rules_file.abs_path.to_path_buf())
+            .collect::<Vec<_>>();
+
+        if let Ok(task) = self.workspace.update(cx, move |workspace, cx| {
+            // TODO: Open a multibuffer instead? In some cases this doesn't make the set of rules
+            // files clear. For example, if rules file 1 is already open but rules file 2 is not,
+            // this would open and focus rules file 2 in a tab that is not next to rules file 1.
+            workspace.open_paths(abs_paths, OpenOptions::default(), None, window, cx)
+        }) {
+            task.detach();
+        }
+    }
 }
 
 impl Render for ActiveThread {

crates/assistant2/src/assistant_panel.rs 🔗

@@ -174,6 +174,7 @@ impl AssistantPanel {
                 thread_store.clone(),
                 language_registry.clone(),
                 message_editor_context_store.clone(),
+                workspace.clone(),
                 window,
                 cx,
             )
@@ -252,6 +253,7 @@ impl AssistantPanel {
                 self.thread_store.clone(),
                 self.language_registry.clone(),
                 message_editor_context_store.clone(),
+                self.workspace.clone(),
                 window,
                 cx,
             )
@@ -389,6 +391,7 @@ impl AssistantPanel {
                         this.thread_store.clone(),
                         this.language_registry.clone(),
                         message_editor_context_store.clone(),
+                        this.workspace.clone(),
                         window,
                         cx,
                     )
@@ -922,8 +925,8 @@ impl AssistantPanel {
                     ThreadError::MaxMonthlySpendReached => {
                         self.render_max_monthly_spend_reached_error(cx)
                     }
-                    ThreadError::Message(error_message) => {
-                        self.render_error_message(&error_message, cx)
+                    ThreadError::Message { header, message } => {
+                        self.render_error_message(header, message, cx)
                     }
                 })
                 .into_any(),
@@ -1026,7 +1029,8 @@ impl AssistantPanel {
 
     fn render_error_message(
         &self,
-        error_message: &SharedString,
+        header: SharedString,
+        message: SharedString,
         cx: &mut Context<Self>,
     ) -> AnyElement {
         v_flex()
@@ -1036,17 +1040,14 @@ impl AssistantPanel {
                     .gap_1p5()
                     .items_center()
                     .child(Icon::new(IconName::XCircle).color(Color::Error))
-                    .child(
-                        Label::new("Error interacting with language model")
-                            .weight(FontWeight::MEDIUM),
-                    ),
+                    .child(Label::new(header).weight(FontWeight::MEDIUM)),
             )
             .child(
                 div()
                     .id("error-message")
                     .max_h_32()
                     .overflow_y_scroll()
-                    .child(Label::new(error_message.clone())),
+                    .child(Label::new(message)),
             )
             .child(
                 h_flex()

crates/assistant2/src/message_editor.rs 🔗

@@ -33,7 +33,7 @@ use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
 use crate::thread::{RequestKind, Thread};
 use crate::thread_store::ThreadStore;
 use crate::tool_selector::ToolSelector;
-use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker};
+use crate::{Chat, ChatMode, RemoveAllContext, ThreadEvent, ToggleContextPicker};
 
 pub struct MessageEditor {
     thread: Entity<Thread>,
@@ -206,12 +206,23 @@ impl MessageEditor {
         let refresh_task =
             refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
 
+        let system_prompt_context_task = self.thread.read(cx).load_system_prompt_context(cx);
+
         let thread = self.thread.clone();
         let context_store = self.context_store.clone();
         let git_store = self.project.read(cx).git_store();
         let checkpoint = git_store.read(cx).checkpoint(cx);
         cx.spawn(async move |_, cx| {
             refresh_task.await;
+            let (system_prompt_context, load_error) = system_prompt_context_task.await;
+            thread
+                .update(cx, |thread, cx| {
+                    thread.set_system_prompt_context(system_prompt_context);
+                    if let Some(load_error) = load_error {
+                        cx.emit(ThreadEvent::ShowError(load_error));
+                    }
+                })
+                .ok();
             let checkpoint = checkpoint.await.log_err();
             thread
                 .update(cx, |thread, cx| {

crates/assistant2/src/thread.rs 🔗

@@ -6,6 +6,7 @@ use anyhow::{Context as _, Result};
 use assistant_tool::{ActionLog, ToolWorkingSet};
 use chrono::{DateTime, Utc};
 use collections::{BTreeMap, HashMap, HashSet};
+use fs::Fs;
 use futures::future::Shared;
 use futures::{FutureExt, StreamExt as _};
 use git;
@@ -17,11 +18,13 @@ use language_model::{
     Role, StopReason, TokenUsage,
 };
 use project::git::GitStoreCheckpoint;
-use project::Project;
-use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
+use project::{Project, Worktree};
+use prompt_store::{
+    AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
+};
 use scripting_tool::{ScriptingSession, ScriptingTool};
 use serde::{Deserialize, Serialize};
-use util::{post_inc, ResultExt, TryFutureExt as _};
+use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _};
 use uuid::Uuid;
 
 use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
@@ -106,6 +109,7 @@ pub struct Thread {
     next_message_id: MessageId,
     context: BTreeMap<ContextId, ContextSnapshot>,
     context_by_message: HashMap<MessageId, Vec<ContextId>>,
+    system_prompt_context: Option<AssistantSystemPromptContext>,
     checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
@@ -136,6 +140,7 @@ impl Thread {
             next_message_id: MessageId(0),
             context: BTreeMap::default(),
             context_by_message: HashMap::default(),
+            system_prompt_context: None,
             checkpoints_by_message: HashMap::default(),
             completion_count: 0,
             pending_completions: Vec::new(),
@@ -197,6 +202,7 @@ impl Thread {
             next_message_id,
             context: BTreeMap::default(),
             context_by_message: HashMap::default(),
+            system_prompt_context: None,
             checkpoints_by_message: HashMap::default(),
             completion_count: 0,
             pending_completions: Vec::new(),
@@ -478,6 +484,116 @@ impl Thread {
         })
     }
 
+    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
+        self.system_prompt_context = Some(context);
+    }
+
+    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
+        &self.system_prompt_context
+    }
+
+    pub fn load_system_prompt_context(
+        &self,
+        cx: &App,
+    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
+        let project = self.project.read(cx);
+        let tasks = project
+            .visible_worktrees(cx)
+            .map(|worktree| {
+                Self::load_worktree_info_for_system_prompt(
+                    project.fs().clone(),
+                    worktree.read(cx),
+                    cx,
+                )
+            })
+            .collect::<Vec<_>>();
+
+        cx.spawn(async |_cx| {
+            let results = futures::future::join_all(tasks).await;
+            let mut first_err = None;
+            let worktrees = results
+                .into_iter()
+                .map(|(worktree, err)| {
+                    if first_err.is_none() && err.is_some() {
+                        first_err = err;
+                    }
+                    worktree
+                })
+                .collect::<Vec<_>>();
+            (AssistantSystemPromptContext::new(worktrees), first_err)
+        })
+    }
+
+    fn load_worktree_info_for_system_prompt(
+        fs: Arc<dyn Fs>,
+        worktree: &Worktree,
+        cx: &App,
+    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
+        let root_name = worktree.root_name().into();
+        let abs_path = worktree.abs_path();
+
+        // Note that Cline supports `.clinerules` being a directory, but that is not currently
+        // supported. This doesn't seem to occur often in GitHub repositories.
+        const RULES_FILE_NAMES: [&'static str; 5] = [
+            ".rules",
+            ".cursorrules",
+            ".windsurfrules",
+            ".clinerules",
+            "CLAUDE.md",
+        ];
+        let selected_rules_file = RULES_FILE_NAMES
+            .into_iter()
+            .filter_map(|name| {
+                worktree
+                    .entry_for_path(name)
+                    .filter(|entry| entry.is_file())
+                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
+            })
+            .next();
+
+        if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
+            cx.spawn(async move |_| {
+                let rules_file_result = maybe!(async move {
+                    let abs_rules_path = abs_rules_path?;
+                    let text = fs.load(&abs_rules_path).await.with_context(|| {
+                        format!("Failed to load assistant rules file {:?}", abs_rules_path)
+                    })?;
+                    anyhow::Ok(RulesFile {
+                        rel_path: rel_rules_path,
+                        abs_path: abs_rules_path.into(),
+                        text: text.trim().to_string(),
+                    })
+                })
+                .await;
+                let (rules_file, rules_file_error) = match rules_file_result {
+                    Ok(rules_file) => (Some(rules_file), None),
+                    Err(err) => (
+                        None,
+                        Some(ThreadError::Message {
+                            header: "Error loading rules file".into(),
+                            message: format!("{err}").into(),
+                        }),
+                    ),
+                };
+                let worktree_info = WorktreeInfoForSystemPrompt {
+                    root_name,
+                    abs_path,
+                    rules_file,
+                };
+                (worktree_info, rules_file_error)
+            })
+        } else {
+            Task::ready((
+                WorktreeInfoForSystemPrompt {
+                    root_name,
+                    abs_path,
+                    rules_file: None,
+                },
+                None,
+            ))
+        }
+    }
+
     pub fn send_to_model(
         &mut self,
         model: Arc<dyn LanguageModel>,
@@ -515,36 +631,30 @@ impl Thread {
         request_kind: RequestKind,
         cx: &App,
     ) -> LanguageModelRequest {
-        let worktree_root_names = self
-            .project
-            .read(cx)
-            .visible_worktrees(cx)
-            .map(|worktree| {
-                let worktree = worktree.read(cx);
-                AssistantSystemPromptWorktree {
-                    root_name: worktree.root_name().into(),
-                    abs_path: worktree.abs_path(),
-                }
-            })
-            .collect::<Vec<_>>();
-        let system_prompt = self
-            .prompt_builder
-            .generate_assistant_system_prompt(worktree_root_names)
-            .context("failed to generate assistant system prompt")
-            .log_err()
-            .unwrap_or_default();
-
         let mut request = LanguageModelRequest {
-            messages: vec![LanguageModelRequestMessage {
-                role: Role::System,
-                content: vec![MessageContent::Text(system_prompt)],
-                cache: true,
-            }],
+            messages: vec![],
             tools: Vec::new(),
             stop: Vec::new(),
             temperature: None,
         };
 
+        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
+            if let Some(system_prompt) = self
+                .prompt_builder
+                .generate_assistant_system_prompt(system_prompt_context)
+                .context("failed to generate assistant system prompt")
+                .log_err()
+            {
+                request.messages.push(LanguageModelRequestMessage {
+                    role: Role::System,
+                    content: vec![MessageContent::Text(system_prompt)],
+                    cache: true,
+                });
+            }
+        } else {
+            log::error!("system_prompt_context not set.")
+        }
+
         let mut referenced_context_ids = HashSet::default();
 
         for message in &self.messages {
@@ -757,9 +867,10 @@ impl Thread {
                                     .map(|err| err.to_string())
                                     .collect::<Vec<_>>()
                                     .join("\n");
-                                cx.emit(ThreadEvent::ShowError(ThreadError::Message(
-                                    SharedString::from(error_message.clone()),
-                                )));
+                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
+                                    header: "Error interacting with language model".into(),
+                                    message: SharedString::from(error_message.clone()),
+                                }));
                             }
 
                             thread.cancel_last_completion(cx);
@@ -1204,7 +1315,10 @@ impl Thread {
 pub enum ThreadError {
     PaymentRequired,
     MaxMonthlySpendReached,
-    Message(SharedString),
+    Message {
+        header: SharedString,
+        message: SharedString,
+    },
 }
 
 #[derive(Debug, Clone)]

crates/assistant2/src/thread_store.rs 🔗

@@ -20,7 +20,7 @@ use prompt_store::PromptBuilder;
 use serde::{Deserialize, Serialize};
 use util::ResultExt as _;
 
-use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadId};
+use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
 
 pub fn init(cx: &mut App) {
     ThreadsDatabase::init(cx);
@@ -113,7 +113,7 @@ impl ThreadStore {
                 .await?
                 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
 
-            this.update(cx, |this, cx| {
+            let thread = this.update(cx, |this, cx| {
                 cx.new(|cx| {
                     Thread::deserialize(
                         id.clone(),
@@ -124,7 +124,19 @@ impl ThreadStore {
                         cx,
                     )
                 })
-            })
+            })?;
+
+            let (system_prompt_context, load_error) = thread
+                .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
+                .await;
+            thread.update(cx, |thread, cx| {
+                thread.set_system_prompt_context(system_prompt_context);
+                if let Some(load_error) = load_error {
+                    cx.emit(ThreadEvent::ShowError(load_error));
+                }
+            })?;
+
+            Ok(thread)
         })
     }
 

crates/assistant_eval/src/eval.rs 🔗

@@ -79,10 +79,25 @@ impl Eval {
 
             let start_time = std::time::SystemTime::now();
 
+            let (system_prompt_context, load_error) = cx
+                .update(|cx| {
+                    assistant
+                        .read(cx)
+                        .thread
+                        .read(cx)
+                        .load_system_prompt_context(cx)
+                })?
+                .await;
+
+            if let Some(load_error) = load_error {
+                return Err(anyhow!("{:?}", load_error));
+            };
+
             assistant.update(cx, |assistant, cx| {
                 assistant.thread.update(cx, |thread, cx| {
                     let context = vec![];
                     thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
+                    thread.set_system_prompt_context(system_prompt_context);
                     thread.send_to_model(model, RequestKind::Chat, cx);
                 });
             })?;

crates/prompt_store/src/prompts.rs 🔗

@@ -18,13 +18,34 @@ use util::ResultExt;
 
 #[derive(Serialize)]
 pub struct AssistantSystemPromptContext {
-    pub worktrees: Vec<AssistantSystemPromptWorktree>,
+    pub worktrees: Vec<WorktreeInfoForSystemPrompt>,
+    pub has_rules: bool,
+}
+
+impl AssistantSystemPromptContext {
+    pub fn new(worktrees: Vec<WorktreeInfoForSystemPrompt>) -> Self {
+        let has_rules = worktrees
+            .iter()
+            .any(|worktree| worktree.rules_file.is_some());
+        Self {
+            worktrees,
+            has_rules,
+        }
+    }
 }
 
 #[derive(Serialize)]
-pub struct AssistantSystemPromptWorktree {
+pub struct WorktreeInfoForSystemPrompt {
     pub root_name: String,
     pub abs_path: Arc<Path>,
+    pub rules_file: Option<RulesFile>,
+}
+
+#[derive(Serialize)]
+pub struct RulesFile {
+    pub rel_path: Arc<Path>,
+    pub abs_path: Arc<Path>,
+    pub text: String,
 }
 
 #[derive(Serialize)]
@@ -234,12 +255,11 @@ impl PromptBuilder {
 
     pub fn generate_assistant_system_prompt(
         &self,
-        worktrees: Vec<AssistantSystemPromptWorktree>,
+        context: &AssistantSystemPromptContext,
     ) -> Result<String, RenderError> {
-        let prompt = AssistantSystemPromptContext { worktrees };
         self.handlebars
             .lock()
-            .render("assistant_system_prompt", &prompt)
+            .render("assistant_system_prompt", context)
     }
 
     pub fn generate_inline_transformation_prompt(