thread.rs

   1use crate::{
   2    ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates,
   3};
   4use acp_thread::{MentionUri, UserMessageId};
   5use action_log::ActionLog;
   6use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
   7use agent_client_protocol as acp;
   8use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
   9use anyhow::{Context as _, Result, anyhow};
  10use assistant_tool::adapt_schema_to_format;
  11use chrono::{DateTime, Utc};
  12use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
  13use collections::IndexMap;
  14use fs::Fs;
  15use futures::{
  16    FutureExt,
  17    channel::{mpsc, oneshot},
  18    future::Shared,
  19    stream::FuturesUnordered,
  20};
  21use git::repository::DiffType;
  22use gpui::{App, AppContext, Context, Entity, SharedString, Task};
  23use language_model::{
  24    LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
  25    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  26    LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
  27    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage,
  28};
  29use project::{
  30    Project,
  31    git_store::{GitStore, RepositoryState},
  32};
  33use prompt_store::ProjectContext;
  34use schemars::{JsonSchema, Schema};
  35use serde::{Deserialize, Serialize};
  36use settings::{Settings, update_settings_file};
  37use smol::stream::StreamExt;
  38use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
  39use std::{fmt::Write, ops::Range};
  40use util::{ResultExt, markdown::MarkdownCodeBlock};
  41use uuid::Uuid;
  42
  43const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
  44
  45/// The ID of the user prompt that initiated a request.
  46///
  47/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
  48#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  49pub struct PromptId(Arc<str>);
  50
  51impl PromptId {
  52    pub fn new() -> Self {
  53        Self(Uuid::new_v4().to_string().into())
  54    }
  55}
  56
  57impl std::fmt::Display for PromptId {
  58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  59        write!(f, "{}", self.0)
  60    }
  61}
  62
  63#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  64pub enum Message {
  65    User(UserMessage),
  66    Agent(AgentMessage),
  67    Resume,
  68}
  69
  70impl Message {
  71    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
  72        match self {
  73            Message::Agent(agent_message) => Some(agent_message),
  74            _ => None,
  75        }
  76    }
  77
  78    pub fn to_markdown(&self) -> String {
  79        match self {
  80            Message::User(message) => message.to_markdown(),
  81            Message::Agent(message) => message.to_markdown(),
  82            Message::Resume => "[resumed after tool use limit was reached]".into(),
  83        }
  84    }
  85}
  86
  87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  88pub struct UserMessage {
  89    pub id: UserMessageId,
  90    pub content: Vec<UserMessageContent>,
  91}
  92
  93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  94pub enum UserMessageContent {
  95    Text(String),
  96    Mention { uri: MentionUri, content: String },
  97    Image(LanguageModelImage),
  98}
  99
 100impl UserMessage {
 101    pub fn to_markdown(&self) -> String {
 102        let mut markdown = String::from("## User\n\n");
 103
 104        for content in &self.content {
 105            match content {
 106                UserMessageContent::Text(text) => {
 107                    markdown.push_str(text);
 108                    markdown.push('\n');
 109                }
 110                UserMessageContent::Image(_) => {
 111                    markdown.push_str("<image />\n");
 112                }
 113                UserMessageContent::Mention { uri, content } => {
 114                    if !content.is_empty() {
 115                        let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
 116                    } else {
 117                        let _ = write!(&mut markdown, "{}\n", uri.as_link());
 118                    }
 119                }
 120            }
 121        }
 122
 123        markdown
 124    }
 125
 126    fn to_request(&self) -> LanguageModelRequestMessage {
 127        let mut message = LanguageModelRequestMessage {
 128            role: Role::User,
 129            content: Vec::with_capacity(self.content.len()),
 130            cache: false,
 131        };
 132
 133        const OPEN_CONTEXT: &str = "<context>\n\
 134            The following items were attached by the user. \
 135            They are up-to-date and don't need to be re-read.\n\n";
 136
 137        const OPEN_FILES_TAG: &str = "<files>";
 138        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
 139        const OPEN_THREADS_TAG: &str = "<threads>";
 140        const OPEN_FETCH_TAG: &str = "<fetched_urls>";
 141        const OPEN_RULES_TAG: &str =
 142            "<rules>\nThe user has specified the following rules that should be applied:\n";
 143
 144        let mut file_context = OPEN_FILES_TAG.to_string();
 145        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
 146        let mut thread_context = OPEN_THREADS_TAG.to_string();
 147        let mut fetch_context = OPEN_FETCH_TAG.to_string();
 148        let mut rules_context = OPEN_RULES_TAG.to_string();
 149
 150        for chunk in &self.content {
 151            let chunk = match chunk {
 152                UserMessageContent::Text(text) => {
 153                    language_model::MessageContent::Text(text.clone())
 154                }
 155                UserMessageContent::Image(value) => {
 156                    language_model::MessageContent::Image(value.clone())
 157                }
 158                UserMessageContent::Mention { uri, content } => {
 159                    match uri {
 160                        MentionUri::File { abs_path, .. } => {
 161                            write!(
 162                                &mut symbol_context,
 163                                "\n{}",
 164                                MarkdownCodeBlock {
 165                                    tag: &codeblock_tag(&abs_path, None),
 166                                    text: &content.to_string(),
 167                                }
 168                            )
 169                            .ok();
 170                        }
 171                        MentionUri::Symbol {
 172                            path, line_range, ..
 173                        }
 174                        | MentionUri::Selection {
 175                            path, line_range, ..
 176                        } => {
 177                            write!(
 178                                &mut rules_context,
 179                                "\n{}",
 180                                MarkdownCodeBlock {
 181                                    tag: &codeblock_tag(&path, Some(line_range)),
 182                                    text: &content
 183                                }
 184                            )
 185                            .ok();
 186                        }
 187                        MentionUri::Thread { .. } => {
 188                            write!(&mut thread_context, "\n{}\n", content).ok();
 189                        }
 190                        MentionUri::TextThread { .. } => {
 191                            write!(&mut thread_context, "\n{}\n", content).ok();
 192                        }
 193                        MentionUri::Rule { .. } => {
 194                            write!(
 195                                &mut rules_context,
 196                                "\n{}",
 197                                MarkdownCodeBlock {
 198                                    tag: "",
 199                                    text: &content
 200                                }
 201                            )
 202                            .ok();
 203                        }
 204                        MentionUri::Fetch { url } => {
 205                            write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
 206                        }
 207                    }
 208
 209                    language_model::MessageContent::Text(uri.as_link().to_string())
 210                }
 211            };
 212
 213            message.content.push(chunk);
 214        }
 215
 216        let len_before_context = message.content.len();
 217
 218        if file_context.len() > OPEN_FILES_TAG.len() {
 219            file_context.push_str("</files>\n");
 220            message
 221                .content
 222                .push(language_model::MessageContent::Text(file_context));
 223        }
 224
 225        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
 226            symbol_context.push_str("</symbols>\n");
 227            message
 228                .content
 229                .push(language_model::MessageContent::Text(symbol_context));
 230        }
 231
 232        if thread_context.len() > OPEN_THREADS_TAG.len() {
 233            thread_context.push_str("</threads>\n");
 234            message
 235                .content
 236                .push(language_model::MessageContent::Text(thread_context));
 237        }
 238
 239        if fetch_context.len() > OPEN_FETCH_TAG.len() {
 240            fetch_context.push_str("</fetched_urls>\n");
 241            message
 242                .content
 243                .push(language_model::MessageContent::Text(fetch_context));
 244        }
 245
 246        if rules_context.len() > OPEN_RULES_TAG.len() {
 247            rules_context.push_str("</user_rules>\n");
 248            message
 249                .content
 250                .push(language_model::MessageContent::Text(rules_context));
 251        }
 252
 253        if message.content.len() > len_before_context {
 254            message.content.insert(
 255                len_before_context,
 256                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
 257            );
 258            message
 259                .content
 260                .push(language_model::MessageContent::Text("</context>".into()));
 261        }
 262
 263        message
 264    }
 265}
 266
 267fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
 268    let mut result = String::new();
 269
 270    if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
 271        let _ = write!(result, "{} ", extension);
 272    }
 273
 274    let _ = write!(result, "{}", full_path.display());
 275
 276    if let Some(range) = line_range {
 277        if range.start == range.end {
 278            let _ = write!(result, ":{}", range.start + 1);
 279        } else {
 280            let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
 281        }
 282    }
 283
 284    result
 285}
 286
 287impl AgentMessage {
 288    pub fn to_markdown(&self) -> String {
 289        let mut markdown = String::from("## Assistant\n\n");
 290
 291        for content in &self.content {
 292            match content {
 293                AgentMessageContent::Text(text) => {
 294                    markdown.push_str(text);
 295                    markdown.push('\n');
 296                }
 297                AgentMessageContent::Thinking { text, .. } => {
 298                    markdown.push_str("<think>");
 299                    markdown.push_str(text);
 300                    markdown.push_str("</think>\n");
 301                }
 302                AgentMessageContent::RedactedThinking(_) => {
 303                    markdown.push_str("<redacted_thinking />\n")
 304                }
 305                AgentMessageContent::ToolUse(tool_use) => {
 306                    markdown.push_str(&format!(
 307                        "**Tool Use**: {} (ID: {})\n",
 308                        tool_use.name, tool_use.id
 309                    ));
 310                    markdown.push_str(&format!(
 311                        "{}\n",
 312                        MarkdownCodeBlock {
 313                            tag: "json",
 314                            text: &format!("{:#}", tool_use.input)
 315                        }
 316                    ));
 317                }
 318            }
 319        }
 320
 321        for tool_result in self.tool_results.values() {
 322            markdown.push_str(&format!(
 323                "**Tool Result**: {} (ID: {})\n\n",
 324                tool_result.tool_name, tool_result.tool_use_id
 325            ));
 326            if tool_result.is_error {
 327                markdown.push_str("**ERROR:**\n");
 328            }
 329
 330            match &tool_result.content {
 331                LanguageModelToolResultContent::Text(text) => {
 332                    writeln!(markdown, "{text}\n").ok();
 333                }
 334                LanguageModelToolResultContent::Image(_) => {
 335                    writeln!(markdown, "<image />\n").ok();
 336                }
 337            }
 338
 339            if let Some(output) = tool_result.output.as_ref() {
 340                writeln!(
 341                    markdown,
 342                    "**Debug Output**:\n\n```json\n{}\n```\n",
 343                    serde_json::to_string_pretty(output).unwrap()
 344                )
 345                .unwrap();
 346            }
 347        }
 348
 349        markdown
 350    }
 351
 352    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
 353        let mut assistant_message = LanguageModelRequestMessage {
 354            role: Role::Assistant,
 355            content: Vec::with_capacity(self.content.len()),
 356            cache: false,
 357        };
 358        for chunk in &self.content {
 359            let chunk = match chunk {
 360                AgentMessageContent::Text(text) => {
 361                    language_model::MessageContent::Text(text.clone())
 362                }
 363                AgentMessageContent::Thinking { text, signature } => {
 364                    language_model::MessageContent::Thinking {
 365                        text: text.clone(),
 366                        signature: signature.clone(),
 367                    }
 368                }
 369                AgentMessageContent::RedactedThinking(value) => {
 370                    language_model::MessageContent::RedactedThinking(value.clone())
 371                }
 372                AgentMessageContent::ToolUse(value) => {
 373                    language_model::MessageContent::ToolUse(value.clone())
 374                }
 375            };
 376            assistant_message.content.push(chunk);
 377        }
 378
 379        let mut user_message = LanguageModelRequestMessage {
 380            role: Role::User,
 381            content: Vec::new(),
 382            cache: false,
 383        };
 384
 385        for tool_result in self.tool_results.values() {
 386            user_message
 387                .content
 388                .push(language_model::MessageContent::ToolResult(
 389                    tool_result.clone(),
 390                ));
 391        }
 392
 393        let mut messages = Vec::new();
 394        if !assistant_message.content.is_empty() {
 395            messages.push(assistant_message);
 396        }
 397        if !user_message.content.is_empty() {
 398            messages.push(user_message);
 399        }
 400        messages
 401    }
 402}
 403
 404#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 405pub struct AgentMessage {
 406    pub content: Vec<AgentMessageContent>,
 407    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 408}
 409
 410#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 411pub enum AgentMessageContent {
 412    Text(String),
 413    Thinking {
 414        text: String,
 415        signature: Option<String>,
 416    },
 417    RedactedThinking(String),
 418    ToolUse(LanguageModelToolUse),
 419}
 420
 421#[derive(Debug)]
 422pub enum ThreadEvent {
 423    UserMessage(UserMessage),
 424    AgentText(String),
 425    AgentThinking(String),
 426    ToolCall(acp::ToolCall),
 427    ToolCallUpdate(acp_thread::ToolCallUpdate),
 428    ToolCallAuthorization(ToolCallAuthorization),
 429    Stop(acp::StopReason),
 430}
 431
 432#[derive(Debug)]
 433pub struct ToolCallAuthorization {
 434    pub tool_call: acp::ToolCallUpdate,
 435    pub options: Vec<acp::PermissionOption>,
 436    pub response: oneshot::Sender<acp::PermissionOptionId>,
 437}
 438
 439enum ThreadTitle {
 440    None,
 441    Pending(Task<()>),
 442    Done(Result<SharedString>),
 443}
 444
 445impl ThreadTitle {
 446    pub fn unwrap_or_default(&self) -> SharedString {
 447        if let ThreadTitle::Done(Ok(title)) = self {
 448            title.clone()
 449        } else {
 450            "New Thread".into()
 451        }
 452    }
 453}
 454
 455pub struct Thread {
 456    id: acp::SessionId,
 457    prompt_id: PromptId,
 458    updated_at: DateTime<Utc>,
 459    title: ThreadTitle,
 460    summary: DetailedSummaryState,
 461    messages: Vec<Message>,
 462    completion_mode: CompletionMode,
 463    /// Holds the task that handles agent interaction until the end of the turn.
 464    /// Survives across multiple requests as the model performs tool calls and
 465    /// we run tools, report their results.
 466    running_turn: Option<RunningTurn>,
 467    pending_message: Option<AgentMessage>,
 468    tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
 469    tool_use_limit_reached: bool,
 470    request_token_usage: Vec<TokenUsage>,
 471    cumulative_token_usage: TokenUsage,
 472    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 473    context_server_registry: Entity<ContextServerRegistry>,
 474    profile_id: AgentProfileId,
 475    project_context: Rc<RefCell<ProjectContext>>,
 476    templates: Arc<Templates>,
 477    model: Arc<dyn LanguageModel>,
 478    project: Entity<Project>,
 479    action_log: Entity<ActionLog>,
 480}
 481
 482impl Thread {
 483    pub fn new(
 484        id: acp::SessionId,
 485        project: Entity<Project>,
 486        project_context: Rc<RefCell<ProjectContext>>,
 487        context_server_registry: Entity<ContextServerRegistry>,
 488        action_log: Entity<ActionLog>,
 489        templates: Arc<Templates>,
 490        model: Arc<dyn LanguageModel>,
 491        cx: &mut Context<Self>,
 492    ) -> Self {
 493        let profile_id = AgentSettings::get_global(cx).default_profile.clone();
 494        Self {
 495            id,
 496            prompt_id: PromptId::new(),
 497            updated_at: Utc::now(),
 498            title: ThreadTitle::None,
 499            summary: DetailedSummaryState::default(),
 500            messages: Vec::new(),
 501            completion_mode: CompletionMode::Normal,
 502            running_turn: None,
 503            pending_message: None,
 504            tools: BTreeMap::default(),
 505            tool_use_limit_reached: false,
 506            request_token_usage: Vec::new(),
 507            cumulative_token_usage: TokenUsage::default(),
 508            initial_project_snapshot: {
 509                let project_snapshot = Self::project_snapshot(project.clone(), cx);
 510                cx.foreground_executor()
 511                    .spawn(async move { Some(project_snapshot.await) })
 512                    .shared()
 513            },
 514            context_server_registry,
 515            profile_id,
 516            project_context,
 517            templates,
 518            model,
 519            project,
 520            action_log,
 521        }
 522    }
 523
 524    pub fn id(&self) -> &acp::SessionId {
 525        &self.id
 526    }
 527
 528    pub fn from_db(
 529        id: acp::SessionId,
 530        db_thread: DbThread,
 531        project: Entity<Project>,
 532        project_context: Rc<RefCell<ProjectContext>>,
 533        context_server_registry: Entity<ContextServerRegistry>,
 534        action_log: Entity<ActionLog>,
 535        templates: Arc<Templates>,
 536        model: Arc<dyn LanguageModel>,
 537        cx: &mut Context<Self>,
 538    ) -> Self {
 539        let profile_id = db_thread
 540            .profile
 541            .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
 542        Self {
 543            id,
 544            prompt_id: PromptId::new(),
 545            title: ThreadTitle::Done(Ok(db_thread.title.clone())),
 546            summary: db_thread.summary,
 547            messages: db_thread.messages,
 548            completion_mode: CompletionMode::Normal,
 549            running_turn: None,
 550            pending_message: None,
 551            tools: BTreeMap::default(),
 552            tool_use_limit_reached: false,
 553            request_token_usage: db_thread.request_token_usage.clone(),
 554            cumulative_token_usage: db_thread.cumulative_token_usage.clone(),
 555            initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(),
 556            context_server_registry,
 557            profile_id,
 558            project_context,
 559            templates,
 560            model,
 561            project,
 562            action_log,
 563            updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
 564        }
 565    }
 566
 567    pub fn to_db(&self, cx: &App) -> Task<DbThread> {
 568        let initial_project_snapshot = self.initial_project_snapshot.clone();
 569        let mut thread = DbThread {
 570            title: self.title.unwrap_or_default(),
 571            messages: self.messages.clone(),
 572            updated_at: self.updated_at.clone(),
 573            summary: self.summary.clone(),
 574            initial_project_snapshot: None,
 575            cumulative_token_usage: self.cumulative_token_usage.clone(),
 576            request_token_usage: self.request_token_usage.clone(),
 577            model: Some(DbLanguageModel {
 578                provider: self.model.provider_id().to_string(),
 579                model: self.model.name().0.to_string(),
 580            }),
 581            completion_mode: Some(self.completion_mode.into()),
 582            profile: Some(self.profile_id.clone()),
 583        };
 584
 585        cx.background_spawn(async move {
 586            let initial_project_snapshot = initial_project_snapshot.await;
 587            thread.initial_project_snapshot = initial_project_snapshot;
 588            thread
 589        })
 590    }
 591
 592    pub fn replay(
 593        &mut self,
 594        cx: &mut Context<Self>,
 595    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 596        let (tx, rx) = mpsc::unbounded();
 597        let stream = ThreadEventStream(tx);
 598        for message in &self.messages {
 599            match message {
 600                Message::User(user_message) => stream.send_user_message(&user_message),
 601                Message::Agent(assistant_message) => {
 602                    for content in &assistant_message.content {
 603                        match content {
 604                            AgentMessageContent::Text(text) => stream.send_text(text),
 605                            AgentMessageContent::Thinking { text, .. } => {
 606                                stream.send_thinking(text)
 607                            }
 608                            AgentMessageContent::RedactedThinking(_) => {}
 609                            AgentMessageContent::ToolUse(tool_use) => {
 610                                self.replay_tool_call(
 611                                    tool_use,
 612                                    assistant_message.tool_results.get(&tool_use.id),
 613                                    &stream,
 614                                    cx,
 615                                );
 616                            }
 617                        }
 618                    }
 619                }
 620                Message::Resume => {}
 621            }
 622        }
 623        rx
 624    }
 625
 626    fn replay_tool_call(
 627        &self,
 628        tool_use: &LanguageModelToolUse,
 629        tool_result: Option<&LanguageModelToolResult>,
 630        stream: &ThreadEventStream,
 631        cx: &mut Context<Self>,
 632    ) {
 633        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
 634            stream
 635                .0
 636                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
 637                    id: acp::ToolCallId(tool_use.id.to_string().into()),
 638                    title: tool_use.name.to_string(),
 639                    kind: acp::ToolKind::Other,
 640                    status: acp::ToolCallStatus::Failed,
 641                    content: Vec::new(),
 642                    locations: Vec::new(),
 643                    raw_input: Some(tool_use.input.clone()),
 644                    raw_output: None,
 645                })))
 646                .ok();
 647            return;
 648        };
 649
 650        let title = tool.initial_title(tool_use.input.clone());
 651        let kind = tool.kind();
 652        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
 653
 654        let output = tool_result
 655            .as_ref()
 656            .and_then(|result| result.output.clone());
 657        if let Some(output) = output.clone() {
 658            let tool_event_stream = ToolCallEventStream::new(
 659                tool_use.id.clone(),
 660                stream.clone(),
 661                Some(self.project.read(cx).fs().clone()),
 662            );
 663            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
 664                .log_err();
 665        }
 666
 667        stream.update_tool_call_fields(
 668            &tool_use.id,
 669            acp::ToolCallUpdateFields {
 670                status: Some(acp::ToolCallStatus::Completed),
 671                raw_output: output,
 672                ..Default::default()
 673            },
 674        );
 675    }
 676
 677    /// Create a snapshot of the current project state including git information and unsaved buffers.
 678    fn project_snapshot(
 679        project: Entity<Project>,
 680        cx: &mut Context<Self>,
 681    ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
 682        let git_store = project.read(cx).git_store().clone();
 683        let worktree_snapshots: Vec<_> = project
 684            .read(cx)
 685            .visible_worktrees(cx)
 686            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
 687            .collect();
 688
 689        cx.spawn(async move |_, cx| {
 690            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
 691
 692            let mut unsaved_buffers = Vec::new();
 693            cx.update(|app_cx| {
 694                let buffer_store = project.read(app_cx).buffer_store();
 695                for buffer_handle in buffer_store.read(app_cx).buffers() {
 696                    let buffer = buffer_handle.read(app_cx);
 697                    if buffer.is_dirty() {
 698                        if let Some(file) = buffer.file() {
 699                            let path = file.path().to_string_lossy().to_string();
 700                            unsaved_buffers.push(path);
 701                        }
 702                    }
 703                }
 704            })
 705            .ok();
 706
 707            Arc::new(ProjectSnapshot {
 708                worktree_snapshots,
 709                unsaved_buffer_paths: unsaved_buffers,
 710                timestamp: Utc::now(),
 711            })
 712        })
 713    }
 714
 715    fn worktree_snapshot(
 716        worktree: Entity<project::Worktree>,
 717        git_store: Entity<GitStore>,
 718        cx: &App,
 719    ) -> Task<agent::thread::WorktreeSnapshot> {
 720        cx.spawn(async move |cx| {
 721            // Get worktree path and snapshot
 722            let worktree_info = cx.update(|app_cx| {
 723                let worktree = worktree.read(app_cx);
 724                let path = worktree.abs_path().to_string_lossy().to_string();
 725                let snapshot = worktree.snapshot();
 726                (path, snapshot)
 727            });
 728
 729            let Ok((worktree_path, _snapshot)) = worktree_info else {
 730                return WorktreeSnapshot {
 731                    worktree_path: String::new(),
 732                    git_state: None,
 733                };
 734            };
 735
 736            let git_state = git_store
 737                .update(cx, |git_store, cx| {
 738                    git_store
 739                        .repositories()
 740                        .values()
 741                        .find(|repo| {
 742                            repo.read(cx)
 743                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
 744                                .is_some()
 745                        })
 746                        .cloned()
 747                })
 748                .ok()
 749                .flatten()
 750                .map(|repo| {
 751                    repo.update(cx, |repo, _| {
 752                        let current_branch =
 753                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
 754                        repo.send_job(None, |state, _| async move {
 755                            let RepositoryState::Local { backend, .. } = state else {
 756                                return GitState {
 757                                    remote_url: None,
 758                                    head_sha: None,
 759                                    current_branch,
 760                                    diff: None,
 761                                };
 762                            };
 763
 764                            let remote_url = backend.remote_url("origin");
 765                            let head_sha = backend.head_sha().await;
 766                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
 767
 768                            GitState {
 769                                remote_url,
 770                                head_sha,
 771                                current_branch,
 772                                diff,
 773                            }
 774                        })
 775                    })
 776                });
 777
 778            let git_state = match git_state {
 779                Some(git_state) => match git_state.ok() {
 780                    Some(git_state) => git_state.await.ok(),
 781                    None => None,
 782                },
 783                None => None,
 784            };
 785
 786            WorktreeSnapshot {
 787                worktree_path,
 788                git_state,
 789            }
 790        })
 791    }
 792
 793    pub fn project(&self) -> &Entity<Project> {
 794        &self.project
 795    }
 796
 797    pub fn action_log(&self) -> &Entity<ActionLog> {
 798        &self.action_log
 799    }
 800
 801    pub fn model(&self) -> &Arc<dyn LanguageModel> {
 802        &self.model
 803    }
 804
 805    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
 806        self.model = model;
 807        cx.notify()
 808    }
 809
 810    pub fn completion_mode(&self) -> CompletionMode {
 811        self.completion_mode
 812    }
 813
 814    pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
 815        self.completion_mode = mode;
 816        cx.notify()
 817    }
 818
 819    #[cfg(any(test, feature = "test-support"))]
 820    pub fn last_message(&self) -> Option<Message> {
 821        if let Some(message) = self.pending_message.clone() {
 822            Some(Message::Agent(message))
 823        } else {
 824            self.messages.last().cloned()
 825        }
 826    }
 827
 828    pub fn add_tool(&mut self, tool: impl AgentTool) {
 829        self.tools.insert(tool.name(), tool.erase());
 830    }
 831
 832    pub fn remove_tool(&mut self, name: &str) -> bool {
 833        self.tools.remove(name).is_some()
 834    }
 835
 836    pub fn profile(&self) -> &AgentProfileId {
 837        &self.profile_id
 838    }
 839
 840    pub fn set_profile(&mut self, profile_id: AgentProfileId) {
 841        self.profile_id = profile_id;
 842    }
 843
 844    pub fn cancel(&mut self, cx: &mut Context<Self>) {
 845        if let Some(running_turn) = self.running_turn.take() {
 846            running_turn.cancel();
 847        }
 848        self.flush_pending_message(cx);
 849    }
 850
 851    pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
 852        self.cancel(cx);
 853        let Some(position) = self.messages.iter().position(
 854            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
 855        ) else {
 856            return Err(anyhow!("Message not found"));
 857        };
 858        self.messages.truncate(position);
 859        cx.notify();
 860        Ok(())
 861    }
 862
 863    pub fn resume(
 864        &mut self,
 865        cx: &mut Context<Self>,
 866    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
 867        anyhow::ensure!(
 868            self.tool_use_limit_reached,
 869            "can only resume after tool use limit is reached"
 870        );
 871
 872        self.messages.push(Message::Resume);
 873        cx.notify();
 874
 875        log::info!("Total messages in thread: {}", self.messages.len());
 876        Ok(self.run_turn(cx))
 877    }
 878
 879    /// Sending a message results in the model streaming a response, which could include tool calls.
 880    /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
 881    /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
 882    pub fn send<T>(
 883        &mut self,
 884        id: UserMessageId,
 885        content: impl IntoIterator<Item = T>,
 886        cx: &mut Context<Self>,
 887    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
 888    where
 889        T: Into<UserMessageContent>,
 890    {
 891        log::info!("Thread::send called with model: {:?}", self.model.name());
 892        self.advance_prompt_id();
 893
 894        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
 895        log::debug!("Thread::send content: {:?}", content);
 896
 897        self.messages
 898            .push(Message::User(UserMessage { id, content }));
 899        cx.notify();
 900
 901        log::info!("Total messages in thread: {}", self.messages.len());
 902        self.run_turn(cx)
 903    }
 904
 905    fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
 906        self.cancel(cx);
 907
 908        let model = self.model.clone();
 909        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
 910        let event_stream = ThreadEventStream(events_tx);
 911        let message_ix = self.messages.len().saturating_sub(1);
 912        self.tool_use_limit_reached = false;
 913        self.running_turn = Some(RunningTurn {
 914            event_stream: event_stream.clone(),
 915            _task: cx.spawn(async move |this, cx| {
 916                log::info!("Starting agent turn execution");
 917                let turn_result: Result<()> = async {
 918                    let mut completion_intent = CompletionIntent::UserPrompt;
 919                    loop {
 920                        log::debug!(
 921                            "Building completion request with intent: {:?}",
 922                            completion_intent
 923                        );
 924                        let request = this.update(cx, |this, cx| {
 925                            this.build_completion_request(completion_intent, cx)
 926                        })?;
 927
 928                        log::info!("Calling model.stream_completion");
 929                        let mut events = model.stream_completion(request, cx).await?;
 930                        log::debug!("Stream completion started successfully");
 931
 932                        let mut tool_use_limit_reached = false;
 933                        let mut tool_uses = FuturesUnordered::new();
 934                        while let Some(event) = events.next().await {
 935                            match event? {
 936                                LanguageModelCompletionEvent::StatusUpdate(
 937                                    CompletionRequestStatus::ToolUseLimitReached,
 938                                ) => {
 939                                    tool_use_limit_reached = true;
 940                                }
 941                                LanguageModelCompletionEvent::Stop(reason) => {
 942                                    event_stream.send_stop(reason);
 943                                    if reason == StopReason::Refusal {
 944                                        this.update(cx, |this, cx| {
 945                                            this.flush_pending_message(cx);
 946                                            this.messages.truncate(message_ix);
 947                                        })?;
 948                                        return Ok(());
 949                                    }
 950                                }
 951                                event => {
 952                                    log::trace!("Received completion event: {:?}", event);
 953                                    this.update(cx, |this, cx| {
 954                                        tool_uses.extend(this.handle_streamed_completion_event(
 955                                            event,
 956                                            &event_stream,
 957                                            cx,
 958                                        ));
 959                                    })
 960                                    .ok();
 961                                }
 962                            }
 963                        }
 964
 965                        let used_tools = tool_uses.is_empty();
 966                        while let Some(tool_result) = tool_uses.next().await {
 967                            log::info!("Tool finished {:?}", tool_result);
 968
 969                            event_stream.update_tool_call_fields(
 970                                &tool_result.tool_use_id,
 971                                acp::ToolCallUpdateFields {
 972                                    status: Some(if tool_result.is_error {
 973                                        acp::ToolCallStatus::Failed
 974                                    } else {
 975                                        acp::ToolCallStatus::Completed
 976                                    }),
 977                                    raw_output: tool_result.output.clone(),
 978                                    ..Default::default()
 979                                },
 980                            );
 981                            this.update(cx, |this, _cx| {
 982                                this.pending_message()
 983                                    .tool_results
 984                                    .insert(tool_result.tool_use_id.clone(), tool_result);
 985                            })
 986                            .ok();
 987                        }
 988
 989                        if tool_use_limit_reached {
 990                            log::info!("Tool use limit reached, completing turn");
 991                            this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
 992                            return Err(language_model::ToolUseLimitReachedError.into());
 993                        } else if used_tools {
 994                            log::info!("No tool uses found, completing turn");
 995                            return Ok(());
 996                        } else {
 997                            this.update(cx, |this, cx| this.flush_pending_message(cx))?;
 998                            completion_intent = CompletionIntent::ToolResults;
 999                        }
1000                    }
1001                }
1002                .await;
1003
1004                if let Err(error) = turn_result {
1005                    log::error!("Turn execution failed: {:?}", error);
1006                    event_stream.send_error(error);
1007                } else {
1008                    log::info!("Turn execution completed successfully");
1009                }
1010
1011                this.update(cx, |this, cx| {
1012                    this.flush_pending_message(cx);
1013                    this.running_turn.take();
1014                })
1015                .ok();
1016            }),
1017        });
1018        events_rx
1019    }
1020
1021    pub fn build_system_message(&self) -> LanguageModelRequestMessage {
1022        log::debug!("Building system message");
1023        let prompt = SystemPromptTemplate {
1024            project: &self.project_context.borrow(),
1025            available_tools: self.tools.keys().cloned().collect(),
1026        }
1027        .render(&self.templates)
1028        .context("failed to build system prompt")
1029        .expect("Invalid template");
1030        log::debug!("System message built");
1031        LanguageModelRequestMessage {
1032            role: Role::System,
1033            content: vec![prompt.into()],
1034            cache: true,
1035        }
1036    }
1037
1038    /// A helper method that's called on every streamed completion event.
1039    /// Returns an optional tool result task, which the main agentic loop in
1040    /// send will send back to the model when it resolves.
1041    fn handle_streamed_completion_event(
1042        &mut self,
1043        event: LanguageModelCompletionEvent,
1044        event_stream: &ThreadEventStream,
1045        cx: &mut Context<Self>,
1046    ) -> Option<Task<LanguageModelToolResult>> {
1047        log::trace!("Handling streamed completion event: {:?}", event);
1048        use LanguageModelCompletionEvent::*;
1049
1050        match event {
1051            StartMessage { .. } => {
1052                self.flush_pending_message(cx);
1053                self.pending_message = Some(AgentMessage::default());
1054            }
1055            Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
1056            Thinking { text, signature } => {
1057                self.handle_thinking_event(text, signature, event_stream, cx)
1058            }
1059            RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
1060            ToolUse(tool_use) => {
1061                return self.handle_tool_use_event(tool_use, event_stream, cx);
1062            }
1063            ToolUseJsonParseError {
1064                id,
1065                tool_name,
1066                raw_input,
1067                json_parse_error,
1068            } => {
1069                return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
1070                    id,
1071                    tool_name,
1072                    raw_input,
1073                    json_parse_error,
1074                )));
1075            }
1076            UsageUpdate(_) | StatusUpdate(_) => {}
1077            Stop(_) => unreachable!(),
1078        }
1079
1080        None
1081    }
1082
1083    fn handle_text_event(
1084        &mut self,
1085        new_text: String,
1086        event_stream: &ThreadEventStream,
1087        cx: &mut Context<Self>,
1088    ) {
1089        event_stream.send_text(&new_text);
1090
1091        let last_message = self.pending_message();
1092        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
1093            text.push_str(&new_text);
1094        } else {
1095            last_message
1096                .content
1097                .push(AgentMessageContent::Text(new_text));
1098        }
1099
1100        cx.notify();
1101    }
1102
1103    fn handle_thinking_event(
1104        &mut self,
1105        new_text: String,
1106        new_signature: Option<String>,
1107        event_stream: &ThreadEventStream,
1108        cx: &mut Context<Self>,
1109    ) {
1110        event_stream.send_thinking(&new_text);
1111
1112        let last_message = self.pending_message();
1113        if let Some(AgentMessageContent::Thinking { text, signature }) =
1114            last_message.content.last_mut()
1115        {
1116            text.push_str(&new_text);
1117            *signature = new_signature.or(signature.take());
1118        } else {
1119            last_message.content.push(AgentMessageContent::Thinking {
1120                text: new_text,
1121                signature: new_signature,
1122            });
1123        }
1124
1125        cx.notify();
1126    }
1127
1128    fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
1129        let last_message = self.pending_message();
1130        last_message
1131            .content
1132            .push(AgentMessageContent::RedactedThinking(data));
1133        cx.notify();
1134    }
1135
1136    fn handle_tool_use_event(
1137        &mut self,
1138        tool_use: LanguageModelToolUse,
1139        event_stream: &ThreadEventStream,
1140        cx: &mut Context<Self>,
1141    ) -> Option<Task<LanguageModelToolResult>> {
1142        cx.notify();
1143
1144        let tool = self.tools.get(tool_use.name.as_ref()).cloned();
1145        let mut title = SharedString::from(&tool_use.name);
1146        let mut kind = acp::ToolKind::Other;
1147        if let Some(tool) = tool.as_ref() {
1148            title = tool.initial_title(tool_use.input.clone());
1149            kind = tool.kind();
1150        }
1151
1152        // Ensure the last message ends in the current tool use
1153        let last_message = self.pending_message();
1154        let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
1155            if let AgentMessageContent::ToolUse(last_tool_use) = content {
1156                if last_tool_use.id == tool_use.id {
1157                    *last_tool_use = tool_use.clone();
1158                    false
1159                } else {
1160                    true
1161                }
1162            } else {
1163                true
1164            }
1165        });
1166
1167        if push_new_tool_use {
1168            event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
1169            last_message
1170                .content
1171                .push(AgentMessageContent::ToolUse(tool_use.clone()));
1172        } else {
1173            event_stream.update_tool_call_fields(
1174                &tool_use.id,
1175                acp::ToolCallUpdateFields {
1176                    title: Some(title.into()),
1177                    kind: Some(kind),
1178                    raw_input: Some(tool_use.input.clone()),
1179                    ..Default::default()
1180                },
1181            );
1182        }
1183
1184        if !tool_use.is_input_complete {
1185            return None;
1186        }
1187
1188        let Some(tool) = tool else {
1189            let content = format!("No tool named {} exists", tool_use.name);
1190            return Some(Task::ready(LanguageModelToolResult {
1191                content: LanguageModelToolResultContent::Text(Arc::from(content)),
1192                tool_use_id: tool_use.id,
1193                tool_name: tool_use.name,
1194                is_error: true,
1195                output: None,
1196            }));
1197        };
1198
1199        let fs = self.project.read(cx).fs().clone();
1200        let tool_event_stream =
1201            ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
1202        tool_event_stream.update_fields(acp::ToolCallUpdateFields {
1203            status: Some(acp::ToolCallStatus::InProgress),
1204            ..Default::default()
1205        });
1206        let supports_images = self.model.supports_images();
1207        let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
1208        log::info!("Running tool {}", tool_use.name);
1209        Some(cx.foreground_executor().spawn(async move {
1210            let tool_result = tool_result.await.and_then(|output| {
1211                if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
1212                    if !supports_images {
1213                        return Err(anyhow!(
1214                            "Attempted to read an image, but this model doesn't support it.",
1215                        ));
1216                    }
1217                }
1218                Ok(output)
1219            });
1220
1221            match tool_result {
1222                Ok(output) => LanguageModelToolResult {
1223                    tool_use_id: tool_use.id,
1224                    tool_name: tool_use.name,
1225                    is_error: false,
1226                    content: output.llm_output,
1227                    output: Some(output.raw_output),
1228                },
1229                Err(error) => LanguageModelToolResult {
1230                    tool_use_id: tool_use.id,
1231                    tool_name: tool_use.name,
1232                    is_error: true,
1233                    content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
1234                    output: None,
1235                },
1236            }
1237        }))
1238    }
1239
1240    fn handle_tool_use_json_parse_error_event(
1241        &mut self,
1242        tool_use_id: LanguageModelToolUseId,
1243        tool_name: Arc<str>,
1244        raw_input: Arc<str>,
1245        json_parse_error: String,
1246    ) -> LanguageModelToolResult {
1247        let tool_output = format!("Error parsing input JSON: {json_parse_error}");
1248        LanguageModelToolResult {
1249            tool_use_id,
1250            tool_name,
1251            is_error: true,
1252            content: LanguageModelToolResultContent::Text(tool_output.into()),
1253            output: Some(serde_json::Value::String(raw_input.to_string())),
1254        }
1255    }
1256
1257    fn pending_message(&mut self) -> &mut AgentMessage {
1258        self.pending_message.get_or_insert_default()
1259    }
1260
1261    fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
1262        let Some(mut message) = self.pending_message.take() else {
1263            return;
1264        };
1265
1266        for content in &message.content {
1267            let AgentMessageContent::ToolUse(tool_use) = content else {
1268                continue;
1269            };
1270
1271            if !message.tool_results.contains_key(&tool_use.id) {
1272                message.tool_results.insert(
1273                    tool_use.id.clone(),
1274                    LanguageModelToolResult {
1275                        tool_use_id: tool_use.id.clone(),
1276                        tool_name: tool_use.name.clone(),
1277                        is_error: true,
1278                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
1279                        output: None,
1280                    },
1281                );
1282            }
1283        }
1284
1285        self.messages.push(Message::Agent(message));
1286        cx.notify()
1287    }
1288
1289    pub(crate) fn build_completion_request(
1290        &self,
1291        completion_intent: CompletionIntent,
1292        cx: &mut App,
1293    ) -> LanguageModelRequest {
1294        log::debug!("Building completion request");
1295        log::debug!("Completion intent: {:?}", completion_intent);
1296        log::debug!("Completion mode: {:?}", self.completion_mode);
1297
1298        let messages = self.build_request_messages();
1299        log::info!("Request will include {} messages", messages.len());
1300
1301        let tools = if let Some(tools) = self.tools(cx).log_err() {
1302            tools
1303                .filter_map(|tool| {
1304                    let tool_name = tool.name().to_string();
1305                    log::trace!("Including tool: {}", tool_name);
1306                    Some(LanguageModelRequestTool {
1307                        name: tool_name,
1308                        description: tool.description().to_string(),
1309                        input_schema: tool
1310                            .input_schema(self.model.tool_input_format())
1311                            .log_err()?,
1312                    })
1313                })
1314                .collect()
1315        } else {
1316            Vec::new()
1317        };
1318
1319        log::info!("Request includes {} tools", tools.len());
1320
1321        let request = LanguageModelRequest {
1322            thread_id: Some(self.id.to_string()),
1323            prompt_id: Some(self.prompt_id.to_string()),
1324            intent: Some(completion_intent),
1325            mode: Some(self.completion_mode.into()),
1326            messages,
1327            tools,
1328            tool_choice: None,
1329            stop: Vec::new(),
1330            temperature: AgentSettings::temperature_for_model(self.model(), cx),
1331            thinking_allowed: true,
1332        };
1333
1334        log::debug!("Completion request built successfully");
1335        request
1336    }
1337
1338    fn tools<'a>(&'a self, cx: &'a App) -> Result<impl Iterator<Item = &'a Arc<dyn AnyAgentTool>>> {
1339        let profile = AgentSettings::get_global(cx)
1340            .profiles
1341            .get(&self.profile_id)
1342            .context("profile not found")?;
1343        let provider_id = self.model.provider_id();
1344
1345        Ok(self
1346            .tools
1347            .iter()
1348            .filter(move |(_, tool)| tool.supported_provider(&provider_id))
1349            .filter_map(|(tool_name, tool)| {
1350                if profile.is_tool_enabled(tool_name) {
1351                    Some(tool)
1352                } else {
1353                    None
1354                }
1355            })
1356            .chain(self.context_server_registry.read(cx).servers().flat_map(
1357                |(server_id, tools)| {
1358                    tools.iter().filter_map(|(tool_name, tool)| {
1359                        if profile.is_context_server_tool_enabled(&server_id.0, tool_name) {
1360                            Some(tool)
1361                        } else {
1362                            None
1363                        }
1364                    })
1365                },
1366            )))
1367    }
1368
1369    fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
1370        log::trace!(
1371            "Building request messages from {} thread messages",
1372            self.messages.len()
1373        );
1374        let mut messages = vec![self.build_system_message()];
1375        for message in &self.messages {
1376            match message {
1377                Message::User(message) => messages.push(message.to_request()),
1378                Message::Agent(message) => messages.extend(message.to_request()),
1379                Message::Resume => messages.push(LanguageModelRequestMessage {
1380                    role: Role::User,
1381                    content: vec!["Continue where you left off".into()],
1382                    cache: false,
1383                }),
1384            }
1385        }
1386
1387        if let Some(message) = self.pending_message.as_ref() {
1388            messages.extend(message.to_request());
1389        }
1390
1391        if let Some(last_user_message) = messages
1392            .iter_mut()
1393            .rev()
1394            .find(|message| message.role == Role::User)
1395        {
1396            last_user_message.cache = true;
1397        }
1398
1399        messages
1400    }
1401
1402    pub fn to_markdown(&self) -> String {
1403        let mut markdown = String::new();
1404        for (ix, message) in self.messages.iter().enumerate() {
1405            if ix > 0 {
1406                markdown.push('\n');
1407            }
1408            markdown.push_str(&message.to_markdown());
1409        }
1410
1411        if let Some(message) = self.pending_message.as_ref() {
1412            markdown.push('\n');
1413            markdown.push_str(&message.to_markdown());
1414        }
1415
1416        markdown
1417    }
1418
1419    fn advance_prompt_id(&mut self) {
1420        self.prompt_id = PromptId::new();
1421    }
1422}
1423
1424struct RunningTurn {
1425    /// Holds the task that handles agent interaction until the end of the turn.
1426    /// Survives across multiple requests as the model performs tool calls and
1427    /// we run tools, report their results.
1428    _task: Task<()>,
1429    /// The current event stream for the running turn. Used to report a final
1430    /// cancellation event if we cancel the turn.
1431    event_stream: ThreadEventStream,
1432}
1433
1434impl RunningTurn {
1435    fn cancel(self) {
1436        log::debug!("Cancelling in progress turn");
1437        self.event_stream.send_canceled();
1438    }
1439}
1440
1441pub trait AgentTool
1442where
1443    Self: 'static + Sized,
1444{
1445    type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
1446    type Output: for<'de> Deserialize<'de> + Serialize + Into<LanguageModelToolResultContent>;
1447
1448    fn name(&self) -> SharedString;
1449
1450    fn description(&self) -> SharedString {
1451        let schema = schemars::schema_for!(Self::Input);
1452        SharedString::new(
1453            schema
1454                .get("description")
1455                .and_then(|description| description.as_str())
1456                .unwrap_or_default(),
1457        )
1458    }
1459
1460    fn kind(&self) -> acp::ToolKind;
1461
1462    /// The initial tool title to display. Can be updated during the tool run.
1463    fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
1464
1465    /// Returns the JSON schema that describes the tool's input.
1466    fn input_schema(&self) -> Schema {
1467        schemars::schema_for!(Self::Input)
1468    }
1469
1470    /// Some tools rely on a provider for the underlying billing or other reasons.
1471    /// Allow the tool to check if they are compatible, or should be filtered out.
1472    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1473        true
1474    }
1475
1476    /// Runs the tool with the provided input.
1477    fn run(
1478        self: Arc<Self>,
1479        input: Self::Input,
1480        event_stream: ToolCallEventStream,
1481        cx: &mut App,
1482    ) -> Task<Result<Self::Output>>;
1483
1484    /// Emits events for a previous execution of the tool.
1485    fn replay(
1486        &self,
1487        _input: Self::Input,
1488        _output: Self::Output,
1489        _event_stream: ToolCallEventStream,
1490        _cx: &mut App,
1491    ) -> Result<()> {
1492        Ok(())
1493    }
1494
1495    fn erase(self) -> Arc<dyn AnyAgentTool> {
1496        Arc::new(Erased(Arc::new(self)))
1497    }
1498}
1499
1500pub struct Erased<T>(T);
1501
1502pub struct AgentToolOutput {
1503    pub llm_output: LanguageModelToolResultContent,
1504    pub raw_output: serde_json::Value,
1505}
1506
1507pub trait AnyAgentTool {
1508    fn name(&self) -> SharedString;
1509    fn description(&self) -> SharedString;
1510    fn kind(&self) -> acp::ToolKind;
1511    fn initial_title(&self, input: serde_json::Value) -> SharedString;
1512    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
1513    fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
1514        true
1515    }
1516    fn run(
1517        self: Arc<Self>,
1518        input: serde_json::Value,
1519        event_stream: ToolCallEventStream,
1520        cx: &mut App,
1521    ) -> Task<Result<AgentToolOutput>>;
1522    fn replay(
1523        &self,
1524        input: serde_json::Value,
1525        output: serde_json::Value,
1526        event_stream: ToolCallEventStream,
1527        cx: &mut App,
1528    ) -> Result<()>;
1529}
1530
1531impl<T> AnyAgentTool for Erased<Arc<T>>
1532where
1533    T: AgentTool,
1534{
1535    fn name(&self) -> SharedString {
1536        self.0.name()
1537    }
1538
1539    fn description(&self) -> SharedString {
1540        self.0.description()
1541    }
1542
1543    fn kind(&self) -> agent_client_protocol::ToolKind {
1544        self.0.kind()
1545    }
1546
1547    fn initial_title(&self, input: serde_json::Value) -> SharedString {
1548        let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
1549        self.0.initial_title(parsed_input)
1550    }
1551
1552    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
1553        let mut json = serde_json::to_value(self.0.input_schema())?;
1554        adapt_schema_to_format(&mut json, format)?;
1555        Ok(json)
1556    }
1557
1558    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
1559        self.0.supported_provider(provider)
1560    }
1561
1562    fn run(
1563        self: Arc<Self>,
1564        input: serde_json::Value,
1565        event_stream: ToolCallEventStream,
1566        cx: &mut App,
1567    ) -> Task<Result<AgentToolOutput>> {
1568        cx.spawn(async move |cx| {
1569            let input = serde_json::from_value(input)?;
1570            let output = cx
1571                .update(|cx| self.0.clone().run(input, event_stream, cx))?
1572                .await?;
1573            let raw_output = serde_json::to_value(&output)?;
1574            Ok(AgentToolOutput {
1575                llm_output: output.into(),
1576                raw_output,
1577            })
1578        })
1579    }
1580
1581    fn replay(
1582        &self,
1583        input: serde_json::Value,
1584        output: serde_json::Value,
1585        event_stream: ToolCallEventStream,
1586        cx: &mut App,
1587    ) -> Result<()> {
1588        let input = serde_json::from_value(input)?;
1589        let output = serde_json::from_value(output)?;
1590        self.0.replay(input, output, event_stream, cx)
1591    }
1592}
1593
1594#[derive(Clone)]
1595struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
1596
1597impl ThreadEventStream {
1598    fn send_user_message(&self, message: &UserMessage) {
1599        self.0
1600            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
1601            .ok();
1602    }
1603
1604    fn send_text(&self, text: &str) {
1605        self.0
1606            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
1607            .ok();
1608    }
1609
1610    fn send_thinking(&self, text: &str) {
1611        self.0
1612            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
1613            .ok();
1614    }
1615
1616    fn send_tool_call(
1617        &self,
1618        id: &LanguageModelToolUseId,
1619        title: SharedString,
1620        kind: acp::ToolKind,
1621        input: serde_json::Value,
1622    ) {
1623        self.0
1624            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
1625                id,
1626                title.to_string(),
1627                kind,
1628                input,
1629            ))))
1630            .ok();
1631    }
1632
1633    fn initial_tool_call(
1634        id: &LanguageModelToolUseId,
1635        title: String,
1636        kind: acp::ToolKind,
1637        input: serde_json::Value,
1638    ) -> acp::ToolCall {
1639        acp::ToolCall {
1640            id: acp::ToolCallId(id.to_string().into()),
1641            title,
1642            kind,
1643            status: acp::ToolCallStatus::Pending,
1644            content: vec![],
1645            locations: vec![],
1646            raw_input: Some(input),
1647            raw_output: None,
1648        }
1649    }
1650
1651    fn update_tool_call_fields(
1652        &self,
1653        tool_use_id: &LanguageModelToolUseId,
1654        fields: acp::ToolCallUpdateFields,
1655    ) {
1656        self.0
1657            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1658                acp::ToolCallUpdate {
1659                    id: acp::ToolCallId(tool_use_id.to_string().into()),
1660                    fields,
1661                }
1662                .into(),
1663            )))
1664            .ok();
1665    }
1666
1667    fn send_stop(&self, reason: StopReason) {
1668        match reason {
1669            StopReason::EndTurn => {
1670                self.0
1671                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
1672                    .ok();
1673            }
1674            StopReason::MaxTokens => {
1675                self.0
1676                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
1677                    .ok();
1678            }
1679            StopReason::Refusal => {
1680                self.0
1681                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
1682                    .ok();
1683            }
1684            StopReason::ToolUse => {}
1685        }
1686    }
1687
1688    fn send_canceled(&self) {
1689        self.0
1690            .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
1691            .ok();
1692    }
1693
1694    fn send_error(&self, error: impl Into<anyhow::Error>) {
1695        self.0.unbounded_send(Err(error.into())).ok();
1696    }
1697}
1698
1699#[derive(Clone)]
1700pub struct ToolCallEventStream {
1701    tool_use_id: LanguageModelToolUseId,
1702    stream: ThreadEventStream,
1703    fs: Option<Arc<dyn Fs>>,
1704}
1705
1706impl ToolCallEventStream {
1707    #[cfg(test)]
1708    pub fn test() -> (Self, ToolCallEventStreamReceiver) {
1709        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
1710
1711        let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None);
1712
1713        (stream, ToolCallEventStreamReceiver(events_rx))
1714    }
1715
1716    fn new(
1717        tool_use_id: LanguageModelToolUseId,
1718        stream: ThreadEventStream,
1719        fs: Option<Arc<dyn Fs>>,
1720    ) -> Self {
1721        Self {
1722            tool_use_id,
1723            stream,
1724            fs,
1725        }
1726    }
1727
1728    pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) {
1729        self.stream
1730            .update_tool_call_fields(&self.tool_use_id, fields);
1731    }
1732
1733    pub fn update_diff(&self, diff: Entity<acp_thread::Diff>) {
1734        self.stream
1735            .0
1736            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1737                acp_thread::ToolCallUpdateDiff {
1738                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1739                    diff,
1740                }
1741                .into(),
1742            )))
1743            .ok();
1744    }
1745
1746    pub fn update_terminal(&self, terminal: Entity<acp_thread::Terminal>) {
1747        self.stream
1748            .0
1749            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
1750                acp_thread::ToolCallUpdateTerminal {
1751                    id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1752                    terminal,
1753                }
1754                .into(),
1755            )))
1756            .ok();
1757    }
1758
1759    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
1760        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
1761            return Task::ready(Ok(()));
1762        }
1763
1764        let (response_tx, response_rx) = oneshot::channel();
1765        self.stream
1766            .0
1767            .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization(
1768                ToolCallAuthorization {
1769                    tool_call: acp::ToolCallUpdate {
1770                        id: acp::ToolCallId(self.tool_use_id.to_string().into()),
1771                        fields: acp::ToolCallUpdateFields {
1772                            title: Some(title.into()),
1773                            ..Default::default()
1774                        },
1775                    },
1776                    options: vec![
1777                        acp::PermissionOption {
1778                            id: acp::PermissionOptionId("always_allow".into()),
1779                            name: "Always Allow".into(),
1780                            kind: acp::PermissionOptionKind::AllowAlways,
1781                        },
1782                        acp::PermissionOption {
1783                            id: acp::PermissionOptionId("allow".into()),
1784                            name: "Allow".into(),
1785                            kind: acp::PermissionOptionKind::AllowOnce,
1786                        },
1787                        acp::PermissionOption {
1788                            id: acp::PermissionOptionId("deny".into()),
1789                            name: "Deny".into(),
1790                            kind: acp::PermissionOptionKind::RejectOnce,
1791                        },
1792                    ],
1793                    response: response_tx,
1794                },
1795            )))
1796            .ok();
1797        let fs = self.fs.clone();
1798        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
1799            "always_allow" => {
1800                if let Some(fs) = fs.clone() {
1801                    cx.update(|cx| {
1802                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
1803                            settings.set_always_allow_tool_actions(true);
1804                        });
1805                    })?;
1806                }
1807
1808                Ok(())
1809            }
1810            "allow" => Ok(()),
1811            _ => Err(anyhow!("Permission to run tool denied by user")),
1812        })
1813    }
1814}
1815
1816#[cfg(test)]
1817pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<ThreadEvent>>);
1818
1819#[cfg(test)]
1820impl ToolCallEventStreamReceiver {
1821    pub async fn expect_authorization(&mut self) -> ToolCallAuthorization {
1822        let event = self.0.next().await;
1823        if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event {
1824            auth
1825        } else {
1826            panic!("Expected ToolCallAuthorization but got: {:?}", event);
1827        }
1828    }
1829
1830    pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
1831        let event = self.0.next().await;
1832        if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
1833            update,
1834        )))) = event
1835        {
1836            update.terminal
1837        } else {
1838            panic!("Expected terminal but got: {:?}", event);
1839        }
1840    }
1841}
1842
1843#[cfg(test)]
1844impl std::ops::Deref for ToolCallEventStreamReceiver {
1845    type Target = mpsc::UnboundedReceiver<Result<ThreadEvent>>;
1846
1847    fn deref(&self) -> &Self::Target {
1848        &self.0
1849    }
1850}
1851
1852#[cfg(test)]
1853impl std::ops::DerefMut for ToolCallEventStreamReceiver {
1854    fn deref_mut(&mut self) -> &mut Self::Target {
1855        &mut self.0
1856    }
1857}
1858
1859impl From<&str> for UserMessageContent {
1860    fn from(text: &str) -> Self {
1861        Self::Text(text.into())
1862    }
1863}
1864
1865impl From<acp::ContentBlock> for UserMessageContent {
1866    fn from(value: acp::ContentBlock) -> Self {
1867        match value {
1868            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
1869            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
1870            acp::ContentBlock::Audio(_) => {
1871                // TODO
1872                Self::Text("[audio]".to_string())
1873            }
1874            acp::ContentBlock::ResourceLink(resource_link) => {
1875                match MentionUri::parse(&resource_link.uri) {
1876                    Ok(uri) => Self::Mention {
1877                        uri,
1878                        content: String::new(),
1879                    },
1880                    Err(err) => {
1881                        log::error!("Failed to parse mention link: {}", err);
1882                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
1883                    }
1884                }
1885            }
1886            acp::ContentBlock::Resource(resource) => match resource.resource {
1887                acp::EmbeddedResourceResource::TextResourceContents(resource) => {
1888                    match MentionUri::parse(&resource.uri) {
1889                        Ok(uri) => Self::Mention {
1890                            uri,
1891                            content: resource.text,
1892                        },
1893                        Err(err) => {
1894                            log::error!("Failed to parse mention link: {}", err);
1895                            Self::Text(
1896                                MarkdownCodeBlock {
1897                                    tag: &resource.uri,
1898                                    text: &resource.text,
1899                                }
1900                                .to_string(),
1901                            )
1902                        }
1903                    }
1904                }
1905                acp::EmbeddedResourceResource::BlobResourceContents(_) => {
1906                    // TODO
1907                    Self::Text("[blob]".to_string())
1908                }
1909            },
1910        }
1911    }
1912}
1913
1914impl From<UserMessageContent> for acp::ContentBlock {
1915    fn from(content: UserMessageContent) -> Self {
1916        match content {
1917            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
1918                text,
1919                annotations: None,
1920            }),
1921            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
1922                data: image.source.to_string(),
1923                mime_type: "image/png".to_string(),
1924                annotations: None,
1925                uri: None,
1926            }),
1927            UserMessageContent::Mention { uri, content } => {
1928                todo!()
1929            }
1930        }
1931    }
1932}
1933
1934fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
1935    LanguageModelImage {
1936        source: image_content.data.into(),
1937        // TODO: make this optional?
1938        size: gpui::Size::new(0.into(), 0.into()),
1939    }
1940}