1use crate::{
2 ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
3 UserMessageContent, templates::Templates,
4};
5use crate::{HistoryStore, TitleUpdated, TokenUsageUpdated};
6use acp_thread::{AcpThread, AgentModelSelector};
7use action_log::ActionLog;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::channel::mpsc;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::update_settings_file;
24use std::any::Any;
25use std::collections::HashMap;
26use std::path::Path;
27use std::rc::Rc;
28use std::sync::Arc;
29use util::ResultExt;
30
31const RULES_FILE_NAMES: [&str; 9] = [
32 ".rules",
33 ".cursorrules",
34 ".windsurfrules",
35 ".clinerules",
36 ".github/copilot-instructions.md",
37 "CLAUDE.md",
38 "AGENT.md",
39 "AGENTS.md",
40 "GEMINI.md",
41];
42
43pub struct RulesLoadingError {
44 pub message: SharedString,
45}
46
47/// Holds both the internal Thread and the AcpThread for a session
48struct Session {
49 /// The internal thread that processes messages
50 thread: Entity<Thread>,
51 /// The ACP thread that handles protocol communication
52 acp_thread: WeakEntity<acp_thread::AcpThread>,
53 pending_save: Task<()>,
54 _subscriptions: Vec<Subscription>,
55}
56
57pub struct LanguageModels {
58 /// Access language model by ID
59 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
60 /// Cached list for returning language model information
61 model_list: acp_thread::AgentModelList,
62 refresh_models_rx: watch::Receiver<()>,
63 refresh_models_tx: watch::Sender<()>,
64 _authenticate_all_providers_task: Task<()>,
65}
66
67impl LanguageModels {
68 fn new(cx: &mut App) -> Self {
69 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
70
71 let mut this = Self {
72 models: HashMap::default(),
73 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
74 refresh_models_rx,
75 refresh_models_tx,
76 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
77 };
78 this.refresh_list(cx);
79 this
80 }
81
82 fn refresh_list(&mut self, cx: &App) {
83 let providers = LanguageModelRegistry::global(cx)
84 .read(cx)
85 .providers()
86 .into_iter()
87 .filter(|provider| provider.is_authenticated(cx))
88 .collect::<Vec<_>>();
89
90 let mut language_model_list = IndexMap::default();
91 let mut recommended_models = HashSet::default();
92
93 let mut recommended = Vec::new();
94 for provider in &providers {
95 for model in provider.recommended_models(cx) {
96 recommended_models.insert((model.provider_id(), model.id()));
97 recommended.push(Self::map_language_model_to_info(&model, provider));
98 }
99 }
100 if !recommended.is_empty() {
101 language_model_list.insert(
102 acp_thread::AgentModelGroupName("Recommended".into()),
103 recommended,
104 );
105 }
106
107 let mut models = HashMap::default();
108 for provider in providers {
109 let mut provider_models = Vec::new();
110 for model in provider.provided_models(cx) {
111 let model_info = Self::map_language_model_to_info(&model, &provider);
112 let model_id = model_info.id.clone();
113 if !recommended_models.contains(&(model.provider_id(), model.id())) {
114 provider_models.push(model_info);
115 }
116 models.insert(model_id, model);
117 }
118 if !provider_models.is_empty() {
119 language_model_list.insert(
120 acp_thread::AgentModelGroupName(provider.name().0.clone()),
121 provider_models,
122 );
123 }
124 }
125
126 self.models = models;
127 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
128 self.refresh_models_tx.send(()).ok();
129 }
130
131 fn watch(&self) -> watch::Receiver<()> {
132 self.refresh_models_rx.clone()
133 }
134
135 pub fn model_from_id(
136 &self,
137 model_id: &acp_thread::AgentModelId,
138 ) -> Option<Arc<dyn LanguageModel>> {
139 self.models.get(model_id).cloned()
140 }
141
142 fn map_language_model_to_info(
143 model: &Arc<dyn LanguageModel>,
144 provider: &Arc<dyn LanguageModelProvider>,
145 ) -> acp_thread::AgentModelInfo {
146 acp_thread::AgentModelInfo {
147 id: Self::model_id(model),
148 name: model.name().0,
149 icon: Some(provider.icon()),
150 }
151 }
152
153 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
154 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
155 }
156
157 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
158 let authenticate_all_providers = LanguageModelRegistry::global(cx)
159 .read(cx)
160 .providers()
161 .iter()
162 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
163 .collect::<Vec<_>>();
164
165 cx.background_spawn(async move {
166 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
167 if let Err(err) = authenticate_task.await {
168 if matches!(err, language_model::AuthenticateError::CredentialsNotFound) {
169 // Since we're authenticating these providers in the
170 // background for the purposes of populating the
171 // language selector, we don't care about providers
172 // where the credentials are not found.
173 } else {
174 // Some providers have noisy failure states that we
175 // don't want to spam the logs with every time the
176 // language model selector is initialized.
177 //
178 // Ideally these should have more clear failure modes
179 // that we know are safe to ignore here, like what we do
180 // with `CredentialsNotFound` above.
181 match provider_id.0.as_ref() {
182 "lmstudio" | "ollama" => {
183 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
184 //
185 // These fail noisily, so we don't log them.
186 }
187 "copilot_chat" => {
188 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
189 }
190 _ => {
191 log::error!(
192 "Failed to authenticate provider: {}: {err}",
193 provider_name.0
194 );
195 }
196 }
197 }
198 }
199 }
200 })
201 }
202}
203
204pub struct NativeAgent {
205 /// Session ID -> Session mapping
206 sessions: HashMap<acp::SessionId, Session>,
207 history: Entity<HistoryStore>,
208 /// Shared project context for all threads
209 project_context: Entity<ProjectContext>,
210 project_context_needs_refresh: watch::Sender<()>,
211 _maintain_project_context: Task<Result<()>>,
212 context_server_registry: Entity<ContextServerRegistry>,
213 /// Shared templates for all threads
214 templates: Arc<Templates>,
215 /// Cached model information
216 models: LanguageModels,
217 project: Entity<Project>,
218 prompt_store: Option<Entity<PromptStore>>,
219 fs: Arc<dyn Fs>,
220 _subscriptions: Vec<Subscription>,
221}
222
223impl NativeAgent {
224 pub async fn new(
225 project: Entity<Project>,
226 history: Entity<HistoryStore>,
227 templates: Arc<Templates>,
228 prompt_store: Option<Entity<PromptStore>>,
229 fs: Arc<dyn Fs>,
230 cx: &mut AsyncApp,
231 ) -> Result<Entity<NativeAgent>> {
232 log::debug!("Creating new NativeAgent");
233
234 let project_context = cx
235 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
236 .await;
237
238 cx.new(|cx| {
239 let mut subscriptions = vec![
240 cx.subscribe(&project, Self::handle_project_event),
241 cx.subscribe(
242 &LanguageModelRegistry::global(cx),
243 Self::handle_models_updated_event,
244 ),
245 ];
246 if let Some(prompt_store) = prompt_store.as_ref() {
247 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
248 }
249
250 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
251 watch::channel(());
252 Self {
253 sessions: HashMap::new(),
254 history,
255 project_context: cx.new(|_| project_context),
256 project_context_needs_refresh: project_context_needs_refresh_tx,
257 _maintain_project_context: cx.spawn(async move |this, cx| {
258 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
259 }),
260 context_server_registry: cx.new(|cx| {
261 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
262 }),
263 templates,
264 models: LanguageModels::new(cx),
265 project,
266 prompt_store,
267 fs,
268 _subscriptions: subscriptions,
269 }
270 })
271 }
272
273 fn register_session(
274 &mut self,
275 thread_handle: Entity<Thread>,
276 cx: &mut Context<Self>,
277 ) -> Entity<AcpThread> {
278 let connection = Rc::new(NativeAgentConnection(cx.entity()));
279 let registry = LanguageModelRegistry::read_global(cx);
280 let summarization_model = registry.thread_summary_model().map(|c| c.model);
281
282 thread_handle.update(cx, |thread, cx| {
283 thread.set_summarization_model(summarization_model, cx);
284 thread.add_default_tools(cx)
285 });
286
287 let thread = thread_handle.read(cx);
288 let session_id = thread.id().clone();
289 let title = thread.title();
290 let project = thread.project.clone();
291 let action_log = thread.action_log.clone();
292 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
293 let acp_thread = cx.new(|cx| {
294 acp_thread::AcpThread::new(
295 title,
296 connection,
297 project.clone(),
298 action_log.clone(),
299 session_id.clone(),
300 prompt_capabilities_rx,
301 cx,
302 )
303 });
304 let subscriptions = vec![
305 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
306 this.sessions.remove(acp_thread.session_id());
307 }),
308 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
309 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
310 cx.observe(&thread_handle, move |this, thread, cx| {
311 this.save_thread(thread, cx)
312 }),
313 ];
314
315 self.sessions.insert(
316 session_id,
317 Session {
318 thread: thread_handle,
319 acp_thread: acp_thread.downgrade(),
320 _subscriptions: subscriptions,
321 pending_save: Task::ready(()),
322 },
323 );
324 acp_thread
325 }
326
327 pub fn models(&self) -> &LanguageModels {
328 &self.models
329 }
330
331 async fn maintain_project_context(
332 this: WeakEntity<Self>,
333 mut needs_refresh: watch::Receiver<()>,
334 cx: &mut AsyncApp,
335 ) -> Result<()> {
336 while needs_refresh.changed().await.is_ok() {
337 let project_context = this
338 .update(cx, |this, cx| {
339 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
340 })?
341 .await;
342 this.update(cx, |this, cx| {
343 this.project_context = cx.new(|_| project_context);
344 })?;
345 }
346
347 Ok(())
348 }
349
350 fn build_project_context(
351 project: &Entity<Project>,
352 prompt_store: Option<&Entity<PromptStore>>,
353 cx: &mut App,
354 ) -> Task<ProjectContext> {
355 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
356 let worktree_tasks = worktrees
357 .into_iter()
358 .map(|worktree| {
359 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
360 })
361 .collect::<Vec<_>>();
362 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
363 prompt_store.read_with(cx, |prompt_store, cx| {
364 let prompts = prompt_store.default_prompt_metadata();
365 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
366 let contents = prompt_store.load(prompt_metadata.id, cx);
367 async move { (contents.await, prompt_metadata) }
368 });
369 cx.background_spawn(future::join_all(load_tasks))
370 })
371 } else {
372 Task::ready(vec![])
373 };
374
375 cx.spawn(async move |_cx| {
376 let (worktrees, default_user_rules) =
377 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
378
379 let worktrees = worktrees
380 .into_iter()
381 .map(|(worktree, _rules_error)| {
382 // TODO: show error message
383 // if let Some(rules_error) = rules_error {
384 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
385 // }
386 worktree
387 })
388 .collect::<Vec<_>>();
389
390 let default_user_rules = default_user_rules
391 .into_iter()
392 .flat_map(|(contents, prompt_metadata)| match contents {
393 Ok(contents) => Some(UserRulesContext {
394 uuid: match prompt_metadata.id {
395 PromptId::User { uuid } => uuid,
396 PromptId::EditWorkflow => return None,
397 },
398 title: prompt_metadata.title.map(|title| title.to_string()),
399 contents,
400 }),
401 Err(_err) => {
402 // TODO: show error message
403 // this.update(cx, |_, cx| {
404 // cx.emit(RulesLoadingError {
405 // message: format!("{err:?}").into(),
406 // });
407 // })
408 // .ok();
409 None
410 }
411 })
412 .collect::<Vec<_>>();
413
414 ProjectContext::new(worktrees, default_user_rules)
415 })
416 }
417
418 fn load_worktree_info_for_system_prompt(
419 worktree: Entity<Worktree>,
420 project: Entity<Project>,
421 cx: &mut App,
422 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
423 let tree = worktree.read(cx);
424 let root_name = tree.root_name().into();
425 let abs_path = tree.abs_path();
426
427 let mut context = WorktreeContext {
428 root_name,
429 abs_path,
430 rules_file: None,
431 };
432
433 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
434 let Some(rules_task) = rules_task else {
435 return Task::ready((context, None));
436 };
437
438 cx.spawn(async move |_| {
439 let (rules_file, rules_file_error) = match rules_task.await {
440 Ok(rules_file) => (Some(rules_file), None),
441 Err(err) => (
442 None,
443 Some(RulesLoadingError {
444 message: format!("{err}").into(),
445 }),
446 ),
447 };
448 context.rules_file = rules_file;
449 (context, rules_file_error)
450 })
451 }
452
453 fn load_worktree_rules_file(
454 worktree: Entity<Worktree>,
455 project: Entity<Project>,
456 cx: &mut App,
457 ) -> Option<Task<Result<RulesFileContext>>> {
458 let worktree = worktree.read(cx);
459 let worktree_id = worktree.id();
460 let selected_rules_file = RULES_FILE_NAMES
461 .into_iter()
462 .filter_map(|name| {
463 worktree
464 .entry_for_path(name)
465 .filter(|entry| entry.is_file())
466 .map(|entry| entry.path.clone())
467 })
468 .next();
469
470 // Note that Cline supports `.clinerules` being a directory, but that is not currently
471 // supported. This doesn't seem to occur often in GitHub repositories.
472 selected_rules_file.map(|path_in_worktree| {
473 let project_path = ProjectPath {
474 worktree_id,
475 path: path_in_worktree.clone(),
476 };
477 let buffer_task =
478 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
479 let rope_task = cx.spawn(async move |cx| {
480 buffer_task.await?.read_with(cx, |buffer, cx| {
481 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
482 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
483 })?
484 });
485 // Build a string from the rope on a background thread.
486 cx.background_spawn(async move {
487 let (project_entry_id, rope) = rope_task.await?;
488 anyhow::Ok(RulesFileContext {
489 path_in_worktree,
490 text: rope.to_string().trim().to_string(),
491 project_entry_id: project_entry_id.to_usize(),
492 })
493 })
494 })
495 }
496
497 fn handle_thread_title_updated(
498 &mut self,
499 thread: Entity<Thread>,
500 _: &TitleUpdated,
501 cx: &mut Context<Self>,
502 ) {
503 let session_id = thread.read(cx).id();
504 let Some(session) = self.sessions.get(session_id) else {
505 return;
506 };
507 let thread = thread.downgrade();
508 let acp_thread = session.acp_thread.clone();
509 cx.spawn(async move |_, cx| {
510 let title = thread.read_with(cx, |thread, _| thread.title())?;
511 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
512 task.await
513 })
514 .detach_and_log_err(cx);
515 }
516
517 fn handle_thread_token_usage_updated(
518 &mut self,
519 thread: Entity<Thread>,
520 usage: &TokenUsageUpdated,
521 cx: &mut Context<Self>,
522 ) {
523 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
524 return;
525 };
526 session
527 .acp_thread
528 .update(cx, |acp_thread, cx| {
529 acp_thread.update_token_usage(usage.0.clone(), cx);
530 })
531 .ok();
532 }
533
534 fn handle_project_event(
535 &mut self,
536 _project: Entity<Project>,
537 event: &project::Event,
538 _cx: &mut Context<Self>,
539 ) {
540 match event {
541 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
542 self.project_context_needs_refresh.send(()).ok();
543 }
544 project::Event::WorktreeUpdatedEntries(_, items) => {
545 if items.iter().any(|(path, _, _)| {
546 RULES_FILE_NAMES
547 .iter()
548 .any(|name| path.as_ref() == Path::new(name))
549 }) {
550 self.project_context_needs_refresh.send(()).ok();
551 }
552 }
553 _ => {}
554 }
555 }
556
557 fn handle_prompts_updated_event(
558 &mut self,
559 _prompt_store: Entity<PromptStore>,
560 _event: &prompt_store::PromptsUpdatedEvent,
561 _cx: &mut Context<Self>,
562 ) {
563 self.project_context_needs_refresh.send(()).ok();
564 }
565
566 fn handle_models_updated_event(
567 &mut self,
568 _registry: Entity<LanguageModelRegistry>,
569 _event: &language_model::Event,
570 cx: &mut Context<Self>,
571 ) {
572 self.models.refresh_list(cx);
573
574 let registry = LanguageModelRegistry::read_global(cx);
575 let default_model = registry.default_model().map(|m| m.model);
576 let summarization_model = registry.thread_summary_model().map(|m| m.model);
577
578 for session in self.sessions.values_mut() {
579 session.thread.update(cx, |thread, cx| {
580 if thread.model().is_none()
581 && let Some(model) = default_model.clone()
582 {
583 thread.set_model(model, cx);
584 cx.notify();
585 }
586 thread.set_summarization_model(summarization_model.clone(), cx);
587 });
588 }
589 }
590
591 pub fn open_thread(
592 &mut self,
593 id: acp::SessionId,
594 cx: &mut Context<Self>,
595 ) -> Task<Result<Entity<AcpThread>>> {
596 let database_future = ThreadsDatabase::connect(cx);
597 cx.spawn(async move |this, cx| {
598 let database = database_future.await.map_err(|err| anyhow!(err))?;
599 let db_thread = database
600 .load_thread(id.clone())
601 .await?
602 .with_context(|| format!("no thread found with ID: {id:?}"))?;
603
604 let thread = this.update(cx, |this, cx| {
605 let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
606 cx.new(|cx| {
607 Thread::from_db(
608 id.clone(),
609 db_thread,
610 this.project.clone(),
611 this.project_context.clone(),
612 this.context_server_registry.clone(),
613 action_log.clone(),
614 this.templates.clone(),
615 cx,
616 )
617 })
618 })?;
619 let acp_thread =
620 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
621 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
622 cx.update(|cx| {
623 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
624 })?
625 .await?;
626 Ok(acp_thread)
627 })
628 }
629
630 pub fn thread_summary(
631 &mut self,
632 id: acp::SessionId,
633 cx: &mut Context<Self>,
634 ) -> Task<Result<SharedString>> {
635 let thread = self.open_thread(id.clone(), cx);
636 cx.spawn(async move |this, cx| {
637 let acp_thread = thread.await?;
638 let result = this
639 .update(cx, |this, cx| {
640 this.sessions
641 .get(&id)
642 .unwrap()
643 .thread
644 .update(cx, |thread, cx| thread.summary(cx))
645 })?
646 .await?;
647 drop(acp_thread);
648 Ok(result)
649 })
650 }
651
652 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
653 if thread.read(cx).is_empty() {
654 return;
655 }
656
657 let database_future = ThreadsDatabase::connect(cx);
658 let (id, db_thread) =
659 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
660 let Some(session) = self.sessions.get_mut(&id) else {
661 return;
662 };
663 let history = self.history.clone();
664 session.pending_save = cx.spawn(async move |_, cx| {
665 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
666 return;
667 };
668 let db_thread = db_thread.await;
669 database.save_thread(id, db_thread).await.log_err();
670 history.update(cx, |history, cx| history.reload(cx)).ok();
671 });
672 }
673}
674
675/// Wrapper struct that implements the AgentConnection trait
676#[derive(Clone)]
677pub struct NativeAgentConnection(pub Entity<NativeAgent>);
678
679impl NativeAgentConnection {
680 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
681 self.0
682 .read(cx)
683 .sessions
684 .get(session_id)
685 .map(|session| session.thread.clone())
686 }
687
688 fn run_turn(
689 &self,
690 session_id: acp::SessionId,
691 cx: &mut App,
692 f: impl 'static
693 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
694 ) -> Task<Result<acp::PromptResponse>> {
695 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
696 agent
697 .sessions
698 .get_mut(&session_id)
699 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
700 }) else {
701 return Task::ready(Err(anyhow!("Session not found")));
702 };
703 log::debug!("Found session for: {}", session_id);
704
705 let response_stream = match f(thread, cx) {
706 Ok(stream) => stream,
707 Err(err) => return Task::ready(Err(err)),
708 };
709 Self::handle_thread_events(response_stream, acp_thread, cx)
710 }
711
712 fn handle_thread_events(
713 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
714 acp_thread: WeakEntity<AcpThread>,
715 cx: &App,
716 ) -> Task<Result<acp::PromptResponse>> {
717 cx.spawn(async move |cx| {
718 // Handle response stream and forward to session.acp_thread
719 while let Some(result) = events.next().await {
720 match result {
721 Ok(event) => {
722 log::trace!("Received completion event: {:?}", event);
723
724 match event {
725 ThreadEvent::UserMessage(message) => {
726 acp_thread.update(cx, |thread, cx| {
727 for content in message.content {
728 thread.push_user_content_block(
729 Some(message.id.clone()),
730 content.into(),
731 cx,
732 );
733 }
734 })?;
735 }
736 ThreadEvent::AgentText(text) => {
737 acp_thread.update(cx, |thread, cx| {
738 thread.push_assistant_content_block(
739 acp::ContentBlock::Text(acp::TextContent {
740 text,
741 annotations: None,
742 }),
743 false,
744 cx,
745 )
746 })?;
747 }
748 ThreadEvent::AgentThinking(text) => {
749 acp_thread.update(cx, |thread, cx| {
750 thread.push_assistant_content_block(
751 acp::ContentBlock::Text(acp::TextContent {
752 text,
753 annotations: None,
754 }),
755 true,
756 cx,
757 )
758 })?;
759 }
760 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
761 tool_call,
762 options,
763 response,
764 }) => {
765 let recv = acp_thread.update(cx, |thread, cx| {
766 thread.request_tool_call_authorization(tool_call, options, cx)
767 })?;
768 cx.background_spawn(async move {
769 if let Some(recv) = recv.log_err()
770 && let Some(option) = recv
771 .await
772 .context("authorization sender was dropped")
773 .log_err()
774 {
775 response
776 .send(option)
777 .map(|_| anyhow!("authorization receiver was dropped"))
778 .log_err();
779 }
780 })
781 .detach();
782 }
783 ThreadEvent::ToolCall(tool_call) => {
784 acp_thread.update(cx, |thread, cx| {
785 thread.upsert_tool_call(tool_call, cx)
786 })??;
787 }
788 ThreadEvent::ToolCallUpdate(update) => {
789 acp_thread.update(cx, |thread, cx| {
790 thread.update_tool_call(update, cx)
791 })??;
792 }
793 ThreadEvent::Retry(status) => {
794 acp_thread.update(cx, |thread, cx| {
795 thread.update_retry_status(status, cx)
796 })?;
797 }
798 ThreadEvent::Stop(stop_reason) => {
799 log::debug!("Assistant message complete: {:?}", stop_reason);
800 return Ok(acp::PromptResponse { stop_reason });
801 }
802 }
803 }
804 Err(e) => {
805 log::error!("Error in model response stream: {:?}", e);
806 return Err(e);
807 }
808 }
809 }
810
811 log::debug!("Response stream completed");
812 anyhow::Ok(acp::PromptResponse {
813 stop_reason: acp::StopReason::EndTurn,
814 })
815 })
816 }
817}
818
819impl AgentModelSelector for NativeAgentConnection {
820 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
821 log::debug!("NativeAgentConnection::list_models called");
822 let list = self.0.read(cx).models.model_list.clone();
823 Task::ready(if list.is_empty() {
824 Err(anyhow::anyhow!("No models available"))
825 } else {
826 Ok(list)
827 })
828 }
829
830 fn select_model(
831 &self,
832 session_id: acp::SessionId,
833 model_id: acp_thread::AgentModelId,
834 cx: &mut App,
835 ) -> Task<Result<()>> {
836 log::debug!("Setting model for session {}: {}", session_id, model_id);
837 let Some(thread) = self
838 .0
839 .read(cx)
840 .sessions
841 .get(&session_id)
842 .map(|session| session.thread.clone())
843 else {
844 return Task::ready(Err(anyhow!("Session not found")));
845 };
846
847 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
848 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
849 };
850
851 thread.update(cx, |thread, cx| {
852 thread.set_model(model.clone(), cx);
853 });
854
855 update_settings_file::<AgentSettings>(
856 self.0.read(cx).fs.clone(),
857 cx,
858 move |settings, _cx| {
859 settings.set_model(model);
860 },
861 );
862
863 Task::ready(Ok(()))
864 }
865
866 fn selected_model(
867 &self,
868 session_id: &acp::SessionId,
869 cx: &mut App,
870 ) -> Task<Result<acp_thread::AgentModelInfo>> {
871 let session_id = session_id.clone();
872
873 let Some(thread) = self
874 .0
875 .read(cx)
876 .sessions
877 .get(&session_id)
878 .map(|session| session.thread.clone())
879 else {
880 return Task::ready(Err(anyhow!("Session not found")));
881 };
882 let Some(model) = thread.read(cx).model() else {
883 return Task::ready(Err(anyhow!("Model not found")));
884 };
885 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
886 else {
887 return Task::ready(Err(anyhow!("Provider not found")));
888 };
889 Task::ready(Ok(LanguageModels::map_language_model_to_info(
890 model, &provider,
891 )))
892 }
893
894 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
895 self.0.read(cx).models.watch()
896 }
897}
898
899impl acp_thread::AgentConnection for NativeAgentConnection {
900 fn new_thread(
901 self: Rc<Self>,
902 project: Entity<Project>,
903 cwd: &Path,
904 cx: &mut App,
905 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
906 let agent = self.0.clone();
907 log::debug!("Creating new thread for project at: {:?}", cwd);
908
909 cx.spawn(async move |cx| {
910 log::debug!("Starting thread creation in async context");
911
912 // Create Thread
913 let thread = agent.update(
914 cx,
915 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
916 // Fetch default model from registry settings
917 let registry = LanguageModelRegistry::read_global(cx);
918 // Log available models for debugging
919 let available_count = registry.available_models(cx).count();
920 log::debug!("Total available models: {}", available_count);
921
922 let default_model = registry.default_model().and_then(|default_model| {
923 agent
924 .models
925 .model_from_id(&LanguageModels::model_id(&default_model.model))
926 });
927 Ok(cx.new(|cx| {
928 Thread::new(
929 project.clone(),
930 agent.project_context.clone(),
931 agent.context_server_registry.clone(),
932 agent.templates.clone(),
933 default_model,
934 cx,
935 )
936 }))
937 },
938 )??;
939 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
940 })
941 }
942
943 fn auth_methods(&self) -> &[acp::AuthMethod] {
944 &[] // No auth for in-process
945 }
946
947 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
948 Task::ready(Ok(()))
949 }
950
951 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
952 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
953 }
954
955 fn prompt(
956 &self,
957 id: Option<acp_thread::UserMessageId>,
958 params: acp::PromptRequest,
959 cx: &mut App,
960 ) -> Task<Result<acp::PromptResponse>> {
961 let id = id.expect("UserMessageId is required");
962 let session_id = params.session_id.clone();
963 log::info!("Received prompt request for session: {}", session_id);
964 log::debug!("Prompt blocks count: {}", params.prompt.len());
965
966 self.run_turn(session_id, cx, |thread, cx| {
967 let content: Vec<UserMessageContent> = params
968 .prompt
969 .into_iter()
970 .map(Into::into)
971 .collect::<Vec<_>>();
972 log::debug!("Converted prompt to message: {} chars", content.len());
973 log::debug!("Message id: {:?}", id);
974 log::debug!("Message content: {:?}", content);
975
976 thread.update(cx, |thread, cx| thread.send(id, content, cx))
977 })
978 }
979
980 fn resume(
981 &self,
982 session_id: &acp::SessionId,
983 _cx: &App,
984 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
985 Some(Rc::new(NativeAgentSessionResume {
986 connection: self.clone(),
987 session_id: session_id.clone(),
988 }) as _)
989 }
990
991 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
992 log::info!("Cancelling on session: {}", session_id);
993 self.0.update(cx, |agent, cx| {
994 if let Some(agent) = agent.sessions.get(session_id) {
995 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
996 }
997 });
998 }
999
1000 fn truncate(
1001 &self,
1002 session_id: &agent_client_protocol::SessionId,
1003 cx: &App,
1004 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1005 self.0.read_with(cx, |agent, _cx| {
1006 agent.sessions.get(session_id).map(|session| {
1007 Rc::new(NativeAgentSessionEditor {
1008 thread: session.thread.clone(),
1009 acp_thread: session.acp_thread.clone(),
1010 }) as _
1011 })
1012 })
1013 }
1014
1015 fn set_title(
1016 &self,
1017 session_id: &acp::SessionId,
1018 _cx: &App,
1019 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1020 Some(Rc::new(NativeAgentSessionSetTitle {
1021 connection: self.clone(),
1022 session_id: session_id.clone(),
1023 }) as _)
1024 }
1025
1026 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1027 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1028 }
1029
1030 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1031 self
1032 }
1033}
1034
1035impl acp_thread::AgentTelemetry for NativeAgentConnection {
1036 fn agent_name(&self) -> String {
1037 "Zed".into()
1038 }
1039
1040 fn thread_data(
1041 &self,
1042 session_id: &acp::SessionId,
1043 cx: &mut App,
1044 ) -> Task<Result<serde_json::Value>> {
1045 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1046 return Task::ready(Err(anyhow!("Session not found")));
1047 };
1048
1049 let task = session.thread.read(cx).to_db(cx);
1050 cx.background_spawn(async move {
1051 serde_json::to_value(task.await).context("Failed to serialize thread")
1052 })
1053 }
1054}
1055
1056struct NativeAgentSessionEditor {
1057 thread: Entity<Thread>,
1058 acp_thread: WeakEntity<AcpThread>,
1059}
1060
1061impl acp_thread::AgentSessionTruncate for NativeAgentSessionEditor {
1062 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1063 match self.thread.update(cx, |thread, cx| {
1064 thread.truncate(message_id.clone(), cx)?;
1065 Ok(thread.latest_token_usage())
1066 }) {
1067 Ok(usage) => {
1068 self.acp_thread
1069 .update(cx, |thread, cx| {
1070 thread.update_token_usage(usage, cx);
1071 })
1072 .ok();
1073 Task::ready(Ok(()))
1074 }
1075 Err(error) => Task::ready(Err(error)),
1076 }
1077 }
1078}
1079
1080struct NativeAgentSessionResume {
1081 connection: NativeAgentConnection,
1082 session_id: acp::SessionId,
1083}
1084
1085impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1086 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1087 self.connection
1088 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1089 thread.update(cx, |thread, cx| thread.resume(cx))
1090 })
1091 }
1092}
1093
1094struct NativeAgentSessionSetTitle {
1095 connection: NativeAgentConnection,
1096 session_id: acp::SessionId,
1097}
1098
1099impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1100 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1101 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1102 return Task::ready(Err(anyhow!("session not found")));
1103 };
1104 let thread = session.thread.clone();
1105 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1106 Task::ready(Ok(()))
1107 }
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112 use crate::HistoryEntryId;
1113
1114 use super::*;
1115 use acp_thread::{
1116 AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
1117 };
1118 use fs::FakeFs;
1119 use gpui::TestAppContext;
1120 use indoc::indoc;
1121 use language_model::fake_provider::FakeLanguageModel;
1122 use serde_json::json;
1123 use settings::SettingsStore;
1124 use util::path;
1125
1126 #[gpui::test]
1127 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1128 init_test(cx);
1129 let fs = FakeFs::new(cx.executor());
1130 fs.insert_tree(
1131 "/",
1132 json!({
1133 "a": {}
1134 }),
1135 )
1136 .await;
1137 let project = Project::test(fs.clone(), [], cx).await;
1138 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1139 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1140 let agent = NativeAgent::new(
1141 project.clone(),
1142 history_store,
1143 Templates::new(),
1144 None,
1145 fs.clone(),
1146 &mut cx.to_async(),
1147 )
1148 .await
1149 .unwrap();
1150 agent.read_with(cx, |agent, cx| {
1151 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1152 });
1153
1154 let worktree = project
1155 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1156 .await
1157 .unwrap();
1158 cx.run_until_parked();
1159 agent.read_with(cx, |agent, cx| {
1160 assert_eq!(
1161 agent.project_context.read(cx).worktrees,
1162 vec![WorktreeContext {
1163 root_name: "a".into(),
1164 abs_path: Path::new("/a").into(),
1165 rules_file: None
1166 }]
1167 )
1168 });
1169
1170 // Creating `/a/.rules` updates the project context.
1171 fs.insert_file("/a/.rules", Vec::new()).await;
1172 cx.run_until_parked();
1173 agent.read_with(cx, |agent, cx| {
1174 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1175 assert_eq!(
1176 agent.project_context.read(cx).worktrees,
1177 vec![WorktreeContext {
1178 root_name: "a".into(),
1179 abs_path: Path::new("/a").into(),
1180 rules_file: Some(RulesFileContext {
1181 path_in_worktree: Path::new(".rules").into(),
1182 text: "".into(),
1183 project_entry_id: rules_entry.id.to_usize()
1184 })
1185 }]
1186 )
1187 });
1188 }
1189
1190 #[gpui::test]
1191 async fn test_listing_models(cx: &mut TestAppContext) {
1192 init_test(cx);
1193 let fs = FakeFs::new(cx.executor());
1194 fs.insert_tree("/", json!({ "a": {} })).await;
1195 let project = Project::test(fs.clone(), [], cx).await;
1196 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1197 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1198 let connection = NativeAgentConnection(
1199 NativeAgent::new(
1200 project.clone(),
1201 history_store,
1202 Templates::new(),
1203 None,
1204 fs.clone(),
1205 &mut cx.to_async(),
1206 )
1207 .await
1208 .unwrap(),
1209 );
1210
1211 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1212
1213 let acp_thread::AgentModelList::Grouped(models) = models else {
1214 panic!("Unexpected model group");
1215 };
1216 assert_eq!(
1217 models,
1218 IndexMap::from_iter([(
1219 AgentModelGroupName("Fake".into()),
1220 vec![AgentModelInfo {
1221 id: AgentModelId("fake/fake".into()),
1222 name: "Fake".into(),
1223 icon: Some(ui::IconName::ZedAssistant),
1224 }]
1225 )])
1226 );
1227 }
1228
1229 #[gpui::test]
1230 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1231 init_test(cx);
1232 let fs = FakeFs::new(cx.executor());
1233 fs.create_dir(paths::settings_file().parent().unwrap())
1234 .await
1235 .unwrap();
1236 fs.insert_file(
1237 paths::settings_file(),
1238 json!({
1239 "agent": {
1240 "default_model": {
1241 "provider": "foo",
1242 "model": "bar"
1243 }
1244 }
1245 })
1246 .to_string()
1247 .into_bytes(),
1248 )
1249 .await;
1250 let project = Project::test(fs.clone(), [], cx).await;
1251
1252 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1253 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1254
1255 // Create the agent and connection
1256 let agent = NativeAgent::new(
1257 project.clone(),
1258 history_store,
1259 Templates::new(),
1260 None,
1261 fs.clone(),
1262 &mut cx.to_async(),
1263 )
1264 .await
1265 .unwrap();
1266 let connection = NativeAgentConnection(agent.clone());
1267
1268 // Create a thread/session
1269 let acp_thread = cx
1270 .update(|cx| {
1271 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1272 })
1273 .await
1274 .unwrap();
1275
1276 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1277
1278 // Select a model
1279 let model_id = AgentModelId("fake/fake".into());
1280 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1281 .await
1282 .unwrap();
1283
1284 // Verify the thread has the selected model
1285 agent.read_with(cx, |agent, _| {
1286 let session = agent.sessions.get(&session_id).unwrap();
1287 session.thread.read_with(cx, |thread, _| {
1288 assert_eq!(thread.model().unwrap().id().0, "fake");
1289 });
1290 });
1291
1292 cx.run_until_parked();
1293
1294 // Verify settings file was updated
1295 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1296 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1297
1298 // Check that the agent settings contain the selected model
1299 assert_eq!(
1300 settings_json["agent"]["default_model"]["model"],
1301 json!("fake")
1302 );
1303 assert_eq!(
1304 settings_json["agent"]["default_model"]["provider"],
1305 json!("fake")
1306 );
1307 }
1308
1309 #[gpui::test]
1310 #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1311 async fn test_save_load_thread(cx: &mut TestAppContext) {
1312 init_test(cx);
1313 let fs = FakeFs::new(cx.executor());
1314 fs.insert_tree(
1315 "/",
1316 json!({
1317 "a": {
1318 "b.md": "Lorem"
1319 }
1320 }),
1321 )
1322 .await;
1323 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1324 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1325 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1326 let agent = NativeAgent::new(
1327 project.clone(),
1328 history_store.clone(),
1329 Templates::new(),
1330 None,
1331 fs.clone(),
1332 &mut cx.to_async(),
1333 )
1334 .await
1335 .unwrap();
1336 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1337
1338 let acp_thread = cx
1339 .update(|cx| {
1340 connection
1341 .clone()
1342 .new_thread(project.clone(), Path::new(""), cx)
1343 })
1344 .await
1345 .unwrap();
1346 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1347 let thread = agent.read_with(cx, |agent, _| {
1348 agent.sessions.get(&session_id).unwrap().thread.clone()
1349 });
1350
1351 // Ensure empty threads are not saved, even if they get mutated.
1352 let model = Arc::new(FakeLanguageModel::default());
1353 let summary_model = Arc::new(FakeLanguageModel::default());
1354 thread.update(cx, |thread, cx| {
1355 thread.set_model(model.clone(), cx);
1356 thread.set_summarization_model(Some(summary_model.clone()), cx);
1357 });
1358 cx.run_until_parked();
1359 assert_eq!(history_entries(&history_store, cx), vec![]);
1360
1361 let send = acp_thread.update(cx, |thread, cx| {
1362 thread.send(
1363 vec![
1364 "What does ".into(),
1365 acp::ContentBlock::ResourceLink(acp::ResourceLink {
1366 name: "b.md".into(),
1367 uri: MentionUri::File {
1368 abs_path: path!("/a/b.md").into(),
1369 }
1370 .to_uri()
1371 .to_string(),
1372 annotations: None,
1373 description: None,
1374 mime_type: None,
1375 size: None,
1376 title: None,
1377 }),
1378 " mean?".into(),
1379 ],
1380 cx,
1381 )
1382 });
1383 let send = cx.foreground_executor().spawn(send);
1384 cx.run_until_parked();
1385
1386 model.send_last_completion_stream_text_chunk("Lorem.");
1387 model.end_last_completion_stream();
1388 cx.run_until_parked();
1389 summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
1390 summary_model.end_last_completion_stream();
1391
1392 send.await.unwrap();
1393 acp_thread.read_with(cx, |thread, cx| {
1394 assert_eq!(
1395 thread.to_markdown(cx),
1396 indoc! {"
1397 ## User
1398
1399 What does [@b.md](file:///a/b.md) mean?
1400
1401 ## Assistant
1402
1403 Lorem.
1404
1405 "}
1406 )
1407 });
1408
1409 cx.run_until_parked();
1410
1411 // Drop the ACP thread, which should cause the session to be dropped as well.
1412 cx.update(|_| {
1413 drop(thread);
1414 drop(acp_thread);
1415 });
1416 agent.read_with(cx, |agent, _| {
1417 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1418 });
1419
1420 // Ensure the thread can be reloaded from disk.
1421 assert_eq!(
1422 history_entries(&history_store, cx),
1423 vec![(
1424 HistoryEntryId::AcpThread(session_id.clone()),
1425 "Explaining /a/b.md".into()
1426 )]
1427 );
1428 let acp_thread = agent
1429 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1430 .await
1431 .unwrap();
1432 acp_thread.read_with(cx, |thread, cx| {
1433 assert_eq!(
1434 thread.to_markdown(cx),
1435 indoc! {"
1436 ## User
1437
1438 What does [@b.md](file:///a/b.md) mean?
1439
1440 ## Assistant
1441
1442 Lorem.
1443
1444 "}
1445 )
1446 });
1447 }
1448
1449 fn history_entries(
1450 history: &Entity<HistoryStore>,
1451 cx: &mut TestAppContext,
1452 ) -> Vec<(HistoryEntryId, String)> {
1453 history.read_with(cx, |history, _| {
1454 history
1455 .entries()
1456 .map(|e| (e.id(), e.title().to_string()))
1457 .collect::<Vec<_>>()
1458 })
1459 }
1460
1461 fn init_test(cx: &mut TestAppContext) {
1462 env_logger::try_init().ok();
1463 cx.update(|cx| {
1464 let settings_store = SettingsStore::test(cx);
1465 cx.set_global(settings_store);
1466 Project::init_settings(cx);
1467 agent_settings::init(cx);
1468 language::init(cx);
1469 LanguageModelRegistry::test(cx);
1470 });
1471 }
1472}