1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod shell_parser;
8mod templates;
9#[cfg(test)]
10mod tests;
11mod thread;
12mod thread_store;
13mod tool_permissions;
14mod tools;
15
16use context_server::ContextServerId;
17pub use db::*;
18pub use native_agent_server::NativeAgentServer;
19pub use pattern_extraction::*;
20pub use templates::*;
21pub use thread::*;
22pub use thread_store::*;
23pub use tool_permissions::*;
24pub use tools::*;
25
26use acp_thread::{
27 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
28 AgentSessionListResponse, UserMessageId,
29};
30use agent_client_protocol as acp;
31use anyhow::{Context as _, Result, anyhow};
32use chrono::{DateTime, Utc};
33use collections::{HashMap, HashSet, IndexMap};
34use fs::Fs;
35use futures::channel::{mpsc, oneshot};
36use futures::future::Shared;
37use futures::{StreamExt, future};
38use gpui::{
39 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
40};
41use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
42use project::{Project, ProjectItem, ProjectPath, Worktree};
43use prompt_store::{
44 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
45 WorktreeContext,
46};
47use serde::{Deserialize, Serialize};
48use settings::{LanguageModelSelection, update_settings_file};
49use std::any::Any;
50use std::path::{Path, PathBuf};
51use std::rc::Rc;
52use std::sync::Arc;
53use util::ResultExt;
54use util::rel_path::RelPath;
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct ProjectSnapshot {
58 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
59 pub timestamp: DateTime<Utc>,
60}
61
62pub struct RulesLoadingError {
63 pub message: SharedString,
64}
65
66/// Holds both the internal Thread and the AcpThread for a session
67struct Session {
68 /// The internal thread that processes messages
69 thread: Entity<Thread>,
70 /// The ACP thread that handles protocol communication
71 acp_thread: WeakEntity<acp_thread::AcpThread>,
72 pending_save: Task<()>,
73 _subscriptions: Vec<Subscription>,
74}
75
76pub struct LanguageModels {
77 /// Access language model by ID
78 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
79 /// Cached list for returning language model information
80 model_list: acp_thread::AgentModelList,
81 refresh_models_rx: watch::Receiver<()>,
82 refresh_models_tx: watch::Sender<()>,
83 _authenticate_all_providers_task: Task<()>,
84}
85
86impl LanguageModels {
87 fn new(cx: &mut App) -> Self {
88 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
89
90 let mut this = Self {
91 models: HashMap::default(),
92 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
93 refresh_models_rx,
94 refresh_models_tx,
95 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
96 };
97 this.refresh_list(cx);
98 this
99 }
100
101 fn refresh_list(&mut self, cx: &App) {
102 let providers = LanguageModelRegistry::global(cx)
103 .read(cx)
104 .visible_providers()
105 .into_iter()
106 .filter(|provider| provider.is_authenticated(cx))
107 .collect::<Vec<_>>();
108
109 let mut language_model_list = IndexMap::default();
110 let mut recommended_models = HashSet::default();
111
112 let mut recommended = Vec::new();
113 for provider in &providers {
114 for model in provider.recommended_models(cx) {
115 recommended_models.insert((model.provider_id(), model.id()));
116 recommended.push(Self::map_language_model_to_info(&model, provider));
117 }
118 }
119 if !recommended.is_empty() {
120 language_model_list.insert(
121 acp_thread::AgentModelGroupName("Recommended".into()),
122 recommended,
123 );
124 }
125
126 let mut models = HashMap::default();
127 for provider in providers {
128 let mut provider_models = Vec::new();
129 for model in provider.provided_models(cx) {
130 let model_info = Self::map_language_model_to_info(&model, &provider);
131 let model_id = model_info.id.clone();
132 provider_models.push(model_info);
133 models.insert(model_id, model);
134 }
135 if !provider_models.is_empty() {
136 language_model_list.insert(
137 acp_thread::AgentModelGroupName(provider.name().0.clone()),
138 provider_models,
139 );
140 }
141 }
142
143 self.models = models;
144 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
145 self.refresh_models_tx.send(()).ok();
146 }
147
148 fn watch(&self) -> watch::Receiver<()> {
149 self.refresh_models_rx.clone()
150 }
151
152 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
153 self.models.get(model_id).cloned()
154 }
155
156 fn map_language_model_to_info(
157 model: &Arc<dyn LanguageModel>,
158 provider: &Arc<dyn LanguageModelProvider>,
159 ) -> acp_thread::AgentModelInfo {
160 acp_thread::AgentModelInfo {
161 id: Self::model_id(model),
162 name: model.name().0,
163 description: None,
164 icon: Some(match provider.icon() {
165 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
166 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
167 }),
168 }
169 }
170
171 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
172 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
173 }
174
175 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
176 let authenticate_all_providers = LanguageModelRegistry::global(cx)
177 .read(cx)
178 .visible_providers()
179 .iter()
180 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
181 .collect::<Vec<_>>();
182
183 cx.background_spawn(async move {
184 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
185 if let Err(err) = authenticate_task.await {
186 match err {
187 language_model::AuthenticateError::CredentialsNotFound => {
188 // Since we're authenticating these providers in the
189 // background for the purposes of populating the
190 // language selector, we don't care about providers
191 // where the credentials are not found.
192 }
193 language_model::AuthenticateError::ConnectionRefused => {
194 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
195 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
196 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
197 }
198 _ => {
199 // Some providers have noisy failure states that we
200 // don't want to spam the logs with every time the
201 // language model selector is initialized.
202 //
203 // Ideally these should have more clear failure modes
204 // that we know are safe to ignore here, like what we do
205 // with `CredentialsNotFound` above.
206 match provider_id.0.as_ref() {
207 "lmstudio" | "ollama" => {
208 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
209 //
210 // These fail noisily, so we don't log them.
211 }
212 "copilot_chat" => {
213 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
214 }
215 _ => {
216 log::error!(
217 "Failed to authenticate provider: {}: {err:#}",
218 provider_name.0
219 );
220 }
221 }
222 }
223 }
224 }
225 }
226 })
227 }
228}
229
230pub struct NativeAgent {
231 /// Session ID -> Session mapping
232 sessions: HashMap<acp::SessionId, Session>,
233 thread_store: Entity<ThreadStore>,
234 /// Shared project context for all threads
235 project_context: Entity<ProjectContext>,
236 project_context_needs_refresh: watch::Sender<()>,
237 _maintain_project_context: Task<Result<()>>,
238 context_server_registry: Entity<ContextServerRegistry>,
239 /// Shared templates for all threads
240 templates: Arc<Templates>,
241 /// Cached model information
242 models: LanguageModels,
243 project: Entity<Project>,
244 prompt_store: Option<Entity<PromptStore>>,
245 fs: Arc<dyn Fs>,
246 _subscriptions: Vec<Subscription>,
247}
248
249impl NativeAgent {
250 pub async fn new(
251 project: Entity<Project>,
252 thread_store: Entity<ThreadStore>,
253 templates: Arc<Templates>,
254 prompt_store: Option<Entity<PromptStore>>,
255 fs: Arc<dyn Fs>,
256 cx: &mut AsyncApp,
257 ) -> Result<Entity<NativeAgent>> {
258 log::debug!("Creating new NativeAgent");
259
260 let project_context = cx
261 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
262 .await;
263
264 Ok(cx.new(|cx| {
265 let context_server_store = project.read(cx).context_server_store();
266 let context_server_registry =
267 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
268
269 let mut subscriptions = vec![
270 cx.subscribe(&project, Self::handle_project_event),
271 cx.subscribe(
272 &LanguageModelRegistry::global(cx),
273 Self::handle_models_updated_event,
274 ),
275 cx.subscribe(
276 &context_server_store,
277 Self::handle_context_server_store_updated,
278 ),
279 cx.subscribe(
280 &context_server_registry,
281 Self::handle_context_server_registry_event,
282 ),
283 ];
284 if let Some(prompt_store) = prompt_store.as_ref() {
285 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
286 }
287
288 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
289 watch::channel(());
290 Self {
291 sessions: HashMap::default(),
292 thread_store,
293 project_context: cx.new(|_| project_context),
294 project_context_needs_refresh: project_context_needs_refresh_tx,
295 _maintain_project_context: cx.spawn(async move |this, cx| {
296 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
297 }),
298 context_server_registry,
299 templates,
300 models: LanguageModels::new(cx),
301 project,
302 prompt_store,
303 fs,
304 _subscriptions: subscriptions,
305 }
306 }))
307 }
308
309 fn register_session(
310 &mut self,
311 thread_handle: Entity<Thread>,
312 cx: &mut Context<Self>,
313 ) -> Entity<AcpThread> {
314 let connection = Rc::new(NativeAgentConnection(cx.entity()));
315
316 let thread = thread_handle.read(cx);
317 let session_id = thread.id().clone();
318 let title = thread.title();
319 let project = thread.project.clone();
320 let action_log = thread.action_log.clone();
321 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
322 let acp_thread = cx.new(|cx| {
323 acp_thread::AcpThread::new(
324 title,
325 connection,
326 project.clone(),
327 action_log.clone(),
328 session_id.clone(),
329 prompt_capabilities_rx,
330 cx,
331 )
332 });
333
334 let registry = LanguageModelRegistry::read_global(cx);
335 let summarization_model = registry.thread_summary_model().map(|c| c.model);
336
337 thread_handle.update(cx, |thread, cx| {
338 thread.set_summarization_model(summarization_model, cx);
339 thread.add_default_tools(
340 Rc::new(AcpThreadEnvironment {
341 acp_thread: acp_thread.downgrade(),
342 }) as _,
343 cx,
344 )
345 });
346
347 let subscriptions = vec![
348 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
349 this.sessions.remove(acp_thread.session_id());
350 }),
351 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
352 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
353 cx.observe(&thread_handle, move |this, thread, cx| {
354 this.save_thread(thread, cx)
355 }),
356 ];
357
358 self.sessions.insert(
359 session_id,
360 Session {
361 thread: thread_handle,
362 acp_thread: acp_thread.downgrade(),
363 _subscriptions: subscriptions,
364 pending_save: Task::ready(()),
365 },
366 );
367
368 self.update_available_commands(cx);
369
370 acp_thread
371 }
372
373 pub fn models(&self) -> &LanguageModels {
374 &self.models
375 }
376
377 async fn maintain_project_context(
378 this: WeakEntity<Self>,
379 mut needs_refresh: watch::Receiver<()>,
380 cx: &mut AsyncApp,
381 ) -> Result<()> {
382 while needs_refresh.changed().await.is_ok() {
383 let project_context = this
384 .update(cx, |this, cx| {
385 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
386 })?
387 .await;
388 this.update(cx, |this, cx| {
389 this.project_context = cx.new(|_| project_context);
390 })?;
391 }
392
393 Ok(())
394 }
395
396 fn build_project_context(
397 project: &Entity<Project>,
398 prompt_store: Option<&Entity<PromptStore>>,
399 cx: &mut App,
400 ) -> Task<ProjectContext> {
401 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
402 let worktree_tasks = worktrees
403 .into_iter()
404 .map(|worktree| {
405 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
406 })
407 .collect::<Vec<_>>();
408 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
409 prompt_store.read_with(cx, |prompt_store, cx| {
410 let prompts = prompt_store.default_prompt_metadata();
411 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
412 let contents = prompt_store.load(prompt_metadata.id, cx);
413 async move { (contents.await, prompt_metadata) }
414 });
415 cx.background_spawn(future::join_all(load_tasks))
416 })
417 } else {
418 Task::ready(vec![])
419 };
420
421 cx.spawn(async move |_cx| {
422 let (worktrees, default_user_rules) =
423 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
424
425 let worktrees = worktrees
426 .into_iter()
427 .map(|(worktree, _rules_error)| {
428 // TODO: show error message
429 // if let Some(rules_error) = rules_error {
430 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
431 // }
432 worktree
433 })
434 .collect::<Vec<_>>();
435
436 let default_user_rules = default_user_rules
437 .into_iter()
438 .flat_map(|(contents, prompt_metadata)| match contents {
439 Ok(contents) => Some(UserRulesContext {
440 uuid: prompt_metadata.id.as_user()?,
441 title: prompt_metadata.title.map(|title| title.to_string()),
442 contents,
443 }),
444 Err(_err) => {
445 // TODO: show error message
446 // this.update(cx, |_, cx| {
447 // cx.emit(RulesLoadingError {
448 // message: format!("{err:?}").into(),
449 // });
450 // })
451 // .ok();
452 None
453 }
454 })
455 .collect::<Vec<_>>();
456
457 ProjectContext::new(worktrees, default_user_rules)
458 })
459 }
460
461 fn load_worktree_info_for_system_prompt(
462 worktree: Entity<Worktree>,
463 project: Entity<Project>,
464 cx: &mut App,
465 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
466 let tree = worktree.read(cx);
467 let root_name = tree.root_name_str().into();
468 let abs_path = tree.abs_path();
469
470 let mut context = WorktreeContext {
471 root_name,
472 abs_path,
473 rules_file: None,
474 };
475
476 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
477 let Some(rules_task) = rules_task else {
478 return Task::ready((context, None));
479 };
480
481 cx.spawn(async move |_| {
482 let (rules_file, rules_file_error) = match rules_task.await {
483 Ok(rules_file) => (Some(rules_file), None),
484 Err(err) => (
485 None,
486 Some(RulesLoadingError {
487 message: format!("{err}").into(),
488 }),
489 ),
490 };
491 context.rules_file = rules_file;
492 (context, rules_file_error)
493 })
494 }
495
496 fn load_worktree_rules_file(
497 worktree: Entity<Worktree>,
498 project: Entity<Project>,
499 cx: &mut App,
500 ) -> Option<Task<Result<RulesFileContext>>> {
501 let worktree = worktree.read(cx);
502 let worktree_id = worktree.id();
503 let selected_rules_file = RULES_FILE_NAMES
504 .into_iter()
505 .filter_map(|name| {
506 worktree
507 .entry_for_path(RelPath::unix(name).unwrap())
508 .filter(|entry| entry.is_file())
509 .map(|entry| entry.path.clone())
510 })
511 .next();
512
513 // Note that Cline supports `.clinerules` being a directory, but that is not currently
514 // supported. This doesn't seem to occur often in GitHub repositories.
515 selected_rules_file.map(|path_in_worktree| {
516 let project_path = ProjectPath {
517 worktree_id,
518 path: path_in_worktree.clone(),
519 };
520 let buffer_task =
521 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
522 let rope_task = cx.spawn(async move |cx| {
523 let buffer = buffer_task.await?;
524 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
525 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
526 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
527 })?;
528 anyhow::Ok((project_entry_id, rope))
529 });
530 // Build a string from the rope on a background thread.
531 cx.background_spawn(async move {
532 let (project_entry_id, rope) = rope_task.await?;
533 anyhow::Ok(RulesFileContext {
534 path_in_worktree,
535 text: rope.to_string().trim().to_string(),
536 project_entry_id: project_entry_id.to_usize(),
537 })
538 })
539 })
540 }
541
542 fn handle_thread_title_updated(
543 &mut self,
544 thread: Entity<Thread>,
545 _: &TitleUpdated,
546 cx: &mut Context<Self>,
547 ) {
548 let session_id = thread.read(cx).id();
549 let Some(session) = self.sessions.get(session_id) else {
550 return;
551 };
552 let thread = thread.downgrade();
553 let acp_thread = session.acp_thread.clone();
554 cx.spawn(async move |_, cx| {
555 let title = thread.read_with(cx, |thread, _| thread.title())?;
556 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
557 task.await
558 })
559 .detach_and_log_err(cx);
560 }
561
562 fn handle_thread_token_usage_updated(
563 &mut self,
564 thread: Entity<Thread>,
565 usage: &TokenUsageUpdated,
566 cx: &mut Context<Self>,
567 ) {
568 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
569 return;
570 };
571 session
572 .acp_thread
573 .update(cx, |acp_thread, cx| {
574 acp_thread.update_token_usage(usage.0.clone(), cx);
575 })
576 .ok();
577 }
578
579 fn handle_project_event(
580 &mut self,
581 _project: Entity<Project>,
582 event: &project::Event,
583 _cx: &mut Context<Self>,
584 ) {
585 match event {
586 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
587 self.project_context_needs_refresh.send(()).ok();
588 }
589 project::Event::WorktreeUpdatedEntries(_, items) => {
590 if items.iter().any(|(path, _, _)| {
591 RULES_FILE_NAMES
592 .iter()
593 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
594 }) {
595 self.project_context_needs_refresh.send(()).ok();
596 }
597 }
598 _ => {}
599 }
600 }
601
602 fn handle_prompts_updated_event(
603 &mut self,
604 _prompt_store: Entity<PromptStore>,
605 _event: &prompt_store::PromptsUpdatedEvent,
606 _cx: &mut Context<Self>,
607 ) {
608 self.project_context_needs_refresh.send(()).ok();
609 }
610
611 fn handle_models_updated_event(
612 &mut self,
613 _registry: Entity<LanguageModelRegistry>,
614 _event: &language_model::Event,
615 cx: &mut Context<Self>,
616 ) {
617 self.models.refresh_list(cx);
618
619 let registry = LanguageModelRegistry::read_global(cx);
620 let default_model = registry.default_model().map(|m| m.model);
621 let summarization_model = registry.thread_summary_model().map(|m| m.model);
622
623 for session in self.sessions.values_mut() {
624 session.thread.update(cx, |thread, cx| {
625 if thread.model().is_none()
626 && let Some(model) = default_model.clone()
627 {
628 thread.set_model(model, cx);
629 cx.notify();
630 }
631 thread.set_summarization_model(summarization_model.clone(), cx);
632 });
633 }
634 }
635
636 fn handle_context_server_store_updated(
637 &mut self,
638 _store: Entity<project::context_server_store::ContextServerStore>,
639 _event: &project::context_server_store::Event,
640 cx: &mut Context<Self>,
641 ) {
642 self.update_available_commands(cx);
643 }
644
645 fn handle_context_server_registry_event(
646 &mut self,
647 _registry: Entity<ContextServerRegistry>,
648 event: &ContextServerRegistryEvent,
649 cx: &mut Context<Self>,
650 ) {
651 match event {
652 ContextServerRegistryEvent::ToolsChanged => {}
653 ContextServerRegistryEvent::PromptsChanged => {
654 self.update_available_commands(cx);
655 }
656 }
657 }
658
659 fn update_available_commands(&self, cx: &mut Context<Self>) {
660 let available_commands = self.build_available_commands(cx);
661 for session in self.sessions.values() {
662 if let Some(acp_thread) = session.acp_thread.upgrade() {
663 acp_thread.update(cx, |thread, cx| {
664 thread
665 .handle_session_update(
666 acp::SessionUpdate::AvailableCommandsUpdate(
667 acp::AvailableCommandsUpdate::new(available_commands.clone()),
668 ),
669 cx,
670 )
671 .log_err();
672 });
673 }
674 }
675 }
676
677 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
678 let registry = self.context_server_registry.read(cx);
679
680 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
681 for context_server_prompt in registry.prompts() {
682 *prompt_name_counts
683 .entry(context_server_prompt.prompt.name.as_str())
684 .or_insert(0) += 1;
685 }
686
687 registry
688 .prompts()
689 .flat_map(|context_server_prompt| {
690 let prompt = &context_server_prompt.prompt;
691
692 let should_prefix = prompt_name_counts
693 .get(prompt.name.as_str())
694 .copied()
695 .unwrap_or(0)
696 > 1;
697
698 let name = if should_prefix {
699 format!("{}.{}", context_server_prompt.server_id, prompt.name)
700 } else {
701 prompt.name.clone()
702 };
703
704 let mut command = acp::AvailableCommand::new(
705 name,
706 prompt.description.clone().unwrap_or_default(),
707 );
708
709 match prompt.arguments.as_deref() {
710 Some([arg]) => {
711 let hint = format!("<{}>", arg.name);
712
713 command = command.input(acp::AvailableCommandInput::Unstructured(
714 acp::UnstructuredCommandInput::new(hint),
715 ));
716 }
717 Some([]) | None => {}
718 Some(_) => {
719 // skip >1 argument commands since we don't support them yet
720 return None;
721 }
722 }
723
724 Some(command)
725 })
726 .collect()
727 }
728
729 pub fn load_thread(
730 &mut self,
731 id: acp::SessionId,
732 cx: &mut Context<Self>,
733 ) -> Task<Result<Entity<Thread>>> {
734 let database_future = ThreadsDatabase::connect(cx);
735 cx.spawn(async move |this, cx| {
736 let database = database_future.await.map_err(|err| anyhow!(err))?;
737 let db_thread = database
738 .load_thread(id.clone())
739 .await?
740 .with_context(|| format!("no thread found with ID: {id:?}"))?;
741
742 this.update(cx, |this, cx| {
743 let summarization_model = LanguageModelRegistry::read_global(cx)
744 .thread_summary_model()
745 .map(|c| c.model);
746
747 cx.new(|cx| {
748 let mut thread = Thread::from_db(
749 id.clone(),
750 db_thread,
751 this.project.clone(),
752 this.project_context.clone(),
753 this.context_server_registry.clone(),
754 this.templates.clone(),
755 cx,
756 );
757 thread.set_summarization_model(summarization_model, cx);
758 thread
759 })
760 })
761 })
762 }
763
764 pub fn open_thread(
765 &mut self,
766 id: acp::SessionId,
767 cx: &mut Context<Self>,
768 ) -> Task<Result<Entity<AcpThread>>> {
769 let task = self.load_thread(id, cx);
770 cx.spawn(async move |this, cx| {
771 let thread = task.await?;
772 let acp_thread =
773 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
774 let events = thread.update(cx, |thread, cx| thread.replay(cx));
775 cx.update(|cx| {
776 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
777 })
778 .await?;
779 Ok(acp_thread)
780 })
781 }
782
783 pub fn thread_summary(
784 &mut self,
785 id: acp::SessionId,
786 cx: &mut Context<Self>,
787 ) -> Task<Result<SharedString>> {
788 let thread = self.open_thread(id.clone(), cx);
789 cx.spawn(async move |this, cx| {
790 let acp_thread = thread.await?;
791 let result = this
792 .update(cx, |this, cx| {
793 this.sessions
794 .get(&id)
795 .unwrap()
796 .thread
797 .update(cx, |thread, cx| thread.summary(cx))
798 })?
799 .await
800 .context("Failed to generate summary")?;
801 drop(acp_thread);
802 Ok(result)
803 })
804 }
805
806 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
807 if thread.read(cx).is_empty() {
808 return;
809 }
810
811 let database_future = ThreadsDatabase::connect(cx);
812 let (id, db_thread) =
813 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
814 let Some(session) = self.sessions.get_mut(&id) else {
815 return;
816 };
817 let thread_store = self.thread_store.clone();
818 session.pending_save = cx.spawn(async move |_, cx| {
819 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
820 return;
821 };
822 let db_thread = db_thread.await;
823 database.save_thread(id, db_thread).await.log_err();
824 thread_store.update(cx, |store, cx| store.reload(cx));
825 });
826 }
827
828 fn send_mcp_prompt(
829 &self,
830 message_id: UserMessageId,
831 session_id: agent_client_protocol::SessionId,
832 prompt_name: String,
833 server_id: ContextServerId,
834 arguments: HashMap<String, String>,
835 original_content: Vec<acp::ContentBlock>,
836 cx: &mut Context<Self>,
837 ) -> Task<Result<acp::PromptResponse>> {
838 let server_store = self.context_server_registry.read(cx).server_store().clone();
839 let path_style = self.project.read(cx).path_style(cx);
840
841 cx.spawn(async move |this, cx| {
842 let prompt =
843 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
844
845 let (acp_thread, thread) = this.update(cx, |this, _cx| {
846 let session = this
847 .sessions
848 .get(&session_id)
849 .context("Failed to get session")?;
850 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
851 })??;
852
853 let mut last_is_user = true;
854
855 thread.update(cx, |thread, cx| {
856 thread.push_acp_user_block(
857 message_id,
858 original_content.into_iter().skip(1),
859 path_style,
860 cx,
861 );
862 });
863
864 for message in prompt.messages {
865 let context_server::types::PromptMessage { role, content } = message;
866 let block = mcp_message_content_to_acp_content_block(content);
867
868 match role {
869 context_server::types::Role::User => {
870 let id = acp_thread::UserMessageId::new();
871
872 acp_thread.update(cx, |acp_thread, cx| {
873 acp_thread.push_user_content_block_with_indent(
874 Some(id.clone()),
875 block.clone(),
876 true,
877 cx,
878 );
879 })?;
880
881 thread.update(cx, |thread, cx| {
882 thread.push_acp_user_block(id, [block], path_style, cx);
883 });
884 }
885 context_server::types::Role::Assistant => {
886 acp_thread.update(cx, |acp_thread, cx| {
887 acp_thread.push_assistant_content_block_with_indent(
888 block.clone(),
889 false,
890 true,
891 cx,
892 );
893 })?;
894
895 thread.update(cx, |thread, cx| {
896 thread.push_acp_agent_block(block, cx);
897 });
898 }
899 }
900
901 last_is_user = role == context_server::types::Role::User;
902 }
903
904 let response_stream = thread.update(cx, |thread, cx| {
905 if last_is_user {
906 thread.send_existing(cx)
907 } else {
908 // Resume if MCP prompt did not end with a user message
909 thread.resume(cx)
910 }
911 })?;
912
913 cx.update(|cx| {
914 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
915 })
916 .await
917 })
918 }
919}
920
921/// Wrapper struct that implements the AgentConnection trait
922#[derive(Clone)]
923pub struct NativeAgentConnection(pub Entity<NativeAgent>);
924
925impl NativeAgentConnection {
926 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
927 self.0
928 .read(cx)
929 .sessions
930 .get(session_id)
931 .map(|session| session.thread.clone())
932 }
933
934 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
935 self.0.update(cx, |this, cx| this.load_thread(id, cx))
936 }
937
938 fn run_turn(
939 &self,
940 session_id: acp::SessionId,
941 cx: &mut App,
942 f: impl 'static
943 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
944 ) -> Task<Result<acp::PromptResponse>> {
945 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
946 agent
947 .sessions
948 .get_mut(&session_id)
949 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
950 }) else {
951 return Task::ready(Err(anyhow!("Session not found")));
952 };
953 log::debug!("Found session for: {}", session_id);
954
955 let response_stream = match f(thread, cx) {
956 Ok(stream) => stream,
957 Err(err) => return Task::ready(Err(err)),
958 };
959 Self::handle_thread_events(response_stream, acp_thread, cx)
960 }
961
962 fn handle_thread_events(
963 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
964 acp_thread: WeakEntity<AcpThread>,
965 cx: &App,
966 ) -> Task<Result<acp::PromptResponse>> {
967 cx.spawn(async move |cx| {
968 // Handle response stream and forward to session.acp_thread
969 while let Some(result) = events.next().await {
970 match result {
971 Ok(event) => {
972 log::trace!("Received completion event: {:?}", event);
973
974 match event {
975 ThreadEvent::UserMessage(message) => {
976 acp_thread.update(cx, |thread, cx| {
977 for content in message.content {
978 thread.push_user_content_block(
979 Some(message.id.clone()),
980 content.into(),
981 cx,
982 );
983 }
984 })?;
985 }
986 ThreadEvent::AgentText(text) => {
987 acp_thread.update(cx, |thread, cx| {
988 thread.push_assistant_content_block(text.into(), false, cx)
989 })?;
990 }
991 ThreadEvent::AgentThinking(text) => {
992 acp_thread.update(cx, |thread, cx| {
993 thread.push_assistant_content_block(text.into(), true, cx)
994 })?;
995 }
996 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
997 tool_call,
998 options,
999 response,
1000 context: _,
1001 }) => {
1002 let outcome_task = acp_thread.update(cx, |thread, cx| {
1003 thread.request_tool_call_authorization(
1004 tool_call, options, true, cx,
1005 )
1006 })??;
1007 cx.background_spawn(async move {
1008 if let acp::RequestPermissionOutcome::Selected(
1009 acp::SelectedPermissionOutcome { option_id, .. },
1010 ) = outcome_task.await
1011 {
1012 response
1013 .send(option_id)
1014 .map(|_| anyhow!("authorization receiver was dropped"))
1015 .log_err();
1016 }
1017 })
1018 .detach();
1019 }
1020 ThreadEvent::ToolCall(tool_call) => {
1021 acp_thread.update(cx, |thread, cx| {
1022 thread.upsert_tool_call(tool_call, cx)
1023 })??;
1024 }
1025 ThreadEvent::ToolCallUpdate(update) => {
1026 acp_thread.update(cx, |thread, cx| {
1027 thread.update_tool_call(update, cx)
1028 })??;
1029 }
1030 ThreadEvent::Retry(status) => {
1031 acp_thread.update(cx, |thread, cx| {
1032 thread.update_retry_status(status, cx)
1033 })?;
1034 }
1035 ThreadEvent::Stop(stop_reason) => {
1036 log::debug!("Assistant message complete: {:?}", stop_reason);
1037 return Ok(acp::PromptResponse::new(stop_reason));
1038 }
1039 }
1040 }
1041 Err(e) => {
1042 log::error!("Error in model response stream: {:?}", e);
1043 return Err(e);
1044 }
1045 }
1046 }
1047
1048 log::debug!("Response stream completed");
1049 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1050 })
1051 }
1052}
1053
1054struct Command<'a> {
1055 prompt_name: &'a str,
1056 arg_value: &'a str,
1057 explicit_server_id: Option<&'a str>,
1058}
1059
1060impl<'a> Command<'a> {
1061 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1062 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1063 return None;
1064 };
1065 let text = text_content.text.trim();
1066 let command = text.strip_prefix('/')?;
1067 let (command, arg_value) = command
1068 .split_once(char::is_whitespace)
1069 .unwrap_or((command, ""));
1070
1071 if let Some((server_id, prompt_name)) = command.split_once('.') {
1072 Some(Self {
1073 prompt_name,
1074 arg_value,
1075 explicit_server_id: Some(server_id),
1076 })
1077 } else {
1078 Some(Self {
1079 prompt_name: command,
1080 arg_value,
1081 explicit_server_id: None,
1082 })
1083 }
1084 }
1085}
1086
1087struct NativeAgentModelSelector {
1088 session_id: acp::SessionId,
1089 connection: NativeAgentConnection,
1090}
1091
1092impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1093 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1094 log::debug!("NativeAgentConnection::list_models called");
1095 let list = self.connection.0.read(cx).models.model_list.clone();
1096 Task::ready(if list.is_empty() {
1097 Err(anyhow::anyhow!("No models available"))
1098 } else {
1099 Ok(list)
1100 })
1101 }
1102
1103 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1104 log::debug!(
1105 "Setting model for session {}: {}",
1106 self.session_id,
1107 model_id
1108 );
1109 let Some(thread) = self
1110 .connection
1111 .0
1112 .read(cx)
1113 .sessions
1114 .get(&self.session_id)
1115 .map(|session| session.thread.clone())
1116 else {
1117 return Task::ready(Err(anyhow!("Session not found")));
1118 };
1119
1120 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1121 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1122 };
1123
1124 thread.update(cx, |thread, cx| {
1125 thread.set_model(model.clone(), cx);
1126 });
1127
1128 update_settings_file(
1129 self.connection.0.read(cx).fs.clone(),
1130 cx,
1131 move |settings, _cx| {
1132 let provider = model.provider_id().0.to_string();
1133 let model = model.id().0.to_string();
1134 settings
1135 .agent
1136 .get_or_insert_default()
1137 .set_model(LanguageModelSelection {
1138 provider: provider.into(),
1139 model,
1140 });
1141 },
1142 );
1143
1144 Task::ready(Ok(()))
1145 }
1146
1147 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1148 let Some(thread) = self
1149 .connection
1150 .0
1151 .read(cx)
1152 .sessions
1153 .get(&self.session_id)
1154 .map(|session| session.thread.clone())
1155 else {
1156 return Task::ready(Err(anyhow!("Session not found")));
1157 };
1158 let Some(model) = thread.read(cx).model() else {
1159 return Task::ready(Err(anyhow!("Model not found")));
1160 };
1161 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1162 else {
1163 return Task::ready(Err(anyhow!("Provider not found")));
1164 };
1165 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1166 model, &provider,
1167 )))
1168 }
1169
1170 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1171 Some(self.connection.0.read(cx).models.watch())
1172 }
1173
1174 fn should_render_footer(&self) -> bool {
1175 true
1176 }
1177}
1178
1179impl acp_thread::AgentConnection for NativeAgentConnection {
1180 fn telemetry_id(&self) -> SharedString {
1181 "zed".into()
1182 }
1183
1184 fn new_thread(
1185 self: Rc<Self>,
1186 project: Entity<Project>,
1187 cwd: &Path,
1188 cx: &mut App,
1189 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1190 let agent = self.0.clone();
1191 log::debug!("Creating new thread for project at: {:?}", cwd);
1192
1193 cx.spawn(async move |cx| {
1194 log::debug!("Starting thread creation in async context");
1195
1196 // Create Thread
1197 let thread = agent.update(cx, |agent, cx| {
1198 // Fetch default model from registry settings
1199 let registry = LanguageModelRegistry::read_global(cx);
1200 // Log available models for debugging
1201 let available_count = registry.available_models(cx).count();
1202 log::debug!("Total available models: {}", available_count);
1203
1204 let default_model = registry.default_model().and_then(|default_model| {
1205 agent
1206 .models
1207 .model_from_id(&LanguageModels::model_id(&default_model.model))
1208 });
1209 cx.new(|cx| {
1210 Thread::new(
1211 project.clone(),
1212 agent.project_context.clone(),
1213 agent.context_server_registry.clone(),
1214 agent.templates.clone(),
1215 default_model,
1216 cx,
1217 )
1218 })
1219 });
1220 Ok(agent.update(cx, |agent, cx| agent.register_session(thread, cx)))
1221 })
1222 }
1223
1224 fn supports_load_session(&self, _cx: &App) -> bool {
1225 true
1226 }
1227
1228 fn load_session(
1229 self: Rc<Self>,
1230 session: AgentSessionInfo,
1231 _project: Entity<Project>,
1232 _cwd: &Path,
1233 cx: &mut App,
1234 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1235 self.0
1236 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1237 }
1238
1239 fn auth_methods(&self) -> &[acp::AuthMethod] {
1240 &[] // No auth for in-process
1241 }
1242
1243 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1244 Task::ready(Ok(()))
1245 }
1246
1247 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1248 Some(Rc::new(NativeAgentModelSelector {
1249 session_id: session_id.clone(),
1250 connection: self.clone(),
1251 }) as Rc<dyn AgentModelSelector>)
1252 }
1253
1254 fn prompt(
1255 &self,
1256 id: Option<acp_thread::UserMessageId>,
1257 params: acp::PromptRequest,
1258 cx: &mut App,
1259 ) -> Task<Result<acp::PromptResponse>> {
1260 let id = id.expect("UserMessageId is required");
1261 let session_id = params.session_id.clone();
1262 log::info!("Received prompt request for session: {}", session_id);
1263 log::debug!("Prompt blocks count: {}", params.prompt.len());
1264
1265 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1266 let registry = self.0.read(cx).context_server_registry.read(cx);
1267
1268 let explicit_server_id = parsed_command
1269 .explicit_server_id
1270 .map(|server_id| ContextServerId(server_id.into()));
1271
1272 if let Some(prompt) =
1273 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1274 {
1275 let arguments = if !parsed_command.arg_value.is_empty()
1276 && let Some(arg_name) = prompt
1277 .prompt
1278 .arguments
1279 .as_ref()
1280 .and_then(|args| args.first())
1281 .map(|arg| arg.name.clone())
1282 {
1283 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1284 } else {
1285 Default::default()
1286 };
1287
1288 let prompt_name = prompt.prompt.name.clone();
1289 let server_id = prompt.server_id.clone();
1290
1291 return self.0.update(cx, |agent, cx| {
1292 agent.send_mcp_prompt(
1293 id,
1294 session_id.clone(),
1295 prompt_name,
1296 server_id,
1297 arguments,
1298 params.prompt,
1299 cx,
1300 )
1301 });
1302 };
1303 };
1304
1305 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1306
1307 self.run_turn(session_id, cx, move |thread, cx| {
1308 let content: Vec<UserMessageContent> = params
1309 .prompt
1310 .into_iter()
1311 .map(|block| UserMessageContent::from_content_block(block, path_style))
1312 .collect::<Vec<_>>();
1313 log::debug!("Converted prompt to message: {} chars", content.len());
1314 log::debug!("Message id: {:?}", id);
1315 log::debug!("Message content: {:?}", content);
1316
1317 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1318 })
1319 }
1320
1321 fn retry(
1322 &self,
1323 session_id: &acp::SessionId,
1324 _cx: &App,
1325 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1326 Some(Rc::new(NativeAgentSessionRetry {
1327 connection: self.clone(),
1328 session_id: session_id.clone(),
1329 }) as _)
1330 }
1331
1332 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1333 log::info!("Cancelling on session: {}", session_id);
1334 self.0.update(cx, |agent, cx| {
1335 if let Some(agent) = agent.sessions.get(session_id) {
1336 agent
1337 .thread
1338 .update(cx, |thread, cx| thread.cancel(cx))
1339 .detach();
1340 }
1341 });
1342 }
1343
1344 fn truncate(
1345 &self,
1346 session_id: &agent_client_protocol::SessionId,
1347 cx: &App,
1348 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1349 self.0.read_with(cx, |agent, _cx| {
1350 agent.sessions.get(session_id).map(|session| {
1351 Rc::new(NativeAgentSessionTruncate {
1352 thread: session.thread.clone(),
1353 acp_thread: session.acp_thread.clone(),
1354 }) as _
1355 })
1356 })
1357 }
1358
1359 fn set_title(
1360 &self,
1361 session_id: &acp::SessionId,
1362 _cx: &App,
1363 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1364 Some(Rc::new(NativeAgentSessionSetTitle {
1365 connection: self.clone(),
1366 session_id: session_id.clone(),
1367 }) as _)
1368 }
1369
1370 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1371 let thread_store = self.0.read(cx).thread_store.clone();
1372 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1373 }
1374
1375 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1376 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1377 }
1378
1379 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1380 self
1381 }
1382}
1383
1384impl acp_thread::AgentTelemetry for NativeAgentConnection {
1385 fn thread_data(
1386 &self,
1387 session_id: &acp::SessionId,
1388 cx: &mut App,
1389 ) -> Task<Result<serde_json::Value>> {
1390 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1391 return Task::ready(Err(anyhow!("Session not found")));
1392 };
1393
1394 let task = session.thread.read(cx).to_db(cx);
1395 cx.background_spawn(async move {
1396 serde_json::to_value(task.await).context("Failed to serialize thread")
1397 })
1398 }
1399}
1400
1401pub struct NativeAgentSessionList {
1402 thread_store: Entity<ThreadStore>,
1403 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1404 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1405 _subscription: Subscription,
1406}
1407
1408impl NativeAgentSessionList {
1409 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1410 let (tx, rx) = smol::channel::unbounded();
1411 let this_tx = tx.clone();
1412 let subscription = cx.observe(&thread_store, move |_, _| {
1413 this_tx
1414 .try_send(acp_thread::SessionListUpdate::Refresh)
1415 .ok();
1416 });
1417 Self {
1418 thread_store,
1419 updates_tx: tx,
1420 updates_rx: rx,
1421 _subscription: subscription,
1422 }
1423 }
1424
1425 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1426 AgentSessionInfo {
1427 session_id: entry.id,
1428 cwd: None,
1429 title: Some(entry.title),
1430 updated_at: Some(entry.updated_at),
1431 meta: None,
1432 }
1433 }
1434
1435 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1436 &self.thread_store
1437 }
1438}
1439
1440impl AgentSessionList for NativeAgentSessionList {
1441 fn list_sessions(
1442 &self,
1443 _request: AgentSessionListRequest,
1444 cx: &mut App,
1445 ) -> Task<Result<AgentSessionListResponse>> {
1446 let sessions = self
1447 .thread_store
1448 .read(cx)
1449 .entries()
1450 .map(Self::to_session_info)
1451 .collect();
1452 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1453 }
1454
1455 fn supports_delete(&self) -> bool {
1456 true
1457 }
1458
1459 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1460 self.thread_store
1461 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1462 }
1463
1464 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1465 self.thread_store
1466 .update(cx, |store, cx| store.delete_threads(cx))
1467 }
1468
1469 fn watch(
1470 &self,
1471 _cx: &mut App,
1472 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1473 Some(self.updates_rx.clone())
1474 }
1475
1476 fn notify_refresh(&self) {
1477 self.updates_tx
1478 .try_send(acp_thread::SessionListUpdate::Refresh)
1479 .ok();
1480 }
1481
1482 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1483 self
1484 }
1485}
1486
1487struct NativeAgentSessionTruncate {
1488 thread: Entity<Thread>,
1489 acp_thread: WeakEntity<AcpThread>,
1490}
1491
1492impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1493 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1494 match self.thread.update(cx, |thread, cx| {
1495 thread.truncate(message_id.clone(), cx)?;
1496 Ok(thread.latest_token_usage())
1497 }) {
1498 Ok(usage) => {
1499 self.acp_thread
1500 .update(cx, |thread, cx| {
1501 thread.update_token_usage(usage, cx);
1502 })
1503 .ok();
1504 Task::ready(Ok(()))
1505 }
1506 Err(error) => Task::ready(Err(error)),
1507 }
1508 }
1509}
1510
1511struct NativeAgentSessionRetry {
1512 connection: NativeAgentConnection,
1513 session_id: acp::SessionId,
1514}
1515
1516impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1517 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1518 self.connection
1519 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1520 thread.update(cx, |thread, cx| thread.resume(cx))
1521 })
1522 }
1523}
1524
1525struct NativeAgentSessionSetTitle {
1526 connection: NativeAgentConnection,
1527 session_id: acp::SessionId,
1528}
1529
1530impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1531 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1532 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1533 return Task::ready(Err(anyhow!("session not found")));
1534 };
1535 let thread = session.thread.clone();
1536 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1537 Task::ready(Ok(()))
1538 }
1539}
1540
1541pub struct AcpThreadEnvironment {
1542 acp_thread: WeakEntity<AcpThread>,
1543}
1544
1545impl ThreadEnvironment for AcpThreadEnvironment {
1546 fn create_terminal(
1547 &self,
1548 command: String,
1549 cwd: Option<PathBuf>,
1550 output_byte_limit: Option<u64>,
1551 cx: &mut AsyncApp,
1552 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1553 let task = self.acp_thread.update(cx, |thread, cx| {
1554 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1555 });
1556
1557 let acp_thread = self.acp_thread.clone();
1558 cx.spawn(async move |cx| {
1559 let terminal = task?.await?;
1560
1561 let (drop_tx, drop_rx) = oneshot::channel();
1562 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1563
1564 cx.spawn(async move |cx| {
1565 drop_rx.await.ok();
1566 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1567 })
1568 .detach();
1569
1570 let handle = AcpTerminalHandle {
1571 terminal,
1572 _drop_tx: Some(drop_tx),
1573 };
1574
1575 Ok(Rc::new(handle) as _)
1576 })
1577 }
1578}
1579
1580pub struct AcpTerminalHandle {
1581 terminal: Entity<acp_thread::Terminal>,
1582 _drop_tx: Option<oneshot::Sender<()>>,
1583}
1584
1585impl TerminalHandle for AcpTerminalHandle {
1586 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1587 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1588 }
1589
1590 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1591 Ok(self
1592 .terminal
1593 .read_with(cx, |term, _cx| term.wait_for_exit()))
1594 }
1595
1596 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1597 Ok(self
1598 .terminal
1599 .read_with(cx, |term, cx| term.current_output(cx)))
1600 }
1601
1602 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1603 cx.update(|cx| {
1604 self.terminal.update(cx, |terminal, cx| {
1605 terminal.kill(cx);
1606 });
1607 });
1608 Ok(())
1609 }
1610
1611 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1612 Ok(self
1613 .terminal
1614 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1615 }
1616}
1617
1618#[cfg(test)]
1619mod internal_tests {
1620 use super::*;
1621 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1622 use fs::FakeFs;
1623 use gpui::TestAppContext;
1624 use indoc::formatdoc;
1625 use language_model::fake_provider::FakeLanguageModel;
1626 use serde_json::json;
1627 use settings::SettingsStore;
1628 use util::{path, rel_path::rel_path};
1629
1630 #[gpui::test]
1631 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1632 init_test(cx);
1633 let fs = FakeFs::new(cx.executor());
1634 fs.insert_tree(
1635 "/",
1636 json!({
1637 "a": {}
1638 }),
1639 )
1640 .await;
1641 let project = Project::test(fs.clone(), [], cx).await;
1642 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1643 let agent = NativeAgent::new(
1644 project.clone(),
1645 thread_store,
1646 Templates::new(),
1647 None,
1648 fs.clone(),
1649 &mut cx.to_async(),
1650 )
1651 .await
1652 .unwrap();
1653 agent.read_with(cx, |agent, cx| {
1654 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1655 });
1656
1657 let worktree = project
1658 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1659 .await
1660 .unwrap();
1661 cx.run_until_parked();
1662 agent.read_with(cx, |agent, cx| {
1663 assert_eq!(
1664 agent.project_context.read(cx).worktrees,
1665 vec![WorktreeContext {
1666 root_name: "a".into(),
1667 abs_path: Path::new("/a").into(),
1668 rules_file: None
1669 }]
1670 )
1671 });
1672
1673 // Creating `/a/.rules` updates the project context.
1674 fs.insert_file("/a/.rules", Vec::new()).await;
1675 cx.run_until_parked();
1676 agent.read_with(cx, |agent, cx| {
1677 let rules_entry = worktree
1678 .read(cx)
1679 .entry_for_path(rel_path(".rules"))
1680 .unwrap();
1681 assert_eq!(
1682 agent.project_context.read(cx).worktrees,
1683 vec![WorktreeContext {
1684 root_name: "a".into(),
1685 abs_path: Path::new("/a").into(),
1686 rules_file: Some(RulesFileContext {
1687 path_in_worktree: rel_path(".rules").into(),
1688 text: "".into(),
1689 project_entry_id: rules_entry.id.to_usize()
1690 })
1691 }]
1692 )
1693 });
1694 }
1695
1696 #[gpui::test]
1697 async fn test_listing_models(cx: &mut TestAppContext) {
1698 init_test(cx);
1699 let fs = FakeFs::new(cx.executor());
1700 fs.insert_tree("/", json!({ "a": {} })).await;
1701 let project = Project::test(fs.clone(), [], cx).await;
1702 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1703 let connection = NativeAgentConnection(
1704 NativeAgent::new(
1705 project.clone(),
1706 thread_store,
1707 Templates::new(),
1708 None,
1709 fs.clone(),
1710 &mut cx.to_async(),
1711 )
1712 .await
1713 .unwrap(),
1714 );
1715
1716 // Create a thread/session
1717 let acp_thread = cx
1718 .update(|cx| {
1719 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1720 })
1721 .await
1722 .unwrap();
1723
1724 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1725
1726 let models = cx
1727 .update(|cx| {
1728 connection
1729 .model_selector(&session_id)
1730 .unwrap()
1731 .list_models(cx)
1732 })
1733 .await
1734 .unwrap();
1735
1736 let acp_thread::AgentModelList::Grouped(models) = models else {
1737 panic!("Unexpected model group");
1738 };
1739 assert_eq!(
1740 models,
1741 IndexMap::from_iter([(
1742 AgentModelGroupName("Fake".into()),
1743 vec![AgentModelInfo {
1744 id: acp::ModelId::new("fake/fake"),
1745 name: "Fake".into(),
1746 description: None,
1747 icon: Some(acp_thread::AgentModelIcon::Named(
1748 ui::IconName::ZedAssistant
1749 )),
1750 }]
1751 )])
1752 );
1753 }
1754
1755 #[gpui::test]
1756 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1757 init_test(cx);
1758 let fs = FakeFs::new(cx.executor());
1759 fs.create_dir(paths::settings_file().parent().unwrap())
1760 .await
1761 .unwrap();
1762 fs.insert_file(
1763 paths::settings_file(),
1764 json!({
1765 "agent": {
1766 "default_model": {
1767 "provider": "foo",
1768 "model": "bar"
1769 }
1770 }
1771 })
1772 .to_string()
1773 .into_bytes(),
1774 )
1775 .await;
1776 let project = Project::test(fs.clone(), [], cx).await;
1777
1778 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1779
1780 // Create the agent and connection
1781 let agent = NativeAgent::new(
1782 project.clone(),
1783 thread_store,
1784 Templates::new(),
1785 None,
1786 fs.clone(),
1787 &mut cx.to_async(),
1788 )
1789 .await
1790 .unwrap();
1791 let connection = NativeAgentConnection(agent.clone());
1792
1793 // Create a thread/session
1794 let acp_thread = cx
1795 .update(|cx| {
1796 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1797 })
1798 .await
1799 .unwrap();
1800
1801 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1802
1803 // Select a model
1804 let selector = connection.model_selector(&session_id).unwrap();
1805 let model_id = acp::ModelId::new("fake/fake");
1806 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1807 .await
1808 .unwrap();
1809
1810 // Verify the thread has the selected model
1811 agent.read_with(cx, |agent, _| {
1812 let session = agent.sessions.get(&session_id).unwrap();
1813 session.thread.read_with(cx, |thread, _| {
1814 assert_eq!(thread.model().unwrap().id().0, "fake");
1815 });
1816 });
1817
1818 cx.run_until_parked();
1819
1820 // Verify settings file was updated
1821 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1822 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1823
1824 // Check that the agent settings contain the selected model
1825 assert_eq!(
1826 settings_json["agent"]["default_model"]["model"],
1827 json!("fake")
1828 );
1829 assert_eq!(
1830 settings_json["agent"]["default_model"]["provider"],
1831 json!("fake")
1832 );
1833 }
1834
1835 #[gpui::test]
1836 async fn test_save_load_thread(cx: &mut TestAppContext) {
1837 init_test(cx);
1838 let fs = FakeFs::new(cx.executor());
1839 fs.insert_tree(
1840 "/",
1841 json!({
1842 "a": {
1843 "b.md": "Lorem"
1844 }
1845 }),
1846 )
1847 .await;
1848 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1849 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1850 let agent = NativeAgent::new(
1851 project.clone(),
1852 thread_store.clone(),
1853 Templates::new(),
1854 None,
1855 fs.clone(),
1856 &mut cx.to_async(),
1857 )
1858 .await
1859 .unwrap();
1860 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1861
1862 let acp_thread = cx
1863 .update(|cx| {
1864 connection
1865 .clone()
1866 .new_thread(project.clone(), Path::new(""), cx)
1867 })
1868 .await
1869 .unwrap();
1870 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1871 let thread = agent.read_with(cx, |agent, _| {
1872 agent.sessions.get(&session_id).unwrap().thread.clone()
1873 });
1874
1875 // Ensure empty threads are not saved, even if they get mutated.
1876 let model = Arc::new(FakeLanguageModel::default());
1877 let summary_model = Arc::new(FakeLanguageModel::default());
1878 thread.update(cx, |thread, cx| {
1879 thread.set_model(model.clone(), cx);
1880 thread.set_summarization_model(Some(summary_model.clone()), cx);
1881 });
1882 cx.run_until_parked();
1883 assert_eq!(thread_entries(&thread_store, cx), vec![]);
1884
1885 let send = acp_thread.update(cx, |thread, cx| {
1886 thread.send(
1887 vec![
1888 "What does ".into(),
1889 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1890 "b.md",
1891 MentionUri::File {
1892 abs_path: path!("/a/b.md").into(),
1893 }
1894 .to_uri()
1895 .to_string(),
1896 )),
1897 " mean?".into(),
1898 ],
1899 cx,
1900 )
1901 });
1902 let send = cx.foreground_executor().spawn(send);
1903 cx.run_until_parked();
1904
1905 model.send_last_completion_stream_text_chunk("Lorem.");
1906 model.end_last_completion_stream();
1907 cx.run_until_parked();
1908 summary_model
1909 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1910 summary_model.end_last_completion_stream();
1911
1912 send.await.unwrap();
1913 let uri = MentionUri::File {
1914 abs_path: path!("/a/b.md").into(),
1915 }
1916 .to_uri();
1917 acp_thread.read_with(cx, |thread, cx| {
1918 assert_eq!(
1919 thread.to_markdown(cx),
1920 formatdoc! {"
1921 ## User
1922
1923 What does [@b.md]({uri}) mean?
1924
1925 ## Assistant
1926
1927 Lorem.
1928
1929 "}
1930 )
1931 });
1932
1933 cx.run_until_parked();
1934
1935 // Drop the ACP thread, which should cause the session to be dropped as well.
1936 cx.update(|_| {
1937 drop(thread);
1938 drop(acp_thread);
1939 });
1940 agent.read_with(cx, |agent, _| {
1941 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1942 });
1943
1944 // Ensure the thread can be reloaded from disk.
1945 assert_eq!(
1946 thread_entries(&thread_store, cx),
1947 vec![(
1948 session_id.clone(),
1949 format!("Explaining {}", path!("/a/b.md"))
1950 )]
1951 );
1952 let acp_thread = agent
1953 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1954 .await
1955 .unwrap();
1956 acp_thread.read_with(cx, |thread, cx| {
1957 assert_eq!(
1958 thread.to_markdown(cx),
1959 formatdoc! {"
1960 ## User
1961
1962 What does [@b.md]({uri}) mean?
1963
1964 ## Assistant
1965
1966 Lorem.
1967
1968 "}
1969 )
1970 });
1971 }
1972
1973 fn thread_entries(
1974 thread_store: &Entity<ThreadStore>,
1975 cx: &mut TestAppContext,
1976 ) -> Vec<(acp::SessionId, String)> {
1977 thread_store.read_with(cx, |store, _| {
1978 store
1979 .entries()
1980 .map(|entry| (entry.id.clone(), entry.title.to_string()))
1981 .collect::<Vec<_>>()
1982 })
1983 }
1984
1985 fn init_test(cx: &mut TestAppContext) {
1986 env_logger::try_init().ok();
1987 cx.update(|cx| {
1988 let settings_store = SettingsStore::test(cx);
1989 cx.set_global(settings_store);
1990
1991 LanguageModelRegistry::test(cx);
1992 });
1993 }
1994}
1995
1996fn mcp_message_content_to_acp_content_block(
1997 content: context_server::types::MessageContent,
1998) -> acp::ContentBlock {
1999 match content {
2000 context_server::types::MessageContent::Text {
2001 text,
2002 annotations: _,
2003 } => text.into(),
2004 context_server::types::MessageContent::Image {
2005 data,
2006 mime_type,
2007 annotations: _,
2008 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2009 context_server::types::MessageContent::Audio {
2010 data,
2011 mime_type,
2012 annotations: _,
2013 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2014 context_server::types::MessageContent::Resource {
2015 resource,
2016 annotations: _,
2017 } => {
2018 let mut link =
2019 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2020 if let Some(mime_type) = resource.mime_type {
2021 link = link.mime_type(mime_type);
2022 }
2023 acp::ContentBlock::ResourceLink(link)
2024 }
2025 }
2026}