agent.rs

   1use crate::{
   2    ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
   3    UserMessageContent, templates::Templates,
   4};
   5use crate::{HistoryStore, TerminalHandle, ThreadEnvironment, TitleUpdated, TokenUsageUpdated};
   6use acp_thread::{AcpThread, AgentModelSelector};
   7use action_log::ActionLog;
   8use agent_client_protocol as acp;
   9use anyhow::{Context as _, Result, anyhow};
  10use collections::{HashSet, IndexMap};
  11use fs::Fs;
  12use futures::channel::{mpsc, oneshot};
  13use futures::future::Shared;
  14use futures::{StreamExt, future};
  15use gpui::{
  16    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  17};
  18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
  19use project::{Project, ProjectItem, ProjectPath, Worktree};
  20use prompt_store::{
  21    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  22};
  23use settings::{LanguageModelSelection, update_settings_file};
  24use std::any::Any;
  25use std::collections::HashMap;
  26use std::path::{Path, PathBuf};
  27use std::rc::Rc;
  28use std::sync::Arc;
  29use util::ResultExt;
  30use util::rel_path::RelPath;
  31
  32const RULES_FILE_NAMES: [&str; 9] = [
  33    ".rules",
  34    ".cursorrules",
  35    ".windsurfrules",
  36    ".clinerules",
  37    ".github/copilot-instructions.md",
  38    "CLAUDE.md",
  39    "AGENT.md",
  40    "AGENTS.md",
  41    "GEMINI.md",
  42];
  43
  44pub struct RulesLoadingError {
  45    pub message: SharedString,
  46}
  47
  48/// Holds both the internal Thread and the AcpThread for a session
  49struct Session {
  50    /// The internal thread that processes messages
  51    thread: Entity<Thread>,
  52    /// The ACP thread that handles protocol communication
  53    acp_thread: WeakEntity<acp_thread::AcpThread>,
  54    pending_save: Task<()>,
  55    _subscriptions: Vec<Subscription>,
  56}
  57
  58pub struct LanguageModels {
  59    /// Access language model by ID
  60    models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
  61    /// Cached list for returning language model information
  62    model_list: acp_thread::AgentModelList,
  63    refresh_models_rx: watch::Receiver<()>,
  64    refresh_models_tx: watch::Sender<()>,
  65    _authenticate_all_providers_task: Task<()>,
  66}
  67
  68impl LanguageModels {
  69    fn new(cx: &mut App) -> Self {
  70        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  71
  72        let mut this = Self {
  73            models: HashMap::default(),
  74            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  75            refresh_models_rx,
  76            refresh_models_tx,
  77            _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
  78        };
  79        this.refresh_list(cx);
  80        this
  81    }
  82
  83    fn refresh_list(&mut self, cx: &App) {
  84        let providers = LanguageModelRegistry::global(cx)
  85            .read(cx)
  86            .providers()
  87            .into_iter()
  88            .filter(|provider| provider.is_authenticated(cx))
  89            .collect::<Vec<_>>();
  90
  91        let mut language_model_list = IndexMap::default();
  92        let mut recommended_models = HashSet::default();
  93
  94        let mut recommended = Vec::new();
  95        for provider in &providers {
  96            for model in provider.recommended_models(cx) {
  97                recommended_models.insert((model.provider_id(), model.id()));
  98                recommended.push(Self::map_language_model_to_info(&model, provider));
  99            }
 100        }
 101        if !recommended.is_empty() {
 102            language_model_list.insert(
 103                acp_thread::AgentModelGroupName("Recommended".into()),
 104                recommended,
 105            );
 106        }
 107
 108        let mut models = HashMap::default();
 109        for provider in providers {
 110            let mut provider_models = Vec::new();
 111            for model in provider.provided_models(cx) {
 112                let model_info = Self::map_language_model_to_info(&model, &provider);
 113                let model_id = model_info.id.clone();
 114                if !recommended_models.contains(&(model.provider_id(), model.id())) {
 115                    provider_models.push(model_info);
 116                }
 117                models.insert(model_id, model);
 118            }
 119            if !provider_models.is_empty() {
 120                language_model_list.insert(
 121                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 122                    provider_models,
 123                );
 124            }
 125        }
 126
 127        self.models = models;
 128        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 129        self.refresh_models_tx.send(()).ok();
 130    }
 131
 132    fn watch(&self) -> watch::Receiver<()> {
 133        self.refresh_models_rx.clone()
 134    }
 135
 136    pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
 137        self.models.get(model_id).cloned()
 138    }
 139
 140    fn map_language_model_to_info(
 141        model: &Arc<dyn LanguageModel>,
 142        provider: &Arc<dyn LanguageModelProvider>,
 143    ) -> acp_thread::AgentModelInfo {
 144        acp_thread::AgentModelInfo {
 145            id: Self::model_id(model),
 146            name: model.name().0,
 147            description: None,
 148            icon: Some(provider.icon()),
 149        }
 150    }
 151
 152    fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
 153        acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 154    }
 155
 156    fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
 157        let authenticate_all_providers = LanguageModelRegistry::global(cx)
 158            .read(cx)
 159            .providers()
 160            .iter()
 161            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
 162            .collect::<Vec<_>>();
 163
 164        cx.background_spawn(async move {
 165            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
 166                if let Err(err) = authenticate_task.await {
 167                    match err {
 168                        language_model::AuthenticateError::CredentialsNotFound => {
 169                            // Since we're authenticating these providers in the
 170                            // background for the purposes of populating the
 171                            // language selector, we don't care about providers
 172                            // where the credentials are not found.
 173                        }
 174                        language_model::AuthenticateError::ConnectionRefused => {
 175                            // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
 176                            // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
 177                            // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
 178                        }
 179                        _ => {
 180                            // Some providers have noisy failure states that we
 181                            // don't want to spam the logs with every time the
 182                            // language model selector is initialized.
 183                            //
 184                            // Ideally these should have more clear failure modes
 185                            // that we know are safe to ignore here, like what we do
 186                            // with `CredentialsNotFound` above.
 187                            match provider_id.0.as_ref() {
 188                                "lmstudio" | "ollama" => {
 189                                    // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
 190                                    //
 191                                    // These fail noisily, so we don't log them.
 192                                }
 193                                "copilot_chat" => {
 194                                    // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
 195                                }
 196                                _ => {
 197                                    log::error!(
 198                                        "Failed to authenticate provider: {}: {err}",
 199                                        provider_name.0
 200                                    );
 201                                }
 202                            }
 203                        }
 204                    }
 205                }
 206            }
 207        })
 208    }
 209}
 210
 211pub struct NativeAgent {
 212    /// Session ID -> Session mapping
 213    sessions: HashMap<acp::SessionId, Session>,
 214    history: Entity<HistoryStore>,
 215    /// Shared project context for all threads
 216    project_context: Entity<ProjectContext>,
 217    project_context_needs_refresh: watch::Sender<()>,
 218    _maintain_project_context: Task<Result<()>>,
 219    context_server_registry: Entity<ContextServerRegistry>,
 220    /// Shared templates for all threads
 221    templates: Arc<Templates>,
 222    /// Cached model information
 223    models: LanguageModels,
 224    project: Entity<Project>,
 225    prompt_store: Option<Entity<PromptStore>>,
 226    fs: Arc<dyn Fs>,
 227    _subscriptions: Vec<Subscription>,
 228}
 229
 230impl NativeAgent {
 231    pub async fn new(
 232        project: Entity<Project>,
 233        history: Entity<HistoryStore>,
 234        templates: Arc<Templates>,
 235        prompt_store: Option<Entity<PromptStore>>,
 236        fs: Arc<dyn Fs>,
 237        cx: &mut AsyncApp,
 238    ) -> Result<Entity<NativeAgent>> {
 239        log::debug!("Creating new NativeAgent");
 240
 241        let project_context = cx
 242            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 243            .await;
 244
 245        cx.new(|cx| {
 246            let mut subscriptions = vec![
 247                cx.subscribe(&project, Self::handle_project_event),
 248                cx.subscribe(
 249                    &LanguageModelRegistry::global(cx),
 250                    Self::handle_models_updated_event,
 251                ),
 252            ];
 253            if let Some(prompt_store) = prompt_store.as_ref() {
 254                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 255            }
 256
 257            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 258                watch::channel(());
 259            Self {
 260                sessions: HashMap::new(),
 261                history,
 262                project_context: cx.new(|_| project_context),
 263                project_context_needs_refresh: project_context_needs_refresh_tx,
 264                _maintain_project_context: cx.spawn(async move |this, cx| {
 265                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 266                }),
 267                context_server_registry: cx.new(|cx| {
 268                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 269                }),
 270                templates,
 271                models: LanguageModels::new(cx),
 272                project,
 273                prompt_store,
 274                fs,
 275                _subscriptions: subscriptions,
 276            }
 277        })
 278    }
 279
 280    fn register_session(
 281        &mut self,
 282        thread_handle: Entity<Thread>,
 283        cx: &mut Context<Self>,
 284    ) -> Entity<AcpThread> {
 285        let connection = Rc::new(NativeAgentConnection(cx.entity()));
 286
 287        let thread = thread_handle.read(cx);
 288        let session_id = thread.id().clone();
 289        let title = thread.title();
 290        let project = thread.project.clone();
 291        let action_log = thread.action_log.clone();
 292        let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
 293        let acp_thread = cx.new(|cx| {
 294            acp_thread::AcpThread::new(
 295                title,
 296                connection,
 297                project.clone(),
 298                action_log.clone(),
 299                session_id.clone(),
 300                prompt_capabilities_rx,
 301                cx,
 302            )
 303        });
 304
 305        let registry = LanguageModelRegistry::read_global(cx);
 306        let summarization_model = registry.thread_summary_model().map(|c| c.model);
 307
 308        thread_handle.update(cx, |thread, cx| {
 309            thread.set_summarization_model(summarization_model, cx);
 310            thread.add_default_tools(
 311                Rc::new(AcpThreadEnvironment {
 312                    acp_thread: acp_thread.downgrade(),
 313                }) as _,
 314                cx,
 315            )
 316        });
 317
 318        let subscriptions = vec![
 319            cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 320                this.sessions.remove(acp_thread.session_id());
 321            }),
 322            cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
 323            cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
 324            cx.observe(&thread_handle, move |this, thread, cx| {
 325                this.save_thread(thread, cx)
 326            }),
 327        ];
 328
 329        self.sessions.insert(
 330            session_id,
 331            Session {
 332                thread: thread_handle,
 333                acp_thread: acp_thread.downgrade(),
 334                _subscriptions: subscriptions,
 335                pending_save: Task::ready(()),
 336            },
 337        );
 338        acp_thread
 339    }
 340
 341    pub fn models(&self) -> &LanguageModels {
 342        &self.models
 343    }
 344
 345    async fn maintain_project_context(
 346        this: WeakEntity<Self>,
 347        mut needs_refresh: watch::Receiver<()>,
 348        cx: &mut AsyncApp,
 349    ) -> Result<()> {
 350        while needs_refresh.changed().await.is_ok() {
 351            let project_context = this
 352                .update(cx, |this, cx| {
 353                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 354                })?
 355                .await;
 356            this.update(cx, |this, cx| {
 357                this.project_context = cx.new(|_| project_context);
 358            })?;
 359        }
 360
 361        Ok(())
 362    }
 363
 364    fn build_project_context(
 365        project: &Entity<Project>,
 366        prompt_store: Option<&Entity<PromptStore>>,
 367        cx: &mut App,
 368    ) -> Task<ProjectContext> {
 369        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 370        let worktree_tasks = worktrees
 371            .into_iter()
 372            .map(|worktree| {
 373                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 374            })
 375            .collect::<Vec<_>>();
 376        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 377            prompt_store.read_with(cx, |prompt_store, cx| {
 378                let prompts = prompt_store.default_prompt_metadata();
 379                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 380                    let contents = prompt_store.load(prompt_metadata.id, cx);
 381                    async move { (contents.await, prompt_metadata) }
 382                });
 383                cx.background_spawn(future::join_all(load_tasks))
 384            })
 385        } else {
 386            Task::ready(vec![])
 387        };
 388
 389        cx.spawn(async move |_cx| {
 390            let (worktrees, default_user_rules) =
 391                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 392
 393            let worktrees = worktrees
 394                .into_iter()
 395                .map(|(worktree, _rules_error)| {
 396                    // TODO: show error message
 397                    // if let Some(rules_error) = rules_error {
 398                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 399                    // }
 400                    worktree
 401                })
 402                .collect::<Vec<_>>();
 403
 404            let default_user_rules = default_user_rules
 405                .into_iter()
 406                .flat_map(|(contents, prompt_metadata)| match contents {
 407                    Ok(contents) => Some(UserRulesContext {
 408                        uuid: match prompt_metadata.id {
 409                            PromptId::User { uuid } => uuid,
 410                            PromptId::EditWorkflow => return None,
 411                        },
 412                        title: prompt_metadata.title.map(|title| title.to_string()),
 413                        contents,
 414                    }),
 415                    Err(_err) => {
 416                        // TODO: show error message
 417                        // this.update(cx, |_, cx| {
 418                        //     cx.emit(RulesLoadingError {
 419                        //         message: format!("{err:?}").into(),
 420                        //     });
 421                        // })
 422                        // .ok();
 423                        None
 424                    }
 425                })
 426                .collect::<Vec<_>>();
 427
 428            ProjectContext::new(worktrees, default_user_rules)
 429        })
 430    }
 431
 432    fn load_worktree_info_for_system_prompt(
 433        worktree: Entity<Worktree>,
 434        project: Entity<Project>,
 435        cx: &mut App,
 436    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 437        let tree = worktree.read(cx);
 438        let root_name = tree.root_name_str().into();
 439        let abs_path = tree.abs_path();
 440
 441        let mut context = WorktreeContext {
 442            root_name,
 443            abs_path,
 444            rules_file: None,
 445        };
 446
 447        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 448        let Some(rules_task) = rules_task else {
 449            return Task::ready((context, None));
 450        };
 451
 452        cx.spawn(async move |_| {
 453            let (rules_file, rules_file_error) = match rules_task.await {
 454                Ok(rules_file) => (Some(rules_file), None),
 455                Err(err) => (
 456                    None,
 457                    Some(RulesLoadingError {
 458                        message: format!("{err}").into(),
 459                    }),
 460                ),
 461            };
 462            context.rules_file = rules_file;
 463            (context, rules_file_error)
 464        })
 465    }
 466
 467    fn load_worktree_rules_file(
 468        worktree: Entity<Worktree>,
 469        project: Entity<Project>,
 470        cx: &mut App,
 471    ) -> Option<Task<Result<RulesFileContext>>> {
 472        let worktree = worktree.read(cx);
 473        let worktree_id = worktree.id();
 474        let selected_rules_file = RULES_FILE_NAMES
 475            .into_iter()
 476            .filter_map(|name| {
 477                worktree
 478                    .entry_for_path(RelPath::unix(name).unwrap())
 479                    .filter(|entry| entry.is_file())
 480                    .map(|entry| entry.path.clone())
 481            })
 482            .next();
 483
 484        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 485        // supported. This doesn't seem to occur often in GitHub repositories.
 486        selected_rules_file.map(|path_in_worktree| {
 487            let project_path = ProjectPath {
 488                worktree_id,
 489                path: path_in_worktree.clone(),
 490            };
 491            let buffer_task =
 492                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 493            let rope_task = cx.spawn(async move |cx| {
 494                buffer_task.await?.read_with(cx, |buffer, cx| {
 495                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 496                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 497                })?
 498            });
 499            // Build a string from the rope on a background thread.
 500            cx.background_spawn(async move {
 501                let (project_entry_id, rope) = rope_task.await?;
 502                anyhow::Ok(RulesFileContext {
 503                    path_in_worktree,
 504                    text: rope.to_string().trim().to_string(),
 505                    project_entry_id: project_entry_id.to_usize(),
 506                })
 507            })
 508        })
 509    }
 510
 511    fn handle_thread_title_updated(
 512        &mut self,
 513        thread: Entity<Thread>,
 514        _: &TitleUpdated,
 515        cx: &mut Context<Self>,
 516    ) {
 517        let session_id = thread.read(cx).id();
 518        let Some(session) = self.sessions.get(session_id) else {
 519            return;
 520        };
 521        let thread = thread.downgrade();
 522        let acp_thread = session.acp_thread.clone();
 523        cx.spawn(async move |_, cx| {
 524            let title = thread.read_with(cx, |thread, _| thread.title())?;
 525            let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
 526            task.await
 527        })
 528        .detach_and_log_err(cx);
 529    }
 530
 531    fn handle_thread_token_usage_updated(
 532        &mut self,
 533        thread: Entity<Thread>,
 534        usage: &TokenUsageUpdated,
 535        cx: &mut Context<Self>,
 536    ) {
 537        let Some(session) = self.sessions.get(thread.read(cx).id()) else {
 538            return;
 539        };
 540        session
 541            .acp_thread
 542            .update(cx, |acp_thread, cx| {
 543                acp_thread.update_token_usage(usage.0.clone(), cx);
 544            })
 545            .ok();
 546    }
 547
 548    fn handle_project_event(
 549        &mut self,
 550        _project: Entity<Project>,
 551        event: &project::Event,
 552        _cx: &mut Context<Self>,
 553    ) {
 554        match event {
 555            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 556                self.project_context_needs_refresh.send(()).ok();
 557            }
 558            project::Event::WorktreeUpdatedEntries(_, items) => {
 559                if items.iter().any(|(path, _, _)| {
 560                    RULES_FILE_NAMES
 561                        .iter()
 562                        .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
 563                }) {
 564                    self.project_context_needs_refresh.send(()).ok();
 565                }
 566            }
 567            _ => {}
 568        }
 569    }
 570
 571    fn handle_prompts_updated_event(
 572        &mut self,
 573        _prompt_store: Entity<PromptStore>,
 574        _event: &prompt_store::PromptsUpdatedEvent,
 575        _cx: &mut Context<Self>,
 576    ) {
 577        self.project_context_needs_refresh.send(()).ok();
 578    }
 579
 580    fn handle_models_updated_event(
 581        &mut self,
 582        _registry: Entity<LanguageModelRegistry>,
 583        _event: &language_model::Event,
 584        cx: &mut Context<Self>,
 585    ) {
 586        self.models.refresh_list(cx);
 587
 588        let registry = LanguageModelRegistry::read_global(cx);
 589        let default_model = registry.default_model().map(|m| m.model);
 590        let summarization_model = registry.thread_summary_model().map(|m| m.model);
 591
 592        for session in self.sessions.values_mut() {
 593            session.thread.update(cx, |thread, cx| {
 594                if thread.model().is_none()
 595                    && let Some(model) = default_model.clone()
 596                {
 597                    thread.set_model(model, cx);
 598                    cx.notify();
 599                }
 600                thread.set_summarization_model(summarization_model.clone(), cx);
 601            });
 602        }
 603    }
 604
 605    pub fn open_thread(
 606        &mut self,
 607        id: acp::SessionId,
 608        cx: &mut Context<Self>,
 609    ) -> Task<Result<Entity<AcpThread>>> {
 610        let database_future = ThreadsDatabase::connect(cx);
 611        cx.spawn(async move |this, cx| {
 612            let database = database_future.await.map_err(|err| anyhow!(err))?;
 613            let db_thread = database
 614                .load_thread(id.clone())
 615                .await?
 616                .with_context(|| format!("no thread found with ID: {id:?}"))?;
 617
 618            let thread = this.update(cx, |this, cx| {
 619                let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
 620                cx.new(|cx| {
 621                    Thread::from_db(
 622                        id.clone(),
 623                        db_thread,
 624                        this.project.clone(),
 625                        this.project_context.clone(),
 626                        this.context_server_registry.clone(),
 627                        action_log.clone(),
 628                        this.templates.clone(),
 629                        cx,
 630                    )
 631                })
 632            })?;
 633            let acp_thread =
 634                this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
 635            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
 636            cx.update(|cx| {
 637                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
 638            })?
 639            .await?;
 640            Ok(acp_thread)
 641        })
 642    }
 643
 644    pub fn thread_summary(
 645        &mut self,
 646        id: acp::SessionId,
 647        cx: &mut Context<Self>,
 648    ) -> Task<Result<SharedString>> {
 649        let thread = self.open_thread(id.clone(), cx);
 650        cx.spawn(async move |this, cx| {
 651            let acp_thread = thread.await?;
 652            let result = this
 653                .update(cx, |this, cx| {
 654                    this.sessions
 655                        .get(&id)
 656                        .unwrap()
 657                        .thread
 658                        .update(cx, |thread, cx| thread.summary(cx))
 659                })?
 660                .await?;
 661            drop(acp_thread);
 662            Ok(result)
 663        })
 664    }
 665
 666    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
 667        if thread.read(cx).is_empty() {
 668            return;
 669        }
 670
 671        let database_future = ThreadsDatabase::connect(cx);
 672        let (id, db_thread) =
 673            thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
 674        let Some(session) = self.sessions.get_mut(&id) else {
 675            return;
 676        };
 677        let history = self.history.clone();
 678        session.pending_save = cx.spawn(async move |_, cx| {
 679            let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
 680                return;
 681            };
 682            let db_thread = db_thread.await;
 683            database.save_thread(id, db_thread).await.log_err();
 684            history.update(cx, |history, cx| history.reload(cx)).ok();
 685        });
 686    }
 687}
 688
 689/// Wrapper struct that implements the AgentConnection trait
 690#[derive(Clone)]
 691pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 692
 693impl NativeAgentConnection {
 694    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 695        self.0
 696            .read(cx)
 697            .sessions
 698            .get(session_id)
 699            .map(|session| session.thread.clone())
 700    }
 701
 702    fn run_turn(
 703        &self,
 704        session_id: acp::SessionId,
 705        cx: &mut App,
 706        f: impl 'static
 707        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 708    ) -> Task<Result<acp::PromptResponse>> {
 709        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 710            agent
 711                .sessions
 712                .get_mut(&session_id)
 713                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 714        }) else {
 715            return Task::ready(Err(anyhow!("Session not found")));
 716        };
 717        log::debug!("Found session for: {}", session_id);
 718
 719        let response_stream = match f(thread, cx) {
 720            Ok(stream) => stream,
 721            Err(err) => return Task::ready(Err(err)),
 722        };
 723        Self::handle_thread_events(response_stream, acp_thread, cx)
 724    }
 725
 726    fn handle_thread_events(
 727        mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 728        acp_thread: WeakEntity<AcpThread>,
 729        cx: &App,
 730    ) -> Task<Result<acp::PromptResponse>> {
 731        cx.spawn(async move |cx| {
 732            // Handle response stream and forward to session.acp_thread
 733            while let Some(result) = events.next().await {
 734                match result {
 735                    Ok(event) => {
 736                        log::trace!("Received completion event: {:?}", event);
 737
 738                        match event {
 739                            ThreadEvent::UserMessage(message) => {
 740                                acp_thread.update(cx, |thread, cx| {
 741                                    for content in message.content {
 742                                        thread.push_user_content_block(
 743                                            Some(message.id.clone()),
 744                                            content.into(),
 745                                            cx,
 746                                        );
 747                                    }
 748                                })?;
 749                            }
 750                            ThreadEvent::AgentText(text) => {
 751                                acp_thread.update(cx, |thread, cx| {
 752                                    thread.push_assistant_content_block(
 753                                        acp::ContentBlock::Text(acp::TextContent {
 754                                            text,
 755                                            annotations: None,
 756                                            meta: None,
 757                                        }),
 758                                        false,
 759                                        cx,
 760                                    )
 761                                })?;
 762                            }
 763                            ThreadEvent::AgentThinking(text) => {
 764                                acp_thread.update(cx, |thread, cx| {
 765                                    thread.push_assistant_content_block(
 766                                        acp::ContentBlock::Text(acp::TextContent {
 767                                            text,
 768                                            annotations: None,
 769                                            meta: None,
 770                                        }),
 771                                        true,
 772                                        cx,
 773                                    )
 774                                })?;
 775                            }
 776                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 777                                tool_call,
 778                                options,
 779                                response,
 780                            }) => {
 781                                let outcome_task = acp_thread.update(cx, |thread, cx| {
 782                                    thread.request_tool_call_authorization(
 783                                        tool_call, options, true, cx,
 784                                    )
 785                                })??;
 786                                cx.background_spawn(async move {
 787                                    if let acp::RequestPermissionOutcome::Selected { option_id } =
 788                                        outcome_task.await
 789                                    {
 790                                        response
 791                                            .send(option_id)
 792                                            .map(|_| anyhow!("authorization receiver was dropped"))
 793                                            .log_err();
 794                                    }
 795                                })
 796                                .detach();
 797                            }
 798                            ThreadEvent::ToolCall(tool_call) => {
 799                                acp_thread.update(cx, |thread, cx| {
 800                                    thread.upsert_tool_call(tool_call, cx)
 801                                })??;
 802                            }
 803                            ThreadEvent::ToolCallUpdate(update) => {
 804                                acp_thread.update(cx, |thread, cx| {
 805                                    thread.update_tool_call(update, cx)
 806                                })??;
 807                            }
 808                            ThreadEvent::Retry(status) => {
 809                                acp_thread.update(cx, |thread, cx| {
 810                                    thread.update_retry_status(status, cx)
 811                                })?;
 812                            }
 813                            ThreadEvent::Stop(stop_reason) => {
 814                                log::debug!("Assistant message complete: {:?}", stop_reason);
 815                                return Ok(acp::PromptResponse {
 816                                    stop_reason,
 817                                    meta: None,
 818                                });
 819                            }
 820                        }
 821                    }
 822                    Err(e) => {
 823                        log::error!("Error in model response stream: {:?}", e);
 824                        return Err(e);
 825                    }
 826                }
 827            }
 828
 829            log::debug!("Response stream completed");
 830            anyhow::Ok(acp::PromptResponse {
 831                stop_reason: acp::StopReason::EndTurn,
 832                meta: None,
 833            })
 834        })
 835    }
 836}
 837
 838struct NativeAgentModelSelector {
 839    session_id: acp::SessionId,
 840    connection: NativeAgentConnection,
 841}
 842
 843impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
 844    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 845        log::debug!("NativeAgentConnection::list_models called");
 846        let list = self.connection.0.read(cx).models.model_list.clone();
 847        Task::ready(if list.is_empty() {
 848            Err(anyhow::anyhow!("No models available"))
 849        } else {
 850            Ok(list)
 851        })
 852    }
 853
 854    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
 855        log::debug!(
 856            "Setting model for session {}: {}",
 857            self.session_id,
 858            model_id
 859        );
 860        let Some(thread) = self
 861            .connection
 862            .0
 863            .read(cx)
 864            .sessions
 865            .get(&self.session_id)
 866            .map(|session| session.thread.clone())
 867        else {
 868            return Task::ready(Err(anyhow!("Session not found")));
 869        };
 870
 871        let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
 872            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 873        };
 874
 875        thread.update(cx, |thread, cx| {
 876            thread.set_model(model.clone(), cx);
 877        });
 878
 879        update_settings_file(
 880            self.connection.0.read(cx).fs.clone(),
 881            cx,
 882            move |settings, _cx| {
 883                let provider = model.provider_id().0.to_string();
 884                let model = model.id().0.to_string();
 885                settings
 886                    .agent
 887                    .get_or_insert_default()
 888                    .set_model(LanguageModelSelection {
 889                        provider: provider.into(),
 890                        model,
 891                    });
 892            },
 893        );
 894
 895        Task::ready(Ok(()))
 896    }
 897
 898    fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
 899        let Some(thread) = self
 900            .connection
 901            .0
 902            .read(cx)
 903            .sessions
 904            .get(&self.session_id)
 905            .map(|session| session.thread.clone())
 906        else {
 907            return Task::ready(Err(anyhow!("Session not found")));
 908        };
 909        let Some(model) = thread.read(cx).model() else {
 910            return Task::ready(Err(anyhow!("Model not found")));
 911        };
 912        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 913        else {
 914            return Task::ready(Err(anyhow!("Provider not found")));
 915        };
 916        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 917            model, &provider,
 918        )))
 919    }
 920
 921    fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
 922        Some(self.connection.0.read(cx).models.watch())
 923    }
 924}
 925
 926impl acp_thread::AgentConnection for NativeAgentConnection {
 927    fn new_thread(
 928        self: Rc<Self>,
 929        project: Entity<Project>,
 930        cwd: &Path,
 931        cx: &mut App,
 932    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 933        let agent = self.0.clone();
 934        log::debug!("Creating new thread for project at: {:?}", cwd);
 935
 936        cx.spawn(async move |cx| {
 937            log::debug!("Starting thread creation in async context");
 938
 939            // Create Thread
 940            let thread = agent.update(
 941                cx,
 942                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 943                    // Fetch default model from registry settings
 944                    let registry = LanguageModelRegistry::read_global(cx);
 945                    // Log available models for debugging
 946                    let available_count = registry.available_models(cx).count();
 947                    log::debug!("Total available models: {}", available_count);
 948
 949                    let default_model = registry.default_model().and_then(|default_model| {
 950                        agent
 951                            .models
 952                            .model_from_id(&LanguageModels::model_id(&default_model.model))
 953                    });
 954                    Ok(cx.new(|cx| {
 955                        Thread::new(
 956                            project.clone(),
 957                            agent.project_context.clone(),
 958                            agent.context_server_registry.clone(),
 959                            agent.templates.clone(),
 960                            default_model,
 961                            cx,
 962                        )
 963                    }))
 964                },
 965            )??;
 966            agent.update(cx, |agent, cx| agent.register_session(thread, cx))
 967        })
 968    }
 969
 970    fn auth_methods(&self) -> &[acp::AuthMethod] {
 971        &[] // No auth for in-process
 972    }
 973
 974    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 975        Task::ready(Ok(()))
 976    }
 977
 978    fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 979        Some(Rc::new(NativeAgentModelSelector {
 980            session_id: session_id.clone(),
 981            connection: self.clone(),
 982        }) as Rc<dyn AgentModelSelector>)
 983    }
 984
 985    fn prompt(
 986        &self,
 987        id: Option<acp_thread::UserMessageId>,
 988        params: acp::PromptRequest,
 989        cx: &mut App,
 990    ) -> Task<Result<acp::PromptResponse>> {
 991        let id = id.expect("UserMessageId is required");
 992        let session_id = params.session_id.clone();
 993        log::info!("Received prompt request for session: {}", session_id);
 994        log::debug!("Prompt blocks count: {}", params.prompt.len());
 995
 996        self.run_turn(session_id, cx, |thread, cx| {
 997            let content: Vec<UserMessageContent> = params
 998                .prompt
 999                .into_iter()
1000                .map(Into::into)
1001                .collect::<Vec<_>>();
1002            log::debug!("Converted prompt to message: {} chars", content.len());
1003            log::debug!("Message id: {:?}", id);
1004            log::debug!("Message content: {:?}", content);
1005
1006            thread.update(cx, |thread, cx| thread.send(id, content, cx))
1007        })
1008    }
1009
1010    fn resume(
1011        &self,
1012        session_id: &acp::SessionId,
1013        _cx: &App,
1014    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1015        Some(Rc::new(NativeAgentSessionResume {
1016            connection: self.clone(),
1017            session_id: session_id.clone(),
1018        }) as _)
1019    }
1020
1021    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1022        log::info!("Cancelling on session: {}", session_id);
1023        self.0.update(cx, |agent, cx| {
1024            if let Some(agent) = agent.sessions.get(session_id) {
1025                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1026            }
1027        });
1028    }
1029
1030    fn truncate(
1031        &self,
1032        session_id: &agent_client_protocol::SessionId,
1033        cx: &App,
1034    ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1035        self.0.read_with(cx, |agent, _cx| {
1036            agent.sessions.get(session_id).map(|session| {
1037                Rc::new(NativeAgentSessionTruncate {
1038                    thread: session.thread.clone(),
1039                    acp_thread: session.acp_thread.clone(),
1040                }) as _
1041            })
1042        })
1043    }
1044
1045    fn set_title(
1046        &self,
1047        session_id: &acp::SessionId,
1048        _cx: &App,
1049    ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1050        Some(Rc::new(NativeAgentSessionSetTitle {
1051            connection: self.clone(),
1052            session_id: session_id.clone(),
1053        }) as _)
1054    }
1055
1056    fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1057        Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1058    }
1059
1060    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1061        self
1062    }
1063}
1064
1065impl acp_thread::AgentTelemetry for NativeAgentConnection {
1066    fn agent_name(&self) -> String {
1067        "Zed".into()
1068    }
1069
1070    fn thread_data(
1071        &self,
1072        session_id: &acp::SessionId,
1073        cx: &mut App,
1074    ) -> Task<Result<serde_json::Value>> {
1075        let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1076            return Task::ready(Err(anyhow!("Session not found")));
1077        };
1078
1079        let task = session.thread.read(cx).to_db(cx);
1080        cx.background_spawn(async move {
1081            serde_json::to_value(task.await).context("Failed to serialize thread")
1082        })
1083    }
1084}
1085
1086struct NativeAgentSessionTruncate {
1087    thread: Entity<Thread>,
1088    acp_thread: WeakEntity<AcpThread>,
1089}
1090
1091impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1092    fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1093        match self.thread.update(cx, |thread, cx| {
1094            thread.truncate(message_id.clone(), cx)?;
1095            Ok(thread.latest_token_usage())
1096        }) {
1097            Ok(usage) => {
1098                self.acp_thread
1099                    .update(cx, |thread, cx| {
1100                        thread.update_token_usage(usage, cx);
1101                    })
1102                    .ok();
1103                Task::ready(Ok(()))
1104            }
1105            Err(error) => Task::ready(Err(error)),
1106        }
1107    }
1108}
1109
1110struct NativeAgentSessionResume {
1111    connection: NativeAgentConnection,
1112    session_id: acp::SessionId,
1113}
1114
1115impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1116    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1117        self.connection
1118            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1119                thread.update(cx, |thread, cx| thread.resume(cx))
1120            })
1121    }
1122}
1123
1124struct NativeAgentSessionSetTitle {
1125    connection: NativeAgentConnection,
1126    session_id: acp::SessionId,
1127}
1128
1129impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1130    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1131        let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1132            return Task::ready(Err(anyhow!("session not found")));
1133        };
1134        let thread = session.thread.clone();
1135        thread.update(cx, |thread, cx| thread.set_title(title, cx));
1136        Task::ready(Ok(()))
1137    }
1138}
1139
1140pub struct AcpThreadEnvironment {
1141    acp_thread: WeakEntity<AcpThread>,
1142}
1143
1144impl ThreadEnvironment for AcpThreadEnvironment {
1145    fn create_terminal(
1146        &self,
1147        command: String,
1148        cwd: Option<PathBuf>,
1149        output_byte_limit: Option<u64>,
1150        cx: &mut AsyncApp,
1151    ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1152        let task = self.acp_thread.update(cx, |thread, cx| {
1153            thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1154        });
1155
1156        let acp_thread = self.acp_thread.clone();
1157        cx.spawn(async move |cx| {
1158            let terminal = task?.await?;
1159
1160            let (drop_tx, drop_rx) = oneshot::channel();
1161            let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1162
1163            cx.spawn(async move |cx| {
1164                drop_rx.await.ok();
1165                acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1166            })
1167            .detach();
1168
1169            let handle = AcpTerminalHandle {
1170                terminal,
1171                _drop_tx: Some(drop_tx),
1172            };
1173
1174            Ok(Rc::new(handle) as _)
1175        })
1176    }
1177}
1178
1179pub struct AcpTerminalHandle {
1180    terminal: Entity<acp_thread::Terminal>,
1181    _drop_tx: Option<oneshot::Sender<()>>,
1182}
1183
1184impl TerminalHandle for AcpTerminalHandle {
1185    fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1186        self.terminal.read_with(cx, |term, _cx| term.id().clone())
1187    }
1188
1189    fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1190        self.terminal
1191            .read_with(cx, |term, _cx| term.wait_for_exit())
1192    }
1193
1194    fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1195        self.terminal
1196            .read_with(cx, |term, cx| term.current_output(cx))
1197    }
1198}
1199
1200#[cfg(test)]
1201mod tests {
1202    use crate::HistoryEntryId;
1203
1204    use super::*;
1205    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1206    use fs::FakeFs;
1207    use gpui::TestAppContext;
1208    use indoc::formatdoc;
1209    use language_model::fake_provider::FakeLanguageModel;
1210    use serde_json::json;
1211    use settings::SettingsStore;
1212    use util::{path, rel_path::rel_path};
1213
1214    #[gpui::test]
1215    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1216        init_test(cx);
1217        let fs = FakeFs::new(cx.executor());
1218        fs.insert_tree(
1219            "/",
1220            json!({
1221                "a": {}
1222            }),
1223        )
1224        .await;
1225        let project = Project::test(fs.clone(), [], cx).await;
1226        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1227        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1228        let agent = NativeAgent::new(
1229            project.clone(),
1230            history_store,
1231            Templates::new(),
1232            None,
1233            fs.clone(),
1234            &mut cx.to_async(),
1235        )
1236        .await
1237        .unwrap();
1238        agent.read_with(cx, |agent, cx| {
1239            assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1240        });
1241
1242        let worktree = project
1243            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1244            .await
1245            .unwrap();
1246        cx.run_until_parked();
1247        agent.read_with(cx, |agent, cx| {
1248            assert_eq!(
1249                agent.project_context.read(cx).worktrees,
1250                vec![WorktreeContext {
1251                    root_name: "a".into(),
1252                    abs_path: Path::new("/a").into(),
1253                    rules_file: None
1254                }]
1255            )
1256        });
1257
1258        // Creating `/a/.rules` updates the project context.
1259        fs.insert_file("/a/.rules", Vec::new()).await;
1260        cx.run_until_parked();
1261        agent.read_with(cx, |agent, cx| {
1262            let rules_entry = worktree
1263                .read(cx)
1264                .entry_for_path(rel_path(".rules"))
1265                .unwrap();
1266            assert_eq!(
1267                agent.project_context.read(cx).worktrees,
1268                vec![WorktreeContext {
1269                    root_name: "a".into(),
1270                    abs_path: Path::new("/a").into(),
1271                    rules_file: Some(RulesFileContext {
1272                        path_in_worktree: rel_path(".rules").into(),
1273                        text: "".into(),
1274                        project_entry_id: rules_entry.id.to_usize()
1275                    })
1276                }]
1277            )
1278        });
1279    }
1280
1281    #[gpui::test]
1282    async fn test_listing_models(cx: &mut TestAppContext) {
1283        init_test(cx);
1284        let fs = FakeFs::new(cx.executor());
1285        fs.insert_tree("/", json!({ "a": {}  })).await;
1286        let project = Project::test(fs.clone(), [], cx).await;
1287        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1288        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1289        let connection = NativeAgentConnection(
1290            NativeAgent::new(
1291                project.clone(),
1292                history_store,
1293                Templates::new(),
1294                None,
1295                fs.clone(),
1296                &mut cx.to_async(),
1297            )
1298            .await
1299            .unwrap(),
1300        );
1301
1302        // Create a thread/session
1303        let acp_thread = cx
1304            .update(|cx| {
1305                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1306            })
1307            .await
1308            .unwrap();
1309
1310        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1311
1312        let models = cx
1313            .update(|cx| {
1314                connection
1315                    .model_selector(&session_id)
1316                    .unwrap()
1317                    .list_models(cx)
1318            })
1319            .await
1320            .unwrap();
1321
1322        let acp_thread::AgentModelList::Grouped(models) = models else {
1323            panic!("Unexpected model group");
1324        };
1325        assert_eq!(
1326            models,
1327            IndexMap::from_iter([(
1328                AgentModelGroupName("Fake".into()),
1329                vec![AgentModelInfo {
1330                    id: acp::ModelId("fake/fake".into()),
1331                    name: "Fake".into(),
1332                    description: None,
1333                    icon: Some(ui::IconName::ZedAssistant),
1334                }]
1335            )])
1336        );
1337    }
1338
1339    #[gpui::test]
1340    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1341        init_test(cx);
1342        let fs = FakeFs::new(cx.executor());
1343        fs.create_dir(paths::settings_file().parent().unwrap())
1344            .await
1345            .unwrap();
1346        fs.insert_file(
1347            paths::settings_file(),
1348            json!({
1349                "agent": {
1350                    "default_model": {
1351                        "provider": "foo",
1352                        "model": "bar"
1353                    }
1354                }
1355            })
1356            .to_string()
1357            .into_bytes(),
1358        )
1359        .await;
1360        let project = Project::test(fs.clone(), [], cx).await;
1361
1362        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1363        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1364
1365        // Create the agent and connection
1366        let agent = NativeAgent::new(
1367            project.clone(),
1368            history_store,
1369            Templates::new(),
1370            None,
1371            fs.clone(),
1372            &mut cx.to_async(),
1373        )
1374        .await
1375        .unwrap();
1376        let connection = NativeAgentConnection(agent.clone());
1377
1378        // Create a thread/session
1379        let acp_thread = cx
1380            .update(|cx| {
1381                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1382            })
1383            .await
1384            .unwrap();
1385
1386        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1387
1388        // Select a model
1389        let selector = connection.model_selector(&session_id).unwrap();
1390        let model_id = acp::ModelId("fake/fake".into());
1391        cx.update(|cx| selector.select_model(model_id.clone(), cx))
1392            .await
1393            .unwrap();
1394
1395        // Verify the thread has the selected model
1396        agent.read_with(cx, |agent, _| {
1397            let session = agent.sessions.get(&session_id).unwrap();
1398            session.thread.read_with(cx, |thread, _| {
1399                assert_eq!(thread.model().unwrap().id().0, "fake");
1400            });
1401        });
1402
1403        cx.run_until_parked();
1404
1405        // Verify settings file was updated
1406        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1407        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1408
1409        // Check that the agent settings contain the selected model
1410        assert_eq!(
1411            settings_json["agent"]["default_model"]["model"],
1412            json!("fake")
1413        );
1414        assert_eq!(
1415            settings_json["agent"]["default_model"]["provider"],
1416            json!("fake")
1417        );
1418    }
1419
1420    #[gpui::test]
1421    #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1422    async fn test_save_load_thread(cx: &mut TestAppContext) {
1423        init_test(cx);
1424        let fs = FakeFs::new(cx.executor());
1425        fs.insert_tree(
1426            "/",
1427            json!({
1428                "a": {
1429                    "b.md": "Lorem"
1430                }
1431            }),
1432        )
1433        .await;
1434        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1435        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1436        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1437        let agent = NativeAgent::new(
1438            project.clone(),
1439            history_store.clone(),
1440            Templates::new(),
1441            None,
1442            fs.clone(),
1443            &mut cx.to_async(),
1444        )
1445        .await
1446        .unwrap();
1447        let connection = Rc::new(NativeAgentConnection(agent.clone()));
1448
1449        let acp_thread = cx
1450            .update(|cx| {
1451                connection
1452                    .clone()
1453                    .new_thread(project.clone(), Path::new(""), cx)
1454            })
1455            .await
1456            .unwrap();
1457        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1458        let thread = agent.read_with(cx, |agent, _| {
1459            agent.sessions.get(&session_id).unwrap().thread.clone()
1460        });
1461
1462        // Ensure empty threads are not saved, even if they get mutated.
1463        let model = Arc::new(FakeLanguageModel::default());
1464        let summary_model = Arc::new(FakeLanguageModel::default());
1465        thread.update(cx, |thread, cx| {
1466            thread.set_model(model.clone(), cx);
1467            thread.set_summarization_model(Some(summary_model.clone()), cx);
1468        });
1469        cx.run_until_parked();
1470        assert_eq!(history_entries(&history_store, cx), vec![]);
1471
1472        let send = acp_thread.update(cx, |thread, cx| {
1473            thread.send(
1474                vec![
1475                    "What does ".into(),
1476                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
1477                        name: "b.md".into(),
1478                        uri: MentionUri::File {
1479                            abs_path: path!("/a/b.md").into(),
1480                        }
1481                        .to_uri()
1482                        .to_string(),
1483                        annotations: None,
1484                        description: None,
1485                        mime_type: None,
1486                        size: None,
1487                        title: None,
1488                        meta: None,
1489                    }),
1490                    " mean?".into(),
1491                ],
1492                cx,
1493            )
1494        });
1495        let send = cx.foreground_executor().spawn(send);
1496        cx.run_until_parked();
1497
1498        model.send_last_completion_stream_text_chunk("Lorem.");
1499        model.end_last_completion_stream();
1500        cx.run_until_parked();
1501        summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
1502        summary_model.end_last_completion_stream();
1503
1504        send.await.unwrap();
1505        let uri = MentionUri::File {
1506            abs_path: path!("/a/b.md").into(),
1507        }
1508        .to_uri();
1509        acp_thread.read_with(cx, |thread, cx| {
1510            assert_eq!(
1511                thread.to_markdown(cx),
1512                formatdoc! {"
1513                    ## User
1514
1515                    What does [@b.md]({uri}) mean?
1516
1517                    ## Assistant
1518
1519                    Lorem.
1520
1521                "}
1522            )
1523        });
1524
1525        cx.run_until_parked();
1526
1527        // Drop the ACP thread, which should cause the session to be dropped as well.
1528        cx.update(|_| {
1529            drop(thread);
1530            drop(acp_thread);
1531        });
1532        agent.read_with(cx, |agent, _| {
1533            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1534        });
1535
1536        // Ensure the thread can be reloaded from disk.
1537        assert_eq!(
1538            history_entries(&history_store, cx),
1539            vec![(
1540                HistoryEntryId::AcpThread(session_id.clone()),
1541                "Explaining /a/b.md".into()
1542            )]
1543        );
1544        let acp_thread = agent
1545            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1546            .await
1547            .unwrap();
1548        acp_thread.read_with(cx, |thread, cx| {
1549            assert_eq!(
1550                thread.to_markdown(cx),
1551                formatdoc! {"
1552                    ## User
1553
1554                    What does [@b.md]({uri}) mean?
1555
1556                    ## Assistant
1557
1558                    Lorem.
1559
1560                "}
1561            )
1562        });
1563    }
1564
1565    fn history_entries(
1566        history: &Entity<HistoryStore>,
1567        cx: &mut TestAppContext,
1568    ) -> Vec<(HistoryEntryId, String)> {
1569        history.read_with(cx, |history, _| {
1570            history
1571                .entries()
1572                .map(|e| (e.id(), e.title().to_string()))
1573                .collect::<Vec<_>>()
1574        })
1575    }
1576
1577    fn init_test(cx: &mut TestAppContext) {
1578        env_logger::try_init().ok();
1579        cx.update(|cx| {
1580            let settings_store = SettingsStore::test(cx);
1581            cx.set_global(settings_store);
1582            Project::init_settings(cx);
1583            agent_settings::init(cx);
1584            language::init(cx);
1585            LanguageModelRegistry::test(cx);
1586        });
1587    }
1588}