thread.rs

   1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
   2use action_log::ActionLog;
   3use agent_client_protocol as acp;
   4use agent_settings::{AgentProfileId, AgentSettings};
   5use anyhow::{Context as _, Result, anyhow};
   6use assistant_tool::adapt_schema_to_format;
   7use cloud_llm_client::{CompletionIntent, CompletionMode};
   8use collections::HashMap;
   9use fs::Fs;
  10use futures::{
  11    channel::{mpsc, oneshot},
  12    stream::FuturesUnordered,
  13};
  14use gpui::{App, Context, Entity, SharedString, Task};
  15use language_model::{
  16    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  17    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  18    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  19    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
  20};
  21use log;
  22use project::Project;
  23use prompt_store::ProjectContext;
  24use schemars::{JsonSchema, Schema};
  25use serde::{Deserialize, Serialize};
  26use settings::{Settings, update_settings_file};
  27use smol::stream::StreamExt;
  28use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc};
  29use util::{ResultExt, markdown::MarkdownCodeBlock};
  30
  31#[derive(Debug, Clone)]
  32pub struct AgentMessage {
  33    pub role: Role,
  34    pub content: Vec<MessageContent>,
  35}
  36
  37impl AgentMessage {
  38    pub fn to_markdown(&self) -> String {
  39        let mut markdown = format!("## {}\n", self.role);
  40
  41        for content in &self.content {
  42            match content {
  43                MessageContent::Text(text) => {
  44                    markdown.push_str(text);
  45                    markdown.push('\n');
  46                }
  47                MessageContent::Thinking { text, .. } => {
  48                    markdown.push_str("<think>");
  49                    markdown.push_str(text);
  50                    markdown.push_str("</think>\n");
  51                }
  52                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
  53                MessageContent::Image(_) => {
  54                    markdown.push_str("<image />\n");
  55                }
  56                MessageContent::ToolUse(tool_use) => {
  57                    markdown.push_str(&format!(
  58                        "**Tool Use**: {} (ID: {})\n",
  59                        tool_use.name, tool_use.id
  60                    ));
  61                    markdown.push_str(&format!(
  62                        "{}\n",
  63                        MarkdownCodeBlock {
  64                            tag: "json",
  65                            text: &format!("{:#}", tool_use.input)
  66                        }
  67                    ));
  68                }
  69                MessageContent::ToolResult(tool_result) => {
  70                    markdown.push_str(&format!(
  71                        "**Tool Result**: {} (ID: {})\n\n",
  72                        tool_result.tool_name, tool_result.tool_use_id
  73                    ));
  74                    if tool_result.is_error {
  75                        markdown.push_str("**ERROR:**\n");
  76                    }
  77
  78                    match &tool_result.content {
  79                        LanguageModelToolResultContent::Text(text) => {
  80                            writeln!(markdown, "{text}\n").ok();
  81                        }
  82                        LanguageModelToolResultContent::Image(_) => {
  83                            writeln!(markdown, "<image />\n").ok();
  84                        }
  85                    }
  86
  87                    if let Some(output) = tool_result.output.as_ref() {
  88                        writeln!(
  89                            markdown,
  90                            "**Debug Output**:\n\n```json\n{}\n```\n",
  91                            serde_json::to_string_pretty(output).unwrap()
  92                        )
  93                        .unwrap();
  94                    }
  95                }
  96            }
  97        }
  98
  99        markdown
 100    }
 101}
 102
 103#[derive(Debug)]
 104pub enum AgentResponseEvent {
 105    Text(String),
 106    Thinking(String),
 107    ToolCall(acp::ToolCall),
 108    ToolCallUpdate(acp_thread::ToolCallUpdate),
 109    ToolCallAuthorization(ToolCallAuthorization),
 110    Stop(acp::StopReason),
 111}
 112
 113#[derive(Debug)]
 114pub struct ToolCallAuthorization {
 115    pub tool_call: acp::ToolCall,
 116    pub options: Vec<acp::PermissionOption>,
 117    pub response: oneshot::Sender<acp::PermissionOptionId>,
 118}
 119
 120pub struct Thread {
 121    messages: Vec<AgentMessage>,
 122    completion_mode: CompletionMode,
 123    /// Holds the task that handles agent interaction until the end of the turn.
 124    /// Survives across multiple requests as the model performs tool calls and
 125    /// we run tools, report their results.
 126    running_turn: Option<Task<()>>,
 127    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
 128    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 129    context_server_registry: Entity<ContextServerRegistry>,
 130    profile_id: AgentProfileId,
 131    project_context: Rc<RefCell<ProjectContext>>,
 132    templates: Arc<Templates>,
 133    pub selected_model: Arc<dyn LanguageModel>,
 134    project: Entity<Project>,
 135    action_log: Entity<ActionLog>,
 136}
 137
 138impl Thread {
 139    pub fn new(
 140        project: Entity<Project>,
 141        project_context: Rc<RefCell<ProjectContext>>,
 142        context_server_registry: Entity<ContextServerRegistry>,
 143        action_log: Entity<ActionLog>,
 144        templates: Arc<Templates>,
 145        default_model: Arc<dyn LanguageModel>,
 146        cx: &mut Context<Self>,
 147    ) -> Self {
 148        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 149        Self {
 150            messages: Vec::new(),
 151            completion_mode: CompletionMode::Normal,
 152            running_turn: None,
 153            pending_tool_uses: HashMap::default(),
 154            tools: BTreeMap::default(),
 155            context_server_registry,
 156            profile_id,
 157            project_context,
 158            templates,
 159            selected_model: default_model,
 160            project,
 161            action_log,
 162        }
 163    }
 164
 165    pub fn project(&self) -> &Entity<Project> {
 166        &self.project
 167    }
 168
 169    pub fn action_log(&self) -> &Entity<ActionLog> {
 170        &self.action_log
 171    }
 172
 173    pub fn set_mode(&mut self, mode: CompletionMode) {
 174        self.completion_mode = mode;
 175    }
 176
 177    pub fn messages(&self) -> &[AgentMessage] {
 178        &self.messages
 179    }
 180
 181    pub fn add_tool(&mut self, tool: impl AgentTool) {
 182        self.tools.insert(tool.name(), tool.erase());
 183    }
 184
 185    pub fn remove_tool(&mut self, name: &str) -> bool {
 186        self.tools.remove(name).is_some()
 187    }
 188
 189    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 190        self.profile_id = profile_id;
 191    }
 192
 193    pub fn cancel(&mut self) {
 194        self.running_turn.take();
 195
 196        let tool_results = self
 197            .pending_tool_uses
 198            .drain()
 199            .map(|(tool_use_id, tool_use)| {
 200                MessageContent::ToolResult(LanguageModelToolResult {
 201                    tool_use_id,
 202                    tool_name: tool_use.name.clone(),
 203                    is_error: true,
 204                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
 205                    output: None,
 206                })
 207            })
 208            .collect::<Vec<_>>();
 209        self.last_user_message().content.extend(tool_results);
 210    }
 211
 212    /// Sending a message results in the model streaming a response, which could include tool calls.
 213    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 214    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 215    pub fn send(
 216        &mut self,
 217        content: impl Into<MessageContent>,
 218        cx: &mut Context<Self>,
 219    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 220        let content = content.into();
 221        let model = self.selected_model.clone();
 222        log::info!("Thread::send called with model: {:?}", model.name());
 223        log::debug!("Thread::send content: {:?}", content);
 224
 225        cx.notify();
 226        let (events_tx, events_rx) =
 227            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 228        let event_stream = AgentResponseEventStream(events_tx);
 229
 230        let user_message_ix = self.messages.len();
 231        self.messages.push(AgentMessage {
 232            role: Role::User,
 233            content: vec![content],
 234        });
 235        log::info!("Total messages in thread: {}", self.messages.len());
 236        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 237            log::info!("Starting agent turn execution");
 238            let turn_result = async {
 239                // Perform one request, then keep looping if the model makes tool calls.
 240                let mut completion_intent = CompletionIntent::UserPrompt;
 241                'outer: loop {
 242                    log::debug!(
 243                        "Building completion request with intent: {:?}",
 244                        completion_intent
 245                    );
 246                    let request = thread.update(cx, |thread, cx| {
 247                        thread.build_completion_request(completion_intent, cx)
 248                    })?;
 249
 250                    // println!(
 251                    //     "request: {}",
 252                    //     serde_json::to_string_pretty(&request).unwrap()
 253                    // );
 254
 255                    // Stream events, appending to messages and collecting up tool uses.
 256                    log::info!("Calling model.stream_completion");
 257                    let mut events = model.stream_completion(request, cx).await?;
 258                    log::debug!("Stream completion started successfully");
 259                    let mut tool_uses = FuturesUnordered::new();
 260                    while let Some(event) = events.next().await {
 261                        match event {
 262                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 263                                event_stream.send_stop(reason);
 264                                if reason == StopReason::Refusal {
 265                                    thread.update(cx, |thread, _cx| {
 266                                        thread.messages.truncate(user_message_ix);
 267                                    })?;
 268                                    break 'outer;
 269                                }
 270                            }
 271                            Ok(event) => {
 272                                log::trace!("Received completion event: {:?}", event);
 273                                thread
 274                                    .update(cx, |thread, cx| {
 275                                        tool_uses.extend(thread.handle_streamed_completion_event(
 276                                            event,
 277                                            &event_stream,
 278                                            cx,
 279                                        ));
 280                                    })
 281                                    .ok();
 282                            }
 283                            Err(error) => {
 284                                log::error!("Error in completion stream: {:?}", error);
 285                                event_stream.send_error(error);
 286                                break;
 287                            }
 288                        }
 289                    }
 290
 291                    // If there are no tool uses, the turn is done.
 292                    if tool_uses.is_empty() {
 293                        log::info!("No tool uses found, completing turn");
 294                        break;
 295                    }
 296                    log::info!("Found {} tool uses to execute", tool_uses.len());
 297
 298                    // As tool results trickle in, insert them in the last user
 299                    // message so that they can be sent on the next tick of the
 300                    // agentic loop.
 301                    while let Some(tool_result) = tool_uses.next().await {
 302                        log::info!("Tool finished {:?}", tool_result);
 303
 304                        event_stream.update_tool_call_fields(
 305                            &tool_result.tool_use_id,
 306                            acp::ToolCallUpdateFields {
 307                                status: Some(if tool_result.is_error {
 308                                    acp::ToolCallStatus::Failed
 309                                } else {
 310                                    acp::ToolCallStatus::Completed
 311                                }),
 312                                raw_output: tool_result.output.clone(),
 313                                ..Default::default()
 314                            },
 315                        );
 316                        thread
 317                            .update(cx, |thread, _cx| {
 318                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
 319                                thread
 320                                    .last_user_message()
 321                                    .content
 322                                    .push(MessageContent::ToolResult(tool_result));
 323                            })
 324                            .ok();
 325                    }
 326
 327                    completion_intent = CompletionIntent::ToolResults;
 328                }
 329
 330                Ok(())
 331            }
 332            .await;
 333
 334            if let Err(error) = turn_result {
 335                log::error!("Turn execution failed: {:?}", error);
 336                event_stream.send_error(error);
 337            } else {
 338                log::info!("Turn execution completed successfully");
 339            }
 340        }));
 341        events_rx
 342    }
 343
 344    pub fn build_system_message(&self) -> AgentMessage {
 345        log::debug!("Building system message");
 346        let prompt = SystemPromptTemplate {
 347            project: &self.project_context.borrow(),
 348            available_tools: self.tools.keys().cloned().collect(),
 349        }
 350        .render(&self.templates)
 351        .context("failed to build system prompt")
 352        .expect("Invalid template");
 353        log::debug!("System message built");
 354        AgentMessage {
 355            role: Role::System,
 356            content: vec![prompt.into()],
 357        }
 358    }
 359
 360    /// A helper method that's called on every streamed completion event.
 361    /// Returns an optional tool result task, which the main agentic loop in
 362    /// send will send back to the model when it resolves.
 363    fn handle_streamed_completion_event(
 364        &mut self,
 365        event: LanguageModelCompletionEvent,
 366        event_stream: &AgentResponseEventStream,
 367        cx: &mut Context<Self>,
 368    ) -> Option<Task<LanguageModelToolResult>> {
 369        log::trace!("Handling streamed completion event: {:?}", event);
 370        use LanguageModelCompletionEvent::*;
 371
 372        match event {
 373            StartMessage { .. } => {
 374                self.messages.push(AgentMessage {
 375                    role: Role::Assistant,
 376                    content: Vec::new(),
 377                });
 378            }
 379            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 380            Thinking { text, signature } => {
 381                self.handle_thinking_event(text, signature, event_stream, cx)
 382            }
 383            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 384            ToolUse(tool_use) => {
 385                return self.handle_tool_use_event(tool_use, event_stream, cx);
 386            }
 387            ToolUseJsonParseError {
 388                id,
 389                tool_name,
 390                raw_input,
 391                json_parse_error,
 392            } => {
 393                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 394                    id,
 395                    tool_name,
 396                    raw_input,
 397                    json_parse_error,
 398                )));
 399            }
 400            UsageUpdate(_) | StatusUpdate(_) => {}
 401            Stop(_) => unreachable!(),
 402        }
 403
 404        None
 405    }
 406
 407    fn handle_text_event(
 408        &mut self,
 409        new_text: String,
 410        events_stream: &AgentResponseEventStream,
 411        cx: &mut Context<Self>,
 412    ) {
 413        events_stream.send_text(&new_text);
 414
 415        let last_message = self.last_assistant_message();
 416        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
 417            text.push_str(&new_text);
 418        } else {
 419            last_message.content.push(MessageContent::Text(new_text));
 420        }
 421
 422        cx.notify();
 423    }
 424
 425    fn handle_thinking_event(
 426        &mut self,
 427        new_text: String,
 428        new_signature: Option<String>,
 429        event_stream: &AgentResponseEventStream,
 430        cx: &mut Context<Self>,
 431    ) {
 432        event_stream.send_thinking(&new_text);
 433
 434        let last_message = self.last_assistant_message();
 435        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
 436        {
 437            text.push_str(&new_text);
 438            *signature = new_signature.or(signature.take());
 439        } else {
 440            last_message.content.push(MessageContent::Thinking {
 441                text: new_text,
 442                signature: new_signature,
 443            });
 444        }
 445
 446        cx.notify();
 447    }
 448
 449    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 450        let last_message = self.last_assistant_message();
 451        last_message
 452            .content
 453            .push(MessageContent::RedactedThinking(data));
 454        cx.notify();
 455    }
 456
 457    fn handle_tool_use_event(
 458        &mut self,
 459        tool_use: LanguageModelToolUse,
 460        event_stream: &AgentResponseEventStream,
 461        cx: &mut Context<Self>,
 462    ) -> Option<Task<LanguageModelToolResult>> {
 463        cx.notify();
 464
 465        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 466
 467        self.pending_tool_uses
 468            .insert(tool_use.id.clone(), tool_use.clone());
 469        let last_message = self.last_assistant_message();
 470
 471        // Ensure the last message ends in the current tool use
 472        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 473            if let MessageContent::ToolUse(last_tool_use) = content {
 474                if last_tool_use.id == tool_use.id {
 475                    *last_tool_use = tool_use.clone();
 476                    false
 477                } else {
 478                    true
 479                }
 480            } else {
 481                true
 482            }
 483        });
 484
 485        let mut title = SharedString::from(&tool_use.name);
 486        let mut kind = acp::ToolKind::Other;
 487        if let Some(tool) = tool.as_ref() {
 488            title = tool.initial_title(tool_use.input.clone());
 489            kind = tool.kind();
 490        }
 491
 492        if push_new_tool_use {
 493            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 494            last_message
 495                .content
 496                .push(MessageContent::ToolUse(tool_use.clone()));
 497        } else {
 498            event_stream.update_tool_call_fields(
 499                &tool_use.id,
 500                acp::ToolCallUpdateFields {
 501                    title: Some(title.into()),
 502                    kind: Some(kind),
 503                    raw_input: Some(tool_use.input.clone()),
 504                    ..Default::default()
 505                },
 506            );
 507        }
 508
 509        if !tool_use.is_input_complete {
 510            return None;
 511        }
 512
 513        let Some(tool) = tool else {
 514            let content = format!("No tool named {} exists", tool_use.name);
 515            return Some(Task::ready(LanguageModelToolResult {
 516                content: LanguageModelToolResultContent::Text(Arc::from(content)),
 517                tool_use_id: tool_use.id,
 518                tool_name: tool_use.name,
 519                is_error: true,
 520                output: None,
 521            }));
 522        };
 523
 524        let fs = self.project.read(cx).fs().clone();
 525        let tool_event_stream =
 526            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
 527        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
 528            status: Some(acp::ToolCallStatus::InProgress),
 529            ..Default::default()
 530        });
 531        let supports_images = self.selected_model.supports_images();
 532        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
 533        Some(cx.foreground_executor().spawn(async move {
 534            let tool_result = tool_result.await.and_then(|output| {
 535                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
 536                    if !supports_images {
 537                        return Err(anyhow!(
 538                            "Attempted to read an image, but this model doesn't support it.",
 539                        ));
 540                    }
 541                }
 542                Ok(output)
 543            });
 544
 545            match tool_result {
 546                Ok(output) => LanguageModelToolResult {
 547                    tool_use_id: tool_use.id,
 548                    tool_name: tool_use.name,
 549                    is_error: false,
 550                    content: output.llm_output,
 551                    output: Some(output.raw_output),
 552                },
 553                Err(error) => LanguageModelToolResult {
 554                    tool_use_id: tool_use.id,
 555                    tool_name: tool_use.name,
 556                    is_error: true,
 557                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
 558                    output: None,
 559                },
 560            }
 561        }))
 562    }
 563
 564    fn handle_tool_use_json_parse_error_event(
 565        &mut self,
 566        tool_use_id: LanguageModelToolUseId,
 567        tool_name: Arc<str>,
 568        raw_input: Arc<str>,
 569        json_parse_error: String,
 570    ) -> LanguageModelToolResult {
 571        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
 572        LanguageModelToolResult {
 573            tool_use_id,
 574            tool_name,
 575            is_error: true,
 576            content: LanguageModelToolResultContent::Text(tool_output.into()),
 577            output: Some(serde_json::Value::String(raw_input.to_string())),
 578        }
 579    }
 580
 581    /// Guarantees the last message is from the assistant and returns a mutable reference.
 582    fn last_assistant_message(&mut self) -> &mut AgentMessage {
 583        if self
 584            .messages
 585            .last()
 586            .map_or(true, |m| m.role != Role::Assistant)
 587        {
 588            self.messages.push(AgentMessage {
 589                role: Role::Assistant,
 590                content: Vec::new(),
 591            });
 592        }
 593        self.messages.last_mut().unwrap()
 594    }
 595
 596    /// Guarantees the last message is from the user and returns a mutable reference.
 597    fn last_user_message(&mut self) -> &mut AgentMessage {
 598        if self.messages.last().map_or(true, |m| m.role != Role::User) {
 599            self.messages.push(AgentMessage {
 600                role: Role::User,
 601                content: Vec::new(),
 602            });
 603        }
 604        self.messages.last_mut().unwrap()
 605    }
 606
 607    pub(crate) fn build_completion_request(
 608        &self,
 609        completion_intent: CompletionIntent,
 610        cx: &mut App,
 611    ) -> LanguageModelRequest {
 612        log::debug!("Building completion request");
 613        log::debug!("Completion intent: {:?}", completion_intent);
 614        log::debug!("Completion mode: {:?}", self.completion_mode);
 615
 616        let messages = self.build_request_messages();
 617        log::info!("Request will include {} messages", messages.len());
 618
 619        let tools = if let Some(tools) = self.tools(cx).log_err() {
 620            tools
 621                .filter_map(|tool| {
 622                    let tool_name = tool.name().to_string();
 623                    log::trace!("Including tool: {}", tool_name);
 624                    Some(LanguageModelRequestTool {
 625                        name: tool_name,
 626                        description: tool.description().to_string(),
 627                        input_schema: tool
 628                            .input_schema(self.selected_model.tool_input_format())
 629                            .log_err()?,
 630                    })
 631                })
 632                .collect()
 633        } else {
 634            Vec::new()
 635        };
 636
 637        log::info!("Request includes {} tools", tools.len());
 638
 639        let request = LanguageModelRequest {
 640            thread_id: None,
 641            prompt_id: None,
 642            intent: Some(completion_intent),
 643            mode: Some(self.completion_mode),
 644            messages,
 645            tools,
 646            tool_choice: None,
 647            stop: Vec::new(),
 648            temperature: None,
 649            thinking_allowed: true,
 650        };
 651
 652        log::debug!("Completion request built successfully");
 653        request
 654    }
 655
 656    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
 657        let profile = AgentSettings::get_global(cx)
 658            .profiles
 659            .get(&self.profile_id)
 660            .context("profile not found")?;
 661
 662        Ok(self
 663            .tools
 664            .iter()
 665            .filter_map(|(tool_name, tool)| {
 666                if profile.is_tool_enabled(tool_name) {
 667                    Some(tool)
 668                } else {
 669                    None
 670                }
 671            })
 672            .chain(self.context_server_registry.read(cx).servers().flat_map(
 673                |(server_id, tools)| {
 674                    tools.iter().filter_map(|(tool_name, tool)| {
 675                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
 676                            Some(tool)
 677                        } else {
 678                            None
 679                        }
 680                    })
 681                },
 682            )))
 683    }
 684
 685    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
 686        log::trace!(
 687            "Building request messages from {} thread messages",
 688            self.messages.len()
 689        );
 690
 691        let messages = Some(self.build_system_message())
 692            .iter()
 693            .chain(self.messages.iter())
 694            .map(|message| {
 695                log::trace!(
 696                    "  - {} message with {} content items",
 697                    match message.role {
 698                        Role::System => "System",
 699                        Role::User => "User",
 700                        Role::Assistant => "Assistant",
 701                    },
 702                    message.content.len()
 703                );
 704                LanguageModelRequestMessage {
 705                    role: message.role,
 706                    content: message.content.clone(),
 707                    cache: false,
 708                }
 709            })
 710            .collect();
 711        messages
 712    }
 713
 714    pub fn to_markdown(&self) -> String {
 715        let mut markdown = String::new();
 716        for message in &self.messages {
 717            markdown.push_str(&message.to_markdown());
 718        }
 719        markdown
 720    }
 721}
 722
 723pub trait AgentTool
 724where
 725    Self: 'static + Sized,
 726{
 727    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
 728    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
 729
 730    fn name(&self) -> SharedString;
 731
 732    fn description(&self) -> SharedString {
 733        let schema = schemars::schema_for!(Self::Input);
 734        SharedString::new(
 735            schema
 736                .get("description")
 737                .and_then(|description| description.as_str())
 738                .unwrap_or_default(),
 739        )
 740    }
 741
 742    fn kind(&self) -> acp::ToolKind;
 743
 744    /// The initial tool title to display. Can be updated during the tool run.
 745    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
 746
 747    /// Returns the JSON schema that describes the tool's input.
 748    fn input_schema(&self) -> Schema {
 749        schemars::schema_for!(Self::Input)
 750    }
 751
 752    /// Runs the tool with the provided input.
 753    fn run(
 754        self: Arc<Self>,
 755        input: Self::Input,
 756        event_stream: ToolCallEventStream,
 757        cx: &mut App,
 758    ) -> Task<Result<Self::Output>>;
 759
 760    fn erase(self) -> Arc<dyn AnyAgentTool> {
 761        Arc::new(Erased(Arc::new(self)))
 762    }
 763}
 764
 765pub struct Erased<T>(T);
 766
 767pub struct AgentToolOutput {
 768    pub llm_output: LanguageModelToolResultContent,
 769    pub raw_output: serde_json::Value,
 770}
 771
 772pub trait AnyAgentTool {
 773    fn name(&self) -> SharedString;
 774    fn description(&self) -> SharedString;
 775    fn kind(&self) -> acp::ToolKind;
 776    fn initial_title(&self, input: serde_json::Value) -> SharedString;
 777    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
 778    fn run(
 779        self: Arc<Self>,
 780        input: serde_json::Value,
 781        event_stream: ToolCallEventStream,
 782        cx: &mut App,
 783    ) -> Task<Result<AgentToolOutput>>;
 784}
 785
 786impl<T> AnyAgentTool for Erased<Arc<T>>
 787where
 788    T: AgentTool,
 789{
 790    fn name(&self) -> SharedString {
 791        self.0.name()
 792    }
 793
 794    fn description(&self) -> SharedString {
 795        self.0.description()
 796    }
 797
 798    fn kind(&self) -> agent_client_protocol::ToolKind {
 799        self.0.kind()
 800    }
 801
 802    fn initial_title(&self, input: serde_json::Value) -> SharedString {
 803        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
 804        self.0.initial_title(parsed_input)
 805    }
 806
 807    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 808        let mut json = serde_json::to_value(self.0.input_schema())?;
 809        adapt_schema_to_format(&mut json, format)?;
 810        Ok(json)
 811    }
 812
 813    fn run(
 814        self: Arc<Self>,
 815        input: serde_json::Value,
 816        event_stream: ToolCallEventStream,
 817        cx: &mut App,
 818    ) -> Task<Result<AgentToolOutput>> {
 819        cx.spawn(async move |cx| {
 820            let input = serde_json::from_value(input)?;
 821            let output = cx
 822                .update(|cx| self.0.clone().run(input, event_stream, cx))?
 823                .await?;
 824            let raw_output = serde_json::to_value(&output)?;
 825            Ok(AgentToolOutput {
 826                llm_output: output.into(),
 827                raw_output,
 828            })
 829        })
 830    }
 831}
 832
 833#[derive(Clone)]
 834struct AgentResponseEventStream(
 835    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 836);
 837
 838impl AgentResponseEventStream {
 839    fn send_text(&self, text: &str) {
 840        self.0
 841            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
 842            .ok();
 843    }
 844
 845    fn send_thinking(&self, text: &str) {
 846        self.0
 847            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
 848            .ok();
 849    }
 850
 851    fn send_tool_call(
 852        &self,
 853        id: &LanguageModelToolUseId,
 854        title: SharedString,
 855        kind: acp::ToolKind,
 856        input: serde_json::Value,
 857    ) {
 858        self.0
 859            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
 860                id,
 861                title.to_string(),
 862                kind,
 863                input,
 864            ))))
 865            .ok();
 866    }
 867
 868    fn initial_tool_call(
 869        id: &LanguageModelToolUseId,
 870        title: String,
 871        kind: acp::ToolKind,
 872        input: serde_json::Value,
 873    ) -> acp::ToolCall {
 874        acp::ToolCall {
 875            id: acp::ToolCallId(id.to_string().into()),
 876            title,
 877            kind,
 878            status: acp::ToolCallStatus::Pending,
 879            content: vec![],
 880            locations: vec![],
 881            raw_input: Some(input),
 882            raw_output: None,
 883        }
 884    }
 885
 886    fn update_tool_call_fields(
 887        &self,
 888        tool_use_id: &LanguageModelToolUseId,
 889        fields: acp::ToolCallUpdateFields,
 890    ) {
 891        self.0
 892            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 893                acp::ToolCallUpdate {
 894                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 895                    fields,
 896                }
 897                .into(),
 898            )))
 899            .ok();
 900    }
 901
 902    fn send_stop(&self, reason: StopReason) {
 903        match reason {
 904            StopReason::EndTurn => {
 905                self.0
 906                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
 907                    .ok();
 908            }
 909            StopReason::MaxTokens => {
 910                self.0
 911                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
 912                    .ok();
 913            }
 914            StopReason::Refusal => {
 915                self.0
 916                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
 917                    .ok();
 918            }
 919            StopReason::ToolUse => {}
 920        }
 921    }
 922
 923    fn send_error(&self, error: LanguageModelCompletionError) {
 924        self.0.unbounded_send(Err(error)).ok();
 925    }
 926}
 927
 928#[derive(Clone)]
 929pub struct ToolCallEventStream {
 930    tool_use_id: LanguageModelToolUseId,
 931    kind: acp::ToolKind,
 932    input: serde_json::Value,
 933    stream: AgentResponseEventStream,
 934    fs: Option<Arc<dyn Fs>>,
 935}
 936
 937impl ToolCallEventStream {
 938    #[cfg(test)]
 939    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
 940        let (events_tx, events_rx) =
 941            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 942
 943        let stream = ToolCallEventStream::new(
 944            &LanguageModelToolUse {
 945                id: "test_id".into(),
 946                name: "test_tool".into(),
 947                raw_input: String::new(),
 948                input: serde_json::Value::Null,
 949                is_input_complete: true,
 950            },
 951            acp::ToolKind::Other,
 952            AgentResponseEventStream(events_tx),
 953            None,
 954        );
 955
 956        (stream, ToolCallEventStreamReceiver(events_rx))
 957    }
 958
 959    fn new(
 960        tool_use: &LanguageModelToolUse,
 961        kind: acp::ToolKind,
 962        stream: AgentResponseEventStream,
 963        fs: Option<Arc<dyn Fs>>,
 964    ) -> Self {
 965        Self {
 966            tool_use_id: tool_use.id.clone(),
 967            kind,
 968            input: tool_use.input.clone(),
 969            stream,
 970            fs,
 971        }
 972    }
 973
 974    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
 975        self.stream
 976            .update_tool_call_fields(&self.tool_use_id, fields);
 977    }
 978
 979    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
 980        self.stream
 981            .0
 982            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 983                acp_thread::ToolCallUpdateDiff {
 984                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
 985                    diff,
 986                }
 987                .into(),
 988            )))
 989            .ok();
 990    }
 991
 992    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
 993        self.stream
 994            .0
 995            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 996                acp_thread::ToolCallUpdateTerminal {
 997                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
 998                    terminal,
 999                }
