agent.rs

  1use crate::{AgentResponseEvent, Thread, templates::Templates};
  2use crate::{
  3    EditFileTool, FindPathTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization,
  4};
  5use acp_thread::ModelSelector;
  6use agent_client_protocol as acp;
  7use anyhow::{Context as _, Result, anyhow};
  8use futures::{StreamExt, future};
  9use gpui::{
 10    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
 11};
 12use language_model::{LanguageModel, LanguageModelRegistry};
 13use project::{Project, ProjectItem, ProjectPath, Worktree};
 14use prompt_store::{
 15    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
 16};
 17use std::cell::RefCell;
 18use std::collections::HashMap;
 19use std::path::Path;
 20use std::rc::Rc;
 21use std::sync::Arc;
 22use util::ResultExt;
 23
 24const RULES_FILE_NAMES: [&'static str; 9] = [
 25    ".rules",
 26    ".cursorrules",
 27    ".windsurfrules",
 28    ".clinerules",
 29    ".github/copilot-instructions.md",
 30    "CLAUDE.md",
 31    "AGENT.md",
 32    "AGENTS.md",
 33    "GEMINI.md",
 34];
 35
 36pub struct RulesLoadingError {
 37    pub message: SharedString,
 38}
 39
 40/// Holds both the internal Thread and the AcpThread for a session
 41struct Session {
 42    /// The internal thread that processes messages
 43    thread: Entity<Thread>,
 44    /// The ACP thread that handles protocol communication
 45    acp_thread: WeakEntity<acp_thread::AcpThread>,
 46    _subscription: Subscription,
 47}
 48
 49pub struct NativeAgent {
 50    /// Session ID -> Session mapping
 51    sessions: HashMap<acp::SessionId, Session>,
 52    /// Shared project context for all threads
 53    project_context: Rc<RefCell<ProjectContext>>,
 54    project_context_needs_refresh: watch::Sender<()>,
 55    _maintain_project_context: Task<Result<()>>,
 56    /// Shared templates for all threads
 57    templates: Arc<Templates>,
 58    project: Entity<Project>,
 59    prompt_store: Option<Entity<PromptStore>>,
 60    _subscriptions: Vec<Subscription>,
 61}
 62
 63impl NativeAgent {
 64    pub async fn new(
 65        project: Entity<Project>,
 66        templates: Arc<Templates>,
 67        prompt_store: Option<Entity<PromptStore>>,
 68        cx: &mut AsyncApp,
 69    ) -> Result<Entity<NativeAgent>> {
 70        log::info!("Creating new NativeAgent");
 71
 72        let project_context = cx
 73            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 74            .await;
 75
 76        cx.new(|cx| {
 77            let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 78            if let Some(prompt_store) = prompt_store.as_ref() {
 79                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 80            }
 81
 82            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 83                watch::channel(());
 84            Self {
 85                sessions: HashMap::new(),
 86                project_context: Rc::new(RefCell::new(project_context)),
 87                project_context_needs_refresh: project_context_needs_refresh_tx,
 88                _maintain_project_context: cx.spawn(async move |this, cx| {
 89                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 90                }),
 91                templates,
 92                project,
 93                prompt_store,
 94                _subscriptions: subscriptions,
 95            }
 96        })
 97    }
 98
 99    async fn maintain_project_context(
100        this: WeakEntity<Self>,
101        mut needs_refresh: watch::Receiver<()>,
102        cx: &mut AsyncApp,
103    ) -> Result<()> {
104        while needs_refresh.changed().await.is_ok() {
105            let project_context = this
106                .update(cx, |this, cx| {
107                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
108                })?
109                .await;
110            this.update(cx, |this, _| this.project_context.replace(project_context))?;
111        }
112
113        Ok(())
114    }
115
116    fn build_project_context(
117        project: &Entity<Project>,
118        prompt_store: Option<&Entity<PromptStore>>,
119        cx: &mut App,
120    ) -> Task<ProjectContext> {
121        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
122        let worktree_tasks = worktrees
123            .into_iter()
124            .map(|worktree| {
125                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
126            })
127            .collect::<Vec<_>>();
128        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
129            prompt_store.read_with(cx, |prompt_store, cx| {
130                let prompts = prompt_store.default_prompt_metadata();
131                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
132                    let contents = prompt_store.load(prompt_metadata.id, cx);
133                    async move { (contents.await, prompt_metadata) }
134                });
135                cx.background_spawn(future::join_all(load_tasks))
136            })
137        } else {
138            Task::ready(vec![])
139        };
140
141        cx.spawn(async move |_cx| {
142            let (worktrees, default_user_rules) =
143                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
144
145            let worktrees = worktrees
146                .into_iter()
147                .map(|(worktree, _rules_error)| {
148                    // TODO: show error message
149                    // if let Some(rules_error) = rules_error {
150                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
151                    // }
152                    worktree
153                })
154                .collect::<Vec<_>>();
155
156            let default_user_rules = default_user_rules
157                .into_iter()
158                .flat_map(|(contents, prompt_metadata)| match contents {
159                    Ok(contents) => Some(UserRulesContext {
160                        uuid: match prompt_metadata.id {
161                            PromptId::User { uuid } => uuid,
162                            PromptId::EditWorkflow => return None,
163                        },
164                        title: prompt_metadata.title.map(|title| title.to_string()),
165                        contents,
166                    }),
167                    Err(_err) => {
168                        // TODO: show error message
169                        // this.update(cx, |_, cx| {
170                        //     cx.emit(RulesLoadingError {
171                        //         message: format!("{err:?}").into(),
172                        //     });
173                        // })
174                        // .ok();
175                        None
176                    }
177                })
178                .collect::<Vec<_>>();
179
180            ProjectContext::new(worktrees, default_user_rules)
181        })
182    }
183
184    fn load_worktree_info_for_system_prompt(
185        worktree: Entity<Worktree>,
186        project: Entity<Project>,
187        cx: &mut App,
188    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
189        let tree = worktree.read(cx);
190        let root_name = tree.root_name().into();
191        let abs_path = tree.abs_path();
192
193        let mut context = WorktreeContext {
194            root_name,
195            abs_path,
196            rules_file: None,
197        };
198
199        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
200        let Some(rules_task) = rules_task else {
201            return Task::ready((context, None));
202        };
203
204        cx.spawn(async move |_| {
205            let (rules_file, rules_file_error) = match rules_task.await {
206                Ok(rules_file) => (Some(rules_file), None),
207                Err(err) => (
208                    None,
209                    Some(RulesLoadingError {
210                        message: format!("{err}").into(),
211                    }),
212                ),
213            };
214            context.rules_file = rules_file;
215            (context, rules_file_error)
216        })
217    }
218
219    fn load_worktree_rules_file(
220        worktree: Entity<Worktree>,
221        project: Entity<Project>,
222        cx: &mut App,
223    ) -> Option<Task<Result<RulesFileContext>>> {
224        let worktree = worktree.read(cx);
225        let worktree_id = worktree.id();
226        let selected_rules_file = RULES_FILE_NAMES
227            .into_iter()
228            .filter_map(|name| {
229                worktree
230                    .entry_for_path(name)
231                    .filter(|entry| entry.is_file())
232                    .map(|entry| entry.path.clone())
233            })
234            .next();
235
236        // Note that Cline supports `.clinerules` being a directory, but that is not currently
237        // supported. This doesn't seem to occur often in GitHub repositories.
238        selected_rules_file.map(|path_in_worktree| {
239            let project_path = ProjectPath {
240                worktree_id,
241                path: path_in_worktree.clone(),
242            };
243            let buffer_task =
244                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
245            let rope_task = cx.spawn(async move |cx| {
246                buffer_task.await?.read_with(cx, |buffer, cx| {
247                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
248                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
249                })?
250            });
251            // Build a string from the rope on a background thread.
252            cx.background_spawn(async move {
253                let (project_entry_id, rope) = rope_task.await?;
254                anyhow::Ok(RulesFileContext {
255                    path_in_worktree,
256                    text: rope.to_string().trim().to_string(),
257                    project_entry_id: project_entry_id.to_usize(),
258                })
259            })
260        })
261    }
262
263    fn handle_project_event(
264        &mut self,
265        _project: Entity<Project>,
266        event: &project::Event,
267        _cx: &mut Context<Self>,
268    ) {
269        match event {
270            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
271                self.project_context_needs_refresh.send(()).ok();
272            }
273            project::Event::WorktreeUpdatedEntries(_, items) => {
274                if items.iter().any(|(path, _, _)| {
275                    RULES_FILE_NAMES
276                        .iter()
277                        .any(|name| path.as_ref() == Path::new(name))
278                }) {
279                    self.project_context_needs_refresh.send(()).ok();
280                }
281            }
282            _ => {}
283        }
284    }
285
286    fn handle_prompts_updated_event(
287        &mut self,
288        _prompt_store: Entity<PromptStore>,
289        _event: &prompt_store::PromptsUpdatedEvent,
290        _cx: &mut Context<Self>,
291    ) {
292        self.project_context_needs_refresh.send(()).ok();
293    }
294}
295
296/// Wrapper struct that implements the AgentConnection trait
297#[derive(Clone)]
298pub struct NativeAgentConnection(pub Entity<NativeAgent>);
299
300impl ModelSelector for NativeAgentConnection {
301    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
302        log::debug!("NativeAgentConnection::list_models called");
303        cx.spawn(async move |cx| {
304            cx.update(|cx| {
305                let registry = LanguageModelRegistry::read_global(cx);
306                let models = registry.available_models(cx).collect::<Vec<_>>();
307                log::info!("Found {} available models", models.len());
308                if models.is_empty() {
309                    Err(anyhow::anyhow!("No models available"))
310                } else {
311                    Ok(models)
312                }
313            })?
314        })
315    }
316
317    fn select_model(
318        &self,
319        session_id: acp::SessionId,
320        model: Arc<dyn LanguageModel>,
321        cx: &mut AsyncApp,
322    ) -> Task<Result<()>> {
323        log::info!(
324            "Setting model for session {}: {:?}",
325            session_id,
326            model.name()
327        );
328        let agent = self.0.clone();
329
330        cx.spawn(async move |cx| {
331            agent.update(cx, |agent, cx| {
332                if let Some(session) = agent.sessions.get(&session_id) {
333                    session.thread.update(cx, |thread, _cx| {
334                        thread.selected_model = model;
335                    });
336                    Ok(())
337                } else {
338                    Err(anyhow!("Session not found"))
339                }
340            })?
341        })
342    }
343
344    fn selected_model(
345        &self,
346        session_id: &acp::SessionId,
347        cx: &mut AsyncApp,
348    ) -> Task<Result<Arc<dyn LanguageModel>>> {
349        let agent = self.0.clone();
350        let session_id = session_id.clone();
351        cx.spawn(async move |cx| {
352            let thread = agent
353                .read_with(cx, |agent, _| {
354                    agent
355                        .sessions
356                        .get(&session_id)
357                        .map(|session| session.thread.clone())
358                })?
359                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
360            let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
361            Ok(selected)
362        })
363    }
364}
365
366impl acp_thread::AgentConnection for NativeAgentConnection {
367    fn new_thread(
368        self: Rc<Self>,
369        project: Entity<Project>,
370        cwd: &Path,
371        cx: &mut AsyncApp,
372    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
373        let agent = self.0.clone();
374        log::info!("Creating new thread for project at: {:?}", cwd);
375
376        cx.spawn(async move |cx| {
377            log::debug!("Starting thread creation in async context");
378
379            // Generate session ID
380            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
381            log::info!("Created session with ID: {}", session_id);
382
383            // Create AcpThread
384            let acp_thread = cx.update(|cx| {
385                cx.new(|cx| {
386                    acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
387                })
388            })?;
389            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
390
391            // Create Thread
392            let thread = agent.update(
393                cx,
394                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
395                    // Fetch default model from registry settings
396                    let registry = LanguageModelRegistry::read_global(cx);
397
398                    // Log available models for debugging
399                    let available_count = registry.available_models(cx).count();
400                    log::debug!("Total available models: {}", available_count);
401
402                    let default_model = registry
403                        .default_model()
404                        .map(|configured| {
405                            log::info!(
406                                "Using configured default model: {:?} from provider: {:?}",
407                                configured.model.name(),
408                                configured.provider.name()
409                            );
410                            configured.model
411                        })
412                        .ok_or_else(|| {
413                            log::warn!("No default model configured in settings");
414                            anyhow!("No default model configured. Please configure a default model in settings.")
415                        })?;
416
417                    let thread = cx.new(|cx| {
418                        let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
419                        thread.add_tool(ThinkingTool);
420                        thread.add_tool(FindPathTool::new(project.clone()));
421                        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
422                        thread.add_tool(EditFileTool::new(cx.entity()));
423                        thread.add_tool(TerminalTool::new(project.clone(), cx));
424                        thread
425                    });
426
427                    Ok(thread)
428                },
429            )??;
430
431            // Store the session
432            agent.update(cx, |agent, cx| {
433                agent.sessions.insert(
434                    session_id,
435                    Session {
436                        thread,
437                        acp_thread: acp_thread.downgrade(),
438                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
439                            this.sessions.remove(acp_thread.session_id());
440                        })
441                    },
442                );
443            })?;
444
445            Ok(acp_thread)
446        })
447    }
448
449    fn auth_methods(&self) -> &[acp::AuthMethod] {
450        &[] // No auth for in-process
451    }
452
453    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
454        Task::ready(Ok(()))
455    }
456
457    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
458        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
459    }
460
461    fn prompt(
462        &self,
463        params: acp::PromptRequest,
464        cx: &mut App,
465    ) -> Task<Result<acp::PromptResponse>> {
466        let session_id = params.session_id.clone();
467        let agent = self.0.clone();
468        log::info!("Received prompt request for session: {}", session_id);
469        log::debug!("Prompt blocks count: {}", params.prompt.len());
470
471        cx.spawn(async move |cx| {
472            // Get session
473            let (thread, acp_thread) = agent
474                .update(cx, |agent, _| {
475                    agent
476                        .sessions
477                        .get_mut(&session_id)
478                        .map(|s| (s.thread.clone(), s.acp_thread.clone()))
479                })?
480                .ok_or_else(|| {
481                    log::error!("Session not found: {}", session_id);
482                    anyhow::anyhow!("Session not found")
483                })?;
484            log::debug!("Found session for: {}", session_id);
485
486            // Convert prompt to message
487            let message = convert_prompt_to_message(params.prompt);
488            log::info!("Converted prompt to message: {} chars", message.len());
489            log::debug!("Message content: {}", message);
490
491            // Get model using the ModelSelector capability (always available for agent2)
492            // Get the selected model from the thread directly
493            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
494
495            // Send to thread
496            log::info!("Sending message to thread with model: {:?}", model.name());
497            let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
498
499            // Handle response stream and forward to session.acp_thread
500            while let Some(result) = response_stream.next().await {
501                match result {
502                    Ok(event) => {
503                        log::trace!("Received completion event: {:?}", event);
504
505                        match event {
506                            AgentResponseEvent::Text(text) => {
507                                acp_thread.update(cx, |thread, cx| {
508                                    thread.push_assistant_content_block(
509                                        acp::ContentBlock::Text(acp::TextContent {
510                                            text,
511                                            annotations: None,
512                                        }),
513                                        false,
514                                        cx,
515                                    )
516                                })?;
517                            }
518                            AgentResponseEvent::Thinking(text) => {
519                                acp_thread.update(cx, |thread, cx| {
520                                    thread.push_assistant_content_block(
521                                        acp::ContentBlock::Text(acp::TextContent {
522                                            text,
523                                            annotations: None,
524                                        }),
525                                        true,
526                                        cx,
527                                    )
528                                })?;
529                            }
530                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
531                                tool_call,
532                                options,
533                                response,
534                            }) => {
535                                let recv = acp_thread.update(cx, |thread, cx| {
536                                    thread.request_tool_call_authorization(tool_call, options, cx)
537                                })?;
538                                cx.background_spawn(async move {
539                                    if let Some(option) = recv
540                                        .await
541                                        .context("authorization sender was dropped")
542                                        .log_err()
543                                    {
544                                        response
545                                            .send(option)
546                                            .map(|_| anyhow!("authorization receiver was dropped"))
547                                            .log_err();
548                                    }
549                                })
550                                .detach();
551                            }
552                            AgentResponseEvent::ToolCall(tool_call) => {
553                                acp_thread.update(cx, |thread, cx| {
554                                    thread.upsert_tool_call(tool_call, cx)
555                                })?;
556                            }
557                            AgentResponseEvent::ToolCallUpdate(update) => {
558                                acp_thread.update(cx, |thread, cx| {
559                                    thread.update_tool_call(update, cx)
560                                })??;
561                            }
562                            AgentResponseEvent::Stop(stop_reason) => {
563                                log::debug!("Assistant message complete: {:?}", stop_reason);
564                                return Ok(acp::PromptResponse { stop_reason });
565                            }
566                        }
567                    }
568                    Err(e) => {
569                        log::error!("Error in model response stream: {:?}", e);
570                        // TODO: Consider sending an error message to the UI
571                        break;
572                    }
573                }
574            }
575
576            log::info!("Response stream completed");
577            anyhow::Ok(acp::PromptResponse {
578                stop_reason: acp::StopReason::EndTurn,
579            })
580        })
581    }
582
583    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
584        log::info!("Cancelling on session: {}", session_id);
585        self.0.update(cx, |agent, cx| {
586            if let Some(agent) = agent.sessions.get(session_id) {
587                agent.thread.update(cx, |thread, _cx| thread.cancel());
588            }
589        });
590    }
591}
592
593/// Convert ACP content blocks to a message string
594fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
595    log::debug!("Converting {} content blocks to message", blocks.len());
596    let mut message = String::new();
597
598    for block in blocks {
599        match block {
600            acp::ContentBlock::Text(text) => {
601                log::trace!("Processing text block: {} chars", text.text.len());
602                message.push_str(&text.text);
603            }
604            acp::ContentBlock::ResourceLink(link) => {
605                log::trace!("Processing resource link: {}", link.uri);
606                message.push_str(&format!(" @{} ", link.uri));
607            }
608            acp::ContentBlock::Image(_) => {
609                log::trace!("Processing image block");
610                message.push_str(" [image] ");
611            }
612            acp::ContentBlock::Audio(_) => {
613                log::trace!("Processing audio block");
614                message.push_str(" [audio] ");
615            }
616            acp::ContentBlock::Resource(resource) => {
617                log::trace!("Processing resource block: {:?}", resource.resource);
618                message.push_str(&format!(" [resource: {:?}] ", resource.resource));
619            }
620        }
621    }
622
623    message
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629    use fs::FakeFs;
630    use gpui::TestAppContext;
631    use serde_json::json;
632    use settings::SettingsStore;
633
634    #[gpui::test]
635    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
636        init_test(cx);
637        let fs = FakeFs::new(cx.executor());
638        fs.insert_tree(
639            "/",
640            json!({
641                "a": {}
642            }),
643        )
644        .await;
645        let project = Project::test(fs.clone(), [], cx).await;
646        let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
647            .await
648            .unwrap();
649        agent.read_with(cx, |agent, _| {
650            assert_eq!(agent.project_context.borrow().worktrees, vec![])
651        });
652
653        let worktree = project
654            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
655            .await
656            .unwrap();
657        cx.run_until_parked();
658        agent.read_with(cx, |agent, _| {
659            assert_eq!(
660                agent.project_context.borrow().worktrees,
661                vec![WorktreeContext {
662                    root_name: "a".into(),
663                    abs_path: Path::new("/a").into(),
664                    rules_file: None
665                }]
666            )
667        });
668
669        // Creating `/a/.rules` updates the project context.
670        fs.insert_file("/a/.rules", Vec::new()).await;
671        cx.run_until_parked();
672        agent.read_with(cx, |agent, cx| {
673            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
674            assert_eq!(
675                agent.project_context.borrow().worktrees,
676                vec![WorktreeContext {
677                    root_name: "a".into(),
678                    abs_path: Path::new("/a").into(),
679                    rules_file: Some(RulesFileContext {
680                        path_in_worktree: Path::new(".rules").into(),
681                        text: "".into(),
682                        project_entry_id: rules_entry.id.to_usize()
683                    })
684                }]
685            )
686        });
687    }
688
689    fn init_test(cx: &mut TestAppContext) {
690        env_logger::try_init().ok();
691        cx.update(|cx| {
692            let settings_store = SettingsStore::test(cx);
693            cx.set_global(settings_store);
694            Project::init_settings(cx);
695            language::init(cx);
696        });
697    }
698}