agent.rs

   1use crate::{
   2    ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
   3    UserMessageContent, templates::Templates,
   4};
   5use crate::{HistoryStore, TerminalHandle, ThreadEnvironment, TitleUpdated, TokenUsageUpdated};
   6use acp_thread::{AcpThread, AgentModelSelector};
   7use action_log::ActionLog;
   8use agent_client_protocol as acp;
   9use anyhow::{Context as _, Result, anyhow};
  10use collections::{HashSet, IndexMap};
  11use fs::Fs;
  12use futures::channel::{mpsc, oneshot};
  13use futures::future::Shared;
  14use futures::{StreamExt, future};
  15use gpui::{
  16    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  17};
  18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  19use project::{Project, ProjectItem, ProjectPath, Worktree};
  20use prompt_store::{
  21    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  22};
  23use settings::{LanguageModelSelection, update_settings_file};
  24use std::any::Any;
  25use std::collections::HashMap;
  26use std::path::{Path, PathBuf};
  27use std::rc::Rc;
  28use std::sync::Arc;
  29use util::ResultExt;
  30
  31const RULES_FILE_NAMES: [&str; 9] = [
  32    ".rules",
  33    ".cursorrules",
  34    ".windsurfrules",
  35    ".clinerules",
  36    ".github/copilot-instructions.md",
  37    "CLAUDE.md",
  38    "AGENT.md",
  39    "AGENTS.md",
  40    "GEMINI.md",
  41];
  42
  43pub struct RulesLoadingError {
  44    pub message: SharedString,
  45}
  46
  47/// Holds both the internal Thread and the AcpThread for a session
  48struct Session {
  49    /// The internal thread that processes messages
  50    thread: Entity<Thread>,
  51    /// The ACP thread that handles protocol communication
  52    acp_thread: WeakEntity<acp_thread::AcpThread>,
  53    pending_save: Task<()>,
  54    _subscriptions: Vec<Subscription>,
  55}
  56
  57pub struct LanguageModels {
  58    /// Access language model by ID
  59    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  60    /// Cached list for returning language model information
  61    model_list: acp_thread::AgentModelList,
  62    refresh_models_rx: watch::Receiver<()>,
  63    refresh_models_tx: watch::Sender<()>,
  64    _authenticate_all_providers_task: Task<()>,
  65}
  66
  67impl LanguageModels {
  68    fn new(cx: &mut App) -> Self {
  69        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  70
  71        let mut this = Self {
  72            models: HashMap::default(),
  73            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  74            refresh_models_rx,
  75            refresh_models_tx,
  76            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  77        };
  78        this.refresh_list(cx);
  79        this
  80    }
  81
  82    fn refresh_list(&mut self, cx: &App) {
  83        let providers = LanguageModelRegistry::global(cx)
  84            .read(cx)
  85            .providers()
  86            .into_iter()
  87            .filter(|provider| provider.is_authenticated(cx))
  88            .collect::<Vec<_>>();
  89
  90        let mut language_model_list = IndexMap::default();
  91        let mut recommended_models = HashSet::default();
  92
  93        let mut recommended = Vec::new();
  94        for provider in &providers {
  95            for model in provider.recommended_models(cx) {
  96                recommended_models.insert((model.provider_id(), model.id()));
  97                recommended.push(Self::map_language_model_to_info(&model, provider));
  98            }
  99        }
 100        if !recommended.is_empty() {
 101            language_model_list.insert(
 102                acp_thread::AgentModelGroupName("Recommended".into()),
 103                recommended,
 104            );
 105        }
 106
 107        let mut models = HashMap::default();
 108        for provider in providers {
 109            let mut provider_models = Vec::new();
 110            for model in provider.provided_models(cx) {
 111                let model_info = Self::map_language_model_to_info(&model, &provider);
 112                let model_id = model_info.id.clone();
 113                if !recommended_models.contains(&(model.provider_id(), model.id())) {
 114                    provider_models.push(model_info);
 115                }
 116                models.insert(model_id, model);
 117            }
 118            if !provider_models.is_empty() {
 119                language_model_list.insert(
 120                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 121                    provider_models,
 122                );
 123            }
 124        }
 125
 126        self.models = models;
 127        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 128        self.refresh_models_tx.send(()).ok();
 129    }
 130
 131    fn watch(&self) -> watch::Receiver<()> {
 132        self.refresh_models_rx.clone()
 133    }
 134
 135    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 136        self.models.get(model_id).cloned()
 137    }
 138
 139    fn map_language_model_to_info(
 140        model: &Arc<dyn LanguageModel>,
 141        provider: &Arc<dyn LanguageModelProvider>,
 142    ) -> acp_thread::AgentModelInfo {
 143        acp_thread::AgentModelInfo {
 144            id: Self::model_id(model),
 145            name: model.name().0,
 146            description: None,
 147            icon: Some(provider.icon()),
 148        }
 149    }
 150
 151    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 152        acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 153    }
 154
 155    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 156        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 157            .read(cx)
 158            .providers()
 159            .iter()
 160            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 161            .collect::<Vec<_>>();
 162
 163        cx.background_spawn(async move {
 164            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 165                if let Err(err) = authenticate_task.await {
 166                    match err {
 167                        language_model::AuthenticateError::CredentialsNotFound => {
 168                            // Since we're authenticating these providers in the
 169                            // background for the purposes of populating the
 170                            // language selector, we don't care about providers
 171                            // where the credentials are not found.
 172                        }
 173                        language_model::AuthenticateError::ConnectionRefused => {
 174                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 175                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 176                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 177                        }
 178                        _ => {
 179                            // Some providers have noisy failure states that we
 180                            // don't want to spam the logs with every time the
 181                            // language model selector is initialized.
 182                            //
 183                            // Ideally these should have more clear failure modes
 184                            // that we know are safe to ignore here, like what we do
 185                            // with `CredentialsNotFound` above.
 186                            match provider_id.0.as_ref() {
 187                                "lmstudio" | "ollama" => {
 188                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 189                                    //
 190                                    // These fail noisily, so we don't log them.
 191                                }
 192                                "copilot_chat" => {
 193                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 194                                }
 195                                _ => {
 196                                    log::error!(
 197                                        "Failed to authenticate provider: {}: {err}",
 198                                        provider_name.0
 199                                    );
 200                                }
 201                            }
 202                        }
 203                    }
 204                }
 205            }
 206        })
 207    }
 208}
 209
 210pub struct NativeAgent {
 211    /// Session ID -> Session mapping
 212    sessions: HashMap<acp::SessionId, Session>,
 213    history: Entity<HistoryStore>,
 214    /// Shared project context for all threads
 215    project_context: Entity<ProjectContext>,
 216    project_context_needs_refresh: watch::Sender<()>,
 217    _maintain_project_context: Task<Result<()>>,
 218    context_server_registry: Entity<ContextServerRegistry>,
 219    /// Shared templates for all threads
 220    templates: Arc<Templates>,
 221    /// Cached model information
 222    models: LanguageModels,
 223    project: Entity<Project>,
 224    prompt_store: Option<Entity<PromptStore>>,
 225    fs: Arc<dyn Fs>,
 226    _subscriptions: Vec<Subscription>,
 227}
 228
 229impl NativeAgent {
 230    pub async fn new(
 231        project: Entity<Project>,
 232        history: Entity<HistoryStore>,
 233        templates: Arc<Templates>,
 234        prompt_store: Option<Entity<PromptStore>>,
 235        fs: Arc<dyn Fs>,
 236        cx: &mut AsyncApp,
 237    ) -> Result<Entity<NativeAgent>> {
 238        log::debug!("Creating new NativeAgent");
 239
 240        let project_context = cx
 241            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 242            .await;
 243
 244        cx.new(|cx| {
 245            let mut subscriptions = vec![
 246                cx.subscribe(&project, Self::handle_project_event),
 247                cx.subscribe(
 248                    &LanguageModelRegistry::global(cx),
 249                    Self::handle_models_updated_event,
 250                ),
 251            ];
 252            if let Some(prompt_store) = prompt_store.as_ref() {
 253                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 254            }
 255
 256            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 257                watch::channel(());
 258            Self {
 259                sessions: HashMap::new(),
 260                history,
 261                project_context: cx.new(|_| project_context),
 262                project_context_needs_refresh: project_context_needs_refresh_tx,
 263                _maintain_project_context: cx.spawn(async move |this, cx| {
 264                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 265                }),
 266                context_server_registry: cx.new(|cx| {
 267                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 268                }),
 269                templates,
 270                models: LanguageModels::new(cx),
 271                project,
 272                prompt_store,
 273                fs,
 274                _subscriptions: subscriptions,
 275            }
 276        })
 277    }
 278
 279    fn register_session(
 280        &mut self,
 281        thread_handle: Entity<Thread>,
 282        cx: &mut Context<Self>,
 283    ) -> Entity<AcpThread> {
 284        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 285
 286        let thread = thread_handle.read(cx);
 287        let session_id = thread.id().clone();
 288        let title = thread.title();
 289        let project = thread.project.clone();
 290        let action_log = thread.action_log.clone();
 291        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 292        let acp_thread = cx.new(|cx| {
 293            acp_thread::AcpThread::new(
 294                title,
 295                connection,
 296                project.clone(),
 297                action_log.clone(),
 298                session_id.clone(),
 299                prompt_capabilities_rx,
 300                cx,
 301            )
 302        });
 303
 304        let registry = LanguageModelRegistry::read_global(cx);
 305        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 306
 307        thread_handle.update(cx, |thread, cx| {
 308            thread.set_summarization_model(summarization_model, cx);
 309            thread.add_default_tools(
 310                Rc::new(AcpThreadEnvironment {
 311                    acp_thread: acp_thread.downgrade(),
 312                }) as _,
 313                cx,
 314            )
 315        });
 316
 317        let subscriptions = vec![
 318            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 319                this.sessions.remove(acp_thread.session_id());
 320            }),
 321            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 322            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 323            cx.observe(&thread_handle, move |this, thread, cx| {
 324                this.save_thread(thread, cx)
 325            }),
 326        ];
 327
 328        self.sessions.insert(
 329            session_id,
 330            Session {
 331                thread: thread_handle,
 332                acp_thread: acp_thread.downgrade(),
 333                _subscriptions: subscriptions,
 334                pending_save: Task::ready(()),
 335            },
 336        );
 337        acp_thread
 338    }
 339
 340    pub fn models(&self) -> &LanguageModels {
 341        &self.models
 342    }
 343
 344    async fn maintain_project_context(
 345        this: WeakEntity<Self>,
 346        mut needs_refresh: watch::Receiver<()>,
 347        cx: &mut AsyncApp,
 348    ) -> Result<()> {
 349        while needs_refresh.changed().await.is_ok() {
 350            let project_context = this
 351                .update(cx, |this, cx| {
 352                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 353                })?
 354                .await;
 355            this.update(cx, |this, cx| {
 356                this.project_context = cx.new(|_| project_context);
 357            })?;
 358        }
 359
 360        Ok(())
 361    }
 362
 363    fn build_project_context(
 364        project: &Entity<Project>,
 365        prompt_store: Option<&Entity<PromptStore>>,
 366        cx: &mut App,
 367    ) -> Task<ProjectContext> {
 368        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 369        let worktree_tasks = worktrees
 370            .into_iter()
 371            .map(|worktree| {
 372                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 373            })
 374            .collect::<Vec<_>>();
 375        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 376            prompt_store.read_with(cx, |prompt_store, cx| {
 377                let prompts = prompt_store.default_prompt_metadata();
 378                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 379                    let contents = prompt_store.load(prompt_metadata.id, cx);
 380                    async move { (contents.await, prompt_metadata) }
 381                });
 382                cx.background_spawn(future::join_all(load_tasks))
 383            })
 384        } else {
 385            Task::ready(vec![])
 386        };
 387
 388        cx.spawn(async move |_cx| {
 389            let (worktrees, default_user_rules) =
 390                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 391
 392            let worktrees = worktrees
 393                .into_iter()
 394                .map(|(worktree, _rules_error)| {
 395                    // TODO: show error message
 396                    // if let Some(rules_error) = rules_error {
 397                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 398                    // }
 399                    worktree
 400                })
 401                .collect::<Vec<_>>();
 402
 403            let default_user_rules = default_user_rules
 404                .into_iter()
 405                .flat_map(|(contents, prompt_metadata)| match contents {
 406                    Ok(contents) => Some(UserRulesContext {
 407                        uuid: match prompt_metadata.id {
 408                            PromptId::User { uuid } => uuid,
 409                            PromptId::EditWorkflow => return None,
 410                        },
 411                        title: prompt_metadata.title.map(|title| title.to_string()),
 412                        contents,
 413                    }),
 414                    Err(_err) => {
 415                        // TODO: show error message
 416                        // this.update(cx, |_, cx| {
 417                        //     cx.emit(RulesLoadingError {
 418                        //         message: format!("{err:?}").into(),
 419                        //     });
 420                        // })
 421                        // .ok();
 422                        None
 423                    }
 424                })
 425                .collect::<Vec<_>>();
 426
 427            ProjectContext::new(worktrees, default_user_rules)
 428        })
 429    }
 430
 431    fn load_worktree_info_for_system_prompt(
 432        worktree: Entity<Worktree>,
 433        project: Entity<Project>,
 434        cx: &mut App,
 435    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 436        let tree = worktree.read(cx);
 437        let root_name = tree.root_name().into();
 438        let abs_path = tree.abs_path();
 439
 440        let mut context = WorktreeContext {
 441            root_name,
 442            abs_path,
 443            rules_file: None,
 444        };
 445
 446        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 447        let Some(rules_task) = rules_task else {
 448            return Task::ready((context, None));
 449        };
 450
 451        cx.spawn(async move |_| {
 452            let (rules_file, rules_file_error) = match rules_task.await {
 453                Ok(rules_file) => (Some(rules_file), None),
 454                Err(err) => (
 455                    None,
 456                    Some(RulesLoadingError {
 457                        message: format!("{err}").into(),
 458                    }),
 459                ),
 460            };
 461            context.rules_file = rules_file;
 462            (context, rules_file_error)
 463        })
 464    }
 465
 466    fn load_worktree_rules_file(
 467        worktree: Entity<Worktree>,
 468        project: Entity<Project>,
 469        cx: &mut App,
 470    ) -> Option<Task<Result<RulesFileContext>>> {
 471        let worktree = worktree.read(cx);
 472        let worktree_id = worktree.id();
 473        let selected_rules_file = RULES_FILE_NAMES
 474            .into_iter()
 475            .filter_map(|name| {
 476                worktree
 477                    .entry_for_path(name)
 478                    .filter(|entry| entry.is_file())
 479                    .map(|entry| entry.path.clone())
 480            })
 481            .next();
 482
 483        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 484        // supported. This doesn't seem to occur often in GitHub repositories.
 485        selected_rules_file.map(|path_in_worktree| {
 486            let project_path = ProjectPath {
 487                worktree_id,
 488                path: path_in_worktree.clone(),
 489            };
 490            let buffer_task =
 491                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 492            let rope_task = cx.spawn(async move |cx| {
 493                buffer_task.await?.read_with(cx, |buffer, cx| {
 494                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 495                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 496                })?
 497            });
 498            // Build a string from the rope on a background thread.
 499            cx.background_spawn(async move {
 500                let (project_entry_id, rope) = rope_task.await?;
 501                anyhow::Ok(RulesFileContext {
 502                    path_in_worktree,
 503                    text: rope.to_string().trim().to_string(),
 504                    project_entry_id: project_entry_id.to_usize(),
 505                })
 506            })
 507        })
 508    }
 509
 510    fn handle_thread_title_updated(
 511        &mut self,
 512        thread: Entity<Thread>,
 513        _: &TitleUpdated,
 514        cx: &mut Context<Self>,
 515    ) {
 516        let session_id = thread.read(cx).id();
 517        let Some(session) = self.sessions.get(session_id) else {
 518            return;
 519        };
 520        let thread = thread.downgrade();
 521        let acp_thread = session.acp_thread.clone();
 522        cx.spawn(async move |_, cx| {
 523            let title = thread.read_with(cx, |thread, _| thread.title())?;
 524            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 525            task.await
 526        })
 527        .detach_and_log_err(cx);
 528    }
 529
 530    fn handle_thread_token_usage_updated(
 531        &mut self,
 532        thread: Entity<Thread>,
 533        usage: &TokenUsageUpdated,
 534        cx: &mut Context<Self>,
 535    ) {
 536        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 537            return;
 538        };
 539        session
 540            .acp_thread
 541            .update(cx, |acp_thread, cx| {
 542                acp_thread.update_token_usage(usage.0.clone(), cx);
 543            })
 544            .ok();
 545    }
 546
 547    fn handle_project_event(
 548        &mut self,
 549        _project: Entity<Project>,
 550        event: &project::Event,
 551        _cx: &mut Context<Self>,
 552    ) {
 553        match event {
 554            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 555                self.project_context_needs_refresh.send(()).ok();
 556            }
 557            project::Event::WorktreeUpdatedEntries(_, items) => {
 558                if items.iter().any(|(path, _, _)| {
 559                    RULES_FILE_NAMES
 560                        .iter()
 561                        .any(|name| path.as_ref() == Path::new(name))
 562                }) {
 563                    self.project_context_needs_refresh.send(()).ok();
 564                }
 565            }
 566            _ => {}
 567        }
 568    }
 569
 570    fn handle_prompts_updated_event(
 571        &mut self,
 572        _prompt_store: Entity<PromptStore>,
 573        _event: &prompt_store::PromptsUpdatedEvent,
 574        _cx: &mut Context<Self>,
 575    ) {
 576        self.project_context_needs_refresh.send(()).ok();
 577    }
 578
 579    fn handle_models_updated_event(
 580        &mut self,
 581        _registry: Entity<LanguageModelRegistry>,
 582        _event: &language_model::Event,
 583        cx: &mut Context<Self>,
 584    ) {
 585        self.models.refresh_list(cx);
 586
 587        let registry = LanguageModelRegistry::read_global(cx);
 588        let default_model = registry.default_model().map(|m| m.model);
 589        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 590
 591        for session in self.sessions.values_mut() {
 592            session.thread.update(cx, |thread, cx| {
 593                if thread.model().is_none()
 594                    && let Some(model) = default_model.clone()
 595                {
 596                    thread.set_model(model, cx);
 597                    cx.notify();
 598                }
 599                thread.set_summarization_model(summarization_model.clone(), cx);
 600            });
 601        }
 602    }
 603
 604    pub fn open_thread(
 605        &mut self,
 606        id: acp::SessionId,
 607        cx: &mut Context<Self>,
 608    ) -> Task<Result<Entity<AcpThread>>> {
 609        let database_future = ThreadsDatabase::connect(cx);
 610        cx.spawn(async move |this, cx| {
 611            let database = database_future.await.map_err(|err| anyhow!(err))?;
 612            let db_thread = database
 613                .load_thread(id.clone())
 614                .await?
 615                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 616
 617            let thread = this.update(cx, |this, cx| {
 618                let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
 619                cx.new(|cx| {
 620                    Thread::from_db(
 621                        id.clone(),
 622                        db_thread,
 623                        this.project.clone(),
 624                        this.project_context.clone(),
 625                        this.context_server_registry.clone(),
 626                        action_log.clone(),
 627                        this.templates.clone(),
 628                        cx,
 629                    )
 630                })
 631            })?;
 632            let acp_thread =
 633                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 634            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
 635            cx.update(|cx| {
 636                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 637            })?
 638            .await?;
 639            Ok(acp_thread)
 640        })
 641    }
 642
 643    pub fn thread_summary(
 644        &mut self,
 645        id: acp::SessionId,
 646        cx: &mut Context<Self>,
 647    ) -> Task<Result<SharedString>> {
 648        let thread = self.open_thread(id.clone(), cx);
 649        cx.spawn(async move |this, cx| {
 650            let acp_thread = thread.await?;
 651            let result = this
 652                .update(cx, |this, cx| {
 653                    this.sessions
 654                        .get(&id)
 655                        .unwrap()
 656                        .thread
 657                        .update(cx, |thread, cx| thread.summary(cx))
 658                })?
 659                .await?;
 660            drop(acp_thread);
 661            Ok(result)
 662        })
 663    }
 664
 665    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 666        if thread.read(cx).is_empty() {
 667            return;
 668        }
 669
 670        let database_future = ThreadsDatabase::connect(cx);
 671        let (id, db_thread) =
 672            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 673        let Some(session) = self.sessions.get_mut(&id) else {
 674            return;
 675        };
 676        let history = self.history.clone();
 677        session.pending_save = cx.spawn(async move |_, cx| {
 678            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 679                return;
 680            };
 681            let db_thread = db_thread.await;
 682            database.save_thread(id, db_thread).await.log_err();
 683            history.update(cx, |history, cx| history.reload(cx)).ok();
 684        });
 685    }
 686}
 687
 688/// Wrapper struct that implements the AgentConnection trait
 689#[derive(Clone)]
 690pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 691
 692impl NativeAgentConnection {
 693    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 694        self.0
 695            .read(cx)
 696            .sessions
 697            .get(session_id)
 698            .map(|session| session.thread.clone())
 699    }
 700
 701    fn run_turn(
 702        &self,
 703        session_id: acp::SessionId,
 704        cx: &mut App,
 705        f: impl 'static
 706        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 707    ) -> Task<Result<acp::PromptResponse>> {
 708        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 709            agent
 710                .sessions
 711                .get_mut(&session_id)
 712                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 713        }) else {
 714            return Task::ready(Err(anyhow!("Session not found")));
 715        };
 716        log::debug!("Found session for: {}", session_id);
 717
 718        let response_stream = match f(thread, cx) {
 719            Ok(stream) => stream,
 720            Err(err) => return Task::ready(Err(err)),
 721        };
 722        Self::handle_thread_events(response_stream, acp_thread, cx)
 723    }
 724
 725    fn handle_thread_events(
 726        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 727        acp_thread: WeakEntity<AcpThread>,
 728        cx: &App,
 729    ) -> Task<Result<acp::PromptResponse>> {
 730        cx.spawn(async move |cx| {
 731            // Handle response stream and forward to session.acp_thread
 732            while let Some(result) = events.next().await {
 733                match result {
 734                    Ok(event) => {
 735                        log::trace!("Received completion event: {:?}", event);
 736
 737                        match event {
 738                            ThreadEvent::UserMessage(message) => {
 739                                acp_thread.update(cx, |thread, cx| {
 740                                    for content in message.content {
 741                                        thread.push_user_content_block(
 742                                            Some(message.id.clone()),
 743                                            content.into(),
 744                                            cx,
 745                                        );
 746                                    }
 747                                })?;
 748                            }
 749                            ThreadEvent::AgentText(text) => {
 750                                acp_thread.update(cx, |thread, cx| {
 751                                    thread.push_assistant_content_block(
 752                                        acp::ContentBlock::Text(acp::TextContent {
 753                                            text,
 754                                            annotations: None,
 755                                            meta: None,
 756                                        }),
 757                                        false,
 758                                        cx,
 759                                    )
 760                                })?;
 761                            }
 762                            ThreadEvent::AgentThinking(text) => {
 763                                acp_thread.update(cx, |thread, cx| {
 764                                    thread.push_assistant_content_block(
 765                                        acp::ContentBlock::Text(acp::TextContent {
 766                                            text,
 767                                            annotations: None,
 768                                            meta: None,
 769                                        }),
 770                                        true,
 771                                        cx,
 772                                    )
 773                                })?;
 774                            }
 775                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 776                                tool_call,
 777                                options,
 778                                response,
 779                            }) => {
 780                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 781                                    thread.request_tool_call_authorization(
 782                                        tool_call, options, true, cx,
 783                                    )
 784                                })??;
 785                                cx.background_spawn(async move {
 786                                    if let acp::RequestPermissionOutcome::Selected { option_id } =
 787                                        outcome_task.await
 788                                    {
 789                                        response
 790                                            .send(option_id)
 791                                            .map(|_| anyhow!("authorization receiver was dropped"))
 792                                            .log_err();
 793                                    }
 794                                })
 795                                .detach();
 796                            }
 797                            ThreadEvent::ToolCall(tool_call) => {
 798                                acp_thread.update(cx, |thread, cx| {
 799                                    thread.upsert_tool_call(tool_call, cx)
 800                                })??;
 801                            }
 802                            ThreadEvent::ToolCallUpdate(update) => {
 803                                acp_thread.update(cx, |thread, cx| {
 804                                    thread.update_tool_call(update, cx)
 805                                })??;
 806                            }
 807                            ThreadEvent::Retry(status) => {
 808                                acp_thread.update(cx, |thread, cx| {
 809                                    thread.update_retry_status(status, cx)
 810                                })?;
 811                            }
 812                            ThreadEvent::Stop(stop_reason) => {
 813                                log::debug!("Assistant message complete: {:?}", stop_reason);
 814                                return Ok(acp::PromptResponse {
 815                                    stop_reason,
 816                                    meta: None,
 817                                });
 818                            }
 819                        }
 820                    }
 821                    Err(e) => {
 822                        log::error!("Error in model response stream: {:?}", e);
 823                        return Err(e);
 824                    }
 825                }
 826            }
 827
 828            log::debug!("Response stream completed");
 829            anyhow::Ok(acp::PromptResponse {
 830                stop_reason: acp::StopReason::EndTurn,
 831                meta: None,
 832            })
 833        })
 834    }
 835}
 836
 837struct NativeAgentModelSelector {
 838    session_id: acp::SessionId,
 839    connection: NativeAgentConnection,
 840}
 841
 842impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
 843    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 844        log::debug!("NativeAgentConnection::list_models called");
 845        let list = self.connection.0.read(cx).models.model_list.clone();
 846        Task::ready(if list.is_empty() {
 847            Err(anyhow::anyhow!("No models available"))
 848        } else {
 849            Ok(list)
 850        })
 851    }
 852
 853    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
 854        log::debug!(
 855            "Setting model for session {}: {}",
 856            self.session_id,
 857            model_id
 858        );
 859        let Some(thread) = self
 860            .connection
 861            .0
 862            .read(cx)
 863            .sessions
 864            .get(&self.session_id)
 865            .map(|session| session.thread.clone())
 866        else {
 867            return Task::ready(Err(anyhow!("Session not found")));
 868        };
 869
 870        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
 871            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 872        };
 873
 874        thread.update(cx, |thread, cx| {
 875            thread.set_model(model.clone(), cx);
 876        });
 877
 878        update_settings_file(
 879            self.connection.0.read(cx).fs.clone(),
 880            cx,
 881            move |settings, _cx| {
 882                let provider = model.provider_id().0.to_string();
 883                let model = model.id().0.to_string();
 884                settings
 885                    .agent
 886                    .get_or_insert_default()
 887                    .set_model(LanguageModelSelection {
 888                        provider: provider.into(),
 889                        model,
 890                    });
 891            },
 892        );
 893
 894        Task::ready(Ok(()))
 895    }
 896
 897    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
 898        let Some(thread) = self
 899            .connection
 900            .0
 901            .read(cx)
 902            .sessions
 903            .get(&self.session_id)
 904            .map(|session| session.thread.clone())
 905        else {
 906            return Task::ready(Err(anyhow!("Session not found")));
 907        };
 908        let Some(model) = thread.read(cx).model() else {
 909            return Task::ready(Err(anyhow!("Model not found")));
 910        };
 911        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 912        else {
 913            return Task::ready(Err(anyhow!("Provider not found")));
 914        };
 915        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 916            model, &provider,
 917        )))
 918    }
 919
 920    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
 921        Some(self.connection.0.read(cx).models.watch())
 922    }
 923}
 924
 925impl acp_thread::AgentConnection for NativeAgentConnection {
 926    fn new_thread(
 927        self: Rc<Self>,
 928        project: Entity<Project>,
 929        cwd: &Path,
 930        cx: &mut App,
 931    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 932        let agent = self.0.clone();
 933        log::debug!("Creating new thread for project at: {:?}", cwd);
 934
 935        cx.spawn(async move |cx| {
 936            log::debug!("Starting thread creation in async context");
 937
 938            // Create Thread
 939            let thread = agent.update(
 940                cx,
 941                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 942                    // Fetch default model from registry settings
 943                    let registry = LanguageModelRegistry::read_global(cx);
 944                    // Log available models for debugging
 945                    let available_count = registry.available_models(cx).count();
 946                    log::debug!("Total available models: {}", available_count);
 947
 948                    let default_model = registry.default_model().and_then(|default_model| {
 949                        agent
 950                            .models
 951                            .model_from_id(&LanguageModels::model_id(&default_model.model))
 952                    });
 953                    Ok(cx.new(|cx| {
 954                        Thread::new(
 955                            project.clone(),
 956                            agent.project_context.clone(),
 957                            agent.context_server_registry.clone(),
 958                            agent.templates.clone(),
 959                            default_model,
 960                            cx,
 961                        )
 962                    }))
 963                },
 964            )??;
 965            agent.update(cx, |agent, cx| agent.register_session(thread, cx))
 966        })
 967    }
 968
 969    fn auth_methods(&self) -> &[acp::AuthMethod] {
 970        &[] // No auth for in-process
 971    }
 972
 973    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 974        Task::ready(Ok(()))
 975    }
 976
 977    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 978        Some(Rc::new(NativeAgentModelSelector {
 979            session_id: session_id.clone(),
 980            connection: self.clone(),
 981        }) as Rc<dyn AgentModelSelector>)
 982    }
 983
 984    fn prompt(
 985        &self,
 986        id: Option<acp_thread::UserMessageId>,
 987        params: acp::PromptRequest,
 988        cx: &mut App,
 989    ) -> Task<Result<acp::PromptResponse>> {
 990        let id = id.expect("UserMessageId is required");
 991        let session_id = params.session_id.clone();
 992        log::info!("Received prompt request for session: {}", session_id);
 993        log::debug!("Prompt blocks count: {}", params.prompt.len());
 994
 995        self.run_turn(session_id, cx, |thread, cx| {
 996            let content: Vec<UserMessageContent> = params
 997                .prompt
 998                .into_iter()
 999                .map(Into::into)
1000                .collect::<Vec<_>>();
1001            log::debug!("Converted prompt to message: {} chars", content.len());
1002            log::debug!("Message id: {:?}", id);
1003            log::debug!("Message content: {:?}", content);
1004
1005            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1006        })
1007    }
1008
1009    fn resume(
1010        &self,
1011        session_id: &acp::SessionId,
1012        _cx: &App,
1013    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1014        Some(Rc::new(NativeAgentSessionResume {
1015            connection: self.clone(),
1016            session_id: session_id.clone(),
1017        }) as _)
1018    }
1019
1020    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1021        log::info!("Cancelling on session: {}", session_id);
1022        self.0.update(cx, |agent, cx| {
1023            if let Some(agent) = agent.sessions.get(session_id) {
1024                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1025            }
1026        });
1027    }
1028
1029    fn truncate(
1030        &self,
1031        session_id: &agent_client_protocol::SessionId,
1032        cx: &App,
1033    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1034        self.0.read_with(cx, |agent, _cx| {
1035            agent.sessions.get(session_id).map(|session| {
1036                Rc::new(NativeAgentSessionTruncate {
1037                    thread: session.thread.clone(),
1038                    acp_thread: session.acp_thread.clone(),
1039                }) as _
1040            })
1041        })
1042    }
1043
1044    fn set_title(
1045        &self,
1046        session_id: &acp::SessionId,
1047        _cx: &App,
1048    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1049        Some(Rc::new(NativeAgentSessionSetTitle {
1050            connection: self.clone(),
1051            session_id: session_id.clone(),
1052        }) as _)
1053    }
1054
1055    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1056        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1057    }
1058
1059    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1060        self
1061    }
1062}
1063
1064impl acp_thread::AgentTelemetry for NativeAgentConnection {
1065    fn agent_name(&self) -> String {
1066        "Zed".into()
1067    }
1068
1069    fn thread_data(
1070        &self,
1071        session_id: &acp::SessionId,
1072        cx: &mut App,
1073    ) -> Task<Result<serde_json::Value>> {
1074        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1075            return Task::ready(Err(anyhow!("Session not found")));
1076        };
1077
1078        let task = session.thread.read(cx).to_db(cx);
1079        cx.background_spawn(async move {
1080            serde_json::to_value(task.await).context("Failed to serialize thread")
1081        })
1082    }
1083}
1084
1085struct NativeAgentSessionTruncate {
1086    thread: Entity<Thread>,
1087    acp_thread: WeakEntity<AcpThread>,
1088}
1089
1090impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1091    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1092        match self.thread.update(cx, |thread, cx| {
1093            thread.truncate(message_id.clone(), cx)?;
1094            Ok(thread.latest_token_usage())
1095        }) {
1096            Ok(usage) => {
1097                self.acp_thread
1098                    .update(cx, |thread, cx| {
1099                        thread.update_token_usage(usage, cx);
1100                    })
1101                    .ok();
1102                Task::ready(Ok(()))
1103            }
1104            Err(error) => Task::ready(Err(error)),
1105        }
1106    }
1107}
1108
1109struct NativeAgentSessionResume {
1110    connection: NativeAgentConnection,
1111    session_id: acp::SessionId,
1112}
1113
1114impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1115    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1116        self.connection
1117            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1118                thread.update(cx, |thread, cx| thread.resume(cx))
1119            })
1120    }
1121}
1122
1123struct NativeAgentSessionSetTitle {
1124    connection: NativeAgentConnection,
1125    session_id: acp::SessionId,
1126}
1127
1128impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1129    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1130        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1131            return Task::ready(Err(anyhow!("session not found")));
1132        };
1133        let thread = session.thread.clone();
1134        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1135        Task::ready(Ok(()))
1136    }
1137}
1138
1139pub struct AcpThreadEnvironment {
1140    acp_thread: WeakEntity<AcpThread>,
1141}
1142
1143impl ThreadEnvironment for AcpThreadEnvironment {
1144    fn create_terminal(
1145        &self,
1146        command: String,
1147        cwd: Option<PathBuf>,
1148        output_byte_limit: Option<u64>,
1149        cx: &mut AsyncApp,
1150    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1151        let task = self.acp_thread.update(cx, |thread, cx| {
1152            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1153        });
1154
1155        let acp_thread = self.acp_thread.clone();
1156        cx.spawn(async move |cx| {
1157            let terminal = task?.await?;
1158
1159            let (drop_tx, drop_rx) = oneshot::channel();
1160            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1161
1162            cx.spawn(async move |cx| {
1163                drop_rx.await.ok();
1164                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1165            })
1166            .detach();
1167
1168            let handle = AcpTerminalHandle {
1169                terminal,
1170                _drop_tx: Some(drop_tx),
1171            };
1172
1173            Ok(Rc::new(handle) as _)
1174        })
1175    }
1176}
1177
1178pub struct AcpTerminalHandle {
1179    terminal: Entity<acp_thread::Terminal>,
1180    _drop_tx: Option<oneshot::Sender<()>>,
1181}
1182
1183impl TerminalHandle for AcpTerminalHandle {
1184    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1185        self.terminal.read_with(cx, |term, _cx| term.id().clone())
1186    }
1187
1188    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1189        self.terminal
1190            .read_with(cx, |term, _cx| term.wait_for_exit())
1191    }
1192
1193    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1194        self.terminal
1195            .read_with(cx, |term, cx| term.current_output(cx))
1196    }
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201    use crate::HistoryEntryId;
1202
1203    use super::*;
1204    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1205    use fs::FakeFs;
1206    use gpui::TestAppContext;
1207    use indoc::indoc;
1208    use language_model::fake_provider::FakeLanguageModel;
1209    use serde_json::json;
1210    use settings::SettingsStore;
1211    use util::path;
1212
1213    #[gpui::test]
1214    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1215        init_test(cx);
1216        let fs = FakeFs::new(cx.executor());
1217        fs.insert_tree(
1218            "/",
1219            json!({
1220                "a": {}
1221            }),
1222        )
1223        .await;
1224        let project = Project::test(fs.clone(), [], cx).await;
1225        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1226        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1227        let agent = NativeAgent::new(
1228            project.clone(),
1229            history_store,
1230            Templates::new(),
1231            None,
1232            fs.clone(),
1233            &mut cx.to_async(),
1234        )
1235        .await
1236        .unwrap();
1237        agent.read_with(cx, |agent, cx| {
1238            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1239        });
1240
1241        let worktree = project
1242            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1243            .await
1244            .unwrap();
1245        cx.run_until_parked();
1246        agent.read_with(cx, |agent, cx| {
1247            assert_eq!(
1248                agent.project_context.read(cx).worktrees,
1249                vec![WorktreeContext {
1250                    root_name: "a".into(),
1251                    abs_path: Path::new("/a").into(),
1252                    rules_file: None
1253                }]
1254            )
1255        });
1256
1257        // Creating `/a/.rules` updates the project context.
1258        fs.insert_file("/a/.rules", Vec::new()).await;
1259        cx.run_until_parked();
1260        agent.read_with(cx, |agent, cx| {
1261            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1262            assert_eq!(
1263                agent.project_context.read(cx).worktrees,
1264                vec![WorktreeContext {
1265                    root_name: "a".into(),
1266                    abs_path: Path::new("/a").into(),
1267                    rules_file: Some(RulesFileContext {
1268                        path_in_worktree: Path::new(".rules").into(),
1269                        text: "".into(),
1270                        project_entry_id: rules_entry.id.to_usize()
1271                    })
1272                }]
1273            )
1274        });
1275    }
1276
1277    #[gpui::test]
1278    async fn test_listing_models(cx: &mut TestAppContext) {
1279        init_test(cx);
1280        let fs = FakeFs::new(cx.executor());
1281        fs.insert_tree("/", json!({ "a": {}  })).await;
1282        let project = Project::test(fs.clone(), [], cx).await;
1283        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1284        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1285        let connection = NativeAgentConnection(
1286            NativeAgent::new(
1287                project.clone(),
1288                history_store,
1289                Templates::new(),
1290                None,
1291                fs.clone(),
1292                &mut cx.to_async(),
1293            )
1294            .await
1295            .unwrap(),
1296        );
1297
1298        // Create a thread/session
1299        let acp_thread = cx
1300            .update(|cx| {
1301                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1302            })
1303            .await
1304            .unwrap();
1305
1306        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1307
1308        let models = cx
1309            .update(|cx| {
1310                connection
1311                    .model_selector(&session_id)
1312                    .unwrap()
1313                    .list_models(cx)
1314            })
1315            .await
1316            .unwrap();
1317
1318        let acp_thread::AgentModelList::Grouped(models) = models else {
1319            panic!("Unexpected model group");
1320        };
1321        assert_eq!(
1322            models,
1323            IndexMap::from_iter([(
1324                AgentModelGroupName("Fake".into()),
1325                vec![AgentModelInfo {
1326                    id: acp::ModelId("fake/fake".into()),
1327                    name: "Fake".into(),
1328                    description: None,
1329                    icon: Some(ui::IconName::ZedAssistant),
1330                }]
1331            )])
1332        );
1333    }
1334
1335    #[gpui::test]
1336    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1337        init_test(cx);
1338        let fs = FakeFs::new(cx.executor());
1339        fs.create_dir(paths::settings_file().parent().unwrap())
1340            .await
1341            .unwrap();
1342        fs.insert_file(
1343            paths::settings_file(),
1344            json!({
1345                "agent": {
1346                    "default_model": {
1347                        "provider": "foo",
1348                        "model": "bar"
1349                    }
1350                }
1351            })
1352            .to_string()
1353            .into_bytes(),
1354        )
1355        .await;
1356        let project = Project::test(fs.clone(), [], cx).await;
1357
1358        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1359        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1360
1361        // Create the agent and connection
1362        let agent = NativeAgent::new(
1363            project.clone(),
1364            history_store,
1365            Templates::new(),
1366            None,
1367            fs.clone(),
1368            &mut cx.to_async(),
1369        )
1370        .await
1371        .unwrap();
1372        let connection = NativeAgentConnection(agent.clone());
1373
1374        // Create a thread/session
1375        let acp_thread = cx
1376            .update(|cx| {
1377                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1378            })
1379            .await
1380            .unwrap();
1381
1382        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1383
1384        // Select a model
1385        let selector = connection.model_selector(&session_id).unwrap();
1386        let model_id = acp::ModelId("fake/fake".into());
1387        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1388            .await
1389            .unwrap();
1390
1391        // Verify the thread has the selected model
1392        agent.read_with(cx, |agent, _| {
1393            let session = agent.sessions.get(&session_id).unwrap();
1394            session.thread.read_with(cx, |thread, _| {
1395                assert_eq!(thread.model().unwrap().id().0, "fake");
1396            });
1397        });
1398
1399        cx.run_until_parked();
1400
1401        // Verify settings file was updated
1402        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1403        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1404
1405        // Check that the agent settings contain the selected model
1406        assert_eq!(
1407            settings_json["agent"]["default_model"]["model"],
1408            json!("fake")
1409        );
1410        assert_eq!(
1411            settings_json["agent"]["default_model"]["provider"],
1412            json!("fake")
1413        );
1414    }
1415
1416    #[gpui::test]
1417    #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1418    async fn test_save_load_thread(cx: &mut TestAppContext) {
1419        init_test(cx);
1420        let fs = FakeFs::new(cx.executor());
1421        fs.insert_tree(
1422            "/",
1423            json!({
1424                "a": {
1425                    "b.md": "Lorem"
1426                }
1427            }),
1428        )
1429        .await;
1430        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1431        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1432        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1433        let agent = NativeAgent::new(
1434            project.clone(),
1435            history_store.clone(),
1436            Templates::new(),
1437            None,
1438            fs.clone(),
1439            &mut cx.to_async(),
1440        )
1441        .await
1442        .unwrap();
1443        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1444
1445        let acp_thread = cx
1446            .update(|cx| {
1447                connection
1448                    .clone()
1449                    .new_thread(project.clone(), Path::new(""), cx)
1450            })
1451            .await
1452            .unwrap();
1453        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1454        let thread = agent.read_with(cx, |agent, _| {
1455            agent.sessions.get(&session_id).unwrap().thread.clone()
1456        });
1457
1458        // Ensure empty threads are not saved, even if they get mutated.
1459        let model = Arc::new(FakeLanguageModel::default());
1460        let summary_model = Arc::new(FakeLanguageModel::default());
1461        thread.update(cx, |thread, cx| {
1462            thread.set_model(model.clone(), cx);
1463            thread.set_summarization_model(Some(summary_model.clone()), cx);
1464        });
1465        cx.run_until_parked();
1466        assert_eq!(history_entries(&history_store, cx), vec![]);
1467
1468        let send = acp_thread.update(cx, |thread, cx| {
1469            thread.send(
1470                vec![
1471                    "What does ".into(),
1472                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
1473                        name: "b.md".into(),
1474                        uri: MentionUri::File {
1475                            abs_path: path!("/a/b.md").into(),
1476                        }
1477                        .to_uri()
1478                        .to_string(),
1479                        annotations: None,
1480                        description: None,
1481                        mime_type: None,
1482                        size: None,
1483                        title: None,
1484                        meta: None,
1485                    }),
1486                    " mean?".into(),
1487                ],
1488                cx,
1489            )
1490        });
1491        let send = cx.foreground_executor().spawn(send);
1492        cx.run_until_parked();
1493
1494        model.send_last_completion_stream_text_chunk("Lorem.");
1495        model.end_last_completion_stream();
1496        cx.run_until_parked();
1497        summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
1498        summary_model.end_last_completion_stream();
1499
1500        send.await.unwrap();
1501        acp_thread.read_with(cx, |thread, cx| {
1502            assert_eq!(
1503                thread.to_markdown(cx),
1504                indoc! {"
1505                    ## User
1506
1507                    What does [@b.md](file:///a/b.md) mean?
1508
1509                    ## Assistant
1510
1511                    Lorem.
1512
1513                "}
1514            )
1515        });
1516
1517        cx.run_until_parked();
1518
1519        // Drop the ACP thread, which should cause the session to be dropped as well.
1520        cx.update(|_| {
1521            drop(thread);
1522            drop(acp_thread);
1523        });
1524        agent.read_with(cx, |agent, _| {
1525            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1526        });
1527
1528        // Ensure the thread can be reloaded from disk.
1529        assert_eq!(
1530            history_entries(&history_store, cx),
1531            vec![(
1532                HistoryEntryId::AcpThread(session_id.clone()),
1533                "Explaining /a/b.md".into()
1534            )]
1535        );
1536        let acp_thread = agent
1537            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1538            .await
1539            .unwrap();
1540        acp_thread.read_with(cx, |thread, cx| {
1541            assert_eq!(
1542                thread.to_markdown(cx),
1543                indoc! {"
1544                    ## User
1545
1546                    What does [@b.md](file:///a/b.md) mean?
1547
1548                    ## Assistant
1549
1550                    Lorem.
1551
1552                "}
1553            )
1554        });
1555    }
1556
1557    fn history_entries(
1558        history: &Entity<HistoryStore>,
1559        cx: &mut TestAppContext,
1560    ) -> Vec<(HistoryEntryId, String)> {
1561        history.read_with(cx, |history, _| {
1562            history
1563                .entries()
1564                .map(|e| (e.id(), e.title().to_string()))
1565                .collect::<Vec<_>>()
1566        })
1567    }
1568
1569    fn init_test(cx: &mut TestAppContext) {
1570        env_logger::try_init().ok();
1571        cx.update(|cx| {
1572            let settings_store = SettingsStore::test(cx);
1573            cx.set_global(settings_store);
1574            Project::init_settings(cx);
1575            agent_settings::init(cx);
1576            language::init(cx);
1577            LanguageModelRegistry::test(cx);
1578        });
1579    }
1580}