1000                .into(),
1001            )))
1002            .ok();
1003    }
1004
1005    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1006        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1007            return Task::ready(Ok(()));
1008        }
1009
1010        let (response_tx, response_rx) = oneshot::channel();
1011        self.stream
1012            .0
1013            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1014                ToolCallAuthorization {
1015                    tool_call: AgentResponseEventStream::initial_tool_call(
1016                        &self.tool_use_id,
1017                        title.into(),
1018                        self.kind.clone(),
1019                        self.input.clone(),
1020                    ),
1021                    options: vec![
1022                        acp::PermissionOption {
1023                            id: acp::PermissionOptionId("always_allow".into()),
1024                            name: "Always Allow".into(),
1025                            kind: acp::PermissionOptionKind::AllowAlways,
1026                        },
1027                        acp::PermissionOption {
1028                            id: acp::PermissionOptionId("allow".into()),
1029                            name: "Allow".into(),
1030                            kind: acp::PermissionOptionKind::AllowOnce,
1031                        },
1032                        acp::PermissionOption {
1033                            id: acp::PermissionOptionId("deny".into()),
1034                            name: "Deny".into(),
1035                            kind: acp::PermissionOptionKind::RejectOnce,
1036                        },
1037                    ],
1038                    response: response_tx,
1039                },
1040            )))
1041            .ok();
1042        let fs = self.fs.clone();
1043        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1044            "always_allow" => {
1045                if let Some(fs) = fs.clone() {
1046                    cx.update(|cx| {
1047                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1048                            settings.set_always_allow_tool_actions(true);
1049                        });
1050                    })?;
1051                }
1052
1053                Ok(())
1054            }
1055            "allow" => Ok(()),
1056            _ => Err(anyhow!("Permission to run tool denied by user")),
1057        })
1058    }
1059}
1060
1061#[cfg(test)]
1062pub struct ToolCallEventStreamReceiver(
1063    mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1064);
1065
1066#[cfg(test)]
1067impl ToolCallEventStreamReceiver {
1068    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1069        let event = self.0.next().await;
1070        if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1071            auth
1072        } else {
1073            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1074        }
1075    }
1076
1077    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1078        let event = self.0.next().await;
1079        if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1080            acp_thread::ToolCallUpdate::UpdateTerminal(update),
1081        ))) = event
1082        {
1083            update.terminal
1084        } else {
1085            panic!("Expected terminal but got: {:?}", event);
1086        }
1087    }
1088}
1089
1090#[cfg(test)]
1091impl std::ops::Deref for ToolCallEventStreamReceiver {
1092    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1093
1094    fn deref(&self) -> &Self::Target {
1095        &self.0
1096    }
1097}
1098
1099#[cfg(test)]
1100impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1101    fn deref_mut(&mut self) -> &mut Self::Target {
1102        &mut self.0
1103    }
1104}