1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
2use acp_thread::{MentionUri, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionMode};
9use collections::IndexMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, Context, Entity, SharedString, Task};
16use language_model::{
17 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
18 LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
19 LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
20 LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
29use std::{fmt::Write, ops::Range};
30use util::{ResultExt, markdown::MarkdownCodeBlock};
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum Message {
34 User(UserMessage),
35 Agent(AgentMessage),
36}
37
38impl Message {
39 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
40 match self {
41 Message::Agent(agent_message) => Some(agent_message),
42 _ => None,
43 }
44 }
45
46 pub fn to_markdown(&self) -> String {
47 match self {
48 Message::User(message) => message.to_markdown(),
49 Message::Agent(message) => message.to_markdown(),
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct UserMessage {
56 pub id: UserMessageId,
57 pub content: Vec<UserMessageContent>,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum UserMessageContent {
62 Text(String),
63 Mention { uri: MentionUri, content: String },
64 Image(LanguageModelImage),
65}
66
67impl UserMessage {
68 pub fn to_markdown(&self) -> String {
69 let mut markdown = String::from("## User\n\n");
70
71 for content in &self.content {
72 match content {
73 UserMessageContent::Text(text) => {
74 markdown.push_str(text);
75 markdown.push('\n');
76 }
77 UserMessageContent::Image(_) => {
78 markdown.push_str("<image />\n");
79 }
80 UserMessageContent::Mention { uri, content } => {
81 if !content.is_empty() {
82 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
83 } else {
84 let _ = write!(&mut markdown, "{}\n", uri.as_link());
85 }
86 }
87 }
88 }
89
90 markdown
91 }
92
93 fn to_request(&self) -> LanguageModelRequestMessage {
94 let mut message = LanguageModelRequestMessage {
95 role: Role::User,
96 content: Vec::with_capacity(self.content.len()),
97 cache: false,
98 };
99
100 const OPEN_CONTEXT: &str = "<context>\n\
101 The following items were attached by the user. \
102 They are up-to-date and don't need to be re-read.\n\n";
103
104 const OPEN_FILES_TAG: &str = "<files>";
105 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
106 const OPEN_THREADS_TAG: &str = "<threads>";
107 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
108 const OPEN_RULES_TAG: &str =
109 "<rules>\nThe user has specified the following rules that should be applied:\n";
110
111 let mut file_context = OPEN_FILES_TAG.to_string();
112 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
113 let mut thread_context = OPEN_THREADS_TAG.to_string();
114 let mut fetch_context = OPEN_FETCH_TAG.to_string();
115 let mut rules_context = OPEN_RULES_TAG.to_string();
116
117 for chunk in &self.content {
118 let chunk = match chunk {
119 UserMessageContent::Text(text) => {
120 language_model::MessageContent::Text(text.clone())
121 }
122 UserMessageContent::Image(value) => {
123 language_model::MessageContent::Image(value.clone())
124 }
125 UserMessageContent::Mention { uri, content } => {
126 match uri {
127 MentionUri::File(path) => {
128 write!(
129 &mut symbol_context,
130 "\n{}",
131 MarkdownCodeBlock {
132 tag: &codeblock_tag(&path, None),
133 text: &content.to_string(),
134 }
135 )
136 .ok();
137 }
138 MentionUri::Symbol {
139 path, line_range, ..
140 }
141 | MentionUri::Selection {
142 path, line_range, ..
143 } => {
144 write!(
145 &mut rules_context,
146 "\n{}",
147 MarkdownCodeBlock {
148 tag: &codeblock_tag(&path, Some(line_range)),
149 text: &content
150 }
151 )
152 .ok();
153 }
154 MentionUri::Thread { .. } => {
155 write!(&mut thread_context, "\n{}\n", content).ok();
156 }
157 MentionUri::TextThread { .. } => {
158 write!(&mut thread_context, "\n{}\n", content).ok();
159 }
160 MentionUri::Rule { .. } => {
161 write!(
162 &mut rules_context,
163 "\n{}",
164 MarkdownCodeBlock {
165 tag: "",
166 text: &content
167 }
168 )
169 .ok();
170 }
171 MentionUri::Fetch { url } => {
172 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
173 }
174 }
175
176 language_model::MessageContent::Text(uri.as_link().to_string())
177 }
178 };
179
180 message.content.push(chunk);
181 }
182
183 let len_before_context = message.content.len();
184
185 if file_context.len() > OPEN_FILES_TAG.len() {
186 file_context.push_str("</files>\n");
187 message
188 .content
189 .push(language_model::MessageContent::Text(file_context));
190 }
191
192 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
193 symbol_context.push_str("</symbols>\n");
194 message
195 .content
196 .push(language_model::MessageContent::Text(symbol_context));
197 }
198
199 if thread_context.len() > OPEN_THREADS_TAG.len() {
200 thread_context.push_str("</threads>\n");
201 message
202 .content
203 .push(language_model::MessageContent::Text(thread_context));
204 }
205
206 if fetch_context.len() > OPEN_FETCH_TAG.len() {
207 fetch_context.push_str("</fetched_urls>\n");
208 message
209 .content
210 .push(language_model::MessageContent::Text(fetch_context));
211 }
212
213 if rules_context.len() > OPEN_RULES_TAG.len() {
214 rules_context.push_str("</user_rules>\n");
215 message
216 .content
217 .push(language_model::MessageContent::Text(rules_context));
218 }
219
220 if message.content.len() > len_before_context {
221 message.content.insert(
222 len_before_context,
223 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
224 );
225 message
226 .content
227 .push(language_model::MessageContent::Text("</context>".into()));
228 }
229
230 message
231 }
232}
233
234fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
235 let mut result = String::new();
236
237 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
238 let _ = write!(result, "{} ", extension);
239 }
240
241 let _ = write!(result, "{}", full_path.display());
242
243 if let Some(range) = line_range {
244 if range.start == range.end {
245 let _ = write!(result, ":{}", range.start + 1);
246 } else {
247 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
248 }
249 }
250
251 result
252}
253
254impl AgentMessage {
255 pub fn to_markdown(&self) -> String {
256 let mut markdown = String::from("## Assistant\n\n");
257
258 for content in &self.content {
259 match content {
260 AgentMessageContent::Text(text) => {
261 markdown.push_str(text);
262 markdown.push('\n');
263 }
264 AgentMessageContent::Thinking { text, .. } => {
265 markdown.push_str("<think>");
266 markdown.push_str(text);
267 markdown.push_str("</think>\n");
268 }
269 AgentMessageContent::RedactedThinking(_) => {
270 markdown.push_str("<redacted_thinking />\n")
271 }
272 AgentMessageContent::Image(_) => {
273 markdown.push_str("<image />\n");
274 }
275 AgentMessageContent::ToolUse(tool_use) => {
276 markdown.push_str(&format!(
277 "**Tool Use**: {} (ID: {})\n",
278 tool_use.name, tool_use.id
279 ));
280 markdown.push_str(&format!(
281 "{}\n",
282 MarkdownCodeBlock {
283 tag: "json",
284 text: &format!("{:#}", tool_use.input)
285 }
286 ));
287 }
288 }
289 }
290
291 for tool_result in self.tool_results.values() {
292 markdown.push_str(&format!(
293 "**Tool Result**: {} (ID: {})\n\n",
294 tool_result.tool_name, tool_result.tool_use_id
295 ));
296 if tool_result.is_error {
297 markdown.push_str("**ERROR:**\n");
298 }
299
300 match &tool_result.content {
301 LanguageModelToolResultContent::Text(text) => {
302 writeln!(markdown, "{text}\n").ok();
303 }
304 LanguageModelToolResultContent::Image(_) => {
305 writeln!(markdown, "<image />\n").ok();
306 }
307 }
308
309 if let Some(output) = tool_result.output.as_ref() {
310 writeln!(
311 markdown,
312 "**Debug Output**:\n\n```json\n{}\n```\n",
313 serde_json::to_string_pretty(output).unwrap()
314 )
315 .unwrap();
316 }
317 }
318
319 markdown
320 }
321
322 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
323 let mut content = Vec::with_capacity(self.content.len());
324 for chunk in &self.content {
325 let chunk = match chunk {
326 AgentMessageContent::Text(text) => {
327 language_model::MessageContent::Text(text.clone())
328 }
329 AgentMessageContent::Thinking { text, signature } => {
330 language_model::MessageContent::Thinking {
331 text: text.clone(),
332 signature: signature.clone(),
333 }
334 }
335 AgentMessageContent::RedactedThinking(value) => {
336 language_model::MessageContent::RedactedThinking(value.clone())
337 }
338 AgentMessageContent::ToolUse(value) => {
339 language_model::MessageContent::ToolUse(value.clone())
340 }
341 AgentMessageContent::Image(value) => {
342 language_model::MessageContent::Image(value.clone())
343 }
344 };
345 content.push(chunk);
346 }
347
348 let mut messages = vec![LanguageModelRequestMessage {
349 role: Role::Assistant,
350 content,
351 cache: false,
352 }];
353
354 if !self.tool_results.is_empty() {
355 let mut tool_results = Vec::with_capacity(self.tool_results.len());
356 for tool_result in self.tool_results.values() {
357 tool_results.push(language_model::MessageContent::ToolResult(
358 tool_result.clone(),
359 ));
360 }
361 messages.push(LanguageModelRequestMessage {
362 role: Role::User,
363 content: tool_results,
364 cache: false,
365 });
366 }
367
368 messages
369 }
370}
371
372#[derive(Default, Debug, Clone, PartialEq, Eq)]
373pub struct AgentMessage {
374 pub content: Vec<AgentMessageContent>,
375 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
376}
377
378#[derive(Debug, Clone, PartialEq, Eq)]
379pub enum AgentMessageContent {
380 Text(String),
381 Thinking {
382 text: String,
383 signature: Option<String>,
384 },
385 RedactedThinking(String),
386 Image(LanguageModelImage),
387 ToolUse(LanguageModelToolUse),
388}
389
390#[derive(Debug)]
391pub enum AgentResponseEvent {
392 Text(String),
393 Thinking(String),
394 ToolCall(acp::ToolCall),
395 ToolCallUpdate(acp_thread::ToolCallUpdate),
396 ToolCallAuthorization(ToolCallAuthorization),
397 Stop(acp::StopReason),
398}
399
400#[derive(Debug)]
401pub struct ToolCallAuthorization {
402 pub tool_call: acp::ToolCall,
403 pub options: Vec<acp::PermissionOption>,
404 pub response: oneshot::Sender<acp::PermissionOptionId>,
405}
406
407pub struct Thread {
408 messages: Vec<Message>,
409 completion_mode: CompletionMode,
410 /// Holds the task that handles agent interaction until the end of the turn.
411 /// Survives across multiple requests as the model performs tool calls and
412 /// we run tools, report their results.
413 running_turn: Option<Task<()>>,
414 pending_agent_message: Option<AgentMessage>,
415 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
416 context_server_registry: Entity<ContextServerRegistry>,
417 profile_id: AgentProfileId,
418 project_context: Rc<RefCell<ProjectContext>>,
419 templates: Arc<Templates>,
420 pub selected_model: Arc<dyn LanguageModel>,
421 project: Entity<Project>,
422 action_log: Entity<ActionLog>,
423}
424
425impl Thread {
426 pub fn new(
427 project: Entity<Project>,
428 project_context: Rc<RefCell<ProjectContext>>,
429 context_server_registry: Entity<ContextServerRegistry>,
430 action_log: Entity<ActionLog>,
431 templates: Arc<Templates>,
432 default_model: Arc<dyn LanguageModel>,
433 cx: &mut Context<Self>,
434 ) -> Self {
435 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
436 Self {
437 messages: Vec::new(),
438 completion_mode: CompletionMode::Normal,
439 running_turn: None,
440 pending_agent_message: None,
441 tools: BTreeMap::default(),
442 context_server_registry,
443 profile_id,
444 project_context,
445 templates,
446 selected_model: default_model,
447 project,
448 action_log,
449 }
450 }
451
452 pub fn project(&self) -> &Entity<Project> {
453 &self.project
454 }
455
456 pub fn action_log(&self) -> &Entity<ActionLog> {
457 &self.action_log
458 }
459
460 pub fn set_mode(&mut self, mode: CompletionMode) {
461 self.completion_mode = mode;
462 }
463
464 #[cfg(any(test, feature = "test-support"))]
465 pub fn last_message(&self) -> Option<Message> {
466 if let Some(message) = self.pending_agent_message.clone() {
467 Some(Message::Agent(message))
468 } else {
469 self.messages.last().cloned()
470 }
471 }
472
473 pub fn add_tool(&mut self, tool: impl AgentTool) {
474 self.tools.insert(tool.name(), tool.erase());
475 }
476
477 pub fn remove_tool(&mut self, name: &str) -> bool {
478 self.tools.remove(name).is_some()
479 }
480
481 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
482 self.profile_id = profile_id;
483 }
484
485 pub fn cancel(&mut self) {
486 // TODO: do we need to emit a stop::cancel for ACP?
487 self.running_turn.take();
488 self.flush_pending_agent_message();
489 }
490
491 pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
492 self.cancel();
493 let Some(position) = self.messages.iter().position(
494 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
495 ) else {
496 return Err(anyhow!("Message not found"));
497 };
498 self.messages.truncate(position);
499 Ok(())
500 }
501
502 /// Sending a message results in the model streaming a response, which could include tool calls.
503 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
504 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
505 pub fn send<T>(
506 &mut self,
507 message_id: UserMessageId,
508 content: impl IntoIterator<Item = T>,
509 cx: &mut Context<Self>,
510 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
511 where
512 T: Into<UserMessageContent>,
513 {
514 let model = self.selected_model.clone();
515 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
516 log::info!("Thread::send called with model: {:?}", model.name());
517 log::debug!("Thread::send content: {:?}", content);
518
519 cx.notify();
520 let (events_tx, events_rx) =
521 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
522 let event_stream = AgentResponseEventStream(events_tx);
523
524 let user_message_ix = self.messages.len();
525 self.messages.push(Message::User(UserMessage {
526 id: message_id,
527 content,
528 }));
529 log::info!("Total messages in thread: {}", self.messages.len());
530 self.running_turn = Some(cx.spawn(async move |thread, cx| {
531 log::info!("Starting agent turn execution");
532 let turn_result = async {
533 // Perform one request, then keep looping if the model makes tool calls.
534 let mut completion_intent = CompletionIntent::UserPrompt;
535 'outer: loop {
536 log::debug!(
537 "Building completion request with intent: {:?}",
538 completion_intent
539 );
540 let request = thread.update(cx, |thread, cx| {
541 thread.build_completion_request(completion_intent, cx)
542 })?;
543
544 // Stream events, appending to messages and collecting up tool uses.
545 log::info!("Calling model.stream_completion");
546 let mut events = model.stream_completion(request, cx).await?;
547 log::debug!("Stream completion started successfully");
548
549 let mut tool_uses = FuturesUnordered::new();
550 while let Some(event) = events.next().await {
551 match event {
552 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
553 event_stream.send_stop(reason);
554 if reason == StopReason::Refusal {
555 thread.update(cx, |thread, _cx| {
556 thread.pending_agent_message = None;
557 thread.messages.truncate(user_message_ix);
558 })?;
559 break 'outer;
560 }
561 }
562 Ok(event) => {
563 log::trace!("Received completion event: {:?}", event);
564 thread
565 .update(cx, |thread, cx| {
566 tool_uses.extend(thread.handle_streamed_completion_event(
567 event,
568 &event_stream,
569 cx,
570 ));
571 })
572 .ok();
573 }
574 Err(error) => {
575 log::error!("Error in completion stream: {:?}", error);
576 event_stream.send_error(error);
577 break;
578 }
579 }
580 }
581
582 // If there are no tool uses, the turn is done.
583 if tool_uses.is_empty() {
584 log::info!("No tool uses found, completing turn");
585 break;
586 }
587 log::info!("Found {} tool uses to execute", tool_uses.len());
588
589 // As tool results trickle in, insert them in the last user
590 // message so that they can be sent on the next tick of the
591 // agentic loop.
592 while let Some(tool_result) = tool_uses.next().await {
593 log::info!("Tool finished {:?}", tool_result);
594
595 event_stream.update_tool_call_fields(
596 &tool_result.tool_use_id,
597 acp::ToolCallUpdateFields {
598 status: Some(if tool_result.is_error {
599 acp::ToolCallStatus::Failed
600 } else {
601 acp::ToolCallStatus::Completed
602 }),
603 raw_output: tool_result.output.clone(),
604 ..Default::default()
605 },
606 );
607 thread
608 .update(cx, |thread, _cx| {
609 thread
610 .pending_agent_message()
611 .tool_results
612 .insert(tool_result.tool_use_id.clone(), tool_result);
613 })
614 .ok();
615 }
616
617 thread.update(cx, |thread, _cx| thread.flush_pending_agent_message())?;
618
619 completion_intent = CompletionIntent::ToolResults;
620 }
621
622 Ok(())
623 }
624 .await;
625
626 thread
627 .update(cx, |thread, _cx| thread.flush_pending_agent_message())
628 .ok();
629
630 if let Err(error) = turn_result {
631 log::error!("Turn execution failed: {:?}", error);
632 event_stream.send_error(error);
633 } else {
634 log::info!("Turn execution completed successfully");
635 }
636 }));
637 events_rx
638 }
639
640 pub fn build_system_message(&self) -> LanguageModelRequestMessage {
641 log::debug!("Building system message");
642 let prompt = SystemPromptTemplate {
643 project: &self.project_context.borrow(),
644 available_tools: self.tools.keys().cloned().collect(),
645 }
646 .render(&self.templates)
647 .context("failed to build system prompt")
648 .expect("Invalid template");
649 log::debug!("System message built");
650 LanguageModelRequestMessage {
651 role: Role::System,
652 content: vec![prompt.into()],
653 cache: true,
654 }
655 }
656
657 /// A helper method that's called on every streamed completion event.
658 /// Returns an optional tool result task, which the main agentic loop in
659 /// send will send back to the model when it resolves.
660 fn handle_streamed_completion_event(
661 &mut self,
662 event: LanguageModelCompletionEvent,
663 event_stream: &AgentResponseEventStream,
664 cx: &mut Context<Self>,
665 ) -> Option<Task<LanguageModelToolResult>> {
666 log::trace!("Handling streamed completion event: {:?}", event);
667 use LanguageModelCompletionEvent::*;
668
669 match event {
670 StartMessage { .. } => {
671 self.messages.push(Message::Agent(AgentMessage::default()));
672 }
673 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
674 Thinking { text, signature } => {
675 self.handle_thinking_event(text, signature, event_stream, cx)
676 }
677 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
678 ToolUse(tool_use) => {
679 return self.handle_tool_use_event(tool_use, event_stream, cx);
680 }
681 ToolUseJsonParseError {
682 id,
683 tool_name,
684 raw_input,
685 json_parse_error,
686 } => {
687 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
688 id,
689 tool_name,
690 raw_input,
691 json_parse_error,
692 )));
693 }
694 UsageUpdate(_) | StatusUpdate(_) => {}
695 Stop(_) => unreachable!(),
696 }
697
698 None
699 }
700
701 fn handle_text_event(
702 &mut self,
703 new_text: String,
704 events_stream: &AgentResponseEventStream,
705 cx: &mut Context<Self>,
706 ) {
707 events_stream.send_text(&new_text);
708
709 let last_message = self.pending_agent_message();
710 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
711 text.push_str(&new_text);
712 } else {
713 last_message
714 .content
715 .push(AgentMessageContent::Text(new_text));
716 }
717
718 cx.notify();
719 }
720
721 fn handle_thinking_event(
722 &mut self,
723 new_text: String,
724 new_signature: Option<String>,
725 event_stream: &AgentResponseEventStream,
726 cx: &mut Context<Self>,
727 ) {
728 event_stream.send_thinking(&new_text);
729
730 let last_message = self.pending_agent_message();
731 if let Some(AgentMessageContent::Thinking { text, signature }) =
732 last_message.content.last_mut()
733 {
734 text.push_str(&new_text);
735 *signature = new_signature.or(signature.take());
736 } else {
737 last_message.content.push(AgentMessageContent::Thinking {
738 text: new_text,
739 signature: new_signature,
740 });
741 }
742
743 cx.notify();
744 }
745
746 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
747 let last_message = self.pending_agent_message();
748 last_message
749 .content
750 .push(AgentMessageContent::RedactedThinking(data));
751 cx.notify();
752 }
753
754 fn handle_tool_use_event(
755 &mut self,
756 tool_use: LanguageModelToolUse,
757 event_stream: &AgentResponseEventStream,
758 cx: &mut Context<Self>,
759 ) -> Option<Task<LanguageModelToolResult>> {
760 cx.notify();
761
762 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
763 let mut title = SharedString::from(&tool_use.name);
764 let mut kind = acp::ToolKind::Other;
765 if let Some(tool) = tool.as_ref() {
766 title = tool.initial_title(tool_use.input.clone());
767 kind = tool.kind();
768 }
769
770 // Ensure the last message ends in the current tool use
771 let last_message = self.pending_agent_message();
772 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
773 if let AgentMessageContent::ToolUse(last_tool_use) = content {
774 if last_tool_use.id == tool_use.id {
775 *last_tool_use = tool_use.clone();
776 false
777 } else {
778 true
779 }
780 } else {
781 true
782 }
783 });
784
785 if push_new_tool_use {
786 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
787 last_message
788 .content
789 .push(AgentMessageContent::ToolUse(tool_use.clone()));
790 } else {
791 event_stream.update_tool_call_fields(
792 &tool_use.id,
793 acp::ToolCallUpdateFields {
794 title: Some(title.into()),
795 kind: Some(kind),
796 raw_input: Some(tool_use.input.clone()),
797 ..Default::default()
798 },
799 );
800 }
801
802 if !tool_use.is_input_complete {
803 return None;
804 }
805
806 let Some(tool) = tool else {
807 let content = format!("No tool named {} exists", tool_use.name);
808 return Some(Task::ready(LanguageModelToolResult {
809 content: LanguageModelToolResultContent::Text(Arc::from(content)),
810 tool_use_id: tool_use.id,
811 tool_name: tool_use.name,
812 is_error: true,
813 output: None,
814 }));
815 };
816
817 let fs = self.project.read(cx).fs().clone();
818 let tool_event_stream =
819 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
820 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
821 status: Some(acp::ToolCallStatus::InProgress),
822 ..Default::default()
823 });
824 let supports_images = self.selected_model.supports_images();
825 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
826 Some(cx.foreground_executor().spawn(async move {
827 let tool_result = tool_result.await.and_then(|output| {
828 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
829 if !supports_images {
830 return Err(anyhow!(
831 "Attempted to read an image, but this model doesn't support it.",
832 ));
833 }
834 }
835 Ok(output)
836 });
837
838 match tool_result {
839 Ok(output) => LanguageModelToolResult {
840 tool_use_id: tool_use.id,
841 tool_name: tool_use.name,
842 is_error: false,
843 content: output.llm_output,
844 output: Some(output.raw_output),
845 },
846 Err(error) => LanguageModelToolResult {
847 tool_use_id: tool_use.id,
848 tool_name: tool_use.name,
849 is_error: true,
850 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
851 output: None,
852 },
853 }
854 }))
855 }
856
857 fn handle_tool_use_json_parse_error_event(
858 &mut self,
859 tool_use_id: LanguageModelToolUseId,
860 tool_name: Arc<str>,
861 raw_input: Arc<str>,
862 json_parse_error: String,
863 ) -> LanguageModelToolResult {
864 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
865 LanguageModelToolResult {
866 tool_use_id,
867 tool_name,
868 is_error: true,
869 content: LanguageModelToolResultContent::Text(tool_output.into()),
870 output: Some(serde_json::Value::String(raw_input.to_string())),
871 }
872 }
873
874 fn pending_agent_message(&mut self) -> &mut AgentMessage {
875 self.pending_agent_message.get_or_insert_default()
876 }
877
878 fn flush_pending_agent_message(&mut self) {
879 let Some(mut message) = self.pending_agent_message.take() else {
880 return;
881 };
882
883 for content in &message.content {
884 let AgentMessageContent::ToolUse(tool_use) = content else {
885 continue;
886 };
887
888 if !message.tool_results.contains_key(&tool_use.id) {
889 message.tool_results.insert(
890 tool_use.id.clone(),
891 LanguageModelToolResult {
892 tool_use_id: tool_use.id.clone(),
893 tool_name: tool_use.name.clone(),
894 is_error: true,
895 content: LanguageModelToolResultContent::Text(
896 "Tool canceled by user".into(),
897 ),
898 output: None,
899 },
900 );
901 }
902 }
903
904 self.messages.push(Message::Agent(message));
905 }
906
907 pub(crate) fn build_completion_request(
908 &self,
909 completion_intent: CompletionIntent,
910 cx: &mut App,
911 ) -> LanguageModelRequest {
912 log::debug!("Building completion request");
913 log::debug!("Completion intent: {:?}", completion_intent);
914 log::debug!("Completion mode: {:?}", self.completion_mode);
915
916 let messages = self.build_request_messages();
917 log::info!("Request will include {} messages", messages.len());
918
919 let tools = if let Some(tools) = self.tools(cx).log_err() {
920 tools
921 .filter_map(|tool| {
922 let tool_name = tool.name().to_string();
923 log::trace!("Including tool: {}", tool_name);
924 Some(LanguageModelRequestTool {
925 name: tool_name,
926 description: tool.description().to_string(),
927 input_schema: tool
928 .input_schema(self.selected_model.tool_input_format())
929 .log_err()?,
930 })
931 })
932 .collect()
933 } else {
934 Vec::new()
935 };
936
937 log::info!("Request includes {} tools", tools.len());
938
939 let request = LanguageModelRequest {
940 thread_id: None,
941 prompt_id: None,
942 intent: Some(completion_intent),
943 mode: Some(self.completion_mode),
944 messages,
945 tools,
946 tool_choice: None,
947 stop: Vec::new(),
948 temperature: None,
949 thinking_allowed: true,
950 };
951
952 log::debug!("Completion request built successfully");
953 request
954 }
955
956 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
957 let profile = AgentSettings::get_global(cx)
958 .profiles
959 .get(&self.profile_id)
960 .context("profile not found")?;
961 let provider_id = self.selected_model.provider_id();
962
963 Ok(self
964 .tools
965 .iter()
966 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
967 .filter_map(|(tool_name, tool)| {
968 if profile.is_tool_enabled(tool_name) {
969 Some(tool)
970 } else {
971 None
972 }
973 })
974 .chain(self.context_server_registry.read(cx).servers().flat_map(
975 |(server_id, tools)| {
976 tools.iter().filter_map(|(tool_name, tool)| {
977 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
978 Some(tool)
979 } else {
980 None
981 }
982 })
983 },
984 )))
985 }
986
987 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
988 log::trace!(
989 "Building request messages from {} thread messages",
990 self.messages.len()
991 );
992 let mut messages = vec![self.build_system_message()];
993 for message in &self.messages {
994 match message {
995 Message::User(message) => messages.push(message.to_request()),
996 Message::Agent(message) => messages.extend(message.to_request()),
997 }
998 }
999
1000 if let Some(message) = self.pending_agent_message.as_ref() {
1001 messages.extend(message.to_request());
1002 }
1003
1004 messages
1005 }
1006
1007 pub fn to_markdown(&self) -> String {
1008 let mut markdown = String::new();
1009 for (ix, message) in self.messages.iter().enumerate() {
1010 if ix > 0 {
1011 markdown.push('\n');
1012 }
1013 markdown.push_str(&message.to_markdown());
1014 }
1015
1016 if let Some(message) = self.pending_agent_message.as_ref() {
1017 markdown.push('\n');
1018 markdown.push_str(&message.to_markdown());
1019 }
1020
1021 markdown
1022 }
1023}
1024
1025pub trait AgentTool
1026where
1027 Self: 'static + Sized,
1028{
1029 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1030 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1031
1032 fn name(&self) -> SharedString;
1033
1034 fn description(&self) -> SharedString {
1035 let schema = schemars::schema_for!(Self::Input);
1036 SharedString::new(
1037 schema
1038 .get("description")
1039 .and_then(|description| description.as_str())
1040 .unwrap_or_default(),
1041 )
1042 }
1043
1044 fn kind(&self) -> acp::ToolKind;
1045
1046 /// The initial tool title to display. Can be updated during the tool run.
1047 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1048
1049 /// Returns the JSON schema that describes the tool's input.
1050 fn input_schema(&self) -> Schema {
1051 schemars::schema_for!(Self::Input)
1052 }
1053
1054 /// Some tools rely on a provider for the underlying billing or other reasons.
1055 /// Allow the tool to check if they are compatible, or should be filtered out.
1056 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1057 true
1058 }
1059
1060 /// Runs the tool with the provided input.
1061 fn run(
1062 self: Arc<Self>,
1063 input: Self::Input,
1064 event_stream: ToolCallEventStream,
1065 cx: &mut App,
1066 ) -> Task<Result<Self::Output>>;
1067
1068 fn erase(self) -> Arc<dyn AnyAgentTool> {
1069 Arc::new(Erased(Arc::new(self)))
1070 }
1071}
1072
1073pub struct Erased<T>(T);
1074
1075pub struct AgentToolOutput {
1076 pub llm_output: LanguageModelToolResultContent,
1077 pub raw_output: serde_json::Value,
1078}
1079
1080pub trait AnyAgentTool {
1081 fn name(&self) -> SharedString;
1082 fn description(&self) -> SharedString;
1083 fn kind(&self) -> acp::ToolKind;
1084 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1085 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1086 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1087 true
1088 }
1089 fn run(
1090 self: Arc<Self>,
1091 input: serde_json::Value,
1092 event_stream: ToolCallEventStream,
1093 cx: &mut App,
1094 ) -> Task<Result<AgentToolOutput>>;
1095}
1096
1097impl<T> AnyAgentTool for Erased<Arc<T>>
1098where
1099 T: AgentTool,
1100{
1101 fn name(&self) -> SharedString {
1102 self.0.name()
1103 }
1104
1105 fn description(&self) -> SharedString {
1106 self.0.description()
1107 }
1108
1109 fn kind(&self) -> agent_client_protocol::ToolKind {
1110 self.0.kind()
1111 }
1112
1113 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1114 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1115 self.0.initial_title(parsed_input)
1116 }
1117
1118 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1119 let mut json = serde_json::to_value(self.0.input_schema())?;
1120 adapt_schema_to_format(&mut json, format)?;
1121 Ok(json)
1122 }
1123
1124 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1125 self.0.supported_provider(provider)
1126 }
1127
1128 fn run(
1129 self: Arc<Self>,
1130 input: serde_json::Value,
1131 event_stream: ToolCallEventStream,
1132 cx: &mut App,
1133 ) -> Task<Result<AgentToolOutput>> {
1134 cx.spawn(async move |cx| {
1135 let input = serde_json::from_value(input)?;
1136 let output = cx
1137 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1138 .await?;
1139 let raw_output = serde_json::to_value(&output)?;
1140 Ok(AgentToolOutput {
1141 llm_output: output.into(),
1142 raw_output,
1143 })
1144 })
1145 }
1146}
1147
1148#[derive(Clone)]
1149struct AgentResponseEventStream(
1150 mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1151);
1152
1153impl AgentResponseEventStream {
1154 fn send_text(&self, text: &str) {
1155 self.0
1156 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
1157 .ok();
1158 }
1159
1160 fn send_thinking(&self, text: &str) {
1161 self.0
1162 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
1163 .ok();
1164 }
1165
1166 fn send_tool_call(
1167 &self,
1168 id: &LanguageModelToolUseId,
1169 title: SharedString,
1170 kind: acp::ToolKind,
1171 input: serde_json::Value,
1172 ) {
1173 self.0
1174 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
1175 id,
1176 title.to_string(),
1177 kind,
1178 input,
1179 ))))
1180 .ok();
1181 }
1182
1183 fn initial_tool_call(
1184 id: &LanguageModelToolUseId,
1185 title: String,
1186 kind: acp::ToolKind,
1187 input: serde_json::Value,
1188 ) -> acp::ToolCall {
1189 acp::ToolCall {
1190 id: acp::ToolCallId(id.to_string().into()),
1191 title,
1192 kind,
1193 status: acp::ToolCallStatus::Pending,
1194 content: vec![],
1195 locations: vec![],
1196 raw_input: Some(input),
1197 raw_output: None,
1198 }
1199 }
1200
1201 fn update_tool_call_fields(
1202 &self,
1203 tool_use_id: &LanguageModelToolUseId,
1204 fields: acp::ToolCallUpdateFields,
1205 ) {
1206 self.0
1207 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1208 acp::ToolCallUpdate {
1209 id: acp::ToolCallId(tool_use_id.to_string().into()),
1210 fields,
1211 }
1212 .into(),
1213 )))
1214 .ok();
1215 }
1216
1217 fn send_stop(&self, reason: StopReason) {
1218 match reason {
1219 StopReason::EndTurn => {
1220 self.0
1221 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
1222 .ok();
1223 }
1224 StopReason::MaxTokens => {
1225 self.0
1226 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
1227 .ok();
1228 }
1229 StopReason::Refusal => {
1230 self.0
1231 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
1232 .ok();
1233 }
1234 StopReason::ToolUse => {}
1235 }
1236 }
1237
1238 fn send_error(&self, error: LanguageModelCompletionError) {
1239 self.0.unbounded_send(Err(error)).ok();
1240 }
1241}
1242
1243#[derive(Clone)]
1244pub struct ToolCallEventStream {
1245 tool_use_id: LanguageModelToolUseId,
1246 kind: acp::ToolKind,
1247 input: serde_json::Value,
1248 stream: AgentResponseEventStream,
1249 fs: Option<Arc<dyn Fs>>,
1250}
1251
1252impl ToolCallEventStream {
1253 #[cfg(test)]
1254 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1255 let (events_tx, events_rx) =
1256 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
1257
1258 let stream = ToolCallEventStream::new(
1259 &LanguageModelToolUse {
1260 id: "test_id".into(),
1261 name: "test_tool".into(),
1262 raw_input: String::new(),
1263 input: serde_json::Value::Null,
1264 is_input_complete: true,
1265 },
1266 acp::ToolKind::Other,
1267 AgentResponseEventStream(events_tx),
1268 None,
1269 );
1270
1271 (stream, ToolCallEventStreamReceiver(events_rx))
1272 }
1273
1274 fn new(
1275 tool_use: &LanguageModelToolUse,
1276 kind: acp::ToolKind,
1277 stream: AgentResponseEventStream,
1278 fs: Option<Arc<dyn Fs>>,
1279 ) -> Self {
1280 Self {
1281 tool_use_id: tool_use.id.clone(),
1282 kind,
1283 input: tool_use.input.clone(),
1284 stream,
1285 fs,
1286 }
1287 }
1288
1289 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1290 self.stream
1291 .update_tool_call_fields(&self.tool_use_id, fields);
1292 }
1293
1294 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1295 self.stream
1296 .0
1297 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1298 acp_thread::ToolCallUpdateDiff {
1299 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1300 diff,
1301 }
1302 .into(),
1303 )))
1304 .ok();
1305 }
1306
1307 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1308 self.stream
1309 .0
1310 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1311 acp_thread::ToolCallUpdateTerminal {
1312 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1313 terminal,
1314 }
1315 .into(),
1316 )))
1317 .ok();
1318 }
1319
1320 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1321 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1322 return Task::ready(Ok(()));
1323 }
1324
1325 let (response_tx, response_rx) = oneshot::channel();
1326 self.stream
1327 .0
1328 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1329 ToolCallAuthorization {
1330 tool_call: AgentResponseEventStream::initial_tool_call(
1331 &self.tool_use_id,
1332 title.into(),
1333 self.kind.clone(),
1334 self.input.clone(),
1335 ),
1336 options: vec![
1337 acp::PermissionOption {
1338 id: acp::PermissionOptionId("always_allow".into()),
1339 name: "Always Allow".into(),
1340 kind: acp::PermissionOptionKind::AllowAlways,
1341 },
1342 acp::PermissionOption {
1343 id: acp::PermissionOptionId("allow".into()),
1344 name: "Allow".into(),
1345 kind: acp::PermissionOptionKind::AllowOnce,
1346 },
1347 acp::PermissionOption {
1348 id: acp::PermissionOptionId("deny".into()),
1349 name: "Deny".into(),
1350 kind: acp::PermissionOptionKind::RejectOnce,
1351 },
1352 ],
1353 response: response_tx,
1354 },
1355 )))
1356 .ok();
1357 let fs = self.fs.clone();
1358 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1359 "always_allow" => {
1360 if let Some(fs) = fs.clone() {
1361 cx.update(|cx| {
1362 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1363 settings.set_always_allow_tool_actions(true);
1364 });
1365 })?;
1366 }
1367
1368 Ok(())
1369 }
1370 "allow" => Ok(()),
1371 _ => Err(anyhow!("Permission to run tool denied by user")),
1372 })
1373 }
1374}
1375
1376#[cfg(test)]
1377pub struct ToolCallEventStreamReceiver(
1378 mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1379);
1380
1381#[cfg(test)]
1382impl ToolCallEventStreamReceiver {
1383 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1384 let event = self.0.next().await;
1385 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1386 auth
1387 } else {
1388 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1389 }
1390 }
1391
1392 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1393 let event = self.0.next().await;
1394 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1395 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1396 ))) = event
1397 {
1398 update.terminal
1399 } else {
1400 panic!("Expected terminal but got: {:?}", event);
1401 }
1402 }
1403}
1404
1405#[cfg(test)]
1406impl std::ops::Deref for ToolCallEventStreamReceiver {
1407 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1408
1409 fn deref(&self) -> &Self::Target {
1410 &self.0
1411 }
1412}
1413
1414#[cfg(test)]
1415impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1416 fn deref_mut(&mut self) -> &mut Self::Target {
1417 &mut self.0
1418 }
1419}
1420
1421impl From<&str> for UserMessageContent {
1422 fn from(text: &str) -> Self {
1423 Self::Text(text.into())
1424 }
1425}
1426
1427impl From<acp::ContentBlock> for UserMessageContent {
1428 fn from(value: acp::ContentBlock) -> Self {
1429 match value {
1430 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1431 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1432 acp::ContentBlock::Audio(_) => {
1433 // TODO
1434 Self::Text("[audio]".to_string())
1435 }
1436 acp::ContentBlock::ResourceLink(resource_link) => {
1437 match MentionUri::parse(&resource_link.uri) {
1438 Ok(uri) => Self::Mention {
1439 uri,
1440 content: String::new(),
1441 },
1442 Err(err) => {
1443 log::error!("Failed to parse mention link: {}", err);
1444 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1445 }
1446 }
1447 }
1448 acp::ContentBlock::Resource(resource) => match resource.resource {
1449 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1450 match MentionUri::parse(&resource.uri) {
1451 Ok(uri) => Self::Mention {
1452 uri,
1453 content: resource.text,
1454 },
1455 Err(err) => {
1456 log::error!("Failed to parse mention link: {}", err);
1457 Self::Text(
1458 MarkdownCodeBlock {
1459 tag: &resource.uri,
1460 text: &resource.text,
1461 }
1462 .to_string(),
1463 )
1464 }
1465 }
1466 }
1467 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1468 // TODO
1469 Self::Text("[blob]".to_string())
1470 }
1471 },
1472 }
1473 }
1474}
1475
1476fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1477 LanguageModelImage {
1478 source: image_content.data.into(),
1479 // TODO: make this optional?
1480 size: gpui::Size::new(0.into(), 0.into()),
1481 }
1482}