agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17pub use native_agent_server::NativeAgentServer;
  18pub use pattern_extraction::*;
  19pub use shell_command_parser::extract_commands;
  20pub use templates::*;
  21pub use thread::*;
  22pub use thread_store::*;
  23pub use tool_permissions::*;
  24pub use tools::*;
  25
  26use acp_thread::{
  27    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  28    AgentSessionListResponse, UserMessageId,
  29};
  30use agent_client_protocol as acp;
  31use anyhow::{Context as _, Result, anyhow};
  32use chrono::{DateTime, Utc};
  33use collections::{HashMap, HashSet, IndexMap};
  34use fs::Fs;
  35use futures::channel::{mpsc, oneshot};
  36use futures::future::Shared;
  37use futures::{FutureExt as _, StreamExt as _, future};
  38use gpui::{
  39    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  40};
  41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  42use project::{Project, ProjectItem, ProjectPath, Worktree};
  43use prompt_store::{
  44    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  45    WorktreeContext,
  46};
  47use serde::{Deserialize, Serialize};
  48use settings::{LanguageModelSelection, update_settings_file};
  49use std::any::Any;
  50use std::path::{Path, PathBuf};
  51use std::rc::Rc;
  52use std::sync::Arc;
  53use std::time::Duration;
  54use util::ResultExt;
  55use util::rel_path::RelPath;
  56
  57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  58pub struct ProjectSnapshot {
  59    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  60    pub timestamp: DateTime<Utc>,
  61}
  62
  63pub struct RulesLoadingError {
  64    pub message: SharedString,
  65}
  66
  67/// Holds both the internal Thread and the AcpThread for a session
  68struct Session {
  69    /// The internal thread that processes messages
  70    thread: Entity<Thread>,
  71    /// The ACP thread that handles protocol communication
  72    acp_thread: Entity<acp_thread::AcpThread>,
  73    pending_save: Task<()>,
  74    _subscriptions: Vec<Subscription>,
  75}
  76
  77pub struct LanguageModels {
  78    /// Access language model by ID
  79    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  80    /// Cached list for returning language model information
  81    model_list: acp_thread::AgentModelList,
  82    refresh_models_rx: watch::Receiver<()>,
  83    refresh_models_tx: watch::Sender<()>,
  84    _authenticate_all_providers_task: Task<()>,
  85}
  86
  87impl LanguageModels {
  88    fn new(cx: &mut App) -> Self {
  89        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  90
  91        let mut this = Self {
  92            models: HashMap::default(),
  93            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  94            refresh_models_rx,
  95            refresh_models_tx,
  96            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  97        };
  98        this.refresh_list(cx);
  99        this
 100    }
 101
 102    fn refresh_list(&mut self, cx: &App) {
 103        let providers = LanguageModelRegistry::global(cx)
 104            .read(cx)
 105            .visible_providers()
 106            .into_iter()
 107            .filter(|provider| provider.is_authenticated(cx))
 108            .collect::<Vec<_>>();
 109
 110        let mut language_model_list = IndexMap::default();
 111        let mut recommended_models = HashSet::default();
 112
 113        let mut recommended = Vec::new();
 114        for provider in &providers {
 115            for model in provider.recommended_models(cx) {
 116                recommended_models.insert((model.provider_id(), model.id()));
 117                recommended.push(Self::map_language_model_to_info(&model, provider));
 118            }
 119        }
 120        if !recommended.is_empty() {
 121            language_model_list.insert(
 122                acp_thread::AgentModelGroupName("Recommended".into()),
 123                recommended,
 124            );
 125        }
 126
 127        let mut models = HashMap::default();
 128        for provider in providers {
 129            let mut provider_models = Vec::new();
 130            for model in provider.provided_models(cx) {
 131                let model_info = Self::map_language_model_to_info(&model, &provider);
 132                let model_id = model_info.id.clone();
 133                provider_models.push(model_info);
 134                models.insert(model_id, model);
 135            }
 136            if !provider_models.is_empty() {
 137                language_model_list.insert(
 138                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 139                    provider_models,
 140                );
 141            }
 142        }
 143
 144        self.models = models;
 145        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 146        self.refresh_models_tx.send(()).ok();
 147    }
 148
 149    fn watch(&self) -> watch::Receiver<()> {
 150        self.refresh_models_rx.clone()
 151    }
 152
 153    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 154        self.models.get(model_id).cloned()
 155    }
 156
 157    fn map_language_model_to_info(
 158        model: &Arc<dyn LanguageModel>,
 159        provider: &Arc<dyn LanguageModelProvider>,
 160    ) -> acp_thread::AgentModelInfo {
 161        acp_thread::AgentModelInfo {
 162            id: Self::model_id(model),
 163            name: model.name().0,
 164            description: None,
 165            icon: Some(match provider.icon() {
 166                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 167                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 168            }),
 169            is_latest: model.is_latest(),
 170            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 171        }
 172    }
 173
 174    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 175        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 176    }
 177
 178    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 179        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 180            .read(cx)
 181            .visible_providers()
 182            .iter()
 183            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 184            .collect::<Vec<_>>();
 185
 186        cx.background_spawn(async move {
 187            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 188                if let Err(err) = authenticate_task.await {
 189                    match err {
 190                        language_model::AuthenticateError::CredentialsNotFound => {
 191                            // Since we're authenticating these providers in the
 192                            // background for the purposes of populating the
 193                            // language selector, we don't care about providers
 194                            // where the credentials are not found.
 195                        }
 196                        language_model::AuthenticateError::ConnectionRefused => {
 197                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 198                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 199                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 200                        }
 201                        _ => {
 202                            // Some providers have noisy failure states that we
 203                            // don't want to spam the logs with every time the
 204                            // language model selector is initialized.
 205                            //
 206                            // Ideally these should have more clear failure modes
 207                            // that we know are safe to ignore here, like what we do
 208                            // with `CredentialsNotFound` above.
 209                            match provider_id.0.as_ref() {
 210                                "lmstudio" | "ollama" => {
 211                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 212                                    //
 213                                    // These fail noisily, so we don't log them.
 214                                }
 215                                "copilot_chat" => {
 216                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 217                                }
 218                                _ => {
 219                                    log::error!(
 220                                        "Failed to authenticate provider: {}: {err:#}",
 221                                        provider_name.0
 222                                    );
 223                                }
 224                            }
 225                        }
 226                    }
 227                }
 228            }
 229        })
 230    }
 231}
 232
 233pub struct NativeAgent {
 234    /// Session ID -> Session mapping
 235    sessions: HashMap<acp::SessionId, Session>,
 236    thread_store: Entity<ThreadStore>,
 237    /// Shared project context for all threads
 238    project_context: Entity<ProjectContext>,
 239    project_context_needs_refresh: watch::Sender<()>,
 240    _maintain_project_context: Task<Result<()>>,
 241    context_server_registry: Entity<ContextServerRegistry>,
 242    /// Shared templates for all threads
 243    templates: Arc<Templates>,
 244    /// Cached model information
 245    models: LanguageModels,
 246    project: Entity<Project>,
 247    prompt_store: Option<Entity<PromptStore>>,
 248    fs: Arc<dyn Fs>,
 249    _subscriptions: Vec<Subscription>,
 250}
 251
 252impl NativeAgent {
 253    pub async fn new(
 254        project: Entity<Project>,
 255        thread_store: Entity<ThreadStore>,
 256        templates: Arc<Templates>,
 257        prompt_store: Option<Entity<PromptStore>>,
 258        fs: Arc<dyn Fs>,
 259        cx: &mut AsyncApp,
 260    ) -> Result<Entity<NativeAgent>> {
 261        log::debug!("Creating new NativeAgent");
 262
 263        let project_context = cx
 264            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 265            .await;
 266
 267        Ok(cx.new(|cx| {
 268            let context_server_store = project.read(cx).context_server_store();
 269            let context_server_registry =
 270                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 271
 272            let mut subscriptions = vec![
 273                cx.subscribe(&project, Self::handle_project_event),
 274                cx.subscribe(
 275                    &LanguageModelRegistry::global(cx),
 276                    Self::handle_models_updated_event,
 277                ),
 278                cx.subscribe(
 279                    &context_server_store,
 280                    Self::handle_context_server_store_updated,
 281                ),
 282                cx.subscribe(
 283                    &context_server_registry,
 284                    Self::handle_context_server_registry_event,
 285                ),
 286            ];
 287            if let Some(prompt_store) = prompt_store.as_ref() {
 288                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 289            }
 290
 291            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 292                watch::channel(());
 293            Self {
 294                sessions: HashMap::default(),
 295                thread_store,
 296                project_context: cx.new(|_| project_context),
 297                project_context_needs_refresh: project_context_needs_refresh_tx,
 298                _maintain_project_context: cx.spawn(async move |this, cx| {
 299                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 300                }),
 301                context_server_registry,
 302                templates,
 303                models: LanguageModels::new(cx),
 304                project,
 305                prompt_store,
 306                fs,
 307                _subscriptions: subscriptions,
 308            }
 309        }))
 310    }
 311
 312    fn new_session(
 313        &mut self,
 314        project: Entity<Project>,
 315        cx: &mut Context<Self>,
 316    ) -> Entity<AcpThread> {
 317        // Create Thread
 318        // Fetch default model from registry settings
 319        let registry = LanguageModelRegistry::read_global(cx);
 320        // Log available models for debugging
 321        let available_count = registry.available_models(cx).count();
 322        log::debug!("Total available models: {}", available_count);
 323
 324        let default_model = registry.default_model().and_then(|default_model| {
 325            self.models
 326                .model_from_id(&LanguageModels::model_id(&default_model.model))
 327        });
 328        let thread = cx.new(|cx| {
 329            Thread::new(
 330                project.clone(),
 331                self.project_context.clone(),
 332                self.context_server_registry.clone(),
 333                self.templates.clone(),
 334                default_model,
 335                cx,
 336            )
 337        });
 338
 339        self.register_session(thread, None, cx)
 340    }
 341
 342    fn register_session(
 343        &mut self,
 344        thread_handle: Entity<Thread>,
 345        allowed_tool_names: Option<Vec<SharedString>>,
 346        cx: &mut Context<Self>,
 347    ) -> Entity<AcpThread> {
 348        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 349
 350        let thread = thread_handle.read(cx);
 351        let session_id = thread.id().clone();
 352        let parent_session_id = thread.parent_thread_id();
 353        let title = thread.title();
 354        let project = thread.project.clone();
 355        let action_log = thread.action_log.clone();
 356        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 357        let acp_thread = cx.new(|cx| {
 358            acp_thread::AcpThread::new(
 359                parent_session_id,
 360                title,
 361                connection,
 362                project.clone(),
 363                action_log.clone(),
 364                session_id.clone(),
 365                prompt_capabilities_rx,
 366                cx,
 367            )
 368        });
 369
 370        let registry = LanguageModelRegistry::read_global(cx);
 371        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 372
 373        let weak = cx.weak_entity();
 374        thread_handle.update(cx, |thread, cx| {
 375            thread.set_summarization_model(summarization_model, cx);
 376            thread.add_default_tools(
 377                allowed_tool_names,
 378                Rc::new(NativeThreadEnvironment {
 379                    acp_thread: acp_thread.downgrade(),
 380                    agent: weak,
 381                }) as _,
 382                cx,
 383            )
 384        });
 385
 386        let subscriptions = vec![
 387            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 388            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 389            cx.observe(&thread_handle, move |this, thread, cx| {
 390                this.save_thread(thread, cx)
 391            }),
 392        ];
 393
 394        self.sessions.insert(
 395            session_id,
 396            Session {
 397                thread: thread_handle,
 398                acp_thread: acp_thread.clone(),
 399                _subscriptions: subscriptions,
 400                pending_save: Task::ready(()),
 401            },
 402        );
 403
 404        self.update_available_commands(cx);
 405
 406        acp_thread
 407    }
 408
 409    pub fn models(&self) -> &LanguageModels {
 410        &self.models
 411    }
 412
 413    async fn maintain_project_context(
 414        this: WeakEntity<Self>,
 415        mut needs_refresh: watch::Receiver<()>,
 416        cx: &mut AsyncApp,
 417    ) -> Result<()> {
 418        while needs_refresh.changed().await.is_ok() {
 419            let project_context = this
 420                .update(cx, |this, cx| {
 421                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 422                })?
 423                .await;
 424            this.update(cx, |this, cx| {
 425                this.project_context = cx.new(|_| project_context);
 426            })?;
 427        }
 428
 429        Ok(())
 430    }
 431
 432    fn build_project_context(
 433        project: &Entity<Project>,
 434        prompt_store: Option<&Entity<PromptStore>>,
 435        cx: &mut App,
 436    ) -> Task<ProjectContext> {
 437        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 438        let worktree_tasks = worktrees
 439            .into_iter()
 440            .map(|worktree| {
 441                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 442            })
 443            .collect::<Vec<_>>();
 444        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 445            prompt_store.read_with(cx, |prompt_store, cx| {
 446                let prompts = prompt_store.default_prompt_metadata();
 447                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 448                    let contents = prompt_store.load(prompt_metadata.id, cx);
 449                    async move { (contents.await, prompt_metadata) }
 450                });
 451                cx.background_spawn(future::join_all(load_tasks))
 452            })
 453        } else {
 454            Task::ready(vec![])
 455        };
 456
 457        cx.spawn(async move |_cx| {
 458            let (worktrees, default_user_rules) =
 459                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 460
 461            let worktrees = worktrees
 462                .into_iter()
 463                .map(|(worktree, _rules_error)| {
 464                    // TODO: show error message
 465                    // if let Some(rules_error) = rules_error {
 466                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 467                    // }
 468                    worktree
 469                })
 470                .collect::<Vec<_>>();
 471
 472            let default_user_rules = default_user_rules
 473                .into_iter()
 474                .flat_map(|(contents, prompt_metadata)| match contents {
 475                    Ok(contents) => Some(UserRulesContext {
 476                        uuid: prompt_metadata.id.as_user()?,
 477                        title: prompt_metadata.title.map(|title| title.to_string()),
 478                        contents,
 479                    }),
 480                    Err(_err) => {
 481                        // TODO: show error message
 482                        // this.update(cx, |_, cx| {
 483                        //     cx.emit(RulesLoadingError {
 484                        //         message: format!("{err:?}").into(),
 485                        //     });
 486                        // })
 487                        // .ok();
 488                        None
 489                    }
 490                })
 491                .collect::<Vec<_>>();
 492
 493            ProjectContext::new(worktrees, default_user_rules)
 494        })
 495    }
 496
 497    fn load_worktree_info_for_system_prompt(
 498        worktree: Entity<Worktree>,
 499        project: Entity<Project>,
 500        cx: &mut App,
 501    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 502        let tree = worktree.read(cx);
 503        let root_name = tree.root_name_str().into();
 504        let abs_path = tree.abs_path();
 505
 506        let mut context = WorktreeContext {
 507            root_name,
 508            abs_path,
 509            rules_file: None,
 510        };
 511
 512        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 513        let Some(rules_task) = rules_task else {
 514            return Task::ready((context, None));
 515        };
 516
 517        cx.spawn(async move |_| {
 518            let (rules_file, rules_file_error) = match rules_task.await {
 519                Ok(rules_file) => (Some(rules_file), None),
 520                Err(err) => (
 521                    None,
 522                    Some(RulesLoadingError {
 523                        message: format!("{err}").into(),
 524                    }),
 525                ),
 526            };
 527            context.rules_file = rules_file;
 528            (context, rules_file_error)
 529        })
 530    }
 531
 532    fn load_worktree_rules_file(
 533        worktree: Entity<Worktree>,
 534        project: Entity<Project>,
 535        cx: &mut App,
 536    ) -> Option<Task<Result<RulesFileContext>>> {
 537        let worktree = worktree.read(cx);
 538        let worktree_id = worktree.id();
 539        let selected_rules_file = RULES_FILE_NAMES
 540            .into_iter()
 541            .filter_map(|name| {
 542                worktree
 543                    .entry_for_path(RelPath::unix(name).unwrap())
 544                    .filter(|entry| entry.is_file())
 545                    .map(|entry| entry.path.clone())
 546            })
 547            .next();
 548
 549        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 550        // supported. This doesn't seem to occur often in GitHub repositories.
 551        selected_rules_file.map(|path_in_worktree| {
 552            let project_path = ProjectPath {
 553                worktree_id,
 554                path: path_in_worktree.clone(),
 555            };
 556            let buffer_task =
 557                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 558            let rope_task = cx.spawn(async move |cx| {
 559                let buffer = buffer_task.await?;
 560                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 561                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 562                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 563                })?;
 564                anyhow::Ok((project_entry_id, rope))
 565            });
 566            // Build a string from the rope on a background thread.
 567            cx.background_spawn(async move {
 568                let (project_entry_id, rope) = rope_task.await?;
 569                anyhow::Ok(RulesFileContext {
 570                    path_in_worktree,
 571                    text: rope.to_string().trim().to_string(),
 572                    project_entry_id: project_entry_id.to_usize(),
 573                })
 574            })
 575        })
 576    }
 577
 578    fn handle_thread_title_updated(
 579        &mut self,
 580        thread: Entity<Thread>,
 581        _: &TitleUpdated,
 582        cx: &mut Context<Self>,
 583    ) {
 584        let session_id = thread.read(cx).id();
 585        let Some(session) = self.sessions.get(session_id) else {
 586            return;
 587        };
 588        let thread = thread.downgrade();
 589        let acp_thread = session.acp_thread.downgrade();
 590        cx.spawn(async move |_, cx| {
 591            let title = thread.read_with(cx, |thread, _| thread.title())?;
 592            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 593            task.await
 594        })
 595        .detach_and_log_err(cx);
 596    }
 597
 598    fn handle_thread_token_usage_updated(
 599        &mut self,
 600        thread: Entity<Thread>,
 601        usage: &TokenUsageUpdated,
 602        cx: &mut Context<Self>,
 603    ) {
 604        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 605            return;
 606        };
 607        session.acp_thread.update(cx, |acp_thread, cx| {
 608            acp_thread.update_token_usage(usage.0.clone(), cx);
 609        });
 610    }
 611
 612    fn handle_project_event(
 613        &mut self,
 614        _project: Entity<Project>,
 615        event: &project::Event,
 616        _cx: &mut Context<Self>,
 617    ) {
 618        match event {
 619            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 620                self.project_context_needs_refresh.send(()).ok();
 621            }
 622            project::Event::WorktreeUpdatedEntries(_, items) => {
 623                if items.iter().any(|(path, _, _)| {
 624                    RULES_FILE_NAMES
 625                        .iter()
 626                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 627                }) {
 628                    self.project_context_needs_refresh.send(()).ok();
 629                }
 630            }
 631            _ => {}
 632        }
 633    }
 634
 635    fn handle_prompts_updated_event(
 636        &mut self,
 637        _prompt_store: Entity<PromptStore>,
 638        _event: &prompt_store::PromptsUpdatedEvent,
 639        _cx: &mut Context<Self>,
 640    ) {
 641        self.project_context_needs_refresh.send(()).ok();
 642    }
 643
 644    fn handle_models_updated_event(
 645        &mut self,
 646        _registry: Entity<LanguageModelRegistry>,
 647        _event: &language_model::Event,
 648        cx: &mut Context<Self>,
 649    ) {
 650        self.models.refresh_list(cx);
 651
 652        let registry = LanguageModelRegistry::read_global(cx);
 653        let default_model = registry.default_model().map(|m| m.model);
 654        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 655
 656        for session in self.sessions.values_mut() {
 657            session.thread.update(cx, |thread, cx| {
 658                if thread.model().is_none()
 659                    && let Some(model) = default_model.clone()
 660                {
 661                    thread.set_model(model, cx);
 662                    cx.notify();
 663                }
 664                thread.set_summarization_model(summarization_model.clone(), cx);
 665            });
 666        }
 667    }
 668
 669    fn handle_context_server_store_updated(
 670        &mut self,
 671        _store: Entity<project::context_server_store::ContextServerStore>,
 672        _event: &project::context_server_store::ServerStatusChangedEvent,
 673        cx: &mut Context<Self>,
 674    ) {
 675        self.update_available_commands(cx);
 676    }
 677
 678    fn handle_context_server_registry_event(
 679        &mut self,
 680        _registry: Entity<ContextServerRegistry>,
 681        event: &ContextServerRegistryEvent,
 682        cx: &mut Context<Self>,
 683    ) {
 684        match event {
 685            ContextServerRegistryEvent::ToolsChanged => {}
 686            ContextServerRegistryEvent::PromptsChanged => {
 687                self.update_available_commands(cx);
 688            }
 689        }
 690    }
 691
 692    fn update_available_commands(&self, cx: &mut Context<Self>) {
 693        let available_commands = self.build_available_commands(cx);
 694        for session in self.sessions.values() {
 695            session.acp_thread.update(cx, |thread, cx| {
 696                thread
 697                    .handle_session_update(
 698                        acp::SessionUpdate::AvailableCommandsUpdate(
 699                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 700                        ),
 701                        cx,
 702                    )
 703                    .log_err();
 704            });
 705        }
 706    }
 707
 708    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 709        let registry = self.context_server_registry.read(cx);
 710
 711        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 712        for context_server_prompt in registry.prompts() {
 713            *prompt_name_counts
 714                .entry(context_server_prompt.prompt.name.as_str())
 715                .or_insert(0) += 1;
 716        }
 717
 718        registry
 719            .prompts()
 720            .flat_map(|context_server_prompt| {
 721                let prompt = &context_server_prompt.prompt;
 722
 723                let should_prefix = prompt_name_counts
 724                    .get(prompt.name.as_str())
 725                    .copied()
 726                    .unwrap_or(0)
 727                    > 1;
 728
 729                let name = if should_prefix {
 730                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 731                } else {
 732                    prompt.name.clone()
 733                };
 734
 735                let mut command = acp::AvailableCommand::new(
 736                    name,
 737                    prompt.description.clone().unwrap_or_default(),
 738                );
 739
 740                match prompt.arguments.as_deref() {
 741                    Some([arg]) => {
 742                        let hint = format!("<{}>", arg.name);
 743
 744                        command = command.input(acp::AvailableCommandInput::Unstructured(
 745                            acp::UnstructuredCommandInput::new(hint),
 746                        ));
 747                    }
 748                    Some([]) | None => {}
 749                    Some(_) => {
 750                        // skip >1 argument commands since we don't support them yet
 751                        return None;
 752                    }
 753                }
 754
 755                Some(command)
 756            })
 757            .collect()
 758    }
 759
 760    pub fn load_thread(
 761        &mut self,
 762        id: acp::SessionId,
 763        cx: &mut Context<Self>,
 764    ) -> Task<Result<Entity<Thread>>> {
 765        let database_future = ThreadsDatabase::connect(cx);
 766        cx.spawn(async move |this, cx| {
 767            let database = database_future.await.map_err(|err| anyhow!(err))?;
 768            let db_thread = database
 769                .load_thread(id.clone())
 770                .await?
 771                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 772
 773            this.update(cx, |this, cx| {
 774                let summarization_model = LanguageModelRegistry::read_global(cx)
 775                    .thread_summary_model()
 776                    .map(|c| c.model);
 777
 778                cx.new(|cx| {
 779                    let mut thread = Thread::from_db(
 780                        id.clone(),
 781                        db_thread,
 782                        this.project.clone(),
 783                        this.project_context.clone(),
 784                        this.context_server_registry.clone(),
 785                        this.templates.clone(),
 786                        cx,
 787                    );
 788                    thread.set_summarization_model(summarization_model, cx);
 789                    thread
 790                })
 791            })
 792        })
 793    }
 794
 795    pub fn open_thread(
 796        &mut self,
 797        id: acp::SessionId,
 798        cx: &mut Context<Self>,
 799    ) -> Task<Result<Entity<AcpThread>>> {
 800        if let Some(session) = self.sessions.get(&id) {
 801            return Task::ready(Ok(session.acp_thread.clone()));
 802        }
 803
 804        let task = self.load_thread(id, cx);
 805        cx.spawn(async move |this, cx| {
 806            let thread = task.await?;
 807            let acp_thread = this.update(cx, |this, cx| {
 808                this.register_session(thread.clone(), None, cx)
 809            })?;
 810            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 811            cx.update(|cx| {
 812                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 813            })
 814            .await?;
 815            Ok(acp_thread)
 816        })
 817    }
 818
 819    pub fn thread_summary(
 820        &mut self,
 821        id: acp::SessionId,
 822        cx: &mut Context<Self>,
 823    ) -> Task<Result<SharedString>> {
 824        let thread = self.open_thread(id.clone(), cx);
 825        cx.spawn(async move |this, cx| {
 826            let acp_thread = thread.await?;
 827            let result = this
 828                .update(cx, |this, cx| {
 829                    this.sessions
 830                        .get(&id)
 831                        .unwrap()
 832                        .thread
 833                        .update(cx, |thread, cx| thread.summary(cx))
 834                })?
 835                .await
 836                .context("Failed to generate summary")?;
 837            drop(acp_thread);
 838            Ok(result)
 839        })
 840    }
 841
 842    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 843        if thread.read(cx).is_empty() {
 844            return;
 845        }
 846
 847        let database_future = ThreadsDatabase::connect(cx);
 848        let (id, db_thread) =
 849            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 850        let Some(session) = self.sessions.get_mut(&id) else {
 851            return;
 852        };
 853        let thread_store = self.thread_store.clone();
 854        session.pending_save = cx.spawn(async move |_, cx| {
 855            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 856                return;
 857            };
 858            let db_thread = db_thread.await;
 859            database.save_thread(id, db_thread).await.log_err();
 860            thread_store.update(cx, |store, cx| store.reload(cx));
 861        });
 862    }
 863
 864    fn send_mcp_prompt(
 865        &self,
 866        message_id: UserMessageId,
 867        session_id: agent_client_protocol::SessionId,
 868        prompt_name: String,
 869        server_id: ContextServerId,
 870        arguments: HashMap<String, String>,
 871        original_content: Vec<acp::ContentBlock>,
 872        cx: &mut Context<Self>,
 873    ) -> Task<Result<acp::PromptResponse>> {
 874        let server_store = self.context_server_registry.read(cx).server_store().clone();
 875        let path_style = self.project.read(cx).path_style(cx);
 876
 877        cx.spawn(async move |this, cx| {
 878            let prompt =
 879                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 880
 881            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 882                let session = this
 883                    .sessions
 884                    .get(&session_id)
 885                    .context("Failed to get session")?;
 886                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 887            })??;
 888
 889            let mut last_is_user = true;
 890
 891            thread.update(cx, |thread, cx| {
 892                thread.push_acp_user_block(
 893                    message_id,
 894                    original_content.into_iter().skip(1),
 895                    path_style,
 896                    cx,
 897                );
 898            });
 899
 900            for message in prompt.messages {
 901                let context_server::types::PromptMessage { role, content } = message;
 902                let block = mcp_message_content_to_acp_content_block(content);
 903
 904                match role {
 905                    context_server::types::Role::User => {
 906                        let id = acp_thread::UserMessageId::new();
 907
 908                        acp_thread.update(cx, |acp_thread, cx| {
 909                            acp_thread.push_user_content_block_with_indent(
 910                                Some(id.clone()),
 911                                block.clone(),
 912                                true,
 913                                cx,
 914                            );
 915                        });
 916
 917                        thread.update(cx, |thread, cx| {
 918                            thread.push_acp_user_block(id, [block], path_style, cx);
 919                        });
 920                    }
 921                    context_server::types::Role::Assistant => {
 922                        acp_thread.update(cx, |acp_thread, cx| {
 923                            acp_thread.push_assistant_content_block_with_indent(
 924                                block.clone(),
 925                                false,
 926                                true,
 927                                cx,
 928                            );
 929                        });
 930
 931                        thread.update(cx, |thread, cx| {
 932                            thread.push_acp_agent_block(block, cx);
 933                        });
 934                    }
 935                }
 936
 937                last_is_user = role == context_server::types::Role::User;
 938            }
 939
 940            let response_stream = thread.update(cx, |thread, cx| {
 941                if last_is_user {
 942                    thread.send_existing(cx)
 943                } else {
 944                    // Resume if MCP prompt did not end with a user message
 945                    thread.resume(cx)
 946                }
 947            })?;
 948
 949            cx.update(|cx| {
 950                NativeAgentConnection::handle_thread_events(
 951                    response_stream,
 952                    acp_thread.downgrade(),
 953                    cx,
 954                )
 955            })
 956            .await
 957        })
 958    }
 959}
 960
 961/// Wrapper struct that implements the AgentConnection trait
 962#[derive(Clone)]
 963pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 964
 965impl NativeAgentConnection {
 966    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 967        self.0
 968            .read(cx)
 969            .sessions
 970            .get(session_id)
 971            .map(|session| session.thread.clone())
 972    }
 973
 974    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 975        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 976    }
 977
 978    fn run_turn(
 979        &self,
 980        session_id: acp::SessionId,
 981        cx: &mut App,
 982        f: impl 'static
 983        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 984    ) -> Task<Result<acp::PromptResponse>> {
 985        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 986            agent
 987                .sessions
 988                .get_mut(&session_id)
 989                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 990        }) else {
 991            return Task::ready(Err(anyhow!("Session not found")));
 992        };
 993        log::debug!("Found session for: {}", session_id);
 994
 995        let response_stream = match f(thread, cx) {
 996            Ok(stream) => stream,
 997            Err(err) => return Task::ready(Err(err)),
 998        };
 999        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1000    }
