1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
2use acp_thread::{MentionUri, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionMode};
9use collections::IndexMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, Context, Entity, SharedString, Task};
16use language_model::{
17 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
18 LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
19 LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
20 LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
29use std::{fmt::Write, ops::Range};
30use util::{ResultExt, markdown::MarkdownCodeBlock};
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum Message {
34 User(UserMessage),
35 Agent(AgentMessage),
36}
37
38impl Message {
39 pub fn as_agent_message(&self) -> Option<&AgentMessage> {
40 match self {
41 Message::Agent(agent_message) => Some(agent_message),
42 _ => None,
43 }
44 }
45
46 pub fn to_markdown(&self) -> String {
47 match self {
48 Message::User(message) => message.to_markdown(),
49 Message::Agent(message) => message.to_markdown(),
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct UserMessage {
56 pub id: UserMessageId,
57 pub content: Vec<UserMessageContent>,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum UserMessageContent {
62 Text(String),
63 Mention { uri: MentionUri, content: String },
64 Image(LanguageModelImage),
65}
66
67impl UserMessage {
68 pub fn to_markdown(&self) -> String {
69 let mut markdown = String::from("## User\n\n");
70
71 for content in &self.content {
72 match content {
73 UserMessageContent::Text(text) => {
74 markdown.push_str(text);
75 markdown.push('\n');
76 }
77 UserMessageContent::Image(_) => {
78 markdown.push_str("<image />\n");
79 }
80 UserMessageContent::Mention { uri, content } => {
81 if !content.is_empty() {
82 let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
83 } else {
84 let _ = write!(&mut markdown, "{}\n", uri.as_link());
85 }
86 }
87 }
88 }
89
90 markdown
91 }
92
93 fn to_request(&self) -> LanguageModelRequestMessage {
94 let mut message = LanguageModelRequestMessage {
95 role: Role::User,
96 content: Vec::with_capacity(self.content.len()),
97 cache: false,
98 };
99
100 const OPEN_CONTEXT: &str = "<context>\n\
101 The following items were attached by the user. \
102 They are up-to-date and don't need to be re-read.\n\n";
103
104 const OPEN_FILES_TAG: &str = "<files>";
105 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
106 const OPEN_THREADS_TAG: &str = "<threads>";
107 const OPEN_FETCH_TAG: &str = "<fetched_urls>";
108 const OPEN_RULES_TAG: &str =
109 "<rules>\nThe user has specified the following rules that should be applied:\n";
110
111 let mut file_context = OPEN_FILES_TAG.to_string();
112 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
113 let mut thread_context = OPEN_THREADS_TAG.to_string();
114 let mut fetch_context = OPEN_FETCH_TAG.to_string();
115 let mut rules_context = OPEN_RULES_TAG.to_string();
116
117 for chunk in &self.content {
118 let chunk = match chunk {
119 UserMessageContent::Text(text) => {
120 language_model::MessageContent::Text(text.clone())
121 }
122 UserMessageContent::Image(value) => {
123 language_model::MessageContent::Image(value.clone())
124 }
125 UserMessageContent::Mention { uri, content } => {
126 match uri {
127 MentionUri::File { abs_path, .. } => {
128 write!(
129 &mut symbol_context,
130 "\n{}",
131 MarkdownCodeBlock {
132 tag: &codeblock_tag(&abs_path, None),
133 text: &content.to_string(),
134 }
135 )
136 .ok();
137 }
138 MentionUri::Symbol {
139 path, line_range, ..
140 }
141 | MentionUri::Selection {
142 path, line_range, ..
143 } => {
144 write!(
145 &mut rules_context,
146 "\n{}",
147 MarkdownCodeBlock {
148 tag: &codeblock_tag(&path, Some(line_range)),
149 text: &content
150 }
151 )
152 .ok();
153 }
154 MentionUri::Thread { .. } => {
155 write!(&mut thread_context, "\n{}\n", content).ok();
156 }
157 MentionUri::TextThread { .. } => {
158 write!(&mut thread_context, "\n{}\n", content).ok();
159 }
160 MentionUri::Rule { .. } => {
161 write!(
162 &mut rules_context,
163 "\n{}",
164 MarkdownCodeBlock {
165 tag: "",
166 text: &content
167 }
168 )
169 .ok();
170 }
171 MentionUri::Fetch { url } => {
172 write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
173 }
174 }
175
176 language_model::MessageContent::Text(uri.as_link().to_string())
177 }
178 };
179
180 message.content.push(chunk);
181 }
182
183 let len_before_context = message.content.len();
184
185 if file_context.len() > OPEN_FILES_TAG.len() {
186 file_context.push_str("</files>\n");
187 message
188 .content
189 .push(language_model::MessageContent::Text(file_context));
190 }
191
192 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
193 symbol_context.push_str("</symbols>\n");
194 message
195 .content
196 .push(language_model::MessageContent::Text(symbol_context));
197 }
198
199 if thread_context.len() > OPEN_THREADS_TAG.len() {
200 thread_context.push_str("</threads>\n");
201 message
202 .content
203 .push(language_model::MessageContent::Text(thread_context));
204 }
205
206 if fetch_context.len() > OPEN_FETCH_TAG.len() {
207 fetch_context.push_str("</fetched_urls>\n");
208 message
209 .content
210 .push(language_model::MessageContent::Text(fetch_context));
211 }
212
213 if rules_context.len() > OPEN_RULES_TAG.len() {
214 rules_context.push_str("</user_rules>\n");
215 message
216 .content
217 .push(language_model::MessageContent::Text(rules_context));
218 }
219
220 if message.content.len() > len_before_context {
221 message.content.insert(
222 len_before_context,
223 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
224 );
225 message
226 .content
227 .push(language_model::MessageContent::Text("</context>".into()));
228 }
229
230 message
231 }
232}
233
234fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
235 let mut result = String::new();
236
237 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
238 let _ = write!(result, "{} ", extension);
239 }
240
241 let _ = write!(result, "{}", full_path.display());
242
243 if let Some(range) = line_range {
244 if range.start == range.end {
245 let _ = write!(result, ":{}", range.start + 1);
246 } else {
247 let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
248 }
249 }
250
251 result
252}
253
254impl AgentMessage {
255 pub fn to_markdown(&self) -> String {
256 let mut markdown = String::from("## Assistant\n\n");
257
258 for content in &self.content {
259 match content {
260 AgentMessageContent::Text(text) => {
261 markdown.push_str(text);
262 markdown.push('\n');
263 }
264 AgentMessageContent::Thinking { text, .. } => {
265 markdown.push_str("<think>");
266 markdown.push_str(text);
267 markdown.push_str("</think>\n");
268 }
269 AgentMessageContent::RedactedThinking(_) => {
270 markdown.push_str("<redacted_thinking />\n")
271 }
272 AgentMessageContent::Image(_) => {
273 markdown.push_str("<image />\n");
274 }
275 AgentMessageContent::ToolUse(tool_use) => {
276 markdown.push_str(&format!(
277 "**Tool Use**: {} (ID: {})\n",
278 tool_use.name, tool_use.id
279 ));
280 markdown.push_str(&format!(
281 "{}\n",
282 MarkdownCodeBlock {
283 tag: "json",
284 text: &format!("{:#}", tool_use.input)
285 }
286 ));
287 }
288 }
289 }
290
291 for tool_result in self.tool_results.values() {
292 markdown.push_str(&format!(
293 "**Tool Result**: {} (ID: {})\n\n",
294 tool_result.tool_name, tool_result.tool_use_id
295 ));
296 if tool_result.is_error {
297 markdown.push_str("**ERROR:**\n");
298 }
299
300 match &tool_result.content {
301 LanguageModelToolResultContent::Text(text) => {
302 writeln!(markdown, "{text}\n").ok();
303 }
304 LanguageModelToolResultContent::Image(_) => {
305 writeln!(markdown, "<image />\n").ok();
306 }
307 }
308
309 if let Some(output) = tool_result.output.as_ref() {
310 writeln!(
311 markdown,
312 "**Debug Output**:\n\n```json\n{}\n```\n",
313 serde_json::to_string_pretty(output).unwrap()
314 )
315 .unwrap();
316 }
317 }
318
319 markdown
320 }
321
322 pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
323 let mut content = Vec::with_capacity(self.content.len());
324 for chunk in &self.content {
325 let chunk = match chunk {
326 AgentMessageContent::Text(text) => {
327 language_model::MessageContent::Text(text.clone())
328 }
329 AgentMessageContent::Thinking { text, signature } => {
330 language_model::MessageContent::Thinking {
331 text: text.clone(),
332 signature: signature.clone(),
333 }
334 }
335 AgentMessageContent::RedactedThinking(value) => {
336 language_model::MessageContent::RedactedThinking(value.clone())
337 }
338 AgentMessageContent::ToolUse(value) => {
339 language_model::MessageContent::ToolUse(value.clone())
340 }
341 AgentMessageContent::Image(value) => {
342 language_model::MessageContent::Image(value.clone())
343 }
344 };
345 content.push(chunk);
346 }
347
348 let mut messages = vec![LanguageModelRequestMessage {
349 role: Role::Assistant,
350 content,
351 cache: false,
352 }];
353
354 if !self.tool_results.is_empty() {
355 let mut tool_results = Vec::with_capacity(self.tool_results.len());
356 for tool_result in self.tool_results.values() {
357 tool_results.push(language_model::MessageContent::ToolResult(
358 tool_result.clone(),
359 ));
360 }
361 messages.push(LanguageModelRequestMessage {
362 role: Role::User,
363 content: tool_results,
364 cache: false,
365 });
366 }
367
368 messages
369 }
370}
371
372#[derive(Default, Debug, Clone, PartialEq, Eq)]
373pub struct AgentMessage {
374 pub content: Vec<AgentMessageContent>,
375 pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
376}
377
378#[derive(Debug, Clone, PartialEq, Eq)]
379pub enum AgentMessageContent {
380 Text(String),
381 Thinking {
382 text: String,
383 signature: Option<String>,
384 },
385 RedactedThinking(String),
386 Image(LanguageModelImage),
387 ToolUse(LanguageModelToolUse),
388}
389
390#[derive(Debug)]
391pub enum AgentResponseEvent {
392 Text(String),
393 Thinking(String),
394 ToolCall(acp::ToolCall),
395 ToolCallUpdate(acp_thread::ToolCallUpdate),
396 ToolCallAuthorization(ToolCallAuthorization),
397 Stop(acp::StopReason),
398}
399
400#[derive(Debug)]
401pub struct ToolCallAuthorization {
402 pub tool_call: acp::ToolCall,
403 pub options: Vec<acp::PermissionOption>,
404 pub response: oneshot::Sender<acp::PermissionOptionId>,
405}
406
407pub struct Thread {
408 messages: Vec<Message>,
409 completion_mode: CompletionMode,
410 /// Holds the task that handles agent interaction until the end of the turn.
411 /// Survives across multiple requests as the model performs tool calls and
412 /// we run tools, report their results.
413 running_turn: Option<Task<()>>,
414 pending_message: Option<AgentMessage>,
415 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
416 context_server_registry: Entity<ContextServerRegistry>,
417 profile_id: AgentProfileId,
418 project_context: Rc<RefCell<ProjectContext>>,
419 templates: Arc<Templates>,
420 pub selected_model: Arc<dyn LanguageModel>,
421 project: Entity<Project>,
422 action_log: Entity<ActionLog>,
423}
424
425impl Thread {
426 pub fn new(
427 project: Entity<Project>,
428 project_context: Rc<RefCell<ProjectContext>>,
429 context_server_registry: Entity<ContextServerRegistry>,
430 action_log: Entity<ActionLog>,
431 templates: Arc<Templates>,
432 default_model: Arc<dyn LanguageModel>,
433 cx: &mut Context<Self>,
434 ) -> Self {
435 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
436 Self {
437 messages: Vec::new(),
438 completion_mode: CompletionMode::Normal,
439 running_turn: None,
440 pending_message: None,
441 tools: BTreeMap::default(),
442 context_server_registry,
443 profile_id,
444 project_context,
445 templates,
446 selected_model: default_model,
447 project,
448 action_log,
449 }
450 }
451
452 pub fn project(&self) -> &Entity<Project> {
453 &self.project
454 }
455
456 pub fn action_log(&self) -> &Entity<ActionLog> {
457 &self.action_log
458 }
459
460 pub fn set_mode(&mut self, mode: CompletionMode) {
461 self.completion_mode = mode;
462 }
463
464 #[cfg(any(test, feature = "test-support"))]
465 pub fn last_message(&self) -> Option<Message> {
466 if let Some(message) = self.pending_message.clone() {
467 Some(Message::Agent(message))
468 } else {
469 self.messages.last().cloned()
470 }
471 }
472
473 pub fn add_tool(&mut self, tool: impl AgentTool) {
474 self.tools.insert(tool.name(), tool.erase());
475 }
476
477 pub fn remove_tool(&mut self, name: &str) -> bool {
478 self.tools.remove(name).is_some()
479 }
480
481 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
482 self.profile_id = profile_id;
483 }
484
485 pub fn cancel(&mut self) {
486 // TODO: do we need to emit a stop::cancel for ACP?
487 self.running_turn.take();
488 self.flush_pending_message();
489 }
490
491 pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
492 self.cancel();
493 let Some(position) = self.messages.iter().position(
494 |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
495 ) else {
496 return Err(anyhow!("Message not found"));
497 };
498 self.messages.truncate(position);
499 Ok(())
500 }
501
502 /// Sending a message results in the model streaming a response, which could include tool calls.
503 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
504 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
505 pub fn send<T>(
506 &mut self,
507 message_id: UserMessageId,
508 content: impl IntoIterator<Item = T>,
509 cx: &mut Context<Self>,
510 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
511 where
512 T: Into<UserMessageContent>,
513 {
514 let model = self.selected_model.clone();
515 let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
516 log::info!("Thread::send called with model: {:?}", model.name());
517 log::debug!("Thread::send content: {:?}", content);
518
519 cx.notify();
520 let (events_tx, events_rx) =
521 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
522 let event_stream = AgentResponseEventStream(events_tx);
523
524 self.messages.push(Message::User(UserMessage {
525 id: message_id.clone(),
526 content,
527 }));
528 log::info!("Total messages in thread: {}", self.messages.len());
529 self.running_turn = Some(cx.spawn(async move |this, cx| {
530 log::info!("Starting agent turn execution");
531 let turn_result = async {
532 let mut completion_intent = CompletionIntent::UserPrompt;
533 loop {
534 log::debug!(
535 "Building completion request with intent: {:?}",
536 completion_intent
537 );
538 let request = this.update(cx, |this, cx| {
539 this.build_completion_request(completion_intent, cx)
540 })?;
541
542 log::info!("Calling model.stream_completion");
543 let mut events = model.stream_completion(request, cx).await?;
544 log::debug!("Stream completion started successfully");
545
546 let mut tool_uses = FuturesUnordered::new();
547 while let Some(event) = events.next().await {
548 match event? {
549 LanguageModelCompletionEvent::Stop(reason) => {
550 event_stream.send_stop(reason);
551 if reason == StopReason::Refusal {
552 this.update(cx, |this, _cx| this.truncate(message_id))??;
553 return Ok(());
554 }
555 }
556 event => {
557 log::trace!("Received completion event: {:?}", event);
558 this.update(cx, |this, cx| {
559 tool_uses.extend(this.handle_streamed_completion_event(
560 event,
561 &event_stream,
562 cx,
563 ));
564 })
565 .ok();
566 }
567 }
568 }
569
570 if tool_uses.is_empty() {
571 log::info!("No tool uses found, completing turn");
572 return Ok(());
573 }
574 log::info!("Found {} tool uses to execute", tool_uses.len());
575
576 while let Some(tool_result) = tool_uses.next().await {
577 log::info!("Tool finished {:?}", tool_result);
578
579 event_stream.update_tool_call_fields(
580 &tool_result.tool_use_id,
581 acp::ToolCallUpdateFields {
582 status: Some(if tool_result.is_error {
583 acp::ToolCallStatus::Failed
584 } else {
585 acp::ToolCallStatus::Completed
586 }),
587 raw_output: tool_result.output.clone(),
588 ..Default::default()
589 },
590 );
591 this.update(cx, |this, _cx| {
592 this.pending_message()
593 .tool_results
594 .insert(tool_result.tool_use_id.clone(), tool_result);
595 })
596 .ok();
597 }
598
599 this.update(cx, |this, _| this.flush_pending_message())?;
600 completion_intent = CompletionIntent::ToolResults;
601 }
602 }
603 .await;
604
605 this.update(cx, |this, _| this.flush_pending_message()).ok();
606 if let Err(error) = turn_result {
607 log::error!("Turn execution failed: {:?}", error);
608 event_stream.send_error(error);
609 } else {
610 log::info!("Turn execution completed successfully");
611 }
612 }));
613 events_rx
614 }
615
616 pub fn build_system_message(&self) -> LanguageModelRequestMessage {
617 log::debug!("Building system message");
618 let prompt = SystemPromptTemplate {
619 project: &self.project_context.borrow(),
620 available_tools: self.tools.keys().cloned().collect(),
621 }
622 .render(&self.templates)
623 .context("failed to build system prompt")
624 .expect("Invalid template");
625 log::debug!("System message built");
626 LanguageModelRequestMessage {
627 role: Role::System,
628 content: vec![prompt.into()],
629 cache: true,
630 }
631 }
632
633 /// A helper method that's called on every streamed completion event.
634 /// Returns an optional tool result task, which the main agentic loop in
635 /// send will send back to the model when it resolves.
636 fn handle_streamed_completion_event(
637 &mut self,
638 event: LanguageModelCompletionEvent,
639 event_stream: &AgentResponseEventStream,
640 cx: &mut Context<Self>,
641 ) -> Option<Task<LanguageModelToolResult>> {
642 log::trace!("Handling streamed completion event: {:?}", event);
643 use LanguageModelCompletionEvent::*;
644
645 match event {
646 StartMessage { .. } => {
647 self.flush_pending_message();
648 self.pending_message = Some(AgentMessage::default());
649 }
650 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
651 Thinking { text, signature } => {
652 self.handle_thinking_event(text, signature, event_stream, cx)
653 }
654 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
655 ToolUse(tool_use) => {
656 return self.handle_tool_use_event(tool_use, event_stream, cx);
657 }
658 ToolUseJsonParseError {
659 id,
660 tool_name,
661 raw_input,
662 json_parse_error,
663 } => {
664 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
665 id,
666 tool_name,
667 raw_input,
668 json_parse_error,
669 )));
670 }
671 UsageUpdate(_) | StatusUpdate(_) => {}
672 Stop(_) => unreachable!(),
673 }
674
675 None
676 }
677
678 fn handle_text_event(
679 &mut self,
680 new_text: String,
681 events_stream: &AgentResponseEventStream,
682 cx: &mut Context<Self>,
683 ) {
684 events_stream.send_text(&new_text);
685
686 let last_message = self.pending_message();
687 if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
688 text.push_str(&new_text);
689 } else {
690 last_message
691 .content
692 .push(AgentMessageContent::Text(new_text));
693 }
694
695 cx.notify();
696 }
697
698 fn handle_thinking_event(
699 &mut self,
700 new_text: String,
701 new_signature: Option<String>,
702 event_stream: &AgentResponseEventStream,
703 cx: &mut Context<Self>,
704 ) {
705 event_stream.send_thinking(&new_text);
706
707 let last_message = self.pending_message();
708 if let Some(AgentMessageContent::Thinking { text, signature }) =
709 last_message.content.last_mut()
710 {
711 text.push_str(&new_text);
712 *signature = new_signature.or(signature.take());
713 } else {
714 last_message.content.push(AgentMessageContent::Thinking {
715 text: new_text,
716 signature: new_signature,
717 });
718 }
719
720 cx.notify();
721 }
722
723 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
724 let last_message = self.pending_message();
725 last_message
726 .content
727 .push(AgentMessageContent::RedactedThinking(data));
728 cx.notify();
729 }
730
731 fn handle_tool_use_event(
732 &mut self,
733 tool_use: LanguageModelToolUse,
734 event_stream: &AgentResponseEventStream,
735 cx: &mut Context<Self>,
736 ) -> Option<Task<LanguageModelToolResult>> {
737 cx.notify();
738
739 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
740 let mut title = SharedString::from(&tool_use.name);
741 let mut kind = acp::ToolKind::Other;
742 if let Some(tool) = tool.as_ref() {
743 title = tool.initial_title(tool_use.input.clone());
744 kind = tool.kind();
745 }
746
747 // Ensure the last message ends in the current tool use
748 let last_message = self.pending_message();
749 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
750 if let AgentMessageContent::ToolUse(last_tool_use) = content {
751 if last_tool_use.id == tool_use.id {
752 *last_tool_use = tool_use.clone();
753 false
754 } else {
755 true
756 }
757 } else {
758 true
759 }
760 });
761
762 if push_new_tool_use {
763 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
764 last_message
765 .content
766 .push(AgentMessageContent::ToolUse(tool_use.clone()));
767 } else {
768 event_stream.update_tool_call_fields(
769 &tool_use.id,
770 acp::ToolCallUpdateFields {
771 title: Some(title.into()),
772 kind: Some(kind),
773 raw_input: Some(tool_use.input.clone()),
774 ..Default::default()
775 },
776 );
777 }
778
779 if !tool_use.is_input_complete {
780 return None;
781 }
782
783 let Some(tool) = tool else {
784 let content = format!("No tool named {} exists", tool_use.name);
785 return Some(Task::ready(LanguageModelToolResult {
786 content: LanguageModelToolResultContent::Text(Arc::from(content)),
787 tool_use_id: tool_use.id,
788 tool_name: tool_use.name,
789 is_error: true,
790 output: None,
791 }));
792 };
793
794 let fs = self.project.read(cx).fs().clone();
795 let tool_event_stream =
796 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
797 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
798 status: Some(acp::ToolCallStatus::InProgress),
799 ..Default::default()
800 });
801 let supports_images = self.selected_model.supports_images();
802 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
803 Some(cx.foreground_executor().spawn(async move {
804 let tool_result = tool_result.await.and_then(|output| {
805 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
806 if !supports_images {
807 return Err(anyhow!(
808 "Attempted to read an image, but this model doesn't support it.",
809 ));
810 }
811 }
812 Ok(output)
813 });
814
815 match tool_result {
816 Ok(output) => LanguageModelToolResult {
817 tool_use_id: tool_use.id,
818 tool_name: tool_use.name,
819 is_error: false,
820 content: output.llm_output,
821 output: Some(output.raw_output),
822 },
823 Err(error) => LanguageModelToolResult {
824 tool_use_id: tool_use.id,
825 tool_name: tool_use.name,
826 is_error: true,
827 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
828 output: None,
829 },
830 }
831 }))
832 }
833
834 fn handle_tool_use_json_parse_error_event(
835 &mut self,
836 tool_use_id: LanguageModelToolUseId,
837 tool_name: Arc<str>,
838 raw_input: Arc<str>,
839 json_parse_error: String,
840 ) -> LanguageModelToolResult {
841 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
842 LanguageModelToolResult {
843 tool_use_id,
844 tool_name,
845 is_error: true,
846 content: LanguageModelToolResultContent::Text(tool_output.into()),
847 output: Some(serde_json::Value::String(raw_input.to_string())),
848 }
849 }
850
851 fn pending_message(&mut self) -> &mut AgentMessage {
852 self.pending_message.get_or_insert_default()
853 }
854
855 fn flush_pending_message(&mut self) {
856 let Some(mut message) = self.pending_message.take() else {
857 return;
858 };
859
860 for content in &message.content {
861 let AgentMessageContent::ToolUse(tool_use) = content else {
862 continue;
863 };
864
865 if !message.tool_results.contains_key(&tool_use.id) {
866 message.tool_results.insert(
867 tool_use.id.clone(),
868 LanguageModelToolResult {
869 tool_use_id: tool_use.id.clone(),
870 tool_name: tool_use.name.clone(),
871 is_error: true,
872 content: LanguageModelToolResultContent::Text(
873 "Tool canceled by user".into(),
874 ),
875 output: None,
876 },
877 );
878 }
879 }
880
881 self.messages.push(Message::Agent(message));
882 }
883
884 pub(crate) fn build_completion_request(
885 &self,
886 completion_intent: CompletionIntent,
887 cx: &mut App,
888 ) -> LanguageModelRequest {
889 log::debug!("Building completion request");
890 log::debug!("Completion intent: {:?}", completion_intent);
891 log::debug!("Completion mode: {:?}", self.completion_mode);
892
893 let messages = self.build_request_messages();
894 log::info!("Request will include {} messages", messages.len());
895
896 let tools = if let Some(tools) = self.tools(cx).log_err() {
897 tools
898 .filter_map(|tool| {
899 let tool_name = tool.name().to_string();
900 log::trace!("Including tool: {}", tool_name);
901 Some(LanguageModelRequestTool {
902 name: tool_name,
903 description: tool.description().to_string(),
904 input_schema: tool
905 .input_schema(self.selected_model.tool_input_format())
906 .log_err()?,
907 })
908 })
909 .collect()
910 } else {
911 Vec::new()
912 };
913
914 log::info!("Request includes {} tools", tools.len());
915
916 let request = LanguageModelRequest {
917 thread_id: None,
918 prompt_id: None,
919 intent: Some(completion_intent),
920 mode: Some(self.completion_mode),
921 messages,
922 tools,
923 tool_choice: None,
924 stop: Vec::new(),
925 temperature: None,
926 thinking_allowed: true,
927 };
928
929 log::debug!("Completion request built successfully");
930 request
931 }
932
933 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
934 let profile = AgentSettings::get_global(cx)
935 .profiles
936 .get(&self.profile_id)
937 .context("profile not found")?;
938 let provider_id = self.selected_model.provider_id();
939
940 Ok(self
941 .tools
942 .iter()
943 .filter(move |(_, tool)| tool.supported_provider(&provider_id))
944 .filter_map(|(tool_name, tool)| {
945 if profile.is_tool_enabled(tool_name) {
946 Some(tool)
947 } else {
948 None
949 }
950 })
951 .chain(self.context_server_registry.read(cx).servers().flat_map(
952 |(server_id, tools)| {
953 tools.iter().filter_map(|(tool_name, tool)| {
954 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
955 Some(tool)
956 } else {
957 None
958 }
959 })
960 },
961 )))
962 }
963
964 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
965 log::trace!(
966 "Building request messages from {} thread messages",
967 self.messages.len()
968 );
969 let mut messages = vec![self.build_system_message()];
970 for message in &self.messages {
971 match message {
972 Message::User(message) => messages.push(message.to_request()),
973 Message::Agent(message) => messages.extend(message.to_request()),
974 }
975 }
976
977 if let Some(message) = self.pending_message.as_ref() {
978 messages.extend(message.to_request());
979 }
980
981 messages
982 }
983
984 pub fn to_markdown(&self) -> String {
985 let mut markdown = String::new();
986 for (ix, message) in self.messages.iter().enumerate() {
987 if ix > 0 {
988 markdown.push('\n');
989 }
990 markdown.push_str(&message.to_markdown());
991 }
992
993 if let Some(message) = self.pending_message.as_ref() {
994 markdown.push('\n');
995 markdown.push_str(&message.to_markdown());
996 }
997
998 markdown
999 }
1000}
1001
1002pub trait AgentTool
1003where
1004 Self: 'static + Sized,
1005{
1006 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1007 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1008
1009 fn name(&self) -> SharedString;
1010
1011 fn description(&self) -> SharedString {
1012 let schema = schemars::schema_for!(Self::Input);
1013 SharedString::new(
1014 schema
1015 .get("description")
1016 .and_then(|description| description.as_str())
1017 .unwrap_or_default(),
1018 )
1019 }
1020
1021 fn kind(&self) -> acp::ToolKind;
1022
1023 /// The initial tool title to display. Can be updated during the tool run.
1024 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1025
1026 /// Returns the JSON schema that describes the tool's input.
1027 fn input_schema(&self) -> Schema {
1028 schemars::schema_for!(Self::Input)
1029 }
1030
1031 /// Some tools rely on a provider for the underlying billing or other reasons.
1032 /// Allow the tool to check if they are compatible, or should be filtered out.
1033 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1034 true
1035 }
1036
1037 /// Runs the tool with the provided input.
1038 fn run(
1039 self: Arc<Self>,
1040 input: Self::Input,
1041 event_stream: ToolCallEventStream,
1042 cx: &mut App,
1043 ) -> Task<Result<Self::Output>>;
1044
1045 fn erase(self) -> Arc<dyn AnyAgentTool> {
1046 Arc::new(Erased(Arc::new(self)))
1047 }
1048}
1049
1050pub struct Erased<T>(T);
1051
1052pub struct AgentToolOutput {
1053 pub llm_output: LanguageModelToolResultContent,
1054 pub raw_output: serde_json::Value,
1055}
1056
1057pub trait AnyAgentTool {
1058 fn name(&self) -> SharedString;
1059 fn description(&self) -> SharedString;
1060 fn kind(&self) -> acp::ToolKind;
1061 fn initial_title(&self, input: serde_json::Value) -> SharedString;
1062 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1063 fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1064 true
1065 }
1066 fn run(
1067 self: Arc<Self>,
1068 input: serde_json::Value,
1069 event_stream: ToolCallEventStream,
1070 cx: &mut App,
1071 ) -> Task<Result<AgentToolOutput>>;
1072}
1073
1074impl<T> AnyAgentTool for Erased<Arc<T>>
1075where
1076 T: AgentTool,
1077{
1078 fn name(&self) -> SharedString {
1079 self.0.name()
1080 }
1081
1082 fn description(&self) -> SharedString {
1083 self.0.description()
1084 }
1085
1086 fn kind(&self) -> agent_client_protocol::ToolKind {
1087 self.0.kind()
1088 }
1089
1090 fn initial_title(&self, input: serde_json::Value) -> SharedString {
1091 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1092 self.0.initial_title(parsed_input)
1093 }
1094
1095 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1096 let mut json = serde_json::to_value(self.0.input_schema())?;
1097 adapt_schema_to_format(&mut json, format)?;
1098 Ok(json)
1099 }
1100
1101 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1102 self.0.supported_provider(provider)
1103 }
1104
1105 fn run(
1106 self: Arc<Self>,
1107 input: serde_json::Value,
1108 event_stream: ToolCallEventStream,
1109 cx: &mut App,
1110 ) -> Task<Result<AgentToolOutput>> {
1111 cx.spawn(async move |cx| {
1112 let input = serde_json::from_value(input)?;
1113 let output = cx
1114 .update(|cx| self.0.clone().run(input, event_stream, cx))?
1115 .await?;
1116 let raw_output = serde_json::to_value(&output)?;
1117 Ok(AgentToolOutput {
1118 llm_output: output.into(),
1119 raw_output,
1120 })
1121 })
1122 }
1123}
1124
1125#[derive(Clone)]
1126struct AgentResponseEventStream(
1127 mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1128);
1129
1130impl AgentResponseEventStream {
1131 fn send_text(&self, text: &str) {
1132 self.0
1133 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
1134 .ok();
1135 }
1136
1137 fn send_thinking(&self, text: &str) {
1138 self.0
1139 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
1140 .ok();
1141 }
1142
1143 fn send_tool_call(
1144 &self,
1145 id: &LanguageModelToolUseId,
1146 title: SharedString,
1147 kind: acp::ToolKind,
1148 input: serde_json::Value,
1149 ) {
1150 self.0
1151 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
1152 id,
1153 title.to_string(),
1154 kind,
1155 input,
1156 ))))
1157 .ok();
1158 }
1159
1160 fn initial_tool_call(
1161 id: &LanguageModelToolUseId,
1162 title: String,
1163 kind: acp::ToolKind,
1164 input: serde_json::Value,
1165 ) -> acp::ToolCall {
1166 acp::ToolCall {
1167 id: acp::ToolCallId(id.to_string().into()),
1168 title,
1169 kind,
1170 status: acp::ToolCallStatus::Pending,
1171 content: vec![],
1172 locations: vec![],
1173 raw_input: Some(input),
1174 raw_output: None,
1175 }
1176 }
1177
1178 fn update_tool_call_fields(
1179 &self,
1180 tool_use_id: &LanguageModelToolUseId,
1181 fields: acp::ToolCallUpdateFields,
1182 ) {
1183 self.0
1184 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1185 acp::ToolCallUpdate {
1186 id: acp::ToolCallId(tool_use_id.to_string().into()),
1187 fields,
1188 }
1189 .into(),
1190 )))
1191 .ok();
1192 }
1193
1194 fn send_stop(&self, reason: StopReason) {
1195 match reason {
1196 StopReason::EndTurn => {
1197 self.0
1198 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
1199 .ok();
1200 }
1201 StopReason::MaxTokens => {
1202 self.0
1203 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
1204 .ok();
1205 }
1206 StopReason::Refusal => {
1207 self.0
1208 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
1209 .ok();
1210 }
1211 StopReason::ToolUse => {}
1212 }
1213 }
1214
1215 fn send_error(&self, error: LanguageModelCompletionError) {
1216 self.0.unbounded_send(Err(error)).ok();
1217 }
1218}
1219
1220#[derive(Clone)]
1221pub struct ToolCallEventStream {
1222 tool_use_id: LanguageModelToolUseId,
1223 kind: acp::ToolKind,
1224 input: serde_json::Value,
1225 stream: AgentResponseEventStream,
1226 fs: Option<Arc<dyn Fs>>,
1227}
1228
1229impl ToolCallEventStream {
1230 #[cfg(test)]
1231 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1232 let (events_tx, events_rx) =
1233 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
1234
1235 let stream = ToolCallEventStream::new(
1236 &LanguageModelToolUse {
1237 id: "test_id".into(),
1238 name: "test_tool".into(),
1239 raw_input: String::new(),
1240 input: serde_json::Value::Null,
1241 is_input_complete: true,
1242 },
1243 acp::ToolKind::Other,
1244 AgentResponseEventStream(events_tx),
1245 None,
1246 );
1247
1248 (stream, ToolCallEventStreamReceiver(events_rx))
1249 }
1250
1251 fn new(
1252 tool_use: &LanguageModelToolUse,
1253 kind: acp::ToolKind,
1254 stream: AgentResponseEventStream,
1255 fs: Option<Arc<dyn Fs>>,
1256 ) -> Self {
1257 Self {
1258 tool_use_id: tool_use.id.clone(),
1259 kind,
1260 input: tool_use.input.clone(),
1261 stream,
1262 fs,
1263 }
1264 }
1265
1266 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1267 self.stream
1268 .update_tool_call_fields(&self.tool_use_id, fields);
1269 }
1270
1271 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1272 self.stream
1273 .0
1274 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1275 acp_thread::ToolCallUpdateDiff {
1276 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1277 diff,
1278 }
1279 .into(),
1280 )))
1281 .ok();
1282 }
1283
1284 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1285 self.stream
1286 .0
1287 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1288 acp_thread::ToolCallUpdateTerminal {
1289 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1290 terminal,
1291 }
1292 .into(),
1293 )))
1294 .ok();
1295 }
1296
1297 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1298 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1299 return Task::ready(Ok(()));
1300 }
1301
1302 let (response_tx, response_rx) = oneshot::channel();
1303 self.stream
1304 .0
1305 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1306 ToolCallAuthorization {
1307 tool_call: AgentResponseEventStream::initial_tool_call(
1308 &self.tool_use_id,
1309 title.into(),
1310 self.kind.clone(),
1311 self.input.clone(),
1312 ),
1313 options: vec![
1314 acp::PermissionOption {
1315 id: acp::PermissionOptionId("always_allow".into()),
1316 name: "Always Allow".into(),
1317 kind: acp::PermissionOptionKind::AllowAlways,
1318 },
1319 acp::PermissionOption {
1320 id: acp::PermissionOptionId("allow".into()),
1321 name: "Allow".into(),
1322 kind: acp::PermissionOptionKind::AllowOnce,
1323 },
1324 acp::PermissionOption {
1325 id: acp::PermissionOptionId("deny".into()),
1326 name: "Deny".into(),
1327 kind: acp::PermissionOptionKind::RejectOnce,
1328 },
1329 ],
1330 response: response_tx,
1331 },
1332 )))
1333 .ok();
1334 let fs = self.fs.clone();
1335 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1336 "always_allow" => {
1337 if let Some(fs) = fs.clone() {
1338 cx.update(|cx| {
1339 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1340 settings.set_always_allow_tool_actions(true);
1341 });
1342 })?;
1343 }
1344
1345 Ok(())
1346 }
1347 "allow" => Ok(()),
1348 _ => Err(anyhow!("Permission to run tool denied by user")),
1349 })
1350 }
1351}
1352
1353#[cfg(test)]
1354pub struct ToolCallEventStreamReceiver(
1355 mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1356);
1357
1358#[cfg(test)]
1359impl ToolCallEventStreamReceiver {
1360 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1361 let event = self.0.next().await;
1362 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1363 auth
1364 } else {
1365 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1366 }
1367 }
1368
1369 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1370 let event = self.0.next().await;
1371 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1372 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1373 ))) = event
1374 {
1375 update.terminal
1376 } else {
1377 panic!("Expected terminal but got: {:?}", event);
1378 }
1379 }
1380}
1381
1382#[cfg(test)]
1383impl std::ops::Deref for ToolCallEventStreamReceiver {
1384 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1385
1386 fn deref(&self) -> &Self::Target {
1387 &self.0
1388 }
1389}
1390
1391#[cfg(test)]
1392impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1393 fn deref_mut(&mut self) -> &mut Self::Target {
1394 &mut self.0
1395 }
1396}
1397
1398impl From<&str> for UserMessageContent {
1399 fn from(text: &str) -> Self {
1400 Self::Text(text.into())
1401 }
1402}
1403
1404impl From<acp::ContentBlock> for UserMessageContent {
1405 fn from(value: acp::ContentBlock) -> Self {
1406 match value {
1407 acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1408 acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1409 acp::ContentBlock::Audio(_) => {
1410 // TODO
1411 Self::Text("[audio]".to_string())
1412 }
1413 acp::ContentBlock::ResourceLink(resource_link) => {
1414 match MentionUri::parse(&resource_link.uri) {
1415 Ok(uri) => Self::Mention {
1416 uri,
1417 content: String::new(),
1418 },
1419 Err(err) => {
1420 log::error!("Failed to parse mention link: {}", err);
1421 Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1422 }
1423 }
1424 }
1425 acp::ContentBlock::Resource(resource) => match resource.resource {
1426 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1427 match MentionUri::parse(&resource.uri) {
1428 Ok(uri) => Self::Mention {
1429 uri,
1430 content: resource.text,
1431 },
1432 Err(err) => {
1433 log::error!("Failed to parse mention link: {}", err);
1434 Self::Text(
1435 MarkdownCodeBlock {
1436 tag: &resource.uri,
1437 text: &resource.text,
1438 }
1439 .to_string(),
1440 )
1441 }
1442 }
1443 }
1444 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1445 // TODO
1446 Self::Text("[blob]".to_string())
1447 }
1448 },
1449 }
1450 }
1451}
1452
1453fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1454 LanguageModelImage {
1455 source: image_content.data.into(),
1456 // TODO: make this optional?
1457 size: gpui::Size::new(0.into(), 0.into()),
1458 }
1459}