thread.rs

   1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
   2use acp_thread::MentionUri;
   3use action_log::ActionLog;
   4use agent_client_protocol as acp;
   5use agent_settings::{AgentProfileId, AgentSettings};
   6use anyhow::{Context as _, Result, anyhow};
   7use assistant_tool::adapt_schema_to_format;
   8use cloud_llm_client::{CompletionIntent, CompletionMode};
   9use collections::HashMap;
  10use fs::Fs;
  11use futures::{
  12    channel::{mpsc, oneshot},
  13    stream::FuturesUnordered,
  14};
  15use gpui::{App, Context, Entity, SharedString, Task};
  16use language_model::{
  17    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
  18    LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
  19    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
  20    LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
  21};
  22use log;
  23use project::Project;
  24use prompt_store::ProjectContext;
  25use schemars::{JsonSchema, Schema};
  26use serde::{Deserialize, Serialize};
  27use settings::{Settings, update_settings_file};
  28use smol::stream::StreamExt;
  29use std::fmt::Write;
  30use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
  31use util::{ResultExt, markdown::MarkdownCodeBlock};
  32
  33#[derive(Debug, Clone)]
  34pub struct AgentMessage {
  35    pub role: Role,
  36    pub content: Vec<MessageContent>,
  37}
  38
  39#[derive(Debug, Clone, PartialEq, Eq)]
  40pub enum MessageContent {
  41    Text(String),
  42    Thinking {
  43        text: String,
  44        signature: Option<String>,
  45    },
  46    Mention {
  47        uri: MentionUri,
  48        content: String,
  49    },
  50    RedactedThinking(String),
  51    Image(LanguageModelImage),
  52    ToolUse(LanguageModelToolUse),
  53    ToolResult(LanguageModelToolResult),
  54}
  55
  56impl AgentMessage {
  57    pub fn to_markdown(&self) -> String {
  58        let mut markdown = format!("## {}\n", self.role);
  59
  60        for content in &self.content {
  61            match content {
  62                MessageContent::Text(text) => {
  63                    markdown.push_str(text);
  64                    markdown.push('\n');
  65                }
  66                MessageContent::Thinking { text, .. } => {
  67                    markdown.push_str("<think>");
  68                    markdown.push_str(text);
  69                    markdown.push_str("</think>\n");
  70                }
  71                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
  72                MessageContent::Image(_) => {
  73                    markdown.push_str("<image />\n");
  74                }
  75                MessageContent::ToolUse(tool_use) => {
  76                    markdown.push_str(&format!(
  77                        "**Tool Use**: {} (ID: {})\n",
  78                        tool_use.name, tool_use.id
  79                    ));
  80                    markdown.push_str(&format!(
  81                        "{}\n",
  82                        MarkdownCodeBlock {
  83                            tag: "json",
  84                            text: &format!("{:#}", tool_use.input)
  85                        }
  86                    ));
  87                }
  88                MessageContent::ToolResult(tool_result) => {
  89                    markdown.push_str(&format!(
  90                        "**Tool Result**: {} (ID: {})\n\n",
  91                        tool_result.tool_name, tool_result.tool_use_id
  92                    ));
  93                    if tool_result.is_error {
  94                        markdown.push_str("**ERROR:**\n");
  95                    }
  96
  97                    match &tool_result.content {
  98                        LanguageModelToolResultContent::Text(text) => {
  99                            writeln!(markdown, "{text}\n").ok();
 100                        }
 101                        LanguageModelToolResultContent::Image(_) => {
 102                            writeln!(markdown, "<image />\n").ok();
 103                        }
 104                    }
 105
 106                    if let Some(output) = tool_result.output.as_ref() {
 107                        writeln!(
 108                            markdown,
 109                            "**Debug Output**:\n\n```json\n{}\n```\n",
 110                            serde_json::to_string_pretty(output).unwrap()
 111                        )
 112                        .unwrap();
 113                    }
 114                }
 115                MessageContent::Mention { uri, .. } => {
 116                    write!(markdown, "{}", uri.to_link()).ok();
 117                }
 118            }
 119        }
 120
 121        markdown
 122    }
 123}
 124
 125#[derive(Debug)]
 126pub enum AgentResponseEvent {
 127    Text(String),
 128    Thinking(String),
 129    ToolCall(acp::ToolCall),
 130    ToolCallUpdate(acp_thread::ToolCallUpdate),
 131    ToolCallAuthorization(ToolCallAuthorization),
 132    Stop(acp::StopReason),
 133}
 134
 135#[derive(Debug)]
 136pub struct ToolCallAuthorization {
 137    pub tool_call: acp::ToolCall,
 138    pub options: Vec<acp::PermissionOption>,
 139    pub response: oneshot::Sender<acp::PermissionOptionId>,
 140}
 141
 142pub struct Thread {
 143    messages: Vec<AgentMessage>,
 144    completion_mode: CompletionMode,
 145    /// Holds the task that handles agent interaction until the end of the turn.
 146    /// Survives across multiple requests as the model performs tool calls and
 147    /// we run tools, report their results.
 148    running_turn: Option<Task<()>>,
 149    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
 150    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 151    context_server_registry: Entity<ContextServerRegistry>,
 152    profile_id: AgentProfileId,
 153    project_context: Rc<RefCell<ProjectContext>>,
 154    templates: Arc<Templates>,
 155    pub selected_model: Arc<dyn LanguageModel>,
 156    project: Entity<Project>,
 157    action_log: Entity<ActionLog>,
 158}
 159
 160impl Thread {
 161    pub fn new(
 162        project: Entity<Project>,
 163        project_context: Rc<RefCell<ProjectContext>>,
 164        context_server_registry: Entity<ContextServerRegistry>,
 165        action_log: Entity<ActionLog>,
 166        templates: Arc<Templates>,
 167        default_model: Arc<dyn LanguageModel>,
 168        cx: &mut Context<Self>,
 169    ) -> Self {
 170        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 171        Self {
 172            messages: Vec::new(),
 173            completion_mode: CompletionMode::Normal,
 174            running_turn: None,
 175            pending_tool_uses: HashMap::default(),
 176            tools: BTreeMap::default(),
 177            context_server_registry,
 178            profile_id,
 179            project_context,
 180            templates,
 181            selected_model: default_model,
 182            project,
 183            action_log,
 184        }
 185    }
 186
 187    pub fn project(&self) -> &Entity<Project> {
 188        &self.project
 189    }
 190
 191    pub fn action_log(&self) -> &Entity<ActionLog> {
 192        &self.action_log
 193    }
 194
 195    pub fn set_mode(&mut self, mode: CompletionMode) {
 196        self.completion_mode = mode;
 197    }
 198
 199    pub fn messages(&self) -> &[AgentMessage] {
 200        &self.messages
 201    }
 202
 203    pub fn add_tool(&mut self, tool: impl AgentTool) {
 204        self.tools.insert(tool.name(), tool.erase());
 205    }
 206
 207    pub fn remove_tool(&mut self, name: &str) -> bool {
 208        self.tools.remove(name).is_some()
 209    }
 210
 211    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 212        self.profile_id = profile_id;
 213    }
 214
 215    pub fn cancel(&mut self) {
 216        self.running_turn.take();
 217
 218        let tool_results = self
 219            .pending_tool_uses
 220            .drain()
 221            .map(|(tool_use_id, tool_use)| {
 222                MessageContent::ToolResult(LanguageModelToolResult {
 223                    tool_use_id,
 224                    tool_name: tool_use.name.clone(),
 225                    is_error: true,
 226                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
 227                    output: None,
 228                })
 229            })
 230            .collect::<Vec<_>>();
 231        self.last_user_message().content.extend(tool_results);
 232    }
 233
 234    /// Sending a message results in the model streaming a response, which could include tool calls.
 235    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 236    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 237    pub fn send(
 238        &mut self,
 239        content: impl Into<UserMessage>,
 240        cx: &mut Context<Self>,
 241    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 242        let content = content.into().0;
 243
 244        let model = self.selected_model.clone();
 245        log::info!("Thread::send called with model: {:?}", model.name());
 246        log::debug!("Thread::send content: {:?}", content);
 247
 248        cx.notify();
 249        let (events_tx, events_rx) =
 250            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 251        let event_stream = AgentResponseEventStream(events_tx);
 252
 253        let user_message_ix = self.messages.len();
 254        self.messages.push(AgentMessage {
 255            role: Role::User,
 256            content,
 257        });
 258        log::info!("Total messages in thread: {}", self.messages.len());
 259        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 260            log::info!("Starting agent turn execution");
 261            let turn_result = async {
 262                // Perform one request, then keep looping if the model makes tool calls.
 263                let mut completion_intent = CompletionIntent::UserPrompt;
 264                'outer: loop {
 265                    log::debug!(
 266                        "Building completion request with intent: {:?}",
 267                        completion_intent
 268                    );
 269                    let request = thread.update(cx, |thread, cx| {
 270                        thread.build_completion_request(completion_intent, cx)
 271                    })?;
 272
 273                    // println!(
 274                    //     "request: {}",
 275                    //     serde_json::to_string_pretty(&request).unwrap()
 276                    // );
 277
 278                    // Stream events, appending to messages and collecting up tool uses.
 279                    log::info!("Calling model.stream_completion");
 280                    let mut events = model.stream_completion(request, cx).await?;
 281                    log::debug!("Stream completion started successfully");
 282                    let mut tool_uses = FuturesUnordered::new();
 283                    while let Some(event) = events.next().await {
 284                        match event {
 285                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 286                                event_stream.send_stop(reason);
 287                                if reason == StopReason::Refusal {
 288                                    thread.update(cx, |thread, _cx| {
 289                                        thread.messages.truncate(user_message_ix);
 290                                    })?;
 291                                    break 'outer;
 292                                }
 293                            }
 294                            Ok(event) => {
 295                                log::trace!("Received completion event: {:?}", event);
 296                                thread
 297                                    .update(cx, |thread, cx| {
 298                                        tool_uses.extend(thread.handle_streamed_completion_event(
 299                                            event,
 300                                            &event_stream,
 301                                            cx,
 302                                        ));
 303                                    })
 304                                    .ok();
 305                            }
 306                            Err(error) => {
 307                                log::error!("Error in completion stream: {:?}", error);
 308                                event_stream.send_error(error);
 309                                break;
 310                            }
 311                        }
 312                    }
 313
 314                    // If there are no tool uses, the turn is done.
 315                    if tool_uses.is_empty() {
 316                        log::info!("No tool uses found, completing turn");
 317                        break;
 318                    }
 319                    log::info!("Found {} tool uses to execute", tool_uses.len());
 320
 321                    // As tool results trickle in, insert them in the last user
 322                    // message so that they can be sent on the next tick of the
 323                    // agentic loop.
 324                    while let Some(tool_result) = tool_uses.next().await {
 325                        log::info!("Tool finished {:?}", tool_result);
 326
 327                        event_stream.update_tool_call_fields(
 328                            &tool_result.tool_use_id,
 329                            acp::ToolCallUpdateFields {
 330                                status: Some(if tool_result.is_error {
 331                                    acp::ToolCallStatus::Failed
 332                                } else {
 333                                    acp::ToolCallStatus::Completed
 334                                }),
 335                                raw_output: tool_result.output.clone(),
 336                                ..Default::default()
 337                            },
 338                        );
 339                        thread
 340                            .update(cx, |thread, _cx| {
 341                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
 342                                thread
 343                                    .last_user_message()
 344                                    .content
 345                                    .push(MessageContent::ToolResult(tool_result));
 346                            })
 347                            .ok();
 348                    }
 349
 350                    completion_intent = CompletionIntent::ToolResults;
 351                }
 352
 353                Ok(())
 354            }
 355            .await;
 356
 357            if let Err(error) = turn_result {
 358                log::error!("Turn execution failed: {:?}", error);
 359                event_stream.send_error(error);
 360            } else {
 361                log::info!("Turn execution completed successfully");
 362            }
 363        }));
 364        events_rx
 365    }
 366
 367    pub fn build_system_message(&self) -> AgentMessage {
 368        log::debug!("Building system message");
 369        let prompt = SystemPromptTemplate {
 370            project: &self.project_context.borrow(),
 371            available_tools: self.tools.keys().cloned().collect(),
 372        }
 373        .render(&self.templates)
 374        .context("failed to build system prompt")
 375        .expect("Invalid template");
 376        log::debug!("System message built");
 377        AgentMessage {
 378            role: Role::System,
 379            content: vec![prompt.as_str().into()],
 380        }
 381    }
 382
 383    /// A helper method that's called on every streamed completion event.
 384    /// Returns an optional tool result task, which the main agentic loop in
 385    /// send will send back to the model when it resolves.
 386    fn handle_streamed_completion_event(
 387        &mut self,
 388        event: LanguageModelCompletionEvent,
 389        event_stream: &AgentResponseEventStream,
 390        cx: &mut Context<Self>,
 391    ) -> Option<Task<LanguageModelToolResult>> {
 392        log::trace!("Handling streamed completion event: {:?}", event);
 393        use LanguageModelCompletionEvent::*;
 394
 395        match event {
 396            StartMessage { .. } => {
 397                self.messages.push(AgentMessage {
 398                    role: Role::Assistant,
 399                    content: Vec::new(),
 400                });
 401            }
 402            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 403            Thinking { text, signature } => {
 404                self.handle_thinking_event(text, signature, event_stream, cx)
 405            }
 406            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 407            ToolUse(tool_use) => {
 408                return self.handle_tool_use_event(tool_use, event_stream, cx);
 409            }
 410            ToolUseJsonParseError {
 411                id,
 412                tool_name,
 413                raw_input,
 414                json_parse_error,
 415            } => {
 416                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 417                    id,
 418                    tool_name,
 419                    raw_input,
 420                    json_parse_error,
 421                )));
 422            }
 423            UsageUpdate(_) | StatusUpdate(_) => {}
 424            Stop(_) => unreachable!(),
 425        }
 426
 427        None
 428    }
 429
 430    fn handle_text_event(
 431        &mut self,
 432        new_text: String,
 433        events_stream: &AgentResponseEventStream,
 434        cx: &mut Context<Self>,
 435    ) {
 436        events_stream.send_text(&new_text);
 437
 438        let last_message = self.last_assistant_message();
 439        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
 440            text.push_str(&new_text);
 441        } else {
 442            last_message.content.push(MessageContent::Text(new_text));
 443        }
 444
 445        cx.notify();
 446    }
 447
 448    fn handle_thinking_event(
 449        &mut self,
 450        new_text: String,
 451        new_signature: Option<String>,
 452        event_stream: &AgentResponseEventStream,
 453        cx: &mut Context<Self>,
 454    ) {
 455        event_stream.send_thinking(&new_text);
 456
 457        let last_message = self.last_assistant_message();
 458        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
 459        {
 460            text.push_str(&new_text);
 461            *signature = new_signature.or(signature.take());
 462        } else {
 463            last_message.content.push(MessageContent::Thinking {
 464                text: new_text,
 465                signature: new_signature,
 466            });
 467        }
 468
 469        cx.notify();
 470    }
 471
 472    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 473        let last_message = self.last_assistant_message();
 474        last_message
 475            .content
 476            .push(MessageContent::RedactedThinking(data));
 477        cx.notify();
 478    }
 479
 480    fn handle_tool_use_event(
 481        &mut self,
 482        tool_use: LanguageModelToolUse,
 483        event_stream: &AgentResponseEventStream,
 484        cx: &mut Context<Self>,
 485    ) -> Option<Task<LanguageModelToolResult>> {
 486        cx.notify();
 487
 488        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 489
 490        self.pending_tool_uses
 491            .insert(tool_use.id.clone(), tool_use.clone());
 492        let last_message = self.last_assistant_message();
 493
 494        // Ensure the last message ends in the current tool use
 495        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 496            if let MessageContent::ToolUse(last_tool_use) = content {
 497                if last_tool_use.id == tool_use.id {
 498                    *last_tool_use = tool_use.clone();
 499                    false
 500                } else {
 501                    true
 502                }
 503            } else {
 504                true
 505            }
 506        });
 507
 508        let mut title = SharedString::from(&tool_use.name);
 509        let mut kind = acp::ToolKind::Other;
 510        if let Some(tool) = tool.as_ref() {
 511            title = tool.initial_title(tool_use.input.clone());
 512            kind = tool.kind();
 513        }
 514
 515        if push_new_tool_use {
 516            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 517            last_message
 518                .content
 519                .push(MessageContent::ToolUse(tool_use.clone()));
 520        } else {
 521            event_stream.update_tool_call_fields(
 522                &tool_use.id,
 523                acp::ToolCallUpdateFields {
 524                    title: Some(title.into()),
 525                    kind: Some(kind),
 526                    raw_input: Some(tool_use.input.clone()),
 527                    ..Default::default()
 528                },
 529            );
 530        }
 531
 532        if !tool_use.is_input_complete {
 533            return None;
 534        }
 535
 536        let Some(tool) = tool else {
 537            let content = format!("No tool named {} exists", tool_use.name);
 538            return Some(Task::ready(LanguageModelToolResult {
 539                content: LanguageModelToolResultContent::Text(Arc::from(content)),
 540                tool_use_id: tool_use.id,
 541                tool_name: tool_use.name,
 542                is_error: true,
 543                output: None,
 544            }));
 545        };
 546
 547        let fs = self.project.read(cx).fs().clone();
 548        let tool_event_stream =
 549            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
 550        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
 551            status: Some(acp::ToolCallStatus::InProgress),
 552            ..Default::default()
 553        });
 554        let supports_images = self.selected_model.supports_images();
 555        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
 556        Some(cx.foreground_executor().spawn(async move {
 557            let tool_result = tool_result.await.and_then(|output| {
 558                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
 559                    if !supports_images {
 560                        return Err(anyhow!(
 561                            "Attempted to read an image, but this model doesn't support it.",
 562                        ));
 563                    }
 564                }
 565                Ok(output)
 566            });
 567
 568            match tool_result {
 569                Ok(output) => LanguageModelToolResult {
 570                    tool_use_id: tool_use.id,
 571                    tool_name: tool_use.name,
 572                    is_error: false,
 573                    content: output.llm_output,
 574                    output: Some(output.raw_output),
 575                },
 576                Err(error) => LanguageModelToolResult {
 577                    tool_use_id: tool_use.id,
 578                    tool_name: tool_use.name,
 579                    is_error: true,
 580                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
 581                    output: None,
 582                },
 583            }
 584        }))
 585    }
 586
 587    fn handle_tool_use_json_parse_error_event(
 588        &mut self,
 589        tool_use_id: LanguageModelToolUseId,
 590        tool_name: Arc<str>,
 591        raw_input: Arc<str>,
 592        json_parse_error: String,
 593    ) -> LanguageModelToolResult {
 594        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
 595        LanguageModelToolResult {
 596            tool_use_id,
 597            tool_name,
 598            is_error: true,
 599            content: LanguageModelToolResultContent::Text(tool_output.into()),
 600            output: Some(serde_json::Value::String(raw_input.to_string())),
 601        }
 602    }
 603
 604    /// Guarantees the last message is from the assistant and returns a mutable reference.
 605    fn last_assistant_message(&mut self) -> &mut AgentMessage {
 606        if self
 607            .messages
 608            .last()
 609            .map_or(true, |m| m.role != Role::Assistant)
 610        {
 611            self.messages.push(AgentMessage {
 612                role: Role::Assistant,
 613                content: Vec::new(),
 614            });
 615        }
 616        self.messages.last_mut().unwrap()
 617    }
 618
 619    /// Guarantees the last message is from the user and returns a mutable reference.
 620    fn last_user_message(&mut self) -> &mut AgentMessage {
 621        if self.messages.last().map_or(true, |m| m.role != Role::User) {
 622            self.messages.push(AgentMessage {
 623                role: Role::User,
 624                content: Vec::new(),
 625            });
 626        }
 627        self.messages.last_mut().unwrap()
 628    }
 629
 630    pub(crate) fn build_completion_request(
 631        &self,
 632        completion_intent: CompletionIntent,
 633        cx: &mut App,
 634    ) -> LanguageModelRequest {
 635        log::debug!("Building completion request");
 636        log::debug!("Completion intent: {:?}", completion_intent);
 637        log::debug!("Completion mode: {:?}", self.completion_mode);
 638
 639        let messages = self.build_request_messages();
 640        log::info!("Request will include {} messages", messages.len());
 641
 642        let tools = if let Some(tools) = self.tools(cx).log_err() {
 643            tools
 644                .filter_map(|tool| {
 645                    let tool_name = tool.name().to_string();
 646                    log::trace!("Including tool: {}", tool_name);
 647                    Some(LanguageModelRequestTool {
 648                        name: tool_name,
 649                        description: tool.description().to_string(),
 650                        input_schema: tool
 651                            .input_schema(self.selected_model.tool_input_format())
 652                            .log_err()?,
 653                    })
 654                })
 655                .collect()
 656        } else {
 657            Vec::new()
 658        };
 659
 660        log::info!("Request includes {} tools", tools.len());
 661
 662        let request = LanguageModelRequest {
 663            thread_id: None,
 664            prompt_id: None,
 665            intent: Some(completion_intent),
 666            mode: Some(self.completion_mode),
 667            messages,
 668            tools,
 669            tool_choice: None,
 670            stop: Vec::new(),
 671            temperature: None,
 672            thinking_allowed: true,
 673        };
 674
 675        log::debug!("Completion request built successfully");
 676        request
 677    }
 678
 679    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
 680        let profile = AgentSettings::get_global(cx)
 681            .profiles
 682            .get(&self.profile_id)
 683            .context("profile not found")?;
 684        let provider_id = self.selected_model.provider_id();
 685
 686        Ok(self
 687            .tools
 688            .iter()
 689            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
 690            .filter_map(|(tool_name, tool)| {
 691                if profile.is_tool_enabled(tool_name) {
 692                    Some(tool)
 693                } else {
 694                    None
 695                }
 696            })
 697            .chain(self.context_server_registry.read(cx).servers().flat_map(
 698                |(server_id, tools)| {
 699                    tools.iter().filter_map(|(tool_name, tool)| {
 700                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
 701                            Some(tool)
 702                        } else {
 703                            None
 704                        }
 705                    })
 706                },
 707            )))
 708    }
 709
 710    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
 711        log::trace!(
 712            "Building request messages from {} thread messages",
 713            self.messages.len()
 714        );
 715
 716        let messages = Some(self.build_system_message())
 717            .iter()
 718            .chain(self.messages.iter())
 719            .map(|message| {
 720                log::trace!(
 721                    "  - {} message with {} content items",
 722                    match message.role {
 723                        Role::System => "System",
 724                        Role::User => "User",
 725                        Role::Assistant => "Assistant",
 726                    },
 727                    message.content.len()
 728                );
 729                message.to_request()
 730            })
 731            .collect();
 732        messages
 733    }
 734
 735    pub fn to_markdown(&self) -> String {
 736        let mut markdown = String::new();
 737        for message in &self.messages {
 738            markdown.push_str(&message.to_markdown());
 739        }
 740        markdown
 741    }
 742}
 743
 744pub struct UserMessage(Vec<MessageContent>);
 745
 746impl From<Vec<MessageContent>> for UserMessage {
 747    fn from(content: Vec<MessageContent>) -> Self {
 748        UserMessage(content)
 749    }
 750}
 751
 752impl<T: Into<MessageContent>> From<T> for UserMessage {
 753    fn from(content: T) -> Self {
 754        UserMessage(vec![content.into()])
 755    }
 756}
 757
 758pub trait AgentTool
 759where
 760    Self: 'static + Sized,
 761{
 762    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
 763    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
 764
 765    fn name(&self) -> SharedString;
 766
 767    fn description(&self) -> SharedString {
 768        let schema = schemars::schema_for!(Self::Input);
 769        SharedString::new(
 770            schema
 771                .get("description")
 772                .and_then(|description| description.as_str())
 773                .unwrap_or_default(),
 774        )
 775    }
 776
 777    fn kind(&self) -> acp::ToolKind;
 778
 779    /// The initial tool title to display. Can be updated during the tool run.
 780    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
 781
 782    /// Returns the JSON schema that describes the tool's input.
 783    fn input_schema(&self) -> Schema {
 784        schemars::schema_for!(Self::Input)
 785    }
 786
 787    /// Some tools rely on a provider for the underlying billing or other reasons.
 788    /// Allow the tool to check if they are compatible, or should be filtered out.
 789    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
 790        true
 791    }
 792
 793    /// Runs the tool with the provided input.
 794    fn run(
 795        self: Arc<Self>,
 796        input: Self::Input,
 797        event_stream: ToolCallEventStream,
 798        cx: &mut App,
 799    ) -> Task<Result<Self::Output>>;
 800
 801    fn erase(self) -> Arc<dyn AnyAgentTool> {
 802        Arc::new(Erased(Arc::new(self)))
 803    }
 804}
 805
 806pub struct Erased<T>(T);
 807
 808pub struct AgentToolOutput {
 809    pub llm_output: LanguageModelToolResultContent,
 810    pub raw_output: serde_json::Value,
 811}
 812
 813pub trait AnyAgentTool {
 814    fn name(&self) -> SharedString;
 815    fn description(&self) -> SharedString;
 816    fn kind(&self) -> acp::ToolKind;
 817    fn initial_title(&self, input: serde_json::Value) -> SharedString;
 818    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
 819    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
 820        true
 821    }
 822    fn run(
 823        self: Arc<Self>,
 824        input: serde_json::Value,
 825        event_stream: ToolCallEventStream,
 826        cx: &mut App,
 827    ) -> Task<Result<AgentToolOutput>>;
 828}
 829
 830impl<T> AnyAgentTool for Erased<Arc<T>>
 831where
 832    T: AgentTool,
 833{
 834    fn name(&self) -> SharedString {
 835        self.0.name()
 836    }
 837
 838    fn description(&self) -> SharedString {
 839        self.0.description()
 840    }
 841
 842    fn kind(&self) -> agent_client_protocol::ToolKind {
 843        self.0.kind()
 844    }
 845
 846    fn initial_title(&self, input: serde_json::Value) -> SharedString {
 847        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
 848        self.0.initial_title(parsed_input)
 849    }
 850
 851    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 852        let mut json = serde_json::to_value(self.0.input_schema())?;
 853        adapt_schema_to_format(&mut json, format)?;
 854        Ok(json)
 855    }
 856
 857    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
 858        self.0.supported_provider(provider)
 859    }
 860
 861    fn run(
 862        self: Arc<Self>,
 863        input: serde_json::Value,
 864        event_stream: ToolCallEventStream,
 865        cx: &mut App,
 866    ) -> Task<Result<AgentToolOutput>> {
 867        cx.spawn(async move |cx| {
 868            let input = serde_json::from_value(input)?;
 869            let output = cx
 870                .update(|cx| self.0.clone().run(input, event_stream, cx))?
 871                .await?;
 872            let raw_output = serde_json::to_value(&output)?;
 873            Ok(AgentToolOutput {
 874                llm_output: output.into(),
 875                raw_output,
 876            })
 877        })
 878    }
 879}
 880
 881#[derive(Clone)]
 882struct AgentResponseEventStream(
 883    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 884);
 885
 886impl AgentResponseEventStream {
 887    fn send_text(&self, text: &str) {
 888        self.0
 889            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
 890            .ok();
 891    }
 892
 893    fn send_thinking(&self, text: &str) {
 894        self.0
 895            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
 896            .ok();
 897    }
 898
 899    fn send_tool_call(
 900        &self,
 901        id: &LanguageModelToolUseId,
 902        title: SharedString,
 903        kind: acp::ToolKind,
 904        input: serde_json::Value,
 905    ) {
 906        self.0
 907            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
 908                id,
 909                title.to_string(),
 910                kind,
 911                input,
 912            ))))
 913            .ok();
 914    }
 915
 916    fn initial_tool_call(
 917        id: &LanguageModelToolUseId,
 918        title: String,
 919        kind: acp::ToolKind,
 920        input: serde_json::Value,
 921    ) -> acp::ToolCall {
 922        acp::ToolCall {
 923            id: acp::ToolCallId(id.to_string().into()),
 924            title,
 925            kind,
 926            status: acp::ToolCallStatus::Pending,
 927            content: vec![],
 928            locations: vec![],
 929            raw_input: Some(input),
 930            raw_output: None,
 931        }
 932    }
 933
 934    fn update_tool_call_fields(
 935        &self,
 936        tool_use_id: &LanguageModelToolUseId,
 937        fields: acp::ToolCallUpdateFields,
 938    ) {
 939        self.0
 940            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 941                acp::ToolCallUpdate {
 942                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 943                    fields,
 944                }
 945                .into(),
 946            )))
 947            .ok();
 948    }
 949
 950    fn send_stop(&self, reason: StopReason) {
 951        match reason {
 952            StopReason::EndTurn => {
 953                self.0
 954                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
 955                    .ok();
 956            }
 957            StopReason::MaxTokens => {
 958                self.0
 959                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
 960                    .ok();
 961            }
 962            StopReason::Refusal => {
 963                self.0
 964                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
 965                    .ok();
 966            }
 967            StopReason::ToolUse => {}
 968        }
 969    }
 970
 971    fn send_error(&self, error: LanguageModelCompletionError) {
 972        self.0.unbounded_send(Err(error)).ok();
 973    }
 974}
 975
 976#[derive(Clone)]
 977pub struct ToolCallEventStream {
 978    tool_use_id: LanguageModelToolUseId,
 979    kind: acp::ToolKind,
 980    input: serde_json::Value,
 981    stream: AgentResponseEventStream,
 982    fs: Option<Arc<dyn Fs>>,
 983}
 984
 985impl ToolCallEventStream {
 986    #[cfg(test)]
 987    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
 988        let (events_tx, events_rx) =
 989            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 990
 991        let stream = ToolCallEventStream::new(
 992            &LanguageModelToolUse {
 993                id: "test_id".into(),
 994                name: "test_tool".into(),
 995                raw_input: String::new(),
 996                input: serde_json::Value::Null,
 997                is_input_complete: true,
 998            },
 999            acp::ToolKind::Other,
1000            AgentResponseEventStream(events_tx),
1001            None,
1002        );
1003
1004        (stream, ToolCallEventStreamReceiver(events_rx))
1005    }
1006
1007    fn new(
1008        tool_use: &LanguageModelToolUse,
1009        kind: acp::ToolKind,
1010        stream: AgentResponseEventStream,
1011        fs: Option<Arc<dyn Fs>>,
1012    ) -> Self {
1013        Self {
1014            tool_use_id: tool_use.id.clone(),
1015            kind,
1016            input: tool_use.input.clone(),
1017            stream,
1018            fs,
1019        }
1020    }
1021
1022    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1023        self.stream
1024            .update_tool_call_fields(&self.tool_use_id, fields);
1025    }
1026
1027    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1028        self.stream
1029            .0
1030            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1031                acp_thread::ToolCallUpdateDiff {
1032                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1033                    diff,
1034                }
1035                .into(),
1036            )))
1037            .ok();
1038    }
1039
1040    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1041        self.stream
1042            .0
1043            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1044                acp_thread::ToolCallUpdateTerminal {
1045                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1046                    terminal,
1047                }
1048                .into(),
1049            )))
1050            .ok();
1051    }
1052
1053    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1054        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1055            return Task::ready(Ok(()));
1056        }
1057
1058        let (response_tx, response_rx) = oneshot::channel();
1059        self.stream
1060            .0
1061            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1062                ToolCallAuthorization {
1063                    tool_call: AgentResponseEventStream::initial_tool_call(
1064                        &self.tool_use_id,
1065                        title.into(),
1066                        self.kind.clone(),
1067                        self.input.clone(),
1068                    ),
1069                    options: vec![
1070                        acp::PermissionOption {
1071                            id: acp::PermissionOptionId("always_allow".into()),
1072                            name: "Always Allow".into(),
1073                            kind: acp::PermissionOptionKind::AllowAlways,
1074                        },
1075                        acp::PermissionOption {
1076                            id: acp::PermissionOptionId("allow".into()),
1077                            name: "Allow".into(),
1078                            kind: acp::PermissionOptionKind::AllowOnce,
1079                        },
1080                        acp::PermissionOption {
1081                            id: acp::PermissionOptionId("deny".into()),
1082                            name: "Deny".into(),
1083                            kind: acp::PermissionOptionKind::RejectOnce,
1084                        },
1085                    ],
1086                    response: response_tx,
1087                },
1088            )))
1089            .ok();
1090        let fs = self.fs.clone();
1091        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1092            "always_allow" => {
1093                if let Some(fs) = fs.clone() {
1094                    cx.update(|cx| {
1095                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1096                            settings.set_always_allow_tool_actions(true);
1097                        });
1098                    })?;
1099                }
1100
1101                Ok(())
1102            }
1103            "allow" => Ok(()),
1104            _ => Err(anyhow!("Permission to run tool denied by user")),
1105        })
1106    }
1107}
1108
1109#[cfg(test)]
1110pub struct ToolCallEventStreamReceiver(
1111    mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1112);
1113
1114#[cfg(test)]
1115impl ToolCallEventStreamReceiver {
1116    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1117        let event = self.0.next().await;
1118        if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1119            auth
1120        } else {
1121            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1122        }
1123    }
1124
1125    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1126        let event = self.0.next().await;
1127        if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1128            acp_thread::ToolCallUpdate::UpdateTerminal(update),
1129        ))) = event
1130        {
1131            update.terminal
1132        } else {
1133            panic!("Expected terminal but got: {:?}", event);
1134        }
1135    }
1136}
1137
1138#[cfg(test)]
1139impl std::ops::Deref for ToolCallEventStreamReceiver {
1140    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1141
1142    fn deref(&self) -> &Self::Target {
1143        &self.0
1144    }
1145}
1146
1147#[cfg(test)]
1148impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1149    fn deref_mut(&mut self) -> &mut Self::Target {
1150        &mut self.0
1151    }
1152}
1153
1154impl AgentMessage {
1155    fn to_request(&self) -> language_model::LanguageModelRequestMessage {
1156        let mut message = LanguageModelRequestMessage {
1157            role: self.role,
1158            content: Vec::with_capacity(self.content.len()),
1159            cache: false,
1160        };
1161
1162        const OPEN_CONTEXT: &str = "<context>\n\
1163            The following items were attached by the user. \
1164            They are up-to-date and don't need to be re-read.\n\n";
1165
1166        const OPEN_FILES_TAG: &str = "<files>";
1167        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
1168        const OPEN_THREADS_TAG: &str = "<threads>";
1169        const OPEN_RULES_TAG: &str =
1170            "<rules>\nThe user has specified the following rules that should be applied:\n";
1171
1172        let mut file_context = OPEN_FILES_TAG.to_string();
1173        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
1174        let mut thread_context = OPEN_THREADS_TAG.to_string();
1175        let mut rules_context = OPEN_RULES_TAG.to_string();
1176
1177        for chunk in &self.content {
1178            let chunk = match chunk {
1179                MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()),
1180                MessageContent::Thinking { text, signature } => {
1181                    language_model::MessageContent::Thinking {
1182                        text: text.clone(),
1183                        signature: signature.clone(),
1184                    }
1185                }
1186                MessageContent::RedactedThinking(value) => {
1187                    language_model::MessageContent::RedactedThinking(value.clone())
1188                }
1189                MessageContent::ToolUse(value) => {
1190                    language_model::MessageContent::ToolUse(value.clone())
1191                }
1192                MessageContent::ToolResult(value) => {
1193                    language_model::MessageContent::ToolResult(value.clone())
1194                }
1195                MessageContent::Image(value) => {
1196                    language_model::MessageContent::Image(value.clone())
1197                }
1198                MessageContent::Mention { uri, content } => {
1199                    match uri {
1200                        MentionUri::File(path) | MentionUri::Symbol(path, _) => {
1201                            write!(
1202                                &mut symbol_context,
1203                                "\n{}",
1204                                MarkdownCodeBlock {
1205                                    tag: &codeblock_tag(&path),
1206                                    text: &content.to_string(),
1207                                }
1208                            )
1209                            .ok();
1210                        }
1211                        MentionUri::Thread(_session_id) => {
1212                            write!(&mut thread_context, "\n{}\n", content).ok();
1213                        }
1214                        MentionUri::Rule(_user_prompt_id) => {
1215                            write!(
1216                                &mut rules_context,
1217                                "\n{}",
1218                                MarkdownCodeBlock {
1219                                    tag: "",
1220                                    text: &content
1221                                }
1222                            )
1223                            .ok();
1224                        }
1225                    }
1226
1227                    language_model::MessageContent::Text(uri.to_link())
1228                }
1229            };
1230
1231            message.content.push(chunk);
1232        }
1233
1234        let len_before_context = message.content.len();
1235
1236        if file_context.len() > OPEN_FILES_TAG.len() {
1237            file_context.push_str("</files>\n");
1238            message
1239                .content
1240                .push(language_model::MessageContent::Text(file_context));
1241        }
1242
1243        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
1244            symbol_context.push_str("</symbols>\n");
1245            message
1246                .content
1247                .push(language_model::MessageContent::Text(symbol_context));
1248        }
1249
1250        if thread_context.len() > OPEN_THREADS_TAG.len() {
1251            thread_context.push_str("</threads>\n");
1252            message
1253                .content
1254                .push(language_model::MessageContent::Text(thread_context));
1255        }
1256
1257        if rules_context.len() > OPEN_RULES_TAG.len() {
1258            rules_context.push_str("</user_rules>\n");
1259            message
1260                .content
1261                .push(language_model::MessageContent::Text(rules_context));
1262        }
1263
1264        if message.content.len() > len_before_context {
1265            message.content.insert(
1266                len_before_context,
1267                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
1268            );
1269            message
1270                .content
1271                .push(language_model::MessageContent::Text("</context>".into()));
1272        }
1273
1274        message
1275    }
1276}
1277
1278fn codeblock_tag(full_path: &Path) -> String {
1279    let mut result = String::new();
1280
1281    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
1282        let _ = write!(result, "{} ", extension);
1283    }
1284
1285    let _ = write!(result, "{}", full_path.display());
1286
1287    result
1288}
1289
1290impl From<acp::ContentBlock> for MessageContent {
1291    fn from(value: acp::ContentBlock) -> Self {
1292        match value {
1293            acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text),
1294            acp::ContentBlock::Image(image_content) => {
1295                MessageContent::Image(convert_image(image_content))
1296            }
1297            acp::ContentBlock::Audio(_) => {
1298                // TODO
1299                MessageContent::Text("[audio]".to_string())
1300            }
1301            acp::ContentBlock::ResourceLink(resource_link) => {
1302                match MentionUri::parse(&resource_link.uri) {
1303                    Ok(uri) => Self::Mention {
1304                        uri,
1305                        content: String::new(),
1306                    },
1307                    Err(err) => {
1308                        log::error!("Failed to parse mention link: {}", err);
1309                        MessageContent::Text(format!(
1310                            "[{}]({})",
1311                            resource_link.name, resource_link.uri
1312                        ))
1313                    }
1314                }
1315            }
1316            acp::ContentBlock::Resource(resource) => match resource.resource {
1317                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1318                    match MentionUri::parse(&resource.uri) {
1319                        Ok(uri) => Self::Mention {
1320                            uri,
1321                            content: resource.text,
1322                        },
1323                        Err(err) => {
1324                            log::error!("Failed to parse mention link: {}", err);
1325                            MessageContent::Text(
1326                                MarkdownCodeBlock {
1327                                    tag: &resource.uri,
1328                                    text: &resource.text,
1329                                }
1330                                .to_string(),
1331                            )
1332                        }
1333                    }
1334                }
1335                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1336                    // TODO
1337                    MessageContent::Text("[blob]".to_string())
1338                }
1339            },
1340        }
1341    }
1342}
1343
1344fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1345    LanguageModelImage {
1346        source: image_content.data.into(),
1347        // TODO: make this optional?
1348        size: gpui::Size::new(0.into(), 0.into()),
1349    }
1350}
1351
1352impl From<&str> for MessageContent {
1353    fn from(text: &str) -> Self {
1354        MessageContent::Text(text.into())
1355    }
1356}