thread.rs

   1use crate::{
   2    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread,
   3    DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool,
   4    ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
   5    Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
   6};
   7use acp_thread::{MentionUri, UserMessageId};
   8use action_log::ActionLog;
   9use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
  10use agent_client_protocol as acp;
  11use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
  12use anyhow::{Context as _, Result, anyhow};
  13use assistant_tool::adapt_schema_to_format;
  14use chrono::{DateTime, Utc};
  15use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
  16use collections::IndexMap;
  17use fs::Fs;
  18use futures::{
  19    FutureExt,
  20    channel::{mpsc, oneshot},
  21    future::Shared,
  22    stream::FuturesUnordered,
  23};
  24use git::repository::DiffType;
  25use gpui::{App, AppContext, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
  26use language_model::{
  27    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
  28    LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
  29    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  30    LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse,
  31    LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
  32};
  33use project::{
  34    Project,
  35    git_store::{GitStore, RepositoryState},
  36};
  37use prompt_store::ProjectContext;
  38use schemars::{JsonSchema, Schema};
  39use serde::{Deserialize, Serialize};
  40use settings::{Settings, update_settings_file};
  41use smol::stream::StreamExt;
  42use std::{
  43    collections::BTreeMap,
  44    path::Path,
  45    sync::Arc,
  46    time::{Duration, Instant},
  47};
  48use std::{fmt::Write, ops::Range};
  49use util::{ResultExt, markdown::MarkdownCodeBlock};
  50use uuid::Uuid;
  51
  52const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  53
  54/// The ID of the user prompt that initiated a request.
  55///
  56/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  57#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  58pub struct PromptId(Arc<str>);
  59
  60impl PromptId {
  61    pub fn new() -> Self {
  62        Self(Uuid::new_v4().to_string().into())
  63    }
  64}
  65
  66impl std::fmt::Display for PromptId {
  67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  68        write!(f, "{}", self.0)
  69    }
  70}
  71
  72pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
  73pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
  74
  75#[derive(Debug, Clone)]
  76enum RetryStrategy {
  77    ExponentialBackoff {
  78        initial_delay: Duration,
  79        max_attempts: u8,
  80    },
  81    Fixed {
  82        delay: Duration,
  83        max_attempts: u8,
  84    },
  85}
  86
  87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  88pub enum Message {
  89    User(UserMessage),
  90    Agent(AgentMessage),
  91    Resume,
  92}
  93
  94impl Message {
  95    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
  96        match self {
  97            Message::Agent(agent_message) => Some(agent_message),
  98            _ => None,
  99        }
 100    }
 101
 102    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 103        match self {
 104            Message::User(message) => vec![message.to_request()],
 105            Message::Agent(message) => message.to_request(),
 106            Message::Resume => vec![LanguageModelRequestMessage {
 107                role: Role::User,
 108                content: vec!["Continue where you left off".into()],
 109                cache: false,
 110            }],
 111        }
 112    }
 113
 114    pub fn to_markdown(&self) -> String {
 115        match self {
 116            Message::User(message) => message.to_markdown(),
 117            Message::Agent(message) => message.to_markdown(),
 118            Message::Resume => "[resumed after tool use limit was reached]".into(),
 119        }
 120    }
 121
 122    pub fn role(&self) -> Role {
 123        match self {
 124            Message::User(_) | Message::Resume => Role::User,
 125            Message::Agent(_) => Role::Assistant,
 126        }
 127    }
 128}
 129
 130#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 131pub struct UserMessage {
 132    pub id: UserMessageId,
 133    pub content: Vec<UserMessageContent>,
 134}
 135
 136#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 137pub enum UserMessageContent {
 138    Text(String),
 139    Mention { uri: MentionUri, content: String },
 140    Image(LanguageModelImage),
 141}
 142
 143impl UserMessage {
 144    pub fn to_markdown(&self) -> String {
 145        let mut markdown = String::from("## User\n\n");
 146
 147        for content in &self.content {
 148            match content {
 149                UserMessageContent::Text(text) => {
 150                    markdown.push_str(text);
 151                    markdown.push('\n');
 152                }
 153                UserMessageContent::Image(_) => {
 154                    markdown.push_str("<image />\n");
 155                }
 156                UserMessageContent::Mention { uri, content } => {
 157                    if !content.is_empty() {
 158                        let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
 159                    } else {
 160                        let _ = write!(&mut markdown, "{}\n", uri.as_link());
 161                    }
 162                }
 163            }
 164        }
 165
 166        markdown
 167    }
 168
 169    fn to_request(&self) -> LanguageModelRequestMessage {
 170        let mut message = LanguageModelRequestMessage {
 171            role: Role::User,
 172            content: Vec::with_capacity(self.content.len()),
 173            cache: false,
 174        };
 175
 176        const OPEN_CONTEXT: &str = "<context>\n\
 177            The following items were attached by the user. \
 178            They are up-to-date and don't need to be re-read.\n\n";
 179
 180        const OPEN_FILES_TAG: &str = "<files>";
 181        const OPEN_DIRECTORIES_TAG: &str = "<directories>";
 182        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 183        const OPEN_THREADS_TAG: &str = "<threads>";
 184        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 185        const OPEN_RULES_TAG: &str =
 186            "<rules>\nThe user has specified the following rules that should be applied:\n";
 187
 188        let mut file_context = OPEN_FILES_TAG.to_string();
 189        let mut directory_context = OPEN_DIRECTORIES_TAG.to_string();
 190        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 191        let mut thread_context = OPEN_THREADS_TAG.to_string();
 192        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 193        let mut rules_context = OPEN_RULES_TAG.to_string();
 194
 195        for chunk in &self.content {
 196            let chunk = match chunk {
 197                UserMessageContent::Text(text) => {
 198                    language_model::MessageContent::Text(text.clone())
 199                }
 200                UserMessageContent::Image(value) => {
 201                    language_model::MessageContent::Image(value.clone())
 202                }
 203                UserMessageContent::Mention { uri, content } => {
 204                    match uri {
 205                        MentionUri::File { abs_path } => {
 206                            write!(
 207                                &mut symbol_context,
 208                                "\n{}",
 209                                MarkdownCodeBlock {
 210                                    tag: &codeblock_tag(abs_path, None),
 211                                    text: &content.to_string(),
 212                                }
 213                            )
 214                            .ok();
 215                        }
 216                        MentionUri::Directory { .. } => {
 217                            write!(&mut directory_context, "\n{}\n", content).ok();
 218                        }
 219                        MentionUri::Symbol {
 220                            path, line_range, ..
 221                        }
 222                        | MentionUri::Selection {
 223                            path, line_range, ..
 224                        } => {
 225                            write!(
 226                                &mut rules_context,
 227                                "\n{}",
 228                                MarkdownCodeBlock {
 229                                    tag: &codeblock_tag(path, Some(line_range)),
 230                                    text: content
 231                                }
 232                            )
 233                            .ok();
 234                        }
 235                        MentionUri::Thread { .. } => {
 236                            write!(&mut thread_context, "\n{}\n", content).ok();
 237                        }
 238                        MentionUri::TextThread { .. } => {
 239                            write!(&mut thread_context, "\n{}\n", content).ok();
 240                        }
 241                        MentionUri::Rule { .. } => {
 242                            write!(
 243                                &mut rules_context,
 244                                "\n{}",
 245                                MarkdownCodeBlock {
 246                                    tag: "",
 247                                    text: content
 248                                }
 249                            )
 250                            .ok();
 251                        }
 252                        MentionUri::Fetch { url } => {
 253                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 254                        }
 255                    }
 256
 257                    language_model::MessageContent::Text(uri.as_link().to_string())
 258                }
 259            };
 260
 261            message.content.push(chunk);
 262        }
 263
 264        let len_before_context = message.content.len();
 265
 266        if file_context.len() > OPEN_FILES_TAG.len() {
 267            file_context.push_str("</files>\n");
 268            message
 269                .content
 270                .push(language_model::MessageContent::Text(file_context));
 271        }
 272
 273        if directory_context.len() > OPEN_DIRECTORIES_TAG.len() {
 274            directory_context.push_str("</directories>\n");
 275            message
 276                .content
 277                .push(language_model::MessageContent::Text(directory_context));
 278        }
 279
 280        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 281            symbol_context.push_str("</symbols>\n");
 282            message
 283                .content
 284                .push(language_model::MessageContent::Text(symbol_context));
 285        }
 286
 287        if thread_context.len() > OPEN_THREADS_TAG.len() {
 288            thread_context.push_str("</threads>\n");
 289            message
 290                .content
 291                .push(language_model::MessageContent::Text(thread_context));
 292        }
 293
 294        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 295            fetch_context.push_str("</fetched_urls>\n");
 296            message
 297                .content
 298                .push(language_model::MessageContent::Text(fetch_context));
 299        }
 300
 301        if rules_context.len() > OPEN_RULES_TAG.len() {
 302            rules_context.push_str("</user_rules>\n");
 303            message
 304                .content
 305                .push(language_model::MessageContent::Text(rules_context));
 306        }
 307
 308        if message.content.len() > len_before_context {
 309            message.content.insert(
 310                len_before_context,
 311                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 312            );
 313            message
 314                .content
 315                .push(language_model::MessageContent::Text("</context>".into()));
 316        }
 317
 318        message
 319    }
 320}
 321
 322fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 323    let mut result = String::new();
 324
 325    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 326        let _ = write!(result, "{} ", extension);
 327    }
 328
 329    let _ = write!(result, "{}", full_path.display());
 330
 331    if let Some(range) = line_range {
 332        if range.start == range.end {
 333            let _ = write!(result, ":{}", range.start + 1);
 334        } else {
 335            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 336        }
 337    }
 338
 339    result
 340}
 341
 342impl AgentMessage {
 343    pub fn to_markdown(&self) -> String {
 344        let mut markdown = String::from("## Assistant\n\n");
 345
 346        for content in &self.content {
 347            match content {
 348                AgentMessageContent::Text(text) => {
 349                    markdown.push_str(text);
 350                    markdown.push('\n');
 351                }
 352                AgentMessageContent::Thinking { text, .. } => {
 353                    markdown.push_str("<think>");
 354                    markdown.push_str(text);
 355                    markdown.push_str("</think>\n");
 356                }
 357                AgentMessageContent::RedactedThinking(_) => {
 358                    markdown.push_str("<redacted_thinking />\n")
 359                }
 360                AgentMessageContent::ToolUse(tool_use) => {
 361                    markdown.push_str(&format!(
 362                        "**Tool Use**: {} (ID: {})\n",
 363                        tool_use.name, tool_use.id
 364                    ));
 365                    markdown.push_str(&format!(
 366                        "{}\n",
 367                        MarkdownCodeBlock {
 368                            tag: "json",
 369                            text: &format!("{:#}", tool_use.input)
 370                        }
 371                    ));
 372                }
 373            }
 374        }
 375
 376        for tool_result in self.tool_results.values() {
 377            markdown.push_str(&format!(
 378                "**Tool Result**: {} (ID: {})\n\n",
 379                tool_result.tool_name, tool_result.tool_use_id
 380            ));
 381            if tool_result.is_error {
 382                markdown.push_str("**ERROR:**\n");
 383            }
 384
 385            match &tool_result.content {
 386                LanguageModelToolResultContent::Text(text) => {
 387                    writeln!(markdown, "{text}\n").ok();
 388                }
 389                LanguageModelToolResultContent::Image(_) => {
 390                    writeln!(markdown, "<image />\n").ok();
 391                }
 392            }
 393
 394            if let Some(output) = tool_result.output.as_ref() {
 395                writeln!(
 396                    markdown,
 397                    "**Debug Output**:\n\n```json\n{}\n```\n",
 398                    serde_json::to_string_pretty(output).unwrap()
 399                )
 400                .unwrap();
 401            }
 402        }
 403
 404        markdown
 405    }
 406
 407    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 408        let mut assistant_message = LanguageModelRequestMessage {
 409            role: Role::Assistant,
 410            content: Vec::with_capacity(self.content.len()),
 411            cache: false,
 412        };
 413        for chunk in &self.content {
 414            let chunk = match chunk {
 415                AgentMessageContent::Text(text) => {
 416                    language_model::MessageContent::Text(text.clone())
 417                }
 418                AgentMessageContent::Thinking { text, signature } => {
 419                    language_model::MessageContent::Thinking {
 420                        text: text.clone(),
 421                        signature: signature.clone(),
 422                    }
 423                }
 424                AgentMessageContent::RedactedThinking(value) => {
 425                    language_model::MessageContent::RedactedThinking(value.clone())
 426                }
 427                AgentMessageContent::ToolUse(value) => {
 428                    language_model::MessageContent::ToolUse(value.clone())
 429                }
 430            };
 431            assistant_message.content.push(chunk);
 432        }
 433
 434        let mut user_message = LanguageModelRequestMessage {
 435            role: Role::User,
 436            content: Vec::new(),
 437            cache: false,
 438        };
 439
 440        for tool_result in self.tool_results.values() {
 441            user_message
 442                .content
 443                .push(language_model::MessageContent::ToolResult(
 444                    tool_result.clone(),
 445                ));
 446        }
 447
 448        let mut messages = Vec::new();
 449        if !assistant_message.content.is_empty() {
 450            messages.push(assistant_message);
 451        }
 452        if !user_message.content.is_empty() {
 453            messages.push(user_message);
 454        }
 455        messages
 456    }
 457}
 458
 459#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 460pub struct AgentMessage {
 461    pub content: Vec<AgentMessageContent>,
 462    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 463}
 464
 465#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 466pub enum AgentMessageContent {
 467    Text(String),
 468    Thinking {
 469        text: String,
 470        signature: Option<String>,
 471    },
 472    RedactedThinking(String),
 473    ToolUse(LanguageModelToolUse),
 474}
 475
 476#[derive(Debug)]
 477pub enum ThreadEvent {
 478    UserMessage(UserMessage),
 479    AgentText(String),
 480    AgentThinking(String),
 481    ToolCall(acp::ToolCall),
 482    ToolCallUpdate(acp_thread::ToolCallUpdate),
 483    ToolCallAuthorization(ToolCallAuthorization),
 484    TitleUpdate(SharedString),
 485    Retry(acp_thread::RetryStatus),
 486    Stop(acp::StopReason),
 487}
 488
 489#[derive(Debug)]
 490pub struct ToolCallAuthorization {
 491    pub tool_call: acp::ToolCallUpdate,
 492    pub options: Vec<acp::PermissionOption>,
 493    pub response: oneshot::Sender<acp::PermissionOptionId>,
 494}
 495
 496pub struct Thread {
 497    id: acp::SessionId,
 498    prompt_id: PromptId,
 499    updated_at: DateTime<Utc>,
 500    title: Option<SharedString>,
 501    #[allow(unused)]
 502    summary: DetailedSummaryState,
 503    messages: Vec<Message>,
 504    completion_mode: CompletionMode,
 505    /// Holds the task that handles agent interaction until the end of the turn.
 506    /// Survives across multiple requests as the model performs tool calls and
 507    /// we run tools, report their results.
 508    running_turn: Option<RunningTurn>,
 509    pending_message: Option<AgentMessage>,
 510    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 511    tool_use_limit_reached: bool,
 512    #[allow(unused)]
 513    request_token_usage: Vec<TokenUsage>,
 514    #[allow(unused)]
 515    cumulative_token_usage: TokenUsage,
 516    #[allow(unused)]
 517    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 518    context_server_registry: Entity<ContextServerRegistry>,
 519    profile_id: AgentProfileId,
 520    project_context: Entity<ProjectContext>,
 521    templates: Arc<Templates>,
 522    model: Option<Arc<dyn LanguageModel>>,
 523    summarization_model: Option<Arc<dyn LanguageModel>>,
 524    pub(crate) project: Entity<Project>,
 525    pub(crate) action_log: Entity<ActionLog>,
 526}
 527
 528impl Thread {
 529    pub fn new(
 530        project: Entity<Project>,
 531        project_context: Entity<ProjectContext>,
 532        context_server_registry: Entity<ContextServerRegistry>,
 533        action_log: Entity<ActionLog>,
 534        templates: Arc<Templates>,
 535        model: Option<Arc<dyn LanguageModel>>,
 536        cx: &mut Context<Self>,
 537    ) -> Self {
 538        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 539        Self {
 540            id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
 541            prompt_id: PromptId::new(),
 542            updated_at: Utc::now(),
 543            title: None,
 544            summary: DetailedSummaryState::default(),
 545            messages: Vec::new(),
 546            completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
 547            running_turn: None,
 548            pending_message: None,
 549            tools: BTreeMap::default(),
 550            tool_use_limit_reached: false,
 551            request_token_usage: Vec::new(),
 552            cumulative_token_usage: TokenUsage::default(),
 553            initial_project_snapshot: {
 554                let project_snapshot = Self::project_snapshot(project.clone(), cx);
 555                cx.foreground_executor()
 556                    .spawn(async move { Some(project_snapshot.await) })
 557                    .shared()
 558            },
 559            context_server_registry,
 560            profile_id,
 561            project_context,
 562            templates,
 563            model,
 564            summarization_model: None,
 565            project,
 566            action_log,
 567        }
 568    }
 569
 570    pub fn id(&self) -> &acp::SessionId {
 571        &self.id
 572    }
 573
 574    pub fn replay(
 575        &mut self,
 576        cx: &mut Context<Self>,
 577    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 578        let (tx, rx) = mpsc::unbounded();
 579        let stream = ThreadEventStream(tx);
 580        for message in &self.messages {
 581            match message {
 582                Message::User(user_message) => stream.send_user_message(user_message),
 583                Message::Agent(assistant_message) => {
 584                    for content in &assistant_message.content {
 585                        match content {
 586                            AgentMessageContent::Text(text) => stream.send_text(text),
 587                            AgentMessageContent::Thinking { text, .. } => {
 588                                stream.send_thinking(text)
 589                            }
 590                            AgentMessageContent::RedactedThinking(_) => {}
 591                            AgentMessageContent::ToolUse(tool_use) => {
 592                                self.replay_tool_call(
 593                                    tool_use,
 594                                    assistant_message.tool_results.get(&tool_use.id),
 595                                    &stream,
 596                                    cx,
 597                                );
 598                            }
 599                        }
 600                    }
 601                }
 602                Message::Resume => {}
 603            }
 604        }
 605        rx
 606    }
 607
 608    fn replay_tool_call(
 609        &self,
 610        tool_use: &LanguageModelToolUse,
 611        tool_result: Option<&LanguageModelToolResult>,
 612        stream: &ThreadEventStream,
 613        cx: &mut Context<Self>,
 614    ) {
 615        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
 616            stream
 617                .0
 618                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 619                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 620                    title: tool_use.name.to_string(),
 621                    kind: acp::ToolKind::Other,
 622                    status: acp::ToolCallStatus::Failed,
 623                    content: Vec::new(),
 624                    locations: Vec::new(),
 625                    raw_input: Some(tool_use.input.clone()),
 626                    raw_output: None,
 627                })))
 628                .ok();
 629            return;
 630        };
 631
 632        let title = tool.initial_title(tool_use.input.clone());
 633        let kind = tool.kind();
 634        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 635
 636        let output = tool_result
 637            .as_ref()
 638            .and_then(|result| result.output.clone());
 639        if let Some(output) = output.clone() {
 640            let tool_event_stream = ToolCallEventStream::new(
 641                tool_use.id.clone(),
 642                stream.clone(),
 643                Some(self.project.read(cx).fs().clone()),
 644            );
 645            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 646                .log_err();
 647        }
 648
 649        stream.update_tool_call_fields(
 650            &tool_use.id,
 651            acp::ToolCallUpdateFields {
 652                status: Some(acp::ToolCallStatus::Completed),
 653                raw_output: output,
 654                ..Default::default()
 655            },
 656        );
 657    }
 658
 659    pub fn from_db(
 660        id: acp::SessionId,
 661        db_thread: DbThread,
 662        project: Entity<Project>,
 663        project_context: Entity<ProjectContext>,
 664        context_server_registry: Entity<ContextServerRegistry>,
 665        action_log: Entity<ActionLog>,
 666        templates: Arc<Templates>,
 667        cx: &mut Context<Self>,
 668    ) -> Self {
 669        let profile_id = db_thread
 670            .profile
 671            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 672        let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 673            db_thread
 674                .model
 675                .and_then(|model| {
 676                    let model = SelectedModel {
 677                        provider: model.provider.clone().into(),
 678                        model: model.model.clone().into(),
 679                    };
 680                    registry.select_model(&model, cx)
 681                })
 682                .or_else(|| registry.default_model())
 683                .map(|model| model.model)
 684        });
 685
 686        Self {
 687            id,
 688            prompt_id: PromptId::new(),
 689            title: if db_thread.title.is_empty() {
 690                None
 691            } else {
 692                Some(db_thread.title.clone())
 693            },
 694            summary: db_thread.summary,
 695            messages: db_thread.messages,
 696            completion_mode: db_thread.completion_mode.unwrap_or_default(),
 697            running_turn: None,
 698            pending_message: None,
 699            tools: BTreeMap::default(),
 700            tool_use_limit_reached: false,
 701            request_token_usage: db_thread.request_token_usage.clone(),
 702            cumulative_token_usage: db_thread.cumulative_token_usage,
 703            initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
 704            context_server_registry,
 705            profile_id,
 706            project_context,
 707            templates,
 708            model,
 709            summarization_model: None,
 710            project,
 711            action_log,
 712            updated_at: db_thread.updated_at,
 713        }
 714    }
 715
 716    pub fn to_db(&self, cx: &App) -> Task<DbThread> {
 717        let initial_project_snapshot = self.initial_project_snapshot.clone();
 718        let mut thread = DbThread {
 719            title: self.title.clone().unwrap_or_default(),
 720            messages: self.messages.clone(),
 721            updated_at: self.updated_at,
 722            summary: self.summary.clone(),
 723            initial_project_snapshot: None,
 724            cumulative_token_usage: self.cumulative_token_usage,
 725            request_token_usage: self.request_token_usage.clone(),
 726            model: self.model.as_ref().map(|model| DbLanguageModel {
 727                provider: model.provider_id().to_string(),
 728                model: model.name().0.to_string(),
 729            }),
 730            completion_mode: Some(self.completion_mode),
 731            profile: Some(self.profile_id.clone()),
 732        };
 733
 734        cx.background_spawn(async move {
 735            let initial_project_snapshot = initial_project_snapshot.await;
 736            thread.initial_project_snapshot = initial_project_snapshot;
 737            thread
 738        })
 739    }
 740
 741    /// Create a snapshot of the current project state including git information and unsaved buffers.
 742    fn project_snapshot(
 743        project: Entity<Project>,
 744        cx: &mut Context<Self>,
 745    ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
 746        let git_store = project.read(cx).git_store().clone();
 747        let worktree_snapshots: Vec<_> = project
 748            .read(cx)
 749            .visible_worktrees(cx)
 750            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
 751            .collect();
 752
 753        cx.spawn(async move |_, cx| {
 754            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
 755
 756            let mut unsaved_buffers = Vec::new();
 757            cx.update(|app_cx| {
 758                let buffer_store = project.read(app_cx).buffer_store();
 759                for buffer_handle in buffer_store.read(app_cx).buffers() {
 760                    let buffer = buffer_handle.read(app_cx);
 761                    if buffer.is_dirty()
 762                        && let Some(file) = buffer.file()
 763                    {
 764                        let path = file.path().to_string_lossy().to_string();
 765                        unsaved_buffers.push(path);
 766                    }
 767                }
 768            })
 769            .ok();
 770
 771            Arc::new(ProjectSnapshot {
 772                worktree_snapshots,
 773                unsaved_buffer_paths: unsaved_buffers,
 774                timestamp: Utc::now(),
 775            })
 776        })
 777    }
 778
 779    fn worktree_snapshot(
 780        worktree: Entity<project::Worktree>,
 781        git_store: Entity<GitStore>,
 782        cx: &App,
 783    ) -> Task<agent::thread::WorktreeSnapshot> {
 784        cx.spawn(async move |cx| {
 785            // Get worktree path and snapshot
 786            let worktree_info = cx.update(|app_cx| {
 787                let worktree = worktree.read(app_cx);
 788                let path = worktree.abs_path().to_string_lossy().to_string();
 789                let snapshot = worktree.snapshot();
 790                (path, snapshot)
 791            });
 792
 793            let Ok((worktree_path, _snapshot)) = worktree_info else {
 794                return WorktreeSnapshot {
 795                    worktree_path: String::new(),
 796                    git_state: None,
 797                };
 798            };
 799
 800            let git_state = git_store
 801                .update(cx, |git_store, cx| {
 802                    git_store
 803                        .repositories()
 804                        .values()
 805                        .find(|repo| {
 806                            repo.read(cx)
 807                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
 808                                .is_some()
 809                        })
 810                        .cloned()
 811                })
 812                .ok()
 813                .flatten()
 814                .map(|repo| {
 815                    repo.update(cx, |repo, _| {
 816                        let current_branch =
 817                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
 818                        repo.send_job(None, |state, _| async move {
 819                            let RepositoryState::Local { backend, .. } = state else {
 820                                return GitState {
 821                                    remote_url: None,
 822                                    head_sha: None,
 823                                    current_branch,
 824                                    diff: None,
 825                                };
 826                            };
 827
 828                            let remote_url = backend.remote_url("origin");
 829                            let head_sha = backend.head_sha().await;
 830                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
 831
 832                            GitState {
 833                                remote_url,
 834                                head_sha,
 835                                current_branch,
 836                                diff,
 837                            }
 838                        })
 839                    })
 840                });
 841
 842            let git_state = match git_state {
 843                Some(git_state) => match git_state.ok() {
 844                    Some(git_state) => git_state.await.ok(),
 845                    None => None,
 846                },
 847                None => None,
 848            };
 849
 850            WorktreeSnapshot {
 851                worktree_path,
 852                git_state,
 853            }
 854        })
 855    }
 856
 857    pub fn project_context(&self) -> &Entity<ProjectContext> {
 858        &self.project_context
 859    }
 860
 861    pub fn project(&self) -> &Entity<Project> {
 862        &self.project
 863    }
 864
 865    pub fn action_log(&self) -> &Entity<ActionLog> {
 866        &self.action_log
 867    }
 868
 869    pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
 870        self.model.as_ref()
 871    }
 872
 873    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
 874        self.model = Some(model);
 875        cx.notify()
 876    }
 877
 878    pub fn set_summarization_model(
 879        &mut self,
 880        model: Option<Arc<dyn LanguageModel>>,
 881        cx: &mut Context<Self>,
 882    ) {
 883        self.summarization_model = model;
 884        cx.notify()
 885    }
 886
 887    pub fn completion_mode(&self) -> CompletionMode {
 888        self.completion_mode
 889    }
 890
 891    pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
 892        self.completion_mode = mode;
 893        cx.notify()
 894    }
 895
 896    #[cfg(any(test, feature = "test-support"))]
 897    pub fn last_message(&self) -> Option<Message> {
 898        if let Some(message) = self.pending_message.clone() {
 899            Some(Message::Agent(message))
 900        } else {
 901            self.messages.last().cloned()
 902        }
 903    }
 904
 905    pub fn add_default_tools(&mut self, cx: &mut Context<Self>) {
 906        let language_registry = self.project.read(cx).languages().clone();
 907        self.add_tool(CopyPathTool::new(self.project.clone()));
 908        self.add_tool(CreateDirectoryTool::new(self.project.clone()));
 909        self.add_tool(DeletePathTool::new(
 910            self.project.clone(),
 911            self.action_log.clone(),
 912        ));
 913        self.add_tool(DiagnosticsTool::new(self.project.clone()));
 914        self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
 915        self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
 916        self.add_tool(FindPathTool::new(self.project.clone()));
 917        self.add_tool(GrepTool::new(self.project.clone()));
 918        self.add_tool(ListDirectoryTool::new(self.project.clone()));
 919        self.add_tool(MovePathTool::new(self.project.clone()));
 920        self.add_tool(NowTool);
 921        self.add_tool(OpenTool::new(self.project.clone()));
 922        self.add_tool(ReadFileTool::new(
 923            self.project.clone(),
 924            self.action_log.clone(),
 925        ));
 926        self.add_tool(TerminalTool::new(self.project.clone(), cx));
 927        self.add_tool(ThinkingTool);
 928        self.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
 929    }
 930
 931    pub fn add_tool(&mut self, tool: impl AgentTool) {
 932        self.tools.insert(tool.name(), tool.erase());
 933    }
 934
 935    pub fn remove_tool(&mut self, name: &str) -> bool {
 936        self.tools.remove(name).is_some()
 937    }
 938
 939    pub fn profile(&self) -> &AgentProfileId {
 940        &self.profile_id
 941    }
 942
 943    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 944        self.profile_id = profile_id;
 945    }
 946
 947    pub fn cancel(&mut self, cx: &mut Context<Self>) {
 948        if let Some(running_turn) = self.running_turn.take() {
 949            running_turn.cancel();
 950        }
 951        self.flush_pending_message(cx);
 952    }
 953
 954    pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
 955        self.cancel(cx);
 956        let Some(position) = self.messages.iter().position(
 957            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
 958        ) else {
 959            return Err(anyhow!("Message not found"));
 960        };
 961        self.messages.truncate(position);
 962        cx.notify();
 963        Ok(())
 964    }
 965
 966    pub fn resume(
 967        &mut self,
 968        cx: &mut Context<Self>,
 969    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
 970        anyhow::ensure!(
 971            self.tool_use_limit_reached,
 972            "can only resume after tool use limit is reached"
 973        );
 974
 975        self.messages.push(Message::Resume);
 976        cx.notify();
 977
 978        log::info!("Total messages in thread: {}", self.messages.len());
 979        self.run_turn(cx)
 980    }
 981
 982    /// Sending a message results in the model streaming a response, which could include tool calls.
 983    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 984    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 985    pub fn send<T>(
 986        &mut self,
 987        id: UserMessageId,
 988        content: impl IntoIterator<Item = T>,
 989        cx: &mut Context<Self>,
 990    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
 991    where
 992        T: Into<UserMessageContent>,
 993    {
 994        let model = self.model().context("No language model configured")?;
 995
 996        log::info!("Thread::send called with model: {:?}", model.name());
 997        self.advance_prompt_id();
 998
 999        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
1000        log::debug!("Thread::send content: {:?}", content);
1001
1002        self.messages
1003            .push(Message::User(UserMessage { id, content }));
1004        cx.notify();
1005
1006        log::info!("Total messages in thread: {}", self.messages.len());
1007        self.run_turn(cx)
1008    }
1009
1010    fn run_turn(
1011        &mut self,
1012        cx: &mut Context<Self>,
1013    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
1014        self.cancel(cx);
1015
1016        let model = self.model.clone().context("No language model configured")?;
1017        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1018        let event_stream = ThreadEventStream(events_tx);
1019        let message_ix = self.messages.len().saturating_sub(1);
1020        self.tool_use_limit_reached = false;
1021        self.running_turn = Some(RunningTurn {
1022            event_stream: event_stream.clone(),
1023            _task: cx.spawn(async move |this, cx| {
1024                log::info!("Starting agent turn execution");
1025                let turn_result: Result<StopReason> = async {
1026                    let mut completion_intent = CompletionIntent::UserPrompt;
1027                    loop {
1028                        log::debug!(
1029                            "Building completion request with intent: {:?}",
1030                            completion_intent
1031                        );
1032                        let request = this.update(cx, |this, cx| {
1033                            this.build_completion_request(completion_intent, cx)
1034                        })??;
1035
1036                        log::info!("Calling model.stream_completion");
1037
1038                        let mut tool_use_limit_reached = false;
1039                        let mut refused = false;
1040                        let mut reached_max_tokens = false;
1041                        let mut tool_uses = Self::stream_completion_with_retries(
1042                            this.clone(),
1043                            model.clone(),
1044                            request,
1045                            &event_stream,
1046                            &mut tool_use_limit_reached,
1047                            &mut refused,
1048                            &mut reached_max_tokens,
1049                            cx,
1050                        )
1051                        .await?;
1052
1053                        if refused {
1054                            return Ok(StopReason::Refusal);
1055                        } else if reached_max_tokens {
1056                            return Ok(StopReason::MaxTokens);
1057                        }
1058
1059                        let end_turn = tool_uses.is_empty();
1060                        while let Some(tool_result) = tool_uses.next().await {
1061                            log::info!("Tool finished {:?}", tool_result);
1062
1063                            event_stream.update_tool_call_fields(
1064                                &tool_result.tool_use_id,
1065                                acp::ToolCallUpdateFields {
1066                                    status: Some(if tool_result.is_error {
1067                                        acp::ToolCallStatus::Failed
1068                                    } else {
1069                                        acp::ToolCallStatus::Completed
1070                                    }),
1071                                    raw_output: tool_result.output.clone(),
1072                                    ..Default::default()
1073                                },
1074                            );
1075                            this.update(cx, |this, _cx| {
1076                                this.pending_message()
1077                                    .tool_results
1078                                    .insert(tool_result.tool_use_id.clone(), tool_result);
1079                            })
1080                            .ok();
1081                        }
1082
1083                        if tool_use_limit_reached {
1084                            log::info!("Tool use limit reached, completing turn");
1085                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
1086                            return Err(language_model::ToolUseLimitReachedError.into());
1087                        } else if end_turn {
1088                            log::info!("No tool uses found, completing turn");
1089                            return Ok(StopReason::EndTurn);
1090                        } else {
1091                            this.update(cx, |this, cx| this.flush_pending_message(cx))?;
1092                            completion_intent = CompletionIntent::ToolResults;
1093                        }
1094                    }
1095                }
1096                .await;
1097                _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
1098
1099                match turn_result {
1100                    Ok(reason) => {
1101                        log::info!("Turn execution completed: {:?}", reason);
1102
1103                        let update_title = this
1104                            .update(cx, |this, cx| this.update_title(&event_stream, cx))
1105                            .ok()
1106                            .flatten();
1107                        if let Some(update_title) = update_title {
1108                            update_title.await.context("update title failed").log_err();
1109                        }
1110
1111                        event_stream.send_stop(reason);
1112                        if reason == StopReason::Refusal {
1113                            _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
1114                        }
1115                    }
1116                    Err(error) => {
1117                        log::error!("Turn execution failed: {:?}", error);
1118                        event_stream.send_error(error);
1119                    }
1120                }
1121
1122                _ = this.update(cx, |this, _| this.running_turn.take());
1123            }),
1124        });
1125        Ok(events_rx)
1126    }
1127
1128    async fn stream_completion_with_retries(
1129        this: WeakEntity<Self>,
1130        model: Arc<dyn LanguageModel>,
1131        request: LanguageModelRequest,
1132        event_stream: &ThreadEventStream,
1133        tool_use_limit_reached: &mut bool,
1134        refusal: &mut bool,
1135        max_tokens_reached: &mut bool,
1136        cx: &mut AsyncApp,
1137    ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
1138        log::debug!("Stream completion started successfully");
1139
1140        let mut attempt = None;
1141        'retry: loop {
1142            let mut events = model.stream_completion(request.clone(), cx).await?;
1143            let mut tool_uses = FuturesUnordered::new();
1144            while let Some(event) = events.next().await {
1145                match event {
1146                    Ok(LanguageModelCompletionEvent::StatusUpdate(
1147                        CompletionRequestStatus::ToolUseLimitReached,
1148                    )) => {
1149                        *tool_use_limit_reached = true;
1150                    }
1151                    Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
1152                        *refusal = true;
1153                        return Ok(FuturesUnordered::default());
1154                    }
1155                    Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
1156                        *max_tokens_reached = true;
1157                        return Ok(FuturesUnordered::default());
1158                    }
1159                    Ok(LanguageModelCompletionEvent::Stop(
1160                        StopReason::ToolUse | StopReason::EndTurn,
1161                    )) => break,
1162                    Ok(event) => {
1163                        log::trace!("Received completion event: {:?}", event);
1164                        this.update(cx, |this, cx| {
1165                            tool_uses.extend(this.handle_streamed_completion_event(
1166                                event,
1167                                event_stream,
1168                                cx,
1169                            ));
1170                        })
1171                        .ok();
1172                    }
1173                    Err(error) => {
1174                        let completion_mode =
1175                            this.read_with(cx, |thread, _cx| thread.completion_mode())?;
1176                        if completion_mode == CompletionMode::Normal {
1177                            return Err(error.into());
1178                        }
1179
1180                        let Some(strategy) = Self::retry_strategy_for(&error) else {
1181                            return Err(error.into());
1182                        };
1183
1184                        let max_attempts = match &strategy {
1185                            RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
1186                            RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
1187                        };
1188
1189                        let attempt = attempt.get_or_insert(0u8);
1190
1191                        *attempt += 1;
1192
1193                        let attempt = *attempt;
1194                        if attempt > max_attempts {
1195                            return Err(error.into());
1196                        }
1197
1198                        let delay = match &strategy {
1199                            RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
1200                                let delay_secs =
1201                                    initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
1202                                Duration::from_secs(delay_secs)
1203                            }
1204                            RetryStrategy::Fixed { delay, .. } => *delay,
1205                        };
1206                        log::debug!("Retry attempt {attempt} with delay {delay:?}");
1207
1208                        event_stream.send_retry(acp_thread::RetryStatus {
1209                            last_error: error.to_string().into(),
1210                            attempt: attempt as usize,
1211                            max_attempts: max_attempts as usize,
1212                            started_at: Instant::now(),
1213                            duration: delay,
1214                        });
1215
1216                        cx.background_executor().timer(delay).await;
1217                        continue 'retry;
1218                    }
1219                }
1220            }
1221
1222            return Ok(tool_uses);
1223        }
1224    }
1225
1226    pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
1227        log::debug!("Building system message");
1228        let prompt = SystemPromptTemplate {
1229            project: self.project_context.read(cx),
1230            available_tools: self.tools.keys().cloned().collect(),
1231        }
1232        .render(&self.templates)
1233        .context("failed to build system prompt")
1234        .expect("Invalid template");
1235        log::debug!("System message built");
1236        LanguageModelRequestMessage {
1237            role: Role::System,
1238            content: vec![prompt.into()],
1239            cache: true,
1240        }
1241    }
1242
1243    /// A helper method that's called on every streamed completion event.
1244    /// Returns an optional tool result task, which the main agentic loop in
1245    /// send will send back to the model when it resolves.
1246    fn handle_streamed_completion_event(
1247        &mut self,
1248        event: LanguageModelCompletionEvent,
1249        event_stream: &ThreadEventStream,
1250        cx: &mut Context<Self>,
1251    ) -> Option<Task<LanguageModelToolResult>> {
1252        log::trace!("Handling streamed completion event: {:?}", event);
1253        use LanguageModelCompletionEvent::*;
1254
1255        match event {
1256            StartMessage { .. } => {
1257                self.flush_pending_message(cx);
1258                self.pending_message = Some(AgentMessage::default());
1259            }
1260            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1261            Thinking { text, signature } => {
1262                self.handle_thinking_event(text, signature, event_stream, cx)
1263            }
1264            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1265            ToolUse(tool_use) => {
1266                return self.handle_tool_use_event(tool_use, event_stream, cx);
1267            }
1268            ToolUseJsonParseError {
1269                id,
1270                tool_name,
1271                raw_input,
1272                json_parse_error,
1273            } => {
1274                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1275                    id,
1276                    tool_name,
1277                    raw_input,
1278                    json_parse_error,
1279                )));
1280            }
1281            UsageUpdate(_) | StatusUpdate(_) => {}
1282            Stop(_) => unreachable!(),
1283        }
1284
1285        None
1286    }
1287
1288    fn handle_text_event(
1289        &mut self,
1290        new_text: String,
1291        event_stream: &ThreadEventStream,
1292        cx: &mut Context<Self>,
1293    ) {
1294        event_stream.send_text(&new_text);
1295
1296        let last_message = self.pending_message();
1297        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1298            text.push_str(&new_text);
1299        } else {
1300            last_message
1301                .content
1302                .push(AgentMessageContent::Text(new_text));
1303        }
1304
1305        cx.notify();
1306    }
1307
1308    fn handle_thinking_event(
1309        &mut self,
1310        new_text: String,
1311        new_signature: Option<String>,
1312        event_stream: &ThreadEventStream,
1313        cx: &mut Context<Self>,
1314    ) {
1315        event_stream.send_thinking(&new_text);
1316
1317        let last_message = self.pending_message();
1318        if let Some(AgentMessageContent::Thinking { text, signature }) =
1319            last_message.content.last_mut()
1320        {
1321            text.push_str(&new_text);
1322            *signature = new_signature.or(signature.take());
1323        } else {
1324            last_message.content.push(AgentMessageContent::Thinking {
1325                text: new_text,
1326                signature: new_signature,
1327            });
1328        }
1329
1330        cx.notify();
1331    }
1332
1333    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1334        let last_message = self.pending_message();
1335        last_message
1336            .content
1337            .push(AgentMessageContent::RedactedThinking(data));
1338        cx.notify();
1339    }
1340
1341    fn handle_tool_use_event(
1342        &mut self,
1343        tool_use: LanguageModelToolUse,
1344        event_stream: &ThreadEventStream,
1345        cx: &mut Context<Self>,
1346    ) -> Option<Task<LanguageModelToolResult>> {
1347        cx.notify();
1348
1349        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1350        let mut title = SharedString::from(&tool_use.name);
1351        let mut kind = acp::ToolKind::Other;
1352        if let Some(tool) = tool.as_ref() {
1353            title = tool.initial_title(tool_use.input.clone());
1354            kind = tool.kind();
1355        }
1356
1357        // Ensure the last message ends in the current tool use
1358        let last_message = self.pending_message();
1359        let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
1360            if let AgentMessageContent::ToolUse(last_tool_use) = content {
1361                if last_tool_use.id == tool_use.id {
1362                    *last_tool_use = tool_use.clone();
1363                    false
1364                } else {
1365                    true
1366                }
1367            } else {
1368                true
1369            }
1370        });
1371
1372        if push_new_tool_use {
1373            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1374            last_message
1375                .content
1376                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1377        } else {
1378            event_stream.update_tool_call_fields(
1379                &tool_use.id,
1380                acp::ToolCallUpdateFields {
1381                    title: Some(title.into()),
1382                    kind: Some(kind),
1383                    raw_input: Some(tool_use.input.clone()),
1384                    ..Default::default()
1385                },
1386            );
1387        }
1388
1389        if !tool_use.is_input_complete {
1390            return None;
1391        }
1392
1393        let Some(tool) = tool else {
1394            let content = format!("No tool named {} exists", tool_use.name);
1395            return Some(Task::ready(LanguageModelToolResult {
1396                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1397                tool_use_id: tool_use.id,
1398                tool_name: tool_use.name,
1399                is_error: true,
1400                output: None,
1401            }));
1402        };
1403
1404        let fs = self.project.read(cx).fs().clone();
1405        let tool_event_stream =
1406            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1407        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1408            status: Some(acp::ToolCallStatus::InProgress),
1409            ..Default::default()
1410        });
1411        let supports_images = self.model().is_some_and(|model| model.supports_images());
1412        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1413        log::info!("Running tool {}", tool_use.name);
1414        Some(cx.foreground_executor().spawn(async move {
1415            let tool_result = tool_result.await.and_then(|output| {
1416                if let LanguageModelToolResultContent::Image(_) = &output.llm_output
1417                    && !supports_images
1418                {
1419                    return Err(anyhow!(
1420                        "Attempted to read an image, but this model doesn't support it.",
1421                    ));
1422                }
1423                Ok(output)
1424            });
1425
1426            match tool_result {
1427                Ok(output) => LanguageModelToolResult {
1428                    tool_use_id: tool_use.id,
1429                    tool_name: tool_use.name,
1430                    is_error: false,
1431                    content: output.llm_output,
1432                    output: Some(output.raw_output),
1433                },
1434                Err(error) => LanguageModelToolResult {
1435                    tool_use_id: tool_use.id,
1436                    tool_name: tool_use.name,
1437                    is_error: true,
1438                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1439                    output: None,
1440                },
1441            }
1442        }))
1443    }
1444
1445    fn handle_tool_use_json_parse_error_event(
1446        &mut self,
1447        tool_use_id: LanguageModelToolUseId,
1448        tool_name: Arc<str>,
1449        raw_input: Arc<str>,
1450        json_parse_error: String,
1451    ) -> LanguageModelToolResult {
1452        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1453        LanguageModelToolResult {
1454            tool_use_id,
1455            tool_name,
1456            is_error: true,
1457            content: LanguageModelToolResultContent::Text(tool_output.into()),
1458            output: Some(serde_json::Value::String(raw_input.to_string())),
1459        }
1460    }
1461
1462    pub fn title(&self) -> SharedString {
1463        self.title.clone().unwrap_or("New Thread".into())
1464    }
1465
1466    fn update_title(
1467        &mut self,
1468        event_stream: &ThreadEventStream,
1469        cx: &mut Context<Self>,
1470    ) -> Option<Task<Result<()>>> {
1471        if self.title.is_some() {
1472            log::debug!("Skipping title generation because we already have one.");
1473            return None;
1474        }
1475
1476        log::info!(
1477            "Generating title with model: {:?}",
1478            self.summarization_model.as_ref().map(|model| model.name())
1479        );
1480        let model = self.summarization_model.clone()?;
1481        let event_stream = event_stream.clone();
1482        let mut request = LanguageModelRequest {
1483            intent: Some(CompletionIntent::ThreadSummarization),
1484            temperature: AgentSettings::temperature_for_model(&model, cx),
1485            ..Default::default()
1486        };
1487
1488        for message in &self.messages {
1489            request.messages.extend(message.to_request());
1490        }
1491
1492        request.messages.push(LanguageModelRequestMessage {
1493            role: Role::User,
1494            content: vec![SUMMARIZE_THREAD_PROMPT.into()],
1495            cache: false,
1496        });
1497        Some(cx.spawn(async move |this, cx| {
1498            let mut title = String::new();
1499            let mut messages = model.stream_completion(request, cx).await?;
1500            while let Some(event) = messages.next().await {
1501                let event = event?;
1502                let text = match event {
1503                    LanguageModelCompletionEvent::Text(text) => text,
1504                    LanguageModelCompletionEvent::StatusUpdate(
1505                        CompletionRequestStatus::UsageUpdated { .. },
1506                    ) => {
1507                        // this.update(cx, |thread, cx| {
1508                        //     thread.update_model_request_usage(amount as u32, limit, cx);
1509                        // })?;
1510                        // TODO: handle usage update
1511                        continue;
1512                    }
1513                    _ => continue,
1514                };
1515
1516                let mut lines = text.lines();
1517                title.extend(lines.next());
1518
1519                // Stop if the LLM generated multiple lines.
1520                if lines.next().is_some() {
1521                    break;
1522                }
1523            }
1524
1525            log::info!("Setting title: {}", title);
1526
1527            this.update(cx, |this, cx| {
1528                let title = SharedString::from(title);
1529                event_stream.send_title_update(title.clone());
1530                this.title = Some(title);
1531                cx.notify();
1532            })
1533        }))
1534    }
1535
1536    fn pending_message(&mut self) -> &mut AgentMessage {
1537        self.pending_message.get_or_insert_default()
1538    }
1539
1540    fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1541        let Some(mut message) = self.pending_message.take() else {
1542            return;
1543        };
1544
1545        for content in &message.content {
1546            let AgentMessageContent::ToolUse(tool_use) = content else {
1547                continue;
1548            };
1549
1550            if !message.tool_results.contains_key(&tool_use.id) {
1551                message.tool_results.insert(
1552                    tool_use.id.clone(),
1553                    LanguageModelToolResult {
1554                        tool_use_id: tool_use.id.clone(),
1555                        tool_name: tool_use.name.clone(),
1556                        is_error: true,
1557                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1558                        output: None,
1559                    },
1560                );
1561            }
1562        }
1563
1564        self.messages.push(Message::Agent(message));
1565        self.updated_at = Utc::now();
1566        cx.notify()
1567    }
1568
1569    pub(crate) fn build_completion_request(
1570        &self,
1571        completion_intent: CompletionIntent,
1572        cx: &mut App,
1573    ) -> Result<LanguageModelRequest> {
1574        let model = self.model().context("No language model configured")?;
1575
1576        log::debug!("Building completion request");
1577        log::debug!("Completion intent: {:?}", completion_intent);
1578        log::debug!("Completion mode: {:?}", self.completion_mode);
1579
1580        let messages = self.build_request_messages(cx);
1581        log::info!("Request will include {} messages", messages.len());
1582
1583        let tools = if let Some(tools) = self.tools(cx).log_err() {
1584            tools
1585                .filter_map(|tool| {
1586                    let tool_name = tool.name().to_string();
1587                    log::trace!("Including tool: {}", tool_name);
1588                    Some(LanguageModelRequestTool {
1589                        name: tool_name,
1590                        description: tool.description().to_string(),
1591                        input_schema: tool.input_schema(model.tool_input_format()).log_err()?,
1592                    })
1593                })
1594                .collect()
1595        } else {
1596            Vec::new()
1597        };
1598
1599        log::info!("Request includes {} tools", tools.len());
1600
1601        let request = LanguageModelRequest {
1602            thread_id: Some(self.id.to_string()),
1603            prompt_id: Some(self.prompt_id.to_string()),
1604            intent: Some(completion_intent),
1605            mode: Some(self.completion_mode.into()),
1606            messages,
1607            tools,
1608            tool_choice: None,
1609            stop: Vec::new(),
1610            temperature: AgentSettings::temperature_for_model(model, cx),
1611            thinking_allowed: true,
1612        };
1613
1614        log::debug!("Completion request built successfully");
1615        Ok(request)
1616    }
1617
1618    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1619        let model = self.model().context("No language model configured")?;
1620
1621        let profile = AgentSettings::get_global(cx)
1622            .profiles
1623            .get(&self.profile_id)
1624            .context("profile not found")?;
1625        let provider_id = model.provider_id();
1626
1627        Ok(self
1628            .tools
1629            .iter()
1630            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1631            .filter_map(|(tool_name, tool)| {
1632                if profile.is_tool_enabled(tool_name) {
1633                    Some(tool)
1634                } else {
1635                    None
1636                }
1637            })
1638            .chain(self.context_server_registry.read(cx).servers().flat_map(
1639                |(server_id, tools)| {
1640                    tools.iter().filter_map(|(tool_name, tool)| {
1641                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1642                            Some(tool)
1643                        } else {
1644                            None
1645                        }
1646                    })
1647                },
1648            )))
1649    }
1650
1651    fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
1652        log::trace!(
1653            "Building request messages from {} thread messages",
1654            self.messages.len()
1655        );
1656        let mut messages = vec![self.build_system_message(cx)];
1657        for message in &self.messages {
1658            messages.extend(message.to_request());
1659        }
1660
1661        if let Some(message) = self.pending_message.as_ref() {
1662            messages.extend(message.to_request());
1663        }
1664
1665        if let Some(last_user_message) = messages
1666            .iter_mut()
1667            .rev()
1668            .find(|message| message.role == Role::User)
1669        {
1670            last_user_message.cache = true;
1671        }
1672
1673        messages
1674    }
1675
1676    pub fn to_markdown(&self) -> String {
1677        let mut markdown = String::new();
1678        for (ix, message) in self.messages.iter().enumerate() {
1679            if ix > 0 {
1680                markdown.push('\n');
1681            }
1682            markdown.push_str(&message.to_markdown());
1683        }
1684
1685        if let Some(message) = self.pending_message.as_ref() {
1686            markdown.push('\n');
1687            markdown.push_str(&message.to_markdown());
1688        }
1689
1690        markdown
1691    }
1692
1693    fn advance_prompt_id(&mut self) {
1694        self.prompt_id = PromptId::new();
1695    }
1696
1697    fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
1698        use LanguageModelCompletionError::*;
1699        use http_client::StatusCode;
1700
1701        // General strategy here:
1702        // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
1703        // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
1704        // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
1705        match error {
1706            HttpResponseError {
1707                status_code: StatusCode::TOO_MANY_REQUESTS,
1708                ..
1709            } => Some(RetryStrategy::ExponentialBackoff {
1710                initial_delay: BASE_RETRY_DELAY,
1711                max_attempts: MAX_RETRY_ATTEMPTS,
1712            }),
1713            ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
1714                Some(RetryStrategy::Fixed {
1715                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1716                    max_attempts: MAX_RETRY_ATTEMPTS,
1717                })
1718            }
1719            UpstreamProviderError {
1720                status,
1721                retry_after,
1722                ..
1723            } => match *status {
1724                StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
1725                    Some(RetryStrategy::Fixed {
1726                        delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1727                        max_attempts: MAX_RETRY_ATTEMPTS,
1728                    })
1729                }
1730                StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
1731                    delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1732                    // Internal Server Error could be anything, retry up to 3 times.
1733                    max_attempts: 3,
1734                }),
1735                status => {
1736                    // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
1737                    // but we frequently get them in practice. See https://http.dev/529
1738                    if status.as_u16() == 529 {
1739                        Some(RetryStrategy::Fixed {
1740                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1741                            max_attempts: MAX_RETRY_ATTEMPTS,
1742                        })
1743                    } else {
1744                        Some(RetryStrategy::Fixed {
1745                            delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
1746                            max_attempts: 2,
1747                        })
1748                    }
1749                }
1750            },
1751            ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
1752                delay: BASE_RETRY_DELAY,
1753                max_attempts: 3,
1754            }),
1755            ApiReadResponseError { .. }
1756            | HttpSend { .. }
1757            | DeserializeResponse { .. }
1758            | BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
1759                delay: BASE_RETRY_DELAY,
1760                max_attempts: 3,
1761            }),
1762            // Retrying these errors definitely shouldn't help.
1763            HttpResponseError {
1764                status_code:
1765                    StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
1766                ..
1767            }
1768            | AuthenticationError { .. }
1769            | PermissionError { .. }
1770            | NoApiKey { .. }
1771            | ApiEndpointNotFound { .. }
1772            | PromptTooLarge { .. } => None,
1773            // These errors might be transient, so retry them
1774            SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
1775                delay: BASE_RETRY_DELAY,
1776                max_attempts: 1,
1777            }),
1778            // Retry all other 4xx and 5xx errors once.
1779            HttpResponseError { status_code, .. }
1780                if status_code.is_client_error() || status_code.is_server_error() =>
1781            {
1782                Some(RetryStrategy::Fixed {
1783                    delay: BASE_RETRY_DELAY,
1784                    max_attempts: 3,
1785                })
1786            }
1787            Other(err)
1788                if err.is::<language_model::PaymentRequiredError>()
1789                    || err.is::<language_model::ModelRequestLimitReachedError>() =>
1790            {
1791                // Retrying won't help for Payment Required or Model Request Limit errors (where
1792                // the user must upgrade to usage-based billing to get more requests, or else wait
1793                // for a significant amount of time for the request limit to reset).
1794                None
1795            }
1796            // Conservatively assume that any other errors are non-retryable
1797            HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
1798                delay: BASE_RETRY_DELAY,
1799                max_attempts: 2,
1800            }),
1801        }
1802    }
1803}
1804
1805struct RunningTurn {
1806    /// Holds the task that handles agent interaction until the end of the turn.
1807    /// Survives across multiple requests as the model performs tool calls and
1808    /// we run tools, report their results.
1809    _task: Task<()>,
1810    /// The current event stream for the running turn. Used to report a final
1811    /// cancellation event if we cancel the turn.
1812    event_stream: ThreadEventStream,
1813}
1814
1815impl RunningTurn {
1816    fn cancel(self) {
1817        log::debug!("Cancelling in progress turn");
1818        self.event_stream.send_canceled();
1819    }
1820}
1821
1822pub trait AgentTool
1823where
1824    Self: 'static + Sized,
1825{
1826    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1827    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1828
1829    fn name(&self) -> SharedString;
1830
1831    fn description(&self) -> SharedString {
1832        let schema = schemars::schema_for!(Self::Input);
1833        SharedString::new(
1834            schema
1835                .get("description")
1836                .and_then(|description| description.as_str())
1837                .unwrap_or_default(),
1838        )
1839    }
1840
1841    fn kind(&self) -> acp::ToolKind;
1842
1843    /// The initial tool title to display. Can be updated during the tool run.
1844    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1845
1846    /// Returns the JSON schema that describes the tool's input.
1847    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
1848        crate::tool_schema::root_schema_for::<Self::Input>(format)
1849    }
1850
1851    /// Some tools rely on a provider for the underlying billing or other reasons.
1852    /// Allow the tool to check if they are compatible, or should be filtered out.
1853    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1854        true
1855    }
1856
1857    /// Runs the tool with the provided input.
1858    fn run(
1859        self: Arc<Self>,
1860        input: Self::Input,
1861        event_stream: ToolCallEventStream,
1862        cx: &mut App,
1863    ) -> Task<Result<Self::Output>>;
1864
1865    /// Emits events for a previous execution of the tool.
1866    fn replay(
1867        &self,
1868        _input: Self::Input,
1869        _output: Self::Output,
1870        _event_stream: ToolCallEventStream,
1871        _cx: &mut App,
1872    ) -> Result<()> {
1873        Ok(())
1874    }
1875
1876    fn erase(self) -> Arc<dyn AnyAgentTool> {
1877        Arc::new(Erased(Arc::new(self)))
1878    }
1879}
1880
1881pub struct Erased<T>(T);
1882
1883pub struct AgentToolOutput {
1884    pub llm_output: LanguageModelToolResultContent,
1885    pub raw_output: serde_json::Value,
1886}
1887
1888pub trait AnyAgentTool {
1889    fn name(&self) -> SharedString;
1890    fn description(&self) -> SharedString;
1891    fn kind(&self) -> acp::ToolKind;
1892    fn initial_title(&self, input: serde_json::Value) -> SharedString;
1893    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1894    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1895        true
1896    }
1897    fn run(
1898        self: Arc<Self>,
1899        input: serde_json::Value,
1900        event_stream: ToolCallEventStream,
1901        cx: &mut App,
1902    ) -> Task<Result<AgentToolOutput>>;
1903    fn replay(
1904        &self,
1905        input: serde_json::Value,
1906        output: serde_json::Value,
1907        event_stream: ToolCallEventStream,
1908        cx: &mut App,
1909    ) -> Result<()>;
1910}
1911
1912impl<T> AnyAgentTool for Erased<Arc<T>>
1913where
1914    T: AgentTool,
1915{
1916    fn name(&self) -> SharedString {
1917        self.0.name()
1918    }
1919
1920    fn description(&self) -> SharedString {
1921        self.0.description()
1922    }
1923
1924    fn kind(&self) -> agent_client_protocol::ToolKind {
1925        self.0.kind()
1926    }
1927
1928    fn initial_title(&self, input: serde_json::Value) -> SharedString {
1929        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1930        self.0.initial_title(parsed_input)
1931    }
1932
1933    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1934        let mut json = serde_json::to_value(self.0.input_schema(format))?;
1935        adapt_schema_to_format(&mut json, format)?;
1936        Ok(json)
1937    }
1938
1939    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1940        self.0.supported_provider(provider)
1941    }
1942
1943    fn run(
1944        self: Arc<Self>,
1945        input: serde_json::Value,
1946        event_stream: ToolCallEventStream,
1947        cx: &mut App,
1948    ) -> Task<Result<AgentToolOutput>> {
1949        cx.spawn(async move |cx| {
1950            let input = serde_json::from_value(input)?;
1951            let output = cx
1952                .update(|cx| self.0.clone().run(input, event_stream, cx))?
1953                .await?;
1954            let raw_output = serde_json::to_value(&output)?;
1955            Ok(AgentToolOutput {
1956                llm_output: output.into(),
1957                raw_output,
1958            })
1959        })
1960    }
1961
1962    fn replay(
1963        &self,
1964        input: serde_json::Value,
1965        output: serde_json::Value,
1966        event_stream: ToolCallEventStream,
1967        cx: &mut App,
1968    ) -> Result<()> {
1969        let input = serde_json::from_value(input)?;
1970        let output = serde_json::from_value(output)?;
1971        self.0.replay(input, output, event_stream, cx)
1972    }
1973}
1974
1975#[derive(Clone)]
1976struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1977
1978impl ThreadEventStream {
1979    fn send_title_update(&self, text: SharedString) {
1980        self.0
1981            .unbounded_send(Ok(ThreadEvent::TitleUpdate(text)))
1982            .ok();
1983    }
1984
1985    fn send_user_message(&self, message: &UserMessage) {
1986        self.0
1987            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1988            .ok();
1989    }
1990
1991    fn send_text(&self, text: &str) {
1992        self.0
1993            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1994            .ok();
1995    }
1996
1997    fn send_thinking(&self, text: &str) {
1998        self.0
1999            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
2000            .ok();
2001    }
2002
2003    fn send_tool_call(
2004        &self,
2005        id: &LanguageModelToolUseId,
2006        title: SharedString,
2007        kind: acp::ToolKind,
2008        input: serde_json::Value,
2009    ) {
2010        self.0
2011            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
2012                id,
2013                title.to_string(),
2014                kind,
2015                input,
2016            ))))
2017            .ok();
2018    }
2019
2020    fn initial_tool_call(
2021        id: &LanguageModelToolUseId,
2022        title: String,
2023        kind: acp::ToolKind,
2024        input: serde_json::Value,
2025    ) -> acp::ToolCall {
2026        acp::ToolCall {
2027            id: acp::ToolCallId(id.to_string().into()),
2028            title,
2029            kind,
2030            status: acp::ToolCallStatus::Pending,
2031            content: vec![],
2032            locations: vec![],
2033            raw_input: Some(input),
2034            raw_output: None,
2035        }
2036    }
2037
2038    fn update_tool_call_fields(
2039        &self,
2040        tool_use_id: &LanguageModelToolUseId,
2041        fields: acp::ToolCallUpdateFields,
2042    ) {
2043        self.0
2044            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2045                acp::ToolCallUpdate {
2046                    id: acp::ToolCallId(tool_use_id.to_string().into()),
2047                    fields,
2048                }
2049                .into(),
2050            )))
2051            .ok();
2052    }
2053
2054    fn send_retry(&self, status: acp_thread::RetryStatus) {
2055        self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
2056    }
2057
2058    fn send_stop(&self, reason: StopReason) {
2059        match reason {
2060            StopReason::EndTurn => {
2061                self.0
2062                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
2063                    .ok();
2064            }
2065            StopReason::MaxTokens => {
2066                self.0
2067                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
2068                    .ok();
2069            }
2070            StopReason::Refusal => {
2071                self.0
2072                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
2073                    .ok();
2074            }
2075            StopReason::ToolUse => {}
2076        }
2077    }
2078
2079    fn send_canceled(&self) {
2080        self.0
2081            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
2082            .ok();
2083    }
2084
2085    fn send_error(&self, error: impl Into<anyhow::Error>) {
2086        self.0.unbounded_send(Err(error.into())).ok();
2087    }
2088}
2089
2090#[derive(Clone)]
2091pub struct ToolCallEventStream {
2092    tool_use_id: LanguageModelToolUseId,
2093    stream: ThreadEventStream,
2094    fs: Option<Arc<dyn Fs>>,
2095}
2096
2097impl ToolCallEventStream {
2098    #[cfg(test)]
2099    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
2100        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
2101
2102        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
2103
2104        (stream, ToolCallEventStreamReceiver(events_rx))
2105    }
2106
2107    fn new(
2108        tool_use_id: LanguageModelToolUseId,
2109        stream: ThreadEventStream,
2110        fs: Option<Arc<dyn Fs>>,
2111    ) -> Self {
2112        Self {
2113            tool_use_id,
2114            stream,
2115            fs,
2116        }
2117    }
2118
2119    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
2120        self.stream
2121            .update_tool_call_fields(&self.tool_use_id, fields);
2122    }
2123
2124    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
2125        self.stream
2126            .0
2127            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2128                acp_thread::ToolCallUpdateDiff {
2129                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2130                    diff,
2131                }
2132                .into(),
2133            )))
2134            .ok();
2135    }
2136
2137    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
2138        self.stream
2139            .0
2140            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
2141                acp_thread::ToolCallUpdateTerminal {
2142                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2143                    terminal,
2144                }
2145                .into(),
2146            )))
2147            .ok();
2148    }
2149
2150    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
2151        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
2152            return Task::ready(Ok(()));
2153        }
2154
2155        let (response_tx, response_rx) = oneshot::channel();
2156        self.stream
2157            .0
2158            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
2159                ToolCallAuthorization {
2160                    tool_call: acp::ToolCallUpdate {
2161                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
2162                        fields: acp::ToolCallUpdateFields {
2163                            title: Some(title.into()),
2164                            ..Default::default()
2165                        },
2166                    },
2167                    options: vec![
2168                        acp::PermissionOption {
2169                            id: acp::PermissionOptionId("always_allow".into()),
2170                            name: "Always Allow".into(),
2171                            kind: acp::PermissionOptionKind::AllowAlways,
2172                        },
2173                        acp::PermissionOption {
2174                            id: acp::PermissionOptionId("allow".into()),
2175                            name: "Allow".into(),
2176                            kind: acp::PermissionOptionKind::AllowOnce,
2177                        },
2178                        acp::PermissionOption {
2179                            id: acp::PermissionOptionId("deny".into()),
2180                            name: "Deny".into(),
2181                            kind: acp::PermissionOptionKind::RejectOnce,
2182                        },
2183                    ],
2184                    response: response_tx,
2185                },
2186            )))
2187            .ok();
2188        let fs = self.fs.clone();
2189        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
2190            "always_allow" => {
2191                if let Some(fs) = fs.clone() {
2192                    cx.update(|cx| {
2193                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
2194                            settings.set_always_allow_tool_actions(true);
2195                        });
2196                    })?;
2197                }
2198
2199                Ok(())
2200            }
2201            "allow" => Ok(()),
2202            _ => Err(anyhow!("Permission to run tool denied by user")),
2203        })
2204    }
2205}
2206
2207#[cfg(test)]
2208pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
2209
2210#[cfg(test)]
2211impl ToolCallEventStreamReceiver {
2212    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
2213        let event = self.0.next().await;
2214        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
2215            auth
2216        } else {
2217            panic!("Expected ToolCallAuthorization but got: {:?}", event);
2218        }
2219    }
2220
2221    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
2222        let event = self.0.next().await;
2223        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
2224            update,
2225        )))) = event
2226        {
2227            update.terminal
2228        } else {
2229            panic!("Expected terminal but got: {:?}", event);
2230        }
2231    }
2232}
2233
2234#[cfg(test)]
2235impl std::ops::Deref for ToolCallEventStreamReceiver {
2236    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
2237
2238    fn deref(&self) -> &Self::Target {
2239        &self.0
2240    }
2241}
2242
2243#[cfg(test)]
2244impl std::ops::DerefMut for ToolCallEventStreamReceiver {
2245    fn deref_mut(&mut self) -> &mut Self::Target {
2246        &mut self.0
2247    }
2248}
2249
2250impl From<&str> for UserMessageContent {
2251    fn from(text: &str) -> Self {
2252        Self::Text(text.into())
2253    }
2254}
2255
2256impl From<acp::ContentBlock> for UserMessageContent {
2257    fn from(value: acp::ContentBlock) -> Self {
2258        match value {
2259            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
2260            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
2261            acp::ContentBlock::Audio(_) => {
2262                // TODO
2263                Self::Text("[audio]".to_string())
2264            }
2265            acp::ContentBlock::ResourceLink(resource_link) => {
2266                match MentionUri::parse(&resource_link.uri) {
2267                    Ok(uri) => Self::Mention {
2268                        uri,
2269                        content: String::new(),
2270                    },
2271                    Err(err) => {
2272                        log::error!("Failed to parse mention link: {}", err);
2273                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
2274                    }
2275                }
2276            }
2277            acp::ContentBlock::Resource(resource) => match resource.resource {
2278                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
2279                    match MentionUri::parse(&resource.uri) {
2280                        Ok(uri) => Self::Mention {
2281                            uri,
2282                            content: resource.text,
2283                        },
2284                        Err(err) => {
2285                            log::error!("Failed to parse mention link: {}", err);
2286                            Self::Text(
2287                                MarkdownCodeBlock {
2288                                    tag: &resource.uri,
2289                                    text: &resource.text,
2290                                }
2291                                .to_string(),
2292                            )
2293                        }
2294                    }
2295                }
2296                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
2297                    // TODO
2298                    Self::Text("[blob]".to_string())
2299                }
2300            },
2301        }
2302    }
2303}
2304
2305impl From<UserMessageContent> for acp::ContentBlock {
2306    fn from(content: UserMessageContent) -> Self {
2307        match content {
2308            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
2309                text,
2310                annotations: None,
2311            }),
2312            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
2313                data: image.source.to_string(),
2314                mime_type: "image/png".to_string(),
2315                annotations: None,
2316                uri: None,
2317            }),
2318            UserMessageContent::Mention { uri, content } => {
2319                acp::ContentBlock::ResourceLink(acp::ResourceLink {
2320                    uri: uri.to_uri().to_string(),
2321                    name: uri.name(),
2322                    annotations: None,
2323                    description: if content.is_empty() {
2324                        None
2325                    } else {
2326                        Some(content)
2327                    },
2328                    mime_type: None,
2329                    size: None,
2330                    title: None,
2331                })
2332            }
2333        }
2334    }
2335}
2336
2337fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
2338    LanguageModelImage {
2339        source: image_content.data.into(),
2340        // TODO: make this optional?
2341        size: gpui::Size::new(0.into(), 0.into()),
2342    }
2343}