agent.rs

   1mod db;
   2mod edit_agent;
   3mod history_store;
   4mod legacy_thread;
   5mod native_agent_server;
   6pub mod outline;
   7mod templates;
   8mod thread;
   9mod tools;
  10
  11#[cfg(test)]
  12mod tests;
  13
  14pub use db::*;
  15pub use history_store::*;
  16pub use native_agent_server::NativeAgentServer;
  17pub use templates::*;
  18pub use thread::*;
  19pub use tools::*;
  20
  21use acp_thread::{AcpThread, AgentModelSelector};
  22use agent_client_protocol as acp;
  23use anyhow::{Context as _, Result, anyhow};
  24use chrono::{DateTime, Utc};
  25use collections::{HashSet, IndexMap};
  26use fs::Fs;
  27use futures::channel::{mpsc, oneshot};
  28use futures::future::Shared;
  29use futures::{StreamExt, future};
  30use gpui::{
  31    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  32};
  33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  34use project::{Project, ProjectItem, ProjectPath, Worktree};
  35use prompt_store::{
  36    ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
  37    WorktreeContext,
  38};
  39use serde::{Deserialize, Serialize};
  40use settings::{LanguageModelSelection, update_settings_file};
  41use std::any::Any;
  42use std::collections::HashMap;
  43use std::path::{Path, PathBuf};
  44use std::rc::Rc;
  45use std::sync::Arc;
  46use util::ResultExt;
  47use util::rel_path::RelPath;
  48
  49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  50pub struct ProjectSnapshot {
  51    pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
  52    pub timestamp: DateTime<Utc>,
  53}
  54
  55pub struct RulesLoadingError {
  56    pub message: SharedString,
  57}
  58
  59/// Holds both the internal Thread and the AcpThread for a session
  60struct Session {
  61    /// The internal thread that processes messages
  62    thread: Entity<Thread>,
  63    /// The ACP thread that handles protocol communication
  64    acp_thread: WeakEntity<acp_thread::AcpThread>,
  65    pending_save: Task<()>,
  66    _subscriptions: Vec<Subscription>,
  67}
  68
  69pub struct LanguageModels {
  70    /// Access language model by ID
  71    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  72    /// Cached list for returning language model information
  73    model_list: acp_thread::AgentModelList,
  74    refresh_models_rx: watch::Receiver<()>,
  75    refresh_models_tx: watch::Sender<()>,
  76    _authenticate_all_providers_task: Task<()>,
  77}
  78
  79impl LanguageModels {
  80    fn new(cx: &mut App) -> Self {
  81        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  82
  83        let mut this = Self {
  84            models: HashMap::default(),
  85            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  86            refresh_models_rx,
  87            refresh_models_tx,
  88            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  89        };
  90        this.refresh_list(cx);
  91        this
  92    }
  93
  94    fn refresh_list(&mut self, cx: &App) {
  95        let providers = LanguageModelRegistry::global(cx)
  96            .read(cx)
  97            .providers()
  98            .into_iter()
  99            .filter(|provider| provider.is_authenticated(cx))
 100            .collect::<Vec<_>>();
 101
 102        let mut language_model_list = IndexMap::default();
 103        let mut recommended_models = HashSet::default();
 104
 105        let mut recommended = Vec::new();
 106        for provider in &providers {
 107            for model in provider.recommended_models(cx) {
 108                recommended_models.insert((model.provider_id(), model.id()));
 109                recommended.push(Self::map_language_model_to_info(&model, provider));
 110            }
 111        }
 112        if !recommended.is_empty() {
 113            language_model_list.insert(
 114                acp_thread::AgentModelGroupName("Recommended".into()),
 115                recommended,
 116            );
 117        }
 118
 119        let mut models = HashMap::default();
 120        for provider in providers {
 121            let mut provider_models = Vec::new();
 122            for model in provider.provided_models(cx) {
 123                let model_info = Self::map_language_model_to_info(&model, &provider);
 124                let model_id = model_info.id.clone();
 125                provider_models.push(model_info);
 126                models.insert(model_id, model);
 127            }
 128            if !provider_models.is_empty() {
 129                language_model_list.insert(
 130                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 131                    provider_models,
 132                );
 133            }
 134        }
 135
 136        self.models = models;
 137        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 138        self.refresh_models_tx.send(()).ok();
 139    }
 140
 141    fn watch(&self) -> watch::Receiver<()> {
 142        self.refresh_models_rx.clone()
 143    }
 144
 145    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 146        self.models.get(model_id).cloned()
 147    }
 148
 149    fn map_language_model_to_info(
 150        model: &Arc<dyn LanguageModel>,
 151        provider: &Arc<dyn LanguageModelProvider>,
 152    ) -> acp_thread::AgentModelInfo {
 153        acp_thread::AgentModelInfo {
 154            id: Self::model_id(model),
 155            name: model.name().0,
 156            description: None,
 157            icon: Some(provider.icon()),
 158        }
 159    }
 160
 161    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 162        acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
 163    }
 164
 165    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 166        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 167            .read(cx)
 168            .providers()
 169            .iter()
 170            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 171            .collect::<Vec<_>>();
 172
 173        cx.background_spawn(async move {
 174            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 175                if let Err(err) = authenticate_task.await {
 176                    match err {
 177                        language_model::AuthenticateError::CredentialsNotFound => {
 178                            // Since we're authenticating these providers in the
 179                            // background for the purposes of populating the
 180                            // language selector, we don't care about providers
 181                            // where the credentials are not found.
 182                        }
 183                        language_model::AuthenticateError::ConnectionRefused => {
 184                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 185                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 186                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 187                        }
 188                        _ => {
 189                            // Some providers have noisy failure states that we
 190                            // don't want to spam the logs with every time the
 191                            // language model selector is initialized.
 192                            //
 193                            // Ideally these should have more clear failure modes
 194                            // that we know are safe to ignore here, like what we do
 195                            // with `CredentialsNotFound` above.
 196                            match provider_id.0.as_ref() {
 197                                "lmstudio" | "ollama" => {
 198                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 199                                    //
 200                                    // These fail noisily, so we don't log them.
 201                                }
 202                                "copilot_chat" => {
 203                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 204                                }
 205                                _ => {
 206                                    log::error!(
 207                                        "Failed to authenticate provider: {}: {err:#}",
 208                                        provider_name.0
 209                                    );
 210                                }
 211                            }
 212                        }
 213                    }
 214                }
 215            }
 216        })
 217    }
 218}
 219
 220pub struct NativeAgent {
 221    /// Session ID -> Session mapping
 222    sessions: HashMap<acp::SessionId, Session>,
 223    history: Entity<HistoryStore>,
 224    /// Shared project context for all threads
 225    project_context: Entity<ProjectContext>,
 226    project_context_needs_refresh: watch::Sender<()>,
 227    _maintain_project_context: Task<Result<()>>,
 228    context_server_registry: Entity<ContextServerRegistry>,
 229    /// Shared templates for all threads
 230    templates: Arc<Templates>,
 231    /// Cached model information
 232    models: LanguageModels,
 233    project: Entity<Project>,
 234    prompt_store: Option<Entity<PromptStore>>,
 235    fs: Arc<dyn Fs>,
 236    _subscriptions: Vec<Subscription>,
 237}
 238
 239impl NativeAgent {
 240    pub async fn new(
 241        project: Entity<Project>,
 242        history: Entity<HistoryStore>,
 243        templates: Arc<Templates>,
 244        prompt_store: Option<Entity<PromptStore>>,
 245        fs: Arc<dyn Fs>,
 246        cx: &mut AsyncApp,
 247    ) -> Result<Entity<NativeAgent>> {
 248        log::debug!("Creating new NativeAgent");
 249
 250        let project_context = cx
 251            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 252            .await;
 253
 254        cx.new(|cx| {
 255            let mut subscriptions = vec![
 256                cx.subscribe(&project, Self::handle_project_event),
 257                cx.subscribe(
 258                    &LanguageModelRegistry::global(cx),
 259                    Self::handle_models_updated_event,
 260                ),
 261            ];
 262            if let Some(prompt_store) = prompt_store.as_ref() {
 263                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 264            }
 265
 266            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 267                watch::channel(());
 268            Self {
 269                sessions: HashMap::new(),
 270                history,
 271                project_context: cx.new(|_| project_context),
 272                project_context_needs_refresh: project_context_needs_refresh_tx,
 273                _maintain_project_context: cx.spawn(async move |this, cx| {
 274                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 275                }),
 276                context_server_registry: cx.new(|cx| {
 277                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 278                }),
 279                templates,
 280                models: LanguageModels::new(cx),
 281                project,
 282                prompt_store,
 283                fs,
 284                _subscriptions: subscriptions,
 285            }
 286        })
 287    }
 288
 289    fn register_session(
 290        &mut self,
 291        thread_handle: Entity<Thread>,
 292        cx: &mut Context<Self>,
 293    ) -> Entity<AcpThread> {
 294        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 295
 296        let thread = thread_handle.read(cx);
 297        let session_id = thread.id().clone();
 298        let title = thread.title();
 299        let project = thread.project.clone();
 300        let action_log = thread.action_log.clone();
 301        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 302        let acp_thread = cx.new(|cx| {
 303            acp_thread::AcpThread::new(
 304                title,
 305                connection,
 306                project.clone(),
 307                action_log.clone(),
 308                session_id.clone(),
 309                prompt_capabilities_rx,
 310                cx,
 311            )
 312        });
 313
 314        let registry = LanguageModelRegistry::read_global(cx);
 315        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 316
 317        thread_handle.update(cx, |thread, cx| {
 318            thread.set_summarization_model(summarization_model, cx);
 319            thread.add_default_tools(
 320                Rc::new(AcpThreadEnvironment {
 321                    acp_thread: acp_thread.downgrade(),
 322                }) as _,
 323                cx,
 324            )
 325        });
 326
 327        let subscriptions = vec![
 328            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 329                this.sessions.remove(acp_thread.session_id());
 330            }),
 331            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 332            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 333            cx.observe(&thread_handle, move |this, thread, cx| {
 334                this.save_thread(thread, cx)
 335            }),
 336        ];
 337
 338        self.sessions.insert(
 339            session_id,
 340            Session {
 341                thread: thread_handle,
 342                acp_thread: acp_thread.downgrade(),
 343                _subscriptions: subscriptions,
 344                pending_save: Task::ready(()),
 345            },
 346        );
 347        acp_thread
 348    }
 349
 350    pub fn models(&self) -> &LanguageModels {
 351        &self.models
 352    }
 353
 354    async fn maintain_project_context(
 355        this: WeakEntity<Self>,
 356        mut needs_refresh: watch::Receiver<()>,
 357        cx: &mut AsyncApp,
 358    ) -> Result<()> {
 359        while needs_refresh.changed().await.is_ok() {
 360            let project_context = this
 361                .update(cx, |this, cx| {
 362                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 363                })?
 364                .await;
 365            this.update(cx, |this, cx| {
 366                this.project_context = cx.new(|_| project_context);
 367            })?;
 368        }
 369
 370        Ok(())
 371    }
 372
 373    fn build_project_context(
 374        project: &Entity<Project>,
 375        prompt_store: Option<&Entity<PromptStore>>,
 376        cx: &mut App,
 377    ) -> Task<ProjectContext> {
 378        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 379        let worktree_tasks = worktrees
 380            .into_iter()
 381            .map(|worktree| {
 382                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 383            })
 384            .collect::<Vec<_>>();
 385        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 386            prompt_store.read_with(cx, |prompt_store, cx| {
 387                let prompts = prompt_store.default_prompt_metadata();
 388                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 389                    let contents = prompt_store.load(prompt_metadata.id, cx);
 390                    async move { (contents.await, prompt_metadata) }
 391                });
 392                cx.background_spawn(future::join_all(load_tasks))
 393            })
 394        } else {
 395            Task::ready(vec![])
 396        };
 397
 398        cx.spawn(async move |_cx| {
 399            let (worktrees, default_user_rules) =
 400                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 401
 402            let worktrees = worktrees
 403                .into_iter()
 404                .map(|(worktree, _rules_error)| {
 405                    // TODO: show error message
 406                    // if let Some(rules_error) = rules_error {
 407                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 408                    // }
 409                    worktree
 410                })
 411                .collect::<Vec<_>>();
 412
 413            let default_user_rules = default_user_rules
 414                .into_iter()
 415                .flat_map(|(contents, prompt_metadata)| match contents {
 416                    Ok(contents) => Some(UserRulesContext {
 417                        uuid: match prompt_metadata.id {
 418                            prompt_store::PromptId::User { uuid } => uuid,
 419                            prompt_store::PromptId::EditWorkflow => return None,
 420                        },
 421                        title: prompt_metadata.title.map(|title| title.to_string()),
 422                        contents,
 423                    }),
 424                    Err(_err) => {
 425                        // TODO: show error message
 426                        // this.update(cx, |_, cx| {
 427                        //     cx.emit(RulesLoadingError {
 428                        //         message: format!("{err:?}").into(),
 429                        //     });
 430                        // })
 431                        // .ok();
 432                        None
 433                    }
 434                })
 435                .collect::<Vec<_>>();
 436
 437            ProjectContext::new(worktrees, default_user_rules)
 438        })
 439    }
 440
 441    fn load_worktree_info_for_system_prompt(
 442        worktree: Entity<Worktree>,
 443        project: Entity<Project>,
 444        cx: &mut App,
 445    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 446        let tree = worktree.read(cx);
 447        let root_name = tree.root_name_str().into();
 448        let abs_path = tree.abs_path();
 449
 450        let mut context = WorktreeContext {
 451            root_name,
 452            abs_path,
 453            rules_file: None,
 454        };
 455
 456        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 457        let Some(rules_task) = rules_task else {
 458            return Task::ready((context, None));
 459        };
 460
 461        cx.spawn(async move |_| {
 462            let (rules_file, rules_file_error) = match rules_task.await {
 463                Ok(rules_file) => (Some(rules_file), None),
 464                Err(err) => (
 465                    None,
 466                    Some(RulesLoadingError {
 467                        message: format!("{err}").into(),
 468                    }),
 469                ),
 470            };
 471            context.rules_file = rules_file;
 472            (context, rules_file_error)
 473        })
 474    }
 475
 476    fn load_worktree_rules_file(
 477        worktree: Entity<Worktree>,
 478        project: Entity<Project>,
 479        cx: &mut App,
 480    ) -> Option<Task<Result<RulesFileContext>>> {
 481        let worktree = worktree.read(cx);
 482        let worktree_id = worktree.id();
 483        let selected_rules_file = RULES_FILE_NAMES
 484            .into_iter()
 485            .filter_map(|name| {
 486                worktree
 487                    .entry_for_path(RelPath::unix(name).unwrap())
 488                    .filter(|entry| entry.is_file())
 489                    .map(|entry| entry.path.clone())
 490            })
 491            .next();
 492
 493        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 494        // supported. This doesn't seem to occur often in GitHub repositories.
 495        selected_rules_file.map(|path_in_worktree| {
 496            let project_path = ProjectPath {
 497                worktree_id,
 498                path: path_in_worktree.clone(),
 499            };
 500            let buffer_task =
 501                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 502            let rope_task = cx.spawn(async move |cx| {
 503                buffer_task.await?.read_with(cx, |buffer, cx| {
 504                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 505                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 506                })?
 507            });
 508            // Build a string from the rope on a background thread.
 509            cx.background_spawn(async move {
 510                let (project_entry_id, rope) = rope_task.await?;
 511                anyhow::Ok(RulesFileContext {
 512                    path_in_worktree,
 513                    text: rope.to_string().trim().to_string(),
 514                    project_entry_id: project_entry_id.to_usize(),
 515                })
 516            })
 517        })
 518    }
 519
 520    fn handle_thread_title_updated(
 521        &mut self,
 522        thread: Entity<Thread>,
 523        _: &TitleUpdated,
 524        cx: &mut Context<Self>,
 525    ) {
 526        let session_id = thread.read(cx).id();
 527        let Some(session) = self.sessions.get(session_id) else {
 528            return;
 529        };
 530        let thread = thread.downgrade();
 531        let acp_thread = session.acp_thread.clone();
 532        cx.spawn(async move |_, cx| {
 533            let title = thread.read_with(cx, |thread, _| thread.title())?;
 534            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 535            task.await
 536        })
 537        .detach_and_log_err(cx);
 538    }
 539
 540    fn handle_thread_token_usage_updated(
 541        &mut self,
 542        thread: Entity<Thread>,
 543        usage: &TokenUsageUpdated,
 544        cx: &mut Context<Self>,
 545    ) {
 546        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 547            return;
 548        };
 549        session
 550            .acp_thread
 551            .update(cx, |acp_thread, cx| {
 552                acp_thread.update_token_usage(usage.0.clone(), cx);
 553            })
 554            .ok();
 555    }
 556
 557    fn handle_project_event(
 558        &mut self,
 559        _project: Entity<Project>,
 560        event: &project::Event,
 561        _cx: &mut Context<Self>,
 562    ) {
 563        match event {
 564            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 565                self.project_context_needs_refresh.send(()).ok();
 566            }
 567            project::Event::WorktreeUpdatedEntries(_, items) => {
 568                if items.iter().any(|(path, _, _)| {
 569                    RULES_FILE_NAMES
 570                        .iter()
 571                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 572                }) {
 573                    self.project_context_needs_refresh.send(()).ok();
 574                }
 575            }
 576            _ => {}
 577        }
 578    }
 579
 580    fn handle_prompts_updated_event(
 581        &mut self,
 582        _prompt_store: Entity<PromptStore>,
 583        _event: &prompt_store::PromptsUpdatedEvent,
 584        _cx: &mut Context<Self>,
 585    ) {
 586        self.project_context_needs_refresh.send(()).ok();
 587    }
 588
 589    fn handle_models_updated_event(
 590        &mut self,
 591        _registry: Entity<LanguageModelRegistry>,
 592        _event: &language_model::Event,
 593        cx: &mut Context<Self>,
 594    ) {
 595        self.models.refresh_list(cx);
 596
 597        let registry = LanguageModelRegistry::read_global(cx);
 598        let default_model = registry.default_model().map(|m| m.model);
 599        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 600
 601        for session in self.sessions.values_mut() {
 602            session.thread.update(cx, |thread, cx| {
 603                if thread.model().is_none()
 604                    && let Some(model) = default_model.clone()
 605                {
 606                    thread.set_model(model, cx);
 607                    cx.notify();
 608                }
 609                thread.set_summarization_model(summarization_model.clone(), cx);
 610            });
 611        }
 612    }
 613
 614    pub fn load_thread(
 615        &mut self,
 616        id: acp::SessionId,
 617        cx: &mut Context<Self>,
 618    ) -> Task<Result<Entity<Thread>>> {
 619        let database_future = ThreadsDatabase::connect(cx);
 620        cx.spawn(async move |this, cx| {
 621            let database = database_future.await.map_err(|err| anyhow!(err))?;
 622            let db_thread = database
 623                .load_thread(id.clone())
 624                .await?
 625                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 626
 627            this.update(cx, |this, cx| {
 628                let summarization_model = LanguageModelRegistry::read_global(cx)
 629                    .thread_summary_model()
 630                    .map(|c| c.model);
 631
 632                cx.new(|cx| {
 633                    let mut thread = Thread::from_db(
 634                        id.clone(),
 635                        db_thread,
 636                        this.project.clone(),
 637                        this.project_context.clone(),
 638                        this.context_server_registry.clone(),
 639                        this.templates.clone(),
 640                        cx,
 641                    );
 642                    thread.set_summarization_model(summarization_model, cx);
 643                    thread
 644                })
 645            })
 646        })
 647    }
 648
 649    pub fn open_thread(
 650        &mut self,
 651        id: acp::SessionId,
 652        cx: &mut Context<Self>,
 653    ) -> Task<Result<Entity<AcpThread>>> {
 654        let task = self.load_thread(id, cx);
 655        cx.spawn(async move |this, cx| {
 656            let thread = task.await?;
 657            let acp_thread =
 658                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 659            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
 660            cx.update(|cx| {
 661                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 662            })?
 663            .await?;
 664            Ok(acp_thread)
 665        })
 666    }
 667
 668    pub fn thread_summary(
 669        &mut self,
 670        id: acp::SessionId,
 671        cx: &mut Context<Self>,
 672    ) -> Task<Result<SharedString>> {
 673        let thread = self.open_thread(id.clone(), cx);
 674        cx.spawn(async move |this, cx| {
 675            let acp_thread = thread.await?;
 676            let result = this
 677                .update(cx, |this, cx| {
 678                    this.sessions
 679                        .get(&id)
 680                        .unwrap()
 681                        .thread
 682                        .update(cx, |thread, cx| thread.summary(cx))
 683                })?
 684                .await
 685                .context("Failed to generate summary")?;
 686            drop(acp_thread);
 687            Ok(result)
 688        })
 689    }
 690
 691    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 692        if thread.read(cx).is_empty() {
 693            return;
 694        }
 695
 696        let database_future = ThreadsDatabase::connect(cx);
 697        let (id, db_thread) =
 698            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 699        let Some(session) = self.sessions.get_mut(&id) else {
 700            return;
 701        };
 702        let history = self.history.clone();
 703        session.pending_save = cx.spawn(async move |_, cx| {
 704            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 705                return;
 706            };
 707            let db_thread = db_thread.await;
 708            database.save_thread(id, db_thread).await.log_err();
 709            history.update(cx, |history, cx| history.reload(cx)).ok();
 710        });
 711    }
 712}
 713
 714/// Wrapper struct that implements the AgentConnection trait
 715#[derive(Clone)]
 716pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 717
 718impl NativeAgentConnection {
 719    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 720        self.0
 721            .read(cx)
 722            .sessions
 723            .get(session_id)
 724            .map(|session| session.thread.clone())
 725    }
 726
 727    pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
 728        self.0.update(cx, |this, cx| this.load_thread(id, cx))
 729    }
 730
 731    fn run_turn(
 732        &self,
 733        session_id: acp::SessionId,
 734        cx: &mut App,
 735        f: impl 'static
 736        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 737    ) -> Task<Result<acp::PromptResponse>> {
 738        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 739            agent
 740                .sessions
 741                .get_mut(&session_id)
 742                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 743        }) else {
 744            return Task::ready(Err(anyhow!("Session not found")));
 745        };
 746        log::debug!("Found session for: {}", session_id);
 747
 748        let response_stream = match f(thread, cx) {
 749            Ok(stream) => stream,
 750            Err(err) => return Task::ready(Err(err)),
 751        };
 752        Self::handle_thread_events(response_stream, acp_thread, cx)
 753    }
 754
 755    fn handle_thread_events(
 756        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 757        acp_thread: WeakEntity<AcpThread>,
 758        cx: &App,
 759    ) -> Task<Result<acp::PromptResponse>> {
 760        cx.spawn(async move |cx| {
 761            // Handle response stream and forward to session.acp_thread
 762            while let Some(result) = events.next().await {
 763                match result {
 764                    Ok(event) => {
 765                        log::trace!("Received completion event: {:?}", event);
 766
 767                        match event {
 768                            ThreadEvent::UserMessage(message) => {
 769                                acp_thread.update(cx, |thread, cx| {
 770                                    for content in message.content {
 771                                        thread.push_user_content_block(
 772                                            Some(message.id.clone()),
 773                                            content.into(),
 774                                            cx,
 775                                        );
 776                                    }
 777                                })?;
 778                            }
 779                            ThreadEvent::AgentText(text) => {
 780                                acp_thread.update(cx, |thread, cx| {
 781                                    thread.push_assistant_content_block(text.into(), false, cx)
 782                                })?;
 783                            }
 784                            ThreadEvent::AgentThinking(text) => {
 785                                acp_thread.update(cx, |thread, cx| {
 786                                    thread.push_assistant_content_block(text.into(), true, cx)
 787                                })?;
 788                            }
 789                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 790                                tool_call,
 791                                options,
 792                                response,
 793                            }) => {
 794                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 795                                    thread.request_tool_call_authorization(
 796                                        tool_call, options, true, cx,
 797                                    )
 798                                })??;
 799                                cx.background_spawn(async move {
 800                                    if let acp::RequestPermissionOutcome::Selected(
 801                                        acp::SelectedPermissionOutcome { option_id, .. },
 802                                    ) = outcome_task.await
 803                                    {
 804                                        response
 805                                            .send(option_id)
 806                                            .map(|_| anyhow!("authorization receiver was dropped"))
 807                                            .log_err();
 808                                    }
 809                                })
 810                                .detach();
 811                            }
 812                            ThreadEvent::ToolCall(tool_call) => {
 813                                acp_thread.update(cx, |thread, cx| {
 814                                    thread.upsert_tool_call(tool_call, cx)
 815                                })??;
 816                            }
 817                            ThreadEvent::ToolCallUpdate(update) => {
 818                                acp_thread.update(cx, |thread, cx| {
 819                                    thread.update_tool_call(update, cx)
 820                                })??;
 821                            }
 822                            ThreadEvent::Retry(status) => {
 823                                acp_thread.update(cx, |thread, cx| {
 824                                    thread.update_retry_status(status, cx)
 825                                })?;
 826                            }
 827                            ThreadEvent::Stop(stop_reason) => {
 828                                log::debug!("Assistant message complete: {:?}", stop_reason);
 829                                return Ok(acp::PromptResponse::new(stop_reason));
 830                            }
 831                        }
 832                    }
 833                    Err(e) => {
 834                        log::error!("Error in model response stream: {:?}", e);
 835                        return Err(e);
 836                    }
 837                }
 838            }
 839
 840            log::debug!("Response stream completed");
 841            anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
 842        })
 843    }
 844}
 845
 846struct NativeAgentModelSelector {
 847    session_id: acp::SessionId,
 848    connection: NativeAgentConnection,
 849}
 850
 851impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
 852    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 853        log::debug!("NativeAgentConnection::list_models called");
 854        let list = self.connection.0.read(cx).models.model_list.clone();
 855        Task::ready(if list.is_empty() {
 856            Err(anyhow::anyhow!("No models available"))
 857        } else {
 858            Ok(list)
 859        })
 860    }
 861
 862    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
 863        log::debug!(
 864            "Setting model for session {}: {}",
 865            self.session_id,
 866            model_id
 867        );
 868        let Some(thread) = self
 869            .connection
 870            .0
 871            .read(cx)
 872            .sessions
 873            .get(&self.session_id)
 874            .map(|session| session.thread.clone())
 875        else {
 876            return Task::ready(Err(anyhow!("Session not found")));
 877        };
 878
 879        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
 880            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 881        };
 882
 883        thread.update(cx, |thread, cx| {
 884            thread.set_model(model.clone(), cx);
 885        });
 886
 887        update_settings_file(
 888            self.connection.0.read(cx).fs.clone(),
 889            cx,
 890            move |settings, _cx| {
 891                let provider = model.provider_id().0.to_string();
 892                let model = model.id().0.to_string();
 893                settings
 894                    .agent
 895                    .get_or_insert_default()
 896                    .set_model(LanguageModelSelection {
 897                        provider: provider.into(),
 898                        model,
 899                    });
 900            },
 901        );
 902
 903        Task::ready(Ok(()))
 904    }
 905
 906    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
 907        let Some(thread) = self
 908            .connection
 909            .0
 910            .read(cx)
 911            .sessions
 912            .get(&self.session_id)
 913            .map(|session| session.thread.clone())
 914        else {
 915            return Task::ready(Err(anyhow!("Session not found")));
 916        };
 917        let Some(model) = thread.read(cx).model() else {
 918            return Task::ready(Err(anyhow!("Model not found")));
 919        };
 920        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 921        else {
 922            return Task::ready(Err(anyhow!("Provider not found")));
 923        };
 924        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 925            model, &provider,
 926        )))
 927    }
 928
 929    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
 930        Some(self.connection.0.read(cx).models.watch())
 931    }
 932
 933    fn should_render_footer(&self) -> bool {
 934        true
 935    }
 936}
 937
 938impl acp_thread::AgentConnection for NativeAgentConnection {
 939    fn telemetry_id(&self) -> SharedString {
 940        "zed".into()
 941    }
 942
 943    fn new_thread(
 944        self: Rc<Self>,
 945        project: Entity<Project>,
 946        cwd: &Path,
 947        cx: &mut App,
 948    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 949        let agent = self.0.clone();
 950        log::debug!("Creating new thread for project at: {:?}", cwd);
 951
 952        cx.spawn(async move |cx| {
 953            log::debug!("Starting thread creation in async context");
 954
 955            // Create Thread
 956            let thread = agent.update(
 957                cx,
 958                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 959                    // Fetch default model from registry settings
 960                    let registry = LanguageModelRegistry::read_global(cx);
 961                    // Log available models for debugging
 962                    let available_count = registry.available_models(cx).count();
 963                    log::debug!("Total available models: {}", available_count);
 964
 965                    let default_model = registry.default_model().and_then(|default_model| {
 966                        agent
 967                            .models
 968                            .model_from_id(&LanguageModels::model_id(&default_model.model))
 969                    });
 970                    Ok(cx.new(|cx| {
 971                        Thread::new(
 972                            project.clone(),
 973                            agent.project_context.clone(),
 974                            agent.context_server_registry.clone(),
 975                            agent.templates.clone(),
 976                            default_model,
 977                            cx,
 978                        )
 979                    }))
 980                },
 981            )??;
 982            agent.update(cx, |agent, cx| agent.register_session(thread, cx))
 983        })
 984    }
 985
 986    fn auth_methods(&self) -> &[acp::AuthMethod] {
 987        &[] // No auth for in-process
 988    }
 989
 990    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 991        Task::ready(Ok(()))
 992    }
 993
 994    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 995        Some(Rc::new(NativeAgentModelSelector {
 996            session_id: session_id.clone(),
 997            connection: self.clone(),
 998        }) as Rc<dyn AgentModelSelector>)
 999    }
