thread.rs

   1use crate::{SystemPromptTemplate, Template, Templates};
   2use acp_thread::Diff;
   3use agent_client_protocol as acp;
   4use anyhow::{anyhow, Context as _, Result};
   5use assistant_tool::{adapt_schema_to_format, ActionLog};
   6use cloud_llm_client::{CompletionIntent, CompletionMode};
   7use collections::HashMap;
   8use futures::{
   9    channel::{mpsc, oneshot},
  10    stream::FuturesUnordered,
  11};
  12use gpui::{App, Context, Entity, SharedString, Task};
  13use language_model::{
  14    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  15    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  16    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  17    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
  18};
  19use log;
  20use project::Project;
  21use prompt_store::ProjectContext;
  22use schemars::{JsonSchema, Schema};
  23use serde::{Deserialize, Serialize};
  24use smol::stream::StreamExt;
  25use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
  26use util::{markdown::MarkdownCodeBlock, ResultExt};
  27
  28#[derive(Debug, Clone)]
  29pub struct AgentMessage {
  30    pub role: Role,
  31    pub content: Vec<MessageContent>,
  32}
  33
  34impl AgentMessage {
  35    pub fn to_markdown(&self) -> String {
  36        let mut markdown = format!("## {}\n", self.role);
  37
  38        for content in &self.content {
  39            match content {
  40                MessageContent::Text(text) => {
  41                    markdown.push_str(text);
  42                    markdown.push('\n');
  43                }
  44                MessageContent::Thinking { text, .. } => {
  45                    markdown.push_str("<think>");
  46                    markdown.push_str(text);
  47                    markdown.push_str("</think>\n");
  48                }
  49                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
  50                MessageContent::Image(_) => {
  51                    markdown.push_str("<image />\n");
  52                }
  53                MessageContent::ToolUse(tool_use) => {
  54                    markdown.push_str(&format!(
  55                        "**Tool Use**: {} (ID: {})\n",
  56                        tool_use.name, tool_use.id
  57                    ));
  58                    markdown.push_str(&format!(
  59                        "{}\n",
  60                        MarkdownCodeBlock {
  61                            tag: "json",
  62                            text: &format!("{:#}", tool_use.input)
  63                        }
  64                    ));
  65                }
  66                MessageContent::ToolResult(tool_result) => {
  67                    markdown.push_str(&format!(
  68                        "**Tool Result**: {} (ID: {})\n\n",
  69                        tool_result.tool_name, tool_result.tool_use_id
  70                    ));
  71                    if tool_result.is_error {
  72                        markdown.push_str("**ERROR:**\n");
  73                    }
  74
  75                    match &tool_result.content {
  76                        LanguageModelToolResultContent::Text(text) => {
  77                            writeln!(markdown, "{text}\n").ok();
  78                        }
  79                        LanguageModelToolResultContent::Image(_) => {
  80                            writeln!(markdown, "<image />\n").ok();
  81                        }
  82                    }
  83
  84                    if let Some(output) = tool_result.output.as_ref() {
  85                        writeln!(
  86                            markdown,
  87                            "**Debug Output**:\n\n```json\n{}\n```\n",
  88                            serde_json::to_string_pretty(output).unwrap()
  89                        )
  90                        .unwrap();
  91                    }
  92                }
  93            }
  94        }
  95
  96        markdown
  97    }
  98}
  99
 100#[derive(Debug)]
 101pub enum AgentResponseEvent {
 102    Text(String),
 103    Thinking(String),
 104    ToolCall(acp::ToolCall),
 105    ToolCallUpdate(acp::ToolCallUpdate),
 106    ToolCallAuthorization(ToolCallAuthorization),
 107    ToolCallDiff(ToolCallDiff),
 108    Stop(acp::StopReason),
 109}
 110
 111#[derive(Debug)]
 112pub struct ToolCallAuthorization {
 113    pub tool_call: acp::ToolCall,
 114    pub options: Vec<acp::PermissionOption>,
 115    pub response: oneshot::Sender<acp::PermissionOptionId>,
 116}
 117
 118#[derive(Debug)]
 119pub struct ToolCallDiff {
 120    pub tool_call_id: acp::ToolCallId,
 121    pub diff: Entity<acp_thread::Diff>,
 122}
 123
 124pub struct Thread {
 125    messages: Vec<AgentMessage>,
 126    completion_mode: CompletionMode,
 127    /// Holds the task that handles agent interaction until the end of the turn.
 128    /// Survives across multiple requests as the model performs tool calls and
 129    /// we run tools, report their results.
 130    running_turn: Option<Task<()>>,
 131    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
 132    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 133    project_context: Rc<RefCell<ProjectContext>>,
 134    templates: Arc<Templates>,
 135    pub selected_model: Arc<dyn LanguageModel>,
 136    project: Entity<Project>,
 137    action_log: Entity<ActionLog>,
 138}
 139
 140impl Thread {
 141    pub fn new(
 142        project: Entity<Project>,
 143        project_context: Rc<RefCell<ProjectContext>>,
 144        action_log: Entity<ActionLog>,
 145        templates: Arc<Templates>,
 146        default_model: Arc<dyn LanguageModel>,
 147    ) -> Self {
 148        Self {
 149            messages: Vec::new(),
 150            completion_mode: CompletionMode::Normal,
 151            running_turn: None,
 152            pending_tool_uses: HashMap::default(),
 153            tools: BTreeMap::default(),
 154            project_context,
 155            templates,
 156            selected_model: default_model,
 157            project,
 158            action_log,
 159        }
 160    }
 161
 162    pub fn project(&self) -> &Entity<Project> {
 163        &self.project
 164    }
 165
 166    pub fn action_log(&self) -> &Entity<ActionLog> {
 167        &self.action_log
 168    }
 169
 170    pub fn set_mode(&mut self, mode: CompletionMode) {
 171        self.completion_mode = mode;
 172    }
 173
 174    pub fn messages(&self) -> &[AgentMessage] {
 175        &self.messages
 176    }
 177
 178    pub fn add_tool(&mut self, tool: impl AgentTool) {
 179        self.tools.insert(tool.name(), tool.erase());
 180    }
 181
 182    pub fn remove_tool(&mut self, name: &str) -> bool {
 183        self.tools.remove(name).is_some()
 184    }
 185
 186    pub fn cancel(&mut self) {
 187        self.running_turn.take();
 188
 189        let tool_results = self
 190            .pending_tool_uses
 191            .drain()
 192            .map(|(tool_use_id, tool_use)| {
 193                MessageContent::ToolResult(LanguageModelToolResult {
 194                    tool_use_id,
 195                    tool_name: tool_use.name.clone(),
 196                    is_error: true,
 197                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
 198                    output: None,
 199                })
 200            })
 201            .collect::<Vec<_>>();
 202        self.last_user_message().content.extend(tool_results);
 203    }
 204
 205    /// Sending a message results in the model streaming a response, which could include tool calls.
 206    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 207    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 208    pub fn send(
 209        &mut self,
 210        model: Arc<dyn LanguageModel>,
 211        content: impl Into<MessageContent>,
 212        cx: &mut Context<Self>,
 213    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 214        let content = content.into();
 215        log::info!("Thread::send called with model: {:?}", model.name());
 216        log::debug!("Thread::send content: {:?}", content);
 217
 218        cx.notify();
 219        let (events_tx, events_rx) =
 220            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 221        let event_stream = AgentResponseEventStream(events_tx);
 222
 223        let user_message_ix = self.messages.len();
 224        self.messages.push(AgentMessage {
 225            role: Role::User,
 226            content: vec![content],
 227        });
 228        log::info!("Total messages in thread: {}", self.messages.len());
 229        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 230            log::info!("Starting agent turn execution");
 231            let turn_result = async {
 232                // Perform one request, then keep looping if the model makes tool calls.
 233                let mut completion_intent = CompletionIntent::UserPrompt;
 234                'outer: loop {
 235                    log::debug!(
 236                        "Building completion request with intent: {:?}",
 237                        completion_intent
 238                    );
 239                    let request = thread.update(cx, |thread, cx| {
 240                        thread.build_completion_request(completion_intent, cx)
 241                    })?;
 242
 243                    // println!(
 244                    //     "request: {}",
 245                    //     serde_json::to_string_pretty(&request).unwrap()
 246                    // );
 247
 248                    // Stream events, appending to messages and collecting up tool uses.
 249                    log::info!("Calling model.stream_completion");
 250                    let mut events = model.stream_completion(request, cx).await?;
 251                    log::debug!("Stream completion started successfully");
 252                    let mut tool_uses = FuturesUnordered::new();
 253                    while let Some(event) = events.next().await {
 254                        match event {
 255                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 256                                event_stream.send_stop(reason);
 257                                if reason == StopReason::Refusal {
 258                                    thread.update(cx, |thread, _cx| {
 259                                        thread.messages.truncate(user_message_ix);
 260                                    })?;
 261                                    break 'outer;
 262                                }
 263                            }
 264                            Ok(event) => {
 265                                log::trace!("Received completion event: {:?}", event);
 266                                thread
 267                                    .update(cx, |thread, cx| {
 268                                        tool_uses.extend(thread.handle_streamed_completion_event(
 269                                            event,
 270                                            &event_stream,
 271                                            cx,
 272                                        ));
 273                                    })
 274                                    .ok();
 275                            }
 276                            Err(error) => {
 277                                log::error!("Error in completion stream: {:?}", error);
 278                                event_stream.send_error(error);
 279                                break;
 280                            }
 281                        }
 282                    }
 283
 284                    // If there are no tool uses, the turn is done.
 285                    if tool_uses.is_empty() {
 286                        log::info!("No tool uses found, completing turn");
 287                        break;
 288                    }
 289                    log::info!("Found {} tool uses to execute", tool_uses.len());
 290
 291                    // As tool results trickle in, insert them in the last user
 292                    // message so that they can be sent on the next tick of the
 293                    // agentic loop.
 294                    while let Some(tool_result) = tool_uses.next().await {
 295                        log::info!("Tool finished {:?}", tool_result);
 296
 297                        event_stream.send_tool_call_update(
 298                            &tool_result.tool_use_id,
 299                            acp::ToolCallUpdateFields {
 300                                status: Some(if tool_result.is_error {
 301                                    acp::ToolCallStatus::Failed
 302                                } else {
 303                                    acp::ToolCallStatus::Completed
 304                                }),
 305                                ..Default::default()
 306                            },
 307                        );
 308                        thread
 309                            .update(cx, |thread, _cx| {
 310                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
 311                                thread
 312                                    .last_user_message()
 313                                    .content
 314                                    .push(MessageContent::ToolResult(tool_result));
 315                            })
 316                            .ok();
 317                    }
 318
 319                    completion_intent = CompletionIntent::ToolResults;
 320                }
 321
 322                Ok(())
 323            }
 324            .await;
 325
 326            if let Err(error) = turn_result {
 327                log::error!("Turn execution failed: {:?}", error);
 328                event_stream.send_error(error);
 329            } else {
 330                log::info!("Turn execution completed successfully");
 331            }
 332        }));
 333        events_rx
 334    }
 335
 336    pub fn build_system_message(&self) -> AgentMessage {
 337        log::debug!("Building system message");
 338        let prompt = SystemPromptTemplate {
 339            project: &self.project_context.borrow(),
 340            available_tools: self.tools.keys().cloned().collect(),
 341        }
 342        .render(&self.templates)
 343        .context("failed to build system prompt")
 344        .expect("Invalid template");
 345        log::debug!("System message built");
 346        AgentMessage {
 347            role: Role::System,
 348            content: vec![prompt.into()],
 349        }
 350    }
 351
 352    /// A helper method that's called on every streamed completion event.
 353    /// Returns an optional tool result task, which the main agentic loop in
 354    /// send will send back to the model when it resolves.
 355    fn handle_streamed_completion_event(
 356        &mut self,
 357        event: LanguageModelCompletionEvent,
 358        event_stream: &AgentResponseEventStream,
 359        cx: &mut Context<Self>,
 360    ) -> Option<Task<LanguageModelToolResult>> {
 361        log::trace!("Handling streamed completion event: {:?}", event);
 362        use LanguageModelCompletionEvent::*;
 363
 364        match event {
 365            StartMessage { .. } => {
 366                self.messages.push(AgentMessage {
 367                    role: Role::Assistant,
 368                    content: Vec::new(),
 369                });
 370            }
 371            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 372            Thinking { text, signature } => {
 373                self.handle_thinking_event(text, signature, event_stream, cx)
 374            }
 375            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 376            ToolUse(tool_use) => {
 377                return self.handle_tool_use_event(tool_use, event_stream, cx);
 378            }
 379            ToolUseJsonParseError {
 380                id,
 381                tool_name,
 382                raw_input,
 383                json_parse_error,
 384            } => {
 385                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 386                    id,
 387                    tool_name,
 388                    raw_input,
 389                    json_parse_error,
 390                )));
 391            }
 392            UsageUpdate(_) | StatusUpdate(_) => {}
 393            Stop(_) => unreachable!(),
 394        }
 395
 396        None
 397    }
 398
 399    fn handle_text_event(
 400        &mut self,
 401        new_text: String,
 402        events_stream: &AgentResponseEventStream,
 403        cx: &mut Context<Self>,
 404    ) {
 405        events_stream.send_text(&new_text);
 406
 407        let last_message = self.last_assistant_message();
 408        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
 409            text.push_str(&new_text);
 410        } else {
 411            last_message.content.push(MessageContent::Text(new_text));
 412        }
 413
 414        cx.notify();
 415    }
 416
 417    fn handle_thinking_event(
 418        &mut self,
 419        new_text: String,
 420        new_signature: Option<String>,
 421        event_stream: &AgentResponseEventStream,
 422        cx: &mut Context<Self>,
 423    ) {
 424        event_stream.send_thinking(&new_text);
 425
 426        let last_message = self.last_assistant_message();
 427        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
 428        {
 429            text.push_str(&new_text);
 430            *signature = new_signature.or(signature.take());
 431        } else {
 432            last_message.content.push(MessageContent::Thinking {
 433                text: new_text,
 434                signature: new_signature,
 435            });
 436        }
 437
 438        cx.notify();
 439    }
 440
 441    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 442        let last_message = self.last_assistant_message();
 443        last_message
 444            .content
 445            .push(MessageContent::RedactedThinking(data));
 446        cx.notify();
 447    }
 448
 449    fn handle_tool_use_event(
 450        &mut self,
 451        tool_use: LanguageModelToolUse,
 452        event_stream: &AgentResponseEventStream,
 453        cx: &mut Context<Self>,
 454    ) -> Option<Task<LanguageModelToolResult>> {
 455        cx.notify();
 456
 457        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 458
 459        self.pending_tool_uses
 460            .insert(tool_use.id.clone(), tool_use.clone());
 461        let last_message = self.last_assistant_message();
 462
 463        // Ensure the last message ends in the current tool use
 464        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 465            if let MessageContent::ToolUse(last_tool_use) = content {
 466                if last_tool_use.id == tool_use.id {
 467                    *last_tool_use = tool_use.clone();
 468                    false
 469                } else {
 470                    true
 471                }
 472            } else {
 473                true
 474            }
 475        });
 476
 477        let mut title = SharedString::from(&tool_use.name);
 478        let mut kind = acp::ToolKind::Other;
 479        if let Some(tool) = tool.as_ref() {
 480            if let Ok(initial_title) = tool.initial_title(tool_use.input.clone()) {
 481                title = initial_title;
 482            }
 483            kind = tool.kind();
 484        }
 485
 486        if push_new_tool_use {
 487            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 488            last_message
 489                .content
 490                .push(MessageContent::ToolUse(tool_use.clone()));
 491        } else {
 492            event_stream.send_tool_call_update(
 493                &tool_use.id,
 494                acp::ToolCallUpdateFields {
 495                    title: Some(title.into()),
 496                    kind: Some(kind),
 497                    raw_input: Some(tool_use.input.clone()),
 498                    ..Default::default()
 499                },
 500            );
 501        }
 502
 503        if !tool_use.is_input_complete {
 504            return None;
 505        }
 506
 507        let Some(tool) = tool else {
 508            let content = format!("No tool named {} exists", tool_use.name);
 509            return Some(Task::ready(LanguageModelToolResult {
 510                content: LanguageModelToolResultContent::Text(Arc::from(content)),
 511                tool_use_id: tool_use.id,
 512                tool_name: tool_use.name,
 513                is_error: true,
 514                output: None,
 515            }));
 516        };
 517
 518        let tool_event_stream =
 519            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
 520        tool_event_stream.send_update(acp::ToolCallUpdateFields {
 521            status: Some(acp::ToolCallStatus::InProgress),
 522            ..Default::default()
 523        });
 524        let supports_images = self.selected_model.supports_images();
 525        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
 526        Some(cx.foreground_executor().spawn(async move {
 527            let tool_result = tool_result.await.and_then(|output| {
 528                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
 529                    if !supports_images {
 530                        return Err(anyhow!(
 531                            "Attempted to read an image, but this model doesn't support it.",
 532                        ));
 533                    }
 534                }
 535                Ok(output)
 536            });
 537
 538            match tool_result {
 539                Ok(output) => LanguageModelToolResult {
 540                    tool_use_id: tool_use.id,
 541                    tool_name: tool_use.name,
 542                    is_error: false,
 543                    content: output.llm_output,
 544                    output: Some(output.raw_output),
 545                },
 546                Err(error) => LanguageModelToolResult {
 547                    tool_use_id: tool_use.id,
 548                    tool_name: tool_use.name,
 549                    is_error: true,
 550                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
 551                    output: None,
 552                },
 553            }
 554        }))
 555    }
 556
 557    fn handle_tool_use_json_parse_error_event(
 558        &mut self,
 559        tool_use_id: LanguageModelToolUseId,
 560        tool_name: Arc<str>,
 561        raw_input: Arc<str>,
 562        json_parse_error: String,
 563    ) -> LanguageModelToolResult {
 564        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
 565        LanguageModelToolResult {
 566            tool_use_id,
 567            tool_name,
 568            is_error: true,
 569            content: LanguageModelToolResultContent::Text(tool_output.into()),
 570            output: Some(serde_json::Value::String(raw_input.to_string())),
 571        }
 572    }
 573
 574    /// Guarantees the last message is from the assistant and returns a mutable reference.
 575    fn last_assistant_message(&mut self) -> &mut AgentMessage {
 576        if self
 577            .messages
 578            .last()
 579            .map_or(true, |m| m.role != Role::Assistant)
 580        {
 581            self.messages.push(AgentMessage {
 582                role: Role::Assistant,
 583                content: Vec::new(),
 584            });
 585        }
 586        self.messages.last_mut().unwrap()
 587    }
 588
 589    /// Guarantees the last message is from the user and returns a mutable reference.
 590    fn last_user_message(&mut self) -> &mut AgentMessage {
 591        if self.messages.last().map_or(true, |m| m.role != Role::User) {
 592            self.messages.push(AgentMessage {
 593                role: Role::User,
 594                content: Vec::new(),
 595            });
 596        }
 597        self.messages.last_mut().unwrap()
 598    }
 599
 600    pub(crate) fn build_completion_request(
 601        &self,
 602        completion_intent: CompletionIntent,
 603        cx: &mut App,
 604    ) -> LanguageModelRequest {
 605        log::debug!("Building completion request");
 606        log::debug!("Completion intent: {:?}", completion_intent);
 607        log::debug!("Completion mode: {:?}", self.completion_mode);
 608
 609        let messages = self.build_request_messages();
 610        log::info!("Request will include {} messages", messages.len());
 611
 612        let tools: Vec<LanguageModelRequestTool> = self
 613            .tools
 614            .values()
 615            .filter_map(|tool| {
 616                let tool_name = tool.name().to_string();
 617                log::trace!("Including tool: {}", tool_name);
 618                Some(LanguageModelRequestTool {
 619                    name: tool_name,
 620                    description: tool.description(cx).to_string(),
 621                    input_schema: tool
 622                        .input_schema(self.selected_model.tool_input_format())
 623                        .log_err()?,
 624                })
 625            })
 626            .collect();
 627
 628        log::info!("Request includes {} tools", tools.len());
 629
 630        let request = LanguageModelRequest {
 631            thread_id: None,
 632            prompt_id: None,
 633            intent: Some(completion_intent),
 634            mode: Some(self.completion_mode),
 635            messages,
 636            tools,
 637            tool_choice: None,
 638            stop: Vec::new(),
 639            temperature: None,
 640            thinking_allowed: true,
 641        };
 642
 643        log::debug!("Completion request built successfully");
 644        request
 645    }
 646
 647    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
 648        log::trace!(
 649            "Building request messages from {} thread messages",
 650            self.messages.len()
 651        );
 652
 653        let messages = Some(self.build_system_message())
 654            .iter()
 655            .chain(self.messages.iter())
 656            .map(|message| {
 657                log::trace!(
 658                    "  - {} message with {} content items",
 659                    match message.role {
 660                        Role::System => "System",
 661                        Role::User => "User",
 662                        Role::Assistant => "Assistant",
 663                    },
 664                    message.content.len()
 665                );
 666                LanguageModelRequestMessage {
 667                    role: message.role,
 668                    content: message.content.clone(),
 669                    cache: false,
 670                }
 671            })
 672            .collect();
 673        messages
 674    }
 675
 676    pub fn to_markdown(&self) -> String {
 677        let mut markdown = String::new();
 678        for message in &self.messages {
 679            markdown.push_str(&message.to_markdown());
 680        }
 681        markdown
 682    }
 683}
 684
 685pub trait AgentTool
 686where
 687    Self: 'static + Sized,
 688{
 689    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
 690    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
 691
 692    fn name(&self) -> SharedString;
 693
 694    fn description(&self, _cx: &mut App) -> SharedString {
 695        let schema = schemars::schema_for!(Self::Input);
 696        SharedString::new(
 697            schema
 698                .get("description")
 699                .and_then(|description| description.as_str())
 700                .unwrap_or_default(),
 701        )
 702    }
 703
 704    fn kind(&self) -> acp::ToolKind;
 705
 706    /// The initial tool title to display. Can be updated during the tool run.
 707    fn initial_title(&self, input: Self::Input) -> SharedString;
 708
 709    /// Returns the JSON schema that describes the tool's input.
 710    fn input_schema(&self) -> Schema {
 711        schemars::schema_for!(Self::Input)
 712    }
 713
 714    /// Runs the tool with the provided input.
 715    fn run(
 716        self: Arc<Self>,
 717        input: Self::Input,
 718        event_stream: ToolCallEventStream,
 719        cx: &mut App,
 720    ) -> Task<Result<Self::Output>>;
 721
 722    fn erase(self) -> Arc<dyn AnyAgentTool> {
 723        Arc::new(Erased(Arc::new(self)))
 724    }
 725}
 726
 727pub struct Erased<T>(T);
 728
 729pub struct AgentToolOutput {
 730    llm_output: LanguageModelToolResultContent,
 731    raw_output: serde_json::Value,
 732}
 733
 734pub trait AnyAgentTool {
 735    fn name(&self) -> SharedString;
 736    fn description(&self, cx: &mut App) -> SharedString;
 737    fn kind(&self) -> acp::ToolKind;
 738    fn initial_title(&self, input: serde_json::Value) -> Result<SharedString>;
 739    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
 740    fn run(
 741        self: Arc<Self>,
 742        input: serde_json::Value,
 743        event_stream: ToolCallEventStream,
 744        cx: &mut App,
 745    ) -> Task<Result<AgentToolOutput>>;
 746}
 747
 748impl<T> AnyAgentTool for Erased<Arc<T>>
 749where
 750    T: AgentTool,
 751{
 752    fn name(&self) -> SharedString {
 753        self.0.name()
 754    }
 755
 756    fn description(&self, cx: &mut App) -> SharedString {
 757        self.0.description(cx)
 758    }
 759
 760    fn kind(&self) -> agent_client_protocol::ToolKind {
 761        self.0.kind()
 762    }
 763
 764    fn initial_title(&self, input: serde_json::Value) -> Result<SharedString> {
 765        let parsed_input = serde_json::from_value(input)?;
 766        Ok(self.0.initial_title(parsed_input))
 767    }
 768
 769    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 770        let mut json = serde_json::to_value(self.0.input_schema())?;
 771        adapt_schema_to_format(&mut json, format)?;
 772        Ok(json)
 773    }
 774
 775    fn run(
 776        self: Arc<Self>,
 777        input: serde_json::Value,
 778        event_stream: ToolCallEventStream,
 779        cx: &mut App,
 780    ) -> Task<Result<AgentToolOutput>> {
 781        cx.spawn(async move |cx| {
 782            let input = serde_json::from_value(input)?;
 783            let output = cx
 784                .update(|cx| self.0.clone().run(input, event_stream, cx))?
 785                .await?;
 786            let raw_output = serde_json::to_value(&output)?;
 787            Ok(AgentToolOutput {
 788                llm_output: output.into(),
 789                raw_output,
 790            })
 791        })
 792    }
 793}
 794
 795#[derive(Clone)]
 796struct AgentResponseEventStream(
 797    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 798);
 799
 800impl AgentResponseEventStream {
 801    fn send_text(&self, text: &str) {
 802        self.0
 803            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
 804            .ok();
 805    }
 806
 807    fn send_thinking(&self, text: &str) {
 808        self.0
 809            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
 810            .ok();
 811    }
 812
 813    fn authorize_tool_call(
 814        &self,
 815        id: &LanguageModelToolUseId,
 816        title: String,
 817        kind: acp::ToolKind,
 818        input: serde_json::Value,
 819    ) -> impl use<> + Future<Output = Result<()>> {
 820        let (response_tx, response_rx) = oneshot::channel();
 821        self.0
 822            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
 823                ToolCallAuthorization {
 824                    tool_call: Self::initial_tool_call(id, title, kind, input),
 825                    options: vec![
 826                        acp::PermissionOption {
 827                            id: acp::PermissionOptionId("always_allow".into()),
 828                            name: "Always Allow".into(),
 829                            kind: acp::PermissionOptionKind::AllowAlways,
 830                        },
 831                        acp::PermissionOption {
 832                            id: acp::PermissionOptionId("allow".into()),
 833                            name: "Allow".into(),
 834                            kind: acp::PermissionOptionKind::AllowOnce,
 835                        },
 836                        acp::PermissionOption {
 837                            id: acp::PermissionOptionId("deny".into()),
 838                            name: "Deny".into(),
 839                            kind: acp::PermissionOptionKind::RejectOnce,
 840                        },
 841                    ],
 842                    response: response_tx,
 843                },
 844            )))
 845            .ok();
 846        async move {
 847            match response_rx.await?.0.as_ref() {
 848                "allow" | "always_allow" => Ok(()),
 849                _ => Err(anyhow!("Permission to run tool denied by user")),
 850            }
 851        }
 852    }
 853
 854    fn send_tool_call(
 855        &self,
 856        id: &LanguageModelToolUseId,
 857        title: SharedString,
 858        kind: acp::ToolKind,
 859        input: serde_json::Value,
 860    ) {
 861        self.0
 862            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
 863                id,
 864                title.to_string(),
 865                kind,
 866                input,
 867            ))))
 868            .ok();
 869    }
 870
 871    fn initial_tool_call(
 872        id: &LanguageModelToolUseId,
 873        title: String,
 874        kind: acp::ToolKind,
 875        input: serde_json::Value,
 876    ) -> acp::ToolCall {
 877        acp::ToolCall {
 878            id: acp::ToolCallId(id.to_string().into()),
 879            title,
 880            kind,
 881            status: acp::ToolCallStatus::Pending,
 882            content: vec![],
 883            locations: vec![],
 884            raw_input: Some(input),
 885            raw_output: None,
 886        }
 887    }
 888
 889    fn send_tool_call_update(
 890        &self,
 891        tool_use_id: &LanguageModelToolUseId,
 892        fields: acp::ToolCallUpdateFields,
 893    ) {
 894        self.0
 895            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 896                acp::ToolCallUpdate {
 897                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 898                    fields,
 899                },
 900            )))
 901            .ok();
 902    }
 903
 904    fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) {
 905        self.0
 906            .unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff)))
 907            .ok();
 908    }
 909
 910    fn send_stop(&self, reason: StopReason) {
 911        match reason {
 912            StopReason::EndTurn => {
 913                self.0
 914                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
 915                    .ok();
 916            }
 917            StopReason::MaxTokens => {
 918                self.0
 919                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
 920                    .ok();
 921            }
 922            StopReason::Refusal => {
 923                self.0
 924                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
 925                    .ok();
 926            }
 927            StopReason::ToolUse => {}
 928        }
 929    }
 930
 931    fn send_error(&self, error: LanguageModelCompletionError) {
 932        self.0.unbounded_send(Err(error)).ok();
 933    }
 934}
 935
 936#[derive(Clone)]
 937pub struct ToolCallEventStream {
 938    tool_use_id: LanguageModelToolUseId,
 939    kind: acp::ToolKind,
 940    input: serde_json::Value,
 941    stream: AgentResponseEventStream,
 942}
 943
 944impl ToolCallEventStream {
 945    #[cfg(test)]
 946    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
 947        let (events_tx, events_rx) =
 948            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 949
 950        let stream = ToolCallEventStream::new(
 951            &LanguageModelToolUse {
 952                id: "test_id".into(),
 953                name: "test_tool".into(),
 954                raw_input: String::new(),
 955                input: serde_json::Value::Null,
 956                is_input_complete: true,
 957            },
 958            acp::ToolKind::Other,
 959            AgentResponseEventStream(events_tx),
 960        );
 961
 962        (stream, ToolCallEventStreamReceiver(events_rx))
 963    }
 964
 965    fn new(
 966        tool_use: &LanguageModelToolUse,
 967        kind: acp::ToolKind,
 968        stream: AgentResponseEventStream,
 969    ) -> Self {
 970        Self {
 971            tool_use_id: tool_use.id.clone(),
 972            kind,
 973            input: tool_use.input.clone(),
 974            stream,
 975        }
 976    }
 977
 978    pub fn send_update(&self, fields: acp::ToolCallUpdateFields) {
 979        self.stream.send_tool_call_update(&self.tool_use_id, fields);
 980    }
 981
 982    pub fn send_diff(&self, diff: Entity<Diff>) {
 983        self.stream.send_tool_call_diff(ToolCallDiff {
 984            tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()),
 985            diff,
 986        });
 987    }
 988
 989    pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
 990        self.stream.authorize_tool_call(
 991            &self.tool_use_id,
 992            title,
 993            self.kind.clone(),
 994            self.input.clone(),
 995        )
 996    }
 997}
 998
 999#[cfg(test)]
1000pub struct ToolCallEventStreamReceiver(
1001    mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1002);
1003
1004#[cfg(test)]
1005impl ToolCallEventStreamReceiver {
1006    pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
1007        let event = self.0.next().await;
1008        if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1009            auth
1010        } else {
1011            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1012        }
1013    }
1014}
1015
1016#[cfg(test)]
1017impl std::ops::Deref for ToolCallEventStreamReceiver {
1018    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1019
1020    fn deref(&self) -> &Self::Target {
1021        &self.0
1022    }
1023}
1024
1025#[cfg(test)]
1026impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1027    fn deref_mut(&mut self) -> &mut Self::Target {
1028        &mut self.0
1029    }
1030}