thread.rs

   1use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates};
   2use acp_thread::{MentionUri, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol as acp;
   5use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
   6use anyhow::{Context as _, Result, anyhow};
   7use assistant_tool::adapt_schema_to_format;
   8use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
   9use collections::IndexMap;
  10use fs::Fs;
  11use futures::{
  12    channel::{mpsc, oneshot},
  13    stream::FuturesUnordered,
  14};
  15use gpui::{App, AppContext, Context, Entity, SharedString, Task};
  16use language_model::{
  17    LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
  18    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  19    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  20    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
  21};
  22use project::Project;
  23use prompt_store::ProjectContext;
  24use schemars::{JsonSchema, Schema};
  25use serde::{Deserialize, Serialize};
  26use settings::{Settings, update_settings_file};
  27use smol::stream::StreamExt;
  28use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
  29use std::{fmt::Write, ops::Range};
  30use util::{ResultExt, markdown::MarkdownCodeBlock};
  31use uuid::Uuid;
  32
  33const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  34
  35#[derive(
  36    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  37)]
  38pub struct ThreadId(pub(crate) Arc<str>);
  39
  40impl ThreadId {
  41    pub fn new() -> Self {
  42        Self(Uuid::new_v4().to_string().into())
  43    }
  44}
  45
  46impl std::fmt::Display for ThreadId {
  47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  48        write!(f, "{}", self.0)
  49    }
  50}
  51
  52impl From<&str> for ThreadId {
  53    fn from(value: &str) -> Self {
  54        Self(value.into())
  55    }
  56}
  57
  58impl From<acp::SessionId> for ThreadId {
  59    fn from(value: acp::SessionId) -> Self {
  60        Self(value.0)
  61    }
  62}
  63
  64impl From<ThreadId> for acp::SessionId {
  65    fn from(value: ThreadId) -> Self {
  66        Self(value.0)
  67    }
  68}
  69
  70/// The ID of the user prompt that initiated a request.
  71///
  72/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  74pub struct PromptId(Arc<str>);
  75
  76impl PromptId {
  77    pub fn new() -> Self {
  78        Self(Uuid::new_v4().to_string().into())
  79    }
  80}
  81
  82impl std::fmt::Display for PromptId {
  83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  84        write!(f, "{}", self.0)
  85    }
  86}
  87
  88#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  89pub enum Message {
  90    User(UserMessage),
  91    Agent(AgentMessage),
  92    Resume,
  93}
  94
  95impl Message {
  96    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
  97        match self {
  98            Message::Agent(agent_message) => Some(agent_message),
  99            _ => None,
 100        }
 101    }
 102
 103    pub fn to_markdown(&self) -> String {
 104        match self {
 105            Message::User(message) => message.to_markdown(),
 106            Message::Agent(message) => message.to_markdown(),
 107            Message::Resume => "[resumed after tool use limit was reached]".into(),
 108        }
 109    }
 110}
 111
 112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 113pub struct UserMessage {
 114    pub id: UserMessageId,
 115    pub content: Vec<UserMessageContent>,
 116}
 117
 118#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 119pub enum UserMessageContent {
 120    Text(String),
 121    Mention { uri: MentionUri, content: String },
 122    Image(LanguageModelImage),
 123}
 124
 125impl UserMessage {
 126    pub fn to_markdown(&self) -> String {
 127        let mut markdown = String::from("## User\n\n");
 128
 129        for content in &self.content {
 130            match content {
 131                UserMessageContent::Text(text) => {
 132                    markdown.push_str(text);
 133                    markdown.push('\n');
 134                }
 135                UserMessageContent::Image(_) => {
 136                    markdown.push_str("<image />\n");
 137                }
 138                UserMessageContent::Mention { uri, content } => {
 139                    if !content.is_empty() {
 140                        let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
 141                    } else {
 142                        let _ = write!(&mut markdown, "{}\n", uri.as_link());
 143                    }
 144                }
 145            }
 146        }
 147
 148        markdown
 149    }
 150
 151    fn to_request(&self) -> LanguageModelRequestMessage {
 152        let mut message = LanguageModelRequestMessage {
 153            role: Role::User,
 154            content: Vec::with_capacity(self.content.len()),
 155            cache: false,
 156        };
 157
 158        const OPEN_CONTEXT: &str = "<context>\n\
 159            The following items were attached by the user. \
 160            They are up-to-date and don't need to be re-read.\n\n";
 161
 162        const OPEN_FILES_TAG: &str = "<files>";
 163        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 164        const OPEN_THREADS_TAG: &str = "<threads>";
 165        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 166        const OPEN_RULES_TAG: &str =
 167            "<rules>\nThe user has specified the following rules that should be applied:\n";
 168
 169        let mut file_context = OPEN_FILES_TAG.to_string();
 170        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 171        let mut thread_context = OPEN_THREADS_TAG.to_string();
 172        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 173        let mut rules_context = OPEN_RULES_TAG.to_string();
 174
 175        for chunk in &self.content {
 176            let chunk = match chunk {
 177                UserMessageContent::Text(text) => {
 178                    language_model::MessageContent::Text(text.clone())
 179                }
 180                UserMessageContent::Image(value) => {
 181                    language_model::MessageContent::Image(value.clone())
 182                }
 183                UserMessageContent::Mention { uri, content } => {
 184                    match uri {
 185                        MentionUri::File { abs_path, .. } => {
 186                            write!(
 187                                &mut symbol_context,
 188                                "\n{}",
 189                                MarkdownCodeBlock {
 190                                    tag: &codeblock_tag(&abs_path, None),
 191                                    text: &content.to_string(),
 192                                }
 193                            )
 194                            .ok();
 195                        }
 196                        MentionUri::Symbol {
 197                            path, line_range, ..
 198                        }
 199                        | MentionUri::Selection {
 200                            path, line_range, ..
 201                        } => {
 202                            write!(
 203                                &mut rules_context,
 204                                "\n{}",
 205                                MarkdownCodeBlock {
 206                                    tag: &codeblock_tag(&path, Some(line_range)),
 207                                    text: &content
 208                                }
 209                            )
 210                            .ok();
 211                        }
 212                        MentionUri::Thread { .. } => {
 213                            write!(&mut thread_context, "\n{}\n", content).ok();
 214                        }
 215                        MentionUri::TextThread { .. } => {
 216                            write!(&mut thread_context, "\n{}\n", content).ok();
 217                        }
 218                        MentionUri::Rule { .. } => {
 219                            write!(
 220                                &mut rules_context,
 221                                "\n{}",
 222                                MarkdownCodeBlock {
 223                                    tag: "",
 224                                    text: &content
 225                                }
 226                            )
 227                            .ok();
 228                        }
 229                        MentionUri::Fetch { url } => {
 230                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 231                        }
 232                    }
 233
 234                    language_model::MessageContent::Text(uri.as_link().to_string())
 235                }
 236            };
 237
 238            message.content.push(chunk);
 239        }
 240
 241        let len_before_context = message.content.len();
 242
 243        if file_context.len() > OPEN_FILES_TAG.len() {
 244            file_context.push_str("</files>\n");
 245            message
 246                .content
 247                .push(language_model::MessageContent::Text(file_context));
 248        }
 249
 250        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 251            symbol_context.push_str("</symbols>\n");
 252            message
 253                .content
 254                .push(language_model::MessageContent::Text(symbol_context));
 255        }
 256
 257        if thread_context.len() > OPEN_THREADS_TAG.len() {
 258            thread_context.push_str("</threads>\n");
 259            message
 260                .content
 261                .push(language_model::MessageContent::Text(thread_context));
 262        }
 263
 264        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 265            fetch_context.push_str("</fetched_urls>\n");
 266            message
 267                .content
 268                .push(language_model::MessageContent::Text(fetch_context));
 269        }
 270
 271        if rules_context.len() > OPEN_RULES_TAG.len() {
 272            rules_context.push_str("</user_rules>\n");
 273            message
 274                .content
 275                .push(language_model::MessageContent::Text(rules_context));
 276        }
 277
 278        if message.content.len() > len_before_context {
 279            message.content.insert(
 280                len_before_context,
 281                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 282            );
 283            message
 284                .content
 285                .push(language_model::MessageContent::Text("</context>".into()));
 286        }
 287
 288        message
 289    }
 290}
 291
 292fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 293    let mut result = String::new();
 294
 295    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 296        let _ = write!(result, "{} ", extension);
 297    }
 298
 299    let _ = write!(result, "{}", full_path.display());
 300
 301    if let Some(range) = line_range {
 302        if range.start == range.end {
 303            let _ = write!(result, ":{}", range.start + 1);
 304        } else {
 305            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 306        }
 307    }
 308
 309    result
 310}
 311
 312impl AgentMessage {
 313    pub fn to_markdown(&self) -> String {
 314        let mut markdown = String::from("## Assistant\n\n");
 315
 316        for content in &self.content {
 317            match content {
 318                AgentMessageContent::Text(text) => {
 319                    markdown.push_str(text);
 320                    markdown.push('\n');
 321                }
 322                AgentMessageContent::Thinking { text, .. } => {
 323                    markdown.push_str("<think>");
 324                    markdown.push_str(text);
 325                    markdown.push_str("</think>\n");
 326                }
 327                AgentMessageContent::RedactedThinking(_) => {
 328                    markdown.push_str("<redacted_thinking />\n")
 329                }
 330                AgentMessageContent::ToolUse(tool_use) => {
 331                    markdown.push_str(&format!(
 332                        "**Tool Use**: {} (ID: {})\n",
 333                        tool_use.name, tool_use.id
 334                    ));
 335                    markdown.push_str(&format!(
 336                        "{}\n",
 337                        MarkdownCodeBlock {
 338                            tag: "json",
 339                            text: &format!("{:#}", tool_use.input)
 340                        }
 341                    ));
 342                }
 343            }
 344        }
 345
 346        for tool_result in self.tool_results.values() {
 347            markdown.push_str(&format!(
 348                "**Tool Result**: {} (ID: {})\n\n",
 349                tool_result.tool_name, tool_result.tool_use_id
 350            ));
 351            if tool_result.is_error {
 352                markdown.push_str("**ERROR:**\n");
 353            }
 354
 355            match &tool_result.content {
 356                LanguageModelToolResultContent::Text(text) => {
 357                    writeln!(markdown, "{text}\n").ok();
 358                }
 359                LanguageModelToolResultContent::Image(_) => {
 360                    writeln!(markdown, "<image />\n").ok();
 361                }
 362            }
 363
 364            if let Some(output) = tool_result.output.as_ref() {
 365                writeln!(
 366                    markdown,
 367                    "**Debug Output**:\n\n```json\n{}\n```\n",
 368                    serde_json::to_string_pretty(output).unwrap()
 369                )
 370                .unwrap();
 371            }
 372        }
 373
 374        markdown
 375    }
 376
 377    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 378        let mut assistant_message = LanguageModelRequestMessage {
 379            role: Role::Assistant,
 380            content: Vec::with_capacity(self.content.len()),
 381            cache: false,
 382        };
 383        for chunk in &self.content {
 384            let chunk = match chunk {
 385                AgentMessageContent::Text(text) => {
 386                    language_model::MessageContent::Text(text.clone())
 387                }
 388                AgentMessageContent::Thinking { text, signature } => {
 389                    language_model::MessageContent::Thinking {
 390                        text: text.clone(),
 391                        signature: signature.clone(),
 392                    }
 393                }
 394                AgentMessageContent::RedactedThinking(value) => {
 395                    language_model::MessageContent::RedactedThinking(value.clone())
 396                }
 397                AgentMessageContent::ToolUse(value) => {
 398                    language_model::MessageContent::ToolUse(value.clone())
 399                }
 400            };
 401            assistant_message.content.push(chunk);
 402        }
 403
 404        let mut user_message = LanguageModelRequestMessage {
 405            role: Role::User,
 406            content: Vec::new(),
 407            cache: false,
 408        };
 409
 410        for tool_result in self.tool_results.values() {
 411            user_message
 412                .content
 413                .push(language_model::MessageContent::ToolResult(
 414                    tool_result.clone(),
 415                ));
 416        }
 417
 418        let mut messages = Vec::new();
 419        if !assistant_message.content.is_empty() {
 420            messages.push(assistant_message);
 421        }
 422        if !user_message.content.is_empty() {
 423            messages.push(user_message);
 424        }
 425        messages
 426    }
 427}
 428
 429#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 430pub struct AgentMessage {
 431    pub content: Vec<AgentMessageContent>,
 432    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 433}
 434
 435#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 436pub enum AgentMessageContent {
 437    Text(String),
 438    Thinking {
 439        text: String,
 440        signature: Option<String>,
 441    },
 442    RedactedThinking(String),
 443    ToolUse(LanguageModelToolUse),
 444}
 445
 446#[derive(Debug)]
 447pub enum ThreadEvent {
 448    UserMessage(UserMessage),
 449    AgentText(String),
 450    AgentThinking(String),
 451    ToolCall(acp::ToolCall),
 452    ToolCallUpdate(acp_thread::ToolCallUpdate),
 453    ToolCallAuthorization(ToolCallAuthorization),
 454    Stop(acp::StopReason),
 455}
 456
 457#[derive(Debug)]
 458pub struct ToolCallAuthorization {
 459    pub tool_call: acp::ToolCallUpdate,
 460    pub options: Vec<acp::PermissionOption>,
 461    pub response: oneshot::Sender<acp::PermissionOptionId>,
 462}
 463
 464pub struct Thread {
 465    id: ThreadId,
 466    prompt_id: PromptId,
 467    messages: Vec<Message>,
 468    completion_mode: CompletionMode,
 469    /// Holds the task that handles agent interaction until the end of the turn.
 470    /// Survives across multiple requests as the model performs tool calls and
 471    /// we run tools, report their results.
 472    running_turn: Option<RunningTurn>,
 473    pending_message: Option<AgentMessage>,
 474    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 475    tool_use_limit_reached: bool,
 476    context_server_registry: Entity<ContextServerRegistry>,
 477    profile_id: AgentProfileId,
 478    project_context: Rc<RefCell<ProjectContext>>,
 479    templates: Arc<Templates>,
 480    model: Arc<dyn LanguageModel>,
 481    project: Entity<Project>,
 482    action_log: Entity<ActionLog>,
 483}
 484
 485impl Thread {
 486    pub fn new(
 487        project: Entity<Project>,
 488        project_context: Rc<RefCell<ProjectContext>>,
 489        context_server_registry: Entity<ContextServerRegistry>,
 490        action_log: Entity<ActionLog>,
 491        templates: Arc<Templates>,
 492        model: Arc<dyn LanguageModel>,
 493        cx: &mut Context<Self>,
 494    ) -> Self {
 495        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 496        Self {
 497            id: ThreadId::new(),
 498            prompt_id: PromptId::new(),
 499            messages: Vec::new(),
 500            completion_mode: CompletionMode::Normal,
 501            running_turn: None,
 502            pending_message: None,
 503            tools: BTreeMap::default(),
 504            tool_use_limit_reached: false,
 505            context_server_registry,
 506            profile_id,
 507            project_context,
 508            templates,
 509            model,
 510            project,
 511            action_log,
 512        }
 513    }
 514
 515    pub fn from_db(
 516        id: ThreadId,
 517        db_thread: DbThread,
 518        project: Entity<Project>,
 519        project_context: Rc<RefCell<ProjectContext>>,
 520        context_server_registry: Entity<ContextServerRegistry>,
 521        action_log: Entity<ActionLog>,
 522        templates: Arc<Templates>,
 523        model: Arc<dyn LanguageModel>,
 524        cx: &mut Context<Self>,
 525    ) -> Self {
 526        let profile_id = db_thread
 527            .profile
 528            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 529        Self {
 530            id,
 531            prompt_id: PromptId::new(),
 532            messages: db_thread.messages,
 533            completion_mode: CompletionMode::Normal,
 534            running_turn: None,
 535            pending_message: None,
 536            tools: BTreeMap::default(),
 537            tool_use_limit_reached: false,
 538            context_server_registry,
 539            profile_id,
 540            project_context,
 541            templates,
 542            model,
 543            project,
 544            action_log,
 545        }
 546    }
 547
 548    pub fn replay(
 549        &mut self,
 550        cx: &mut Context<Self>,
 551    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 552        let (tx, rx) = mpsc::unbounded();
 553        let stream = ThreadEventStream(tx);
 554        for message in &self.messages {
 555            match message {
 556                Message::User(user_message) => stream.send_user_message(&user_message),
 557                Message::Agent(assistant_message) => {
 558                    for content in &assistant_message.content {
 559                        match content {
 560                            AgentMessageContent::Text(text) => stream.send_text(text),
 561                            AgentMessageContent::Thinking { text, .. } => {
 562                                stream.send_thinking(text)
 563                            }
 564                            AgentMessageContent::RedactedThinking(_) => {}
 565                            AgentMessageContent::ToolUse(tool_use) => {
 566                                self.replay_tool_call(
 567                                    tool_use,
 568                                    assistant_message.tool_results.get(&tool_use.id),
 569                                    &stream,
 570                                    cx,
 571                                );
 572                            }
 573                        }
 574                    }
 575                }
 576                Message::Resume => {}
 577            }
 578        }
 579        rx
 580    }
 581
 582    fn replay_tool_call(
 583        &self,
 584        tool_use: &LanguageModelToolUse,
 585        tool_result: Option<&LanguageModelToolResult>,
 586        stream: &ThreadEventStream,
 587        cx: &mut Context<Self>,
 588    ) {
 589        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
 590            stream
 591                .0
 592                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 593                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 594                    title: tool_use.name.to_string(),
 595                    kind: acp::ToolKind::Other,
 596                    status: acp::ToolCallStatus::Failed,
 597                    content: Vec::new(),
 598                    locations: Vec::new(),
 599                    raw_input: Some(tool_use.input.clone()),
 600                    raw_output: None,
 601                })))
 602                .ok();
 603            return;
 604        };
 605
 606        let title = tool.initial_title(tool_use.input.clone());
 607        let kind = tool.kind();
 608        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 609
 610        if let Some(output) = tool_result
 611            .as_ref()
 612            .and_then(|result| result.output.clone())
 613        {
 614            let tool_event_stream = ToolCallEventStream::new(
 615                tool_use.id.clone(),
 616                stream.clone(),
 617                Some(self.project.read(cx).fs().clone()),
 618            );
 619            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 620                .log_err();
 621        }
 622
 623        stream.update_tool_call_fields(
 624            &tool_use.id,
 625            acp::ToolCallUpdateFields {
 626                status: Some(acp::ToolCallStatus::Completed),
 627                ..Default::default()
 628            },
 629        );
 630    }
 631
 632    pub fn project(&self) -> &Entity<Project> {
 633        &self.project
 634    }
 635
 636    pub fn action_log(&self) -> &Entity<ActionLog> {
 637        &self.action_log
 638    }
 639
 640    pub fn model(&self) -> &Arc<dyn LanguageModel> {
 641        &self.model
 642    }
 643
 644    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
 645        self.model = model;
 646    }
 647
 648    pub fn completion_mode(&self) -> CompletionMode {
 649        self.completion_mode
 650    }
 651
 652    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
 653        self.completion_mode = mode;
 654    }
 655
 656    #[cfg(any(test, feature = "test-support"))]
 657    pub fn last_message(&self) -> Option<Message> {
 658        if let Some(message) = self.pending_message.clone() {
 659            Some(Message::Agent(message))
 660        } else {
 661            self.messages.last().cloned()
 662        }
 663    }
 664
 665    pub fn add_tool(&mut self, tool: impl AgentTool) {
 666        self.tools.insert(tool.name(), tool.erase());
 667    }
 668
 669    pub fn remove_tool(&mut self, name: &str) -> bool {
 670        self.tools.remove(name).is_some()
 671    }
 672
 673    pub fn profile(&self) -> &AgentProfileId {
 674        &self.profile_id
 675    }
 676
 677    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 678        self.profile_id = profile_id;
 679    }
 680
 681    pub fn cancel(&mut self) {
 682        if let Some(running_turn) = self.running_turn.take() {
 683            running_turn.cancel();
 684        }
 685        self.flush_pending_message();
 686    }
 687
 688    pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
 689        self.cancel();
 690        let Some(position) = self.messages.iter().position(
 691            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
 692        ) else {
 693            return Err(anyhow!("Message not found"));
 694        };
 695        self.messages.truncate(position);
 696        Ok(())
 697    }
 698
 699    pub fn resume(
 700        &mut self,
 701        cx: &mut Context<Self>,
 702    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
 703        anyhow::ensure!(
 704            self.tool_use_limit_reached,
 705            "can only resume after tool use limit is reached"
 706        );
 707
 708        self.messages.push(Message::Resume);
 709        cx.notify();
 710
 711        log::info!("Total messages in thread: {}", self.messages.len());
 712        Ok(self.run_turn(cx))
 713    }
 714
 715    /// Sending a message results in the model streaming a response, which could include tool calls.
 716    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 717    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 718    pub fn send<T>(
 719        &mut self,
 720        id: UserMessageId,
 721        content: impl IntoIterator<Item = T>,
 722        cx: &mut Context<Self>,
 723    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
 724    where
 725        T: Into<UserMessageContent>,
 726    {
 727        log::info!("Thread::send called with model: {:?}", self.model.name());
 728        self.advance_prompt_id();
 729
 730        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
 731        log::debug!("Thread::send content: {:?}", content);
 732
 733        self.messages
 734            .push(Message::User(UserMessage { id, content }));
 735        cx.notify();
 736
 737        log::info!("Total messages in thread: {}", self.messages.len());
 738        self.run_turn(cx)
 739    }
 740
 741    fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 742        self.cancel();
 743
 744        let model = self.model.clone();
 745        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
 746        let event_stream = ThreadEventStream(events_tx);
 747        let message_ix = self.messages.len().saturating_sub(1);
 748        self.tool_use_limit_reached = false;
 749        self.running_turn = Some(RunningTurn {
 750            event_stream: event_stream.clone(),
 751            _task: cx.spawn(async move |this, cx| {
 752                log::info!("Starting agent turn execution");
 753                let turn_result: Result<()> = async {
 754                    let mut completion_intent = CompletionIntent::UserPrompt;
 755                    loop {
 756                        log::debug!(
 757                            "Building completion request with intent: {:?}",
 758                            completion_intent
 759                        );
 760                        let request = this.update(cx, |this, cx| {
 761                            this.build_completion_request(completion_intent, cx)
 762                        })?;
 763
 764                        log::info!("Calling model.stream_completion");
 765                        let mut events = model.stream_completion(request, cx).await?;
 766                        log::debug!("Stream completion started successfully");
 767
 768                        let mut tool_use_limit_reached = false;
 769                        let mut tool_uses = FuturesUnordered::new();
 770                        while let Some(event) = events.next().await {
 771                            match event? {
 772                                LanguageModelCompletionEvent::StatusUpdate(
 773                                    CompletionRequestStatus::ToolUseLimitReached,
 774                                ) => {
 775                                    tool_use_limit_reached = true;
 776                                }
 777                                LanguageModelCompletionEvent::Stop(reason) => {
 778                                    event_stream.send_stop(reason);
 779                                    if reason == StopReason::Refusal {
 780                                        this.update(cx, |this, _cx| {
 781                                            this.flush_pending_message();
 782                                            this.messages.truncate(message_ix);
 783                                        })?;
 784                                        return Ok(());
 785                                    }
 786                                }
 787                                event => {
 788                                    log::trace!("Received completion event: {:?}", event);
 789                                    this.update(cx, |this, cx| {
 790                                        tool_uses.extend(this.handle_streamed_completion_event(
 791                                            event,
 792                                            &event_stream,
 793                                            cx,
 794                                        ));
 795                                    })
 796                                    .ok();
 797                                }
 798                            }
 799                        }
 800
 801                        let used_tools = tool_uses.is_empty();
 802                        while let Some(tool_result) = tool_uses.next().await {
 803                            log::info!("Tool finished {:?}", tool_result);
 804
 805                            event_stream.update_tool_call_fields(
 806                                &tool_result.tool_use_id,
 807                                acp::ToolCallUpdateFields {
 808                                    status: Some(if tool_result.is_error {
 809                                        acp::ToolCallStatus::Failed
 810                                    } else {
 811                                        acp::ToolCallStatus::Completed
 812                                    }),
 813                                    raw_output: tool_result.output.clone(),
 814                                    ..Default::default()
 815                                },
 816                            );
 817                            this.update(cx, |this, _cx| {
 818                                this.pending_message()
 819                                    .tool_results
 820                                    .insert(tool_result.tool_use_id.clone(), tool_result);
 821                            })
 822                            .ok();
 823                        }
 824
 825                        if tool_use_limit_reached {
 826                            log::info!("Tool use limit reached, completing turn");
 827                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
 828                            return Err(language_model::ToolUseLimitReachedError.into());
 829                        } else if used_tools {
 830                            log::info!("No tool uses found, completing turn");
 831                            return Ok(());
 832                        } else {
 833                            this.update(cx, |this, _| this.flush_pending_message())?;
 834                            completion_intent = CompletionIntent::ToolResults;
 835                        }
 836                    }
 837                }
 838                .await;
 839
 840                if let Err(error) = turn_result {
 841                    log::error!("Turn execution failed: {:?}", error);
 842                    event_stream.send_error(error);
 843                } else {
 844                    log::info!("Turn execution completed successfully");
 845                }
 846
 847                this.update(cx, |this, _| {
 848                    this.flush_pending_message();
 849                    this.running_turn.take();
 850                })
 851                .ok();
 852            }),
 853        });
 854        events_rx
 855    }
 856
 857    pub fn build_system_message(&self) -> LanguageModelRequestMessage {
 858        log::debug!("Building system message");
 859        let prompt = SystemPromptTemplate {
 860            project: &self.project_context.borrow(),
 861            available_tools: self.tools.keys().cloned().collect(),
 862        }
 863        .render(&self.templates)
 864        .context("failed to build system prompt")
 865        .expect("Invalid template");
 866        log::debug!("System message built");
 867        LanguageModelRequestMessage {
 868            role: Role::System,
 869            content: vec![prompt.into()],
 870            cache: true,
 871        }
 872    }
 873
 874    /// A helper method that's called on every streamed completion event.
 875    /// Returns an optional tool result task, which the main agentic loop in
 876    /// send will send back to the model when it resolves.
 877    fn handle_streamed_completion_event(
 878        &mut self,
 879        event: LanguageModelCompletionEvent,
 880        event_stream: &ThreadEventStream,
 881        cx: &mut Context<Self>,
 882    ) -> Option<Task<LanguageModelToolResult>> {
 883        log::trace!("Handling streamed completion event: {:?}", event);
 884        use LanguageModelCompletionEvent::*;
 885
 886        match event {
 887            StartMessage { .. } => {
 888                self.flush_pending_message();
 889                self.pending_message = Some(AgentMessage::default());
 890            }
 891            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
 892            Thinking { text, signature } => {
 893                self.handle_thinking_event(text, signature, event_stream, cx)
 894            }
 895            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
 896            ToolUse(tool_use) => {
 897                return self.handle_tool_use_event(tool_use, event_stream, cx);
 898            }
 899            ToolUseJsonParseError {
 900                id,
 901                tool_name,
 902                raw_input,
 903                json_parse_error,
 904            } => {
 905                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
 906                    id,
 907                    tool_name,
 908                    raw_input,
 909                    json_parse_error,
 910                )));
 911            }
 912            UsageUpdate(_) | StatusUpdate(_) => {}
 913            Stop(_) => unreachable!(),
 914        }
 915
 916        None
 917    }
 918
 919    fn handle_text_event(
 920        &mut self,
 921        new_text: String,
 922        event_stream: &ThreadEventStream,
 923        cx: &mut Context<Self>,
 924    ) {
 925        event_stream.send_text(&new_text);
 926
 927        let last_message = self.pending_message();
 928        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
 929            text.push_str(&new_text);
 930        } else {
 931            last_message
 932                .content
 933                .push(AgentMessageContent::Text(new_text));
 934        }
 935
 936        cx.notify();
 937    }
 938
 939    fn handle_thinking_event(
 940        &mut self,
 941        new_text: String,
 942        new_signature: Option<String>,
 943        event_stream: &ThreadEventStream,
 944        cx: &mut Context<Self>,
 945    ) {
 946        event_stream.send_thinking(&new_text);
 947
 948        let last_message = self.pending_message();
 949        if let Some(AgentMessageContent::Thinking { text, signature }) =
 950            last_message.content.last_mut()
 951        {
 952            text.push_str(&new_text);
 953            *signature = new_signature.or(signature.take());
 954        } else {
 955            last_message.content.push(AgentMessageContent::Thinking {
 956                text: new_text,
 957                signature: new_signature,
 958            });
 959        }
 960
 961        cx.notify();
 962    }
 963
 964    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
 965        let last_message = self.pending_message();
 966        last_message
 967            .content
 968            .push(AgentMessageContent::RedactedThinking(data));
 969        cx.notify();
 970    }
 971
 972    fn handle_tool_use_event(
 973        &mut self,
 974        tool_use: LanguageModelToolUse,
 975        event_stream: &ThreadEventStream,
 976        cx: &mut Context<Self>,
 977    ) -> Option<Task<LanguageModelToolResult>> {
 978        cx.notify();
 979
 980        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
 981        let mut title = SharedString::from(&tool_use.name);
 982        let mut kind = acp::ToolKind::Other;
 983        if let Some(tool) = tool.as_ref() {
 984            title = tool.initial_title(tool_use.input.clone());
 985            kind = tool.kind();
 986        }
 987
 988        // Ensure the last message ends in the current tool use
 989        let last_message = self.pending_message();
 990        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
 991            if let AgentMessageContent::ToolUse(last_tool_use) = content {
 992                if last_tool_use.id == tool_use.id {
 993                    *last_tool_use = tool_use.clone();
 994                    false
 995                } else {
 996                    true
 997                }
 998            } else {
 999                true
1000            }
1001        });
1002
1003        if push_new_tool_use {
1004            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1005            last_message
1006                .content
1007                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1008        } else {
1009            event_stream.update_tool_call_fields(
1010                &tool_use.id,
1011                acp::ToolCallUpdateFields {
1012                    title: Some(title.into()),
1013                    kind: Some(kind),
1014                    raw_input: Some(tool_use.input.clone()),
1015                    ..Default::default()
1016                },
1017            );
1018        }
1019
1020        if !tool_use.is_input_complete {
1021            return None;
1022        }
1023
1024        let Some(tool) = tool else {
1025            let content = format!("No tool named {} exists", tool_use.name);
1026            return Some(Task::ready(LanguageModelToolResult {
1027                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1028                tool_use_id: tool_use.id,
1029                tool_name: tool_use.name,
1030                is_error: true,
1031                output: None,
1032            }));
1033        };
1034
1035        let fs = self.project.read(cx).fs().clone();
1036        let tool_event_stream =
1037            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1038        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1039            status: Some(acp::ToolCallStatus::InProgress),
1040            ..Default::default()
1041        });
1042        let supports_images = self.model.supports_images();
1043        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1044        log::info!("Running tool {}", tool_use.name);
1045        Some(cx.foreground_executor().spawn(async move {
1046            let tool_result = tool_result.await.and_then(|output| {
1047                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1048                    if !supports_images {
1049                        return Err(anyhow!(
1050                            "Attempted to read an image, but this model doesn't support it.",
1051                        ));
1052                    }
1053                }
1054                Ok(output)
1055            });
1056
1057            match tool_result {
1058                Ok(output) => LanguageModelToolResult {
1059                    tool_use_id: tool_use.id,
1060                    tool_name: tool_use.name,
1061                    is_error: false,
1062                    content: output.llm_output,
1063                    output: Some(output.raw_output),
1064                },
1065                Err(error) => LanguageModelToolResult {
1066                    tool_use_id: tool_use.id,
1067                    tool_name: tool_use.name,
1068                    is_error: true,
1069                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1070                    output: None,
1071                },
1072            }
1073        }))
1074    }
1075
1076    fn handle_tool_use_json_parse_error_event(
1077        &mut self,
1078        tool_use_id: LanguageModelToolUseId,
1079        tool_name: Arc<str>,
1080        raw_input: Arc<str>,
1081        json_parse_error: String,
1082    ) -> LanguageModelToolResult {
1083        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1084        LanguageModelToolResult {
1085            tool_use_id,
1086            tool_name,
1087            is_error: true,
1088            content: LanguageModelToolResultContent::Text(tool_output.into()),
1089            output: Some(serde_json::Value::String(raw_input.to_string())),
1090        }
1091    }
1092
1093    fn pending_message(&mut self) -> &mut AgentMessage {
1094        self.pending_message.get_or_insert_default()
1095    }
1096
1097    fn flush_pending_message(&mut self) {
1098        let Some(mut message) = self.pending_message.take() else {
1099            return;
1100        };
1101
1102        for content in &message.content {
1103            let AgentMessageContent::ToolUse(tool_use) = content else {
1104                continue;
1105            };
1106
1107            if !message.tool_results.contains_key(&tool_use.id) {
1108                message.tool_results.insert(
1109                    tool_use.id.clone(),
1110                    LanguageModelToolResult {
1111                        tool_use_id: tool_use.id.clone(),
1112                        tool_name: tool_use.name.clone(),
1113                        is_error: true,
1114                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1115                        output: None,
1116                    },
1117                );
1118            }
1119        }
1120
1121        self.messages.push(Message::Agent(message));
1122    }
1123
1124    pub(crate) fn build_completion_request(
1125        &self,
1126        completion_intent: CompletionIntent,
1127        cx: &mut App,
1128    ) -> LanguageModelRequest {
1129        log::debug!("Building completion request");
1130        log::debug!("Completion intent: {:?}", completion_intent);
1131        log::debug!("Completion mode: {:?}", self.completion_mode);
1132
1133        let messages = self.build_request_messages();
1134        log::info!("Request will include {} messages", messages.len());
1135
1136        let tools = if let Some(tools) = self.tools(cx).log_err() {
1137            tools
1138                .filter_map(|tool| {
1139                    let tool_name = tool.name().to_string();
1140                    log::trace!("Including tool: {}", tool_name);
1141                    Some(LanguageModelRequestTool {
1142                        name: tool_name,
1143                        description: tool.description().to_string(),
1144                        input_schema: tool
1145                            .input_schema(self.model.tool_input_format())
1146                            .log_err()?,
1147                    })
1148                })
1149                .collect()
1150        } else {
1151            Vec::new()
1152        };
1153
1154        log::info!("Request includes {} tools", tools.len());
1155
1156        let request = LanguageModelRequest {
1157            thread_id: Some(self.id.to_string()),
1158            prompt_id: Some(self.prompt_id.to_string()),
1159            intent: Some(completion_intent),
1160            mode: Some(self.completion_mode.into()),
1161            messages,
1162            tools,
1163            tool_choice: None,
1164            stop: Vec::new(),
1165            temperature: AgentSettings::temperature_for_model(self.model(), cx),
1166            thinking_allowed: true,
1167        };
1168
1169        log::debug!("Completion request built successfully");
1170        request
1171    }
1172
1173    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1174        let profile = AgentSettings::get_global(cx)
1175            .profiles
1176            .get(&self.profile_id)
1177            .context("profile not found")?;
1178        let provider_id = self.model.provider_id();
1179
1180        Ok(self
1181            .tools
1182            .iter()
1183            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1184            .filter_map(|(tool_name, tool)| {
1185                if profile.is_tool_enabled(tool_name) {
1186                    Some(tool)
1187                } else {
1188                    None
1189                }
1190            })
1191            .chain(self.context_server_registry.read(cx).servers().flat_map(
1192                |(server_id, tools)| {
1193                    tools.iter().filter_map(|(tool_name, tool)| {
1194                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1195                            Some(tool)
1196                        } else {
1197                            None
1198                        }
1199                    })
1200                },
1201            )))
1202    }
1203
1204    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1205        log::trace!(
1206            "Building request messages from {} thread messages",
1207            self.messages.len()
1208        );
1209        let mut messages = vec![self.build_system_message()];
1210        for message in &self.messages {
1211            match message {
1212                Message::User(message) => messages.push(message.to_request()),
1213                Message::Agent(message) => messages.extend(message.to_request()),
1214                Message::Resume => messages.push(LanguageModelRequestMessage {
1215                    role: Role::User,
1216                    content: vec!["Continue where you left off".into()],
1217                    cache: false,
1218                }),
1219            }
1220        }
1221
1222        if let Some(message) = self.pending_message.as_ref() {
1223            messages.extend(message.to_request());
1224        }
1225
1226        if let Some(last_user_message) = messages
1227            .iter_mut()
1228            .rev()
1229            .find(|message| message.role == Role::User)
1230        {
1231            last_user_message.cache = true;
1232        }
1233
1234        messages
1235    }
1236
1237    pub fn to_markdown(&self) -> String {
1238        let mut markdown = String::new();
1239        for (ix, message) in self.messages.iter().enumerate() {
1240            if ix > 0 {
1241                markdown.push('\n');
1242            }
1243            markdown.push_str(&message.to_markdown());
1244        }
1245
1246        if let Some(message) = self.pending_message.as_ref() {
1247            markdown.push('\n');
1248            markdown.push_str(&message.to_markdown());
1249        }
1250
1251        markdown
1252    }
1253
1254    fn advance_prompt_id(&mut self) {
1255        self.prompt_id = PromptId::new();
1256    }
1257}
1258
1259struct RunningTurn {
1260    /// Holds the task that handles agent interaction until the end of the turn.
1261    /// Survives across multiple requests as the model performs tool calls and
1262    /// we run tools, report their results.
1263    _task: Task<()>,
1264    /// The current event stream for the running turn. Used to report a final
1265    /// cancellation event if we cancel the turn.
1266    event_stream: ThreadEventStream,
1267}
1268
1269impl RunningTurn {
1270    fn cancel(self) {
1271        log::debug!("Cancelling in progress turn");
1272        self.event_stream.send_canceled();
1273    }
1274}
1275
1276pub trait AgentTool
1277where
1278    Self: 'static + Sized,
1279{
1280    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1281    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1282
1283    fn name(&self) -> SharedString;
1284
1285    fn description(&self) -> SharedString {
1286        let schema = schemars::schema_for!(Self::Input);
1287        SharedString::new(
1288            schema
1289                .get("description")
1290                .and_then(|description| description.as_str())
1291                .unwrap_or_default(),
1292        )
1293    }
1294
1295    fn kind(&self) -> acp::ToolKind;
1296
1297    /// The initial tool title to display. Can be updated during the tool run.
1298    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1299
1300    /// Returns the JSON schema that describes the tool's input.
1301    fn input_schema(&self) -> Schema {
1302        schemars::schema_for!(Self::Input)
1303    }
1304
1305    /// Some tools rely on a provider for the underlying billing or other reasons.
1306    /// Allow the tool to check if they are compatible, or should be filtered out.
1307    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1308        true
1309    }
1310
1311    /// Runs the tool with the provided input.
1312    fn run(
1313        self: Arc<Self>,
1314        input: Self::Input,
1315        event_stream: ToolCallEventStream,
1316        cx: &mut App,
1317    ) -> Task<Result<Self::Output>>;
1318
1319    /// Emits events for a previous execution of the tool.
1320    fn replay(
1321        &self,
1322        _input: Self::Input,
1323        _output: Self::Output,
1324        _event_stream: ToolCallEventStream,
1325        _cx: &mut App,
1326    ) -> Result<()> {
1327        Ok(())
1328    }
1329
1330    fn erase(self) -> Arc<dyn AnyAgentTool> {
1331        Arc::new(Erased(Arc::new(self)))
1332    }
1333}
1334
1335pub struct Erased<T>(T);
1336
1337pub struct AgentToolOutput {
1338    pub llm_output: LanguageModelToolResultContent,
1339    pub raw_output: serde_json::Value,
1340}
1341
1342pub trait AnyAgentTool {
1343    fn name(&self) -> SharedString;
1344    fn description(&self) -> SharedString;
1345    fn kind(&self) -> acp::ToolKind;
1346    fn initial_title(&self, input: serde_json::Value) -> SharedString;
1347    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1348    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1349        true
1350    }
1351    fn run(
1352        self: Arc<Self>,
1353        input: serde_json::Value,
1354        event_stream: ToolCallEventStream,
1355        cx: &mut App,
1356    ) -> Task<Result<AgentToolOutput>>;
1357    fn replay(
1358        &self,
1359        input: serde_json::Value,
1360        output: serde_json::Value,
1361        event_stream: ToolCallEventStream,
1362        cx: &mut App,
1363    ) -> Result<()>;
1364}
1365
1366impl<T> AnyAgentTool for Erased<Arc<T>>
1367where
1368    T: AgentTool,
1369{
1370    fn name(&self) -> SharedString {
1371        self.0.name()
1372    }
1373
1374    fn description(&self) -> SharedString {
1375        self.0.description()
1376    }
1377
1378    fn kind(&self) -> agent_client_protocol::ToolKind {
1379        self.0.kind()
1380    }
1381
1382    fn initial_title(&self, input: serde_json::Value) -> SharedString {
1383        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1384        self.0.initial_title(parsed_input)
1385    }
1386
1387    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1388        let mut json = serde_json::to_value(self.0.input_schema())?;
1389        adapt_schema_to_format(&mut json, format)?;
1390        Ok(json)
1391    }
1392
1393    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1394        self.0.supported_provider(provider)
1395    }
1396
1397    fn run(
1398        self: Arc<Self>,
1399        input: serde_json::Value,
1400        event_stream: ToolCallEventStream,
1401        cx: &mut App,
1402    ) -> Task<Result<AgentToolOutput>> {
1403        cx.spawn(async move |cx| {
1404            let input = serde_json::from_value(input)?;
1405            let output = cx
1406                .update(|cx| self.0.clone().run(input, event_stream, cx))?
1407                .await?;
1408            let raw_output = serde_json::to_value(&output)?;
1409            Ok(AgentToolOutput {
1410                llm_output: output.into(),
1411                raw_output,
1412            })
1413        })
1414    }
1415
1416    fn replay(
1417        &self,
1418        input: serde_json::Value,
1419        output: serde_json::Value,
1420        event_stream: ToolCallEventStream,
1421        cx: &mut App,
1422    ) -> Result<()> {
1423        let input = serde_json::from_value(input)?;
1424        let output = serde_json::from_value(output)?;
1425        self.0.replay(input, output, event_stream, cx)
1426    }
1427}
1428
1429#[derive(Clone)]
1430struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1431
1432impl ThreadEventStream {
1433    fn send_user_message(&self, message: &UserMessage) {
1434        self.0
1435            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1436            .ok();
1437    }
1438
1439    fn send_text(&self, text: &str) {
1440        self.0
1441            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1442            .ok();
1443    }
1444
1445    fn send_thinking(&self, text: &str) {
1446        self.0
1447            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1448            .ok();
1449    }
1450
1451    fn send_tool_call(
1452        &self,
1453        id: &LanguageModelToolUseId,
1454        title: SharedString,
1455        kind: acp::ToolKind,
1456        input: serde_json::Value,
1457    ) {
1458        self.0
1459            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1460                id,
1461                title.to_string(),
1462                kind,
1463                input,
1464            ))))
1465            .ok();
1466    }
1467
1468    fn initial_tool_call(
1469        id: &LanguageModelToolUseId,
1470        title: String,
1471        kind: acp::ToolKind,
1472        input: serde_json::Value,
1473    ) -> acp::ToolCall {
1474        acp::ToolCall {
1475            id: acp::ToolCallId(id.to_string().into()),
1476            title,
1477            kind,
1478            status: acp::ToolCallStatus::Pending,
1479            content: vec![],
1480            locations: vec![],
1481            raw_input: Some(input),
1482            raw_output: None,
1483        }
1484    }
1485
1486    fn update_tool_call_fields(
1487        &self,
1488        tool_use_id: &LanguageModelToolUseId,
1489        fields: acp::ToolCallUpdateFields,
1490    ) {
1491        self.0
1492            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1493                acp::ToolCallUpdate {
1494                    id: acp::ToolCallId(tool_use_id.to_string().into()),
1495                    fields,
1496                }
1497                .into(),
1498            )))
1499            .ok();
1500    }
1501
1502    fn send_stop(&self, reason: StopReason) {
1503        match reason {
1504            StopReason::EndTurn => {
1505                self.0
1506                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1507                    .ok();
1508            }
1509            StopReason::MaxTokens => {
1510                self.0
1511                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1512                    .ok();
1513            }
1514            StopReason::Refusal => {
1515                self.0
1516                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1517                    .ok();
1518            }
1519            StopReason::ToolUse => {}
1520        }
1521    }
1522
1523    fn send_canceled(&self) {
1524        self.0
1525            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1526            .ok();
1527    }
1528
1529    fn send_error(&self, error: impl Into<anyhow::Error>) {
1530        self.0.unbounded_send(Err(error.into())).ok();
1531    }
1532}
1533
1534#[derive(Clone)]
1535pub struct ToolCallEventStream {
1536    tool_use_id: LanguageModelToolUseId,
1537    stream: ThreadEventStream,
1538    fs: Option<Arc<dyn Fs>>,
1539}
1540
1541impl ToolCallEventStream {
1542    #[cfg(test)]
1543    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1544        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1545
1546        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1547
1548        (stream, ToolCallEventStreamReceiver(events_rx))
1549    }
1550
1551    fn new(
1552        tool_use_id: LanguageModelToolUseId,
1553        stream: ThreadEventStream,
1554        fs: Option<Arc<dyn Fs>>,
1555    ) -> Self {
1556        Self {
1557            tool_use_id,
1558            stream,
1559            fs,
1560        }
1561    }
1562
1563    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1564        self.stream
1565            .update_tool_call_fields(&self.tool_use_id, fields);
1566    }
1567
1568    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1569        self.stream
1570            .0
1571            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1572                acp_thread::ToolCallUpdateDiff {
1573                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1574                    diff,
1575                }
1576                .into(),
1577            )))
1578            .ok();
1579    }
1580
1581    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1582        self.stream
1583            .0
1584            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1585                acp_thread::ToolCallUpdateTerminal {
1586                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1587                    terminal,
1588                }
1589                .into(),
1590            )))
1591            .ok();
1592    }
1593
1594    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1595        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1596            return Task::ready(Ok(()));
1597        }
1598
1599        let (response_tx, response_rx) = oneshot::channel();
1600        self.stream
1601            .0
1602            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1603                ToolCallAuthorization {
1604                    tool_call: acp::ToolCallUpdate {
1605                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1606                        fields: acp::ToolCallUpdateFields {
1607                            title: Some(title.into()),
1608                            ..Default::default()
1609                        },
1610                    },
1611                    options: vec![
1612                        acp::PermissionOption {
1613                            id: acp::PermissionOptionId("always_allow".into()),
1614                            name: "Always Allow".into(),
1615                            kind: acp::PermissionOptionKind::AllowAlways,
1616                        },
1617                        acp::PermissionOption {
1618                            id: acp::PermissionOptionId("allow".into()),
1619                            name: "Allow".into(),
1620                            kind: acp::PermissionOptionKind::AllowOnce,
1621                        },
1622                        acp::PermissionOption {
1623                            id: acp::PermissionOptionId("deny".into()),
1624                            name: "Deny".into(),
1625                            kind: acp::PermissionOptionKind::RejectOnce,
1626                        },
1627                    ],
1628                    response: response_tx,
1629                },
1630            )))
1631            .ok();
1632        let fs = self.fs.clone();
1633        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1634            "always_allow" => {
1635                if let Some(fs) = fs.clone() {
1636                    cx.update(|cx| {
1637                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1638                            settings.set_always_allow_tool_actions(true);
1639                        });
1640                    })?;
1641                }
1642
1643                Ok(())
1644            }
1645            "allow" => Ok(()),
1646            _ => Err(anyhow!("Permission to run tool denied by user")),
1647        })
1648    }
1649}
1650
1651#[cfg(test)]
1652pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1653
1654#[cfg(test)]
1655impl ToolCallEventStreamReceiver {
1656    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1657        let event = self.0.next().await;
1658        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1659            auth
1660        } else {
1661            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1662        }
1663    }
1664
1665    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1666        let event = self.0.next().await;
1667        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1668            update,
1669        )))) = event
1670        {
1671            update.terminal
1672        } else {
1673            panic!("Expected terminal but got: {:?}", event);
1674        }
1675    }
1676}
1677
1678#[cfg(test)]
1679impl std::ops::Deref for ToolCallEventStreamReceiver {
1680    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1681
1682    fn deref(&self) -> &Self::Target {
1683        &self.0
1684    }
1685}
1686
1687#[cfg(test)]
1688impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1689    fn deref_mut(&mut self) -> &mut Self::Target {
1690        &mut self.0
1691    }
1692}
1693
1694impl From<&str> for UserMessageContent {
1695    fn from(text: &str) -> Self {
1696        Self::Text(text.into())
1697    }
1698}
1699
1700impl From<acp::ContentBlock> for UserMessageContent {
1701    fn from(value: acp::ContentBlock) -> Self {
1702        match value {
1703            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1704            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1705            acp::ContentBlock::Audio(_) => {
1706                // TODO
1707                Self::Text("[audio]".to_string())
1708            }
1709            acp::ContentBlock::ResourceLink(resource_link) => {
1710                match MentionUri::parse(&resource_link.uri) {
1711                    Ok(uri) => Self::Mention {
1712                        uri,
1713                        content: String::new(),
1714                    },
1715                    Err(err) => {
1716                        log::error!("Failed to parse mention link: {}", err);
1717                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1718                    }
1719                }
1720            }
1721            acp::ContentBlock::Resource(resource) => match resource.resource {
1722                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1723                    match MentionUri::parse(&resource.uri) {
1724                        Ok(uri) => Self::Mention {
1725                            uri,
1726                            content: resource.text,
1727                        },
1728                        Err(err) => {
1729                            log::error!("Failed to parse mention link: {}", err);
1730                            Self::Text(
1731                                MarkdownCodeBlock {
1732                                    tag: &resource.uri,
1733                                    text: &resource.text,
1734                                }
1735                                .to_string(),
1736                            )
1737                        }
1738                    }
1739                }
1740                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1741                    // TODO
1742                    Self::Text("[blob]".to_string())
1743                }
1744            },
1745        }
1746    }
1747}
1748
1749impl From<UserMessageContent> for acp::ContentBlock {
1750    fn from(content: UserMessageContent) -> Self {
1751        match content {
1752            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
1753                text,
1754                annotations: None,
1755            }),
1756            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
1757                data: image.source.to_string(),
1758                mime_type: "image/png".to_string(),
1759                annotations: None,
1760                uri: None,
1761            }),
1762            UserMessageContent::Mention { uri, content } => {
1763                todo!()
1764            }
1765        }
1766    }
1767}
1768
1769fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1770    LanguageModelImage {
1771        source: image_content.data.into(),
1772        // TODO: make this optional?
1773        size: gpui::Size::new(0.into(), 0.into()),
1774    }
1775}