agent.rs

  1use crate::{AgentResponseEvent, Thread, templates::Templates};
  2use crate::{
  3    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
  4    FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
  5    OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
  6};
  7use acp_thread::AgentModelSelector;
  8use agent_client_protocol as acp;
  9use agent_settings::AgentSettings;
 10use anyhow::{Context as _, Result, anyhow};
 11use collections::{HashSet, IndexMap};
 12use fs::Fs;
 13use futures::{StreamExt, future};
 14use gpui::{
 15    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
 16};
 17use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
 18use project::{Project, ProjectItem, ProjectPath, Worktree};
 19use prompt_store::{
 20    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
 21};
 22use settings::update_settings_file;
 23use std::cell::RefCell;
 24use std::collections::HashMap;
 25use std::path::Path;
 26use std::rc::Rc;
 27use std::sync::Arc;
 28use util::ResultExt;
 29
 30const RULES_FILE_NAMES: [&'static str; 9] = [
 31    ".rules",
 32    ".cursorrules",
 33    ".windsurfrules",
 34    ".clinerules",
 35    ".github/copilot-instructions.md",
 36    "CLAUDE.md",
 37    "AGENT.md",
 38    "AGENTS.md",
 39    "GEMINI.md",
 40];
 41
 42pub struct RulesLoadingError {
 43    pub message: SharedString,
 44}
 45
 46/// Holds both the internal Thread and the AcpThread for a session
 47struct Session {
 48    /// The internal thread that processes messages
 49    thread: Entity<Thread>,
 50    /// The ACP thread that handles protocol communication
 51    acp_thread: WeakEntity<acp_thread::AcpThread>,
 52    _subscription: Subscription,
 53}
 54
 55pub struct LanguageModels {
 56    /// Access language model by ID
 57    models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
 58    /// Cached list for returning language model information
 59    model_list: acp_thread::AgentModelList,
 60    refresh_models_rx: watch::Receiver<()>,
 61    refresh_models_tx: watch::Sender<()>,
 62}
 63
 64impl LanguageModels {
 65    fn new(cx: &App) -> Self {
 66        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
 67        let mut this = Self {
 68            models: HashMap::default(),
 69            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
 70            refresh_models_rx,
 71            refresh_models_tx,
 72        };
 73        this.refresh_list(cx);
 74        this
 75    }
 76
 77    fn refresh_list(&mut self, cx: &App) {
 78        let providers = LanguageModelRegistry::global(cx)
 79            .read(cx)
 80            .providers()
 81            .into_iter()
 82            .filter(|provider| provider.is_authenticated(cx))
 83            .collect::<Vec<_>>();
 84
 85        let mut language_model_list = IndexMap::default();
 86        let mut recommended_models = HashSet::default();
 87
 88        let mut recommended = Vec::new();
 89        for provider in &providers {
 90            for model in provider.recommended_models(cx) {
 91                recommended_models.insert(model.id());
 92                recommended.push(Self::map_language_model_to_info(&model, &provider));
 93            }
 94        }
 95        if !recommended.is_empty() {
 96            language_model_list.insert(
 97                acp_thread::AgentModelGroupName("Recommended".into()),
 98                recommended,
 99            );
100        }
101
102        let mut models = HashMap::default();
103        for provider in providers {
104            let mut provider_models = Vec::new();
105            for model in provider.provided_models(cx) {
106                let model_info = Self::map_language_model_to_info(&model, &provider);
107                let model_id = model_info.id.clone();
108                if !recommended_models.contains(&model.id()) {
109                    provider_models.push(model_info);
110                }
111                models.insert(model_id, model);
112            }
113            if !provider_models.is_empty() {
114                language_model_list.insert(
115                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
116                    provider_models,
117                );
118            }
119        }
120
121        self.models = models;
122        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
123        self.refresh_models_tx.send(()).ok();
124    }
125
126    fn watch(&self) -> watch::Receiver<()> {
127        self.refresh_models_rx.clone()
128    }
129
130    pub fn model_from_id(
131        &self,
132        model_id: &acp_thread::AgentModelId,
133    ) -> Option<Arc<dyn LanguageModel>> {
134        self.models.get(model_id).cloned()
135    }
136
137    fn map_language_model_to_info(
138        model: &Arc<dyn LanguageModel>,
139        provider: &Arc<dyn LanguageModelProvider>,
140    ) -> acp_thread::AgentModelInfo {
141        acp_thread::AgentModelInfo {
142            id: Self::model_id(model),
143            name: model.name().0,
144            icon: Some(provider.icon()),
145        }
146    }
147
148    fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
149        acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
150    }
151}
152
153pub struct NativeAgent {
154    /// Session ID -> Session mapping
155    sessions: HashMap<acp::SessionId, Session>,
156    /// Shared project context for all threads
157    project_context: Rc<RefCell<ProjectContext>>,
158    project_context_needs_refresh: watch::Sender<()>,
159    _maintain_project_context: Task<Result<()>>,
160    context_server_registry: Entity<ContextServerRegistry>,
161    /// Shared templates for all threads
162    templates: Arc<Templates>,
163    /// Cached model information
164    models: LanguageModels,
165    project: Entity<Project>,
166    prompt_store: Option<Entity<PromptStore>>,
167    fs: Arc<dyn Fs>,
168    _subscriptions: Vec<Subscription>,
169}
170
171impl NativeAgent {
172    pub async fn new(
173        project: Entity<Project>,
174        templates: Arc<Templates>,
175        prompt_store: Option<Entity<PromptStore>>,
176        fs: Arc<dyn Fs>,
177        cx: &mut AsyncApp,
178    ) -> Result<Entity<NativeAgent>> {
179        log::info!("Creating new NativeAgent");
180
181        let project_context = cx
182            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
183            .await;
184
185        cx.new(|cx| {
186            let mut subscriptions = vec![
187                cx.subscribe(&project, Self::handle_project_event),
188                cx.subscribe(
189                    &LanguageModelRegistry::global(cx),
190                    Self::handle_models_updated_event,
191                ),
192            ];
193            if let Some(prompt_store) = prompt_store.as_ref() {
194                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
195            }
196
197            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
198                watch::channel(());
199            Self {
200                sessions: HashMap::new(),
201                project_context: Rc::new(RefCell::new(project_context)),
202                project_context_needs_refresh: project_context_needs_refresh_tx,
203                _maintain_project_context: cx.spawn(async move |this, cx| {
204                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
205                }),
206                context_server_registry: cx.new(|cx| {
207                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
208                }),
209                templates,
210                models: LanguageModels::new(cx),
211                project,
212                prompt_store,
213                fs,
214                _subscriptions: subscriptions,
215            }
216        })
217    }
218
219    pub fn models(&self) -> &LanguageModels {
220        &self.models
221    }
222
223    async fn maintain_project_context(
224        this: WeakEntity<Self>,
225        mut needs_refresh: watch::Receiver<()>,
226        cx: &mut AsyncApp,
227    ) -> Result<()> {
228        while needs_refresh.changed().await.is_ok() {
229            let project_context = this
230                .update(cx, |this, cx| {
231                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
232                })?
233                .await;
234            this.update(cx, |this, _| this.project_context.replace(project_context))?;
235        }
236
237        Ok(())
238    }
239
240    fn build_project_context(
241        project: &Entity<Project>,
242        prompt_store: Option<&Entity<PromptStore>>,
243        cx: &mut App,
244    ) -> Task<ProjectContext> {
245        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
246        let worktree_tasks = worktrees
247            .into_iter()
248            .map(|worktree| {
249                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
250            })
251            .collect::<Vec<_>>();
252        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
253            prompt_store.read_with(cx, |prompt_store, cx| {
254                let prompts = prompt_store.default_prompt_metadata();
255                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
256                    let contents = prompt_store.load(prompt_metadata.id, cx);
257                    async move { (contents.await, prompt_metadata) }
258                });
259                cx.background_spawn(future::join_all(load_tasks))
260            })
261        } else {
262            Task::ready(vec![])
263        };
264
265        cx.spawn(async move |_cx| {
266            let (worktrees, default_user_rules) =
267                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
268
269            let worktrees = worktrees
270                .into_iter()
271                .map(|(worktree, _rules_error)| {
272                    // TODO: show error message
273                    // if let Some(rules_error) = rules_error {
274                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
275                    // }
276                    worktree
277                })
278                .collect::<Vec<_>>();
279
280            let default_user_rules = default_user_rules
281                .into_iter()
282                .flat_map(|(contents, prompt_metadata)| match contents {
283                    Ok(contents) => Some(UserRulesContext {
284                        uuid: match prompt_metadata.id {
285                            PromptId::User { uuid } => uuid,
286                            PromptId::EditWorkflow => return None,
287                        },
288                        title: prompt_metadata.title.map(|title| title.to_string()),
289                        contents,
290                    }),
291                    Err(_err) => {
292                        // TODO: show error message
293                        // this.update(cx, |_, cx| {
294                        //     cx.emit(RulesLoadingError {
295                        //         message: format!("{err:?}").into(),
296                        //     });
297                        // })
298                        // .ok();
299                        None
300                    }
301                })
302                .collect::<Vec<_>>();
303
304            ProjectContext::new(worktrees, default_user_rules)
305        })
306    }
307
308    fn load_worktree_info_for_system_prompt(
309        worktree: Entity<Worktree>,
310        project: Entity<Project>,
311        cx: &mut App,
312    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
313        let tree = worktree.read(cx);
314        let root_name = tree.root_name().into();
315        let abs_path = tree.abs_path();
316
317        let mut context = WorktreeContext {
318            root_name,
319            abs_path,
320            rules_file: None,
321        };
322
323        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
324        let Some(rules_task) = rules_task else {
325            return Task::ready((context, None));
326        };
327
328        cx.spawn(async move |_| {
329            let (rules_file, rules_file_error) = match rules_task.await {
330                Ok(rules_file) => (Some(rules_file), None),
331                Err(err) => (
332                    None,
333                    Some(RulesLoadingError {
334                        message: format!("{err}").into(),
335                    }),
336                ),
337            };
338            context.rules_file = rules_file;
339            (context, rules_file_error)
340        })
341    }
342
343    fn load_worktree_rules_file(
344        worktree: Entity<Worktree>,
345        project: Entity<Project>,
346        cx: &mut App,
347    ) -> Option<Task<Result<RulesFileContext>>> {
348        let worktree = worktree.read(cx);
349        let worktree_id = worktree.id();
350        let selected_rules_file = RULES_FILE_NAMES
351            .into_iter()
352            .filter_map(|name| {
353                worktree
354                    .entry_for_path(name)
355                    .filter(|entry| entry.is_file())
356                    .map(|entry| entry.path.clone())
357            })
358            .next();
359
360        // Note that Cline supports `.clinerules` being a directory, but that is not currently
361        // supported. This doesn't seem to occur often in GitHub repositories.
362        selected_rules_file.map(|path_in_worktree| {
363            let project_path = ProjectPath {
364                worktree_id,
365                path: path_in_worktree.clone(),
366            };
367            let buffer_task =
368                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
369            let rope_task = cx.spawn(async move |cx| {
370                buffer_task.await?.read_with(cx, |buffer, cx| {
371                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
372                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
373                })?
374            });
375            // Build a string from the rope on a background thread.
376            cx.background_spawn(async move {
377                let (project_entry_id, rope) = rope_task.await?;
378                anyhow::Ok(RulesFileContext {
379                    path_in_worktree,
380                    text: rope.to_string().trim().to_string(),
381                    project_entry_id: project_entry_id.to_usize(),
382                })
383            })
384        })
385    }
386
387    fn handle_project_event(
388        &mut self,
389        _project: Entity<Project>,
390        event: &project::Event,
391        _cx: &mut Context<Self>,
392    ) {
393        match event {
394            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
395                self.project_context_needs_refresh.send(()).ok();
396            }
397            project::Event::WorktreeUpdatedEntries(_, items) => {
398                if items.iter().any(|(path, _, _)| {
399                    RULES_FILE_NAMES
400                        .iter()
401                        .any(|name| path.as_ref() == Path::new(name))
402                }) {
403                    self.project_context_needs_refresh.send(()).ok();
404                }
405            }
406            _ => {}
407        }
408    }
409
410    fn handle_prompts_updated_event(
411        &mut self,
412        _prompt_store: Entity<PromptStore>,
413        _event: &prompt_store::PromptsUpdatedEvent,
414        _cx: &mut Context<Self>,
415    ) {
416        self.project_context_needs_refresh.send(()).ok();
417    }
418
419    fn handle_models_updated_event(
420        &mut self,
421        _registry: Entity<LanguageModelRegistry>,
422        _event: &language_model::Event,
423        cx: &mut Context<Self>,
424    ) {
425        self.models.refresh_list(cx);
426        for session in self.sessions.values_mut() {
427            session.thread.update(cx, |thread, _| {
428                let model_id = LanguageModels::model_id(&thread.selected_model);
429                if let Some(model) = self.models.model_from_id(&model_id) {
430                    thread.selected_model = model.clone();
431                }
432            });
433        }
434    }
435}
436
437/// Wrapper struct that implements the AgentConnection trait
438#[derive(Clone)]
439pub struct NativeAgentConnection(pub Entity<NativeAgent>);
440
441impl AgentModelSelector for NativeAgentConnection {
442    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
443        log::debug!("NativeAgentConnection::list_models called");
444        let list = self.0.read(cx).models.model_list.clone();
445        Task::ready(if list.is_empty() {
446            Err(anyhow::anyhow!("No models available"))
447        } else {
448            Ok(list)
449        })
450    }
451
452    fn select_model(
453        &self,
454        session_id: acp::SessionId,
455        model_id: acp_thread::AgentModelId,
456        cx: &mut App,
457    ) -> Task<Result<()>> {
458        log::info!("Setting model for session {}: {}", session_id, model_id);
459        let Some(thread) = self
460            .0
461            .read(cx)
462            .sessions
463            .get(&session_id)
464            .map(|session| session.thread.clone())
465        else {
466            return Task::ready(Err(anyhow!("Session not found")));
467        };
468
469        let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
470            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
471        };
472
473        thread.update(cx, |thread, _cx| {
474            thread.selected_model = model.clone();
475        });
476
477        update_settings_file::<AgentSettings>(
478            self.0.read(cx).fs.clone(),
479            cx,
480            move |settings, _cx| {
481                settings.set_model(model);
482            },
483        );
484
485        Task::ready(Ok(()))
486    }
487
488    fn selected_model(
489        &self,
490        session_id: &acp::SessionId,
491        cx: &mut App,
492    ) -> Task<Result<acp_thread::AgentModelInfo>> {
493        let session_id = session_id.clone();
494
495        let Some(thread) = self
496            .0
497            .read(cx)
498            .sessions
499            .get(&session_id)
500            .map(|session| session.thread.clone())
501        else {
502            return Task::ready(Err(anyhow!("Session not found")));
503        };
504        let model = thread.read(cx).selected_model.clone();
505        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
506        else {
507            return Task::ready(Err(anyhow!("Provider not found")));
508        };
509        Task::ready(Ok(LanguageModels::map_language_model_to_info(
510            &model, &provider,
511        )))
512    }
513
514    fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
515        self.0.read(cx).models.watch()
516    }
517}
518
519impl acp_thread::AgentConnection for NativeAgentConnection {
520    fn new_thread(
521        self: Rc<Self>,
522        project: Entity<Project>,
523        cwd: &Path,
524        cx: &mut AsyncApp,
525    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
526        let agent = self.0.clone();
527        log::info!("Creating new thread for project at: {:?}", cwd);
528
529        cx.spawn(async move |cx| {
530            log::debug!("Starting thread creation in async context");
531
532            // Generate session ID
533            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
534            log::info!("Created session with ID: {}", session_id);
535
536            // Create AcpThread
537            let acp_thread = cx.update(|cx| {
538                cx.new(|cx| {
539                    acp_thread::AcpThread::new(
540                        "agent2",
541                        self.clone(),
542                        project.clone(),
543                        session_id.clone(),
544                        cx,
545                    )
546                })
547            })?;
548            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
549
550            // Create Thread
551            let thread = agent.update(
552                cx,
553                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
554                    // Fetch default model from registry settings
555                    let registry = LanguageModelRegistry::read_global(cx);
556
557                    // Log available models for debugging
558                    let available_count = registry.available_models(cx).count();
559                    log::debug!("Total available models: {}", available_count);
560
561                    let default_model = registry
562                        .default_model()
563                        .and_then(|default_model| {
564                            agent
565                                .models
566                                .model_from_id(&LanguageModels::model_id(&default_model.model))
567                        })
568                        .ok_or_else(|| {
569                            log::warn!("No default model configured in settings");
570                            anyhow!(
571                                "No default model. Please configure a default model in settings."
572                            )
573                        })?;
574
575                    let thread = cx.new(|cx| {
576                        let mut thread = Thread::new(
577                            project.clone(),
578                            agent.project_context.clone(),
579                            agent.context_server_registry.clone(),
580                            action_log.clone(),
581                            agent.templates.clone(),
582                            default_model,
583                            cx,
584                        );
585                        thread.add_tool(CreateDirectoryTool::new(project.clone()));
586                        thread.add_tool(CopyPathTool::new(project.clone()));
587                        thread.add_tool(DiagnosticsTool::new(project.clone()));
588                        thread.add_tool(MovePathTool::new(project.clone()));
589                        thread.add_tool(ListDirectoryTool::new(project.clone()));
590                        thread.add_tool(OpenTool::new(project.clone()));
591                        thread.add_tool(ThinkingTool);
592                        thread.add_tool(FindPathTool::new(project.clone()));
593                        thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
594                        thread.add_tool(GrepTool::new(project.clone()));
595                        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
596                        thread.add_tool(EditFileTool::new(cx.entity()));
597                        thread.add_tool(NowTool);
598                        thread.add_tool(TerminalTool::new(project.clone(), cx));
599                        // TODO: Needs to be conditional based on zed model or not
600                        thread.add_tool(WebSearchTool);
601                        thread
602                    });
603
604                    Ok(thread)
605                },
606            )??;
607
608            // Store the session
609            agent.update(cx, |agent, cx| {
610                agent.sessions.insert(
611                    session_id,
612                    Session {
613                        thread,
614                        acp_thread: acp_thread.downgrade(),
615                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
616                            this.sessions.remove(acp_thread.session_id());
617                        }),
618                    },
619                );
620            })?;
621
622            Ok(acp_thread)
623        })
624    }
625
626    fn auth_methods(&self) -> &[acp::AuthMethod] {
627        &[] // No auth for in-process
628    }
629
630    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
631        Task::ready(Ok(()))
632    }
633
634    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
635        Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
636    }
637
638    fn prompt(
639        &self,
640        params: acp::PromptRequest,
641        cx: &mut App,
642    ) -> Task<Result<acp::PromptResponse>> {
643        let session_id = params.session_id.clone();
644        let agent = self.0.clone();
645        log::info!("Received prompt request for session: {}", session_id);
646        log::debug!("Prompt blocks count: {}", params.prompt.len());
647
648        cx.spawn(async move |cx| {
649            // Get session
650            let (thread, acp_thread) = agent
651                .update(cx, |agent, _| {
652                    agent
653                        .sessions
654                        .get_mut(&session_id)
655                        .map(|s| (s.thread.clone(), s.acp_thread.clone()))
656                })?
657                .ok_or_else(|| {
658                    log::error!("Session not found: {}", session_id);
659                    anyhow::anyhow!("Session not found")
660                })?;
661            log::debug!("Found session for: {}", session_id);
662
663            let message: Vec<MessageContent> = params
664                .prompt
665                .into_iter()
666                .map(Into::into)
667                .collect::<Vec<_>>();
668            log::info!("Converted prompt to message: {} chars", message.len());
669            log::debug!("Message content: {:?}", message);
670
671            // Get model using the ModelSelector capability (always available for agent2)
672            // Get the selected model from the thread directly
673            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
674
675            // Send to thread
676            log::info!("Sending message to thread with model: {:?}", model.name());
677            let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
678
679            // Handle response stream and forward to session.acp_thread
680            while let Some(result) = response_stream.next().await {
681                match result {
682                    Ok(event) => {
683                        log::trace!("Received completion event: {:?}", event);
684
685                        match event {
686                            AgentResponseEvent::Text(text) => {
687                                acp_thread.update(cx, |thread, cx| {
688                                    thread.push_assistant_content_block(
689                                        acp::ContentBlock::Text(acp::TextContent {
690                                            text,
691                                            annotations: None,
692                                        }),
693                                        false,
694                                        cx,
695                                    )
696                                })?;
697                            }
698                            AgentResponseEvent::Thinking(text) => {
699                                acp_thread.update(cx, |thread, cx| {
700                                    thread.push_assistant_content_block(
701                                        acp::ContentBlock::Text(acp::TextContent {
702                                            text,
703                                            annotations: None,
704                                        }),
705                                        true,
706                                        cx,
707                                    )
708                                })?;
709                            }
710                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
711                                tool_call,
712                                options,
713                                response,
714                            }) => {
715                                let recv = acp_thread.update(cx, |thread, cx| {
716                                    thread.request_tool_call_authorization(tool_call, options, cx)
717                                })?;
718                                cx.background_spawn(async move {
719                                    if let Some(option) = recv
720                                        .await
721                                        .context("authorization sender was dropped")
722                                        .log_err()
723                                    {
724                                        response
725                                            .send(option)
726                                            .map(|_| anyhow!("authorization receiver was dropped"))
727                                            .log_err();
728                                    }
729                                })
730                                .detach();
731                            }
732                            AgentResponseEvent::ToolCall(tool_call) => {
733                                acp_thread.update(cx, |thread, cx| {
734                                    thread.upsert_tool_call(tool_call, cx)
735                                })?;
736                            }
737                            AgentResponseEvent::ToolCallUpdate(update) => {
738                                acp_thread.update(cx, |thread, cx| {
739                                    thread.update_tool_call(update, cx)
740                                })??;
741                            }
742                            AgentResponseEvent::Stop(stop_reason) => {
743                                log::debug!("Assistant message complete: {:?}", stop_reason);
744                                return Ok(acp::PromptResponse { stop_reason });
745                            }
746                        }
747                    }
748                    Err(e) => {
749                        log::error!("Error in model response stream: {:?}", e);
750                        // TODO: Consider sending an error message to the UI
751                        break;
752                    }
753                }
754            }
755
756            log::info!("Response stream completed");
757            anyhow::Ok(acp::PromptResponse {
758                stop_reason: acp::StopReason::EndTurn,
759            })
760        })
761    }
762
763    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
764        log::info!("Cancelling on session: {}", session_id);
765        self.0.update(cx, |agent, cx| {
766            if let Some(agent) = agent.sessions.get(session_id) {
767                agent.thread.update(cx, |thread, _cx| thread.cancel());
768            }
769        });
770    }
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
777    use fs::FakeFs;
778    use gpui::TestAppContext;
779    use serde_json::json;
780    use settings::SettingsStore;
781
782    #[gpui::test]
783    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
784        init_test(cx);
785        let fs = FakeFs::new(cx.executor());
786        fs.insert_tree(
787            "/",
788            json!({
789                "a": {}
790            }),
791        )
792        .await;
793        let project = Project::test(fs.clone(), [], cx).await;
794        let agent = NativeAgent::new(
795            project.clone(),
796            Templates::new(),
797            None,
798            fs.clone(),
799            &mut cx.to_async(),
800        )
801        .await
802        .unwrap();
803        agent.read_with(cx, |agent, _| {
804            assert_eq!(agent.project_context.borrow().worktrees, vec![])
805        });
806
807        let worktree = project
808            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
809            .await
810            .unwrap();
811        cx.run_until_parked();
812        agent.read_with(cx, |agent, _| {
813            assert_eq!(
814                agent.project_context.borrow().worktrees,
815                vec![WorktreeContext {
816                    root_name: "a".into(),
817                    abs_path: Path::new("/a").into(),
818                    rules_file: None
819                }]
820            )
821        });
822
823        // Creating `/a/.rules` updates the project context.
824        fs.insert_file("/a/.rules", Vec::new()).await;
825        cx.run_until_parked();
826        agent.read_with(cx, |agent, cx| {
827            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
828            assert_eq!(
829                agent.project_context.borrow().worktrees,
830                vec![WorktreeContext {
831                    root_name: "a".into(),
832                    abs_path: Path::new("/a").into(),
833                    rules_file: Some(RulesFileContext {
834                        path_in_worktree: Path::new(".rules").into(),
835                        text: "".into(),
836                        project_entry_id: rules_entry.id.to_usize()
837                    })
838                }]
839            )
840        });
841    }
842
843    #[gpui::test]
844    async fn test_listing_models(cx: &mut TestAppContext) {
845        init_test(cx);
846        let fs = FakeFs::new(cx.executor());
847        fs.insert_tree("/", json!({ "a": {}  })).await;
848        let project = Project::test(fs.clone(), [], cx).await;
849        let connection = NativeAgentConnection(
850            NativeAgent::new(
851                project.clone(),
852                Templates::new(),
853                None,
854                fs.clone(),
855                &mut cx.to_async(),
856            )
857            .await
858            .unwrap(),
859        );
860
861        let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
862
863        let acp_thread::AgentModelList::Grouped(models) = models else {
864            panic!("Unexpected model group");
865        };
866        assert_eq!(
867            models,
868            IndexMap::from_iter([(
869                AgentModelGroupName("Fake".into()),
870                vec![AgentModelInfo {
871                    id: AgentModelId("fake/fake".into()),
872                    name: "Fake".into(),
873                    icon: Some(ui::IconName::ZedAssistant),
874                }]
875            )])
876        );
877    }
878
879    #[gpui::test]
880    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
881        init_test(cx);
882        let fs = FakeFs::new(cx.executor());
883        fs.create_dir(paths::settings_file().parent().unwrap())
884            .await
885            .unwrap();
886        fs.insert_file(
887            paths::settings_file(),
888            json!({
889                "agent": {
890                    "default_model": {
891                        "provider": "foo",
892                        "model": "bar"
893                    }
894                }
895            })
896            .to_string()
897            .into_bytes(),
898        )
899        .await;
900        let project = Project::test(fs.clone(), [], cx).await;
901
902        // Create the agent and connection
903        let agent = NativeAgent::new(
904            project.clone(),
905            Templates::new(),
906            None,
907            fs.clone(),
908            &mut cx.to_async(),
909        )
910        .await
911        .unwrap();
912        let connection = NativeAgentConnection(agent.clone());
913
914        // Create a thread/session
915        let acp_thread = cx
916            .update(|cx| {
917                Rc::new(connection.clone()).new_thread(
918                    project.clone(),
919                    Path::new("/a"),
920                    &mut cx.to_async(),
921                )
922            })
923            .await
924            .unwrap();
925
926        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
927
928        // Select a model
929        let model_id = AgentModelId("fake/fake".into());
930        cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
931            .await
932            .unwrap();
933
934        // Verify the thread has the selected model
935        agent.read_with(cx, |agent, _| {
936            let session = agent.sessions.get(&session_id).unwrap();
937            session.thread.read_with(cx, |thread, _| {
938                assert_eq!(thread.selected_model.id().0, "fake");
939            });
940        });
941
942        cx.run_until_parked();
943
944        // Verify settings file was updated
945        let settings_content = fs.load(paths::settings_file()).await.unwrap();
946        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
947
948        // Check that the agent settings contain the selected model
949        assert_eq!(
950            settings_json["agent"]["default_model"]["model"],
951            json!("fake")
952        );
953        assert_eq!(
954            settings_json["agent"]["default_model"]["provider"],
955            json!("fake")
956        );
957    }
958
959    fn init_test(cx: &mut TestAppContext) {
960        env_logger::try_init().ok();
961        cx.update(|cx| {
962            let settings_store = SettingsStore::test(cx);
963            cx.set_global(settings_store);
964            Project::init_settings(cx);
965            agent_settings::init(cx);
966            language::init(cx);
967            LanguageModelRegistry::test(cx);
968        });
969    }
970}