thread.rs

   1use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates};
   2use acp_thread::{MentionUri, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol as acp;
   5use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
   6use anyhow::{Context as _, Result, anyhow};
   7use assistant_tool::adapt_schema_to_format;
   8use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
   9use collections::IndexMap;
  10use fs::Fs;
  11use futures::{
  12    channel::{mpsc, oneshot},
  13    stream::FuturesUnordered,
  14};
  15use gpui::{App, Context, Entity, SharedString, Task};
  16use language_model::{
  17    LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
  18    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  19    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  20    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
  21};
  22use project::Project;
  23use prompt_store::ProjectContext;
  24use schemars::{JsonSchema, Schema};
  25use serde::{Deserialize, Serialize};
  26use settings::{Settings, update_settings_file};
  27use smol::stream::StreamExt;
  28use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
  29use std::{fmt::Write, ops::Range};
  30use util::{ResultExt, markdown::MarkdownCodeBlock};
  31use uuid::Uuid;
  32
  33const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  34
  35#[derive(
  36    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  37)]
  38pub struct ThreadId(pub(crate) Arc<str>);
  39
  40impl ThreadId {
  41    pub fn new() -> Self {
  42        Self(Uuid::new_v4().to_string().into())
  43    }
  44}
  45
  46impl std::fmt::Display for ThreadId {
  47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  48        write!(f, "{}", self.0)
  49    }
  50}
  51
  52impl From<&str> for ThreadId {
  53    fn from(value: &str) -> Self {
  54        Self(value.into())
  55    }
  56}
  57
  58impl From<acp::SessionId> for ThreadId {
  59    fn from(value: acp::SessionId) -> Self {
  60        Self(value.0)
  61    }
  62}
  63
  64impl From<ThreadId> for acp::SessionId {
  65    fn from(value: ThreadId) -> Self {
  66        Self(value.0)
  67    }
  68}
  69
  70/// The ID of the user prompt that initiated a request.
  71///
  72/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  74pub struct PromptId(Arc<str>);
  75
  76impl PromptId {
  77    pub fn new() -> Self {
  78        Self(Uuid::new_v4().to_string().into())
  79    }
  80}
  81
  82impl std::fmt::Display for PromptId {
  83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  84        write!(f, "{}", self.0)
  85    }
  86}
  87
  88#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  89pub enum Message {
  90    User(UserMessage),
  91    Agent(AgentMessage),
  92    Resume,
  93}
  94
  95impl Message {
  96    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
  97        match self {
  98            Message::Agent(agent_message) => Some(agent_message),
  99            _ => None,
 100        }
 101    }
 102
 103    pub fn to_markdown(&self) -> String {
 104        match self {
 105            Message::User(message) => message.to_markdown(),
 106            Message::Agent(message) => message.to_markdown(),
 107            Message::Resume => "[resumed after tool use limit was reached]".into(),
 108        }
 109    }
 110}
 111
 112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 113pub struct UserMessage {
 114    pub id: UserMessageId,
 115    pub content: Vec<UserMessageContent>,
 116}
 117
 118#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 119pub enum UserMessageContent {
 120    Text(String),
 121    Mention { uri: MentionUri, content: String },
 122    Image(LanguageModelImage),
 123}
 124
 125impl UserMessage {
 126    pub fn to_markdown(&self) -> String {
 127        let mut markdown = String::from("## User\n\n");
 128
 129        for content in &self.content {
 130            match content {
 131                UserMessageContent::Text(text) => {
 132                    markdown.push_str(text);
 133                    markdown.push('\n');
 134                }
 135                UserMessageContent::Image(_) => {
 136                    markdown.push_str("<image />\n");
 137                }
 138                UserMessageContent::Mention { uri, content } => {
 139                    if !content.is_empty() {
 140                        let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
 141                    } else {
 142                        let _ = write!(&mut markdown, "{}\n", uri.as_link());
 143                    }
 144                }
 145            }
 146        }
 147
 148        markdown
 149    }
 150
 151    fn to_request(&self) -> LanguageModelRequestMessage {
 152        let mut message = LanguageModelRequestMessage {
 153            role: Role::User,
 154            content: Vec::with_capacity(self.content.len()),
 155            cache: false,
 156        };
 157
 158        const OPEN_CONTEXT: &str = "<context>\n\
 159            The following items were attached by the user. \
 160            They are up-to-date and don't need to be re-read.\n\n";
 161
 162        const OPEN_FILES_TAG: &str = "<files>";
 163        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 164        const OPEN_THREADS_TAG: &str = "<threads>";
 165        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 166        const OPEN_RULES_TAG: &str =
 167            "<rules>\nThe user has specified the following rules that should be applied:\n";
 168
 169        let mut file_context = OPEN_FILES_TAG.to_string();
 170        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 171        let mut thread_context = OPEN_THREADS_TAG.to_string();
 172        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 173        let mut rules_context = OPEN_RULES_TAG.to_string();
 174
 175        for chunk in &self.content {
 176            let chunk = match chunk {
 177                UserMessageContent::Text(text) => {
 178                    language_model::MessageContent::Text(text.clone())
 179                }
 180                UserMessageContent::Image(value) => {
 181                    language_model::MessageContent::Image(value.clone())
 182                }
 183                UserMessageContent::Mention { uri, content } => {
 184                    match uri {
 185                        MentionUri::File { abs_path, .. } => {
 186                            write!(
 187                                &mut symbol_context,
 188                                "\n{}",
 189                                MarkdownCodeBlock {
 190                                    tag: &codeblock_tag(&abs_path, None),
 191                                    text: &content.to_string(),
 192                                }
 193                            )
 194                            .ok();
 195                        }
 196                        MentionUri::Symbol {
 197                            path, line_range, ..
 198                        }
 199                        | MentionUri::Selection {
 200                            path, line_range, ..
 201                        } => {
 202                            write!(
 203                                &mut rules_context,
 204                                "\n{}",
 205                                MarkdownCodeBlock {
 206                                    tag: &codeblock_tag(&path, Some(line_range)),
 207                                    text: &content
 208                                }
 209                            )
 210                            .ok();
 211                        }
 212                        MentionUri::Thread { .. } => {
 213                            write!(&mut thread_context, "\n{}\n", content).ok();
 214                        }
 215                        MentionUri::TextThread { .. } => {
 216                            write!(&mut thread_context, "\n{}\n", content).ok();
 217                        }
 218                        MentionUri::Rule { .. } => {
 219                            write!(
 220                                &mut rules_context,
 221                                "\n{}",
 222                                MarkdownCodeBlock {
 223                                    tag: "",
 224                                    text: &content
 225                                }
 226                            )
 227                            .ok();
 228                        }
 229                        MentionUri::Fetch { url } => {
 230                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 231                        }
 232                    }
 233
 234                    language_model::MessageContent::Text(uri.as_link().to_string())
 235                }
 236            };
 237
 238            message.content.push(chunk);
 239        }
 240
 241        let len_before_context = message.content.len();
 242
 243        if file_context.len() > OPEN_FILES_TAG.len() {
 244            file_context.push_str("</files>\n");
 245            message
 246                .content
 247                .push(language_model::MessageContent::Text(file_context));
 248        }
 249
 250        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 251            symbol_context.push_str("</symbols>\n");
 252            message
 253                .content
 254                .push(language_model::MessageContent::Text(symbol_context));
 255        }
 256
 257        if thread_context.len() > OPEN_THREADS_TAG.len() {
 258            thread_context.push_str("</threads>\n");
 259            message
 260                .content
 261                .push(language_model::MessageContent::Text(thread_context));
 262        }
 263
 264        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 265            fetch_context.push_str("</fetched_urls>\n");
 266            message
 267                .content
 268                .push(language_model::MessageContent::Text(fetch_context));
 269        }
 270
 271        if rules_context.len() > OPEN_RULES_TAG.len() {
 272            rules_context.push_str("</user_rules>\n");
 273            message
 274                .content
 275                .push(language_model::MessageContent::Text(rules_context));
 276        }
 277
 278        if message.content.len() > len_before_context {
 279            message.content.insert(
 280                len_before_context,
 281                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 282            );
 283            message
 284                .content
 285                .push(language_model::MessageContent::Text("</context>".into()));
 286        }
 287
 288        message
 289    }
 290}
 291
 292fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 293    let mut result = String::new();
 294
 295    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 296        let _ = write!(result, "{} ", extension);
 297    }
 298
 299    let _ = write!(result, "{}", full_path.display());
 300
 301    if let Some(range) = line_range {
 302        if range.start == range.end {
 303            let _ = write!(result, ":{}", range.start + 1);
 304        } else {
 305            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 306        }
 307    }
 308
 309    result
 310}
 311
 312impl AgentMessage {
 313    pub fn to_markdown(&self) -> String {
 314        let mut markdown = String::from("## Assistant\n\n");
 315
 316        for content in &self.content {
 317            match content {
 318                AgentMessageContent::Text(text) => {
 319                    markdown.push_str(text);
 320                    markdown.push('\n');
 321                }
 322                AgentMessageContent::Thinking { text, .. } => {
 323                    markdown.push_str("<think>");
 324                    markdown.push_str(text);
 325                    markdown.push_str("</think>\n");
 326                }
 327                AgentMessageContent::RedactedThinking(_) => {
 328                    markdown.push_str("<redacted_thinking />\n")
 329                }
 330                AgentMessageContent::ToolUse(tool_use) => {
 331                    markdown.push_str(&format!(
 332                        "**Tool Use**: {} (ID: {})\n",
 333                        tool_use.name, tool_use.id
 334                    ));
 335                    markdown.push_str(&format!(
 336                        "{}\n",
 337                        MarkdownCodeBlock {
 338                            tag: "json",
 339                            text: &format!("{:#}", tool_use.input)
 340                        }
 341                    ));
 342                }
 343            }
 344        }
 345
 346        for tool_result in self.tool_results.values() {
 347            markdown.push_str(&format!(
 348                "**Tool Result**: {} (ID: {})\n\n",
 349                tool_result.tool_name, tool_result.tool_use_id
 350            ));
 351            if tool_result.is_error {
 352                markdown.push_str("**ERROR:**\n");
 353            }
 354
 355            match &tool_result.content {
 356                LanguageModelToolResultContent::Text(text) => {
 357                    writeln!(markdown, "{text}\n").ok();
 358                }
 359                LanguageModelToolResultContent::Image(_) => {
 360                    writeln!(markdown, "<image />\n").ok();
 361                }
 362            }
 363
 364            if let Some(output) = tool_result.output.as_ref() {
 365                writeln!(
 366                    markdown,
 367                    "**Debug Output**:\n\n```json\n{}\n```\n",
 368                    serde_json::to_string_pretty(output).unwrap()
 369                )
 370                .unwrap();
 371            }
 372        }
 373
 374        markdown
 375    }
 376
 377    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 378        let mut assistant_message = LanguageModelRequestMessage {
 379            role: Role::Assistant,
 380            content: Vec::with_capacity(self.content.len()),
 381            cache: false,
 382        };
 383        for chunk in &self.content {
 384            let chunk = match chunk {
 385                AgentMessageContent::Text(text) => {
 386                    language_model::MessageContent::Text(text.clone())
 387                }
 388                AgentMessageContent::Thinking { text, signature } => {
 389                    language_model::MessageContent::Thinking {
 390                        text: text.clone(),
 391                        signature: signature.clone(),
 392                    }
 393                }
 394                AgentMessageContent::RedactedThinking(value) => {
 395                    language_model::MessageContent::RedactedThinking(value.clone())
 396                }
 397                AgentMessageContent::ToolUse(value) => {
 398                    language_model::MessageContent::ToolUse(value.clone())
 399                }
 400            };
 401            assistant_message.content.push(chunk);
 402        }
 403
 404        let mut user_message = LanguageModelRequestMessage {
 405            role: Role::User,
 406            content: Vec::new(),
 407            cache: false,
 408        };
 409
 410        for tool_result in self.tool_results.values() {
 411            user_message
 412                .content
 413                .push(language_model::MessageContent::ToolResult(
 414                    tool_result.clone(),
 415                ));
 416        }
 417
 418        let mut messages = Vec::new();
 419        if !assistant_message.content.is_empty() {
 420            messages.push(assistant_message);
 421        }
 422        if !user_message.content.is_empty() {
 423            messages.push(user_message);
 424        }
 425        messages
 426    }
 427}
 428
 429#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 430pub struct AgentMessage {
 431    pub content: Vec<AgentMessageContent>,
 432    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 433}
 434
 435#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 436pub enum AgentMessageContent {
 437    Text(String),
 438    Thinking {
 439        text: String,
 440        signature: Option<String>,
 441    },
 442    RedactedThinking(String),
 443    ToolUse(LanguageModelToolUse),
 444}
 445
 446#[derive(Debug)]
 447pub enum ThreadEvent {
 448    UserMessage(UserMessage),
 449    AgentText(String),
 450    AgentThinking(String),
 451    ToolCall(acp::ToolCall),
 452    ToolCallUpdate(acp_thread::ToolCallUpdate),
 453    ToolCallAuthorization(ToolCallAuthorization),
 454    Stop(acp::StopReason),
 455}
 456
 457#[derive(Debug)]
 458pub struct ToolCallAuthorization {
 459    pub tool_call: acp::ToolCallUpdate,
 460    pub options: Vec<acp::PermissionOption>,
 461    pub response: oneshot::Sender<acp::PermissionOptionId>,
 462}
 463
 464pub struct Thread {
 465    id: ThreadId,
 466    prompt_id: PromptId,
 467    messages: Vec<Message>,
 468    completion_mode: CompletionMode,
 469    /// Holds the task that handles agent interaction until the end of the turn.
 470    /// Survives across multiple requests as the model performs tool calls and
 471    /// we run tools, report their results.
 472    running_turn: Option<RunningTurn>,
 473    pending_message: Option<AgentMessage>,
 474    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 475    tool_use_limit_reached: bool,
 476    context_server_registry: Entity<ContextServerRegistry>,
 477    profile_id: AgentProfileId,
 478    project_context: Rc<RefCell<ProjectContext>>,
 479    templates: Arc<Templates>,
 480    model: Arc<dyn LanguageModel>,
 481    project: Entity<Project>,
 482    action_log: Entity<ActionLog>,
 483}
 484
 485impl Thread {
 486    pub fn new(
 487        project: Entity<Project>,
 488        project_context: Rc<RefCell<ProjectContext>>,
 489        context_server_registry: Entity<ContextServerRegistry>,
 490        action_log: Entity<ActionLog>,
 491        templates: Arc<Templates>,
 492        model: Arc<dyn LanguageModel>,
 493        cx: &mut Context<Self>,
 494    ) -> Self {
 495        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 496        Self {
 497            id: ThreadId::new(),
 498            prompt_id: PromptId::new(),
 499            messages: Vec::new(),
 500            completion_mode: CompletionMode::Normal,
 501            running_turn: None,
 502            pending_message: None,
 503            tools: BTreeMap::default(),
 504            tool_use_limit_reached: false,
 505            context_server_registry,
 506            profile_id,
 507            project_context,
 508            templates,
 509            model,
 510            project,
 511            action_log,
 512        }
 513    }
 514
 515    pub fn from_db(
 516        id: ThreadId,
 517        db_thread: DbThread,
 518        project: Entity<Project>,
 519        project_context: Rc<RefCell<ProjectContext>>,
 520        context_server_registry: Entity<ContextServerRegistry>,
 521        action_log: Entity<ActionLog>,
 522        templates: Arc<Templates>,
 523        model: Arc<dyn LanguageModel>,
 524        cx: &mut Context<Self>,
 525    ) -> Self {
 526        let profile_id = db_thread
 527            .profile
 528            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 529        Self {
 530            id,
 531            prompt_id: PromptId::new(),
 532            messages: db_thread.messages,
 533            completion_mode: CompletionMode::Normal,
 534            running_turn: None,
 535            pending_message: None,
 536            tools: BTreeMap::default(),
 537            tool_use_limit_reached: false,
 538            context_server_registry,
 539            profile_id,
 540            project_context,
 541            templates,
 542            model,
 543            project,
 544            action_log,
 545        }
 546    }
 547
 548    pub fn replay(&self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 549        let (tx, rx) = mpsc::unbounded();
 550        let stream = ThreadEventStream(tx);
 551        for message in &self.messages {
 552            match message {
 553                Message::User(user_message) => stream.send_user_message(&user_message),
 554                Message::Agent(assistant_message) => {
 555                    for content in &assistant_message.content {
 556                        match content {
 557                            AgentMessageContent::Text(text) => stream.send_text(text),
 558                            AgentMessageContent::Thinking { text, .. } => {
 559                                stream.send_thinking(text)
 560                            }
 561                            AgentMessageContent::RedactedThinking(_) => {}
 562                            AgentMessageContent::ToolUse(tool_use) => {
 563                                self.replay_tool_call(
 564                                    tool_use,
 565                                    assistant_message.tool_results.get(&tool_use.id),
 566                                    &stream,
 567                                    cx,
 568                                );
 569                            }
 570                        }
 571                    }
 572                }
 573                Message::Resume => {}
 574            }
 575        }
 576        rx
 577    }
 578
 579    fn replay_tool_call(
 580        &self,
 581        tool_use: &LanguageModelToolUse,
 582        tool_result: Option<&LanguageModelToolResult>,
 583        stream: &ThreadEventStream,
 584        cx: &mut Context<Self>,
 585    ) {
 586        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
 587            stream
 588                .0
 589                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 590                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 591                    title: tool_use.name.to_string(),
 592                    kind: acp::ToolKind::Other,
 593                    status: acp::ToolCallStatus::Failed,
 594                    content: Vec::new(),
 595                    locations: Vec::new(),
 596                    raw_input: Some(tool_use.input.clone()),
 597                    raw_output: None,
 598                })))
 599                .ok();
 600            return;
 601        };
 602
 603        let title = tool.initial_title(tool_use.input.clone());
 604        let kind = tool.kind();
 605        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 606
 607        if let Some(output) = tool_result
 608            .as_ref()
 609            .and_then(|result| result.output.clone())
 610        {
 611            let tool_event_stream = ToolCallEventStream::new(
 612                tool_use.id.clone(),
 613                stream.clone(),
 614                Some(self.project.read(cx).fs().clone()),
 615            );
 616            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 617                .log_err();
 618        } else {
 619            stream.update_tool_call_fields(
 620                &tool_use.id,
 621                acp::ToolCallUpdateFields {
 622                    content: Some(vec![TOOL_CANCELED_MESSAGE.into()]),
 623                    status: Some(acp::ToolCallStatus::Failed),
 624                    ..Default::default()
 625                },
 626            );
 627        }
 628    }
 629
 630    pub fn project(&self) -> &Entity<Project> {
 631        &self.project
 632    }
 633
 634    pub fn action_log(&self) -> &Entity<ActionLog> {
 635        &self.action_log
 636    }
 637
 638    pub fn model(&self) -> &Arc<dyn LanguageModel> {
 639        &self.model
 640    }
 641
 642    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
 643        self.model = model;
 644    }
 645
 646    pub fn completion_mode(&self) -> CompletionMode {
 647        self.completion_mode
 648    }
 649
 650    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 651        self.completion_mode = mode;
 652    }
 653
 654    #[cfg(any(test, feature = "test-support"))]
 655    pub fn last_message(&self) -> Option<Message> {
 656        if let Some(message) = self.pending_message.clone() {
 657            Some(Message::Agent(message))
 658        } else {
 659            self.messages.last().cloned()
 660        }
 661    }
 662
 663    pub fn add_tool(&mut self, tool: impl AgentTool) {
 664        self.tools.insert(tool.name(), tool.erase());
 665    }
 666
 667    pub fn remove_tool(&mut self, name: &str) -> bool {
 668        self.tools.remove(name).is_some()
 669    }
 670
 671    pub fn profile(&self) -> &AgentProfileId {
 672        &self.profile_id
 673    }
 674
 675    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 676        self.profile_id = profile_id;
 677    }
 678
 679    pub fn cancel(&mut self) {
 680        if let Some(running_turn) = self.running_turn.take() {
 681            running_turn.cancel();
 682        }
 683        self.flush_pending_message();
 684    }
 685
 686    pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
 687        self.cancel();
 688        let Some(position) = self.messages.iter().position(
 689            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
 690        ) else {
 691            return Err(anyhow!("Message not found"));
 692        };
 693        self.messages.truncate(position);
 694        Ok(())
 695    }
 696
 697    pub fn resume(
 698        &mut self,
 699        cx: &mut Context<Self>,
 700    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
 701        anyhow::ensure!(
 702            self.tool_use_limit_reached,
 703            "can only resume after tool use limit is reached"
 704        );
 705
 706        self.messages.push(Message::Resume);
 707        cx.notify();
 708
 709        log::info!("Total messages in thread: {}", self.messages.len());
 710        Ok(self.run_turn(cx))
 711    }
 712
 713    /// Sending a message results in the model streaming a response, which could include tool calls.
 714    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 715    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 716    pub fn send<T>(
 717        &mut self,
 718        id: UserMessageId,
 719        content: impl IntoIterator<Item = T>,
 720        cx: &mut Context<Self>,
 721    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
 722    where
 723        T: Into<UserMessageContent>,
 724    {
 725        log::info!("Thread::send called with model: {:?}", self.model.name());
 726        self.advance_prompt_id();
 727
 728        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
 729        log::debug!("Thread::send content: {:?}", content);
 730
 731        self.messages
 732            .push(Message::User(UserMessage { id, content }));
 733        cx.notify();
 734
 735        log::info!("Total messages in thread: {}", self.messages.len());
 736        self.run_turn(cx)
 737    }
 738
 739    fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 740        self.cancel();
 741
 742        let model = self.model.clone();
 743        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
 744        let event_stream = ThreadEventStream(events_tx);
 745        let message_ix = self.messages.len().saturating_sub(1);
 746        self.tool_use_limit_reached = false;
 747        self.running_turn = Some(RunningTurn {
 748            event_stream: event_stream.clone(),
 749            _task: cx.spawn(async move |this, cx| {
 750                log::info!("Starting agent turn execution");
 751                let turn_result: Result<()> = async {
 752                    let mut completion_intent = CompletionIntent::UserPrompt;
 753                    loop {
 754                        log::debug!(
 755                            "Building completion request with intent: {:?}",
 756                            completion_intent
 757                        );
 758                        let request = this.update(cx, |this, cx| {
 759                            this.build_completion_request(completion_intent, cx)
 760                        })?;
 761
 762                        log::info!("Calling model.stream_completion");
 763                        let mut events = model.stream_completion(request, cx).await?;
 764                        log::debug!("Stream completion started successfully");
 765
 766                        let mut tool_use_limit_reached = false;
 767                        let mut tool_uses = FuturesUnordered::new();
 768                        while let Some(event) = events.next().await {
 769                            match event? {
 770                                LanguageModelCompletionEvent::StatusUpdate(
 771                                    CompletionRequestStatus::ToolUseLimitReached,
 772                                ) => {
 773                                    tool_use_limit_reached = true;
 774                                }
 775                                LanguageModelCompletionEvent::Stop(reason) => {
 776                                    event_stream.send_stop(reason);
 777                                    if reason == StopReason::Refusal {
 778                                        this.update(cx, |this, _cx| {
 779                                            this.flush_pending_message();
 780                                            this.messages.truncate(message_ix);
 781                                        })?;
 782                                        return Ok(());
 783                                    }
 784                                }
 785                                event => {
 786                                    log::trace!("Received completion event: {:?}", event);
 787                                    this.update(cx, |this, cx| {
 788                                        tool_uses.extend(this.handle_streamed_completion_event(
 789                                            event,
 790                                            &event_stream,
 791                                            cx,
 792                                        ));
 793                                    })
 794                                    .ok();
 795                                }
 796                            }
 797                        }
 798
 799                        let used_tools = tool_uses.is_empty();
 800                        while let Some(tool_result) = tool_uses.next().await {
 801                            log::info!("Tool finished {:?}", tool_result);
 802
 803                            event_stream.update_tool_call_fields(
 804                                &tool_result.tool_use_id,
 805                                acp::ToolCallUpdateFields {
 806                                    status: Some(if tool_result.is_error {
 807                                        acp::ToolCallStatus::Failed
 808                                    } else {
 809                                        acp::ToolCallStatus::Completed
 810                                    }),
 811                                    raw_output: tool_result.output.clone(),
 812                                    ..Default::default()
 813                                },
 814                            );
 815                            this.update(cx, |this, _cx| {
 816                                this.pending_message()
 817                                    .tool_results
 818                                    .insert(tool_result.tool_use_id.clone(), tool_result);
 819                            })
 820                            .ok();
 821                        }
 822
 823                        if tool_use_limit_reached {
 824                            log::info!("Tool use limit reached, completing turn");
 825                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
 826                            return Err(language_model::ToolUseLimitReachedError.into());
 827                        } else if used_tools {
 828                            log::info!("No tool uses found, completing turn");
 829                            return Ok(());
 830                        } else {
 831                            this.update(cx, |this, _| this.flush_pending_message())?;
 832                            completion_intent = CompletionIntent::ToolResults;
 833                        }
 834                    }
 835                }
 836                .await;
 837
 838                if let Err(error) = turn_result {
 839                    log::error!("Turn execution failed: {:?}", error);
 840                    event_stream.send_error(error);
 841                } else {
 842                    log::info!("Turn execution completed successfully");
 843                }
 844
 845                this.update(cx, |this, _| {
 846                    this.flush_pending_message();
 847                    this.running_turn.take();
 848                })
 849                .ok();
 850            }),
 851        });
 852        events_rx
 853    }
 854
 855    pub fn build_system_message(&self) -> LanguageModelRequestMessage {
 856        log::debug!("Building system message");
 857        let prompt = SystemPromptTemplate {
 858            project: &self.project_context.borrow(),
 859            available_tools: self.tools.keys().cloned().collect(),
 860        }
 861        .render(&self.templates)
 862        .context("failed to build system prompt")
 863        .expect("Invalid template");
 864        log::debug!("System message built");
 865        LanguageModelRequestMessage {
 866            role: Role::System,
 867            content: vec![prompt.into()],
 868            cache: true,
 869        }
 870    }
 871
 872    /// A helper method that's called on every streamed completion event.
 873    /// Returns an optional tool result task, which the main agentic loop in
 874    /// send will send back to the model when it resolves.
 875    fn handle_streamed_completion_event(
 876        &mut self,
 877        event: LanguageModelCompletionEvent,
 878        event_stream: &ThreadEventStream,
 879        cx: &mut Context<Self>,
 880    ) -> Option<Task<LanguageModelToolResult>> {
 881        log::trace!("Handling streamed completion event: {:?}", event);
 882        use LanguageModelCompletionEvent::*;
 883
 884        match event {
 885            StartMessage { .. } => {
 886                self.flush_pending_message();
 887                self.pending_message = Some(AgentMessage::default());
 888            }
 889            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 890            Thinking { text, signature } => {
 891                self.handle_thinking_event(text, signature, event_stream, cx)
 892            }
 893            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 894            ToolUse(tool_use) => {
 895                return self.handle_tool_use_event(tool_use, event_stream, cx);
 896            }
 897            ToolUseJsonParseError {
 898                id,
 899                tool_name,
 900                raw_input,
 901                json_parse_error,
 902            } => {
 903                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 904                    id,
 905                    tool_name,
 906                    raw_input,
 907                    json_parse_error,
 908                )));
 909            }
 910            UsageUpdate(_) | StatusUpdate(_) => {}
 911            Stop(_) => unreachable!(),
 912        }
 913
 914        None
 915    }
 916
 917    fn handle_text_event(
 918        &mut self,
 919        new_text: String,
 920        event_stream: &ThreadEventStream,
 921        cx: &mut Context<Self>,
 922    ) {
 923        event_stream.send_text(&new_text);
 924
 925        let last_message = self.pending_message();
 926        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
 927            text.push_str(&new_text);
 928        } else {
 929            last_message
 930                .content
 931                .push(AgentMessageContent::Text(new_text));
 932        }
 933
 934        cx.notify();
 935    }
 936
 937    fn handle_thinking_event(
 938        &mut self,
 939        new_text: String,
 940        new_signature: Option<String>,
 941        event_stream: &ThreadEventStream,
 942        cx: &mut Context<Self>,
 943    ) {
 944        event_stream.send_thinking(&new_text);
 945
 946        let last_message = self.pending_message();
 947        if let Some(AgentMessageContent::Thinking { text, signature }) =
 948            last_message.content.last_mut()
 949        {
 950            text.push_str(&new_text);
 951            *signature = new_signature.or(signature.take());
 952        } else {
 953            last_message.content.push(AgentMessageContent::Thinking {
 954                text: new_text,
 955                signature: new_signature,
 956            });
 957        }
 958
 959        cx.notify();
 960    }
 961
 962    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 963        let last_message = self.pending_message();
 964        last_message
 965            .content
 966            .push(AgentMessageContent::RedactedThinking(data));
 967        cx.notify();
 968    }
 969
 970    fn handle_tool_use_event(
 971        &mut self,
 972        tool_use: LanguageModelToolUse,
 973        event_stream: &ThreadEventStream,
 974        cx: &mut Context<Self>,
 975    ) -> Option<Task<LanguageModelToolResult>> {
 976        cx.notify();
 977
 978        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 979        let mut title = SharedString::from(&tool_use.name);
 980        let mut kind = acp::ToolKind::Other;
 981        if let Some(tool) = tool.as_ref() {
 982            title = tool.initial_title(tool_use.input.clone());
 983            kind = tool.kind();
 984        }
 985
 986        // Ensure the last message ends in the current tool use
 987        let last_message = self.pending_message();
 988        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 989            if let AgentMessageContent::ToolUse(last_tool_use) = content {
 990                if last_tool_use.id == tool_use.id {
 991                    *last_tool_use = tool_use.clone();
 992                    false
 993                } else {
 994                    true
 995                }
 996            } else {
 997                true
 998            }
 999        });
