agent.rs

   1use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
   2use crate::{
   3    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
   4    EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
   5    OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
   6    UserMessageContent, WebSearchTool, templates::Templates,
   7};
   8use crate::{DbThread, ThreadId, ThreadsDatabase};
   9use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
  10use agent_client_protocol as acp;
  11use agent_settings::AgentSettings;
  12use anyhow::{Context as _, Result, anyhow};
  13use collections::{HashSet, IndexMap};
  14use fs::Fs;
  15use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
  16use futures::future::Shared;
  17use futures::{SinkExt, StreamExt, future};
  18use gpui::{
  19    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  20};
  21use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
  22use project::{Project, ProjectItem, ProjectPath, Worktree};
  23use prompt_store::{
  24    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  25};
  26use settings::update_settings_file;
  27use std::any::Any;
  28use std::cell::RefCell;
  29use std::collections::HashMap;
  30use std::path::Path;
  31use std::rc::Rc;
  32use std::sync::Arc;
  33use util::ResultExt;
  34
  35const RULES_FILE_NAMES: [&'static str; 9] = [
  36    ".rules",
  37    ".cursorrules",
  38    ".windsurfrules",
  39    ".clinerules",
  40    ".github/copilot-instructions.md",
  41    "CLAUDE.md",
  42    "AGENT.md",
  43    "AGENTS.md",
  44    "GEMINI.md",
  45];
  46
  47pub struct RulesLoadingError {
  48    pub message: SharedString,
  49}
  50
  51/// Holds both the internal Thread and the AcpThread for a session
  52struct Session {
  53    /// The internal thread that processes messages
  54    thread: Entity<Thread>,
  55    /// The ACP thread that handles protocol communication
  56    acp_thread: WeakEntity<acp_thread::AcpThread>,
  57    _subscription: Subscription,
  58}
  59
  60pub struct LanguageModels {
  61    /// Access language model by ID
  62    models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
  63    /// Cached list for returning language model information
  64    model_list: acp_thread::AgentModelList,
  65    refresh_models_rx: watch::Receiver<()>,
  66    refresh_models_tx: watch::Sender<()>,
  67}
  68
  69impl LanguageModels {
  70    fn new(cx: &App) -> Self {
  71        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  72        let mut this = Self {
  73            models: HashMap::default(),
  74            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  75            refresh_models_rx,
  76            refresh_models_tx,
  77        };
  78        this.refresh_list(cx);
  79        this
  80    }
  81
  82    fn refresh_list(&mut self, cx: &App) {
  83        let providers = LanguageModelRegistry::global(cx)
  84            .read(cx)
  85            .providers()
  86            .into_iter()
  87            .filter(|provider| provider.is_authenticated(cx))
  88            .collect::<Vec<_>>();
  89
  90        let mut language_model_list = IndexMap::default();
  91        let mut recommended_models = HashSet::default();
  92
  93        let mut recommended = Vec::new();
  94        for provider in &providers {
  95            for model in provider.recommended_models(cx) {
  96                recommended_models.insert(model.id());
  97                recommended.push(Self::map_language_model_to_info(&model, &provider));
  98            }
  99        }
 100        if !recommended.is_empty() {
 101            language_model_list.insert(
 102                acp_thread::AgentModelGroupName("Recommended".into()),
 103                recommended,
 104            );
 105        }
 106
 107        let mut models = HashMap::default();
 108        for provider in providers {
 109            let mut provider_models = Vec::new();
 110            for model in provider.provided_models(cx) {
 111                let model_info = Self::map_language_model_to_info(&model, &provider);
 112                let model_id = model_info.id.clone();
 113                if !recommended_models.contains(&model.id()) {
 114                    provider_models.push(model_info);
 115                }
 116                models.insert(model_id, model);
 117            }
 118            if !provider_models.is_empty() {
 119                language_model_list.insert(
 120                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 121                    provider_models,
 122                );
 123            }
 124        }
 125
 126        self.models = models;
 127        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 128        self.refresh_models_tx.send(()).ok();
 129    }
 130
 131    fn watch(&self) -> watch::Receiver<()> {
 132        self.refresh_models_rx.clone()
 133    }
 134
 135    pub fn model_from_id(
 136        &self,
 137        model_id: &acp_thread::AgentModelId,
 138    ) -> Option<Arc<dyn LanguageModel>> {
 139        self.models.get(model_id).cloned()
 140    }
 141
 142    fn map_language_model_to_info(
 143        model: &Arc<dyn LanguageModel>,
 144        provider: &Arc<dyn LanguageModelProvider>,
 145    ) -> acp_thread::AgentModelInfo {
 146        acp_thread::AgentModelInfo {
 147            id: Self::model_id(model),
 148            name: model.name().0,
 149            icon: Some(provider.icon()),
 150        }
 151    }
 152
 153    fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
 154        acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 155    }
 156}
 157
 158pub struct NativeAgent {
 159    /// Session ID -> Session mapping
 160    sessions: HashMap<acp::SessionId, Session>,
 161    /// Shared project context for all threads
 162    project_context: Rc<RefCell<ProjectContext>>,
 163    project_context_needs_refresh: watch::Sender<()>,
 164    _maintain_project_context: Task<Result<()>>,
 165    context_server_registry: Entity<ContextServerRegistry>,
 166    /// Shared templates for all threads
 167    templates: Arc<Templates>,
 168    /// Cached model information
 169    models: LanguageModels,
 170    project: Entity<Project>,
 171    prompt_store: Option<Entity<PromptStore>>,
 172    thread_database: Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 173    history_listeners: Vec<UnboundedSender<Vec<AcpThreadMetadata>>>,
 174    fs: Arc<dyn Fs>,
 175    _subscriptions: Vec<Subscription>,
 176}
 177
 178impl NativeAgent {
 179    pub async fn new(
 180        project: Entity<Project>,
 181        templates: Arc<Templates>,
 182        prompt_store: Option<Entity<PromptStore>>,
 183        fs: Arc<dyn Fs>,
 184        cx: &mut AsyncApp,
 185    ) -> Result<Entity<NativeAgent>> {
 186        log::info!("Creating new NativeAgent");
 187
 188        let project_context = cx
 189            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 190            .await;
 191
 192        cx.new(|cx| {
 193            let mut subscriptions = vec![
 194                cx.subscribe(&project, Self::handle_project_event),
 195                cx.subscribe(
 196                    &LanguageModelRegistry::global(cx),
 197                    Self::handle_models_updated_event,
 198                ),
 199            ];
 200            if let Some(prompt_store) = prompt_store.as_ref() {
 201                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 202            }
 203
 204            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 205                watch::channel(());
 206            Self {
 207                sessions: HashMap::new(),
 208                project_context: Rc::new(RefCell::new(project_context)),
 209                project_context_needs_refresh: project_context_needs_refresh_tx,
 210                _maintain_project_context: cx.spawn(async move |this, cx| {
 211                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 212                }),
 213                context_server_registry: cx.new(|cx| {
 214                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 215                }),
 216                thread_database: ThreadsDatabase::connect(cx),
 217                templates,
 218                models: LanguageModels::new(cx),
 219                project,
 220                prompt_store,
 221                fs,
 222                history_listeners: Vec::new(),
 223                _subscriptions: subscriptions,
 224            }
 225        })
 226    }
 227
 228    pub fn models(&self) -> &LanguageModels {
 229        &self.models
 230    }
 231
 232    async fn maintain_project_context(
 233        this: WeakEntity<Self>,
 234        mut needs_refresh: watch::Receiver<()>,
 235        cx: &mut AsyncApp,
 236    ) -> Result<()> {
 237        while needs_refresh.changed().await.is_ok() {
 238            let project_context = this
 239                .update(cx, |this, cx| {
 240                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 241                })?
 242                .await;
 243            this.update(cx, |this, _| this.project_context.replace(project_context))?;
 244        }
 245
 246        Ok(())
 247    }
 248
 249    fn build_project_context(
 250        project: &Entity<Project>,
 251        prompt_store: Option<&Entity<PromptStore>>,
 252        cx: &mut App,
 253    ) -> Task<ProjectContext> {
 254        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 255        let worktree_tasks = worktrees
 256            .into_iter()
 257            .map(|worktree| {
 258                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 259            })
 260            .collect::<Vec<_>>();
 261        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 262            prompt_store.read_with(cx, |prompt_store, cx| {
 263                let prompts = prompt_store.default_prompt_metadata();
 264                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 265                    let contents = prompt_store.load(prompt_metadata.id, cx);
 266                    async move { (contents.await, prompt_metadata) }
 267                });
 268                cx.background_spawn(future::join_all(load_tasks))
 269            })
 270        } else {
 271            Task::ready(vec![])
 272        };
 273
 274        cx.spawn(async move |_cx| {
 275            let (worktrees, default_user_rules) =
 276                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 277
 278            let worktrees = worktrees
 279                .into_iter()
 280                .map(|(worktree, _rules_error)| {
 281                    // TODO: show error message
 282                    // if let Some(rules_error) = rules_error {
 283                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 284                    // }
 285                    worktree
 286                })
 287                .collect::<Vec<_>>();
 288
 289            let default_user_rules = default_user_rules
 290                .into_iter()
 291                .flat_map(|(contents, prompt_metadata)| match contents {
 292                    Ok(contents) => Some(UserRulesContext {
 293                        uuid: match prompt_metadata.id {
 294                            PromptId::User { uuid } => uuid,
 295                            PromptId::EditWorkflow => return None,
 296                        },
 297                        title: prompt_metadata.title.map(|title| title.to_string()),
 298                        contents,
 299                    }),
 300                    Err(_err) => {
 301                        // TODO: show error message
 302                        // this.update(cx, |_, cx| {
 303                        //     cx.emit(RulesLoadingError {
 304                        //         message: format!("{err:?}").into(),
 305                        //     });
 306                        // })
 307                        // .ok();
 308                        None
 309                    }
 310                })
 311                .collect::<Vec<_>>();
 312
 313            ProjectContext::new(worktrees, default_user_rules)
 314        })
 315    }
 316
 317    fn load_worktree_info_for_system_prompt(
 318        worktree: Entity<Worktree>,
 319        project: Entity<Project>,
 320        cx: &mut App,
 321    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 322        let tree = worktree.read(cx);
 323        let root_name = tree.root_name().into();
 324        let abs_path = tree.abs_path();
 325
 326        let mut context = WorktreeContext {
 327            root_name,
 328            abs_path,
 329            rules_file: None,
 330        };
 331
 332        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 333        let Some(rules_task) = rules_task else {
 334            return Task::ready((context, None));
 335        };
 336
 337        cx.spawn(async move |_| {
 338            let (rules_file, rules_file_error) = match rules_task.await {
 339                Ok(rules_file) => (Some(rules_file), None),
 340                Err(err) => (
 341                    None,
 342                    Some(RulesLoadingError {
 343                        message: format!("{err}").into(),
 344                    }),
 345                ),
 346            };
 347            context.rules_file = rules_file;
 348            (context, rules_file_error)
 349        })
 350    }
 351
 352    fn load_worktree_rules_file(
 353        worktree: Entity<Worktree>,
 354        project: Entity<Project>,
 355        cx: &mut App,
 356    ) -> Option<Task<Result<RulesFileContext>>> {
 357        let worktree = worktree.read(cx);
 358        let worktree_id = worktree.id();
 359        let selected_rules_file = RULES_FILE_NAMES
 360            .into_iter()
 361            .filter_map(|name| {
 362                worktree
 363                    .entry_for_path(name)
 364                    .filter(|entry| entry.is_file())
 365                    .map(|entry| entry.path.clone())
 366            })
 367            .next();
 368
 369        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 370        // supported. This doesn't seem to occur often in GitHub repositories.
 371        selected_rules_file.map(|path_in_worktree| {
 372            let project_path = ProjectPath {
 373                worktree_id,
 374                path: path_in_worktree.clone(),
 375            };
 376            let buffer_task =
 377                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 378            let rope_task = cx.spawn(async move |cx| {
 379                buffer_task.await?.read_with(cx, |buffer, cx| {
 380                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 381                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 382                })?
 383            });
 384            // Build a string from the rope on a background thread.
 385            cx.background_spawn(async move {
 386                let (project_entry_id, rope) = rope_task.await?;
 387                anyhow::Ok(RulesFileContext {
 388                    path_in_worktree,
 389                    text: rope.to_string().trim().to_string(),
 390                    project_entry_id: project_entry_id.to_usize(),
 391                })
 392            })
 393        })
 394    }
 395
 396    fn handle_project_event(
 397        &mut self,
 398        _project: Entity<Project>,
 399        event: &project::Event,
 400        _cx: &mut Context<Self>,
 401    ) {
 402        match event {
 403            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 404                self.project_context_needs_refresh.send(()).ok();
 405            }
 406            project::Event::WorktreeUpdatedEntries(_, items) => {
 407                if items.iter().any(|(path, _, _)| {
 408                    RULES_FILE_NAMES
 409                        .iter()
 410                        .any(|name| path.as_ref() == Path::new(name))
 411                }) {
 412                    self.project_context_needs_refresh.send(()).ok();
 413                }
 414            }
 415            _ => {}
 416        }
 417    }
 418
 419    fn handle_prompts_updated_event(
 420        &mut self,
 421        _prompt_store: Entity<PromptStore>,
 422        _event: &prompt_store::PromptsUpdatedEvent,
 423        _cx: &mut Context<Self>,
 424    ) {
 425        self.project_context_needs_refresh.send(()).ok();
 426    }
 427
 428    fn handle_models_updated_event(
 429        &mut self,
 430        _registry: Entity<LanguageModelRegistry>,
 431        _event: &language_model::Event,
 432        cx: &mut Context<Self>,
 433    ) {
 434        self.models.refresh_list(cx);
 435        for session in self.sessions.values_mut() {
 436            session.thread.update(cx, |thread, _| {
 437                let model_id = LanguageModels::model_id(&thread.model());
 438                if let Some(model) = self.models.model_from_id(&model_id) {
 439                    thread.set_model(model.clone());
 440                }
 441            });
 442        }
 443    }
 444}
 445
 446/// Wrapper struct that implements the AgentConnection trait
 447#[derive(Clone)]
 448pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 449
 450impl NativeAgentConnection {
 451    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 452        self.0
 453            .read(cx)
 454            .sessions
 455            .get(session_id)
 456            .map(|session| session.thread.clone())
 457    }
 458
 459    fn run_turn(
 460        &self,
 461        session_id: acp::SessionId,
 462        cx: &mut App,
 463        f: impl 'static
 464        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 465    ) -> Task<Result<acp::PromptResponse>> {
 466        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 467            agent
 468                .sessions
 469                .get_mut(&session_id)
 470                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 471        }) else {
 472            return Task::ready(Err(anyhow!("Session not found")));
 473        };
 474        log::debug!("Found session for: {}", session_id);
 475
 476        let response_stream = match f(thread, cx) {
 477            Ok(stream) => stream,
 478            Err(err) => return Task::ready(Err(err)),
 479        };
 480        Self::handle_thread_events(response_stream, acp_thread, cx)
 481    }
 482
 483    fn handle_thread_events(
 484        mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
 485        acp_thread: WeakEntity<AcpThread>,
 486        cx: &mut App,
 487    ) -> Task<Result<acp::PromptResponse>> {
 488        cx.spawn(async move |cx| {
 489            // Handle response stream and forward to session.acp_thread
 490            while let Some(result) = response_stream.next().await {
 491                match result {
 492                    Ok(event) => {
 493                        log::trace!("Received completion event: {:?}", event);
 494
 495                        match event {
 496                            ThreadEvent::UserMessage(message) => {
 497                                acp_thread.update(cx, |thread, cx| {
 498                                    for content in message.content {
 499                                        thread.push_user_content_block(
 500                                            Some(message.id.clone()),
 501                                            content.into(),
 502                                            cx,
 503                                        );
 504                                    }
 505                                })?;
 506                            }
 507                            ThreadEvent::AgentText(text) => {
 508                                acp_thread.update(cx, |thread, cx| {
 509                                    thread.push_assistant_content_block(
 510                                        acp::ContentBlock::Text(acp::TextContent {
 511                                            text,
 512                                            annotations: None,
 513                                        }),
 514                                        false,
 515                                        cx,
 516                                    )
 517                                })?;
 518                            }
 519                            ThreadEvent::AgentThinking(text) => {
 520                                acp_thread.update(cx, |thread, cx| {
 521                                    thread.push_assistant_content_block(
 522                                        acp::ContentBlock::Text(acp::TextContent {
 523                                            text,
 524                                            annotations: None,
 525                                        }),
 526                                        true,
 527                                        cx,
 528                                    )
 529                                })?;
 530                            }
 531                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 532                                tool_call,
 533                                options,
 534                                response,
 535                            }) => {
 536                                let recv = acp_thread.update(cx, |thread, cx| {
 537                                    thread.request_tool_call_authorization(tool_call, options, cx)
 538                                })?;
 539                                cx.background_spawn(async move {
 540                                    if let Some(recv) = recv.log_err()
 541                                        && let Some(option) = recv
 542                                            .await
 543                                            .context("authorization sender was dropped")
 544                                            .log_err()
 545                                    {
 546                                        response
 547                                            .send(option)
 548                                            .map(|_| anyhow!("authorization receiver was dropped"))
 549                                            .log_err();
 550                                    }
 551                                })
 552                                .detach();
 553                            }
 554                            ThreadEvent::ToolCall(tool_call) => {
 555                                acp_thread.update(cx, |thread, cx| {
 556                                    thread.upsert_tool_call(tool_call, cx)
 557                                })??;
 558                            }
 559                            ThreadEvent::ToolCallUpdate(update) => {
 560                                acp_thread.update(cx, |thread, cx| {
 561                                    thread.update_tool_call(update, cx)
 562                                })??;
 563                            }
 564                            ThreadEvent::Stop(stop_reason) => {
 565                                log::debug!("Assistant message complete: {:?}", stop_reason);
 566                                return Ok(acp::PromptResponse { stop_reason });
 567                            }
 568                        }
 569                    }
 570                    Err(e) => {
 571                        log::error!("Error in model response stream: {:?}", e);
 572                        return Err(e);
 573                    }
 574                }
 575            }
 576
 577            log::info!("Response stream completed");
 578            anyhow::Ok(acp::PromptResponse {
 579                stop_reason: acp::StopReason::EndTurn,
 580            })
 581        })
 582    }
 583
 584    fn register_tools(
 585        thread: &mut Thread,
 586        project: Entity<Project>,
 587        action_log: Entity<action_log::ActionLog>,
 588        cx: &mut Context<Thread>,
 589    ) {
 590        let language_registry = project.read(cx).languages().clone();
 591        thread.add_tool(CopyPathTool::new(project.clone()));
 592        thread.add_tool(CreateDirectoryTool::new(project.clone()));
 593        thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
 594        thread.add_tool(DiagnosticsTool::new(project.clone()));
 595        thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
 596        thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
 597        thread.add_tool(FindPathTool::new(project.clone()));
 598        thread.add_tool(GrepTool::new(project.clone()));
 599        thread.add_tool(ListDirectoryTool::new(project.clone()));
 600        thread.add_tool(MovePathTool::new(project.clone()));
 601        thread.add_tool(NowTool);
 602        thread.add_tool(OpenTool::new(project.clone()));
 603        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
 604        thread.add_tool(TerminalTool::new(project.clone(), cx));
 605        thread.add_tool(ThinkingTool);
 606        thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
 607    }
 608}
 609
 610impl AgentModelSelector for NativeAgentConnection {
 611    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 612        log::debug!("NativeAgentConnection::list_models called");
 613        let list = self.0.read(cx).models.model_list.clone();
 614        Task::ready(if list.is_empty() {
 615            Err(anyhow::anyhow!("No models available"))
 616        } else {
 617            Ok(list)
 618        })
 619    }
 620
 621    fn select_model(
 622        &self,
 623        session_id: acp::SessionId,
 624        model_id: acp_thread::AgentModelId,
 625        cx: &mut App,
 626    ) -> Task<Result<()>> {
 627        log::info!("Setting model for session {}: {}", session_id, model_id);
 628        let Some(thread) = self
 629            .0
 630            .read(cx)
 631            .sessions
 632            .get(&session_id)
 633            .map(|session| session.thread.clone())
 634        else {
 635            return Task::ready(Err(anyhow!("Session not found")));
 636        };
 637
 638        let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
 639            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 640        };
 641
 642        thread.update(cx, |thread, _cx| {
 643            thread.set_model(model.clone());
 644        });
 645
 646        update_settings_file::<AgentSettings>(
 647            self.0.read(cx).fs.clone(),
 648            cx,
 649            move |settings, _cx| {
 650                settings.set_model(model);
 651            },
 652        );
 653
 654        Task::ready(Ok(()))
 655    }
 656
 657    fn selected_model(
 658        &self,
 659        session_id: &acp::SessionId,
 660        cx: &mut App,
 661    ) -> Task<Result<acp_thread::AgentModelInfo>> {
 662        let session_id = session_id.clone();
 663
 664        let Some(thread) = self
 665            .0
 666            .read(cx)
 667            .sessions
 668            .get(&session_id)
 669            .map(|session| session.thread.clone())
 670        else {
 671            return Task::ready(Err(anyhow!("Session not found")));
 672        };
 673        let model = thread.read(cx).model().clone();
 674        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 675        else {
 676            return Task::ready(Err(anyhow!("Provider not found")));
 677        };
 678        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 679            &model, &provider,
 680        )))
 681    }
 682
 683    fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
 684        self.0.read(cx).models.watch()
 685    }
 686}
 687
 688impl acp_thread::AgentConnection for NativeAgentConnection {
 689    fn new_thread(
 690        self: Rc<Self>,
 691        project: Entity<Project>,
 692        cwd: &Path,
 693        cx: &mut App,
 694    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 695        let agent = self.0.clone();
 696        log::info!("Creating new thread for project at: {:?}", cwd);
 697
 698        cx.spawn(async move |cx| {
 699            log::debug!("Starting thread creation in async context");
 700
 701            // Generate session ID
 702            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
 703            log::info!("Created session with ID: {}", session_id);
 704
 705            // Create AcpThread
 706            let acp_thread = cx.update(|cx| {
 707                cx.new(|cx| {
 708                    acp_thread::AcpThread::new(
 709                        "agent2",
 710                        self.clone(),
 711                        project.clone(),
 712                        session_id.clone(),
 713                        cx,
 714                    )
 715                })
 716            })?;
 717            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
 718
 719            // Create Thread
 720            let thread = agent.update(
 721                cx,
 722                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 723                    // Fetch default model from registry settings
 724                    let registry = LanguageModelRegistry::read_global(cx);
 725
 726                    // Log available models for debugging
 727                    let available_count = registry.available_models(cx).count();
 728                    log::debug!("Total available models: {}", available_count);
 729
 730                    let default_model = registry
 731                        .default_model()
 732                        .and_then(|default_model| {
 733                            agent
 734                                .models
 735                                .model_from_id(&LanguageModels::model_id(&default_model.model))
 736                        })
 737                        .ok_or_else(|| {
 738                            log::warn!("No default model configured in settings");
 739                            anyhow!(
 740                                "No default model. Please configure a default model in settings."
 741                            )
 742                        })?;
 743
 744                    let thread = cx.new(|cx| {
 745                        let mut thread = Thread::new(
 746                            project.clone(),
 747                            agent.project_context.clone(),
 748                            agent.context_server_registry.clone(),
 749                            action_log.clone(),
 750                            agent.templates.clone(),
 751                            default_model,
 752                            cx,
 753                        );
 754                        Self::register_tools(&mut thread, project, action_log, cx);
 755                        thread
 756                    });
 757
 758                    Ok(thread)
 759                },
 760            )??;
 761
 762            // Store the session
 763            agent.update(cx, |agent, cx| {
 764                agent.sessions.insert(
 765                    session_id,
 766                    Session {
 767                        thread,
 768                        acp_thread: acp_thread.downgrade(),
 769                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 770                            this.sessions.remove(acp_thread.session_id());
 771                        }),
 772                    },
 773                );
 774            })?;
 775
 776            Ok(acp_thread)
 777        })
 778    }
 779
 780    fn auth_methods(&self) -> &[acp::AuthMethod] {
 781        &[] // No auth for in-process
 782    }
 783
 784    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 785        Task::ready(Ok(()))
 786    }
 787
 788    fn list_threads(&self, cx: &mut App) -> Option<UnboundedReceiver<Vec<AcpThreadMetadata>>> {
 789        let (mut tx, rx) = futures::channel::mpsc::unbounded();
 790        let database = self.0.update(cx, |this, _| {
 791            this.history_listeners.push(tx.clone());
 792            this.thread_database.clone()
 793        });
 794        cx.background_executor()
 795            .spawn(async move {
 796                dbg!("listing!");
 797                let database = database.await.map_err(|e| anyhow!(e))?;
 798                let results = database.list_threads().await?;
 799
 800                dbg!(&results);
 801                tx.send(
 802                    results
 803                        .into_iter()
 804                        .map(|thread| AcpThreadMetadata {
 805                            agent: NATIVE_AGENT_SERVER_NAME.clone(),
 806                            id: thread.id.into(),
 807                            title: thread.title,
 808                            updated_at: thread.updated_at,
 809                        })
 810                        .collect(),
 811                )
 812                .await?;
 813                anyhow::Ok(())
 814            })
 815            .detach_and_log_err(cx);
 816        Some(rx)
 817    }
 818
 819    fn load_thread(
 820        self: Rc<Self>,
 821        project: Entity<Project>,
 822        _cwd: &Path,
 823        session_id: acp::SessionId,
 824        cx: &mut App,
 825    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 826        let thread_id = ThreadId::from(session_id.clone());
 827        let database = self.0.update(cx, |this, _| this.thread_database.clone());
 828        cx.spawn(async move |cx| {
 829            let database = database.await.map_err(|e| anyhow!(e))?;
 830            let db_thread = database
 831                .load_thread(thread_id.clone())
 832                .await?
 833                .context("no such thread found")?;
 834
 835            let acp_thread = cx.update(|cx| {
 836                cx.new(|cx| {
 837                    acp_thread::AcpThread::new(
 838                        db_thread.title.clone(),
 839                        self.clone(),
 840                        project.clone(),
 841                        session_id.clone(),
 842                        cx,
 843                    )
 844                })
 845            })?;
 846            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
 847            let agent = self.0.clone();
 848
 849            // Create Thread
 850            let thread = agent.update(cx, |agent, cx| {
 851                let configured_model = LanguageModelRegistry::global(cx)
 852                    .update(cx, |registry, cx| {
 853                        db_thread
 854                            .model
 855                            .as_ref()
 856                            .and_then(|model| {
 857                                let model = SelectedModel {
 858                                    provider: model.provider.clone().into(),
 859                                    model: model.model.clone().into(),
 860                                };
 861                                registry.select_model(&model, cx)
 862                            })
 863                            .or_else(|| registry.default_model())
 864                    })
 865                    .context("no default model configured")?;
 866
 867                let model = agent
 868                    .models
 869                    .model_from_id(&LanguageModels::model_id(&configured_model.model))
 870                    .context("no model by id")?;
 871
 872                let thread = cx.new(|cx| {
 873                    let mut thread = Thread::from_db(
 874                        thread_id,
 875                        db_thread,
 876                        project.clone(),
 877                        agent.project_context.clone(),
 878                        agent.context_server_registry.clone(),
 879                        action_log.clone(),
 880                        agent.templates.clone(),
 881                        model,
 882                        cx,
 883                    );
 884                    Self::register_tools(&mut thread, project, action_log, cx);
 885                    thread
 886                });
 887
 888                anyhow::Ok(thread)
 889            })??;
 890
 891            // Store the session
 892            agent.update(cx, |agent, cx| {
 893                agent.sessions.insert(
 894                    session_id,
 895                    Session {
 896                        thread: thread.clone(),
 897                        acp_thread: acp_thread.downgrade(),
 898                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 899                            this.sessions.remove(acp_thread.session_id());
 900                        }),
 901                    },
 902                );
 903            })?;
 904
 905            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
 906            cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
 907                .await?;
 908
 909            Ok(acp_thread)
 910        })
 911    }
 912
 913    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 914        Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
 915    }
 916
 917    fn prompt(
 918        &self,
 919        id: Option<acp_thread::UserMessageId>,
 920        params: acp::PromptRequest,
 921        cx: &mut App,
 922    ) -> Task<Result<acp::PromptResponse>> {
 923        let id = id.expect("UserMessageId is required");
 924        let session_id = params.session_id.clone();
 925        log::info!("Received prompt request for session: {}", session_id);
 926        log::debug!("Prompt blocks count: {}", params.prompt.len());
 927
 928        self.run_turn(session_id, cx, |thread, cx| {
 929            let content: Vec<UserMessageContent> = params
 930                .prompt
 931                .into_iter()
 932                .map(Into::into)
 933                .collect::<Vec<_>>();
 934            log::info!("Converted prompt to message: {} chars", content.len());
 935            log::debug!("Message id: {:?}", id);
 936            log::debug!("Message content: {:?}", content);
 937
 938            Ok(thread.update(cx, |thread, cx| {
 939                log::info!(
 940                    "Sending message to thread with model: {:?}",
 941                    thread.model().name()
 942                );
 943                thread.send(id, content, cx)
 944            }))
 945        })
 946    }
 947
 948    fn resume(
 949        &self,
 950        session_id: &acp::SessionId,
 951        _cx: &mut App,
 952    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
 953        Some(Rc::new(NativeAgentSessionResume {
 954            connection: self.clone(),
 955            session_id: session_id.clone(),
 956        }) as _)
 957    }
 958
 959    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
 960        log::info!("Cancelling on session: {}", session_id);
 961        self.0.update(cx, |agent, cx| {
 962            if let Some(agent) = agent.sessions.get(session_id) {
 963                agent.thread.update(cx, |thread, _cx| thread.cancel());
 964            }
 965        });
 966    }
 967
 968    fn session_editor(
 969        &self,
 970        session_id: &agent_client_protocol::SessionId,
 971        cx: &mut App,
 972    ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
 973        self.0.update(cx, |agent, _cx| {
 974            agent
 975                .sessions
 976                .get(session_id)
 977                .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
 978        })
 979    }
 980
 981    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 982        self
 983    }
 984}
 985
 986struct NativeAgentSessionEditor(Entity<Thread>);
 987
 988impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
 989    fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
 990        Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
 991    }
 992}
 993
 994struct NativeAgentSessionResume {
 995    connection: NativeAgentConnection,
 996    session_id: acp::SessionId,
 997}
 998
 999impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1000    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1001        self.connection
