agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17pub use native_agent_server::NativeAgentServer;
  18pub use pattern_extraction::*;
  19pub use templates::*;
  20pub use thread::*;
  21pub use thread_store::*;
  22pub use tool_permissions::*;
  23pub use tools::*;
  24
  25use acp_thread::{
  26    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  27    AgentSessionListResponse, UserMessageId,
  28};
  29use agent_client_protocol as acp;
  30use anyhow::{Context as _, Result, anyhow};
  31use chrono::{DateTime, Utc};
  32use collections::{HashMap, HashSet, IndexMap};
  33use fs::Fs;
  34use futures::channel::{mpsc, oneshot};
  35use futures::future::Shared;
  36use futures::{FutureExt as _, StreamExt as _, future};
  37use gpui::{
  38    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  39};
  40use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  41use project::{Project, ProjectItem, ProjectPath, Worktree};
  42use prompt_store::{
  43    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  44    WorktreeContext,
  45};
  46use serde::{Deserialize, Serialize};
  47use settings::{LanguageModelSelection, update_settings_file};
  48use std::any::Any;
  49use std::path::{Path, PathBuf};
  50use std::rc::Rc;
  51use std::sync::Arc;
  52use std::time::Duration;
  53use util::ResultExt;
  54use util::rel_path::RelPath;
  55
  56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  57pub struct ProjectSnapshot {
  58    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  59    pub timestamp: DateTime<Utc>,
  60}
  61
  62pub struct RulesLoadingError {
  63    pub message: SharedString,
  64}
  65
  66/// Holds both the internal Thread and the AcpThread for a session
  67struct Session {
  68    /// The internal thread that processes messages
  69    thread: Entity<Thread>,
  70    /// The ACP thread that handles protocol communication
  71    acp_thread: Entity<acp_thread::AcpThread>,
  72    pending_save: Task<()>,
  73    _subscriptions: Vec<Subscription>,
  74}
  75
  76pub struct LanguageModels {
  77    /// Access language model by ID
  78    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  79    /// Cached list for returning language model information
  80    model_list: acp_thread::AgentModelList,
  81    refresh_models_rx: watch::Receiver<()>,
  82    refresh_models_tx: watch::Sender<()>,
  83    _authenticate_all_providers_task: Task<()>,
  84}
  85
  86impl LanguageModels {
  87    fn new(cx: &mut App) -> Self {
  88        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  89
  90        let mut this = Self {
  91            models: HashMap::default(),
  92            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  93            refresh_models_rx,
  94            refresh_models_tx,
  95            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  96        };
  97        this.refresh_list(cx);
  98        this
  99    }
 100
 101    fn refresh_list(&mut self, cx: &App) {
 102        let providers = LanguageModelRegistry::global(cx)
 103            .read(cx)
 104            .visible_providers()
 105            .into_iter()
 106            .filter(|provider| provider.is_authenticated(cx))
 107            .collect::<Vec<_>>();
 108
 109        let mut language_model_list = IndexMap::default();
 110        let mut recommended_models = HashSet::default();
 111
 112        let mut recommended = Vec::new();
 113        for provider in &providers {
 114            for model in provider.recommended_models(cx) {
 115                recommended_models.insert((model.provider_id(), model.id()));
 116                recommended.push(Self::map_language_model_to_info(&model, provider));
 117            }
 118        }
 119        if !recommended.is_empty() {
 120            language_model_list.insert(
 121                acp_thread::AgentModelGroupName("Recommended".into()),
 122                recommended,
 123            );
 124        }
 125
 126        let mut models = HashMap::default();
 127        for provider in providers {
 128            let mut provider_models = Vec::new();
 129            for model in provider.provided_models(cx) {
 130                let model_info = Self::map_language_model_to_info(&model, &provider);
 131                let model_id = model_info.id.clone();
 132                provider_models.push(model_info);
 133                models.insert(model_id, model);
 134            }
 135            if !provider_models.is_empty() {
 136                language_model_list.insert(
 137                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 138                    provider_models,
 139                );
 140            }
 141        }
 142
 143        self.models = models;
 144        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 145        self.refresh_models_tx.send(()).ok();
 146    }
 147
 148    fn watch(&self) -> watch::Receiver<()> {
 149        self.refresh_models_rx.clone()
 150    }
 151
 152    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 153        self.models.get(model_id).cloned()
 154    }
 155
 156    fn map_language_model_to_info(
 157        model: &Arc<dyn LanguageModel>,
 158        provider: &Arc<dyn LanguageModelProvider>,
 159    ) -> acp_thread::AgentModelInfo {
 160        acp_thread::AgentModelInfo {
 161            id: Self::model_id(model),
 162            name: model.name().0,
 163            description: None,
 164            icon: Some(match provider.icon() {
 165                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 166                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 167            }),
 168            is_latest: model.is_latest(),
 169        }
 170    }
 171
 172    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 173        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 174    }
 175
 176    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 177        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 178            .read(cx)
 179            .visible_providers()
 180            .iter()
 181            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 182            .collect::<Vec<_>>();
 183
 184        cx.background_spawn(async move {
 185            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 186                if let Err(err) = authenticate_task.await {
 187                    match err {
 188                        language_model::AuthenticateError::CredentialsNotFound => {
 189                            // Since we're authenticating these providers in the
 190                            // background for the purposes of populating the
 191                            // language selector, we don't care about providers
 192                            // where the credentials are not found.
 193                        }
 194                        language_model::AuthenticateError::ConnectionRefused => {
 195                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 196                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 197                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 198                        }
 199                        _ => {
 200                            // Some providers have noisy failure states that we
 201                            // don't want to spam the logs with every time the
 202                            // language model selector is initialized.
 203                            //
 204                            // Ideally these should have more clear failure modes
 205                            // that we know are safe to ignore here, like what we do
 206                            // with `CredentialsNotFound` above.
 207                            match provider_id.0.as_ref() {
 208                                "lmstudio" | "ollama" => {
 209                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 210                                    //
 211                                    // These fail noisily, so we don't log them.
 212                                }
 213                                "copilot_chat" => {
 214                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 215                                }
 216                                _ => {
 217                                    log::error!(
 218                                        "Failed to authenticate provider: {}: {err:#}",
 219                                        provider_name.0
 220                                    );
 221                                }
 222                            }
 223                        }
 224                    }
 225                }
 226            }
 227        })
 228    }
 229}
 230
 231pub struct NativeAgent {
 232    /// Session ID -> Session mapping
 233    sessions: HashMap<acp::SessionId, Session>,
 234    thread_store: Entity<ThreadStore>,
 235    /// Shared project context for all threads
 236    project_context: Entity<ProjectContext>,
 237    project_context_needs_refresh: watch::Sender<()>,
 238    _maintain_project_context: Task<Result<()>>,
 239    context_server_registry: Entity<ContextServerRegistry>,
 240    /// Shared templates for all threads
 241    templates: Arc<Templates>,
 242    /// Cached model information
 243    models: LanguageModels,
 244    project: Entity<Project>,
 245    prompt_store: Option<Entity<PromptStore>>,
 246    fs: Arc<dyn Fs>,
 247    _subscriptions: Vec<Subscription>,
 248}
 249
 250impl NativeAgent {
 251    pub async fn new(
 252        project: Entity<Project>,
 253        thread_store: Entity<ThreadStore>,
 254        templates: Arc<Templates>,
 255        prompt_store: Option<Entity<PromptStore>>,
 256        fs: Arc<dyn Fs>,
 257        cx: &mut AsyncApp,
 258    ) -> Result<Entity<NativeAgent>> {
 259        log::debug!("Creating new NativeAgent");
 260
 261        let project_context = cx
 262            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 263            .await;
 264
 265        Ok(cx.new(|cx| {
 266            let context_server_store = project.read(cx).context_server_store();
 267            let context_server_registry =
 268                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 269
 270            let mut subscriptions = vec![
 271                cx.subscribe(&project, Self::handle_project_event),
 272                cx.subscribe(
 273                    &LanguageModelRegistry::global(cx),
 274                    Self::handle_models_updated_event,
 275                ),
 276                cx.subscribe(
 277                    &context_server_store,
 278                    Self::handle_context_server_store_updated,
 279                ),
 280                cx.subscribe(
 281                    &context_server_registry,
 282                    Self::handle_context_server_registry_event,
 283                ),
 284            ];
 285            if let Some(prompt_store) = prompt_store.as_ref() {
 286                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 287            }
 288
 289            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 290                watch::channel(());
 291            Self {
 292                sessions: HashMap::default(),
 293                thread_store,
 294                project_context: cx.new(|_| project_context),
 295                project_context_needs_refresh: project_context_needs_refresh_tx,
 296                _maintain_project_context: cx.spawn(async move |this, cx| {
 297                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 298                }),
 299                context_server_registry,
 300                templates,
 301                models: LanguageModels::new(cx),
 302                project,
 303                prompt_store,
 304                fs,
 305                _subscriptions: subscriptions,
 306            }
 307        }))
 308    }
 309
 310    fn new_session(
 311        &mut self,
 312        project: Entity<Project>,
 313        cx: &mut Context<Self>,
 314    ) -> Entity<AcpThread> {
 315        // Create Thread
 316        // Fetch default model from registry settings
 317        let registry = LanguageModelRegistry::read_global(cx);
 318        // Log available models for debugging
 319        let available_count = registry.available_models(cx).count();
 320        log::debug!("Total available models: {}", available_count);
 321
 322        let default_model = registry.default_model().and_then(|default_model| {
 323            self.models
 324                .model_from_id(&LanguageModels::model_id(&default_model.model))
 325        });
 326        let thread = cx.new(|cx| {
 327            Thread::new(
 328                project.clone(),
 329                self.project_context.clone(),
 330                self.context_server_registry.clone(),
 331                self.templates.clone(),
 332                default_model,
 333                cx,
 334            )
 335        });
 336
 337        self.register_session(thread, None, cx)
 338    }
 339
 340    fn register_session(
 341        &mut self,
 342        thread_handle: Entity<Thread>,
 343        allowed_tool_names: Option<Vec<&str>>,
 344        cx: &mut Context<Self>,
 345    ) -> Entity<AcpThread> {
 346        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 347
 348        let thread = thread_handle.read(cx);
 349        let session_id = thread.id().clone();
 350        let parent_session_id = thread.parent_thread_id();
 351        let title = thread.title();
 352        let project = thread.project.clone();
 353        let action_log = thread.action_log.clone();
 354        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 355        let acp_thread = cx.new(|cx| {
 356            acp_thread::AcpThread::new(
 357                parent_session_id,
 358                title,
 359                connection,
 360                project.clone(),
 361                action_log.clone(),
 362                session_id.clone(),
 363                prompt_capabilities_rx,
 364                cx,
 365            )
 366        });
 367
 368        let registry = LanguageModelRegistry::read_global(cx);
 369        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 370
 371        let weak = cx.weak_entity();
 372        thread_handle.update(cx, |thread, cx| {
 373            thread.set_summarization_model(summarization_model, cx);
 374            thread.add_default_tools(
 375                allowed_tool_names,
 376                Rc::new(NativeThreadEnvironment {
 377                    acp_thread: acp_thread.downgrade(),
 378                    agent: weak,
 379                }) as _,
 380                cx,
 381            )
 382        });
 383
 384        let subscriptions = vec![
 385            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 386            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 387            cx.observe(&thread_handle, move |this, thread, cx| {
 388                this.save_thread(thread, cx)
 389            }),
 390        ];
 391
 392        self.sessions.insert(
 393            session_id,
 394            Session {
 395                thread: thread_handle,
 396                acp_thread: acp_thread.clone(),
 397                _subscriptions: subscriptions,
 398                pending_save: Task::ready(()),
 399            },
 400        );
 401
 402        self.update_available_commands(cx);
 403
 404        acp_thread
 405    }
 406
 407    pub fn models(&self) -> &LanguageModels {
 408        &self.models
 409    }
 410
 411    async fn maintain_project_context(
 412        this: WeakEntity<Self>,
 413        mut needs_refresh: watch::Receiver<()>,
 414        cx: &mut AsyncApp,
 415    ) -> Result<()> {
 416        while needs_refresh.changed().await.is_ok() {
 417            let project_context = this
 418                .update(cx, |this, cx| {
 419                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 420                })?
 421                .await;
 422            this.update(cx, |this, cx| {
 423                this.project_context = cx.new(|_| project_context);
 424            })?;
 425        }
 426
 427        Ok(())
 428    }
 429
 430    fn build_project_context(
 431        project: &Entity<Project>,
 432        prompt_store: Option<&Entity<PromptStore>>,
 433        cx: &mut App,
 434    ) -> Task<ProjectContext> {
 435        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 436        let worktree_tasks = worktrees
 437            .into_iter()
 438            .map(|worktree| {
 439                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 440            })
 441            .collect::<Vec<_>>();
 442        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 443            prompt_store.read_with(cx, |prompt_store, cx| {
 444                let prompts = prompt_store.default_prompt_metadata();
 445                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 446                    let contents = prompt_store.load(prompt_metadata.id, cx);
 447                    async move { (contents.await, prompt_metadata) }
 448                });
 449                cx.background_spawn(future::join_all(load_tasks))
 450            })
 451        } else {
 452            Task::ready(vec![])
 453        };
 454
 455        cx.spawn(async move |_cx| {
 456            let (worktrees, default_user_rules) =
 457                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 458
 459            let worktrees = worktrees
 460                .into_iter()
 461                .map(|(worktree, _rules_error)| {
 462                    // TODO: show error message
 463                    // if let Some(rules_error) = rules_error {
 464                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 465                    // }
 466                    worktree
 467                })
 468                .collect::<Vec<_>>();
 469
 470            let default_user_rules = default_user_rules
 471                .into_iter()
 472                .flat_map(|(contents, prompt_metadata)| match contents {
 473                    Ok(contents) => Some(UserRulesContext {
 474                        uuid: prompt_metadata.id.as_user()?,
 475                        title: prompt_metadata.title.map(|title| title.to_string()),
 476                        contents,
 477                    }),
 478                    Err(_err) => {
 479                        // TODO: show error message
 480                        // this.update(cx, |_, cx| {
 481                        //     cx.emit(RulesLoadingError {
 482                        //         message: format!("{err:?}").into(),
 483                        //     });
 484                        // })
 485                        // .ok();
 486                        None
 487                    }
 488                })
 489                .collect::<Vec<_>>();
 490
 491            ProjectContext::new(worktrees, default_user_rules)
 492        })
 493    }
 494
 495    fn load_worktree_info_for_system_prompt(
 496        worktree: Entity<Worktree>,
 497        project: Entity<Project>,
 498        cx: &mut App,
 499    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 500        let tree = worktree.read(cx);
 501        let root_name = tree.root_name_str().into();
 502        let abs_path = tree.abs_path();
 503
 504        let mut context = WorktreeContext {
 505            root_name,
 506            abs_path,
 507            rules_file: None,
 508        };
 509
 510        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 511        let Some(rules_task) = rules_task else {
 512            return Task::ready((context, None));
 513        };
 514
 515        cx.spawn(async move |_| {
 516            let (rules_file, rules_file_error) = match rules_task.await {
 517                Ok(rules_file) => (Some(rules_file), None),
 518                Err(err) => (
 519                    None,
 520                    Some(RulesLoadingError {
 521                        message: format!("{err}").into(),
 522                    }),
 523                ),
 524            };
 525            context.rules_file = rules_file;
 526            (context, rules_file_error)
 527        })
 528    }
 529
 530    fn load_worktree_rules_file(
 531        worktree: Entity<Worktree>,
 532        project: Entity<Project>,
 533        cx: &mut App,
 534    ) -> Option<Task<Result<RulesFileContext>>> {
 535        let worktree = worktree.read(cx);
 536        let worktree_id = worktree.id();
 537        let selected_rules_file = RULES_FILE_NAMES
 538            .into_iter()
 539            .filter_map(|name| {
 540                worktree
 541                    .entry_for_path(RelPath::unix(name).unwrap())
 542                    .filter(|entry| entry.is_file())
 543                    .map(|entry| entry.path.clone())
 544            })
 545            .next();
 546
 547        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 548        // supported. This doesn't seem to occur often in GitHub repositories.
 549        selected_rules_file.map(|path_in_worktree| {
 550            let project_path = ProjectPath {
 551                worktree_id,
 552                path: path_in_worktree.clone(),
 553            };
 554            let buffer_task =
 555                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 556            let rope_task = cx.spawn(async move |cx| {
 557                let buffer = buffer_task.await?;
 558                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 559                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 560                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 561                })?;
 562                anyhow::Ok((project_entry_id, rope))
 563            });
 564            // Build a string from the rope on a background thread.
 565            cx.background_spawn(async move {
 566                let (project_entry_id, rope) = rope_task.await?;
 567                anyhow::Ok(RulesFileContext {
 568                    path_in_worktree,
 569                    text: rope.to_string().trim().to_string(),
 570                    project_entry_id: project_entry_id.to_usize(),
 571                })
 572            })
 573        })
 574    }
 575
 576    fn handle_thread_title_updated(
 577        &mut self,
 578        thread: Entity<Thread>,
 579        _: &TitleUpdated,
 580        cx: &mut Context<Self>,
 581    ) {
 582        let session_id = thread.read(cx).id();
 583        let Some(session) = self.sessions.get(session_id) else {
 584            return;
 585        };
 586        let thread = thread.downgrade();
 587        let acp_thread = session.acp_thread.downgrade();
 588        cx.spawn(async move |_, cx| {
 589            let title = thread.read_with(cx, |thread, _| thread.title())?;
 590            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 591            task.await
 592        })
 593        .detach_and_log_err(cx);
 594    }
 595
 596    fn handle_thread_token_usage_updated(
 597        &mut self,
 598        thread: Entity<Thread>,
 599        usage: &TokenUsageUpdated,
 600        cx: &mut Context<Self>,
 601    ) {
 602        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 603            return;
 604        };
 605        session.acp_thread.update(cx, |acp_thread, cx| {
 606            acp_thread.update_token_usage(usage.0.clone(), cx);
 607        });
 608    }
 609
 610    fn handle_project_event(
 611        &mut self,
 612        _project: Entity<Project>,
 613        event: &project::Event,
 614        _cx: &mut Context<Self>,
 615    ) {
 616        match event {
 617            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 618                self.project_context_needs_refresh.send(()).ok();
 619            }
 620            project::Event::WorktreeUpdatedEntries(_, items) => {
 621                if items.iter().any(|(path, _, _)| {
 622                    RULES_FILE_NAMES
 623                        .iter()
 624                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 625                }) {
 626                    self.project_context_needs_refresh.send(()).ok();
 627                }
 628            }
 629            _ => {}
 630        }
 631    }
 632
 633    fn handle_prompts_updated_event(
 634        &mut self,
 635        _prompt_store: Entity<PromptStore>,
 636        _event: &prompt_store::PromptsUpdatedEvent,
 637        _cx: &mut Context<Self>,
 638    ) {
 639        self.project_context_needs_refresh.send(()).ok();
 640    }
 641
 642    fn handle_models_updated_event(
 643        &mut self,
 644        _registry: Entity<LanguageModelRegistry>,
 645        _event: &language_model::Event,
 646        cx: &mut Context<Self>,
 647    ) {
 648        self.models.refresh_list(cx);
 649
 650        let registry = LanguageModelRegistry::read_global(cx);
 651        let default_model = registry.default_model().map(|m| m.model);
 652        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 653
 654        for session in self.sessions.values_mut() {
 655            session.thread.update(cx, |thread, cx| {
 656                if thread.model().is_none()
 657                    && let Some(model) = default_model.clone()
 658                {
 659                    thread.set_model(model, cx);
 660                    cx.notify();
 661                }
 662                thread.set_summarization_model(summarization_model.clone(), cx);
 663            });
 664        }
 665    }
 666
 667    fn handle_context_server_store_updated(
 668        &mut self,
 669        _store: Entity<project::context_server_store::ContextServerStore>,
 670        _event: &project::context_server_store::ServerStatusChangedEvent,
 671        cx: &mut Context<Self>,
 672    ) {
 673        self.update_available_commands(cx);
 674    }
 675
 676    fn handle_context_server_registry_event(
 677        &mut self,
 678        _registry: Entity<ContextServerRegistry>,
 679        event: &ContextServerRegistryEvent,
 680        cx: &mut Context<Self>,
 681    ) {
 682        match event {
 683            ContextServerRegistryEvent::ToolsChanged => {}
 684            ContextServerRegistryEvent::PromptsChanged => {
 685                self.update_available_commands(cx);
 686            }
 687        }
 688    }
 689
 690    fn update_available_commands(&self, cx: &mut Context<Self>) {
 691        let available_commands = self.build_available_commands(cx);
 692        for session in self.sessions.values() {
 693            session.acp_thread.update(cx, |thread, cx| {
 694                thread
 695                    .handle_session_update(
 696                        acp::SessionUpdate::AvailableCommandsUpdate(
 697                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 698                        ),
 699                        cx,
 700                    )
 701                    .log_err();
 702            });
 703        }
 704    }
 705
 706    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 707        let registry = self.context_server_registry.read(cx);
 708
 709        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 710        for context_server_prompt in registry.prompts() {
 711            *prompt_name_counts
 712                .entry(context_server_prompt.prompt.name.as_str())
 713                .or_insert(0) += 1;
 714        }
 715
 716        registry
 717            .prompts()
 718            .flat_map(|context_server_prompt| {
 719                let prompt = &context_server_prompt.prompt;
 720
 721                let should_prefix = prompt_name_counts
 722                    .get(prompt.name.as_str())
 723                    .copied()
 724                    .unwrap_or(0)
 725                    > 1;
 726
 727                let name = if should_prefix {
 728                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 729                } else {
 730                    prompt.name.clone()
 731                };
 732
 733                let mut command = acp::AvailableCommand::new(
 734                    name,
 735                    prompt.description.clone().unwrap_or_default(),
 736                );
 737
 738                match prompt.arguments.as_deref() {
 739                    Some([arg]) => {
 740                        let hint = format!("<{}>", arg.name);
 741
 742                        command = command.input(acp::AvailableCommandInput::Unstructured(
 743                            acp::UnstructuredCommandInput::new(hint),
 744                        ));
 745                    }
 746                    Some([]) | None => {}
 747                    Some(_) => {
 748                        // skip >1 argument commands since we don't support them yet
 749                        return None;
 750                    }
 751                }
 752
 753                Some(command)
 754            })
 755            .collect()
 756    }
 757
 758    pub fn load_thread(
 759        &mut self,
 760        id: acp::SessionId,
 761        cx: &mut Context<Self>,
 762    ) -> Task<Result<Entity<Thread>>> {
 763        let database_future = ThreadsDatabase::connect(cx);
 764        cx.spawn(async move |this, cx| {
 765            let database = database_future.await.map_err(|err| anyhow!(err))?;
 766            let db_thread = database
 767                .load_thread(id.clone())
 768                .await?
 769                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 770
 771            this.update(cx, |this, cx| {
 772                let summarization_model = LanguageModelRegistry::read_global(cx)
 773                    .thread_summary_model()
 774                    .map(|c| c.model);
 775
 776                cx.new(|cx| {
 777                    let mut thread = Thread::from_db(
 778                        id.clone(),
 779                        db_thread,
 780                        this.project.clone(),
 781                        this.project_context.clone(),
 782                        this.context_server_registry.clone(),
 783                        this.templates.clone(),
 784                        cx,
 785                    );
 786                    thread.set_summarization_model(summarization_model, cx);
 787                    thread
 788                })
 789            })
 790        })
 791    }
 792
 793    pub fn open_thread(
 794        &mut self,
 795        id: acp::SessionId,
 796        cx: &mut Context<Self>,
 797    ) -> Task<Result<Entity<AcpThread>>> {
 798        if let Some(session) = self.sessions.get(&id) {
 799            return Task::ready(Ok(session.acp_thread.clone()));
 800        }
 801
 802        let task = self.load_thread(id, cx);
 803        cx.spawn(async move |this, cx| {
 804            let thread = task.await?;
 805            let acp_thread = this.update(cx, |this, cx| {
 806                this.register_session(thread.clone(), None, cx)
 807            })?;
 808            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 809            cx.update(|cx| {
 810                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 811            })
 812            .await?;
 813            Ok(acp_thread)
 814        })
 815    }
 816
 817    pub fn thread_summary(
 818        &mut self,
 819        id: acp::SessionId,
 820        cx: &mut Context<Self>,
 821    ) -> Task<Result<SharedString>> {
 822        let thread = self.open_thread(id.clone(), cx);
 823        cx.spawn(async move |this, cx| {
 824            let acp_thread = thread.await?;
 825            let result = this
 826                .update(cx, |this, cx| {
 827                    this.sessions
 828                        .get(&id)
 829                        .unwrap()
 830                        .thread
 831                        .update(cx, |thread, cx| thread.summary(cx))
 832                })?
 833                .await
 834                .context("Failed to generate summary")?;
 835            drop(acp_thread);
 836            Ok(result)
 837        })
 838    }
 839
 840    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 841        if thread.read(cx).is_empty() {
 842            return;
 843        }
 844
 845        let database_future = ThreadsDatabase::connect(cx);
 846        let (id, db_thread) =
 847            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 848        let Some(session) = self.sessions.get_mut(&id) else {
 849            return;
 850        };
 851        let thread_store = self.thread_store.clone();
 852        session.pending_save = cx.spawn(async move |_, cx| {
 853            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 854                return;
 855            };
 856            let db_thread = db_thread.await;
 857            database.save_thread(id, db_thread).await.log_err();
 858            thread_store.update(cx, |store, cx| store.reload(cx));
 859        });
 860    }
 861
 862    fn send_mcp_prompt(
 863        &self,
 864        message_id: UserMessageId,
 865        session_id: agent_client_protocol::SessionId,
 866        prompt_name: String,
 867        server_id: ContextServerId,
 868        arguments: HashMap<String, String>,
 869        original_content: Vec<acp::ContentBlock>,
 870        cx: &mut Context<Self>,
 871    ) -> Task<Result<acp::PromptResponse>> {
 872        let server_store = self.context_server_registry.read(cx).server_store().clone();
 873        let path_style = self.project.read(cx).path_style(cx);
 874
 875        cx.spawn(async move |this, cx| {
 876            let prompt =
 877                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 878
 879            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 880                let session = this
 881                    .sessions
 882                    .get(&session_id)
 883                    .context("Failed to get session")?;
 884                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 885            })??;
 886
 887            let mut last_is_user = true;
 888
 889            thread.update(cx, |thread, cx| {
 890                thread.push_acp_user_block(
 891                    message_id,
 892                    original_content.into_iter().skip(1),
 893                    path_style,
 894                    cx,
 895                );
 896            });
 897
 898            for message in prompt.messages {
 899                let context_server::types::PromptMessage { role, content } = message;
 900                let block = mcp_message_content_to_acp_content_block(content);
 901
 902                match role {
 903                    context_server::types::Role::User => {
 904                        let id = acp_thread::UserMessageId::new();
 905
 906                        acp_thread.update(cx, |acp_thread, cx| {
 907                            acp_thread.push_user_content_block_with_indent(
 908                                Some(id.clone()),
 909                                block.clone(),
 910                                true,
 911                                cx,
 912                            );
 913                        });
 914
 915                        thread.update(cx, |thread, cx| {
 916                            thread.push_acp_user_block(id, [block], path_style, cx);
 917                        });
 918                    }
 919                    context_server::types::Role::Assistant => {
 920                        acp_thread.update(cx, |acp_thread, cx| {
 921                            acp_thread.push_assistant_content_block_with_indent(
 922                                block.clone(),
 923                                false,
 924                                true,
 925                                cx,
 926                            );
 927                        });
 928
 929                        thread.update(cx, |thread, cx| {
 930                            thread.push_acp_agent_block(block, cx);
 931                        });
 932                    }
 933                }
 934
 935                last_is_user = role == context_server::types::Role::User;
 936            }
 937
 938            let response_stream = thread.update(cx, |thread, cx| {
 939                if last_is_user {
 940                    thread.send_existing(cx)
 941                } else {
 942                    // Resume if MCP prompt did not end with a user message
 943                    thread.resume(cx)
 944                }
 945            })?;
 946
 947            cx.update(|cx| {
 948                NativeAgentConnection::handle_thread_events(
 949                    response_stream,
 950                    acp_thread.downgrade(),
 951                    cx,
 952                )
 953            })
 954            .await
 955        })
 956    }
 957}
 958
 959/// Wrapper struct that implements the AgentConnection trait
 960#[derive(Clone)]
 961pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 962
 963impl NativeAgentConnection {
 964    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 965        self.0
 966            .read(cx)
 967            .sessions
 968            .get(session_id)
 969            .map(|session| session.thread.clone())
 970    }
 971
 972    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 973        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 974    }
 975
 976    fn run_turn(
 977        &self,
 978        session_id: acp::SessionId,
 979        cx: &mut App,
 980        f: impl 'static
 981        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 982    ) -> Task<Result<acp::PromptResponse>> {
 983        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 984            agent
 985                .sessions
 986                .get_mut(&session_id)
 987                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 988        }) else {
 989            return Task::ready(Err(anyhow!("Session not found")));
 990        };
 991        log::debug!("Found session for: {}", session_id);
 992
 993        let response_stream = match f(thread, cx) {
 994            Ok(stream) => stream,
 995            Err(err) => return Task::ready(Err(err)),
 996        };
 997        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
 998    }
 999