1001
1002    fn handle_thread_events(
1003        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1004        acp_thread: WeakEntity<AcpThread>,
1005        cx: &App,
1006    ) -> Task<Result<acp::PromptResponse>> {
1007        cx.spawn(async move |cx| {
1008            // Handle response stream and forward to session.acp_thread
1009            while let Some(result) = events.next().await {
1010                match result {
1011                    Ok(event) => {
1012                        log::trace!("Received completion event: {:?}", event);
1013
1014                        match event {
1015                            ThreadEvent::UserMessage(message) => {
1016                                acp_thread.update(cx, |thread, cx| {
1017                                    for content in message.content {
1018                                        thread.push_user_content_block(
1019                                            Some(message.id.clone()),
1020                                            content.into(),
1021                                            cx,
1022                                        );
1023                                    }
1024                                })?;
1025                            }
1026                            ThreadEvent::AgentText(text) => {
1027                                acp_thread.update(cx, |thread, cx| {
1028                                    thread.push_assistant_content_block(text.into(), false, cx)
1029                                })?;
1030                            }
1031                            ThreadEvent::AgentThinking(text) => {
1032                                acp_thread.update(cx, |thread, cx| {
1033                                    thread.push_assistant_content_block(text.into(), true, cx)
1034                                })?;
1035                            }
1036                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1037                                tool_call,
1038                                options,
1039                                response,
1040                                context: _,
1041                            }) => {
1042                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1043                                    thread.request_tool_call_authorization(tool_call, options, cx)
1044                                })??;
1045                                cx.background_spawn(async move {
1046                                    if let acp::RequestPermissionOutcome::Selected(
1047                                        acp::SelectedPermissionOutcome { option_id, .. },
1048                                    ) = outcome_task.await
1049                                    {
1050                                        response
1051                                            .send(option_id)
1052                                            .map(|_| anyhow!("authorization receiver was dropped"))
1053                                            .log_err();
1054                                    }
1055                                })
1056                                .detach();
1057                            }
1058                            ThreadEvent::ToolCall(tool_call) => {
1059                                acp_thread.update(cx, |thread, cx| {
1060                                    thread.upsert_tool_call(tool_call, cx)
1061                                })??;
1062                            }
1063                            ThreadEvent::ToolCallUpdate(update) => {
1064                                acp_thread.update(cx, |thread, cx| {
1065                                    thread.update_tool_call(update, cx)
1066                                })??;
1067                            }
1068                            ThreadEvent::SubagentSpawned(session_id) => {
1069                                acp_thread.update(cx, |thread, cx| {
1070                                    thread.subagent_spawned(session_id, cx);
1071                                })?;
1072                            }
1073                            ThreadEvent::Retry(status) => {
1074                                acp_thread.update(cx, |thread, cx| {
1075                                    thread.update_retry_status(status, cx)
1076                                })?;
1077                            }
1078                            ThreadEvent::Stop(stop_reason) => {
1079                                log::debug!("Assistant message complete: {:?}", stop_reason);
1080                                return Ok(acp::PromptResponse::new(stop_reason));
1081                            }
1082                        }
1083                    }
1084                    Err(e) => {
1085                        log::error!("Error in model response stream: {:?}", e);
1086                        return Err(e);
1087                    }
1088                }
1089            }
1090
1091            log::debug!("Response stream completed");
1092            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1093        })
1094    }
1095}
1096
1097struct Command<'a> {
1098    prompt_name: &'a str,
1099    arg_value: &'a str,
1100    explicit_server_id: Option<&'a str>,
1101}
1102
1103impl<'a> Command<'a> {
1104    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1105        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1106            return None;
1107        };
1108        let text = text_content.text.trim();
1109        let command = text.strip_prefix('/')?;
1110        let (command, arg_value) = command
1111            .split_once(char::is_whitespace)
1112            .unwrap_or((command, ""));
1113
1114        if let Some((server_id, prompt_name)) = command.split_once('.') {
1115            Some(Self {
1116                prompt_name,
1117                arg_value,
1118                explicit_server_id: Some(server_id),
1119            })
1120        } else {
1121            Some(Self {
1122                prompt_name: command,
1123                arg_value,
1124                explicit_server_id: None,
1125            })
1126        }
1127    }
1128}
1129
1130struct NativeAgentModelSelector {
1131    session_id: acp::SessionId,
1132    connection: NativeAgentConnection,
1133}
1134
1135impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1136    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1137        log::debug!("NativeAgentConnection::list_models called");
1138        let list = self.connection.0.read(cx).models.model_list.clone();
1139        Task::ready(if list.is_empty() {
1140            Err(anyhow::anyhow!("No models available"))
1141        } else {
1142            Ok(list)
1143        })
1144    }
1145
1146    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1147        log::debug!(
1148            "Setting model for session {}: {}",
1149            self.session_id,
1150            model_id
1151        );
1152        let Some(thread) = self
1153            .connection
1154            .0
1155            .read(cx)
1156            .sessions
1157            .get(&self.session_id)
1158            .map(|session| session.thread.clone())
1159        else {
1160            return Task::ready(Err(anyhow!("Session not found")));
1161        };
1162
1163        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1164            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1165        };
1166
1167        // We want to reset the effort level when switching models, as the currently-selected effort level may
1168        // not be compatible.
1169        let effort = model
1170            .default_effort_level()
1171            .map(|effort_level| effort_level.value.to_string());
1172
1173        thread.update(cx, |thread, cx| {
1174            thread.set_model(model.clone(), cx);
1175            thread.set_thinking_effort(effort.clone(), cx);
1176            thread.set_thinking_enabled(model.supports_thinking(), cx);
1177        });
1178
1179        update_settings_file(
1180            self.connection.0.read(cx).fs.clone(),
1181            cx,
1182            move |settings, cx| {
1183                let provider = model.provider_id().0.to_string();
1184                let model = model.id().0.to_string();
1185                let enable_thinking = thread.read(cx).thinking_enabled();
1186                settings
1187                    .agent
1188                    .get_or_insert_default()
1189                    .set_model(LanguageModelSelection {
1190                        provider: provider.into(),
1191                        model,
1192                        enable_thinking,
1193                        effort,
1194                    });
1195            },
1196        );
1197
1198        Task::ready(Ok(()))
1199    }
1200
1201    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1202        let Some(thread) = self
1203            .connection
1204            .0
1205            .read(cx)
1206            .sessions
1207            .get(&self.session_id)
1208            .map(|session| session.thread.clone())
1209        else {
1210            return Task::ready(Err(anyhow!("Session not found")));
1211        };
1212        let Some(model) = thread.read(cx).model() else {
1213            return Task::ready(Err(anyhow!("Model not found")));
1214        };
1215        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1216        else {
1217            return Task::ready(Err(anyhow!("Provider not found")));
1218        };
1219        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1220            model, &provider,
1221        )))
1222    }
1223
1224    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1225        Some(self.connection.0.read(cx).models.watch())
1226    }
1227
1228    fn should_render_footer(&self) -> bool {
1229        true
1230    }
1231}
1232
1233impl acp_thread::AgentConnection for NativeAgentConnection {
1234    fn telemetry_id(&self) -> SharedString {
1235        "zed".into()
1236    }
1237
1238    fn new_session(
1239        self: Rc<Self>,
1240        project: Entity<Project>,
1241        cwd: &Path,
1242        cx: &mut App,
1243    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1244        log::debug!("Creating new thread for project at: {cwd:?}");
1245        Task::ready(Ok(self
1246            .0
1247            .update(cx, |agent, cx| agent.new_session(project, cx))))
1248    }
1249
1250    fn supports_load_session(&self) -> bool {
1251        true
1252    }
1253
1254    fn load_session(
1255        self: Rc<Self>,
1256        session: AgentSessionInfo,
1257        _project: Entity<Project>,
1258        _cwd: &Path,
1259        cx: &mut App,
1260    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1261        self.0
1262            .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1263    }
1264
1265    fn supports_close_session(&self) -> bool {
1266        true
1267    }
1268
1269    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1270        self.0.update(cx, |agent, _cx| {
1271            agent.sessions.remove(session_id);
1272        });
1273        Task::ready(Ok(()))
1274    }
1275
1276    fn auth_methods(&self) -> &[acp::AuthMethod] {
1277        &[] // No auth for in-process
1278    }
1279
1280    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1281        Task::ready(Ok(()))
1282    }
1283
1284    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1285        Some(Rc::new(NativeAgentModelSelector {
1286            session_id: session_id.clone(),
1287            connection: self.clone(),
1288        }) as Rc<dyn AgentModelSelector>)
1289    }
1290
1291    fn prompt(
1292        &self,
1293        id: Option<acp_thread::UserMessageId>,
1294        params: acp::PromptRequest,
1295        cx: &mut App,
1296    ) -> Task<Result<acp::PromptResponse>> {
1297        let id = id.expect("UserMessageId is required");
1298        let session_id = params.session_id.clone();
1299        log::info!("Received prompt request for session: {}", session_id);
1300        log::debug!("Prompt blocks count: {}", params.prompt.len());
1301
1302        if let Some(parsed_command) = Command::parse(&params.prompt) {
1303            let registry = self.0.read(cx).context_server_registry.read(cx);
1304
1305            let explicit_server_id = parsed_command
1306                .explicit_server_id
1307                .map(|server_id| ContextServerId(server_id.into()));
1308
1309            if let Some(prompt) =
1310                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1311            {
1312                let arguments = if !parsed_command.arg_value.is_empty()
1313                    && let Some(arg_name) = prompt
1314                        .prompt
1315                        .arguments
1316                        .as_ref()
1317                        .and_then(|args| args.first())
1318                        .map(|arg| arg.name.clone())
1319                {
1320                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1321                } else {
1322                    Default::default()
1323                };
1324
1325                let prompt_name = prompt.prompt.name.clone();
1326                let server_id = prompt.server_id.clone();
1327
1328                return self.0.update(cx, |agent, cx| {
1329                    agent.send_mcp_prompt(
1330                        id,
1331                        session_id.clone(),
1332                        prompt_name,
1333                        server_id,
1334                        arguments,
1335                        params.prompt,
1336                        cx,
1337                    )
1338                });
1339            };
1340        };
1341
1342        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1343
1344        self.run_turn(session_id, cx, move |thread, cx| {
1345            let content: Vec<UserMessageContent> = params
1346                .prompt
1347                .into_iter()
1348                .map(|block| UserMessageContent::from_content_block(block, path_style))
1349                .collect::<Vec<_>>();
1350            log::debug!("Converted prompt to message: {} chars", content.len());
1351            log::debug!("Message id: {:?}", id);
1352            log::debug!("Message content: {:?}", content);
1353
1354            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1355        })
1356    }
1357
1358    fn retry(
1359        &self,
1360        session_id: &acp::SessionId,
1361        _cx: &App,
1362    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1363        Some(Rc::new(NativeAgentSessionRetry {
1364            connection: self.clone(),
1365            session_id: session_id.clone(),
1366        }) as _)
1367    }
1368
1369    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1370        log::info!("Cancelling on session: {}", session_id);
1371        self.0.update(cx, |agent, cx| {
1372            if let Some(session) = agent.sessions.get(session_id) {
1373                session
1374                    .thread
1375                    .update(cx, |thread, cx| thread.cancel(cx))
1376                    .detach();
1377            }
1378        });
1379    }
1380
1381    fn truncate(
1382        &self,
1383        session_id: &agent_client_protocol::SessionId,
1384        cx: &App,
1385    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1386        self.0.read_with(cx, |agent, _cx| {
1387            agent.sessions.get(session_id).map(|session| {
1388                Rc::new(NativeAgentSessionTruncate {
1389                    thread: session.thread.clone(),
1390                    acp_thread: session.acp_thread.downgrade(),
1391                }) as _
1392            })
1393        })
1394    }
1395
1396    fn set_title(
1397        &self,
1398        session_id: &acp::SessionId,
1399        cx: &App,
1400    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1401        self.0.read_with(cx, |agent, _cx| {
1402            agent
1403                .sessions
1404                .get(session_id)
1405                .filter(|s| !s.thread.read(cx).is_subagent())
1406                .map(|session| {
1407                    Rc::new(NativeAgentSessionSetTitle {
1408                        thread: session.thread.clone(),
1409                    }) as _
1410                })
1411        })
1412    }
1413
1414    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1415        let thread_store = self.0.read(cx).thread_store.clone();
1416        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1417    }
1418
1419    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1420        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1421    }
1422
1423    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1424        self
1425    }
1426}
1427
1428impl acp_thread::AgentTelemetry for NativeAgentConnection {
1429    fn thread_data(
1430        &self,
1431        session_id: &acp::SessionId,
1432        cx: &mut App,
1433    ) -> Task<Result<serde_json::Value>> {
1434        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1435            return Task::ready(Err(anyhow!("Session not found")));
1436        };
1437
1438        let task = session.thread.read(cx).to_db(cx);
1439        cx.background_spawn(async move {
1440            serde_json::to_value(task.await).context("Failed to serialize thread")
1441        })
1442    }
1443}
1444
1445pub struct NativeAgentSessionList {
1446    thread_store: Entity<ThreadStore>,
1447    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1448    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1449    _subscription: Subscription,
1450}
1451
1452impl NativeAgentSessionList {
1453    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1454        let (tx, rx) = smol::channel::unbounded();
1455        let this_tx = tx.clone();
1456        let subscription = cx.observe(&thread_store, move |_, _| {
1457            this_tx
1458                .try_send(acp_thread::SessionListUpdate::Refresh)
1459                .ok();
1460        });
1461        Self {
1462            thread_store,
1463            updates_tx: tx,
1464            updates_rx: rx,
1465            _subscription: subscription,
1466        }
1467    }
1468
1469    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1470        AgentSessionInfo {
1471            session_id: entry.id,
1472            cwd: None,
1473            title: Some(entry.title),
1474            updated_at: Some(entry.updated_at),
1475            meta: None,
1476        }
1477    }
1478
1479    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1480        &self.thread_store
1481    }
1482}
1483
1484impl AgentSessionList for NativeAgentSessionList {
1485    fn list_sessions(
1486        &self,
1487        _request: AgentSessionListRequest,
1488        cx: &mut App,
1489    ) -> Task<Result<AgentSessionListResponse>> {
1490        let sessions = self
1491            .thread_store
1492            .read(cx)
1493            .entries()
1494            .map(Self::to_session_info)
1495            .collect();
1496        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1497    }
1498
1499    fn supports_delete(&self) -> bool {
1500        true
1501    }
1502
1503    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1504        self.thread_store
1505            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1506    }
1507
1508    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1509        self.thread_store
1510            .update(cx, |store, cx| store.delete_threads(cx))
1511    }
1512
1513    fn watch(
1514        &self,
1515        _cx: &mut App,
1516    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1517        Some(self.updates_rx.clone())
1518    }
1519
1520    fn notify_refresh(&self) {
1521        self.updates_tx
1522            .try_send(acp_thread::SessionListUpdate::Refresh)
1523            .ok();
1524    }
1525
1526    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1527        self
1528    }
1529}
1530
1531struct NativeAgentSessionTruncate {
1532    thread: Entity<Thread>,
1533    acp_thread: WeakEntity<AcpThread>,
1534}
1535
1536impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1537    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1538        match self.thread.update(cx, |thread, cx| {
1539            thread.truncate(message_id.clone(), cx)?;
1540            Ok(thread.latest_token_usage())
1541        }) {
1542            Ok(usage) => {
1543                self.acp_thread
1544                    .update(cx, |thread, cx| {
1545                        thread.update_token_usage(usage, cx);
1546                    })
1547                    .ok();
1548                Task::ready(Ok(()))
1549            }
1550            Err(error) => Task::ready(Err(error)),
1551        }
1552    }
1553}
1554
1555struct NativeAgentSessionRetry {
1556    connection: NativeAgentConnection,
1557    session_id: acp::SessionId,
1558}
1559
1560impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1561    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1562        self.connection
1563            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1564                thread.update(cx, |thread, cx| thread.resume(cx))
1565            })
1566    }
1567}
1568
1569struct NativeAgentSessionSetTitle {
1570    thread: Entity<Thread>,
1571}
1572
1573impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1574    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1575        self.thread
1576            .update(cx, |thread, cx| thread.set_title(title, cx));
1577        Task::ready(Ok(()))
1578    }
1579}
1580
1581pub struct NativeThreadEnvironment {
1582    agent: WeakEntity<NativeAgent>,
1583    acp_thread: WeakEntity<AcpThread>,
1584}
1585
1586impl NativeThreadEnvironment {
1587    pub(crate) fn create_subagent_thread(
1588        agent: WeakEntity<NativeAgent>,
1589        parent_thread_entity: Entity<Thread>,
1590        label: String,
1591        initial_prompt: String,
1592        timeout: Option<Duration>,
1593        cx: &mut App,
1594    ) -> Result<Rc<dyn SubagentHandle>> {
1595        let parent_thread = parent_thread_entity.read(cx);
1596        let current_depth = parent_thread.depth();
1597
1598        if current_depth >= MAX_SUBAGENT_DEPTH {
1599            return Err(anyhow!(
1600                "Maximum subagent depth ({}) reached",
1601                MAX_SUBAGENT_DEPTH
1602            ));
1603        }
1604        let allowed_tool_names = Some(parent_thread.tools.keys().cloned().collect::<Vec<_>>());
1605
1606        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1607            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1608            thread.set_title(label.into(), cx);
1609            thread
1610        });
1611
1612        let session_id = subagent_thread.read(cx).id().clone();
1613
1614        let acp_thread = agent.update(cx, |agent, cx| {
1615            agent.register_session(subagent_thread.clone(), allowed_tool_names, cx)
1616        })?;
1617
1618        parent_thread_entity.update(cx, |parent_thread, _cx| {
1619            parent_thread.register_running_subagent(subagent_thread.downgrade())
1620        });
1621
1622        let task = acp_thread.update(cx, |agent, cx| agent.send(vec![initial_prompt.into()], cx));
1623
1624        let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1625        let wait_for_prompt_to_complete = cx
1626            .background_spawn(async move {
1627                if let Some(timer) = timeout_timer {
1628                    futures::select! {
1629                        _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1630                        response = task.fuse() => {
1631                            let response = response.log_err().flatten();
1632                            if response.is_some_and(|response| {
1633                                response.stop_reason == acp::StopReason::Cancelled
1634                            })
1635                            {
1636                                SubagentInitialPromptResult::Cancelled
1637                            } else {
1638                                SubagentInitialPromptResult::Completed
1639                            }
1640                        },
1641                    }
1642                } else {
1643                    let response = task.await.log_err().flatten();
1644                    if response
1645                        .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled)
1646                    {
1647                        SubagentInitialPromptResult::Cancelled
1648                    } else {
1649                        SubagentInitialPromptResult::Completed
1650                    }
1651                }
1652            })
1653            .shared();
1654
1655        Ok(Rc::new(NativeSubagentHandle {
1656            session_id,
1657            subagent_thread,
1658            parent_thread: parent_thread_entity.downgrade(),
1659            wait_for_prompt_to_complete,
1660        }) as _)
1661    }
1662}
1663
1664impl ThreadEnvironment for NativeThreadEnvironment {
1665    fn create_terminal(
1666        &self,
1667        command: String,
1668        cwd: Option<PathBuf>,
1669        output_byte_limit: Option<u64>,
1670        cx: &mut AsyncApp,
1671    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1672        let task = self.acp_thread.update(cx, |thread, cx| {
1673            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1674        });
1675
1676        let acp_thread = self.acp_thread.clone();
1677        cx.spawn(async move |cx| {
1678            let terminal = task?.await?;
1679
1680            let (drop_tx, drop_rx) = oneshot::channel();
1681            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1682
1683            cx.spawn(async move |cx| {
1684                drop_rx.await.ok();
1685                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1686            })
1687            .detach();
1688
1689            let handle = AcpTerminalHandle {
1690                terminal,
1691                _drop_tx: Some(drop_tx),
1692            };
1693
1694            Ok(Rc::new(handle) as _)
1695        })
1696    }
1697
1698    fn create_subagent(
1699        &self,
1700        parent_thread_entity: Entity<Thread>,
1701        label: String,
1702        initial_prompt: String,
1703        timeout: Option<Duration>,
1704        cx: &mut App,
1705    ) -> Result<Rc<dyn SubagentHandle>> {
1706        Self::create_subagent_thread(
1707            self.agent.clone(),
1708            parent_thread_entity,
1709            label,
1710            initial_prompt,
1711            timeout,
1712            cx,
1713        )
1714    }
1715}
1716
1717#[derive(Debug, Clone, Copy)]
1718enum SubagentInitialPromptResult {
1719    Completed,
1720    Timeout,
1721    Cancelled,
1722}
1723
1724pub struct NativeSubagentHandle {
1725    session_id: acp::SessionId,
1726    parent_thread: WeakEntity<Thread>,
1727    subagent_thread: Entity<Thread>,
1728    wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1729}
1730
1731impl SubagentHandle for NativeSubagentHandle {
1732    fn id(&self) -> acp::SessionId {
1733        self.session_id.clone()
1734    }
1735
1736    fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
1737        let thread = self.subagent_thread.clone();
1738        let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1739
1740        let subagent_session_id = self.session_id.clone();
1741        let parent_thread = self.parent_thread.clone();
1742
1743        cx.spawn(async move |cx| {
1744            match wait_for_prompt.await {
1745                SubagentInitialPromptResult::Completed => {}
1746                SubagentInitialPromptResult::Timeout => {
1747                    return Err(anyhow!("The time to complete the task was exceeded."));
1748                }
1749                SubagentInitialPromptResult::Cancelled => return Err(anyhow!("User cancelled")),
1750            };
1751
1752            let result = thread.read_with(cx, |thread, _cx| {
1753                thread
1754                    .last_message()
1755                    .map(|m| m.to_markdown())
1756                    .context("No response from subagent")
1757            });
1758
1759            parent_thread
1760                .update(cx, |parent_thread, cx| {
1761                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1762                })
1763                .ok();
1764
1765            result
1766        })
1767    }
1768}
1769
1770pub struct AcpTerminalHandle {
1771    terminal: Entity<acp_thread::Terminal>,
1772    _drop_tx: Option<oneshot::Sender<()>>,
1773}
1774
1775impl TerminalHandle for AcpTerminalHandle {
1776    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1777        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1778    }
1779
1780    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1781        Ok(self
1782            .terminal
1783            .read_with(cx, |term, _cx| term.wait_for_exit()))
1784    }
1785
1786    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1787        Ok(self
1788            .terminal
1789            .read_with(cx, |term, cx| term.current_output(cx)))
1790    }
1791
1792    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1793        cx.update(|cx| {
1794            self.terminal.update(cx, |terminal, cx| {
1795                terminal.kill(cx);
1796            });
1797        });
1798        Ok(())
1799    }
1800
1801    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1802        Ok(self
1803            .terminal
1804            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1805    }
1806}
1807
1808#[cfg(test)]
1809mod internal_tests {
1810    use super::*;
1811    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1812    use fs::FakeFs;
1813    use gpui::TestAppContext;
1814    use indoc::formatdoc;
1815    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1816    use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1817    use serde_json::json;
1818    use settings::SettingsStore;
1819    use util::{path, rel_path::rel_path};
1820
1821    #[gpui::test]
1822    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1823        init_test(cx);
1824        let fs = FakeFs::new(cx.executor());
1825        fs.insert_tree(
1826            "/",
1827            json!({
1828                "a": {}
1829            }),
1830        )
1831        .await;
1832        let project = Project::test(fs.clone(), [], cx).await;
1833        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1834        let agent = NativeAgent::new(
1835            project.clone(),
1836            thread_store,
1837            Templates::new(),
1838            None,
1839            fs.clone(),
1840            &mut cx.to_async(),
1841        )
1842        .await
1843        .unwrap();
1844        agent.read_with(cx, |agent, cx| {
1845            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1846        });
1847
1848        let worktree = project
1849            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1850            .await
1851            .unwrap();
1852        cx.run_until_parked();
1853        agent.read_with(cx, |agent, cx| {
1854            assert_eq!(
1855                agent.project_context.read(cx).worktrees,
1856                vec![WorktreeContext {
1857                    root_name: "a".into(),
1858                    abs_path: Path::new("/a").into(),
1859                    rules_file: None
1860                }]
1861            )
1862        });
1863
1864        // Creating `/a/.rules` updates the project context.
1865        fs.insert_file("/a/.rules", Vec::new()).await;
1866        cx.run_until_parked();
1867        agent.read_with(cx, |agent, cx| {
1868            let rules_entry = worktree
1869                .read(cx)
1870                .entry_for_path(rel_path(".rules"))
1871                .unwrap();
1872            assert_eq!(
1873                agent.project_context.read(cx).worktrees,
1874                vec![WorktreeContext {
1875                    root_name: "a".into(),
1876                    abs_path: Path::new("/a").into(),
1877                    rules_file: Some(RulesFileContext {
1878                        path_in_worktree: rel_path(".rules").into(),
1879                        text: "".into(),
1880                        project_entry_id: rules_entry.id.to_usize()
1881                    })
1882                }]
1883            )
1884        });
1885    }
1886
1887    #[gpui::test]
1888    async fn test_listing_models(cx: &mut TestAppContext) {
1889        init_test(cx);
1890        let fs = FakeFs::new(cx.executor());
1891        fs.insert_tree("/", json!({ "a": {}  })).await;
1892        let project = Project::test(fs.clone(), [], cx).await;
1893        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1894        let connection = NativeAgentConnection(
1895            NativeAgent::new(
1896                project.clone(),
1897                thread_store,
1898                Templates::new(),
1899                None,
1900                fs.clone(),
1901                &mut cx.to_async(),
1902            )
1903            .await
1904            .unwrap(),
1905        );
1906
1907        // Create a thread/session
1908        let acp_thread = cx
1909            .update(|cx| {
1910                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1911            })
1912            .await
1913            .unwrap();
1914
1915        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1916
1917        let models = cx
1918            .update(|cx| {
1919                connection
1920                    .model_selector(&session_id)
1921                    .unwrap()
1922                    .list_models(cx)
1923            })
1924            .await
1925            .unwrap();
1926
1927        let acp_thread::AgentModelList::Grouped(models) = models else {
1928            panic!("Unexpected model group");
1929        };
1930        assert_eq!(
1931            models,
1932            IndexMap::from_iter([(
1933                AgentModelGroupName("Fake".into()),
1934                vec![AgentModelInfo {
1935                    id: acp::ModelId::new("fake/fake"),
1936                    name: "Fake".into(),
1937                    description: None,
1938                    icon: Some(acp_thread::AgentModelIcon::Named(
1939                        ui::IconName::ZedAssistant
1940                    )),
1941                    is_latest: false,
1942                    cost: None,
1943                }]
1944            )])
1945        );
1946    }
1947
1948    #[gpui::test]
1949    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1950        init_test(cx);
1951        let fs = FakeFs::new(cx.executor());
1952        fs.create_dir(paths::settings_file().parent().unwrap())
1953            .await
1954            .unwrap();
1955        fs.insert_file(
1956            paths::settings_file(),
1957            json!({
1958                "agent": {
1959                    "default_model": {
1960                        "provider": "foo",
1961                        "model": "bar"
1962                    }
1963                }
1964            })
1965            .to_string()
1966            .into_bytes(),
1967        )
1968        .await;
1969        let project = Project::test(fs.clone(), [], cx).await;
1970
1971        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1972
1973        // Create the agent and connection
1974        let agent = NativeAgent::new(
1975            project.clone(),
1976            thread_store,
1977            Templates::new(),
1978            None,
1979            fs.clone(),
1980            &mut cx.to_async(),
1981        )
1982        .await
1983        .unwrap();
1984        let connection = NativeAgentConnection(agent.clone());
1985
1986        // Create a thread/session
1987        let acp_thread = cx
1988            .update(|cx| {
1989                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1990            })
1991            .await
1992            .unwrap();
1993
1994        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1995
1996        // Select a model
1997        let selector = connection.model_selector(&session_id).unwrap();
1998        let model_id = acp::ModelId::new("fake/fake");
1999        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2000            .await
2001            .unwrap();
2002
2003        // Verify the thread has the selected model
2004        agent.read_with(cx, |agent, _| {
2005            let session = agent.sessions.get(&session_id).unwrap();
2006            session.thread.read_with(cx, |thread, _| {
2007                assert_eq!(thread.model().unwrap().id().0, "fake");
2008            });
2009        });
2010
2011        cx.run_until_parked();
2012
2013        // Verify settings file was updated
2014        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2015        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2016
2017        // Check that the agent settings contain the selected model
2018        assert_eq!(
2019            settings_json["agent"]["default_model"]["model"],
2020            json!("fake")
2021        );
2022        assert_eq!(
2023            settings_json["agent"]["default_model"]["provider"],
2024            json!("fake")
2025        );
2026
2027        // Register a thinking model and select it.
2028        cx.update(|cx| {
2029            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2030                "fake-corp",
2031                "fake-thinking",
2032                "Fake Thinking",
2033                true,
2034            ));
2035            let thinking_provider = Arc::new(
2036                FakeLanguageModelProvider::new(
2037                    LanguageModelProviderId::from("fake-corp".to_string()),
2038                    LanguageModelProviderName::from("Fake Corp".to_string()),
2039                )
2040                .with_models(vec![thinking_model]),
2041            );
2042            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2043                registry.register_provider(thinking_provider, cx);
2044            });
2045        });
2046        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2047
2048        let selector = connection.model_selector(&session_id).unwrap();
2049        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2050            .await
2051            .unwrap();
2052        cx.run_until_parked();
2053
2054        // Verify enable_thinking was written to settings as true.
2055        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2056        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2057        assert_eq!(
2058            settings_json["agent"]["default_model"]["enable_thinking"],
2059            json!(true),
2060            "selecting a thinking model should persist enable_thinking: true to settings"
2061        );
2062    }
2063
2064    #[gpui::test]
2065    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2066        init_test(cx);
2067        let fs = FakeFs::new(cx.executor());
2068        fs.create_dir(paths::settings_file().parent().unwrap())
2069            .await
2070            .unwrap();
2071        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2072        let project = Project::test(fs.clone(), [], cx).await;
2073
2074        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2075        let agent = NativeAgent::new(
2076            project.clone(),
2077            thread_store,
2078            Templates::new(),
2079            None,
2080            fs.clone(),
2081            &mut cx.to_async(),
2082        )
2083        .await
2084        .unwrap();
2085        let connection = NativeAgentConnection(agent.clone());
2086
2087        let acp_thread = cx
2088            .update(|cx| {
2089                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2090            })
2091            .await
2092            .unwrap();
2093        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2094
2095        // Register a second provider with a thinking model.
2096        cx.update(|cx| {
2097            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2098                "fake-corp",
2099                "fake-thinking",
2100                "Fake Thinking",
2101                true,
2102            ));
2103            let thinking_provider = Arc::new(
2104                FakeLanguageModelProvider::new(
2105                    LanguageModelProviderId::from("fake-corp".to_string()),
2106                    LanguageModelProviderName::from("Fake Corp".to_string()),
2107                )
2108                .with_models(vec![thinking_model]),
2109            );
2110            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2111                registry.register_provider(thinking_provider, cx);
2112            });
2113        });
2114        // Refresh the agent's model list so it picks up the new provider.
2115        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2116
2117        // Thread starts with thinking_enabled = false (the default).
2118        agent.read_with(cx, |agent, _| {
2119            let session = agent.sessions.get(&session_id).unwrap();
2120            session.thread.read_with(cx, |thread, _| {
2121                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2122            });
2123        });
2124
2125        // Select the thinking model via select_model.
2126        let selector = connection.model_selector(&session_id).unwrap();
2127        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2128            .await
2129            .unwrap();
2130
2131        // select_model should have enabled thinking based on the model's supports_thinking().
2132        agent.read_with(cx, |agent, _| {
2133            let session = agent.sessions.get(&session_id).unwrap();
2134            session.thread.read_with(cx, |thread, _| {
2135                assert!(
2136                    thread.thinking_enabled(),
2137                    "select_model should enable thinking when model supports it"
2138                );
2139            });
2140        });
2141
2142        // Switch back to the non-thinking model.
2143        let selector = connection.model_selector(&session_id).unwrap();
2144        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2145            .await
2146            .unwrap();
2147
2148        // select_model should have disabled thinking.
2149        agent.read_with(cx, |agent, _| {
2150            let session = agent.sessions.get(&session_id).unwrap();
2151            session.thread.read_with(cx, |thread, _| {
2152                assert!(
2153                    !thread.thinking_enabled(),
2154                    "select_model should disable thinking when model does not support it"
2155                );
2156            });
2157        });
2158    }
2159
2160    #[gpui::test]
2161    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2162        init_test(cx);
2163        let fs = FakeFs::new(cx.executor());
2164        fs.insert_tree("/", json!({ "a": {} })).await;
2165        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2166        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2167        let agent = NativeAgent::new(
2168            project.clone(),
2169            thread_store.clone(),
2170            Templates::new(),
2171            None,
2172            fs.clone(),
2173            &mut cx.to_async(),
2174        )
2175        .await
2176        .unwrap();
2177        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2178
2179        // Register a thinking model.
2180        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2181            "fake-corp",
2182            "fake-thinking",
2183            "Fake Thinking",
2184            true,
2185        ));
2186        let thinking_provider = Arc::new(
2187            FakeLanguageModelProvider::new(
2188                LanguageModelProviderId::from("fake-corp".to_string()),
2189                LanguageModelProviderName::from("Fake Corp".to_string()),
2190            )
2191            .with_models(vec![thinking_model.clone()]),
2192        );
2193        cx.update(|cx| {
2194            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2195                registry.register_provider(thinking_provider, cx);
2196            });
2197        });
2198        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2199
2200        // Create a thread and select the thinking model.
2201        let acp_thread = cx
2202            .update(|cx| {
2203                connection
2204                    .clone()
2205                    .new_session(project.clone(), Path::new("/a"), cx)
2206            })
2207            .await
2208            .unwrap();
2209        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2210
2211        let selector = connection.model_selector(&session_id).unwrap();
2212        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2213            .await
2214            .unwrap();
2215
2216        // Verify thinking is enabled after selecting the thinking model.
2217        let thread = agent.read_with(cx, |agent, _| {
2218            agent.sessions.get(&session_id).unwrap().thread.clone()
2219        });
2220        thread.read_with(cx, |thread, _| {
2221            assert!(
2222                thread.thinking_enabled(),
2223                "thinking should be enabled after selecting thinking model"
2224            );
2225        });
2226
2227        // Send a message so the thread gets persisted.
2228        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2229        let send = cx.foreground_executor().spawn(send);
2230        cx.run_until_parked();
2231
2232        thinking_model.send_last_completion_stream_text_chunk("Response.");
2233        thinking_model.end_last_completion_stream();
2234
2235        send.await.unwrap();
2236        cx.run_until_parked();
2237
2238        // Close the session so it can be reloaded from disk.
2239        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2240            .await
2241            .unwrap();
2242        drop(thread);
2243        drop(acp_thread);
2244        agent.read_with(cx, |agent, _| {
2245            assert!(agent.sessions.is_empty());
2246        });
2247
2248        // Reload the thread and verify thinking_enabled is still true.
2249        let reloaded_acp_thread = agent
2250            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2251            .await
2252            .unwrap();
2253        let reloaded_thread = agent.read_with(cx, |agent, _| {
2254            agent.sessions.get(&session_id).unwrap().thread.clone()
2255        });
2256        reloaded_thread.read_with(cx, |thread, _| {
2257            assert!(
2258                thread.thinking_enabled(),
2259                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2260            );
2261        });
2262
2263        drop(reloaded_acp_thread);
2264    }
2265
2266    #[gpui::test]
2267    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2268        init_test(cx);
2269        let fs = FakeFs::new(cx.executor());
2270        fs.insert_tree("/", json!({ "a": {} })).await;
2271        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2272        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2273        let agent = NativeAgent::new(
2274            project.clone(),
2275            thread_store.clone(),
2276            Templates::new(),
2277            None,
2278            fs.clone(),
2279            &mut cx.to_async(),
2280        )
2281        .await
2282        .unwrap();
2283        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2284
2285        // Register a model where id() != name(), like real Anthropic models
2286        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2287        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2288            "fake-corp",
2289            "custom-model-id",
2290            "Custom Model Display Name",
2291            false,
2292        ));
2293        let provider = Arc::new(
2294            FakeLanguageModelProvider::new(
2295                LanguageModelProviderId::from("fake-corp".to_string()),
2296                LanguageModelProviderName::from("Fake Corp".to_string()),
2297            )
2298            .with_models(vec![model.clone()]),
2299        );
2300        cx.update(|cx| {
2301            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2302                registry.register_provider(provider, cx);
2303            });
2304        });
2305        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2306
2307        // Create a thread and select the model.
2308        let acp_thread = cx
2309            .update(|cx| {
2310                connection
2311                    .clone()
2312                    .new_session(project.clone(), Path::new("/a"), cx)
2313            })
2314            .await
2315            .unwrap();
2316        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2317
2318        let selector = connection.model_selector(&session_id).unwrap();
2319        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2320            .await
2321            .unwrap();
2322
2323        let thread = agent.read_with(cx, |agent, _| {
2324            agent.sessions.get(&session_id).unwrap().thread.clone()
2325        });
2326        thread.read_with(cx, |thread, _| {
2327            assert_eq!(
2328                thread.model().unwrap().id().0.as_ref(),
2329                "custom-model-id",
2330                "model should be set before persisting"
2331            );
2332        });
2333
2334        // Send a message so the thread gets persisted.
2335        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2336        let send = cx.foreground_executor().spawn(send);
2337        cx.run_until_parked();
2338
2339        model.send_last_completion_stream_text_chunk("Response.");
2340        model.end_last_completion_stream();
2341
2342        send.await.unwrap();
2343        cx.run_until_parked();
2344
2345        // Close the session so it can be reloaded from disk.
2346        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2347            .await
2348            .unwrap();
2349        drop(thread);
2350        drop(acp_thread);
2351        agent.read_with(cx, |agent, _| {
2352            assert!(agent.sessions.is_empty());
2353        });
2354
2355        // Reload the thread and verify the model was preserved.
2356        let reloaded_acp_thread = agent
2357            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2358            .await
2359            .unwrap();
2360        let reloaded_thread = agent.read_with(cx, |agent, _| {
2361            agent.sessions.get(&session_id).unwrap().thread.clone()
2362        });
2363        reloaded_thread.read_with(cx, |thread, _| {
2364            let reloaded_model = thread
2365                .model()
2366                .expect("model should be present after reload");
2367            assert_eq!(
2368                reloaded_model.id().0.as_ref(),
2369                "custom-model-id",
2370                "reloaded thread should have the same model, not fall back to the default"
2371            );
2372        });
2373
2374        drop(reloaded_acp_thread);
2375    }
2376
2377    #[gpui::test]
2378    async fn test_save_load_thread(cx: &mut TestAppContext) {
2379        init_test(cx);
2380        let fs = FakeFs::new(cx.executor());
2381        fs.insert_tree(
2382            "/",
2383            json!({
2384                "a": {
2385                    "b.md": "Lorem"
2386                }
2387            }),
2388        )
2389        .await;
2390        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2391        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2392        let agent = NativeAgent::new(
2393            project.clone(),
2394            thread_store.clone(),
2395            Templates::new(),
2396            None,
2397            fs.clone(),
2398            &mut cx.to_async(),
2399        )
2400        .await
2401        .unwrap();
2402        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2403
2404        let acp_thread = cx
2405            .update(|cx| {
2406                connection
2407                    .clone()
2408                    .new_session(project.clone(), Path::new(""), cx)
2409            })
2410            .await
2411            .unwrap();
2412        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2413        let thread = agent.read_with(cx, |agent, _| {
2414            agent.sessions.get(&session_id).unwrap().thread.clone()
2415        });
2416
2417        // Ensure empty threads are not saved, even if they get mutated.
2418        let model = Arc::new(FakeLanguageModel::default());
2419        let summary_model = Arc::new(FakeLanguageModel::default());
2420        thread.update(cx, |thread, cx| {
2421            thread.set_model(model.clone(), cx);
2422            thread.set_summarization_model(Some(summary_model.clone()), cx);
2423        });
2424        cx.run_until_parked();
2425        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2426
2427        let send = acp_thread.update(cx, |thread, cx| {
2428            thread.send(
2429                vec![
2430                    "What does ".into(),
2431                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2432                        "b.md",
2433                        MentionUri::File {
2434                            abs_path: path!("/a/b.md").into(),
2435                        }
2436                        .to_uri()
2437                        .to_string(),
2438                    )),
2439                    " mean?".into(),
2440                ],
2441                cx,
2442            )
2443        });
2444        let send = cx.foreground_executor().spawn(send);
2445        cx.run_until_parked();
2446
2447        model.send_last_completion_stream_text_chunk("Lorem.");
2448        model.end_last_completion_stream();
2449        cx.run_until_parked();
2450        summary_model
2451            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2452        summary_model.end_last_completion_stream();
2453
2454        send.await.unwrap();
2455        let uri = MentionUri::File {
2456            abs_path: path!("/a/b.md").into(),
2457        }
2458        .to_uri();
2459        acp_thread.read_with(cx, |thread, cx| {
2460            assert_eq!(
2461                thread.to_markdown(cx),
2462                formatdoc! {"
2463                    ## User
2464
2465                    What does [@b.md]({uri}) mean?
2466
2467                    ## Assistant
2468
2469                    Lorem.
2470
2471                "}
2472            )
2473        });
2474
2475        cx.run_until_parked();
2476
2477        // Close the session so it can be reloaded from disk.
2478        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2479            .await
2480            .unwrap();
2481        drop(thread);
2482        drop(acp_thread);
2483        agent.read_with(cx, |agent, _| {
2484            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2485        });
2486
2487        // Ensure the thread can be reloaded from disk.
2488        assert_eq!(
2489            thread_entries(&thread_store, cx),
2490            vec![(
2491                session_id.clone(),
2492                format!("Explaining {}", path!("/a/b.md"))
2493            )]
2494        );
2495        let acp_thread = agent
2496            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2497            .await
2498            .unwrap();
2499        acp_thread.read_with(cx, |thread, cx| {
2500            assert_eq!(
2501                thread.to_markdown(cx),
2502                formatdoc! {"
2503                    ## User
2504
2505                    What does [@b.md]({uri}) mean?
2506
2507                    ## Assistant
2508
2509                    Lorem.
2510
2511                "}
2512            )
2513        });
2514    }
2515
2516    fn thread_entries(
2517        thread_store: &Entity<ThreadStore>,
2518        cx: &mut TestAppContext,
2519    ) -> Vec<(acp::SessionId, String)> {
2520        thread_store.read_with(cx, |store, _| {
2521            store
2522                .entries()
2523                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2524                .collect::<Vec<_>>()
2525        })
2526    }
2527
2528    fn init_test(cx: &mut TestAppContext) {
2529        env_logger::try_init().ok();
2530        cx.update(|cx| {
2531            let settings_store = SettingsStore::test(cx);
2532            cx.set_global(settings_store);
2533
2534            LanguageModelRegistry::test(cx);
2535        });
2536    }
2537}
2538
2539fn mcp_message_content_to_acp_content_block(
2540    content: context_server::types::MessageContent,
2541) -> acp::ContentBlock {
2542    match content {
2543        context_server::types::MessageContent::Text {
2544            text,
2545            annotations: _,
2546        } => text.into(),
2547        context_server::types::MessageContent::Image {
2548            data,
2549            mime_type,
2550            annotations: _,
2551        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2552        context_server::types::MessageContent::Audio {
2553            data,
2554            mime_type,
2555            annotations: _,
2556        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2557        context_server::types::MessageContent::Resource {
2558            resource,
2559            annotations: _,
2560        } => {
2561            let mut link =
2562                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2563            if let Some(mime_type) = resource.mime_type {
2564                link = link.mime_type(mime_type);
2565            }
2566            acp::ContentBlock::ResourceLink(link)
2567        }
2568    }
2569}