thread.rs

   1use crate::{
   2    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
   3    DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
   4    ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
   5    Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
   6};
   7use acp_thread::{MentionUri, UserMessageId};
   8use action_log::ActionLog;
   9use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
  10use agent_client_protocol as acp;
  11use agent_settings::{
  12    AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
  13    SUMMARIZE_THREAD_PROMPT,
  14};
  15use anyhow::{Context as _, Result, anyhow};
  16use assistant_tool::adapt_schema_to_format;
  17use chrono::{DateTime, Utc};
  18use client::{ModelRequestUsage, RequestUsage};
  19use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  20use collections::{HashMap, IndexMap};
  21use fs::Fs;
  22use futures::{
  23    FutureExt,
  24    channel::{mpsc, oneshot},
  25    future::Shared,
  26    stream::FuturesUnordered,
  27};
  28use git::repository::DiffType;
  29use gpui::{
  30    App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
  31};
  32use language_model::{
  33    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
  34    LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
  35    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  36    LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
  37    LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
  38};
  39use project::{
  40    Project,
  41    git_store::{GitStore, RepositoryState},
  42};
  43use prompt_store::ProjectContext;
  44use schemars::{JsonSchema, Schema};
  45use serde::{Deserialize, Serialize};
  46use settings::{Settings, update_settings_file};
  47use smol::stream::StreamExt;
  48use std::{
  49    collections::BTreeMap,
  50    path::Path,
  51    sync::Arc,
  52    time::{Duration, Instant},
  53};
  54use std::{fmt::Write, ops::Range};
  55use util::{ResultExt, markdown::MarkdownCodeBlock};
  56use uuid::Uuid;
  57
  58const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  59
  60/// The ID of the user prompt that initiated a request.
  61///
  62/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  63#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  64pub struct PromptId(Arc<str>);
  65
  66impl PromptId {
  67    pub fn new() -> Self {
  68        Self(Uuid::new_v4().to_string().into())
  69    }
  70}
  71
  72impl std::fmt::Display for PromptId {
  73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  74        write!(f, "{}", self.0)
  75    }
  76}
  77
  78pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
  79pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
  80
  81#[derive(Debug, Clone)]
  82enum RetryStrategy {
  83    ExponentialBackoff {
  84        initial_delay: Duration,
  85        max_attempts: u8,
  86    },
  87    Fixed {
  88        delay: Duration,
  89        max_attempts: u8,
  90    },
  91}
  92
  93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  94pub enum Message {
  95    User(UserMessage),
  96    Agent(AgentMessage),
  97    Resume,
  98}
  99
 100impl Message {
 101    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
 102        match self {
 103            Message::Agent(agent_message) => Some(agent_message),
 104            _ => None,
 105        }
 106    }
 107
 108    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 109        match self {
 110            Message::User(message) => vec![message.to_request()],
 111            Message::Agent(message) => message.to_request(),
 112            Message::Resume => vec![LanguageModelRequestMessage {
 113                role: Role::User,
 114                content: vec!["Continue where you left off".into()],
 115                cache: false,
 116            }],
 117        }
 118    }
 119
 120    pub fn to_markdown(&self) -> String {
 121        match self {
 122            Message::User(message) => message.to_markdown(),
 123            Message::Agent(message) => message.to_markdown(),
 124            Message::Resume => "[resumed after tool use limit was reached]".into(),
 125        }
 126    }
 127
 128    pub fn role(&self) -> Role {
 129        match self {
 130            Message::User(_) | Message::Resume => Role::User,
 131            Message::Agent(_) => Role::Assistant,
 132        }
 133    }
 134}
 135
 136#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 137pub struct UserMessage {
 138    pub id: UserMessageId,
 139    pub content: Vec<UserMessageContent>,
 140}
 141
 142#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 143pub enum UserMessageContent {
 144    Text(String),
 145    Mention { uri: MentionUri, content: String },
 146    Image(LanguageModelImage),
 147}
 148
 149impl UserMessage {
 150    pub fn to_markdown(&self) -> String {
 151        let mut markdown = String::from("## User\n\n");
 152
 153        for content in &self.content {
 154            match content {
 155                UserMessageContent::Text(text) => {
 156                    markdown.push_str(text);
 157                    markdown.push('\n');
 158                }
 159                UserMessageContent::Image(_) => {
 160                    markdown.push_str("<image />\n");
 161                }
 162                UserMessageContent::Mention { uri, content } => {
 163                    if !content.is_empty() {
 164                        let _ = writeln!(&mut markdown, "{}\n\n{}", uri.as_link(), content);
 165                    } else {
 166                        let _ = writeln!(&mut markdown, "{}", uri.as_link());
 167                    }
 168                }
 169            }
 170        }
 171
 172        markdown
 173    }
 174
 175    fn to_request(&self) -> LanguageModelRequestMessage {
 176        let mut message = LanguageModelRequestMessage {
 177            role: Role::User,
 178            content: Vec::with_capacity(self.content.len()),
 179            cache: false,
 180        };
 181
 182        const OPEN_CONTEXT: &str = "<context>\n\
 183            The following items were attached by the user. \
 184            They are up-to-date and don't need to be re-read.\n\n";
 185
 186        const OPEN_FILES_TAG: &str = "<files>";
 187        const OPEN_DIRECTORIES_TAG: &str = "<directories>";
 188        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 189        const OPEN_THREADS_TAG: &str = "<threads>";
 190        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 191        const OPEN_RULES_TAG: &str =
 192            "<rules>\nThe user has specified the following rules that should be applied:\n";
 193
 194        let mut file_context = OPEN_FILES_TAG.to_string();
 195        let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
 196        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 197        let mut thread_context = OPEN_THREADS_TAG.to_string();
 198        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 199        let mut rules_context = OPEN_RULES_TAG.to_string();
 200
 201        for chunk in &self.content {
 202            let chunk = match chunk {
 203                UserMessageContent::Text(text) => {
 204                    language_model::MessageContent::Text(text.clone())
 205                }
 206                UserMessageContent::Image(value) => {
 207                    language_model::MessageContent::Image(value.clone())
 208                }
 209                UserMessageContent::Mention { uri, content } => {
 210                    match uri {
 211                        MentionUri::File { abs_path } => {
 212                            write!(
 213                                &mut symbol_context,
 214                                "\n{}",
 215                                MarkdownCodeBlock {
 216                                    tag: &codeblock_tag(abs_path, None),
 217                                    text: &content.to_string(),
 218                                }
 219                            )
 220                            .ok();
 221                        }
 222                        MentionUri::Directory { .. } => {
 223                            write!(&mut directory_context, "\n{}\n", content).ok();
 224                        }
 225                        MentionUri::Symbol {
 226                            path, line_range, ..
 227                        }
 228                        | MentionUri::Selection {
 229                            path, line_range, ..
 230                        } => {
 231                            write!(
 232                                &mut rules_context,
 233                                "\n{}",
 234                                MarkdownCodeBlock {
 235                                    tag: &codeblock_tag(path, Some(line_range)),
 236                                    text: content
 237                                }
 238                            )
 239                            .ok();
 240                        }
 241                        MentionUri::Thread { .. } => {
 242                            write!(&mut thread_context, "\n{}\n", content).ok();
 243                        }
 244                        MentionUri::TextThread { .. } => {
 245                            write!(&mut thread_context, "\n{}\n", content).ok();
 246                        }
 247                        MentionUri::Rule { .. } => {
 248                            write!(
 249                                &mut rules_context,
 250                                "\n{}",
 251                                MarkdownCodeBlock {
 252                                    tag: "",
 253                                    text: content
 254                                }
 255                            )
 256                            .ok();
 257                        }
 258                        MentionUri::Fetch { url } => {
 259                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 260                        }
 261                    }
 262
 263                    language_model::MessageContent::Text(uri.as_link().to_string())
 264                }
 265            };
 266
 267            message.content.push(chunk);
 268        }
 269
 270        let len_before_context = message.content.len();
 271
 272        if file_context.len() > OPEN_FILES_TAG.len() {
 273            file_context.push_str("</files>\n");
 274            message
 275                .content
 276                .push(language_model::MessageContent::Text(file_context));
 277        }
 278
 279        if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
 280            directory_context.push_str("</directories>\n");
 281            message
 282                .content
 283                .push(language_model::MessageContent::Text(directory_context));
 284        }
 285
 286        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 287            symbol_context.push_str("</symbols>\n");
 288            message
 289                .content
 290                .push(language_model::MessageContent::Text(symbol_context));
 291        }
 292
 293        if thread_context.len() > OPEN_THREADS_TAG.len() {
 294            thread_context.push_str("</threads>\n");
 295            message
 296                .content
 297                .push(language_model::MessageContent::Text(thread_context));
 298        }
 299
 300        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 301            fetch_context.push_str("</fetched_urls>\n");
 302            message
 303                .content
 304                .push(language_model::MessageContent::Text(fetch_context));
 305        }
 306
 307        if rules_context.len() > OPEN_RULES_TAG.len() {
 308            rules_context.push_str("</user_rules>\n");
 309            message
 310                .content
 311                .push(language_model::MessageContent::Text(rules_context));
 312        }
 313
 314        if message.content.len() > len_before_context {
 315            message.content.insert(
 316                len_before_context,
 317                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 318            );
 319            message
 320                .content
 321                .push(language_model::MessageContent::Text("</context>".into()));
 322        }
 323
 324        message
 325    }
 326}
 327
 328fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 329    let mut result = String::new();
 330
 331    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 332        let _ = write!(result, "{} ", extension);
 333    }
 334
 335    let _ = write!(result, "{}", full_path.display());
 336
 337    if let Some(range) = line_range {
 338        if range.start == range.end {
 339            let _ = write!(result, ":{}", range.start + 1);
 340        } else {
 341            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 342        }
 343    }
 344
 345    result
 346}
 347
 348impl AgentMessage {
 349    pub fn to_markdown(&self) -> String {
 350        let mut markdown = String::from("## Assistant\n\n");
 351
 352        for content in &self.content {
 353            match content {
 354                AgentMessageContent::Text(text) => {
 355                    markdown.push_str(text);
 356                    markdown.push('\n');
 357                }
 358                AgentMessageContent::Thinking { text, .. } => {
 359                    markdown.push_str("<think>");
 360                    markdown.push_str(text);
 361                    markdown.push_str("</think>\n");
 362                }
 363                AgentMessageContent::RedactedThinking(_) => {
 364                    markdown.push_str("<redacted_thinking />\n")
 365                }
 366                AgentMessageContent::ToolUse(tool_use) => {
 367                    markdown.push_str(&format!(
 368                        "**Tool Use**: {} (ID: {})\n",
 369                        tool_use.name, tool_use.id
 370                    ));
 371                    markdown.push_str(&format!(
 372                        "{}\n",
 373                        MarkdownCodeBlock {
 374                            tag: "json",
 375                            text: &format!("{:#}", tool_use.input)
 376                        }
 377                    ));
 378                }
 379            }
 380        }
 381
 382        for tool_result in self.tool_results.values() {
 383            markdown.push_str(&format!(
 384                "**Tool Result**: {} (ID: {})\n\n",
 385                tool_result.tool_name, tool_result.tool_use_id
 386            ));
 387            if tool_result.is_error {
 388                markdown.push_str("**ERROR:**\n");
 389            }
 390
 391            match &tool_result.content {
 392                LanguageModelToolResultContent::Text(text) => {
 393                    writeln!(markdown, "{text}\n").ok();
 394                }
 395                LanguageModelToolResultContent::Image(_) => {
 396                    writeln!(markdown, "<image />\n").ok();
 397                }
 398            }
 399
 400            if let Some(output) = tool_result.output.as_ref() {
 401                writeln!(
 402                    markdown,
 403                    "**Debug Output**:\n\n```json\n{}\n```\n",
 404                    serde_json::to_string_pretty(output).unwrap()
 405                )
 406                .unwrap();
 407            }
 408        }
 409
 410        markdown
 411    }
 412
 413    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 414        let mut assistant_message = LanguageModelRequestMessage {
 415            role: Role::Assistant,
 416            content: Vec::with_capacity(self.content.len()),
 417            cache: false,
 418        };
 419        for chunk in &self.content {
 420            let chunk = match chunk {
 421                AgentMessageContent::Text(text) => {
 422                    language_model::MessageContent::Text(text.clone())
 423                }
 424                AgentMessageContent::Thinking { text, signature } => {
 425                    language_model::MessageContent::Thinking {
 426                        text: text.clone(),
 427                        signature: signature.clone(),
 428                    }
 429                }
 430                AgentMessageContent::RedactedThinking(value) => {
 431                    language_model::MessageContent::RedactedThinking(value.clone())
 432                }
 433                AgentMessageContent::ToolUse(value) => {
 434                    language_model::MessageContent::ToolUse(value.clone())
 435                }
 436            };
 437            assistant_message.content.push(chunk);
 438        }
 439
 440        let mut user_message = LanguageModelRequestMessage {
 441            role: Role::User,
 442            content: Vec::new(),
 443            cache: false,
 444        };
 445
 446        for tool_result in self.tool_results.values() {
 447            user_message
 448                .content
 449                .push(language_model::MessageContent::ToolResult(
 450                    tool_result.clone(),
 451                ));
 452        }
 453
 454        let mut messages = Vec::new();
 455        if !assistant_message.content.is_empty() {
 456            messages.push(assistant_message);
 457        }
 458        if !user_message.content.is_empty() {
 459            messages.push(user_message);
 460        }
 461        messages
 462    }
 463}
 464
 465#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 466pub struct AgentMessage {
 467    pub content: Vec<AgentMessageContent>,
 468    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 469}
 470
 471#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 472pub enum AgentMessageContent {
 473    Text(String),
 474    Thinking {
 475        text: String,
 476        signature: Option<String>,
 477    },
 478    RedactedThinking(String),
 479    ToolUse(LanguageModelToolUse),
 480}
 481
 482#[derive(Debug)]
 483pub enum ThreadEvent {
 484    UserMessage(UserMessage),
 485    AgentText(String),
 486    AgentThinking(String),
 487    ToolCall(acp::ToolCall),
 488    ToolCallUpdate(acp_thread::ToolCallUpdate),
 489    ToolCallAuthorization(ToolCallAuthorization),
 490    TitleUpdate(SharedString),
 491    Retry(acp_thread::RetryStatus),
 492    Stop(acp::StopReason),
 493}
 494
 495#[derive(Debug)]
 496pub struct ToolCallAuthorization {
 497    pub tool_call: acp::ToolCallUpdate,
 498    pub options: Vec<acp::PermissionOption>,
 499    pub response: oneshot::Sender<acp::PermissionOptionId>,
 500}
 501
 502pub struct Thread {
 503    id: acp::SessionId,
 504    prompt_id: PromptId,
 505    updated_at: DateTime<Utc>,
 506    title: Option<SharedString>,
 507    summary: Option<SharedString>,
 508    messages: Vec<Message>,
 509    completion_mode: CompletionMode,
 510    /// Holds the task that handles agent interaction until the end of the turn.
 511    /// Survives across multiple requests as the model performs tool calls and
 512    /// we run tools, report their results.
 513    running_turn: Option<RunningTurn>,
 514    pending_message: Option<AgentMessage>,
 515    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 516    tool_use_limit_reached: bool,
 517    request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
 518    #[allow(unused)]
 519    cumulative_token_usage: TokenUsage,
 520    #[allow(unused)]
 521    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 522    context_server_registry: Entity<ContextServerRegistry>,
 523    profile_id: AgentProfileId,
 524    project_context: Entity<ProjectContext>,
 525    templates: Arc<Templates>,
 526    model: Option<Arc<dyn LanguageModel>>,
 527    summarization_model: Option<Arc<dyn LanguageModel>>,
 528    pub(crate) project: Entity<Project>,
 529    pub(crate) action_log: Entity<ActionLog>,
 530}
 531
 532impl Thread {
 533    pub fn new(
 534        project: Entity<Project>,
 535        project_context: Entity<ProjectContext>,
 536        context_server_registry: Entity<ContextServerRegistry>,
 537        action_log: Entity<ActionLog>,
 538        templates: Arc<Templates>,
 539        model: Option<Arc<dyn LanguageModel>>,
 540        cx: &mut Context<Self>,
 541    ) -> Self {
 542        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 543        Self {
 544            id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
 545            prompt_id: PromptId::new(),
 546            updated_at: Utc::now(),
 547            title: None,
 548            summary: None,
 549            messages: Vec::new(),
 550            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 551            running_turn: None,
 552            pending_message: None,
 553            tools: BTreeMap::default(),
 554            tool_use_limit_reached: false,
 555            request_token_usage: HashMap::default(),
 556            cumulative_token_usage: TokenUsage::default(),
 557            initial_project_snapshot: {
 558                let project_snapshot = Self::project_snapshot(project.clone(), cx);
 559                cx.foreground_executor()
 560                    .spawn(async move { Some(project_snapshot.await) })
 561                    .shared()
 562            },
 563            context_server_registry,
 564            profile_id,
 565            project_context,
 566            templates,
 567            model,
 568            summarization_model: None,
 569            project,
 570            action_log,
 571        }
 572    }
 573
 574    pub fn id(&self) -> &acp::SessionId {
 575        &self.id
 576    }
 577
 578    pub fn replay(
 579        &mut self,
 580        cx: &mut Context<Self>,
 581    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 582        let (tx, rx) = mpsc::unbounded();
 583        let stream = ThreadEventStream(tx);
 584        for message in &self.messages {
 585            match message {
 586                Message::User(user_message) => stream.send_user_message(user_message),
 587                Message::Agent(assistant_message) => {
 588                    for content in &assistant_message.content {
 589                        match content {
 590                            AgentMessageContent::Text(text) => stream.send_text(text),
 591                            AgentMessageContent::Thinking { text, .. } => {
 592                                stream.send_thinking(text)
 593                            }
 594                            AgentMessageContent::RedactedThinking(_) => {}
 595                            AgentMessageContent::ToolUse(tool_use) => {
 596                                self.replay_tool_call(
 597                                    tool_use,
 598                                    assistant_message.tool_results.get(&tool_use.id),
 599                                    &stream,
 600                                    cx,
 601                                );
 602                            }
 603                        }
 604                    }
 605                }
 606                Message::Resume => {}
 607            }
 608        }
 609        rx
 610    }
 611
 612    fn replay_tool_call(
 613        &self,
 614        tool_use: &LanguageModelToolUse,
 615        tool_result: Option<&LanguageModelToolResult>,
 616        stream: &ThreadEventStream,
 617        cx: &mut Context<Self>,
 618    ) {
 619        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
 620            stream
 621                .0
 622                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 623                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 624                    title: tool_use.name.to_string(),
 625                    kind: acp::ToolKind::Other,
 626                    status: acp::ToolCallStatus::Failed,
 627                    content: Vec::new(),
 628                    locations: Vec::new(),
 629                    raw_input: Some(tool_use.input.clone()),
 630                    raw_output: None,
 631                })))
 632                .ok();
 633            return;
 634        };
 635
 636        let title = tool.initial_title(tool_use.input.clone());
 637        let kind = tool.kind();
 638        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 639
 640        let output = tool_result
 641            .as_ref()
 642            .and_then(|result| result.output.clone());
 643        if let Some(output) = output.clone() {
 644            let tool_event_stream = ToolCallEventStream::new(
 645                tool_use.id.clone(),
 646                stream.clone(),
 647                Some(self.project.read(cx).fs().clone()),
 648            );
 649            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 650                .log_err();
 651        }
 652
 653        stream.update_tool_call_fields(
 654            &tool_use.id,
 655            acp::ToolCallUpdateFields {
 656                status: Some(acp::ToolCallStatus::Completed),
 657                raw_output: output,
 658                ..Default::default()
 659            },
 660        );
 661    }
 662
 663    pub fn from_db(
 664        id: acp::SessionId,
 665        db_thread: DbThread,
 666        project: Entity<Project>,
 667        project_context: Entity<ProjectContext>,
 668        context_server_registry: Entity<ContextServerRegistry>,
 669        action_log: Entity<ActionLog>,
 670        templates: Arc<Templates>,
 671        cx: &mut Context<Self>,
 672    ) -> Self {
 673        let profile_id = db_thread
 674            .profile
 675            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 676        let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 677            db_thread
 678                .model
 679                .and_then(|model| {
 680                    let model = SelectedModel {
 681                        provider: model.provider.clone().into(),
 682                        model: model.model.into(),
 683                    };
 684                    registry.select_model(&model, cx)
 685                })
 686                .or_else(|| registry.default_model())
 687                .map(|model| model.model)
 688        });
 689
 690        Self {
 691            id,
 692            prompt_id: PromptId::new(),
 693            title: if db_thread.title.is_empty() {
 694                None
 695            } else {
 696                Some(db_thread.title.clone())
 697            },
 698            summary: db_thread.detailed_summary,
 699            messages: db_thread.messages,
 700            completion_mode: db_thread.completion_mode.unwrap_or_default(),
 701            running_turn: None,
 702            pending_message: None,
 703            tools: BTreeMap::default(),
 704            tool_use_limit_reached: false,
 705            request_token_usage: db_thread.request_token_usage.clone(),
 706            cumulative_token_usage: db_thread.cumulative_token_usage,
 707            initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
 708            context_server_registry,
 709            profile_id,
 710            project_context,
 711            templates,
 712            model,
 713            summarization_model: None,
 714            project,
 715            action_log,
 716            updated_at: db_thread.updated_at,
 717        }
 718    }
 719
 720    pub fn to_db(&self, cx: &App) -> Task<DbThread> {
 721        let initial_project_snapshot = self.initial_project_snapshot.clone();
 722        let mut thread = DbThread {
 723            title: self.title(),
 724            messages: self.messages.clone(),
 725            updated_at: self.updated_at,
 726            detailed_summary: self.summary.clone(),
 727            initial_project_snapshot: None,
 728            cumulative_token_usage: self.cumulative_token_usage,
 729            request_token_usage: self.request_token_usage.clone(),
 730            model: self.model.as_ref().map(|model| DbLanguageModel {
 731                provider: model.provider_id().to_string(),
 732                model: model.name().0.to_string(),
 733            }),
 734            completion_mode: Some(self.completion_mode),
 735            profile: Some(self.profile_id.clone()),
 736        };
 737
 738        cx.background_spawn(async move {
 739            let initial_project_snapshot = initial_project_snapshot.await;
 740            thread.initial_project_snapshot = initial_project_snapshot;
 741            thread
 742        })
 743    }
 744
 745    /// Create a snapshot of the current project state including git information and unsaved buffers.
 746    fn project_snapshot(
 747        project: Entity<Project>,
 748        cx: &mut Context<Self>,
 749    ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
 750        let git_store = project.read(cx).git_store().clone();
 751        let worktree_snapshots: Vec<_> = project
 752            .read(cx)
 753            .visible_worktrees(cx)
 754            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
 755            .collect();
 756
 757        cx.spawn(async move |_, cx| {
 758            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
 759
 760            let mut unsaved_buffers = Vec::new();
 761            cx.update(|app_cx| {
 762                let buffer_store = project.read(app_cx).buffer_store();
 763                for buffer_handle in buffer_store.read(app_cx).buffers() {
 764                    let buffer = buffer_handle.read(app_cx);
 765                    if buffer.is_dirty()
 766                        && let Some(file) = buffer.file()
 767                    {
 768                        let path = file.path().to_string_lossy().to_string();
 769                        unsaved_buffers.push(path);
 770                    }
 771                }
 772            })
 773            .ok();
 774
 775            Arc::new(ProjectSnapshot {
 776                worktree_snapshots,
 777                unsaved_buffer_paths: unsaved_buffers,
 778                timestamp: Utc::now(),
 779            })
 780        })
 781    }
 782
 783    fn worktree_snapshot(
 784        worktree: Entity<project::Worktree>,
 785        git_store: Entity<GitStore>,
 786        cx: &App,
 787    ) -> Task<agent::thread::WorktreeSnapshot> {
 788        cx.spawn(async move |cx| {
 789            // Get worktree path and snapshot
 790            let worktree_info = cx.update(|app_cx| {
 791                let worktree = worktree.read(app_cx);
 792                let path = worktree.abs_path().to_string_lossy().to_string();
 793                let snapshot = worktree.snapshot();
 794                (path, snapshot)
 795            });
 796
 797            let Ok((worktree_path, _snapshot)) = worktree_info else {
 798                return WorktreeSnapshot {
 799                    worktree_path: String::new(),
 800                    git_state: None,
 801                };
 802            };
 803
 804            let git_state = git_store
 805                .update(cx, |git_store, cx| {
 806                    git_store
 807                        .repositories()
 808                        .values()
 809                        .find(|repo| {
 810                            repo.read(cx)
 811                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
 812                                .is_some()
 813                        })
 814                        .cloned()
 815                })
 816                .ok()
 817                .flatten()
 818                .map(|repo| {
 819                    repo.update(cx, |repo, _| {
 820                        let current_branch =
 821                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
 822                        repo.send_job(None, |state, _| async move {
 823                            let RepositoryState::Local { backend, .. } = state else {
 824                                return GitState {
 825                                    remote_url: None,
 826                                    head_sha: None,
 827                                    current_branch,
 828                                    diff: None,
 829                                };
 830                            };
 831
 832                            let remote_url = backend.remote_url("origin");
 833                            let head_sha = backend.head_sha().await;
 834                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
 835
 836                            GitState {
 837                                remote_url,
 838                                head_sha,
 839                                current_branch,
 840                                diff,
 841                            }
 842                        })
 843                    })
 844                });
 845
 846            let git_state = match git_state {
 847                Some(git_state) => match git_state.ok() {
 848                    Some(git_state) => git_state.await.ok(),
 849                    None => None,
 850                },
 851                None => None,
 852            };
 853
 854            WorktreeSnapshot {
 855                worktree_path,
 856                git_state,
 857            }
 858        })
 859    }
 860
 861    pub fn project_context(&self) -> &Entity<ProjectContext> {
 862        &self.project_context
 863    }
 864
 865    pub fn project(&self) -> &Entity<Project> {
 866        &self.project
 867    }
 868
 869    pub fn action_log(&self) -> &Entity<ActionLog> {
 870        &self.action_log
 871    }
 872
 873    pub fn is_empty(&self) -> bool {
 874        self.messages.is_empty() && self.title.is_none()
 875    }
 876
 877    pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
 878        self.model.as_ref()
 879    }
 880
 881    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
 882        let old_usage = self.latest_token_usage();
 883        self.model = Some(model);
 884        let new_usage = self.latest_token_usage();
 885        if old_usage != new_usage {
 886            cx.emit(TokenUsageUpdated(new_usage));
 887        }
 888        cx.notify()
 889    }
 890
 891    pub fn summarization_model(&self) -> Option<&Arc<dyn LanguageModel>> {
 892        self.summarization_model.as_ref()
 893    }
 894
 895    pub fn set_summarization_model(
 896        &mut self,
 897        model: Option<Arc<dyn LanguageModel>>,
 898        cx: &mut Context<Self>,
 899    ) {
 900        self.summarization_model = model;
 901        cx.notify()
 902    }
 903
 904    pub fn completion_mode(&self) -> CompletionMode {
 905        self.completion_mode
 906    }
 907
 908    pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
 909        let old_usage = self.latest_token_usage();
 910        self.completion_mode = mode;
 911        let new_usage = self.latest_token_usage();
 912        if old_usage != new_usage {
 913            cx.emit(TokenUsageUpdated(new_usage));
 914        }
 915        cx.notify()
 916    }
 917
 918    #[cfg(any(test, feature = "test-support"))]
 919    pub fn last_message(&self) -> Option<Message> {
 920        if let Some(message) = self.pending_message.clone() {
 921            Some(Message::Agent(message))
 922        } else {
 923            self.messages.last().cloned()
 924        }
 925    }
 926
 927    pub fn add_default_tools(&mut self, cx: &mut Context<Self>) {
 928        let language_registry = self.project.read(cx).languages().clone();
 929        self.add_tool(CopyPathTool::new(self.project.clone()));
 930        self.add_tool(CreateDirectoryTool::new(self.project.clone()));
 931        self.add_tool(DeletePathTool::new(
 932            self.project.clone(),
 933            self.action_log.clone(),
 934        ));
 935        self.add_tool(DiagnosticsTool::new(self.project.clone()));
 936        self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
 937        self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
 938        self.add_tool(FindPathTool::new(self.project.clone()));
 939        self.add_tool(GrepTool::new(self.project.clone()));
 940        self.add_tool(ListDirectoryTool::new(self.project.clone()));
 941        self.add_tool(MovePathTool::new(self.project.clone()));
 942        self.add_tool(NowTool);
 943        self.add_tool(OpenTool::new(self.project.clone()));
 944        self.add_tool(ReadFileTool::new(
 945            self.project.clone(),
 946            self.action_log.clone(),
 947        ));
 948        self.add_tool(TerminalTool::new(self.project.clone(), cx));
 949        self.add_tool(ThinkingTool);
 950        self.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
 951    }
 952
 953    pub fn add_tool(&mut self, tool: impl AgentTool) {
 954        self.tools.insert(tool.name(), tool.erase());
 955    }
 956
 957    pub fn remove_tool(&mut self, name: &str) -> bool {
 958        self.tools.remove(name).is_some()
 959    }
 960
 961    pub fn profile(&self) -> &AgentProfileId {
 962        &self.profile_id
 963    }
 964
 965    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 966        self.profile_id = profile_id;
 967    }
 968
 969    pub fn cancel(&mut self, cx: &mut Context<Self>) {
 970        if let Some(running_turn) = self.running_turn.take() {
 971            running_turn.cancel();
 972        }
 973        self.flush_pending_message(cx);
 974    }
 975
 976    fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
 977        let Some(last_user_message) = self.last_user_message() else {
 978            return;
 979        };
 980
 981        self.request_token_usage
 982            .insert(last_user_message.id.clone(), update);
 983        cx.emit(TokenUsageUpdated(self.latest_token_usage()));
 984        cx.notify();
 985    }
 986
 987    pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
 988        self.cancel(cx);
 989        let Some(position) = self.messages.iter().position(
 990            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
 991        ) else {
 992            return Err(anyhow!("Message not found"));
 993        };
 994
 995        for message in self.messages.drain(position..) {
 996            match message {
 997                Message::User(message) => {
 998                    self.request_token_usage.remove(&message.id);
 999                }
1000                Message::Agent(_) | Message::Resume => {}
1001            }
1002        }
1003        self.summary = None;
1004        cx.notify();
1005        Ok(())
1006    }
1007
1008    pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
1009        let last_user_message = self.last_user_message()?;
1010        let tokens = self.request_token_usage.get(&last_user_message.id)?;
1011        let model = self.model.clone()?;
1012
1013        Some(acp_thread::TokenUsage {
1014            max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
1015            used_tokens: tokens.total_tokens(),
1016        })
1017    }
1018
1019    pub fn resume(
1020        &mut self,
1021        cx: &mut Context<Self>,
1022    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1023        anyhow::ensure!(
1024            self.tool_use_limit_reached,
1025            "can only resume after tool use limit is reached"
1026        );
1027
1028        self.messages.push(Message::Resume);
1029        cx.notify();
1030
1031        log::info!("Total messages in thread: {}", self.messages.len());
1032        self.run_turn(cx)
1033    }
1034
1035    /// Sending a message results in the model streaming a response, which could include tool calls.
1036    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
1037    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
1038    pub fn send<T>(
1039        &mut self,
1040        id: UserMessageId,
1041        content: impl IntoIterator<Item = T>,
1042        cx: &mut Context<Self>,
1043    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
1044    where
1045        T: Into<UserMessageContent>,
1046    {
1047        let model = self.model().context("No language model configured")?;
1048
1049        log::info!("Thread::send called with model: {:?}", model.name());
1050        self.advance_prompt_id();
1051
1052        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
1053        log::debug!("Thread::send content: {:?}", content);
1054
1055        self.messages
1056            .push(Message::User(UserMessage { id, content }));
1057        cx.notify();
1058
1059        log::info!("Total messages in thread: {}", self.messages.len());
1060        self.run_turn(cx)
1061    }
1062
1063    fn run_turn(
1064        &mut self,
1065        cx: &mut Context<Self>,
1066    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1067        self.cancel(cx);
1068
1069        let model = self.model.clone().context("No language model configured")?;
1070        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1071        let event_stream = ThreadEventStream(events_tx);
1072        let message_ix = self.messages.len().saturating_sub(1);
1073        self.tool_use_limit_reached = false;
1074        self.summary = None;
1075        self.running_turn = Some(RunningTurn {
1076            event_stream: event_stream.clone(),
1077            _task: cx.spawn(async move |this, cx| {
1078                log::info!("Starting agent turn execution");
1079                let mut update_title = None;
1080                let turn_result: Result<StopReason> = async {
1081                    let mut completion_intent = CompletionIntent::UserPrompt;
1082                    loop {
1083                        log::debug!(
1084                            "Building completion request with intent: {:?}",
1085                            completion_intent
1086                        );
1087                        let request = this.update(cx, |this, cx| {
1088                            this.build_completion_request(completion_intent, cx)
1089                        })??;
1090
1091                        log::info!("Calling model.stream_completion");
1092
1093                        let mut tool_use_limit_reached = false;
1094                        let mut refused = false;
1095                        let mut reached_max_tokens = false;
1096                        let mut tool_uses = Self::stream_completion_with_retries(
1097                            this.clone(),
1098                            model.clone(),
1099                            request,
1100                            &event_stream,
1101                            &mut tool_use_limit_reached,
1102                            &mut refused,
1103                            &mut reached_max_tokens,
1104                            cx,
1105                        )
1106                        .await?;
1107
1108                        if refused {
1109                            return Ok(StopReason::Refusal);
1110                        } else if reached_max_tokens {
1111                            return Ok(StopReason::MaxTokens);
1112                        }
1113
1114                        let end_turn = tool_uses.is_empty();
1115                        while let Some(tool_result) = tool_uses.next().await {
1116                            log::info!("Tool finished {:?}", tool_result);
1117
1118                            event_stream.update_tool_call_fields(
1119                                &tool_result.tool_use_id,
1120                                acp::ToolCallUpdateFields {
1121                                    status: Some(if tool_result.is_error {
1122                                        acp::ToolCallStatus::Failed
1123                                    } else {
1124                                        acp::ToolCallStatus::Completed
1125                                    }),
1126                                    raw_output: tool_result.output.clone(),
1127                                    ..Default::default()
1128                                },
1129                            );
1130                            this.update(cx, |this, _cx| {
1131                                this.pending_message()
1132                                    .tool_results
1133                                    .insert(tool_result.tool_use_id.clone(), tool_result);
1134                            })?;
1135                        }
1136
1137                        this.update(cx, |this, cx| {
1138                            if this.title.is_none() && update_title.is_none() {
1139                                update_title = Some(this.update_title(&event_stream, cx));
1140                            }
1141                        })?;
1142
1143                        if tool_use_limit_reached {
1144                            log::info!("Tool use limit reached, completing turn");
1145                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
1146                            return Err(language_model::ToolUseLimitReachedError.into());
1147                        } else if end_turn {
1148                            log::info!("No tool uses found, completing turn");
1149                            return Ok(StopReason::EndTurn);
1150                        } else {
1151                            this.update(cx, |this, cx| this.flush_pending_message(cx))?;
1152                            completion_intent = CompletionIntent::ToolResults;
1153                        }
1154                    }
1155                }
1156                .await;
1157                _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
1158
1159                match turn_result {
1160                    Ok(reason) => {
1161                        log::info!("Turn execution completed: {:?}", reason);
1162
1163                        if let Some(update_title) = update_title {
1164                            update_title.await.context("update title failed").log_err();
1165                        }
1166
1167                        event_stream.send_stop(reason);
1168                        if reason == StopReason::Refusal {
1169                            _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
1170                        }
1171                    }
1172                    Err(error) => {
1173                        log::error!("Turn execution failed: {:?}", error);
1174                        event_stream.send_error(error);
1175                    }
1176                }
1177
1178                _ = this.update(cx, |this, _| this.running_turn.take());
1179            }),
1180        });
1181        Ok(events_rx)
1182    }
1183
1184    async fn stream_completion_with_retries(
1185        this: WeakEntity<Self>,
1186        model: Arc<dyn LanguageModel>,
1187        request: LanguageModelRequest,
1188        event_stream: &ThreadEventStream,
1189        tool_use_limit_reached: &mut bool,
1190        refusal: &mut bool,
1191        max_tokens_reached: &mut bool,
1192        cx: &mut AsyncApp,
1193    ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
1194        log::debug!("Stream completion started successfully");
1195
1196        let mut attempt = None;
1197        'retry: loop {
1198            telemetry::event!(
1199                "Agent Thread Completion",
1200                thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
1201                prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
1202                model = model.telemetry_id(),
1203                model_provider = model.provider_id().to_string(),
1204                attempt
1205            );
1206
1207            let mut events = model.stream_completion(request.clone(), cx).await?;
1208            let mut tool_uses = FuturesUnordered::new();
1209            while let Some(event) = events.next().await {
1210                match event {
1211                    Ok(LanguageModelCompletionEvent::StatusUpdate(
1212                        CompletionRequestStatus::ToolUseLimitReached,
1213                    )) => {
1214                        *tool_use_limit_reached = true;
1215                    }
1216                    Ok(LanguageModelCompletionEvent::StatusUpdate(
1217                        CompletionRequestStatus::UsageUpdated { amount, limit },
1218                    )) => {
1219                        this.update(cx, |this, cx| {
1220                            this.update_model_request_usage(amount, limit, cx)
1221                        })?;
1222                    }
1223                    Ok(LanguageModelCompletionEvent::UsageUpdate(usage)) => {
1224                        telemetry::event!(
1225                            "Agent Thread Completion Usage Updated",
1226                            thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
1227                            prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
1228                            model = model.telemetry_id(),
1229                            model_provider = model.provider_id().to_string(),
1230                            attempt,
1231                            input_tokens = usage.input_tokens,
1232                            output_tokens = usage.output_tokens,
1233                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1234                            cache_read_input_tokens = usage.cache_read_input_tokens,
1235                        );
1236
1237                        this.update(cx, |this, cx| this.update_token_usage(usage, cx))?;
1238                    }
1239                    Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
1240                        *refusal = true;
1241                        return Ok(FuturesUnordered::default());
1242                    }
1243                    Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
1244                        *max_tokens_reached = true;
1245                        return Ok(FuturesUnordered::default());
1246                    }
1247                    Ok(LanguageModelCompletionEvent::Stop(
1248                        StopReason::ToolUse | StopReason::EndTurn,
1249                    )) => break,
1250                    Ok(event) => {
1251                        log::trace!("Received completion event: {:?}", event);
1252                        this.update(cx, |this, cx| {
1253                            tool_uses.extend(this.handle_streamed_completion_event(
1254                                event,
1255                                event_stream,
1256                                cx,
1257                            ));
1258                        })?;
1259                    }
1260                    Err(error) => {
1261                        let completion_mode =
1262                            this.read_with(cx, |thread, _cx| thread.completion_mode())?;
1263                        if completion_mode == CompletionMode::Normal {
1264                            return Err(error.into());
1265                        }
1266
1267                        let Some(strategy) = Self::retry_strategy_for(&error) else {
1268                            return Err(error.into());
1269                        };
1270
1271                        let max_attempts = match &strategy {
1272                            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
1273                            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
1274                        };
1275
1276                        let attempt = attempt.get_or_insert(0u8);
1277
1278                        *attempt += 1;
1279
1280                        let attempt = *attempt;
1281                        if attempt > max_attempts {
1282                            return Err(error.into());
1283                        }
1284
1285                        let delay = match &strategy {
1286                            RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
1287                                let delay_secs =
1288                                    initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
1289                                Duration::from_secs(delay_secs)
1290                            }
1291                            RetryStrategy::Fixed { delay, .. } => *delay,
1292                        };
1293                        log::debug!("Retry attempt {attempt} with delay {delay:?}");
1294
1295                        event_stream.send_retry(acp_thread::RetryStatus {
1296                            last_error: error.to_string().into(),
1297                            attempt: attempt as usize,
1298                            max_attempts: max_attempts as usize,
1299                            started_at: Instant::now(),
1300                            duration: delay,
1301                        });
1302
1303                        cx.background_executor().timer(delay).await;
1304                        continue 'retry;
1305                    }
1306                }
1307            }
1308
1309            return Ok(tool_uses);
1310        }
1311    }
1312
1313    pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
1314        log::debug!("Building system message");
1315        let prompt = SystemPromptTemplate {
1316            project: self.project_context.read(cx),
1317            available_tools: self.tools.keys().cloned().collect(),
1318        }
1319        .render(&self.templates)
1320        .context("failed to build system prompt")
1321        .expect("Invalid template");
1322        log::debug!("System message built");
1323        LanguageModelRequestMessage {
1324            role: Role::System,
1325            content: vec![prompt.into()],
1326            cache: true,
1327        }
1328    }
1329
1330    /// A helper method that's called on every streamed completion event.
1331    /// Returns an optional tool result task, which the main agentic loop in
1332    /// send will send back to the model when it resolves.
1333    fn handle_streamed_completion_event(
1334        &mut self,
1335        event: LanguageModelCompletionEvent,
1336        event_stream: &ThreadEventStream,
1337        cx: &mut Context<Self>,
1338    ) -> Option<Task<LanguageModelToolResult>> {
1339        log::trace!("Handling streamed completion event: {:?}", event);
1340        use LanguageModelCompletionEvent::*;
1341
1342        match event {
1343            StartMessage { .. } => {
1344                self.flush_pending_message(cx);
1345                self.pending_message = Some(AgentMessage::default());
1346            }
1347            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1348            Thinking { text, signature } => {
1349                self.handle_thinking_event(text, signature, event_stream, cx)
1350            }
1351            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1352            ToolUse(tool_use) => {
1353                return self.handle_tool_use_event(tool_use, event_stream, cx);
1354            }
1355            ToolUseJsonParseError {
1356                id,
1357                tool_name,
1358                raw_input,
1359                json_parse_error,
1360            } => {
1361                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1362                    id,
1363                    tool_name,
1364                    raw_input,
1365                    json_parse_error,
1366                )));
1367            }
1368            StatusUpdate(_) => {}
1369            UsageUpdate(_) | Stop(_) => unreachable!(),
1370        }
1371
1372        None
1373    }
1374
1375    fn handle_text_event(
1376        &mut self,
1377        new_text: String,
1378        event_stream: &ThreadEventStream,
1379        cx: &mut Context<Self>,
1380    ) {
1381        event_stream.send_text(&new_text);
1382
1383        let last_message = self.pending_message();
1384        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1385            text.push_str(&new_text);
1386        } else {
1387            last_message
1388                .content
1389                .push(AgentMessageContent::Text(new_text));
1390        }
1391
1392        cx.notify();
1393    }
1394
1395    fn handle_thinking_event(
1396        &mut self,
1397        new_text: String,
1398        new_signature: Option<String>,
1399        event_stream: &ThreadEventStream,
1400        cx: &mut Context<Self>,
1401    ) {
1402        event_stream.send_thinking(&new_text);
1403
1404        let last_message = self.pending_message();
1405        if let Some(AgentMessageContent::Thinking { text, signature }) =
1406            last_message.content.last_mut()
1407        {
1408            text.push_str(&new_text);
1409            *signature = new_signature.or(signature.take());
1410        } else {
1411            last_message.content.push(AgentMessageContent::Thinking {
1412                text: new_text,
1413                signature: new_signature,
1414            });
1415        }
1416
1417        cx.notify();
1418    }
1419
1420    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1421        let last_message = self.pending_message();
1422        last_message
1423            .content
1424            .push(AgentMessageContent::RedactedThinking(data));
1425        cx.notify();
1426    }
1427
1428    fn handle_tool_use_event(
1429        &mut self,
1430        tool_use: LanguageModelToolUse,
1431        event_stream: &ThreadEventStream,
1432        cx: &mut Context<Self>,
1433    ) -> Option<Task<LanguageModelToolResult>> {
1434        cx.notify();
1435
1436        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1437        let mut title = SharedString::from(&tool_use.name);
1438        let mut kind = acp::ToolKind::Other;
1439        if let Some(tool) = tool.as_ref() {
1440            title = tool.initial_title(tool_use.input.clone());
1441            kind = tool.kind();
1442        }
1443
1444        // Ensure the last message ends in the current tool use
1445        let last_message = self.pending_message();
1446        let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
1447            if let AgentMessageContent::ToolUse(last_tool_use) = content {
1448                if last_tool_use.id == tool_use.id {
1449                    *last_tool_use = tool_use.clone();
1450                    false
1451                } else {
1452                    true
1453                }
1454            } else {
1455                true
1456            }
1457        });
1458
1459        if push_new_tool_use {
1460            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1461            last_message
1462                .content
1463                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1464        } else {
1465            event_stream.update_tool_call_fields(
1466                &tool_use.id,
1467                acp::ToolCallUpdateFields {
1468                    title: Some(title.into()),
1469                    kind: Some(kind),
1470                    raw_input: Some(tool_use.input.clone()),
1471                    ..Default::default()
1472                },
1473            );
1474        }
1475
1476        if !tool_use.is_input_complete {
1477            return None;
1478        }
1479
1480        let Some(tool) = tool else {
1481            let content = format!("No tool named {} exists", tool_use.name);
1482            return Some(Task::ready(LanguageModelToolResult {
1483                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1484                tool_use_id: tool_use.id,
1485                tool_name: tool_use.name,
1486                is_error: true,
1487                output: None,
1488            }));
1489        };
1490
1491        let fs = self.project.read(cx).fs().clone();
1492        let tool_event_stream =
1493            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1494        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1495            status: Some(acp::ToolCallStatus::InProgress),
1496            ..Default::default()
1497        });
1498        let supports_images = self.model().is_some_and(|model| model.supports_images());
1499        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1500        log::info!("Running tool {}", tool_use.name);
1501        Some(cx.foreground_executor().spawn(async move {
1502            let tool_result = tool_result.await.and_then(|output| {
1503                if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1504                    && !supports_images
1505                {
1506                    return Err(anyhow!(
1507                        "Attempted to read an image, but this model doesn't support it.",
1508                    ));
1509                }
1510                Ok(output)
1511            });
1512
1513            match tool_result {
1514                Ok(output) => LanguageModelToolResult {
1515                    tool_use_id: tool_use.id,
1516                    tool_name: tool_use.name,
1517                    is_error: false,
1518                    content: output.llm_output,
1519                    output: Some(output.raw_output),
1520                },
1521                Err(error) => LanguageModelToolResult {
1522                    tool_use_id: tool_use.id,
1523                    tool_name: tool_use.name,
1524                    is_error: true,
1525                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1526                    output: None,
1527                },
1528            }
1529        }))
1530    }
1531
1532    fn handle_tool_use_json_parse_error_event(
1533        &mut self,
1534        tool_use_id: LanguageModelToolUseId,
1535        tool_name: Arc<str>,
1536        raw_input: Arc<str>,
1537        json_parse_error: String,
1538    ) -> LanguageModelToolResult {
1539        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1540        LanguageModelToolResult {
1541            tool_use_id,
1542            tool_name,
1543            is_error: true,
1544            content: LanguageModelToolResultContent::Text(tool_output.into()),
1545            output: Some(serde_json::Value::String(raw_input.to_string())),
1546        }
1547    }
1548
1549    fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context<Self>) {
1550        self.project
1551            .read(cx)
1552            .user_store()
1553            .update(cx, |user_store, cx| {
1554                user_store.update_model_request_usage(
1555                    ModelRequestUsage(RequestUsage {
1556                        amount: amount as i32,
1557                        limit,
1558                    }),
1559                    cx,
1560                )
1561            });
1562    }
1563
1564    pub fn title(&self) -> SharedString {
1565        self.title.clone().unwrap_or("New Thread".into())
1566    }
1567
1568    pub fn summary(&mut self, cx: &mut Context<Self>) -> Task<Result<SharedString>> {
1569        if let Some(summary) = self.summary.as_ref() {
1570            return Task::ready(Ok(summary.clone()));
1571        }
1572        let Some(model) = self.summarization_model.clone() else {
1573            return Task::ready(Err(anyhow!("No summarization model available")));
1574        };
1575        let mut request = LanguageModelRequest {
1576            intent: Some(CompletionIntent::ThreadContextSummarization),
1577            temperature: AgentSettings::temperature_for_model(&model, cx),
1578            ..Default::default()
1579        };
1580
1581        for message in &self.messages {
1582            request.messages.extend(message.to_request());
1583        }
1584
1585        request.messages.push(LanguageModelRequestMessage {
1586            role: Role::User,
1587            content: vec![SUMMARIZE_THREAD_DETAILED_PROMPT.into()],
1588            cache: false,
1589        });
1590        cx.spawn(async move |this, cx| {
1591            let mut summary = String::new();
1592            let mut messages = model.stream_completion(request, cx).await?;
1593            while let Some(event) = messages.next().await {
1594                let event = event?;
1595                let text = match event {
1596                    LanguageModelCompletionEvent::Text(text) => text,
1597                    LanguageModelCompletionEvent::StatusUpdate(
1598                        CompletionRequestStatus::UsageUpdated { amount, limit },
1599                    ) => {
1600                        this.update(cx, |thread, cx| {
1601                            thread.update_model_request_usage(amount, limit, cx);
1602                        })?;
1603                        continue;
1604                    }
1605                    _ => continue,
1606                };
1607
1608                let mut lines = text.lines();
1609                summary.extend(lines.next());
1610            }
1611
1612            log::info!("Setting summary: {}", summary);
1613            let summary = SharedString::from(summary);
1614
1615            this.update(cx, |this, cx| {
1616                this.summary = Some(summary.clone());
1617                cx.notify()
1618            })?;
1619
1620            Ok(summary)
1621        })
1622    }
1623
1624    fn update_title(
1625        &mut self,
1626        event_stream: &ThreadEventStream,
1627        cx: &mut Context<Self>,
1628    ) -> Task<Result<()>> {
1629        log::info!(
1630            "Generating title with model: {:?}",
1631            self.summarization_model.as_ref().map(|model| model.name())
1632        );
1633        let Some(model) = self.summarization_model.clone() else {
1634            return Task::ready(Ok(()));
1635        };
1636        let event_stream = event_stream.clone();
1637        let mut request = LanguageModelRequest {
1638            intent: Some(CompletionIntent::ThreadSummarization),
1639            temperature: AgentSettings::temperature_for_model(&model, cx),
1640            ..Default::default()
1641        };
1642
1643        for message in &self.messages {
1644            request.messages.extend(message.to_request());
1645        }
1646
1647        request.messages.push(LanguageModelRequestMessage {
1648            role: Role::User,
1649            content: vec![SUMMARIZE_THREAD_PROMPT.into()],
1650            cache: false,
1651        });
1652        cx.spawn(async move |this, cx| {
1653            let mut title = String::new();
1654            let mut messages = model.stream_completion(request, cx).await?;
1655            while let Some(event) = messages.next().await {
1656                let event = event?;
1657                let text = match event {
1658                    LanguageModelCompletionEvent::Text(text) => text,
1659                    LanguageModelCompletionEvent::StatusUpdate(
1660                        CompletionRequestStatus::UsageUpdated { amount, limit },
1661                    ) => {
1662                        this.update(cx, |thread, cx| {
1663                            thread.update_model_request_usage(amount, limit, cx);
1664                        })?;
1665                        continue;
1666                    }
1667                    _ => continue,
1668                };
1669
1670                let mut lines = text.lines();
1671                title.extend(lines.next());
1672
1673                // Stop if the LLM generated multiple lines.
1674                if lines.next().is_some() {
1675                    break;
1676                }
1677            }
1678
1679            log::info!("Setting title: {}", title);
1680
1681            this.update(cx, |this, cx| {
1682                let title = SharedString::from(title);
1683                event_stream.send_title_update(title.clone());
1684                this.title = Some(title);
1685                cx.notify();
1686            })
1687        })
1688    }
1689
1690    fn last_user_message(&self) -> Option<&UserMessage> {
1691        self.messages
1692            .iter()
1693            .rev()
1694            .find_map(|message| match message {
1695                Message::User(user_message) => Some(user_message),
1696                Message::Agent(_) => None,
1697                Message::Resume => None,
1698            })
1699    }
1700
1701    fn pending_message(&mut self) -> &mut AgentMessage {
1702        self.pending_message.get_or_insert_default()
1703    }
1704
1705    fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1706        let Some(mut message) = self.pending_message.take() else {
1707            return;
1708        };
1709
1710        for content in &message.content {
1711            let AgentMessageContent::ToolUse(tool_use) = content else {
1712                continue;
1713            };
1714
1715            if !message.tool_results.contains_key(&tool_use.id) {
1716                message.tool_results.insert(
1717                    tool_use.id.clone(),
1718                    LanguageModelToolResult {
1719                        tool_use_id: tool_use.id.clone(),
1720                        tool_name: tool_use.name.clone(),
1721                        is_error: true,
1722                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1723                        output: None,
1724                    },
1725                );
1726            }
1727        }
1728
1729        self.messages.push(Message::Agent(message));
1730        self.updated_at = Utc::now();
1731        self.summary = None;
1732        cx.notify()
1733    }
1734
1735    pub(crate) fn build_completion_request(
1736        &self,
1737        completion_intent: CompletionIntent,
1738        cx: &mut App,
1739    ) -> Result<LanguageModelRequest> {
1740        let model = self.model().context("No language model configured")?;
1741
1742        log::debug!("Building completion request");
1743        log::debug!("Completion intent: {:?}", completion_intent);
1744        log::debug!("Completion mode: {:?}", self.completion_mode);
1745
1746        let messages = self.build_request_messages(cx);
1747        log::info!("Request will include {} messages", messages.len());
1748
1749        let tools = if let Some(tools) = self.tools(cx).log_err() {
1750            tools
1751                .filter_map(|tool| {
1752                    let tool_name = tool.name().to_string();
1753                    log::trace!("Including tool: {}", tool_name);
1754                    Some(LanguageModelRequestTool {
1755                        name: tool_name,
1756                        description: tool.description().to_string(),
1757                        input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1758                    })
1759                })
1760                .collect()
1761        } else {
1762            Vec::new()
1763        };
1764
1765        log::info!("Request includes {} tools", tools.len());
1766
1767        let request = LanguageModelRequest {
1768            thread_id: Some(self.id.to_string()),
1769            prompt_id: Some(self.prompt_id.to_string()),
1770            intent: Some(completion_intent),
1771            mode: Some(self.completion_mode.into()),
1772            messages,
1773            tools,
1774            tool_choice: None,
1775            stop: Vec::new(),
1776            temperature: AgentSettings::temperature_for_model(model, cx),
1777            thinking_allowed: true,
1778        };
1779
1780        log::debug!("Completion request built successfully");
1781        Ok(request)
1782    }
1783
1784    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1785        let model = self.model().context("No language model configured")?;
1786
1787        let profile = AgentSettings::get_global(cx)
1788            .profiles
1789            .get(&self.profile_id)
1790            .context("profile not found")?;
1791        let provider_id = model.provider_id();
1792
1793        Ok(self
1794            .tools
1795            .iter()
1796            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1797            .filter_map(|(tool_name, tool)| {
1798                if profile.is_tool_enabled(tool_name) {
1799                    Some(tool)
1800                } else {
1801                    None
1802                }
1803            })
1804            .chain(self.context_server_registry.read(cx).servers().flat_map(
1805                |(server_id, tools)| {
1806                    tools.iter().filter_map(|(tool_name, tool)| {
1807                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1808                            Some(tool)
1809                        } else {
1810                            None
1811                        }
1812                    })
1813                },
1814            )))
1815    }
1816
1817    fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1818        log::trace!(
1819            "Building request messages from {} thread messages",
1820            self.messages.len()
1821        );
1822        let mut messages = vec![self.build_system_message(cx)];
1823        for message in &self.messages {
1824            messages.extend(message.to_request());
1825        }
1826
1827        if let Some(message) = self.pending_message.as_ref() {
1828            messages.extend(message.to_request());
1829        }
1830
1831        if let Some(last_user_message) = messages
1832            .iter_mut()
1833            .rev()
1834            .find(|message| message.role == Role::User)
1835        {
1836            last_user_message.cache = true;
1837        }
1838
1839        messages
1840    }
1841
1842    pub fn to_markdown(&self) -> String {
1843        let mut markdown = String::new();
1844        for (ix, message) in self.messages.iter().enumerate() {
1845            if ix > 0 {
1846                markdown.push('\n');
1847            }
1848            markdown.push_str(&message.to_markdown());
1849        }
1850
1851        if let Some(message) = self.pending_message.as_ref() {
1852            markdown.push('\n');
1853            markdown.push_str(&message.to_markdown());
1854        }
1855
1856        markdown
1857    }
1858
1859    fn advance_prompt_id(&mut self) {
1860        self.prompt_id = PromptId::new();
1861    }
1862
1863    fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1864        use LanguageModelCompletionError::*;
1865        use http_client::StatusCode;
1866
1867        // General strategy here:
1868        // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1869        // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1870        // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1871        match error {
1872            HttpResponseError {
1873                status_code: StatusCode::TOO_MANY_REQUESTS,
1874                ..
1875            } => Some(RetryStrategy::ExponentialBackoff {
1876                initial_delay: BASE_RETRY_DELAY,
1877                max_attempts: MAX_RETRY_ATTEMPTS,
1878            }),
1879            ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1880                Some(RetryStrategy::Fixed {
1881                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1882                    max_attempts: MAX_RETRY_ATTEMPTS,
1883                })
1884            }
1885            UpstreamProviderError {
1886                status,
1887                retry_after,
1888                ..
1889            } => match *status {
1890                StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1891                    Some(RetryStrategy::Fixed {
1892                        delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1893                        max_attempts: MAX_RETRY_ATTEMPTS,
1894                    })
1895                }
1896                StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1897                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1898                    // Internal Server Error could be anything, retry up to 3 times.
1899                    max_attempts: 3,
1900                }),
1901                status => {
1902                    // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1903                    // but we frequently get them in practice. See https://http.dev/529
1904                    if status.as_u16() == 529 {
1905                        Some(RetryStrategy::Fixed {
1906                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1907                            max_attempts: MAX_RETRY_ATTEMPTS,
1908                        })
1909                    } else {
1910                        Some(RetryStrategy::Fixed {
1911                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1912                            max_attempts: 2,
1913                        })
1914                    }
1915                }
1916            },
1917            ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1918                delay: BASE_RETRY_DELAY,
1919                max_attempts: 3,
1920            }),
1921            ApiReadResponseError { .. }
1922            | HttpSend { .. }
1923            | DeserializeResponse { .. }
1924            | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
1925                delay: BASE_RETRY_DELAY,
1926                max_attempts: 3,
1927            }),
1928            // Retrying these errors definitely shouldn't help.
1929            HttpResponseError {
1930                status_code:
1931                    StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
1932                ..
1933            }
1934            | AuthenticationError { .. }
1935            | PermissionError { .. }
1936            | NoApiKey { .. }
1937            | ApiEndpointNotFound { .. }
1938            | PromptTooLarge { .. } => None,
1939            // These errors might be transient, so retry them
1940            SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
1941                delay: BASE_RETRY_DELAY,
1942                max_attempts: 1,
1943            }),
1944            // Retry all other 4xx and 5xx errors once.
1945            HttpResponseError { status_code, .. }
1946                if status_code.is_client_error() || status_code.is_server_error() =>
1947            {
1948                Some(RetryStrategy::Fixed {
1949                    delay: BASE_RETRY_DELAY,
1950                    max_attempts: 3,
1951                })
1952            }
1953            Other(err)
1954                if err.is::<language_model::PaymentRequiredError>()
1955                    || err.is::<language_model::ModelRequestLimitReachedError>() =>
1956            {
1957                // Retrying won't help for Payment Required or Model Request Limit errors (where
1958                // the user must upgrade to usage-based billing to get more requests, or else wait
1959                // for a significant amount of time for the request limit to reset).
1960                None
1961            }
1962            // Conservatively assume that any other errors are non-retryable
1963            HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
1964                delay: BASE_RETRY_DELAY,
1965                max_attempts: 2,
1966            }),
1967        }
1968    }
1969}
1970
1971struct RunningTurn {
1972    /// Holds the task that handles agent interaction until the end of the turn.
1973    /// Survives across multiple requests as the model performs tool calls and
1974    /// we run tools, report their results.
1975    _task: Task<()>,
1976    /// The current event stream for the running turn. Used to report a final
1977    /// cancellation event if we cancel the turn.
1978    event_stream: ThreadEventStream,
1979}
1980
1981impl RunningTurn {
1982    fn cancel(self) {
1983        log::debug!("Cancelling in progress turn");
1984        self.event_stream.send_canceled();
1985    }
1986}
1987
1988pub struct TokenUsageUpdated(pub Option<acp_thread::TokenUsage>);
1989
1990impl EventEmitter<TokenUsageUpdated> for Thread {}
1991
1992pub trait AgentTool
1993where
1994    Self: 'static + Sized,
1995{
1996    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1997    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1998
1999    fn name(&self) -> SharedString;
2000
2001    fn description(&self) -> SharedString {
2002        let schema = schemars::schema_for!(Self::Input);
2003        SharedString::new(
2004            schema
2005                .get("description")
2006                .and_then(|description| description.as_str())
2007                .unwrap_or_default(),
2008        )
2009    }
2010
2011    fn kind(&self) -> acp::ToolKind;
2012
2013    /// The initial tool title to display. Can be updated during the tool run.
2014    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
2015
2016    /// Returns the JSON schema that describes the tool's input.
2017    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
2018        crate::tool_schema::root_schema_for::<Self::Input>(format)
2019    }
2020
2021    /// Some tools rely on a provider for the underlying billing or other reasons.
2022    /// Allow the tool to check if they are compatible, or should be filtered out.
2023    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2024        true
2025    }
2026
2027    /// Runs the tool with the provided input.
2028    fn run(
2029        self: Arc<Self>,
2030        input: Self::Input,
2031        event_stream: ToolCallEventStream,
2032        cx: &mut App,
2033    ) -> Task<Result<Self::Output>>;
2034
2035    /// Emits events for a previous execution of the tool.
2036    fn replay(
2037        &self,
2038        _input: Self::Input,
2039        _output: Self::Output,
2040        _event_stream: ToolCallEventStream,
2041        _cx: &mut App,
2042    ) -> Result<()> {
2043        Ok(())
2044    }
2045
2046    fn erase(self) -> Arc<dyn AnyAgentTool> {
2047        Arc::new(Erased(Arc::new(self)))
2048    }
2049}
2050
2051pub struct Erased<T>(T);
2052
2053pub struct AgentToolOutput {
2054    pub llm_output: LanguageModelToolResultContent,
2055    pub raw_output: serde_json::Value,
2056}
2057
2058pub trait AnyAgentTool {
2059    fn name(&self) -> SharedString;
2060    fn description(&self) -> SharedString;
2061    fn kind(&self) -> acp::ToolKind;
2062    fn initial_title(&self, input: serde_json::Value) -> SharedString;
2063    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
2064    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2065        true
2066    }
2067    fn run(
2068        self: Arc<Self>,
2069        input: serde_json::Value,
2070        event_stream: ToolCallEventStream,
2071        cx: &mut App,
2072    ) -> Task<Result<AgentToolOutput>>;
2073    fn replay(
2074        &self,
2075        input: serde_json::Value,
2076        output: serde_json::Value,
2077        event_stream: ToolCallEventStream,
2078        cx: &mut App,
2079    ) -> Result<()>;
2080}
2081
2082impl<T> AnyAgentTool for Erased<Arc<T>>
2083where
2084    T: AgentTool,
2085{
2086    fn name(&self) -> SharedString {
2087        self.0.name()
2088    }
2089
2090    fn description(&self) -> SharedString {
2091        self.0.description()
2092    }
2093
2094    fn kind(&self) -> agent_client_protocol::ToolKind {
2095        self.0.kind()
2096    }
2097
2098    fn initial_title(&self, input: serde_json::Value) -> SharedString {
2099        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
2100        self.0.initial_title(parsed_input)
2101    }
2102
2103    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
2104        let mut json = serde_json::to_value(self.0.input_schema(format))?;
2105        adapt_schema_to_format(&mut json, format)?;
2106        Ok(json)
2107    }
2108
2109    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
2110        self.0.supported_provider(provider)
2111    }
2112
2113    fn run(
2114        self: Arc<Self>,
2115        input: serde_json::Value,
2116        event_stream: ToolCallEventStream,
2117        cx: &mut App,
2118    ) -> Task<Result<AgentToolOutput>> {
2119        cx.spawn(async move |cx| {
2120            let input = serde_json::from_value(input)?;
2121            let output = cx
2122                .update(|cx| self.0.clone().run(input, event_stream, cx))?
2123                .await?;
2124            let raw_output = serde_json::to_value(&output)?;
2125            Ok(AgentToolOutput {
2126                llm_output: output.into(),
2127                raw_output,
2128            })
2129        })
2130    }
2131
2132    fn replay(
2133        &self,
2134        input: serde_json::Value,
2135        output: serde_json::Value,
2136        event_stream: ToolCallEventStream,
2137        cx: &mut App,
2138    ) -> Result<()> {
2139        let input = serde_json::from_value(input)?;
2140        let output = serde_json::from_value(output)?;
2141        self.0.replay(input, output, event_stream, cx)
2142    }
2143}
2144
2145#[derive(Clone)]
2146struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
2147
2148impl ThreadEventStream {
2149    fn send_title_update(&self, text: SharedString) {
2150        self.0
2151            .unbounded_send(Ok(ThreadEvent::TitleUpdate(text)))
2152            .ok();
2153    }
2154
2155    fn send_user_message(&self, message: &UserMessage) {
2156        self.0
2157            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
2158            .ok();
2159    }
2160
2161    fn send_text(&self, text: &str) {
2162        self.0
2163            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
2164            .ok();
2165    }
2166
2167    fn send_thinking(&self, text: &str) {
2168        self.0
2169            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
2170            .ok();
2171    }
2172
2173    fn send_tool_call(
2174        &self,
2175        id: &LanguageModelToolUseId,
2176        title: SharedString,
2177        kind: acp::ToolKind,
2178        input: serde_json::Value,
2179    ) {
2180        self.0
2181            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
2182                id,
2183                title.to_string(),
2184                kind,
2185                input,
2186            ))))
2187            .ok();
2188    }
2189
2190    fn initial_tool_call(
2191        id: &LanguageModelToolUseId,
2192        title: String,
2193        kind: acp::ToolKind,
2194        input: serde_json::Value,
2195    ) -> acp::ToolCall {
2196        acp::ToolCall {
2197            id: acp::ToolCallId(id.to_string().into()),
2198            title,
2199            kind,
2200            status: acp::ToolCallStatus::Pending,
2201            content: vec![],
2202            locations: vec![],
2203            raw_input: Some(input),
2204            raw_output: None,
2205        }
2206    }
2207
2208    fn update_tool_call_fields(
2209        &self,
2210        tool_use_id: &LanguageModelToolUseId,
2211        fields: acp::ToolCallUpdateFields,
2212    ) {
2213        self.0
2214            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2215                acp::ToolCallUpdate {
2216                    id: acp::ToolCallId(tool_use_id.to_string().into()),
2217                    fields,
2218                }
2219                .into(),
2220            )))
2221            .ok();
2222    }
2223
2224    fn send_retry(&self, status: acp_thread::RetryStatus) {
2225        self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
2226    }
2227
2228    fn send_stop(&self, reason: StopReason) {
2229        match reason {
2230            StopReason::EndTurn => {
2231                self.0
2232                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
2233                    .ok();
2234            }
2235            StopReason::MaxTokens => {
2236                self.0
2237                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
2238                    .ok();
2239            }
2240            StopReason::Refusal => {
2241                self.0
2242                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
2243                    .ok();
2244            }
2245            StopReason::ToolUse => {}
2246        }
2247    }
2248
2249    fn send_canceled(&self) {
2250        self.0
2251            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
2252            .ok();
2253    }
2254
2255    fn send_error(&self, error: impl Into<anyhow::Error>) {
2256        self.0.unbounded_send(Err(error.into())).ok();
2257    }
2258}
2259
2260#[derive(Clone)]
2261pub struct ToolCallEventStream {
2262    tool_use_id: LanguageModelToolUseId,
2263    stream: ThreadEventStream,
2264    fs: Option<Arc<dyn Fs>>,
2265}
2266
2267impl ToolCallEventStream {
2268    #[cfg(test)]
2269    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
2270        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
2271
2272        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
2273
2274        (stream, ToolCallEventStreamReceiver(events_rx))
2275    }
2276
2277    fn new(
2278        tool_use_id: LanguageModelToolUseId,
2279        stream: ThreadEventStream,
2280        fs: Option<Arc<dyn Fs>>,
2281    ) -> Self {
2282        Self {
2283            tool_use_id,
2284            stream,
2285            fs,
2286        }
2287    }
2288
2289    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
2290        self.stream
2291            .update_tool_call_fields(&self.tool_use_id, fields);
2292    }
2293
2294    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
2295        self.stream
2296            .0
2297            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2298                acp_thread::ToolCallUpdateDiff {
2299                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2300                    diff,
2301                }
2302                .into(),
2303            )))
2304            .ok();
2305    }
2306
2307    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
2308        self.stream
2309            .0
2310            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2311                acp_thread::ToolCallUpdateTerminal {
2312                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2313                    terminal,
2314                }
2315                .into(),
2316            )))
2317            .ok();
2318    }
2319
2320    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
2321        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
2322            return Task::ready(Ok(()));
2323        }
2324
2325        let (response_tx, response_rx) = oneshot::channel();
2326        self.stream
2327            .0
2328            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
2329                ToolCallAuthorization {
2330                    tool_call: acp::ToolCallUpdate {
2331                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2332                        fields: acp::ToolCallUpdateFields {
2333                            title: Some(title.into()),
2334                            ..Default::default()
2335                        },
2336                    },
2337                    options: vec![
2338                        acp::PermissionOption {
2339                            id: acp::PermissionOptionId("always_allow".into()),
2340                            name: "Always Allow".into(),
2341                            kind: acp::PermissionOptionKind::AllowAlways,
2342                        },
2343                        acp::PermissionOption {
2344                            id: acp::PermissionOptionId("allow".into()),
2345                            name: "Allow".into(),
2346                            kind: acp::PermissionOptionKind::AllowOnce,
2347                        },
2348                        acp::PermissionOption {
2349                            id: acp::PermissionOptionId("deny".into()),
2350                            name: "Deny".into(),
2351                            kind: acp::PermissionOptionKind::RejectOnce,
2352                        },
2353                    ],
2354                    response: response_tx,
2355                },
2356            )))
2357            .ok();
2358        let fs = self.fs.clone();
2359        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
2360            "always_allow" => {
2361                if let Some(fs) = fs.clone() {
2362                    cx.update(|cx| {
2363                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
2364                            settings.set_always_allow_tool_actions(true);
2365                        });
2366                    })?;
2367                }
2368
2369                Ok(())
2370            }
2371            "allow" => Ok(()),
2372            _ => Err(anyhow!("Permission to run tool denied by user")),
2373        })
2374    }
2375}
2376
2377#[cfg(test)]
2378pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
2379
2380#[cfg(test)]
2381impl ToolCallEventStreamReceiver {
2382    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
2383        let event = self.0.next().await;
2384        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
2385            auth
2386        } else {
2387            panic!("Expected ToolCallAuthorization but got: {:?}", event);
2388        }
2389    }
2390
2391    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
2392        let event = self.0.next().await;
2393        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
2394            update,
2395        )))) = event
2396        {
2397            update.terminal
2398        } else {
2399            panic!("Expected terminal but got: {:?}", event);
2400        }
2401    }
2402}
2403
2404#[cfg(test)]
2405impl std::ops::Deref for ToolCallEventStreamReceiver {
2406    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
2407
2408    fn deref(&self) -> &Self::Target {
2409        &self.0
2410    }
2411}
2412
2413#[cfg(test)]
2414impl std::ops::DerefMut for ToolCallEventStreamReceiver {
2415    fn deref_mut(&mut self) -> &mut Self::Target {
2416        &mut self.0
2417    }
2418}
2419
2420impl From<&str> for UserMessageContent {
2421    fn from(text: &str) -> Self {
2422        Self::Text(text.into())
2423    }
2424}
2425
2426impl From<acp::ContentBlock> for UserMessageContent {
2427    fn from(value: acp::ContentBlock) -> Self {
2428        match value {
2429            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
2430            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
2431            acp::ContentBlock::Audio(_) => {
2432                // TODO
2433                Self::Text("[audio]".to_string())
2434            }
2435            acp::ContentBlock::ResourceLink(resource_link) => {
2436                match MentionUri::parse(&resource_link.uri) {
2437                    Ok(uri) => Self::Mention {
2438                        uri,
2439                        content: String::new(),
2440                    },
2441                    Err(err) => {
2442                        log::error!("Failed to parse mention link: {}", err);
2443                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2444                    }
2445                }
2446            }
2447            acp::ContentBlock::Resource(resource) => match resource.resource {
2448                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2449                    match MentionUri::parse(&resource.uri) {
2450                        Ok(uri) => Self::Mention {
2451                            uri,
2452                            content: resource.text,
2453                        },
2454                        Err(err) => {
2455                            log::error!("Failed to parse mention link: {}", err);
2456                            Self::Text(
2457                                MarkdownCodeBlock {
2458                                    tag: &resource.uri,
2459                                    text: &resource.text,
2460                                }
2461                                .to_string(),
2462                            )
2463                        }
2464                    }
2465                }
2466                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2467                    // TODO
2468                    Self::Text("[blob]".to_string())
2469                }
2470            },
2471        }
2472    }
2473}
2474
2475impl From<UserMessageContent> for acp::ContentBlock {
2476    fn from(content: UserMessageContent) -> Self {
2477        match content {
2478            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
2479                text,
2480                annotations: None,
2481            }),
2482            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
2483                data: image.source.to_string(),
2484                mime_type: "image/png".to_string(),
2485                annotations: None,
2486                uri: None,
2487            }),
2488            UserMessageContent::Mention { uri, content } => {
2489                acp::ContentBlock::Resource(acp::EmbeddedResource {
2490                    resource: acp::EmbeddedResourceResource::TextResourceContents(
2491                        acp::TextResourceContents {
2492                            mime_type: None,
2493                            text: content,
2494                            uri: uri.to_uri().to_string(),
2495                        },
2496                    ),
2497                    annotations: None,
2498                })
2499            }
2500        }
2501    }
2502}
2503
2504fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2505    LanguageModelImage {
2506        source: image_content.data.into(),
2507        // TODO: make this optional?
2508        size: gpui::Size::new(0.into(), 0.into()),
2509    }
2510}