thread.rs

   1use crate::{SystemPromptTemplate, Template, Templates};
   2use acp_thread::Diff;
   3use agent_client_protocol as acp;
   4use anyhow::{anyhow, Context as _, Result};
   5use assistant_tool::{adapt_schema_to_format, ActionLog};
   6use cloud_llm_client::{CompletionIntent, CompletionMode};
   7use collections::HashMap;
   8use futures::{
   9    channel::{mpsc, oneshot},
  10    stream::FuturesUnordered,
  11};
  12use gpui::{App, Context, Entity, SharedString, Task};
  13use language_model::{
  14    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  15    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  16    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  17    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
  18};
  19use log;
  20use project::Project;
  21use prompt_store::ProjectContext;
  22use schemars::{JsonSchema, Schema};
  23use serde::{Deserialize, Serialize};
  24use smol::stream::StreamExt;
  25use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
  26use util::{markdown::MarkdownCodeBlock, ResultExt};
  27
  28#[derive(Debug, Clone)]
  29pub struct AgentMessage {
  30    pub role: Role,
  31    pub content: Vec<MessageContent>,
  32}
  33
  34impl AgentMessage {
  35    pub fn to_markdown(&self) -> String {
  36        let mut markdown = format!("## {}\n", self.role);
  37
  38        for content in &self.content {
  39            match content {
  40                MessageContent::Text(text) => {
  41                    markdown.push_str(text);
  42                    markdown.push('\n');
  43                }
  44                MessageContent::Thinking { text, .. } => {
  45                    markdown.push_str("<think>");
  46                    markdown.push_str(text);
  47                    markdown.push_str("</think>\n");
  48                }
  49                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
  50                MessageContent::Image(_) => {
  51                    markdown.push_str("<image />\n");
  52                }
  53                MessageContent::ToolUse(tool_use) => {
  54                    markdown.push_str(&format!(
  55                        "**Tool Use**: {} (ID: {})\n",
  56                        tool_use.name, tool_use.id
  57                    ));
  58                    markdown.push_str(&format!(
  59                        "{}\n",
  60                        MarkdownCodeBlock {
  61                            tag: "json",
  62                            text: &format!("{:#}", tool_use.input)
  63                        }
  64                    ));
  65                }
  66                MessageContent::ToolResult(tool_result) => {
  67                    markdown.push_str(&format!(
  68                        "**Tool Result**: {} (ID: {})\n\n",
  69                        tool_result.tool_name, tool_result.tool_use_id
  70                    ));
  71                    if tool_result.is_error {
  72                        markdown.push_str("**ERROR:**\n");
  73                    }
  74
  75                    match &tool_result.content {
  76                        LanguageModelToolResultContent::Text(text) => {
  77                            writeln!(markdown, "{text}\n").ok();
  78                        }
  79                        LanguageModelToolResultContent::Image(_) => {
  80                            writeln!(markdown, "<image />\n").ok();
  81                        }
  82                    }
  83
  84                    if let Some(output) = tool_result.output.as_ref() {
  85                        writeln!(
  86                            markdown,
  87                            "**Debug Output**:\n\n```json\n{}\n```\n",
  88                            serde_json::to_string_pretty(output).unwrap()
  89                        )
  90                        .unwrap();
  91                    }
  92                }
  93            }
  94        }
  95
  96        markdown
  97    }
  98}
  99
 100#[derive(Debug)]
 101pub enum AgentResponseEvent {
 102    Text(String),
 103    Thinking(String),
 104    ToolCall(acp::ToolCall),
 105    ToolCallUpdate(acp::ToolCallUpdate),
 106    ToolCallAuthorization(ToolCallAuthorization),
 107    ToolCallDiff(ToolCallDiff),
 108    Stop(acp::StopReason),
 109}
 110
 111#[derive(Debug)]
 112pub struct ToolCallAuthorization {
 113    pub tool_call: acp::ToolCall,
 114    pub options: Vec<acp::PermissionOption>,
 115    pub response: oneshot::Sender<acp::PermissionOptionId>,
 116}
 117
 118#[derive(Debug)]
 119pub struct ToolCallDiff {
 120    pub tool_call_id: acp::ToolCallId,
 121    pub diff: Entity<acp_thread::Diff>,
 122}
 123
 124pub struct Thread {
 125    messages: Vec<AgentMessage>,
 126    completion_mode: CompletionMode,
 127    /// Holds the task that handles agent interaction until the end of the turn.
 128    /// Survives across multiple requests as the model performs tool calls and
 129    /// we run tools, report their results.
 130    running_turn: Option<Task<()>>,
 131    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
 132    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 133    project_context: Rc<RefCell<ProjectContext>>,
 134    templates: Arc<Templates>,
 135    pub selected_model: Arc<dyn LanguageModel>,
 136    project: Entity<Project>,
 137    action_log: Entity<ActionLog>,
 138}
 139
 140impl Thread {
 141    pub fn new(
 142        project: Entity<Project>,
 143        project_context: Rc<RefCell<ProjectContext>>,
 144        action_log: Entity<ActionLog>,
 145        templates: Arc<Templates>,
 146        default_model: Arc<dyn LanguageModel>,
 147    ) -> Self {
 148        Self {
 149            messages: Vec::new(),
 150            completion_mode: CompletionMode::Normal,
 151            running_turn: None,
 152            pending_tool_uses: HashMap::default(),
 153            tools: BTreeMap::default(),
 154            project_context,
 155            templates,
 156            selected_model: default_model,
 157            project,
 158            action_log,
 159        }
 160    }
 161
 162    pub fn project(&self) -> &Entity<Project> {
 163        &self.project
 164    }
 165
 166    pub fn action_log(&self) -> &Entity<ActionLog> {
 167        &self.action_log
 168    }
 169
 170    pub fn set_mode(&mut self, mode: CompletionMode) {
 171        self.completion_mode = mode;
 172    }
 173
 174    pub fn messages(&self) -> &[AgentMessage] {
 175        &self.messages
 176    }
 177
 178    pub fn add_tool(&mut self, tool: impl AgentTool) {
 179        self.tools.insert(tool.name(), tool.erase());
 180    }
 181
 182    pub fn remove_tool(&mut self, name: &str) -> bool {
 183        self.tools.remove(name).is_some()
 184    }
 185
 186    pub fn cancel(&mut self) {
 187        self.running_turn.take();
 188
 189        let tool_results = self
 190            .pending_tool_uses
 191            .drain()
 192            .map(|(tool_use_id, tool_use)| {
 193                MessageContent::ToolResult(LanguageModelToolResult {
 194                    tool_use_id,
 195                    tool_name: tool_use.name.clone(),
 196                    is_error: true,
 197                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
 198                    output: None,
 199                })
 200            })
 201            .collect::<Vec<_>>();
 202        self.last_user_message().content.extend(tool_results);
 203    }
 204
 205    /// Sending a message results in the model streaming a response, which could include tool calls.
 206    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 207    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 208    pub fn send(
 209        &mut self,
 210        model: Arc<dyn LanguageModel>,
 211        content: impl Into<MessageContent>,
 212        cx: &mut Context<Self>,
 213    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 214        let content = content.into();
 215        log::info!("Thread::send called with model: {:?}", model.name());
 216        log::debug!("Thread::send content: {:?}", content);
 217
 218        cx.notify();
 219        let (events_tx, events_rx) =
 220            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 221        let event_stream = AgentResponseEventStream(events_tx);
 222
 223        let user_message_ix = self.messages.len();
 224        self.messages.push(AgentMessage {
 225            role: Role::User,
 226            content: vec![content],
 227        });
 228        log::info!("Total messages in thread: {}", self.messages.len());
 229        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 230            log::info!("Starting agent turn execution");
 231            let turn_result = async {
 232                // Perform one request, then keep looping if the model makes tool calls.
 233                let mut completion_intent = CompletionIntent::UserPrompt;
 234                'outer: loop {
 235                    log::debug!(
 236                        "Building completion request with intent: {:?}",
 237                        completion_intent
 238                    );
 239                    let request = thread.update(cx, |thread, cx| {
 240                        thread.build_completion_request(completion_intent, cx)
 241                    })?;
 242
 243                    // println!(
 244                    //     "request: {}",
 245                    //     serde_json::to_string_pretty(&request).unwrap()
 246                    // );
 247
 248                    // Stream events, appending to messages and collecting up tool uses.
 249                    log::info!("Calling model.stream_completion");
 250                    let mut events = model.stream_completion(request, cx).await?;
 251                    log::debug!("Stream completion started successfully");
 252                    let mut tool_uses = FuturesUnordered::new();
 253                    while let Some(event) = events.next().await {
 254                        match event {
 255                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 256                                event_stream.send_stop(reason);
 257                                if reason == StopReason::Refusal {
 258                                    thread.update(cx, |thread, _cx| {
 259                                        thread.messages.truncate(user_message_ix);
 260                                    })?;
 261                                    break 'outer;
 262                                }
 263                            }
 264                            Ok(event) => {
 265                                log::trace!("Received completion event: {:?}", event);
 266                                thread
 267                                    .update(cx, |thread, cx| {
 268                                        tool_uses.extend(thread.handle_streamed_completion_event(
 269                                            event,
 270                                            &event_stream,
 271                                            cx,
 272                                        ));
 273                                    })
 274                                    .ok();
 275                            }
 276                            Err(error) => {
 277                                log::error!("Error in completion stream: {:?}", error);
 278                                event_stream.send_error(error);
 279                                break;
 280                            }
 281                        }
 282                    }
 283
 284                    // If there are no tool uses, the turn is done.
 285                    if tool_uses.is_empty() {
 286                        log::info!("No tool uses found, completing turn");
 287                        break;
 288                    }
 289                    log::info!("Found {} tool uses to execute", tool_uses.len());
 290
 291                    // As tool results trickle in, insert them in the last user
 292                    // message so that they can be sent on the next tick of the
 293                    // agentic loop.
 294                    while let Some(tool_result) = tool_uses.next().await {
 295                        log::info!("Tool finished {:?}", tool_result);
 296
 297                        event_stream.send_tool_call_update(
 298                            &tool_result.tool_use_id,
 299                            acp::ToolCallUpdateFields {
 300                                status: Some(if tool_result.is_error {
 301                                    acp::ToolCallStatus::Failed
 302                                } else {
 303                                    acp::ToolCallStatus::Completed
 304                                }),
 305                                ..Default::default()
 306                            },
 307                        );
 308                        thread
 309                            .update(cx, |thread, _cx| {
 310                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
 311                                thread
 312                                    .last_user_message()
 313                                    .content
 314                                    .push(MessageContent::ToolResult(tool_result));
 315                            })
 316                            .ok();
 317                    }
 318
 319                    completion_intent = CompletionIntent::ToolResults;
 320                }
 321
 322                Ok(())
 323            }
 324            .await;
 325
 326            if let Err(error) = turn_result {
 327                log::error!("Turn execution failed: {:?}", error);
 328                event_stream.send_error(error);
 329            } else {
 330                log::info!("Turn execution completed successfully");
 331            }
 332        }));
 333        events_rx
 334    }
 335
 336    pub fn build_system_message(&self) -> AgentMessage {
 337        log::debug!("Building system message");
 338        let prompt = SystemPromptTemplate {
 339            project: &self.project_context.borrow(),
 340            available_tools: self.tools.keys().cloned().collect(),
 341        }
 342        .render(&self.templates)
 343        .context("failed to build system prompt")
 344        .expect("Invalid template");
 345        log::debug!("System message built");
 346        AgentMessage {
 347            role: Role::System,
 348            content: vec![prompt.into()],
 349        }
 350    }
 351
 352    /// A helper method that's called on every streamed completion event.
 353    /// Returns an optional tool result task, which the main agentic loop in
 354    /// send will send back to the model when it resolves.
 355    fn handle_streamed_completion_event(
 356        &mut self,
 357        event: LanguageModelCompletionEvent,
 358        event_stream: &AgentResponseEventStream,
 359        cx: &mut Context<Self>,
 360    ) -> Option<Task<LanguageModelToolResult>> {
 361        log::trace!("Handling streamed completion event: {:?}", event);
 362        use LanguageModelCompletionEvent::*;
 363
 364        match event {
 365            StartMessage { .. } => {
 366                self.messages.push(AgentMessage {
 367                    role: Role::Assistant,
 368                    content: Vec::new(),
 369                });
 370            }
 371            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 372            Thinking { text, signature } => {
 373                self.handle_thinking_event(text, signature, event_stream, cx)
 374            }
 375            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 376            ToolUse(tool_use) => {
 377                return self.handle_tool_use_event(tool_use, event_stream, cx);
 378            }
 379            ToolUseJsonParseError {
 380                id,
 381                tool_name,
 382                raw_input,
 383                json_parse_error,
 384            } => {
 385                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 386                    id,
 387                    tool_name,
 388                    raw_input,
 389                    json_parse_error,
 390                )));
 391            }
 392            UsageUpdate(_) | StatusUpdate(_) => {}
 393            Stop(_) => unreachable!(),
 394        }
 395
 396        None
 397    }
 398
 399    fn handle_text_event(
 400        &mut self,
 401        new_text: String,
 402        events_stream: &AgentResponseEventStream,
 403        cx: &mut Context<Self>,
 404    ) {
 405        events_stream.send_text(&new_text);
 406
 407        let last_message = self.last_assistant_message();
 408        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
 409            text.push_str(&new_text);
 410        } else {
 411            last_message.content.push(MessageContent::Text(new_text));
 412        }
 413
 414        cx.notify();
 415    }
 416
 417    fn handle_thinking_event(
 418        &mut self,
 419        new_text: String,
 420        new_signature: Option<String>,
 421        event_stream: &AgentResponseEventStream,
 422        cx: &mut Context<Self>,
 423    ) {
 424        event_stream.send_thinking(&new_text);
 425
 426        let last_message = self.last_assistant_message();
 427        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
 428        {
 429            text.push_str(&new_text);
 430            *signature = new_signature.or(signature.take());
 431        } else {
 432            last_message.content.push(MessageContent::Thinking {
 433                text: new_text,
 434                signature: new_signature,
 435            });
 436        }
 437
 438        cx.notify();
 439    }
 440
 441    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 442        let last_message = self.last_assistant_message();
 443        last_message
 444            .content
 445            .push(MessageContent::RedactedThinking(data));
 446        cx.notify();
 447    }
 448
 449    fn handle_tool_use_event(
 450        &mut self,
 451        tool_use: LanguageModelToolUse,
 452        event_stream: &AgentResponseEventStream,
 453        cx: &mut Context<Self>,
 454    ) -> Option<Task<LanguageModelToolResult>> {
 455        cx.notify();
 456
 457        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 458
 459        self.pending_tool_uses
 460            .insert(tool_use.id.clone(), tool_use.clone());
 461        let last_message = self.last_assistant_message();
 462
 463        // Ensure the last message ends in the current tool use
 464        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 465            if let MessageContent::ToolUse(last_tool_use) = content {
 466                if last_tool_use.id == tool_use.id {
 467                    *last_tool_use = tool_use.clone();
 468                    false
 469                } else {
 470                    true
 471                }
 472            } else {
 473                true
 474            }
 475        });
 476
 477        if push_new_tool_use {
 478            event_stream.send_tool_call(tool.as_ref(), &tool_use);
 479            last_message
 480                .content
 481                .push(MessageContent::ToolUse(tool_use.clone()));
 482        } else {
 483            event_stream.send_tool_call_update(
 484                &tool_use.id,
 485                acp::ToolCallUpdateFields {
 486                    raw_input: Some(tool_use.input.clone()),
 487                    ..Default::default()
 488                },
 489            );
 490        }
 491
 492        if !tool_use.is_input_complete {
 493            return None;
 494        }
 495
 496        let Some(tool) = tool else {
 497            let content = format!("No tool named {} exists", tool_use.name);
 498            return Some(Task::ready(LanguageModelToolResult {
 499                content: LanguageModelToolResultContent::Text(Arc::from(content)),
 500                tool_use_id: tool_use.id,
 501                tool_name: tool_use.name,
 502                is_error: true,
 503                output: None,
 504            }));
 505        };
 506
 507        let tool_event_stream =
 508            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
 509        tool_event_stream.send_update(acp::ToolCallUpdateFields {
 510            status: Some(acp::ToolCallStatus::InProgress),
 511            ..Default::default()
 512        });
 513        let supports_images = self.selected_model.supports_images();
 514        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
 515        Some(cx.foreground_executor().spawn(async move {
 516            let tool_result = tool_result.await.and_then(|output| {
 517                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
 518                    if !supports_images {
 519                        return Err(anyhow!(
 520                            "Attempted to read an image, but this model doesn't support it.",
 521                        ));
 522                    }
 523                }
 524                Ok(output)
 525            });
 526
 527            match tool_result {
 528                Ok(output) => LanguageModelToolResult {
 529                    tool_use_id: tool_use.id,
 530                    tool_name: tool_use.name,
 531                    is_error: false,
 532                    content: output.llm_output,
 533                    output: Some(output.raw_output),
 534                },
 535                Err(error) => LanguageModelToolResult {
 536                    tool_use_id: tool_use.id,
 537                    tool_name: tool_use.name,
 538                    is_error: true,
 539                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
 540                    output: None,
 541                },
 542            }
 543        }))
 544    }
 545
 546    fn handle_tool_use_json_parse_error_event(
 547        &mut self,
 548        tool_use_id: LanguageModelToolUseId,
 549        tool_name: Arc<str>,
 550        raw_input: Arc<str>,
 551        json_parse_error: String,
 552    ) -> LanguageModelToolResult {
 553        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
 554        LanguageModelToolResult {
 555            tool_use_id,
 556            tool_name,
 557            is_error: true,
 558            content: LanguageModelToolResultContent::Text(tool_output.into()),
 559            output: Some(serde_json::Value::String(raw_input.to_string())),
 560        }
 561    }
 562
 563    /// Guarantees the last message is from the assistant and returns a mutable reference.
 564    fn last_assistant_message(&mut self) -> &mut AgentMessage {
 565        if self
 566            .messages
 567            .last()
 568            .map_or(true, |m| m.role != Role::Assistant)
 569        {
 570            self.messages.push(AgentMessage {
 571                role: Role::Assistant,
 572                content: Vec::new(),
 573            });
 574        }
 575        self.messages.last_mut().unwrap()
 576    }
 577
 578    /// Guarantees the last message is from the user and returns a mutable reference.
 579    fn last_user_message(&mut self) -> &mut AgentMessage {
 580        if self.messages.last().map_or(true, |m| m.role != Role::User) {
 581            self.messages.push(AgentMessage {
 582                role: Role::User,
 583                content: Vec::new(),
 584            });
 585        }
 586        self.messages.last_mut().unwrap()
 587    }
 588
 589    pub(crate) fn build_completion_request(
 590        &self,
 591        completion_intent: CompletionIntent,
 592        cx: &mut App,
 593    ) -> LanguageModelRequest {
 594        log::debug!("Building completion request");
 595        log::debug!("Completion intent: {:?}", completion_intent);
 596        log::debug!("Completion mode: {:?}", self.completion_mode);
 597
 598        let messages = self.build_request_messages();
 599        log::info!("Request will include {} messages", messages.len());
 600
 601        let tools: Vec<LanguageModelRequestTool> = self
 602            .tools
 603            .values()
 604            .filter_map(|tool| {
 605                let tool_name = tool.name().to_string();
 606                log::trace!("Including tool: {}", tool_name);
 607                Some(LanguageModelRequestTool {
 608                    name: tool_name,
 609                    description: tool.description(cx).to_string(),
 610                    input_schema: tool
 611                        .input_schema(self.selected_model.tool_input_format())
 612                        .log_err()?,
 613                })
 614            })
 615            .collect();
 616
 617        log::info!("Request includes {} tools", tools.len());
 618
 619        let request = LanguageModelRequest {
 620            thread_id: None,
 621            prompt_id: None,
 622            intent: Some(completion_intent),
 623            mode: Some(self.completion_mode),
 624            messages,
 625            tools,
 626            tool_choice: None,
 627            stop: Vec::new(),
 628            temperature: None,
 629            thinking_allowed: true,
 630        };
 631
 632        log::debug!("Completion request built successfully");
 633        request
 634    }
 635
 636    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
 637        log::trace!(
 638            "Building request messages from {} thread messages",
 639            self.messages.len()
 640        );
 641
 642        let messages = Some(self.build_system_message())
 643            .iter()
 644            .chain(self.messages.iter())
 645            .map(|message| {
 646                log::trace!(
 647                    "  - {} message with {} content items",
 648                    match message.role {
 649                        Role::System => "System",
 650                        Role::User => "User",
 651                        Role::Assistant => "Assistant",
 652                    },
 653                    message.content.len()
 654                );
 655                LanguageModelRequestMessage {
 656                    role: message.role,
 657                    content: message.content.clone(),
 658                    cache: false,
 659                }
 660            })
 661            .collect();
 662        messages
 663    }
 664
 665    pub fn to_markdown(&self) -> String {
 666        let mut markdown = String::new();
 667        for message in &self.messages {
 668            markdown.push_str(&message.to_markdown());
 669        }
 670        markdown
 671    }
 672}
 673
 674pub trait AgentTool
 675where
 676    Self: 'static + Sized,
 677{
 678    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
 679    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
 680
 681    fn name(&self) -> SharedString;
 682
 683    fn description(&self, _cx: &mut App) -> SharedString {
 684        let schema = schemars::schema_for!(Self::Input);
 685        SharedString::new(
 686            schema
 687                .get("description")
 688                .and_then(|description| description.as_str())
 689                .unwrap_or_default(),
 690        )
 691    }
 692
 693    fn kind(&self) -> acp::ToolKind;
 694
 695    /// The initial tool title to display. Can be updated during the tool run.
 696    fn initial_title(&self, input: Self::Input) -> SharedString;
 697
 698    /// Returns the JSON schema that describes the tool's input.
 699    fn input_schema(&self) -> Schema {
 700        schemars::schema_for!(Self::Input)
 701    }
 702
 703    /// Runs the tool with the provided input.
 704    fn run(
 705        self: Arc<Self>,
 706        input: Self::Input,
 707        event_stream: ToolCallEventStream,
 708        cx: &mut App,
 709    ) -> Task<Result<Self::Output>>;
 710
 711    fn erase(self) -> Arc<dyn AnyAgentTool> {
 712        Arc::new(Erased(Arc::new(self)))
 713    }
 714}
 715
 716pub struct Erased<T>(T);
 717
 718pub struct AgentToolOutput {
 719    llm_output: LanguageModelToolResultContent,
 720    raw_output: serde_json::Value,
 721}
 722
 723pub trait AnyAgentTool {
 724    fn name(&self) -> SharedString;
 725    fn description(&self, cx: &mut App) -> SharedString;
 726    fn kind(&self) -> acp::ToolKind;
 727    fn initial_title(&self, input: serde_json::Value) -> Result<SharedString>;
 728    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
 729    fn run(
 730        self: Arc<Self>,
 731        input: serde_json::Value,
 732        event_stream: ToolCallEventStream,
 733        cx: &mut App,
 734    ) -> Task<Result<AgentToolOutput>>;
 735}
 736
 737impl<T> AnyAgentTool for Erased<Arc<T>>
 738where
 739    T: AgentTool,
 740{
 741    fn name(&self) -> SharedString {
 742        self.0.name()
 743    }
 744
 745    fn description(&self, cx: &mut App) -> SharedString {
 746        self.0.description(cx)
 747    }
 748
 749    fn kind(&self) -> agent_client_protocol::ToolKind {
 750        self.0.kind()
 751    }
 752
 753    fn initial_title(&self, input: serde_json::Value) -> Result<SharedString> {
 754        let parsed_input = serde_json::from_value(input)?;
 755        Ok(self.0.initial_title(parsed_input))
 756    }
 757
 758    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 759        let mut json = serde_json::to_value(self.0.input_schema())?;
 760        adapt_schema_to_format(&mut json, format)?;
 761        Ok(json)
 762    }
 763
 764    fn run(
 765        self: Arc<Self>,
 766        input: serde_json::Value,
 767        event_stream: ToolCallEventStream,
 768        cx: &mut App,
 769    ) -> Task<Result<AgentToolOutput>> {
 770        cx.spawn(async move |cx| {
 771            let input = serde_json::from_value(input)?;
 772            let output = cx
 773                .update(|cx| self.0.clone().run(input, event_stream, cx))?
 774                .await?;
 775            let raw_output = serde_json::to_value(&output)?;
 776            Ok(AgentToolOutput {
 777                llm_output: output.into(),
 778                raw_output,
 779            })
 780        })
 781    }
 782}
 783
 784#[derive(Clone)]
 785struct AgentResponseEventStream(
 786    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 787);
 788
 789impl AgentResponseEventStream {
 790    fn send_text(&self, text: &str) {
 791        self.0
 792            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
 793            .ok();
 794    }
 795
 796    fn send_thinking(&self, text: &str) {
 797        self.0
 798            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
 799            .ok();
 800    }
 801
 802    fn authorize_tool_call(
 803        &self,
 804        id: &LanguageModelToolUseId,
 805        title: String,
 806        kind: acp::ToolKind,
 807        input: serde_json::Value,
 808    ) -> impl use<> + Future<Output = Result<()>> {
 809        let (response_tx, response_rx) = oneshot::channel();
 810        self.0
 811            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
 812                ToolCallAuthorization {
 813                    tool_call: Self::initial_tool_call(id, title, kind, input),
 814                    options: vec![
 815                        acp::PermissionOption {
 816                            id: acp::PermissionOptionId("always_allow".into()),
 817                            name: "Always Allow".into(),
 818                            kind: acp::PermissionOptionKind::AllowAlways,
 819                        },
 820                        acp::PermissionOption {
 821                            id: acp::PermissionOptionId("allow".into()),
 822                            name: "Allow".into(),
 823                            kind: acp::PermissionOptionKind::AllowOnce,
 824                        },
 825                        acp::PermissionOption {
 826                            id: acp::PermissionOptionId("deny".into()),
 827                            name: "Deny".into(),
 828                            kind: acp::PermissionOptionKind::RejectOnce,
 829                        },
 830                    ],
 831                    response: response_tx,
 832                },
 833            )))
 834            .ok();
 835        async move {
 836            match response_rx.await?.0.as_ref() {
 837                "allow" | "always_allow" => Ok(()),
 838                _ => Err(anyhow!("Permission to run tool denied by user")),
 839            }
 840        }
 841    }
 842
 843    fn send_tool_call(
 844        &self,
 845        tool: Option<&Arc<dyn AnyAgentTool>>,
 846        tool_use: &LanguageModelToolUse,
 847    ) {
 848        self.0
 849            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
 850                &tool_use.id,
 851                tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok())
 852                    .map(|i| i.into())
 853                    .unwrap_or_else(|| tool_use.name.to_string()),
 854                tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other),
 855                tool_use.input.clone(),
 856            ))))
 857            .ok();
 858    }
 859
 860    fn initial_tool_call(
 861        id: &LanguageModelToolUseId,
 862        title: String,
 863        kind: acp::ToolKind,
 864        input: serde_json::Value,
 865    ) -> acp::ToolCall {
 866        acp::ToolCall {
 867            id: acp::ToolCallId(id.to_string().into()),
 868            title,
 869            kind,
 870            status: acp::ToolCallStatus::Pending,
 871            content: vec![],
 872            locations: vec![],
 873            raw_input: Some(input),
 874            raw_output: None,
 875        }
 876    }
 877
 878    fn send_tool_call_update(
 879        &self,
 880        tool_use_id: &LanguageModelToolUseId,
 881        fields: acp::ToolCallUpdateFields,
 882    ) {
 883        self.0
 884            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 885                acp::ToolCallUpdate {
 886                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 887                    fields,
 888                },
 889            )))
 890            .ok();
 891    }
 892
 893    fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) {
 894        self.0
 895            .unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff)))
 896            .ok();
 897    }
 898
 899    fn send_stop(&self, reason: StopReason) {
 900        match reason {
 901            StopReason::EndTurn => {
 902                self.0
 903                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
 904                    .ok();
 905            }
 906            StopReason::MaxTokens => {
 907                self.0
 908                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
 909                    .ok();
 910            }
 911            StopReason::Refusal => {
 912                self.0
 913                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
 914                    .ok();
 915            }
 916            StopReason::ToolUse => {}
 917        }
 918    }
 919
 920    fn send_error(&self, error: LanguageModelCompletionError) {
 921        self.0.unbounded_send(Err(error)).ok();
 922    }
 923}
 924
 925#[derive(Clone)]
 926pub struct ToolCallEventStream {
 927    tool_use_id: LanguageModelToolUseId,
 928    kind: acp::ToolKind,
 929    input: serde_json::Value,
 930    stream: AgentResponseEventStream,
 931}
 932
 933impl ToolCallEventStream {
 934    #[cfg(test)]
 935    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
 936        let (events_tx, events_rx) =
 937            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 938
 939        let stream = ToolCallEventStream::new(
 940            &LanguageModelToolUse {
 941                id: "test_id".into(),
 942                name: "test_tool".into(),
 943                raw_input: String::new(),
 944                input: serde_json::Value::Null,
 945                is_input_complete: true,
 946            },
 947            acp::ToolKind::Other,
 948            AgentResponseEventStream(events_tx),
 949        );
 950
 951        (stream, ToolCallEventStreamReceiver(events_rx))
 952    }
 953
 954    fn new(
 955        tool_use: &LanguageModelToolUse,
 956        kind: acp::ToolKind,
 957        stream: AgentResponseEventStream,
 958    ) -> Self {
 959        Self {
 960            tool_use_id: tool_use.id.clone(),
 961            kind,
 962            input: tool_use.input.clone(),
 963            stream,
 964        }
 965    }
 966
 967    pub fn send_update(&self, fields: acp::ToolCallUpdateFields) {
 968        self.stream.send_tool_call_update(&self.tool_use_id, fields);
 969    }
 970
 971    pub fn send_diff(&self, diff: Entity<Diff>) {
 972        self.stream.send_tool_call_diff(ToolCallDiff {
 973            tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()),
 974            diff,
 975        });
 976    }
 977
 978    pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
 979        self.stream.authorize_tool_call(
 980            &self.tool_use_id,
 981            title,
 982            self.kind.clone(),
 983            self.input.clone(),
 984        )
 985    }
 986}
 987
 988#[cfg(test)]
 989pub struct ToolCallEventStreamReceiver(
 990    mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 991);
 992
 993#[cfg(test)]
 994impl ToolCallEventStreamReceiver {
 995    pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization {
 996        let event = self.0.next().await;
 997        if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
 998            auth
 999        } else {
1000            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1001        }
1002    }
1003}
1004
1005#[cfg(test)]
1006impl std::ops::Deref for ToolCallEventStreamReceiver {
1007    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1008
1009    fn deref(&self) -> &Self::Target {
1010        &self.0
1011    }
1012}
1013
1014#[cfg(test)]
1015impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1016    fn deref_mut(&mut self) -> &mut Self::Target {
1017        &mut self.0
1018    }
1019}