1002            .run_turn(self.session_id.clone(), cx, |thread, cx| {
1003                thread.update(cx, |thread, cx| thread.resume(cx))
1004            })
1005    }
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010    use super::*;
1011    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
1012    use fs::FakeFs;
1013    use gpui::TestAppContext;
1014    use serde_json::json;
1015    use settings::SettingsStore;
1016
1017    #[gpui::test]
1018    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1019        init_test(cx);
1020        let fs = FakeFs::new(cx.executor());
1021        fs.insert_tree(
1022            "/",
1023            json!({
1024                "a": {}
1025            }),
1026        )
1027        .await;
1028        let project = Project::test(fs.clone(), [], cx).await;
1029        let agent = NativeAgent::new(
1030            project.clone(),
1031            Templates::new(),
1032            None,
1033            fs.clone(),
1034            &mut cx.to_async(),
1035        )
1036        .await
1037        .unwrap();
1038        agent.read_with(cx, |agent, _| {
1039            assert_eq!(agent.project_context.borrow().worktrees, vec![])
1040        });
1041
1042        let worktree = project
1043            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1044            .await
1045            .unwrap();
1046        cx.run_until_parked();
1047        agent.read_with(cx, |agent, _| {
1048            assert_eq!(
1049                agent.project_context.borrow().worktrees,
1050                vec![WorktreeContext {
1051                    root_name: "a".into(),
1052                    abs_path: Path::new("/a").into(),
1053                    rules_file: None
1054                }]
1055            )
1056        });
1057
1058        // Creating `/a/.rules` updates the project context.
1059        fs.insert_file("/a/.rules", Vec::new()).await;
1060        cx.run_until_parked();
1061        agent.read_with(cx, |agent, cx| {
1062            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1063            assert_eq!(
1064                agent.project_context.borrow().worktrees,
1065                vec![WorktreeContext {
1066                    root_name: "a".into(),
1067                    abs_path: Path::new("/a").into(),
1068                    rules_file: Some(RulesFileContext {
1069                        path_in_worktree: Path::new(".rules").into(),
1070                        text: "".into(),
1071                        project_entry_id: rules_entry.id.to_usize()
1072                    })
1073                }]
1074            )
1075        });
1076    }
1077
1078    #[gpui::test]
1079    async fn test_listing_models(cx: &mut TestAppContext) {
1080        init_test(cx);
1081        let fs = FakeFs::new(cx.executor());
1082        fs.insert_tree("/", json!({ "a": {}  })).await;
1083        let project = Project::test(fs.clone(), [], cx).await;
1084        let connection = NativeAgentConnection(
1085            NativeAgent::new(
1086                project.clone(),
1087                Templates::new(),
1088                None,
1089                fs.clone(),
1090                &mut cx.to_async(),
1091            )
1092            .await
1093            .unwrap(),
1094        );
1095
1096        let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1097
1098        let acp_thread::AgentModelList::Grouped(models) = models else {
1099            panic!("Unexpected model group");
1100        };
1101        assert_eq!(
1102            models,
1103            IndexMap::from_iter([(
1104                AgentModelGroupName("Fake".into()),
1105                vec![AgentModelInfo {
1106                    id: AgentModelId("fake/fake".into()),
1107                    name: "Fake".into(),
1108                    icon: Some(ui::IconName::ZedAssistant),
1109                }]
1110            )])
1111        );
1112    }
1113
1114    #[gpui::test]
1115    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1116        init_test(cx);
1117        let fs = FakeFs::new(cx.executor());
1118        fs.create_dir(paths::settings_file().parent().unwrap())
1119            .await
1120            .unwrap();
1121        fs.insert_file(
1122            paths::settings_file(),
1123            json!({
1124                "agent": {
1125                    "default_model": {
1126                        "provider": "foo",
1127                        "model": "bar"
1128                    }
1129                }
1130            })
1131            .to_string()
1132            .into_bytes(),
1133        )
1134        .await;
1135        let project = Project::test(fs.clone(), [], cx).await;
1136
1137        // Create the agent and connection
1138        let agent = NativeAgent::new(
1139            project.clone(),
1140            Templates::new(),
1141            None,
1142            fs.clone(),
1143            &mut cx.to_async(),
1144        )
1145        .await
1146        .unwrap();
1147        let connection = NativeAgentConnection(agent.clone());
1148
1149        // Create a thread/session
1150        let acp_thread = cx
1151            .update(|cx| {
1152                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1153            })
1154            .await
1155            .unwrap();
1156
1157        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1158
1159        // Select a model
1160        let model_id = AgentModelId("fake/fake".into());
1161        cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1162            .await
1163            .unwrap();
1164
1165        // Verify the thread has the selected model
1166        agent.read_with(cx, |agent, _| {
1167            let session = agent.sessions.get(&session_id).unwrap();
1168            session.thread.read_with(cx, |thread, _| {
1169                assert_eq!(thread.model().id().0, "fake");
1170            });
1171        });
1172
1173        cx.run_until_parked();
1174
1175        // Verify settings file was updated
1176        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1177        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1178
1179        // Check that the agent settings contain the selected model
1180        assert_eq!(
1181            settings_json["agent"]["default_model"]["model"],
1182            json!("fake")
1183        );
1184        assert_eq!(
1185            settings_json["agent"]["default_model"]["provider"],
1186            json!("fake")
1187        );
1188    }
1189
1190    fn init_test(cx: &mut TestAppContext) {
1191        env_logger::try_init().ok();
1192        cx.update(|cx| {
1193            let settings_store = SettingsStore::test(cx);
1194            cx.set_global(settings_store);
1195            Project::init_settings(cx);
1196            agent_settings::init(cx);
1197            language::init(cx);
1198            LanguageModelRegistry::test(cx);
1199        });
1200    }
1201}