1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use templates::*;
20pub use thread::*;
21pub use thread_store::*;
22pub use tool_permissions::*;
23pub use tools::*;
24
25use acp_thread::{
26 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
27 AgentSessionListResponse, UserMessageId,
28};
29use agent_client_protocol as acp;
30use anyhow::{Context as _, Result, anyhow};
31use chrono::{DateTime, Utc};
32use collections::{HashMap, HashSet, IndexMap};
33use fs::Fs;
34use futures::channel::{mpsc, oneshot};
35use futures::future::Shared;
36use futures::{StreamExt, future};
37use gpui::{
38 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
39};
40use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
41use project::{Project, ProjectItem, ProjectPath, Worktree};
42use prompt_store::{
43 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
44 WorktreeContext,
45};
46use serde::{Deserialize, Serialize};
47use settings::{LanguageModelSelection, update_settings_file};
48use std::any::Any;
49use std::path::{Path, PathBuf};
50use std::rc::Rc;
51use std::sync::Arc;
52use util::ResultExt;
53use util::rel_path::RelPath;
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub struct ProjectSnapshot {
57 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
58 pub timestamp: DateTime<Utc>,
59}
60
61pub struct RulesLoadingError {
62 pub message: SharedString,
63}
64
65/// Holds both the internal Thread and the AcpThread for a session
66struct Session {
67 /// The internal thread that processes messages
68 thread: Entity<Thread>,
69 /// The ACP thread that handles protocol communication
70 acp_thread: WeakEntity<acp_thread::AcpThread>,
71 pending_save: Task<()>,
72 _subscriptions: Vec<Subscription>,
73}
74
75pub struct LanguageModels {
76 /// Access language model by ID
77 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
78 /// Cached list for returning language model information
79 model_list: acp_thread::AgentModelList,
80 refresh_models_rx: watch::Receiver<()>,
81 refresh_models_tx: watch::Sender<()>,
82 _authenticate_all_providers_task: Task<()>,
83}
84
85impl LanguageModels {
86 fn new(cx: &mut App) -> Self {
87 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
88
89 let mut this = Self {
90 models: HashMap::default(),
91 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
92 refresh_models_rx,
93 refresh_models_tx,
94 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
95 };
96 this.refresh_list(cx);
97 this
98 }
99
100 fn refresh_list(&mut self, cx: &App) {
101 let providers = LanguageModelRegistry::global(cx)
102 .read(cx)
103 .visible_providers()
104 .into_iter()
105 .filter(|provider| provider.is_authenticated(cx))
106 .collect::<Vec<_>>();
107
108 let mut language_model_list = IndexMap::default();
109 let mut recommended_models = HashSet::default();
110
111 let mut recommended = Vec::new();
112 for provider in &providers {
113 for model in provider.recommended_models(cx) {
114 recommended_models.insert((model.provider_id(), model.id()));
115 recommended.push(Self::map_language_model_to_info(&model, provider));
116 }
117 }
118 if !recommended.is_empty() {
119 language_model_list.insert(
120 acp_thread::AgentModelGroupName("Recommended".into()),
121 recommended,
122 );
123 }
124
125 let mut models = HashMap::default();
126 for provider in providers {
127 let mut provider_models = Vec::new();
128 for model in provider.provided_models(cx) {
129 let model_info = Self::map_language_model_to_info(&model, &provider);
130 let model_id = model_info.id.clone();
131 provider_models.push(model_info);
132 models.insert(model_id, model);
133 }
134 if !provider_models.is_empty() {
135 language_model_list.insert(
136 acp_thread::AgentModelGroupName(provider.name().0.clone()),
137 provider_models,
138 );
139 }
140 }
141
142 self.models = models;
143 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
144 self.refresh_models_tx.send(()).ok();
145 }
146
147 fn watch(&self) -> watch::Receiver<()> {
148 self.refresh_models_rx.clone()
149 }
150
151 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
152 self.models.get(model_id).cloned()
153 }
154
155 fn map_language_model_to_info(
156 model: &Arc<dyn LanguageModel>,
157 provider: &Arc<dyn LanguageModelProvider>,
158 ) -> acp_thread::AgentModelInfo {
159 acp_thread::AgentModelInfo {
160 id: Self::model_id(model),
161 name: model.name().0,
162 description: None,
163 icon: Some(match provider.icon() {
164 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
165 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
166 }),
167 is_latest: model.is_latest(),
168 }
169 }
170
171 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
172 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
173 }
174
175 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
176 let authenticate_all_providers = LanguageModelRegistry::global(cx)
177 .read(cx)
178 .visible_providers()
179 .iter()
180 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
181 .collect::<Vec<_>>();
182
183 cx.background_spawn(async move {
184 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
185 if let Err(err) = authenticate_task.await {
186 match err {
187 language_model::AuthenticateError::CredentialsNotFound => {
188 // Since we're authenticating these providers in the
189 // background for the purposes of populating the
190 // language selector, we don't care about providers
191 // where the credentials are not found.
192 }
193 language_model::AuthenticateError::ConnectionRefused => {
194 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
195 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
196 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
197 }
198 _ => {
199 // Some providers have noisy failure states that we
200 // don't want to spam the logs with every time the
201 // language model selector is initialized.
202 //
203 // Ideally these should have more clear failure modes
204 // that we know are safe to ignore here, like what we do
205 // with `CredentialsNotFound` above.
206 match provider_id.0.as_ref() {
207 "lmstudio" | "ollama" => {
208 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
209 //
210 // These fail noisily, so we don't log them.
211 }
212 "copilot_chat" => {
213 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
214 }
215 _ => {
216 log::error!(
217 "Failed to authenticate provider: {}: {err:#}",
218 provider_name.0
219 );
220 }
221 }
222 }
223 }
224 }
225 }
226 })
227 }
228}
229
230pub struct NativeAgent {
231 /// Session ID -> Session mapping
232 sessions: HashMap<acp::SessionId, Session>,
233 thread_store: Entity<ThreadStore>,
234 /// Shared project context for all threads
235 project_context: Entity<ProjectContext>,
236 project_context_needs_refresh: watch::Sender<()>,
237 _maintain_project_context: Task<Result<()>>,
238 context_server_registry: Entity<ContextServerRegistry>,
239 /// Shared templates for all threads
240 templates: Arc<Templates>,
241 /// Cached model information
242 models: LanguageModels,
243 project: Entity<Project>,
244 prompt_store: Option<Entity<PromptStore>>,
245 fs: Arc<dyn Fs>,
246 _subscriptions: Vec<Subscription>,
247}
248
249impl NativeAgent {
250 pub async fn new(
251 project: Entity<Project>,
252 thread_store: Entity<ThreadStore>,
253 templates: Arc<Templates>,
254 prompt_store: Option<Entity<PromptStore>>,
255 fs: Arc<dyn Fs>,
256 cx: &mut AsyncApp,
257 ) -> Result<Entity<NativeAgent>> {
258 log::debug!("Creating new NativeAgent");
259
260 let project_context = cx
261 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
262 .await;
263
264 Ok(cx.new(|cx| {
265 let context_server_store = project.read(cx).context_server_store();
266 let context_server_registry =
267 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
268
269 let mut subscriptions = vec![
270 cx.subscribe(&project, Self::handle_project_event),
271 cx.subscribe(
272 &LanguageModelRegistry::global(cx),
273 Self::handle_models_updated_event,
274 ),
275 cx.subscribe(
276 &context_server_store,
277 Self::handle_context_server_store_updated,
278 ),
279 cx.subscribe(
280 &context_server_registry,
281 Self::handle_context_server_registry_event,
282 ),
283 ];
284 if let Some(prompt_store) = prompt_store.as_ref() {
285 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
286 }
287
288 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
289 watch::channel(());
290 Self {
291 sessions: HashMap::default(),
292 thread_store,
293 project_context: cx.new(|_| project_context),
294 project_context_needs_refresh: project_context_needs_refresh_tx,
295 _maintain_project_context: cx.spawn(async move |this, cx| {
296 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
297 }),
298 context_server_registry,
299 templates,
300 models: LanguageModels::new(cx),
301 project,
302 prompt_store,
303 fs,
304 _subscriptions: subscriptions,
305 }
306 }))
307 }
308
309 fn new_session(
310 &mut self,
311 project: Entity<Project>,
312 cx: &mut Context<Self>,
313 ) -> Entity<AcpThread> {
314 // Create Thread
315 // Fetch default model from registry settings
316 let registry = LanguageModelRegistry::read_global(cx);
317 // Log available models for debugging
318 let available_count = registry.available_models(cx).count();
319 log::debug!("Total available models: {}", available_count);
320
321 let default_model = registry.default_model().and_then(|default_model| {
322 self.models
323 .model_from_id(&LanguageModels::model_id(&default_model.model))
324 });
325 let thread = cx.new(|cx| {
326 Thread::new(
327 project.clone(),
328 self.project_context.clone(),
329 self.context_server_registry.clone(),
330 self.templates.clone(),
331 default_model,
332 cx,
333 )
334 });
335
336 self.register_session(thread, cx)
337 }
338
339 fn register_session(
340 &mut self,
341 thread_handle: Entity<Thread>,
342 cx: &mut Context<Self>,
343 ) -> Entity<AcpThread> {
344 let connection = Rc::new(NativeAgentConnection(cx.entity()));
345
346 let thread = thread_handle.read(cx);
347 let session_id = thread.id().clone();
348 let title = thread.title();
349 let project = thread.project.clone();
350 let action_log = thread.action_log.clone();
351 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
352 let acp_thread = cx.new(|cx| {
353 acp_thread::AcpThread::new(
354 title,
355 connection,
356 project.clone(),
357 action_log.clone(),
358 session_id.clone(),
359 prompt_capabilities_rx,
360 cx,
361 )
362 });
363
364 let registry = LanguageModelRegistry::read_global(cx);
365 let summarization_model = registry.thread_summary_model().map(|c| c.model);
366
367 thread_handle.update(cx, |thread, cx| {
368 thread.set_summarization_model(summarization_model, cx);
369 thread.add_default_tools(
370 Rc::new(AcpThreadEnvironment {
371 acp_thread: acp_thread.downgrade(),
372 }) as _,
373 cx,
374 )
375 });
376
377 let subscriptions = vec![
378 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
379 this.sessions.remove(acp_thread.session_id());
380 }),
381 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
382 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
383 cx.observe(&thread_handle, move |this, thread, cx| {
384 this.save_thread(thread, cx)
385 }),
386 ];
387
388 self.sessions.insert(
389 session_id,
390 Session {
391 thread: thread_handle,
392 acp_thread: acp_thread.downgrade(),
393 _subscriptions: subscriptions,
394 pending_save: Task::ready(()),
395 },
396 );
397
398 self.update_available_commands(cx);
399
400 acp_thread
401 }
402
403 pub fn models(&self) -> &LanguageModels {
404 &self.models
405 }
406
407 async fn maintain_project_context(
408 this: WeakEntity<Self>,
409 mut needs_refresh: watch::Receiver<()>,
410 cx: &mut AsyncApp,
411 ) -> Result<()> {
412 while needs_refresh.changed().await.is_ok() {
413 let project_context = this
414 .update(cx, |this, cx| {
415 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
416 })?
417 .await;
418 this.update(cx, |this, cx| {
419 this.project_context = cx.new(|_| project_context);
420 })?;
421 }
422
423 Ok(())
424 }
425
426 fn build_project_context(
427 project: &Entity<Project>,
428 prompt_store: Option<&Entity<PromptStore>>,
429 cx: &mut App,
430 ) -> Task<ProjectContext> {
431 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
432 let worktree_tasks = worktrees
433 .into_iter()
434 .map(|worktree| {
435 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
436 })
437 .collect::<Vec<_>>();
438 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
439 prompt_store.read_with(cx, |prompt_store, cx| {
440 let prompts = prompt_store.default_prompt_metadata();
441 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
442 let contents = prompt_store.load(prompt_metadata.id, cx);
443 async move { (contents.await, prompt_metadata) }
444 });
445 cx.background_spawn(future::join_all(load_tasks))
446 })
447 } else {
448 Task::ready(vec![])
449 };
450
451 cx.spawn(async move |_cx| {
452 let (worktrees, default_user_rules) =
453 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
454
455 let worktrees = worktrees
456 .into_iter()
457 .map(|(worktree, _rules_error)| {
458 // TODO: show error message
459 // if let Some(rules_error) = rules_error {
460 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
461 // }
462 worktree
463 })
464 .collect::<Vec<_>>();
465
466 let default_user_rules = default_user_rules
467 .into_iter()
468 .flat_map(|(contents, prompt_metadata)| match contents {
469 Ok(contents) => Some(UserRulesContext {
470 uuid: prompt_metadata.id.as_user()?,
471 title: prompt_metadata.title.map(|title| title.to_string()),
472 contents,
473 }),
474 Err(_err) => {
475 // TODO: show error message
476 // this.update(cx, |_, cx| {
477 // cx.emit(RulesLoadingError {
478 // message: format!("{err:?}").into(),
479 // });
480 // })
481 // .ok();
482 None
483 }
484 })
485 .collect::<Vec<_>>();
486
487 ProjectContext::new(worktrees, default_user_rules)
488 })
489 }
490
491 fn load_worktree_info_for_system_prompt(
492 worktree: Entity<Worktree>,
493 project: Entity<Project>,
494 cx: &mut App,
495 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
496 let tree = worktree.read(cx);
497 let root_name = tree.root_name_str().into();
498 let abs_path = tree.abs_path();
499
500 let mut context = WorktreeContext {
501 root_name,
502 abs_path,
503 rules_file: None,
504 };
505
506 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
507 let Some(rules_task) = rules_task else {
508 return Task::ready((context, None));
509 };
510
511 cx.spawn(async move |_| {
512 let (rules_file, rules_file_error) = match rules_task.await {
513 Ok(rules_file) => (Some(rules_file), None),
514 Err(err) => (
515 None,
516 Some(RulesLoadingError {
517 message: format!("{err}").into(),
518 }),
519 ),
520 };
521 context.rules_file = rules_file;
522 (context, rules_file_error)
523 })
524 }
525
526 fn load_worktree_rules_file(
527 worktree: Entity<Worktree>,
528 project: Entity<Project>,
529 cx: &mut App,
530 ) -> Option<Task<Result<RulesFileContext>>> {
531 let worktree = worktree.read(cx);
532 let worktree_id = worktree.id();
533 let selected_rules_file = RULES_FILE_NAMES
534 .into_iter()
535 .filter_map(|name| {
536 worktree
537 .entry_for_path(RelPath::unix(name).unwrap())
538 .filter(|entry| entry.is_file())
539 .map(|entry| entry.path.clone())
540 })
541 .next();
542
543 // Note that Cline supports `.clinerules` being a directory, but that is not currently
544 // supported. This doesn't seem to occur often in GitHub repositories.
545 selected_rules_file.map(|path_in_worktree| {
546 let project_path = ProjectPath {
547 worktree_id,
548 path: path_in_worktree.clone(),
549 };
550 let buffer_task =
551 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
552 let rope_task = cx.spawn(async move |cx| {
553 let buffer = buffer_task.await?;
554 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
555 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
556 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
557 })?;
558 anyhow::Ok((project_entry_id, rope))
559 });
560 // Build a string from the rope on a background thread.
561 cx.background_spawn(async move {
562 let (project_entry_id, rope) = rope_task.await?;
563 anyhow::Ok(RulesFileContext {
564 path_in_worktree,
565 text: rope.to_string().trim().to_string(),
566 project_entry_id: project_entry_id.to_usize(),
567 })
568 })
569 })
570 }
571
572 fn handle_thread_title_updated(
573 &mut self,
574 thread: Entity<Thread>,
575 _: &TitleUpdated,
576 cx: &mut Context<Self>,
577 ) {
578 let session_id = thread.read(cx).id();
579 let Some(session) = self.sessions.get(session_id) else {
580 return;
581 };
582 let thread = thread.downgrade();
583 let acp_thread = session.acp_thread.clone();
584 cx.spawn(async move |_, cx| {
585 let title = thread.read_with(cx, |thread, _| thread.title())?;
586 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
587 task.await
588 })
589 .detach_and_log_err(cx);
590 }
591
592 fn handle_thread_token_usage_updated(
593 &mut self,
594 thread: Entity<Thread>,
595 usage: &TokenUsageUpdated,
596 cx: &mut Context<Self>,
597 ) {
598 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
599 return;
600 };
601 session
602 .acp_thread
603 .update(cx, |acp_thread, cx| {
604 acp_thread.update_token_usage(usage.0.clone(), cx);
605 })
606 .ok();
607 }
608
609 fn handle_project_event(
610 &mut self,
611 _project: Entity<Project>,
612 event: &project::Event,
613 _cx: &mut Context<Self>,
614 ) {
615 match event {
616 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
617 self.project_context_needs_refresh.send(()).ok();
618 }
619 project::Event::WorktreeUpdatedEntries(_, items) => {
620 if items.iter().any(|(path, _, _)| {
621 RULES_FILE_NAMES
622 .iter()
623 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
624 }) {
625 self.project_context_needs_refresh.send(()).ok();
626 }
627 }
628 _ => {}
629 }
630 }
631
632 fn handle_prompts_updated_event(
633 &mut self,
634 _prompt_store: Entity<PromptStore>,
635 _event: &prompt_store::PromptsUpdatedEvent,
636 _cx: &mut Context<Self>,
637 ) {
638 self.project_context_needs_refresh.send(()).ok();
639 }
640
641 fn handle_models_updated_event(
642 &mut self,
643 _registry: Entity<LanguageModelRegistry>,
644 _event: &language_model::Event,
645 cx: &mut Context<Self>,
646 ) {
647 self.models.refresh_list(cx);
648
649 let registry = LanguageModelRegistry::read_global(cx);
650 let default_model = registry.default_model().map(|m| m.model);
651 let summarization_model = registry.thread_summary_model().map(|m| m.model);
652
653 for session in self.sessions.values_mut() {
654 session.thread.update(cx, |thread, cx| {
655 if thread.model().is_none()
656 && let Some(model) = default_model.clone()
657 {
658 thread.set_model(model, cx);
659 cx.notify();
660 }
661 thread.set_summarization_model(summarization_model.clone(), cx);
662 });
663 }
664 }
665
666 fn handle_context_server_store_updated(
667 &mut self,
668 _store: Entity<project::context_server_store::ContextServerStore>,
669 _event: &project::context_server_store::ServerStatusChangedEvent,
670 cx: &mut Context<Self>,
671 ) {
672 self.update_available_commands(cx);
673 }
674
675 fn handle_context_server_registry_event(
676 &mut self,
677 _registry: Entity<ContextServerRegistry>,
678 event: &ContextServerRegistryEvent,
679 cx: &mut Context<Self>,
680 ) {
681 match event {
682 ContextServerRegistryEvent::ToolsChanged => {}
683 ContextServerRegistryEvent::PromptsChanged => {
684 self.update_available_commands(cx);
685 }
686 }
687 }
688
689 fn update_available_commands(&self, cx: &mut Context<Self>) {
690 let available_commands = self.build_available_commands(cx);
691 for session in self.sessions.values() {
692 if let Some(acp_thread) = session.acp_thread.upgrade() {
693 acp_thread.update(cx, |thread, cx| {
694 thread
695 .handle_session_update(
696 acp::SessionUpdate::AvailableCommandsUpdate(
697 acp::AvailableCommandsUpdate::new(available_commands.clone()),
698 ),
699 cx,
700 )
701 .log_err();
702 });
703 }
704 }
705 }
706
707 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
708 let registry = self.context_server_registry.read(cx);
709
710 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
711 for context_server_prompt in registry.prompts() {
712 *prompt_name_counts
713 .entry(context_server_prompt.prompt.name.as_str())
714 .or_insert(0) += 1;
715 }
716
717 registry
718 .prompts()
719 .flat_map(|context_server_prompt| {
720 let prompt = &context_server_prompt.prompt;
721
722 let should_prefix = prompt_name_counts
723 .get(prompt.name.as_str())
724 .copied()
725 .unwrap_or(0)
726 > 1;
727
728 let name = if should_prefix {
729 format!("{}.{}", context_server_prompt.server_id, prompt.name)
730 } else {
731 prompt.name.clone()
732 };
733
734 let mut command = acp::AvailableCommand::new(
735 name,
736 prompt.description.clone().unwrap_or_default(),
737 );
738
739 match prompt.arguments.as_deref() {
740 Some([arg]) => {
741 let hint = format!("<{}>", arg.name);
742
743 command = command.input(acp::AvailableCommandInput::Unstructured(
744 acp::UnstructuredCommandInput::new(hint),
745 ));
746 }
747 Some([]) | None => {}
748 Some(_) => {
749 // skip >1 argument commands since we don't support them yet
750 return None;
751 }
752 }
753
754 Some(command)
755 })
756 .collect()
757 }
758
759 pub fn load_thread(
760 &mut self,
761 id: acp::SessionId,
762 cx: &mut Context<Self>,
763 ) -> Task<Result<Entity<Thread>>> {
764 let database_future = ThreadsDatabase::connect(cx);
765 cx.spawn(async move |this, cx| {
766 let database = database_future.await.map_err(|err| anyhow!(err))?;
767 let db_thread = database
768 .load_thread(id.clone())
769 .await?
770 .with_context(|| format!("no thread found with ID: {id:?}"))?;
771
772 this.update(cx, |this, cx| {
773 let summarization_model = LanguageModelRegistry::read_global(cx)
774 .thread_summary_model()
775 .map(|c| c.model);
776
777 cx.new(|cx| {
778 let mut thread = Thread::from_db(
779 id.clone(),
780 db_thread,
781 this.project.clone(),
782 this.project_context.clone(),
783 this.context_server_registry.clone(),
784 this.templates.clone(),
785 cx,
786 );
787 thread.set_summarization_model(summarization_model, cx);
788 thread
789 })
790 })
791 })
792 }
793
794 pub fn open_thread(
795 &mut self,
796 id: acp::SessionId,
797 cx: &mut Context<Self>,
798 ) -> Task<Result<Entity<AcpThread>>> {
799 let task = self.load_thread(id, cx);
800 cx.spawn(async move |this, cx| {
801 let thread = task.await?;
802 let acp_thread =
803 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
804 let events = thread.update(cx, |thread, cx| thread.replay(cx));
805 cx.update(|cx| {
806 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
807 })
808 .await?;
809 Ok(acp_thread)
810 })
811 }
812
813 pub fn thread_summary(
814 &mut self,
815 id: acp::SessionId,
816 cx: &mut Context<Self>,
817 ) -> Task<Result<SharedString>> {
818 let thread = self.open_thread(id.clone(), cx);
819 cx.spawn(async move |this, cx| {
820 let acp_thread = thread.await?;
821 let result = this
822 .update(cx, |this, cx| {
823 this.sessions
824 .get(&id)
825 .unwrap()
826 .thread
827 .update(cx, |thread, cx| thread.summary(cx))
828 })?
829 .await
830 .context("Failed to generate summary")?;
831 drop(acp_thread);
832 Ok(result)
833 })
834 }
835
836 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
837 if thread.read(cx).is_empty() {
838 return;
839 }
840
841 let database_future = ThreadsDatabase::connect(cx);
842 let (id, db_thread) =
843 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
844 let Some(session) = self.sessions.get_mut(&id) else {
845 return;
846 };
847 let thread_store = self.thread_store.clone();
848 session.pending_save = cx.spawn(async move |_, cx| {
849 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
850 return;
851 };
852 let db_thread = db_thread.await;
853 database.save_thread(id, db_thread).await.log_err();
854 thread_store.update(cx, |store, cx| store.reload(cx));
855 });
856 }
857
858 fn send_mcp_prompt(
859 &self,
860 message_id: UserMessageId,
861 session_id: agent_client_protocol::SessionId,
862 prompt_name: String,
863 server_id: ContextServerId,
864 arguments: HashMap<String, String>,
865 original_content: Vec<acp::ContentBlock>,
866 cx: &mut Context<Self>,
867 ) -> Task<Result<acp::PromptResponse>> {
868 let server_store = self.context_server_registry.read(cx).server_store().clone();
869 let path_style = self.project.read(cx).path_style(cx);
870
871 cx.spawn(async move |this, cx| {
872 let prompt =
873 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
874
875 let (acp_thread, thread) = this.update(cx, |this, _cx| {
876 let session = this
877 .sessions
878 .get(&session_id)
879 .context("Failed to get session")?;
880 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
881 })??;
882
883 let mut last_is_user = true;
884
885 thread.update(cx, |thread, cx| {
886 thread.push_acp_user_block(
887 message_id,
888 original_content.into_iter().skip(1),
889 path_style,
890 cx,
891 );
892 });
893
894 for message in prompt.messages {
895 let context_server::types::PromptMessage { role, content } = message;
896 let block = mcp_message_content_to_acp_content_block(content);
897
898 match role {
899 context_server::types::Role::User => {
900 let id = acp_thread::UserMessageId::new();
901
902 acp_thread.update(cx, |acp_thread, cx| {
903 acp_thread.push_user_content_block_with_indent(
904 Some(id.clone()),
905 block.clone(),
906 true,
907 cx,
908 );
909 })?;
910
911 thread.update(cx, |thread, cx| {
912 thread.push_acp_user_block(id, [block], path_style, cx);
913 });
914 }
915 context_server::types::Role::Assistant => {
916 acp_thread.update(cx, |acp_thread, cx| {
917 acp_thread.push_assistant_content_block_with_indent(
918 block.clone(),
919 false,
920 true,
921 cx,
922 );
923 })?;
924
925 thread.update(cx, |thread, cx| {
926 thread.push_acp_agent_block(block, cx);
927 });
928 }
929 }
930
931 last_is_user = role == context_server::types::Role::User;
932 }
933
934 let response_stream = thread.update(cx, |thread, cx| {
935 if last_is_user {
936 thread.send_existing(cx)
937 } else {
938 // Resume if MCP prompt did not end with a user message
939 thread.resume(cx)
940 }
941 })?;
942
943 cx.update(|cx| {
944 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
945 })
946 .await
947 })
948 }
949}
950
951/// Wrapper struct that implements the AgentConnection trait
952#[derive(Clone)]
953pub struct NativeAgentConnection(pub Entity<NativeAgent>);
954
955impl NativeAgentConnection {
956 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
957 self.0
958 .read(cx)
959 .sessions
960 .get(session_id)
961 .map(|session| session.thread.clone())
962 }
963
964 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
965 self.0.update(cx, |this, cx| this.load_thread(id, cx))
966 }
967
968 fn run_turn(
969 &self,
970 session_id: acp::SessionId,
971 cx: &mut App,
972 f: impl 'static
973 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
974 ) -> Task<Result<acp::PromptResponse>> {
975 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
976 agent
977 .sessions
978 .get_mut(&session_id)
979 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
980 }) else {
981 return Task::ready(Err(anyhow!("Session not found")));
982 };
983 log::debug!("Found session for: {}", session_id);
984
985 let response_stream = match f(thread, cx) {
986 Ok(stream) => stream,
987 Err(err) => return Task::ready(Err(err)),
988 };
989 Self::handle_thread_events(response_stream, acp_thread, cx)
990 }
991
992 fn handle_thread_events(
993 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
994 acp_thread: WeakEntity<AcpThread>,
995 cx: &App,
996 ) -> Task<Result<acp::PromptResponse>> {
997 cx.spawn(async move |cx| {
998 // Handle response stream and forward to session.acp_thread
999 while let Some(result) = events.next().await {
1000 match result {
1001 Ok(event) => {
1002 log::trace!("Received completion event: {:?}", event);
1003
1004 match event {
1005 ThreadEvent::UserMessage(message) => {
1006 acp_thread.update(cx, |thread, cx| {
1007 for content in message.content {
1008 thread.push_user_content_block(
1009 Some(message.id.clone()),
1010 content.into(),
1011 cx,
1012 );
1013 }
1014 })?;
1015 }
1016 ThreadEvent::AgentText(text) => {
1017 acp_thread.update(cx, |thread, cx| {
1018 thread.push_assistant_content_block(text.into(), false, cx)
1019 })?;
1020 }
1021 ThreadEvent::AgentThinking(text) => {
1022 acp_thread.update(cx, |thread, cx| {
1023 thread.push_assistant_content_block(text.into(), true, cx)
1024 })?;
1025 }
1026 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1027 tool_call,
1028 options,
1029 response,
1030 context: _,
1031 }) => {
1032 let outcome_task = acp_thread.update(cx, |thread, cx| {
1033 thread.request_tool_call_authorization(
1034 tool_call, options, true, cx,
1035 )
1036 })??;
1037 cx.background_spawn(async move {
1038 if let acp::RequestPermissionOutcome::Selected(
1039 acp::SelectedPermissionOutcome { option_id, .. },
1040 ) = outcome_task.await
1041 {
1042 response
1043 .send(option_id)
1044 .map(|_| anyhow!("authorization receiver was dropped"))
1045 .log_err();
1046 }
1047 })
1048 .detach();
1049 }
1050 ThreadEvent::ToolCall(tool_call) => {
1051 acp_thread.update(cx, |thread, cx| {
1052 thread.upsert_tool_call(tool_call, cx)
1053 })??;
1054 }
1055 ThreadEvent::ToolCallUpdate(update) => {
1056 acp_thread.update(cx, |thread, cx| {
1057 thread.update_tool_call(update, cx)
1058 })??;
1059 }
1060 ThreadEvent::Retry(status) => {
1061 acp_thread.update(cx, |thread, cx| {
1062 thread.update_retry_status(status, cx)
1063 })?;
1064 }
1065 ThreadEvent::Stop(stop_reason) => {
1066 log::debug!("Assistant message complete: {:?}", stop_reason);
1067 return Ok(acp::PromptResponse::new(stop_reason));
1068 }
1069 }
1070 }
1071 Err(e) => {
1072 log::error!("Error in model response stream: {:?}", e);
1073 return Err(e);
1074 }
1075 }
1076 }
1077
1078 log::debug!("Response stream completed");
1079 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1080 })
1081 }
1082}
1083
1084struct Command<'a> {
1085 prompt_name: &'a str,
1086 arg_value: &'a str,
1087 explicit_server_id: Option<&'a str>,
1088}
1089
1090impl<'a> Command<'a> {
1091 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1092 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1093 return None;
1094 };
1095 let text = text_content.text.trim();
1096 let command = text.strip_prefix('/')?;
1097 let (command, arg_value) = command
1098 .split_once(char::is_whitespace)
1099 .unwrap_or((command, ""));
1100
1101 if let Some((server_id, prompt_name)) = command.split_once('.') {
1102 Some(Self {
1103 prompt_name,
1104 arg_value,
1105 explicit_server_id: Some(server_id),
1106 })
1107 } else {
1108 Some(Self {
1109 prompt_name: command,
1110 arg_value,
1111 explicit_server_id: None,
1112 })
1113 }
1114 }
1115}
1116
1117struct NativeAgentModelSelector {
1118 session_id: acp::SessionId,
1119 connection: NativeAgentConnection,
1120}
1121
1122impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1123 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1124 log::debug!("NativeAgentConnection::list_models called");
1125 let list = self.connection.0.read(cx).models.model_list.clone();
1126 Task::ready(if list.is_empty() {
1127 Err(anyhow::anyhow!("No models available"))
1128 } else {
1129 Ok(list)
1130 })
1131 }
1132
1133 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1134 log::debug!(
1135 "Setting model for session {}: {}",
1136 self.session_id,
1137 model_id
1138 );
1139 let Some(thread) = self
1140 .connection
1141 .0
1142 .read(cx)
1143 .sessions
1144 .get(&self.session_id)
1145 .map(|session| session.thread.clone())
1146 else {
1147 return Task::ready(Err(anyhow!("Session not found")));
1148 };
1149
1150 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1151 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1152 };
1153
1154 // We want to reset the effort level when switching models, as the currently-selected effort level may
1155 // not be compatible.
1156 let effort = model
1157 .default_effort_level()
1158 .map(|effort_level| effort_level.value.to_string());
1159
1160 thread.update(cx, |thread, cx| {
1161 thread.set_model(model.clone(), cx);
1162 thread.set_thinking_effort(effort.clone(), cx);
1163 thread.set_thinking_enabled(model.supports_thinking(), cx);
1164 });
1165
1166 update_settings_file(
1167 self.connection.0.read(cx).fs.clone(),
1168 cx,
1169 move |settings, cx| {
1170 let provider = model.provider_id().0.to_string();
1171 let model = model.id().0.to_string();
1172 let enable_thinking = thread.read(cx).thinking_enabled();
1173 settings
1174 .agent
1175 .get_or_insert_default()
1176 .set_model(LanguageModelSelection {
1177 provider: provider.into(),
1178 model,
1179 enable_thinking,
1180 effort,
1181 });
1182 },
1183 );
1184
1185 Task::ready(Ok(()))
1186 }
1187
1188 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1189 let Some(thread) = self
1190 .connection
1191 .0
1192 .read(cx)
1193 .sessions
1194 .get(&self.session_id)
1195 .map(|session| session.thread.clone())
1196 else {
1197 return Task::ready(Err(anyhow!("Session not found")));
1198 };
1199 let Some(model) = thread.read(cx).model() else {
1200 return Task::ready(Err(anyhow!("Model not found")));
1201 };
1202 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1203 else {
1204 return Task::ready(Err(anyhow!("Provider not found")));
1205 };
1206 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1207 model, &provider,
1208 )))
1209 }
1210
1211 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1212 Some(self.connection.0.read(cx).models.watch())
1213 }
1214
1215 fn should_render_footer(&self) -> bool {
1216 true
1217 }
1218}
1219
1220impl acp_thread::AgentConnection for NativeAgentConnection {
1221 fn telemetry_id(&self) -> SharedString {
1222 "zed".into()
1223 }
1224
1225 fn new_thread(
1226 self: Rc<Self>,
1227 project: Entity<Project>,
1228 cwd: &Path,
1229 cx: &mut App,
1230 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1231 log::debug!("Creating new thread for project at: {cwd:?}");
1232 Task::ready(Ok(self
1233 .0
1234 .update(cx, |agent, cx| agent.new_session(project, cx))))
1235 }
1236
1237 fn supports_load_session(&self, _cx: &App) -> bool {
1238 true
1239 }
1240
1241 fn load_session(
1242 self: Rc<Self>,
1243 session: AgentSessionInfo,
1244 _project: Entity<Project>,
1245 _cwd: &Path,
1246 cx: &mut App,
1247 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1248 self.0
1249 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1250 }
1251
1252 fn auth_methods(&self) -> &[acp::AuthMethod] {
1253 &[] // No auth for in-process
1254 }
1255
1256 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1257 Task::ready(Ok(()))
1258 }
1259
1260 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1261 Some(Rc::new(NativeAgentModelSelector {
1262 session_id: session_id.clone(),
1263 connection: self.clone(),
1264 }) as Rc<dyn AgentModelSelector>)
1265 }
1266
1267 fn prompt(
1268 &self,
1269 id: Option<acp_thread::UserMessageId>,
1270 params: acp::PromptRequest,
1271 cx: &mut App,
1272 ) -> Task<Result<acp::PromptResponse>> {
1273 let id = id.expect("UserMessageId is required");
1274 let session_id = params.session_id.clone();
1275 log::info!("Received prompt request for session: {}", session_id);
1276 log::debug!("Prompt blocks count: {}", params.prompt.len());
1277
1278 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1279 let registry = self.0.read(cx).context_server_registry.read(cx);
1280
1281 let explicit_server_id = parsed_command
1282 .explicit_server_id
1283 .map(|server_id| ContextServerId(server_id.into()));
1284
1285 if let Some(prompt) =
1286 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1287 {
1288 let arguments = if !parsed_command.arg_value.is_empty()
1289 && let Some(arg_name) = prompt
1290 .prompt
1291 .arguments
1292 .as_ref()
1293 .and_then(|args| args.first())
1294 .map(|arg| arg.name.clone())
1295 {
1296 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1297 } else {
1298 Default::default()
1299 };
1300
1301 let prompt_name = prompt.prompt.name.clone();
1302 let server_id = prompt.server_id.clone();
1303
1304 return self.0.update(cx, |agent, cx| {
1305 agent.send_mcp_prompt(
1306 id,
1307 session_id.clone(),
1308 prompt_name,
1309 server_id,
1310 arguments,
1311 params.prompt,
1312 cx,
1313 )
1314 });
1315 };
1316 };
1317
1318 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1319
1320 self.run_turn(session_id, cx, move |thread, cx| {
1321 let content: Vec<UserMessageContent> = params
1322 .prompt
1323 .into_iter()
1324 .map(|block| UserMessageContent::from_content_block(block, path_style))
1325 .collect::<Vec<_>>();
1326 log::debug!("Converted prompt to message: {} chars", content.len());
1327 log::debug!("Message id: {:?}", id);
1328 log::debug!("Message content: {:?}", content);
1329
1330 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1331 })
1332 }
1333
1334 fn retry(
1335 &self,
1336 session_id: &acp::SessionId,
1337 _cx: &App,
1338 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1339 Some(Rc::new(NativeAgentSessionRetry {
1340 connection: self.clone(),
1341 session_id: session_id.clone(),
1342 }) as _)
1343 }
1344
1345 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1346 log::info!("Cancelling on session: {}", session_id);
1347 self.0.update(cx, |agent, cx| {
1348 if let Some(agent) = agent.sessions.get(session_id) {
1349 agent
1350 .thread
1351 .update(cx, |thread, cx| thread.cancel(cx))
1352 .detach();
1353 }
1354 });
1355 }
1356
1357 fn truncate(
1358 &self,
1359 session_id: &agent_client_protocol::SessionId,
1360 cx: &App,
1361 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1362 self.0.read_with(cx, |agent, _cx| {
1363 agent.sessions.get(session_id).map(|session| {
1364 Rc::new(NativeAgentSessionTruncate {
1365 thread: session.thread.clone(),
1366 acp_thread: session.acp_thread.clone(),
1367 }) as _
1368 })
1369 })
1370 }
1371
1372 fn set_title(
1373 &self,
1374 session_id: &acp::SessionId,
1375 _cx: &App,
1376 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1377 Some(Rc::new(NativeAgentSessionSetTitle {
1378 connection: self.clone(),
1379 session_id: session_id.clone(),
1380 }) as _)
1381 }
1382
1383 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1384 let thread_store = self.0.read(cx).thread_store.clone();
1385 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1386 }
1387
1388 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1389 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1390 }
1391
1392 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1393 self
1394 }
1395}
1396
1397impl acp_thread::AgentTelemetry for NativeAgentConnection {
1398 fn thread_data(
1399 &self,
1400 session_id: &acp::SessionId,
1401 cx: &mut App,
1402 ) -> Task<Result<serde_json::Value>> {
1403 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1404 return Task::ready(Err(anyhow!("Session not found")));
1405 };
1406
1407 let task = session.thread.read(cx).to_db(cx);
1408 cx.background_spawn(async move {
1409 serde_json::to_value(task.await).context("Failed to serialize thread")
1410 })
1411 }
1412}
1413
1414pub struct NativeAgentSessionList {
1415 thread_store: Entity<ThreadStore>,
1416 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1417 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1418 _subscription: Subscription,
1419}
1420
1421impl NativeAgentSessionList {
1422 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1423 let (tx, rx) = smol::channel::unbounded();
1424 let this_tx = tx.clone();
1425 let subscription = cx.observe(&thread_store, move |_, _| {
1426 this_tx
1427 .try_send(acp_thread::SessionListUpdate::Refresh)
1428 .ok();
1429 });
1430 Self {
1431 thread_store,
1432 updates_tx: tx,
1433 updates_rx: rx,
1434 _subscription: subscription,
1435 }
1436 }
1437
1438 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1439 AgentSessionInfo {
1440 session_id: entry.id,
1441 cwd: None,
1442 title: Some(entry.title),
1443 updated_at: Some(entry.updated_at),
1444 meta: None,
1445 }
1446 }
1447
1448 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1449 &self.thread_store
1450 }
1451}
1452
1453impl AgentSessionList for NativeAgentSessionList {
1454 fn list_sessions(
1455 &self,
1456 _request: AgentSessionListRequest,
1457 cx: &mut App,
1458 ) -> Task<Result<AgentSessionListResponse>> {
1459 let sessions = self
1460 .thread_store
1461 .read(cx)
1462 .entries()
1463 .map(Self::to_session_info)
1464 .collect();
1465 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1466 }
1467
1468 fn supports_delete(&self) -> bool {
1469 true
1470 }
1471
1472 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1473 self.thread_store
1474 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1475 }
1476
1477 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1478 self.thread_store
1479 .update(cx, |store, cx| store.delete_threads(cx))
1480 }
1481
1482 fn watch(
1483 &self,
1484 _cx: &mut App,
1485 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1486 Some(self.updates_rx.clone())
1487 }
1488
1489 fn notify_refresh(&self) {
1490 self.updates_tx
1491 .try_send(acp_thread::SessionListUpdate::Refresh)
1492 .ok();
1493 }
1494
1495 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1496 self
1497 }
1498}
1499
1500struct NativeAgentSessionTruncate {
1501 thread: Entity<Thread>,
1502 acp_thread: WeakEntity<AcpThread>,
1503}
1504
1505impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1506 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1507 match self.thread.update(cx, |thread, cx| {
1508 thread.truncate(message_id.clone(), cx)?;
1509 Ok(thread.latest_token_usage())
1510 }) {
1511 Ok(usage) => {
1512 self.acp_thread
1513 .update(cx, |thread, cx| {
1514 thread.update_token_usage(usage, cx);
1515 })
1516 .ok();
1517 Task::ready(Ok(()))
1518 }
1519 Err(error) => Task::ready(Err(error)),
1520 }
1521 }
1522}
1523
1524struct NativeAgentSessionRetry {
1525 connection: NativeAgentConnection,
1526 session_id: acp::SessionId,
1527}
1528
1529impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1530 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1531 self.connection
1532 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1533 thread.update(cx, |thread, cx| thread.resume(cx))
1534 })
1535 }
1536}
1537
1538struct NativeAgentSessionSetTitle {
1539 connection: NativeAgentConnection,
1540 session_id: acp::SessionId,
1541}
1542
1543impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1544 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1545 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1546 return Task::ready(Err(anyhow!("session not found")));
1547 };
1548 let thread = session.thread.clone();
1549 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1550 Task::ready(Ok(()))
1551 }
1552}
1553
1554pub struct AcpThreadEnvironment {
1555 acp_thread: WeakEntity<AcpThread>,
1556}
1557
1558impl ThreadEnvironment for AcpThreadEnvironment {
1559 fn create_terminal(
1560 &self,
1561 command: String,
1562 cwd: Option<PathBuf>,
1563 output_byte_limit: Option<u64>,
1564 cx: &mut AsyncApp,
1565 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1566 let task = self.acp_thread.update(cx, |thread, cx| {
1567 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1568 });
1569
1570 let acp_thread = self.acp_thread.clone();
1571 cx.spawn(async move |cx| {
1572 let terminal = task?.await?;
1573
1574 let (drop_tx, drop_rx) = oneshot::channel();
1575 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1576
1577 cx.spawn(async move |cx| {
1578 drop_rx.await.ok();
1579 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1580 })
1581 .detach();
1582
1583 let handle = AcpTerminalHandle {
1584 terminal,
1585 _drop_tx: Some(drop_tx),
1586 };
1587
1588 Ok(Rc::new(handle) as _)
1589 })
1590 }
1591}
1592
1593pub struct AcpTerminalHandle {
1594 terminal: Entity<acp_thread::Terminal>,
1595 _drop_tx: Option<oneshot::Sender<()>>,
1596}
1597
1598impl TerminalHandle for AcpTerminalHandle {
1599 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1600 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1601 }
1602
1603 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1604 Ok(self
1605 .terminal
1606 .read_with(cx, |term, _cx| term.wait_for_exit()))
1607 }
1608
1609 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1610 Ok(self
1611 .terminal
1612 .read_with(cx, |term, cx| term.current_output(cx)))
1613 }
1614
1615 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1616 cx.update(|cx| {
1617 self.terminal.update(cx, |terminal, cx| {
1618 terminal.kill(cx);
1619 });
1620 });
1621 Ok(())
1622 }
1623
1624 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1625 Ok(self
1626 .terminal
1627 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1628 }
1629}
1630
1631#[cfg(test)]
1632mod internal_tests {
1633 use super::*;
1634 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1635 use fs::FakeFs;
1636 use gpui::TestAppContext;
1637 use indoc::formatdoc;
1638 use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
1639 use language_model::{LanguageModelProviderId, LanguageModelProviderName};
1640 use serde_json::json;
1641 use settings::SettingsStore;
1642 use util::{path, rel_path::rel_path};
1643
1644 #[gpui::test]
1645 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1646 init_test(cx);
1647 let fs = FakeFs::new(cx.executor());
1648 fs.insert_tree(
1649 "/",
1650 json!({
1651 "a": {}
1652 }),
1653 )
1654 .await;
1655 let project = Project::test(fs.clone(), [], cx).await;
1656 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1657 let agent = NativeAgent::new(
1658 project.clone(),
1659 thread_store,
1660 Templates::new(),
1661 None,
1662 fs.clone(),
1663 &mut cx.to_async(),
1664 )
1665 .await
1666 .unwrap();
1667 agent.read_with(cx, |agent, cx| {
1668 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1669 });
1670
1671 let worktree = project
1672 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1673 .await
1674 .unwrap();
1675 cx.run_until_parked();
1676 agent.read_with(cx, |agent, cx| {
1677 assert_eq!(
1678 agent.project_context.read(cx).worktrees,
1679 vec![WorktreeContext {
1680 root_name: "a".into(),
1681 abs_path: Path::new("/a").into(),
1682 rules_file: None
1683 }]
1684 )
1685 });
1686
1687 // Creating `/a/.rules` updates the project context.
1688 fs.insert_file("/a/.rules", Vec::new()).await;
1689 cx.run_until_parked();
1690 agent.read_with(cx, |agent, cx| {
1691 let rules_entry = worktree
1692 .read(cx)
1693 .entry_for_path(rel_path(".rules"))
1694 .unwrap();
1695 assert_eq!(
1696 agent.project_context.read(cx).worktrees,
1697 vec![WorktreeContext {
1698 root_name: "a".into(),
1699 abs_path: Path::new("/a").into(),
1700 rules_file: Some(RulesFileContext {
1701 path_in_worktree: rel_path(".rules").into(),
1702 text: "".into(),
1703 project_entry_id: rules_entry.id.to_usize()
1704 })
1705 }]
1706 )
1707 });
1708 }
1709
1710 #[gpui::test]
1711 async fn test_listing_models(cx: &mut TestAppContext) {
1712 init_test(cx);
1713 let fs = FakeFs::new(cx.executor());
1714 fs.insert_tree("/", json!({ "a": {} })).await;
1715 let project = Project::test(fs.clone(), [], cx).await;
1716 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1717 let connection = NativeAgentConnection(
1718 NativeAgent::new(
1719 project.clone(),
1720 thread_store,
1721 Templates::new(),
1722 None,
1723 fs.clone(),
1724 &mut cx.to_async(),
1725 )
1726 .await
1727 .unwrap(),
1728 );
1729
1730 // Create a thread/session
1731 let acp_thread = cx
1732 .update(|cx| {
1733 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1734 })
1735 .await
1736 .unwrap();
1737
1738 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1739
1740 let models = cx
1741 .update(|cx| {
1742 connection
1743 .model_selector(&session_id)
1744 .unwrap()
1745 .list_models(cx)
1746 })
1747 .await
1748 .unwrap();
1749
1750 let acp_thread::AgentModelList::Grouped(models) = models else {
1751 panic!("Unexpected model group");
1752 };
1753 assert_eq!(
1754 models,
1755 IndexMap::from_iter([(
1756 AgentModelGroupName("Fake".into()),
1757 vec![AgentModelInfo {
1758 id: acp::ModelId::new("fake/fake"),
1759 name: "Fake".into(),
1760 description: None,
1761 icon: Some(acp_thread::AgentModelIcon::Named(
1762 ui::IconName::ZedAssistant
1763 )),
1764 is_latest: false,
1765 }]
1766 )])
1767 );
1768 }
1769
1770 #[gpui::test]
1771 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1772 init_test(cx);
1773 let fs = FakeFs::new(cx.executor());
1774 fs.create_dir(paths::settings_file().parent().unwrap())
1775 .await
1776 .unwrap();
1777 fs.insert_file(
1778 paths::settings_file(),
1779 json!({
1780 "agent": {
1781 "default_model": {
1782 "provider": "foo",
1783 "model": "bar"
1784 }
1785 }
1786 })
1787 .to_string()
1788 .into_bytes(),
1789 )
1790 .await;
1791 let project = Project::test(fs.clone(), [], cx).await;
1792
1793 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1794
1795 // Create the agent and connection
1796 let agent = NativeAgent::new(
1797 project.clone(),
1798 thread_store,
1799 Templates::new(),
1800 None,
1801 fs.clone(),
1802 &mut cx.to_async(),
1803 )
1804 .await
1805 .unwrap();
1806 let connection = NativeAgentConnection(agent.clone());
1807
1808 // Create a thread/session
1809 let acp_thread = cx
1810 .update(|cx| {
1811 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1812 })
1813 .await
1814 .unwrap();
1815
1816 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1817
1818 // Select a model
1819 let selector = connection.model_selector(&session_id).unwrap();
1820 let model_id = acp::ModelId::new("fake/fake");
1821 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1822 .await
1823 .unwrap();
1824
1825 // Verify the thread has the selected model
1826 agent.read_with(cx, |agent, _| {
1827 let session = agent.sessions.get(&session_id).unwrap();
1828 session.thread.read_with(cx, |thread, _| {
1829 assert_eq!(thread.model().unwrap().id().0, "fake");
1830 });
1831 });
1832
1833 cx.run_until_parked();
1834
1835 // Verify settings file was updated
1836 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1837 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1838
1839 // Check that the agent settings contain the selected model
1840 assert_eq!(
1841 settings_json["agent"]["default_model"]["model"],
1842 json!("fake")
1843 );
1844 assert_eq!(
1845 settings_json["agent"]["default_model"]["provider"],
1846 json!("fake")
1847 );
1848
1849 // Register a thinking model and select it.
1850 cx.update(|cx| {
1851 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
1852 "fake-corp",
1853 "fake-thinking",
1854 "Fake Thinking",
1855 true,
1856 ));
1857 let thinking_provider = Arc::new(
1858 FakeLanguageModelProvider::new(
1859 LanguageModelProviderId::from("fake-corp".to_string()),
1860 LanguageModelProviderName::from("Fake Corp".to_string()),
1861 )
1862 .with_models(vec![thinking_model]),
1863 );
1864 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1865 registry.register_provider(thinking_provider, cx);
1866 });
1867 });
1868 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
1869
1870 let selector = connection.model_selector(&session_id).unwrap();
1871 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
1872 .await
1873 .unwrap();
1874 cx.run_until_parked();
1875
1876 // Verify enable_thinking was written to settings as true.
1877 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1878 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1879 assert_eq!(
1880 settings_json["agent"]["default_model"]["enable_thinking"],
1881 json!(true),
1882 "selecting a thinking model should persist enable_thinking: true to settings"
1883 );
1884 }
1885
1886 #[gpui::test]
1887 async fn test_select_model_updates_thinking_enabled(cx: &mut TestAppContext) {
1888 init_test(cx);
1889 let fs = FakeFs::new(cx.executor());
1890 fs.create_dir(paths::settings_file().parent().unwrap())
1891 .await
1892 .unwrap();
1893 fs.insert_file(paths::settings_file(), b"{}".to_vec()).await;
1894 let project = Project::test(fs.clone(), [], cx).await;
1895
1896 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1897 let agent = NativeAgent::new(
1898 project.clone(),
1899 thread_store,
1900 Templates::new(),
1901 None,
1902 fs.clone(),
1903 &mut cx.to_async(),
1904 )
1905 .await
1906 .unwrap();
1907 let connection = NativeAgentConnection(agent.clone());
1908
1909 let acp_thread = cx
1910 .update(|cx| {
1911 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1912 })
1913 .await
1914 .unwrap();
1915 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1916
1917 // Register a second provider with a thinking model.
1918 cx.update(|cx| {
1919 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
1920 "fake-corp",
1921 "fake-thinking",
1922 "Fake Thinking",
1923 true,
1924 ));
1925 let thinking_provider = Arc::new(
1926 FakeLanguageModelProvider::new(
1927 LanguageModelProviderId::from("fake-corp".to_string()),
1928 LanguageModelProviderName::from("Fake Corp".to_string()),
1929 )
1930 .with_models(vec![thinking_model]),
1931 );
1932 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
1933 registry.register_provider(thinking_provider, cx);
1934 });
1935 });
1936 // Refresh the agent's model list so it picks up the new provider.
1937 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
1938
1939 // Thread starts with thinking_enabled = false (the default).
1940 agent.read_with(cx, |agent, _| {
1941 let session = agent.sessions.get(&session_id).unwrap();
1942 session.thread.read_with(cx, |thread, _| {
1943 assert!(!thread.thinking_enabled(), "thinking defaults to false");
1944 });
1945 });
1946
1947 // Select the thinking model via select_model.
1948 let selector = connection.model_selector(&session_id).unwrap();
1949 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
1950 .await
1951 .unwrap();
1952
1953 // select_model should have enabled thinking based on the model's supports_thinking().
1954 agent.read_with(cx, |agent, _| {
1955 let session = agent.sessions.get(&session_id).unwrap();
1956 session.thread.read_with(cx, |thread, _| {
1957 assert!(
1958 thread.thinking_enabled(),
1959 "select_model should enable thinking when model supports it"
1960 );
1961 });
1962 });
1963
1964 // Switch back to the non-thinking model.
1965 let selector = connection.model_selector(&session_id).unwrap();
1966 cx.update(|cx| selector.select_model(acp::ModelId::new("fake/fake"), cx))
1967 .await
1968 .unwrap();
1969
1970 // select_model should have disabled thinking.
1971 agent.read_with(cx, |agent, _| {
1972 let session = agent.sessions.get(&session_id).unwrap();
1973 session.thread.read_with(cx, |thread, _| {
1974 assert!(
1975 !thread.thinking_enabled(),
1976 "select_model should disable thinking when model does not support it"
1977 );
1978 });
1979 });
1980 }
1981
1982 #[gpui::test]
1983 async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) {
1984 init_test(cx);
1985 let fs = FakeFs::new(cx.executor());
1986 fs.insert_tree("/", json!({ "a": {} })).await;
1987 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1988 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1989 let agent = NativeAgent::new(
1990 project.clone(),
1991 thread_store.clone(),
1992 Templates::new(),
1993 None,
1994 fs.clone(),
1995 &mut cx.to_async(),
1996 )
1997 .await
1998 .unwrap();
1999 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2000
2001 // Register a thinking model.
2002 let thinking_model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2003 "fake-corp",
2004 "fake-thinking",
2005 "Fake Thinking",
2006 true,
2007 ));
2008 let thinking_provider = Arc::new(
2009 FakeLanguageModelProvider::new(
2010 LanguageModelProviderId::from("fake-corp".to_string()),
2011 LanguageModelProviderName::from("Fake Corp".to_string()),
2012 )
2013 .with_models(vec![thinking_model.clone()]),
2014 );
2015 cx.update(|cx| {
2016 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2017 registry.register_provider(thinking_provider, cx);
2018 });
2019 });
2020 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2021
2022 // Create a thread and select the thinking model.
2023 let acp_thread = cx
2024 .update(|cx| {
2025 connection
2026 .clone()
2027 .new_thread(project.clone(), Path::new("/a"), cx)
2028 })
2029 .await
2030 .unwrap();
2031 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2032
2033 let selector = connection.model_selector(&session_id).unwrap();
2034 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/fake-thinking"), cx))
2035 .await
2036 .unwrap();
2037
2038 // Verify thinking is enabled after selecting the thinking model.
2039 let thread = agent.read_with(cx, |agent, _| {
2040 agent.sessions.get(&session_id).unwrap().thread.clone()
2041 });
2042 thread.read_with(cx, |thread, _| {
2043 assert!(
2044 thread.thinking_enabled(),
2045 "thinking should be enabled after selecting thinking model"
2046 );
2047 });
2048
2049 // Send a message so the thread gets persisted.
2050 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2051 let send = cx.foreground_executor().spawn(send);
2052 cx.run_until_parked();
2053
2054 thinking_model.send_last_completion_stream_text_chunk("Response.");
2055 thinking_model.end_last_completion_stream();
2056
2057 send.await.unwrap();
2058 cx.run_until_parked();
2059
2060 // Drop the thread so it can be reloaded from disk.
2061 cx.update(|_| {
2062 drop(thread);
2063 drop(acp_thread);
2064 });
2065 agent.read_with(cx, |agent, _| {
2066 assert!(agent.sessions.is_empty());
2067 });
2068
2069 // Reload the thread and verify thinking_enabled is still true.
2070 let reloaded_acp_thread = agent
2071 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2072 .await
2073 .unwrap();
2074 let reloaded_thread = agent.read_with(cx, |agent, _| {
2075 agent.sessions.get(&session_id).unwrap().thread.clone()
2076 });
2077 reloaded_thread.read_with(cx, |thread, _| {
2078 assert!(
2079 thread.thinking_enabled(),
2080 "thinking_enabled should be preserved when reloading a thread with a thinking model"
2081 );
2082 });
2083
2084 drop(reloaded_acp_thread);
2085 }
2086
2087 #[gpui::test]
2088 async fn test_loaded_thread_preserves_model(cx: &mut TestAppContext) {
2089 init_test(cx);
2090 let fs = FakeFs::new(cx.executor());
2091 fs.insert_tree("/", json!({ "a": {} })).await;
2092 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2093 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2094 let agent = NativeAgent::new(
2095 project.clone(),
2096 thread_store.clone(),
2097 Templates::new(),
2098 None,
2099 fs.clone(),
2100 &mut cx.to_async(),
2101 )
2102 .await
2103 .unwrap();
2104 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2105
2106 // Register a model where id() != name(), like real Anthropic models
2107 // (e.g. id="claude-sonnet-4-5-thinking-latest", name="Claude Sonnet 4.5 Thinking").
2108 let model = Arc::new(FakeLanguageModel::with_id_and_thinking(
2109 "fake-corp",
2110 "custom-model-id",
2111 "Custom Model Display Name",
2112 false,
2113 ));
2114 let provider = Arc::new(
2115 FakeLanguageModelProvider::new(
2116 LanguageModelProviderId::from("fake-corp".to_string()),
2117 LanguageModelProviderName::from("Fake Corp".to_string()),
2118 )
2119 .with_models(vec![model.clone()]),
2120 );
2121 cx.update(|cx| {
2122 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
2123 registry.register_provider(provider, cx);
2124 });
2125 });
2126 agent.update(cx, |agent, cx| agent.models.refresh_list(cx));
2127
2128 // Create a thread and select the model.
2129 let acp_thread = cx
2130 .update(|cx| {
2131 connection
2132 .clone()
2133 .new_thread(project.clone(), Path::new("/a"), cx)
2134 })
2135 .await
2136 .unwrap();
2137 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2138
2139 let selector = connection.model_selector(&session_id).unwrap();
2140 cx.update(|cx| selector.select_model(acp::ModelId::new("fake-corp/custom-model-id"), cx))
2141 .await
2142 .unwrap();
2143
2144 let thread = agent.read_with(cx, |agent, _| {
2145 agent.sessions.get(&session_id).unwrap().thread.clone()
2146 });
2147 thread.read_with(cx, |thread, _| {
2148 assert_eq!(
2149 thread.model().unwrap().id().0.as_ref(),
2150 "custom-model-id",
2151 "model should be set before persisting"
2152 );
2153 });
2154
2155 // Send a message so the thread gets persisted.
2156 let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["Hello".into()], cx));
2157 let send = cx.foreground_executor().spawn(send);
2158 cx.run_until_parked();
2159
2160 model.send_last_completion_stream_text_chunk("Response.");
2161 model.end_last_completion_stream();
2162
2163 send.await.unwrap();
2164 cx.run_until_parked();
2165
2166 // Drop the thread so it can be reloaded from disk.
2167 cx.update(|_| {
2168 drop(thread);
2169 drop(acp_thread);
2170 });
2171 agent.read_with(cx, |agent, _| {
2172 assert!(agent.sessions.is_empty());
2173 });
2174
2175 // Reload the thread and verify the model was preserved.
2176 let reloaded_acp_thread = agent
2177 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2178 .await
2179 .unwrap();
2180 let reloaded_thread = agent.read_with(cx, |agent, _| {
2181 agent.sessions.get(&session_id).unwrap().thread.clone()
2182 });
2183 reloaded_thread.read_with(cx, |thread, _| {
2184 let reloaded_model = thread
2185 .model()
2186 .expect("model should be present after reload");
2187 assert_eq!(
2188 reloaded_model.id().0.as_ref(),
2189 "custom-model-id",
2190 "reloaded thread should have the same model, not fall back to the default"
2191 );
2192 });
2193
2194 drop(reloaded_acp_thread);
2195 }
2196
2197 #[gpui::test]
2198 async fn test_save_load_thread(cx: &mut TestAppContext) {
2199 init_test(cx);
2200 let fs = FakeFs::new(cx.executor());
2201 fs.insert_tree(
2202 "/",
2203 json!({
2204 "a": {
2205 "b.md": "Lorem"
2206 }
2207 }),
2208 )
2209 .await;
2210 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
2211 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2212 let agent = NativeAgent::new(
2213 project.clone(),
2214 thread_store.clone(),
2215 Templates::new(),
2216 None,
2217 fs.clone(),
2218 &mut cx.to_async(),
2219 )
2220 .await
2221 .unwrap();
2222 let connection = Rc::new(NativeAgentConnection(agent.clone()));
2223
2224 let acp_thread = cx
2225 .update(|cx| {
2226 connection
2227 .clone()
2228 .new_thread(project.clone(), Path::new(""), cx)
2229 })
2230 .await
2231 .unwrap();
2232 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2233 let thread = agent.read_with(cx, |agent, _| {
2234 agent.sessions.get(&session_id).unwrap().thread.clone()
2235 });
2236
2237 // Ensure empty threads are not saved, even if they get mutated.
2238 let model = Arc::new(FakeLanguageModel::default());
2239 let summary_model = Arc::new(FakeLanguageModel::default());
2240 thread.update(cx, |thread, cx| {
2241 thread.set_model(model.clone(), cx);
2242 thread.set_summarization_model(Some(summary_model.clone()), cx);
2243 });
2244 cx.run_until_parked();
2245 assert_eq!(thread_entries(&thread_store, cx), vec![]);
2246
2247 let send = acp_thread.update(cx, |thread, cx| {
2248 thread.send(
2249 vec![
2250 "What does ".into(),
2251 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
2252 "b.md",
2253 MentionUri::File {
2254 abs_path: path!("/a/b.md").into(),
2255 }
2256 .to_uri()
2257 .to_string(),
2258 )),
2259 " mean?".into(),
2260 ],
2261 cx,
2262 )
2263 });
2264 let send = cx.foreground_executor().spawn(send);
2265 cx.run_until_parked();
2266
2267 model.send_last_completion_stream_text_chunk("Lorem.");
2268 model.end_last_completion_stream();
2269 cx.run_until_parked();
2270 summary_model
2271 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
2272 summary_model.end_last_completion_stream();
2273
2274 send.await.unwrap();
2275 let uri = MentionUri::File {
2276 abs_path: path!("/a/b.md").into(),
2277 }
2278 .to_uri();
2279 acp_thread.read_with(cx, |thread, cx| {
2280 assert_eq!(
2281 thread.to_markdown(cx),
2282 formatdoc! {"
2283 ## User
2284
2285 What does [@b.md]({uri}) mean?
2286
2287 ## Assistant
2288
2289 Lorem.
2290
2291 "}
2292 )
2293 });
2294
2295 cx.run_until_parked();
2296
2297 // Drop the ACP thread, which should cause the session to be dropped as well.
2298 cx.update(|_| {
2299 drop(thread);
2300 drop(acp_thread);
2301 });
2302 agent.read_with(cx, |agent, _| {
2303 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
2304 });
2305
2306 // Ensure the thread can be reloaded from disk.
2307 assert_eq!(
2308 thread_entries(&thread_store, cx),
2309 vec![(
2310 session_id.clone(),
2311 format!("Explaining {}", path!("/a/b.md"))
2312 )]
2313 );
2314 let acp_thread = agent
2315 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
2316 .await
2317 .unwrap();
2318 acp_thread.read_with(cx, |thread, cx| {
2319 assert_eq!(
2320 thread.to_markdown(cx),
2321 formatdoc! {"
2322 ## User
2323
2324 What does [@b.md]({uri}) mean?
2325
2326 ## Assistant
2327
2328 Lorem.
2329
2330 "}
2331 )
2332 });
2333 }
2334
2335 fn thread_entries(
2336 thread_store: &Entity<ThreadStore>,
2337 cx: &mut TestAppContext,
2338 ) -> Vec<(acp::SessionId, String)> {
2339 thread_store.read_with(cx, |store, _| {
2340 store
2341 .entries()
2342 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2343 .collect::<Vec<_>>()
2344 })
2345 }
2346
2347 fn init_test(cx: &mut TestAppContext) {
2348 env_logger::try_init().ok();
2349 cx.update(|cx| {
2350 let settings_store = SettingsStore::test(cx);
2351 cx.set_global(settings_store);
2352
2353 LanguageModelRegistry::test(cx);
2354 });
2355 }
2356}
2357
2358fn mcp_message_content_to_acp_content_block(
2359 content: context_server::types::MessageContent,
2360) -> acp::ContentBlock {
2361 match content {
2362 context_server::types::MessageContent::Text {
2363 text,
2364 annotations: _,
2365 } => text.into(),
2366 context_server::types::MessageContent::Image {
2367 data,
2368 mime_type,
2369 annotations: _,
2370 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2371 context_server::types::MessageContent::Audio {
2372 data,
2373 mime_type,
2374 annotations: _,
2375 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2376 context_server::types::MessageContent::Resource {
2377 resource,
2378 annotations: _,
2379 } => {
2380 let mut link =
2381 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2382 if let Some(mime_type) = resource.mime_type {
2383 link = link.mime_type(mime_type);
2384 }
2385 acp::ContentBlock::ResourceLink(link)
2386 }
2387 }
2388}