agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17use itertools::Itertools;
  18pub use native_agent_server::NativeAgentServer;
  19pub use pattern_extraction::*;
  20pub use shell_command_parser::extract_commands;
  21pub use templates::*;
  22pub use thread::*;
  23pub use thread_store::*;
  24pub use tool_permissions::*;
  25pub use tools::*;
  26
  27use acp_thread::{
  28    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  29    AgentSessionListResponse, TokenUsageRatio, UserMessageId,
  30};
  31use agent_client_protocol as acp;
  32use anyhow::{Context as _, Result, anyhow};
  33use chrono::{DateTime, Utc};
  34use collections::{HashMap, HashSet, IndexMap};
  35use fs::Fs;
  36use futures::channel::{mpsc, oneshot};
  37use futures::future::Shared;
  38use futures::{FutureExt as _, StreamExt as _, future};
  39use gpui::{
  40    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  41};
  42use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  43use project::{Project, ProjectItem, ProjectPath, Worktree};
  44use prompt_store::{
  45    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  46    WorktreeContext,
  47};
  48use serde::{Deserialize, Serialize};
  49use settings::{LanguageModelSelection, update_settings_file};
  50use std::any::Any;
  51use std::path::{Path, PathBuf};
  52use std::rc::Rc;
  53use std::sync::Arc;
  54use util::ResultExt;
  55use util::path_list::PathList;
  56use util::rel_path::RelPath;
  57
  58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  59pub struct ProjectSnapshot {
  60    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  61    pub timestamp: DateTime<Utc>,
  62}
  63
  64pub struct RulesLoadingError {
  65    pub message: SharedString,
  66}
  67
  68/// Holds both the internal Thread and the AcpThread for a session
  69struct Session {
  70    /// The internal thread that processes messages
  71    thread: Entity<Thread>,
  72    /// The ACP thread that handles protocol communication
  73    acp_thread: Entity<acp_thread::AcpThread>,
  74    pending_save: Task<()>,
  75    _subscriptions: Vec<Subscription>,
  76}
  77
  78pub struct LanguageModels {
  79    /// Access language model by ID
  80    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  81    /// Cached list for returning language model information
  82    model_list: acp_thread::AgentModelList,
  83    refresh_models_rx: watch::Receiver<()>,
  84    refresh_models_tx: watch::Sender<()>,
  85    _authenticate_all_providers_task: Task<()>,
  86}
  87
  88impl LanguageModels {
  89    fn new(cx: &mut App) -> Self {
  90        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  91
  92        let mut this = Self {
  93            models: HashMap::default(),
  94            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  95            refresh_models_rx,
  96            refresh_models_tx,
  97            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  98        };
  99        this.refresh_list(cx);
 100        this
 101    }
 102
 103    fn refresh_list(&mut self, cx: &App) {
 104        let providers = LanguageModelRegistry::global(cx)
 105            .read(cx)
 106            .visible_providers()
 107            .into_iter()
 108            .filter(|provider| provider.is_authenticated(cx))
 109            .collect::<Vec<_>>();
 110
 111        let mut language_model_list = IndexMap::default();
 112        let mut recommended_models = HashSet::default();
 113
 114        let mut recommended = Vec::new();
 115        for provider in &providers {
 116            for model in provider.recommended_models(cx) {
 117                recommended_models.insert((model.provider_id(), model.id()));
 118                recommended.push(Self::map_language_model_to_info(&model, provider));
 119            }
 120        }
 121        if !recommended.is_empty() {
 122            language_model_list.insert(
 123                acp_thread::AgentModelGroupName("Recommended".into()),
 124                recommended,
 125            );
 126        }
 127
 128        let mut models = HashMap::default();
 129        for provider in providers {
 130            let mut provider_models = Vec::new();
 131            for model in provider.provided_models(cx) {
 132                let model_info = Self::map_language_model_to_info(&model, &provider);
 133                let model_id = model_info.id.clone();
 134                provider_models.push(model_info);
 135                models.insert(model_id, model);
 136            }
 137            if !provider_models.is_empty() {
 138                language_model_list.insert(
 139                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 140                    provider_models,
 141                );
 142            }
 143        }
 144
 145        self.models = models;
 146        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 147        self.refresh_models_tx.send(()).ok();
 148    }
 149
 150    fn watch(&self) -> watch::Receiver<()> {
 151        self.refresh_models_rx.clone()
 152    }
 153
 154    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 155        self.models.get(model_id).cloned()
 156    }
 157
 158    fn map_language_model_to_info(
 159        model: &Arc<dyn LanguageModel>,
 160        provider: &Arc<dyn LanguageModelProvider>,
 161    ) -> acp_thread::AgentModelInfo {
 162        acp_thread::AgentModelInfo {
 163            id: Self::model_id(model),
 164            name: model.name().0,
 165            description: None,
 166            icon: Some(match provider.icon() {
 167                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 168                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 169            }),
 170            is_latest: model.is_latest(),
 171            cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
 172        }
 173    }
 174
 175    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 176        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 177    }
 178
 179    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 180        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 181            .read(cx)
 182            .visible_providers()
 183            .iter()
 184            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 185            .collect::<Vec<_>>();
 186
 187        cx.background_spawn(async move {
 188            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 189                if let Err(err) = authenticate_task.await {
 190                    match err {
 191                        language_model::AuthenticateError::CredentialsNotFound => {
 192                            // Since we're authenticating these providers in the
 193                            // background for the purposes of populating the
 194                            // language selector, we don't care about providers
 195                            // where the credentials are not found.
 196                        }
 197                        language_model::AuthenticateError::ConnectionRefused => {
 198                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 199                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 200                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 201                        }
 202                        _ => {
 203                            // Some providers have noisy failure states that we
 204                            // don't want to spam the logs with every time the
 205                            // language model selector is initialized.
 206                            //
 207                            // Ideally these should have more clear failure modes
 208                            // that we know are safe to ignore here, like what we do
 209                            // with `CredentialsNotFound` above.
 210                            match provider_id.0.as_ref() {
 211                                "lmstudio" | "ollama" => {
 212                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 213                                    //
 214                                    // These fail noisily, so we don't log them.
 215                                }
 216                                "copilot_chat" => {
 217                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 218                                }
 219                                _ => {
 220                                    log::error!(
 221                                        "Failed to authenticate provider: {}: {err:#}",
 222                                        provider_name.0
 223                                    );
 224                                }
 225                            }
 226                        }
 227                    }
 228                }
 229            }
 230        })
 231    }
 232}
 233
 234pub struct NativeAgent {
 235    /// Session ID -> Session mapping
 236    sessions: HashMap<acp::SessionId, Session>,
 237    thread_store: Entity<ThreadStore>,
 238    /// Shared project context for all threads
 239    project_context: Entity<ProjectContext>,
 240    project_context_needs_refresh: watch::Sender<()>,
 241    _maintain_project_context: Task<Result<()>>,
 242    context_server_registry: Entity<ContextServerRegistry>,
 243    /// Shared templates for all threads
 244    templates: Arc<Templates>,
 245    /// Cached model information
 246    models: LanguageModels,
 247    project: Entity<Project>,
 248    prompt_store: Option<Entity<PromptStore>>,
 249    fs: Arc<dyn Fs>,
 250    _subscriptions: Vec<Subscription>,
 251}
 252
 253impl NativeAgent {
 254    pub async fn new(
 255        project: Entity<Project>,
 256        thread_store: Entity<ThreadStore>,
 257        templates: Arc<Templates>,
 258        prompt_store: Option<Entity<PromptStore>>,
 259        fs: Arc<dyn Fs>,
 260        cx: &mut AsyncApp,
 261    ) -> Result<Entity<NativeAgent>> {
 262        log::debug!("Creating new NativeAgent");
 263
 264        let project_context = cx
 265            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 266            .await;
 267
 268        Ok(cx.new(|cx| {
 269            let context_server_store = project.read(cx).context_server_store();
 270            let context_server_registry =
 271                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 272
 273            let mut subscriptions = vec![
 274                cx.subscribe(&project, Self::handle_project_event),
 275                cx.subscribe(
 276                    &LanguageModelRegistry::global(cx),
 277                    Self::handle_models_updated_event,
 278                ),
 279                cx.subscribe(
 280                    &context_server_store,
 281                    Self::handle_context_server_store_updated,
 282                ),
 283                cx.subscribe(
 284                    &context_server_registry,
 285                    Self::handle_context_server_registry_event,
 286                ),
 287            ];
 288            if let Some(prompt_store) = prompt_store.as_ref() {
 289                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 290            }
 291
 292            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 293                watch::channel(());
 294            Self {
 295                sessions: HashMap::default(),
 296                thread_store,
 297                project_context: cx.new(|_| project_context),
 298                project_context_needs_refresh: project_context_needs_refresh_tx,
 299                _maintain_project_context: cx.spawn(async move |this, cx| {
 300                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 301                }),
 302                context_server_registry,
 303                templates,
 304                models: LanguageModels::new(cx),
 305                project,
 306                prompt_store,
 307                fs,
 308                _subscriptions: subscriptions,
 309            }
 310        }))
 311    }
 312
 313    fn new_session(
 314        &mut self,
 315        project: Entity<Project>,
 316        cx: &mut Context<Self>,
 317    ) -> Entity<AcpThread> {
 318        // Create Thread
 319        // Fetch default model from registry settings
 320        let registry = LanguageModelRegistry::read_global(cx);
 321        // Log available models for debugging
 322        let available_count = registry.available_models(cx).count();
 323        log::debug!("Total available models: {}", available_count);
 324
 325        let default_model = registry.default_model().and_then(|default_model| {
 326            self.models
 327                .model_from_id(&LanguageModels::model_id(&default_model.model))
 328        });
 329        let thread = cx.new(|cx| {
 330            Thread::new(
 331                project.clone(),
 332                self.project_context.clone(),
 333                self.context_server_registry.clone(),
 334                self.templates.clone(),
 335                default_model,
 336                cx,
 337            )
 338        });
 339
 340        self.register_session(thread, cx)
 341    }
 342
 343    fn register_session(
 344        &mut self,
 345        thread_handle: Entity<Thread>,
 346        cx: &mut Context<Self>,
 347    ) -> Entity<AcpThread> {
 348        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 349
 350        let thread = thread_handle.read(cx);
 351        let session_id = thread.id().clone();
 352        let parent_session_id = thread.parent_thread_id();
 353        let title = thread.title();
 354        let draft_prompt = thread.draft_prompt().map(Vec::from);
 355        let project = thread.project.clone();
 356        let action_log = thread.action_log.clone();
 357        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 358        let acp_thread = cx.new(|cx| {
 359            let mut acp_thread = acp_thread::AcpThread::new(
 360                parent_session_id,
 361                title,
 362                connection,
 363                project.clone(),
 364                action_log.clone(),
 365                session_id.clone(),
 366                prompt_capabilities_rx,
 367                cx,
 368            );
 369            acp_thread.set_draft_prompt(draft_prompt);
 370            acp_thread
 371        });
 372
 373        let registry = LanguageModelRegistry::read_global(cx);
 374        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 375
 376        let weak = cx.weak_entity();
 377        let weak_thread = thread_handle.downgrade();
 378        thread_handle.update(cx, |thread, cx| {
 379            thread.set_summarization_model(summarization_model, cx);
 380            thread.add_default_tools(
 381                Rc::new(NativeThreadEnvironment {
 382                    acp_thread: acp_thread.downgrade(),
 383                    thread: weak_thread,
 384                    agent: weak,
 385                }) as _,
 386                cx,
 387            )
 388        });
 389
 390        let subscriptions = vec![
 391            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 392            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 393            cx.observe(&thread_handle, move |this, thread, cx| {
 394                this.save_thread(thread, cx)
 395            }),
 396        ];
 397
 398        self.sessions.insert(
 399            session_id,
 400            Session {
 401                thread: thread_handle,
 402                acp_thread: acp_thread.clone(),
 403                _subscriptions: subscriptions,
 404                pending_save: Task::ready(()),
 405            },
 406        );
 407
 408        self.update_available_commands(cx);
 409
 410        acp_thread
 411    }
 412
 413    pub fn models(&self) -> &LanguageModels {
 414        &self.models
 415    }
 416
 417    async fn maintain_project_context(
 418        this: WeakEntity<Self>,
 419        mut needs_refresh: watch::Receiver<()>,
 420        cx: &mut AsyncApp,
 421    ) -> Result<()> {
 422        while needs_refresh.changed().await.is_ok() {
 423            let project_context = this
 424                .update(cx, |this, cx| {
 425                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 426                })?
 427                .await;
 428            this.update(cx, |this, cx| {
 429                this.project_context = cx.new(|_| project_context);
 430            })?;
 431        }
 432
 433        Ok(())
 434    }
 435
 436    fn build_project_context(
 437        project: &Entity<Project>,
 438        prompt_store: Option<&Entity<PromptStore>>,
 439        cx: &mut App,
 440    ) -> Task<ProjectContext> {
 441        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 442        let worktree_tasks = worktrees
 443            .into_iter()
 444            .map(|worktree| {
 445                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 446            })
 447            .collect::<Vec<_>>();
 448        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 449            prompt_store.read_with(cx, |prompt_store, cx| {
 450                let prompts = prompt_store.default_prompt_metadata();
 451                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 452                    let contents = prompt_store.load(prompt_metadata.id, cx);
 453                    async move { (contents.await, prompt_metadata) }
 454                });
 455                cx.background_spawn(future::join_all(load_tasks))
 456            })
 457        } else {
 458            Task::ready(vec![])
 459        };
 460
 461        cx.spawn(async move |_cx| {
 462            let (worktrees, default_user_rules) =
 463                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 464
 465            let worktrees = worktrees
 466                .into_iter()
 467                .map(|(worktree, _rules_error)| {
 468                    // TODO: show error message
 469                    // if let Some(rules_error) = rules_error {
 470                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 471                    // }
 472                    worktree
 473                })
 474                .collect::<Vec<_>>();
 475
 476            let default_user_rules = default_user_rules
 477                .into_iter()
 478                .flat_map(|(contents, prompt_metadata)| match contents {
 479                    Ok(contents) => Some(UserRulesContext {
 480                        uuid: prompt_metadata.id.as_user()?,
 481                        title: prompt_metadata.title.map(|title| title.to_string()),
 482                        contents,
 483                    }),
 484                    Err(_err) => {
 485                        // TODO: show error message
 486                        // this.update(cx, |_, cx| {
 487                        //     cx.emit(RulesLoadingError {
 488                        //         message: format!("{err:?}").into(),
 489                        //     });
 490                        // })
 491                        // .ok();
 492                        None
 493                    }
 494                })
 495                .collect::<Vec<_>>();
 496
 497            ProjectContext::new(worktrees, default_user_rules)
 498        })
 499    }
 500
 501    fn load_worktree_info_for_system_prompt(
 502        worktree: Entity<Worktree>,
 503        project: Entity<Project>,
 504        cx: &mut App,
 505    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 506        let tree = worktree.read(cx);
 507        let root_name = tree.root_name_str().into();
 508        let abs_path = tree.abs_path();
 509
 510        let mut context = WorktreeContext {
 511            root_name,
 512            abs_path,
 513            rules_file: None,
 514        };
 515
 516        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 517        let Some(rules_task) = rules_task else {
 518            return Task::ready((context, None));
 519        };
 520
 521        cx.spawn(async move |_| {
 522            let (rules_file, rules_file_error) = match rules_task.await {
 523                Ok(rules_file) => (Some(rules_file), None),
 524                Err(err) => (
 525                    None,
 526                    Some(RulesLoadingError {
 527                        message: format!("{err}").into(),
 528                    }),
 529                ),
 530            };
 531            context.rules_file = rules_file;
 532            (context, rules_file_error)
 533        })
 534    }
 535
 536    fn load_worktree_rules_file(
 537        worktree: Entity<Worktree>,
 538        project: Entity<Project>,
 539        cx: &mut App,
 540    ) -> Option<Task<Result<RulesFileContext>>> {
 541        let worktree = worktree.read(cx);
 542        let worktree_id = worktree.id();
 543        let selected_rules_file = RULES_FILE_NAMES
 544            .into_iter()
 545            .filter_map(|name| {
 546                worktree
 547                    .entry_for_path(RelPath::unix(name).unwrap())
 548                    .filter(|entry| entry.is_file())
 549                    .map(|entry| entry.path.clone())
 550            })
 551            .next();
 552
 553        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 554        // supported. This doesn't seem to occur often in GitHub repositories.
 555        selected_rules_file.map(|path_in_worktree| {
 556            let project_path = ProjectPath {
 557                worktree_id,
 558                path: path_in_worktree.clone(),
 559            };
 560            let buffer_task =
 561                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 562            let rope_task = cx.spawn(async move |cx| {
 563                let buffer = buffer_task.await?;
 564                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 565                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 566                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 567                })?;
 568                anyhow::Ok((project_entry_id, rope))
 569            });
 570            // Build a string from the rope on a background thread.
 571            cx.background_spawn(async move {
 572                let (project_entry_id, rope) = rope_task.await?;
 573                anyhow::Ok(RulesFileContext {
 574                    path_in_worktree,
 575                    text: rope.to_string().trim().to_string(),
 576                    project_entry_id: project_entry_id.to_usize(),
 577                })
 578            })
 579        })
 580    }
 581
 582    fn handle_thread_title_updated(
 583        &mut self,
 584        thread: Entity<Thread>,
 585        _: &TitleUpdated,
 586        cx: &mut Context<Self>,
 587    ) {
 588        let session_id = thread.read(cx).id();
 589        let Some(session) = self.sessions.get(session_id) else {
 590            return;
 591        };
 592        let thread = thread.downgrade();
 593        let acp_thread = session.acp_thread.downgrade();
 594        cx.spawn(async move |_, cx| {
 595            let title = thread.read_with(cx, |thread, _| thread.title())?;
 596            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 597            task.await
 598        })
 599        .detach_and_log_err(cx);
 600    }
 601
 602    fn handle_thread_token_usage_updated(
 603        &mut self,
 604        thread: Entity<Thread>,
 605        usage: &TokenUsageUpdated,
 606        cx: &mut Context<Self>,
 607    ) {
 608        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 609            return;
 610        };
 611        session.acp_thread.update(cx, |acp_thread, cx| {
 612            acp_thread.update_token_usage(usage.0.clone(), cx);
 613        });
 614    }
 615
 616    fn handle_project_event(
 617        &mut self,
 618        _project: Entity<Project>,
 619        event: &project::Event,
 620        _cx: &mut Context<Self>,
 621    ) {
 622        match event {
 623            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 624                self.project_context_needs_refresh.send(()).ok();
 625            }
 626            project::Event::WorktreeUpdatedEntries(_, items) => {
 627                if items.iter().any(|(path, _, _)| {
 628                    RULES_FILE_NAMES
 629                        .iter()
 630                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 631                }) {
 632                    self.project_context_needs_refresh.send(()).ok();
 633                }
 634            }
 635            _ => {}
 636        }
 637    }
 638
 639    fn handle_prompts_updated_event(
 640        &mut self,
 641        _prompt_store: Entity<PromptStore>,
 642        _event: &prompt_store::PromptsUpdatedEvent,
 643        _cx: &mut Context<Self>,
 644    ) {
 645        self.project_context_needs_refresh.send(()).ok();
 646    }
 647
 648    fn handle_models_updated_event(
 649        &mut self,
 650        _registry: Entity<LanguageModelRegistry>,
 651        _event: &language_model::Event,
 652        cx: &mut Context<Self>,
 653    ) {
 654        self.models.refresh_list(cx);
 655
 656        let registry = LanguageModelRegistry::read_global(cx);
 657        let default_model = registry.default_model().map(|m| m.model);
 658        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 659
 660        for session in self.sessions.values_mut() {
 661            session.thread.update(cx, |thread, cx| {
 662                if thread.model().is_none()
 663                    && let Some(model) = default_model.clone()
 664                {
 665                    thread.set_model(model, cx);
 666                    cx.notify();
 667                }
 668                thread.set_summarization_model(summarization_model.clone(), cx);
 669            });
 670        }
 671    }
 672
 673    fn handle_context_server_store_updated(
 674        &mut self,
 675        _store: Entity<project::context_server_store::ContextServerStore>,
 676        _event: &project::context_server_store::ServerStatusChangedEvent,
 677        cx: &mut Context<Self>,
 678    ) {
 679        self.update_available_commands(cx);
 680    }
 681
 682    fn handle_context_server_registry_event(
 683        &mut self,
 684        _registry: Entity<ContextServerRegistry>,
 685        event: &ContextServerRegistryEvent,
 686        cx: &mut Context<Self>,
 687    ) {
 688        match event {
 689            ContextServerRegistryEvent::ToolsChanged => {}
 690            ContextServerRegistryEvent::PromptsChanged => {
 691                self.update_available_commands(cx);
 692            }
 693        }
 694    }
 695
 696    fn update_available_commands(&self, cx: &mut Context<Self>) {
 697        let available_commands = self.build_available_commands(cx);
 698        for session in self.sessions.values() {
 699            session.acp_thread.update(cx, |thread, cx| {
 700                thread
 701                    .handle_session_update(
 702                        acp::SessionUpdate::AvailableCommandsUpdate(
 703                            acp::AvailableCommandsUpdate::new(available_commands.clone()),
 704                        ),
 705                        cx,
 706                    )
 707                    .log_err();
 708            });
 709        }
 710    }
 711
 712    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 713        let registry = self.context_server_registry.read(cx);
 714
 715        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 716        for context_server_prompt in registry.prompts() {
 717            *prompt_name_counts
 718                .entry(context_server_prompt.prompt.name.as_str())
 719                .or_insert(0) += 1;
 720        }
 721
 722        registry
 723            .prompts()
 724            .flat_map(|context_server_prompt| {
 725                let prompt = &context_server_prompt.prompt;
 726
 727                let should_prefix = prompt_name_counts
 728                    .get(prompt.name.as_str())
 729                    .copied()
 730                    .unwrap_or(0)
 731                    > 1;
 732
 733                let name = if should_prefix {
 734                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 735                } else {
 736                    prompt.name.clone()
 737                };
 738
 739                let mut command = acp::AvailableCommand::new(
 740                    name,
 741                    prompt.description.clone().unwrap_or_default(),
 742                );
 743
 744                match prompt.arguments.as_deref() {
 745                    Some([arg]) => {
 746                        let hint = format!("<{}>", arg.name);
 747
 748                        command = command.input(acp::AvailableCommandInput::Unstructured(
 749                            acp::UnstructuredCommandInput::new(hint),
 750                        ));
 751                    }
 752                    Some([]) | None => {}
 753                    Some(_) => {
 754                        // skip >1 argument commands since we don't support them yet
 755                        return None;
 756                    }
 757                }
 758
 759                Some(command)
 760            })
 761            .collect()
 762    }
 763
 764    pub fn load_thread(
 765        &mut self,
 766        id: acp::SessionId,
 767        cx: &mut Context<Self>,
 768    ) -> Task<Result<Entity<Thread>>> {
 769        let database_future = ThreadsDatabase::connect(cx);
 770        cx.spawn(async move |this, cx| {
 771            let database = database_future.await.map_err(|err| anyhow!(err))?;
 772            let db_thread = database
 773                .load_thread(id.clone())
 774                .await?
 775                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 776
 777            this.update(cx, |this, cx| {
 778                let summarization_model = LanguageModelRegistry::read_global(cx)
 779                    .thread_summary_model()
 780                    .map(|c| c.model);
 781
 782                cx.new(|cx| {
 783                    let mut thread = Thread::from_db(
 784                        id.clone(),
 785                        db_thread,
 786                        this.project.clone(),
 787                        this.project_context.clone(),
 788                        this.context_server_registry.clone(),
 789                        this.templates.clone(),
 790                        cx,
 791                    );
 792                    thread.set_summarization_model(summarization_model, cx);
 793                    thread
 794                })
 795            })
 796        })
 797    }
 798
 799    pub fn open_thread(
 800        &mut self,
 801        id: acp::SessionId,
 802        cx: &mut Context<Self>,
 803    ) -> Task<Result<Entity<AcpThread>>> {
 804        if let Some(session) = self.sessions.get(&id) {
 805            return Task::ready(Ok(session.acp_thread.clone()));
 806        }
 807
 808        let task = self.load_thread(id, cx);
 809        cx.spawn(async move |this, cx| {
 810            let thread = task.await?;
 811            let acp_thread =
 812                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 813            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 814            cx.update(|cx| {
 815                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 816            })
 817            .await?;
 818            Ok(acp_thread)
 819        })
 820    }
 821
 822    pub fn thread_summary(
 823        &mut self,
 824        id: acp::SessionId,
 825        cx: &mut Context<Self>,
 826    ) -> Task<Result<SharedString>> {
 827        let thread = self.open_thread(id.clone(), cx);
 828        cx.spawn(async move |this, cx| {
 829            let acp_thread = thread.await?;
 830            let result = this
 831                .update(cx, |this, cx| {
 832                    this.sessions
 833                        .get(&id)
 834                        .unwrap()
 835                        .thread
 836                        .update(cx, |thread, cx| thread.summary(cx))
 837                })?
 838                .await
 839                .context("Failed to generate summary")?;
 840            drop(acp_thread);
 841            Ok(result)
 842        })
 843    }
 844
 845    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 846        if thread.read(cx).is_empty() {
 847            return;
 848        }
 849
 850        let id = thread.read(cx).id().clone();
 851        let Some(session) = self.sessions.get_mut(&id) else {
 852            return;
 853        };
 854
 855        let folder_paths = PathList::new(
 856            &self
 857                .project
 858                .read(cx)
 859                .visible_worktrees(cx)
 860                .map(|worktree| worktree.read(cx).abs_path().to_path_buf())
 861                .collect::<Vec<_>>(),
 862        );
 863
 864        let draft_prompt = session.acp_thread.read(cx).draft_prompt().map(Vec::from);
 865        let database_future = ThreadsDatabase::connect(cx);
 866        let db_thread = thread.update(cx, |thread, cx| {
 867            thread.set_draft_prompt(draft_prompt);
 868            thread.to_db(cx)
 869        });
 870        let thread_store = self.thread_store.clone();
 871        session.pending_save = cx.spawn(async move |_, cx| {
 872            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 873                return;
 874            };
 875            let db_thread = db_thread.await;
 876            database
 877                .save_thread(id, db_thread, folder_paths)
 878                .await
 879                .log_err();
 880            thread_store.update(cx, |store, cx| store.reload(cx));
 881        });
 882    }
 883
 884    fn send_mcp_prompt(
 885        &self,
 886        message_id: UserMessageId,
 887        session_id: agent_client_protocol::SessionId,
 888        prompt_name: String,
 889        server_id: ContextServerId,
 890        arguments: HashMap<String, String>,
 891        original_content: Vec<acp::ContentBlock>,
 892        cx: &mut Context<Self>,
 893    ) -> Task<Result<acp::PromptResponse>> {
 894        let server_store = self.context_server_registry.read(cx).server_store().clone();
 895        let path_style = self.project.read(cx).path_style(cx);
 896
 897        cx.spawn(async move |this, cx| {
 898            let prompt =
 899                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 900
 901            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 902                let session = this
 903                    .sessions
 904                    .get(&session_id)
 905                    .context("Failed to get session")?;
 906                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 907            })??;
 908
 909            let mut last_is_user = true;
 910
 911            thread.update(cx, |thread, cx| {
 912                thread.push_acp_user_block(
 913                    message_id,
 914                    original_content.into_iter().skip(1),
 915                    path_style,
 916                    cx,
 917                );
 918            });
 919
 920            for message in prompt.messages {
 921                let context_server::types::PromptMessage { role, content } = message;
 922                let block = mcp_message_content_to_acp_content_block(content);
 923
 924                match role {
 925                    context_server::types::Role::User => {
 926                        let id = acp_thread::UserMessageId::new();
 927
 928                        acp_thread.update(cx, |acp_thread, cx| {
 929                            acp_thread.push_user_content_block_with_indent(
 930                                Some(id.clone()),
 931                                block.clone(),
 932                                true,
 933                                cx,
 934                            );
 935                        });
 936
 937                        thread.update(cx, |thread, cx| {
 938                            thread.push_acp_user_block(id, [block], path_style, cx);
 939                        });
 940                    }
 941                    context_server::types::Role::Assistant => {
 942                        acp_thread.update(cx, |acp_thread, cx| {
 943                            acp_thread.push_assistant_content_block_with_indent(
 944                                block.clone(),
 945                                false,
 946                                true,
 947                                cx,
 948                            );
 949                        });
 950
 951                        thread.update(cx, |thread, cx| {
 952                            thread.push_acp_agent_block(block, cx);
 953                        });
 954                    }
 955                }
 956
 957                last_is_user = role == context_server::types::Role::User;
 958            }
 959
 960            let response_stream = thread.update(cx, |thread, cx| {
 961                if last_is_user {
 962                    thread.send_existing(cx)
 963                } else {
 964                    // Resume if MCP prompt did not end with a user message
 965                    thread.resume(cx)
 966                }
 967            })?;
 968
 969            cx.update(|cx| {
 970                NativeAgentConnection::handle_thread_events(
 971                    response_stream,
 972                    acp_thread.downgrade(),
 973                    cx,
 974                )
 975            })
 976            .await
 977        })
 978    }
 979}
 980
 981/// Wrapper struct that implements the AgentConnection trait
 982#[derive(Clone)]
 983pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 984
 985impl NativeAgentConnection {
 986    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 987        self.0
 988            .read(cx)
 989            .sessions
 990            .get(session_id)
 991            .map(|session| session.thread.clone())
 992    }
 993
 994    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 995        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 996    }
 997
 998    fn run_turn(
 999        &self,
1000        session_id: acp::SessionId,
1001        cx: &mut App,
1002        f: impl 'static
1003        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
1004    ) -> Task<Result<acp::PromptResponse>> {
1005        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
1006            agent
1007                .sessions
1008                .get_mut(&session_id)
1009                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
1010        }) else {
1011            return Task::ready(Err(anyhow!("Session not found")));
1012        };
1013        log::debug!("Found session for: {}", session_id);
1014
1015        let response_stream = match f(thread, cx) {
1016            Ok(stream) => stream,
1017            Err(err) => return Task::ready(Err(err)),
1018        };
1019        Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1020    }
1021
1022    fn handle_thread_events(
1023        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1024        acp_thread: WeakEntity<AcpThread>,
1025        cx: &App,
1026    ) -> Task<Result<acp::PromptResponse>> {
1027        cx.spawn(async move |cx| {
1028            // Handle response stream and forward to session.acp_thread
1029            while let Some(result) = events.next().await {
1030                match result {
1031                    Ok(event) => {
1032                        log::trace!("Received completion event: {:?}", event);
1033
1034                        match event {
1035                            ThreadEvent::UserMessage(message) => {
1036                                acp_thread.update(cx, |thread, cx| {
1037                                    for content in message.content {
1038                                        thread.push_user_content_block(
1039                                            Some(message.id.clone()),
1040                                            content.into(),
1041                                            cx,
1042                                        );
1043                                    }
1044                                })?;
1045                            }
1046                            ThreadEvent::AgentText(text) => {
1047                                acp_thread.update(cx, |thread, cx| {
1048                                    thread.push_assistant_content_block(text.into(), false, cx)
1049                                })?;
1050                            }
1051                            ThreadEvent::AgentThinking(text) => {
1052                                acp_thread.update(cx, |thread, cx| {
1053                                    thread.push_assistant_content_block(text.into(), true, cx)
1054                                })?;
1055                            }
1056                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1057                                tool_call,
1058                                options,
1059                                response,
1060                                context: _,
1061                            }) => {
1062                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1063                                    thread.request_tool_call_authorization(tool_call, options, cx)
1064                                })??;
1065                                cx.background_spawn(async move {
1066                                    if let acp::RequestPermissionOutcome::Selected(
1067                                        acp::SelectedPermissionOutcome { option_id, .. },
1068                                    ) = outcome_task.await
1069                                    {
1070                                        response
1071                                            .send(option_id)
1072                                            .map(|_| anyhow!("authorization receiver was dropped"))
1073                                            .log_err();
1074                                    }
1075                                })
1076                                .detach();
1077                            }
1078                            ThreadEvent::ToolCall(tool_call) => {
1079                                acp_thread.update(cx, |thread, cx| {
1080                                    thread.upsert_tool_call(tool_call, cx)
1081                                })??;
1082                            }
1083                            ThreadEvent::ToolCallUpdate(update) => {
1084                                acp_thread.update(cx, |thread, cx| {
1085                                    thread.update_tool_call(update, cx)
1086                                })??;
1087                            }
1088                            ThreadEvent::SubagentSpawned(session_id) => {
1089                                acp_thread.update(cx, |thread, cx| {
1090                                    thread.subagent_spawned(session_id, cx);
1091                                })?;
1092                            }
1093                            ThreadEvent::Retry(status) => {
1094                                acp_thread.update(cx, |thread, cx| {
1095                                    thread.update_retry_status(status, cx)
1096                                })?;
1097                            }
1098                            ThreadEvent::Stop(stop_reason) => {
1099                                log::debug!("Assistant message complete: {:?}", stop_reason);
1100                                return Ok(acp::PromptResponse::new(stop_reason));
1101                            }
1102                        }
1103                    }
1104                    Err(e) => {
1105                        log::error!("Error in model response stream: {:?}", e);
1106                        return Err(e);
1107                    }
1108                }
1109            }
1110
1111            log::debug!("Response stream completed");
1112            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1113        })
1114    }
1115}
1116
1117struct Command<'a> {
1118    prompt_name: &'a str,
1119    arg_value: &'a str,
1120    explicit_server_id: Option<&'a str>,
1121}
1122
1123impl<'a> Command<'a> {
1124    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1125        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1126            return None;
1127        };
1128        let text = text_content.text.trim();
1129        let command = text.strip_prefix('/')?;
1130        let (command, arg_value) = command
1131            .split_once(char::is_whitespace)
1132            .unwrap_or((command, ""));
1133
1134        if let Some((server_id, prompt_name)) = command.split_once('.') {
1135            Some(Self {
1136                prompt_name,
1137                arg_value,
1138                explicit_server_id: Some(server_id),
1139            })
1140        } else {
1141            Some(Self {
1142                prompt_name: command,
1143                arg_value,
1144                explicit_server_id: None,
1145            })
1146        }
1147    }
1148}
1149
1150struct NativeAgentModelSelector {
1151    session_id: acp::SessionId,
1152    connection: NativeAgentConnection,
1153}
1154
1155impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1156    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1157        log::debug!("NativeAgentConnection::list_models called");
1158        let list = self.connection.0.read(cx).models.model_list.clone();
1159        Task::ready(if list.is_empty() {
1160            Err(anyhow::anyhow!("No models available"))
1161        } else {
1162            Ok(list)
1163        })
1164    }
1165
1166    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1167        log::debug!(
1168            "Setting model for session {}: {}",
1169            self.session_id,
1170            model_id
1171        );
1172        let Some(thread) = self
1173            .connection
1174            .0
1175            .read(cx)
1176            .sessions
1177            .get(&self.session_id)
1178            .map(|session| session.thread.clone())
1179        else {
1180            return Task::ready(Err(anyhow!("Session not found")));
1181        };
1182
1183        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1184            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1185        };
1186
1187        // We want to reset the effort level when switching models, as the currently-selected effort level may
1188        // not be compatible.
1189        let effort = model
1190            .default_effort_level()
1191            .map(|effort_level| effort_level.value.to_string());
1192
1193        thread.update(cx, |thread, cx| {
1194            thread.set_model(model.clone(), cx);
1195            thread.set_thinking_effort(effort.clone(), cx);
1196            thread.set_thinking_enabled(model.supports_thinking(), cx);
1197        });
1198
1199        update_settings_file(
1200            self.connection.0.read(cx).fs.clone(),
1201            cx,
1202            move |settings, cx| {
1203                let provider = model.provider_id().0.to_string();
1204                let model = model.id().0.to_string();
1205                let enable_thinking = thread.read(cx).thinking_enabled();
1206                settings
1207                    .agent
1208                    .get_or_insert_default()
1209                    .set_model(LanguageModelSelection {
1210                        provider: provider.into(),
1211                        model,
1212                        enable_thinking,
1213                        effort,
1214                    });
1215            },
1216        );
1217
1218        Task::ready(Ok(()))
1219    }
1220
1221    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1222        let Some(thread) = self
1223            .connection
1224            .0
1225            .read(cx)
1226            .sessions
1227            .get(&self.session_id)
1228            .map(|session| session.thread.clone())
1229        else {
1230            return Task::ready(Err(anyhow!("Session not found")));
1231        };
1232        let Some(model) = thread.read(cx).model() else {
1233            return Task::ready(Err(anyhow!("Model not found")));
1234        };
1235        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1236        else {
1237            return Task::ready(Err(anyhow!("Provider not found")));
1238        };
1239        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1240            model, &provider,
1241        )))
1242    }
1243
1244    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1245        Some(self.connection.0.read(cx).models.watch())
1246    }
1247
1248    fn should_render_footer(&self) -> bool {
1249        true
1250    }
1251}
1252
1253impl acp_thread::AgentConnection for NativeAgentConnection {
1254    fn telemetry_id(&self) -> SharedString {
1255        "zed".into()
1256    }
1257
1258    fn new_session(
1259        self: Rc<Self>,
1260        project: Entity<Project>,
1261        cwd: &Path,
1262        cx: &mut App,
1263    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1264        log::debug!("Creating new thread for project at: {cwd:?}");
1265        Task::ready(Ok(self
1266            .0
1267            .update(cx, |agent, cx| agent.new_session(project, cx))))
1268    }
1269
1270    fn supports_load_session(&self) -> bool {
1271        true
1272    }
1273
1274    fn load_session(
1275        self: Rc<Self>,
1276        session: AgentSessionInfo,
1277        _project: Entity<Project>,
1278        _cwd: &Path,
1279        cx: &mut App,
1280    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1281        self.0
1282            .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1283    }
1284
1285    fn supports_close_session(&self) -> bool {
1286        true
1287    }
1288
1289    fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1290        self.0.update(cx, |agent, _cx| {
1291            agent.sessions.remove(session_id);
1292        });
1293        Task::ready(Ok(()))
1294    }
1295
1296    fn auth_methods(&self) -> &[acp::AuthMethod] {
1297        &[] // No auth for in-process
1298    }
1299
1300    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1301        Task::ready(Ok(()))
1302    }
1303
1304    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1305        Some(Rc::new(NativeAgentModelSelector {
1306            session_id: session_id.clone(),
1307            connection: self.clone(),
1308        }) as Rc<dyn AgentModelSelector>)
1309    }
1310
1311    fn prompt(
1312        &self,
1313        id: Option<acp_thread::UserMessageId>,
1314        params: acp::PromptRequest,
1315        cx: &mut App,
1316    ) -> Task<Result<acp::PromptResponse>> {
1317        let id = id.expect("UserMessageId is required");
1318        let session_id = params.session_id.clone();
1319        log::info!("Received prompt request for session: {}", session_id);
1320        log::debug!("Prompt blocks count: {}", params.prompt.len());
1321
1322        if let Some(parsed_command) = Command::parse(&params.prompt) {
1323            let registry = self.0.read(cx).context_server_registry.read(cx);
1324
1325            let explicit_server_id = parsed_command
1326                .explicit_server_id
1327                .map(|server_id| ContextServerId(server_id.into()));
1328
1329            if let Some(prompt) =
1330                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1331            {
1332                let arguments = if !parsed_command.arg_value.is_empty()
1333                    && let Some(arg_name) = prompt
1334                        .prompt
1335                        .arguments
1336                        .as_ref()
1337                        .and_then(|args| args.first())
1338                        .map(|arg| arg.name.clone())
1339                {
1340                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1341                } else {
1342                    Default::default()
1343                };
1344
1345                let prompt_name = prompt.prompt.name.clone();
1346                let server_id = prompt.server_id.clone();
1347
1348                return self.0.update(cx, |agent, cx| {
1349                    agent.send_mcp_prompt(
1350                        id,
1351                        session_id.clone(),
1352                        prompt_name,
1353                        server_id,
1354                        arguments,
1355                        params.prompt,
1356                        cx,
1357                    )
1358                });
1359            };
1360        };
1361
1362        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1363
1364        self.run_turn(session_id, cx, move |thread, cx| {
1365            let content: Vec<UserMessageContent> = params
1366                .prompt
1367                .into_iter()
1368                .map(|block| UserMessageContent::from_content_block(block, path_style))
1369                .collect::<Vec<_>>();
1370            log::debug!("Converted prompt to message: {} chars", content.len());
1371            log::debug!("Message id: {:?}", id);
1372            log::debug!("Message content: {:?}", content);
1373
1374            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1375        })
1376    }
1377
1378    fn retry(
1379        &self,
1380        session_id: &acp::SessionId,
1381        _cx: &App,
1382    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1383        Some(Rc::new(NativeAgentSessionRetry {
1384            connection: self.clone(),
1385            session_id: session_id.clone(),
1386        }) as _)
1387    }
1388
1389    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1390        log::info!("Cancelling on session: {}", session_id);
1391        self.0.update(cx, |agent, cx| {
1392            if let Some(session) = agent.sessions.get(session_id) {
1393                session
1394                    .thread
1395                    .update(cx, |thread, cx| thread.cancel(cx))
1396                    .detach();
1397            }
1398        });
1399    }
1400
1401    fn truncate(
1402        &self,
1403        session_id: &agent_client_protocol::SessionId,
1404        cx: &App,
1405    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1406        self.0.read_with(cx, |agent, _cx| {
1407            agent.sessions.get(session_id).map(|session| {
1408                Rc::new(NativeAgentSessionTruncate {
1409                    thread: session.thread.clone(),
1410                    acp_thread: session.acp_thread.downgrade(),
1411                }) as _
1412            })
1413        })
1414    }
1415
1416    fn set_title(
1417        &self,
1418        session_id: &acp::SessionId,
1419        cx: &App,
1420    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1421        self.0.read_with(cx, |agent, _cx| {
1422            agent
1423                .sessions
1424                .get(session_id)
1425                .filter(|s| !s.thread.read(cx).is_subagent())
1426                .map(|session| {
1427                    Rc::new(NativeAgentSessionSetTitle {
1428                        thread: session.thread.clone(),
1429                    }) as _
1430                })
1431        })
1432    }
1433
1434    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1435        let thread_store = self.0.read(cx).thread_store.clone();
1436        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1437    }
1438
1439    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1440        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1441    }
1442
1443    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1444        self
1445    }
1446}
1447
1448impl acp_thread::AgentTelemetry for NativeAgentConnection {
1449    fn thread_data(
1450        &self,
1451        session_id: &acp::SessionId,
1452        cx: &mut App,
1453    ) -> Task<Result<serde_json::Value>> {
1454        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1455            return Task::ready(Err(anyhow!("Session not found")));
1456        };
1457
1458        let task = session.thread.read(cx).to_db(cx);
1459        cx.background_spawn(async move {
1460            serde_json::to_value(task.await).context("Failed to serialize thread")
1461        })
1462    }
1463}
1464
1465pub struct NativeAgentSessionList {
1466    thread_store: Entity<ThreadStore>,
1467    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1468    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1469    _subscription: Subscription,
1470}
1471
1472impl NativeAgentSessionList {
1473    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1474        let (tx, rx) = smol::channel::unbounded();
1475        let this_tx = tx.clone();
1476        let subscription = cx.observe(&thread_store, move |_, _| {
1477            this_tx
1478                .try_send(acp_thread::SessionListUpdate::Refresh)
1479                .ok();
1480        });
1481        Self {
1482            thread_store,
1483            updates_tx: tx,
1484            updates_rx: rx,
1485            _subscription: subscription,
1486        }
1487    }
1488
1489    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1490        AgentSessionInfo {
1491            session_id: entry.id,
1492            cwd: None,
1493            title: Some(entry.title),
1494            updated_at: Some(entry.updated_at),
1495            meta: None,
1496        }
1497    }
1498
1499    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1500        &self.thread_store
1501    }
1502}
1503
1504impl AgentSessionList for NativeAgentSessionList {
1505    fn list_sessions(
1506        &self,
1507        _request: AgentSessionListRequest,
1508        cx: &mut App,
1509    ) -> Task<Result<AgentSessionListResponse>> {
1510        let sessions = self
1511            .thread_store
1512            .read(cx)
1513            .entries()
1514            .map(Self::to_session_info)
1515            .collect();
1516        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1517    }
1518
1519    fn supports_delete(&self) -> bool {
1520        true
1521    }
1522
1523    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1524        self.thread_store
1525            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1526    }
1527
1528    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1529        self.thread_store
1530            .update(cx, |store, cx| store.delete_threads(cx))
1531    }
1532
1533    fn watch(
1534        &self,
1535        _cx: &mut App,
1536    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1537        Some(self.updates_rx.clone())
1538    }
1539
1540    fn notify_refresh(&self) {
1541        self.updates_tx
1542            .try_send(acp_thread::SessionListUpdate::Refresh)
1543            .ok();
1544    }
1545
1546    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1547        self
1548    }
1549}
1550
1551struct NativeAgentSessionTruncate {
1552    thread: Entity<Thread>,
1553    acp_thread: WeakEntity<AcpThread>,
1554}
1555
1556impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1557    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1558        match self.thread.update(cx, |thread, cx| {
1559            thread.truncate(message_id.clone(), cx)?;
1560            Ok(thread.latest_token_usage())
1561        }) {
1562            Ok(usage) => {
1563                self.acp_thread
1564                    .update(cx, |thread, cx| {
1565                        thread.update_token_usage(usage, cx);
1566                    })
1567                    .ok();
1568                Task::ready(Ok(()))
1569            }
1570            Err(error) => Task::ready(Err(error)),
1571        }
1572    }
1573}
1574
1575struct NativeAgentSessionRetry {
1576    connection: NativeAgentConnection,
1577    session_id: acp::SessionId,
1578}
1579
1580impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1581    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1582        self.connection
1583            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1584                thread.update(cx, |thread, cx| thread.resume(cx))
1585            })
1586    }
1587}
1588
1589struct NativeAgentSessionSetTitle {
1590    thread: Entity<Thread>,
1591}
1592
1593impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1594    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1595        self.thread
1596            .update(cx, |thread, cx| thread.set_title(title, cx));
1597        Task::ready(Ok(()))
1598    }
1599}
1600
1601pub struct NativeThreadEnvironment {
1602    agent: WeakEntity<NativeAgent>,
1603    thread: WeakEntity<Thread>,
1604    acp_thread: WeakEntity<AcpThread>,
1605}
1606
1607impl NativeThreadEnvironment {
1608    pub(crate) fn create_subagent_thread(
1609        &self,
1610        label: String,
1611        cx: &mut App,
1612    ) -> Result<Rc<dyn SubagentHandle>> {
1613        let Some(parent_thread_entity) = self.thread.upgrade() else {
1614            anyhow::bail!("Parent thread no longer exists".to_string());
1615        };
1616        let parent_thread = parent_thread_entity.read(cx);
1617        let current_depth = parent_thread.depth();
1618
1619        if current_depth >= MAX_SUBAGENT_DEPTH {
1620            return Err(anyhow!(
1621                "Maximum subagent depth ({}) reached",
1622                MAX_SUBAGENT_DEPTH
1623            ));
1624        }
1625
1626        let subagent_thread: Entity<Thread> = cx.new(|cx| {
1627            let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1628            thread.set_title(label.into(), cx);
1629            thread
1630        });
1631
1632        let session_id = subagent_thread.read(cx).id().clone();
1633
1634        let acp_thread = self.agent.update(cx, |agent, cx| {
1635            agent.register_session(subagent_thread.clone(), cx)
1636        })?;
1637
1638        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1639    }
1640
1641    pub(crate) fn resume_subagent_thread(
1642        &self,
1643        session_id: acp::SessionId,
1644        cx: &mut App,
1645    ) -> Result<Rc<dyn SubagentHandle>> {
1646        let (subagent_thread, acp_thread) = self.agent.update(cx, |agent, _cx| {
1647            let session = agent
1648                .sessions
1649                .get(&session_id)
1650                .ok_or_else(|| anyhow!("No subagent session found with id {session_id}"))?;
1651            anyhow::Ok((session.thread.clone(), session.acp_thread.clone()))
1652        })??;
1653
1654        self.prompt_subagent(session_id, subagent_thread, acp_thread)
1655    }
1656
1657    fn prompt_subagent(
1658        &self,
1659        session_id: acp::SessionId,
1660        subagent_thread: Entity<Thread>,
1661        acp_thread: Entity<acp_thread::AcpThread>,
1662    ) -> Result<Rc<dyn SubagentHandle>> {
1663        let Some(parent_thread_entity) = self.thread.upgrade() else {
1664            anyhow::bail!("Parent thread no longer exists".to_string());
1665        };
1666        Ok(Rc::new(NativeSubagentHandle::new(
1667            session_id,
1668            subagent_thread,
1669            acp_thread,
1670            parent_thread_entity,
1671        )) as _)
1672    }
1673}
1674
1675impl ThreadEnvironment for NativeThreadEnvironment {
1676    fn create_terminal(
1677        &self,
1678        command: String,
1679        cwd: Option<PathBuf>,
1680        output_byte_limit: Option<u64>,
1681        cx: &mut AsyncApp,
1682    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1683        let task = self.acp_thread.update(cx, |thread, cx| {
1684            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1685        });
1686
1687        let acp_thread = self.acp_thread.clone();
1688        cx.spawn(async move |cx| {
1689            let terminal = task?.await?;
1690
1691            let (drop_tx, drop_rx) = oneshot::channel();
1692            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1693
1694            cx.spawn(async move |cx| {
1695                drop_rx.await.ok();
1696                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1697            })
1698            .detach();
1699
1700            let handle = AcpTerminalHandle {
1701                terminal,
1702                _drop_tx: Some(drop_tx),
1703            };
1704
1705            Ok(Rc::new(handle) as _)
1706        })
1707    }
1708
1709    fn create_subagent(&self, label: String, cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
1710        self.create_subagent_thread(label, cx)
1711    }
1712
1713    fn resume_subagent(
1714        &self,
1715        session_id: acp::SessionId,
1716        cx: &mut App,
1717    ) -> Result<Rc<dyn SubagentHandle>> {
1718        self.resume_subagent_thread(session_id, cx)
1719    }
1720}
1721
1722#[derive(Debug, Clone)]
1723enum SubagentPromptResult {
1724    Completed,
1725    Cancelled,
1726    ContextWindowWarning,
1727    Error(String),
1728}
1729
1730pub struct NativeSubagentHandle {
1731    session_id: acp::SessionId,
1732    parent_thread: WeakEntity<Thread>,
1733    subagent_thread: Entity<Thread>,
1734    acp_thread: Entity<acp_thread::AcpThread>,
1735}
1736
1737impl NativeSubagentHandle {
1738    fn new(
1739        session_id: acp::SessionId,
1740        subagent_thread: Entity<Thread>,
1741        acp_thread: Entity<acp_thread::AcpThread>,
1742        parent_thread_entity: Entity<Thread>,
1743    ) -> Self {
1744        NativeSubagentHandle {
1745            session_id,
1746            subagent_thread,
1747            parent_thread: parent_thread_entity.downgrade(),
1748            acp_thread,
1749        }
1750    }
1751}
1752
1753impl SubagentHandle for NativeSubagentHandle {
1754    fn id(&self) -> acp::SessionId {
1755        self.session_id.clone()
1756    }
1757
1758    fn num_entries(&self, cx: &App) -> usize {
1759        self.acp_thread.read(cx).entries().len()
1760    }
1761
1762    fn send(&self, message: String, cx: &AsyncApp) -> Task<Result<String>> {
1763        let thread = self.subagent_thread.clone();
1764        let acp_thread = self.acp_thread.clone();
1765        let subagent_session_id = self.session_id.clone();
1766        let parent_thread = self.parent_thread.clone();
1767
1768        cx.spawn(async move |cx| {
1769            let (task, _subscription) = cx.update(|cx| {
1770                let ratio_before_prompt = thread
1771                    .read(cx)
1772                    .latest_token_usage()
1773                    .map(|usage| usage.ratio());
1774
1775                parent_thread
1776                    .update(cx, |parent_thread, _cx| {
1777                        parent_thread.register_running_subagent(thread.downgrade())
1778                    })
1779                    .ok();
1780
1781                let task = acp_thread.update(cx, |acp_thread, cx| {
1782                    acp_thread.send(vec![message.into()], cx)
1783                });
1784
1785                let (token_limit_tx, token_limit_rx) = oneshot::channel::<()>();
1786                let mut token_limit_tx = Some(token_limit_tx);
1787
1788                let subscription = cx.subscribe(
1789                    &thread,
1790                    move |_thread, event: &TokenUsageUpdated, _cx| {
1791                        if let Some(usage) = &event.0 {
1792                            let old_ratio = ratio_before_prompt
1793                                .clone()
1794                                .unwrap_or(TokenUsageRatio::Normal);
1795                            let new_ratio = usage.ratio();
1796                            if old_ratio == TokenUsageRatio::Normal
1797                                && new_ratio == TokenUsageRatio::Warning
1798                            {
1799                                if let Some(tx) = token_limit_tx.take() {
1800                                    tx.send(()).ok();
1801                                }
1802                            }
1803                        }
1804                    },
1805                );
1806
1807                let wait_for_prompt = cx
1808                    .background_spawn(async move {
1809                        futures::select! {
1810                            response = task.fuse() => match response {
1811                                Ok(Some(response)) => {
1812                                    match response.stop_reason {
1813                                        acp::StopReason::Cancelled => SubagentPromptResult::Cancelled,
1814                                        acp::StopReason::MaxTokens => SubagentPromptResult::Error("The agent reached the maximum number of tokens.".into()),
1815                                        acp::StopReason::MaxTurnRequests => SubagentPromptResult::Error("The agent reached the maximum number of allowed requests between user turns. Try prompting again.".into()),
1816                                        acp::StopReason::Refusal => SubagentPromptResult::Error("The agent refused to process that prompt. Try again.".into()),
1817                                        acp::StopReason::EndTurn | _ => SubagentPromptResult::Completed,
1818                                    }
1819                                }
1820                                Ok(None) => SubagentPromptResult::Error("No response from the agent. You can try messaging again.".into()),
1821                                Err(error) => SubagentPromptResult::Error(error.to_string()),
1822                            },
1823                            _ = token_limit_rx.fuse() => SubagentPromptResult::ContextWindowWarning,
1824                        }
1825                    });
1826
1827                (wait_for_prompt, subscription)
1828            });
1829
1830            let result = match task.await {
1831                SubagentPromptResult::Completed => thread.read_with(cx, |thread, _cx| {
1832                    thread
1833                        .last_message()
1834                        .and_then(|message| {
1835                            let content = message.as_agent_message()?
1836                                .content
1837                                .iter()
1838                                .filter_map(|c| match c {
1839                                    AgentMessageContent::Text(text) => Some(text.as_str()),
1840                                    _ => None,
1841                                })
1842                                .join("\n\n");
1843                            if content.is_empty() {
1844                                None
1845                            } else {
1846                                Some( content)
1847                            }
1848                        })
1849                        .context("No response from subagent")
1850                }),
1851                SubagentPromptResult::Cancelled => Err(anyhow!("User canceled")),
1852                SubagentPromptResult::Error(message) => Err(anyhow!("{message}")),
1853                SubagentPromptResult::ContextWindowWarning => {
1854                    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1855                    Err(anyhow!(
1856                        "The agent is nearing the end of its context window and has been \
1857                         stopped. You can prompt the thread again to have the agent wrap up \
1858                         or hand off its work."
1859                    ))
1860                }
1861            };
1862
1863            parent_thread
1864                .update(cx, |parent_thread, cx| {
1865                    parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1866                })
1867                .ok();
1868
1869            result
1870        })
1871    }
1872}
1873
1874pub struct AcpTerminalHandle {
1875    terminal: Entity<acp_thread::Terminal>,
1876    _drop_tx: Option<oneshot::Sender<()>>,
1877}
1878
1879impl TerminalHandle for AcpTerminalHandle {
1880    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1881        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1882    }
1883
1884    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1885        Ok(self
1886            .terminal
1887            .read_with(cx, |term, _cx| term.wait_for_exit()))
1888    }
1889
1890    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1891        Ok(self
1892            .terminal
1893            .read_with(cx, |term, cx| term.current_output(cx)))
1894    }
1895
1896    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1897        cx.update(|cx| {
1898            self.terminal.update(cx, |terminal, cx| {
1899                terminal.kill(cx);
1900            });
1901        });
1902        Ok(())
1903    }
1904
1905    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1906        Ok(self
1907            .terminal
1908            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1909    }
1910}
1911
1912#[cfg(test)]
1913mod internal_tests {
1914    use super::*;
1915    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1916    use fs::FakeFs;
1917    use gpui::TestAppContext;
1918    use indoc::formatdoc;
1919    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1920    use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1921    use serde_json::json;
1922    use settings::SettingsStore;
1923    use util::{path, rel_path::rel_path};
1924
1925    #[gpui::test]
1926    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1927        init_test(cx);
1928        let fs = FakeFs::new(cx.executor());
1929        fs.insert_tree(
1930            "/",
1931            json!({
1932                "a": {}
1933            }),
1934        )
1935        .await;
1936        let project = Project::test(fs.clone(), [], cx).await;
1937        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1938        let agent = NativeAgent::new(
1939            project.clone(),
1940            thread_store,
1941            Templates::new(),
1942            None,
1943            fs.clone(),
1944            &mut cx.to_async(),
1945        )
1946        .await
1947        .unwrap();
1948        agent.read_with(cx, |agent, cx| {
1949            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1950        });
1951
1952        let worktree = project
1953            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1954            .await
1955            .unwrap();
1956        cx.run_until_parked();
1957        agent.read_with(cx, |agent, cx| {
1958            assert_eq!(
1959                agent.project_context.read(cx).worktrees,
1960                vec![WorktreeContext {
1961                    root_name: "a".into(),
1962                    abs_path: Path::new("/a").into(),
1963                    rules_file: None
1964                }]
1965            )
1966        });
1967
1968        // Creating `/a/.rules` updates the project context.
1969        fs.insert_file("/a/.rules", Vec::new()).await;
1970        cx.run_until_parked();
1971        agent.read_with(cx, |agent, cx| {
1972            let rules_entry = worktree
1973                .read(cx)
1974                .entry_for_path(rel_path(".rules"))
1975                .unwrap();
1976            assert_eq!(
1977                agent.project_context.read(cx).worktrees,
1978                vec![WorktreeContext {
1979                    root_name: "a".into(),
1980                    abs_path: Path::new("/a").into(),
1981                    rules_file: Some(RulesFileContext {
1982                        path_in_worktree: rel_path(".rules").into(),
1983                        text: "".into(),
1984                        project_entry_id: rules_entry.id.to_usize()
1985                    })
1986                }]
1987            )
1988        });
1989    }
1990
1991    #[gpui::test]
1992    async fn test_listing_models(cx: &mut TestAppContext) {
1993        init_test(cx);
1994        let fs = FakeFs::new(cx.executor());
1995        fs.insert_tree("/", json!({ "a": {}  })).await;
1996        let project = Project::test(fs.clone(), [], cx).await;
1997        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1998        let connection = NativeAgentConnection(
1999            NativeAgent::new(
2000                project.clone(),
2001                thread_store,
2002                Templates::new(),
2003                None,
2004                fs.clone(),
2005                &mut cx.to_async(),
2006            )
2007            .await
2008            .unwrap(),
2009        );
2010
2011        // Create a thread/session
2012        let acp_thread = cx
2013            .update(|cx| {
2014                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2015            })
2016            .await
2017            .unwrap();
2018
2019        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2020
2021        let models = cx
2022            .update(|cx| {
2023                connection
2024                    .model_selector(&session_id)
2025                    .unwrap()
2026                    .list_models(cx)
2027            })
2028            .await
2029            .unwrap();
2030
2031        let acp_thread::AgentModelList::Grouped(models) = models else {
2032            panic!("Unexpected model group");
2033        };
2034        assert_eq!(
2035            models,
2036            IndexMap::from_iter([(
2037                AgentModelGroupName("Fake".into()),
2038                vec![AgentModelInfo {
2039                    id: acp::ModelId::new("fake/fake"),
2040                    name: "Fake".into(),
2041                    description: None,
2042                    icon: Some(acp_thread::AgentModelIcon::Named(
2043                        ui::IconName::ZedAssistant
2044                    )),
2045                    is_latest: false,
2046                    cost: None,
2047                }]
2048            )])
2049        );
2050    }
2051
2052    #[gpui::test]
2053    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
2054        init_test(cx);
2055        let fs = FakeFs::new(cx.executor());
2056        fs.create_dir(paths::settings_file().parent().unwrap())
2057            .await
2058            .unwrap();
2059        fs.insert_file(
2060            paths::settings_file(),
2061            json!({
2062                "agent": {
2063                    "default_model": {
2064                        "provider": "foo",
2065                        "model": "bar"
2066                    }
2067                }
2068            })
2069            .to_string()
2070            .into_bytes(),
2071        )
2072        .await;
2073        let project = Project::test(fs.clone(), [], cx).await;
2074
2075        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2076
2077        // Create the agent and connection
2078        let agent = NativeAgent::new(
2079            project.clone(),
2080            thread_store,
2081            Templates::new(),
2082            None,
2083            fs.clone(),
2084            &mut cx.to_async(),
2085        )
2086        .await
2087        .unwrap();
2088        let connection = NativeAgentConnection(agent.clone());
2089
2090        // Create a thread/session
2091        let acp_thread = cx
2092            .update(|cx| {
2093                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2094            })
2095            .await
2096            .unwrap();
2097
2098        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2099
2100        // Select a model
2101        let selector = connection.model_selector(&session_id).unwrap();
2102        let model_id = acp::ModelId::new("fake/fake");
2103        cx.update(|cx| selector.select_model(model_id.clone(), cx))
2104            .await
2105            .unwrap();
2106
2107        // Verify the thread has the selected model
2108        agent.read_with(cx, |agent, _| {
2109            let session = agent.sessions.get(&session_id).unwrap();
2110            session.thread.read_with(cx, |thread, _| {
2111                assert_eq!(thread.model().unwrap().id().0, "fake");
2112            });
2113        });
2114
2115        cx.run_until_parked();
2116
2117        // Verify settings file was updated
2118        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2119        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2120
2121        // Check that the agent settings contain the selected model
2122        assert_eq!(
2123            settings_json["agent"]["default_model"]["model"],
2124            json!("fake")
2125        );
2126        assert_eq!(
2127            settings_json["agent"]["default_model"]["provider"],
2128            json!("fake")
2129        );
2130
2131        // Register a thinking model and select it.
2132        cx.update(|cx| {
2133            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2134                "fake-corp",
2135                "fake-thinking",
2136                "Fake Thinking",
2137                true,
2138            ));
2139            let thinking_provider = Arc::new(
2140                FakeLanguageModelProvider::new(
2141                    LanguageModelProviderId::from("fake-corp".to_string()),
2142                    LanguageModelProviderName::from("Fake Corp".to_string()),
2143                )
2144                .with_models(vec![thinking_model]),
2145            );
2146            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2147                registry.register_provider(thinking_provider, cx);
2148            });
2149        });
2150        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2151
2152        let selector = connection.model_selector(&session_id).unwrap();
2153        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2154            .await
2155            .unwrap();
2156        cx.run_until_parked();
2157
2158        // Verify enable_thinking was written to settings as true.
2159        let settings_content = fs.load(paths::settings_file()).await.unwrap();
2160        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2161        assert_eq!(
2162            settings_json["agent"]["default_model"]["enable_thinking"],
2163            json!(true),
2164            "selecting a thinking model should persist enable_thinking: true to settings"
2165        );
2166    }
2167
2168    #[gpui::test]
2169    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2170        init_test(cx);
2171        let fs = FakeFs::new(cx.executor());
2172        fs.create_dir(paths::settings_file().parent().unwrap())
2173            .await
2174            .unwrap();
2175        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2176        let project = Project::test(fs.clone(), [], cx).await;
2177
2178        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2179        let agent = NativeAgent::new(
2180            project.clone(),
2181            thread_store,
2182            Templates::new(),
2183            None,
2184            fs.clone(),
2185            &mut cx.to_async(),
2186        )
2187        .await
2188        .unwrap();
2189        let connection = NativeAgentConnection(agent.clone());
2190
2191        let acp_thread = cx
2192            .update(|cx| {
2193                Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2194            })
2195            .await
2196            .unwrap();
2197        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2198
2199        // Register a second provider with a thinking model.
2200        cx.update(|cx| {
2201            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2202                "fake-corp",
2203                "fake-thinking",
2204                "Fake Thinking",
2205                true,
2206            ));
2207            let thinking_provider = Arc::new(
2208                FakeLanguageModelProvider::new(
2209                    LanguageModelProviderId::from("fake-corp".to_string()),
2210                    LanguageModelProviderName::from("Fake Corp".to_string()),
2211                )
2212                .with_models(vec![thinking_model]),
2213            );
2214            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2215                registry.register_provider(thinking_provider, cx);
2216            });
2217        });
2218        // Refresh the agent's model list so it picks up the new provider.
2219        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2220
2221        // Thread starts with thinking_enabled = false (the default).
2222        agent.read_with(cx, |agent, _| {
2223            let session = agent.sessions.get(&session_id).unwrap();
2224            session.thread.read_with(cx, |thread, _| {
2225                assert!(!thread.thinking_enabled(), "thinking defaults to false");
2226            });
2227        });
2228
2229        // Select the thinking model via select_model.
2230        let selector = connection.model_selector(&session_id).unwrap();
2231        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2232            .await
2233            .unwrap();
2234
2235        // select_model should have enabled thinking based on the model's supports_thinking().
2236        agent.read_with(cx, |agent, _| {
2237            let session = agent.sessions.get(&session_id).unwrap();
2238            session.thread.read_with(cx, |thread, _| {
2239                assert!(
2240                    thread.thinking_enabled(),
2241                    "select_model should enable thinking when model supports it"
2242                );
2243            });
2244        });
2245
2246        // Switch back to the non-thinking model.
2247        let selector = connection.model_selector(&session_id).unwrap();
2248        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2249            .await
2250            .unwrap();
2251
2252        // select_model should have disabled thinking.
2253        agent.read_with(cx, |agent, _| {
2254            let session = agent.sessions.get(&session_id).unwrap();
2255            session.thread.read_with(cx, |thread, _| {
2256                assert!(
2257                    !thread.thinking_enabled(),
2258                    "select_model should disable thinking when model does not support it"
2259                );
2260            });
2261        });
2262    }
2263
2264    #[gpui::test]
2265    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2266        init_test(cx);
2267        let fs = FakeFs::new(cx.executor());
2268        fs.insert_tree("/", json!({ "a": {} })).await;
2269        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2270        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2271        let agent = NativeAgent::new(
2272            project.clone(),
2273            thread_store.clone(),
2274            Templates::new(),
2275            None,
2276            fs.clone(),
2277            &mut cx.to_async(),
2278        )
2279        .await
2280        .unwrap();
2281        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2282
2283        // Register a thinking model.
2284        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2285            "fake-corp",
2286            "fake-thinking",
2287            "Fake Thinking",
2288            true,
2289        ));
2290        let thinking_provider = Arc::new(
2291            FakeLanguageModelProvider::new(
2292                LanguageModelProviderId::from("fake-corp".to_string()),
2293                LanguageModelProviderName::from("Fake Corp".to_string()),
2294            )
2295            .with_models(vec![thinking_model.clone()]),
2296        );
2297        cx.update(|cx| {
2298            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2299                registry.register_provider(thinking_provider, cx);
2300            });
2301        });
2302        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2303
2304        // Create a thread and select the thinking model.
2305        let acp_thread = cx
2306            .update(|cx| {
2307                connection
2308                    .clone()
2309                    .new_session(project.clone(), Path::new("/a"), cx)
2310            })
2311            .await
2312            .unwrap();
2313        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2314
2315        let selector = connection.model_selector(&session_id).unwrap();
2316        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2317            .await
2318            .unwrap();
2319
2320        // Verify thinking is enabled after selecting the thinking model.
2321        let thread = agent.read_with(cx, |agent, _| {
2322            agent.sessions.get(&session_id).unwrap().thread.clone()
2323        });
2324        thread.read_with(cx, |thread, _| {
2325            assert!(
2326                thread.thinking_enabled(),
2327                "thinking should be enabled after selecting thinking model"
2328            );
2329        });
2330
2331        // Send a message so the thread gets persisted.
2332        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2333        let send = cx.foreground_executor().spawn(send);
2334        cx.run_until_parked();
2335
2336        thinking_model.send_last_completion_stream_text_chunk("Response.");
2337        thinking_model.end_last_completion_stream();
2338
2339        send.await.unwrap();
2340        cx.run_until_parked();
2341
2342        // Close the session so it can be reloaded from disk.
2343        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2344            .await
2345            .unwrap();
2346        drop(thread);
2347        drop(acp_thread);
2348        agent.read_with(cx, |agent, _| {
2349            assert!(agent.sessions.is_empty());
2350        });
2351
2352        // Reload the thread and verify thinking_enabled is still true.
2353        let reloaded_acp_thread = agent
2354            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2355            .await
2356            .unwrap();
2357        let reloaded_thread = agent.read_with(cx, |agent, _| {
2358            agent.sessions.get(&session_id).unwrap().thread.clone()
2359        });
2360        reloaded_thread.read_with(cx, |thread, _| {
2361            assert!(
2362                thread.thinking_enabled(),
2363                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2364            );
2365        });
2366
2367        drop(reloaded_acp_thread);
2368    }
2369
2370    #[gpui::test]
2371    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2372        init_test(cx);
2373        let fs = FakeFs::new(cx.executor());
2374        fs.insert_tree("/", json!({ "a": {} })).await;
2375        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2376        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2377        let agent = NativeAgent::new(
2378            project.clone(),
2379            thread_store.clone(),
2380            Templates::new(),
2381            None,
2382            fs.clone(),
2383            &mut cx.to_async(),
2384        )
2385        .await
2386        .unwrap();
2387        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2388
2389        // Register a model where id() != name(), like real Anthropic models
2390        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2391        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2392            "fake-corp",
2393            "custom-model-id",
2394            "Custom Model Display Name",
2395            false,
2396        ));
2397        let provider = Arc::new(
2398            FakeLanguageModelProvider::new(
2399                LanguageModelProviderId::from("fake-corp".to_string()),
2400                LanguageModelProviderName::from("Fake Corp".to_string()),
2401            )
2402            .with_models(vec![model.clone()]),
2403        );
2404        cx.update(|cx| {
2405            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2406                registry.register_provider(provider, cx);
2407            });
2408        });
2409        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2410
2411        // Create a thread and select the model.
2412        let acp_thread = cx
2413            .update(|cx| {
2414                connection
2415                    .clone()
2416                    .new_session(project.clone(), Path::new("/a"), cx)
2417            })
2418            .await
2419            .unwrap();
2420        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2421
2422        let selector = connection.model_selector(&session_id).unwrap();
2423        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2424            .await
2425            .unwrap();
2426
2427        let thread = agent.read_with(cx, |agent, _| {
2428            agent.sessions.get(&session_id).unwrap().thread.clone()
2429        });
2430        thread.read_with(cx, |thread, _| {
2431            assert_eq!(
2432                thread.model().unwrap().id().0.as_ref(),
2433                "custom-model-id",
2434                "model should be set before persisting"
2435            );
2436        });
2437
2438        // Send a message so the thread gets persisted.
2439        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2440        let send = cx.foreground_executor().spawn(send);
2441        cx.run_until_parked();
2442
2443        model.send_last_completion_stream_text_chunk("Response.");
2444        model.end_last_completion_stream();
2445
2446        send.await.unwrap();
2447        cx.run_until_parked();
2448
2449        // Close the session so it can be reloaded from disk.
2450        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2451            .await
2452            .unwrap();
2453        drop(thread);
2454        drop(acp_thread);
2455        agent.read_with(cx, |agent, _| {
2456            assert!(agent.sessions.is_empty());
2457        });
2458
2459        // Reload the thread and verify the model was preserved.
2460        let reloaded_acp_thread = agent
2461            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2462            .await
2463            .unwrap();
2464        let reloaded_thread = agent.read_with(cx, |agent, _| {
2465            agent.sessions.get(&session_id).unwrap().thread.clone()
2466        });
2467        reloaded_thread.read_with(cx, |thread, _| {
2468            let reloaded_model = thread
2469                .model()
2470                .expect("model should be present after reload");
2471            assert_eq!(
2472                reloaded_model.id().0.as_ref(),
2473                "custom-model-id",
2474                "reloaded thread should have the same model, not fall back to the default"
2475            );
2476        });
2477
2478        drop(reloaded_acp_thread);
2479    }
2480
2481    #[gpui::test]
2482    async fn test_save_load_thread(cx: &mut TestAppContext) {
2483        init_test(cx);
2484        let fs = FakeFs::new(cx.executor());
2485        fs.insert_tree(
2486            "/",
2487            json!({
2488                "a": {
2489                    "b.md": "Lorem"
2490                }
2491            }),
2492        )
2493        .await;
2494        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2495        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2496        let agent = NativeAgent::new(
2497            project.clone(),
2498            thread_store.clone(),
2499            Templates::new(),
2500            None,
2501            fs.clone(),
2502            &mut cx.to_async(),
2503        )
2504        .await
2505        .unwrap();
2506        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2507
2508        let acp_thread = cx
2509            .update(|cx| {
2510                connection
2511                    .clone()
2512                    .new_session(project.clone(), Path::new(""), cx)
2513            })
2514            .await
2515            .unwrap();
2516        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2517        let thread = agent.read_with(cx, |agent, _| {
2518            agent.sessions.get(&session_id).unwrap().thread.clone()
2519        });
2520
2521        // Ensure empty threads are not saved, even if they get mutated.
2522        let model = Arc::new(FakeLanguageModel::default());
2523        let summary_model = Arc::new(FakeLanguageModel::default());
2524        thread.update(cx, |thread, cx| {
2525            thread.set_model(model.clone(), cx);
2526            thread.set_summarization_model(Some(summary_model.clone()), cx);
2527        });
2528        cx.run_until_parked();
2529        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2530
2531        let send = acp_thread.update(cx, |thread, cx| {
2532            thread.send(
2533                vec![
2534                    "What does ".into(),
2535                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2536                        "b.md",
2537                        MentionUri::File {
2538                            abs_path: path!("/a/b.md").into(),
2539                        }
2540                        .to_uri()
2541                        .to_string(),
2542                    )),
2543                    " mean?".into(),
2544                ],
2545                cx,
2546            )
2547        });
2548        let send = cx.foreground_executor().spawn(send);
2549        cx.run_until_parked();
2550
2551        model.send_last_completion_stream_text_chunk("Lorem.");
2552        model.end_last_completion_stream();
2553        cx.run_until_parked();
2554        summary_model
2555            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2556        summary_model.end_last_completion_stream();
2557
2558        send.await.unwrap();
2559        let uri = MentionUri::File {
2560            abs_path: path!("/a/b.md").into(),
2561        }
2562        .to_uri();
2563        acp_thread.read_with(cx, |thread, cx| {
2564            assert_eq!(
2565                thread.to_markdown(cx),
2566                formatdoc! {"
2567                    ## User
2568
2569                    What does [@b.md]({uri}) mean?
2570
2571                    ## Assistant
2572
2573                    Lorem.
2574
2575                "}
2576            )
2577        });
2578
2579        cx.run_until_parked();
2580
2581        // Set a draft prompt with rich content blocks before saving.
2582        let draft_blocks = vec![
2583            acp::ContentBlock::Text(acp::TextContent::new("Check out ")),
2584            acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())),
2585            acp::ContentBlock::Text(acp::TextContent::new(" please")),
2586        ];
2587        acp_thread.update(cx, |thread, _cx| {
2588            thread.set_draft_prompt(Some(draft_blocks.clone()));
2589        });
2590        thread.update(cx, |_thread, cx| cx.notify());
2591        cx.run_until_parked();
2592
2593        // Close the session so it can be reloaded from disk.
2594        cx.update(|cx| connection.clone().close_session(&session_id, cx))
2595            .await
2596            .unwrap();
2597        drop(thread);
2598        drop(acp_thread);
2599        agent.read_with(cx, |agent, _| {
2600            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2601        });
2602
2603        // Ensure the thread can be reloaded from disk.
2604        assert_eq!(
2605            thread_entries(&thread_store, cx),
2606            vec![(
2607                session_id.clone(),
2608                format!("Explaining {}", path!("/a/b.md"))
2609            )]
2610        );
2611        let acp_thread = agent
2612            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2613            .await
2614            .unwrap();
2615        acp_thread.read_with(cx, |thread, cx| {
2616            assert_eq!(
2617                thread.to_markdown(cx),
2618                formatdoc! {"
2619                    ## User
2620
2621                    What does [@b.md]({uri}) mean?
2622
2623                    ## Assistant
2624
2625                    Lorem.
2626
2627                "}
2628            )
2629        });
2630
2631        // Ensure the draft prompt with rich content blocks survived the round-trip.
2632        acp_thread.read_with(cx, |thread, _| {
2633            assert_eq!(thread.draft_prompt(), Some(draft_blocks.as_slice()));
2634        });
2635    }
2636
2637    fn thread_entries(
2638        thread_store: &Entity<ThreadStore>,
2639        cx: &mut TestAppContext,
2640    ) -> Vec<(acp::SessionId, String)> {
2641        thread_store.read_with(cx, |store, _| {
2642            store
2643                .entries()
2644                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2645                .collect::<Vec<_>>()
2646        })
2647    }
2648
2649    fn init_test(cx: &mut TestAppContext) {
2650        env_logger::try_init().ok();
2651        cx.update(|cx| {
2652            let settings_store = SettingsStore::test(cx);
2653            cx.set_global(settings_store);
2654
2655            LanguageModelRegistry::test(cx);
2656        });
2657    }
2658}
2659
2660fn mcp_message_content_to_acp_content_block(
2661    content: context_server::types::MessageContent,
2662) -> acp::ContentBlock {
2663    match content {
2664        context_server::types::MessageContent::Text {
2665            text,
2666            annotations: _,
2667        } => text.into(),
2668        context_server::types::MessageContent::Image {
2669            data,
2670            mime_type,
2671            annotations: _,
2672        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2673        context_server::types::MessageContent::Audio {
2674            data,
2675            mime_type,
2676            annotations: _,
2677        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2678        context_server::types::MessageContent::Resource {
2679            resource,
2680            annotations: _,
2681        } => {
2682            let mut link =
2683                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2684            if let Some(mime_type) = resource.mime_type {
2685                link = link.mime_type(mime_type);
2686            }
2687            acp::ContentBlock::ResourceLink(link)
2688        }
2689    }
2690}