thread.rs

   1use crate::{SystemPromptTemplate, Template, Templates};
   2use acp_thread::Diff;
   3use agent_client_protocol as acp;
   4use anyhow::{anyhow, Context as _, Result};
   5use assistant_tool::{adapt_schema_to_format, ActionLog};
   6use cloud_llm_client::{CompletionIntent, CompletionMode};
   7use collections::HashMap;
   8use futures::{
   9    channel::{mpsc, oneshot},
  10    stream::FuturesUnordered,
  11};
  12use gpui::{App, Context, Entity, SharedString, Task};
  13use language_model::{
  14    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  15    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  16    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  17    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
  18};
  19use log;
  20use project::Project;
  21use prompt_store::ProjectContext;
  22use schemars::{JsonSchema, Schema};
  23use serde::{Deserialize, Serialize};
  24use smol::stream::StreamExt;
  25use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
  26use util::{markdown::MarkdownCodeBlock, ResultExt};
  27
  28#[derive(Debug, Clone)]
  29pub struct AgentMessage {
  30    pub role: Role,
  31    pub content: Vec<MessageContent>,
  32}
  33
  34impl AgentMessage {
  35    pub fn to_markdown(&self) -> String {
  36        let mut markdown = format!("## {}\n", self.role);
  37
  38        for content in &self.content {
  39            match content {
  40                MessageContent::Text(text) => {
  41                    markdown.push_str(text);
  42                    markdown.push('\n');
  43                }
  44                MessageContent::Thinking { text, .. } => {
  45                    markdown.push_str("<think>");
  46                    markdown.push_str(text);
  47                    markdown.push_str("</think>\n");
  48                }
  49                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
  50                MessageContent::Image(_) => {
  51                    markdown.push_str("<image />\n");
  52                }
  53                MessageContent::ToolUse(tool_use) => {
  54                    markdown.push_str(&format!(
  55                        "**Tool Use**: {} (ID: {})\n",
  56                        tool_use.name, tool_use.id
  57                    ));
  58                    markdown.push_str(&format!(
  59                        "{}\n",
  60                        MarkdownCodeBlock {
  61                            tag: "json",
  62                            text: &format!("{:#}", tool_use.input)
  63                        }
  64                    ));
  65                }
  66                MessageContent::ToolResult(tool_result) => {
  67                    markdown.push_str(&format!(
  68                        "**Tool Result**: {} (ID: {})\n\n",
  69                        tool_result.tool_name, tool_result.tool_use_id
  70                    ));
  71                    if tool_result.is_error {
  72                        markdown.push_str("**ERROR:**\n");
  73                    }
  74
  75                    match &tool_result.content {
  76                        LanguageModelToolResultContent::Text(text) => {
  77                            writeln!(markdown, "{text}\n").ok();
  78                        }
  79                        LanguageModelToolResultContent::Image(_) => {
  80                            writeln!(markdown, "<image />\n").ok();
  81                        }
  82                    }
  83
  84                    if let Some(output) = tool_result.output.as_ref() {
  85                        writeln!(
  86                            markdown,
  87                            "**Debug Output**:\n\n```json\n{}\n```\n",
  88                            serde_json::to_string_pretty(output).unwrap()
  89                        )
  90                        .unwrap();
  91                    }
  92                }
  93            }
  94        }
  95
  96        markdown
  97    }
  98}
  99
 100#[derive(Debug)]
 101pub enum AgentResponseEvent {
 102    Text(String),
 103    Thinking(String),
 104    ToolCall(acp::ToolCall),
 105    ToolCallUpdate(acp_thread::ToolCallUpdate),
 106    ToolCallAuthorization(ToolCallAuthorization),
 107    Stop(acp::StopReason),
 108}
 109
 110#[derive(Debug)]
 111pub struct ToolCallAuthorization {
 112    pub tool_call: acp::ToolCall,
 113    pub options: Vec<acp::PermissionOption>,
 114    pub response: oneshot::Sender<acp::PermissionOptionId>,
 115}
 116
 117pub struct Thread {
 118    messages: Vec<AgentMessage>,
 119    completion_mode: CompletionMode,
 120    /// Holds the task that handles agent interaction until the end of the turn.
 121    /// Survives across multiple requests as the model performs tool calls and
 122    /// we run tools, report their results.
 123    running_turn: Option<Task<()>>,
 124    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
 125    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 126    project_context: Rc<RefCell<ProjectContext>>,
 127    templates: Arc<Templates>,
 128    pub selected_model: Arc<dyn LanguageModel>,
 129    project: Entity<Project>,
 130    action_log: Entity<ActionLog>,
 131}
 132
 133impl Thread {
 134    pub fn new(
 135        project: Entity<Project>,
 136        project_context: Rc<RefCell<ProjectContext>>,
 137        action_log: Entity<ActionLog>,
 138        templates: Arc<Templates>,
 139        default_model: Arc<dyn LanguageModel>,
 140    ) -> Self {
 141        Self {
 142            messages: Vec::new(),
 143            completion_mode: CompletionMode::Normal,
 144            running_turn: None,
 145            pending_tool_uses: HashMap::default(),
 146            tools: BTreeMap::default(),
 147            project_context,
 148            templates,
 149            selected_model: default_model,
 150            project,
 151            action_log,
 152        }
 153    }
 154
 155    pub fn project(&self) -> &Entity<Project> {
 156        &self.project
 157    }
 158
 159    pub fn action_log(&self) -> &Entity<ActionLog> {
 160        &self.action_log
 161    }
 162
 163    pub fn set_mode(&mut self, mode: CompletionMode) {
 164        self.completion_mode = mode;
 165    }
 166
 167    pub fn messages(&self) -> &[AgentMessage] {
 168        &self.messages
 169    }
 170
 171    pub fn add_tool(&mut self, tool: impl AgentTool) {
 172        self.tools.insert(tool.name(), tool.erase());
 173    }
 174
 175    pub fn remove_tool(&mut self, name: &str) -> bool {
 176        self.tools.remove(name).is_some()
 177    }
 178
 179    pub fn cancel(&mut self) {
 180        self.running_turn.take();
 181
 182        let tool_results = self
 183            .pending_tool_uses
 184            .drain()
 185            .map(|(tool_use_id, tool_use)| {
 186                MessageContent::ToolResult(LanguageModelToolResult {
 187                    tool_use_id,
 188                    tool_name: tool_use.name.clone(),
 189                    is_error: true,
 190                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
 191                    output: None,
 192                })
 193            })
 194            .collect::<Vec<_>>();
 195        self.last_user_message().content.extend(tool_results);
 196    }
 197
 198    /// Sending a message results in the model streaming a response, which could include tool calls.
 199    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 200    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 201    pub fn send(
 202        &mut self,
 203        content: impl Into<MessageContent>,
 204        cx: &mut Context<Self>,
 205    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 206        let content = content.into();
 207        let model = self.selected_model.clone();
 208        log::info!("Thread::send called with model: {:?}", model.name());
 209        log::debug!("Thread::send content: {:?}", content);
 210
 211        cx.notify();
 212        let (events_tx, events_rx) =
 213            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 214        let event_stream = AgentResponseEventStream(events_tx);
 215
 216        let user_message_ix = self.messages.len();
 217        self.messages.push(AgentMessage {
 218            role: Role::User,
 219            content: vec![content],
 220        });
 221        log::info!("Total messages in thread: {}", self.messages.len());
 222        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 223            log::info!("Starting agent turn execution");
 224            let turn_result = async {
 225                // Perform one request, then keep looping if the model makes tool calls.
 226                let mut completion_intent = CompletionIntent::UserPrompt;
 227                'outer: loop {
 228                    log::debug!(
 229                        "Building completion request with intent: {:?}",
 230                        completion_intent
 231                    );
 232                    let request = thread.update(cx, |thread, cx| {
 233                        thread.build_completion_request(completion_intent, cx)
 234                    })?;
 235
 236                    // println!(
 237                    //     "request: {}",
 238                    //     serde_json::to_string_pretty(&request).unwrap()
 239                    // );
 240
 241                    // Stream events, appending to messages and collecting up tool uses.
 242                    log::info!("Calling model.stream_completion");
 243                    let mut events = model.stream_completion(request, cx).await?;
 244                    log::debug!("Stream completion started successfully");
 245                    let mut tool_uses = FuturesUnordered::new();
 246                    while let Some(event) = events.next().await {
 247                        match event {
 248                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 249                                event_stream.send_stop(reason);
 250                                if reason == StopReason::Refusal {
 251                                    thread.update(cx, |thread, _cx| {
 252                                        thread.messages.truncate(user_message_ix);
 253                                    })?;
 254                                    break 'outer;
 255                                }
 256                            }
 257                            Ok(event) => {
 258                                log::trace!("Received completion event: {:?}", event);
 259                                thread
 260                                    .update(cx, |thread, cx| {
 261                                        tool_uses.extend(thread.handle_streamed_completion_event(
 262                                            event,
 263                                            &event_stream,
 264                                            cx,
 265                                        ));
 266                                    })
 267                                    .ok();
 268                            }
 269                            Err(error) => {
 270                                log::error!("Error in completion stream: {:?}", error);
 271                                event_stream.send_error(error);
 272                                break;
 273                            }
 274                        }
 275                    }
 276
 277                    // If there are no tool uses, the turn is done.
 278                    if tool_uses.is_empty() {
 279                        log::info!("No tool uses found, completing turn");
 280                        break;
 281                    }
 282                    log::info!("Found {} tool uses to execute", tool_uses.len());
 283
 284                    // As tool results trickle in, insert them in the last user
 285                    // message so that they can be sent on the next tick of the
 286                    // agentic loop.
 287                    while let Some(tool_result) = tool_uses.next().await {
 288                        log::info!("Tool finished {:?}", tool_result);
 289
 290                        event_stream.update_tool_call_fields(
 291                            &tool_result.tool_use_id,
 292                            acp::ToolCallUpdateFields {
 293                                status: Some(if tool_result.is_error {
 294                                    acp::ToolCallStatus::Failed
 295                                } else {
 296                                    acp::ToolCallStatus::Completed
 297                                }),
 298                                ..Default::default()
 299                            },
 300                        );
 301                        thread
 302                            .update(cx, |thread, _cx| {
 303                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
 304                                thread
 305                                    .last_user_message()
 306                                    .content
 307                                    .push(MessageContent::ToolResult(tool_result));
 308                            })
 309                            .ok();
 310                    }
 311
 312                    completion_intent = CompletionIntent::ToolResults;
 313                }
 314
 315                Ok(())
 316            }
 317            .await;
 318
 319            if let Err(error) = turn_result {
 320                log::error!("Turn execution failed: {:?}", error);
 321                event_stream.send_error(error);
 322            } else {
 323                log::info!("Turn execution completed successfully");
 324            }
 325        }));
 326        events_rx
 327    }
 328
 329    pub fn build_system_message(&self) -> AgentMessage {
 330        log::debug!("Building system message");
 331        let prompt = SystemPromptTemplate {
 332            project: &self.project_context.borrow(),
 333            available_tools: self.tools.keys().cloned().collect(),
 334        }
 335        .render(&self.templates)
 336        .context("failed to build system prompt")
 337        .expect("Invalid template");
 338        log::debug!("System message built");
 339        AgentMessage {
 340            role: Role::System,
 341            content: vec![prompt.into()],
 342        }
 343    }
 344
 345    /// A helper method that's called on every streamed completion event.
 346    /// Returns an optional tool result task, which the main agentic loop in
 347    /// send will send back to the model when it resolves.
 348    fn handle_streamed_completion_event(
 349        &mut self,
 350        event: LanguageModelCompletionEvent,
 351        event_stream: &AgentResponseEventStream,
 352        cx: &mut Context<Self>,
 353    ) -> Option<Task<LanguageModelToolResult>> {
 354        log::trace!("Handling streamed completion event: {:?}", event);
 355        use LanguageModelCompletionEvent::*;
 356
 357        match event {
 358            StartMessage { .. } => {
 359                self.messages.push(AgentMessage {
 360                    role: Role::Assistant,
 361                    content: Vec::new(),
 362                });
 363            }
 364            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 365            Thinking { text, signature } => {
 366                self.handle_thinking_event(text, signature, event_stream, cx)
 367            }
 368            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 369            ToolUse(tool_use) => {
 370                return self.handle_tool_use_event(tool_use, event_stream, cx);
 371            }
 372            ToolUseJsonParseError {
 373                id,
 374                tool_name,
 375                raw_input,
 376                json_parse_error,
 377            } => {
 378                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 379                    id,
 380                    tool_name,
 381                    raw_input,
 382                    json_parse_error,
 383                )));
 384            }
 385            UsageUpdate(_) | StatusUpdate(_) => {}
 386            Stop(_) => unreachable!(),
 387        }
 388
 389        None
 390    }
 391
 392    fn handle_text_event(
 393        &mut self,
 394        new_text: String,
 395        events_stream: &AgentResponseEventStream,
 396        cx: &mut Context<Self>,
 397    ) {
 398        events_stream.send_text(&new_text);
 399
 400        let last_message = self.last_assistant_message();
 401        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
 402            text.push_str(&new_text);
 403        } else {
 404            last_message.content.push(MessageContent::Text(new_text));
 405        }
 406
 407        cx.notify();
 408    }
 409
 410    fn handle_thinking_event(
 411        &mut self,
 412        new_text: String,
 413        new_signature: Option<String>,
 414        event_stream: &AgentResponseEventStream,
 415        cx: &mut Context<Self>,
 416    ) {
 417        event_stream.send_thinking(&new_text);
 418
 419        let last_message = self.last_assistant_message();
 420        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
 421        {
 422            text.push_str(&new_text);
 423            *signature = new_signature.or(signature.take());
 424        } else {
 425            last_message.content.push(MessageContent::Thinking {
 426                text: new_text,
 427                signature: new_signature,
 428            });
 429        }
 430
 431        cx.notify();
 432    }
 433
 434    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 435        let last_message = self.last_assistant_message();
 436        last_message
 437            .content
 438            .push(MessageContent::RedactedThinking(data));
 439        cx.notify();
 440    }
 441
 442    fn handle_tool_use_event(
 443        &mut self,
 444        tool_use: LanguageModelToolUse,
 445        event_stream: &AgentResponseEventStream,
 446        cx: &mut Context<Self>,
 447    ) -> Option<Task<LanguageModelToolResult>> {
 448        cx.notify();
 449
 450        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 451
 452        self.pending_tool_uses
 453            .insert(tool_use.id.clone(), tool_use.clone());
 454        let last_message = self.last_assistant_message();
 455
 456        // Ensure the last message ends in the current tool use
 457        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 458            if let MessageContent::ToolUse(last_tool_use) = content {
 459                if last_tool_use.id == tool_use.id {
 460                    *last_tool_use = tool_use.clone();
 461                    false
 462                } else {
 463                    true
 464                }
 465            } else {
 466                true
 467            }
 468        });
 469
 470        let mut title = SharedString::from(&tool_use.name);
 471        let mut kind = acp::ToolKind::Other;
 472        if let Some(tool) = tool.as_ref() {
 473            title = tool.initial_title(tool_use.input.clone());
 474            kind = tool.kind();
 475        }
 476
 477        if push_new_tool_use {
 478            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 479            last_message
 480                .content
 481                .push(MessageContent::ToolUse(tool_use.clone()));
 482        } else {
 483            event_stream.update_tool_call_fields(
 484                &tool_use.id,
 485                acp::ToolCallUpdateFields {
 486                    title: Some(title.into()),
 487                    kind: Some(kind),
 488                    raw_input: Some(tool_use.input.clone()),
 489                    ..Default::default()
 490                },
 491            );
 492        }
 493
 494        if !tool_use.is_input_complete {
 495            return None;
 496        }
 497
 498        let Some(tool) = tool else {
 499            let content = format!("No tool named {} exists", tool_use.name);
 500            return Some(Task::ready(LanguageModelToolResult {
 501                content: LanguageModelToolResultContent::Text(Arc::from(content)),
 502                tool_use_id: tool_use.id,
 503                tool_name: tool_use.name,
 504                is_error: true,
 505                output: None,
 506            }));
 507        };
 508
 509        let tool_event_stream =
 510            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
 511        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
 512            status: Some(acp::ToolCallStatus::InProgress),
 513            ..Default::default()
 514        });
 515        let supports_images = self.selected_model.supports_images();
 516        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
 517        Some(cx.foreground_executor().spawn(async move {
 518            let tool_result = tool_result.await.and_then(|output| {
 519                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
 520                    if !supports_images {
 521                        return Err(anyhow!(
 522                            "Attempted to read an image, but this model doesn't support it.",
 523                        ));
 524                    }
 525                }
 526                Ok(output)
 527            });
 528
 529            match tool_result {
 530                Ok(output) => LanguageModelToolResult {
 531                    tool_use_id: tool_use.id,
 532                    tool_name: tool_use.name,
 533                    is_error: false,
 534                    content: output.llm_output,
 535                    output: Some(output.raw_output),
 536                },
 537                Err(error) => LanguageModelToolResult {
 538                    tool_use_id: tool_use.id,
 539                    tool_name: tool_use.name,
 540                    is_error: true,
 541                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
 542                    output: None,
 543                },
 544            }
 545        }))
 546    }
 547
 548    fn handle_tool_use_json_parse_error_event(
 549        &mut self,
 550        tool_use_id: LanguageModelToolUseId,
 551        tool_name: Arc<str>,
 552        raw_input: Arc<str>,
 553        json_parse_error: String,
 554    ) -> LanguageModelToolResult {
 555        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
 556        LanguageModelToolResult {
 557            tool_use_id,
 558            tool_name,
 559            is_error: true,
 560            content: LanguageModelToolResultContent::Text(tool_output.into()),
 561            output: Some(serde_json::Value::String(raw_input.to_string())),
 562        }
 563    }
 564
 565    /// Guarantees the last message is from the assistant and returns a mutable reference.
 566    fn last_assistant_message(&mut self) -> &mut AgentMessage {
 567        if self
 568            .messages
 569            .last()
 570            .map_or(true, |m| m.role != Role::Assistant)
 571        {
 572            self.messages.push(AgentMessage {
 573                role: Role::Assistant,
 574                content: Vec::new(),
 575            });
 576        }
 577        self.messages.last_mut().unwrap()
 578    }
 579
 580    /// Guarantees the last message is from the user and returns a mutable reference.
 581    fn last_user_message(&mut self) -> &mut AgentMessage {
 582        if self.messages.last().map_or(true, |m| m.role != Role::User) {
 583            self.messages.push(AgentMessage {
 584                role: Role::User,
 585                content: Vec::new(),
 586            });
 587        }
 588        self.messages.last_mut().unwrap()
 589    }
 590
 591    pub(crate) fn build_completion_request(
 592        &self,
 593        completion_intent: CompletionIntent,
 594        cx: &mut App,
 595    ) -> LanguageModelRequest {
 596        log::debug!("Building completion request");
 597        log::debug!("Completion intent: {:?}", completion_intent);
 598        log::debug!("Completion mode: {:?}", self.completion_mode);
 599
 600        let messages = self.build_request_messages();
 601        log::info!("Request will include {} messages", messages.len());
 602
 603        let tools: Vec<LanguageModelRequestTool> = self
 604            .tools
 605            .values()
 606            .filter_map(|tool| {
 607                let tool_name = tool.name().to_string();
 608                log::trace!("Including tool: {}", tool_name);
 609                Some(LanguageModelRequestTool {
 610                    name: tool_name,
 611                    description: tool.description(cx).to_string(),
 612                    input_schema: tool
 613                        .input_schema(self.selected_model.tool_input_format())
 614                        .log_err()?,
 615                })
 616            })
 617            .collect();
 618
 619        log::info!("Request includes {} tools", tools.len());
 620
 621        let request = LanguageModelRequest {
 622            thread_id: None,
 623            prompt_id: None,
 624            intent: Some(completion_intent),
 625            mode: Some(self.completion_mode),
 626            messages,
 627            tools,
 628            tool_choice: None,
 629            stop: Vec::new(),
 630            temperature: None,
 631            thinking_allowed: true,
 632        };
 633
 634        log::debug!("Completion request built successfully");
 635        request
 636    }
 637
 638    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
 639        log::trace!(
 640            "Building request messages from {} thread messages",
 641            self.messages.len()
 642        );
 643
 644        let messages = Some(self.build_system_message())
 645            .iter()
 646            .chain(self.messages.iter())
 647            .map(|message| {
 648                log::trace!(
 649                    "  - {} message with {} content items",
 650                    match message.role {
 651                        Role::System => "System",
 652                        Role::User => "User",
 653                        Role::Assistant => "Assistant",
 654                    },
 655                    message.content.len()
 656                );
 657                LanguageModelRequestMessage {
 658                    role: message.role,
 659                    content: message.content.clone(),
 660                    cache: false,
 661                }
 662            })
 663            .collect();
 664        messages
 665    }
 666
 667    pub fn to_markdown(&self) -> String {
 668        let mut markdown = String::new();
 669        for message in &self.messages {
 670            markdown.push_str(&message.to_markdown());
 671        }
 672        markdown
 673    }
 674}
 675
 676pub trait AgentTool
 677where
 678    Self: 'static + Sized,
 679{
 680    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
 681    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
 682
 683    fn name(&self) -> SharedString;
 684
 685    fn description(&self, _cx: &mut App) -> SharedString {
 686        let schema = schemars::schema_for!(Self::Input);
 687        SharedString::new(
 688            schema
 689                .get("description")
 690                .and_then(|description| description.as_str())
 691                .unwrap_or_default(),
 692        )
 693    }
 694
 695    fn kind(&self) -> acp::ToolKind;
 696
 697    /// The initial tool title to display. Can be updated during the tool run.
 698    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
 699
 700    /// Returns the JSON schema that describes the tool's input.
 701    fn input_schema(&self) -> Schema {
 702        schemars::schema_for!(Self::Input)
 703    }
 704
 705    /// Runs the tool with the provided input.
 706    fn run(
 707        self: Arc<Self>,
 708        input: Self::Input,
 709        event_stream: ToolCallEventStream,
 710        cx: &mut App,
 711    ) -> Task<Result<Self::Output>>;
 712
 713    fn erase(self) -> Arc<dyn AnyAgentTool> {
 714        Arc::new(Erased(Arc::new(self)))
 715    }
 716}
 717
 718pub struct Erased<T>(T);
 719
 720pub struct AgentToolOutput {
 721    llm_output: LanguageModelToolResultContent,
 722    raw_output: serde_json::Value,
 723}
 724
 725pub trait AnyAgentTool {
 726    fn name(&self) -> SharedString;
 727    fn description(&self, cx: &mut App) -> SharedString;
 728    fn kind(&self) -> acp::ToolKind;
 729    fn initial_title(&self, input: serde_json::Value) -> SharedString;
 730    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
 731    fn run(
 732        self: Arc<Self>,
 733        input: serde_json::Value,
 734        event_stream: ToolCallEventStream,
 735        cx: &mut App,
 736    ) -> Task<Result<AgentToolOutput>>;
 737}
 738
 739impl<T> AnyAgentTool for Erased<Arc<T>>
 740where
 741    T: AgentTool,
 742{
 743    fn name(&self) -> SharedString {
 744        self.0.name()
 745    }
 746
 747    fn description(&self, cx: &mut App) -> SharedString {
 748        self.0.description(cx)
 749    }
 750
 751    fn kind(&self) -> agent_client_protocol::ToolKind {
 752        self.0.kind()
 753    }
 754
 755    fn initial_title(&self, input: serde_json::Value) -> SharedString {
 756        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
 757        self.0.initial_title(parsed_input)
 758    }
 759
 760    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 761        let mut json = serde_json::to_value(self.0.input_schema())?;
 762        adapt_schema_to_format(&mut json, format)?;
 763        Ok(json)
 764    }
 765
 766    fn run(
 767        self: Arc<Self>,
 768        input: serde_json::Value,
 769        event_stream: ToolCallEventStream,
 770        cx: &mut App,
 771    ) -> Task<Result<AgentToolOutput>> {
 772        cx.spawn(async move |cx| {
 773            let input = serde_json::from_value(input)?;
 774            let output = cx
 775                .update(|cx| self.0.clone().run(input, event_stream, cx))?
 776                .await?;
 777            let raw_output = serde_json::to_value(&output)?;
 778            Ok(AgentToolOutput {
 779                llm_output: output.into(),
 780                raw_output,
 781            })
 782        })
 783    }
 784}
 785
 786#[derive(Clone)]
 787struct AgentResponseEventStream(
 788    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 789);
 790
 791impl AgentResponseEventStream {
 792    fn send_text(&self, text: &str) {
 793        self.0
 794            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
 795            .ok();
 796    }
 797
 798    fn send_thinking(&self, text: &str) {
 799        self.0
 800            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
 801            .ok();
 802    }
 803
 804    fn authorize_tool_call(
 805        &self,
 806        id: &LanguageModelToolUseId,
 807        title: String,
 808        kind: acp::ToolKind,
 809        input: serde_json::Value,
 810    ) -> impl use<> + Future<Output = Result<()>> {
 811        let (response_tx, response_rx) = oneshot::channel();
 812        self.0
 813            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
 814                ToolCallAuthorization {
 815                    tool_call: Self::initial_tool_call(id, title, kind, input),
 816                    options: vec![
 817                        acp::PermissionOption {
 818                            id: acp::PermissionOptionId("always_allow".into()),
 819                            name: "Always Allow".into(),
 820                            kind: acp::PermissionOptionKind::AllowAlways,
 821                        },
 822                        acp::PermissionOption {
 823                            id: acp::PermissionOptionId("allow".into()),
 824                            name: "Allow".into(),
 825                            kind: acp::PermissionOptionKind::AllowOnce,
 826                        },
 827                        acp::PermissionOption {
 828                            id: acp::PermissionOptionId("deny".into()),
 829                            name: "Deny".into(),
 830                            kind: acp::PermissionOptionKind::RejectOnce,
 831                        },
 832                    ],
 833                    response: response_tx,
 834                },
 835            )))
 836            .ok();
 837        async move {
 838            match response_rx.await?.0.as_ref() {
 839                "allow" | "always_allow" => Ok(()),
 840                _ => Err(anyhow!("Permission to run tool denied by user")),
 841            }
 842        }
 843    }
 844
 845    fn send_tool_call(
 846        &self,
 847        id: &LanguageModelToolUseId,
 848        title: SharedString,
 849        kind: acp::ToolKind,
 850        input: serde_json::Value,
 851    ) {
 852        self.0
 853            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
 854                id,
 855                title.to_string(),
 856                kind,
 857                input,
 858            ))))
 859            .ok();
 860    }
 861
 862    fn initial_tool_call(
 863        id: &LanguageModelToolUseId,
 864        title: String,
 865        kind: acp::ToolKind,
 866        input: serde_json::Value,
 867    ) -> acp::ToolCall {
 868        acp::ToolCall {
 869            id: acp::ToolCallId(id.to_string().into()),
 870            title,
 871            kind,
 872            status: acp::ToolCallStatus::Pending,
 873            content: vec![],
 874            locations: vec![],
 875            raw_input: Some(input),
 876            raw_output: None,
 877        }
 878    }
 879
 880    fn update_tool_call_fields(
 881        &self,
 882        tool_use_id: &LanguageModelToolUseId,
 883        fields: acp::ToolCallUpdateFields,
 884    ) {
 885        self.0
 886            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 887                acp::ToolCallUpdate {
 888                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 889                    fields,
 890                }
 891                .into(),
 892            )))
 893            .ok();
 894    }
 895
 896    fn update_tool_call_diff(&self, tool_use_id: &LanguageModelToolUseId, diff: Entity<Diff>) {
 897        self.0
 898            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 899                acp_thread::ToolCallUpdateDiff {
 900                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 901                    diff,
 902                }
 903                .into(),
 904            )))
 905            .ok();
 906    }
 907
 908    fn send_stop(&self, reason: StopReason) {
 909        match reason {
 910            StopReason::EndTurn => {
 911                self.0
 912                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
 913                    .ok();
 914            }
 915            StopReason::MaxTokens => {
 916                self.0
 917                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
 918                    .ok();
 919            }
 920            StopReason::Refusal => {
 921                self.0
 922                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
 923                    .ok();
 924            }
 925            StopReason::ToolUse => {}
 926        }
 927    }
 928
 929    fn send_error(&self, error: LanguageModelCompletionError) {
 930        self.0.unbounded_send(Err(error)).ok();
 931    }
 932}
 933
 934#[derive(Clone)]
 935pub struct ToolCallEventStream {
 936    tool_use_id: LanguageModelToolUseId,
 937    kind: acp::ToolKind,
 938    input: serde_json::Value,
 939    stream: AgentResponseEventStream,
 940}
 941
 942impl ToolCallEventStream {
 943    #[cfg(test)]
 944    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
 945        let (events_tx, events_rx) =
 946            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 947
 948        let stream = ToolCallEventStream::new(
 949            &LanguageModelToolUse {
 950                id: "test_id".into(),
 951                name: "test_tool".into(),
 952                raw_input: String::new(),
 953                input: serde_json::Value::Null,
 954                is_input_complete: true,
 955            },
 956            acp::ToolKind::Other,
 957            AgentResponseEventStream(events_tx),
 958        );
 959
 960        (stream, ToolCallEventStreamReceiver(events_rx))
 961    }
 962
 963    fn new(
 964        tool_use: &LanguageModelToolUse,
 965        kind: acp::ToolKind,
 966        stream: AgentResponseEventStream,
 967    ) -> Self {
 968        Self {
 969            tool_use_id: tool_use.id.clone(),
 970            kind,
 971            input: tool_use.input.clone(),
 972            stream,
 973        }
 974    }
 975
 976    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
 977        self.stream
 978            .update_tool_call_fields(&self.tool_use_id, fields);
 979    }
 980
 981    pub fn update_diff(&self, diff: Entity<Diff>) {
 982        self.stream.update_tool_call_diff(&self.tool_use_id, diff);
 983    }
 984
 985    pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
 986        self.stream.authorize_tool_call(
 987            &self.tool_use_id,
 988            title,
 989            self.kind.clone(),
 990            self.input.clone(),
 991        )
 992    }
 993}
 994
 995#[cfg(test)]
 996pub struct ToolCallEventStreamReceiver(
 997    mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 998);
 999
1000#[cfg(test)]
1001impl ToolCallEventStreamReceiver {
1002    pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
1003        let event = self.0.next().await;
1004        if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1005            auth
1006        } else {
1007            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1008        }
1009    }
1010}
1011
1012#[cfg(test)]
1013impl std::ops::Deref for ToolCallEventStreamReceiver {
1014    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1015
1016    fn deref(&self) -> &Self::Target {
1017        &self.0
1018    }
1019}
1020
1021#[cfg(test)]
1022impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1023    fn deref_mut(&mut self) -> &mut Self::Target {
1024        &mut self.0
1025    }
1026}