agent.rs

  1use crate::{AgentResponseEvent, Thread, templates::Templates};
  2use crate::{
  3    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
  4    FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
  5    OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
  6};
  7use acp_thread::ModelSelector;
  8use agent_client_protocol as acp;
  9use anyhow::{Context as _, Result, anyhow};
 10use futures::{StreamExt, future};
 11use gpui::{
 12    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
 13};
 14use language_model::{LanguageModel, LanguageModelRegistry};
 15use project::{Project, ProjectItem, ProjectPath, Worktree};
 16use prompt_store::{
 17    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
 18};
 19use std::cell::RefCell;
 20use std::collections::HashMap;
 21use std::path::Path;
 22use std::rc::Rc;
 23use std::sync::Arc;
 24use util::ResultExt;
 25
 26const RULES_FILE_NAMES: [&'static str; 9] = [
 27    ".rules",
 28    ".cursorrules",
 29    ".windsurfrules",
 30    ".clinerules",
 31    ".github/copilot-instructions.md",
 32    "CLAUDE.md",
 33    "AGENT.md",
 34    "AGENTS.md",
 35    "GEMINI.md",
 36];
 37
 38pub struct RulesLoadingError {
 39    pub message: SharedString,
 40}
 41
 42/// Holds both the internal Thread and the AcpThread for a session
 43struct Session {
 44    /// The internal thread that processes messages
 45    thread: Entity<Thread>,
 46    /// The ACP thread that handles protocol communication
 47    acp_thread: WeakEntity<acp_thread::AcpThread>,
 48    _subscription: Subscription,
 49}
 50
 51pub struct NativeAgent {
 52    /// Session ID -> Session mapping
 53    sessions: HashMap<acp::SessionId, Session>,
 54    /// Shared project context for all threads
 55    project_context: Rc<RefCell<ProjectContext>>,
 56    project_context_needs_refresh: watch::Sender<()>,
 57    _maintain_project_context: Task<Result<()>>,
 58    context_server_registry: Entity<ContextServerRegistry>,
 59    /// Shared templates for all threads
 60    templates: Arc<Templates>,
 61    project: Entity<Project>,
 62    prompt_store: Option<Entity<PromptStore>>,
 63    _subscriptions: Vec<Subscription>,
 64}
 65
 66impl NativeAgent {
 67    pub async fn new(
 68        project: Entity<Project>,
 69        templates: Arc<Templates>,
 70        prompt_store: Option<Entity<PromptStore>>,
 71        cx: &mut AsyncApp,
 72    ) -> Result<Entity<NativeAgent>> {
 73        log::info!("Creating new NativeAgent");
 74
 75        let project_context = cx
 76            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 77            .await;
 78
 79        cx.new(|cx| {
 80            let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
 81            if let Some(prompt_store) = prompt_store.as_ref() {
 82                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 83            }
 84
 85            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 86                watch::channel(());
 87            Self {
 88                sessions: HashMap::new(),
 89                project_context: Rc::new(RefCell::new(project_context)),
 90                project_context_needs_refresh: project_context_needs_refresh_tx,
 91                _maintain_project_context: cx.spawn(async move |this, cx| {
 92                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 93                }),
 94                context_server_registry: cx.new(|cx| {
 95                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 96                }),
 97                templates,
 98                project,
 99                prompt_store,
100                _subscriptions: subscriptions,
101            }
102        })
103    }
104
105    async fn maintain_project_context(
106        this: WeakEntity<Self>,
107        mut needs_refresh: watch::Receiver<()>,
108        cx: &mut AsyncApp,
109    ) -> Result<()> {
110        while needs_refresh.changed().await.is_ok() {
111            let project_context = this
112                .update(cx, |this, cx| {
113                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
114                })?
115                .await;
116            this.update(cx, |this, _| this.project_context.replace(project_context))?;
117        }
118
119        Ok(())
120    }
121
122    fn build_project_context(
123        project: &Entity<Project>,
124        prompt_store: Option<&Entity<PromptStore>>,
125        cx: &mut App,
126    ) -> Task<ProjectContext> {
127        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
128        let worktree_tasks = worktrees
129            .into_iter()
130            .map(|worktree| {
131                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
132            })
133            .collect::<Vec<_>>();
134        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
135            prompt_store.read_with(cx, |prompt_store, cx| {
136                let prompts = prompt_store.default_prompt_metadata();
137                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
138                    let contents = prompt_store.load(prompt_metadata.id, cx);
139                    async move { (contents.await, prompt_metadata) }
140                });
141                cx.background_spawn(future::join_all(load_tasks))
142            })
143        } else {
144            Task::ready(vec![])
145        };
146
147        cx.spawn(async move |_cx| {
148            let (worktrees, default_user_rules) =
149                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
150
151            let worktrees = worktrees
152                .into_iter()
153                .map(|(worktree, _rules_error)| {
154                    // TODO: show error message
155                    // if let Some(rules_error) = rules_error {
156                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
157                    // }
158                    worktree
159                })
160                .collect::<Vec<_>>();
161
162            let default_user_rules = default_user_rules
163                .into_iter()
164                .flat_map(|(contents, prompt_metadata)| match contents {
165                    Ok(contents) => Some(UserRulesContext {
166                        uuid: match prompt_metadata.id {
167                            PromptId::User { uuid } => uuid,
168                            PromptId::EditWorkflow => return None,
169                        },
170                        title: prompt_metadata.title.map(|title| title.to_string()),
171                        contents,
172                    }),
173                    Err(_err) => {
174                        // TODO: show error message
175                        // this.update(cx, |_, cx| {
176                        //     cx.emit(RulesLoadingError {
177                        //         message: format!("{err:?}").into(),
178                        //     });
179                        // })
180                        // .ok();
181                        None
182                    }
183                })
184                .collect::<Vec<_>>();
185
186            ProjectContext::new(worktrees, default_user_rules)
187        })
188    }
189
190    fn load_worktree_info_for_system_prompt(
191        worktree: Entity<Worktree>,
192        project: Entity<Project>,
193        cx: &mut App,
194    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
195        let tree = worktree.read(cx);
196        let root_name = tree.root_name().into();
197        let abs_path = tree.abs_path();
198
199        let mut context = WorktreeContext {
200            root_name,
201            abs_path,
202            rules_file: None,
203        };
204
205        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
206        let Some(rules_task) = rules_task else {
207            return Task::ready((context, None));
208        };
209
210        cx.spawn(async move |_| {
211            let (rules_file, rules_file_error) = match rules_task.await {
212                Ok(rules_file) => (Some(rules_file), None),
213                Err(err) => (
214                    None,
215                    Some(RulesLoadingError {
216                        message: format!("{err}").into(),
217                    }),
218                ),
219            };
220            context.rules_file = rules_file;
221            (context, rules_file_error)
222        })
223    }
224
225    fn load_worktree_rules_file(
226        worktree: Entity<Worktree>,
227        project: Entity<Project>,
228        cx: &mut App,
229    ) -> Option<Task<Result<RulesFileContext>>> {
230        let worktree = worktree.read(cx);
231        let worktree_id = worktree.id();
232        let selected_rules_file = RULES_FILE_NAMES
233            .into_iter()
234            .filter_map(|name| {
235                worktree
236                    .entry_for_path(name)
237                    .filter(|entry| entry.is_file())
238                    .map(|entry| entry.path.clone())
239            })
240            .next();
241
242        // Note that Cline supports `.clinerules` being a directory, but that is not currently
243        // supported. This doesn't seem to occur often in GitHub repositories.
244        selected_rules_file.map(|path_in_worktree| {
245            let project_path = ProjectPath {
246                worktree_id,
247                path: path_in_worktree.clone(),
248            };
249            let buffer_task =
250                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
251            let rope_task = cx.spawn(async move |cx| {
252                buffer_task.await?.read_with(cx, |buffer, cx| {
253                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
254                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
255                })?
256            });
257            // Build a string from the rope on a background thread.
258            cx.background_spawn(async move {
259                let (project_entry_id, rope) = rope_task.await?;
260                anyhow::Ok(RulesFileContext {
261                    path_in_worktree,
262                    text: rope.to_string().trim().to_string(),
263                    project_entry_id: project_entry_id.to_usize(),
264                })
265            })
266        })
267    }
268
269    fn handle_project_event(
270        &mut self,
271        _project: Entity<Project>,
272        event: &project::Event,
273        _cx: &mut Context<Self>,
274    ) {
275        match event {
276            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
277                self.project_context_needs_refresh.send(()).ok();
278            }
279            project::Event::WorktreeUpdatedEntries(_, items) => {
280                if items.iter().any(|(path, _, _)| {
281                    RULES_FILE_NAMES
282                        .iter()
283                        .any(|name| path.as_ref() == Path::new(name))
284                }) {
285                    self.project_context_needs_refresh.send(()).ok();
286                }
287            }
288            _ => {}
289        }
290    }
291
292    fn handle_prompts_updated_event(
293        &mut self,
294        _prompt_store: Entity<PromptStore>,
295        _event: &prompt_store::PromptsUpdatedEvent,
296        _cx: &mut Context<Self>,
297    ) {
298        self.project_context_needs_refresh.send(()).ok();
299    }
300}
301
302/// Wrapper struct that implements the AgentConnection trait
303#[derive(Clone)]
304pub struct NativeAgentConnection(pub Entity<NativeAgent>);
305
306impl ModelSelector for NativeAgentConnection {
307    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
308        log::debug!("NativeAgentConnection::list_models called");
309        cx.spawn(async move |cx| {
310            cx.update(|cx| {
311                let registry = LanguageModelRegistry::read_global(cx);
312                let models = registry.available_models(cx).collect::<Vec<_>>();
313                log::info!("Found {} available models", models.len());
314                if models.is_empty() {
315                    Err(anyhow::anyhow!("No models available"))
316                } else {
317                    Ok(models)
318                }
319            })?
320        })
321    }
322
323    fn select_model(
324        &self,
325        session_id: acp::SessionId,
326        model: Arc<dyn LanguageModel>,
327        cx: &mut AsyncApp,
328    ) -> Task<Result<()>> {
329        log::info!(
330            "Setting model for session {}: {:?}",
331            session_id,
332            model.name()
333        );
334        let agent = self.0.clone();
335
336        cx.spawn(async move |cx| {
337            agent.update(cx, |agent, cx| {
338                if let Some(session) = agent.sessions.get(&session_id) {
339                    session.thread.update(cx, |thread, _cx| {
340                        thread.selected_model = model;
341                    });
342                    Ok(())
343                } else {
344                    Err(anyhow!("Session not found"))
345                }
346            })?
347        })
348    }
349
350    fn selected_model(
351        &self,
352        session_id: &acp::SessionId,
353        cx: &mut AsyncApp,
354    ) -> Task<Result<Arc<dyn LanguageModel>>> {
355        let agent = self.0.clone();
356        let session_id = session_id.clone();
357        cx.spawn(async move |cx| {
358            let thread = agent
359                .read_with(cx, |agent, _| {
360                    agent
361                        .sessions
362                        .get(&session_id)
363                        .map(|session| session.thread.clone())
364                })?
365                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
366            let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
367            Ok(selected)
368        })
369    }
370}
371
372impl acp_thread::AgentConnection for NativeAgentConnection {
373    fn new_thread(
374        self: Rc<Self>,
375        project: Entity<Project>,
376        cwd: &Path,
377        cx: &mut AsyncApp,
378    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
379        let agent = self.0.clone();
380        log::info!("Creating new thread for project at: {:?}", cwd);
381
382        cx.spawn(async move |cx| {
383            log::debug!("Starting thread creation in async context");
384
385            // Generate session ID
386            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
387            log::info!("Created session with ID: {}", session_id);
388
389            // Create AcpThread
390            let acp_thread = cx.update(|cx| {
391                cx.new(|cx| {
392                    acp_thread::AcpThread::new(
393                        "agent2",
394                        self.clone(),
395                        project.clone(),
396                        session_id.clone(),
397                        cx,
398                    )
399                })
400            })?;
401            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
402
403            // Create Thread
404            let thread = agent.update(
405                cx,
406                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
407                    // Fetch default model from registry settings
408                    let registry = LanguageModelRegistry::read_global(cx);
409
410                    // Log available models for debugging
411                    let available_count = registry.available_models(cx).count();
412                    log::debug!("Total available models: {}", available_count);
413
414                    let default_model = registry
415                        .default_model()
416                        .map(|configured| {
417                            log::info!(
418                                "Using configured default model: {:?} from provider: {:?}",
419                                configured.model.name(),
420                                configured.provider.name()
421                            );
422                            configured.model
423                        })
424                        .ok_or_else(|| {
425                            log::warn!("No default model configured in settings");
426                            anyhow!(
427                                "No default model. Please configure a default model in settings."
428                            )
429                        })?;
430
431                    let thread = cx.new(|cx| {
432                        let mut thread = Thread::new(
433                            project.clone(),
434                            agent.project_context.clone(),
435                            agent.context_server_registry.clone(),
436                            action_log.clone(),
437                            agent.templates.clone(),
438                            default_model,
439                            cx,
440                        );
441                        thread.add_tool(CreateDirectoryTool::new(project.clone()));
442                        thread.add_tool(CopyPathTool::new(project.clone()));
443                        thread.add_tool(DiagnosticsTool::new(project.clone()));
444                        thread.add_tool(MovePathTool::new(project.clone()));
445                        thread.add_tool(ListDirectoryTool::new(project.clone()));
446                        thread.add_tool(OpenTool::new(project.clone()));
447                        thread.add_tool(ThinkingTool);
448                        thread.add_tool(FindPathTool::new(project.clone()));
449                        thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
450                        thread.add_tool(GrepTool::new(project.clone()));
451                        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
452                        thread.add_tool(EditFileTool::new(cx.entity()));
453                        thread.add_tool(NowTool);
454                        thread.add_tool(TerminalTool::new(project.clone(), cx));
455                        // TODO: Needs to be conditional based on zed model or not
456                        thread.add_tool(WebSearchTool);
457                        thread
458                    });
459
460                    Ok(thread)
461                },
462            )??;
463
464            // Store the session
465            agent.update(cx, |agent, cx| {
466                agent.sessions.insert(
467                    session_id,
468                    Session {
469                        thread,
470                        acp_thread: acp_thread.downgrade(),
471                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
472                            this.sessions.remove(acp_thread.session_id());
473                        }),
474                    },
475                );
476            })?;
477
478            Ok(acp_thread)
479        })
480    }
481
482    fn auth_methods(&self) -> &[acp::AuthMethod] {
483        &[] // No auth for in-process
484    }
485
486    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
487        Task::ready(Ok(()))
488    }
489
490    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
491        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
492    }
493
494    fn prompt(
495        &self,
496        params: acp::PromptRequest,
497        cx: &mut App,
498    ) -> Task<Result<acp::PromptResponse>> {
499        let session_id = params.session_id.clone();
500        let agent = self.0.clone();
501        log::info!("Received prompt request for session: {}", session_id);
502        log::debug!("Prompt blocks count: {}", params.prompt.len());
503
504        cx.spawn(async move |cx| {
505            // Get session
506            let (thread, acp_thread) = agent
507                .update(cx, |agent, _| {
508                    agent
509                        .sessions
510                        .get_mut(&session_id)
511                        .map(|s| (s.thread.clone(), s.acp_thread.clone()))
512                })?
513                .ok_or_else(|| {
514                    log::error!("Session not found: {}", session_id);
515                    anyhow::anyhow!("Session not found")
516                })?;
517            log::debug!("Found session for: {}", session_id);
518
519            let message: Vec<MessageContent> = params
520                .prompt
521                .into_iter()
522                .map(Into::into)
523                .collect::<Vec<_>>();
524            log::info!("Converted prompt to message: {} chars", message.len());
525            log::debug!("Message content: {:?}", message);
526
527            // Get model using the ModelSelector capability (always available for agent2)
528            // Get the selected model from the thread directly
529            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
530
531            // Send to thread
532            log::info!("Sending message to thread with model: {:?}", model.name());
533            let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
534
535            // Handle response stream and forward to session.acp_thread
536            while let Some(result) = response_stream.next().await {
537                match result {
538                    Ok(event) => {
539                        log::trace!("Received completion event: {:?}", event);
540
541                        match event {
542                            AgentResponseEvent::Text(text) => {
543                                acp_thread.update(cx, |thread, cx| {
544                                    thread.push_assistant_content_block(
545                                        acp::ContentBlock::Text(acp::TextContent {
546                                            text,
547                                            annotations: None,
548                                        }),
549                                        false,
550                                        cx,
551                                    )
552                                })?;
553                            }
554                            AgentResponseEvent::Thinking(text) => {
555                                acp_thread.update(cx, |thread, cx| {
556                                    thread.push_assistant_content_block(
557                                        acp::ContentBlock::Text(acp::TextContent {
558                                            text,
559                                            annotations: None,
560                                        }),
561                                        true,
562                                        cx,
563                                    )
564                                })?;
565                            }
566                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
567                                tool_call,
568                                options,
569                                response,
570                            }) => {
571                                let recv = acp_thread.update(cx, |thread, cx| {
572                                    thread.request_tool_call_authorization(tool_call, options, cx)
573                                })?;
574                                cx.background_spawn(async move {
575                                    if let Some(option) = recv
576                                        .await
577                                        .context("authorization sender was dropped")
578                                        .log_err()
579                                    {
580                                        response
581                                            .send(option)
582                                            .map(|_| anyhow!("authorization receiver was dropped"))
583                                            .log_err();
584                                    }
585                                })
586                                .detach();
587                            }
588                            AgentResponseEvent::ToolCall(tool_call) => {
589                                acp_thread.update(cx, |thread, cx| {
590                                    thread.upsert_tool_call(tool_call, cx)
591                                })?;
592                            }
593                            AgentResponseEvent::ToolCallUpdate(update) => {
594                                acp_thread.update(cx, |thread, cx| {
595                                    thread.update_tool_call(update, cx)
596                                })??;
597                            }
598                            AgentResponseEvent::Stop(stop_reason) => {
599                                log::debug!("Assistant message complete: {:?}", stop_reason);
600                                return Ok(acp::PromptResponse { stop_reason });
601                            }
602                        }
603                    }
604                    Err(e) => {
605                        log::error!("Error in model response stream: {:?}", e);
606                        // TODO: Consider sending an error message to the UI
607                        break;
608                    }
609                }
610            }
611
612            log::info!("Response stream completed");
613            anyhow::Ok(acp::PromptResponse {
614                stop_reason: acp::StopReason::EndTurn,
615            })
616        })
617    }
618
619    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
620        log::info!("Cancelling on session: {}", session_id);
621        self.0.update(cx, |agent, cx| {
622            if let Some(agent) = agent.sessions.get(session_id) {
623                agent.thread.update(cx, |thread, _cx| thread.cancel());
624            }
625        });
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632    use fs::FakeFs;
633    use gpui::TestAppContext;
634    use serde_json::json;
635    use settings::SettingsStore;
636
637    #[gpui::test]
638    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
639        init_test(cx);
640        let fs = FakeFs::new(cx.executor());
641        fs.insert_tree(
642            "/",
643            json!({
644                "a": {}
645            }),
646        )
647        .await;
648        let project = Project::test(fs.clone(), [], cx).await;
649        let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
650            .await
651            .unwrap();
652        agent.read_with(cx, |agent, _| {
653            assert_eq!(agent.project_context.borrow().worktrees, vec![])
654        });
655
656        let worktree = project
657            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
658            .await
659            .unwrap();
660        cx.run_until_parked();
661        agent.read_with(cx, |agent, _| {
662            assert_eq!(
663                agent.project_context.borrow().worktrees,
664                vec![WorktreeContext {
665                    root_name: "a".into(),
666                    abs_path: Path::new("/a").into(),
667                    rules_file: None
668                }]
669            )
670        });
671
672        // Creating `/a/.rules` updates the project context.
673        fs.insert_file("/a/.rules", Vec::new()).await;
674        cx.run_until_parked();
675        agent.read_with(cx, |agent, cx| {
676            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
677            assert_eq!(
678                agent.project_context.borrow().worktrees,
679                vec![WorktreeContext {
680                    root_name: "a".into(),
681                    abs_path: Path::new("/a").into(),
682                    rules_file: Some(RulesFileContext {
683                        path_in_worktree: Path::new(".rules").into(),
684                        text: "".into(),
685                        project_entry_id: rules_entry.id.to_usize()
686                    })
687                }]
688            )
689        });
690    }
691
692    fn init_test(cx: &mut TestAppContext) {
693        env_logger::try_init().ok();
694        cx.update(|cx| {
695            let settings_store = SettingsStore::test(cx);
696            cx.set_global(settings_store);
697            Project::init_settings(cx);
698            language::init(cx);
699        });
700    }
701}