1mod db;
2mod edit_agent;
3mod history_store;
4mod legacy_thread;
5mod native_agent_server;
6pub mod outline;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod tools;
12
13use context_server::ContextServerId;
14pub use db::*;
15pub use history_store::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use tools::*;
20
21use acp_thread::{AcpThread, AgentModelSelector, UserMessageId};
22use agent_client_protocol as acp;
23use anyhow::{Context as _, Result, anyhow};
24use chrono::{DateTime, Utc};
25use collections::{HashMap, HashSet, IndexMap};
26use fs::Fs;
27use futures::channel::{mpsc, oneshot};
28use futures::future::Shared;
29use futures::{StreamExt, future};
30use gpui::{
31 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
32};
33use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
34use project::{Project, ProjectItem, ProjectPath, Worktree};
35use prompt_store::{
36 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
37 WorktreeContext,
38};
39use serde::{Deserialize, Serialize};
40use settings::{LanguageModelSelection, update_settings_file};
41use std::any::Any;
42use std::path::{Path, PathBuf};
43use std::rc::Rc;
44use std::sync::Arc;
45use util::ResultExt;
46use util::rel_path::RelPath;
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ProjectSnapshot {
50 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
51 pub timestamp: DateTime<Utc>,
52}
53
54pub struct RulesLoadingError {
55 pub message: SharedString,
56}
57
58/// Holds both the internal Thread and the AcpThread for a session
59struct Session {
60 /// The internal thread that processes messages
61 thread: Entity<Thread>,
62 /// The ACP thread that handles protocol communication
63 acp_thread: WeakEntity<acp_thread::AcpThread>,
64 pending_save: Task<()>,
65 _subscriptions: Vec<Subscription>,
66}
67
68pub struct LanguageModels {
69 /// Access language model by ID
70 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
71 /// Cached list for returning language model information
72 model_list: acp_thread::AgentModelList,
73 refresh_models_rx: watch::Receiver<()>,
74 refresh_models_tx: watch::Sender<()>,
75 _authenticate_all_providers_task: Task<()>,
76}
77
78impl LanguageModels {
79 fn new(cx: &mut App) -> Self {
80 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
81
82 let mut this = Self {
83 models: HashMap::default(),
84 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
85 refresh_models_rx,
86 refresh_models_tx,
87 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
88 };
89 this.refresh_list(cx);
90 this
91 }
92
93 fn refresh_list(&mut self, cx: &App) {
94 let providers = LanguageModelRegistry::global(cx)
95 .read(cx)
96 .visible_providers()
97 .into_iter()
98 .filter(|provider| provider.is_authenticated(cx))
99 .collect::<Vec<_>>();
100
101 let mut language_model_list = IndexMap::default();
102 let mut recommended_models = HashSet::default();
103
104 let mut recommended = Vec::new();
105 for provider in &providers {
106 for model in provider.recommended_models(cx) {
107 recommended_models.insert((model.provider_id(), model.id()));
108 recommended.push(Self::map_language_model_to_info(&model, provider));
109 }
110 }
111 if !recommended.is_empty() {
112 language_model_list.insert(
113 acp_thread::AgentModelGroupName("Recommended".into()),
114 recommended,
115 );
116 }
117
118 let mut models = HashMap::default();
119 for provider in providers {
120 let mut provider_models = Vec::new();
121 for model in provider.provided_models(cx) {
122 let model_info = Self::map_language_model_to_info(&model, &provider);
123 let model_id = model_info.id.clone();
124 provider_models.push(model_info);
125 models.insert(model_id, model);
126 }
127 if !provider_models.is_empty() {
128 language_model_list.insert(
129 acp_thread::AgentModelGroupName(provider.name().0.clone()),
130 provider_models,
131 );
132 }
133 }
134
135 self.models = models;
136 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
137 self.refresh_models_tx.send(()).ok();
138 }
139
140 fn watch(&self) -> watch::Receiver<()> {
141 self.refresh_models_rx.clone()
142 }
143
144 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
145 self.models.get(model_id).cloned()
146 }
147
148 fn map_language_model_to_info(
149 model: &Arc<dyn LanguageModel>,
150 provider: &Arc<dyn LanguageModelProvider>,
151 ) -> acp_thread::AgentModelInfo {
152 acp_thread::AgentModelInfo {
153 id: Self::model_id(model),
154 name: model.name().0,
155 description: None,
156 icon: Some(match provider.icon() {
157 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
158 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
159 }),
160 }
161 }
162
163 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
164 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
165 }
166
167 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
168 let authenticate_all_providers = LanguageModelRegistry::global(cx)
169 .read(cx)
170 .visible_providers()
171 .iter()
172 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
173 .collect::<Vec<_>>();
174
175 cx.background_spawn(async move {
176 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
177 if let Err(err) = authenticate_task.await {
178 match err {
179 language_model::AuthenticateError::CredentialsNotFound => {
180 // Since we're authenticating these providers in the
181 // background for the purposes of populating the
182 // language selector, we don't care about providers
183 // where the credentials are not found.
184 }
185 language_model::AuthenticateError::ConnectionRefused => {
186 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
187 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
188 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
189 }
190 _ => {
191 // Some providers have noisy failure states that we
192 // don't want to spam the logs with every time the
193 // language model selector is initialized.
194 //
195 // Ideally these should have more clear failure modes
196 // that we know are safe to ignore here, like what we do
197 // with `CredentialsNotFound` above.
198 match provider_id.0.as_ref() {
199 "lmstudio" | "ollama" => {
200 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
201 //
202 // These fail noisily, so we don't log them.
203 }
204 "copilot_chat" => {
205 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
206 }
207 _ => {
208 log::error!(
209 "Failed to authenticate provider: {}: {err:#}",
210 provider_name.0
211 );
212 }
213 }
214 }
215 }
216 }
217 }
218 })
219 }
220}
221
222pub struct NativeAgent {
223 /// Session ID -> Session mapping
224 sessions: HashMap<acp::SessionId, Session>,
225 history: Entity<HistoryStore>,
226 /// Shared project context for all threads
227 project_context: Entity<ProjectContext>,
228 project_context_needs_refresh: watch::Sender<()>,
229 _maintain_project_context: Task<Result<()>>,
230 context_server_registry: Entity<ContextServerRegistry>,
231 /// Shared templates for all threads
232 templates: Arc<Templates>,
233 /// Cached model information
234 models: LanguageModels,
235 project: Entity<Project>,
236 prompt_store: Option<Entity<PromptStore>>,
237 fs: Arc<dyn Fs>,
238 _subscriptions: Vec<Subscription>,
239}
240
241impl NativeAgent {
242 pub async fn new(
243 project: Entity<Project>,
244 history: Entity<HistoryStore>,
245 templates: Arc<Templates>,
246 prompt_store: Option<Entity<PromptStore>>,
247 fs: Arc<dyn Fs>,
248 cx: &mut AsyncApp,
249 ) -> Result<Entity<NativeAgent>> {
250 log::debug!("Creating new NativeAgent");
251
252 let project_context = cx
253 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
254 .await;
255
256 Ok(cx.new(|cx| {
257 let context_server_store = project.read(cx).context_server_store();
258 let context_server_registry =
259 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
260
261 let mut subscriptions = vec![
262 cx.subscribe(&project, Self::handle_project_event),
263 cx.subscribe(
264 &LanguageModelRegistry::global(cx),
265 Self::handle_models_updated_event,
266 ),
267 cx.subscribe(
268 &context_server_store,
269 Self::handle_context_server_store_updated,
270 ),
271 cx.subscribe(
272 &context_server_registry,
273 Self::handle_context_server_registry_event,
274 ),
275 ];
276 if let Some(prompt_store) = prompt_store.as_ref() {
277 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
278 }
279
280 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
281 watch::channel(());
282 Self {
283 sessions: HashMap::default(),
284 history,
285 project_context: cx.new(|_| project_context),
286 project_context_needs_refresh: project_context_needs_refresh_tx,
287 _maintain_project_context: cx.spawn(async move |this, cx| {
288 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
289 }),
290 context_server_registry,
291 templates,
292 models: LanguageModels::new(cx),
293 project,
294 prompt_store,
295 fs,
296 _subscriptions: subscriptions,
297 }
298 }))
299 }
300
301 fn register_session(
302 &mut self,
303 thread_handle: Entity<Thread>,
304 cx: &mut Context<Self>,
305 ) -> Entity<AcpThread> {
306 let connection = Rc::new(NativeAgentConnection(cx.entity()));
307
308 let thread = thread_handle.read(cx);
309 let session_id = thread.id().clone();
310 let title = thread.title();
311 let project = thread.project.clone();
312 let action_log = thread.action_log.clone();
313 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
314 let acp_thread = cx.new(|cx| {
315 acp_thread::AcpThread::new(
316 title,
317 connection,
318 project.clone(),
319 action_log.clone(),
320 session_id.clone(),
321 prompt_capabilities_rx,
322 cx,
323 )
324 });
325
326 let registry = LanguageModelRegistry::read_global(cx);
327 let summarization_model = registry.thread_summary_model().map(|c| c.model);
328
329 thread_handle.update(cx, |thread, cx| {
330 thread.set_summarization_model(summarization_model, cx);
331 thread.add_default_tools(
332 Rc::new(AcpThreadEnvironment {
333 acp_thread: acp_thread.downgrade(),
334 }) as _,
335 cx,
336 )
337 });
338
339 let subscriptions = vec![
340 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
341 this.sessions.remove(acp_thread.session_id());
342 }),
343 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
344 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
345 cx.observe(&thread_handle, move |this, thread, cx| {
346 this.save_thread(thread, cx)
347 }),
348 ];
349
350 self.sessions.insert(
351 session_id,
352 Session {
353 thread: thread_handle,
354 acp_thread: acp_thread.downgrade(),
355 _subscriptions: subscriptions,
356 pending_save: Task::ready(()),
357 },
358 );
359
360 self.update_available_commands(cx);
361
362 acp_thread
363 }
364
365 pub fn models(&self) -> &LanguageModels {
366 &self.models
367 }
368
369 async fn maintain_project_context(
370 this: WeakEntity<Self>,
371 mut needs_refresh: watch::Receiver<()>,
372 cx: &mut AsyncApp,
373 ) -> Result<()> {
374 while needs_refresh.changed().await.is_ok() {
375 let project_context = this
376 .update(cx, |this, cx| {
377 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
378 })?
379 .await;
380 this.update(cx, |this, cx| {
381 this.project_context = cx.new(|_| project_context);
382 })?;
383 }
384
385 Ok(())
386 }
387
388 fn build_project_context(
389 project: &Entity<Project>,
390 prompt_store: Option<&Entity<PromptStore>>,
391 cx: &mut App,
392 ) -> Task<ProjectContext> {
393 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
394 let worktree_tasks = worktrees
395 .into_iter()
396 .map(|worktree| {
397 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
398 })
399 .collect::<Vec<_>>();
400 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
401 prompt_store.read_with(cx, |prompt_store, cx| {
402 let prompts = prompt_store.default_prompt_metadata();
403 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
404 let contents = prompt_store.load(prompt_metadata.id, cx);
405 async move { (contents.await, prompt_metadata) }
406 });
407 cx.background_spawn(future::join_all(load_tasks))
408 })
409 } else {
410 Task::ready(vec![])
411 };
412
413 cx.spawn(async move |_cx| {
414 let (worktrees, default_user_rules) =
415 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
416
417 let worktrees = worktrees
418 .into_iter()
419 .map(|(worktree, _rules_error)| {
420 // TODO: show error message
421 // if let Some(rules_error) = rules_error {
422 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
423 // }
424 worktree
425 })
426 .collect::<Vec<_>>();
427
428 let default_user_rules = default_user_rules
429 .into_iter()
430 .flat_map(|(contents, prompt_metadata)| match contents {
431 Ok(contents) => Some(UserRulesContext {
432 uuid: prompt_metadata.id.as_user()?,
433 title: prompt_metadata.title.map(|title| title.to_string()),
434 contents,
435 }),
436 Err(_err) => {
437 // TODO: show error message
438 // this.update(cx, |_, cx| {
439 // cx.emit(RulesLoadingError {
440 // message: format!("{err:?}").into(),
441 // });
442 // })
443 // .ok();
444 None
445 }
446 })
447 .collect::<Vec<_>>();
448
449 ProjectContext::new(worktrees, default_user_rules)
450 })
451 }
452
453 fn load_worktree_info_for_system_prompt(
454 worktree: Entity<Worktree>,
455 project: Entity<Project>,
456 cx: &mut App,
457 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
458 let tree = worktree.read(cx);
459 let root_name = tree.root_name_str().into();
460 let abs_path = tree.abs_path();
461
462 let mut context = WorktreeContext {
463 root_name,
464 abs_path,
465 rules_file: None,
466 };
467
468 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
469 let Some(rules_task) = rules_task else {
470 return Task::ready((context, None));
471 };
472
473 cx.spawn(async move |_| {
474 let (rules_file, rules_file_error) = match rules_task.await {
475 Ok(rules_file) => (Some(rules_file), None),
476 Err(err) => (
477 None,
478 Some(RulesLoadingError {
479 message: format!("{err}").into(),
480 }),
481 ),
482 };
483 context.rules_file = rules_file;
484 (context, rules_file_error)
485 })
486 }
487
488 fn load_worktree_rules_file(
489 worktree: Entity<Worktree>,
490 project: Entity<Project>,
491 cx: &mut App,
492 ) -> Option<Task<Result<RulesFileContext>>> {
493 let worktree = worktree.read(cx);
494 let worktree_id = worktree.id();
495 let selected_rules_file = RULES_FILE_NAMES
496 .into_iter()
497 .filter_map(|name| {
498 worktree
499 .entry_for_path(RelPath::unix(name).unwrap())
500 .filter(|entry| entry.is_file())
501 .map(|entry| entry.path.clone())
502 })
503 .next();
504
505 // Note that Cline supports `.clinerules` being a directory, but that is not currently
506 // supported. This doesn't seem to occur often in GitHub repositories.
507 selected_rules_file.map(|path_in_worktree| {
508 let project_path = ProjectPath {
509 worktree_id,
510 path: path_in_worktree.clone(),
511 };
512 let buffer_task =
513 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
514 let rope_task = cx.spawn(async move |cx| {
515 let buffer = buffer_task.await?;
516 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
517 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
518 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
519 })?;
520 anyhow::Ok((project_entry_id, rope))
521 });
522 // Build a string from the rope on a background thread.
523 cx.background_spawn(async move {
524 let (project_entry_id, rope) = rope_task.await?;
525 anyhow::Ok(RulesFileContext {
526 path_in_worktree,
527 text: rope.to_string().trim().to_string(),
528 project_entry_id: project_entry_id.to_usize(),
529 })
530 })
531 })
532 }
533
534 fn handle_thread_title_updated(
535 &mut self,
536 thread: Entity<Thread>,
537 _: &TitleUpdated,
538 cx: &mut Context<Self>,
539 ) {
540 let session_id = thread.read(cx).id();
541 let Some(session) = self.sessions.get(session_id) else {
542 return;
543 };
544 let thread = thread.downgrade();
545 let acp_thread = session.acp_thread.clone();
546 cx.spawn(async move |_, cx| {
547 let title = thread.read_with(cx, |thread, _| thread.title())?;
548 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
549 task.await
550 })
551 .detach_and_log_err(cx);
552 }
553
554 fn handle_thread_token_usage_updated(
555 &mut self,
556 thread: Entity<Thread>,
557 usage: &TokenUsageUpdated,
558 cx: &mut Context<Self>,
559 ) {
560 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
561 return;
562 };
563 session
564 .acp_thread
565 .update(cx, |acp_thread, cx| {
566 acp_thread.update_token_usage(usage.0.clone(), cx);
567 })
568 .ok();
569 }
570
571 fn handle_project_event(
572 &mut self,
573 _project: Entity<Project>,
574 event: &project::Event,
575 _cx: &mut Context<Self>,
576 ) {
577 match event {
578 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
579 self.project_context_needs_refresh.send(()).ok();
580 }
581 project::Event::WorktreeUpdatedEntries(_, items) => {
582 if items.iter().any(|(path, _, _)| {
583 RULES_FILE_NAMES
584 .iter()
585 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
586 }) {
587 self.project_context_needs_refresh.send(()).ok();
588 }
589 }
590 _ => {}
591 }
592 }
593
594 fn handle_prompts_updated_event(
595 &mut self,
596 _prompt_store: Entity<PromptStore>,
597 _event: &prompt_store::PromptsUpdatedEvent,
598 _cx: &mut Context<Self>,
599 ) {
600 self.project_context_needs_refresh.send(()).ok();
601 }
602
603 fn handle_models_updated_event(
604 &mut self,
605 _registry: Entity<LanguageModelRegistry>,
606 _event: &language_model::Event,
607 cx: &mut Context<Self>,
608 ) {
609 self.models.refresh_list(cx);
610
611 let registry = LanguageModelRegistry::read_global(cx);
612 let default_model = registry.default_model().map(|m| m.model);
613 let summarization_model = registry.thread_summary_model().map(|m| m.model);
614
615 for session in self.sessions.values_mut() {
616 session.thread.update(cx, |thread, cx| {
617 if thread.model().is_none()
618 && let Some(model) = default_model.clone()
619 {
620 thread.set_model(model, cx);
621 cx.notify();
622 }
623 thread.set_summarization_model(summarization_model.clone(), cx);
624 });
625 }
626 }
627
628 fn handle_context_server_store_updated(
629 &mut self,
630 _store: Entity<project::context_server_store::ContextServerStore>,
631 _event: &project::context_server_store::Event,
632 cx: &mut Context<Self>,
633 ) {
634 self.update_available_commands(cx);
635 }
636
637 fn handle_context_server_registry_event(
638 &mut self,
639 _registry: Entity<ContextServerRegistry>,
640 event: &ContextServerRegistryEvent,
641 cx: &mut Context<Self>,
642 ) {
643 match event {
644 ContextServerRegistryEvent::ToolsChanged => {}
645 ContextServerRegistryEvent::PromptsChanged => {
646 self.update_available_commands(cx);
647 }
648 }
649 }
650
651 fn update_available_commands(&self, cx: &mut Context<Self>) {
652 let available_commands = self.build_available_commands(cx);
653 for session in self.sessions.values() {
654 if let Some(acp_thread) = session.acp_thread.upgrade() {
655 acp_thread.update(cx, |thread, cx| {
656 thread
657 .handle_session_update(
658 acp::SessionUpdate::AvailableCommandsUpdate(
659 acp::AvailableCommandsUpdate::new(available_commands.clone()),
660 ),
661 cx,
662 )
663 .log_err();
664 });
665 }
666 }
667 }
668
669 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
670 let registry = self.context_server_registry.read(cx);
671
672 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
673 for context_server_prompt in registry.prompts() {
674 *prompt_name_counts
675 .entry(context_server_prompt.prompt.name.as_str())
676 .or_insert(0) += 1;
677 }
678
679 registry
680 .prompts()
681 .flat_map(|context_server_prompt| {
682 let prompt = &context_server_prompt.prompt;
683
684 let should_prefix = prompt_name_counts
685 .get(prompt.name.as_str())
686 .copied()
687 .unwrap_or(0)
688 > 1;
689
690 let name = if should_prefix {
691 format!("{}.{}", context_server_prompt.server_id, prompt.name)
692 } else {
693 prompt.name.clone()
694 };
695
696 let mut command = acp::AvailableCommand::new(
697 name,
698 prompt.description.clone().unwrap_or_default(),
699 );
700
701 match prompt.arguments.as_deref() {
702 Some([arg]) => {
703 let hint = format!("<{}>", arg.name);
704
705 command = command.input(acp::AvailableCommandInput::Unstructured(
706 acp::UnstructuredCommandInput::new(hint),
707 ));
708 }
709 Some([]) | None => {}
710 Some(_) => {
711 // skip >1 argument commands since we don't support them yet
712 return None;
713 }
714 }
715
716 Some(command)
717 })
718 .collect()
719 }
720
721 pub fn load_thread(
722 &mut self,
723 id: acp::SessionId,
724 cx: &mut Context<Self>,
725 ) -> Task<Result<Entity<Thread>>> {
726 let database_future = ThreadsDatabase::connect(cx);
727 cx.spawn(async move |this, cx| {
728 let database = database_future.await.map_err(|err| anyhow!(err))?;
729 let db_thread = database
730 .load_thread(id.clone())
731 .await?
732 .with_context(|| format!("no thread found with ID: {id:?}"))?;
733
734 this.update(cx, |this, cx| {
735 let summarization_model = LanguageModelRegistry::read_global(cx)
736 .thread_summary_model()
737 .map(|c| c.model);
738
739 cx.new(|cx| {
740 let mut thread = Thread::from_db(
741 id.clone(),
742 db_thread,
743 this.project.clone(),
744 this.project_context.clone(),
745 this.context_server_registry.clone(),
746 this.templates.clone(),
747 cx,
748 );
749 thread.set_summarization_model(summarization_model, cx);
750 thread
751 })
752 })
753 })
754 }
755
756 pub fn open_thread(
757 &mut self,
758 id: acp::SessionId,
759 cx: &mut Context<Self>,
760 ) -> Task<Result<Entity<AcpThread>>> {
761 let task = self.load_thread(id, cx);
762 cx.spawn(async move |this, cx| {
763 let thread = task.await?;
764 let acp_thread =
765 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
766 let events = thread.update(cx, |thread, cx| thread.replay(cx));
767 cx.update(|cx| {
768 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
769 })
770 .await?;
771 Ok(acp_thread)
772 })
773 }
774
775 pub fn thread_summary(
776 &mut self,
777 id: acp::SessionId,
778 cx: &mut Context<Self>,
779 ) -> Task<Result<SharedString>> {
780 let thread = self.open_thread(id.clone(), cx);
781 cx.spawn(async move |this, cx| {
782 let acp_thread = thread.await?;
783 let result = this
784 .update(cx, |this, cx| {
785 this.sessions
786 .get(&id)
787 .unwrap()
788 .thread
789 .update(cx, |thread, cx| thread.summary(cx))
790 })?
791 .await
792 .context("Failed to generate summary")?;
793 drop(acp_thread);
794 Ok(result)
795 })
796 }
797
798 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
799 if thread.read(cx).is_empty() {
800 return;
801 }
802
803 let database_future = ThreadsDatabase::connect(cx);
804 let (id, db_thread) =
805 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
806 let Some(session) = self.sessions.get_mut(&id) else {
807 return;
808 };
809 let history = self.history.clone();
810 session.pending_save = cx.spawn(async move |_, cx| {
811 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
812 return;
813 };
814 let db_thread = db_thread.await;
815 database.save_thread(id, db_thread).await.log_err();
816 history.update(cx, |history, cx| history.reload(cx));
817 });
818 }
819
820 fn send_mcp_prompt(
821 &self,
822 message_id: UserMessageId,
823 session_id: agent_client_protocol::SessionId,
824 prompt_name: String,
825 server_id: ContextServerId,
826 arguments: HashMap<String, String>,
827 original_content: Vec<acp::ContentBlock>,
828 cx: &mut Context<Self>,
829 ) -> Task<Result<acp::PromptResponse>> {
830 let server_store = self.context_server_registry.read(cx).server_store().clone();
831 let path_style = self.project.read(cx).path_style(cx);
832
833 cx.spawn(async move |this, cx| {
834 let prompt =
835 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
836
837 let (acp_thread, thread) = this.update(cx, |this, _cx| {
838 let session = this
839 .sessions
840 .get(&session_id)
841 .context("Failed to get session")?;
842 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
843 })??;
844
845 let mut last_is_user = true;
846
847 thread.update(cx, |thread, cx| {
848 thread.push_acp_user_block(
849 message_id,
850 original_content.into_iter().skip(1),
851 path_style,
852 cx,
853 );
854 });
855
856 for message in prompt.messages {
857 let context_server::types::PromptMessage { role, content } = message;
858 let block = mcp_message_content_to_acp_content_block(content);
859
860 match role {
861 context_server::types::Role::User => {
862 let id = acp_thread::UserMessageId::new();
863
864 acp_thread.update(cx, |acp_thread, cx| {
865 acp_thread.push_user_content_block_with_indent(
866 Some(id.clone()),
867 block.clone(),
868 true,
869 cx,
870 );
871 })?;
872
873 thread.update(cx, |thread, cx| {
874 thread.push_acp_user_block(id, [block], path_style, cx);
875 });
876 }
877 context_server::types::Role::Assistant => {
878 acp_thread.update(cx, |acp_thread, cx| {
879 acp_thread.push_assistant_content_block_with_indent(
880 block.clone(),
881 false,
882 true,
883 cx,
884 );
885 })?;
886
887 thread.update(cx, |thread, cx| {
888 thread.push_acp_agent_block(block, cx);
889 });
890 }
891 }
892
893 last_is_user = role == context_server::types::Role::User;
894 }
895
896 let response_stream = thread.update(cx, |thread, cx| {
897 if last_is_user {
898 thread.send_existing(cx)
899 } else {
900 // Resume if MCP prompt did not end with a user message
901 thread.resume(cx)
902 }
903 })?;
904
905 cx.update(|cx| {
906 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
907 })
908 .await
909 })
910 }
911}
912
913/// Wrapper struct that implements the AgentConnection trait
914#[derive(Clone)]
915pub struct NativeAgentConnection(pub Entity<NativeAgent>);
916
917impl NativeAgentConnection {
918 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
919 self.0
920 .read(cx)
921 .sessions
922 .get(session_id)
923 .map(|session| session.thread.clone())
924 }
925
926 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
927 self.0.update(cx, |this, cx| this.load_thread(id, cx))
928 }
929
930 fn run_turn(
931 &self,
932 session_id: acp::SessionId,
933 cx: &mut App,
934 f: impl 'static
935 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
936 ) -> Task<Result<acp::PromptResponse>> {
937 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
938 agent
939 .sessions
940 .get_mut(&session_id)
941 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
942 }) else {
943 return Task::ready(Err(anyhow!("Session not found")));
944 };
945 log::debug!("Found session for: {}", session_id);
946
947 let response_stream = match f(thread, cx) {
948 Ok(stream) => stream,
949 Err(err) => return Task::ready(Err(err)),
950 };
951 Self::handle_thread_events(response_stream, acp_thread, cx)
952 }
953
954 fn handle_thread_events(
955 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
956 acp_thread: WeakEntity<AcpThread>,
957 cx: &App,
958 ) -> Task<Result<acp::PromptResponse>> {
959 cx.spawn(async move |cx| {
960 // Handle response stream and forward to session.acp_thread
961 while let Some(result) = events.next().await {
962 match result {
963 Ok(event) => {
964 log::trace!("Received completion event: {:?}", event);
965
966 match event {
967 ThreadEvent::UserMessage(message) => {
968 acp_thread.update(cx, |thread, cx| {
969 for content in message.content {
970 thread.push_user_content_block(
971 Some(message.id.clone()),
972 content.into(),
973 cx,
974 );
975 }
976 })?;
977 }
978 ThreadEvent::AgentText(text) => {
979 acp_thread.update(cx, |thread, cx| {
980 thread.push_assistant_content_block(text.into(), false, cx)
981 })?;
982 }
983 ThreadEvent::AgentThinking(text) => {
984 acp_thread.update(cx, |thread, cx| {
985 thread.push_assistant_content_block(text.into(), true, cx)
986 })?;
987 }
988 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
989 tool_call,
990 options,
991 response,
992 }) => {
993 let outcome_task = acp_thread.update(cx, |thread, cx| {
994 thread.request_tool_call_authorization(
995 tool_call, options, true, cx,
996 )
997 })??;
998 cx.background_spawn(async move {
999 if let acp::RequestPermissionOutcome::Selected(
1000 acp::SelectedPermissionOutcome { option_id, .. },
1001 ) = outcome_task.await
1002 {
1003 response
1004 .send(option_id)
1005 .map(|_| anyhow!("authorization receiver was dropped"))
1006 .log_err();
1007 }
1008 })
1009 .detach();
1010 }
1011 ThreadEvent::ToolCall(tool_call) => {
1012 acp_thread.update(cx, |thread, cx| {
1013 thread.upsert_tool_call(tool_call, cx)
1014 })??;
1015 }
1016 ThreadEvent::ToolCallUpdate(update) => {
1017 acp_thread.update(cx, |thread, cx| {
1018 thread.update_tool_call(update, cx)
1019 })??;
1020 }
1021 ThreadEvent::Retry(status) => {
1022 acp_thread.update(cx, |thread, cx| {
1023 thread.update_retry_status(status, cx)
1024 })?;
1025 }
1026 ThreadEvent::Stop(stop_reason) => {
1027 log::debug!("Assistant message complete: {:?}", stop_reason);
1028 return Ok(acp::PromptResponse::new(stop_reason));
1029 }
1030 }
1031 }
1032 Err(e) => {
1033 log::error!("Error in model response stream: {:?}", e);
1034 return Err(e);
1035 }
1036 }
1037 }
1038
1039 log::debug!("Response stream completed");
1040 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1041 })
1042 }
1043}
1044
1045struct Command<'a> {
1046 prompt_name: &'a str,
1047 arg_value: &'a str,
1048 explicit_server_id: Option<&'a str>,
1049}
1050
1051impl<'a> Command<'a> {
1052 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1053 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1054 return None;
1055 };
1056 let text = text_content.text.trim();
1057 let command = text.strip_prefix('/')?;
1058 let (command, arg_value) = command
1059 .split_once(char::is_whitespace)
1060 .unwrap_or((command, ""));
1061
1062 if let Some((server_id, prompt_name)) = command.split_once('.') {
1063 Some(Self {
1064 prompt_name,
1065 arg_value,
1066 explicit_server_id: Some(server_id),
1067 })
1068 } else {
1069 Some(Self {
1070 prompt_name: command,
1071 arg_value,
1072 explicit_server_id: None,
1073 })
1074 }
1075 }
1076}
1077
1078struct NativeAgentModelSelector {
1079 session_id: acp::SessionId,
1080 connection: NativeAgentConnection,
1081}
1082
1083impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1084 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1085 log::debug!("NativeAgentConnection::list_models called");
1086 let list = self.connection.0.read(cx).models.model_list.clone();
1087 Task::ready(if list.is_empty() {
1088 Err(anyhow::anyhow!("No models available"))
1089 } else {
1090 Ok(list)
1091 })
1092 }
1093
1094 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1095 log::debug!(
1096 "Setting model for session {}: {}",
1097 self.session_id,
1098 model_id
1099 );
1100 let Some(thread) = self
1101 .connection
1102 .0
1103 .read(cx)
1104 .sessions
1105 .get(&self.session_id)
1106 .map(|session| session.thread.clone())
1107 else {
1108 return Task::ready(Err(anyhow!("Session not found")));
1109 };
1110
1111 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1112 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1113 };
1114
1115 thread.update(cx, |thread, cx| {
1116 thread.set_model(model.clone(), cx);
1117 });
1118
1119 update_settings_file(
1120 self.connection.0.read(cx).fs.clone(),
1121 cx,
1122 move |settings, _cx| {
1123 let provider = model.provider_id().0.to_string();
1124 let model = model.id().0.to_string();
1125 settings
1126 .agent
1127 .get_or_insert_default()
1128 .set_model(LanguageModelSelection {
1129 provider: provider.into(),
1130 model,
1131 });
1132 },
1133 );
1134
1135 Task::ready(Ok(()))
1136 }
1137
1138 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1139 let Some(thread) = self
1140 .connection
1141 .0
1142 .read(cx)
1143 .sessions
1144 .get(&self.session_id)
1145 .map(|session| session.thread.clone())
1146 else {
1147 return Task::ready(Err(anyhow!("Session not found")));
1148 };
1149 let Some(model) = thread.read(cx).model() else {
1150 return Task::ready(Err(anyhow!("Model not found")));
1151 };
1152 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1153 else {
1154 return Task::ready(Err(anyhow!("Provider not found")));
1155 };
1156 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1157 model, &provider,
1158 )))
1159 }
1160
1161 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1162 Some(self.connection.0.read(cx).models.watch())
1163 }
1164
1165 fn should_render_footer(&self) -> bool {
1166 true
1167 }
1168}
1169
1170impl acp_thread::AgentConnection for NativeAgentConnection {
1171 fn telemetry_id(&self) -> SharedString {
1172 "zed".into()
1173 }
1174
1175 fn new_thread(
1176 self: Rc<Self>,
1177 project: Entity<Project>,
1178 cwd: &Path,
1179 cx: &mut App,
1180 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1181 let agent = self.0.clone();
1182 log::debug!("Creating new thread for project at: {:?}", cwd);
1183
1184 cx.spawn(async move |cx| {
1185 log::debug!("Starting thread creation in async context");
1186
1187 // Create Thread
1188 let thread = agent.update(cx, |agent, cx| {
1189 // Fetch default model from registry settings
1190 let registry = LanguageModelRegistry::read_global(cx);
1191 // Log available models for debugging
1192 let available_count = registry.available_models(cx).count();
1193 log::debug!("Total available models: {}", available_count);
1194
1195 let default_model = registry.default_model().and_then(|default_model| {
1196 agent
1197 .models
1198 .model_from_id(&LanguageModels::model_id(&default_model.model))
1199 });
1200 cx.new(|cx| {
1201 Thread::new(
1202 project.clone(),
1203 agent.project_context.clone(),
1204 agent.context_server_registry.clone(),
1205 agent.templates.clone(),
1206 default_model,
1207 cx,
1208 )
1209 })
1210 });
1211 Ok(agent.update(cx, |agent, cx| agent.register_session(thread, cx)))
1212 })
1213 }
1214
1215 fn auth_methods(&self) -> &[acp::AuthMethod] {
1216 &[] // No auth for in-process
1217 }
1218
1219 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1220 Task::ready(Ok(()))
1221 }
1222
1223 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1224 Some(Rc::new(NativeAgentModelSelector {
1225 session_id: session_id.clone(),
1226 connection: self.clone(),
1227 }) as Rc<dyn AgentModelSelector>)
1228 }
1229
1230 fn prompt(
1231 &self,
1232 id: Option<acp_thread::UserMessageId>,
1233 params: acp::PromptRequest,
1234 cx: &mut App,
1235 ) -> Task<Result<acp::PromptResponse>> {
1236 let id = id.expect("UserMessageId is required");
1237 let session_id = params.session_id.clone();
1238 log::info!("Received prompt request for session: {}", session_id);
1239 log::debug!("Prompt blocks count: {}", params.prompt.len());
1240
1241 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1242 let registry = self.0.read(cx).context_server_registry.read(cx);
1243
1244 let explicit_server_id = parsed_command
1245 .explicit_server_id
1246 .map(|server_id| ContextServerId(server_id.into()));
1247
1248 if let Some(prompt) =
1249 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1250 {
1251 let arguments = if !parsed_command.arg_value.is_empty()
1252 && let Some(arg_name) = prompt
1253 .prompt
1254 .arguments
1255 .as_ref()
1256 .and_then(|args| args.first())
1257 .map(|arg| arg.name.clone())
1258 {
1259 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1260 } else {
1261 Default::default()
1262 };
1263
1264 let prompt_name = prompt.prompt.name.clone();
1265 let server_id = prompt.server_id.clone();
1266
1267 return self.0.update(cx, |agent, cx| {
1268 agent.send_mcp_prompt(
1269 id,
1270 session_id.clone(),
1271 prompt_name,
1272 server_id,
1273 arguments,
1274 params.prompt,
1275 cx,
1276 )
1277 });
1278 };
1279 };
1280
1281 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1282
1283 self.run_turn(session_id, cx, move |thread, cx| {
1284 let content: Vec<UserMessageContent> = params
1285 .prompt
1286 .into_iter()
1287 .map(|block| UserMessageContent::from_content_block(block, path_style))
1288 .collect::<Vec<_>>();
1289 log::debug!("Converted prompt to message: {} chars", content.len());
1290 log::debug!("Message id: {:?}", id);
1291 log::debug!("Message content: {:?}", content);
1292
1293 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1294 })
1295 }
1296
1297 fn resume(
1298 &self,
1299 session_id: &acp::SessionId,
1300 _cx: &App,
1301 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1302 Some(Rc::new(NativeAgentSessionResume {
1303 connection: self.clone(),
1304 session_id: session_id.clone(),
1305 }) as _)
1306 }
1307
1308 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1309 log::info!("Cancelling on session: {}", session_id);
1310 self.0.update(cx, |agent, cx| {
1311 if let Some(agent) = agent.sessions.get(session_id) {
1312 agent
1313 .thread
1314 .update(cx, |thread, cx| thread.cancel(cx))
1315 .detach();
1316 }
1317 });
1318 }
1319
1320 fn truncate(
1321 &self,
1322 session_id: &agent_client_protocol::SessionId,
1323 cx: &App,
1324 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1325 self.0.read_with(cx, |agent, _cx| {
1326 agent.sessions.get(session_id).map(|session| {
1327 Rc::new(NativeAgentSessionTruncate {
1328 thread: session.thread.clone(),
1329 acp_thread: session.acp_thread.clone(),
1330 }) as _
1331 })
1332 })
1333 }
1334
1335 fn set_title(
1336 &self,
1337 session_id: &acp::SessionId,
1338 _cx: &App,
1339 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1340 Some(Rc::new(NativeAgentSessionSetTitle {
1341 connection: self.clone(),
1342 session_id: session_id.clone(),
1343 }) as _)
1344 }
1345
1346 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1347 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1348 }
1349
1350 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1351 self
1352 }
1353}
1354
1355impl acp_thread::AgentTelemetry for NativeAgentConnection {
1356 fn thread_data(
1357 &self,
1358 session_id: &acp::SessionId,
1359 cx: &mut App,
1360 ) -> Task<Result<serde_json::Value>> {
1361 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1362 return Task::ready(Err(anyhow!("Session not found")));
1363 };
1364
1365 let task = session.thread.read(cx).to_db(cx);
1366 cx.background_spawn(async move {
1367 serde_json::to_value(task.await).context("Failed to serialize thread")
1368 })
1369 }
1370}
1371
1372struct NativeAgentSessionTruncate {
1373 thread: Entity<Thread>,
1374 acp_thread: WeakEntity<AcpThread>,
1375}
1376
1377impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1378 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1379 match self.thread.update(cx, |thread, cx| {
1380 thread.truncate(message_id.clone(), cx)?;
1381 Ok(thread.latest_token_usage())
1382 }) {
1383 Ok(usage) => {
1384 self.acp_thread
1385 .update(cx, |thread, cx| {
1386 thread.update_token_usage(usage, cx);
1387 })
1388 .ok();
1389 Task::ready(Ok(()))
1390 }
1391 Err(error) => Task::ready(Err(error)),
1392 }
1393 }
1394}
1395
1396struct NativeAgentSessionResume {
1397 connection: NativeAgentConnection,
1398 session_id: acp::SessionId,
1399}
1400
1401impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1402 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1403 self.connection
1404 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1405 thread.update(cx, |thread, cx| thread.resume(cx))
1406 })
1407 }
1408}
1409
1410struct NativeAgentSessionSetTitle {
1411 connection: NativeAgentConnection,
1412 session_id: acp::SessionId,
1413}
1414
1415impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1416 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1417 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1418 return Task::ready(Err(anyhow!("session not found")));
1419 };
1420 let thread = session.thread.clone();
1421 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1422 Task::ready(Ok(()))
1423 }
1424}
1425
1426pub struct AcpThreadEnvironment {
1427 acp_thread: WeakEntity<AcpThread>,
1428}
1429
1430impl ThreadEnvironment for AcpThreadEnvironment {
1431 fn create_terminal(
1432 &self,
1433 command: String,
1434 cwd: Option<PathBuf>,
1435 output_byte_limit: Option<u64>,
1436 cx: &mut AsyncApp,
1437 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1438 let task = self.acp_thread.update(cx, |thread, cx| {
1439 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1440 });
1441
1442 let acp_thread = self.acp_thread.clone();
1443 cx.spawn(async move |cx| {
1444 let terminal = task?.await?;
1445
1446 let (drop_tx, drop_rx) = oneshot::channel();
1447 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1448
1449 cx.spawn(async move |cx| {
1450 drop_rx.await.ok();
1451 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1452 })
1453 .detach();
1454
1455 let handle = AcpTerminalHandle {
1456 terminal,
1457 _drop_tx: Some(drop_tx),
1458 };
1459
1460 Ok(Rc::new(handle) as _)
1461 })
1462 }
1463}
1464
1465pub struct AcpTerminalHandle {
1466 terminal: Entity<acp_thread::Terminal>,
1467 _drop_tx: Option<oneshot::Sender<()>>,
1468}
1469
1470impl TerminalHandle for AcpTerminalHandle {
1471 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1472 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1473 }
1474
1475 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1476 Ok(self
1477 .terminal
1478 .read_with(cx, |term, _cx| term.wait_for_exit()))
1479 }
1480
1481 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1482 Ok(self
1483 .terminal
1484 .read_with(cx, |term, cx| term.current_output(cx)))
1485 }
1486
1487 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1488 cx.update(|cx| {
1489 self.terminal.update(cx, |terminal, cx| {
1490 terminal.kill(cx);
1491 });
1492 });
1493 Ok(())
1494 }
1495
1496 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1497 Ok(self
1498 .terminal
1499 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1500 }
1501}
1502
1503#[cfg(test)]
1504mod internal_tests {
1505 use crate::HistoryEntryId;
1506
1507 use super::*;
1508 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1509 use fs::FakeFs;
1510 use gpui::TestAppContext;
1511 use indoc::formatdoc;
1512 use language_model::fake_provider::FakeLanguageModel;
1513 use serde_json::json;
1514 use settings::SettingsStore;
1515 use util::{path, rel_path::rel_path};
1516
1517 #[gpui::test]
1518 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1519 init_test(cx);
1520 let fs = FakeFs::new(cx.executor());
1521 fs.insert_tree(
1522 "/",
1523 json!({
1524 "a": {}
1525 }),
1526 )
1527 .await;
1528 let project = Project::test(fs.clone(), [], cx).await;
1529 let text_thread_store =
1530 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1531 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1532 let agent = NativeAgent::new(
1533 project.clone(),
1534 history_store,
1535 Templates::new(),
1536 None,
1537 fs.clone(),
1538 &mut cx.to_async(),
1539 )
1540 .await
1541 .unwrap();
1542 agent.read_with(cx, |agent, cx| {
1543 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1544 });
1545
1546 let worktree = project
1547 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1548 .await
1549 .unwrap();
1550 cx.run_until_parked();
1551 agent.read_with(cx, |agent, cx| {
1552 assert_eq!(
1553 agent.project_context.read(cx).worktrees,
1554 vec![WorktreeContext {
1555 root_name: "a".into(),
1556 abs_path: Path::new("/a").into(),
1557 rules_file: None
1558 }]
1559 )
1560 });
1561
1562 // Creating `/a/.rules` updates the project context.
1563 fs.insert_file("/a/.rules", Vec::new()).await;
1564 cx.run_until_parked();
1565 agent.read_with(cx, |agent, cx| {
1566 let rules_entry = worktree
1567 .read(cx)
1568 .entry_for_path(rel_path(".rules"))
1569 .unwrap();
1570 assert_eq!(
1571 agent.project_context.read(cx).worktrees,
1572 vec![WorktreeContext {
1573 root_name: "a".into(),
1574 abs_path: Path::new("/a").into(),
1575 rules_file: Some(RulesFileContext {
1576 path_in_worktree: rel_path(".rules").into(),
1577 text: "".into(),
1578 project_entry_id: rules_entry.id.to_usize()
1579 })
1580 }]
1581 )
1582 });
1583 }
1584
1585 #[gpui::test]
1586 async fn test_listing_models(cx: &mut TestAppContext) {
1587 init_test(cx);
1588 let fs = FakeFs::new(cx.executor());
1589 fs.insert_tree("/", json!({ "a": {} })).await;
1590 let project = Project::test(fs.clone(), [], cx).await;
1591 let text_thread_store =
1592 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1593 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1594 let connection = NativeAgentConnection(
1595 NativeAgent::new(
1596 project.clone(),
1597 history_store,
1598 Templates::new(),
1599 None,
1600 fs.clone(),
1601 &mut cx.to_async(),
1602 )
1603 .await
1604 .unwrap(),
1605 );
1606
1607 // Create a thread/session
1608 let acp_thread = cx
1609 .update(|cx| {
1610 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1611 })
1612 .await
1613 .unwrap();
1614
1615 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1616
1617 let models = cx
1618 .update(|cx| {
1619 connection
1620 .model_selector(&session_id)
1621 .unwrap()
1622 .list_models(cx)
1623 })
1624 .await
1625 .unwrap();
1626
1627 let acp_thread::AgentModelList::Grouped(models) = models else {
1628 panic!("Unexpected model group");
1629 };
1630 assert_eq!(
1631 models,
1632 IndexMap::from_iter([(
1633 AgentModelGroupName("Fake".into()),
1634 vec![AgentModelInfo {
1635 id: acp::ModelId::new("fake/fake"),
1636 name: "Fake".into(),
1637 description: None,
1638 icon: Some(acp_thread::AgentModelIcon::Named(
1639 ui::IconName::ZedAssistant
1640 )),
1641 }]
1642 )])
1643 );
1644 }
1645
1646 #[gpui::test]
1647 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1648 init_test(cx);
1649 let fs = FakeFs::new(cx.executor());
1650 fs.create_dir(paths::settings_file().parent().unwrap())
1651 .await
1652 .unwrap();
1653 fs.insert_file(
1654 paths::settings_file(),
1655 json!({
1656 "agent": {
1657 "default_model": {
1658 "provider": "foo",
1659 "model": "bar"
1660 }
1661 }
1662 })
1663 .to_string()
1664 .into_bytes(),
1665 )
1666 .await;
1667 let project = Project::test(fs.clone(), [], cx).await;
1668
1669 let text_thread_store =
1670 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1671 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1672
1673 // Create the agent and connection
1674 let agent = NativeAgent::new(
1675 project.clone(),
1676 history_store,
1677 Templates::new(),
1678 None,
1679 fs.clone(),
1680 &mut cx.to_async(),
1681 )
1682 .await
1683 .unwrap();
1684 let connection = NativeAgentConnection(agent.clone());
1685
1686 // Create a thread/session
1687 let acp_thread = cx
1688 .update(|cx| {
1689 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1690 })
1691 .await
1692 .unwrap();
1693
1694 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1695
1696 // Select a model
1697 let selector = connection.model_selector(&session_id).unwrap();
1698 let model_id = acp::ModelId::new("fake/fake");
1699 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1700 .await
1701 .unwrap();
1702
1703 // Verify the thread has the selected model
1704 agent.read_with(cx, |agent, _| {
1705 let session = agent.sessions.get(&session_id).unwrap();
1706 session.thread.read_with(cx, |thread, _| {
1707 assert_eq!(thread.model().unwrap().id().0, "fake");
1708 });
1709 });
1710
1711 cx.run_until_parked();
1712
1713 // Verify settings file was updated
1714 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1715 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1716
1717 // Check that the agent settings contain the selected model
1718 assert_eq!(
1719 settings_json["agent"]["default_model"]["model"],
1720 json!("fake")
1721 );
1722 assert_eq!(
1723 settings_json["agent"]["default_model"]["provider"],
1724 json!("fake")
1725 );
1726 }
1727
1728 #[gpui::test]
1729 async fn test_save_load_thread(cx: &mut TestAppContext) {
1730 init_test(cx);
1731 let fs = FakeFs::new(cx.executor());
1732 fs.insert_tree(
1733 "/",
1734 json!({
1735 "a": {
1736 "b.md": "Lorem"
1737 }
1738 }),
1739 )
1740 .await;
1741 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1742 let text_thread_store =
1743 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1744 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1745 let agent = NativeAgent::new(
1746 project.clone(),
1747 history_store.clone(),
1748 Templates::new(),
1749 None,
1750 fs.clone(),
1751 &mut cx.to_async(),
1752 )
1753 .await
1754 .unwrap();
1755 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1756
1757 let acp_thread = cx
1758 .update(|cx| {
1759 connection
1760 .clone()
1761 .new_thread(project.clone(), Path::new(""), cx)
1762 })
1763 .await
1764 .unwrap();
1765 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1766 let thread = agent.read_with(cx, |agent, _| {
1767 agent.sessions.get(&session_id).unwrap().thread.clone()
1768 });
1769
1770 // Ensure empty threads are not saved, even if they get mutated.
1771 let model = Arc::new(FakeLanguageModel::default());
1772 let summary_model = Arc::new(FakeLanguageModel::default());
1773 thread.update(cx, |thread, cx| {
1774 thread.set_model(model.clone(), cx);
1775 thread.set_summarization_model(Some(summary_model.clone()), cx);
1776 });
1777 cx.run_until_parked();
1778 assert_eq!(history_entries(&history_store, cx), vec![]);
1779
1780 let send = acp_thread.update(cx, |thread, cx| {
1781 thread.send(
1782 vec![
1783 "What does ".into(),
1784 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1785 "b.md",
1786 MentionUri::File {
1787 abs_path: path!("/a/b.md").into(),
1788 }
1789 .to_uri()
1790 .to_string(),
1791 )),
1792 " mean?".into(),
1793 ],
1794 cx,
1795 )
1796 });
1797 let send = cx.foreground_executor().spawn(send);
1798 cx.run_until_parked();
1799
1800 model.send_last_completion_stream_text_chunk("Lorem.");
1801 model.end_last_completion_stream();
1802 cx.run_until_parked();
1803 summary_model
1804 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1805 summary_model.end_last_completion_stream();
1806
1807 send.await.unwrap();
1808 let uri = MentionUri::File {
1809 abs_path: path!("/a/b.md").into(),
1810 }
1811 .to_uri();
1812 acp_thread.read_with(cx, |thread, cx| {
1813 assert_eq!(
1814 thread.to_markdown(cx),
1815 formatdoc! {"
1816 ## User
1817
1818 What does [@b.md]({uri}) mean?
1819
1820 ## Assistant
1821
1822 Lorem.
1823
1824 "}
1825 )
1826 });
1827
1828 cx.run_until_parked();
1829
1830 // Drop the ACP thread, which should cause the session to be dropped as well.
1831 cx.update(|_| {
1832 drop(thread);
1833 drop(acp_thread);
1834 });
1835 agent.read_with(cx, |agent, _| {
1836 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1837 });
1838
1839 // Ensure the thread can be reloaded from disk.
1840 assert_eq!(
1841 history_entries(&history_store, cx),
1842 vec![(
1843 HistoryEntryId::AcpThread(session_id.clone()),
1844 format!("Explaining {}", path!("/a/b.md"))
1845 )]
1846 );
1847 let acp_thread = agent
1848 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1849 .await
1850 .unwrap();
1851 acp_thread.read_with(cx, |thread, cx| {
1852 assert_eq!(
1853 thread.to_markdown(cx),
1854 formatdoc! {"
1855 ## User
1856
1857 What does [@b.md]({uri}) mean?
1858
1859 ## Assistant
1860
1861 Lorem.
1862
1863 "}
1864 )
1865 });
1866 }
1867
1868 fn history_entries(
1869 history: &Entity<HistoryStore>,
1870 cx: &mut TestAppContext,
1871 ) -> Vec<(HistoryEntryId, String)> {
1872 history.read_with(cx, |history, _| {
1873 history
1874 .entries()
1875 .map(|e| (e.id(), e.title().to_string()))
1876 .collect::<Vec<_>>()
1877 })
1878 }
1879
1880 fn init_test(cx: &mut TestAppContext) {
1881 env_logger::try_init().ok();
1882 cx.update(|cx| {
1883 let settings_store = SettingsStore::test(cx);
1884 cx.set_global(settings_store);
1885
1886 LanguageModelRegistry::test(cx);
1887 });
1888 }
1889}
1890
1891fn mcp_message_content_to_acp_content_block(
1892 content: context_server::types::MessageContent,
1893) -> acp::ContentBlock {
1894 match content {
1895 context_server::types::MessageContent::Text {
1896 text,
1897 annotations: _,
1898 } => text.into(),
1899 context_server::types::MessageContent::Image {
1900 data,
1901 mime_type,
1902 annotations: _,
1903 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
1904 context_server::types::MessageContent::Audio {
1905 data,
1906 mime_type,
1907 annotations: _,
1908 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
1909 context_server::types::MessageContent::Resource {
1910 resource,
1911 annotations: _,
1912 } => {
1913 let mut link =
1914 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
1915 if let Some(mime_type) = resource.mime_type {
1916 link = link.mime_type(mime_type);
1917 }
1918 acp::ContentBlock::ResourceLink(link)
1919 }
1920 }
1921}