agent.rs

   1use crate::{
   2    ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
   3    UserMessageContent, templates::Templates,
   4};
   5use crate::{HistoryStore, TerminalHandle, ThreadEnvironment, TitleUpdated, TokenUsageUpdated};
   6use acp_thread::{AcpThread, AgentModelSelector};
   7use action_log::ActionLog;
   8use agent_client_protocol as acp;
   9use anyhow::{Context as _, Result, anyhow};
  10use collections::{HashSet, IndexMap};
  11use fs::Fs;
  12use futures::channel::{mpsc, oneshot};
  13use futures::future::Shared;
  14use futures::{StreamExt, future};
  15use gpui::{
  16    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  17};
  18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  19use project::{Project, ProjectItem, ProjectPath, Worktree};
  20use prompt_store::{
  21    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  22};
  23use settings::{LanguageModelSelection, update_settings_file};
  24use std::any::Any;
  25use std::collections::HashMap;
  26use std::path::{Path, PathBuf};
  27use std::rc::Rc;
  28use std::sync::{Arc, LazyLock};
  29use util::ResultExt;
  30use util::rel_path::RelPath;
  31
  32static RULES_FILE_NAMES: LazyLock<[&RelPath; 9]> = LazyLock::new(|| {
  33    [
  34        RelPath::unix(".rules").unwrap(),
  35        RelPath::unix(".cursorrules").unwrap(),
  36        RelPath::unix(".windsurfrules").unwrap(),
  37        RelPath::unix(".clinerules").unwrap(),
  38        RelPath::unix(".github/copilot-instructions.md").unwrap(),
  39        RelPath::unix("CLAUDE.md").unwrap(),
  40        RelPath::unix("AGENT.md").unwrap(),
  41        RelPath::unix("AGENTS.md").unwrap(),
  42        RelPath::unix("GEMINI.md").unwrap(),
  43    ]
  44});
  45
  46pub struct RulesLoadingError {
  47    pub message: SharedString,
  48}
  49
  50/// Holds both the internal Thread and the AcpThread for a session
  51struct Session {
  52    /// The internal thread that processes messages
  53    thread: Entity<Thread>,
  54    /// The ACP thread that handles protocol communication
  55    acp_thread: WeakEntity<acp_thread::AcpThread>,
  56    pending_save: Task<()>,
  57    _subscriptions: Vec<Subscription>,
  58}
  59
  60pub struct LanguageModels {
  61    /// Access language model by ID
  62    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  63    /// Cached list for returning language model information
  64    model_list: acp_thread::AgentModelList,
  65    refresh_models_rx: watch::Receiver<()>,
  66    refresh_models_tx: watch::Sender<()>,
  67    _authenticate_all_providers_task: Task<()>,
  68}
  69
  70impl LanguageModels {
  71    fn new(cx: &mut App) -> Self {
  72        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  73
  74        let mut this = Self {
  75            models: HashMap::default(),
  76            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  77            refresh_models_rx,
  78            refresh_models_tx,
  79            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  80        };
  81        this.refresh_list(cx);
  82        this
  83    }
  84
  85    fn refresh_list(&mut self, cx: &App) {
  86        let providers = LanguageModelRegistry::global(cx)
  87            .read(cx)
  88            .providers()
  89            .into_iter()
  90            .filter(|provider| provider.is_authenticated(cx))
  91            .collect::<Vec<_>>();
  92
  93        let mut language_model_list = IndexMap::default();
  94        let mut recommended_models = HashSet::default();
  95
  96        let mut recommended = Vec::new();
  97        for provider in &providers {
  98            for model in provider.recommended_models(cx) {
  99                recommended_models.insert((model.provider_id(), model.id()));
 100                recommended.push(Self::map_language_model_to_info(&model, provider));
 101            }
 102        }
 103        if !recommended.is_empty() {
 104            language_model_list.insert(
 105                acp_thread::AgentModelGroupName("Recommended".into()),
 106                recommended,
 107            );
 108        }
 109
 110        let mut models = HashMap::default();
 111        for provider in providers {
 112            let mut provider_models = Vec::new();
 113            for model in provider.provided_models(cx) {
 114                let model_info = Self::map_language_model_to_info(&model, &provider);
 115                let model_id = model_info.id.clone();
 116                if !recommended_models.contains(&(model.provider_id(), model.id())) {
 117                    provider_models.push(model_info);
 118                }
 119                models.insert(model_id, model);
 120            }
 121            if !provider_models.is_empty() {
 122                language_model_list.insert(
 123                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 124                    provider_models,
 125                );
 126            }
 127        }
 128
 129        self.models = models;
 130        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 131        self.refresh_models_tx.send(()).ok();
 132    }
 133
 134    fn watch(&self) -> watch::Receiver<()> {
 135        self.refresh_models_rx.clone()
 136    }
 137
 138    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 139        self.models.get(model_id).cloned()
 140    }
 141
 142    fn map_language_model_to_info(
 143        model: &Arc<dyn LanguageModel>,
 144        provider: &Arc<dyn LanguageModelProvider>,
 145    ) -> acp_thread::AgentModelInfo {
 146        acp_thread::AgentModelInfo {
 147            id: Self::model_id(model),
 148            name: model.name().0,
 149            description: None,
 150            icon: Some(provider.icon()),
 151        }
 152    }
 153
 154    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 155        acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 156    }
 157
 158    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 159        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 160            .read(cx)
 161            .providers()
 162            .iter()
 163            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 164            .collect::<Vec<_>>();
 165
 166        cx.background_spawn(async move {
 167            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 168                if let Err(err) = authenticate_task.await {
 169                    match err {
 170                        language_model::AuthenticateError::CredentialsNotFound => {
 171                            // Since we're authenticating these providers in the
 172                            // background for the purposes of populating the
 173                            // language selector, we don't care about providers
 174                            // where the credentials are not found.
 175                        }
 176                        language_model::AuthenticateError::ConnectionRefused => {
 177                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 178                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 179                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 180                        }
 181                        _ => {
 182                            // Some providers have noisy failure states that we
 183                            // don't want to spam the logs with every time the
 184                            // language model selector is initialized.
 185                            //
 186                            // Ideally these should have more clear failure modes
 187                            // that we know are safe to ignore here, like what we do
 188                            // with `CredentialsNotFound` above.
 189                            match provider_id.0.as_ref() {
 190                                "lmstudio" | "ollama" => {
 191                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 192                                    //
 193                                    // These fail noisily, so we don't log them.
 194                                }
 195                                "copilot_chat" => {
 196                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 197                                }
 198                                _ => {
 199                                    log::error!(
 200                                        "Failed to authenticate provider: {}: {err}",
 201                                        provider_name.0
 202                                    );
 203                                }
 204                            }
 205                        }
 206                    }
 207                }
 208            }
 209        })
 210    }
 211}
 212
 213pub struct NativeAgent {
 214    /// Session ID -> Session mapping
 215    sessions: HashMap<acp::SessionId, Session>,
 216    history: Entity<HistoryStore>,
 217    /// Shared project context for all threads
 218    project_context: Entity<ProjectContext>,
 219    project_context_needs_refresh: watch::Sender<()>,
 220    _maintain_project_context: Task<Result<()>>,
 221    context_server_registry: Entity<ContextServerRegistry>,
 222    /// Shared templates for all threads
 223    templates: Arc<Templates>,
 224    /// Cached model information
 225    models: LanguageModels,
 226    project: Entity<Project>,
 227    prompt_store: Option<Entity<PromptStore>>,
 228    fs: Arc<dyn Fs>,
 229    _subscriptions: Vec<Subscription>,
 230}
 231
 232impl NativeAgent {
 233    pub async fn new(
 234        project: Entity<Project>,
 235        history: Entity<HistoryStore>,
 236        templates: Arc<Templates>,
 237        prompt_store: Option<Entity<PromptStore>>,
 238        fs: Arc<dyn Fs>,
 239        cx: &mut AsyncApp,
 240    ) -> Result<Entity<NativeAgent>> {
 241        log::debug!("Creating new NativeAgent");
 242
 243        let project_context = cx
 244            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 245            .await;
 246
 247        cx.new(|cx| {
 248            let mut subscriptions = vec![
 249                cx.subscribe(&project, Self::handle_project_event),
 250                cx.subscribe(
 251                    &LanguageModelRegistry::global(cx),
 252                    Self::handle_models_updated_event,
 253                ),
 254            ];
 255            if let Some(prompt_store) = prompt_store.as_ref() {
 256                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 257            }
 258
 259            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 260                watch::channel(());
 261            Self {
 262                sessions: HashMap::new(),
 263                history,
 264                project_context: cx.new(|_| project_context),
 265                project_context_needs_refresh: project_context_needs_refresh_tx,
 266                _maintain_project_context: cx.spawn(async move |this, cx| {
 267                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 268                }),
 269                context_server_registry: cx.new(|cx| {
 270                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 271                }),
 272                templates,
 273                models: LanguageModels::new(cx),
 274                project,
 275                prompt_store,
 276                fs,
 277                _subscriptions: subscriptions,
 278            }
 279        })
 280    }
 281
 282    fn register_session(
 283        &mut self,
 284        thread_handle: Entity<Thread>,
 285        cx: &mut Context<Self>,
 286    ) -> Entity<AcpThread> {
 287        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 288
 289        let thread = thread_handle.read(cx);
 290        let session_id = thread.id().clone();
 291        let title = thread.title();
 292        let project = thread.project.clone();
 293        let action_log = thread.action_log.clone();
 294        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 295        let acp_thread = cx.new(|cx| {
 296            acp_thread::AcpThread::new(
 297                title,
 298                connection,
 299                project.clone(),
 300                action_log.clone(),
 301                session_id.clone(),
 302                prompt_capabilities_rx,
 303                cx,
 304            )
 305        });
 306
 307        let registry = LanguageModelRegistry::read_global(cx);
 308        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 309
 310        thread_handle.update(cx, |thread, cx| {
 311            thread.set_summarization_model(summarization_model, cx);
 312            thread.add_default_tools(
 313                Rc::new(AcpThreadEnvironment {
 314                    acp_thread: acp_thread.downgrade(),
 315                }) as _,
 316                cx,
 317            )
 318        });
 319
 320        let subscriptions = vec![
 321            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 322                this.sessions.remove(acp_thread.session_id());
 323            }),
 324            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 325            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 326            cx.observe(&thread_handle, move |this, thread, cx| {
 327                this.save_thread(thread, cx)
 328            }),
 329        ];
 330
 331        self.sessions.insert(
 332            session_id,
 333            Session {
 334                thread: thread_handle,
 335                acp_thread: acp_thread.downgrade(),
 336                _subscriptions: subscriptions,
 337                pending_save: Task::ready(()),
 338            },
 339        );
 340        acp_thread
 341    }
 342
 343    pub fn models(&self) -> &LanguageModels {
 344        &self.models
 345    }
 346
 347    async fn maintain_project_context(
 348        this: WeakEntity<Self>,
 349        mut needs_refresh: watch::Receiver<()>,
 350        cx: &mut AsyncApp,
 351    ) -> Result<()> {
 352        while needs_refresh.changed().await.is_ok() {
 353            let project_context = this
 354                .update(cx, |this, cx| {
 355                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 356                })?
 357                .await;
 358            this.update(cx, |this, cx| {
 359                this.project_context = cx.new(|_| project_context);
 360            })?;
 361        }
 362
 363        Ok(())
 364    }
 365
 366    fn build_project_context(
 367        project: &Entity<Project>,
 368        prompt_store: Option<&Entity<PromptStore>>,
 369        cx: &mut App,
 370    ) -> Task<ProjectContext> {
 371        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 372        let worktree_tasks = worktrees
 373            .into_iter()
 374            .map(|worktree| {
 375                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 376            })
 377            .collect::<Vec<_>>();
 378        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 379            prompt_store.read_with(cx, |prompt_store, cx| {
 380                let prompts = prompt_store.default_prompt_metadata();
 381                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 382                    let contents = prompt_store.load(prompt_metadata.id, cx);
 383                    async move { (contents.await, prompt_metadata) }
 384                });
 385                cx.background_spawn(future::join_all(load_tasks))
 386            })
 387        } else {
 388            Task::ready(vec![])
 389        };
 390
 391        cx.spawn(async move |_cx| {
 392            let (worktrees, default_user_rules) =
 393                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 394
 395            let worktrees = worktrees
 396                .into_iter()
 397                .map(|(worktree, _rules_error)| {
 398                    // TODO: show error message
 399                    // if let Some(rules_error) = rules_error {
 400                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 401                    // }
 402                    worktree
 403                })
 404                .collect::<Vec<_>>();
 405
 406            let default_user_rules = default_user_rules
 407                .into_iter()
 408                .flat_map(|(contents, prompt_metadata)| match contents {
 409                    Ok(contents) => Some(UserRulesContext {
 410                        uuid: match prompt_metadata.id {
 411                            PromptId::User { uuid } => uuid,
 412                            PromptId::EditWorkflow => return None,
 413                        },
 414                        title: prompt_metadata.title.map(|title| title.to_string()),
 415                        contents,
 416                    }),
 417                    Err(_err) => {
 418                        // TODO: show error message
 419                        // this.update(cx, |_, cx| {
 420                        //     cx.emit(RulesLoadingError {
 421                        //         message: format!("{err:?}").into(),
 422                        //     });
 423                        // })
 424                        // .ok();
 425                        None
 426                    }
 427                })
 428                .collect::<Vec<_>>();
 429
 430            ProjectContext::new(worktrees, default_user_rules)
 431        })
 432    }
 433
 434    fn load_worktree_info_for_system_prompt(
 435        worktree: Entity<Worktree>,
 436        project: Entity<Project>,
 437        cx: &mut App,
 438    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 439        let tree = worktree.read(cx);
 440        let root_name = tree.root_name_str().into();
 441        let abs_path = tree.abs_path();
 442
 443        let mut context = WorktreeContext {
 444            root_name,
 445            abs_path,
 446            rules_file: None,
 447        };
 448
 449        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 450        let Some(rules_task) = rules_task else {
 451            return Task::ready((context, None));
 452        };
 453
 454        cx.spawn(async move |_| {
 455            let (rules_file, rules_file_error) = match rules_task.await {
 456                Ok(rules_file) => (Some(rules_file), None),
 457                Err(err) => (
 458                    None,
 459                    Some(RulesLoadingError {
 460                        message: format!("{err}").into(),
 461                    }),
 462                ),
 463            };
 464            context.rules_file = rules_file;
 465            (context, rules_file_error)
 466        })
 467    }
 468
 469    fn load_worktree_rules_file(
 470        worktree: Entity<Worktree>,
 471        project: Entity<Project>,
 472        cx: &mut App,
 473    ) -> Option<Task<Result<RulesFileContext>>> {
 474        let worktree = worktree.read(cx);
 475        let worktree_id = worktree.id();
 476        let selected_rules_file = RULES_FILE_NAMES
 477            .into_iter()
 478            .filter_map(|name| {
 479                worktree
 480                    .entry_for_path(name)
 481                    .filter(|entry| entry.is_file())
 482                    .map(|entry| entry.path.clone())
 483            })
 484            .next();
 485
 486        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 487        // supported. This doesn't seem to occur often in GitHub repositories.
 488        selected_rules_file.map(|path_in_worktree| {
 489            let project_path = ProjectPath {
 490                worktree_id,
 491                path: path_in_worktree.clone(),
 492            };
 493            let buffer_task =
 494                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 495            let rope_task = cx.spawn(async move |cx| {
 496                buffer_task.await?.read_with(cx, |buffer, cx| {
 497                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 498                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 499                })?
 500            });
 501            // Build a string from the rope on a background thread.
 502            cx.background_spawn(async move {
 503                let (project_entry_id, rope) = rope_task.await?;
 504                anyhow::Ok(RulesFileContext {
 505                    path_in_worktree,
 506                    text: rope.to_string().trim().to_string(),
 507                    project_entry_id: project_entry_id.to_usize(),
 508                })
 509            })
 510        })
 511    }
 512
 513    fn handle_thread_title_updated(
 514        &mut self,
 515        thread: Entity<Thread>,
 516        _: &TitleUpdated,
 517        cx: &mut Context<Self>,
 518    ) {
 519        let session_id = thread.read(cx).id();
 520        let Some(session) = self.sessions.get(session_id) else {
 521            return;
 522        };
 523        let thread = thread.downgrade();
 524        let acp_thread = session.acp_thread.clone();
 525        cx.spawn(async move |_, cx| {
 526            let title = thread.read_with(cx, |thread, _| thread.title())?;
 527            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 528            task.await
 529        })
 530        .detach_and_log_err(cx);
 531    }
 532
 533    fn handle_thread_token_usage_updated(
 534        &mut self,
 535        thread: Entity<Thread>,
 536        usage: &TokenUsageUpdated,
 537        cx: &mut Context<Self>,
 538    ) {
 539        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 540            return;
 541        };
 542        session
 543            .acp_thread
 544            .update(cx, |acp_thread, cx| {
 545                acp_thread.update_token_usage(usage.0.clone(), cx);
 546            })
 547            .ok();
 548    }
 549
 550    fn handle_project_event(
 551        &mut self,
 552        _project: Entity<Project>,
 553        event: &project::Event,
 554        _cx: &mut Context<Self>,
 555    ) {
 556        match event {
 557            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 558                self.project_context_needs_refresh.send(()).ok();
 559            }
 560            project::Event::WorktreeUpdatedEntries(_, items) => {
 561                if items
 562                    .iter()
 563                    .any(|(path, _, _)| RULES_FILE_NAMES.iter().any(|name| path.as_ref() == *name))
 564                {
 565                    self.project_context_needs_refresh.send(()).ok();
 566                }
 567            }
 568            _ => {}
 569        }
 570    }
 571
 572    fn handle_prompts_updated_event(
 573        &mut self,
 574        _prompt_store: Entity<PromptStore>,
 575        _event: &prompt_store::PromptsUpdatedEvent,
 576        _cx: &mut Context<Self>,
 577    ) {
 578        self.project_context_needs_refresh.send(()).ok();
 579    }
 580
 581    fn handle_models_updated_event(
 582        &mut self,
 583        _registry: Entity<LanguageModelRegistry>,
 584        _event: &language_model::Event,
 585        cx: &mut Context<Self>,
 586    ) {
 587        self.models.refresh_list(cx);
 588
 589        let registry = LanguageModelRegistry::read_global(cx);
 590        let default_model = registry.default_model().map(|m| m.model);
 591        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 592
 593        for session in self.sessions.values_mut() {
 594            session.thread.update(cx, |thread, cx| {
 595                if thread.model().is_none()
 596                    && let Some(model) = default_model.clone()
 597                {
 598                    thread.set_model(model, cx);
 599                    cx.notify();
 600                }
 601                thread.set_summarization_model(summarization_model.clone(), cx);
 602            });
 603        }
 604    }
 605
 606    pub fn open_thread(
 607        &mut self,
 608        id: acp::SessionId,
 609        cx: &mut Context<Self>,
 610    ) -> Task<Result<Entity<AcpThread>>> {
 611        let database_future = ThreadsDatabase::connect(cx);
 612        cx.spawn(async move |this, cx| {
 613            let database = database_future.await.map_err(|err| anyhow!(err))?;
 614            let db_thread = database
 615                .load_thread(id.clone())
 616                .await?
 617                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 618
 619            let thread = this.update(cx, |this, cx| {
 620                let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
 621                cx.new(|cx| {
 622                    Thread::from_db(
 623                        id.clone(),
 624                        db_thread,
 625                        this.project.clone(),
 626                        this.project_context.clone(),
 627                        this.context_server_registry.clone(),
 628                        action_log.clone(),
 629                        this.templates.clone(),
 630                        cx,
 631                    )
 632                })
 633            })?;
 634            let acp_thread =
 635                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 636            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
 637            cx.update(|cx| {
 638                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 639            })?
 640            .await?;
 641            Ok(acp_thread)
 642        })
 643    }
 644
 645    pub fn thread_summary(
 646        &mut self,
 647        id: acp::SessionId,
 648        cx: &mut Context<Self>,
 649    ) -> Task<Result<SharedString>> {
 650        let thread = self.open_thread(id.clone(), cx);
 651        cx.spawn(async move |this, cx| {
 652            let acp_thread = thread.await?;
 653            let result = this
 654                .update(cx, |this, cx| {
 655                    this.sessions
 656                        .get(&id)
 657                        .unwrap()
 658                        .thread
 659                        .update(cx, |thread, cx| thread.summary(cx))
 660                })?
 661                .await?;
 662            drop(acp_thread);
 663            Ok(result)
 664        })
 665    }
 666
 667    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 668        if thread.read(cx).is_empty() {
 669            return;
 670        }
 671
 672        let database_future = ThreadsDatabase::connect(cx);
 673        let (id, db_thread) =
 674            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 675        let Some(session) = self.sessions.get_mut(&id) else {
 676            return;
 677        };
 678        let history = self.history.clone();
 679        session.pending_save = cx.spawn(async move |_, cx| {
 680            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 681                return;
 682            };
 683            let db_thread = db_thread.await;
 684            database.save_thread(id, db_thread).await.log_err();
 685            history.update(cx, |history, cx| history.reload(cx)).ok();
 686        });
 687    }
 688}
 689
 690/// Wrapper struct that implements the AgentConnection trait
 691#[derive(Clone)]
 692pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 693
 694impl NativeAgentConnection {
 695    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 696        self.0
 697            .read(cx)
 698            .sessions
 699            .get(session_id)
 700            .map(|session| session.thread.clone())
 701    }
 702
 703    fn run_turn(
 704        &self,
 705        session_id: acp::SessionId,
 706        cx: &mut App,
 707        f: impl 'static
 708        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 709    ) -> Task<Result<acp::PromptResponse>> {
 710        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 711            agent
 712                .sessions
 713                .get_mut(&session_id)
 714                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 715        }) else {
 716            return Task::ready(Err(anyhow!("Session not found")));
 717        };
 718        log::debug!("Found session for: {}", session_id);
 719
 720        let response_stream = match f(thread, cx) {
 721            Ok(stream) => stream,
 722            Err(err) => return Task::ready(Err(err)),
 723        };
 724        Self::handle_thread_events(response_stream, acp_thread, cx)
 725    }
 726
 727    fn handle_thread_events(
 728        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 729        acp_thread: WeakEntity<AcpThread>,
 730        cx: &App,
 731    ) -> Task<Result<acp::PromptResponse>> {
 732        cx.spawn(async move |cx| {
 733            // Handle response stream and forward to session.acp_thread
 734            while let Some(result) = events.next().await {
 735                match result {
 736                    Ok(event) => {
 737                        log::trace!("Received completion event: {:?}", event);
 738
 739                        match event {
 740                            ThreadEvent::UserMessage(message) => {
 741                                acp_thread.update(cx, |thread, cx| {
 742                                    for content in message.content {
 743                                        thread.push_user_content_block(
 744                                            Some(message.id.clone()),
 745                                            content.into(),
 746                                            cx,
 747                                        );
 748                                    }
 749                                })?;
 750                            }
 751                            ThreadEvent::AgentText(text) => {
 752                                acp_thread.update(cx, |thread, cx| {
 753                                    thread.push_assistant_content_block(
 754                                        acp::ContentBlock::Text(acp::TextContent {
 755                                            text,
 756                                            annotations: None,
 757                                            meta: None,
 758                                        }),
 759                                        false,
 760                                        cx,
 761                                    )
 762                                })?;
 763                            }
 764                            ThreadEvent::AgentThinking(text) => {
 765                                acp_thread.update(cx, |thread, cx| {
 766                                    thread.push_assistant_content_block(
 767                                        acp::ContentBlock::Text(acp::TextContent {
 768                                            text,
 769                                            annotations: None,
 770                                            meta: None,
 771                                        }),
 772                                        true,
 773                                        cx,
 774                                    )
 775                                })?;
 776                            }
 777                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 778                                tool_call,
 779                                options,
 780                                response,
 781                            }) => {
 782                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 783                                    thread.request_tool_call_authorization(
 784                                        tool_call, options, true, cx,
 785                                    )
 786                                })??;
 787                                cx.background_spawn(async move {
 788                                    if let acp::RequestPermissionOutcome::Selected { option_id } =
 789                                        outcome_task.await
 790                                    {
 791                                        response
 792                                            .send(option_id)
 793                                            .map(|_| anyhow!("authorization receiver was dropped"))
 794                                            .log_err();
 795                                    }
 796                                })
 797                                .detach();
 798                            }
 799                            ThreadEvent::ToolCall(tool_call) => {
 800                                acp_thread.update(cx, |thread, cx| {
 801                                    thread.upsert_tool_call(tool_call, cx)
 802                                })??;
 803                            }
 804                            ThreadEvent::ToolCallUpdate(update) => {
 805                                acp_thread.update(cx, |thread, cx| {
 806                                    thread.update_tool_call(update, cx)
 807                                })??;
 808                            }
 809                            ThreadEvent::Retry(status) => {
 810                                acp_thread.update(cx, |thread, cx| {
 811                                    thread.update_retry_status(status, cx)
 812                                })?;
 813                            }
 814                            ThreadEvent::Stop(stop_reason) => {
 815                                log::debug!("Assistant message complete: {:?}", stop_reason);
 816                                return Ok(acp::PromptResponse {
 817                                    stop_reason,
 818                                    meta: None,
 819                                });
 820                            }
 821                        }
 822                    }
 823                    Err(e) => {
 824                        log::error!("Error in model response stream: {:?}", e);
 825                        return Err(e);
 826                    }
 827                }
 828            }
 829
 830            log::debug!("Response stream completed");
 831            anyhow::Ok(acp::PromptResponse {
 832                stop_reason: acp::StopReason::EndTurn,
 833                meta: None,
 834            })
 835        })
 836    }
 837}
 838
 839struct NativeAgentModelSelector {
 840    session_id: acp::SessionId,
 841    connection: NativeAgentConnection,
 842}
 843
 844impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
 845    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 846        log::debug!("NativeAgentConnection::list_models called");
 847        let list = self.connection.0.read(cx).models.model_list.clone();
 848        Task::ready(if list.is_empty() {
 849            Err(anyhow::anyhow!("No models available"))
 850        } else {
 851            Ok(list)
 852        })
 853    }
 854
 855    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
 856        log::debug!(
 857            "Setting model for session {}: {}",
 858            self.session_id,
 859            model_id
 860        );
 861        let Some(thread) = self
 862            .connection
 863            .0
 864            .read(cx)
 865            .sessions
 866            .get(&self.session_id)
 867            .map(|session| session.thread.clone())
 868        else {
 869            return Task::ready(Err(anyhow!("Session not found")));
 870        };
 871
 872        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
 873            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 874        };
 875
 876        thread.update(cx, |thread, cx| {
 877            thread.set_model(model.clone(), cx);
 878        });
 879
 880        update_settings_file(
 881            self.connection.0.read(cx).fs.clone(),
 882            cx,
 883            move |settings, _cx| {
 884                let provider = model.provider_id().0.to_string();
 885                let model = model.id().0.to_string();
 886                settings
 887                    .agent
 888                    .get_or_insert_default()
 889                    .set_model(LanguageModelSelection {
 890                        provider: provider.into(),
 891                        model,
 892                    });
 893            },
 894        );
 895
 896        Task::ready(Ok(()))
 897    }
 898
 899    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
 900        let Some(thread) = self
 901            .connection
 902            .0
 903            .read(cx)
 904            .sessions
 905            .get(&self.session_id)
 906            .map(|session| session.thread.clone())
 907        else {
 908            return Task::ready(Err(anyhow!("Session not found")));
 909        };
 910        let Some(model) = thread.read(cx).model() else {
 911            return Task::ready(Err(anyhow!("Model not found")));
 912        };
 913        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 914        else {
 915            return Task::ready(Err(anyhow!("Provider not found")));
 916        };
 917        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 918            model, &provider,
 919        )))
 920    }
 921
 922    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
 923        Some(self.connection.0.read(cx).models.watch())
 924    }
 925}
 926
 927impl acp_thread::AgentConnection for NativeAgentConnection {
 928    fn new_thread(
 929        self: Rc<Self>,
 930        project: Entity<Project>,
 931        cwd: &Path,
 932        cx: &mut App,
 933    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 934        let agent = self.0.clone();
 935        log::debug!("Creating new thread for project at: {:?}", cwd);
 936
 937        cx.spawn(async move |cx| {
 938            log::debug!("Starting thread creation in async context");
 939
 940            // Create Thread
 941            let thread = agent.update(
 942                cx,
 943                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 944                    // Fetch default model from registry settings
 945                    let registry = LanguageModelRegistry::read_global(cx);
 946                    // Log available models for debugging
 947                    let available_count = registry.available_models(cx).count();
 948                    log::debug!("Total available models: {}", available_count);
 949
 950                    let default_model = registry.default_model().and_then(|default_model| {
 951                        agent
 952                            .models
 953                            .model_from_id(&LanguageModels::model_id(&default_model.model))
 954                    });
 955                    Ok(cx.new(|cx| {
 956                        Thread::new(
 957                            project.clone(),
 958                            agent.project_context.clone(),
 959                            agent.context_server_registry.clone(),
 960                            agent.templates.clone(),
 961                            default_model,
 962                            cx,
 963                        )
 964                    }))
 965                },
 966            )??;
 967            agent.update(cx, |agent, cx| agent.register_session(thread, cx))
 968        })
 969    }
 970
 971    fn auth_methods(&self) -> &[acp::AuthMethod] {
 972        &[] // No auth for in-process
 973    }
 974
 975    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 976        Task::ready(Ok(()))
 977    }
 978
 979    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 980        Some(Rc::new(NativeAgentModelSelector {
 981            session_id: session_id.clone(),
 982            connection: self.clone(),
 983        }) as Rc<dyn AgentModelSelector>)
 984    }
 985
 986    fn prompt(
 987        &self,
 988        id: Option<acp_thread::UserMessageId>,
 989        params: acp::PromptRequest,
 990        cx: &mut App,
 991    ) -> Task<Result<acp::PromptResponse>> {
 992        let id = id.expect("UserMessageId is required");
 993        let session_id = params.session_id.clone();
 994        log::info!("Received prompt request for session: {}", session_id);
 995        log::debug!("Prompt blocks count: {}", params.prompt.len());
 996
 997        self.run_turn(session_id, cx, |thread, cx| {
 998            let content: Vec<UserMessageContent> = params
 999                .prompt
1000                .into_iter()
1001                .map(Into::into)
1002                .collect::<Vec<_>>();
1003            log::debug!("Converted prompt to message: {} chars", content.len());
1004            log::debug!("Message id: {:?}", id);
1005            log::debug!("Message content: {:?}", content);
1006
1007            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1008        })
1009    }
1010
1011    fn resume(
1012        &self,
1013        session_id: &acp::SessionId,
1014        _cx: &App,
1015    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1016        Some(Rc::new(NativeAgentSessionResume {
1017            connection: self.clone(),
1018            session_id: session_id.clone(),
1019        }) as _)
1020    }
1021
1022    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1023        log::info!("Cancelling on session: {}", session_id);
1024        self.0.update(cx, |agent, cx| {
1025            if let Some(agent) = agent.sessions.get(session_id) {
1026                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1027            }
1028        });
1029    }
1030
1031    fn truncate(
1032        &self,
1033        session_id: &agent_client_protocol::SessionId,
1034        cx: &App,
1035    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1036        self.0.read_with(cx, |agent, _cx| {
1037            agent.sessions.get(session_id).map(|session| {
1038                Rc::new(NativeAgentSessionTruncate {
1039                    thread: session.thread.clone(),
1040                    acp_thread: session.acp_thread.clone(),
1041                }) as _
1042            })
1043        })
1044    }
1045
1046    fn set_title(
1047        &self,
1048        session_id: &acp::SessionId,
1049        _cx: &App,
1050    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1051        Some(Rc::new(NativeAgentSessionSetTitle {
1052            connection: self.clone(),
1053            session_id: session_id.clone(),
1054        }) as _)
1055    }
1056
1057    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1058        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1059    }
1060
1061    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1062        self
1063    }
1064}
1065
1066impl acp_thread::AgentTelemetry for NativeAgentConnection {
1067    fn agent_name(&self) -> String {
1068        "Zed".into()
1069    }
1070
1071    fn thread_data(
1072        &self,
1073        session_id: &acp::SessionId,
1074        cx: &mut App,
1075    ) -> Task<Result<serde_json::Value>> {
1076        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1077            return Task::ready(Err(anyhow!("Session not found")));
1078        };
1079
1080        let task = session.thread.read(cx).to_db(cx);
1081        cx.background_spawn(async move {
1082            serde_json::to_value(task.await).context("Failed to serialize thread")
1083        })
1084    }
1085}
1086
1087struct NativeAgentSessionTruncate {
1088    thread: Entity<Thread>,
1089    acp_thread: WeakEntity<AcpThread>,
1090}
1091
1092impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1093    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1094        match self.thread.update(cx, |thread, cx| {
1095            thread.truncate(message_id.clone(), cx)?;
1096            Ok(thread.latest_token_usage())
1097        }) {
1098            Ok(usage) => {
1099                self.acp_thread
1100                    .update(cx, |thread, cx| {
1101                        thread.update_token_usage(usage, cx);
1102                    })
1103                    .ok();
1104                Task::ready(Ok(()))
1105            }
1106            Err(error) => Task::ready(Err(error)),
1107        }
1108    }
1109}
1110
1111struct NativeAgentSessionResume {
1112    connection: NativeAgentConnection,
1113    session_id: acp::SessionId,
1114}
1115
1116impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1117    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1118        self.connection
1119            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1120                thread.update(cx, |thread, cx| thread.resume(cx))
1121            })
1122    }
1123}
1124
1125struct NativeAgentSessionSetTitle {
1126    connection: NativeAgentConnection,
1127    session_id: acp::SessionId,
1128}
1129
1130impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1131    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1132        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1133            return Task::ready(Err(anyhow!("session not found")));
1134        };
1135        let thread = session.thread.clone();
1136        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1137        Task::ready(Ok(()))
1138    }
1139}
1140
1141pub struct AcpThreadEnvironment {
1142    acp_thread: WeakEntity<AcpThread>,
1143}
1144
1145impl ThreadEnvironment for AcpThreadEnvironment {
1146    fn create_terminal(
1147        &self,
1148        command: String,
1149        cwd: Option<PathBuf>,
1150        output_byte_limit: Option<u64>,
1151        cx: &mut AsyncApp,
1152    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1153        let task = self.acp_thread.update(cx, |thread, cx| {
1154            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1155        });
1156
1157        let acp_thread = self.acp_thread.clone();
1158        cx.spawn(async move |cx| {
1159            let terminal = task?.await?;
1160
1161            let (drop_tx, drop_rx) = oneshot::channel();
1162            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1163
1164            cx.spawn(async move |cx| {
1165                drop_rx.await.ok();
1166                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1167            })
1168            .detach();
1169
1170            let handle = AcpTerminalHandle {
1171                terminal,
1172                _drop_tx: Some(drop_tx),
1173            };
1174
1175            Ok(Rc::new(handle) as _)
1176        })
1177    }
1178}
1179
1180pub struct AcpTerminalHandle {
1181    terminal: Entity<acp_thread::Terminal>,
1182    _drop_tx: Option<oneshot::Sender<()>>,
1183}
1184
1185impl TerminalHandle for AcpTerminalHandle {
1186    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1187        self.terminal.read_with(cx, |term, _cx| term.id().clone())
1188    }
1189
1190    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1191        self.terminal
1192            .read_with(cx, |term, _cx| term.wait_for_exit())
1193    }
1194
1195    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1196        self.terminal
1197            .read_with(cx, |term, cx| term.current_output(cx))
1198    }
1199}
1200
1201#[cfg(test)]
1202mod tests {
1203    use crate::HistoryEntryId;
1204
1205    use super::*;
1206    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1207    use fs::FakeFs;
1208    use gpui::TestAppContext;
1209    use indoc::formatdoc;
1210    use language_model::fake_provider::FakeLanguageModel;
1211    use serde_json::json;
1212    use settings::SettingsStore;
1213    use util::{path, rel_path::rel_path};
1214
1215    #[gpui::test]
1216    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1217        init_test(cx);
1218        let fs = FakeFs::new(cx.executor());
1219        fs.insert_tree(
1220            "/",
1221            json!({
1222                "a": {}
1223            }),
1224        )
1225        .await;
1226        let project = Project::test(fs.clone(), [], cx).await;
1227        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1228        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1229        let agent = NativeAgent::new(
1230            project.clone(),
1231            history_store,
1232            Templates::new(),
1233            None,
1234            fs.clone(),
1235            &mut cx.to_async(),
1236        )
1237        .await
1238        .unwrap();
1239        agent.read_with(cx, |agent, cx| {
1240            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1241        });
1242
1243        let worktree = project
1244            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1245            .await
1246            .unwrap();
1247        cx.run_until_parked();
1248        agent.read_with(cx, |agent, cx| {
1249            assert_eq!(
1250                agent.project_context.read(cx).worktrees,
1251                vec![WorktreeContext {
1252                    root_name: "a".into(),
1253                    abs_path: Path::new("/a").into(),
1254                    rules_file: None
1255                }]
1256            )
1257        });
1258
1259        // Creating `/a/.rules` updates the project context.
1260        fs.insert_file("/a/.rules", Vec::new()).await;
1261        cx.run_until_parked();
1262        agent.read_with(cx, |agent, cx| {
1263            let rules_entry = worktree
1264                .read(cx)
1265                .entry_for_path(rel_path(".rules"))
1266                .unwrap();
1267            assert_eq!(
1268                agent.project_context.read(cx).worktrees,
1269                vec![WorktreeContext {
1270                    root_name: "a".into(),
1271                    abs_path: Path::new("/a").into(),
1272                    rules_file: Some(RulesFileContext {
1273                        path_in_worktree: rel_path(".rules").into(),
1274                        text: "".into(),
1275                        project_entry_id: rules_entry.id.to_usize()
1276                    })
1277                }]
1278            )
1279        });
1280    }
1281
1282    #[gpui::test]
1283    async fn test_listing_models(cx: &mut TestAppContext) {
1284        init_test(cx);
1285        let fs = FakeFs::new(cx.executor());
1286        fs.insert_tree("/", json!({ "a": {}  })).await;
1287        let project = Project::test(fs.clone(), [], cx).await;
1288        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1289        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1290        let connection = NativeAgentConnection(
1291            NativeAgent::new(
1292                project.clone(),
1293                history_store,
1294                Templates::new(),
1295                None,
1296                fs.clone(),
1297                &mut cx.to_async(),
1298            )
1299            .await
1300            .unwrap(),
1301        );
1302
1303        // Create a thread/session
1304        let acp_thread = cx
1305            .update(|cx| {
1306                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1307            })
1308            .await
1309            .unwrap();
1310
1311        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1312
1313        let models = cx
1314            .update(|cx| {
1315                connection
1316                    .model_selector(&session_id)
1317                    .unwrap()
1318                    .list_models(cx)
1319            })
1320            .await
1321            .unwrap();
1322
1323        let acp_thread::AgentModelList::Grouped(models) = models else {
1324            panic!("Unexpected model group");
1325        };
1326        assert_eq!(
1327            models,
1328            IndexMap::from_iter([(
1329                AgentModelGroupName("Fake".into()),
1330                vec![AgentModelInfo {
1331                    id: acp::ModelId("fake/fake".into()),
1332                    name: "Fake".into(),
1333                    description: None,
1334                    icon: Some(ui::IconName::ZedAssistant),
1335                }]
1336            )])
1337        );
1338    }
1339
1340    #[gpui::test]
1341    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1342        init_test(cx);
1343        let fs = FakeFs::new(cx.executor());
1344        fs.create_dir(paths::settings_file().parent().unwrap())
1345            .await
1346            .unwrap();
1347        fs.insert_file(
1348            paths::settings_file(),
1349            json!({
1350                "agent": {
1351                    "default_model": {
1352                        "provider": "foo",
1353                        "model": "bar"
1354                    }
1355                }
1356            })
1357            .to_string()
1358            .into_bytes(),
1359        )
1360        .await;
1361        let project = Project::test(fs.clone(), [], cx).await;
1362
1363        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1364        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1365
1366        // Create the agent and connection
1367        let agent = NativeAgent::new(
1368            project.clone(),
1369            history_store,
1370            Templates::new(),
1371            None,
1372            fs.clone(),
1373            &mut cx.to_async(),
1374        )
1375        .await
1376        .unwrap();
1377        let connection = NativeAgentConnection(agent.clone());
1378
1379        // Create a thread/session
1380        let acp_thread = cx
1381            .update(|cx| {
1382                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1383            })
1384            .await
1385            .unwrap();
1386
1387        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1388
1389        // Select a model
1390        let selector = connection.model_selector(&session_id).unwrap();
1391        let model_id = acp::ModelId("fake/fake".into());
1392        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1393            .await
1394            .unwrap();
1395
1396        // Verify the thread has the selected model
1397        agent.read_with(cx, |agent, _| {
1398            let session = agent.sessions.get(&session_id).unwrap();
1399            session.thread.read_with(cx, |thread, _| {
1400                assert_eq!(thread.model().unwrap().id().0, "fake");
1401            });
1402        });
1403
1404        cx.run_until_parked();
1405
1406        // Verify settings file was updated
1407        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1408        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1409
1410        // Check that the agent settings contain the selected model
1411        assert_eq!(
1412            settings_json["agent"]["default_model"]["model"],
1413            json!("fake")
1414        );
1415        assert_eq!(
1416            settings_json["agent"]["default_model"]["provider"],
1417            json!("fake")
1418        );
1419    }
1420
1421    #[gpui::test]
1422    #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1423    async fn test_save_load_thread(cx: &mut TestAppContext) {
1424        init_test(cx);
1425        let fs = FakeFs::new(cx.executor());
1426        fs.insert_tree(
1427            "/",
1428            json!({
1429                "a": {
1430                    "b.md": "Lorem"
1431                }
1432            }),
1433        )
1434        .await;
1435        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1436        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1437        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1438        let agent = NativeAgent::new(
1439            project.clone(),
1440            history_store.clone(),
1441            Templates::new(),
1442            None,
1443            fs.clone(),
1444            &mut cx.to_async(),
1445        )
1446        .await
1447        .unwrap();
1448        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1449
1450        let acp_thread = cx
1451            .update(|cx| {
1452                connection
1453                    .clone()
1454                    .new_thread(project.clone(), Path::new(""), cx)
1455            })
1456            .await
1457            .unwrap();
1458        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1459        let thread = agent.read_with(cx, |agent, _| {
1460            agent.sessions.get(&session_id).unwrap().thread.clone()
1461        });
1462
1463        // Ensure empty threads are not saved, even if they get mutated.
1464        let model = Arc::new(FakeLanguageModel::default());
1465        let summary_model = Arc::new(FakeLanguageModel::default());
1466        thread.update(cx, |thread, cx| {
1467            thread.set_model(model.clone(), cx);
1468            thread.set_summarization_model(Some(summary_model.clone()), cx);
1469        });
1470        cx.run_until_parked();
1471        assert_eq!(history_entries(&history_store, cx), vec![]);
1472
1473        let send = acp_thread.update(cx, |thread, cx| {
1474            thread.send(
1475                vec![
1476                    "What does ".into(),
1477                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
1478                        name: "b.md".into(),
1479                        uri: MentionUri::File {
1480                            abs_path: path!("/a/b.md").into(),
1481                        }
1482                        .to_uri()
1483                        .to_string(),
1484                        annotations: None,
1485                        description: None,
1486                        mime_type: None,
1487                        size: None,
1488                        title: None,
1489                        meta: None,
1490                    }),
1491                    " mean?".into(),
1492                ],
1493                cx,
1494            )
1495        });
1496        let send = cx.foreground_executor().spawn(send);
1497        cx.run_until_parked();
1498
1499        model.send_last_completion_stream_text_chunk("Lorem.");
1500        model.end_last_completion_stream();
1501        cx.run_until_parked();
1502        summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
1503        summary_model.end_last_completion_stream();
1504
1505        send.await.unwrap();
1506        let uri = MentionUri::File {
1507            abs_path: path!("/a/b.md").into(),
1508        }
1509        .to_uri();
1510        acp_thread.read_with(cx, |thread, cx| {
1511            assert_eq!(
1512                thread.to_markdown(cx),
1513                formatdoc! {"
1514                    ## User
1515
1516                    What does [@b.md]({uri}) mean?
1517
1518                    ## Assistant
1519
1520                    Lorem.
1521
1522                "}
1523            )
1524        });
1525
1526        cx.run_until_parked();
1527
1528        // Drop the ACP thread, which should cause the session to be dropped as well.
1529        cx.update(|_| {
1530            drop(thread);
1531            drop(acp_thread);
1532        });
1533        agent.read_with(cx, |agent, _| {
1534            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1535        });
1536
1537        // Ensure the thread can be reloaded from disk.
1538        assert_eq!(
1539            history_entries(&history_store, cx),
1540            vec![(
1541                HistoryEntryId::AcpThread(session_id.clone()),
1542                "Explaining /a/b.md".into()
1543            )]
1544        );
1545        let acp_thread = agent
1546            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1547            .await
1548            .unwrap();
1549        acp_thread.read_with(cx, |thread, cx| {
1550            assert_eq!(
1551                thread.to_markdown(cx),
1552                formatdoc! {"
1553                    ## User
1554
1555                    What does [@b.md]({uri}) mean?
1556
1557                    ## Assistant
1558
1559                    Lorem.
1560
1561                "}
1562            )
1563        });
1564    }
1565
1566    fn history_entries(
1567        history: &Entity<HistoryStore>,
1568        cx: &mut TestAppContext,
1569    ) -> Vec<(HistoryEntryId, String)> {
1570        history.read_with(cx, |history, _| {
1571            history
1572                .entries()
1573                .map(|e| (e.id(), e.title().to_string()))
1574                .collect::<Vec<_>>()
1575        })
1576    }
1577
1578    fn init_test(cx: &mut TestAppContext) {
1579        env_logger::try_init().ok();
1580        cx.update(|cx| {
1581            let settings_store = SettingsStore::test(cx);
1582            cx.set_global(settings_store);
1583            Project::init_settings(cx);
1584            agent_settings::init(cx);
1585            language::init(cx);
1586            LanguageModelRegistry::test(cx);
1587        });
1588    }
1589}