thread.rs

   1use crate::{
   2    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
   3    DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
   4    ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
   5    Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
   6};
   7use acp_thread::{MentionUri, UserMessageId};
   8use action_log::ActionLog;
   9use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
  10use agent_client_protocol as acp;
  11use agent_settings::{
  12    AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode,
  13    SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT,
  14};
  15use anyhow::{Context as _, Result, anyhow};
  16use assistant_tool::adapt_schema_to_format;
  17use chrono::{DateTime, Utc};
  18use client::{ModelRequestUsage, RequestUsage};
  19use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  20use collections::{HashMap, HashSet, IndexMap};
  21use fs::Fs;
  22use futures::{
  23    FutureExt,
  24    channel::{mpsc, oneshot},
  25    future::Shared,
  26    stream::FuturesUnordered,
  27};
  28use git::repository::DiffType;
  29use gpui::{
  30    App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
  31};
  32use language_model::{
  33    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
  34    LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
  35    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  36    LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
  37    LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
  38};
  39use project::{
  40    Project,
  41    git_store::{GitStore, RepositoryState},
  42};
  43use prompt_store::ProjectContext;
  44use schemars::{JsonSchema, Schema};
  45use serde::{Deserialize, Serialize};
  46use settings::{Settings, update_settings_file};
  47use smol::stream::StreamExt;
  48use std::{
  49    collections::BTreeMap,
  50    path::Path,
  51    sync::Arc,
  52    time::{Duration, Instant},
  53};
  54use std::{fmt::Write, ops::Range};
  55use util::{ResultExt, markdown::MarkdownCodeBlock};
  56use uuid::Uuid;
  57
  58const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  59pub const MAX_TOOL_NAME_LENGTH: usize = 64;
  60
  61/// The ID of the user prompt that initiated a request.
  62///
  63/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  64#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  65pub struct PromptId(Arc<str>);
  66
  67impl PromptId {
  68    pub fn new() -> Self {
  69        Self(Uuid::new_v4().to_string().into())
  70    }
  71}
  72
  73impl std::fmt::Display for PromptId {
  74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  75        write!(f, "{}", self.0)
  76    }
  77}
  78
  79pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
  80pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
  81
  82#[derive(Debug, Clone)]
  83enum RetryStrategy {
  84    ExponentialBackoff {
  85        initial_delay: Duration,
  86        max_attempts: u8,
  87    },
  88    Fixed {
  89        delay: Duration,
  90        max_attempts: u8,
  91    },
  92}
  93
  94#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  95pub enum Message {
  96    User(UserMessage),
  97    Agent(AgentMessage),
  98    Resume,
  99}
 100
 101impl Message {
 102    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
 103        match self {
 104            Message::Agent(agent_message) => Some(agent_message),
 105            _ => None,
 106        }
 107    }
 108
 109    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 110        match self {
 111            Message::User(message) => vec![message.to_request()],
 112            Message::Agent(message) => message.to_request(),
 113            Message::Resume => vec![LanguageModelRequestMessage {
 114                role: Role::User,
 115                content: vec!["Continue where you left off".into()],
 116                cache: false,
 117            }],
 118        }
 119    }
 120
 121    pub fn to_markdown(&self) -> String {
 122        match self {
 123            Message::User(message) => message.to_markdown(),
 124            Message::Agent(message) => message.to_markdown(),
 125            Message::Resume => "[resumed after tool use limit was reached]".into(),
 126        }
 127    }
 128
 129    pub fn role(&self) -> Role {
 130        match self {
 131            Message::User(_) | Message::Resume => Role::User,
 132            Message::Agent(_) => Role::Assistant,
 133        }
 134    }
 135}
 136
 137#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 138pub struct UserMessage {
 139    pub id: UserMessageId,
 140    pub content: Vec<UserMessageContent>,
 141}
 142
 143#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 144pub enum UserMessageContent {
 145    Text(String),
 146    Mention { uri: MentionUri, content: String },
 147    Image(LanguageModelImage),
 148}
 149
 150impl UserMessage {
 151    pub fn to_markdown(&self) -> String {
 152        let mut markdown = String::from("## User\n\n");
 153
 154        for content in &self.content {
 155            match content {
 156                UserMessageContent::Text(text) => {
 157                    markdown.push_str(text);
 158                    markdown.push('\n');
 159                }
 160                UserMessageContent::Image(_) => {
 161                    markdown.push_str("<image />\n");
 162                }
 163                UserMessageContent::Mention { uri, content } => {
 164                    if !content.is_empty() {
 165                        let _ = writeln!(&mut markdown, "{}\n\n{}", uri.as_link(), content);
 166                    } else {
 167                        let _ = writeln!(&mut markdown, "{}", uri.as_link());
 168                    }
 169                }
 170            }
 171        }
 172
 173        markdown
 174    }
 175
 176    fn to_request(&self) -> LanguageModelRequestMessage {
 177        let mut message = LanguageModelRequestMessage {
 178            role: Role::User,
 179            content: Vec::with_capacity(self.content.len()),
 180            cache: false,
 181        };
 182
 183        const OPEN_CONTEXT: &str = "<context>\n\
 184            The following items were attached by the user. \
 185            They are up-to-date and don't need to be re-read.\n\n";
 186
 187        const OPEN_FILES_TAG: &str = "<files>";
 188        const OPEN_DIRECTORIES_TAG: &str = "<directories>";
 189        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 190        const OPEN_THREADS_TAG: &str = "<threads>";
 191        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 192        const OPEN_RULES_TAG: &str =
 193            "<rules>\nThe user has specified the following rules that should be applied:\n";
 194
 195        let mut file_context = OPEN_FILES_TAG.to_string();
 196        let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
 197        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 198        let mut thread_context = OPEN_THREADS_TAG.to_string();
 199        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 200        let mut rules_context = OPEN_RULES_TAG.to_string();
 201
 202        for chunk in &self.content {
 203            let chunk = match chunk {
 204                UserMessageContent::Text(text) => {
 205                    language_model::MessageContent::Text(text.clone())
 206                }
 207                UserMessageContent::Image(value) => {
 208                    language_model::MessageContent::Image(value.clone())
 209                }
 210                UserMessageContent::Mention { uri, content } => {
 211                    match uri {
 212                        MentionUri::File { abs_path } => {
 213                            write!(
 214                                &mut symbol_context,
 215                                "\n{}",
 216                                MarkdownCodeBlock {
 217                                    tag: &codeblock_tag(abs_path, None),
 218                                    text: &content.to_string(),
 219                                }
 220                            )
 221                            .ok();
 222                        }
 223                        MentionUri::Directory { .. } => {
 224                            write!(&mut directory_context, "\n{}\n", content).ok();
 225                        }
 226                        MentionUri::Symbol {
 227                            path, line_range, ..
 228                        }
 229                        | MentionUri::Selection {
 230                            path, line_range, ..
 231                        } => {
 232                            write!(
 233                                &mut rules_context,
 234                                "\n{}",
 235                                MarkdownCodeBlock {
 236                                    tag: &codeblock_tag(path, Some(line_range)),
 237                                    text: content
 238                                }
 239                            )
 240                            .ok();
 241                        }
 242                        MentionUri::Thread { .. } => {
 243                            write!(&mut thread_context, "\n{}\n", content).ok();
 244                        }
 245                        MentionUri::TextThread { .. } => {
 246                            write!(&mut thread_context, "\n{}\n", content).ok();
 247                        }
 248                        MentionUri::Rule { .. } => {
 249                            write!(
 250                                &mut rules_context,
 251                                "\n{}",
 252                                MarkdownCodeBlock {
 253                                    tag: "",
 254                                    text: content
 255                                }
 256                            )
 257                            .ok();
 258                        }
 259                        MentionUri::Fetch { url } => {
 260                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 261                        }
 262                    }
 263
 264                    language_model::MessageContent::Text(uri.as_link().to_string())
 265                }
 266            };
 267
 268            message.content.push(chunk);
 269        }
 270
 271        let len_before_context = message.content.len();
 272
 273        if file_context.len() > OPEN_FILES_TAG.len() {
 274            file_context.push_str("</files>\n");
 275            message
 276                .content
 277                .push(language_model::MessageContent::Text(file_context));
 278        }
 279
 280        if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
 281            directory_context.push_str("</directories>\n");
 282            message
 283                .content
 284                .push(language_model::MessageContent::Text(directory_context));
 285        }
 286
 287        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 288            symbol_context.push_str("</symbols>\n");
 289            message
 290                .content
 291                .push(language_model::MessageContent::Text(symbol_context));
 292        }
 293
 294        if thread_context.len() > OPEN_THREADS_TAG.len() {
 295            thread_context.push_str("</threads>\n");
 296            message
 297                .content
 298                .push(language_model::MessageContent::Text(thread_context));
 299        }
 300
 301        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 302            fetch_context.push_str("</fetched_urls>\n");
 303            message
 304                .content
 305                .push(language_model::MessageContent::Text(fetch_context));
 306        }
 307
 308        if rules_context.len() > OPEN_RULES_TAG.len() {
 309            rules_context.push_str("</user_rules>\n");
 310            message
 311                .content
 312                .push(language_model::MessageContent::Text(rules_context));
 313        }
 314
 315        if message.content.len() > len_before_context {
 316            message.content.insert(
 317                len_before_context,
 318                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 319            );
 320            message
 321                .content
 322                .push(language_model::MessageContent::Text("</context>".into()));
 323        }
 324
 325        message
 326    }
 327}
 328
 329fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 330    let mut result = String::new();
 331
 332    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 333        let _ = write!(result, "{} ", extension);
 334    }
 335
 336    let _ = write!(result, "{}", full_path.display());
 337
 338    if let Some(range) = line_range {
 339        if range.start == range.end {
 340            let _ = write!(result, ":{}", range.start + 1);
 341        } else {
 342            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 343        }
 344    }
 345
 346    result
 347}
 348
 349impl AgentMessage {
 350    pub fn to_markdown(&self) -> String {
 351        let mut markdown = String::from("## Assistant\n\n");
 352
 353        for content in &self.content {
 354            match content {
 355                AgentMessageContent::Text(text) => {
 356                    markdown.push_str(text);
 357                    markdown.push('\n');
 358                }
 359                AgentMessageContent::Thinking { text, .. } => {
 360                    markdown.push_str("<think>");
 361                    markdown.push_str(text);
 362                    markdown.push_str("</think>\n");
 363                }
 364                AgentMessageContent::RedactedThinking(_) => {
 365                    markdown.push_str("<redacted_thinking />\n")
 366                }
 367                AgentMessageContent::ToolUse(tool_use) => {
 368                    markdown.push_str(&format!(
 369                        "**Tool Use**: {} (ID: {})\n",
 370                        tool_use.name, tool_use.id
 371                    ));
 372                    markdown.push_str(&format!(
 373                        "{}\n",
 374                        MarkdownCodeBlock {
 375                            tag: "json",
 376                            text: &format!("{:#}", tool_use.input)
 377                        }
 378                    ));
 379                }
 380            }
 381        }
 382
 383        for tool_result in self.tool_results.values() {
 384            markdown.push_str(&format!(
 385                "**Tool Result**: {} (ID: {})\n\n",
 386                tool_result.tool_name, tool_result.tool_use_id
 387            ));
 388            if tool_result.is_error {
 389                markdown.push_str("**ERROR:**\n");
 390            }
 391
 392            match &tool_result.content {
 393                LanguageModelToolResultContent::Text(text) => {
 394                    writeln!(markdown, "{text}\n").ok();
 395                }
 396                LanguageModelToolResultContent::Image(_) => {
 397                    writeln!(markdown, "<image />\n").ok();
 398                }
 399            }
 400
 401            if let Some(output) = tool_result.output.as_ref() {
 402                writeln!(
 403                    markdown,
 404                    "**Debug Output**:\n\n```json\n{}\n```\n",
 405                    serde_json::to_string_pretty(output).unwrap()
 406                )
 407                .unwrap();
 408            }
 409        }
 410
 411        markdown
 412    }
 413
 414    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 415        let mut assistant_message = LanguageModelRequestMessage {
 416            role: Role::Assistant,
 417            content: Vec::with_capacity(self.content.len()),
 418            cache: false,
 419        };
 420        for chunk in &self.content {
 421            let chunk = match chunk {
 422                AgentMessageContent::Text(text) => {
 423                    language_model::MessageContent::Text(text.clone())
 424                }
 425                AgentMessageContent::Thinking { text, signature } => {
 426                    language_model::MessageContent::Thinking {
 427                        text: text.clone(),
 428                        signature: signature.clone(),
 429                    }
 430                }
 431                AgentMessageContent::RedactedThinking(value) => {
 432                    language_model::MessageContent::RedactedThinking(value.clone())
 433                }
 434                AgentMessageContent::ToolUse(value) => {
 435                    language_model::MessageContent::ToolUse(value.clone())
 436                }
 437            };
 438            assistant_message.content.push(chunk);
 439        }
 440
 441        let mut user_message = LanguageModelRequestMessage {
 442            role: Role::User,
 443            content: Vec::new(),
 444            cache: false,
 445        };
 446
 447        for tool_result in self.tool_results.values() {
 448            user_message
 449                .content
 450                .push(language_model::MessageContent::ToolResult(
 451                    tool_result.clone(),
 452                ));
 453        }
 454
 455        let mut messages = Vec::new();
 456        if !assistant_message.content.is_empty() {
 457            messages.push(assistant_message);
 458        }
 459        if !user_message.content.is_empty() {
 460            messages.push(user_message);
 461        }
 462        messages
 463    }
 464}
 465
 466#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 467pub struct AgentMessage {
 468    pub content: Vec<AgentMessageContent>,
 469    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 470}
 471
 472#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 473pub enum AgentMessageContent {
 474    Text(String),
 475    Thinking {
 476        text: String,
 477        signature: Option<String>,
 478    },
 479    RedactedThinking(String),
 480    ToolUse(LanguageModelToolUse),
 481}
 482
 483#[derive(Debug)]
 484pub enum ThreadEvent {
 485    UserMessage(UserMessage),
 486    AgentText(String),
 487    AgentThinking(String),
 488    ToolCall(acp::ToolCall),
 489    ToolCallUpdate(acp_thread::ToolCallUpdate),
 490    ToolCallAuthorization(ToolCallAuthorization),
 491    Retry(acp_thread::RetryStatus),
 492    Stop(acp::StopReason),
 493}
 494
 495#[derive(Debug)]
 496pub struct ToolCallAuthorization {
 497    pub tool_call: acp::ToolCallUpdate,
 498    pub options: Vec<acp::PermissionOption>,
 499    pub response: oneshot::Sender<acp::PermissionOptionId>,
 500}
 501
 502#[derive(Debug, thiserror::Error)]
 503enum CompletionError {
 504    #[error("max tokens")]
 505    MaxTokens,
 506    #[error("refusal")]
 507    Refusal,
 508    #[error(transparent)]
 509    Other(#[from] anyhow::Error),
 510}
 511
 512pub struct Thread {
 513    id: acp::SessionId,
 514    prompt_id: PromptId,
 515    updated_at: DateTime<Utc>,
 516    title: Option<SharedString>,
 517    pending_title_generation: Option<Task<()>>,
 518    summary: Option<SharedString>,
 519    messages: Vec<Message>,
 520    completion_mode: CompletionMode,
 521    /// Holds the task that handles agent interaction until the end of the turn.
 522    /// Survives across multiple requests as the model performs tool calls and
 523    /// we run tools, report their results.
 524    running_turn: Option<RunningTurn>,
 525    pending_message: Option<AgentMessage>,
 526    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 527    tool_use_limit_reached: bool,
 528    request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
 529    #[allow(unused)]
 530    cumulative_token_usage: TokenUsage,
 531    #[allow(unused)]
 532    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 533    context_server_registry: Entity<ContextServerRegistry>,
 534    profile_id: AgentProfileId,
 535    project_context: Entity<ProjectContext>,
 536    templates: Arc<Templates>,
 537    model: Option<Arc<dyn LanguageModel>>,
 538    summarization_model: Option<Arc<dyn LanguageModel>>,
 539    pub(crate) project: Entity<Project>,
 540    pub(crate) action_log: Entity<ActionLog>,
 541}
 542
 543impl Thread {
 544    pub fn new(
 545        project: Entity<Project>,
 546        project_context: Entity<ProjectContext>,
 547        context_server_registry: Entity<ContextServerRegistry>,
 548        templates: Arc<Templates>,
 549        model: Option<Arc<dyn LanguageModel>>,
 550        cx: &mut Context<Self>,
 551    ) -> Self {
 552        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 553        let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
 554        Self {
 555            id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
 556            prompt_id: PromptId::new(),
 557            updated_at: Utc::now(),
 558            title: None,
 559            pending_title_generation: None,
 560            summary: None,
 561            messages: Vec::new(),
 562            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 563            running_turn: None,
 564            pending_message: None,
 565            tools: BTreeMap::default(),
 566            tool_use_limit_reached: false,
 567            request_token_usage: HashMap::default(),
 568            cumulative_token_usage: TokenUsage::default(),
 569            initial_project_snapshot: {
 570                let project_snapshot = Self::project_snapshot(project.clone(), cx);
 571                cx.foreground_executor()
 572                    .spawn(async move { Some(project_snapshot.await) })
 573                    .shared()
 574            },
 575            context_server_registry,
 576            profile_id,
 577            project_context,
 578            templates,
 579            model,
 580            summarization_model: None,
 581            project,
 582            action_log,
 583        }
 584    }
 585
 586    pub fn id(&self) -> &acp::SessionId {
 587        &self.id
 588    }
 589
 590    pub fn replay(
 591        &mut self,
 592        cx: &mut Context<Self>,
 593    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 594        let (tx, rx) = mpsc::unbounded();
 595        let stream = ThreadEventStream(tx);
 596        for message in &self.messages {
 597            match message {
 598                Message::User(user_message) => stream.send_user_message(user_message),
 599                Message::Agent(assistant_message) => {
 600                    for content in &assistant_message.content {
 601                        match content {
 602                            AgentMessageContent::Text(text) => stream.send_text(text),
 603                            AgentMessageContent::Thinking { text, .. } => {
 604                                stream.send_thinking(text)
 605                            }
 606                            AgentMessageContent::RedactedThinking(_) => {}
 607                            AgentMessageContent::ToolUse(tool_use) => {
 608                                self.replay_tool_call(
 609                                    tool_use,
 610                                    assistant_message.tool_results.get(&tool_use.id),
 611                                    &stream,
 612                                    cx,
 613                                );
 614                            }
 615                        }
 616                    }
 617                }
 618                Message::Resume => {}
 619            }
 620        }
 621        rx
 622    }
 623
 624    fn replay_tool_call(
 625        &self,
 626        tool_use: &LanguageModelToolUse,
 627        tool_result: Option<&LanguageModelToolResult>,
 628        stream: &ThreadEventStream,
 629        cx: &mut Context<Self>,
 630    ) {
 631        let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| {
 632            self.context_server_registry
 633                .read(cx)
 634                .servers()
 635                .find_map(|(_, tools)| {
 636                    if let Some(tool) = tools.get(tool_use.name.as_ref()) {
 637                        Some(tool.clone())
 638                    } else {
 639                        None
 640                    }
 641                })
 642        });
 643
 644        let Some(tool) = tool else {
 645            stream
 646                .0
 647                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 648                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 649                    title: tool_use.name.to_string(),
 650                    kind: acp::ToolKind::Other,
 651                    status: acp::ToolCallStatus::Failed,
 652                    content: Vec::new(),
 653                    locations: Vec::new(),
 654                    raw_input: Some(tool_use.input.clone()),
 655                    raw_output: None,
 656                })))
 657                .ok();
 658            return;
 659        };
 660
 661        let title = tool.initial_title(tool_use.input.clone());
 662        let kind = tool.kind();
 663        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 664
 665        let output = tool_result
 666            .as_ref()
 667            .and_then(|result| result.output.clone());
 668        if let Some(output) = output.clone() {
 669            let tool_event_stream = ToolCallEventStream::new(
 670                tool_use.id.clone(),
 671                stream.clone(),
 672                Some(self.project.read(cx).fs().clone()),
 673            );
 674            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 675                .log_err();
 676        }
 677
 678        stream.update_tool_call_fields(
 679            &tool_use.id,
 680            acp::ToolCallUpdateFields {
 681                status: Some(acp::ToolCallStatus::Completed),
 682                raw_output: output,
 683                ..Default::default()
 684            },
 685        );
 686    }
 687
 688    pub fn from_db(
 689        id: acp::SessionId,
 690        db_thread: DbThread,
 691        project: Entity<Project>,
 692        project_context: Entity<ProjectContext>,
 693        context_server_registry: Entity<ContextServerRegistry>,
 694        action_log: Entity<ActionLog>,
 695        templates: Arc<Templates>,
 696        cx: &mut Context<Self>,
 697    ) -> Self {
 698        let profile_id = db_thread
 699            .profile
 700            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 701        let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 702            db_thread
 703                .model
 704                .and_then(|model| {
 705                    let model = SelectedModel {
 706                        provider: model.provider.clone().into(),
 707                        model: model.model.into(),
 708                    };
 709                    registry.select_model(&model, cx)
 710                })
 711                .or_else(|| registry.default_model())
 712                .map(|model| model.model)
 713        });
 714
 715        Self {
 716            id,
 717            prompt_id: PromptId::new(),
 718            title: if db_thread.title.is_empty() {
 719                None
 720            } else {
 721                Some(db_thread.title.clone())
 722            },
 723            pending_title_generation: None,
 724            summary: db_thread.detailed_summary,
 725            messages: db_thread.messages,
 726            completion_mode: db_thread.completion_mode.unwrap_or_default(),
 727            running_turn: None,
 728            pending_message: None,
 729            tools: BTreeMap::default(),
 730            tool_use_limit_reached: false,
 731            request_token_usage: db_thread.request_token_usage.clone(),
 732            cumulative_token_usage: db_thread.cumulative_token_usage,
 733            initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
 734            context_server_registry,
 735            profile_id,
 736            project_context,
 737            templates,
 738            model,
 739            summarization_model: None,
 740            project,
 741            action_log,
 742            updated_at: db_thread.updated_at,
 743        }
 744    }
 745
 746    pub fn to_db(&self, cx: &App) -> Task<DbThread> {
 747        let initial_project_snapshot = self.initial_project_snapshot.clone();
 748        let mut thread = DbThread {
 749            title: self.title(),
 750            messages: self.messages.clone(),
 751            updated_at: self.updated_at,
 752            detailed_summary: self.summary.clone(),
 753            initial_project_snapshot: None,
 754            cumulative_token_usage: self.cumulative_token_usage,
 755            request_token_usage: self.request_token_usage.clone(),
 756            model: self.model.as_ref().map(|model| DbLanguageModel {
 757                provider: model.provider_id().to_string(),
 758                model: model.name().0.to_string(),
 759            }),
 760            completion_mode: Some(self.completion_mode),
 761            profile: Some(self.profile_id.clone()),
 762        };
 763
 764        cx.background_spawn(async move {
 765            let initial_project_snapshot = initial_project_snapshot.await;
 766            thread.initial_project_snapshot = initial_project_snapshot;
 767            thread
 768        })
 769    }
 770
 771    /// Create a snapshot of the current project state including git information and unsaved buffers.
 772    fn project_snapshot(
 773        project: Entity<Project>,
 774        cx: &mut Context<Self>,
 775    ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
 776        let git_store = project.read(cx).git_store().clone();
 777        let worktree_snapshots: Vec<_> = project
 778            .read(cx)
 779            .visible_worktrees(cx)
 780            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
 781            .collect();
 782
 783        cx.spawn(async move |_, cx| {
 784            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
 785
 786            let mut unsaved_buffers = Vec::new();
 787            cx.update(|app_cx| {
 788                let buffer_store = project.read(app_cx).buffer_store();
 789                for buffer_handle in buffer_store.read(app_cx).buffers() {
 790                    let buffer = buffer_handle.read(app_cx);
 791                    if buffer.is_dirty()
 792                        && let Some(file) = buffer.file()
 793                    {
 794                        let path = file.path().to_string_lossy().to_string();
 795                        unsaved_buffers.push(path);
 796                    }
 797                }
 798            })
 799            .ok();
 800
 801            Arc::new(ProjectSnapshot {
 802                worktree_snapshots,
 803                unsaved_buffer_paths: unsaved_buffers,
 804                timestamp: Utc::now(),
 805            })
 806        })
 807    }
 808
 809    fn worktree_snapshot(
 810        worktree: Entity<project::Worktree>,
 811        git_store: Entity<GitStore>,
 812        cx: &App,
 813    ) -> Task<agent::thread::WorktreeSnapshot> {
 814        cx.spawn(async move |cx| {
 815            // Get worktree path and snapshot
 816            let worktree_info = cx.update(|app_cx| {
 817                let worktree = worktree.read(app_cx);
 818                let path = worktree.abs_path().to_string_lossy().to_string();
 819                let snapshot = worktree.snapshot();
 820                (path, snapshot)
 821            });
 822
 823            let Ok((worktree_path, _snapshot)) = worktree_info else {
 824                return WorktreeSnapshot {
 825                    worktree_path: String::new(),
 826                    git_state: None,
 827                };
 828            };
 829
 830            let git_state = git_store
 831                .update(cx, |git_store, cx| {
 832                    git_store
 833                        .repositories()
 834                        .values()
 835                        .find(|repo| {
 836                            repo.read(cx)
 837                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
 838                                .is_some()
 839                        })
 840                        .cloned()
 841                })
 842                .ok()
 843                .flatten()
 844                .map(|repo| {
 845                    repo.update(cx, |repo, _| {
 846                        let current_branch =
 847                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
 848                        repo.send_job(None, |state, _| async move {
 849                            let RepositoryState::Local { backend, .. } = state else {
 850                                return GitState {
 851                                    remote_url: None,
 852                                    head_sha: None,
 853                                    current_branch,
 854                                    diff: None,
 855                                };
 856                            };
 857
 858                            let remote_url = backend.remote_url("origin");
 859                            let head_sha = backend.head_sha().await;
 860                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
 861
 862                            GitState {
 863                                remote_url,
 864                                head_sha,
 865                                current_branch,
 866                                diff,
 867                            }
 868                        })
 869                    })
 870                });
 871
 872            let git_state = match git_state {
 873                Some(git_state) => match git_state.ok() {
 874                    Some(git_state) => git_state.await.ok(),
 875                    None => None,
 876                },
 877                None => None,
 878            };
 879
 880            WorktreeSnapshot {
 881                worktree_path,
 882                git_state,
 883            }
 884        })
 885    }
 886
 887    pub fn project_context(&self) -> &Entity<ProjectContext> {
 888        &self.project_context
 889    }
 890
 891    pub fn project(&self) -> &Entity<Project> {
 892        &self.project
 893    }
 894
 895    pub fn action_log(&self) -> &Entity<ActionLog> {
 896        &self.action_log
 897    }
 898
 899    pub fn is_empty(&self) -> bool {
 900        self.messages.is_empty() && self.title.is_none()
 901    }
 902
 903    pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
 904        self.model.as_ref()
 905    }
 906
 907    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
 908        let old_usage = self.latest_token_usage();
 909        self.model = Some(model);
 910        let new_usage = self.latest_token_usage();
 911        if old_usage != new_usage {
 912            cx.emit(TokenUsageUpdated(new_usage));
 913        }
 914        cx.notify()
 915    }
 916
 917    pub fn summarization_model(&self) -> Option<&Arc<dyn LanguageModel>> {
 918        self.summarization_model.as_ref()
 919    }
 920
 921    pub fn set_summarization_model(
 922        &mut self,
 923        model: Option<Arc<dyn LanguageModel>>,
 924        cx: &mut Context<Self>,
 925    ) {
 926        self.summarization_model = model;
 927        cx.notify()
 928    }
 929
 930    pub fn completion_mode(&self) -> CompletionMode {
 931        self.completion_mode
 932    }
 933
 934    pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
 935        let old_usage = self.latest_token_usage();
 936        self.completion_mode = mode;
 937        let new_usage = self.latest_token_usage();
 938        if old_usage != new_usage {
 939            cx.emit(TokenUsageUpdated(new_usage));
 940        }
 941        cx.notify()
 942    }
 943
 944    #[cfg(any(test, feature = "test-support"))]
 945    pub fn last_message(&self) -> Option<Message> {
 946        if let Some(message) = self.pending_message.clone() {
 947            Some(Message::Agent(message))
 948        } else {
 949            self.messages.last().cloned()
 950        }
 951    }
 952
 953    pub fn add_default_tools(&mut self, cx: &mut Context<Self>) {
 954        let language_registry = self.project.read(cx).languages().clone();
 955        self.add_tool(CopyPathTool::new(self.project.clone()));
 956        self.add_tool(CreateDirectoryTool::new(self.project.clone()));
 957        self.add_tool(DeletePathTool::new(
 958            self.project.clone(),
 959            self.action_log.clone(),
 960        ));
 961        self.add_tool(DiagnosticsTool::new(self.project.clone()));
 962        self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
 963        self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
 964        self.add_tool(FindPathTool::new(self.project.clone()));
 965        self.add_tool(GrepTool::new(self.project.clone()));
 966        self.add_tool(ListDirectoryTool::new(self.project.clone()));
 967        self.add_tool(MovePathTool::new(self.project.clone()));
 968        self.add_tool(NowTool);
 969        self.add_tool(OpenTool::new(self.project.clone()));
 970        self.add_tool(ReadFileTool::new(
 971            self.project.clone(),
 972            self.action_log.clone(),
 973        ));
 974        self.add_tool(TerminalTool::new(self.project.clone(), cx));
 975        self.add_tool(ThinkingTool);
 976        self.add_tool(WebSearchTool);
 977    }
 978
 979    pub fn add_tool<T: AgentTool>(&mut self, tool: T) {
 980        self.tools.insert(T::name().into(), tool.erase());
 981    }
 982
 983    pub fn remove_tool(&mut self, name: &str) -> bool {
 984        self.tools.remove(name).is_some()
 985    }
 986
 987    pub fn profile(&self) -> &AgentProfileId {
 988        &self.profile_id
 989    }
 990
 991    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 992        self.profile_id = profile_id;
 993    }
 994
 995    pub fn cancel(&mut self, cx: &mut Context<Self>) {
 996        if let Some(running_turn) = self.running_turn.take() {
 997            running_turn.cancel();
 998        }
 999        self.flush_pending_message(cx);
1000    }
1001
1002    fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
1003        let Some(last_user_message) = self.last_user_message() else {
1004            return;
1005        };
1006
1007        self.request_token_usage
1008            .insert(last_user_message.id.clone(), update);
1009        cx.emit(TokenUsageUpdated(self.latest_token_usage()));
1010        cx.notify();
1011    }
1012
1013    pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
1014        self.cancel(cx);
1015        let Some(position) = self.messages.iter().position(
1016            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
1017        ) else {
1018            return Err(anyhow!("Message not found"));
1019        };
1020
1021        for message in self.messages.drain(position..) {
1022            match message {
1023                Message::User(message) => {
1024                    self.request_token_usage.remove(&message.id);
1025                }
1026                Message::Agent(_) | Message::Resume => {}
1027            }
1028        }
1029        self.summary = None;
1030        cx.notify();
1031        Ok(())
1032    }
1033
1034    pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
1035        let last_user_message = self.last_user_message()?;
1036        let tokens = self.request_token_usage.get(&last_user_message.id)?;
1037        let model = self.model.clone()?;
1038
1039        Some(acp_thread::TokenUsage {
1040            max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
1041            used_tokens: tokens.total_tokens(),
1042        })
1043    }
1044
1045    pub fn resume(
1046        &mut self,
1047        cx: &mut Context<Self>,
1048    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1049        anyhow::ensure!(
1050            self.tool_use_limit_reached,
1051            "can only resume after tool use limit is reached"
1052        );
1053
1054        self.messages.push(Message::Resume);
1055        cx.notify();
1056
1057        log::info!("Total messages in thread: {}", self.messages.len());
1058        self.run_turn(cx)
1059    }
1060
1061    /// Sending a message results in the model streaming a response, which could include tool calls.
1062    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
1063    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
1064    pub fn send<T>(
1065        &mut self,
1066        id: UserMessageId,
1067        content: impl IntoIterator<Item = T>,
1068        cx: &mut Context<Self>,
1069    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
1070    where
1071        T: Into<UserMessageContent>,
1072    {
1073        let model = self.model().context("No language model configured")?;
1074
1075        log::info!("Thread::send called with model: {:?}", model.name());
1076        self.advance_prompt_id();
1077
1078        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
1079        log::debug!("Thread::send content: {:?}", content);
1080
1081        self.messages
1082            .push(Message::User(UserMessage { id, content }));
1083        cx.notify();
1084
1085        log::info!("Total messages in thread: {}", self.messages.len());
1086        self.run_turn(cx)
1087    }
1088
1089    fn run_turn(
1090        &mut self,
1091        cx: &mut Context<Self>,
1092    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1093        self.cancel(cx);
1094
1095        let model = self.model.clone().context("No language model configured")?;
1096        let profile = AgentSettings::get_global(cx)
1097            .profiles
1098            .get(&self.profile_id)
1099            .context("Profile not found")?;
1100        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1101        let event_stream = ThreadEventStream(events_tx);
1102        let message_ix = self.messages.len().saturating_sub(1);
1103        self.tool_use_limit_reached = false;
1104        self.summary = None;
1105        self.running_turn = Some(RunningTurn {
1106            event_stream: event_stream.clone(),
1107            tools: self.enabled_tools(profile, &model, cx),
1108            _task: cx.spawn(async move |this, cx| {
1109                log::info!("Starting agent turn execution");
1110
1111                let turn_result: Result<()> = async {
1112                    let mut intent = CompletionIntent::UserPrompt;
1113                    loop {
1114                        Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
1115
1116                        let mut end_turn = true;
1117                        this.update(cx, |this, cx| {
1118                            // Generate title if needed.
1119                            if this.title.is_none() && this.pending_title_generation.is_none() {
1120                                this.generate_title(cx);
1121                            }
1122
1123                            // End the turn if the model didn't use tools.
1124                            let message = this.pending_message.as_ref();
1125                            end_turn =
1126                                message.map_or(true, |message| message.tool_results.is_empty());
1127                            this.flush_pending_message(cx);
1128                        })?;
1129
1130                        if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
1131                            log::info!("Tool use limit reached, completing turn");
1132                            return Err(language_model::ToolUseLimitReachedError.into());
1133                        } else if end_turn {
1134                            log::info!("No tool uses found, completing turn");
1135                            return Ok(());
1136                        } else {
1137                            intent = CompletionIntent::ToolResults;
1138                        }
1139                    }
1140                }
1141                .await;
1142                _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
1143
1144                match turn_result {
1145                    Ok(()) => {
1146                        log::info!("Turn execution completed");
1147                        event_stream.send_stop(acp::StopReason::EndTurn);
1148                    }
1149                    Err(error) => {
1150                        log::error!("Turn execution failed: {:?}", error);
1151                        match error.downcast::<CompletionError>() {
1152                            Ok(CompletionError::Refusal) => {
1153                                event_stream.send_stop(acp::StopReason::Refusal);
1154                                _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
1155                            }
1156                            Ok(CompletionError::MaxTokens) => {
1157                                event_stream.send_stop(acp::StopReason::MaxTokens);
1158                            }
1159                            Ok(CompletionError::Other(error)) | Err(error) => {
1160                                event_stream.send_error(error);
1161                            }
1162                        }
1163                    }
1164                }
1165
1166                _ = this.update(cx, |this, _| this.running_turn.take());
1167            }),
1168        });
1169        Ok(events_rx)
1170    }
1171
1172    async fn stream_completion(
1173        this: &WeakEntity<Self>,
1174        model: &Arc<dyn LanguageModel>,
1175        completion_intent: CompletionIntent,
1176        event_stream: &ThreadEventStream,
1177        cx: &mut AsyncApp,
1178    ) -> Result<()> {
1179        log::debug!("Stream completion started successfully");
1180        let request = this.update(cx, |this, cx| {
1181            this.build_completion_request(completion_intent, cx)
1182        })??;
1183
1184        let mut attempt = None;
1185        'retry: loop {
1186            telemetry::event!(
1187                "Agent Thread Completion",
1188                thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
1189                prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
1190                model = model.telemetry_id(),
1191                model_provider = model.provider_id().to_string(),
1192                attempt
1193            );
1194
1195            log::info!(
1196                "Calling model.stream_completion, attempt {}",
1197                attempt.unwrap_or(0)
1198            );
1199            let mut events = model
1200                .stream_completion(request.clone(), cx)
1201                .await
1202                .map_err(|error| anyhow!(error))?;
1203            let mut tool_results = FuturesUnordered::new();
1204
1205            while let Some(event) = events.next().await {
1206                match event {
1207                    Ok(event) => {
1208                        log::trace!("Received completion event: {:?}", event);
1209                        tool_results.extend(this.update(cx, |this, cx| {
1210                            this.handle_streamed_completion_event(event, event_stream, cx)
1211                        })??);
1212                    }
1213                    Err(error) => {
1214                        let completion_mode =
1215                            this.read_with(cx, |thread, _cx| thread.completion_mode())?;
1216                        if completion_mode == CompletionMode::Normal {
1217                            return Err(anyhow!(error))?;
1218                        }
1219
1220                        let Some(strategy) = Self::retry_strategy_for(&error) else {
1221                            return Err(anyhow!(error))?;
1222                        };
1223
1224                        let max_attempts = match &strategy {
1225                            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
1226                            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
1227                        };
1228
1229                        let attempt = attempt.get_or_insert(0u8);
1230
1231                        *attempt += 1;
1232
1233                        let attempt = *attempt;
1234                        if attempt > max_attempts {
1235                            return Err(anyhow!(error))?;
1236                        }
1237
1238                        let delay = match &strategy {
1239                            RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
1240                                let delay_secs =
1241                                    initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
1242                                Duration::from_secs(delay_secs)
1243                            }
1244                            RetryStrategy::Fixed { delay, .. } => *delay,
1245                        };
1246                        log::debug!("Retry attempt {attempt} with delay {delay:?}");
1247
1248                        event_stream.send_retry(acp_thread::RetryStatus {
1249                            last_error: error.to_string().into(),
1250                            attempt: attempt as usize,
1251                            max_attempts: max_attempts as usize,
1252                            started_at: Instant::now(),
1253                            duration: delay,
1254                        });
1255
1256                        cx.background_executor().timer(delay).await;
1257                        continue 'retry;
1258                    }
1259                }
1260            }
1261
1262            while let Some(tool_result) = tool_results.next().await {
1263                log::info!("Tool finished {:?}", tool_result);
1264
1265                event_stream.update_tool_call_fields(
1266                    &tool_result.tool_use_id,
1267                    acp::ToolCallUpdateFields {
1268                        status: Some(if tool_result.is_error {
1269                            acp::ToolCallStatus::Failed
1270                        } else {
1271                            acp::ToolCallStatus::Completed
1272                        }),
1273                        raw_output: tool_result.output.clone(),
1274                        ..Default::default()
1275                    },
1276                );
1277                this.update(cx, |this, _cx| {
1278                    this.pending_message()
1279                        .tool_results
1280                        .insert(tool_result.tool_use_id.clone(), tool_result);
1281                })?;
1282            }
1283
1284            return Ok(());
1285        }
1286    }
1287
1288    pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
1289        log::debug!("Building system message");
1290        let prompt = SystemPromptTemplate {
1291            project: self.project_context.read(cx),
1292            available_tools: self.tools.keys().cloned().collect(),
1293        }
1294        .render(&self.templates)
1295        .context("failed to build system prompt")
1296        .expect("Invalid template");
1297        log::debug!("System message built");
1298        LanguageModelRequestMessage {
1299            role: Role::System,
1300            content: vec![prompt.into()],
1301            cache: true,
1302        }
1303    }
1304
1305    /// A helper method that's called on every streamed completion event.
1306    /// Returns an optional tool result task, which the main agentic loop will
1307    /// send back to the model when it resolves.
1308    fn handle_streamed_completion_event(
1309        &mut self,
1310        event: LanguageModelCompletionEvent,
1311        event_stream: &ThreadEventStream,
1312        cx: &mut Context<Self>,
1313    ) -> Result<Option<Task<LanguageModelToolResult>>> {
1314        log::trace!("Handling streamed completion event: {:?}", event);
1315        use LanguageModelCompletionEvent::*;
1316
1317        match event {
1318            StartMessage { .. } => {
1319                self.flush_pending_message(cx);
1320                self.pending_message = Some(AgentMessage::default());
1321            }
1322            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1323            Thinking { text, signature } => {
1324                self.handle_thinking_event(text, signature, event_stream, cx)
1325            }
1326            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1327            ToolUse(tool_use) => {
1328                return Ok(self.handle_tool_use_event(tool_use, event_stream, cx));
1329            }
1330            ToolUseJsonParseError {
1331                id,
1332                tool_name,
1333                raw_input,
1334                json_parse_error,
1335            } => {
1336                return Ok(Some(Task::ready(
1337                    self.handle_tool_use_json_parse_error_event(
1338                        id,
1339                        tool_name,
1340                        raw_input,
1341                        json_parse_error,
1342                    ),
1343                )));
1344            }
1345            UsageUpdate(usage) => {
1346                telemetry::event!(
1347                    "Agent Thread Completion Usage Updated",
1348                    thread_id = self.id.to_string(),
1349                    prompt_id = self.prompt_id.to_string(),
1350                    model = self.model.as_ref().map(|m| m.telemetry_id()),
1351                    model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()),
1352                    input_tokens = usage.input_tokens,
1353                    output_tokens = usage.output_tokens,
1354                    cache_creation_input_tokens = usage.cache_creation_input_tokens,
1355                    cache_read_input_tokens = usage.cache_read_input_tokens,
1356                );
1357                self.update_token_usage(usage, cx);
1358            }
1359            StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
1360                self.update_model_request_usage(amount, limit, cx);
1361            }
1362            StatusUpdate(
1363                CompletionRequestStatus::Started
1364                | CompletionRequestStatus::Queued { .. }
1365                | CompletionRequestStatus::Failed { .. },
1366            ) => {}
1367            StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
1368                self.tool_use_limit_reached = true;
1369            }
1370            Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
1371            Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
1372            Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
1373        }
1374
1375        Ok(None)
1376    }
1377
1378    fn handle_text_event(
1379        &mut self,
1380        new_text: String,
1381        event_stream: &ThreadEventStream,
1382        cx: &mut Context<Self>,
1383    ) {
1384        event_stream.send_text(&new_text);
1385
1386        let last_message = self.pending_message();
1387        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1388            text.push_str(&new_text);
1389        } else {
1390            last_message
1391                .content
1392                .push(AgentMessageContent::Text(new_text));
1393        }
1394
1395        cx.notify();
1396    }
1397
1398    fn handle_thinking_event(
1399        &mut self,
1400        new_text: String,
1401        new_signature: Option<String>,
1402        event_stream: &ThreadEventStream,
1403        cx: &mut Context<Self>,
1404    ) {
1405        event_stream.send_thinking(&new_text);
1406
1407        let last_message = self.pending_message();
1408        if let Some(AgentMessageContent::Thinking { text, signature }) =
1409            last_message.content.last_mut()
1410        {
1411            text.push_str(&new_text);
1412            *signature = new_signature.or(signature.take());
1413        } else {
1414            last_message.content.push(AgentMessageContent::Thinking {
1415                text: new_text,
1416                signature: new_signature,
1417            });
1418        }
1419
1420        cx.notify();
1421    }
1422
1423    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1424        let last_message = self.pending_message();
1425        last_message
1426            .content
1427            .push(AgentMessageContent::RedactedThinking(data));
1428        cx.notify();
1429    }
1430
1431    fn handle_tool_use_event(
1432        &mut self,
1433        tool_use: LanguageModelToolUse,
1434        event_stream: &ThreadEventStream,
1435        cx: &mut Context<Self>,
1436    ) -> Option<Task<LanguageModelToolResult>> {
1437        cx.notify();
1438
1439        let tool = self.tool(tool_use.name.as_ref());
1440        let mut title = SharedString::from(&tool_use.name);
1441        let mut kind = acp::ToolKind::Other;
1442        if let Some(tool) = tool.as_ref() {
1443            title = tool.initial_title(tool_use.input.clone());
1444            kind = tool.kind();
1445        }
1446
1447        // Ensure the last message ends in the current tool use
1448        let last_message = self.pending_message();
1449        let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
1450            if let AgentMessageContent::ToolUse(last_tool_use) = content {
1451                if last_tool_use.id == tool_use.id {
1452                    *last_tool_use = tool_use.clone();
1453                    false
1454                } else {
1455                    true
1456                }
1457            } else {
1458                true
1459            }
1460        });
1461
1462        if push_new_tool_use {
1463            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1464            last_message
1465                .content
1466                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1467        } else {
1468            event_stream.update_tool_call_fields(
1469                &tool_use.id,
1470                acp::ToolCallUpdateFields {
1471                    title: Some(title.into()),
1472                    kind: Some(kind),
1473                    raw_input: Some(tool_use.input.clone()),
1474                    ..Default::default()
1475                },
1476            );
1477        }
1478
1479        if !tool_use.is_input_complete {
1480            return None;
1481        }
1482
1483        let Some(tool) = tool else {
1484            let content = format!("No tool named {} exists", tool_use.name);
1485            return Some(Task::ready(LanguageModelToolResult {
1486                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1487                tool_use_id: tool_use.id,
1488                tool_name: tool_use.name,
1489                is_error: true,
1490                output: None,
1491            }));
1492        };
1493
1494        let fs = self.project.read(cx).fs().clone();
1495        let tool_event_stream =
1496            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1497        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1498            status: Some(acp::ToolCallStatus::InProgress),
1499            ..Default::default()
1500        });
1501        let supports_images = self.model().is_some_and(|model| model.supports_images());
1502        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1503        log::info!("Running tool {}", tool_use.name);
1504        Some(cx.foreground_executor().spawn(async move {
1505            let tool_result = tool_result.await.and_then(|output| {
1506                if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1507                    && !supports_images
1508                {
1509                    return Err(anyhow!(
1510                        "Attempted to read an image, but this model doesn't support it.",
1511                    ));
1512                }
1513                Ok(output)
1514            });
1515
1516            match tool_result {
1517                Ok(output) => LanguageModelToolResult {
1518                    tool_use_id: tool_use.id,
1519                    tool_name: tool_use.name,
1520                    is_error: false,
1521                    content: output.llm_output,
1522                    output: Some(output.raw_output),
1523                },
1524                Err(error) => LanguageModelToolResult {
1525                    tool_use_id: tool_use.id,
1526                    tool_name: tool_use.name,
1527                    is_error: true,
1528                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1529                    output: None,
1530                },
1531            }
1532        }))
1533    }
1534
1535    fn handle_tool_use_json_parse_error_event(
1536        &mut self,
1537        tool_use_id: LanguageModelToolUseId,
1538        tool_name: Arc<str>,
1539        raw_input: Arc<str>,
1540        json_parse_error: String,
1541    ) -> LanguageModelToolResult {
1542        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1543        LanguageModelToolResult {
1544            tool_use_id,
1545            tool_name,
1546            is_error: true,
1547            content: LanguageModelToolResultContent::Text(tool_output.into()),
1548            output: Some(serde_json::Value::String(raw_input.to_string())),
1549        }
1550    }
1551
1552    fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context<Self>) {
1553        self.project
1554            .read(cx)
1555            .user_store()
1556            .update(cx, |user_store, cx| {
1557                user_store.update_model_request_usage(
1558                    ModelRequestUsage(RequestUsage {
1559                        amount: amount as i32,
1560                        limit,
1561                    }),
1562                    cx,
1563                )
1564            });
1565    }
1566
1567    pub fn title(&self) -> SharedString {
1568        self.title.clone().unwrap_or("New Thread".into())
1569    }
1570
1571    pub fn summary(&mut self, cx: &mut Context<Self>) -> Task<Result<SharedString>> {
1572        if let Some(summary) = self.summary.as_ref() {
1573            return Task::ready(Ok(summary.clone()));
1574        }
1575        let Some(model) = self.summarization_model.clone() else {
1576            return Task::ready(Err(anyhow!("No summarization model available")));
1577        };
1578        let mut request = LanguageModelRequest {
1579            intent: Some(CompletionIntent::ThreadContextSummarization),
1580            temperature: AgentSettings::temperature_for_model(&model, cx),
1581            ..Default::default()
1582        };
1583
1584        for message in &self.messages {
1585            request.messages.extend(message.to_request());
1586        }
1587
1588        request.messages.push(LanguageModelRequestMessage {
1589            role: Role::User,
1590            content: vec![SUMMARIZE_THREAD_DETAILED_PROMPT.into()],
1591            cache: false,
1592        });
1593        cx.spawn(async move |this, cx| {
1594            let mut summary = String::new();
1595            let mut messages = model.stream_completion(request, cx).await?;
1596            while let Some(event) = messages.next().await {
1597                let event = event?;
1598                let text = match event {
1599                    LanguageModelCompletionEvent::Text(text) => text,
1600                    LanguageModelCompletionEvent::StatusUpdate(
1601                        CompletionRequestStatus::UsageUpdated { amount, limit },
1602                    ) => {
1603                        this.update(cx, |thread, cx| {
1604                            thread.update_model_request_usage(amount, limit, cx);
1605                        })?;
1606                        continue;
1607                    }
1608                    _ => continue,
1609                };
1610
1611                let mut lines = text.lines();
1612                summary.extend(lines.next());
1613            }
1614
1615            log::info!("Setting summary: {}", summary);
1616            let summary = SharedString::from(summary);
1617
1618            this.update(cx, |this, cx| {
1619                this.summary = Some(summary.clone());
1620                cx.notify()
1621            })?;
1622
1623            Ok(summary)
1624        })
1625    }
1626
1627    fn generate_title(&mut self, cx: &mut Context<Self>) {
1628        let Some(model) = self.summarization_model.clone() else {
1629            return;
1630        };
1631
1632        log::info!(
1633            "Generating title with model: {:?}",
1634            self.summarization_model.as_ref().map(|model| model.name())
1635        );
1636        let mut request = LanguageModelRequest {
1637            intent: Some(CompletionIntent::ThreadSummarization),
1638            temperature: AgentSettings::temperature_for_model(&model, cx),
1639            ..Default::default()
1640        };
1641
1642        for message in &self.messages {
1643            request.messages.extend(message.to_request());
1644        }
1645
1646        request.messages.push(LanguageModelRequestMessage {
1647            role: Role::User,
1648            content: vec![SUMMARIZE_THREAD_PROMPT.into()],
1649            cache: false,
1650        });
1651        self.pending_title_generation = Some(cx.spawn(async move |this, cx| {
1652            let mut title = String::new();
1653
1654            let generate = async {
1655                let mut messages = model.stream_completion(request, cx).await?;
1656                while let Some(event) = messages.next().await {
1657                    let event = event?;
1658                    let text = match event {
1659                        LanguageModelCompletionEvent::Text(text) => text,
1660                        LanguageModelCompletionEvent::StatusUpdate(
1661                            CompletionRequestStatus::UsageUpdated { amount, limit },
1662                        ) => {
1663                            this.update(cx, |thread, cx| {
1664                                thread.update_model_request_usage(amount, limit, cx);
1665                            })?;
1666                            continue;
1667                        }
1668                        _ => continue,
1669                    };
1670
1671                    let mut lines = text.lines();
1672                    title.extend(lines.next());
1673
1674                    // Stop if the LLM generated multiple lines.
1675                    if lines.next().is_some() {
1676                        break;
1677                    }
1678                }
1679                anyhow::Ok(())
1680            };
1681
1682            if generate.await.context("failed to generate title").is_ok() {
1683                _ = this.update(cx, |this, cx| this.set_title(title.into(), cx));
1684            }
1685            _ = this.update(cx, |this, _| this.pending_title_generation = None);
1686        }));
1687    }
1688
1689    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) {
1690        self.pending_title_generation = None;
1691        if Some(&title) != self.title.as_ref() {
1692            self.title = Some(title);
1693            cx.emit(TitleUpdated);
1694            cx.notify();
1695        }
1696    }
1697
1698    fn last_user_message(&self) -> Option<&UserMessage> {
1699        self.messages
1700            .iter()
1701            .rev()
1702            .find_map(|message| match message {
1703                Message::User(user_message) => Some(user_message),
1704                Message::Agent(_) => None,
1705                Message::Resume => None,
1706            })
1707    }
1708
1709    fn pending_message(&mut self) -> &mut AgentMessage {
1710        self.pending_message.get_or_insert_default()
1711    }
1712
1713    fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1714        let Some(mut message) = self.pending_message.take() else {
1715            return;
1716        };
1717
1718        for content in &message.content {
1719            let AgentMessageContent::ToolUse(tool_use) = content else {
1720                continue;
1721            };
1722
1723            if !message.tool_results.contains_key(&tool_use.id) {
1724                message.tool_results.insert(
1725                    tool_use.id.clone(),
1726                    LanguageModelToolResult {
1727                        tool_use_id: tool_use.id.clone(),
1728                        tool_name: tool_use.name.clone(),
1729                        is_error: true,
1730                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1731                        output: None,
1732                    },
1733                );
1734            }
1735        }
1736
1737        self.messages.push(Message::Agent(message));
1738        self.updated_at = Utc::now();
1739        self.summary = None;
1740        cx.notify()
1741    }
1742
1743    pub(crate) fn build_completion_request(
1744        &self,
1745        completion_intent: CompletionIntent,
1746        cx: &mut App,
1747    ) -> Result<LanguageModelRequest> {
1748        let model = self.model().context("No language model configured")?;
1749        let tools = if let Some(turn) = self.running_turn.as_ref() {
1750            turn.tools
1751                .iter()
1752                .filter_map(|(tool_name, tool)| {
1753                    log::trace!("Including tool: {}", tool_name);
1754                    Some(LanguageModelRequestTool {
1755                        name: tool_name.to_string(),
1756                        description: tool.description().to_string(),
1757                        input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1758                    })
1759                })
1760                .collect::<Vec<_>>()
1761        } else {
1762            Vec::new()
1763        };
1764
1765        log::debug!("Building completion request");
1766        log::debug!("Completion intent: {:?}", completion_intent);
1767        log::debug!("Completion mode: {:?}", self.completion_mode);
1768
1769        let messages = self.build_request_messages(cx);
1770        log::info!("Request will include {} messages", messages.len());
1771        log::info!("Request includes {} tools", tools.len());
1772
1773        let request = LanguageModelRequest {
1774            thread_id: Some(self.id.to_string()),
1775            prompt_id: Some(self.prompt_id.to_string()),
1776            intent: Some(completion_intent),
1777            mode: Some(self.completion_mode.into()),
1778            messages,
1779            tools,
1780            tool_choice: None,
1781            stop: Vec::new(),
1782            temperature: AgentSettings::temperature_for_model(model, cx),
1783            thinking_allowed: true,
1784        };
1785
1786        log::debug!("Completion request built successfully");
1787        Ok(request)
1788    }
1789
1790    fn enabled_tools(
1791        &self,
1792        profile: &AgentProfileSettings,
1793        model: &Arc<dyn LanguageModel>,
1794        cx: &App,
1795    ) -> BTreeMap<SharedString, Arc<dyn AnyAgentTool>> {
1796        fn truncate(tool_name: &SharedString) -> SharedString {
1797            if tool_name.len() > MAX_TOOL_NAME_LENGTH {
1798                let mut truncated = tool_name.to_string();
1799                truncated.truncate(MAX_TOOL_NAME_LENGTH);
1800                truncated.into()
1801            } else {
1802                tool_name.clone()
1803            }
1804        }
1805
1806        let mut tools = self
1807            .tools
1808            .iter()
1809            .filter_map(|(tool_name, tool)| {
1810                if tool.supported_provider(&model.provider_id())
1811                    && profile.is_tool_enabled(tool_name)
1812                {
1813                    Some((truncate(tool_name), tool.clone()))
1814                } else {
1815                    None
1816                }
1817            })
1818            .collect::<BTreeMap<_, _>>();
1819
1820        let mut context_server_tools = Vec::new();
1821        let mut seen_tools = tools.keys().cloned().collect::<HashSet<_>>();
1822        let mut duplicate_tool_names = HashSet::default();
1823        for (server_id, server_tools) in self.context_server_registry.read(cx).servers() {
1824            for (tool_name, tool) in server_tools {
1825                if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) {
1826                    let tool_name = truncate(tool_name);
1827                    if !seen_tools.insert(tool_name.clone()) {
1828                        duplicate_tool_names.insert(tool_name.clone());
1829                    }
1830                    context_server_tools.push((server_id.clone(), tool_name, tool.clone()));
1831                }
1832            }
1833        }
1834
1835        // When there are duplicate tool names, disambiguate by prefixing them
1836        // with the server ID. In the rare case there isn't enough space for the
1837        // disambiguated tool name, keep only the last tool with this name.
1838        for (server_id, tool_name, tool) in context_server_tools {
1839            if duplicate_tool_names.contains(&tool_name) {
1840                let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len());
1841                if available >= 2 {
1842                    let mut disambiguated = server_id.0.to_string();
1843                    disambiguated.truncate(available - 1);
1844                    disambiguated.push('_');
1845                    disambiguated.push_str(&tool_name);
1846                    tools.insert(disambiguated.into(), tool.clone());
1847                } else {
1848                    tools.insert(tool_name, tool.clone());
1849                }
1850            } else {
1851                tools.insert(tool_name, tool.clone());
1852            }
1853        }
1854
1855        tools
1856    }
1857
1858    fn tool(&self, name: &str) -> Option<Arc<dyn AnyAgentTool>> {
1859        self.running_turn.as_ref()?.tools.get(name).cloned()
1860    }
1861
1862    fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1863        log::trace!(
1864            "Building request messages from {} thread messages",
1865            self.messages.len()
1866        );
1867        let mut messages = vec![self.build_system_message(cx)];
1868        for message in &self.messages {
1869            messages.extend(message.to_request());
1870        }
1871
1872        if let Some(message) = self.pending_message.as_ref() {
1873            messages.extend(message.to_request());
1874        }
1875
1876        if let Some(last_user_message) = messages
1877            .iter_mut()
1878            .rev()
1879            .find(|message| message.role == Role::User)
1880        {
1881            last_user_message.cache = true;
1882        }
1883
1884        messages
1885    }
1886
1887    pub fn to_markdown(&self) -> String {
1888        let mut markdown = String::new();
1889        for (ix, message) in self.messages.iter().enumerate() {
1890            if ix > 0 {
1891                markdown.push('\n');
1892            }
1893            markdown.push_str(&message.to_markdown());
1894        }
1895
1896        if let Some(message) = self.pending_message.as_ref() {
1897            markdown.push('\n');
1898            markdown.push_str(&message.to_markdown());
1899        }
1900
1901        markdown
1902    }
1903
1904    fn advance_prompt_id(&mut self) {
1905        self.prompt_id = PromptId::new();
1906    }
1907
1908    fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1909        use LanguageModelCompletionError::*;
1910        use http_client::StatusCode;
1911
1912        // General strategy here:
1913        // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1914        // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1915        // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1916        match error {
1917            HttpResponseError {
1918                status_code: StatusCode::TOO_MANY_REQUESTS,
1919                ..
1920            } => Some(RetryStrategy::ExponentialBackoff {
1921                initial_delay: BASE_RETRY_DELAY,
1922                max_attempts: MAX_RETRY_ATTEMPTS,
1923            }),
1924            ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1925                Some(RetryStrategy::Fixed {
1926                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1927                    max_attempts: MAX_RETRY_ATTEMPTS,
1928                })
1929            }
1930            UpstreamProviderError {
1931                status,
1932                retry_after,
1933                ..
1934            } => match *status {
1935                StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1936                    Some(RetryStrategy::Fixed {
1937                        delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1938                        max_attempts: MAX_RETRY_ATTEMPTS,
1939                    })
1940                }
1941                StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1942                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1943                    // Internal Server Error could be anything, retry up to 3 times.
1944                    max_attempts: 3,
1945                }),
1946                status => {
1947                    // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1948                    // but we frequently get them in practice. See https://http.dev/529
1949                    if status.as_u16() == 529 {
1950                        Some(RetryStrategy::Fixed {
1951                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1952                            max_attempts: MAX_RETRY_ATTEMPTS,
1953                        })
1954                    } else {
1955                        Some(RetryStrategy::Fixed {
1956                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1957                            max_attempts: 2,
1958                        })
1959                    }
1960                }
1961            },
1962            ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1963                delay: BASE_RETRY_DELAY,
1964                max_attempts: 3,
1965            }),
1966            ApiReadResponseError { .. }
1967            | HttpSend { .. }
1968            | DeserializeResponse { .. }
1969            | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
1970                delay: BASE_RETRY_DELAY,
1971                max_attempts: 3,
1972            }),
1973            // Retrying these errors definitely shouldn't help.
1974            HttpResponseError {
1975                status_code:
1976                    StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
1977                ..
1978            }
1979            | AuthenticationError { .. }
1980            | PermissionError { .. }
1981            | NoApiKey { .. }
1982            | ApiEndpointNotFound { .. }
1983            | PromptTooLarge { .. } => None,
1984            // These errors might be transient, so retry them
1985            SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
1986                delay: BASE_RETRY_DELAY,
1987                max_attempts: 1,
1988            }),
1989            // Retry all other 4xx and 5xx errors once.
1990            HttpResponseError { status_code, .. }
1991                if status_code.is_client_error() || status_code.is_server_error() =>
1992            {
1993                Some(RetryStrategy::Fixed {
1994                    delay: BASE_RETRY_DELAY,
1995                    max_attempts: 3,
1996                })
1997            }
1998            Other(err)
1999                if err.is::<language_model::PaymentRequiredError>()
2000                    || err.is::<language_model::ModelRequestLimitReachedError>() =>
2001            {
2002                // Retrying won't help for Payment Required or Model Request Limit errors (where
2003                // the user must upgrade to usage-based billing to get more requests, or else wait
2004                // for a significant amount of time for the request limit to reset).
2005                None
2006            }
2007            // Conservatively assume that any other errors are non-retryable
2008            HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
2009                delay: BASE_RETRY_DELAY,
2010                max_attempts: 2,
2011            }),
2012        }
2013    }
2014}
2015
2016struct RunningTurn {
2017    /// Holds the task that handles agent interaction until the end of the turn.
2018    /// Survives across multiple requests as the model performs tool calls and
2019    /// we run tools, report their results.
2020    _task: Task<()>,
2021    /// The current event stream for the running turn. Used to report a final
2022    /// cancellation event if we cancel the turn.
2023    event_stream: ThreadEventStream,
2024    /// The tools that were enabled for this turn.
2025    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
2026}
2027
2028impl RunningTurn {
2029    fn cancel(self) {
2030        log::debug!("Cancelling in progress turn");
2031        self.event_stream.send_canceled();
2032    }
2033}
2034
2035pub struct TokenUsageUpdated(pub Option<acp_thread::TokenUsage>);
2036
2037impl EventEmitter<TokenUsageUpdated> for Thread {}
2038
2039pub struct TitleUpdated;
2040
2041impl EventEmitter<TitleUpdated> for Thread {}
2042
2043pub trait AgentTool
2044where
2045    Self: 'static + Sized,
2046{
2047    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
2048    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
2049
2050    fn name() -> &'static str;
2051
2052    fn description(&self) -> SharedString {
2053        let schema = schemars::schema_for!(Self::Input);
2054        SharedString::new(
2055            schema
2056                .get("description")
2057                .and_then(|description| description.as_str())
2058                .unwrap_or_default(),
2059        )
2060    }
2061
2062    fn kind() -> acp::ToolKind;
2063
2064    /// The initial tool title to display. Can be updated during the tool run.
2065    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
2066
2067    /// Returns the JSON schema that describes the tool's input.
2068    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
2069        crate::tool_schema::root_schema_for::<Self::Input>(format)
2070    }
2071
2072    /// Some tools rely on a provider for the underlying billing or other reasons.
2073    /// Allow the tool to check if they are compatible, or should be filtered out.
2074    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2075        true
2076    }
2077
2078    /// Runs the tool with the provided input.
2079    fn run(
2080        self: Arc<Self>,
2081        input: Self::Input,
2082        event_stream: ToolCallEventStream,
2083        cx: &mut App,
2084    ) -> Task<Result<Self::Output>>;
2085
2086    /// Emits events for a previous execution of the tool.
2087    fn replay(
2088        &self,
2089        _input: Self::Input,
2090        _output: Self::Output,
2091        _event_stream: ToolCallEventStream,
2092        _cx: &mut App,
2093    ) -> Result<()> {
2094        Ok(())
2095    }
2096
2097    fn erase(self) -> Arc<dyn AnyAgentTool> {
2098        Arc::new(Erased(Arc::new(self)))
2099    }
2100}
2101
2102pub struct Erased<T>(T);
2103
2104pub struct AgentToolOutput {
2105    pub llm_output: LanguageModelToolResultContent,
2106    pub raw_output: serde_json::Value,
2107}
2108
2109pub trait AnyAgentTool {
2110    fn name(&self) -> SharedString;
2111    fn description(&self) -> SharedString;
2112    fn kind(&self) -> acp::ToolKind;
2113    fn initial_title(&self, input: serde_json::Value) -> SharedString;
2114    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
2115    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2116        true
2117    }
2118    fn run(
2119        self: Arc<Self>,
2120        input: serde_json::Value,
2121        event_stream: ToolCallEventStream,
2122        cx: &mut App,
2123    ) -> Task<Result<AgentToolOutput>>;
2124    fn replay(
2125        &self,
2126        input: serde_json::Value,
2127        output: serde_json::Value,
2128        event_stream: ToolCallEventStream,
2129        cx: &mut App,
2130    ) -> Result<()>;
2131}
2132
2133impl<T> AnyAgentTool for Erased<Arc<T>>
2134where
2135    T: AgentTool,
2136{
2137    fn name(&self) -> SharedString {
2138        T::name().into()
2139    }
2140
2141    fn description(&self) -> SharedString {
2142        self.0.description()
2143    }
2144
2145    fn kind(&self) -> agent_client_protocol::ToolKind {
2146        T::kind()
2147    }
2148
2149    fn initial_title(&self, input: serde_json::Value) -> SharedString {
2150        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
2151        self.0.initial_title(parsed_input)
2152    }
2153
2154    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
2155        let mut json = serde_json::to_value(self.0.input_schema(format))?;
2156        adapt_schema_to_format(&mut json, format)?;
2157        Ok(json)
2158    }
2159
2160    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
2161        self.0.supported_provider(provider)
2162    }
2163
2164    fn run(
2165        self: Arc<Self>,
2166        input: serde_json::Value,
2167        event_stream: ToolCallEventStream,
2168        cx: &mut App,
2169    ) -> Task<Result<AgentToolOutput>> {
2170        cx.spawn(async move |cx| {
2171            let input = serde_json::from_value(input)?;
2172            let output = cx
2173                .update(|cx| self.0.clone().run(input, event_stream, cx))?
2174                .await?;
2175            let raw_output = serde_json::to_value(&output)?;
2176            Ok(AgentToolOutput {
2177                llm_output: output.into(),
2178                raw_output,
2179            })
2180        })
2181    }
2182
2183    fn replay(
2184        &self,
2185        input: serde_json::Value,
2186        output: serde_json::Value,
2187        event_stream: ToolCallEventStream,
2188        cx: &mut App,
2189    ) -> Result<()> {
2190        let input = serde_json::from_value(input)?;
2191        let output = serde_json::from_value(output)?;
2192        self.0.replay(input, output, event_stream, cx)
2193    }
2194}
2195
2196#[derive(Clone)]
2197struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
2198
2199impl ThreadEventStream {
2200    fn send_user_message(&self, message: &UserMessage) {
2201        self.0
2202            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
2203            .ok();
2204    }
2205
2206    fn send_text(&self, text: &str) {
2207        self.0
2208            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
2209            .ok();
2210    }
2211
2212    fn send_thinking(&self, text: &str) {
2213        self.0
2214            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
2215            .ok();
2216    }
2217
2218    fn send_tool_call(
2219        &self,
2220        id: &LanguageModelToolUseId,
2221        title: SharedString,
2222        kind: acp::ToolKind,
2223        input: serde_json::Value,
2224    ) {
2225        self.0
2226            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
2227                id,
2228                title.to_string(),
2229                kind,
2230                input,
2231            ))))
2232            .ok();
2233    }
2234
2235    fn initial_tool_call(
2236        id: &LanguageModelToolUseId,
2237        title: String,
2238        kind: acp::ToolKind,
2239        input: serde_json::Value,
2240    ) -> acp::ToolCall {
2241        acp::ToolCall {
2242            id: acp::ToolCallId(id.to_string().into()),
2243            title,
2244            kind,
2245            status: acp::ToolCallStatus::Pending,
2246            content: vec![],
2247            locations: vec![],
2248            raw_input: Some(input),
2249            raw_output: None,
2250        }
2251    }
2252
2253    fn update_tool_call_fields(
2254        &self,
2255        tool_use_id: &LanguageModelToolUseId,
2256        fields: acp::ToolCallUpdateFields,
2257    ) {
2258        self.0
2259            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2260                acp::ToolCallUpdate {
2261                    id: acp::ToolCallId(tool_use_id.to_string().into()),
2262                    fields,
2263                }
2264                .into(),
2265            )))
2266            .ok();
2267    }
2268
2269    fn send_retry(&self, status: acp_thread::RetryStatus) {
2270        self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
2271    }
2272
2273    fn send_stop(&self, reason: acp::StopReason) {
2274        self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
2275    }
2276
2277    fn send_canceled(&self) {
2278        self.0
2279            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
2280            .ok();
2281    }
2282
2283    fn send_error(&self, error: impl Into<anyhow::Error>) {
2284        self.0.unbounded_send(Err(error.into())).ok();
2285    }
2286}
2287
2288#[derive(Clone)]
2289pub struct ToolCallEventStream {
2290    tool_use_id: LanguageModelToolUseId,
2291    stream: ThreadEventStream,
2292    fs: Option<Arc<dyn Fs>>,
2293}
2294
2295impl ToolCallEventStream {
2296    #[cfg(test)]
2297    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
2298        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
2299
2300        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
2301
2302        (stream, ToolCallEventStreamReceiver(events_rx))
2303    }
2304
2305    fn new(
2306        tool_use_id: LanguageModelToolUseId,
2307        stream: ThreadEventStream,
2308        fs: Option<Arc<dyn Fs>>,
2309    ) -> Self {
2310        Self {
2311            tool_use_id,
2312            stream,
2313            fs,
2314        }
2315    }
2316
2317    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
2318        self.stream
2319            .update_tool_call_fields(&self.tool_use_id, fields);
2320    }
2321
2322    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
2323        self.stream
2324            .0
2325            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2326                acp_thread::ToolCallUpdateDiff {
2327                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2328                    diff,
2329                }
2330                .into(),
2331            )))
2332            .ok();
2333    }
2334
2335    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
2336        self.stream
2337            .0
2338            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2339                acp_thread::ToolCallUpdateTerminal {
2340                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2341                    terminal,
2342                }
2343                .into(),
2344            )))
2345            .ok();
2346    }
2347
2348    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
2349        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
2350            return Task::ready(Ok(()));
2351        }
2352
2353        let (response_tx, response_rx) = oneshot::channel();
2354        self.stream
2355            .0
2356            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
2357                ToolCallAuthorization {
2358                    tool_call: acp::ToolCallUpdate {
2359                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2360                        fields: acp::ToolCallUpdateFields {
2361                            title: Some(title.into()),
2362                            ..Default::default()
2363                        },
2364                    },
2365                    options: vec![
2366                        acp::PermissionOption {
2367                            id: acp::PermissionOptionId("always_allow".into()),
2368                            name: "Always Allow".into(),
2369                            kind: acp::PermissionOptionKind::AllowAlways,
2370                        },
2371                        acp::PermissionOption {
2372                            id: acp::PermissionOptionId("allow".into()),
2373                            name: "Allow".into(),
2374                            kind: acp::PermissionOptionKind::AllowOnce,
2375                        },
2376                        acp::PermissionOption {
2377                            id: acp::PermissionOptionId("deny".into()),
2378                            name: "Deny".into(),
2379                            kind: acp::PermissionOptionKind::RejectOnce,
2380                        },
2381                    ],
2382                    response: response_tx,
2383                },
2384            )))
2385            .ok();
2386        let fs = self.fs.clone();
2387        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
2388            "always_allow" => {
2389                if let Some(fs) = fs.clone() {
2390                    cx.update(|cx| {
2391                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
2392                            settings.set_always_allow_tool_actions(true);
2393                        });
2394                    })?;
2395                }
2396
2397                Ok(())
2398            }
2399            "allow" => Ok(()),
2400            _ => Err(anyhow!("Permission to run tool denied by user")),
2401        })
2402    }
2403}
2404
2405#[cfg(test)]
2406pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
2407
2408#[cfg(test)]
2409impl ToolCallEventStreamReceiver {
2410    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
2411        let event = self.0.next().await;
2412        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
2413            auth
2414        } else {
2415            panic!("Expected ToolCallAuthorization but got: {:?}", event);
2416        }
2417    }
2418
2419    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
2420        let event = self.0.next().await;
2421        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
2422            update,
2423        )))) = event
2424        {
2425            update.terminal
2426        } else {
2427            panic!("Expected terminal but got: {:?}", event);
2428        }
2429    }
2430}
2431
2432#[cfg(test)]
2433impl std::ops::Deref for ToolCallEventStreamReceiver {
2434    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
2435
2436    fn deref(&self) -> &Self::Target {
2437        &self.0
2438    }
2439}
2440
2441#[cfg(test)]
2442impl std::ops::DerefMut for ToolCallEventStreamReceiver {
2443    fn deref_mut(&mut self) -> &mut Self::Target {
2444        &mut self.0
2445    }
2446}
2447
2448impl From<&str> for UserMessageContent {
2449    fn from(text: &str) -> Self {
2450        Self::Text(text.into())
2451    }
2452}
2453
2454impl From<acp::ContentBlock> for UserMessageContent {
2455    fn from(value: acp::ContentBlock) -> Self {
2456        match value {
2457            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
2458            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
2459            acp::ContentBlock::Audio(_) => {
2460                // TODO
2461                Self::Text("[audio]".to_string())
2462            }
2463            acp::ContentBlock::ResourceLink(resource_link) => {
2464                match MentionUri::parse(&resource_link.uri) {
2465                    Ok(uri) => Self::Mention {
2466                        uri,
2467                        content: String::new(),
2468                    },
2469                    Err(err) => {
2470                        log::error!("Failed to parse mention link: {}", err);
2471                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2472                    }
2473                }
2474            }
2475            acp::ContentBlock::Resource(resource) => match resource.resource {
2476                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2477                    match MentionUri::parse(&resource.uri) {
2478                        Ok(uri) => Self::Mention {
2479                            uri,
2480                            content: resource.text,
2481                        },
2482                        Err(err) => {
2483                            log::error!("Failed to parse mention link: {}", err);
2484                            Self::Text(
2485                                MarkdownCodeBlock {
2486                                    tag: &resource.uri,
2487                                    text: &resource.text,
2488                                }
2489                                .to_string(),
2490                            )
2491                        }
2492                    }
2493                }
2494                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2495                    // TODO
2496                    Self::Text("[blob]".to_string())
2497                }
2498            },
2499        }
2500    }
2501}
2502
2503impl From<UserMessageContent> for acp::ContentBlock {
2504    fn from(content: UserMessageContent) -> Self {
2505        match content {
2506            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
2507                text,
2508                annotations: None,
2509            }),
2510            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
2511                data: image.source.to_string(),
2512                mime_type: "image/png".to_string(),
2513                annotations: None,
2514                uri: None,
2515            }),
2516            UserMessageContent::Mention { uri, content } => {
2517                acp::ContentBlock::Resource(acp::EmbeddedResource {
2518                    resource: acp::EmbeddedResourceResource::TextResourceContents(
2519                        acp::TextResourceContents {
2520                            mime_type: None,
2521                            text: content,
2522                            uri: uri.to_uri().to_string(),
2523                        },
2524                    ),
2525                    annotations: None,
2526                })
2527            }
2528        }
2529    }
2530}
2531
2532fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2533    LanguageModelImage {
2534        source: image_content.data.into(),
2535        // TODO: make this optional?
2536        size: gpui::Size::new(0.into(), 0.into()),
2537    }
2538}