agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  41};
  42use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  43use project::{Project, ProjectItem, ProjectPath, Worktree};
  44use prompt_store::{
  45    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  46    WorktreeContext,
  47};
  48use serde::{Deserialize, Serialize};
  49use settings::{LanguageModelSelection, update_settings_file};
  50use std::any::Any;
  51use std::path::{Path, PathBuf};
  52use std::rc::Rc;
  53use std::sync::Arc;
  54use util::ResultExt;
  55use util::path_list::PathList;
  56use util::rel_path::RelPath;
  57
  58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  59pub struct ProjectSnapshot {
  60    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  61    pub timestamp: DateTime<Utc>,
  62}
  63
  64pub struct RulesLoadingError {
  65    pub message: SharedString,
  66}
  67
  68/// Holds both the internal Thread and the AcpThread for a session
  69struct Session {
  70    /// The internal thread that processes messages
  71    thread: Entity<Thread>,
  72    /// The ACP thread that handles protocol communication
  73    acp_thread: Entity<acp_thread::AcpThread>,
  74    pending_save: Task<()>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78pub struct LanguageModels {
  79    /// Access language model by ID
  80    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  81    /// Cached list for returning language model information
  82    model_list: acp_thread::AgentModelList,
  83    refresh_models_rx: watch::Receiver<()>,
  84    refresh_models_tx: watch::Sender<()>,
  85    _authenticate_all_providers_task: Task<()>,
  86}
  87
  88impl LanguageModels {
  89    fn new(cx: &mut App) -> Self {
  90        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  91
  92        let mut this = Self {
  93            models: HashMap::default(),
  94            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  95            refresh_models_rx,
  96            refresh_models_tx,
  97            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  98        };
  99        this.refresh_list(cx);
 100        this
 101    }
 102
 103    fn refresh_list(&mut self, cx: &App) {
 104        let providers = LanguageModelRegistry::global(cx)
 105            .read(cx)
 106            .visible_providers()
 107            .into_iter()
 108            .filter(|provider| provider.is_authenticated(cx))
 109            .collect::<Vec<_>>();
 110
 111        let mut language_model_list = IndexMap::default();
 112        let mut recommended_models = HashSet::default();
 113
 114        let mut recommended = Vec::new();
 115        for provider in &providers {
 116            for model in provider.recommended_models(cx) {
 117                recommended_models.insert((model.provider_id(), model.id()));
 118                recommended.push(Self::map_language_model_to_info(&model, provider));
 119            }
 120        }
 121        if !recommended.is_empty() {
 122            language_model_list.insert(
 123                acp_thread::AgentModelGroupName("Recommended".into()),
 124                recommended,
 125            );
 126        }
 127
 128        let mut models = HashMap::default();
 129        for provider in providers {
 130            let mut provider_models = Vec::new();
 131            for model in provider.provided_models(cx) {
 132                let model_info = Self::map_language_model_to_info(&model, &provider);
 133                let model_id = model_info.id.clone();
 134                provider_models.push(model_info);
 135                models.insert(model_id, model);
 136            }
 137            if !provider_models.is_empty() {
 138                language_model_list.insert(
 139                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 140                    provider_models,
 141                );
 142            }
 143        }
 144
 145        self.models = models;
 146        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 147        self.refresh_models_tx.send(()).ok();
 148    }
 149
 150    fn watch(&self) -> watch::Receiver<()> {
 151        self.refresh_models_rx.clone()
 152    }
 153
 154    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 155        self.models.get(model_id).cloned()
 156    }
 157
 158    fn map_language_model_to_info(
 159        model: &Arc<dyn LanguageModel>,
 160        provider: &Arc<dyn LanguageModelProvider>,
 161    ) -> acp_thread::AgentModelInfo {
 162        acp_thread::AgentModelInfo {
 163            id: Self::model_id(model),
 164            name: model.name().0,
 165            description: None,
 166            icon: Some(match provider.icon() {
 167                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 168                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 169            }),
 170            is_latest: model.is_latest(),
 171            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 172        }
 173    }
 174
 175    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 176        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 177    }
 178
 179    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 180        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 181            .read(cx)
 182            .visible_providers()
 183            .iter()
 184            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 185            .collect::<Vec<_>>();
 186
 187        cx.background_spawn(async move {
 188            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 189                if let Err(err) = authenticate_task.await {
 190                    match err {
 191                        language_model::AuthenticateError::CredentialsNotFound => {
 192                            // Since we're authenticating these providers in the
 193                            // background for the purposes of populating the
 194                            // language selector, we don't care about providers
 195                            // where the credentials are not found.
 196                        }
 197                        language_model::AuthenticateError::ConnectionRefused => {
 198                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 199                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 200                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 201                        }
 202                        _ => {
 203                            // Some providers have noisy failure states that we
 204                            // don't want to spam the logs with every time the
 205                            // language model selector is initialized.
 206                            //
 207                            // Ideally these should have more clear failure modes
 208                            // that we know are safe to ignore here, like what we do
 209                            // with `CredentialsNotFound` above.
 210                            match provider_id.0.as_ref() {
 211                                "lmstudio" | "ollama" => {
 212                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 213                                    //
 214                                    // These fail noisily, so we don't log them.
 215                                }
 216                                "copilot_chat" => {
 217                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 218                                }
 219                                _ => {
 220                                    log::error!(
 221                                        "Failed to authenticate provider: {}: {err:#}",
 222                                        provider_name.0
 223                                    );
 224                                }
 225                            }
 226                        }
 227                    }
 228                }
 229            }
 230        })
 231    }
 232}
 233
 234pub struct NativeAgent {
 235    /// Session ID -> Session mapping
 236    sessions: HashMap<acp::SessionId, Session>,
 237    thread_store: Entity<ThreadStore>,
 238    /// Shared project context for all threads
 239    project_context: Entity<ProjectContext>,
 240    project_context_needs_refresh: watch::Sender<()>,
 241    _maintain_project_context: Task<Result<()>>,
 242    context_server_registry: Entity<ContextServerRegistry>,
 243    /// Shared templates for all threads
 244    templates: Arc<Templates>,
 245    /// Cached model information
 246    models: LanguageModels,
 247    project: Entity<Project>,
 248    prompt_store: Option<Entity<PromptStore>>,
 249    fs: Arc<dyn Fs>,
 250    _subscriptions: Vec<Subscription>,
 251}
 252
 253impl NativeAgent {
 254    pub async fn new(
 255        project: Entity<Project>,
 256        thread_store: Entity<ThreadStore>,
 257        templates: Arc<Templates>,
 258        prompt_store: Option<Entity<PromptStore>>,
 259        fs: Arc<dyn Fs>,
 260        cx: &mut AsyncApp,
 261    ) -> Result<Entity<NativeAgent>> {
 262        log::debug!("Creating new NativeAgent");
 263
 264        let project_context = cx
 265            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 266            .await;
 267
 268        Ok(cx.new(|cx| {
 269            let context_server_store = project.read(cx).context_server_store();
 270            let context_server_registry =
 271                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 272
 273            let mut subscriptions = vec![
 274                cx.subscribe(&project, Self::handle_project_event),
 275                cx.subscribe(
 276                    &LanguageModelRegistry::global(cx),
 277                    Self::handle_models_updated_event,
 278                ),
 279                cx.subscribe(
 280                    &context_server_store,
 281                    Self::handle_context_server_store_updated,
 282                ),
 283                cx.subscribe(
 284                    &context_server_registry,
 285                    Self::handle_context_server_registry_event,
 286                ),
 287            ];
 288            if let Some(prompt_store) = prompt_store.as_ref() {
 289                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 290            }
 291
 292            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 293                watch::channel(());
 294            Self {
 295                sessions: HashMap::default(),
 296                thread_store,
 297                project_context: cx.new(|_| project_context),
 298                project_context_needs_refresh: project_context_needs_refresh_tx,
 299                _maintain_project_context: cx.spawn(async move |this, cx| {
 300                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 301                }),
 302                context_server_registry,
 303                templates,
 304                models: LanguageModels::new(cx),
 305                project,
 306                prompt_store,
 307                fs,
 308                _subscriptions: subscriptions,
 309            }
 310        }))
 311    }
 312
 313    fn new_session(
 314        &mut self,
 315        project: Entity<Project>,
 316        cx: &mut Context<Self>,
 317    ) -> Entity<AcpThread> {
 318        // Create Thread
 319        // Fetch default model from registry settings
 320        let registry = LanguageModelRegistry::read_global(cx);
 321        // Log available models for debugging
 322        let available_count = registry.available_models(cx).count();
 323        log::debug!("Total available models: {}", available_count);
 324
 325        let default_model = registry.default_model().and_then(|default_model| {
 326            self.models
 327                .model_from_id(&LanguageModels::model_id(&default_model.model))
 328        });
 329        let thread = cx.new(|cx| {
 330            Thread::new(
 331                project.clone(),
 332                self.project_context.clone(),
 333                self.context_server_registry.clone(),
 334                self.templates.clone(),
 335                default_model,
 336                cx,
 337            )
 338        });
 339
 340        self.register_session(thread, cx)
 341    }
 342
 343    fn register_session(
 344        &mut self,
 345        thread_handle: Entity<Thread>,
 346        cx: &mut Context<Self>,
 347    ) -> Entity<AcpThread> {
 348        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 349
 350        let thread = thread_handle.read(cx);
 351        let session_id = thread.id().clone();
 352        let parent_session_id = thread.parent_thread_id();
 353        let title = thread.title();
 354        let project = thread.project.clone();
 355        let action_log = thread.action_log.clone();
 356        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 357        let acp_thread = cx.new(|cx| {
 358            acp_thread::AcpThread::new(
 359                parent_session_id,
 360                title,
 361                connection,
 362                project.clone(),
 363                action_log.clone(),
 364                session_id.clone(),
 365                prompt_capabilities_rx,
 366                cx,
 367            )
 368        });
 369
 370        let registry = LanguageModelRegistry::read_global(cx);
 371        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 372
 373        let weak = cx.weak_entity();
 374        let weak_thread = thread_handle.downgrade();
 375        thread_handle.update(cx, |thread, cx| {
 376            thread.set_summarization_model(summarization_model, cx);
 377            thread.add_default_tools(
 378                Rc::new(NativeThreadEnvironment {
 379                    acp_thread: acp_thread.downgrade(),
 380                    thread: weak_thread,
 381                    agent: weak,
 382                }) as _,
 383                cx,
 384            )
 385        });
 386
 387        let subscriptions = vec![
 388            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 389            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 390            cx.observe(&thread_handle, move |this, thread, cx| {
 391                this.save_thread(thread, cx)
 392            }),
 393        ];
 394
 395        self.sessions.insert(
 396            session_id,
 397            Session {
 398                thread: thread_handle,
 399                acp_thread: acp_thread.clone(),
 400                _subscriptions: subscriptions,
 401                pending_save: Task::ready(()),
 402            },
 403        );
 404
 405        self.update_available_commands(cx);
 406
 407        acp_thread
 408    }
 409
 410    pub fn models(&self) -> &LanguageModels {
 411        &self.models
 412    }
 413
 414    async fn maintain_project_context(
 415        this: WeakEntity<Self>,
 416        mut needs_refresh: watch::Receiver<()>,
 417        cx: &mut AsyncApp,
 418    ) -> Result<()> {
 419        while needs_refresh.changed().await.is_ok() {
 420            let project_context = this
 421                .update(cx, |this, cx| {
 422                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 423                })?
 424                .await;
 425            this.update(cx, |this, cx| {
 426                this.project_context = cx.new(|_| project_context);
 427            })?;
 428        }
 429
 430        Ok(())
 431    }
 432
 433    fn build_project_context(
 434        project: &Entity<Project>,
 435        prompt_store: Option<&Entity<PromptStore>>,
 436        cx: &mut App,
 437    ) -> Task<ProjectContext> {
 438        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 439        let worktree_tasks = worktrees
 440            .into_iter()
 441            .map(|worktree| {
 442                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 443            })
 444            .collect::<Vec<_>>();
 445        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 446            prompt_store.read_with(cx, |prompt_store, cx| {
 447                let prompts = prompt_store.default_prompt_metadata();
 448                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 449                    let contents = prompt_store.load(prompt_metadata.id, cx);
 450                    async move { (contents.await, prompt_metadata) }
 451                });
 452                cx.background_spawn(future::join_all(load_tasks))
 453            })
 454        } else {
 455            Task::ready(vec![])
 456        };
 457
 458        cx.spawn(async move |_cx| {
 459            let (worktrees, default_user_rules) =
 460                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 461
 462            let worktrees = worktrees
 463                .into_iter()
 464                .map(|(worktree, _rules_error)| {
 465                    // TODO: show error message
 466                    // if let Some(rules_error) = rules_error {
 467                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 468                    // }
 469                    worktree
 470                })
 471                .collect::<Vec<_>>();
 472
 473            let default_user_rules = default_user_rules
 474                .into_iter()
 475                .flat_map(|(contents, prompt_metadata)| match contents {
 476                    Ok(contents) => Some(UserRulesContext {
 477                        uuid: prompt_metadata.id.as_user()?,
 478                        title: prompt_metadata.title.map(|title| title.to_string()),
 479                        contents,
 480                    }),
 481                    Err(_err) => {
 482                        // TODO: show error message
 483                        // this.update(cx, |_, cx| {
 484                        //     cx.emit(RulesLoadingError {
 485                        //         message: format!("{err:?}").into(),
 486                        //     });
 487                        // })
 488                        // .ok();
 489                        None
 490                    }
 491                })
 492                .collect::<Vec<_>>();
 493
 494            ProjectContext::new(worktrees, default_user_rules)
 495        })
 496    }
 497
 498    fn load_worktree_info_for_system_prompt(
 499        worktree: Entity<Worktree>,
 500        project: Entity<Project>,
 501        cx: &mut App,
 502    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 503        let tree = worktree.read(cx);
 504        let root_name = tree.root_name_str().into();
 505        let abs_path = tree.abs_path();
 506
 507        let mut context = WorktreeContext {
 508            root_name,
 509            abs_path,
 510            rules_file: None,
 511        };
 512
 513        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 514        let Some(rules_task) = rules_task else {
 515            return Task::ready((context, None));
 516        };
 517
 518        cx.spawn(async move |_| {
 519            let (rules_file, rules_file_error) = match rules_task.await {
 520                Ok(rules_file) => (Some(rules_file), None),
 521                Err(err) => (
 522                    None,
 523                    Some(RulesLoadingError {
 524                        message: format!("{err}").into(),
 525                    }),
 526                ),
 527            };
 528            context.rules_file = rules_file;
 529            (context, rules_file_error)
 530        })
 531    }
 532
 533    fn load_worktree_rules_file(
 534        worktree: Entity<Worktree>,
 535        project: Entity<Project>,
 536        cx: &mut App,
 537    ) -> Option<Task<Result<RulesFileContext>>> {
 538        let worktree = worktree.read(cx);
 539        let worktree_id = worktree.id();
 540        let selected_rules_file = RULES_FILE_NAMES
 541            .into_iter()
 542            .filter_map(|name| {
 543                worktree
 544                    .entry_for_path(RelPath::unix(name).unwrap())
 545                    .filter(|entry| entry.is_file())
 546                    .map(|entry| entry.path.clone())
 547            })
 548            .next();
 549
 550        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 551        // supported. This doesn't seem to occur often in GitHub repositories.
 552        selected_rules_file.map(|path_in_worktree| {
 553            let project_path = ProjectPath {
 554                worktree_id,
 555                path: path_in_worktree.clone(),
 556            };
 557            let buffer_task =
 558                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 559            let rope_task = cx.spawn(async move |cx| {
 560                let buffer = buffer_task.await?;
 561                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 562                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 563                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 564                })?;
 565                anyhow::Ok((project_entry_id, rope))
 566            });
 567            // Build a string from the rope on a background thread.
 568            cx.background_spawn(async move {
 569                let (project_entry_id, rope) = rope_task.await?;
 570                anyhow::Ok(RulesFileContext {
 571                    path_in_worktree,
 572                    text: rope.to_string().trim().to_string(),
 573                    project_entry_id: project_entry_id.to_usize(),
 574                })
 575            })
 576        })
 577    }
 578
 579    fn handle_thread_title_updated(
 580        &mut self,
 581        thread: Entity<Thread>,
 582        _: &TitleUpdated,
 583        cx: &mut Context<Self>,
 584    ) {
 585        let session_id = thread.read(cx).id();
 586        let Some(session) = self.sessions.get(session_id) else {
 587            return;
 588        };
 589        let thread = thread.downgrade();
 590        let acp_thread = session.acp_thread.downgrade();
 591        cx.spawn(async move |_, cx| {
 592            let title = thread.read_with(cx, |thread, _| thread.title())?;
 593            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 594            task.await
 595        })
 596        .detach_and_log_err(cx);
 597    }
 598
 599    fn handle_thread_token_usage_updated(
 600        &mut self,
 601        thread: Entity<Thread>,
 602        usage: &TokenUsageUpdated,
 603        cx: &mut Context<Self>,
 604    ) {
 605        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 606            return;
 607        };
 608        session.acp_thread.update(cx, |acp_thread, cx| {
 609            acp_thread.update_token_usage(usage.0.clone(), cx);
 610        });
 611    }
 612
 613    fn handle_project_event(
 614        &mut self,
 615        _project: Entity<Project>,
 616        event: &project::Event,
 617        _cx: &mut Context<Self>,
 618    ) {
 619        match event {
 620            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 621                self.project_context_needs_refresh.send(()).ok();
 622            }
 623            project::Event::WorktreeUpdatedEntries(_, items) => {
 624                if items.iter().any(|(path, _, _)| {
 625                    RULES_FILE_NAMES
 626                        .iter()
 627                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 628                }) {
 629                    self.project_context_needs_refresh.send(()).ok();
 630                }
 631            }
 632            _ => {}
 633        }
 634    }
 635
 636    fn handle_prompts_updated_event(
 637        &mut self,
 638        _prompt_store: Entity<PromptStore>,
 639        _event: &prompt_store::PromptsUpdatedEvent,
 640        _cx: &mut Context<Self>,
 641    ) {
 642        self.project_context_needs_refresh.send(()).ok();
 643    }
 644
 645    fn handle_models_updated_event(
 646        &mut self,
 647        _registry: Entity<LanguageModelRegistry>,
 648        _event: &language_model::Event,
 649        cx: &mut Context<Self>,
 650    ) {
 651        self.models.refresh_list(cx);
 652
 653        let registry = LanguageModelRegistry::read_global(cx);
 654        let default_model = registry.default_model().map(|m| m.model);
 655        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 656
 657        for session in self.sessions.values_mut() {
 658            session.thread.update(cx, |thread, cx| {
 659                if thread.model().is_none()
 660                    && let Some(model) = default_model.clone()
 661                {
 662                    thread.set_model(model, cx);
 663                    cx.notify();
 664                }
 665                thread.set_summarization_model(summarization_model.clone(), cx);
 666            });
 667        }
 668    }
 669
 670    fn handle_context_server_store_updated(
 671        &mut self,
 672        _store: Entity<project::context_server_store::ContextServerStore>,
 673        _event: &project::context_server_store::ServerStatusChangedEvent,
 674        cx: &mut Context<Self>,
 675    ) {
 676        self.update_available_commands(cx);
 677    }
 678
 679    fn handle_context_server_registry_event(
 680        &mut self,
 681        _registry: Entity<ContextServerRegistry>,
 682        event: &ContextServerRegistryEvent,
 683        cx: &mut Context<Self>,
 684    ) {
 685        match event {
 686            ContextServerRegistryEvent::ToolsChanged => {}
 687            ContextServerRegistryEvent::PromptsChanged => {
 688                self.update_available_commands(cx);
 689            }
 690        }
 691    }
 692
 693    fn update_available_commands(&self, cx: &mut Context<Self>) {
 694        let available_commands = self.build_available_commands(cx);
 695        for session in self.sessions.values() {
 696            session.acp_thread.update(cx, |thread, cx| {
 697                thread
 698                    .handle_session_update(
 699                        acp::SessionUpdate::AvailableCommandsUpdate(
 700                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 701                        ),
 702                        cx,
 703                    )
 704                    .log_err();
 705            });
 706        }
 707    }
 708
 709    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 710        let registry = self.context_server_registry.read(cx);
 711
 712        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 713        for context_server_prompt in registry.prompts() {
 714            *prompt_name_counts
 715                .entry(context_server_prompt.prompt.name.as_str())
 716                .or_insert(0) += 1;
 717        }
 718
 719        registry
 720            .prompts()
 721            .flat_map(|context_server_prompt| {
 722                let prompt = &context_server_prompt.prompt;
 723
 724                let should_prefix = prompt_name_counts
 725                    .get(prompt.name.as_str())
 726                    .copied()
 727                    .unwrap_or(0)
 728                    > 1;
 729
 730                let name = if should_prefix {
 731                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 732                } else {
 733                    prompt.name.clone()
 734                };
 735
 736                let mut command = acp::AvailableCommand::new(
 737                    name,
 738                    prompt.description.clone().unwrap_or_default(),
 739                );
 740
 741                match prompt.arguments.as_deref() {
 742                    Some([arg]) => {
 743                        let hint = format!("<{}>", arg.name);
 744
 745                        command = command.input(acp::AvailableCommandInput::Unstructured(
 746                            acp::UnstructuredCommandInput::new(hint),
 747                        ));
 748                    }
 749                    Some([]) | None => {}
 750                    Some(_) => {
 751                        // skip >1 argument commands since we don't support them yet
 752                        return None;
 753                    }
 754                }
 755
 756                Some(command)
 757            })
 758            .collect()
 759    }
 760
 761    pub fn load_thread(
 762        &mut self,
 763        id: acp::SessionId,
 764        cx: &mut Context<Self>,
 765    ) -> Task<Result<Entity<Thread>>> {
 766        let database_future = ThreadsDatabase::connect(cx);
 767        cx.spawn(async move |this, cx| {
 768            let database = database_future.await.map_err(|err| anyhow!(err))?;
 769            let db_thread = database
 770                .load_thread(id.clone())
 771                .await?
 772                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 773
 774            this.update(cx, |this, cx| {
 775                let summarization_model = LanguageModelRegistry::read_global(cx)
 776                    .thread_summary_model()
 777                    .map(|c| c.model);
 778
 779                cx.new(|cx| {
 780                    let mut thread = Thread::from_db(
 781                        id.clone(),
 782                        db_thread,
 783                        this.project.clone(),
 784                        this.project_context.clone(),
 785                        this.context_server_registry.clone(),
 786                        this.templates.clone(),
 787                        cx,
 788                    );
 789                    thread.set_summarization_model(summarization_model, cx);
 790                    thread
 791                })
 792            })
 793        })
 794    }
 795
 796    pub fn open_thread(
 797        &mut self,
 798        id: acp::SessionId,
 799        cx: &mut Context<Self>,
 800    ) -> Task<Result<Entity<AcpThread>>> {
 801        if let Some(session) = self.sessions.get(&id) {
 802            return Task::ready(Ok(session.acp_thread.clone()));
 803        }
 804
 805        let task = self.load_thread(id, cx);
 806        cx.spawn(async move |this, cx| {
 807            let thread = task.await?;
 808            let acp_thread =
 809                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 810            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 811            cx.update(|cx| {
 812                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 813            })
 814            .await?;
 815            Ok(acp_thread)
 816        })
 817    }
 818
 819    pub fn thread_summary(
 820        &mut self,
 821        id: acp::SessionId,
 822        cx: &mut Context<Self>,
 823    ) -> Task<Result<SharedString>> {
 824        let thread = self.open_thread(id.clone(), cx);
 825        cx.spawn(async move |this, cx| {
 826            let acp_thread = thread.await?;
 827            let result = this
 828                .update(cx, |this, cx| {
 829                    this.sessions
 830                        .get(&id)
 831                        .unwrap()
 832                        .thread
 833                        .update(cx, |thread, cx| thread.summary(cx))
 834                })?
 835                .await
 836                .context("Failed to generate summary")?;
 837            drop(acp_thread);
 838            Ok(result)
 839        })
 840    }
 841
 842    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 843        if thread.read(cx).is_empty() {
 844            return;
 845        }
 846
 847        let database_future = ThreadsDatabase::connect(cx);
 848        let (id, db_thread) =
 849            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 850        let Some(session) = self.sessions.get_mut(&id) else {
 851            return;
 852        };
 853
 854        let folder_paths = PathList::new(
 855            &self
 856                .project
 857                .read(cx)
 858                .visible_worktrees(cx)
 859                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 860                .collect::<Vec<_>>(),
 861        );
 862
 863        let thread_store = self.thread_store.clone();
 864        session.pending_save = cx.spawn(async move |_, cx| {
 865            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 866                return;
 867            };
 868            let db_thread = db_thread.await;
 869            database
 870                .save_thread(id, db_thread, folder_paths)
 871                .await
 872                .log_err();
 873            thread_store.update(cx, |store, cx| store.reload(cx));
 874        });
 875    }
 876
 877    fn send_mcp_prompt(
 878        &self,
 879        message_id: UserMessageId,
 880        session_id: agent_client_protocol::SessionId,
 881        prompt_name: String,
 882        server_id: ContextServerId,
 883        arguments: HashMap<String, String>,
 884        original_content: Vec<acp::ContentBlock>,
 885        cx: &mut Context<Self>,
 886    ) -> Task<Result<acp::PromptResponse>> {
 887        let server_store = self.context_server_registry.read(cx).server_store().clone();
 888        let path_style = self.project.read(cx).path_style(cx);
 889
 890        cx.spawn(async move |this, cx| {
 891            let prompt =
 892                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 893
 894            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 895                let session = this
 896                    .sessions
 897                    .get(&session_id)
 898                    .context("Failed to get session")?;
 899                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 900            })??;
 901
 902            let mut last_is_user = true;
 903
 904            thread.update(cx, |thread, cx| {
 905                thread.push_acp_user_block(
 906                    message_id,
 907                    original_content.into_iter().skip(1),
 908                    path_style,
 909                    cx,
 910                );
 911            });
 912
 913            for message in prompt.messages {
 914                let context_server::types::PromptMessage { role, content } = message;
 915                let block = mcp_message_content_to_acp_content_block(content);
 916
 917                match role {
 918                    context_server::types::Role::User => {
 919                        let id = acp_thread::UserMessageId::new();
 920
 921                        acp_thread.update(cx, |acp_thread, cx| {
 922                            acp_thread.push_user_content_block_with_indent(
 923                                Some(id.clone()),
 924                                block.clone(),
 925                                true,
 926                                cx,
 927                            );
 928                        });
 929
 930                        thread.update(cx, |thread, cx| {
 931                            thread.push_acp_user_block(id, [block], path_style, cx);
 932                        });
 933                    }
 934                    context_server::types::Role::Assistant => {
 935                        acp_thread.update(cx, |acp_thread, cx| {
 936                            acp_thread.push_assistant_content_block_with_indent(
 937                                block.clone(),
 938                                false,
 939                                true,
 940                                cx,
 941                            );
 942                        });
 943
 944                        thread.update(cx, |thread, cx| {
 945                            thread.push_acp_agent_block(block, cx);
 946                        });
 947                    }
 948                }
 949
 950                last_is_user = role == context_server::types::Role::User;
 951            }
 952
 953            let response_stream = thread.update(cx, |thread, cx| {
 954                if last_is_user {
 955                    thread.send_existing(cx)
 956                } else {
 957                    // Resume if MCP prompt did not end with a user message
 958                    thread.resume(cx)
 959                }
 960            })?;
 961
 962            cx.update(|cx| {
 963                NativeAgentConnection::handle_thread_events(
 964                    response_stream,
 965                    acp_thread.downgrade(),
 966                    cx,
 967                )
 968            })
 969            .await
 970        })
 971    }
 972}
 973
 974/// Wrapper struct that implements the AgentConnection trait
 975#[derive(Clone)]
 976pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 977
 978impl NativeAgentConnection {
 979    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 980        self.0
 981            .read(cx)
 982            .sessions
 983            .get(session_id)
 984            .map(|session| session.thread.clone())
 985    }
 986
 987    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 988        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 989    }
 990
 991    fn run_turn(
 992        &self,
 993        session_id: acp::SessionId,
 994        cx: &mut App,
 995        f: impl 'static
 996        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 997    ) -> Task<Result<acp::PromptResponse>> {
 998        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 999            agent
1000                .sessions
1001                .get_mut(&session_id)
1002                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1003        }) else {
1004            return Task::ready(Err(anyhow!("Session not found")));
1005        };
1006        log::debug!("Found session for: {}", session_id);
1007
1008        let response_stream = match f(thread, cx) {
1009            Ok(stream) => stream,
1010            Err(err) => return Task::ready(Err(err)),
1011        };
1012        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1013    }
1014
1015    fn handle_thread_events(
1016        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1017        acp_thread: WeakEntity<AcpThread>,
1018        cx: &App,
1019    ) -> Task<Result<acp::PromptResponse>> {
1020        cx.spawn(async move |cx| {
1021            // Handle response stream and forward to session.acp_thread
1022            while let Some(result) = events.next().await {
1023                match result {
1024                    Ok(event) => {
1025                        log::trace!("Received completion event: {:?}", event);
1026
1027                        match event {
1028                            ThreadEvent::UserMessage(message) => {
1029                                acp_thread.update(cx, |thread, cx| {
1030                                    for content in message.content {
1031                                        thread.push_user_content_block(
1032                                            Some(message.id.clone()),
1033                                            content.into(),
1034                                            cx,
1035                                        );
1036                                    }
1037                                })?;
1038                            }
1039                            ThreadEvent::AgentText(text) => {
1040                                acp_thread.update(cx, |thread, cx| {
1041                                    thread.push_assistant_content_block(text.into(), false, cx)
1042                                })?;
1043                            }
1044                            ThreadEvent::AgentThinking(text) => {
1045                                acp_thread.update(cx, |thread, cx| {
1046                                    thread.push_assistant_content_block(text.into(), true, cx)
1047                                })?;
1048                            }
1049                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1050                                tool_call,
1051                                options,
1052                                response,
1053                                context: _,
1054                            }) => {
1055                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1056                                    thread.request_tool_call_authorization(tool_call, options, cx)
1057                                })??;
1058                                cx.background_spawn(async move {
1059                                    if let acp::RequestPermissionOutcome::Selected(
1060                                        acp::SelectedPermissionOutcome { option_id, .. },
1061                                    ) = outcome_task.await
1062                                    {
1063                                        response
1064                                            .send(option_id)
1065                                            .map(|_| anyhow!("authorization receiver was dropped"))
1066                                            .log_err();
1067                                    }
1068                                })
1069                                .detach();
1070                            }
1071                            ThreadEvent::ToolCall(tool_call) => {
1072                                acp_thread.update(cx, |thread, cx| {
1073                                    thread.upsert_tool_call(tool_call, cx)
1074                                })??;
1075                            }
1076                            ThreadEvent::ToolCallUpdate(update) => {
1077                                acp_thread.update(cx, |thread, cx| {
1078                                    thread.update_tool_call(update, cx)
1079                                })??;
1080                            }
1081                            ThreadEvent::SubagentSpawned(session_id) => {
1082                                acp_thread.update(cx, |thread, cx| {
1083                                    thread.subagent_spawned(session_id, cx);
1084                                })?;
1085                            }
1086                            ThreadEvent::Retry(status) => {
1087                                acp_thread.update(cx, |thread, cx| {
1088                                    thread.update_retry_status(status, cx)
1089                                })?;
1090                            }
1091                            ThreadEvent::Stop(stop_reason) => {
1092                                log::debug!("Assistant message complete: {:?}", stop_reason);
1093                                return Ok(acp::PromptResponse::new(stop_reason));
1094                            }
1095                        }
1096                    }
1097                    Err(e) => {
1098                        log::error!("Error in model response stream: {:?}", e);
1099                        return Err(e);
1100                    }
1101                }
1102            }
1103
1104            log::debug!("Response stream completed");
1105            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1106        })
1107    }
1108}
1109
1110struct Command<'a> {
1111    prompt_name: &'a str,
1112    arg_value: &'a str,
1113    explicit_server_id: Option<&'a str>,
1114}
1115
1116impl<'a> Command<'a> {
1117    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1118        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1119            return None;
1120        };
1121        let text = text_content.text.trim();
1122        let command = text.strip_prefix('/')?;
1123        let (command, arg_value) = command
1124            .split_once(char::is_whitespace)
1125            .unwrap_or((command, ""));
1126
1127        if let Some((server_id, prompt_name)) = command.split_once('.') {
1128            Some(Self {
1129                prompt_name,
1130                arg_value,
1131                explicit_server_id: Some(server_id),
1132            })
1133        } else {
1134            Some(Self {
1135                prompt_name: command,
1136                arg_value,
1137                explicit_server_id: None,
1138            })
1139        }
1140    }
1141}
1142
1143struct NativeAgentModelSelector {
1144    session_id: acp::SessionId,
1145    connection: NativeAgentConnection,
1146}
1147
1148impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1149    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1150        log::debug!("NativeAgentConnection::list_models called");
1151        let list = self.connection.0.read(cx).models.model_list.clone();
1152        Task::ready(if list.is_empty() {
1153            Err(anyhow::anyhow!("No models available"))
1154        } else {
1155            Ok(list)
1156        })
1157    }
1158
1159    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1160        log::debug!(
1161            "Setting model for session {}: {}",
1162            self.session_id,
1163            model_id
1164        );
1165        let Some(thread) = self
1166            .connection
1167            .0
1168            .read(cx)
1169            .sessions
1170            .get(&self.session_id)
1171            .map(|session| session.thread.clone())
1172        else {
1173            return Task::ready(Err(anyhow!("Session not found")));
1174        };
1175
1176        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1177            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1178        };
1179
1180        // We want to reset the effort level when switching models, as the currently-selected effort level may
1181        // not be compatible.
1182        let effort = model
1183            .default_effort_level()
1184            .map(|effort_level| effort_level.value.to_string());
1185
1186        thread.update(cx, |thread, cx| {
1187            thread.set_model(model.clone(), cx);
1188            thread.set_thinking_effort(effort.clone(), cx);
1189            thread.set_thinking_enabled(model.supports_thinking(), cx);
1190        });
1191
1192        update_settings_file(
1193            self.connection.0.read(cx).fs.clone(),
1194            cx,
1195            move |settings, cx| {
1196                let provider = model.provider_id().0.to_string();
1197                let model = model.id().0.to_string();
1198                let enable_thinking = thread.read(cx).thinking_enabled();
1199                settings
1200                    .agent
1201                    .get_or_insert_default()
1202                    .set_model(LanguageModelSelection {
1203                        provider: provider.into(),
1204                        model,
1205                        enable_thinking,
1206                        effort,
1207                    });
1208            },
1209        );
1210
1211        Task::ready(Ok(()))
1212    }
1213
1214    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1215        let Some(thread) = self
1216            .connection
1217            .0
1218            .read(cx)
1219            .sessions
1220            .get(&self.session_id)
1221            .map(|session| session.thread.clone())
1222        else {
1223            return Task::ready(Err(anyhow!("Session not found")));
1224        };
1225        let Some(model) = thread.read(cx).model() else {
1226            return Task::ready(Err(anyhow!("Model not found")));
1227        };
1228        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1229        else {
1230            return Task::ready(Err(anyhow!("Provider not found")));
1231        };
1232        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1233            model, &provider,
1234        )))
1235    }
1236
1237    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1238        Some(self.connection.0.read(cx).models.watch())
1239    }
1240
1241    fn should_render_footer(&self) -> bool {
1242        true
1243    }
1244}
1245
1246impl acp_thread::AgentConnection for NativeAgentConnection {
1247    fn telemetry_id(&self) -> SharedString {
1248        "zed".into()
1249    }
1250
1251    fn new_session(
1252        self: Rc<Self>,
1253        project: Entity<Project>,
1254        cwd: &Path,
1255        cx: &mut App,
1256    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1257        log::debug!("Creating new thread for project at: {cwd:?}");
1258        Task::ready(Ok(self
1259            .0
1260            .update(cx, |agent, cx| agent.new_session(project, cx))))
1261    }
1262
1263    fn supports_load_session(&self) -> bool {
1264        true
1265    }
1266
1267    fn load_session(
1268        self: Rc<Self>,
1269        session: AgentSessionInfo,
1270        _project: Entity<Project>,
1271        _cwd: &Path,
1272        cx: &mut App,
1273    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1274        self.0
1275            .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1276    }
1277
1278    fn supports_close_session(&self) -> bool {
1279        true
1280    }
1281
1282    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1283        self.0.update(cx, |agent, _cx| {
1284            agent.sessions.remove(session_id);
1285        });
1286        Task::ready(Ok(()))
1287    }
1288
1289    fn auth_methods(&self) -> &[acp::AuthMethod] {
1290        &[] // No auth for in-process
1291    }
1292
1293    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1294        Task::ready(Ok(()))
1295    }
1296
1297    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1298        Some(Rc::new(NativeAgentModelSelector {
1299            session_id: session_id.clone(),
1300            connection: self.clone(),
1301        }) as Rc<dyn AgentModelSelector>)
1302    }
1303
1304    fn prompt(
1305        &self,
1306        id: Option<acp_thread::UserMessageId>,
1307        params: acp::PromptRequest,
1308        cx: &mut App,
1309    ) -> Task<Result<acp::PromptResponse>> {
1310        let id = id.expect("UserMessageId is required");
1311        let session_id = params.session_id.clone();
1312        log::info!("Received prompt request for session: {}", session_id);
1313        log::debug!("Prompt blocks count: {}", params.prompt.len());
1314
1315        if let Some(parsed_command) = Command::parse(&params.prompt) {
1316            let registry = self.0.read(cx).context_server_registry.read(cx);
1317
1318            let explicit_server_id = parsed_command
1319                .explicit_server_id
1320                .map(|server_id| ContextServerId(server_id.into()));
1321
1322            if let Some(prompt) =
1323                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1324            {
1325                let arguments = if !parsed_command.arg_value.is_empty()
1326                    && let Some(arg_name) = prompt
1327                        .prompt
1328                        .arguments
1329                        .as_ref()
1330                        .and_then(|args| args.first())
1331                        .map(|arg| arg.name.clone())
1332                {
1333                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1334                } else {
1335                    Default::default()
1336                };
1337
1338                let prompt_name = prompt.prompt.name.clone();
1339                let server_id = prompt.server_id.clone();
1340
1341                return self.0.update(cx, |agent, cx| {
1342                    agent.send_mcp_prompt(
1343                        id,
1344                        session_id.clone(),
1345                        prompt_name,
1346                        server_id,
1347                        arguments,
1348                        params.prompt,
1349                        cx,
1350                    )
1351                });
1352            };
1353        };
1354
1355        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1356
1357        self.run_turn(session_id, cx, move |thread, cx| {
1358            let content: Vec<UserMessageContent> = params
1359                .prompt
1360                .into_iter()
1361                .map(|block| UserMessageContent::from_content_block(block, path_style))
1362                .collect::<Vec<_>>();
1363            log::debug!("Converted prompt to message: {} chars", content.len());
1364            log::debug!("Message id: {:?}", id);
1365            log::debug!("Message content: {:?}", content);
1366
1367            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1368        })
1369    }
1370
1371    fn retry(
1372        &self,
1373        session_id: &acp::SessionId,
1374        _cx: &App,
1375    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1376        Some(Rc::new(NativeAgentSessionRetry {
1377            connection: self.clone(),
1378            session_id: session_id.clone(),
1379        }) as _)
1380    }
1381
1382    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1383        log::info!("Cancelling on session: {}", session_id);
1384        self.0.update(cx, |agent, cx| {
1385            if let Some(session) = agent.sessions.get(session_id) {
1386                session
1387                    .thread
1388                    .update(cx, |thread, cx| thread.cancel(cx))
1389                    .detach();
1390            }
1391        });
1392    }
1393
1394    fn truncate(
1395        &self,
1396        session_id: &agent_client_protocol::SessionId,
1397        cx: &App,
1398    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1399        self.0.read_with(cx, |agent, _cx| {
1400            agent.sessions.get(session_id).map(|session| {
1401                Rc::new(NativeAgentSessionTruncate {
1402                    thread: session.thread.clone(),
1403                    acp_thread: session.acp_thread.downgrade(),
1404                }) as _
1405            })
1406        })
1407    }
1408
1409    fn set_title(
1410        &self,
1411        session_id: &acp::SessionId,
1412        cx: &App,
1413    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1414        self.0.read_with(cx, |agent, _cx| {
1415            agent
1416                .sessions
1417                .get(session_id)
1418                .filter(|s| !s.thread.read(cx).is_subagent())
1419                .map(|session| {
1420                    Rc::new(NativeAgentSessionSetTitle {
1421                        thread: session.thread.clone(),
1422                    }) as _
1423                })
1424        })
1425    }
1426
1427    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1428        let thread_store = self.0.read(cx).thread_store.clone();
1429        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1430    }
1431
1432    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1433        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1434    }
1435
1436    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1437        self
1438    }
1439}
1440
1441impl acp_thread::AgentTelemetry for NativeAgentConnection {
1442    fn thread_data(
1443        &self,
1444        session_id: &acp::SessionId,
1445        cx: &mut App,
1446    ) -> Task<Result<serde_json::Value>> {
1447        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1448            return Task::ready(Err(anyhow!("Session not found")));
1449        };
1450
1451        let task = session.thread.read(cx).to_db(cx);
1452        cx.background_spawn(async move {
1453            serde_json::to_value(task.await).context("Failed to serialize thread")
1454        })
1455    }
1456}
1457
1458pub struct NativeAgentSessionList {
1459    thread_store: Entity<ThreadStore>,
1460    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1461    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1462    _subscription: Subscription,
1463}
1464
1465impl NativeAgentSessionList {
1466    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1467        let (tx, rx) = smol::channel::unbounded();
1468        let this_tx = tx.clone();
1469        let subscription = cx.observe(&thread_store, move |_, _| {
1470            this_tx
1471                .try_send(acp_thread::SessionListUpdate::Refresh)
1472                .ok();
1473        });
1474        Self {
1475            thread_store,
1476            updates_tx: tx,
1477            updates_rx: rx,
1478            _subscription: subscription,
1479        }
1480    }
1481
1482    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1483        AgentSessionInfo {
1484            session_id: entry.id,
1485            cwd: None,
1486            title: Some(entry.title),
1487            updated_at: Some(entry.updated_at),
1488            meta: None,
1489        }
1490    }
1491
1492    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1493        &self.thread_store
1494    }
1495}
1496
1497impl AgentSessionList for NativeAgentSessionList {
1498    fn list_sessions(
1499        &self,
1500        _request: AgentSessionListRequest,
1501        cx: &mut App,
1502    ) -> Task<Result<AgentSessionListResponse>> {
1503        let sessions = self
1504            .thread_store
1505            .read(cx)
1506            .entries()
1507            .map(Self::to_session_info)
1508            .collect();
1509        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1510    }
1511
1512    fn supports_delete(&self) -> bool {
1513        true
1514    }
1515
1516    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1517        self.thread_store
1518            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1519    }
1520
1521    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1522        self.thread_store
1523            .update(cx, |store, cx| store.delete_threads(cx))
1524    }
1525
1526    fn watch(
1527        &self,
1528        _cx: &mut App,
1529    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1530        Some(self.updates_rx.clone())
1531    }
1532
1533    fn notify_refresh(&self) {
1534        self.updates_tx
1535            .try_send(acp_thread::SessionListUpdate::Refresh)
1536            .ok();
1537    }
1538
1539    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1540        self
1541    }
1542}
1543
1544struct NativeAgentSessionTruncate {
1545    thread: Entity<Thread>,
1546    acp_thread: WeakEntity<AcpThread>,
1547}
1548
1549impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1550    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1551        match self.thread.update(cx, |thread, cx| {
1552            thread.truncate(message_id.clone(), cx)?;
1553            Ok(thread.latest_token_usage())
1554        }) {
1555            Ok(usage) => {
1556                self.acp_thread
1557                    .update(cx, |thread, cx| {
1558                        thread.update_token_usage(usage, cx);
1559                    })
1560                    .ok();
1561                Task::ready(Ok(()))
1562            }
1563            Err(error) => Task::ready(Err(error)),
1564        }
1565    }
1566}
1567
1568struct NativeAgentSessionRetry {
1569    connection: NativeAgentConnection,
1570    session_id: acp::SessionId,
1571}
1572
1573impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1574    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1575        self.connection
1576            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1577                thread.update(cx, |thread, cx| thread.resume(cx))
1578            })
1579    }
1580}
1581
1582struct NativeAgentSessionSetTitle {
1583    thread: Entity<Thread>,
1584}
1585
1586impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1587    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1588        self.thread
1589            .update(cx, |thread, cx| thread.set_title(title, cx));
1590        Task::ready(Ok(()))
1591    }
1592}
1593
1594pub struct NativeThreadEnvironment {
1595    agent: WeakEntity<NativeAgent>,
1596    thread: WeakEntity<Thread>,
1597    acp_thread: WeakEntity<AcpThread>,
1598}
1599
1600impl NativeThreadEnvironment {
1601    pub(crate) fn create_subagent_thread(
1602        &self,
1603        label: String,
1604        cx: &mut App,
1605    ) -> Result<Rc<dyn SubagentHandle>> {
1606        let Some(parent_thread_entity) = self.thread.upgrade() else {
1607            anyhow::bail!("Parent thread no longer exists".to_string());
1608        };
1609        let parent_thread = parent_thread_entity.read(cx);
1610        let current_depth = parent_thread.depth();
1611
1612        if current_depth >= MAX_SUBAGENT_DEPTH {
1613            return Err(anyhow!(
1614                "Maximum subagent depth ({}) reached",
1615                MAX_SUBAGENT_DEPTH
1616            ));
1617        }
1618
1619        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1620            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1621            thread.set_title(label.into(), cx);
1622            thread
1623        });
1624
1625        let session_id = subagent_thread.read(cx).id().clone();
1626
1627        let acp_thread = self.agent.update(cx, |agent, cx| {
1628            agent.register_session(subagent_thread.clone(), cx)
1629        })?;
1630
1631        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1632    }
1633
1634    pub(crate) fn resume_subagent_thread(
1635        &self,
1636        session_id: acp::SessionId,
1637        cx: &mut App,
1638    ) -> Result<Rc<dyn SubagentHandle>> {
1639        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1640            let session = agent
1641                .sessions
1642                .get(&session_id)
1643                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1644            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1645        })??;
1646
1647        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1648    }
1649
1650    fn prompt_subagent(
1651        &self,
1652        session_id: acp::SessionId,
1653        subagent_thread: Entity<Thread>,
1654        acp_thread: Entity<acp_thread::AcpThread>,
1655    ) -> Result<Rc<dyn SubagentHandle>> {
1656        let Some(parent_thread_entity) = self.thread.upgrade() else {
1657            anyhow::bail!("Parent thread no longer exists".to_string());
1658        };
1659        Ok(Rc::new(NativeSubagentHandle::new(
1660            session_id,
1661            subagent_thread,
1662            acp_thread,
1663            parent_thread_entity,
1664        )) as _)
1665    }
1666}
1667
1668impl ThreadEnvironment for NativeThreadEnvironment {
1669    fn create_terminal(
1670        &self,
1671        command: String,
1672        cwd: Option<PathBuf>,
1673        output_byte_limit: Option<u64>,
1674        cx: &mut AsyncApp,
1675    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1676        let task = self.acp_thread.update(cx, |thread, cx| {
1677            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1678        });
1679
1680        let acp_thread = self.acp_thread.clone();
1681        cx.spawn(async move |cx| {
1682            let terminal = task?.await?;
1683
1684            let (drop_tx, drop_rx) = oneshot::channel();
1685            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1686
1687            cx.spawn(async move |cx| {
1688                drop_rx.await.ok();
1689                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1690            })
1691            .detach();
1692
1693            let handle = AcpTerminalHandle {
1694                terminal,
1695                _drop_tx: Some(drop_tx),
1696            };
1697
1698            Ok(Rc::new(handle) as _)
1699        })
1700    }
1701
1702    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1703        self.create_subagent_thread(label, cx)
1704    }
1705
1706    fn resume_subagent(
1707        &self,
1708        session_id: acp::SessionId,
1709        cx: &mut App,
1710    ) -> Result<Rc<dyn SubagentHandle>> {
1711        self.resume_subagent_thread(session_id, cx)
1712    }
1713}
1714
1715#[derive(Debug, Clone)]
1716enum SubagentPromptResult {
1717    Completed,
1718    Cancelled,
1719    ContextWindowWarning,
1720    Error(String),
1721}
1722
1723pub struct NativeSubagentHandle {
1724    session_id: acp::SessionId,
1725    parent_thread: WeakEntity<Thread>,
1726    subagent_thread: Entity<Thread>,
1727    acp_thread: Entity<acp_thread::AcpThread>,
1728}
1729
1730impl NativeSubagentHandle {
1731    fn new(
1732        session_id: acp::SessionId,
1733        subagent_thread: Entity<Thread>,
1734        acp_thread: Entity<acp_thread::AcpThread>,
1735        parent_thread_entity: Entity<Thread>,
1736    ) -> Self {
1737        NativeSubagentHandle {
1738            session_id,
1739            subagent_thread,
1740            parent_thread: parent_thread_entity.downgrade(),
1741            acp_thread,
1742        }
1743    }
1744}
1745
1746impl SubagentHandle for NativeSubagentHandle {
1747    fn id(&self) -> acp::SessionId {
1748        self.session_id.clone()
1749    }
1750
1751    fn num_entries(&self, cx: &App) -> usize {
1752        self.acp_thread.read(cx).entries().len()
1753    }
1754
1755    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1756        let thread = self.subagent_thread.clone();
1757        let acp_thread = self.acp_thread.clone();
1758        let subagent_session_id = self.session_id.clone();
1759        let parent_thread = self.parent_thread.clone();
1760
1761        cx.spawn(async move |cx| {
1762            let (task, _subscription) = cx.update(|cx| {
1763                let ratio_before_prompt = thread
1764                    .read(cx)
1765                    .latest_token_usage()
1766                    .map(|usage| usage.ratio());
1767
1768                parent_thread
1769                    .update(cx, |parent_thread, _cx| {
1770                        parent_thread.register_running_subagent(thread.downgrade())
1771                    })
1772                    .ok();
1773
1774                let task = acp_thread.update(cx, |acp_thread, cx| {
1775                    acp_thread.send(vec![message.into()], cx)
1776                });
1777
1778                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1779                let mut token_limit_tx = Some(token_limit_tx);
1780
1781                let subscription = cx.subscribe(
1782                    &thread,
1783                    move |_thread, event: &TokenUsageUpdated, _cx| {
1784                        if let Some(usage) = &event.0 {
1785                            let old_ratio = ratio_before_prompt
1786                                .clone()
1787                                .unwrap_or(TokenUsageRatio::Normal);
1788                            let new_ratio = usage.ratio();
1789                            if old_ratio == TokenUsageRatio::Normal
1790                                && new_ratio == TokenUsageRatio::Warning
1791                            {
1792                                if let Some(tx) = token_limit_tx.take() {
1793                                    tx.send(()).ok();
1794                                }
1795                            }
1796                        }
1797                    },
1798                );
1799
1800                let wait_for_prompt = cx
1801                    .background_spawn(async move {
1802                        futures::select! {
1803                            response = task.fuse() => match response {
1804                                Ok(Some(response)) => {
1805                                    match response.stop_reason {
1806                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1807                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1808                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1809                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1810                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1811                                    }
1812                                }
1813                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1814                                Err(error) => SubagentPromptResult::Error(error.to_string()),
1815                            },
1816                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1817                        }
1818                    });
1819
1820                (wait_for_prompt, subscription)
1821            });
1822
1823            let result = match task.await {
1824                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1825                    thread
1826                        .last_message()
1827                        .and_then(|message| {
1828                            let content = message.as_agent_message()?
1829                                .content
1830                                .iter()
1831                                .filter_map(|c| match c {
1832                                    AgentMessageContent::Text(text) => Some(text.as_str()),
1833                                    _ => None,
1834                                })
1835                                .join("\n\n");
1836                            if content.is_empty() {
1837                                None
1838                            } else {
1839                                Some( content)
1840                            }
1841                        })
1842                        .context("No response from subagent")
1843                }),
1844                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
1845                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1846                SubagentPromptResult::ContextWindowWarning => {
1847                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1848                    Err(anyhow!(
1849                        "The agent is nearing the end of its context window and has been \
1850                         stopped. You can prompt the thread again to have the agent wrap up \
1851                         or hand off its work."
1852                    ))
1853                }
1854            };
1855
1856            parent_thread
1857                .update(cx, |parent_thread, cx| {
1858                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1859                })
1860                .ok();
1861
1862            result
1863        })
1864    }
1865}
1866
1867pub struct AcpTerminalHandle {
1868    terminal: Entity<acp_thread::Terminal>,
1869    _drop_tx: Option<oneshot::Sender<()>>,
1870}
1871
1872impl TerminalHandle for AcpTerminalHandle {
1873    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1874        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1875    }
1876
1877    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1878        Ok(self
1879            .terminal
1880            .read_with(cx, |term, _cx| term.wait_for_exit()))
1881    }
1882
1883    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1884        Ok(self
1885            .terminal
1886            .read_with(cx, |term, cx| term.current_output(cx)))
1887    }
1888
1889    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1890        cx.update(|cx| {
1891            self.terminal.update(cx, |terminal, cx| {
1892                terminal.kill(cx);
1893            });
1894        });
1895        Ok(())
1896    }
1897
1898    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1899        Ok(self
1900            .terminal
1901            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1902    }
1903}
1904
1905#[cfg(test)]
1906mod internal_tests {
1907    use super::*;
1908    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1909    use fs::FakeFs;
1910    use gpui::TestAppContext;
1911    use indoc::formatdoc;
1912    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1913    use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1914    use serde_json::json;
1915    use settings::SettingsStore;
1916    use util::{path, rel_path::rel_path};
1917
1918    #[gpui::test]
1919    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1920        init_test(cx);
1921        let fs = FakeFs::new(cx.executor());
1922        fs.insert_tree(
1923            "/",
1924            json!({
1925                "a": {}
1926            }),
1927        )
1928        .await;
1929        let project = Project::test(fs.clone(), [], cx).await;
1930        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1931        let agent = NativeAgent::new(
1932            project.clone(),
1933            thread_store,
1934            Templates::new(),
1935            None,
1936            fs.clone(),
1937            &mut cx.to_async(),
1938        )
1939        .await
1940        .unwrap();
1941        agent.read_with(cx, |agent, cx| {
1942            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1943        });
1944
1945        let worktree = project
1946            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1947            .await
1948            .unwrap();
1949        cx.run_until_parked();
1950        agent.read_with(cx, |agent, cx| {
1951            assert_eq!(
1952                agent.project_context.read(cx).worktrees,
1953                vec![WorktreeContext {
1954                    root_name: "a".into(),
1955                    abs_path: Path::new("/a").into(),
1956                    rules_file: None
1957                }]
1958            )
1959        });
1960
1961        // Creating `/a/.rules` updates the project context.
1962        fs.insert_file("/a/.rules", Vec::new()).await;
1963        cx.run_until_parked();
1964        agent.read_with(cx, |agent, cx| {
1965            let rules_entry = worktree
1966                .read(cx)
1967                .entry_for_path(rel_path(".rules"))
1968                .unwrap();
1969            assert_eq!(
1970                agent.project_context.read(cx).worktrees,
1971                vec![WorktreeContext {
1972                    root_name: "a".into(),
1973                    abs_path: Path::new("/a").into(),
1974                    rules_file: Some(RulesFileContext {
1975                        path_in_worktree: rel_path(".rules").into(),
1976                        text: "".into(),
1977                        project_entry_id: rules_entry.id.to_usize()
1978                    })
1979                }]
1980            )
1981        });
1982    }
1983
1984    #[gpui::test]
1985    async fn test_listing_models(cx: &mut TestAppContext) {
1986        init_test(cx);
1987        let fs = FakeFs::new(cx.executor());
1988        fs.insert_tree("/", json!({ "a": {}  })).await;
1989        let project = Project::test(fs.clone(), [], cx).await;
1990        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1991        let connection = NativeAgentConnection(
1992            NativeAgent::new(
1993                project.clone(),
1994                thread_store,
1995                Templates::new(),
1996                None,
1997                fs.clone(),
1998                &mut cx.to_async(),
1999            )
2000            .await
2001            .unwrap(),
2002        );
2003
2004        // Create a thread/session
2005        let acp_thread = cx
2006            .update(|cx| {
2007                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2008            })
2009            .await
2010            .unwrap();
2011
2012        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2013
2014        let models = cx
2015            .update(|cx| {
2016                connection
2017                    .model_selector(&session_id)
2018                    .unwrap()
2019                    .list_models(cx)
2020            })
2021            .await
2022            .unwrap();
2023
2024        let acp_thread::AgentModelList::Grouped(models) = models else {
2025            panic!("Unexpected model group");
2026        };
2027        assert_eq!(
2028            models,
2029            IndexMap::from_iter([(
2030                AgentModelGroupName("Fake".into()),
2031                vec![AgentModelInfo {
2032                    id: acp::ModelId::new("fake/fake"),
2033                    name: "Fake".into(),
2034                    description: None,
2035                    icon: Some(acp_thread::AgentModelIcon::Named(
2036                        ui::IconName::ZedAssistant
2037                    )),
2038                    is_latest: false,
2039                    cost: None,
2040                }]
2041            )])
2042        );
2043    }
2044
2045    #[gpui::test]
2046    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2047        init_test(cx);
2048        let fs = FakeFs::new(cx.executor());
2049        fs.create_dir(paths::settings_file().parent().unwrap())
2050            .await
2051            .unwrap();
2052        fs.insert_file(
2053            paths::settings_file(),
2054            json!({
2055                "agent": {
2056                    "default_model": {
2057                        "provider": "foo",
2058                        "model": "bar"
2059                    }
2060                }
2061            })
2062            .to_string()
2063            .into_bytes(),
2064        )
2065        .await;
2066        let project = Project::test(fs.clone(), [], cx).await;
2067
2068        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2069
2070        // Create the agent and connection
2071        let agent = NativeAgent::new(
2072            project.clone(),
2073            thread_store,
2074            Templates::new(),
2075            None,
2076            fs.clone(),
2077            &mut cx.to_async(),
2078        )
2079        .await
2080        .unwrap();
2081        let connection = NativeAgentConnection(agent.clone());
2082
2083        // Create a thread/session
2084        let acp_thread = cx
2085            .update(|cx| {
2086                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2087            })
2088            .await
2089            .unwrap();
2090
2091        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2092
2093        // Select a model
2094        let selector = connection.model_selector(&session_id).unwrap();
2095        let model_id = acp::ModelId::new("fake/fake");
2096        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2097            .await
2098            .unwrap();
2099
2100        // Verify the thread has the selected model
2101        agent.read_with(cx, |agent, _| {
2102            let session = agent.sessions.get(&session_id).unwrap();
2103            session.thread.read_with(cx, |thread, _| {
2104                assert_eq!(thread.model().unwrap().id().0, "fake");
2105            });
2106        });
2107
2108        cx.run_until_parked();
2109
2110        // Verify settings file was updated
2111        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2112        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2113
2114        // Check that the agent settings contain the selected model
2115        assert_eq!(
2116            settings_json["agent"]["default_model"]["model"],
2117            json!("fake")
2118        );
2119        assert_eq!(
2120            settings_json["agent"]["default_model"]["provider"],
2121            json!("fake")
2122        );
2123
2124        // Register a thinking model and select it.
2125        cx.update(|cx| {
2126            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2127                "fake-corp",
2128                "fake-thinking",
2129                "Fake Thinking",
2130                true,
2131            ));
2132            let thinking_provider = Arc::new(
2133                FakeLanguageModelProvider::new(
2134                    LanguageModelProviderId::from("fake-corp".to_string()),
2135                    LanguageModelProviderName::from("Fake Corp".to_string()),
2136                )
2137                .with_models(vec![thinking_model]),
2138            );
2139            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2140                registry.register_provider(thinking_provider, cx);
2141            });
2142        });
2143        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2144
2145        let selector = connection.model_selector(&session_id).unwrap();
2146        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2147            .await
2148            .unwrap();
2149        cx.run_until_parked();
2150
2151        // Verify enable_thinking was written to settings as true.
2152        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2153        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2154        assert_eq!(
2155            settings_json["agent"]["default_model"]["enable_thinking"],
2156            json!(true),
2157            "selecting a thinking model should persist enable_thinking: true to settings"
2158        );
2159    }
2160
2161    #[gpui::test]
2162    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2163        init_test(cx);
2164        let fs = FakeFs::new(cx.executor());
2165        fs.create_dir(paths::settings_file().parent().unwrap())
2166            .await
2167            .unwrap();
2168        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2169        let project = Project::test(fs.clone(), [], cx).await;
2170
2171        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2172        let agent = NativeAgent::new(
2173            project.clone(),
2174            thread_store,
2175            Templates::new(),
2176            None,
2177            fs.clone(),
2178            &mut cx.to_async(),
2179        )
2180        .await
2181        .unwrap();
2182        let connection = NativeAgentConnection(agent.clone());
2183
2184        let acp_thread = cx
2185            .update(|cx| {
2186                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2187            })
2188            .await
2189            .unwrap();
2190        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2191
2192        // Register a second provider with a thinking model.
2193        cx.update(|cx| {
2194            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2195                "fake-corp",
2196                "fake-thinking",
2197                "Fake Thinking",
2198                true,
2199            ));
2200            let thinking_provider = Arc::new(
2201                FakeLanguageModelProvider::new(
2202                    LanguageModelProviderId::from("fake-corp".to_string()),
2203                    LanguageModelProviderName::from("Fake Corp".to_string()),
2204                )
2205                .with_models(vec![thinking_model]),
2206            );
2207            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2208                registry.register_provider(thinking_provider, cx);
2209            });
2210        });
2211        // Refresh the agent's model list so it picks up the new provider.
2212        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2213
2214        // Thread starts with thinking_enabled = false (the default).
2215        agent.read_with(cx, |agent, _| {
2216            let session = agent.sessions.get(&session_id).unwrap();
2217            session.thread.read_with(cx, |thread, _| {
2218                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2219            });
2220        });
2221
2222        // Select the thinking model via select_model.
2223        let selector = connection.model_selector(&session_id).unwrap();
2224        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2225            .await
2226            .unwrap();
2227
2228        // select_model should have enabled thinking based on the model's supports_thinking().
2229        agent.read_with(cx, |agent, _| {
2230            let session = agent.sessions.get(&session_id).unwrap();
2231            session.thread.read_with(cx, |thread, _| {
2232                assert!(
2233                    thread.thinking_enabled(),
2234                    "select_model should enable thinking when model supports it"
2235                );
2236            });
2237        });
2238
2239        // Switch back to the non-thinking model.
2240        let selector = connection.model_selector(&session_id).unwrap();
2241        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2242            .await
2243            .unwrap();
2244
2245        // select_model should have disabled thinking.
2246        agent.read_with(cx, |agent, _| {
2247            let session = agent.sessions.get(&session_id).unwrap();
2248            session.thread.read_with(cx, |thread, _| {
2249                assert!(
2250                    !thread.thinking_enabled(),
2251                    "select_model should disable thinking when model does not support it"
2252                );
2253            });
2254        });
2255    }
2256
2257    #[gpui::test]
2258    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2259        init_test(cx);
2260        let fs = FakeFs::new(cx.executor());
2261        fs.insert_tree("/", json!({ "a": {} })).await;
2262        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2263        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2264        let agent = NativeAgent::new(
2265            project.clone(),
2266            thread_store.clone(),
2267            Templates::new(),
2268            None,
2269            fs.clone(),
2270            &mut cx.to_async(),
2271        )
2272        .await
2273        .unwrap();
2274        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2275
2276        // Register a thinking model.
2277        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2278            "fake-corp",
2279            "fake-thinking",
2280            "Fake Thinking",
2281            true,
2282        ));
2283        let thinking_provider = Arc::new(
2284            FakeLanguageModelProvider::new(
2285                LanguageModelProviderId::from("fake-corp".to_string()),
2286                LanguageModelProviderName::from("Fake Corp".to_string()),
2287            )
2288            .with_models(vec![thinking_model.clone()]),
2289        );
2290        cx.update(|cx| {
2291            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2292                registry.register_provider(thinking_provider, cx);
2293            });
2294        });
2295        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2296
2297        // Create a thread and select the thinking model.
2298        let acp_thread = cx
2299            .update(|cx| {
2300                connection
2301                    .clone()
2302                    .new_session(project.clone(), Path::new("/a"), cx)
2303            })
2304            .await
2305            .unwrap();
2306        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2307
2308        let selector = connection.model_selector(&session_id).unwrap();
2309        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2310            .await
2311            .unwrap();
2312
2313        // Verify thinking is enabled after selecting the thinking model.
2314        let thread = agent.read_with(cx, |agent, _| {
2315            agent.sessions.get(&session_id).unwrap().thread.clone()
2316        });
2317        thread.read_with(cx, |thread, _| {
2318            assert!(
2319                thread.thinking_enabled(),
2320                "thinking should be enabled after selecting thinking model"
2321            );
2322        });
2323
2324        // Send a message so the thread gets persisted.
2325        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2326        let send = cx.foreground_executor().spawn(send);
2327        cx.run_until_parked();
2328
2329        thinking_model.send_last_completion_stream_text_chunk("Response.");
2330        thinking_model.end_last_completion_stream();
2331
2332        send.await.unwrap();
2333        cx.run_until_parked();
2334
2335        // Close the session so it can be reloaded from disk.
2336        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2337            .await
2338            .unwrap();
2339        drop(thread);
2340        drop(acp_thread);
2341        agent.read_with(cx, |agent, _| {
2342            assert!(agent.sessions.is_empty());
2343        });
2344
2345        // Reload the thread and verify thinking_enabled is still true.
2346        let reloaded_acp_thread = agent
2347            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2348            .await
2349            .unwrap();
2350        let reloaded_thread = agent.read_with(cx, |agent, _| {
2351            agent.sessions.get(&session_id).unwrap().thread.clone()
2352        });
2353        reloaded_thread.read_with(cx, |thread, _| {
2354            assert!(
2355                thread.thinking_enabled(),
2356                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2357            );
2358        });
2359
2360        drop(reloaded_acp_thread);
2361    }
2362
2363    #[gpui::test]
2364    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2365        init_test(cx);
2366        let fs = FakeFs::new(cx.executor());
2367        fs.insert_tree("/", json!({ "a": {} })).await;
2368        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2369        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2370        let agent = NativeAgent::new(
2371            project.clone(),
2372            thread_store.clone(),
2373            Templates::new(),
2374            None,
2375            fs.clone(),
2376            &mut cx.to_async(),
2377        )
2378        .await
2379        .unwrap();
2380        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2381
2382        // Register a model where id() != name(), like real Anthropic models
2383        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2384        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2385            "fake-corp",
2386            "custom-model-id",
2387            "Custom Model Display Name",
2388            false,
2389        ));
2390        let provider = Arc::new(
2391            FakeLanguageModelProvider::new(
2392                LanguageModelProviderId::from("fake-corp".to_string()),
2393                LanguageModelProviderName::from("Fake Corp".to_string()),
2394            )
2395            .with_models(vec![model.clone()]),
2396        );
2397        cx.update(|cx| {
2398            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2399                registry.register_provider(provider, cx);
2400            });
2401        });
2402        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2403
2404        // Create a thread and select the model.
2405        let acp_thread = cx
2406            .update(|cx| {
2407                connection
2408                    .clone()
2409                    .new_session(project.clone(), Path::new("/a"), cx)
2410            })
2411            .await
2412            .unwrap();
2413        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2414
2415        let selector = connection.model_selector(&session_id).unwrap();
2416        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2417            .await
2418            .unwrap();
2419
2420        let thread = agent.read_with(cx, |agent, _| {
2421            agent.sessions.get(&session_id).unwrap().thread.clone()
2422        });
2423        thread.read_with(cx, |thread, _| {
2424            assert_eq!(
2425                thread.model().unwrap().id().0.as_ref(),
2426                "custom-model-id",
2427                "model should be set before persisting"
2428            );
2429        });
2430
2431        // Send a message so the thread gets persisted.
2432        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2433        let send = cx.foreground_executor().spawn(send);
2434        cx.run_until_parked();
2435
2436        model.send_last_completion_stream_text_chunk("Response.");
2437        model.end_last_completion_stream();
2438
2439        send.await.unwrap();
2440        cx.run_until_parked();
2441
2442        // Close the session so it can be reloaded from disk.
2443        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2444            .await
2445            .unwrap();
2446        drop(thread);
2447        drop(acp_thread);
2448        agent.read_with(cx, |agent, _| {
2449            assert!(agent.sessions.is_empty());
2450        });
2451
2452        // Reload the thread and verify the model was preserved.
2453        let reloaded_acp_thread = agent
2454            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2455            .await
2456            .unwrap();
2457        let reloaded_thread = agent.read_with(cx, |agent, _| {
2458            agent.sessions.get(&session_id).unwrap().thread.clone()
2459        });
2460        reloaded_thread.read_with(cx, |thread, _| {
2461            let reloaded_model = thread
2462                .model()
2463                .expect("model should be present after reload");
2464            assert_eq!(
2465                reloaded_model.id().0.as_ref(),
2466                "custom-model-id",
2467                "reloaded thread should have the same model, not fall back to the default"
2468            );
2469        });
2470
2471        drop(reloaded_acp_thread);
2472    }
2473
2474    #[gpui::test]
2475    async fn test_save_load_thread(cx: &mut TestAppContext) {
2476        init_test(cx);
2477        let fs = FakeFs::new(cx.executor());
2478        fs.insert_tree(
2479            "/",
2480            json!({
2481                "a": {
2482                    "b.md": "Lorem"
2483                }
2484            }),
2485        )
2486        .await;
2487        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2488        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2489        let agent = NativeAgent::new(
2490            project.clone(),
2491            thread_store.clone(),
2492            Templates::new(),
2493            None,
2494            fs.clone(),
2495            &mut cx.to_async(),
2496        )
2497        .await
2498        .unwrap();
2499        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2500
2501        let acp_thread = cx
2502            .update(|cx| {
2503                connection
2504                    .clone()
2505                    .new_session(project.clone(), Path::new(""), cx)
2506            })
2507            .await
2508            .unwrap();
2509        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2510        let thread = agent.read_with(cx, |agent, _| {
2511            agent.sessions.get(&session_id).unwrap().thread.clone()
2512        });
2513
2514        // Ensure empty threads are not saved, even if they get mutated.
2515        let model = Arc::new(FakeLanguageModel::default());
2516        let summary_model = Arc::new(FakeLanguageModel::default());
2517        thread.update(cx, |thread, cx| {
2518            thread.set_model(model.clone(), cx);
2519            thread.set_summarization_model(Some(summary_model.clone()), cx);
2520        });
2521        cx.run_until_parked();
2522        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2523
2524        let send = acp_thread.update(cx, |thread, cx| {
2525            thread.send(
2526                vec![
2527                    "What does ".into(),
2528                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2529                        "b.md",
2530                        MentionUri::File {
2531                            abs_path: path!("/a/b.md").into(),
2532                        }
2533                        .to_uri()
2534                        .to_string(),
2535                    )),
2536                    " mean?".into(),
2537                ],
2538                cx,
2539            )
2540        });
2541        let send = cx.foreground_executor().spawn(send);
2542        cx.run_until_parked();
2543
2544        model.send_last_completion_stream_text_chunk("Lorem.");
2545        model.end_last_completion_stream();
2546        cx.run_until_parked();
2547        summary_model
2548            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2549        summary_model.end_last_completion_stream();
2550
2551        send.await.unwrap();
2552        let uri = MentionUri::File {
2553            abs_path: path!("/a/b.md").into(),
2554        }
2555        .to_uri();
2556        acp_thread.read_with(cx, |thread, cx| {
2557            assert_eq!(
2558                thread.to_markdown(cx),
2559                formatdoc! {"
2560                    ## User
2561
2562                    What does [@b.md]({uri}) mean?
2563
2564                    ## Assistant
2565
2566                    Lorem.
2567
2568                "}
2569            )
2570        });
2571
2572        cx.run_until_parked();
2573
2574        // Close the session so it can be reloaded from disk.
2575        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2576            .await
2577            .unwrap();
2578        drop(thread);
2579        drop(acp_thread);
2580        agent.read_with(cx, |agent, _| {
2581            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2582        });
2583
2584        // Ensure the thread can be reloaded from disk.
2585        assert_eq!(
2586            thread_entries(&thread_store, cx),
2587            vec![(
2588                session_id.clone(),
2589                format!("Explaining {}", path!("/a/b.md"))
2590            )]
2591        );
2592        let acp_thread = agent
2593            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2594            .await
2595            .unwrap();
2596        acp_thread.read_with(cx, |thread, cx| {
2597            assert_eq!(
2598                thread.to_markdown(cx),
2599                formatdoc! {"
2600                    ## User
2601
2602                    What does [@b.md]({uri}) mean?
2603
2604                    ## Assistant
2605
2606                    Lorem.
2607
2608                "}
2609            )
2610        });
2611    }
2612
2613    fn thread_entries(
2614        thread_store: &Entity<ThreadStore>,
2615        cx: &mut TestAppContext,
2616    ) -> Vec<(acp::SessionId, String)> {
2617        thread_store.read_with(cx, |store, _| {
2618            store
2619                .entries()
2620                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2621                .collect::<Vec<_>>()
2622        })
2623    }
2624
2625    fn init_test(cx: &mut TestAppContext) {
2626        env_logger::try_init().ok();
2627        cx.update(|cx| {
2628            let settings_store = SettingsStore::test(cx);
2629            cx.set_global(settings_store);
2630
2631            LanguageModelRegistry::test(cx);
2632        });
2633    }
2634}
2635
2636fn mcp_message_content_to_acp_content_block(
2637    content: context_server::types::MessageContent,
2638) -> acp::ContentBlock {
2639    match content {
2640        context_server::types::MessageContent::Text {
2641            text,
2642            annotations: _,
2643        } => text.into(),
2644        context_server::types::MessageContent::Image {
2645            data,
2646            mime_type,
2647            annotations: _,
2648        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2649        context_server::types::MessageContent::Audio {
2650            data,
2651            mime_type,
2652            annotations: _,
2653        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2654        context_server::types::MessageContent::Resource {
2655            resource,
2656            annotations: _,
2657        } => {
2658            let mut link =
2659                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2660            if let Some(mime_type) = resource.mime_type {
2661                link = link.mime_type(mime_type);
2662            }
2663            acp::ContentBlock::ResourceLink(link)
2664        }
2665    }
2666}