thread.rs

   1use crate::{SystemPromptTemplate, Template, Templates};
   2use action_log::ActionLog;
   3use agent_client_protocol as acp;
   4use agent_settings::AgentSettings;
   5use anyhow::{Context as _, Result, anyhow};
   6use assistant_tool::adapt_schema_to_format;
   7use cloud_llm_client::{CompletionIntent, CompletionMode};
   8use collections::HashMap;
   9use fs::Fs;
  10use futures::{
  11    channel::{mpsc, oneshot},
  12    stream::FuturesUnordered,
  13};
  14use gpui::{App, Context, Entity, SharedString, Task};
  15use language_model::{
  16    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  17    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  18    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  19    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
  20};
  21use log;
  22use project::Project;
  23use prompt_store::ProjectContext;
  24use schemars::{JsonSchema, Schema};
  25use serde::{Deserialize, Serialize};
  26use settings::{Settings, update_settings_file};
  27use smol::stream::StreamExt;
  28use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc};
  29use util::{ResultExt, markdown::MarkdownCodeBlock};
  30
  31#[derive(Debug, Clone)]
  32pub struct AgentMessage {
  33    pub role: Role,
  34    pub content: Vec<MessageContent>,
  35}
  36
  37impl AgentMessage {
  38    pub fn to_markdown(&self) -> String {
  39        let mut markdown = format!("## {}\n", self.role);
  40
  41        for content in &self.content {
  42            match content {
  43                MessageContent::Text(text) => {
  44                    markdown.push_str(text);
  45                    markdown.push('\n');
  46                }
  47                MessageContent::Thinking { text, .. } => {
  48                    markdown.push_str("<think>");
  49                    markdown.push_str(text);
  50                    markdown.push_str("</think>\n");
  51                }
  52                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
  53                MessageContent::Image(_) => {
  54                    markdown.push_str("<image />\n");
  55                }
  56                MessageContent::ToolUse(tool_use) => {
  57                    markdown.push_str(&format!(
  58                        "**Tool Use**: {} (ID: {})\n",
  59                        tool_use.name, tool_use.id
  60                    ));
  61                    markdown.push_str(&format!(
  62                        "{}\n",
  63                        MarkdownCodeBlock {
  64                            tag: "json",
  65                            text: &format!("{:#}", tool_use.input)
  66                        }
  67                    ));
  68                }
  69                MessageContent::ToolResult(tool_result) => {
  70                    markdown.push_str(&format!(
  71                        "**Tool Result**: {} (ID: {})\n\n",
  72                        tool_result.tool_name, tool_result.tool_use_id
  73                    ));
  74                    if tool_result.is_error {
  75                        markdown.push_str("**ERROR:**\n");
  76                    }
  77
  78                    match &tool_result.content {
  79                        LanguageModelToolResultContent::Text(text) => {
  80                            writeln!(markdown, "{text}\n").ok();
  81                        }
  82                        LanguageModelToolResultContent::Image(_) => {
  83                            writeln!(markdown, "<image />\n").ok();
  84                        }
  85                    }
  86
  87                    if let Some(output) = tool_result.output.as_ref() {
  88                        writeln!(
  89                            markdown,
  90                            "**Debug Output**:\n\n```json\n{}\n```\n",
  91                            serde_json::to_string_pretty(output).unwrap()
  92                        )
  93                        .unwrap();
  94                    }
  95                }
  96            }
  97        }
  98
  99        markdown
 100    }
 101}
 102
 103#[derive(Debug)]
 104pub enum AgentResponseEvent {
 105    Text(String),
 106    Thinking(String),
 107    ToolCall(acp::ToolCall),
 108    ToolCallUpdate(acp_thread::ToolCallUpdate),
 109    ToolCallAuthorization(ToolCallAuthorization),
 110    Stop(acp::StopReason),
 111}
 112
 113#[derive(Debug)]
 114pub struct ToolCallAuthorization {
 115    pub tool_call: acp::ToolCall,
 116    pub options: Vec<acp::PermissionOption>,
 117    pub response: oneshot::Sender<acp::PermissionOptionId>,
 118}
 119
 120pub struct Thread {
 121    messages: Vec<AgentMessage>,
 122    completion_mode: CompletionMode,
 123    /// Holds the task that handles agent interaction until the end of the turn.
 124    /// Survives across multiple requests as the model performs tool calls and
 125    /// we run tools, report their results.
 126    running_turn: Option<Task<()>>,
 127    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
 128    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 129    project_context: Rc<RefCell<ProjectContext>>,
 130    templates: Arc<Templates>,
 131    pub selected_model: Arc<dyn LanguageModel>,
 132    project: Entity<Project>,
 133    action_log: Entity<ActionLog>,
 134}
 135
 136impl Thread {
 137    pub fn new(
 138        project: Entity<Project>,
 139        project_context: Rc<RefCell<ProjectContext>>,
 140        action_log: Entity<ActionLog>,
 141        templates: Arc<Templates>,
 142        default_model: Arc<dyn LanguageModel>,
 143    ) -> Self {
 144        Self {
 145            messages: Vec::new(),
 146            completion_mode: CompletionMode::Normal,
 147            running_turn: None,
 148            pending_tool_uses: HashMap::default(),
 149            tools: BTreeMap::default(),
 150            project_context,
 151            templates,
 152            selected_model: default_model,
 153            project,
 154            action_log,
 155        }
 156    }
 157
 158    pub fn project(&self) -> &Entity<Project> {
 159        &self.project
 160    }
 161
 162    pub fn action_log(&self) -> &Entity<ActionLog> {
 163        &self.action_log
 164    }
 165
 166    pub fn set_mode(&mut self, mode: CompletionMode) {
 167        self.completion_mode = mode;
 168    }
 169
 170    pub fn messages(&self) -> &[AgentMessage] {
 171        &self.messages
 172    }
 173
 174    pub fn add_tool(&mut self, tool: impl AgentTool) {
 175        self.tools.insert(tool.name(), tool.erase());
 176    }
 177
 178    pub fn remove_tool(&mut self, name: &str) -> bool {
 179        self.tools.remove(name).is_some()
 180    }
 181
 182    pub fn cancel(&mut self) {
 183        self.running_turn.take();
 184
 185        let tool_results = self
 186            .pending_tool_uses
 187            .drain()
 188            .map(|(tool_use_id, tool_use)| {
 189                MessageContent::ToolResult(LanguageModelToolResult {
 190                    tool_use_id,
 191                    tool_name: tool_use.name.clone(),
 192                    is_error: true,
 193                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
 194                    output: None,
 195                })
 196            })
 197            .collect::<Vec<_>>();
 198        self.last_user_message().content.extend(tool_results);
 199    }
 200
 201    /// Sending a message results in the model streaming a response, which could include tool calls.
 202    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 203    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 204    pub fn send(
 205        &mut self,
 206        content: impl Into<MessageContent>,
 207        cx: &mut Context<Self>,
 208    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 209        let content = content.into();
 210        let model = self.selected_model.clone();
 211        log::info!("Thread::send called with model: {:?}", model.name());
 212        log::debug!("Thread::send content: {:?}", content);
 213
 214        cx.notify();
 215        let (events_tx, events_rx) =
 216            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 217        let event_stream = AgentResponseEventStream(events_tx);
 218
 219        let user_message_ix = self.messages.len();
 220        self.messages.push(AgentMessage {
 221            role: Role::User,
 222            content: vec![content],
 223        });
 224        log::info!("Total messages in thread: {}", self.messages.len());
 225        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 226            log::info!("Starting agent turn execution");
 227            let turn_result = async {
 228                // Perform one request, then keep looping if the model makes tool calls.
 229                let mut completion_intent = CompletionIntent::UserPrompt;
 230                'outer: loop {
 231                    log::debug!(
 232                        "Building completion request with intent: {:?}",
 233                        completion_intent
 234                    );
 235                    let request = thread.update(cx, |thread, cx| {
 236                        thread.build_completion_request(completion_intent, cx)
 237                    })?;
 238
 239                    // println!(
 240                    //     "request: {}",
 241                    //     serde_json::to_string_pretty(&request).unwrap()
 242                    // );
 243
 244                    // Stream events, appending to messages and collecting up tool uses.
 245                    log::info!("Calling model.stream_completion");
 246                    let mut events = model.stream_completion(request, cx).await?;
 247                    log::debug!("Stream completion started successfully");
 248                    let mut tool_uses = FuturesUnordered::new();
 249                    while let Some(event) = events.next().await {
 250                        match event {
 251                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 252                                event_stream.send_stop(reason);
 253                                if reason == StopReason::Refusal {
 254                                    thread.update(cx, |thread, _cx| {
 255                                        thread.messages.truncate(user_message_ix);
 256                                    })?;
 257                                    break 'outer;
 258                                }
 259                            }
 260                            Ok(event) => {
 261                                log::trace!("Received completion event: {:?}", event);
 262                                thread
 263                                    .update(cx, |thread, cx| {
 264                                        tool_uses.extend(thread.handle_streamed_completion_event(
 265                                            event,
 266                                            &event_stream,
 267                                            cx,
 268                                        ));
 269                                    })
 270                                    .ok();
 271                            }
 272                            Err(error) => {
 273                                log::error!("Error in completion stream: {:?}", error);
 274                                event_stream.send_error(error);
 275                                break;
 276                            }
 277                        }
 278                    }
 279
 280                    // If there are no tool uses, the turn is done.
 281                    if tool_uses.is_empty() {
 282                        log::info!("No tool uses found, completing turn");
 283                        break;
 284                    }
 285                    log::info!("Found {} tool uses to execute", tool_uses.len());
 286
 287                    // As tool results trickle in, insert them in the last user
 288                    // message so that they can be sent on the next tick of the
 289                    // agentic loop.
 290                    while let Some(tool_result) = tool_uses.next().await {
 291                        log::info!("Tool finished {:?}", tool_result);
 292
 293                        event_stream.update_tool_call_fields(
 294                            &tool_result.tool_use_id,
 295                            acp::ToolCallUpdateFields {
 296                                status: Some(if tool_result.is_error {
 297                                    acp::ToolCallStatus::Failed
 298                                } else {
 299                                    acp::ToolCallStatus::Completed
 300                                }),
 301                                ..Default::default()
 302                            },
 303                        );
 304                        thread
 305                            .update(cx, |thread, _cx| {
 306                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
 307                                thread
 308                                    .last_user_message()
 309                                    .content
 310                                    .push(MessageContent::ToolResult(tool_result));
 311                            })
 312                            .ok();
 313                    }
 314
 315                    completion_intent = CompletionIntent::ToolResults;
 316                }
 317
 318                Ok(())
 319            }
 320            .await;
 321
 322            if let Err(error) = turn_result {
 323                log::error!("Turn execution failed: {:?}", error);
 324                event_stream.send_error(error);
 325            } else {
 326                log::info!("Turn execution completed successfully");
 327            }
 328        }));
 329        events_rx
 330    }
 331
 332    pub fn build_system_message(&self) -> AgentMessage {
 333        log::debug!("Building system message");
 334        let prompt = SystemPromptTemplate {
 335            project: &self.project_context.borrow(),
 336            available_tools: self.tools.keys().cloned().collect(),
 337        }
 338        .render(&self.templates)
 339        .context("failed to build system prompt")
 340        .expect("Invalid template");
 341        log::debug!("System message built");
 342        AgentMessage {
 343            role: Role::System,
 344            content: vec![prompt.into()],
 345        }
 346    }
 347
 348    /// A helper method that's called on every streamed completion event.
 349    /// Returns an optional tool result task, which the main agentic loop in
 350    /// send will send back to the model when it resolves.
 351    fn handle_streamed_completion_event(
 352        &mut self,
 353        event: LanguageModelCompletionEvent,
 354        event_stream: &AgentResponseEventStream,
 355        cx: &mut Context<Self>,
 356    ) -> Option<Task<LanguageModelToolResult>> {
 357        log::trace!("Handling streamed completion event: {:?}", event);
 358        use LanguageModelCompletionEvent::*;
 359
 360        match event {
 361            StartMessage { .. } => {
 362                self.messages.push(AgentMessage {
 363                    role: Role::Assistant,
 364                    content: Vec::new(),
 365                });
 366            }
 367            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 368            Thinking { text, signature } => {
 369                self.handle_thinking_event(text, signature, event_stream, cx)
 370            }
 371            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 372            ToolUse(tool_use) => {
 373                return self.handle_tool_use_event(tool_use, event_stream, cx);
 374            }
 375            ToolUseJsonParseError {
 376                id,
 377                tool_name,
 378                raw_input,
 379                json_parse_error,
 380            } => {
 381                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 382                    id,
 383                    tool_name,
 384                    raw_input,
 385                    json_parse_error,
 386                )));
 387            }
 388            UsageUpdate(_) | StatusUpdate(_) => {}
 389            Stop(_) => unreachable!(),
 390        }
 391
 392        None
 393    }
 394
 395    fn handle_text_event(
 396        &mut self,
 397        new_text: String,
 398        events_stream: &AgentResponseEventStream,
 399        cx: &mut Context<Self>,
 400    ) {
 401        events_stream.send_text(&new_text);
 402
 403        let last_message = self.last_assistant_message();
 404        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
 405            text.push_str(&new_text);
 406        } else {
 407            last_message.content.push(MessageContent::Text(new_text));
 408        }
 409
 410        cx.notify();
 411    }
 412
 413    fn handle_thinking_event(
 414        &mut self,
 415        new_text: String,
 416        new_signature: Option<String>,
 417        event_stream: &AgentResponseEventStream,
 418        cx: &mut Context<Self>,
 419    ) {
 420        event_stream.send_thinking(&new_text);
 421
 422        let last_message = self.last_assistant_message();
 423        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
 424        {
 425            text.push_str(&new_text);
 426            *signature = new_signature.or(signature.take());
 427        } else {
 428            last_message.content.push(MessageContent::Thinking {
 429                text: new_text,
 430                signature: new_signature,
 431            });
 432        }
 433
 434        cx.notify();
 435    }
 436
 437    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 438        let last_message = self.last_assistant_message();
 439        last_message
 440            .content
 441            .push(MessageContent::RedactedThinking(data));
 442        cx.notify();
 443    }
 444
 445    fn handle_tool_use_event(
 446        &mut self,
 447        tool_use: LanguageModelToolUse,
 448        event_stream: &AgentResponseEventStream,
 449        cx: &mut Context<Self>,
 450    ) -> Option<Task<LanguageModelToolResult>> {
 451        cx.notify();
 452
 453        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 454
 455        self.pending_tool_uses
 456            .insert(tool_use.id.clone(), tool_use.clone());
 457        let last_message = self.last_assistant_message();
 458
 459        // Ensure the last message ends in the current tool use
 460        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 461            if let MessageContent::ToolUse(last_tool_use) = content {
 462                if last_tool_use.id == tool_use.id {
 463                    *last_tool_use = tool_use.clone();
 464                    false
 465                } else {
 466                    true
 467                }
 468            } else {
 469                true
 470            }
 471        });
 472
 473        let mut title = SharedString::from(&tool_use.name);
 474        let mut kind = acp::ToolKind::Other;
 475        if let Some(tool) = tool.as_ref() {
 476            title = tool.initial_title(tool_use.input.clone());
 477            kind = tool.kind();
 478        }
 479
 480        if push_new_tool_use {
 481            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 482            last_message
 483                .content
 484                .push(MessageContent::ToolUse(tool_use.clone()));
 485        } else {
 486            event_stream.update_tool_call_fields(
 487                &tool_use.id,
 488                acp::ToolCallUpdateFields {
 489                    title: Some(title.into()),
 490                    kind: Some(kind),
 491                    raw_input: Some(tool_use.input.clone()),
 492                    ..Default::default()
 493                },
 494            );
 495        }
 496
 497        if !tool_use.is_input_complete {
 498            return None;
 499        }
 500
 501        let Some(tool) = tool else {
 502            let content = format!("No tool named {} exists", tool_use.name);
 503            return Some(Task::ready(LanguageModelToolResult {
 504                content: LanguageModelToolResultContent::Text(Arc::from(content)),
 505                tool_use_id: tool_use.id,
 506                tool_name: tool_use.name,
 507                is_error: true,
 508                output: None,
 509            }));
 510        };
 511
 512        let fs = self.project.read(cx).fs().clone();
 513        let tool_event_stream =
 514            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
 515        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
 516            status: Some(acp::ToolCallStatus::InProgress),
 517            ..Default::default()
 518        });
 519        let supports_images = self.selected_model.supports_images();
 520        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
 521        Some(cx.foreground_executor().spawn(async move {
 522            let tool_result = tool_result.await.and_then(|output| {
 523                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
 524                    if !supports_images {
 525                        return Err(anyhow!(
 526                            "Attempted to read an image, but this model doesn't support it.",
 527                        ));
 528                    }
 529                }
 530                Ok(output)
 531            });
 532
 533            match tool_result {
 534                Ok(output) => LanguageModelToolResult {
 535                    tool_use_id: tool_use.id,
 536                    tool_name: tool_use.name,
 537                    is_error: false,
 538                    content: output.llm_output,
 539                    output: Some(output.raw_output),
 540                },
 541                Err(error) => LanguageModelToolResult {
 542                    tool_use_id: tool_use.id,
 543                    tool_name: tool_use.name,
 544                    is_error: true,
 545                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
 546                    output: None,
 547                },
 548            }
 549        }))
 550    }
 551
 552    fn handle_tool_use_json_parse_error_event(
 553        &mut self,
 554        tool_use_id: LanguageModelToolUseId,
 555        tool_name: Arc<str>,
 556        raw_input: Arc<str>,
 557        json_parse_error: String,
 558    ) -> LanguageModelToolResult {
 559        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
 560        LanguageModelToolResult {
 561            tool_use_id,
 562            tool_name,
 563            is_error: true,
 564            content: LanguageModelToolResultContent::Text(tool_output.into()),
 565            output: Some(serde_json::Value::String(raw_input.to_string())),
 566        }
 567    }
 568
 569    /// Guarantees the last message is from the assistant and returns a mutable reference.
 570    fn last_assistant_message(&mut self) -> &mut AgentMessage {
 571        if self
 572            .messages
 573            .last()
 574            .map_or(true, |m| m.role != Role::Assistant)
 575        {
 576            self.messages.push(AgentMessage {
 577                role: Role::Assistant,
 578                content: Vec::new(),
 579            });
 580        }
 581        self.messages.last_mut().unwrap()
 582    }
 583
 584    /// Guarantees the last message is from the user and returns a mutable reference.
 585    fn last_user_message(&mut self) -> &mut AgentMessage {
 586        if self.messages.last().map_or(true, |m| m.role != Role::User) {
 587            self.messages.push(AgentMessage {
 588                role: Role::User,
 589                content: Vec::new(),
 590            });
 591        }
 592        self.messages.last_mut().unwrap()
 593    }
 594
 595    pub(crate) fn build_completion_request(
 596        &self,
 597        completion_intent: CompletionIntent,
 598        cx: &mut App,
 599    ) -> LanguageModelRequest {
 600        log::debug!("Building completion request");
 601        log::debug!("Completion intent: {:?}", completion_intent);
 602        log::debug!("Completion mode: {:?}", self.completion_mode);
 603
 604        let messages = self.build_request_messages();
 605        log::info!("Request will include {} messages", messages.len());
 606
 607        let tools: Vec<LanguageModelRequestTool> = self
 608            .tools
 609            .values()
 610            .filter_map(|tool| {
 611                let tool_name = tool.name().to_string();
 612                log::trace!("Including tool: {}", tool_name);
 613                Some(LanguageModelRequestTool {
 614                    name: tool_name,
 615                    description: tool.description(cx).to_string(),
 616                    input_schema: tool
 617                        .input_schema(self.selected_model.tool_input_format())
 618                        .log_err()?,
 619                })
 620            })
 621            .collect();
 622
 623        log::info!("Request includes {} tools", tools.len());
 624
 625        let request = LanguageModelRequest {
 626            thread_id: None,
 627            prompt_id: None,
 628            intent: Some(completion_intent),
 629            mode: Some(self.completion_mode),
 630            messages,
 631            tools,
 632            tool_choice: None,
 633            stop: Vec::new(),
 634            temperature: None,
 635            thinking_allowed: true,
 636        };
 637
 638        log::debug!("Completion request built successfully");
 639        request
 640    }
 641
 642    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
 643        log::trace!(
 644            "Building request messages from {} thread messages",
 645            self.messages.len()
 646        );
 647
 648        let messages = Some(self.build_system_message())
 649            .iter()
 650            .chain(self.messages.iter())
 651            .map(|message| {
 652                log::trace!(
 653                    "  - {} message with {} content items",
 654                    match message.role {
 655                        Role::System => "System",
 656                        Role::User => "User",
 657                        Role::Assistant => "Assistant",
 658                    },
 659                    message.content.len()
 660                );
 661                LanguageModelRequestMessage {
 662                    role: message.role,
 663                    content: message.content.clone(),
 664                    cache: false,
 665                }
 666            })
 667            .collect();
 668        messages
 669    }
 670
 671    pub fn to_markdown(&self) -> String {
 672        let mut markdown = String::new();
 673        for message in &self.messages {
 674            markdown.push_str(&message.to_markdown());
 675        }
 676        markdown
 677    }
 678}
 679
 680pub trait AgentTool
 681where
 682    Self: 'static + Sized,
 683{
 684    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
 685    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
 686
 687    fn name(&self) -> SharedString;
 688
 689    fn description(&self, _cx: &mut App) -> SharedString {
 690        let schema = schemars::schema_for!(Self::Input);
 691        SharedString::new(
 692            schema
 693                .get("description")
 694                .and_then(|description| description.as_str())
 695                .unwrap_or_default(),
 696        )
 697    }
 698
 699    fn kind(&self) -> acp::ToolKind;
 700
 701    /// The initial tool title to display. Can be updated during the tool run.
 702    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
 703
 704    /// Returns the JSON schema that describes the tool's input.
 705    fn input_schema(&self) -> Schema {
 706        schemars::schema_for!(Self::Input)
 707    }
 708
 709    /// Runs the tool with the provided input.
 710    fn run(
 711        self: Arc<Self>,
 712        input: Self::Input,
 713        event_stream: ToolCallEventStream,
 714        cx: &mut App,
 715    ) -> Task<Result<Self::Output>>;
 716
 717    fn erase(self) -> Arc<dyn AnyAgentTool> {
 718        Arc::new(Erased(Arc::new(self)))
 719    }
 720}
 721
 722pub struct Erased<T>(T);
 723
 724pub struct AgentToolOutput {
 725    llm_output: LanguageModelToolResultContent,
 726    raw_output: serde_json::Value,
 727}
 728
 729pub trait AnyAgentTool {
 730    fn name(&self) -> SharedString;
 731    fn description(&self, cx: &mut App) -> SharedString;
 732    fn kind(&self) -> acp::ToolKind;
 733    fn initial_title(&self, input: serde_json::Value) -> SharedString;
 734    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
 735    fn run(
 736        self: Arc<Self>,
 737        input: serde_json::Value,
 738        event_stream: ToolCallEventStream,
 739        cx: &mut App,
 740    ) -> Task<Result<AgentToolOutput>>;
 741}
 742
 743impl<T> AnyAgentTool for Erased<Arc<T>>
 744where
 745    T: AgentTool,
 746{
 747    fn name(&self) -> SharedString {
 748        self.0.name()
 749    }
 750
 751    fn description(&self, cx: &mut App) -> SharedString {
 752        self.0.description(cx)
 753    }
 754
 755    fn kind(&self) -> agent_client_protocol::ToolKind {
 756        self.0.kind()
 757    }
 758
 759    fn initial_title(&self, input: serde_json::Value) -> SharedString {
 760        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
 761        self.0.initial_title(parsed_input)
 762    }
 763
 764    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 765        let mut json = serde_json::to_value(self.0.input_schema())?;
 766        adapt_schema_to_format(&mut json, format)?;
 767        Ok(json)
 768    }
 769
 770    fn run(
 771        self: Arc<Self>,
 772        input: serde_json::Value,
 773        event_stream: ToolCallEventStream,
 774        cx: &mut App,
 775    ) -> Task<Result<AgentToolOutput>> {
 776        cx.spawn(async move |cx| {
 777            let input = serde_json::from_value(input)?;
 778            let output = cx
 779                .update(|cx| self.0.clone().run(input, event_stream, cx))?
 780                .await?;
 781            let raw_output = serde_json::to_value(&output)?;
 782            Ok(AgentToolOutput {
 783                llm_output: output.into(),
 784                raw_output,
 785            })
 786        })
 787    }
 788}
 789
 790#[derive(Clone)]
 791struct AgentResponseEventStream(
 792    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 793);
 794
 795impl AgentResponseEventStream {
 796    fn send_text(&self, text: &str) {
 797        self.0
 798            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
 799            .ok();
 800    }
 801
 802    fn send_thinking(&self, text: &str) {
 803        self.0
 804            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
 805            .ok();
 806    }
 807
 808    fn send_tool_call(
 809        &self,
 810        id: &LanguageModelToolUseId,
 811        title: SharedString,
 812        kind: acp::ToolKind,
 813        input: serde_json::Value,
 814    ) {
 815        self.0
 816            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
 817                id,
 818                title.to_string(),
 819                kind,
 820                input,
 821            ))))
 822            .ok();
 823    }
 824
 825    fn initial_tool_call(
 826        id: &LanguageModelToolUseId,
 827        title: String,
 828        kind: acp::ToolKind,
 829        input: serde_json::Value,
 830    ) -> acp::ToolCall {
 831        acp::ToolCall {
 832            id: acp::ToolCallId(id.to_string().into()),
 833            title,
 834            kind,
 835            status: acp::ToolCallStatus::Pending,
 836            content: vec![],
 837            locations: vec![],
 838            raw_input: Some(input),
 839            raw_output: None,
 840        }
 841    }
 842
 843    fn update_tool_call_fields(
 844        &self,
 845        tool_use_id: &LanguageModelToolUseId,
 846        fields: acp::ToolCallUpdateFields,
 847    ) {
 848        self.0
 849            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 850                acp::ToolCallUpdate {
 851                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 852                    fields,
 853                }
 854                .into(),
 855            )))
 856            .ok();
 857    }
 858
 859    fn send_stop(&self, reason: StopReason) {
 860        match reason {
 861            StopReason::EndTurn => {
 862                self.0
 863                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
 864                    .ok();
 865            }
 866            StopReason::MaxTokens => {
 867                self.0
 868                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
 869                    .ok();
 870            }
 871            StopReason::Refusal => {
 872                self.0
 873                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
 874                    .ok();
 875            }
 876            StopReason::ToolUse => {}
 877        }
 878    }
 879
 880    fn send_error(&self, error: LanguageModelCompletionError) {
 881        self.0.unbounded_send(Err(error)).ok();
 882    }
 883}
 884
 885#[derive(Clone)]
 886pub struct ToolCallEventStream {
 887    tool_use_id: LanguageModelToolUseId,
 888    kind: acp::ToolKind,
 889    input: serde_json::Value,
 890    stream: AgentResponseEventStream,
 891    fs: Option<Arc<dyn Fs>>,
 892}
 893
 894impl ToolCallEventStream {
 895    #[cfg(test)]
 896    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
 897        let (events_tx, events_rx) =
 898            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 899
 900        let stream = ToolCallEventStream::new(
 901            &LanguageModelToolUse {
 902                id: "test_id".into(),
 903                name: "test_tool".into(),
 904                raw_input: String::new(),
 905                input: serde_json::Value::Null,
 906                is_input_complete: true,
 907            },
 908            acp::ToolKind::Other,
 909            AgentResponseEventStream(events_tx),
 910            None,
 911        );
 912
 913        (stream, ToolCallEventStreamReceiver(events_rx))
 914    }
 915
 916    fn new(
 917        tool_use: &LanguageModelToolUse,
 918        kind: acp::ToolKind,
 919        stream: AgentResponseEventStream,
 920        fs: Option<Arc<dyn Fs>>,
 921    ) -> Self {
 922        Self {
 923            tool_use_id: tool_use.id.clone(),
 924            kind,
 925            input: tool_use.input.clone(),
 926            stream,
 927            fs,
 928        }
 929    }
 930
 931    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
 932        self.stream
 933            .update_tool_call_fields(&self.tool_use_id, fields);
 934    }
 935
 936    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
 937        self.stream
 938            .0
 939            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 940                acp_thread::ToolCallUpdateDiff {
 941                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
 942                    diff,
 943                }
 944                .into(),
 945            )))
 946            .ok();
 947    }
 948
 949    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
 950        self.stream
 951            .0
 952            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 953                acp_thread::ToolCallUpdateTerminal {
 954                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
 955                    terminal,
 956                }
 957                .into(),
 958            )))
 959            .ok();
 960    }
 961
 962    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
 963        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
 964            return Task::ready(Ok(()));
 965        }
 966
 967        let (response_tx, response_rx) = oneshot::channel();
 968        self.stream
 969            .0
 970            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
 971                ToolCallAuthorization {
 972                    tool_call: AgentResponseEventStream::initial_tool_call(
 973                        &self.tool_use_id,
 974                        title.into(),
 975                        self.kind.clone(),
 976                        self.input.clone(),
 977                    ),
 978                    options: vec![
 979                        acp::PermissionOption {
 980                            id: acp::PermissionOptionId("always_allow".into()),
 981                            name: "Always Allow".into(),
 982                            kind: acp::PermissionOptionKind::AllowAlways,
 983                        },
 984                        acp::PermissionOption {
 985                            id: acp::PermissionOptionId("allow".into()),
 986                            name: "Allow".into(),
 987                            kind: acp::PermissionOptionKind::AllowOnce,
 988                        },
 989                        acp::PermissionOption {
 990                            id: acp::PermissionOptionId("deny".into()),
 991                            name: "Deny".into(),
 992                            kind: acp::PermissionOptionKind::RejectOnce,
 993                        },
 994                    ],
 995                    response: response_tx,
 996                },
 997            )))
 998            .ok();
 999        let fs = self.fs.clone();
