thread.rs

   1use crate::{SystemPromptTemplate, Template, Templates};
   2use action_log::ActionLog;
   3use agent_client_protocol as acp;
   4use anyhow::{Context as _, Result, anyhow};
   5use assistant_tool::adapt_schema_to_format;
   6use cloud_llm_client::{CompletionIntent, CompletionMode};
   7use collections::HashMap;
   8use futures::{
   9    channel::{mpsc, oneshot},
  10    stream::FuturesUnordered,
  11};
  12use gpui::{App, Context, Entity, SharedString, Task};
  13use language_model::{
  14    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  15    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  16    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  17    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
  18};
  19use log;
  20use project::Project;
  21use prompt_store::ProjectContext;
  22use schemars::{JsonSchema, Schema};
  23use serde::{Deserialize, Serialize};
  24use smol::stream::StreamExt;
  25use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
  26use util::{ResultExt, markdown::MarkdownCodeBlock};
  27
  28#[derive(Debug, Clone)]
  29pub struct AgentMessage {
  30    pub role: Role,
  31    pub content: Vec<MessageContent>,
  32}
  33
  34impl AgentMessage {
  35    pub fn to_markdown(&self) -> String {
  36        let mut markdown = format!("## {}\n", self.role);
  37
  38        for content in &self.content {
  39            match content {
  40                MessageContent::Text(text) => {
  41                    markdown.push_str(text);
  42                    markdown.push('\n');
  43                }
  44                MessageContent::Thinking { text, .. } => {
  45                    markdown.push_str("<think>");
  46                    markdown.push_str(text);
  47                    markdown.push_str("</think>\n");
  48                }
  49                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
  50                MessageContent::Image(_) => {
  51                    markdown.push_str("<image />\n");
  52                }
  53                MessageContent::ToolUse(tool_use) => {
  54                    markdown.push_str(&format!(
  55                        "**Tool Use**: {} (ID: {})\n",
  56                        tool_use.name, tool_use.id
  57                    ));
  58                    markdown.push_str(&format!(
  59                        "{}\n",
  60                        MarkdownCodeBlock {
  61                            tag: "json",
  62                            text: &format!("{:#}", tool_use.input)
  63                        }
  64                    ));
  65                }
  66                MessageContent::ToolResult(tool_result) => {
  67                    markdown.push_str(&format!(
  68                        "**Tool Result**: {} (ID: {})\n\n",
  69                        tool_result.tool_name, tool_result.tool_use_id
  70                    ));
  71                    if tool_result.is_error {
  72                        markdown.push_str("**ERROR:**\n");
  73                    }
  74
  75                    match &tool_result.content {
  76                        LanguageModelToolResultContent::Text(text) => {
  77                            writeln!(markdown, "{text}\n").ok();
  78                        }
  79                        LanguageModelToolResultContent::Image(_) => {
  80                            writeln!(markdown, "<image />\n").ok();
  81                        }
  82                    }
  83
  84                    if let Some(output) = tool_result.output.as_ref() {
  85                        writeln!(
  86                            markdown,
  87                            "**Debug Output**:\n\n```json\n{}\n```\n",
  88                            serde_json::to_string_pretty(output).unwrap()
  89                        )
  90                        .unwrap();
  91                    }
  92                }
  93            }
  94        }
  95
  96        markdown
  97    }
  98}
  99
 100#[derive(Debug)]
 101pub enum AgentResponseEvent {
 102    Text(String),
 103    Thinking(String),
 104    ToolCall(acp::ToolCall),
 105    ToolCallUpdate(acp_thread::ToolCallUpdate),
 106    ToolCallAuthorization(ToolCallAuthorization),
 107    Stop(acp::StopReason),
 108}
 109
 110#[derive(Debug)]
 111pub struct ToolCallAuthorization {
 112    pub tool_call: acp::ToolCall,
 113    pub options: Vec<acp::PermissionOption>,
 114    pub response: oneshot::Sender<acp::PermissionOptionId>,
 115}
 116
 117pub struct Thread {
 118    messages: Vec<AgentMessage>,
 119    completion_mode: CompletionMode,
 120    /// Holds the task that handles agent interaction until the end of the turn.
 121    /// Survives across multiple requests as the model performs tool calls and
 122    /// we run tools, report their results.
 123    running_turn: Option<Task<()>>,
 124    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
 125    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 126    project_context: Rc<RefCell<ProjectContext>>,
 127    templates: Arc<Templates>,
 128    pub selected_model: Arc<dyn LanguageModel>,
 129    project: Entity<Project>,
 130    action_log: Entity<ActionLog>,
 131}
 132
 133impl Thread {
 134    pub fn new(
 135        project: Entity<Project>,
 136        project_context: Rc<RefCell<ProjectContext>>,
 137        action_log: Entity<ActionLog>,
 138        templates: Arc<Templates>,
 139        default_model: Arc<dyn LanguageModel>,
 140    ) -> Self {
 141        Self {
 142            messages: Vec::new(),
 143            completion_mode: CompletionMode::Normal,
 144            running_turn: None,
 145            pending_tool_uses: HashMap::default(),
 146            tools: BTreeMap::default(),
 147            project_context,
 148            templates,
 149            selected_model: default_model,
 150            project,
 151            action_log,
 152        }
 153    }
 154
 155    pub fn project(&self) -> &Entity<Project> {
 156        &self.project
 157    }
 158
 159    pub fn action_log(&self) -> &Entity<ActionLog> {
 160        &self.action_log
 161    }
 162
 163    pub fn set_mode(&mut self, mode: CompletionMode) {
 164        self.completion_mode = mode;
 165    }
 166
 167    pub fn messages(&self) -> &[AgentMessage] {
 168        &self.messages
 169    }
 170
 171    pub fn add_tool(&mut self, tool: impl AgentTool) {
 172        self.tools.insert(tool.name(), tool.erase());
 173    }
 174
 175    pub fn remove_tool(&mut self, name: &str) -> bool {
 176        self.tools.remove(name).is_some()
 177    }
 178
 179    pub fn cancel(&mut self) {
 180        self.running_turn.take();
 181
 182        let tool_results = self
 183            .pending_tool_uses
 184            .drain()
 185            .map(|(tool_use_id, tool_use)| {
 186                MessageContent::ToolResult(LanguageModelToolResult {
 187                    tool_use_id,
 188                    tool_name: tool_use.name.clone(),
 189                    is_error: true,
 190                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
 191                    output: None,
 192                })
 193            })
 194            .collect::<Vec<_>>();
 195        self.last_user_message().content.extend(tool_results);
 196    }
 197
 198    /// Sending a message results in the model streaming a response, which could include tool calls.
 199    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 200    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 201    pub fn send(
 202        &mut self,
 203        content: impl Into<MessageContent>,
 204        cx: &mut Context<Self>,
 205    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 206        let content = content.into();
 207        let model = self.selected_model.clone();
 208        log::info!("Thread::send called with model: {:?}", model.name());
 209        log::debug!("Thread::send content: {:?}", content);
 210
 211        cx.notify();
 212        let (events_tx, events_rx) =
 213            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 214        let event_stream = AgentResponseEventStream(events_tx);
 215
 216        let user_message_ix = self.messages.len();
 217        self.messages.push(AgentMessage {
 218            role: Role::User,
 219            content: vec![content],
 220        });
 221        log::info!("Total messages in thread: {}", self.messages.len());
 222        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 223            log::info!("Starting agent turn execution");
 224            let turn_result = async {
 225                // Perform one request, then keep looping if the model makes tool calls.
 226                let mut completion_intent = CompletionIntent::UserPrompt;
 227                'outer: loop {
 228                    log::debug!(
 229                        "Building completion request with intent: {:?}",
 230                        completion_intent
 231                    );
 232                    let request = thread.update(cx, |thread, cx| {
 233                        thread.build_completion_request(completion_intent, cx)
 234                    })?;
 235
 236                    // println!(
 237                    //     "request: {}",
 238                    //     serde_json::to_string_pretty(&request).unwrap()
 239                    // );
 240
 241                    // Stream events, appending to messages and collecting up tool uses.
 242                    log::info!("Calling model.stream_completion");
 243                    let mut events = model.stream_completion(request, cx).await?;
 244                    log::debug!("Stream completion started successfully");
 245                    let mut tool_uses = FuturesUnordered::new();
 246                    while let Some(event) = events.next().await {
 247                        match event {
 248                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 249                                event_stream.send_stop(reason);
 250                                if reason == StopReason::Refusal {
 251                                    thread.update(cx, |thread, _cx| {
 252                                        thread.messages.truncate(user_message_ix);
 253                                    })?;
 254                                    break 'outer;
 255                                }
 256                            }
 257                            Ok(event) => {
 258                                log::trace!("Received completion event: {:?}", event);
 259                                thread
 260                                    .update(cx, |thread, cx| {
 261                                        tool_uses.extend(thread.handle_streamed_completion_event(
 262                                            event,
 263                                            &event_stream,
 264                                            cx,
 265                                        ));
 266                                    })
 267                                    .ok();
 268                            }
 269                            Err(error) => {
 270                                log::error!("Error in completion stream: {:?}", error);
 271                                event_stream.send_error(error);
 272                                break;
 273                            }
 274                        }
 275                    }
 276
 277                    // If there are no tool uses, the turn is done.
 278                    if tool_uses.is_empty() {
 279                        log::info!("No tool uses found, completing turn");
 280                        break;
 281                    }
 282                    log::info!("Found {} tool uses to execute", tool_uses.len());
 283
 284                    // As tool results trickle in, insert them in the last user
 285                    // message so that they can be sent on the next tick of the
 286                    // agentic loop.
 287                    while let Some(tool_result) = tool_uses.next().await {
 288                        log::info!("Tool finished {:?}", tool_result);
 289
 290                        event_stream.update_tool_call_fields(
 291                            &tool_result.tool_use_id,
 292                            acp::ToolCallUpdateFields {
 293                                status: Some(if tool_result.is_error {
 294                                    acp::ToolCallStatus::Failed
 295                                } else {
 296                                    acp::ToolCallStatus::Completed
 297                                }),
 298                                ..Default::default()
 299                            },
 300                        );
 301                        thread
 302                            .update(cx, |thread, _cx| {
 303                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
 304                                thread
 305                                    .last_user_message()
 306                                    .content
 307                                    .push(MessageContent::ToolResult(tool_result));
 308                            })
 309                            .ok();
 310                    }
 311
 312                    completion_intent = CompletionIntent::ToolResults;
 313                }
 314
 315                Ok(())
 316            }
 317            .await;
 318
 319            if let Err(error) = turn_result {
 320                log::error!("Turn execution failed: {:?}", error);
 321                event_stream.send_error(error);
 322            } else {
 323                log::info!("Turn execution completed successfully");
 324            }
 325        }));
 326        events_rx
 327    }
 328
 329    pub fn build_system_message(&self) -> AgentMessage {
 330        log::debug!("Building system message");
 331        let prompt = SystemPromptTemplate {
 332            project: &self.project_context.borrow(),
 333            available_tools: self.tools.keys().cloned().collect(),
 334        }
 335        .render(&self.templates)
 336        .context("failed to build system prompt")
 337        .expect("Invalid template");
 338        log::debug!("System message built");
 339        AgentMessage {
 340            role: Role::System,
 341            content: vec![prompt.into()],
 342        }
 343    }
 344
 345    /// A helper method that's called on every streamed completion event.
 346    /// Returns an optional tool result task, which the main agentic loop in
 347    /// send will send back to the model when it resolves.
 348    fn handle_streamed_completion_event(
 349        &mut self,
 350        event: LanguageModelCompletionEvent,
 351        event_stream: &AgentResponseEventStream,
 352        cx: &mut Context<Self>,
 353    ) -> Option<Task<LanguageModelToolResult>> {
 354        log::trace!("Handling streamed completion event: {:?}", event);
 355        use LanguageModelCompletionEvent::*;
 356
 357        match event {
 358            StartMessage { .. } => {
 359                self.messages.push(AgentMessage {
 360                    role: Role::Assistant,
 361                    content: Vec::new(),
 362                });
 363            }
 364            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 365            Thinking { text, signature } => {
 366                self.handle_thinking_event(text, signature, event_stream, cx)
 367            }
 368            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 369            ToolUse(tool_use) => {
 370                return self.handle_tool_use_event(tool_use, event_stream, cx);
 371            }
 372            ToolUseJsonParseError {
 373                id,
 374                tool_name,
 375                raw_input,
 376                json_parse_error,
 377            } => {
 378                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 379                    id,
 380                    tool_name,
 381                    raw_input,
 382                    json_parse_error,
 383                )));
 384            }
 385            UsageUpdate(_) | StatusUpdate(_) => {}
 386            Stop(_) => unreachable!(),
 387        }
 388
 389        None
 390    }
 391
 392    fn handle_text_event(
 393        &mut self,
 394        new_text: String,
 395        events_stream: &AgentResponseEventStream,
 396        cx: &mut Context<Self>,
 397    ) {
 398        events_stream.send_text(&new_text);
 399
 400        let last_message = self.last_assistant_message();
 401        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
 402            text.push_str(&new_text);
 403        } else {
 404            last_message.content.push(MessageContent::Text(new_text));
 405        }
 406
 407        cx.notify();
 408    }
 409
 410    fn handle_thinking_event(
 411        &mut self,
 412        new_text: String,
 413        new_signature: Option<String>,
 414        event_stream: &AgentResponseEventStream,
 415        cx: &mut Context<Self>,
 416    ) {
 417        event_stream.send_thinking(&new_text);
 418
 419        let last_message = self.last_assistant_message();
 420        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
 421        {
 422            text.push_str(&new_text);
 423            *signature = new_signature.or(signature.take());
 424        } else {
 425            last_message.content.push(MessageContent::Thinking {
 426                text: new_text,
 427                signature: new_signature,
 428            });
 429        }
 430
 431        cx.notify();
 432    }
 433
 434    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 435        let last_message = self.last_assistant_message();
 436        last_message
 437            .content
 438            .push(MessageContent::RedactedThinking(data));
 439        cx.notify();
 440    }
 441
 442    fn handle_tool_use_event(
 443        &mut self,
 444        tool_use: LanguageModelToolUse,
 445        event_stream: &AgentResponseEventStream,
 446        cx: &mut Context<Self>,
 447    ) -> Option<Task<LanguageModelToolResult>> {
 448        cx.notify();
 449
 450        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 451
 452        self.pending_tool_uses
 453            .insert(tool_use.id.clone(), tool_use.clone());
 454        let last_message = self.last_assistant_message();
 455
 456        // Ensure the last message ends in the current tool use
 457        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 458            if let MessageContent::ToolUse(last_tool_use) = content {
 459                if last_tool_use.id == tool_use.id {
 460                    *last_tool_use = tool_use.clone();
 461                    false
 462                } else {
 463                    true
 464                }
 465            } else {
 466                true
 467            }
 468        });
 469
 470        let mut title = SharedString::from(&tool_use.name);
 471        let mut kind = acp::ToolKind::Other;
 472        if let Some(tool) = tool.as_ref() {
 473            title = tool.initial_title(tool_use.input.clone());
 474            kind = tool.kind();
 475        }
 476
 477        if push_new_tool_use {
 478            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 479            last_message
 480                .content
 481                .push(MessageContent::ToolUse(tool_use.clone()));
 482        } else {
 483            event_stream.update_tool_call_fields(
 484                &tool_use.id,
 485                acp::ToolCallUpdateFields {
 486                    title: Some(title.into()),
 487                    kind: Some(kind),
 488                    raw_input: Some(tool_use.input.clone()),
 489                    ..Default::default()
 490                },
 491            );
 492        }
 493
 494        if !tool_use.is_input_complete {
 495            return None;
 496        }
 497
 498        let Some(tool) = tool else {
 499            let content = format!("No tool named {} exists", tool_use.name);
 500            return Some(Task::ready(LanguageModelToolResult {
 501                content: LanguageModelToolResultContent::Text(Arc::from(content)),
 502                tool_use_id: tool_use.id,
 503                tool_name: tool_use.name,
 504                is_error: true,
 505                output: None,
 506            }));
 507        };
 508
 509        let tool_event_stream =
 510            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
 511        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
 512            status: Some(acp::ToolCallStatus::InProgress),
 513            ..Default::default()
 514        });
 515        let supports_images = self.selected_model.supports_images();
 516        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
 517        Some(cx.foreground_executor().spawn(async move {
 518            let tool_result = tool_result.await.and_then(|output| {
 519                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
 520                    if !supports_images {
 521                        return Err(anyhow!(
 522                            "Attempted to read an image, but this model doesn't support it.",
 523                        ));
 524                    }
 525                }
 526                Ok(output)
 527            });
 528
 529            match tool_result {
 530                Ok(output) => LanguageModelToolResult {
 531                    tool_use_id: tool_use.id,
 532                    tool_name: tool_use.name,
 533                    is_error: false,
 534                    content: output.llm_output,
 535                    output: Some(output.raw_output),
 536                },
 537                Err(error) => LanguageModelToolResult {
 538                    tool_use_id: tool_use.id,
 539                    tool_name: tool_use.name,
 540                    is_error: true,
 541                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
 542                    output: None,
 543                },
 544            }
 545        }))
 546    }
 547
 548    fn handle_tool_use_json_parse_error_event(
 549        &mut self,
 550        tool_use_id: LanguageModelToolUseId,
 551        tool_name: Arc<str>,
 552        raw_input: Arc<str>,
 553        json_parse_error: String,
 554    ) -> LanguageModelToolResult {
 555        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
 556        LanguageModelToolResult {
 557            tool_use_id,
 558            tool_name,
 559            is_error: true,
 560            content: LanguageModelToolResultContent::Text(tool_output.into()),
 561            output: Some(serde_json::Value::String(raw_input.to_string())),
 562        }
 563    }
 564
 565    /// Guarantees the last message is from the assistant and returns a mutable reference.
 566    fn last_assistant_message(&mut self) -> &mut AgentMessage {
 567        if self
 568            .messages
 569            .last()
 570            .map_or(true, |m| m.role != Role::Assistant)
 571        {
 572            self.messages.push(AgentMessage {
 573                role: Role::Assistant,
 574                content: Vec::new(),
 575            });
 576        }
 577        self.messages.last_mut().unwrap()
 578    }
 579
 580    /// Guarantees the last message is from the user and returns a mutable reference.
 581    fn last_user_message(&mut self) -> &mut AgentMessage {
 582        if self.messages.last().map_or(true, |m| m.role != Role::User) {
 583            self.messages.push(AgentMessage {
 584                role: Role::User,
 585                content: Vec::new(),
 586            });
 587        }
 588        self.messages.last_mut().unwrap()
 589    }
 590
 591    pub(crate) fn build_completion_request(
 592        &self,
 593        completion_intent: CompletionIntent,
 594        cx: &mut App,
 595    ) -> LanguageModelRequest {
 596        log::debug!("Building completion request");
 597        log::debug!("Completion intent: {:?}", completion_intent);
 598        log::debug!("Completion mode: {:?}", self.completion_mode);
 599
 600        let messages = self.build_request_messages();
 601        log::info!("Request will include {} messages", messages.len());
 602
 603        let tools: Vec<LanguageModelRequestTool> = self
 604            .tools
 605            .values()
 606            .filter_map(|tool| {
 607                let tool_name = tool.name().to_string();
 608                log::trace!("Including tool: {}", tool_name);
 609                Some(LanguageModelRequestTool {
 610                    name: tool_name,
 611                    description: tool.description(cx).to_string(),
 612                    input_schema: tool
 613                        .input_schema(self.selected_model.tool_input_format())
 614                        .log_err()?,
 615                })
 616            })
 617            .collect();
 618
 619        log::info!("Request includes {} tools", tools.len());
 620
 621        let request = LanguageModelRequest {
 622            thread_id: None,
 623            prompt_id: None,
 624            intent: Some(completion_intent),
 625            mode: Some(self.completion_mode),
 626            messages,
 627            tools,
 628            tool_choice: None,
 629            stop: Vec::new(),
 630            temperature: None,
 631            thinking_allowed: true,
 632        };
 633
 634        log::debug!("Completion request built successfully");
 635        request
 636    }
 637
 638    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
 639        log::trace!(
 640            "Building request messages from {} thread messages",
 641            self.messages.len()
 642        );
 643
 644        let messages = Some(self.build_system_message())
 645            .iter()
 646            .chain(self.messages.iter())
 647            .map(|message| {
 648                log::trace!(
 649                    "  - {} message with {} content items",
 650                    match message.role {
 651                        Role::System => "System",
 652                        Role::User => "User",
 653                        Role::Assistant => "Assistant",
 654                    },
 655                    message.content.len()
 656                );
 657                LanguageModelRequestMessage {
 658                    role: message.role,
 659                    content: message.content.clone(),
 660                    cache: false,
 661                }
 662            })
 663            .collect();
 664        messages
 665    }
 666
 667    pub fn to_markdown(&self) -> String {
 668        let mut markdown = String::new();
 669        for message in &self.messages {
 670            markdown.push_str(&message.to_markdown());
 671        }
 672        markdown
 673    }
 674}
 675
 676pub trait AgentTool
 677where
 678    Self: 'static + Sized,
 679{
 680    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
 681    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
 682
 683    fn name(&self) -> SharedString;
 684
 685    fn description(&self, _cx: &mut App) -> SharedString {
 686        let schema = schemars::schema_for!(Self::Input);
 687        SharedString::new(
 688            schema
 689                .get("description")
 690                .and_then(|description| description.as_str())
 691                .unwrap_or_default(),
 692        )
 693    }
 694
 695    fn kind(&self) -> acp::ToolKind;
 696
 697    /// The initial tool title to display. Can be updated during the tool run.
 698    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
 699
 700    /// Returns the JSON schema that describes the tool's input.
 701    fn input_schema(&self) -> Schema {
 702        schemars::schema_for!(Self::Input)
 703    }
 704
 705    /// Runs the tool with the provided input.
 706    fn run(
 707        self: Arc<Self>,
 708        input: Self::Input,
 709        event_stream: ToolCallEventStream,
 710        cx: &mut App,
 711    ) -> Task<Result<Self::Output>>;
 712
 713    fn erase(self) -> Arc<dyn AnyAgentTool> {
 714        Arc::new(Erased(Arc::new(self)))
 715    }
 716}
 717
 718pub struct Erased<T>(T);
 719
 720pub struct AgentToolOutput {
 721    llm_output: LanguageModelToolResultContent,
 722    raw_output: serde_json::Value,
 723}
 724
 725pub trait AnyAgentTool {
 726    fn name(&self) -> SharedString;
 727    fn description(&self, cx: &mut App) -> SharedString;
 728    fn kind(&self) -> acp::ToolKind;
 729    fn initial_title(&self, input: serde_json::Value) -> SharedString;
 730    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
 731    fn run(
 732        self: Arc<Self>,
 733        input: serde_json::Value,
 734        event_stream: ToolCallEventStream,
 735        cx: &mut App,
 736    ) -> Task<Result<AgentToolOutput>>;
 737}
 738
 739impl<T> AnyAgentTool for Erased<Arc<T>>
 740where
 741    T: AgentTool,
 742{
 743    fn name(&self) -> SharedString {
 744        self.0.name()
 745    }
 746
 747    fn description(&self, cx: &mut App) -> SharedString {
 748        self.0.description(cx)
 749    }
 750
 751    fn kind(&self) -> agent_client_protocol::ToolKind {
 752        self.0.kind()
 753    }
 754
 755    fn initial_title(&self, input: serde_json::Value) -> SharedString {
 756        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
 757        self.0.initial_title(parsed_input)
 758    }
 759
 760    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 761        let mut json = serde_json::to_value(self.0.input_schema())?;
 762        adapt_schema_to_format(&mut json, format)?;
 763        Ok(json)
 764    }
 765
 766    fn run(
 767        self: Arc<Self>,
 768        input: serde_json::Value,
 769        event_stream: ToolCallEventStream,
 770        cx: &mut App,
 771    ) -> Task<Result<AgentToolOutput>> {
 772        cx.spawn(async move |cx| {
 773            let input = serde_json::from_value(input)?;
 774            let output = cx
 775                .update(|cx| self.0.clone().run(input, event_stream, cx))?
 776                .await?;
 777            let raw_output = serde_json::to_value(&output)?;
 778            Ok(AgentToolOutput {
 779                llm_output: output.into(),
 780                raw_output,
 781            })
 782        })
 783    }
 784}
 785
 786#[derive(Clone)]
 787struct AgentResponseEventStream(
 788    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 789);
 790
 791impl AgentResponseEventStream {
 792    fn send_text(&self, text: &str) {
 793        self.0
 794            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
 795            .ok();
 796    }
 797
 798    fn send_thinking(&self, text: &str) {
 799        self.0
 800            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
 801            .ok();
 802    }
 803
 804    fn send_tool_call(
 805        &self,
 806        id: &LanguageModelToolUseId,
 807        title: SharedString,
 808        kind: acp::ToolKind,
 809        input: serde_json::Value,
 810    ) {
 811        self.0
 812            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
 813                id,
 814                title.to_string(),
 815                kind,
 816                input,
 817            ))))
 818            .ok();
 819    }
 820
 821    fn initial_tool_call(
 822        id: &LanguageModelToolUseId,
 823        title: String,
 824        kind: acp::ToolKind,
 825        input: serde_json::Value,
 826    ) -> acp::ToolCall {
 827        acp::ToolCall {
 828            id: acp::ToolCallId(id.to_string().into()),
 829            title,
 830            kind,
 831            status: acp::ToolCallStatus::Pending,
 832            content: vec![],
 833            locations: vec![],
 834            raw_input: Some(input),
 835            raw_output: None,
 836        }
 837    }
 838
 839    fn update_tool_call_fields(
 840        &self,
 841        tool_use_id: &LanguageModelToolUseId,
 842        fields: acp::ToolCallUpdateFields,
 843    ) {
 844        self.0
 845            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 846                acp::ToolCallUpdate {
 847                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 848                    fields,
 849                }
 850                .into(),
 851            )))
 852            .ok();
 853    }
 854
 855    fn send_stop(&self, reason: StopReason) {
 856        match reason {
 857            StopReason::EndTurn => {
 858                self.0
 859                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
 860                    .ok();
 861            }
 862            StopReason::MaxTokens => {
 863                self.0
 864                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
 865                    .ok();
 866            }
 867            StopReason::Refusal => {
 868                self.0
 869                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
 870                    .ok();
 871            }
 872            StopReason::ToolUse => {}
 873        }
 874    }
 875
 876    fn send_error(&self, error: LanguageModelCompletionError) {
 877        self.0.unbounded_send(Err(error)).ok();
 878    }
 879}
 880
 881#[derive(Clone)]
 882pub struct ToolCallEventStream {
 883    tool_use_id: LanguageModelToolUseId,
 884    kind: acp::ToolKind,
 885    input: serde_json::Value,
 886    stream: AgentResponseEventStream,
 887}
 888
 889impl ToolCallEventStream {
 890    #[cfg(test)]
 891    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
 892        let (events_tx, events_rx) =
 893            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 894
 895        let stream = ToolCallEventStream::new(
 896            &LanguageModelToolUse {
 897                id: "test_id".into(),
 898                name: "test_tool".into(),
 899                raw_input: String::new(),
 900                input: serde_json::Value::Null,
 901                is_input_complete: true,
 902            },
 903            acp::ToolKind::Other,
 904            AgentResponseEventStream(events_tx),
 905        );
 906
 907        (stream, ToolCallEventStreamReceiver(events_rx))
 908    }
 909
 910    fn new(
 911        tool_use: &LanguageModelToolUse,
 912        kind: acp::ToolKind,
 913        stream: AgentResponseEventStream,
 914    ) -> Self {
 915        Self {
 916            tool_use_id: tool_use.id.clone(),
 917            kind,
 918            input: tool_use.input.clone(),
 919            stream,
 920        }
 921    }
 922
 923    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
 924        self.stream
 925            .update_tool_call_fields(&self.tool_use_id, fields);
 926    }
 927
 928    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
 929        self.stream
 930            .0
 931            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 932                acp_thread::ToolCallUpdateDiff {
 933                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
 934                    diff,
 935                }
 936                .into(),
 937            )))
 938            .ok();
 939    }
 940
 941    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
 942        self.stream
 943            .0
 944            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 945                acp_thread::ToolCallUpdateTerminal {
 946                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
 947                    terminal,
 948                }
 949                .into(),
 950            )))
 951            .ok();
 952    }
 953
 954    pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
 955        let (response_tx, response_rx) = oneshot::channel();
 956        self.stream
 957            .0
 958            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
 959                ToolCallAuthorization {
 960                    tool_call: AgentResponseEventStream::initial_tool_call(
 961                        &self.tool_use_id,
 962                        title,
 963                        self.kind.clone(),
 964                        self.input.clone(),
 965                    ),
 966                    options: vec![
 967                        acp::PermissionOption {
 968                            id: acp::PermissionOptionId("always_allow".into()),
 969                            name: "Always Allow".into(),
 970                            kind: acp::PermissionOptionKind::AllowAlways,
 971                        },
 972                        acp::PermissionOption {
 973                            id: acp::PermissionOptionId("allow".into()),
 974                            name: "Allow".into(),
 975                            kind: acp::PermissionOptionKind::AllowOnce,
 976                        },
 977                        acp::PermissionOption {
 978                            id: acp::PermissionOptionId("deny".into()),
 979                            name: "Deny".into(),
 980                            kind: acp::PermissionOptionKind::RejectOnce,
 981                        },
 982                    ],
 983                    response: response_tx,
 984                },
 985            )))
 986            .ok();
 987        async move {
 988            match response_rx.await?.0.as_ref() {
 989                "allow" | "always_allow" => Ok(()),
 990                _ => Err(anyhow!("Permission to run tool denied by user")),
 991            }
 992        }
 993    }
 994}
 995
 996#[cfg(test)]
 997pub struct ToolCallEventStreamReceiver(
 998    mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 999);
1000
1001#[cfg(test)]
1002impl ToolCallEventStreamReceiver {
1003    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1004        let event = self.0.next().await;
1005        if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1006            auth
1007        } else {
1008            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1009        }
1010    }
1011
1012    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1013        let event = self.0.next().await;
1014        if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1015            acp_thread::ToolCallUpdate::UpdateTerminal(update),
1016        ))) = event
1017        {
1018            update.terminal
1019        } else {
1020            panic!("Expected terminal but got: {:?}", event);
1021        }
1022    }
1023}
1024
1025#[cfg(test)]
1026impl std::ops::Deref for ToolCallEventStreamReceiver {
1027    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1028
1029    fn deref(&self) -> &Self::Target {
1030        &self.0
1031    }
1032}
1033
1034#[cfg(test)]
1035impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1036    fn deref_mut(&mut self) -> &mut Self::Target {
1037        &mut self.0
1038    }
1039}