thread.rs

   1use crate::{
   2    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
   3    DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
   4    ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
   5    Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
   6};
   7use acp_thread::{MentionUri, UserMessageId};
   8use action_log::ActionLog;
   9use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
  10use agent_client_protocol as acp;
  11use agent_settings::{
  12    AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT,
  13    SUMMARIZE_THREAD_PROMPT,
  14};
  15use anyhow::{Context as _, Result, anyhow};
  16use assistant_tool::adapt_schema_to_format;
  17use chrono::{DateTime, Utc};
  18use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
  19use collections::{HashMap, IndexMap};
  20use fs::Fs;
  21use futures::{
  22    FutureExt,
  23    channel::{mpsc, oneshot},
  24    future::Shared,
  25    stream::FuturesUnordered,
  26};
  27use git::repository::DiffType;
  28use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
  29use language_model::{
  30    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
  31    LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
  32    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  33    LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
  34    LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
  35};
  36use project::{
  37    Project,
  38    git_store::{GitStore, RepositoryState},
  39};
  40use prompt_store::ProjectContext;
  41use schemars::{JsonSchema, Schema};
  42use serde::{Deserialize, Serialize};
  43use settings::{Settings, update_settings_file};
  44use smol::stream::StreamExt;
  45use std::{
  46    collections::BTreeMap,
  47    path::Path,
  48    sync::Arc,
  49    time::{Duration, Instant},
  50};
  51use std::{fmt::Write, ops::Range};
  52use util::{ResultExt, markdown::MarkdownCodeBlock};
  53use uuid::Uuid;
  54
  55const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  56
  57/// The ID of the user prompt that initiated a request.
  58///
  59/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  60#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  61pub struct PromptId(Arc<str>);
  62
  63impl PromptId {
  64    pub fn new() -> Self {
  65        Self(Uuid::new_v4().to_string().into())
  66    }
  67}
  68
  69impl std::fmt::Display for PromptId {
  70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  71        write!(f, "{}", self.0)
  72    }
  73}
  74
  75pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
  76pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
  77
  78#[derive(Debug, Clone)]
  79enum RetryStrategy {
  80    ExponentialBackoff {
  81        initial_delay: Duration,
  82        max_attempts: u8,
  83    },
  84    Fixed {
  85        delay: Duration,
  86        max_attempts: u8,
  87    },
  88}
  89
  90#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  91pub enum Message {
  92    User(UserMessage),
  93    Agent(AgentMessage),
  94    Resume,
  95}
  96
  97impl Message {
  98    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
  99        match self {
 100            Message::Agent(agent_message) => Some(agent_message),
 101            _ => None,
 102        }
 103    }
 104
 105    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 106        match self {
 107            Message::User(message) => vec![message.to_request()],
 108            Message::Agent(message) => message.to_request(),
 109            Message::Resume => vec![LanguageModelRequestMessage {
 110                role: Role::User,
 111                content: vec!["Continue where you left off".into()],
 112                cache: false,
 113            }],
 114        }
 115    }
 116
 117    pub fn to_markdown(&self) -> String {
 118        match self {
 119            Message::User(message) => message.to_markdown(),
 120            Message::Agent(message) => message.to_markdown(),
 121            Message::Resume => "[resumed after tool use limit was reached]".into(),
 122        }
 123    }
 124
 125    pub fn role(&self) -> Role {
 126        match self {
 127            Message::User(_) | Message::Resume => Role::User,
 128            Message::Agent(_) => Role::Assistant,
 129        }
 130    }
 131}
 132
 133#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 134pub struct UserMessage {
 135    pub id: UserMessageId,
 136    pub content: Vec<UserMessageContent>,
 137}
 138
 139#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 140pub enum UserMessageContent {
 141    Text(String),
 142    Mention { uri: MentionUri, content: String },
 143    Image(LanguageModelImage),
 144}
 145
 146impl UserMessage {
 147    pub fn to_markdown(&self) -> String {
 148        let mut markdown = String::from("## User\n\n");
 149
 150        for content in &self.content {
 151            match content {
 152                UserMessageContent::Text(text) => {
 153                    markdown.push_str(text);
 154                    markdown.push('\n');
 155                }
 156                UserMessageContent::Image(_) => {
 157                    markdown.push_str("<image />\n");
 158                }
 159                UserMessageContent::Mention { uri, content } => {
 160                    if !content.is_empty() {
 161                        let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
 162                    } else {
 163                        let _ = write!(&mut markdown, "{}\n", uri.as_link());
 164                    }
 165                }
 166            }
 167        }
 168
 169        markdown
 170    }
 171
 172    fn to_request(&self) -> LanguageModelRequestMessage {
 173        let mut message = LanguageModelRequestMessage {
 174            role: Role::User,
 175            content: Vec::with_capacity(self.content.len()),
 176            cache: false,
 177        };
 178
 179        const OPEN_CONTEXT: &str = "<context>\n\
 180            The following items were attached by the user. \
 181            They are up-to-date and don't need to be re-read.\n\n";
 182
 183        const OPEN_FILES_TAG: &str = "<files>";
 184        const OPEN_DIRECTORIES_TAG: &str = "<directories>";
 185        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 186        const OPEN_THREADS_TAG: &str = "<threads>";
 187        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 188        const OPEN_RULES_TAG: &str =
 189            "<rules>\nThe user has specified the following rules that should be applied:\n";
 190
 191        let mut file_context = OPEN_FILES_TAG.to_string();
 192        let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
 193        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 194        let mut thread_context = OPEN_THREADS_TAG.to_string();
 195        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 196        let mut rules_context = OPEN_RULES_TAG.to_string();
 197
 198        for chunk in &self.content {
 199            let chunk = match chunk {
 200                UserMessageContent::Text(text) => {
 201                    language_model::MessageContent::Text(text.clone())
 202                }
 203                UserMessageContent::Image(value) => {
 204                    language_model::MessageContent::Image(value.clone())
 205                }
 206                UserMessageContent::Mention { uri, content } => {
 207                    match uri {
 208                        MentionUri::File { abs_path } => {
 209                            write!(
 210                                &mut symbol_context,
 211                                "\n{}",
 212                                MarkdownCodeBlock {
 213                                    tag: &codeblock_tag(abs_path, None),
 214                                    text: &content.to_string(),
 215                                }
 216                            )
 217                            .ok();
 218                        }
 219                        MentionUri::Directory { .. } => {
 220                            write!(&mut directory_context, "\n{}\n", content).ok();
 221                        }
 222                        MentionUri::Symbol {
 223                            path, line_range, ..
 224                        }
 225                        | MentionUri::Selection {
 226                            path, line_range, ..
 227                        } => {
 228                            write!(
 229                                &mut rules_context,
 230                                "\n{}",
 231                                MarkdownCodeBlock {
 232                                    tag: &codeblock_tag(path, Some(line_range)),
 233                                    text: content
 234                                }
 235                            )
 236                            .ok();
 237                        }
 238                        MentionUri::Thread { .. } => {
 239                            write!(&mut thread_context, "\n{}\n", content).ok();
 240                        }
 241                        MentionUri::TextThread { .. } => {
 242                            write!(&mut thread_context, "\n{}\n", content).ok();
 243                        }
 244                        MentionUri::Rule { .. } => {
 245                            write!(
 246                                &mut rules_context,
 247                                "\n{}",
 248                                MarkdownCodeBlock {
 249                                    tag: "",
 250                                    text: content
 251                                }
 252                            )
 253                            .ok();
 254                        }
 255                        MentionUri::Fetch { url } => {
 256                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 257                        }
 258                    }
 259
 260                    language_model::MessageContent::Text(uri.as_link().to_string())
 261                }
 262            };
 263
 264            message.content.push(chunk);
 265        }
 266
 267        let len_before_context = message.content.len();
 268
 269        if file_context.len() > OPEN_FILES_TAG.len() {
 270            file_context.push_str("</files>\n");
 271            message
 272                .content
 273                .push(language_model::MessageContent::Text(file_context));
 274        }
 275
 276        if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
 277            directory_context.push_str("</directories>\n");
 278            message
 279                .content
 280                .push(language_model::MessageContent::Text(directory_context));
 281        }
 282
 283        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 284            symbol_context.push_str("</symbols>\n");
 285            message
 286                .content
 287                .push(language_model::MessageContent::Text(symbol_context));
 288        }
 289
 290        if thread_context.len() > OPEN_THREADS_TAG.len() {
 291            thread_context.push_str("</threads>\n");
 292            message
 293                .content
 294                .push(language_model::MessageContent::Text(thread_context));
 295        }
 296
 297        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 298            fetch_context.push_str("</fetched_urls>\n");
 299            message
 300                .content
 301                .push(language_model::MessageContent::Text(fetch_context));
 302        }
 303
 304        if rules_context.len() > OPEN_RULES_TAG.len() {
 305            rules_context.push_str("</user_rules>\n");
 306            message
 307                .content
 308                .push(language_model::MessageContent::Text(rules_context));
 309        }
 310
 311        if message.content.len() > len_before_context {
 312            message.content.insert(
 313                len_before_context,
 314                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 315            );
 316            message
 317                .content
 318                .push(language_model::MessageContent::Text("</context>".into()));
 319        }
 320
 321        message
 322    }
 323}
 324
 325fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 326    let mut result = String::new();
 327
 328    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 329        let _ = write!(result, "{} ", extension);
 330    }
 331
 332    let _ = write!(result, "{}", full_path.display());
 333
 334    if let Some(range) = line_range {
 335        if range.start == range.end {
 336            let _ = write!(result, ":{}", range.start + 1);
 337        } else {
 338            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 339        }
 340    }
 341
 342    result
 343}
 344
 345impl AgentMessage {
 346    pub fn to_markdown(&self) -> String {
 347        let mut markdown = String::from("## Assistant\n\n");
 348
 349        for content in &self.content {
 350            match content {
 351                AgentMessageContent::Text(text) => {
 352                    markdown.push_str(text);
 353                    markdown.push('\n');
 354                }
 355                AgentMessageContent::Thinking { text, .. } => {
 356                    markdown.push_str("<think>");
 357                    markdown.push_str(text);
 358                    markdown.push_str("</think>\n");
 359                }
 360                AgentMessageContent::RedactedThinking(_) => {
 361                    markdown.push_str("<redacted_thinking />\n")
 362                }
 363                AgentMessageContent::ToolUse(tool_use) => {
 364                    markdown.push_str(&format!(
 365                        "**Tool Use**: {} (ID: {})\n",
 366                        tool_use.name, tool_use.id
 367                    ));
 368                    markdown.push_str(&format!(
 369                        "{}\n",
 370                        MarkdownCodeBlock {
 371                            tag: "json",
 372                            text: &format!("{:#}", tool_use.input)
 373                        }
 374                    ));
 375                }
 376            }
 377        }
 378
 379        for tool_result in self.tool_results.values() {
 380            markdown.push_str(&format!(
 381                "**Tool Result**: {} (ID: {})\n\n",
 382                tool_result.tool_name, tool_result.tool_use_id
 383            ));
 384            if tool_result.is_error {
 385                markdown.push_str("**ERROR:**\n");
 386            }
 387
 388            match &tool_result.content {
 389                LanguageModelToolResultContent::Text(text) => {
 390                    writeln!(markdown, "{text}\n").ok();
 391                }
 392                LanguageModelToolResultContent::Image(_) => {
 393                    writeln!(markdown, "<image />\n").ok();
 394                }
 395            }
 396
 397            if let Some(output) = tool_result.output.as_ref() {
 398                writeln!(
 399                    markdown,
 400                    "**Debug Output**:\n\n```json\n{}\n```\n",
 401                    serde_json::to_string_pretty(output).unwrap()
 402                )
 403                .unwrap();
 404            }
 405        }
 406
 407        markdown
 408    }
 409
 410    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 411        let mut assistant_message = LanguageModelRequestMessage {
 412            role: Role::Assistant,
 413            content: Vec::with_capacity(self.content.len()),
 414            cache: false,
 415        };
 416        for chunk in &self.content {
 417            let chunk = match chunk {
 418                AgentMessageContent::Text(text) => {
 419                    language_model::MessageContent::Text(text.clone())
 420                }
 421                AgentMessageContent::Thinking { text, signature } => {
 422                    language_model::MessageContent::Thinking {
 423                        text: text.clone(),
 424                        signature: signature.clone(),
 425                    }
 426                }
 427                AgentMessageContent::RedactedThinking(value) => {
 428                    language_model::MessageContent::RedactedThinking(value.clone())
 429                }
 430                AgentMessageContent::ToolUse(value) => {
 431                    language_model::MessageContent::ToolUse(value.clone())
 432                }
 433            };
 434            assistant_message.content.push(chunk);
 435        }
 436
 437        let mut user_message = LanguageModelRequestMessage {
 438            role: Role::User,
 439            content: Vec::new(),
 440            cache: false,
 441        };
 442
 443        for tool_result in self.tool_results.values() {
 444            user_message
 445                .content
 446                .push(language_model::MessageContent::ToolResult(
 447                    tool_result.clone(),
 448                ));
 449        }
 450
 451        let mut messages = Vec::new();
 452        if !assistant_message.content.is_empty() {
 453            messages.push(assistant_message);
 454        }
 455        if !user_message.content.is_empty() {
 456            messages.push(user_message);
 457        }
 458        messages
 459    }
 460}
 461
 462#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 463pub struct AgentMessage {
 464    pub content: Vec<AgentMessageContent>,
 465    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 466}
 467
 468#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 469pub enum AgentMessageContent {
 470    Text(String),
 471    Thinking {
 472        text: String,
 473        signature: Option<String>,
 474    },
 475    RedactedThinking(String),
 476    ToolUse(LanguageModelToolUse),
 477}
 478
 479#[derive(Debug)]
 480pub enum ThreadEvent {
 481    UserMessage(UserMessage),
 482    AgentText(String),
 483    AgentThinking(String),
 484    ToolCall(acp::ToolCall),
 485    ToolCallUpdate(acp_thread::ToolCallUpdate),
 486    ToolCallAuthorization(ToolCallAuthorization),
 487    TokenUsageUpdate(acp_thread::TokenUsage),
 488    TitleUpdate(SharedString),
 489    Retry(acp_thread::RetryStatus),
 490    Stop(acp::StopReason),
 491}
 492
 493#[derive(Debug)]
 494pub struct ToolCallAuthorization {
 495    pub tool_call: acp::ToolCallUpdate,
 496    pub options: Vec<acp::PermissionOption>,
 497    pub response: oneshot::Sender<acp::PermissionOptionId>,
 498}
 499
 500pub struct Thread {
 501    id: acp::SessionId,
 502    prompt_id: PromptId,
 503    updated_at: DateTime<Utc>,
 504    title: Option<SharedString>,
 505    summary: Option<SharedString>,
 506    messages: Vec<Message>,
 507    completion_mode: CompletionMode,
 508    /// Holds the task that handles agent interaction until the end of the turn.
 509    /// Survives across multiple requests as the model performs tool calls and
 510    /// we run tools, report their results.
 511    running_turn: Option<RunningTurn>,
 512    pending_message: Option<AgentMessage>,
 513    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 514    tool_use_limit_reached: bool,
 515    request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
 516    #[allow(unused)]
 517    cumulative_token_usage: TokenUsage,
 518    #[allow(unused)]
 519    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 520    context_server_registry: Entity<ContextServerRegistry>,
 521    profile_id: AgentProfileId,
 522    project_context: Entity<ProjectContext>,
 523    templates: Arc<Templates>,
 524    model: Option<Arc<dyn LanguageModel>>,
 525    summarization_model: Option<Arc<dyn LanguageModel>>,
 526    pub(crate) project: Entity<Project>,
 527    pub(crate) action_log: Entity<ActionLog>,
 528}
 529
 530impl Thread {
 531    pub fn new(
 532        project: Entity<Project>,
 533        project_context: Entity<ProjectContext>,
 534        context_server_registry: Entity<ContextServerRegistry>,
 535        action_log: Entity<ActionLog>,
 536        templates: Arc<Templates>,
 537        model: Option<Arc<dyn LanguageModel>>,
 538        cx: &mut Context<Self>,
 539    ) -> Self {
 540        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 541        Self {
 542            id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
 543            prompt_id: PromptId::new(),
 544            updated_at: Utc::now(),
 545            title: None,
 546            summary: None,
 547            messages: Vec::new(),
 548            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 549            running_turn: None,
 550            pending_message: None,
 551            tools: BTreeMap::default(),
 552            tool_use_limit_reached: false,
 553            request_token_usage: HashMap::default(),
 554            cumulative_token_usage: TokenUsage::default(),
 555            initial_project_snapshot: {
 556                let project_snapshot = Self::project_snapshot(project.clone(), cx);
 557                cx.foreground_executor()
 558                    .spawn(async move { Some(project_snapshot.await) })
 559                    .shared()
 560            },
 561            context_server_registry,
 562            profile_id,
 563            project_context,
 564            templates,
 565            model,
 566            summarization_model: None,
 567            project,
 568            action_log,
 569        }
 570    }
 571
 572    pub fn id(&self) -> &acp::SessionId {
 573        &self.id
 574    }
 575
 576    pub fn replay(
 577        &mut self,
 578        cx: &mut Context<Self>,
 579    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 580        let (tx, rx) = mpsc::unbounded();
 581        let stream = ThreadEventStream(tx);
 582        for message in &self.messages {
 583            match message {
 584                Message::User(user_message) => stream.send_user_message(user_message),
 585                Message::Agent(assistant_message) => {
 586                    for content in &assistant_message.content {
 587                        match content {
 588                            AgentMessageContent::Text(text) => stream.send_text(text),
 589                            AgentMessageContent::Thinking { text, .. } => {
 590                                stream.send_thinking(text)
 591                            }
 592                            AgentMessageContent::RedactedThinking(_) => {}
 593                            AgentMessageContent::ToolUse(tool_use) => {
 594                                self.replay_tool_call(
 595                                    tool_use,
 596                                    assistant_message.tool_results.get(&tool_use.id),
 597                                    &stream,
 598                                    cx,
 599                                );
 600                            }
 601                        }
 602                    }
 603                }
 604                Message::Resume => {}
 605            }
 606        }
 607        rx
 608    }
 609
 610    fn replay_tool_call(
 611        &self,
 612        tool_use: &LanguageModelToolUse,
 613        tool_result: Option<&LanguageModelToolResult>,
 614        stream: &ThreadEventStream,
 615        cx: &mut Context<Self>,
 616    ) {
 617        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
 618            stream
 619                .0
 620                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 621                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 622                    title: tool_use.name.to_string(),
 623                    kind: acp::ToolKind::Other,
 624                    status: acp::ToolCallStatus::Failed,
 625                    content: Vec::new(),
 626                    locations: Vec::new(),
 627                    raw_input: Some(tool_use.input.clone()),
 628                    raw_output: None,
 629                })))
 630                .ok();
 631            return;
 632        };
 633
 634        let title = tool.initial_title(tool_use.input.clone());
 635        let kind = tool.kind();
 636        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 637
 638        let output = tool_result
 639            .as_ref()
 640            .and_then(|result| result.output.clone());
 641        if let Some(output) = output.clone() {
 642            let tool_event_stream = ToolCallEventStream::new(
 643                tool_use.id.clone(),
 644                stream.clone(),
 645                Some(self.project.read(cx).fs().clone()),
 646            );
 647            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 648                .log_err();
 649        }
 650
 651        stream.update_tool_call_fields(
 652            &tool_use.id,
 653            acp::ToolCallUpdateFields {
 654                status: Some(acp::ToolCallStatus::Completed),
 655                raw_output: output,
 656                ..Default::default()
 657            },
 658        );
 659    }
 660
 661    pub fn from_db(
 662        id: acp::SessionId,
 663        db_thread: DbThread,
 664        project: Entity<Project>,
 665        project_context: Entity<ProjectContext>,
 666        context_server_registry: Entity<ContextServerRegistry>,
 667        action_log: Entity<ActionLog>,
 668        templates: Arc<Templates>,
 669        cx: &mut Context<Self>,
 670    ) -> Self {
 671        let profile_id = db_thread
 672            .profile
 673            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 674        let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 675            db_thread
 676                .model
 677                .and_then(|model| {
 678                    let model = SelectedModel {
 679                        provider: model.provider.clone().into(),
 680                        model: model.model.clone().into(),
 681                    };
 682                    registry.select_model(&model, cx)
 683                })
 684                .or_else(|| registry.default_model())
 685                .map(|model| model.model)
 686        });
 687
 688        Self {
 689            id,
 690            prompt_id: PromptId::new(),
 691            title: if db_thread.title.is_empty() {
 692                None
 693            } else {
 694                Some(db_thread.title.clone())
 695            },
 696            summary: db_thread.detailed_summary,
 697            messages: db_thread.messages,
 698            completion_mode: db_thread.completion_mode.unwrap_or_default(),
 699            running_turn: None,
 700            pending_message: None,
 701            tools: BTreeMap::default(),
 702            tool_use_limit_reached: false,
 703            request_token_usage: db_thread.request_token_usage.clone(),
 704            cumulative_token_usage: db_thread.cumulative_token_usage,
 705            initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
 706            context_server_registry,
 707            profile_id,
 708            project_context,
 709            templates,
 710            model,
 711            summarization_model: None,
 712            project,
 713            action_log,
 714            updated_at: db_thread.updated_at,
 715        }
 716    }
 717
 718    pub fn to_db(&self, cx: &App) -> Task<DbThread> {
 719        let initial_project_snapshot = self.initial_project_snapshot.clone();
 720        let mut thread = DbThread {
 721            title: self.title.clone().unwrap_or_default(),
 722            messages: self.messages.clone(),
 723            updated_at: self.updated_at,
 724            detailed_summary: self.summary.clone(),
 725            initial_project_snapshot: None,
 726            cumulative_token_usage: self.cumulative_token_usage,
 727            request_token_usage: self.request_token_usage.clone(),
 728            model: self.model.as_ref().map(|model| DbLanguageModel {
 729                provider: model.provider_id().to_string(),
 730                model: model.name().0.to_string(),
 731            }),
 732            completion_mode: Some(self.completion_mode),
 733            profile: Some(self.profile_id.clone()),
 734        };
 735
 736        cx.background_spawn(async move {
 737            let initial_project_snapshot = initial_project_snapshot.await;
 738            thread.initial_project_snapshot = initial_project_snapshot;
 739            thread
 740        })
 741    }
 742
 743    /// Create a snapshot of the current project state including git information and unsaved buffers.
 744    fn project_snapshot(
 745        project: Entity<Project>,
 746        cx: &mut Context<Self>,
 747    ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
 748        let git_store = project.read(cx).git_store().clone();
 749        let worktree_snapshots: Vec<_> = project
 750            .read(cx)
 751            .visible_worktrees(cx)
 752            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
 753            .collect();
 754
 755        cx.spawn(async move |_, cx| {
 756            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
 757
 758            let mut unsaved_buffers = Vec::new();
 759            cx.update(|app_cx| {
 760                let buffer_store = project.read(app_cx).buffer_store();
 761                for buffer_handle in buffer_store.read(app_cx).buffers() {
 762                    let buffer = buffer_handle.read(app_cx);
 763                    if buffer.is_dirty()
 764                        && let Some(file) = buffer.file()
 765                    {
 766                        let path = file.path().to_string_lossy().to_string();
 767                        unsaved_buffers.push(path);
 768                    }
 769                }
 770            })
 771            .ok();
 772
 773            Arc::new(ProjectSnapshot {
 774                worktree_snapshots,
 775                unsaved_buffer_paths: unsaved_buffers,
 776                timestamp: Utc::now(),
 777            })
 778        })
 779    }
 780
 781    fn worktree_snapshot(
 782        worktree: Entity<project::Worktree>,
 783        git_store: Entity<GitStore>,
 784        cx: &App,
 785    ) -> Task<agent::thread::WorktreeSnapshot> {
 786        cx.spawn(async move |cx| {
 787            // Get worktree path and snapshot
 788            let worktree_info = cx.update(|app_cx| {
 789                let worktree = worktree.read(app_cx);
 790                let path = worktree.abs_path().to_string_lossy().to_string();
 791                let snapshot = worktree.snapshot();
 792                (path, snapshot)
 793            });
 794
 795            let Ok((worktree_path, _snapshot)) = worktree_info else {
 796                return WorktreeSnapshot {
 797                    worktree_path: String::new(),
 798                    git_state: None,
 799                };
 800            };
 801
 802            let git_state = git_store
 803                .update(cx, |git_store, cx| {
 804                    git_store
 805                        .repositories()
 806                        .values()
 807                        .find(|repo| {
 808                            repo.read(cx)
 809                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
 810                                .is_some()
 811                        })
 812                        .cloned()
 813                })
 814                .ok()
 815                .flatten()
 816                .map(|repo| {
 817                    repo.update(cx, |repo, _| {
 818                        let current_branch =
 819                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
 820                        repo.send_job(None, |state, _| async move {
 821                            let RepositoryState::Local { backend, .. } = state else {
 822                                return GitState {
 823                                    remote_url: None,
 824                                    head_sha: None,
 825                                    current_branch,
 826                                    diff: None,
 827                                };
 828                            };
 829
 830                            let remote_url = backend.remote_url("origin");
 831                            let head_sha = backend.head_sha().await;
 832                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
 833
 834                            GitState {
 835                                remote_url,
 836                                head_sha,
 837                                current_branch,
 838                                diff,
 839                            }
 840                        })
 841                    })
 842                });
 843
 844            let git_state = match git_state {
 845                Some(git_state) => match git_state.ok() {
 846                    Some(git_state) => git_state.await.ok(),
 847                    None => None,
 848                },
 849                None => None,
 850            };
 851
 852            WorktreeSnapshot {
 853                worktree_path,
 854                git_state,
 855            }
 856        })
 857    }
 858
 859    pub fn project_context(&self) -> &Entity<ProjectContext> {
 860        &self.project_context
 861    }
 862
 863    pub fn project(&self) -> &Entity<Project> {
 864        &self.project
 865    }
 866
 867    pub fn action_log(&self) -> &Entity<ActionLog> {
 868        &self.action_log
 869    }
 870
 871    pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
 872        self.model.as_ref()
 873    }
 874
 875    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
 876        self.model = Some(model);
 877        cx.notify()
 878    }
 879
 880    pub fn set_summarization_model(
 881        &mut self,
 882        model: Option<Arc<dyn LanguageModel>>,
 883        cx: &mut Context<Self>,
 884    ) {
 885        self.summarization_model = model;
 886        cx.notify()
 887    }
 888
 889    pub fn completion_mode(&self) -> CompletionMode {
 890        self.completion_mode
 891    }
 892
 893    pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
 894        self.completion_mode = mode;
 895        cx.notify()
 896    }
 897
 898    #[cfg(any(test, feature = "test-support"))]
 899    pub fn last_message(&self) -> Option<Message> {
 900        if let Some(message) = self.pending_message.clone() {
 901            Some(Message::Agent(message))
 902        } else {
 903            self.messages.last().cloned()
 904        }
 905    }
 906
 907    pub fn add_default_tools(&mut self, cx: &mut Context<Self>) {
 908        let language_registry = self.project.read(cx).languages().clone();
 909        self.add_tool(CopyPathTool::new(self.project.clone()));
 910        self.add_tool(CreateDirectoryTool::new(self.project.clone()));
 911        self.add_tool(DeletePathTool::new(
 912            self.project.clone(),
 913            self.action_log.clone(),
 914        ));
 915        self.add_tool(DiagnosticsTool::new(self.project.clone()));
 916        self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
 917        self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
 918        self.add_tool(FindPathTool::new(self.project.clone()));
 919        self.add_tool(GrepTool::new(self.project.clone()));
 920        self.add_tool(ListDirectoryTool::new(self.project.clone()));
 921        self.add_tool(MovePathTool::new(self.project.clone()));
 922        self.add_tool(NowTool);
 923        self.add_tool(OpenTool::new(self.project.clone()));
 924        self.add_tool(ReadFileTool::new(
 925            self.project.clone(),
 926            self.action_log.clone(),
 927        ));
 928        self.add_tool(TerminalTool::new(self.project.clone(), cx));
 929        self.add_tool(ThinkingTool);
 930        self.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
 931    }
 932
 933    pub fn add_tool(&mut self, tool: impl AgentTool) {
 934        self.tools.insert(tool.name(), tool.erase());
 935    }
 936
 937    pub fn remove_tool(&mut self, name: &str) -> bool {
 938        self.tools.remove(name).is_some()
 939    }
 940
 941    pub fn profile(&self) -> &AgentProfileId {
 942        &self.profile_id
 943    }
 944
 945    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 946        self.profile_id = profile_id;
 947    }
 948
 949    pub fn cancel(&mut self, cx: &mut Context<Self>) {
 950        if let Some(running_turn) = self.running_turn.take() {
 951            running_turn.cancel();
 952        }
 953        self.flush_pending_message(cx);
 954    }
 955
 956    pub fn update_token_usage(&mut self, update: language_model::TokenUsage) {
 957        let Some(last_user_message) = self.last_user_message() else {
 958            return;
 959        };
 960
 961        self.request_token_usage
 962            .insert(last_user_message.id.clone(), update);
 963    }
 964
 965    pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
 966        self.cancel(cx);
 967        let Some(position) = self.messages.iter().position(
 968            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
 969        ) else {
 970            return Err(anyhow!("Message not found"));
 971        };
 972
 973        for message in self.messages.drain(position..) {
 974            match message {
 975                Message::User(message) => {
 976                    self.request_token_usage.remove(&message.id);
 977                }
 978                Message::Agent(_) | Message::Resume => {}
 979            }
 980        }
 981        self.summary = None;
 982        cx.notify();
 983        Ok(())
 984    }
 985
 986    pub fn latest_token_usage(&self) -> Option<acp_thread::TokenUsage> {
 987        let last_user_message = self.last_user_message()?;
 988        let tokens = self.request_token_usage.get(&last_user_message.id)?;
 989        let model = self.model.clone()?;
 990
 991        Some(acp_thread::TokenUsage {
 992            max_tokens: model.max_token_count_for_mode(self.completion_mode.into()),
 993            used_tokens: tokens.total_tokens(),
 994        })
 995    }
 996
 997    pub fn resume(
 998        &mut self,
 999        cx: &mut Context<Self>,
1000    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1001        anyhow::ensure!(
1002            self.tool_use_limit_reached,
1003            "can only resume after tool use limit is reached"
1004        );
1005
1006        self.messages.push(Message::Resume);
1007        cx.notify();
1008
1009        log::info!("Total messages in thread: {}", self.messages.len());
1010        self.run_turn(cx)
1011    }
1012
1013    /// Sending a message results in the model streaming a response, which could include tool calls.
1014    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
1015    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
1016    pub fn send<T>(
1017        &mut self,
1018        id: UserMessageId,
1019        content: impl IntoIterator<Item = T>,
1020        cx: &mut Context<Self>,
1021    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
1022    where
1023        T: Into<UserMessageContent>,
1024    {
1025        let model = self.model().context("No language model configured")?;
1026
1027        log::info!("Thread::send called with model: {:?}", model.name());
1028        self.advance_prompt_id();
1029
1030        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
1031        log::debug!("Thread::send content: {:?}", content);
1032
1033        self.messages
1034            .push(Message::User(UserMessage { id, content }));
1035        cx.notify();
1036
1037        log::info!("Total messages in thread: {}", self.messages.len());
1038        self.run_turn(cx)
1039    }
1040
1041    fn run_turn(
1042        &mut self,
1043        cx: &mut Context<Self>,
1044    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1045        self.cancel(cx);
1046
1047        let model = self.model.clone().context("No language model configured")?;
1048        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1049        let event_stream = ThreadEventStream(events_tx);
1050        let message_ix = self.messages.len().saturating_sub(1);
1051        self.tool_use_limit_reached = false;
1052        self.summary = None;
1053        self.running_turn = Some(RunningTurn {
1054            event_stream: event_stream.clone(),
1055            _task: cx.spawn(async move |this, cx| {
1056                log::info!("Starting agent turn execution");
1057                let turn_result: Result<StopReason> = async {
1058                    let mut completion_intent = CompletionIntent::UserPrompt;
1059                    loop {
1060                        log::debug!(
1061                            "Building completion request with intent: {:?}",
1062                            completion_intent
1063                        );
1064                        let request = this.update(cx, |this, cx| {
1065                            this.build_completion_request(completion_intent, cx)
1066                        })??;
1067
1068                        log::info!("Calling model.stream_completion");
1069
1070                        let mut tool_use_limit_reached = false;
1071                        let mut refused = false;
1072                        let mut reached_max_tokens = false;
1073                        let mut tool_uses = Self::stream_completion_with_retries(
1074                            this.clone(),
1075                            model.clone(),
1076                            request,
1077                            &event_stream,
1078                            &mut tool_use_limit_reached,
1079                            &mut refused,
1080                            &mut reached_max_tokens,
1081                            cx,
1082                        )
1083                        .await?;
1084
1085                        if refused {
1086                            return Ok(StopReason::Refusal);
1087                        } else if reached_max_tokens {
1088                            return Ok(StopReason::MaxTokens);
1089                        }
1090
1091                        let end_turn = tool_uses.is_empty();
1092                        while let Some(tool_result) = tool_uses.next().await {
1093                            log::info!("Tool finished {:?}", tool_result);
1094
1095                            event_stream.update_tool_call_fields(
1096                                &tool_result.tool_use_id,
1097                                acp::ToolCallUpdateFields {
1098                                    status: Some(if tool_result.is_error {
1099                                        acp::ToolCallStatus::Failed
1100                                    } else {
1101                                        acp::ToolCallStatus::Completed
1102                                    }),
1103                                    raw_output: tool_result.output.clone(),
1104                                    ..Default::default()
1105                                },
1106                            );
1107                            this.update(cx, |this, _cx| {
1108                                this.pending_message()
1109                                    .tool_results
1110                                    .insert(tool_result.tool_use_id.clone(), tool_result);
1111                            })
1112                            .ok();
1113                        }
1114
1115                        if tool_use_limit_reached {
1116                            log::info!("Tool use limit reached, completing turn");
1117                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
1118                            return Err(language_model::ToolUseLimitReachedError.into());
1119                        } else if end_turn {
1120                            log::info!("No tool uses found, completing turn");
1121                            return Ok(StopReason::EndTurn);
1122                        } else {
1123                            this.update(cx, |this, cx| this.flush_pending_message(cx))?;
1124                            completion_intent = CompletionIntent::ToolResults;
1125                        }
1126                    }
1127                }
1128                .await;
1129                _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
1130
1131                match turn_result {
1132                    Ok(reason) => {
1133                        log::info!("Turn execution completed: {:?}", reason);
1134
1135                        let update_title = this
1136                            .update(cx, |this, cx| this.update_title(&event_stream, cx))
1137                            .ok()
1138                            .flatten();
1139                        if let Some(update_title) = update_title {
1140                            update_title.await.context("update title failed").log_err();
1141                        }
1142
1143                        event_stream.send_stop(reason);
1144                        if reason == StopReason::Refusal {
1145                            _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
1146                        }
1147                    }
1148                    Err(error) => {
1149                        log::error!("Turn execution failed: {:?}", error);
1150                        event_stream.send_error(error);
1151                    }
1152                }
1153
1154                _ = this.update(cx, |this, _| this.running_turn.take());
1155            }),
1156        });
1157        Ok(events_rx)
1158    }
1159
1160    async fn stream_completion_with_retries(
1161        this: WeakEntity<Self>,
1162        model: Arc<dyn LanguageModel>,
1163        request: LanguageModelRequest,
1164        event_stream: &ThreadEventStream,
1165        tool_use_limit_reached: &mut bool,
1166        refusal: &mut bool,
1167        max_tokens_reached: &mut bool,
1168        cx: &mut AsyncApp,
1169    ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
1170        log::debug!("Stream completion started successfully");
1171
1172        let mut attempt = None;
1173        'retry: loop {
1174            let mut events = model.stream_completion(request.clone(), cx).await?;
1175            let mut tool_uses = FuturesUnordered::new();
1176            while let Some(event) = events.next().await {
1177                match event {
1178                    Ok(LanguageModelCompletionEvent::StatusUpdate(
1179                        CompletionRequestStatus::ToolUseLimitReached,
1180                    )) => {
1181                        *tool_use_limit_reached = true;
1182                    }
1183                    Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1184                        let usage = acp_thread::TokenUsage {
1185                            max_tokens: model.max_token_count_for_mode(
1186                                request
1187                                    .mode
1188                                    .unwrap_or(cloud_llm_client::CompletionMode::Normal),
1189                            ),
1190                            used_tokens: token_usage.total_tokens(),
1191                        };
1192
1193                        this.update(cx, |this, _cx| this.update_token_usage(token_usage))
1194                            .ok();
1195
1196                        event_stream.send_token_usage_update(usage);
1197                    }
1198                    Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
1199                        *refusal = true;
1200                        return Ok(FuturesUnordered::default());
1201                    }
1202                    Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
1203                        *max_tokens_reached = true;
1204                        return Ok(FuturesUnordered::default());
1205                    }
1206                    Ok(LanguageModelCompletionEvent::Stop(
1207                        StopReason::ToolUse | StopReason::EndTurn,
1208                    )) => break,
1209                    Ok(event) => {
1210                        log::trace!("Received completion event: {:?}", event);
1211                        this.update(cx, |this, cx| {
1212                            tool_uses.extend(this.handle_streamed_completion_event(
1213                                event,
1214                                event_stream,
1215                                cx,
1216                            ));
1217                        })
1218                        .ok();
1219                    }
1220                    Err(error) => {
1221                        let completion_mode =
1222                            this.read_with(cx, |thread, _cx| thread.completion_mode())?;
1223                        if completion_mode == CompletionMode::Normal {
1224                            return Err(error.into());
1225                        }
1226
1227                        let Some(strategy) = Self::retry_strategy_for(&error) else {
1228                            return Err(error.into());
1229                        };
1230
1231                        let max_attempts = match &strategy {
1232                            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
1233                            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
1234                        };
1235
1236                        let attempt = attempt.get_or_insert(0u8);
1237
1238                        *attempt += 1;
1239
1240                        let attempt = *attempt;
1241                        if attempt > max_attempts {
1242                            return Err(error.into());
1243                        }
1244
1245                        let delay = match &strategy {
1246                            RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
1247                                let delay_secs =
1248                                    initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
1249                                Duration::from_secs(delay_secs)
1250                            }
1251                            RetryStrategy::Fixed { delay, .. } => *delay,
1252                        };
1253                        log::debug!("Retry attempt {attempt} with delay {delay:?}");
1254
1255                        event_stream.send_retry(acp_thread::RetryStatus {
1256                            last_error: error.to_string().into(),
1257                            attempt: attempt as usize,
1258                            max_attempts: max_attempts as usize,
1259                            started_at: Instant::now(),
1260                            duration: delay,
1261                        });
1262
1263                        cx.background_executor().timer(delay).await;
1264                        continue 'retry;
1265                    }
1266                }
1267            }
1268
1269            return Ok(tool_uses);
1270        }
1271    }
1272
1273    pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
1274        log::debug!("Building system message");
1275        let prompt = SystemPromptTemplate {
1276            project: self.project_context.read(cx),
1277            available_tools: self.tools.keys().cloned().collect(),
1278        }
1279        .render(&self.templates)
1280        .context("failed to build system prompt")
1281        .expect("Invalid template");
1282        log::debug!("System message built");
1283        LanguageModelRequestMessage {
1284            role: Role::System,
1285            content: vec![prompt.into()],
1286            cache: true,
1287        }
1288    }
1289
1290    /// A helper method that's called on every streamed completion event.
1291    /// Returns an optional tool result task, which the main agentic loop in
1292    /// send will send back to the model when it resolves.
1293    fn handle_streamed_completion_event(
1294        &mut self,
1295        event: LanguageModelCompletionEvent,
1296        event_stream: &ThreadEventStream,
1297        cx: &mut Context<Self>,
1298    ) -> Option<Task<LanguageModelToolResult>> {
1299        log::trace!("Handling streamed completion event: {:?}", event);
1300        use LanguageModelCompletionEvent::*;
1301
1302        match event {
1303            StartMessage { .. } => {
1304                self.flush_pending_message(cx);
1305                self.pending_message = Some(AgentMessage::default());
1306            }
1307            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1308            Thinking { text, signature } => {
1309                self.handle_thinking_event(text, signature, event_stream, cx)
1310            }
1311            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1312            ToolUse(tool_use) => {
1313                return self.handle_tool_use_event(tool_use, event_stream, cx);
1314            }
1315            ToolUseJsonParseError {
1316                id,
1317                tool_name,
1318                raw_input,
1319                json_parse_error,
1320            } => {
1321                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1322                    id,
1323                    tool_name,
1324                    raw_input,
1325                    json_parse_error,
1326                )));
1327            }
1328            UsageUpdate(_) | StatusUpdate(_) => {}
1329            Stop(_) => unreachable!(),
1330        }
1331
1332        None
1333    }
1334
1335    fn handle_text_event(
1336        &mut self,
1337        new_text: String,
1338        event_stream: &ThreadEventStream,
1339        cx: &mut Context<Self>,
1340    ) {
1341        event_stream.send_text(&new_text);
1342
1343        let last_message = self.pending_message();
1344        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1345            text.push_str(&new_text);
1346        } else {
1347            last_message
1348                .content
1349                .push(AgentMessageContent::Text(new_text));
1350        }
1351
1352        cx.notify();
1353    }
1354
1355    fn handle_thinking_event(
1356        &mut self,
1357        new_text: String,
1358        new_signature: Option<String>,
1359        event_stream: &ThreadEventStream,
1360        cx: &mut Context<Self>,
1361    ) {
1362        event_stream.send_thinking(&new_text);
1363
1364        let last_message = self.pending_message();
1365        if let Some(AgentMessageContent::Thinking { text, signature }) =
1366            last_message.content.last_mut()
1367        {
1368            text.push_str(&new_text);
1369            *signature = new_signature.or(signature.take());
1370        } else {
1371            last_message.content.push(AgentMessageContent::Thinking {
1372                text: new_text,
1373                signature: new_signature,
1374            });
1375        }
1376
1377        cx.notify();
1378    }
1379
1380    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1381        let last_message = self.pending_message();
1382        last_message
1383            .content
1384            .push(AgentMessageContent::RedactedThinking(data));
1385        cx.notify();
1386    }
1387
1388    fn handle_tool_use_event(
1389        &mut self,
1390        tool_use: LanguageModelToolUse,
1391        event_stream: &ThreadEventStream,
1392        cx: &mut Context<Self>,
1393    ) -> Option<Task<LanguageModelToolResult>> {
1394        cx.notify();
1395
1396        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1397        let mut title = SharedString::from(&tool_use.name);
1398        let mut kind = acp::ToolKind::Other;
1399        if let Some(tool) = tool.as_ref() {
1400            title = tool.initial_title(tool_use.input.clone());
1401            kind = tool.kind();
1402        }
1403
1404        // Ensure the last message ends in the current tool use
1405        let last_message = self.pending_message();
1406        let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
1407            if let AgentMessageContent::ToolUse(last_tool_use) = content {
1408                if last_tool_use.id == tool_use.id {
1409                    *last_tool_use = tool_use.clone();
1410                    false
1411                } else {
1412                    true
1413                }
1414            } else {
1415                true
1416            }
1417        });
1418
1419        if push_new_tool_use {
1420            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1421            last_message
1422                .content
1423                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1424        } else {
1425            event_stream.update_tool_call_fields(
1426                &tool_use.id,
1427                acp::ToolCallUpdateFields {
1428                    title: Some(title.into()),
1429                    kind: Some(kind),
1430                    raw_input: Some(tool_use.input.clone()),
1431                    ..Default::default()
1432                },
1433            );
1434        }
1435
1436        if !tool_use.is_input_complete {
1437            return None;
1438        }
1439
1440        let Some(tool) = tool else {
1441            let content = format!("No tool named {} exists", tool_use.name);
1442            return Some(Task::ready(LanguageModelToolResult {
1443                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1444                tool_use_id: tool_use.id,
1445                tool_name: tool_use.name,
1446                is_error: true,
1447                output: None,
1448            }));
1449        };
1450
1451        let fs = self.project.read(cx).fs().clone();
1452        let tool_event_stream =
1453            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1454        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1455            status: Some(acp::ToolCallStatus::InProgress),
1456            ..Default::default()
1457        });
1458        let supports_images = self.model().is_some_and(|model| model.supports_images());
1459        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1460        log::info!("Running tool {}", tool_use.name);
1461        Some(cx.foreground_executor().spawn(async move {
1462            let tool_result = tool_result.await.and_then(|output| {
1463                if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1464                    && !supports_images
1465                {
1466                    return Err(anyhow!(
1467                        "Attempted to read an image, but this model doesn't support it.",
1468                    ));
1469                }
1470                Ok(output)
1471            });
1472
1473            match tool_result {
1474                Ok(output) => LanguageModelToolResult {
1475                    tool_use_id: tool_use.id,
1476                    tool_name: tool_use.name,
1477                    is_error: false,
1478                    content: output.llm_output,
1479                    output: Some(output.raw_output),
1480                },
1481                Err(error) => LanguageModelToolResult {
1482                    tool_use_id: tool_use.id,
1483                    tool_name: tool_use.name,
1484                    is_error: true,
1485                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1486                    output: None,
1487                },
1488            }
1489        }))
1490    }
1491
1492    fn handle_tool_use_json_parse_error_event(
1493        &mut self,
1494        tool_use_id: LanguageModelToolUseId,
1495        tool_name: Arc<str>,
1496        raw_input: Arc<str>,
1497        json_parse_error: String,
1498    ) -> LanguageModelToolResult {
1499        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1500        LanguageModelToolResult {
1501            tool_use_id,
1502            tool_name,
1503            is_error: true,
1504            content: LanguageModelToolResultContent::Text(tool_output.into()),
1505            output: Some(serde_json::Value::String(raw_input.to_string())),
1506        }
1507    }
1508
1509    pub fn title(&self) -> SharedString {
1510        self.title.clone().unwrap_or("New Thread".into())
1511    }
1512
1513    pub fn summary(&mut self, cx: &mut Context<Self>) -> Task<Result<SharedString>> {
1514        if let Some(summary) = self.summary.as_ref() {
1515            return Task::ready(Ok(summary.clone()));
1516        }
1517        let Some(model) = self.summarization_model.clone() else {
1518            return Task::ready(Err(anyhow!("No summarization model available")));
1519        };
1520        let mut request = LanguageModelRequest {
1521            intent: Some(CompletionIntent::ThreadSummarization),
1522            temperature: AgentSettings::temperature_for_model(&model, cx),
1523            ..Default::default()
1524        };
1525
1526        for message in &self.messages {
1527            request.messages.extend(message.to_request());
1528        }
1529
1530        request.messages.push(LanguageModelRequestMessage {
1531            role: Role::User,
1532            content: vec![SUMMARIZE_THREAD_DETAILED_PROMPT.into()],
1533            cache: false,
1534        });
1535        cx.spawn(async move |this, cx| {
1536            let mut summary = String::new();
1537            let mut messages = model.stream_completion(request, cx).await?;
1538            while let Some(event) = messages.next().await {
1539                let event = event?;
1540                let text = match event {
1541                    LanguageModelCompletionEvent::Text(text) => text,
1542                    LanguageModelCompletionEvent::StatusUpdate(
1543                        CompletionRequestStatus::UsageUpdated { .. },
1544                    ) => {
1545                        // this.update(cx, |thread, cx| {
1546                        //     thread.update_model_request_usage(amount as u32, limit, cx);
1547                        // })?;
1548                        // TODO: handle usage update
1549                        continue;
1550                    }
1551                    _ => continue,
1552                };
1553
1554                let mut lines = text.lines();
1555                summary.extend(lines.next());
1556            }
1557
1558            log::info!("Setting summary: {}", summary);
1559            let summary = SharedString::from(summary);
1560
1561            this.update(cx, |this, cx| {
1562                this.summary = Some(summary.clone());
1563                cx.notify()
1564            })?;
1565
1566            Ok(summary)
1567        })
1568    }
1569
1570    fn update_title(
1571        &mut self,
1572        event_stream: &ThreadEventStream,
1573        cx: &mut Context<Self>,
1574    ) -> Option<Task<Result<()>>> {
1575        if self.title.is_some() {
1576            log::debug!("Skipping title generation because we already have one.");
1577            return None;
1578        }
1579
1580        log::info!(
1581            "Generating title with model: {:?}",
1582            self.summarization_model.as_ref().map(|model| model.name())
1583        );
1584        let model = self.summarization_model.clone()?;
1585        let event_stream = event_stream.clone();
1586        let mut request = LanguageModelRequest {
1587            intent: Some(CompletionIntent::ThreadSummarization),
1588            temperature: AgentSettings::temperature_for_model(&model, cx),
1589            ..Default::default()
1590        };
1591
1592        for message in &self.messages {
1593            request.messages.extend(message.to_request());
1594        }
1595
1596        request.messages.push(LanguageModelRequestMessage {
1597            role: Role::User,
1598            content: vec![SUMMARIZE_THREAD_PROMPT.into()],
1599            cache: false,
1600        });
1601        Some(cx.spawn(async move |this, cx| {
1602            let mut title = String::new();
1603            let mut messages = model.stream_completion(request, cx).await?;
1604            while let Some(event) = messages.next().await {
1605                let event = event?;
1606                let text = match event {
1607                    LanguageModelCompletionEvent::Text(text) => text,
1608                    LanguageModelCompletionEvent::StatusUpdate(
1609                        CompletionRequestStatus::UsageUpdated { .. },
1610                    ) => {
1611                        // this.update(cx, |thread, cx| {
1612                        //     thread.update_model_request_usage(amount as u32, limit, cx);
1613                        // })?;
1614                        // TODO: handle usage update
1615                        continue;
1616                    }
1617                    _ => continue,
1618                };
1619
1620                let mut lines = text.lines();
1621                title.extend(lines.next());
1622
1623                // Stop if the LLM generated multiple lines.
1624                if lines.next().is_some() {
1625                    break;
1626                }
1627            }
1628
1629            log::info!("Setting title: {}", title);
1630
1631            this.update(cx, |this, cx| {
1632                let title = SharedString::from(title);
1633                event_stream.send_title_update(title.clone());
1634                this.title = Some(title);
1635                cx.notify();
1636            })
1637        }))
1638    }
1639    fn last_user_message(&self) -> Option<&UserMessage> {
1640        self.messages
1641            .iter()
1642            .rev()
1643            .find_map(|message| match message {
1644                Message::User(user_message) => Some(user_message),
1645                Message::Agent(_) => None,
1646                Message::Resume => None,
1647            })
1648    }
1649
1650    fn pending_message(&mut self) -> &mut AgentMessage {
1651        self.pending_message.get_or_insert_default()
1652    }
1653
1654    fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1655        let Some(mut message) = self.pending_message.take() else {
1656            return;
1657        };
1658
1659        for content in &message.content {
1660            let AgentMessageContent::ToolUse(tool_use) = content else {
1661                continue;
1662            };
1663
1664            if !message.tool_results.contains_key(&tool_use.id) {
1665                message.tool_results.insert(
1666                    tool_use.id.clone(),
1667                    LanguageModelToolResult {
1668                        tool_use_id: tool_use.id.clone(),
1669                        tool_name: tool_use.name.clone(),
1670                        is_error: true,
1671                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1672                        output: None,
1673                    },
1674                );
1675            }
1676        }
1677
1678        self.messages.push(Message::Agent(message));
1679        self.updated_at = Utc::now();
1680        self.summary = None;
1681        cx.notify()
1682    }
1683
1684    pub(crate) fn build_completion_request(
1685        &self,
1686        completion_intent: CompletionIntent,
1687        cx: &mut App,
1688    ) -> Result<LanguageModelRequest> {
1689        let model = self.model().context("No language model configured")?;
1690
1691        log::debug!("Building completion request");
1692        log::debug!("Completion intent: {:?}", completion_intent);
1693        log::debug!("Completion mode: {:?}", self.completion_mode);
1694
1695        let messages = self.build_request_messages(cx);
1696        log::info!("Request will include {} messages", messages.len());
1697
1698        let tools = if let Some(tools) = self.tools(cx).log_err() {
1699            tools
1700                .filter_map(|tool| {
1701                    let tool_name = tool.name().to_string();
1702                    log::trace!("Including tool: {}", tool_name);
1703                    Some(LanguageModelRequestTool {
1704                        name: tool_name,
1705                        description: tool.description().to_string(),
1706                        input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1707                    })
1708                })
1709                .collect()
1710        } else {
1711            Vec::new()
1712        };
1713
1714        log::info!("Request includes {} tools", tools.len());
1715
1716        let request = LanguageModelRequest {
1717            thread_id: Some(self.id.to_string()),
1718            prompt_id: Some(self.prompt_id.to_string()),
1719            intent: Some(completion_intent),
1720            mode: Some(self.completion_mode.into()),
1721            messages,
1722            tools,
1723            tool_choice: None,
1724            stop: Vec::new(),
1725            temperature: AgentSettings::temperature_for_model(model, cx),
1726            thinking_allowed: true,
1727        };
1728
1729        log::debug!("Completion request built successfully");
1730        Ok(request)
1731    }
1732
1733    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1734        let model = self.model().context("No language model configured")?;
1735
1736        let profile = AgentSettings::get_global(cx)
1737            .profiles
1738            .get(&self.profile_id)
1739            .context("profile not found")?;
1740        let provider_id = model.provider_id();
1741
1742        Ok(self
1743            .tools
1744            .iter()
1745            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1746            .filter_map(|(tool_name, tool)| {
1747                if profile.is_tool_enabled(tool_name) {
1748                    Some(tool)
1749                } else {
1750                    None
1751                }
1752            })
1753            .chain(self.context_server_registry.read(cx).servers().flat_map(
1754                |(server_id, tools)| {
1755                    tools.iter().filter_map(|(tool_name, tool)| {
1756                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1757                            Some(tool)
1758                        } else {
1759                            None
1760                        }
1761                    })
1762                },
1763            )))
1764    }
1765
1766    fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1767        log::trace!(
1768            "Building request messages from {} thread messages",
1769            self.messages.len()
1770        );
1771        let mut messages = vec![self.build_system_message(cx)];
1772        for message in &self.messages {
1773            messages.extend(message.to_request());
1774        }
1775
1776        if let Some(message) = self.pending_message.as_ref() {
1777            messages.extend(message.to_request());
1778        }
1779
1780        if let Some(last_user_message) = messages
1781            .iter_mut()
1782            .rev()
1783            .find(|message| message.role == Role::User)
1784        {
1785            last_user_message.cache = true;
1786        }
1787
1788        messages
1789    }
1790
1791    pub fn to_markdown(&self) -> String {
1792        let mut markdown = String::new();
1793        for (ix, message) in self.messages.iter().enumerate() {
1794            if ix > 0 {
1795                markdown.push('\n');
1796            }
1797            markdown.push_str(&message.to_markdown());
1798        }
1799
1800        if let Some(message) = self.pending_message.as_ref() {
1801            markdown.push('\n');
1802            markdown.push_str(&message.to_markdown());
1803        }
1804
1805        markdown
1806    }
1807
1808    fn advance_prompt_id(&mut self) {
1809        self.prompt_id = PromptId::new();
1810    }
1811
1812    fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1813        use LanguageModelCompletionError::*;
1814        use http_client::StatusCode;
1815
1816        // General strategy here:
1817        // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1818        // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1819        // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1820        match error {
1821            HttpResponseError {
1822                status_code: StatusCode::TOO_MANY_REQUESTS,
1823                ..
1824            } => Some(RetryStrategy::ExponentialBackoff {
1825                initial_delay: BASE_RETRY_DELAY,
1826                max_attempts: MAX_RETRY_ATTEMPTS,
1827            }),
1828            ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1829                Some(RetryStrategy::Fixed {
1830                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1831                    max_attempts: MAX_RETRY_ATTEMPTS,
1832                })
1833            }
1834            UpstreamProviderError {
1835                status,
1836                retry_after,
1837                ..
1838            } => match *status {
1839                StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1840                    Some(RetryStrategy::Fixed {
1841                        delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1842                        max_attempts: MAX_RETRY_ATTEMPTS,
1843                    })
1844                }
1845                StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1846                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1847                    // Internal Server Error could be anything, retry up to 3 times.
1848                    max_attempts: 3,
1849                }),
1850                status => {
1851                    // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1852                    // but we frequently get them in practice. See https://http.dev/529
1853                    if status.as_u16() == 529 {
1854                        Some(RetryStrategy::Fixed {
1855                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1856                            max_attempts: MAX_RETRY_ATTEMPTS,
1857                        })
1858                    } else {
1859                        Some(RetryStrategy::Fixed {
1860                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1861                            max_attempts: 2,
1862                        })
1863                    }
1864                }
1865            },
1866            ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1867                delay: BASE_RETRY_DELAY,
1868                max_attempts: 3,
1869            }),
1870            ApiReadResponseError { .. }
1871            | HttpSend { .. }
1872            | DeserializeResponse { .. }
1873            | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
1874                delay: BASE_RETRY_DELAY,
1875                max_attempts: 3,
1876            }),
1877            // Retrying these errors definitely shouldn't help.
1878            HttpResponseError {
1879                status_code:
1880                    StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
1881                ..
1882            }
1883            | AuthenticationError { .. }
1884            | PermissionError { .. }
1885            | NoApiKey { .. }
1886            | ApiEndpointNotFound { .. }
1887            | PromptTooLarge { .. } => None,
1888            // These errors might be transient, so retry them
1889            SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
1890                delay: BASE_RETRY_DELAY,
1891                max_attempts: 1,
1892            }),
1893            // Retry all other 4xx and 5xx errors once.
1894            HttpResponseError { status_code, .. }
1895                if status_code.is_client_error() || status_code.is_server_error() =>
1896            {
1897                Some(RetryStrategy::Fixed {
1898                    delay: BASE_RETRY_DELAY,
1899                    max_attempts: 3,
1900                })
1901            }
1902            Other(err)
1903                if err.is::<language_model::PaymentRequiredError>()
1904                    || err.is::<language_model::ModelRequestLimitReachedError>() =>
1905            {
1906                // Retrying won't help for Payment Required or Model Request Limit errors (where
1907                // the user must upgrade to usage-based billing to get more requests, or else wait
1908                // for a significant amount of time for the request limit to reset).
1909                None
1910            }
1911            // Conservatively assume that any other errors are non-retryable
1912            HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
1913                delay: BASE_RETRY_DELAY,
1914                max_attempts: 2,
1915            }),
1916        }
1917    }
1918}
1919
1920struct RunningTurn {
1921    /// Holds the task that handles agent interaction until the end of the turn.
1922    /// Survives across multiple requests as the model performs tool calls and
1923    /// we run tools, report their results.
1924    _task: Task<()>,
1925    /// The current event stream for the running turn. Used to report a final
1926    /// cancellation event if we cancel the turn.
1927    event_stream: ThreadEventStream,
1928}
1929
1930impl RunningTurn {
1931    fn cancel(self) {
1932        log::debug!("Cancelling in progress turn");
1933        self.event_stream.send_canceled();
1934    }
1935}
1936
1937pub trait AgentTool
1938where
1939    Self: 'static + Sized,
1940{
1941    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1942    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1943
1944    fn name(&self) -> SharedString;
1945
1946    fn description(&self) -> SharedString {
1947        let schema = schemars::schema_for!(Self::Input);
1948        SharedString::new(
1949            schema
1950                .get("description")
1951                .and_then(|description| description.as_str())
1952                .unwrap_or_default(),
1953        )
1954    }
1955
1956    fn kind(&self) -> acp::ToolKind;
1957
1958    /// The initial tool title to display. Can be updated during the tool run.
1959    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1960
1961    /// Returns the JSON schema that describes the tool's input.
1962    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
1963        crate::tool_schema::root_schema_for::<Self::Input>(format)
1964    }
1965
1966    /// Some tools rely on a provider for the underlying billing or other reasons.
1967    /// Allow the tool to check if they are compatible, or should be filtered out.
1968    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1969        true
1970    }
1971
1972    /// Runs the tool with the provided input.
1973    fn run(
1974        self: Arc<Self>,
1975        input: Self::Input,
1976        event_stream: ToolCallEventStream,
1977        cx: &mut App,
1978    ) -> Task<Result<Self::Output>>;
1979
1980    /// Emits events for a previous execution of the tool.
1981    fn replay(
1982        &self,
1983        _input: Self::Input,
1984        _output: Self::Output,
1985        _event_stream: ToolCallEventStream,
1986        _cx: &mut App,
1987    ) -> Result<()> {
1988        Ok(())
1989    }
1990
1991    fn erase(self) -> Arc<dyn AnyAgentTool> {
1992        Arc::new(Erased(Arc::new(self)))
1993    }
1994}
1995
1996pub struct Erased<T>(T);
1997
1998pub struct AgentToolOutput {
1999    pub llm_output: LanguageModelToolResultContent,
2000    pub raw_output: serde_json::Value,
2001}
2002
2003pub trait AnyAgentTool {
2004    fn name(&self) -> SharedString;
2005    fn description(&self) -> SharedString;
2006    fn kind(&self) -> acp::ToolKind;
2007    fn initial_title(&self, input: serde_json::Value) -> SharedString;
2008    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
2009    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
2010        true
2011    }
2012    fn run(
2013        self: Arc<Self>,
2014        input: serde_json::Value,
2015        event_stream: ToolCallEventStream,
2016        cx: &mut App,
2017    ) -> Task<Result<AgentToolOutput>>;
2018    fn replay(
2019        &self,
2020        input: serde_json::Value,
2021        output: serde_json::Value,
2022        event_stream: ToolCallEventStream,
2023        cx: &mut App,
2024    ) -> Result<()>;
2025}
2026
2027impl<T> AnyAgentTool for Erased<Arc<T>>
2028where
2029    T: AgentTool,
2030{
2031    fn name(&self) -> SharedString {
2032        self.0.name()
2033    }
2034
2035    fn description(&self) -> SharedString {
2036        self.0.description()
2037    }
2038
2039    fn kind(&self) -> agent_client_protocol::ToolKind {
2040        self.0.kind()
2041    }
2042
2043    fn initial_title(&self, input: serde_json::Value) -> SharedString {
2044        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
2045        self.0.initial_title(parsed_input)
2046    }
2047
2048    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
2049        let mut json = serde_json::to_value(self.0.input_schema(format))?;
2050        adapt_schema_to_format(&mut json, format)?;
2051        Ok(json)
2052    }
2053
2054    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
2055        self.0.supported_provider(provider)
2056    }
2057
2058    fn run(
2059        self: Arc<Self>,
2060        input: serde_json::Value,
2061        event_stream: ToolCallEventStream,
2062        cx: &mut App,
2063    ) -> Task<Result<AgentToolOutput>> {
2064        cx.spawn(async move |cx| {
2065            let input = serde_json::from_value(input)?;
2066            let output = cx
2067                .update(|cx| self.0.clone().run(input, event_stream, cx))?
2068                .await?;
2069            let raw_output = serde_json::to_value(&output)?;
2070            Ok(AgentToolOutput {
2071                llm_output: output.into(),
2072                raw_output,
2073            })
2074        })
2075    }
2076
2077    fn replay(
2078        &self,
2079        input: serde_json::Value,
2080        output: serde_json::Value,
2081        event_stream: ToolCallEventStream,
2082        cx: &mut App,
2083    ) -> Result<()> {
2084        let input = serde_json::from_value(input)?;
2085        let output = serde_json::from_value(output)?;
2086        self.0.replay(input, output, event_stream, cx)
2087    }
2088}
2089
2090#[derive(Clone)]
2091struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
2092
2093impl ThreadEventStream {
2094    fn send_title_update(&self, text: SharedString) {
2095        self.0
2096            .unbounded_send(Ok(ThreadEvent::TitleUpdate(text)))
2097            .ok();
2098    }
2099
2100    fn send_user_message(&self, message: &UserMessage) {
2101        self.0
2102            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
2103            .ok();
2104    }
2105
2106    fn send_text(&self, text: &str) {
2107        self.0
2108            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
2109            .ok();
2110    }
2111
2112    fn send_thinking(&self, text: &str) {
2113        self.0
2114            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
2115            .ok();
2116    }
2117
2118    fn send_tool_call(
2119        &self,
2120        id: &LanguageModelToolUseId,
2121        title: SharedString,
2122        kind: acp::ToolKind,
2123        input: serde_json::Value,
2124    ) {
2125        self.0
2126            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
2127                id,
2128                title.to_string(),
2129                kind,
2130                input,
2131            ))))
2132            .ok();
2133    }
2134
2135    fn initial_tool_call(
2136        id: &LanguageModelToolUseId,
2137        title: String,
2138        kind: acp::ToolKind,
2139        input: serde_json::Value,
2140    ) -> acp::ToolCall {
2141        acp::ToolCall {
2142            id: acp::ToolCallId(id.to_string().into()),
2143            title,
2144            kind,
2145            status: acp::ToolCallStatus::Pending,
2146            content: vec![],
2147            locations: vec![],
2148            raw_input: Some(input),
2149            raw_output: None,
2150        }
2151    }
2152
2153    fn update_tool_call_fields(
2154        &self,
2155        tool_use_id: &LanguageModelToolUseId,
2156        fields: acp::ToolCallUpdateFields,
2157    ) {
2158        self.0
2159            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2160                acp::ToolCallUpdate {
2161                    id: acp::ToolCallId(tool_use_id.to_string().into()),
2162                    fields,
2163                }
2164                .into(),
2165            )))
2166            .ok();
2167    }
2168
2169    fn send_token_usage_update(&self, usage: acp_thread::TokenUsage) {
2170        self.0
2171            .unbounded_send(Ok(ThreadEvent::TokenUsageUpdate(usage)))
2172            .ok();
2173    }
2174
2175    fn send_retry(&self, status: acp_thread::RetryStatus) {
2176        self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
2177    }
2178
2179    fn send_stop(&self, reason: StopReason) {
2180        match reason {
2181            StopReason::EndTurn => {
2182                self.0
2183                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
2184                    .ok();
2185            }
2186            StopReason::MaxTokens => {
2187                self.0
2188                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
2189                    .ok();
2190            }
2191            StopReason::Refusal => {
2192                self.0
2193                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
2194                    .ok();
2195            }
2196            StopReason::ToolUse => {}
2197        }
2198    }
2199
2200    fn send_canceled(&self) {
2201        self.0
2202            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
2203            .ok();
2204    }
2205
2206    fn send_error(&self, error: impl Into<anyhow::Error>) {
2207        self.0.unbounded_send(Err(error.into())).ok();
2208    }
2209}
2210
2211#[derive(Clone)]
2212pub struct ToolCallEventStream {
2213    tool_use_id: LanguageModelToolUseId,
2214    stream: ThreadEventStream,
2215    fs: Option<Arc<dyn Fs>>,
2216}
2217
2218impl ToolCallEventStream {
2219    #[cfg(test)]
2220    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
2221        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
2222
2223        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
2224
2225        (stream, ToolCallEventStreamReceiver(events_rx))
2226    }
2227
2228    fn new(
2229        tool_use_id: LanguageModelToolUseId,
2230        stream: ThreadEventStream,
2231        fs: Option<Arc<dyn Fs>>,
2232    ) -> Self {
2233        Self {
2234            tool_use_id,
2235            stream,
2236            fs,
2237        }
2238    }
2239
2240    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
2241        self.stream
2242            .update_tool_call_fields(&self.tool_use_id, fields);
2243    }
2244
2245    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
2246        self.stream
2247            .0
2248            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2249                acp_thread::ToolCallUpdateDiff {
2250                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2251                    diff,
2252                }
2253                .into(),
2254            )))
2255            .ok();
2256    }
2257
2258    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
2259        self.stream
2260            .0
2261            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2262                acp_thread::ToolCallUpdateTerminal {
2263                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2264                    terminal,
2265                }
2266                .into(),
2267            )))
2268            .ok();
2269    }
2270
2271    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
2272        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
2273            return Task::ready(Ok(()));
2274        }
2275
2276        let (response_tx, response_rx) = oneshot::channel();
2277        self.stream
2278            .0
2279            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
2280                ToolCallAuthorization {
2281                    tool_call: acp::ToolCallUpdate {
2282                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2283                        fields: acp::ToolCallUpdateFields {
2284                            title: Some(title.into()),
2285                            ..Default::default()
2286                        },
2287                    },
2288                    options: vec![
2289                        acp::PermissionOption {
2290                            id: acp::PermissionOptionId("always_allow".into()),
2291                            name: "Always Allow".into(),
2292                            kind: acp::PermissionOptionKind::AllowAlways,
2293                        },
2294                        acp::PermissionOption {
2295                            id: acp::PermissionOptionId("allow".into()),
2296                            name: "Allow".into(),
2297                            kind: acp::PermissionOptionKind::AllowOnce,
2298                        },
2299                        acp::PermissionOption {
2300                            id: acp::PermissionOptionId("deny".into()),
2301                            name: "Deny".into(),
2302                            kind: acp::PermissionOptionKind::RejectOnce,
2303                        },
2304                    ],
2305                    response: response_tx,
2306                },
2307            )))
2308            .ok();
2309        let fs = self.fs.clone();
2310        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
2311            "always_allow" => {
2312                if let Some(fs) = fs.clone() {
2313                    cx.update(|cx| {
2314                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
2315                            settings.set_always_allow_tool_actions(true);
2316                        });
2317                    })?;
2318                }
2319
2320                Ok(())
2321            }
2322            "allow" => Ok(()),
2323            _ => Err(anyhow!("Permission to run tool denied by user")),
2324        })
2325    }
2326}
2327
2328#[cfg(test)]
2329pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
2330
2331#[cfg(test)]
2332impl ToolCallEventStreamReceiver {
2333    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
2334        let event = self.0.next().await;
2335        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
2336            auth
2337        } else {
2338            panic!("Expected ToolCallAuthorization but got: {:?}", event);
2339        }
2340    }
2341
2342    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
2343        let event = self.0.next().await;
2344        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
2345            update,
2346        )))) = event
2347        {
2348            update.terminal
2349        } else {
2350            panic!("Expected terminal but got: {:?}", event);
2351        }
2352    }
2353}
2354
2355#[cfg(test)]
2356impl std::ops::Deref for ToolCallEventStreamReceiver {
2357    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
2358
2359    fn deref(&self) -> &Self::Target {
2360        &self.0
2361    }
2362}
2363
2364#[cfg(test)]
2365impl std::ops::DerefMut for ToolCallEventStreamReceiver {
2366    fn deref_mut(&mut self) -> &mut Self::Target {
2367        &mut self.0
2368    }
2369}
2370
2371impl From<&str> for UserMessageContent {
2372    fn from(text: &str) -> Self {
2373        Self::Text(text.into())
2374    }
2375}
2376
2377impl From<acp::ContentBlock> for UserMessageContent {
2378    fn from(value: acp::ContentBlock) -> Self {
2379        match value {
2380            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
2381            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
2382            acp::ContentBlock::Audio(_) => {
2383                // TODO
2384                Self::Text("[audio]".to_string())
2385            }
2386            acp::ContentBlock::ResourceLink(resource_link) => {
2387                match MentionUri::parse(&resource_link.uri) {
2388                    Ok(uri) => Self::Mention {
2389                        uri,
2390                        content: String::new(),
2391                    },
2392                    Err(err) => {
2393                        log::error!("Failed to parse mention link: {}", err);
2394                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2395                    }
2396                }
2397            }
2398            acp::ContentBlock::Resource(resource) => match resource.resource {
2399                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2400                    match MentionUri::parse(&resource.uri) {
2401                        Ok(uri) => Self::Mention {
2402                            uri,
2403                            content: resource.text,
2404                        },
2405                        Err(err) => {
2406                            log::error!("Failed to parse mention link: {}", err);
2407                            Self::Text(
2408                                MarkdownCodeBlock {
2409                                    tag: &resource.uri,
2410                                    text: &resource.text,
2411                                }
2412                                .to_string(),
2413                            )
2414                        }
2415                    }
2416                }
2417                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2418                    // TODO
2419                    Self::Text("[blob]".to_string())
2420                }
2421            },
2422        }
2423    }
2424}
2425
2426impl From<UserMessageContent> for acp::ContentBlock {
2427    fn from(content: UserMessageContent) -> Self {
2428        match content {
2429            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
2430                text,
2431                annotations: None,
2432            }),
2433            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
2434                data: image.source.to_string(),
2435                mime_type: "image/png".to_string(),
2436                annotations: None,
2437                uri: None,
2438            }),
2439            UserMessageContent::Mention { uri, content } => {
2440                acp::ContentBlock::ResourceLink(acp::ResourceLink {
2441                    uri: uri.to_uri().to_string(),
2442                    name: uri.name(),
2443                    annotations: None,
2444                    description: if content.is_empty() {
2445                        None
2446                    } else {
2447                        Some(content)
2448                    },
2449                    mime_type: None,
2450                    size: None,
2451                    title: None,
2452                })
2453            }
2454        }
2455    }
2456}
2457
2458fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2459    LanguageModelImage {
2460        source: image_content.data.into(),
2461        // TODO: make this optional?
2462        size: gpui::Size::new(0.into(), 0.into()),
2463    }
2464}