1000
1001    fn prompt(
1002        &self,
1003        id: Option<acp_thread::UserMessageId>,
1004        params: acp::PromptRequest,
1005        cx: &mut App,
1006    ) -> Task<Result<acp::PromptResponse>> {
1007        let id = id.expect("UserMessageId is required");
1008        let session_id = params.session_id.clone();
1009        log::info!("Received prompt request for session: {}", session_id);
1010        log::debug!("Prompt blocks count: {}", params.prompt.len());
1011        let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1012
1013        self.run_turn(session_id, cx, move |thread, cx| {
1014            let content: Vec<UserMessageContent> = params
1015                .prompt
1016                .into_iter()
1017                .map(|block| UserMessageContent::from_content_block(block, path_style))
1018                .collect::<Vec<_>>();
1019            log::debug!("Converted prompt to message: {} chars", content.len());
1020            log::debug!("Message id: {:?}", id);
1021            log::debug!("Message content: {:?}", content);
1022
1023            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1024        })
1025    }
1026
1027    fn resume(
1028        &self,
1029        session_id: &acp::SessionId,
1030        _cx: &App,
1031    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1032        Some(Rc::new(NativeAgentSessionResume {
1033            connection: self.clone(),
1034            session_id: session_id.clone(),
1035        }) as _)
1036    }
1037
1038    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1039        log::info!("Cancelling on session: {}", session_id);
1040        self.0.update(cx, |agent, cx| {
1041            if let Some(agent) = agent.sessions.get(session_id) {
1042                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1043            }
1044        });
1045    }
1046
1047    fn truncate(
1048        &self,
1049        session_id: &agent_client_protocol::SessionId,
1050        cx: &App,
1051    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1052        self.0.read_with(cx, |agent, _cx| {
1053            agent.sessions.get(session_id).map(|session| {
1054                Rc::new(NativeAgentSessionTruncate {
1055                    thread: session.thread.clone(),
1056                    acp_thread: session.acp_thread.clone(),
1057                }) as _
1058            })
1059        })
1060    }
1061
1062    fn set_title(
1063        &self,
1064        session_id: &acp::SessionId,
1065        _cx: &App,
1066    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1067        Some(Rc::new(NativeAgentSessionSetTitle {
1068            connection: self.clone(),
1069            session_id: session_id.clone(),
1070        }) as _)
1071    }
1072
1073    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1074        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1075    }
1076
1077    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1078        self
1079    }
1080}
1081
1082impl acp_thread::AgentTelemetry for NativeAgentConnection {
1083    fn thread_data(
1084        &self,
1085        session_id: &acp::SessionId,
1086        cx: &mut App,
1087    ) -> Task<Result<serde_json::Value>> {
1088        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1089            return Task::ready(Err(anyhow!("Session not found")));
1090        };
1091
1092        let task = session.thread.read(cx).to_db(cx);
1093        cx.background_spawn(async move {
1094            serde_json::to_value(task.await).context("Failed to serialize thread")
1095        })
1096    }
1097}
1098
1099struct NativeAgentSessionTruncate {
1100    thread: Entity<Thread>,
1101    acp_thread: WeakEntity<AcpThread>,
1102}
1103
1104impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1105    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1106        match self.thread.update(cx, |thread, cx| {
1107            thread.truncate(message_id.clone(), cx)?;
1108            Ok(thread.latest_token_usage())
1109        }) {
1110            Ok(usage) => {
1111                self.acp_thread
1112                    .update(cx, |thread, cx| {
1113                        thread.update_token_usage(usage, cx);
1114                    })
1115                    .ok();
1116                Task::ready(Ok(()))
1117            }
1118            Err(error) => Task::ready(Err(error)),
1119        }
1120    }
1121}
1122
1123struct NativeAgentSessionResume {
1124    connection: NativeAgentConnection,
1125    session_id: acp::SessionId,
1126}
1127
1128impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1129    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1130        self.connection
1131            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1132                thread.update(cx, |thread, cx| thread.resume(cx))
1133            })
1134    }
1135}
1136
1137struct NativeAgentSessionSetTitle {
1138    connection: NativeAgentConnection,
1139    session_id: acp::SessionId,
1140}
1141
1142impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1143    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1144        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1145            return Task::ready(Err(anyhow!("session not found")));
1146        };
1147        let thread = session.thread.clone();
1148        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1149        Task::ready(Ok(()))
1150    }
1151}
1152
1153pub struct AcpThreadEnvironment {
1154    acp_thread: WeakEntity<AcpThread>,
1155}
1156
1157impl ThreadEnvironment for AcpThreadEnvironment {
1158    fn create_terminal(
1159        &self,
1160        command: String,
1161        cwd: Option<PathBuf>,
1162        output_byte_limit: Option<u64>,
1163        cx: &mut AsyncApp,
1164    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1165        let task = self.acp_thread.update(cx, |thread, cx| {
1166            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1167        });
1168
1169        let acp_thread = self.acp_thread.clone();
1170        cx.spawn(async move |cx| {
1171            let terminal = task?.await?;
1172
1173            let (drop_tx, drop_rx) = oneshot::channel();
1174            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1175
1176            cx.spawn(async move |cx| {
1177                drop_rx.await.ok();
1178                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1179            })
1180            .detach();
1181
1182            let handle = AcpTerminalHandle {
1183                terminal,
1184                _drop_tx: Some(drop_tx),
1185            };
1186
1187            Ok(Rc::new(handle) as _)
1188        })
1189    }
1190}
1191
1192pub struct AcpTerminalHandle {
1193    terminal: Entity<acp_thread::Terminal>,
1194    _drop_tx: Option<oneshot::Sender<()>>,
1195}
1196
1197impl TerminalHandle for AcpTerminalHandle {
1198    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1199        self.terminal.read_with(cx, |term, _cx| term.id().clone())
1200    }
1201
1202    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1203        self.terminal
1204            .read_with(cx, |term, _cx| term.wait_for_exit())
1205    }
1206
1207    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1208        self.terminal
1209            .read_with(cx, |term, cx| term.current_output(cx))
1210    }
1211
1212    fn kill(&self, cx: &AsyncApp) -> Result<()> {
1213        cx.update(|cx| {
1214            self.terminal.update(cx, |terminal, cx| {
1215                terminal.kill(cx);
1216            });
1217        })?;
1218        Ok(())
1219    }
1220}
1221
1222#[cfg(test)]
1223mod internal_tests {
1224    use crate::HistoryEntryId;
1225
1226    use super::*;
1227    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1228    use fs::FakeFs;
1229    use gpui::TestAppContext;
1230    use indoc::formatdoc;
1231    use language_model::fake_provider::FakeLanguageModel;
1232    use serde_json::json;
1233    use settings::SettingsStore;
1234    use util::{path, rel_path::rel_path};
1235
1236    #[gpui::test]
1237    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1238        init_test(cx);
1239        let fs = FakeFs::new(cx.executor());
1240        fs.insert_tree(
1241            "/",
1242            json!({
1243                "a": {}
1244            }),
1245        )
1246        .await;
1247        let project = Project::test(fs.clone(), [], cx).await;
1248        let text_thread_store =
1249            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1250        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1251        let agent = NativeAgent::new(
1252            project.clone(),
1253            history_store,
1254            Templates::new(),
1255            None,
1256            fs.clone(),
1257            &mut cx.to_async(),
1258        )
1259        .await
1260        .unwrap();
1261        agent.read_with(cx, |agent, cx| {
1262            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1263        });
1264
1265        let worktree = project
1266            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1267            .await
1268            .unwrap();
1269        cx.run_until_parked();
1270        agent.read_with(cx, |agent, cx| {
1271            assert_eq!(
1272                agent.project_context.read(cx).worktrees,
1273                vec![WorktreeContext {
1274                    root_name: "a".into(),
1275                    abs_path: Path::new("/a").into(),
1276                    rules_file: None
1277                }]
1278            )
1279        });
1280
1281        // Creating `/a/.rules` updates the project context.
1282        fs.insert_file("/a/.rules", Vec::new()).await;
1283        cx.run_until_parked();
1284        agent.read_with(cx, |agent, cx| {
1285            let rules_entry = worktree
1286                .read(cx)
1287                .entry_for_path(rel_path(".rules"))
1288                .unwrap();
1289            assert_eq!(
1290                agent.project_context.read(cx).worktrees,
1291                vec![WorktreeContext {
1292                    root_name: "a".into(),
1293                    abs_path: Path::new("/a").into(),
1294                    rules_file: Some(RulesFileContext {
1295                        path_in_worktree: rel_path(".rules").into(),
1296                        text: "".into(),
1297                        project_entry_id: rules_entry.id.to_usize()
1298                    })
1299                }]
1300            )
1301        });
1302    }
1303
1304    #[gpui::test]
1305    async fn test_listing_models(cx: &mut TestAppContext) {
1306        init_test(cx);
1307        let fs = FakeFs::new(cx.executor());
1308        fs.insert_tree("/", json!({ "a": {}  })).await;
1309        let project = Project::test(fs.clone(), [], cx).await;
1310        let text_thread_store =
1311            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1312        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1313        let connection = NativeAgentConnection(
1314            NativeAgent::new(
1315                project.clone(),
1316                history_store,
1317                Templates::new(),
1318                None,
1319                fs.clone(),
1320                &mut cx.to_async(),
1321            )
1322            .await
1323            .unwrap(),
1324        );
1325
1326        // Create a thread/session
1327        let acp_thread = cx
1328            .update(|cx| {
1329                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1330            })
1331            .await
1332            .unwrap();
1333
1334        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1335
1336        let models = cx
1337            .update(|cx| {
1338                connection
1339                    .model_selector(&session_id)
1340                    .unwrap()
1341                    .list_models(cx)
1342            })
1343            .await
1344            .unwrap();
1345
1346        let acp_thread::AgentModelList::Grouped(models) = models else {
1347            panic!("Unexpected model group");
1348        };
1349        assert_eq!(
1350            models,
1351            IndexMap::from_iter([(
1352                AgentModelGroupName("Fake".into()),
1353                vec![AgentModelInfo {
1354                    id: acp::ModelId::new("fake/fake"),
1355                    name: "Fake".into(),
1356                    description: None,
1357                    icon: Some(ui::IconName::ZedAssistant),
1358                }]
1359            )])
1360        );
1361    }
1362
1363    #[gpui::test]
1364    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1365        init_test(cx);
1366        let fs = FakeFs::new(cx.executor());
1367        fs.create_dir(paths::settings_file().parent().unwrap())
1368            .await
1369            .unwrap();
1370        fs.insert_file(
1371            paths::settings_file(),
1372            json!({
1373                "agent": {
1374                    "default_model": {
1375                        "provider": "foo",
1376                        "model": "bar"
1377                    }
1378                }
1379            })
1380            .to_string()
1381            .into_bytes(),
1382        )
1383        .await;
1384        let project = Project::test(fs.clone(), [], cx).await;
1385
1386        let text_thread_store =
1387            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1388        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1389
1390        // Create the agent and connection
1391        let agent = NativeAgent::new(
1392            project.clone(),
1393            history_store,
1394            Templates::new(),
1395            None,
1396            fs.clone(),
1397            &mut cx.to_async(),
1398        )
1399        .await
1400        .unwrap();
1401        let connection = NativeAgentConnection(agent.clone());
1402
1403        // Create a thread/session
1404        let acp_thread = cx
1405            .update(|cx| {
1406                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1407            })
1408            .await
1409            .unwrap();
1410
1411        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1412
1413        // Select a model
1414        let selector = connection.model_selector(&session_id).unwrap();
1415        let model_id = acp::ModelId::new("fake/fake");
1416        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1417            .await
1418            .unwrap();
1419
1420        // Verify the thread has the selected model
1421        agent.read_with(cx, |agent, _| {
1422            let session = agent.sessions.get(&session_id).unwrap();
1423            session.thread.read_with(cx, |thread, _| {
1424                assert_eq!(thread.model().unwrap().id().0, "fake");
1425            });
1426        });
1427
1428        cx.run_until_parked();
1429
1430        // Verify settings file was updated
1431        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1432        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1433
1434        // Check that the agent settings contain the selected model
1435        assert_eq!(
1436            settings_json["agent"]["default_model"]["model"],
1437            json!("fake")
1438        );
1439        assert_eq!(
1440            settings_json["agent"]["default_model"]["provider"],
1441            json!("fake")
1442        );
1443    }
1444
1445    #[gpui::test]
1446    async fn test_save_load_thread(cx: &mut TestAppContext) {
1447        init_test(cx);
1448        let fs = FakeFs::new(cx.executor());
1449        fs.insert_tree(
1450            "/",
1451            json!({
1452                "a": {
1453                    "b.md": "Lorem"
1454                }
1455            }),
1456        )
1457        .await;
1458        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1459        let text_thread_store =
1460            cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1461        let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1462        let agent = NativeAgent::new(
1463            project.clone(),
1464            history_store.clone(),
1465            Templates::new(),
1466            None,
1467            fs.clone(),
1468            &mut cx.to_async(),
1469        )
1470        .await
1471        .unwrap();
1472        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1473
1474        let acp_thread = cx
1475            .update(|cx| {
1476                connection
1477                    .clone()
1478                    .new_thread(project.clone(), Path::new(""), cx)
1479            })
1480            .await
1481            .unwrap();
1482        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1483        let thread = agent.read_with(cx, |agent, _| {
1484            agent.sessions.get(&session_id).unwrap().thread.clone()
1485        });
1486
1487        // Ensure empty threads are not saved, even if they get mutated.
1488        let model = Arc::new(FakeLanguageModel::default());
1489        let summary_model = Arc::new(FakeLanguageModel::default());
1490        thread.update(cx, |thread, cx| {
1491            thread.set_model(model.clone(), cx);
1492            thread.set_summarization_model(Some(summary_model.clone()), cx);
1493        });
1494        cx.run_until_parked();
1495        assert_eq!(history_entries(&history_store, cx), vec![]);
1496
1497        let send = acp_thread.update(cx, |thread, cx| {
1498            thread.send(
1499                vec![
1500                    "What does ".into(),
1501                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1502                        "b.md",
1503                        MentionUri::File {
1504                            abs_path: path!("/a/b.md").into(),
1505                        }
1506                        .to_uri()
1507                        .to_string(),
1508                    )),
1509                    " mean?".into(),
1510                ],
1511                cx,
1512            )
1513        });
1514        let send = cx.foreground_executor().spawn(send);
1515        cx.run_until_parked();
1516
1517        model.send_last_completion_stream_text_chunk("Lorem.");
1518        model.end_last_completion_stream();
1519        cx.run_until_parked();
1520        summary_model
1521            .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1522        summary_model.end_last_completion_stream();
1523
1524        send.await.unwrap();
1525        let uri = MentionUri::File {
1526            abs_path: path!("/a/b.md").into(),
1527        }
1528        .to_uri();
1529        acp_thread.read_with(cx, |thread, cx| {
1530            assert_eq!(
1531                thread.to_markdown(cx),
1532                formatdoc! {"
1533                    ## User
1534
1535                    What does [@b.md]({uri}) mean?
1536
1537                    ## Assistant
1538
1539                    Lorem.
1540
1541                "}
1542            )
1543        });
1544
1545        cx.run_until_parked();
1546
1547        // Drop the ACP thread, which should cause the session to be dropped as well.
1548        cx.update(|_| {
1549            drop(thread);
1550            drop(acp_thread);
1551        });
1552        agent.read_with(cx, |agent, _| {
1553            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1554        });
1555
1556        // Ensure the thread can be reloaded from disk.
1557        assert_eq!(
1558            history_entries(&history_store, cx),
1559            vec![(
1560                HistoryEntryId::AcpThread(session_id.clone()),
1561                format!("Explaining {}", path!("/a/b.md"))
1562            )]
1563        );
1564        let acp_thread = agent
1565            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1566            .await
1567            .unwrap();
1568        acp_thread.read_with(cx, |thread, cx| {
1569            assert_eq!(
1570                thread.to_markdown(cx),
1571                formatdoc! {"
1572                    ## User
1573
1574                    What does [@b.md]({uri}) mean?
1575
1576                    ## Assistant
1577
1578                    Lorem.
1579
1580                "}
1581            )
1582        });
1583    }
1584
1585    fn history_entries(
1586        history: &Entity<HistoryStore>,
1587        cx: &mut TestAppContext,
1588    ) -> Vec<(HistoryEntryId, String)> {
1589        history.read_with(cx, |history, _| {
1590            history
1591                .entries()
1592                .map(|e| (e.id(), e.title().to_string()))
1593                .collect::<Vec<_>>()
1594        })
1595    }
1596
1597    fn init_test(cx: &mut TestAppContext) {
1598        env_logger::try_init().ok();
1599        cx.update(|cx| {
1600            let settings_store = SettingsStore::test(cx);
1601            cx.set_global(settings_store);
1602
1603            LanguageModelRegistry::test(cx);
1604        });
1605    }
1606}