1000        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1001            "always_allow" => {
1002                if let Some(fs) = fs.clone() {
1003                    cx.update(|cx| {
1004                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1005                            settings.set_always_allow_tool_actions(true);
1006                        });
1007                    })?;
1008                }
1009
1010                Ok(())
1011            }
1012            "allow" => Ok(()),
1013            _ => Err(anyhow!("Permission to run tool denied by user")),
1014        })
1015    }
1016}
1017
1018#[cfg(test)]
1019pub struct ToolCallEventStreamReceiver(
1020    mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1021);
1022
1023#[cfg(test)]
1024impl ToolCallEventStreamReceiver {
1025    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1026        let event = self.0.next().await;
1027        if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1028            auth
1029        } else {
1030            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1031        }
1032    }
1033
1034    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1035        let event = self.0.next().await;
1036        if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1037            acp_thread::ToolCallUpdate::UpdateTerminal(update),
1038        ))) = event
1039        {
1040            update.terminal
1041        } else {
1042            panic!("Expected terminal but got: {:?}", event);
1043        }
1044    }
1045}
1046
1047#[cfg(test)]
1048impl std::ops::Deref for ToolCallEventStreamReceiver {
1049    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1050
1051    fn deref(&self) -> &Self::Target {
1052        &self.0
1053    }
1054}
1055
1056#[cfg(test)]
1057impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1058    fn deref_mut(&mut self) -> &mut Self::Target {
1059        &mut self.0
1060    }
1061}