agent.rs

   1use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
   2use crate::{
   3    AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
   4    DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
   5    MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
   6    ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
   7};
   8use crate::{DbThread, ThreadsDatabase};
   9use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
  10use agent_client_protocol as acp;
  11use agent_settings::AgentSettings;
  12use anyhow::{Context as _, Result, anyhow};
  13use collections::{HashSet, IndexMap};
  14use fs::Fs;
  15use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
  16use futures::future::Shared;
  17use futures::{SinkExt, StreamExt, future};
  18use gpui::{
  19    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  20};
  21use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
  22use project::{Project, ProjectItem, ProjectPath, Worktree};
  23use prompt_store::{
  24    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  25};
  26use settings::update_settings_file;
  27use std::any::Any;
  28use std::cell::RefCell;
  29use std::collections::HashMap;
  30use std::path::Path;
  31use std::rc::Rc;
  32use std::sync::Arc;
  33use util::ResultExt;
  34
  35const RULES_FILE_NAMES: [&'static str; 9] = [
  36    ".rules",
  37    ".cursorrules",
  38    ".windsurfrules",
  39    ".clinerules",
  40    ".github/copilot-instructions.md",
  41    "CLAUDE.md",
  42    "AGENT.md",
  43    "AGENTS.md",
  44    "GEMINI.md",
  45];
  46
  47pub struct RulesLoadingError {
  48    pub message: SharedString,
  49}
  50
  51/// Holds both the internal Thread and the AcpThread for a session
  52struct Session {
  53    /// The internal thread that processes messages
  54    thread: Entity<Thread>,
  55    /// The ACP thread that handles protocol communication
  56    acp_thread: WeakEntity<acp_thread::AcpThread>,
  57    _subscription: Subscription,
  58}
  59
  60pub struct LanguageModels {
  61    /// Access language model by ID
  62    models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
  63    /// Cached list for returning language model information
  64    model_list: acp_thread::AgentModelList,
  65    refresh_models_rx: watch::Receiver<()>,
  66    refresh_models_tx: watch::Sender<()>,
  67}
  68
  69impl LanguageModels {
  70    fn new(cx: &App) -> Self {
  71        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  72        let mut this = Self {
  73            models: HashMap::default(),
  74            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  75            refresh_models_rx,
  76            refresh_models_tx,
  77        };
  78        this.refresh_list(cx);
  79        this
  80    }
  81
  82    fn refresh_list(&mut self, cx: &App) {
  83        let providers = LanguageModelRegistry::global(cx)
  84            .read(cx)
  85            .providers()
  86            .into_iter()
  87            .filter(|provider| provider.is_authenticated(cx))
  88            .collect::<Vec<_>>();
  89
  90        let mut language_model_list = IndexMap::default();
  91        let mut recommended_models = HashSet::default();
  92
  93        let mut recommended = Vec::new();
  94        for provider in &providers {
  95            for model in provider.recommended_models(cx) {
  96                recommended_models.insert(model.id());
  97                recommended.push(Self::map_language_model_to_info(&model, &provider));
  98            }
  99        }
 100        if !recommended.is_empty() {
 101            language_model_list.insert(
 102                acp_thread::AgentModelGroupName("Recommended".into()),
 103                recommended,
 104            );
 105        }
 106
 107        let mut models = HashMap::default();
 108        for provider in providers {
 109            let mut provider_models = Vec::new();
 110            for model in provider.provided_models(cx) {
 111                let model_info = Self::map_language_model_to_info(&model, &provider);
 112                let model_id = model_info.id.clone();
 113                if !recommended_models.contains(&model.id()) {
 114                    provider_models.push(model_info);
 115                }
 116                models.insert(model_id, model);
 117            }
 118            if !provider_models.is_empty() {
 119                language_model_list.insert(
 120                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 121                    provider_models,
 122                );
 123            }
 124        }
 125
 126        self.models = models;
 127        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 128        self.refresh_models_tx.send(()).ok();
 129    }
 130
 131    fn watch(&self) -> watch::Receiver<()> {
 132        self.refresh_models_rx.clone()
 133    }
 134
 135    pub fn model_from_id(
 136        &self,
 137        model_id: &acp_thread::AgentModelId,
 138    ) -> Option<Arc<dyn LanguageModel>> {
 139        self.models.get(model_id).cloned()
 140    }
 141
 142    fn map_language_model_to_info(
 143        model: &Arc<dyn LanguageModel>,
 144        provider: &Arc<dyn LanguageModelProvider>,
 145    ) -> acp_thread::AgentModelInfo {
 146        acp_thread::AgentModelInfo {
 147            id: Self::model_id(model),
 148            name: model.name().0,
 149            icon: Some(provider.icon()),
 150        }
 151    }
 152
 153    fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
 154        acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 155    }
 156}
 157
 158pub struct NativeAgent {
 159    /// Session ID -> Session mapping
 160    sessions: HashMap<acp::SessionId, Session>,
 161    /// Shared project context for all threads
 162    project_context: Rc<RefCell<ProjectContext>>,
 163    project_context_needs_refresh: watch::Sender<()>,
 164    _maintain_project_context: Task<Result<()>>,
 165    context_server_registry: Entity<ContextServerRegistry>,
 166    /// Shared templates for all threads
 167    templates: Arc<Templates>,
 168    /// Cached model information
 169    models: LanguageModels,
 170    project: Entity<Project>,
 171    prompt_store: Option<Entity<PromptStore>>,
 172    thread_database: Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 173    history_listeners: Vec<UnboundedSender<Vec<AcpThreadMetadata>>>,
 174    fs: Arc<dyn Fs>,
 175    _subscriptions: Vec<Subscription>,
 176}
 177
 178impl NativeAgent {
 179    pub async fn new(
 180        project: Entity<Project>,
 181        templates: Arc<Templates>,
 182        prompt_store: Option<Entity<PromptStore>>,
 183        fs: Arc<dyn Fs>,
 184        cx: &mut AsyncApp,
 185    ) -> Result<Entity<NativeAgent>> {
 186        log::info!("Creating new NativeAgent");
 187
 188        let project_context = cx
 189            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 190            .await;
 191
 192        cx.new(|cx| {
 193            let mut subscriptions = vec![
 194                cx.subscribe(&project, Self::handle_project_event),
 195                cx.subscribe(
 196                    &LanguageModelRegistry::global(cx),
 197                    Self::handle_models_updated_event,
 198                ),
 199            ];
 200            if let Some(prompt_store) = prompt_store.as_ref() {
 201                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 202            }
 203
 204            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 205                watch::channel(());
 206            Self {
 207                sessions: HashMap::new(),
 208                project_context: Rc::new(RefCell::new(project_context)),
 209                project_context_needs_refresh: project_context_needs_refresh_tx,
 210                _maintain_project_context: cx.spawn(async move |this, cx| {
 211                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 212                }),
 213                context_server_registry: cx.new(|cx| {
 214                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 215                }),
 216                thread_database: ThreadsDatabase::connect(cx),
 217                templates,
 218                models: LanguageModels::new(cx),
 219                project,
 220                prompt_store,
 221                fs,
 222                history_listeners: Vec::new(),
 223                _subscriptions: subscriptions,
 224            }
 225        })
 226    }
 227
 228    pub fn models(&self) -> &LanguageModels {
 229        &self.models
 230    }
 231
 232    async fn maintain_project_context(
 233        this: WeakEntity<Self>,
 234        mut needs_refresh: watch::Receiver<()>,
 235        cx: &mut AsyncApp,
 236    ) -> Result<()> {
 237        while needs_refresh.changed().await.is_ok() {
 238            let project_context = this
 239                .update(cx, |this, cx| {
 240                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 241                })?
 242                .await;
 243            this.update(cx, |this, _| this.project_context.replace(project_context))?;
 244        }
 245
 246        Ok(())
 247    }
 248
 249    fn build_project_context(
 250        project: &Entity<Project>,
 251        prompt_store: Option<&Entity<PromptStore>>,
 252        cx: &mut App,
 253    ) -> Task<ProjectContext> {
 254        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 255        let worktree_tasks = worktrees
 256            .into_iter()
 257            .map(|worktree| {
 258                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 259            })
 260            .collect::<Vec<_>>();
 261        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 262            prompt_store.read_with(cx, |prompt_store, cx| {
 263                let prompts = prompt_store.default_prompt_metadata();
 264                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 265                    let contents = prompt_store.load(prompt_metadata.id, cx);
 266                    async move { (contents.await, prompt_metadata) }
 267                });
 268                cx.background_spawn(future::join_all(load_tasks))
 269            })
 270        } else {
 271            Task::ready(vec![])
 272        };
 273
 274        cx.spawn(async move |_cx| {
 275            let (worktrees, default_user_rules) =
 276                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 277
 278            let worktrees = worktrees
 279                .into_iter()
 280                .map(|(worktree, _rules_error)| {
 281                    // TODO: show error message
 282                    // if let Some(rules_error) = rules_error {
 283                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 284                    // }
 285                    worktree
 286                })
 287                .collect::<Vec<_>>();
 288
 289            let default_user_rules = default_user_rules
 290                .into_iter()
 291                .flat_map(|(contents, prompt_metadata)| match contents {
 292                    Ok(contents) => Some(UserRulesContext {
 293                        uuid: match prompt_metadata.id {
 294                            PromptId::User { uuid } => uuid,
 295                            PromptId::EditWorkflow => return None,
 296                        },
 297                        title: prompt_metadata.title.map(|title| title.to_string()),
 298                        contents,
 299                    }),
 300                    Err(_err) => {
 301                        // TODO: show error message
 302                        // this.update(cx, |_, cx| {
 303                        //     cx.emit(RulesLoadingError {
 304                        //         message: format!("{err:?}").into(),
 305                        //     });
 306                        // })
 307                        // .ok();
 308                        None
 309                    }
 310                })
 311                .collect::<Vec<_>>();
 312
 313            ProjectContext::new(worktrees, default_user_rules)
 314        })
 315    }
 316
 317    fn load_worktree_info_for_system_prompt(
 318        worktree: Entity<Worktree>,
 319        project: Entity<Project>,
 320        cx: &mut App,
 321    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 322        let tree = worktree.read(cx);
 323        let root_name = tree.root_name().into();
 324        let abs_path = tree.abs_path();
 325
 326        let mut context = WorktreeContext {
 327            root_name,
 328            abs_path,
 329            rules_file: None,
 330        };
 331
 332        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 333        let Some(rules_task) = rules_task else {
 334            return Task::ready((context, None));
 335        };
 336
 337        cx.spawn(async move |_| {
 338            let (rules_file, rules_file_error) = match rules_task.await {
 339                Ok(rules_file) => (Some(rules_file), None),
 340                Err(err) => (
 341                    None,
 342                    Some(RulesLoadingError {
 343                        message: format!("{err}").into(),
 344                    }),
 345                ),
 346            };
 347            context.rules_file = rules_file;
 348            (context, rules_file_error)
 349        })
 350    }
 351
 352    fn load_worktree_rules_file(
 353        worktree: Entity<Worktree>,
 354        project: Entity<Project>,
 355        cx: &mut App,
 356    ) -> Option<Task<Result<RulesFileContext>>> {
 357        let worktree = worktree.read(cx);
 358        let worktree_id = worktree.id();
 359        let selected_rules_file = RULES_FILE_NAMES
 360            .into_iter()
 361            .filter_map(|name| {
 362                worktree
 363                    .entry_for_path(name)
 364                    .filter(|entry| entry.is_file())
 365                    .map(|entry| entry.path.clone())
 366            })
 367            .next();
 368
 369        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 370        // supported. This doesn't seem to occur often in GitHub repositories.
 371        selected_rules_file.map(|path_in_worktree| {
 372            let project_path = ProjectPath {
 373                worktree_id,
 374                path: path_in_worktree.clone(),
 375            };
 376            let buffer_task =
 377                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 378            let rope_task = cx.spawn(async move |cx| {
 379                buffer_task.await?.read_with(cx, |buffer, cx| {
 380                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 381                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 382                })?
 383            });
 384            // Build a string from the rope on a background thread.
 385            cx.background_spawn(async move {
 386                let (project_entry_id, rope) = rope_task.await?;
 387                anyhow::Ok(RulesFileContext {
 388                    path_in_worktree,
 389                    text: rope.to_string().trim().to_string(),
 390                    project_entry_id: project_entry_id.to_usize(),
 391                })
 392            })
 393        })
 394    }
 395
 396    fn handle_project_event(
 397        &mut self,
 398        _project: Entity<Project>,
 399        event: &project::Event,
 400        _cx: &mut Context<Self>,
 401    ) {
 402        match event {
 403            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 404                self.project_context_needs_refresh.send(()).ok();
 405            }
 406            project::Event::WorktreeUpdatedEntries(_, items) => {
 407                if items.iter().any(|(path, _, _)| {
 408                    RULES_FILE_NAMES
 409                        .iter()
 410                        .any(|name| path.as_ref() == Path::new(name))
 411                }) {
 412                    self.project_context_needs_refresh.send(()).ok();
 413                }
 414            }
 415            _ => {}
 416        }
 417    }
 418
 419    fn handle_prompts_updated_event(
 420        &mut self,
 421        _prompt_store: Entity<PromptStore>,
 422        _event: &prompt_store::PromptsUpdatedEvent,
 423        _cx: &mut Context<Self>,
 424    ) {
 425        self.project_context_needs_refresh.send(()).ok();
 426    }
 427
 428    fn handle_models_updated_event(
 429        &mut self,
 430        _registry: Entity<LanguageModelRegistry>,
 431        _event: &language_model::Event,
 432        cx: &mut Context<Self>,
 433    ) {
 434        self.models.refresh_list(cx);
 435        for session in self.sessions.values_mut() {
 436            session.thread.update(cx, |thread, _| {
 437                let model_id = LanguageModels::model_id(&thread.model());
 438                if let Some(model) = self.models.model_from_id(&model_id) {
 439                    thread.set_model(model.clone());
 440                }
 441            });
 442        }
 443    }
 444}
 445
 446/// Wrapper struct that implements the AgentConnection trait
 447#[derive(Clone)]
 448pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 449
 450impl NativeAgentConnection {
 451    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 452        self.0
 453            .read(cx)
 454            .sessions
 455            .get(session_id)
 456            .map(|session| session.thread.clone())
 457    }
 458
 459    fn run_turn(
 460        &self,
 461        session_id: acp::SessionId,
 462        cx: &mut App,
 463        f: impl 'static
 464        + FnOnce(
 465            Entity<Thread>,
 466            &mut App,
 467        ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
 468    ) -> Task<Result<acp::PromptResponse>> {
 469        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 470            agent
 471                .sessions
 472                .get_mut(&session_id)
 473                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 474        }) else {
 475            return Task::ready(Err(anyhow!("Session not found")));
 476        };
 477        log::debug!("Found session for: {}", session_id);
 478
 479        let mut response_stream = match f(thread, cx) {
 480            Ok(stream) => stream,
 481            Err(err) => return Task::ready(Err(err)),
 482        };
 483        cx.spawn(async move |cx| {
 484            // Handle response stream and forward to session.acp_thread
 485            while let Some(result) = response_stream.next().await {
 486                match result {
 487                    Ok(event) => {
 488                        log::trace!("Received completion event: {:?}", event);
 489
 490                        match event {
 491                            AgentResponseEvent::Text(text) => {
 492                                acp_thread.update(cx, |thread, cx| {
 493                                    thread.push_assistant_content_block(
 494                                        acp::ContentBlock::Text(acp::TextContent {
 495                                            text,
 496                                            annotations: None,
 497                                        }),
 498                                        false,
 499                                        cx,
 500                                    )
 501                                })?;
 502                            }
 503                            AgentResponseEvent::Thinking(text) => {
 504                                acp_thread.update(cx, |thread, cx| {
 505                                    thread.push_assistant_content_block(
 506                                        acp::ContentBlock::Text(acp::TextContent {
 507                                            text,
 508                                            annotations: None,
 509                                        }),
 510                                        true,
 511                                        cx,
 512                                    )
 513                                })?;
 514                            }
 515                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
 516                                tool_call,
 517                                options,
 518                                response,
 519                            }) => {
 520                                let recv = acp_thread.update(cx, |thread, cx| {
 521                                    thread.request_tool_call_authorization(tool_call, options, cx)
 522                                })?;
 523                                cx.background_spawn(async move {
 524                                    if let Some(recv) = recv.log_err()
 525                                        && let Some(option) = recv
 526                                            .await
 527                                            .context("authorization sender was dropped")
 528                                            .log_err()
 529                                    {
 530                                        response
 531                                            .send(option)
 532                                            .map(|_| anyhow!("authorization receiver was dropped"))
 533                                            .log_err();
 534                                    }
 535                                })
 536                                .detach();
 537                            }
 538                            AgentResponseEvent::ToolCall(tool_call) => {
 539                                acp_thread.update(cx, |thread, cx| {
 540                                    thread.upsert_tool_call(tool_call, cx)
 541                                })??;
 542                            }
 543                            AgentResponseEvent::ToolCallUpdate(update) => {
 544                                acp_thread.update(cx, |thread, cx| {
 545                                    thread.update_tool_call(update, cx)
 546                                })??;
 547                            }
 548                            AgentResponseEvent::Stop(stop_reason) => {
 549                                log::debug!("Assistant message complete: {:?}", stop_reason);
 550                                return Ok(acp::PromptResponse { stop_reason });
 551                            }
 552                        }
 553                    }
 554                    Err(e) => {
 555                        log::error!("Error in model response stream: {:?}", e);
 556                        return Err(e);
 557                    }
 558                }
 559            }
 560
 561            log::info!("Response stream completed");
 562            anyhow::Ok(acp::PromptResponse {
 563                stop_reason: acp::StopReason::EndTurn,
 564            })
 565        })
 566    }
 567
 568    fn register_tools(
 569        thread: &mut Thread,
 570        project: Entity<Project>,
 571        action_log: Entity<action_log::ActionLog>,
 572        cx: &mut Context<Thread>,
 573    ) {
 574        thread.add_tool(CopyPathTool::new(project.clone()));
 575        thread.add_tool(CreateDirectoryTool::new(project.clone()));
 576        thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
 577        thread.add_tool(DiagnosticsTool::new(project.clone()));
 578        thread.add_tool(EditFileTool::new(cx.weak_entity()));
 579        thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
 580        thread.add_tool(FindPathTool::new(project.clone()));
 581        thread.add_tool(GrepTool::new(project.clone()));
 582        thread.add_tool(ListDirectoryTool::new(project.clone()));
 583        thread.add_tool(MovePathTool::new(project.clone()));
 584        thread.add_tool(NowTool);
 585        thread.add_tool(OpenTool::new(project.clone()));
 586        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
 587        thread.add_tool(TerminalTool::new(project.clone(), cx));
 588        thread.add_tool(ThinkingTool);
 589        thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
 590    }
 591}
 592
 593impl AgentModelSelector for NativeAgentConnection {
 594    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 595        log::debug!("NativeAgentConnection::list_models called");
 596        let list = self.0.read(cx).models.model_list.clone();
 597        Task::ready(if list.is_empty() {
 598            Err(anyhow::anyhow!("No models available"))
 599        } else {
 600            Ok(list)
 601        })
 602    }
 603
 604    fn select_model(
 605        &self,
 606        session_id: acp::SessionId,
 607        model_id: acp_thread::AgentModelId,
 608        cx: &mut App,
 609    ) -> Task<Result<()>> {
 610        log::info!("Setting model for session {}: {}", session_id, model_id);
 611        let Some(thread) = self
 612            .0
 613            .read(cx)
 614            .sessions
 615            .get(&session_id)
 616            .map(|session| session.thread.clone())
 617        else {
 618            return Task::ready(Err(anyhow!("Session not found")));
 619        };
 620
 621        let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
 622            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 623        };
 624
 625        thread.update(cx, |thread, _cx| {
 626            thread.set_model(model.clone());
 627        });
 628
 629        update_settings_file::<AgentSettings>(
 630            self.0.read(cx).fs.clone(),
 631            cx,
 632            move |settings, _cx| {
 633                settings.set_model(model);
 634            },
 635        );
 636
 637        Task::ready(Ok(()))
 638    }
 639
 640    fn selected_model(
 641        &self,
 642        session_id: &acp::SessionId,
 643        cx: &mut App,
 644    ) -> Task<Result<acp_thread::AgentModelInfo>> {
 645        let session_id = session_id.clone();
 646
 647        let Some(thread) = self
 648            .0
 649            .read(cx)
 650            .sessions
 651            .get(&session_id)
 652            .map(|session| session.thread.clone())
 653        else {
 654            return Task::ready(Err(anyhow!("Session not found")));
 655        };
 656        let model = thread.read(cx).model().clone();
 657        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 658        else {
 659            return Task::ready(Err(anyhow!("Provider not found")));
 660        };
 661        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 662            &model, &provider,
 663        )))
 664    }
 665
 666    fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
 667        self.0.read(cx).models.watch()
 668    }
 669}
 670
 671impl acp_thread::AgentConnection for NativeAgentConnection {
 672    fn new_thread(
 673        self: Rc<Self>,
 674        project: Entity<Project>,
 675        cwd: &Path,
 676        cx: &mut App,
 677    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 678        let agent = self.0.clone();
 679        log::info!("Creating new thread for project at: {:?}", cwd);
 680
 681        cx.spawn(async move |cx| {
 682            log::debug!("Starting thread creation in async context");
 683
 684            // Generate session ID
 685            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
 686            log::info!("Created session with ID: {}", session_id);
 687
 688            // Create AcpThread
 689            let acp_thread = cx.update(|cx| {
 690                cx.new(|cx| {
 691                    acp_thread::AcpThread::new(
 692                        "agent2",
 693                        self.clone(),
 694                        project.clone(),
 695                        session_id.clone(),
 696                        cx,
 697                    )
 698                })
 699            })?;
 700            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
 701
 702            // Create Thread
 703            let thread = agent.update(
 704                cx,
 705                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 706                    // Fetch default model from registry settings
 707                    let registry = LanguageModelRegistry::read_global(cx);
 708
 709                    // Log available models for debugging
 710                    let available_count = registry.available_models(cx).count();
 711                    log::debug!("Total available models: {}", available_count);
 712
 713                    let default_model = registry
 714                        .default_model()
 715                        .and_then(|default_model| {
 716                            agent
 717                                .models
 718                                .model_from_id(&LanguageModels::model_id(&default_model.model))
 719                        })
 720                        .ok_or_else(|| {
 721                            log::warn!("No default model configured in settings");
 722                            anyhow!(
 723                                "No default model. Please configure a default model in settings."
 724                            )
 725                        })?;
 726
 727                    let thread = cx.new(|cx| {
 728                        let mut thread = Thread::new(
 729                            project.clone(),
 730                            agent.project_context.clone(),
 731                            agent.context_server_registry.clone(),
 732                            action_log.clone(),
 733                            agent.templates.clone(),
 734                            default_model,
 735                            cx,
 736                        );
 737                        Self::register_tools(&mut thread, project, action_log, cx);
 738                        thread
 739                    });
 740
 741                    Ok(thread)
 742                },
 743            )??;
 744
 745            // Store the session
 746            agent.update(cx, |agent, cx| {
 747                agent.sessions.insert(
 748                    session_id,
 749                    Session {
 750                        thread,
 751                        acp_thread: acp_thread.downgrade(),
 752                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 753                            this.sessions.remove(acp_thread.session_id());
 754                        }),
 755                    },
 756                );
 757            })?;
 758
 759            Ok(acp_thread)
 760        })
 761    }
 762
 763    fn auth_methods(&self) -> &[acp::AuthMethod] {
 764        &[] // No auth for in-process
 765    }
 766
 767    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 768        Task::ready(Ok(()))
 769    }
 770
 771    fn list_threads(&self, cx: &mut App) -> Option<UnboundedReceiver<Vec<AcpThreadMetadata>>> {
 772        let (mut tx, rx) = futures::channel::mpsc::unbounded();
 773        let database = self.0.update(cx, |this, _| {
 774            this.history_listeners.push(tx.clone());
 775            this.thread_database.clone()
 776        });
 777        cx.background_executor()
 778            .spawn(async move {
 779                dbg!("listing!");
 780                let database = database.await.map_err(|e| anyhow!(e))?;
 781                let results = database.list_threads().await?;
 782
 783                dbg!(&results);
 784                tx.send(
 785                    results
 786                        .into_iter()
 787                        .map(|thread| AcpThreadMetadata {
 788                            agent: NATIVE_AGENT_SERVER_NAME.clone(),
 789                            id: thread.id,
 790                            title: thread.title,
 791                            updated_at: thread.updated_at,
 792                        })
 793                        .collect(),
 794                )
 795                .await?;
 796                anyhow::Ok(())
 797            })
 798            .detach_and_log_err(cx);
 799        Some(rx)
 800    }
 801
 802    fn load_thread(
 803        self: Rc<Self>,
 804        project: Entity<Project>,
 805        _cwd: &Path,
 806        session_id: acp::SessionId,
 807        cx: &mut App,
 808    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 809        let database = self.0.update(cx, |this, _| this.thread_database.clone());
 810        cx.spawn(async move |cx| {
 811            let database = database.await.map_err(|e| anyhow!(e))?;
 812            let db_thread = database
 813                .load_thread(session_id.clone())
 814                .await?
 815                .context("no such thread found")?;
 816
 817            let acp_thread = cx.update(|cx| {
 818                cx.new(|cx| {
 819                    acp_thread::AcpThread::new(
 820                        db_thread.title,
 821                        self.clone(),
 822                        project.clone(),
 823                        session_id.clone(),
 824                        cx,
 825                    )
 826                })
 827            })?;
 828            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
 829            let agent = self.0.clone();
 830
 831            // Create Thread
 832            let thread = agent.update(cx, |agent, cx| {
 833                let configured_model = LanguageModelRegistry::global(cx)
 834                    .update(cx, |registry, cx| {
 835                        db_thread
 836                            .model
 837                            .and_then(|model| {
 838                                let model = SelectedModel {
 839                                    provider: model.provider.clone().into(),
 840                                    model: model.model.clone().into(),
 841                                };
 842                                registry.select_model(&model, cx)
 843                            })
 844                            .or_else(|| registry.default_model())
 845                    })
 846                    .context("no default model configured")?;
 847
 848                let model = agent
 849                    .models
 850                    .model_from_id(&LanguageModels::model_id(&configured_model.model))
 851                    .context("no model by id")?;
 852
 853                let thread = cx.new(|cx| {
 854                    let mut thread = Thread::new(
 855                        project.clone(),
 856                        agent.project_context.clone(),
 857                        agent.context_server_registry.clone(),
 858                        action_log.clone(),
 859                        agent.templates.clone(),
 860                        model,
 861                        cx,
 862                    );
 863                    Self::register_tools(&mut thread, project, action_log, cx);
 864                    thread
 865                });
 866
 867                anyhow::Ok(thread)
 868            })??;
 869
 870            // Store the session
 871            agent.update(cx, |agent, cx| {
 872                agent.sessions.insert(
 873                    session_id,
 874                    Session {
 875                        thread,
 876                        acp_thread: acp_thread.downgrade(),
 877                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 878                            this.sessions.remove(acp_thread.session_id());
 879                        }),
 880                    },
 881                );
 882            })?;
 883
 884            // we need to actually deserialize the DbThread.
 885            // todo!()
 886
 887            Ok(acp_thread)
 888        })
 889    }
 890
 891    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 892        Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
 893    }
 894
 895    fn prompt(
 896        &self,
 897        id: Option<acp_thread::UserMessageId>,
 898        params: acp::PromptRequest,
 899        cx: &mut App,
 900    ) -> Task<Result<acp::PromptResponse>> {
 901        let id = id.expect("UserMessageId is required");
 902        let session_id = params.session_id.clone();
 903        log::info!("Received prompt request for session: {}", session_id);
 904        log::debug!("Prompt blocks count: {}", params.prompt.len());
 905
 906        self.run_turn(session_id, cx, |thread, cx| {
 907            let content: Vec<UserMessageContent> = params
 908                .prompt
 909                .into_iter()
 910                .map(Into::into)
 911                .collect::<Vec<_>>();
 912            log::info!("Converted prompt to message: {} chars", content.len());
 913            log::debug!("Message id: {:?}", id);
 914            log::debug!("Message content: {:?}", content);
 915
 916            Ok(thread.update(cx, |thread, cx| {
 917                log::info!(
 918                    "Sending message to thread with model: {:?}",
 919                    thread.model().name()
 920                );
 921                thread.send(id, content, cx)
 922            }))
 923        })
 924    }
 925
 926    fn resume(
 927        &self,
 928        session_id: &acp::SessionId,
 929        _cx: &mut App,
 930    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
 931        Some(Rc::new(NativeAgentSessionResume {
 932            connection: self.clone(),
 933            session_id: session_id.clone(),
 934        }) as _)
 935    }
 936
 937    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
 938        log::info!("Cancelling on session: {}", session_id);
 939        self.0.update(cx, |agent, cx| {
 940            if let Some(agent) = agent.sessions.get(session_id) {
 941                agent.thread.update(cx, |thread, _cx| thread.cancel());
 942            }
 943        });
 944    }
 945
 946    fn session_editor(
 947        &self,
 948        session_id: &agent_client_protocol::SessionId,
 949        cx: &mut App,
 950    ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
 951        self.0.update(cx, |agent, _cx| {
 952            agent
 953                .sessions
 954                .get(session_id)
 955                .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
 956        })
 957    }
 958
 959    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 960        self
 961    }
 962}
 963
 964struct NativeAgentSessionEditor(Entity<Thread>);
 965
 966impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
 967    fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
 968        Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
 969    }
 970}
 971
 972struct NativeAgentSessionResume {
 973    connection: NativeAgentConnection,
 974    session_id: acp::SessionId,
 975}
 976
 977impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
 978    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
 979        self.connection
 980            .run_turn(self.session_id.clone(), cx, |thread, cx| {
 981                thread.update(cx, |thread, cx| thread.resume(cx))
 982            })
 983    }
 984}
 985
 986#[cfg(test)]
 987mod tests {
 988    use super::*;
 989    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
 990    use fs::FakeFs;
 991    use gpui::TestAppContext;
 992    use serde_json::json;
 993    use settings::SettingsStore;
 994
 995    #[gpui::test]
 996    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
 997        init_test(cx);
 998        let fs = FakeFs::new(cx.executor());
 999        fs.insert_tree(
1000            "/",
1001            json!({
1002                "a": {}
1003            }),
1004        )
1005        .await;
1006        let project = Project::test(fs.clone(), [], cx).await;
1007        let agent = NativeAgent::new(
1008            project.clone(),
1009            Templates::new(),
1010            None,
1011            fs.clone(),
1012            &mut cx.to_async(),
1013        )
1014        .await
1015        .unwrap();
1016        agent.read_with(cx, |agent, _| {
1017            assert_eq!(agent.project_context.borrow().worktrees, vec![])
1018        });
1019
1020        let worktree = project
1021            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1022            .await
1023            .unwrap();
1024        cx.run_until_parked();
1025        agent.read_with(cx, |agent, _| {
1026            assert_eq!(
1027                agent.project_context.borrow().worktrees,
1028                vec![WorktreeContext {
1029                    root_name: "a".into(),
1030                    abs_path: Path::new("/a").into(),
1031                    rules_file: None
1032                }]
1033            )
1034        });
1035
1036        // Creating `/a/.rules` updates the project context.
1037        fs.insert_file("/a/.rules", Vec::new()).await;
1038        cx.run_until_parked();
1039        agent.read_with(cx, |agent, cx| {
1040            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1041            assert_eq!(
1042                agent.project_context.borrow().worktrees,
1043                vec![WorktreeContext {
1044                    root_name: "a".into(),
1045                    abs_path: Path::new("/a").into(),
1046                    rules_file: Some(RulesFileContext {
1047                        path_in_worktree: Path::new(".rules").into(),
1048                        text: "".into(),
1049                        project_entry_id: rules_entry.id.to_usize()
1050                    })
1051                }]
1052            )
1053        });
1054    }
1055
1056    #[gpui::test]
1057    async fn test_listing_models(cx: &mut TestAppContext) {
1058        init_test(cx);
1059        let fs = FakeFs::new(cx.executor());
1060        fs.insert_tree("/", json!({ "a": {}  })).await;
1061        let project = Project::test(fs.clone(), [], cx).await;
1062        let connection = NativeAgentConnection(
1063            NativeAgent::new(
1064                project.clone(),
1065                Templates::new(),
1066                None,
1067                fs.clone(),
1068                &mut cx.to_async(),
1069            )
1070            .await
1071            .unwrap(),
1072        );
1073
1074        let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1075
1076        let acp_thread::AgentModelList::Grouped(models) = models else {
1077            panic!("Unexpected model group");
1078        };
1079        assert_eq!(
1080            models,
1081            IndexMap::from_iter([(
1082                AgentModelGroupName("Fake".into()),
1083                vec![AgentModelInfo {
1084                    id: AgentModelId("fake/fake".into()),
1085                    name: "Fake".into(),
1086                    icon: Some(ui::IconName::ZedAssistant),
1087                }]
1088            )])
1089        );
1090    }
1091
1092    #[gpui::test]
1093    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1094        init_test(cx);
1095        let fs = FakeFs::new(cx.executor());
1096        fs.create_dir(paths::settings_file().parent().unwrap())
1097            .await
1098            .unwrap();
1099        fs.insert_file(
1100            paths::settings_file(),
1101            json!({
1102                "agent": {
1103                    "default_model": {
1104                        "provider": "foo",
1105                        "model": "bar"
1106                    }
1107                }
1108            })
1109            .to_string()
1110            .into_bytes(),
1111        )
1112        .await;
1113        let project = Project::test(fs.clone(), [], cx).await;
1114
1115        // Create the agent and connection
1116        let agent = NativeAgent::new(
1117            project.clone(),
1118            Templates::new(),
1119            None,
1120            fs.clone(),
1121            &mut cx.to_async(),
1122        )
1123        .await
1124        .unwrap();
1125        let connection = NativeAgentConnection(agent.clone());
1126
1127        // Create a thread/session
1128        let acp_thread = cx
1129            .update(|cx| {
1130                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1131            })
1132            .await
1133            .unwrap();
1134
1135        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1136
1137        // Select a model
1138        let model_id = AgentModelId("fake/fake".into());
1139        cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1140            .await
1141            .unwrap();
1142
1143        // Verify the thread has the selected model
1144        agent.read_with(cx, |agent, _| {
1145            let session = agent.sessions.get(&session_id).unwrap();
1146            session.thread.read_with(cx, |thread, _| {
1147                assert_eq!(thread.model().id().0, "fake");
1148            });
1149        });
1150
1151        cx.run_until_parked();
1152
1153        // Verify settings file was updated
1154        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1155        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1156
1157        // Check that the agent settings contain the selected model
1158        assert_eq!(
1159            settings_json["agent"]["default_model"]["model"],
1160            json!("fake")
1161        );
1162        assert_eq!(
1163            settings_json["agent"]["default_model"]["provider"],
1164            json!("fake")
1165        );
1166    }
1167
1168    fn init_test(cx: &mut TestAppContext) {
1169        env_logger::try_init().ok();
1170        cx.update(|cx| {
1171            let settings_store = SettingsStore::test(cx);
1172            cx.set_global(settings_store);
1173            Project::init_settings(cx);
1174            agent_settings::init(cx);
1175            language::init(cx);
1176            LanguageModelRegistry::test(cx);
1177        });
1178    }
1179}