agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  41};
  42use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  43use project::{Project, ProjectItem, ProjectPath, Worktree};
  44use prompt_store::{
  45    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  46    WorktreeContext,
  47};
  48use serde::{Deserialize, Serialize};
  49use settings::{LanguageModelSelection, update_settings_file};
  50use std::any::Any;
  51use std::path::{Path, PathBuf};
  52use std::rc::Rc;
  53use std::sync::Arc;
  54use util::ResultExt;
  55use util::path_list::PathList;
  56use util::rel_path::RelPath;
  57
  58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  59pub struct ProjectSnapshot {
  60    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  61    pub timestamp: DateTime<Utc>,
  62}
  63
  64pub struct RulesLoadingError {
  65    pub message: SharedString,
  66}
  67
  68/// Holds both the internal Thread and the AcpThread for a session
  69struct Session {
  70    /// The internal thread that processes messages
  71    thread: Entity<Thread>,
  72    /// The ACP thread that handles protocol communication
  73    acp_thread: Entity<acp_thread::AcpThread>,
  74    pending_save: Task<()>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78pub struct LanguageModels {
  79    /// Access language model by ID
  80    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  81    /// Cached list for returning language model information
  82    model_list: acp_thread::AgentModelList,
  83    refresh_models_rx: watch::Receiver<()>,
  84    refresh_models_tx: watch::Sender<()>,
  85    _authenticate_all_providers_task: Task<()>,
  86}
  87
  88impl LanguageModels {
  89    fn new(cx: &mut App) -> Self {
  90        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  91
  92        let mut this = Self {
  93            models: HashMap::default(),
  94            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  95            refresh_models_rx,
  96            refresh_models_tx,
  97            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  98        };
  99        this.refresh_list(cx);
 100        this
 101    }
 102
 103    fn refresh_list(&mut self, cx: &App) {
 104        let providers = LanguageModelRegistry::global(cx)
 105            .read(cx)
 106            .visible_providers()
 107            .into_iter()
 108            .filter(|provider| provider.is_authenticated(cx))
 109            .collect::<Vec<_>>();
 110
 111        let mut language_model_list = IndexMap::default();
 112        let mut recommended_models = HashSet::default();
 113
 114        let mut recommended = Vec::new();
 115        for provider in &providers {
 116            for model in provider.recommended_models(cx) {
 117                recommended_models.insert((model.provider_id(), model.id()));
 118                recommended.push(Self::map_language_model_to_info(&model, provider));
 119            }
 120        }
 121        if !recommended.is_empty() {
 122            language_model_list.insert(
 123                acp_thread::AgentModelGroupName("Recommended".into()),
 124                recommended,
 125            );
 126        }
 127
 128        let mut models = HashMap::default();
 129        for provider in providers {
 130            let mut provider_models = Vec::new();
 131            for model in provider.provided_models(cx) {
 132                let model_info = Self::map_language_model_to_info(&model, &provider);
 133                let model_id = model_info.id.clone();
 134                provider_models.push(model_info);
 135                models.insert(model_id, model);
 136            }
 137            if !provider_models.is_empty() {
 138                language_model_list.insert(
 139                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 140                    provider_models,
 141                );
 142            }
 143        }
 144
 145        self.models = models;
 146        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 147        self.refresh_models_tx.send(()).ok();
 148    }
 149
 150    fn watch(&self) -> watch::Receiver<()> {
 151        self.refresh_models_rx.clone()
 152    }
 153
 154    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 155        self.models.get(model_id).cloned()
 156    }
 157
 158    fn map_language_model_to_info(
 159        model: &Arc<dyn LanguageModel>,
 160        provider: &Arc<dyn LanguageModelProvider>,
 161    ) -> acp_thread::AgentModelInfo {
 162        acp_thread::AgentModelInfo {
 163            id: Self::model_id(model),
 164            name: model.name().0,
 165            description: None,
 166            icon: Some(match provider.icon() {
 167                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 168                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 169            }),
 170            is_latest: model.is_latest(),
 171            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 172        }
 173    }
 174
 175    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 176        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 177    }
 178
 179    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 180        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 181            .read(cx)
 182            .visible_providers()
 183            .iter()
 184            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 185            .collect::<Vec<_>>();
 186
 187        cx.background_spawn(async move {
 188            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 189                if let Err(err) = authenticate_task.await {
 190                    match err {
 191                        language_model::AuthenticateError::CredentialsNotFound => {
 192                            // Since we're authenticating these providers in the
 193                            // background for the purposes of populating the
 194                            // language selector, we don't care about providers
 195                            // where the credentials are not found.
 196                        }
 197                        language_model::AuthenticateError::ConnectionRefused => {
 198                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 199                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 200                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 201                        }
 202                        _ => {
 203                            // Some providers have noisy failure states that we
 204                            // don't want to spam the logs with every time the
 205                            // language model selector is initialized.
 206                            //
 207                            // Ideally these should have more clear failure modes
 208                            // that we know are safe to ignore here, like what we do
 209                            // with `CredentialsNotFound` above.
 210                            match provider_id.0.as_ref() {
 211                                "lmstudio" | "ollama" => {
 212                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 213                                    //
 214                                    // These fail noisily, so we don't log them.
 215                                }
 216                                "copilot_chat" => {
 217                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 218                                }
 219                                _ => {
 220                                    log::error!(
 221                                        "Failed to authenticate provider: {}: {err:#}",
 222                                        provider_name.0
 223                                    );
 224                                }
 225                            }
 226                        }
 227                    }
 228                }
 229            }
 230        })
 231    }
 232}
 233
 234pub struct NativeAgent {
 235    /// Session ID -> Session mapping
 236    sessions: HashMap<acp::SessionId, Session>,
 237    thread_store: Entity<ThreadStore>,
 238    /// Shared project context for all threads
 239    project_context: Entity<ProjectContext>,
 240    project_context_needs_refresh: watch::Sender<()>,
 241    _maintain_project_context: Task<Result<()>>,
 242    context_server_registry: Entity<ContextServerRegistry>,
 243    /// Shared templates for all threads
 244    templates: Arc<Templates>,
 245    /// Cached model information
 246    models: LanguageModels,
 247    project: Entity<Project>,
 248    prompt_store: Option<Entity<PromptStore>>,
 249    fs: Arc<dyn Fs>,
 250    _subscriptions: Vec<Subscription>,
 251}
 252
 253impl NativeAgent {
 254    pub async fn new(
 255        project: Entity<Project>,
 256        thread_store: Entity<ThreadStore>,
 257        templates: Arc<Templates>,
 258        prompt_store: Option<Entity<PromptStore>>,
 259        fs: Arc<dyn Fs>,
 260        cx: &mut AsyncApp,
 261    ) -> Result<Entity<NativeAgent>> {
 262        log::debug!("Creating new NativeAgent");
 263
 264        let project_context = cx
 265            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 266            .await;
 267
 268        Ok(cx.new(|cx| {
 269            let context_server_store = project.read(cx).context_server_store();
 270            let context_server_registry =
 271                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 272
 273            let mut subscriptions = vec![
 274                cx.subscribe(&project, Self::handle_project_event),
 275                cx.subscribe(
 276                    &LanguageModelRegistry::global(cx),
 277                    Self::handle_models_updated_event,
 278                ),
 279                cx.subscribe(
 280                    &context_server_store,
 281                    Self::handle_context_server_store_updated,
 282                ),
 283                cx.subscribe(
 284                    &context_server_registry,
 285                    Self::handle_context_server_registry_event,
 286                ),
 287            ];
 288            if let Some(prompt_store) = prompt_store.as_ref() {
 289                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 290            }
 291
 292            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 293                watch::channel(());
 294            Self {
 295                sessions: HashMap::default(),
 296                thread_store,
 297                project_context: cx.new(|_| project_context),
 298                project_context_needs_refresh: project_context_needs_refresh_tx,
 299                _maintain_project_context: cx.spawn(async move |this, cx| {
 300                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 301                }),
 302                context_server_registry,
 303                templates,
 304                models: LanguageModels::new(cx),
 305                project,
 306                prompt_store,
 307                fs,
 308                _subscriptions: subscriptions,
 309            }
 310        }))
 311    }
 312
 313    fn new_session(
 314        &mut self,
 315        project: Entity<Project>,
 316        cx: &mut Context<Self>,
 317    ) -> Entity<AcpThread> {
 318        // Create Thread
 319        // Fetch default model from registry settings
 320        let registry = LanguageModelRegistry::read_global(cx);
 321        // Log available models for debugging
 322        let available_count = registry.available_models(cx).count();
 323        log::debug!("Total available models: {}", available_count);
 324
 325        let default_model = registry.default_model().and_then(|default_model| {
 326            self.models
 327                .model_from_id(&LanguageModels::model_id(&default_model.model))
 328        });
 329        let thread = cx.new(|cx| {
 330            Thread::new(
 331                project.clone(),
 332                self.project_context.clone(),
 333                self.context_server_registry.clone(),
 334                self.templates.clone(),
 335                default_model,
 336                cx,
 337            )
 338        });
 339
 340        self.register_session(thread, cx)
 341    }
 342
 343    fn register_session(
 344        &mut self,
 345        thread_handle: Entity<Thread>,
 346        cx: &mut Context<Self>,
 347    ) -> Entity<AcpThread> {
 348        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 349
 350        let thread = thread_handle.read(cx);
 351        let session_id = thread.id().clone();
 352        let parent_session_id = thread.parent_thread_id();
 353        let title = thread.title();
 354        let draft_prompt = thread.draft_prompt().map(Vec::from);
 355        let scroll_position = thread.ui_scroll_position();
 356        let token_usage = thread.latest_token_usage();
 357        let project = thread.project.clone();
 358        let action_log = thread.action_log.clone();
 359        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 360        let acp_thread = cx.new(|cx| {
 361            let mut acp_thread = acp_thread::AcpThread::new(
 362                parent_session_id,
 363                title,
 364                None,
 365                connection,
 366                project.clone(),
 367                action_log.clone(),
 368                session_id.clone(),
 369                prompt_capabilities_rx,
 370                cx,
 371            );
 372            acp_thread.set_draft_prompt(draft_prompt);
 373            acp_thread.set_ui_scroll_position(scroll_position);
 374            acp_thread.update_token_usage(token_usage, cx);
 375            acp_thread
 376        });
 377
 378        let registry = LanguageModelRegistry::read_global(cx);
 379        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 380
 381        let weak = cx.weak_entity();
 382        let weak_thread = thread_handle.downgrade();
 383        thread_handle.update(cx, |thread, cx| {
 384            thread.set_summarization_model(summarization_model, cx);
 385            thread.add_default_tools(
 386                Rc::new(NativeThreadEnvironment {
 387                    acp_thread: acp_thread.downgrade(),
 388                    thread: weak_thread,
 389                    agent: weak,
 390                }) as _,
 391                cx,
 392            )
 393        });
 394
 395        let subscriptions = vec![
 396            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 397            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 398            cx.observe(&thread_handle, move |this, thread, cx| {
 399                this.save_thread(thread, cx)
 400            }),
 401        ];
 402
 403        self.sessions.insert(
 404            session_id,
 405            Session {
 406                thread: thread_handle,
 407                acp_thread: acp_thread.clone(),
 408                _subscriptions: subscriptions,
 409                pending_save: Task::ready(()),
 410            },
 411        );
 412
 413        self.update_available_commands(cx);
 414
 415        acp_thread
 416    }
 417
 418    pub fn models(&self) -> &LanguageModels {
 419        &self.models
 420    }
 421
 422    async fn maintain_project_context(
 423        this: WeakEntity<Self>,
 424        mut needs_refresh: watch::Receiver<()>,
 425        cx: &mut AsyncApp,
 426    ) -> Result<()> {
 427        while needs_refresh.changed().await.is_ok() {
 428            let project_context = this
 429                .update(cx, |this, cx| {
 430                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 431                })?
 432                .await;
 433            this.update(cx, |this, cx| {
 434                this.project_context = cx.new(|_| project_context);
 435            })?;
 436        }
 437
 438        Ok(())
 439    }
 440
 441    fn build_project_context(
 442        project: &Entity<Project>,
 443        prompt_store: Option<&Entity<PromptStore>>,
 444        cx: &mut App,
 445    ) -> Task<ProjectContext> {
 446        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 447        let worktree_tasks = worktrees
 448            .into_iter()
 449            .map(|worktree| {
 450                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 451            })
 452            .collect::<Vec<_>>();
 453        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 454            prompt_store.read_with(cx, |prompt_store, cx| {
 455                let prompts = prompt_store.default_prompt_metadata();
 456                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 457                    let contents = prompt_store.load(prompt_metadata.id, cx);
 458                    async move { (contents.await, prompt_metadata) }
 459                });
 460                cx.background_spawn(future::join_all(load_tasks))
 461            })
 462        } else {
 463            Task::ready(vec![])
 464        };
 465
 466        cx.spawn(async move |_cx| {
 467            let (worktrees, default_user_rules) =
 468                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 469
 470            let worktrees = worktrees
 471                .into_iter()
 472                .map(|(worktree, _rules_error)| {
 473                    // TODO: show error message
 474                    // if let Some(rules_error) = rules_error {
 475                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 476                    // }
 477                    worktree
 478                })
 479                .collect::<Vec<_>>();
 480
 481            let default_user_rules = default_user_rules
 482                .into_iter()
 483                .flat_map(|(contents, prompt_metadata)| match contents {
 484                    Ok(contents) => Some(UserRulesContext {
 485                        uuid: prompt_metadata.id.as_user()?,
 486                        title: prompt_metadata.title.map(|title| title.to_string()),
 487                        contents,
 488                    }),
 489                    Err(_err) => {
 490                        // TODO: show error message
 491                        // this.update(cx, |_, cx| {
 492                        //     cx.emit(RulesLoadingError {
 493                        //         message: format!("{err:?}").into(),
 494                        //     });
 495                        // })
 496                        // .ok();
 497                        None
 498                    }
 499                })
 500                .collect::<Vec<_>>();
 501
 502            ProjectContext::new(worktrees, default_user_rules)
 503        })
 504    }
 505
 506    fn load_worktree_info_for_system_prompt(
 507        worktree: Entity<Worktree>,
 508        project: Entity<Project>,
 509        cx: &mut App,
 510    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 511        let tree = worktree.read(cx);
 512        let root_name = tree.root_name_str().into();
 513        let abs_path = tree.abs_path();
 514
 515        let mut context = WorktreeContext {
 516            root_name,
 517            abs_path,
 518            rules_file: None,
 519        };
 520
 521        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 522        let Some(rules_task) = rules_task else {
 523            return Task::ready((context, None));
 524        };
 525
 526        cx.spawn(async move |_| {
 527            let (rules_file, rules_file_error) = match rules_task.await {
 528                Ok(rules_file) => (Some(rules_file), None),
 529                Err(err) => (
 530                    None,
 531                    Some(RulesLoadingError {
 532                        message: format!("{err}").into(),
 533                    }),
 534                ),
 535            };
 536            context.rules_file = rules_file;
 537            (context, rules_file_error)
 538        })
 539    }
 540
 541    fn load_worktree_rules_file(
 542        worktree: Entity<Worktree>,
 543        project: Entity<Project>,
 544        cx: &mut App,
 545    ) -> Option<Task<Result<RulesFileContext>>> {
 546        let worktree = worktree.read(cx);
 547        let worktree_id = worktree.id();
 548        let selected_rules_file = RULES_FILE_NAMES
 549            .into_iter()
 550            .filter_map(|name| {
 551                worktree
 552                    .entry_for_path(RelPath::unix(name).unwrap())
 553                    .filter(|entry| entry.is_file())
 554                    .map(|entry| entry.path.clone())
 555            })
 556            .next();
 557
 558        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 559        // supported. This doesn't seem to occur often in GitHub repositories.
 560        selected_rules_file.map(|path_in_worktree| {
 561            let project_path = ProjectPath {
 562                worktree_id,
 563                path: path_in_worktree.clone(),
 564            };
 565            let buffer_task =
 566                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 567            let rope_task = cx.spawn(async move |cx| {
 568                let buffer = buffer_task.await?;
 569                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 570                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 571                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 572                })?;
 573                anyhow::Ok((project_entry_id, rope))
 574            });
 575            // Build a string from the rope on a background thread.
 576            cx.background_spawn(async move {
 577                let (project_entry_id, rope) = rope_task.await?;
 578                anyhow::Ok(RulesFileContext {
 579                    path_in_worktree,
 580                    text: rope.to_string().trim().to_string(),
 581                    project_entry_id: project_entry_id.to_usize(),
 582                })
 583            })
 584        })
 585    }
 586
 587    fn handle_thread_title_updated(
 588        &mut self,
 589        thread: Entity<Thread>,
 590        _: &TitleUpdated,
 591        cx: &mut Context<Self>,
 592    ) {
 593        let session_id = thread.read(cx).id();
 594        let Some(session) = self.sessions.get(session_id) else {
 595            return;
 596        };
 597        let thread = thread.downgrade();
 598        let acp_thread = session.acp_thread.downgrade();
 599        cx.spawn(async move |_, cx| {
 600            let title = thread.read_with(cx, |thread, _| thread.title())?;
 601            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 602            task.await
 603        })
 604        .detach_and_log_err(cx);
 605    }
 606
 607    fn handle_thread_token_usage_updated(
 608        &mut self,
 609        thread: Entity<Thread>,
 610        usage: &TokenUsageUpdated,
 611        cx: &mut Context<Self>,
 612    ) {
 613        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 614            return;
 615        };
 616        session.acp_thread.update(cx, |acp_thread, cx| {
 617            acp_thread.update_token_usage(usage.0.clone(), cx);
 618        });
 619    }
 620
 621    fn handle_project_event(
 622        &mut self,
 623        _project: Entity<Project>,
 624        event: &project::Event,
 625        _cx: &mut Context<Self>,
 626    ) {
 627        match event {
 628            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 629                self.project_context_needs_refresh.send(()).ok();
 630            }
 631            project::Event::WorktreeUpdatedEntries(_, items) => {
 632                if items.iter().any(|(path, _, _)| {
 633                    RULES_FILE_NAMES
 634                        .iter()
 635                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 636                }) {
 637                    self.project_context_needs_refresh.send(()).ok();
 638                }
 639            }
 640            _ => {}
 641        }
 642    }
 643
 644    fn handle_prompts_updated_event(
 645        &mut self,
 646        _prompt_store: Entity<PromptStore>,
 647        _event: &prompt_store::PromptsUpdatedEvent,
 648        _cx: &mut Context<Self>,
 649    ) {
 650        self.project_context_needs_refresh.send(()).ok();
 651    }
 652
 653    fn handle_models_updated_event(
 654        &mut self,
 655        _registry: Entity<LanguageModelRegistry>,
 656        _event: &language_model::Event,
 657        cx: &mut Context<Self>,
 658    ) {
 659        self.models.refresh_list(cx);
 660
 661        let registry = LanguageModelRegistry::read_global(cx);
 662        let default_model = registry.default_model().map(|m| m.model);
 663        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 664
 665        for session in self.sessions.values_mut() {
 666            session.thread.update(cx, |thread, cx| {
 667                if thread.model().is_none()
 668                    && let Some(model) = default_model.clone()
 669                {
 670                    thread.set_model(model, cx);
 671                    cx.notify();
 672                }
 673                thread.set_summarization_model(summarization_model.clone(), cx);
 674            });
 675        }
 676    }
 677
 678    fn handle_context_server_store_updated(
 679        &mut self,
 680        _store: Entity<project::context_server_store::ContextServerStore>,
 681        _event: &project::context_server_store::ServerStatusChangedEvent,
 682        cx: &mut Context<Self>,
 683    ) {
 684        self.update_available_commands(cx);
 685    }
 686
 687    fn handle_context_server_registry_event(
 688        &mut self,
 689        _registry: Entity<ContextServerRegistry>,
 690        event: &ContextServerRegistryEvent,
 691        cx: &mut Context<Self>,
 692    ) {
 693        match event {
 694            ContextServerRegistryEvent::ToolsChanged => {}
 695            ContextServerRegistryEvent::PromptsChanged => {
 696                self.update_available_commands(cx);
 697            }
 698        }
 699    }
 700
 701    fn update_available_commands(&self, cx: &mut Context<Self>) {
 702        let available_commands = self.build_available_commands(cx);
 703        for session in self.sessions.values() {
 704            session.acp_thread.update(cx, |thread, cx| {
 705                thread
 706                    .handle_session_update(
 707                        acp::SessionUpdate::AvailableCommandsUpdate(
 708                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 709                        ),
 710                        cx,
 711                    )
 712                    .log_err();
 713            });
 714        }
 715    }
 716
 717    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 718        let registry = self.context_server_registry.read(cx);
 719
 720        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 721        for context_server_prompt in registry.prompts() {
 722            *prompt_name_counts
 723                .entry(context_server_prompt.prompt.name.as_str())
 724                .or_insert(0) += 1;
 725        }
 726
 727        registry
 728            .prompts()
 729            .flat_map(|context_server_prompt| {
 730                let prompt = &context_server_prompt.prompt;
 731
 732                let should_prefix = prompt_name_counts
 733                    .get(prompt.name.as_str())
 734                    .copied()
 735                    .unwrap_or(0)
 736                    > 1;
 737
 738                let name = if should_prefix {
 739                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 740                } else {
 741                    prompt.name.clone()
 742                };
 743
 744                let mut command = acp::AvailableCommand::new(
 745                    name,
 746                    prompt.description.clone().unwrap_or_default(),
 747                );
 748
 749                match prompt.arguments.as_deref() {
 750                    Some([arg]) => {
 751                        let hint = format!("<{}>", arg.name);
 752
 753                        command = command.input(acp::AvailableCommandInput::Unstructured(
 754                            acp::UnstructuredCommandInput::new(hint),
 755                        ));
 756                    }
 757                    Some([]) | None => {}
 758                    Some(_) => {
 759                        // skip >1 argument commands since we don't support them yet
 760                        return None;
 761                    }
 762                }
 763
 764                Some(command)
 765            })
 766            .collect()
 767    }
 768
 769    pub fn load_thread(
 770        &mut self,
 771        id: acp::SessionId,
 772        cx: &mut Context<Self>,
 773    ) -> Task<Result<Entity<Thread>>> {
 774        let database_future = ThreadsDatabase::connect(cx);
 775        cx.spawn(async move |this, cx| {
 776            let database = database_future.await.map_err(|err| anyhow!(err))?;
 777            let db_thread = database
 778                .load_thread(id.clone())
 779                .await?
 780                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 781
 782            this.update(cx, |this, cx| {
 783                let summarization_model = LanguageModelRegistry::read_global(cx)
 784                    .thread_summary_model()
 785                    .map(|c| c.model);
 786
 787                cx.new(|cx| {
 788                    let mut thread = Thread::from_db(
 789                        id.clone(),
 790                        db_thread,
 791                        this.project.clone(),
 792                        this.project_context.clone(),
 793                        this.context_server_registry.clone(),
 794                        this.templates.clone(),
 795                        cx,
 796                    );
 797                    thread.set_summarization_model(summarization_model, cx);
 798                    thread
 799                })
 800            })
 801        })
 802    }
 803
 804    pub fn open_thread(
 805        &mut self,
 806        id: acp::SessionId,
 807        cx: &mut Context<Self>,
 808    ) -> Task<Result<Entity<AcpThread>>> {
 809        if let Some(session) = self.sessions.get(&id) {
 810            return Task::ready(Ok(session.acp_thread.clone()));
 811        }
 812
 813        let task = self.load_thread(id, cx);
 814        cx.spawn(async move |this, cx| {
 815            let thread = task.await?;
 816            let acp_thread =
 817                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 818            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 819            cx.update(|cx| {
 820                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 821            })
 822            .await?;
 823            Ok(acp_thread)
 824        })
 825    }
 826
 827    pub fn thread_summary(
 828        &mut self,
 829        id: acp::SessionId,
 830        cx: &mut Context<Self>,
 831    ) -> Task<Result<SharedString>> {
 832        let thread = self.open_thread(id.clone(), cx);
 833        cx.spawn(async move |this, cx| {
 834            let acp_thread = thread.await?;
 835            let result = this
 836                .update(cx, |this, cx| {
 837                    this.sessions
 838                        .get(&id)
 839                        .unwrap()
 840                        .thread
 841                        .update(cx, |thread, cx| thread.summary(cx))
 842                })?
 843                .await
 844                .context("Failed to generate summary")?;
 845            drop(acp_thread);
 846            Ok(result)
 847        })
 848    }
 849
 850    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 851        if thread.read(cx).is_empty() {
 852            return;
 853        }
 854
 855        let id = thread.read(cx).id().clone();
 856        let Some(session) = self.sessions.get_mut(&id) else {
 857            return;
 858        };
 859
 860        let folder_paths = PathList::new(
 861            &self
 862                .project
 863                .read(cx)
 864                .visible_worktrees(cx)
 865                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 866                .collect::<Vec<_>>(),
 867        );
 868
 869        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
 870        let database_future = ThreadsDatabase::connect(cx);
 871        let db_thread = thread.update(cx, |thread, cx| {
 872            thread.set_draft_prompt(draft_prompt);
 873            thread.to_db(cx)
 874        });
 875        let thread_store = self.thread_store.clone();
 876        session.pending_save = cx.spawn(async move |_, cx| {
 877            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 878                return;
 879            };
 880            let db_thread = db_thread.await;
 881            database
 882                .save_thread(id, db_thread, folder_paths)
 883                .await
 884                .log_err();
 885            thread_store.update(cx, |store, cx| store.reload(cx));
 886        });
 887    }
 888
 889    fn send_mcp_prompt(
 890        &self,
 891        message_id: UserMessageId,
 892        session_id: agent_client_protocol::SessionId,
 893        prompt_name: String,
 894        server_id: ContextServerId,
 895        arguments: HashMap<String, String>,
 896        original_content: Vec<acp::ContentBlock>,
 897        cx: &mut Context<Self>,
 898    ) -> Task<Result<acp::PromptResponse>> {
 899        let server_store = self.context_server_registry.read(cx).server_store().clone();
 900        let path_style = self.project.read(cx).path_style(cx);
 901
 902        cx.spawn(async move |this, cx| {
 903            let prompt =
 904                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 905
 906            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 907                let session = this
 908                    .sessions
 909                    .get(&session_id)
 910                    .context("Failed to get session")?;
 911                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 912            })??;
 913
 914            let mut last_is_user = true;
 915
 916            thread.update(cx, |thread, cx| {
 917                thread.push_acp_user_block(
 918                    message_id,
 919                    original_content.into_iter().skip(1),
 920                    path_style,
 921                    cx,
 922                );
 923            });
 924
 925            for message in prompt.messages {
 926                let context_server::types::PromptMessage { role, content } = message;
 927                let block = mcp_message_content_to_acp_content_block(content);
 928
 929                match role {
 930                    context_server::types::Role::User => {
 931                        let id = acp_thread::UserMessageId::new();
 932
 933                        acp_thread.update(cx, |acp_thread, cx| {
 934                            acp_thread.push_user_content_block_with_indent(
 935                                Some(id.clone()),
 936                                block.clone(),
 937                                true,
 938                                cx,
 939                            );
 940                        });
 941
 942                        thread.update(cx, |thread, cx| {
 943                            thread.push_acp_user_block(id, [block], path_style, cx);
 944                        });
 945                    }
 946                    context_server::types::Role::Assistant => {
 947                        acp_thread.update(cx, |acp_thread, cx| {
 948                            acp_thread.push_assistant_content_block_with_indent(
 949                                block.clone(),
 950                                false,
 951                                true,
 952                                cx,
 953                            );
 954                        });
 955
 956                        thread.update(cx, |thread, cx| {
 957                            thread.push_acp_agent_block(block, cx);
 958                        });
 959                    }
 960                }
 961
 962                last_is_user = role == context_server::types::Role::User;
 963            }
 964
 965            let response_stream = thread.update(cx, |thread, cx| {
 966                if last_is_user {
 967                    thread.send_existing(cx)
 968                } else {
 969                    // Resume if MCP prompt did not end with a user message
 970                    thread.resume(cx)
 971                }
 972            })?;
 973
 974            cx.update(|cx| {
 975                NativeAgentConnection::handle_thread_events(
 976                    response_stream,
 977                    acp_thread.downgrade(),
 978                    cx,
 979                )
 980            })
 981            .await
 982        })
 983    }
 984}
 985
 986/// Wrapper struct that implements the AgentConnection trait
 987#[derive(Clone)]
 988pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 989
 990impl NativeAgentConnection {
 991    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 992        self.0
 993            .read(cx)
 994            .sessions
 995            .get(session_id)
 996            .map(|session| session.thread.clone())
 997    }
 998
 999    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
