1mod db;
2mod edit_agent;
3mod legacy_thread;
4mod native_agent_server;
5pub mod outline;
6mod pattern_extraction;
7mod templates;
8#[cfg(test)]
9mod tests;
10mod thread;
11mod thread_store;
12mod tool_permissions;
13mod tools;
14
15use context_server::ContextServerId;
16pub use db::*;
17pub use native_agent_server::NativeAgentServer;
18pub use pattern_extraction::*;
19pub use templates::*;
20pub use thread::*;
21pub use thread_store::*;
22pub use tool_permissions::*;
23pub use tools::*;
24
25use acp_thread::{
26 AcpThread, AgentModelSelector, AgentSessionInfo, AgentSessionList, AgentSessionListRequest,
27 AgentSessionListResponse, UserMessageId,
28};
29use agent_client_protocol as acp;
30use anyhow::{Context as _, Result, anyhow};
31use chrono::{DateTime, Utc};
32use collections::{HashMap, HashSet, IndexMap};
33use fs::Fs;
34use futures::channel::{mpsc, oneshot};
35use futures::future::Shared;
36use futures::{StreamExt, future};
37use gpui::{
38 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
39};
40use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
41use project::{Project, ProjectItem, ProjectPath, Worktree};
42use prompt_store::{
43 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
44 WorktreeContext,
45};
46use serde::{Deserialize, Serialize};
47use settings::{LanguageModelSelection, update_settings_file};
48use std::any::Any;
49use std::path::{Path, PathBuf};
50use std::rc::Rc;
51use std::sync::Arc;
52use util::ResultExt;
53use util::rel_path::RelPath;
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub struct ProjectSnapshot {
57 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
58 pub timestamp: DateTime<Utc>,
59}
60
61pub struct RulesLoadingError {
62 pub message: SharedString,
63}
64
65/// Holds both the internal Thread and the AcpThread for a session
66struct Session {
67 /// The internal thread that processes messages
68 thread: Entity<Thread>,
69 /// The ACP thread that handles protocol communication
70 acp_thread: WeakEntity<acp_thread::AcpThread>,
71 pending_save: Task<()>,
72 _subscriptions: Vec<Subscription>,
73}
74
75pub struct LanguageModels {
76 /// Access language model by ID
77 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
78 /// Cached list for returning language model information
79 model_list: acp_thread::AgentModelList,
80 refresh_models_rx: watch::Receiver<()>,
81 refresh_models_tx: watch::Sender<()>,
82 _authenticate_all_providers_task: Task<()>,
83}
84
85impl LanguageModels {
86 fn new(cx: &mut App) -> Self {
87 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
88
89 let mut this = Self {
90 models: HashMap::default(),
91 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
92 refresh_models_rx,
93 refresh_models_tx,
94 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
95 };
96 this.refresh_list(cx);
97 this
98 }
99
100 fn refresh_list(&mut self, cx: &App) {
101 let providers = LanguageModelRegistry::global(cx)
102 .read(cx)
103 .visible_providers()
104 .into_iter()
105 .filter(|provider| provider.is_authenticated(cx))
106 .collect::<Vec<_>>();
107
108 let mut language_model_list = IndexMap::default();
109 let mut recommended_models = HashSet::default();
110
111 let mut recommended = Vec::new();
112 for provider in &providers {
113 for model in provider.recommended_models(cx) {
114 recommended_models.insert((model.provider_id(), model.id()));
115 recommended.push(Self::map_language_model_to_info(&model, provider));
116 }
117 }
118 if !recommended.is_empty() {
119 language_model_list.insert(
120 acp_thread::AgentModelGroupName("Recommended".into()),
121 recommended,
122 );
123 }
124
125 let mut models = HashMap::default();
126 for provider in providers {
127 let mut provider_models = Vec::new();
128 for model in provider.provided_models(cx) {
129 let model_info = Self::map_language_model_to_info(&model, &provider);
130 let model_id = model_info.id.clone();
131 provider_models.push(model_info);
132 models.insert(model_id, model);
133 }
134 if !provider_models.is_empty() {
135 language_model_list.insert(
136 acp_thread::AgentModelGroupName(provider.name().0.clone()),
137 provider_models,
138 );
139 }
140 }
141
142 self.models = models;
143 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
144 self.refresh_models_tx.send(()).ok();
145 }
146
147 fn watch(&self) -> watch::Receiver<()> {
148 self.refresh_models_rx.clone()
149 }
150
151 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
152 self.models.get(model_id).cloned()
153 }
154
155 fn map_language_model_to_info(
156 model: &Arc<dyn LanguageModel>,
157 provider: &Arc<dyn LanguageModelProvider>,
158 ) -> acp_thread::AgentModelInfo {
159 acp_thread::AgentModelInfo {
160 id: Self::model_id(model),
161 name: model.name().0,
162 description: None,
163 icon: Some(match provider.icon() {
164 IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
165 IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
166 }),
167 }
168 }
169
170 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
171 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
172 }
173
174 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
175 let authenticate_all_providers = LanguageModelRegistry::global(cx)
176 .read(cx)
177 .visible_providers()
178 .iter()
179 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
180 .collect::<Vec<_>>();
181
182 cx.background_spawn(async move {
183 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
184 if let Err(err) = authenticate_task.await {
185 match err {
186 language_model::AuthenticateError::CredentialsNotFound => {
187 // Since we're authenticating these providers in the
188 // background for the purposes of populating the
189 // language selector, we don't care about providers
190 // where the credentials are not found.
191 }
192 language_model::AuthenticateError::ConnectionRefused => {
193 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
194 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
195 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
196 }
197 _ => {
198 // Some providers have noisy failure states that we
199 // don't want to spam the logs with every time the
200 // language model selector is initialized.
201 //
202 // Ideally these should have more clear failure modes
203 // that we know are safe to ignore here, like what we do
204 // with `CredentialsNotFound` above.
205 match provider_id.0.as_ref() {
206 "lmstudio" | "ollama" => {
207 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
208 //
209 // These fail noisily, so we don't log them.
210 }
211 "copilot_chat" => {
212 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
213 }
214 _ => {
215 log::error!(
216 "Failed to authenticate provider: {}: {err:#}",
217 provider_name.0
218 );
219 }
220 }
221 }
222 }
223 }
224 }
225 })
226 }
227}
228
229pub struct NativeAgent {
230 /// Session ID -> Session mapping
231 sessions: HashMap<acp::SessionId, Session>,
232 thread_store: Entity<ThreadStore>,
233 /// Shared project context for all threads
234 project_context: Entity<ProjectContext>,
235 project_context_needs_refresh: watch::Sender<()>,
236 _maintain_project_context: Task<Result<()>>,
237 context_server_registry: Entity<ContextServerRegistry>,
238 /// Shared templates for all threads
239 templates: Arc<Templates>,
240 /// Cached model information
241 models: LanguageModels,
242 project: Entity<Project>,
243 prompt_store: Option<Entity<PromptStore>>,
244 fs: Arc<dyn Fs>,
245 _subscriptions: Vec<Subscription>,
246}
247
248impl NativeAgent {
249 pub async fn new(
250 project: Entity<Project>,
251 thread_store: Entity<ThreadStore>,
252 templates: Arc<Templates>,
253 prompt_store: Option<Entity<PromptStore>>,
254 fs: Arc<dyn Fs>,
255 cx: &mut AsyncApp,
256 ) -> Result<Entity<NativeAgent>> {
257 log::debug!("Creating new NativeAgent");
258
259 let project_context = cx
260 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))
261 .await;
262
263 Ok(cx.new(|cx| {
264 let context_server_store = project.read(cx).context_server_store();
265 let context_server_registry =
266 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
267
268 let mut subscriptions = vec![
269 cx.subscribe(&project, Self::handle_project_event),
270 cx.subscribe(
271 &LanguageModelRegistry::global(cx),
272 Self::handle_models_updated_event,
273 ),
274 cx.subscribe(
275 &context_server_store,
276 Self::handle_context_server_store_updated,
277 ),
278 cx.subscribe(
279 &context_server_registry,
280 Self::handle_context_server_registry_event,
281 ),
282 ];
283 if let Some(prompt_store) = prompt_store.as_ref() {
284 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
285 }
286
287 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
288 watch::channel(());
289 Self {
290 sessions: HashMap::default(),
291 thread_store,
292 project_context: cx.new(|_| project_context),
293 project_context_needs_refresh: project_context_needs_refresh_tx,
294 _maintain_project_context: cx.spawn(async move |this, cx| {
295 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
296 }),
297 context_server_registry,
298 templates,
299 models: LanguageModels::new(cx),
300 project,
301 prompt_store,
302 fs,
303 _subscriptions: subscriptions,
304 }
305 }))
306 }
307
308 fn register_session(
309 &mut self,
310 thread_handle: Entity<Thread>,
311 cx: &mut Context<Self>,
312 ) -> Entity<AcpThread> {
313 let connection = Rc::new(NativeAgentConnection(cx.entity()));
314
315 let thread = thread_handle.read(cx);
316 let session_id = thread.id().clone();
317 let title = thread.title();
318 let project = thread.project.clone();
319 let action_log = thread.action_log.clone();
320 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
321 let acp_thread = cx.new(|cx| {
322 acp_thread::AcpThread::new(
323 title,
324 connection,
325 project.clone(),
326 action_log.clone(),
327 session_id.clone(),
328 prompt_capabilities_rx,
329 cx,
330 )
331 });
332
333 let registry = LanguageModelRegistry::read_global(cx);
334 let summarization_model = registry.thread_summary_model().map(|c| c.model);
335
336 thread_handle.update(cx, |thread, cx| {
337 thread.set_summarization_model(summarization_model, cx);
338 thread.add_default_tools(
339 Rc::new(AcpThreadEnvironment {
340 acp_thread: acp_thread.downgrade(),
341 }) as _,
342 cx,
343 )
344 });
345
346 let subscriptions = vec![
347 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
348 this.sessions.remove(acp_thread.session_id());
349 }),
350 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
351 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
352 cx.observe(&thread_handle, move |this, thread, cx| {
353 this.save_thread(thread, cx)
354 }),
355 ];
356
357 self.sessions.insert(
358 session_id,
359 Session {
360 thread: thread_handle,
361 acp_thread: acp_thread.downgrade(),
362 _subscriptions: subscriptions,
363 pending_save: Task::ready(()),
364 },
365 );
366
367 self.update_available_commands(cx);
368
369 acp_thread
370 }
371
372 pub fn models(&self) -> &LanguageModels {
373 &self.models
374 }
375
376 async fn maintain_project_context(
377 this: WeakEntity<Self>,
378 mut needs_refresh: watch::Receiver<()>,
379 cx: &mut AsyncApp,
380 ) -> Result<()> {
381 while needs_refresh.changed().await.is_ok() {
382 let project_context = this
383 .update(cx, |this, cx| {
384 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
385 })?
386 .await;
387 this.update(cx, |this, cx| {
388 this.project_context = cx.new(|_| project_context);
389 })?;
390 }
391
392 Ok(())
393 }
394
395 fn build_project_context(
396 project: &Entity<Project>,
397 prompt_store: Option<&Entity<PromptStore>>,
398 cx: &mut App,
399 ) -> Task<ProjectContext> {
400 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
401 let worktree_tasks = worktrees
402 .into_iter()
403 .map(|worktree| {
404 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
405 })
406 .collect::<Vec<_>>();
407 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
408 prompt_store.read_with(cx, |prompt_store, cx| {
409 let prompts = prompt_store.default_prompt_metadata();
410 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
411 let contents = prompt_store.load(prompt_metadata.id, cx);
412 async move { (contents.await, prompt_metadata) }
413 });
414 cx.background_spawn(future::join_all(load_tasks))
415 })
416 } else {
417 Task::ready(vec![])
418 };
419
420 cx.spawn(async move |_cx| {
421 let (worktrees, default_user_rules) =
422 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
423
424 let worktrees = worktrees
425 .into_iter()
426 .map(|(worktree, _rules_error)| {
427 // TODO: show error message
428 // if let Some(rules_error) = rules_error {
429 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
430 // }
431 worktree
432 })
433 .collect::<Vec<_>>();
434
435 let default_user_rules = default_user_rules
436 .into_iter()
437 .flat_map(|(contents, prompt_metadata)| match contents {
438 Ok(contents) => Some(UserRulesContext {
439 uuid: prompt_metadata.id.as_user()?,
440 title: prompt_metadata.title.map(|title| title.to_string()),
441 contents,
442 }),
443 Err(_err) => {
444 // TODO: show error message
445 // this.update(cx, |_, cx| {
446 // cx.emit(RulesLoadingError {
447 // message: format!("{err:?}").into(),
448 // });
449 // })
450 // .ok();
451 None
452 }
453 })
454 .collect::<Vec<_>>();
455
456 ProjectContext::new(worktrees, default_user_rules)
457 })
458 }
459
460 fn load_worktree_info_for_system_prompt(
461 worktree: Entity<Worktree>,
462 project: Entity<Project>,
463 cx: &mut App,
464 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
465 let tree = worktree.read(cx);
466 let root_name = tree.root_name_str().into();
467 let abs_path = tree.abs_path();
468
469 let mut context = WorktreeContext {
470 root_name,
471 abs_path,
472 rules_file: None,
473 };
474
475 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
476 let Some(rules_task) = rules_task else {
477 return Task::ready((context, None));
478 };
479
480 cx.spawn(async move |_| {
481 let (rules_file, rules_file_error) = match rules_task.await {
482 Ok(rules_file) => (Some(rules_file), None),
483 Err(err) => (
484 None,
485 Some(RulesLoadingError {
486 message: format!("{err}").into(),
487 }),
488 ),
489 };
490 context.rules_file = rules_file;
491 (context, rules_file_error)
492 })
493 }
494
495 fn load_worktree_rules_file(
496 worktree: Entity<Worktree>,
497 project: Entity<Project>,
498 cx: &mut App,
499 ) -> Option<Task<Result<RulesFileContext>>> {
500 let worktree = worktree.read(cx);
501 let worktree_id = worktree.id();
502 let selected_rules_file = RULES_FILE_NAMES
503 .into_iter()
504 .filter_map(|name| {
505 worktree
506 .entry_for_path(RelPath::unix(name).unwrap())
507 .filter(|entry| entry.is_file())
508 .map(|entry| entry.path.clone())
509 })
510 .next();
511
512 // Note that Cline supports `.clinerules` being a directory, but that is not currently
513 // supported. This doesn't seem to occur often in GitHub repositories.
514 selected_rules_file.map(|path_in_worktree| {
515 let project_path = ProjectPath {
516 worktree_id,
517 path: path_in_worktree.clone(),
518 };
519 let buffer_task =
520 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
521 let rope_task = cx.spawn(async move |cx| {
522 let buffer = buffer_task.await?;
523 let (project_entry_id, rope) = buffer.read_with(cx, |buffer, cx| {
524 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
525 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
526 })?;
527 anyhow::Ok((project_entry_id, rope))
528 });
529 // Build a string from the rope on a background thread.
530 cx.background_spawn(async move {
531 let (project_entry_id, rope) = rope_task.await?;
532 anyhow::Ok(RulesFileContext {
533 path_in_worktree,
534 text: rope.to_string().trim().to_string(),
535 project_entry_id: project_entry_id.to_usize(),
536 })
537 })
538 })
539 }
540
541 fn handle_thread_title_updated(
542 &mut self,
543 thread: Entity<Thread>,
544 _: &TitleUpdated,
545 cx: &mut Context<Self>,
546 ) {
547 let session_id = thread.read(cx).id();
548 let Some(session) = self.sessions.get(session_id) else {
549 return;
550 };
551 let thread = thread.downgrade();
552 let acp_thread = session.acp_thread.clone();
553 cx.spawn(async move |_, cx| {
554 let title = thread.read_with(cx, |thread, _| thread.title())?;
555 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
556 task.await
557 })
558 .detach_and_log_err(cx);
559 }
560
561 fn handle_thread_token_usage_updated(
562 &mut self,
563 thread: Entity<Thread>,
564 usage: &TokenUsageUpdated,
565 cx: &mut Context<Self>,
566 ) {
567 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
568 return;
569 };
570 session
571 .acp_thread
572 .update(cx, |acp_thread, cx| {
573 acp_thread.update_token_usage(usage.0.clone(), cx);
574 })
575 .ok();
576 }
577
578 fn handle_project_event(
579 &mut self,
580 _project: Entity<Project>,
581 event: &project::Event,
582 _cx: &mut Context<Self>,
583 ) {
584 match event {
585 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
586 self.project_context_needs_refresh.send(()).ok();
587 }
588 project::Event::WorktreeUpdatedEntries(_, items) => {
589 if items.iter().any(|(path, _, _)| {
590 RULES_FILE_NAMES
591 .iter()
592 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
593 }) {
594 self.project_context_needs_refresh.send(()).ok();
595 }
596 }
597 _ => {}
598 }
599 }
600
601 fn handle_prompts_updated_event(
602 &mut self,
603 _prompt_store: Entity<PromptStore>,
604 _event: &prompt_store::PromptsUpdatedEvent,
605 _cx: &mut Context<Self>,
606 ) {
607 self.project_context_needs_refresh.send(()).ok();
608 }
609
610 fn handle_models_updated_event(
611 &mut self,
612 _registry: Entity<LanguageModelRegistry>,
613 _event: &language_model::Event,
614 cx: &mut Context<Self>,
615 ) {
616 self.models.refresh_list(cx);
617
618 let registry = LanguageModelRegistry::read_global(cx);
619 let default_model = registry.default_model().map(|m| m.model);
620 let summarization_model = registry.thread_summary_model().map(|m| m.model);
621
622 for session in self.sessions.values_mut() {
623 session.thread.update(cx, |thread, cx| {
624 if thread.model().is_none()
625 && let Some(model) = default_model.clone()
626 {
627 thread.set_model(model, cx);
628 cx.notify();
629 }
630 thread.set_summarization_model(summarization_model.clone(), cx);
631 });
632 }
633 }
634
635 fn handle_context_server_store_updated(
636 &mut self,
637 _store: Entity<project::context_server_store::ContextServerStore>,
638 _event: &project::context_server_store::Event,
639 cx: &mut Context<Self>,
640 ) {
641 self.update_available_commands(cx);
642 }
643
644 fn handle_context_server_registry_event(
645 &mut self,
646 _registry: Entity<ContextServerRegistry>,
647 event: &ContextServerRegistryEvent,
648 cx: &mut Context<Self>,
649 ) {
650 match event {
651 ContextServerRegistryEvent::ToolsChanged => {}
652 ContextServerRegistryEvent::PromptsChanged => {
653 self.update_available_commands(cx);
654 }
655 }
656 }
657
658 fn update_available_commands(&self, cx: &mut Context<Self>) {
659 let available_commands = self.build_available_commands(cx);
660 for session in self.sessions.values() {
661 if let Some(acp_thread) = session.acp_thread.upgrade() {
662 acp_thread.update(cx, |thread, cx| {
663 thread
664 .handle_session_update(
665 acp::SessionUpdate::AvailableCommandsUpdate(
666 acp::AvailableCommandsUpdate::new(available_commands.clone()),
667 ),
668 cx,
669 )
670 .log_err();
671 });
672 }
673 }
674 }
675
676 fn build_available_commands(&self, cx: &App) -> Vec<acp::AvailableCommand> {
677 let registry = self.context_server_registry.read(cx);
678
679 let mut prompt_name_counts: HashMap<&str, usize> = HashMap::default();
680 for context_server_prompt in registry.prompts() {
681 *prompt_name_counts
682 .entry(context_server_prompt.prompt.name.as_str())
683 .or_insert(0) += 1;
684 }
685
686 registry
687 .prompts()
688 .flat_map(|context_server_prompt| {
689 let prompt = &context_server_prompt.prompt;
690
691 let should_prefix = prompt_name_counts
692 .get(prompt.name.as_str())
693 .copied()
694 .unwrap_or(0)
695 > 1;
696
697 let name = if should_prefix {
698 format!("{}.{}", context_server_prompt.server_id, prompt.name)
699 } else {
700 prompt.name.clone()
701 };
702
703 let mut command = acp::AvailableCommand::new(
704 name,
705 prompt.description.clone().unwrap_or_default(),
706 );
707
708 match prompt.arguments.as_deref() {
709 Some([arg]) => {
710 let hint = format!("<{}>", arg.name);
711
712 command = command.input(acp::AvailableCommandInput::Unstructured(
713 acp::UnstructuredCommandInput::new(hint),
714 ));
715 }
716 Some([]) | None => {}
717 Some(_) => {
718 // skip >1 argument commands since we don't support them yet
719 return None;
720 }
721 }
722
723 Some(command)
724 })
725 .collect()
726 }
727
728 pub fn load_thread(
729 &mut self,
730 id: acp::SessionId,
731 cx: &mut Context<Self>,
732 ) -> Task<Result<Entity<Thread>>> {
733 let database_future = ThreadsDatabase::connect(cx);
734 cx.spawn(async move |this, cx| {
735 let database = database_future.await.map_err(|err| anyhow!(err))?;
736 let db_thread = database
737 .load_thread(id.clone())
738 .await?
739 .with_context(|| format!("no thread found with ID: {id:?}"))?;
740
741 this.update(cx, |this, cx| {
742 let summarization_model = LanguageModelRegistry::read_global(cx)
743 .thread_summary_model()
744 .map(|c| c.model);
745
746 cx.new(|cx| {
747 let mut thread = Thread::from_db(
748 id.clone(),
749 db_thread,
750 this.project.clone(),
751 this.project_context.clone(),
752 this.context_server_registry.clone(),
753 this.templates.clone(),
754 cx,
755 );
756 thread.set_summarization_model(summarization_model, cx);
757 thread
758 })
759 })
760 })
761 }
762
763 pub fn open_thread(
764 &mut self,
765 id: acp::SessionId,
766 cx: &mut Context<Self>,
767 ) -> Task<Result<Entity<AcpThread>>> {
768 let task = self.load_thread(id, cx);
769 cx.spawn(async move |this, cx| {
770 let thread = task.await?;
771 let acp_thread =
772 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
773 let events = thread.update(cx, |thread, cx| thread.replay(cx));
774 cx.update(|cx| {
775 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
776 })
777 .await?;
778 Ok(acp_thread)
779 })
780 }
781
782 pub fn thread_summary(
783 &mut self,
784 id: acp::SessionId,
785 cx: &mut Context<Self>,
786 ) -> Task<Result<SharedString>> {
787 let thread = self.open_thread(id.clone(), cx);
788 cx.spawn(async move |this, cx| {
789 let acp_thread = thread.await?;
790 let result = this
791 .update(cx, |this, cx| {
792 this.sessions
793 .get(&id)
794 .unwrap()
795 .thread
796 .update(cx, |thread, cx| thread.summary(cx))
797 })?
798 .await
799 .context("Failed to generate summary")?;
800 drop(acp_thread);
801 Ok(result)
802 })
803 }
804
805 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
806 if thread.read(cx).is_empty() {
807 return;
808 }
809
810 let database_future = ThreadsDatabase::connect(cx);
811 let (id, db_thread) =
812 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
813 let Some(session) = self.sessions.get_mut(&id) else {
814 return;
815 };
816 let thread_store = self.thread_store.clone();
817 session.pending_save = cx.spawn(async move |_, cx| {
818 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
819 return;
820 };
821 let db_thread = db_thread.await;
822 database.save_thread(id, db_thread).await.log_err();
823 thread_store.update(cx, |store, cx| store.reload(cx));
824 });
825 }
826
827 fn send_mcp_prompt(
828 &self,
829 message_id: UserMessageId,
830 session_id: agent_client_protocol::SessionId,
831 prompt_name: String,
832 server_id: ContextServerId,
833 arguments: HashMap<String, String>,
834 original_content: Vec<acp::ContentBlock>,
835 cx: &mut Context<Self>,
836 ) -> Task<Result<acp::PromptResponse>> {
837 let server_store = self.context_server_registry.read(cx).server_store().clone();
838 let path_style = self.project.read(cx).path_style(cx);
839
840 cx.spawn(async move |this, cx| {
841 let prompt =
842 crate::get_prompt(&server_store, &server_id, &prompt_name, arguments, cx).await?;
843
844 let (acp_thread, thread) = this.update(cx, |this, _cx| {
845 let session = this
846 .sessions
847 .get(&session_id)
848 .context("Failed to get session")?;
849 anyhow::Ok((session.acp_thread.clone(), session.thread.clone()))
850 })??;
851
852 let mut last_is_user = true;
853
854 thread.update(cx, |thread, cx| {
855 thread.push_acp_user_block(
856 message_id,
857 original_content.into_iter().skip(1),
858 path_style,
859 cx,
860 );
861 });
862
863 for message in prompt.messages {
864 let context_server::types::PromptMessage { role, content } = message;
865 let block = mcp_message_content_to_acp_content_block(content);
866
867 match role {
868 context_server::types::Role::User => {
869 let id = acp_thread::UserMessageId::new();
870
871 acp_thread.update(cx, |acp_thread, cx| {
872 acp_thread.push_user_content_block_with_indent(
873 Some(id.clone()),
874 block.clone(),
875 true,
876 cx,
877 );
878 })?;
879
880 thread.update(cx, |thread, cx| {
881 thread.push_acp_user_block(id, [block], path_style, cx);
882 });
883 }
884 context_server::types::Role::Assistant => {
885 acp_thread.update(cx, |acp_thread, cx| {
886 acp_thread.push_assistant_content_block_with_indent(
887 block.clone(),
888 false,
889 true,
890 cx,
891 );
892 })?;
893
894 thread.update(cx, |thread, cx| {
895 thread.push_acp_agent_block(block, cx);
896 });
897 }
898 }
899
900 last_is_user = role == context_server::types::Role::User;
901 }
902
903 let response_stream = thread.update(cx, |thread, cx| {
904 if last_is_user {
905 thread.send_existing(cx)
906 } else {
907 // Resume if MCP prompt did not end with a user message
908 thread.resume(cx)
909 }
910 })?;
911
912 cx.update(|cx| {
913 NativeAgentConnection::handle_thread_events(response_stream, acp_thread, cx)
914 })
915 .await
916 })
917 }
918}
919
920/// Wrapper struct that implements the AgentConnection trait
921#[derive(Clone)]
922pub struct NativeAgentConnection(pub Entity<NativeAgent>);
923
924impl NativeAgentConnection {
925 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
926 self.0
927 .read(cx)
928 .sessions
929 .get(session_id)
930 .map(|session| session.thread.clone())
931 }
932
933 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
934 self.0.update(cx, |this, cx| this.load_thread(id, cx))
935 }
936
937 fn run_turn(
938 &self,
939 session_id: acp::SessionId,
940 cx: &mut App,
941 f: impl 'static
942 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
943 ) -> Task<Result<acp::PromptResponse>> {
944 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
945 agent
946 .sessions
947 .get_mut(&session_id)
948 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
949 }) else {
950 return Task::ready(Err(anyhow!("Session not found")));
951 };
952 log::debug!("Found session for: {}", session_id);
953
954 let response_stream = match f(thread, cx) {
955 Ok(stream) => stream,
956 Err(err) => return Task::ready(Err(err)),
957 };
958 Self::handle_thread_events(response_stream, acp_thread, cx)
959 }
960
961 fn handle_thread_events(
962 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
963 acp_thread: WeakEntity<AcpThread>,
964 cx: &App,
965 ) -> Task<Result<acp::PromptResponse>> {
966 cx.spawn(async move |cx| {
967 // Handle response stream and forward to session.acp_thread
968 while let Some(result) = events.next().await {
969 match result {
970 Ok(event) => {
971 log::trace!("Received completion event: {:?}", event);
972
973 match event {
974 ThreadEvent::UserMessage(message) => {
975 acp_thread.update(cx, |thread, cx| {
976 for content in message.content {
977 thread.push_user_content_block(
978 Some(message.id.clone()),
979 content.into(),
980 cx,
981 );
982 }
983 })?;
984 }
985 ThreadEvent::AgentText(text) => {
986 acp_thread.update(cx, |thread, cx| {
987 thread.push_assistant_content_block(text.into(), false, cx)
988 })?;
989 }
990 ThreadEvent::AgentThinking(text) => {
991 acp_thread.update(cx, |thread, cx| {
992 thread.push_assistant_content_block(text.into(), true, cx)
993 })?;
994 }
995 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
996 tool_call,
997 options,
998 response,
999 context: _,
1000 }) => {
1001 let outcome_task = acp_thread.update(cx, |thread, cx| {
1002 thread.request_tool_call_authorization(
1003 tool_call, options, true, cx,
1004 )
1005 })??;
1006 cx.background_spawn(async move {
1007 if let acp::RequestPermissionOutcome::Selected(
1008 acp::SelectedPermissionOutcome { option_id, .. },
1009 ) = outcome_task.await
1010 {
1011 response
1012 .send(option_id)
1013 .map(|_| anyhow!("authorization receiver was dropped"))
1014 .log_err();
1015 }
1016 })
1017 .detach();
1018 }
1019 ThreadEvent::ToolCall(tool_call) => {
1020 acp_thread.update(cx, |thread, cx| {
1021 thread.upsert_tool_call(tool_call, cx)
1022 })??;
1023 }
1024 ThreadEvent::ToolCallUpdate(update) => {
1025 acp_thread.update(cx, |thread, cx| {
1026 thread.update_tool_call(update, cx)
1027 })??;
1028 }
1029 ThreadEvent::Retry(status) => {
1030 acp_thread.update(cx, |thread, cx| {
1031 thread.update_retry_status(status, cx)
1032 })?;
1033 }
1034 ThreadEvent::Stop(stop_reason) => {
1035 log::debug!("Assistant message complete: {:?}", stop_reason);
1036 return Ok(acp::PromptResponse::new(stop_reason));
1037 }
1038 }
1039 }
1040 Err(e) => {
1041 log::error!("Error in model response stream: {:?}", e);
1042 return Err(e);
1043 }
1044 }
1045 }
1046
1047 log::debug!("Response stream completed");
1048 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
1049 })
1050 }
1051}
1052
1053struct Command<'a> {
1054 prompt_name: &'a str,
1055 arg_value: &'a str,
1056 explicit_server_id: Option<&'a str>,
1057}
1058
1059impl<'a> Command<'a> {
1060 fn parse(prompt: &'a [acp::ContentBlock]) -> Option<Self> {
1061 let acp::ContentBlock::Text(text_content) = prompt.first()? else {
1062 return None;
1063 };
1064 let text = text_content.text.trim();
1065 let command = text.strip_prefix('/')?;
1066 let (command, arg_value) = command
1067 .split_once(char::is_whitespace)
1068 .unwrap_or((command, ""));
1069
1070 if let Some((server_id, prompt_name)) = command.split_once('.') {
1071 Some(Self {
1072 prompt_name,
1073 arg_value,
1074 explicit_server_id: Some(server_id),
1075 })
1076 } else {
1077 Some(Self {
1078 prompt_name: command,
1079 arg_value,
1080 explicit_server_id: None,
1081 })
1082 }
1083 }
1084}
1085
1086struct NativeAgentModelSelector {
1087 session_id: acp::SessionId,
1088 connection: NativeAgentConnection,
1089}
1090
1091impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
1092 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
1093 log::debug!("NativeAgentConnection::list_models called");
1094 let list = self.connection.0.read(cx).models.model_list.clone();
1095 Task::ready(if list.is_empty() {
1096 Err(anyhow::anyhow!("No models available"))
1097 } else {
1098 Ok(list)
1099 })
1100 }
1101
1102 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
1103 log::debug!(
1104 "Setting model for session {}: {}",
1105 self.session_id,
1106 model_id
1107 );
1108 let Some(thread) = self
1109 .connection
1110 .0
1111 .read(cx)
1112 .sessions
1113 .get(&self.session_id)
1114 .map(|session| session.thread.clone())
1115 else {
1116 return Task::ready(Err(anyhow!("Session not found")));
1117 };
1118
1119 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
1120 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
1121 };
1122
1123 thread.update(cx, |thread, cx| {
1124 thread.set_model(model.clone(), cx);
1125 });
1126
1127 update_settings_file(
1128 self.connection.0.read(cx).fs.clone(),
1129 cx,
1130 move |settings, _cx| {
1131 let provider = model.provider_id().0.to_string();
1132 let model = model.id().0.to_string();
1133 settings
1134 .agent
1135 .get_or_insert_default()
1136 .set_model(LanguageModelSelection {
1137 provider: provider.into(),
1138 model,
1139 });
1140 },
1141 );
1142
1143 Task::ready(Ok(()))
1144 }
1145
1146 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
1147 let Some(thread) = self
1148 .connection
1149 .0
1150 .read(cx)
1151 .sessions
1152 .get(&self.session_id)
1153 .map(|session| session.thread.clone())
1154 else {
1155 return Task::ready(Err(anyhow!("Session not found")));
1156 };
1157 let Some(model) = thread.read(cx).model() else {
1158 return Task::ready(Err(anyhow!("Model not found")));
1159 };
1160 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
1161 else {
1162 return Task::ready(Err(anyhow!("Provider not found")));
1163 };
1164 Task::ready(Ok(LanguageModels::map_language_model_to_info(
1165 model, &provider,
1166 )))
1167 }
1168
1169 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
1170 Some(self.connection.0.read(cx).models.watch())
1171 }
1172
1173 fn should_render_footer(&self) -> bool {
1174 true
1175 }
1176}
1177
1178impl acp_thread::AgentConnection for NativeAgentConnection {
1179 fn telemetry_id(&self) -> SharedString {
1180 "zed".into()
1181 }
1182
1183 fn new_thread(
1184 self: Rc<Self>,
1185 project: Entity<Project>,
1186 cwd: &Path,
1187 cx: &mut App,
1188 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1189 let agent = self.0.clone();
1190 log::debug!("Creating new thread for project at: {:?}", cwd);
1191
1192 cx.spawn(async move |cx| {
1193 log::debug!("Starting thread creation in async context");
1194
1195 // Create Thread
1196 let thread = agent.update(cx, |agent, cx| {
1197 // Fetch default model from registry settings
1198 let registry = LanguageModelRegistry::read_global(cx);
1199 // Log available models for debugging
1200 let available_count = registry.available_models(cx).count();
1201 log::debug!("Total available models: {}", available_count);
1202
1203 let default_model = registry.default_model().and_then(|default_model| {
1204 agent
1205 .models
1206 .model_from_id(&LanguageModels::model_id(&default_model.model))
1207 });
1208 cx.new(|cx| {
1209 Thread::new(
1210 project.clone(),
1211 agent.project_context.clone(),
1212 agent.context_server_registry.clone(),
1213 agent.templates.clone(),
1214 default_model,
1215 cx,
1216 )
1217 })
1218 });
1219 Ok(agent.update(cx, |agent, cx| agent.register_session(thread, cx)))
1220 })
1221 }
1222
1223 fn supports_load_session(&self, _cx: &App) -> bool {
1224 true
1225 }
1226
1227 fn load_session(
1228 self: Rc<Self>,
1229 session: AgentSessionInfo,
1230 _project: Entity<Project>,
1231 _cwd: &Path,
1232 cx: &mut App,
1233 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
1234 self.0
1235 .update(cx, |agent, cx| agent.open_thread(session.session_id, cx))
1236 }
1237
1238 fn auth_methods(&self) -> &[acp::AuthMethod] {
1239 &[] // No auth for in-process
1240 }
1241
1242 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1243 Task::ready(Ok(()))
1244 }
1245
1246 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1247 Some(Rc::new(NativeAgentModelSelector {
1248 session_id: session_id.clone(),
1249 connection: self.clone(),
1250 }) as Rc<dyn AgentModelSelector>)
1251 }
1252
1253 fn prompt(
1254 &self,
1255 id: Option<acp_thread::UserMessageId>,
1256 params: acp::PromptRequest,
1257 cx: &mut App,
1258 ) -> Task<Result<acp::PromptResponse>> {
1259 let id = id.expect("UserMessageId is required");
1260 let session_id = params.session_id.clone();
1261 log::info!("Received prompt request for session: {}", session_id);
1262 log::debug!("Prompt blocks count: {}", params.prompt.len());
1263
1264 if let Some(parsed_command) = Command::parse(¶ms.prompt) {
1265 let registry = self.0.read(cx).context_server_registry.read(cx);
1266
1267 let explicit_server_id = parsed_command
1268 .explicit_server_id
1269 .map(|server_id| ContextServerId(server_id.into()));
1270
1271 if let Some(prompt) =
1272 registry.find_prompt(explicit_server_id.as_ref(), parsed_command.prompt_name)
1273 {
1274 let arguments = if !parsed_command.arg_value.is_empty()
1275 && let Some(arg_name) = prompt
1276 .prompt
1277 .arguments
1278 .as_ref()
1279 .and_then(|args| args.first())
1280 .map(|arg| arg.name.clone())
1281 {
1282 HashMap::from_iter([(arg_name, parsed_command.arg_value.to_string())])
1283 } else {
1284 Default::default()
1285 };
1286
1287 let prompt_name = prompt.prompt.name.clone();
1288 let server_id = prompt.server_id.clone();
1289
1290 return self.0.update(cx, |agent, cx| {
1291 agent.send_mcp_prompt(
1292 id,
1293 session_id.clone(),
1294 prompt_name,
1295 server_id,
1296 arguments,
1297 params.prompt,
1298 cx,
1299 )
1300 });
1301 };
1302 };
1303
1304 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1305
1306 self.run_turn(session_id, cx, move |thread, cx| {
1307 let content: Vec<UserMessageContent> = params
1308 .prompt
1309 .into_iter()
1310 .map(|block| UserMessageContent::from_content_block(block, path_style))
1311 .collect::<Vec<_>>();
1312 log::debug!("Converted prompt to message: {} chars", content.len());
1313 log::debug!("Message id: {:?}", id);
1314 log::debug!("Message content: {:?}", content);
1315
1316 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1317 })
1318 }
1319
1320 fn retry(
1321 &self,
1322 session_id: &acp::SessionId,
1323 _cx: &App,
1324 ) -> Option<Rc<dyn acp_thread::AgentSessionRetry>> {
1325 Some(Rc::new(NativeAgentSessionRetry {
1326 connection: self.clone(),
1327 session_id: session_id.clone(),
1328 }) as _)
1329 }
1330
1331 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1332 log::info!("Cancelling on session: {}", session_id);
1333 self.0.update(cx, |agent, cx| {
1334 if let Some(agent) = agent.sessions.get(session_id) {
1335 agent
1336 .thread
1337 .update(cx, |thread, cx| thread.cancel(cx))
1338 .detach();
1339 }
1340 });
1341 }
1342
1343 fn truncate(
1344 &self,
1345 session_id: &agent_client_protocol::SessionId,
1346 cx: &App,
1347 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1348 self.0.read_with(cx, |agent, _cx| {
1349 agent.sessions.get(session_id).map(|session| {
1350 Rc::new(NativeAgentSessionTruncate {
1351 thread: session.thread.clone(),
1352 acp_thread: session.acp_thread.clone(),
1353 }) as _
1354 })
1355 })
1356 }
1357
1358 fn set_title(
1359 &self,
1360 session_id: &acp::SessionId,
1361 _cx: &App,
1362 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1363 Some(Rc::new(NativeAgentSessionSetTitle {
1364 connection: self.clone(),
1365 session_id: session_id.clone(),
1366 }) as _)
1367 }
1368
1369 fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
1370 let thread_store = self.0.read(cx).thread_store.clone();
1371 Some(Rc::new(NativeAgentSessionList::new(thread_store, cx)) as _)
1372 }
1373
1374 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1375 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1376 }
1377
1378 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1379 self
1380 }
1381}
1382
1383impl acp_thread::AgentTelemetry for NativeAgentConnection {
1384 fn thread_data(
1385 &self,
1386 session_id: &acp::SessionId,
1387 cx: &mut App,
1388 ) -> Task<Result<serde_json::Value>> {
1389 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1390 return Task::ready(Err(anyhow!("Session not found")));
1391 };
1392
1393 let task = session.thread.read(cx).to_db(cx);
1394 cx.background_spawn(async move {
1395 serde_json::to_value(task.await).context("Failed to serialize thread")
1396 })
1397 }
1398}
1399
1400pub struct NativeAgentSessionList {
1401 thread_store: Entity<ThreadStore>,
1402 updates_tx: smol::channel::Sender<acp_thread::SessionListUpdate>,
1403 updates_rx: smol::channel::Receiver<acp_thread::SessionListUpdate>,
1404 _subscription: Subscription,
1405}
1406
1407impl NativeAgentSessionList {
1408 fn new(thread_store: Entity<ThreadStore>, cx: &mut App) -> Self {
1409 let (tx, rx) = smol::channel::unbounded();
1410 let this_tx = tx.clone();
1411 let subscription = cx.observe(&thread_store, move |_, _| {
1412 this_tx
1413 .try_send(acp_thread::SessionListUpdate::Refresh)
1414 .ok();
1415 });
1416 Self {
1417 thread_store,
1418 updates_tx: tx,
1419 updates_rx: rx,
1420 _subscription: subscription,
1421 }
1422 }
1423
1424 fn to_session_info(entry: DbThreadMetadata) -> AgentSessionInfo {
1425 AgentSessionInfo {
1426 session_id: entry.id,
1427 cwd: None,
1428 title: Some(entry.title),
1429 updated_at: Some(entry.updated_at),
1430 meta: None,
1431 }
1432 }
1433
1434 pub fn thread_store(&self) -> &Entity<ThreadStore> {
1435 &self.thread_store
1436 }
1437}
1438
1439impl AgentSessionList for NativeAgentSessionList {
1440 fn list_sessions(
1441 &self,
1442 _request: AgentSessionListRequest,
1443 cx: &mut App,
1444 ) -> Task<Result<AgentSessionListResponse>> {
1445 let sessions = self
1446 .thread_store
1447 .read(cx)
1448 .entries()
1449 .map(Self::to_session_info)
1450 .collect();
1451 Task::ready(Ok(AgentSessionListResponse::new(sessions)))
1452 }
1453
1454 fn supports_delete(&self) -> bool {
1455 true
1456 }
1457
1458 fn delete_session(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<()>> {
1459 self.thread_store
1460 .update(cx, |store, cx| store.delete_thread(session_id.clone(), cx))
1461 }
1462
1463 fn delete_sessions(&self, cx: &mut App) -> Task<Result<()>> {
1464 self.thread_store
1465 .update(cx, |store, cx| store.delete_threads(cx))
1466 }
1467
1468 fn watch(
1469 &self,
1470 _cx: &mut App,
1471 ) -> Option<smol::channel::Receiver<acp_thread::SessionListUpdate>> {
1472 Some(self.updates_rx.clone())
1473 }
1474
1475 fn notify_refresh(&self) {
1476 self.updates_tx
1477 .try_send(acp_thread::SessionListUpdate::Refresh)
1478 .ok();
1479 }
1480
1481 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1482 self
1483 }
1484}
1485
1486struct NativeAgentSessionTruncate {
1487 thread: Entity<Thread>,
1488 acp_thread: WeakEntity<AcpThread>,
1489}
1490
1491impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1492 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1493 match self.thread.update(cx, |thread, cx| {
1494 thread.truncate(message_id.clone(), cx)?;
1495 Ok(thread.latest_token_usage())
1496 }) {
1497 Ok(usage) => {
1498 self.acp_thread
1499 .update(cx, |thread, cx| {
1500 thread.update_token_usage(usage, cx);
1501 })
1502 .ok();
1503 Task::ready(Ok(()))
1504 }
1505 Err(error) => Task::ready(Err(error)),
1506 }
1507 }
1508}
1509
1510struct NativeAgentSessionRetry {
1511 connection: NativeAgentConnection,
1512 session_id: acp::SessionId,
1513}
1514
1515impl acp_thread::AgentSessionRetry for NativeAgentSessionRetry {
1516 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1517 self.connection
1518 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1519 thread.update(cx, |thread, cx| thread.resume(cx))
1520 })
1521 }
1522}
1523
1524struct NativeAgentSessionSetTitle {
1525 connection: NativeAgentConnection,
1526 session_id: acp::SessionId,
1527}
1528
1529impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1530 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1531 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1532 return Task::ready(Err(anyhow!("session not found")));
1533 };
1534 let thread = session.thread.clone();
1535 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1536 Task::ready(Ok(()))
1537 }
1538}
1539
1540pub struct AcpThreadEnvironment {
1541 acp_thread: WeakEntity<AcpThread>,
1542}
1543
1544impl ThreadEnvironment for AcpThreadEnvironment {
1545 fn create_terminal(
1546 &self,
1547 command: String,
1548 cwd: Option<PathBuf>,
1549 output_byte_limit: Option<u64>,
1550 cx: &mut AsyncApp,
1551 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1552 let task = self.acp_thread.update(cx, |thread, cx| {
1553 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1554 });
1555
1556 let acp_thread = self.acp_thread.clone();
1557 cx.spawn(async move |cx| {
1558 let terminal = task?.await?;
1559
1560 let (drop_tx, drop_rx) = oneshot::channel();
1561 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone());
1562
1563 cx.spawn(async move |cx| {
1564 drop_rx.await.ok();
1565 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1566 })
1567 .detach();
1568
1569 let handle = AcpTerminalHandle {
1570 terminal,
1571 _drop_tx: Some(drop_tx),
1572 };
1573
1574 Ok(Rc::new(handle) as _)
1575 })
1576 }
1577}
1578
1579pub struct AcpTerminalHandle {
1580 terminal: Entity<acp_thread::Terminal>,
1581 _drop_tx: Option<oneshot::Sender<()>>,
1582}
1583
1584impl TerminalHandle for AcpTerminalHandle {
1585 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1586 Ok(self.terminal.read_with(cx, |term, _cx| term.id().clone()))
1587 }
1588
1589 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1590 Ok(self
1591 .terminal
1592 .read_with(cx, |term, _cx| term.wait_for_exit()))
1593 }
1594
1595 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1596 Ok(self
1597 .terminal
1598 .read_with(cx, |term, cx| term.current_output(cx)))
1599 }
1600
1601 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1602 cx.update(|cx| {
1603 self.terminal.update(cx, |terminal, cx| {
1604 terminal.kill(cx);
1605 });
1606 });
1607 Ok(())
1608 }
1609
1610 fn was_stopped_by_user(&self, cx: &AsyncApp) -> Result<bool> {
1611 Ok(self
1612 .terminal
1613 .read_with(cx, |term, _cx| term.was_stopped_by_user()))
1614 }
1615}
1616
1617#[cfg(test)]
1618mod internal_tests {
1619 use super::*;
1620 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1621 use fs::FakeFs;
1622 use gpui::TestAppContext;
1623 use indoc::formatdoc;
1624 use language_model::fake_provider::FakeLanguageModel;
1625 use serde_json::json;
1626 use settings::SettingsStore;
1627 use util::{path, rel_path::rel_path};
1628
1629 #[gpui::test]
1630 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1631 init_test(cx);
1632 let fs = FakeFs::new(cx.executor());
1633 fs.insert_tree(
1634 "/",
1635 json!({
1636 "a": {}
1637 }),
1638 )
1639 .await;
1640 let project = Project::test(fs.clone(), [], cx).await;
1641 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1642 let agent = NativeAgent::new(
1643 project.clone(),
1644 thread_store,
1645 Templates::new(),
1646 None,
1647 fs.clone(),
1648 &mut cx.to_async(),
1649 )
1650 .await
1651 .unwrap();
1652 agent.read_with(cx, |agent, cx| {
1653 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1654 });
1655
1656 let worktree = project
1657 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1658 .await
1659 .unwrap();
1660 cx.run_until_parked();
1661 agent.read_with(cx, |agent, cx| {
1662 assert_eq!(
1663 agent.project_context.read(cx).worktrees,
1664 vec![WorktreeContext {
1665 root_name: "a".into(),
1666 abs_path: Path::new("/a").into(),
1667 rules_file: None
1668 }]
1669 )
1670 });
1671
1672 // Creating `/a/.rules` updates the project context.
1673 fs.insert_file("/a/.rules", Vec::new()).await;
1674 cx.run_until_parked();
1675 agent.read_with(cx, |agent, cx| {
1676 let rules_entry = worktree
1677 .read(cx)
1678 .entry_for_path(rel_path(".rules"))
1679 .unwrap();
1680 assert_eq!(
1681 agent.project_context.read(cx).worktrees,
1682 vec![WorktreeContext {
1683 root_name: "a".into(),
1684 abs_path: Path::new("/a").into(),
1685 rules_file: Some(RulesFileContext {
1686 path_in_worktree: rel_path(".rules").into(),
1687 text: "".into(),
1688 project_entry_id: rules_entry.id.to_usize()
1689 })
1690 }]
1691 )
1692 });
1693 }
1694
1695 #[gpui::test]
1696 async fn test_listing_models(cx: &mut TestAppContext) {
1697 init_test(cx);
1698 let fs = FakeFs::new(cx.executor());
1699 fs.insert_tree("/", json!({ "a": {} })).await;
1700 let project = Project::test(fs.clone(), [], cx).await;
1701 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1702 let connection = NativeAgentConnection(
1703 NativeAgent::new(
1704 project.clone(),
1705 thread_store,
1706 Templates::new(),
1707 None,
1708 fs.clone(),
1709 &mut cx.to_async(),
1710 )
1711 .await
1712 .unwrap(),
1713 );
1714
1715 // Create a thread/session
1716 let acp_thread = cx
1717 .update(|cx| {
1718 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1719 })
1720 .await
1721 .unwrap();
1722
1723 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1724
1725 let models = cx
1726 .update(|cx| {
1727 connection
1728 .model_selector(&session_id)
1729 .unwrap()
1730 .list_models(cx)
1731 })
1732 .await
1733 .unwrap();
1734
1735 let acp_thread::AgentModelList::Grouped(models) = models else {
1736 panic!("Unexpected model group");
1737 };
1738 assert_eq!(
1739 models,
1740 IndexMap::from_iter([(
1741 AgentModelGroupName("Fake".into()),
1742 vec![AgentModelInfo {
1743 id: acp::ModelId::new("fake/fake"),
1744 name: "Fake".into(),
1745 description: None,
1746 icon: Some(acp_thread::AgentModelIcon::Named(
1747 ui::IconName::ZedAssistant
1748 )),
1749 }]
1750 )])
1751 );
1752 }
1753
1754 #[gpui::test]
1755 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1756 init_test(cx);
1757 let fs = FakeFs::new(cx.executor());
1758 fs.create_dir(paths::settings_file().parent().unwrap())
1759 .await
1760 .unwrap();
1761 fs.insert_file(
1762 paths::settings_file(),
1763 json!({
1764 "agent": {
1765 "default_model": {
1766 "provider": "foo",
1767 "model": "bar"
1768 }
1769 }
1770 })
1771 .to_string()
1772 .into_bytes(),
1773 )
1774 .await;
1775 let project = Project::test(fs.clone(), [], cx).await;
1776
1777 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1778
1779 // Create the agent and connection
1780 let agent = NativeAgent::new(
1781 project.clone(),
1782 thread_store,
1783 Templates::new(),
1784 None,
1785 fs.clone(),
1786 &mut cx.to_async(),
1787 )
1788 .await
1789 .unwrap();
1790 let connection = NativeAgentConnection(agent.clone());
1791
1792 // Create a thread/session
1793 let acp_thread = cx
1794 .update(|cx| {
1795 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1796 })
1797 .await
1798 .unwrap();
1799
1800 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1801
1802 // Select a model
1803 let selector = connection.model_selector(&session_id).unwrap();
1804 let model_id = acp::ModelId::new("fake/fake");
1805 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1806 .await
1807 .unwrap();
1808
1809 // Verify the thread has the selected model
1810 agent.read_with(cx, |agent, _| {
1811 let session = agent.sessions.get(&session_id).unwrap();
1812 session.thread.read_with(cx, |thread, _| {
1813 assert_eq!(thread.model().unwrap().id().0, "fake");
1814 });
1815 });
1816
1817 cx.run_until_parked();
1818
1819 // Verify settings file was updated
1820 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1821 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1822
1823 // Check that the agent settings contain the selected model
1824 assert_eq!(
1825 settings_json["agent"]["default_model"]["model"],
1826 json!("fake")
1827 );
1828 assert_eq!(
1829 settings_json["agent"]["default_model"]["provider"],
1830 json!("fake")
1831 );
1832 }
1833
1834 #[gpui::test]
1835 async fn test_save_load_thread(cx: &mut TestAppContext) {
1836 init_test(cx);
1837 let fs = FakeFs::new(cx.executor());
1838 fs.insert_tree(
1839 "/",
1840 json!({
1841 "a": {
1842 "b.md": "Lorem"
1843 }
1844 }),
1845 )
1846 .await;
1847 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1848 let thread_store = cx.new(|cx| ThreadStore::new(cx));
1849 let agent = NativeAgent::new(
1850 project.clone(),
1851 thread_store.clone(),
1852 Templates::new(),
1853 None,
1854 fs.clone(),
1855 &mut cx.to_async(),
1856 )
1857 .await
1858 .unwrap();
1859 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1860
1861 let acp_thread = cx
1862 .update(|cx| {
1863 connection
1864 .clone()
1865 .new_thread(project.clone(), Path::new(""), cx)
1866 })
1867 .await
1868 .unwrap();
1869 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1870 let thread = agent.read_with(cx, |agent, _| {
1871 agent.sessions.get(&session_id).unwrap().thread.clone()
1872 });
1873
1874 // Ensure empty threads are not saved, even if they get mutated.
1875 let model = Arc::new(FakeLanguageModel::default());
1876 let summary_model = Arc::new(FakeLanguageModel::default());
1877 thread.update(cx, |thread, cx| {
1878 thread.set_model(model.clone(), cx);
1879 thread.set_summarization_model(Some(summary_model.clone()), cx);
1880 });
1881 cx.run_until_parked();
1882 assert_eq!(thread_entries(&thread_store, cx), vec![]);
1883
1884 let send = acp_thread.update(cx, |thread, cx| {
1885 thread.send(
1886 vec![
1887 "What does ".into(),
1888 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1889 "b.md",
1890 MentionUri::File {
1891 abs_path: path!("/a/b.md").into(),
1892 }
1893 .to_uri()
1894 .to_string(),
1895 )),
1896 " mean?".into(),
1897 ],
1898 cx,
1899 )
1900 });
1901 let send = cx.foreground_executor().spawn(send);
1902 cx.run_until_parked();
1903
1904 model.send_last_completion_stream_text_chunk("Lorem.");
1905 model.end_last_completion_stream();
1906 cx.run_until_parked();
1907 summary_model
1908 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1909 summary_model.end_last_completion_stream();
1910
1911 send.await.unwrap();
1912 let uri = MentionUri::File {
1913 abs_path: path!("/a/b.md").into(),
1914 }
1915 .to_uri();
1916 acp_thread.read_with(cx, |thread, cx| {
1917 assert_eq!(
1918 thread.to_markdown(cx),
1919 formatdoc! {"
1920 ## User
1921
1922 What does [@b.md]({uri}) mean?
1923
1924 ## Assistant
1925
1926 Lorem.
1927
1928 "}
1929 )
1930 });
1931
1932 cx.run_until_parked();
1933
1934 // Drop the ACP thread, which should cause the session to be dropped as well.
1935 cx.update(|_| {
1936 drop(thread);
1937 drop(acp_thread);
1938 });
1939 agent.read_with(cx, |agent, _| {
1940 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1941 });
1942
1943 // Ensure the thread can be reloaded from disk.
1944 assert_eq!(
1945 thread_entries(&thread_store, cx),
1946 vec![(
1947 session_id.clone(),
1948 format!("Explaining {}", path!("/a/b.md"))
1949 )]
1950 );
1951 let acp_thread = agent
1952 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1953 .await
1954 .unwrap();
1955 acp_thread.read_with(cx, |thread, cx| {
1956 assert_eq!(
1957 thread.to_markdown(cx),
1958 formatdoc! {"
1959 ## User
1960
1961 What does [@b.md]({uri}) mean?
1962
1963 ## Assistant
1964
1965 Lorem.
1966
1967 "}
1968 )
1969 });
1970 }
1971
1972 fn thread_entries(
1973 thread_store: &Entity<ThreadStore>,
1974 cx: &mut TestAppContext,
1975 ) -> Vec<(acp::SessionId, String)> {
1976 thread_store.read_with(cx, |store, _| {
1977 store
1978 .entries()
1979 .map(|entry| (entry.id.clone(), entry.title.to_string()))
1980 .collect::<Vec<_>>()
1981 })
1982 }
1983
1984 fn init_test(cx: &mut TestAppContext) {
1985 env_logger::try_init().ok();
1986 cx.update(|cx| {
1987 let settings_store = SettingsStore::test(cx);
1988 cx.set_global(settings_store);
1989
1990 LanguageModelRegistry::test(cx);
1991 });
1992 }
1993}
1994
1995fn mcp_message_content_to_acp_content_block(
1996 content: context_server::types::MessageContent,
1997) -> acp::ContentBlock {
1998 match content {
1999 context_server::types::MessageContent::Text {
2000 text,
2001 annotations: _,
2002 } => text.into(),
2003 context_server::types::MessageContent::Image {
2004 data,
2005 mime_type,
2006 annotations: _,
2007 } => acp::ContentBlock::Image(acp::ImageContent::new(data, mime_type)),
2008 context_server::types::MessageContent::Audio {
2009 data,
2010 mime_type,
2011 annotations: _,
2012 } => acp::ContentBlock::Audio(acp::AudioContent::new(data, mime_type)),
2013 context_server::types::MessageContent::Resource {
2014 resource,
2015 annotations: _,
2016 } => {
2017 let mut link =
2018 acp::ResourceLink::new(resource.uri.to_string(), resource.uri.to_string());
2019 if let Some(mime_type) = resource.mime_type {
2020 link = link.mime_type(mime_type);
2021 }
2022 acp::ContentBlock::ResourceLink(link)
2023 }
2024 }
2025}