1mod db;
2mod edit_agent;
3mod history_store;
4mod legacy_thread;
5mod native_agent_server;
6pub mod outline;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod tools;
12
13use context_server::ContextServerId;
14pub use db::*;
15pub use history_store::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use tools::*;
20
21use acp_thread::{AcpThread, AgentModelSelector, UserMessageId};
22use agent_client_protocol as acp;
23use anyhow::{Context as _, Result, anyhow};
24use chrono::{DateTime, Utc};
25use collections::{HashMap, HashSet, IndexMap};
26use fs::Fs;
27use futures::channel::{mpsc, oneshot};
28use futures::future::Shared;
29use futures::{StreamExt, future};
30use gpui::{
31 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
32};
33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
34use project::{Project, ProjectItem, ProjectPath, Worktree};
35use prompt_store::{
36 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
37 WorktreeContext,
38};
39use serde::{Deserialize, Serialize};
40use settings::{LanguageModelSelection, update_settings_file};
41use std::any::Any;
42use std::path::{Path, PathBuf};
43use std::rc::Rc;
44use std::sync::Arc;
45use util::ResultExt;
46use util::rel_path::RelPath;
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ProjectSnapshot {
50 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
51 pub timestamp: DateTime<Utc>,
52}
53
54pub struct RulesLoadingError {
55 pub message: SharedString,
56}
57
58/// Holds both the internal Thread and the AcpThread for a session
59struct Session {
60 /// The internal thread that processes messages
61 thread: Entity<Thread>,
62 /// The ACP thread that handles protocol communication
63 acp_thread: WeakEntity<acp_thread::AcpThread>,
64 pending_save: Task<()>,
65 _subscriptions: Vec<Subscription>,
66}
67
68pub struct LanguageModels {
69 /// Access language model by ID
70 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
71 /// Cached list for returning language model information
72 model_list: acp_thread::AgentModelList,
73 refresh_models_rx: watch::Receiver<()>,
74 refresh_models_tx: watch::Sender<()>,
75 _authenticate_all_providers_task: Task<()>,
76}
77
78impl LanguageModels {
79 fn new(cx: &mut App) -> Self {
80 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
81
82 let mut this = Self {
83 models: HashMap::default(),
84 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
85 refresh_models_rx,
86 refresh_models_tx,
87 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
88 };
89 this.refresh_list(cx);
90 this
91 }
92
93 fn refresh_list(&mut self, cx: &App) {
94 let providers = LanguageModelRegistry::global(cx)
95 .read(cx)
96 .providers()
97 .into_iter()
98 .filter(|provider| provider.is_authenticated(cx))
99 .collect::<Vec<_>>();
100
101 let mut language_model_list = IndexMap::default();
102 let mut recommended_models = HashSet::default();
103
104 let mut recommended = Vec::new();
105 for provider in &providers {
106 for model in provider.recommended_models(cx) {
107 recommended_models.insert((model.provider_id(), model.id()));
108 recommended.push(Self::map_language_model_to_info(&model, provider));
109 }
110 }
111 if !recommended.is_empty() {
112 language_model_list.insert(
113 acp_thread::AgentModelGroupName("Recommended".into()),
114 recommended,
115 );
116 }
117
118 let mut models = HashMap::default();
119 for provider in providers {
120 let mut provider_models = Vec::new();
121 for model in provider.provided_models(cx) {
122 let model_info = Self::map_language_model_to_info(&model, &provider);
123 let model_id = model_info.id.clone();
124 provider_models.push(model_info);
125 models.insert(model_id, model);
126 }
127 if !provider_models.is_empty() {
128 language_model_list.insert(
129 acp_thread::AgentModelGroupName(provider.name().0.clone()),
130 provider_models,
131 );
132 }
133 }
134
135 self.models = models;
136 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
137 self.refresh_models_tx.send(()).ok();
138 }
139
140 fn watch(&self) -> watch::Receiver<()> {
141 self.refresh_models_rx.clone()
142 }
143
144 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
145 self.models.get(model_id).cloned()
146 }
147
148 fn map_language_model_to_info(
149 model: &Arc<dyn LanguageModel>,
150 provider: &Arc<dyn LanguageModelProvider>,
151 ) -> acp_thread::AgentModelInfo {
152 acp_thread::AgentModelInfo {
153 id: Self::model_id(model),
154 name: model.name().0,
155 description: None,
156 icon: Some(provider.icon()),
157 }
158 }
159
160 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
161 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
162 }
163
164 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
165 let authenticate_all_providers = LanguageModelRegistry::global(cx)
166 .read(cx)
167 .providers()
168 .iter()
169 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
170 .collect::<Vec<_>>();
171
172 cx.background_spawn(async move {
173 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
174 if let Err(err) = authenticate_task.await {
175 match err {
176 language_model::AuthenticateError::CredentialsNotFound => {
177 // Since we're authenticating these providers in the
178 // background for the purposes of populating the
179 // language selector, we don't care about providers
180 // where the credentials are not found.
181 }
182 language_model::AuthenticateError::ConnectionRefused => {
183 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
184 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
185 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
186 }
187 _ => {
188 // Some providers have noisy failure states that we
189 // don't want to spam the logs with every time the
190 // language model selector is initialized.
191 //
192 // Ideally these should have more clear failure modes
193 // that we know are safe to ignore here, like what we do
194 // with `CredentialsNotFound` above.
195 match provider_id.0.as_ref() {
196 "lmstudio" | "ollama" => {
197 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
198 //
199 // These fail noisily, so we don't log them.
200 }
201 "copilot_chat" => {
202 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
203 }
204 _ => {
205 log::error!(
206 "Failed to authenticate provider: {}: {err:#}",
207 provider_name.0
208 );
209 }
210 }
211 }
212 }
213 }
214 }
215 })
216 }
217}
218
219pub struct NativeAgent {
220 /// Session ID -> Session mapping
221 sessions: HashMap<acp::SessionId, Session>,
222 history: Entity<HistoryStore>,
223 /// Shared project context for all threads
224 project_context: Entity<ProjectContext>,
225 project_context_needs_refresh: watch::Sender<()>,
226 _maintain_project_context: Task<Result<()>>,
227 context_server_registry: Entity<ContextServerRegistry>,
228 /// Shared templates for all threads
229 templates: Arc<Templates>,
230 /// Cached model information
231 models: LanguageModels,
232 project: Entity<Project>,
233 prompt_store: Option<Entity<PromptStore>>,
234 fs: Arc<dyn Fs>,
235 _subscriptions: Vec<Subscription>,
236}
237
238impl NativeAgent {
239 pub async fn new(
240 project: Entity<Project>,
241 history: Entity<HistoryStore>,
242 templates: Arc<Templates>,
243 prompt_store: Option<Entity<PromptStore>>,
244 fs: Arc<dyn Fs>,
245 cx: &mut AsyncApp,
246 ) -> Result<Entity<NativeAgent>> {
247 log::debug!("Creating new NativeAgent");
248
249 let project_context = cx
250 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
251 .await;
252
253 cx.new(|cx| {
254 let context_server_store = project.read(cx).context_server_store();
255 let context_server_registry =
256 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
257
258 let mut subscriptions = vec![
259 cx.subscribe(&project, Self::handle_project_event),
260 cx.subscribe(
261 &LanguageModelRegistry::global(cx),
262 Self::handle_models_updated_event,
263 ),
264 cx.subscribe(
265 &context_server_store,
266 Self::handle_context_server_store_updated,
267 ),
268 cx.subscribe(
269 &context_server_registry,
270 Self::handle_context_server_registry_event,
271 ),
272 ];
273 if let Some(prompt_store) = prompt_store.as_ref() {
274 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
275 }
276
277 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
278 watch::channel(());
279 Self {
280 sessions: HashMap::default(),
281 history,
282 project_context: cx.new(|_| project_context),
283 project_context_needs_refresh: project_context_needs_refresh_tx,
284 _maintain_project_context: cx.spawn(async move |this, cx| {
285 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
286 }),
287 context_server_registry,
288 templates,
289 models: LanguageModels::new(cx),
290 project,
291 prompt_store,
292 fs,
293 _subscriptions: subscriptions,
294 }
295 })
296 }
297
298 fn register_session(
299 &mut self,
300 thread_handle: Entity<Thread>,
301 cx: &mut Context<Self>,
302 ) -> Entity<AcpThread> {
303 let connection = Rc::new(NativeAgentConnection(cx.entity()));
304
305 let thread = thread_handle.read(cx);
306 let session_id = thread.id().clone();
307 let title = thread.title();
308 let project = thread.project.clone();
309 let action_log = thread.action_log.clone();
310 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
311 let acp_thread = cx.new(|cx| {
312 acp_thread::AcpThread::new(
313 title,
314 connection,
315 project.clone(),
316 action_log.clone(),
317 session_id.clone(),
318 prompt_capabilities_rx,
319 cx,
320 )
321 });
322
323 let registry = LanguageModelRegistry::read_global(cx);
324 let summarization_model = registry.thread_summary_model().map(|c| c.model);
325
326 thread_handle.update(cx, |thread, cx| {
327 thread.set_summarization_model(summarization_model, cx);
328 thread.add_default_tools(
329 Rc::new(AcpThreadEnvironment {
330 acp_thread: acp_thread.downgrade(),
331 }) as _,
332 cx,
333 )
334 });
335
336 let subscriptions = vec![
337 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
338 this.sessions.remove(acp_thread.session_id());
339 }),
340 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
341 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
342 cx.observe(&thread_handle, move |this, thread, cx| {
343 this.save_thread(thread, cx)
344 }),
345 ];
346
347 self.sessions.insert(
348 session_id,
349 Session {
350 thread: thread_handle,
351 acp_thread: acp_thread.downgrade(),
352 _subscriptions: subscriptions,
353 pending_save: Task::ready(()),
354 },
355 );
356
357 self.update_available_commands(cx);
358
359 acp_thread
360 }
361
362 pub fn models(&self) -> &LanguageModels {
363 &self.models
364 }
365
366 async fn maintain_project_context(
367 this: WeakEntity<Self>,
368 mut needs_refresh: watch::Receiver<()>,
369 cx: &mut AsyncApp,
370 ) -> Result<()> {
371 while needs_refresh.changed().await.is_ok() {
372 let project_context = this
373 .update(cx, |this, cx| {
374 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
375 })?
376 .await;
377 this.update(cx, |this, cx| {
378 this.project_context = cx.new(|_| project_context);
379 })?;
380 }
381
382 Ok(())
383 }
384
385 fn build_project_context(
386 project: &Entity<Project>,
387 prompt_store: Option<&Entity<PromptStore>>,
388 cx: &mut App,
389 ) -> Task<ProjectContext> {
390 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
391 let worktree_tasks = worktrees
392 .into_iter()
393 .map(|worktree| {
394 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
395 })
396 .collect::<Vec<_>>();
397 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
398 prompt_store.read_with(cx, |prompt_store, cx| {
399 let prompts = prompt_store.default_prompt_metadata();
400 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
401 let contents = prompt_store.load(prompt_metadata.id, cx);
402 async move { (contents.await, prompt_metadata) }
403 });
404 cx.background_spawn(future::join_all(load_tasks))
405 })
406 } else {
407 Task::ready(vec![])
408 };
409
410 cx.spawn(async move |_cx| {
411 let (worktrees, default_user_rules) =
412 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
413
414 let worktrees = worktrees
415 .into_iter()
416 .map(|(worktree, _rules_error)| {
417 // TODO: show error message
418 // if let Some(rules_error) = rules_error {
419 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
420 // }
421 worktree
422 })
423 .collect::<Vec<_>>();
424
425 let default_user_rules = default_user_rules
426 .into_iter()
427 .flat_map(|(contents, prompt_metadata)| match contents {
428 Ok(contents) => Some(UserRulesContext {
429 uuid: prompt_metadata.id.user_id()?,
430 title: prompt_metadata.title.map(|title| title.to_string()),
431 contents,
432 }),
433 Err(_err) => {
434 // TODO: show error message
435 // this.update(cx, |_, cx| {
436 // cx.emit(RulesLoadingError {
437 // message: format!("{err:?}").into(),
438 // });
439 // })
440 // .ok();
441 None
442 }
443 })
444 .collect::<Vec<_>>();
445
446 ProjectContext::new(worktrees, default_user_rules)
447 })
448 }
449
450 fn load_worktree_info_for_system_prompt(
451 worktree: Entity<Worktree>,
452 project: Entity<Project>,
453 cx: &mut App,
454 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
455 let tree = worktree.read(cx);
456 let root_name = tree.root_name_str().into();
457 let abs_path = tree.abs_path();
458
459 let mut context = WorktreeContext {
460 root_name,
461 abs_path,
462 rules_file: None,
463 };
464
465 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
466 let Some(rules_task) = rules_task else {
467 return Task::ready((context, None));
468 };
469
470 cx.spawn(async move |_| {
471 let (rules_file, rules_file_error) = match rules_task.await {
472 Ok(rules_file) => (Some(rules_file), None),
473 Err(err) => (
474 None,
475 Some(RulesLoadingError {
476 message: format!("{err}").into(),
477 }),
478 ),
479 };
480 context.rules_file = rules_file;
481 (context, rules_file_error)
482 })
483 }
484
485 fn load_worktree_rules_file(
486 worktree: Entity<Worktree>,
487 project: Entity<Project>,
488 cx: &mut App,
489 ) -> Option<Task<Result<RulesFileContext>>> {
490 let worktree = worktree.read(cx);
491 let worktree_id = worktree.id();
492 let selected_rules_file = RULES_FILE_NAMES
493 .into_iter()
494 .filter_map(|name| {
495 worktree
496 .entry_for_path(RelPath::unix(name).unwrap())
497 .filter(|entry| entry.is_file())
498 .map(|entry| entry.path.clone())
499 })
500 .next();
501
502 // Note that Cline supports `.clinerules` being a directory, but that is not currently
503 // supported. This doesn't seem to occur often in GitHub repositories.
504 selected_rules_file.map(|path_in_worktree| {
505 let project_path = ProjectPath {
506 worktree_id,
507 path: path_in_worktree.clone(),
508 };
509 let buffer_task =
510 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
511 let rope_task = cx.spawn(async move |cx| {
512 buffer_task.await?.read_with(cx, |buffer, cx| {
513 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
514 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
515 })?
516 });
517 // Build a string from the rope on a background thread.
518 cx.background_spawn(async move {
519 let (project_entry_id, rope) = rope_task.await?;
520 anyhow::Ok(RulesFileContext {
521 path_in_worktree,
522 text: rope.to_string().trim().to_string(),
523 project_entry_id: project_entry_id.to_usize(),
524 })
525 })
526 })
527 }
528
529 fn handle_thread_title_updated(
530 &mut self,
531 thread: Entity<Thread>,
532 _: &TitleUpdated,
533 cx: &mut Context<Self>,
534 ) {
535 let session_id = thread.read(cx).id();
536 let Some(session) = self.sessions.get(session_id) else {
537 return;
538 };
539 let thread = thread.downgrade();
540 let acp_thread = session.acp_thread.clone();
541 cx.spawn(async move |_, cx| {
542 let title = thread.read_with(cx, |thread, _| thread.title())?;
543 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
544 task.await
545 })
546 .detach_and_log_err(cx);
547 }
548
549 fn handle_thread_token_usage_updated(
550 &mut self,
551 thread: Entity<Thread>,
552 usage: &TokenUsageUpdated,
553 cx: &mut Context<Self>,
554 ) {
555 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
556 return;
557 };
558 session
559 .acp_thread
560 .update(cx, |acp_thread, cx| {
561 acp_thread.update_token_usage(usage.0.clone(), cx);
562 })
563 .ok();
564 }
565
566 fn handle_project_event(
567 &mut self,
568 _project: Entity<Project>,
569 event: &project::Event,
570 _cx: &mut Context<Self>,
571 ) {
572 match event {
573 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
574 self.project_context_needs_refresh.send(()).ok();
575 }
576 project::Event::WorktreeUpdatedEntries(_, items) => {
577 if items.iter().any(|(path, _, _)| {
578 RULES_FILE_NAMES
579 .iter()
580 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
581 }) {
582 self.project_context_needs_refresh.send(()).ok();
583 }
584 }
585 _ => {}
586 }
587 }
588
589 fn handle_prompts_updated_event(
590 &mut self,
591 _prompt_store: Entity<PromptStore>,
592 _event: &prompt_store::PromptsUpdatedEvent,
593 _cx: &mut Context<Self>,
594 ) {
595 self.project_context_needs_refresh.send(()).ok();
596 }
597
598 fn handle_models_updated_event(
599 &mut self,
600 _registry: Entity<LanguageModelRegistry>,
601 _event: &language_model::Event,
602 cx: &mut Context<Self>,
603 ) {
604 self.models.refresh_list(cx);
605
606 let registry = LanguageModelRegistry::read_global(cx);
607 let default_model = registry.default_model().map(|m| m.model);
608 let summarization_model = registry.thread_summary_model().map(|m| m.model);
609
610 for session in self.sessions.values_mut() {
611 session.thread.update(cx, |thread, cx| {
612 if thread.model().is_none()
613 && let Some(model) = default_model.clone()
614 {
615 thread.set_model(model, cx);
616 cx.notify();
617 }
618 thread.set_summarization_model(summarization_model.clone(), cx);
619 });
620 }
621 }
622
623 fn handle_context_server_store_updated(
624 &mut self,
625 _store: Entity<project::context_server_store::ContextServerStore>,
626 _event: &project::context_server_store::Event,
627 cx: &mut Context<Self>,
628 ) {
629 self.update_available_commands(cx);
630 }
631
632 fn handle_context_server_registry_event(
633 &mut self,
634 _registry: Entity<ContextServerRegistry>,
635 event: &ContextServerRegistryEvent,
636 cx: &mut Context<Self>,
637 ) {
638 match event {
639 ContextServerRegistryEvent::ToolsChanged => {}
640 ContextServerRegistryEvent::PromptsChanged => {
641 self.update_available_commands(cx);
642 }
643 }
644 }
645
646 fn update_available_commands(&self, cx: &mut Context<Self>) {
647 let available_commands = self.build_available_commands(cx);
648 for session in self.sessions.values() {
649 if let Some(acp_thread) = session.acp_thread.upgrade() {
650 acp_thread.update(cx, |thread, cx| {
651 thread
652 .handle_session_update(
653 acp::SessionUpdate::AvailableCommandsUpdate(
654 acp::AvailableCommandsUpdate::new(available_commands.clone()),
655 ),
656 cx,
657 )
658 .log_err();
659 });
660 }
661 }
662 }
663
664 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
665 let registry = self.context_server_registry.read(cx);
666
667 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
668 for context_server_prompt in registry.prompts() {
669 *prompt_name_counts
670 .entry(context_server_prompt.prompt.name.as_str())
671 .or_insert(0) += 1;
672 }
673
674 registry
675 .prompts()
676 .flat_map(|context_server_prompt| {
677 let prompt = &context_server_prompt.prompt;
678
679 let should_prefix = prompt_name_counts
680 .get(prompt.name.as_str())
681 .copied()
682 .unwrap_or(0)
683 > 1;
684
685 let name = if should_prefix {
686 format!("{}.{}", context_server_prompt.server_id, prompt.name)
687 } else {
688 prompt.name.clone()
689 };
690
691 let mut command = acp::AvailableCommand::new(
692 name,
693 prompt.description.clone().unwrap_or_default(),
694 );
695
696 match prompt.arguments.as_deref() {
697 Some([arg]) => {
698 let hint = format!("<{}>", arg.name);
699
700 command = command.input(acp::AvailableCommandInput::Unstructured(
701 acp::UnstructuredCommandInput::new(hint),
702 ));
703 }
704 Some([]) | None => {}
705 Some(_) => {
706 // skip >1 argument commands since we don't support them yet
707 return None;
708 }
709 }
710
711 Some(command)
712 })
713 .collect()
714 }
715
716 pub fn load_thread(
717 &mut self,
718 id: acp::SessionId,
719 cx: &mut Context<Self>,
720 ) -> Task<Result<Entity<Thread>>> {
721 let database_future = ThreadsDatabase::connect(cx);
722 cx.spawn(async move |this, cx| {
723 let database = database_future.await.map_err(|err| anyhow!(err))?;
724 let db_thread = database
725 .load_thread(id.clone())
726 .await?
727 .with_context(|| format!("no thread found with ID: {id:?}"))?;
728
729 this.update(cx, |this, cx| {
730 let summarization_model = LanguageModelRegistry::read_global(cx)
731 .thread_summary_model()
732 .map(|c| c.model);
733
734 cx.new(|cx| {
735 let mut thread = Thread::from_db(
736 id.clone(),
737 db_thread,
738 this.project.clone(),
739 this.project_context.clone(),
740 this.context_server_registry.clone(),
741 this.templates.clone(),
742 cx,
743 );
744 thread.set_summarization_model(summarization_model, cx);
745 thread
746 })
747 })
748 })
749 }
750
751 pub fn open_thread(
752 &mut self,
753 id: acp::SessionId,
754 cx: &mut Context<Self>,
755 ) -> Task<Result<Entity<AcpThread>>> {
756 let task = self.load_thread(id, cx);
757 cx.spawn(async move |this, cx| {
758 let thread = task.await?;
759 let acp_thread =
760 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
761 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
762 cx.update(|cx| {
763 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
764 })?
765 .await?;
766 Ok(acp_thread)
767 })
768 }
769
770 pub fn thread_summary(
771 &mut self,
772 id: acp::SessionId,
773 cx: &mut Context<Self>,
774 ) -> Task<Result<SharedString>> {
775 let thread = self.open_thread(id.clone(), cx);
776 cx.spawn(async move |this, cx| {
777 let acp_thread = thread.await?;
778 let result = this
779 .update(cx, |this, cx| {
780 this.sessions
781 .get(&id)
782 .unwrap()
783 .thread
784 .update(cx, |thread, cx| thread.summary(cx))
785 })?
786 .await
787 .context("Failed to generate summary")?;
788 drop(acp_thread);
789 Ok(result)
790 })
791 }
792
793 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
794 if thread.read(cx).is_empty() {
795 return;
796 }
797
798 let database_future = ThreadsDatabase::connect(cx);
799 let (id, db_thread) =
800 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
801 let Some(session) = self.sessions.get_mut(&id) else {
802 return;
803 };
804 let history = self.history.clone();
805 session.pending_save = cx.spawn(async move |_, cx| {
806 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
807 return;
808 };
809 let db_thread = db_thread.await;
810 database.save_thread(id, db_thread).await.log_err();
811 history.update(cx, |history, cx| history.reload(cx)).ok();
812 });
813 }
814
815 fn send_mcp_prompt(
816 &self,
817 message_id: UserMessageId,
818 session_id: agent_client_protocol::SessionId,
819 prompt_name: String,
820 server_id: ContextServerId,
821 arguments: HashMap<String, String>,
822 original_content: Vec<acp::ContentBlock>,
823 cx: &mut Context<Self>,
824 ) -> Task<Result<acp::PromptResponse>> {
825 let server_store = self.context_server_registry.read(cx).server_store().clone();
826 let path_style = self.project.read(cx).path_style(cx);
827
828 cx.spawn(async move |this, cx| {
829 let prompt =
830 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
831
832 let (acp_thread, thread) = this.update(cx, |this, _cx| {
833 let session = this
834 .sessions
835 .get(&session_id)
836 .context("Failed to get session")?;
837 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
838 })??;
839
840 let mut last_is_user = true;
841
842 thread.update(cx, |thread, cx| {
843 thread.push_acp_user_block(
844 message_id,
845 original_content.into_iter().skip(1),
846 path_style,
847 cx,
848 );
849 })?;
850
851 for message in prompt.messages {
852 let context_server::types::PromptMessage { role, content } = message;
853 let block = mcp_message_content_to_acp_content_block(content);
854
855 match role {
856 context_server::types::Role::User => {
857 let id = acp_thread::UserMessageId::new();
858
859 acp_thread.update(cx, |acp_thread, cx| {
860 acp_thread.push_user_content_block_with_indent(
861 Some(id.clone()),
862 block.clone(),
863 true,
864 cx,
865 );
866 anyhow::Ok(())
867 })??;
868
869 thread.update(cx, |thread, cx| {
870 thread.push_acp_user_block(id, [block], path_style, cx);
871 anyhow::Ok(())
872 })??;
873 }
874 context_server::types::Role::Assistant => {
875 acp_thread.update(cx, |acp_thread, cx| {
876 acp_thread.push_assistant_content_block_with_indent(
877 block.clone(),
878 false,
879 true,
880 cx,
881 );
882 anyhow::Ok(())
883 })??;
884
885 thread.update(cx, |thread, cx| {
886 thread.push_acp_agent_block(block, cx);
887 anyhow::Ok(())
888 })??;
889 }
890 }
891
892 last_is_user = role == context_server::types::Role::User;
893 }
894
895 let response_stream = thread.update(cx, |thread, cx| {
896 if last_is_user {
897 thread.send_existing(cx)
898 } else {
899 // Resume if MCP prompt did not end with a user message
900 thread.resume(cx)
901 }
902 })??;
903
904 cx.update(|cx| {
905 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
906 })?
907 .await
908 })
909 }
910}
911
912/// Wrapper struct that implements the AgentConnection trait
913#[derive(Clone)]
914pub struct NativeAgentConnection(pub Entity<NativeAgent>);
915
916impl NativeAgentConnection {
917 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
918 self.0
919 .read(cx)
920 .sessions
921 .get(session_id)
922 .map(|session| session.thread.clone())
923 }
924
925 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
926 self.0.update(cx, |this, cx| this.load_thread(id, cx))
927 }
928
929 fn run_turn(
930 &self,
931 session_id: acp::SessionId,
932 cx: &mut App,
933 f: impl 'static
934 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
935 ) -> Task<Result<acp::PromptResponse>> {
936 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
937 agent
938 .sessions
939 .get_mut(&session_id)
940 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
941 }) else {
942 return Task::ready(Err(anyhow!("Session not found")));
943 };
944 log::debug!("Found session for: {}", session_id);
945
946 let response_stream = match f(thread, cx) {
947 Ok(stream) => stream,
948 Err(err) => return Task::ready(Err(err)),
949 };
950 Self::handle_thread_events(response_stream, acp_thread, cx)
951 }
952
953 fn handle_thread_events(
954 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
955 acp_thread: WeakEntity<AcpThread>,
956 cx: &App,
957 ) -> Task<Result<acp::PromptResponse>> {
958 cx.spawn(async move |cx| {
959 // Handle response stream and forward to session.acp_thread
960 while let Some(result) = events.next().await {
961 match result {
962 Ok(event) => {
963 log::trace!("Received completion event: {:?}", event);
964
965 match event {
966 ThreadEvent::UserMessage(message) => {
967 acp_thread.update(cx, |thread, cx| {
968 for content in message.content {
969 thread.push_user_content_block(
970 Some(message.id.clone()),
971 content.into(),
972 cx,
973 );
974 }
975 })?;
976 }
977 ThreadEvent::AgentText(text) => {
978 acp_thread.update(cx, |thread, cx| {
979 thread.push_assistant_content_block(text.into(), false, cx)
980 })?;
981 }
982 ThreadEvent::AgentThinking(text) => {
983 acp_thread.update(cx, |thread, cx| {
984 thread.push_assistant_content_block(text.into(), true, cx)
985 })?;
986 }
987 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
988 tool_call,
989 options,
990 response,
991 }) => {
992 let outcome_task = acp_thread.update(cx, |thread, cx| {
993 thread.request_tool_call_authorization(
994 tool_call, options, true, cx,
995 )
996 })??;
997 cx.background_spawn(async move {
998 if let acp::RequestPermissionOutcome::Selected(
999 acp::SelectedPermissionOutcome { option_id, .. },
1000 ) = outcome_task.await
1001 {
1002 response
1003 .send(option_id)
1004 .map(|_| anyhow!("authorization receiver was dropped"))
1005 .log_err();
1006 }
1007 })
1008 .detach();
1009 }
1010 ThreadEvent::ToolCall(tool_call) => {
1011 acp_thread.update(cx, |thread, cx| {
1012 thread.upsert_tool_call(tool_call, cx)
1013 })??;
1014 }
1015 ThreadEvent::ToolCallUpdate(update) => {
1016 acp_thread.update(cx, |thread, cx| {
1017 thread.update_tool_call(update, cx)
1018 })??;
1019 }
1020 ThreadEvent::Retry(status) => {
1021 acp_thread.update(cx, |thread, cx| {
1022 thread.update_retry_status(status, cx)
1023 })?;
1024 }
1025 ThreadEvent::Stop(stop_reason) => {
1026 log::debug!("Assistant message complete: {:?}", stop_reason);
1027 return Ok(acp::PromptResponse::new(stop_reason));
1028 }
1029 }
1030 }
1031 Err(e) => {
1032 log::error!("Error in model response stream: {:?}", e);
1033 return Err(e);
1034 }
1035 }
1036 }
1037
1038 log::debug!("Response stream completed");
1039 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1040 })
1041 }
1042}
1043
1044struct Command<'a> {
1045 prompt_name: &'a str,
1046 arg_value: &'a str,
1047 explicit_server_id: Option<&'a str>,
1048}
1049
1050impl<'a> Command<'a> {
1051 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1052 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1053 return None;
1054 };
1055 let text = text_content.text.trim();
1056 let command = text.strip_prefix('/')?;
1057 let (command, arg_value) = command
1058 .split_once(char::is_whitespace)
1059 .unwrap_or((command, ""));
1060
1061 if let Some((server_id, prompt_name)) = command.split_once('.') {
1062 Some(Self {
1063 prompt_name,
1064 arg_value,
1065 explicit_server_id: Some(server_id),
1066 })
1067 } else {
1068 Some(Self {
1069 prompt_name: command,
1070 arg_value,
1071 explicit_server_id: None,
1072 })
1073 }
1074 }
1075}
1076
1077struct NativeAgentModelSelector {
1078 session_id: acp::SessionId,
1079 connection: NativeAgentConnection,
1080}
1081
1082impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1083 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1084 log::debug!("NativeAgentConnection::list_models called");
1085 let list = self.connection.0.read(cx).models.model_list.clone();
1086 Task::ready(if list.is_empty() {
1087 Err(anyhow::anyhow!("No models available"))
1088 } else {
1089 Ok(list)
1090 })
1091 }
1092
1093 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1094 log::debug!(
1095 "Setting model for session {}: {}",
1096 self.session_id,
1097 model_id
1098 );
1099 let Some(thread) = self
1100 .connection
1101 .0
1102 .read(cx)
1103 .sessions
1104 .get(&self.session_id)
1105 .map(|session| session.thread.clone())
1106 else {
1107 return Task::ready(Err(anyhow!("Session not found")));
1108 };
1109
1110 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1111 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1112 };
1113
1114 thread.update(cx, |thread, cx| {
1115 thread.set_model(model.clone(), cx);
1116 });
1117
1118 update_settings_file(
1119 self.connection.0.read(cx).fs.clone(),
1120 cx,
1121 move |settings, _cx| {
1122 let provider = model.provider_id().0.to_string();
1123 let model = model.id().0.to_string();
1124 settings
1125 .agent
1126 .get_or_insert_default()
1127 .set_model(LanguageModelSelection {
1128 provider: provider.into(),
1129 model,
1130 });
1131 },
1132 );
1133
1134 Task::ready(Ok(()))
1135 }
1136
1137 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1138 let Some(thread) = self
1139 .connection
1140 .0
1141 .read(cx)
1142 .sessions
1143 .get(&self.session_id)
1144 .map(|session| session.thread.clone())
1145 else {
1146 return Task::ready(Err(anyhow!("Session not found")));
1147 };
1148 let Some(model) = thread.read(cx).model() else {
1149 return Task::ready(Err(anyhow!("Model not found")));
1150 };
1151 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1152 else {
1153 return Task::ready(Err(anyhow!("Provider not found")));
1154 };
1155 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1156 model, &provider,
1157 )))
1158 }
1159
1160 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1161 Some(self.connection.0.read(cx).models.watch())
1162 }
1163
1164 fn should_render_footer(&self) -> bool {
1165 true
1166 }
1167
1168 fn supports_favorites(&self) -> bool {
1169 true
1170 }
1171}
1172
1173impl acp_thread::AgentConnection for NativeAgentConnection {
1174 fn telemetry_id(&self) -> SharedString {
1175 "zed".into()
1176 }
1177
1178 fn new_thread(
1179 self: Rc<Self>,
1180 project: Entity<Project>,
1181 cwd: &Path,
1182 cx: &mut App,
1183 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1184 let agent = self.0.clone();
1185 log::debug!("Creating new thread for project at: {:?}", cwd);
1186
1187 cx.spawn(async move |cx| {
1188 log::debug!("Starting thread creation in async context");
1189
1190 // Create Thread
1191 let thread = agent.update(
1192 cx,
1193 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
1194 // Fetch default model from registry settings
1195 let registry = LanguageModelRegistry::read_global(cx);
1196 // Log available models for debugging
1197 let available_count = registry.available_models(cx).count();
1198 log::debug!("Total available models: {}", available_count);
1199
1200 let default_model = registry.default_model().and_then(|default_model| {
1201 agent
1202 .models
1203 .model_from_id(&LanguageModels::model_id(&default_model.model))
1204 });
1205 Ok(cx.new(|cx| {
1206 Thread::new(
1207 project.clone(),
1208 agent.project_context.clone(),
1209 agent.context_server_registry.clone(),
1210 agent.templates.clone(),
1211 default_model,
1212 cx,
1213 )
1214 }))
1215 },
1216 )??;
1217 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
1218 })
1219 }
1220
1221 fn auth_methods(&self) -> &[acp::AuthMethod] {
1222 &[] // No auth for in-process
1223 }
1224
1225 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1226 Task::ready(Ok(()))
1227 }
1228
1229 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1230 Some(Rc::new(NativeAgentModelSelector {
1231 session_id: session_id.clone(),
1232 connection: self.clone(),
1233 }) as Rc<dyn AgentModelSelector>)
1234 }
1235
1236 fn prompt(
1237 &self,
1238 id: Option<acp_thread::UserMessageId>,
1239 params: acp::PromptRequest,
1240 cx: &mut App,
1241 ) -> Task<Result<acp::PromptResponse>> {
1242 let id = id.expect("UserMessageId is required");
1243 let session_id = params.session_id.clone();
1244 log::info!("Received prompt request for session: {}", session_id);
1245 log::debug!("Prompt blocks count: {}", params.prompt.len());
1246
1247 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1248 let registry = self.0.read(cx).context_server_registry.read(cx);
1249
1250 let explicit_server_id = parsed_command
1251 .explicit_server_id
1252 .map(|server_id| ContextServerId(server_id.into()));
1253
1254 if let Some(prompt) =
1255 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1256 {
1257 let arguments = if !parsed_command.arg_value.is_empty()
1258 && let Some(arg_name) = prompt
1259 .prompt
1260 .arguments
1261 .as_ref()
1262 .and_then(|args| args.first())
1263 .map(|arg| arg.name.clone())
1264 {
1265 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1266 } else {
1267 Default::default()
1268 };
1269
1270 let prompt_name = prompt.prompt.name.clone();
1271 let server_id = prompt.server_id.clone();
1272
1273 return self.0.update(cx, |agent, cx| {
1274 agent.send_mcp_prompt(
1275 id,
1276 session_id.clone(),
1277 prompt_name,
1278 server_id,
1279 arguments,
1280 params.prompt,
1281 cx,
1282 )
1283 });
1284 };
1285 };
1286
1287 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1288
1289 self.run_turn(session_id, cx, move |thread, cx| {
1290 let content: Vec<UserMessageContent> = params
1291 .prompt
1292 .into_iter()
1293 .map(|block| UserMessageContent::from_content_block(block, path_style))
1294 .collect::<Vec<_>>();
1295 log::debug!("Converted prompt to message: {} chars", content.len());
1296 log::debug!("Message id: {:?}", id);
1297 log::debug!("Message content: {:?}", content);
1298
1299 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1300 })
1301 }
1302
1303 fn resume(
1304 &self,
1305 session_id: &acp::SessionId,
1306 _cx: &App,
1307 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1308 Some(Rc::new(NativeAgentSessionResume {
1309 connection: self.clone(),
1310 session_id: session_id.clone(),
1311 }) as _)
1312 }
1313
1314 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1315 log::info!("Cancelling on session: {}", session_id);
1316 self.0.update(cx, |agent, cx| {
1317 if let Some(agent) = agent.sessions.get(session_id) {
1318 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1319 }
1320 });
1321 }
1322
1323 fn truncate(
1324 &self,
1325 session_id: &agent_client_protocol::SessionId,
1326 cx: &App,
1327 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1328 self.0.read_with(cx, |agent, _cx| {
1329 agent.sessions.get(session_id).map(|session| {
1330 Rc::new(NativeAgentSessionTruncate {
1331 thread: session.thread.clone(),
1332 acp_thread: session.acp_thread.clone(),
1333 }) as _
1334 })
1335 })
1336 }
1337
1338 fn set_title(
1339 &self,
1340 session_id: &acp::SessionId,
1341 _cx: &App,
1342 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1343 Some(Rc::new(NativeAgentSessionSetTitle {
1344 connection: self.clone(),
1345 session_id: session_id.clone(),
1346 }) as _)
1347 }
1348
1349 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1350 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1351 }
1352
1353 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1354 self
1355 }
1356}
1357
1358impl acp_thread::AgentTelemetry for NativeAgentConnection {
1359 fn thread_data(
1360 &self,
1361 session_id: &acp::SessionId,
1362 cx: &mut App,
1363 ) -> Task<Result<serde_json::Value>> {
1364 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1365 return Task::ready(Err(anyhow!("Session not found")));
1366 };
1367
1368 let task = session.thread.read(cx).to_db(cx);
1369 cx.background_spawn(async move {
1370 serde_json::to_value(task.await).context("Failed to serialize thread")
1371 })
1372 }
1373}
1374
1375struct NativeAgentSessionTruncate {
1376 thread: Entity<Thread>,
1377 acp_thread: WeakEntity<AcpThread>,
1378}
1379
1380impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1381 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1382 match self.thread.update(cx, |thread, cx| {
1383 thread.truncate(message_id.clone(), cx)?;
1384 Ok(thread.latest_token_usage())
1385 }) {
1386 Ok(usage) => {
1387 self.acp_thread
1388 .update(cx, |thread, cx| {
1389 thread.update_token_usage(usage, cx);
1390 })
1391 .ok();
1392 Task::ready(Ok(()))
1393 }
1394 Err(error) => Task::ready(Err(error)),
1395 }
1396 }
1397}
1398
1399struct NativeAgentSessionResume {
1400 connection: NativeAgentConnection,
1401 session_id: acp::SessionId,
1402}
1403
1404impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1405 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1406 self.connection
1407 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1408 thread.update(cx, |thread, cx| thread.resume(cx))
1409 })
1410 }
1411}
1412
1413struct NativeAgentSessionSetTitle {
1414 connection: NativeAgentConnection,
1415 session_id: acp::SessionId,
1416}
1417
1418impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1419 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1420 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1421 return Task::ready(Err(anyhow!("session not found")));
1422 };
1423 let thread = session.thread.clone();
1424 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1425 Task::ready(Ok(()))
1426 }
1427}
1428
1429pub struct AcpThreadEnvironment {
1430 acp_thread: WeakEntity<AcpThread>,
1431}
1432
1433impl ThreadEnvironment for AcpThreadEnvironment {
1434 fn create_terminal(
1435 &self,
1436 command: String,
1437 cwd: Option<PathBuf>,
1438 output_byte_limit: Option<u64>,
1439 cx: &mut AsyncApp,
1440 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1441 let task = self.acp_thread.update(cx, |thread, cx| {
1442 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1443 });
1444
1445 let acp_thread = self.acp_thread.clone();
1446 cx.spawn(async move |cx| {
1447 let terminal = task?.await?;
1448
1449 let (drop_tx, drop_rx) = oneshot::channel();
1450 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1451
1452 cx.spawn(async move |cx| {
1453 drop_rx.await.ok();
1454 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1455 })
1456 .detach();
1457
1458 let handle = AcpTerminalHandle {
1459 terminal,
1460 _drop_tx: Some(drop_tx),
1461 };
1462
1463 Ok(Rc::new(handle) as _)
1464 })
1465 }
1466}
1467
1468pub struct AcpTerminalHandle {
1469 terminal: Entity<acp_thread::Terminal>,
1470 _drop_tx: Option<oneshot::Sender<()>>,
1471}
1472
1473impl TerminalHandle for AcpTerminalHandle {
1474 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1475 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1476 }
1477
1478 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1479 self.terminal
1480 .read_with(cx, |term, _cx| term.wait_for_exit())
1481 }
1482
1483 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1484 self.terminal
1485 .read_with(cx, |term, cx| term.current_output(cx))
1486 }
1487
1488 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1489 cx.update(|cx| {
1490 self.terminal.update(cx, |terminal, cx| {
1491 terminal.kill(cx);
1492 });
1493 })?;
1494 Ok(())
1495 }
1496}
1497
1498#[cfg(test)]
1499mod internal_tests {
1500 use crate::HistoryEntryId;
1501
1502 use super::*;
1503 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1504 use fs::FakeFs;
1505 use gpui::TestAppContext;
1506 use indoc::formatdoc;
1507 use language_model::fake_provider::FakeLanguageModel;
1508 use serde_json::json;
1509 use settings::SettingsStore;
1510 use util::{path, rel_path::rel_path};
1511
1512 #[gpui::test]
1513 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1514 init_test(cx);
1515 let fs = FakeFs::new(cx.executor());
1516 fs.insert_tree(
1517 "/",
1518 json!({
1519 "a": {}
1520 }),
1521 )
1522 .await;
1523 let project = Project::test(fs.clone(), [], cx).await;
1524 let text_thread_store =
1525 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1526 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1527 let agent = NativeAgent::new(
1528 project.clone(),
1529 history_store,
1530 Templates::new(),
1531 None,
1532 fs.clone(),
1533 &mut cx.to_async(),
1534 )
1535 .await
1536 .unwrap();
1537 agent.read_with(cx, |agent, cx| {
1538 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1539 });
1540
1541 let worktree = project
1542 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1543 .await
1544 .unwrap();
1545 cx.run_until_parked();
1546 agent.read_with(cx, |agent, cx| {
1547 assert_eq!(
1548 agent.project_context.read(cx).worktrees,
1549 vec![WorktreeContext {
1550 root_name: "a".into(),
1551 abs_path: Path::new("/a").into(),
1552 rules_file: None
1553 }]
1554 )
1555 });
1556
1557 // Creating `/a/.rules` updates the project context.
1558 fs.insert_file("/a/.rules", Vec::new()).await;
1559 cx.run_until_parked();
1560 agent.read_with(cx, |agent, cx| {
1561 let rules_entry = worktree
1562 .read(cx)
1563 .entry_for_path(rel_path(".rules"))
1564 .unwrap();
1565 assert_eq!(
1566 agent.project_context.read(cx).worktrees,
1567 vec![WorktreeContext {
1568 root_name: "a".into(),
1569 abs_path: Path::new("/a").into(),
1570 rules_file: Some(RulesFileContext {
1571 path_in_worktree: rel_path(".rules").into(),
1572 text: "".into(),
1573 project_entry_id: rules_entry.id.to_usize()
1574 })
1575 }]
1576 )
1577 });
1578 }
1579
1580 #[gpui::test]
1581 async fn test_listing_models(cx: &mut TestAppContext) {
1582 init_test(cx);
1583 let fs = FakeFs::new(cx.executor());
1584 fs.insert_tree("/", json!({ "a": {} })).await;
1585 let project = Project::test(fs.clone(), [], cx).await;
1586 let text_thread_store =
1587 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1588 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1589 let connection = NativeAgentConnection(
1590 NativeAgent::new(
1591 project.clone(),
1592 history_store,
1593 Templates::new(),
1594 None,
1595 fs.clone(),
1596 &mut cx.to_async(),
1597 )
1598 .await
1599 .unwrap(),
1600 );
1601
1602 // Create a thread/session
1603 let acp_thread = cx
1604 .update(|cx| {
1605 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1606 })
1607 .await
1608 .unwrap();
1609
1610 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1611
1612 let models = cx
1613 .update(|cx| {
1614 connection
1615 .model_selector(&session_id)
1616 .unwrap()
1617 .list_models(cx)
1618 })
1619 .await
1620 .unwrap();
1621
1622 let acp_thread::AgentModelList::Grouped(models) = models else {
1623 panic!("Unexpected model group");
1624 };
1625 assert_eq!(
1626 models,
1627 IndexMap::from_iter([(
1628 AgentModelGroupName("Fake".into()),
1629 vec![AgentModelInfo {
1630 id: acp::ModelId::new("fake/fake"),
1631 name: "Fake".into(),
1632 description: None,
1633 icon: Some(ui::IconName::ZedAssistant),
1634 }]
1635 )])
1636 );
1637 }
1638
1639 #[gpui::test]
1640 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1641 init_test(cx);
1642 let fs = FakeFs::new(cx.executor());
1643 fs.create_dir(paths::settings_file().parent().unwrap())
1644 .await
1645 .unwrap();
1646 fs.insert_file(
1647 paths::settings_file(),
1648 json!({
1649 "agent": {
1650 "default_model": {
1651 "provider": "foo",
1652 "model": "bar"
1653 }
1654 }
1655 })
1656 .to_string()
1657 .into_bytes(),
1658 )
1659 .await;
1660 let project = Project::test(fs.clone(), [], cx).await;
1661
1662 let text_thread_store =
1663 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1664 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1665
1666 // Create the agent and connection
1667 let agent = NativeAgent::new(
1668 project.clone(),
1669 history_store,
1670 Templates::new(),
1671 None,
1672 fs.clone(),
1673 &mut cx.to_async(),
1674 )
1675 .await
1676 .unwrap();
1677 let connection = NativeAgentConnection(agent.clone());
1678
1679 // Create a thread/session
1680 let acp_thread = cx
1681 .update(|cx| {
1682 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1683 })
1684 .await
1685 .unwrap();
1686
1687 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1688
1689 // Select a model
1690 let selector = connection.model_selector(&session_id).unwrap();
1691 let model_id = acp::ModelId::new("fake/fake");
1692 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1693 .await
1694 .unwrap();
1695
1696 // Verify the thread has the selected model
1697 agent.read_with(cx, |agent, _| {
1698 let session = agent.sessions.get(&session_id).unwrap();
1699 session.thread.read_with(cx, |thread, _| {
1700 assert_eq!(thread.model().unwrap().id().0, "fake");
1701 });
1702 });
1703
1704 cx.run_until_parked();
1705
1706 // Verify settings file was updated
1707 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1708 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1709
1710 // Check that the agent settings contain the selected model
1711 assert_eq!(
1712 settings_json["agent"]["default_model"]["model"],
1713 json!("fake")
1714 );
1715 assert_eq!(
1716 settings_json["agent"]["default_model"]["provider"],
1717 json!("fake")
1718 );
1719 }
1720
1721 #[gpui::test]
1722 async fn test_save_load_thread(cx: &mut TestAppContext) {
1723 init_test(cx);
1724 let fs = FakeFs::new(cx.executor());
1725 fs.insert_tree(
1726 "/",
1727 json!({
1728 "a": {
1729 "b.md": "Lorem"
1730 }
1731 }),
1732 )
1733 .await;
1734 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1735 let text_thread_store =
1736 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1737 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1738 let agent = NativeAgent::new(
1739 project.clone(),
1740 history_store.clone(),
1741 Templates::new(),
1742 None,
1743 fs.clone(),
1744 &mut cx.to_async(),
1745 )
1746 .await
1747 .unwrap();
1748 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1749
1750 let acp_thread = cx
1751 .update(|cx| {
1752 connection
1753 .clone()
1754 .new_thread(project.clone(), Path::new(""), cx)
1755 })
1756 .await
1757 .unwrap();
1758 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1759 let thread = agent.read_with(cx, |agent, _| {
1760 agent.sessions.get(&session_id).unwrap().thread.clone()
1761 });
1762
1763 // Ensure empty threads are not saved, even if they get mutated.
1764 let model = Arc::new(FakeLanguageModel::default());
1765 let summary_model = Arc::new(FakeLanguageModel::default());
1766 thread.update(cx, |thread, cx| {
1767 thread.set_model(model.clone(), cx);
1768 thread.set_summarization_model(Some(summary_model.clone()), cx);
1769 });
1770 cx.run_until_parked();
1771 assert_eq!(history_entries(&history_store, cx), vec![]);
1772
1773 let send = acp_thread.update(cx, |thread, cx| {
1774 thread.send(
1775 vec![
1776 "What does ".into(),
1777 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1778 "b.md",
1779 MentionUri::File {
1780 abs_path: path!("/a/b.md").into(),
1781 }
1782 .to_uri()
1783 .to_string(),
1784 )),
1785 " mean?".into(),
1786 ],
1787 cx,
1788 )
1789 });
1790 let send = cx.foreground_executor().spawn(send);
1791 cx.run_until_parked();
1792
1793 model.send_last_completion_stream_text_chunk("Lorem.");
1794 model.end_last_completion_stream();
1795 cx.run_until_parked();
1796 summary_model
1797 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1798 summary_model.end_last_completion_stream();
1799
1800 send.await.unwrap();
1801 let uri = MentionUri::File {
1802 abs_path: path!("/a/b.md").into(),
1803 }
1804 .to_uri();
1805 acp_thread.read_with(cx, |thread, cx| {
1806 assert_eq!(
1807 thread.to_markdown(cx),
1808 formatdoc! {"
1809 ## User
1810
1811 What does [@b.md]({uri}) mean?
1812
1813 ## Assistant
1814
1815 Lorem.
1816
1817 "}
1818 )
1819 });
1820
1821 cx.run_until_parked();
1822
1823 // Drop the ACP thread, which should cause the session to be dropped as well.
1824 cx.update(|_| {
1825 drop(thread);
1826 drop(acp_thread);
1827 });
1828 agent.read_with(cx, |agent, _| {
1829 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1830 });
1831
1832 // Ensure the thread can be reloaded from disk.
1833 assert_eq!(
1834 history_entries(&history_store, cx),
1835 vec![(
1836 HistoryEntryId::AcpThread(session_id.clone()),
1837 format!("Explaining {}", path!("/a/b.md"))
1838 )]
1839 );
1840 let acp_thread = agent
1841 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1842 .await
1843 .unwrap();
1844 acp_thread.read_with(cx, |thread, cx| {
1845 assert_eq!(
1846 thread.to_markdown(cx),
1847 formatdoc! {"
1848 ## User
1849
1850 What does [@b.md]({uri}) mean?
1851
1852 ## Assistant
1853
1854 Lorem.
1855
1856 "}
1857 )
1858 });
1859 }
1860
1861 fn history_entries(
1862 history: &Entity<HistoryStore>,
1863 cx: &mut TestAppContext,
1864 ) -> Vec<(HistoryEntryId, String)> {
1865 history.read_with(cx, |history, _| {
1866 history
1867 .entries()
1868 .map(|e| (e.id(), e.title().to_string()))
1869 .collect::<Vec<_>>()
1870 })
1871 }
1872
1873 fn init_test(cx: &mut TestAppContext) {
1874 env_logger::try_init().ok();
1875 cx.update(|cx| {
1876 let settings_store = SettingsStore::test(cx);
1877 cx.set_global(settings_store);
1878
1879 LanguageModelRegistry::test(cx);
1880 });
1881 }
1882}
1883
1884fn mcp_message_content_to_acp_content_block(
1885 content: context_server::types::MessageContent,
1886) -> acp::ContentBlock {
1887 match content {
1888 context_server::types::MessageContent::Text {
1889 text,
1890 annotations: _,
1891 } => text.into(),
1892 context_server::types::MessageContent::Image {
1893 data,
1894 mime_type,
1895 annotations: _,
1896 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
1897 context_server::types::MessageContent::Audio {
1898 data,
1899 mime_type,
1900 annotations: _,
1901 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
1902 context_server::types::MessageContent::Resource {
1903 resource,
1904 annotations: _,
1905 } => {
1906 let mut link =
1907 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
1908 if let Some(mime_type) = resource.mime_type {
1909 link = link.mime_type(mime_type);
1910 }
1911 acp::ContentBlock::ResourceLink(link)
1912 }
1913 }
1914}