1000        self.0.update(cx, |this, cx| this.load_thread(id, cx))
1001    }
1002
1003    fn run_turn(
1004        &self,
1005        session_id: acp::SessionId,
1006        cx: &mut App,
1007        f: impl 'static
1008        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1009    ) -> Task<Result<acp::PromptResponse>> {
1010        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1011            agent
1012                .sessions
1013                .get_mut(&session_id)
1014                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1015        }) else {
1016            return Task::ready(Err(anyhow!("Session not found")));
1017        };
1018        log::debug!("Found session for: {}", session_id);
1019
1020        let response_stream = match f(thread, cx) {
1021            Ok(stream) => stream,
1022            Err(err) => return Task::ready(Err(err)),
1023        };
1024        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1025    }
1026
1027    fn handle_thread_events(
1028        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1029        acp_thread: WeakEntity<AcpThread>,
1030        cx: &App,
1031    ) -> Task<Result<acp::PromptResponse>> {
1032        cx.spawn(async move |cx| {
1033            // Handle response stream and forward to session.acp_thread
1034            while let Some(result) = events.next().await {
1035                match result {
1036                    Ok(event) => {
1037                        log::trace!("Received completion event: {:?}", event);
1038
1039                        match event {
1040                            ThreadEvent::UserMessage(message) => {
1041                                acp_thread.update(cx, |thread, cx| {
1042                                    for content in message.content {
1043                                        thread.push_user_content_block(
1044                                            Some(message.id.clone()),
1045                                            content.into(),
1046                                            cx,
1047                                        );
1048                                    }
1049                                })?;
1050                            }
1051                            ThreadEvent::AgentText(text) => {
1052                                acp_thread.update(cx, |thread, cx| {
1053                                    thread.push_assistant_content_block(text.into(), false, cx)
1054                                })?;
1055                            }
1056                            ThreadEvent::AgentThinking(text) => {
1057                                acp_thread.update(cx, |thread, cx| {
1058                                    thread.push_assistant_content_block(text.into(), true, cx)
1059                                })?;
1060                            }
1061                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1062                                tool_call,
1063                                options,
1064                                response,
1065                                context: _,
1066                            }) => {
1067                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1068                                    thread.request_tool_call_authorization(tool_call, options, cx)
1069                                })??;
1070                                cx.background_spawn(async move {
1071                                    if let acp::RequestPermissionOutcome::Selected(
1072                                        acp::SelectedPermissionOutcome { option_id, .. },
1073                                    ) = outcome_task.await
1074                                    {
1075                                        response
1076                                            .send(option_id)
1077                                            .map(|_| anyhow!("authorization receiver was dropped"))
1078                                            .log_err();
1079                                    }
1080                                })
1081                                .detach();
1082                            }
1083                            ThreadEvent::ToolCall(tool_call) => {
1084                                acp_thread.update(cx, |thread, cx| {
1085                                    thread.upsert_tool_call(tool_call, cx)
1086                                })??;
1087                            }
1088                            ThreadEvent::ToolCallUpdate(update) => {
1089                                acp_thread.update(cx, |thread, cx| {
1090                                    thread.update_tool_call(update, cx)
1091                                })??;
1092                            }
1093                            ThreadEvent::SubagentSpawned(session_id) => {
1094                                acp_thread.update(cx, |thread, cx| {
1095                                    thread.subagent_spawned(session_id, cx);
1096                                })?;
1097                            }
1098                            ThreadEvent::Retry(status) => {
1099                                acp_thread.update(cx, |thread, cx| {
1100                                    thread.update_retry_status(status, cx)
1101                                })?;
1102                            }
1103                            ThreadEvent::Stop(stop_reason) => {
1104                                log::debug!("Assistant message complete: {:?}", stop_reason);
1105                                return Ok(acp::PromptResponse::new(stop_reason));
1106                            }
1107                        }
1108                    }
1109                    Err(e) => {
1110                        log::error!("Error in model response stream: {:?}", e);
1111                        return Err(e);
1112                    }
1113                }
1114            }
1115
1116            log::debug!("Response stream completed");
1117            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1118        })
1119    }
1120}
1121
1122struct Command<'a> {
1123    prompt_name: &'a str,
1124    arg_value: &'a str,
1125    explicit_server_id: Option<&'a str>,
1126}
1127
1128impl<'a> Command<'a> {
1129    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1130        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1131            return None;
1132        };
1133        let text = text_content.text.trim();
1134        let command = text.strip_prefix('/')?;
1135        let (command, arg_value) = command
1136            .split_once(char::is_whitespace)
1137            .unwrap_or((command, ""));
1138
1139        if let Some((server_id, prompt_name)) = command.split_once('.') {
1140            Some(Self {
1141                prompt_name,
1142                arg_value,
1143                explicit_server_id: Some(server_id),
1144            })
1145        } else {
1146            Some(Self {
1147                prompt_name: command,
1148                arg_value,
1149                explicit_server_id: None,
1150            })
1151        }
1152    }
1153}
1154
1155struct NativeAgentModelSelector {
1156    session_id: acp::SessionId,
1157    connection: NativeAgentConnection,
1158}
1159
1160impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1161    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1162        log::debug!("NativeAgentConnection::list_models called");
1163        let list = self.connection.0.read(cx).models.model_list.clone();
1164        Task::ready(if list.is_empty() {
1165            Err(anyhow::anyhow!("No models available"))
1166        } else {
1167            Ok(list)
1168        })
1169    }
1170
1171    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1172        log::debug!(
1173            "Setting model for session {}: {}",
1174            self.session_id,
1175            model_id
1176        );
1177        let Some(thread) = self
1178            .connection
1179            .0
1180            .read(cx)
1181            .sessions
1182            .get(&self.session_id)
1183            .map(|session| session.thread.clone())
1184        else {
1185            return Task::ready(Err(anyhow!("Session not found")));
1186        };
1187
1188        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1189            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1190        };
1191
1192        // We want to reset the effort level when switching models, as the currently-selected effort level may
1193        // not be compatible.
1194        let effort = model
1195            .default_effort_level()
1196            .map(|effort_level| effort_level.value.to_string());
1197
1198        thread.update(cx, |thread, cx| {
1199            thread.set_model(model.clone(), cx);
1200            thread.set_thinking_effort(effort.clone(), cx);
1201            thread.set_thinking_enabled(model.supports_thinking(), cx);
1202        });
1203
1204        update_settings_file(
1205            self.connection.0.read(cx).fs.clone(),
1206            cx,
1207            move |settings, cx| {
1208                let provider = model.provider_id().0.to_string();
1209                let model = model.id().0.to_string();
1210                let enable_thinking = thread.read(cx).thinking_enabled();
1211                settings
1212                    .agent
1213                    .get_or_insert_default()
1214                    .set_model(LanguageModelSelection {
1215                        provider: provider.into(),
1216                        model,
1217                        enable_thinking,
1218                        effort,
1219                    });
1220            },
1221        );
1222
1223        Task::ready(Ok(()))
1224    }
1225
1226    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1227        let Some(thread) = self
1228            .connection
1229            .0
1230            .read(cx)
1231            .sessions
1232            .get(&self.session_id)
1233            .map(|session| session.thread.clone())
1234        else {
1235            return Task::ready(Err(anyhow!("Session not found")));
1236        };
1237        let Some(model) = thread.read(cx).model() else {
1238            return Task::ready(Err(anyhow!("Model not found")));
1239        };
1240        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1241        else {
1242            return Task::ready(Err(anyhow!("Provider not found")));
1243        };
1244        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1245            model, &provider,
1246        )))
1247    }
1248
1249    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1250        Some(self.connection.0.read(cx).models.watch())
1251    }
1252
1253    fn should_render_footer(&self) -> bool {
1254        true
1255    }
1256}
1257
1258impl acp_thread::AgentConnection for NativeAgentConnection {
1259    fn telemetry_id(&self) -> SharedString {
1260        "zed".into()
1261    }
1262
1263    fn new_session(
1264        self: Rc<Self>,
1265        project: Entity<Project>,
1266        cwd: &Path,
1267        cx: &mut App,
1268    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1269        log::debug!("Creating new thread for project at: {cwd:?}");
1270        Task::ready(Ok(self
1271            .0
1272            .update(cx, |agent, cx| agent.new_session(project, cx))))
1273    }
1274
1275    fn supports_load_session(&self) -> bool {
1276        true
1277    }
1278
1279    fn load_session(
1280        self: Rc<Self>,
1281        session_id: acp::SessionId,
1282        _project: Entity<Project>,
1283        _cwd: &Path,
1284        _title: Option<SharedString>,
1285        cx: &mut App,
1286    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1287        self.0
1288            .update(cx, |agent, cx| agent.open_thread(session_id, cx))
1289    }
1290
1291    fn supports_close_session(&self) -> bool {
1292        true
1293    }
1294
1295    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1296        self.0.update(cx, |agent, _cx| {
1297            agent.sessions.remove(session_id);
1298        });
1299        Task::ready(Ok(()))
1300    }
1301
1302    fn auth_methods(&self) -> &[acp::AuthMethod] {
1303        &[] // No auth for in-process
1304    }
1305
1306    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1307        Task::ready(Ok(()))
1308    }
1309
1310    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1311        Some(Rc::new(NativeAgentModelSelector {
1312            session_id: session_id.clone(),
1313            connection: self.clone(),
1314        }) as Rc<dyn AgentModelSelector>)
1315    }
1316
1317    fn prompt(
1318        &self,
1319        id: Option<acp_thread::UserMessageId>,
1320        params: acp::PromptRequest,
1321        cx: &mut App,
1322    ) -> Task<Result<acp::PromptResponse>> {
1323        let id = id.expect("UserMessageId is required");
1324        let session_id = params.session_id.clone();
1325        log::info!("Received prompt request for session: {}", session_id);
1326        log::debug!("Prompt blocks count: {}", params.prompt.len());
1327
1328        if let Some(parsed_command) = Command::parse(&params.prompt) {
1329            let registry = self.0.read(cx).context_server_registry.read(cx);
1330
1331            let explicit_server_id = parsed_command
1332                .explicit_server_id
1333                .map(|server_id| ContextServerId(server_id.into()));
1334
1335            if let Some(prompt) =
1336                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1337            {
1338                let arguments = if !parsed_command.arg_value.is_empty()
1339                    && let Some(arg_name) = prompt
1340                        .prompt
1341                        .arguments
1342                        .as_ref()
1343                        .and_then(|args| args.first())
1344                        .map(|arg| arg.name.clone())
1345                {
1346                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1347                } else {
1348                    Default::default()
1349                };
1350
1351                let prompt_name = prompt.prompt.name.clone();
1352                let server_id = prompt.server_id.clone();
1353
1354                return self.0.update(cx, |agent, cx| {
1355                    agent.send_mcp_prompt(
1356                        id,
1357                        session_id.clone(),
1358                        prompt_name,
1359                        server_id,
1360                        arguments,
1361                        params.prompt,
1362                        cx,
1363                    )
1364                });
1365            };
1366        };
1367
1368        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1369
1370        self.run_turn(session_id, cx, move |thread, cx| {
1371            let content: Vec<UserMessageContent> = params
1372                .prompt
1373                .into_iter()
1374                .map(|block| UserMessageContent::from_content_block(block, path_style))
1375                .collect::<Vec<_>>();
1376            log::debug!("Converted prompt to message: {} chars", content.len());
1377            log::debug!("Message id: {:?}", id);
1378            log::debug!("Message content: {:?}", content);
1379
1380            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1381        })
1382    }
1383
1384    fn retry(
1385        &self,
1386        session_id: &acp::SessionId,
1387        _cx: &App,
1388    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1389        Some(Rc::new(NativeAgentSessionRetry {
1390            connection: self.clone(),
1391            session_id: session_id.clone(),
1392        }) as _)
1393    }
1394
1395    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1396        log::info!("Cancelling on session: {}", session_id);
1397        self.0.update(cx, |agent, cx| {
1398            if let Some(session) = agent.sessions.get(session_id) {
1399                session
1400                    .thread
1401                    .update(cx, |thread, cx| thread.cancel(cx))
1402                    .detach();
1403            }
1404        });
1405    }
1406
1407    fn truncate(
1408        &self,
1409        session_id: &agent_client_protocol::SessionId,
1410        cx: &App,
1411    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1412        self.0.read_with(cx, |agent, _cx| {
1413            agent.sessions.get(session_id).map(|session| {
1414                Rc::new(NativeAgentSessionTruncate {
1415                    thread: session.thread.clone(),
1416                    acp_thread: session.acp_thread.downgrade(),
1417                }) as _
1418            })
1419        })
1420    }
1421
1422    fn set_title(
1423        &self,
1424        session_id: &acp::SessionId,
1425        cx: &App,
1426    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1427        self.0.read_with(cx, |agent, _cx| {
1428            agent
1429                .sessions
1430                .get(session_id)
1431                .filter(|s| !s.thread.read(cx).is_subagent())
1432                .map(|session| {
1433                    Rc::new(NativeAgentSessionSetTitle {
1434                        thread: session.thread.clone(),
1435                    }) as _
1436                })
1437        })
1438    }
1439
1440    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1441        let thread_store = self.0.read(cx).thread_store.clone();
1442        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1443    }
1444
1445    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1446        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1447    }
1448
1449    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1450        self
1451    }
1452}
1453
1454impl acp_thread::AgentTelemetry for NativeAgentConnection {
1455    fn thread_data(
1456        &self,
1457        session_id: &acp::SessionId,
1458        cx: &mut App,
1459    ) -> Task<Result<serde_json::Value>> {
1460        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1461            return Task::ready(Err(anyhow!("Session not found")));
1462        };
1463
1464        let task = session.thread.read(cx).to_db(cx);
1465        cx.background_spawn(async move {
1466            serde_json::to_value(task.await).context("Failed to serialize thread")
1467        })
1468    }
1469}
1470
1471pub struct NativeAgentSessionList {
1472    thread_store: Entity<ThreadStore>,
1473    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1474    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1475    _subscription: Subscription,
1476}
1477
1478impl NativeAgentSessionList {
1479    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1480        let (tx, rx) = smol::channel::unbounded();
1481        let this_tx = tx.clone();
1482        let subscription = cx.observe(&thread_store, move |_, _| {
1483            this_tx
1484                .try_send(acp_thread::SessionListUpdate::Refresh)
1485                .ok();
1486        });
1487        Self {
1488            thread_store,
1489            updates_tx: tx,
1490            updates_rx: rx,
1491            _subscription: subscription,
1492        }
1493    }
1494
1495    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1496        &self.thread_store
1497    }
1498}
1499
1500impl AgentSessionList for NativeAgentSessionList {
1501    fn list_sessions(
1502        &self,
1503        _request: AgentSessionListRequest,
1504        cx: &mut App,
1505    ) -> Task<Result<AgentSessionListResponse>> {
1506        let sessions = self
1507            .thread_store
1508            .read(cx)
1509            .entries()
1510            .map(|entry| AgentSessionInfo::from(&entry))
1511            .collect();
1512        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1513    }
1514
1515    fn supports_delete(&self) -> bool {
1516        true
1517    }
1518
1519    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1520        self.thread_store
1521            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1522    }
1523
1524    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1525        self.thread_store
1526            .update(cx, |store, cx| store.delete_threads(cx))
1527    }
1528
1529    fn watch(
1530        &self,
1531        _cx: &mut App,
1532    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1533        Some(self.updates_rx.clone())
1534    }
1535
1536    fn notify_refresh(&self) {
1537        self.updates_tx
1538            .try_send(acp_thread::SessionListUpdate::Refresh)
1539            .ok();
1540    }
1541
1542    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1543        self
1544    }
1545}
1546
1547struct NativeAgentSessionTruncate {
1548    thread: Entity<Thread>,
1549    acp_thread: WeakEntity<AcpThread>,
1550}
1551
1552impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1553    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1554        match self.thread.update(cx, |thread, cx| {
1555            thread.truncate(message_id.clone(), cx)?;
1556            Ok(thread.latest_token_usage())
1557        }) {
1558            Ok(usage) => {
1559                self.acp_thread
1560                    .update(cx, |thread, cx| {
1561                        thread.update_token_usage(usage, cx);
1562                    })
1563                    .ok();
1564                Task::ready(Ok(()))
1565            }
1566            Err(error) => Task::ready(Err(error)),
1567        }
1568    }
1569}
1570
1571struct NativeAgentSessionRetry {
1572    connection: NativeAgentConnection,
1573    session_id: acp::SessionId,
1574}
1575
1576impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1577    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1578        self.connection
1579            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1580                thread.update(cx, |thread, cx| thread.resume(cx))
1581            })
1582    }
1583}
1584
1585struct NativeAgentSessionSetTitle {
1586    thread: Entity<Thread>,
1587}
1588
1589impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1590    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1591        self.thread
1592            .update(cx, |thread, cx| thread.set_title(title, cx));
1593        Task::ready(Ok(()))
1594    }
1595}
1596
1597pub struct NativeThreadEnvironment {
1598    agent: WeakEntity<NativeAgent>,
1599    thread: WeakEntity<Thread>,
1600    acp_thread: WeakEntity<AcpThread>,
1601}
1602
1603impl NativeThreadEnvironment {
1604    pub(crate) fn create_subagent_thread(
1605        &self,
1606        label: String,
1607        cx: &mut App,
1608    ) -> Result<Rc<dyn SubagentHandle>> {
1609        let Some(parent_thread_entity) = self.thread.upgrade() else {
1610            anyhow::bail!("Parent thread no longer exists".to_string());
1611        };
1612        let parent_thread = parent_thread_entity.read(cx);
1613        let current_depth = parent_thread.depth();
1614
1615        if current_depth >= MAX_SUBAGENT_DEPTH {
1616            return Err(anyhow!(
1617                "Maximum subagent depth ({}) reached",
1618                MAX_SUBAGENT_DEPTH
1619            ));
1620        }
1621
1622        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1623            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1624            thread.set_title(label.into(), cx);
1625            thread
1626        });
1627
1628        let session_id = subagent_thread.read(cx).id().clone();
1629
1630        let acp_thread = self.agent.update(cx, |agent, cx| {
1631            agent.register_session(subagent_thread.clone(), cx)
1632        })?;
1633
1634        let depth = current_depth + 1;
1635
1636        telemetry::event!(
1637            "Subagent Started",
1638            session = parent_thread_entity.read(cx).id().to_string(),
1639            subagent_session = session_id.to_string(),
1640            depth,
1641            is_resumed = false,
1642        );
1643
1644        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1645    }
1646
1647    pub(crate) fn resume_subagent_thread(
1648        &self,
1649        session_id: acp::SessionId,
1650        cx: &mut App,
1651    ) -> Result<Rc<dyn SubagentHandle>> {
1652        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1653            let session = agent
1654                .sessions
1655                .get(&session_id)
1656                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1657            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1658        })??;
1659
1660        let depth = subagent_thread.read(cx).depth();
1661
1662        if let Some(parent_thread_entity) = self.thread.upgrade() {
1663            telemetry::event!(
1664                "Subagent Started",
1665                session = parent_thread_entity.read(cx).id().to_string(),
1666                subagent_session = session_id.to_string(),
1667                depth,
1668                is_resumed = true,
1669            );
1670        }
1671
1672        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1673    }
1674
1675    fn prompt_subagent(
1676        &self,
1677        session_id: acp::SessionId,
1678        subagent_thread: Entity<Thread>,
1679        acp_thread: Entity<acp_thread::AcpThread>,
1680    ) -> Result<Rc<dyn SubagentHandle>> {
1681        let Some(parent_thread_entity) = self.thread.upgrade() else {
1682            anyhow::bail!("Parent thread no longer exists".to_string());
1683        };
1684        Ok(Rc::new(NativeSubagentHandle::new(
1685            session_id,
1686            subagent_thread,
1687            acp_thread,
1688            parent_thread_entity,
1689        )) as _)
1690    }
1691}
1692
1693impl ThreadEnvironment for NativeThreadEnvironment {
1694    fn create_terminal(
1695        &self,
1696        command: String,
1697        cwd: Option<PathBuf>,
1698        output_byte_limit: Option<u64>,
1699        cx: &mut AsyncApp,
1700    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1701        let task = self.acp_thread.update(cx, |thread, cx| {
1702            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1703        });
1704
1705        let acp_thread = self.acp_thread.clone();
1706        cx.spawn(async move |cx| {
1707            let terminal = task?.await?;
1708
1709            let (drop_tx, drop_rx) = oneshot::channel();
1710            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1711
1712            cx.spawn(async move |cx| {
1713                drop_rx.await.ok();
1714                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1715            })
1716            .detach();
1717
1718            let handle = AcpTerminalHandle {
1719                terminal,
1720                _drop_tx: Some(drop_tx),
1721            };
1722
1723            Ok(Rc::new(handle) as _)
1724        })
1725    }
1726
1727    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1728        self.create_subagent_thread(label, cx)
1729    }
1730
1731    fn resume_subagent(
1732        &self,
1733        session_id: acp::SessionId,
1734        cx: &mut App,
1735    ) -> Result<Rc<dyn SubagentHandle>> {
1736        self.resume_subagent_thread(session_id, cx)
1737    }
1738}
1739
1740#[derive(Debug, Clone)]
1741enum SubagentPromptResult {
1742    Completed,
1743    Cancelled,
1744    ContextWindowWarning,
1745    Error(String),
1746}
1747
1748pub struct NativeSubagentHandle {
1749    session_id: acp::SessionId,
1750    parent_thread: WeakEntity<Thread>,
1751    subagent_thread: Entity<Thread>,
1752    acp_thread: Entity<acp_thread::AcpThread>,
1753}
1754
1755impl NativeSubagentHandle {
1756    fn new(
1757        session_id: acp::SessionId,
1758        subagent_thread: Entity<Thread>,
1759        acp_thread: Entity<acp_thread::AcpThread>,
1760        parent_thread_entity: Entity<Thread>,
1761    ) -> Self {
1762        NativeSubagentHandle {
1763            session_id,
1764            subagent_thread,
1765            parent_thread: parent_thread_entity.downgrade(),
1766            acp_thread,
1767        }
1768    }
1769}
1770
1771impl SubagentHandle for NativeSubagentHandle {
1772    fn id(&self) -> acp::SessionId {
1773        self.session_id.clone()
1774    }
1775
1776    fn num_entries(&self, cx: &App) -> usize {
1777        self.acp_thread.read(cx).entries().len()
1778    }
1779
1780    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1781        let thread = self.subagent_thread.clone();
1782        let acp_thread = self.acp_thread.clone();
1783        let subagent_session_id = self.session_id.clone();
1784        let parent_thread = self.parent_thread.clone();
1785
1786        cx.spawn(async move |cx| {
1787            let (task, _subscription) = cx.update(|cx| {
1788                let ratio_before_prompt = thread
1789                    .read(cx)
1790                    .latest_token_usage()
1791                    .map(|usage| usage.ratio());
1792
1793                parent_thread
1794                    .update(cx, |parent_thread, _cx| {
1795                        parent_thread.register_running_subagent(thread.downgrade())
1796                    })
1797                    .ok();
1798
1799                let task = acp_thread.update(cx, |acp_thread, cx| {
1800                    acp_thread.send(vec![message.into()], cx)
1801                });
1802
1803                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1804                let mut token_limit_tx = Some(token_limit_tx);
1805
1806                let subscription = cx.subscribe(
1807                    &thread,
1808                    move |_thread, event: &TokenUsageUpdated, _cx| {
1809                        if let Some(usage) = &event.0 {
1810                            let old_ratio = ratio_before_prompt
1811                                .clone()
1812                                .unwrap_or(TokenUsageRatio::Normal);
1813                            let new_ratio = usage.ratio();
1814                            if old_ratio == TokenUsageRatio::Normal
1815                                && new_ratio == TokenUsageRatio::Warning
1816                            {
1817                                if let Some(tx) = token_limit_tx.take() {
1818                                    tx.send(()).ok();
1819                                }
1820                            }
1821                        }
1822                    },
1823                );
1824
1825                let wait_for_prompt = cx
1826                    .background_spawn(async move {
1827                        futures::select! {
1828                            response = task.fuse() => match response {
1829                                Ok(Some(response)) => {
1830                                    match response.stop_reason {
1831                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1832                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1833                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1834                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1835                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1836                                    }
1837                                }
1838                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1839                                Err(error) => SubagentPromptResult::Error(error.to_string()),
1840                            },
1841                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1842                        }
1843                    });
1844
1845                (wait_for_prompt, subscription)
1846            });
1847
1848            let result = match task.await {
1849                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1850                    thread
1851                        .last_message()
1852                        .and_then(|message| {
1853                            let content = message.as_agent_message()?
1854                                .content
1855                                .iter()
1856                                .filter_map(|c| match c {
1857                                    AgentMessageContent::Text(text) => Some(text.as_str()),
1858                                    _ => None,
1859                                })
1860                                .join("\n\n");
1861                            if content.is_empty() {
1862                                None
1863                            } else {
1864                                Some( content)
1865                            }
1866                        })
1867                        .context("No response from subagent")
1868                }),
1869                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
1870                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1871                SubagentPromptResult::ContextWindowWarning => {
1872                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1873                    Err(anyhow!(
1874                        "The agent is nearing the end of its context window and has been \
1875                         stopped. You can prompt the thread again to have the agent wrap up \
1876                         or hand off its work."
1877                    ))
1878                }
1879            };
1880
1881            parent_thread
1882                .update(cx, |parent_thread, cx| {
1883                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1884                })
1885                .ok();
1886
1887            result
1888        })
1889    }
1890}
1891
1892pub struct AcpTerminalHandle {
1893    terminal: Entity<acp_thread::Terminal>,
1894    _drop_tx: Option<oneshot::Sender<()>>,
1895}
1896
1897impl TerminalHandle for AcpTerminalHandle {
1898    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1899        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1900    }
1901
1902    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1903        Ok(self
1904            .terminal
1905            .read_with(cx, |term, _cx| term.wait_for_exit()))
1906    }
1907
1908    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1909        Ok(self
1910            .terminal
1911            .read_with(cx, |term, cx| term.current_output(cx)))
1912    }
1913
1914    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1915        cx.update(|cx| {
1916            self.terminal.update(cx, |terminal, cx| {
1917                terminal.kill(cx);
1918            });
1919        });
1920        Ok(())
1921    }
1922
1923    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1924        Ok(self
1925            .terminal
1926            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1927    }
1928}
1929
1930#[cfg(test)]
1931mod internal_tests {
1932    use super::*;
1933    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1934    use fs::FakeFs;
1935    use gpui::TestAppContext;
1936    use indoc::formatdoc;
1937    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1938    use language_model::{
1939        LanguageModelCompletionEvent, LanguageModelProviderId, LanguageModelProviderName,
1940    };
1941    use serde_json::json;
1942    use settings::SettingsStore;
1943    use util::{path, rel_path::rel_path};
1944
1945    #[gpui::test]
1946    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1947        init_test(cx);
1948        let fs = FakeFs::new(cx.executor());
1949        fs.insert_tree(
1950            "/",
1951            json!({
1952                "a": {}
1953            }),
1954        )
1955        .await;
1956        let project = Project::test(fs.clone(), [], cx).await;
1957        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1958        let agent = NativeAgent::new(
1959            project.clone(),
1960            thread_store,
1961            Templates::new(),
1962            None,
1963            fs.clone(),
1964            &mut cx.to_async(),
1965        )
1966        .await
1967        .unwrap();
1968        agent.read_with(cx, |agent, cx| {
1969            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1970        });
1971
1972        let worktree = project
1973            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1974            .await
1975            .unwrap();
1976        cx.run_until_parked();
1977        agent.read_with(cx, |agent, cx| {
1978            assert_eq!(
1979                agent.project_context.read(cx).worktrees,
1980                vec![WorktreeContext {
1981                    root_name: "a".into(),
1982                    abs_path: Path::new("/a").into(),
1983                    rules_file: None
1984                }]
1985            )
1986        });
1987
1988        // Creating `/a/.rules` updates the project context.
1989        fs.insert_file("/a/.rules", Vec::new()).await;
1990        cx.run_until_parked();
1991        agent.read_with(cx, |agent, cx| {
1992            let rules_entry = worktree
1993                .read(cx)
1994                .entry_for_path(rel_path(".rules"))
1995                .unwrap();
1996            assert_eq!(
1997                agent.project_context.read(cx).worktrees,
1998                vec![WorktreeContext {
1999                    root_name: "a".into(),
2000                    abs_path: Path::new("/a").into(),
2001                    rules_file: Some(RulesFileContext {
2002                        path_in_worktree: rel_path(".rules").into(),
2003                        text: "".into(),
2004                        project_entry_id: rules_entry.id.to_usize()
2005                    })
2006                }]
2007            )
2008        });
2009    }
2010
2011    #[gpui::test]
2012    async fn test_listing_models(cx: &mut TestAppContext) {
2013        init_test(cx);
2014        let fs = FakeFs::new(cx.executor());
2015        fs.insert_tree("/", json!({ "a": {}  })).await;
2016        let project = Project::test(fs.clone(), [], cx).await;
2017        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2018        let connection = NativeAgentConnection(
2019            NativeAgent::new(
2020                project.clone(),
2021                thread_store,
2022                Templates::new(),
2023                None,
2024                fs.clone(),
2025                &mut cx.to_async(),
2026            )
2027            .await
2028            .unwrap(),
2029        );
2030
2031        // Create a thread/session
2032        let acp_thread = cx
2033            .update(|cx| {
2034                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2035            })
2036            .await
2037            .unwrap();
2038
2039        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2040
2041        let models = cx
2042            .update(|cx| {
2043                connection
2044                    .model_selector(&session_id)
2045                    .unwrap()
2046                    .list_models(cx)
2047            })
2048            .await
2049            .unwrap();
2050
2051        let acp_thread::AgentModelList::Grouped(models) = models else {
2052            panic!("Unexpected model group");
2053        };
2054        assert_eq!(
2055            models,
2056            IndexMap::from_iter([(
2057                AgentModelGroupName("Fake".into()),
2058                vec![AgentModelInfo {
2059                    id: acp::ModelId::new("fake/fake"),
2060                    name: "Fake".into(),
2061                    description: None,
2062                    icon: Some(acp_thread::AgentModelIcon::Named(
2063                        ui::IconName::ZedAssistant
2064                    )),
2065                    is_latest: false,
2066                    cost: None,
2067                }]
2068            )])
2069        );
2070    }
2071
2072    #[gpui::test]
2073    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2074        init_test(cx);
2075        let fs = FakeFs::new(cx.executor());
2076        fs.create_dir(paths::settings_file().parent().unwrap())
2077            .await
2078            .unwrap();
2079        fs.insert_file(
2080            paths::settings_file(),
2081            json!({
2082                "agent": {
2083                    "default_model": {
2084                        "provider": "foo",
2085                        "model": "bar"
2086                    }
2087                }
2088            })
2089            .to_string()
2090            .into_bytes(),
2091        )
2092        .await;
2093        let project = Project::test(fs.clone(), [], cx).await;
2094
2095        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2096
2097        // Create the agent and connection
2098        let agent = NativeAgent::new(
2099            project.clone(),
2100            thread_store,
2101            Templates::new(),
2102            None,
2103            fs.clone(),
2104            &mut cx.to_async(),
2105        )
2106        .await
2107        .unwrap();
2108        let connection = NativeAgentConnection(agent.clone());
2109
2110        // Create a thread/session
2111        let acp_thread = cx
2112            .update(|cx| {
2113                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2114            })
2115            .await
2116            .unwrap();
2117
2118        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2119
2120        // Select a model
2121        let selector = connection.model_selector(&session_id).unwrap();
2122        let model_id = acp::ModelId::new("fake/fake");
2123        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2124            .await
2125            .unwrap();
2126
2127        // Verify the thread has the selected model
2128        agent.read_with(cx, |agent, _| {
2129            let session = agent.sessions.get(&session_id).unwrap();
2130            session.thread.read_with(cx, |thread, _| {
2131                assert_eq!(thread.model().unwrap().id().0, "fake");
2132            });
2133        });
2134
2135        cx.run_until_parked();
2136
2137        // Verify settings file was updated
2138        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2139        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2140
2141        // Check that the agent settings contain the selected model
2142        assert_eq!(
2143            settings_json["agent"]["default_model"]["model"],
2144            json!("fake")
2145        );
2146        assert_eq!(
2147            settings_json["agent"]["default_model"]["provider"],
2148            json!("fake")
2149        );
2150
2151        // Register a thinking model and select it.
2152        cx.update(|cx| {
2153            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2154                "fake-corp",
2155                "fake-thinking",
2156                "Fake Thinking",
2157                true,
2158            ));
2159            let thinking_provider = Arc::new(
2160                FakeLanguageModelProvider::new(
2161                    LanguageModelProviderId::from("fake-corp".to_string()),
2162                    LanguageModelProviderName::from("Fake Corp".to_string()),
2163                )
2164                .with_models(vec![thinking_model]),
2165            );
2166            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2167                registry.register_provider(thinking_provider, cx);
2168            });
2169        });
2170        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2171
2172        let selector = connection.model_selector(&session_id).unwrap();
2173        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2174            .await
2175            .unwrap();
2176        cx.run_until_parked();
2177
2178        // Verify enable_thinking was written to settings as true.
2179        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2180        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2181        assert_eq!(
2182            settings_json["agent"]["default_model"]["enable_thinking"],
2183            json!(true),
2184            "selecting a thinking model should persist enable_thinking: true to settings"
2185        );
2186    }
2187
2188    #[gpui::test]
2189    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2190        init_test(cx);
2191        let fs = FakeFs::new(cx.executor());
2192        fs.create_dir(paths::settings_file().parent().unwrap())
2193            .await
2194            .unwrap();
2195        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2196        let project = Project::test(fs.clone(), [], cx).await;
2197
2198        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2199        let agent = NativeAgent::new(
2200            project.clone(),
2201            thread_store,
2202            Templates::new(),
2203            None,
2204            fs.clone(),
2205            &mut cx.to_async(),
2206        )
2207        .await
2208        .unwrap();
2209        let connection = NativeAgentConnection(agent.clone());
2210
2211        let acp_thread = cx
2212            .update(|cx| {
2213                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2214            })
2215            .await
2216            .unwrap();
2217        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2218
2219        // Register a second provider with a thinking model.
2220        cx.update(|cx| {
2221            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2222                "fake-corp",
2223                "fake-thinking",
2224                "Fake Thinking",
2225                true,
2226            ));
2227            let thinking_provider = Arc::new(
2228                FakeLanguageModelProvider::new(
2229                    LanguageModelProviderId::from("fake-corp".to_string()),
2230                    LanguageModelProviderName::from("Fake Corp".to_string()),
2231                )
2232                .with_models(vec![thinking_model]),
2233            );
2234            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2235                registry.register_provider(thinking_provider, cx);
2236            });
2237        });
2238        // Refresh the agent's model list so it picks up the new provider.
2239        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2240
2241        // Thread starts with thinking_enabled = false (the default).
2242        agent.read_with(cx, |agent, _| {
2243            let session = agent.sessions.get(&session_id).unwrap();
2244            session.thread.read_with(cx, |thread, _| {
2245                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2246            });
2247        });
2248
2249        // Select the thinking model via select_model.
2250        let selector = connection.model_selector(&session_id).unwrap();
2251        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2252            .await
2253            .unwrap();
2254
2255        // select_model should have enabled thinking based on the model's supports_thinking().
2256        agent.read_with(cx, |agent, _| {
2257            let session = agent.sessions.get(&session_id).unwrap();
2258            session.thread.read_with(cx, |thread, _| {
2259                assert!(
2260                    thread.thinking_enabled(),
2261                    "select_model should enable thinking when model supports it"
2262                );
2263            });
2264        });
2265
2266        // Switch back to the non-thinking model.
2267        let selector = connection.model_selector(&session_id).unwrap();
2268        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2269            .await
2270            .unwrap();
2271
2272        // select_model should have disabled thinking.
2273        agent.read_with(cx, |agent, _| {
2274            let session = agent.sessions.get(&session_id).unwrap();
2275            session.thread.read_with(cx, |thread, _| {
2276                assert!(
2277                    !thread.thinking_enabled(),
2278                    "select_model should disable thinking when model does not support it"
2279                );
2280            });
2281        });
2282    }
2283
2284    #[gpui::test]
2285    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2286        init_test(cx);
2287        let fs = FakeFs::new(cx.executor());
2288        fs.insert_tree("/", json!({ "a": {} })).await;
2289        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2290        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2291        let agent = NativeAgent::new(
2292            project.clone(),
2293            thread_store.clone(),
2294            Templates::new(),
2295            None,
2296            fs.clone(),
2297            &mut cx.to_async(),
2298        )
2299        .await
2300        .unwrap();
2301        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2302
2303        // Register a thinking model.
2304        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2305            "fake-corp",
2306            "fake-thinking",
2307            "Fake Thinking",
2308            true,
2309        ));
2310        let thinking_provider = Arc::new(
2311            FakeLanguageModelProvider::new(
2312                LanguageModelProviderId::from("fake-corp".to_string()),
2313                LanguageModelProviderName::from("Fake Corp".to_string()),
2314            )
2315            .with_models(vec![thinking_model.clone()]),
2316        );
2317        cx.update(|cx| {
2318            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2319                registry.register_provider(thinking_provider, cx);
2320            });
2321        });
2322        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2323
2324        // Create a thread and select the thinking model.
2325        let acp_thread = cx
2326            .update(|cx| {
2327                connection
2328                    .clone()
2329                    .new_session(project.clone(), Path::new("/a"), cx)
2330            })
2331            .await
2332            .unwrap();
2333        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2334
2335        let selector = connection.model_selector(&session_id).unwrap();
2336        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2337            .await
2338            .unwrap();
2339
2340        // Verify thinking is enabled after selecting the thinking model.
2341        let thread = agent.read_with(cx, |agent, _| {
2342            agent.sessions.get(&session_id).unwrap().thread.clone()
2343        });
2344        thread.read_with(cx, |thread, _| {
2345            assert!(
2346                thread.thinking_enabled(),
2347                "thinking should be enabled after selecting thinking model"
2348            );
2349        });
2350
2351        // Send a message so the thread gets persisted.
2352        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2353        let send = cx.foreground_executor().spawn(send);
2354        cx.run_until_parked();
2355
2356        thinking_model.send_last_completion_stream_text_chunk("Response.");
2357        thinking_model.end_last_completion_stream();
2358
2359        send.await.unwrap();
2360        cx.run_until_parked();
2361
2362        // Close the session so it can be reloaded from disk.
2363        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2364            .await
2365            .unwrap();
2366        drop(thread);
2367        drop(acp_thread);
2368        agent.read_with(cx, |agent, _| {
2369            assert!(agent.sessions.is_empty());
2370        });
2371
2372        // Reload the thread and verify thinking_enabled is still true.
2373        let reloaded_acp_thread = agent
2374            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2375            .await
2376            .unwrap();
2377        let reloaded_thread = agent.read_with(cx, |agent, _| {
2378            agent.sessions.get(&session_id).unwrap().thread.clone()
2379        });
2380        reloaded_thread.read_with(cx, |thread, _| {
2381            assert!(
2382                thread.thinking_enabled(),
2383                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2384            );
2385        });
2386
2387        drop(reloaded_acp_thread);
2388    }
2389
2390    #[gpui::test]
2391    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2392        init_test(cx);
2393        let fs = FakeFs::new(cx.executor());
2394        fs.insert_tree("/", json!({ "a": {} })).await;
2395        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2396        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2397        let agent = NativeAgent::new(
2398            project.clone(),
2399            thread_store.clone(),
2400            Templates::new(),
2401            None,
2402            fs.clone(),
2403            &mut cx.to_async(),
2404        )
2405        .await
2406        .unwrap();
2407        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2408
2409        // Register a model where id() != name(), like real Anthropic models
2410        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2411        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2412            "fake-corp",
2413            "custom-model-id",
2414            "Custom Model Display Name",
2415            false,
2416        ));
2417        let provider = Arc::new(
2418            FakeLanguageModelProvider::new(
2419                LanguageModelProviderId::from("fake-corp".to_string()),
2420                LanguageModelProviderName::from("Fake Corp".to_string()),
2421            )
2422            .with_models(vec![model.clone()]),
2423        );
2424        cx.update(|cx| {
2425            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2426                registry.register_provider(provider, cx);
2427            });
2428        });
2429        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2430
2431        // Create a thread and select the model.
2432        let acp_thread = cx
2433            .update(|cx| {
2434                connection
2435                    .clone()
2436                    .new_session(project.clone(), Path::new("/a"), cx)
2437            })
2438            .await
2439            .unwrap();
2440        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2441
2442        let selector = connection.model_selector(&session_id).unwrap();
2443        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2444            .await
2445            .unwrap();
2446
2447        let thread = agent.read_with(cx, |agent, _| {
2448            agent.sessions.get(&session_id).unwrap().thread.clone()
2449        });
2450        thread.read_with(cx, |thread, _| {
2451            assert_eq!(
2452                thread.model().unwrap().id().0.as_ref(),
2453                "custom-model-id",
2454                "model should be set before persisting"
2455            );
2456        });
2457
2458        // Send a message so the thread gets persisted.
2459        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2460        let send = cx.foreground_executor().spawn(send);
2461        cx.run_until_parked();
2462
2463        model.send_last_completion_stream_text_chunk("Response.");
2464        model.end_last_completion_stream();
2465
2466        send.await.unwrap();
2467        cx.run_until_parked();
2468
2469        // Close the session so it can be reloaded from disk.
2470        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2471            .await
2472            .unwrap();
2473        drop(thread);
2474        drop(acp_thread);
2475        agent.read_with(cx, |agent, _| {
2476            assert!(agent.sessions.is_empty());
2477        });
2478
2479        // Reload the thread and verify the model was preserved.
2480        let reloaded_acp_thread = agent
2481            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2482            .await
2483            .unwrap();
2484        let reloaded_thread = agent.read_with(cx, |agent, _| {
2485            agent.sessions.get(&session_id).unwrap().thread.clone()
2486        });
2487        reloaded_thread.read_with(cx, |thread, _| {
2488            let reloaded_model = thread
2489                .model()
2490                .expect("model should be present after reload");
2491            assert_eq!(
2492                reloaded_model.id().0.as_ref(),
2493                "custom-model-id",
2494                "reloaded thread should have the same model, not fall back to the default"
2495            );
2496        });
2497
2498        drop(reloaded_acp_thread);
2499    }
2500
2501    #[gpui::test]
2502    async fn test_save_load_thread(cx: &mut TestAppContext) {
2503        init_test(cx);
2504        let fs = FakeFs::new(cx.executor());
2505        fs.insert_tree(
2506            "/",
2507            json!({
2508                "a": {
2509                    "b.md": "Lorem"
2510                }
2511            }),
2512        )
2513        .await;
2514        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2515        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2516        let agent = NativeAgent::new(
2517            project.clone(),
2518            thread_store.clone(),
2519            Templates::new(),
2520            None,
2521            fs.clone(),
2522            &mut cx.to_async(),
2523        )
2524        .await
2525        .unwrap();
2526        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2527
2528        let acp_thread = cx
2529            .update(|cx| {
2530                connection
2531                    .clone()
2532                    .new_session(project.clone(), Path::new(""), cx)
2533            })
2534            .await
2535            .unwrap();
2536        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2537        let thread = agent.read_with(cx, |agent, _| {
2538            agent.sessions.get(&session_id).unwrap().thread.clone()
2539        });
2540
2541        // Ensure empty threads are not saved, even if they get mutated.
2542        let model = Arc::new(FakeLanguageModel::default());
2543        let summary_model = Arc::new(FakeLanguageModel::default());
2544        thread.update(cx, |thread, cx| {
2545            thread.set_model(model.clone(), cx);
2546            thread.set_summarization_model(Some(summary_model.clone()), cx);
2547        });
2548        cx.run_until_parked();
2549        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2550
2551        let send = acp_thread.update(cx, |thread, cx| {
2552            thread.send(
2553                vec![
2554                    "What does ".into(),
2555                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2556                        "b.md",
2557                        MentionUri::File {
2558                            abs_path: path!("/a/b.md").into(),
2559                        }
2560                        .to_uri()
2561                        .to_string(),
2562                    )),
2563                    " mean?".into(),
2564                ],
2565                cx,
2566            )
2567        });
2568        let send = cx.foreground_executor().spawn(send);
2569        cx.run_until_parked();
2570
2571        model.send_last_completion_stream_text_chunk("Lorem.");
2572        model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2573            language_model::TokenUsage {
2574                input_tokens: 150,
2575                output_tokens: 75,
2576                ..Default::default()
2577            },
2578        ));
2579        model.end_last_completion_stream();
2580        cx.run_until_parked();
2581        summary_model
2582            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2583        summary_model.end_last_completion_stream();
2584
2585        send.await.unwrap();
2586        let uri = MentionUri::File {
2587            abs_path: path!("/a/b.md").into(),
2588        }
2589        .to_uri();
2590        acp_thread.read_with(cx, |thread, cx| {
2591            assert_eq!(
2592                thread.to_markdown(cx),
2593                formatdoc! {"
2594                    ## User
2595
2596                    What does [@b.md]({uri}) mean?
2597
2598                    ## Assistant
2599
2600                    Lorem.
2601
2602                "}
2603            )
2604        });
2605
2606        cx.run_until_parked();
2607
2608        // Set a draft prompt with rich content blocks before saving.
2609        let draft_blocks = vec![
2610            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2611            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2612            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2613        ];
2614        acp_thread.update(cx, |thread, _cx| {
2615            thread.set_draft_prompt(Some(draft_blocks.clone()));
2616        });
2617        thread.update(cx, |thread, _cx| {
2618            thread.set_ui_scroll_position(Some(gpui::ListOffset {
2619                item_ix: 5,
2620                offset_in_item: gpui::px(12.5),
2621            }));
2622        });
2623        thread.update(cx, |_thread, cx| cx.notify());
2624        cx.run_until_parked();
2625
2626        // Close the session so it can be reloaded from disk.
2627        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2628            .await
2629            .unwrap();
2630        drop(thread);
2631        drop(acp_thread);
2632        agent.read_with(cx, |agent, _| {
2633            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2634        });
2635
2636        // Ensure the thread can be reloaded from disk.
2637        assert_eq!(
2638            thread_entries(&thread_store, cx),
2639            vec![(
2640                session_id.clone(),
2641                format!("Explaining {}", path!("/a/b.md"))
2642            )]
2643        );
2644        let acp_thread = agent
2645            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2646            .await
2647            .unwrap();
2648        acp_thread.read_with(cx, |thread, cx| {
2649            assert_eq!(
2650                thread.to_markdown(cx),
2651                formatdoc! {"
2652                    ## User
2653
2654                    What does [@b.md]({uri}) mean?
2655
2656                    ## Assistant
2657
2658                    Lorem.
2659
2660                "}
2661            )
2662        });
2663
2664        // Ensure the draft prompt with rich content blocks survived the round-trip.
2665        acp_thread.read_with(cx, |thread, _| {
2666            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2667        });
2668
2669        // Ensure token usage survived the round-trip.
2670        acp_thread.read_with(cx, |thread, _| {
2671            let usage = thread
2672                .token_usage()
2673                .expect("token usage should be restored after reload");
2674            assert_eq!(usage.input_tokens, 150);
2675            assert_eq!(usage.output_tokens, 75);
2676        });
2677
2678        // Ensure scroll position survived the round-trip.
2679        acp_thread.read_with(cx, |thread, _| {
2680            let scroll = thread
2681                .ui_scroll_position()
2682                .expect("scroll position should be restored after reload");
2683            assert_eq!(scroll.item_ix, 5);
2684            assert_eq!(scroll.offset_in_item, gpui::px(12.5));
2685        });
2686    }
2687
2688    fn thread_entries(
2689        thread_store: &Entity<ThreadStore>,
2690        cx: &mut TestAppContext,
2691    ) -> Vec<(acp::SessionId, String)> {
2692        thread_store.read_with(cx, |store, _| {
2693            store
2694                .entries()
2695                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2696                .collect::<Vec<_>>()
2697        })
2698    }
2699
2700    fn init_test(cx: &mut TestAppContext) {
2701        env_logger::try_init().ok();
2702        cx.update(|cx| {
2703            let settings_store = SettingsStore::test(cx);
2704            cx.set_global(settings_store);
2705
2706            LanguageModelRegistry::test(cx);
2707        });
2708    }
2709}
2710
2711fn mcp_message_content_to_acp_content_block(
2712    content: context_server::types::MessageContent,
2713) -> acp::ContentBlock {
2714    match content {
2715        context_server::types::MessageContent::Text {
2716            text,
2717            annotations: _,
2718        } => text.into(),
2719        context_server::types::MessageContent::Image {
2720            data,
2721            mime_type,
2722            annotations: _,
2723        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2724        context_server::types::MessageContent::Audio {
2725            data,
2726            mime_type,
2727            annotations: _,
2728        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2729        context_server::types::MessageContent::Resource {
2730            resource,
2731            annotations: _,
2732        } => {
2733            let mut link =
2734                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2735            if let Some(mime_type) = resource.mime_type {
2736                link = link.mime_type(mime_type);
2737            }
2738            acp::ContentBlock::ResourceLink(link)
2739        }
2740    }
2741}