1000    fn handle_thread_events(
1001        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1002        acp_thread: WeakEntity<AcpThread>,
1003        cx: &App,
1004    ) -> Task<Result<acp::PromptResponse>> {
1005        cx.spawn(async move |cx| {
1006            // Handle response stream and forward to session.acp_thread
1007            while let Some(result) = events.next().await {
1008                match result {
1009                    Ok(event) => {
1010                        log::trace!("Received completion event: {:?}", event);
1011
1012                        match event {
1013                            ThreadEvent::UserMessage(message) => {
1014                                acp_thread.update(cx, |thread, cx| {
1015                                    for content in message.content {
1016                                        thread.push_user_content_block(
1017                                            Some(message.id.clone()),
1018                                            content.into(),
1019                                            cx,
1020                                        );
1021                                    }
1022                                })?;
1023                            }
1024                            ThreadEvent::AgentText(text) => {
1025                                acp_thread.update(cx, |thread, cx| {
1026                                    thread.push_assistant_content_block(text.into(), false, cx)
1027                                })?;
1028                            }
1029                            ThreadEvent::AgentThinking(text) => {
1030                                acp_thread.update(cx, |thread, cx| {
1031                                    thread.push_assistant_content_block(text.into(), true, cx)
1032                                })?;
1033                            }
1034                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1035                                tool_call,
1036                                options,
1037                                response,
1038                                context: _,
1039                            }) => {
1040                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1041                                    thread.request_tool_call_authorization(
1042                                        tool_call, options, true, cx,
1043                                    )
1044                                })??;
1045                                cx.background_spawn(async move {
1046                                    if let acp::RequestPermissionOutcome::Selected(
1047                                        acp::SelectedPermissionOutcome { option_id, .. },
1048                                    ) = outcome_task.await
1049                                    {
1050                                        response
1051                                            .send(option_id)
1052                                            .map(|_| anyhow!("authorization receiver was dropped"))
1053                                            .log_err();
1054                                    }
1055                                })
1056                                .detach();
1057                            }
1058                            ThreadEvent::ToolCall(tool_call) => {
1059                                acp_thread.update(cx, |thread, cx| {
1060                                    thread.upsert_tool_call(tool_call, cx)
1061                                })??;
1062                            }
1063                            ThreadEvent::ToolCallUpdate(update) => {
1064                                acp_thread.update(cx, |thread, cx| {
1065                                    thread.update_tool_call(update, cx)
1066                                })??;
1067                            }
1068                            ThreadEvent::SubagentSpawned(session_id) => {
1069                                acp_thread.update(cx, |thread, cx| {
1070                                    thread.subagent_spawned(session_id, cx);
1071                                })?;
1072                            }
1073                            ThreadEvent::Retry(status) => {
1074                                acp_thread.update(cx, |thread, cx| {
1075                                    thread.update_retry_status(status, cx)
1076                                })?;
1077                            }
1078                            ThreadEvent::Stop(stop_reason) => {
1079                                log::debug!("Assistant message complete: {:?}", stop_reason);
1080                                return Ok(acp::PromptResponse::new(stop_reason));
1081                            }
1082                        }
1083                    }
1084                    Err(e) => {
1085                        log::error!("Error in model response stream: {:?}", e);
1086                        return Err(e);
1087                    }
1088                }
1089            }
1090
1091            log::debug!("Response stream completed");
1092            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1093        })
1094    }
1095}
1096
1097struct Command<'a> {
1098    prompt_name: &'a str,
1099    arg_value: &'a str,
1100    explicit_server_id: Option<&'a str>,
1101}
1102
1103impl<'a> Command<'a> {
1104    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1105        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1106            return None;
1107        };
1108        let text = text_content.text.trim();
1109        let command = text.strip_prefix('/')?;
1110        let (command, arg_value) = command
1111            .split_once(char::is_whitespace)
1112            .unwrap_or((command, ""));
1113
1114        if let Some((server_id, prompt_name)) = command.split_once('.') {
1115            Some(Self {
1116                prompt_name,
1117                arg_value,
1118                explicit_server_id: Some(server_id),
1119            })
1120        } else {
1121            Some(Self {
1122                prompt_name: command,
1123                arg_value,
1124                explicit_server_id: None,
1125            })
1126        }
1127    }
1128}
1129
1130struct NativeAgentModelSelector {
1131    session_id: acp::SessionId,
1132    connection: NativeAgentConnection,
1133}
1134
1135impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1136    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1137        log::debug!("NativeAgentConnection::list_models called");
1138        let list = self.connection.0.read(cx).models.model_list.clone();
1139        Task::ready(if list.is_empty() {
1140            Err(anyhow::anyhow!("No models available"))
1141        } else {
1142            Ok(list)
1143        })
1144    }
1145
1146    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1147        log::debug!(
1148            "Setting model for session {}: {}",
1149            self.session_id,
1150            model_id
1151        );
1152        let Some(thread) = self
1153            .connection
1154            .0
1155            .read(cx)
1156            .sessions
1157            .get(&self.session_id)
1158            .map(|session| session.thread.clone())
1159        else {
1160            return Task::ready(Err(anyhow!("Session not found")));
1161        };
1162
1163        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1164            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1165        };
1166
1167        // We want to reset the effort level when switching models, as the currently-selected effort level may
1168        // not be compatible.
1169        let effort = model
1170            .default_effort_level()
1171            .map(|effort_level| effort_level.value.to_string());
1172
1173        thread.update(cx, |thread, cx| {
1174            thread.set_model(model.clone(), cx);
1175            thread.set_thinking_effort(effort.clone(), cx);
1176            thread.set_thinking_enabled(model.supports_thinking(), cx);
1177        });
1178
1179        update_settings_file(
1180            self.connection.0.read(cx).fs.clone(),
1181            cx,
1182            move |settings, cx| {
1183                let provider = model.provider_id().0.to_string();
1184                let model = model.id().0.to_string();
1185                let enable_thinking = thread.read(cx).thinking_enabled();
1186                settings
1187                    .agent
1188                    .get_or_insert_default()
1189                    .set_model(LanguageModelSelection {
1190                        provider: provider.into(),
1191                        model,
1192                        enable_thinking,
1193                        effort,
1194                    });
1195            },
1196        );
1197
1198        Task::ready(Ok(()))
1199    }
1200
1201    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1202        let Some(thread) = self
1203            .connection
1204            .0
1205            .read(cx)
1206            .sessions
1207            .get(&self.session_id)
1208            .map(|session| session.thread.clone())
1209        else {
1210            return Task::ready(Err(anyhow!("Session not found")));
1211        };
1212        let Some(model) = thread.read(cx).model() else {
1213            return Task::ready(Err(anyhow!("Model not found")));
1214        };
1215        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1216        else {
1217            return Task::ready(Err(anyhow!("Provider not found")));
1218        };
1219        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1220            model, &provider,
1221        )))
1222    }
1223
1224    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1225        Some(self.connection.0.read(cx).models.watch())
1226    }
1227
1228    fn should_render_footer(&self) -> bool {
1229        true
1230    }
1231}
1232
1233impl acp_thread::AgentConnection for NativeAgentConnection {
1234    fn telemetry_id(&self) -> SharedString {
1235        "zed".into()
1236    }
1237
1238    fn new_session(
1239        self: Rc<Self>,
1240        project: Entity<Project>,
1241        cwd: &Path,
1242        cx: &mut App,
1243    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1244        log::debug!("Creating new thread for project at: {cwd:?}");
1245        Task::ready(Ok(self
1246            .0
1247            .update(cx, |agent, cx| agent.new_session(project, cx))))
1248    }
1249
1250    fn supports_load_session(&self, _cx: &App) -> bool {
1251        true
1252    }
1253
1254    fn load_session(
1255        self: Rc<Self>,
1256        session: AgentSessionInfo,
1257        _project: Entity<Project>,
1258        _cwd: &Path,
1259        cx: &mut App,
1260    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1261        self.0
1262            .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1263    }
1264
1265    fn supports_close_session(&self, _cx: &App) -> bool {
1266        true
1267    }
1268
1269    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1270        self.0.update(cx, |agent, _cx| {
1271            agent.sessions.remove(session_id);
1272        });
1273        Task::ready(Ok(()))
1274    }
1275
1276    fn auth_methods(&self) -> &[acp::AuthMethod] {
1277        &[] // No auth for in-process
1278    }
1279
1280    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1281        Task::ready(Ok(()))
1282    }
1283
1284    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1285        Some(Rc::new(NativeAgentModelSelector {
1286            session_id: session_id.clone(),
1287            connection: self.clone(),
1288        }) as Rc<dyn AgentModelSelector>)
1289    }
1290
1291    fn prompt(
1292        &self,
1293        id: Option<acp_thread::UserMessageId>,
1294        params: acp::PromptRequest,
1295        cx: &mut App,
1296    ) -> Task<Result<acp::PromptResponse>> {
1297        let id = id.expect("UserMessageId is required");
1298        let session_id = params.session_id.clone();
1299        log::info!("Received prompt request for session: {}", session_id);
1300        log::debug!("Prompt blocks count: {}", params.prompt.len());
1301
1302        if let Some(parsed_command) = Command::parse(&params.prompt) {
1303            let registry = self.0.read(cx).context_server_registry.read(cx);
1304
1305            let explicit_server_id = parsed_command
1306                .explicit_server_id
1307                .map(|server_id| ContextServerId(server_id.into()));
1308
1309            if let Some(prompt) =
1310                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1311            {
1312                let arguments = if !parsed_command.arg_value.is_empty()
1313                    && let Some(arg_name) = prompt
1314                        .prompt
1315                        .arguments
1316                        .as_ref()
1317                        .and_then(|args| args.first())
1318                        .map(|arg| arg.name.clone())
1319                {
1320                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1321                } else {
1322                    Default::default()
1323                };
1324
1325                let prompt_name = prompt.prompt.name.clone();
1326                let server_id = prompt.server_id.clone();
1327
1328                return self.0.update(cx, |agent, cx| {
1329                    agent.send_mcp_prompt(
1330                        id,
1331                        session_id.clone(),
1332                        prompt_name,
1333                        server_id,
1334                        arguments,
1335                        params.prompt,
1336                        cx,
1337                    )
1338                });
1339            };
1340        };
1341
1342        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1343
1344        self.run_turn(session_id, cx, move |thread, cx| {
1345            let content: Vec<UserMessageContent> = params
1346                .prompt
1347                .into_iter()
1348                .map(|block| UserMessageContent::from_content_block(block, path_style))
1349                .collect::<Vec<_>>();
1350            log::debug!("Converted prompt to message: {} chars", content.len());
1351            log::debug!("Message id: {:?}", id);
1352            log::debug!("Message content: {:?}", content);
1353
1354            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1355        })
1356    }
1357
1358    fn retry(
1359        &self,
1360        session_id: &acp::SessionId,
1361        _cx: &App,
1362    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1363        Some(Rc::new(NativeAgentSessionRetry {
1364            connection: self.clone(),
1365            session_id: session_id.clone(),
1366        }) as _)
1367    }
1368
1369    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1370        log::info!("Cancelling on session: {}", session_id);
1371        self.0.update(cx, |agent, cx| {
1372            if let Some(agent) = agent.sessions.get(session_id) {
1373                agent
1374                    .thread
1375                    .update(cx, |thread, cx| thread.cancel(cx))
1376                    .detach();
1377            }
1378        });
1379    }
1380
1381    fn truncate(
1382        &self,
1383        session_id: &agent_client_protocol::SessionId,
1384        cx: &App,
1385    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1386        self.0.read_with(cx, |agent, _cx| {
1387            agent.sessions.get(session_id).map(|session| {
1388                Rc::new(NativeAgentSessionTruncate {
1389                    thread: session.thread.clone(),
1390                    acp_thread: session.acp_thread.downgrade(),
1391                }) as _
1392            })
1393        })
1394    }
1395
1396    fn set_title(
1397        &self,
1398        session_id: &acp::SessionId,
1399        _cx: &App,
1400    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1401        Some(Rc::new(NativeAgentSessionSetTitle {
1402            connection: self.clone(),
1403            session_id: session_id.clone(),
1404        }) as _)
1405    }
1406
1407    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1408        let thread_store = self.0.read(cx).thread_store.clone();
1409        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1410    }
1411
1412    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1413        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1414    }
1415
1416    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1417        self
1418    }
1419}
1420
1421impl acp_thread::AgentTelemetry for NativeAgentConnection {
1422    fn thread_data(
1423        &self,
1424        session_id: &acp::SessionId,
1425        cx: &mut App,
1426    ) -> Task<Result<serde_json::Value>> {
1427        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1428            return Task::ready(Err(anyhow!("Session not found")));
1429        };
1430
1431        let task = session.thread.read(cx).to_db(cx);
1432        cx.background_spawn(async move {
1433            serde_json::to_value(task.await).context("Failed to serialize thread")
1434        })
1435    }
1436}
1437
1438pub struct NativeAgentSessionList {
1439    thread_store: Entity<ThreadStore>,
1440    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1441    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1442    _subscription: Subscription,
1443}
1444
1445impl NativeAgentSessionList {
1446    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1447        let (tx, rx) = smol::channel::unbounded();
1448        let this_tx = tx.clone();
1449        let subscription = cx.observe(&thread_store, move |_, _| {
1450            this_tx
1451                .try_send(acp_thread::SessionListUpdate::Refresh)
1452                .ok();
1453        });
1454        Self {
1455            thread_store,
1456            updates_tx: tx,
1457            updates_rx: rx,
1458            _subscription: subscription,
1459        }
1460    }
1461
1462    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1463        AgentSessionInfo {
1464            session_id: entry.id,
1465            cwd: None,
1466            title: Some(entry.title),
1467            updated_at: Some(entry.updated_at),
1468            meta: None,
1469        }
1470    }
1471
1472    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1473        &self.thread_store
1474    }
1475}
1476
1477impl AgentSessionList for NativeAgentSessionList {
1478    fn list_sessions(
1479        &self,
1480        _request: AgentSessionListRequest,
1481        cx: &mut App,
1482    ) -> Task<Result<AgentSessionListResponse>> {
1483        let sessions = self
1484            .thread_store
1485            .read(cx)
1486            .entries()
1487            .map(Self::to_session_info)
1488            .collect();
1489        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1490    }
1491
1492    fn supports_delete(&self) -> bool {
1493        true
1494    }
1495
1496    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1497        self.thread_store
1498            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1499    }
1500
1501    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1502        self.thread_store
1503            .update(cx, |store, cx| store.delete_threads(cx))
1504    }
1505
1506    fn watch(
1507        &self,
1508        _cx: &mut App,
1509    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1510        Some(self.updates_rx.clone())
1511    }
1512
1513    fn notify_refresh(&self) {
1514        self.updates_tx
1515            .try_send(acp_thread::SessionListUpdate::Refresh)
1516            .ok();
1517    }
1518
1519    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1520        self
1521    }
1522}
1523
1524struct NativeAgentSessionTruncate {
1525    thread: Entity<Thread>,
1526    acp_thread: WeakEntity<AcpThread>,
1527}
1528
1529impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1530    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1531        match self.thread.update(cx, |thread, cx| {
1532            thread.truncate(message_id.clone(), cx)?;
1533            Ok(thread.latest_token_usage())
1534        }) {
1535            Ok(usage) => {
1536                self.acp_thread
1537                    .update(cx, |thread, cx| {
1538                        thread.update_token_usage(usage, cx);
1539                    })
1540                    .ok();
1541                Task::ready(Ok(()))
1542            }
1543            Err(error) => Task::ready(Err(error)),
1544        }
1545    }
1546}
1547
1548struct NativeAgentSessionRetry {
1549    connection: NativeAgentConnection,
1550    session_id: acp::SessionId,
1551}
1552
1553impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1554    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1555        self.connection
1556            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1557                thread.update(cx, |thread, cx| thread.resume(cx))
1558            })
1559    }
1560}
1561
1562struct NativeAgentSessionSetTitle {
1563    connection: NativeAgentConnection,
1564    session_id: acp::SessionId,
1565}
1566
1567impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1568    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1569        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1570            return Task::ready(Err(anyhow!("session not found")));
1571        };
1572        let thread = session.thread.clone();
1573        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1574        Task::ready(Ok(()))
1575    }
1576}
1577
1578pub struct NativeThreadEnvironment {
1579    agent: WeakEntity<NativeAgent>,
1580    acp_thread: WeakEntity<AcpThread>,
1581}
1582
1583impl NativeThreadEnvironment {
1584    pub(crate) fn create_subagent_thread(
1585        agent: WeakEntity<NativeAgent>,
1586        parent_thread_entity: Entity<Thread>,
1587        label: String,
1588        initial_prompt: String,
1589        timeout: Option<Duration>,
1590        allowed_tools: Option<Vec<String>>,
1591        cx: &mut App,
1592    ) -> Result<Rc<dyn SubagentHandle>> {
1593        let parent_thread = parent_thread_entity.read(cx);
1594        let current_depth = parent_thread.depth();
1595
1596        if current_depth >= MAX_SUBAGENT_DEPTH {
1597            return Err(anyhow!(
1598                "Maximum subagent depth ({}) reached",
1599                MAX_SUBAGENT_DEPTH
1600            ));
1601        }
1602
1603        let running_count = parent_thread.running_subagent_count();
1604        if running_count >= MAX_PARALLEL_SUBAGENTS {
1605            return Err(anyhow!(
1606                "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
1607                MAX_PARALLEL_SUBAGENTS
1608            ));
1609        }
1610
1611        let allowed_tools = match allowed_tools {
1612            Some(tools) => {
1613                let parent_tool_names: std::collections::HashSet<&str> =
1614                    parent_thread.tools.keys().map(|s| s.as_str()).collect();
1615                Some(
1616                    tools
1617                        .into_iter()
1618                        .filter(|t| parent_tool_names.contains(t.as_str()))
1619                        .collect::<Vec<_>>(),
1620                )
1621            }
1622            None => Some(parent_thread.tools.keys().map(|s| s.to_string()).collect()),
1623        };
1624
1625        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1626            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1627            thread.set_title(label.into(), cx);
1628            thread
1629        });
1630
1631        let session_id = subagent_thread.read(cx).id().clone();
1632
1633        let acp_thread = agent.update(cx, |agent, cx| {
1634            agent.register_session(
1635                subagent_thread.clone(),
1636                allowed_tools
1637                    .as_ref()
1638                    .map(|v| v.iter().map(|s| s.as_str()).collect()),
1639                cx,
1640            )
1641        })?;
1642
1643        parent_thread_entity.update(cx, |parent_thread, _cx| {
1644            parent_thread.register_running_subagent(subagent_thread.downgrade())
1645        });
1646
1647        let task = acp_thread.update(cx, |agent, cx| agent.send(vec![initial_prompt.into()], cx));
1648
1649        let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1650        let wait_for_prompt_to_complete = cx
1651            .background_spawn(async move {
1652                if let Some(timer) = timeout_timer {
1653                    futures::select! {
1654                        _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1655                        _ = task.fuse() => SubagentInitialPromptResult::Completed,
1656                    }
1657                } else {
1658                    task.await.log_err();
1659                    SubagentInitialPromptResult::Completed
1660                }
1661            })
1662            .shared();
1663
1664        let mut user_stop_rx: watch::Receiver<bool> =
1665            acp_thread.update(cx, |thread, _| thread.user_stop_receiver());
1666
1667        let user_cancelled = cx
1668            .background_spawn(async move {
1669                loop {
1670                    if *user_stop_rx.borrow() {
1671                        return;
1672                    }
1673                    if user_stop_rx.changed().await.is_err() {
1674                        std::future::pending::<()>().await;
1675                    }
1676                }
1677            })
1678            .shared();
1679
1680        Ok(Rc::new(NativeSubagentHandle {
1681            session_id,
1682            subagent_thread,
1683            parent_thread: parent_thread_entity.downgrade(),
1684            acp_thread,
1685            wait_for_prompt_to_complete,
1686            user_cancelled,
1687        }) as _)
1688    }
1689}
1690
1691impl ThreadEnvironment for NativeThreadEnvironment {
1692    fn create_terminal(
1693        &self,
1694        command: String,
1695        cwd: Option<PathBuf>,
1696        output_byte_limit: Option<u64>,
1697        cx: &mut AsyncApp,
1698    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1699        let task = self.acp_thread.update(cx, |thread, cx| {
1700            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1701        });
1702
1703        let acp_thread = self.acp_thread.clone();
1704        cx.spawn(async move |cx| {
1705            let terminal = task?.await?;
1706
1707            let (drop_tx, drop_rx) = oneshot::channel();
1708            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1709
1710            cx.spawn(async move |cx| {
1711                drop_rx.await.ok();
1712                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1713            })
1714            .detach();
1715
1716            let handle = AcpTerminalHandle {
1717                terminal,
1718                _drop_tx: Some(drop_tx),
1719            };
1720
1721            Ok(Rc::new(handle) as _)
1722        })
1723    }
1724
1725    fn create_subagent(
1726        &self,
1727        parent_thread_entity: Entity<Thread>,
1728        label: String,
1729        initial_prompt: String,
1730        timeout: Option<Duration>,
1731        allowed_tools: Option<Vec<String>>,
1732        cx: &mut App,
1733    ) -> Result<Rc<dyn SubagentHandle>> {
1734        Self::create_subagent_thread(
1735            self.agent.clone(),
1736            parent_thread_entity,
1737            label,
1738            initial_prompt,
1739            timeout,
1740            allowed_tools,
1741            cx,
1742        )
1743    }
1744}
1745
1746#[derive(Debug, Clone, Copy)]
1747enum SubagentInitialPromptResult {
1748    Completed,
1749    Timeout,
1750}
1751
1752pub struct NativeSubagentHandle {
1753    session_id: acp::SessionId,
1754    parent_thread: WeakEntity<Thread>,
1755    subagent_thread: Entity<Thread>,
1756    acp_thread: Entity<AcpThread>,
1757    wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1758    user_cancelled: Shared<Task<()>>,
1759}
1760
1761impl SubagentHandle for NativeSubagentHandle {
1762    fn id(&self) -> acp::SessionId {
1763        self.session_id.clone()
1764    }
1765
1766    fn wait_for_summary(&self, summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
1767        let thread = self.subagent_thread.clone();
1768        let acp_thread = self.acp_thread.clone();
1769        let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1770
1771        let wait_for_summary_task = cx.spawn(async move |cx| {
1772            let timed_out = match wait_for_prompt.await {
1773                SubagentInitialPromptResult::Completed => false,
1774                SubagentInitialPromptResult::Timeout => true,
1775            };
1776
1777            let summary_prompt = if timed_out {
1778                thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1779                format!("{}\n{}", "The time to complete the task was exceeded. Stop with the task and follow the directions below:", summary_prompt)
1780            } else {
1781                summary_prompt
1782            };
1783
1784            acp_thread
1785                .update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx))
1786                .await?;
1787
1788            thread.read_with(cx, |thread, _cx| {
1789                thread
1790                    .last_message()
1791                    .map(|m| m.to_markdown())
1792                    .context("No response from subagent")
1793            })
1794        });
1795
1796        let user_cancelled = self.user_cancelled.clone();
1797        let thread = self.subagent_thread.clone();
1798        let subagent_session_id = self.session_id.clone();
1799        let parent_thread = self.parent_thread.clone();
1800        cx.spawn(async move |cx| {
1801            let result = futures::select! {
1802                result = wait_for_summary_task.fuse() => result,
1803                _ = user_cancelled.fuse() => {
1804                    thread.update(cx, |thread, cx| thread.cancel(cx).detach());
1805                    Err(anyhow!("User cancelled"))
1806                },
1807            };
1808            parent_thread
1809                .update(cx, |parent_thread, cx| {
1810                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1811                })
1812                .ok();
1813            result
1814        })
1815    }
1816}
1817
1818pub struct AcpTerminalHandle {
1819    terminal: Entity<acp_thread::Terminal>,
1820    _drop_tx: Option<oneshot::Sender<()>>,
1821}
1822
1823impl TerminalHandle for AcpTerminalHandle {
1824    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1825        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1826    }
1827
1828    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1829        Ok(self
1830            .terminal
1831            .read_with(cx, |term, _cx| term.wait_for_exit()))
1832    }
1833
1834    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1835        Ok(self
1836            .terminal
1837            .read_with(cx, |term, cx| term.current_output(cx)))
1838    }
1839
1840    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1841        cx.update(|cx| {
1842            self.terminal.update(cx, |terminal, cx| {
1843                terminal.kill(cx);
1844            });
1845        });
1846        Ok(())
1847    }
1848
1849    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1850        Ok(self
1851            .terminal
1852            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1853    }
1854}
1855
1856#[cfg(test)]
1857mod internal_tests {
1858    use super::*;
1859    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1860    use fs::FakeFs;
1861    use gpui::TestAppContext;
1862    use indoc::formatdoc;
1863    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1864    use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1865    use serde_json::json;
1866    use settings::SettingsStore;
1867    use util::{path, rel_path::rel_path};
1868
1869    #[gpui::test]
1870    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1871        init_test(cx);
1872        let fs = FakeFs::new(cx.executor());
1873        fs.insert_tree(
1874            "/",
1875            json!({
1876                "a": {}
1877            }),
1878        )
1879        .await;
1880        let project = Project::test(fs.clone(), [], cx).await;
1881        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1882        let agent = NativeAgent::new(
1883            project.clone(),
1884            thread_store,
1885            Templates::new(),
1886            None,
1887            fs.clone(),
1888            &mut cx.to_async(),
1889        )
1890        .await
1891        .unwrap();
1892        agent.read_with(cx, |agent, cx| {
1893            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1894        });
1895
1896        let worktree = project
1897            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1898            .await
1899            .unwrap();
1900        cx.run_until_parked();
1901        agent.read_with(cx, |agent, cx| {
1902            assert_eq!(
1903                agent.project_context.read(cx).worktrees,
1904                vec![WorktreeContext {
1905                    root_name: "a".into(),
1906                    abs_path: Path::new("/a").into(),
1907                    rules_file: None
1908                }]
1909            )
1910        });
1911
1912        // Creating `/a/.rules` updates the project context.
1913        fs.insert_file("/a/.rules", Vec::new()).await;
1914        cx.run_until_parked();
1915        agent.read_with(cx, |agent, cx| {
1916            let rules_entry = worktree
1917                .read(cx)
1918                .entry_for_path(rel_path(".rules"))
1919                .unwrap();
1920            assert_eq!(
1921                agent.project_context.read(cx).worktrees,
1922                vec![WorktreeContext {
1923                    root_name: "a".into(),
1924                    abs_path: Path::new("/a").into(),
1925                    rules_file: Some(RulesFileContext {
1926                        path_in_worktree: rel_path(".rules").into(),
1927                        text: "".into(),
1928                        project_entry_id: rules_entry.id.to_usize()
1929                    })
1930                }]
1931            )
1932        });
1933    }
1934
1935    #[gpui::test]
1936    async fn test_listing_models(cx: &mut TestAppContext) {
1937        init_test(cx);
1938        let fs = FakeFs::new(cx.executor());
1939        fs.insert_tree("/", json!({ "a": {}  })).await;
1940        let project = Project::test(fs.clone(), [], cx).await;
1941        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1942        let connection = NativeAgentConnection(
1943            NativeAgent::new(
1944                project.clone(),
1945                thread_store,
1946                Templates::new(),
1947                None,
1948                fs.clone(),
1949                &mut cx.to_async(),
1950            )
1951            .await
1952            .unwrap(),
1953        );
1954
1955        // Create a thread/session
1956        let acp_thread = cx
1957            .update(|cx| {
1958                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1959            })
1960            .await
1961            .unwrap();
1962
1963        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1964
1965        let models = cx
1966            .update(|cx| {
1967                connection
1968                    .model_selector(&session_id)
1969                    .unwrap()
1970                    .list_models(cx)
1971            })
1972            .await
1973            .unwrap();
1974
1975        let acp_thread::AgentModelList::Grouped(models) = models else {
1976            panic!("Unexpected model group");
1977        };
1978        assert_eq!(
1979            models,
1980            IndexMap::from_iter([(
1981                AgentModelGroupName("Fake".into()),
1982                vec![AgentModelInfo {
1983                    id: acp::ModelId::new("fake/fake"),
1984                    name: "Fake".into(),
1985                    description: None,
1986                    icon: Some(acp_thread::AgentModelIcon::Named(
1987                        ui::IconName::ZedAssistant
1988                    )),
1989                    is_latest: false,
1990                }]
1991            )])
1992        );
1993    }
1994
1995    #[gpui::test]
1996    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1997        init_test(cx);
1998        let fs = FakeFs::new(cx.executor());
1999        fs.create_dir(paths::settings_file().parent().unwrap())
2000            .await
2001            .unwrap();
2002        fs.insert_file(
2003            paths::settings_file(),
2004            json!({
2005                "agent": {
2006                    "default_model": {
2007                        "provider": "foo",
2008                        "model": "bar"
2009                    }
2010                }
2011            })
2012            .to_string()
2013            .into_bytes(),
2014        )
2015        .await;
2016        let project = Project::test(fs.clone(), [], cx).await;
2017
2018        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2019
2020        // Create the agent and connection
2021        let agent = NativeAgent::new(
2022            project.clone(),
2023            thread_store,
2024            Templates::new(),
2025            None,
2026            fs.clone(),
2027            &mut cx.to_async(),
2028        )
2029        .await
2030        .unwrap();
2031        let connection = NativeAgentConnection(agent.clone());
2032
2033        // Create a thread/session
2034        let acp_thread = cx
2035            .update(|cx| {
2036                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2037            })
2038            .await
2039            .unwrap();
2040
2041        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2042
2043        // Select a model
2044        let selector = connection.model_selector(&session_id).unwrap();
2045        let model_id = acp::ModelId::new("fake/fake");
2046        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2047            .await
2048            .unwrap();
2049
2050        // Verify the thread has the selected model
2051        agent.read_with(cx, |agent, _| {
2052            let session = agent.sessions.get(&session_id).unwrap();
2053            session.thread.read_with(cx, |thread, _| {
2054                assert_eq!(thread.model().unwrap().id().0, "fake");
2055            });
2056        });
2057
2058        cx.run_until_parked();
2059
2060        // Verify settings file was updated
2061        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2062        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2063
2064        // Check that the agent settings contain the selected model
2065        assert_eq!(
2066            settings_json["agent"]["default_model"]["model"],
2067            json!("fake")
2068        );
2069        assert_eq!(
2070            settings_json["agent"]["default_model"]["provider"],
2071            json!("fake")
2072        );
2073
2074        // Register a thinking model and select it.
2075        cx.update(|cx| {
2076            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2077                "fake-corp",
2078                "fake-thinking",
2079                "Fake Thinking",
2080                true,
2081            ));
2082            let thinking_provider = Arc::new(
2083                FakeLanguageModelProvider::new(
2084                    LanguageModelProviderId::from("fake-corp".to_string()),
2085                    LanguageModelProviderName::from("Fake Corp".to_string()),
2086                )
2087                .with_models(vec![thinking_model]),
2088            );
2089            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2090                registry.register_provider(thinking_provider, cx);
2091            });
2092        });
2093        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2094
2095        let selector = connection.model_selector(&session_id).unwrap();
2096        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2097            .await
2098            .unwrap();
2099        cx.run_until_parked();
2100
2101        // Verify enable_thinking was written to settings as true.
2102        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2103        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2104        assert_eq!(
2105            settings_json["agent"]["default_model"]["enable_thinking"],
2106            json!(true),
2107            "selecting a thinking model should persist enable_thinking: true to settings"
2108        );
2109    }
2110
2111    #[gpui::test]
2112    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2113        init_test(cx);
2114        let fs = FakeFs::new(cx.executor());
2115        fs.create_dir(paths::settings_file().parent().unwrap())
2116            .await
2117            .unwrap();
2118        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2119        let project = Project::test(fs.clone(), [], cx).await;
2120
2121        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2122        let agent = NativeAgent::new(
2123            project.clone(),
2124            thread_store,
2125            Templates::new(),
2126            None,
2127            fs.clone(),
2128            &mut cx.to_async(),
2129        )
2130        .await
2131        .unwrap();
2132        let connection = NativeAgentConnection(agent.clone());
2133
2134        let acp_thread = cx
2135            .update(|cx| {
2136                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2137            })
2138            .await
2139            .unwrap();
2140        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2141
2142        // Register a second provider with a thinking model.
2143        cx.update(|cx| {
2144            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2145                "fake-corp",
2146                "fake-thinking",
2147                "Fake Thinking",
2148                true,
2149            ));
2150            let thinking_provider = Arc::new(
2151                FakeLanguageModelProvider::new(
2152                    LanguageModelProviderId::from("fake-corp".to_string()),
2153                    LanguageModelProviderName::from("Fake Corp".to_string()),
2154                )
2155                .with_models(vec![thinking_model]),
2156            );
2157            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2158                registry.register_provider(thinking_provider, cx);
2159            });
2160        });
2161        // Refresh the agent's model list so it picks up the new provider.
2162        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2163
2164        // Thread starts with thinking_enabled = false (the default).
2165        agent.read_with(cx, |agent, _| {
2166            let session = agent.sessions.get(&session_id).unwrap();
2167            session.thread.read_with(cx, |thread, _| {
2168                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2169            });
2170        });
2171
2172        // Select the thinking model via select_model.
2173        let selector = connection.model_selector(&session_id).unwrap();
2174        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2175            .await
2176            .unwrap();
2177
2178        // select_model should have enabled thinking based on the model's supports_thinking().
2179        agent.read_with(cx, |agent, _| {
2180            let session = agent.sessions.get(&session_id).unwrap();
2181            session.thread.read_with(cx, |thread, _| {
2182                assert!(
2183                    thread.thinking_enabled(),
2184                    "select_model should enable thinking when model supports it"
2185                );
2186            });
2187        });
2188
2189        // Switch back to the non-thinking model.
2190        let selector = connection.model_selector(&session_id).unwrap();
2191        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2192            .await
2193            .unwrap();
2194
2195        // select_model should have disabled thinking.
2196        agent.read_with(cx, |agent, _| {
2197            let session = agent.sessions.get(&session_id).unwrap();
2198            session.thread.read_with(cx, |thread, _| {
2199                assert!(
2200                    !thread.thinking_enabled(),
2201                    "select_model should disable thinking when model does not support it"
2202                );
2203            });
2204        });
2205    }
2206
2207    #[gpui::test]
2208    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2209        init_test(cx);
2210        let fs = FakeFs::new(cx.executor());
2211        fs.insert_tree("/", json!({ "a": {} })).await;
2212        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2213        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2214        let agent = NativeAgent::new(
2215            project.clone(),
2216            thread_store.clone(),
2217            Templates::new(),
2218            None,
2219            fs.clone(),
2220            &mut cx.to_async(),
2221        )
2222        .await
2223        .unwrap();
2224        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2225
2226        // Register a thinking model.
2227        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2228            "fake-corp",
2229            "fake-thinking",
2230            "Fake Thinking",
2231            true,
2232        ));
2233        let thinking_provider = Arc::new(
2234            FakeLanguageModelProvider::new(
2235                LanguageModelProviderId::from("fake-corp".to_string()),
2236                LanguageModelProviderName::from("Fake Corp".to_string()),
2237            )
2238            .with_models(vec![thinking_model.clone()]),
2239        );
2240        cx.update(|cx| {
2241            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2242                registry.register_provider(thinking_provider, cx);
2243            });
2244        });
2245        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2246
2247        // Create a thread and select the thinking model.
2248        let acp_thread = cx
2249            .update(|cx| {
2250                connection
2251                    .clone()
2252                    .new_session(project.clone(), Path::new("/a"), cx)
2253            })
2254            .await
2255            .unwrap();
2256        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2257
2258        let selector = connection.model_selector(&session_id).unwrap();
2259        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2260            .await
2261            .unwrap();
2262
2263        // Verify thinking is enabled after selecting the thinking model.
2264        let thread = agent.read_with(cx, |agent, _| {
2265            agent.sessions.get(&session_id).unwrap().thread.clone()
2266        });
2267        thread.read_with(cx, |thread, _| {
2268            assert!(
2269                thread.thinking_enabled(),
2270                "thinking should be enabled after selecting thinking model"
2271            );
2272        });
2273
2274        // Send a message so the thread gets persisted.
2275        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2276        let send = cx.foreground_executor().spawn(send);
2277        cx.run_until_parked();
2278
2279        thinking_model.send_last_completion_stream_text_chunk("Response.");
2280        thinking_model.end_last_completion_stream();
2281
2282        send.await.unwrap();
2283        cx.run_until_parked();
2284
2285        // Close the session so it can be reloaded from disk.
2286        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2287            .await
2288            .unwrap();
2289        drop(thread);
2290        drop(acp_thread);
2291        agent.read_with(cx, |agent, _| {
2292            assert!(agent.sessions.is_empty());
2293        });
2294
2295        // Reload the thread and verify thinking_enabled is still true.
2296        let reloaded_acp_thread = agent
2297            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2298            .await
2299            .unwrap();
2300        let reloaded_thread = agent.read_with(cx, |agent, _| {
2301            agent.sessions.get(&session_id).unwrap().thread.clone()
2302        });
2303        reloaded_thread.read_with(cx, |thread, _| {
2304            assert!(
2305                thread.thinking_enabled(),
2306                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2307            );
2308        });
2309
2310        drop(reloaded_acp_thread);
2311    }
2312
2313    #[gpui::test]
2314    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2315        init_test(cx);
2316        let fs = FakeFs::new(cx.executor());
2317        fs.insert_tree("/", json!({ "a": {} })).await;
2318        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2319        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2320        let agent = NativeAgent::new(
2321            project.clone(),
2322            thread_store.clone(),
2323            Templates::new(),
2324            None,
2325            fs.clone(),
2326            &mut cx.to_async(),
2327        )
2328        .await
2329        .unwrap();
2330        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2331
2332        // Register a model where id() != name(), like real Anthropic models
2333        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2334        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2335            "fake-corp",
2336            "custom-model-id",
2337            "Custom Model Display Name",
2338            false,
2339        ));
2340        let provider = Arc::new(
2341            FakeLanguageModelProvider::new(
2342                LanguageModelProviderId::from("fake-corp".to_string()),
2343                LanguageModelProviderName::from("Fake Corp".to_string()),
2344            )
2345            .with_models(vec![model.clone()]),
2346        );
2347        cx.update(|cx| {
2348            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2349                registry.register_provider(provider, cx);
2350            });
2351        });
2352        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2353
2354        // Create a thread and select the model.
2355        let acp_thread = cx
2356            .update(|cx| {
2357                connection
2358                    .clone()
2359                    .new_session(project.clone(), Path::new("/a"), cx)
2360            })
2361            .await
2362            .unwrap();
2363        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2364
2365        let selector = connection.model_selector(&session_id).unwrap();
2366        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2367            .await
2368            .unwrap();
2369
2370        let thread = agent.read_with(cx, |agent, _| {
2371            agent.sessions.get(&session_id).unwrap().thread.clone()
2372        });
2373        thread.read_with(cx, |thread, _| {
2374            assert_eq!(
2375                thread.model().unwrap().id().0.as_ref(),
2376                "custom-model-id",
2377                "model should be set before persisting"
2378            );
2379        });
2380
2381        // Send a message so the thread gets persisted.
2382        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2383        let send = cx.foreground_executor().spawn(send);
2384        cx.run_until_parked();
2385
2386        model.send_last_completion_stream_text_chunk("Response.");
2387        model.end_last_completion_stream();
2388
2389        send.await.unwrap();
2390        cx.run_until_parked();
2391
2392        // Close the session so it can be reloaded from disk.
2393        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2394            .await
2395            .unwrap();
2396        drop(thread);
2397        drop(acp_thread);
2398        agent.read_with(cx, |agent, _| {
2399            assert!(agent.sessions.is_empty());
2400        });
2401
2402        // Reload the thread and verify the model was preserved.
2403        let reloaded_acp_thread = agent
2404            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2405            .await
2406            .unwrap();
2407        let reloaded_thread = agent.read_with(cx, |agent, _| {
2408            agent.sessions.get(&session_id).unwrap().thread.clone()
2409        });
2410        reloaded_thread.read_with(cx, |thread, _| {
2411            let reloaded_model = thread
2412                .model()
2413                .expect("model should be present after reload");
2414            assert_eq!(
2415                reloaded_model.id().0.as_ref(),
2416                "custom-model-id",
2417                "reloaded thread should have the same model, not fall back to the default"
2418            );
2419        });
2420
2421        drop(reloaded_acp_thread);
2422    }
2423
2424    #[gpui::test]
2425    async fn test_save_load_thread(cx: &mut TestAppContext) {
2426        init_test(cx);
2427        let fs = FakeFs::new(cx.executor());
2428        fs.insert_tree(
2429            "/",
2430            json!({
2431                "a": {
2432                    "b.md": "Lorem"
2433                }
2434            }),
2435        )
2436        .await;
2437        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2438        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2439        let agent = NativeAgent::new(
2440            project.clone(),
2441            thread_store.clone(),
2442            Templates::new(),
2443            None,
2444            fs.clone(),
2445            &mut cx.to_async(),
2446        )
2447        .await
2448        .unwrap();
2449        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2450
2451        let acp_thread = cx
2452            .update(|cx| {
2453                connection
2454                    .clone()
2455                    .new_session(project.clone(), Path::new(""), cx)
2456            })
2457            .await
2458            .unwrap();
2459        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2460        let thread = agent.read_with(cx, |agent, _| {
2461            agent.sessions.get(&session_id).unwrap().thread.clone()
2462        });
2463
2464        // Ensure empty threads are not saved, even if they get mutated.
2465        let model = Arc::new(FakeLanguageModel::default());
2466        let summary_model = Arc::new(FakeLanguageModel::default());
2467        thread.update(cx, |thread, cx| {
2468            thread.set_model(model.clone(), cx);
2469            thread.set_summarization_model(Some(summary_model.clone()), cx);
2470        });
2471        cx.run_until_parked();
2472        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2473
2474        let send = acp_thread.update(cx, |thread, cx| {
2475            thread.send(
2476                vec![
2477                    "What does ".into(),
2478                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2479                        "b.md",
2480                        MentionUri::File {
2481                            abs_path: path!("/a/b.md").into(),
2482                        }
2483                        .to_uri()
2484                        .to_string(),
2485                    )),
2486                    " mean?".into(),
2487                ],
2488                cx,
2489            )
2490        });
2491        let send = cx.foreground_executor().spawn(send);
2492        cx.run_until_parked();
2493
2494        model.send_last_completion_stream_text_chunk("Lorem.");
2495        model.end_last_completion_stream();
2496        cx.run_until_parked();
2497        summary_model
2498            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2499        summary_model.end_last_completion_stream();
2500
2501        send.await.unwrap();
2502        let uri = MentionUri::File {
2503            abs_path: path!("/a/b.md").into(),
2504        }
2505        .to_uri();
2506        acp_thread.read_with(cx, |thread, cx| {
2507            assert_eq!(
2508                thread.to_markdown(cx),
2509                formatdoc! {"
2510                    ## User
2511
2512                    What does [@b.md]({uri}) mean?
2513
2514                    ## Assistant
2515
2516                    Lorem.
2517
2518                "}
2519            )
2520        });
2521
2522        cx.run_until_parked();
2523
2524        // Close the session so it can be reloaded from disk.
2525        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2526            .await
2527            .unwrap();
2528        drop(thread);
2529        drop(acp_thread);
2530        agent.read_with(cx, |agent, _| {
2531            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2532        });
2533
2534        // Ensure the thread can be reloaded from disk.
2535        assert_eq!(
2536            thread_entries(&thread_store, cx),
2537            vec![(
2538                session_id.clone(),
2539                format!("Explaining {}", path!("/a/b.md"))
2540            )]
2541        );
2542        let acp_thread = agent
2543            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2544            .await
2545            .unwrap();
2546        acp_thread.read_with(cx, |thread, cx| {
2547            assert_eq!(
2548                thread.to_markdown(cx),
2549                formatdoc! {"
2550                    ## User
2551
2552                    What does [@b.md]({uri}) mean?
2553
2554                    ## Assistant
2555
2556                    Lorem.
2557
2558                "}
2559            )
2560        });
2561    }
2562
2563    fn thread_entries(
2564        thread_store: &Entity<ThreadStore>,
2565        cx: &mut TestAppContext,
2566    ) -> Vec<(acp::SessionId, String)> {
2567        thread_store.read_with(cx, |store, _| {
2568            store
2569                .entries()
2570                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2571                .collect::<Vec<_>>()
2572        })
2573    }
2574
2575    fn init_test(cx: &mut TestAppContext) {
2576        env_logger::try_init().ok();
2577        cx.update(|cx| {
2578            let settings_store = SettingsStore::test(cx);
2579            cx.set_global(settings_store);
2580
2581            LanguageModelRegistry::test(cx);
2582        });
2583    }
2584}
2585
2586fn mcp_message_content_to_acp_content_block(
2587    content: context_server::types::MessageContent,
2588) -> acp::ContentBlock {
2589    match content {
2590        context_server::types::MessageContent::Text {
2591            text,
2592            annotations: _,
2593        } => text.into(),
2594        context_server::types::MessageContent::Image {
2595            data,
2596            mime_type,
2597            annotations: _,
2598        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2599        context_server::types::MessageContent::Audio {
2600            data,
2601            mime_type,
2602            annotations: _,
2603        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2604        context_server::types::MessageContent::Resource {
2605            resource,
2606            annotations: _,
2607        } => {
2608            let mut link =
2609                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2610            if let Some(mime_type) = resource.mime_type {
2611                link = link.mime_type(mime_type);
2612            }
2613            acp::ContentBlock::ResourceLink(link)
2614        }
2615    }
2616}