agent.rs

   1use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
   2use crate::{
   3    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
   4    EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
   5    OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
   6    UserMessageContent, WebSearchTool, templates::Templates,
   7};
   8use crate::{DbThread, ThreadsDatabase};
   9use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
  10use agent_client_protocol as acp;
  11use agent_settings::AgentSettings;
  12use anyhow::{Context as _, Result, anyhow};
  13use collections::{HashSet, IndexMap};
  14use fs::Fs;
  15use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
  16use futures::future::Shared;
  17use futures::{SinkExt, StreamExt, future};
  18use gpui::{
  19    App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
  20};
  21use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
  22use project::{Project, ProjectItem, ProjectPath, Worktree};
  23use prompt_store::{
  24    ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
  25};
  26use settings::update_settings_file;
  27use std::any::Any;
  28use std::cell::RefCell;
  29use std::collections::HashMap;
  30use std::path::Path;
  31use std::rc::Rc;
  32use std::sync::Arc;
  33use util::ResultExt;
  34
  35const RULES_FILE_NAMES: [&'static str; 9] = [
  36    ".rules",
  37    ".cursorrules",
  38    ".windsurfrules",
  39    ".clinerules",
  40    ".github/copilot-instructions.md",
  41    "CLAUDE.md",
  42    "AGENT.md",
  43    "AGENTS.md",
  44    "GEMINI.md",
  45];
  46
  47pub struct RulesLoadingError {
  48    pub message: SharedString,
  49}
  50
  51/// Holds both the internal Thread and the AcpThread for a session
  52struct Session {
  53    /// The internal thread that processes messages
  54    thread: Entity<Thread>,
  55    /// The ACP thread that handles protocol communication
  56    acp_thread: WeakEntity<acp_thread::AcpThread>,
  57    _subscription: Subscription,
  58}
  59
  60pub struct LanguageModels {
  61    /// Access language model by ID
  62    models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
  63    /// Cached list for returning language model information
  64    model_list: acp_thread::AgentModelList,
  65    refresh_models_rx: watch::Receiver<()>,
  66    refresh_models_tx: watch::Sender<()>,
  67}
  68
  69impl LanguageModels {
  70    fn new(cx: &App) -> Self {
  71        let (refresh_models_tx, refresh_models_rx) = watch::channel(());
  72        let mut this = Self {
  73            models: HashMap::default(),
  74            model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
  75            refresh_models_rx,
  76            refresh_models_tx,
  77        };
  78        this.refresh_list(cx);
  79        this
  80    }
  81
  82    fn refresh_list(&mut self, cx: &App) {
  83        let providers = LanguageModelRegistry::global(cx)
  84            .read(cx)
  85            .providers()
  86            .into_iter()
  87            .filter(|provider| provider.is_authenticated(cx))
  88            .collect::<Vec<_>>();
  89
  90        let mut language_model_list = IndexMap::default();
  91        let mut recommended_models = HashSet::default();
  92
  93        let mut recommended = Vec::new();
  94        for provider in &providers {
  95            for model in provider.recommended_models(cx) {
  96                recommended_models.insert(model.id());
  97                recommended.push(Self::map_language_model_to_info(&model, &provider));
  98            }
  99        }
 100        if !recommended.is_empty() {
 101            language_model_list.insert(
 102                acp_thread::AgentModelGroupName("Recommended".into()),
 103                recommended,
 104            );
 105        }
 106
 107        let mut models = HashMap::default();
 108        for provider in providers {
 109            let mut provider_models = Vec::new();
 110            for model in provider.provided_models(cx) {
 111                let model_info = Self::map_language_model_to_info(&model, &provider);
 112                let model_id = model_info.id.clone();
 113                if !recommended_models.contains(&model.id()) {
 114                    provider_models.push(model_info);
 115                }
 116                models.insert(model_id, model);
 117            }
 118            if !provider_models.is_empty() {
 119                language_model_list.insert(
 120                    acp_thread::AgentModelGroupName(provider.name().0.clone()),
 121                    provider_models,
 122                );
 123            }
 124        }
 125
 126        self.models = models;
 127        self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
 128        self.refresh_models_tx.send(()).ok();
 129    }
 130
 131    fn watch(&self) -> watch::Receiver<()> {
 132        self.refresh_models_rx.clone()
 133    }
 134
 135    pub fn model_from_id(
 136        &self,
 137        model_id: &acp_thread::AgentModelId,
 138    ) -> Option<Arc<dyn LanguageModel>> {
 139        self.models.get(model_id).cloned()
 140    }
 141
 142    fn map_language_model_to_info(
 143        model: &Arc<dyn LanguageModel>,
 144        provider: &Arc<dyn LanguageModelProvider>,
 145    ) -> acp_thread::AgentModelInfo {
 146        acp_thread::AgentModelInfo {
 147            id: Self::model_id(model),
 148            name: model.name().0,
 149            icon: Some(provider.icon()),
 150        }
 151    }
 152
 153    fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
 154        acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
 155    }
 156}
 157
 158pub struct NativeAgent {
 159    /// Session ID -> Session mapping
 160    sessions: HashMap<acp::SessionId, Session>,
 161    /// Shared project context for all threads
 162    project_context: Rc<RefCell<ProjectContext>>,
 163    project_context_needs_refresh: watch::Sender<()>,
 164    _maintain_project_context: Task<Result<()>>,
 165    context_server_registry: Entity<ContextServerRegistry>,
 166    /// Shared templates for all threads
 167    templates: Arc<Templates>,
 168    /// Cached model information
 169    models: LanguageModels,
 170    project: Entity<Project>,
 171    prompt_store: Option<Entity<PromptStore>>,
 172    thread_database: Shared<Task<Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
 173    history_listeners: Vec<UnboundedSender<Vec<AcpThreadMetadata>>>,
 174    fs: Arc<dyn Fs>,
 175    _subscriptions: Vec<Subscription>,
 176}
 177
 178impl NativeAgent {
 179    pub async fn new(
 180        project: Entity<Project>,
 181        templates: Arc<Templates>,
 182        prompt_store: Option<Entity<PromptStore>>,
 183        fs: Arc<dyn Fs>,
 184        cx: &mut AsyncApp,
 185    ) -> Result<Entity<NativeAgent>> {
 186        log::info!("Creating new NativeAgent");
 187
 188        let project_context = cx
 189            .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
 190            .await;
 191
 192        cx.new(|cx| {
 193            let mut subscriptions = vec![
 194                cx.subscribe(&project, Self::handle_project_event),
 195                cx.subscribe(
 196                    &LanguageModelRegistry::global(cx),
 197                    Self::handle_models_updated_event,
 198                ),
 199            ];
 200            if let Some(prompt_store) = prompt_store.as_ref() {
 201                subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
 202            }
 203
 204            let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
 205                watch::channel(());
 206            Self {
 207                sessions: HashMap::new(),
 208                project_context: Rc::new(RefCell::new(project_context)),
 209                project_context_needs_refresh: project_context_needs_refresh_tx,
 210                _maintain_project_context: cx.spawn(async move |this, cx| {
 211                    Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
 212                }),
 213                context_server_registry: cx.new(|cx| {
 214                    ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
 215                }),
 216                thread_database: ThreadsDatabase::connect(cx),
 217                templates,
 218                models: LanguageModels::new(cx),
 219                project,
 220                prompt_store,
 221                fs,
 222                history_listeners: Vec::new(),
 223                _subscriptions: subscriptions,
 224            }
 225        })
 226    }
 227
 228    pub fn models(&self) -> &LanguageModels {
 229        &self.models
 230    }
 231
 232    async fn maintain_project_context(
 233        this: WeakEntity<Self>,
 234        mut needs_refresh: watch::Receiver<()>,
 235        cx: &mut AsyncApp,
 236    ) -> Result<()> {
 237        while needs_refresh.changed().await.is_ok() {
 238            let project_context = this
 239                .update(cx, |this, cx| {
 240                    Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
 241                })?
 242                .await;
 243            this.update(cx, |this, _| this.project_context.replace(project_context))?;
 244        }
 245
 246        Ok(())
 247    }
 248
 249    fn build_project_context(
 250        project: &Entity<Project>,
 251        prompt_store: Option<&Entity<PromptStore>>,
 252        cx: &mut App,
 253    ) -> Task<ProjectContext> {
 254        let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
 255        let worktree_tasks = worktrees
 256            .into_iter()
 257            .map(|worktree| {
 258                Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
 259            })
 260            .collect::<Vec<_>>();
 261        let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
 262            prompt_store.read_with(cx, |prompt_store, cx| {
 263                let prompts = prompt_store.default_prompt_metadata();
 264                let load_tasks = prompts.into_iter().map(|prompt_metadata| {
 265                    let contents = prompt_store.load(prompt_metadata.id, cx);
 266                    async move { (contents.await, prompt_metadata) }
 267                });
 268                cx.background_spawn(future::join_all(load_tasks))
 269            })
 270        } else {
 271            Task::ready(vec![])
 272        };
 273
 274        cx.spawn(async move |_cx| {
 275            let (worktrees, default_user_rules) =
 276                future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
 277
 278            let worktrees = worktrees
 279                .into_iter()
 280                .map(|(worktree, _rules_error)| {
 281                    // TODO: show error message
 282                    // if let Some(rules_error) = rules_error {
 283                    //     this.update(cx, |_, cx| cx.emit(rules_error)).ok();
 284                    // }
 285                    worktree
 286                })
 287                .collect::<Vec<_>>();
 288
 289            let default_user_rules = default_user_rules
 290                .into_iter()
 291                .flat_map(|(contents, prompt_metadata)| match contents {
 292                    Ok(contents) => Some(UserRulesContext {
 293                        uuid: match prompt_metadata.id {
 294                            PromptId::User { uuid } => uuid,
 295                            PromptId::EditWorkflow => return None,
 296                        },
 297                        title: prompt_metadata.title.map(|title| title.to_string()),
 298                        contents,
 299                    }),
 300                    Err(_err) => {
 301                        // TODO: show error message
 302                        // this.update(cx, |_, cx| {
 303                        //     cx.emit(RulesLoadingError {
 304                        //         message: format!("{err:?}").into(),
 305                        //     });
 306                        // })
 307                        // .ok();
 308                        None
 309                    }
 310                })
 311                .collect::<Vec<_>>();
 312
 313            ProjectContext::new(worktrees, default_user_rules)
 314        })
 315    }
 316
 317    fn load_worktree_info_for_system_prompt(
 318        worktree: Entity<Worktree>,
 319        project: Entity<Project>,
 320        cx: &mut App,
 321    ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
 322        let tree = worktree.read(cx);
 323        let root_name = tree.root_name().into();
 324        let abs_path = tree.abs_path();
 325
 326        let mut context = WorktreeContext {
 327            root_name,
 328            abs_path,
 329            rules_file: None,
 330        };
 331
 332        let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
 333        let Some(rules_task) = rules_task else {
 334            return Task::ready((context, None));
 335        };
 336
 337        cx.spawn(async move |_| {
 338            let (rules_file, rules_file_error) = match rules_task.await {
 339                Ok(rules_file) => (Some(rules_file), None),
 340                Err(err) => (
 341                    None,
 342                    Some(RulesLoadingError {
 343                        message: format!("{err}").into(),
 344                    }),
 345                ),
 346            };
 347            context.rules_file = rules_file;
 348            (context, rules_file_error)
 349        })
 350    }
 351
 352    fn load_worktree_rules_file(
 353        worktree: Entity<Worktree>,
 354        project: Entity<Project>,
 355        cx: &mut App,
 356    ) -> Option<Task<Result<RulesFileContext>>> {
 357        let worktree = worktree.read(cx);
 358        let worktree_id = worktree.id();
 359        let selected_rules_file = RULES_FILE_NAMES
 360            .into_iter()
 361            .filter_map(|name| {
 362                worktree
 363                    .entry_for_path(name)
 364                    .filter(|entry| entry.is_file())
 365                    .map(|entry| entry.path.clone())
 366            })
 367            .next();
 368
 369        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 370        // supported. This doesn't seem to occur often in GitHub repositories.
 371        selected_rules_file.map(|path_in_worktree| {
 372            let project_path = ProjectPath {
 373                worktree_id,
 374                path: path_in_worktree.clone(),
 375            };
 376            let buffer_task =
 377                project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 378            let rope_task = cx.spawn(async move |cx| {
 379                buffer_task.await?.read_with(cx, |buffer, cx| {
 380                    let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
 381                    anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
 382                })?
 383            });
 384            // Build a string from the rope on a background thread.
 385            cx.background_spawn(async move {
 386                let (project_entry_id, rope) = rope_task.await?;
 387                anyhow::Ok(RulesFileContext {
 388                    path_in_worktree,
 389                    text: rope.to_string().trim().to_string(),
 390                    project_entry_id: project_entry_id.to_usize(),
 391                })
 392            })
 393        })
 394    }
 395
 396    fn handle_project_event(
 397        &mut self,
 398        _project: Entity<Project>,
 399        event: &project::Event,
 400        _cx: &mut Context<Self>,
 401    ) {
 402        match event {
 403            project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
 404                self.project_context_needs_refresh.send(()).ok();
 405            }
 406            project::Event::WorktreeUpdatedEntries(_, items) => {
 407                if items.iter().any(|(path, _, _)| {
 408                    RULES_FILE_NAMES
 409                        .iter()
 410                        .any(|name| path.as_ref() == Path::new(name))
 411                }) {
 412                    self.project_context_needs_refresh.send(()).ok();
 413                }
 414            }
 415            _ => {}
 416        }
 417    }
 418
 419    fn handle_prompts_updated_event(
 420        &mut self,
 421        _prompt_store: Entity<PromptStore>,
 422        _event: &prompt_store::PromptsUpdatedEvent,
 423        _cx: &mut Context<Self>,
 424    ) {
 425        self.project_context_needs_refresh.send(()).ok();
 426    }
 427
 428    fn handle_models_updated_event(
 429        &mut self,
 430        _registry: Entity<LanguageModelRegistry>,
 431        _event: &language_model::Event,
 432        cx: &mut Context<Self>,
 433    ) {
 434        self.models.refresh_list(cx);
 435        for session in self.sessions.values_mut() {
 436            session.thread.update(cx, |thread, _| {
 437                let model_id = LanguageModels::model_id(&thread.model());
 438                if let Some(model) = self.models.model_from_id(&model_id) {
 439                    thread.set_model(model.clone());
 440                }
 441            });
 442        }
 443    }
 444}
 445
 446/// Wrapper struct that implements the AgentConnection trait
 447#[derive(Clone)]
 448pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 449
 450impl NativeAgentConnection {
 451    pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
 452        self.0
 453            .read(cx)
 454            .sessions
 455            .get(session_id)
 456            .map(|session| session.thread.clone())
 457    }
 458
 459    fn run_turn(
 460        &self,
 461        session_id: acp::SessionId,
 462        cx: &mut App,
 463        f: impl 'static
 464        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
 465    ) -> Task<Result<acp::PromptResponse>> {
 466        let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
 467            agent
 468                .sessions
 469                .get_mut(&session_id)
 470                .map(|s| (s.thread.clone(), s.acp_thread.clone()))
 471        }) else {
 472            return Task::ready(Err(anyhow!("Session not found")));
 473        };
 474        log::debug!("Found session for: {}", session_id);
 475
 476        let mut response_stream = match f(thread, cx) {
 477            Ok(stream) => stream,
 478            Err(err) => return Task::ready(Err(err)),
 479        };
 480        cx.spawn(async move |cx| {
 481            // Handle response stream and forward to session.acp_thread
 482            while let Some(result) = response_stream.next().await {
 483                match result {
 484                    Ok(event) => {
 485                        log::trace!("Received completion event: {:?}", event);
 486
 487                        match event {
 488                            ThreadEvent::UserMessage(message) => {
 489                                todo!()
 490                            }
 491                            ThreadEvent::AgentText(text) => {
 492                                acp_thread.update(cx, |thread, cx| {
 493                                    thread.push_assistant_content_block(
 494                                        acp::ContentBlock::Text(acp::TextContent {
 495                                            text,
 496                                            annotations: None,
 497                                        }),
 498                                        false,
 499                                        cx,
 500                                    )
 501                                })?;
 502                            }
 503                            ThreadEvent::AgentThinking(text) => {
 504                                acp_thread.update(cx, |thread, cx| {
 505                                    thread.push_assistant_content_block(
 506                                        acp::ContentBlock::Text(acp::TextContent {
 507                                            text,
 508                                            annotations: None,
 509                                        }),
 510                                        true,
 511                                        cx,
 512                                    )
 513                                })?;
 514                            }
 515                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
 516                                tool_call,
 517                                options,
 518                                response,
 519                            }) => {
 520                                let recv = acp_thread.update(cx, |thread, cx| {
 521                                    thread.request_tool_call_authorization(tool_call, options, cx)
 522                                })?;
 523                                cx.background_spawn(async move {
 524                                    if let Some(recv) = recv.log_err()
 525                                        && let Some(option) = recv
 526                                            .await
 527                                            .context("authorization sender was dropped")
 528                                            .log_err()
 529                                    {
 530                                        response
 531                                            .send(option)
 532                                            .map(|_| anyhow!("authorization receiver was dropped"))
 533                                            .log_err();
 534                                    }
 535                                })
 536                                .detach();
 537                            }
 538                            ThreadEvent::ToolCall(tool_call) => {
 539                                acp_thread.update(cx, |thread, cx| {
 540                                    thread.upsert_tool_call(tool_call, cx)
 541                                })??;
 542                            }
 543                            ThreadEvent::ToolCallUpdate(update) => {
 544                                acp_thread.update(cx, |thread, cx| {
 545                                    thread.update_tool_call(update, cx)
 546                                })??;
 547                            }
 548                            ThreadEvent::Stop(stop_reason) => {
 549                                log::debug!("Assistant message complete: {:?}", stop_reason);
 550                                return Ok(acp::PromptResponse { stop_reason });
 551                            }
 552                        }
 553                    }
 554                    Err(e) => {
 555                        log::error!("Error in model response stream: {:?}", e);
 556                        return Err(e);
 557                    }
 558                }
 559            }
 560
 561            log::info!("Response stream completed");
 562            anyhow::Ok(acp::PromptResponse {
 563                stop_reason: acp::StopReason::EndTurn,
 564            })
 565        })
 566    }
 567
 568    fn register_tools(
 569        thread: &mut Thread,
 570        project: Entity<Project>,
 571        action_log: Entity<action_log::ActionLog>,
 572        cx: &mut Context<Thread>,
 573    ) {
 574        thread.add_tool(CopyPathTool::new(project.clone()));
 575        thread.add_tool(CreateDirectoryTool::new(project.clone()));
 576        thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
 577        thread.add_tool(DiagnosticsTool::new(project.clone()));
 578        thread.add_tool(EditFileTool::new(cx.weak_entity()));
 579        thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
 580        thread.add_tool(FindPathTool::new(project.clone()));
 581        thread.add_tool(GrepTool::new(project.clone()));
 582        thread.add_tool(ListDirectoryTool::new(project.clone()));
 583        thread.add_tool(MovePathTool::new(project.clone()));
 584        thread.add_tool(NowTool);
 585        thread.add_tool(OpenTool::new(project.clone()));
 586        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
 587        thread.add_tool(TerminalTool::new(project.clone(), cx));
 588        thread.add_tool(ThinkingTool);
 589        thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
 590    }
 591}
 592
 593impl AgentModelSelector for NativeAgentConnection {
 594    fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
 595        log::debug!("NativeAgentConnection::list_models called");
 596        let list = self.0.read(cx).models.model_list.clone();
 597        Task::ready(if list.is_empty() {
 598            Err(anyhow::anyhow!("No models available"))
 599        } else {
 600            Ok(list)
 601        })
 602    }
 603
 604    fn select_model(
 605        &self,
 606        session_id: acp::SessionId,
 607        model_id: acp_thread::AgentModelId,
 608        cx: &mut App,
 609    ) -> Task<Result<()>> {
 610        log::info!("Setting model for session {}: {}", session_id, model_id);
 611        let Some(thread) = self
 612            .0
 613            .read(cx)
 614            .sessions
 615            .get(&session_id)
 616            .map(|session| session.thread.clone())
 617        else {
 618            return Task::ready(Err(anyhow!("Session not found")));
 619        };
 620
 621        let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
 622            return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
 623        };
 624
 625        thread.update(cx, |thread, _cx| {
 626            thread.set_model(model.clone());
 627        });
 628
 629        update_settings_file::<AgentSettings>(
 630            self.0.read(cx).fs.clone(),
 631            cx,
 632            move |settings, _cx| {
 633                settings.set_model(model);
 634            },
 635        );
 636
 637        Task::ready(Ok(()))
 638    }
 639
 640    fn selected_model(
 641        &self,
 642        session_id: &acp::SessionId,
 643        cx: &mut App,
 644    ) -> Task<Result<acp_thread::AgentModelInfo>> {
 645        let session_id = session_id.clone();
 646
 647        let Some(thread) = self
 648            .0
 649            .read(cx)
 650            .sessions
 651            .get(&session_id)
 652            .map(|session| session.thread.clone())
 653        else {
 654            return Task::ready(Err(anyhow!("Session not found")));
 655        };
 656        let model = thread.read(cx).model().clone();
 657        let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
 658        else {
 659            return Task::ready(Err(anyhow!("Provider not found")));
 660        };
 661        Task::ready(Ok(LanguageModels::map_language_model_to_info(
 662            &model, &provider,
 663        )))
 664    }
 665
 666    fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
 667        self.0.read(cx).models.watch()
 668    }
 669}
 670
 671impl acp_thread::AgentConnection for NativeAgentConnection {
 672    fn new_thread(
 673        self: Rc<Self>,
 674        project: Entity<Project>,
 675        cwd: &Path,
 676        cx: &mut App,
 677    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 678        let agent = self.0.clone();
 679        log::info!("Creating new thread for project at: {:?}", cwd);
 680
 681        cx.spawn(async move |cx| {
 682            log::debug!("Starting thread creation in async context");
 683
 684            // Generate session ID
 685            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
 686            log::info!("Created session with ID: {}", session_id);
 687
 688            // Create AcpThread
 689            let acp_thread = cx.update(|cx| {
 690                cx.new(|cx| {
 691                    acp_thread::AcpThread::new(
 692                        "agent2",
 693                        self.clone(),
 694                        project.clone(),
 695                        session_id.clone(),
 696                        cx,
 697                    )
 698                })
 699            })?;
 700            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
 701
 702            // Create Thread
 703            let thread = agent.update(
 704                cx,
 705                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
 706                    // Fetch default model from registry settings
 707                    let registry = LanguageModelRegistry::read_global(cx);
 708
 709                    // Log available models for debugging
 710                    let available_count = registry.available_models(cx).count();
 711                    log::debug!("Total available models: {}", available_count);
 712
 713                    let default_model = registry
 714                        .default_model()
 715                        .and_then(|default_model| {
 716                            agent
 717                                .models
 718                                .model_from_id(&LanguageModels::model_id(&default_model.model))
 719                        })
 720                        .ok_or_else(|| {
 721                            log::warn!("No default model configured in settings");
 722                            anyhow!(
 723                                "No default model. Please configure a default model in settings."
 724                            )
 725                        })?;
 726
 727                    let thread = cx.new(|cx| {
 728                        let mut thread = Thread::new(
 729                            project.clone(),
 730                            agent.project_context.clone(),
 731                            agent.context_server_registry.clone(),
 732                            action_log.clone(),
 733                            agent.templates.clone(),
 734                            default_model,
 735                            cx,
 736                        );
 737                        Self::register_tools(&mut thread, project, action_log, cx);
 738                        thread
 739                    });
 740
 741                    Ok(thread)
 742                },
 743            )??;
 744
 745            // Store the session
 746            agent.update(cx, |agent, cx| {
 747                agent.sessions.insert(
 748                    session_id,
 749                    Session {
 750                        thread,
 751                        acp_thread: acp_thread.downgrade(),
 752                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 753                            this.sessions.remove(acp_thread.session_id());
 754                        }),
 755                    },
 756                );
 757            })?;
 758
 759            Ok(acp_thread)
 760        })
 761    }
 762
 763    fn auth_methods(&self) -> &[acp::AuthMethod] {
 764        &[] // No auth for in-process
 765    }
 766
 767    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
 768        Task::ready(Ok(()))
 769    }
 770
 771    fn list_threads(&self, cx: &mut App) -> Option<UnboundedReceiver<Vec<AcpThreadMetadata>>> {
 772        let (mut tx, rx) = futures::channel::mpsc::unbounded();
 773        let database = self.0.update(cx, |this, _| {
 774            this.history_listeners.push(tx.clone());
 775            this.thread_database.clone()
 776        });
 777        cx.background_executor()
 778            .spawn(async move {
 779                dbg!("listing!");
 780                let database = database.await.map_err(|e| anyhow!(e))?;
 781                let results = database.list_threads().await?;
 782
 783                dbg!(&results);
 784                tx.send(
 785                    results
 786                        .into_iter()
 787                        .map(|thread| AcpThreadMetadata {
 788                            agent: NATIVE_AGENT_SERVER_NAME.clone(),
 789                            id: thread.id.into(),
 790                            title: thread.title,
 791                            updated_at: thread.updated_at,
 792                        })
 793                        .collect(),
 794                )
 795                .await?;
 796                anyhow::Ok(())
 797            })
 798            .detach_and_log_err(cx);
 799        Some(rx)
 800    }
 801
 802    fn load_thread(
 803        self: Rc<Self>,
 804        project: Entity<Project>,
 805        _cwd: &Path,
 806        session_id: acp::SessionId,
 807        cx: &mut App,
 808    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 809        let thread_id = session_id.clone().into();
 810        let database = self.0.update(cx, |this, _| this.thread_database.clone());
 811        cx.spawn(async move |cx| {
 812            let database = database.await.map_err(|e| anyhow!(e))?;
 813            let db_thread = database
 814                .load_thread(thread_id)
 815                .await?
 816                .context("no such thread found")?;
 817
 818            let acp_thread = cx.update(|cx| {
 819                cx.new(|cx| {
 820                    acp_thread::AcpThread::new(
 821                        db_thread.title,
 822                        self.clone(),
 823                        project.clone(),
 824                        session_id.clone(),
 825                        cx,
 826                    )
 827                })
 828            })?;
 829            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
 830            let agent = self.0.clone();
 831
 832            // Create Thread
 833            let thread = agent.update(cx, |agent, cx| {
 834                let configured_model = LanguageModelRegistry::global(cx)
 835                    .update(cx, |registry, cx| {
 836                        db_thread
 837                            .model
 838                            .and_then(|model| {
 839                                let model = SelectedModel {
 840                                    provider: model.provider.clone().into(),
 841                                    model: model.model.clone().into(),
 842                                };
 843                                registry.select_model(&model, cx)
 844                            })
 845                            .or_else(|| registry.default_model())
 846                    })
 847                    .context("no default model configured")?;
 848
 849                let model = agent
 850                    .models
 851                    .model_from_id(&LanguageModels::model_id(&configured_model.model))
 852                    .context("no model by id")?;
 853
 854                let thread = cx.new(|cx| {
 855                    let mut thread = Thread::new(
 856                        project.clone(),
 857                        agent.project_context.clone(),
 858                        agent.context_server_registry.clone(),
 859                        action_log.clone(),
 860                        agent.templates.clone(),
 861                        model,
 862                        cx,
 863                    );
 864                    Self::register_tools(&mut thread, project, action_log, cx);
 865                    thread
 866                });
 867
 868                anyhow::Ok(thread)
 869            })??;
 870
 871            // Store the session
 872            agent.update(cx, |agent, cx| {
 873                agent.sessions.insert(
 874                    session_id,
 875                    Session {
 876                        thread,
 877                        acp_thread: acp_thread.downgrade(),
 878                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
 879                            this.sessions.remove(acp_thread.session_id());
 880                        }),
 881                    },
 882                );
 883            })?;
 884
 885            // we need to actually deserialize the DbThread.
 886            // todo!()
 887
 888            Ok(acp_thread)
 889        })
 890    }
 891
 892    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 893        Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
 894    }
 895
 896    fn prompt(
 897        &self,
 898        id: Option<acp_thread::UserMessageId>,
 899        params: acp::PromptRequest,
 900        cx: &mut App,
 901    ) -> Task<Result<acp::PromptResponse>> {
 902        let id = id.expect("UserMessageId is required");
 903        let session_id = params.session_id.clone();
 904        log::info!("Received prompt request for session: {}", session_id);
 905        log::debug!("Prompt blocks count: {}", params.prompt.len());
 906
 907        self.run_turn(session_id, cx, |thread, cx| {
 908            let content: Vec<UserMessageContent> = params
 909                .prompt
 910                .into_iter()
 911                .map(Into::into)
 912                .collect::<Vec<_>>();
 913            log::info!("Converted prompt to message: {} chars", content.len());
 914            log::debug!("Message id: {:?}", id);
 915            log::debug!("Message content: {:?}", content);
 916
 917            Ok(thread.update(cx, |thread, cx| {
 918                log::info!(
 919                    "Sending message to thread with model: {:?}",
 920                    thread.model().name()
 921                );
 922                thread.send(id, content, cx)
 923            }))
 924        })
 925    }
 926
 927    fn resume(
 928        &self,
 929        session_id: &acp::SessionId,
 930        _cx: &mut App,
 931    ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
 932        Some(Rc::new(NativeAgentSessionResume {
 933            connection: self.clone(),
 934            session_id: session_id.clone(),
 935        }) as _)
 936    }
 937
 938    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
 939        log::info!("Cancelling on session: {}", session_id);
 940        self.0.update(cx, |agent, cx| {
 941            if let Some(agent) = agent.sessions.get(session_id) {
 942                agent.thread.update(cx, |thread, _cx| thread.cancel());
 943            }
 944        });
 945    }
 946
 947    fn session_editor(
 948        &self,
 949        session_id: &agent_client_protocol::SessionId,
 950        cx: &mut App,
 951    ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
 952        self.0.update(cx, |agent, _cx| {
 953            agent
 954                .sessions
 955                .get(session_id)
 956                .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
 957        })
 958    }
 959
 960    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 961        self
 962    }
 963}
 964
 965struct NativeAgentSessionEditor(Entity<Thread>);
 966
 967impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
 968    fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
 969        Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
 970    }
 971}
 972
 973struct NativeAgentSessionResume {
 974    connection: NativeAgentConnection,
 975    session_id: acp::SessionId,
 976}
 977
 978impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
 979    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
 980        self.connection
 981            .run_turn(self.session_id.clone(), cx, |thread, cx| {
 982                thread.update(cx, |thread, cx| thread.resume(cx))
 983            })
 984    }
 985}
 986
 987#[cfg(test)]
 988mod tests {
 989    use super::*;
 990    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
 991    use fs::FakeFs;
 992    use gpui::TestAppContext;
 993    use serde_json::json;
 994    use settings::SettingsStore;
 995
 996    #[gpui::test]
 997    async fn test_maintaining_project_context(cx: &mut TestAppContext) {
 998        init_test(cx);
 999        let fs = FakeFs::new(cx.executor());
1000        fs.insert_tree(
1001            "/",
1002            json!({
1003                "a": {}
1004            }),
1005        )
1006        .await;
1007        let project = Project::test(fs.clone(), [], cx).await;
1008        let agent = NativeAgent::new(
1009            project.clone(),
1010            Templates::new(),
1011            None,
1012            fs.clone(),
1013            &mut cx.to_async(),
1014        )
1015        .await
1016        .unwrap();
1017        agent.read_with(cx, |agent, _| {
1018            assert_eq!(agent.project_context.borrow().worktrees, vec![])
1019        });
1020
1021        let worktree = project
1022            .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1023            .await
1024            .unwrap();
1025        cx.run_until_parked();
1026        agent.read_with(cx, |agent, _| {
1027            assert_eq!(
1028                agent.project_context.borrow().worktrees,
1029                vec![WorktreeContext {
1030                    root_name: "a".into(),
1031                    abs_path: Path::new("/a").into(),
1032                    rules_file: None
1033                }]
1034            )
1035        });
1036
1037        // Creating `/a/.rules` updates the project context.
1038        fs.insert_file("/a/.rules", Vec::new()).await;
1039        cx.run_until_parked();
1040        agent.read_with(cx, |agent, cx| {
1041            let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1042            assert_eq!(
1043                agent.project_context.borrow().worktrees,
1044                vec![WorktreeContext {
1045                    root_name: "a".into(),
1046                    abs_path: Path::new("/a").into(),
1047                    rules_file: Some(RulesFileContext {
1048                        path_in_worktree: Path::new(".rules").into(),
1049                        text: "".into(),
1050                        project_entry_id: rules_entry.id.to_usize()
1051                    })
1052                }]
1053            )
1054        });
1055    }
1056
1057    #[gpui::test]
1058    async fn test_listing_models(cx: &mut TestAppContext) {
1059        init_test(cx);
1060        let fs = FakeFs::new(cx.executor());
1061        fs.insert_tree("/", json!({ "a": {}  })).await;
1062        let project = Project::test(fs.clone(), [], cx).await;
1063        let connection = NativeAgentConnection(
1064            NativeAgent::new(
1065                project.clone(),
1066                Templates::new(),
1067                None,
1068                fs.clone(),
1069                &mut cx.to_async(),
1070            )
1071            .await
1072            .unwrap(),
1073        );
1074
1075        let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1076
1077        let acp_thread::AgentModelList::Grouped(models) = models else {
1078            panic!("Unexpected model group");
1079        };
1080        assert_eq!(
1081            models,
1082            IndexMap::from_iter([(
1083                AgentModelGroupName("Fake".into()),
1084                vec![AgentModelInfo {
1085                    id: AgentModelId("fake/fake".into()),
1086                    name: "Fake".into(),
1087                    icon: Some(ui::IconName::ZedAssistant),
1088                }]
1089            )])
1090        );
1091    }
1092
1093    #[gpui::test]
1094    async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1095        init_test(cx);
1096        let fs = FakeFs::new(cx.executor());
1097        fs.create_dir(paths::settings_file().parent().unwrap())
1098            .await
1099            .unwrap();
1100        fs.insert_file(
1101            paths::settings_file(),
1102            json!({
1103                "agent": {
1104                    "default_model": {
1105                        "provider": "foo",
1106                        "model": "bar"
1107                    }
1108                }
1109            })
1110            .to_string()
1111            .into_bytes(),
1112        )
1113        .await;
1114        let project = Project::test(fs.clone(), [], cx).await;
1115
1116        // Create the agent and connection
1117        let agent = NativeAgent::new(
1118            project.clone(),
1119            Templates::new(),
1120            None,
1121            fs.clone(),
1122            &mut cx.to_async(),
1123        )
1124        .await
1125        .unwrap();
1126        let connection = NativeAgentConnection(agent.clone());
1127
1128        // Create a thread/session
1129        let acp_thread = cx
1130            .update(|cx| {
1131                Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1132            })
1133            .await
1134            .unwrap();
1135
1136        let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1137
1138        // Select a model
1139        let model_id = AgentModelId("fake/fake".into());
1140        cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1141            .await
1142            .unwrap();
1143
1144        // Verify the thread has the selected model
1145        agent.read_with(cx, |agent, _| {
1146            let session = agent.sessions.get(&session_id).unwrap();
1147            session.thread.read_with(cx, |thread, _| {
1148                assert_eq!(thread.model().id().0, "fake");
1149            });
1150        });
1151
1152        cx.run_until_parked();
1153
1154        // Verify settings file was updated
1155        let settings_content = fs.load(paths::settings_file()).await.unwrap();
1156        let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1157
1158        // Check that the agent settings contain the selected model
1159        assert_eq!(
1160            settings_json["agent"]["default_model"]["model"],
1161            json!("fake")
1162        );
1163        assert_eq!(
1164            settings_json["agent"]["default_model"]["provider"],
1165            json!("fake")
1166        );
1167    }
1168
1169    fn init_test(cx: &mut TestAppContext) {
1170        env_logger::try_init().ok();
1171        cx.update(|cx| {
1172            let settings_store = SettingsStore::test(cx);
1173            cx.set_global(settings_store);
1174            Project::init_settings(cx);
1175            agent_settings::init(cx);
1176            language::init(cx);
1177            LanguageModelRegistry::test(cx);
1178        });
1179    }
1180}