thread.rs

   1use crate::{
   2    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
   3    DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
   4    ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
   5    Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
   6};
   7use acp_thread::{MentionUri, UserMessageId};
   8use action_log::ActionLog;
   9use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
  10use agent_client_protocol as acp;
  11use agent_settings::{
  12    AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
  13    SUMMARIZE_THREAD_PROMPT,
  14};
  15use anyhow::{Context as _, Result, anyhow};
  16use assistant_tool::adapt_schema_to_format;
  17use chrono::{DateTime, Utc};
  18use client::{ModelRequestUsage, RequestUsage};
  19use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
  20use collections::{HashMap, IndexMap};
  21use fs::Fs;
  22use futures::{
  23    FutureExt,
  24    channel::{mpsc, oneshot},
  25    future::Shared,
  26    stream::FuturesUnordered,
  27};
  28use git::repository::DiffType;
  29use gpui::{
  30    App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
  31};
  32use language_model::{
  33    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
  34    LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
  35    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  36    LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
  37    LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
  38};
  39use project::{
  40    Project,
  41    git_store::{GitStore, RepositoryState},
  42};
  43use prompt_store::ProjectContext;
  44use schemars::{JsonSchema, Schema};
  45use serde::{Deserialize, Serialize};
  46use settings::{Settings, update_settings_file};
  47use smol::stream::StreamExt;
  48use std::{
  49    collections::BTreeMap,
  50    path::Path,
  51    sync::Arc,
  52    time::{Duration, Instant},
  53};
  54use std::{fmt::Write, ops::Range};
  55use util::{ResultExt, markdown::MarkdownCodeBlock};
  56use uuid::Uuid;
  57
  58const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  59
  60/// The ID of the user prompt that initiated a request.
  61///
  62/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  63#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  64pub struct PromptId(Arc<str>);
  65
  66impl PromptId {
  67    pub fn new() -> Self {
  68        Self(Uuid::new_v4().to_string().into())
  69    }
  70}
  71
  72impl std::fmt::Display for PromptId {
  73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  74        write!(f, "{}", self.0)
  75    }
  76}
  77
  78pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
  79pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
  80
  81#[derive(Debug, Clone)]
  82enum RetryStrategy {
  83    ExponentialBackoff {
  84        initial_delay: Duration,
  85        max_attempts: u8,
  86    },
  87    Fixed {
  88        delay: Duration,
  89        max_attempts: u8,
  90    },
  91}
  92
  93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  94pub enum Message {
  95    User(UserMessage),
  96    Agent(AgentMessage),
  97    Resume,
  98}
  99
 100impl Message {
 101    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
 102        match self {
 103            Message::Agent(agent_message) => Some(agent_message),
 104            _ => None,
 105        }
 106    }
 107
 108    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 109        match self {
 110            Message::User(message) => vec![message.to_request()],
 111            Message::Agent(message) => message.to_request(),
 112            Message::Resume => vec![LanguageModelRequestMessage {
 113                role: Role::User,
 114                content: vec!["Continue where you left off".into()],
 115                cache: false,
 116            }],
 117        }
 118    }
 119
 120    pub fn to_markdown(&self) -> String {
 121        match self {
 122            Message::User(message) => message.to_markdown(),
 123            Message::Agent(message) => message.to_markdown(),
 124            Message::Resume => "[resumed after tool use limit was reached]".into(),
 125        }
 126    }
 127
 128    pub fn role(&self) -> Role {
 129        match self {
 130            Message::User(_) | Message::Resume => Role::User,
 131            Message::Agent(_) => Role::Assistant,
 132        }
 133    }
 134}
 135
 136#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 137pub struct UserMessage {
 138    pub id: UserMessageId,
 139    pub content: Vec<UserMessageContent>,
 140}
 141
 142#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 143pub enum UserMessageContent {
 144    Text(String),
 145    Mention { uri: MentionUri, content: String },
 146    Image(LanguageModelImage),
 147}
 148
 149impl UserMessage {
 150    pub fn to_markdown(&self) -> String {
 151        let mut markdown = String::from("## User\n\n");
 152
 153        for content in &self.content {
 154            match content {
 155                UserMessageContent::Text(text) => {
 156                    markdown.push_str(text);
 157                    markdown.push('\n');
 158                }
 159                UserMessageContent::Image(_) => {
 160                    markdown.push_str("<image />\n");
 161                }
 162                UserMessageContent::Mention { uri, content } => {
 163                    if !content.is_empty() {
 164                        let _ = writeln!(&mut markdown, "{}\n\n{}", uri.as_link(), content);
 165                    } else {
 166                        let _ = writeln!(&mut markdown, "{}", uri.as_link());
 167                    }
 168                }
 169            }
 170        }
 171
 172        markdown
 173    }
 174
 175    fn to_request(&self) -> LanguageModelRequestMessage {
 176        let mut message = LanguageModelRequestMessage {
 177            role: Role::User,
 178            content: Vec::with_capacity(self.content.len()),
 179            cache: false,
 180        };
 181
 182        const OPEN_CONTEXT: &str = "<context>\n\
 183            The following items were attached by the user. \
 184            They are up-to-date and don't need to be re-read.\n\n";
 185
 186        const OPEN_FILES_TAG: &str = "<files>";
 187        const OPEN_DIRECTORIES_TAG: &str = "<directories>";
 188        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 189        const OPEN_THREADS_TAG: &str = "<threads>";
 190        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 191        const OPEN_RULES_TAG: &str =
 192            "<rules>\nThe user has specified the following rules that should be applied:\n";
 193
 194        let mut file_context = OPEN_FILES_TAG.to_string();
 195        let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
 196        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 197        let mut thread_context = OPEN_THREADS_TAG.to_string();
 198        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 199        let mut rules_context = OPEN_RULES_TAG.to_string();
 200
 201        for chunk in &self.content {
 202            let chunk = match chunk {
 203                UserMessageContent::Text(text) => {
 204                    language_model::MessageContent::Text(text.clone())
 205                }
 206                UserMessageContent::Image(value) => {
 207                    language_model::MessageContent::Image(value.clone())
 208                }
 209                UserMessageContent::Mention { uri, content } => {
 210                    match uri {
 211                        MentionUri::File { abs_path } => {
 212                            write!(
 213                                &mut symbol_context,
 214                                "\n{}",
 215                                MarkdownCodeBlock {
 216                                    tag: &codeblock_tag(abs_path, None),
 217                                    text: &content.to_string(),
 218                                }
 219                            )
 220                            .ok();
 221                        }
 222                        MentionUri::Directory { .. } => {
 223                            write!(&mut directory_context, "\n{}\n", content).ok();
 224                        }
 225                        MentionUri::Symbol {
 226                            path, line_range, ..
 227                        }
 228                        | MentionUri::Selection {
 229                            path, line_range, ..
 230                        } => {
 231                            write!(
 232                                &mut rules_context,
 233                                "\n{}",
 234                                MarkdownCodeBlock {
 235                                    tag: &codeblock_tag(path, Some(line_range)),
 236                                    text: content
 237                                }
 238                            )
 239                            .ok();
 240                        }
 241                        MentionUri::Thread { .. } => {
 242                            write!(&mut thread_context, "\n{}\n", content).ok();
 243                        }
 244                        MentionUri::TextThread { .. } => {
 245                            write!(&mut thread_context, "\n{}\n", content).ok();
 246                        }
 247                        MentionUri::Rule { .. } => {
 248                            write!(
 249                                &mut rules_context,
 250                                "\n{}",
 251                                MarkdownCodeBlock {
 252                                    tag: "",
 253                                    text: content
 254                                }
 255                            )
 256                            .ok();
 257                        }
 258                        MentionUri::Fetch { url } => {
 259                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 260                        }
 261                    }
 262
 263                    language_model::MessageContent::Text(uri.as_link().to_string())
 264                }
 265            };
 266
 267            message.content.push(chunk);
 268        }
 269
 270        let len_before_context = message.content.len();
 271
 272        if file_context.len() > OPEN_FILES_TAG.len() {
 273            file_context.push_str("</files>\n");
 274            message
 275                .content
 276                .push(language_model::MessageContent::Text(file_context));
 277        }
 278
 279        if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
 280            directory_context.push_str("</directories>\n");
 281            message
 282                .content
 283                .push(language_model::MessageContent::Text(directory_context));
 284        }
 285
 286        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 287            symbol_context.push_str("</symbols>\n");
 288            message
 289                .content
 290                .push(language_model::MessageContent::Text(symbol_context));
 291        }
 292
 293        if thread_context.len() > OPEN_THREADS_TAG.len() {
 294            thread_context.push_str("</threads>\n");
 295            message
 296                .content
 297                .push(language_model::MessageContent::Text(thread_context));
 298        }
 299
 300        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 301            fetch_context.push_str("</fetched_urls>\n");
 302            message
 303                .content
 304                .push(language_model::MessageContent::Text(fetch_context));
 305        }
 306
 307        if rules_context.len() > OPEN_RULES_TAG.len() {
 308            rules_context.push_str("</user_rules>\n");
 309            message
 310                .content
 311                .push(language_model::MessageContent::Text(rules_context));
 312        }
 313
 314        if message.content.len() > len_before_context {
 315            message.content.insert(
 316                len_before_context,
 317                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 318            );
 319            message
 320                .content
 321                .push(language_model::MessageContent::Text("</context>".into()));
 322        }
 323
 324        message
 325    }
 326}
 327
 328fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 329    let mut result = String::new();
 330
 331    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 332        let _ = write!(result, "{} ", extension);
 333    }
 334
 335    let _ = write!(result, "{}", full_path.display());
 336
 337    if let Some(range) = line_range {
 338        if range.start == range.end {
 339            let _ = write!(result, ":{}", range.start + 1);
 340        } else {
 341            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 342        }
 343    }
 344
 345    result
 346}
 347
 348impl AgentMessage {
 349    pub fn to_markdown(&self) -> String {
 350        let mut markdown = String::from("## Assistant\n\n");
 351
 352        for content in &self.content {
 353            match content {
 354                AgentMessageContent::Text(text) => {
 355                    markdown.push_str(text);
 356                    markdown.push('\n');
 357                }
 358                AgentMessageContent::Thinking { text, .. } => {
 359                    markdown.push_str("<think>");
 360                    markdown.push_str(text);
 361                    markdown.push_str("</think>\n");
 362                }
 363                AgentMessageContent::RedactedThinking(_) => {
 364                    markdown.push_str("<redacted_thinking />\n")
 365                }
 366                AgentMessageContent::ToolUse(tool_use) => {
 367                    markdown.push_str(&format!(
 368                        "**Tool Use**: {} (ID: {})\n",
 369                        tool_use.name, tool_use.id
 370                    ));
 371                    markdown.push_str(&format!(
 372                        "{}\n",
 373                        MarkdownCodeBlock {
 374                            tag: "json",
 375                            text: &format!("{:#}", tool_use.input)
 376                        }
 377                    ));
 378                }
 379            }
 380        }
 381
 382        for tool_result in self.tool_results.values() {
 383            markdown.push_str(&format!(
 384                "**Tool Result**: {} (ID: {})\n\n",
 385                tool_result.tool_name, tool_result.tool_use_id
 386            ));
 387            if tool_result.is_error {
 388                markdown.push_str("**ERROR:**\n");
 389            }
 390
 391            match &tool_result.content {
 392                LanguageModelToolResultContent::Text(text) => {
 393                    writeln!(markdown, "{text}\n").ok();
 394                }
 395                LanguageModelToolResultContent::Image(_) => {
 396                    writeln!(markdown, "<image />\n").ok();
 397                }
 398            }
 399
 400            if let Some(output) = tool_result.output.as_ref() {
 401                writeln!(
 402                    markdown,
 403                    "**Debug Output**:\n\n```json\n{}\n```\n",
 404                    serde_json::to_string_pretty(output).unwrap()
 405                )
 406                .unwrap();
 407            }
 408        }
 409
 410        markdown
 411    }
 412
 413    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 414        let mut assistant_message = LanguageModelRequestMessage {
 415            role: Role::Assistant,
 416            content: Vec::with_capacity(self.content.len()),
 417            cache: false,
 418        };
 419        for chunk in &self.content {
 420            let chunk = match chunk {
 421                AgentMessageContent::Text(text) => {
 422                    language_model::MessageContent::Text(text.clone())
 423                }
 424                AgentMessageContent::Thinking { text, signature } => {
 425                    language_model::MessageContent::Thinking {
 426                        text: text.clone(),
 427                        signature: signature.clone(),
 428                    }
 429                }
 430                AgentMessageContent::RedactedThinking(value) => {
 431                    language_model::MessageContent::RedactedThinking(value.clone())
 432                }
 433                AgentMessageContent::ToolUse(value) => {
 434                    language_model::MessageContent::ToolUse(value.clone())
 435                }
 436            };
 437            assistant_message.content.push(chunk);
 438        }
 439
 440        let mut user_message = LanguageModelRequestMessage {
 441            role: Role::User,
 442            content: Vec::new(),
 443            cache: false,
 444        };
 445
 446        for tool_result in self.tool_results.values() {
 447            user_message
 448                .content
 449                .push(language_model::MessageContent::ToolResult(
 450                    tool_result.clone(),
 451                ));
 452        }
 453
 454        let mut messages = Vec::new();
 455        if !assistant_message.content.is_empty() {
 456            messages.push(assistant_message);
 457        }
 458        if !user_message.content.is_empty() {
 459            messages.push(user_message);
 460        }
 461        messages
 462    }
 463}
 464
 465#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 466pub struct AgentMessage {
 467    pub content: Vec<AgentMessageContent>,
 468    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 469}
 470
 471#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 472pub enum AgentMessageContent {
 473    Text(String),
 474    Thinking {
 475        text: String,
 476        signature: Option<String>,
 477    },
 478    RedactedThinking(String),
 479    ToolUse(LanguageModelToolUse),
 480}
 481
 482#[derive(Debug)]
 483pub enum ThreadEvent {
 484    UserMessage(UserMessage),
 485    AgentText(String),
 486    AgentThinking(String),
 487    ToolCall(acp::ToolCall),
 488    ToolCallUpdate(acp_thread::ToolCallUpdate),
 489    ToolCallAuthorization(ToolCallAuthorization),
 490    Retry(acp_thread::RetryStatus),
 491    Stop(acp::StopReason),
 492}
 493
 494#[derive(Debug)]
 495pub struct ToolCallAuthorization {
 496    pub tool_call: acp::ToolCallUpdate,
 497    pub options: Vec<acp::PermissionOption>,
 498    pub response: oneshot::Sender<acp::PermissionOptionId>,
 499}
 500
 501#[derive(Debug, thiserror::Error)]
 502enum CompletionError {
 503    #[error("max tokens")]
 504    MaxTokens,
 505    #[error("refusal")]
 506    Refusal,
 507    #[error(transparent)]
 508    Other(#[from] anyhow::Error),
 509}
 510
 511pub struct Thread {
 512    id: acp::SessionId,
 513    prompt_id: PromptId,
 514    updated_at: DateTime<Utc>,
 515    title: Option<SharedString>,
 516    pending_title_generation: Option<Task<()>>,
 517    summary: Option<SharedString>,
 518    messages: Vec<Message>,
 519    completion_mode: CompletionMode,
 520    /// Holds the task that handles agent interaction until the end of the turn.
 521    /// Survives across multiple requests as the model performs tool calls and
 522    /// we run tools, report their results.
 523    running_turn: Option<RunningTurn>,
 524    pending_message: Option<AgentMessage>,
 525    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 526    tool_use_limit_reached: bool,
 527    request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
 528    #[allow(unused)]
 529    cumulative_token_usage: TokenUsage,
 530    #[allow(unused)]
 531    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 532    context_server_registry: Entity<ContextServerRegistry>,
 533    profile_id: AgentProfileId,
 534    project_context: Entity<ProjectContext>,
 535    templates: Arc<Templates>,
 536    model: Option<Arc<dyn LanguageModel>>,
 537    summarization_model: Option<Arc<dyn LanguageModel>>,
 538    pub(crate) project: Entity<Project>,
 539    pub(crate) action_log: Entity<ActionLog>,
 540}
 541
 542impl Thread {
 543    pub fn new(
 544        project: Entity<Project>,
 545        project_context: Entity<ProjectContext>,
 546        context_server_registry: Entity<ContextServerRegistry>,
 547        action_log: Entity<ActionLog>,
 548        templates: Arc<Templates>,
 549        model: Option<Arc<dyn LanguageModel>>,
 550        cx: &mut Context<Self>,
 551    ) -> Self {
 552        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 553        Self {
 554            id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
 555            prompt_id: PromptId::new(),
 556            updated_at: Utc::now(),
 557            title: None,
 558            pending_title_generation: None,
 559            summary: None,
 560            messages: Vec::new(),
 561            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 562            running_turn: None,
 563            pending_message: None,
 564            tools: BTreeMap::default(),
 565            tool_use_limit_reached: false,
 566            request_token_usage: HashMap::default(),
 567            cumulative_token_usage: TokenUsage::default(),
 568            initial_project_snapshot: {
 569                let project_snapshot = Self::project_snapshot(project.clone(), cx);
 570                cx.foreground_executor()
 571                    .spawn(async move { Some(project_snapshot.await) })
 572                    .shared()
 573            },
 574            context_server_registry,
 575            profile_id,
 576            project_context,
 577            templates,
 578            model,
 579            summarization_model: None,
 580            project,
 581            action_log,
 582        }
 583    }
 584
 585    pub fn id(&self) -> &acp::SessionId {
 586        &self.id
 587    }
 588
 589    pub fn replay(
 590        &mut self,
 591        cx: &mut Context<Self>,
 592    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 593        let (tx, rx) = mpsc::unbounded();
 594        let stream = ThreadEventStream(tx);
 595        for message in &self.messages {
 596            match message {
 597                Message::User(user_message) => stream.send_user_message(user_message),
 598                Message::Agent(assistant_message) => {
 599                    for content in &assistant_message.content {
 600                        match content {
 601                            AgentMessageContent::Text(text) => stream.send_text(text),
 602                            AgentMessageContent::Thinking { text, .. } => {
 603                                stream.send_thinking(text)
 604                            }
 605                            AgentMessageContent::RedactedThinking(_) => {}
 606                            AgentMessageContent::ToolUse(tool_use) => {
 607                                self.replay_tool_call(
 608                                    tool_use,
 609                                    assistant_message.tool_results.get(&tool_use.id),
 610                                    &stream,
 611                                    cx,
 612                                );
 613                            }
 614                        }
 615                    }
 616                }
 617                Message::Resume => {}
 618            }
 619        }
 620        rx
 621    }
 622
 623    fn replay_tool_call(
 624        &self,
 625        tool_use: &LanguageModelToolUse,
 626        tool_result: Option<&LanguageModelToolResult>,
 627        stream: &ThreadEventStream,
 628        cx: &mut Context<Self>,
 629    ) {
 630        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
 631            stream
 632                .0
 633                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 634                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 635                    title: tool_use.name.to_string(),
 636                    kind: acp::ToolKind::Other,
 637                    status: acp::ToolCallStatus::Failed,
 638                    content: Vec::new(),
 639                    locations: Vec::new(),
 640                    raw_input: Some(tool_use.input.clone()),
 641                    raw_output: None,
 642                })))
 643                .ok();
 644            return;
 645        };
 646
 647        let title = tool.initial_title(tool_use.input.clone());
 648        let kind = tool.kind();
 649        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 650
 651        let output = tool_result
 652            .as_ref()
 653            .and_then(|result| result.output.clone());
 654        if let Some(output) = output.clone() {
 655            let tool_event_stream = ToolCallEventStream::new(
 656                tool_use.id.clone(),
 657                stream.clone(),
 658                Some(self.project.read(cx).fs().clone()),
 659            );
 660            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 661                .log_err();
 662        }
 663
 664        stream.update_tool_call_fields(
 665            &tool_use.id,
 666            acp::ToolCallUpdateFields {
 667                status: Some(acp::ToolCallStatus::Completed),
 668                raw_output: output,
 669                ..Default::default()
 670            },
 671        );
 672    }
 673
 674    pub fn from_db(
 675        id: acp::SessionId,
 676        db_thread: DbThread,
 677        project: Entity<Project>,
 678        project_context: Entity<ProjectContext>,
 679        context_server_registry: Entity<ContextServerRegistry>,
 680        action_log: Entity<ActionLog>,
 681        templates: Arc<Templates>,
 682        cx: &mut Context<Self>,
 683    ) -> Self {
 684        let profile_id = db_thread
 685            .profile
 686            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 687        let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 688            db_thread
 689                .model
 690                .and_then(|model| {
 691                    let model = SelectedModel {
 692                        provider: model.provider.clone().into(),
 693                        model: model.model.into(),
 694                    };
 695                    registry.select_model(&model, cx)
 696                })
 697                .or_else(|| registry.default_model())
 698                .map(|model| model.model)
 699        });
 700
 701        Self {
 702            id,
 703            prompt_id: PromptId::new(),
 704            title: if db_thread.title.is_empty() {
 705                None
 706            } else {
 707                Some(db_thread.title.clone())
 708            },
 709            pending_title_generation: None,
 710            summary: db_thread.detailed_summary,
 711            messages: db_thread.messages,
 712            completion_mode: db_thread.completion_mode.unwrap_or_default(),
 713            running_turn: None,
 714            pending_message: None,
 715            tools: BTreeMap::default(),
 716            tool_use_limit_reached: false,
 717            request_token_usage: db_thread.request_token_usage.clone(),
 718            cumulative_token_usage: db_thread.cumulative_token_usage,
 719            initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
 720            context_server_registry,
 721            profile_id,
 722            project_context,
 723            templates,
 724            model,
 725            summarization_model: None,
 726            project,
 727            action_log,
 728            updated_at: db_thread.updated_at,
 729        }
 730    }
 731
 732    pub fn to_db(&self, cx: &App) -> Task<DbThread> {
 733        let initial_project_snapshot = self.initial_project_snapshot.clone();
 734        let mut thread = DbThread {
 735            title: self.title(),
 736            messages: self.messages.clone(),
 737            updated_at: self.updated_at,
 738            detailed_summary: self.summary.clone(),
 739            initial_project_snapshot: None,
 740            cumulative_token_usage: self.cumulative_token_usage,
 741            request_token_usage: self.request_token_usage.clone(),
 742            model: self.model.as_ref().map(|model| DbLanguageModel {
 743                provider: model.provider_id().to_string(),
 744                model: model.name().0.to_string(),
 745            }),
 746            completion_mode: Some(self.completion_mode),
 747            profile: Some(self.profile_id.clone()),
 748        };
 749
 750        cx.background_spawn(async move {
 751            let initial_project_snapshot = initial_project_snapshot.await;
 752            thread.initial_project_snapshot = initial_project_snapshot;
 753            thread
 754        })
 755    }
 756
 757    /// Create a snapshot of the current project state including git information and unsaved buffers.
 758    fn project_snapshot(
 759        project: Entity<Project>,
 760        cx: &mut Context<Self>,
 761    ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
 762        let git_store = project.read(cx).git_store().clone();
 763        let worktree_snapshots: Vec<_> = project
 764            .read(cx)
 765            .visible_worktrees(cx)
 766            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
 767            .collect();
 768
 769        cx.spawn(async move |_, cx| {
 770            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
 771
 772            let mut unsaved_buffers = Vec::new();
 773            cx.update(|app_cx| {
 774                let buffer_store = project.read(app_cx).buffer_store();
 775                for buffer_handle in buffer_store.read(app_cx).buffers() {
 776                    let buffer = buffer_handle.read(app_cx);
 777                    if buffer.is_dirty()
 778                        && let Some(file) = buffer.file()
 779                    {
 780                        let path = file.path().to_string_lossy().to_string();
 781                        unsaved_buffers.push(path);
 782                    }
 783                }
 784            })
 785            .ok();
 786
 787            Arc::new(ProjectSnapshot {
 788                worktree_snapshots,
 789                unsaved_buffer_paths: unsaved_buffers,
 790                timestamp: Utc::now(),
 791            })
 792        })
 793    }
 794
 795    fn worktree_snapshot(
 796        worktree: Entity<project::Worktree>,
 797        git_store: Entity<GitStore>,
 798        cx: &App,
 799    ) -> Task<agent::thread::WorktreeSnapshot> {
 800        cx.spawn(async move |cx| {
 801            // Get worktree path and snapshot
 802            let worktree_info = cx.update(|app_cx| {
 803                let worktree = worktree.read(app_cx);
 804                let path = worktree.abs_path().to_string_lossy().to_string();
 805                let snapshot = worktree.snapshot();
 806                (path, snapshot)
 807            });
 808
 809            let Ok((worktree_path, _snapshot)) = worktree_info else {
 810                return WorktreeSnapshot {
 811                    worktree_path: String::new(),
 812                    git_state: None,
 813                };
 814            };
 815
 816            let git_state = git_store
 817                .update(cx, |git_store, cx| {
 818                    git_store
 819                        .repositories()
 820                        .values()
 821                        .find(|repo| {
 822                            repo.read(cx)
 823                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
 824                                .is_some()
 825                        })
 826                        .cloned()
 827                })
 828                .ok()
 829                .flatten()
 830                .map(|repo| {
 831                    repo.update(cx, |repo, _| {
 832                        let current_branch =
 833                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
 834                        repo.send_job(None, |state, _| async move {
 835                            let RepositoryState::Local { backend, .. } = state else {
 836                                return GitState {
 837                                    remote_url: None,
 838                                    head_sha: None,
 839                                    current_branch,
 840                                    diff: None,
 841                                };
 842                            };
 843
 844                            let remote_url = backend.remote_url("origin");
 845                            let head_sha = backend.head_sha().await;
 846                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
 847
 848                            GitState {
 849                                remote_url,
 850                                head_sha,
 851                                current_branch,
 852                                diff,
 853                            }
 854                        })
 855                    })
 856                });
 857
 858            let git_state = match git_state {
 859                Some(git_state) => match git_state.ok() {
 860                    Some(git_state) => git_state.await.ok(),
 861                    None => None,
 862                },
 863                None => None,
 864            };
 865
 866            WorktreeSnapshot {
 867                worktree_path,
 868                git_state,
 869            }
 870        })
 871    }
 872
 873    pub fn project_context(&self) -> &Entity<ProjectContext> {
 874        &self.project_context
 875    }
 876
 877    pub fn project(&self) -> &Entity<Project> {
 878        &self.project
 879    }
 880
 881    pub fn action_log(&self) -> &Entity<ActionLog> {
 882        &self.action_log
 883    }
 884
 885    pub fn is_empty(&self) -> bool {
 886        self.messages.is_empty() && self.title.is_none()
 887    }
 888
 889    pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
 890        self.model.as_ref()
 891    }
 892
 893    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
 894        let old_usage = self.latest_token_usage();
 895        self.model = Some(model);
 896        let new_usage = self.latest_token_usage();
 897        if old_usage != new_usage {
 898            cx.emit(TokenUsageUpdated(new_usage));
 899        }
 900        cx.notify()
 901    }
 902
 903    pub fn summarization_model(&self) -> Option<&Arc<dyn LanguageModel>> {
 904        self.summarization_model.as_ref()
 905    }
 906
 907    pub fn set_summarization_model(
 908        &mut self,
 909        model: Option<Arc<dyn LanguageModel>>,
 910        cx: &mut Context<Self>,
 911    ) {
 912        self.summarization_model = model;
 913        cx.notify()
 914    }
 915
 916    pub fn completion_mode(&self) -> CompletionMode {
 917        self.completion_mode
 918    }
 919
 920    pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
 921        let old_usage = self.latest_token_usage();
 922        self.completion_mode = mode;
 923        let new_usage = self.latest_token_usage();
 924        if old_usage != new_usage {
 925            cx.emit(TokenUsageUpdated(new_usage));
 926        }
 927        cx.notify()
 928    }
 929
 930    #[cfg(any(test, feature = "test-support"))]
 931    pub fn last_message(&self) -> Option<Message> {
 932        if let Some(message) = self.pending_message.clone() {
 933            Some(Message::Agent(message))
 934        } else {
 935            self.messages.last().cloned()
 936        }
 937    }
 938
 939    pub fn add_default_tools(&mut self, cx: &mut Context<Self>) {
 940        let language_registry = self.project.read(cx).languages().clone();
 941        self.add_tool(CopyPathTool::new(self.project.clone()));
 942        self.add_tool(CreateDirectoryTool::new(self.project.clone()));
 943        self.add_tool(DeletePathTool::new(
 944            self.project.clone(),
 945            self.action_log.clone(),
 946        ));
 947        self.add_tool(DiagnosticsTool::new(self.project.clone()));
 948        self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
 949        self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
 950        self.add_tool(FindPathTool::new(self.project.clone()));
 951        self.add_tool(GrepTool::new(self.project.clone()));
 952        self.add_tool(ListDirectoryTool::new(self.project.clone()));
 953        self.add_tool(MovePathTool::new(self.project.clone()));
 954        self.add_tool(NowTool);
 955        self.add_tool(OpenTool::new(self.project.clone()));
 956        self.add_tool(ReadFileTool::new(
 957            self.project.clone(),
 958            self.action_log.clone(),
 959        ));
 960        self.add_tool(TerminalTool::new(self.project.clone(), cx));
 961        self.add_tool(ThinkingTool);
 962        self.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
 963    }
 964
 965    pub fn add_tool(&mut self, tool: impl AgentTool) {
 966        self.tools.insert(tool.name(), tool.erase());
 967    }
 968
 969    pub fn remove_tool(&mut self, name: &str) -> bool {
 970        self.tools.remove(name).is_some()
 971    }
 972
 973    pub fn profile(&self) -> &AgentProfileId {
 974        &self.profile_id
 975    }
 976
 977    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 978        self.profile_id = profile_id;
 979    }
 980
 981    pub fn cancel(&mut self, cx: &mut Context<Self>) {
 982        if let Some(running_turn) = self.running_turn.take() {
 983            running_turn.cancel();
 984        }
 985        self.flush_pending_message(cx);
 986    }
 987
 988    fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
 989        let Some(last_user_message) = self.last_user_message() else {
 990            return;
 991        };
 992
 993        self.request_token_usage
 994            .insert(last_user_message.id.clone(), update);
 995        cx.emit(TokenUsageUpdated(self.latest_token_usage()));
 996        cx.notify();
 997    }
 998
 999    pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
