agent.rs

   1mod db;
   2mod edit_agent;
   3mod legacy_thread;
   4mod native_agent_server;
   5pub mod outline;
   6mod pattern_extraction;
   7mod templates;
   8#[cfg(test)]
   9mod tests;
  10mod thread;
  11mod thread_store;
  12mod tool_permissions;
  13mod tools;
  14
  15use context_server::ContextServerId;
  16pub use db::*;
  17pub use native_agent_server::NativeAgentServer;
  18pub use pattern_extraction::*;
  19pub use templates::*;
  20pub use thread::*;
  21pub use thread_store::*;
  22pub use tool_permissions::*;
  23pub use tools::*;
  24
  25use acp_thread::{
  26    AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
  27    AgentSessionListResponse, UserMessageId,
  28};
  29use agent_client_protocol as acp;
  30use anyhow::{Context as _, Result, anyhow};
  31use chrono::{DateTime, Utc};
  32use collections::{HashMap, HashSet, IndexMap};
  33use fs::Fs;
  34use futures::channel::{mpsc, oneshot};
  35use futures::future::Shared;
  36use futures::{StreamExt, future};
  37use gpui::{
  38    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  39};
  40use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  41use project::{Project, ProjectItem, ProjectPath, Worktree};
  42use prompt_store::{
  43    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  44    WorktreeContext,
  45};
  46use serde::{Deserialize, Serialize};
  47use settings::{LanguageModelSelection, update_settings_file};
  48use std::any::Any;
  49use std::path::{Path, PathBuf};
  50use std::rc::Rc;
  51use std::sync::Arc;
  52use util::ResultExt;
  53use util::rel_path::RelPath;
  54
  55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  56pub struct ProjectSnapshot {
  57    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  58    pub timestamp: DateTime<Utc>,
  59}
  60
  61pub struct RulesLoadingError {
  62    pub message: SharedString,
  63}
  64
  65/// Holds both the internal Thread and the AcpThread for a session
  66struct Session {
  67    /// The internal thread that processes messages
  68    thread: Entity<Thread>,
  69    /// The ACP thread that handles protocol communication
  70    acp_thread: WeakEntity<acp_thread::AcpThread>,
  71    pending_save: Task<()>,
  72    _subscriptions: Vec<Subscription>,
  73}
  74
  75pub struct LanguageModels {
  76    /// Access language model by ID
  77    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  78    /// Cached list for returning language model information
  79    model_list: acp_thread::AgentModelList,
  80    refresh_models_rx: watch::Receiver<()>,
  81    refresh_models_tx: watch::Sender<()>,
  82    _authenticate_all_providers_task: Task<()>,
  83}
  84
  85impl LanguageModels {
  86    fn new(cx: &mut App) -> Self {
  87        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  88
  89        let mut this = Self {
  90            models: HashMap::default(),
  91            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  92            refresh_models_rx,
  93            refresh_models_tx,
  94            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  95        };
  96        this.refresh_list(cx);
  97        this
  98    }
  99
 100    fn refresh_list(&mut self, cx: &App) {
 101        let providers = LanguageModelRegistry::global(cx)
 102            .read(cx)
 103            .visible_providers()
 104            .into_iter()
 105            .filter(|provider| provider.is_authenticated(cx))
 106            .collect::<Vec<_>>();
 107
 108        let mut language_model_list = IndexMap::default();
 109        let mut recommended_models = HashSet::default();
 110
 111        let mut recommended = Vec::new();
 112        for provider in &providers {
 113            for model in provider.recommended_models(cx) {
 114                recommended_models.insert((model.provider_id(), model.id()));
 115                recommended.push(Self::map_language_model_to_info(&model, provider));
 116            }
 117        }
 118        if !recommended.is_empty() {
 119            language_model_list.insert(
 120                acp_thread::AgentModelGroupName("Recommended".into()),
 121                recommended,
 122            );
 123        }
 124
 125        let mut models = HashMap::default();
 126        for provider in providers {
 127            let mut provider_models = Vec::new();
 128            for model in provider.provided_models(cx) {
 129                let model_info = Self::map_language_model_to_info(&model, &provider);
 130                let model_id = model_info.id.clone();
 131                provider_models.push(model_info);
 132                models.insert(model_id, model);
 133            }
 134            if !provider_models.is_empty() {
 135                language_model_list.insert(
 136                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 137                    provider_models,
 138                );
 139            }
 140        }
 141
 142        self.models = models;
 143        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 144        self.refresh_models_tx.send(()).ok();
 145    }
 146
 147    fn watch(&self) -> watch::Receiver<()> {
 148        self.refresh_models_rx.clone()
 149    }
 150
 151    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 152        self.models.get(model_id).cloned()
 153    }
 154
 155    fn map_language_model_to_info(
 156        model: &Arc<dyn LanguageModel>,
 157        provider: &Arc<dyn LanguageModelProvider>,
 158    ) -> acp_thread::AgentModelInfo {
 159        acp_thread::AgentModelInfo {
 160            id: Self::model_id(model),
 161            name: model.name().0,
 162            description: None,
 163            icon: Some(match provider.icon() {
 164                IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
 165                IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
 166            }),
 167            is_latest: model.is_latest(),
 168        }
 169    }
 170
 171    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 172        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 173    }
 174
 175    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 176        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 177            .read(cx)
 178            .visible_providers()
 179            .iter()
 180            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 181            .collect::<Vec<_>>();
 182
 183        cx.background_spawn(async move {
 184            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 185                if let Err(err) = authenticate_task.await {
 186                    match err {
 187                        language_model::AuthenticateError::CredentialsNotFound => {
 188                            // Since we're authenticating these providers in the
 189                            // background for the purposes of populating the
 190                            // language selector, we don't care about providers
 191                            // where the credentials are not found.
 192                        }
 193                        language_model::AuthenticateError::ConnectionRefused => {
 194                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 195                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 196                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 197                        }
 198                        _ => {
 199                            // Some providers have noisy failure states that we
 200                            // don't want to spam the logs with every time the
 201                            // language model selector is initialized.
 202                            //
 203                            // Ideally these should have more clear failure modes
 204                            // that we know are safe to ignore here, like what we do
 205                            // with `CredentialsNotFound` above.
 206                            match provider_id.0.as_ref() {
 207                                "lmstudio" | "ollama" => {
 208                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 209                                    //
 210                                    // These fail noisily, so we don't log them.
 211                                }
 212                                "copilot_chat" => {
 213                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 214                                }
 215                                _ => {
 216                                    log::error!(
 217                                        "Failed to authenticate provider: {}: {err:#}",
 218                                        provider_name.0
 219                                    );
 220                                }
 221                            }
 222                        }
 223                    }
 224                }
 225            }
 226        })
 227    }
 228}
 229
 230pub struct NativeAgent {
 231    /// Session ID -> Session mapping
 232    sessions: HashMap<acp::SessionId, Session>,
 233    thread_store: Entity<ThreadStore>,
 234    /// Shared project context for all threads
 235    project_context: Entity<ProjectContext>,
 236    project_context_needs_refresh: watch::Sender<()>,
 237    _maintain_project_context: Task<Result<()>>,
 238    context_server_registry: Entity<ContextServerRegistry>,
 239    /// Shared templates for all threads
 240    templates: Arc<Templates>,
 241    /// Cached model information
 242    models: LanguageModels,
 243    project: Entity<Project>,
 244    prompt_store: Option<Entity<PromptStore>>,
 245    fs: Arc<dyn Fs>,
 246    _subscriptions: Vec<Subscription>,
 247}
 248
 249impl NativeAgent {
 250    pub async fn new(
 251        project: Entity<Project>,
 252        thread_store: Entity<ThreadStore>,
 253        templates: Arc<Templates>,
 254        prompt_store: Option<Entity<PromptStore>>,
 255        fs: Arc<dyn Fs>,
 256        cx: &mut AsyncApp,
 257    ) -> Result<Entity<NativeAgent>> {
 258        log::debug!("Creating new NativeAgent");
 259
 260        let project_context = cx
 261            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
 262            .await;
 263
 264        Ok(cx.new(|cx| {
 265            let context_server_store = project.read(cx).context_server_store();
 266            let context_server_registry =
 267                cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
 268
 269            let mut subscriptions = vec![
 270                cx.subscribe(&project, Self::handle_project_event),
 271                cx.subscribe(
 272                    &LanguageModelRegistry::global(cx),
 273                    Self::handle_models_updated_event,
 274                ),
 275                cx.subscribe(
 276                    &context_server_store,
 277                    Self::handle_context_server_store_updated,
 278                ),
 279                cx.subscribe(
 280                    &context_server_registry,
 281                    Self::handle_context_server_registry_event,
 282                ),
 283            ];
 284            if let Some(prompt_store) = prompt_store.as_ref() {
 285                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 286            }
 287
 288            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 289                watch::channel(());
 290            Self {
 291                sessions: HashMap::default(),
 292                thread_store,
 293                project_context: cx.new(|_| project_context),
 294                project_context_needs_refresh: project_context_needs_refresh_tx,
 295                _maintain_project_context: cx.spawn(async move |this, cx| {
 296                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 297                }),
 298                context_server_registry,
 299                templates,
 300                models: LanguageModels::new(cx),
 301                project,
 302                prompt_store,
 303                fs,
 304                _subscriptions: subscriptions,
 305            }
 306        }))
 307    }
 308
 309    fn new_session(
 310        &mut self,
 311        project: Entity<Project>,
 312        cx: &mut Context<Self>,
 313    ) -> Entity<AcpThread> {
 314        // Create Thread
 315        // Fetch default model from registry settings
 316        let registry = LanguageModelRegistry::read_global(cx);
 317        // Log available models for debugging
 318        let available_count = registry.available_models(cx).count();
 319        log::debug!("Total available models: {}", available_count);
 320
 321        let default_model = registry.default_model().and_then(|default_model| {
 322            self.models
 323                .model_from_id(&LanguageModels::model_id(&default_model.model))
 324        });
 325        let thread = cx.new(|cx| {
 326            Thread::new(
 327                project.clone(),
 328                self.project_context.clone(),
 329                self.context_server_registry.clone(),
 330                self.templates.clone(),
 331                default_model,
 332                cx,
 333            )
 334        });
 335
 336        self.register_session(thread, cx)
 337    }
 338
 339    fn register_session(
 340        &mut self,
 341        thread_handle: Entity<Thread>,
 342        cx: &mut Context<Self>,
 343    ) -> Entity<AcpThread> {
 344        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 345
 346        let thread = thread_handle.read(cx);
 347        let session_id = thread.id().clone();
 348        let title = thread.title();
 349        let project = thread.project.clone();
 350        let action_log = thread.action_log.clone();
 351        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 352        let acp_thread = cx.new(|cx| {
 353            acp_thread::AcpThread::new(
 354                title,
 355                connection,
 356                project.clone(),
 357                action_log.clone(),
 358                session_id.clone(),
 359                prompt_capabilities_rx,
 360                cx,
 361            )
 362        });
 363
 364        let registry = LanguageModelRegistry::read_global(cx);
 365        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 366
 367        thread_handle.update(cx, |thread, cx| {
 368            thread.set_summarization_model(summarization_model, cx);
 369            thread.add_default_tools(
 370                Rc::new(AcpThreadEnvironment {
 371                    acp_thread: acp_thread.downgrade(),
 372                }) as _,
 373                cx,
 374            )
 375        });
 376
 377        let subscriptions = vec![
 378            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 379                this.sessions.remove(acp_thread.session_id());
 380            }),
 381            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 382            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 383            cx.observe(&thread_handle, move |this, thread, cx| {
 384                this.save_thread(thread, cx)
 385            }),
 386        ];
 387
 388        self.sessions.insert(
 389            session_id,
 390            Session {
 391                thread: thread_handle,
 392                acp_thread: acp_thread.downgrade(),
 393                _subscriptions: subscriptions,
 394                pending_save: Task::ready(()),
 395            },
 396        );
 397
 398        self.update_available_commands(cx);
 399
 400        acp_thread
 401    }
 402
 403    pub fn models(&self) -> &LanguageModels {
 404        &self.models
 405    }
 406
 407    async fn maintain_project_context(
 408        this: WeakEntity<Self>,
 409        mut needs_refresh: watch::Receiver<()>,
 410        cx: &mut AsyncApp,
 411    ) -> Result<()> {
 412        while needs_refresh.changed().await.is_ok() {
 413            let project_context = this
 414                .update(cx, |this, cx| {
 415                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 416                })?
 417                .await;
 418            this.update(cx, |this, cx| {
 419                this.project_context = cx.new(|_| project_context);
 420            })?;
 421        }
 422
 423        Ok(())
 424    }
 425
 426    fn build_project_context(
 427        project: &Entity<Project>,
 428        prompt_store: Option<&Entity<PromptStore>>,
 429        cx: &mut App,
 430    ) -> Task<ProjectContext> {
 431        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 432        let worktree_tasks = worktrees
 433            .into_iter()
 434            .map(|worktree| {
 435                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 436            })
 437            .collect::<Vec<_>>();
 438        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 439            prompt_store.read_with(cx, |prompt_store, cx| {
 440                let prompts = prompt_store.default_prompt_metadata();
 441                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 442                    let contents = prompt_store.load(prompt_metadata.id, cx);
 443                    async move { (contents.await, prompt_metadata) }
 444                });
 445                cx.background_spawn(future::join_all(load_tasks))
 446            })
 447        } else {
 448            Task::ready(vec![])
 449        };
 450
 451        cx.spawn(async move |_cx| {
 452            let (worktrees, default_user_rules) =
 453                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 454
 455            let worktrees = worktrees
 456                .into_iter()
 457                .map(|(worktree, _rules_error)| {
 458                    // TODO: show error message
 459                    // if let Some(rules_error) = rules_error {
 460                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 461                    // }
 462                    worktree
 463                })
 464                .collect::<Vec<_>>();
 465
 466            let default_user_rules = default_user_rules
 467                .into_iter()
 468                .flat_map(|(contents, prompt_metadata)| match contents {
 469                    Ok(contents) => Some(UserRulesContext {
 470                        uuid: prompt_metadata.id.as_user()?,
 471                        title: prompt_metadata.title.map(|title| title.to_string()),
 472                        contents,
 473                    }),
 474                    Err(_err) => {
 475                        // TODO: show error message
 476                        // this.update(cx, |_, cx| {
 477                        //     cx.emit(RulesLoadingError {
 478                        //         message: format!("{err:?}").into(),
 479                        //     });
 480                        // })
 481                        // .ok();
 482                        None
 483                    }
 484                })
 485                .collect::<Vec<_>>();
 486
 487            ProjectContext::new(worktrees, default_user_rules)
 488        })
 489    }
 490
 491    fn load_worktree_info_for_system_prompt(
 492        worktree: Entity<Worktree>,
 493        project: Entity<Project>,
 494        cx: &mut App,
 495    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 496        let tree = worktree.read(cx);
 497        let root_name = tree.root_name_str().into();
 498        let abs_path = tree.abs_path();
 499
 500        let mut context = WorktreeContext {
 501            root_name,
 502            abs_path,
 503            rules_file: None,
 504        };
 505
 506        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 507        let Some(rules_task) = rules_task else {
 508            return Task::ready((context, None));
 509        };
 510
 511        cx.spawn(async move |_| {
 512            let (rules_file, rules_file_error) = match rules_task.await {
 513                Ok(rules_file) => (Some(rules_file), None),
 514                Err(err) => (
 515                    None,
 516                    Some(RulesLoadingError {
 517                        message: format!("{err}").into(),
 518                    }),
 519                ),
 520            };
 521            context.rules_file = rules_file;
 522            (context, rules_file_error)
 523        })
 524    }
 525
 526    fn load_worktree_rules_file(
 527        worktree: Entity<Worktree>,
 528        project: Entity<Project>,
 529        cx: &mut App,
 530    ) -> Option<Task<Result<RulesFileContext>>> {
 531        let worktree = worktree.read(cx);
 532        let worktree_id = worktree.id();
 533        let selected_rules_file = RULES_FILE_NAMES
 534            .into_iter()
 535            .filter_map(|name| {
 536                worktree
 537                    .entry_for_path(RelPath::unix(name).unwrap())
 538                    .filter(|entry| entry.is_file())
 539                    .map(|entry| entry.path.clone())
 540            })
 541            .next();
 542
 543        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 544        // supported. This doesn't seem to occur often in GitHub repositories.
 545        selected_rules_file.map(|path_in_worktree| {
 546            let project_path = ProjectPath {
 547                worktree_id,
 548                path: path_in_worktree.clone(),
 549            };
 550            let buffer_task =
 551                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 552            let rope_task = cx.spawn(async move |cx| {
 553                let buffer = buffer_task.await?;
 554                let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
 555                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 556                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 557                })?;
 558                anyhow::Ok((project_entry_id, rope))
 559            });
 560            // Build a string from the rope on a background thread.
 561            cx.background_spawn(async move {
 562                let (project_entry_id, rope) = rope_task.await?;
 563                anyhow::Ok(RulesFileContext {
 564                    path_in_worktree,
 565                    text: rope.to_string().trim().to_string(),
 566                    project_entry_id: project_entry_id.to_usize(),
 567                })
 568            })
 569        })
 570    }
 571
 572    fn handle_thread_title_updated(
 573        &mut self,
 574        thread: Entity<Thread>,
 575        _: &TitleUpdated,
 576        cx: &mut Context<Self>,
 577    ) {
 578        let session_id = thread.read(cx).id();
 579        let Some(session) = self.sessions.get(session_id) else {
 580            return;
 581        };
 582        let thread = thread.downgrade();
 583        let acp_thread = session.acp_thread.clone();
 584        cx.spawn(async move |_, cx| {
 585            let title = thread.read_with(cx, |thread, _| thread.title())?;
 586            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 587            task.await
 588        })
 589        .detach_and_log_err(cx);
 590    }
 591
 592    fn handle_thread_token_usage_updated(
 593        &mut self,
 594        thread: Entity<Thread>,
 595        usage: &TokenUsageUpdated,
 596        cx: &mut Context<Self>,
 597    ) {
 598        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 599            return;
 600        };
 601        session
 602            .acp_thread
 603            .update(cx, |acp_thread, cx| {
 604                acp_thread.update_token_usage(usage.0.clone(), cx);
 605            })
 606            .ok();
 607    }
 608
 609    fn handle_project_event(
 610        &mut self,
 611        _project: Entity<Project>,
 612        event: &project::Event,
 613        _cx: &mut Context<Self>,
 614    ) {
 615        match event {
 616            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 617                self.project_context_needs_refresh.send(()).ok();
 618            }
 619            project::Event::WorktreeUpdatedEntries(_, items) => {
 620                if items.iter().any(|(path, _, _)| {
 621                    RULES_FILE_NAMES
 622                        .iter()
 623                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 624                }) {
 625                    self.project_context_needs_refresh.send(()).ok();
 626                }
 627            }
 628            _ => {}
 629        }
 630    }
 631
 632    fn handle_prompts_updated_event(
 633        &mut self,
 634        _prompt_store: Entity<PromptStore>,
 635        _event: &prompt_store::PromptsUpdatedEvent,
 636        _cx: &mut Context<Self>,
 637    ) {
 638        self.project_context_needs_refresh.send(()).ok();
 639    }
 640
 641    fn handle_models_updated_event(
 642        &mut self,
 643        _registry: Entity<LanguageModelRegistry>,
 644        _event: &language_model::Event,
 645        cx: &mut Context<Self>,
 646    ) {
 647        self.models.refresh_list(cx);
 648
 649        let registry = LanguageModelRegistry::read_global(cx);
 650        let default_model = registry.default_model().map(|m| m.model);
 651        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 652
 653        for session in self.sessions.values_mut() {
 654            session.thread.update(cx, |thread, cx| {
 655                if thread.model().is_none()
 656                    && let Some(model) = default_model.clone()
 657                {
 658                    thread.set_model(model, cx);
 659                    cx.notify();
 660                }
 661                thread.set_summarization_model(summarization_model.clone(), cx);
 662            });
 663        }
 664    }
 665
 666    fn handle_context_server_store_updated(
 667        &mut self,
 668        _store: Entity<project::context_server_store::ContextServerStore>,
 669        _event: &project::context_server_store::ServerStatusChangedEvent,
 670        cx: &mut Context<Self>,
 671    ) {
 672        self.update_available_commands(cx);
 673    }
 674
 675    fn handle_context_server_registry_event(
 676        &mut self,
 677        _registry: Entity<ContextServerRegistry>,
 678        event: &ContextServerRegistryEvent,
 679        cx: &mut Context<Self>,
 680    ) {
 681        match event {
 682            ContextServerRegistryEvent::ToolsChanged => {}
 683            ContextServerRegistryEvent::PromptsChanged => {
 684                self.update_available_commands(cx);
 685            }
 686        }
 687    }
 688
 689    fn update_available_commands(&self, cx: &mut Context<Self>) {
 690        let available_commands = self.build_available_commands(cx);
 691        for session in self.sessions.values() {
 692            if let Some(acp_thread) = session.acp_thread.upgrade() {
 693                acp_thread.update(cx, |thread, cx| {
 694                    thread
 695                        .handle_session_update(
 696                            acp::SessionUpdate::AvailableCommandsUpdate(
 697                                acp::AvailableCommandsUpdate::new(available_commands.clone()),
 698                            ),
 699                            cx,
 700                        )
 701                        .log_err();
 702                });
 703            }
 704        }
 705    }
 706
 707    fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
 708        let registry = self.context_server_registry.read(cx);
 709
 710        let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
 711        for context_server_prompt in registry.prompts() {
 712            *prompt_name_counts
 713                .entry(context_server_prompt.prompt.name.as_str())
 714                .or_insert(0) += 1;
 715        }
 716
 717        registry
 718            .prompts()
 719            .flat_map(|context_server_prompt| {
 720                let prompt = &context_server_prompt.prompt;
 721
 722                let should_prefix = prompt_name_counts
 723                    .get(prompt.name.as_str())
 724                    .copied()
 725                    .unwrap_or(0)
 726                    > 1;
 727
 728                let name = if should_prefix {
 729                    format!("{}.{}", context_server_prompt.server_id, prompt.name)
 730                } else {
 731                    prompt.name.clone()
 732                };
 733
 734                let mut command = acp::AvailableCommand::new(
 735                    name,
 736                    prompt.description.clone().unwrap_or_default(),
 737                );
 738
 739                match prompt.arguments.as_deref() {
 740                    Some([arg]) => {
 741                        let hint = format!("<{}>", arg.name);
 742
 743                        command = command.input(acp::AvailableCommandInput::Unstructured(
 744                            acp::UnstructuredCommandInput::new(hint),
 745                        ));
 746                    }
 747                    Some([]) | None => {}
 748                    Some(_) => {
 749                        // skip >1 argument commands since we don't support them yet
 750                        return None;
 751                    }
 752                }
 753
 754                Some(command)
 755            })
 756            .collect()
 757    }
 758
 759    pub fn load_thread(
 760        &mut self,
 761        id: acp::SessionId,
 762        cx: &mut Context<Self>,
 763    ) -> Task<Result<Entity<Thread>>> {
 764        let database_future = ThreadsDatabase::connect(cx);
 765        cx.spawn(async move |this, cx| {
 766            let database = database_future.await.map_err(|err| anyhow!(err))?;
 767            let db_thread = database
 768                .load_thread(id.clone())
 769                .await?
 770                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 771
 772            this.update(cx, |this, cx| {
 773                let summarization_model = LanguageModelRegistry::read_global(cx)
 774                    .thread_summary_model()
 775                    .map(|c| c.model);
 776
 777                cx.new(|cx| {
 778                    let mut thread = Thread::from_db(
 779                        id.clone(),
 780                        db_thread,
 781                        this.project.clone(),
 782                        this.project_context.clone(),
 783                        this.context_server_registry.clone(),
 784                        this.templates.clone(),
 785                        cx,
 786                    );
 787                    thread.set_summarization_model(summarization_model, cx);
 788                    thread
 789                })
 790            })
 791        })
 792    }
 793
 794    pub fn open_thread(
 795        &mut self,
 796        id: acp::SessionId,
 797        cx: &mut Context<Self>,
 798    ) -> Task<Result<Entity<AcpThread>>> {
 799        let task = self.load_thread(id, cx);
 800        cx.spawn(async move |this, cx| {
 801            let thread = task.await?;
 802            let acp_thread =
 803                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 804            let events = thread.update(cx, |thread, cx| thread.replay(cx));
 805            cx.update(|cx| {
 806                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 807            })
 808            .await?;
 809            Ok(acp_thread)
 810        })
 811    }
 812
 813    pub fn thread_summary(
 814        &mut self,
 815        id: acp::SessionId,
 816        cx: &mut Context<Self>,
 817    ) -> Task<Result<SharedString>> {
 818        let thread = self.open_thread(id.clone(), cx);
 819        cx.spawn(async move |this, cx| {
 820            let acp_thread = thread.await?;
 821            let result = this
 822                .update(cx, |this, cx| {
 823                    this.sessions
 824                        .get(&id)
 825                        .unwrap()
 826                        .thread
 827                        .update(cx, |thread, cx| thread.summary(cx))
 828                })?
 829                .await
 830                .context("Failed to generate summary")?;
 831            drop(acp_thread);
 832            Ok(result)
 833        })
 834    }
 835
 836    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 837        if thread.read(cx).is_empty() {
 838            return;
 839        }
 840
 841        let database_future = ThreadsDatabase::connect(cx);
 842        let (id, db_thread) =
 843            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 844        let Some(session) = self.sessions.get_mut(&id) else {
 845            return;
 846        };
 847        let thread_store = self.thread_store.clone();
 848        session.pending_save = cx.spawn(async move |_, cx| {
 849            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 850                return;
 851            };
 852            let db_thread = db_thread.await;
 853            database.save_thread(id, db_thread).await.log_err();
 854            thread_store.update(cx, |store, cx| store.reload(cx));
 855        });
 856    }
 857
 858    fn send_mcp_prompt(
 859        &self,
 860        message_id: UserMessageId,
 861        session_id: agent_client_protocol::SessionId,
 862        prompt_name: String,
 863        server_id: ContextServerId,
 864        arguments: HashMap<String, String>,
 865        original_content: Vec<acp::ContentBlock>,
 866        cx: &mut Context<Self>,
 867    ) -> Task<Result<acp::PromptResponse>> {
 868        let server_store = self.context_server_registry.read(cx).server_store().clone();
 869        let path_style = self.project.read(cx).path_style(cx);
 870
 871        cx.spawn(async move |this, cx| {
 872            let prompt =
 873                crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
 874
 875            let (acp_thread, thread) = this.update(cx, |this, _cx| {
 876                let session = this
 877                    .sessions
 878                    .get(&session_id)
 879                    .context("Failed to get session")?;
 880                anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
 881            })??;
 882
 883            let mut last_is_user = true;
 884
 885            thread.update(cx, |thread, cx| {
 886                thread.push_acp_user_block(
 887                    message_id,
 888                    original_content.into_iter().skip(1),
 889                    path_style,
 890                    cx,
 891                );
 892            });
 893
 894            for message in prompt.messages {
 895                let context_server::types::PromptMessage { role, content } = message;
 896                let block = mcp_message_content_to_acp_content_block(content);
 897
 898                match role {
 899                    context_server::types::Role::User => {
 900                        let id = acp_thread::UserMessageId::new();
 901
 902                        acp_thread.update(cx, |acp_thread, cx| {
 903                            acp_thread.push_user_content_block_with_indent(
 904                                Some(id.clone()),
 905                                block.clone(),
 906                                true,
 907                                cx,
 908                            );
 909                        })?;
 910
 911                        thread.update(cx, |thread, cx| {
 912                            thread.push_acp_user_block(id, [block], path_style, cx);
 913                        });
 914                    }
 915                    context_server::types::Role::Assistant => {
 916                        acp_thread.update(cx, |acp_thread, cx| {
 917                            acp_thread.push_assistant_content_block_with_indent(
 918                                block.clone(),
 919                                false,
 920                                true,
 921                                cx,
 922                            );
 923                        })?;
 924
 925                        thread.update(cx, |thread, cx| {
 926                            thread.push_acp_agent_block(block, cx);
 927                        });
 928                    }
 929                }
 930
 931                last_is_user = role == context_server::types::Role::User;
 932            }
 933
 934            let response_stream = thread.update(cx, |thread, cx| {
 935                if last_is_user {
 936                    thread.send_existing(cx)
 937                } else {
 938                    // Resume if MCP prompt did not end with a user message
 939                    thread.resume(cx)
 940                }
 941            })?;
 942
 943            cx.update(|cx| {
 944                NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
 945            })
 946            .await
 947        })
 948    }
 949}
 950
 951/// Wrapper struct that implements the AgentConnection trait
 952#[derive(Clone)]
 953pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 954
 955impl NativeAgentConnection {
 956    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 957        self.0
 958            .read(cx)
 959            .sessions
 960            .get(session_id)
 961            .map(|session| session.thread.clone())
 962    }
 963
 964    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 965        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 966    }
 967
 968    fn run_turn(
 969        &self,
 970        session_id: acp::SessionId,
 971        cx: &mut App,
 972        f: impl 'static
 973        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 974    ) -> Task<Result<acp::PromptResponse>> {
 975        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 976            agent
 977                .sessions
 978                .get_mut(&session_id)
 979                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 980        }) else {
 981            return Task::ready(Err(anyhow!("Session not found")));
 982        };
 983        log::debug!("Found session for: {}", session_id);
 984
 985        let response_stream = match f(thread, cx) {
 986            Ok(stream) => stream,
 987            Err(err) => return Task::ready(Err(err)),
 988        };
 989        Self::handle_thread_events(response_stream, acp_thread, cx)
 990    }
 991
 992    fn handle_thread_events(
 993        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 994        acp_thread: WeakEntity<AcpThread>,
 995        cx: &App,
 996    ) -> Task<Result<acp::PromptResponse>> {
 997        cx.spawn(async move |cx| {
 998            // Handle response stream and forward to session.acp_thread
 999            while let Some(result) = events.next().await {
1000                match result {
1001                    Ok(event) => {
1002                        log::trace!("Received completion event: {:?}", event);
1003
1004                        match event {
1005                            ThreadEvent::UserMessage(message) => {
1006                                acp_thread.update(cx, |thread, cx| {
1007                                    for content in message.content {
1008                                        thread.push_user_content_block(
1009                                            Some(message.id.clone()),
1010                                            content.into(),
1011                                            cx,
1012                                        );
1013                                    }
1014                                })?;
1015                            }
1016                            ThreadEvent::AgentText(text) => {
1017                                acp_thread.update(cx, |thread, cx| {
1018                                    thread.push_assistant_content_block(text.into(), false, cx)
1019                                })?;
1020                            }
1021                            ThreadEvent::AgentThinking(text) => {
1022                                acp_thread.update(cx, |thread, cx| {
1023                                    thread.push_assistant_content_block(text.into(), true, cx)
1024                                })?;
1025                            }
1026                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1027                                tool_call,
1028                                options,
1029                                response,
1030                                context: _,
1031                            }) => {
1032                                let outcome_task = acp_thread.update(cx, |thread, cx| {
1033                                    thread.request_tool_call_authorization(
1034                                        tool_call, options, true, cx,
1035                                    )
1036                                })??;
1037                                cx.background_spawn(async move {
1038                                    if let acp::RequestPermissionOutcome::Selected(
1039                                        acp::SelectedPermissionOutcome { option_id, .. },
1040                                    ) = outcome_task.await
1041                                    {
1042                                        response
1043                                            .send(option_id)
1044                                            .map(|_| anyhow!("authorization receiver was dropped"))
1045                                            .log_err();
1046                                    }
1047                                })
1048                                .detach();
1049                            }
1050                            ThreadEvent::ToolCall(tool_call) => {
1051                                acp_thread.update(cx, |thread, cx| {
1052                                    thread.upsert_tool_call(tool_call, cx)
1053                                })??;
1054                            }
1055                            ThreadEvent::ToolCallUpdate(update) => {
1056                                acp_thread.update(cx, |thread, cx| {
1057                                    thread.update_tool_call(update, cx)
1058                                })??;
1059                            }
1060                            ThreadEvent::Retry(status) => {
1061                                acp_thread.update(cx, |thread, cx| {
1062                                    thread.update_retry_status(status, cx)
1063                                })?;
1064                            }
1065                            ThreadEvent::Stop(stop_reason) => {
1066                                log::debug!("Assistant message complete: {:?}", stop_reason);
1067                                return Ok(acp::PromptResponse::new(stop_reason));
1068                            }
1069                        }
1070                    }
1071                    Err(e) => {
1072                        log::error!("Error in model response stream: {:?}", e);
1073                        return Err(e);
1074                    }
1075                }
1076            }
1077
1078            log::debug!("Response stream completed");
1079            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1080        })
1081    }
1082}
1083
1084struct Command<'a> {
1085    prompt_name: &'a str,
1086    arg_value: &'a str,
1087    explicit_server_id: Option<&'a str>,
1088}
1089
1090impl<'a> Command<'a> {
1091    fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1092        let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1093            return None;
1094        };
1095        let text = text_content.text.trim();
1096        let command = text.strip_prefix('/')?;
1097        let (command, arg_value) = command
1098            .split_once(char::is_whitespace)
1099            .unwrap_or((command, ""));
1100
1101        if let Some((server_id, prompt_name)) = command.split_once('.') {
1102            Some(Self {
1103                prompt_name,
1104                arg_value,
1105                explicit_server_id: Some(server_id),
1106            })
1107        } else {
1108            Some(Self {
1109                prompt_name: command,
1110                arg_value,
1111                explicit_server_id: None,
1112            })
1113        }
1114    }
1115}
1116
1117struct NativeAgentModelSelector {
1118    session_id: acp::SessionId,
1119    connection: NativeAgentConnection,
1120}
1121
1122impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1123    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1124        log::debug!("NativeAgentConnection::list_models called");
1125        let list = self.connection.0.read(cx).models.model_list.clone();
1126        Task::ready(if list.is_empty() {
1127            Err(anyhow::anyhow!("No models available"))
1128        } else {
1129            Ok(list)
1130        })
1131    }
1132
1133    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1134        log::debug!(
1135            "Setting model for session {}: {}",
1136            self.session_id,
1137            model_id
1138        );
1139        let Some(thread) = self
1140            .connection
1141            .0
1142            .read(cx)
1143            .sessions
1144            .get(&self.session_id)
1145            .map(|session| session.thread.clone())
1146        else {
1147            return Task::ready(Err(anyhow!("Session not found")));
1148        };
1149
1150        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1151            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1152        };
1153
1154        // We want to reset the effort level when switching models, as the currently-selected effort level may
1155        // not be compatible.
1156        let effort = model
1157            .default_effort_level()
1158            .map(|effort_level| effort_level.value.to_string());
1159
1160        thread.update(cx, |thread, cx| {
1161            thread.set_model(model.clone(), cx);
1162            thread.set_thinking_effort(effort.clone(), cx);
1163            thread.set_thinking_enabled(model.supports_thinking(), cx);
1164        });
1165
1166        update_settings_file(
1167            self.connection.0.read(cx).fs.clone(),
1168            cx,
1169            move |settings, cx| {
1170                let provider = model.provider_id().0.to_string();
1171                let model = model.id().0.to_string();
1172                let enable_thinking = thread.read(cx).thinking_enabled();
1173                settings
1174                    .agent
1175                    .get_or_insert_default()
1176                    .set_model(LanguageModelSelection {
1177                        provider: provider.into(),
1178                        model,
1179                        enable_thinking,
1180                        effort,
1181                    });
1182            },
1183        );
1184
1185        Task::ready(Ok(()))
1186    }
1187
1188    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1189        let Some(thread) = self
1190            .connection
1191            .0
1192            .read(cx)
1193            .sessions
1194            .get(&self.session_id)
1195            .map(|session| session.thread.clone())
1196        else {
1197            return Task::ready(Err(anyhow!("Session not found")));
1198        };
1199        let Some(model) = thread.read(cx).model() else {
1200            return Task::ready(Err(anyhow!("Model not found")));
1201        };
1202        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1203        else {
1204            return Task::ready(Err(anyhow!("Provider not found")));
1205        };
1206        Task::ready(Ok(LanguageModels::map_language_model_to_info(
1207            model, &provider,
1208        )))
1209    }
1210
1211    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1212        Some(self.connection.0.read(cx).models.watch())
1213    }
1214
1215    fn should_render_footer(&self) -> bool {
1216        true
1217    }
1218}
1219
1220impl acp_thread::AgentConnection for NativeAgentConnection {
1221    fn telemetry_id(&self) -> SharedString {
1222        "zed".into()
1223    }
1224
1225    fn new_thread(
1226        self: Rc<Self>,
1227        project: Entity<Project>,
1228        cwd: &Path,
1229        cx: &mut App,
1230    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1231        log::debug!("Creating new thread for project at: {cwd:?}");
1232        Task::ready(Ok(self
1233            .0
1234            .update(cx, |agent, cx| agent.new_session(project, cx))))
1235    }
1236
1237    fn supports_load_session(&self, _cx: &App) -> bool {
1238        true
1239    }
1240
1241    fn load_session(
1242        self: Rc<Self>,
1243        session: AgentSessionInfo,
1244        _project: Entity<Project>,
1245        _cwd: &Path,
1246        cx: &mut App,
1247    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1248        self.0
1249            .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1250    }
1251
1252    fn auth_methods(&self) -> &[acp::AuthMethod] {
1253        &[] // No auth for in-process
1254    }
1255
1256    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1257        Task::ready(Ok(()))
1258    }
1259
1260    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1261        Some(Rc::new(NativeAgentModelSelector {
1262            session_id: session_id.clone(),
1263            connection: self.clone(),
1264        }) as Rc<dyn AgentModelSelector>)
1265    }
1266
1267    fn prompt(
1268        &self,
1269        id: Option<acp_thread::UserMessageId>,
1270        params: acp::PromptRequest,
1271        cx: &mut App,
1272    ) -> Task<Result<acp::PromptResponse>> {
1273        let id = id.expect("UserMessageId is required");
1274        let session_id = params.session_id.clone();
1275        log::info!("Received prompt request for session: {}", session_id);
1276        log::debug!("Prompt blocks count: {}", params.prompt.len());
1277
1278        if let Some(parsed_command) = Command::parse(&params.prompt) {
1279            let registry = self.0.read(cx).context_server_registry.read(cx);
1280
1281            let explicit_server_id = parsed_command
1282                .explicit_server_id
1283                .map(|server_id| ContextServerId(server_id.into()));
1284
1285            if let Some(prompt) =
1286                registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1287            {
1288                let arguments = if !parsed_command.arg_value.is_empty()
1289                    && let Some(arg_name) = prompt
1290                        .prompt
1291                        .arguments
1292                        .as_ref()
1293                        .and_then(|args| args.first())
1294                        .map(|arg| arg.name.clone())
1295                {
1296                    HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1297                } else {
1298                    Default::default()
1299                };
1300
1301                let prompt_name = prompt.prompt.name.clone();
1302                let server_id = prompt.server_id.clone();
1303
1304                return self.0.update(cx, |agent, cx| {
1305                    agent.send_mcp_prompt(
1306                        id,
1307                        session_id.clone(),
1308                        prompt_name,
1309                        server_id,
1310                        arguments,
1311                        params.prompt,
1312                        cx,
1313                    )
1314                });
1315            };
1316        };
1317
1318        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1319
1320        self.run_turn(session_id, cx, move |thread, cx| {
1321            let content: Vec<UserMessageContent> = params
1322                .prompt
1323                .into_iter()
1324                .map(|block| UserMessageContent::from_content_block(block, path_style))
1325                .collect::<Vec<_>>();
1326            log::debug!("Converted prompt to message: {} chars", content.len());
1327            log::debug!("Message id: {:?}", id);
1328            log::debug!("Message content: {:?}", content);
1329
1330            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1331        })
1332    }
1333
1334    fn retry(
1335        &self,
1336        session_id: &acp::SessionId,
1337        _cx: &App,
1338    ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1339        Some(Rc::new(NativeAgentSessionRetry {
1340            connection: self.clone(),
1341            session_id: session_id.clone(),
1342        }) as _)
1343    }
1344
1345    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1346        log::info!("Cancelling on session: {}", session_id);
1347        self.0.update(cx, |agent, cx| {
1348            if let Some(agent) = agent.sessions.get(session_id) {
1349                agent
1350                    .thread
1351                    .update(cx, |thread, cx| thread.cancel(cx))
1352                    .detach();
1353            }
1354        });
1355    }
1356
1357    fn truncate(
1358        &self,
1359        session_id: &agent_client_protocol::SessionId,
1360        cx: &App,
1361    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1362        self.0.read_with(cx, |agent, _cx| {
1363            agent.sessions.get(session_id).map(|session| {
1364                Rc::new(NativeAgentSessionTruncate {
1365                    thread: session.thread.clone(),
1366                    acp_thread: session.acp_thread.clone(),
1367                }) as _
1368            })
1369        })
1370    }
1371
1372    fn set_title(
1373        &self,
1374        session_id: &acp::SessionId,
1375        _cx: &App,
1376    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1377        Some(Rc::new(NativeAgentSessionSetTitle {
1378            connection: self.clone(),
1379            session_id: session_id.clone(),
1380        }) as _)
1381    }
1382
1383    fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1384        let thread_store = self.0.read(cx).thread_store.clone();
1385        Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1386    }
1387
1388    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1389        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1390    }
1391
1392    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1393        self
1394    }
1395}
1396
1397impl acp_thread::AgentTelemetry for NativeAgentConnection {
1398    fn thread_data(
1399        &self,
1400        session_id: &acp::SessionId,
1401        cx: &mut App,
1402    ) -> Task<Result<serde_json::Value>> {
1403        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1404            return Task::ready(Err(anyhow!("Session not found")));
1405        };
1406
1407        let task = session.thread.read(cx).to_db(cx);
1408        cx.background_spawn(async move {
1409            serde_json::to_value(task.await).context("Failed to serialize thread")
1410        })
1411    }
1412}
1413
1414pub struct NativeAgentSessionList {
1415    thread_store: Entity<ThreadStore>,
1416    updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1417    updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1418    _subscription: Subscription,
1419}
1420
1421impl NativeAgentSessionList {
1422    fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1423        let (tx, rx) = smol::channel::unbounded();
1424        let this_tx = tx.clone();
1425        let subscription = cx.observe(&thread_store, move |_, _| {
1426            this_tx
1427                .try_send(acp_thread::SessionListUpdate::Refresh)
1428                .ok();
1429        });
1430        Self {
1431            thread_store,
1432            updates_tx: tx,
1433            updates_rx: rx,
1434            _subscription: subscription,
1435        }
1436    }
1437
1438    fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1439        AgentSessionInfo {
1440            session_id: entry.id,
1441            cwd: None,
1442            title: Some(entry.title),
1443            updated_at: Some(entry.updated_at),
1444            meta: None,
1445        }
1446    }
1447
1448    pub fn thread_store(&self) -> &Entity<ThreadStore> {
1449        &self.thread_store
1450    }
1451}
1452
1453impl AgentSessionList for NativeAgentSessionList {
1454    fn list_sessions(
1455        &self,
1456        _request: AgentSessionListRequest,
1457        cx: &mut App,
1458    ) -> Task<Result<AgentSessionListResponse>> {
1459        let sessions = self
1460            .thread_store
1461            .read(cx)
1462            .entries()
1463            .map(Self::to_session_info)
1464            .collect();
1465        Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1466    }
1467
1468    fn supports_delete(&self) -> bool {
1469        true
1470    }
1471
1472    fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1473        self.thread_store
1474            .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1475    }
1476
1477    fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1478        self.thread_store
1479            .update(cx, |store, cx| store.delete_threads(cx))
1480    }
1481
1482    fn watch(
1483        &self,
1484        _cx: &mut App,
1485    ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1486        Some(self.updates_rx.clone())
1487    }
1488
1489    fn notify_refresh(&self) {
1490        self.updates_tx
1491            .try_send(acp_thread::SessionListUpdate::Refresh)
1492            .ok();
1493    }
1494
1495    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1496        self
1497    }
1498}
1499
1500struct NativeAgentSessionTruncate {
1501    thread: Entity<Thread>,
1502    acp_thread: WeakEntity<AcpThread>,
1503}
1504
1505impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1506    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1507        match self.thread.update(cx, |thread, cx| {
1508            thread.truncate(message_id.clone(), cx)?;
1509            Ok(thread.latest_token_usage())
1510        }) {
1511            Ok(usage) => {
1512                self.acp_thread
1513                    .update(cx, |thread, cx| {
1514                        thread.update_token_usage(usage, cx);
1515                    })
1516                    .ok();
1517                Task::ready(Ok(()))
1518            }
1519            Err(error) => Task::ready(Err(error)),
1520        }
1521    }
1522}
1523
1524struct NativeAgentSessionRetry {
1525    connection: NativeAgentConnection,
1526    session_id: acp::SessionId,
1527}
1528
1529impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1530    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1531        self.connection
1532            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1533                thread.update(cx, |thread, cx| thread.resume(cx))
1534            })
1535    }
1536}
1537
1538struct NativeAgentSessionSetTitle {
1539    connection: NativeAgentConnection,
1540    session_id: acp::SessionId,
1541}
1542
1543impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1544    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1545        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1546            return Task::ready(Err(anyhow!("session not found")));
1547        };
1548        let thread = session.thread.clone();
1549        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1550        Task::ready(Ok(()))
1551    }
1552}
1553
1554pub struct AcpThreadEnvironment {
1555    acp_thread: WeakEntity<AcpThread>,
1556}
1557
1558impl ThreadEnvironment for AcpThreadEnvironment {
1559    fn create_terminal(
1560        &self,
1561        command: String,
1562        cwd: Option<PathBuf>,
1563        output_byte_limit: Option<u64>,
1564        cx: &mut AsyncApp,
1565    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1566        let task = self.acp_thread.update(cx, |thread, cx| {
1567            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1568        });
1569
1570        let acp_thread = self.acp_thread.clone();
1571        cx.spawn(async move |cx| {
1572            let terminal = task?.await?;
1573
1574            let (drop_tx, drop_rx) = oneshot::channel();
1575            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1576
1577            cx.spawn(async move |cx| {
1578                drop_rx.await.ok();
1579                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1580            })
1581            .detach();
1582
1583            let handle = AcpTerminalHandle {
1584                terminal,
1585                _drop_tx: Some(drop_tx),
1586            };
1587
1588            Ok(Rc::new(handle) as _)
1589        })
1590    }
1591}
1592
1593pub struct AcpTerminalHandle {
1594    terminal: Entity<acp_thread::Terminal>,
1595    _drop_tx: Option<oneshot::Sender<()>>,
1596}
1597
1598impl TerminalHandle for AcpTerminalHandle {
1599    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1600        Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1601    }
1602
1603    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1604        Ok(self
1605            .terminal
1606            .read_with(cx, |term, _cx| term.wait_for_exit()))
1607    }
1608
1609    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1610        Ok(self
1611            .terminal
1612            .read_with(cx, |term, cx| term.current_output(cx)))
1613    }
1614
1615    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1616        cx.update(|cx| {
1617            self.terminal.update(cx, |terminal, cx| {
1618                terminal.kill(cx);
1619            });
1620        });
1621        Ok(())
1622    }
1623
1624    fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1625        Ok(self
1626            .terminal
1627            .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1628    }
1629}
1630
1631#[cfg(test)]
1632mod internal_tests {
1633    use super::*;
1634    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1635    use fs::FakeFs;
1636    use gpui::TestAppContext;
1637    use indoc::formatdoc;
1638    use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1639    use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1640    use serde_json::json;
1641    use settings::SettingsStore;
1642    use util::{path, rel_path::rel_path};
1643
1644    #[gpui::test]
1645    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1646        init_test(cx);
1647        let fs = FakeFs::new(cx.executor());
1648        fs.insert_tree(
1649            "/",
1650            json!({
1651                "a": {}
1652            }),
1653        )
1654        .await;
1655        let project = Project::test(fs.clone(), [], cx).await;
1656        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1657        let agent = NativeAgent::new(
1658            project.clone(),
1659            thread_store,
1660            Templates::new(),
1661            None,
1662            fs.clone(),
1663            &mut cx.to_async(),
1664        )
1665        .await
1666        .unwrap();
1667        agent.read_with(cx, |agent, cx| {
1668            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1669        });
1670
1671        let worktree = project
1672            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1673            .await
1674            .unwrap();
1675        cx.run_until_parked();
1676        agent.read_with(cx, |agent, cx| {
1677            assert_eq!(
1678                agent.project_context.read(cx).worktrees,
1679                vec![WorktreeContext {
1680                    root_name: "a".into(),
1681                    abs_path: Path::new("/a").into(),
1682                    rules_file: None
1683                }]
1684            )
1685        });
1686
1687        // Creating `/a/.rules` updates the project context.
1688        fs.insert_file("/a/.rules", Vec::new()).await;
1689        cx.run_until_parked();
1690        agent.read_with(cx, |agent, cx| {
1691            let rules_entry = worktree
1692                .read(cx)
1693                .entry_for_path(rel_path(".rules"))
1694                .unwrap();
1695            assert_eq!(
1696                agent.project_context.read(cx).worktrees,
1697                vec![WorktreeContext {
1698                    root_name: "a".into(),
1699                    abs_path: Path::new("/a").into(),
1700                    rules_file: Some(RulesFileContext {
1701                        path_in_worktree: rel_path(".rules").into(),
1702                        text: "".into(),
1703                        project_entry_id: rules_entry.id.to_usize()
1704                    })
1705                }]
1706            )
1707        });
1708    }
1709
1710    #[gpui::test]
1711    async fn test_listing_models(cx: &mut TestAppContext) {
1712        init_test(cx);
1713        let fs = FakeFs::new(cx.executor());
1714        fs.insert_tree("/", json!({ "a": {}  })).await;
1715        let project = Project::test(fs.clone(), [], cx).await;
1716        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1717        let connection = NativeAgentConnection(
1718            NativeAgent::new(
1719                project.clone(),
1720                thread_store,
1721                Templates::new(),
1722                None,
1723                fs.clone(),
1724                &mut cx.to_async(),
1725            )
1726            .await
1727            .unwrap(),
1728        );
1729
1730        // Create a thread/session
1731        let acp_thread = cx
1732            .update(|cx| {
1733                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1734            })
1735            .await
1736            .unwrap();
1737
1738        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1739
1740        let models = cx
1741            .update(|cx| {
1742                connection
1743                    .model_selector(&session_id)
1744                    .unwrap()
1745                    .list_models(cx)
1746            })
1747            .await
1748            .unwrap();
1749
1750        let acp_thread::AgentModelList::Grouped(models) = models else {
1751            panic!("Unexpected model group");
1752        };
1753        assert_eq!(
1754            models,
1755            IndexMap::from_iter([(
1756                AgentModelGroupName("Fake".into()),
1757                vec![AgentModelInfo {
1758                    id: acp::ModelId::new("fake/fake"),
1759                    name: "Fake".into(),
1760                    description: None,
1761                    icon: Some(acp_thread::AgentModelIcon::Named(
1762                        ui::IconName::ZedAssistant
1763                    )),
1764                    is_latest: false,
1765                }]
1766            )])
1767        );
1768    }
1769
1770    #[gpui::test]
1771    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1772        init_test(cx);
1773        let fs = FakeFs::new(cx.executor());
1774        fs.create_dir(paths::settings_file().parent().unwrap())
1775            .await
1776            .unwrap();
1777        fs.insert_file(
1778            paths::settings_file(),
1779            json!({
1780                "agent": {
1781                    "default_model": {
1782                        "provider": "foo",
1783                        "model": "bar"
1784                    }
1785                }
1786            })
1787            .to_string()
1788            .into_bytes(),
1789        )
1790        .await;
1791        let project = Project::test(fs.clone(), [], cx).await;
1792
1793        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1794
1795        // Create the agent and connection
1796        let agent = NativeAgent::new(
1797            project.clone(),
1798            thread_store,
1799            Templates::new(),
1800            None,
1801            fs.clone(),
1802            &mut cx.to_async(),
1803        )
1804        .await
1805        .unwrap();
1806        let connection = NativeAgentConnection(agent.clone());
1807
1808        // Create a thread/session
1809        let acp_thread = cx
1810            .update(|cx| {
1811                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1812            })
1813            .await
1814            .unwrap();
1815
1816        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1817
1818        // Select a model
1819        let selector = connection.model_selector(&session_id).unwrap();
1820        let model_id = acp::ModelId::new("fake/fake");
1821        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1822            .await
1823            .unwrap();
1824
1825        // Verify the thread has the selected model
1826        agent.read_with(cx, |agent, _| {
1827            let session = agent.sessions.get(&session_id).unwrap();
1828            session.thread.read_with(cx, |thread, _| {
1829                assert_eq!(thread.model().unwrap().id().0, "fake");
1830            });
1831        });
1832
1833        cx.run_until_parked();
1834
1835        // Verify settings file was updated
1836        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1837        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1838
1839        // Check that the agent settings contain the selected model
1840        assert_eq!(
1841            settings_json["agent"]["default_model"]["model"],
1842            json!("fake")
1843        );
1844        assert_eq!(
1845            settings_json["agent"]["default_model"]["provider"],
1846            json!("fake")
1847        );
1848
1849        // Register a thinking model and select it.
1850        cx.update(|cx| {
1851            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
1852                "fake-corp",
1853                "fake-thinking",
1854                "Fake Thinking",
1855                true,
1856            ));
1857            let thinking_provider = Arc::new(
1858                FakeLanguageModelProvider::new(
1859                    LanguageModelProviderId::from("fake-corp".to_string()),
1860                    LanguageModelProviderName::from("Fake Corp".to_string()),
1861                )
1862                .with_models(vec![thinking_model]),
1863            );
1864            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1865                registry.register_provider(thinking_provider, cx);
1866            });
1867        });
1868        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
1869
1870        let selector = connection.model_selector(&session_id).unwrap();
1871        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
1872            .await
1873            .unwrap();
1874        cx.run_until_parked();
1875
1876        // Verify enable_thinking was written to settings as true.
1877        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1878        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1879        assert_eq!(
1880            settings_json["agent"]["default_model"]["enable_thinking"],
1881            json!(true),
1882            "selecting a thinking model should persist enable_thinking: true to settings"
1883        );
1884    }
1885
1886    #[gpui::test]
1887    async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
1888        init_test(cx);
1889        let fs = FakeFs::new(cx.executor());
1890        fs.create_dir(paths::settings_file().parent().unwrap())
1891            .await
1892            .unwrap();
1893        fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
1894        let project = Project::test(fs.clone(), [], cx).await;
1895
1896        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1897        let agent = NativeAgent::new(
1898            project.clone(),
1899            thread_store,
1900            Templates::new(),
1901            None,
1902            fs.clone(),
1903            &mut cx.to_async(),
1904        )
1905        .await
1906        .unwrap();
1907        let connection = NativeAgentConnection(agent.clone());
1908
1909        let acp_thread = cx
1910            .update(|cx| {
1911                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1912            })
1913            .await
1914            .unwrap();
1915        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1916
1917        // Register a second provider with a thinking model.
1918        cx.update(|cx| {
1919            let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
1920                "fake-corp",
1921                "fake-thinking",
1922                "Fake Thinking",
1923                true,
1924            ));
1925            let thinking_provider = Arc::new(
1926                FakeLanguageModelProvider::new(
1927                    LanguageModelProviderId::from("fake-corp".to_string()),
1928                    LanguageModelProviderName::from("Fake Corp".to_string()),
1929                )
1930                .with_models(vec![thinking_model]),
1931            );
1932            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1933                registry.register_provider(thinking_provider, cx);
1934            });
1935        });
1936        // Refresh the agent's model list so it picks up the new provider.
1937        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
1938
1939        // Thread starts with thinking_enabled = false (the default).
1940        agent.read_with(cx, |agent, _| {
1941            let session = agent.sessions.get(&session_id).unwrap();
1942            session.thread.read_with(cx, |thread, _| {
1943                assert!(!thread.thinking_enabled(), "thinking defaults to false");
1944            });
1945        });
1946
1947        // Select the thinking model via select_model.
1948        let selector = connection.model_selector(&session_id).unwrap();
1949        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
1950            .await
1951            .unwrap();
1952
1953        // select_model should have enabled thinking based on the model's supports_thinking().
1954        agent.read_with(cx, |agent, _| {
1955            let session = agent.sessions.get(&session_id).unwrap();
1956            session.thread.read_with(cx, |thread, _| {
1957                assert!(
1958                    thread.thinking_enabled(),
1959                    "select_model should enable thinking when model supports it"
1960                );
1961            });
1962        });
1963
1964        // Switch back to the non-thinking model.
1965        let selector = connection.model_selector(&session_id).unwrap();
1966        cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
1967            .await
1968            .unwrap();
1969
1970        // select_model should have disabled thinking.
1971        agent.read_with(cx, |agent, _| {
1972            let session = agent.sessions.get(&session_id).unwrap();
1973            session.thread.read_with(cx, |thread, _| {
1974                assert!(
1975                    !thread.thinking_enabled(),
1976                    "select_model should disable thinking when model does not support it"
1977                );
1978            });
1979        });
1980    }
1981
1982    #[gpui::test]
1983    async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
1984        init_test(cx);
1985        let fs = FakeFs::new(cx.executor());
1986        fs.insert_tree("/", json!({ "a": {} })).await;
1987        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1988        let thread_store = cx.new(|cx| ThreadStore::new(cx));
1989        let agent = NativeAgent::new(
1990            project.clone(),
1991            thread_store.clone(),
1992            Templates::new(),
1993            None,
1994            fs.clone(),
1995            &mut cx.to_async(),
1996        )
1997        .await
1998        .unwrap();
1999        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2000
2001        // Register a thinking model.
2002        let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2003            "fake-corp",
2004            "fake-thinking",
2005            "Fake Thinking",
2006            true,
2007        ));
2008        let thinking_provider = Arc::new(
2009            FakeLanguageModelProvider::new(
2010                LanguageModelProviderId::from("fake-corp".to_string()),
2011                LanguageModelProviderName::from("Fake Corp".to_string()),
2012            )
2013            .with_models(vec![thinking_model.clone()]),
2014        );
2015        cx.update(|cx| {
2016            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2017                registry.register_provider(thinking_provider, cx);
2018            });
2019        });
2020        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2021
2022        // Create a thread and select the thinking model.
2023        let acp_thread = cx
2024            .update(|cx| {
2025                connection
2026                    .clone()
2027                    .new_thread(project.clone(), Path::new("/a"), cx)
2028            })
2029            .await
2030            .unwrap();
2031        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2032
2033        let selector = connection.model_selector(&session_id).unwrap();
2034        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2035            .await
2036            .unwrap();
2037
2038        // Verify thinking is enabled after selecting the thinking model.
2039        let thread = agent.read_with(cx, |agent, _| {
2040            agent.sessions.get(&session_id).unwrap().thread.clone()
2041        });
2042        thread.read_with(cx, |thread, _| {
2043            assert!(
2044                thread.thinking_enabled(),
2045                "thinking should be enabled after selecting thinking model"
2046            );
2047        });
2048
2049        // Send a message so the thread gets persisted.
2050        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2051        let send = cx.foreground_executor().spawn(send);
2052        cx.run_until_parked();
2053
2054        thinking_model.send_last_completion_stream_text_chunk("Response.");
2055        thinking_model.end_last_completion_stream();
2056
2057        send.await.unwrap();
2058        cx.run_until_parked();
2059
2060        // Drop the thread so it can be reloaded from disk.
2061        cx.update(|_| {
2062            drop(thread);
2063            drop(acp_thread);
2064        });
2065        agent.read_with(cx, |agent, _| {
2066            assert!(agent.sessions.is_empty());
2067        });
2068
2069        // Reload the thread and verify thinking_enabled is still true.
2070        let reloaded_acp_thread = agent
2071            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2072            .await
2073            .unwrap();
2074        let reloaded_thread = agent.read_with(cx, |agent, _| {
2075            agent.sessions.get(&session_id).unwrap().thread.clone()
2076        });
2077        reloaded_thread.read_with(cx, |thread, _| {
2078            assert!(
2079                thread.thinking_enabled(),
2080                "thinking_enabled should be preserved when reloading a thread with a thinking model"
2081            );
2082        });
2083
2084        drop(reloaded_acp_thread);
2085    }
2086
2087    #[gpui::test]
2088    async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2089        init_test(cx);
2090        let fs = FakeFs::new(cx.executor());
2091        fs.insert_tree("/", json!({ "a": {} })).await;
2092        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2093        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2094        let agent = NativeAgent::new(
2095            project.clone(),
2096            thread_store.clone(),
2097            Templates::new(),
2098            None,
2099            fs.clone(),
2100            &mut cx.to_async(),
2101        )
2102        .await
2103        .unwrap();
2104        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2105
2106        // Register a model where id() != name(), like real Anthropic models
2107        // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2108        let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2109            "fake-corp",
2110            "custom-model-id",
2111            "Custom Model Display Name",
2112            false,
2113        ));
2114        let provider = Arc::new(
2115            FakeLanguageModelProvider::new(
2116                LanguageModelProviderId::from("fake-corp".to_string()),
2117                LanguageModelProviderName::from("Fake Corp".to_string()),
2118            )
2119            .with_models(vec![model.clone()]),
2120        );
2121        cx.update(|cx| {
2122            LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2123                registry.register_provider(provider, cx);
2124            });
2125        });
2126        agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2127
2128        // Create a thread and select the model.
2129        let acp_thread = cx
2130            .update(|cx| {
2131                connection
2132                    .clone()
2133                    .new_thread(project.clone(), Path::new("/a"), cx)
2134            })
2135            .await
2136            .unwrap();
2137        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2138
2139        let selector = connection.model_selector(&session_id).unwrap();
2140        cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2141            .await
2142            .unwrap();
2143
2144        let thread = agent.read_with(cx, |agent, _| {
2145            agent.sessions.get(&session_id).unwrap().thread.clone()
2146        });
2147        thread.read_with(cx, |thread, _| {
2148            assert_eq!(
2149                thread.model().unwrap().id().0.as_ref(),
2150                "custom-model-id",
2151                "model should be set before persisting"
2152            );
2153        });
2154
2155        // Send a message so the thread gets persisted.
2156        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2157        let send = cx.foreground_executor().spawn(send);
2158        cx.run_until_parked();
2159
2160        model.send_last_completion_stream_text_chunk("Response.");
2161        model.end_last_completion_stream();
2162
2163        send.await.unwrap();
2164        cx.run_until_parked();
2165
2166        // Drop the thread so it can be reloaded from disk.
2167        cx.update(|_| {
2168            drop(thread);
2169            drop(acp_thread);
2170        });
2171        agent.read_with(cx, |agent, _| {
2172            assert!(agent.sessions.is_empty());
2173        });
2174
2175        // Reload the thread and verify the model was preserved.
2176        let reloaded_acp_thread = agent
2177            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2178            .await
2179            .unwrap();
2180        let reloaded_thread = agent.read_with(cx, |agent, _| {
2181            agent.sessions.get(&session_id).unwrap().thread.clone()
2182        });
2183        reloaded_thread.read_with(cx, |thread, _| {
2184            let reloaded_model = thread
2185                .model()
2186                .expect("model should be present after reload");
2187            assert_eq!(
2188                reloaded_model.id().0.as_ref(),
2189                "custom-model-id",
2190                "reloaded thread should have the same model, not fall back to the default"
2191            );
2192        });
2193
2194        drop(reloaded_acp_thread);
2195    }
2196
2197    #[gpui::test]
2198    async fn test_save_load_thread(cx: &mut TestAppContext) {
2199        init_test(cx);
2200        let fs = FakeFs::new(cx.executor());
2201        fs.insert_tree(
2202            "/",
2203            json!({
2204                "a": {
2205                    "b.md": "Lorem"
2206                }
2207            }),
2208        )
2209        .await;
2210        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2211        let thread_store = cx.new(|cx| ThreadStore::new(cx));
2212        let agent = NativeAgent::new(
2213            project.clone(),
2214            thread_store.clone(),
2215            Templates::new(),
2216            None,
2217            fs.clone(),
2218            &mut cx.to_async(),
2219        )
2220        .await
2221        .unwrap();
2222        let connection = Rc::new(NativeAgentConnection(agent.clone()));
2223
2224        let acp_thread = cx
2225            .update(|cx| {
2226                connection
2227                    .clone()
2228                    .new_thread(project.clone(), Path::new(""), cx)
2229            })
2230            .await
2231            .unwrap();
2232        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2233        let thread = agent.read_with(cx, |agent, _| {
2234            agent.sessions.get(&session_id).unwrap().thread.clone()
2235        });
2236
2237        // Ensure empty threads are not saved, even if they get mutated.
2238        let model = Arc::new(FakeLanguageModel::default());
2239        let summary_model = Arc::new(FakeLanguageModel::default());
2240        thread.update(cx, |thread, cx| {
2241            thread.set_model(model.clone(), cx);
2242            thread.set_summarization_model(Some(summary_model.clone()), cx);
2243        });
2244        cx.run_until_parked();
2245        assert_eq!(thread_entries(&thread_store, cx), vec![]);
2246
2247        let send = acp_thread.update(cx, |thread, cx| {
2248            thread.send(
2249                vec![
2250                    "What does ".into(),
2251                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2252                        "b.md",
2253                        MentionUri::File {
2254                            abs_path: path!("/a/b.md").into(),
2255                        }
2256                        .to_uri()
2257                        .to_string(),
2258                    )),
2259                    " mean?".into(),
2260                ],
2261                cx,
2262            )
2263        });
2264        let send = cx.foreground_executor().spawn(send);
2265        cx.run_until_parked();
2266
2267        model.send_last_completion_stream_text_chunk("Lorem.");
2268        model.end_last_completion_stream();
2269        cx.run_until_parked();
2270        summary_model
2271            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2272        summary_model.end_last_completion_stream();
2273
2274        send.await.unwrap();
2275        let uri = MentionUri::File {
2276            abs_path: path!("/a/b.md").into(),
2277        }
2278        .to_uri();
2279        acp_thread.read_with(cx, |thread, cx| {
2280            assert_eq!(
2281                thread.to_markdown(cx),
2282                formatdoc! {"
2283                    ## User
2284
2285                    What does [@b.md]({uri}) mean?
2286
2287                    ## Assistant
2288
2289                    Lorem.
2290
2291                "}
2292            )
2293        });
2294
2295        cx.run_until_parked();
2296
2297        // Drop the ACP thread, which should cause the session to be dropped as well.
2298        cx.update(|_| {
2299            drop(thread);
2300            drop(acp_thread);
2301        });
2302        agent.read_with(cx, |agent, _| {
2303            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2304        });
2305
2306        // Ensure the thread can be reloaded from disk.
2307        assert_eq!(
2308            thread_entries(&thread_store, cx),
2309            vec![(
2310                session_id.clone(),
2311                format!("Explaining {}", path!("/a/b.md"))
2312            )]
2313        );
2314        let acp_thread = agent
2315            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2316            .await
2317            .unwrap();
2318        acp_thread.read_with(cx, |thread, cx| {
2319            assert_eq!(
2320                thread.to_markdown(cx),
2321                formatdoc! {"
2322                    ## User
2323
2324                    What does [@b.md]({uri}) mean?
2325
2326                    ## Assistant
2327
2328                    Lorem.
2329
2330                "}
2331            )
2332        });
2333    }
2334
2335    fn thread_entries(
2336        thread_store: &Entity<ThreadStore>,
2337        cx: &mut TestAppContext,
2338    ) -> Vec<(acp::SessionId, String)> {
2339        thread_store.read_with(cx, |store, _| {
2340            store
2341                .entries()
2342                .map(|entry| (entry.id.clone(), entry.title.to_string()))
2343                .collect::<Vec<_>>()
2344        })
2345    }
2346
2347    fn init_test(cx: &mut TestAppContext) {
2348        env_logger::try_init().ok();
2349        cx.update(|cx| {
2350            let settings_store = SettingsStore::test(cx);
2351            cx.set_global(settings_store);
2352
2353            LanguageModelRegistry::test(cx);
2354        });
2355    }
2356}
2357
2358fn mcp_message_content_to_acp_content_block(
2359    content: context_server::types::MessageContent,
2360) -> acp::ContentBlock {
2361    match content {
2362        context_server::types::MessageContent::Text {
2363            text,
2364            annotations: _,
2365        } => text.into(),
2366        context_server::types::MessageContent::Image {
2367            data,
2368            mime_type,
2369            annotations: _,
2370        } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2371        context_server::types::MessageContent::Audio {
2372            data,
2373            mime_type,
2374            annotations: _,
2375        } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2376        context_server::types::MessageContent::Resource {
2377            resource,
2378            annotations: _,
2379        } => {
2380            let mut link =
2381                acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2382            if let Some(mime_type) = resource.mime_type {
2383                link = link.mime_type(mime_type);
2384            }
2385            acp::ContentBlock::ResourceLink(link)
2386        }
2387    }
2388}