1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod shell_parser;
8mod templates;
9#[cfg(test)]
10mod tests;
11mod thread;
12mod thread_store;
13mod tool_permissions;
14mod tools;
15
16use context_server::ContextServerId;
17pub use db::*;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{StreamExt, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use util::ResultExt;
54use util::rel_path::RelPath;
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct ProjectSnapshot {
58 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
59 pub timestamp: DateTime<Utc>,
60}
61
62pub struct RulesLoadingError {
63 pub message: SharedString,
64}
65
66/// Holds both the internal Thread and the AcpThread for a session
67struct Session {
68 /// The internal thread that processes messages
69 thread: Entity<Thread>,
70 /// The ACP thread that handles protocol communication
71 acp_thread: WeakEntity<acp_thread::AcpThread>,
72 pending_save: Task<()>,
73 _subscriptions: Vec<Subscription>,
74}
75
76pub struct LanguageModels {
77 /// Access language model by ID
78 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
79 /// Cached list for returning language model information
80 model_list: acp_thread::AgentModelList,
81 refresh_models_rx: watch::Receiver<()>,
82 refresh_models_tx: watch::Sender<()>,
83 _authenticate_all_providers_task: Task<()>,
84}
85
86impl LanguageModels {
87 fn new(cx: &mut App) -> Self {
88 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
89
90 let mut this = Self {
91 models: HashMap::default(),
92 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
93 refresh_models_rx,
94 refresh_models_tx,
95 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
96 };
97 this.refresh_list(cx);
98 this
99 }
100
101 fn refresh_list(&mut self, cx: &App) {
102 let providers = LanguageModelRegistry::global(cx)
103 .read(cx)
104 .visible_providers()
105 .into_iter()
106 .filter(|provider| provider.is_authenticated(cx))
107 .collect::<Vec<_>>();
108
109 let mut language_model_list = IndexMap::default();
110 let mut recommended_models = HashSet::default();
111
112 let mut recommended = Vec::new();
113 for provider in &providers {
114 for model in provider.recommended_models(cx) {
115 recommended_models.insert((model.provider_id(), model.id()));
116 recommended.push(Self::map_language_model_to_info(&model, provider));
117 }
118 }
119 if !recommended.is_empty() {
120 language_model_list.insert(
121 acp_thread::AgentModelGroupName("Recommended".into()),
122 recommended,
123 );
124 }
125
126 let mut models = HashMap::default();
127 for provider in providers {
128 let mut provider_models = Vec::new();
129 for model in provider.provided_models(cx) {
130 let model_info = Self::map_language_model_to_info(&model, &provider);
131 let model_id = model_info.id.clone();
132 provider_models.push(model_info);
133 models.insert(model_id, model);
134 }
135 if !provider_models.is_empty() {
136 language_model_list.insert(
137 acp_thread::AgentModelGroupName(provider.name().0.clone()),
138 provider_models,
139 );
140 }
141 }
142
143 self.models = models;
144 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
145 self.refresh_models_tx.send(()).ok();
146 }
147
148 fn watch(&self) -> watch::Receiver<()> {
149 self.refresh_models_rx.clone()
150 }
151
152 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
153 self.models.get(model_id).cloned()
154 }
155
156 fn map_language_model_to_info(
157 model: &Arc<dyn LanguageModel>,
158 provider: &Arc<dyn LanguageModelProvider>,
159 ) -> acp_thread::AgentModelInfo {
160 acp_thread::AgentModelInfo {
161 id: Self::model_id(model),
162 name: model.name().0,
163 description: None,
164 icon: Some(match provider.icon() {
165 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
166 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
167 }),
168 is_latest: model.is_latest(),
169 }
170 }
171
172 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
173 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
174 }
175
176 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
177 let authenticate_all_providers = LanguageModelRegistry::global(cx)
178 .read(cx)
179 .visible_providers()
180 .iter()
181 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
182 .collect::<Vec<_>>();
183
184 cx.background_spawn(async move {
185 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
186 if let Err(err) = authenticate_task.await {
187 match err {
188 language_model::AuthenticateError::CredentialsNotFound => {
189 // Since we're authenticating these providers in the
190 // background for the purposes of populating the
191 // language selector, we don't care about providers
192 // where the credentials are not found.
193 }
194 language_model::AuthenticateError::ConnectionRefused => {
195 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
196 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
197 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
198 }
199 _ => {
200 // Some providers have noisy failure states that we
201 // don't want to spam the logs with every time the
202 // language model selector is initialized.
203 //
204 // Ideally these should have more clear failure modes
205 // that we know are safe to ignore here, like what we do
206 // with `CredentialsNotFound` above.
207 match provider_id.0.as_ref() {
208 "lmstudio" | "ollama" => {
209 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
210 //
211 // These fail noisily, so we don't log them.
212 }
213 "copilot_chat" => {
214 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
215 }
216 _ => {
217 log::error!(
218 "Failed to authenticate provider: {}: {err:#}",
219 provider_name.0
220 );
221 }
222 }
223 }
224 }
225 }
226 }
227 })
228 }
229}
230
231pub struct NativeAgent {
232 /// Session ID -> Session mapping
233 sessions: HashMap<acp::SessionId, Session>,
234 thread_store: Entity<ThreadStore>,
235 /// Shared project context for all threads
236 project_context: Entity<ProjectContext>,
237 project_context_needs_refresh: watch::Sender<()>,
238 _maintain_project_context: Task<Result<()>>,
239 context_server_registry: Entity<ContextServerRegistry>,
240 /// Shared templates for all threads
241 templates: Arc<Templates>,
242 /// Cached model information
243 models: LanguageModels,
244 project: Entity<Project>,
245 prompt_store: Option<Entity<PromptStore>>,
246 fs: Arc<dyn Fs>,
247 _subscriptions: Vec<Subscription>,
248}
249
250impl NativeAgent {
251 pub async fn new(
252 project: Entity<Project>,
253 thread_store: Entity<ThreadStore>,
254 templates: Arc<Templates>,
255 prompt_store: Option<Entity<PromptStore>>,
256 fs: Arc<dyn Fs>,
257 cx: &mut AsyncApp,
258 ) -> Result<Entity<NativeAgent>> {
259 log::debug!("Creating new NativeAgent");
260
261 let project_context = cx
262 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
263 .await;
264
265 Ok(cx.new(|cx| {
266 let context_server_store = project.read(cx).context_server_store();
267 let context_server_registry =
268 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
269
270 let mut subscriptions = vec![
271 cx.subscribe(&project, Self::handle_project_event),
272 cx.subscribe(
273 &LanguageModelRegistry::global(cx),
274 Self::handle_models_updated_event,
275 ),
276 cx.subscribe(
277 &context_server_store,
278 Self::handle_context_server_store_updated,
279 ),
280 cx.subscribe(
281 &context_server_registry,
282 Self::handle_context_server_registry_event,
283 ),
284 ];
285 if let Some(prompt_store) = prompt_store.as_ref() {
286 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
287 }
288
289 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
290 watch::channel(());
291 Self {
292 sessions: HashMap::default(),
293 thread_store,
294 project_context: cx.new(|_| project_context),
295 project_context_needs_refresh: project_context_needs_refresh_tx,
296 _maintain_project_context: cx.spawn(async move |this, cx| {
297 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
298 }),
299 context_server_registry,
300 templates,
301 models: LanguageModels::new(cx),
302 project,
303 prompt_store,
304 fs,
305 _subscriptions: subscriptions,
306 }
307 }))
308 }
309
310 fn new_session(
311 &mut self,
312 project: Entity<Project>,
313 cx: &mut Context<Self>,
314 ) -> Entity<AcpThread> {
315 // Create Thread
316 // Fetch default model from registry settings
317 let registry = LanguageModelRegistry::read_global(cx);
318 // Log available models for debugging
319 let available_count = registry.available_models(cx).count();
320 log::debug!("Total available models: {}", available_count);
321
322 let default_model = registry.default_model().and_then(|default_model| {
323 self.models
324 .model_from_id(&LanguageModels::model_id(&default_model.model))
325 });
326 let thread = cx.new(|cx| {
327 Thread::new(
328 project.clone(),
329 self.project_context.clone(),
330 self.context_server_registry.clone(),
331 self.templates.clone(),
332 default_model,
333 cx,
334 )
335 });
336
337 self.register_session(thread, cx)
338 }
339
340 fn register_session(
341 &mut self,
342 thread_handle: Entity<Thread>,
343 cx: &mut Context<Self>,
344 ) -> Entity<AcpThread> {
345 let connection = Rc::new(NativeAgentConnection(cx.entity()));
346
347 let thread = thread_handle.read(cx);
348 let session_id = thread.id().clone();
349 let title = thread.title();
350 let project = thread.project.clone();
351 let action_log = thread.action_log.clone();
352 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
353 let acp_thread = cx.new(|cx| {
354 acp_thread::AcpThread::new(
355 title,
356 connection,
357 project.clone(),
358 action_log.clone(),
359 session_id.clone(),
360 prompt_capabilities_rx,
361 cx,
362 )
363 });
364
365 let registry = LanguageModelRegistry::read_global(cx);
366 let summarization_model = registry.thread_summary_model().map(|c| c.model);
367
368 thread_handle.update(cx, |thread, cx| {
369 thread.set_summarization_model(summarization_model, cx);
370 thread.add_default_tools(
371 Rc::new(AcpThreadEnvironment {
372 acp_thread: acp_thread.downgrade(),
373 }) as _,
374 cx,
375 )
376 });
377
378 let subscriptions = vec![
379 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
380 this.sessions.remove(acp_thread.session_id());
381 }),
382 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
383 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
384 cx.observe(&thread_handle, move |this, thread, cx| {
385 this.save_thread(thread, cx)
386 }),
387 ];
388
389 self.sessions.insert(
390 session_id,
391 Session {
392 thread: thread_handle,
393 acp_thread: acp_thread.downgrade(),
394 _subscriptions: subscriptions,
395 pending_save: Task::ready(()),
396 },
397 );
398
399 self.update_available_commands(cx);
400
401 acp_thread
402 }
403
404 pub fn models(&self) -> &LanguageModels {
405 &self.models
406 }
407
408 async fn maintain_project_context(
409 this: WeakEntity<Self>,
410 mut needs_refresh: watch::Receiver<()>,
411 cx: &mut AsyncApp,
412 ) -> Result<()> {
413 while needs_refresh.changed().await.is_ok() {
414 let project_context = this
415 .update(cx, |this, cx| {
416 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
417 })?
418 .await;
419 this.update(cx, |this, cx| {
420 this.project_context = cx.new(|_| project_context);
421 })?;
422 }
423
424 Ok(())
425 }
426
427 fn build_project_context(
428 project: &Entity<Project>,
429 prompt_store: Option<&Entity<PromptStore>>,
430 cx: &mut App,
431 ) -> Task<ProjectContext> {
432 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
433 let worktree_tasks = worktrees
434 .into_iter()
435 .map(|worktree| {
436 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
437 })
438 .collect::<Vec<_>>();
439 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
440 prompt_store.read_with(cx, |prompt_store, cx| {
441 let prompts = prompt_store.default_prompt_metadata();
442 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
443 let contents = prompt_store.load(prompt_metadata.id, cx);
444 async move { (contents.await, prompt_metadata) }
445 });
446 cx.background_spawn(future::join_all(load_tasks))
447 })
448 } else {
449 Task::ready(vec![])
450 };
451
452 cx.spawn(async move |_cx| {
453 let (worktrees, default_user_rules) =
454 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
455
456 let worktrees = worktrees
457 .into_iter()
458 .map(|(worktree, _rules_error)| {
459 // TODO: show error message
460 // if let Some(rules_error) = rules_error {
461 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
462 // }
463 worktree
464 })
465 .collect::<Vec<_>>();
466
467 let default_user_rules = default_user_rules
468 .into_iter()
469 .flat_map(|(contents, prompt_metadata)| match contents {
470 Ok(contents) => Some(UserRulesContext {
471 uuid: prompt_metadata.id.as_user()?,
472 title: prompt_metadata.title.map(|title| title.to_string()),
473 contents,
474 }),
475 Err(_err) => {
476 // TODO: show error message
477 // this.update(cx, |_, cx| {
478 // cx.emit(RulesLoadingError {
479 // message: format!("{err:?}").into(),
480 // });
481 // })
482 // .ok();
483 None
484 }
485 })
486 .collect::<Vec<_>>();
487
488 ProjectContext::new(worktrees, default_user_rules)
489 })
490 }
491
492 fn load_worktree_info_for_system_prompt(
493 worktree: Entity<Worktree>,
494 project: Entity<Project>,
495 cx: &mut App,
496 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
497 let tree = worktree.read(cx);
498 let root_name = tree.root_name_str().into();
499 let abs_path = tree.abs_path();
500
501 let mut context = WorktreeContext {
502 root_name,
503 abs_path,
504 rules_file: None,
505 };
506
507 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
508 let Some(rules_task) = rules_task else {
509 return Task::ready((context, None));
510 };
511
512 cx.spawn(async move |_| {
513 let (rules_file, rules_file_error) = match rules_task.await {
514 Ok(rules_file) => (Some(rules_file), None),
515 Err(err) => (
516 None,
517 Some(RulesLoadingError {
518 message: format!("{err}").into(),
519 }),
520 ),
521 };
522 context.rules_file = rules_file;
523 (context, rules_file_error)
524 })
525 }
526
527 fn load_worktree_rules_file(
528 worktree: Entity<Worktree>,
529 project: Entity<Project>,
530 cx: &mut App,
531 ) -> Option<Task<Result<RulesFileContext>>> {
532 let worktree = worktree.read(cx);
533 let worktree_id = worktree.id();
534 let selected_rules_file = RULES_FILE_NAMES
535 .into_iter()
536 .filter_map(|name| {
537 worktree
538 .entry_for_path(RelPath::unix(name).unwrap())
539 .filter(|entry| entry.is_file())
540 .map(|entry| entry.path.clone())
541 })
542 .next();
543
544 // Note that Cline supports `.clinerules` being a directory, but that is not currently
545 // supported. This doesn't seem to occur often in GitHub repositories.
546 selected_rules_file.map(|path_in_worktree| {
547 let project_path = ProjectPath {
548 worktree_id,
549 path: path_in_worktree.clone(),
550 };
551 let buffer_task =
552 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
553 let rope_task = cx.spawn(async move |cx| {
554 let buffer = buffer_task.await?;
555 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
556 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
557 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
558 })?;
559 anyhow::Ok((project_entry_id, rope))
560 });
561 // Build a string from the rope on a background thread.
562 cx.background_spawn(async move {
563 let (project_entry_id, rope) = rope_task.await?;
564 anyhow::Ok(RulesFileContext {
565 path_in_worktree,
566 text: rope.to_string().trim().to_string(),
567 project_entry_id: project_entry_id.to_usize(),
568 })
569 })
570 })
571 }
572
573 fn handle_thread_title_updated(
574 &mut self,
575 thread: Entity<Thread>,
576 _: &TitleUpdated,
577 cx: &mut Context<Self>,
578 ) {
579 let session_id = thread.read(cx).id();
580 let Some(session) = self.sessions.get(session_id) else {
581 return;
582 };
583 let thread = thread.downgrade();
584 let acp_thread = session.acp_thread.clone();
585 cx.spawn(async move |_, cx| {
586 let title = thread.read_with(cx, |thread, _| thread.title())?;
587 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
588 task.await
589 })
590 .detach_and_log_err(cx);
591 }
592
593 fn handle_thread_token_usage_updated(
594 &mut self,
595 thread: Entity<Thread>,
596 usage: &TokenUsageUpdated,
597 cx: &mut Context<Self>,
598 ) {
599 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
600 return;
601 };
602 session
603 .acp_thread
604 .update(cx, |acp_thread, cx| {
605 acp_thread.update_token_usage(usage.0.clone(), cx);
606 })
607 .ok();
608 }
609
610 fn handle_project_event(
611 &mut self,
612 _project: Entity<Project>,
613 event: &project::Event,
614 _cx: &mut Context<Self>,
615 ) {
616 match event {
617 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
618 self.project_context_needs_refresh.send(()).ok();
619 }
620 project::Event::WorktreeUpdatedEntries(_, items) => {
621 if items.iter().any(|(path, _, _)| {
622 RULES_FILE_NAMES
623 .iter()
624 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
625 }) {
626 self.project_context_needs_refresh.send(()).ok();
627 }
628 }
629 _ => {}
630 }
631 }
632
633 fn handle_prompts_updated_event(
634 &mut self,
635 _prompt_store: Entity<PromptStore>,
636 _event: &prompt_store::PromptsUpdatedEvent,
637 _cx: &mut Context<Self>,
638 ) {
639 self.project_context_needs_refresh.send(()).ok();
640 }
641
642 fn handle_models_updated_event(
643 &mut self,
644 _registry: Entity<LanguageModelRegistry>,
645 _event: &language_model::Event,
646 cx: &mut Context<Self>,
647 ) {
648 self.models.refresh_list(cx);
649
650 let registry = LanguageModelRegistry::read_global(cx);
651 let default_model = registry.default_model().map(|m| m.model);
652 let summarization_model = registry.thread_summary_model().map(|m| m.model);
653
654 for session in self.sessions.values_mut() {
655 session.thread.update(cx, |thread, cx| {
656 if thread.model().is_none()
657 && let Some(model) = default_model.clone()
658 {
659 thread.set_model(model, cx);
660 cx.notify();
661 }
662 thread.set_summarization_model(summarization_model.clone(), cx);
663 });
664 }
665 }
666
667 fn handle_context_server_store_updated(
668 &mut self,
669 _store: Entity<project::context_server_store::ContextServerStore>,
670 _event: &project::context_server_store::Event,
671 cx: &mut Context<Self>,
672 ) {
673 self.update_available_commands(cx);
674 }
675
676 fn handle_context_server_registry_event(
677 &mut self,
678 _registry: Entity<ContextServerRegistry>,
679 event: &ContextServerRegistryEvent,
680 cx: &mut Context<Self>,
681 ) {
682 match event {
683 ContextServerRegistryEvent::ToolsChanged => {}
684 ContextServerRegistryEvent::PromptsChanged => {
685 self.update_available_commands(cx);
686 }
687 }
688 }
689
690 fn update_available_commands(&self, cx: &mut Context<Self>) {
691 let available_commands = self.build_available_commands(cx);
692 for session in self.sessions.values() {
693 if let Some(acp_thread) = session.acp_thread.upgrade() {
694 acp_thread.update(cx, |thread, cx| {
695 thread
696 .handle_session_update(
697 acp::SessionUpdate::AvailableCommandsUpdate(
698 acp::AvailableCommandsUpdate::new(available_commands.clone()),
699 ),
700 cx,
701 )
702 .log_err();
703 });
704 }
705 }
706 }
707
708 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
709 let registry = self.context_server_registry.read(cx);
710
711 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
712 for context_server_prompt in registry.prompts() {
713 *prompt_name_counts
714 .entry(context_server_prompt.prompt.name.as_str())
715 .or_insert(0) += 1;
716 }
717
718 registry
719 .prompts()
720 .flat_map(|context_server_prompt| {
721 let prompt = &context_server_prompt.prompt;
722
723 let should_prefix = prompt_name_counts
724 .get(prompt.name.as_str())
725 .copied()
726 .unwrap_or(0)
727 > 1;
728
729 let name = if should_prefix {
730 format!("{}.{}", context_server_prompt.server_id, prompt.name)
731 } else {
732 prompt.name.clone()
733 };
734
735 let mut command = acp::AvailableCommand::new(
736 name,
737 prompt.description.clone().unwrap_or_default(),
738 );
739
740 match prompt.arguments.as_deref() {
741 Some([arg]) => {
742 let hint = format!("<{}>", arg.name);
743
744 command = command.input(acp::AvailableCommandInput::Unstructured(
745 acp::UnstructuredCommandInput::new(hint),
746 ));
747 }
748 Some([]) | None => {}
749 Some(_) => {
750 // skip >1 argument commands since we don't support them yet
751 return None;
752 }
753 }
754
755 Some(command)
756 })
757 .collect()
758 }
759
760 pub fn load_thread(
761 &mut self,
762 id: acp::SessionId,
763 cx: &mut Context<Self>,
764 ) -> Task<Result<Entity<Thread>>> {
765 let database_future = ThreadsDatabase::connect(cx);
766 cx.spawn(async move |this, cx| {
767 let database = database_future.await.map_err(|err| anyhow!(err))?;
768 let db_thread = database
769 .load_thread(id.clone())
770 .await?
771 .with_context(|| format!("no thread found with ID: {id:?}"))?;
772
773 this.update(cx, |this, cx| {
774 let summarization_model = LanguageModelRegistry::read_global(cx)
775 .thread_summary_model()
776 .map(|c| c.model);
777
778 cx.new(|cx| {
779 let mut thread = Thread::from_db(
780 id.clone(),
781 db_thread,
782 this.project.clone(),
783 this.project_context.clone(),
784 this.context_server_registry.clone(),
785 this.templates.clone(),
786 cx,
787 );
788 thread.set_summarization_model(summarization_model, cx);
789 thread
790 })
791 })
792 })
793 }
794
795 pub fn open_thread(
796 &mut self,
797 id: acp::SessionId,
798 cx: &mut Context<Self>,
799 ) -> Task<Result<Entity<AcpThread>>> {
800 let task = self.load_thread(id, cx);
801 cx.spawn(async move |this, cx| {
802 let thread = task.await?;
803 let acp_thread =
804 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
805 let events = thread.update(cx, |thread, cx| thread.replay(cx));
806 cx.update(|cx| {
807 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
808 })
809 .await?;
810 Ok(acp_thread)
811 })
812 }
813
814 pub fn thread_summary(
815 &mut self,
816 id: acp::SessionId,
817 cx: &mut Context<Self>,
818 ) -> Task<Result<SharedString>> {
819 let thread = self.open_thread(id.clone(), cx);
820 cx.spawn(async move |this, cx| {
821 let acp_thread = thread.await?;
822 let result = this
823 .update(cx, |this, cx| {
824 this.sessions
825 .get(&id)
826 .unwrap()
827 .thread
828 .update(cx, |thread, cx| thread.summary(cx))
829 })?
830 .await
831 .context("Failed to generate summary")?;
832 drop(acp_thread);
833 Ok(result)
834 })
835 }
836
837 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
838 if thread.read(cx).is_empty() {
839 return;
840 }
841
842 let database_future = ThreadsDatabase::connect(cx);
843 let (id, db_thread) =
844 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
845 let Some(session) = self.sessions.get_mut(&id) else {
846 return;
847 };
848 let thread_store = self.thread_store.clone();
849 session.pending_save = cx.spawn(async move |_, cx| {
850 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
851 return;
852 };
853 let db_thread = db_thread.await;
854 database.save_thread(id, db_thread).await.log_err();
855 thread_store.update(cx, |store, cx| store.reload(cx));
856 });
857 }
858
859 fn send_mcp_prompt(
860 &self,
861 message_id: UserMessageId,
862 session_id: agent_client_protocol::SessionId,
863 prompt_name: String,
864 server_id: ContextServerId,
865 arguments: HashMap<String, String>,
866 original_content: Vec<acp::ContentBlock>,
867 cx: &mut Context<Self>,
868 ) -> Task<Result<acp::PromptResponse>> {
869 let server_store = self.context_server_registry.read(cx).server_store().clone();
870 let path_style = self.project.read(cx).path_style(cx);
871
872 cx.spawn(async move |this, cx| {
873 let prompt =
874 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
875
876 let (acp_thread, thread) = this.update(cx, |this, _cx| {
877 let session = this
878 .sessions
879 .get(&session_id)
880 .context("Failed to get session")?;
881 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
882 })??;
883
884 let mut last_is_user = true;
885
886 thread.update(cx, |thread, cx| {
887 thread.push_acp_user_block(
888 message_id,
889 original_content.into_iter().skip(1),
890 path_style,
891 cx,
892 );
893 });
894
895 for message in prompt.messages {
896 let context_server::types::PromptMessage { role, content } = message;
897 let block = mcp_message_content_to_acp_content_block(content);
898
899 match role {
900 context_server::types::Role::User => {
901 let id = acp_thread::UserMessageId::new();
902
903 acp_thread.update(cx, |acp_thread, cx| {
904 acp_thread.push_user_content_block_with_indent(
905 Some(id.clone()),
906 block.clone(),
907 true,
908 cx,
909 );
910 })?;
911
912 thread.update(cx, |thread, cx| {
913 thread.push_acp_user_block(id, [block], path_style, cx);
914 });
915 }
916 context_server::types::Role::Assistant => {
917 acp_thread.update(cx, |acp_thread, cx| {
918 acp_thread.push_assistant_content_block_with_indent(
919 block.clone(),
920 false,
921 true,
922 cx,
923 );
924 })?;
925
926 thread.update(cx, |thread, cx| {
927 thread.push_acp_agent_block(block, cx);
928 });
929 }
930 }
931
932 last_is_user = role == context_server::types::Role::User;
933 }
934
935 let response_stream = thread.update(cx, |thread, cx| {
936 if last_is_user {
937 thread.send_existing(cx)
938 } else {
939 // Resume if MCP prompt did not end with a user message
940 thread.resume(cx)
941 }
942 })?;
943
944 cx.update(|cx| {
945 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
946 })
947 .await
948 })
949 }
950}
951
952/// Wrapper struct that implements the AgentConnection trait
953#[derive(Clone)]
954pub struct NativeAgentConnection(pub Entity<NativeAgent>);
955
956impl NativeAgentConnection {
957 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
958 self.0
959 .read(cx)
960 .sessions
961 .get(session_id)
962 .map(|session| session.thread.clone())
963 }
964
965 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
966 self.0.update(cx, |this, cx| this.load_thread(id, cx))
967 }
968
969 fn run_turn(
970 &self,
971 session_id: acp::SessionId,
972 cx: &mut App,
973 f: impl 'static
974 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
975 ) -> Task<Result<acp::PromptResponse>> {
976 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
977 agent
978 .sessions
979 .get_mut(&session_id)
980 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
981 }) else {
982 return Task::ready(Err(anyhow!("Session not found")));
983 };
984 log::debug!("Found session for: {}", session_id);
985
986 let response_stream = match f(thread, cx) {
987 Ok(stream) => stream,
988 Err(err) => return Task::ready(Err(err)),
989 };
990 Self::handle_thread_events(response_stream, acp_thread, cx)
991 }
992
993 fn handle_thread_events(
994 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
995 acp_thread: WeakEntity<AcpThread>,
996 cx: &App,
997 ) -> Task<Result<acp::PromptResponse>> {
998 cx.spawn(async move |cx| {
999 // Handle response stream and forward to session.acp_thread
1000 while let Some(result) = events.next().await {
1001 match result {
1002 Ok(event) => {
1003 log::trace!("Received completion event: {:?}", event);
1004
1005 match event {
1006 ThreadEvent::UserMessage(message) => {
1007 acp_thread.update(cx, |thread, cx| {
1008 for content in message.content {
1009 thread.push_user_content_block(
1010 Some(message.id.clone()),
1011 content.into(),
1012 cx,
1013 );
1014 }
1015 })?;
1016 }
1017 ThreadEvent::AgentText(text) => {
1018 acp_thread.update(cx, |thread, cx| {
1019 thread.push_assistant_content_block(text.into(), false, cx)
1020 })?;
1021 }
1022 ThreadEvent::AgentThinking(text) => {
1023 acp_thread.update(cx, |thread, cx| {
1024 thread.push_assistant_content_block(text.into(), true, cx)
1025 })?;
1026 }
1027 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
1028 tool_call,
1029 options,
1030 response,
1031 context: _,
1032 }) => {
1033 let outcome_task = acp_thread.update(cx, |thread, cx| {
1034 thread.request_tool_call_authorization(
1035 tool_call, options, true, cx,
1036 )
1037 })??;
1038 cx.background_spawn(async move {
1039 if let acp::RequestPermissionOutcome::Selected(
1040 acp::SelectedPermissionOutcome { option_id, .. },
1041 ) = outcome_task.await
1042 {
1043 response
1044 .send(option_id)
1045 .map(|_| anyhow!("authorization receiver was dropped"))
1046 .log_err();
1047 }
1048 })
1049 .detach();
1050 }
1051 ThreadEvent::ToolCall(tool_call) => {
1052 acp_thread.update(cx, |thread, cx| {
1053 thread.upsert_tool_call(tool_call, cx)
1054 })??;
1055 }
1056 ThreadEvent::ToolCallUpdate(update) => {
1057 acp_thread.update(cx, |thread, cx| {
1058 thread.update_tool_call(update, cx)
1059 })??;
1060 }
1061 ThreadEvent::Retry(status) => {
1062 acp_thread.update(cx, |thread, cx| {
1063 thread.update_retry_status(status, cx)
1064 })?;
1065 }
1066 ThreadEvent::Stop(stop_reason) => {
1067 log::debug!("Assistant message complete: {:?}", stop_reason);
1068 return Ok(acp::PromptResponse::new(stop_reason));
1069 }
1070 }
1071 }
1072 Err(e) => {
1073 log::error!("Error in model response stream: {:?}", e);
1074 return Err(e);
1075 }
1076 }
1077 }
1078
1079 log::debug!("Response stream completed");
1080 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1081 })
1082 }
1083}
1084
1085struct Command<'a> {
1086 prompt_name: &'a str,
1087 arg_value: &'a str,
1088 explicit_server_id: Option<&'a str>,
1089}
1090
1091impl<'a> Command<'a> {
1092 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1093 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1094 return None;
1095 };
1096 let text = text_content.text.trim();
1097 let command = text.strip_prefix('/')?;
1098 let (command, arg_value) = command
1099 .split_once(char::is_whitespace)
1100 .unwrap_or((command, ""));
1101
1102 if let Some((server_id, prompt_name)) = command.split_once('.') {
1103 Some(Self {
1104 prompt_name,
1105 arg_value,
1106 explicit_server_id: Some(server_id),
1107 })
1108 } else {
1109 Some(Self {
1110 prompt_name: command,
1111 arg_value,
1112 explicit_server_id: None,
1113 })
1114 }
1115 }
1116}
1117
1118struct NativeAgentModelSelector {
1119 session_id: acp::SessionId,
1120 connection: NativeAgentConnection,
1121}
1122
1123impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1124 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1125 log::debug!("NativeAgentConnection::list_models called");
1126 let list = self.connection.0.read(cx).models.model_list.clone();
1127 Task::ready(if list.is_empty() {
1128 Err(anyhow::anyhow!("No models available"))
1129 } else {
1130 Ok(list)
1131 })
1132 }
1133
1134 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1135 log::debug!(
1136 "Setting model for session {}: {}",
1137 self.session_id,
1138 model_id
1139 );
1140 let Some(thread) = self
1141 .connection
1142 .0
1143 .read(cx)
1144 .sessions
1145 .get(&self.session_id)
1146 .map(|session| session.thread.clone())
1147 else {
1148 return Task::ready(Err(anyhow!("Session not found")));
1149 };
1150
1151 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1152 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1153 };
1154
1155 // We want to reset the effort level when switching models, as the currently-selected effort level may
1156 // not be compatible.
1157 let effort = model
1158 .default_effort_level()
1159 .map(|effort_level| effort_level.value.to_string());
1160
1161 thread.update(cx, |thread, cx| {
1162 thread.set_model(model.clone(), cx);
1163 thread.set_thinking_effort(effort.clone(), cx);
1164 });
1165
1166 update_settings_file(
1167 self.connection.0.read(cx).fs.clone(),
1168 cx,
1169 move |settings, cx| {
1170 let provider = model.provider_id().0.to_string();
1171 let model = model.id().0.to_string();
1172 let enable_thinking = settings
1173 .agent
1174 .as_ref()
1175 .and_then(|agent| {
1176 agent
1177 .default_model
1178 .as_ref()
1179 .map(|default_model| default_model.enable_thinking)
1180 })
1181 .unwrap_or_else(|| thread.read(cx).thinking_enabled());
1182 settings
1183 .agent
1184 .get_or_insert_default()
1185 .set_model(LanguageModelSelection {
1186 provider: provider.into(),
1187 model,
1188 enable_thinking,
1189 effort,
1190 });
1191 },
1192 );
1193
1194 Task::ready(Ok(()))
1195 }
1196
1197 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1198 let Some(thread) = self
1199 .connection
1200 .0
1201 .read(cx)
1202 .sessions
1203 .get(&self.session_id)
1204 .map(|session| session.thread.clone())
1205 else {
1206 return Task::ready(Err(anyhow!("Session not found")));
1207 };
1208 let Some(model) = thread.read(cx).model() else {
1209 return Task::ready(Err(anyhow!("Model not found")));
1210 };
1211 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1212 else {
1213 return Task::ready(Err(anyhow!("Provider not found")));
1214 };
1215 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1216 model, &provider,
1217 )))
1218 }
1219
1220 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1221 Some(self.connection.0.read(cx).models.watch())
1222 }
1223
1224 fn should_render_footer(&self) -> bool {
1225 true
1226 }
1227}
1228
1229impl acp_thread::AgentConnection for NativeAgentConnection {
1230 fn telemetry_id(&self) -> SharedString {
1231 "zed".into()
1232 }
1233
1234 fn new_thread(
1235 self: Rc<Self>,
1236 project: Entity<Project>,
1237 cwd: &Path,
1238 cx: &mut App,
1239 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1240 log::debug!("Creating new thread for project at: {cwd:?}");
1241 Task::ready(Ok(self
1242 .0
1243 .update(cx, |agent, cx| agent.new_session(project, cx))))
1244 }
1245
1246 fn supports_load_session(&self, _cx: &App) -> bool {
1247 true
1248 }
1249
1250 fn load_session(
1251 self: Rc<Self>,
1252 session: AgentSessionInfo,
1253 _project: Entity<Project>,
1254 _cwd: &Path,
1255 cx: &mut App,
1256 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1257 self.0
1258 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1259 }
1260
1261 fn auth_methods(&self) -> &[acp::AuthMethod] {
1262 &[] // No auth for in-process
1263 }
1264
1265 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1266 Task::ready(Ok(()))
1267 }
1268
1269 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1270 Some(Rc::new(NativeAgentModelSelector {
1271 session_id: session_id.clone(),
1272 connection: self.clone(),
1273 }) as Rc<dyn AgentModelSelector>)
1274 }
1275
1276 fn prompt(
1277 &self,
1278 id: Option<acp_thread::UserMessageId>,
1279 params: acp::PromptRequest,
1280 cx: &mut App,
1281 ) -> Task<Result<acp::PromptResponse>> {
1282 let id = id.expect("UserMessageId is required");
1283 let session_id = params.session_id.clone();
1284 log::info!("Received prompt request for session: {}", session_id);
1285 log::debug!("Prompt blocks count: {}", params.prompt.len());
1286
1287 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1288 let registry = self.0.read(cx).context_server_registry.read(cx);
1289
1290 let explicit_server_id = parsed_command
1291 .explicit_server_id
1292 .map(|server_id| ContextServerId(server_id.into()));
1293
1294 if let Some(prompt) =
1295 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1296 {
1297 let arguments = if !parsed_command.arg_value.is_empty()
1298 && let Some(arg_name) = prompt
1299 .prompt
1300 .arguments
1301 .as_ref()
1302 .and_then(|args| args.first())
1303 .map(|arg| arg.name.clone())
1304 {
1305 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1306 } else {
1307 Default::default()
1308 };
1309
1310 let prompt_name = prompt.prompt.name.clone();
1311 let server_id = prompt.server_id.clone();
1312
1313 return self.0.update(cx, |agent, cx| {
1314 agent.send_mcp_prompt(
1315 id,
1316 session_id.clone(),
1317 prompt_name,
1318 server_id,
1319 arguments,
1320 params.prompt,
1321 cx,
1322 )
1323 });
1324 };
1325 };
1326
1327 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1328
1329 self.run_turn(session_id, cx, move |thread, cx| {
1330 let content: Vec<UserMessageContent> = params
1331 .prompt
1332 .into_iter()
1333 .map(|block| UserMessageContent::from_content_block(block, path_style))
1334 .collect::<Vec<_>>();
1335 log::debug!("Converted prompt to message: {} chars", content.len());
1336 log::debug!("Message id: {:?}", id);
1337 log::debug!("Message content: {:?}", content);
1338
1339 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1340 })
1341 }
1342
1343 fn retry(
1344 &self,
1345 session_id: &acp::SessionId,
1346 _cx: &App,
1347 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1348 Some(Rc::new(NativeAgentSessionRetry {
1349 connection: self.clone(),
1350 session_id: session_id.clone(),
1351 }) as _)
1352 }
1353
1354 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1355 log::info!("Cancelling on session: {}", session_id);
1356 self.0.update(cx, |agent, cx| {
1357 if let Some(agent) = agent.sessions.get(session_id) {
1358 agent
1359 .thread
1360 .update(cx, |thread, cx| thread.cancel(cx))
1361 .detach();
1362 }
1363 });
1364 }
1365
1366 fn truncate(
1367 &self,
1368 session_id: &agent_client_protocol::SessionId,
1369 cx: &App,
1370 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1371 self.0.read_with(cx, |agent, _cx| {
1372 agent.sessions.get(session_id).map(|session| {
1373 Rc::new(NativeAgentSessionTruncate {
1374 thread: session.thread.clone(),
1375 acp_thread: session.acp_thread.clone(),
1376 }) as _
1377 })
1378 })
1379 }
1380
1381 fn set_title(
1382 &self,
1383 session_id: &acp::SessionId,
1384 _cx: &App,
1385 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1386 Some(Rc::new(NativeAgentSessionSetTitle {
1387 connection: self.clone(),
1388 session_id: session_id.clone(),
1389 }) as _)
1390 }
1391
1392 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1393 let thread_store = self.0.read(cx).thread_store.clone();
1394 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1395 }
1396
1397 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1398 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1399 }
1400
1401 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1402 self
1403 }
1404}
1405
1406impl acp_thread::AgentTelemetry for NativeAgentConnection {
1407 fn thread_data(
1408 &self,
1409 session_id: &acp::SessionId,
1410 cx: &mut App,
1411 ) -> Task<Result<serde_json::Value>> {
1412 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1413 return Task::ready(Err(anyhow!("Session not found")));
1414 };
1415
1416 let task = session.thread.read(cx).to_db(cx);
1417 cx.background_spawn(async move {
1418 serde_json::to_value(task.await).context("Failed to serialize thread")
1419 })
1420 }
1421}
1422
1423pub struct NativeAgentSessionList {
1424 thread_store: Entity<ThreadStore>,
1425 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1426 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1427 _subscription: Subscription,
1428}
1429
1430impl NativeAgentSessionList {
1431 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1432 let (tx, rx) = smol::channel::unbounded();
1433 let this_tx = tx.clone();
1434 let subscription = cx.observe(&thread_store, move |_, _| {
1435 this_tx
1436 .try_send(acp_thread::SessionListUpdate::Refresh)
1437 .ok();
1438 });
1439 Self {
1440 thread_store,
1441 updates_tx: tx,
1442 updates_rx: rx,
1443 _subscription: subscription,
1444 }
1445 }
1446
1447 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1448 AgentSessionInfo {
1449 session_id: entry.id,
1450 cwd: None,
1451 title: Some(entry.title),
1452 updated_at: Some(entry.updated_at),
1453 meta: None,
1454 }
1455 }
1456
1457 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1458 &self.thread_store
1459 }
1460}
1461
1462impl AgentSessionList for NativeAgentSessionList {
1463 fn list_sessions(
1464 &self,
1465 _request: AgentSessionListRequest,
1466 cx: &mut App,
1467 ) -> Task<Result<AgentSessionListResponse>> {
1468 let sessions = self
1469 .thread_store
1470 .read(cx)
1471 .entries()
1472 .map(Self::to_session_info)
1473 .collect();
1474 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1475 }
1476
1477 fn supports_delete(&self) -> bool {
1478 true
1479 }
1480
1481 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1482 self.thread_store
1483 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1484 }
1485
1486 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1487 self.thread_store
1488 .update(cx, |store, cx| store.delete_threads(cx))
1489 }
1490
1491 fn watch(
1492 &self,
1493 _cx: &mut App,
1494 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1495 Some(self.updates_rx.clone())
1496 }
1497
1498 fn notify_refresh(&self) {
1499 self.updates_tx
1500 .try_send(acp_thread::SessionListUpdate::Refresh)
1501 .ok();
1502 }
1503
1504 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1505 self
1506 }
1507}
1508
1509struct NativeAgentSessionTruncate {
1510 thread: Entity<Thread>,
1511 acp_thread: WeakEntity<AcpThread>,
1512}
1513
1514impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1515 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1516 match self.thread.update(cx, |thread, cx| {
1517 thread.truncate(message_id.clone(), cx)?;
1518 Ok(thread.latest_token_usage())
1519 }) {
1520 Ok(usage) => {
1521 self.acp_thread
1522 .update(cx, |thread, cx| {
1523 thread.update_token_usage(usage, cx);
1524 })
1525 .ok();
1526 Task::ready(Ok(()))
1527 }
1528 Err(error) => Task::ready(Err(error)),
1529 }
1530 }
1531}
1532
1533struct NativeAgentSessionRetry {
1534 connection: NativeAgentConnection,
1535 session_id: acp::SessionId,
1536}
1537
1538impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1539 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1540 self.connection
1541 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1542 thread.update(cx, |thread, cx| thread.resume(cx))
1543 })
1544 }
1545}
1546
1547struct NativeAgentSessionSetTitle {
1548 connection: NativeAgentConnection,
1549 session_id: acp::SessionId,
1550}
1551
1552impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1553 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1554 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1555 return Task::ready(Err(anyhow!("session not found")));
1556 };
1557 let thread = session.thread.clone();
1558 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1559 Task::ready(Ok(()))
1560 }
1561}
1562
1563pub struct AcpThreadEnvironment {
1564 acp_thread: WeakEntity<AcpThread>,
1565}
1566
1567impl ThreadEnvironment for AcpThreadEnvironment {
1568 fn create_terminal(
1569 &self,
1570 command: String,
1571 cwd: Option<PathBuf>,
1572 output_byte_limit: Option<u64>,
1573 cx: &mut AsyncApp,
1574 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1575 let task = self.acp_thread.update(cx, |thread, cx| {
1576 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1577 });
1578
1579 let acp_thread = self.acp_thread.clone();
1580 cx.spawn(async move |cx| {
1581 let terminal = task?.await?;
1582
1583 let (drop_tx, drop_rx) = oneshot::channel();
1584 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1585
1586 cx.spawn(async move |cx| {
1587 drop_rx.await.ok();
1588 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1589 })
1590 .detach();
1591
1592 let handle = AcpTerminalHandle {
1593 terminal,
1594 _drop_tx: Some(drop_tx),
1595 };
1596
1597 Ok(Rc::new(handle) as _)
1598 })
1599 }
1600}
1601
1602pub struct AcpTerminalHandle {
1603 terminal: Entity<acp_thread::Terminal>,
1604 _drop_tx: Option<oneshot::Sender<()>>,
1605}
1606
1607impl TerminalHandle for AcpTerminalHandle {
1608 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1609 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1610 }
1611
1612 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1613 Ok(self
1614 .terminal
1615 .read_with(cx, |term, _cx| term.wait_for_exit()))
1616 }
1617
1618 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1619 Ok(self
1620 .terminal
1621 .read_with(cx, |term, cx| term.current_output(cx)))
1622 }
1623
1624 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1625 cx.update(|cx| {
1626 self.terminal.update(cx, |terminal, cx| {
1627 terminal.kill(cx);
1628 });
1629 });
1630 Ok(())
1631 }
1632
1633 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1634 Ok(self
1635 .terminal
1636 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1637 }
1638}
1639
1640#[cfg(test)]
1641mod internal_tests {
1642 use super::*;
1643 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1644 use fs::FakeFs;
1645 use gpui::TestAppContext;
1646 use indoc::formatdoc;
1647 use language_model::fake_provider::FakeLanguageModel;
1648 use serde_json::json;
1649 use settings::SettingsStore;
1650 use util::{path, rel_path::rel_path};
1651
1652 #[gpui::test]
1653 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1654 init_test(cx);
1655 let fs = FakeFs::new(cx.executor());
1656 fs.insert_tree(
1657 "/",
1658 json!({
1659 "a": {}
1660 }),
1661 )
1662 .await;
1663 let project = Project::test(fs.clone(), [], cx).await;
1664 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1665 let agent = NativeAgent::new(
1666 project.clone(),
1667 thread_store,
1668 Templates::new(),
1669 None,
1670 fs.clone(),
1671 &mut cx.to_async(),
1672 )
1673 .await
1674 .unwrap();
1675 agent.read_with(cx, |agent, cx| {
1676 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1677 });
1678
1679 let worktree = project
1680 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1681 .await
1682 .unwrap();
1683 cx.run_until_parked();
1684 agent.read_with(cx, |agent, cx| {
1685 assert_eq!(
1686 agent.project_context.read(cx).worktrees,
1687 vec![WorktreeContext {
1688 root_name: "a".into(),
1689 abs_path: Path::new("/a").into(),
1690 rules_file: None
1691 }]
1692 )
1693 });
1694
1695 // Creating `/a/.rules` updates the project context.
1696 fs.insert_file("/a/.rules", Vec::new()).await;
1697 cx.run_until_parked();
1698 agent.read_with(cx, |agent, cx| {
1699 let rules_entry = worktree
1700 .read(cx)
1701 .entry_for_path(rel_path(".rules"))
1702 .unwrap();
1703 assert_eq!(
1704 agent.project_context.read(cx).worktrees,
1705 vec![WorktreeContext {
1706 root_name: "a".into(),
1707 abs_path: Path::new("/a").into(),
1708 rules_file: Some(RulesFileContext {
1709 path_in_worktree: rel_path(".rules").into(),
1710 text: "".into(),
1711 project_entry_id: rules_entry.id.to_usize()
1712 })
1713 }]
1714 )
1715 });
1716 }
1717
1718 #[gpui::test]
1719 async fn test_listing_models(cx: &mut TestAppContext) {
1720 init_test(cx);
1721 let fs = FakeFs::new(cx.executor());
1722 fs.insert_tree("/", json!({ "a": {} })).await;
1723 let project = Project::test(fs.clone(), [], cx).await;
1724 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1725 let connection = NativeAgentConnection(
1726 NativeAgent::new(
1727 project.clone(),
1728 thread_store,
1729 Templates::new(),
1730 None,
1731 fs.clone(),
1732 &mut cx.to_async(),
1733 )
1734 .await
1735 .unwrap(),
1736 );
1737
1738 // Create a thread/session
1739 let acp_thread = cx
1740 .update(|cx| {
1741 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1742 })
1743 .await
1744 .unwrap();
1745
1746 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1747
1748 let models = cx
1749 .update(|cx| {
1750 connection
1751 .model_selector(&session_id)
1752 .unwrap()
1753 .list_models(cx)
1754 })
1755 .await
1756 .unwrap();
1757
1758 let acp_thread::AgentModelList::Grouped(models) = models else {
1759 panic!("Unexpected model group");
1760 };
1761 assert_eq!(
1762 models,
1763 IndexMap::from_iter([(
1764 AgentModelGroupName("Fake".into()),
1765 vec![AgentModelInfo {
1766 id: acp::ModelId::new("fake/fake"),
1767 name: "Fake".into(),
1768 description: None,
1769 icon: Some(acp_thread::AgentModelIcon::Named(
1770 ui::IconName::ZedAssistant
1771 )),
1772 is_latest: false,
1773 }]
1774 )])
1775 );
1776 }
1777
1778 #[gpui::test]
1779 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1780 init_test(cx);
1781 let fs = FakeFs::new(cx.executor());
1782 fs.create_dir(paths::settings_file().parent().unwrap())
1783 .await
1784 .unwrap();
1785 fs.insert_file(
1786 paths::settings_file(),
1787 json!({
1788 "agent": {
1789 "default_model": {
1790 "provider": "foo",
1791 "model": "bar"
1792 }
1793 }
1794 })
1795 .to_string()
1796 .into_bytes(),
1797 )
1798 .await;
1799 let project = Project::test(fs.clone(), [], cx).await;
1800
1801 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1802
1803 // Create the agent and connection
1804 let agent = NativeAgent::new(
1805 project.clone(),
1806 thread_store,
1807 Templates::new(),
1808 None,
1809 fs.clone(),
1810 &mut cx.to_async(),
1811 )
1812 .await
1813 .unwrap();
1814 let connection = NativeAgentConnection(agent.clone());
1815
1816 // Create a thread/session
1817 let acp_thread = cx
1818 .update(|cx| {
1819 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1820 })
1821 .await
1822 .unwrap();
1823
1824 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1825
1826 // Select a model
1827 let selector = connection.model_selector(&session_id).unwrap();
1828 let model_id = acp::ModelId::new("fake/fake");
1829 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1830 .await
1831 .unwrap();
1832
1833 // Verify the thread has the selected model
1834 agent.read_with(cx, |agent, _| {
1835 let session = agent.sessions.get(&session_id).unwrap();
1836 session.thread.read_with(cx, |thread, _| {
1837 assert_eq!(thread.model().unwrap().id().0, "fake");
1838 });
1839 });
1840
1841 cx.run_until_parked();
1842
1843 // Verify settings file was updated
1844 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1845 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1846
1847 // Check that the agent settings contain the selected model
1848 assert_eq!(
1849 settings_json["agent"]["default_model"]["model"],
1850 json!("fake")
1851 );
1852 assert_eq!(
1853 settings_json["agent"]["default_model"]["provider"],
1854 json!("fake")
1855 );
1856 }
1857
1858 #[gpui::test]
1859 async fn test_save_load_thread(cx: &mut TestAppContext) {
1860 init_test(cx);
1861 let fs = FakeFs::new(cx.executor());
1862 fs.insert_tree(
1863 "/",
1864 json!({
1865 "a": {
1866 "b.md": "Lorem"
1867 }
1868 }),
1869 )
1870 .await;
1871 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1872 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1873 let agent = NativeAgent::new(
1874 project.clone(),
1875 thread_store.clone(),
1876 Templates::new(),
1877 None,
1878 fs.clone(),
1879 &mut cx.to_async(),
1880 )
1881 .await
1882 .unwrap();
1883 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1884
1885 let acp_thread = cx
1886 .update(|cx| {
1887 connection
1888 .clone()
1889 .new_thread(project.clone(), Path::new(""), cx)
1890 })
1891 .await
1892 .unwrap();
1893 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1894 let thread = agent.read_with(cx, |agent, _| {
1895 agent.sessions.get(&session_id).unwrap().thread.clone()
1896 });
1897
1898 // Ensure empty threads are not saved, even if they get mutated.
1899 let model = Arc::new(FakeLanguageModel::default());
1900 let summary_model = Arc::new(FakeLanguageModel::default());
1901 thread.update(cx, |thread, cx| {
1902 thread.set_model(model.clone(), cx);
1903 thread.set_summarization_model(Some(summary_model.clone()), cx);
1904 });
1905 cx.run_until_parked();
1906 assert_eq!(thread_entries(&thread_store, cx), vec![]);
1907
1908 let send = acp_thread.update(cx, |thread, cx| {
1909 thread.send(
1910 vec![
1911 "What does ".into(),
1912 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1913 "b.md",
1914 MentionUri::File {
1915 abs_path: path!("/a/b.md").into(),
1916 }
1917 .to_uri()
1918 .to_string(),
1919 )),
1920 " mean?".into(),
1921 ],
1922 cx,
1923 )
1924 });
1925 let send = cx.foreground_executor().spawn(send);
1926 cx.run_until_parked();
1927
1928 model.send_last_completion_stream_text_chunk("Lorem.");
1929 model.end_last_completion_stream();
1930 cx.run_until_parked();
1931 summary_model
1932 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1933 summary_model.end_last_completion_stream();
1934
1935 send.await.unwrap();
1936 let uri = MentionUri::File {
1937 abs_path: path!("/a/b.md").into(),
1938 }
1939 .to_uri();
1940 acp_thread.read_with(cx, |thread, cx| {
1941 assert_eq!(
1942 thread.to_markdown(cx),
1943 formatdoc! {"
1944 ## User
1945
1946 What does [@b.md]({uri}) mean?
1947
1948 ## Assistant
1949
1950 Lorem.
1951
1952 "}
1953 )
1954 });
1955
1956 cx.run_until_parked();
1957
1958 // Drop the ACP thread, which should cause the session to be dropped as well.
1959 cx.update(|_| {
1960 drop(thread);
1961 drop(acp_thread);
1962 });
1963 agent.read_with(cx, |agent, _| {
1964 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1965 });
1966
1967 // Ensure the thread can be reloaded from disk.
1968 assert_eq!(
1969 thread_entries(&thread_store, cx),
1970 vec![(
1971 session_id.clone(),
1972 format!("Explaining {}", path!("/a/b.md"))
1973 )]
1974 );
1975 let acp_thread = agent
1976 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1977 .await
1978 .unwrap();
1979 acp_thread.read_with(cx, |thread, cx| {
1980 assert_eq!(
1981 thread.to_markdown(cx),
1982 formatdoc! {"
1983 ## User
1984
1985 What does [@b.md]({uri}) mean?
1986
1987 ## Assistant
1988
1989 Lorem.
1990
1991 "}
1992 )
1993 });
1994 }
1995
1996 fn thread_entries(
1997 thread_store: &Entity<ThreadStore>,
1998 cx: &mut TestAppContext,
1999 ) -> Vec<(acp::SessionId, String)> {
2000 thread_store.read_with(cx, |store, _| {
2001 store
2002 .entries()
2003 .map(|entry| (entry.id.clone(), entry.title.to_string()))
2004 .collect::<Vec<_>>()
2005 })
2006 }
2007
2008 fn init_test(cx: &mut TestAppContext) {
2009 env_logger::try_init().ok();
2010 cx.update(|cx| {
2011 let settings_store = SettingsStore::test(cx);
2012 cx.set_global(settings_store);
2013
2014 LanguageModelRegistry::test(cx);
2015 });
2016 }
2017}
2018
2019fn mcp_message_content_to_acp_content_block(
2020 content: context_server::types::MessageContent,
2021) -> acp::ContentBlock {
2022 match content {
2023 context_server::types::MessageContent::Text {
2024 text,
2025 annotations: _,
2026 } => text.into(),
2027 context_server::types::MessageContent::Image {
2028 data,
2029 mime_type,
2030 annotations: _,
2031 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2032 context_server::types::MessageContent::Audio {
2033 data,
2034 mime_type,
2035 annotations: _,
2036 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2037 context_server::types::MessageContent::Resource {
2038 resource,
2039 annotations: _,
2040 } => {
2041 let mut link =
2042 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2043 if let Some(mime_type) = resource.mime_type {
2044 link = link.mime_type(mime_type);
2045 }
2046 acp::ContentBlock::ResourceLink(link)
2047 }
2048 }
2049}