agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17pub use native_agent_server::NativeAgentServer;
  18pub use pattern_extraction::*;
  19pub use shell_command_parser::extract_commands;
  20pub use templates::*;
  21pub use thread::*;
  22pub use thread_store::*;
  23pub use tool_permissions::*;
  24pub use tools::*;
  25
  26use acp_thread::{
  27    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  28    AgentSessionListResponse, UserMessageId,
  29};
  30use agent_client_protocol as acp;
  31use anyhow::{Context as _, Result, anyhow};
  32use chrono::{DateTime, Utc};
  33use collections::{HashMap, HashSet, IndexMap};
  34use fs::Fs;
  35use futures::channel::{mpsc, oneshot};
  36use futures::future::Shared;
  37use futures::{FutureExt as _, StreamExt as _, future};
  38use gpui::{
  39    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  40};
  41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  42use project::{Project, ProjectItem, ProjectPath, Worktree};
  43use prompt_store::{
  44    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  45    WorktreeContext,
  46};
  47use serde::{Deserialize, Serialize};
  48use settings::{LanguageModelSelection, update_settings_file};
  49use std::any::Any;
  50use std::path::{Path, PathBuf};
  51use std::rc::Rc;
  52use std::sync::Arc;
  53use std::time::Duration;
  54use util::ResultExt;
  55use util::rel_path::RelPath;
  56
  57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  58pub struct ProjectSnapshot {
  59    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  60    pub timestamp: DateTime<Utc>,
  61}
  62
  63pub struct RulesLoadingError {
  64    pub message: SharedString,
  65}
  66
  67/// Holds both the internal Thread and the AcpThread for a session
  68struct Session {
  69    /// The internal thread that processes messages
  70    thread: Entity<Thread>,
  71    /// The ACP thread that handles protocol communication
  72    acp_thread: Entity<acp_thread::AcpThread>,
  73    pending_save: Task<()>,
  74    _subscriptions: Vec<Subscription>,
  75}
  76
  77pub struct LanguageModels {
  78    /// Access language model by ID
  79    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  80    /// Cached list for returning language model information
  81    model_list: acp_thread::AgentModelList,
  82    refresh_models_rx: watch::Receiver<()>,
  83    refresh_models_tx: watch::Sender<()>,
  84    _authenticate_all_providers_task: Task<()>,
  85}
  86
  87impl LanguageModels {
  88    fn new(cx: &mut App) -> Self {
  89        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  90
  91        let mut this = Self {
  92            models: HashMap::default(),
  93            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  94            refresh_models_rx,
  95            refresh_models_tx,
  96            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  97        };
  98        this.refresh_list(cx);
  99        this
 100    }
 101
 102    fn refresh_list(&mut self, cx: &App) {
 103        let providers = LanguageModelRegistry::global(cx)
 104            .read(cx)
 105            .visible_providers()
 106            .into_iter()
 107            .filter(|provider| provider.is_authenticated(cx))
 108            .collect::<Vec<_>>();
 109
 110        let mut language_model_list = IndexMap::default();
 111        let mut recommended_models = HashSet::default();
 112
 113        let mut recommended = Vec::new();
 114        for provider in &providers {
 115            for model in provider.recommended_models(cx) {
 116                recommended_models.insert((model.provider_id(), model.id()));
 117                recommended.push(Self::map_language_model_to_info(&model, provider));
 118            }
 119        }
 120        if !recommended.is_empty() {
 121            language_model_list.insert(
 122                acp_thread::AgentModelGroupName("Recommended".into()),
 123                recommended,
 124            );
 125        }
 126
 127        let mut models = HashMap::default();
 128        for provider in providers {
 129            let mut provider_models = Vec::new();
 130            for model in provider.provided_models(cx) {
 131                let model_info = Self::map_language_model_to_info(&model, &provider);
 132                let model_id = model_info.id.clone();
 133                provider_models.push(model_info);
 134                models.insert(model_id, model);
 135            }
 136            if !provider_models.is_empty() {
 137                language_model_list.insert(
 138                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 139                    provider_models,
 140                );
 141            }
 142        }
 143
 144        self.models = models;
 145        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 146        self.refresh_models_tx.send(()).ok();
 147    }
 148
 149    fn watch(&self) -> watch::Receiver<()> {
 150        self.refresh_models_rx.clone()
 151    }
 152
 153    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 154        self.models.get(model_id).cloned()
 155    }
 156
 157    fn map_language_model_to_info(
 158        model: &Arc<dyn LanguageModel>,
 159        provider: &Arc<dyn LanguageModelProvider>,
 160    ) -> acp_thread::AgentModelInfo {
 161        acp_thread::AgentModelInfo {
 162            id: Self::model_id(model),
 163            name: model.name().0,
 164            description: None,
 165            icon: Some(match provider.icon() {
 166                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 167                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 168            }),
 169            is_latest: model.is_latest(),
 170            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 171        }
 172    }
 173
 174    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 175        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 176    }
 177
 178    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 179        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 180            .read(cx)
 181            .visible_providers()
 182            .iter()
 183            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 184            .collect::<Vec<_>>();
 185
 186        cx.background_spawn(async move {
 187            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 188                if let Err(err) = authenticate_task.await {
 189                    match err {
 190                        language_model::AuthenticateError::CredentialsNotFound => {
 191                            // Since we're authenticating these providers in the
 192                            // background for the purposes of populating the
 193                            // language selector, we don't care about providers
 194                            // where the credentials are not found.
 195                        }
 196                        language_model::AuthenticateError::ConnectionRefused => {
 197                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 198                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 199                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 200                        }
 201                        _ => {
 202                            // Some providers have noisy failure states that we
 203                            // don't want to spam the logs with every time the
 204                            // language model selector is initialized.
 205                            //
 206                            // Ideally these should have more clear failure modes
 207                            // that we know are safe to ignore here, like what we do
 208                            // with `CredentialsNotFound` above.
 209                            match provider_id.0.as_ref() {
 210                                "lmstudio" | "ollama" => {
 211                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 212                                    //
 213                                    // These fail noisily, so we don't log them.
 214                                }
 215                                "copilot_chat" => {
 216                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 217                                }
 218                                _ => {
 219                                    log::error!(
 220                                        "Failed to authenticate provider: {}: {err:#}",
 221                                        provider_name.0
 222                                    );
 223                                }
 224                            }
 225                        }
 226                    }
 227                }
 228            }
 229        })
 230    }
 231}
 232
 233pub struct NativeAgent {
 234    /// Session ID -> Session mapping
 235    sessions: HashMap<acp::SessionId, Session>,
 236    thread_store: Entity<ThreadStore>,
 237    /// Shared project context for all threads
 238    project_context: Entity<ProjectContext>,
 239    project_context_needs_refresh: watch::Sender<()>,
 240    _maintain_project_context: Task<Result<()>>,
 241    context_server_registry: Entity<ContextServerRegistry>,
 242    /// Shared templates for all threads
 243    templates: Arc<Templates>,
 244    /// Cached model information
 245    models: LanguageModels,
 246    project: Entity<Project>,
 247    prompt_store: Option<Entity<PromptStore>>,
 248    fs: Arc<dyn Fs>,
 249    _subscriptions: Vec<Subscription>,
 250}
 251
 252impl NativeAgent {
 253    pub async fn new(
 254        project: Entity<Project>,
 255        thread_store: Entity<ThreadStore>,
 256        templates: Arc<Templates>,
 257        prompt_store: Option<Entity<PromptStore>>,
 258        fs: Arc<dyn Fs>,
 259        cx: &mut AsyncApp,
 260    ) -> Result<Entity<NativeAgent>> {
 261        log::debug!("Creating new NativeAgent");
 262
 263        let project_context = cx
 264            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 265            .await;
 266
 267        Ok(cx.new(|cx| {
 268            let context_server_store = project.read(cx).context_server_store();
 269            let context_server_registry =
 270                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 271
 272            let mut subscriptions = vec![
 273                cx.subscribe(&project, Self::handle_project_event),
 274                cx.subscribe(
 275                    &LanguageModelRegistry::global(cx),
 276                    Self::handle_models_updated_event,
 277                ),
 278                cx.subscribe(
 279                    &context_server_store,
 280                    Self::handle_context_server_store_updated,
 281                ),
 282                cx.subscribe(
 283                    &context_server_registry,
 284                    Self::handle_context_server_registry_event,
 285                ),
 286            ];
 287            if let Some(prompt_store) = prompt_store.as_ref() {
 288                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 289            }
 290
 291            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 292                watch::channel(());
 293            Self {
 294                sessions: HashMap::default(),
 295                thread_store,
 296                project_context: cx.new(|_| project_context),
 297                project_context_needs_refresh: project_context_needs_refresh_tx,
 298                _maintain_project_context: cx.spawn(async move |this, cx| {
 299                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 300                }),
 301                context_server_registry,
 302                templates,
 303                models: LanguageModels::new(cx),
 304                project,
 305                prompt_store,
 306                fs,
 307                _subscriptions: subscriptions,
 308            }
 309        }))
 310    }
 311
 312    fn new_session(
 313        &mut self,
 314        project: Entity<Project>,
 315        cx: &mut Context<Self>,
 316    ) -> Entity<AcpThread> {
 317        // Create Thread
 318        // Fetch default model from registry settings
 319        let registry = LanguageModelRegistry::read_global(cx);
 320        // Log available models for debugging
 321        let available_count = registry.available_models(cx).count();
 322        log::debug!("Total available models: {}", available_count);
 323
 324        let default_model = registry.default_model().and_then(|default_model| {
 325            self.models
 326                .model_from_id(&LanguageModels::model_id(&default_model.model))
 327        });
 328        let thread = cx.new(|cx| {
 329            Thread::new(
 330                project.clone(),
 331                self.project_context.clone(),
 332                self.context_server_registry.clone(),
 333                self.templates.clone(),
 334                default_model,
 335                cx,
 336            )
 337        });
 338
 339        self.register_session(thread, cx)
 340    }
 341
 342    fn register_session(
 343        &mut self,
 344        thread_handle: Entity<Thread>,
 345        cx: &mut Context<Self>,
 346    ) -> Entity<AcpThread> {
 347        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 348
 349        let thread = thread_handle.read(cx);
 350        let session_id = thread.id().clone();
 351        let parent_session_id = thread.parent_thread_id();
 352        let title = thread.title();
 353        let project = thread.project.clone();
 354        let action_log = thread.action_log.clone();
 355        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 356        let acp_thread = cx.new(|cx| {
 357            acp_thread::AcpThread::new(
 358                parent_session_id,
 359                title,
 360                connection,
 361                project.clone(),
 362                action_log.clone(),
 363                session_id.clone(),
 364                prompt_capabilities_rx,
 365                cx,
 366            )
 367        });
 368
 369        let registry = LanguageModelRegistry::read_global(cx);
 370        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 371
 372        let weak = cx.weak_entity();
 373        thread_handle.update(cx, |thread, cx| {
 374            thread.set_summarization_model(summarization_model, cx);
 375            thread.add_default_tools(
 376                Rc::new(NativeThreadEnvironment {
 377                    acp_thread: acp_thread.downgrade(),
 378                    agent: weak,
 379                }) as _,
 380                cx,
 381            )
 382        });
 383
 384        let subscriptions = vec![
 385            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 386            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 387            cx.observe(&thread_handle, move |this, thread, cx| {
 388                this.save_thread(thread, cx)
 389            }),
 390        ];
 391
 392        self.sessions.insert(
 393            session_id,
 394            Session {
 395                thread: thread_handle,
 396                acp_thread: acp_thread.clone(),
 397                _subscriptions: subscriptions,
 398                pending_save: Task::ready(()),
 399            },
 400        );
 401
 402        self.update_available_commands(cx);
 403
 404        acp_thread
 405    }
 406
 407    pub fn models(&self) -> &LanguageModels {
 408        &self.models
 409    }
 410
 411    async fn maintain_project_context(
 412        this: WeakEntity<Self>,
 413        mut needs_refresh: watch::Receiver<()>,
 414        cx: &mut AsyncApp,
 415    ) -> Result<()> {
 416        while needs_refresh.changed().await.is_ok() {
 417            let project_context = this
 418                .update(cx, |this, cx| {
 419                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 420                })?
 421                .await;
 422            this.update(cx, |this, cx| {
 423                this.project_context = cx.new(|_| project_context);
 424            })?;
 425        }
 426
 427        Ok(())
 428    }
 429
 430    fn build_project_context(
 431        project: &Entity<Project>,
 432        prompt_store: Option<&Entity<PromptStore>>,
 433        cx: &mut App,
 434    ) -> Task<ProjectContext> {
 435        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 436        let worktree_tasks = worktrees
 437            .into_iter()
 438            .map(|worktree| {
 439                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 440            })
 441            .collect::<Vec<_>>();
 442        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 443            prompt_store.read_with(cx, |prompt_store, cx| {
 444                let prompts = prompt_store.default_prompt_metadata();
 445                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 446                    let contents = prompt_store.load(prompt_metadata.id, cx);
 447                    async move { (contents.await, prompt_metadata) }
 448                });
 449                cx.background_spawn(future::join_all(load_tasks))
 450            })
 451        } else {
 452            Task::ready(vec![])
 453        };
 454
 455        cx.spawn(async move |_cx| {
 456            let (worktrees, default_user_rules) =
 457                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 458
 459            let worktrees = worktrees
 460                .into_iter()
 461                .map(|(worktree, _rules_error)| {
 462                    // TODO: show error message
 463                    // if let Some(rules_error) = rules_error {
 464                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 465                    // }
 466                    worktree
 467                })
 468                .collect::<Vec<_>>();
 469
 470            let default_user_rules = default_user_rules
 471                .into_iter()
 472                .flat_map(|(contents, prompt_metadata)| match contents {
 473                    Ok(contents) => Some(UserRulesContext {
 474                        uuid: prompt_metadata.id.as_user()?,
 475                        title: prompt_metadata.title.map(|title| title.to_string()),
 476                        contents,
 477                    }),
 478                    Err(_err) => {
 479                        // TODO: show error message
 480                        // this.update(cx, |_, cx| {
 481                        //     cx.emit(RulesLoadingError {
 482                        //         message: format!("{err:?}").into(),
 483                        //     });
 484                        // })
 485                        // .ok();
 486                        None
 487                    }
 488                })
 489                .collect::<Vec<_>>();
 490
 491            ProjectContext::new(worktrees, default_user_rules)
 492        })
 493    }
 494
 495    fn load_worktree_info_for_system_prompt(
 496        worktree: Entity<Worktree>,
 497        project: Entity<Project>,
 498        cx: &mut App,
 499    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 500        let tree = worktree.read(cx);
 501        let root_name = tree.root_name_str().into();
 502        let abs_path = tree.abs_path();
 503
 504        let mut context = WorktreeContext {
 505            root_name,
 506            abs_path,
 507            rules_file: None,
 508        };
 509
 510        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 511        let Some(rules_task) = rules_task else {
 512            return Task::ready((context, None));
 513        };
 514
 515        cx.spawn(async move |_| {
 516            let (rules_file, rules_file_error) = match rules_task.await {
 517                Ok(rules_file) => (Some(rules_file), None),
 518                Err(err) => (
 519                    None,
 520                    Some(RulesLoadingError {
 521                        message: format!("{err}").into(),
 522                    }),
 523                ),
 524            };
 525            context.rules_file = rules_file;
 526            (context, rules_file_error)
 527        })
 528    }
 529
 530    fn load_worktree_rules_file(
 531        worktree: Entity<Worktree>,
 532        project: Entity<Project>,
 533        cx: &mut App,
 534    ) -> Option<Task<Result<RulesFileContext>>> {
 535        let worktree = worktree.read(cx);
 536        let worktree_id = worktree.id();
 537        let selected_rules_file = RULES_FILE_NAMES
 538            .into_iter()
 539            .filter_map(|name| {
 540                worktree
 541                    .entry_for_path(RelPath::unix(name).unwrap())
 542                    .filter(|entry| entry.is_file())
 543                    .map(|entry| entry.path.clone())
 544            })
 545            .next();
 546
 547        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 548        // supported. This doesn't seem to occur often in GitHub repositories.
 549        selected_rules_file.map(|path_in_worktree| {
 550            let project_path = ProjectPath {
 551                worktree_id,
 552                path: path_in_worktree.clone(),
 553            };
 554            let buffer_task =
 555                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 556            let rope_task = cx.spawn(async move |cx| {
 557                let buffer = buffer_task.await?;
 558                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 559                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 560                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 561                })?;
 562                anyhow::Ok((project_entry_id, rope))
 563            });
 564            // Build a string from the rope on a background thread.
 565            cx.background_spawn(async move {
 566                let (project_entry_id, rope) = rope_task.await?;
 567                anyhow::Ok(RulesFileContext {
 568                    path_in_worktree,
 569                    text: rope.to_string().trim().to_string(),
 570                    project_entry_id: project_entry_id.to_usize(),
 571                })
 572            })
 573        })
 574    }
 575
 576    fn handle_thread_title_updated(
 577        &mut self,
 578        thread: Entity<Thread>,
 579        _: &TitleUpdated,
 580        cx: &mut Context<Self>,
 581    ) {
 582        let session_id = thread.read(cx).id();
 583        let Some(session) = self.sessions.get(session_id) else {
 584            return;
 585        };
 586        let thread = thread.downgrade();
 587        let acp_thread = session.acp_thread.downgrade();
 588        cx.spawn(async move |_, cx| {
 589            let title = thread.read_with(cx, |thread, _| thread.title())?;
 590            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 591            task.await
 592        })
 593        .detach_and_log_err(cx);
 594    }
 595
 596    fn handle_thread_token_usage_updated(
 597        &mut self,
 598        thread: Entity<Thread>,
 599        usage: &TokenUsageUpdated,
 600        cx: &mut Context<Self>,
 601    ) {
 602        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 603            return;
 604        };
 605        session.acp_thread.update(cx, |acp_thread, cx| {
 606            acp_thread.update_token_usage(usage.0.clone(), cx);
 607        });
 608    }
 609
 610    fn handle_project_event(
 611        &mut self,
 612        _project: Entity<Project>,
 613        event: &project::Event,
 614        _cx: &mut Context<Self>,
 615    ) {
 616        match event {
 617            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 618                self.project_context_needs_refresh.send(()).ok();
 619            }
 620            project::Event::WorktreeUpdatedEntries(_, items) => {
 621                if items.iter().any(|(path, _, _)| {
 622                    RULES_FILE_NAMES
 623                        .iter()
 624                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 625                }) {
 626                    self.project_context_needs_refresh.send(()).ok();
 627                }
 628            }
 629            _ => {}
 630        }
 631    }
 632
 633    fn handle_prompts_updated_event(
 634        &mut self,
 635        _prompt_store: Entity<PromptStore>,
 636        _event: &prompt_store::PromptsUpdatedEvent,
 637        _cx: &mut Context<Self>,
 638    ) {
 639        self.project_context_needs_refresh.send(()).ok();
 640    }
 641
 642    fn handle_models_updated_event(
 643        &mut self,
 644        _registry: Entity<LanguageModelRegistry>,
 645        _event: &language_model::Event,
 646        cx: &mut Context<Self>,
 647    ) {
 648        self.models.refresh_list(cx);
 649
 650        let registry = LanguageModelRegistry::read_global(cx);
 651        let default_model = registry.default_model().map(|m| m.model);
 652        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 653
 654        for session in self.sessions.values_mut() {
 655            session.thread.update(cx, |thread, cx| {
 656                if thread.model().is_none()
 657                    && let Some(model) = default_model.clone()
 658                {
 659                    thread.set_model(model, cx);
 660                    cx.notify();
 661                }
 662                thread.set_summarization_model(summarization_model.clone(), cx);
 663            });
 664        }
 665    }
 666
 667    fn handle_context_server_store_updated(
 668        &mut self,
 669        _store: Entity<project::context_server_store::ContextServerStore>,
 670        _event: &project::context_server_store::ServerStatusChangedEvent,
 671        cx: &mut Context<Self>,
 672    ) {
 673        self.update_available_commands(cx);
 674    }
 675
 676    fn handle_context_server_registry_event(
 677        &mut self,
 678        _registry: Entity<ContextServerRegistry>,
 679        event: &ContextServerRegistryEvent,
 680        cx: &mut Context<Self>,
 681    ) {
 682        match event {
 683            ContextServerRegistryEvent::ToolsChanged => {}
 684            ContextServerRegistryEvent::PromptsChanged => {
 685                self.update_available_commands(cx);
 686            }
 687        }
 688    }
 689
 690    fn update_available_commands(&self, cx: &mut Context<Self>) {
 691        let available_commands = self.build_available_commands(cx);
 692        for session in self.sessions.values() {
 693            session.acp_thread.update(cx, |thread, cx| {
 694                thread
 695                    .handle_session_update(
 696                        acp::SessionUpdate::AvailableCommandsUpdate(
 697                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 698                        ),
 699                        cx,
 700                    )
 701                    .log_err();
 702            });
 703        }
 704    }
 705
 706    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 707        let registry = self.context_server_registry.read(cx);
 708
 709        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 710        for context_server_prompt in registry.prompts() {
 711            *prompt_name_counts
 712                .entry(context_server_prompt.prompt.name.as_str())
 713                .or_insert(0) += 1;
 714        }
 715
 716        registry
 717            .prompts()
 718            .flat_map(|context_server_prompt| {
 719                let prompt = &context_server_prompt.prompt;
 720
 721                let should_prefix = prompt_name_counts
 722                    .get(prompt.name.as_str())
 723                    .copied()
 724                    .unwrap_or(0)
 725                    > 1;
 726
 727                let name = if should_prefix {
 728                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 729                } else {
 730                    prompt.name.clone()
 731                };
 732
 733                let mut command = acp::AvailableCommand::new(
 734                    name,
 735                    prompt.description.clone().unwrap_or_default(),
 736                );
 737
 738                match prompt.arguments.as_deref() {
 739                    Some([arg]) => {
 740                        let hint = format!("<{}>", arg.name);
 741
 742                        command = command.input(acp::AvailableCommandInput::Unstructured(
 743                            acp::UnstructuredCommandInput::new(hint),
 744                        ));
 745                    }
 746                    Some([]) | None => {}
 747                    Some(_) => {
 748                        // skip >1 argument commands since we don't support them yet
 749                        return None;
 750                    }
 751                }
 752
 753                Some(command)
 754            })
 755            .collect()
 756    }
 757
 758    pub fn load_thread(
 759        &mut self,
 760        id: acp::SessionId,
 761        cx: &mut Context<Self>,
 762    ) -> Task<Result<Entity<Thread>>> {
 763        let database_future = ThreadsDatabase::connect(cx);
 764        cx.spawn(async move |this, cx| {
 765            let database = database_future.await.map_err(|err| anyhow!(err))?;
 766            let db_thread = database
 767                .load_thread(id.clone())
 768                .await?
 769                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 770
 771            this.update(cx, |this, cx| {
 772                let summarization_model = LanguageModelRegistry::read_global(cx)
 773                    .thread_summary_model()
 774                    .map(|c| c.model);
 775
 776                cx.new(|cx| {
 777                    let mut thread = Thread::from_db(
 778                        id.clone(),
 779                        db_thread,
 780                        this.project.clone(),
 781                        this.project_context.clone(),
 782                        this.context_server_registry.clone(),
 783                        this.templates.clone(),
 784                        cx,
 785                    );
 786                    thread.set_summarization_model(summarization_model, cx);
 787                    thread
 788                })
 789            })
 790        })
 791    }
 792
 793    pub fn open_thread(
 794        &mut self,
 795        id: acp::SessionId,
 796        cx: &mut Context<Self>,
 797    ) -> Task<Result<Entity<AcpThread>>> {
 798        if let Some(session) = self.sessions.get(&id) {
 799            return Task::ready(Ok(session.acp_thread.clone()));
 800        }
 801
 802        let task = self.load_thread(id, cx);
 803        cx.spawn(async move |this, cx| {
 804            let thread = task.await?;
 805            let acp_thread =
 806                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 807            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 808            cx.update(|cx| {
 809                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 810            })
 811            .await?;
 812            Ok(acp_thread)
 813        })
 814    }
 815
 816    pub fn thread_summary(
 817        &mut self,
 818        id: acp::SessionId,
 819        cx: &mut Context<Self>,
 820    ) -> Task<Result<SharedString>> {
 821        let thread = self.open_thread(id.clone(), cx);
 822        cx.spawn(async move |this, cx| {
 823            let acp_thread = thread.await?;
 824            let result = this
 825                .update(cx, |this, cx| {
 826                    this.sessions
 827                        .get(&id)
 828                        .unwrap()
 829                        .thread
 830                        .update(cx, |thread, cx| thread.summary(cx))
 831                })?
 832                .await
 833                .context("Failed to generate summary")?;
 834            drop(acp_thread);
 835            Ok(result)
 836        })
 837    }
 838
 839    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 840        if thread.read(cx).is_empty() {
 841            return;
 842        }
 843
 844        let database_future = ThreadsDatabase::connect(cx);
 845        let (id, db_thread) =
 846            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 847        let Some(session) = self.sessions.get_mut(&id) else {
 848            return;
 849        };
 850        let thread_store = self.thread_store.clone();
 851        session.pending_save = cx.spawn(async move |_, cx| {
 852            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 853                return;
 854            };
 855            let db_thread = db_thread.await;
 856            database.save_thread(id, db_thread).await.log_err();
 857            thread_store.update(cx, |store, cx| store.reload(cx));
 858        });
 859    }
 860
 861    fn send_mcp_prompt(
 862        &self,
 863        message_id: UserMessageId,
 864        session_id: agent_client_protocol::SessionId,
 865        prompt_name: String,
 866        server_id: ContextServerId,
 867        arguments: HashMap<String, String>,
 868        original_content: Vec<acp::ContentBlock>,
 869        cx: &mut Context<Self>,
 870    ) -> Task<Result<acp::PromptResponse>> {
 871        let server_store = self.context_server_registry.read(cx).server_store().clone();
 872        let path_style = self.project.read(cx).path_style(cx);
 873
 874        cx.spawn(async move |this, cx| {
 875            let prompt =
 876                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 877
 878            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 879                let session = this
 880                    .sessions
 881                    .get(&session_id)
 882                    .context("Failed to get session")?;
 883                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 884            })??;
 885
 886            let mut last_is_user = true;
 887
 888            thread.update(cx, |thread, cx| {
 889                thread.push_acp_user_block(
 890                    message_id,
 891                    original_content.into_iter().skip(1),
 892                    path_style,
 893                    cx,
 894                );
 895            });
 896
 897            for message in prompt.messages {
 898                let context_server::types::PromptMessage { role, content } = message;
 899                let block = mcp_message_content_to_acp_content_block(content);
 900
 901                match role {
 902                    context_server::types::Role::User => {
 903                        let id = acp_thread::UserMessageId::new();
 904
 905                        acp_thread.update(cx, |acp_thread, cx| {
 906                            acp_thread.push_user_content_block_with_indent(
 907                                Some(id.clone()),
 908                                block.clone(),
 909                                true,
 910                                cx,
 911                            );
 912                        });
 913
 914                        thread.update(cx, |thread, cx| {
 915                            thread.push_acp_user_block(id, [block], path_style, cx);
 916                        });
 917                    }
 918                    context_server::types::Role::Assistant => {
 919                        acp_thread.update(cx, |acp_thread, cx| {
 920                            acp_thread.push_assistant_content_block_with_indent(
 921                                block.clone(),
 922                                false,
 923                                true,
 924                                cx,
 925                            );
 926                        });
 927
 928                        thread.update(cx, |thread, cx| {
 929                            thread.push_acp_agent_block(block, cx);
 930                        });
 931                    }
 932                }
 933
 934                last_is_user = role == context_server::types::Role::User;
 935            }
 936
 937            let response_stream = thread.update(cx, |thread, cx| {
 938                if last_is_user {
 939                    thread.send_existing(cx)
 940                } else {
 941                    // Resume if MCP prompt did not end with a user message
 942                    thread.resume(cx)
 943                }
 944            })?;
 945
 946            cx.update(|cx| {
 947                NativeAgentConnection::handle_thread_events(
 948                    response_stream,
 949                    acp_thread.downgrade(),
 950                    cx,
 951                )
 952            })
 953            .await
 954        })
 955    }
 956}
 957
 958/// Wrapper struct that implements the AgentConnection trait
 959#[derive(Clone)]
 960pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 961
 962impl NativeAgentConnection {
 963    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 964        self.0
 965            .read(cx)
 966            .sessions
 967            .get(session_id)
 968            .map(|session| session.thread.clone())
 969    }
 970
 971    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 972        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 973    }
 974
 975    fn run_turn(
 976        &self,
 977        session_id: acp::SessionId,
 978        cx: &mut App,
 979        f: impl 'static
 980        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 981    ) -> Task<Result<acp::PromptResponse>> {
 982        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 983            agent
 984                .sessions
 985                .get_mut(&session_id)
 986                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 987        }) else {
 988            return Task::ready(Err(anyhow!("Session not found")));
 989        };
 990        log::debug!("Found session for: {}", session_id);
 991
 992        let response_stream = match f(thread, cx) {
 993            Ok(stream) => stream,
 994            Err(err) => return Task::ready(Err(err)),
 995        };
 996        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
 997    }
 998
 999    fn handle_thread_events(
1000        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1001        acp_thread: WeakEntity<AcpThread>,
1002        cx: &App,
1003    ) -> Task<Result<acp::PromptResponse>> {
1004        cx.spawn(async move |cx| {
1005            // Handle response stream and forward to session.acp_thread
1006            while let Some(result) = events.next().await {
1007                match result {
1008                    Ok(event) => {
1009                        log::trace!("Received completion event: {:?}", event);
1010
1011                        match event {
1012                            ThreadEvent::UserMessage(message) => {
1013                                acp_thread.update(cx, |thread, cx| {
1014                                    for content in message.content {
1015                                        thread.push_user_content_block(
1016                                            Some(message.id.clone()),
1017                                            content.into(),
1018                                            cx,
1019                                        );
1020                                    }
1021                                })?;
1022                            }
1023                            ThreadEvent::AgentText(text) => {
1024                                acp_thread.update(cx, |thread, cx| {
1025                                    thread.push_assistant_content_block(text.into(), false, cx)
1026                                })?;
1027                            }
1028                            ThreadEvent::AgentThinking(text) => {
1029                                acp_thread.update(cx, |thread, cx| {
1030                                    thread.push_assistant_content_block(text.into(), true, cx)
1031                                })?;
1032                            }
1033                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1034                                tool_call,
1035                                options,
1036                                response,
1037                                context: _,
1038                            }) => {
1039                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1040                                    thread.request_tool_call_authorization(tool_call, options, cx)
1041                                })??;
1042                                cx.background_spawn(async move {
1043                                    if let acp::RequestPermissionOutcome::Selected(
1044                                        acp::SelectedPermissionOutcome { option_id, .. },
1045                                    ) = outcome_task.await
1046                                    {
1047                                        response
1048                                            .send(option_id)
1049                                            .map(|_| anyhow!("authorization receiver was dropped"))
1050                                            .log_err();
1051                                    }
1052                                })
1053                                .detach();
1054                            }
1055                            ThreadEvent::ToolCall(tool_call) => {
1056                                acp_thread.update(cx, |thread, cx| {
1057                                    thread.upsert_tool_call(tool_call, cx)
1058                                })??;
1059                            }
1060                            ThreadEvent::ToolCallUpdate(update) => {
1061                                acp_thread.update(cx, |thread, cx| {
1062                                    thread.update_tool_call(update, cx)
1063                                })??;
1064                            }
1065                            ThreadEvent::SubagentSpawned(session_id) => {
1066                                acp_thread.update(cx, |thread, cx| {
1067                                    thread.subagent_spawned(session_id, cx);
1068                                })?;
1069                            }
1070                            ThreadEvent::Retry(status) => {
1071                                acp_thread.update(cx, |thread, cx| {
1072                                    thread.update_retry_status(status, cx)
1073                                })?;
1074                            }
1075                            ThreadEvent::Stop(stop_reason) => {
1076                                log::debug!("Assistant message complete: {:?}", stop_reason);
1077                                return Ok(acp::PromptResponse::new(stop_reason));
1078                            }
1079                        }
1080                    }
1081                    Err(e) => {
1082                        log::error!("Error in model response stream: {:?}", e);
1083                        return Err(e);
1084                    }
1085                }
1086            }
1087
1088            log::debug!("Response stream completed");
1089            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1090        })
1091    }
1092}
1093
1094struct Command<'a> {
1095    prompt_name: &'a str,
1096    arg_value: &'a str,
1097    explicit_server_id: Option<&'a str>,
1098}
1099
1100impl<'a> Command<'a> {
1101    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1102        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1103            return None;
1104        };
1105        let text = text_content.text.trim();
1106        let command = text.strip_prefix('/')?;
1107        let (command, arg_value) = command
1108            .split_once(char::is_whitespace)
1109            .unwrap_or((command, ""));
1110
1111        if let Some((server_id, prompt_name)) = command.split_once('.') {
1112            Some(Self {
1113                prompt_name,
1114                arg_value,
1115                explicit_server_id: Some(server_id),
1116            })
1117        } else {
1118            Some(Self {
1119                prompt_name: command,
1120                arg_value,
1121                explicit_server_id: None,
1122            })
1123        }
1124    }
1125}
1126
1127struct NativeAgentModelSelector {
1128    session_id: acp::SessionId,
1129    connection: NativeAgentConnection,
1130}
1131
1132impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1133    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1134        log::debug!("NativeAgentConnection::list_models called");
1135        let list = self.connection.0.read(cx).models.model_list.clone();
1136        Task::ready(if list.is_empty() {
1137            Err(anyhow::anyhow!("No models available"))
1138        } else {
1139            Ok(list)
1140        })
1141    }
1142
1143    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1144        log::debug!(
1145            "Setting model for session {}: {}",
1146            self.session_id,
1147            model_id
1148        );
1149        let Some(thread) = self
1150            .connection
1151            .0
1152            .read(cx)
1153            .sessions
1154            .get(&self.session_id)
1155            .map(|session| session.thread.clone())
1156        else {
1157            return Task::ready(Err(anyhow!("Session not found")));
1158        };
1159
1160        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1161            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1162        };
1163
1164        // We want to reset the effort level when switching models, as the currently-selected effort level may
1165        // not be compatible.
1166        let effort = model
1167            .default_effort_level()
1168            .map(|effort_level| effort_level.value.to_string());
1169
1170        thread.update(cx, |thread, cx| {
1171            thread.set_model(model.clone(), cx);
1172            thread.set_thinking_effort(effort.clone(), cx);
1173            thread.set_thinking_enabled(model.supports_thinking(), cx);
1174        });
1175
1176        update_settings_file(
1177            self.connection.0.read(cx).fs.clone(),
1178            cx,
1179            move |settings, cx| {
1180                let provider = model.provider_id().0.to_string();
1181                let model = model.id().0.to_string();
1182                let enable_thinking = thread.read(cx).thinking_enabled();
1183                settings
1184                    .agent
1185                    .get_or_insert_default()
1186                    .set_model(LanguageModelSelection {
1187                        provider: provider.into(),
1188                        model,
1189                        enable_thinking,
1190                        effort,
1191                    });
1192            },
1193        );
1194
1195        Task::ready(Ok(()))
1196    }
1197
1198    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1199        let Some(thread) = self
1200            .connection
1201            .0
1202            .read(cx)
1203            .sessions
1204            .get(&self.session_id)
1205            .map(|session| session.thread.clone())
1206        else {
1207            return Task::ready(Err(anyhow!("Session not found")));
1208        };
1209        let Some(model) = thread.read(cx).model() else {
1210            return Task::ready(Err(anyhow!("Model not found")));
1211        };
1212        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1213        else {
1214            return Task::ready(Err(anyhow!("Provider not found")));
1215        };
1216        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1217            model, &provider,
1218        )))
1219    }
1220
1221    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1222        Some(self.connection.0.read(cx).models.watch())
1223    }
1224
1225    fn should_render_footer(&self) -> bool {
1226        true
1227    }
1228}
1229
1230impl acp_thread::AgentConnection for NativeAgentConnection {
1231    fn telemetry_id(&self) -> SharedString {
1232        "zed".into()
1233    }
1234
1235    fn new_session(
1236        self: Rc<Self>,
1237        project: Entity<Project>,
1238        cwd: &Path,
1239        cx: &mut App,
1240    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1241        log::debug!("Creating new thread for project at: {cwd:?}");
1242        Task::ready(Ok(self
1243            .0
1244            .update(cx, |agent, cx| agent.new_session(project, cx))))
1245    }
1246
1247    fn supports_load_session(&self) -> bool {
1248        true
1249    }
1250
1251    fn load_session(
1252        self: Rc<Self>,
1253        session: AgentSessionInfo,
1254        _project: Entity<Project>,
1255        _cwd: &Path,
1256        cx: &mut App,
1257    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1258        self.0
1259            .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1260    }
1261
1262    fn supports_close_session(&self) -> bool {
1263        true
1264    }
1265
1266    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1267        self.0.update(cx, |agent, _cx| {
1268            agent.sessions.remove(session_id);
1269        });
1270        Task::ready(Ok(()))
1271    }
1272
1273    fn auth_methods(&self) -> &[acp::AuthMethod] {
1274        &[] // No auth for in-process
1275    }
1276
1277    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1278        Task::ready(Ok(()))
1279    }
1280
1281    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1282        Some(Rc::new(NativeAgentModelSelector {
1283            session_id: session_id.clone(),
1284            connection: self.clone(),
1285        }) as Rc<dyn AgentModelSelector>)
1286    }
1287
1288    fn prompt(
1289        &self,
1290        id: Option<acp_thread::UserMessageId>,
1291        params: acp::PromptRequest,
1292        cx: &mut App,
1293    ) -> Task<Result<acp::PromptResponse>> {
1294        let id = id.expect("UserMessageId is required");
1295        let session_id = params.session_id.clone();
1296        log::info!("Received prompt request for session: {}", session_id);
1297        log::debug!("Prompt blocks count: {}", params.prompt.len());
1298
1299        if let Some(parsed_command) = Command::parse(&params.prompt) {
1300            let registry = self.0.read(cx).context_server_registry.read(cx);
1301
1302            let explicit_server_id = parsed_command
1303                .explicit_server_id
1304                .map(|server_id| ContextServerId(server_id.into()));
1305
1306            if let Some(prompt) =
1307                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1308            {
1309                let arguments = if !parsed_command.arg_value.is_empty()
1310                    && let Some(arg_name) = prompt
1311                        .prompt
1312                        .arguments
1313                        .as_ref()
1314                        .and_then(|args| args.first())
1315                        .map(|arg| arg.name.clone())
1316                {
1317                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1318                } else {
1319                    Default::default()
1320                };
1321
1322                let prompt_name = prompt.prompt.name.clone();
1323                let server_id = prompt.server_id.clone();
1324
1325                return self.0.update(cx, |agent, cx| {
1326                    agent.send_mcp_prompt(
1327                        id,
1328                        session_id.clone(),
1329                        prompt_name,
1330                        server_id,
1331                        arguments,
1332                        params.prompt,
1333                        cx,
1334                    )
1335                });
1336            };
1337        };
1338
1339        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1340
1341        self.run_turn(session_id, cx, move |thread, cx| {
1342            let content: Vec<UserMessageContent> = params
1343                .prompt
1344                .into_iter()
1345                .map(|block| UserMessageContent::from_content_block(block, path_style))
1346                .collect::<Vec<_>>();
1347            log::debug!("Converted prompt to message: {} chars", content.len());
1348            log::debug!("Message id: {:?}", id);
1349            log::debug!("Message content: {:?}", content);
1350
1351            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1352        })
1353    }
1354
1355    fn retry(
1356        &self,
1357        session_id: &acp::SessionId,
1358        _cx: &App,
1359    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1360        Some(Rc::new(NativeAgentSessionRetry {
1361            connection: self.clone(),
1362            session_id: session_id.clone(),
1363        }) as _)
1364    }
1365
1366    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1367        log::info!("Cancelling on session: {}", session_id);
1368        self.0.update(cx, |agent, cx| {
1369            if let Some(session) = agent.sessions.get(session_id) {
1370                session
1371                    .thread
1372                    .update(cx, |thread, cx| thread.cancel(cx))
1373                    .detach();
1374            }
1375        });
1376    }
1377
1378    fn truncate(
1379        &self,
1380        session_id: &agent_client_protocol::SessionId,
1381        cx: &App,
1382    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1383        self.0.read_with(cx, |agent, _cx| {
1384            agent.sessions.get(session_id).map(|session| {
1385                Rc::new(NativeAgentSessionTruncate {
1386                    thread: session.thread.clone(),
1387                    acp_thread: session.acp_thread.downgrade(),
1388                }) as _
1389            })
1390        })
1391    }
1392
1393    fn set_title(
1394        &self,
1395        session_id: &acp::SessionId,
1396        cx: &App,
1397    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1398        self.0.read_with(cx, |agent, _cx| {
1399            agent
1400                .sessions
1401                .get(session_id)
1402                .filter(|s| !s.thread.read(cx).is_subagent())
1403                .map(|session| {
1404                    Rc::new(NativeAgentSessionSetTitle {
1405                        thread: session.thread.clone(),
1406                    }) as _
1407                })
1408        })
1409    }
1410
1411    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1412        let thread_store = self.0.read(cx).thread_store.clone();
1413        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1414    }
1415
1416    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1417        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1418    }
1419
1420    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1421        self
1422    }
1423}
1424
1425impl acp_thread::AgentTelemetry for NativeAgentConnection {
1426    fn thread_data(
1427        &self,
1428        session_id: &acp::SessionId,
1429        cx: &mut App,
1430    ) -> Task<Result<serde_json::Value>> {
1431        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1432            return Task::ready(Err(anyhow!("Session not found")));
1433        };
1434
1435        let task = session.thread.read(cx).to_db(cx);
1436        cx.background_spawn(async move {
1437            serde_json::to_value(task.await).context("Failed to serialize thread")
1438        })
1439    }
1440}
1441
1442pub struct NativeAgentSessionList {
1443    thread_store: Entity<ThreadStore>,
1444    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1445    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1446    _subscription: Subscription,
1447}
1448
1449impl NativeAgentSessionList {
1450    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1451        let (tx, rx) = smol::channel::unbounded();
1452        let this_tx = tx.clone();
1453        let subscription = cx.observe(&thread_store, move |_, _| {
1454            this_tx
1455                .try_send(acp_thread::SessionListUpdate::Refresh)
1456                .ok();
1457        });
1458        Self {
1459            thread_store,
1460            updates_tx: tx,
1461            updates_rx: rx,
1462            _subscription: subscription,
1463        }
1464    }
1465
1466    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1467        AgentSessionInfo {
1468            session_id: entry.id,
1469            cwd: None,
1470            title: Some(entry.title),
1471            updated_at: Some(entry.updated_at),
1472            meta: None,
1473        }
1474    }
1475
1476    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1477        &self.thread_store
1478    }
1479}
1480
1481impl AgentSessionList for NativeAgentSessionList {
1482    fn list_sessions(
1483        &self,
1484        _request: AgentSessionListRequest,
1485        cx: &mut App,
1486    ) -> Task<Result<AgentSessionListResponse>> {
1487        let sessions = self
1488            .thread_store
1489            .read(cx)
1490            .entries()
1491            .map(Self::to_session_info)
1492            .collect();
1493        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1494    }
1495
1496    fn supports_delete(&self) -> bool {
1497        true
1498    }
1499
1500    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1501        self.thread_store
1502            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1503    }
1504
1505    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1506        self.thread_store
1507            .update(cx, |store, cx| store.delete_threads(cx))
1508    }
1509
1510    fn watch(
1511        &self,
1512        _cx: &mut App,
1513    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1514        Some(self.updates_rx.clone())
1515    }
1516
1517    fn notify_refresh(&self) {
1518        self.updates_tx
1519            .try_send(acp_thread::SessionListUpdate::Refresh)
1520            .ok();
1521    }
1522
1523    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1524        self
1525    }
1526}
1527
1528struct NativeAgentSessionTruncate {
1529    thread: Entity<Thread>,
1530    acp_thread: WeakEntity<AcpThread>,
1531}
1532
1533impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1534    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1535        match self.thread.update(cx, |thread, cx| {
1536            thread.truncate(message_id.clone(), cx)?;
1537            Ok(thread.latest_token_usage())
1538        }) {
1539            Ok(usage) => {
1540                self.acp_thread
1541                    .update(cx, |thread, cx| {
1542                        thread.update_token_usage(usage, cx);
1543                    })
1544                    .ok();
1545                Task::ready(Ok(()))
1546            }
1547            Err(error) => Task::ready(Err(error)),
1548        }
1549    }
1550}
1551
1552struct NativeAgentSessionRetry {
1553    connection: NativeAgentConnection,
1554    session_id: acp::SessionId,
1555}
1556
1557impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1558    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1559        self.connection
1560            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1561                thread.update(cx, |thread, cx| thread.resume(cx))
1562            })
1563    }
1564}
1565
1566struct NativeAgentSessionSetTitle {
1567    thread: Entity<Thread>,
1568}
1569
1570impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1571    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1572        self.thread
1573            .update(cx, |thread, cx| thread.set_title(title, cx));
1574        Task::ready(Ok(()))
1575    }
1576}
1577
1578pub struct NativeThreadEnvironment {
1579    agent: WeakEntity<NativeAgent>,
1580    acp_thread: WeakEntity<AcpThread>,
1581}
1582
1583impl NativeThreadEnvironment {
1584    pub(crate) fn create_subagent_thread(
1585        agent: WeakEntity<NativeAgent>,
1586        parent_thread_entity: Entity<Thread>,
1587        label: String,
1588        initial_prompt: String,
1589        timeout: Option<Duration>,
1590        cx: &mut App,
1591    ) -> Result<Rc<dyn SubagentHandle>> {
1592        let parent_thread = parent_thread_entity.read(cx);
1593        let current_depth = parent_thread.depth();
1594
1595        if current_depth >= MAX_SUBAGENT_DEPTH {
1596            return Err(anyhow!(
1597                "Maximum subagent depth ({}) reached",
1598                MAX_SUBAGENT_DEPTH
1599            ));
1600        }
1601
1602        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1603            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1604            thread.set_title(label.into(), cx);
1605            thread
1606        });
1607
1608        let session_id = subagent_thread.read(cx).id().clone();
1609
1610        let acp_thread = agent.update(cx, |agent, cx| {
1611            agent.register_session(subagent_thread.clone(), cx)
1612        })?;
1613
1614        Self::prompt_subagent(
1615            session_id,
1616            subagent_thread,
1617            acp_thread,
1618            parent_thread_entity,
1619            initial_prompt,
1620            timeout,
1621            cx,
1622        )
1623    }
1624
1625    pub(crate) fn resume_subagent_thread(
1626        agent: WeakEntity<NativeAgent>,
1627        parent_thread_entity: Entity<Thread>,
1628        session_id: acp::SessionId,
1629        follow_up_prompt: String,
1630        timeout: Option<Duration>,
1631        cx: &mut App,
1632    ) -> Result<Rc<dyn SubagentHandle>> {
1633        let (subagent_thread, acp_thread) = agent.update(cx, |agent, _cx| {
1634            let session = agent
1635                .sessions
1636                .get(&session_id)
1637                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1638            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1639        })??;
1640
1641        Self::prompt_subagent(
1642            session_id,
1643            subagent_thread,
1644            acp_thread,
1645            parent_thread_entity,
1646            follow_up_prompt,
1647            timeout,
1648            cx,
1649        )
1650    }
1651
1652    fn prompt_subagent(
1653        session_id: acp::SessionId,
1654        subagent_thread: Entity<Thread>,
1655        acp_thread: Entity<acp_thread::AcpThread>,
1656        parent_thread_entity: Entity<Thread>,
1657        prompt: String,
1658        timeout: Option<Duration>,
1659        cx: &mut App,
1660    ) -> Result<Rc<dyn SubagentHandle>> {
1661        parent_thread_entity.update(cx, |parent_thread, _cx| {
1662            parent_thread.register_running_subagent(subagent_thread.downgrade())
1663        });
1664
1665        let task = acp_thread.update(cx, |acp_thread, cx| {
1666            acp_thread.send(vec![prompt.into()], cx)
1667        });
1668
1669        let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1670        let wait_for_prompt_to_complete = cx
1671            .background_spawn(async move {
1672                if let Some(timer) = timeout_timer {
1673                    futures::select! {
1674                        _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1675                        response = task.fuse() => {
1676                            let response = response.log_err().flatten();
1677                            if response.is_some_and(|response| {
1678                                response.stop_reason == acp::StopReason::Cancelled
1679                            })
1680                            {
1681                                SubagentInitialPromptResult::Cancelled
1682                            } else {
1683                                SubagentInitialPromptResult::Completed
1684                            }
1685                        },
1686                    }
1687                } else {
1688                    let response = task.await.log_err().flatten();
1689                    if response
1690                        .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled)
1691                    {
1692                        SubagentInitialPromptResult::Cancelled
1693                    } else {
1694                        SubagentInitialPromptResult::Completed
1695                    }
1696                }
1697            })
1698            .shared();
1699
1700        Ok(Rc::new(NativeSubagentHandle {
1701            session_id,
1702            subagent_thread,
1703            parent_thread: parent_thread_entity.downgrade(),
1704            wait_for_prompt_to_complete,
1705        }) as _)
1706    }
1707}
1708
1709impl ThreadEnvironment for NativeThreadEnvironment {
1710    fn create_terminal(
1711        &self,
1712        command: String,
1713        cwd: Option<PathBuf>,
1714        output_byte_limit: Option<u64>,
1715        cx: &mut AsyncApp,
1716    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1717        let task = self.acp_thread.update(cx, |thread, cx| {
1718            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1719        });
1720
1721        let acp_thread = self.acp_thread.clone();
1722        cx.spawn(async move |cx| {
1723            let terminal = task?.await?;
1724
1725            let (drop_tx, drop_rx) = oneshot::channel();
1726            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1727
1728            cx.spawn(async move |cx| {
1729                drop_rx.await.ok();
1730                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1731            })
1732            .detach();
1733
1734            let handle = AcpTerminalHandle {
1735                terminal,
1736                _drop_tx: Some(drop_tx),
1737            };
1738
1739            Ok(Rc::new(handle) as _)
1740        })
1741    }
1742
1743    fn create_subagent(
1744        &self,
1745        parent_thread_entity: Entity<Thread>,
1746        label: String,
1747        initial_prompt: String,
1748        timeout: Option<Duration>,
1749        cx: &mut App,
1750    ) -> Result<Rc<dyn SubagentHandle>> {
1751        Self::create_subagent_thread(
1752            self.agent.clone(),
1753            parent_thread_entity,
1754            label,
1755            initial_prompt,
1756            timeout,
1757            cx,
1758        )
1759    }
1760
1761    fn resume_subagent(
1762        &self,
1763        parent_thread_entity: Entity<Thread>,
1764        session_id: acp::SessionId,
1765        follow_up_prompt: String,
1766        timeout: Option<Duration>,
1767        cx: &mut App,
1768    ) -> Result<Rc<dyn SubagentHandle>> {
1769        Self::resume_subagent_thread(
1770            self.agent.clone(),
1771            parent_thread_entity,
1772            session_id,
1773            follow_up_prompt,
1774            timeout,
1775            cx,
1776        )
1777    }
1778}
1779
1780#[derive(Debug, Clone, Copy)]
1781enum SubagentInitialPromptResult {
1782    Completed,
1783    Timeout,
1784    Cancelled,
1785}
1786
1787pub struct NativeSubagentHandle {
1788    session_id: acp::SessionId,
1789    parent_thread: WeakEntity<Thread>,
1790    subagent_thread: Entity<Thread>,
1791    wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1792}
1793
1794impl SubagentHandle for NativeSubagentHandle {
1795    fn id(&self) -> acp::SessionId {
1796        self.session_id.clone()
1797    }
1798
1799    fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
1800        let thread = self.subagent_thread.clone();
1801        let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1802
1803        let subagent_session_id = self.session_id.clone();
1804        let parent_thread = self.parent_thread.clone();
1805
1806        cx.spawn(async move |cx| {
1807            let result = match wait_for_prompt.await {
1808                SubagentInitialPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1809                    thread
1810                        .last_message()
1811                        .map(|m| m.to_markdown())
1812                        .context("No response from subagent")
1813                }),
1814                SubagentInitialPromptResult::Timeout => {
1815                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1816                    Err(anyhow!("The time to complete the task was exceeded."))
1817                }
1818                SubagentInitialPromptResult::Cancelled => Err(anyhow!("User cancelled")),
1819            };
1820
1821            parent_thread
1822                .update(cx, |parent_thread, cx| {
1823                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1824                })
1825                .ok();
1826
1827            result
1828        })
1829    }
1830}
1831
1832pub struct AcpTerminalHandle {
1833    terminal: Entity<acp_thread::Terminal>,
1834    _drop_tx: Option<oneshot::Sender<()>>,
1835}
1836
1837impl TerminalHandle for AcpTerminalHandle {
1838    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1839        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1840    }
1841
1842    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1843        Ok(self
1844            .terminal
1845            .read_with(cx, |term, _cx| term.wait_for_exit()))
1846    }
1847
1848    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1849        Ok(self
1850            .terminal
1851            .read_with(cx, |term, cx| term.current_output(cx)))
1852    }
1853
1854    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1855        cx.update(|cx| {
1856            self.terminal.update(cx, |terminal, cx| {
1857                terminal.kill(cx);
1858            });
1859        });
1860        Ok(())
1861    }
1862
1863    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1864        Ok(self
1865            .terminal
1866            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1867    }
1868}
1869
1870#[cfg(test)]
1871mod internal_tests {
1872    use super::*;
1873    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1874    use fs::FakeFs;
1875    use gpui::TestAppContext;
1876    use indoc::formatdoc;
1877    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1878    use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1879    use serde_json::json;
1880    use settings::SettingsStore;
1881    use util::{path, rel_path::rel_path};
1882
1883    #[gpui::test]
1884    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1885        init_test(cx);
1886        let fs = FakeFs::new(cx.executor());
1887        fs.insert_tree(
1888            "/",
1889            json!({
1890                "a": {}
1891            }),
1892        )
1893        .await;
1894        let project = Project::test(fs.clone(), [], cx).await;
1895        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1896        let agent = NativeAgent::new(
1897            project.clone(),
1898            thread_store,
1899            Templates::new(),
1900            None,
1901            fs.clone(),
1902            &mut cx.to_async(),
1903        )
1904        .await
1905        .unwrap();
1906        agent.read_with(cx, |agent, cx| {
1907            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1908        });
1909
1910        let worktree = project
1911            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1912            .await
1913            .unwrap();
1914        cx.run_until_parked();
1915        agent.read_with(cx, |agent, cx| {
1916            assert_eq!(
1917                agent.project_context.read(cx).worktrees,
1918                vec![WorktreeContext {
1919                    root_name: "a".into(),
1920                    abs_path: Path::new("/a").into(),
1921                    rules_file: None
1922                }]
1923            )
1924        });
1925
1926        // Creating `/a/.rules` updates the project context.
1927        fs.insert_file("/a/.rules", Vec::new()).await;
1928        cx.run_until_parked();
1929        agent.read_with(cx, |agent, cx| {
1930            let rules_entry = worktree
1931                .read(cx)
1932                .entry_for_path(rel_path(".rules"))
1933                .unwrap();
1934            assert_eq!(
1935                agent.project_context.read(cx).worktrees,
1936                vec![WorktreeContext {
1937                    root_name: "a".into(),
1938                    abs_path: Path::new("/a").into(),
1939                    rules_file: Some(RulesFileContext {
1940                        path_in_worktree: rel_path(".rules").into(),
1941                        text: "".into(),
1942                        project_entry_id: rules_entry.id.to_usize()
1943                    })
1944                }]
1945            )
1946        });
1947    }
1948
1949    #[gpui::test]
1950    async fn test_listing_models(cx: &mut TestAppContext) {
1951        init_test(cx);
1952        let fs = FakeFs::new(cx.executor());
1953        fs.insert_tree("/", json!({ "a": {}  })).await;
1954        let project = Project::test(fs.clone(), [], cx).await;
1955        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1956        let connection = NativeAgentConnection(
1957            NativeAgent::new(
1958                project.clone(),
1959                thread_store,
1960                Templates::new(),
1961                None,
1962                fs.clone(),
1963                &mut cx.to_async(),
1964            )
1965            .await
1966            .unwrap(),
1967        );
1968
1969        // Create a thread/session
1970        let acp_thread = cx
1971            .update(|cx| {
1972                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1973            })
1974            .await
1975            .unwrap();
1976
1977        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1978
1979        let models = cx
1980            .update(|cx| {
1981                connection
1982                    .model_selector(&session_id)
1983                    .unwrap()
1984                    .list_models(cx)
1985            })
1986            .await
1987            .unwrap();
1988
1989        let acp_thread::AgentModelList::Grouped(models) = models else {
1990            panic!("Unexpected model group");
1991        };
1992        assert_eq!(
1993            models,
1994            IndexMap::from_iter([(
1995                AgentModelGroupName("Fake".into()),
1996                vec![AgentModelInfo {
1997                    id: acp::ModelId::new("fake/fake"),
1998                    name: "Fake".into(),
1999                    description: None,
2000                    icon: Some(acp_thread::AgentModelIcon::Named(
2001                        ui::IconName::ZedAssistant
2002                    )),
2003                    is_latest: false,
2004                    cost: None,
2005                }]
2006            )])
2007        );
2008    }
2009
2010    #[gpui::test]
2011    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2012        init_test(cx);
2013        let fs = FakeFs::new(cx.executor());
2014        fs.create_dir(paths::settings_file().parent().unwrap())
2015            .await
2016            .unwrap();
2017        fs.insert_file(
2018            paths::settings_file(),
2019            json!({
2020                "agent": {
2021                    "default_model": {
2022                        "provider": "foo",
2023                        "model": "bar"
2024                    }
2025                }
2026            })
2027            .to_string()
2028            .into_bytes(),
2029        )
2030        .await;
2031        let project = Project::test(fs.clone(), [], cx).await;
2032
2033        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2034
2035        // Create the agent and connection
2036        let agent = NativeAgent::new(
2037            project.clone(),
2038            thread_store,
2039            Templates::new(),
2040            None,
2041            fs.clone(),
2042            &mut cx.to_async(),
2043        )
2044        .await
2045        .unwrap();
2046        let connection = NativeAgentConnection(agent.clone());
2047
2048        // Create a thread/session
2049        let acp_thread = cx
2050            .update(|cx| {
2051                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2052            })
2053            .await
2054            .unwrap();
2055
2056        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2057
2058        // Select a model
2059        let selector = connection.model_selector(&session_id).unwrap();
2060        let model_id = acp::ModelId::new("fake/fake");
2061        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2062            .await
2063            .unwrap();
2064
2065        // Verify the thread has the selected model
2066        agent.read_with(cx, |agent, _| {
2067            let session = agent.sessions.get(&session_id).unwrap();
2068            session.thread.read_with(cx, |thread, _| {
2069                assert_eq!(thread.model().unwrap().id().0, "fake");
2070            });
2071        });
2072
2073        cx.run_until_parked();
2074
2075        // Verify settings file was updated
2076        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2077        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2078
2079        // Check that the agent settings contain the selected model
2080        assert_eq!(
2081            settings_json["agent"]["default_model"]["model"],
2082            json!("fake")
2083        );
2084        assert_eq!(
2085            settings_json["agent"]["default_model"]["provider"],
2086            json!("fake")
2087        );
2088
2089        // Register a thinking model and select it.
2090        cx.update(|cx| {
2091            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2092                "fake-corp",
2093                "fake-thinking",
2094                "Fake Thinking",
2095                true,
2096            ));
2097            let thinking_provider = Arc::new(
2098                FakeLanguageModelProvider::new(
2099                    LanguageModelProviderId::from("fake-corp".to_string()),
2100                    LanguageModelProviderName::from("Fake Corp".to_string()),
2101                )
2102                .with_models(vec![thinking_model]),
2103            );
2104            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2105                registry.register_provider(thinking_provider, cx);
2106            });
2107        });
2108        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2109
2110        let selector = connection.model_selector(&session_id).unwrap();
2111        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2112            .await
2113            .unwrap();
2114        cx.run_until_parked();
2115
2116        // Verify enable_thinking was written to settings as true.
2117        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2118        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2119        assert_eq!(
2120            settings_json["agent"]["default_model"]["enable_thinking"],
2121            json!(true),
2122            "selecting a thinking model should persist enable_thinking: true to settings"
2123        );
2124    }
2125
2126    #[gpui::test]
2127    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2128        init_test(cx);
2129        let fs = FakeFs::new(cx.executor());
2130        fs.create_dir(paths::settings_file().parent().unwrap())
2131            .await
2132            .unwrap();
2133        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2134        let project = Project::test(fs.clone(), [], cx).await;
2135
2136        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2137        let agent = NativeAgent::new(
2138            project.clone(),
2139            thread_store,
2140            Templates::new(),
2141            None,
2142            fs.clone(),
2143            &mut cx.to_async(),
2144        )
2145        .await
2146        .unwrap();
2147        let connection = NativeAgentConnection(agent.clone());
2148
2149        let acp_thread = cx
2150            .update(|cx| {
2151                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2152            })
2153            .await
2154            .unwrap();
2155        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2156
2157        // Register a second provider with a thinking model.
2158        cx.update(|cx| {
2159            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2160                "fake-corp",
2161                "fake-thinking",
2162                "Fake Thinking",
2163                true,
2164            ));
2165            let thinking_provider = Arc::new(
2166                FakeLanguageModelProvider::new(
2167                    LanguageModelProviderId::from("fake-corp".to_string()),
2168                    LanguageModelProviderName::from("Fake Corp".to_string()),
2169                )
2170                .with_models(vec![thinking_model]),
2171            );
2172            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2173                registry.register_provider(thinking_provider, cx);
2174            });
2175        });
2176        // Refresh the agent's model list so it picks up the new provider.
2177        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2178
2179        // Thread starts with thinking_enabled = false (the default).
2180        agent.read_with(cx, |agent, _| {
2181            let session = agent.sessions.get(&session_id).unwrap();
2182            session.thread.read_with(cx, |thread, _| {
2183                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2184            });
2185        });
2186
2187        // Select the thinking model via select_model.
2188        let selector = connection.model_selector(&session_id).unwrap();
2189        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2190            .await
2191            .unwrap();
2192
2193        // select_model should have enabled thinking based on the model's supports_thinking().
2194        agent.read_with(cx, |agent, _| {
2195            let session = agent.sessions.get(&session_id).unwrap();
2196            session.thread.read_with(cx, |thread, _| {
2197                assert!(
2198                    thread.thinking_enabled(),
2199                    "select_model should enable thinking when model supports it"
2200                );
2201            });
2202        });
2203
2204        // Switch back to the non-thinking model.
2205        let selector = connection.model_selector(&session_id).unwrap();
2206        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2207            .await
2208            .unwrap();
2209
2210        // select_model should have disabled thinking.
2211        agent.read_with(cx, |agent, _| {
2212            let session = agent.sessions.get(&session_id).unwrap();
2213            session.thread.read_with(cx, |thread, _| {
2214                assert!(
2215                    !thread.thinking_enabled(),
2216                    "select_model should disable thinking when model does not support it"
2217                );
2218            });
2219        });
2220    }
2221
2222    #[gpui::test]
2223    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2224        init_test(cx);
2225        let fs = FakeFs::new(cx.executor());
2226        fs.insert_tree("/", json!({ "a": {} })).await;
2227        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2228        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2229        let agent = NativeAgent::new(
2230            project.clone(),
2231            thread_store.clone(),
2232            Templates::new(),
2233            None,
2234            fs.clone(),
2235            &mut cx.to_async(),
2236        )
2237        .await
2238        .unwrap();
2239        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2240
2241        // Register a thinking model.
2242        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2243            "fake-corp",
2244            "fake-thinking",
2245            "Fake Thinking",
2246            true,
2247        ));
2248        let thinking_provider = Arc::new(
2249            FakeLanguageModelProvider::new(
2250                LanguageModelProviderId::from("fake-corp".to_string()),
2251                LanguageModelProviderName::from("Fake Corp".to_string()),
2252            )
2253            .with_models(vec![thinking_model.clone()]),
2254        );
2255        cx.update(|cx| {
2256            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2257                registry.register_provider(thinking_provider, cx);
2258            });
2259        });
2260        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2261
2262        // Create a thread and select the thinking model.
2263        let acp_thread = cx
2264            .update(|cx| {
2265                connection
2266                    .clone()
2267                    .new_session(project.clone(), Path::new("/a"), cx)
2268            })
2269            .await
2270            .unwrap();
2271        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2272
2273        let selector = connection.model_selector(&session_id).unwrap();
2274        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2275            .await
2276            .unwrap();
2277
2278        // Verify thinking is enabled after selecting the thinking model.
2279        let thread = agent.read_with(cx, |agent, _| {
2280            agent.sessions.get(&session_id).unwrap().thread.clone()
2281        });
2282        thread.read_with(cx, |thread, _| {
2283            assert!(
2284                thread.thinking_enabled(),
2285                "thinking should be enabled after selecting thinking model"
2286            );
2287        });
2288
2289        // Send a message so the thread gets persisted.
2290        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2291        let send = cx.foreground_executor().spawn(send);
2292        cx.run_until_parked();
2293
2294        thinking_model.send_last_completion_stream_text_chunk("Response.");
2295        thinking_model.end_last_completion_stream();
2296
2297        send.await.unwrap();
2298        cx.run_until_parked();
2299
2300        // Close the session so it can be reloaded from disk.
2301        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2302            .await
2303            .unwrap();
2304        drop(thread);
2305        drop(acp_thread);
2306        agent.read_with(cx, |agent, _| {
2307            assert!(agent.sessions.is_empty());
2308        });
2309
2310        // Reload the thread and verify thinking_enabled is still true.
2311        let reloaded_acp_thread = agent
2312            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2313            .await
2314            .unwrap();
2315        let reloaded_thread = agent.read_with(cx, |agent, _| {
2316            agent.sessions.get(&session_id).unwrap().thread.clone()
2317        });
2318        reloaded_thread.read_with(cx, |thread, _| {
2319            assert!(
2320                thread.thinking_enabled(),
2321                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2322            );
2323        });
2324
2325        drop(reloaded_acp_thread);
2326    }
2327
2328    #[gpui::test]
2329    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2330        init_test(cx);
2331        let fs = FakeFs::new(cx.executor());
2332        fs.insert_tree("/", json!({ "a": {} })).await;
2333        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2334        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2335        let agent = NativeAgent::new(
2336            project.clone(),
2337            thread_store.clone(),
2338            Templates::new(),
2339            None,
2340            fs.clone(),
2341            &mut cx.to_async(),
2342        )
2343        .await
2344        .unwrap();
2345        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2346
2347        // Register a model where id() != name(), like real Anthropic models
2348        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2349        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2350            "fake-corp",
2351            "custom-model-id",
2352            "Custom Model Display Name",
2353            false,
2354        ));
2355        let provider = Arc::new(
2356            FakeLanguageModelProvider::new(
2357                LanguageModelProviderId::from("fake-corp".to_string()),
2358                LanguageModelProviderName::from("Fake Corp".to_string()),
2359            )
2360            .with_models(vec![model.clone()]),
2361        );
2362        cx.update(|cx| {
2363            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2364                registry.register_provider(provider, cx);
2365            });
2366        });
2367        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2368
2369        // Create a thread and select the model.
2370        let acp_thread = cx
2371            .update(|cx| {
2372                connection
2373                    .clone()
2374                    .new_session(project.clone(), Path::new("/a"), cx)
2375            })
2376            .await
2377            .unwrap();
2378        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2379
2380        let selector = connection.model_selector(&session_id).unwrap();
2381        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2382            .await
2383            .unwrap();
2384
2385        let thread = agent.read_with(cx, |agent, _| {
2386            agent.sessions.get(&session_id).unwrap().thread.clone()
2387        });
2388        thread.read_with(cx, |thread, _| {
2389            assert_eq!(
2390                thread.model().unwrap().id().0.as_ref(),
2391                "custom-model-id",
2392                "model should be set before persisting"
2393            );
2394        });
2395
2396        // Send a message so the thread gets persisted.
2397        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2398        let send = cx.foreground_executor().spawn(send);
2399        cx.run_until_parked();
2400
2401        model.send_last_completion_stream_text_chunk("Response.");
2402        model.end_last_completion_stream();
2403
2404        send.await.unwrap();
2405        cx.run_until_parked();
2406
2407        // Close the session so it can be reloaded from disk.
2408        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2409            .await
2410            .unwrap();
2411        drop(thread);
2412        drop(acp_thread);
2413        agent.read_with(cx, |agent, _| {
2414            assert!(agent.sessions.is_empty());
2415        });
2416
2417        // Reload the thread and verify the model was preserved.
2418        let reloaded_acp_thread = agent
2419            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2420            .await
2421            .unwrap();
2422        let reloaded_thread = agent.read_with(cx, |agent, _| {
2423            agent.sessions.get(&session_id).unwrap().thread.clone()
2424        });
2425        reloaded_thread.read_with(cx, |thread, _| {
2426            let reloaded_model = thread
2427                .model()
2428                .expect("model should be present after reload");
2429            assert_eq!(
2430                reloaded_model.id().0.as_ref(),
2431                "custom-model-id",
2432                "reloaded thread should have the same model, not fall back to the default"
2433            );
2434        });
2435
2436        drop(reloaded_acp_thread);
2437    }
2438
2439    #[gpui::test]
2440    async fn test_save_load_thread(cx: &mut TestAppContext) {
2441        init_test(cx);
2442        let fs = FakeFs::new(cx.executor());
2443        fs.insert_tree(
2444            "/",
2445            json!({
2446                "a": {
2447                    "b.md": "Lorem"
2448                }
2449            }),
2450        )
2451        .await;
2452        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2453        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2454        let agent = NativeAgent::new(
2455            project.clone(),
2456            thread_store.clone(),
2457            Templates::new(),
2458            None,
2459            fs.clone(),
2460            &mut cx.to_async(),
2461        )
2462        .await
2463        .unwrap();
2464        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2465
2466        let acp_thread = cx
2467            .update(|cx| {
2468                connection
2469                    .clone()
2470                    .new_session(project.clone(), Path::new(""), cx)
2471            })
2472            .await
2473            .unwrap();
2474        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2475        let thread = agent.read_with(cx, |agent, _| {
2476            agent.sessions.get(&session_id).unwrap().thread.clone()
2477        });
2478
2479        // Ensure empty threads are not saved, even if they get mutated.
2480        let model = Arc::new(FakeLanguageModel::default());
2481        let summary_model = Arc::new(FakeLanguageModel::default());
2482        thread.update(cx, |thread, cx| {
2483            thread.set_model(model.clone(), cx);
2484            thread.set_summarization_model(Some(summary_model.clone()), cx);
2485        });
2486        cx.run_until_parked();
2487        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2488
2489        let send = acp_thread.update(cx, |thread, cx| {
2490            thread.send(
2491                vec![
2492                    "What does ".into(),
2493                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2494                        "b.md",
2495                        MentionUri::File {
2496                            abs_path: path!("/a/b.md").into(),
2497                        }
2498                        .to_uri()
2499                        .to_string(),
2500                    )),
2501                    " mean?".into(),
2502                ],
2503                cx,
2504            )
2505        });
2506        let send = cx.foreground_executor().spawn(send);
2507        cx.run_until_parked();
2508
2509        model.send_last_completion_stream_text_chunk("Lorem.");
2510        model.end_last_completion_stream();
2511        cx.run_until_parked();
2512        summary_model
2513            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2514        summary_model.end_last_completion_stream();
2515
2516        send.await.unwrap();
2517        let uri = MentionUri::File {
2518            abs_path: path!("/a/b.md").into(),
2519        }
2520        .to_uri();
2521        acp_thread.read_with(cx, |thread, cx| {
2522            assert_eq!(
2523                thread.to_markdown(cx),
2524                formatdoc! {"
2525                    ## User
2526
2527                    What does [@b.md]({uri}) mean?
2528
2529                    ## Assistant
2530
2531                    Lorem.
2532
2533                "}
2534            )
2535        });
2536
2537        cx.run_until_parked();
2538
2539        // Close the session so it can be reloaded from disk.
2540        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2541            .await
2542            .unwrap();
2543        drop(thread);
2544        drop(acp_thread);
2545        agent.read_with(cx, |agent, _| {
2546            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2547        });
2548
2549        // Ensure the thread can be reloaded from disk.
2550        assert_eq!(
2551            thread_entries(&thread_store, cx),
2552            vec![(
2553                session_id.clone(),
2554                format!("Explaining {}", path!("/a/b.md"))
2555            )]
2556        );
2557        let acp_thread = agent
2558            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2559            .await
2560            .unwrap();
2561        acp_thread.read_with(cx, |thread, cx| {
2562            assert_eq!(
2563                thread.to_markdown(cx),
2564                formatdoc! {"
2565                    ## User
2566
2567                    What does [@b.md]({uri}) mean?
2568
2569                    ## Assistant
2570
2571                    Lorem.
2572
2573                "}
2574            )
2575        });
2576    }
2577
2578    fn thread_entries(
2579        thread_store: &Entity<ThreadStore>,
2580        cx: &mut TestAppContext,
2581    ) -> Vec<(acp::SessionId, String)> {
2582        thread_store.read_with(cx, |store, _| {
2583            store
2584                .entries()
2585                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2586                .collect::<Vec<_>>()
2587        })
2588    }
2589
2590    fn init_test(cx: &mut TestAppContext) {
2591        env_logger::try_init().ok();
2592        cx.update(|cx| {
2593            let settings_store = SettingsStore::test(cx);
2594            cx.set_global(settings_store);
2595
2596            LanguageModelRegistry::test(cx);
2597        });
2598    }
2599}
2600
2601fn mcp_message_content_to_acp_content_block(
2602    content: context_server::types::MessageContent,
2603) -> acp::ContentBlock {
2604    match content {
2605        context_server::types::MessageContent::Text {
2606            text,
2607            annotations: _,
2608        } => text.into(),
2609        context_server::types::MessageContent::Image {
2610            data,
2611            mime_type,
2612            annotations: _,
2613        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2614        context_server::types::MessageContent::Audio {
2615            data,
2616            mime_type,
2617            annotations: _,
2618        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2619        context_server::types::MessageContent::Resource {
2620            resource,
2621            annotations: _,
2622        } => {
2623            let mut link =
2624                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2625            if let Some(mime_type) = resource.mime_type {
2626                link = link.mime_type(mime_type);
2627            }
2628            acp::ContentBlock::ResourceLink(link)
2629        }
2630    }
2631}