agent.rs

  1use crate::{AgentResponseEvent, Thread, templates::Templates};
  2use crate::{
  3    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
  4    EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
  5    OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
  6    WebSearchTool,
  7};
  8use acp_thread::AgentModelSelector;
  9use agent_client_protocol as acp;
 10use agent_settings::AgentSettings;
 11use anyhow::{Context as _, Result, anyhow};
 12use collections::{HashSet, IndexMap};
 13use fs::Fs;
 14use futures::{StreamExt, future};
 15use gpui::{
 16    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
 17};
 18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
 19use project::{Project, ProjectItem, ProjectPath, Worktree};
 20use prompt_store::{
 21    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
 22};
 23use settings::update_settings_file;
 24use std::cell::RefCell;
 25use std::collections::HashMap;
 26use std::path::Path;
 27use std::rc::Rc;
 28use std::sync::Arc;
 29use util::ResultExt;
 30
 31const RULES_FILE_NAMES: [&'static str; 9] = [
 32    ".rules",
 33    ".cursorrules",
 34    ".windsurfrules",
 35    ".clinerules",
 36    ".github/copilot-instructions.md",
 37    "CLAUDE.md",
 38    "AGENT.md",
 39    "AGENTS.md",
 40    "GEMINI.md",
 41];
 42
 43pub struct RulesLoadingError {
 44    pub message: SharedString,
 45}
 46
 47/// Holds both the internal Thread and the AcpThread for a session
 48struct Session {
 49    /// The internal thread that processes messages
 50    thread: Entity<Thread>,
 51    /// The ACP thread that handles protocol communication
 52    acp_thread: WeakEntity<acp_thread::AcpThread>,
 53    _subscription: Subscription,
 54}
 55
 56pub struct LanguageModels {
 57    /// Access language model by ID
 58    models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
 59    /// Cached list for returning language model information
 60    model_list: acp_thread::AgentModelList,
 61    refresh_models_rx: watch::Receiver<()>,
 62    refresh_models_tx: watch::Sender<()>,
 63}
 64
 65impl LanguageModels {
 66    fn new(cx: &App) -> Self {
 67        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 68        let mut this = Self {
 69            models: HashMap::default(),
 70            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 71            refresh_models_rx,
 72            refresh_models_tx,
 73        };
 74        this.refresh_list(cx);
 75        this
 76    }
 77
 78    fn refresh_list(&mut self, cx: &App) {
 79        let providers = LanguageModelRegistry::global(cx)
 80            .read(cx)
 81            .providers()
 82            .into_iter()
 83            .filter(|provider| provider.is_authenticated(cx))
 84            .collect::<Vec<_>>();
 85
 86        let mut language_model_list = IndexMap::default();
 87        let mut recommended_models = HashSet::default();
 88
 89        let mut recommended = Vec::new();
 90        for provider in &providers {
 91            for model in provider.recommended_models(cx) {
 92                recommended_models.insert(model.id());
 93                recommended.push(Self::map_language_model_to_info(&model, &provider));
 94            }
 95        }
 96        if !recommended.is_empty() {
 97            language_model_list.insert(
 98                acp_thread::AgentModelGroupName("Recommended".into()),
 99                recommended,
100            );
101        }
102
103        let mut models = HashMap::default();
104        for provider in providers {
105            let mut provider_models = Vec::new();
106            for model in provider.provided_models(cx) {
107                let model_info = Self::map_language_model_to_info(&model, &provider);
108                let model_id = model_info.id.clone();
109                if !recommended_models.contains(&model.id()) {
110                    provider_models.push(model_info);
111                }
112                models.insert(model_id, model);
113            }
114            if !provider_models.is_empty() {
115                language_model_list.insert(
116                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
117                    provider_models,
118                );
119            }
120        }
121
122        self.models = models;
123        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
124        self.refresh_models_tx.send(()).ok();
125    }
126
127    fn watch(&self) -> watch::Receiver<()> {
128        self.refresh_models_rx.clone()
129    }
130
131    pub fn model_from_id(
132        &self,
133        model_id: &acp_thread::AgentModelId,
134    ) -> Option<Arc<dyn LanguageModel>> {
135        self.models.get(model_id).cloned()
136    }
137
138    fn map_language_model_to_info(
139        model: &Arc<dyn LanguageModel>,
140        provider: &Arc<dyn LanguageModelProvider>,
141    ) -> acp_thread::AgentModelInfo {
142        acp_thread::AgentModelInfo {
143            id: Self::model_id(model),
144            name: model.name().0,
145            icon: Some(provider.icon()),
146        }
147    }
148
149    fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
150        acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
151    }
152}
153
154pub struct NativeAgent {
155    /// Session ID -> Session mapping
156    sessions: HashMap<acp::SessionId, Session>,
157    /// Shared project context for all threads
158    project_context: Rc<RefCell<ProjectContext>>,
159    project_context_needs_refresh: watch::Sender<()>,
160    _maintain_project_context: Task<Result<()>>,
161    context_server_registry: Entity<ContextServerRegistry>,
162    /// Shared templates for all threads
163    templates: Arc<Templates>,
164    /// Cached model information
165    models: LanguageModels,
166    project: Entity<Project>,
167    prompt_store: Option<Entity<PromptStore>>,
168    fs: Arc<dyn Fs>,
169    _subscriptions: Vec<Subscription>,
170}
171
172impl NativeAgent {
173    pub async fn new(
174        project: Entity<Project>,
175        templates: Arc<Templates>,
176        prompt_store: Option<Entity<PromptStore>>,
177        fs: Arc<dyn Fs>,
178        cx: &mut AsyncApp,
179    ) -> Result<Entity<NativeAgent>> {
180        log::info!("Creating new NativeAgent");
181
182        let project_context = cx
183            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
184            .await;
185
186        cx.new(|cx| {
187            let mut subscriptions = vec![
188                cx.subscribe(&project, Self::handle_project_event),
189                cx.subscribe(
190                    &LanguageModelRegistry::global(cx),
191                    Self::handle_models_updated_event,
192                ),
193            ];
194            if let Some(prompt_store) = prompt_store.as_ref() {
195                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
196            }
197
198            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
199                watch::channel(());
200            Self {
201                sessions: HashMap::new(),
202                project_context: Rc::new(RefCell::new(project_context)),
203                project_context_needs_refresh: project_context_needs_refresh_tx,
204                _maintain_project_context: cx.spawn(async move |this, cx| {
205                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
206                }),
207                context_server_registry: cx.new(|cx| {
208                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
209                }),
210                templates,
211                models: LanguageModels::new(cx),
212                project,
213                prompt_store,
214                fs,
215                _subscriptions: subscriptions,
216            }
217        })
218    }
219
220    pub fn models(&self) -> &LanguageModels {
221        &self.models
222    }
223
224    async fn maintain_project_context(
225        this: WeakEntity<Self>,
226        mut needs_refresh: watch::Receiver<()>,
227        cx: &mut AsyncApp,
228    ) -> Result<()> {
229        while needs_refresh.changed().await.is_ok() {
230            let project_context = this
231                .update(cx, |this, cx| {
232                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
233                })?
234                .await;
235            this.update(cx, |this, _| this.project_context.replace(project_context))?;
236        }
237
238        Ok(())
239    }
240
241    fn build_project_context(
242        project: &Entity<Project>,
243        prompt_store: Option<&Entity<PromptStore>>,
244        cx: &mut App,
245    ) -> Task<ProjectContext> {
246        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
247        let worktree_tasks = worktrees
248            .into_iter()
249            .map(|worktree| {
250                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
251            })
252            .collect::<Vec<_>>();
253        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
254            prompt_store.read_with(cx, |prompt_store, cx| {
255                let prompts = prompt_store.default_prompt_metadata();
256                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
257                    let contents = prompt_store.load(prompt_metadata.id, cx);
258                    async move { (contents.await, prompt_metadata) }
259                });
260                cx.background_spawn(future::join_all(load_tasks))
261            })
262        } else {
263            Task::ready(vec![])
264        };
265
266        cx.spawn(async move |_cx| {
267            let (worktrees, default_user_rules) =
268                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
269
270            let worktrees = worktrees
271                .into_iter()
272                .map(|(worktree, _rules_error)| {
273                    // TODO: show error message
274                    // if let Some(rules_error) = rules_error {
275                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
276                    // }
277                    worktree
278                })
279                .collect::<Vec<_>>();
280
281            let default_user_rules = default_user_rules
282                .into_iter()
283                .flat_map(|(contents, prompt_metadata)| match contents {
284                    Ok(contents) => Some(UserRulesContext {
285                        uuid: match prompt_metadata.id {
286                            PromptId::User { uuid } => uuid,
287                            PromptId::EditWorkflow => return None,
288                        },
289                        title: prompt_metadata.title.map(|title| title.to_string()),
290                        contents,
291                    }),
292                    Err(_err) => {
293                        // TODO: show error message
294                        // this.update(cx, |_, cx| {
295                        //     cx.emit(RulesLoadingError {
296                        //         message: format!("{err:?}").into(),
297                        //     });
298                        // })
299                        // .ok();
300                        None
301                    }
302                })
303                .collect::<Vec<_>>();
304
305            ProjectContext::new(worktrees, default_user_rules)
306        })
307    }
308
309    fn load_worktree_info_for_system_prompt(
310        worktree: Entity<Worktree>,
311        project: Entity<Project>,
312        cx: &mut App,
313    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
314        let tree = worktree.read(cx);
315        let root_name = tree.root_name().into();
316        let abs_path = tree.abs_path();
317
318        let mut context = WorktreeContext {
319            root_name,
320            abs_path,
321            rules_file: None,
322        };
323
324        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
325        let Some(rules_task) = rules_task else {
326            return Task::ready((context, None));
327        };
328
329        cx.spawn(async move |_| {
330            let (rules_file, rules_file_error) = match rules_task.await {
331                Ok(rules_file) => (Some(rules_file), None),
332                Err(err) => (
333                    None,
334                    Some(RulesLoadingError {
335                        message: format!("{err}").into(),
336                    }),
337                ),
338            };
339            context.rules_file = rules_file;
340            (context, rules_file_error)
341        })
342    }
343
344    fn load_worktree_rules_file(
345        worktree: Entity<Worktree>,
346        project: Entity<Project>,
347        cx: &mut App,
348    ) -> Option<Task<Result<RulesFileContext>>> {
349        let worktree = worktree.read(cx);
350        let worktree_id = worktree.id();
351        let selected_rules_file = RULES_FILE_NAMES
352            .into_iter()
353            .filter_map(|name| {
354                worktree
355                    .entry_for_path(name)
356                    .filter(|entry| entry.is_file())
357                    .map(|entry| entry.path.clone())
358            })
359            .next();
360
361        // Note that Cline supports `.clinerules` being a directory, but that is not currently
362        // supported. This doesn't seem to occur often in GitHub repositories.
363        selected_rules_file.map(|path_in_worktree| {
364            let project_path = ProjectPath {
365                worktree_id,
366                path: path_in_worktree.clone(),
367            };
368            let buffer_task =
369                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
370            let rope_task = cx.spawn(async move |cx| {
371                buffer_task.await?.read_with(cx, |buffer, cx| {
372                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
373                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
374                })?
375            });
376            // Build a string from the rope on a background thread.
377            cx.background_spawn(async move {
378                let (project_entry_id, rope) = rope_task.await?;
379                anyhow::Ok(RulesFileContext {
380                    path_in_worktree,
381                    text: rope.to_string().trim().to_string(),
382                    project_entry_id: project_entry_id.to_usize(),
383                })
384            })
385        })
386    }
387
388    fn handle_project_event(
389        &mut self,
390        _project: Entity<Project>,
391        event: &project::Event,
392        _cx: &mut Context<Self>,
393    ) {
394        match event {
395            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
396                self.project_context_needs_refresh.send(()).ok();
397            }
398            project::Event::WorktreeUpdatedEntries(_, items) => {
399                if items.iter().any(|(path, _, _)| {
400                    RULES_FILE_NAMES
401                        .iter()
402                        .any(|name| path.as_ref() == Path::new(name))
403                }) {
404                    self.project_context_needs_refresh.send(()).ok();
405                }
406            }
407            _ => {}
408        }
409    }
410
411    fn handle_prompts_updated_event(
412        &mut self,
413        _prompt_store: Entity<PromptStore>,
414        _event: &prompt_store::PromptsUpdatedEvent,
415        _cx: &mut Context<Self>,
416    ) {
417        self.project_context_needs_refresh.send(()).ok();
418    }
419
420    fn handle_models_updated_event(
421        &mut self,
422        _registry: Entity<LanguageModelRegistry>,
423        _event: &language_model::Event,
424        cx: &mut Context<Self>,
425    ) {
426        self.models.refresh_list(cx);
427        for session in self.sessions.values_mut() {
428            session.thread.update(cx, |thread, _| {
429                let model_id = LanguageModels::model_id(&thread.selected_model);
430                if let Some(model) = self.models.model_from_id(&model_id) {
431                    thread.selected_model = model.clone();
432                }
433            });
434        }
435    }
436}
437
438/// Wrapper struct that implements the AgentConnection trait
439#[derive(Clone)]
440pub struct NativeAgentConnection(pub Entity<NativeAgent>);
441
442impl AgentModelSelector for NativeAgentConnection {
443    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
444        log::debug!("NativeAgentConnection::list_models called");
445        let list = self.0.read(cx).models.model_list.clone();
446        Task::ready(if list.is_empty() {
447            Err(anyhow::anyhow!("No models available"))
448        } else {
449            Ok(list)
450        })
451    }
452
453    fn select_model(
454        &self,
455        session_id: acp::SessionId,
456        model_id: acp_thread::AgentModelId,
457        cx: &mut App,
458    ) -> Task<Result<()>> {
459        log::info!("Setting model for session {}: {}", session_id, model_id);
460        let Some(thread) = self
461            .0
462            .read(cx)
463            .sessions
464            .get(&session_id)
465            .map(|session| session.thread.clone())
466        else {
467            return Task::ready(Err(anyhow!("Session not found")));
468        };
469
470        let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
471            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
472        };
473
474        thread.update(cx, |thread, _cx| {
475            thread.selected_model = model.clone();
476        });
477
478        update_settings_file::<AgentSettings>(
479            self.0.read(cx).fs.clone(),
480            cx,
481            move |settings, _cx| {
482                settings.set_model(model);
483            },
484        );
485
486        Task::ready(Ok(()))
487    }
488
489    fn selected_model(
490        &self,
491        session_id: &acp::SessionId,
492        cx: &mut App,
493    ) -> Task<Result<acp_thread::AgentModelInfo>> {
494        let session_id = session_id.clone();
495
496        let Some(thread) = self
497            .0
498            .read(cx)
499            .sessions
500            .get(&session_id)
501            .map(|session| session.thread.clone())
502        else {
503            return Task::ready(Err(anyhow!("Session not found")));
504        };
505        let model = thread.read(cx).selected_model.clone();
506        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
507        else {
508            return Task::ready(Err(anyhow!("Provider not found")));
509        };
510        Task::ready(Ok(LanguageModels::map_language_model_to_info(
511            &model, &provider,
512        )))
513    }
514
515    fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
516        self.0.read(cx).models.watch()
517    }
518}
519
520impl acp_thread::AgentConnection for NativeAgentConnection {
521    fn new_thread(
522        self: Rc<Self>,
523        project: Entity<Project>,
524        cwd: &Path,
525        cx: &mut App,
526    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
527        let agent = self.0.clone();
528        log::info!("Creating new thread for project at: {:?}", cwd);
529
530        cx.spawn(async move |cx| {
531            log::debug!("Starting thread creation in async context");
532
533            // Generate session ID
534            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
535            log::info!("Created session with ID: {}", session_id);
536
537            // Create AcpThread
538            let acp_thread = cx.update(|cx| {
539                cx.new(|cx| {
540                    acp_thread::AcpThread::new(
541                        "agent2",
542                        self.clone(),
543                        project.clone(),
544                        session_id.clone(),
545                        cx,
546                    )
547                })
548            })?;
549            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
550
551            // Create Thread
552            let thread = agent.update(
553                cx,
554                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
555                    // Fetch default model from registry settings
556                    let registry = LanguageModelRegistry::read_global(cx);
557
558                    // Log available models for debugging
559                    let available_count = registry.available_models(cx).count();
560                    log::debug!("Total available models: {}", available_count);
561
562                    let default_model = registry
563                        .default_model()
564                        .and_then(|default_model| {
565                            agent
566                                .models
567                                .model_from_id(&LanguageModels::model_id(&default_model.model))
568                        })
569                        .ok_or_else(|| {
570                            log::warn!("No default model configured in settings");
571                            anyhow!(
572                                "No default model. Please configure a default model in settings."
573                            )
574                        })?;
575
576                    let thread = cx.new(|cx| {
577                        let mut thread = Thread::new(
578                            project.clone(),
579                            agent.project_context.clone(),
580                            agent.context_server_registry.clone(),
581                            action_log.clone(),
582                            agent.templates.clone(),
583                            default_model,
584                            cx,
585                        );
586                        thread.add_tool(CopyPathTool::new(project.clone()));
587                        thread.add_tool(CreateDirectoryTool::new(project.clone()));
588                        thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
589                        thread.add_tool(DiagnosticsTool::new(project.clone()));
590                        thread.add_tool(EditFileTool::new(cx.entity()));
591                        thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
592                        thread.add_tool(FindPathTool::new(project.clone()));
593                        thread.add_tool(GrepTool::new(project.clone()));
594                        thread.add_tool(ListDirectoryTool::new(project.clone()));
595                        thread.add_tool(MovePathTool::new(project.clone()));
596                        thread.add_tool(NowTool);
597                        thread.add_tool(OpenTool::new(project.clone()));
598                        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
599                        thread.add_tool(TerminalTool::new(project.clone(), cx));
600                        thread.add_tool(ThinkingTool);
601                        thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
602                        thread
603                    });
604
605                    Ok(thread)
606                },
607            )??;
608
609            // Store the session
610            agent.update(cx, |agent, cx| {
611                agent.sessions.insert(
612                    session_id,
613                    Session {
614                        thread,
615                        acp_thread: acp_thread.downgrade(),
616                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
617                            this.sessions.remove(acp_thread.session_id());
618                        }),
619                    },
620                );
621            })?;
622
623            Ok(acp_thread)
624        })
625    }
626
627    fn auth_methods(&self) -> &[acp::AuthMethod] {
628        &[] // No auth for in-process
629    }
630
631    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
632        Task::ready(Ok(()))
633    }
634
635    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
636        Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
637    }
638
639    fn prompt(
640        &self,
641        id: Option<acp_thread::UserMessageId>,
642        params: acp::PromptRequest,
643        cx: &mut App,
644    ) -> Task<Result<acp::PromptResponse>> {
645        let id = id.expect("UserMessageId is required");
646        let session_id = params.session_id.clone();
647        let agent = self.0.clone();
648        log::info!("Received prompt request for session: {}", session_id);
649        log::debug!("Prompt blocks count: {}", params.prompt.len());
650
651        cx.spawn(async move |cx| {
652            // Get session
653            let (thread, acp_thread) = agent
654                .update(cx, |agent, _| {
655                    agent
656                        .sessions
657                        .get_mut(&session_id)
658                        .map(|s| (s.thread.clone(), s.acp_thread.clone()))
659                })?
660                .ok_or_else(|| {
661                    log::error!("Session not found: {}", session_id);
662                    anyhow::anyhow!("Session not found")
663                })?;
664            log::debug!("Found session for: {}", session_id);
665
666            let content: Vec<UserMessageContent> = params
667                .prompt
668                .into_iter()
669                .map(Into::into)
670                .collect::<Vec<_>>();
671            log::info!("Converted prompt to message: {} chars", content.len());
672            log::debug!("Message id: {:?}", id);
673            log::debug!("Message content: {:?}", content);
674
675            // Get model using the ModelSelector capability (always available for agent2)
676            // Get the selected model from the thread directly
677            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
678
679            // Send to thread
680            log::info!("Sending message to thread with model: {:?}", model.name());
681            let mut response_stream =
682                thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
683
684            // Handle response stream and forward to session.acp_thread
685            while let Some(result) = response_stream.next().await {
686                match result {
687                    Ok(event) => {
688                        log::trace!("Received completion event: {:?}", event);
689
690                        match event {
691                            AgentResponseEvent::Text(text) => {
692                                acp_thread.update(cx, |thread, cx| {
693                                    thread.push_assistant_content_block(
694                                        acp::ContentBlock::Text(acp::TextContent {
695                                            text,
696                                            annotations: None,
697                                        }),
698                                        false,
699                                        cx,
700                                    )
701                                })?;
702                            }
703                            AgentResponseEvent::Thinking(text) => {
704                                acp_thread.update(cx, |thread, cx| {
705                                    thread.push_assistant_content_block(
706                                        acp::ContentBlock::Text(acp::TextContent {
707                                            text,
708                                            annotations: None,
709                                        }),
710                                        true,
711                                        cx,
712                                    )
713                                })?;
714                            }
715                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
716                                tool_call,
717                                options,
718                                response,
719                            }) => {
720                                let recv = acp_thread.update(cx, |thread, cx| {
721                                    thread.request_tool_call_authorization(tool_call, options, cx)
722                                })?;
723                                cx.background_spawn(async move {
724                                    if let Some(option) = recv
725                                        .await
726                                        .context("authorization sender was dropped")
727                                        .log_err()
728                                    {
729                                        response
730                                            .send(option)
731                                            .map(|_| anyhow!("authorization receiver was dropped"))
732                                            .log_err();
733                                    }
734                                })
735                                .detach();
736                            }
737                            AgentResponseEvent::ToolCall(tool_call) => {
738                                acp_thread.update(cx, |thread, cx| {
739                                    thread.upsert_tool_call(tool_call, cx)
740                                })?;
741                            }
742                            AgentResponseEvent::ToolCallUpdate(update) => {
743                                acp_thread.update(cx, |thread, cx| {
744                                    thread.update_tool_call(update, cx)
745                                })??;
746                            }
747                            AgentResponseEvent::Stop(stop_reason) => {
748                                log::debug!("Assistant message complete: {:?}", stop_reason);
749                                return Ok(acp::PromptResponse { stop_reason });
750                            }
751                        }
752                    }
753                    Err(e) => {
754                        log::error!("Error in model response stream: {:?}", e);
755                        // TODO: Consider sending an error message to the UI
756                        break;
757                    }
758                }
759            }
760
761            log::info!("Response stream completed");
762            anyhow::Ok(acp::PromptResponse {
763                stop_reason: acp::StopReason::EndTurn,
764            })
765        })
766    }
767
768    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
769        log::info!("Cancelling on session: {}", session_id);
770        self.0.update(cx, |agent, cx| {
771            if let Some(agent) = agent.sessions.get(session_id) {
772                agent.thread.update(cx, |thread, _cx| thread.cancel());
773            }
774        });
775    }
776
777    fn session_editor(
778        &self,
779        session_id: &agent_client_protocol::SessionId,
780        cx: &mut App,
781    ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
782        self.0.update(cx, |agent, _cx| {
783            agent
784                .sessions
785                .get(session_id)
786                .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
787        })
788    }
789}
790
791struct NativeAgentSessionEditor(Entity<Thread>);
792
793impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
794    fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
795        Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
796    }
797}
798
799#[cfg(test)]
800mod tests {
801    use super::*;
802    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
803    use fs::FakeFs;
804    use gpui::TestAppContext;
805    use serde_json::json;
806    use settings::SettingsStore;
807
808    #[gpui::test]
809    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
810        init_test(cx);
811        let fs = FakeFs::new(cx.executor());
812        fs.insert_tree(
813            "/",
814            json!({
815                "a": {}
816            }),
817        )
818        .await;
819        let project = Project::test(fs.clone(), [], cx).await;
820        let agent = NativeAgent::new(
821            project.clone(),
822            Templates::new(),
823            None,
824            fs.clone(),
825            &mut cx.to_async(),
826        )
827        .await
828        .unwrap();
829        agent.read_with(cx, |agent, _| {
830            assert_eq!(agent.project_context.borrow().worktrees, vec![])
831        });
832
833        let worktree = project
834            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
835            .await
836            .unwrap();
837        cx.run_until_parked();
838        agent.read_with(cx, |agent, _| {
839            assert_eq!(
840                agent.project_context.borrow().worktrees,
841                vec![WorktreeContext {
842                    root_name: "a".into(),
843                    abs_path: Path::new("/a").into(),
844                    rules_file: None
845                }]
846            )
847        });
848
849        // Creating `/a/.rules` updates the project context.
850        fs.insert_file("/a/.rules", Vec::new()).await;
851        cx.run_until_parked();
852        agent.read_with(cx, |agent, cx| {
853            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
854            assert_eq!(
855                agent.project_context.borrow().worktrees,
856                vec![WorktreeContext {
857                    root_name: "a".into(),
858                    abs_path: Path::new("/a").into(),
859                    rules_file: Some(RulesFileContext {
860                        path_in_worktree: Path::new(".rules").into(),
861                        text: "".into(),
862                        project_entry_id: rules_entry.id.to_usize()
863                    })
864                }]
865            )
866        });
867    }
868
869    #[gpui::test]
870    async fn test_listing_models(cx: &mut TestAppContext) {
871        init_test(cx);
872        let fs = FakeFs::new(cx.executor());
873        fs.insert_tree("/", json!({ "a": {}  })).await;
874        let project = Project::test(fs.clone(), [], cx).await;
875        let connection = NativeAgentConnection(
876            NativeAgent::new(
877                project.clone(),
878                Templates::new(),
879                None,
880                fs.clone(),
881                &mut cx.to_async(),
882            )
883            .await
884            .unwrap(),
885        );
886
887        let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
888
889        let acp_thread::AgentModelList::Grouped(models) = models else {
890            panic!("Unexpected model group");
891        };
892        assert_eq!(
893            models,
894            IndexMap::from_iter([(
895                AgentModelGroupName("Fake".into()),
896                vec![AgentModelInfo {
897                    id: AgentModelId("fake/fake".into()),
898                    name: "Fake".into(),
899                    icon: Some(ui::IconName::ZedAssistant),
900                }]
901            )])
902        );
903    }
904
905    #[gpui::test]
906    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
907        init_test(cx);
908        let fs = FakeFs::new(cx.executor());
909        fs.create_dir(paths::settings_file().parent().unwrap())
910            .await
911            .unwrap();
912        fs.insert_file(
913            paths::settings_file(),
914            json!({
915                "agent": {
916                    "default_model": {
917                        "provider": "foo",
918                        "model": "bar"
919                    }
920                }
921            })
922            .to_string()
923            .into_bytes(),
924        )
925        .await;
926        let project = Project::test(fs.clone(), [], cx).await;
927
928        // Create the agent and connection
929        let agent = NativeAgent::new(
930            project.clone(),
931            Templates::new(),
932            None,
933            fs.clone(),
934            &mut cx.to_async(),
935        )
936        .await
937        .unwrap();
938        let connection = NativeAgentConnection(agent.clone());
939
940        // Create a thread/session
941        let acp_thread = cx
942            .update(|cx| {
943                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
944            })
945            .await
946            .unwrap();
947
948        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
949
950        // Select a model
951        let model_id = AgentModelId("fake/fake".into());
952        cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
953            .await
954            .unwrap();
955
956        // Verify the thread has the selected model
957        agent.read_with(cx, |agent, _| {
958            let session = agent.sessions.get(&session_id).unwrap();
959            session.thread.read_with(cx, |thread, _| {
960                assert_eq!(thread.selected_model.id().0, "fake");
961            });
962        });
963
964        cx.run_until_parked();
965
966        // Verify settings file was updated
967        let settings_content = fs.load(paths::settings_file()).await.unwrap();
968        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
969
970        // Check that the agent settings contain the selected model
971        assert_eq!(
972            settings_json["agent"]["default_model"]["model"],
973            json!("fake")
974        );
975        assert_eq!(
976            settings_json["agent"]["default_model"]["provider"],
977            json!("fake")
978        );
979    }
980
981    fn init_test(cx: &mut TestAppContext) {
982        env_logger::try_init().ok();
983        cx.update(|cx| {
984            let settings_store = SettingsStore::test(cx);
985            cx.set_global(settings_store);
986            Project::init_settings(cx);
987            agent_settings::init(cx);
988            language::init(cx);
989            LanguageModelRegistry::test(cx);
990        });
991    }
992}