1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use shell_command_parser::extract_commands;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{FutureExt as _, StreamExt as _, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use std::time::Duration;
54use util::ResultExt;
55use util::rel_path::RelPath;
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ProjectSnapshot {
59 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
60 pub timestamp: DateTime<Utc>,
61}
62
63pub struct RulesLoadingError {
64 pub message: SharedString,
65}
66
67/// Holds both the internal Thread and the AcpThread for a session
68struct Session {
69 /// The internal thread that processes messages
70 thread: Entity<Thread>,
71 /// The ACP thread that handles protocol communication
72 acp_thread: Entity<acp_thread::AcpThread>,
73 pending_save: Task<()>,
74 _subscriptions: Vec<Subscription>,
75}
76
77pub struct LanguageModels {
78 /// Access language model by ID
79 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
80 /// Cached list for returning language model information
81 model_list: acp_thread::AgentModelList,
82 refresh_models_rx: watch::Receiver<()>,
83 refresh_models_tx: watch::Sender<()>,
84 _authenticate_all_providers_task: Task<()>,
85}
86
87impl LanguageModels {
88 fn new(cx: &mut App) -> Self {
89 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
90
91 let mut this = Self {
92 models: HashMap::default(),
93 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
94 refresh_models_rx,
95 refresh_models_tx,
96 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
97 };
98 this.refresh_list(cx);
99 this
100 }
101
102 fn refresh_list(&mut self, cx: &App) {
103 let providers = LanguageModelRegistry::global(cx)
104 .read(cx)
105 .visible_providers()
106 .into_iter()
107 .filter(|provider| provider.is_authenticated(cx))
108 .collect::<Vec<_>>();
109
110 let mut language_model_list = IndexMap::default();
111 let mut recommended_models = HashSet::default();
112
113 let mut recommended = Vec::new();
114 for provider in &providers {
115 for model in provider.recommended_models(cx) {
116 recommended_models.insert((model.provider_id(), model.id()));
117 recommended.push(Self::map_language_model_to_info(&model, provider));
118 }
119 }
120 if !recommended.is_empty() {
121 language_model_list.insert(
122 acp_thread::AgentModelGroupName("Recommended".into()),
123 recommended,
124 );
125 }
126
127 let mut models = HashMap::default();
128 for provider in providers {
129 let mut provider_models = Vec::new();
130 for model in provider.provided_models(cx) {
131 let model_info = Self::map_language_model_to_info(&model, &provider);
132 let model_id = model_info.id.clone();
133 provider_models.push(model_info);
134 models.insert(model_id, model);
135 }
136 if !provider_models.is_empty() {
137 language_model_list.insert(
138 acp_thread::AgentModelGroupName(provider.name().0.clone()),
139 provider_models,
140 );
141 }
142 }
143
144 self.models = models;
145 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
146 self.refresh_models_tx.send(()).ok();
147 }
148
149 fn watch(&self) -> watch::Receiver<()> {
150 self.refresh_models_rx.clone()
151 }
152
153 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
154 self.models.get(model_id).cloned()
155 }
156
157 fn map_language_model_to_info(
158 model: &Arc<dyn LanguageModel>,
159 provider: &Arc<dyn LanguageModelProvider>,
160 ) -> acp_thread::AgentModelInfo {
161 acp_thread::AgentModelInfo {
162 id: Self::model_id(model),
163 name: model.name().0,
164 description: None,
165 icon: Some(match provider.icon() {
166 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
167 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
168 }),
169 is_latest: model.is_latest(),
170 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
171 }
172 }
173
174 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
175 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
176 }
177
178 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
179 let authenticate_all_providers = LanguageModelRegistry::global(cx)
180 .read(cx)
181 .visible_providers()
182 .iter()
183 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
184 .collect::<Vec<_>>();
185
186 cx.background_spawn(async move {
187 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
188 if let Err(err) = authenticate_task.await {
189 match err {
190 language_model::AuthenticateError::CredentialsNotFound => {
191 // Since we're authenticating these providers in the
192 // background for the purposes of populating the
193 // language selector, we don't care about providers
194 // where the credentials are not found.
195 }
196 language_model::AuthenticateError::ConnectionRefused => {
197 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
198 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
199 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
200 }
201 _ => {
202 // Some providers have noisy failure states that we
203 // don't want to spam the logs with every time the
204 // language model selector is initialized.
205 //
206 // Ideally these should have more clear failure modes
207 // that we know are safe to ignore here, like what we do
208 // with `CredentialsNotFound` above.
209 match provider_id.0.as_ref() {
210 "lmstudio" | "ollama" => {
211 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
212 //
213 // These fail noisily, so we don't log them.
214 }
215 "copilot_chat" => {
216 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
217 }
218 _ => {
219 log::error!(
220 "Failed to authenticate provider: {}: {err:#}",
221 provider_name.0
222 );
223 }
224 }
225 }
226 }
227 }
228 }
229 })
230 }
231}
232
233pub struct NativeAgent {
234 /// Session ID -> Session mapping
235 sessions: HashMap<acp::SessionId, Session>,
236 thread_store: Entity<ThreadStore>,
237 /// Shared project context for all threads
238 project_context: Entity<ProjectContext>,
239 project_context_needs_refresh: watch::Sender<()>,
240 _maintain_project_context: Task<Result<()>>,
241 context_server_registry: Entity<ContextServerRegistry>,
242 /// Shared templates for all threads
243 templates: Arc<Templates>,
244 /// Cached model information
245 models: LanguageModels,
246 project: Entity<Project>,
247 prompt_store: Option<Entity<PromptStore>>,
248 fs: Arc<dyn Fs>,
249 _subscriptions: Vec<Subscription>,
250}
251
252impl NativeAgent {
253 pub async fn new(
254 project: Entity<Project>,
255 thread_store: Entity<ThreadStore>,
256 templates: Arc<Templates>,
257 prompt_store: Option<Entity<PromptStore>>,
258 fs: Arc<dyn Fs>,
259 cx: &mut AsyncApp,
260 ) -> Result<Entity<NativeAgent>> {
261 log::debug!("Creating new NativeAgent");
262
263 let project_context = cx
264 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
265 .await;
266
267 Ok(cx.new(|cx| {
268 let context_server_store = project.read(cx).context_server_store();
269 let context_server_registry =
270 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
271
272 let mut subscriptions = vec![
273 cx.subscribe(&project, Self::handle_project_event),
274 cx.subscribe(
275 &LanguageModelRegistry::global(cx),
276 Self::handle_models_updated_event,
277 ),
278 cx.subscribe(
279 &context_server_store,
280 Self::handle_context_server_store_updated,
281 ),
282 cx.subscribe(
283 &context_server_registry,
284 Self::handle_context_server_registry_event,
285 ),
286 ];
287 if let Some(prompt_store) = prompt_store.as_ref() {
288 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
289 }
290
291 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
292 watch::channel(());
293 Self {
294 sessions: HashMap::default(),
295 thread_store,
296 project_context: cx.new(|_| project_context),
297 project_context_needs_refresh: project_context_needs_refresh_tx,
298 _maintain_project_context: cx.spawn(async move |this, cx| {
299 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
300 }),
301 context_server_registry,
302 templates,
303 models: LanguageModels::new(cx),
304 project,
305 prompt_store,
306 fs,
307 _subscriptions: subscriptions,
308 }
309 }))
310 }
311
312 fn new_session(
313 &mut self,
314 project: Entity<Project>,
315 cx: &mut Context<Self>,
316 ) -> Entity<AcpThread> {
317 // Create Thread
318 // Fetch default model from registry settings
319 let registry = LanguageModelRegistry::read_global(cx);
320 // Log available models for debugging
321 let available_count = registry.available_models(cx).count();
322 log::debug!("Total available models: {}", available_count);
323
324 let default_model = registry.default_model().and_then(|default_model| {
325 self.models
326 .model_from_id(&LanguageModels::model_id(&default_model.model))
327 });
328 let thread = cx.new(|cx| {
329 Thread::new(
330 project.clone(),
331 self.project_context.clone(),
332 self.context_server_registry.clone(),
333 self.templates.clone(),
334 default_model,
335 cx,
336 )
337 });
338
339 self.register_session(thread, cx)
340 }
341
342 fn register_session(
343 &mut self,
344 thread_handle: Entity<Thread>,
345 cx: &mut Context<Self>,
346 ) -> Entity<AcpThread> {
347 let connection = Rc::new(NativeAgentConnection(cx.entity()));
348
349 let thread = thread_handle.read(cx);
350 let session_id = thread.id().clone();
351 let parent_session_id = thread.parent_thread_id();
352 let title = thread.title();
353 let project = thread.project.clone();
354 let action_log = thread.action_log.clone();
355 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
356 let acp_thread = cx.new(|cx| {
357 acp_thread::AcpThread::new(
358 parent_session_id,
359 title,
360 connection,
361 project.clone(),
362 action_log.clone(),
363 session_id.clone(),
364 prompt_capabilities_rx,
365 cx,
366 )
367 });
368
369 let registry = LanguageModelRegistry::read_global(cx);
370 let summarization_model = registry.thread_summary_model().map(|c| c.model);
371
372 let weak = cx.weak_entity();
373 thread_handle.update(cx, |thread, cx| {
374 thread.set_summarization_model(summarization_model, cx);
375 thread.add_default_tools(
376 Rc::new(NativeThreadEnvironment {
377 acp_thread: acp_thread.downgrade(),
378 agent: weak,
379 }) as _,
380 cx,
381 )
382 });
383
384 let subscriptions = vec![
385 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
386 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
387 cx.observe(&thread_handle, move |this, thread, cx| {
388 this.save_thread(thread, cx)
389 }),
390 ];
391
392 self.sessions.insert(
393 session_id,
394 Session {
395 thread: thread_handle,
396 acp_thread: acp_thread.clone(),
397 _subscriptions: subscriptions,
398 pending_save: Task::ready(()),
399 },
400 );
401
402 self.update_available_commands(cx);
403
404 acp_thread
405 }
406
407 pub fn models(&self) -> &LanguageModels {
408 &self.models
409 }
410
411 async fn maintain_project_context(
412 this: WeakEntity<Self>,
413 mut needs_refresh: watch::Receiver<()>,
414 cx: &mut AsyncApp,
415 ) -> Result<()> {
416 while needs_refresh.changed().await.is_ok() {
417 let project_context = this
418 .update(cx, |this, cx| {
419 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
420 })?
421 .await;
422 this.update(cx, |this, cx| {
423 this.project_context = cx.new(|_| project_context);
424 })?;
425 }
426
427 Ok(())
428 }
429
430 fn build_project_context(
431 project: &Entity<Project>,
432 prompt_store: Option<&Entity<PromptStore>>,
433 cx: &mut App,
434 ) -> Task<ProjectContext> {
435 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
436 let worktree_tasks = worktrees
437 .into_iter()
438 .map(|worktree| {
439 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
440 })
441 .collect::<Vec<_>>();
442 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
443 prompt_store.read_with(cx, |prompt_store, cx| {
444 let prompts = prompt_store.default_prompt_metadata();
445 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
446 let contents = prompt_store.load(prompt_metadata.id, cx);
447 async move { (contents.await, prompt_metadata) }
448 });
449 cx.background_spawn(future::join_all(load_tasks))
450 })
451 } else {
452 Task::ready(vec![])
453 };
454
455 cx.spawn(async move |_cx| {
456 let (worktrees, default_user_rules) =
457 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
458
459 let worktrees = worktrees
460 .into_iter()
461 .map(|(worktree, _rules_error)| {
462 // TODO: show error message
463 // if let Some(rules_error) = rules_error {
464 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
465 // }
466 worktree
467 })
468 .collect::<Vec<_>>();
469
470 let default_user_rules = default_user_rules
471 .into_iter()
472 .flat_map(|(contents, prompt_metadata)| match contents {
473 Ok(contents) => Some(UserRulesContext {
474 uuid: prompt_metadata.id.as_user()?,
475 title: prompt_metadata.title.map(|title| title.to_string()),
476 contents,
477 }),
478 Err(_err) => {
479 // TODO: show error message
480 // this.update(cx, |_, cx| {
481 // cx.emit(RulesLoadingError {
482 // message: format!("{err:?}").into(),
483 // });
484 // })
485 // .ok();
486 None
487 }
488 })
489 .collect::<Vec<_>>();
490
491 ProjectContext::new(worktrees, default_user_rules)
492 })
493 }
494
495 fn load_worktree_info_for_system_prompt(
496 worktree: Entity<Worktree>,
497 project: Entity<Project>,
498 cx: &mut App,
499 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
500 let tree = worktree.read(cx);
501 let root_name = tree.root_name_str().into();
502 let abs_path = tree.abs_path();
503
504 let mut context = WorktreeContext {
505 root_name,
506 abs_path,
507 rules_file: None,
508 };
509
510 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
511 let Some(rules_task) = rules_task else {
512 return Task::ready((context, None));
513 };
514
515 cx.spawn(async move |_| {
516 let (rules_file, rules_file_error) = match rules_task.await {
517 Ok(rules_file) => (Some(rules_file), None),
518 Err(err) => (
519 None,
520 Some(RulesLoadingError {
521 message: format!("{err}").into(),
522 }),
523 ),
524 };
525 context.rules_file = rules_file;
526 (context, rules_file_error)
527 })
528 }
529
530 fn load_worktree_rules_file(
531 worktree: Entity<Worktree>,
532 project: Entity<Project>,
533 cx: &mut App,
534 ) -> Option<Task<Result<RulesFileContext>>> {
535 let worktree = worktree.read(cx);
536 let worktree_id = worktree.id();
537 let selected_rules_file = RULES_FILE_NAMES
538 .into_iter()
539 .filter_map(|name| {
540 worktree
541 .entry_for_path(RelPath::unix(name).unwrap())
542 .filter(|entry| entry.is_file())
543 .map(|entry| entry.path.clone())
544 })
545 .next();
546
547 // Note that Cline supports `.clinerules` being a directory, but that is not currently
548 // supported. This doesn't seem to occur often in GitHub repositories.
549 selected_rules_file.map(|path_in_worktree| {
550 let project_path = ProjectPath {
551 worktree_id,
552 path: path_in_worktree.clone(),
553 };
554 let buffer_task =
555 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
556 let rope_task = cx.spawn(async move |cx| {
557 let buffer = buffer_task.await?;
558 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
559 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
560 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
561 })?;
562 anyhow::Ok((project_entry_id, rope))
563 });
564 // Build a string from the rope on a background thread.
565 cx.background_spawn(async move {
566 let (project_entry_id, rope) = rope_task.await?;
567 anyhow::Ok(RulesFileContext {
568 path_in_worktree,
569 text: rope.to_string().trim().to_string(),
570 project_entry_id: project_entry_id.to_usize(),
571 })
572 })
573 })
574 }
575
576 fn handle_thread_title_updated(
577 &mut self,
578 thread: Entity<Thread>,
579 _: &TitleUpdated,
580 cx: &mut Context<Self>,
581 ) {
582 let session_id = thread.read(cx).id();
583 let Some(session) = self.sessions.get(session_id) else {
584 return;
585 };
586 let thread = thread.downgrade();
587 let acp_thread = session.acp_thread.downgrade();
588 cx.spawn(async move |_, cx| {
589 let title = thread.read_with(cx, |thread, _| thread.title())?;
590 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
591 task.await
592 })
593 .detach_and_log_err(cx);
594 }
595
596 fn handle_thread_token_usage_updated(
597 &mut self,
598 thread: Entity<Thread>,
599 usage: &TokenUsageUpdated,
600 cx: &mut Context<Self>,
601 ) {
602 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
603 return;
604 };
605 session.acp_thread.update(cx, |acp_thread, cx| {
606 acp_thread.update_token_usage(usage.0.clone(), cx);
607 });
608 }
609
610 fn handle_project_event(
611 &mut self,
612 _project: Entity<Project>,
613 event: &project::Event,
614 _cx: &mut Context<Self>,
615 ) {
616 match event {
617 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
618 self.project_context_needs_refresh.send(()).ok();
619 }
620 project::Event::WorktreeUpdatedEntries(_, items) => {
621 if items.iter().any(|(path, _, _)| {
622 RULES_FILE_NAMES
623 .iter()
624 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
625 }) {
626 self.project_context_needs_refresh.send(()).ok();
627 }
628 }
629 _ => {}
630 }
631 }
632
633 fn handle_prompts_updated_event(
634 &mut self,
635 _prompt_store: Entity<PromptStore>,
636 _event: &prompt_store::PromptsUpdatedEvent,
637 _cx: &mut Context<Self>,
638 ) {
639 self.project_context_needs_refresh.send(()).ok();
640 }
641
642 fn handle_models_updated_event(
643 &mut self,
644 _registry: Entity<LanguageModelRegistry>,
645 _event: &language_model::Event,
646 cx: &mut Context<Self>,
647 ) {
648 self.models.refresh_list(cx);
649
650 let registry = LanguageModelRegistry::read_global(cx);
651 let default_model = registry.default_model().map(|m| m.model);
652 let summarization_model = registry.thread_summary_model().map(|m| m.model);
653
654 for session in self.sessions.values_mut() {
655 session.thread.update(cx, |thread, cx| {
656 if thread.model().is_none()
657 && let Some(model) = default_model.clone()
658 {
659 thread.set_model(model, cx);
660 cx.notify();
661 }
662 thread.set_summarization_model(summarization_model.clone(), cx);
663 });
664 }
665 }
666
667 fn handle_context_server_store_updated(
668 &mut self,
669 _store: Entity<project::context_server_store::ContextServerStore>,
670 _event: &project::context_server_store::ServerStatusChangedEvent,
671 cx: &mut Context<Self>,
672 ) {
673 self.update_available_commands(cx);
674 }
675
676 fn handle_context_server_registry_event(
677 &mut self,
678 _registry: Entity<ContextServerRegistry>,
679 event: &ContextServerRegistryEvent,
680 cx: &mut Context<Self>,
681 ) {
682 match event {
683 ContextServerRegistryEvent::ToolsChanged => {}
684 ContextServerRegistryEvent::PromptsChanged => {
685 self.update_available_commands(cx);
686 }
687 }
688 }
689
690 fn update_available_commands(&self, cx: &mut Context<Self>) {
691 let available_commands = self.build_available_commands(cx);
692 for session in self.sessions.values() {
693 session.acp_thread.update(cx, |thread, cx| {
694 thread
695 .handle_session_update(
696 acp::SessionUpdate::AvailableCommandsUpdate(
697 acp::AvailableCommandsUpdate::new(available_commands.clone()),
698 ),
699 cx,
700 )
701 .log_err();
702 });
703 }
704 }
705
706 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
707 let registry = self.context_server_registry.read(cx);
708
709 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
710 for context_server_prompt in registry.prompts() {
711 *prompt_name_counts
712 .entry(context_server_prompt.prompt.name.as_str())
713 .or_insert(0) += 1;
714 }
715
716 registry
717 .prompts()
718 .flat_map(|context_server_prompt| {
719 let prompt = &context_server_prompt.prompt;
720
721 let should_prefix = prompt_name_counts
722 .get(prompt.name.as_str())
723 .copied()
724 .unwrap_or(0)
725 > 1;
726
727 let name = if should_prefix {
728 format!("{}.{}", context_server_prompt.server_id, prompt.name)
729 } else {
730 prompt.name.clone()
731 };
732
733 let mut command = acp::AvailableCommand::new(
734 name,
735 prompt.description.clone().unwrap_or_default(),
736 );
737
738 match prompt.arguments.as_deref() {
739 Some([arg]) => {
740 let hint = format!("<{}>", arg.name);
741
742 command = command.input(acp::AvailableCommandInput::Unstructured(
743 acp::UnstructuredCommandInput::new(hint),
744 ));
745 }
746 Some([]) | None => {}
747 Some(_) => {
748 // skip >1 argument commands since we don't support them yet
749 return None;
750 }
751 }
752
753 Some(command)
754 })
755 .collect()
756 }
757
758 pub fn load_thread(
759 &mut self,
760 id: acp::SessionId,
761 cx: &mut Context<Self>,
762 ) -> Task<Result<Entity<Thread>>> {
763 let database_future = ThreadsDatabase::connect(cx);
764 cx.spawn(async move |this, cx| {
765 let database = database_future.await.map_err(|err| anyhow!(err))?;
766 let db_thread = database
767 .load_thread(id.clone())
768 .await?
769 .with_context(|| format!("no thread found with ID: {id:?}"))?;
770
771 this.update(cx, |this, cx| {
772 let summarization_model = LanguageModelRegistry::read_global(cx)
773 .thread_summary_model()
774 .map(|c| c.model);
775
776 cx.new(|cx| {
777 let mut thread = Thread::from_db(
778 id.clone(),
779 db_thread,
780 this.project.clone(),
781 this.project_context.clone(),
782 this.context_server_registry.clone(),
783 this.templates.clone(),
784 cx,
785 );
786 thread.set_summarization_model(summarization_model, cx);
787 thread
788 })
789 })
790 })
791 }
792
793 pub fn open_thread(
794 &mut self,
795 id: acp::SessionId,
796 cx: &mut Context<Self>,
797 ) -> Task<Result<Entity<AcpThread>>> {
798 if let Some(session) = self.sessions.get(&id) {
799 return Task::ready(Ok(session.acp_thread.clone()));
800 }
801
802 let task = self.load_thread(id, cx);
803 cx.spawn(async move |this, cx| {
804 let thread = task.await?;
805 let acp_thread =
806 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
807 let events = thread.update(cx, |thread, cx| thread.replay(cx));
808 cx.update(|cx| {
809 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
810 })
811 .await?;
812 Ok(acp_thread)
813 })
814 }
815
816 pub fn thread_summary(
817 &mut self,
818 id: acp::SessionId,
819 cx: &mut Context<Self>,
820 ) -> Task<Result<SharedString>> {
821 let thread = self.open_thread(id.clone(), cx);
822 cx.spawn(async move |this, cx| {
823 let acp_thread = thread.await?;
824 let result = this
825 .update(cx, |this, cx| {
826 this.sessions
827 .get(&id)
828 .unwrap()
829 .thread
830 .update(cx, |thread, cx| thread.summary(cx))
831 })?
832 .await
833 .context("Failed to generate summary")?;
834 drop(acp_thread);
835 Ok(result)
836 })
837 }
838
839 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
840 if thread.read(cx).is_empty() {
841 return;
842 }
843
844 let database_future = ThreadsDatabase::connect(cx);
845 let (id, db_thread) =
846 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
847 let Some(session) = self.sessions.get_mut(&id) else {
848 return;
849 };
850 let thread_store = self.thread_store.clone();
851 session.pending_save = cx.spawn(async move |_, cx| {
852 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
853 return;
854 };
855 let db_thread = db_thread.await;
856 database.save_thread(id, db_thread).await.log_err();
857 thread_store.update(cx, |store, cx| store.reload(cx));
858 });
859 }
860
861 fn send_mcp_prompt(
862 &self,
863 message_id: UserMessageId,
864 session_id: agent_client_protocol::SessionId,
865 prompt_name: String,
866 server_id: ContextServerId,
867 arguments: HashMap<String, String>,
868 original_content: Vec<acp::ContentBlock>,
869 cx: &mut Context<Self>,
870 ) -> Task<Result<acp::PromptResponse>> {
871 let server_store = self.context_server_registry.read(cx).server_store().clone();
872 let path_style = self.project.read(cx).path_style(cx);
873
874 cx.spawn(async move |this, cx| {
875 let prompt =
876 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
877
878 let (acp_thread, thread) = this.update(cx, |this, _cx| {
879 let session = this
880 .sessions
881 .get(&session_id)
882 .context("Failed to get session")?;
883 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
884 })??;
885
886 let mut last_is_user = true;
887
888 thread.update(cx, |thread, cx| {
889 thread.push_acp_user_block(
890 message_id,
891 original_content.into_iter().skip(1),
892 path_style,
893 cx,
894 );
895 });
896
897 for message in prompt.messages {
898 let context_server::types::PromptMessage { role, content } = message;
899 let block = mcp_message_content_to_acp_content_block(content);
900
901 match role {
902 context_server::types::Role::User => {
903 let id = acp_thread::UserMessageId::new();
904
905 acp_thread.update(cx, |acp_thread, cx| {
906 acp_thread.push_user_content_block_with_indent(
907 Some(id.clone()),
908 block.clone(),
909 true,
910 cx,
911 );
912 });
913
914 thread.update(cx, |thread, cx| {
915 thread.push_acp_user_block(id, [block], path_style, cx);
916 });
917 }
918 context_server::types::Role::Assistant => {
919 acp_thread.update(cx, |acp_thread, cx| {
920 acp_thread.push_assistant_content_block_with_indent(
921 block.clone(),
922 false,
923 true,
924 cx,
925 );
926 });
927
928 thread.update(cx, |thread, cx| {
929 thread.push_acp_agent_block(block, cx);
930 });
931 }
932 }
933
934 last_is_user = role == context_server::types::Role::User;
935 }
936
937 let response_stream = thread.update(cx, |thread, cx| {
938 if last_is_user {
939 thread.send_existing(cx)
940 } else {
941 // Resume if MCP prompt did not end with a user message
942 thread.resume(cx)
943 }
944 })?;
945
946 cx.update(|cx| {
947 NativeAgentConnection::handle_thread_events(
948 response_stream,
949 acp_thread.downgrade(),
950 cx,
951 )
952 })
953 .await
954 })
955 }
956}
957
958/// Wrapper struct that implements the AgentConnection trait
959#[derive(Clone)]
960pub struct NativeAgentConnection(pub Entity<NativeAgent>);
961
962impl NativeAgentConnection {
963 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
964 self.0
965 .read(cx)
966 .sessions
967 .get(session_id)
968 .map(|session| session.thread.clone())
969 }
970
971 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
972 self.0.update(cx, |this, cx| this.load_thread(id, cx))
973 }
974
975 fn run_turn(
976 &self,
977 session_id: acp::SessionId,
978 cx: &mut App,
979 f: impl 'static
980 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
981 ) -> Task<Result<acp::PromptResponse>> {
982 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
983 agent
984 .sessions
985 .get_mut(&session_id)
986 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
987 }) else {
988 return Task::ready(Err(anyhow!("Session not found")));
989 };
990 log::debug!("Found session for: {}", session_id);
991
992 let response_stream = match f(thread, cx) {
993 Ok(stream) => stream,
994 Err(err) => return Task::ready(Err(err)),
995 };
996 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
997 }
998
999 fn handle_thread_events(
1000 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1001 acp_thread: WeakEntity<AcpThread>,
1002 cx: &App,
1003 ) -> Task<Result<acp::PromptResponse>> {
1004 cx.spawn(async move |cx| {
1005 // Handle response stream and forward to session.acp_thread
1006 while let Some(result) = events.next().await {
1007 match result {
1008 Ok(event) => {
1009 log::trace!("Received completion event: {:?}", event);
1010
1011 match event {
1012 ThreadEvent::UserMessage(message) => {
1013 acp_thread.update(cx, |thread, cx| {
1014 for content in message.content {
1015 thread.push_user_content_block(
1016 Some(message.id.clone()),
1017 content.into(),
1018 cx,
1019 );
1020 }
1021 })?;
1022 }
1023 ThreadEvent::AgentText(text) => {
1024 acp_thread.update(cx, |thread, cx| {
1025 thread.push_assistant_content_block(text.into(), false, cx)
1026 })?;
1027 }
1028 ThreadEvent::AgentThinking(text) => {
1029 acp_thread.update(cx, |thread, cx| {
1030 thread.push_assistant_content_block(text.into(), true, cx)
1031 })?;
1032 }
1033 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1034 tool_call,
1035 options,
1036 response,
1037 context: _,
1038 }) => {
1039 let outcome_task = acp_thread.update(cx, |thread, cx| {
1040 thread.request_tool_call_authorization(tool_call, options, cx)
1041 })??;
1042 cx.background_spawn(async move {
1043 if let acp::RequestPermissionOutcome::Selected(
1044 acp::SelectedPermissionOutcome { option_id, .. },
1045 ) = outcome_task.await
1046 {
1047 response
1048 .send(option_id)
1049 .map(|_| anyhow!("authorization receiver was dropped"))
1050 .log_err();
1051 }
1052 })
1053 .detach();
1054 }
1055 ThreadEvent::ToolCall(tool_call) => {
1056 acp_thread.update(cx, |thread, cx| {
1057 thread.upsert_tool_call(tool_call, cx)
1058 })??;
1059 }
1060 ThreadEvent::ToolCallUpdate(update) => {
1061 acp_thread.update(cx, |thread, cx| {
1062 thread.update_tool_call(update, cx)
1063 })??;
1064 }
1065 ThreadEvent::SubagentSpawned(session_id) => {
1066 acp_thread.update(cx, |thread, cx| {
1067 thread.subagent_spawned(session_id, cx);
1068 })?;
1069 }
1070 ThreadEvent::Retry(status) => {
1071 acp_thread.update(cx, |thread, cx| {
1072 thread.update_retry_status(status, cx)
1073 })?;
1074 }
1075 ThreadEvent::Stop(stop_reason) => {
1076 log::debug!("Assistant message complete: {:?}", stop_reason);
1077 return Ok(acp::PromptResponse::new(stop_reason));
1078 }
1079 }
1080 }
1081 Err(e) => {
1082 log::error!("Error in model response stream: {:?}", e);
1083 return Err(e);
1084 }
1085 }
1086 }
1087
1088 log::debug!("Response stream completed");
1089 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1090 })
1091 }
1092}
1093
1094struct Command<'a> {
1095 prompt_name: &'a str,
1096 arg_value: &'a str,
1097 explicit_server_id: Option<&'a str>,
1098}
1099
1100impl<'a> Command<'a> {
1101 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1102 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1103 return None;
1104 };
1105 let text = text_content.text.trim();
1106 let command = text.strip_prefix('/')?;
1107 let (command, arg_value) = command
1108 .split_once(char::is_whitespace)
1109 .unwrap_or((command, ""));
1110
1111 if let Some((server_id, prompt_name)) = command.split_once('.') {
1112 Some(Self {
1113 prompt_name,
1114 arg_value,
1115 explicit_server_id: Some(server_id),
1116 })
1117 } else {
1118 Some(Self {
1119 prompt_name: command,
1120 arg_value,
1121 explicit_server_id: None,
1122 })
1123 }
1124 }
1125}
1126
1127struct NativeAgentModelSelector {
1128 session_id: acp::SessionId,
1129 connection: NativeAgentConnection,
1130}
1131
1132impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1133 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1134 log::debug!("NativeAgentConnection::list_models called");
1135 let list = self.connection.0.read(cx).models.model_list.clone();
1136 Task::ready(if list.is_empty() {
1137 Err(anyhow::anyhow!("No models available"))
1138 } else {
1139 Ok(list)
1140 })
1141 }
1142
1143 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1144 log::debug!(
1145 "Setting model for session {}: {}",
1146 self.session_id,
1147 model_id
1148 );
1149 let Some(thread) = self
1150 .connection
1151 .0
1152 .read(cx)
1153 .sessions
1154 .get(&self.session_id)
1155 .map(|session| session.thread.clone())
1156 else {
1157 return Task::ready(Err(anyhow!("Session not found")));
1158 };
1159
1160 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1161 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1162 };
1163
1164 // We want to reset the effort level when switching models, as the currently-selected effort level may
1165 // not be compatible.
1166 let effort = model
1167 .default_effort_level()
1168 .map(|effort_level| effort_level.value.to_string());
1169
1170 thread.update(cx, |thread, cx| {
1171 thread.set_model(model.clone(), cx);
1172 thread.set_thinking_effort(effort.clone(), cx);
1173 thread.set_thinking_enabled(model.supports_thinking(), cx);
1174 });
1175
1176 update_settings_file(
1177 self.connection.0.read(cx).fs.clone(),
1178 cx,
1179 move |settings, cx| {
1180 let provider = model.provider_id().0.to_string();
1181 let model = model.id().0.to_string();
1182 let enable_thinking = thread.read(cx).thinking_enabled();
1183 settings
1184 .agent
1185 .get_or_insert_default()
1186 .set_model(LanguageModelSelection {
1187 provider: provider.into(),
1188 model,
1189 enable_thinking,
1190 effort,
1191 });
1192 },
1193 );
1194
1195 Task::ready(Ok(()))
1196 }
1197
1198 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1199 let Some(thread) = self
1200 .connection
1201 .0
1202 .read(cx)
1203 .sessions
1204 .get(&self.session_id)
1205 .map(|session| session.thread.clone())
1206 else {
1207 return Task::ready(Err(anyhow!("Session not found")));
1208 };
1209 let Some(model) = thread.read(cx).model() else {
1210 return Task::ready(Err(anyhow!("Model not found")));
1211 };
1212 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1213 else {
1214 return Task::ready(Err(anyhow!("Provider not found")));
1215 };
1216 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1217 model, &provider,
1218 )))
1219 }
1220
1221 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1222 Some(self.connection.0.read(cx).models.watch())
1223 }
1224
1225 fn should_render_footer(&self) -> bool {
1226 true
1227 }
1228}
1229
1230impl acp_thread::AgentConnection for NativeAgentConnection {
1231 fn telemetry_id(&self) -> SharedString {
1232 "zed".into()
1233 }
1234
1235 fn new_session(
1236 self: Rc<Self>,
1237 project: Entity<Project>,
1238 cwd: &Path,
1239 cx: &mut App,
1240 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1241 log::debug!("Creating new thread for project at: {cwd:?}");
1242 Task::ready(Ok(self
1243 .0
1244 .update(cx, |agent, cx| agent.new_session(project, cx))))
1245 }
1246
1247 fn supports_load_session(&self) -> bool {
1248 true
1249 }
1250
1251 fn load_session(
1252 self: Rc<Self>,
1253 session: AgentSessionInfo,
1254 _project: Entity<Project>,
1255 _cwd: &Path,
1256 cx: &mut App,
1257 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1258 self.0
1259 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1260 }
1261
1262 fn supports_close_session(&self) -> bool {
1263 true
1264 }
1265
1266 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1267 self.0.update(cx, |agent, _cx| {
1268 agent.sessions.remove(session_id);
1269 });
1270 Task::ready(Ok(()))
1271 }
1272
1273 fn auth_methods(&self) -> &[acp::AuthMethod] {
1274 &[] // No auth for in-process
1275 }
1276
1277 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1278 Task::ready(Ok(()))
1279 }
1280
1281 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1282 Some(Rc::new(NativeAgentModelSelector {
1283 session_id: session_id.clone(),
1284 connection: self.clone(),
1285 }) as Rc<dyn AgentModelSelector>)
1286 }
1287
1288 fn prompt(
1289 &self,
1290 id: Option<acp_thread::UserMessageId>,
1291 params: acp::PromptRequest,
1292 cx: &mut App,
1293 ) -> Task<Result<acp::PromptResponse>> {
1294 let id = id.expect("UserMessageId is required");
1295 let session_id = params.session_id.clone();
1296 log::info!("Received prompt request for session: {}", session_id);
1297 log::debug!("Prompt blocks count: {}", params.prompt.len());
1298
1299 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1300 let registry = self.0.read(cx).context_server_registry.read(cx);
1301
1302 let explicit_server_id = parsed_command
1303 .explicit_server_id
1304 .map(|server_id| ContextServerId(server_id.into()));
1305
1306 if let Some(prompt) =
1307 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1308 {
1309 let arguments = if !parsed_command.arg_value.is_empty()
1310 && let Some(arg_name) = prompt
1311 .prompt
1312 .arguments
1313 .as_ref()
1314 .and_then(|args| args.first())
1315 .map(|arg| arg.name.clone())
1316 {
1317 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1318 } else {
1319 Default::default()
1320 };
1321
1322 let prompt_name = prompt.prompt.name.clone();
1323 let server_id = prompt.server_id.clone();
1324
1325 return self.0.update(cx, |agent, cx| {
1326 agent.send_mcp_prompt(
1327 id,
1328 session_id.clone(),
1329 prompt_name,
1330 server_id,
1331 arguments,
1332 params.prompt,
1333 cx,
1334 )
1335 });
1336 };
1337 };
1338
1339 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1340
1341 self.run_turn(session_id, cx, move |thread, cx| {
1342 let content: Vec<UserMessageContent> = params
1343 .prompt
1344 .into_iter()
1345 .map(|block| UserMessageContent::from_content_block(block, path_style))
1346 .collect::<Vec<_>>();
1347 log::debug!("Converted prompt to message: {} chars", content.len());
1348 log::debug!("Message id: {:?}", id);
1349 log::debug!("Message content: {:?}", content);
1350
1351 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1352 })
1353 }
1354
1355 fn retry(
1356 &self,
1357 session_id: &acp::SessionId,
1358 _cx: &App,
1359 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1360 Some(Rc::new(NativeAgentSessionRetry {
1361 connection: self.clone(),
1362 session_id: session_id.clone(),
1363 }) as _)
1364 }
1365
1366 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1367 log::info!("Cancelling on session: {}", session_id);
1368 self.0.update(cx, |agent, cx| {
1369 if let Some(session) = agent.sessions.get(session_id) {
1370 session
1371 .thread
1372 .update(cx, |thread, cx| thread.cancel(cx))
1373 .detach();
1374 }
1375 });
1376 }
1377
1378 fn truncate(
1379 &self,
1380 session_id: &agent_client_protocol::SessionId,
1381 cx: &App,
1382 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1383 self.0.read_with(cx, |agent, _cx| {
1384 agent.sessions.get(session_id).map(|session| {
1385 Rc::new(NativeAgentSessionTruncate {
1386 thread: session.thread.clone(),
1387 acp_thread: session.acp_thread.downgrade(),
1388 }) as _
1389 })
1390 })
1391 }
1392
1393 fn set_title(
1394 &self,
1395 session_id: &acp::SessionId,
1396 cx: &App,
1397 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1398 self.0.read_with(cx, |agent, _cx| {
1399 agent
1400 .sessions
1401 .get(session_id)
1402 .filter(|s| !s.thread.read(cx).is_subagent())
1403 .map(|session| {
1404 Rc::new(NativeAgentSessionSetTitle {
1405 thread: session.thread.clone(),
1406 }) as _
1407 })
1408 })
1409 }
1410
1411 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1412 let thread_store = self.0.read(cx).thread_store.clone();
1413 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1414 }
1415
1416 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1417 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1418 }
1419
1420 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1421 self
1422 }
1423}
1424
1425impl acp_thread::AgentTelemetry for NativeAgentConnection {
1426 fn thread_data(
1427 &self,
1428 session_id: &acp::SessionId,
1429 cx: &mut App,
1430 ) -> Task<Result<serde_json::Value>> {
1431 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1432 return Task::ready(Err(anyhow!("Session not found")));
1433 };
1434
1435 let task = session.thread.read(cx).to_db(cx);
1436 cx.background_spawn(async move {
1437 serde_json::to_value(task.await).context("Failed to serialize thread")
1438 })
1439 }
1440}
1441
1442pub struct NativeAgentSessionList {
1443 thread_store: Entity<ThreadStore>,
1444 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1445 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1446 _subscription: Subscription,
1447}
1448
1449impl NativeAgentSessionList {
1450 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1451 let (tx, rx) = smol::channel::unbounded();
1452 let this_tx = tx.clone();
1453 let subscription = cx.observe(&thread_store, move |_, _| {
1454 this_tx
1455 .try_send(acp_thread::SessionListUpdate::Refresh)
1456 .ok();
1457 });
1458 Self {
1459 thread_store,
1460 updates_tx: tx,
1461 updates_rx: rx,
1462 _subscription: subscription,
1463 }
1464 }
1465
1466 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1467 AgentSessionInfo {
1468 session_id: entry.id,
1469 cwd: None,
1470 title: Some(entry.title),
1471 updated_at: Some(entry.updated_at),
1472 meta: None,
1473 }
1474 }
1475
1476 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1477 &self.thread_store
1478 }
1479}
1480
1481impl AgentSessionList for NativeAgentSessionList {
1482 fn list_sessions(
1483 &self,
1484 _request: AgentSessionListRequest,
1485 cx: &mut App,
1486 ) -> Task<Result<AgentSessionListResponse>> {
1487 let sessions = self
1488 .thread_store
1489 .read(cx)
1490 .entries()
1491 .map(Self::to_session_info)
1492 .collect();
1493 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1494 }
1495
1496 fn supports_delete(&self) -> bool {
1497 true
1498 }
1499
1500 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1501 self.thread_store
1502 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1503 }
1504
1505 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1506 self.thread_store
1507 .update(cx, |store, cx| store.delete_threads(cx))
1508 }
1509
1510 fn watch(
1511 &self,
1512 _cx: &mut App,
1513 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1514 Some(self.updates_rx.clone())
1515 }
1516
1517 fn notify_refresh(&self) {
1518 self.updates_tx
1519 .try_send(acp_thread::SessionListUpdate::Refresh)
1520 .ok();
1521 }
1522
1523 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1524 self
1525 }
1526}
1527
1528struct NativeAgentSessionTruncate {
1529 thread: Entity<Thread>,
1530 acp_thread: WeakEntity<AcpThread>,
1531}
1532
1533impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1534 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1535 match self.thread.update(cx, |thread, cx| {
1536 thread.truncate(message_id.clone(), cx)?;
1537 Ok(thread.latest_token_usage())
1538 }) {
1539 Ok(usage) => {
1540 self.acp_thread
1541 .update(cx, |thread, cx| {
1542 thread.update_token_usage(usage, cx);
1543 })
1544 .ok();
1545 Task::ready(Ok(()))
1546 }
1547 Err(error) => Task::ready(Err(error)),
1548 }
1549 }
1550}
1551
1552struct NativeAgentSessionRetry {
1553 connection: NativeAgentConnection,
1554 session_id: acp::SessionId,
1555}
1556
1557impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1558 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1559 self.connection
1560 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1561 thread.update(cx, |thread, cx| thread.resume(cx))
1562 })
1563 }
1564}
1565
1566struct NativeAgentSessionSetTitle {
1567 thread: Entity<Thread>,
1568}
1569
1570impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1571 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1572 self.thread
1573 .update(cx, |thread, cx| thread.set_title(title, cx));
1574 Task::ready(Ok(()))
1575 }
1576}
1577
1578pub struct NativeThreadEnvironment {
1579 agent: WeakEntity<NativeAgent>,
1580 acp_thread: WeakEntity<AcpThread>,
1581}
1582
1583impl NativeThreadEnvironment {
1584 pub(crate) fn create_subagent_thread(
1585 agent: WeakEntity<NativeAgent>,
1586 parent_thread_entity: Entity<Thread>,
1587 label: String,
1588 initial_prompt: String,
1589 timeout: Option<Duration>,
1590 cx: &mut App,
1591 ) -> Result<Rc<dyn SubagentHandle>> {
1592 let parent_thread = parent_thread_entity.read(cx);
1593 let current_depth = parent_thread.depth();
1594
1595 if current_depth >= MAX_SUBAGENT_DEPTH {
1596 return Err(anyhow!(
1597 "Maximum subagent depth ({}) reached",
1598 MAX_SUBAGENT_DEPTH
1599 ));
1600 }
1601
1602 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1603 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1604 thread.set_title(label.into(), cx);
1605 thread
1606 });
1607
1608 let session_id = subagent_thread.read(cx).id().clone();
1609
1610 let acp_thread = agent.update(cx, |agent, cx| {
1611 agent.register_session(subagent_thread.clone(), cx)
1612 })?;
1613
1614 parent_thread_entity.update(cx, |parent_thread, _cx| {
1615 parent_thread.register_running_subagent(subagent_thread.downgrade())
1616 });
1617
1618 let task = acp_thread.update(cx, |agent, cx| agent.send(vec![initial_prompt.into()], cx));
1619
1620 let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1621 let wait_for_prompt_to_complete = cx
1622 .background_spawn(async move {
1623 if let Some(timer) = timeout_timer {
1624 futures::select! {
1625 _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1626 response = task.fuse() => {
1627 let response = response.log_err().flatten();
1628 if response.is_some_and(|response| {
1629 response.stop_reason == acp::StopReason::Cancelled
1630 })
1631 {
1632 SubagentInitialPromptResult::Cancelled
1633 } else {
1634 SubagentInitialPromptResult::Completed
1635 }
1636 },
1637 }
1638 } else {
1639 let response = task.await.log_err().flatten();
1640 if response
1641 .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled)
1642 {
1643 SubagentInitialPromptResult::Cancelled
1644 } else {
1645 SubagentInitialPromptResult::Completed
1646 }
1647 }
1648 })
1649 .shared();
1650
1651 Ok(Rc::new(NativeSubagentHandle {
1652 session_id,
1653 subagent_thread,
1654 parent_thread: parent_thread_entity.downgrade(),
1655 wait_for_prompt_to_complete,
1656 }) as _)
1657 }
1658}
1659
1660impl ThreadEnvironment for NativeThreadEnvironment {
1661 fn create_terminal(
1662 &self,
1663 command: String,
1664 cwd: Option<PathBuf>,
1665 output_byte_limit: Option<u64>,
1666 cx: &mut AsyncApp,
1667 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1668 let task = self.acp_thread.update(cx, |thread, cx| {
1669 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1670 });
1671
1672 let acp_thread = self.acp_thread.clone();
1673 cx.spawn(async move |cx| {
1674 let terminal = task?.await?;
1675
1676 let (drop_tx, drop_rx) = oneshot::channel();
1677 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1678
1679 cx.spawn(async move |cx| {
1680 drop_rx.await.ok();
1681 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1682 })
1683 .detach();
1684
1685 let handle = AcpTerminalHandle {
1686 terminal,
1687 _drop_tx: Some(drop_tx),
1688 };
1689
1690 Ok(Rc::new(handle) as _)
1691 })
1692 }
1693
1694 fn create_subagent(
1695 &self,
1696 parent_thread_entity: Entity<Thread>,
1697 label: String,
1698 initial_prompt: String,
1699 timeout: Option<Duration>,
1700 cx: &mut App,
1701 ) -> Result<Rc<dyn SubagentHandle>> {
1702 Self::create_subagent_thread(
1703 self.agent.clone(),
1704 parent_thread_entity,
1705 label,
1706 initial_prompt,
1707 timeout,
1708 cx,
1709 )
1710 }
1711}
1712
1713#[derive(Debug, Clone, Copy)]
1714enum SubagentInitialPromptResult {
1715 Completed,
1716 Timeout,
1717 Cancelled,
1718}
1719
1720pub struct NativeSubagentHandle {
1721 session_id: acp::SessionId,
1722 parent_thread: WeakEntity<Thread>,
1723 subagent_thread: Entity<Thread>,
1724 wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1725}
1726
1727impl SubagentHandle for NativeSubagentHandle {
1728 fn id(&self) -> acp::SessionId {
1729 self.session_id.clone()
1730 }
1731
1732 fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
1733 let thread = self.subagent_thread.clone();
1734 let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1735
1736 let subagent_session_id = self.session_id.clone();
1737 let parent_thread = self.parent_thread.clone();
1738
1739 cx.spawn(async move |cx| {
1740 match wait_for_prompt.await {
1741 SubagentInitialPromptResult::Completed => {}
1742 SubagentInitialPromptResult::Timeout => {
1743 return Err(anyhow!("The time to complete the task was exceeded."));
1744 }
1745 SubagentInitialPromptResult::Cancelled => return Err(anyhow!("User cancelled")),
1746 };
1747
1748 let result = thread.read_with(cx, |thread, _cx| {
1749 thread
1750 .last_message()
1751 .map(|m| m.to_markdown())
1752 .context("No response from subagent")
1753 });
1754
1755 parent_thread
1756 .update(cx, |parent_thread, cx| {
1757 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1758 })
1759 .ok();
1760
1761 result
1762 })
1763 }
1764}
1765
1766pub struct AcpTerminalHandle {
1767 terminal: Entity<acp_thread::Terminal>,
1768 _drop_tx: Option<oneshot::Sender<()>>,
1769}
1770
1771impl TerminalHandle for AcpTerminalHandle {
1772 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1773 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1774 }
1775
1776 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1777 Ok(self
1778 .terminal
1779 .read_with(cx, |term, _cx| term.wait_for_exit()))
1780 }
1781
1782 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1783 Ok(self
1784 .terminal
1785 .read_with(cx, |term, cx| term.current_output(cx)))
1786 }
1787
1788 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1789 cx.update(|cx| {
1790 self.terminal.update(cx, |terminal, cx| {
1791 terminal.kill(cx);
1792 });
1793 });
1794 Ok(())
1795 }
1796
1797 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1798 Ok(self
1799 .terminal
1800 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1801 }
1802}
1803
1804#[cfg(test)]
1805mod internal_tests {
1806 use super::*;
1807 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1808 use fs::FakeFs;
1809 use gpui::TestAppContext;
1810 use indoc::formatdoc;
1811 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1812 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1813 use serde_json::json;
1814 use settings::SettingsStore;
1815 use util::{path, rel_path::rel_path};
1816
1817 #[gpui::test]
1818 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1819 init_test(cx);
1820 let fs = FakeFs::new(cx.executor());
1821 fs.insert_tree(
1822 "/",
1823 json!({
1824 "a": {}
1825 }),
1826 )
1827 .await;
1828 let project = Project::test(fs.clone(), [], cx).await;
1829 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1830 let agent = NativeAgent::new(
1831 project.clone(),
1832 thread_store,
1833 Templates::new(),
1834 None,
1835 fs.clone(),
1836 &mut cx.to_async(),
1837 )
1838 .await
1839 .unwrap();
1840 agent.read_with(cx, |agent, cx| {
1841 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1842 });
1843
1844 let worktree = project
1845 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1846 .await
1847 .unwrap();
1848 cx.run_until_parked();
1849 agent.read_with(cx, |agent, cx| {
1850 assert_eq!(
1851 agent.project_context.read(cx).worktrees,
1852 vec![WorktreeContext {
1853 root_name: "a".into(),
1854 abs_path: Path::new("/a").into(),
1855 rules_file: None
1856 }]
1857 )
1858 });
1859
1860 // Creating `/a/.rules` updates the project context.
1861 fs.insert_file("/a/.rules", Vec::new()).await;
1862 cx.run_until_parked();
1863 agent.read_with(cx, |agent, cx| {
1864 let rules_entry = worktree
1865 .read(cx)
1866 .entry_for_path(rel_path(".rules"))
1867 .unwrap();
1868 assert_eq!(
1869 agent.project_context.read(cx).worktrees,
1870 vec![WorktreeContext {
1871 root_name: "a".into(),
1872 abs_path: Path::new("/a").into(),
1873 rules_file: Some(RulesFileContext {
1874 path_in_worktree: rel_path(".rules").into(),
1875 text: "".into(),
1876 project_entry_id: rules_entry.id.to_usize()
1877 })
1878 }]
1879 )
1880 });
1881 }
1882
1883 #[gpui::test]
1884 async fn test_listing_models(cx: &mut TestAppContext) {
1885 init_test(cx);
1886 let fs = FakeFs::new(cx.executor());
1887 fs.insert_tree("/", json!({ "a": {} })).await;
1888 let project = Project::test(fs.clone(), [], cx).await;
1889 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1890 let connection = NativeAgentConnection(
1891 NativeAgent::new(
1892 project.clone(),
1893 thread_store,
1894 Templates::new(),
1895 None,
1896 fs.clone(),
1897 &mut cx.to_async(),
1898 )
1899 .await
1900 .unwrap(),
1901 );
1902
1903 // Create a thread/session
1904 let acp_thread = cx
1905 .update(|cx| {
1906 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1907 })
1908 .await
1909 .unwrap();
1910
1911 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1912
1913 let models = cx
1914 .update(|cx| {
1915 connection
1916 .model_selector(&session_id)
1917 .unwrap()
1918 .list_models(cx)
1919 })
1920 .await
1921 .unwrap();
1922
1923 let acp_thread::AgentModelList::Grouped(models) = models else {
1924 panic!("Unexpected model group");
1925 };
1926 assert_eq!(
1927 models,
1928 IndexMap::from_iter([(
1929 AgentModelGroupName("Fake".into()),
1930 vec![AgentModelInfo {
1931 id: acp::ModelId::new("fake/fake"),
1932 name: "Fake".into(),
1933 description: None,
1934 icon: Some(acp_thread::AgentModelIcon::Named(
1935 ui::IconName::ZedAssistant
1936 )),
1937 is_latest: false,
1938 cost: None,
1939 }]
1940 )])
1941 );
1942 }
1943
1944 #[gpui::test]
1945 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1946 init_test(cx);
1947 let fs = FakeFs::new(cx.executor());
1948 fs.create_dir(paths::settings_file().parent().unwrap())
1949 .await
1950 .unwrap();
1951 fs.insert_file(
1952 paths::settings_file(),
1953 json!({
1954 "agent": {
1955 "default_model": {
1956 "provider": "foo",
1957 "model": "bar"
1958 }
1959 }
1960 })
1961 .to_string()
1962 .into_bytes(),
1963 )
1964 .await;
1965 let project = Project::test(fs.clone(), [], cx).await;
1966
1967 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1968
1969 // Create the agent and connection
1970 let agent = NativeAgent::new(
1971 project.clone(),
1972 thread_store,
1973 Templates::new(),
1974 None,
1975 fs.clone(),
1976 &mut cx.to_async(),
1977 )
1978 .await
1979 .unwrap();
1980 let connection = NativeAgentConnection(agent.clone());
1981
1982 // Create a thread/session
1983 let acp_thread = cx
1984 .update(|cx| {
1985 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1986 })
1987 .await
1988 .unwrap();
1989
1990 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1991
1992 // Select a model
1993 let selector = connection.model_selector(&session_id).unwrap();
1994 let model_id = acp::ModelId::new("fake/fake");
1995 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1996 .await
1997 .unwrap();
1998
1999 // Verify the thread has the selected model
2000 agent.read_with(cx, |agent, _| {
2001 let session = agent.sessions.get(&session_id).unwrap();
2002 session.thread.read_with(cx, |thread, _| {
2003 assert_eq!(thread.model().unwrap().id().0, "fake");
2004 });
2005 });
2006
2007 cx.run_until_parked();
2008
2009 // Verify settings file was updated
2010 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2011 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2012
2013 // Check that the agent settings contain the selected model
2014 assert_eq!(
2015 settings_json["agent"]["default_model"]["model"],
2016 json!("fake")
2017 );
2018 assert_eq!(
2019 settings_json["agent"]["default_model"]["provider"],
2020 json!("fake")
2021 );
2022
2023 // Register a thinking model and select it.
2024 cx.update(|cx| {
2025 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2026 "fake-corp",
2027 "fake-thinking",
2028 "Fake Thinking",
2029 true,
2030 ));
2031 let thinking_provider = Arc::new(
2032 FakeLanguageModelProvider::new(
2033 LanguageModelProviderId::from("fake-corp".to_string()),
2034 LanguageModelProviderName::from("Fake Corp".to_string()),
2035 )
2036 .with_models(vec![thinking_model]),
2037 );
2038 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2039 registry.register_provider(thinking_provider, cx);
2040 });
2041 });
2042 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2043
2044 let selector = connection.model_selector(&session_id).unwrap();
2045 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2046 .await
2047 .unwrap();
2048 cx.run_until_parked();
2049
2050 // Verify enable_thinking was written to settings as true.
2051 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2052 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2053 assert_eq!(
2054 settings_json["agent"]["default_model"]["enable_thinking"],
2055 json!(true),
2056 "selecting a thinking model should persist enable_thinking: true to settings"
2057 );
2058 }
2059
2060 #[gpui::test]
2061 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2062 init_test(cx);
2063 let fs = FakeFs::new(cx.executor());
2064 fs.create_dir(paths::settings_file().parent().unwrap())
2065 .await
2066 .unwrap();
2067 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2068 let project = Project::test(fs.clone(), [], cx).await;
2069
2070 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2071 let agent = NativeAgent::new(
2072 project.clone(),
2073 thread_store,
2074 Templates::new(),
2075 None,
2076 fs.clone(),
2077 &mut cx.to_async(),
2078 )
2079 .await
2080 .unwrap();
2081 let connection = NativeAgentConnection(agent.clone());
2082
2083 let acp_thread = cx
2084 .update(|cx| {
2085 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2086 })
2087 .await
2088 .unwrap();
2089 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2090
2091 // Register a second provider with a thinking model.
2092 cx.update(|cx| {
2093 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2094 "fake-corp",
2095 "fake-thinking",
2096 "Fake Thinking",
2097 true,
2098 ));
2099 let thinking_provider = Arc::new(
2100 FakeLanguageModelProvider::new(
2101 LanguageModelProviderId::from("fake-corp".to_string()),
2102 LanguageModelProviderName::from("Fake Corp".to_string()),
2103 )
2104 .with_models(vec![thinking_model]),
2105 );
2106 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2107 registry.register_provider(thinking_provider, cx);
2108 });
2109 });
2110 // Refresh the agent's model list so it picks up the new provider.
2111 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2112
2113 // Thread starts with thinking_enabled = false (the default).
2114 agent.read_with(cx, |agent, _| {
2115 let session = agent.sessions.get(&session_id).unwrap();
2116 session.thread.read_with(cx, |thread, _| {
2117 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2118 });
2119 });
2120
2121 // Select the thinking model via select_model.
2122 let selector = connection.model_selector(&session_id).unwrap();
2123 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2124 .await
2125 .unwrap();
2126
2127 // select_model should have enabled thinking based on the model's supports_thinking().
2128 agent.read_with(cx, |agent, _| {
2129 let session = agent.sessions.get(&session_id).unwrap();
2130 session.thread.read_with(cx, |thread, _| {
2131 assert!(
2132 thread.thinking_enabled(),
2133 "select_model should enable thinking when model supports it"
2134 );
2135 });
2136 });
2137
2138 // Switch back to the non-thinking model.
2139 let selector = connection.model_selector(&session_id).unwrap();
2140 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2141 .await
2142 .unwrap();
2143
2144 // select_model should have disabled thinking.
2145 agent.read_with(cx, |agent, _| {
2146 let session = agent.sessions.get(&session_id).unwrap();
2147 session.thread.read_with(cx, |thread, _| {
2148 assert!(
2149 !thread.thinking_enabled(),
2150 "select_model should disable thinking when model does not support it"
2151 );
2152 });
2153 });
2154 }
2155
2156 #[gpui::test]
2157 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2158 init_test(cx);
2159 let fs = FakeFs::new(cx.executor());
2160 fs.insert_tree("/", json!({ "a": {} })).await;
2161 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2162 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2163 let agent = NativeAgent::new(
2164 project.clone(),
2165 thread_store.clone(),
2166 Templates::new(),
2167 None,
2168 fs.clone(),
2169 &mut cx.to_async(),
2170 )
2171 .await
2172 .unwrap();
2173 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2174
2175 // Register a thinking model.
2176 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2177 "fake-corp",
2178 "fake-thinking",
2179 "Fake Thinking",
2180 true,
2181 ));
2182 let thinking_provider = Arc::new(
2183 FakeLanguageModelProvider::new(
2184 LanguageModelProviderId::from("fake-corp".to_string()),
2185 LanguageModelProviderName::from("Fake Corp".to_string()),
2186 )
2187 .with_models(vec![thinking_model.clone()]),
2188 );
2189 cx.update(|cx| {
2190 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2191 registry.register_provider(thinking_provider, cx);
2192 });
2193 });
2194 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2195
2196 // Create a thread and select the thinking model.
2197 let acp_thread = cx
2198 .update(|cx| {
2199 connection
2200 .clone()
2201 .new_session(project.clone(), Path::new("/a"), cx)
2202 })
2203 .await
2204 .unwrap();
2205 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2206
2207 let selector = connection.model_selector(&session_id).unwrap();
2208 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2209 .await
2210 .unwrap();
2211
2212 // Verify thinking is enabled after selecting the thinking model.
2213 let thread = agent.read_with(cx, |agent, _| {
2214 agent.sessions.get(&session_id).unwrap().thread.clone()
2215 });
2216 thread.read_with(cx, |thread, _| {
2217 assert!(
2218 thread.thinking_enabled(),
2219 "thinking should be enabled after selecting thinking model"
2220 );
2221 });
2222
2223 // Send a message so the thread gets persisted.
2224 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2225 let send = cx.foreground_executor().spawn(send);
2226 cx.run_until_parked();
2227
2228 thinking_model.send_last_completion_stream_text_chunk("Response.");
2229 thinking_model.end_last_completion_stream();
2230
2231 send.await.unwrap();
2232 cx.run_until_parked();
2233
2234 // Close the session so it can be reloaded from disk.
2235 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2236 .await
2237 .unwrap();
2238 drop(thread);
2239 drop(acp_thread);
2240 agent.read_with(cx, |agent, _| {
2241 assert!(agent.sessions.is_empty());
2242 });
2243
2244 // Reload the thread and verify thinking_enabled is still true.
2245 let reloaded_acp_thread = agent
2246 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2247 .await
2248 .unwrap();
2249 let reloaded_thread = agent.read_with(cx, |agent, _| {
2250 agent.sessions.get(&session_id).unwrap().thread.clone()
2251 });
2252 reloaded_thread.read_with(cx, |thread, _| {
2253 assert!(
2254 thread.thinking_enabled(),
2255 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2256 );
2257 });
2258
2259 drop(reloaded_acp_thread);
2260 }
2261
2262 #[gpui::test]
2263 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2264 init_test(cx);
2265 let fs = FakeFs::new(cx.executor());
2266 fs.insert_tree("/", json!({ "a": {} })).await;
2267 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2268 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2269 let agent = NativeAgent::new(
2270 project.clone(),
2271 thread_store.clone(),
2272 Templates::new(),
2273 None,
2274 fs.clone(),
2275 &mut cx.to_async(),
2276 )
2277 .await
2278 .unwrap();
2279 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2280
2281 // Register a model where id() != name(), like real Anthropic models
2282 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2283 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2284 "fake-corp",
2285 "custom-model-id",
2286 "Custom Model Display Name",
2287 false,
2288 ));
2289 let provider = Arc::new(
2290 FakeLanguageModelProvider::new(
2291 LanguageModelProviderId::from("fake-corp".to_string()),
2292 LanguageModelProviderName::from("Fake Corp".to_string()),
2293 )
2294 .with_models(vec![model.clone()]),
2295 );
2296 cx.update(|cx| {
2297 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2298 registry.register_provider(provider, cx);
2299 });
2300 });
2301 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2302
2303 // Create a thread and select the model.
2304 let acp_thread = cx
2305 .update(|cx| {
2306 connection
2307 .clone()
2308 .new_session(project.clone(), Path::new("/a"), cx)
2309 })
2310 .await
2311 .unwrap();
2312 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2313
2314 let selector = connection.model_selector(&session_id).unwrap();
2315 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2316 .await
2317 .unwrap();
2318
2319 let thread = agent.read_with(cx, |agent, _| {
2320 agent.sessions.get(&session_id).unwrap().thread.clone()
2321 });
2322 thread.read_with(cx, |thread, _| {
2323 assert_eq!(
2324 thread.model().unwrap().id().0.as_ref(),
2325 "custom-model-id",
2326 "model should be set before persisting"
2327 );
2328 });
2329
2330 // Send a message so the thread gets persisted.
2331 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2332 let send = cx.foreground_executor().spawn(send);
2333 cx.run_until_parked();
2334
2335 model.send_last_completion_stream_text_chunk("Response.");
2336 model.end_last_completion_stream();
2337
2338 send.await.unwrap();
2339 cx.run_until_parked();
2340
2341 // Close the session so it can be reloaded from disk.
2342 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2343 .await
2344 .unwrap();
2345 drop(thread);
2346 drop(acp_thread);
2347 agent.read_with(cx, |agent, _| {
2348 assert!(agent.sessions.is_empty());
2349 });
2350
2351 // Reload the thread and verify the model was preserved.
2352 let reloaded_acp_thread = agent
2353 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2354 .await
2355 .unwrap();
2356 let reloaded_thread = agent.read_with(cx, |agent, _| {
2357 agent.sessions.get(&session_id).unwrap().thread.clone()
2358 });
2359 reloaded_thread.read_with(cx, |thread, _| {
2360 let reloaded_model = thread
2361 .model()
2362 .expect("model should be present after reload");
2363 assert_eq!(
2364 reloaded_model.id().0.as_ref(),
2365 "custom-model-id",
2366 "reloaded thread should have the same model, not fall back to the default"
2367 );
2368 });
2369
2370 drop(reloaded_acp_thread);
2371 }
2372
2373 #[gpui::test]
2374 async fn test_save_load_thread(cx: &mut TestAppContext) {
2375 init_test(cx);
2376 let fs = FakeFs::new(cx.executor());
2377 fs.insert_tree(
2378 "/",
2379 json!({
2380 "a": {
2381 "b.md": "Lorem"
2382 }
2383 }),
2384 )
2385 .await;
2386 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2387 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2388 let agent = NativeAgent::new(
2389 project.clone(),
2390 thread_store.clone(),
2391 Templates::new(),
2392 None,
2393 fs.clone(),
2394 &mut cx.to_async(),
2395 )
2396 .await
2397 .unwrap();
2398 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2399
2400 let acp_thread = cx
2401 .update(|cx| {
2402 connection
2403 .clone()
2404 .new_session(project.clone(), Path::new(""), cx)
2405 })
2406 .await
2407 .unwrap();
2408 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2409 let thread = agent.read_with(cx, |agent, _| {
2410 agent.sessions.get(&session_id).unwrap().thread.clone()
2411 });
2412
2413 // Ensure empty threads are not saved, even if they get mutated.
2414 let model = Arc::new(FakeLanguageModel::default());
2415 let summary_model = Arc::new(FakeLanguageModel::default());
2416 thread.update(cx, |thread, cx| {
2417 thread.set_model(model.clone(), cx);
2418 thread.set_summarization_model(Some(summary_model.clone()), cx);
2419 });
2420 cx.run_until_parked();
2421 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2422
2423 let send = acp_thread.update(cx, |thread, cx| {
2424 thread.send(
2425 vec![
2426 "What does ".into(),
2427 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2428 "b.md",
2429 MentionUri::File {
2430 abs_path: path!("/a/b.md").into(),
2431 }
2432 .to_uri()
2433 .to_string(),
2434 )),
2435 " mean?".into(),
2436 ],
2437 cx,
2438 )
2439 });
2440 let send = cx.foreground_executor().spawn(send);
2441 cx.run_until_parked();
2442
2443 model.send_last_completion_stream_text_chunk("Lorem.");
2444 model.end_last_completion_stream();
2445 cx.run_until_parked();
2446 summary_model
2447 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2448 summary_model.end_last_completion_stream();
2449
2450 send.await.unwrap();
2451 let uri = MentionUri::File {
2452 abs_path: path!("/a/b.md").into(),
2453 }
2454 .to_uri();
2455 acp_thread.read_with(cx, |thread, cx| {
2456 assert_eq!(
2457 thread.to_markdown(cx),
2458 formatdoc! {"
2459 ## User
2460
2461 What does [@b.md]({uri}) mean?
2462
2463 ## Assistant
2464
2465 Lorem.
2466
2467 "}
2468 )
2469 });
2470
2471 cx.run_until_parked();
2472
2473 // Close the session so it can be reloaded from disk.
2474 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2475 .await
2476 .unwrap();
2477 drop(thread);
2478 drop(acp_thread);
2479 agent.read_with(cx, |agent, _| {
2480 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2481 });
2482
2483 // Ensure the thread can be reloaded from disk.
2484 assert_eq!(
2485 thread_entries(&thread_store, cx),
2486 vec![(
2487 session_id.clone(),
2488 format!("Explaining {}", path!("/a/b.md"))
2489 )]
2490 );
2491 let acp_thread = agent
2492 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2493 .await
2494 .unwrap();
2495 acp_thread.read_with(cx, |thread, cx| {
2496 assert_eq!(
2497 thread.to_markdown(cx),
2498 formatdoc! {"
2499 ## User
2500
2501 What does [@b.md]({uri}) mean?
2502
2503 ## Assistant
2504
2505 Lorem.
2506
2507 "}
2508 )
2509 });
2510 }
2511
2512 fn thread_entries(
2513 thread_store: &Entity<ThreadStore>,
2514 cx: &mut TestAppContext,
2515 ) -> Vec<(acp::SessionId, String)> {
2516 thread_store.read_with(cx, |store, _| {
2517 store
2518 .entries()
2519 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2520 .collect::<Vec<_>>()
2521 })
2522 }
2523
2524 fn init_test(cx: &mut TestAppContext) {
2525 env_logger::try_init().ok();
2526 cx.update(|cx| {
2527 let settings_store = SettingsStore::test(cx);
2528 cx.set_global(settings_store);
2529
2530 LanguageModelRegistry::test(cx);
2531 });
2532 }
2533}
2534
2535fn mcp_message_content_to_acp_content_block(
2536 content: context_server::types::MessageContent,
2537) -> acp::ContentBlock {
2538 match content {
2539 context_server::types::MessageContent::Text {
2540 text,
2541 annotations: _,
2542 } => text.into(),
2543 context_server::types::MessageContent::Image {
2544 data,
2545 mime_type,
2546 annotations: _,
2547 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2548 context_server::types::MessageContent::Audio {
2549 data,
2550 mime_type,
2551 annotations: _,
2552 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2553 context_server::types::MessageContent::Resource {
2554 resource,
2555 annotations: _,
2556 } => {
2557 let mut link =
2558 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2559 if let Some(mime_type) = resource.mime_type {
2560 link = link.mime_type(mime_type);
2561 }
2562 acp::ContentBlock::ResourceLink(link)
2563 }
2564 }
2565}