thread.rs

   1use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates};
   2use acp_thread::{MentionUri, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol as acp;
   5use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
   6use anyhow::{Context as _, Result, anyhow};
   7use assistant_tool::adapt_schema_to_format;
   8use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
   9use collections::IndexMap;
  10use fs::Fs;
  11use futures::{
  12    channel::{mpsc, oneshot},
  13    stream::FuturesUnordered,
  14};
  15use gpui::{App, AppContext, Context, Entity, SharedString, Task};
  16use language_model::{
  17    LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
  18    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  19    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  20    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
  21};
  22use project::Project;
  23use prompt_store::ProjectContext;
  24use schemars::{JsonSchema, Schema};
  25use serde::{Deserialize, Serialize};
  26use settings::{Settings, update_settings_file};
  27use smol::stream::StreamExt;
  28use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
  29use std::{fmt::Write, ops::Range};
  30use util::{ResultExt, markdown::MarkdownCodeBlock};
  31use uuid::Uuid;
  32
  33const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  34
  35#[derive(
  36    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  37)]
  38pub struct ThreadId(pub(crate) Arc<str>);
  39
  40impl ThreadId {
  41    pub fn new() -> Self {
  42        Self(Uuid::new_v4().to_string().into())
  43    }
  44}
  45
  46impl std::fmt::Display for ThreadId {
  47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  48        write!(f, "{}", self.0)
  49    }
  50}
  51
  52impl From<&str> for ThreadId {
  53    fn from(value: &str) -> Self {
  54        Self(value.into())
  55    }
  56}
  57
  58impl From<acp::SessionId> for ThreadId {
  59    fn from(value: acp::SessionId) -> Self {
  60        Self(value.0)
  61    }
  62}
  63
  64impl From<ThreadId> for acp::SessionId {
  65    fn from(value: ThreadId) -> Self {
  66        Self(value.0)
  67    }
  68}
  69
  70/// The ID of the user prompt that initiated a request.
  71///
  72/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  74pub struct PromptId(Arc<str>);
  75
  76impl PromptId {
  77    pub fn new() -> Self {
  78        Self(Uuid::new_v4().to_string().into())
  79    }
  80}
  81
  82impl std::fmt::Display for PromptId {
  83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  84        write!(f, "{}", self.0)
  85    }
  86}
  87
  88#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  89pub enum Message {
  90    User(UserMessage),
  91    Agent(AgentMessage),
  92    Resume,
  93}
  94
  95impl Message {
  96    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
  97        match self {
  98            Message::Agent(agent_message) => Some(agent_message),
  99            _ => None,
 100        }
 101    }
 102
 103    pub fn to_markdown(&self) -> String {
 104        match self {
 105            Message::User(message) => message.to_markdown(),
 106            Message::Agent(message) => message.to_markdown(),
 107            Message::Resume => "[resumed after tool use limit was reached]".into(),
 108        }
 109    }
 110}
 111
 112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 113pub struct UserMessage {
 114    pub id: UserMessageId,
 115    pub content: Vec<UserMessageContent>,
 116}
 117
 118#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 119pub enum UserMessageContent {
 120    Text(String),
 121    Mention { uri: MentionUri, content: String },
 122    Image(LanguageModelImage),
 123}
 124
 125impl UserMessage {
 126    pub fn to_markdown(&self) -> String {
 127        let mut markdown = String::from("## User\n\n");
 128
 129        for content in &self.content {
 130            match content {
 131                UserMessageContent::Text(text) => {
 132                    markdown.push_str(text);
 133                    markdown.push('\n');
 134                }
 135                UserMessageContent::Image(_) => {
 136                    markdown.push_str("<image />\n");
 137                }
 138                UserMessageContent::Mention { uri, content } => {
 139                    if !content.is_empty() {
 140                        let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
 141                    } else {
 142                        let _ = write!(&mut markdown, "{}\n", uri.as_link());
 143                    }
 144                }
 145            }
 146        }
 147
 148        markdown
 149    }
 150
 151    fn to_request(&self) -> LanguageModelRequestMessage {
 152        let mut message = LanguageModelRequestMessage {
 153            role: Role::User,
 154            content: Vec::with_capacity(self.content.len()),
 155            cache: false,
 156        };
 157
 158        const OPEN_CONTEXT: &str = "<context>\n\
 159            The following items were attached by the user. \
 160            They are up-to-date and don't need to be re-read.\n\n";
 161
 162        const OPEN_FILES_TAG: &str = "<files>";
 163        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 164        const OPEN_THREADS_TAG: &str = "<threads>";
 165        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 166        const OPEN_RULES_TAG: &str =
 167            "<rules>\nThe user has specified the following rules that should be applied:\n";
 168
 169        let mut file_context = OPEN_FILES_TAG.to_string();
 170        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 171        let mut thread_context = OPEN_THREADS_TAG.to_string();
 172        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 173        let mut rules_context = OPEN_RULES_TAG.to_string();
 174
 175        for chunk in &self.content {
 176            let chunk = match chunk {
 177                UserMessageContent::Text(text) => {
 178                    language_model::MessageContent::Text(text.clone())
 179                }
 180                UserMessageContent::Image(value) => {
 181                    language_model::MessageContent::Image(value.clone())
 182                }
 183                UserMessageContent::Mention { uri, content } => {
 184                    match uri {
 185                        MentionUri::File { abs_path, .. } => {
 186                            write!(
 187                                &mut symbol_context,
 188                                "\n{}",
 189                                MarkdownCodeBlock {
 190                                    tag: &codeblock_tag(&abs_path, None),
 191                                    text: &content.to_string(),
 192                                }
 193                            )
 194                            .ok();
 195                        }
 196                        MentionUri::Symbol {
 197                            path, line_range, ..
 198                        }
 199                        | MentionUri::Selection {
 200                            path, line_range, ..
 201                        } => {
 202                            write!(
 203                                &mut rules_context,
 204                                "\n{}",
 205                                MarkdownCodeBlock {
 206                                    tag: &codeblock_tag(&path, Some(line_range)),
 207                                    text: &content
 208                                }
 209                            )
 210                            .ok();
 211                        }
 212                        MentionUri::Thread { .. } => {
 213                            write!(&mut thread_context, "\n{}\n", content).ok();
 214                        }
 215                        MentionUri::TextThread { .. } => {
 216                            write!(&mut thread_context, "\n{}\n", content).ok();
 217                        }
 218                        MentionUri::Rule { .. } => {
 219                            write!(
 220                                &mut rules_context,
 221                                "\n{}",
 222                                MarkdownCodeBlock {
 223                                    tag: "",
 224                                    text: &content
 225                                }
 226                            )
 227                            .ok();
 228                        }
 229                        MentionUri::Fetch { url } => {
 230                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 231                        }
 232                    }
 233
 234                    language_model::MessageContent::Text(uri.as_link().to_string())
 235                }
 236            };
 237
 238            message.content.push(chunk);
 239        }
 240
 241        let len_before_context = message.content.len();
 242
 243        if file_context.len() > OPEN_FILES_TAG.len() {
 244            file_context.push_str("</files>\n");
 245            message
 246                .content
 247                .push(language_model::MessageContent::Text(file_context));
 248        }
 249
 250        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 251            symbol_context.push_str("</symbols>\n");
 252            message
 253                .content
 254                .push(language_model::MessageContent::Text(symbol_context));
 255        }
 256
 257        if thread_context.len() > OPEN_THREADS_TAG.len() {
 258            thread_context.push_str("</threads>\n");
 259            message
 260                .content
 261                .push(language_model::MessageContent::Text(thread_context));
 262        }
 263
 264        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 265            fetch_context.push_str("</fetched_urls>\n");
 266            message
 267                .content
 268                .push(language_model::MessageContent::Text(fetch_context));
 269        }
 270
 271        if rules_context.len() > OPEN_RULES_TAG.len() {
 272            rules_context.push_str("</user_rules>\n");
 273            message
 274                .content
 275                .push(language_model::MessageContent::Text(rules_context));
 276        }
 277
 278        if message.content.len() > len_before_context {
 279            message.content.insert(
 280                len_before_context,
 281                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 282            );
 283            message
 284                .content
 285                .push(language_model::MessageContent::Text("</context>".into()));
 286        }
 287
 288        message
 289    }
 290}
 291
 292fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 293    let mut result = String::new();
 294
 295    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 296        let _ = write!(result, "{} ", extension);
 297    }
 298
 299    let _ = write!(result, "{}", full_path.display());
 300
 301    if let Some(range) = line_range {
 302        if range.start == range.end {
 303            let _ = write!(result, ":{}", range.start + 1);
 304        } else {
 305            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 306        }
 307    }
 308
 309    result
 310}
 311
 312impl AgentMessage {
 313    pub fn to_markdown(&self) -> String {
 314        let mut markdown = String::from("## Assistant\n\n");
 315
 316        for content in &self.content {
 317            match content {
 318                AgentMessageContent::Text(text) => {
 319                    markdown.push_str(text);
 320                    markdown.push('\n');
 321                }
 322                AgentMessageContent::Thinking { text, .. } => {
 323                    markdown.push_str("<think>");
 324                    markdown.push_str(text);
 325                    markdown.push_str("</think>\n");
 326                }
 327                AgentMessageContent::RedactedThinking(_) => {
 328                    markdown.push_str("<redacted_thinking />\n")
 329                }
 330                AgentMessageContent::ToolUse(tool_use) => {
 331                    markdown.push_str(&format!(
 332                        "**Tool Use**: {} (ID: {})\n",
 333                        tool_use.name, tool_use.id
 334                    ));
 335                    markdown.push_str(&format!(
 336                        "{}\n",
 337                        MarkdownCodeBlock {
 338                            tag: "json",
 339                            text: &format!("{:#}", tool_use.input)
 340                        }
 341                    ));
 342                }
 343            }
 344        }
 345
 346        for tool_result in self.tool_results.values() {
 347            markdown.push_str(&format!(
 348                "**Tool Result**: {} (ID: {})\n\n",
 349                tool_result.tool_name, tool_result.tool_use_id
 350            ));
 351            if tool_result.is_error {
 352                markdown.push_str("**ERROR:**\n");
 353            }
 354
 355            match &tool_result.content {
 356                LanguageModelToolResultContent::Text(text) => {
 357                    writeln!(markdown, "{text}\n").ok();
 358                }
 359                LanguageModelToolResultContent::Image(_) => {
 360                    writeln!(markdown, "<image />\n").ok();
 361                }
 362            }
 363
 364            if let Some(output) = tool_result.output.as_ref() {
 365                writeln!(
 366                    markdown,
 367                    "**Debug Output**:\n\n```json\n{}\n```\n",
 368                    serde_json::to_string_pretty(output).unwrap()
 369                )
 370                .unwrap();
 371            }
 372        }
 373
 374        markdown
 375    }
 376
 377    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 378        let mut assistant_message = LanguageModelRequestMessage {
 379            role: Role::Assistant,
 380            content: Vec::with_capacity(self.content.len()),
 381            cache: false,
 382        };
 383        for chunk in &self.content {
 384            let chunk = match chunk {
 385                AgentMessageContent::Text(text) => {
 386                    language_model::MessageContent::Text(text.clone())
 387                }
 388                AgentMessageContent::Thinking { text, signature } => {
 389                    language_model::MessageContent::Thinking {
 390                        text: text.clone(),
 391                        signature: signature.clone(),
 392                    }
 393                }
 394                AgentMessageContent::RedactedThinking(value) => {
 395                    language_model::MessageContent::RedactedThinking(value.clone())
 396                }
 397                AgentMessageContent::ToolUse(value) => {
 398                    language_model::MessageContent::ToolUse(value.clone())
 399                }
 400            };
 401            assistant_message.content.push(chunk);
 402        }
 403
 404        let mut user_message = LanguageModelRequestMessage {
 405            role: Role::User,
 406            content: Vec::new(),
 407            cache: false,
 408        };
 409
 410        for tool_result in self.tool_results.values() {
 411            user_message
 412                .content
 413                .push(language_model::MessageContent::ToolResult(
 414                    tool_result.clone(),
 415                ));
 416        }
 417
 418        let mut messages = Vec::new();
 419        if !assistant_message.content.is_empty() {
 420            messages.push(assistant_message);
 421        }
 422        if !user_message.content.is_empty() {
 423            messages.push(user_message);
 424        }
 425        messages
 426    }
 427}
 428
 429#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 430pub struct AgentMessage {
 431    pub content: Vec<AgentMessageContent>,
 432    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 433}
 434
 435#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 436pub enum AgentMessageContent {
 437    Text(String),
 438    Thinking {
 439        text: String,
 440        signature: Option<String>,
 441    },
 442    RedactedThinking(String),
 443    ToolUse(LanguageModelToolUse),
 444}
 445
 446#[derive(Debug)]
 447pub enum ThreadEvent {
 448    UserMessage(UserMessage),
 449    AgentText(String),
 450    AgentThinking(String),
 451    ToolCall(acp::ToolCall),
 452    ToolCallUpdate(acp_thread::ToolCallUpdate),
 453    ToolCallAuthorization(ToolCallAuthorization),
 454    Stop(acp::StopReason),
 455}
 456
 457#[derive(Debug)]
 458pub struct ToolCallAuthorization {
 459    pub tool_call: acp::ToolCallUpdate,
 460    pub options: Vec<acp::PermissionOption>,
 461    pub response: oneshot::Sender<acp::PermissionOptionId>,
 462}
 463
 464pub struct Thread {
 465    id: ThreadId,
 466    prompt_id: PromptId,
 467    messages: Vec<Message>,
 468    completion_mode: CompletionMode,
 469    /// Holds the task that handles agent interaction until the end of the turn.
 470    /// Survives across multiple requests as the model performs tool calls and
 471    /// we run tools, report their results.
 472    running_turn: Option<RunningTurn>,
 473    pending_message: Option<AgentMessage>,
 474    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 475    tool_use_limit_reached: bool,
 476    context_server_registry: Entity<ContextServerRegistry>,
 477    profile_id: AgentProfileId,
 478    project_context: Rc<RefCell<ProjectContext>>,
 479    templates: Arc<Templates>,
 480    model: Arc<dyn LanguageModel>,
 481    project: Entity<Project>,
 482    action_log: Entity<ActionLog>,
 483}
 484
 485impl Thread {
 486    pub fn new(
 487        project: Entity<Project>,
 488        project_context: Rc<RefCell<ProjectContext>>,
 489        context_server_registry: Entity<ContextServerRegistry>,
 490        action_log: Entity<ActionLog>,
 491        templates: Arc<Templates>,
 492        model: Arc<dyn LanguageModel>,
 493        cx: &mut Context<Self>,
 494    ) -> Self {
 495        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 496        Self {
 497            id: ThreadId::new(),
 498            prompt_id: PromptId::new(),
 499            messages: Vec::new(),
 500            completion_mode: CompletionMode::Normal,
 501            running_turn: None,
 502            pending_message: None,
 503            tools: BTreeMap::default(),
 504            tool_use_limit_reached: false,
 505            context_server_registry,
 506            profile_id,
 507            project_context,
 508            templates,
 509            model,
 510            project,
 511            action_log,
 512        }
 513    }
 514
 515    pub fn from_db(
 516        id: ThreadId,
 517        db_thread: DbThread,
 518        project: Entity<Project>,
 519        project_context: Rc<RefCell<ProjectContext>>,
 520        context_server_registry: Entity<ContextServerRegistry>,
 521        action_log: Entity<ActionLog>,
 522        templates: Arc<Templates>,
 523        model: Arc<dyn LanguageModel>,
 524        cx: &mut Context<Self>,
 525    ) -> Self {
 526        let profile_id = db_thread
 527            .profile
 528            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 529        Self {
 530            id,
 531            prompt_id: PromptId::new(),
 532            messages: db_thread.messages,
 533            completion_mode: CompletionMode::Normal,
 534            running_turn: None,
 535            pending_message: None,
 536            tools: BTreeMap::default(),
 537            tool_use_limit_reached: false,
 538            context_server_registry,
 539            profile_id,
 540            project_context,
 541            templates,
 542            model,
 543            project,
 544            action_log,
 545        }
 546    }
 547
 548    pub fn replay(
 549        &mut self,
 550        cx: &mut Context<Self>,
 551    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 552        let (tx, rx) = mpsc::unbounded();
 553        let stream = ThreadEventStream(tx);
 554        for message in &self.messages {
 555            match message {
 556                Message::User(user_message) => stream.send_user_message(&user_message),
 557                Message::Agent(assistant_message) => {
 558                    for content in &assistant_message.content {
 559                        match content {
 560                            AgentMessageContent::Text(text) => stream.send_text(text),
 561                            AgentMessageContent::Thinking { text, .. } => {
 562                                stream.send_thinking(text)
 563                            }
 564                            AgentMessageContent::RedactedThinking(_) => {}
 565                            AgentMessageContent::ToolUse(tool_use) => {
 566                                self.replay_tool_call(
 567                                    tool_use,
 568                                    assistant_message.tool_results.get(&tool_use.id),
 569                                    &stream,
 570                                    cx,
 571                                );
 572                            }
 573                        }
 574                    }
 575                }
 576                Message::Resume => {}
 577            }
 578        }
 579        rx
 580    }
 581
 582    fn replay_tool_call(
 583        &self,
 584        tool_use: &LanguageModelToolUse,
 585        tool_result: Option<&LanguageModelToolResult>,
 586        stream: &ThreadEventStream,
 587        cx: &mut Context<Self>,
 588    ) {
 589        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
 590            stream
 591                .0
 592                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 593                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 594                    title: tool_use.name.to_string(),
 595                    kind: acp::ToolKind::Other,
 596                    status: acp::ToolCallStatus::Failed,
 597                    content: Vec::new(),
 598                    locations: Vec::new(),
 599                    raw_input: Some(tool_use.input.clone()),
 600                    raw_output: None,
 601                })))
 602                .ok();
 603            return;
 604        };
 605
 606        let title = tool.initial_title(tool_use.input.clone());
 607        let kind = tool.kind();
 608        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 609
 610        let output = tool_result
 611            .as_ref()
 612            .and_then(|result| result.output.clone());
 613        if let Some(output) = output.clone() {
 614            let tool_event_stream = ToolCallEventStream::new(
 615                tool_use.id.clone(),
 616                stream.clone(),
 617                Some(self.project.read(cx).fs().clone()),
 618            );
 619            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 620                .log_err();
 621        }
 622
 623        stream.update_tool_call_fields(
 624            &tool_use.id,
 625            acp::ToolCallUpdateFields {
 626                status: Some(acp::ToolCallStatus::Completed),
 627                raw_output: output,
 628                ..Default::default()
 629            },
 630        );
 631    }
 632
 633    pub fn project(&self) -> &Entity<Project> {
 634        &self.project
 635    }
 636
 637    pub fn action_log(&self) -> &Entity<ActionLog> {
 638        &self.action_log
 639    }
 640
 641    pub fn model(&self) -> &Arc<dyn LanguageModel> {
 642        &self.model
 643    }
 644
 645    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
 646        self.model = model;
 647    }
 648
 649    pub fn completion_mode(&self) -> CompletionMode {
 650        self.completion_mode
 651    }
 652
 653    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 654        self.completion_mode = mode;
 655    }
 656
 657    #[cfg(any(test, feature = "test-support"))]
 658    pub fn last_message(&self) -> Option<Message> {
 659        if let Some(message) = self.pending_message.clone() {
 660            Some(Message::Agent(message))
 661        } else {
 662            self.messages.last().cloned()
 663        }
 664    }
 665
 666    pub fn add_tool(&mut self, tool: impl AgentTool) {
 667        self.tools.insert(tool.name(), tool.erase());
 668    }
 669
 670    pub fn remove_tool(&mut self, name: &str) -> bool {
 671        self.tools.remove(name).is_some()
 672    }
 673
 674    pub fn profile(&self) -> &AgentProfileId {
 675        &self.profile_id
 676    }
 677
 678    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 679        self.profile_id = profile_id;
 680    }
 681
 682    pub fn cancel(&mut self) {
 683        if let Some(running_turn) = self.running_turn.take() {
 684            running_turn.cancel();
 685        }
 686        self.flush_pending_message();
 687    }
 688
 689    pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
 690        self.cancel();
 691        let Some(position) = self.messages.iter().position(
 692            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
 693        ) else {
 694            return Err(anyhow!("Message not found"));
 695        };
 696        self.messages.truncate(position);
 697        Ok(())
 698    }
 699
 700    pub fn resume(
 701        &mut self,
 702        cx: &mut Context<Self>,
 703    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
 704        anyhow::ensure!(
 705            self.tool_use_limit_reached,
 706            "can only resume after tool use limit is reached"
 707        );
 708
 709        self.messages.push(Message::Resume);
 710        cx.notify();
 711
 712        log::info!("Total messages in thread: {}", self.messages.len());
 713        Ok(self.run_turn(cx))
 714    }
 715
 716    /// Sending a message results in the model streaming a response, which could include tool calls.
 717    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 718    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 719    pub fn send<T>(
 720        &mut self,
 721        id: UserMessageId,
 722        content: impl IntoIterator<Item = T>,
 723        cx: &mut Context<Self>,
 724    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
 725    where
 726        T: Into<UserMessageContent>,
 727    {
 728        log::info!("Thread::send called with model: {:?}", self.model.name());
 729        self.advance_prompt_id();
 730
 731        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
 732        log::debug!("Thread::send content: {:?}", content);
 733
 734        self.messages
 735            .push(Message::User(UserMessage { id, content }));
 736        cx.notify();
 737
 738        log::info!("Total messages in thread: {}", self.messages.len());
 739        self.run_turn(cx)
 740    }
 741
 742    fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 743        self.cancel();
 744
 745        let model = self.model.clone();
 746        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
 747        let event_stream = ThreadEventStream(events_tx);
 748        let message_ix = self.messages.len().saturating_sub(1);
 749        self.tool_use_limit_reached = false;
 750        self.running_turn = Some(RunningTurn {
 751            event_stream: event_stream.clone(),
 752            _task: cx.spawn(async move |this, cx| {
 753                log::info!("Starting agent turn execution");
 754                let turn_result: Result<()> = async {
 755                    let mut completion_intent = CompletionIntent::UserPrompt;
 756                    loop {
 757                        log::debug!(
 758                            "Building completion request with intent: {:?}",
 759                            completion_intent
 760                        );
 761                        let request = this.update(cx, |this, cx| {
 762                            this.build_completion_request(completion_intent, cx)
 763                        })?;
 764
 765                        log::info!("Calling model.stream_completion");
 766                        let mut events = model.stream_completion(request, cx).await?;
 767                        log::debug!("Stream completion started successfully");
 768
 769                        let mut tool_use_limit_reached = false;
 770                        let mut tool_uses = FuturesUnordered::new();
 771                        while let Some(event) = events.next().await {
 772                            match event? {
 773                                LanguageModelCompletionEvent::StatusUpdate(
 774                                    CompletionRequestStatus::ToolUseLimitReached,
 775                                ) => {
 776                                    tool_use_limit_reached = true;
 777                                }
 778                                LanguageModelCompletionEvent::Stop(reason) => {
 779                                    event_stream.send_stop(reason);
 780                                    if reason == StopReason::Refusal {
 781                                        this.update(cx, |this, _cx| {
 782                                            this.flush_pending_message();
 783                                            this.messages.truncate(message_ix);
 784                                        })?;
 785                                        return Ok(());
 786                                    }
 787                                }
 788                                event => {
 789                                    log::trace!("Received completion event: {:?}", event);
 790                                    this.update(cx, |this, cx| {
 791                                        tool_uses.extend(this.handle_streamed_completion_event(
 792                                            event,
 793                                            &event_stream,
 794                                            cx,
 795                                        ));
 796                                    })
 797                                    .ok();
 798                                }
 799                            }
 800                        }
 801
 802                        let used_tools = tool_uses.is_empty();
 803                        while let Some(tool_result) = tool_uses.next().await {
 804                            log::info!("Tool finished {:?}", tool_result);
 805
 806                            event_stream.update_tool_call_fields(
 807                                &tool_result.tool_use_id,
 808                                acp::ToolCallUpdateFields {
 809                                    status: Some(if tool_result.is_error {
 810                                        acp::ToolCallStatus::Failed
 811                                    } else {
 812                                        acp::ToolCallStatus::Completed
 813                                    }),
 814                                    raw_output: tool_result.output.clone(),
 815                                    ..Default::default()
 816                                },
 817                            );
 818                            this.update(cx, |this, _cx| {
 819                                this.pending_message()
 820                                    .tool_results
 821                                    .insert(tool_result.tool_use_id.clone(), tool_result);
 822                            })
 823                            .ok();
 824                        }
 825
 826                        if tool_use_limit_reached {
 827                            log::info!("Tool use limit reached, completing turn");
 828                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
 829                            return Err(language_model::ToolUseLimitReachedError.into());
 830                        } else if used_tools {
 831                            log::info!("No tool uses found, completing turn");
 832                            return Ok(());
 833                        } else {
 834                            this.update(cx, |this, _| this.flush_pending_message())?;
 835                            completion_intent = CompletionIntent::ToolResults;
 836                        }
 837                    }
 838                }
 839                .await;
 840
 841                if let Err(error) = turn_result {
 842                    log::error!("Turn execution failed: {:?}", error);
 843                    event_stream.send_error(error);
 844                } else {
 845                    log::info!("Turn execution completed successfully");
 846                }
 847
 848                this.update(cx, |this, _| {
 849                    this.flush_pending_message();
 850                    this.running_turn.take();
 851                })
 852                .ok();
 853            }),
 854        });
 855        events_rx
 856    }
 857
 858    pub fn build_system_message(&self) -> LanguageModelRequestMessage {
 859        log::debug!("Building system message");
 860        let prompt = SystemPromptTemplate {
 861            project: &self.project_context.borrow(),
 862            available_tools: self.tools.keys().cloned().collect(),
 863        }
 864        .render(&self.templates)
 865        .context("failed to build system prompt")
 866        .expect("Invalid template");
 867        log::debug!("System message built");
 868        LanguageModelRequestMessage {
 869            role: Role::System,
 870            content: vec![prompt.into()],
 871            cache: true,
 872        }
 873    }
 874
 875    /// A helper method that's called on every streamed completion event.
 876    /// Returns an optional tool result task, which the main agentic loop in
 877    /// send will send back to the model when it resolves.
 878    fn handle_streamed_completion_event(
 879        &mut self,
 880        event: LanguageModelCompletionEvent,
 881        event_stream: &ThreadEventStream,
 882        cx: &mut Context<Self>,
 883    ) -> Option<Task<LanguageModelToolResult>> {
 884        log::trace!("Handling streamed completion event: {:?}", event);
 885        use LanguageModelCompletionEvent::*;
 886
 887        match event {
 888            StartMessage { .. } => {
 889                self.flush_pending_message();
 890                self.pending_message = Some(AgentMessage::default());
 891            }
 892            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 893            Thinking { text, signature } => {
 894                self.handle_thinking_event(text, signature, event_stream, cx)
 895            }
 896            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 897            ToolUse(tool_use) => {
 898                return self.handle_tool_use_event(tool_use, event_stream, cx);
 899            }
 900            ToolUseJsonParseError {
 901                id,
 902                tool_name,
 903                raw_input,
 904                json_parse_error,
 905            } => {
 906                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 907                    id,
 908                    tool_name,
 909                    raw_input,
 910                    json_parse_error,
 911                )));
 912            }
 913            UsageUpdate(_) | StatusUpdate(_) => {}
 914            Stop(_) => unreachable!(),
 915        }
 916
 917        None
 918    }
 919
 920    fn handle_text_event(
 921        &mut self,
 922        new_text: String,
 923        event_stream: &ThreadEventStream,
 924        cx: &mut Context<Self>,
 925    ) {
 926        event_stream.send_text(&new_text);
 927
 928        let last_message = self.pending_message();
 929        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
 930            text.push_str(&new_text);
 931        } else {
 932            last_message
 933                .content
 934                .push(AgentMessageContent::Text(new_text));
 935        }
 936
 937        cx.notify();
 938    }
 939
 940    fn handle_thinking_event(
 941        &mut self,
 942        new_text: String,
 943        new_signature: Option<String>,
 944        event_stream: &ThreadEventStream,
 945        cx: &mut Context<Self>,
 946    ) {
 947        event_stream.send_thinking(&new_text);
 948
 949        let last_message = self.pending_message();
 950        if let Some(AgentMessageContent::Thinking { text, signature }) =
 951            last_message.content.last_mut()
 952        {
 953            text.push_str(&new_text);
 954            *signature = new_signature.or(signature.take());
 955        } else {
 956            last_message.content.push(AgentMessageContent::Thinking {
 957                text: new_text,
 958                signature: new_signature,
 959            });
 960        }
 961
 962        cx.notify();
 963    }
 964
 965    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 966        let last_message = self.pending_message();
 967        last_message
 968            .content
 969            .push(AgentMessageContent::RedactedThinking(data));
 970        cx.notify();
 971    }
 972
 973    fn handle_tool_use_event(
 974        &mut self,
 975        tool_use: LanguageModelToolUse,
 976        event_stream: &ThreadEventStream,
 977        cx: &mut Context<Self>,
 978    ) -> Option<Task<LanguageModelToolResult>> {
 979        cx.notify();
 980
 981        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 982        let mut title = SharedString::from(&tool_use.name);
 983        let mut kind = acp::ToolKind::Other;
 984        if let Some(tool) = tool.as_ref() {
 985            title = tool.initial_title(tool_use.input.clone());
 986            kind = tool.kind();
 987        }
 988
 989        // Ensure the last message ends in the current tool use
 990        let last_message = self.pending_message();
 991        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 992            if let AgentMessageContent::ToolUse(last_tool_use) = content {
 993                if last_tool_use.id == tool_use.id {
 994                    *last_tool_use = tool_use.clone();
 995                    false
 996                } else {
 997                    true
 998                }
 999            } else {
1000                true
1001            }
1002        });
1003
1004        if push_new_tool_use {
1005            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1006            last_message
1007                .content
1008                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1009        } else {
1010            event_stream.update_tool_call_fields(
1011                &tool_use.id,
1012                acp::ToolCallUpdateFields {
1013                    title: Some(title.into()),
1014                    kind: Some(kind),
1015                    raw_input: Some(tool_use.input.clone()),
1016                    ..Default::default()
1017                },
1018            );
1019        }
1020
1021        if !tool_use.is_input_complete {
1022            return None;
1023        }
1024
1025        let Some(tool) = tool else {
1026            let content = format!("No tool named {} exists", tool_use.name);
1027            return Some(Task::ready(LanguageModelToolResult {
1028                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1029                tool_use_id: tool_use.id,
1030                tool_name: tool_use.name,
1031                is_error: true,
1032                output: None,
1033            }));
1034        };
1035
1036        let fs = self.project.read(cx).fs().clone();
1037        let tool_event_stream =
1038            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1039        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1040            status: Some(acp::ToolCallStatus::InProgress),
1041            ..Default::default()
1042        });
1043        let supports_images = self.model.supports_images();
1044        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1045        log::info!("Running tool {}", tool_use.name);
1046        Some(cx.foreground_executor().spawn(async move {
1047            let tool_result = tool_result.await.and_then(|output| {
1048                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1049                    if !supports_images {
1050                        return Err(anyhow!(
1051                            "Attempted to read an image, but this model doesn't support it.",
1052                        ));
1053                    }
1054                }
1055                Ok(output)
1056            });
1057
1058            match tool_result {
1059                Ok(output) => LanguageModelToolResult {
1060                    tool_use_id: tool_use.id,
1061                    tool_name: tool_use.name,
1062                    is_error: false,
1063                    content: output.llm_output,
1064                    output: Some(output.raw_output),
1065                },
1066                Err(error) => LanguageModelToolResult {
1067                    tool_use_id: tool_use.id,
1068                    tool_name: tool_use.name,
1069                    is_error: true,
1070                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1071                    output: None,
1072                },
1073            }
1074        }))
1075    }
1076
1077    fn handle_tool_use_json_parse_error_event(
1078        &mut self,
1079        tool_use_id: LanguageModelToolUseId,
1080        tool_name: Arc<str>,
1081        raw_input: Arc<str>,
1082        json_parse_error: String,
1083    ) -> LanguageModelToolResult {
1084        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1085        LanguageModelToolResult {
1086            tool_use_id,
1087            tool_name,
1088            is_error: true,
1089            content: LanguageModelToolResultContent::Text(tool_output.into()),
1090            output: Some(serde_json::Value::String(raw_input.to_string())),
1091        }
1092    }
1093
1094    fn pending_message(&mut self) -> &mut AgentMessage {
1095        self.pending_message.get_or_insert_default()
1096    }
1097
1098    fn flush_pending_message(&mut self) {
1099        let Some(mut message) = self.pending_message.take() else {
1100            return;
1101        };
1102
1103        for content in &message.content {
1104            let AgentMessageContent::ToolUse(tool_use) = content else {
1105                continue;
1106            };
1107
1108            if !message.tool_results.contains_key(&tool_use.id) {
1109                message.tool_results.insert(
1110                    tool_use.id.clone(),
1111                    LanguageModelToolResult {
1112                        tool_use_id: tool_use.id.clone(),
1113                        tool_name: tool_use.name.clone(),
1114                        is_error: true,
1115                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1116                        output: None,
1117                    },
1118                );
1119            }
1120        }
1121
1122        self.messages.push(Message::Agent(message));
1123    }
1124
1125    pub(crate) fn build_completion_request(
1126        &self,
1127        completion_intent: CompletionIntent,
1128        cx: &mut App,
1129    ) -> LanguageModelRequest {
1130        log::debug!("Building completion request");
1131        log::debug!("Completion intent: {:?}", completion_intent);
1132        log::debug!("Completion mode: {:?}", self.completion_mode);
1133
1134        let messages = self.build_request_messages();
1135        log::info!("Request will include {} messages", messages.len());
1136
1137        let tools = if let Some(tools) = self.tools(cx).log_err() {
1138            tools
1139                .filter_map(|tool| {
1140                    let tool_name = tool.name().to_string();
1141                    log::trace!("Including tool: {}", tool_name);
1142                    Some(LanguageModelRequestTool {
1143                        name: tool_name,
1144                        description: tool.description().to_string(),
1145                        input_schema: tool
1146                            .input_schema(self.model.tool_input_format())
1147                            .log_err()?,
1148                    })
1149                })
1150                .collect()
1151        } else {
1152            Vec::new()
1153        };
1154
1155        log::info!("Request includes {} tools", tools.len());
1156
1157        let request = LanguageModelRequest {
1158            thread_id: Some(self.id.to_string()),
1159            prompt_id: Some(self.prompt_id.to_string()),
1160            intent: Some(completion_intent),
1161            mode: Some(self.completion_mode.into()),
1162            messages,
1163            tools,
1164            tool_choice: None,
1165            stop: Vec::new(),
1166            temperature: AgentSettings::temperature_for_model(self.model(), cx),
1167            thinking_allowed: true,
1168        };
1169
1170        log::debug!("Completion request built successfully");
1171        request
1172    }
1173
1174    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1175        let profile = AgentSettings::get_global(cx)
1176            .profiles
1177            .get(&self.profile_id)
1178            .context("profile not found")?;
1179        let provider_id = self.model.provider_id();
1180
1181        Ok(self
1182            .tools
1183            .iter()
1184            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1185            .filter_map(|(tool_name, tool)| {
1186                if profile.is_tool_enabled(tool_name) {
1187                    Some(tool)
1188                } else {
1189                    None
1190                }
1191            })
1192            .chain(self.context_server_registry.read(cx).servers().flat_map(
1193                |(server_id, tools)| {
1194                    tools.iter().filter_map(|(tool_name, tool)| {
1195                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1196                            Some(tool)
1197                        } else {
1198                            None
1199                        }
1200                    })
1201                },
1202            )))
1203    }
1204
1205    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1206        log::trace!(
1207            "Building request messages from {} thread messages",
1208            self.messages.len()
1209        );
1210        let mut messages = vec![self.build_system_message()];
1211        for message in &self.messages {
1212            match message {
1213                Message::User(message) => messages.push(message.to_request()),
1214                Message::Agent(message) => messages.extend(message.to_request()),
1215                Message::Resume => messages.push(LanguageModelRequestMessage {
1216                    role: Role::User,
1217                    content: vec!["Continue where you left off".into()],
1218                    cache: false,
1219                }),
1220            }
1221        }
1222
1223        if let Some(message) = self.pending_message.as_ref() {
1224            messages.extend(message.to_request());
1225        }
1226
1227        if let Some(last_user_message) = messages
1228            .iter_mut()
1229            .rev()
1230            .find(|message| message.role == Role::User)
1231        {
1232            last_user_message.cache = true;
1233        }
1234
1235        messages
1236    }
1237
1238    pub fn to_markdown(&self) -> String {
1239        let mut markdown = String::new();
1240        for (ix, message) in self.messages.iter().enumerate() {
1241            if ix > 0 {
1242                markdown.push('\n');
1243            }
1244            markdown.push_str(&message.to_markdown());
1245        }
1246
1247        if let Some(message) = self.pending_message.as_ref() {
1248            markdown.push('\n');
1249            markdown.push_str(&message.to_markdown());
1250        }
1251
1252        markdown
1253    }
1254
1255    fn advance_prompt_id(&mut self) {
1256        self.prompt_id = PromptId::new();
1257    }
1258}
1259
1260struct RunningTurn {
1261    /// Holds the task that handles agent interaction until the end of the turn.
1262    /// Survives across multiple requests as the model performs tool calls and
1263    /// we run tools, report their results.
1264    _task: Task<()>,
1265    /// The current event stream for the running turn. Used to report a final
1266    /// cancellation event if we cancel the turn.
1267    event_stream: ThreadEventStream,
1268}
1269
1270impl RunningTurn {
1271    fn cancel(self) {
1272        log::debug!("Cancelling in progress turn");
1273        self.event_stream.send_canceled();
1274    }
1275}
1276
1277pub trait AgentTool
1278where
1279    Self: 'static + Sized,
1280{
1281    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1282    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1283
1284    fn name(&self) -> SharedString;
1285
1286    fn description(&self) -> SharedString {
1287        let schema = schemars::schema_for!(Self::Input);
1288        SharedString::new(
1289            schema
1290                .get("description")
1291                .and_then(|description| description.as_str())
1292                .unwrap_or_default(),
1293        )
1294    }
1295
1296    fn kind(&self) -> acp::ToolKind;
1297
1298    /// The initial tool title to display. Can be updated during the tool run.
1299    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1300
1301    /// Returns the JSON schema that describes the tool's input.
1302    fn input_schema(&self) -> Schema {
1303        schemars::schema_for!(Self::Input)
1304    }
1305
1306    /// Some tools rely on a provider for the underlying billing or other reasons.
1307    /// Allow the tool to check if they are compatible, or should be filtered out.
1308    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1309        true
1310    }
1311
1312    /// Runs the tool with the provided input.
1313    fn run(
1314        self: Arc<Self>,
1315        input: Self::Input,
1316        event_stream: ToolCallEventStream,
1317        cx: &mut App,
1318    ) -> Task<Result<Self::Output>>;
1319
1320    /// Emits events for a previous execution of the tool.
1321    fn replay(
1322        &self,
1323        _input: Self::Input,
1324        _output: Self::Output,
1325        _event_stream: ToolCallEventStream,
1326        _cx: &mut App,
1327    ) -> Result<()> {
1328        Ok(())
1329    }
1330
1331    fn erase(self) -> Arc<dyn AnyAgentTool> {
1332        Arc::new(Erased(Arc::new(self)))
1333    }
1334}
1335
1336pub struct Erased<T>(T);
1337
1338pub struct AgentToolOutput {
1339    pub llm_output: LanguageModelToolResultContent,
1340    pub raw_output: serde_json::Value,
1341}
1342
1343pub trait AnyAgentTool {
1344    fn name(&self) -> SharedString;
1345    fn description(&self) -> SharedString;
1346    fn kind(&self) -> acp::ToolKind;
1347    fn initial_title(&self, input: serde_json::Value) -> SharedString;
1348    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1349    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1350        true
1351    }
1352    fn run(
1353        self: Arc<Self>,
1354        input: serde_json::Value,
1355        event_stream: ToolCallEventStream,
1356        cx: &mut App,
1357    ) -> Task<Result<AgentToolOutput>>;
1358    fn replay(
1359        &self,
1360        input: serde_json::Value,
1361        output: serde_json::Value,
1362        event_stream: ToolCallEventStream,
1363        cx: &mut App,
1364    ) -> Result<()>;
1365}
1366
1367impl<T> AnyAgentTool for Erased<Arc<T>>
1368where
1369    T: AgentTool,
1370{
1371    fn name(&self) -> SharedString {
1372        self.0.name()
1373    }
1374
1375    fn description(&self) -> SharedString {
1376        self.0.description()
1377    }
1378
1379    fn kind(&self) -> agent_client_protocol::ToolKind {
1380        self.0.kind()
1381    }
1382
1383    fn initial_title(&self, input: serde_json::Value) -> SharedString {
1384        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1385        self.0.initial_title(parsed_input)
1386    }
1387
1388    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1389        let mut json = serde_json::to_value(self.0.input_schema())?;
1390        adapt_schema_to_format(&mut json, format)?;
1391        Ok(json)
1392    }
1393
1394    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1395        self.0.supported_provider(provider)
1396    }
1397
1398    fn run(
1399        self: Arc<Self>,
1400        input: serde_json::Value,
1401        event_stream: ToolCallEventStream,
1402        cx: &mut App,
1403    ) -> Task<Result<AgentToolOutput>> {
1404        cx.spawn(async move |cx| {
1405            let input = serde_json::from_value(input)?;
1406            let output = cx
1407                .update(|cx| self.0.clone().run(input, event_stream, cx))?
1408                .await?;
1409            let raw_output = serde_json::to_value(&output)?;
1410            Ok(AgentToolOutput {
1411                llm_output: output.into(),
1412                raw_output,
1413            })
1414        })
1415    }
1416
1417    fn replay(
1418        &self,
1419        input: serde_json::Value,
1420        output: serde_json::Value,
1421        event_stream: ToolCallEventStream,
1422        cx: &mut App,
1423    ) -> Result<()> {
1424        let input = serde_json::from_value(input)?;
1425        let output = serde_json::from_value(output)?;
1426        self.0.replay(input, output, event_stream, cx)
1427    }
1428}
1429
1430#[derive(Clone)]
1431struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1432
1433impl ThreadEventStream {
1434    fn send_user_message(&self, message: &UserMessage) {
1435        self.0
1436            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1437            .ok();
1438    }
1439
1440    fn send_text(&self, text: &str) {
1441        self.0
1442            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1443            .ok();
1444    }
1445
1446    fn send_thinking(&self, text: &str) {
1447        self.0
1448            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1449            .ok();
1450    }
1451
1452    fn send_tool_call(
1453        &self,
1454        id: &LanguageModelToolUseId,
1455        title: SharedString,
1456        kind: acp::ToolKind,
1457        input: serde_json::Value,
1458    ) {
1459        self.0
1460            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1461                id,
1462                title.to_string(),
1463                kind,
1464                input,
1465            ))))
1466            .ok();
1467    }
1468
1469    fn initial_tool_call(
1470        id: &LanguageModelToolUseId,
1471        title: String,
1472        kind: acp::ToolKind,
1473        input: serde_json::Value,
1474    ) -> acp::ToolCall {
1475        acp::ToolCall {
1476            id: acp::ToolCallId(id.to_string().into()),
1477            title,
1478            kind,
1479            status: acp::ToolCallStatus::Pending,
1480            content: vec![],
1481            locations: vec![],
1482            raw_input: Some(input),
1483            raw_output: None,
1484        }
1485    }
1486
1487    fn update_tool_call_fields(
1488        &self,
1489        tool_use_id: &LanguageModelToolUseId,
1490        fields: acp::ToolCallUpdateFields,
1491    ) {
1492        self.0
1493            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1494                acp::ToolCallUpdate {
1495                    id: acp::ToolCallId(tool_use_id.to_string().into()),
1496                    fields,
1497                }
1498                .into(),
1499            )))
1500            .ok();
1501    }
1502
1503    fn send_stop(&self, reason: StopReason) {
1504        match reason {
1505            StopReason::EndTurn => {
1506                self.0
1507                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1508                    .ok();
1509            }
1510            StopReason::MaxTokens => {
1511                self.0
1512                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1513                    .ok();
1514            }
1515            StopReason::Refusal => {
1516                self.0
1517                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1518                    .ok();
1519            }
1520            StopReason::ToolUse => {}
1521        }
1522    }
1523
1524    fn send_canceled(&self) {
1525        self.0
1526            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1527            .ok();
1528    }
1529
1530    fn send_error(&self, error: impl Into<anyhow::Error>) {
1531        self.0.unbounded_send(Err(error.into())).ok();
1532    }
1533}
1534
1535#[derive(Clone)]
1536pub struct ToolCallEventStream {
1537    tool_use_id: LanguageModelToolUseId,
1538    stream: ThreadEventStream,
1539    fs: Option<Arc<dyn Fs>>,
1540}
1541
1542impl ToolCallEventStream {
1543    #[cfg(test)]
1544    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1545        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1546
1547        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1548
1549        (stream, ToolCallEventStreamReceiver(events_rx))
1550    }
1551
1552    fn new(
1553        tool_use_id: LanguageModelToolUseId,
1554        stream: ThreadEventStream,
1555        fs: Option<Arc<dyn Fs>>,
1556    ) -> Self {
1557        Self {
1558            tool_use_id,
1559            stream,
1560            fs,
1561        }
1562    }
1563
1564    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1565        self.stream
1566            .update_tool_call_fields(&self.tool_use_id, fields);
1567    }
1568
1569    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1570        self.stream
1571            .0
1572            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1573                acp_thread::ToolCallUpdateDiff {
1574                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1575                    diff,
1576                }
1577                .into(),
1578            )))
1579            .ok();
1580    }
1581
1582    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1583        self.stream
1584            .0
1585            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1586                acp_thread::ToolCallUpdateTerminal {
1587                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1588                    terminal,
1589                }
1590                .into(),
1591            )))
1592            .ok();
1593    }
1594
1595    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1596        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1597            return Task::ready(Ok(()));
1598        }
1599
1600        let (response_tx, response_rx) = oneshot::channel();
1601        self.stream
1602            .0
1603            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1604                ToolCallAuthorization {
1605                    tool_call: acp::ToolCallUpdate {
1606                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1607                        fields: acp::ToolCallUpdateFields {
1608                            title: Some(title.into()),
1609                            ..Default::default()
1610                        },
1611                    },
1612                    options: vec![
1613                        acp::PermissionOption {
1614                            id: acp::PermissionOptionId("always_allow".into()),
1615                            name: "Always Allow".into(),
1616                            kind: acp::PermissionOptionKind::AllowAlways,
1617                        },
1618                        acp::PermissionOption {
1619                            id: acp::PermissionOptionId("allow".into()),
1620                            name: "Allow".into(),
1621                            kind: acp::PermissionOptionKind::AllowOnce,
1622                        },
1623                        acp::PermissionOption {
1624                            id: acp::PermissionOptionId("deny".into()),
1625                            name: "Deny".into(),
1626                            kind: acp::PermissionOptionKind::RejectOnce,
1627                        },
1628                    ],
1629                    response: response_tx,
1630                },
1631            )))
1632            .ok();
1633        let fs = self.fs.clone();
1634        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1635            "always_allow" => {
1636                if let Some(fs) = fs.clone() {
1637                    cx.update(|cx| {
1638                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1639                            settings.set_always_allow_tool_actions(true);
1640                        });
1641                    })?;
1642                }
1643
1644                Ok(())
1645            }
1646            "allow" => Ok(()),
1647            _ => Err(anyhow!("Permission to run tool denied by user")),
1648        })
1649    }
1650}
1651
1652#[cfg(test)]
1653pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1654
1655#[cfg(test)]
1656impl ToolCallEventStreamReceiver {
1657    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1658        let event = self.0.next().await;
1659        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1660            auth
1661        } else {
1662            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1663        }
1664    }
1665
1666    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1667        let event = self.0.next().await;
1668        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1669            update,
1670        )))) = event
1671        {
1672            update.terminal
1673        } else {
1674            panic!("Expected terminal but got: {:?}", event);
1675        }
1676    }
1677}
1678
1679#[cfg(test)]
1680impl std::ops::Deref for ToolCallEventStreamReceiver {
1681    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1682
1683    fn deref(&self) -> &Self::Target {
1684        &self.0
1685    }
1686}
1687
1688#[cfg(test)]
1689impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1690    fn deref_mut(&mut self) -> &mut Self::Target {
1691        &mut self.0
1692    }
1693}
1694
1695impl From<&str> for UserMessageContent {
1696    fn from(text: &str) -> Self {
1697        Self::Text(text.into())
1698    }
1699}
1700
1701impl From<acp::ContentBlock> for UserMessageContent {
1702    fn from(value: acp::ContentBlock) -> Self {
1703        match value {
1704            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1705            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1706            acp::ContentBlock::Audio(_) => {
1707                // TODO
1708                Self::Text("[audio]".to_string())
1709            }
1710            acp::ContentBlock::ResourceLink(resource_link) => {
1711                match MentionUri::parse(&resource_link.uri) {
1712                    Ok(uri) => Self::Mention {
1713                        uri,
1714                        content: String::new(),
1715                    },
1716                    Err(err) => {
1717                        log::error!("Failed to parse mention link: {}", err);
1718                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1719                    }
1720                }
1721            }
1722            acp::ContentBlock::Resource(resource) => match resource.resource {
1723                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1724                    match MentionUri::parse(&resource.uri) {
1725                        Ok(uri) => Self::Mention {
1726                            uri,
1727                            content: resource.text,
1728                        },
1729                        Err(err) => {
1730                            log::error!("Failed to parse mention link: {}", err);
1731                            Self::Text(
1732                                MarkdownCodeBlock {
1733                                    tag: &resource.uri,
1734                                    text: &resource.text,
1735                                }
1736                                .to_string(),
1737                            )
1738                        }
1739                    }
1740                }
1741                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1742                    // TODO
1743                    Self::Text("[blob]".to_string())
1744                }
1745            },
1746        }
1747    }
1748}
1749
1750impl From<UserMessageContent> for acp::ContentBlock {
1751    fn from(content: UserMessageContent) -> Self {
1752        match content {
1753            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
1754                text,
1755                annotations: None,
1756            }),
1757            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
1758                data: image.source.to_string(),
1759                mime_type: "image/png".to_string(),
1760                annotations: None,
1761                uri: None,
1762            }),
1763            UserMessageContent::Mention { uri, content } => {
1764                todo!()
1765            }
1766        }
1767    }
1768}
1769
1770fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1771    LanguageModelImage {
1772        source: image_content.data.into(),
1773        // TODO: make this optional?
1774        size: gpui::Size::new(0.into(), 0.into()),
1775    }
1776}