1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
2use acp_thread::MentionUri;
3use action_log::ActionLog;
4use agent_client_protocol as acp;
5use agent_settings::{AgentProfileId, AgentSettings};
6use anyhow::{Context as _, Result, anyhow};
7use assistant_tool::adapt_schema_to_format;
8use cloud_llm_client::{CompletionIntent, CompletionMode};
9use collections::HashMap;
10use fs::Fs;
11use futures::{
12 channel::{mpsc, oneshot},
13 stream::FuturesUnordered,
14};
15use gpui::{App, Context, Entity, SharedString, Task};
16use language_model::{
17 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
18 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
19 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
20 LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
21};
22use log;
23use project::Project;
24use prompt_store::ProjectContext;
25use schemars::{JsonSchema, Schema};
26use serde::{Deserialize, Serialize};
27use settings::{Settings, update_settings_file};
28use smol::stream::StreamExt;
29use std::fmt::Write;
30use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
31use util::{ResultExt, markdown::MarkdownCodeBlock};
32
33#[derive(Debug, Clone)]
34pub struct AgentMessage {
35 pub role: Role,
36 pub content: Vec<MessageContent>,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum MessageContent {
41 Text(String),
42 Thinking {
43 text: String,
44 signature: Option<String>,
45 },
46 Mention {
47 uri: MentionUri,
48 content: String,
49 },
50 RedactedThinking(String),
51 Image(LanguageModelImage),
52 ToolUse(LanguageModelToolUse),
53 ToolResult(LanguageModelToolResult),
54}
55
56impl AgentMessage {
57 pub fn to_markdown(&self) -> String {
58 let mut markdown = format!("## {}\n", self.role);
59
60 for content in &self.content {
61 match content {
62 MessageContent::Text(text) => {
63 markdown.push_str(text);
64 markdown.push('\n');
65 }
66 MessageContent::Thinking { text, .. } => {
67 markdown.push_str("<think>");
68 markdown.push_str(text);
69 markdown.push_str("</think>\n");
70 }
71 MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
72 MessageContent::Image(_) => {
73 markdown.push_str("<image />\n");
74 }
75 MessageContent::ToolUse(tool_use) => {
76 markdown.push_str(&format!(
77 "**Tool Use**: {} (ID: {})\n",
78 tool_use.name, tool_use.id
79 ));
80 markdown.push_str(&format!(
81 "{}\n",
82 MarkdownCodeBlock {
83 tag: "json",
84 text: &format!("{:#}", tool_use.input)
85 }
86 ));
87 }
88 MessageContent::ToolResult(tool_result) => {
89 markdown.push_str(&format!(
90 "**Tool Result**: {} (ID: {})\n\n",
91 tool_result.tool_name, tool_result.tool_use_id
92 ));
93 if tool_result.is_error {
94 markdown.push_str("**ERROR:**\n");
95 }
96
97 match &tool_result.content {
98 LanguageModelToolResultContent::Text(text) => {
99 writeln!(markdown, "{text}\n").ok();
100 }
101 LanguageModelToolResultContent::Image(_) => {
102 writeln!(markdown, "<image />\n").ok();
103 }
104 }
105
106 if let Some(output) = tool_result.output.as_ref() {
107 writeln!(
108 markdown,
109 "**Debug Output**:\n\n```json\n{}\n```\n",
110 serde_json::to_string_pretty(output).unwrap()
111 )
112 .unwrap();
113 }
114 }
115 MessageContent::Mention { uri, .. } => {
116 write!(markdown, "{}", uri.to_link()).ok();
117 }
118 }
119 }
120
121 markdown
122 }
123}
124
125#[derive(Debug)]
126pub enum AgentResponseEvent {
127 Text(String),
128 Thinking(String),
129 ToolCall(acp::ToolCall),
130 ToolCallUpdate(acp_thread::ToolCallUpdate),
131 ToolCallAuthorization(ToolCallAuthorization),
132 Stop(acp::StopReason),
133}
134
135#[derive(Debug)]
136pub struct ToolCallAuthorization {
137 pub tool_call: acp::ToolCall,
138 pub options: Vec<acp::PermissionOption>,
139 pub response: oneshot::Sender<acp::PermissionOptionId>,
140}
141
142pub struct Thread {
143 messages: Vec<AgentMessage>,
144 completion_mode: CompletionMode,
145 /// Holds the task that handles agent interaction until the end of the turn.
146 /// Survives across multiple requests as the model performs tool calls and
147 /// we run tools, report their results.
148 running_turn: Option<Task<()>>,
149 pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
150 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
151 context_server_registry: Entity<ContextServerRegistry>,
152 profile_id: AgentProfileId,
153 project_context: Rc<RefCell<ProjectContext>>,
154 templates: Arc<Templates>,
155 pub selected_model: Arc<dyn LanguageModel>,
156 project: Entity<Project>,
157 action_log: Entity<ActionLog>,
158}
159
160impl Thread {
161 pub fn new(
162 project: Entity<Project>,
163 project_context: Rc<RefCell<ProjectContext>>,
164 context_server_registry: Entity<ContextServerRegistry>,
165 action_log: Entity<ActionLog>,
166 templates: Arc<Templates>,
167 default_model: Arc<dyn LanguageModel>,
168 cx: &mut Context<Self>,
169 ) -> Self {
170 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
171 Self {
172 messages: Vec::new(),
173 completion_mode: CompletionMode::Normal,
174 running_turn: None,
175 pending_tool_uses: HashMap::default(),
176 tools: BTreeMap::default(),
177 context_server_registry,
178 profile_id,
179 project_context,
180 templates,
181 selected_model: default_model,
182 project,
183 action_log,
184 }
185 }
186
187 pub fn project(&self) -> &Entity<Project> {
188 &self.project
189 }
190
191 pub fn action_log(&self) -> &Entity<ActionLog> {
192 &self.action_log
193 }
194
195 pub fn set_mode(&mut self, mode: CompletionMode) {
196 self.completion_mode = mode;
197 }
198
199 pub fn messages(&self) -> &[AgentMessage] {
200 &self.messages
201 }
202
203 pub fn add_tool(&mut self, tool: impl AgentTool) {
204 self.tools.insert(tool.name(), tool.erase());
205 }
206
207 pub fn remove_tool(&mut self, name: &str) -> bool {
208 self.tools.remove(name).is_some()
209 }
210
211 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
212 self.profile_id = profile_id;
213 }
214
215 pub fn cancel(&mut self) {
216 self.running_turn.take();
217
218 let tool_results = self
219 .pending_tool_uses
220 .drain()
221 .map(|(tool_use_id, tool_use)| {
222 MessageContent::ToolResult(LanguageModelToolResult {
223 tool_use_id,
224 tool_name: tool_use.name.clone(),
225 is_error: true,
226 content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
227 output: None,
228 })
229 })
230 .collect::<Vec<_>>();
231 self.last_user_message().content.extend(tool_results);
232 }
233
234 /// Sending a message results in the model streaming a response, which could include tool calls.
235 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
236 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
237 pub fn send(
238 &mut self,
239 content: impl Into<UserMessage>,
240 cx: &mut Context<Self>,
241 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
242 let content = content.into().0;
243
244 let model = self.selected_model.clone();
245 log::info!("Thread::send called with model: {:?}", model.name());
246 log::debug!("Thread::send content: {:?}", content);
247
248 cx.notify();
249 let (events_tx, events_rx) =
250 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
251 let event_stream = AgentResponseEventStream(events_tx);
252
253 let user_message_ix = self.messages.len();
254 self.messages.push(AgentMessage {
255 role: Role::User,
256 content,
257 });
258 log::info!("Total messages in thread: {}", self.messages.len());
259 self.running_turn = Some(cx.spawn(async move |thread, cx| {
260 log::info!("Starting agent turn execution");
261 let turn_result = async {
262 // Perform one request, then keep looping if the model makes tool calls.
263 let mut completion_intent = CompletionIntent::UserPrompt;
264 'outer: loop {
265 log::debug!(
266 "Building completion request with intent: {:?}",
267 completion_intent
268 );
269 let request = thread.update(cx, |thread, cx| {
270 thread.build_completion_request(completion_intent, cx)
271 })?;
272
273 // println!(
274 // "request: {}",
275 // serde_json::to_string_pretty(&request).unwrap()
276 // );
277
278 // Stream events, appending to messages and collecting up tool uses.
279 log::info!("Calling model.stream_completion");
280 let mut events = model.stream_completion(request, cx).await?;
281 log::debug!("Stream completion started successfully");
282 let mut tool_uses = FuturesUnordered::new();
283 while let Some(event) = events.next().await {
284 match event {
285 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
286 event_stream.send_stop(reason);
287 if reason == StopReason::Refusal {
288 thread.update(cx, |thread, _cx| {
289 thread.messages.truncate(user_message_ix);
290 })?;
291 break 'outer;
292 }
293 }
294 Ok(event) => {
295 log::trace!("Received completion event: {:?}", event);
296 thread
297 .update(cx, |thread, cx| {
298 tool_uses.extend(thread.handle_streamed_completion_event(
299 event,
300 &event_stream,
301 cx,
302 ));
303 })
304 .ok();
305 }
306 Err(error) => {
307 log::error!("Error in completion stream: {:?}", error);
308 event_stream.send_error(error);
309 break;
310 }
311 }
312 }
313
314 // If there are no tool uses, the turn is done.
315 if tool_uses.is_empty() {
316 log::info!("No tool uses found, completing turn");
317 break;
318 }
319 log::info!("Found {} tool uses to execute", tool_uses.len());
320
321 // As tool results trickle in, insert them in the last user
322 // message so that they can be sent on the next tick of the
323 // agentic loop.
324 while let Some(tool_result) = tool_uses.next().await {
325 log::info!("Tool finished {:?}", tool_result);
326
327 event_stream.update_tool_call_fields(
328 &tool_result.tool_use_id,
329 acp::ToolCallUpdateFields {
330 status: Some(if tool_result.is_error {
331 acp::ToolCallStatus::Failed
332 } else {
333 acp::ToolCallStatus::Completed
334 }),
335 raw_output: tool_result.output.clone(),
336 ..Default::default()
337 },
338 );
339 thread
340 .update(cx, |thread, _cx| {
341 thread.pending_tool_uses.remove(&tool_result.tool_use_id);
342 thread
343 .last_user_message()
344 .content
345 .push(MessageContent::ToolResult(tool_result));
346 })
347 .ok();
348 }
349
350 completion_intent = CompletionIntent::ToolResults;
351 }
352
353 Ok(())
354 }
355 .await;
356
357 if let Err(error) = turn_result {
358 log::error!("Turn execution failed: {:?}", error);
359 event_stream.send_error(error);
360 } else {
361 log::info!("Turn execution completed successfully");
362 }
363 }));
364 events_rx
365 }
366
367 pub fn build_system_message(&self) -> AgentMessage {
368 log::debug!("Building system message");
369 let prompt = SystemPromptTemplate {
370 project: &self.project_context.borrow(),
371 available_tools: self.tools.keys().cloned().collect(),
372 }
373 .render(&self.templates)
374 .context("failed to build system prompt")
375 .expect("Invalid template");
376 log::debug!("System message built");
377 AgentMessage {
378 role: Role::System,
379 content: vec![prompt.as_str().into()],
380 }
381 }
382
383 /// A helper method that's called on every streamed completion event.
384 /// Returns an optional tool result task, which the main agentic loop in
385 /// send will send back to the model when it resolves.
386 fn handle_streamed_completion_event(
387 &mut self,
388 event: LanguageModelCompletionEvent,
389 event_stream: &AgentResponseEventStream,
390 cx: &mut Context<Self>,
391 ) -> Option<Task<LanguageModelToolResult>> {
392 log::trace!("Handling streamed completion event: {:?}", event);
393 use LanguageModelCompletionEvent::*;
394
395 match event {
396 StartMessage { .. } => {
397 self.messages.push(AgentMessage {
398 role: Role::Assistant,
399 content: Vec::new(),
400 });
401 }
402 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
403 Thinking { text, signature } => {
404 self.handle_thinking_event(text, signature, event_stream, cx)
405 }
406 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
407 ToolUse(tool_use) => {
408 return self.handle_tool_use_event(tool_use, event_stream, cx);
409 }
410 ToolUseJsonParseError {
411 id,
412 tool_name,
413 raw_input,
414 json_parse_error,
415 } => {
416 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
417 id,
418 tool_name,
419 raw_input,
420 json_parse_error,
421 )));
422 }
423 UsageUpdate(_) | StatusUpdate(_) => {}
424 Stop(_) => unreachable!(),
425 }
426
427 None
428 }
429
430 fn handle_text_event(
431 &mut self,
432 new_text: String,
433 events_stream: &AgentResponseEventStream,
434 cx: &mut Context<Self>,
435 ) {
436 events_stream.send_text(&new_text);
437
438 let last_message = self.last_assistant_message();
439 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
440 text.push_str(&new_text);
441 } else {
442 last_message.content.push(MessageContent::Text(new_text));
443 }
444
445 cx.notify();
446 }
447
448 fn handle_thinking_event(
449 &mut self,
450 new_text: String,
451 new_signature: Option<String>,
452 event_stream: &AgentResponseEventStream,
453 cx: &mut Context<Self>,
454 ) {
455 event_stream.send_thinking(&new_text);
456
457 let last_message = self.last_assistant_message();
458 if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
459 {
460 text.push_str(&new_text);
461 *signature = new_signature.or(signature.take());
462 } else {
463 last_message.content.push(MessageContent::Thinking {
464 text: new_text,
465 signature: new_signature,
466 });
467 }
468
469 cx.notify();
470 }
471
472 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
473 let last_message = self.last_assistant_message();
474 last_message
475 .content
476 .push(MessageContent::RedactedThinking(data));
477 cx.notify();
478 }
479
480 fn handle_tool_use_event(
481 &mut self,
482 tool_use: LanguageModelToolUse,
483 event_stream: &AgentResponseEventStream,
484 cx: &mut Context<Self>,
485 ) -> Option<Task<LanguageModelToolResult>> {
486 cx.notify();
487
488 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
489
490 self.pending_tool_uses
491 .insert(tool_use.id.clone(), tool_use.clone());
492 let last_message = self.last_assistant_message();
493
494 // Ensure the last message ends in the current tool use
495 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
496 if let MessageContent::ToolUse(last_tool_use) = content {
497 if last_tool_use.id == tool_use.id {
498 *last_tool_use = tool_use.clone();
499 false
500 } else {
501 true
502 }
503 } else {
504 true
505 }
506 });
507
508 let mut title = SharedString::from(&tool_use.name);
509 let mut kind = acp::ToolKind::Other;
510 if let Some(tool) = tool.as_ref() {
511 title = tool.initial_title(tool_use.input.clone());
512 kind = tool.kind();
513 }
514
515 if push_new_tool_use {
516 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
517 last_message
518 .content
519 .push(MessageContent::ToolUse(tool_use.clone()));
520 } else {
521 event_stream.update_tool_call_fields(
522 &tool_use.id,
523 acp::ToolCallUpdateFields {
524 title: Some(title.into()),
525 kind: Some(kind),
526 raw_input: Some(tool_use.input.clone()),
527 ..Default::default()
528 },
529 );
530 }
531
532 if !tool_use.is_input_complete {
533 return None;
534 }
535
536 let Some(tool) = tool else {
537 let content = format!("No tool named {} exists", tool_use.name);
538 return Some(Task::ready(LanguageModelToolResult {
539 content: LanguageModelToolResultContent::Text(Arc::from(content)),
540 tool_use_id: tool_use.id,
541 tool_name: tool_use.name,
542 is_error: true,
543 output: None,
544 }));
545 };
546
547 let fs = self.project.read(cx).fs().clone();
548 let tool_event_stream =
549 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
550 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
551 status: Some(acp::ToolCallStatus::InProgress),
552 ..Default::default()
553 });
554 let supports_images = self.selected_model.supports_images();
555 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
556 Some(cx.foreground_executor().spawn(async move {
557 let tool_result = tool_result.await.and_then(|output| {
558 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
559 if !supports_images {
560 return Err(anyhow!(
561 "Attempted to read an image, but this model doesn't support it.",
562 ));
563 }
564 }
565 Ok(output)
566 });
567
568 match tool_result {
569 Ok(output) => LanguageModelToolResult {
570 tool_use_id: tool_use.id,
571 tool_name: tool_use.name,
572 is_error: false,
573 content: output.llm_output,
574 output: Some(output.raw_output),
575 },
576 Err(error) => LanguageModelToolResult {
577 tool_use_id: tool_use.id,
578 tool_name: tool_use.name,
579 is_error: true,
580 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
581 output: None,
582 },
583 }
584 }))
585 }
586
587 fn handle_tool_use_json_parse_error_event(
588 &mut self,
589 tool_use_id: LanguageModelToolUseId,
590 tool_name: Arc<str>,
591 raw_input: Arc<str>,
592 json_parse_error: String,
593 ) -> LanguageModelToolResult {
594 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
595 LanguageModelToolResult {
596 tool_use_id,
597 tool_name,
598 is_error: true,
599 content: LanguageModelToolResultContent::Text(tool_output.into()),
600 output: Some(serde_json::Value::String(raw_input.to_string())),
601 }
602 }
603
604 /// Guarantees the last message is from the assistant and returns a mutable reference.
605 fn last_assistant_message(&mut self) -> &mut AgentMessage {
606 if self
607 .messages
608 .last()
609 .map_or(true, |m| m.role != Role::Assistant)
610 {
611 self.messages.push(AgentMessage {
612 role: Role::Assistant,
613 content: Vec::new(),
614 });
615 }
616 self.messages.last_mut().unwrap()
617 }
618
619 /// Guarantees the last message is from the user and returns a mutable reference.
620 fn last_user_message(&mut self) -> &mut AgentMessage {
621 if self.messages.last().map_or(true, |m| m.role != Role::User) {
622 self.messages.push(AgentMessage {
623 role: Role::User,
624 content: Vec::new(),
625 });
626 }
627 self.messages.last_mut().unwrap()
628 }
629
630 pub(crate) fn build_completion_request(
631 &self,
632 completion_intent: CompletionIntent,
633 cx: &mut App,
634 ) -> LanguageModelRequest {
635 log::debug!("Building completion request");
636 log::debug!("Completion intent: {:?}", completion_intent);
637 log::debug!("Completion mode: {:?}", self.completion_mode);
638
639 let messages = self.build_request_messages();
640 log::info!("Request will include {} messages", messages.len());
641
642 let tools = if let Some(tools) = self.tools(cx).log_err() {
643 tools
644 .filter_map(|tool| {
645 let tool_name = tool.name().to_string();
646 log::trace!("Including tool: {}", tool_name);
647 Some(LanguageModelRequestTool {
648 name: tool_name,
649 description: tool.description().to_string(),
650 input_schema: tool
651 .input_schema(self.selected_model.tool_input_format())
652 .log_err()?,
653 })
654 })
655 .collect()
656 } else {
657 Vec::new()
658 };
659
660 log::info!("Request includes {} tools", tools.len());
661
662 let request = LanguageModelRequest {
663 thread_id: None,
664 prompt_id: None,
665 intent: Some(completion_intent),
666 mode: Some(self.completion_mode),
667 messages,
668 tools,
669 tool_choice: None,
670 stop: Vec::new(),
671 temperature: None,
672 thinking_allowed: true,
673 };
674
675 log::debug!("Completion request built successfully");
676 request
677 }
678
679 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
680 let profile = AgentSettings::get_global(cx)
681 .profiles
682 .get(&self.profile_id)
683 .context("profile not found")?;
684
685 Ok(self
686 .tools
687 .iter()
688 .filter_map(|(tool_name, tool)| {
689 if profile.is_tool_enabled(tool_name) {
690 Some(tool)
691 } else {
692 None
693 }
694 })
695 .chain(self.context_server_registry.read(cx).servers().flat_map(
696 |(server_id, tools)| {
697 tools.iter().filter_map(|(tool_name, tool)| {
698 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
699 Some(tool)
700 } else {
701 None
702 }
703 })
704 },
705 )))
706 }
707
708 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
709 log::trace!(
710 "Building request messages from {} thread messages",
711 self.messages.len()
712 );
713
714 let messages = Some(self.build_system_message())
715 .iter()
716 .chain(self.messages.iter())
717 .map(|message| {
718 log::trace!(
719 " - {} message with {} content items",
720 match message.role {
721 Role::System => "System",
722 Role::User => "User",
723 Role::Assistant => "Assistant",
724 },
725 message.content.len()
726 );
727 message.to_request()
728 })
729 .collect();
730 messages
731 }
732
733 pub fn to_markdown(&self) -> String {
734 let mut markdown = String::new();
735 for message in &self.messages {
736 markdown.push_str(&message.to_markdown());
737 }
738 markdown
739 }
740}
741
742pub struct UserMessage(Vec<MessageContent>);
743
744impl From<Vec<MessageContent>> for UserMessage {
745 fn from(content: Vec<MessageContent>) -> Self {
746 UserMessage(content)
747 }
748}
749
750impl<T: Into<MessageContent>> From<T> for UserMessage {
751 fn from(content: T) -> Self {
752 UserMessage(vec![content.into()])
753 }
754}
755
756pub trait AgentTool
757where
758 Self: 'static + Sized,
759{
760 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
761 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
762
763 fn name(&self) -> SharedString;
764
765 fn description(&self) -> SharedString {
766 let schema = schemars::schema_for!(Self::Input);
767 SharedString::new(
768 schema
769 .get("description")
770 .and_then(|description| description.as_str())
771 .unwrap_or_default(),
772 )
773 }
774
775 fn kind(&self) -> acp::ToolKind;
776
777 /// The initial tool title to display. Can be updated during the tool run.
778 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
779
780 /// Returns the JSON schema that describes the tool's input.
781 fn input_schema(&self) -> Schema {
782 schemars::schema_for!(Self::Input)
783 }
784
785 /// Runs the tool with the provided input.
786 fn run(
787 self: Arc<Self>,
788 input: Self::Input,
789 event_stream: ToolCallEventStream,
790 cx: &mut App,
791 ) -> Task<Result<Self::Output>>;
792
793 fn erase(self) -> Arc<dyn AnyAgentTool> {
794 Arc::new(Erased(Arc::new(self)))
795 }
796}
797
798pub struct Erased<T>(T);
799
800pub struct AgentToolOutput {
801 pub llm_output: LanguageModelToolResultContent,
802 pub raw_output: serde_json::Value,
803}
804
805pub trait AnyAgentTool {
806 fn name(&self) -> SharedString;
807 fn description(&self) -> SharedString;
808 fn kind(&self) -> acp::ToolKind;
809 fn initial_title(&self, input: serde_json::Value) -> SharedString;
810 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
811 fn run(
812 self: Arc<Self>,
813 input: serde_json::Value,
814 event_stream: ToolCallEventStream,
815 cx: &mut App,
816 ) -> Task<Result<AgentToolOutput>>;
817}
818
819impl<T> AnyAgentTool for Erased<Arc<T>>
820where
821 T: AgentTool,
822{
823 fn name(&self) -> SharedString {
824 self.0.name()
825 }
826
827 fn description(&self) -> SharedString {
828 self.0.description()
829 }
830
831 fn kind(&self) -> agent_client_protocol::ToolKind {
832 self.0.kind()
833 }
834
835 fn initial_title(&self, input: serde_json::Value) -> SharedString {
836 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
837 self.0.initial_title(parsed_input)
838 }
839
840 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
841 let mut json = serde_json::to_value(self.0.input_schema())?;
842 adapt_schema_to_format(&mut json, format)?;
843 Ok(json)
844 }
845
846 fn run(
847 self: Arc<Self>,
848 input: serde_json::Value,
849 event_stream: ToolCallEventStream,
850 cx: &mut App,
851 ) -> Task<Result<AgentToolOutput>> {
852 cx.spawn(async move |cx| {
853 let input = serde_json::from_value(input)?;
854 let output = cx
855 .update(|cx| self.0.clone().run(input, event_stream, cx))?
856 .await?;
857 let raw_output = serde_json::to_value(&output)?;
858 Ok(AgentToolOutput {
859 llm_output: output.into(),
860 raw_output,
861 })
862 })
863 }
864}
865
866#[derive(Clone)]
867struct AgentResponseEventStream(
868 mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
869);
870
871impl AgentResponseEventStream {
872 fn send_text(&self, text: &str) {
873 self.0
874 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
875 .ok();
876 }
877
878 fn send_thinking(&self, text: &str) {
879 self.0
880 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
881 .ok();
882 }
883
884 fn send_tool_call(
885 &self,
886 id: &LanguageModelToolUseId,
887 title: SharedString,
888 kind: acp::ToolKind,
889 input: serde_json::Value,
890 ) {
891 self.0
892 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
893 id,
894 title.to_string(),
895 kind,
896 input,
897 ))))
898 .ok();
899 }
900
901 fn initial_tool_call(
902 id: &LanguageModelToolUseId,
903 title: String,
904 kind: acp::ToolKind,
905 input: serde_json::Value,
906 ) -> acp::ToolCall {
907 acp::ToolCall {
908 id: acp::ToolCallId(id.to_string().into()),
909 title,
910 kind,
911 status: acp::ToolCallStatus::Pending,
912 content: vec![],
913 locations: vec![],
914 raw_input: Some(input),
915 raw_output: None,
916 }
917 }
918
919 fn update_tool_call_fields(
920 &self,
921 tool_use_id: &LanguageModelToolUseId,
922 fields: acp::ToolCallUpdateFields,
923 ) {
924 self.0
925 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
926 acp::ToolCallUpdate {
927 id: acp::ToolCallId(tool_use_id.to_string().into()),
928 fields,
929 }
930 .into(),
931 )))
932 .ok();
933 }
934
935 fn send_stop(&self, reason: StopReason) {
936 match reason {
937 StopReason::EndTurn => {
938 self.0
939 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
940 .ok();
941 }
942 StopReason::MaxTokens => {
943 self.0
944 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
945 .ok();
946 }
947 StopReason::Refusal => {
948 self.0
949 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
950 .ok();
951 }
952 StopReason::ToolUse => {}
953 }
954 }
955
956 fn send_error(&self, error: LanguageModelCompletionError) {
957 self.0.unbounded_send(Err(error)).ok();
958 }
959}
960
961#[derive(Clone)]
962pub struct ToolCallEventStream {
963 tool_use_id: LanguageModelToolUseId,
964 kind: acp::ToolKind,
965 input: serde_json::Value,
966 stream: AgentResponseEventStream,
967 fs: Option<Arc<dyn Fs>>,
968}
969
970impl ToolCallEventStream {
971 #[cfg(test)]
972 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
973 let (events_tx, events_rx) =
974 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
975
976 let stream = ToolCallEventStream::new(
977 &LanguageModelToolUse {
978 id: "test_id".into(),
979 name: "test_tool".into(),
980 raw_input: String::new(),
981 input: serde_json::Value::Null,
982 is_input_complete: true,
983 },
984 acp::ToolKind::Other,
985 AgentResponseEventStream(events_tx),
986 None,
987 );
988
989 (stream, ToolCallEventStreamReceiver(events_rx))
990 }
991
992 fn new(
993 tool_use: &LanguageModelToolUse,
994 kind: acp::ToolKind,
995 stream: AgentResponseEventStream,
996 fs: Option<Arc<dyn Fs>>,
997 ) -> Self {
998 Self {
999 tool_use_id: tool_use.id.clone(),
1000 kind,
1001 input: tool_use.input.clone(),
1002 stream,
1003 fs,
1004 }
1005 }
1006
1007 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1008 self.stream
1009 .update_tool_call_fields(&self.tool_use_id, fields);
1010 }
1011
1012 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1013 self.stream
1014 .0
1015 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1016 acp_thread::ToolCallUpdateDiff {
1017 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1018 diff,
1019 }
1020 .into(),
1021 )))
1022 .ok();
1023 }
1024
1025 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1026 self.stream
1027 .0
1028 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1029 acp_thread::ToolCallUpdateTerminal {
1030 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1031 terminal,
1032 }
1033 .into(),
1034 )))
1035 .ok();
1036 }
1037
1038 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1039 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1040 return Task::ready(Ok(()));
1041 }
1042
1043 let (response_tx, response_rx) = oneshot::channel();
1044 self.stream
1045 .0
1046 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1047 ToolCallAuthorization {
1048 tool_call: AgentResponseEventStream::initial_tool_call(
1049 &self.tool_use_id,
1050 title.into(),
1051 self.kind.clone(),
1052 self.input.clone(),
1053 ),
1054 options: vec![
1055 acp::PermissionOption {
1056 id: acp::PermissionOptionId("always_allow".into()),
1057 name: "Always Allow".into(),
1058 kind: acp::PermissionOptionKind::AllowAlways,
1059 },
1060 acp::PermissionOption {
1061 id: acp::PermissionOptionId("allow".into()),
1062 name: "Allow".into(),
1063 kind: acp::PermissionOptionKind::AllowOnce,
1064 },
1065 acp::PermissionOption {
1066 id: acp::PermissionOptionId("deny".into()),
1067 name: "Deny".into(),
1068 kind: acp::PermissionOptionKind::RejectOnce,
1069 },
1070 ],
1071 response: response_tx,
1072 },
1073 )))
1074 .ok();
1075 let fs = self.fs.clone();
1076 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1077 "always_allow" => {
1078 if let Some(fs) = fs.clone() {
1079 cx.update(|cx| {
1080 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1081 settings.set_always_allow_tool_actions(true);
1082 });
1083 })?;
1084 }
1085
1086 Ok(())
1087 }
1088 "allow" => Ok(()),
1089 _ => Err(anyhow!("Permission to run tool denied by user")),
1090 })
1091 }
1092}
1093
1094#[cfg(test)]
1095pub struct ToolCallEventStreamReceiver(
1096 mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1097);
1098
1099#[cfg(test)]
1100impl ToolCallEventStreamReceiver {
1101 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1102 let event = self.0.next().await;
1103 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1104 auth
1105 } else {
1106 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1107 }
1108 }
1109
1110 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1111 let event = self.0.next().await;
1112 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1113 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1114 ))) = event
1115 {
1116 update.terminal
1117 } else {
1118 panic!("Expected terminal but got: {:?}", event);
1119 }
1120 }
1121}
1122
1123#[cfg(test)]
1124impl std::ops::Deref for ToolCallEventStreamReceiver {
1125 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1126
1127 fn deref(&self) -> &Self::Target {
1128 &self.0
1129 }
1130}
1131
1132#[cfg(test)]
1133impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1134 fn deref_mut(&mut self) -> &mut Self::Target {
1135 &mut self.0
1136 }
1137}
1138
1139impl AgentMessage {
1140 fn to_request(&self) -> language_model::LanguageModelRequestMessage {
1141 let mut message = LanguageModelRequestMessage {
1142 role: self.role,
1143 content: Vec::with_capacity(self.content.len()),
1144 cache: false,
1145 };
1146
1147 const OPEN_CONTEXT: &str = "<context>\n\
1148 The following items were attached by the user. \
1149 They are up-to-date and don't need to be re-read.\n\n";
1150
1151 const OPEN_FILES_TAG: &str = "<files>";
1152 const OPEN_SYMBOLS_TAG: &str = "<symbols>";
1153 const OPEN_THREADS_TAG: &str = "<threads>";
1154 const OPEN_RULES_TAG: &str =
1155 "<rules>\nThe user has specified the following rules that should be applied:\n";
1156
1157 let mut file_context = OPEN_FILES_TAG.to_string();
1158 let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
1159 let mut thread_context = OPEN_THREADS_TAG.to_string();
1160 let mut rules_context = OPEN_RULES_TAG.to_string();
1161
1162 for chunk in &self.content {
1163 let chunk = match chunk {
1164 MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()),
1165 MessageContent::Thinking { text, signature } => {
1166 language_model::MessageContent::Thinking {
1167 text: text.clone(),
1168 signature: signature.clone(),
1169 }
1170 }
1171 MessageContent::RedactedThinking(value) => {
1172 language_model::MessageContent::RedactedThinking(value.clone())
1173 }
1174 MessageContent::ToolUse(value) => {
1175 language_model::MessageContent::ToolUse(value.clone())
1176 }
1177 MessageContent::ToolResult(value) => {
1178 language_model::MessageContent::ToolResult(value.clone())
1179 }
1180 MessageContent::Image(value) => {
1181 language_model::MessageContent::Image(value.clone())
1182 }
1183 MessageContent::Mention { uri, content } => {
1184 match uri {
1185 MentionUri::File(path) | MentionUri::Symbol(path, _) => {
1186 write!(
1187 &mut symbol_context,
1188 "\n{}",
1189 MarkdownCodeBlock {
1190 tag: &codeblock_tag(&path),
1191 text: &content.to_string(),
1192 }
1193 )
1194 .ok();
1195 }
1196 MentionUri::Thread(_session_id) => {
1197 write!(&mut thread_context, "\n{}\n", content).ok();
1198 }
1199 MentionUri::Rule(_user_prompt_id) => {
1200 write!(
1201 &mut rules_context,
1202 "\n{}",
1203 MarkdownCodeBlock {
1204 tag: "",
1205 text: &content
1206 }
1207 )
1208 .ok();
1209 }
1210 }
1211
1212 language_model::MessageContent::Text(uri.to_link())
1213 }
1214 };
1215
1216 message.content.push(chunk);
1217 }
1218
1219 let len_before_context = message.content.len();
1220
1221 if file_context.len() > OPEN_FILES_TAG.len() {
1222 file_context.push_str("</files>\n");
1223 message
1224 .content
1225 .push(language_model::MessageContent::Text(file_context));
1226 }
1227
1228 if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
1229 symbol_context.push_str("</symbols>\n");
1230 message
1231 .content
1232 .push(language_model::MessageContent::Text(symbol_context));
1233 }
1234
1235 if thread_context.len() > OPEN_THREADS_TAG.len() {
1236 thread_context.push_str("</threads>\n");
1237 message
1238 .content
1239 .push(language_model::MessageContent::Text(thread_context));
1240 }
1241
1242 if rules_context.len() > OPEN_RULES_TAG.len() {
1243 rules_context.push_str("</user_rules>\n");
1244 message
1245 .content
1246 .push(language_model::MessageContent::Text(rules_context));
1247 }
1248
1249 if message.content.len() > len_before_context {
1250 message.content.insert(
1251 len_before_context,
1252 language_model::MessageContent::Text(OPEN_CONTEXT.into()),
1253 );
1254 message
1255 .content
1256 .push(language_model::MessageContent::Text("</context>".into()));
1257 }
1258
1259 message
1260 }
1261}
1262
1263fn codeblock_tag(full_path: &Path) -> String {
1264 let mut result = String::new();
1265
1266 if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
1267 let _ = write!(result, "{} ", extension);
1268 }
1269
1270 let _ = write!(result, "{}", full_path.display());
1271
1272 result
1273}
1274
1275impl From<acp::ContentBlock> for MessageContent {
1276 fn from(value: acp::ContentBlock) -> Self {
1277 match value {
1278 acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text),
1279 acp::ContentBlock::Image(image_content) => {
1280 MessageContent::Image(convert_image(image_content))
1281 }
1282 acp::ContentBlock::Audio(_) => {
1283 // TODO
1284 MessageContent::Text("[audio]".to_string())
1285 }
1286 acp::ContentBlock::ResourceLink(resource_link) => {
1287 match MentionUri::parse(&resource_link.uri) {
1288 Ok(uri) => Self::Mention {
1289 uri,
1290 content: String::new(),
1291 },
1292 Err(err) => {
1293 log::error!("Failed to parse mention link: {}", err);
1294 MessageContent::Text(format!(
1295 "[{}]({})",
1296 resource_link.name, resource_link.uri
1297 ))
1298 }
1299 }
1300 }
1301 acp::ContentBlock::Resource(resource) => match resource.resource {
1302 acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1303 match MentionUri::parse(&resource.uri) {
1304 Ok(uri) => Self::Mention {
1305 uri,
1306 content: resource.text,
1307 },
1308 Err(err) => {
1309 log::error!("Failed to parse mention link: {}", err);
1310 MessageContent::Text(
1311 MarkdownCodeBlock {
1312 tag: &resource.uri,
1313 text: &resource.text,
1314 }
1315 .to_string(),
1316 )
1317 }
1318 }
1319 }
1320 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1321 // TODO
1322 MessageContent::Text("[blob]".to_string())
1323 }
1324 },
1325 }
1326 }
1327}
1328
1329fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1330 LanguageModelImage {
1331 source: image_content.data.into(),
1332 // TODO: make this optional?
1333 size: gpui::Size::new(0.into(), 0.into()),
1334 }
1335}
1336
1337impl From<&str> for MessageContent {
1338 fn from(text: &str) -> Self {
1339 MessageContent::Text(text.into())
1340 }
1341}