1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
2use action_log::ActionLog;
3use agent_client_protocol as acp;
4use agent_settings::{AgentProfileId, AgentSettings};
5use anyhow::{Context as _, Result, anyhow};
6use assistant_tool::adapt_schema_to_format;
7use cloud_llm_client::{CompletionIntent, CompletionMode};
8use collections::HashMap;
9use fs::Fs;
10use futures::{
11 channel::{mpsc, oneshot},
12 stream::FuturesUnordered,
13};
14use gpui::{App, Context, Entity, SharedString, Task};
15use language_model::{
16 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
17 LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
18 LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
19 LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
20};
21use log;
22use project::Project;
23use prompt_store::ProjectContext;
24use schemars::{JsonSchema, Schema};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, update_settings_file};
27use smol::stream::StreamExt;
28use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc};
29use util::{ResultExt, markdown::MarkdownCodeBlock};
30
31#[derive(Debug, Clone)]
32pub struct AgentMessage {
33 pub role: Role,
34 pub content: Vec<MessageContent>,
35}
36
37impl AgentMessage {
38 pub fn to_markdown(&self) -> String {
39 let mut markdown = format!("## {}\n", self.role);
40
41 for content in &self.content {
42 match content {
43 MessageContent::Text(text) => {
44 markdown.push_str(text);
45 markdown.push('\n');
46 }
47 MessageContent::Thinking { text, .. } => {
48 markdown.push_str("<think>");
49 markdown.push_str(text);
50 markdown.push_str("</think>\n");
51 }
52 MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
53 MessageContent::Image(_) => {
54 markdown.push_str("<image />\n");
55 }
56 MessageContent::ToolUse(tool_use) => {
57 markdown.push_str(&format!(
58 "**Tool Use**: {} (ID: {})\n",
59 tool_use.name, tool_use.id
60 ));
61 markdown.push_str(&format!(
62 "{}\n",
63 MarkdownCodeBlock {
64 tag: "json",
65 text: &format!("{:#}", tool_use.input)
66 }
67 ));
68 }
69 MessageContent::ToolResult(tool_result) => {
70 markdown.push_str(&format!(
71 "**Tool Result**: {} (ID: {})\n\n",
72 tool_result.tool_name, tool_result.tool_use_id
73 ));
74 if tool_result.is_error {
75 markdown.push_str("**ERROR:**\n");
76 }
77
78 match &tool_result.content {
79 LanguageModelToolResultContent::Text(text) => {
80 writeln!(markdown, "{text}\n").ok();
81 }
82 LanguageModelToolResultContent::Image(_) => {
83 writeln!(markdown, "<image />\n").ok();
84 }
85 }
86
87 if let Some(output) = tool_result.output.as_ref() {
88 writeln!(
89 markdown,
90 "**Debug Output**:\n\n```json\n{}\n```\n",
91 serde_json::to_string_pretty(output).unwrap()
92 )
93 .unwrap();
94 }
95 }
96 }
97 }
98
99 markdown
100 }
101}
102
103#[derive(Debug)]
104pub enum AgentResponseEvent {
105 Text(String),
106 Thinking(String),
107 ToolCall(acp::ToolCall),
108 ToolCallUpdate(acp_thread::ToolCallUpdate),
109 ToolCallAuthorization(ToolCallAuthorization),
110 Stop(acp::StopReason),
111}
112
113#[derive(Debug)]
114pub struct ToolCallAuthorization {
115 pub tool_call: acp::ToolCall,
116 pub options: Vec<acp::PermissionOption>,
117 pub response: oneshot::Sender<acp::PermissionOptionId>,
118}
119
120pub struct Thread {
121 messages: Vec<AgentMessage>,
122 completion_mode: CompletionMode,
123 /// Holds the task that handles agent interaction until the end of the turn.
124 /// Survives across multiple requests as the model performs tool calls and
125 /// we run tools, report their results.
126 running_turn: Option<Task<()>>,
127 pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
128 tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
129 context_server_registry: Entity<ContextServerRegistry>,
130 profile_id: AgentProfileId,
131 project_context: Rc<RefCell<ProjectContext>>,
132 templates: Arc<Templates>,
133 pub selected_model: Arc<dyn LanguageModel>,
134 project: Entity<Project>,
135 action_log: Entity<ActionLog>,
136}
137
138impl Thread {
139 pub fn new(
140 project: Entity<Project>,
141 project_context: Rc<RefCell<ProjectContext>>,
142 context_server_registry: Entity<ContextServerRegistry>,
143 action_log: Entity<ActionLog>,
144 templates: Arc<Templates>,
145 default_model: Arc<dyn LanguageModel>,
146 cx: &mut Context<Self>,
147 ) -> Self {
148 let profile_id = AgentSettings::get_global(cx).default_profile.clone();
149 Self {
150 messages: Vec::new(),
151 completion_mode: CompletionMode::Normal,
152 running_turn: None,
153 pending_tool_uses: HashMap::default(),
154 tools: BTreeMap::default(),
155 context_server_registry,
156 profile_id,
157 project_context,
158 templates,
159 selected_model: default_model,
160 project,
161 action_log,
162 }
163 }
164
165 pub fn project(&self) -> &Entity<Project> {
166 &self.project
167 }
168
169 pub fn action_log(&self) -> &Entity<ActionLog> {
170 &self.action_log
171 }
172
173 pub fn set_mode(&mut self, mode: CompletionMode) {
174 self.completion_mode = mode;
175 }
176
177 pub fn messages(&self) -> &[AgentMessage] {
178 &self.messages
179 }
180
181 pub fn add_tool(&mut self, tool: impl AgentTool) {
182 self.tools.insert(tool.name(), tool.erase());
183 }
184
185 pub fn remove_tool(&mut self, name: &str) -> bool {
186 self.tools.remove(name).is_some()
187 }
188
189 pub fn set_profile(&mut self, profile_id: AgentProfileId) {
190 self.profile_id = profile_id;
191 }
192
193 pub fn cancel(&mut self) {
194 self.running_turn.take();
195
196 let tool_results = self
197 .pending_tool_uses
198 .drain()
199 .map(|(tool_use_id, tool_use)| {
200 MessageContent::ToolResult(LanguageModelToolResult {
201 tool_use_id,
202 tool_name: tool_use.name.clone(),
203 is_error: true,
204 content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
205 output: None,
206 })
207 })
208 .collect::<Vec<_>>();
209 self.last_user_message().content.extend(tool_results);
210 }
211
212 /// Sending a message results in the model streaming a response, which could include tool calls.
213 /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
214 /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
215 pub fn send(
216 &mut self,
217 content: impl Into<MessageContent>,
218 cx: &mut Context<Self>,
219 ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
220 let content = content.into();
221 let model = self.selected_model.clone();
222 log::info!("Thread::send called with model: {:?}", model.name());
223 log::debug!("Thread::send content: {:?}", content);
224
225 cx.notify();
226 let (events_tx, events_rx) =
227 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
228 let event_stream = AgentResponseEventStream(events_tx);
229
230 let user_message_ix = self.messages.len();
231 self.messages.push(AgentMessage {
232 role: Role::User,
233 content: vec![content],
234 });
235 log::info!("Total messages in thread: {}", self.messages.len());
236 self.running_turn = Some(cx.spawn(async move |thread, cx| {
237 log::info!("Starting agent turn execution");
238 let turn_result = async {
239 // Perform one request, then keep looping if the model makes tool calls.
240 let mut completion_intent = CompletionIntent::UserPrompt;
241 'outer: loop {
242 log::debug!(
243 "Building completion request with intent: {:?}",
244 completion_intent
245 );
246 let request = thread.update(cx, |thread, cx| {
247 thread.build_completion_request(completion_intent, cx)
248 })?;
249
250 // println!(
251 // "request: {}",
252 // serde_json::to_string_pretty(&request).unwrap()
253 // );
254
255 // Stream events, appending to messages and collecting up tool uses.
256 log::info!("Calling model.stream_completion");
257 let mut events = model.stream_completion(request, cx).await?;
258 log::debug!("Stream completion started successfully");
259 let mut tool_uses = FuturesUnordered::new();
260 while let Some(event) = events.next().await {
261 match event {
262 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
263 event_stream.send_stop(reason);
264 if reason == StopReason::Refusal {
265 thread.update(cx, |thread, _cx| {
266 thread.messages.truncate(user_message_ix);
267 })?;
268 break 'outer;
269 }
270 }
271 Ok(event) => {
272 log::trace!("Received completion event: {:?}", event);
273 thread
274 .update(cx, |thread, cx| {
275 tool_uses.extend(thread.handle_streamed_completion_event(
276 event,
277 &event_stream,
278 cx,
279 ));
280 })
281 .ok();
282 }
283 Err(error) => {
284 log::error!("Error in completion stream: {:?}", error);
285 event_stream.send_error(error);
286 break;
287 }
288 }
289 }
290
291 // If there are no tool uses, the turn is done.
292 if tool_uses.is_empty() {
293 log::info!("No tool uses found, completing turn");
294 break;
295 }
296 log::info!("Found {} tool uses to execute", tool_uses.len());
297
298 // As tool results trickle in, insert them in the last user
299 // message so that they can be sent on the next tick of the
300 // agentic loop.
301 while let Some(tool_result) = tool_uses.next().await {
302 log::info!("Tool finished {:?}", tool_result);
303
304 event_stream.update_tool_call_fields(
305 &tool_result.tool_use_id,
306 acp::ToolCallUpdateFields {
307 status: Some(if tool_result.is_error {
308 acp::ToolCallStatus::Failed
309 } else {
310 acp::ToolCallStatus::Completed
311 }),
312 raw_output: tool_result.output.clone(),
313 ..Default::default()
314 },
315 );
316 thread
317 .update(cx, |thread, _cx| {
318 thread.pending_tool_uses.remove(&tool_result.tool_use_id);
319 thread
320 .last_user_message()
321 .content
322 .push(MessageContent::ToolResult(tool_result));
323 })
324 .ok();
325 }
326
327 completion_intent = CompletionIntent::ToolResults;
328 }
329
330 Ok(())
331 }
332 .await;
333
334 if let Err(error) = turn_result {
335 log::error!("Turn execution failed: {:?}", error);
336 event_stream.send_error(error);
337 } else {
338 log::info!("Turn execution completed successfully");
339 }
340 }));
341 events_rx
342 }
343
344 pub fn build_system_message(&self) -> AgentMessage {
345 log::debug!("Building system message");
346 let prompt = SystemPromptTemplate {
347 project: &self.project_context.borrow(),
348 available_tools: self.tools.keys().cloned().collect(),
349 }
350 .render(&self.templates)
351 .context("failed to build system prompt")
352 .expect("Invalid template");
353 log::debug!("System message built");
354 AgentMessage {
355 role: Role::System,
356 content: vec![prompt.into()],
357 }
358 }
359
360 /// A helper method that's called on every streamed completion event.
361 /// Returns an optional tool result task, which the main agentic loop in
362 /// send will send back to the model when it resolves.
363 fn handle_streamed_completion_event(
364 &mut self,
365 event: LanguageModelCompletionEvent,
366 event_stream: &AgentResponseEventStream,
367 cx: &mut Context<Self>,
368 ) -> Option<Task<LanguageModelToolResult>> {
369 log::trace!("Handling streamed completion event: {:?}", event);
370 use LanguageModelCompletionEvent::*;
371
372 match event {
373 StartMessage { .. } => {
374 self.messages.push(AgentMessage {
375 role: Role::Assistant,
376 content: Vec::new(),
377 });
378 }
379 Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
380 Thinking { text, signature } => {
381 self.handle_thinking_event(text, signature, event_stream, cx)
382 }
383 RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
384 ToolUse(tool_use) => {
385 return self.handle_tool_use_event(tool_use, event_stream, cx);
386 }
387 ToolUseJsonParseError {
388 id,
389 tool_name,
390 raw_input,
391 json_parse_error,
392 } => {
393 return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
394 id,
395 tool_name,
396 raw_input,
397 json_parse_error,
398 )));
399 }
400 UsageUpdate(_) | StatusUpdate(_) => {}
401 Stop(_) => unreachable!(),
402 }
403
404 None
405 }
406
407 fn handle_text_event(
408 &mut self,
409 new_text: String,
410 events_stream: &AgentResponseEventStream,
411 cx: &mut Context<Self>,
412 ) {
413 events_stream.send_text(&new_text);
414
415 let last_message = self.last_assistant_message();
416 if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
417 text.push_str(&new_text);
418 } else {
419 last_message.content.push(MessageContent::Text(new_text));
420 }
421
422 cx.notify();
423 }
424
425 fn handle_thinking_event(
426 &mut self,
427 new_text: String,
428 new_signature: Option<String>,
429 event_stream: &AgentResponseEventStream,
430 cx: &mut Context<Self>,
431 ) {
432 event_stream.send_thinking(&new_text);
433
434 let last_message = self.last_assistant_message();
435 if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
436 {
437 text.push_str(&new_text);
438 *signature = new_signature.or(signature.take());
439 } else {
440 last_message.content.push(MessageContent::Thinking {
441 text: new_text,
442 signature: new_signature,
443 });
444 }
445
446 cx.notify();
447 }
448
449 fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
450 let last_message = self.last_assistant_message();
451 last_message
452 .content
453 .push(MessageContent::RedactedThinking(data));
454 cx.notify();
455 }
456
457 fn handle_tool_use_event(
458 &mut self,
459 tool_use: LanguageModelToolUse,
460 event_stream: &AgentResponseEventStream,
461 cx: &mut Context<Self>,
462 ) -> Option<Task<LanguageModelToolResult>> {
463 cx.notify();
464
465 let tool = self.tools.get(tool_use.name.as_ref()).cloned();
466
467 self.pending_tool_uses
468 .insert(tool_use.id.clone(), tool_use.clone());
469 let last_message = self.last_assistant_message();
470
471 // Ensure the last message ends in the current tool use
472 let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
473 if let MessageContent::ToolUse(last_tool_use) = content {
474 if last_tool_use.id == tool_use.id {
475 *last_tool_use = tool_use.clone();
476 false
477 } else {
478 true
479 }
480 } else {
481 true
482 }
483 });
484
485 let mut title = SharedString::from(&tool_use.name);
486 let mut kind = acp::ToolKind::Other;
487 if let Some(tool) = tool.as_ref() {
488 title = tool.initial_title(tool_use.input.clone());
489 kind = tool.kind();
490 }
491
492 if push_new_tool_use {
493 event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
494 last_message
495 .content
496 .push(MessageContent::ToolUse(tool_use.clone()));
497 } else {
498 event_stream.update_tool_call_fields(
499 &tool_use.id,
500 acp::ToolCallUpdateFields {
501 title: Some(title.into()),
502 kind: Some(kind),
503 raw_input: Some(tool_use.input.clone()),
504 ..Default::default()
505 },
506 );
507 }
508
509 if !tool_use.is_input_complete {
510 return None;
511 }
512
513 let Some(tool) = tool else {
514 let content = format!("No tool named {} exists", tool_use.name);
515 return Some(Task::ready(LanguageModelToolResult {
516 content: LanguageModelToolResultContent::Text(Arc::from(content)),
517 tool_use_id: tool_use.id,
518 tool_name: tool_use.name,
519 is_error: true,
520 output: None,
521 }));
522 };
523
524 let fs = self.project.read(cx).fs().clone();
525 let tool_event_stream =
526 ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
527 tool_event_stream.update_fields(acp::ToolCallUpdateFields {
528 status: Some(acp::ToolCallStatus::InProgress),
529 ..Default::default()
530 });
531 let supports_images = self.selected_model.supports_images();
532 let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
533 Some(cx.foreground_executor().spawn(async move {
534 let tool_result = tool_result.await.and_then(|output| {
535 if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
536 if !supports_images {
537 return Err(anyhow!(
538 "Attempted to read an image, but this model doesn't support it.",
539 ));
540 }
541 }
542 Ok(output)
543 });
544
545 match tool_result {
546 Ok(output) => LanguageModelToolResult {
547 tool_use_id: tool_use.id,
548 tool_name: tool_use.name,
549 is_error: false,
550 content: output.llm_output,
551 output: Some(output.raw_output),
552 },
553 Err(error) => LanguageModelToolResult {
554 tool_use_id: tool_use.id,
555 tool_name: tool_use.name,
556 is_error: true,
557 content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
558 output: None,
559 },
560 }
561 }))
562 }
563
564 fn handle_tool_use_json_parse_error_event(
565 &mut self,
566 tool_use_id: LanguageModelToolUseId,
567 tool_name: Arc<str>,
568 raw_input: Arc<str>,
569 json_parse_error: String,
570 ) -> LanguageModelToolResult {
571 let tool_output = format!("Error parsing input JSON: {json_parse_error}");
572 LanguageModelToolResult {
573 tool_use_id,
574 tool_name,
575 is_error: true,
576 content: LanguageModelToolResultContent::Text(tool_output.into()),
577 output: Some(serde_json::Value::String(raw_input.to_string())),
578 }
579 }
580
581 /// Guarantees the last message is from the assistant and returns a mutable reference.
582 fn last_assistant_message(&mut self) -> &mut AgentMessage {
583 if self
584 .messages
585 .last()
586 .map_or(true, |m| m.role != Role::Assistant)
587 {
588 self.messages.push(AgentMessage {
589 role: Role::Assistant,
590 content: Vec::new(),
591 });
592 }
593 self.messages.last_mut().unwrap()
594 }
595
596 /// Guarantees the last message is from the user and returns a mutable reference.
597 fn last_user_message(&mut self) -> &mut AgentMessage {
598 if self.messages.last().map_or(true, |m| m.role != Role::User) {
599 self.messages.push(AgentMessage {
600 role: Role::User,
601 content: Vec::new(),
602 });
603 }
604 self.messages.last_mut().unwrap()
605 }
606
607 pub(crate) fn build_completion_request(
608 &self,
609 completion_intent: CompletionIntent,
610 cx: &mut App,
611 ) -> LanguageModelRequest {
612 log::debug!("Building completion request");
613 log::debug!("Completion intent: {:?}", completion_intent);
614 log::debug!("Completion mode: {:?}", self.completion_mode);
615
616 let messages = self.build_request_messages();
617 log::info!("Request will include {} messages", messages.len());
618
619 let tools = if let Some(tools) = self.tools(cx).log_err() {
620 tools
621 .filter_map(|tool| {
622 let tool_name = tool.name().to_string();
623 log::trace!("Including tool: {}", tool_name);
624 Some(LanguageModelRequestTool {
625 name: tool_name,
626 description: tool.description().to_string(),
627 input_schema: tool
628 .input_schema(self.selected_model.tool_input_format())
629 .log_err()?,
630 })
631 })
632 .collect()
633 } else {
634 Vec::new()
635 };
636
637 log::info!("Request includes {} tools", tools.len());
638
639 let request = LanguageModelRequest {
640 thread_id: None,
641 prompt_id: None,
642 intent: Some(completion_intent),
643 mode: Some(self.completion_mode),
644 messages,
645 tools,
646 tool_choice: None,
647 stop: Vec::new(),
648 temperature: None,
649 thinking_allowed: true,
650 };
651
652 log::debug!("Completion request built successfully");
653 request
654 }
655
656 fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
657 let profile = AgentSettings::get_global(cx)
658 .profiles
659 .get(&self.profile_id)
660 .context("profile not found")?;
661
662 Ok(self
663 .tools
664 .iter()
665 .filter_map(|(tool_name, tool)| {
666 if profile.is_tool_enabled(tool_name) {
667 Some(tool)
668 } else {
669 None
670 }
671 })
672 .chain(self.context_server_registry.read(cx).servers().flat_map(
673 |(server_id, tools)| {
674 tools.iter().filter_map(|(tool_name, tool)| {
675 if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
676 Some(tool)
677 } else {
678 None
679 }
680 })
681 },
682 )))
683 }
684
685 fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
686 log::trace!(
687 "Building request messages from {} thread messages",
688 self.messages.len()
689 );
690
691 let messages = Some(self.build_system_message())
692 .iter()
693 .chain(self.messages.iter())
694 .map(|message| {
695 log::trace!(
696 " - {} message with {} content items",
697 match message.role {
698 Role::System => "System",
699 Role::User => "User",
700 Role::Assistant => "Assistant",
701 },
702 message.content.len()
703 );
704 LanguageModelRequestMessage {
705 role: message.role,
706 content: message.content.clone(),
707 cache: false,
708 }
709 })
710 .collect();
711 messages
712 }
713
714 pub fn to_markdown(&self) -> String {
715 let mut markdown = String::new();
716 for message in &self.messages {
717 markdown.push_str(&message.to_markdown());
718 }
719 markdown
720 }
721}
722
723pub trait AgentTool
724where
725 Self: 'static + Sized,
726{
727 type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
728 type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
729
730 fn name(&self) -> SharedString;
731
732 fn description(&self) -> SharedString {
733 let schema = schemars::schema_for!(Self::Input);
734 SharedString::new(
735 schema
736 .get("description")
737 .and_then(|description| description.as_str())
738 .unwrap_or_default(),
739 )
740 }
741
742 fn kind(&self) -> acp::ToolKind;
743
744 /// The initial tool title to display. Can be updated during the tool run.
745 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
746
747 /// Returns the JSON schema that describes the tool's input.
748 fn input_schema(&self) -> Schema {
749 schemars::schema_for!(Self::Input)
750 }
751
752 /// Runs the tool with the provided input.
753 fn run(
754 self: Arc<Self>,
755 input: Self::Input,
756 event_stream: ToolCallEventStream,
757 cx: &mut App,
758 ) -> Task<Result<Self::Output>>;
759
760 fn erase(self) -> Arc<dyn AnyAgentTool> {
761 Arc::new(Erased(Arc::new(self)))
762 }
763}
764
765pub struct Erased<T>(T);
766
767pub struct AgentToolOutput {
768 pub llm_output: LanguageModelToolResultContent,
769 pub raw_output: serde_json::Value,
770}
771
772pub trait AnyAgentTool {
773 fn name(&self) -> SharedString;
774 fn description(&self) -> SharedString;
775 fn kind(&self) -> acp::ToolKind;
776 fn initial_title(&self, input: serde_json::Value) -> SharedString;
777 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
778 fn run(
779 self: Arc<Self>,
780 input: serde_json::Value,
781 event_stream: ToolCallEventStream,
782 cx: &mut App,
783 ) -> Task<Result<AgentToolOutput>>;
784}
785
786impl<T> AnyAgentTool for Erased<Arc<T>>
787where
788 T: AgentTool,
789{
790 fn name(&self) -> SharedString {
791 self.0.name()
792 }
793
794 fn description(&self) -> SharedString {
795 self.0.description()
796 }
797
798 fn kind(&self) -> agent_client_protocol::ToolKind {
799 self.0.kind()
800 }
801
802 fn initial_title(&self, input: serde_json::Value) -> SharedString {
803 let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
804 self.0.initial_title(parsed_input)
805 }
806
807 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
808 let mut json = serde_json::to_value(self.0.input_schema())?;
809 adapt_schema_to_format(&mut json, format)?;
810 Ok(json)
811 }
812
813 fn run(
814 self: Arc<Self>,
815 input: serde_json::Value,
816 event_stream: ToolCallEventStream,
817 cx: &mut App,
818 ) -> Task<Result<AgentToolOutput>> {
819 cx.spawn(async move |cx| {
820 let input = serde_json::from_value(input)?;
821 let output = cx
822 .update(|cx| self.0.clone().run(input, event_stream, cx))?
823 .await?;
824 let raw_output = serde_json::to_value(&output)?;
825 Ok(AgentToolOutput {
826 llm_output: output.into(),
827 raw_output,
828 })
829 })
830 }
831}
832
833#[derive(Clone)]
834struct AgentResponseEventStream(
835 mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
836);
837
838impl AgentResponseEventStream {
839 fn send_text(&self, text: &str) {
840 self.0
841 .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
842 .ok();
843 }
844
845 fn send_thinking(&self, text: &str) {
846 self.0
847 .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
848 .ok();
849 }
850
851 fn send_tool_call(
852 &self,
853 id: &LanguageModelToolUseId,
854 title: SharedString,
855 kind: acp::ToolKind,
856 input: serde_json::Value,
857 ) {
858 self.0
859 .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
860 id,
861 title.to_string(),
862 kind,
863 input,
864 ))))
865 .ok();
866 }
867
868 fn initial_tool_call(
869 id: &LanguageModelToolUseId,
870 title: String,
871 kind: acp::ToolKind,
872 input: serde_json::Value,
873 ) -> acp::ToolCall {
874 acp::ToolCall {
875 id: acp::ToolCallId(id.to_string().into()),
876 title,
877 kind,
878 status: acp::ToolCallStatus::Pending,
879 content: vec![],
880 locations: vec![],
881 raw_input: Some(input),
882 raw_output: None,
883 }
884 }
885
886 fn update_tool_call_fields(
887 &self,
888 tool_use_id: &LanguageModelToolUseId,
889 fields: acp::ToolCallUpdateFields,
890 ) {
891 self.0
892 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
893 acp::ToolCallUpdate {
894 id: acp::ToolCallId(tool_use_id.to_string().into()),
895 fields,
896 }
897 .into(),
898 )))
899 .ok();
900 }
901
902 fn send_stop(&self, reason: StopReason) {
903 match reason {
904 StopReason::EndTurn => {
905 self.0
906 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
907 .ok();
908 }
909 StopReason::MaxTokens => {
910 self.0
911 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
912 .ok();
913 }
914 StopReason::Refusal => {
915 self.0
916 .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
917 .ok();
918 }
919 StopReason::ToolUse => {}
920 }
921 }
922
923 fn send_error(&self, error: LanguageModelCompletionError) {
924 self.0.unbounded_send(Err(error)).ok();
925 }
926}
927
928#[derive(Clone)]
929pub struct ToolCallEventStream {
930 tool_use_id: LanguageModelToolUseId,
931 kind: acp::ToolKind,
932 input: serde_json::Value,
933 stream: AgentResponseEventStream,
934 fs: Option<Arc<dyn Fs>>,
935}
936
937impl ToolCallEventStream {
938 #[cfg(test)]
939 pub fn test() -> (Self, ToolCallEventStreamReceiver) {
940 let (events_tx, events_rx) =
941 mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
942
943 let stream = ToolCallEventStream::new(
944 &LanguageModelToolUse {
945 id: "test_id".into(),
946 name: "test_tool".into(),
947 raw_input: String::new(),
948 input: serde_json::Value::Null,
949 is_input_complete: true,
950 },
951 acp::ToolKind::Other,
952 AgentResponseEventStream(events_tx),
953 None,
954 );
955
956 (stream, ToolCallEventStreamReceiver(events_rx))
957 }
958
959 fn new(
960 tool_use: &LanguageModelToolUse,
961 kind: acp::ToolKind,
962 stream: AgentResponseEventStream,
963 fs: Option<Arc<dyn Fs>>,
964 ) -> Self {
965 Self {
966 tool_use_id: tool_use.id.clone(),
967 kind,
968 input: tool_use.input.clone(),
969 stream,
970 fs,
971 }
972 }
973
974 pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
975 self.stream
976 .update_tool_call_fields(&self.tool_use_id, fields);
977 }
978
979 pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
980 self.stream
981 .0
982 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
983 acp_thread::ToolCallUpdateDiff {
984 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
985 diff,
986 }
987 .into(),
988 )))
989 .ok();
990 }
991
992 pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
993 self.stream
994 .0
995 .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
996 acp_thread::ToolCallUpdateTerminal {
997 id: acp::ToolCallId(self.tool_use_id.to_string().into()),
998 terminal,
999 }
1000 .into(),
1001 )))
1002 .ok();
1003 }
1004
1005 pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1006 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1007 return Task::ready(Ok(()));
1008 }
1009
1010 let (response_tx, response_rx) = oneshot::channel();
1011 self.stream
1012 .0
1013 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1014 ToolCallAuthorization {
1015 tool_call: AgentResponseEventStream::initial_tool_call(
1016 &self.tool_use_id,
1017 title.into(),
1018 self.kind.clone(),
1019 self.input.clone(),
1020 ),
1021 options: vec![
1022 acp::PermissionOption {
1023 id: acp::PermissionOptionId("always_allow".into()),
1024 name: "Always Allow".into(),
1025 kind: acp::PermissionOptionKind::AllowAlways,
1026 },
1027 acp::PermissionOption {
1028 id: acp::PermissionOptionId("allow".into()),
1029 name: "Allow".into(),
1030 kind: acp::PermissionOptionKind::AllowOnce,
1031 },
1032 acp::PermissionOption {
1033 id: acp::PermissionOptionId("deny".into()),
1034 name: "Deny".into(),
1035 kind: acp::PermissionOptionKind::RejectOnce,
1036 },
1037 ],
1038 response: response_tx,
1039 },
1040 )))
1041 .ok();
1042 let fs = self.fs.clone();
1043 cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1044 "always_allow" => {
1045 if let Some(fs) = fs.clone() {
1046 cx.update(|cx| {
1047 update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1048 settings.set_always_allow_tool_actions(true);
1049 });
1050 })?;
1051 }
1052
1053 Ok(())
1054 }
1055 "allow" => Ok(()),
1056 _ => Err(anyhow!("Permission to run tool denied by user")),
1057 })
1058 }
1059}
1060
1061#[cfg(test)]
1062pub struct ToolCallEventStreamReceiver(
1063 mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1064);
1065
1066#[cfg(test)]
1067impl ToolCallEventStreamReceiver {
1068 pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1069 let event = self.0.next().await;
1070 if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1071 auth
1072 } else {
1073 panic!("Expected ToolCallAuthorization but got: {:?}", event);
1074 }
1075 }
1076
1077 pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1078 let event = self.0.next().await;
1079 if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1080 acp_thread::ToolCallUpdate::UpdateTerminal(update),
1081 ))) = event
1082 {
1083 update.terminal
1084 } else {
1085 panic!("Expected terminal but got: {:?}", event);
1086 }
1087 }
1088}
1089
1090#[cfg(test)]
1091impl std::ops::Deref for ToolCallEventStreamReceiver {
1092 type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1093
1094 fn deref(&self) -> &Self::Target {
1095 &self.0
1096 }
1097}
1098
1099#[cfg(test)]
1100impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1101 fn deref_mut(&mut self) -> &mut Self::Target {
1102 &mut self.0
1103 }
1104}