thread.rs

   1use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
   2use acp_thread::MentionUri;
   3use action_log::ActionLog;
   4use agent_client_protocol as acp;
   5use agent_settings::{AgentProfileId, AgentSettings};
   6use anyhow::{Context as _, Result, anyhow};
   7use assistant_tool::adapt_schema_to_format;
   8use cloud_llm_client::{CompletionIntent, CompletionMode};
   9use collections::HashMap;
  10use fs::Fs;
  11use futures::{
  12    channel::{mpsc, oneshot},
  13    stream::FuturesUnordered,
  14};
  15use gpui::{App, Context, Entity, SharedString, Task};
  16use language_model::{
  17    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
  18    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  19    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  20    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
  21};
  22use log;
  23use project::Project;
  24use prompt_store::ProjectContext;
  25use schemars::{JsonSchema, Schema};
  26use serde::{Deserialize, Serialize};
  27use settings::{Settings, update_settings_file};
  28use smol::stream::StreamExt;
  29use std::fmt::Write;
  30use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
  31use util::{ResultExt, markdown::MarkdownCodeBlock};
  32
  33#[derive(Debug, Clone)]
  34pub struct AgentMessage {
  35    pub role: Role,
  36    pub content: Vec<MessageContent>,
  37}
  38
  39#[derive(Debug, Clone, PartialEq, Eq)]
  40pub enum MessageContent {
  41    Text(String),
  42    Thinking {
  43        text: String,
  44        signature: Option<String>,
  45    },
  46    Mention {
  47        uri: MentionUri,
  48        content: String,
  49    },
  50    RedactedThinking(String),
  51    Image(LanguageModelImage),
  52    ToolUse(LanguageModelToolUse),
  53    ToolResult(LanguageModelToolResult),
  54}
  55
  56impl AgentMessage {
  57    pub fn to_markdown(&self) -> String {
  58        let mut markdown = format!("## {}\n", self.role);
  59
  60        for content in &self.content {
  61            match content {
  62                MessageContent::Text(text) => {
  63                    markdown.push_str(text);
  64                    markdown.push('\n');
  65                }
  66                MessageContent::Thinking { text, .. } => {
  67                    markdown.push_str("<think>");
  68                    markdown.push_str(text);
  69                    markdown.push_str("</think>\n");
  70                }
  71                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
  72                MessageContent::Image(_) => {
  73                    markdown.push_str("<image />\n");
  74                }
  75                MessageContent::ToolUse(tool_use) => {
  76                    markdown.push_str(&format!(
  77                        "**Tool Use**: {} (ID: {})\n",
  78                        tool_use.name, tool_use.id
  79                    ));
  80                    markdown.push_str(&format!(
  81                        "{}\n",
  82                        MarkdownCodeBlock {
  83                            tag: "json",
  84                            text: &format!("{:#}", tool_use.input)
  85                        }
  86                    ));
  87                }
  88                MessageContent::ToolResult(tool_result) => {
  89                    markdown.push_str(&format!(
  90                        "**Tool Result**: {} (ID: {})\n\n",
  91                        tool_result.tool_name, tool_result.tool_use_id
  92                    ));
  93                    if tool_result.is_error {
  94                        markdown.push_str("**ERROR:**\n");
  95                    }
  96
  97                    match &tool_result.content {
  98                        LanguageModelToolResultContent::Text(text) => {
  99                            writeln!(markdown, "{text}\n").ok();
 100                        }
 101                        LanguageModelToolResultContent::Image(_) => {
 102                            writeln!(markdown, "<image />\n").ok();
 103                        }
 104                    }
 105
 106                    if let Some(output) = tool_result.output.as_ref() {
 107                        writeln!(
 108                            markdown,
 109                            "**Debug Output**:\n\n```json\n{}\n```\n",
 110                            serde_json::to_string_pretty(output).unwrap()
 111                        )
 112                        .unwrap();
 113                    }
 114                }
 115                MessageContent::Mention { uri, .. } => {
 116                    write!(markdown, "{}", uri.to_link()).ok();
 117                }
 118            }
 119        }
 120
 121        markdown
 122    }
 123}
 124
 125#[derive(Debug)]
 126pub enum AgentResponseEvent {
 127    Text(String),
 128    Thinking(String),
 129    ToolCall(acp::ToolCall),
 130    ToolCallUpdate(acp_thread::ToolCallUpdate),
 131    ToolCallAuthorization(ToolCallAuthorization),
 132    Stop(acp::StopReason),
 133}
 134
 135#[derive(Debug)]
 136pub struct ToolCallAuthorization {
 137    pub tool_call: acp::ToolCall,
 138    pub options: Vec<acp::PermissionOption>,
 139    pub response: oneshot::Sender<acp::PermissionOptionId>,
 140}
 141
 142pub struct Thread {
 143    messages: Vec<AgentMessage>,
 144    completion_mode: CompletionMode,
 145    /// Holds the task that handles agent interaction until the end of the turn.
 146    /// Survives across multiple requests as the model performs tool calls and
 147    /// we run tools, report their results.
 148    running_turn: Option<Task<()>>,
 149    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
 150    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 151    context_server_registry: Entity<ContextServerRegistry>,
 152    profile_id: AgentProfileId,
 153    project_context: Rc<RefCell<ProjectContext>>,
 154    templates: Arc<Templates>,
 155    pub selected_model: Arc<dyn LanguageModel>,
 156    project: Entity<Project>,
 157    action_log: Entity<ActionLog>,
 158}
 159
 160impl Thread {
 161    pub fn new(
 162        project: Entity<Project>,
 163        project_context: Rc<RefCell<ProjectContext>>,
 164        context_server_registry: Entity<ContextServerRegistry>,
 165        action_log: Entity<ActionLog>,
 166        templates: Arc<Templates>,
 167        default_model: Arc<dyn LanguageModel>,
 168        cx: &mut Context<Self>,
 169    ) -> Self {
 170        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 171        Self {
 172            messages: Vec::new(),
 173            completion_mode: CompletionMode::Normal,
 174            running_turn: None,
 175            pending_tool_uses: HashMap::default(),
 176            tools: BTreeMap::default(),
 177            context_server_registry,
 178            profile_id,
 179            project_context,
 180            templates,
 181            selected_model: default_model,
 182            project,
 183            action_log,
 184        }
 185    }
 186
 187    pub fn project(&self) -> &Entity<Project> {
 188        &self.project
 189    }
 190
 191    pub fn action_log(&self) -> &Entity<ActionLog> {
 192        &self.action_log
 193    }
 194
 195    pub fn set_mode(&mut self, mode: CompletionMode) {
 196        self.completion_mode = mode;
 197    }
 198
 199    pub fn messages(&self) -> &[AgentMessage] {
 200        &self.messages
 201    }
 202
 203    pub fn add_tool(&mut self, tool: impl AgentTool) {
 204        self.tools.insert(tool.name(), tool.erase());
 205    }
 206
 207    pub fn remove_tool(&mut self, name: &str) -> bool {
 208        self.tools.remove(name).is_some()
 209    }
 210
 211    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 212        self.profile_id = profile_id;
 213    }
 214
 215    pub fn cancel(&mut self) {
 216        self.running_turn.take();
 217
 218        let tool_results = self
 219            .pending_tool_uses
 220            .drain()
 221            .map(|(tool_use_id, tool_use)| {
 222                MessageContent::ToolResult(LanguageModelToolResult {
 223                    tool_use_id,
 224                    tool_name: tool_use.name.clone(),
 225                    is_error: true,
 226                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
 227                    output: None,
 228                })
 229            })
 230            .collect::<Vec<_>>();
 231        self.last_user_message().content.extend(tool_results);
 232    }
 233
 234    /// Sending a message results in the model streaming a response, which could include tool calls.
 235    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 236    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 237    pub fn send(
 238        &mut self,
 239        content: impl Into<UserMessage>,
 240        cx: &mut Context<Self>,
 241    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
 242        let content = content.into().0;
 243
 244        let model = self.selected_model.clone();
 245        log::info!("Thread::send called with model: {:?}", model.name());
 246        log::debug!("Thread::send content: {:?}", content);
 247
 248        cx.notify();
 249        let (events_tx, events_rx) =
 250            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 251        let event_stream = AgentResponseEventStream(events_tx);
 252
 253        let user_message_ix = self.messages.len();
 254        self.messages.push(AgentMessage {
 255            role: Role::User,
 256            content,
 257        });
 258        log::info!("Total messages in thread: {}", self.messages.len());
 259        self.running_turn = Some(cx.spawn(async move |thread, cx| {
 260            log::info!("Starting agent turn execution");
 261            let turn_result = async {
 262                // Perform one request, then keep looping if the model makes tool calls.
 263                let mut completion_intent = CompletionIntent::UserPrompt;
 264                'outer: loop {
 265                    log::debug!(
 266                        "Building completion request with intent: {:?}",
 267                        completion_intent
 268                    );
 269                    let request = thread.update(cx, |thread, cx| {
 270                        thread.build_completion_request(completion_intent, cx)
 271                    })?;
 272
 273                    // println!(
 274                    //     "request: {}",
 275                    //     serde_json::to_string_pretty(&request).unwrap()
 276                    // );
 277
 278                    // Stream events, appending to messages and collecting up tool uses.
 279                    log::info!("Calling model.stream_completion");
 280                    let mut events = model.stream_completion(request, cx).await?;
 281                    log::debug!("Stream completion started successfully");
 282                    let mut tool_uses = FuturesUnordered::new();
 283                    while let Some(event) = events.next().await {
 284                        match event {
 285                            Ok(LanguageModelCompletionEvent::Stop(reason)) => {
 286                                event_stream.send_stop(reason);
 287                                if reason == StopReason::Refusal {
 288                                    thread.update(cx, |thread, _cx| {
 289                                        thread.messages.truncate(user_message_ix);
 290                                    })?;
 291                                    break 'outer;
 292                                }
 293                            }
 294                            Ok(event) => {
 295                                log::trace!("Received completion event: {:?}", event);
 296                                thread
 297                                    .update(cx, |thread, cx| {
 298                                        tool_uses.extend(thread.handle_streamed_completion_event(
 299                                            event,
 300                                            &event_stream,
 301                                            cx,
 302                                        ));
 303                                    })
 304                                    .ok();
 305                            }
 306                            Err(error) => {
 307                                log::error!("Error in completion stream: {:?}", error);
 308                                event_stream.send_error(error);
 309                                break;
 310                            }
 311                        }
 312                    }
 313
 314                    // If there are no tool uses, the turn is done.
 315                    if tool_uses.is_empty() {
 316                        log::info!("No tool uses found, completing turn");
 317                        break;
 318                    }
 319                    log::info!("Found {} tool uses to execute", tool_uses.len());
 320
 321                    // As tool results trickle in, insert them in the last user
 322                    // message so that they can be sent on the next tick of the
 323                    // agentic loop.
 324                    while let Some(tool_result) = tool_uses.next().await {
 325                        log::info!("Tool finished {:?}", tool_result);
 326
 327                        event_stream.update_tool_call_fields(
 328                            &tool_result.tool_use_id,
 329                            acp::ToolCallUpdateFields {
 330                                status: Some(if tool_result.is_error {
 331                                    acp::ToolCallStatus::Failed
 332                                } else {
 333                                    acp::ToolCallStatus::Completed
 334                                }),
 335                                raw_output: tool_result.output.clone(),
 336                                ..Default::default()
 337                            },
 338                        );
 339                        thread
 340                            .update(cx, |thread, _cx| {
 341                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
 342                                thread
 343                                    .last_user_message()
 344                                    .content
 345                                    .push(MessageContent::ToolResult(tool_result));
 346                            })
 347                            .ok();
 348                    }
 349
 350                    completion_intent = CompletionIntent::ToolResults;
 351                }
 352
 353                Ok(())
 354            }
 355            .await;
 356
 357            if let Err(error) = turn_result {
 358                log::error!("Turn execution failed: {:?}", error);
 359                event_stream.send_error(error);
 360            } else {
 361                log::info!("Turn execution completed successfully");
 362            }
 363        }));
 364        events_rx
 365    }
 366
 367    pub fn build_system_message(&self) -> AgentMessage {
 368        log::debug!("Building system message");
 369        let prompt = SystemPromptTemplate {
 370            project: &self.project_context.borrow(),
 371            available_tools: self.tools.keys().cloned().collect(),
 372        }
 373        .render(&self.templates)
 374        .context("failed to build system prompt")
 375        .expect("Invalid template");
 376        log::debug!("System message built");
 377        AgentMessage {
 378            role: Role::System,
 379            content: vec![prompt.as_str().into()],
 380        }
 381    }
 382
 383    /// A helper method that's called on every streamed completion event.
 384    /// Returns an optional tool result task, which the main agentic loop in
 385    /// send will send back to the model when it resolves.
 386    fn handle_streamed_completion_event(
 387        &mut self,
 388        event: LanguageModelCompletionEvent,
 389        event_stream: &AgentResponseEventStream,
 390        cx: &mut Context<Self>,
 391    ) -> Option<Task<LanguageModelToolResult>> {
 392        log::trace!("Handling streamed completion event: {:?}", event);
 393        use LanguageModelCompletionEvent::*;
 394
 395        match event {
 396            StartMessage { .. } => {
 397                self.messages.push(AgentMessage {
 398                    role: Role::Assistant,
 399                    content: Vec::new(),
 400                });
 401            }
 402            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 403            Thinking { text, signature } => {
 404                self.handle_thinking_event(text, signature, event_stream, cx)
 405            }
 406            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 407            ToolUse(tool_use) => {
 408                return self.handle_tool_use_event(tool_use, event_stream, cx);
 409            }
 410            ToolUseJsonParseError {
 411                id,
 412                tool_name,
 413                raw_input,
 414                json_parse_error,
 415            } => {
 416                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 417                    id,
 418                    tool_name,
 419                    raw_input,
 420                    json_parse_error,
 421                )));
 422            }
 423            UsageUpdate(_) | StatusUpdate(_) => {}
 424            Stop(_) => unreachable!(),
 425        }
 426
 427        None
 428    }
 429
 430    fn handle_text_event(
 431        &mut self,
 432        new_text: String,
 433        events_stream: &AgentResponseEventStream,
 434        cx: &mut Context<Self>,
 435    ) {
 436        events_stream.send_text(&new_text);
 437
 438        let last_message = self.last_assistant_message();
 439        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
 440            text.push_str(&new_text);
 441        } else {
 442            last_message.content.push(MessageContent::Text(new_text));
 443        }
 444
 445        cx.notify();
 446    }
 447
 448    fn handle_thinking_event(
 449        &mut self,
 450        new_text: String,
 451        new_signature: Option<String>,
 452        event_stream: &AgentResponseEventStream,
 453        cx: &mut Context<Self>,
 454    ) {
 455        event_stream.send_thinking(&new_text);
 456
 457        let last_message = self.last_assistant_message();
 458        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
 459        {
 460            text.push_str(&new_text);
 461            *signature = new_signature.or(signature.take());
 462        } else {
 463            last_message.content.push(MessageContent::Thinking {
 464                text: new_text,
 465                signature: new_signature,
 466            });
 467        }
 468
 469        cx.notify();
 470    }
 471
 472    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 473        let last_message = self.last_assistant_message();
 474        last_message
 475            .content
 476            .push(MessageContent::RedactedThinking(data));
 477        cx.notify();
 478    }
 479
 480    fn handle_tool_use_event(
 481        &mut self,
 482        tool_use: LanguageModelToolUse,
 483        event_stream: &AgentResponseEventStream,
 484        cx: &mut Context<Self>,
 485    ) -> Option<Task<LanguageModelToolResult>> {
 486        cx.notify();
 487
 488        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 489
 490        self.pending_tool_uses
 491            .insert(tool_use.id.clone(), tool_use.clone());
 492        let last_message = self.last_assistant_message();
 493
 494        // Ensure the last message ends in the current tool use
 495        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 496            if let MessageContent::ToolUse(last_tool_use) = content {
 497                if last_tool_use.id == tool_use.id {
 498                    *last_tool_use = tool_use.clone();
 499                    false
 500                } else {
 501                    true
 502                }
 503            } else {
 504                true
 505            }
 506        });
 507
 508        let mut title = SharedString::from(&tool_use.name);
 509        let mut kind = acp::ToolKind::Other;
 510        if let Some(tool) = tool.as_ref() {
 511            title = tool.initial_title(tool_use.input.clone());
 512            kind = tool.kind();
 513        }
 514
 515        if push_new_tool_use {
 516            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 517            last_message
 518                .content
 519                .push(MessageContent::ToolUse(tool_use.clone()));
 520        } else {
 521            event_stream.update_tool_call_fields(
 522                &tool_use.id,
 523                acp::ToolCallUpdateFields {
 524                    title: Some(title.into()),
 525                    kind: Some(kind),
 526                    raw_input: Some(tool_use.input.clone()),
 527                    ..Default::default()
 528                },
 529            );
 530        }
 531
 532        if !tool_use.is_input_complete {
 533            return None;
 534        }
 535
 536        let Some(tool) = tool else {
 537            let content = format!("No tool named {} exists", tool_use.name);
 538            return Some(Task::ready(LanguageModelToolResult {
 539                content: LanguageModelToolResultContent::Text(Arc::from(content)),
 540                tool_use_id: tool_use.id,
 541                tool_name: tool_use.name,
 542                is_error: true,
 543                output: None,
 544            }));
 545        };
 546
 547        let fs = self.project.read(cx).fs().clone();
 548        let tool_event_stream =
 549            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
 550        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
 551            status: Some(acp::ToolCallStatus::InProgress),
 552            ..Default::default()
 553        });
 554        let supports_images = self.selected_model.supports_images();
 555        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
 556        Some(cx.foreground_executor().spawn(async move {
 557            let tool_result = tool_result.await.and_then(|output| {
 558                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
 559                    if !supports_images {
 560                        return Err(anyhow!(
 561                            "Attempted to read an image, but this model doesn't support it.",
 562                        ));
 563                    }
 564                }
 565                Ok(output)
 566            });
 567
 568            match tool_result {
 569                Ok(output) => LanguageModelToolResult {
 570                    tool_use_id: tool_use.id,
 571                    tool_name: tool_use.name,
 572                    is_error: false,
 573                    content: output.llm_output,
 574                    output: Some(output.raw_output),
 575                },
 576                Err(error) => LanguageModelToolResult {
 577                    tool_use_id: tool_use.id,
 578                    tool_name: tool_use.name,
 579                    is_error: true,
 580                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
 581                    output: None,
 582                },
 583            }
 584        }))
 585    }
 586
 587    fn handle_tool_use_json_parse_error_event(
 588        &mut self,
 589        tool_use_id: LanguageModelToolUseId,
 590        tool_name: Arc<str>,
 591        raw_input: Arc<str>,
 592        json_parse_error: String,
 593    ) -> LanguageModelToolResult {
 594        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
 595        LanguageModelToolResult {
 596            tool_use_id,
 597            tool_name,
 598            is_error: true,
 599            content: LanguageModelToolResultContent::Text(tool_output.into()),
 600            output: Some(serde_json::Value::String(raw_input.to_string())),
 601        }
 602    }
 603
 604    /// Guarantees the last message is from the assistant and returns a mutable reference.
 605    fn last_assistant_message(&mut self) -> &mut AgentMessage {
 606        if self
 607            .messages
 608            .last()
 609            .map_or(true, |m| m.role != Role::Assistant)
 610        {
 611            self.messages.push(AgentMessage {
 612                role: Role::Assistant,
 613                content: Vec::new(),
 614            });
 615        }
 616        self.messages.last_mut().unwrap()
 617    }
 618
 619    /// Guarantees the last message is from the user and returns a mutable reference.
 620    fn last_user_message(&mut self) -> &mut AgentMessage {
 621        if self.messages.last().map_or(true, |m| m.role != Role::User) {
 622            self.messages.push(AgentMessage {
 623                role: Role::User,
 624                content: Vec::new(),
 625            });
 626        }
 627        self.messages.last_mut().unwrap()
 628    }
 629
 630    pub(crate) fn build_completion_request(
 631        &self,
 632        completion_intent: CompletionIntent,
 633        cx: &mut App,
 634    ) -> LanguageModelRequest {
 635        log::debug!("Building completion request");
 636        log::debug!("Completion intent: {:?}", completion_intent);
 637        log::debug!("Completion mode: {:?}", self.completion_mode);
 638
 639        let messages = self.build_request_messages();
 640        log::info!("Request will include {} messages", messages.len());
 641
 642        let tools = if let Some(tools) = self.tools(cx).log_err() {
 643            tools
 644                .filter_map(|tool| {
 645                    let tool_name = tool.name().to_string();
 646                    log::trace!("Including tool: {}", tool_name);
 647                    Some(LanguageModelRequestTool {
 648                        name: tool_name,
 649                        description: tool.description().to_string(),
 650                        input_schema: tool
 651                            .input_schema(self.selected_model.tool_input_format())
 652                            .log_err()?,
 653                    })
 654                })
 655                .collect()
 656        } else {
 657            Vec::new()
 658        };
 659
 660        log::info!("Request includes {} tools", tools.len());
 661
 662        let request = LanguageModelRequest {
 663            thread_id: None,
 664            prompt_id: None,
 665            intent: Some(completion_intent),
 666            mode: Some(self.completion_mode),
 667            messages,
 668            tools,
 669            tool_choice: None,
 670            stop: Vec::new(),
 671            temperature: None,
 672            thinking_allowed: true,
 673        };
 674
 675        log::debug!("Completion request built successfully");
 676        request
 677    }
 678
 679    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
 680        let profile = AgentSettings::get_global(cx)
 681            .profiles
 682            .get(&self.profile_id)
 683            .context("profile not found")?;
 684
 685        Ok(self
 686            .tools
 687            .iter()
 688            .filter_map(|(tool_name, tool)| {
 689                if profile.is_tool_enabled(tool_name) {
 690                    Some(tool)
 691                } else {
 692                    None
 693                }
 694            })
 695            .chain(self.context_server_registry.read(cx).servers().flat_map(
 696                |(server_id, tools)| {
 697                    tools.iter().filter_map(|(tool_name, tool)| {
 698                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
 699                            Some(tool)
 700                        } else {
 701                            None
 702                        }
 703                    })
 704                },
 705            )))
 706    }
 707
 708    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
 709        log::trace!(
 710            "Building request messages from {} thread messages",
 711            self.messages.len()
 712        );
 713
 714        let messages = Some(self.build_system_message())
 715            .iter()
 716            .chain(self.messages.iter())
 717            .map(|message| {
 718                log::trace!(
 719                    "  - {} message with {} content items",
 720                    match message.role {
 721                        Role::System => "System",
 722                        Role::User => "User",
 723                        Role::Assistant => "Assistant",
 724                    },
 725                    message.content.len()
 726                );
 727                message.to_request()
 728            })
 729            .collect();
 730        messages
 731    }
 732
 733    pub fn to_markdown(&self) -> String {
 734        let mut markdown = String::new();
 735        for message in &self.messages {
 736            markdown.push_str(&message.to_markdown());
 737        }
 738        markdown
 739    }
 740}
 741
 742pub struct UserMessage(Vec<MessageContent>);
 743
 744impl From<Vec<MessageContent>> for UserMessage {
 745    fn from(content: Vec<MessageContent>) -> Self {
 746        UserMessage(content)
 747    }
 748}
 749
 750impl<T: Into<MessageContent>> From<T> for UserMessage {
 751    fn from(content: T) -> Self {
 752        UserMessage(vec![content.into()])
 753    }
 754}
 755
 756pub trait AgentTool
 757where
 758    Self: 'static + Sized,
 759{
 760    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
 761    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
 762
 763    fn name(&self) -> SharedString;
 764
 765    fn description(&self) -> SharedString {
 766        let schema = schemars::schema_for!(Self::Input);
 767        SharedString::new(
 768            schema
 769                .get("description")
 770                .and_then(|description| description.as_str())
 771                .unwrap_or_default(),
 772        )
 773    }
 774
 775    fn kind(&self) -> acp::ToolKind;
 776
 777    /// The initial tool title to display. Can be updated during the tool run.
 778    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
 779
 780    /// Returns the JSON schema that describes the tool's input.
 781    fn input_schema(&self) -> Schema {
 782        schemars::schema_for!(Self::Input)
 783    }
 784
 785    /// Runs the tool with the provided input.
 786    fn run(
 787        self: Arc<Self>,
 788        input: Self::Input,
 789        event_stream: ToolCallEventStream,
 790        cx: &mut App,
 791    ) -> Task<Result<Self::Output>>;
 792
 793    fn erase(self) -> Arc<dyn AnyAgentTool> {
 794        Arc::new(Erased(Arc::new(self)))
 795    }
 796}
 797
 798pub struct Erased<T>(T);
 799
 800pub struct AgentToolOutput {
 801    pub llm_output: LanguageModelToolResultContent,
 802    pub raw_output: serde_json::Value,
 803}
 804
 805pub trait AnyAgentTool {
 806    fn name(&self) -> SharedString;
 807    fn description(&self) -> SharedString;
 808    fn kind(&self) -> acp::ToolKind;
 809    fn initial_title(&self, input: serde_json::Value) -> SharedString;
 810    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
 811    fn run(
 812        self: Arc<Self>,
 813        input: serde_json::Value,
 814        event_stream: ToolCallEventStream,
 815        cx: &mut App,
 816    ) -> Task<Result<AgentToolOutput>>;
 817}
 818
 819impl<T> AnyAgentTool for Erased<Arc<T>>
 820where
 821    T: AgentTool,
 822{
 823    fn name(&self) -> SharedString {
 824        self.0.name()
 825    }
 826
 827    fn description(&self) -> SharedString {
 828        self.0.description()
 829    }
 830
 831    fn kind(&self) -> agent_client_protocol::ToolKind {
 832        self.0.kind()
 833    }
 834
 835    fn initial_title(&self, input: serde_json::Value) -> SharedString {
 836        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
 837        self.0.initial_title(parsed_input)
 838    }
 839
 840    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 841        let mut json = serde_json::to_value(self.0.input_schema())?;
 842        adapt_schema_to_format(&mut json, format)?;
 843        Ok(json)
 844    }
 845
 846    fn run(
 847        self: Arc<Self>,
 848        input: serde_json::Value,
 849        event_stream: ToolCallEventStream,
 850        cx: &mut App,
 851    ) -> Task<Result<AgentToolOutput>> {
 852        cx.spawn(async move |cx| {
 853            let input = serde_json::from_value(input)?;
 854            let output = cx
 855                .update(|cx| self.0.clone().run(input, event_stream, cx))?
 856                .await?;
 857            let raw_output = serde_json::to_value(&output)?;
 858            Ok(AgentToolOutput {
 859                llm_output: output.into(),
 860                raw_output,
 861            })
 862        })
 863    }
 864}
 865
 866#[derive(Clone)]
 867struct AgentResponseEventStream(
 868    mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 869);
 870
 871impl AgentResponseEventStream {
 872    fn send_text(&self, text: &str) {
 873        self.0
 874            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
 875            .ok();
 876    }
 877
 878    fn send_thinking(&self, text: &str) {
 879        self.0
 880            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
 881            .ok();
 882    }
 883
 884    fn send_tool_call(
 885        &self,
 886        id: &LanguageModelToolUseId,
 887        title: SharedString,
 888        kind: acp::ToolKind,
 889        input: serde_json::Value,
 890    ) {
 891        self.0
 892            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
 893                id,
 894                title.to_string(),
 895                kind,
 896                input,
 897            ))))
 898            .ok();
 899    }
 900
 901    fn initial_tool_call(
 902        id: &LanguageModelToolUseId,
 903        title: String,
 904        kind: acp::ToolKind,
 905        input: serde_json::Value,
 906    ) -> acp::ToolCall {
 907        acp::ToolCall {
 908            id: acp::ToolCallId(id.to_string().into()),
 909            title,
 910            kind,
 911            status: acp::ToolCallStatus::Pending,
 912            content: vec![],
 913            locations: vec![],
 914            raw_input: Some(input),
 915            raw_output: None,
 916        }
 917    }
 918
 919    fn update_tool_call_fields(
 920        &self,
 921        tool_use_id: &LanguageModelToolUseId,
 922        fields: acp::ToolCallUpdateFields,
 923    ) {
 924        self.0
 925            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
 926                acp::ToolCallUpdate {
 927                    id: acp::ToolCallId(tool_use_id.to_string().into()),
 928                    fields,
 929                }
 930                .into(),
 931            )))
 932            .ok();
 933    }
 934
 935    fn send_stop(&self, reason: StopReason) {
 936        match reason {
 937            StopReason::EndTurn => {
 938                self.0
 939                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
 940                    .ok();
 941            }
 942            StopReason::MaxTokens => {
 943                self.0
 944                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
 945                    .ok();
 946            }
 947            StopReason::Refusal => {
 948                self.0
 949                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
 950                    .ok();
 951            }
 952            StopReason::ToolUse => {}
 953        }
 954    }
 955
 956    fn send_error(&self, error: LanguageModelCompletionError) {
 957        self.0.unbounded_send(Err(error)).ok();
 958    }
 959}
 960
 961#[derive(Clone)]
 962pub struct ToolCallEventStream {
 963    tool_use_id: LanguageModelToolUseId,
 964    kind: acp::ToolKind,
 965    input: serde_json::Value,
 966    stream: AgentResponseEventStream,
 967    fs: Option<Arc<dyn Fs>>,
 968}
 969
 970impl ToolCallEventStream {
 971    #[cfg(test)]
 972    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
 973        let (events_tx, events_rx) =
 974            mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
 975
 976        let stream = ToolCallEventStream::new(
 977            &LanguageModelToolUse {
 978                id: "test_id".into(),
 979                name: "test_tool".into(),
 980                raw_input: String::new(),
 981                input: serde_json::Value::Null,
 982                is_input_complete: true,
 983            },
 984            acp::ToolKind::Other,
 985            AgentResponseEventStream(events_tx),
 986            None,
 987        );
 988
 989        (stream, ToolCallEventStreamReceiver(events_rx))
 990    }
 991
 992    fn new(
 993        tool_use: &LanguageModelToolUse,
 994        kind: acp::ToolKind,
 995        stream: AgentResponseEventStream,
 996        fs: Option<Arc<dyn Fs>>,
 997    ) -> Self {
 998        Self {
 999            tool_use_id: tool_use.id.clone(),
1000            kind,
1001            input: tool_use.input.clone(),
1002            stream,
1003            fs,
1004        }
1005    }
1006
1007    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1008        self.stream
1009            .update_tool_call_fields(&self.tool_use_id, fields);
1010    }
1011
1012    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1013        self.stream
1014            .0
1015            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1016                acp_thread::ToolCallUpdateDiff {
1017                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1018                    diff,
1019                }
1020                .into(),
1021            )))
1022            .ok();
1023    }
1024
1025    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1026        self.stream
1027            .0
1028            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
1029                acp_thread::ToolCallUpdateTerminal {
1030                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1031                    terminal,
1032                }
1033                .into(),
1034            )))
1035            .ok();
1036    }
1037
1038    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1039        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1040            return Task::ready(Ok(()));
1041        }
1042
1043        let (response_tx, response_rx) = oneshot::channel();
1044        self.stream
1045            .0
1046            .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
1047                ToolCallAuthorization {
1048                    tool_call: AgentResponseEventStream::initial_tool_call(
1049                        &self.tool_use_id,
1050                        title.into(),
1051                        self.kind.clone(),
1052                        self.input.clone(),
1053                    ),
1054                    options: vec![
1055                        acp::PermissionOption {
1056                            id: acp::PermissionOptionId("always_allow".into()),
1057                            name: "Always Allow".into(),
1058                            kind: acp::PermissionOptionKind::AllowAlways,
1059                        },
1060                        acp::PermissionOption {
1061                            id: acp::PermissionOptionId("allow".into()),
1062                            name: "Allow".into(),
1063                            kind: acp::PermissionOptionKind::AllowOnce,
1064                        },
1065                        acp::PermissionOption {
1066                            id: acp::PermissionOptionId("deny".into()),
1067                            name: "Deny".into(),
1068                            kind: acp::PermissionOptionKind::RejectOnce,
1069                        },
1070                    ],
1071                    response: response_tx,
1072                },
1073            )))
1074            .ok();
1075        let fs = self.fs.clone();
1076        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1077            "always_allow" => {
1078                if let Some(fs) = fs.clone() {
1079                    cx.update(|cx| {
1080                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1081                            settings.set_always_allow_tool_actions(true);
1082                        });
1083                    })?;
1084                }
1085
1086                Ok(())
1087            }
1088            "allow" => Ok(()),
1089            _ => Err(anyhow!("Permission to run tool denied by user")),
1090        })
1091    }
1092}
1093
1094#[cfg(test)]
1095pub struct ToolCallEventStreamReceiver(
1096    mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1097);
1098
1099#[cfg(test)]
1100impl ToolCallEventStreamReceiver {
1101    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1102        let event = self.0.next().await;
1103        if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event {
1104            auth
1105        } else {
1106            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1107        }
1108    }
1109
1110    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1111        let event = self.0.next().await;
1112        if let Some(Ok(AgentResponseEvent::ToolCallUpdate(
1113            acp_thread::ToolCallUpdate::UpdateTerminal(update),
1114        ))) = event
1115        {
1116            update.terminal
1117        } else {
1118            panic!("Expected terminal but got: {:?}", event);
1119        }
1120    }
1121}
1122
1123#[cfg(test)]
1124impl std::ops::Deref for ToolCallEventStreamReceiver {
1125    type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
1126
1127    fn deref(&self) -> &Self::Target {
1128        &self.0
1129    }
1130}
1131
1132#[cfg(test)]
1133impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1134    fn deref_mut(&mut self) -> &mut Self::Target {
1135        &mut self.0
1136    }
1137}
1138
1139impl AgentMessage {
1140    fn to_request(&self) -> language_model::LanguageModelRequestMessage {
1141        let mut message = LanguageModelRequestMessage {
1142            role: self.role,
1143            content: Vec::with_capacity(self.content.len()),
1144            cache: false,
1145        };
1146
1147        const OPEN_CONTEXT: &str = "<context>\n\
1148            The following items were attached by the user. \
1149            They are up-to-date and don't need to be re-read.\n\n";
1150
1151        const OPEN_FILES_TAG: &str = "<files>";
1152        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
1153        const OPEN_THREADS_TAG: &str = "<threads>";
1154        const OPEN_RULES_TAG: &str =
1155            "<rules>\nThe user has specified the following rules that should be applied:\n";
1156
1157        let mut file_context = OPEN_FILES_TAG.to_string();
1158        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
1159        let mut thread_context = OPEN_THREADS_TAG.to_string();
1160        let mut rules_context = OPEN_RULES_TAG.to_string();
1161
1162        for chunk in &self.content {
1163            let chunk = match chunk {
1164                MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()),
1165                MessageContent::Thinking { text, signature } => {
1166                    language_model::MessageContent::Thinking {
1167                        text: text.clone(),
1168                        signature: signature.clone(),
1169                    }
1170                }
1171                MessageContent::RedactedThinking(value) => {
1172                    language_model::MessageContent::RedactedThinking(value.clone())
1173                }
1174                MessageContent::ToolUse(value) => {
1175                    language_model::MessageContent::ToolUse(value.clone())
1176                }
1177                MessageContent::ToolResult(value) => {
1178                    language_model::MessageContent::ToolResult(value.clone())
1179                }
1180                MessageContent::Image(value) => {
1181                    language_model::MessageContent::Image(value.clone())
1182                }
1183                MessageContent::Mention { uri, content } => {
1184                    match uri {
1185                        MentionUri::File(path) | MentionUri::Symbol(path, _) => {
1186                            write!(
1187                                &mut symbol_context,
1188                                "\n{}",
1189                                MarkdownCodeBlock {
1190                                    tag: &codeblock_tag(&path),
1191                                    text: &content.to_string(),
1192                                }
1193                            )
1194                            .ok();
1195                        }
1196                        MentionUri::Thread(_session_id) => {
1197                            write!(&mut thread_context, "\n{}\n", content).ok();
1198                        }
1199                        MentionUri::Rule(_user_prompt_id) => {
1200                            write!(
1201                                &mut rules_context,
1202                                "\n{}",
1203                                MarkdownCodeBlock {
1204                                    tag: "",
1205                                    text: &content
1206                                }
1207                            )
1208                            .ok();
1209                        }
1210                    }
1211
1212                    language_model::MessageContent::Text(uri.to_link())
1213                }
1214            };
1215
1216            message.content.push(chunk);
1217        }
1218
1219        let len_before_context = message.content.len();
1220
1221        if file_context.len() > OPEN_FILES_TAG.len() {
1222            file_context.push_str("</files>\n");
1223            message
1224                .content
1225                .push(language_model::MessageContent::Text(file_context));
1226        }
1227
1228        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
1229            symbol_context.push_str("</symbols>\n");
1230            message
1231                .content
1232                .push(language_model::MessageContent::Text(symbol_context));
1233        }
1234
1235        if thread_context.len() > OPEN_THREADS_TAG.len() {
1236            thread_context.push_str("</threads>\n");
1237            message
1238                .content
1239                .push(language_model::MessageContent::Text(thread_context));
1240        }
1241
1242        if rules_context.len() > OPEN_RULES_TAG.len() {
1243            rules_context.push_str("</user_rules>\n");
1244            message
1245                .content
1246                .push(language_model::MessageContent::Text(rules_context));
1247        }
1248
1249        if message.content.len() > len_before_context {
1250            message.content.insert(
1251                len_before_context,
1252                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
1253            );
1254            message
1255                .content
1256                .push(language_model::MessageContent::Text("</context>".into()));
1257        }
1258
1259        message
1260    }
1261}
1262
1263fn codeblock_tag(full_path: &Path) -> String {
1264    let mut result = String::new();
1265
1266    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
1267        let _ = write!(result, "{} ", extension);
1268    }
1269
1270    let _ = write!(result, "{}", full_path.display());
1271
1272    result
1273}
1274
1275impl From<acp::ContentBlock> for MessageContent {
1276    fn from(value: acp::ContentBlock) -> Self {
1277        match value {
1278            acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text),
1279            acp::ContentBlock::Image(image_content) => {
1280                MessageContent::Image(convert_image(image_content))
1281            }
1282            acp::ContentBlock::Audio(_) => {
1283                // TODO
1284                MessageContent::Text("[audio]".to_string())
1285            }
1286            acp::ContentBlock::ResourceLink(resource_link) => {
1287                match MentionUri::parse(&resource_link.uri) {
1288                    Ok(uri) => Self::Mention {
1289                        uri,
1290                        content: String::new(),
1291                    },
1292                    Err(err) => {
1293                        log::error!("Failed to parse mention link: {}", err);
1294                        MessageContent::Text(format!(
1295                            "[{}]({})",
1296                            resource_link.name, resource_link.uri
1297                        ))
1298                    }
1299                }
1300            }
1301            acp::ContentBlock::Resource(resource) => match resource.resource {
1302                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1303                    match MentionUri::parse(&resource.uri) {
1304                        Ok(uri) => Self::Mention {
1305                            uri,
1306                            content: resource.text,
1307                        },
1308                        Err(err) => {
1309                            log::error!("Failed to parse mention link: {}", err);
1310                            MessageContent::Text(
1311                                MarkdownCodeBlock {
1312                                    tag: &resource.uri,
1313                                    text: &resource.text,
1314                                }
1315                                .to_string(),
1316                            )
1317                        }
1318                    }
1319                }
1320                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1321                    // TODO
1322                    MessageContent::Text("[blob]".to_string())
1323                }
1324            },
1325        }
1326    }
1327}
1328
1329fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1330    LanguageModelImage {
1331        source: image_content.data.into(),
1332        // TODO: make this optional?
1333        size: gpui::Size::new(0.into(), 0.into()),
1334    }
1335}
1336
1337impl From<&str> for MessageContent {
1338    fn from(text: &str) -> Self {
1339        MessageContent::Text(text.into())
1340    }
1341}