1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
2use acp_thread::{MentionUri, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionMode};
9use collections::IndexMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, Context, Entity, SharedString, Task};
16use language_model::{
17 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
18 LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
19 LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
20 LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::fmt::Write;
29use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
30use util::{ResultExt, markdown::MarkdownCodeBlock};
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum Message {
34 User(UserMessage),
35 Agent(AgentMessage),
36}
37
38impl Message {
39 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
40 match self {
41 Message::Agent(agent_message) => Some(agent_message),
42 _ => None,
43 }
44 }
45
46 pub fn to_markdown(&self) -> String {
47 match self {
48 Message::User(message) => message.to_markdown(),
49 Message::Agent(message) => message.to_markdown(),
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct UserMessage {
56 pub id: UserMessageId,
57 pub content: Vec<UserMessageContent>,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum UserMessageContent {
62 Text(String),
63 Mention { uri: MentionUri, content: String },
64 Image(LanguageModelImage),
65}
66
67impl UserMessage {
68 pub fn to_markdown(&self) -> String {
69 let mut markdown = String::from("## User\n\n");
70
71 for content in &self.content {
72 match content {
73 UserMessageContent::Text(text) => {
74 markdown.push_str(text);
75 markdown.push('\n');
76 }
77 UserMessageContent::Image(_) => {
78 markdown.push_str("<image />\n");
79 }
80 UserMessageContent::Mention { uri, content } => {
81 if !content.is_empty() {
82 markdown.push_str(&format!("{}\n\n{}\n", uri.to_link(), content));
83 } else {
84 markdown.push_str(&format!("{}\n", uri.to_link()));
85 }
86 }
87 }
88 }
89
90 markdown
91 }
92
93 fn to_request(&self) -> LanguageModelRequestMessage {
94 let mut message = LanguageModelRequestMessage {
95 role: Role::User,
96 content: Vec::with_capacity(self.content.len()),
97 cache: false,
98 };
99
100 const OPEN_CONTEXT: &str = "<context>\n\
101 The following items were attached by the user. \
102 They are up-to-date and don't need to be re-read.\n\n";
103
104 const OPEN_FILES_TAG: &str = "<files>";
105 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
106 const OPEN_THREADS_TAG: &str = "<threads>";
107 const OPEN_RULES_TAG: &str =
108 "<rules>\nThe user has specified the following rules that should be applied:\n";
109
110 let mut file_context = OPEN_FILES_TAG.to_string();
111 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
112 let mut thread_context = OPEN_THREADS_TAG.to_string();
113 let mut rules_context = OPEN_RULES_TAG.to_string();
114
115 for chunk in &self.content {
116 let chunk = match chunk {
117 UserMessageContent::Text(text) => {
118 language_model::MessageContent::Text(text.clone())
119 }
120 UserMessageContent::Image(value) => {
121 language_model::MessageContent::Image(value.clone())
122 }
123 UserMessageContent::Mention { uri, content } => {
124 match uri {
125 MentionUri::File(path) | MentionUri::Symbol(path, _) => {
126 write!(
127 &mut symbol_context,
128 "\n{}",
129 MarkdownCodeBlock {
130 tag: &codeblock_tag(&path),
131 text: &content.to_string(),
132 }
133 )
134 .ok();
135 }
136 MentionUri::Thread(_session_id) => {
137 write!(&mut thread_context, "\n{}\n", content).ok();
138 }
139 MentionUri::Rule(_user_prompt_id) => {
140 write!(
141 &mut rules_context,
142 "\n{}",
143 MarkdownCodeBlock {
144 tag: "",
145 text: &content
146 }
147 )
148 .ok();
149 }
150 }
151
152 language_model::MessageContent::Text(uri.to_link())
153 }
154 };
155
156 message.content.push(chunk);
157 }
158
159 let len_before_context = message.content.len();
160
161 if file_context.len() > OPEN_FILES_TAG.len() {
162 file_context.push_str("</files>\n");
163 message
164 .content
165 .push(language_model::MessageContent::Text(file_context));
166 }
167
168 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
169 symbol_context.push_str("</symbols>\n");
170 message
171 .content
172 .push(language_model::MessageContent::Text(symbol_context));
173 }
174
175 if thread_context.len() > OPEN_THREADS_TAG.len() {
176 thread_context.push_str("</threads>\n");
177 message
178 .content
179 .push(language_model::MessageContent::Text(thread_context));
180 }
181
182 if rules_context.len() > OPEN_RULES_TAG.len() {
183 rules_context.push_str("</user_rules>\n");
184 message
185 .content
186 .push(language_model::MessageContent::Text(rules_context));
187 }
188
189 if message.content.len() > len_before_context {
190 message.content.insert(
191 len_before_context,
192 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
193 );
194 message
195 .content
196 .push(language_model::MessageContent::Text("</context>".into()));
197 }
198
199 message
200 }
201}
202
203impl AgentMessage {
204 pub fn to_markdown(&self) -> String {
205 let mut markdown = String::from("## Assistant\n\n");
206
207 for content in &self.content {
208 match content {
209 AgentMessageContent::Text(text) => {
210 markdown.push_str(text);
211 markdown.push('\n');
212 }
213 AgentMessageContent::Thinking { text, .. } => {
214 markdown.push_str("<think>");
215 markdown.push_str(text);
216 markdown.push_str("</think>\n");
217 }
218 AgentMessageContent::RedactedThinking(_) => {
219 markdown.push_str("<redacted_thinking />\n")
220 }
221 AgentMessageContent::Image(_) => {
222 markdown.push_str("<image />\n");
223 }
224 AgentMessageContent::ToolUse(tool_use) => {
225 markdown.push_str(&format!(
226 "**Tool Use**: {} (ID: {})\n",
227 tool_use.name, tool_use.id
228 ));
229 markdown.push_str(&format!(
230 "{}\n",
231 MarkdownCodeBlock {
232 tag: "json",
233 text: &format!("{:#}", tool_use.input)
234 }
235 ));
236 }
237 }
238 }
239
240 for tool_result in self.tool_results.values() {
241 markdown.push_str(&format!(
242 "**Tool Result**: {} (ID: {})\n\n",
243 tool_result.tool_name, tool_result.tool_use_id
244 ));
245 if tool_result.is_error {
246 markdown.push_str("**ERROR:**\n");
247 }
248
249 match &tool_result.content {
250 LanguageModelToolResultContent::Text(text) => {
251 writeln!(markdown, "{text}\n").ok();
252 }
253 LanguageModelToolResultContent::Image(_) => {
254 writeln!(markdown, "<image />\n").ok();
255 }
256 }
257
258 if let Some(output) = tool_result.output.as_ref() {
259 writeln!(
260 markdown,
261 "**Debug Output**:\n\n```json\n{}\n```\n",
262 serde_json::to_string_pretty(output).unwrap()
263 )
264 .unwrap();
265 }
266 }
267
268 markdown
269 }
270
271 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
272 let mut content = Vec::with_capacity(self.content.len());
273 for chunk in &self.content {
274 let chunk = match chunk {
275 AgentMessageContent::Text(text) => {
276 language_model::MessageContent::Text(text.clone())
277 }
278 AgentMessageContent::Thinking { text, signature } => {
279 language_model::MessageContent::Thinking {
280 text: text.clone(),
281 signature: signature.clone(),
282 }
283 }
284 AgentMessageContent::RedactedThinking(value) => {
285 language_model::MessageContent::RedactedThinking(value.clone())
286 }
287 AgentMessageContent::ToolUse(value) => {
288 language_model::MessageContent::ToolUse(value.clone())
289 }
290 AgentMessageContent::Image(value) => {
291 language_model::MessageContent::Image(value.clone())
292 }
293 };
294 content.push(chunk);
295 }
296
297 let mut messages = vec![LanguageModelRequestMessage {
298 role: Role::Assistant,
299 content,
300 cache: false,
301 }];
302
303 if !self.tool_results.is_empty() {
304 let mut tool_results = Vec::with_capacity(self.tool_results.len());
305 for tool_result in self.tool_results.values() {
306 tool_results.push(language_model::MessageContent::ToolResult(
307 tool_result.clone(),
308 ));
309 }
310 messages.push(LanguageModelRequestMessage {
311 role: Role::User,
312 content: tool_results,
313 cache: false,
314 });
315 }
316
317 messages
318 }
319}
320
321#[derive(Default, Debug, Clone, PartialEq, Eq)]
322pub struct AgentMessage {
323 pub content: Vec<AgentMessageContent>,
324 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
325}
326
327#[derive(Debug, Clone, PartialEq, Eq)]
328pub enum AgentMessageContent {
329 Text(String),
330 Thinking {
331 text: String,
332 signature: Option<String>,
333 },
334 RedactedThinking(String),
335 Image(LanguageModelImage),
336 ToolUse(LanguageModelToolUse),
337}
338
339#[derive(Debug)]
340pub enum AgentResponseEvent {
341 Text(String),
342 Thinking(String),
343 ToolCall(acp::ToolCall),
344 ToolCallUpdate(acp_thread::ToolCallUpdate),
345 ToolCallAuthorization(ToolCallAuthorization),
346 Stop(acp::StopReason),
347}
348
349#[derive(Debug)]
350pub struct ToolCallAuthorization {
351 pub tool_call: acp::ToolCall,
352 pub options: Vec<acp::PermissionOption>,
353 pub response: oneshot::Sender<acp::PermissionOptionId>,
354}
355
356pub struct Thread {
357 messages: Vec<Message>,
358 completion_mode: CompletionMode,
359 /// Holds the task that handles agent interaction until the end of the turn.
360 /// Survives across multiple requests as the model performs tool calls and
361 /// we run tools, report their results.
362 running_turn: Option<Task<()>>,
363 pending_agent_message: Option<AgentMessage>,
364 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
365 context_server_registry: Entity<ContextServerRegistry>,
366 profile_id: AgentProfileId,
367 project_context: Rc<RefCell<ProjectContext>>,
368 templates: Arc<Templates>,
369 pub selected_model: Arc<dyn LanguageModel>,
370 project: Entity<Project>,
371 action_log: Entity<ActionLog>,
372}
373
374impl Thread {
375 pub fn new(
376 project: Entity<Project>,
377 project_context: Rc<RefCell<ProjectContext>>,
378 context_server_registry: Entity<ContextServerRegistry>,
379 action_log: Entity<ActionLog>,
380 templates: Arc<Templates>,
381 default_model: Arc<dyn LanguageModel>,
382 cx: &mut Context<Self>,
383 ) -> Self {
384 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
385 Self {
386 messages: Vec::new(),
387 completion_mode: CompletionMode::Normal,
388 running_turn: None,
389 pending_agent_message: None,
390 tools: BTreeMap::default(),
391 context_server_registry,
392 profile_id,
393 project_context,
394 templates,
395 selected_model: default_model,
396 project,
397 action_log,
398 }
399 }
400
401 pub fn project(&self) -> &Entity<Project> {
402 &self.project
403 }
404
405 pub fn action_log(&self) -> &Entity<ActionLog> {
406 &self.action_log
407 }
408
409 pub fn set_mode(&mut self, mode: CompletionMode) {
410 self.completion_mode = mode;
411 }
412
413 #[cfg(any(test, feature = "test-support"))]
414 pub fn last_message(&self) -> Option<Message> {
415 if let Some(message) = self.pending_agent_message.clone() {
416 Some(Message::Agent(message))
417 } else {
418 self.messages.last().cloned()
419 }
420 }
421
422 pub fn add_tool(&mut self, tool: impl AgentTool) {
423 self.tools.insert(tool.name(), tool.erase());
424 }
425
426 pub fn remove_tool(&mut self, name: &str) -> bool {
427 self.tools.remove(name).is_some()
428 }
429
430 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
431 self.profile_id = profile_id;
432 }
433
434 pub fn cancel(&mut self) {
435 // TODO: do we need to emit a stop::cancel for ACP?
436 self.running_turn.take();
437 self.flush_pending_agent_message();
438 }
439
440 pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
441 self.cancel();
442 let Some(position) = self.messages.iter().position(
443 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
444 ) else {
445 return Err(anyhow!("Message not found"));
446 };
447 self.messages.truncate(position);
448 Ok(())
449 }
450
451 /// Sending a message results in the model streaming a response, which could include tool calls.
452 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
453 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
454 pub fn send<T>(
455 &mut self,
456 message_id: UserMessageId,
457 content: impl IntoIterator<Item = T>,
458 cx: &mut Context<Self>,
459 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
460 where
461 T: Into<UserMessageContent>,
462 {
463 let model = self.selected_model.clone();
464 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
465 log::info!("Thread::send called with model: {:?}", model.name());
466 log::debug!("Thread::send content: {:?}", content);
467
468 cx.notify();
469 let (events_tx, events_rx) =
470 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
471 let event_stream = AgentResponseEventStream(events_tx);
472
473 let user_message_ix = self.messages.len();
474 self.messages.push(Message::User(UserMessage {
475 id: message_id,
476 content,
477 }));
478 log::info!("Total messages in thread: {}", self.messages.len());
479 self.running_turn = Some(cx.spawn(async move |thread, cx| {
480 log::info!("Starting agent turn execution");
481 let turn_result = async {
482 // Perform one request, then keep looping if the model makes tool calls.
483 let mut completion_intent = CompletionIntent::UserPrompt;
484 'outer: loop {
485 log::debug!(
486 "Building completion request with intent: {:?}",
487 completion_intent
488 );
489 let request = thread.update(cx, |thread, cx| {
490 thread.build_completion_request(completion_intent, cx)
491 })?;
492
493 // Stream events, appending to messages and collecting up tool uses.
494 log::info!("Calling model.stream_completion");
495 let mut events = model.stream_completion(request, cx).await?;
496 log::debug!("Stream completion started successfully");
497
498 let mut tool_uses = FuturesUnordered::new();
499 while let Some(event) = events.next().await {
500 match event {
501 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
502 event_stream.send_stop(reason);
503 if reason == StopReason::Refusal {
504 thread.update(cx, |thread, _cx| {
505 thread.pending_agent_message = None;
506 thread.messages.truncate(user_message_ix);
507 })?;
508 break 'outer;
509 }
510 }
511 Ok(event) => {
512 log::trace!("Received completion event: {:?}", event);
513 thread
514 .update(cx, |thread, cx| {
515 tool_uses.extend(thread.handle_streamed_completion_event(
516 event,
517 &event_stream,
518 cx,
519 ));
520 })
521 .ok();
522 }
523 Err(error) => {
524 log::error!("Error in completion stream: {:?}", error);
525 event_stream.send_error(error);
526 break;
527 }
528 }
529 }
530
531 // If there are no tool uses, the turn is done.
532 if tool_uses.is_empty() {
533 log::info!("No tool uses found, completing turn");
534 break;
535 }
536 log::info!("Found {} tool uses to execute", tool_uses.len());
537
538 // As tool results trickle in, insert them in the last user
539 // message so that they can be sent on the next tick of the
540 // agentic loop.
541 while let Some(tool_result) = tool_uses.next().await {
542 log::info!("Tool finished {:?}", tool_result);
543
544 event_stream.update_tool_call_fields(
545 &tool_result.tool_use_id,
546 acp::ToolCallUpdateFields {
547 status: Some(if tool_result.is_error {
548 acp::ToolCallStatus::Failed
549 } else {
550 acp::ToolCallStatus::Completed
551 }),
552 raw_output: tool_result.output.clone(),
553 ..Default::default()
554 },
555 );
556 thread
557 .update(cx, |thread, _cx| {
558 thread
559 .pending_agent_message()
560 .tool_results
561 .insert(tool_result.tool_use_id.clone(), tool_result);
562 })
563 .ok();
564 }
565
566 thread.update(cx, |thread, _cx| thread.flush_pending_agent_message())?;
567
568 completion_intent = CompletionIntent::ToolResults;
569 }
570
571 Ok(())
572 }
573 .await;
574
575 thread
576 .update(cx, |thread, _cx| thread.flush_pending_agent_message())
577 .ok();
578
579 if let Err(error) = turn_result {
580 log::error!("Turn execution failed: {:?}", error);
581 event_stream.send_error(error);
582 } else {
583 log::info!("Turn execution completed successfully");
584 }
585 }));
586 events_rx
587 }
588
589 pub fn build_system_message(&self) -> LanguageModelRequestMessage {
590 log::debug!("Building system message");
591 let prompt = SystemPromptTemplate {
592 project: &self.project_context.borrow(),
593 available_tools: self.tools.keys().cloned().collect(),
594 }
595 .render(&self.templates)
596 .context("failed to build system prompt")
597 .expect("Invalid template");
598 log::debug!("System message built");
599 LanguageModelRequestMessage {
600 role: Role::System,
601 content: vec![prompt.into()],
602 cache: true,
603 }
604 }
605
606 /// A helper method that's called on every streamed completion event.
607 /// Returns an optional tool result task, which the main agentic loop in
608 /// send will send back to the model when it resolves.
609 fn handle_streamed_completion_event(
610 &mut self,
611 event: LanguageModelCompletionEvent,
612 event_stream: &AgentResponseEventStream,
613 cx: &mut Context<Self>,
614 ) -> Option<Task<LanguageModelToolResult>> {
615 log::trace!("Handling streamed completion event: {:?}", event);
616 use LanguageModelCompletionEvent::*;
617
618 match event {
619 StartMessage { .. } => {
620 self.messages.push(Message::Agent(AgentMessage::default()));
621 }
622 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
623 Thinking { text, signature } => {
624 self.handle_thinking_event(text, signature, event_stream, cx)
625 }
626 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
627 ToolUse(tool_use) => {
628 return self.handle_tool_use_event(tool_use, event_stream, cx);
629 }
630 ToolUseJsonParseError {
631 id,
632 tool_name,
633 raw_input,
634 json_parse_error,
635 } => {
636 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
637 id,
638 tool_name,
639 raw_input,
640 json_parse_error,
641 )));
642 }
643 UsageUpdate(_) | StatusUpdate(_) => {}
644 Stop(_) => unreachable!(),
645 }
646
647 None
648 }
649
650 fn handle_text_event(
651 &mut self,
652 new_text: String,
653 events_stream: &AgentResponseEventStream,
654 cx: &mut Context<Self>,
655 ) {
656 events_stream.send_text(&new_text);
657
658 let last_message = self.pending_agent_message();
659 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
660 text.push_str(&new_text);
661 } else {
662 last_message
663 .content
664 .push(AgentMessageContent::Text(new_text));
665 }
666
667 cx.notify();
668 }
669
670 fn handle_thinking_event(
671 &mut self,
672 new_text: String,
673 new_signature: Option<String>,
674 event_stream: &AgentResponseEventStream,
675 cx: &mut Context<Self>,
676 ) {
677 event_stream.send_thinking(&new_text);
678
679 let last_message = self.pending_agent_message();
680 if let Some(AgentMessageContent::Thinking { text, signature }) =
681 last_message.content.last_mut()
682 {
683 text.push_str(&new_text);
684 *signature = new_signature.or(signature.take());
685 } else {
686 last_message.content.push(AgentMessageContent::Thinking {
687 text: new_text,
688 signature: new_signature,
689 });
690 }
691
692 cx.notify();
693 }
694
695 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
696 let last_message = self.pending_agent_message();
697 last_message
698 .content
699 .push(AgentMessageContent::RedactedThinking(data));
700 cx.notify();
701 }
702
703 fn handle_tool_use_event(
704 &mut self,
705 tool_use: LanguageModelToolUse,
706 event_stream: &AgentResponseEventStream,
707 cx: &mut Context<Self>,
708 ) -> Option<Task<LanguageModelToolResult>> {
709 cx.notify();
710
711 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
712 let mut title = SharedString::from(&tool_use.name);
713 let mut kind = acp::ToolKind::Other;
714 if let Some(tool) = tool.as_ref() {
715 title = tool.initial_title(tool_use.input.clone());
716 kind = tool.kind();
717 }
718
719 // Ensure the last message ends in the current tool use
720 let last_message = self.pending_agent_message();
721 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
722 if let AgentMessageContent::ToolUse(last_tool_use) = content {
723 if last_tool_use.id == tool_use.id {
724 *last_tool_use = tool_use.clone();
725 false
726 } else {
727 true
728 }
729 } else {
730 true
731 }
732 });
733
734 if push_new_tool_use {
735 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
736 last_message
737 .content
738 .push(AgentMessageContent::ToolUse(tool_use.clone()));
739 } else {
740 event_stream.update_tool_call_fields(
741 &tool_use.id,
742 acp::ToolCallUpdateFields {
743 title: Some(title.into()),
744 kind: Some(kind),
745 raw_input: Some(tool_use.input.clone()),
746 ..Default::default()
747 },
748 );
749 }
750
751 if !tool_use.is_input_complete {
752 return None;
753 }
754
755 let Some(tool) = tool else {
756 let content = format!("No tool named {} exists", tool_use.name);
757 return Some(Task::ready(LanguageModelToolResult {
758 content: LanguageModelToolResultContent::Text(Arc::from(content)),
759 tool_use_id: tool_use.id,
760 tool_name: tool_use.name,
761 is_error: true,
762 output: None,
763 }));
764 };
765
766 let fs = self.project.read(cx).fs().clone();
767 let tool_event_stream =
768 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
769 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
770 status: Some(acp::ToolCallStatus::InProgress),
771 ..Default::default()
772 });
773 let supports_images = self.selected_model.supports_images();
774 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
775 Some(cx.foreground_executor().spawn(async move {
776 let tool_result = tool_result.await.and_then(|output| {
777 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
778 if !supports_images {
779 return Err(anyhow!(
780 "Attempted to read an image, but this model doesn't support it.",
781 ));
782 }
783 }
784 Ok(output)
785 });
786
787 match tool_result {
788 Ok(output) => LanguageModelToolResult {
789 tool_use_id: tool_use.id,
790 tool_name: tool_use.name,
791 is_error: false,
792 content: output.llm_output,
793 output: Some(output.raw_output),
794 },
795 Err(error) => LanguageModelToolResult {
796 tool_use_id: tool_use.id,
797 tool_name: tool_use.name,
798 is_error: true,
799 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
800 output: None,
801 },
802 }
803 }))
804 }
805
806 fn handle_tool_use_json_parse_error_event(
807 &mut self,
808 tool_use_id: LanguageModelToolUseId,
809 tool_name: Arc<str>,
810 raw_input: Arc<str>,
811 json_parse_error: String,
812 ) -> LanguageModelToolResult {
813 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
814 LanguageModelToolResult {
815 tool_use_id,
816 tool_name,
817 is_error: true,
818 content: LanguageModelToolResultContent::Text(tool_output.into()),
819 output: Some(serde_json::Value::String(raw_input.to_string())),
820 }
821 }
822
823 fn pending_agent_message(&mut self) -> &mut AgentMessage {
824 self.pending_agent_message.get_or_insert_default()
825 }
826
827 fn flush_pending_agent_message(&mut self) {
828 let Some(mut message) = self.pending_agent_message.take() else {
829 return;
830 };
831
832 for content in &message.content {
833 let AgentMessageContent::ToolUse(tool_use) = content else {
834 continue;
835 };
836
837 if !message.tool_results.contains_key(&tool_use.id) {
838 message.tool_results.insert(
839 tool_use.id.clone(),
840 LanguageModelToolResult {
841 tool_use_id: tool_use.id.clone(),
842 tool_name: tool_use.name.clone(),
843 is_error: true,
844 content: LanguageModelToolResultContent::Text(
845 "Tool canceled by user".into(),
846 ),
847 output: None,
848 },
849 );
850 }
851 }
852
853 self.messages.push(Message::Agent(message));
854 }
855
856 pub(crate) fn build_completion_request(
857 &self,
858 completion_intent: CompletionIntent,
859 cx: &mut App,
860 ) -> LanguageModelRequest {
861 log::debug!("Building completion request");
862 log::debug!("Completion intent: {:?}", completion_intent);
863 log::debug!("Completion mode: {:?}", self.completion_mode);
864
865 let messages = self.build_request_messages();
866 log::info!("Request will include {} messages", messages.len());
867
868 let tools = if let Some(tools) = self.tools(cx).log_err() {
869 tools
870 .filter_map(|tool| {
871 let tool_name = tool.name().to_string();
872 log::trace!("Including tool: {}", tool_name);
873 Some(LanguageModelRequestTool {
874 name: tool_name,
875 description: tool.description().to_string(),
876 input_schema: tool
877 .input_schema(self.selected_model.tool_input_format())
878 .log_err()?,
879 })
880 })
881 .collect()
882 } else {
883 Vec::new()
884 };
885
886 log::info!("Request includes {} tools", tools.len());
887
888 let request = LanguageModelRequest {
889 thread_id: None,
890 prompt_id: None,
891 intent: Some(completion_intent),
892 mode: Some(self.completion_mode),
893 messages,
894 tools,
895 tool_choice: None,
896 stop: Vec::new(),
897 temperature: None,
898 thinking_allowed: true,
899 };
900
901 log::debug!("Completion request built successfully");
902 request
903 }
904
905 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
906 let profile = AgentSettings::get_global(cx)
907 .profiles
908 .get(&self.profile_id)
909 .context("profile not found")?;
910 let provider_id = self.selected_model.provider_id();
911
912 Ok(self
913 .tools
914 .iter()
915 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
916 .filter_map(|(tool_name, tool)| {
917 if profile.is_tool_enabled(tool_name) {
918 Some(tool)
919 } else {
920 None
921 }
922 })
923 .chain(self.context_server_registry.read(cx).servers().flat_map(
924 |(server_id, tools)| {
925 tools.iter().filter_map(|(tool_name, tool)| {
926 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
927 Some(tool)
928 } else {
929 None
930 }
931 })
932 },
933 )))
934 }
935
936 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
937 log::trace!(
938 "Building request messages from {} thread messages",
939 self.messages.len()
940 );
941 let mut messages = vec![self.build_system_message()];
942 for message in &self.messages {
943 match message {
944 Message::User(message) => messages.push(message.to_request()),
945 Message::Agent(message) => messages.extend(message.to_request()),
946 }
947 }
948
949 if let Some(message) = self.pending_agent_message.as_ref() {
950 messages.extend(message.to_request());
951 }
952
953 messages
954 }
955
956 pub fn to_markdown(&self) -> String {
957 let mut markdown = String::new();
958 for (ix, message) in self.messages.iter().enumerate() {
959 if ix > 0 {
960 markdown.push('\n');
961 }
962 markdown.push_str(&message.to_markdown());
963 }
964
965 if let Some(message) = self.pending_agent_message.as_ref() {
966 markdown.push('\n');
967 markdown.push_str(&message.to_markdown());
968 }
969
970 markdown
971 }
972}
973
974pub trait AgentTool
975where
976 Self: 'static + Sized,
977{
978 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
979 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
980
981 fn name(&self) -> SharedString;
982
983 fn description(&self) -> SharedString {
984 let schema = schemars::schema_for!(Self::Input);
985 SharedString::new(
986 schema
987 .get("description")
988 .and_then(|description| description.as_str())
989 .unwrap_or_default(),
990 )
991 }
992
993 fn kind(&self) -> acp::ToolKind;
994
995 /// The initial tool title to display. Can be updated during the tool run.
996 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
997
998 /// Returns the JSON schema that describes the tool's input.
999 fn input_schema(&self) -> Schema {
1000 schemars::schema_for!(Self::Input)
1001 }
1002
1003 /// Some tools rely on a provider for the underlying billing or other reasons.
1004 /// Allow the tool to check if they are compatible, or should be filtered out.
1005 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1006 true
1007 }
1008
1009 /// Runs the tool with the provided input.
1010 fn run(
1011 self: Arc<Self>,
1012 input: Self::Input,
1013 event_stream: ToolCallEventStream,
1014 cx: &mut App,
1015 ) -> Task<Result<Self::Output>>;
1016
1017 fn erase(self) -> Arc<dyn AnyAgentTool> {
1018 Arc::new(Erased(Arc::new(self)))
1019 }
1020}
1021
1022pub struct Erased<T>(T);
1023
1024pub struct AgentToolOutput {
1025 pub llm_output: LanguageModelToolResultContent,
1026 pub raw_output: serde_json::Value,
1027}
1028
1029pub trait AnyAgentTool {
1030 fn name(&self) -> SharedString;
1031 fn description(&self) -> SharedString;
1032 fn kind(&self) -> acp::ToolKind;
1033 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1034 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1035 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1036 true
1037 }
1038 fn run(
1039 self: Arc<Self>,
1040 input: serde_json::Value,
1041 event_stream: ToolCallEventStream,
1042 cx: &mut App,
1043 ) -> Task<Result<AgentToolOutput>>;
1044}
1045
1046impl<T> AnyAgentTool for Erased<Arc<T>>
1047where
1048 T: AgentTool,
1049{
1050 fn name(&self) -> SharedString {
1051 self.0.name()
1052 }
1053
1054 fn description(&self) -> SharedString {
1055 self.0.description()
1056 }
1057
1058 fn kind(&self) -> agent_client_protocol::ToolKind {
1059 self.0.kind()
1060 }
1061
1062 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1063 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1064 self.0.initial_title(parsed_input)
1065 }
1066
1067 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1068 let mut json = serde_json::to_value(self.0.input_schema())?;
1069 adapt_schema_to_format(&mut json, format)?;
1070 Ok(json)
1071 }
1072
1073 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1074 self.0.supported_provider(provider)
1075 }
1076
1077 fn run(
1078 self: Arc<Self>,
1079 input: serde_json::Value,
1080 event_stream: ToolCallEventStream,
1081 cx: &mut App,
1082 ) -> Task<Result<AgentToolOutput>> {
1083 cx.spawn(async move |cx| {
1084 let input = serde_json::from_value(input)?;
1085 let output = cx
1086 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1087 .await?;
1088 let raw_output = serde_json::to_value(&output)?;
1089 Ok(AgentToolOutput {
1090 llm_output: output.into(),
1091 raw_output,
1092 })
1093 })
1094 }
1095}
1096
1097#[derive(Clone)]
1098struct AgentResponseEventStream(
1099 mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1100);
1101
1102impl AgentResponseEventStream {
1103 fn send_text(&self, text: &str) {
1104 self.0
1105 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
1106 .ok();
1107 }
1108
1109 fn send_thinking(&self, text: &str) {
1110 self.0
1111 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
1112 .ok();
1113 }
1114
1115 fn send_tool_call(
1116 &self,
1117 id: &LanguageModelToolUseId,
1118 title: SharedString,
1119 kind: acp::ToolKind,
1120 input: serde_json::Value,
1121 ) {
1122 self.0
1123 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
1124 id,
1125 title.to_string(),
1126 kind,
1127 input,
1128 ))))
1129 .ok();
1130 }
1131
1132 fn initial_tool_call(
1133 id: &LanguageModelToolUseId,
1134 title: String,
1135 kind: acp::ToolKind,
1136 input: serde_json::Value,
1137 ) -> acp::ToolCall {
1138 acp::ToolCall {
1139 id: acp::ToolCallId(id.to_string().into()),
1140 title,
1141 kind,
1142 status: acp::ToolCallStatus::Pending,
1143 content: vec![],
1144 locations: vec![],
1145 raw_input: Some(input),
1146 raw_output: None,
1147 }
1148 }
1149
1150 fn update_tool_call_fields(
1151 &self,
1152 tool_use_id: &LanguageModelToolUseId,
1153 fields: acp::ToolCallUpdateFields,
1154 ) {
1155 self.0
1156 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1157 acp::ToolCallUpdate {
1158 id: acp::ToolCallId(tool_use_id.to_string().into()),
1159 fields,
1160 }
1161 .into(),
1162 )))
1163 .ok();
1164 }
1165
1166 fn send_stop(&self, reason: StopReason) {
1167 match reason {
1168 StopReason::EndTurn => {
1169 self.0
1170 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
1171 .ok();
1172 }
1173 StopReason::MaxTokens => {
1174 self.0
1175 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
1176 .ok();
1177 }
1178 StopReason::Refusal => {
1179 self.0
1180 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
1181 .ok();
1182 }
1183 StopReason::ToolUse => {}
1184 }
1185 }
1186
1187 fn send_error(&self, error: LanguageModelCompletionError) {
1188 self.0.unbounded_send(Err(error)).ok();
1189 }
1190}
1191
1192#[derive(Clone)]
1193pub struct ToolCallEventStream {
1194 tool_use_id: LanguageModelToolUseId,
1195 kind: acp::ToolKind,
1196 input: serde_json::Value,
1197 stream: AgentResponseEventStream,
1198 fs: Option<Arc<dyn Fs>>,
1199}
1200
1201impl ToolCallEventStream {
1202 #[cfg(test)]
1203 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1204 let (events_tx, events_rx) =
1205 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
1206
1207 let stream = ToolCallEventStream::new(
1208 &LanguageModelToolUse {
1209 id: "test_id".into(),
1210 name: "test_tool".into(),
1211 raw_input: String::new(),
1212 input: serde_json::Value::Null,
1213 is_input_complete: true,
1214 },
1215 acp::ToolKind::Other,
1216 AgentResponseEventStream(events_tx),
1217 None,
1218 );
1219
1220 (stream, ToolCallEventStreamReceiver(events_rx))
1221 }
1222
1223 fn new(
1224 tool_use: &LanguageModelToolUse,
1225 kind: acp::ToolKind,
1226 stream: AgentResponseEventStream,
1227 fs: Option<Arc<dyn Fs>>,
1228 ) -> Self {
1229 Self {
1230 tool_use_id: tool_use.id.clone(),
1231 kind,
1232 input: tool_use.input.clone(),
1233 stream,
1234 fs,
1235 }
1236 }
1237
1238 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1239 self.stream
1240 .update_tool_call_fields(&self.tool_use_id, fields);
1241 }
1242
1243 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1244 self.stream
1245 .0
1246 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1247 acp_thread::ToolCallUpdateDiff {
1248 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1249 diff,
1250 }
1251 .into(),
1252 )))
1253 .ok();
1254 }
1255
1256 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1257 self.stream
1258 .0
1259 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1260 acp_thread::ToolCallUpdateTerminal {
1261 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1262 terminal,
1263 }
1264 .into(),
1265 )))
1266 .ok();
1267 }
1268
1269 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1270 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1271 return Task::ready(Ok(()));
1272 }
1273
1274 let (response_tx, response_rx) = oneshot::channel();
1275 self.stream
1276 .0
1277 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1278 ToolCallAuthorization {
1279 tool_call: AgentResponseEventStream::initial_tool_call(
1280 &self.tool_use_id,
1281 title.into(),
1282 self.kind.clone(),
1283 self.input.clone(),
1284 ),
1285 options: vec![
1286 acp::PermissionOption {
1287 id: acp::PermissionOptionId("always_allow".into()),
1288 name: "Always Allow".into(),
1289 kind: acp::PermissionOptionKind::AllowAlways,
1290 },
1291 acp::PermissionOption {
1292 id: acp::PermissionOptionId("allow".into()),
1293 name: "Allow".into(),
1294 kind: acp::PermissionOptionKind::AllowOnce,
1295 },
1296 acp::PermissionOption {
1297 id: acp::PermissionOptionId("deny".into()),
1298 name: "Deny".into(),
1299 kind: acp::PermissionOptionKind::RejectOnce,
1300 },
1301 ],
1302 response: response_tx,
1303 },
1304 )))
1305 .ok();
1306 let fs = self.fs.clone();
1307 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1308 "always_allow" => {
1309 if let Some(fs) = fs.clone() {
1310 cx.update(|cx| {
1311 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1312 settings.set_always_allow_tool_actions(true);
1313 });
1314 })?;
1315 }
1316
1317 Ok(())
1318 }
1319 "allow" => Ok(()),
1320 _ => Err(anyhow!("Permission to run tool denied by user")),
1321 })
1322 }
1323}
1324
1325#[cfg(test)]
1326pub struct ToolCallEventStreamReceiver(
1327 mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1328);
1329
1330#[cfg(test)]
1331impl ToolCallEventStreamReceiver {
1332 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1333 let event = self.0.next().await;
1334 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1335 auth
1336 } else {
1337 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1338 }
1339 }
1340
1341 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1342 let event = self.0.next().await;
1343 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1344 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1345 ))) = event
1346 {
1347 update.terminal
1348 } else {
1349 panic!("Expected terminal but got: {:?}", event);
1350 }
1351 }
1352}
1353
1354#[cfg(test)]
1355impl std::ops::Deref for ToolCallEventStreamReceiver {
1356 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1357
1358 fn deref(&self) -> &Self::Target {
1359 &self.0
1360 }
1361}
1362
1363#[cfg(test)]
1364impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1365 fn deref_mut(&mut self) -> &mut Self::Target {
1366 &mut self.0
1367 }
1368}
1369
1370fn codeblock_tag(full_path: &Path) -> String {
1371 let mut result = String::new();
1372
1373 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
1374 let _ = write!(result, "{} ", extension);
1375 }
1376
1377 let _ = write!(result, "{}", full_path.display());
1378
1379 result
1380}
1381
1382impl From<&str> for UserMessageContent {
1383 fn from(text: &str) -> Self {
1384 Self::Text(text.into())
1385 }
1386}
1387
1388impl From<acp::ContentBlock> for UserMessageContent {
1389 fn from(value: acp::ContentBlock) -> Self {
1390 match value {
1391 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1392 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1393 acp::ContentBlock::Audio(_) => {
1394 // TODO
1395 Self::Text("[audio]".to_string())
1396 }
1397 acp::ContentBlock::ResourceLink(resource_link) => {
1398 match MentionUri::parse(&resource_link.uri) {
1399 Ok(uri) => Self::Mention {
1400 uri,
1401 content: String::new(),
1402 },
1403 Err(err) => {
1404 log::error!("Failed to parse mention link: {}", err);
1405 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1406 }
1407 }
1408 }
1409 acp::ContentBlock::Resource(resource) => match resource.resource {
1410 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1411 match MentionUri::parse(&resource.uri) {
1412 Ok(uri) => Self::Mention {
1413 uri,
1414 content: resource.text,
1415 },
1416 Err(err) => {
1417 log::error!("Failed to parse mention link: {}", err);
1418 Self::Text(
1419 MarkdownCodeBlock {
1420 tag: &resource.uri,
1421 text: &resource.text,
1422 }
1423 .to_string(),
1424 )
1425 }
1426 }
1427 }
1428 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1429 // TODO
1430 Self::Text("[blob]".to_string())
1431 }
1432 },
1433 }
1434 }
1435}
1436
1437fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1438 LanguageModelImage {
1439 source: image_content.data.into(),
1440 // TODO: make this optional?
1441 size: gpui::Size::new(0.into(), 0.into()),
1442 }
1443}