1000
1001        if push_new_tool_use {
1002            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1003            last_message
1004                .content
1005                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1006        } else {
1007            event_stream.update_tool_call_fields(
1008                &tool_use.id,
1009                acp::ToolCallUpdateFields {
1010                    title: Some(title.into()),
1011                    kind: Some(kind),
1012                    raw_input: Some(tool_use.input.clone()),
1013                    ..Default::default()
1014                },
1015            );
1016        }
1017
1018        if !tool_use.is_input_complete {
1019            return None;
1020        }
1021
1022        let Some(tool) = tool else {
1023            let content = format!("No tool named {} exists", tool_use.name);
1024            return Some(Task::ready(LanguageModelToolResult {
1025                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1026                tool_use_id: tool_use.id,
1027                tool_name: tool_use.name,
1028                is_error: true,
1029                output: None,
1030            }));
1031        };
1032
1033        let fs = self.project.read(cx).fs().clone();
1034        let tool_event_stream =
1035            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1036        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1037            status: Some(acp::ToolCallStatus::InProgress),
1038            ..Default::default()
1039        });
1040        let supports_images = self.model.supports_images();
1041        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1042        log::info!("Running tool {}", tool_use.name);
1043        Some(cx.foreground_executor().spawn(async move {
1044            let tool_result = tool_result.await.and_then(|output| {
1045                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1046                    if !supports_images {
1047                        return Err(anyhow!(
1048                            "Attempted to read an image, but this model doesn't support it.",
1049                        ));
1050                    }
1051                }
1052                Ok(output)
1053            });
1054
1055            match tool_result {
1056                Ok(output) => LanguageModelToolResult {
1057                    tool_use_id: tool_use.id,
1058                    tool_name: tool_use.name,
1059                    is_error: false,
1060                    content: output.llm_output,
1061                    output: Some(output.raw_output),
1062                },
1063                Err(error) => LanguageModelToolResult {
1064                    tool_use_id: tool_use.id,
1065                    tool_name: tool_use.name,
1066                    is_error: true,
1067                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1068                    output: None,
1069                },
1070            }
1071        }))
1072    }
1073
1074    fn handle_tool_use_json_parse_error_event(
1075        &mut self,
1076        tool_use_id: LanguageModelToolUseId,
1077        tool_name: Arc<str>,
1078        raw_input: Arc<str>,
1079        json_parse_error: String,
1080    ) -> LanguageModelToolResult {
1081        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1082        LanguageModelToolResult {
1083            tool_use_id,
1084            tool_name,
1085            is_error: true,
1086            content: LanguageModelToolResultContent::Text(tool_output.into()),
1087            output: Some(serde_json::Value::String(raw_input.to_string())),
1088        }
1089    }
1090
1091    fn pending_message(&mut self) -> &mut AgentMessage {
1092        self.pending_message.get_or_insert_default()
1093    }
1094
1095    fn flush_pending_message(&mut self) {
1096        let Some(mut message) = self.pending_message.take() else {
1097            return;
1098        };
1099
1100        for content in &message.content {
1101            let AgentMessageContent::ToolUse(tool_use) = content else {
1102                continue;
1103            };
1104
1105            if !message.tool_results.contains_key(&tool_use.id) {
1106                message.tool_results.insert(
1107                    tool_use.id.clone(),
1108                    LanguageModelToolResult {
1109                        tool_use_id: tool_use.id.clone(),
1110                        tool_name: tool_use.name.clone(),
1111                        is_error: true,
1112                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1113                        output: None,
1114                    },
1115                );
1116            }
1117        }
1118
1119        self.messages.push(Message::Agent(message));
1120    }
1121
1122    pub(crate) fn build_completion_request(
1123        &self,
1124        completion_intent: CompletionIntent,
1125        cx: &mut App,
1126    ) -> LanguageModelRequest {
1127        log::debug!("Building completion request");
1128        log::debug!("Completion intent: {:?}", completion_intent);
1129        log::debug!("Completion mode: {:?}", self.completion_mode);
1130
1131        let messages = self.build_request_messages();
1132        log::info!("Request will include {} messages", messages.len());
1133
1134        let tools = if let Some(tools) = self.tools(cx).log_err() {
1135            tools
1136                .filter_map(|tool| {
1137                    let tool_name = tool.name().to_string();
1138                    log::trace!("Including tool: {}", tool_name);
1139                    Some(LanguageModelRequestTool {
1140                        name: tool_name,
1141                        description: tool.description().to_string(),
1142                        input_schema: tool
1143                            .input_schema(self.model.tool_input_format())
1144                            .log_err()?,
1145                    })
1146                })
1147                .collect()
1148        } else {
1149            Vec::new()
1150        };
1151
1152        log::info!("Request includes {} tools", tools.len());
1153
1154        let request = LanguageModelRequest {
1155            thread_id: Some(self.id.to_string()),
1156            prompt_id: Some(self.prompt_id.to_string()),
1157            intent: Some(completion_intent),
1158            mode: Some(self.completion_mode.into()),
1159            messages,
1160            tools,
1161            tool_choice: None,
1162            stop: Vec::new(),
1163            temperature: AgentSettings::temperature_for_model(self.model(), cx),
1164            thinking_allowed: true,
1165        };
1166
1167        log::debug!("Completion request built successfully");
1168        request
1169    }
1170
1171    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1172        let profile = AgentSettings::get_global(cx)
1173            .profiles
1174            .get(&self.profile_id)
1175            .context("profile not found")?;
1176        let provider_id = self.model.provider_id();
1177
1178        Ok(self
1179            .tools
1180            .iter()
1181            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1182            .filter_map(|(tool_name, tool)| {
1183                if profile.is_tool_enabled(tool_name) {
1184                    Some(tool)
1185                } else {
1186                    None
1187                }
1188            })
1189            .chain(self.context_server_registry.read(cx).servers().flat_map(
1190                |(server_id, tools)| {
1191                    tools.iter().filter_map(|(tool_name, tool)| {
1192                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1193                            Some(tool)
1194                        } else {
1195                            None
1196                        }
1197                    })
1198                },
1199            )))
1200    }
1201
1202    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1203        log::trace!(
1204            "Building request messages from {} thread messages",
1205            self.messages.len()
1206        );
1207        let mut messages = vec![self.build_system_message()];
1208        for message in &self.messages {
1209            match message {
1210                Message::User(message) => messages.push(message.to_request()),
1211                Message::Agent(message) => messages.extend(message.to_request()),
1212                Message::Resume => messages.push(LanguageModelRequestMessage {
1213                    role: Role::User,
1214                    content: vec!["Continue where you left off".into()],
1215                    cache: false,
1216                }),
1217            }
1218        }
1219
1220        if let Some(message) = self.pending_message.as_ref() {
1221            messages.extend(message.to_request());
1222        }
1223
1224        if let Some(last_user_message) = messages
1225            .iter_mut()
1226            .rev()
1227            .find(|message| message.role == Role::User)
1228        {
1229            last_user_message.cache = true;
1230        }
1231
1232        messages
1233    }
1234
1235    pub fn to_markdown(&self) -> String {
1236        let mut markdown = String::new();
1237        for (ix, message) in self.messages.iter().enumerate() {
1238            if ix > 0 {
1239                markdown.push('\n');
1240            }
1241            markdown.push_str(&message.to_markdown());
1242        }
1243
1244        if let Some(message) = self.pending_message.as_ref() {
1245            markdown.push('\n');
1246            markdown.push_str(&message.to_markdown());
1247        }
1248
1249        markdown
1250    }
1251
1252    fn advance_prompt_id(&mut self) {
1253        self.prompt_id = PromptId::new();
1254    }
1255}
1256
1257struct RunningTurn {
1258    /// Holds the task that handles agent interaction until the end of the turn.
1259    /// Survives across multiple requests as the model performs tool calls and
1260    /// we run tools, report their results.
1261    _task: Task<()>,
1262    /// The current event stream for the running turn. Used to report a final
1263    /// cancellation event if we cancel the turn.
1264    event_stream: ThreadEventStream,
1265}
1266
1267impl RunningTurn {
1268    fn cancel(self) {
1269        log::debug!("Cancelling in progress turn");
1270        self.event_stream.send_canceled();
1271    }
1272}
1273
1274pub trait AgentTool
1275where
1276    Self: 'static + Sized,
1277{
1278    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1279    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1280
1281    fn name(&self) -> SharedString;
1282
1283    fn description(&self) -> SharedString {
1284        let schema = schemars::schema_for!(Self::Input);
1285        SharedString::new(
1286            schema
1287                .get("description")
1288                .and_then(|description| description.as_str())
1289                .unwrap_or_default(),
1290        )
1291    }
1292
1293    fn kind(&self) -> acp::ToolKind;
1294
1295    /// The initial tool title to display. Can be updated during the tool run.
1296    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1297
1298    /// Returns the JSON schema that describes the tool's input.
1299    fn input_schema(&self) -> Schema {
1300        schemars::schema_for!(Self::Input)
1301    }
1302
1303    /// Some tools rely on a provider for the underlying billing or other reasons.
1304    /// Allow the tool to check if they are compatible, or should be filtered out.
1305    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1306        true
1307    }
1308
1309    /// Runs the tool with the provided input.
1310    fn run(
1311        self: Arc<Self>,
1312        input: Self::Input,
1313        event_stream: ToolCallEventStream,
1314        cx: &mut App,
1315    ) -> Task<Result<Self::Output>>;
1316
1317    /// Emits events for a previous execution of the tool.
1318    fn replay(
1319        &self,
1320        _input: Self::Input,
1321        _output: Self::Output,
1322        _event_stream: ToolCallEventStream,
1323        _cx: &mut App,
1324    ) -> Result<()> {
1325        Ok(())
1326    }
1327
1328    fn erase(self) -> Arc<dyn AnyAgentTool> {
1329        Arc::new(Erased(Arc::new(self)))
1330    }
1331}
1332
1333pub struct Erased<T>(T);
1334
1335pub struct AgentToolOutput {
1336    pub llm_output: LanguageModelToolResultContent,
1337    pub raw_output: serde_json::Value,
1338}
1339
1340pub trait AnyAgentTool {
1341    fn name(&self) -> SharedString;
1342    fn description(&self) -> SharedString;
1343    fn kind(&self) -> acp::ToolKind;
1344    fn initial_title(&self, input: serde_json::Value) -> SharedString;
1345    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1346    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1347        true
1348    }
1349    fn run(
1350        self: Arc<Self>,
1351        input: serde_json::Value,
1352        event_stream: ToolCallEventStream,
1353        cx: &mut App,
1354    ) -> Task<Result<AgentToolOutput>>;
1355    fn replay(
1356        &self,
1357        input: serde_json::Value,
1358        output: serde_json::Value,
1359        event_stream: ToolCallEventStream,
1360        cx: &mut App,
1361    ) -> Result<()>;
1362}
1363
1364impl<T> AnyAgentTool for Erased<Arc<T>>
1365where
1366    T: AgentTool,
1367{
1368    fn name(&self) -> SharedString {
1369        self.0.name()
1370    }
1371
1372    fn description(&self) -> SharedString {
1373        self.0.description()
1374    }
1375
1376    fn kind(&self) -> agent_client_protocol::ToolKind {
1377        self.0.kind()
1378    }
1379
1380    fn initial_title(&self, input: serde_json::Value) -> SharedString {
1381        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1382        self.0.initial_title(parsed_input)
1383    }
1384
1385    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1386        let mut json = serde_json::to_value(self.0.input_schema())?;
1387        adapt_schema_to_format(&mut json, format)?;
1388        Ok(json)
1389    }
1390
1391    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1392        self.0.supported_provider(provider)
1393    }
1394
1395    fn run(
1396        self: Arc<Self>,
1397        input: serde_json::Value,
1398        event_stream: ToolCallEventStream,
1399        cx: &mut App,
1400    ) -> Task<Result<AgentToolOutput>> {
1401        cx.spawn(async move |cx| {
1402            let input = serde_json::from_value(input)?;
1403            let output = cx
1404                .update(|cx| self.0.clone().run(input, event_stream, cx))?
1405                .await?;
1406            let raw_output = serde_json::to_value(&output)?;
1407            Ok(AgentToolOutput {
1408                llm_output: output.into(),
1409                raw_output,
1410            })
1411        })
1412    }
1413
1414    fn replay(
1415        &self,
1416        input: serde_json::Value,
1417        output: serde_json::Value,
1418        event_stream: ToolCallEventStream,
1419        cx: &mut App,
1420    ) -> Result<()> {
1421        let input = serde_json::from_value(input)?;
1422        let output = serde_json::from_value(output)?;
1423        self.0.replay(input, output, event_stream, cx)
1424    }
1425}
1426
1427#[derive(Clone)]
1428struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1429
1430impl ThreadEventStream {
1431    fn send_user_message(&self, message: &UserMessage) {
1432        self.0
1433            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1434            .ok();
1435    }
1436
1437    fn send_text(&self, text: &str) {
1438        self.0
1439            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1440            .ok();
1441    }
1442
1443    fn send_thinking(&self, text: &str) {
1444        self.0
1445            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1446            .ok();
1447    }
1448
1449    fn send_tool_call(
1450        &self,
1451        id: &LanguageModelToolUseId,
1452        title: SharedString,
1453        kind: acp::ToolKind,
1454        input: serde_json::Value,
1455    ) {
1456        self.0
1457            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1458                id,
1459                title.to_string(),
1460                kind,
1461                input,
1462            ))))
1463            .ok();
1464    }
1465
1466    fn initial_tool_call(
1467        id: &LanguageModelToolUseId,
1468        title: String,
1469        kind: acp::ToolKind,
1470        input: serde_json::Value,
1471    ) -> acp::ToolCall {
1472        acp::ToolCall {
1473            id: acp::ToolCallId(id.to_string().into()),
1474            title,
1475            kind,
1476            status: acp::ToolCallStatus::Pending,
1477            content: vec![],
1478            locations: vec![],
1479            raw_input: Some(input),
1480            raw_output: None,
1481        }
1482    }
1483
1484    fn update_tool_call_fields(
1485        &self,
1486        tool_use_id: &LanguageModelToolUseId,
1487        fields: acp::ToolCallUpdateFields,
1488    ) {
1489        self.0
1490            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1491                acp::ToolCallUpdate {
1492                    id: acp::ToolCallId(tool_use_id.to_string().into()),
1493                    fields,
1494                }
1495                .into(),
1496            )))
1497            .ok();
1498    }
1499
1500    fn send_stop(&self, reason: StopReason) {
1501        match reason {
1502            StopReason::EndTurn => {
1503                self.0
1504                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1505                    .ok();
1506            }
1507            StopReason::MaxTokens => {
1508                self.0
1509                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1510                    .ok();
1511            }
1512            StopReason::Refusal => {
1513                self.0
1514                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1515                    .ok();
1516            }
1517            StopReason::ToolUse => {}
1518        }
1519    }
1520
1521    fn send_canceled(&self) {
1522        self.0
1523            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1524            .ok();
1525    }
1526
1527    fn send_error(&self, error: impl Into<anyhow::Error>) {
1528        self.0.unbounded_send(Err(error.into())).ok();
1529    }
1530}
1531
1532#[derive(Clone)]
1533pub struct ToolCallEventStream {
1534    tool_use_id: LanguageModelToolUseId,
1535    stream: ThreadEventStream,
1536    fs: Option<Arc<dyn Fs>>,
1537}
1538
1539impl ToolCallEventStream {
1540    #[cfg(test)]
1541    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1542        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1543
1544        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1545
1546        (stream, ToolCallEventStreamReceiver(events_rx))
1547    }
1548
1549    fn new(
1550        tool_use_id: LanguageModelToolUseId,
1551        stream: ThreadEventStream,
1552        fs: Option<Arc<dyn Fs>>,
1553    ) -> Self {
1554        Self {
1555            tool_use_id,
1556            stream,
1557            fs,
1558        }
1559    }
1560
1561    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1562        self.stream
1563            .update_tool_call_fields(&self.tool_use_id, fields);
1564    }
1565
1566    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1567        self.stream
1568            .0
1569            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1570                acp_thread::ToolCallUpdateDiff {
1571                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1572                    diff,
1573                }
1574                .into(),
1575            )))
1576            .ok();
1577    }
1578
1579    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1580        self.stream
1581            .0
1582            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1583                acp_thread::ToolCallUpdateTerminal {
1584                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1585                    terminal,
1586                }
1587                .into(),
1588            )))
1589            .ok();
1590    }
1591
1592    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1593        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1594            return Task::ready(Ok(()));
1595        }
1596
1597        let (response_tx, response_rx) = oneshot::channel();
1598        self.stream
1599            .0
1600            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1601                ToolCallAuthorization {
1602                    tool_call: acp::ToolCallUpdate {
1603                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1604                        fields: acp::ToolCallUpdateFields {
1605                            title: Some(title.into()),
1606                            ..Default::default()
1607                        },
1608                    },
1609                    options: vec![
1610                        acp::PermissionOption {
1611                            id: acp::PermissionOptionId("always_allow".into()),
1612                            name: "Always Allow".into(),
1613                            kind: acp::PermissionOptionKind::AllowAlways,
1614                        },
1615                        acp::PermissionOption {
1616                            id: acp::PermissionOptionId("allow".into()),
1617                            name: "Allow".into(),
1618                            kind: acp::PermissionOptionKind::AllowOnce,
1619                        },
1620                        acp::PermissionOption {
1621                            id: acp::PermissionOptionId("deny".into()),
1622                            name: "Deny".into(),
1623                            kind: acp::PermissionOptionKind::RejectOnce,
1624                        },
1625                    ],
1626                    response: response_tx,
1627                },
1628            )))
1629            .ok();
1630        let fs = self.fs.clone();
1631        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1632            "always_allow" => {
1633                if let Some(fs) = fs.clone() {
1634                    cx.update(|cx| {
1635                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1636                            settings.set_always_allow_tool_actions(true);
1637                        });
1638                    })?;
1639                }
1640
1641                Ok(())
1642            }
1643            "allow" => Ok(()),
1644            _ => Err(anyhow!("Permission to run tool denied by user")),
1645        })
1646    }
1647}
1648
1649#[cfg(test)]
1650pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1651
1652#[cfg(test)]
1653impl ToolCallEventStreamReceiver {
1654    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1655        let event = self.0.next().await;
1656        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1657            auth
1658        } else {
1659            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1660        }
1661    }
1662
1663    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1664        let event = self.0.next().await;
1665        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1666            update,
1667        )))) = event
1668        {
1669            update.terminal
1670        } else {
1671            panic!("Expected terminal but got: {:?}", event);
1672        }
1673    }
1674}
1675
1676#[cfg(test)]
1677impl std::ops::Deref for ToolCallEventStreamReceiver {
1678    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1679
1680    fn deref(&self) -> &Self::Target {
1681        &self.0
1682    }
1683}
1684
1685#[cfg(test)]
1686impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1687    fn deref_mut(&mut self) -> &mut Self::Target {
1688        &mut self.0
1689    }
1690}
1691
1692impl From<&str> for UserMessageContent {
1693    fn from(text: &str) -> Self {
1694        Self::Text(text.into())
1695    }
1696}
1697
1698impl From<acp::ContentBlock> for UserMessageContent {
1699    fn from(value: acp::ContentBlock) -> Self {
1700        match value {
1701            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1702            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1703            acp::ContentBlock::Audio(_) => {
1704                // TODO
1705                Self::Text("[audio]".to_string())
1706            }
1707            acp::ContentBlock::ResourceLink(resource_link) => {
1708                match MentionUri::parse(&resource_link.uri) {
1709                    Ok(uri) => Self::Mention {
1710                        uri,
1711                        content: String::new(),
1712                    },
1713                    Err(err) => {
1714                        log::error!("Failed to parse mention link: {}", err);
1715                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1716                    }
1717                }
1718            }
1719            acp::ContentBlock::Resource(resource) => match resource.resource {
1720                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1721                    match MentionUri::parse(&resource.uri) {
1722                        Ok(uri) => Self::Mention {
1723                            uri,
1724                            content: resource.text,
1725                        },
1726                        Err(err) => {
1727                            log::error!("Failed to parse mention link: {}", err);
1728                            Self::Text(
1729                                MarkdownCodeBlock {
1730                                    tag: &resource.uri,
1731                                    text: &resource.text,
1732                                }
1733                                .to_string(),
1734                            )
1735                        }
1736                    }
1737                }
1738                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1739                    // TODO
1740                    Self::Text("[blob]".to_string())
1741                }
1742            },
1743        }
1744    }
1745}
1746
1747fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1748    LanguageModelImage {
1749        source: image_content.data.into(),
1750        // TODO: make this optional?
1751        size: gpui::Size::new(0.into(), 0.into()),
1752    }
1753}