1000        self.cancel(cx);
1001        let Some(position) = self.messages.iter().position(
1002            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
1003        ) else {
1004            return Err(anyhow!("Message not found"));
1005        };
1006
1007        for message in self.messages.drain(position..) {
1008            match message {
1009                Message::User(message) => {
1010                    self.request_token_usage.remove(&message.id);
1011                }
1012                Message::Agent(_) | Message::Resume => {}
1013            }
1014        }
1015        self.summary = None;
1016        cx.notify();
1017        Ok(())
1018    }
1019
1020    pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
1021        let last_user_message = self.last_user_message()?;
1022        let tokens = self.request_token_usage.get(&last_user_message.id)?;
1023        let model = self.model.clone()?;
1024
1025        Some(acp_thread::TokenUsage {
1026            max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
1027            used_tokens: tokens.total_tokens(),
1028        })
1029    }
1030
1031    pub fn resume(
1032        &mut self,
1033        cx: &mut Context<Self>,
1034    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1035        anyhow::ensure!(
1036            self.tool_use_limit_reached,
1037            "can only resume after tool use limit is reached"
1038        );
1039
1040        self.messages.push(Message::Resume);
1041        cx.notify();
1042
1043        log::info!("Total messages in thread: {}", self.messages.len());
1044        self.run_turn(cx)
1045    }
1046
1047    /// Sending a message results in the model streaming a response, which could include tool calls.
1048    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
1049    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
1050    pub fn send<T>(
1051        &mut self,
1052        id: UserMessageId,
1053        content: impl IntoIterator<Item = T>,
1054        cx: &mut Context<Self>,
1055    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
1056    where
1057        T: Into<UserMessageContent>,
1058    {
1059        let model = self.model().context("No language model configured")?;
1060
1061        log::info!("Thread::send called with model: {:?}", model.name());
1062        self.advance_prompt_id();
1063
1064        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
1065        log::debug!("Thread::send content: {:?}", content);
1066
1067        self.messages
1068            .push(Message::User(UserMessage { id, content }));
1069        cx.notify();
1070
1071        log::info!("Total messages in thread: {}", self.messages.len());
1072        self.run_turn(cx)
1073    }
1074
1075    fn run_turn(
1076        &mut self,
1077        cx: &mut Context<Self>,
1078    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1079        self.cancel(cx);
1080
1081        let model = self.model.clone().context("No language model configured")?;
1082        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1083        let event_stream = ThreadEventStream(events_tx);
1084        let message_ix = self.messages.len().saturating_sub(1);
1085        self.tool_use_limit_reached = false;
1086        self.summary = None;
1087        self.running_turn = Some(RunningTurn {
1088            event_stream: event_stream.clone(),
1089            _task: cx.spawn(async move |this, cx| {
1090                log::info!("Starting agent turn execution");
1091
1092                let turn_result: Result<()> = async {
1093                    let mut intent = CompletionIntent::UserPrompt;
1094                    loop {
1095                        Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
1096
1097                        let mut end_turn = true;
1098                        this.update(cx, |this, cx| {
1099                            // Generate title if needed.
1100                            if this.title.is_none() && this.pending_title_generation.is_none() {
1101                                this.generate_title(cx);
1102                            }
1103
1104                            // End the turn if the model didn't use tools.
1105                            let message = this.pending_message.as_ref();
1106                            end_turn =
1107                                message.map_or(true, |message| message.tool_results.is_empty());
1108                            this.flush_pending_message(cx);
1109                        })?;
1110
1111                        if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
1112                            log::info!("Tool use limit reached, completing turn");
1113                            return Err(language_model::ToolUseLimitReachedError.into());
1114                        } else if end_turn {
1115                            log::info!("No tool uses found, completing turn");
1116                            return Ok(());
1117                        } else {
1118                            intent = CompletionIntent::ToolResults;
1119                        }
1120                    }
1121                }
1122                .await;
1123                _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
1124
1125                match turn_result {
1126                    Ok(()) => {
1127                        log::info!("Turn execution completed");
1128                        event_stream.send_stop(acp::StopReason::EndTurn);
1129                    }
1130                    Err(error) => {
1131                        log::error!("Turn execution failed: {:?}", error);
1132                        match error.downcast::<CompletionError>() {
1133                            Ok(CompletionError::Refusal) => {
1134                                event_stream.send_stop(acp::StopReason::Refusal);
1135                                _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
1136                            }
1137                            Ok(CompletionError::MaxTokens) => {
1138                                event_stream.send_stop(acp::StopReason::MaxTokens);
1139                            }
1140                            Ok(CompletionError::Other(error)) | Err(error) => {
1141                                event_stream.send_error(error);
1142                            }
1143                        }
1144                    }
1145                }
1146
1147                _ = this.update(cx, |this, _| this.running_turn.take());
1148            }),
1149        });
1150        Ok(events_rx)
1151    }
1152
1153    async fn stream_completion(
1154        this: &WeakEntity<Self>,
1155        model: &Arc<dyn LanguageModel>,
1156        completion_intent: CompletionIntent,
1157        event_stream: &ThreadEventStream,
1158        cx: &mut AsyncApp,
1159    ) -> Result<()> {
1160        log::debug!("Stream completion started successfully");
1161        let request = this.update(cx, |this, cx| {
1162            this.build_completion_request(completion_intent, cx)
1163        })??;
1164
1165        let mut attempt = None;
1166        'retry: loop {
1167            telemetry::event!(
1168                "Agent Thread Completion",
1169                thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
1170                prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
1171                model = model.telemetry_id(),
1172                model_provider = model.provider_id().to_string(),
1173                attempt
1174            );
1175
1176            log::info!(
1177                "Calling model.stream_completion, attempt {}",
1178                attempt.unwrap_or(0)
1179            );
1180            let mut events = model
1181                .stream_completion(request.clone(), cx)
1182                .await
1183                .map_err(|error| anyhow!(error))?;
1184            let mut tool_results = FuturesUnordered::new();
1185
1186            while let Some(event) = events.next().await {
1187                match event {
1188                    Ok(event) => {
1189                        log::trace!("Received completion event: {:?}", event);
1190                        tool_results.extend(this.update(cx, |this, cx| {
1191                            this.handle_streamed_completion_event(event, event_stream, cx)
1192                        })??);
1193                    }
1194                    Err(error) => {
1195                        let completion_mode =
1196                            this.read_with(cx, |thread, _cx| thread.completion_mode())?;
1197                        if completion_mode == CompletionMode::Normal {
1198                            return Err(anyhow!(error))?;
1199                        }
1200
1201                        let Some(strategy) = Self::retry_strategy_for(&error) else {
1202                            return Err(anyhow!(error))?;
1203                        };
1204
1205                        let max_attempts = match &strategy {
1206                            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
1207                            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
1208                        };
1209
1210                        let attempt = attempt.get_or_insert(0u8);
1211
1212                        *attempt += 1;
1213
1214                        let attempt = *attempt;
1215                        if attempt > max_attempts {
1216                            return Err(anyhow!(error))?;
1217                        }
1218
1219                        let delay = match &strategy {
1220                            RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
1221                                let delay_secs =
1222                                    initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
1223                                Duration::from_secs(delay_secs)
1224                            }
1225                            RetryStrategy::Fixed { delay, .. } => *delay,
1226                        };
1227                        log::debug!("Retry attempt {attempt} with delay {delay:?}");
1228
1229                        event_stream.send_retry(acp_thread::RetryStatus {
1230                            last_error: error.to_string().into(),
1231                            attempt: attempt as usize,
1232                            max_attempts: max_attempts as usize,
1233                            started_at: Instant::now(),
1234                            duration: delay,
1235                        });
1236
1237                        cx.background_executor().timer(delay).await;
1238                        continue 'retry;
1239                    }
1240                }
1241            }
1242
1243            while let Some(tool_result) = tool_results.next().await {
1244                log::info!("Tool finished {:?}", tool_result);
1245
1246                event_stream.update_tool_call_fields(
1247                    &tool_result.tool_use_id,
1248                    acp::ToolCallUpdateFields {
1249                        status: Some(if tool_result.is_error {
1250                            acp::ToolCallStatus::Failed
1251                        } else {
1252                            acp::ToolCallStatus::Completed
1253                        }),
1254                        raw_output: tool_result.output.clone(),
1255                        ..Default::default()
1256                    },
1257                );
1258                this.update(cx, |this, _cx| {
1259                    this.pending_message()
1260                        .tool_results
1261                        .insert(tool_result.tool_use_id.clone(), tool_result);
1262                })?;
1263            }
1264
1265            return Ok(());
1266        }
1267    }
1268
1269    pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
1270        log::debug!("Building system message");
1271        let prompt = SystemPromptTemplate {
1272            project: self.project_context.read(cx),
1273            available_tools: self.tools.keys().cloned().collect(),
1274        }
1275        .render(&self.templates)
1276        .context("failed to build system prompt")
1277        .expect("Invalid template");
1278        log::debug!("System message built");
1279        LanguageModelRequestMessage {
1280            role: Role::System,
1281            content: vec![prompt.into()],
1282            cache: true,
1283        }
1284    }
1285
1286    /// A helper method that's called on every streamed completion event.
1287    /// Returns an optional tool result task, which the main agentic loop will
1288    /// send back to the model when it resolves.
1289    fn handle_streamed_completion_event(
1290        &mut self,
1291        event: LanguageModelCompletionEvent,
1292        event_stream: &ThreadEventStream,
1293        cx: &mut Context<Self>,
1294    ) -> Result<Option<Task<LanguageModelToolResult>>> {
1295        log::trace!("Handling streamed completion event: {:?}", event);
1296        use LanguageModelCompletionEvent::*;
1297
1298        match event {
1299            StartMessage { .. } => {
1300                self.flush_pending_message(cx);
1301                self.pending_message = Some(AgentMessage::default());
1302            }
1303            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1304            Thinking { text, signature } => {
1305                self.handle_thinking_event(text, signature, event_stream, cx)
1306            }
1307            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1308            ToolUse(tool_use) => {
1309                return Ok(self.handle_tool_use_event(tool_use, event_stream, cx));
1310            }
1311            ToolUseJsonParseError {
1312                id,
1313                tool_name,
1314                raw_input,
1315                json_parse_error,
1316            } => {
1317                return Ok(Some(Task::ready(
1318                    self.handle_tool_use_json_parse_error_event(
1319                        id,
1320                        tool_name,
1321                        raw_input,
1322                        json_parse_error,
1323                    ),
1324                )));
1325            }
1326            UsageUpdate(usage) => {
1327                telemetry::event!(
1328                    "Agent Thread Completion Usage Updated",
1329                    thread_id = self.id.to_string(),
1330                    prompt_id = self.prompt_id.to_string(),
1331                    model = self.model.as_ref().map(|m| m.telemetry_id()),
1332                    model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()),
1333                    input_tokens = usage.input_tokens,
1334                    output_tokens = usage.output_tokens,
1335                    cache_creation_input_tokens = usage.cache_creation_input_tokens,
1336                    cache_read_input_tokens = usage.cache_read_input_tokens,
1337                );
1338                self.update_token_usage(usage, cx);
1339            }
1340            StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
1341                self.update_model_request_usage(amount, limit, cx);
1342            }
1343            StatusUpdate(
1344                CompletionRequestStatus::Started
1345                | CompletionRequestStatus::Queued { .. }
1346                | CompletionRequestStatus::Failed { .. },
1347            ) => {}
1348            StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
1349                self.tool_use_limit_reached = true;
1350            }
1351            Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
1352            Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
1353            Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
1354        }
1355
1356        Ok(None)
1357    }
1358
1359    fn handle_text_event(
1360        &mut self,
1361        new_text: String,
1362        event_stream: &ThreadEventStream,
1363        cx: &mut Context<Self>,
1364    ) {
1365        event_stream.send_text(&new_text);
1366
1367        let last_message = self.pending_message();
1368        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1369            text.push_str(&new_text);
1370        } else {
1371            last_message
1372                .content
1373                .push(AgentMessageContent::Text(new_text));
1374        }
1375
1376        cx.notify();
1377    }
1378
1379    fn handle_thinking_event(
1380        &mut self,
1381        new_text: String,
1382        new_signature: Option<String>,
1383        event_stream: &ThreadEventStream,
1384        cx: &mut Context<Self>,
1385    ) {
1386        event_stream.send_thinking(&new_text);
1387
1388        let last_message = self.pending_message();
1389        if let Some(AgentMessageContent::Thinking { text, signature }) =
1390            last_message.content.last_mut()
1391        {
1392            text.push_str(&new_text);
1393            *signature = new_signature.or(signature.take());
1394        } else {
1395            last_message.content.push(AgentMessageContent::Thinking {
1396                text: new_text,
1397                signature: new_signature,
1398            });
1399        }
1400
1401        cx.notify();
1402    }
1403
1404    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1405        let last_message = self.pending_message();
1406        last_message
1407            .content
1408            .push(AgentMessageContent::RedactedThinking(data));
1409        cx.notify();
1410    }
1411
1412    fn handle_tool_use_event(
1413        &mut self,
1414        tool_use: LanguageModelToolUse,
1415        event_stream: &ThreadEventStream,
1416        cx: &mut Context<Self>,
1417    ) -> Option<Task<LanguageModelToolResult>> {
1418        cx.notify();
1419
1420        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1421        let mut title = SharedString::from(&tool_use.name);
1422        let mut kind = acp::ToolKind::Other;
1423        if let Some(tool) = tool.as_ref() {
1424            title = tool.initial_title(tool_use.input.clone());
1425            kind = tool.kind();
1426        }
1427
1428        // Ensure the last message ends in the current tool use
1429        let last_message = self.pending_message();
1430        let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
1431            if let AgentMessageContent::ToolUse(last_tool_use) = content {
1432                if last_tool_use.id == tool_use.id {
1433                    *last_tool_use = tool_use.clone();
1434                    false
1435                } else {
1436                    true
1437                }
1438            } else {
1439                true
1440            }
1441        });
1442
1443        if push_new_tool_use {
1444            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1445            last_message
1446                .content
1447                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1448        } else {
1449            event_stream.update_tool_call_fields(
1450                &tool_use.id,
1451                acp::ToolCallUpdateFields {
1452                    title: Some(title.into()),
1453                    kind: Some(kind),
1454                    raw_input: Some(tool_use.input.clone()),
1455                    ..Default::default()
1456                },
1457            );
1458        }
1459
1460        if !tool_use.is_input_complete {
1461            return None;
1462        }
1463
1464        let Some(tool) = tool else {
1465            let content = format!("No tool named {} exists", tool_use.name);
1466            return Some(Task::ready(LanguageModelToolResult {
1467                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1468                tool_use_id: tool_use.id,
1469                tool_name: tool_use.name,
1470                is_error: true,
1471                output: None,
1472            }));
1473        };
1474
1475        let fs = self.project.read(cx).fs().clone();
1476        let tool_event_stream =
1477            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1478        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1479            status: Some(acp::ToolCallStatus::InProgress),
1480            ..Default::default()
1481        });
1482        let supports_images = self.model().is_some_and(|model| model.supports_images());
1483        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1484        log::info!("Running tool {}", tool_use.name);
1485        Some(cx.foreground_executor().spawn(async move {
1486            let tool_result = tool_result.await.and_then(|output| {
1487                if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1488                    && !supports_images
1489                {
1490                    return Err(anyhow!(
1491                        "Attempted to read an image, but this model doesn't support it.",
1492                    ));
1493                }
1494                Ok(output)
1495            });
1496
1497            match tool_result {
1498                Ok(output) => LanguageModelToolResult {
1499                    tool_use_id: tool_use.id,
1500                    tool_name: tool_use.name,
1501                    is_error: false,
1502                    content: output.llm_output,
1503                    output: Some(output.raw_output),
1504                },
1505                Err(error) => LanguageModelToolResult {
1506                    tool_use_id: tool_use.id,
1507                    tool_name: tool_use.name,
1508                    is_error: true,
1509                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1510                    output: None,
1511                },
1512            }
1513        }))
1514    }
1515
1516    fn handle_tool_use_json_parse_error_event(
1517        &mut self,
1518        tool_use_id: LanguageModelToolUseId,
1519        tool_name: Arc<str>,
1520        raw_input: Arc<str>,
1521        json_parse_error: String,
1522    ) -> LanguageModelToolResult {
1523        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1524        LanguageModelToolResult {
1525            tool_use_id,
1526            tool_name,
1527            is_error: true,
1528            content: LanguageModelToolResultContent::Text(tool_output.into()),
1529            output: Some(serde_json::Value::String(raw_input.to_string())),
1530        }
1531    }
1532
1533    fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context<Self>) {
1534        self.project
1535            .read(cx)
1536            .user_store()
1537            .update(cx, |user_store, cx| {
1538                user_store.update_model_request_usage(
1539                    ModelRequestUsage(RequestUsage {
1540                        amount: amount as i32,
1541                        limit,
1542                    }),
1543                    cx,
1544                )
1545            });
1546    }
1547
1548    pub fn title(&self) -> SharedString {
1549        self.title.clone().unwrap_or("New Thread".into())
1550    }
1551
1552    pub fn summary(&mut self, cx: &mut Context<Self>) -> Task<Result<SharedString>> {
1553        if let Some(summary) = self.summary.as_ref() {
1554            return Task::ready(Ok(summary.clone()));
1555        }
1556        let Some(model) = self.summarization_model.clone() else {
1557            return Task::ready(Err(anyhow!("No summarization model available")));
1558        };
1559        let mut request = LanguageModelRequest {
1560            intent: Some(CompletionIntent::ThreadContextSummarization),
1561            temperature: AgentSettings::temperature_for_model(&model, cx),
1562            ..Default::default()
1563        };
1564
1565        for message in &self.messages {
1566            request.messages.extend(message.to_request());
1567        }
1568
1569        request.messages.push(LanguageModelRequestMessage {
1570            role: Role::User,
1571            content: vec![SUMMARIZE_THREAD_DETAILED_PROMPT.into()],
1572            cache: false,
1573        });
1574        cx.spawn(async move |this, cx| {
1575            let mut summary = String::new();
1576            let mut messages = model.stream_completion(request, cx).await?;
1577            while let Some(event) = messages.next().await {
1578                let event = event?;
1579                let text = match event {
1580                    LanguageModelCompletionEvent::Text(text) => text,
1581                    LanguageModelCompletionEvent::StatusUpdate(
1582                        CompletionRequestStatus::UsageUpdated { amount, limit },
1583                    ) => {
1584                        this.update(cx, |thread, cx| {
1585                            thread.update_model_request_usage(amount, limit, cx);
1586                        })?;
1587                        continue;
1588                    }
1589                    _ => continue,
1590                };
1591
1592                let mut lines = text.lines();
1593                summary.extend(lines.next());
1594            }
1595
1596            log::info!("Setting summary: {}", summary);
1597            let summary = SharedString::from(summary);
1598
1599            this.update(cx, |this, cx| {
1600                this.summary = Some(summary.clone());
1601                cx.notify()
1602            })?;
1603
1604            Ok(summary)
1605        })
1606    }
1607
1608    fn generate_title(&mut self, cx: &mut Context<Self>) {
1609        let Some(model) = self.summarization_model.clone() else {
1610            return;
1611        };
1612
1613        log::info!(
1614            "Generating title with model: {:?}",
1615            self.summarization_model.as_ref().map(|model| model.name())
1616        );
1617        let mut request = LanguageModelRequest {
1618            intent: Some(CompletionIntent::ThreadSummarization),
1619            temperature: AgentSettings::temperature_for_model(&model, cx),
1620            ..Default::default()
1621        };
1622
1623        for message in &self.messages {
1624            request.messages.extend(message.to_request());
1625        }
1626
1627        request.messages.push(LanguageModelRequestMessage {
1628            role: Role::User,
1629            content: vec![SUMMARIZE_THREAD_PROMPT.into()],
1630            cache: false,
1631        });
1632        self.pending_title_generation = Some(cx.spawn(async move |this, cx| {
1633            let mut title = String::new();
1634
1635            let generate = async {
1636                let mut messages = model.stream_completion(request, cx).await?;
1637                while let Some(event) = messages.next().await {
1638                    let event = event?;
1639                    let text = match event {
1640                        LanguageModelCompletionEvent::Text(text) => text,
1641                        LanguageModelCompletionEvent::StatusUpdate(
1642                            CompletionRequestStatus::UsageUpdated { amount, limit },
1643                        ) => {
1644                            this.update(cx, |thread, cx| {
1645                                thread.update_model_request_usage(amount, limit, cx);
1646                            })?;
1647                            continue;
1648                        }
1649                        _ => continue,
1650                    };
1651
1652                    let mut lines = text.lines();
1653                    title.extend(lines.next());
1654
1655                    // Stop if the LLM generated multiple lines.
1656                    if lines.next().is_some() {
1657                        break;
1658                    }
1659                }
1660                anyhow::Ok(())
1661            };
1662
1663            if generate.await.context("failed to generate title").is_ok() {
1664                _ = this.update(cx, |this, cx| this.set_title(title.into(), cx));
1665            }
1666            _ = this.update(cx, |this, _| this.pending_title_generation = None);
1667        }));
1668    }
1669
1670    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) {
1671        self.pending_title_generation = None;
1672        if Some(&title) != self.title.as_ref() {
1673            self.title = Some(title);
1674            cx.emit(TitleUpdated);
1675            cx.notify();
1676        }
1677    }
1678
1679    fn last_user_message(&self) -> Option<&UserMessage> {
1680        self.messages
1681            .iter()
1682            .rev()
1683            .find_map(|message| match message {
1684                Message::User(user_message) => Some(user_message),
1685                Message::Agent(_) => None,
1686                Message::Resume => None,
1687            })
1688    }
1689
1690    fn pending_message(&mut self) -> &mut AgentMessage {
1691        self.pending_message.get_or_insert_default()
1692    }
1693
1694    fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1695        let Some(mut message) = self.pending_message.take() else {
1696            return;
1697        };
1698
1699        for content in &message.content {
1700            let AgentMessageContent::ToolUse(tool_use) = content else {
1701                continue;
1702            };
1703
1704            if !message.tool_results.contains_key(&tool_use.id) {
1705                message.tool_results.insert(
1706                    tool_use.id.clone(),
1707                    LanguageModelToolResult {
1708                        tool_use_id: tool_use.id.clone(),
1709                        tool_name: tool_use.name.clone(),
1710                        is_error: true,
1711                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1712                        output: None,
1713                    },
1714                );
1715            }
1716        }
1717
1718        self.messages.push(Message::Agent(message));
1719        self.updated_at = Utc::now();
1720        self.summary = None;
1721        cx.notify()
1722    }
1723
1724    pub(crate) fn build_completion_request(
1725        &self,
1726        completion_intent: CompletionIntent,
1727        cx: &mut App,
1728    ) -> Result<LanguageModelRequest> {
1729        let model = self.model().context("No language model configured")?;
1730
1731        log::debug!("Building completion request");
1732        log::debug!("Completion intent: {:?}", completion_intent);
1733        log::debug!("Completion mode: {:?}", self.completion_mode);
1734
1735        let messages = self.build_request_messages(cx);
1736        log::info!("Request will include {} messages", messages.len());
1737
1738        let tools = if let Some(tools) = self.tools(cx).log_err() {
1739            tools
1740                .filter_map(|tool| {
1741                    let tool_name = tool.name().to_string();
1742                    log::trace!("Including tool: {}", tool_name);
1743                    Some(LanguageModelRequestTool {
1744                        name: tool_name,
1745                        description: tool.description().to_string(),
1746                        input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1747                    })
1748                })
1749                .collect()
1750        } else {
1751            Vec::new()
1752        };
1753
1754        log::info!("Request includes {} tools", tools.len());
1755
1756        let request = LanguageModelRequest {
1757            thread_id: Some(self.id.to_string()),
1758            prompt_id: Some(self.prompt_id.to_string()),
1759            intent: Some(completion_intent),
1760            mode: Some(self.completion_mode.into()),
1761            messages,
1762            tools,
1763            tool_choice: None,
1764            stop: Vec::new(),
1765            temperature: AgentSettings::temperature_for_model(model, cx),
1766            thinking_allowed: true,
1767        };
1768
1769        log::debug!("Completion request built successfully");
1770        Ok(request)
1771    }
1772
1773    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1774        let model = self.model().context("No language model configured")?;
1775
1776        let profile = AgentSettings::get_global(cx)
1777            .profiles
1778            .get(&self.profile_id)
1779            .context("profile not found")?;
1780        let provider_id = model.provider_id();
1781
1782        Ok(self
1783            .tools
1784            .iter()
1785            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1786            .filter_map(|(tool_name, tool)| {
1787                if profile.is_tool_enabled(tool_name) {
1788                    Some(tool)
1789                } else {
1790                    None
1791                }
1792            })
1793            .chain(self.context_server_registry.read(cx).servers().flat_map(
1794                |(server_id, tools)| {
1795                    tools.iter().filter_map(|(tool_name, tool)| {
1796                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1797                            Some(tool)
1798                        } else {
1799                            None
1800                        }
1801                    })
1802                },
1803            )))
1804    }
1805
1806    fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1807        log::trace!(
1808            "Building request messages from {} thread messages",
1809            self.messages.len()
1810        );
1811        let mut messages = vec![self.build_system_message(cx)];
1812        for message in &self.messages {
1813            messages.extend(message.to_request());
1814        }
1815
1816        if let Some(message) = self.pending_message.as_ref() {
1817            messages.extend(message.to_request());
1818        }
1819
1820        if let Some(last_user_message) = messages
1821            .iter_mut()
1822            .rev()
1823            .find(|message| message.role == Role::User)
1824        {
1825            last_user_message.cache = true;
1826        }
1827
1828        messages
1829    }
1830
1831    pub fn to_markdown(&self) -> String {
1832        let mut markdown = String::new();
1833        for (ix, message) in self.messages.iter().enumerate() {
1834            if ix > 0 {
1835                markdown.push('\n');
1836            }
1837            markdown.push_str(&message.to_markdown());
1838        }
1839
1840        if let Some(message) = self.pending_message.as_ref() {
1841            markdown.push('\n');
1842            markdown.push_str(&message.to_markdown());
1843        }
1844
1845        markdown
1846    }
1847
1848    fn advance_prompt_id(&mut self) {
1849        self.prompt_id = PromptId::new();
1850    }
1851
1852    fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1853        use LanguageModelCompletionError::*;
1854        use http_client::StatusCode;
1855
1856        // General strategy here:
1857        // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1858        // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1859        // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1860        match error {
1861            HttpResponseError {
1862                status_code: StatusCode::TOO_MANY_REQUESTS,
1863                ..
1864            } => Some(RetryStrategy::ExponentialBackoff {
1865                initial_delay: BASE_RETRY_DELAY,
1866                max_attempts: MAX_RETRY_ATTEMPTS,
1867            }),
1868            ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1869                Some(RetryStrategy::Fixed {
1870                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1871                    max_attempts: MAX_RETRY_ATTEMPTS,
1872                })
1873            }
1874            UpstreamProviderError {
1875                status,
1876                retry_after,
1877                ..
1878            } => match *status {
1879                StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1880                    Some(RetryStrategy::Fixed {
1881                        delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1882                        max_attempts: MAX_RETRY_ATTEMPTS,
1883                    })
1884                }
1885                StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1886                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1887                    // Internal Server Error could be anything, retry up to 3 times.
1888                    max_attempts: 3,
1889                }),
1890                status => {
1891                    // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1892                    // but we frequently get them in practice. See https://http.dev/529
1893                    if status.as_u16() == 529 {
1894                        Some(RetryStrategy::Fixed {
1895                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1896                            max_attempts: MAX_RETRY_ATTEMPTS,
1897                        })
1898                    } else {
1899                        Some(RetryStrategy::Fixed {
1900                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1901                            max_attempts: 2,
1902                        })
1903                    }
1904                }
1905            },
1906            ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1907                delay: BASE_RETRY_DELAY,
1908                max_attempts: 3,
1909            }),
1910            ApiReadResponseError { .. }
1911            | HttpSend { .. }
1912            | DeserializeResponse { .. }
1913            | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
1914                delay: BASE_RETRY_DELAY,
1915                max_attempts: 3,
1916            }),
1917            // Retrying these errors definitely shouldn't help.
1918            HttpResponseError {
1919                status_code:
1920                    StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
1921                ..
1922            }
1923            | AuthenticationError { .. }
1924            | PermissionError { .. }
1925            | NoApiKey { .. }
1926            | ApiEndpointNotFound { .. }
1927            | PromptTooLarge { .. } => None,
1928            // These errors might be transient, so retry them
1929            SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
1930                delay: BASE_RETRY_DELAY,
1931                max_attempts: 1,
1932            }),
1933            // Retry all other 4xx and 5xx errors once.
1934            HttpResponseError { status_code, .. }
1935                if status_code.is_client_error() || status_code.is_server_error() =>
1936            {
1937                Some(RetryStrategy::Fixed {
1938                    delay: BASE_RETRY_DELAY,
1939                    max_attempts: 3,
1940                })
1941            }
1942            Other(err)
1943                if err.is::<language_model::PaymentRequiredError>()
1944                    || err.is::<language_model::ModelRequestLimitReachedError>() =>
1945            {
1946                // Retrying won't help for Payment Required or Model Request Limit errors (where
1947                // the user must upgrade to usage-based billing to get more requests, or else wait
1948                // for a significant amount of time for the request limit to reset).
1949                None
1950            }
1951            // Conservatively assume that any other errors are non-retryable
1952            HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
1953                delay: BASE_RETRY_DELAY,
1954                max_attempts: 2,
1955            }),
1956        }
1957    }
1958}
1959
1960struct RunningTurn {
1961    /// Holds the task that handles agent interaction until the end of the turn.
1962    /// Survives across multiple requests as the model performs tool calls and
1963    /// we run tools, report their results.
1964    _task: Task<()>,
1965    /// The current event stream for the running turn. Used to report a final
1966    /// cancellation event if we cancel the turn.
1967    event_stream: ThreadEventStream,
1968}
1969
1970impl RunningTurn {
1971    fn cancel(self) {
1972        log::debug!("Cancelling in progress turn");
1973        self.event_stream.send_canceled();
1974    }
1975}
1976
1977pub struct TokenUsageUpdated(pub Option<acp_thread::TokenUsage>);
1978
1979impl EventEmitter<TokenUsageUpdated> for Thread {}
1980
1981pub struct TitleUpdated;
1982
1983impl EventEmitter<TitleUpdated> for Thread {}
1984
1985pub trait AgentTool
1986where
1987    Self: 'static + Sized,
1988{
1989    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1990    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1991
1992    fn name(&self) -> SharedString;
1993
1994    fn description(&self) -> SharedString {
1995        let schema = schemars::schema_for!(Self::Input);
1996        SharedString::new(
1997            schema
1998                .get("description")
1999                .and_then(|description| description.as_str())
2000                .unwrap_or_default(),
2001        )
2002    }
2003
2004    fn kind(&self) -> acp::ToolKind;
2005
2006    /// The initial tool title to display. Can be updated during the tool run.
2007    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
2008
2009    /// Returns the JSON schema that describes the tool's input.
2010    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
2011        crate::tool_schema::root_schema_for::<Self::Input>(format)
2012    }
2013
2014    /// Some tools rely on a provider for the underlying billing or other reasons.
2015    /// Allow the tool to check if they are compatible, or should be filtered out.
2016    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2017        true
2018    }
2019
2020    /// Runs the tool with the provided input.
2021    fn run(
2022        self: Arc<Self>,
2023        input: Self::Input,
2024        event_stream: ToolCallEventStream,
2025        cx: &mut App,
2026    ) -> Task<Result<Self::Output>>;
2027
2028    /// Emits events for a previous execution of the tool.
2029    fn replay(
2030        &self,
2031        _input: Self::Input,
2032        _output: Self::Output,
2033        _event_stream: ToolCallEventStream,
2034        _cx: &mut App,
2035    ) -> Result<()> {
2036        Ok(())
2037    }
2038
2039    fn erase(self) -> Arc<dyn AnyAgentTool> {
2040        Arc::new(Erased(Arc::new(self)))
2041    }
2042}
2043
2044pub struct Erased<T>(T);
2045
2046pub struct AgentToolOutput {
2047    pub llm_output: LanguageModelToolResultContent,
2048    pub raw_output: serde_json::Value,
2049}
2050
2051pub trait AnyAgentTool {
2052    fn name(&self) -> SharedString;
2053    fn description(&self) -> SharedString;
2054    fn kind(&self) -> acp::ToolKind;
2055    fn initial_title(&self, input: serde_json::Value) -> SharedString;
2056    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
2057    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2058        true
2059    }
2060    fn run(
2061        self: Arc<Self>,
2062        input: serde_json::Value,
2063        event_stream: ToolCallEventStream,
2064        cx: &mut App,
2065    ) -> Task<Result<AgentToolOutput>>;
2066    fn replay(
2067        &self,
2068        input: serde_json::Value,
2069        output: serde_json::Value,
2070        event_stream: ToolCallEventStream,
2071        cx: &mut App,
2072    ) -> Result<()>;
2073}
2074
2075impl<T> AnyAgentTool for Erased<Arc<T>>
2076where
2077    T: AgentTool,
2078{
2079    fn name(&self) -> SharedString {
2080        self.0.name()
2081    }
2082
2083    fn description(&self) -> SharedString {
2084        self.0.description()
2085    }
2086
2087    fn kind(&self) -> agent_client_protocol::ToolKind {
2088        self.0.kind()
2089    }
2090
2091    fn initial_title(&self, input: serde_json::Value) -> SharedString {
2092        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
2093        self.0.initial_title(parsed_input)
2094    }
2095
2096    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
2097        let mut json = serde_json::to_value(self.0.input_schema(format))?;
2098        adapt_schema_to_format(&mut json, format)?;
2099        Ok(json)
2100    }
2101
2102    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
2103        self.0.supported_provider(provider)
2104    }
2105
2106    fn run(
2107        self: Arc<Self>,
2108        input: serde_json::Value,
2109        event_stream: ToolCallEventStream,
2110        cx: &mut App,
2111    ) -> Task<Result<AgentToolOutput>> {
2112        cx.spawn(async move |cx| {
2113            let input = serde_json::from_value(input)?;
2114            let output = cx
2115                .update(|cx| self.0.clone().run(input, event_stream, cx))?
2116                .await?;
2117            let raw_output = serde_json::to_value(&output)?;
2118            Ok(AgentToolOutput {
2119                llm_output: output.into(),
2120                raw_output,
2121            })
2122        })
2123    }
2124
2125    fn replay(
2126        &self,
2127        input: serde_json::Value,
2128        output: serde_json::Value,
2129        event_stream: ToolCallEventStream,
2130        cx: &mut App,
2131    ) -> Result<()> {
2132        let input = serde_json::from_value(input)?;
2133        let output = serde_json::from_value(output)?;
2134        self.0.replay(input, output, event_stream, cx)
2135    }
2136}
2137
2138#[derive(Clone)]
2139struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
2140
2141impl ThreadEventStream {
2142    fn send_user_message(&self, message: &UserMessage) {
2143        self.0
2144            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
2145            .ok();
2146    }
2147
2148    fn send_text(&self, text: &str) {
2149        self.0
2150            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
2151            .ok();
2152    }
2153
2154    fn send_thinking(&self, text: &str) {
2155        self.0
2156            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
2157            .ok();
2158    }
2159
2160    fn send_tool_call(
2161        &self,
2162        id: &LanguageModelToolUseId,
2163        title: SharedString,
2164        kind: acp::ToolKind,
2165        input: serde_json::Value,
2166    ) {
2167        self.0
2168            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
2169                id,
2170                title.to_string(),
2171                kind,
2172                input,
2173            ))))
2174            .ok();
2175    }
2176
2177    fn initial_tool_call(
2178        id: &LanguageModelToolUseId,
2179        title: String,
2180        kind: acp::ToolKind,
2181        input: serde_json::Value,
2182    ) -> acp::ToolCall {
2183        acp::ToolCall {
2184            id: acp::ToolCallId(id.to_string().into()),
2185            title,
2186            kind,
2187            status: acp::ToolCallStatus::Pending,
2188            content: vec![],
2189            locations: vec![],
2190            raw_input: Some(input),
2191            raw_output: None,
2192        }
2193    }
2194
2195    fn update_tool_call_fields(
2196        &self,
2197        tool_use_id: &LanguageModelToolUseId,
2198        fields: acp::ToolCallUpdateFields,
2199    ) {
2200        self.0
2201            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2202                acp::ToolCallUpdate {
2203                    id: acp::ToolCallId(tool_use_id.to_string().into()),
2204                    fields,
2205                }
2206                .into(),
2207            )))
2208            .ok();
2209    }
2210
2211    fn send_retry(&self, status: acp_thread::RetryStatus) {
2212        self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
2213    }
2214
2215    fn send_stop(&self, reason: acp::StopReason) {
2216        self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
2217    }
2218
2219    fn send_canceled(&self) {
2220        self.0
2221            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
2222            .ok();
2223    }
2224
2225    fn send_error(&self, error: impl Into<anyhow::Error>) {
2226        self.0.unbounded_send(Err(error.into())).ok();
2227    }
2228}
2229
2230#[derive(Clone)]
2231pub struct ToolCallEventStream {
2232    tool_use_id: LanguageModelToolUseId,
2233    stream: ThreadEventStream,
2234    fs: Option<Arc<dyn Fs>>,
2235}
2236
2237impl ToolCallEventStream {
2238    #[cfg(test)]
2239    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
2240        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
2241
2242        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
2243
2244        (stream, ToolCallEventStreamReceiver(events_rx))
2245    }
2246
2247    fn new(
2248        tool_use_id: LanguageModelToolUseId,
2249        stream: ThreadEventStream,
2250        fs: Option<Arc<dyn Fs>>,
2251    ) -> Self {
2252        Self {
2253            tool_use_id,
2254            stream,
2255            fs,
2256        }
2257    }
2258
2259    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
2260        self.stream
2261            .update_tool_call_fields(&self.tool_use_id, fields);
2262    }
2263
2264    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
2265        self.stream
2266            .0
2267            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2268                acp_thread::ToolCallUpdateDiff {
2269                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2270                    diff,
2271                }
2272                .into(),
2273            )))
2274            .ok();
2275    }
2276
2277    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
2278        self.stream
2279            .0
2280            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2281                acp_thread::ToolCallUpdateTerminal {
2282                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2283                    terminal,
2284                }
2285                .into(),
2286            )))
2287            .ok();
2288    }
2289
2290    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
2291        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
2292            return Task::ready(Ok(()));
2293        }
2294
2295        let (response_tx, response_rx) = oneshot::channel();
2296        self.stream
2297            .0
2298            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
2299                ToolCallAuthorization {
2300                    tool_call: acp::ToolCallUpdate {
2301                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2302                        fields: acp::ToolCallUpdateFields {
2303                            title: Some(title.into()),
2304                            ..Default::default()
2305                        },
2306                    },
2307                    options: vec![
2308                        acp::PermissionOption {
2309                            id: acp::PermissionOptionId("always_allow".into()),
2310                            name: "Always Allow".into(),
2311                            kind: acp::PermissionOptionKind::AllowAlways,
2312                        },
2313                        acp::PermissionOption {
2314                            id: acp::PermissionOptionId("allow".into()),
2315                            name: "Allow".into(),
2316                            kind: acp::PermissionOptionKind::AllowOnce,
2317                        },
2318                        acp::PermissionOption {
2319                            id: acp::PermissionOptionId("deny".into()),
2320                            name: "Deny".into(),
2321                            kind: acp::PermissionOptionKind::RejectOnce,
2322                        },
2323                    ],
2324                    response: response_tx,
2325                },
2326            )))
2327            .ok();
2328        let fs = self.fs.clone();
2329        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
2330            "always_allow" => {
2331                if let Some(fs) = fs.clone() {
2332                    cx.update(|cx| {
2333                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
2334                            settings.set_always_allow_tool_actions(true);
2335                        });
2336                    })?;
2337                }
2338
2339                Ok(())
2340            }
2341            "allow" => Ok(()),
2342            _ => Err(anyhow!("Permission to run tool denied by user")),
2343        })
2344    }
2345}
2346
2347#[cfg(test)]
2348pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
2349
2350#[cfg(test)]
2351impl ToolCallEventStreamReceiver {
2352    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
2353        let event = self.0.next().await;
2354        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
2355            auth
2356        } else {
2357            panic!("Expected ToolCallAuthorization but got: {:?}", event);
2358        }
2359    }
2360
2361    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
2362        let event = self.0.next().await;
2363        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
2364            update,
2365        )))) = event
2366        {
2367            update.terminal
2368        } else {
2369            panic!("Expected terminal but got: {:?}", event);
2370        }
2371    }
2372}
2373
2374#[cfg(test)]
2375impl std::ops::Deref for ToolCallEventStreamReceiver {
2376    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
2377
2378    fn deref(&self) -> &Self::Target {
2379        &self.0
2380    }
2381}
2382
2383#[cfg(test)]
2384impl std::ops::DerefMut for ToolCallEventStreamReceiver {
2385    fn deref_mut(&mut self) -> &mut Self::Target {
2386        &mut self.0
2387    }
2388}
2389
2390impl From<&str> for UserMessageContent {
2391    fn from(text: &str) -> Self {
2392        Self::Text(text.into())
2393    }
2394}
2395
2396impl From<acp::ContentBlock> for UserMessageContent {
2397    fn from(value: acp::ContentBlock) -> Self {
2398        match value {
2399            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
2400            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
2401            acp::ContentBlock::Audio(_) => {
2402                // TODO
2403                Self::Text("[audio]".to_string())
2404            }
2405            acp::ContentBlock::ResourceLink(resource_link) => {
2406                match MentionUri::parse(&resource_link.uri) {
2407                    Ok(uri) => Self::Mention {
2408                        uri,
2409                        content: String::new(),
2410                    },
2411                    Err(err) => {
2412                        log::error!("Failed to parse mention link: {}", err);
2413                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2414                    }
2415                }
2416            }
2417            acp::ContentBlock::Resource(resource) => match resource.resource {
2418                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2419                    match MentionUri::parse(&resource.uri) {
2420                        Ok(uri) => Self::Mention {
2421                            uri,
2422                            content: resource.text,
2423                        },
2424                        Err(err) => {
2425                            log::error!("Failed to parse mention link: {}", err);
2426                            Self::Text(
2427                                MarkdownCodeBlock {
2428                                    tag: &resource.uri,
2429                                    text: &resource.text,
2430                                }
2431                                .to_string(),
2432                            )
2433                        }
2434                    }
2435                }
2436                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2437                    // TODO
2438                    Self::Text("[blob]".to_string())
2439                }
2440            },
2441        }
2442    }
2443}
2444
2445impl From<UserMessageContent> for acp::ContentBlock {
2446    fn from(content: UserMessageContent) -> Self {
2447        match content {
2448            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
2449                text,
2450                annotations: None,
2451            }),
2452            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
2453                data: image.source.to_string(),
2454                mime_type: "image/png".to_string(),
2455                annotations: None,
2456                uri: None,
2457            }),
2458            UserMessageContent::Mention { uri, content } => {
2459                acp::ContentBlock::Resource(acp::EmbeddedResource {
2460                    resource: acp::EmbeddedResourceResource::TextResourceContents(
2461                        acp::TextResourceContents {
2462                            mime_type: None,
2463                            text: content,
2464                            uri: uri.to_uri().to_string(),
2465                        },
2466                    ),
2467                    annotations: None,
2468                })
2469            }
2470        }
2471    }
2472}
2473
2474fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2475    LanguageModelImage {
2476        source: image_content.data.into(),
2477        // TODO: make this optional?
2478        size: gpui::Size::new(0.into(), 0.into()),
2479    }
2480}