agent.rs

  1use crate::{AgentResponseEvent, Thread, templates::Templates};
  2use crate::{
  3    CopyPathTool, CreateDirectoryTool, EditFileTool, FindPathTool, ListDirectoryTool, MovePathTool,
  4    OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization,
  5};
  6use acp_thread::ModelSelector;
  7use agent_client_protocol as acp;
  8use anyhow::{Context as _, Result, anyhow};
  9use futures::{StreamExt, future};
 10use gpui::{
 11    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
 12};
 13use language_model::{LanguageModel, LanguageModelRegistry};
 14use project::{Project, ProjectItem, ProjectPath, Worktree};
 15use prompt_store::{
 16    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
 17};
 18use std::cell::RefCell;
 19use std::collections::HashMap;
 20use std::path::Path;
 21use std::rc::Rc;
 22use std::sync::Arc;
 23use util::ResultExt;
 24
 25const RULES_FILE_NAMES: [&'static str; 9] = [
 26    ".rules",
 27    ".cursorrules",
 28    ".windsurfrules",
 29    ".clinerules",
 30    ".github/copilot-instructions.md",
 31    "CLAUDE.md",
 32    "AGENT.md",
 33    "AGENTS.md",
 34    "GEMINI.md",
 35];
 36
 37pub struct RulesLoadingError {
 38    pub message: SharedString,
 39}
 40
 41/// Holds both the internal Thread and the AcpThread for a session
 42struct Session {
 43    /// The internal thread that processes messages
 44    thread: Entity<Thread>,
 45    /// The ACP thread that handles protocol communication
 46    acp_thread: WeakEntity<acp_thread::AcpThread>,
 47    _subscription: Subscription,
 48}
 49
 50pub struct NativeAgent {
 51    /// Session ID -> Session mapping
 52    sessions: HashMap<acp::SessionId, Session>,
 53    /// Shared project context for all threads
 54    project_context: Rc<RefCell<ProjectContext>>,
 55    project_context_needs_refresh: watch::Sender<()>,
 56    _maintain_project_context: Task<Result<()>>,
 57    /// Shared templates for all threads
 58    templates: Arc<Templates>,
 59    project: Entity<Project>,
 60    prompt_store: Option<Entity<PromptStore>>,
 61    _subscriptions: Vec<Subscription>,
 62}
 63
 64impl NativeAgent {
 65    pub async fn new(
 66        project: Entity<Project>,
 67        templates: Arc<Templates>,
 68        prompt_store: Option<Entity<PromptStore>>,
 69        cx: &mut AsyncApp,
 70    ) -> Result<Entity<NativeAgent>> {
 71        log::info!("Creating new NativeAgent");
 72
 73        let project_context = cx
 74            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 75            .await;
 76
 77        cx.new(|cx| {
 78            let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 79            if let Some(prompt_store) = prompt_store.as_ref() {
 80                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 81            }
 82
 83            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 84                watch::channel(());
 85            Self {
 86                sessions: HashMap::new(),
 87                project_context: Rc::new(RefCell::new(project_context)),
 88                project_context_needs_refresh: project_context_needs_refresh_tx,
 89                _maintain_project_context: cx.spawn(async move |this, cx| {
 90                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 91                }),
 92                templates,
 93                project,
 94                prompt_store,
 95                _subscriptions: subscriptions,
 96            }
 97        })
 98    }
 99
100    async fn maintain_project_context(
101        this: WeakEntity<Self>,
102        mut needs_refresh: watch::Receiver<()>,
103        cx: &mut AsyncApp,
104    ) -> Result<()> {
105        while needs_refresh.changed().await.is_ok() {
106            let project_context = this
107                .update(cx, |this, cx| {
108                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
109                })?
110                .await;
111            this.update(cx, |this, _| this.project_context.replace(project_context))?;
112        }
113
114        Ok(())
115    }
116
117    fn build_project_context(
118        project: &Entity<Project>,
119        prompt_store: Option<&Entity<PromptStore>>,
120        cx: &mut App,
121    ) -> Task<ProjectContext> {
122        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
123        let worktree_tasks = worktrees
124            .into_iter()
125            .map(|worktree| {
126                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
127            })
128            .collect::<Vec<_>>();
129        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
130            prompt_store.read_with(cx, |prompt_store, cx| {
131                let prompts = prompt_store.default_prompt_metadata();
132                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
133                    let contents = prompt_store.load(prompt_metadata.id, cx);
134                    async move { (contents.await, prompt_metadata) }
135                });
136                cx.background_spawn(future::join_all(load_tasks))
137            })
138        } else {
139            Task::ready(vec![])
140        };
141
142        cx.spawn(async move |_cx| {
143            let (worktrees, default_user_rules) =
144                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
145
146            let worktrees = worktrees
147                .into_iter()
148                .map(|(worktree, _rules_error)| {
149                    // TODO: show error message
150                    // if let Some(rules_error) = rules_error {
151                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
152                    // }
153                    worktree
154                })
155                .collect::<Vec<_>>();
156
157            let default_user_rules = default_user_rules
158                .into_iter()
159                .flat_map(|(contents, prompt_metadata)| match contents {
160                    Ok(contents) => Some(UserRulesContext {
161                        uuid: match prompt_metadata.id {
162                            PromptId::User { uuid } => uuid,
163                            PromptId::EditWorkflow => return None,
164                        },
165                        title: prompt_metadata.title.map(|title| title.to_string()),
166                        contents,
167                    }),
168                    Err(_err) => {
169                        // TODO: show error message
170                        // this.update(cx, |_, cx| {
171                        //     cx.emit(RulesLoadingError {
172                        //         message: format!("{err:?}").into(),
173                        //     });
174                        // })
175                        // .ok();
176                        None
177                    }
178                })
179                .collect::<Vec<_>>();
180
181            ProjectContext::new(worktrees, default_user_rules)
182        })
183    }
184
185    fn load_worktree_info_for_system_prompt(
186        worktree: Entity<Worktree>,
187        project: Entity<Project>,
188        cx: &mut App,
189    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
190        let tree = worktree.read(cx);
191        let root_name = tree.root_name().into();
192        let abs_path = tree.abs_path();
193
194        let mut context = WorktreeContext {
195            root_name,
196            abs_path,
197            rules_file: None,
198        };
199
200        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
201        let Some(rules_task) = rules_task else {
202            return Task::ready((context, None));
203        };
204
205        cx.spawn(async move |_| {
206            let (rules_file, rules_file_error) = match rules_task.await {
207                Ok(rules_file) => (Some(rules_file), None),
208                Err(err) => (
209                    None,
210                    Some(RulesLoadingError {
211                        message: format!("{err}").into(),
212                    }),
213                ),
214            };
215            context.rules_file = rules_file;
216            (context, rules_file_error)
217        })
218    }
219
220    fn load_worktree_rules_file(
221        worktree: Entity<Worktree>,
222        project: Entity<Project>,
223        cx: &mut App,
224    ) -> Option<Task<Result<RulesFileContext>>> {
225        let worktree = worktree.read(cx);
226        let worktree_id = worktree.id();
227        let selected_rules_file = RULES_FILE_NAMES
228            .into_iter()
229            .filter_map(|name| {
230                worktree
231                    .entry_for_path(name)
232                    .filter(|entry| entry.is_file())
233                    .map(|entry| entry.path.clone())
234            })
235            .next();
236
237        // Note that Cline supports `.clinerules` being a directory, but that is not currently
238        // supported. This doesn't seem to occur often in GitHub repositories.
239        selected_rules_file.map(|path_in_worktree| {
240            let project_path = ProjectPath {
241                worktree_id,
242                path: path_in_worktree.clone(),
243            };
244            let buffer_task =
245                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
246            let rope_task = cx.spawn(async move |cx| {
247                buffer_task.await?.read_with(cx, |buffer, cx| {
248                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
249                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
250                })?
251            });
252            // Build a string from the rope on a background thread.
253            cx.background_spawn(async move {
254                let (project_entry_id, rope) = rope_task.await?;
255                anyhow::Ok(RulesFileContext {
256                    path_in_worktree,
257                    text: rope.to_string().trim().to_string(),
258                    project_entry_id: project_entry_id.to_usize(),
259                })
260            })
261        })
262    }
263
264    fn handle_project_event(
265        &mut self,
266        _project: Entity<Project>,
267        event: &project::Event,
268        _cx: &mut Context<Self>,
269    ) {
270        match event {
271            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
272                self.project_context_needs_refresh.send(()).ok();
273            }
274            project::Event::WorktreeUpdatedEntries(_, items) => {
275                if items.iter().any(|(path, _, _)| {
276                    RULES_FILE_NAMES
277                        .iter()
278                        .any(|name| path.as_ref() == Path::new(name))
279                }) {
280                    self.project_context_needs_refresh.send(()).ok();
281                }
282            }
283            _ => {}
284        }
285    }
286
287    fn handle_prompts_updated_event(
288        &mut self,
289        _prompt_store: Entity<PromptStore>,
290        _event: &prompt_store::PromptsUpdatedEvent,
291        _cx: &mut Context<Self>,
292    ) {
293        self.project_context_needs_refresh.send(()).ok();
294    }
295}
296
297/// Wrapper struct that implements the AgentConnection trait
298#[derive(Clone)]
299pub struct NativeAgentConnection(pub Entity<NativeAgent>);
300
301impl ModelSelector for NativeAgentConnection {
302    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
303        log::debug!("NativeAgentConnection::list_models called");
304        cx.spawn(async move |cx| {
305            cx.update(|cx| {
306                let registry = LanguageModelRegistry::read_global(cx);
307                let models = registry.available_models(cx).collect::<Vec<_>>();
308                log::info!("Found {} available models", models.len());
309                if models.is_empty() {
310                    Err(anyhow::anyhow!("No models available"))
311                } else {
312                    Ok(models)
313                }
314            })?
315        })
316    }
317
318    fn select_model(
319        &self,
320        session_id: acp::SessionId,
321        model: Arc<dyn LanguageModel>,
322        cx: &mut AsyncApp,
323    ) -> Task<Result<()>> {
324        log::info!(
325            "Setting model for session {}: {:?}",
326            session_id,
327            model.name()
328        );
329        let agent = self.0.clone();
330
331        cx.spawn(async move |cx| {
332            agent.update(cx, |agent, cx| {
333                if let Some(session) = agent.sessions.get(&session_id) {
334                    session.thread.update(cx, |thread, _cx| {
335                        thread.selected_model = model;
336                    });
337                    Ok(())
338                } else {
339                    Err(anyhow!("Session not found"))
340                }
341            })?
342        })
343    }
344
345    fn selected_model(
346        &self,
347        session_id: &acp::SessionId,
348        cx: &mut AsyncApp,
349    ) -> Task<Result<Arc<dyn LanguageModel>>> {
350        let agent = self.0.clone();
351        let session_id = session_id.clone();
352        cx.spawn(async move |cx| {
353            let thread = agent
354                .read_with(cx, |agent, _| {
355                    agent
356                        .sessions
357                        .get(&session_id)
358                        .map(|session| session.thread.clone())
359                })?
360                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
361            let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
362            Ok(selected)
363        })
364    }
365}
366
367impl acp_thread::AgentConnection for NativeAgentConnection {
368    fn new_thread(
369        self: Rc<Self>,
370        project: Entity<Project>,
371        cwd: &Path,
372        cx: &mut AsyncApp,
373    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
374        let agent = self.0.clone();
375        log::info!("Creating new thread for project at: {:?}", cwd);
376
377        cx.spawn(async move |cx| {
378            log::debug!("Starting thread creation in async context");
379
380            // Generate session ID
381            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
382            log::info!("Created session with ID: {}", session_id);
383
384            // Create AcpThread
385            let acp_thread = cx.update(|cx| {
386                cx.new(|cx| {
387                    acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
388                })
389            })?;
390            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
391
392            // Create Thread
393            let thread = agent.update(
394                cx,
395                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
396                    // Fetch default model from registry settings
397                    let registry = LanguageModelRegistry::read_global(cx);
398
399                    // Log available models for debugging
400                    let available_count = registry.available_models(cx).count();
401                    log::debug!("Total available models: {}", available_count);
402
403                    let default_model = registry
404                        .default_model()
405                        .map(|configured| {
406                            log::info!(
407                                "Using configured default model: {:?} from provider: {:?}",
408                                configured.model.name(),
409                                configured.provider.name()
410                            );
411                            configured.model
412                        })
413                        .ok_or_else(|| {
414                            log::warn!("No default model configured in settings");
415                            anyhow!("No default model configured. Please configure a default model in settings.")
416                        })?;
417
418                    let thread = cx.new(|cx| {
419                        let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
420                        thread.add_tool(CreateDirectoryTool::new(project.clone()));
421                        thread.add_tool(CopyPathTool::new(project.clone()));
422                        thread.add_tool(MovePathTool::new(project.clone()));
423                        thread.add_tool(ListDirectoryTool::new(project.clone()));
424                        thread.add_tool(OpenTool::new(project.clone()));
425                        thread.add_tool(ThinkingTool);
426                        thread.add_tool(FindPathTool::new(project.clone()));
427                        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
428                        thread.add_tool(EditFileTool::new(cx.entity()));
429                        thread.add_tool(TerminalTool::new(project.clone(), cx));
430                        thread
431                    });
432
433                    Ok(thread)
434                },
435            )??;
436
437            // Store the session
438            agent.update(cx, |agent, cx| {
439                agent.sessions.insert(
440                    session_id,
441                    Session {
442                        thread,
443                        acp_thread: acp_thread.downgrade(),
444                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
445                            this.sessions.remove(acp_thread.session_id());
446                        })
447                    },
448                );
449            })?;
450
451            Ok(acp_thread)
452        })
453    }
454
455    fn auth_methods(&self) -> &[acp::AuthMethod] {
456        &[] // No auth for in-process
457    }
458
459    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
460        Task::ready(Ok(()))
461    }
462
463    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
464        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
465    }
466
467    fn prompt(
468        &self,
469        params: acp::PromptRequest,
470        cx: &mut App,
471    ) -> Task<Result<acp::PromptResponse>> {
472        let session_id = params.session_id.clone();
473        let agent = self.0.clone();
474        log::info!("Received prompt request for session: {}", session_id);
475        log::debug!("Prompt blocks count: {}", params.prompt.len());
476
477        cx.spawn(async move |cx| {
478            // Get session
479            let (thread, acp_thread) = agent
480                .update(cx, |agent, _| {
481                    agent
482                        .sessions
483                        .get_mut(&session_id)
484                        .map(|s| (s.thread.clone(), s.acp_thread.clone()))
485                })?
486                .ok_or_else(|| {
487                    log::error!("Session not found: {}", session_id);
488                    anyhow::anyhow!("Session not found")
489                })?;
490            log::debug!("Found session for: {}", session_id);
491
492            // Convert prompt to message
493            let message = convert_prompt_to_message(params.prompt);
494            log::info!("Converted prompt to message: {} chars", message.len());
495            log::debug!("Message content: {}", message);
496
497            // Get model using the ModelSelector capability (always available for agent2)
498            // Get the selected model from the thread directly
499            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
500
501            // Send to thread
502            log::info!("Sending message to thread with model: {:?}", model.name());
503            let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
504
505            // Handle response stream and forward to session.acp_thread
506            while let Some(result) = response_stream.next().await {
507                match result {
508                    Ok(event) => {
509                        log::trace!("Received completion event: {:?}", event);
510
511                        match event {
512                            AgentResponseEvent::Text(text) => {
513                                acp_thread.update(cx, |thread, cx| {
514                                    thread.push_assistant_content_block(
515                                        acp::ContentBlock::Text(acp::TextContent {
516                                            text,
517                                            annotations: None,
518                                        }),
519                                        false,
520                                        cx,
521                                    )
522                                })?;
523                            }
524                            AgentResponseEvent::Thinking(text) => {
525                                acp_thread.update(cx, |thread, cx| {
526                                    thread.push_assistant_content_block(
527                                        acp::ContentBlock::Text(acp::TextContent {
528                                            text,
529                                            annotations: None,
530                                        }),
531                                        true,
532                                        cx,
533                                    )
534                                })?;
535                            }
536                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
537                                tool_call,
538                                options,
539                                response,
540                            }) => {
541                                let recv = acp_thread.update(cx, |thread, cx| {
542                                    thread.request_tool_call_authorization(tool_call, options, cx)
543                                })?;
544                                cx.background_spawn(async move {
545                                    if let Some(option) = recv
546                                        .await
547                                        .context("authorization sender was dropped")
548                                        .log_err()
549                                    {
550                                        response
551                                            .send(option)
552                                            .map(|_| anyhow!("authorization receiver was dropped"))
553                                            .log_err();
554                                    }
555                                })
556                                .detach();
557                            }
558                            AgentResponseEvent::ToolCall(tool_call) => {
559                                acp_thread.update(cx, |thread, cx| {
560                                    thread.upsert_tool_call(tool_call, cx)
561                                })?;
562                            }
563                            AgentResponseEvent::ToolCallUpdate(update) => {
564                                acp_thread.update(cx, |thread, cx| {
565                                    thread.update_tool_call(update, cx)
566                                })??;
567                            }
568                            AgentResponseEvent::Stop(stop_reason) => {
569                                log::debug!("Assistant message complete: {:?}", stop_reason);
570                                return Ok(acp::PromptResponse { stop_reason });
571                            }
572                        }
573                    }
574                    Err(e) => {
575                        log::error!("Error in model response stream: {:?}", e);
576                        // TODO: Consider sending an error message to the UI
577                        break;
578                    }
579                }
580            }
581
582            log::info!("Response stream completed");
583            anyhow::Ok(acp::PromptResponse {
584                stop_reason: acp::StopReason::EndTurn,
585            })
586        })
587    }
588
589    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
590        log::info!("Cancelling on session: {}", session_id);
591        self.0.update(cx, |agent, cx| {
592            if let Some(agent) = agent.sessions.get(session_id) {
593                agent.thread.update(cx, |thread, _cx| thread.cancel());
594            }
595        });
596    }
597}
598
599/// Convert ACP content blocks to a message string
600fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
601    log::debug!("Converting {} content blocks to message", blocks.len());
602    let mut message = String::new();
603
604    for block in blocks {
605        match block {
606            acp::ContentBlock::Text(text) => {
607                log::trace!("Processing text block: {} chars", text.text.len());
608                message.push_str(&text.text);
609            }
610            acp::ContentBlock::ResourceLink(link) => {
611                log::trace!("Processing resource link: {}", link.uri);
612                message.push_str(&format!(" @{} ", link.uri));
613            }
614            acp::ContentBlock::Image(_) => {
615                log::trace!("Processing image block");
616                message.push_str(" [image] ");
617            }
618            acp::ContentBlock::Audio(_) => {
619                log::trace!("Processing audio block");
620                message.push_str(" [audio] ");
621            }
622            acp::ContentBlock::Resource(resource) => {
623                log::trace!("Processing resource block: {:?}", resource.resource);
624                message.push_str(&format!(" [resource: {:?}] ", resource.resource));
625            }
626        }
627    }
628
629    message
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use fs::FakeFs;
636    use gpui::TestAppContext;
637    use serde_json::json;
638    use settings::SettingsStore;
639
640    #[gpui::test]
641    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
642        init_test(cx);
643        let fs = FakeFs::new(cx.executor());
644        fs.insert_tree(
645            "/",
646            json!({
647                "a": {}
648            }),
649        )
650        .await;
651        let project = Project::test(fs.clone(), [], cx).await;
652        let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
653            .await
654            .unwrap();
655        agent.read_with(cx, |agent, _| {
656            assert_eq!(agent.project_context.borrow().worktrees, vec![])
657        });
658
659        let worktree = project
660            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
661            .await
662            .unwrap();
663        cx.run_until_parked();
664        agent.read_with(cx, |agent, _| {
665            assert_eq!(
666                agent.project_context.borrow().worktrees,
667                vec![WorktreeContext {
668                    root_name: "a".into(),
669                    abs_path: Path::new("/a").into(),
670                    rules_file: None
671                }]
672            )
673        });
674
675        // Creating `/a/.rules` updates the project context.
676        fs.insert_file("/a/.rules", Vec::new()).await;
677        cx.run_until_parked();
678        agent.read_with(cx, |agent, cx| {
679            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
680            assert_eq!(
681                agent.project_context.borrow().worktrees,
682                vec![WorktreeContext {
683                    root_name: "a".into(),
684                    abs_path: Path::new("/a").into(),
685                    rules_file: Some(RulesFileContext {
686                        path_in_worktree: Path::new(".rules").into(),
687                        text: "".into(),
688                        project_entry_id: rules_entry.id.to_usize()
689                    })
690                }]
691            )
692        });
693    }
694
695    fn init_test(cx: &mut TestAppContext) {
696        env_logger::try_init().ok();
697        cx.update(|cx| {
698            let settings_store = SettingsStore::test(cx);
699            cx.set_global(settings_store);
700            Project::init_settings(cx);
701            language::init(cx);
702        });
703    }
704}