agent.rs

  1use crate::{AgentResponseEvent, Thread, templates::Templates};
  2use crate::{
  3    CopyPathTool, CreateDirectoryTool, EditFileTool, FindPathTool, GrepTool, ListDirectoryTool,
  4    MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool,
  5    ToolCallAuthorization, WebSearchTool,
  6};
  7use acp_thread::ModelSelector;
  8use agent_client_protocol as acp;
  9use anyhow::{Context as _, Result, anyhow};
 10use futures::{StreamExt, future};
 11use gpui::{
 12    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
 13};
 14use language_model::{LanguageModel, LanguageModelRegistry};
 15use project::{Project, ProjectItem, ProjectPath, Worktree};
 16use prompt_store::{
 17    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
 18};
 19use std::cell::RefCell;
 20use std::collections::HashMap;
 21use std::path::Path;
 22use std::rc::Rc;
 23use std::sync::Arc;
 24use util::ResultExt;
 25
 26const RULES_FILE_NAMES: [&'static str; 9] = [
 27    ".rules",
 28    ".cursorrules",
 29    ".windsurfrules",
 30    ".clinerules",
 31    ".github/copilot-instructions.md",
 32    "CLAUDE.md",
 33    "AGENT.md",
 34    "AGENTS.md",
 35    "GEMINI.md",
 36];
 37
 38pub struct RulesLoadingError {
 39    pub message: SharedString,
 40}
 41
 42/// Holds both the internal Thread and the AcpThread for a session
 43struct Session {
 44    /// The internal thread that processes messages
 45    thread: Entity<Thread>,
 46    /// The ACP thread that handles protocol communication
 47    acp_thread: WeakEntity<acp_thread::AcpThread>,
 48    _subscription: Subscription,
 49}
 50
 51pub struct NativeAgent {
 52    /// Session ID -> Session mapping
 53    sessions: HashMap<acp::SessionId, Session>,
 54    /// Shared project context for all threads
 55    project_context: Rc<RefCell<ProjectContext>>,
 56    project_context_needs_refresh: watch::Sender<()>,
 57    _maintain_project_context: Task<Result<()>>,
 58    /// Shared templates for all threads
 59    templates: Arc<Templates>,
 60    project: Entity<Project>,
 61    prompt_store: Option<Entity<PromptStore>>,
 62    _subscriptions: Vec<Subscription>,
 63}
 64
 65impl NativeAgent {
 66    pub async fn new(
 67        project: Entity<Project>,
 68        templates: Arc<Templates>,
 69        prompt_store: Option<Entity<PromptStore>>,
 70        cx: &mut AsyncApp,
 71    ) -> Result<Entity<NativeAgent>> {
 72        log::info!("Creating new NativeAgent");
 73
 74        let project_context = cx
 75            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 76            .await;
 77
 78        cx.new(|cx| {
 79            let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 80            if let Some(prompt_store) = prompt_store.as_ref() {
 81                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 82            }
 83
 84            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 85                watch::channel(());
 86            Self {
 87                sessions: HashMap::new(),
 88                project_context: Rc::new(RefCell::new(project_context)),
 89                project_context_needs_refresh: project_context_needs_refresh_tx,
 90                _maintain_project_context: cx.spawn(async move |this, cx| {
 91                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 92                }),
 93                templates,
 94                project,
 95                prompt_store,
 96                _subscriptions: subscriptions,
 97            }
 98        })
 99    }
100
101    async fn maintain_project_context(
102        this: WeakEntity<Self>,
103        mut needs_refresh: watch::Receiver<()>,
104        cx: &mut AsyncApp,
105    ) -> Result<()> {
106        while needs_refresh.changed().await.is_ok() {
107            let project_context = this
108                .update(cx, |this, cx| {
109                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
110                })?
111                .await;
112            this.update(cx, |this, _| this.project_context.replace(project_context))?;
113        }
114
115        Ok(())
116    }
117
118    fn build_project_context(
119        project: &Entity<Project>,
120        prompt_store: Option<&Entity<PromptStore>>,
121        cx: &mut App,
122    ) -> Task<ProjectContext> {
123        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
124        let worktree_tasks = worktrees
125            .into_iter()
126            .map(|worktree| {
127                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
128            })
129            .collect::<Vec<_>>();
130        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
131            prompt_store.read_with(cx, |prompt_store, cx| {
132                let prompts = prompt_store.default_prompt_metadata();
133                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
134                    let contents = prompt_store.load(prompt_metadata.id, cx);
135                    async move { (contents.await, prompt_metadata) }
136                });
137                cx.background_spawn(future::join_all(load_tasks))
138            })
139        } else {
140            Task::ready(vec![])
141        };
142
143        cx.spawn(async move |_cx| {
144            let (worktrees, default_user_rules) =
145                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
146
147            let worktrees = worktrees
148                .into_iter()
149                .map(|(worktree, _rules_error)| {
150                    // TODO: show error message
151                    // if let Some(rules_error) = rules_error {
152                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
153                    // }
154                    worktree
155                })
156                .collect::<Vec<_>>();
157
158            let default_user_rules = default_user_rules
159                .into_iter()
160                .flat_map(|(contents, prompt_metadata)| match contents {
161                    Ok(contents) => Some(UserRulesContext {
162                        uuid: match prompt_metadata.id {
163                            PromptId::User { uuid } => uuid,
164                            PromptId::EditWorkflow => return None,
165                        },
166                        title: prompt_metadata.title.map(|title| title.to_string()),
167                        contents,
168                    }),
169                    Err(_err) => {
170                        // TODO: show error message
171                        // this.update(cx, |_, cx| {
172                        //     cx.emit(RulesLoadingError {
173                        //         message: format!("{err:?}").into(),
174                        //     });
175                        // })
176                        // .ok();
177                        None
178                    }
179                })
180                .collect::<Vec<_>>();
181
182            ProjectContext::new(worktrees, default_user_rules)
183        })
184    }
185
186    fn load_worktree_info_for_system_prompt(
187        worktree: Entity<Worktree>,
188        project: Entity<Project>,
189        cx: &mut App,
190    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
191        let tree = worktree.read(cx);
192        let root_name = tree.root_name().into();
193        let abs_path = tree.abs_path();
194
195        let mut context = WorktreeContext {
196            root_name,
197            abs_path,
198            rules_file: None,
199        };
200
201        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
202        let Some(rules_task) = rules_task else {
203            return Task::ready((context, None));
204        };
205
206        cx.spawn(async move |_| {
207            let (rules_file, rules_file_error) = match rules_task.await {
208                Ok(rules_file) => (Some(rules_file), None),
209                Err(err) => (
210                    None,
211                    Some(RulesLoadingError {
212                        message: format!("{err}").into(),
213                    }),
214                ),
215            };
216            context.rules_file = rules_file;
217            (context, rules_file_error)
218        })
219    }
220
221    fn load_worktree_rules_file(
222        worktree: Entity<Worktree>,
223        project: Entity<Project>,
224        cx: &mut App,
225    ) -> Option<Task<Result<RulesFileContext>>> {
226        let worktree = worktree.read(cx);
227        let worktree_id = worktree.id();
228        let selected_rules_file = RULES_FILE_NAMES
229            .into_iter()
230            .filter_map(|name| {
231                worktree
232                    .entry_for_path(name)
233                    .filter(|entry| entry.is_file())
234                    .map(|entry| entry.path.clone())
235            })
236            .next();
237
238        // Note that Cline supports `.clinerules` being a directory, but that is not currently
239        // supported. This doesn't seem to occur often in GitHub repositories.
240        selected_rules_file.map(|path_in_worktree| {
241            let project_path = ProjectPath {
242                worktree_id,
243                path: path_in_worktree.clone(),
244            };
245            let buffer_task =
246                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
247            let rope_task = cx.spawn(async move |cx| {
248                buffer_task.await?.read_with(cx, |buffer, cx| {
249                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
250                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
251                })?
252            });
253            // Build a string from the rope on a background thread.
254            cx.background_spawn(async move {
255                let (project_entry_id, rope) = rope_task.await?;
256                anyhow::Ok(RulesFileContext {
257                    path_in_worktree,
258                    text: rope.to_string().trim().to_string(),
259                    project_entry_id: project_entry_id.to_usize(),
260                })
261            })
262        })
263    }
264
265    fn handle_project_event(
266        &mut self,
267        _project: Entity<Project>,
268        event: &project::Event,
269        _cx: &mut Context<Self>,
270    ) {
271        match event {
272            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
273                self.project_context_needs_refresh.send(()).ok();
274            }
275            project::Event::WorktreeUpdatedEntries(_, items) => {
276                if items.iter().any(|(path, _, _)| {
277                    RULES_FILE_NAMES
278                        .iter()
279                        .any(|name| path.as_ref() == Path::new(name))
280                }) {
281                    self.project_context_needs_refresh.send(()).ok();
282                }
283            }
284            _ => {}
285        }
286    }
287
288    fn handle_prompts_updated_event(
289        &mut self,
290        _prompt_store: Entity<PromptStore>,
291        _event: &prompt_store::PromptsUpdatedEvent,
292        _cx: &mut Context<Self>,
293    ) {
294        self.project_context_needs_refresh.send(()).ok();
295    }
296}
297
298/// Wrapper struct that implements the AgentConnection trait
299#[derive(Clone)]
300pub struct NativeAgentConnection(pub Entity<NativeAgent>);
301
302impl ModelSelector for NativeAgentConnection {
303    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
304        log::debug!("NativeAgentConnection::list_models called");
305        cx.spawn(async move |cx| {
306            cx.update(|cx| {
307                let registry = LanguageModelRegistry::read_global(cx);
308                let models = registry.available_models(cx).collect::<Vec<_>>();
309                log::info!("Found {} available models", models.len());
310                if models.is_empty() {
311                    Err(anyhow::anyhow!("No models available"))
312                } else {
313                    Ok(models)
314                }
315            })?
316        })
317    }
318
319    fn select_model(
320        &self,
321        session_id: acp::SessionId,
322        model: Arc<dyn LanguageModel>,
323        cx: &mut AsyncApp,
324    ) -> Task<Result<()>> {
325        log::info!(
326            "Setting model for session {}: {:?}",
327            session_id,
328            model.name()
329        );
330        let agent = self.0.clone();
331
332        cx.spawn(async move |cx| {
333            agent.update(cx, |agent, cx| {
334                if let Some(session) = agent.sessions.get(&session_id) {
335                    session.thread.update(cx, |thread, _cx| {
336                        thread.selected_model = model;
337                    });
338                    Ok(())
339                } else {
340                    Err(anyhow!("Session not found"))
341                }
342            })?
343        })
344    }
345
346    fn selected_model(
347        &self,
348        session_id: &acp::SessionId,
349        cx: &mut AsyncApp,
350    ) -> Task<Result<Arc<dyn LanguageModel>>> {
351        let agent = self.0.clone();
352        let session_id = session_id.clone();
353        cx.spawn(async move |cx| {
354            let thread = agent
355                .read_with(cx, |agent, _| {
356                    agent
357                        .sessions
358                        .get(&session_id)
359                        .map(|session| session.thread.clone())
360                })?
361                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
362            let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
363            Ok(selected)
364        })
365    }
366}
367
368impl acp_thread::AgentConnection for NativeAgentConnection {
369    fn new_thread(
370        self: Rc<Self>,
371        project: Entity<Project>,
372        cwd: &Path,
373        cx: &mut AsyncApp,
374    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
375        let agent = self.0.clone();
376        log::info!("Creating new thread for project at: {:?}", cwd);
377
378        cx.spawn(async move |cx| {
379            log::debug!("Starting thread creation in async context");
380
381            // Generate session ID
382            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
383            log::info!("Created session with ID: {}", session_id);
384
385            // Create AcpThread
386            let acp_thread = cx.update(|cx| {
387                cx.new(|cx| {
388                    acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
389                })
390            })?;
391            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
392
393            // Create Thread
394            let thread = agent.update(
395                cx,
396                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
397                    // Fetch default model from registry settings
398                    let registry = LanguageModelRegistry::read_global(cx);
399
400                    // Log available models for debugging
401                    let available_count = registry.available_models(cx).count();
402                    log::debug!("Total available models: {}", available_count);
403
404                    let default_model = registry
405                        .default_model()
406                        .map(|configured| {
407                            log::info!(
408                                "Using configured default model: {:?} from provider: {:?}",
409                                configured.model.name(),
410                                configured.provider.name()
411                            );
412                            configured.model
413                        })
414                        .ok_or_else(|| {
415                            log::warn!("No default model configured in settings");
416                            anyhow!("No default model configured. Please configure a default model in settings.")
417                        })?;
418
419                    let thread = cx.new(|cx| {
420                        let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
421                        thread.add_tool(CreateDirectoryTool::new(project.clone()));
422                        thread.add_tool(CopyPathTool::new(project.clone()));
423                        thread.add_tool(MovePathTool::new(project.clone()));
424                        thread.add_tool(ListDirectoryTool::new(project.clone()));
425                        thread.add_tool(OpenTool::new(project.clone()));
426                        thread.add_tool(ThinkingTool);
427                        thread.add_tool(FindPathTool::new(project.clone()));
428                        thread.add_tool(GrepTool::new(project.clone()));
429                        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
430                        thread.add_tool(EditFileTool::new(cx.entity()));
431                        thread.add_tool(NowTool);
432                        thread.add_tool(TerminalTool::new(project.clone(), cx));
433                        // TODO: Needs to be conditional based on zed model or not
434                        thread.add_tool(WebSearchTool);
435                        thread
436                    });
437
438                    Ok(thread)
439                },
440            )??;
441
442            // Store the session
443            agent.update(cx, |agent, cx| {
444                agent.sessions.insert(
445                    session_id,
446                    Session {
447                        thread,
448                        acp_thread: acp_thread.downgrade(),
449                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
450                            this.sessions.remove(acp_thread.session_id());
451                        })
452                    },
453                );
454            })?;
455
456            Ok(acp_thread)
457        })
458    }
459
460    fn auth_methods(&self) -> &[acp::AuthMethod] {
461        &[] // No auth for in-process
462    }
463
464    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
465        Task::ready(Ok(()))
466    }
467
468    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
469        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
470    }
471
472    fn prompt(
473        &self,
474        params: acp::PromptRequest,
475        cx: &mut App,
476    ) -> Task<Result<acp::PromptResponse>> {
477        let session_id = params.session_id.clone();
478        let agent = self.0.clone();
479        log::info!("Received prompt request for session: {}", session_id);
480        log::debug!("Prompt blocks count: {}", params.prompt.len());
481
482        cx.spawn(async move |cx| {
483            // Get session
484            let (thread, acp_thread) = agent
485                .update(cx, |agent, _| {
486                    agent
487                        .sessions
488                        .get_mut(&session_id)
489                        .map(|s| (s.thread.clone(), s.acp_thread.clone()))
490                })?
491                .ok_or_else(|| {
492                    log::error!("Session not found: {}", session_id);
493                    anyhow::anyhow!("Session not found")
494                })?;
495            log::debug!("Found session for: {}", session_id);
496
497            // Convert prompt to message
498            let message = convert_prompt_to_message(params.prompt);
499            log::info!("Converted prompt to message: {} chars", message.len());
500            log::debug!("Message content: {}", message);
501
502            // Get model using the ModelSelector capability (always available for agent2)
503            // Get the selected model from the thread directly
504            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
505
506            // Send to thread
507            log::info!("Sending message to thread with model: {:?}", model.name());
508            let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
509
510            // Handle response stream and forward to session.acp_thread
511            while let Some(result) = response_stream.next().await {
512                match result {
513                    Ok(event) => {
514                        log::trace!("Received completion event: {:?}", event);
515
516                        match event {
517                            AgentResponseEvent::Text(text) => {
518                                acp_thread.update(cx, |thread, cx| {
519                                    thread.push_assistant_content_block(
520                                        acp::ContentBlock::Text(acp::TextContent {
521                                            text,
522                                            annotations: None,
523                                        }),
524                                        false,
525                                        cx,
526                                    )
527                                })?;
528                            }
529                            AgentResponseEvent::Thinking(text) => {
530                                acp_thread.update(cx, |thread, cx| {
531                                    thread.push_assistant_content_block(
532                                        acp::ContentBlock::Text(acp::TextContent {
533                                            text,
534                                            annotations: None,
535                                        }),
536                                        true,
537                                        cx,
538                                    )
539                                })?;
540                            }
541                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
542                                tool_call,
543                                options,
544                                response,
545                            }) => {
546                                let recv = acp_thread.update(cx, |thread, cx| {
547                                    thread.request_tool_call_authorization(tool_call, options, cx)
548                                })?;
549                                cx.background_spawn(async move {
550                                    if let Some(option) = recv
551                                        .await
552                                        .context("authorization sender was dropped")
553                                        .log_err()
554                                    {
555                                        response
556                                            .send(option)
557                                            .map(|_| anyhow!("authorization receiver was dropped"))
558                                            .log_err();
559                                    }
560                                })
561                                .detach();
562                            }
563                            AgentResponseEvent::ToolCall(tool_call) => {
564                                acp_thread.update(cx, |thread, cx| {
565                                    thread.upsert_tool_call(tool_call, cx)
566                                })?;
567                            }
568                            AgentResponseEvent::ToolCallUpdate(update) => {
569                                acp_thread.update(cx, |thread, cx| {
570                                    thread.update_tool_call(update, cx)
571                                })??;
572                            }
573                            AgentResponseEvent::Stop(stop_reason) => {
574                                log::debug!("Assistant message complete: {:?}", stop_reason);
575                                return Ok(acp::PromptResponse { stop_reason });
576                            }
577                        }
578                    }
579                    Err(e) => {
580                        log::error!("Error in model response stream: {:?}", e);
581                        // TODO: Consider sending an error message to the UI
582                        break;
583                    }
584                }
585            }
586
587            log::info!("Response stream completed");
588            anyhow::Ok(acp::PromptResponse {
589                stop_reason: acp::StopReason::EndTurn,
590            })
591        })
592    }
593
594    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
595        log::info!("Cancelling on session: {}", session_id);
596        self.0.update(cx, |agent, cx| {
597            if let Some(agent) = agent.sessions.get(session_id) {
598                agent.thread.update(cx, |thread, _cx| thread.cancel());
599            }
600        });
601    }
602}
603
604/// Convert ACP content blocks to a message string
605fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
606    log::debug!("Converting {} content blocks to message", blocks.len());
607    let mut message = String::new();
608
609    for block in blocks {
610        match block {
611            acp::ContentBlock::Text(text) => {
612                log::trace!("Processing text block: {} chars", text.text.len());
613                message.push_str(&text.text);
614            }
615            acp::ContentBlock::ResourceLink(link) => {
616                log::trace!("Processing resource link: {}", link.uri);
617                message.push_str(&format!(" @{} ", link.uri));
618            }
619            acp::ContentBlock::Image(_) => {
620                log::trace!("Processing image block");
621                message.push_str(" [image] ");
622            }
623            acp::ContentBlock::Audio(_) => {
624                log::trace!("Processing audio block");
625                message.push_str(" [audio] ");
626            }
627            acp::ContentBlock::Resource(resource) => {
628                log::trace!("Processing resource block: {:?}", resource.resource);
629                message.push_str(&format!(" [resource: {:?}] ", resource.resource));
630            }
631        }
632    }
633
634    message
635}
636
637#[cfg(test)]
638mod tests {
639    use super::*;
640    use fs::FakeFs;
641    use gpui::TestAppContext;
642    use serde_json::json;
643    use settings::SettingsStore;
644
645    #[gpui::test]
646    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
647        init_test(cx);
648        let fs = FakeFs::new(cx.executor());
649        fs.insert_tree(
650            "/",
651            json!({
652                "a": {}
653            }),
654        )
655        .await;
656        let project = Project::test(fs.clone(), [], cx).await;
657        let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
658            .await
659            .unwrap();
660        agent.read_with(cx, |agent, _| {
661            assert_eq!(agent.project_context.borrow().worktrees, vec![])
662        });
663
664        let worktree = project
665            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
666            .await
667            .unwrap();
668        cx.run_until_parked();
669        agent.read_with(cx, |agent, _| {
670            assert_eq!(
671                agent.project_context.borrow().worktrees,
672                vec![WorktreeContext {
673                    root_name: "a".into(),
674                    abs_path: Path::new("/a").into(),
675                    rules_file: None
676                }]
677            )
678        });
679
680        // Creating `/a/.rules` updates the project context.
681        fs.insert_file("/a/.rules", Vec::new()).await;
682        cx.run_until_parked();
683        agent.read_with(cx, |agent, cx| {
684            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
685            assert_eq!(
686                agent.project_context.borrow().worktrees,
687                vec![WorktreeContext {
688                    root_name: "a".into(),
689                    abs_path: Path::new("/a").into(),
690                    rules_file: Some(RulesFileContext {
691                        path_in_worktree: Path::new(".rules").into(),
692                        text: "".into(),
693                        project_entry_id: rules_entry.id.to_usize()
694                    })
695                }]
696            )
697        });
698    }
699
700    fn init_test(cx: &mut TestAppContext) {
701        env_logger::try_init().ok();
702        cx.update(|cx| {
703            let settings_store = SettingsStore::test(cx);
704            cx.set_global(settings_store);
705            Project::init_settings(cx);
706            language::init(cx);
707        });
708    }
709}