agent.rs

  1use crate::{templates::Templates, AgentResponseEvent, Thread};
  2use crate::{FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
  3use acp_thread::ModelSelector;
  4use agent_client_protocol as acp;
  5use anyhow::{anyhow, Context as _, Result};
  6use futures::{future, StreamExt};
  7use gpui::{
  8    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  9};
 10use language_model::{LanguageModel, LanguageModelRegistry};
 11use project::{Project, ProjectItem, ProjectPath, Worktree};
 12use prompt_store::{
 13    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
 14};
 15use std::cell::RefCell;
 16use std::collections::HashMap;
 17use std::path::Path;
 18use std::rc::Rc;
 19use std::sync::Arc;
 20use util::ResultExt;
 21
 22const RULES_FILE_NAMES: [&'static str; 9] = [
 23    ".rules",
 24    ".cursorrules",
 25    ".windsurfrules",
 26    ".clinerules",
 27    ".github/copilot-instructions.md",
 28    "CLAUDE.md",
 29    "AGENT.md",
 30    "AGENTS.md",
 31    "GEMINI.md",
 32];
 33
 34pub struct RulesLoadingError {
 35    pub message: SharedString,
 36}
 37
 38/// Holds both the internal Thread and the AcpThread for a session
 39struct Session {
 40    /// The internal thread that processes messages
 41    thread: Entity<Thread>,
 42    /// The ACP thread that handles protocol communication
 43    acp_thread: WeakEntity<acp_thread::AcpThread>,
 44    _subscription: Subscription,
 45}
 46
 47pub struct NativeAgent {
 48    /// Session ID -> Session mapping
 49    sessions: HashMap<acp::SessionId, Session>,
 50    /// Shared project context for all threads
 51    project_context: Rc<RefCell<ProjectContext>>,
 52    project_context_needs_refresh: watch::Sender<()>,
 53    _maintain_project_context: Task<Result<()>>,
 54    /// Shared templates for all threads
 55    templates: Arc<Templates>,
 56    project: Entity<Project>,
 57    prompt_store: Option<Entity<PromptStore>>,
 58    _subscriptions: Vec<Subscription>,
 59}
 60
 61impl NativeAgent {
 62    pub async fn new(
 63        project: Entity<Project>,
 64        templates: Arc<Templates>,
 65        prompt_store: Option<Entity<PromptStore>>,
 66        cx: &mut AsyncApp,
 67    ) -> Result<Entity<NativeAgent>> {
 68        log::info!("Creating new NativeAgent");
 69
 70        let project_context = cx
 71            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 72            .await;
 73
 74        cx.new(|cx| {
 75            let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 76            if let Some(prompt_store) = prompt_store.as_ref() {
 77                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 78            }
 79
 80            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 81                watch::channel(());
 82            Self {
 83                sessions: HashMap::new(),
 84                project_context: Rc::new(RefCell::new(project_context)),
 85                project_context_needs_refresh: project_context_needs_refresh_tx,
 86                _maintain_project_context: cx.spawn(async move |this, cx| {
 87                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 88                }),
 89                templates,
 90                project,
 91                prompt_store,
 92                _subscriptions: subscriptions,
 93            }
 94        })
 95    }
 96
 97    async fn maintain_project_context(
 98        this: WeakEntity<Self>,
 99        mut needs_refresh: watch::Receiver<()>,
100        cx: &mut AsyncApp,
101    ) -> Result<()> {
102        while needs_refresh.changed().await.is_ok() {
103            let project_context = this
104                .update(cx, |this, cx| {
105                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
106                })?
107                .await;
108            this.update(cx, |this, _| this.project_context.replace(project_context))?;
109        }
110
111        Ok(())
112    }
113
114    fn build_project_context(
115        project: &Entity<Project>,
116        prompt_store: Option<&Entity<PromptStore>>,
117        cx: &mut App,
118    ) -> Task<ProjectContext> {
119        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
120        let worktree_tasks = worktrees
121            .into_iter()
122            .map(|worktree| {
123                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
124            })
125            .collect::<Vec<_>>();
126        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
127            prompt_store.read_with(cx, |prompt_store, cx| {
128                let prompts = prompt_store.default_prompt_metadata();
129                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
130                    let contents = prompt_store.load(prompt_metadata.id, cx);
131                    async move { (contents.await, prompt_metadata) }
132                });
133                cx.background_spawn(future::join_all(load_tasks))
134            })
135        } else {
136            Task::ready(vec![])
137        };
138
139        cx.spawn(async move |_cx| {
140            let (worktrees, default_user_rules) =
141                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
142
143            let worktrees = worktrees
144                .into_iter()
145                .map(|(worktree, _rules_error)| {
146                    // TODO: show error message
147                    // if let Some(rules_error) = rules_error {
148                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
149                    // }
150                    worktree
151                })
152                .collect::<Vec<_>>();
153
154            let default_user_rules = default_user_rules
155                .into_iter()
156                .flat_map(|(contents, prompt_metadata)| match contents {
157                    Ok(contents) => Some(UserRulesContext {
158                        uuid: match prompt_metadata.id {
159                            PromptId::User { uuid } => uuid,
160                            PromptId::EditWorkflow => return None,
161                        },
162                        title: prompt_metadata.title.map(|title| title.to_string()),
163                        contents,
164                    }),
165                    Err(_err) => {
166                        // TODO: show error message
167                        // this.update(cx, |_, cx| {
168                        //     cx.emit(RulesLoadingError {
169                        //         message: format!("{err:?}").into(),
170                        //     });
171                        // })
172                        // .ok();
173                        None
174                    }
175                })
176                .collect::<Vec<_>>();
177
178            ProjectContext::new(worktrees, default_user_rules)
179        })
180    }
181
182    fn load_worktree_info_for_system_prompt(
183        worktree: Entity<Worktree>,
184        project: Entity<Project>,
185        cx: &mut App,
186    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
187        let tree = worktree.read(cx);
188        let root_name = tree.root_name().into();
189        let abs_path = tree.abs_path();
190
191        let mut context = WorktreeContext {
192            root_name,
193            abs_path,
194            rules_file: None,
195        };
196
197        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
198        let Some(rules_task) = rules_task else {
199            return Task::ready((context, None));
200        };
201
202        cx.spawn(async move |_| {
203            let (rules_file, rules_file_error) = match rules_task.await {
204                Ok(rules_file) => (Some(rules_file), None),
205                Err(err) => (
206                    None,
207                    Some(RulesLoadingError {
208                        message: format!("{err}").into(),
209                    }),
210                ),
211            };
212            context.rules_file = rules_file;
213            (context, rules_file_error)
214        })
215    }
216
217    fn load_worktree_rules_file(
218        worktree: Entity<Worktree>,
219        project: Entity<Project>,
220        cx: &mut App,
221    ) -> Option<Task<Result<RulesFileContext>>> {
222        let worktree = worktree.read(cx);
223        let worktree_id = worktree.id();
224        let selected_rules_file = RULES_FILE_NAMES
225            .into_iter()
226            .filter_map(|name| {
227                worktree
228                    .entry_for_path(name)
229                    .filter(|entry| entry.is_file())
230                    .map(|entry| entry.path.clone())
231            })
232            .next();
233
234        // Note that Cline supports `.clinerules` being a directory, but that is not currently
235        // supported. This doesn't seem to occur often in GitHub repositories.
236        selected_rules_file.map(|path_in_worktree| {
237            let project_path = ProjectPath {
238                worktree_id,
239                path: path_in_worktree.clone(),
240            };
241            let buffer_task =
242                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
243            let rope_task = cx.spawn(async move |cx| {
244                buffer_task.await?.read_with(cx, |buffer, cx| {
245                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
246                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
247                })?
248            });
249            // Build a string from the rope on a background thread.
250            cx.background_spawn(async move {
251                let (project_entry_id, rope) = rope_task.await?;
252                anyhow::Ok(RulesFileContext {
253                    path_in_worktree,
254                    text: rope.to_string().trim().to_string(),
255                    project_entry_id: project_entry_id.to_usize(),
256                })
257            })
258        })
259    }
260
261    fn handle_project_event(
262        &mut self,
263        _project: Entity<Project>,
264        event: &project::Event,
265        _cx: &mut Context<Self>,
266    ) {
267        match event {
268            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
269                self.project_context_needs_refresh.send(()).ok();
270            }
271            project::Event::WorktreeUpdatedEntries(_, items) => {
272                if items.iter().any(|(path, _, _)| {
273                    RULES_FILE_NAMES
274                        .iter()
275                        .any(|name| path.as_ref() == Path::new(name))
276                }) {
277                    self.project_context_needs_refresh.send(()).ok();
278                }
279            }
280            _ => {}
281        }
282    }
283
284    fn handle_prompts_updated_event(
285        &mut self,
286        _prompt_store: Entity<PromptStore>,
287        _event: &prompt_store::PromptsUpdatedEvent,
288        _cx: &mut Context<Self>,
289    ) {
290        self.project_context_needs_refresh.send(()).ok();
291    }
292}
293
294/// Wrapper struct that implements the AgentConnection trait
295#[derive(Clone)]
296pub struct NativeAgentConnection(pub Entity<NativeAgent>);
297
298impl ModelSelector for NativeAgentConnection {
299    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
300        log::debug!("NativeAgentConnection::list_models called");
301        cx.spawn(async move |cx| {
302            cx.update(|cx| {
303                let registry = LanguageModelRegistry::read_global(cx);
304                let models = registry.available_models(cx).collect::<Vec<_>>();
305                log::info!("Found {} available models", models.len());
306                if models.is_empty() {
307                    Err(anyhow::anyhow!("No models available"))
308                } else {
309                    Ok(models)
310                }
311            })?
312        })
313    }
314
315    fn select_model(
316        &self,
317        session_id: acp::SessionId,
318        model: Arc<dyn LanguageModel>,
319        cx: &mut AsyncApp,
320    ) -> Task<Result<()>> {
321        log::info!(
322            "Setting model for session {}: {:?}",
323            session_id,
324            model.name()
325        );
326        let agent = self.0.clone();
327
328        cx.spawn(async move |cx| {
329            agent.update(cx, |agent, cx| {
330                if let Some(session) = agent.sessions.get(&session_id) {
331                    session.thread.update(cx, |thread, _cx| {
332                        thread.selected_model = model;
333                    });
334                    Ok(())
335                } else {
336                    Err(anyhow!("Session not found"))
337                }
338            })?
339        })
340    }
341
342    fn selected_model(
343        &self,
344        session_id: &acp::SessionId,
345        cx: &mut AsyncApp,
346    ) -> Task<Result<Arc<dyn LanguageModel>>> {
347        let agent = self.0.clone();
348        let session_id = session_id.clone();
349        cx.spawn(async move |cx| {
350            let thread = agent
351                .read_with(cx, |agent, _| {
352                    agent
353                        .sessions
354                        .get(&session_id)
355                        .map(|session| session.thread.clone())
356                })?
357                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
358            let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
359            Ok(selected)
360        })
361    }
362}
363
364impl acp_thread::AgentConnection for NativeAgentConnection {
365    fn new_thread(
366        self: Rc<Self>,
367        project: Entity<Project>,
368        cwd: &Path,
369        cx: &mut AsyncApp,
370    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
371        let agent = self.0.clone();
372        log::info!("Creating new thread for project at: {:?}", cwd);
373
374        cx.spawn(async move |cx| {
375            log::debug!("Starting thread creation in async context");
376
377            // Generate session ID
378            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
379            log::info!("Created session with ID: {}", session_id);
380
381            // Create AcpThread
382            let acp_thread = cx.update(|cx| {
383                cx.new(|cx| {
384                    acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
385                })
386            })?;
387            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
388
389            // Create Thread
390            let thread = agent.update(
391                cx,
392                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
393                    // Fetch default model from registry settings
394                    let registry = LanguageModelRegistry::read_global(cx);
395
396                    // Log available models for debugging
397                    let available_count = registry.available_models(cx).count();
398                    log::debug!("Total available models: {}", available_count);
399
400                    let default_model = registry
401                        .default_model()
402                        .map(|configured| {
403                            log::info!(
404                                "Using configured default model: {:?} from provider: {:?}",
405                                configured.model.name(),
406                                configured.provider.name()
407                            );
408                            configured.model
409                        })
410                        .ok_or_else(|| {
411                            log::warn!("No default model configured in settings");
412                            anyhow!("No default model configured. Please configure a default model in settings.")
413                        })?;
414
415                    let thread = cx.new(|_| {
416                        let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
417                        thread.add_tool(ThinkingTool);
418                        thread.add_tool(FindPathTool::new(project.clone()));
419                        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
420                        thread
421                    });
422
423                    Ok(thread)
424                },
425            )??;
426
427            // Store the session
428            agent.update(cx, |agent, cx| {
429                agent.sessions.insert(
430                    session_id,
431                    Session {
432                        thread,
433                        acp_thread: acp_thread.downgrade(),
434                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
435                            this.sessions.remove(acp_thread.session_id());
436                        })
437                    },
438                );
439            })?;
440
441            Ok(acp_thread)
442        })
443    }
444
445    fn auth_methods(&self) -> &[acp::AuthMethod] {
446        &[] // No auth for in-process
447    }
448
449    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
450        Task::ready(Ok(()))
451    }
452
453    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
454        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
455    }
456
457    fn prompt(
458        &self,
459        params: acp::PromptRequest,
460        cx: &mut App,
461    ) -> Task<Result<acp::PromptResponse>> {
462        let session_id = params.session_id.clone();
463        let agent = self.0.clone();
464        log::info!("Received prompt request for session: {}", session_id);
465        log::debug!("Prompt blocks count: {}", params.prompt.len());
466
467        cx.spawn(async move |cx| {
468            // Get session
469            let (thread, acp_thread) = agent
470                .update(cx, |agent, _| {
471                    agent
472                        .sessions
473                        .get_mut(&session_id)
474                        .map(|s| (s.thread.clone(), s.acp_thread.clone()))
475                })?
476                .ok_or_else(|| {
477                    log::error!("Session not found: {}", session_id);
478                    anyhow::anyhow!("Session not found")
479                })?;
480            log::debug!("Found session for: {}", session_id);
481
482            // Convert prompt to message
483            let message = convert_prompt_to_message(params.prompt);
484            log::info!("Converted prompt to message: {} chars", message.len());
485            log::debug!("Message content: {}", message);
486
487            // Get model using the ModelSelector capability (always available for agent2)
488            // Get the selected model from the thread directly
489            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
490
491            // Send to thread
492            log::info!("Sending message to thread with model: {:?}", model.name());
493            let mut response_stream =
494                thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
495
496            // Handle response stream and forward to session.acp_thread
497            while let Some(result) = response_stream.next().await {
498                match result {
499                    Ok(event) => {
500                        log::trace!("Received completion event: {:?}", event);
501
502                        match event {
503                            AgentResponseEvent::Text(text) => {
504                                acp_thread.update(cx, |thread, cx| {
505                                    thread.handle_session_update(
506                                        acp::SessionUpdate::AgentMessageChunk {
507                                            content: acp::ContentBlock::Text(acp::TextContent {
508                                                text,
509                                                annotations: None,
510                                            }),
511                                        },
512                                        cx,
513                                    )
514                                })??;
515                            }
516                            AgentResponseEvent::Thinking(text) => {
517                                acp_thread.update(cx, |thread, cx| {
518                                    thread.handle_session_update(
519                                        acp::SessionUpdate::AgentThoughtChunk {
520                                            content: acp::ContentBlock::Text(acp::TextContent {
521                                                text,
522                                                annotations: None,
523                                            }),
524                                        },
525                                        cx,
526                                    )
527                                })??;
528                            }
529                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
530                                tool_call,
531                                options,
532                                response,
533                            }) => {
534                                let recv = acp_thread.update(cx, |thread, cx| {
535                                    thread.request_tool_call_authorization(tool_call, options, cx)
536                                })?;
537                                cx.background_spawn(async move {
538                                    if let Some(option) = recv
539                                        .await
540                                        .context("authorization sender was dropped")
541                                        .log_err()
542                                    {
543                                        response
544                                            .send(option)
545                                            .map(|_| anyhow!("authorization receiver was dropped"))
546                                            .log_err();
547                                    }
548                                })
549                                .detach();
550                            }
551                            AgentResponseEvent::ToolCall(tool_call) => {
552                                acp_thread.update(cx, |thread, cx| {
553                                    thread.handle_session_update(
554                                        acp::SessionUpdate::ToolCall(tool_call),
555                                        cx,
556                                    )
557                                })??;
558                            }
559                            AgentResponseEvent::ToolCallUpdate(tool_call_update) => {
560                                acp_thread.update(cx, |thread, cx| {
561                                    thread.handle_session_update(
562                                        acp::SessionUpdate::ToolCallUpdate(tool_call_update),
563                                        cx,
564                                    )
565                                })??;
566                            }
567                            AgentResponseEvent::Stop(stop_reason) => {
568                                log::debug!("Assistant message complete: {:?}", stop_reason);
569                                return Ok(acp::PromptResponse { stop_reason });
570                            }
571                        }
572                    }
573                    Err(e) => {
574                        log::error!("Error in model response stream: {:?}", e);
575                        // TODO: Consider sending an error message to the UI
576                        break;
577                    }
578                }
579            }
580
581            log::info!("Response stream completed");
582            anyhow::Ok(acp::PromptResponse {
583                stop_reason: acp::StopReason::EndTurn,
584            })
585        })
586    }
587
588    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
589        log::info!("Cancelling on session: {}", session_id);
590        self.0.update(cx, |agent, cx| {
591            if let Some(agent) = agent.sessions.get(session_id) {
592                agent.thread.update(cx, |thread, _cx| thread.cancel());
593            }
594        });
595    }
596}
597
598/// Convert ACP content blocks to a message string
599fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
600    log::debug!("Converting {} content blocks to message", blocks.len());
601    let mut message = String::new();
602
603    for block in blocks {
604        match block {
605            acp::ContentBlock::Text(text) => {
606                log::trace!("Processing text block: {} chars", text.text.len());
607                message.push_str(&text.text);
608            }
609            acp::ContentBlock::ResourceLink(link) => {
610                log::trace!("Processing resource link: {}", link.uri);
611                message.push_str(&format!(" @{} ", link.uri));
612            }
613            acp::ContentBlock::Image(_) => {
614                log::trace!("Processing image block");
615                message.push_str(" [image] ");
616            }
617            acp::ContentBlock::Audio(_) => {
618                log::trace!("Processing audio block");
619                message.push_str(" [audio] ");
620            }
621            acp::ContentBlock::Resource(resource) => {
622                log::trace!("Processing resource block: {:?}", resource.resource);
623                message.push_str(&format!(" [resource: {:?}] ", resource.resource));
624            }
625        }
626    }
627
628    message
629}
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634    use fs::FakeFs;
635    use gpui::TestAppContext;
636    use serde_json::json;
637    use settings::SettingsStore;
638
639    #[gpui::test]
640    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
641        init_test(cx);
642        let fs = FakeFs::new(cx.executor());
643        fs.insert_tree(
644            "/",
645            json!({
646                "a": {}
647            }),
648        )
649        .await;
650        let project = Project::test(fs.clone(), [], cx).await;
651        let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
652            .await
653            .unwrap();
654        agent.read_with(cx, |agent, _| {
655            assert_eq!(agent.project_context.borrow().worktrees, vec![])
656        });
657
658        let worktree = project
659            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
660            .await
661            .unwrap();
662        cx.run_until_parked();
663        agent.read_with(cx, |agent, _| {
664            assert_eq!(
665                agent.project_context.borrow().worktrees,
666                vec![WorktreeContext {
667                    root_name: "a".into(),
668                    abs_path: Path::new("/a").into(),
669                    rules_file: None
670                }]
671            )
672        });
673
674        // Creating `/a/.rules` updates the project context.
675        fs.insert_file("/a/.rules", Vec::new()).await;
676        cx.run_until_parked();
677        agent.read_with(cx, |agent, cx| {
678            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
679            assert_eq!(
680                agent.project_context.borrow().worktrees,
681                vec![WorktreeContext {
682                    root_name: "a".into(),
683                    abs_path: Path::new("/a").into(),
684                    rules_file: Some(RulesFileContext {
685                        path_in_worktree: Path::new(".rules").into(),
686                        text: "".into(),
687                        project_entry_id: rules_entry.id.to_usize()
688                    })
689                }]
690            )
691        });
692    }
693
694    fn init_test(cx: &mut TestAppContext) {
695        env_logger::try_init().ok();
696        cx.update(|cx| {
697            let settings_store = SettingsStore::test(cx);
698            cx.set_global(settings_store);
699            Project::init_settings(cx);
700            language::init(cx);
701        });
702    }
703}