1use crate::native_agent_server::NATIVE_AGENT_SERVER_NAME;
2use crate::{
3 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
4 EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
5 OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
6 UserMessageContent, WebSearchTool, templates::Templates,
7};
8use crate::{ThreadsDatabase, generate_session_id};
9use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
10use agent_client_protocol as acp;
11use agent_settings::AgentSettings;
12use anyhow::{Context as _, Result, anyhow};
13use collections::{HashSet, IndexMap};
14use fs::Fs;
15use futures::channel::mpsc;
16use futures::{StreamExt, future};
17use gpui::{
18 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
19};
20use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry, SelectedModel};
21use project::{Project, ProjectItem, ProjectPath, Worktree};
22use prompt_store::{
23 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
24};
25use settings::update_settings_file;
26use std::any::Any;
27use std::cell::RefCell;
28use std::collections::HashMap;
29use std::path::Path;
30use std::rc::Rc;
31use std::sync::Arc;
32use std::time::Duration;
33use util::ResultExt;
34
35const RULES_FILE_NAMES: [&'static str; 9] = [
36 ".rules",
37 ".cursorrules",
38 ".windsurfrules",
39 ".clinerules",
40 ".github/copilot-instructions.md",
41 "CLAUDE.md",
42 "AGENT.md",
43 "AGENTS.md",
44 "GEMINI.md",
45];
46
47const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500);
48
49pub struct RulesLoadingError {
50 pub message: SharedString,
51}
52
53/// Holds both the internal Thread and the AcpThread for a session
54struct Session {
55 /// The internal thread that processes messages
56 thread: Entity<Thread>,
57 /// The ACP thread that handles protocol communication
58 acp_thread: WeakEntity<acp_thread::AcpThread>,
59 save_task: Task<Result<()>>,
60 _subscriptions: Vec<Subscription>,
61}
62
63pub struct LanguageModels {
64 /// Access language model by ID
65 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
66 /// Cached list for returning language model information
67 model_list: acp_thread::AgentModelList,
68 refresh_models_rx: watch::Receiver<()>,
69 refresh_models_tx: watch::Sender<()>,
70}
71
72impl LanguageModels {
73 fn new(cx: &App) -> Self {
74 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
75 let mut this = Self {
76 models: HashMap::default(),
77 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
78 refresh_models_rx,
79 refresh_models_tx,
80 };
81 this.refresh_list(cx);
82 this
83 }
84
85 fn refresh_list(&mut self, cx: &App) {
86 let providers = LanguageModelRegistry::global(cx)
87 .read(cx)
88 .providers()
89 .into_iter()
90 .filter(|provider| provider.is_authenticated(cx))
91 .collect::<Vec<_>>();
92
93 let mut language_model_list = IndexMap::default();
94 let mut recommended_models = HashSet::default();
95
96 let mut recommended = Vec::new();
97 for provider in &providers {
98 for model in provider.recommended_models(cx) {
99 recommended_models.insert(model.id());
100 recommended.push(Self::map_language_model_to_info(&model, &provider));
101 }
102 }
103 if !recommended.is_empty() {
104 language_model_list.insert(
105 acp_thread::AgentModelGroupName("Recommended".into()),
106 recommended,
107 );
108 }
109
110 let mut models = HashMap::default();
111 for provider in providers {
112 let mut provider_models = Vec::new();
113 for model in provider.provided_models(cx) {
114 let model_info = Self::map_language_model_to_info(&model, &provider);
115 let model_id = model_info.id.clone();
116 if !recommended_models.contains(&model.id()) {
117 provider_models.push(model_info);
118 }
119 models.insert(model_id, model);
120 }
121 if !provider_models.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName(provider.name().0.clone()),
124 provider_models,
125 );
126 }
127 }
128
129 self.models = models;
130 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
131 self.refresh_models_tx.send(()).ok();
132 }
133
134 fn watch(&self) -> watch::Receiver<()> {
135 self.refresh_models_rx.clone()
136 }
137
138 pub fn model_from_id(
139 &self,
140 model_id: &acp_thread::AgentModelId,
141 ) -> Option<Arc<dyn LanguageModel>> {
142 self.models.get(model_id).cloned()
143 }
144
145 fn map_language_model_to_info(
146 model: &Arc<dyn LanguageModel>,
147 provider: &Arc<dyn LanguageModelProvider>,
148 ) -> acp_thread::AgentModelInfo {
149 acp_thread::AgentModelInfo {
150 id: Self::model_id(model),
151 name: model.name().0,
152 icon: Some(provider.icon()),
153 }
154 }
155
156 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
157 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
158 }
159}
160
161pub struct NativeAgent {
162 /// Session ID -> Session mapping
163 sessions: HashMap<acp::SessionId, Session>,
164 /// Shared project context for all threads
165 project_context: Rc<RefCell<ProjectContext>>,
166 project_context_needs_refresh: watch::Sender<()>,
167 _maintain_project_context: Task<Result<()>>,
168 context_server_registry: Entity<ContextServerRegistry>,
169 /// Shared templates for all threads
170 templates: Arc<Templates>,
171 /// Cached model information
172 models: LanguageModels,
173 project: Entity<Project>,
174 prompt_store: Option<Entity<PromptStore>>,
175 thread_database: Arc<ThreadsDatabase>,
176 history: watch::Sender<Option<Vec<AcpThreadMetadata>>>,
177 load_history: Task<()>,
178 fs: Arc<dyn Fs>,
179 _subscriptions: Vec<Subscription>,
180}
181
182impl NativeAgent {
183 pub async fn new(
184 project: Entity<Project>,
185 templates: Arc<Templates>,
186 prompt_store: Option<Entity<PromptStore>>,
187 fs: Arc<dyn Fs>,
188 cx: &mut AsyncApp,
189 ) -> Result<Entity<NativeAgent>> {
190 log::info!("Creating new NativeAgent");
191
192 let project_context = cx
193 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
194 .await;
195
196 let thread_database = cx
197 .update(|cx| ThreadsDatabase::connect(cx))?
198 .await
199 .map_err(|e| anyhow!(e))?;
200
201 cx.new(|cx| {
202 let mut subscriptions = vec![
203 cx.subscribe(&project, Self::handle_project_event),
204 cx.subscribe(
205 &LanguageModelRegistry::global(cx),
206 Self::handle_models_updated_event,
207 ),
208 ];
209 if let Some(prompt_store) = prompt_store.as_ref() {
210 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
211 }
212
213 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
214 watch::channel(());
215 let mut this = Self {
216 sessions: HashMap::new(),
217 project_context: Rc::new(RefCell::new(project_context)),
218 project_context_needs_refresh: project_context_needs_refresh_tx,
219 _maintain_project_context: cx.spawn(async move |this, cx| {
220 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
221 }),
222 context_server_registry: cx.new(|cx| {
223 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
224 }),
225 thread_database,
226 templates,
227 models: LanguageModels::new(cx),
228 project,
229 prompt_store,
230 fs,
231 history: watch::channel(None).0,
232 load_history: Task::ready(()),
233 _subscriptions: subscriptions,
234 };
235 this.reload_history(cx);
236 this
237 })
238 }
239
240 pub fn insert_session(
241 &mut self,
242 thread: Entity<Thread>,
243 acp_thread: Entity<AcpThread>,
244 cx: &mut Context<Self>,
245 ) {
246 let id = thread.read(cx).id().clone();
247 self.sessions.insert(
248 id,
249 Session {
250 thread: thread.clone(),
251 acp_thread: acp_thread.downgrade(),
252 save_task: Task::ready(Ok(())),
253 _subscriptions: vec![
254 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
255 this.sessions.remove(acp_thread.session_id());
256 }),
257 cx.observe(&thread, |this, thread, cx| {
258 this.save_thread(thread.clone(), cx)
259 }),
260 ],
261 },
262 );
263 }
264
265 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
266 let id = thread.read(cx).id().clone();
267 let Some(session) = self.sessions.get_mut(&id) else {
268 return;
269 };
270
271 let thread = thread.downgrade();
272 let thread_database = self.thread_database.clone();
273 session.save_task = cx.spawn(async move |this, cx| {
274 cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
275 let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await;
276 thread_database.save_thread(id, db_thread).await?;
277 this.update(cx, |this, cx| this.reload_history(cx))?;
278 Ok(())
279 });
280 }
281
282 fn reload_history(&mut self, cx: &mut Context<Self>) {
283 let thread_database = self.thread_database.clone();
284 self.load_history = cx.spawn(async move |this, cx| {
285 let results = cx
286 .background_spawn(async move {
287 let results = thread_database.list_threads().await?;
288 anyhow::Ok(
289 results
290 .into_iter()
291 .map(|thread| AcpThreadMetadata {
292 agent: NATIVE_AGENT_SERVER_NAME.clone(),
293 id: thread.id.into(),
294 title: thread.title,
295 updated_at: thread.updated_at,
296 })
297 .collect(),
298 )
299 })
300 .await;
301 if let Some(results) = results.log_err() {
302 this.update(cx, |this, _| this.history.send(Some(results)))
303 .ok();
304 }
305 });
306 }
307
308 pub fn models(&self) -> &LanguageModels {
309 &self.models
310 }
311
312 async fn maintain_project_context(
313 this: WeakEntity<Self>,
314 mut needs_refresh: watch::Receiver<()>,
315 cx: &mut AsyncApp,
316 ) -> Result<()> {
317 while needs_refresh.changed().await.is_ok() {
318 let project_context = this
319 .update(cx, |this, cx| {
320 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
321 })?
322 .await;
323 this.update(cx, |this, _| this.project_context.replace(project_context))?;
324 }
325
326 Ok(())
327 }
328
329 fn build_project_context(
330 project: &Entity<Project>,
331 prompt_store: Option<&Entity<PromptStore>>,
332 cx: &mut App,
333 ) -> Task<ProjectContext> {
334 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
335 let worktree_tasks = worktrees
336 .into_iter()
337 .map(|worktree| {
338 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
339 })
340 .collect::<Vec<_>>();
341 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
342 prompt_store.read_with(cx, |prompt_store, cx| {
343 let prompts = prompt_store.default_prompt_metadata();
344 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
345 let contents = prompt_store.load(prompt_metadata.id, cx);
346 async move { (contents.await, prompt_metadata) }
347 });
348 cx.background_spawn(future::join_all(load_tasks))
349 })
350 } else {
351 Task::ready(vec![])
352 };
353
354 cx.spawn(async move |_cx| {
355 let (worktrees, default_user_rules) =
356 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
357
358 let worktrees = worktrees
359 .into_iter()
360 .map(|(worktree, _rules_error)| {
361 // TODO: show error message
362 // if let Some(rules_error) = rules_error {
363 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
364 // }
365 worktree
366 })
367 .collect::<Vec<_>>();
368
369 let default_user_rules = default_user_rules
370 .into_iter()
371 .flat_map(|(contents, prompt_metadata)| match contents {
372 Ok(contents) => Some(UserRulesContext {
373 uuid: match prompt_metadata.id {
374 PromptId::User { uuid } => uuid,
375 PromptId::EditWorkflow => return None,
376 },
377 title: prompt_metadata.title.map(|title| title.to_string()),
378 contents,
379 }),
380 Err(_err) => {
381 // TODO: show error message
382 // this.update(cx, |_, cx| {
383 // cx.emit(RulesLoadingError {
384 // message: format!("{err:?}").into(),
385 // });
386 // })
387 // .ok();
388 None
389 }
390 })
391 .collect::<Vec<_>>();
392
393 ProjectContext::new(worktrees, default_user_rules)
394 })
395 }
396
397 fn load_worktree_info_for_system_prompt(
398 worktree: Entity<Worktree>,
399 project: Entity<Project>,
400 cx: &mut App,
401 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
402 let tree = worktree.read(cx);
403 let root_name = tree.root_name().into();
404 let abs_path = tree.abs_path();
405
406 let mut context = WorktreeContext {
407 root_name,
408 abs_path,
409 rules_file: None,
410 };
411
412 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
413 let Some(rules_task) = rules_task else {
414 return Task::ready((context, None));
415 };
416
417 cx.spawn(async move |_| {
418 let (rules_file, rules_file_error) = match rules_task.await {
419 Ok(rules_file) => (Some(rules_file), None),
420 Err(err) => (
421 None,
422 Some(RulesLoadingError {
423 message: format!("{err}").into(),
424 }),
425 ),
426 };
427 context.rules_file = rules_file;
428 (context, rules_file_error)
429 })
430 }
431
432 fn load_worktree_rules_file(
433 worktree: Entity<Worktree>,
434 project: Entity<Project>,
435 cx: &mut App,
436 ) -> Option<Task<Result<RulesFileContext>>> {
437 let worktree = worktree.read(cx);
438 let worktree_id = worktree.id();
439 let selected_rules_file = RULES_FILE_NAMES
440 .into_iter()
441 .filter_map(|name| {
442 worktree
443 .entry_for_path(name)
444 .filter(|entry| entry.is_file())
445 .map(|entry| entry.path.clone())
446 })
447 .next();
448
449 // Note that Cline supports `.clinerules` being a directory, but that is not currently
450 // supported. This doesn't seem to occur often in GitHub repositories.
451 selected_rules_file.map(|path_in_worktree| {
452 let project_path = ProjectPath {
453 worktree_id,
454 path: path_in_worktree.clone(),
455 };
456 let buffer_task =
457 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
458 let rope_task = cx.spawn(async move |cx| {
459 buffer_task.await?.read_with(cx, |buffer, cx| {
460 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
461 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
462 })?
463 });
464 // Build a string from the rope on a background thread.
465 cx.background_spawn(async move {
466 let (project_entry_id, rope) = rope_task.await?;
467 anyhow::Ok(RulesFileContext {
468 path_in_worktree,
469 text: rope.to_string().trim().to_string(),
470 project_entry_id: project_entry_id.to_usize(),
471 })
472 })
473 })
474 }
475
476 fn handle_project_event(
477 &mut self,
478 _project: Entity<Project>,
479 event: &project::Event,
480 _cx: &mut Context<Self>,
481 ) {
482 match event {
483 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
484 self.project_context_needs_refresh.send(()).ok();
485 }
486 project::Event::WorktreeUpdatedEntries(_, items) => {
487 if items.iter().any(|(path, _, _)| {
488 RULES_FILE_NAMES
489 .iter()
490 .any(|name| path.as_ref() == Path::new(name))
491 }) {
492 self.project_context_needs_refresh.send(()).ok();
493 }
494 }
495 _ => {}
496 }
497 }
498
499 fn handle_prompts_updated_event(
500 &mut self,
501 _prompt_store: Entity<PromptStore>,
502 _event: &prompt_store::PromptsUpdatedEvent,
503 _cx: &mut Context<Self>,
504 ) {
505 self.project_context_needs_refresh.send(()).ok();
506 }
507
508 fn handle_models_updated_event(
509 &mut self,
510 _registry: Entity<LanguageModelRegistry>,
511 _event: &language_model::Event,
512 cx: &mut Context<Self>,
513 ) {
514 self.models.refresh_list(cx);
515 for session in self.sessions.values_mut() {
516 session.thread.update(cx, |thread, cx| {
517 let model_id = LanguageModels::model_id(&thread.model());
518 if let Some(model) = self.models.model_from_id(&model_id) {
519 thread.set_model(model.clone(), cx);
520 }
521 });
522 }
523 }
524}
525
526/// Wrapper struct that implements the AgentConnection trait
527#[derive(Clone)]
528pub struct NativeAgentConnection(pub Entity<NativeAgent>);
529
530impl NativeAgentConnection {
531 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
532 self.0
533 .read(cx)
534 .sessions
535 .get(session_id)
536 .map(|session| session.thread.clone())
537 }
538
539 fn run_turn(
540 &self,
541 session_id: acp::SessionId,
542 cx: &mut App,
543 f: impl 'static
544 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
545 ) -> Task<Result<acp::PromptResponse>> {
546 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
547 agent
548 .sessions
549 .get_mut(&session_id)
550 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
551 }) else {
552 return Task::ready(Err(anyhow!("Session not found")));
553 };
554 log::debug!("Found session for: {}", session_id);
555
556 let response_stream = match f(thread, cx) {
557 Ok(stream) => stream,
558 Err(err) => return Task::ready(Err(err)),
559 };
560 Self::handle_thread_events(response_stream, acp_thread, cx)
561 }
562
563 fn handle_thread_events(
564 mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
565 acp_thread: WeakEntity<AcpThread>,
566 cx: &mut App,
567 ) -> Task<Result<acp::PromptResponse>> {
568 cx.spawn(async move |cx| {
569 // Handle response stream and forward to session.acp_thread
570 while let Some(result) = response_stream.next().await {
571 match result {
572 Ok(event) => {
573 log::trace!("Received completion event: {:?}", event);
574
575 match event {
576 ThreadEvent::UserMessage(message) => {
577 acp_thread.update(cx, |thread, cx| {
578 for content in message.content {
579 thread.push_user_content_block(
580 Some(message.id.clone()),
581 content.into(),
582 cx,
583 );
584 }
585 })?;
586 }
587 ThreadEvent::AgentText(text) => {
588 acp_thread.update(cx, |thread, cx| {
589 thread.push_assistant_content_block(
590 acp::ContentBlock::Text(acp::TextContent {
591 text,
592 annotations: None,
593 }),
594 false,
595 cx,
596 )
597 })?;
598 }
599 ThreadEvent::AgentThinking(text) => {
600 acp_thread.update(cx, |thread, cx| {
601 thread.push_assistant_content_block(
602 acp::ContentBlock::Text(acp::TextContent {
603 text,
604 annotations: None,
605 }),
606 true,
607 cx,
608 )
609 })?;
610 }
611 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
612 tool_call,
613 options,
614 response,
615 }) => {
616 let recv = acp_thread.update(cx, |thread, cx| {
617 thread.request_tool_call_authorization(tool_call, options, cx)
618 })?;
619 cx.background_spawn(async move {
620 if let Some(recv) = recv.log_err()
621 && let Some(option) = recv
622 .await
623 .context("authorization sender was dropped")
624 .log_err()
625 {
626 response
627 .send(option)
628 .map(|_| anyhow!("authorization receiver was dropped"))
629 .log_err();
630 }
631 })
632 .detach();
633 }
634 ThreadEvent::ToolCall(tool_call) => {
635 acp_thread.update(cx, |thread, cx| {
636 thread.upsert_tool_call(tool_call, cx)
637 })??;
638 }
639 ThreadEvent::ToolCallUpdate(update) => {
640 acp_thread.update(cx, |thread, cx| {
641 thread.update_tool_call(update, cx)
642 })??;
643 }
644 ThreadEvent::Stop(stop_reason) => {
645 log::debug!("Assistant message complete: {:?}", stop_reason);
646 return Ok(acp::PromptResponse { stop_reason });
647 }
648 }
649 }
650 Err(e) => {
651 log::error!("Error in model response stream: {:?}", e);
652 return Err(e);
653 }
654 }
655 }
656
657 log::info!("Response stream completed");
658 anyhow::Ok(acp::PromptResponse {
659 stop_reason: acp::StopReason::EndTurn,
660 })
661 })
662 }
663
664 fn register_tools(
665 thread: &mut Thread,
666 project: Entity<Project>,
667 action_log: Entity<action_log::ActionLog>,
668 cx: &mut Context<Thread>,
669 ) {
670 let language_registry = project.read(cx).languages().clone();
671 thread.add_tool(CopyPathTool::new(project.clone()));
672 thread.add_tool(CreateDirectoryTool::new(project.clone()));
673 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
674 thread.add_tool(DiagnosticsTool::new(project.clone()));
675 thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
676 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
677 thread.add_tool(FindPathTool::new(project.clone()));
678 thread.add_tool(GrepTool::new(project.clone()));
679 thread.add_tool(ListDirectoryTool::new(project.clone()));
680 thread.add_tool(MovePathTool::new(project.clone()));
681 thread.add_tool(NowTool);
682 thread.add_tool(OpenTool::new(project.clone()));
683 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
684 thread.add_tool(TerminalTool::new(project.clone(), cx));
685 thread.add_tool(ThinkingTool);
686 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
687 }
688}
689
690impl AgentModelSelector for NativeAgentConnection {
691 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
692 log::debug!("NativeAgentConnection::list_models called");
693 let list = self.0.read(cx).models.model_list.clone();
694 Task::ready(if list.is_empty() {
695 Err(anyhow::anyhow!("No models available"))
696 } else {
697 Ok(list)
698 })
699 }
700
701 fn select_model(
702 &self,
703 session_id: acp::SessionId,
704 model_id: acp_thread::AgentModelId,
705 cx: &mut App,
706 ) -> Task<Result<()>> {
707 log::info!("Setting model for session {}: {}", session_id, model_id);
708 let Some(thread) = self
709 .0
710 .read(cx)
711 .sessions
712 .get(&session_id)
713 .map(|session| session.thread.clone())
714 else {
715 return Task::ready(Err(anyhow!("Session not found")));
716 };
717
718 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
719 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
720 };
721
722 thread.update(cx, |thread, cx| {
723 thread.set_model(model.clone(), cx);
724 });
725
726 update_settings_file::<AgentSettings>(
727 self.0.read(cx).fs.clone(),
728 cx,
729 move |settings, _cx| {
730 settings.set_model(model);
731 },
732 );
733
734 Task::ready(Ok(()))
735 }
736
737 fn selected_model(
738 &self,
739 session_id: &acp::SessionId,
740 cx: &mut App,
741 ) -> Task<Result<acp_thread::AgentModelInfo>> {
742 let session_id = session_id.clone();
743
744 let Some(thread) = self
745 .0
746 .read(cx)
747 .sessions
748 .get(&session_id)
749 .map(|session| session.thread.clone())
750 else {
751 return Task::ready(Err(anyhow!("Session not found")));
752 };
753 let model = thread.read(cx).model().clone();
754 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
755 else {
756 return Task::ready(Err(anyhow!("Provider not found")));
757 };
758 Task::ready(Ok(LanguageModels::map_language_model_to_info(
759 &model, &provider,
760 )))
761 }
762
763 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
764 self.0.read(cx).models.watch()
765 }
766}
767
768impl acp_thread::AgentConnection for NativeAgentConnection {
769 fn new_thread(
770 self: Rc<Self>,
771 project: Entity<Project>,
772 cwd: &Path,
773 cx: &mut App,
774 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
775 let agent = self.0.clone();
776 log::info!("Creating new thread for project at: {:?}", cwd);
777
778 cx.spawn(async move |cx| {
779 log::debug!("Starting thread creation in async context");
780
781 // Generate session ID
782 let session_id = generate_session_id();
783 log::info!("Created session with ID: {}", session_id);
784
785 // Create AcpThread
786 let acp_thread = cx.update(|cx| {
787 cx.new(|cx| {
788 acp_thread::AcpThread::new(
789 "agent2",
790 self.clone(),
791 project.clone(),
792 session_id.clone(),
793 cx,
794 )
795 })
796 })?;
797 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
798
799 // Create Thread
800 let thread = agent.update(
801 cx,
802 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
803 // Fetch default model from registry settings
804 let registry = LanguageModelRegistry::read_global(cx);
805
806 // Log available models for debugging
807 let available_count = registry.available_models(cx).count();
808 log::debug!("Total available models: {}", available_count);
809
810 let default_model = registry
811 .default_model()
812 .and_then(|default_model| {
813 agent
814 .models
815 .model_from_id(&LanguageModels::model_id(&default_model.model))
816 })
817 .ok_or_else(|| {
818 log::warn!("No default model configured in settings");
819 anyhow!(
820 "No default model. Please configure a default model in settings."
821 )
822 })?;
823
824 let thread = cx.new(|cx| {
825 let mut thread = Thread::new(
826 session_id.clone(),
827 project.clone(),
828 agent.project_context.clone(),
829 agent.context_server_registry.clone(),
830 action_log.clone(),
831 agent.templates.clone(),
832 default_model,
833 cx,
834 );
835 Self::register_tools(&mut thread, project, action_log, cx);
836 thread
837 });
838
839 Ok(thread)
840 },
841 )??;
842
843 // Store the session
844 agent.update(cx, |agent, cx| {
845 agent.insert_session(thread, acp_thread.clone(), cx)
846 })?;
847
848 Ok(acp_thread)
849 })
850 }
851
852 fn auth_methods(&self) -> &[acp::AuthMethod] {
853 &[] // No auth for in-process
854 }
855
856 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
857 Task::ready(Ok(()))
858 }
859
860 fn list_threads(
861 &self,
862 cx: &mut App,
863 ) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
864 Some(self.0.read(cx).history.receiver())
865 }
866
867 fn load_thread(
868 self: Rc<Self>,
869 project: Entity<Project>,
870 _cwd: &Path,
871 session_id: acp::SessionId,
872 cx: &mut App,
873 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
874 let database = self.0.update(cx, |this, _| this.thread_database.clone());
875 cx.spawn(async move |cx| {
876 let db_thread = database
877 .load_thread(session_id.clone())
878 .await?
879 .context("no such thread found")?;
880
881 let acp_thread = cx.update(|cx| {
882 cx.new(|cx| {
883 acp_thread::AcpThread::new(
884 db_thread.title.clone(),
885 self.clone(),
886 project.clone(),
887 session_id.clone(),
888 cx,
889 )
890 })
891 })?;
892 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
893 let agent = self.0.clone();
894
895 // Create Thread
896 let thread = agent.update(cx, |agent, cx| {
897 let configured_model = LanguageModelRegistry::global(cx)
898 .update(cx, |registry, cx| {
899 db_thread
900 .model
901 .as_ref()
902 .and_then(|model| {
903 let model = SelectedModel {
904 provider: model.provider.clone().into(),
905 model: model.model.clone().into(),
906 };
907 registry.select_model(&model, cx)
908 })
909 .or_else(|| registry.default_model())
910 })
911 .context("no default model configured")?;
912
913 let model = agent
914 .models
915 .model_from_id(&LanguageModels::model_id(&configured_model.model))
916 .context("no model by id")?;
917
918 let thread = cx.new(|cx| {
919 let mut thread = Thread::from_db(
920 session_id,
921 db_thread,
922 project.clone(),
923 agent.project_context.clone(),
924 agent.context_server_registry.clone(),
925 action_log.clone(),
926 agent.templates.clone(),
927 model,
928 cx,
929 );
930 Self::register_tools(&mut thread, project, action_log, cx);
931 thread
932 });
933
934 anyhow::Ok(thread)
935 })??;
936
937 // Store the session
938 agent.update(cx, |agent, cx| {
939 agent.insert_session(thread.clone(), acp_thread.clone(), cx)
940 })?;
941
942 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
943 cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
944 .await?;
945
946 Ok(acp_thread)
947 })
948 }
949
950 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
951 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
952 }
953
954 fn prompt(
955 &self,
956 id: Option<acp_thread::UserMessageId>,
957 params: acp::PromptRequest,
958 cx: &mut App,
959 ) -> Task<Result<acp::PromptResponse>> {
960 let id = id.expect("UserMessageId is required");
961 let session_id = params.session_id.clone();
962 log::info!("Received prompt request for session: {}", session_id);
963 log::debug!("Prompt blocks count: {}", params.prompt.len());
964
965 self.run_turn(session_id, cx, |thread, cx| {
966 let content: Vec<UserMessageContent> = params
967 .prompt
968 .into_iter()
969 .map(Into::into)
970 .collect::<Vec<_>>();
971 log::info!("Converted prompt to message: {} chars", content.len());
972 log::debug!("Message id: {:?}", id);
973 log::debug!("Message content: {:?}", content);
974
975 Ok(thread.update(cx, |thread, cx| {
976 log::info!(
977 "Sending message to thread with model: {:?}",
978 thread.model().name()
979 );
980 thread.send(id, content, cx)
981 }))
982 })
983 }
984
985 fn resume(
986 &self,
987 session_id: &acp::SessionId,
988 _cx: &mut App,
989 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
990 Some(Rc::new(NativeAgentSessionResume {
991 connection: self.clone(),
992 session_id: session_id.clone(),
993 }) as _)
994 }
995
996 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
997 log::info!("Cancelling on session: {}", session_id);
998 self.0.update(cx, |agent, cx| {
999 if let Some(agent) = agent.sessions.get(session_id) {
1000 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1001 }
1002 });
1003 }
1004
1005 fn session_editor(
1006 &self,
1007 session_id: &agent_client_protocol::SessionId,
1008 cx: &mut App,
1009 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
1010 self.0.update(cx, |agent, _cx| {
1011 agent
1012 .sessions
1013 .get(session_id)
1014 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
1015 })
1016 }
1017
1018 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1019 self
1020 }
1021}
1022
1023struct NativeAgentSessionEditor(Entity<Thread>);
1024
1025impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
1026 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1027 Task::ready(
1028 self.0
1029 .update(cx, |thread, cx| thread.truncate(message_id, cx)),
1030 )
1031 }
1032}
1033
1034struct NativeAgentSessionResume {
1035 connection: NativeAgentConnection,
1036 session_id: acp::SessionId,
1037}
1038
1039impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1040 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1041 self.connection
1042 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1043 thread.update(cx, |thread, cx| thread.resume(cx))
1044 })
1045 }
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050 use crate::{HistoryEntry, HistoryStore};
1051
1052 use super::*;
1053 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
1054 use fs::FakeFs;
1055 use gpui::TestAppContext;
1056 use serde_json::json;
1057 use settings::SettingsStore;
1058 use util::path;
1059
1060 #[gpui::test]
1061 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1062 init_test(cx);
1063 let fs = FakeFs::new(cx.executor());
1064 fs.insert_tree(
1065 "/",
1066 json!({
1067 "a": {}
1068 }),
1069 )
1070 .await;
1071 let project = Project::test(fs.clone(), [], cx).await;
1072 let agent = NativeAgent::new(
1073 project.clone(),
1074 Templates::new(),
1075 None,
1076 fs.clone(),
1077 &mut cx.to_async(),
1078 )
1079 .await
1080 .unwrap();
1081 agent.read_with(cx, |agent, _| {
1082 assert_eq!(agent.project_context.borrow().worktrees, vec![])
1083 });
1084
1085 let worktree = project
1086 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1087 .await
1088 .unwrap();
1089 cx.run_until_parked();
1090 agent.read_with(cx, |agent, _| {
1091 assert_eq!(
1092 agent.project_context.borrow().worktrees,
1093 vec![WorktreeContext {
1094 root_name: "a".into(),
1095 abs_path: Path::new("/a").into(),
1096 rules_file: None
1097 }]
1098 )
1099 });
1100
1101 // Creating `/a/.rules` updates the project context.
1102 fs.insert_file("/a/.rules", Vec::new()).await;
1103 cx.run_until_parked();
1104 agent.read_with(cx, |agent, cx| {
1105 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1106 assert_eq!(
1107 agent.project_context.borrow().worktrees,
1108 vec![WorktreeContext {
1109 root_name: "a".into(),
1110 abs_path: Path::new("/a").into(),
1111 rules_file: Some(RulesFileContext {
1112 path_in_worktree: Path::new(".rules").into(),
1113 text: "".into(),
1114 project_entry_id: rules_entry.id.to_usize()
1115 })
1116 }]
1117 )
1118 });
1119 }
1120
1121 #[gpui::test]
1122 async fn test_listing_models(cx: &mut TestAppContext) {
1123 init_test(cx);
1124 let fs = FakeFs::new(cx.executor());
1125 fs.insert_tree("/", json!({ "a": {} })).await;
1126 let project = Project::test(fs.clone(), [], cx).await;
1127 let connection = NativeAgentConnection(
1128 NativeAgent::new(
1129 project.clone(),
1130 Templates::new(),
1131 None,
1132 fs.clone(),
1133 &mut cx.to_async(),
1134 )
1135 .await
1136 .unwrap(),
1137 );
1138
1139 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1140
1141 let acp_thread::AgentModelList::Grouped(models) = models else {
1142 panic!("Unexpected model group");
1143 };
1144 assert_eq!(
1145 models,
1146 IndexMap::from_iter([(
1147 AgentModelGroupName("Fake".into()),
1148 vec![AgentModelInfo {
1149 id: AgentModelId("fake/fake".into()),
1150 name: "Fake".into(),
1151 icon: Some(ui::IconName::ZedAssistant),
1152 }]
1153 )])
1154 );
1155 }
1156
1157 #[gpui::test]
1158 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1159 init_test(cx);
1160 let fs = FakeFs::new(cx.executor());
1161 fs.create_dir(paths::settings_file().parent().unwrap())
1162 .await
1163 .unwrap();
1164 fs.insert_file(
1165 paths::settings_file(),
1166 json!({
1167 "agent": {
1168 "default_model": {
1169 "provider": "foo",
1170 "model": "bar"
1171 }
1172 }
1173 })
1174 .to_string()
1175 .into_bytes(),
1176 )
1177 .await;
1178 let project = Project::test(fs.clone(), [], cx).await;
1179
1180 // Create the agent and connection
1181 let agent = NativeAgent::new(
1182 project.clone(),
1183 Templates::new(),
1184 None,
1185 fs.clone(),
1186 &mut cx.to_async(),
1187 )
1188 .await
1189 .unwrap();
1190 let connection = NativeAgentConnection(agent.clone());
1191
1192 // Create a thread/session
1193 let acp_thread = cx
1194 .update(|cx| {
1195 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1196 })
1197 .await
1198 .unwrap();
1199
1200 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1201
1202 // Select a model
1203 let model_id = AgentModelId("fake/fake".into());
1204 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1205 .await
1206 .unwrap();
1207
1208 // Verify the thread has the selected model
1209 agent.read_with(cx, |agent, _| {
1210 let session = agent.sessions.get(&session_id).unwrap();
1211 session.thread.read_with(cx, |thread, _| {
1212 assert_eq!(thread.model().id().0, "fake");
1213 });
1214 });
1215
1216 cx.run_until_parked();
1217
1218 // Verify settings file was updated
1219 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1220 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1221
1222 // Check that the agent settings contain the selected model
1223 assert_eq!(
1224 settings_json["agent"]["default_model"]["model"],
1225 json!("fake")
1226 );
1227 assert_eq!(
1228 settings_json["agent"]["default_model"]["provider"],
1229 json!("fake")
1230 );
1231 }
1232
1233 #[gpui::test]
1234 async fn test_history(cx: &mut TestAppContext) {
1235 init_test(cx);
1236 let fs = FakeFs::new(cx.executor());
1237 let project = Project::test(fs.clone(), [], cx).await;
1238
1239 let agent = NativeAgent::new(
1240 project.clone(),
1241 Templates::new(),
1242 None,
1243 fs.clone(),
1244 &mut cx.to_async(),
1245 )
1246 .await
1247 .unwrap();
1248 let model = cx.update(|cx| {
1249 LanguageModelRegistry::global(cx)
1250 .read(cx)
1251 .default_model()
1252 .unwrap()
1253 .model
1254 });
1255 let connection = NativeAgentConnection(agent.clone());
1256 let history_store = cx.new(|cx| {
1257 let mut store = HistoryStore::new(cx);
1258 store.register_agent(NATIVE_AGENT_SERVER_NAME.clone(), &connection, cx);
1259 store
1260 });
1261
1262 let acp_thread = cx
1263 .update(|cx| {
1264 Rc::new(connection.clone()).new_thread(project.clone(), Path::new(path!("")), cx)
1265 })
1266 .await
1267 .unwrap();
1268 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1269 let selector = connection.model_selector().unwrap();
1270
1271 let model = cx
1272 .update(|cx| selector.selected_model(&session_id, cx))
1273 .await
1274 .expect("selected_model should succeed");
1275 let model = cx
1276 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1277 .unwrap();
1278 let model = model.as_fake();
1279
1280 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Hi", cx));
1281 let send = cx.foreground_executor().spawn(send);
1282 cx.run_until_parked();
1283 model.send_last_completion_stream_text_chunk("Hey");
1284 model.end_last_completion_stream();
1285 send.await.unwrap();
1286 cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE);
1287
1288 let history = history_store.update(cx, |store, cx| store.entries(cx));
1289 assert_eq!(history.len(), 1);
1290 assert_eq!(history[0].title(), "Hi");
1291 }
1292
1293 fn init_test(cx: &mut TestAppContext) {
1294 env_logger::try_init().ok();
1295 cx.update(|cx| {
1296 let settings_store = SettingsStore::test(cx);
1297 cx.set_global(settings_store);
1298 Project::init_settings(cx);
1299 agent_settings::init(cx);
1300 language::init(cx);
1301 LanguageModelRegistry::test(cx);
1302 });
1303 }
1304}