thread.rs

   1use crate::{SystemPromptTemplate, Template, Templates};
   2use acp_thread::Diff;
   3use action_log::ActionLog;
   4use agent_client_protocol as acp;
   5use anyhow::{Context as _, Result, anyhow};
   6use assistant_tool::adapt_schema_to_format;
   7use cloud_llm_client::{CompletionIntent, CompletionMode};
   8use collections::HashMap;
   9use futures::{
  10    channel::{mpsc, oneshot},
  11    stream::FuturesUnordered,
  12};
  13use gpui::{App, Context, Entity, SharedString, Task};
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  16    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  17    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  18    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
  19};
  20use log;
  21use project::Project;
  22use prompt_store::ProjectContext;
  23use schemars::{JsonSchema, Schema};
  24use serde::{Deserialize, Serialize};
  25use smol::stream::StreamExt;
  26use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
  27use util::{ResultExt, markdown::MarkdownCodeBlock};
  28
  29#[derive(Debug, Clone)]
  30pub struct AgentMessage {
  31    pub role: Role,
  32    pub content: Vec<MessageContent>,
  33}
  34
  35impl AgentMessage {
  36    pub fn to_markdown(&self) -> String {
  37        let mut markdown = format!("## {}\n", self.role);
  38
  39        for content in &self.content {
  40            match content {
  41                MessageContent::Text(text) => {
  42                    markdown.push_str(text);
  43                    markdown.push('\n');
  44                }
  45                MessageContent::Thinking { text, .. } => {
  46                    markdown.push_str("<think>");
  47                    markdown.push_str(text);
  48                    markdown.push_str("</think>\n");
  49                }
  50                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
  51                MessageContent::Image(_) => {
  52                    markdown.push_str("<image />\n");
  53                }
  54                MessageContent::ToolUse(tool_use) => {
  55                    markdown.push_str(&format!(
  56                        "**Tool Use**: {} (ID: {})\n",
  57                        tool_use.name, tool_use.id
  58                    ));
  59                    markdown.push_str(&format!(
  60                        "{}\n",
  61                        MarkdownCodeBlock {
  62                            tag: "json",
  63                            text: &format!("{:#}", tool_use.input)
  64                        }
  65                    ));
  66                }
  67                MessageContent::ToolResult(tool_result) => {
  68                    markdown.push_str(&format!(
  69                        "**Tool Result**: {} (ID: {})\n\n",
  70                        tool_result.tool_name, tool_result.tool_use_id
  71                    ));
  72                    if tool_result.is_error {
  73                        markdown.push_str("**ERROR:**\n");
  74                    }
  75
  76                    match &tool_result.content {
  77                        LanguageModelToolResultContent::Text(text) => {
  78                            writeln!(markdown, "{text}\n").ok();
  79                        }
  80                        LanguageModelToolResultContent::Image(_) => {
  81                            writeln!(markdown, "<image />\n").ok();
  82                        }
  83                    }
  84
  85                    if let Some(output) = tool_result.output.as_ref() {
  86                        writeln!(
  87                            markdown,
  88                            "**Debug Output**:\n\n```json\n{}\n```\n",
  89                            serde_json::to_string_pretty(output).unwrap()
  90                        )
  91                        .unwrap();
  92                    }
  93                }
  94            }
  95        }
  96
  97        markdown
  98    }
  99}
 100
 101#[derive(Debug)]
 102pub enum AgentResponseEvent {
 103    Text(String),
 104    Thinking(String),
 105    ToolCall(acp::ToolCall),
 106    ToolCallUpdate(acp_thread::ToolCallUpdate),
 107    ToolCallAuthorization(ToolCallAuthorization),
 108    Stop(acp::StopReason),
 109}
 110
 111#[derive(Debug)]
 112pub struct ToolCallAuthorization {
 113    pub tool_call: acp::ToolCall,
 114    pub options: Vec<acp::PermissionOption>,
 115    pub response: oneshot::Sender<acp::PermissionOptionId>,
 116}
 117
 118pub struct Thread {
 119    messages: Vec<AgentMessage>,
 120    completion_mode: CompletionMode,
 121    /// Holds the task that handles agent interaction until the end of the turn.
 122    /// Survives across multiple requests as the model performs tool calls and
 123    /// we run tools, report their results.
 124    running_turn: Option<Task<()>>,
 125    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
 126    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 127    project_context: Rc<RefCell<ProjectContext>>,
 128    templates: Arc<Templates>,
 129    pub selected_model: Arc<dyn LanguageModel>,
 130    project: Entity<Project>,
 131    action_log: Entity<ActionLog>,
 132}
 133
 134impl Thread {
 135    pub fn new(
 136        project: Entity<Project>,
 137        project_context: Rc<RefCell<ProjectContext>>,
 138        action_log: Entity<ActionLog>,
 139        templates: Arc<Templates>,
 140        default_model: Arc<dyn LanguageModel>,
 141    ) -> Self {
 142        Self {
 143            messages: Vec::new(),
 144            completion_mode: CompletionMode::Normal,
 145            running_turn: None,
 146            pending_tool_uses: HashMap::default(),
 147            tools: BTreeMap::default(),
 148            project_context,
 149            templates,
 150            selected_model: default_model,
 151            project,
 152            action_log,
 153        }
 154    }
 155
 156    pub fn project(&self) -> &Entity<Project> {
 157        &self.project
 158    }
 159
 160    pub fn action_log(&self) -> &Entity<ActionLog> {
 161        &self.action_log
 162    }
 163
 164    pub fn set_mode(&mut self, mode: CompletionMode) {
 165        self.completion_mode = mode;
 166    }
 167
 168    pub fn messages(&self) -> &[AgentMessage] {
 169        &self.messages
 170    }
 171
 172    pub fn add_tool(&mut self, tool: impl AgentTool) {
 173        self.tools.insert(tool.name(), tool.erase());
 174    }
 175
 176    pub fn remove_tool(&mut self, name: &str) -> bool {
 177        self.tools.remove(name).is_some()
 178    }
 179
 180    pub fn cancel(&mut self) {
 181        self.running_turn.take();
 182
 183        let tool_results = self
 184            .pending_tool_uses
 185            .drain()
 186            .map(|(tool_use_id, tool_use)| {
 187                MessageContent::ToolResult(LanguageModelToolResult {
 188                    tool_use_id,
 189                    tool_name: tool_use.name.clone(),
 190                    is_error: true,
 191                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
 192                    output: None,
 193                })
 194            })
 195            .collect::<Vec<_>>();
 196        self.last_user_message().content.extend(tool_results);
 197    }
 198
 199    /// Sending a message results in the model streaming a response, which could include tool calls.
 200    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 201    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 202    pub fn send(
 203        &mut self,
 204        content: impl Into<MessageContent>,
 205        cx: &mut Context<Self>,
 206    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 207        let content = content.into();
 208        let model = self.selected_model.clone();
 209        log::info!("Thread::send called with model: {:?}", model.name());
 210        log::debug!("Thread::send content: {:?}", content);
 211
 212        cx.notify();
 213        let (events_tx, events_rx) =
 214            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 215        let event_stream = AgentResponseEventStream(events_tx);
 216
 217        let user_message_ix = self.messages.len();
 218        self.messages.push(AgentMessage {
 219            role: Role::User,
 220            content: vec![content],
 221        });
 222        log::info!("Total messages in thread: {}", self.messages.len());
 223        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 224            log::info!("Starting agent turn execution");
 225            let turn_result = async {
 226                // Perform one request, then keep looping if the model makes tool calls.
 227                let mut completion_intent = CompletionIntent::UserPrompt;
 228                'outer: loop {
 229                    log::debug!(
 230                        "Building completion request with intent: {:?}",
 231                        completion_intent
 232                    );
 233                    let request = thread.update(cx, |thread, cx| {
 234                        thread.build_completion_request(completion_intent, cx)
 235                    })?;
 236
 237                    // println!(
 238                    //     "request: {}",
 239                    //     serde_json::to_string_pretty(&request).unwrap()
 240                    // );
 241
 242                    // Stream events, appending to messages and collecting up tool uses.
 243                    log::info!("Calling model.stream_completion");
 244                    let mut events = model.stream_completion(request, cx).await?;
 245                    log::debug!("Stream completion started successfully");
 246                    let mut tool_uses = FuturesUnordered::new();
 247                    while let Some(event) = events.next().await {
 248                        match event {
 249                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 250                                event_stream.send_stop(reason);
 251                                if reason == StopReason::Refusal {
 252                                    thread.update(cx, |thread, _cx| {
 253                                        thread.messages.truncate(user_message_ix);
 254                                    })?;
 255                                    break 'outer;
 256                                }
 257                            }
 258                            Ok(event) => {
 259                                log::trace!("Received completion event: {:?}", event);
 260                                thread
 261                                    .update(cx, |thread, cx| {
 262                                        tool_uses.extend(thread.handle_streamed_completion_event(
 263                                            event,
 264                                            &event_stream,
 265                                            cx,
 266                                        ));
 267                                    })
 268                                    .ok();
 269                            }
 270                            Err(error) => {
 271                                log::error!("Error in completion stream: {:?}", error);
 272                                event_stream.send_error(error);
 273                                break;
 274                            }
 275                        }
 276                    }
 277
 278                    // If there are no tool uses, the turn is done.
 279                    if tool_uses.is_empty() {
 280                        log::info!("No tool uses found, completing turn");
 281                        break;
 282                    }
 283                    log::info!("Found {} tool uses to execute", tool_uses.len());
 284
 285                    // As tool results trickle in, insert them in the last user
 286                    // message so that they can be sent on the next tick of the
 287                    // agentic loop.
 288                    while let Some(tool_result) = tool_uses.next().await {
 289                        log::info!("Tool finished {:?}", tool_result);
 290
 291                        event_stream.update_tool_call_fields(
 292                            &tool_result.tool_use_id,
 293                            acp::ToolCallUpdateFields {
 294                                status: Some(if tool_result.is_error {
 295                                    acp::ToolCallStatus::Failed
 296                                } else {
 297                                    acp::ToolCallStatus::Completed
 298                                }),
 299                                ..Default::default()
 300                            },
 301                        );
 302                        thread
 303                            .update(cx, |thread, _cx| {
 304                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
 305                                thread
 306                                    .last_user_message()
 307                                    .content
 308                                    .push(MessageContent::ToolResult(tool_result));
 309                            })
 310                            .ok();
 311                    }
 312
 313                    completion_intent = CompletionIntent::ToolResults;
 314                }
 315
 316                Ok(())
 317            }
 318            .await;
 319
 320            if let Err(error) = turn_result {
 321                log::error!("Turn execution failed: {:?}", error);
 322                event_stream.send_error(error);
 323            } else {
 324                log::info!("Turn execution completed successfully");
 325            }
 326        }));
 327        events_rx
 328    }
 329
 330    pub fn build_system_message(&self) -> AgentMessage {
 331        log::debug!("Building system message");
 332        let prompt = SystemPromptTemplate {
 333            project: &self.project_context.borrow(),
 334            available_tools: self.tools.keys().cloned().collect(),
 335        }
 336        .render(&self.templates)
 337        .context("failed to build system prompt")
 338        .expect("Invalid template");
 339        log::debug!("System message built");
 340        AgentMessage {
 341            role: Role::System,
 342            content: vec![prompt.into()],
 343        }
 344    }
 345
 346    /// A helper method that's called on every streamed completion event.
 347    /// Returns an optional tool result task, which the main agentic loop in
 348    /// send will send back to the model when it resolves.
 349    fn handle_streamed_completion_event(
 350        &mut self,
 351        event: LanguageModelCompletionEvent,
 352        event_stream: &AgentResponseEventStream,
 353        cx: &mut Context<Self>,
 354    ) -> Option<Task<LanguageModelToolResult>> {
 355        log::trace!("Handling streamed completion event: {:?}", event);
 356        use LanguageModelCompletionEvent::*;
 357
 358        match event {
 359            StartMessage { .. } => {
 360                self.messages.push(AgentMessage {
 361                    role: Role::Assistant,
 362                    content: Vec::new(),
 363                });
 364            }
 365            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 366            Thinking { text, signature } => {
 367                self.handle_thinking_event(text, signature, event_stream, cx)
 368            }
 369            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 370            ToolUse(tool_use) => {
 371                return self.handle_tool_use_event(tool_use, event_stream, cx);
 372            }
 373            ToolUseJsonParseError {
 374                id,
 375                tool_name,
 376                raw_input,
 377                json_parse_error,
 378            } => {
 379                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 380                    id,
 381                    tool_name,
 382                    raw_input,
 383                    json_parse_error,
 384                )));
 385            }
 386            UsageUpdate(_) | StatusUpdate(_) => {}
 387            Stop(_) => unreachable!(),
 388        }
 389
 390        None
 391    }
 392
 393    fn handle_text_event(
 394        &mut self,
 395        new_text: String,
 396        events_stream: &AgentResponseEventStream,
 397        cx: &mut Context<Self>,
 398    ) {
 399        events_stream.send_text(&new_text);
 400
 401        let last_message = self.last_assistant_message();
 402        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
 403            text.push_str(&new_text);
 404        } else {
 405            last_message.content.push(MessageContent::Text(new_text));
 406        }
 407
 408        cx.notify();
 409    }
 410
 411    fn handle_thinking_event(
 412        &mut self,
 413        new_text: String,
 414        new_signature: Option<String>,
 415        event_stream: &AgentResponseEventStream,
 416        cx: &mut Context<Self>,
 417    ) {
 418        event_stream.send_thinking(&new_text);
 419
 420        let last_message = self.last_assistant_message();
 421        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
 422        {
 423            text.push_str(&new_text);
 424            *signature = new_signature.or(signature.take());
 425        } else {
 426            last_message.content.push(MessageContent::Thinking {
 427                text: new_text,
 428                signature: new_signature,
 429            });
 430        }
 431
 432        cx.notify();
 433    }
 434
 435    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 436        let last_message = self.last_assistant_message();
 437        last_message
 438            .content
 439            .push(MessageContent::RedactedThinking(data));
 440        cx.notify();
 441    }
 442
 443    fn handle_tool_use_event(
 444        &mut self,
 445        tool_use: LanguageModelToolUse,
 446        event_stream: &AgentResponseEventStream,
 447        cx: &mut Context<Self>,
 448    ) -> Option<Task<LanguageModelToolResult>> {
 449        cx.notify();
 450
 451        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 452
 453        self.pending_tool_uses
 454            .insert(tool_use.id.clone(), tool_use.clone());
 455        let last_message = self.last_assistant_message();
 456
 457        // Ensure the last message ends in the current tool use
 458        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 459            if let MessageContent::ToolUse(last_tool_use) = content {
 460                if last_tool_use.id == tool_use.id {
 461                    *last_tool_use = tool_use.clone();
 462                    false
 463                } else {
 464                    true
 465                }
 466            } else {
 467                true
 468            }
 469        });
 470
 471        let mut title = SharedString::from(&tool_use.name);
 472        let mut kind = acp::ToolKind::Other;
 473        if let Some(tool) = tool.as_ref() {
 474            title = tool.initial_title(tool_use.input.clone());
 475            kind = tool.kind();
 476        }
 477
 478        if push_new_tool_use {
 479            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 480            last_message
 481                .content
 482                .push(MessageContent::ToolUse(tool_use.clone()));
 483        } else {
 484            event_stream.update_tool_call_fields(
 485                &tool_use.id,
 486                acp::ToolCallUpdateFields {
 487                    title: Some(title.into()),
 488                    kind: Some(kind),
 489                    raw_input: Some(tool_use.input.clone()),
 490                    ..Default::default()
 491                },
 492            );
 493        }
 494
 495        if !tool_use.is_input_complete {
 496            return None;
 497        }
 498
 499        let Some(tool) = tool else {
 500            let content = format!("No tool named {} exists", tool_use.name);
 501            return Some(Task::ready(LanguageModelToolResult {
 502                content: LanguageModelToolResultContent::Text(Arc::from(content)),
 503                tool_use_id: tool_use.id,
 504                tool_name: tool_use.name,
 505                is_error: true,
 506                output: None,
 507            }));
 508        };
 509
 510        let tool_event_stream =
 511            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
 512        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
 513            status: Some(acp::ToolCallStatus::InProgress),
 514            ..Default::default()
 515        });
 516        let supports_images = self.selected_model.supports_images();
 517        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
 518        Some(cx.foreground_executor().spawn(async move {
 519            let tool_result = tool_result.await.and_then(|output| {
 520                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
 521                    if !supports_images {
 522                        return Err(anyhow!(
 523                            "Attempted to read an image, but this model doesn't support it.",
 524                        ));
 525                    }
 526                }
 527                Ok(output)
 528            });
 529
 530            match tool_result {
 531                Ok(output) => LanguageModelToolResult {
 532                    tool_use_id: tool_use.id,
 533                    tool_name: tool_use.name,
 534                    is_error: false,
 535                    content: output.llm_output,
 536                    output: Some(output.raw_output),
 537                },
 538                Err(error) => LanguageModelToolResult {
 539                    tool_use_id: tool_use.id,
 540                    tool_name: tool_use.name,
 541                    is_error: true,
 542                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
 543                    output: None,
 544                },
 545            }
 546        }))
 547    }
 548
 549    fn handle_tool_use_json_parse_error_event(
 550        &mut self,
 551        tool_use_id: LanguageModelToolUseId,
 552        tool_name: Arc<str>,
 553        raw_input: Arc<str>,
 554        json_parse_error: String,
 555    ) -> LanguageModelToolResult {
 556        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
 557        LanguageModelToolResult {
 558            tool_use_id,
 559            tool_name,
 560            is_error: true,
 561            content: LanguageModelToolResultContent::Text(tool_output.into()),
 562            output: Some(serde_json::Value::String(raw_input.to_string())),
 563        }
 564    }
 565
 566    /// Guarantees the last message is from the assistant and returns a mutable reference.
 567    fn last_assistant_message(&mut self) -> &mut AgentMessage {
 568        if self
 569            .messages
 570            .last()
 571            .map_or(true, |m| m.role != Role::Assistant)
 572        {
 573            self.messages.push(AgentMessage {
 574                role: Role::Assistant,
 575                content: Vec::new(),
 576            });
 577        }
 578        self.messages.last_mut().unwrap()
 579    }
 580
 581    /// Guarantees the last message is from the user and returns a mutable reference.
 582    fn last_user_message(&mut self) -> &mut AgentMessage {
 583        if self.messages.last().map_or(true, |m| m.role != Role::User) {
 584            self.messages.push(AgentMessage {
 585                role: Role::User,
 586                content: Vec::new(),
 587            });
 588        }
 589        self.messages.last_mut().unwrap()
 590    }
 591
 592    pub(crate) fn build_completion_request(
 593        &self,
 594        completion_intent: CompletionIntent,
 595        cx: &mut App,
 596    ) -> LanguageModelRequest {
 597        log::debug!("Building completion request");
 598        log::debug!("Completion intent: {:?}", completion_intent);
 599        log::debug!("Completion mode: {:?}", self.completion_mode);
 600
 601        let messages = self.build_request_messages();
 602        log::info!("Request will include {} messages", messages.len());
 603
 604        let tools: Vec<LanguageModelRequestTool> = self
 605            .tools
 606            .values()
 607            .filter_map(|tool| {
 608                let tool_name = tool.name().to_string();
 609                log::trace!("Including tool: {}", tool_name);
 610                Some(LanguageModelRequestTool {
 611                    name: tool_name,
 612                    description: tool.description(cx).to_string(),
 613                    input_schema: tool
 614                        .input_schema(self.selected_model.tool_input_format())
 615                        .log_err()?,
 616                })
 617            })
 618            .collect();
 619
 620        log::info!("Request includes {} tools", tools.len());
 621
 622        let request = LanguageModelRequest {
 623            thread_id: None,
 624            prompt_id: None,
 625            intent: Some(completion_intent),
 626            mode: Some(self.completion_mode),
 627            messages,
 628            tools,
 629            tool_choice: None,
 630            stop: Vec::new(),
 631            temperature: None,
 632            thinking_allowed: true,
 633        };
 634
 635        log::debug!("Completion request built successfully");
 636        request
 637    }
 638
 639    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
 640        log::trace!(
 641            "Building request messages from {} thread messages",
 642            self.messages.len()
 643        );
 644
 645        let messages = Some(self.build_system_message())
 646            .iter()
 647            .chain(self.messages.iter())
 648            .map(|message| {
 649                log::trace!(
 650                    "  - {} message with {} content items",
 651                    match message.role {
 652                        Role::System => "System",
 653                        Role::User => "User",
 654                        Role::Assistant => "Assistant",
 655                    },
 656                    message.content.len()
 657                );
 658                LanguageModelRequestMessage {
 659                    role: message.role,
 660                    content: message.content.clone(),
 661                    cache: false,
 662                }
 663            })
 664            .collect();
 665        messages
 666    }
 667
 668    pub fn to_markdown(&self) -> String {
 669        let mut markdown = String::new();
 670        for message in &self.messages {
 671            markdown.push_str(&message.to_markdown());
 672        }
 673        markdown
 674    }
 675}
 676
 677pub trait AgentTool
 678where
 679    Self: 'static + Sized,
 680{
 681    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
 682    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
 683
 684    fn name(&self) -> SharedString;
 685
 686    fn description(&self, _cx: &mut App) -> SharedString {
 687        let schema = schemars::schema_for!(Self::Input);
 688        SharedString::new(
 689            schema
 690                .get("description")
 691                .and_then(|description| description.as_str())
 692                .unwrap_or_default(),
 693        )
 694    }
 695
 696    fn kind(&self) -> acp::ToolKind;
 697
 698    /// The initial tool title to display. Can be updated during the tool run.
 699    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
 700
 701    /// Returns the JSON schema that describes the tool's input.
 702    fn input_schema(&self) -> Schema {
 703        schemars::schema_for!(Self::Input)
 704    }
 705
 706    /// Runs the tool with the provided input.
 707    fn run(
 708        self: Arc<Self>,
 709        input: Self::Input,
 710        event_stream: ToolCallEventStream,
 711        cx: &mut App,
 712    ) -> Task<Result<Self::Output>>;
 713
 714    fn erase(self) -> Arc<dyn AnyAgentTool> {
 715        Arc::new(Erased(Arc::new(self)))
 716    }
 717}
 718
 719pub struct Erased<T>(T);
 720
 721pub struct AgentToolOutput {
 722    llm_output: LanguageModelToolResultContent,
 723    raw_output: serde_json::Value,
 724}
 725
 726pub trait AnyAgentTool {
 727    fn name(&self) -> SharedString;
 728    fn description(&self, cx: &mut App) -> SharedString;
 729    fn kind(&self) -> acp::ToolKind;
 730    fn initial_title(&self, input: serde_json::Value) -> SharedString;
 731    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
 732    fn run(
 733        self: Arc<Self>,
 734        input: serde_json::Value,
 735        event_stream: ToolCallEventStream,
 736        cx: &mut App,
 737    ) -> Task<Result<AgentToolOutput>>;
 738}
 739
 740impl<T> AnyAgentTool for Erased<Arc<T>>
 741where
 742    T: AgentTool,
 743{
 744    fn name(&self) -> SharedString {
 745        self.0.name()
 746    }
 747
 748    fn description(&self, cx: &mut App) -> SharedString {
 749        self.0.description(cx)
 750    }
 751
 752    fn kind(&self) -> agent_client_protocol::ToolKind {
 753        self.0.kind()
 754    }
 755
 756    fn initial_title(&self, input: serde_json::Value) -> SharedString {
 757        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
 758        self.0.initial_title(parsed_input)
 759    }
 760
 761    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 762        let mut json = serde_json::to_value(self.0.input_schema())?;
 763        adapt_schema_to_format(&mut json, format)?;
 764        Ok(json)
 765    }
 766
 767    fn run(
 768        self: Arc<Self>,
 769        input: serde_json::Value,
 770        event_stream: ToolCallEventStream,
 771        cx: &mut App,
 772    ) -> Task<Result<AgentToolOutput>> {
 773        cx.spawn(async move |cx| {
 774            let input = serde_json::from_value(input)?;
 775            let output = cx
 776                .update(|cx| self.0.clone().run(input, event_stream, cx))?
 777                .await?;
 778            let raw_output = serde_json::to_value(&output)?;
 779            Ok(AgentToolOutput {
 780                llm_output: output.into(),
 781                raw_output,
 782            })
 783        })
 784    }
 785}
 786
 787#[derive(Clone)]
 788struct AgentResponseEventStream(
 789    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 790);
 791
 792impl AgentResponseEventStream {
 793    fn send_text(&self, text: &str) {
 794        self.0
 795            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
 796            .ok();
 797    }
 798
 799    fn send_thinking(&self, text: &str) {
 800        self.0
 801            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
 802            .ok();
 803    }
 804
 805    fn authorize_tool_call(
 806        &self,
 807        id: &LanguageModelToolUseId,
 808        title: String,
 809        kind: acp::ToolKind,
 810        input: serde_json::Value,
 811    ) -> impl use<> + Future<Output = Result<()>> {
 812        let (response_tx, response_rx) = oneshot::channel();
 813        self.0
 814            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
 815                ToolCallAuthorization {
 816                    tool_call: Self::initial_tool_call(id, title, kind, input),
 817                    options: vec![
 818                        acp::PermissionOption {
 819                            id: acp::PermissionOptionId("always_allow".into()),
 820                            name: "Always Allow".into(),
 821                            kind: acp::PermissionOptionKind::AllowAlways,
 822                        },
 823                        acp::PermissionOption {
 824                            id: acp::PermissionOptionId("allow".into()),
 825                            name: "Allow".into(),
 826                            kind: acp::PermissionOptionKind::AllowOnce,
 827                        },
 828                        acp::PermissionOption {
 829                            id: acp::PermissionOptionId("deny".into()),
 830                            name: "Deny".into(),
 831                            kind: acp::PermissionOptionKind::RejectOnce,
 832                        },
 833                    ],
 834                    response: response_tx,
 835                },
 836            )))
 837            .ok();
 838        async move {
 839            match response_rx.await?.0.as_ref() {
 840                "allow" | "always_allow" => Ok(()),
 841                _ => Err(anyhow!("Permission to run tool denied by user")),
 842            }
 843        }
 844    }
 845
 846    fn send_tool_call(
 847        &self,
 848        id: &LanguageModelToolUseId,
 849        title: SharedString,
 850        kind: acp::ToolKind,
 851        input: serde_json::Value,
 852    ) {
 853        self.0
 854            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
 855                id,
 856                title.to_string(),
 857                kind,
 858                input,
 859            ))))
 860            .ok();
 861    }
 862
 863    fn initial_tool_call(
 864        id: &LanguageModelToolUseId,
 865        title: String,
 866        kind: acp::ToolKind,
 867        input: serde_json::Value,
 868    ) -> acp::ToolCall {
 869        acp::ToolCall {
 870            id: acp::ToolCallId(id.to_string().into()),
 871            title,
 872            kind,
 873            status: acp::ToolCallStatus::Pending,
 874            content: vec![],
 875            locations: vec![],
 876            raw_input: Some(input),
 877            raw_output: None,
 878        }
 879    }
 880
 881    fn update_tool_call_fields(
 882        &self,
 883        tool_use_id: &LanguageModelToolUseId,
 884        fields: acp::ToolCallUpdateFields,
 885    ) {
 886        self.0
 887            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 888                acp::ToolCallUpdate {
 889                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 890                    fields,
 891                }
 892                .into(),
 893            )))
 894            .ok();
 895    }
 896
 897    fn update_tool_call_diff(&self, tool_use_id: &LanguageModelToolUseId, diff: Entity<Diff>) {
 898        self.0
 899            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 900                acp_thread::ToolCallUpdateDiff {
 901                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 902                    diff,
 903                }
 904                .into(),
 905            )))
 906            .ok();
 907    }
 908
 909    fn send_stop(&self, reason: StopReason) {
 910        match reason {
 911            StopReason::EndTurn => {
 912                self.0
 913                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
 914                    .ok();
 915            }
 916            StopReason::MaxTokens => {
 917                self.0
 918                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
 919                    .ok();
 920            }
 921            StopReason::Refusal => {
 922                self.0
 923                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
 924                    .ok();
 925            }
 926            StopReason::ToolUse => {}
 927        }
 928    }
 929
 930    fn send_error(&self, error: LanguageModelCompletionError) {
 931        self.0.unbounded_send(Err(error)).ok();
 932    }
 933}
 934
 935#[derive(Clone)]
 936pub struct ToolCallEventStream {
 937    tool_use_id: LanguageModelToolUseId,
 938    kind: acp::ToolKind,
 939    input: serde_json::Value,
 940    stream: AgentResponseEventStream,
 941}
 942
 943impl ToolCallEventStream {
 944    #[cfg(test)]
 945    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
 946        let (events_tx, events_rx) =
 947            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 948
 949        let stream = ToolCallEventStream::new(
 950            &LanguageModelToolUse {
 951                id: "test_id".into(),
 952                name: "test_tool".into(),
 953                raw_input: String::new(),
 954                input: serde_json::Value::Null,
 955                is_input_complete: true,
 956            },
 957            acp::ToolKind::Other,
 958            AgentResponseEventStream(events_tx),
 959        );
 960
 961        (stream, ToolCallEventStreamReceiver(events_rx))
 962    }
 963
 964    fn new(
 965        tool_use: &LanguageModelToolUse,
 966        kind: acp::ToolKind,
 967        stream: AgentResponseEventStream,
 968    ) -> Self {
 969        Self {
 970            tool_use_id: tool_use.id.clone(),
 971            kind,
 972            input: tool_use.input.clone(),
 973            stream,
 974        }
 975    }
 976
 977    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
 978        self.stream
 979            .update_tool_call_fields(&self.tool_use_id, fields);
 980    }
 981
 982    pub fn update_diff(&self, diff: Entity<Diff>) {
 983        self.stream.update_tool_call_diff(&self.tool_use_id, diff);
 984    }
 985
 986    pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
 987        self.stream.authorize_tool_call(
 988            &self.tool_use_id,
 989            title,
 990            self.kind.clone(),
 991            self.input.clone(),
 992        )
 993    }
 994}
 995
 996#[cfg(test)]
 997pub struct ToolCallEventStreamReceiver(
 998    mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 999);
1000
1001#[cfg(test)]
1002impl ToolCallEventStreamReceiver {
1003    pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
1004        let event = self.0.next().await;
1005        if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1006            auth
1007        } else {
1008            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1009        }
1010    }
1011}
1012
1013#[cfg(test)]
1014impl std::ops::Deref for ToolCallEventStreamReceiver {
1015    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1016
1017    fn deref(&self) -> &Self::Target {
1018        &self.0
1019    }
1020}
1021
1022#[cfg(test)]
1023impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1024    fn deref_mut(&mut self) -> &mut Self::Target {
1025        &mut self.0
1026    }
1027}