agent.rs

   1use crate::{
   2    ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
   3    UserMessageContent, templates::Templates,
   4};
   5use crate::{HistoryStore, TerminalHandle, ThreadEnvironment, TitleUpdated, TokenUsageUpdated};
   6use acp_thread::{AcpThread, AgentModelSelector};
   7use action_log::ActionLog;
   8use agent_client_protocol as acp;
   9use agent_settings::AgentSettings;
  10use anyhow::{Context as _, Result, anyhow};
  11use collections::{HashSet, IndexMap};
  12use fs::Fs;
  13use futures::channel::{mpsc, oneshot};
  14use futures::future::Shared;
  15use futures::{StreamExt, future};
  16use gpui::{
  17    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  18};
  19use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  20use project::{Project, ProjectItem, ProjectPath, Worktree};
  21use prompt_store::{
  22    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  23};
  24use settings::update_settings_file;
  25use std::any::Any;
  26use std::collections::HashMap;
  27use std::path::{Path, PathBuf};
  28use std::rc::Rc;
  29use std::sync::Arc;
  30use util::ResultExt;
  31
  32const RULES_FILE_NAMES: [&str; 9] = [
  33    ".rules",
  34    ".cursorrules",
  35    ".windsurfrules",
  36    ".clinerules",
  37    ".github/copilot-instructions.md",
  38    "CLAUDE.md",
  39    "AGENT.md",
  40    "AGENTS.md",
  41    "GEMINI.md",
  42];
  43
  44pub struct RulesLoadingError {
  45    pub message: SharedString,
  46}
  47
  48/// Holds both the internal Thread and the AcpThread for a session
  49struct Session {
  50    /// The internal thread that processes messages
  51    thread: Entity<Thread>,
  52    /// The ACP thread that handles protocol communication
  53    acp_thread: WeakEntity<acp_thread::AcpThread>,
  54    pending_save: Task<()>,
  55    _subscriptions: Vec<Subscription>,
  56}
  57
  58pub struct LanguageModels {
  59    /// Access language model by ID
  60    models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
  61    /// Cached list for returning language model information
  62    model_list: acp_thread::AgentModelList,
  63    refresh_models_rx: watch::Receiver<()>,
  64    refresh_models_tx: watch::Sender<()>,
  65    _authenticate_all_providers_task: Task<()>,
  66}
  67
  68impl LanguageModels {
  69    fn new(cx: &mut App) -> Self {
  70        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  71
  72        let mut this = Self {
  73            models: HashMap::default(),
  74            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  75            refresh_models_rx,
  76            refresh_models_tx,
  77            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  78        };
  79        this.refresh_list(cx);
  80        this
  81    }
  82
  83    fn refresh_list(&mut self, cx: &App) {
  84        let providers = LanguageModelRegistry::global(cx)
  85            .read(cx)
  86            .providers()
  87            .into_iter()
  88            .filter(|provider| provider.is_authenticated(cx))
  89            .collect::<Vec<_>>();
  90
  91        let mut language_model_list = IndexMap::default();
  92        let mut recommended_models = HashSet::default();
  93
  94        let mut recommended = Vec::new();
  95        for provider in &providers {
  96            for model in provider.recommended_models(cx) {
  97                recommended_models.insert((model.provider_id(), model.id()));
  98                recommended.push(Self::map_language_model_to_info(&model, provider));
  99            }
 100        }
 101        if !recommended.is_empty() {
 102            language_model_list.insert(
 103                acp_thread::AgentModelGroupName("Recommended".into()),
 104                recommended,
 105            );
 106        }
 107
 108        let mut models = HashMap::default();
 109        for provider in providers {
 110            let mut provider_models = Vec::new();
 111            for model in provider.provided_models(cx) {
 112                let model_info = Self::map_language_model_to_info(&model, &provider);
 113                let model_id = model_info.id.clone();
 114                if !recommended_models.contains(&(model.provider_id(), model.id())) {
 115                    provider_models.push(model_info);
 116                }
 117                models.insert(model_id, model);
 118            }
 119            if !provider_models.is_empty() {
 120                language_model_list.insert(
 121                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 122                    provider_models,
 123                );
 124            }
 125        }
 126
 127        self.models = models;
 128        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 129        self.refresh_models_tx.send(()).ok();
 130    }
 131
 132    fn watch(&self) -> watch::Receiver<()> {
 133        self.refresh_models_rx.clone()
 134    }
 135
 136    pub fn model_from_id(
 137        &self,
 138        model_id: &acp_thread::AgentModelId,
 139    ) -> Option<Arc<dyn LanguageModel>> {
 140        self.models.get(model_id).cloned()
 141    }
 142
 143    fn map_language_model_to_info(
 144        model: &Arc<dyn LanguageModel>,
 145        provider: &Arc<dyn LanguageModelProvider>,
 146    ) -> acp_thread::AgentModelInfo {
 147        acp_thread::AgentModelInfo {
 148            id: Self::model_id(model),
 149            name: model.name().0,
 150            icon: Some(provider.icon()),
 151        }
 152    }
 153
 154    fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
 155        acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 156    }
 157
 158    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 159        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 160            .read(cx)
 161            .providers()
 162            .iter()
 163            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 164            .collect::<Vec<_>>();
 165
 166        cx.background_spawn(async move {
 167            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 168                if let Err(err) = authenticate_task.await {
 169                    if matches!(err, language_model::AuthenticateError::CredentialsNotFound) {
 170                        // Since we're authenticating these providers in the
 171                        // background for the purposes of populating the
 172                        // language selector, we don't care about providers
 173                        // where the credentials are not found.
 174                    } else {
 175                        // Some providers have noisy failure states that we
 176                        // don't want to spam the logs with every time the
 177                        // language model selector is initialized.
 178                        //
 179                        // Ideally these should have more clear failure modes
 180                        // that we know are safe to ignore here, like what we do
 181                        // with `CredentialsNotFound` above.
 182                        match provider_id.0.as_ref() {
 183                            "lmstudio" | "ollama" => {
 184                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 185                                //
 186                                // These fail noisily, so we don't log them.
 187                            }
 188                            "copilot_chat" => {
 189                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 190                            }
 191                            _ => {
 192                                log::error!(
 193                                    "Failed to authenticate provider: {}: {err}",
 194                                    provider_name.0
 195                                );
 196                            }
 197                        }
 198                    }
 199                }
 200            }
 201        })
 202    }
 203}
 204
 205pub struct NativeAgent {
 206    /// Session ID -> Session mapping
 207    sessions: HashMap<acp::SessionId, Session>,
 208    history: Entity<HistoryStore>,
 209    /// Shared project context for all threads
 210    project_context: Entity<ProjectContext>,
 211    project_context_needs_refresh: watch::Sender<()>,
 212    _maintain_project_context: Task<Result<()>>,
 213    context_server_registry: Entity<ContextServerRegistry>,
 214    /// Shared templates for all threads
 215    templates: Arc<Templates>,
 216    /// Cached model information
 217    models: LanguageModels,
 218    project: Entity<Project>,
 219    prompt_store: Option<Entity<PromptStore>>,
 220    fs: Arc<dyn Fs>,
 221    _subscriptions: Vec<Subscription>,
 222}
 223
 224impl NativeAgent {
 225    pub async fn new(
 226        project: Entity<Project>,
 227        history: Entity<HistoryStore>,
 228        templates: Arc<Templates>,
 229        prompt_store: Option<Entity<PromptStore>>,
 230        fs: Arc<dyn Fs>,
 231        cx: &mut AsyncApp,
 232    ) -> Result<Entity<NativeAgent>> {
 233        log::debug!("Creating new NativeAgent");
 234
 235        let project_context = cx
 236            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 237            .await;
 238
 239        cx.new(|cx| {
 240            let mut subscriptions = vec![
 241                cx.subscribe(&project, Self::handle_project_event),
 242                cx.subscribe(
 243                    &LanguageModelRegistry::global(cx),
 244                    Self::handle_models_updated_event,
 245                ),
 246            ];
 247            if let Some(prompt_store) = prompt_store.as_ref() {
 248                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 249            }
 250
 251            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 252                watch::channel(());
 253            Self {
 254                sessions: HashMap::new(),
 255                history,
 256                project_context: cx.new(|_| project_context),
 257                project_context_needs_refresh: project_context_needs_refresh_tx,
 258                _maintain_project_context: cx.spawn(async move |this, cx| {
 259                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 260                }),
 261                context_server_registry: cx.new(|cx| {
 262                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 263                }),
 264                templates,
 265                models: LanguageModels::new(cx),
 266                project,
 267                prompt_store,
 268                fs,
 269                _subscriptions: subscriptions,
 270            }
 271        })
 272    }
 273
 274    fn register_session(
 275        &mut self,
 276        thread_handle: Entity<Thread>,
 277        cx: &mut Context<Self>,
 278    ) -> Entity<AcpThread> {
 279        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 280
 281        let thread = thread_handle.read(cx);
 282        let session_id = thread.id().clone();
 283        let title = thread.title();
 284        let project = thread.project.clone();
 285        let action_log = thread.action_log.clone();
 286        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 287        let acp_thread = cx.new(|cx| {
 288            acp_thread::AcpThread::new(
 289                title,
 290                connection,
 291                project.clone(),
 292                action_log.clone(),
 293                session_id.clone(),
 294                prompt_capabilities_rx,
 295                cx,
 296            )
 297        });
 298
 299        let registry = LanguageModelRegistry::read_global(cx);
 300        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 301
 302        thread_handle.update(cx, |thread, cx| {
 303            thread.set_summarization_model(summarization_model, cx);
 304            thread.add_default_tools(
 305                Rc::new(AcpThreadEnvironment {
 306                    acp_thread: acp_thread.downgrade(),
 307                }) as _,
 308                cx,
 309            )
 310        });
 311
 312        let subscriptions = vec![
 313            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 314                this.sessions.remove(acp_thread.session_id());
 315            }),
 316            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 317            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 318            cx.observe(&thread_handle, move |this, thread, cx| {
 319                this.save_thread(thread, cx)
 320            }),
 321        ];
 322
 323        self.sessions.insert(
 324            session_id,
 325            Session {
 326                thread: thread_handle,
 327                acp_thread: acp_thread.downgrade(),
 328                _subscriptions: subscriptions,
 329                pending_save: Task::ready(()),
 330            },
 331        );
 332        acp_thread
 333    }
 334
 335    pub fn models(&self) -> &LanguageModels {
 336        &self.models
 337    }
 338
 339    async fn maintain_project_context(
 340        this: WeakEntity<Self>,
 341        mut needs_refresh: watch::Receiver<()>,
 342        cx: &mut AsyncApp,
 343    ) -> Result<()> {
 344        while needs_refresh.changed().await.is_ok() {
 345            let project_context = this
 346                .update(cx, |this, cx| {
 347                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 348                })?
 349                .await;
 350            this.update(cx, |this, cx| {
 351                this.project_context = cx.new(|_| project_context);
 352            })?;
 353        }
 354
 355        Ok(())
 356    }
 357
 358    fn build_project_context(
 359        project: &Entity<Project>,
 360        prompt_store: Option<&Entity<PromptStore>>,
 361        cx: &mut App,
 362    ) -> Task<ProjectContext> {
 363        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 364        let worktree_tasks = worktrees
 365            .into_iter()
 366            .map(|worktree| {
 367                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 368            })
 369            .collect::<Vec<_>>();
 370        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 371            prompt_store.read_with(cx, |prompt_store, cx| {
 372                let prompts = prompt_store.default_prompt_metadata();
 373                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 374                    let contents = prompt_store.load(prompt_metadata.id, cx);
 375                    async move { (contents.await, prompt_metadata) }
 376                });
 377                cx.background_spawn(future::join_all(load_tasks))
 378            })
 379        } else {
 380            Task::ready(vec![])
 381        };
 382
 383        cx.spawn(async move |_cx| {
 384            let (worktrees, default_user_rules) =
 385                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 386
 387            let worktrees = worktrees
 388                .into_iter()
 389                .map(|(worktree, _rules_error)| {
 390                    // TODO: show error message
 391                    // if let Some(rules_error) = rules_error {
 392                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 393                    // }
 394                    worktree
 395                })
 396                .collect::<Vec<_>>();
 397
 398            let default_user_rules = default_user_rules
 399                .into_iter()
 400                .flat_map(|(contents, prompt_metadata)| match contents {
 401                    Ok(contents) => Some(UserRulesContext {
 402                        uuid: match prompt_metadata.id {
 403                            PromptId::User { uuid } => uuid,
 404                            PromptId::EditWorkflow => return None,
 405                        },
 406                        title: prompt_metadata.title.map(|title| title.to_string()),
 407                        contents,
 408                    }),
 409                    Err(_err) => {
 410                        // TODO: show error message
 411                        // this.update(cx, |_, cx| {
 412                        //     cx.emit(RulesLoadingError {
 413                        //         message: format!("{err:?}").into(),
 414                        //     });
 415                        // })
 416                        // .ok();
 417                        None
 418                    }
 419                })
 420                .collect::<Vec<_>>();
 421
 422            ProjectContext::new(worktrees, default_user_rules)
 423        })
 424    }
 425
 426    fn load_worktree_info_for_system_prompt(
 427        worktree: Entity<Worktree>,
 428        project: Entity<Project>,
 429        cx: &mut App,
 430    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 431        let tree = worktree.read(cx);
 432        let root_name = tree.root_name().into();
 433        let abs_path = tree.abs_path();
 434
 435        let mut context = WorktreeContext {
 436            root_name,
 437            abs_path,
 438            rules_file: None,
 439        };
 440
 441        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 442        let Some(rules_task) = rules_task else {
 443            return Task::ready((context, None));
 444        };
 445
 446        cx.spawn(async move |_| {
 447            let (rules_file, rules_file_error) = match rules_task.await {
 448                Ok(rules_file) => (Some(rules_file), None),
 449                Err(err) => (
 450                    None,
 451                    Some(RulesLoadingError {
 452                        message: format!("{err}").into(),
 453                    }),
 454                ),
 455            };
 456            context.rules_file = rules_file;
 457            (context, rules_file_error)
 458        })
 459    }
 460
 461    fn load_worktree_rules_file(
 462        worktree: Entity<Worktree>,
 463        project: Entity<Project>,
 464        cx: &mut App,
 465    ) -> Option<Task<Result<RulesFileContext>>> {
 466        let worktree = worktree.read(cx);
 467        let worktree_id = worktree.id();
 468        let selected_rules_file = RULES_FILE_NAMES
 469            .into_iter()
 470            .filter_map(|name| {
 471                worktree
 472                    .entry_for_path(name)
 473                    .filter(|entry| entry.is_file())
 474                    .map(|entry| entry.path.clone())
 475            })
 476            .next();
 477
 478        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 479        // supported. This doesn't seem to occur often in GitHub repositories.
 480        selected_rules_file.map(|path_in_worktree| {
 481            let project_path = ProjectPath {
 482                worktree_id,
 483                path: path_in_worktree.clone(),
 484            };
 485            let buffer_task =
 486                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 487            let rope_task = cx.spawn(async move |cx| {
 488                buffer_task.await?.read_with(cx, |buffer, cx| {
 489                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 490                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 491                })?
 492            });
 493            // Build a string from the rope on a background thread.
 494            cx.background_spawn(async move {
 495                let (project_entry_id, rope) = rope_task.await?;
 496                anyhow::Ok(RulesFileContext {
 497                    path_in_worktree,
 498                    text: rope.to_string().trim().to_string(),
 499                    project_entry_id: project_entry_id.to_usize(),
 500                })
 501            })
 502        })
 503    }
 504
 505    fn handle_thread_title_updated(
 506        &mut self,
 507        thread: Entity<Thread>,
 508        _: &TitleUpdated,
 509        cx: &mut Context<Self>,
 510    ) {
 511        let session_id = thread.read(cx).id();
 512        let Some(session) = self.sessions.get(session_id) else {
 513            return;
 514        };
 515        let thread = thread.downgrade();
 516        let acp_thread = session.acp_thread.clone();
 517        cx.spawn(async move |_, cx| {
 518            let title = thread.read_with(cx, |thread, _| thread.title())?;
 519            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 520            task.await
 521        })
 522        .detach_and_log_err(cx);
 523    }
 524
 525    fn handle_thread_token_usage_updated(
 526        &mut self,
 527        thread: Entity<Thread>,
 528        usage: &TokenUsageUpdated,
 529        cx: &mut Context<Self>,
 530    ) {
 531        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 532            return;
 533        };
 534        session
 535            .acp_thread
 536            .update(cx, |acp_thread, cx| {
 537                acp_thread.update_token_usage(usage.0.clone(), cx);
 538            })
 539            .ok();
 540    }
 541
 542    fn handle_project_event(
 543        &mut self,
 544        _project: Entity<Project>,
 545        event: &project::Event,
 546        _cx: &mut Context<Self>,
 547    ) {
 548        match event {
 549            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 550                self.project_context_needs_refresh.send(()).ok();
 551            }
 552            project::Event::WorktreeUpdatedEntries(_, items) => {
 553                if items.iter().any(|(path, _, _)| {
 554                    RULES_FILE_NAMES
 555                        .iter()
 556                        .any(|name| path.as_ref() == Path::new(name))
 557                }) {
 558                    self.project_context_needs_refresh.send(()).ok();
 559                }
 560            }
 561            _ => {}
 562        }
 563    }
 564
 565    fn handle_prompts_updated_event(
 566        &mut self,
 567        _prompt_store: Entity<PromptStore>,
 568        _event: &prompt_store::PromptsUpdatedEvent,
 569        _cx: &mut Context<Self>,
 570    ) {
 571        self.project_context_needs_refresh.send(()).ok();
 572    }
 573
 574    fn handle_models_updated_event(
 575        &mut self,
 576        _registry: Entity<LanguageModelRegistry>,
 577        _event: &language_model::Event,
 578        cx: &mut Context<Self>,
 579    ) {
 580        self.models.refresh_list(cx);
 581
 582        let registry = LanguageModelRegistry::read_global(cx);
 583        let default_model = registry.default_model().map(|m| m.model);
 584        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 585
 586        for session in self.sessions.values_mut() {
 587            session.thread.update(cx, |thread, cx| {
 588                if thread.model().is_none()
 589                    && let Some(model) = default_model.clone()
 590                {
 591                    thread.set_model(model, cx);
 592                    cx.notify();
 593                }
 594                thread.set_summarization_model(summarization_model.clone(), cx);
 595            });
 596        }
 597    }
 598
 599    pub fn open_thread(
 600        &mut self,
 601        id: acp::SessionId,
 602        cx: &mut Context<Self>,
 603    ) -> Task<Result<Entity<AcpThread>>> {
 604        let database_future = ThreadsDatabase::connect(cx);
 605        cx.spawn(async move |this, cx| {
 606            let database = database_future.await.map_err(|err| anyhow!(err))?;
 607            let db_thread = database
 608                .load_thread(id.clone())
 609                .await?
 610                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 611
 612            let thread = this.update(cx, |this, cx| {
 613                let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
 614                cx.new(|cx| {
 615                    Thread::from_db(
 616                        id.clone(),
 617                        db_thread,
 618                        this.project.clone(),
 619                        this.project_context.clone(),
 620                        this.context_server_registry.clone(),
 621                        action_log.clone(),
 622                        this.templates.clone(),
 623                        cx,
 624                    )
 625                })
 626            })?;
 627            let acp_thread =
 628                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 629            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
 630            cx.update(|cx| {
 631                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 632            })?
 633            .await?;
 634            Ok(acp_thread)
 635        })
 636    }
 637
 638    pub fn thread_summary(
 639        &mut self,
 640        id: acp::SessionId,
 641        cx: &mut Context<Self>,
 642    ) -> Task<Result<SharedString>> {
 643        let thread = self.open_thread(id.clone(), cx);
 644        cx.spawn(async move |this, cx| {
 645            let acp_thread = thread.await?;
 646            let result = this
 647                .update(cx, |this, cx| {
 648                    this.sessions
 649                        .get(&id)
 650                        .unwrap()
 651                        .thread
 652                        .update(cx, |thread, cx| thread.summary(cx))
 653                })?
 654                .await?;
 655            drop(acp_thread);
 656            Ok(result)
 657        })
 658    }
 659
 660    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 661        if thread.read(cx).is_empty() {
 662            return;
 663        }
 664
 665        let database_future = ThreadsDatabase::connect(cx);
 666        let (id, db_thread) =
 667            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 668        let Some(session) = self.sessions.get_mut(&id) else {
 669            return;
 670        };
 671        let history = self.history.clone();
 672        session.pending_save = cx.spawn(async move |_, cx| {
 673            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 674                return;
 675            };
 676            let db_thread = db_thread.await;
 677            database.save_thread(id, db_thread).await.log_err();
 678            history.update(cx, |history, cx| history.reload(cx)).ok();
 679        });
 680    }
 681}
 682
 683/// Wrapper struct that implements the AgentConnection trait
 684#[derive(Clone)]
 685pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 686
 687impl NativeAgentConnection {
 688    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 689        self.0
 690            .read(cx)
 691            .sessions
 692            .get(session_id)
 693            .map(|session| session.thread.clone())
 694    }
 695
 696    fn run_turn(
 697        &self,
 698        session_id: acp::SessionId,
 699        cx: &mut App,
 700        f: impl 'static
 701        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 702    ) -> Task<Result<acp::PromptResponse>> {
 703        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 704            agent
 705                .sessions
 706                .get_mut(&session_id)
 707                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 708        }) else {
 709            return Task::ready(Err(anyhow!("Session not found")));
 710        };
 711        log::debug!("Found session for: {}", session_id);
 712
 713        let response_stream = match f(thread, cx) {
 714            Ok(stream) => stream,
 715            Err(err) => return Task::ready(Err(err)),
 716        };
 717        Self::handle_thread_events(response_stream, acp_thread, cx)
 718    }
 719
 720    fn handle_thread_events(
 721        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 722        acp_thread: WeakEntity<AcpThread>,
 723        cx: &App,
 724    ) -> Task<Result<acp::PromptResponse>> {
 725        cx.spawn(async move |cx| {
 726            // Handle response stream and forward to session.acp_thread
 727            while let Some(result) = events.next().await {
 728                match result {
 729                    Ok(event) => {
 730                        log::trace!("Received completion event: {:?}", event);
 731
 732                        match event {
 733                            ThreadEvent::UserMessage(message) => {
 734                                acp_thread.update(cx, |thread, cx| {
 735                                    for content in message.content {
 736                                        thread.push_user_content_block(
 737                                            Some(message.id.clone()),
 738                                            content.into(),
 739                                            cx,
 740                                        );
 741                                    }
 742                                })?;
 743                            }
 744                            ThreadEvent::AgentText(text) => {
 745                                acp_thread.update(cx, |thread, cx| {
 746                                    thread.push_assistant_content_block(
 747                                        acp::ContentBlock::Text(acp::TextContent {
 748                                            text,
 749                                            annotations: None,
 750                                            meta: None,
 751                                        }),
 752                                        false,
 753                                        cx,
 754                                    )
 755                                })?;
 756                            }
 757                            ThreadEvent::AgentThinking(text) => {
 758                                acp_thread.update(cx, |thread, cx| {
 759                                    thread.push_assistant_content_block(
 760                                        acp::ContentBlock::Text(acp::TextContent {
 761                                            text,
 762                                            annotations: None,
 763                                            meta: None,
 764                                        }),
 765                                        true,
 766                                        cx,
 767                                    )
 768                                })?;
 769                            }
 770                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 771                                tool_call,
 772                                options,
 773                                response,
 774                            }) => {
 775                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 776                                    thread.request_tool_call_authorization(
 777                                        tool_call, options, true, cx,
 778                                    )
 779                                })??;
 780                                cx.background_spawn(async move {
 781                                    if let acp::RequestPermissionOutcome::Selected { option_id } =
 782                                        outcome_task.await
 783                                    {
 784                                        response
 785                                            .send(option_id)
 786                                            .map(|_| anyhow!("authorization receiver was dropped"))
 787                                            .log_err();
 788                                    }
 789                                })
 790                                .detach();
 791                            }
 792                            ThreadEvent::ToolCall(tool_call) => {
 793                                acp_thread.update(cx, |thread, cx| {
 794                                    thread.upsert_tool_call(tool_call, cx)
 795                                })??;
 796                            }
 797                            ThreadEvent::ToolCallUpdate(update) => {
 798                                acp_thread.update(cx, |thread, cx| {
 799                                    thread.update_tool_call(update, cx)
 800                                })??;
 801                            }
 802                            ThreadEvent::Retry(status) => {
 803                                acp_thread.update(cx, |thread, cx| {
 804                                    thread.update_retry_status(status, cx)
 805                                })?;
 806                            }
 807                            ThreadEvent::Stop(stop_reason) => {
 808                                log::debug!("Assistant message complete: {:?}", stop_reason);
 809                                return Ok(acp::PromptResponse {
 810                                    stop_reason,
 811                                    meta: None,
 812                                });
 813                            }
 814                        }
 815                    }
 816                    Err(e) => {
 817                        log::error!("Error in model response stream: {:?}", e);
 818                        return Err(e);
 819                    }
 820                }
 821            }
 822
 823            log::debug!("Response stream completed");
 824            anyhow::Ok(acp::PromptResponse {
 825                stop_reason: acp::StopReason::EndTurn,
 826                meta: None,
 827            })
 828        })
 829    }
 830}
 831
 832impl AgentModelSelector for NativeAgentConnection {
 833    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 834        log::debug!("NativeAgentConnection::list_models called");
 835        let list = self.0.read(cx).models.model_list.clone();
 836        Task::ready(if list.is_empty() {
 837            Err(anyhow::anyhow!("No models available"))
 838        } else {
 839            Ok(list)
 840        })
 841    }
 842
 843    fn select_model(
 844        &self,
 845        session_id: acp::SessionId,
 846        model_id: acp_thread::AgentModelId,
 847        cx: &mut App,
 848    ) -> Task<Result<()>> {
 849        log::debug!("Setting model for session {}: {}", session_id, model_id);
 850        let Some(thread) = self
 851            .0
 852            .read(cx)
 853            .sessions
 854            .get(&session_id)
 855            .map(|session| session.thread.clone())
 856        else {
 857            return Task::ready(Err(anyhow!("Session not found")));
 858        };
 859
 860        let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
 861            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 862        };
 863
 864        thread.update(cx, |thread, cx| {
 865            thread.set_model(model.clone(), cx);
 866        });
 867
 868        update_settings_file::<AgentSettings>(
 869            self.0.read(cx).fs.clone(),
 870            cx,
 871            move |settings, _cx| {
 872                settings.set_model(model);
 873            },
 874        );
 875
 876        Task::ready(Ok(()))
 877    }
 878
 879    fn selected_model(
 880        &self,
 881        session_id: &acp::SessionId,
 882        cx: &mut App,
 883    ) -> Task<Result<acp_thread::AgentModelInfo>> {
 884        let session_id = session_id.clone();
 885
 886        let Some(thread) = self
 887            .0
 888            .read(cx)
 889            .sessions
 890            .get(&session_id)
 891            .map(|session| session.thread.clone())
 892        else {
 893            return Task::ready(Err(anyhow!("Session not found")));
 894        };
 895        let Some(model) = thread.read(cx).model() else {
 896            return Task::ready(Err(anyhow!("Model not found")));
 897        };
 898        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 899        else {
 900            return Task::ready(Err(anyhow!("Provider not found")));
 901        };
 902        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 903            model, &provider,
 904        )))
 905    }
 906
 907    fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
 908        self.0.read(cx).models.watch()
 909    }
 910}
 911
 912impl acp_thread::AgentConnection for NativeAgentConnection {
 913    fn new_thread(
 914        self: Rc<Self>,
 915        project: Entity<Project>,
 916        cwd: &Path,
 917        cx: &mut App,
 918    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 919        let agent = self.0.clone();
 920        log::debug!("Creating new thread for project at: {:?}", cwd);
 921
 922        cx.spawn(async move |cx| {
 923            log::debug!("Starting thread creation in async context");
 924
 925            // Create Thread
 926            let thread = agent.update(
 927                cx,
 928                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 929                    // Fetch default model from registry settings
 930                    let registry = LanguageModelRegistry::read_global(cx);
 931                    // Log available models for debugging
 932                    let available_count = registry.available_models(cx).count();
 933                    log::debug!("Total available models: {}", available_count);
 934
 935                    let default_model = registry.default_model().and_then(|default_model| {
 936                        agent
 937                            .models
 938                            .model_from_id(&LanguageModels::model_id(&default_model.model))
 939                    });
 940                    Ok(cx.new(|cx| {
 941                        Thread::new(
 942                            project.clone(),
 943                            agent.project_context.clone(),
 944                            agent.context_server_registry.clone(),
 945                            agent.templates.clone(),
 946                            default_model,
 947                            cx,
 948                        )
 949                    }))
 950                },
 951            )??;
 952            agent.update(cx, |agent, cx| agent.register_session(thread, cx))
 953        })
 954    }
 955
 956    fn auth_methods(&self) -> &[acp::AuthMethod] {
 957        &[] // No auth for in-process
 958    }
 959
 960    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 961        Task::ready(Ok(()))
 962    }
 963
 964    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 965        Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
 966    }
 967
 968    fn prompt(
 969        &self,
 970        id: Option<acp_thread::UserMessageId>,
 971        params: acp::PromptRequest,
 972        cx: &mut App,
 973    ) -> Task<Result<acp::PromptResponse>> {
 974        let id = id.expect("UserMessageId is required");
 975        let session_id = params.session_id.clone();
 976        log::info!("Received prompt request for session: {}", session_id);
 977        log::debug!("Prompt blocks count: {}", params.prompt.len());
 978
 979        self.run_turn(session_id, cx, |thread, cx| {
 980            let content: Vec<UserMessageContent> = params
 981                .prompt
 982                .into_iter()
 983                .map(Into::into)
 984                .collect::<Vec<_>>();
 985            log::debug!("Converted prompt to message: {} chars", content.len());
 986            log::debug!("Message id: {:?}", id);
 987            log::debug!("Message content: {:?}", content);
 988
 989            thread.update(cx, |thread, cx| thread.send(id, content, cx))
 990        })
 991    }
 992
 993    fn resume(
 994        &self,
 995        session_id: &acp::SessionId,
 996        _cx: &App,
 997    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
 998        Some(Rc::new(NativeAgentSessionResume {
 999            connection: self.clone(),
1000            session_id: session_id.clone(),
1001        }) as _)
1002    }
1003
1004    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1005        log::info!("Cancelling on session: {}", session_id);
1006        self.0.update(cx, |agent, cx| {
1007            if let Some(agent) = agent.sessions.get(session_id) {
1008                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1009            }
1010        });
1011    }
1012
1013    fn truncate(
1014        &self,
1015        session_id: &agent_client_protocol::SessionId,
1016        cx: &App,
1017    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1018        self.0.read_with(cx, |agent, _cx| {
1019            agent.sessions.get(session_id).map(|session| {
1020                Rc::new(NativeAgentSessionTruncate {
1021                    thread: session.thread.clone(),
1022                    acp_thread: session.acp_thread.clone(),
1023                }) as _
1024            })
1025        })
1026    }
1027
1028    fn set_title(
1029        &self,
1030        session_id: &acp::SessionId,
1031        _cx: &App,
1032    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1033        Some(Rc::new(NativeAgentSessionSetTitle {
1034            connection: self.clone(),
1035            session_id: session_id.clone(),
1036        }) as _)
1037    }
1038
1039    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1040        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1041    }
1042
1043    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1044        self
1045    }
1046}
1047
1048impl acp_thread::AgentTelemetry for NativeAgentConnection {
1049    fn agent_name(&self) -> String {
1050        "Zed".into()
1051    }
1052
1053    fn thread_data(
1054        &self,
1055        session_id: &acp::SessionId,
1056        cx: &mut App,
1057    ) -> Task<Result<serde_json::Value>> {
1058        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1059            return Task::ready(Err(anyhow!("Session not found")));
1060        };
1061
1062        let task = session.thread.read(cx).to_db(cx);
1063        cx.background_spawn(async move {
1064            serde_json::to_value(task.await).context("Failed to serialize thread")
1065        })
1066    }
1067}
1068
1069struct NativeAgentSessionTruncate {
1070    thread: Entity<Thread>,
1071    acp_thread: WeakEntity<AcpThread>,
1072}
1073
1074impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1075    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1076        match self.thread.update(cx, |thread, cx| {
1077            thread.truncate(message_id.clone(), cx)?;
1078            Ok(thread.latest_token_usage())
1079        }) {
1080            Ok(usage) => {
1081                self.acp_thread
1082                    .update(cx, |thread, cx| {
1083                        thread.update_token_usage(usage, cx);
1084                    })
1085                    .ok();
1086                Task::ready(Ok(()))
1087            }
1088            Err(error) => Task::ready(Err(error)),
1089        }
1090    }
1091}
1092
1093struct NativeAgentSessionResume {
1094    connection: NativeAgentConnection,
1095    session_id: acp::SessionId,
1096}
1097
1098impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1099    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1100        self.connection
1101            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1102                thread.update(cx, |thread, cx| thread.resume(cx))
1103            })
1104    }
1105}
1106
1107struct NativeAgentSessionSetTitle {
1108    connection: NativeAgentConnection,
1109    session_id: acp::SessionId,
1110}
1111
1112impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1113    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1114        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1115            return Task::ready(Err(anyhow!("session not found")));
1116        };
1117        let thread = session.thread.clone();
1118        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1119        Task::ready(Ok(()))
1120    }
1121}
1122
1123pub struct AcpThreadEnvironment {
1124    acp_thread: WeakEntity<AcpThread>,
1125}
1126
1127impl ThreadEnvironment for AcpThreadEnvironment {
1128    fn create_terminal(
1129        &self,
1130        command: String,
1131        cwd: Option<PathBuf>,
1132        output_byte_limit: Option<u64>,
1133        cx: &mut AsyncApp,
1134    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1135        let task = self.acp_thread.update(cx, |thread, cx| {
1136            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1137        });
1138
1139        let acp_thread = self.acp_thread.clone();
1140        cx.spawn(async move |cx| {
1141            let terminal = task?.await?;
1142
1143            let (drop_tx, drop_rx) = oneshot::channel();
1144            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1145
1146            cx.spawn(async move |cx| {
1147                drop_rx.await.ok();
1148                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1149            })
1150            .detach();
1151
1152            let handle = AcpTerminalHandle {
1153                terminal,
1154                _drop_tx: Some(drop_tx),
1155            };
1156
1157            Ok(Rc::new(handle) as _)
1158        })
1159    }
1160}
1161
1162pub struct AcpTerminalHandle {
1163    terminal: Entity<acp_thread::Terminal>,
1164    _drop_tx: Option<oneshot::Sender<()>>,
1165}
1166
1167impl TerminalHandle for AcpTerminalHandle {
1168    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1169        self.terminal.read_with(cx, |term, _cx| term.id().clone())
1170    }
1171
1172    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1173        self.terminal
1174            .read_with(cx, |term, _cx| term.wait_for_exit())
1175    }
1176
1177    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1178        self.terminal
1179            .read_with(cx, |term, cx| term.current_output(cx))
1180    }
1181}
1182
1183#[cfg(test)]
1184mod tests {
1185    use crate::HistoryEntryId;
1186
1187    use super::*;
1188    use acp_thread::{
1189        AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
1190    };
1191    use fs::FakeFs;
1192    use gpui::TestAppContext;
1193    use indoc::indoc;
1194    use language_model::fake_provider::FakeLanguageModel;
1195    use serde_json::json;
1196    use settings::SettingsStore;
1197    use util::path;
1198
1199    #[gpui::test]
1200    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1201        init_test(cx);
1202        let fs = FakeFs::new(cx.executor());
1203        fs.insert_tree(
1204            "/",
1205            json!({
1206                "a": {}
1207            }),
1208        )
1209        .await;
1210        let project = Project::test(fs.clone(), [], cx).await;
1211        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1212        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1213        let agent = NativeAgent::new(
1214            project.clone(),
1215            history_store,
1216            Templates::new(),
1217            None,
1218            fs.clone(),
1219            &mut cx.to_async(),
1220        )
1221        .await
1222        .unwrap();
1223        agent.read_with(cx, |agent, cx| {
1224            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1225        });
1226
1227        let worktree = project
1228            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1229            .await
1230            .unwrap();
1231        cx.run_until_parked();
1232        agent.read_with(cx, |agent, cx| {
1233            assert_eq!(
1234                agent.project_context.read(cx).worktrees,
1235                vec![WorktreeContext {
1236                    root_name: "a".into(),
1237                    abs_path: Path::new("/a").into(),
1238                    rules_file: None
1239                }]
1240            )
1241        });
1242
1243        // Creating `/a/.rules` updates the project context.
1244        fs.insert_file("/a/.rules", Vec::new()).await;
1245        cx.run_until_parked();
1246        agent.read_with(cx, |agent, cx| {
1247            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1248            assert_eq!(
1249                agent.project_context.read(cx).worktrees,
1250                vec![WorktreeContext {
1251                    root_name: "a".into(),
1252                    abs_path: Path::new("/a").into(),
1253                    rules_file: Some(RulesFileContext {
1254                        path_in_worktree: Path::new(".rules").into(),
1255                        text: "".into(),
1256                        project_entry_id: rules_entry.id.to_usize()
1257                    })
1258                }]
1259            )
1260        });
1261    }
1262
1263    #[gpui::test]
1264    async fn test_listing_models(cx: &mut TestAppContext) {
1265        init_test(cx);
1266        let fs = FakeFs::new(cx.executor());
1267        fs.insert_tree("/", json!({ "a": {}  })).await;
1268        let project = Project::test(fs.clone(), [], cx).await;
1269        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1270        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1271        let connection = NativeAgentConnection(
1272            NativeAgent::new(
1273                project.clone(),
1274                history_store,
1275                Templates::new(),
1276                None,
1277                fs.clone(),
1278                &mut cx.to_async(),
1279            )
1280            .await
1281            .unwrap(),
1282        );
1283
1284        let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1285
1286        let acp_thread::AgentModelList::Grouped(models) = models else {
1287            panic!("Unexpected model group");
1288        };
1289        assert_eq!(
1290            models,
1291            IndexMap::from_iter([(
1292                AgentModelGroupName("Fake".into()),
1293                vec![AgentModelInfo {
1294                    id: AgentModelId("fake/fake".into()),
1295                    name: "Fake".into(),
1296                    icon: Some(ui::IconName::ZedAssistant),
1297                }]
1298            )])
1299        );
1300    }
1301
1302    #[gpui::test]
1303    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1304        init_test(cx);
1305        let fs = FakeFs::new(cx.executor());
1306        fs.create_dir(paths::settings_file().parent().unwrap())
1307            .await
1308            .unwrap();
1309        fs.insert_file(
1310            paths::settings_file(),
1311            json!({
1312                "agent": {
1313                    "default_model": {
1314                        "provider": "foo",
1315                        "model": "bar"
1316                    }
1317                }
1318            })
1319            .to_string()
1320            .into_bytes(),
1321        )
1322        .await;
1323        let project = Project::test(fs.clone(), [], cx).await;
1324
1325        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1326        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1327
1328        // Create the agent and connection
1329        let agent = NativeAgent::new(
1330            project.clone(),
1331            history_store,
1332            Templates::new(),
1333            None,
1334            fs.clone(),
1335            &mut cx.to_async(),
1336        )
1337        .await
1338        .unwrap();
1339        let connection = NativeAgentConnection(agent.clone());
1340
1341        // Create a thread/session
1342        let acp_thread = cx
1343            .update(|cx| {
1344                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1345            })
1346            .await
1347            .unwrap();
1348
1349        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1350
1351        // Select a model
1352        let model_id = AgentModelId("fake/fake".into());
1353        cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1354            .await
1355            .unwrap();
1356
1357        // Verify the thread has the selected model
1358        agent.read_with(cx, |agent, _| {
1359            let session = agent.sessions.get(&session_id).unwrap();
1360            session.thread.read_with(cx, |thread, _| {
1361                assert_eq!(thread.model().unwrap().id().0, "fake");
1362            });
1363        });
1364
1365        cx.run_until_parked();
1366
1367        // Verify settings file was updated
1368        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1369        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1370
1371        // Check that the agent settings contain the selected model
1372        assert_eq!(
1373            settings_json["agent"]["default_model"]["model"],
1374            json!("fake")
1375        );
1376        assert_eq!(
1377            settings_json["agent"]["default_model"]["provider"],
1378            json!("fake")
1379        );
1380    }
1381
1382    #[gpui::test]
1383    #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1384    async fn test_save_load_thread(cx: &mut TestAppContext) {
1385        init_test(cx);
1386        let fs = FakeFs::new(cx.executor());
1387        fs.insert_tree(
1388            "/",
1389            json!({
1390                "a": {
1391                    "b.md": "Lorem"
1392                }
1393            }),
1394        )
1395        .await;
1396        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1397        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1398        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1399        let agent = NativeAgent::new(
1400            project.clone(),
1401            history_store.clone(),
1402            Templates::new(),
1403            None,
1404            fs.clone(),
1405            &mut cx.to_async(),
1406        )
1407        .await
1408        .unwrap();
1409        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1410
1411        let acp_thread = cx
1412            .update(|cx| {
1413                connection
1414                    .clone()
1415                    .new_thread(project.clone(), Path::new(""), cx)
1416            })
1417            .await
1418            .unwrap();
1419        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1420        let thread = agent.read_with(cx, |agent, _| {
1421            agent.sessions.get(&session_id).unwrap().thread.clone()
1422        });
1423
1424        // Ensure empty threads are not saved, even if they get mutated.
1425        let model = Arc::new(FakeLanguageModel::default());
1426        let summary_model = Arc::new(FakeLanguageModel::default());
1427        thread.update(cx, |thread, cx| {
1428            thread.set_model(model.clone(), cx);
1429            thread.set_summarization_model(Some(summary_model.clone()), cx);
1430        });
1431        cx.run_until_parked();
1432        assert_eq!(history_entries(&history_store, cx), vec![]);
1433
1434        let send = acp_thread.update(cx, |thread, cx| {
1435            thread.send(
1436                vec![
1437                    "What does ".into(),
1438                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
1439                        name: "b.md".into(),
1440                        uri: MentionUri::File {
1441                            abs_path: path!("/a/b.md").into(),
1442                        }
1443                        .to_uri()
1444                        .to_string(),
1445                        annotations: None,
1446                        description: None,
1447                        mime_type: None,
1448                        size: None,
1449                        title: None,
1450                        meta: None,
1451                    }),
1452                    " mean?".into(),
1453                ],
1454                cx,
1455            )
1456        });
1457        let send = cx.foreground_executor().spawn(send);
1458        cx.run_until_parked();
1459
1460        model.send_last_completion_stream_text_chunk("Lorem.");
1461        model.end_last_completion_stream();
1462        cx.run_until_parked();
1463        summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
1464        summary_model.end_last_completion_stream();
1465
1466        send.await.unwrap();
1467        acp_thread.read_with(cx, |thread, cx| {
1468            assert_eq!(
1469                thread.to_markdown(cx),
1470                indoc! {"
1471                    ## User
1472
1473                    What does [@b.md](file:///a/b.md) mean?
1474
1475                    ## Assistant
1476
1477                    Lorem.
1478
1479                "}
1480            )
1481        });
1482
1483        cx.run_until_parked();
1484
1485        // Drop the ACP thread, which should cause the session to be dropped as well.
1486        cx.update(|_| {
1487            drop(thread);
1488            drop(acp_thread);
1489        });
1490        agent.read_with(cx, |agent, _| {
1491            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1492        });
1493
1494        // Ensure the thread can be reloaded from disk.
1495        assert_eq!(
1496            history_entries(&history_store, cx),
1497            vec![(
1498                HistoryEntryId::AcpThread(session_id.clone()),
1499                "Explaining /a/b.md".into()
1500            )]
1501        );
1502        let acp_thread = agent
1503            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1504            .await
1505            .unwrap();
1506        acp_thread.read_with(cx, |thread, cx| {
1507            assert_eq!(
1508                thread.to_markdown(cx),
1509                indoc! {"
1510                    ## User
1511
1512                    What does [@b.md](file:///a/b.md) mean?
1513
1514                    ## Assistant
1515
1516                    Lorem.
1517
1518                "}
1519            )
1520        });
1521    }
1522
1523    fn history_entries(
1524        history: &Entity<HistoryStore>,
1525        cx: &mut TestAppContext,
1526    ) -> Vec<(HistoryEntryId, String)> {
1527        history.read_with(cx, |history, _| {
1528            history
1529                .entries()
1530                .map(|e| (e.id(), e.title().to_string()))
1531                .collect::<Vec<_>>()
1532        })
1533    }
1534
1535    fn init_test(cx: &mut TestAppContext) {
1536        env_logger::try_init().ok();
1537        cx.update(|cx| {
1538            let settings_store = SettingsStore::test(cx);
1539            cx.set_global(settings_store);
1540            Project::init_settings(cx);
1541            agent_settings::init(cx);
1542            language::init(cx);
1543            LanguageModelRegistry::test(cx);
1544        });
1545    }
1546}