1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use shell_command_parser::extract_commands;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{FutureExt as _, StreamExt as _, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use std::time::Duration;
54use util::ResultExt;
55use util::rel_path::RelPath;
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ProjectSnapshot {
59 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
60 pub timestamp: DateTime<Utc>,
61}
62
63pub struct RulesLoadingError {
64 pub message: SharedString,
65}
66
67/// Holds both the internal Thread and the AcpThread for a session
68struct Session {
69 /// The internal thread that processes messages
70 thread: Entity<Thread>,
71 /// The ACP thread that handles protocol communication
72 acp_thread: Entity<acp_thread::AcpThread>,
73 pending_save: Task<()>,
74 _subscriptions: Vec<Subscription>,
75}
76
77pub struct LanguageModels {
78 /// Access language model by ID
79 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
80 /// Cached list for returning language model information
81 model_list: acp_thread::AgentModelList,
82 refresh_models_rx: watch::Receiver<()>,
83 refresh_models_tx: watch::Sender<()>,
84 _authenticate_all_providers_task: Task<()>,
85}
86
87impl LanguageModels {
88 fn new(cx: &mut App) -> Self {
89 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
90
91 let mut this = Self {
92 models: HashMap::default(),
93 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
94 refresh_models_rx,
95 refresh_models_tx,
96 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
97 };
98 this.refresh_list(cx);
99 this
100 }
101
102 fn refresh_list(&mut self, cx: &App) {
103 let providers = LanguageModelRegistry::global(cx)
104 .read(cx)
105 .visible_providers()
106 .into_iter()
107 .filter(|provider| provider.is_authenticated(cx))
108 .collect::<Vec<_>>();
109
110 let mut language_model_list = IndexMap::default();
111 let mut recommended_models = HashSet::default();
112
113 let mut recommended = Vec::new();
114 for provider in &providers {
115 for model in provider.recommended_models(cx) {
116 recommended_models.insert((model.provider_id(), model.id()));
117 recommended.push(Self::map_language_model_to_info(&model, provider));
118 }
119 }
120 if !recommended.is_empty() {
121 language_model_list.insert(
122 acp_thread::AgentModelGroupName("Recommended".into()),
123 recommended,
124 );
125 }
126
127 let mut models = HashMap::default();
128 for provider in providers {
129 let mut provider_models = Vec::new();
130 for model in provider.provided_models(cx) {
131 let model_info = Self::map_language_model_to_info(&model, &provider);
132 let model_id = model_info.id.clone();
133 provider_models.push(model_info);
134 models.insert(model_id, model);
135 }
136 if !provider_models.is_empty() {
137 language_model_list.insert(
138 acp_thread::AgentModelGroupName(provider.name().0.clone()),
139 provider_models,
140 );
141 }
142 }
143
144 self.models = models;
145 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
146 self.refresh_models_tx.send(()).ok();
147 }
148
149 fn watch(&self) -> watch::Receiver<()> {
150 self.refresh_models_rx.clone()
151 }
152
153 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
154 self.models.get(model_id).cloned()
155 }
156
157 fn map_language_model_to_info(
158 model: &Arc<dyn LanguageModel>,
159 provider: &Arc<dyn LanguageModelProvider>,
160 ) -> acp_thread::AgentModelInfo {
161 acp_thread::AgentModelInfo {
162 id: Self::model_id(model),
163 name: model.name().0,
164 description: None,
165 icon: Some(match provider.icon() {
166 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
167 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
168 }),
169 is_latest: model.is_latest(),
170 cost: model.model_cost_info().map(|cost| cost.to_shared_string()),
171 }
172 }
173
174 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
175 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
176 }
177
178 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
179 let authenticate_all_providers = LanguageModelRegistry::global(cx)
180 .read(cx)
181 .visible_providers()
182 .iter()
183 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
184 .collect::<Vec<_>>();
185
186 cx.background_spawn(async move {
187 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
188 if let Err(err) = authenticate_task.await {
189 match err {
190 language_model::AuthenticateError::CredentialsNotFound => {
191 // Since we're authenticating these providers in the
192 // background for the purposes of populating the
193 // language selector, we don't care about providers
194 // where the credentials are not found.
195 }
196 language_model::AuthenticateError::ConnectionRefused => {
197 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
198 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
199 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
200 }
201 _ => {
202 // Some providers have noisy failure states that we
203 // don't want to spam the logs with every time the
204 // language model selector is initialized.
205 //
206 // Ideally these should have more clear failure modes
207 // that we know are safe to ignore here, like what we do
208 // with `CredentialsNotFound` above.
209 match provider_id.0.as_ref() {
210 "lmstudio" | "ollama" => {
211 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
212 //
213 // These fail noisily, so we don't log them.
214 }
215 "copilot_chat" => {
216 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
217 }
218 _ => {
219 log::error!(
220 "Failed to authenticate provider: {}: {err:#}",
221 provider_name.0
222 );
223 }
224 }
225 }
226 }
227 }
228 }
229 })
230 }
231}
232
233pub struct NativeAgent {
234 /// Session ID -> Session mapping
235 sessions: HashMap<acp::SessionId, Session>,
236 thread_store: Entity<ThreadStore>,
237 /// Shared project context for all threads
238 project_context: Entity<ProjectContext>,
239 project_context_needs_refresh: watch::Sender<()>,
240 _maintain_project_context: Task<Result<()>>,
241 context_server_registry: Entity<ContextServerRegistry>,
242 /// Shared templates for all threads
243 templates: Arc<Templates>,
244 /// Cached model information
245 models: LanguageModels,
246 project: Entity<Project>,
247 prompt_store: Option<Entity<PromptStore>>,
248 fs: Arc<dyn Fs>,
249 _subscriptions: Vec<Subscription>,
250}
251
252impl NativeAgent {
253 pub async fn new(
254 project: Entity<Project>,
255 thread_store: Entity<ThreadStore>,
256 templates: Arc<Templates>,
257 prompt_store: Option<Entity<PromptStore>>,
258 fs: Arc<dyn Fs>,
259 cx: &mut AsyncApp,
260 ) -> Result<Entity<NativeAgent>> {
261 log::debug!("Creating new NativeAgent");
262
263 let project_context = cx
264 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
265 .await;
266
267 Ok(cx.new(|cx| {
268 let context_server_store = project.read(cx).context_server_store();
269 let context_server_registry =
270 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
271
272 let mut subscriptions = vec![
273 cx.subscribe(&project, Self::handle_project_event),
274 cx.subscribe(
275 &LanguageModelRegistry::global(cx),
276 Self::handle_models_updated_event,
277 ),
278 cx.subscribe(
279 &context_server_store,
280 Self::handle_context_server_store_updated,
281 ),
282 cx.subscribe(
283 &context_server_registry,
284 Self::handle_context_server_registry_event,
285 ),
286 ];
287 if let Some(prompt_store) = prompt_store.as_ref() {
288 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
289 }
290
291 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
292 watch::channel(());
293 Self {
294 sessions: HashMap::default(),
295 thread_store,
296 project_context: cx.new(|_| project_context),
297 project_context_needs_refresh: project_context_needs_refresh_tx,
298 _maintain_project_context: cx.spawn(async move |this, cx| {
299 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
300 }),
301 context_server_registry,
302 templates,
303 models: LanguageModels::new(cx),
304 project,
305 prompt_store,
306 fs,
307 _subscriptions: subscriptions,
308 }
309 }))
310 }
311
312 fn new_session(
313 &mut self,
314 project: Entity<Project>,
315 cx: &mut Context<Self>,
316 ) -> Entity<AcpThread> {
317 // Create Thread
318 // Fetch default model from registry settings
319 let registry = LanguageModelRegistry::read_global(cx);
320 // Log available models for debugging
321 let available_count = registry.available_models(cx).count();
322 log::debug!("Total available models: {}", available_count);
323
324 let default_model = registry.default_model().and_then(|default_model| {
325 self.models
326 .model_from_id(&LanguageModels::model_id(&default_model.model))
327 });
328 let thread = cx.new(|cx| {
329 Thread::new(
330 project.clone(),
331 self.project_context.clone(),
332 self.context_server_registry.clone(),
333 self.templates.clone(),
334 default_model,
335 cx,
336 )
337 });
338
339 self.register_session(thread, None, cx)
340 }
341
342 fn register_session(
343 &mut self,
344 thread_handle: Entity<Thread>,
345 allowed_tool_names: Option<Vec<&str>>,
346 cx: &mut Context<Self>,
347 ) -> Entity<AcpThread> {
348 let connection = Rc::new(NativeAgentConnection(cx.entity()));
349
350 let thread = thread_handle.read(cx);
351 let session_id = thread.id().clone();
352 let parent_session_id = thread.parent_thread_id();
353 let title = thread.title();
354 let project = thread.project.clone();
355 let action_log = thread.action_log.clone();
356 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
357 let acp_thread = cx.new(|cx| {
358 acp_thread::AcpThread::new(
359 parent_session_id,
360 title,
361 connection,
362 project.clone(),
363 action_log.clone(),
364 session_id.clone(),
365 prompt_capabilities_rx,
366 cx,
367 )
368 });
369
370 let registry = LanguageModelRegistry::read_global(cx);
371 let summarization_model = registry.thread_summary_model().map(|c| c.model);
372
373 let weak = cx.weak_entity();
374 thread_handle.update(cx, |thread, cx| {
375 thread.set_summarization_model(summarization_model, cx);
376 thread.add_default_tools(
377 allowed_tool_names,
378 Rc::new(NativeThreadEnvironment {
379 acp_thread: acp_thread.downgrade(),
380 agent: weak,
381 }) as _,
382 cx,
383 )
384 });
385
386 let subscriptions = vec![
387 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
388 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
389 cx.observe(&thread_handle, move |this, thread, cx| {
390 this.save_thread(thread, cx)
391 }),
392 ];
393
394 self.sessions.insert(
395 session_id,
396 Session {
397 thread: thread_handle,
398 acp_thread: acp_thread.clone(),
399 _subscriptions: subscriptions,
400 pending_save: Task::ready(()),
401 },
402 );
403
404 self.update_available_commands(cx);
405
406 acp_thread
407 }
408
409 pub fn models(&self) -> &LanguageModels {
410 &self.models
411 }
412
413 async fn maintain_project_context(
414 this: WeakEntity<Self>,
415 mut needs_refresh: watch::Receiver<()>,
416 cx: &mut AsyncApp,
417 ) -> Result<()> {
418 while needs_refresh.changed().await.is_ok() {
419 let project_context = this
420 .update(cx, |this, cx| {
421 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
422 })?
423 .await;
424 this.update(cx, |this, cx| {
425 this.project_context = cx.new(|_| project_context);
426 })?;
427 }
428
429 Ok(())
430 }
431
432 fn build_project_context(
433 project: &Entity<Project>,
434 prompt_store: Option<&Entity<PromptStore>>,
435 cx: &mut App,
436 ) -> Task<ProjectContext> {
437 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
438 let worktree_tasks = worktrees
439 .into_iter()
440 .map(|worktree| {
441 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
442 })
443 .collect::<Vec<_>>();
444 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
445 prompt_store.read_with(cx, |prompt_store, cx| {
446 let prompts = prompt_store.default_prompt_metadata();
447 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
448 let contents = prompt_store.load(prompt_metadata.id, cx);
449 async move { (contents.await, prompt_metadata) }
450 });
451 cx.background_spawn(future::join_all(load_tasks))
452 })
453 } else {
454 Task::ready(vec![])
455 };
456
457 cx.spawn(async move |_cx| {
458 let (worktrees, default_user_rules) =
459 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
460
461 let worktrees = worktrees
462 .into_iter()
463 .map(|(worktree, _rules_error)| {
464 // TODO: show error message
465 // if let Some(rules_error) = rules_error {
466 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
467 // }
468 worktree
469 })
470 .collect::<Vec<_>>();
471
472 let default_user_rules = default_user_rules
473 .into_iter()
474 .flat_map(|(contents, prompt_metadata)| match contents {
475 Ok(contents) => Some(UserRulesContext {
476 uuid: prompt_metadata.id.as_user()?,
477 title: prompt_metadata.title.map(|title| title.to_string()),
478 contents,
479 }),
480 Err(_err) => {
481 // TODO: show error message
482 // this.update(cx, |_, cx| {
483 // cx.emit(RulesLoadingError {
484 // message: format!("{err:?}").into(),
485 // });
486 // })
487 // .ok();
488 None
489 }
490 })
491 .collect::<Vec<_>>();
492
493 ProjectContext::new(worktrees, default_user_rules)
494 })
495 }
496
497 fn load_worktree_info_for_system_prompt(
498 worktree: Entity<Worktree>,
499 project: Entity<Project>,
500 cx: &mut App,
501 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
502 let tree = worktree.read(cx);
503 let root_name = tree.root_name_str().into();
504 let abs_path = tree.abs_path();
505
506 let mut context = WorktreeContext {
507 root_name,
508 abs_path,
509 rules_file: None,
510 };
511
512 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
513 let Some(rules_task) = rules_task else {
514 return Task::ready((context, None));
515 };
516
517 cx.spawn(async move |_| {
518 let (rules_file, rules_file_error) = match rules_task.await {
519 Ok(rules_file) => (Some(rules_file), None),
520 Err(err) => (
521 None,
522 Some(RulesLoadingError {
523 message: format!("{err}").into(),
524 }),
525 ),
526 };
527 context.rules_file = rules_file;
528 (context, rules_file_error)
529 })
530 }
531
532 fn load_worktree_rules_file(
533 worktree: Entity<Worktree>,
534 project: Entity<Project>,
535 cx: &mut App,
536 ) -> Option<Task<Result<RulesFileContext>>> {
537 let worktree = worktree.read(cx);
538 let worktree_id = worktree.id();
539 let selected_rules_file = RULES_FILE_NAMES
540 .into_iter()
541 .filter_map(|name| {
542 worktree
543 .entry_for_path(RelPath::unix(name).unwrap())
544 .filter(|entry| entry.is_file())
545 .map(|entry| entry.path.clone())
546 })
547 .next();
548
549 // Note that Cline supports `.clinerules` being a directory, but that is not currently
550 // supported. This doesn't seem to occur often in GitHub repositories.
551 selected_rules_file.map(|path_in_worktree| {
552 let project_path = ProjectPath {
553 worktree_id,
554 path: path_in_worktree.clone(),
555 };
556 let buffer_task =
557 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
558 let rope_task = cx.spawn(async move |cx| {
559 let buffer = buffer_task.await?;
560 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
561 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
562 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
563 })?;
564 anyhow::Ok((project_entry_id, rope))
565 });
566 // Build a string from the rope on a background thread.
567 cx.background_spawn(async move {
568 let (project_entry_id, rope) = rope_task.await?;
569 anyhow::Ok(RulesFileContext {
570 path_in_worktree,
571 text: rope.to_string().trim().to_string(),
572 project_entry_id: project_entry_id.to_usize(),
573 })
574 })
575 })
576 }
577
578 fn handle_thread_title_updated(
579 &mut self,
580 thread: Entity<Thread>,
581 _: &TitleUpdated,
582 cx: &mut Context<Self>,
583 ) {
584 let session_id = thread.read(cx).id();
585 let Some(session) = self.sessions.get(session_id) else {
586 return;
587 };
588 let thread = thread.downgrade();
589 let acp_thread = session.acp_thread.downgrade();
590 cx.spawn(async move |_, cx| {
591 let title = thread.read_with(cx, |thread, _| thread.title())?;
592 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
593 task.await
594 })
595 .detach_and_log_err(cx);
596 }
597
598 fn handle_thread_token_usage_updated(
599 &mut self,
600 thread: Entity<Thread>,
601 usage: &TokenUsageUpdated,
602 cx: &mut Context<Self>,
603 ) {
604 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
605 return;
606 };
607 session.acp_thread.update(cx, |acp_thread, cx| {
608 acp_thread.update_token_usage(usage.0.clone(), cx);
609 });
610 }
611
612 fn handle_project_event(
613 &mut self,
614 _project: Entity<Project>,
615 event: &project::Event,
616 _cx: &mut Context<Self>,
617 ) {
618 match event {
619 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
620 self.project_context_needs_refresh.send(()).ok();
621 }
622 project::Event::WorktreeUpdatedEntries(_, items) => {
623 if items.iter().any(|(path, _, _)| {
624 RULES_FILE_NAMES
625 .iter()
626 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
627 }) {
628 self.project_context_needs_refresh.send(()).ok();
629 }
630 }
631 _ => {}
632 }
633 }
634
635 fn handle_prompts_updated_event(
636 &mut self,
637 _prompt_store: Entity<PromptStore>,
638 _event: &prompt_store::PromptsUpdatedEvent,
639 _cx: &mut Context<Self>,
640 ) {
641 self.project_context_needs_refresh.send(()).ok();
642 }
643
644 fn handle_models_updated_event(
645 &mut self,
646 _registry: Entity<LanguageModelRegistry>,
647 _event: &language_model::Event,
648 cx: &mut Context<Self>,
649 ) {
650 self.models.refresh_list(cx);
651
652 let registry = LanguageModelRegistry::read_global(cx);
653 let default_model = registry.default_model().map(|m| m.model);
654 let summarization_model = registry.thread_summary_model().map(|m| m.model);
655
656 for session in self.sessions.values_mut() {
657 session.thread.update(cx, |thread, cx| {
658 if thread.model().is_none()
659 && let Some(model) = default_model.clone()
660 {
661 thread.set_model(model, cx);
662 cx.notify();
663 }
664 thread.set_summarization_model(summarization_model.clone(), cx);
665 });
666 }
667 }
668
669 fn handle_context_server_store_updated(
670 &mut self,
671 _store: Entity<project::context_server_store::ContextServerStore>,
672 _event: &project::context_server_store::ServerStatusChangedEvent,
673 cx: &mut Context<Self>,
674 ) {
675 self.update_available_commands(cx);
676 }
677
678 fn handle_context_server_registry_event(
679 &mut self,
680 _registry: Entity<ContextServerRegistry>,
681 event: &ContextServerRegistryEvent,
682 cx: &mut Context<Self>,
683 ) {
684 match event {
685 ContextServerRegistryEvent::ToolsChanged => {}
686 ContextServerRegistryEvent::PromptsChanged => {
687 self.update_available_commands(cx);
688 }
689 }
690 }
691
692 fn update_available_commands(&self, cx: &mut Context<Self>) {
693 let available_commands = self.build_available_commands(cx);
694 for session in self.sessions.values() {
695 session.acp_thread.update(cx, |thread, cx| {
696 thread
697 .handle_session_update(
698 acp::SessionUpdate::AvailableCommandsUpdate(
699 acp::AvailableCommandsUpdate::new(available_commands.clone()),
700 ),
701 cx,
702 )
703 .log_err();
704 });
705 }
706 }
707
708 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
709 let registry = self.context_server_registry.read(cx);
710
711 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
712 for context_server_prompt in registry.prompts() {
713 *prompt_name_counts
714 .entry(context_server_prompt.prompt.name.as_str())
715 .or_insert(0) += 1;
716 }
717
718 registry
719 .prompts()
720 .flat_map(|context_server_prompt| {
721 let prompt = &context_server_prompt.prompt;
722
723 let should_prefix = prompt_name_counts
724 .get(prompt.name.as_str())
725 .copied()
726 .unwrap_or(0)
727 > 1;
728
729 let name = if should_prefix {
730 format!("{}.{}", context_server_prompt.server_id, prompt.name)
731 } else {
732 prompt.name.clone()
733 };
734
735 let mut command = acp::AvailableCommand::new(
736 name,
737 prompt.description.clone().unwrap_or_default(),
738 );
739
740 match prompt.arguments.as_deref() {
741 Some([arg]) => {
742 let hint = format!("<{}>", arg.name);
743
744 command = command.input(acp::AvailableCommandInput::Unstructured(
745 acp::UnstructuredCommandInput::new(hint),
746 ));
747 }
748 Some([]) | None => {}
749 Some(_) => {
750 // skip >1 argument commands since we don't support them yet
751 return None;
752 }
753 }
754
755 Some(command)
756 })
757 .collect()
758 }
759
760 pub fn load_thread(
761 &mut self,
762 id: acp::SessionId,
763 cx: &mut Context<Self>,
764 ) -> Task<Result<Entity<Thread>>> {
765 let database_future = ThreadsDatabase::connect(cx);
766 cx.spawn(async move |this, cx| {
767 let database = database_future.await.map_err(|err| anyhow!(err))?;
768 let db_thread = database
769 .load_thread(id.clone())
770 .await?
771 .with_context(|| format!("no thread found with ID: {id:?}"))?;
772
773 this.update(cx, |this, cx| {
774 let summarization_model = LanguageModelRegistry::read_global(cx)
775 .thread_summary_model()
776 .map(|c| c.model);
777
778 cx.new(|cx| {
779 let mut thread = Thread::from_db(
780 id.clone(),
781 db_thread,
782 this.project.clone(),
783 this.project_context.clone(),
784 this.context_server_registry.clone(),
785 this.templates.clone(),
786 cx,
787 );
788 thread.set_summarization_model(summarization_model, cx);
789 thread
790 })
791 })
792 })
793 }
794
795 pub fn open_thread(
796 &mut self,
797 id: acp::SessionId,
798 cx: &mut Context<Self>,
799 ) -> Task<Result<Entity<AcpThread>>> {
800 if let Some(session) = self.sessions.get(&id) {
801 return Task::ready(Ok(session.acp_thread.clone()));
802 }
803
804 let task = self.load_thread(id, cx);
805 cx.spawn(async move |this, cx| {
806 let thread = task.await?;
807 let acp_thread = this.update(cx, |this, cx| {
808 this.register_session(thread.clone(), None, cx)
809 })?;
810 let events = thread.update(cx, |thread, cx| thread.replay(cx));
811 cx.update(|cx| {
812 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
813 })
814 .await?;
815 Ok(acp_thread)
816 })
817 }
818
819 pub fn thread_summary(
820 &mut self,
821 id: acp::SessionId,
822 cx: &mut Context<Self>,
823 ) -> Task<Result<SharedString>> {
824 let thread = self.open_thread(id.clone(), cx);
825 cx.spawn(async move |this, cx| {
826 let acp_thread = thread.await?;
827 let result = this
828 .update(cx, |this, cx| {
829 this.sessions
830 .get(&id)
831 .unwrap()
832 .thread
833 .update(cx, |thread, cx| thread.summary(cx))
834 })?
835 .await
836 .context("Failed to generate summary")?;
837 drop(acp_thread);
838 Ok(result)
839 })
840 }
841
842 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
843 if thread.read(cx).is_empty() {
844 return;
845 }
846
847 let database_future = ThreadsDatabase::connect(cx);
848 let (id, db_thread) =
849 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
850 let Some(session) = self.sessions.get_mut(&id) else {
851 return;
852 };
853 let thread_store = self.thread_store.clone();
854 session.pending_save = cx.spawn(async move |_, cx| {
855 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
856 return;
857 };
858 let db_thread = db_thread.await;
859 database.save_thread(id, db_thread).await.log_err();
860 thread_store.update(cx, |store, cx| store.reload(cx));
861 });
862 }
863
864 fn send_mcp_prompt(
865 &self,
866 message_id: UserMessageId,
867 session_id: agent_client_protocol::SessionId,
868 prompt_name: String,
869 server_id: ContextServerId,
870 arguments: HashMap<String, String>,
871 original_content: Vec<acp::ContentBlock>,
872 cx: &mut Context<Self>,
873 ) -> Task<Result<acp::PromptResponse>> {
874 let server_store = self.context_server_registry.read(cx).server_store().clone();
875 let path_style = self.project.read(cx).path_style(cx);
876
877 cx.spawn(async move |this, cx| {
878 let prompt =
879 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
880
881 let (acp_thread, thread) = this.update(cx, |this, _cx| {
882 let session = this
883 .sessions
884 .get(&session_id)
885 .context("Failed to get session")?;
886 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
887 })??;
888
889 let mut last_is_user = true;
890
891 thread.update(cx, |thread, cx| {
892 thread.push_acp_user_block(
893 message_id,
894 original_content.into_iter().skip(1),
895 path_style,
896 cx,
897 );
898 });
899
900 for message in prompt.messages {
901 let context_server::types::PromptMessage { role, content } = message;
902 let block = mcp_message_content_to_acp_content_block(content);
903
904 match role {
905 context_server::types::Role::User => {
906 let id = acp_thread::UserMessageId::new();
907
908 acp_thread.update(cx, |acp_thread, cx| {
909 acp_thread.push_user_content_block_with_indent(
910 Some(id.clone()),
911 block.clone(),
912 true,
913 cx,
914 );
915 });
916
917 thread.update(cx, |thread, cx| {
918 thread.push_acp_user_block(id, [block], path_style, cx);
919 });
920 }
921 context_server::types::Role::Assistant => {
922 acp_thread.update(cx, |acp_thread, cx| {
923 acp_thread.push_assistant_content_block_with_indent(
924 block.clone(),
925 false,
926 true,
927 cx,
928 );
929 });
930
931 thread.update(cx, |thread, cx| {
932 thread.push_acp_agent_block(block, cx);
933 });
934 }
935 }
936
937 last_is_user = role == context_server::types::Role::User;
938 }
939
940 let response_stream = thread.update(cx, |thread, cx| {
941 if last_is_user {
942 thread.send_existing(cx)
943 } else {
944 // Resume if MCP prompt did not end with a user message
945 thread.resume(cx)
946 }
947 })?;
948
949 cx.update(|cx| {
950 NativeAgentConnection::handle_thread_events(
951 response_stream,
952 acp_thread.downgrade(),
953 cx,
954 )
955 })
956 .await
957 })
958 }
959}
960
961/// Wrapper struct that implements the AgentConnection trait
962#[derive(Clone)]
963pub struct NativeAgentConnection(pub Entity<NativeAgent>);
964
965impl NativeAgentConnection {
966 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
967 self.0
968 .read(cx)
969 .sessions
970 .get(session_id)
971 .map(|session| session.thread.clone())
972 }
973
974 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
975 self.0.update(cx, |this, cx| this.load_thread(id, cx))
976 }
977
978 fn run_turn(
979 &self,
980 session_id: acp::SessionId,
981 cx: &mut App,
982 f: impl 'static
983 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
984 ) -> Task<Result<acp::PromptResponse>> {
985 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
986 agent
987 .sessions
988 .get_mut(&session_id)
989 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
990 }) else {
991 return Task::ready(Err(anyhow!("Session not found")));
992 };
993 log::debug!("Found session for: {}", session_id);
994
995 let response_stream = match f(thread, cx) {
996 Ok(stream) => stream,
997 Err(err) => return Task::ready(Err(err)),
998 };
999 Self::handle_thread_events(response_stream, acp_thread.downgrade(), cx)
1000 }
1001
1002 fn handle_thread_events(
1003 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1004 acp_thread: WeakEntity<AcpThread>,
1005 cx: &App,
1006 ) -> Task<Result<acp::PromptResponse>> {
1007 cx.spawn(async move |cx| {
1008 // Handle response stream and forward to session.acp_thread
1009 while let Some(result) = events.next().await {
1010 match result {
1011 Ok(event) => {
1012 log::trace!("Received completion event: {:?}", event);
1013
1014 match event {
1015 ThreadEvent::UserMessage(message) => {
1016 acp_thread.update(cx, |thread, cx| {
1017 for content in message.content {
1018 thread.push_user_content_block(
1019 Some(message.id.clone()),
1020 content.into(),
1021 cx,
1022 );
1023 }
1024 })?;
1025 }
1026 ThreadEvent::AgentText(text) => {
1027 acp_thread.update(cx, |thread, cx| {
1028 thread.push_assistant_content_block(text.into(), false, cx)
1029 })?;
1030 }
1031 ThreadEvent::AgentThinking(text) => {
1032 acp_thread.update(cx, |thread, cx| {
1033 thread.push_assistant_content_block(text.into(), true, cx)
1034 })?;
1035 }
1036 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1037 tool_call,
1038 options,
1039 response,
1040 context: _,
1041 }) => {
1042 let outcome_task = acp_thread.update(cx, |thread, cx| {
1043 thread.request_tool_call_authorization(tool_call, options, cx)
1044 })??;
1045 cx.background_spawn(async move {
1046 if let acp::RequestPermissionOutcome::Selected(
1047 acp::SelectedPermissionOutcome { option_id, .. },
1048 ) = outcome_task.await
1049 {
1050 response
1051 .send(option_id)
1052 .map(|_| anyhow!("authorization receiver was dropped"))
1053 .log_err();
1054 }
1055 })
1056 .detach();
1057 }
1058 ThreadEvent::ToolCall(tool_call) => {
1059 acp_thread.update(cx, |thread, cx| {
1060 thread.upsert_tool_call(tool_call, cx)
1061 })??;
1062 }
1063 ThreadEvent::ToolCallUpdate(update) => {
1064 acp_thread.update(cx, |thread, cx| {
1065 thread.update_tool_call(update, cx)
1066 })??;
1067 }
1068 ThreadEvent::SubagentSpawned(session_id) => {
1069 acp_thread.update(cx, |thread, cx| {
1070 thread.subagent_spawned(session_id, cx);
1071 })?;
1072 }
1073 ThreadEvent::Retry(status) => {
1074 acp_thread.update(cx, |thread, cx| {
1075 thread.update_retry_status(status, cx)
1076 })?;
1077 }
1078 ThreadEvent::Stop(stop_reason) => {
1079 log::debug!("Assistant message complete: {:?}", stop_reason);
1080 return Ok(acp::PromptResponse::new(stop_reason));
1081 }
1082 }
1083 }
1084 Err(e) => {
1085 log::error!("Error in model response stream: {:?}", e);
1086 return Err(e);
1087 }
1088 }
1089 }
1090
1091 log::debug!("Response stream completed");
1092 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1093 })
1094 }
1095}
1096
1097struct Command<'a> {
1098 prompt_name: &'a str,
1099 arg_value: &'a str,
1100 explicit_server_id: Option<&'a str>,
1101}
1102
1103impl<'a> Command<'a> {
1104 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1105 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1106 return None;
1107 };
1108 let text = text_content.text.trim();
1109 let command = text.strip_prefix('/')?;
1110 let (command, arg_value) = command
1111 .split_once(char::is_whitespace)
1112 .unwrap_or((command, ""));
1113
1114 if let Some((server_id, prompt_name)) = command.split_once('.') {
1115 Some(Self {
1116 prompt_name,
1117 arg_value,
1118 explicit_server_id: Some(server_id),
1119 })
1120 } else {
1121 Some(Self {
1122 prompt_name: command,
1123 arg_value,
1124 explicit_server_id: None,
1125 })
1126 }
1127 }
1128}
1129
1130struct NativeAgentModelSelector {
1131 session_id: acp::SessionId,
1132 connection: NativeAgentConnection,
1133}
1134
1135impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1136 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1137 log::debug!("NativeAgentConnection::list_models called");
1138 let list = self.connection.0.read(cx).models.model_list.clone();
1139 Task::ready(if list.is_empty() {
1140 Err(anyhow::anyhow!("No models available"))
1141 } else {
1142 Ok(list)
1143 })
1144 }
1145
1146 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1147 log::debug!(
1148 "Setting model for session {}: {}",
1149 self.session_id,
1150 model_id
1151 );
1152 let Some(thread) = self
1153 .connection
1154 .0
1155 .read(cx)
1156 .sessions
1157 .get(&self.session_id)
1158 .map(|session| session.thread.clone())
1159 else {
1160 return Task::ready(Err(anyhow!("Session not found")));
1161 };
1162
1163 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1164 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1165 };
1166
1167 // We want to reset the effort level when switching models, as the currently-selected effort level may
1168 // not be compatible.
1169 let effort = model
1170 .default_effort_level()
1171 .map(|effort_level| effort_level.value.to_string());
1172
1173 thread.update(cx, |thread, cx| {
1174 thread.set_model(model.clone(), cx);
1175 thread.set_thinking_effort(effort.clone(), cx);
1176 thread.set_thinking_enabled(model.supports_thinking(), cx);
1177 });
1178
1179 update_settings_file(
1180 self.connection.0.read(cx).fs.clone(),
1181 cx,
1182 move |settings, cx| {
1183 let provider = model.provider_id().0.to_string();
1184 let model = model.id().0.to_string();
1185 let enable_thinking = thread.read(cx).thinking_enabled();
1186 settings
1187 .agent
1188 .get_or_insert_default()
1189 .set_model(LanguageModelSelection {
1190 provider: provider.into(),
1191 model,
1192 enable_thinking,
1193 effort,
1194 });
1195 },
1196 );
1197
1198 Task::ready(Ok(()))
1199 }
1200
1201 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1202 let Some(thread) = self
1203 .connection
1204 .0
1205 .read(cx)
1206 .sessions
1207 .get(&self.session_id)
1208 .map(|session| session.thread.clone())
1209 else {
1210 return Task::ready(Err(anyhow!("Session not found")));
1211 };
1212 let Some(model) = thread.read(cx).model() else {
1213 return Task::ready(Err(anyhow!("Model not found")));
1214 };
1215 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1216 else {
1217 return Task::ready(Err(anyhow!("Provider not found")));
1218 };
1219 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1220 model, &provider,
1221 )))
1222 }
1223
1224 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1225 Some(self.connection.0.read(cx).models.watch())
1226 }
1227
1228 fn should_render_footer(&self) -> bool {
1229 true
1230 }
1231}
1232
1233impl acp_thread::AgentConnection for NativeAgentConnection {
1234 fn telemetry_id(&self) -> SharedString {
1235 "zed".into()
1236 }
1237
1238 fn new_session(
1239 self: Rc<Self>,
1240 project: Entity<Project>,
1241 cwd: &Path,
1242 cx: &mut App,
1243 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1244 log::debug!("Creating new thread for project at: {cwd:?}");
1245 Task::ready(Ok(self
1246 .0
1247 .update(cx, |agent, cx| agent.new_session(project, cx))))
1248 }
1249
1250 fn supports_load_session(&self) -> bool {
1251 true
1252 }
1253
1254 fn load_session(
1255 self: Rc<Self>,
1256 session: AgentSessionInfo,
1257 _project: Entity<Project>,
1258 _cwd: &Path,
1259 cx: &mut App,
1260 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1261 self.0
1262 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1263 }
1264
1265 fn supports_close_session(&self) -> bool {
1266 true
1267 }
1268
1269 fn close_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1270 self.0.update(cx, |agent, _cx| {
1271 agent.sessions.remove(session_id);
1272 });
1273 Task::ready(Ok(()))
1274 }
1275
1276 fn auth_methods(&self) -> &[acp::AuthMethod] {
1277 &[] // No auth for in-process
1278 }
1279
1280 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1281 Task::ready(Ok(()))
1282 }
1283
1284 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1285 Some(Rc::new(NativeAgentModelSelector {
1286 session_id: session_id.clone(),
1287 connection: self.clone(),
1288 }) as Rc<dyn AgentModelSelector>)
1289 }
1290
1291 fn prompt(
1292 &self,
1293 id: Option<acp_thread::UserMessageId>,
1294 params: acp::PromptRequest,
1295 cx: &mut App,
1296 ) -> Task<Result<acp::PromptResponse>> {
1297 let id = id.expect("UserMessageId is required");
1298 let session_id = params.session_id.clone();
1299 log::info!("Received prompt request for session: {}", session_id);
1300 log::debug!("Prompt blocks count: {}", params.prompt.len());
1301
1302 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1303 let registry = self.0.read(cx).context_server_registry.read(cx);
1304
1305 let explicit_server_id = parsed_command
1306 .explicit_server_id
1307 .map(|server_id| ContextServerId(server_id.into()));
1308
1309 if let Some(prompt) =
1310 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1311 {
1312 let arguments = if !parsed_command.arg_value.is_empty()
1313 && let Some(arg_name) = prompt
1314 .prompt
1315 .arguments
1316 .as_ref()
1317 .and_then(|args| args.first())
1318 .map(|arg| arg.name.clone())
1319 {
1320 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1321 } else {
1322 Default::default()
1323 };
1324
1325 let prompt_name = prompt.prompt.name.clone();
1326 let server_id = prompt.server_id.clone();
1327
1328 return self.0.update(cx, |agent, cx| {
1329 agent.send_mcp_prompt(
1330 id,
1331 session_id.clone(),
1332 prompt_name,
1333 server_id,
1334 arguments,
1335 params.prompt,
1336 cx,
1337 )
1338 });
1339 };
1340 };
1341
1342 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1343
1344 self.run_turn(session_id, cx, move |thread, cx| {
1345 let content: Vec<UserMessageContent> = params
1346 .prompt
1347 .into_iter()
1348 .map(|block| UserMessageContent::from_content_block(block, path_style))
1349 .collect::<Vec<_>>();
1350 log::debug!("Converted prompt to message: {} chars", content.len());
1351 log::debug!("Message id: {:?}", id);
1352 log::debug!("Message content: {:?}", content);
1353
1354 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1355 })
1356 }
1357
1358 fn retry(
1359 &self,
1360 session_id: &acp::SessionId,
1361 _cx: &App,
1362 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1363 Some(Rc::new(NativeAgentSessionRetry {
1364 connection: self.clone(),
1365 session_id: session_id.clone(),
1366 }) as _)
1367 }
1368
1369 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1370 log::info!("Cancelling on session: {}", session_id);
1371 self.0.update(cx, |agent, cx| {
1372 if let Some(session) = agent.sessions.get(session_id) {
1373 session
1374 .thread
1375 .update(cx, |thread, cx| thread.cancel(cx))
1376 .detach();
1377 }
1378 });
1379 }
1380
1381 fn truncate(
1382 &self,
1383 session_id: &agent_client_protocol::SessionId,
1384 cx: &App,
1385 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1386 self.0.read_with(cx, |agent, _cx| {
1387 agent.sessions.get(session_id).map(|session| {
1388 Rc::new(NativeAgentSessionTruncate {
1389 thread: session.thread.clone(),
1390 acp_thread: session.acp_thread.downgrade(),
1391 }) as _
1392 })
1393 })
1394 }
1395
1396 fn set_title(
1397 &self,
1398 session_id: &acp::SessionId,
1399 cx: &App,
1400 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1401 self.0.read_with(cx, |agent, _cx| {
1402 agent
1403 .sessions
1404 .get(session_id)
1405 .filter(|s| !s.thread.read(cx).is_subagent())
1406 .map(|session| {
1407 Rc::new(NativeAgentSessionSetTitle {
1408 thread: session.thread.clone(),
1409 }) as _
1410 })
1411 })
1412 }
1413
1414 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1415 let thread_store = self.0.read(cx).thread_store.clone();
1416 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1417 }
1418
1419 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1420 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1421 }
1422
1423 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1424 self
1425 }
1426}
1427
1428impl acp_thread::AgentTelemetry for NativeAgentConnection {
1429 fn thread_data(
1430 &self,
1431 session_id: &acp::SessionId,
1432 cx: &mut App,
1433 ) -> Task<Result<serde_json::Value>> {
1434 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1435 return Task::ready(Err(anyhow!("Session not found")));
1436 };
1437
1438 let task = session.thread.read(cx).to_db(cx);
1439 cx.background_spawn(async move {
1440 serde_json::to_value(task.await).context("Failed to serialize thread")
1441 })
1442 }
1443}
1444
1445pub struct NativeAgentSessionList {
1446 thread_store: Entity<ThreadStore>,
1447 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1448 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1449 _subscription: Subscription,
1450}
1451
1452impl NativeAgentSessionList {
1453 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1454 let (tx, rx) = smol::channel::unbounded();
1455 let this_tx = tx.clone();
1456 let subscription = cx.observe(&thread_store, move |_, _| {
1457 this_tx
1458 .try_send(acp_thread::SessionListUpdate::Refresh)
1459 .ok();
1460 });
1461 Self {
1462 thread_store,
1463 updates_tx: tx,
1464 updates_rx: rx,
1465 _subscription: subscription,
1466 }
1467 }
1468
1469 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1470 AgentSessionInfo {
1471 session_id: entry.id,
1472 cwd: None,
1473 title: Some(entry.title),
1474 updated_at: Some(entry.updated_at),
1475 meta: None,
1476 }
1477 }
1478
1479 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1480 &self.thread_store
1481 }
1482}
1483
1484impl AgentSessionList for NativeAgentSessionList {
1485 fn list_sessions(
1486 &self,
1487 _request: AgentSessionListRequest,
1488 cx: &mut App,
1489 ) -> Task<Result<AgentSessionListResponse>> {
1490 let sessions = self
1491 .thread_store
1492 .read(cx)
1493 .entries()
1494 .map(Self::to_session_info)
1495 .collect();
1496 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1497 }
1498
1499 fn supports_delete(&self) -> bool {
1500 true
1501 }
1502
1503 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1504 self.thread_store
1505 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1506 }
1507
1508 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1509 self.thread_store
1510 .update(cx, |store, cx| store.delete_threads(cx))
1511 }
1512
1513 fn watch(
1514 &self,
1515 _cx: &mut App,
1516 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1517 Some(self.updates_rx.clone())
1518 }
1519
1520 fn notify_refresh(&self) {
1521 self.updates_tx
1522 .try_send(acp_thread::SessionListUpdate::Refresh)
1523 .ok();
1524 }
1525
1526 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1527 self
1528 }
1529}
1530
1531struct NativeAgentSessionTruncate {
1532 thread: Entity<Thread>,
1533 acp_thread: WeakEntity<AcpThread>,
1534}
1535
1536impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1537 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1538 match self.thread.update(cx, |thread, cx| {
1539 thread.truncate(message_id.clone(), cx)?;
1540 Ok(thread.latest_token_usage())
1541 }) {
1542 Ok(usage) => {
1543 self.acp_thread
1544 .update(cx, |thread, cx| {
1545 thread.update_token_usage(usage, cx);
1546 })
1547 .ok();
1548 Task::ready(Ok(()))
1549 }
1550 Err(error) => Task::ready(Err(error)),
1551 }
1552 }
1553}
1554
1555struct NativeAgentSessionRetry {
1556 connection: NativeAgentConnection,
1557 session_id: acp::SessionId,
1558}
1559
1560impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1561 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1562 self.connection
1563 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1564 thread.update(cx, |thread, cx| thread.resume(cx))
1565 })
1566 }
1567}
1568
1569struct NativeAgentSessionSetTitle {
1570 thread: Entity<Thread>,
1571}
1572
1573impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1574 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1575 self.thread
1576 .update(cx, |thread, cx| thread.set_title(title, cx));
1577 Task::ready(Ok(()))
1578 }
1579}
1580
1581pub struct NativeThreadEnvironment {
1582 agent: WeakEntity<NativeAgent>,
1583 acp_thread: WeakEntity<AcpThread>,
1584}
1585
1586impl NativeThreadEnvironment {
1587 pub(crate) fn create_subagent_thread(
1588 agent: WeakEntity<NativeAgent>,
1589 parent_thread_entity: Entity<Thread>,
1590 label: String,
1591 initial_prompt: String,
1592 timeout: Option<Duration>,
1593 allowed_tools: Option<Vec<String>>,
1594 cx: &mut App,
1595 ) -> Result<Rc<dyn SubagentHandle>> {
1596 let parent_thread = parent_thread_entity.read(cx);
1597 let current_depth = parent_thread.depth();
1598
1599 if current_depth >= MAX_SUBAGENT_DEPTH {
1600 return Err(anyhow!(
1601 "Maximum subagent depth ({}) reached",
1602 MAX_SUBAGENT_DEPTH
1603 ));
1604 }
1605
1606 let allowed_tools = match allowed_tools {
1607 Some(tools) => {
1608 let parent_tool_names: std::collections::HashSet<&str> =
1609 parent_thread.tools.keys().map(|s| s.as_str()).collect();
1610 Some(
1611 tools
1612 .into_iter()
1613 .filter(|t| parent_tool_names.contains(t.as_str()))
1614 .collect::<Vec<_>>(),
1615 )
1616 }
1617 None => Some(parent_thread.tools.keys().map(|s| s.to_string()).collect()),
1618 };
1619
1620 let subagent_thread: Entity<Thread> = cx.new(|cx| {
1621 let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
1622 thread.set_title(label.into(), cx);
1623 thread
1624 });
1625
1626 let session_id = subagent_thread.read(cx).id().clone();
1627
1628 let acp_thread = agent.update(cx, |agent, cx| {
1629 agent.register_session(
1630 subagent_thread.clone(),
1631 allowed_tools
1632 .as_ref()
1633 .map(|v| v.iter().map(|s| s.as_str()).collect()),
1634 cx,
1635 )
1636 })?;
1637
1638 parent_thread_entity.update(cx, |parent_thread, _cx| {
1639 parent_thread.register_running_subagent(subagent_thread.downgrade())
1640 });
1641
1642 let task = acp_thread.update(cx, |agent, cx| agent.send(vec![initial_prompt.into()], cx));
1643
1644 let timeout_timer = timeout.map(|d| cx.background_executor().timer(d));
1645 let wait_for_prompt_to_complete = cx
1646 .background_spawn(async move {
1647 if let Some(timer) = timeout_timer {
1648 futures::select! {
1649 _ = timer.fuse() => SubagentInitialPromptResult::Timeout,
1650 response = task.fuse() => {
1651 let response = response.log_err().flatten();
1652 if response.is_some_and(|response| {
1653 response.stop_reason == acp::StopReason::Cancelled
1654 })
1655 {
1656 SubagentInitialPromptResult::Cancelled
1657 } else {
1658 SubagentInitialPromptResult::Completed
1659 }
1660 },
1661 }
1662 } else {
1663 let response = task.await.log_err().flatten();
1664 if response
1665 .is_some_and(|response| response.stop_reason == acp::StopReason::Cancelled)
1666 {
1667 SubagentInitialPromptResult::Cancelled
1668 } else {
1669 SubagentInitialPromptResult::Completed
1670 }
1671 }
1672 })
1673 .shared();
1674
1675 Ok(Rc::new(NativeSubagentHandle {
1676 session_id,
1677 subagent_thread,
1678 parent_thread: parent_thread_entity.downgrade(),
1679 acp_thread,
1680 wait_for_prompt_to_complete,
1681 }) as _)
1682 }
1683}
1684
1685impl ThreadEnvironment for NativeThreadEnvironment {
1686 fn create_terminal(
1687 &self,
1688 command: String,
1689 cwd: Option<PathBuf>,
1690 output_byte_limit: Option<u64>,
1691 cx: &mut AsyncApp,
1692 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1693 let task = self.acp_thread.update(cx, |thread, cx| {
1694 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1695 });
1696
1697 let acp_thread = self.acp_thread.clone();
1698 cx.spawn(async move |cx| {
1699 let terminal = task?.await?;
1700
1701 let (drop_tx, drop_rx) = oneshot::channel();
1702 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1703
1704 cx.spawn(async move |cx| {
1705 drop_rx.await.ok();
1706 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1707 })
1708 .detach();
1709
1710 let handle = AcpTerminalHandle {
1711 terminal,
1712 _drop_tx: Some(drop_tx),
1713 };
1714
1715 Ok(Rc::new(handle) as _)
1716 })
1717 }
1718
1719 fn create_subagent(
1720 &self,
1721 parent_thread_entity: Entity<Thread>,
1722 label: String,
1723 initial_prompt: String,
1724 timeout: Option<Duration>,
1725 allowed_tools: Option<Vec<String>>,
1726 cx: &mut App,
1727 ) -> Result<Rc<dyn SubagentHandle>> {
1728 Self::create_subagent_thread(
1729 self.agent.clone(),
1730 parent_thread_entity,
1731 label,
1732 initial_prompt,
1733 timeout,
1734 allowed_tools,
1735 cx,
1736 )
1737 }
1738}
1739
1740#[derive(Debug, Clone, Copy)]
1741enum SubagentInitialPromptResult {
1742 Completed,
1743 Timeout,
1744 Cancelled,
1745}
1746
1747pub struct NativeSubagentHandle {
1748 session_id: acp::SessionId,
1749 parent_thread: WeakEntity<Thread>,
1750 subagent_thread: Entity<Thread>,
1751 acp_thread: Entity<AcpThread>,
1752 wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
1753}
1754
1755impl SubagentHandle for NativeSubagentHandle {
1756 fn id(&self) -> acp::SessionId {
1757 self.session_id.clone()
1758 }
1759
1760 fn wait_for_summary(&self, summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
1761 let thread = self.subagent_thread.clone();
1762 let acp_thread = self.acp_thread.clone();
1763 let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
1764
1765 let wait_for_summary_task = cx.spawn(async move |cx| {
1766 let timed_out = match wait_for_prompt.await {
1767 SubagentInitialPromptResult::Completed => false,
1768 SubagentInitialPromptResult::Timeout => true,
1769 SubagentInitialPromptResult::Cancelled => return Err(anyhow!("User cancelled")),
1770 };
1771
1772 let summary_prompt = if timed_out {
1773 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1774 format!("{}\n{}", "The time to complete the task was exceeded. Stop with the task and follow the directions below:", summary_prompt)
1775 } else {
1776 summary_prompt
1777 };
1778
1779 let response = acp_thread
1780 .update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx))
1781 .await?;
1782
1783 let was_canceled = response.is_some_and(|r| r.stop_reason == acp::StopReason::Cancelled);
1784 if was_canceled {
1785 return Err(anyhow!("User cancelled"));
1786 }
1787
1788 thread.read_with(cx, |thread, _cx| {
1789 thread
1790 .last_message()
1791 .map(|m| m.to_markdown())
1792 .context("No response from subagent")
1793 })
1794 });
1795
1796 let subagent_session_id = self.session_id.clone();
1797 let parent_thread = self.parent_thread.clone();
1798 cx.spawn(async move |cx| {
1799 let result = wait_for_summary_task.await;
1800 parent_thread
1801 .update(cx, |parent_thread, cx| {
1802 parent_thread.unregister_running_subagent(&subagent_session_id, cx)
1803 })
1804 .ok();
1805 result
1806 })
1807 }
1808}
1809
1810pub struct AcpTerminalHandle {
1811 terminal: Entity<acp_thread::Terminal>,
1812 _drop_tx: Option<oneshot::Sender<()>>,
1813}
1814
1815impl TerminalHandle for AcpTerminalHandle {
1816 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1817 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1818 }
1819
1820 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1821 Ok(self
1822 .terminal
1823 .read_with(cx, |term, _cx| term.wait_for_exit()))
1824 }
1825
1826 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1827 Ok(self
1828 .terminal
1829 .read_with(cx, |term, cx| term.current_output(cx)))
1830 }
1831
1832 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1833 cx.update(|cx| {
1834 self.terminal.update(cx, |terminal, cx| {
1835 terminal.kill(cx);
1836 });
1837 });
1838 Ok(())
1839 }
1840
1841 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1842 Ok(self
1843 .terminal
1844 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1845 }
1846}
1847
1848#[cfg(test)]
1849mod internal_tests {
1850 use super::*;
1851 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1852 use fs::FakeFs;
1853 use gpui::TestAppContext;
1854 use indoc::formatdoc;
1855 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1856 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1857 use serde_json::json;
1858 use settings::SettingsStore;
1859 use util::{path, rel_path::rel_path};
1860
1861 #[gpui::test]
1862 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1863 init_test(cx);
1864 let fs = FakeFs::new(cx.executor());
1865 fs.insert_tree(
1866 "/",
1867 json!({
1868 "a": {}
1869 }),
1870 )
1871 .await;
1872 let project = Project::test(fs.clone(), [], cx).await;
1873 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1874 let agent = NativeAgent::new(
1875 project.clone(),
1876 thread_store,
1877 Templates::new(),
1878 None,
1879 fs.clone(),
1880 &mut cx.to_async(),
1881 )
1882 .await
1883 .unwrap();
1884 agent.read_with(cx, |agent, cx| {
1885 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1886 });
1887
1888 let worktree = project
1889 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1890 .await
1891 .unwrap();
1892 cx.run_until_parked();
1893 agent.read_with(cx, |agent, cx| {
1894 assert_eq!(
1895 agent.project_context.read(cx).worktrees,
1896 vec![WorktreeContext {
1897 root_name: "a".into(),
1898 abs_path: Path::new("/a").into(),
1899 rules_file: None
1900 }]
1901 )
1902 });
1903
1904 // Creating `/a/.rules` updates the project context.
1905 fs.insert_file("/a/.rules", Vec::new()).await;
1906 cx.run_until_parked();
1907 agent.read_with(cx, |agent, cx| {
1908 let rules_entry = worktree
1909 .read(cx)
1910 .entry_for_path(rel_path(".rules"))
1911 .unwrap();
1912 assert_eq!(
1913 agent.project_context.read(cx).worktrees,
1914 vec![WorktreeContext {
1915 root_name: "a".into(),
1916 abs_path: Path::new("/a").into(),
1917 rules_file: Some(RulesFileContext {
1918 path_in_worktree: rel_path(".rules").into(),
1919 text: "".into(),
1920 project_entry_id: rules_entry.id.to_usize()
1921 })
1922 }]
1923 )
1924 });
1925 }
1926
1927 #[gpui::test]
1928 async fn test_listing_models(cx: &mut TestAppContext) {
1929 init_test(cx);
1930 let fs = FakeFs::new(cx.executor());
1931 fs.insert_tree("/", json!({ "a": {} })).await;
1932 let project = Project::test(fs.clone(), [], cx).await;
1933 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1934 let connection = NativeAgentConnection(
1935 NativeAgent::new(
1936 project.clone(),
1937 thread_store,
1938 Templates::new(),
1939 None,
1940 fs.clone(),
1941 &mut cx.to_async(),
1942 )
1943 .await
1944 .unwrap(),
1945 );
1946
1947 // Create a thread/session
1948 let acp_thread = cx
1949 .update(|cx| {
1950 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
1951 })
1952 .await
1953 .unwrap();
1954
1955 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1956
1957 let models = cx
1958 .update(|cx| {
1959 connection
1960 .model_selector(&session_id)
1961 .unwrap()
1962 .list_models(cx)
1963 })
1964 .await
1965 .unwrap();
1966
1967 let acp_thread::AgentModelList::Grouped(models) = models else {
1968 panic!("Unexpected model group");
1969 };
1970 assert_eq!(
1971 models,
1972 IndexMap::from_iter([(
1973 AgentModelGroupName("Fake".into()),
1974 vec![AgentModelInfo {
1975 id: acp::ModelId::new("fake/fake"),
1976 name: "Fake".into(),
1977 description: None,
1978 icon: Some(acp_thread::AgentModelIcon::Named(
1979 ui::IconName::ZedAssistant
1980 )),
1981 is_latest: false,
1982 cost: None,
1983 }]
1984 )])
1985 );
1986 }
1987
1988 #[gpui::test]
1989 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1990 init_test(cx);
1991 let fs = FakeFs::new(cx.executor());
1992 fs.create_dir(paths::settings_file().parent().unwrap())
1993 .await
1994 .unwrap();
1995 fs.insert_file(
1996 paths::settings_file(),
1997 json!({
1998 "agent": {
1999 "default_model": {
2000 "provider": "foo",
2001 "model": "bar"
2002 }
2003 }
2004 })
2005 .to_string()
2006 .into_bytes(),
2007 )
2008 .await;
2009 let project = Project::test(fs.clone(), [], cx).await;
2010
2011 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2012
2013 // Create the agent and connection
2014 let agent = NativeAgent::new(
2015 project.clone(),
2016 thread_store,
2017 Templates::new(),
2018 None,
2019 fs.clone(),
2020 &mut cx.to_async(),
2021 )
2022 .await
2023 .unwrap();
2024 let connection = NativeAgentConnection(agent.clone());
2025
2026 // Create a thread/session
2027 let acp_thread = cx
2028 .update(|cx| {
2029 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2030 })
2031 .await
2032 .unwrap();
2033
2034 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2035
2036 // Select a model
2037 let selector = connection.model_selector(&session_id).unwrap();
2038 let model_id = acp::ModelId::new("fake/fake");
2039 cx.update(|cx| selector.select_model(model_id.clone(), cx))
2040 .await
2041 .unwrap();
2042
2043 // Verify the thread has the selected model
2044 agent.read_with(cx, |agent, _| {
2045 let session = agent.sessions.get(&session_id).unwrap();
2046 session.thread.read_with(cx, |thread, _| {
2047 assert_eq!(thread.model().unwrap().id().0, "fake");
2048 });
2049 });
2050
2051 cx.run_until_parked();
2052
2053 // Verify settings file was updated
2054 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2055 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2056
2057 // Check that the agent settings contain the selected model
2058 assert_eq!(
2059 settings_json["agent"]["default_model"]["model"],
2060 json!("fake")
2061 );
2062 assert_eq!(
2063 settings_json["agent"]["default_model"]["provider"],
2064 json!("fake")
2065 );
2066
2067 // Register a thinking model and select it.
2068 cx.update(|cx| {
2069 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2070 "fake-corp",
2071 "fake-thinking",
2072 "Fake Thinking",
2073 true,
2074 ));
2075 let thinking_provider = Arc::new(
2076 FakeLanguageModelProvider::new(
2077 LanguageModelProviderId::from("fake-corp".to_string()),
2078 LanguageModelProviderName::from("Fake Corp".to_string()),
2079 )
2080 .with_models(vec![thinking_model]),
2081 );
2082 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2083 registry.register_provider(thinking_provider, cx);
2084 });
2085 });
2086 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2087
2088 let selector = connection.model_selector(&session_id).unwrap();
2089 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2090 .await
2091 .unwrap();
2092 cx.run_until_parked();
2093
2094 // Verify enable_thinking was written to settings as true.
2095 let settings_content = fs.load(paths::settings_file()).await.unwrap();
2096 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
2097 assert_eq!(
2098 settings_json["agent"]["default_model"]["enable_thinking"],
2099 json!(true),
2100 "selecting a thinking model should persist enable_thinking: true to settings"
2101 );
2102 }
2103
2104 #[gpui::test]
2105 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
2106 init_test(cx);
2107 let fs = FakeFs::new(cx.executor());
2108 fs.create_dir(paths::settings_file().parent().unwrap())
2109 .await
2110 .unwrap();
2111 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
2112 let project = Project::test(fs.clone(), [], cx).await;
2113
2114 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2115 let agent = NativeAgent::new(
2116 project.clone(),
2117 thread_store,
2118 Templates::new(),
2119 None,
2120 fs.clone(),
2121 &mut cx.to_async(),
2122 )
2123 .await
2124 .unwrap();
2125 let connection = NativeAgentConnection(agent.clone());
2126
2127 let acp_thread = cx
2128 .update(|cx| {
2129 Rc::new(connection.clone()).new_session(project.clone(), Path::new("/a"), cx)
2130 })
2131 .await
2132 .unwrap();
2133 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
2134
2135 // Register a second provider with a thinking model.
2136 cx.update(|cx| {
2137 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2138 "fake-corp",
2139 "fake-thinking",
2140 "Fake Thinking",
2141 true,
2142 ));
2143 let thinking_provider = Arc::new(
2144 FakeLanguageModelProvider::new(
2145 LanguageModelProviderId::from("fake-corp".to_string()),
2146 LanguageModelProviderName::from("Fake Corp".to_string()),
2147 )
2148 .with_models(vec![thinking_model]),
2149 );
2150 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2151 registry.register_provider(thinking_provider, cx);
2152 });
2153 });
2154 // Refresh the agent's model list so it picks up the new provider.
2155 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2156
2157 // Thread starts with thinking_enabled = false (the default).
2158 agent.read_with(cx, |agent, _| {
2159 let session = agent.sessions.get(&session_id).unwrap();
2160 session.thread.read_with(cx, |thread, _| {
2161 assert!(!thread.thinking_enabled(), "thinking defaults to false");
2162 });
2163 });
2164
2165 // Select the thinking model via select_model.
2166 let selector = connection.model_selector(&session_id).unwrap();
2167 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2168 .await
2169 .unwrap();
2170
2171 // select_model should have enabled thinking based on the model's supports_thinking().
2172 agent.read_with(cx, |agent, _| {
2173 let session = agent.sessions.get(&session_id).unwrap();
2174 session.thread.read_with(cx, |thread, _| {
2175 assert!(
2176 thread.thinking_enabled(),
2177 "select_model should enable thinking when model supports it"
2178 );
2179 });
2180 });
2181
2182 // Switch back to the non-thinking model.
2183 let selector = connection.model_selector(&session_id).unwrap();
2184 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
2185 .await
2186 .unwrap();
2187
2188 // select_model should have disabled thinking.
2189 agent.read_with(cx, |agent, _| {
2190 let session = agent.sessions.get(&session_id).unwrap();
2191 session.thread.read_with(cx, |thread, _| {
2192 assert!(
2193 !thread.thinking_enabled(),
2194 "select_model should disable thinking when model does not support it"
2195 );
2196 });
2197 });
2198 }
2199
2200 #[gpui::test]
2201 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
2202 init_test(cx);
2203 let fs = FakeFs::new(cx.executor());
2204 fs.insert_tree("/", json!({ "a": {} })).await;
2205 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2206 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2207 let agent = NativeAgent::new(
2208 project.clone(),
2209 thread_store.clone(),
2210 Templates::new(),
2211 None,
2212 fs.clone(),
2213 &mut cx.to_async(),
2214 )
2215 .await
2216 .unwrap();
2217 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2218
2219 // Register a thinking model.
2220 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2221 "fake-corp",
2222 "fake-thinking",
2223 "Fake Thinking",
2224 true,
2225 ));
2226 let thinking_provider = Arc::new(
2227 FakeLanguageModelProvider::new(
2228 LanguageModelProviderId::from("fake-corp".to_string()),
2229 LanguageModelProviderName::from("Fake Corp".to_string()),
2230 )
2231 .with_models(vec![thinking_model.clone()]),
2232 );
2233 cx.update(|cx| {
2234 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2235 registry.register_provider(thinking_provider, cx);
2236 });
2237 });
2238 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2239
2240 // Create a thread and select the thinking model.
2241 let acp_thread = cx
2242 .update(|cx| {
2243 connection
2244 .clone()
2245 .new_session(project.clone(), Path::new("/a"), cx)
2246 })
2247 .await
2248 .unwrap();
2249 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2250
2251 let selector = connection.model_selector(&session_id).unwrap();
2252 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2253 .await
2254 .unwrap();
2255
2256 // Verify thinking is enabled after selecting the thinking model.
2257 let thread = agent.read_with(cx, |agent, _| {
2258 agent.sessions.get(&session_id).unwrap().thread.clone()
2259 });
2260 thread.read_with(cx, |thread, _| {
2261 assert!(
2262 thread.thinking_enabled(),
2263 "thinking should be enabled after selecting thinking model"
2264 );
2265 });
2266
2267 // Send a message so the thread gets persisted.
2268 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2269 let send = cx.foreground_executor().spawn(send);
2270 cx.run_until_parked();
2271
2272 thinking_model.send_last_completion_stream_text_chunk("Response.");
2273 thinking_model.end_last_completion_stream();
2274
2275 send.await.unwrap();
2276 cx.run_until_parked();
2277
2278 // Close the session so it can be reloaded from disk.
2279 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2280 .await
2281 .unwrap();
2282 drop(thread);
2283 drop(acp_thread);
2284 agent.read_with(cx, |agent, _| {
2285 assert!(agent.sessions.is_empty());
2286 });
2287
2288 // Reload the thread and verify thinking_enabled is still true.
2289 let reloaded_acp_thread = agent
2290 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2291 .await
2292 .unwrap();
2293 let reloaded_thread = agent.read_with(cx, |agent, _| {
2294 agent.sessions.get(&session_id).unwrap().thread.clone()
2295 });
2296 reloaded_thread.read_with(cx, |thread, _| {
2297 assert!(
2298 thread.thinking_enabled(),
2299 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2300 );
2301 });
2302
2303 drop(reloaded_acp_thread);
2304 }
2305
2306 #[gpui::test]
2307 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2308 init_test(cx);
2309 let fs = FakeFs::new(cx.executor());
2310 fs.insert_tree("/", json!({ "a": {} })).await;
2311 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2312 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2313 let agent = NativeAgent::new(
2314 project.clone(),
2315 thread_store.clone(),
2316 Templates::new(),
2317 None,
2318 fs.clone(),
2319 &mut cx.to_async(),
2320 )
2321 .await
2322 .unwrap();
2323 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2324
2325 // Register a model where id() != name(), like real Anthropic models
2326 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2327 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2328 "fake-corp",
2329 "custom-model-id",
2330 "Custom Model Display Name",
2331 false,
2332 ));
2333 let provider = Arc::new(
2334 FakeLanguageModelProvider::new(
2335 LanguageModelProviderId::from("fake-corp".to_string()),
2336 LanguageModelProviderName::from("Fake Corp".to_string()),
2337 )
2338 .with_models(vec![model.clone()]),
2339 );
2340 cx.update(|cx| {
2341 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2342 registry.register_provider(provider, cx);
2343 });
2344 });
2345 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2346
2347 // Create a thread and select the model.
2348 let acp_thread = cx
2349 .update(|cx| {
2350 connection
2351 .clone()
2352 .new_session(project.clone(), Path::new("/a"), cx)
2353 })
2354 .await
2355 .unwrap();
2356 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2357
2358 let selector = connection.model_selector(&session_id).unwrap();
2359 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2360 .await
2361 .unwrap();
2362
2363 let thread = agent.read_with(cx, |agent, _| {
2364 agent.sessions.get(&session_id).unwrap().thread.clone()
2365 });
2366 thread.read_with(cx, |thread, _| {
2367 assert_eq!(
2368 thread.model().unwrap().id().0.as_ref(),
2369 "custom-model-id",
2370 "model should be set before persisting"
2371 );
2372 });
2373
2374 // Send a message so the thread gets persisted.
2375 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2376 let send = cx.foreground_executor().spawn(send);
2377 cx.run_until_parked();
2378
2379 model.send_last_completion_stream_text_chunk("Response.");
2380 model.end_last_completion_stream();
2381
2382 send.await.unwrap();
2383 cx.run_until_parked();
2384
2385 // Close the session so it can be reloaded from disk.
2386 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2387 .await
2388 .unwrap();
2389 drop(thread);
2390 drop(acp_thread);
2391 agent.read_with(cx, |agent, _| {
2392 assert!(agent.sessions.is_empty());
2393 });
2394
2395 // Reload the thread and verify the model was preserved.
2396 let reloaded_acp_thread = agent
2397 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2398 .await
2399 .unwrap();
2400 let reloaded_thread = agent.read_with(cx, |agent, _| {
2401 agent.sessions.get(&session_id).unwrap().thread.clone()
2402 });
2403 reloaded_thread.read_with(cx, |thread, _| {
2404 let reloaded_model = thread
2405 .model()
2406 .expect("model should be present after reload");
2407 assert_eq!(
2408 reloaded_model.id().0.as_ref(),
2409 "custom-model-id",
2410 "reloaded thread should have the same model, not fall back to the default"
2411 );
2412 });
2413
2414 drop(reloaded_acp_thread);
2415 }
2416
2417 #[gpui::test]
2418 async fn test_save_load_thread(cx: &mut TestAppContext) {
2419 init_test(cx);
2420 let fs = FakeFs::new(cx.executor());
2421 fs.insert_tree(
2422 "/",
2423 json!({
2424 "a": {
2425 "b.md": "Lorem"
2426 }
2427 }),
2428 )
2429 .await;
2430 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2431 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2432 let agent = NativeAgent::new(
2433 project.clone(),
2434 thread_store.clone(),
2435 Templates::new(),
2436 None,
2437 fs.clone(),
2438 &mut cx.to_async(),
2439 )
2440 .await
2441 .unwrap();
2442 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2443
2444 let acp_thread = cx
2445 .update(|cx| {
2446 connection
2447 .clone()
2448 .new_session(project.clone(), Path::new(""), cx)
2449 })
2450 .await
2451 .unwrap();
2452 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2453 let thread = agent.read_with(cx, |agent, _| {
2454 agent.sessions.get(&session_id).unwrap().thread.clone()
2455 });
2456
2457 // Ensure empty threads are not saved, even if they get mutated.
2458 let model = Arc::new(FakeLanguageModel::default());
2459 let summary_model = Arc::new(FakeLanguageModel::default());
2460 thread.update(cx, |thread, cx| {
2461 thread.set_model(model.clone(), cx);
2462 thread.set_summarization_model(Some(summary_model.clone()), cx);
2463 });
2464 cx.run_until_parked();
2465 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2466
2467 let send = acp_thread.update(cx, |thread, cx| {
2468 thread.send(
2469 vec![
2470 "What does ".into(),
2471 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2472 "b.md",
2473 MentionUri::File {
2474 abs_path: path!("/a/b.md").into(),
2475 }
2476 .to_uri()
2477 .to_string(),
2478 )),
2479 " mean?".into(),
2480 ],
2481 cx,
2482 )
2483 });
2484 let send = cx.foreground_executor().spawn(send);
2485 cx.run_until_parked();
2486
2487 model.send_last_completion_stream_text_chunk("Lorem.");
2488 model.end_last_completion_stream();
2489 cx.run_until_parked();
2490 summary_model
2491 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2492 summary_model.end_last_completion_stream();
2493
2494 send.await.unwrap();
2495 let uri = MentionUri::File {
2496 abs_path: path!("/a/b.md").into(),
2497 }
2498 .to_uri();
2499 acp_thread.read_with(cx, |thread, cx| {
2500 assert_eq!(
2501 thread.to_markdown(cx),
2502 formatdoc! {"
2503 ## User
2504
2505 What does [@b.md]({uri}) mean?
2506
2507 ## Assistant
2508
2509 Lorem.
2510
2511 "}
2512 )
2513 });
2514
2515 cx.run_until_parked();
2516
2517 // Close the session so it can be reloaded from disk.
2518 cx.update(|cx| connection.clone().close_session(&session_id, cx))
2519 .await
2520 .unwrap();
2521 drop(thread);
2522 drop(acp_thread);
2523 agent.read_with(cx, |agent, _| {
2524 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2525 });
2526
2527 // Ensure the thread can be reloaded from disk.
2528 assert_eq!(
2529 thread_entries(&thread_store, cx),
2530 vec![(
2531 session_id.clone(),
2532 format!("Explaining {}", path!("/a/b.md"))
2533 )]
2534 );
2535 let acp_thread = agent
2536 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2537 .await
2538 .unwrap();
2539 acp_thread.read_with(cx, |thread, cx| {
2540 assert_eq!(
2541 thread.to_markdown(cx),
2542 formatdoc! {"
2543 ## User
2544
2545 What does [@b.md]({uri}) mean?
2546
2547 ## Assistant
2548
2549 Lorem.
2550
2551 "}
2552 )
2553 });
2554 }
2555
2556 fn thread_entries(
2557 thread_store: &Entity<ThreadStore>,
2558 cx: &mut TestAppContext,
2559 ) -> Vec<(acp::SessionId, String)> {
2560 thread_store.read_with(cx, |store, _| {
2561 store
2562 .entries()
2563 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2564 .collect::<Vec<_>>()
2565 })
2566 }
2567
2568 fn init_test(cx: &mut TestAppContext) {
2569 env_logger::try_init().ok();
2570 cx.update(|cx| {
2571 let settings_store = SettingsStore::test(cx);
2572 cx.set_global(settings_store);
2573
2574 LanguageModelRegistry::test(cx);
2575 });
2576 }
2577}
2578
2579fn mcp_message_content_to_acp_content_block(
2580 content: context_server::types::MessageContent,
2581) -> acp::ContentBlock {
2582 match content {
2583 context_server::types::MessageContent::Text {
2584 text,
2585 annotations: _,
2586 } => text.into(),
2587 context_server::types::MessageContent::Image {
2588 data,
2589 mime_type,
2590 annotations: _,
2591 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2592 context_server::types::MessageContent::Audio {
2593 data,
2594 mime_type,
2595 annotations: _,
2596 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2597 context_server::types::MessageContent::Resource {
2598 resource,
2599 annotations: _,
2600 } => {
2601 let mut link =
2602 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2603 if let Some(mime_type) = resource.mime_type {
2604 link = link.mime_type(mime_type);
2605 }
2606 acp::ContentBlock::ResourceLink(link)
2607 }
2608 }
2609}